├── .github ├── CODEOWNERS └── workflows │ ├── fossa.yml │ └── test.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── chestnut.go ├── chestnut_test.go ├── codecov.yml ├── encoding ├── compress │ ├── compress.go │ ├── compress_test.go │ └── zstd │ │ ├── zstd.go │ │ └── zstd_test.go ├── json │ ├── decode.go │ ├── decode_test.go │ ├── encode.go │ ├── encode_test.go │ ├── encoders │ │ ├── encoder.go │ │ ├── hash │ │ │ ├── encoder.go │ │ │ ├── encoder_test.go │ │ │ ├── hash.go │ │ │ └── hash_test.go │ │ ├── lookup │ │ │ ├── decoder.go │ │ │ ├── decoder_test.go │ │ │ ├── encoder.go │ │ │ ├── encoder_test.go │ │ │ ├── lookup.go │ │ │ └── lookup_test.go │ │ └── secure │ │ │ ├── decoder.go │ │ │ ├── decoder_test.go │ │ │ ├── encoder.go │ │ │ ├── encoder_test.go │ │ │ ├── options.go │ │ │ ├── secure_test.go │ │ │ ├── tags_all_test.go │ │ │ ├── tags_both_test.go │ │ │ ├── tags_esc_test.go │ │ │ ├── tags_hash_test.go │ │ │ ├── tags_json_test.go │ │ │ ├── tags_none_test.go │ │ │ └── tags_secure_test.go │ ├── packager │ │ ├── encoding.go │ │ ├── package.go │ │ └── package_test.go │ ├── secure_test.go │ └── tags_test.go └── tags │ ├── tags.go │ └── tags_test.go ├── encryptor ├── aes.go ├── aes │ ├── aes.go │ ├── aes_test.go │ ├── cfb.go │ ├── cfb_test.go │ ├── ctr.go │ ├── ctr_test.go │ ├── gcm.go │ ├── gcm_test.go │ └── stream.go ├── aes_test.go ├── chain.go ├── chian_test.go └── crypto │ ├── data.go │ ├── data_test.go │ ├── encryptor.go │ ├── hash.go │ ├── hash_test.go │ ├── header.go │ ├── header_test.go │ ├── key.go │ ├── key_test.go │ ├── mode.go │ ├── rand.go │ ├── rand_test.go │ └── secret.go ├── examples ├── README.md ├── hash │ └── main.go ├── keystore │ └── main.go └── sparse │ └── main.go ├── go.mod ├── go.sum ├── keystore ├── README.md ├── keystore.go ├── keystore_test.go ├── keyutils.go └── keyutils_test.go ├── log ├── level.go ├── level_test.go ├── logger.go ├── logrus.go ├── named.go ├── named_test.go ├── std.go └── zap.go ├── options.go ├── storage ├── bolt │ ├── store.go │ └── store_test.go ├── nuts │ ├── store.go │ └── store_test.go ├── options.go ├── storage.go └── store_test │ └── test_suite.go └── value ├── id.go ├── id_test.go ├── keyed.go └── secure.go /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @jrapoport 2 | -------------------------------------------------------------------------------- /.github/workflows/fossa.yml: -------------------------------------------------------------------------------- 1 | name: Dependency License Scanning 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | 8 | defaults: 9 | run: 10 | shell: bash 11 | 12 | jobs: 13 | fossa-scan: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Dependencies 17 | uses: actions/checkout@v2 18 | - name: Fossa 19 | uses: fossas/fossa-action@v1 20 | with: 21 | api-key: ${{secrets.FOSSA_API_KEY}} 22 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | pull_request: 4 | types: [opened, synchronize, reopened] 5 | name: test 6 | jobs: 7 | test: 8 | strategy: 9 | matrix: 10 | go-version: [1.23.x] 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Install Go 14 | uses: actions/setup-go@v2 15 | with: 16 | go-version: ${{ matrix.go-version }} 17 | - name: Checkout code 18 | uses: actions/checkout@v2 19 | - name: Install dependencies 20 | run: make deps 21 | - name: Lint and test 22 | run: make all TEST_FLAGS="-covermode=atomic -coverpkg=./... -coverprofile=coverage.txt" 23 | - name: Upload coverage to Codecov 24 | uses: codecov/codecov-action@v1 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | coverage.txt 2 | 3 | # Created by https://www.toptal.com/developers/gitignore/api/go,visualstudiocode,jetbrains+all,macos 4 | # Edit at https://www.toptal.com/developers/gitignore?templates=go,visualstudiocode,jetbrains+all,macos 5 | 6 | ### Go ### 7 | # Binaries for programs and plugins 8 | *.exe 9 | *.exe~ 10 | *.dll 11 | *.so 12 | *.dylib 13 | 14 | # Test binary, built with `go test -c` 15 | *.test 16 | 17 | # Output of the go coverage tool, specifically when used with LiteIDE 18 | *.out 19 | 20 | # Dependency directories (remove the comment below to include it) 21 | # vendor/ 22 | 23 | ### Go Patch ### 24 | /vendor/ 25 | /Godeps/ 26 | 27 | ### JetBrains+all ### 28 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 29 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 30 | 31 | # User-specific stuff 32 | .idea/**/workspace.xml 33 | .idea/**/tasks.xml 34 | .idea/**/usage.statistics.xml 35 | .idea/**/dictionaries 36 | .idea/**/shelf 37 | 38 | # Generated files 39 | .idea/**/contentModel.xml 40 | 41 | # Sensitive or high-churn files 42 | .idea/**/dataSources/ 43 | .idea/**/dataSources.ids 44 | .idea/**/dataSources.local.xml 45 | .idea/**/sqlDataSources.xml 46 | .idea/**/dynamic.xml 47 | .idea/**/uiDesigner.xml 48 | .idea/**/dbnavigator.xml 49 | 50 | # Gradle 51 | .idea/**/gradle.xml 52 | .idea/**/libraries 53 | 54 | # Gradle and Maven with auto-import 55 | # When using Gradle or Maven with auto-import, you should exclude module files, 56 | # since they will be recreated, and may cause churn. Uncomment if using 57 | # auto-import. 58 | # .idea/artifacts 59 | # .idea/compiler.xml 60 | # .idea/jarRepositories.xml 61 | # .idea/modules.xml 62 | # .idea/*.iml 63 | # .idea/modules 64 | # *.iml 65 | # *.ipr 66 | 67 | # CMake 68 | cmake-build-*/ 69 | 70 | # Mongo Explorer plugin 71 | .idea/**/mongoSettings.xml 72 | 73 | # File-based project format 74 | *.iws 75 | 76 | # IntelliJ 77 | out/ 78 | 79 | # mpeltonen/sbt-idea plugin 80 | .idea_modules/ 81 | 82 | # JIRA plugin 83 | atlassian-ide-plugin.xml 84 | 85 | # Cursive Clojure plugin 86 | .idea/replstate.xml 87 | 88 | # Crashlytics plugin (for Android Studio and IntelliJ) 89 | com_crashlytics_export_strings.xml 90 | crashlytics.properties 91 | crashlytics-build.properties 92 | fabric.properties 93 | 94 | # Editor-based Rest Client 95 | .idea/httpRequests 96 | 97 | # Android studio 3.1+ serialized cache file 98 | .idea/caches/build_file_checksums.ser 99 | 100 | ### JetBrains+all Patch ### 101 | # Ignores the whole .idea folder and all .iml files 102 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 103 | 104 | .idea/ 105 | 106 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 107 | 108 | *.iml 109 | modules.xml 110 | .idea/misc.xml 111 | *.ipr 112 | 113 | # Sonarlint plugin 114 | .idea/sonarlint 115 | 116 | ### macOS ### 117 | # General 118 | .DS_Store 119 | .AppleDouble 120 | .LSOverride 121 | 122 | # Icon must end with two \r 123 | Icon 124 | 125 | 126 | # Thumbnails 127 | ._* 128 | 129 | # Files that might appear in the root of a volume 130 | .DocumentRevisions-V100 131 | .fseventsd 132 | .Spotlight-V100 133 | .TemporaryItems 134 | .Trashes 135 | .VolumeIcon.icns 136 | .com.apple.timemachine.donotpresent 137 | 138 | # Directories potentially created on remote AFP share 139 | .AppleDB 140 | .AppleDesktop 141 | Network Trash Folder 142 | Temporary Items 143 | .apdisk 144 | 145 | ### VisualStudioCode ### 146 | .vscode/* 147 | !.vscode/tasks.json 148 | !.vscode/launch.json 149 | *.code-workspace 150 | 151 | ### VisualStudioCode Patch ### 152 | # Ignore all local history of files 153 | .history 154 | .ionide 155 | 156 | # End of https://www.toptal.com/developers/gitignore/api/go,visualstudiocode,jetbrains+all,macos 157 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # CONTRIBUTING 2 | 3 | Contributions are always welcome, no matter how large or small. Before contributing, 4 | please check to see if the issue is already being tracked and if there is already a PR. 5 | 6 | ## Setup 7 | 8 | > Install Go 1.15.x 9 | 10 | Chestnut uses the Go Modules support built into Go 1.11 to build. 11 | The easiest is to clone Chestnut in a directory outside of GOPATH, 12 | as in the following example: 13 | 14 | ```sh 15 | $ git clone https://github.com/jrapoport/chestnut 16 | $ cd chestnut 17 | $ make deps 18 | ``` 19 | 20 | ## Running examples 21 | 22 | ```sh 23 | $ make examples 24 | ``` 25 | 26 | ## Testing 27 | 28 | ```sh 29 | $ make test 30 | ``` 31 | 32 | ## Pull Requests 33 | 34 | Pull requests are welcome!. 35 | 36 | 1. Fork the repo and create your branch from `master`. 37 | 2. If you've added code that should be tested, add tests. 38 | 3. If you've changed APIs, update the documentation. 39 | 4. Ensure the test suite passes. 40 | 5. Make sure your code lints. 41 | 42 | ```sh 43 | # will run fmt, lint, & vet 44 | $ make pr 45 | ``` 46 | 47 | ## License 48 | 49 | By contributing to Chestnut, you agree that your contributions will be licensed 50 | under its [MIT license](LICENSE). 51 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 jrapoport 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | GO := go 2 | GO_PATH := $(shell $(GO) env GOPATH) 3 | GO_BIN := $(GO_PATH)/bin 4 | 5 | GO_MOD := $(GO) mod 6 | GO_GET := $(GO) get -u -v 7 | GO_FMT := $(GO) fmt 8 | GO_RUN := $(GO) run 9 | GO_TEST:= $(GO) test -p 1 -v -failfast 10 | GO_LINT := golangci-lint run 11 | # BUG: go vet: structtag field repeats json warning with valid override #40102 12 | # https://github.com/golang/go/issues/40102 13 | GO_VET:= $(GO) vet -v -structtag=false 14 | 15 | #$(GO_LINT): 16 | # $(GO_GET) golang.org/x/lint/golint 17 | # brew install golangci-lint 18 | 19 | deps: 20 | $(GO_MOD) tidy 21 | $(GO_MOD) download 22 | 23 | fmt: 24 | $(GO_FMT) ./... 25 | 26 | lint: 27 | $(GO_LINT) ./... 28 | 29 | vet: 30 | $(GO_VET) ./... 31 | 32 | # disable lint for now 33 | pr: vet 34 | 35 | test: 36 | $(GO_TEST) $(TEST_FLAGS) ./... 37 | 38 | # build any example with make 39 | EXAMPLE_NAME := $(word 1, $(MAKECMDGOALS)) 40 | $(EXAMPLE_NAME): 41 | ifneq ($(filter examples/$(EXAMPLE_NAME),$(wildcard examples/*)),) 42 | $(GO_RUN) examples/$(EXAMPLE_NAME)/main.go 43 | endif 44 | 45 | all: pr test 46 | 47 | .DEFAULT_GOAL := all 48 | 49 | .PHONY: deps fmt lint vet keystore 50 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: auto 6 | threshold: 5% 7 | if_not_found: success 8 | patch: 9 | default: 10 | target: auto 11 | threshold: 5% 12 | if_not_found: success 13 | 14 | ignore: 15 | - "examples/" 16 | -------------------------------------------------------------------------------- /encoding/compress/compress.go: -------------------------------------------------------------------------------- 1 | package compress 2 | 3 | import ( 4 | "bytes" 5 | "encoding/hex" 6 | ) 7 | 8 | // Format is the supporter compression algorithm type. 9 | type Format string 10 | 11 | // TODO: support additional compression algorithms besides 12 | // Zstandard from https://github.com/klauspost/compress 13 | const ( 14 | // None no compression. 15 | None Format = "" 16 | 17 | // Custom a custom compression format is being used. 18 | Custom Format = "custom" 19 | 20 | // Zstd Zstandard compression https://facebook.github.io/zstd/. 21 | Zstd Format = "zstd" 22 | ) 23 | 24 | // Valid returns true if the Format is valid. 25 | func (f Format) Valid() bool { 26 | switch f { 27 | case None: 28 | break 29 | case Custom: 30 | break 31 | case Zstd: 32 | break 33 | default: 34 | return false 35 | } 36 | return true 37 | } 38 | 39 | // CompressorFunc is the function the prototype for compression. 40 | type CompressorFunc func(data []byte) (compressed []byte, err error) 41 | 42 | // PassthroughCompressor is a dummy function for development and testing *ONLY*. 43 | /* 44 | * WARNING: DO NOT USE IN PRODUCTION. 45 | * PassthroughCompressor is *NOT* compression and *DOES NOT* compress data. 46 | */ 47 | var PassthroughCompressor CompressorFunc = func(data []byte) ([]byte, error) { 48 | return []byte(hex.EncodeToString(data)), nil 49 | } 50 | 51 | // DecompressorFunc is the function the prototype for decompression. 52 | type DecompressorFunc func(compressed []byte) (data []byte, err error) 53 | 54 | // PassthroughDecompressor is a dummy function for development and testing *ONLY*. 55 | /* 56 | * WARNING: DO NOT USE IN PRODUCTION. 57 | * PassthroughDecompressor is *NOT* decompression and *DOES NOT* decompress data. 58 | */ 59 | var PassthroughDecompressor DecompressorFunc = func(compressed []byte) ([]byte, error) { 60 | return hex.DecodeString(string(compressed)) 61 | } 62 | 63 | var ( 64 | formatTag = []byte{0xB, 0xA, 0xD, 0xA, 0x5, 0x5, 0x5, 0xB} 65 | formatSep = []byte{0x1e} // US-ASCII Record Separator 66 | ) 67 | 68 | // EncodeFormat adds the compression format to the compressed data. 69 | func EncodeFormat(data []byte, f Format) []byte { 70 | if f == None || len(data) <= 0 { 71 | return data 72 | } 73 | return bytes.Join([][]byte{formatTag, []byte(f), data}, formatSep) 74 | } 75 | 76 | // DecodeFormat removes and returns the compression format from the compressed data. 77 | // If no compression format is found DecodeFormat returns the original the data. 78 | func DecodeFormat(data []byte) ([]byte, Format) { 79 | if len(data) <= 0 { 80 | return data, None 81 | } 82 | if !bytes.HasPrefix(data, formatTag) { 83 | return data, None 84 | } 85 | parts := bytes.SplitN(data, formatSep, 3) 86 | if len(parts) < 3 { 87 | return data, None 88 | } 89 | // double check 90 | if !bytes.Equal(parts[0], formatTag) { 91 | return data, None 92 | } 93 | format := Format(parts[1]) 94 | switch format { 95 | case Zstd: 96 | break 97 | case Custom: 98 | break 99 | default: 100 | return data, None 101 | } 102 | return parts[2], format 103 | } 104 | -------------------------------------------------------------------------------- /encoding/compress/compress_test.go: -------------------------------------------------------------------------------- 1 | package compress 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | var ( 11 | empty = []byte("") 12 | value = []byte("i-am-a-test-in") 13 | valueFmt = []byte{0xb, 0xa, 0xd, 0xa, 0x5, 0x5, 0x5, 0xb, 0x1e, 0x7a, 0x73, 0x74, 0x64, 0x1e, 14 | 0x69, 0x2d, 0x61, 0x6d, 0x2d, 0x61, 0x2d, 0x74, 0x65, 0x73, 0x74, 0x2d, 0x69, 0x6e} 15 | comp = []byte{ 16 | 0x28, 0xb5, 0x2f, 0xfd, 0x4, 0x0, 0x71, 0x0, 0x0, 0x69, 0x2d, 0x61, 0x6d, 0x2d, 0x61, 0x2d, 17 | 0x74, 0x65, 0x73, 0x74, 0x2d, 0x69, 0x6e, 0x31, 0x49, 0x18, 0x48} 18 | compFmt = []byte{ 19 | 0xb, 0xa, 0xd, 0xa, 0x5, 0x5, 0x5, 0xb, 0x1e, 0x7a, 0x73, 0x74, 0x64, 0x1e, 0x28, 0xb5, 20 | 0x2f, 0xfd, 0x4, 0x0, 0x71, 0x0, 0x0, 0x69, 0x2d, 0x61, 0x6d, 0x2d, 0x61, 0x2d, 0x74, 0x65, 21 | 0x73, 0x74, 0x2d, 0x69, 0x6e, 0x31, 0x49, 0x18, 0x48} 22 | extra = []byte{ 23 | 0x69, 0x2d, 0x61, 0x6d, 0x2d, 0x1e, 0x2d, 0x74, 0x65, 0x73, 0x1e, 0x2d, 0x69, 0x6e} 24 | extraFmt = []byte{ 25 | 0xb, 0xa, 0xd, 0xa, 0x5, 0x5, 0x5, 0xb, 0x1e, 0x7a, 0x73, 0x74, 0x64, 0x1e, 0x69, 0x2d, 26 | 0x61, 0x6d, 0x2d, 0x1e, 0x2d, 0x74, 0x65, 0x73, 0x1e, 0x2d, 0x69, 0x6e} 27 | badFmt1 = []byte{0xb, 0xa, 0xd, 0xa, 0x5, 0x5, 0x5, 0xb, 0x1e, 0xa, 0x73, 0x74, 0x64, 0x1e, 0x69, 28 | 0x2d, 0x61, 0x6d, 0x2d, 0x61, 0x2d, 0x74, 0x65, 0x73, 0x74, 0x2d, 0x69, 0x6e} 29 | badFmt2 = bytes.Join([][]byte{formatTag, empty}, formatSep) 30 | ) 31 | 32 | func TestEncodeFormat(t *testing.T) { 33 | type testCase struct { 34 | in []byte 35 | format Format 36 | out []byte 37 | } 38 | var tests = []testCase{ 39 | {nil, None, nil}, 40 | {nil, Zstd, nil}, 41 | {empty, None, empty}, 42 | {empty, Zstd, empty}, 43 | {value, None, value}, 44 | {value, Zstd, valueFmt}, 45 | {comp, Zstd, compFmt}, 46 | } 47 | for _, test := range tests { 48 | out := EncodeFormat(test.in, test.format) 49 | assert.Equal(t, test.out, out) 50 | } 51 | } 52 | 53 | func TestDecodeFormat(t *testing.T) { 54 | type testCase struct { 55 | in []byte 56 | out []byte 57 | format Format 58 | } 59 | var tests = []testCase{ 60 | {nil, nil, None}, 61 | {empty, empty, None}, 62 | {value, value, None}, 63 | {valueFmt, value, Zstd}, 64 | {compFmt, comp, Zstd}, 65 | {extraFmt, extra, Zstd}, 66 | {badFmt1, badFmt1, None}, 67 | {badFmt2, badFmt2, None}, 68 | } 69 | for _, test := range tests { 70 | out, format := DecodeFormat(test.in) 71 | assert.Equal(t, test.format, format) 72 | assert.Equal(t, test.out, out) 73 | } 74 | } 75 | 76 | func TestPassthrough(t *testing.T) { 77 | testString := []byte("test-string") 78 | c, err := PassthroughCompressor(testString) 79 | assert.NoError(t, err) 80 | assert.NotEmpty(t, c) 81 | d, err := PassthroughDecompressor(c) 82 | assert.NoError(t, err) 83 | assert.Equal(t, testString, d) 84 | } 85 | -------------------------------------------------------------------------------- /encoding/compress/zstd/zstd.go: -------------------------------------------------------------------------------- 1 | package zstd 2 | 3 | import ( 4 | "github.com/jrapoport/chestnut/encoding/compress" 5 | "github.com/klauspost/compress/zstd" 6 | ) 7 | 8 | // Zstandard compression 9 | // https://facebook.github.io/zstd/ 10 | 11 | var ( 12 | _ compress.CompressorFunc = Compress // Compress conforms to CompressorFunc 13 | _ compress.DecompressorFunc = Decompress // Decompress conforms to DecompressorFunc 14 | ) 15 | 16 | // Create a writer that caches compressors. 17 | // For this operation type we supply a nil Reader. 18 | var encoderZStd, _ = zstd.NewWriter(nil) 19 | 20 | // Compress a buffer. If you have a destination buffer, 21 | // the allocation src the call can also be eliminated. 22 | func Compress(src []byte) ([]byte, error) { 23 | return encoderZStd.EncodeAll(src, make([]byte, 0, len(src))), nil 24 | } 25 | 26 | // Create a reader that caches decompressors. 27 | // For this operation type we supply a nil Reader. 28 | var decoderZStd, _ = zstd.NewReader(nil) 29 | 30 | // Decompress a buffer. We don't supply a destination 31 | // buffer, so it will be allocated by the decoder. 32 | func Decompress(src []byte) ([]byte, error) { 33 | return decoderZStd.DecodeAll(src, nil) 34 | } 35 | -------------------------------------------------------------------------------- /encoding/compress/zstd/zstd_test.go: -------------------------------------------------------------------------------- 1 | package zstd 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | var ( 10 | testNil = []byte(nil) 11 | testEmpty = []byte("") 12 | testSpace = []byte(" ") 13 | testValue = []byte("i-am-an-uncompressed-string") 14 | compSpace = []byte{0x28, 0xb5, 0x2f, 0xfd, 0x4, 0x0, 0x9, 0x0, 0x0, 0x20, 0x8d, 0x63, 0x68, 0xb6} 15 | compValue = []byte{0x28, 0xb5, 0x2f, 0xfd, 0x4, 0x0, 0xd9, 0x0, 0x0, 0x69, 0x2d, 0x61, 0x6d, 0x2d, 16 | 0x61, 0x6e, 0x2d, 0x75, 0x6e, 0x63, 0x6f, 0x6d, 0x70, 0x72, 0x65, 0x73, 0x73, 0x65, 0x64, 0x2d, 17 | 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0xab, 0x52, 0xd3, 0x9d} 18 | ) 19 | 20 | func TestCompressZStd(t *testing.T) { 21 | tests := []struct { 22 | src []byte 23 | out []byte 24 | err assert.ErrorAssertionFunc 25 | }{ 26 | {testNil, testEmpty, assert.NoError}, 27 | {testEmpty, testEmpty, assert.NoError}, 28 | {testSpace, compSpace, assert.NoError}, 29 | {testValue, compValue, assert.NoError}, 30 | } 31 | for _, test := range tests { 32 | bytes, err := Compress(test.src) 33 | test.err(t, err) 34 | assert.Equal(t, test.out, bytes) 35 | } 36 | } 37 | 38 | func TestDecompressZStd(t *testing.T) { 39 | tests := []struct { 40 | src []byte 41 | out []byte 42 | err assert.ErrorAssertionFunc 43 | }{ 44 | {testNil, testNil, assert.NoError}, 45 | {testEmpty, testNil, assert.NoError}, 46 | {testValue, testNil, assert.Error}, 47 | {compSpace, testSpace, assert.NoError}, 48 | {compValue, testValue, assert.NoError}, 49 | } 50 | for _, test := range tests { 51 | bytes, err := Decompress(test.src) 52 | test.err(t, err) 53 | assert.Equal(t, test.out, bytes) 54 | } 55 | } 56 | 57 | func TestZStd(t *testing.T) { 58 | // compress the src 59 | buf, err := Compress(testValue) 60 | assert.NoError(t, err) 61 | assert.NotNil(t, buf) 62 | // decompress the result 63 | src, err := Decompress(buf) 64 | assert.NoError(t, err) 65 | assert.Equal(t, testValue, src) 66 | } 67 | -------------------------------------------------------------------------------- /encoding/json/decode.go: -------------------------------------------------------------------------------- 1 | package json 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/jrapoport/chestnut/encoding/json/encoders" 7 | "github.com/jrapoport/chestnut/encoding/json/encoders/secure" 8 | ) 9 | 10 | // SecureUnmarshal decrypts & parses the JSON-encoded data returned by SecureUnmarshal and stores 11 | // the result in the value pointed to by v. If v is nil or not a pointer, Unmarshal returns an 12 | // error. SecureUnmarshal adds support for sparse decryption and via JSON struct tag options. If 13 | // SecureMarshal is called at least one 'secure' option set on a struct field JSON tag, only those 14 | // fields will be encrypted. The remaining encoded data stored as sparse plaintext. If SecureUnmarshal 15 | // is called on a sparse encoding with the sparse option set, SecureUnmarshal will skip the decryption 16 | // step and return only the plaintext decoding of v with encrypted fields replaced by empty values. 17 | // For more detail, SEE: https://github.com/jrapoport/chestnut/blob/master/README.md 18 | func SecureUnmarshal(data []byte, v interface{}, decryptFunc secure.DecryptionFunction, opt ...secure.Option) error { 19 | if v == nil { 20 | return errors.New("nil value") 21 | } 22 | enc := encoders.NewEncoder() 23 | ext := secure.NewSecureDecoderExtension(encoders.DefaultID, decryptFunc, opt...) 24 | enc.RegisterExtension(ext) 25 | defer ext.Close() 26 | unsealed, err := ext.Unseal(data) 27 | if err != nil { 28 | return err 29 | } 30 | if err = ext.Open(); err != nil { 31 | return err 32 | } 33 | return enc.Unmarshal(unsealed, v) 34 | } 35 | -------------------------------------------------------------------------------- /encoding/json/decode_test.go: -------------------------------------------------------------------------------- 1 | package json 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestSecureUnmarshal(t *testing.T) { 10 | // uncompressed secure 11 | secureObj := &Family{} 12 | err := SecureUnmarshal(familyEnc, secureObj, decrypt) 13 | assertDecoding(t, familyDec, secureObj, err) 14 | // compressed secure 15 | secureObj = &Family{} 16 | err = SecureUnmarshal(familyComp, secureObj, decrypt, compOpt) 17 | assertDecoding(t, familyDec, secureObj, err) 18 | // uncompressed sparse 19 | sparseObj := &Family{} 20 | err = SecureUnmarshal(familyEnc, sparseObj, decrypt, sparseOpt) 21 | assertDecoding(t, familySpr, sparseObj, err) 22 | // compressed sparse 23 | sparseObj = &Family{} 24 | err = SecureUnmarshal(familyComp, sparseObj, decrypt, compOpt, sparseOpt) 25 | assertDecoding(t, familySpr, sparseObj, err) 26 | } 27 | 28 | func TestSecureUnmarshal_Error(t *testing.T) { 29 | secureObj := &Family{} 30 | assert.Panics(t, func() { 31 | _ = SecureUnmarshal(familyEnc, secureObj, nil) 32 | }) 33 | err := SecureUnmarshal(familyEnc, nil, decrypt) 34 | assert.Error(t, err) 35 | err = SecureUnmarshal(nil, secureObj, decrypt) 36 | assert.Error(t, err) 37 | err = SecureUnmarshal([]byte("bad encoding"), secureObj, decrypt) 38 | assert.Error(t, err) 39 | var p chan bool 40 | err = SecureUnmarshal(familyEnc, &p, decrypt) 41 | assert.Error(t, err) 42 | } 43 | -------------------------------------------------------------------------------- /encoding/json/encode.go: -------------------------------------------------------------------------------- 1 | package json 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/jrapoport/chestnut/encoding/json/encoders" 7 | "github.com/jrapoport/chestnut/encoding/json/encoders/secure" 8 | ) 9 | 10 | // SecureMarshal returns an encrypted JSON encoding of v. It adds support for sparse encryption and 11 | // hashing via JSON struct tag options. If SecureMarshal is called at least one 'secure' option set 12 | // on a struct field JSON tag, only those fields will be encrypted. The remaining encoded data stored 13 | // as sparse plaintext. If no secure tag option is found, all the encoded data will be encrypted. 14 | // For more detail, SEE: https://github.com/jrapoport/chestnut/blob/master/README.md 15 | func SecureMarshal(v interface{}, encryptFunc secure.EncryptionFunction, opt ...secure.Option) ([]byte, error) { 16 | if v == nil { 17 | return nil, errors.New("nil value") 18 | } 19 | enc := encoders.NewEncoder() 20 | ext := secure.NewSecureEncoderExtension(encoders.DefaultID, encryptFunc, opt...) 21 | enc.RegisterExtension(ext) 22 | if err := ext.Open(); err != nil { 23 | return nil, err 24 | } 25 | buf, err := enc.Marshal(v) 26 | if err != nil { 27 | return nil, err 28 | } 29 | ext.Close() 30 | return ext.Seal(buf) 31 | } 32 | -------------------------------------------------------------------------------- /encoding/json/encode_test.go: -------------------------------------------------------------------------------- 1 | package json 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestSecureMarshal(t *testing.T) { 10 | // uncompressed 11 | bytes, err := SecureMarshal(family, encrypt) 12 | assert.NoError(t, err) 13 | assert.Equal(t, familyEnc, bytes) 14 | // compressed 15 | bytes, err = SecureMarshal(family, encrypt, compOpt) 16 | assert.NoError(t, err) 17 | assert.Equal(t, familyComp, bytes) 18 | } 19 | 20 | func TestSecureMarshal_Error(t *testing.T) { 21 | assert.Panics(t, func() { 22 | _, _ = SecureMarshal(family, nil) 23 | }) 24 | bytes, err := SecureMarshal(nil, nil) 25 | assert.Error(t, err) 26 | assert.Nil(t, bytes) 27 | bytes, err = SecureMarshal(nil, encrypt) 28 | assert.Error(t, err) 29 | assert.Nil(t, bytes) 30 | var p chan bool 31 | bytes, err = SecureMarshal(p, encrypt) 32 | assert.Error(t, err) 33 | assert.Nil(t, bytes) 34 | } 35 | -------------------------------------------------------------------------------- /encoding/json/encoders/encoder.go: -------------------------------------------------------------------------------- 1 | package encoders 2 | 3 | import ( 4 | "encoding/hex" 5 | "log" 6 | 7 | "github.com/jrapoport/chestnut/encryptor/crypto" 8 | "github.com/json-iterator/go" 9 | ) 10 | 11 | // NewEncoder returns a new encoder with a _clean_ configuration and _no_ registered 12 | // extensions. Extensions registered to this encoder will not impact the global encoder. 13 | // Config options match jsoniter ConfigCompatibleWithStandardLibrary. 14 | func NewEncoder() jsoniter.API { 15 | return jsoniter.Config{ 16 | EscapeHTML: true, 17 | SortMapKeys: true, 18 | ValidateJsonRawMessage: true, 19 | }.Froze() 20 | } 21 | 22 | // DefaultID can be used with a SecureEncoderExtension instead of a set id. When used, 23 | // it will be replaced with a randomly generated 8 character hex id for the encoder. 24 | // #954535 is hex color code for Chestnut. https://en.wikipedia.org/wiki/Chestnut_(color) 25 | var DefaultID = "0x954535" 26 | 27 | // InvalidID is an invalid encoder id. 28 | const InvalidID = "" 29 | 30 | // NewEncoderID returns a new random encoder id as a hex string. This id 31 | // is not guaranteed to be unique and does not have to be. It is only used 32 | // internally in the encoder so there is no risk of collision. 33 | func NewEncoderID() string { 34 | id, err := crypto.MakeRand(4) 35 | if err != nil { 36 | log.Fatal(err) 37 | } 38 | return hex.EncodeToString(id) 39 | } 40 | -------------------------------------------------------------------------------- /encoding/json/encoders/hash/encoder.go: -------------------------------------------------------------------------------- 1 | package hash 2 | 3 | import ( 4 | "strings" 5 | "unsafe" 6 | 7 | "github.com/jrapoport/chestnut/log" 8 | jsoniter "github.com/json-iterator/go" 9 | ) 10 | 11 | // Encoder is a ValEncoder strings that hashes string data 12 | // with HashingFunction before encoding it to stream. 13 | type Encoder struct { 14 | hashName string 15 | hashFunc HashingFunction 16 | encoder jsoniter.ValEncoder 17 | log log.Logger 18 | } 19 | 20 | // NewHashEncoder returns a string encoder that with encode string value using the supplied hashFn. 21 | // The hash encoder will run before the other encoders, ensuring that struct fields are hashed first. 22 | func NewHashEncoder(name string, hashFn HashingFunction, encoder jsoniter.ValEncoder) jsoniter.ValEncoder { 23 | if name == "" { 24 | name = "hash" 25 | } 26 | return &Encoder{hashName: name, hashFunc: hashFn, encoder: encoder, log: log.Log} 27 | } 28 | 29 | // SetLogger changes the logger for the encoder. 30 | func (e *Encoder) SetLogger(l log.Logger) { 31 | e.log = l 32 | } 33 | 34 | // Encode writes the value of ptr to stream. 35 | func (e *Encoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { 36 | e.log.Debug("encoding hash") 37 | if e.IsEmpty(ptr) || e.hashFunc == nil { 38 | e.log.Warn("cannot encode empty ptr or nil hash function") 39 | e.encoder.Encode(ptr, stream) 40 | return 41 | } 42 | prefix := e.hashName + ":" 43 | if strings.HasPrefix(*((*string)(ptr)), prefix) { 44 | e.log.Warn("do not re-hash field") 45 | e.encoder.Encode(ptr, stream) 46 | return 47 | } 48 | data := *((*[]byte)(ptr)) 49 | e.log.Debugf("hash string: %s", string(data)) 50 | hash, err := e.hashFunc(data) 51 | if err == nil { 52 | hash = string(prefix) + hash 53 | e.log.Debugf("encoding hash: %s", hash) 54 | ptr = unsafe.Pointer(&hash) 55 | } else { 56 | e.log.Error(err) 57 | } 58 | e.encoder.Encode(ptr, stream) 59 | } 60 | 61 | // IsEmpty returns true is ptr is empty, otherwise false. 62 | func (e *Encoder) IsEmpty(ptr unsafe.Pointer) bool { 63 | return e.encoder.IsEmpty(ptr) 64 | } 65 | -------------------------------------------------------------------------------- /encoding/json/encoders/hash/encoder_test.go: -------------------------------------------------------------------------------- 1 | package hash 2 | 3 | import ( 4 | "bytes" 5 | "reflect" 6 | "testing" 7 | "unsafe" 8 | 9 | "github.com/jrapoport/chestnut/encoding/tags" 10 | jsoniter "github.com/json-iterator/go" 11 | "github.com/modern-go/reflect2" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestHashEncoder(t *testing.T) { 16 | var tests = []struct { 17 | in []byte 18 | out string 19 | assertEmpty assert.BoolAssertionFunc 20 | }{ 21 | { 22 | nil, 23 | `""`, 24 | assert.True, 25 | }, 26 | { 27 | []byte(""), 28 | `""`, 29 | assert.True, 30 | }, 31 | { 32 | []byte("abcdefghijklmnopqrstuvwxyz"), 33 | `"sha256:71c480df93d6ae2f1efad1447c66c9525e316218cf51fc8d9ed832f2daf18b73"`, 34 | assert.False, 35 | }, 36 | { 37 | []byte("abcdefghijklmnopqrstuvwxyz1234567890"), 38 | `"sha256:77d721c817f9d216c1fb783bcad9cdc20aaa2427402683f1f75dd6dfbe657470"`, 39 | assert.False, 40 | }, 41 | } 42 | for _, test := range tests { 43 | var buf bytes.Buffer 44 | conf := jsoniter.ConfigDefault 45 | valEncoder := conf.EncoderOf(reflect2.DefaultTypeOfKind(reflect.String)) 46 | stream := jsoniter.NewStream(conf, &buf, 100) 47 | stream.Reset(&buf) 48 | he := NewHashEncoder(tags.HashSHA256, EncodeToSHA256, valEncoder) 49 | he.Encode(unsafe.Pointer(&test.in), stream) 50 | assert.Equal(t, test.out, string(stream.Buffer())) 51 | test.assertEmpty(t, he.IsEmpty(unsafe.Pointer(&test.in))) 52 | } 53 | } 54 | 55 | func TestHashEncoder_NoRehash(t *testing.T) { 56 | var testIn = "sha256:71c480df93d6ae2f1efad1447c66c9525e316218cf51fc8d9ed832f2daf18b73" 57 | const testOut = `"sha256:71c480df93d6ae2f1efad1447c66c9525e316218cf51fc8d9ed832f2daf18b73"` 58 | var buf bytes.Buffer 59 | conf := jsoniter.ConfigDefault 60 | valEncoder := conf.EncoderOf(reflect2.DefaultTypeOfKind(reflect.String)) 61 | stream := jsoniter.NewStream(conf, &buf, 100) 62 | stream.Reset(&buf) 63 | he := NewHashEncoder(tags.HashSHA256, EncodeToSHA256, valEncoder) 64 | he.Encode(unsafe.Pointer(&testIn), stream) 65 | assert.Equal(t, testOut, string(stream.Buffer())) 66 | } 67 | -------------------------------------------------------------------------------- /encoding/json/encoders/hash/hash.go: -------------------------------------------------------------------------------- 1 | package hash 2 | 3 | import ( 4 | "encoding/hex" 5 | 6 | "github.com/jrapoport/chestnut/encoding/tags" 7 | "github.com/jrapoport/chestnut/encryptor/crypto" 8 | ) 9 | 10 | // HashingFunction defines the prototype for the hash callback. Defaults to EncodeToSHA256. 11 | type HashingFunction func(buf []byte) (hash string, err error) 12 | 13 | // FunctionForName returns the hash function for a given otherwise nil (passthrough). 14 | func FunctionForName(name tags.Hash) HashingFunction { 15 | switch name { 16 | case tags.HashSHA256: 17 | return EncodeToSHA256 18 | default: 19 | return nil 20 | } 21 | } 22 | 23 | // EncodeToSHA256 returns a sha256 hash of data as string. 24 | var EncodeToSHA256 = func(buf []byte) (string, error) { 25 | hash, err := crypto.HashSHA256(buf) 26 | if err != nil { 27 | return "", err 28 | } 29 | return hex.EncodeToString(hash), nil 30 | } 31 | -------------------------------------------------------------------------------- /encoding/json/encoders/hash/hash_test.go: -------------------------------------------------------------------------------- 1 | package hash 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jrapoport/chestnut/encoding/tags" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestHashFunctionForName(t *testing.T) { 11 | fn := FunctionForName(tags.HashNone) 12 | assert.Nil(t, fn) 13 | fn = FunctionForName(tags.HashSHA256) 14 | assert.NotNil(t, fn) 15 | h1, err := EncodeToSHA256([]byte("test")) 16 | assert.NoError(t, err) 17 | h2, err := fn([]byte("test")) 18 | assert.NoError(t, err) 19 | assert.Equal(t, h1, h2) 20 | } 21 | 22 | func TestEncodeToSHA256(t *testing.T) { 23 | var tests = []struct { 24 | in []byte 25 | out string 26 | }{ 27 | { 28 | nil, 29 | "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", 30 | }, 31 | { 32 | []byte(""), 33 | "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", 34 | }, 35 | { 36 | []byte("abcdefghijklmnopqrstuvwxyz"), 37 | "71c480df93d6ae2f1efad1447c66c9525e316218cf51fc8d9ed832f2daf18b73", 38 | }, 39 | {[]byte("abcdefghijklmnopqrstuvwxyz1234567890"), 40 | "77d721c817f9d216c1fb783bcad9cdc20aaa2427402683f1f75dd6dfbe657470", 41 | }, 42 | } 43 | for _, test := range tests { 44 | h, err := EncodeToSHA256(test.in) 45 | assert.NoError(t, err) 46 | assert.Equal(t, test.out, h, test.in) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /encoding/json/encoders/lookup/decoder.go: -------------------------------------------------------------------------------- 1 | package lookup 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | "unsafe" 8 | 9 | "github.com/jrapoport/chestnut/log" 10 | jsoniter "github.com/json-iterator/go" 11 | "github.com/modern-go/reflect2" 12 | ) 13 | 14 | // Decoder is a ValDecoder that reads lookup table key strings from the iterator. When 15 | // a key is found in the lookup table it decodes the lookup table data in place of the key. 16 | type Decoder struct { 17 | token string 18 | stream *jsoniter.Stream 19 | valType reflect2.Type 20 | decoder jsoniter.ValDecoder 21 | log log.Logger 22 | } 23 | 24 | // NewLookupDecoder returns a decoder that reads a lookup table. It will check the 25 | // iterated string values to see if they match our lookup token. If there is a match, 26 | // it will replace it with a decoded value from the lookup table or an empty value. 27 | func NewLookupDecoder(ctx *Context, typ reflect2.Type, decoder jsoniter.ValDecoder) jsoniter.ValDecoder { 28 | logger := log.Log 29 | if decoder == nil { 30 | logger.Panic(errors.New("value encoder required")) 31 | return nil 32 | } 33 | if typ == nil { 34 | logger.Panic(errors.New("decoder typ required")) 35 | return nil 36 | } 37 | if ctx == nil { 38 | logger.Panic(errors.New("lookup context required")) 39 | return nil 40 | } 41 | if ctx.Token == "" { 42 | logger.Panic(errors.New("lookup token required")) 43 | return nil 44 | } 45 | if ctx.Stream == nil { 46 | logger.Panic(errors.New("lookup stream required")) 47 | return nil 48 | } 49 | return &Decoder{ 50 | token: ctx.Token, 51 | stream: ctx.Stream, 52 | valType: typ, 53 | decoder: decoder, 54 | log: logger, 55 | } 56 | } 57 | 58 | // SetLogger changes the logger for the decoder. 59 | func (d *Decoder) SetLogger(l log.Logger) { 60 | d.log = l 61 | } 62 | 63 | // Decode sets ptr to the next value of iterator. 64 | func (d *Decoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { 65 | d.log.Debugf("decoding type %s", d.valType) 66 | // if we are dealing with an empty interface, skip it. 67 | if d.isEmptyInterface(ptr) { 68 | d.log.Warn("cannot encode to empty interface") 69 | iter.Skip() 70 | return 71 | } 72 | // we really shouldn't be here with an invalid token, if for 73 | // some reason we are, call the default decoder and bail. 74 | if d.token == InvalidToken { 75 | d.log.Warn("invalid token") 76 | d.decoder.Decode(ptr, iter) 77 | return 78 | } 79 | // get the 'from' type 80 | fromType := iter.WhatIsNext() 81 | // secure tokens will be type string. if this is not 82 | // a string, call the default decoder and bail. 83 | if fromType != jsoniter.StringValue { 84 | d.log.Debug("skipping non-string value") 85 | d.decoder.Decode(ptr, iter) 86 | return 87 | } 88 | // read the string & format a key 89 | val := iter.ReadRawString() 90 | key := Key(val) 91 | // check to see if it is one of ours 92 | if !key.IsTokenKey(d.token) { 93 | // we use an Iterator avoid setting the ptr directly since it might be a string 94 | // or an interface or who knows what. this was the codecs handle it for us. 95 | subIter := iter.Pool().BorrowIterator([]byte(`"` + key + `"`)) 96 | defer iter.Pool().ReturnIterator(subIter) 97 | d.log.Debugf("decode value: %s", val) 98 | d.decoder.Decode(ptr, subIter) 99 | return 100 | } 101 | // we have a valid lookup key. look it up in our table 102 | ent, err := d.lookupKey(key) 103 | // did we find something in the lookup table? 104 | if err != nil || ent == nil { 105 | d.log.Debugf("lookup entry not found: %s", key) 106 | // this is expected when sparse decoding a struct. 107 | if d.valType.Kind() == reflect.Interface { 108 | d.log.Debugf("decode empty %s for interface", key.Kind()) 109 | // if we have a map then set an explicitly typed empty value 110 | *(*interface{})(ptr) = emptyValueOfKind(key.Kind()) 111 | } 112 | return 113 | } 114 | // clear the buffer 115 | d.stream.Reset(nil) 116 | ent.WriteTo(d.stream) 117 | subIter := iter.Pool().BorrowIterator(d.stream.Buffer()) 118 | defer iter.Pool().ReturnIterator(subIter) 119 | // decode the string 120 | d.decoder.Decode(ptr, subIter) 121 | d.log.Debugf("decoded lookup entry for %s: %s", key, string(d.stream.Buffer())) 122 | } 123 | 124 | func (d *Decoder) lookupKey(key Key) (jsoniter.Any, error) { 125 | d.log.Debugf("lookup key: %s", key) 126 | logErr := func(err error) error { 127 | d.log.Error(err) 128 | return err 129 | } 130 | if d.stream == nil { 131 | return nil, logErr(errors.New("lookup stream not found")) 132 | } 133 | table, ok := d.stream.Attachment.(jsoniter.Any) 134 | if !ok || table == nil { 135 | return nil, logErr(errors.New("lookup table not found")) 136 | } 137 | val := table.Get(key.String()) 138 | if val.ValueType() == jsoniter.InvalidValue { 139 | err := fmt.Errorf("lookup key not found: %s", key) 140 | d.log.Debug(err) // this is an expected error 141 | return nil, err 142 | } 143 | d.log.Debugf("lookup found %s for key %s: %s", val.ValueType(), key, val.ToString()) 144 | return val, nil 145 | } 146 | 147 | func (d *Decoder) isEmptyInterface(ptr unsafe.Pointer) bool { 148 | if d.valType.Kind() != reflect.Interface { 149 | return false 150 | } 151 | i, ok := d.valType.(*reflect2.UnsafeIFaceType) 152 | if !ok { 153 | return false 154 | } 155 | return reflect2.IsNil(i.UnsafeIndirect(ptr)) 156 | } 157 | 158 | func emptyValueOfKind(kind reflect.Kind) interface{} { 159 | var v interface{} 160 | switch kind { 161 | case reflect.String: 162 | v = "" 163 | case reflect.Bool: 164 | v = false 165 | case reflect.Uint8, reflect.Int8, 166 | reflect.Uint16, reflect.Int16, 167 | reflect.Uint32, reflect.Int32, 168 | reflect.Uint64, reflect.Int64, 169 | reflect.Uint, reflect.Int, 170 | reflect.Float32, reflect.Float64, 171 | reflect.Uintptr: 172 | v = 0.0 173 | default: 174 | } 175 | return v 176 | } 177 | -------------------------------------------------------------------------------- /encoding/json/encoders/lookup/decoder_test.go: -------------------------------------------------------------------------------- 1 | package lookup 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | "unsafe" 8 | 9 | "github.com/jrapoport/chestnut/encoding/json/encoders" 10 | jsoniter "github.com/json-iterator/go" 11 | "github.com/modern-go/reflect2" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestLookupDecoder_Decode(t *testing.T) { 16 | type testObject struct { 17 | Value string 18 | } 19 | tests := []struct { 20 | value interface{} 21 | key string 22 | encoding string 23 | }{ 24 | { 25 | "a-string-value", 26 | `"tst0xtesting%d_24"`, 27 | `"a-string-value"`, 28 | }, 29 | { 30 | []string{"a-string-slice"}, 31 | `"tst0xtesting%d_23"`, 32 | `["a-string-slice"]`, 33 | }, 34 | { 35 | 99.9, 36 | `"tst0xtesting%d_14"`, 37 | `99.9`, 38 | }, 39 | { 40 | testObject{"a-struct-value"}, 41 | `"tst0xtesting%d_25"`, 42 | `{"Value":"a-struct-value"}`, 43 | }, 44 | { 45 | &testObject{"a-struct-ptr-value"}, 46 | `"tst0xtesting%d_22"`, 47 | `{"Value":"a-struct-ptr-value"}`, 48 | }, 49 | { 50 | &testObject{"a\nstruct\tptr\"value"}, 51 | `"tst0xtesting%d_22"`, 52 | `{"Value":"a\nstruct\tptr\"value"}`, 53 | }, 54 | } 55 | lookUpTable := "{" 56 | for i, test := range tests { 57 | key := fmt.Sprintf(test.key, i) 58 | if i > 0 { 59 | lookUpTable += "," 60 | } 61 | entry := fmt.Sprintf("%s:%s", key, test.encoding) 62 | lookUpTable += entry 63 | } 64 | lookUpTable += "}" 65 | ctx := &Context{ 66 | NewLookupToken(testPrefix, testID), 67 | newTestStream(t), 68 | } 69 | enc := encoders.NewEncoder() 70 | ctx.Stream.Attachment = enc.Get([]byte(lookUpTable)) 71 | for i, test := range tests { 72 | typ := reflect2.TypeOf(&test.value) 73 | decoder := enc.DecoderOf(typ) 74 | le := NewLookupDecoder(ctx, typ, decoder) 75 | key := fmt.Sprintf(test.key, i) 76 | iter := enc.BorrowIterator([]byte(key)) 77 | ptr := reflect.New(reflect.TypeOf(test.value)).Interface() 78 | le.Decode(unsafe.Pointer(&ptr), iter) 79 | enc.ReturnIterator(iter) 80 | assert.Equal(t, test.encoding, string(ctx.Stream.Buffer())) 81 | any := jsoniter.Get(ctx.Stream.Buffer()) 82 | assert.NotEqual(t, jsoniter.InvalidValue, any.ValueType()) 83 | } 84 | } 85 | 86 | func TestLookupEncoder_NewLookupDecoder(t *testing.T) { 87 | encoder := encoders.NewEncoder() 88 | str := "a-string" 89 | typ := reflect2.TypeOf(&str) 90 | enc := encoder.DecoderOf(typ) 91 | bad1 := &Context{} 92 | bad2 := &Context{InvalidToken, newTestStream(t)} 93 | bad3 := &Context{"a-string-value", nil} 94 | good := &Context{"a-string-value", newTestStream(t)} 95 | for _, ctx := range []*Context{nil, bad1, bad2, bad3, good} { 96 | for _, tp := range []reflect2.Type{nil, typ} { 97 | for _, ve := range []jsoniter.ValDecoder{nil, enc} { 98 | if ctx == good && tp == typ && ve == enc { 99 | continue 100 | } 101 | assert.Panics(t, func() { 102 | _ = NewLookupDecoder(ctx, tp, ve) 103 | }, ctx, tp, enc) 104 | } 105 | } 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /encoding/json/encoders/lookup/encoder.go: -------------------------------------------------------------------------------- 1 | package lookup 2 | 3 | import ( 4 | "errors" 5 | "unsafe" 6 | 7 | "github.com/jrapoport/chestnut/encoding/json/encoders" 8 | "github.com/jrapoport/chestnut/log" 9 | jsoniter "github.com/json-iterator/go" 10 | "github.com/modern-go/reflect2" 11 | ) 12 | 13 | var cleanEncoder = encoders.NewEncoder() 14 | 15 | // Encoder is a ValEncoder that encodes the data to lookup table and encodes a 16 | // entry key for the data into the stream that can be read later by the decoder. 17 | type Encoder struct { 18 | token string 19 | stream *jsoniter.Stream 20 | valType reflect2.Type 21 | encoder jsoniter.ValEncoder 22 | log log.Logger 23 | } 24 | 25 | // NewLookupEncoder returns an encoder that builds a lookup table. It will strip out tagged 26 | // struct fields and collect the encoded values in the provided stream as a map. As it strips 27 | // out values, it replaces them with a token key for the lookup table. Later we can use this 28 | // key as a lookup to reconstruct the encoded struct as it is decoded. The hash encoder must 29 | // be run before this encoder, so the struct fields are hashed before they are stripped. 30 | func NewLookupEncoder(ctx *Context, typ reflect2.Type, encoder jsoniter.ValEncoder) jsoniter.ValEncoder { 31 | logger := log.Log 32 | if encoder == nil { 33 | logger.Panic(errors.New("value encoder required")) 34 | return nil 35 | } 36 | if typ == nil { 37 | logger.Panic(errors.New("encoder type required")) 38 | return nil 39 | } 40 | if ctx == nil { 41 | logger.Panic(errors.New("lookup context required")) 42 | return nil 43 | } 44 | if ctx.Token == InvalidToken { 45 | logger.Panic(errors.New("lookup token required")) 46 | return nil 47 | } 48 | if ctx.Stream == nil { 49 | logger.Panic(errors.New("lookup stream required")) 50 | return nil 51 | } 52 | return &Encoder{ 53 | token: ctx.Token, 54 | stream: ctx.Stream, 55 | valType: typ, 56 | encoder: encoder, 57 | log: logger, 58 | } 59 | } 60 | 61 | // SetLogger changes the logger for the encoder. 62 | func (e *Encoder) SetLogger(l log.Logger) { 63 | e.log = l 64 | } 65 | 66 | // Encode writes the value of ptr to stream. 67 | func (e *Encoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { 68 | e.log.Debugf("encoding type %s", e.valType) 69 | // FIXME: I've looked around for a way to avoid this, or unwrap the encoder, but it's 70 | // not clear what the best way to do that is or if it's possible with jsoniter as-is. 71 | // NOTE: This is *SUPER important*. This is so when UpdateStructDescriptor is called 72 | // recursively for nested structs the ValEncoder we use is a ORIGINAL ValEncoder, and 73 | // NOT a copy of our modified Encode (that strips out values). If we don't do this, 74 | // tagged fields will also be stripped out of our steam and not just the encoded stream. 75 | // We know this is happening because when it does: encoding stream == lookup stream. 76 | if stream == e.stream { 77 | // we are being called recursively so try and get a clean encoder. 78 | if subEncoder := cleanEncoder.EncoderOf(e.valType); subEncoder != nil { 79 | e.log.Debugf("use sub-encoder type %s", e.valType) 80 | // use the clean encoder to encode to our own stream. 81 | subEncoder.Encode(ptr, stream) 82 | } 83 | return 84 | } 85 | // encode the ptr to the lookup table 86 | key := e.encodeLookup(ptr, e.nextIndex()) 87 | // encode our lookup key to the main stream 88 | e.log.Debugf("encoded lookup key: %s", key) 89 | stream.WriteString(key.String()) 90 | } 91 | 92 | // IsEmpty returns true is ptr is empty, otherwise false. 93 | func (e *Encoder) IsEmpty(ptr unsafe.Pointer) bool { 94 | return e.encoder.IsEmpty(ptr) 95 | } 96 | 97 | func (e *Encoder) encodeLookup(ptr unsafe.Pointer, tableIndex int) Key { 98 | key := NewLookupKey(e.token, tableIndex, e.valType) 99 | // encode the actual data into our lookup table 100 | if tableIndex > 0 { 101 | e.stream.WriteMore() 102 | } 103 | e.stream.WriteObjectField(key.String()) 104 | e.encoder.Encode(ptr, e.stream) 105 | e.log.Debugf("encoded lookup for key %s: %s", string(e.stream.Buffer()), key) 106 | return key 107 | } 108 | 109 | // we shouldn't need locking here since it should not to be called concurrently. 110 | func (e *Encoder) nextIndex() int { 111 | idx, _ := e.stream.Attachment.(int) 112 | e.stream.Attachment = idx + 1 113 | return idx 114 | } 115 | -------------------------------------------------------------------------------- /encoding/json/encoders/lookup/encoder_test.go: -------------------------------------------------------------------------------- 1 | package lookup 2 | 3 | import ( 4 | "fmt" 5 | "github.com/jrapoport/chestnut/log" 6 | jsoniter "github.com/json-iterator/go" 7 | "testing" 8 | 9 | "github.com/jrapoport/chestnut/encoding/json/encoders" 10 | "github.com/modern-go/reflect2" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestLookupEncoder_Encode(t *testing.T) { 15 | type testObject struct { 16 | Value string 17 | } 18 | tests := []struct { 19 | value interface{} 20 | key string 21 | encoding string 22 | }{ 23 | { 24 | "a-string-value", 25 | `"tst0xtesting%d_24"`, 26 | `"a-string-value"`, 27 | }, 28 | { 29 | []string{"a-string-slice"}, 30 | `"tst0xtesting%d_23"`, 31 | `["a-string-slice"]`, 32 | }, 33 | { 34 | 99.9, 35 | `"tst0xtesting%d_14"`, 36 | `99.9`, 37 | }, 38 | { 39 | testObject{"a-struct-value"}, 40 | `"tst0xtesting%d_25"`, 41 | `{"Value":"a-struct-value"}`, 42 | }, 43 | { 44 | &testObject{"a-struct-ptr-value"}, 45 | `"tst0xtesting%d_22"`, 46 | `{"Value":"a-struct-ptr-value"}`, 47 | }, 48 | { 49 | &testObject{"a\nstruct\tptr\"value"}, 50 | `"tst0xtesting%d_22"`, 51 | `{"Value":"a\nstruct\tptr\"value"}`, 52 | }, 53 | } 54 | encoded := "" 55 | lookup := "" 56 | stream := newTestStream(t) 57 | ctx := &Context{ 58 | NewLookupToken(testPrefix, testID), 59 | newTestStream(t), 60 | } 61 | enc := encoders.NewEncoder() 62 | for i, test := range tests { 63 | typ := reflect2.TypeOf(test.value) 64 | encoder := enc.EncoderOf(typ) 65 | le := NewLookupEncoder(ctx, typ, encoder) 66 | le.Encode(reflect2.PtrOf(test.value), stream) 67 | key := fmt.Sprintf(test.key, i) 68 | encoded += key 69 | assert.Equal(t, encoded, string(stream.Buffer())) 70 | if i > 0 { 71 | lookup += "," 72 | } 73 | entry := fmt.Sprintf("%s:%s", key, test.encoding) 74 | lookup += entry 75 | assert.Equal(t, lookup, string(ctx.Stream.Buffer())) 76 | } 77 | } 78 | 79 | func TestLookupEncoder_IsEmpty(t *testing.T) { 80 | tests := []struct { 81 | value interface{} 82 | assertEmpty assert.BoolAssertionFunc 83 | }{ 84 | {"", assert.True}, 85 | {"not-empty", assert.False}, 86 | {[]string{}, assert.True}, 87 | {[]string{"not-empty"}, assert.False}, 88 | } 89 | encoder := encoders.NewEncoder() 90 | for _, test := range tests { 91 | enc := encoder.EncoderOf(reflect2.TypeOf(test.value)) 92 | le := &Encoder{encoder: enc} 93 | empty := le.IsEmpty(reflect2.PtrOf(test.value)) 94 | test.assertEmpty(t, empty, "value: %v", test.value) 95 | } 96 | } 97 | 98 | func TestLookupEncoder_NewLookupEncoder(t *testing.T) { 99 | encoder := encoders.NewEncoder() 100 | typ := reflect2.TypeOf("a-string") 101 | enc := encoder.EncoderOf(typ) 102 | bad1 := &Context{} 103 | bad2 := &Context{InvalidToken, newTestStream(t)} 104 | bad3 := &Context{"a-string-value", nil} 105 | good := &Context{"a-string-value", newTestStream(t)} 106 | for _, ctx := range []*Context{nil, bad1, bad2, bad3, good} { 107 | for _, tp := range []reflect2.Type{nil, typ} { 108 | for _, ve := range []jsoniter.ValEncoder{nil, enc} { 109 | if ctx == good && tp == typ && ve == enc { 110 | continue 111 | } 112 | assert.Panics(t, func() { 113 | _ = NewLookupEncoder(ctx, tp, ve) 114 | }, ctx, tp, enc) 115 | } 116 | } 117 | } 118 | } 119 | 120 | func TestLookupEncoder_Fallback(t *testing.T) { 121 | strVal := "not-empty" 122 | stream := newTestStream(t) 123 | encoder := encoders.NewEncoder() 124 | kty := reflect2.TypeOf("a-string") 125 | enc := encoder.EncoderOf(kty) 126 | le := &Encoder{stream: stream, valType: kty, encoder: enc, log: log.Log} 127 | le.Encode(reflect2.PtrOf(strVal), stream) 128 | } 129 | -------------------------------------------------------------------------------- /encoding/json/encoders/lookup/lookup.go: -------------------------------------------------------------------------------- 1 | package lookup 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strconv" 7 | "strings" 8 | 9 | jsoniter "github.com/json-iterator/go" 10 | "github.com/modern-go/reflect2" 11 | ) 12 | 13 | // InvalidToken is an invalid lookup token. 14 | const InvalidToken = "" 15 | 16 | const tokenSeparator = "_" 17 | 18 | // NewLookupToken returns the field name for sparse encoded data 19 | // as the encoder id with the format "[prefix]-[encoder id]". 20 | func NewLookupToken(prefix, encoderID string) string { 21 | return fmt.Sprintf("%s%s", prefix, encoderID) 22 | } 23 | 24 | // Key is an encoded lookup data key. 25 | type Key string 26 | 27 | // NewLookupKey creates a new lookup table key with the encoding field index and type. The field 28 | // index is *not* the index relative to a StructField, but relative to the JSON encoding itself. 29 | func NewLookupKey(token string, index int, typ reflect2.Type) Key { 30 | return Key(fmt.Sprintf("%s%d%s%d", token, index, tokenSeparator, typ.Kind())) 31 | } 32 | 33 | // IsTokenKey returns true if the key was derived from the lookup token. 34 | func (k Key) IsTokenKey(token string) bool { 35 | return strings.HasPrefix(k.String(), token) 36 | } 37 | 38 | // Kind returns the encoded reflect.Kind for the key. 39 | func (k Key) Kind() reflect.Kind { 40 | parts := strings.Split(k.String(), tokenSeparator) 41 | if len(parts) < 2 { 42 | return reflect.Invalid 43 | } 44 | // the last part should be the type 45 | kind, err := strconv.Atoi(parts[len(parts)-1]) 46 | if err != nil { 47 | return reflect.Invalid 48 | } 49 | return reflect.Kind(kind) 50 | } 51 | 52 | func (k Key) String() string { 53 | return string(k) 54 | } 55 | 56 | // Context holds the context for the lookup coders. 57 | type Context struct { 58 | Token string 59 | Stream *jsoniter.Stream 60 | } 61 | -------------------------------------------------------------------------------- /encoding/json/encoders/lookup/lookup_test.go: -------------------------------------------------------------------------------- 1 | package lookup 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | jsoniter "github.com/json-iterator/go" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | const testID = "0xtesting" 12 | const testPrefix = "tst" 13 | 14 | func newTestStream(t *testing.T) *jsoniter.Stream { 15 | var buf bytes.Buffer 16 | conf := jsoniter.ConfigDefault 17 | stream := jsoniter.NewStream(conf, &buf, 4096) 18 | stream.Reset(&buf) 19 | assert.NotNil(t, stream) 20 | return stream 21 | } 22 | -------------------------------------------------------------------------------- /encoding/json/encoders/secure/decoder_test.go: -------------------------------------------------------------------------------- 1 | package secure 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/jrapoport/chestnut/encoding/compress/zstd" 9 | "github.com/jrapoport/chestnut/encoding/json/encoders" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | type decoderTest struct { 14 | src []byte 15 | unsealed []byte 16 | dst interface{} 17 | res interface{} 18 | mp map[string]interface{} 19 | sparse Option 20 | } 21 | 22 | var decoderTests = []decoderTest{ 23 | {noneSealed, noneUnsealed, &None{}, noneDecoded, noneMap, noOpt}, 24 | {noneSealed, noneUnsealed, &None{}, noneDecoded, noneMap, sparseOpt}, 25 | {noneComp, noneUnsealed, &None{}, noneDecoded, noneMap, noOpt}, 26 | {noneComp, noneUnsealed, &None{}, noneDecoded, noneMap, sparseOpt}, 27 | {jsonSealed, jsonUnsealed, &JSON{}, jsonDecoded, jsonMap, noOpt}, 28 | {jsonSealed, jsonUnsealed, &JSON{}, jsonDecoded, jsonMap, sparseOpt}, 29 | {jsonComp, jsonUnsealed, &JSON{}, jsonDecoded, jsonMap, noOpt}, 30 | {jsonComp, jsonUnsealed, &JSON{}, jsonDecoded, jsonMap, sparseOpt}, 31 | {ctrlSealed, ctrlUnsealed, &Escape{}, ctrlDecoded, ctrlMap, noOpt}, 32 | {ctrlSealed, ctrlUnsealed, &Escape{}, ctrlSparse, ctrlMapSparse, sparseOpt}, 33 | {ctrlComp, ctrlUnsealed, &Escape{}, ctrlDecoded, ctrlMap, noOpt}, 34 | {ctrlComp, ctrlUnsealed, &Escape{}, ctrlSparse, ctrlMapSparse, sparseOpt}, 35 | {hashSealed, hashUnsealed, &Hash{}, hashDecoded, hashMap, noOpt}, 36 | {hashSealed, hashUnsealed, &Hash{}, hashDecoded, hashMap, sparseOpt}, 37 | {hashComp, hashUnsealed, &Hash{}, hashDecoded, hashMap, noOpt}, 38 | {hashComp, hashUnsealed, &Hash{}, hashDecoded, hashMap, sparseOpt}, 39 | {secSealed, secUnsealed, &Secure{}, secDecoded, secMap, noOpt}, 40 | {secSealed, secUnsealed, &Secure{}, secSparse, secMapSparse, sparseOpt}, 41 | {secComp, secUnsealed, &Secure{}, secDecoded, secMap, noOpt}, 42 | {secComp, secUnsealed, &Secure{}, secSparse, secMapSparse, sparseOpt}, 43 | {bothSealed, bothUnsealed, &Both{}, bothDecoded, bothMap, noOpt}, 44 | {bothSealed, bothUnsealed, &Both{}, bothSparse, bothMapSparse, sparseOpt}, 45 | {bothComp, bothUnsealed, &Both{}, bothDecoded, bothMap, noOpt}, 46 | {bothComp, bothUnsealed, &Both{}, bothSparse, bothMapSparse, sparseOpt}, 47 | {allSealed, allUnsealed, &All{SI: ifc{}}, allDecoded, allMap, noOpt}, 48 | {allSealed, allUnsealed, &All{SI: ifc{}}, allSparse, allMapSparse, sparseOpt}, 49 | {allComp, allUnsealed, &All{SI: ifc{}}, allDecoded, allMap, noOpt}, 50 | {allComp, allUnsealed, &All{SI: ifc{}}, allSparse, allMapSparse, sparseOpt}, 51 | } 52 | 53 | func TestSecureDecoderExtension(t *testing.T) { 54 | for _, test := range decoderTests { 55 | testName := reflect.TypeOf(test.dst).Elem().Name() 56 | if test.sparse != noOpt { 57 | testName += " sparse" 58 | } 59 | t.Run(testName, func(t *testing.T) { 60 | encoder := encoders.NewEncoder() 61 | // register decoding extension 62 | decoderExt := NewSecureDecoderExtension(testEncoderID, 63 | PassthroughDecryption, 64 | WithDecompressor(zstd.Decompress), 65 | test.sparse) 66 | encoder.RegisterExtension(decoderExt) 67 | // unseal the encoding 68 | unsealed, err := decoderExt.Unseal(test.src) 69 | assert.NoError(t, err) 70 | assert.Equal(t, test.unsealed, unsealed) 71 | // open the decoder 72 | err = decoderExt.Open() 73 | assert.NoError(t, err) 74 | // securely decode the value 75 | err = encoder.Unmarshal(unsealed, test.dst) 76 | assert.NoError(t, err) 77 | assertDecoding(t, test.res, test.dst, err) 78 | // securely decode the reflected interface 79 | typ := reflect.ValueOf(test.dst).Elem().Type() 80 | ptr := reflect.New(typ).Interface() 81 | err = encoder.Unmarshal(unsealed, ptr) 82 | assertDecoding(t, test.res, ptr, err) 83 | // securely decode the mapped struct 84 | var mapped interface{} 85 | err = encoder.Unmarshal(unsealed, &mapped) 86 | assertDecoding(t, test.mp, mapped, err) 87 | // close the decoder 88 | decoderExt.Close() 89 | }) 90 | } 91 | d := NewSecureDecoderExtension(encoders.InvalidID, PassthroughDecryption) 92 | assert.NotNil(t, d) 93 | assert.Empty(t, d.encoderID) 94 | assert.Panics(t, func() { 95 | _ = NewSecureDecoderExtension(encoders.InvalidID, nil) 96 | }) 97 | } 98 | 99 | func TestSecureDecoderExtension_BadUnseal(t *testing.T) { 100 | var i int 101 | badCompressor := func(data []byte) (compressed []byte, err error) { 102 | if i%2 != 0 && i < 10 { 103 | i++ 104 | return nil, errors.New("compression error") 105 | } 106 | i++ 107 | return nil, err 108 | } 109 | bade := true 110 | ext := NewSecureDecoderExtension(testEncoderID, func(plaintext []byte) (ciphertext []byte, err error) { 111 | if bade { 112 | return nil, errors.New("encryption error") 113 | } 114 | return nil, err 115 | }, 116 | WithCompressor(badCompressor)) 117 | err := ext.Open() 118 | assert.NoError(t, err) 119 | err = ext.Open() 120 | assert.Error(t, err) 121 | _, err = ext.Unseal(bothEncoded) 122 | assert.Error(t, err) 123 | ext.Close() 124 | _, err = ext.Unseal(bothEncoded) 125 | assert.Error(t, err) 126 | _, err = ext.Unseal(bothSealed) 127 | assert.Error(t, err) 128 | bade = false 129 | _, err = ext.Unseal(bothComp) 130 | assert.Error(t, err) 131 | i = 1 132 | _, err = ext.Unseal(bothComp) 133 | i = 0 134 | ext.Close() 135 | encoder := encoders.NewEncoder() 136 | encoder.RegisterExtension(ext) 137 | err = encoder.Unmarshal(allComp, &None{}) 138 | assert.Error(t, err) 139 | err = ext.Open() 140 | assert.NoError(t, err) 141 | assert.Panics(t, func() { 142 | ext.decryptFunc = nil 143 | _, err = ext.Unseal(bothComp) 144 | assert.Error(t, err) 145 | }) 146 | } 147 | 148 | func TestSecureDecoderExtension_BadOpen(t *testing.T) { 149 | ext := NewSecureDecoderExtension(testEncoderID, PassthroughDecryption) 150 | err := ext.Open() 151 | assert.NoError(t, err) 152 | err = ext.Open() 153 | assert.Error(t, err) 154 | ext.Close() 155 | ext.lookupCtx = nil 156 | err = ext.Open() 157 | assert.Error(t, err) 158 | ext.Close() 159 | } 160 | -------------------------------------------------------------------------------- /encoding/json/encoders/secure/encoder_test.go: -------------------------------------------------------------------------------- 1 | package secure 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/jrapoport/chestnut/encoding/json/encoders" 9 | "github.com/jrapoport/chestnut/encoding/json/packager" 10 | "github.com/jrapoport/chestnut/log" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | type encoderTest struct { 15 | src interface{} 16 | dst []byte 17 | sealed []byte 18 | compressed Option 19 | } 20 | 21 | var encoderTests = []encoderTest{ 22 | {noneObj, noneEncoded, noneSealed, noOpt}, 23 | {noneObj, noneEncoded, noneComp, compOpt}, 24 | {jsonObj, jsonEncoded, jsonSealed, noOpt}, 25 | {jsonObj, jsonEncoded, jsonComp, compOpt}, 26 | {ctrlObj, ctrlEncoded, ctrlSealed, noOpt}, 27 | {ctrlObj, ctrlEncoded, ctrlComp, compOpt}, 28 | {hashObj, hashEncoded, hashSealed, noOpt}, 29 | {hashObj, hashEncoded, hashComp, compOpt}, 30 | {secObj, secEncoded, secSealed, noOpt}, 31 | {secObj, secEncoded, secComp, compOpt}, 32 | {bothObj, bothEncoded, bothSealed, noOpt}, 33 | {bothObj, bothEncoded, bothComp, compOpt}, 34 | {allObj, allEncoded, allSealed, noOpt}, 35 | {allObj, allEncoded, allComp, compOpt}, 36 | } 37 | 38 | func TestSecureEncoderExtension(t *testing.T) { 39 | for _, test := range encoderTests { 40 | testName := reflect.TypeOf(test.src).Elem().Name() 41 | if test.compressed != noOpt { 42 | testName += " compressed" 43 | } 44 | t.Run(testName, func(t *testing.T) { 45 | encoder := encoders.NewEncoder() 46 | // register encoding extension 47 | encoderExt := NewSecureEncoderExtension(testEncoderID, 48 | PassthroughEncryption, 49 | WithLogger(log.Log), 50 | test.compressed) 51 | encoder.RegisterExtension(encoderExt) 52 | // open the encoder 53 | err := encoderExt.Open() 54 | assert.NoError(t, err) 55 | // securely encode the value 56 | encoded, err := encoder.Marshal(test.src) 57 | assertJSON(t, test.dst, encoded, err) 58 | // close the encoder 59 | encoderExt.Close() 60 | // seal the encoding 61 | sealed, err := encoderExt.Seal(encoded) 62 | assert.NoError(t, err) 63 | assert.Equal(t, test.sealed, sealed) 64 | // unwrap the sealed package & make sure it is valid 65 | pkg, err := packager.DecodePackage(sealed) 66 | assert.NoError(t, err) 67 | assert.NotNil(t, pkg) 68 | assert.NoError(t, pkg.Valid()) 69 | }) 70 | } 71 | e := NewSecureEncoderExtension(encoders.InvalidID, PassthroughEncryption) 72 | assert.NotNil(t, e) 73 | assert.NotEmpty(t, e.encoderID) 74 | assert.Panics(t, func() { 75 | _ = NewSecureEncoderExtension(encoders.InvalidID, nil) 76 | }) 77 | } 78 | 79 | func TestSecureEncoderExtension_BadSeal(t *testing.T) { 80 | var i int 81 | badCompressor := func(data []byte) (compressed []byte, err error) { 82 | if i%2 != 0 && i < 10 { 83 | i++ 84 | return nil, errors.New("compression error") 85 | } 86 | i++ 87 | return nil, err 88 | } 89 | bade := true 90 | ext := NewSecureEncoderExtension(testEncoderID, func(plaintext []byte) (ciphertext []byte, err error) { 91 | if bade { 92 | return nil, errors.New("encryption error") 93 | } 94 | return nil, err 95 | }, 96 | WithCompressor(badCompressor)) 97 | err := ext.Open() 98 | assert.NoError(t, err) 99 | i = 0 100 | ext.Close() 101 | ext.lookupBuffer = []byte("121343546432343546576453423142534653423142536435243142536463524") 102 | _, err = ext.Seal(bothEncoded) 103 | i = 1 104 | ext.Close() 105 | ext.lookupBuffer = []byte("121343546432343546576453423142534653423142536435243142536463524") 106 | _, err = ext.Seal(bothEncoded) 107 | i = 10 108 | ext.Close() 109 | assert.Error(t, err) 110 | ext.lookupBuffer = []byte("121343546432343546576453423142534653423142536435243142536463524") 111 | _, err = ext.Seal(bothEncoded) 112 | assert.Error(t, err) 113 | i = 10 114 | bade = false 115 | ext.Close() 116 | assert.Error(t, err) 117 | ext.lookupBuffer = []byte("121343546432343546576453423142534653423142536435243142536463524") 118 | ext.encoderID = encoders.InvalidID 119 | _, err = ext.Seal(bothEncoded) 120 | assert.Error(t, err) 121 | i = 10 122 | bade = false 123 | ext.Close() 124 | assert.Error(t, err) 125 | ext.lookupBuffer = []byte("121343546432343546576453423142534653423142536435243142536463524") 126 | ext.encoderID = testEncoderID 127 | ext.lookupCtx.Stream = nil 128 | _, err = ext.Seal(bothEncoded) 129 | assert.Error(t, err) 130 | } 131 | 132 | func TestSecureEncoderExtension_BadOpen(t *testing.T) { 133 | ext := NewSecureEncoderExtension(testEncoderID, PassthroughEncryption) 134 | err := ext.Open() 135 | assert.NoError(t, err) 136 | err = ext.Open() 137 | assert.Error(t, err) 138 | ext.Close() 139 | ctx := ext.lookupCtx 140 | ext.lookupCtx = nil 141 | err = ext.Open() 142 | assert.Error(t, err) 143 | ext.lookupCtx = ctx 144 | ext.lookupCtx.Token = encoders.InvalidID 145 | err = ext.Open() 146 | assert.Error(t, err) 147 | ext.lookupCtx = ctx 148 | ext.lookupCtx.Stream = nil 149 | err = ext.Open() 150 | assert.Error(t, err) 151 | } 152 | -------------------------------------------------------------------------------- /encoding/json/encoders/secure/options.go: -------------------------------------------------------------------------------- 1 | package secure 2 | 3 | import ( 4 | "github.com/jrapoport/chestnut/encoding/compress" 5 | "github.com/jrapoport/chestnut/encoding/compress/zstd" 6 | "github.com/jrapoport/chestnut/log" 7 | ) 8 | 9 | // Options provides a default implementation for common options for a secure encoding. 10 | type Options struct { 11 | // compressor is only valid for encoders 12 | compressor compress.CompressorFunc 13 | 14 | // decompressor is only valid for decoders 15 | decompressor compress.DecompressorFunc 16 | 17 | // sparse is only valid for decoding sparse packages 18 | sparse bool 19 | 20 | // log is the logger to use 21 | log log.Logger 22 | } 23 | 24 | // DefaultOptions represents the recommended default Options for secure encoding. 25 | var DefaultOptions = Options{ 26 | log: log.Log, 27 | } 28 | 29 | // A Option sets options such as compression or sparse decoding. 30 | type Option interface { 31 | apply(*Options) 32 | } 33 | 34 | // EmptyOption does not alter the encoder configuration. It can be embedded 35 | // in another structure to build custom encoder options. 36 | type EmptyOption struct{} 37 | 38 | func (EmptyOption) apply(*Options) {} 39 | 40 | // funcOption wraps a function that modifies Options 41 | // into an implementation of the Option interface. 42 | type funcOption struct { 43 | f func(*Options) 44 | } 45 | 46 | // apply applies an Option to Options. 47 | func (fdo *funcOption) apply(do *Options) { 48 | fdo.f(do) 49 | } 50 | 51 | func newFuncOption(f func(*Options)) *funcOption { 52 | return &funcOption{ 53 | f: f, 54 | } 55 | } 56 | 57 | // applyOptions accepts an Options struct and applies the Option(s) to it. 58 | func applyOptions(opts Options, opt ...Option) Options { 59 | if opt != nil { 60 | for _, o := range opt { 61 | o.apply(&opts) 62 | } 63 | } 64 | return opts 65 | } 66 | 67 | // SparseDecode returns an Option that set the decoder to return sparsely 68 | // decoded data. If the JSON data was not sparely encoded, this does nothing. 69 | func SparseDecode() Option { 70 | return newFuncOption(func(o *Options) { 71 | o.sparse = true 72 | }) 73 | } 74 | 75 | // WithCompressor returns an Option that compresses data. 76 | func WithCompressor(compressor compress.CompressorFunc) Option { 77 | return newFuncOption(func(o *Options) { 78 | o.compressor = compressor 79 | }) 80 | } 81 | 82 | // WithDecompressor returns an Option that decompresses data. 83 | func WithDecompressor(decompressor compress.DecompressorFunc) Option { 84 | return newFuncOption(func(o *Options) { 85 | o.decompressor = decompressor 86 | }) 87 | } 88 | 89 | // WithCompression returns an Option that compresses & decompresses data with Zstd. 90 | func WithCompression(format compress.Format) Option { 91 | return newFuncOption(func(o *Options) { 92 | switch format { 93 | case compress.Zstd: 94 | o.compressor = zstd.Compress 95 | o.decompressor = zstd.Decompress 96 | default: 97 | break 98 | } 99 | }) 100 | } 101 | 102 | // WithLogger returns an Option which sets the logger for the extension. 103 | func WithLogger(l log.Logger) Option { 104 | return newFuncOption(func(o *Options) { 105 | o.log = l 106 | }) 107 | } 108 | -------------------------------------------------------------------------------- /encoding/json/encoders/secure/secure_test.go: -------------------------------------------------------------------------------- 1 | package secure 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/jrapoport/chestnut/encoding/compress" 8 | "github.com/jrapoport/chestnut/encoding/json/encoders" 9 | jsoniter "github.com/json-iterator/go" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | const testEncoderID = "86fb3fa0" 14 | 15 | type testCase struct { 16 | src interface{} 17 | dst interface{} 18 | res interface{} 19 | mp map[string]interface{} 20 | sparse Option 21 | } 22 | 23 | var ( 24 | noOpt = EmptyOption{} 25 | // ignored on non-sparse packages 26 | sparseOpt = SparseDecode() 27 | compOpt = WithCompression(compress.Zstd) 28 | ) 29 | 30 | var tests = []testCase{ 31 | {noneObj, &None{}, noneDecoded, noneMap, noOpt}, 32 | {noneObj, &None{}, noneDecoded, noneMap, sparseOpt}, 33 | {jsonObj, &JSON{}, jsonDecoded, jsonMap, noOpt}, 34 | {jsonObj, &JSON{}, jsonDecoded, jsonMap, sparseOpt}, 35 | {ctrlObj, &Escape{}, ctrlDecoded, ctrlMap, noOpt}, 36 | {ctrlObj, &Escape{}, ctrlSparse, ctrlMapSparse, sparseOpt}, 37 | {hashObj, &Hash{}, hashDecoded, hashMap, noOpt}, 38 | {hashObj, &Hash{}, hashDecoded, hashMap, sparseOpt}, 39 | {secObj, &Secure{}, secDecoded, secMap, noOpt}, 40 | {secObj, &Secure{}, secSparse, secMapSparse, sparseOpt}, 41 | {bothObj, &Both{}, bothDecoded, bothMap, noOpt}, 42 | {bothObj, &Both{}, bothSparse, bothMapSparse, sparseOpt}, 43 | {allObj, &All{SI: ifc{}}, allDecoded, allMap, noOpt}, 44 | {allObj, &All{SI: ifc{}}, allSparse, allMapSparse, sparseOpt}, 45 | } 46 | 47 | func TestSecureExtension(t *testing.T) { 48 | comps := []Option{noOpt, compOpt} 49 | for _, compressed := range comps { 50 | for _, test := range tests { 51 | testName := reflect.TypeOf(test.dst).Elem().Name() 52 | if test.sparse != noOpt { 53 | testName += " sparse" 54 | } 55 | if compressed != noOpt { 56 | testName += " compressed" 57 | } 58 | t.Run(testName, func(t *testing.T) { 59 | encoder := encoders.NewEncoder() 60 | // register encoding extension 61 | encoderExt := NewSecureEncoderExtension(testEncoderID, 62 | PassthroughEncryption, compressed) 63 | encoder.RegisterExtension(encoderExt) 64 | // register decoding extension 65 | decoderExt := NewSecureDecoderExtension(testEncoderID, 66 | PassthroughDecryption, compressed, test.sparse) 67 | encoder.RegisterExtension(decoderExt) 68 | // open the encoder 69 | err := encoderExt.Open() 70 | assert.NoError(t, err) 71 | // securely encode the value 72 | encoded, err := encoder.Marshal(test.src) 73 | assert.NoError(t, err) 74 | // close the encoder 75 | encoderExt.Close() 76 | // seal the encoding 77 | sealed, err := encoderExt.Seal(encoded) 78 | assert.NoError(t, err) 79 | // unseal the encoding 80 | unsealed, err := decoderExt.Unseal(sealed) 81 | assert.NoError(t, err) 82 | // open the decoder 83 | err = decoderExt.Open() 84 | assert.NoError(t, err) 85 | // securely decode the value 86 | err = encoder.Unmarshal(unsealed, test.dst) 87 | assert.NoError(t, err) 88 | assertDecoding(t, test.res, test.dst, err) 89 | // securely decode the reflected interface 90 | typ := reflect.ValueOf(test.dst).Elem().Type() 91 | ptr := reflect.New(typ).Interface() 92 | err = encoder.Unmarshal(unsealed, ptr) 93 | assertDecoding(t, test.res, ptr, err) 94 | // securely decode the mapped struct 95 | var mapped interface{} 96 | err = encoder.Unmarshal(unsealed, &mapped) 97 | assertDecoding(t, test.mp, mapped, err) 98 | // close the decoder 99 | decoderExt.Close() 100 | }) 101 | } 102 | } 103 | } 104 | 105 | func assertJSON(t *testing.T, expected, actual []byte, err error) { 106 | e := assert.NoError(t, err) 107 | if !e { 108 | t.Fatal(err) 109 | } 110 | valid := jsoniter.Valid(actual) 111 | assert.True(t, valid, "invalid JSON") 112 | assert.Equal(t, string(expected), string(actual)) 113 | } 114 | 115 | func assertDecoding(t *testing.T, expected, actual interface{}, err error) { 116 | e := assert.NoError(t, err) 117 | if !e { 118 | t.Fatal(err) 119 | } 120 | assert.Equal(t, expected, actual) 121 | deep := reflect.DeepEqual(expected, actual) 122 | assert.True(t, deep, "values are not deep equal") 123 | } 124 | -------------------------------------------------------------------------------- /encoding/json/encoders/secure/tags_json_test.go: -------------------------------------------------------------------------------- 1 | package secure 2 | 3 | type JSON struct { 4 | TagDefault string 5 | TagBlank string `json:""` 6 | TagIgnore string `json:"-"` 7 | TagNamed string `json:"tag_named"` 8 | TagEmpty string `json:"tag_empty"` 9 | TagEscaped string `json:"tag_escaped"` 10 | TagOmit string `json:"tag_omit,omitempty"` 11 | TagNumber int `json:"tag_number"` 12 | TagFloat float64 `json:"tag_float"` 13 | } 14 | 15 | var jsonObj = &JSON{ 16 | TagDefault: "default-value", 17 | TagBlank: "blank-value", 18 | TagIgnore: "ignore-value", 19 | TagNamed: "named-value", 20 | TagEmpty: "", 21 | TagEscaped: "escaped\n-\tvalue\\\"", 22 | TagNumber: 42, 23 | TagFloat: 99.9, 24 | } 25 | 26 | var jsonDecoded = &JSON{ 27 | TagDefault: "default-value", 28 | TagBlank: "blank-value", 29 | TagNamed: "named-value", 30 | TagEscaped: "escaped\n-\tvalue\\\"", 31 | TagNumber: 42, 32 | TagFloat: 99.9, 33 | } 34 | 35 | var jsonMap = map[string]interface{}{ 36 | "TagBlank": "blank-value", 37 | "TagDefault": "default-value", 38 | "tag_empty": "", 39 | "tag_float": 99.9, 40 | "tag_named": "named-value", 41 | "tag_escaped": "escaped\n-\tvalue\\\"", 42 | "tag_number": 42.0, 43 | } 44 | 45 | var jsonEncoded = []byte(`{"TagDefault":"default-value","TagBlank":"blank-val` + 46 | `ue","tag_named":"named-value","tag_empty":"","tag_escaped":"escaped\n-\t` + 47 | `value\\\"","tag_number":42,"tag_float":99.9}`) 48 | 49 | var jsonUnsealed = []byte(`{"TagDefault":"default-value","TagBlank":"blank-va` + 50 | `lue","tag_named":"named-value","tag_empty":"","tag_escaped":"escaped\n-\` + 51 | `tvalue\\\"","tag_number":42,"tag_float":99.9}`) 52 | 53 | var jsonSealed = []byte{0x69, 0x7f, 0x3, 0x1, 0x1, 0x7, 0x50, 0x61, 0x63, 0x6b, 54 | 0x61, 0x67, 0x65, 0x1, 0xff, 0x80, 0x0, 0x1, 0x7, 0x1, 0x7, 0x56, 0x65, 55 | 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x1, 0xc, 0x0, 0x1, 0x6, 0x46, 0x6f, 0x72, 56 | 0x6d, 0x61, 0x74, 0x1, 0xc, 0x0, 0x1, 0xa, 0x43, 0x6f, 0x6d, 0x70, 0x72, 57 | 0x65, 0x73, 0x73, 0x65, 0x64, 0x1, 0x2, 0x0, 0x1, 0x9, 0x45, 0x6e, 0x63, 58 | 0x6f, 0x64, 0x65, 0x72, 0x49, 0x44, 0x1, 0xc, 0x0, 0x1, 0x5, 0x54, 0x6f, 59 | 0x6b, 0x65, 0x6e, 0x1, 0xc, 0x0, 0x1, 0x6, 0x43, 0x69, 0x70, 0x68, 0x65, 60 | 0x72, 0x1, 0xa, 0x0, 0x1, 0x7, 0x45, 0x6e, 0x63, 0x6f, 0x64, 0x65, 0x64, 61 | 0x1, 0xa, 0x0, 0x0, 0x0, 0xfe, 0x1, 0x6e, 0xff, 0x80, 0x1, 0x5, 0x30, 0x2e, 62 | 0x30, 0x2e, 0x31, 0x1, 0x6, 0x73, 0x65, 0x63, 0x75, 0x72, 0x65, 0x2, 0x8, 63 | 0x38, 0x36, 0x66, 0x62, 0x33, 0x66, 0x61, 0x30, 0x2, 0xfe, 0x1, 0x4e, 0x37, 64 | 0x62, 0x32, 0x32, 0x35, 0x34, 0x36, 0x31, 0x36, 0x37, 0x34, 0x34, 0x36, 65 | 0x35, 0x36, 0x36, 0x36, 0x31, 0x37, 0x35, 0x36, 0x63, 0x37, 0x34, 0x32, 66 | 0x32, 0x33, 0x61, 0x32, 0x32, 0x36, 0x34, 0x36, 0x35, 0x36, 0x36, 0x36, 67 | 0x31, 0x37, 0x35, 0x36, 0x63, 0x37, 0x34, 0x32, 0x64, 0x37, 0x36, 0x36, 68 | 0x31, 0x36, 0x63, 0x37, 0x35, 0x36, 0x35, 0x32, 0x32, 0x32, 0x63, 0x32, 69 | 0x32, 0x35, 0x34, 0x36, 0x31, 0x36, 0x37, 0x34, 0x32, 0x36, 0x63, 0x36, 70 | 0x31, 0x36, 0x65, 0x36, 0x62, 0x32, 0x32, 0x33, 0x61, 0x32, 0x32, 0x36, 71 | 0x32, 0x36, 0x63, 0x36, 0x31, 0x36, 0x65, 0x36, 0x62, 0x32, 0x64, 0x37, 72 | 0x36, 0x36, 0x31, 0x36, 0x63, 0x37, 0x35, 0x36, 0x35, 0x32, 0x32, 0x32, 73 | 0x63, 0x32, 0x32, 0x37, 0x34, 0x36, 0x31, 0x36, 0x37, 0x35, 0x66, 0x36, 74 | 0x65, 0x36, 0x31, 0x36, 0x64, 0x36, 0x35, 0x36, 0x34, 0x32, 0x32, 0x33, 75 | 0x61, 0x32, 0x32, 0x36, 0x65, 0x36, 0x31, 0x36, 0x64, 0x36, 0x35, 0x36, 76 | 0x34, 0x32, 0x64, 0x37, 0x36, 0x36, 0x31, 0x36, 0x63, 0x37, 0x35, 0x36, 77 | 0x35, 0x32, 0x32, 0x32, 0x63, 0x32, 0x32, 0x37, 0x34, 0x36, 0x31, 0x36, 78 | 0x37, 0x35, 0x66, 0x36, 0x35, 0x36, 0x64, 0x37, 0x30, 0x37, 0x34, 0x37, 79 | 0x39, 0x32, 0x32, 0x33, 0x61, 0x32, 0x32, 0x32, 0x32, 0x32, 0x63, 0x32, 80 | 0x32, 0x37, 0x34, 0x36, 0x31, 0x36, 0x37, 0x35, 0x66, 0x36, 0x35, 0x37, 81 | 0x33, 0x36, 0x33, 0x36, 0x31, 0x37, 0x30, 0x36, 0x35, 0x36, 0x34, 0x32, 82 | 0x32, 0x33, 0x61, 0x32, 0x32, 0x36, 0x35, 0x37, 0x33, 0x36, 0x33, 0x36, 83 | 0x31, 0x37, 0x30, 0x36, 0x35, 0x36, 0x34, 0x35, 0x63, 0x36, 0x65, 0x32, 84 | 0x64, 0x35, 0x63, 0x37, 0x34, 0x37, 0x36, 0x36, 0x31, 0x36, 0x63, 0x37, 85 | 0x35, 0x36, 0x35, 0x35, 0x63, 0x35, 0x63, 0x35, 0x63, 0x32, 0x32, 0x32, 86 | 0x32, 0x32, 0x63, 0x32, 0x32, 0x37, 0x34, 0x36, 0x31, 0x36, 0x37, 0x35, 87 | 0x66, 0x36, 0x65, 0x37, 0x35, 0x36, 0x64, 0x36, 0x32, 0x36, 0x35, 0x37, 88 | 0x32, 0x32, 0x32, 0x33, 0x61, 0x33, 0x34, 0x33, 0x32, 0x32, 0x63, 0x32, 89 | 0x32, 0x37, 0x34, 0x36, 0x31, 0x36, 0x37, 0x35, 0x66, 0x36, 0x36, 0x36, 90 | 0x63, 0x36, 0x66, 0x36, 0x31, 0x37, 0x34, 0x32, 0x32, 0x33, 0x61, 0x33, 91 | 0x39, 0x33, 0x39, 0x32, 0x65, 0x33, 0x39, 0x37, 0x64, 0x0} 92 | 93 | var jsonComp = []byte{0x69, 0x7f, 0x3, 0x1, 0x1, 0x7, 0x50, 0x61, 0x63, 0x6b, 94 | 0x61, 0x67, 0x65, 0x1, 0xff, 0x80, 0x0, 0x1, 0x7, 0x1, 0x7, 0x56, 0x65, 95 | 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x1, 0xc, 0x0, 0x1, 0x6, 0x46, 0x6f, 0x72, 96 | 0x6d, 0x61, 0x74, 0x1, 0xc, 0x0, 0x1, 0xa, 0x43, 0x6f, 0x6d, 0x70, 0x72, 97 | 0x65, 0x73, 0x73, 0x65, 0x64, 0x1, 0x2, 0x0, 0x1, 0x9, 0x45, 0x6e, 0x63, 98 | 0x6f, 0x64, 0x65, 0x72, 0x49, 0x44, 0x1, 0xc, 0x0, 0x1, 0x5, 0x54, 0x6f, 99 | 0x6b, 0x65, 0x6e, 0x1, 0xc, 0x0, 0x1, 0x6, 0x43, 0x69, 0x70, 0x68, 0x65, 100 | 0x72, 0x1, 0xa, 0x0, 0x1, 0x7, 0x45, 0x6e, 0x63, 0x6f, 0x64, 0x65, 0x64, 101 | 0x1, 0xa, 0x0, 0x0, 0x0, 0xfe, 0x1, 0x1d, 0xff, 0x80, 0x1, 0x5, 0x30, 0x2e, 102 | 0x30, 0x2e, 0x31, 0x1, 0x6, 0x73, 0x65, 0x63, 0x75, 0x72, 0x65, 0x1, 0x1, 103 | 0x1, 0x8, 0x38, 0x36, 0x66, 0x62, 0x33, 0x66, 0x61, 0x30, 0x2, 0xff, 0xfc, 104 | 0x32, 0x38, 0x62, 0x35, 0x32, 0x66, 0x66, 0x64, 0x30, 0x34, 0x30, 0x30, 105 | 0x38, 0x64, 0x30, 0x33, 0x30, 0x30, 0x30, 0x32, 0x63, 0x36, 0x31, 0x34, 106 | 0x31, 0x61, 0x38, 0x30, 0x33, 0x39, 0x30, 0x64, 0x64, 0x30, 0x30, 0x35, 107 | 0x33, 0x37, 0x38, 0x36, 0x38, 0x64, 0x37, 0x32, 0x37, 0x62, 0x35, 0x62, 108 | 0x31, 0x32, 0x34, 0x32, 0x35, 0x66, 0x35, 0x63, 0x39, 0x30, 0x34, 0x33, 109 | 0x30, 0x63, 0x64, 0x63, 0x30, 0x33, 0x30, 0x65, 0x37, 0x63, 0x32, 0x30, 110 | 0x34, 0x63, 0x30, 0x38, 0x65, 0x38, 0x30, 0x64, 0x30, 0x34, 0x34, 0x30, 111 | 0x63, 0x38, 0x62, 0x63, 0x33, 0x66, 0x65, 0x34, 0x31, 0x32, 0x31, 0x30, 112 | 0x39, 0x38, 0x31, 0x31, 0x39, 0x38, 0x61, 0x38, 0x61, 0x35, 0x37, 0x37, 113 | 0x62, 0x37, 0x37, 0x37, 0x66, 0x34, 0x39, 0x61, 0x31, 0x62, 0x31, 0x32, 114 | 0x36, 0x38, 0x66, 0x33, 0x36, 0x32, 0x31, 0x62, 0x38, 0x35, 0x33, 0x36, 115 | 0x32, 0x62, 0x61, 0x38, 0x61, 0x65, 0x38, 0x35, 0x66, 0x34, 0x36, 0x35, 116 | 0x66, 0x34, 0x39, 0x39, 0x64, 0x38, 0x63, 0x63, 0x65, 0x38, 0x37, 0x33, 117 | 0x39, 0x30, 0x38, 0x65, 0x66, 0x38, 0x33, 0x30, 0x61, 0x65, 0x62, 0x63, 118 | 0x65, 0x32, 0x32, 0x38, 0x33, 0x36, 0x36, 0x66, 0x61, 0x65, 0x32, 0x66, 119 | 0x34, 0x38, 0x34, 0x38, 0x34, 0x37, 0x63, 0x63, 0x30, 0x38, 0x30, 0x61, 120 | 0x30, 0x30, 0x33, 0x33, 0x30, 0x37, 0x34, 0x61, 0x34, 0x61, 0x65, 0x33, 121 | 0x30, 0x63, 0x35, 0x37, 0x35, 0x62, 0x33, 0x32, 0x33, 0x38, 0x64, 0x35, 122 | 0x33, 0x34, 0x35, 0x33, 0x32, 0x65, 0x66, 0x37, 0x62, 0x63, 0x62, 0x30, 123 | 0x37, 0x32, 0x35, 0x33, 0x33, 0x63, 0x63, 0x62, 0x64, 0x30, 0x61, 0x31, 124 | 0x38, 0x30, 0x30, 0x31, 0x36, 0x39, 0x61, 0x38, 0x38, 0x33, 0x39, 0x30, 125 | 0x0} 126 | -------------------------------------------------------------------------------- /encoding/json/encoders/secure/tags_none_test.go: -------------------------------------------------------------------------------- 1 | package secure 2 | 3 | type None struct { 4 | Value string 5 | Empty string 6 | Escaped string 7 | Number int 8 | Float float64 9 | } 10 | 11 | var noneObj = &None{ 12 | Value: "object-value", 13 | Empty: "", 14 | Escaped: "escaped\n-\tvalue\\\"", 15 | Number: 42, 16 | Float: 99.9, 17 | } 18 | 19 | var noneDecoded = &None{ 20 | Value: "object-value", 21 | Escaped: "escaped\n-\tvalue\\\"", 22 | Number: 42, 23 | Float: 99.9, 24 | } 25 | 26 | var noneMap = map[string]interface{}{ 27 | "Empty": "", 28 | "Escaped": "escaped\n-\tvalue\\\"", 29 | "Float": 99.9, 30 | "Number": 42.0, 31 | "Value": "object-value", 32 | } 33 | 34 | var noneEncoded = []byte(`{"Value":"object-value","Empty":"","Escaped":"escap` + 35 | `ed\n-\tvalue\\\"","Number":42,"Float":99.9}`) 36 | 37 | var noneUnsealed = []byte(`{"Value":"object-value","Empty":"","Escaped":"esca` + 38 | `ped\n-\tvalue\\\"","Number":42,"Float":99.9}`) 39 | 40 | var noneSealed = []byte{0x69, 0x7f, 0x3, 0x1, 0x1, 0x7, 0x50, 0x61, 0x63, 0x6b, 41 | 0x61, 0x67, 0x65, 0x1, 0xff, 0x80, 0x0, 0x1, 0x7, 0x1, 0x7, 0x56, 0x65, 42 | 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x1, 0xc, 0x0, 0x1, 0x6, 0x46, 0x6f, 0x72, 43 | 0x6d, 0x61, 0x74, 0x1, 0xc, 0x0, 0x1, 0xa, 0x43, 0x6f, 0x6d, 0x70, 0x72, 44 | 0x65, 0x73, 0x73, 0x65, 0x64, 0x1, 0x2, 0x0, 0x1, 0x9, 0x45, 0x6e, 0x63, 45 | 0x6f, 0x64, 0x65, 0x72, 0x49, 0x44, 0x1, 0xc, 0x0, 0x1, 0x5, 0x54, 0x6f, 46 | 0x6b, 0x65, 0x6e, 0x1, 0xc, 0x0, 0x1, 0x6, 0x43, 0x69, 0x70, 0x68, 0x65, 47 | 0x72, 0x1, 0xa, 0x0, 0x1, 0x7, 0x45, 0x6e, 0x63, 0x6f, 0x64, 0x65, 0x64, 48 | 0x1, 0xa, 0x0, 0x0, 0x0, 0xff, 0xdb, 0xff, 0x80, 0x1, 0x5, 0x30, 0x2e, 0x30, 49 | 0x2e, 0x31, 0x1, 0x6, 0x73, 0x65, 0x63, 0x75, 0x72, 0x65, 0x2, 0x8, 0x38, 50 | 0x36, 0x66, 0x62, 0x33, 0x66, 0x61, 0x30, 0x2, 0xff, 0xbc, 0x37, 0x62, 0x32, 51 | 0x32, 0x35, 0x36, 0x36, 0x31, 0x36, 0x63, 0x37, 0x35, 0x36, 0x35, 0x32, 52 | 0x32, 0x33, 0x61, 0x32, 0x32, 0x36, 0x66, 0x36, 0x32, 0x36, 0x61, 0x36, 53 | 0x35, 0x36, 0x33, 0x37, 0x34, 0x32, 0x64, 0x37, 0x36, 0x36, 0x31, 0x36, 54 | 0x63, 0x37, 0x35, 0x36, 0x35, 0x32, 0x32, 0x32, 0x63, 0x32, 0x32, 0x34, 55 | 0x35, 0x36, 0x64, 0x37, 0x30, 0x37, 0x34, 0x37, 0x39, 0x32, 0x32, 0x33, 56 | 0x61, 0x32, 0x32, 0x32, 0x32, 0x32, 0x63, 0x32, 0x32, 0x34, 0x35, 0x37, 57 | 0x33, 0x36, 0x33, 0x36, 0x31, 0x37, 0x30, 0x36, 0x35, 0x36, 0x34, 0x32, 58 | 0x32, 0x33, 0x61, 0x32, 0x32, 0x36, 0x35, 0x37, 0x33, 0x36, 0x33, 0x36, 59 | 0x31, 0x37, 0x30, 0x36, 0x35, 0x36, 0x34, 0x35, 0x63, 0x36, 0x65, 0x32, 60 | 0x64, 0x35, 0x63, 0x37, 0x34, 0x37, 0x36, 0x36, 0x31, 0x36, 0x63, 0x37, 61 | 0x35, 0x36, 0x35, 0x35, 0x63, 0x35, 0x63, 0x35, 0x63, 0x32, 0x32, 0x32, 62 | 0x32, 0x32, 0x63, 0x32, 0x32, 0x34, 0x65, 0x37, 0x35, 0x36, 0x64, 0x36, 63 | 0x32, 0x36, 0x35, 0x37, 0x32, 0x32, 0x32, 0x33, 0x61, 0x33, 0x34, 0x33, 64 | 0x32, 0x32, 0x63, 0x32, 0x32, 0x34, 0x36, 0x36, 0x63, 0x36, 0x66, 0x36, 65 | 0x31, 0x37, 0x34, 0x32, 0x32, 0x33, 0x61, 0x33, 0x39, 0x33, 0x39, 0x32, 66 | 0x65, 0x33, 0x39, 0x37, 0x64, 0x0} 67 | 68 | var noneComp = []byte{0x69, 0x7f, 0x3, 0x1, 0x1, 0x7, 0x50, 0x61, 0x63, 0x6b, 69 | 0x61, 0x67, 0x65, 0x1, 0xff, 0x80, 0x0, 0x1, 0x7, 0x1, 0x7, 0x56, 0x65, 70 | 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x1, 0xc, 0x0, 0x1, 0x6, 0x46, 0x6f, 0x72, 71 | 0x6d, 0x61, 0x74, 0x1, 0xc, 0x0, 0x1, 0xa, 0x43, 0x6f, 0x6d, 0x70, 0x72, 72 | 0x65, 0x73, 0x73, 0x65, 0x64, 0x1, 0x2, 0x0, 0x1, 0x9, 0x45, 0x6e, 0x63, 73 | 0x6f, 0x64, 0x65, 0x72, 0x49, 0x44, 0x1, 0xc, 0x0, 0x1, 0x5, 0x54, 0x6f, 74 | 0x6b, 0x65, 0x6e, 0x1, 0xc, 0x0, 0x1, 0x6, 0x43, 0x69, 0x70, 0x68, 0x65, 75 | 0x72, 0x1, 0xa, 0x0, 0x1, 0x7, 0x45, 0x6e, 0x63, 0x6f, 0x64, 0x65, 0x64, 76 | 0x1, 0xa, 0x0, 0x0, 0x0, 0xff, 0xf7, 0xff, 0x80, 0x1, 0x5, 0x30, 0x2e, 0x30, 77 | 0x2e, 0x31, 0x1, 0x6, 0x73, 0x65, 0x63, 0x75, 0x72, 0x65, 0x1, 0x1, 0x1, 78 | 0x8, 0x38, 0x36, 0x66, 0x62, 0x33, 0x66, 0x61, 0x30, 0x2, 0xff, 0xd6, 0x32, 79 | 0x38, 0x62, 0x35, 0x32, 0x66, 0x66, 0x64, 0x30, 0x34, 0x30, 0x30, 0x66, 80 | 0x31, 0x30, 0x32, 0x30, 0x30, 0x37, 0x62, 0x32, 0x32, 0x35, 0x36, 0x36, 81 | 0x31, 0x36, 0x63, 0x37, 0x35, 0x36, 0x35, 0x32, 0x32, 0x33, 0x61, 0x32, 82 | 0x32, 0x36, 0x66, 0x36, 0x32, 0x36, 0x61, 0x36, 0x35, 0x36, 0x33, 0x37, 83 | 0x34, 0x32, 0x64, 0x37, 0x36, 0x36, 0x31, 0x36, 0x63, 0x37, 0x35, 0x36, 84 | 0x35, 0x32, 0x32, 0x32, 0x63, 0x32, 0x32, 0x34, 0x35, 0x36, 0x64, 0x37, 85 | 0x30, 0x37, 0x34, 0x37, 0x39, 0x32, 0x32, 0x33, 0x61, 0x32, 0x32, 0x32, 86 | 0x32, 0x32, 0x63, 0x32, 0x32, 0x34, 0x35, 0x37, 0x33, 0x36, 0x33, 0x36, 87 | 0x31, 0x37, 0x30, 0x36, 0x35, 0x36, 0x34, 0x32, 0x32, 0x33, 0x61, 0x32, 88 | 0x32, 0x36, 0x35, 0x37, 0x33, 0x36, 0x33, 0x36, 0x31, 0x37, 0x30, 0x36, 89 | 0x35, 0x36, 0x34, 0x35, 0x63, 0x36, 0x65, 0x32, 0x64, 0x35, 0x63, 0x37, 90 | 0x34, 0x37, 0x36, 0x36, 0x31, 0x36, 0x63, 0x37, 0x35, 0x36, 0x35, 0x35, 91 | 0x63, 0x35, 0x63, 0x35, 0x63, 0x32, 0x32, 0x32, 0x32, 0x32, 0x63, 0x32, 92 | 0x32, 0x34, 0x65, 0x37, 0x35, 0x36, 0x64, 0x36, 0x32, 0x36, 0x35, 0x37, 93 | 0x32, 0x32, 0x32, 0x33, 0x61, 0x33, 0x34, 0x33, 0x32, 0x32, 0x63, 0x32, 94 | 0x32, 0x34, 0x36, 0x36, 0x63, 0x36, 0x66, 0x36, 0x31, 0x37, 0x34, 0x32, 95 | 0x32, 0x33, 0x61, 0x33, 0x39, 0x33, 0x39, 0x32, 0x65, 0x33, 0x39, 0x37, 96 | 0x64, 0x36, 0x33, 0x63, 0x63, 0x39, 0x31, 0x39, 0x37, 0x0} 97 | -------------------------------------------------------------------------------- /encoding/json/packager/encoding.go: -------------------------------------------------------------------------------- 1 | package packager 2 | 3 | import ( 4 | "bytes" 5 | "encoding/gob" 6 | "errors" 7 | 8 | "github.com/jrapoport/chestnut/encoding/json/encoders" 9 | ) 10 | 11 | // EncodePackage returns a valid binary enc package for storage. 12 | func EncodePackage(encoderID, token string, cipher, encoded []byte, compressed bool) ([]byte, error) { 13 | if encoderID == encoders.InvalidID { 14 | return nil, errors.New("invalid encoder id") 15 | } 16 | format := Secure 17 | // are we sparse? 18 | sparse := len(encoded) >= minSparse 19 | if sparse { 20 | format = Sparse 21 | } 22 | // start the package 23 | pkg := &Package{ 24 | Version: currentVer.String(), 25 | Format: format, 26 | Compressed: compressed, 27 | EncoderID: encoderID, 28 | Token: token, 29 | Cipher: cipher, 30 | Encoded: encoded, 31 | } 32 | if err := pkg.Valid(); err != nil { 33 | return nil, err 34 | } 35 | return encode(pkg) 36 | } 37 | 38 | func encode(pkg *Package) ([]byte, error) { 39 | if err := pkg.Valid(); err != nil { 40 | return nil, err 41 | } 42 | b := bytes.Buffer{} 43 | e := gob.NewEncoder(&b) 44 | if err := e.Encode(pkg); err != nil { 45 | return nil, err 46 | } 47 | return b.Bytes(), nil 48 | } 49 | 50 | // DecodePackage takes packaged data and returns the ciphertext and encoding block. 51 | func DecodePackage(bytes []byte) (*Package, error) { 52 | pkg, err := decode(bytes) 53 | if err != nil { 54 | return nil, err 55 | } 56 | // check the package ver 57 | if err = pkg.checkVersion(); err != nil { 58 | return nil, err 59 | } 60 | if err = pkg.Valid(); err != nil { 61 | return nil, err 62 | } 63 | return pkg, err 64 | } 65 | 66 | func decode(data []byte) (*Package, error) { 67 | pkg := &Package{} 68 | buf := bytes.Buffer{} 69 | buf.Write(data) 70 | d := gob.NewDecoder(&buf) 71 | return pkg, d.Decode(pkg) 72 | } 73 | -------------------------------------------------------------------------------- /encoding/json/packager/package.go: -------------------------------------------------------------------------------- 1 | package packager 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/hashicorp/go-version" 8 | "github.com/jrapoport/chestnut/encoding/json/encoders" 9 | ) 10 | 11 | const ( 12 | // Version is the current package fmt ver. 13 | Version = "0.0.1" 14 | 15 | // InvalidToken is an invalid sparse token. 16 | InvalidToken = "" 17 | 18 | // minCipher is the min length of base 64 enc ciphertext. 19 | minCipher = 4 20 | 21 | // minSparse is the min length of sparse enc data. 22 | minSparse = 2 // "{}" is an empty JSON object 23 | 24 | // minCompressed is the min length of compressed base 64 enc data. 25 | minCompressed = 8 26 | ) 27 | 28 | // Format is the package fmt. Currently 29 | // only secure & sparse formats are supported. 30 | type Format string 31 | 32 | const ( 33 | // Secure indicates the package contains a fully encrypted JSON object. 34 | Secure Format = "secure" 35 | 36 | // Sparse indicates the package supports sparse decryption. 37 | Sparse Format = "sparse" 38 | ) 39 | 40 | // Valid returns true if the fmt is valid. 41 | func (f Format) Valid() bool { 42 | switch f { 43 | case Secure: 44 | return true 45 | case Sparse: 46 | return true 47 | default: 48 | return false 49 | } 50 | } 51 | 52 | // Package is returned by DecodePackage. 53 | // 54 | // - Secure: If the package fmt Format is Secure, Package contains the encrypted ciphertext 55 | // Cipher containing a fully encoded JSON object. 56 | // 57 | // - Sparse: If the package Format is Sparse, the encrypted ciphertext Cipher contains a lookup 58 | // table of secure values. Encoded contains a plaintext enc JSON with its secure fields 59 | // removed and replaced with a secure lookup token consisting of a prefixed EncoderID with the 60 | // fmt "[prefix]-[encoder id]" (SEE: NewLookupToken) and an index into the lookup table. 61 | type Package struct { 62 | Version string 63 | Format Format 64 | Compressed bool 65 | EncoderID string 66 | Token string 67 | Cipher []byte 68 | Encoded []byte 69 | } 70 | 71 | // Valid returns true if the package fmt is valid. 72 | func (p *Package) Valid() error { 73 | if len(p.Version) <= 0 { 74 | return errors.New("ver required") 75 | } 76 | _, err := version.NewVersion(p.Version) 77 | if err != nil { 78 | return fmt.Errorf("invalid ver %w", err) 79 | } 80 | if p.EncoderID == encoders.InvalidID { 81 | return errors.New("invalid encoder id") 82 | } 83 | if !p.Format.Valid() { 84 | return fmt.Errorf("invalid fmt %s", p.Format) 85 | } 86 | err = p.validateData() 87 | if err != nil { 88 | return err 89 | } 90 | sparse := len(p.Encoded) >= minSparse 91 | if sparse && p.Token == InvalidToken { 92 | return errors.New("invalid sparse token") 93 | } 94 | return nil 95 | } 96 | 97 | func (p *Package) validateData() error { 98 | if len(p.Cipher) < minCipher { 99 | return errors.New("invalid ciphertext") 100 | } 101 | if p.Compressed && len(p.Cipher) < minCompressed { 102 | return errors.New("invalid compressed ciphertext") 103 | } 104 | switch p.Format { 105 | case Secure: 106 | // this was handled above 107 | break 108 | case Sparse: 109 | if len(p.Encoded) < minSparse { 110 | return errors.New("invalid enc data") 111 | } 112 | if p.Compressed { 113 | if len(p.Encoded) < minCompressed { 114 | return errors.New("invalid compressed enc data") 115 | } 116 | break 117 | } 118 | // check that we have what looks like JSON 119 | if p.Encoded[0] != '{' { 120 | return errors.New("invalid enc data") 121 | } 122 | default: 123 | return fmt.Errorf("unsupported fmt: %s", p.Format) 124 | } 125 | return nil 126 | } 127 | 128 | // the currently supported package ver 129 | var currentVer = version.Must(version.NewVersion(Version)) 130 | 131 | func (p *Package) checkVersion() error { 132 | if len(p.Version) <= 0 { 133 | return errors.New("ver required") 134 | } 135 | ver, err := version.NewVersion(p.Version) 136 | if err != nil { 137 | return err 138 | } 139 | if ver.GreaterThan(currentVer) { 140 | return fmt.Errorf("supported ver %s", ver) 141 | } 142 | return nil 143 | } 144 | -------------------------------------------------------------------------------- /encoding/json/packager/package_test.go: -------------------------------------------------------------------------------- 1 | package packager 2 | 3 | import ( 4 | "bytes" 5 | "encoding/gob" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/suite" 10 | ) 11 | 12 | const ( 13 | noComp = false 14 | comp = true 15 | ver = Version 16 | ) 17 | 18 | var ( 19 | empty = "" 20 | id = "c1ff7755" 21 | token = "lookup-token" 22 | sec = []byte("AAAAB3NzaC1yc2EAAAABJQAAAQB/nAmOjTmezNUDKYvEeIRf2Ynw") 23 | enc = []byte("{\"test_object\":{\"" + token + "\":0}}") 24 | zstd = []byte("KLUv/QQAAQEAeyJ0ZXN0X29iamVjdCI6eyJjbmMxZmY3NzU1IjowfX1hE1Nm") 25 | emptyZstd = []byte("KLUv/QQACQAAII1jaLY=") 26 | badVer = "999.999.999" 27 | badVer2 = ".*" 28 | badFormat = Format("invalid") 29 | badData = []byte("==") 30 | badZstd = []byte("bm9wZQ") 31 | comps = []bool{noComp, comp} 32 | secIns = [][]byte{[]byte(nil), []byte(empty), badData, badZstd, sec, emptyZstd, zstd} 33 | encIns = [][]byte{badData, enc, badZstd, emptyZstd, zstd} 34 | ) 35 | 36 | type TestCase struct { 37 | ver string 38 | fmt Format 39 | id string 40 | token string 41 | comp bool 42 | sec []byte 43 | enc []byte 44 | wrapErr assert.ErrorAssertionFunc 45 | unwrapErr assert.ErrorAssertionFunc 46 | } 47 | 48 | var tests = []TestCase{ 49 | // malformed packages 50 | {empty, "", empty, empty, noComp, nil, nil, 51 | assert.Error, assert.Error}, 52 | {"0", "", empty, empty, noComp, nil, nil, 53 | assert.Error, assert.Error}, 54 | {badVer, "", empty, empty, noComp, nil, nil, 55 | assert.Error, assert.Error}, 56 | {badVer2, "", empty, empty, noComp, nil, nil, 57 | assert.Error, assert.Error}, 58 | {ver, "", empty, empty, noComp, nil, nil, 59 | assert.Error, assert.Error}, 60 | {ver, badFormat, empty, empty, noComp, nil, nil, 61 | assert.Error, assert.Error}, 62 | {ver, badFormat, id, empty, noComp, nil, nil, 63 | assert.Error, assert.Error}, 64 | {ver, Secure, id, empty, noComp, nil, nil, 65 | assert.Error, assert.Error}, 66 | {ver, Sparse, empty, empty, noComp, nil, nil, 67 | assert.Error, assert.Error}, 68 | {ver, Sparse, id, empty, noComp, nil, nil, 69 | assert.Error, assert.Error}, 70 | // valid packages 71 | {ver, Secure, id, empty, noComp, sec, nil, 72 | assert.NoError, assert.NoError}, 73 | {ver, Sparse, id, token, noComp, sec, enc, 74 | assert.NoError, assert.NoError}, 75 | // valid compressed packages 76 | {ver, Secure, id, empty, comp, zstd, nil, 77 | assert.NoError, assert.NoError}, 78 | {ver, Sparse, id, token, comp, zstd, zstd, 79 | assert.NoError, assert.NoError}, 80 | } 81 | 82 | func genSecureTestCases() { 83 | for _, c := range comps { 84 | for secIdx, secIn := range secIns { 85 | wrapErr := assert.Error 86 | unwrapErr := assert.Error 87 | if c { 88 | if secIdx >= 4 { 89 | wrapErr = assert.NoError 90 | unwrapErr = assert.NoError 91 | } 92 | } else { 93 | if secIdx >= 3 { 94 | wrapErr = assert.NoError 95 | unwrapErr = assert.NoError 96 | } 97 | } 98 | tc := TestCase{ 99 | ver: ver, 100 | fmt: Secure, 101 | id: id, 102 | comp: c, 103 | sec: secIn, 104 | wrapErr: wrapErr, 105 | unwrapErr: unwrapErr, 106 | } 107 | tests = append(tests, tc) 108 | } 109 | } 110 | } 111 | 112 | func genSparseTestCases() { 113 | for _, c := range comps { 114 | for secIdx, secIn := range secIns { 115 | for encIdx, encIn := range encIns { 116 | wrapErr := assert.Error 117 | unwrapErr := assert.Error 118 | if c { 119 | if encIdx == 1 { 120 | continue 121 | } else if secIdx >= 4 && encIdx > 2 { 122 | wrapErr = assert.NoError 123 | unwrapErr = assert.NoError 124 | } 125 | } else { 126 | if secIdx >= 3 && encIdx == 1 { 127 | wrapErr = assert.NoError 128 | unwrapErr = assert.NoError 129 | } 130 | } 131 | tc := TestCase{ 132 | ver: ver, 133 | fmt: Sparse, 134 | id: id, 135 | token: token, 136 | comp: c, 137 | sec: secIn, 138 | enc: encIn, 139 | wrapErr: wrapErr, 140 | unwrapErr: unwrapErr, 141 | } 142 | tests = append(tests, tc) 143 | } 144 | } 145 | } 146 | } 147 | 148 | type PackageTestSuite struct { 149 | suite.Suite 150 | } 151 | 152 | func TestStore(t *testing.T) { 153 | suite.Run(t, new(PackageTestSuite)) 154 | } 155 | 156 | func (ts *PackageTestSuite) SetupSuite() { 157 | genSecureTestCases() 158 | genSparseTestCases() 159 | } 160 | 161 | func (ts *PackageTestSuite) TestPackage_Encode() { 162 | for _, test := range tests { 163 | bytes, err := EncodePackage(test.id, test.token, test.sec, test.enc, test.comp) 164 | test.wrapErr(ts.T(), err) 165 | if err == nil { 166 | ts.NotEmpty(bytes) 167 | } else { 168 | ts.Empty(bytes) 169 | } 170 | } 171 | } 172 | 173 | func (ts *PackageTestSuite) TestPackage_Decode() { 174 | for _, test := range tests { 175 | testPkg := &Package{ 176 | Version: test.ver, 177 | Format: test.fmt, 178 | Compressed: test.comp, 179 | EncoderID: test.id, 180 | Token: test.token, 181 | Cipher: test.sec, 182 | Encoded: test.enc, 183 | } 184 | _, err := encode(testPkg) 185 | test.unwrapErr(ts.T(), err) 186 | } 187 | 188 | for _, test := range tests { 189 | testPkg := &Package{ 190 | Version: test.ver, 191 | Format: test.fmt, 192 | Compressed: test.comp, 193 | EncoderID: test.id, 194 | Token: test.token, 195 | Cipher: test.sec, 196 | Encoded: test.enc, 197 | } 198 | b := bytes.Buffer{} 199 | e := gob.NewEncoder(&b) 200 | err := e.Encode(testPkg) 201 | ts.NoError(err) 202 | pkg, err := DecodePackage(b.Bytes()) 203 | test.unwrapErr(ts.T(), err) 204 | if err != nil { 205 | ts.Nil(pkg) 206 | } else { 207 | assertPackage(ts.T(), test, pkg) 208 | } 209 | } 210 | } 211 | 212 | func (ts *PackageTestSuite) TestPackage() { 213 | for _, test := range tests { 214 | bytes, err := EncodePackage(test.id, test.token, test.sec, test.enc, test.comp) 215 | test.wrapErr(ts.T(), err, string(bytes)) 216 | if err != nil { 217 | ts.Empty(string(bytes)) 218 | continue 219 | } else { 220 | ts.NotEmpty(string(bytes)) 221 | } 222 | pkg, err := DecodePackage(bytes) 223 | test.unwrapErr(ts.T(), err) 224 | if err != nil { 225 | ts.Nil(pkg) 226 | } else { 227 | assertPackage(ts.T(), test, pkg) 228 | } 229 | } 230 | } 231 | 232 | func assertPackage(t *testing.T, test TestCase, pkg *Package) { 233 | assert.NotNil(t, pkg) 234 | assert.NoError(t, pkg.Valid()) 235 | assert.Equal(t, test.ver, pkg.Version) 236 | assert.Equal(t, test.fmt, pkg.Format) 237 | assert.Equal(t, test.comp, pkg.Compressed) 238 | assert.Equal(t, test.id, pkg.EncoderID) 239 | assert.Equal(t, test.token, pkg.Token) 240 | assert.Equal(t, test.sec, pkg.Cipher) 241 | assert.Equal(t, test.enc, pkg.Encoded) 242 | } 243 | -------------------------------------------------------------------------------- /encoding/json/secure_test.go: -------------------------------------------------------------------------------- 1 | package json 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/jrapoport/chestnut/encoding/compress" 8 | "github.com/jrapoport/chestnut/encoding/json/encoders/secure" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | var ( 13 | encrypt = secure.PassthroughEncryption 14 | decrypt = secure.PassthroughDecryption 15 | compOpt = secure.WithCompression(compress.Zstd) 16 | sparseOpt = secure.SparseDecode() 17 | ) 18 | 19 | func TestSecureEncoding(t *testing.T) { 20 | secureObj := &Family{} 21 | bytes, err := SecureMarshal(family, encrypt) 22 | assert.NoError(t, err) 23 | assert.Equal(t, familyEnc, bytes) 24 | err = SecureUnmarshal(bytes, secureObj, decrypt) 25 | assertDecoding(t, familyDec, secureObj, err) 26 | } 27 | 28 | func TestCompressedEncoding(t *testing.T) { 29 | secureObj := &Family{} 30 | bytes, err := SecureMarshal(family, encrypt, compOpt) 31 | assert.NoError(t, err) 32 | assert.Equal(t, familyComp, bytes) 33 | err = SecureUnmarshal(bytes, secureObj, decrypt, compOpt) 34 | assertDecoding(t, familyDec, secureObj, err) 35 | } 36 | 37 | func TestSparseDecoding(t *testing.T) { 38 | sparseObj := &Family{} 39 | bytes, err := SecureMarshal(family, encrypt) 40 | assert.NoError(t, err) 41 | assert.Equal(t, familyEnc, bytes) 42 | err = SecureUnmarshal(bytes, sparseObj, decrypt, sparseOpt) 43 | assertDecoding(t, familySpr, sparseObj, err) 44 | } 45 | 46 | func TestCompressedSparseDecoding(t *testing.T) { 47 | sparseObj := &Family{} 48 | bytes, err := SecureMarshal(family, encrypt, compOpt) 49 | assert.NoError(t, err) 50 | assert.Equal(t, familyComp, bytes) 51 | err = SecureUnmarshal(bytes, sparseObj, decrypt, compOpt, sparseOpt) 52 | assertDecoding(t, familySpr, sparseObj, err) 53 | } 54 | 55 | func assertDecoding(t *testing.T, expected, actual interface{}, err error) { 56 | e := assert.NoError(t, err) 57 | if !e { 58 | t.Fatal(err) 59 | } 60 | assert.Equal(t, expected, actual) 61 | deep := reflect.DeepEqual(expected, actual) 62 | assert.True(t, deep, "values are not deep equal") 63 | } 64 | -------------------------------------------------------------------------------- /encoding/tags/tags.go: -------------------------------------------------------------------------------- 1 | package tags 2 | 3 | import "strings" 4 | 5 | const ( 6 | // TODO: Add support for chestnut & gorm struct field tags 7 | // ChestnutTag = "cn" 8 | // GORMTag = "gorm" 9 | 10 | // JSONTag is the default JSON struct tag to use. 11 | JSONTag = "json" 12 | 13 | // SecureOption is the tag option to enable sparse encryption of a struct field. 14 | SecureOption = "secure" 15 | 16 | // HashOption is the tag option to hash a struct field of type string. Defaults to SHA256. 17 | HashOption = "hash" 18 | 19 | jsonSeparator = "," 20 | jsonNameIgnore = "-" 21 | ) 22 | 23 | // Hash provides a type for supported hash function names 24 | type Hash string 25 | 26 | const ( 27 | // HashNone is used to indicate no has function was found in the tag options. 28 | HashNone Hash = "" 29 | 30 | // HashSHA256 sets the HashOption to use sha256. This is the default. 31 | // TODO: support parsing this from the struct field tag hash option e.g. `...,hash=md5"` 32 | HashSHA256 = "sha256" 33 | ) 34 | 35 | func (h Hash) String() string { 36 | return string(h) 37 | } 38 | 39 | // ParseJSONTag returns the name and options for a JSON struct field tag. 40 | func ParseJSONTag(tag string) (name string, opts []string) { 41 | parts := strings.Split(tag, jsonSeparator) 42 | switch len(parts) { 43 | case 0: 44 | return "", []string{} 45 | case 1: 46 | return parts[0], []string{} 47 | default: 48 | if IgnoreField(parts[0]) { 49 | return parts[0], []string{} 50 | } 51 | return parts[0], parts[1:] 52 | } 53 | } 54 | 55 | // IgnoreField checks the name to see if field should be ignored. 56 | func IgnoreField(name string) bool { 57 | return name == jsonNameIgnore 58 | } 59 | 60 | // HasOption checks to see if the tag options contain a specific option. 61 | func HasOption(opts []string, opt string) bool { 62 | for _, s := range opts { 63 | if s == opt { 64 | return true 65 | } 66 | } 67 | return false 68 | } 69 | 70 | // HashName checks to see if the hash option is set. The struct field *MUST BE* 71 | // type string and capable of holding the decoded hash as a string. If no hash option 72 | // is found it will return HashNone. Defaults to HashSHA256 (sha256). 73 | func HashName(opts []string) Hash { 74 | if HasOption(opts, HashOption) { 75 | return HashSHA256 76 | } 77 | return HashNone // do not hash 78 | } 79 | 80 | // IsSecure checks to see if the secure option is set. 81 | func IsSecure(opts []string) bool { 82 | return HasOption(opts, SecureOption) 83 | } 84 | -------------------------------------------------------------------------------- /encoding/tags/tags_test.go: -------------------------------------------------------------------------------- 1 | package tags 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestParseJSONTag(t *testing.T) { 10 | tests := []struct { 11 | tag string 12 | name string 13 | opts []string 14 | }{ 15 | {"", "", []string{}}, 16 | {"-", "-", []string{}}, 17 | {"test", "test", []string{}}, 18 | {",secure", "", []string{SecureOption}}, 19 | {"-,secure", "-", []string{}}, 20 | {"test,secure", "test", []string{SecureOption}}, 21 | {",secure,hash", "", []string{SecureOption, HashOption}}, 22 | {"-,secure,hash", "-", []string{}}, 23 | {"test,secure,hash", "test", []string{SecureOption, HashOption}}, 24 | {",secure,hash,omitempty", "", []string{SecureOption, HashOption, "omitempty"}}, 25 | {"-,secure,hash,omitempty", "-", []string{}}, 26 | {"test,secure,hash,omitempty", "test", []string{SecureOption, HashOption, "omitempty"}}, 27 | } 28 | for _, test := range tests { 29 | name, opts := ParseJSONTag(test.tag) 30 | assert.Equal(t, test.name, name) 31 | assert.ElementsMatch(t, test.opts, opts) 32 | } 33 | } 34 | 35 | func TestIgnoreField(t *testing.T) { 36 | tests := []struct { 37 | tag string 38 | assertBool assert.BoolAssertionFunc 39 | }{ 40 | {"", assert.False}, 41 | {"-", assert.True}, 42 | {"test", assert.False}, 43 | {",secure", assert.False}, 44 | {"-,secure", assert.True}, 45 | {"test,secure", assert.False}, 46 | {",secure,hash", assert.False}, 47 | {"-,secure,hash", assert.True}, 48 | {"test,secure,hash", assert.False}, 49 | {",secure,hash,omitempty", assert.False}, 50 | {"-,secure,hash,omitempty", assert.True}, 51 | {"test,secure,hash,omitempty", assert.False}, 52 | } 53 | for _, test := range tests { 54 | name, _ := ParseJSONTag(test.tag) 55 | test.assertBool(t, IgnoreField(name), "unexpected") 56 | } 57 | } 58 | 59 | func TestHasOption(t *testing.T) { 60 | tests := []struct { 61 | opts []string 62 | opt string 63 | has bool 64 | }{ 65 | {[]string{}, "", false}, 66 | {[]string{}, SecureOption, false}, 67 | {[]string{HashOption}, SecureOption, false}, 68 | {[]string{SecureOption}, SecureOption, true}, 69 | {nil, SecureOption, false}, 70 | } 71 | for _, test := range tests { 72 | has := HasOption(test.opts, test.opt) 73 | assert.Equal(t, test.has, has) 74 | } 75 | } 76 | 77 | func TestHashFunction(t *testing.T) { 78 | tests := []struct { 79 | opts []string 80 | name Hash 81 | }{ 82 | {nil, HashNone}, 83 | {[]string{}, HashNone}, 84 | {[]string{HashOption}, HashSHA256}, 85 | } 86 | for _, test := range tests { 87 | name := HashName(test.opts) 88 | assert.Equal(t, test.name, name) 89 | } 90 | } 91 | 92 | func TestIsSecure(t *testing.T) { 93 | tests := []struct { 94 | opts []string 95 | is bool 96 | }{ 97 | {nil, false}, 98 | {[]string{}, false}, 99 | {[]string{SecureOption}, true}, 100 | } 101 | for _, test := range tests { 102 | is := IsSecure(test.opts) 103 | assert.Equal(t, test.is, is) 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /encryptor/aes.go: -------------------------------------------------------------------------------- 1 | package encryptor 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/jrapoport/chestnut/encryptor/aes" 7 | "github.com/jrapoport/chestnut/encryptor/crypto" 8 | ) 9 | 10 | // AESEncryptor is an encryptor that supports the 11 | // following AES keyLen lengths & cipher modes: 12 | // - AES128-CFB, AES192-CFB, AES256-CFB 13 | // - AES128-CTR, AES192-CTR, AES256-CTR 14 | // - AES128-GCM, AES192-GCM, AES256-GCM 15 | type AESEncryptor struct { 16 | secret crypto.Secret 17 | keyLen crypto.KeyLen 18 | mode crypto.Mode 19 | } 20 | 21 | var _ crypto.Encryptor = (*AESEncryptor)(nil) 22 | 23 | // NewAESEncryptor returns a new AESEncryptor configured 24 | // with an AES keyLen length and mode for a secret. 25 | func NewAESEncryptor(keyLen crypto.KeyLen, mode crypto.Mode, secret crypto.Secret) *AESEncryptor { 26 | ae := new(AESEncryptor) 27 | ae.secret = secret 28 | ae.keyLen = keyLen 29 | ae.mode = mode 30 | return ae 31 | } 32 | 33 | // ID returns the id of the encryptor (secret) that 34 | // was used to encrypt the data (for tracking). 35 | func (e *AESEncryptor) ID() string { 36 | return e.secret.ID() 37 | } 38 | 39 | // Name returns the name of the configured AES encryption cipher 40 | // in following format "[cipher][keyLen length]-[mode]" e.g. "aes192-ctr". 41 | func (e *AESEncryptor) Name() string { 42 | return crypto.CipherName("aes", e.keyLen, e.mode) 43 | } 44 | 45 | // Encrypt returns the plain data encrypted with the configured cipher mode and secret. 46 | func (e *AESEncryptor) Encrypt(plaintext []byte) ([]byte, error) { 47 | var encryptCall aes.CipherCall 48 | switch e.mode { 49 | case aes.CFB: 50 | encryptCall = aes.EncryptCFB 51 | case aes.CTR: 52 | encryptCall = aes.EncryptCTR 53 | case aes.GCM: 54 | encryptCall = aes.EncryptGCM 55 | default: 56 | return nil, fmt.Errorf("unsupported encryption cipher mode: %s", e.mode) 57 | } 58 | return encryptCall(e.keyLen, e.secret.Open(), plaintext) 59 | } 60 | 61 | // Decrypt returns the cipher data decrypted with the configured cipher mode and secret. 62 | func (e *AESEncryptor) Decrypt(ciphertext []byte) ([]byte, error) { 63 | var decryptCall aes.CipherCall 64 | switch e.mode { 65 | case aes.CFB: 66 | decryptCall = aes.DecryptCFB 67 | case aes.CTR: 68 | decryptCall = aes.DecryptCTR 69 | case aes.GCM: 70 | decryptCall = aes.DecryptGCM 71 | default: 72 | return nil, fmt.Errorf("unsupported decryption cipher mode: %s", e.mode) 73 | } 74 | return decryptCall(e.keyLen, e.secret.Open(), ciphertext) 75 | } 76 | -------------------------------------------------------------------------------- /encryptor/aes/aes.go: -------------------------------------------------------------------------------- 1 | package aes 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/cipher" 6 | "errors" 7 | 8 | "github.com/jrapoport/chestnut/encryptor/crypto" 9 | ) 10 | 11 | // currently supported modes 12 | const ( 13 | CFB crypto.Mode = "cfb" 14 | CTR = "ctr" 15 | GCM = "gcm" 16 | ) 17 | 18 | // CipherCall is function the prototype for the encryption and decryption. 19 | type CipherCall func(length crypto.KeyLen, secret, data []byte) ([]byte, error) 20 | 21 | // cipherTransform preforms the encryption or decryption and returns the result. 22 | type cipherTransform func(header crypto.Header, block cipher.Block, data []byte) ([]byte, error) 23 | 24 | // encrypt is a generalized AES decryption function that takes plaintext and return a serialized Entry. 25 | func encrypt(keyLen crypto.KeyLen, secret, plaintext []byte, header crypto.Header, encryptT cipherTransform) ([]byte, error) { 26 | if plaintext == nil || len(plaintext) <= 0 { 27 | return nil, errors.New("invalid plain data") 28 | } 29 | // create the cipher key 30 | key, err := crypto.NewCipherKey(keyLen, secret, header.Salt) 31 | if err != nil { 32 | return nil, err 33 | } 34 | // create the cipher block 35 | block, err := aes.NewCipher(key) 36 | if err != nil { 37 | return nil, err 38 | } 39 | ciphertext, err := encryptT(header, block, plaintext) 40 | if err != nil { 41 | return nil, err 42 | } 43 | // check the result 44 | data := crypto.NewData(header, ciphertext) 45 | if err = isDataValid(data); err != nil { 46 | return nil, err 47 | } 48 | // encode the encrypted data and return the result 49 | return crypto.EncodeData(data) 50 | } 51 | 52 | // decrypt is a generalized AES decryption function that takes a serialized Entry and returns plaintext. 53 | func decrypt(keyLen crypto.KeyLen, secret, ciphertext []byte, decryptT cipherTransform) ([]byte, error) { 54 | if ciphertext == nil || len(ciphertext) <= 0 { 55 | return nil, errors.New("invalid cipher data") 56 | } 57 | // decode the encrypted data 58 | data, err := crypto.DecodeData(ciphertext) 59 | if err != nil { 60 | return nil, err 61 | } 62 | // check the encoding 63 | if err = isDataValid(data); err != nil { 64 | return nil, err 65 | } 66 | // get the cipher key 67 | key, err := crypto.NewCipherKey(keyLen, secret, data.Salt) 68 | if err != nil { 69 | return nil, err 70 | } 71 | // get the cipher block 72 | block, err := aes.NewCipher(key) 73 | if err != nil { 74 | return nil, err 75 | } 76 | // decrypt the data 77 | return decryptT(data.Header, block, data.Bytes) 78 | } 79 | 80 | func isDataValid(data crypto.Data) error { 81 | if err := data.Valid(); err != nil { 82 | return err 83 | } 84 | // check the iv 85 | if data.IV != nil && len(data.IV) < aes.BlockSize { 86 | return errors.New("invalid iv") 87 | } 88 | return nil 89 | } 90 | -------------------------------------------------------------------------------- /encryptor/aes/aes_test.go: -------------------------------------------------------------------------------- 1 | package aes 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | "github.com/jrapoport/chestnut/encryptor/crypto" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestAllCiphers(t *testing.T) { 12 | ciphers := []struct { 13 | name crypto.Mode 14 | encryptCall CipherCall 15 | decryptCall CipherCall 16 | }{ 17 | {CFB, EncryptCFB, DecryptCFB}, 18 | {CTR, EncryptCTR, DecryptCTR}, 19 | {GCM, EncryptGCM, DecryptGCM}, 20 | } 21 | for _, cipher := range ciphers { 22 | t.Run(cipher.name.String(), func(t *testing.T) { 23 | testCipher(t, cipher.encryptCall, cipher.decryptCall) 24 | }) 25 | } 26 | } 27 | 28 | func testCipher(t *testing.T, encryptCall, decryptCall CipherCall) { 29 | const ( 30 | secret = "i-am-a-good-secret" 31 | plaintext = "Lorem ipsum dolor sit amet" 32 | ) 33 | lens := []crypto.KeyLen{ 34 | crypto.Key128, 35 | crypto.Key192, 36 | crypto.Key256, 37 | } 38 | for _, l := range lens { 39 | t.Run(l.String(), func(t *testing.T) { 40 | encrypted, err := encryptCall(l, []byte(secret), []byte(plaintext)) 41 | assert.NoError(t, err) 42 | assert.NotEmpty(t, encrypted) 43 | data, err := crypto.DecodeData(encrypted) 44 | assert.NoError(t, err) 45 | assert.NotNil(t, data) 46 | assert.NoError(t, isDataValid(data)) 47 | assert.Equal(t, l, data.KeyLen) 48 | decrypted, err := decryptCall(l, []byte(secret), encrypted) 49 | assert.NoError(t, err) 50 | assert.NotEmpty(t, decrypted) 51 | assert.Equal(t, plaintext, string(decrypted)) 52 | }) 53 | } 54 | // bad plain data 55 | _, err := encryptCall(crypto.Key256, []byte(secret), nil) 56 | assert.Error(t, err) 57 | // mismatch 58 | e, _ := encryptCall(crypto.Key256, []byte(secret), []byte(plaintext)) 59 | d, _ := decryptCall(crypto.Key128, []byte(secret), e) 60 | assert.NotEqual(t, plaintext, string(d)) 61 | // bad cipher data 62 | badData := [][]byte{ 63 | nil, 64 | []byte(""), 65 | []byte("bad"), 66 | } 67 | for _, bd := range badData { 68 | _, err = decryptCall(crypto.Key256, []byte(secret), bd) 69 | assert.Error(t, err) 70 | } 71 | for _, bd := range badData { 72 | _, err = decryptCall(0, nil, bd) 73 | assert.Error(t, err) 74 | } 75 | for _, bd := range badData { 76 | _, err = decryptCall(math.MaxInt64, nil, bd) 77 | assert.Error(t, err) 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /encryptor/aes/cfb.go: -------------------------------------------------------------------------------- 1 | package aes 2 | 3 | import ( 4 | "crypto/cipher" 5 | 6 | "github.com/jrapoport/chestnut/encryptor/crypto" 7 | ) 8 | 9 | var ( 10 | _ CipherCall = EncryptCFB // EncryptCFB conforms to CipherCall 11 | _ CipherCall = DecryptCFB // DecryptCFB conforms to CipherCall 12 | ) 13 | 14 | // EncryptCFB supports AES128-CFB, AES192-CFB, and AES256-CFB encryption. 15 | func EncryptCFB(length crypto.KeyLen, secret, plaintext []byte) ([]byte, error) { 16 | // encrypt the data 17 | return xorStreamEncrypt(length, CFB, secret, plaintext, cipher.NewCFBEncrypter) 18 | } 19 | 20 | // DecryptCFB supports AES128-CFB, AES192-CFB, and AES256-CFB decryption. 21 | func DecryptCFB(length crypto.KeyLen, secret, ciphertext []byte) ([]byte, error) { 22 | // decrypt the data 23 | return xorStreamDecrypt(length, CFB, secret, ciphertext, cipher.NewCFBDecrypter) 24 | } 25 | -------------------------------------------------------------------------------- /encryptor/aes/cfb_test.go: -------------------------------------------------------------------------------- 1 | package aes 2 | 3 | import "testing" 4 | 5 | func TestCipherCFB(t *testing.T) { 6 | testCipher(t, EncryptCFB, DecryptCFB) 7 | } 8 | -------------------------------------------------------------------------------- /encryptor/aes/ctr.go: -------------------------------------------------------------------------------- 1 | package aes 2 | 3 | import ( 4 | "crypto/cipher" 5 | 6 | "github.com/jrapoport/chestnut/encryptor/crypto" 7 | ) 8 | 9 | var ( 10 | _ CipherCall = EncryptCTR // EncryptCTR conforms to CipherCall 11 | _ CipherCall = DecryptCTR // DecryptCTR conforms to CipherCall 12 | ) 13 | 14 | // EncryptCTR supports AES128-CTR, AES192-CTR, and AES256-CTR encryption. 15 | func EncryptCTR(length crypto.KeyLen, secret, plaintext []byte) ([]byte, error) { 16 | // encrypt the data 17 | return xorStreamEncrypt(length, CTR, secret, plaintext, cipher.NewCTR) 18 | } 19 | 20 | // DecryptCTR supports AES128-CTR, AES192-CTR, and AES256-CTR decryption. 21 | func DecryptCTR(length crypto.KeyLen, secret, ciphertext []byte) ([]byte, error) { 22 | // decrypt the data 23 | return xorStreamDecrypt(length, CTR, secret, ciphertext, cipher.NewCTR) 24 | } 25 | -------------------------------------------------------------------------------- /encryptor/aes/ctr_test.go: -------------------------------------------------------------------------------- 1 | package aes 2 | 3 | import "testing" 4 | 5 | func TestCipherCTR(t *testing.T) { 6 | testCipher(t, EncryptCTR, DecryptCTR) 7 | } 8 | -------------------------------------------------------------------------------- /encryptor/aes/gcm.go: -------------------------------------------------------------------------------- 1 | package aes 2 | 3 | import ( 4 | "crypto/cipher" 5 | 6 | "github.com/jrapoport/chestnut/encryptor/crypto" 7 | ) 8 | 9 | var ( 10 | _ CipherCall = EncryptGCM // EncryptGCM conforms to CipherCall 11 | _ CipherCall = DecryptGCM // DecryptGCM conforms to CipherCall 12 | ) 13 | 14 | // newGMCHeader returns a header containing a nonce suitable for a gcm cipher. 15 | func newGMCHeader(keyLen crypto.KeyLen) (crypto.Header, error) { 16 | salt, err := crypto.MakeSalt() 17 | if err != nil { 18 | return crypto.Header{}, err 19 | } 20 | nonce, err := crypto.MakeNonce() 21 | if err != nil { 22 | return crypto.Header{}, err 23 | } 24 | return crypto.NewHeader("aes", keyLen, GCM, salt, nil, nonce) 25 | } 26 | 27 | // EncryptGCM supports AES128-GCM, AES192-GCM, and AES256-GCM encryption. 28 | func EncryptGCM(keyLen crypto.KeyLen, secret, plaintext []byte) ([]byte, error) { 29 | // create the header 30 | header, err := newGMCHeader(keyLen) 31 | if err != nil { 32 | return nil, err 33 | } 34 | // seal the data with gcms 35 | sealData := func(_ crypto.Header, block cipher.Block, _ []byte) ([]byte, error) { 36 | // create the AHEAD 37 | gcm, gcmErr := cipher.NewGCM(block) 38 | if gcmErr != nil { 39 | return nil, gcmErr 40 | } 41 | // encrypt the data 42 | return gcm.Seal(nil, header.Nonce, plaintext, nil), nil 43 | } 44 | return encrypt(keyLen, secret, plaintext, header, sealData) 45 | } 46 | 47 | // DecryptGCM supports AES128-GCM, AES192-GCM, and AES256-GCM decryption. 48 | func DecryptGCM(keyLen crypto.KeyLen, secret, ciphertext []byte) ([]byte, error) { 49 | // open the data with gcm 50 | openData := func(header crypto.Header, block cipher.Block, data []byte) ([]byte, error) { 51 | // create the AHEAD 52 | gcm, err := cipher.NewGCM(block) 53 | if err != nil { 54 | return nil, err 55 | } 56 | // decrypt the data 57 | return gcm.Open(nil, header.Nonce, data, nil) 58 | } 59 | return decrypt(keyLen, secret, ciphertext, openData) 60 | } 61 | -------------------------------------------------------------------------------- /encryptor/aes/gcm_test.go: -------------------------------------------------------------------------------- 1 | package aes 2 | 3 | import "testing" 4 | 5 | func TestCipherGCM(t *testing.T) { 6 | testCipher(t, EncryptGCM, DecryptGCM) 7 | } 8 | -------------------------------------------------------------------------------- /encryptor/aes/stream.go: -------------------------------------------------------------------------------- 1 | package aes 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/cipher" 6 | 7 | "github.com/jrapoport/chestnut/encryptor/crypto" 8 | ) 9 | 10 | type streamCipher func(block cipher.Block, iv []byte) cipher.Stream 11 | 12 | // newStreamHeader returns a generic header suitable for aes stream ciphers that require an iv. 13 | func newStreamHeader(keyLen crypto.KeyLen, mode crypto.Mode) (crypto.Header, error) { 14 | salt, err := crypto.MakeRand(crypto.SaltLength) 15 | if err != nil { 16 | return crypto.Header{}, err 17 | } 18 | iv, err := crypto.MakeRand(aes.BlockSize) 19 | if err != nil { 20 | return crypto.Header{}, err 21 | } 22 | return crypto.NewHeader("aes", keyLen, mode, salt, iv, nil) 23 | } 24 | 25 | // xorStreamEncrypt is a generic function for AES XOR stream encryption ciphers. 26 | func xorStreamEncrypt(keyLen crypto.KeyLen, mode crypto.Mode, secret, 27 | plaintext []byte, newEncryptor streamCipher) ([]byte, error) { 28 | // create the header 29 | header, err := newStreamHeader(keyLen, mode) 30 | if err != nil { 31 | return nil, err 32 | } 33 | // encrypt the data 34 | encryptStream := func(_ crypto.Header, block cipher.Block, _ []byte) ([]byte, error) { 35 | ciphertext := make([]byte, len(plaintext)) 36 | stream := newEncryptor(block, header.IV) 37 | stream.XORKeyStream(ciphertext, plaintext) 38 | return ciphertext, nil 39 | } 40 | return encrypt(keyLen, secret, plaintext, header, encryptStream) 41 | } 42 | 43 | // xorStreamDecrypt is a generic function for AES XOR stream decryption ciphers. 44 | func xorStreamDecrypt(keyLen crypto.KeyLen, _ crypto.Mode, secret, 45 | ciphertext []byte, newDecrypter streamCipher) ([]byte, error) { 46 | // decrypt the data 47 | var decryptStream = func(header crypto.Header, block cipher.Block, data []byte) ([]byte, error) { 48 | plaintext := make([]byte, len(data)) 49 | stream := newDecrypter(block, header.IV) 50 | stream.XORKeyStream(plaintext, data) 51 | // return the plain data 52 | return plaintext, nil 53 | } 54 | return decrypt(keyLen, secret, ciphertext, decryptStream) 55 | } 56 | -------------------------------------------------------------------------------- /encryptor/aes_test.go: -------------------------------------------------------------------------------- 1 | package encryptor 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/google/uuid" 8 | "github.com/jrapoport/chestnut/encryptor/aes" 9 | "github.com/jrapoport/chestnut/encryptor/crypto" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | const testPlainText = "Lorem ipsum dolor sit amet" 14 | 15 | var ( 16 | textSecret = crypto.TextSecret("i-am-a-good-secret") 17 | managedSecret = crypto.NewManagedSecret(uuid.New().String(), "i-am-a-managed-secret") 18 | secureSecret = crypto.NewSecureSecret(uuid.New().String(), func(s crypto.Secret) []byte { 19 | return []byte(s.ID()) 20 | }) 21 | ) 22 | 23 | func testAESEncryptor(t *testing.T, secret crypto.Secret, keyLen crypto.KeyLen, mode crypto.Mode) { 24 | ae := &AESEncryptor{secret, keyLen, mode} 25 | assert.Equal(t, secret.ID(), ae.ID()) 26 | assert.Equal(t, crypto.CipherName("aes", keyLen, mode), ae.Name()) 27 | e, err := ae.Encrypt([]byte(testPlainText)) 28 | assert.NoError(t, err) 29 | assert.NotEmpty(t, e) 30 | d, err := ae.Decrypt(e) 31 | assert.NoError(t, err) 32 | assert.NotEmpty(t, d) 33 | assert.Equal(t, testPlainText, string(d)) 34 | } 35 | 36 | func TestAESEncryptor(t *testing.T) { 37 | secrets := []struct { 38 | name string 39 | crypto.Secret 40 | }{ 41 | {"TextSecret", textSecret}, 42 | {"ManagedSecret", managedSecret}, 43 | {"SecureSecret", secureSecret}, 44 | } 45 | modes := []crypto.Mode{ 46 | aes.CFB, 47 | aes.CTR, 48 | aes.GCM, 49 | } 50 | keyLens := []crypto.KeyLen{ 51 | crypto.Key128, 52 | crypto.Key192, 53 | crypto.Key256, 54 | } 55 | testSecrets := func(t *testing.T, keyLen crypto.KeyLen, mode crypto.Mode) { 56 | for _, secret := range secrets { 57 | t.Run(secret.name, func(t *testing.T) { 58 | testAESEncryptor(t, secret, keyLen, mode) 59 | }) 60 | } 61 | } 62 | testKeyLens := func(t *testing.T, mode crypto.Mode) { 63 | for _, keyLen := range keyLens { 64 | t.Run(keyLen.String(), func(t *testing.T) { 65 | testSecrets(t, keyLen, mode) 66 | }) 67 | } 68 | } 69 | for _, mode := range modes { 70 | t.Run(mode.String(), func(t *testing.T) { 71 | testKeyLens(t, mode) 72 | }) 73 | } 74 | // load a bad cipher 75 | const invalidMode = "Invalid_Cipher_Mode" 76 | ae := &AESEncryptor{textSecret, crypto.Key128, invalidMode} 77 | t.Run(fmt.Sprintf("%s_%s", invalidMode, "Encrypt"), func(t *testing.T) { 78 | // try to encrypt with a bad cipher 79 | _, err := ae.Encrypt([]byte(testPlainText)) 80 | assert.Error(t, err) 81 | }) 82 | t.Run(fmt.Sprintf("%s_%s", invalidMode, "Decrypt"), func(t *testing.T) { 83 | // try to decrypt with a bad cipher 84 | _, err := ae.Decrypt([]byte(testPlainText)) 85 | assert.Error(t, err) 86 | }) 87 | } 88 | -------------------------------------------------------------------------------- /encryptor/chain.go: -------------------------------------------------------------------------------- 1 | package encryptor 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/jrapoport/chestnut/encryptor/crypto" 7 | ) 8 | 9 | // ChainEncryptor is an encryptor that supports an chain of other Encryptors. 10 | // Bytes will be encrypted by chaining the Encryptors in a FIFO order. 11 | type ChainEncryptor struct { 12 | id string 13 | name string 14 | ids []string 15 | names []string 16 | encryption []crypto.Encryptor 17 | decryption []crypto.Encryptor 18 | } 19 | 20 | var _ crypto.Encryptor = (*ChainEncryptor)(nil) 21 | 22 | const chainSep = " " 23 | 24 | // NewChainEncryptor creates a new ChainEncryptor consisting of a chain 25 | // of the supplied Encryptors. 26 | func NewChainEncryptor(encryptors ...crypto.Encryptor) *ChainEncryptor { 27 | if len(encryptors) == 0 { 28 | return nil 29 | } 30 | // reverse the encryptors from FIFO to LIFO 31 | decryptors := make([]crypto.Encryptor, len(encryptors)) 32 | for i := range encryptors { 33 | decryptors[len(encryptors)-1-i] = encryptors[i] 34 | } 35 | chain := new(ChainEncryptor) 36 | chain.encryption = encryptors 37 | chain.decryption = decryptors 38 | chain.ids = make([]string, len(encryptors)) 39 | chain.names = make([]string, len(encryptors)) 40 | for i, e := range chain.encryption { 41 | chain.ids[i] = e.ID() 42 | chain.names[i] = e.Name() 43 | } 44 | chain.id = strings.Join(chain.ids, chainSep) 45 | chain.name = strings.Join(chain.names, chainSep) 46 | return chain 47 | } 48 | 49 | // ID returns a concatenated list of the ids of chained encryptor(s) / secrets 50 | // that were used to encrypt the data (for tracking) separated by spaces. 51 | func (e *ChainEncryptor) ID() string { 52 | return e.id 53 | } 54 | 55 | // Name returns a concatenated list of the cipher names of the chained encryptor(s) 56 | // that were used to encrypt the data separated by spaces. 57 | func (e *ChainEncryptor) Name() string { 58 | return e.name 59 | } 60 | 61 | // Encrypt returns data encrypted with the chain of Encryptors. 62 | func (e *ChainEncryptor) Encrypt(plaintext []byte) ([]byte, error) { 63 | var err error 64 | ciphertext := plaintext 65 | for _, en := range e.encryption { 66 | ciphertext, err = en.Encrypt(ciphertext) 67 | if err != nil { 68 | break 69 | } 70 | } 71 | return ciphertext, err 72 | } 73 | 74 | // Decrypt returns data decrypted with the chain of Encryptors. 75 | func (e *ChainEncryptor) Decrypt(ciphertext []byte) ([]byte, error) { 76 | var err error 77 | plaintext := ciphertext 78 | for _, de := range e.decryption { 79 | plaintext, err = de.Decrypt(plaintext) 80 | if err != nil { 81 | break 82 | } 83 | } 84 | return plaintext, err 85 | } 86 | -------------------------------------------------------------------------------- /encryptor/chian_test.go: -------------------------------------------------------------------------------- 1 | package encryptor 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/jrapoport/chestnut/encryptor/aes" 8 | "github.com/jrapoport/chestnut/encryptor/crypto" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestChainEncryptor_Nil(t *testing.T) { 13 | assert.Nil(t, NewChainEncryptor()) 14 | } 15 | 16 | func TestChainEncryptor_Single(t *testing.T) { 17 | ae := NewAESEncryptor(crypto.Key128, aes.CFB, textSecret) 18 | assert.NotNil(t, ae) 19 | chain := NewChainEncryptor(ae) 20 | assert.Equal(t, ae.Name(), chain.Name()) 21 | assert.Equal(t, ae.ID(), chain.ID()) 22 | testChainEncryptor(t, chain) 23 | } 24 | 25 | func TestChainEncryptor_Chained(t *testing.T) { 26 | encryptors := []crypto.Encryptor{ 27 | &AESEncryptor{textSecret, crypto.Key128, aes.CFB}, 28 | &AESEncryptor{managedSecret, crypto.Key192, aes.CTR}, 29 | &AESEncryptor{secureSecret, crypto.Key256, aes.GCM}, 30 | } 31 | chain := NewChainEncryptor(encryptors...) 32 | testChainName(t, chain, encryptors) 33 | testChainID(t, chain, encryptors) 34 | testChainEncryptor(t, chain) 35 | } 36 | 37 | func testChainName(t *testing.T, chain *ChainEncryptor, encryptors []crypto.Encryptor) { 38 | var names []string 39 | for _, e := range encryptors { 40 | names = append(names, e.Name()) 41 | } 42 | name := strings.Join(names, chainSep) 43 | assert.Equal(t, name, chain.Name()) 44 | } 45 | 46 | func testChainID(t *testing.T, chain *ChainEncryptor, encryptors []crypto.Encryptor) { 47 | var ids []string 48 | for _, e := range encryptors { 49 | ids = append(ids, e.ID()) 50 | } 51 | id := strings.Join(ids, chainSep) 52 | assert.Equal(t, id, chain.ID()) 53 | } 54 | 55 | func testChainEncryptor(t *testing.T, chain *ChainEncryptor) { 56 | assert.NotNil(t, chain) 57 | e, err := chain.Encrypt([]byte(testPlainText)) 58 | assert.NoError(t, err) 59 | assert.NotEmpty(t, e) 60 | d, err := chain.Decrypt(e) 61 | assert.NoError(t, err) 62 | assert.NotEmpty(t, d) 63 | assert.Equal(t, testPlainText, string(d)) 64 | } 65 | -------------------------------------------------------------------------------- /encryptor/crypto/data.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "bytes" 5 | "encoding/gob" 6 | "errors" 7 | "fmt" 8 | ) 9 | 10 | // Data is a serializable wrapper for encrypted 11 | // bytes with additional metadata in the Header. 12 | type Data struct { 13 | Header 14 | Bytes []byte 15 | } 16 | 17 | // NewData returns an Data initialized 18 | // with a Header and encrypted data. 19 | func NewData(h Header, data []byte) Data { 20 | return Data{h, data} 21 | } 22 | 23 | // Valid returns an error if the Data is not valid. 24 | func (e Data) Valid() error { 25 | if err := e.Header.Valid(); err != nil { 26 | return fmt.Errorf("invalid header %w", err) 27 | } 28 | // check the data 29 | if len(e.Bytes) <= 0 { 30 | return errors.New("invalid data") 31 | } 32 | return nil 33 | } 34 | 35 | // EncodeData encodes Data to a byte representation. This provides a small abstraction 36 | // in case we want to swap out the gob encoder for something else. 37 | func EncodeData(data Data) ([]byte, error) { 38 | if err := data.Valid(); err != nil { 39 | return nil, err 40 | } 41 | return GobEncodeData(data) 42 | } 43 | 44 | // DecodeData decodes a byte representation to Data. This provides a small abstraction 45 | // in case we want to swap out the gob decoder for something else. 46 | func DecodeData(b []byte) (Data, error) { 47 | return GobDecodeData(b) 48 | } 49 | 50 | // GobEncodeData serializes Data to a gob binary representation. 51 | func GobEncodeData(data Data) ([]byte, error) { 52 | b := bytes.Buffer{} 53 | e := gob.NewEncoder(&b) 54 | if err := e.Encode(data); err != nil { 55 | return nil, err 56 | } 57 | return b.Bytes(), nil 58 | } 59 | 60 | // GobDecodeData deserializes a gob binary representation to Data. 61 | func GobDecodeData(b []byte) (Data, error) { 62 | data := Data{} 63 | buf := bytes.Buffer{} 64 | buf.Write(b) 65 | d := gob.NewDecoder(&buf) 66 | return data, d.Decode(&data) 67 | } 68 | -------------------------------------------------------------------------------- /encryptor/crypto/data_test.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestData(t *testing.T) { 10 | const name = "aes256-gcm" 11 | s, err := MakeRand(SaltLength) 12 | assert.NoError(t, err) 13 | iv, err := MakeRand(8) 14 | assert.NoError(t, err) 15 | nonce, err := MakeRand(NonceLength) 16 | assert.NoError(t, err) 17 | bytes, err := MakeRand(512) 18 | assert.NoError(t, err) 19 | type testCase struct { 20 | cipher string 21 | key KeyLen 22 | mode Mode 23 | salt []byte 24 | iv []byte 25 | nonce []byte 26 | bytes []byte 27 | err assert.ErrorAssertionFunc 28 | } 29 | tests := []testCase{ 30 | {"aes", Key256, "gcm", nil, nil, nil, nil, assert.Error}, 31 | {"aes", Key256, "gcm", s, nil, nonce, nil, assert.Error}, 32 | {"aes", Key256, "gcm", s, iv, nil, nil, assert.Error}, 33 | {"aes", Key256, "gcm", s, iv, nonce, nil, assert.Error}, 34 | {"aes", Key256, "gcm", s, iv, nil, bytes, assert.NoError}, 35 | {"aes", Key256, "gcm", s, nil, nonce, bytes, assert.NoError}, 36 | {"aes", Key256, "gcm", s, iv, nonce, bytes, assert.NoError}, 37 | } 38 | for _, test := range tests { 39 | data := NewData(Header{test.cipher, test.key, test.mode, 40 | test.salt, test.iv, test.nonce}, test.bytes) 41 | test.err(t, data.Valid()) 42 | } 43 | } 44 | 45 | func makeHeader(t *testing.T) Header { 46 | s, err := MakeRand(SaltLength) 47 | assert.NoError(t, err) 48 | iv, err := MakeRand(NonceLength) 49 | assert.NoError(t, err) 50 | nonce, err := MakeRand(NonceLength) 51 | assert.NoError(t, err) 52 | h, err := NewHeader("aes", Key256, "gcm", s, iv, nonce) 53 | assert.NoError(t, err) 54 | return h 55 | } 56 | 57 | func TestEncodeData(t *testing.T) { 58 | bytes, err := MakeRand(512) 59 | assert.NoError(t, err) 60 | data := NewData(makeHeader(t), bytes) 61 | assert.NoError(t, data.Valid()) 62 | enc, err := EncodeData(data) 63 | assert.NoError(t, err) 64 | assert.NotEmpty(t, enc) 65 | dec, err := DecodeData(enc) 66 | assert.NoError(t, err) 67 | assert.Equal(t, data, dec) 68 | } 69 | 70 | func TestGobEncodeData(t *testing.T) { 71 | bytes, err := MakeRand(512) 72 | assert.NoError(t, err) 73 | data := NewData(makeHeader(t), bytes) 74 | assert.NoError(t, data.Valid()) 75 | enc, err := GobEncodeData(data) 76 | assert.NoError(t, err) 77 | assert.NotEmpty(t, enc) 78 | dec, err := GobDecodeData(enc) 79 | assert.NoError(t, err) 80 | assert.Equal(t, data, dec) 81 | } 82 | -------------------------------------------------------------------------------- /encryptor/crypto/encryptor.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | // Encryptor is the interface use to supply cipher implementations to the datastore. 4 | type Encryptor interface { 5 | // ID returns the id of the secret used to encrypt the data. 6 | ID() string 7 | 8 | // Name returns the name of encryption cipher, keyLen length 9 | // and mode used to encrypt the data ("aes192-ctr"). 10 | Name() string 11 | 12 | // Encrypt returns data encrypted with the secret. 13 | Encrypt(plaintext []byte) (ciphertext []byte, err error) 14 | 15 | // Decrypt returns data decrypted with the secret. 16 | Decrypt(ciphertext []byte) (plaintext []byte, err error) 17 | } 18 | -------------------------------------------------------------------------------- /encryptor/crypto/hash.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import "crypto/sha256" 4 | 5 | // HashSHA256 returns a sha256 hash of data. 6 | func HashSHA256(data []byte) ([]byte, error) { 7 | h := sha256.New() 8 | if _, err := h.Write(data); err != nil { 9 | return nil, err 10 | } 11 | return h.Sum(nil), nil 12 | } 13 | -------------------------------------------------------------------------------- /encryptor/crypto/hash_test.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestHashSHA256(t *testing.T) { 10 | var tests = []struct { 11 | in string 12 | out []byte 13 | }{ 14 | { 15 | "", 16 | []byte{0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, 0xc8, 17 | 0x99, 0x6f, 0xb9, 0x24, 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, 0xa4, 18 | 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55}, 19 | }, 20 | { 21 | "abcdefghijklmnopqrstuvwxyz", 22 | []byte{0x71, 0xc4, 0x80, 0xdf, 0x93, 0xd6, 0xae, 0x2f, 0x1e, 0xfa, 0xd1, 0x44, 23 | 0x7c, 0x66, 0xc9, 0x52, 0x5e, 0x31, 0x62, 0x18, 0xcf, 0x51, 0xfc, 0x8d, 0x9e, 0xd8, 24 | 0x32, 0xf2, 0xda, 0xf1, 0x8b, 0x73}, 25 | }, 26 | { 27 | "abcdefghijklmnopqrstuvwxyz1234567890", 28 | []byte{0x77, 0xd7, 0x21, 0xc8, 0x17, 0xf9, 0xd2, 0x16, 0xc1, 0xfb, 0x78, 0x3b, 0xca, 29 | 0xd9, 0xcd, 0xc2, 0xa, 0xaa, 0x24, 0x27, 0x40, 0x26, 0x83, 0xf1, 0xf7, 0x5d, 0xd6, 30 | 0xdf, 0xbe, 0x65, 0x74, 0x70}, 31 | }, 32 | } 33 | for _, test := range tests { 34 | h, err := HashSHA256([]byte(test.in)) 35 | assert.NoError(t, err) 36 | assert.Equal(t, test.out, h, test.in) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /encryptor/crypto/header.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | // MinSaltLength is the minimum length of the salt buffer. 10 | const MinSaltLength = 8 11 | 12 | // A Header describes an encryption block. It contains the cipher name, 13 | // key length, mode used as well as the cipher key salt, iv or nonce. 14 | type Header struct { 15 | Cipher string // e.g. "aes" 16 | KeyLen KeyLen // e.g. 128 17 | Mode Mode // e.g. "gcm" 18 | Salt []byte 19 | IV []byte 20 | Nonce []byte 21 | } 22 | 23 | // NewHeader create a new Header checking the length of the 24 | // salt buffer against MinSaltLength. If the length of the 25 | // salt buffer is less than MinSaltLength it returns an error. 26 | func NewHeader(cipher string, keyLen KeyLen, mode Mode, salt []byte, iv []byte, nonce []byte) (Header, error) { 27 | cipher = strings.ToLower(cipher) 28 | mode = Mode(strings.ToLower(mode.String())) 29 | h := Header{cipher, keyLen, mode, salt, iv, nonce} 30 | if err := h.Valid(); err != nil { 31 | return Header{}, err 32 | } 33 | return h, nil 34 | } 35 | 36 | // Valid returns an error if the Header is not valid. 37 | func (h Header) Valid() error { 38 | if h.Cipher == "" { 39 | return errors.New("cipher required") 40 | } 41 | if h.KeyLen <= 0 { 42 | return errors.New("key length required") 43 | } 44 | if h.Mode == "" { 45 | return errors.New("mode required") 46 | } 47 | if len(h.Salt) < MinSaltLength { 48 | return fmt.Errorf("salt length %d < %d minimum", len(h.Salt), MinSaltLength) 49 | } 50 | if h.Nonce != nil && len(h.Nonce) < NonceLength { 51 | return fmt.Errorf("nonce length %d < %d minimum", len(h.Nonce), NonceLength) 52 | } 53 | return nil 54 | } 55 | 56 | // Name returns the name of the cipher in following 57 | // format "[cipher][key length]-[mode]" e.g. "aes192-ctr". 58 | func (h *Header) Name() string { 59 | return CipherName(h.Cipher, h.KeyLen, h.Mode) 60 | } 61 | 62 | // CipherName is a convenience function that returns the name, 63 | // key length, and mode of a cipher in the following format 64 | // "[cipher][key length]-[mode]" e.g. "aes192-ctr". 65 | func CipherName(cipher string, keyLen KeyLen, mode Mode) string { 66 | return fmt.Sprintf("%s%s-%s", strings.ToLower(cipher), keyLen, strings.ToLower(mode.String())) 67 | } 68 | -------------------------------------------------------------------------------- /encryptor/crypto/header_test.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestHeader(t *testing.T) { 10 | const name = "aes256-gcm" 11 | s, err := MakeRand(SaltLength) 12 | assert.NoError(t, err) 13 | iv, err := MakeRand(8) 14 | assert.NoError(t, err) 15 | nonce, err := MakeRand(NonceLength) 16 | assert.NoError(t, err) 17 | type testCase struct { 18 | cipher string 19 | key KeyLen 20 | mode Mode 21 | salt []byte 22 | iv []byte 23 | nonce []byte 24 | name string 25 | err assert.ErrorAssertionFunc 26 | } 27 | tests := []testCase{ 28 | {"", 0, "", nil, nil, nil, "", assert.Error}, 29 | {"aes", 0, "", nil, nil, nil, "", assert.Error}, 30 | {"aes", Key256, "", nil, nil, nil, "", assert.Error}, 31 | {"aes", Key256, "gcm", nil, nil, nil, "", assert.Error}, 32 | {"aes", Key256, "gcm", []byte(""), nil, nil, "", assert.Error}, 33 | {"aes", Key256, "gcm", s, nil, []byte(""), "", assert.Error}, 34 | {"aes", Key256, "gcm", s, nil, nil, name, assert.NoError}, 35 | {"aes", Key256, "gcm", s, iv, nil, name, assert.NoError}, 36 | {"aes", Key256, "gcm", s, nil, nonce, name, assert.NoError}, 37 | {"aes", Key256, "gcm", s, iv, nonce, name, assert.NoError}, 38 | {"AES", Key256, "GCM", s, iv, nonce, name, assert.NoError}, 39 | } 40 | for _, test := range tests { 41 | h, err := NewHeader(test.cipher, test.key, test.mode, test.salt, test.iv, test.nonce) 42 | test.err(t, err) 43 | if err == nil { 44 | assert.NotNil(t, h) 45 | assert.Equal(t, test.name, h.Name()) 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /encryptor/crypto/key.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "crypto/sha512" 5 | 6 | "golang.org/x/crypto/pbkdf2" 7 | "golang.org/x/crypto/scrypt" 8 | ) 9 | 10 | // KeyLen is used to select 128, 192, or 256 bit keys. 11 | type KeyLen int 12 | 13 | // key lengths 14 | const ( 15 | Key128 KeyLen = 16 // 128 bit 16 | Key192 = 24 // 128 bit 17 | Key256 = 32 // 128 bit 18 | ) 19 | 20 | func (k KeyLen) String() string { 21 | switch k { 22 | case Key128: 23 | return "128" 24 | case Key192: 25 | return "192" 26 | case Key256: 27 | return "256" 28 | default: 29 | return "" 30 | } 31 | } 32 | 33 | // NewCipherKey generate a new cipher key of the appropriate key length. 34 | // Note: Currently this is hard-coded to 4096 key iterations. The thinking here is that 35 | // the strength of secret was determined externally and therefore it less important to 36 | // iterate (again) a large number of times. 1<<15 (or 32768) key iterations, seems to 37 | // be the current consensus for passwords in general (2020). 38 | func NewCipherKey(l KeyLen, secret, salt []byte) ([]byte, error) { 39 | const keyIterations = 4096 40 | return NewScryptCipherKey(l, keyIterations, secret, salt) 41 | } 42 | 43 | // NewPBKDF2CipherKey generate a new cipher key using pbkdf2. 44 | func NewPBKDF2CipherKey(l KeyLen, iterations int, secret, salt []byte) ([]byte, error) { 45 | // sha512, in addition to being more secure, should be faster on 64-bit systems 46 | return pbkdf2.Key(secret, salt, iterations, int(l), sha512.New), nil 47 | } 48 | 49 | // NewScryptCipherKey generate a new cipher key using scrypt. 50 | func NewScryptCipherKey(l KeyLen, iterations int, secret, salt []byte) ([]byte, error) { 51 | return scrypt.Key(secret, salt, iterations, 8, 1, int(l)) 52 | } 53 | -------------------------------------------------------------------------------- /encryptor/crypto/key_test.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | const ( 10 | secret = "i-am-a-secret" 11 | saltLen = 8 12 | keyLen = 32 13 | iterations = 1024 14 | ) 15 | 16 | func TestNewCipherKeys(t *testing.T) { 17 | salt, err := MakeRand(saltLen) 18 | assert.NoError(t, err) 19 | assert.Len(t, salt, saltLen) 20 | sec := []byte(secret) 21 | cipher := func() ([]byte, error) { return NewCipherKey(keyLen, sec, salt) } 22 | pbkdf2 := func() ([]byte, error) { return NewPBKDF2CipherKey(keyLen, iterations, sec, salt) } 23 | scrypt := func() ([]byte, error) { return NewScryptCipherKey(keyLen, iterations, sec, salt) } 24 | test := func(newKey func() ([]byte, error)) { 25 | key1, err1 := newKey() 26 | key2, err2 := newKey() 27 | assert.NoError(t, err1) 28 | assert.NoError(t, err2) 29 | assert.Equal(t, key1, key2) 30 | } 31 | t.Run("NewCipherKey", func(t *testing.T) { test(cipher) }) 32 | t.Run("NewPBKDF2CipherKey", func(t *testing.T) { test(pbkdf2) }) 33 | t.Run("NewScryptCipherKey", func(t *testing.T) { test(scrypt) }) 34 | } 35 | -------------------------------------------------------------------------------- /encryptor/crypto/mode.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | // Mode are the supported modes for a cipher. 4 | type Mode string 5 | 6 | func (m Mode) String() string { 7 | return string(m) 8 | } 9 | -------------------------------------------------------------------------------- /encryptor/crypto/rand.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "crypto/rand" 5 | "fmt" 6 | "io" 7 | ) 8 | 9 | const ( 10 | // SaltLength is the default salt length. 11 | SaltLength = 32 12 | 13 | // NonceLength is the default nonce length. 14 | NonceLength = 12 15 | ) 16 | 17 | // MakeRand returns a buffer of size length filled with random bytes. 18 | func MakeRand(length uint) ([]byte, error) { 19 | // generate random bytes 20 | r := make([]byte, length) 21 | n, err := io.ReadFull(rand.Reader, r) 22 | if err != nil { 23 | return nil, err 24 | } else if uint(n) != length { 25 | return nil, fmt.Errorf("invalid buffer length %d != %d", n, length) 26 | } 27 | return r, nil 28 | } 29 | 30 | // MakeSalt returns random salt of size length. 31 | func MakeSalt() ([]byte, error) { 32 | return MakeRand(SaltLength) 33 | } 34 | 35 | // MakeNonce returns random nonce of size length. 36 | func MakeNonce() ([]byte, error) { 37 | return MakeRand(NonceLength) 38 | } 39 | -------------------------------------------------------------------------------- /encryptor/crypto/rand_test.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestMakeRand(t *testing.T) { 10 | const testLength = 20 11 | buf, err := MakeRand(0) 12 | assert.NoError(t, err) 13 | assert.Len(t, buf, 0) 14 | buf, err = MakeRand(testLength) 15 | assert.NoError(t, err) 16 | assert.Len(t, buf, testLength) 17 | } 18 | -------------------------------------------------------------------------------- /encryptor/crypto/secret.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | // Secret is the interface that wraps a cipher keyLen and its id. 4 | type Secret interface { 5 | // ID return the id of the secret for tracking, or rollover etc. 6 | ID() string 7 | // Open returns a byte representation of the secret for encryption and decryption. 8 | Open() []byte 9 | } 10 | 11 | // A TextSecret provides a simple plaintext secret. 12 | type TextSecret string 13 | 14 | var _ Secret = (*TextSecret)(nil) 15 | 16 | // ID return the id of the secret for tracking, or rollover etc. 17 | func (s TextSecret) ID() string { 18 | return "text" 19 | } 20 | 21 | // Open returns a byte representation of the secret for encryption and decryption. 22 | func (s TextSecret) Open() []byte { 23 | return []byte(s) 24 | } 25 | 26 | // A ManagedSecret provides a simple plaintext secret alongside a unique id. 27 | type ManagedSecret struct { 28 | id string 29 | TextSecret 30 | } 31 | 32 | var _ Secret = (*ManagedSecret)(nil) 33 | 34 | // NewManagedSecret creates a new ManagedSecret with a secret with its corresponding id. 35 | func NewManagedSecret(id, secret string) *ManagedSecret { 36 | return &ManagedSecret{id, TextSecret(secret)} 37 | } 38 | 39 | // ID return the id of the secret for tracking, or rollover etc. 40 | func (s ManagedSecret) ID() string { 41 | return s.id 42 | } 43 | 44 | // SecureSecret provides a unique id for a secret alongside an openSecret callback which 45 | // returns a byte representation of the secret for encryption and decryption on Open. 46 | // When SecureSecret calls openSecret it will pass a copy of itself as a Secret. This allows 47 | // for remote loading of the secret based on its id, or using a secure in-memory storage 48 | // solution for the secret like memguarded (https://github.com/n0rad/memguarded). 49 | type SecureSecret struct { 50 | id string 51 | open func(Secret) []byte 52 | } 53 | 54 | var _ Secret = (*SecureSecret)(nil) 55 | 56 | // NewSecureSecret creates a new SecureSecret with an id and an callback function which 57 | // returns a byte representation of the secret for encryption and decryption. 58 | func NewSecureSecret(id string, openSecret func(Secret) []byte) *SecureSecret { 59 | return &SecureSecret{id, openSecret} 60 | } 61 | 62 | // ID return the id of the secret for tracking, or rollover etc. 63 | func (s SecureSecret) ID() string { 64 | return s.id 65 | } 66 | 67 | // Open returns a byte representation of the secret for encryption and decryption. 68 | func (s SecureSecret) Open() []byte { 69 | return s.open(s) 70 | } 71 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This is a repository for Chestnut examples. Feel free to contribute. 4 | 5 | ## The Goods 6 | 7 | - [sparse](sparse) - Provides an sparse encryption example 8 | 9 | - [hash](hash) - Provides a hash example 10 | 11 | - [keystore](keystore) - Provides an encrypted keystore example 12 | 13 | ## Running the examples 14 | 15 | ```shell 16 | # run any example with make 17 | $ make sparse 18 | ``` 19 | -------------------------------------------------------------------------------- /examples/hash/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | "path/filepath" 8 | 9 | "github.com/jrapoport/chestnut" 10 | "github.com/jrapoport/chestnut/encryptor/aes" 11 | "github.com/jrapoport/chestnut/encryptor/crypto" 12 | "github.com/jrapoport/chestnut/storage/nuts" 13 | ) 14 | 15 | func main() { 16 | path := filepath.Join(os.TempDir(), "hash") 17 | defer os.RemoveAll(path) 18 | 19 | // use nutsdb 20 | store := nuts.NewStore(path) 21 | 22 | // use a simple text secret 23 | textSecret := crypto.TextSecret("i-am-a-good-secret") 24 | 25 | // use AES256-CFB encryption 26 | opt := chestnut.WithAES(crypto.Key256, aes.CFB, textSecret) 27 | 28 | // open the storage chest with nutsdb and the aes encryptor 29 | cn := chestnut.NewChestnut(store, opt) 30 | if err := cn.Open(); err != nil { 31 | log.Panic(err) 32 | } 33 | 34 | // define an struct with a hash field 35 | type HashValue struct { 36 | // ClearString will not be hashed 37 | ClearString string 38 | // HashString with the 'hash' tag option 39 | HashString string `json:",hash"` 40 | } 41 | 42 | src := &HashValue{ 43 | ClearString: "I am a string", 44 | HashString: "I will be hashed", 45 | } 46 | 47 | // a key for the value 48 | namespace := "sparse-values" 49 | key := []byte("sparse-value-id") 50 | 51 | // save the struct with sparse encryption 52 | if err := cn.Save(namespace, key, src); err != nil { 53 | log.Panic(err) 54 | } 55 | 56 | // load the value 57 | err := cn.Load(namespace, key, src) 58 | if err != nil { 59 | log.Panic(err) 60 | } 61 | 62 | fmt.Println("clear field:", src.ClearString) 63 | fmt.Println("hashed field:", src.HashString) 64 | } 65 | -------------------------------------------------------------------------------- /examples/keystore/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "log" 7 | "os" 8 | "path/filepath" 9 | 10 | "github.com/btcsuite/btcd/btcec/v2" 11 | "github.com/jrapoport/chestnut" 12 | "github.com/jrapoport/chestnut/encryptor/aes" 13 | "github.com/jrapoport/chestnut/encryptor/crypto" 14 | "github.com/jrapoport/chestnut/keystore" 15 | "github.com/jrapoport/chestnut/storage/nuts" 16 | ) 17 | 18 | func main() { 19 | 20 | path := filepath.Join(os.TempDir(), "keystore") 21 | defer os.RemoveAll(path) 22 | 23 | // use nutsdb 24 | store := nuts.NewStore(path) 25 | 26 | // use a simple text secret 27 | textSecret := crypto.TextSecret("i-am-a-good-secret") 28 | 29 | opts := []chestnut.ChestOption{ 30 | // use AES256-CFB encryption 31 | chestnut.WithAES(crypto.Key256, aes.CFB, textSecret), 32 | } 33 | 34 | // open the keystore with nutsdb and the aes encryptor 35 | ks := keystore.NewKeystore(store, opts...) 36 | if err := ks.Open(); err != nil { 37 | log.Panic(err) 38 | } 39 | 40 | // generate a new *btcec.PrivateKey 41 | pk1, err := btcec.NewPrivateKey() 42 | if err != nil { 43 | log.Panic(err) 44 | } 45 | 46 | // convert pk from *btcec.PrivateKey to ci.PrivKey. 47 | privKey1 := keystore.BTCECPrivateKeyToPrivKey(pk1) 48 | 49 | // encrypt the private key and put in the keystore 50 | if err = ks.Put("my private key", privKey1); err != nil { 51 | log.Panic(err) 52 | } 53 | 54 | // get the private key from the store and decrypt it 55 | privKey2, err := ks.Get("my private key") 56 | if err != nil { 57 | log.Panic(err) 58 | } 59 | 60 | // convert the saved private key to *btcec.PrivateKey 61 | pk2 := keystore.PrivKeyToBTCECPrivateKey(privKey2) 62 | 63 | // compare the keys 64 | if bytes.Equal(pk1.Serialize(), pk2.Serialize()) { 65 | fmt.Println("private keys are equal") 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /examples/sparse/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | "path/filepath" 8 | 9 | "github.com/jrapoport/chestnut" 10 | "github.com/jrapoport/chestnut/encryptor/aes" 11 | "github.com/jrapoport/chestnut/encryptor/crypto" 12 | "github.com/jrapoport/chestnut/storage/nuts" 13 | ) 14 | 15 | func main() { 16 | path := filepath.Join(os.TempDir(), "sparse") 17 | defer os.RemoveAll(path) 18 | 19 | // use nutsdb 20 | store := nuts.NewStore(path) 21 | 22 | // use a simple text secret 23 | textSecret := crypto.TextSecret("i-am-a-good-secret") 24 | 25 | // use AES256-CFB encryption 26 | opt := chestnut.WithAES(crypto.Key256, aes.CFB, textSecret) 27 | 28 | // open the storage chest with nutsdb and the aes encryptor 29 | cn := chestnut.NewChestnut(store, opt) 30 | if err := cn.Open(); err != nil { 31 | log.Panic(err) 32 | } 33 | 34 | // define a sparse struct with a secure field 35 | type Sparse struct { 36 | // SecretString with the 'secure' tag option 37 | SecretString string `json:",secure"` 38 | // PublicString will not be encrypted 39 | PublicString string 40 | } 41 | 42 | src := &Sparse{ 43 | SecretString: "I am secret", 44 | PublicString: "I am visible", 45 | } 46 | 47 | // a key for the value 48 | namespace := "sparse-values" 49 | key := []byte("sparse-value-id") 50 | 51 | // save the value with sparse encryption 52 | if err := cn.Save(namespace, key, src); err != nil { 53 | log.Panic(err) 54 | } 55 | 56 | sparse := &Sparse{} 57 | 58 | // load a sparse copy 59 | err := cn.Sparse(namespace, key, sparse) 60 | if err != nil { 61 | log.Panic(err) 62 | } 63 | 64 | fmt.Println("-- sparse --") 65 | fmt.Println("secure field:", sparse.SecretString) 66 | fmt.Println("public field:", sparse.PublicString) 67 | 68 | load := &Sparse{} 69 | 70 | // load a full copy 71 | err = cn.Load(namespace, key, load) 72 | if err != nil { 73 | log.Panic(err) 74 | } 75 | 76 | fmt.Println("-- load --") 77 | fmt.Println("secure field:", load.SecretString) 78 | fmt.Println("public field:", load.PublicString) 79 | } 80 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/jrapoport/chestnut 2 | 3 | go 1.23 4 | 5 | toolchain go1.23.2 6 | 7 | require ( 8 | github.com/btcsuite/btcd/btcec/v2 v2.3.4 9 | github.com/google/uuid v1.6.0 10 | github.com/hashicorp/go-version v1.7.0 11 | github.com/ipfs/boxo v0.24.2 12 | github.com/json-iterator/go v1.1.12 13 | github.com/klauspost/compress v1.17.11 14 | github.com/libp2p/go-libp2p v0.37.0 15 | github.com/modern-go/reflect2 v1.0.2 16 | github.com/nutsdb/nutsdb v1.0.4 17 | github.com/sirupsen/logrus v1.9.3 18 | github.com/stretchr/testify v1.9.0 19 | go.etcd.io/bbolt v1.3.11 20 | go.uber.org/zap v1.27.0 21 | golang.org/x/crypto v0.28.0 22 | ) 23 | 24 | replace github.com/json-iterator/go => github.com/jrapoport/jsoniter v0.0.0-20241027074812-b8ebffc46abb 25 | 26 | require ( 27 | github.com/antlabs/stl v0.0.2 // indirect 28 | github.com/antlabs/timer v0.1.4 // indirect 29 | github.com/bwmarrin/snowflake v0.3.0 // indirect 30 | github.com/davecgh/go-spew v1.1.1 // indirect 31 | github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect 32 | github.com/gofrs/flock v0.12.1 // indirect 33 | github.com/google/go-cmp v0.6.0 // indirect 34 | github.com/ipfs/go-log/v2 v2.5.1 // indirect 35 | github.com/mattn/go-isatty v0.0.20 // indirect 36 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 37 | github.com/pkg/errors v0.9.1 // indirect 38 | github.com/pmezard/go-difflib v1.0.0 // indirect 39 | github.com/rogpeppe/go-internal v1.13.1 // indirect 40 | github.com/tidwall/btree v1.7.0 // indirect 41 | github.com/xujiajun/mmap-go v1.0.1 // indirect 42 | github.com/xujiajun/utils v0.0.0-20220904132955-5f7c5b914235 // indirect 43 | go.uber.org/multierr v1.11.0 // indirect 44 | golang.org/x/sys v0.26.0 // indirect 45 | google.golang.org/protobuf v1.35.1 // indirect 46 | gopkg.in/yaml.v3 v3.0.1 // indirect 47 | ) 48 | -------------------------------------------------------------------------------- /keystore/README.md: -------------------------------------------------------------------------------- 1 | # Keystore 2 | 3 | Keystore is an IPFS compliant keystore built on Chestnut. It implements an IPFS keystore interface, allowing it to be used natively with many existing IPFS implementations, and tools. 4 | 5 | We recommend using AES256-CTR for encryption based in part on this 6 | [helpful analysis](https://www.highgo.ca/2019/08/08/the-difference-in-five-modes-in-the-aes-encryption-algorithm/) 7 | of database encryption approaches and trade-offs from Shawn Wang, PostgreSQL Database Core. 8 | 9 | For a detailed example on importing and using the Keystore, please check out the [Keystore](../examples/keystore) 10 | example under the `examples` folder. 11 | 12 | ### IMPORTANT! 13 | 14 | ```go 15 | package main 16 | 17 | import ( 18 | "github.com/ipfs/go-ipfs/keystore" 19 | "github.com/libp2p/go-libp2p/core/crypto" 20 | ) 21 | ``` 22 | 23 | Please **make sure** you import 24 | [go-ipfs](github.com/ipfs/go-ipfs) and [go-libp2p-core](https://github.com/libp2p/go-libp2p-core/), 25 | and are **NOT** importing [go-ipfs-keystore](github.com/ipfs/go-ipfs-keystore) and 26 | [go-libp2p-crypto](github.com/libp2p/go-libp2p-crypto). Those repos are **DEPRECATED**, 27 | out of date, archived, etc. This will save you time and sanity. 28 | -------------------------------------------------------------------------------- /keystore/keystore.go: -------------------------------------------------------------------------------- 1 | package keystore 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/ipfs/boxo/keystore" 7 | "github.com/jrapoport/chestnut" 8 | "github.com/jrapoport/chestnut/log" 9 | "github.com/jrapoport/chestnut/storage" 10 | ci "github.com/libp2p/go-libp2p/core/crypto" 11 | ) 12 | 13 | const ( 14 | namespace = "keys" 15 | logName = "keystore" 16 | ) 17 | 18 | // Keystore is used to manage an encrypted IPFS-compliant keystore. 19 | type Keystore struct { 20 | cn *chestnut.Chestnut 21 | store storage.Storage 22 | log log.Logger 23 | } 24 | 25 | var _ keystore.Keystore = (*Keystore)(nil) 26 | 27 | // NewKeystore is used to create a new chestnut ipfs-compliant keystore. 28 | // Suggest using this with AES256-CTR encryption based in part 29 | // on this helpful analysis from Shawn Wang, PostgreSQL Database Core: 30 | // https://www.highgo.ca/2019/08/08/the-difference-in-five-modes-in-the-aes-encryption-algorithm/ 31 | func NewKeystore(store storage.Storage, opt ...chestnut.ChestOption) *Keystore { 32 | // keystore requires that overwrites are forbidden 33 | opt = append(opt, chestnut.OverwritesForbidden()) 34 | cn := chestnut.NewChestnut(store, opt...) 35 | logger := log.Named(cn.Logger(), logName) 36 | ks := &Keystore{cn, store, logger} 37 | if err := ks.validConfig(); err != nil { 38 | logger.Panic(err) 39 | return nil 40 | } 41 | return ks 42 | } 43 | 44 | func (ks *Keystore) validConfig() error { 45 | if ks.store == nil { 46 | return errors.New("store required") 47 | } 48 | return nil 49 | } 50 | 51 | // Open the Keystore 52 | func (ks *Keystore) Open() error { 53 | if err := ks.validConfig(); err != nil { 54 | return err 55 | } 56 | return ks.cn.Open() 57 | } 58 | 59 | // Has returns whether or not a key exists in the Keystore 60 | func (ks *Keystore) Has(s string) (bool, error) { 61 | return ks.cn.Has(namespace, []byte(s)) 62 | } 63 | 64 | // Put stores a key in the Keystore, if a key with 65 | // the same name already exists, returns ErrKeyExists 66 | func (ks *Keystore) Put(s string, key ci.PrivKey) error { 67 | if key == nil { 68 | return errors.New("invalid key") 69 | } 70 | data, err := ci.MarshalPrivateKey(key) 71 | if err != nil { 72 | return err 73 | } 74 | err = ks.cn.Put(namespace, []byte(s), data) 75 | if errors.Is(err, chestnut.ErrForbidden) { 76 | return keystore.ErrKeyExists 77 | } 78 | return err 79 | } 80 | 81 | // Get retrieves a key from the Keystore if it 82 | // exists, and returns ErrNoSuchKey otherwise. 83 | func (ks *Keystore) Get(s string) (ci.PrivKey, error) { 84 | data, err := ks.cn.Get(namespace, []byte(s)) 85 | if err != nil { 86 | return nil, keystore.ErrNoSuchKey 87 | } 88 | return ci.UnmarshalPrivateKey(data) 89 | } 90 | 91 | // Delete removes a key from the Keystore 92 | func (ks *Keystore) Delete(s string) error { 93 | return ks.cn.Delete(namespace, []byte(s)) 94 | } 95 | 96 | // List returns a list of key identifier 97 | func (ks *Keystore) List() ([]string, error) { 98 | list, err := ks.cn.List(namespace) 99 | if err != nil { 100 | return nil, err 101 | } 102 | keys := make([]string, len(list)) 103 | for i, key := range list { 104 | keys[i] = string(key) 105 | } 106 | return keys, nil 107 | } 108 | 109 | // Export the Keystore 110 | func (ks *Keystore) Export(path string) error { 111 | return ks.cn.Export(path) 112 | } 113 | 114 | // Close the Keystore 115 | func (ks *Keystore) Close() error { 116 | return ks.cn.Close() 117 | } 118 | -------------------------------------------------------------------------------- /keystore/keystore_test.go: -------------------------------------------------------------------------------- 1 | package keystore 2 | 3 | import ( 4 | "log" 5 | "sort" 6 | "testing" 7 | 8 | "github.com/google/uuid" 9 | "github.com/jrapoport/chestnut" 10 | "github.com/jrapoport/chestnut/encryptor/aes" 11 | "github.com/jrapoport/chestnut/encryptor/crypto" 12 | "github.com/jrapoport/chestnut/storage" 13 | "github.com/jrapoport/chestnut/storage/nuts" 14 | ci "github.com/libp2p/go-libp2p/core/crypto" 15 | "github.com/stretchr/testify/assert" 16 | "github.com/stretchr/testify/suite" 17 | ) 18 | 19 | var ( 20 | testName = uuid.New().String() 21 | textSecret = crypto.TextSecret("i-am-a-good-secret") 22 | encryptorOpt = chestnut.WithAES(crypto.Key256, aes.CFB, textSecret) 23 | privateKey = func() ci.PrivKey { 24 | pk, _, err := ci.GenerateKeyPair(ci.ECDSA, 512) 25 | if err != nil { 26 | log.Fatal(err) 27 | } 28 | return pk 29 | }() 30 | ) 31 | 32 | type testCase struct { 33 | name string 34 | key ci.PrivKey 35 | err assert.ErrorAssertionFunc 36 | exists bool 37 | } 38 | 39 | var tests = []testCase{ 40 | {"", nil, assert.Error, false}, 41 | {"", nil, assert.Error, false}, 42 | {"f", nil, assert.Error, false}, 43 | {"g", privateKey, assert.NoError, true}, 44 | {"h", privateKey, assert.NoError, true}, 45 | {"i/i", privateKey, assert.NoError, true}, 46 | {".j", privateKey, assert.NoError, true}, 47 | {testName, privateKey, assert.NoError, true}, 48 | } 49 | 50 | var testCaseNotFound = testCase{"not-found", nil, assert.Error, false} 51 | 52 | type KeystoreTestSuite struct { 53 | suite.Suite 54 | keystore *Keystore 55 | } 56 | 57 | func newNutsDBStore(t *testing.T) storage.Storage { 58 | path := t.TempDir() 59 | store := nuts.NewStore(path) 60 | assert.NotNil(t, store) 61 | return store 62 | } 63 | 64 | func TestKeystore(t *testing.T) { 65 | suite.Run(t, new(KeystoreTestSuite)) 66 | } 67 | 68 | func (ts *KeystoreTestSuite) SetupTest() { 69 | store := newNutsDBStore(ts.T()) 70 | ts.keystore = NewKeystore(store, encryptorOpt) 71 | ts.NotNil(ts.keystore) 72 | err := ts.keystore.Open() 73 | ts.NoError(err) 74 | } 75 | 76 | func (ts *KeystoreTestSuite) TearDownTest() { 77 | err := ts.keystore.Close() 78 | ts.NoError(err) 79 | } 80 | 81 | func (ts *KeystoreTestSuite) BeforeTest(_, testName string) { 82 | switch testName { 83 | case "TestKeystore_Encryptor", 84 | "TestKeystore_Put", 85 | "TestKeystore_List": 86 | break 87 | default: 88 | ts.TestKeystore_Put() 89 | } 90 | } 91 | 92 | func TestInvalidConfig(t *testing.T) { 93 | assert.Panics(t, func() { 94 | NewKeystore(nil, encryptorOpt) 95 | }) 96 | } 97 | 98 | func (ts *KeystoreTestSuite) TestKeystore_Encryptor() { 99 | err := ts.keystore.Put(testName, privateKey) 100 | ts.NoError(err) 101 | pk, err := ts.keystore.Get(testName) 102 | ts.NotNil(pk) 103 | ts.NoError(err) 104 | ts.Equal(privateKey.Type().String(), pk.Type().String()) 105 | } 106 | 107 | func (ts *KeystoreTestSuite) TestKeystore_Put() { 108 | for i, test := range tests { 109 | err := ts.keystore.Put(test.name, test.key) 110 | test.err(ts.T(), err, "%d test name: %s", i, test.name) 111 | } 112 | err := ts.keystore.Put(testName, privateKey) 113 | ts.Error(err) 114 | } 115 | 116 | func (ts *KeystoreTestSuite) TestKeystore_Get() { 117 | getTests := append(tests, testCaseNotFound) 118 | for i, test := range getTests { 119 | key, err := ts.keystore.Get(test.name) 120 | test.err(ts.T(), err, "%d test name: %s", i, test.name) 121 | ts.Equal(test.key, key, "%d test name: %s", i, test.name) 122 | } 123 | } 124 | 125 | func (ts *KeystoreTestSuite) TestKeystore_Has() { 126 | for _, test := range tests { 127 | has, _ := ts.keystore.Has(test.name) 128 | ts.Equal(test.exists, has) 129 | } 130 | } 131 | 132 | func (ts *KeystoreTestSuite) TestKeystore_List() { 133 | const listLen = 100 134 | list := make([]string, listLen) 135 | for i := 0; i < listLen; i++ { 136 | list[i] = uuid.New().String() 137 | err := ts.keystore.Put(list[i], privateKey) 138 | ts.NoError(err) 139 | } 140 | keys, err := ts.keystore.List() 141 | ts.NoError(err) 142 | ts.Len(keys, listLen) 143 | // put both lists in the same order so we can compare them 144 | sort.Strings(list) 145 | sort.Strings(keys) 146 | ts.Equal(list, keys) 147 | } 148 | 149 | func (ts *KeystoreTestSuite) TestKeystore_Delete() { 150 | for i, test := range tests { 151 | if test.exists == false { 152 | continue 153 | } 154 | err := ts.keystore.Delete(test.name) 155 | test.err(ts.T(), err, "%d test key: %s", i, test.key) 156 | } 157 | } 158 | 159 | func (ts *KeystoreTestSuite) TestKeystore_Export() { 160 | err := ts.keystore.Export(ts.T().TempDir()) 161 | ts.NoError(err) 162 | } 163 | 164 | func TestKeystore_OpenErr(t *testing.T) { 165 | ks := &Keystore{} 166 | err := ks.Open() 167 | assert.Error(t, err) 168 | } 169 | -------------------------------------------------------------------------------- /keystore/keyutils.go: -------------------------------------------------------------------------------- 1 | package keystore 2 | 3 | import ( 4 | "crypto/ecdsa" 5 | "crypto/rsa" 6 | "log" 7 | 8 | "github.com/btcsuite/btcd/btcec/v2" 9 | "github.com/libp2p/go-libp2p/core/crypto" 10 | "golang.org/x/crypto/ed25519" 11 | ) 12 | 13 | // PrivKeyToRSAPrivateKey converts libp2p/go-libp2p/core/crypto 14 | // private keys to standard library rsa private keys. 15 | func PrivKeyToRSAPrivateKey(privKey crypto.PrivKey) *rsa.PrivateKey { 16 | key, err := crypto.PrivKeyToStdKey(privKey) 17 | if err != nil { 18 | log.Panic(err) 19 | return nil 20 | } 21 | if pk, ok := key.(*rsa.PrivateKey); ok { 22 | return pk 23 | } 24 | return nil 25 | } 26 | 27 | // RSAPrivateKeyToPrivKey converts standard library rsa 28 | // private keys to libp2p/go-libp2p/core/crypto private keys. 29 | func RSAPrivateKeyToPrivKey(privateKey *rsa.PrivateKey) crypto.PrivKey { 30 | // because we are strongly typing the interface it will never fail 31 | pk, _, _ := crypto.KeyPairFromStdKey(privateKey) 32 | return pk 33 | } 34 | 35 | // PrivKeyToECDSAPrivateKey converts libp2p/go-libp2p/core/crypto 36 | // private keys to new standard library ecdsa private keys. 37 | func PrivKeyToECDSAPrivateKey(privKey crypto.PrivKey) *ecdsa.PrivateKey { 38 | key, err := crypto.PrivKeyToStdKey(privKey) 39 | if err != nil { 40 | log.Panic(err) 41 | return nil 42 | } 43 | if pk, ok := key.(*ecdsa.PrivateKey); ok { 44 | return pk 45 | } 46 | return nil 47 | } 48 | 49 | // ECDSAPrivateKeyToPrivKey converts standard library ecdsa 50 | // private keys to libp2p/go-libp2p/core/crypto private keys. 51 | func ECDSAPrivateKeyToPrivKey(privateKey *ecdsa.PrivateKey) crypto.PrivKey { 52 | // because we are strongly typing the interface it will never fail 53 | pk, _, _ := crypto.KeyPairFromStdKey(privateKey) 54 | return pk 55 | } 56 | 57 | // PrivKeyToEd25519PrivateKey converts libp2p/go-libp2p/core/crypto 58 | // private keys to ed25519 private keys. 59 | func PrivKeyToEd25519PrivateKey(privKey crypto.PrivKey) *ed25519.PrivateKey { 60 | key, err := crypto.PrivKeyToStdKey(privKey) 61 | if err != nil { 62 | log.Panic(err) 63 | return nil 64 | } 65 | if pk, ok := key.(*ed25519.PrivateKey); ok { 66 | return pk 67 | } 68 | return nil 69 | } 70 | 71 | // Ed25519PrivateKeyToPrivKey converts ed25519 private keys 72 | // to libp2p/go-libp2p/core/crypto private keys. 73 | func Ed25519PrivateKeyToPrivKey(privateKey *ed25519.PrivateKey) crypto.PrivKey { 74 | // because we are strongly typing the interface it will never fail 75 | pk, _, _ := crypto.KeyPairFromStdKey(privateKey) 76 | return pk 77 | } 78 | 79 | // PrivKeyToBTCECPrivateKey converts libp2p/go-libp2p/core/crypto 80 | // private keys to standard library btcec (and secp256k1) private keys. 81 | // Internally equivalent to (*btcec.PrivateKey)(privKey.(*crypto.Secp256k1PrivateKey)). 82 | func PrivKeyToBTCECPrivateKey(privKey crypto.PrivKey) *btcec.PrivateKey { 83 | key, err := crypto.PrivKeyToStdKey(privKey) 84 | if err != nil { 85 | log.Panic(err) 86 | return nil 87 | } 88 | if pk, ok := key.(*crypto.Secp256k1PrivateKey); ok { 89 | return (*btcec.PrivateKey)(pk) 90 | } 91 | return nil 92 | } 93 | 94 | // BTCECPrivateKeyToPrivKey converts standard library btcec (and secp256k1) 95 | // private keys to libp2p/go-libp2p/core/crypto private keys. Internally 96 | // equivalent to (*crypto.Secp256k1PrivateKey)(privateKey). 97 | func BTCECPrivateKeyToPrivKey(privateKey *btcec.PrivateKey) crypto.PrivKey { 98 | // because we are strongly typing the interface it will never fail 99 | pk, _, _ := crypto.KeyPairFromStdKey(privateKey) 100 | return pk 101 | } 102 | -------------------------------------------------------------------------------- /keystore/keyutils_test.go: -------------------------------------------------------------------------------- 1 | package keystore 2 | 3 | import ( 4 | "crypto/ecdsa" 5 | "crypto/elliptic" 6 | "crypto/rand" 7 | "crypto/rsa" 8 | "testing" 9 | 10 | "github.com/btcsuite/btcd/btcec/v2" 11 | "github.com/libp2p/go-libp2p/core/crypto" 12 | "github.com/stretchr/testify/assert" 13 | "golang.org/x/crypto/ed25519" 14 | ) 15 | 16 | func testPrivKeyToPrivateKey(t *testing.T, pk1 interface{}, conv func() interface{}) { 17 | assert.NotNil(t, pk1) 18 | stdKey := conv() 19 | assert.NotNil(t, stdKey) 20 | pk2, _, err := crypto.KeyPairFromStdKey(stdKey) 21 | assert.NoError(t, err) 22 | assert.Equal(t, pk1, pk2) 23 | } 24 | 25 | func testPrivateKeyToPrivKey(t *testing.T, pk1 interface{}, conv func() crypto.PrivKey) { 26 | assert.NotNil(t, pk1) 27 | privKey := conv() 28 | assert.NotNil(t, privKey) 29 | pk2, err := crypto.PrivKeyToStdKey(privKey) 30 | assert.NoError(t, err) 31 | assert.Equal(t, pk1, pk2) 32 | } 33 | 34 | func TestPrivKeyToRSAPrivateKey(t *testing.T) { 35 | privKey, _, err := crypto.GenerateRSAKeyPair(2048, rand.Reader) 36 | assert.NoError(t, err) 37 | testPrivKeyToPrivateKey(t, privKey, func() interface{} { 38 | return PrivKeyToRSAPrivateKey(privKey) 39 | }) 40 | assert.Panics(t, func() { 41 | _ = PrivKeyToRSAPrivateKey(nil) 42 | }) 43 | } 44 | 45 | func TestPrivKeyToECDSAPrivateKey(t *testing.T) { 46 | privKey, _, err := crypto.GenerateECDSAKeyPair(rand.Reader) 47 | assert.NoError(t, err) 48 | testPrivKeyToPrivateKey(t, privKey, func() interface{} { 49 | return PrivKeyToECDSAPrivateKey(privKey) 50 | }) 51 | assert.Panics(t, func() { 52 | _ = PrivKeyToECDSAPrivateKey(nil) 53 | }) 54 | } 55 | 56 | func TestPrivKeyToEd25519PrivateKey(t *testing.T) { 57 | privKey, _, err := crypto.GenerateEd25519Key(rand.Reader) 58 | assert.NoError(t, err) 59 | testPrivKeyToPrivateKey(t, privKey, func() interface{} { 60 | return PrivKeyToEd25519PrivateKey(privKey) 61 | }) 62 | assert.Panics(t, func() { 63 | _ = PrivKeyToEd25519PrivateKey(nil) 64 | }) 65 | } 66 | 67 | func TestPrivKeyToBTCECPrivateKey(t *testing.T) { 68 | privKey, _, err := crypto.GenerateSecp256k1Key(rand.Reader) 69 | assert.NoError(t, err) 70 | testPrivKeyToPrivateKey(t, privKey, func() interface{} { 71 | return PrivKeyToBTCECPrivateKey(privKey) 72 | }) 73 | assert.Panics(t, func() { 74 | _ = PrivKeyToBTCECPrivateKey(nil) 75 | }) 76 | } 77 | 78 | func TestRSAPrivateKeyToPrivKey(t *testing.T) { 79 | rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) 80 | assert.NoError(t, err) 81 | testPrivateKeyToPrivKey(t, rsaKey, func() crypto.PrivKey { 82 | return RSAPrivateKeyToPrivKey(rsaKey) 83 | }) 84 | assert.Panics(t, func() { 85 | _ = RSAPrivateKeyToPrivKey(nil) 86 | }) 87 | } 88 | 89 | func TestECDSAPrivateKeyToPrivKey(t *testing.T) { 90 | ecdsaKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) 91 | assert.NoError(t, err) 92 | testPrivateKeyToPrivKey(t, ecdsaKey, func() crypto.PrivKey { 93 | return ECDSAPrivateKeyToPrivKey(ecdsaKey) 94 | }) 95 | assert.Panics(t, func() { 96 | _ = ECDSAPrivateKeyToPrivKey(nil) 97 | }) 98 | } 99 | 100 | func TestEd25519PrivateKeyToPrivKey(t *testing.T) { 101 | _, edKey, err := ed25519.GenerateKey(rand.Reader) 102 | assert.NoError(t, err) 103 | testPrivateKeyToPrivKey(t, &edKey, func() crypto.PrivKey { 104 | return Ed25519PrivateKeyToPrivKey(&edKey) 105 | }) 106 | assert.Panics(t, func() { 107 | _ = Ed25519PrivateKeyToPrivKey(nil) 108 | }) 109 | } 110 | 111 | func TestBTCECPrivateKeyToPrivKey(t *testing.T) { 112 | btcecKey, err := btcec.NewPrivateKey() 113 | key := (*crypto.Secp256k1PrivateKey)(btcecKey) 114 | assert.NoError(t, err) 115 | testPrivateKeyToPrivKey(t, key, func() crypto.PrivKey { 116 | return BTCECPrivateKeyToPrivKey(btcecKey) 117 | }) 118 | assert.Panics(t, func() { 119 | _ = BTCECPrivateKeyToPrivKey(nil) 120 | }) 121 | } 122 | -------------------------------------------------------------------------------- /log/level.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | // A Level is a logging priority. Higher levels are more important. 4 | // This is here as a convenience when using the various log options. 5 | type Level int 6 | 7 | const ( 8 | // DebugLevel logs are typically voluminous, 9 | // and are usually disabled in production. 10 | DebugLevel Level = iota - 1 11 | 12 | // InfoLevel is the default logging priority. 13 | InfoLevel 14 | 15 | // WarnLevel logs are more important than Info, 16 | // but don't need individual human review. 17 | WarnLevel 18 | 19 | // ErrorLevel logs are high-priority. If an application runs 20 | // smoothly, it shouldn't generate any error-level logs. 21 | ErrorLevel 22 | 23 | // PanicLevel logs a message, then panics. 24 | PanicLevel 25 | 26 | // FatalLevel logs a message, then calls os.Exit(1). 27 | FatalLevel 28 | ) 29 | -------------------------------------------------------------------------------- /log/level_test.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestLevel(t *testing.T) { 10 | levels := []Level{ 11 | DebugLevel, 12 | InfoLevel, 13 | WarnLevel, 14 | ErrorLevel, 15 | PanicLevel, 16 | FatalLevel, 17 | } 18 | type NewLoggerFunc func(Level) Logger 19 | tests := []struct { 20 | name string 21 | logFn NewLoggerFunc 22 | }{ 23 | {"logrus", NewLogrusLoggerWithLevel}, 24 | {"std", NewStdLoggerWithLevel}, 25 | {"zap", NewZapLoggerWithLevel}, 26 | } 27 | for _, level := range levels { 28 | for _, test := range tests { 29 | logger := test.logFn(level) 30 | // debug 31 | logger.Debug(test.name, " ", "debug") 32 | logger.Debugf("%s %s", test.name, "debug") 33 | // info 34 | logger.Info(test.name, " ", "info") 35 | logger.Infof("%s %s", test.name, "info") 36 | // warn 37 | logger.Warn(test.name, " ", "warn") 38 | logger.Warnf("%s %s", test.name, "warn") 39 | // error 40 | logger.Error(test.name, " ", "error") 41 | logger.Errorf("%s %s", test.name, "error") 42 | // panic 43 | assert.Panics(t, func() { 44 | logger.Panic(test.name, " ", "panic") 45 | }) 46 | assert.Panics(t, func() { 47 | logger.Panicf("%s %s", test.name, "panic") 48 | }) 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /log/logger.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | // Log is the same as the default standard logger from "log". 4 | var Log = NewStdLoggerWithLevel(PanicLevel) 5 | 6 | // Logger is a generic logger interface. 7 | type Logger interface { 8 | // Debug logs args when the logger level is debug. 9 | Debug(v ...interface{}) 10 | 11 | // Debugf formats args and logs the result when the logger level is debug. 12 | Debugf(format string, v ...interface{}) 13 | 14 | // Info logs args when the logger level is info. 15 | Info(args ...interface{}) 16 | 17 | // Infof formats args and logs the result when the logger level is info. 18 | Infof(format string, v ...interface{}) 19 | 20 | // Warn logs args when the logger level is warn. 21 | Warn(v ...interface{}) 22 | 23 | // Warnf formats args and logs the result when the logger level is warn. 24 | Warnf(format string, v ...interface{}) 25 | 26 | // Error logs args when the logger level is error. 27 | Error(v ...interface{}) 28 | 29 | // Errorf formats args and logs the result when the logger level is debug. 30 | Errorf(format string, v ...interface{}) 31 | 32 | // Panic logs args on panic. 33 | Panic(v ...interface{}) 34 | 35 | // Panicf formats args and logs the result on panic. 36 | Panicf(format string, v ...interface{}) 37 | 38 | // Fatal logs args when the error is fatal. 39 | Fatal(v ...interface{}) 40 | 41 | // Fatalf formats args and logs the result when the error is fatal. 42 | Fatalf(format string, v ...interface{}) 43 | } 44 | -------------------------------------------------------------------------------- /log/logrus.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import "github.com/sirupsen/logrus" 4 | 5 | var _ Logger = (*logrus.Logger)(nil) 6 | var _ Logger = (*logrus.Entry)(nil) 7 | 8 | // NewLogrusLoggerWithLevel returns a new production logrus logger with the log level. 9 | func NewLogrusLoggerWithLevel(lvl Level) Logger { 10 | l := logrus.New() 11 | l.SetLevel(levelToLogrusLevel(lvl)) 12 | return l.WithContext(nil) 13 | } 14 | 15 | // NOTE: for logrus panic is a higher level than fatal. 16 | func levelToLogrusLevel(lvl Level) logrus.Level { 17 | switch lvl { 18 | case DebugLevel: 19 | return logrus.DebugLevel 20 | case InfoLevel: 21 | return logrus.InfoLevel 22 | case WarnLevel: 23 | return logrus.WarnLevel 24 | case ErrorLevel: 25 | return logrus.ErrorLevel 26 | case PanicLevel: 27 | return logrus.PanicLevel 28 | case FatalLevel: 29 | return logrus.FatalLevel 30 | default: 31 | return logrus.InfoLevel 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /log/named.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/sirupsen/logrus" 7 | "go.uber.org/zap" 8 | ) 9 | 10 | // logrusField matches zap 11 | const logrusField = "logger" 12 | 13 | // Named adds a name string to the logger. How the name is added is 14 | // logger specific i.e. a logrus field or std logger prefix, etc. 15 | func Named(logger interface{}, name string) Logger { 16 | switch l := logger.(type) { 17 | case *logrus.Logger: 18 | return l.WithField(logrusField, name) 19 | case *logrus.Entry: 20 | return l.WithField(logrusField, name) 21 | case *log.Logger: 22 | l.SetPrefix(name + " ") 23 | return &stdLogger{l, InfoLevel} 24 | case *stdLogger: 25 | l.SetPrefix(name + " ") 26 | return l 27 | case *zap.SugaredLogger: 28 | return l.Named(name) 29 | case *zap.Logger: 30 | return l.Sugar().Named(name) 31 | } 32 | return nil 33 | } 34 | -------------------------------------------------------------------------------- /log/named_test.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "log" 5 | "os" 6 | "testing" 7 | 8 | "github.com/sirupsen/logrus" 9 | "github.com/stretchr/testify/assert" 10 | "go.uber.org/zap" 11 | ) 12 | 13 | func TestWrapper(t *testing.T) { 14 | const ( 15 | testName = "test" 16 | emptyName = "" 17 | ) 18 | tests := []struct { 19 | logger interface{} 20 | name string 21 | assertNil assert.ValueAssertionFunc 22 | }{ 23 | {nil, emptyName, assert.Nil}, 24 | {logrus.New(), emptyName, assert.NotNil}, 25 | {logrus.New(), testName, assert.NotNil}, 26 | {logrus.New().WithContext(nil), emptyName, assert.NotNil}, 27 | {logrus.New().WithContext(nil), testName, assert.NotNil}, 28 | {NewLogrusLoggerWithLevel(ErrorLevel), emptyName, assert.NotNil}, 29 | {NewLogrusLoggerWithLevel(ErrorLevel), testName, assert.NotNil}, 30 | {log.New(os.Stderr, "", 0), emptyName, assert.NotNil}, 31 | {log.New(os.Stderr, "", 0), testName, assert.NotNil}, 32 | {NewStdLoggerWithLevel(ErrorLevel), emptyName, assert.NotNil}, 33 | {NewStdLoggerWithLevel(ErrorLevel), testName, assert.NotNil}, 34 | {zap.NewExample(), emptyName, assert.NotNil}, 35 | {zap.NewExample(), testName, assert.NotNil}, 36 | {zap.NewExample().Sugar(), emptyName, assert.NotNil}, 37 | {zap.NewExample().Sugar(), testName, assert.NotNil}, 38 | {NewZapLoggerWithLevel(ErrorLevel), emptyName, assert.NotNil}, 39 | {NewZapLoggerWithLevel(ErrorLevel), testName, assert.NotNil}, 40 | } 41 | 42 | for _, test := range tests { 43 | logger := Named(test.logger, "name") 44 | test.assertNil(t, logger) 45 | if logger != nil { 46 | _, ok := logger.(Logger) 47 | assert.True(t, ok) 48 | // error 49 | logger.Error(testName) 50 | logger.Errorf("%s", testName) 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /log/std.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "os" 7 | ) 8 | 9 | // stdLogger is a wrapper of standard log. 10 | type stdLogger struct { 11 | *log.Logger 12 | level Level 13 | } 14 | 15 | var _ Logger = (*stdLogger)(nil) 16 | 17 | // NewStdLoggerWithLevel is a stderr logger with the log level. 18 | func NewStdLoggerWithLevel(lvl Level) Logger { 19 | return NewStdLogger(lvl, os.Stderr, "", log.LstdFlags) 20 | } 21 | 22 | // NewStdLogger returns a new standard logger with the log level. 23 | func NewStdLogger(lvl Level, out io.Writer, prefix string, flag int) Logger { 24 | return &stdLogger{log.New(out, prefix, flag), lvl} 25 | } 26 | 27 | // Debug logs args when the logger level is debug. 28 | func (l *stdLogger) Debug(v ...interface{}) { 29 | if l.level > DebugLevel { 30 | return 31 | } 32 | l.Print(v...) 33 | } 34 | 35 | // Debugf formats args and logs the result when the logger level is debug. 36 | func (l *stdLogger) Debugf(format string, v ...interface{}) { 37 | if l.level > DebugLevel { 38 | return 39 | } 40 | l.Printf(format, v...) 41 | } 42 | 43 | // Info logs args when the logger level is info. 44 | func (l *stdLogger) Info(v ...interface{}) { 45 | if l.level > InfoLevel { 46 | return 47 | } 48 | l.Print(v...) 49 | } 50 | 51 | // Infof formats args and logs the result when the logger level is info. 52 | func (l *stdLogger) Infof(format string, v ...interface{}) { 53 | if l.level > InfoLevel { 54 | return 55 | } 56 | l.Printf(format, v...) 57 | } 58 | 59 | // Warn logs args when the logger level is warn. 60 | func (l *stdLogger) Warn(v ...interface{}) { 61 | if l.level > WarnLevel { 62 | return 63 | } 64 | l.Print(v...) 65 | } 66 | 67 | // Warnf formats args and logs the result when the logger level is warn. 68 | func (l *stdLogger) Warnf(format string, v ...interface{}) { 69 | if l.level > WarnLevel { 70 | return 71 | } 72 | l.Printf(format, v...) 73 | } 74 | 75 | // Error logs args when the logger level is error. 76 | func (l *stdLogger) Error(v ...interface{}) { 77 | if l.level > ErrorLevel { 78 | return 79 | } 80 | l.Print(v...) 81 | } 82 | 83 | // Errorf formats args and logs the result when the logger level is debug. 84 | func (l *stdLogger) Errorf(format string, v ...interface{}) { 85 | if l.level > ErrorLevel { 86 | return 87 | } 88 | l.Printf(format, v...) 89 | } 90 | 91 | // Panic logs args on panic. 92 | func (l *stdLogger) Panic(v ...interface{}) { 93 | l.Logger.Panic(v...) 94 | } 95 | 96 | // Panicf formats args and logs the result on panic. 97 | func (l *stdLogger) Panicf(format string, v ...interface{}) { 98 | l.Logger.Panicf(format, v...) 99 | } 100 | 101 | // Fatal logs args when the error is fatal. 102 | func (l *stdLogger) Fatal(v ...interface{}) { 103 | l.Logger.Fatal(v...) 104 | } 105 | 106 | // Fatalf formats args and logs the result when the error is fatal. 107 | func (l *stdLogger) Fatalf(format string, v ...interface{}) { 108 | l.Logger.Fatalf(format, v...) 109 | } 110 | -------------------------------------------------------------------------------- /log/zap.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "log" 5 | 6 | "go.uber.org/zap" 7 | "go.uber.org/zap/zapcore" 8 | ) 9 | 10 | var _ Logger = (*zap.SugaredLogger)(nil) 11 | 12 | // NewZapLoggerWithLevel returns a new production zap logger with the log level. 13 | func NewZapLoggerWithLevel(lvl Level) Logger { 14 | zlvl := levelToZapLevel(lvl) 15 | opt := zap.IncreaseLevel(zlvl) 16 | l, err := zap.NewProduction(opt) 17 | if err != nil { 18 | log.Fatal(err.Error()) 19 | return nil 20 | } 21 | return l.Sugar() 22 | } 23 | 24 | func levelToZapLevel(lvl Level) zapcore.Level { 25 | switch lvl { 26 | case DebugLevel: 27 | return zapcore.DebugLevel 28 | case InfoLevel: 29 | return zapcore.InfoLevel 30 | case WarnLevel: 31 | return zapcore.WarnLevel 32 | case ErrorLevel: 33 | return zapcore.ErrorLevel 34 | case PanicLevel: 35 | return zapcore.PanicLevel 36 | case FatalLevel: 37 | return zapcore.FatalLevel 38 | default: 39 | return zapcore.InfoLevel 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package chestnut 2 | 3 | import ( 4 | "github.com/jrapoport/chestnut/encoding/compress" 5 | "github.com/jrapoport/chestnut/encryptor" 6 | "github.com/jrapoport/chestnut/encryptor/crypto" 7 | "github.com/jrapoport/chestnut/log" 8 | ) 9 | 10 | // ChestOptions provides a default implementation for common options for a secure store. 11 | type ChestOptions struct { 12 | encryptor crypto.Encryptor 13 | chainEncryptors []crypto.Encryptor 14 | compression compress.Format 15 | compressor compress.CompressorFunc 16 | decompressor compress.DecompressorFunc 17 | // overwrites allows a storage chest to save data over existing data with the same storage key. 18 | // if Overwrite is true, overwrite are enabled and successive calls to save data 19 | // with the same key will succeed. The existing data will be overwritten by the new data. 20 | // if Overwrite is false, overwrite are disabled and successive calls to save data 21 | // with the same key will fail with an error. The existing data will not be overwritten. 22 | overwrites bool 23 | log log.Logger 24 | } 25 | 26 | // DefaultChestOptions represents the recommended default ChestOptions for a store. 27 | var DefaultChestOptions = ChestOptions{ 28 | overwrites: true, 29 | log: log.Log, 30 | } 31 | 32 | // A ChestOption sets options such as encryptors, key rolling, and other parameters, etc. 33 | type ChestOption interface { 34 | apply(*ChestOptions) 35 | } 36 | 37 | // EmptyChestOption does not alter the encrypted store's configuration. 38 | // It can be embedded in another structure to build custom options. 39 | type EmptyChestOption struct{} 40 | 41 | func (EmptyChestOption) apply(*ChestOptions) {} 42 | 43 | // funcOption wraps a function that modifies ChestOptions 44 | // into an implementation of the ChestOption interface. 45 | type funcOption struct { 46 | f func(*ChestOptions) 47 | } 48 | 49 | // apply applies an Option to ChestOptions. 50 | func (fdo *funcOption) apply(do *ChestOptions) { 51 | fdo.f(do) 52 | } 53 | 54 | func newFuncOption(f func(*ChestOptions)) *funcOption { 55 | return &funcOption{ 56 | f: f, 57 | } 58 | } 59 | 60 | // applyOptions accepts a ChestOptions struct and applies the ChestOption(s) to it. 61 | func applyOptions(opts ChestOptions, opt ...ChestOption) ChestOptions { 62 | for _, o := range opt { 63 | o.apply(&opts) 64 | } 65 | chainEncryptors(&opts) 66 | return opts 67 | } 68 | 69 | // chainEncryptors chains all encryptors into one. 70 | func chainEncryptors(opts *ChestOptions) { 71 | // Prepend opts.encryptor to the chaining encryptors if it exists, so that single 72 | // encryptor will be executed before any other chained encryptor. 73 | encryptors := opts.chainEncryptors 74 | if opts.encryptor != nil { 75 | encryptors = append([]crypto.Encryptor{opts.encryptor}, opts.chainEncryptors...) 76 | } 77 | var chained crypto.Encryptor 78 | if len(encryptors) == 0 { 79 | chained = nil 80 | } else if len(encryptors) == 1 { 81 | chained = encryptors[0] 82 | } else { 83 | chained = encryptor.NewChainEncryptor(encryptors...) 84 | } 85 | opts.encryptor = chained 86 | } 87 | 88 | // WithEncryptor returns a ChestOption that specifies the encryptor to use. 89 | func WithEncryptor(e crypto.Encryptor) ChestOption { 90 | return newFuncOption(func(o *ChestOptions) { 91 | if o.encryptor != nil { 92 | panic("The encryptor was already set and may not be reset.") 93 | } 94 | o.encryptor = e 95 | }) 96 | } 97 | 98 | // WithEncryptorChain returns a ChestOption that specifies an encryptor chain. 99 | // for encrypted stores. The first encryptor will be the outer most, 100 | // while the last encryptor will be the inner most wrapper around the real call. 101 | // All encryptors added by this method will be chained. If a single encryptor 102 | // has also been set, it will be *prepended* to the encryptor chain, 103 | // making it the outer most encryptor in the encryptor chain. 104 | func WithEncryptorChain(encryptors ...crypto.Encryptor) ChestOption { 105 | return newFuncOption(func(o *ChestOptions) { 106 | o.chainEncryptors = append(o.chainEncryptors, encryptors...) 107 | }) 108 | } 109 | 110 | // WithAES is a convenience that returns a ChestOption which sets the encryptor 111 | // to be an AESEncryptor initialized with a key length, cipher mode, and Secret. 112 | func WithAES(keyLen crypto.KeyLen, mode crypto.Mode, secret crypto.Secret) ChestOption { 113 | return WithEncryptor(encryptor.NewAESEncryptor(keyLen, mode, secret)) 114 | } 115 | 116 | // WithCompressors instructs the storage chest to compress/decompress data with these compressor 117 | // functions before committing it. If this option is set, WithCompression is ignored. 118 | func WithCompressors(c compress.CompressorFunc, d compress.DecompressorFunc) ChestOption { 119 | return newFuncOption(func(o *ChestOptions) { 120 | o.compression = compress.Custom 121 | o.compressor = c 122 | o.decompressor = d 123 | }) 124 | } 125 | 126 | // WithCompression instructs the storage chest to compress data using the this compression format 127 | // before committing it. Compression this way is self-contained, meaning changes only effect data 128 | // going forward. Previously saved data, compressed or uncompressed, will be transparently retrieved 129 | // regardless of a change to this setting. 130 | func WithCompression(format compress.Format) ChestOption { 131 | return newFuncOption(func(o *ChestOptions) { 132 | o.compression = format 133 | }) 134 | } 135 | 136 | // OverwritesForbidden prevents the store from overwriting existing data. 137 | func OverwritesForbidden() ChestOption { 138 | return newFuncOption(func(o *ChestOptions) { 139 | o.overwrites = false 140 | }) 141 | } 142 | 143 | // WithLogger returns a StoreOption which sets the logger to use for the encrypted store. 144 | func WithLogger(l log.Logger) ChestOption { 145 | return newFuncOption(func(o *ChestOptions) { 146 | o.log = l 147 | }) 148 | } 149 | 150 | // WithStdLogger is a convenience that returns a StoreOption for a standard err logger. 151 | func WithStdLogger(lvl log.Level) ChestOption { 152 | return WithLogger(log.NewStdLoggerWithLevel(lvl)) 153 | } 154 | 155 | // WithLogrusLogger is a convenience that returns a StoreOption for a default logrus logger. 156 | func WithLogrusLogger(lvl log.Level) ChestOption { 157 | return WithLogger(log.NewLogrusLoggerWithLevel(lvl)) 158 | } 159 | 160 | // WithZapLogger is a convenience that returns a StoreOption for a production zap logger. 161 | func WithZapLogger(lvl log.Level) ChestOption { 162 | return WithLogger(log.NewZapLoggerWithLevel(lvl)) 163 | } 164 | -------------------------------------------------------------------------------- /storage/bolt/store.go: -------------------------------------------------------------------------------- 1 | package bolt 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | 9 | "github.com/jrapoport/chestnut/log" 10 | "github.com/jrapoport/chestnut/storage" 11 | jsoniter "github.com/json-iterator/go" 12 | bolt "go.etcd.io/bbolt" 13 | ) 14 | 15 | const ( 16 | logName = "bolt" 17 | storeName = "chest.db" 18 | storeExt = ".db" 19 | ) 20 | 21 | // boltStore is an implementation the Storage interface for bbolt 22 | // https://github.com/etcd-io/bbolt. 23 | type boltStore struct { 24 | opts storage.StoreOptions 25 | path string 26 | db *bolt.DB 27 | log log.Logger 28 | } 29 | 30 | var _ storage.Storage = (*boltStore)(nil) 31 | 32 | // NewStore is used to instantiate a datastore backed by bbolt. 33 | func NewStore(path string, opt ...storage.StoreOption) storage.Storage { 34 | opts := storage.ApplyOptions(storage.DefaultStoreOptions, opt...) 35 | logger := log.Named(opts.Logger(), logName) 36 | if path == "" { 37 | logger.Panic("store path required") 38 | } 39 | return &boltStore{path: path, opts: opts, log: logger} 40 | } 41 | 42 | // Options returns the configuration options for the store. 43 | func (s *boltStore) Options() storage.StoreOptions { 44 | return s.opts 45 | } 46 | 47 | // Open opens the store. 48 | func (s *boltStore) Open() (err error) { 49 | s.log.Debugf("opening store at path: %s", s.path) 50 | var path string 51 | path, err = ensureDBPath(s.path) 52 | if err != nil { 53 | err = s.logError("open", err) 54 | return 55 | } 56 | s.db, err = bolt.Open(path, 0600, nil) 57 | if err != nil { 58 | err = s.logError("open", err) 59 | return 60 | } 61 | if s.db == nil { 62 | err = errors.New("unable to open backing store") 63 | err = s.logError("open", err) 64 | return 65 | } 66 | s.log.Infof("opened store at path: %s", s.path) 67 | return 68 | } 69 | 70 | // Put an entry in the store. 71 | func (s *boltStore) Put(name string, key []byte, value []byte) error { 72 | s.log.Debugf("put: %d value bytes to key: %s", len(value), key) 73 | if err := storage.ValidKey(name, key); err != nil { 74 | return s.logError("put", err) 75 | } else if len(value) <= 0 { 76 | err = errors.New("value cannot be empty") 77 | return s.logError("put", err) 78 | } 79 | putValue := func(tx *bolt.Tx) error { 80 | s.log.Debugf("put: tx %d bytes to key: %s.%s", 81 | len(value), name, string(key)) 82 | b, err := tx.CreateBucketIfNotExists([]byte(name)) 83 | if err != nil { 84 | return err 85 | } 86 | return b.Put(key, value) 87 | } 88 | return s.logError("put", s.db.Update(putValue)) 89 | } 90 | 91 | // Get a value from the store. 92 | func (s *boltStore) Get(name string, key []byte) ([]byte, error) { 93 | s.log.Debugf("get: value at key: %s", key) 94 | if err := storage.ValidKey(name, key); err != nil { 95 | return nil, s.logError("get", err) 96 | } 97 | var value []byte 98 | getValue := func(tx *bolt.Tx) error { 99 | s.log.Debugf("get: tx key: %s.%s", name, key) 100 | b := tx.Bucket([]byte(name)) 101 | if b == nil { 102 | return fmt.Errorf("bucket not found: %s", name) 103 | } 104 | v := b.Get(key) 105 | if len(v) <= 0 { 106 | return errors.New("nil value") 107 | } 108 | value = v 109 | s.log.Debugf("get: tx key: %s.%s value (%d bytes)", 110 | name, string(key), len(value)) 111 | return nil 112 | } 113 | if err := s.db.View(getValue); err != nil { 114 | return nil, s.logError("get", err) 115 | } 116 | return value, nil 117 | } 118 | 119 | // Save the value in v and store the result at key. 120 | func (s *boltStore) Save(name string, key []byte, v interface{}) error { 121 | b, err := jsoniter.Marshal(v) 122 | if err != nil { 123 | return s.logError("save", err) 124 | } 125 | return s.Put(name, key, b) 126 | } 127 | 128 | // Load the value at key and stores the result in v. 129 | func (s *boltStore) Load(name string, key []byte, v interface{}) error { 130 | b, err := s.Get(name, key) 131 | if err != nil { 132 | return s.logError("load", err) 133 | } 134 | return s.logError("load", jsoniter.Unmarshal(b, v)) 135 | } 136 | 137 | // Has checks for a key in the store. 138 | func (s *boltStore) Has(name string, key []byte) (bool, error) { 139 | s.log.Debugf("has: key: %s", key) 140 | if err := storage.ValidKey(name, key); err != nil { 141 | return false, s.logError("has", err) 142 | } 143 | var has bool 144 | hasKey := func(tx *bolt.Tx) error { 145 | s.log.Debugf("has: tx get namespace: %s", name) 146 | b := tx.Bucket([]byte(name)) 147 | if b == nil { 148 | err := fmt.Errorf("bucket not found: %s", name) 149 | return err 150 | } 151 | v := b.Get(key) 152 | has = len(v) > 0 153 | if has { 154 | s.log.Debugf("has: tx key found: %s.%s", name, string(key)) 155 | } 156 | return nil 157 | } 158 | if err := s.db.View(hasKey); err != nil { 159 | return false, s.logError("has", err) 160 | } 161 | s.log.Debugf("has: found key %s: %t", key, has) 162 | return has, nil 163 | } 164 | 165 | // Delete removes a key from the store. 166 | func (s *boltStore) Delete(name string, key []byte) error { 167 | s.log.Debugf("delete: key: %s", key) 168 | if err := storage.ValidKey(name, key); err != nil { 169 | return s.logError("delete", err) 170 | } 171 | del := func(tx *bolt.Tx) error { 172 | s.log.Debugf("delete: tx key: %s.%s", name, string(key)) 173 | b := tx.Bucket([]byte(name)) 174 | if b == nil { 175 | err := fmt.Errorf("bucket not found: %s", name) 176 | // an error just means we couldn't find the bucket 177 | s.log.Warn(err) 178 | return nil 179 | } 180 | return b.Delete(key) 181 | } 182 | return s.logError("delete", s.db.Update(del)) 183 | } 184 | 185 | // List returns a list of all keys in the namespace. 186 | func (s *boltStore) List(name string) (keys [][]byte, err error) { 187 | s.log.Debugf("list: keys in namespace: %s", name) 188 | listKeys := func(tx *bolt.Tx) error { 189 | b := tx.Bucket([]byte(name)) 190 | if b == nil { 191 | err = fmt.Errorf("bucket not found: %s", name) 192 | return err 193 | } 194 | keys, err = s.listKeys(name, b) 195 | return err 196 | } 197 | if err = s.db.View(listKeys); err != nil { 198 | return nil, s.logError("list", err) 199 | } 200 | s.log.Debugf("list: found %d keys: %s", len(keys), keys) 201 | return 202 | } 203 | 204 | func (s *boltStore) listKeys(name string, b *bolt.Bucket) ([][]byte, error) { 205 | if b == nil { 206 | err := fmt.Errorf("invalid bucket: %s", name) 207 | return nil, err 208 | } 209 | var keys [][]byte 210 | s.log.Debugf("list: tx scan namespace: %s", name) 211 | count := b.Stats().KeyN 212 | keys = make([][]byte, count) 213 | s.log.Debugf("list: tx found %d keys in: %s", count, name) 214 | var i int 215 | _ = b.ForEach(func(k, _ []byte) error { 216 | s.log.Debugf("list: tx found key: %s.%s", name, string(k)) 217 | keys[i] = k 218 | i++ 219 | return nil 220 | }) 221 | return keys, nil 222 | } 223 | 224 | // ListAll returns a mapped list of all keys in the store. 225 | func (s *boltStore) ListAll() (map[string][][]byte, error) { 226 | s.log.Debugf("list: all keys") 227 | var total int 228 | allKeys := map[string][][]byte{} 229 | listKeys := func(tx *bolt.Tx) error { 230 | err := tx.ForEach(func(name []byte, b *bolt.Bucket) error { 231 | keys, err := s.listKeys(string(name), b) 232 | if err != nil { 233 | return err 234 | } 235 | if len(keys) <= 0 { 236 | return nil 237 | } 238 | allKeys[string(name)] = keys 239 | total += len(keys) 240 | return nil 241 | }) 242 | return err 243 | } 244 | if err := s.db.View(listKeys); err != nil { 245 | return nil, s.logError("list", err) 246 | } 247 | s.log.Debugf("list: found %d keys: %s", total, allKeys) 248 | return allKeys, nil 249 | } 250 | 251 | // Export copies the datastore to directory at path. 252 | func (s *boltStore) Export(path string) error { 253 | s.log.Debugf("export: to path: %s", path) 254 | if path == "" { 255 | err := fmt.Errorf("invalid path: %s", path) 256 | return s.logError("export", err) 257 | } else if s.path == path { 258 | err := fmt.Errorf("path cannot be store path: %s", path) 259 | return s.logError("export", err) 260 | } 261 | var err error 262 | path, err = ensureDBPath(path) 263 | if err != nil { 264 | return s.logError("export", err) 265 | } 266 | err = s.db.View(func(tx *bolt.Tx) error { 267 | return tx.CopyFile(path, 0600) 268 | }) 269 | if err != nil { 270 | return s.logError("export", err) 271 | } 272 | s.log.Debugf("export: to path complete: %s", path) 273 | return nil 274 | } 275 | 276 | // Close closes the datastore and releases all db resources. 277 | func (s *boltStore) Close() error { 278 | s.log.Debugf("closing store at path: %s", s.path) 279 | err := s.db.Close() 280 | s.db = nil 281 | s.log.Info("store closed") 282 | return s.logError("close", err) 283 | } 284 | 285 | func (s *boltStore) logError(name string, err error) error { 286 | if err == nil { 287 | return nil 288 | } 289 | if name != "" { 290 | err = fmt.Errorf("%s: %w", name, err) 291 | } 292 | s.log.Error(err) 293 | return err 294 | } 295 | 296 | func ensureDBPath(path string) (string, error) { 297 | if path == "" { 298 | return "", errors.New("path not found") 299 | } 300 | // does the path exist? 301 | info, err := os.Stat(path) 302 | exists := !os.IsNotExist(err) 303 | // this is some kind of actual error 304 | if err != nil && exists { 305 | return "", err 306 | } 307 | if exists && info.Mode().IsDir() { 308 | // if we have a directory, then append our default name 309 | path = filepath.Join(path, storeName) 310 | } 311 | ext := filepath.Ext(path) 312 | if ext == "" { 313 | path += storeExt 314 | } 315 | dir, _ := filepath.Split(path) 316 | // make sure the directory path exists 317 | if err = os.MkdirAll(dir, 0700); err != nil { 318 | return "", err 319 | } 320 | _, err = os.Stat(path) 321 | exists = !os.IsNotExist(err) 322 | // this is some kind of actual error 323 | if err != nil && exists { 324 | return "", err 325 | } 326 | if exists { 327 | return path, nil 328 | } 329 | f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666) 330 | if err != nil { 331 | return "", err 332 | } 333 | defer f.Close() 334 | return path, nil 335 | } 336 | -------------------------------------------------------------------------------- /storage/bolt/store_test.go: -------------------------------------------------------------------------------- 1 | package bolt 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jrapoport/chestnut/storage/store_test" 7 | ) 8 | 9 | func TestStore(t *testing.T) { 10 | store_test.TestStore(t, NewStore) 11 | } 12 | -------------------------------------------------------------------------------- /storage/nuts/store.go: -------------------------------------------------------------------------------- 1 | package nuts 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | 8 | "github.com/jrapoport/chestnut/log" 9 | "github.com/jrapoport/chestnut/storage" 10 | jsoniter "github.com/json-iterator/go" 11 | "github.com/nutsdb/nutsdb" 12 | ) 13 | 14 | const logName = "nutsdb" 15 | 16 | // nutsDBStore is an implementation the Storage interface for nutsdb 17 | // https://github.com/nutsdb/nutsdb. 18 | type nutsDBStore struct { 19 | opts storage.StoreOptions 20 | path string 21 | db *nutsdb.DB 22 | log log.Logger 23 | } 24 | 25 | var _ storage.Storage = (*nutsDBStore)(nil) 26 | 27 | // NewStore is used to instantiate a datastore backed by nutsdb. 28 | func NewStore(path string, opt ...storage.StoreOption) storage.Storage { 29 | opts := storage.ApplyOptions(storage.DefaultStoreOptions, opt...) 30 | logger := log.Named(opts.Logger(), logName) 31 | if path == "" { 32 | logger.Panic("store path required") 33 | } 34 | return &nutsDBStore{path: path, opts: opts, log: logger} 35 | } 36 | 37 | // Options returns the configuration options for the store. 38 | func (s *nutsDBStore) Options() storage.StoreOptions { 39 | return s.opts 40 | } 41 | 42 | // Open opens the store. 43 | func (s *nutsDBStore) Open() (err error) { 44 | s.log.Debugf("opening store at path: %s", s.path) 45 | opt := nutsdb.DefaultOptions 46 | opt.Dir = s.path 47 | if s.db, err = nutsdb.Open(opt); err != nil { 48 | err = s.logError("open", err) 49 | return 50 | } 51 | if s.db == nil { 52 | err = errors.New("unable to open backing store") 53 | err = s.logError("open", err) 54 | return 55 | } 56 | s.log.Infof("opened store at path: %s", s.path) 57 | return 58 | } 59 | 60 | // Put an entry in the store. 61 | func (s *nutsDBStore) Put(name string, key []byte, value []byte) error { 62 | s.log.Debugf("put: %d value bytes to key: %s", len(value), key) 63 | if err := storage.ValidKey(name, key); err != nil { 64 | return s.logError("put", err) 65 | } else if len(value) <= 0 { 66 | err = errors.New("value cannot be empty") 67 | return s.logError("put", err) 68 | } 69 | newBucket := func(tx *nutsdb.Tx) error { 70 | e := tx.NewBucket(nutsdb.DataStructureBTree, name) 71 | if e != nil && !errors.Is(e, nutsdb.ErrBucketAlreadyExist) { 72 | return e 73 | } 74 | return nil 75 | } 76 | if err := s.db.Update(newBucket); err != nil { 77 | return s.logError("put", err) 78 | } 79 | putValue := func(tx *nutsdb.Tx) error { 80 | s.log.Debugf("put: tx %d bytes to key: %s.%s", 81 | len(value), name, string(key)) 82 | return tx.Put(name, key, value, 0) 83 | } 84 | return s.logError("put", s.db.Update(putValue)) 85 | } 86 | 87 | // Get a value from the store. 88 | func (s *nutsDBStore) Get(name string, key []byte) ([]byte, error) { 89 | s.log.Debugf("get: value at key: %s", key) 90 | if err := storage.ValidKey(name, key); err != nil { 91 | return nil, s.logError("get", err) 92 | } 93 | var value []byte 94 | var err error 95 | getValue := func(tx *nutsdb.Tx) error { 96 | s.log.Debugf("get: tx key: %s.%s", name, key) 97 | value, err = tx.Get(name, key) 98 | if err != nil { 99 | return err 100 | } 101 | s.log.Debugf("get: tx key: %s.%s value (%d bytes)", 102 | name, string(key), len(value)) 103 | return nil 104 | } 105 | if err := s.db.View(getValue); err != nil { 106 | return nil, s.logError("get", err) 107 | } 108 | return value, nil 109 | } 110 | 111 | // Save the value in v and store the result at key. 112 | func (s *nutsDBStore) Save(name string, key []byte, v interface{}) error { 113 | b, err := jsoniter.Marshal(v) 114 | if err != nil { 115 | return s.logError("save", err) 116 | } 117 | return s.Put(name, key, b) 118 | } 119 | 120 | // Load the value at key and stores the result in v. 121 | func (s *nutsDBStore) Load(name string, key []byte, v interface{}) error { 122 | b, err := s.Get(name, key) 123 | if err != nil { 124 | return s.logError("load", err) 125 | } 126 | return s.logError("load", jsoniter.Unmarshal(b, v)) 127 | } 128 | 129 | // Has checks for a key in the store. 130 | func (s *nutsDBStore) Has(name string, key []byte) (bool, error) { 131 | s.log.Debugf("has: key: %s", key) 132 | if err := storage.ValidKey(name, key); err != nil { 133 | return false, s.logError("has", err) 134 | } 135 | var has bool 136 | hasKey := func(tx *nutsdb.Tx) error { 137 | s.log.Debugf("has: tx get namespace: %s", name) 138 | keys, err := tx.GetKeys(name) 139 | if err != nil { 140 | return err 141 | } 142 | s.log.Debugf("has: tx found %d keys in: %s", len(keys), name) 143 | for _, k := range keys { 144 | has = bytes.Equal(key, k) 145 | if has { 146 | s.log.Debugf("has: tx key found: %s.%s", name, string(key)) 147 | break 148 | } 149 | } 150 | return nil 151 | } 152 | if err := s.db.View(hasKey); err != nil { 153 | return false, s.logError("has", err) 154 | } 155 | s.log.Debugf("has: found key %s: %t", key, has) 156 | return has, nil 157 | } 158 | 159 | // Delete removes a key from the store. 160 | func (s *nutsDBStore) Delete(name string, key []byte) error { 161 | s.log.Debugf("delete: key: %s", key) 162 | if err := storage.ValidKey(name, key); err != nil { 163 | return s.logError("delete", err) 164 | } 165 | del := func(tx *nutsdb.Tx) error { 166 | s.log.Debugf("delete: tx key: %s.%s", name, string(key)) 167 | err := tx.Delete(name, key) 168 | if errors.Is(err, nutsdb.ErrKeyNotFound) { 169 | return nil 170 | } 171 | return err 172 | } 173 | return s.logError("delete", s.db.Update(del)) 174 | } 175 | 176 | // List returns a list of all keys in the namespace. 177 | func (s *nutsDBStore) List(name string) (keys [][]byte, err error) { 178 | s.log.Debugf("list: keys in namespace: %s", name) 179 | listKeys := func(tx *nutsdb.Tx) error { 180 | keys, err = s.listKeys(name, tx) 181 | return err 182 | } 183 | if err = s.db.View(listKeys); err != nil { 184 | return nil, s.logError("list", err) 185 | } 186 | s.log.Debugf("list: found %d keys: %s", len(keys), keys) 187 | return 188 | } 189 | 190 | func (s *nutsDBStore) listKeys(name string, tx *nutsdb.Tx) ([][]byte, error) { 191 | var keys [][]byte 192 | s.log.Debugf("list: tx scan namespace: %s", name) 193 | keys, err := tx.GetKeys(name) 194 | if err != nil { 195 | return nil, err 196 | } 197 | s.log.Debugf("list: tx found %d keys in: %s", len(keys), name) 198 | // for _, key := range keys { 199 | // s.log.Debugf("list: tx found key: %s.%s", name, key) 200 | // } 201 | return keys, nil 202 | } 203 | 204 | // ListAll returns a mapped list of all keys in the store. 205 | func (s *nutsDBStore) ListAll() (map[string][][]byte, error) { 206 | s.log.Debugf("list: all keys") 207 | var total int 208 | allKeys := map[string][][]byte{} 209 | listKeys := func(tx *nutsdb.Tx) error { 210 | err := tx.IterateBuckets(nutsdb.DataStructureBTree, "*", func(bucket string) bool { 211 | keys, err := s.listKeys(bucket, tx) 212 | if err != nil { 213 | return false 214 | } 215 | if len(keys) <= 0 { 216 | return true 217 | } 218 | allKeys[bucket] = keys 219 | total += len(keys) 220 | return true 221 | }) 222 | return err 223 | } 224 | if err := s.db.View(listKeys); err != nil { 225 | return nil, s.logError("list", err) 226 | } 227 | s.log.Debugf("list: found %d keys: %s", total, allKeys) 228 | return allKeys, nil 229 | } 230 | 231 | // Export copies the datastore to directory at path. 232 | func (s *nutsDBStore) Export(path string) error { 233 | s.log.Debugf("export: to path: %s", path) 234 | if path == "" { 235 | err := fmt.Errorf("invalid path: %s", path) 236 | return s.logError("export", err) 237 | } else if s.path == path { 238 | err := fmt.Errorf("path cannot be store path: %s", path) 239 | return s.logError("export", err) 240 | } 241 | if err := s.db.Backup(path); err != nil { 242 | return s.logError("export", err) 243 | } 244 | s.log.Debugf("export: to path complete: %s", path) 245 | return nil 246 | } 247 | 248 | // Close closes the datastore and releases all db resources. 249 | func (s *nutsDBStore) Close() error { 250 | s.log.Debugf("closing store at path: %s", s.path) 251 | err := s.db.Close() 252 | s.db = nil 253 | s.log.Info("store closed") 254 | return s.logError("close", err) 255 | } 256 | 257 | func (s *nutsDBStore) logError(name string, err error) error { 258 | if err == nil { 259 | return nil 260 | } 261 | if name != "" { 262 | err = fmt.Errorf("%s: %w", name, err) 263 | } 264 | s.log.Error(err) 265 | return err 266 | } 267 | -------------------------------------------------------------------------------- /storage/nuts/store_test.go: -------------------------------------------------------------------------------- 1 | package nuts 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jrapoport/chestnut/storage/store_test" 7 | ) 8 | 9 | func TestStore(t *testing.T) { 10 | store_test.TestStore(t, NewStore) 11 | } 12 | -------------------------------------------------------------------------------- /storage/options.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import "github.com/jrapoport/chestnut/log" 4 | 5 | // StoreOptions provides a default implementation for common storage Options stores should support. 6 | type StoreOptions struct { 7 | log log.Logger 8 | } 9 | 10 | // Logger returns the configured logger for the store. 11 | func (o StoreOptions) Logger() log.Logger { 12 | return o.log 13 | } 14 | 15 | // DefaultStoreOptions represents the recommended default StoreOptions for a store. 16 | var DefaultStoreOptions = StoreOptions{ 17 | log: log.Log, 18 | } 19 | 20 | // A StoreOption sets options such disabling overwrite, and other parameters, etc. 21 | type StoreOption interface { 22 | apply(*StoreOptions) 23 | } 24 | 25 | // EmptyStoreOption does not alter the store configuration. 26 | // It can be embedded in another structure to build custom options. 27 | type EmptyStoreOption struct{} 28 | 29 | func (EmptyStoreOption) apply(*StoreOptions) {} 30 | 31 | // funcOption wraps a function that modifies StoreOptions 32 | // into an implementation of the StoreOption interface. 33 | type funcOption struct { 34 | f func(*StoreOptions) 35 | } 36 | 37 | // Apply applies an StoreOption to StoreOptions. 38 | func (fdo *funcOption) apply(do *StoreOptions) { 39 | fdo.f(do) 40 | } 41 | 42 | func newFuncOption(f func(*StoreOptions)) *funcOption { 43 | return &funcOption{ 44 | f: f, 45 | } 46 | } 47 | 48 | // ApplyOptions accepts an StoreOptions struct and applies the StoreOption(s) to it. 49 | func ApplyOptions(opts StoreOptions, opt ...StoreOption) StoreOptions { 50 | for _, o := range opt { 51 | o.apply(&opts) 52 | } 53 | return opts 54 | } 55 | 56 | // WithLogger returns a StoreOption which sets the logger to use for the encrypted store. 57 | func WithLogger(l log.Logger) StoreOption { 58 | return newFuncOption(func(o *StoreOptions) { 59 | o.log = l 60 | }) 61 | } 62 | 63 | // WithStdLogger is a convenience that returns a StoreOption for a standard err logger. 64 | func WithStdLogger(lvl log.Level) StoreOption { 65 | return WithLogger(log.NewStdLoggerWithLevel(lvl)) 66 | } 67 | 68 | // WithLogrusLogger is a convenience that returns a StoreOption for a default logrus logger. 69 | func WithLogrusLogger(lvl log.Level) StoreOption { 70 | return WithLogger(log.NewLogrusLoggerWithLevel(lvl)) 71 | } 72 | 73 | // WithZapLogger is a convenience that returns a StoreOption for a production zap logger. 74 | func WithZapLogger(lvl log.Level) StoreOption { 75 | return WithLogger(log.NewZapLoggerWithLevel(lvl)) 76 | } 77 | -------------------------------------------------------------------------------- /storage/storage.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | // Storage provides a management interface for a datastore. 9 | type Storage interface { 10 | // Open opens the store. 11 | Open() error 12 | 13 | // Put a value in the store. 14 | Put(namespace string, key []byte, value []byte) error 15 | 16 | // Get a value from the store. 17 | Get(namespace string, key []byte) (value []byte, err error) 18 | 19 | // Has checks for a key in the store. 20 | Has(namespace string, key []byte) (bool, error) 21 | 22 | // Save the value in v and stores the result at key. 23 | Save(namespace string, key []byte, v interface{}) error 24 | 25 | // Load the value at key and stores the result in v. 26 | Load(namespace string, key []byte, v interface{}) error 27 | 28 | // List returns a list of all keys in the namespace. 29 | List(namespace string) ([][]byte, error) 30 | 31 | // ListAll returns a mapped list of all keys in the store. 32 | ListAll() (map[string][][]byte, error) 33 | 34 | // Delete removes a key from the store. 35 | Delete(name string, key []byte) error 36 | 37 | // Close closes the store. 38 | Close() error 39 | 40 | // Export saves the store to path. 41 | Export(path string) error 42 | } 43 | 44 | // ErrInvalidKey the storage key is invalid. 45 | var ErrInvalidKey = errors.New("invalid storage key") 46 | 47 | // ValidKey returns nil if the key is valid, otherwise ErrInvalidKey. 48 | func ValidKey(name string, key []byte) error { 49 | if name == "" { 50 | return fmt.Errorf("%w namespace: %s", ErrInvalidKey, name) 51 | } 52 | if len(key) <= 0 { 53 | return fmt.Errorf("%w: %s", ErrInvalidKey, key) 54 | } 55 | return nil 56 | } 57 | -------------------------------------------------------------------------------- /storage/store_test/test_suite.go: -------------------------------------------------------------------------------- 1 | package store_test 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | "testing" 7 | 8 | "github.com/google/uuid" 9 | "github.com/jrapoport/chestnut/log" 10 | "github.com/jrapoport/chestnut/storage" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/suite" 13 | ) 14 | 15 | type testCase struct { 16 | name string 17 | key string 18 | value string 19 | err assert.ErrorAssertionFunc 20 | has assert.BoolAssertionFunc 21 | } 22 | 23 | type testObject struct { 24 | Value string 25 | } 26 | 27 | var ( 28 | testName = "test-name" 29 | testKey = "test-key" 30 | testValue = "test-value" 31 | testObj = &testObject{"hello"} 32 | ) 33 | 34 | var putTests = []testCase{ 35 | {"", "", "", assert.Error, assert.False}, 36 | {"a", testKey, "", assert.Error, assert.False}, 37 | {"b", testKey, testValue, assert.NoError, assert.True}, 38 | {"c/c", testKey, testValue, assert.NoError, assert.True}, 39 | {".d", testKey, testValue, assert.NoError, assert.True}, 40 | {testName, "", "", assert.Error, assert.False}, 41 | {testName, "a", "", assert.Error, assert.False}, 42 | {testName, "b", testValue, assert.NoError, assert.True}, 43 | {testName, "c/c", testValue, assert.NoError, assert.True}, 44 | {testName, ".d", testValue, assert.NoError, assert.True}, 45 | {testName, testKey, testValue, assert.NoError, assert.True}, 46 | } 47 | 48 | var tests = append(putTests, 49 | testCase{testName, "not-found", "", assert.Error, assert.False}, 50 | ) 51 | 52 | type storeFunc = func(string, ...storage.StoreOption) storage.Storage 53 | 54 | type storeTestSuite struct { 55 | suite.Suite 56 | storeFunc 57 | store storage.Storage 58 | path string 59 | } 60 | 61 | // TestStore tests a store 62 | func TestStore(t *testing.T, fn storeFunc) { 63 | ts := new(storeTestSuite) 64 | ts.storeFunc = fn 65 | suite.Run(t, ts) 66 | } 67 | 68 | // SetupTest 69 | func (ts *storeTestSuite) SetupTest() { 70 | ts.path = ts.T().TempDir() 71 | ts.store = ts.storeFunc(ts.path) 72 | err := ts.store.Open() 73 | ts.NoError(err) 74 | } 75 | 76 | // TearDownTest 77 | func (ts *storeTestSuite) TearDownTest() { 78 | err := ts.store.Close() 79 | ts.NoError(err) 80 | } 81 | 82 | // BeforeTest 83 | func (ts *storeTestSuite) BeforeTest(_, testName string) { 84 | switch testName { 85 | case "TestStorePut", 86 | "TestStoreSave", 87 | "TestStoreLoad", 88 | "TestStoreList", 89 | "TestStoreListAll", 90 | "TestStoreWithLogger": 91 | break 92 | default: 93 | ts.TestStorePut() 94 | } 95 | } 96 | 97 | func (ts *storeTestSuite) TestInvalidPath() { 98 | ts.Panics(func() { 99 | ts.storeFunc("") 100 | }) 101 | } 102 | 103 | // TestStorePut 104 | func (ts *storeTestSuite) TestStorePut() { 105 | for i, test := range putTests { 106 | err := ts.store.Put(test.name, []byte(test.key), []byte(test.value)) 107 | test.err(ts.T(), err, "%d test name: %s key: %s", i, test.name, test.key) 108 | } 109 | } 110 | 111 | // TestStoreSave 112 | func (ts *storeTestSuite) TestStoreSave() { 113 | err := ts.store.Save(testName, []byte(testKey), testObj) 114 | ts.NoError(err) 115 | } 116 | 117 | // TestStoreLoad 118 | func (ts *storeTestSuite) TestStoreLoad() { 119 | ts.T().Run("Setup", func(t *testing.T) { 120 | ts.TestStoreSave() 121 | }) 122 | to := &testObject{} 123 | err := ts.store.Load(testName, []byte(testKey), to) 124 | ts.NoError(err) 125 | ts.Equal(testObj, to) 126 | } 127 | 128 | // TestStoreGet 129 | func (ts *storeTestSuite) TestStoreGet() { 130 | for i, test := range tests { 131 | value, err := ts.store.Get(test.name, []byte(test.key)) 132 | test.err(ts.T(), err, "%d test name: %s key: %s", i, test.name, test.key) 133 | ts.Equal(test.value, string(value), 134 | "%d test key: %s", i, test.key) 135 | } 136 | } 137 | 138 | // TestStoreHas 139 | func (ts *storeTestSuite) TestStoreHas() { 140 | for i, test := range tests { 141 | has, _ := ts.store.Has(test.name, []byte(test.key)) 142 | test.has(ts.T(), has, "%d test key: %s", i, test.key) 143 | } 144 | } 145 | 146 | // TestStoreList 147 | func (ts *storeTestSuite) TestStoreList() { 148 | const listLen = 100 149 | list := make([]string, listLen) 150 | for i := 0; i < listLen; i++ { 151 | list[i] = uuid.New().String() 152 | err := ts.store.Put(testName, []byte(list[i]), []byte(testValue)) 153 | ts.NoError(err) 154 | } 155 | keys, err := ts.store.List(testName) 156 | ts.NoError(err) 157 | ts.Len(keys, listLen) 158 | // put both lists in the same order so we can compare them 159 | strKeys := make([]string, len(keys)) 160 | for i, k := range keys { 161 | strKeys[i] = string(k) 162 | } 163 | sort.Strings(list) 164 | sort.Strings(strKeys) 165 | ts.Equal(list, strKeys) 166 | } 167 | 168 | // TestStoreListAll 169 | func (ts *storeTestSuite) TestStoreListAll() { 170 | const listLen = 100 171 | list := make([]string, listLen) 172 | for i := 0; i < listLen; i++ { 173 | list[i] = uuid.New().String() 174 | ns := fmt.Sprintf("%s%d", testName, i) 175 | err := ts.store.Put(ns, []byte(list[i]), []byte(testValue)) 176 | ts.NoError(err) 177 | } 178 | keyMap, err := ts.store.ListAll() 179 | ts.NoError(err) 180 | var keys []string 181 | for _, ks := range keyMap { 182 | for _, k := range ks { 183 | keys = append(keys, string(k)) 184 | } 185 | } 186 | ts.Len(keys, listLen) 187 | sort.Strings(list) 188 | sort.Strings(keys) 189 | ts.Equal(list, keys) 190 | } 191 | 192 | // TestStoreDelete 193 | func (ts *storeTestSuite) TestStoreDelete() { 194 | var deleteTests = []struct { 195 | key string 196 | err assert.ErrorAssertionFunc 197 | }{ 198 | {"", assert.Error}, 199 | {"a", assert.NoError}, 200 | {"b", assert.NoError}, 201 | {"c/c", assert.NoError}, 202 | {".d", assert.NoError}, 203 | {"eee", assert.NoError}, 204 | {"not-found", assert.NoError}, 205 | } 206 | for i, test := range deleteTests { 207 | err := ts.store.Delete(testName, []byte(test.key)) 208 | test.err(ts.T(), err, "%d test key: %s", i, test.key) 209 | } 210 | } 211 | 212 | // TestStoreExport 213 | func (ts *storeTestSuite) TestStoreExport() { 214 | exTests := []struct { 215 | path string 216 | Err assert.ErrorAssertionFunc 217 | }{ 218 | {"", assert.Error}, 219 | {ts.path, assert.Error}, 220 | {ts.T().TempDir(), assert.NoError}, 221 | } 222 | for _, test := range exTests { 223 | err := ts.store.Export(test.path) 224 | test.Err(ts.T(), err) 225 | if err == nil { 226 | s2 := ts.storeFunc(test.path) 227 | ts.NotNil(s2) 228 | err = s2.Open() 229 | ts.NoError(err) 230 | keys, err := s2.ListAll() 231 | ts.NoError(err) 232 | ts.NotEmpty(keys) 233 | err = s2.Close() 234 | ts.NoError(err) 235 | } 236 | } 237 | } 238 | 239 | // TestStoreWithLogger 240 | func (ts *storeTestSuite) TestStoreWithLogger() { 241 | levels := []log.Level{ 242 | log.DebugLevel, 243 | log.InfoLevel, 244 | log.WarnLevel, 245 | log.ErrorLevel, 246 | log.PanicLevel, 247 | } 248 | type LoggerOpt func(log.Level) storage.StoreOption 249 | logOpts := []LoggerOpt{ 250 | storage.WithLogrusLogger, 251 | storage.WithStdLogger, 252 | storage.WithZapLogger, 253 | } 254 | path := ts.T().TempDir() 255 | for _, level := range levels { 256 | for _, logOpt := range logOpts { 257 | opt := logOpt(level) 258 | store := ts.storeFunc(path, opt) 259 | ts.NotNil(store) 260 | err := store.Open() 261 | ts.NoError(err) 262 | err = store.Close() 263 | ts.NoError(err) 264 | } 265 | } 266 | } 267 | -------------------------------------------------------------------------------- /value/id.go: -------------------------------------------------------------------------------- 1 | package value 2 | 3 | import "github.com/jrapoport/chestnut/storage" 4 | 5 | // ID provides a implementation of the Keyed interface. 6 | // It can be embedded in another structure to build custom Keyed values. 7 | type ID struct { 8 | ID string `json:"id"` 9 | } 10 | 11 | var _ Keyed = (*ID)(nil) 12 | 13 | // Key returns the key as bytes. 14 | func (k *ID) Key() []byte { 15 | return []byte(k.ID) 16 | } 17 | 18 | // Namespace is the namespace to use when storing the key. 19 | func (k *ID) Namespace() string { 20 | name := "" 21 | if k.ID != "" { 22 | name = k.ID[:1] 23 | } 24 | return name 25 | } 26 | 27 | // ValidKey checks the key is valid. 28 | func (k *ID) ValidKey() error { 29 | return storage.ValidKey(k.Namespace(), k.Key()) 30 | } 31 | 32 | // String returns the key as a string. 33 | func (k *ID) String() string { 34 | return k.ID 35 | } 36 | -------------------------------------------------------------------------------- /value/id_test.go: -------------------------------------------------------------------------------- 1 | package value 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func Test_ID(t *testing.T) { 10 | keyTest := []struct { 11 | name string 12 | key string 13 | err assert.ErrorAssertionFunc 14 | }{ 15 | {"", "", assert.Error}, 16 | {"a", "a", assert.NoError}, 17 | {"t", "test", assert.NoError}, 18 | } 19 | for _, test := range keyTest { 20 | key := &ID{test.key} 21 | assert.Equal(t, test.name, key.Namespace()) 22 | assert.Equal(t, []byte(test.key), key.Key()) 23 | assert.Equal(t, test.key, key.String()) 24 | test.err(t, key.ValidKey()) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /value/keyed.go: -------------------------------------------------------------------------------- 1 | package value 2 | 3 | // Keyed provides a management interface for keyed values. 4 | type Keyed interface { 5 | // Key is the byte representation of the key. 6 | Key() []byte 7 | 8 | // Namespace is the namespace to use when storing the key. 9 | Namespace() string 10 | 11 | // ValidKey returns nil if the key is valid, otherwise ErrInvalidKey. 12 | ValidKey() error 13 | } 14 | -------------------------------------------------------------------------------- /value/secure.go: -------------------------------------------------------------------------------- 1 | package value 2 | 3 | // Secure provides a simple value for storing sparsely 4 | // encrypted blobs and plaintext metadata. 5 | type Secure struct { 6 | ID 7 | Data []byte `json:"data,secure"` 8 | Metadata map[string]interface{} `json:"metadata"` 9 | } 10 | 11 | // NewSecureValue returns a new Secure value. 12 | func NewSecureValue(id string, data []byte) *Secure { 13 | return &Secure{ 14 | ID: ID{ID: id}, 15 | Data: data, 16 | Metadata: map[string]interface{}{}, 17 | } 18 | } 19 | 20 | // SetMetadata sets the metadata entry for k to v. 21 | func (e *Secure) SetMetadata(k string, v interface{}) { 22 | e.Metadata[k] = v 23 | } 24 | 25 | // GetMetadata gets the metadata entry for k and returns it as v. 26 | func (e *Secure) GetMetadata(k string) (v interface{}) { 27 | v = e.Metadata[k] 28 | return 29 | } 30 | --------------------------------------------------------------------------------