├── .github ├── release-drafter.yml └── workflows │ ├── draft.yaml │ └── test.yaml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── README.md ├── api.go ├── api_test.go ├── client.go ├── client_test.go ├── codecov.yml ├── config.go ├── config_test.go ├── credential.go ├── credential_test.go ├── devicecheck_test.go ├── error.go ├── error_test.go ├── go.mod ├── go.sum ├── invalid_private_key.p8 ├── jwt.go ├── jwt_test.go ├── query.go ├── query_test.go ├── renovate.json ├── revoked_private_key.p8 ├── update.go ├── update_test.go ├── validate.go └── validate_test.go /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: 'v$RESOLVED_VERSION' 2 | tag-template: 'v$RESOLVED_VERSION' 3 | template: | 4 | ## Changes 5 | $CHANGES 6 | version-resolver: 7 | major: 8 | labels: 9 | - 'major' 10 | minor: 11 | labels: 12 | - 'minor' 13 | patch: 14 | labels: 15 | - 'patch' 16 | default: patch 17 | -------------------------------------------------------------------------------- /.github/workflows/draft.yaml: -------------------------------------------------------------------------------- 1 | name: Draft 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | 10 | draft: 11 | name: Draft 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Draft Release 15 | uses: release-drafter/release-drafter@v5 16 | env: 17 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 18 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | push: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | 13 | go-versions: 14 | name: Fetch Go versions 15 | runs-on: ubuntu-latest 16 | outputs: 17 | versions: ${{ steps.versions.outputs.value }} 18 | steps: 19 | - name: Fetch versions 20 | id: versions 21 | run: | 22 | versions=$(curl -s 'https://go.dev/dl/?mode=json' | jq -c 'map(.version[2:])') 23 | echo "value=${versions}" >> $GITHUB_OUTPUT 24 | 25 | lint: 26 | name: Lint 27 | needs: 28 | - go-versions 29 | strategy: 30 | matrix: 31 | go-version: ${{ fromJson(needs.go-versions.outputs.versions) }} 32 | runs-on: ubuntu-latest 33 | steps: 34 | - name: Checkout 35 | uses: actions/checkout@v4 36 | - name: Setup Go 37 | uses: actions/setup-go@v4 38 | with: 39 | go-version: ${{ matrix.go-version }} 40 | - name: Run golangci-lint 41 | uses: golangci/golangci-lint-action@v3 42 | with: 43 | version: v1.54 44 | - name: Run gosec 45 | run: go run github.com/securego/gosec/v2/cmd/gosec@latest ./... 46 | - name: Run govulncheck 47 | run: go run golang.org/x/vuln/cmd/govulncheck@latest ./... 48 | 49 | test: 50 | name: Test 51 | needs: 52 | - go-versions 53 | strategy: 54 | matrix: 55 | go-version: ${{ fromJson(needs.go-versions.outputs.versions) }} 56 | os: [ubuntu-latest, macos-latest] 57 | runs-on: ${{ matrix.os }} 58 | steps: 59 | - name: Checkout 60 | uses: actions/checkout@v4 61 | - name: Setup Go 62 | uses: actions/setup-go@v4 63 | with: 64 | go-version: ${{ matrix.go-version }} 65 | - name: Run test 66 | run: go test ./... 67 | 68 | coverage: 69 | name: Coverage 70 | runs-on: ubuntu-latest 71 | steps: 72 | - name: Checkout 73 | uses: actions/checkout@v4 74 | - name: Setup Go 75 | uses: actions/setup-go@v4 76 | with: 77 | go-version-file: ./go.mod 78 | - name: Generate coverage 79 | run: go test ./... -cover -coverprofile coverage.out -covermode atomic 80 | - name: Upload coverage 81 | uses: codecov/codecov-action@v3 82 | with: 83 | file: ./coverage.out 84 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinchsan/device-check-go/57eb89b7b335243b3857814bfddabe22be0d04b3/.gitignore -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | timeout: 1m 3 | tests: false 4 | 5 | linters-settings: 6 | tagliatelle: 7 | case: 8 | use-field-name: true 9 | rules: 10 | json: snake 11 | 12 | linters: 13 | enable-all: true 14 | disable: 15 | - ireturn 16 | - errchkjson 17 | - depguard 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Masaya Hayashi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # device-check-go 2 | 3 | ![](https://github.com/rinchsan/device-check-go/workflows/CI/badge.svg) 4 | ![](https://img.shields.io/github/release/rinchsan/device-check-go.svg?colorB=7E7E7E) 5 | [![](https://pkg.go.dev/badge/github.com/rinchsan/device-check-go.svg)](https://pkg.go.dev/github.com/rinchsan/device-check-go) 6 | [![](https://codecov.io/github/rinchsan/device-check-go/coverage.svg?branch=main)](https://codecov.io/github/rinchsan/device-check-go?branch=main) 7 | [![](https://goreportcard.com/badge/github.com/rinchsan/device-check-go)](https://goreportcard.com/report/github.com/rinchsan/device-check-go) 8 | [![](https://awesome.re/mentioned-badge.svg)](https://awesome-go.com/#third-party-apis) 9 | [![](http://img.shields.io/badge/license-MIT-blue.svg?style=flat)](LICENSE) 10 | 11 | :iphone: iOS DeviceCheck SDK for Go - query and modify the per-device bits 12 | 13 | ## Installation 14 | 15 | ```bash 16 | go get github.com/rinchsan/device-check-go/v2 17 | ``` 18 | 19 | ## Getting started 20 | 21 | ### Initialize SDK 22 | 23 | ```go 24 | import "github.com/rinchsan/device-check-go/v2" 25 | 26 | cred := devicecheck.NewCredentialFile("/path/to/private/key/file") // You can create credential also from raw string/bytes 27 | cfg := devicecheck.NewConfig("ISSUER", "KEY_ID", devicecheck.Development) 28 | client := devicecheck.New(cred, cfg) 29 | ```` 30 | 31 | ### Use DeviceCheck API 32 | 33 | #### Query two bits 34 | 35 | ```go 36 | var result devicecheck.QueryTwoBitsResult 37 | if err := client.QueryTwoBits("DEVICE_TOKEN", &result); err != nil { 38 | switch { 39 | // Note that QueryTwoBits returns ErrBitStateNotFound error if no bits found 40 | case errors.Is(err, devicecheck.ErrBitStateNotFound): 41 | // handle ErrBitStateNotFound error 42 | default: 43 | // handle other errors 44 | } 45 | } 46 | ``` 47 | 48 | #### Update two bits 49 | 50 | ```go 51 | if err := client.UpdateTwoBits("DEVICE_TOKEN", true, true); err != nil { 52 | // handle errors 53 | } 54 | ``` 55 | 56 | #### Validate device token 57 | 58 | ```go 59 | if err := client.ValidateDeviceToken("DEVICE_TOKEN"); err != nil { 60 | // handle errors 61 | } 62 | ``` 63 | 64 | ## Apple documentation 65 | 66 | - [iOS DeviceCheck API for Swift](https://developer.apple.com/documentation/devicecheck) 67 | - [HTTP commands to query and modify the per-device bits](https://developer.apple.com/documentation/devicecheck/accessing_and_modifying_per-device_data) 68 | -------------------------------------------------------------------------------- /api.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | ) 11 | 12 | const ( 13 | developmentBaseURL = "https://api.development.devicecheck.apple.com/v1" 14 | productionBaseURL = "https://api.devicecheck.apple.com/v1" 15 | ) 16 | 17 | func newBaseURL(env Environment) string { 18 | switch env { 19 | case Development: 20 | return developmentBaseURL 21 | case Production: 22 | return productionBaseURL 23 | default: 24 | return developmentBaseURL 25 | } 26 | } 27 | 28 | type api struct { 29 | client *http.Client 30 | baseURL string 31 | } 32 | 33 | func newAPI(env Environment) api { 34 | return api{ 35 | client: http.DefaultClient, 36 | baseURL: newBaseURL(env), 37 | } 38 | } 39 | 40 | func newAPIWithHTTPClient(client *http.Client, env Environment) api { 41 | return api{ 42 | client: client, 43 | baseURL: newBaseURL(env), 44 | } 45 | } 46 | 47 | func (api api) do(ctx context.Context, jwt, path string, requestBody interface{}) (int, string, error) { 48 | buf := new(bytes.Buffer) 49 | if err := json.NewEncoder(buf).Encode(requestBody); err != nil { 50 | return 0, "", fmt.Errorf("json: %w", err) 51 | } 52 | 53 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, api.baseURL+path, buf) 54 | if err != nil { 55 | return 0, "", fmt.Errorf("http: %w", err) 56 | } 57 | 58 | req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", jwt)) 59 | req.Header.Set("User-Agent", "device-check-go (+https://github.com/rinchsan/device-check-go)") 60 | 61 | resp, err := api.client.Do(req) 62 | if err != nil { 63 | var traceID string 64 | if resp != nil { 65 | traceID = resp.Header.Get("x-b3-traceid") 66 | } 67 | 68 | return 0, "", fmt.Errorf("http: %w: x-b3-traceid: %s", err, traceID) 69 | } 70 | defer resp.Body.Close() 71 | 72 | respBody, err := io.ReadAll(resp.Body) 73 | if err != nil { 74 | return 0, "", fmt.Errorf("io: %w", err) 75 | } 76 | 77 | return resp.StatusCode, string(respBody), nil 78 | } 79 | -------------------------------------------------------------------------------- /api_test.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func Test_newBaseURL(t *testing.T) { 11 | t.Parallel() 12 | 13 | cases := map[string]struct { 14 | env Environment 15 | want string 16 | }{ 17 | "development": { 18 | env: Development, 19 | want: "https://api.development.devicecheck.apple.com/v1", 20 | }, 21 | "production": { 22 | env: Production, 23 | want: "https://api.devicecheck.apple.com/v1", 24 | }, 25 | "unknown": { 26 | env: -1, 27 | want: "https://api.development.devicecheck.apple.com/v1", 28 | }, 29 | } 30 | 31 | for name, c := range cases { 32 | c := c 33 | t.Run(name, func(t *testing.T) { 34 | t.Parallel() 35 | 36 | got := newBaseURL(c.env) 37 | 38 | if !reflect.DeepEqual(got, c.want) { 39 | t.Errorf("want '%+v', got '%+v'", c.want, got) 40 | } 41 | }) 42 | } 43 | } 44 | 45 | func Test_newAPI(t *testing.T) { 46 | t.Parallel() 47 | 48 | cases := map[string]struct { 49 | env Environment 50 | want api 51 | }{ 52 | "development": { 53 | env: Development, 54 | want: api{ 55 | client: http.DefaultClient, 56 | baseURL: "https://api.development.devicecheck.apple.com/v1", 57 | }, 58 | }, 59 | "production": { 60 | env: Production, 61 | want: api{ 62 | client: http.DefaultClient, 63 | baseURL: "https://api.devicecheck.apple.com/v1", 64 | }, 65 | }, 66 | "unknown environment": { 67 | env: -1, 68 | want: api{ 69 | client: http.DefaultClient, 70 | baseURL: "https://api.development.devicecheck.apple.com/v1", 71 | }, 72 | }, 73 | } 74 | 75 | for name, c := range cases { 76 | c := c 77 | t.Run(name, func(t *testing.T) { 78 | t.Parallel() 79 | 80 | got := newAPI(c.env) 81 | 82 | if !reflect.DeepEqual(got, c.want) { 83 | t.Errorf("want '%+v', got '%+v'", c.want, got) 84 | } 85 | }) 86 | } 87 | } 88 | 89 | func Test_newAPIWithHTTPClient(t *testing.T) { 90 | t.Parallel() 91 | 92 | client := new(http.Client) 93 | cases := map[string]struct { 94 | client *http.Client 95 | env Environment 96 | want api 97 | }{ 98 | "development": { 99 | client: client, 100 | env: Development, 101 | want: api{ 102 | client: client, 103 | baseURL: "https://api.development.devicecheck.apple.com/v1", 104 | }, 105 | }, 106 | "production": { 107 | client: client, 108 | env: Production, 109 | want: api{ 110 | client: client, 111 | baseURL: "https://api.devicecheck.apple.com/v1", 112 | }, 113 | }, 114 | "unknown environment": { 115 | client: client, 116 | env: -1, 117 | want: api{ 118 | client: client, 119 | baseURL: "https://api.development.devicecheck.apple.com/v1", 120 | }, 121 | }, 122 | } 123 | 124 | for name, c := range cases { 125 | c := c 126 | t.Run(name, func(t *testing.T) { 127 | t.Parallel() 128 | 129 | got := newAPIWithHTTPClient(c.client, c.env) 130 | 131 | if !reflect.DeepEqual(got, c.want) { 132 | t.Errorf("want '%+v', got '%+v'", c.want, got) 133 | } 134 | }) 135 | } 136 | } 137 | 138 | func TestAPI_do(t *testing.T) { 139 | t.Parallel() 140 | 141 | cases := map[string]struct { 142 | baseURL string 143 | path string 144 | body interface{} 145 | noErr bool 146 | }{ 147 | "empty body": { 148 | baseURL: "http://example.com", 149 | path: "/", 150 | body: nil, 151 | noErr: true, 152 | }, 153 | "invalid url": { 154 | baseURL: "invalid url", 155 | path: "/", 156 | body: nil, 157 | noErr: false, 158 | }, 159 | "invalid path": { 160 | baseURL: "http://example.com", 161 | path: "invalid path", 162 | body: nil, 163 | noErr: false, 164 | }, 165 | "invalid body": { 166 | baseURL: "http://example.com", 167 | path: "/", 168 | body: func() {}, 169 | noErr: false, 170 | }, 171 | } 172 | 173 | for name, c := range cases { 174 | c := c 175 | t.Run(name, func(t *testing.T) { 176 | t.Parallel() 177 | 178 | api := api{ 179 | client: http.DefaultClient, 180 | baseURL: c.baseURL, 181 | } 182 | code, body, err := api.do(context.Background(), "jwt", c.path, c.body) 183 | 184 | if c.noErr { 185 | if err != nil { 186 | t.Errorf("want 'nil', got '%+v'", err) 187 | } 188 | if code != http.StatusOK { 189 | t.Errorf("want '200', got '%d'", code) 190 | } 191 | if len(body) == 0 { 192 | t.Error("want non-empty body, got empty") 193 | } 194 | } else { 195 | if err == nil { 196 | t.Error("want 'not nil', got 'nil'") 197 | } 198 | if code == http.StatusOK { 199 | t.Errorf("want 'not 200', got '200'") 200 | } 201 | if len(body) != 0 { 202 | t.Error("want empty body, got non-empty") 203 | } 204 | } 205 | }) 206 | } 207 | } 208 | 209 | type roundTripFunc func(req *http.Request) *http.Response 210 | 211 | func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { 212 | return f(req), nil 213 | } 214 | 215 | func newMockHTTPClient(resp *http.Response) *http.Client { 216 | return &http.Client{ 217 | Transport: roundTripFunc(func(r *http.Request) *http.Response { 218 | return resp 219 | }), 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import "net/http" 4 | 5 | // Client provides methods to use DeviceCheck API. 6 | type Client struct { 7 | api api 8 | cred Credential 9 | jwt jwt 10 | } 11 | 12 | // New returns a new DeviceCheck API client. 13 | func New(cred Credential, cfg Config) *Client { 14 | return &Client{ 15 | api: newAPI(cfg.env), 16 | cred: cred, 17 | jwt: newJWT(cfg.issuer, cfg.keyID), 18 | } 19 | } 20 | 21 | // NewWithHTTPClient returns a new DeviceCheck API client with specified http client. 22 | func NewWithHTTPClient(httpClient *http.Client, cred Credential, cfg Config) *Client { 23 | return &Client{ 24 | api: newAPIWithHTTPClient(httpClient, cfg.env), 25 | cred: cred, 26 | jwt: newJWT(cfg.issuer, cfg.keyID), 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "net/http" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestNew(t *testing.T) { 10 | t.Parallel() 11 | 12 | cases := map[string]struct { 13 | cred Credential 14 | cfg Config 15 | want *Client 16 | }{ 17 | "development": { 18 | cred: NewCredentialFile("revoked_private_key.p8"), 19 | cfg: NewConfig("issuer", "keyID", Development), 20 | want: &Client{ 21 | api: api{ 22 | client: http.DefaultClient, 23 | baseURL: "https://api.development.devicecheck.apple.com/v1", 24 | }, 25 | cred: credentialFile{ 26 | filename: "revoked_private_key.p8", 27 | }, 28 | jwt: jwt{ 29 | issuer: "issuer", 30 | keyID: "keyID", 31 | }, 32 | }, 33 | }, 34 | "production": { 35 | cred: NewCredentialFile("revoked_private_key.p8"), 36 | cfg: NewConfig("issuer", "keyID", Production), 37 | want: &Client{ 38 | api: api{ 39 | client: http.DefaultClient, 40 | baseURL: "https://api.devicecheck.apple.com/v1", 41 | }, 42 | cred: credentialFile{ 43 | filename: "revoked_private_key.p8", 44 | }, 45 | jwt: jwt{ 46 | issuer: "issuer", 47 | keyID: "keyID", 48 | }, 49 | }, 50 | }, 51 | } 52 | 53 | for name, c := range cases { 54 | c := c 55 | t.Run(name, func(t *testing.T) { 56 | t.Parallel() 57 | 58 | got := New(c.cred, c.cfg) 59 | 60 | if !reflect.DeepEqual(got, c.want) { 61 | t.Errorf("want '%+v', got '%+v'", c.want, got) 62 | } 63 | }) 64 | } 65 | } 66 | 67 | func TestNewWithHTTPClient(t *testing.T) { 68 | t.Parallel() 69 | 70 | client := new(http.Client) 71 | cases := map[string]struct { 72 | client *http.Client 73 | cred Credential 74 | cfg Config 75 | want *Client 76 | }{ 77 | "development": { 78 | client: client, 79 | cred: NewCredentialFile("revoked_private_key.p8"), 80 | cfg: NewConfig("issuer", "keyID", Development), 81 | want: &Client{ 82 | api: api{ 83 | client: client, 84 | baseURL: "https://api.development.devicecheck.apple.com/v1", 85 | }, 86 | cred: credentialFile{ 87 | filename: "revoked_private_key.p8", 88 | }, 89 | jwt: jwt{ 90 | issuer: "issuer", 91 | keyID: "keyID", 92 | }, 93 | }, 94 | }, 95 | "production": { 96 | client: client, 97 | cred: NewCredentialFile("revoked_private_key.p8"), 98 | cfg: NewConfig("issuer", "keyID", Production), 99 | want: &Client{ 100 | api: api{ 101 | client: client, 102 | baseURL: "https://api.devicecheck.apple.com/v1", 103 | }, 104 | cred: credentialFile{ 105 | filename: "revoked_private_key.p8", 106 | }, 107 | jwt: jwt{ 108 | issuer: "issuer", 109 | keyID: "keyID", 110 | }, 111 | }, 112 | }, 113 | } 114 | 115 | for name, c := range cases { 116 | c := c 117 | t.Run(name, func(t *testing.T) { 118 | t.Parallel() 119 | 120 | got := NewWithHTTPClient(c.client, c.cred, c.cfg) 121 | 122 | if !reflect.DeepEqual(got, c.want) { 123 | t.Errorf("want '%+v', got '%+v'", c.want, got) 124 | } 125 | }) 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: 90% 6 | patch: off 7 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | // Environment specifies DeviceCheck API environment. 4 | type Environment int 5 | 6 | const ( 7 | // Development specifies Apple's development environment. 8 | Development Environment = iota + 1 9 | // Production specifies Apple's production environment. 10 | Production 11 | ) 12 | 13 | // Config provides configuration for DeviceCheck API. 14 | type Config struct { 15 | env Environment 16 | issuer string 17 | keyID string 18 | } 19 | 20 | // NewConfig returns a new configuration. 21 | func NewConfig(issuer, keyID string, env Environment) Config { 22 | return Config{ 23 | env: env, 24 | issuer: issuer, 25 | keyID: keyID, 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /config_test.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestNewConfig(t *testing.T) { 9 | t.Parallel() 10 | 11 | cases := map[string]struct { 12 | issuer string 13 | keyID string 14 | env Environment 15 | want Config 16 | }{ 17 | "development": { 18 | issuer: "issuer", 19 | keyID: "keyID", 20 | env: Development, 21 | want: Config{ 22 | env: Development, 23 | issuer: "issuer", 24 | keyID: "keyID", 25 | }, 26 | }, 27 | "production": { 28 | issuer: "issuer", 29 | keyID: "keyID", 30 | env: Production, 31 | want: Config{ 32 | env: Production, 33 | issuer: "issuer", 34 | keyID: "keyID", 35 | }, 36 | }, 37 | } 38 | 39 | for name, c := range cases { 40 | c := c 41 | t.Run(name, func(t *testing.T) { 42 | t.Parallel() 43 | 44 | got := NewConfig(c.issuer, c.keyID, c.env) 45 | 46 | if !reflect.DeepEqual(got, c.want) { 47 | t.Errorf("want '%+v', got '%+v'", c.want, got) 48 | } 49 | }) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /credential.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "crypto/ecdsa" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/dvsekhvalnov/jose2go/keys/ecc" 9 | ) 10 | 11 | // Credential provides credential for DeviceCheck API. 12 | type Credential interface { 13 | key() (*ecdsa.PrivateKey, error) 14 | } 15 | 16 | type credentialFile struct { 17 | filename string 18 | } 19 | 20 | // NewCredentialFile returns credential from private key file. 21 | func NewCredentialFile(filename string) Credential { 22 | return credentialFile{ 23 | filename: filename, 24 | } 25 | } 26 | 27 | func (cred credentialFile) key() (*ecdsa.PrivateKey, error) { 28 | raw, err := os.ReadFile(cred.filename) 29 | if err != nil { 30 | return nil, fmt.Errorf("os: %w", err) 31 | } 32 | 33 | key, err := ecc.ReadPrivate(raw) 34 | if err != nil { 35 | return nil, fmt.Errorf("ecc: %w", err) 36 | } 37 | 38 | return key, nil 39 | } 40 | 41 | type credentialBytes struct { 42 | raw []byte 43 | } 44 | 45 | // NewCredentialBytes returns credential from private key bytes. 46 | func NewCredentialBytes(raw []byte) Credential { 47 | return credentialBytes{ 48 | raw: raw, 49 | } 50 | } 51 | 52 | func (cred credentialBytes) key() (*ecdsa.PrivateKey, error) { 53 | key, err := ecc.ReadPrivate(cred.raw) 54 | if err != nil { 55 | return nil, fmt.Errorf("ecc: %w", err) 56 | } 57 | 58 | return key, nil 59 | } 60 | 61 | type credentialString struct { 62 | str string 63 | } 64 | 65 | // NewCredentialString returns credential from private key string. 66 | func NewCredentialString(str string) Credential { 67 | return credentialString{ 68 | str: str, 69 | } 70 | } 71 | 72 | func (cred credentialString) key() (*ecdsa.PrivateKey, error) { 73 | key, err := ecc.ReadPrivate([]byte(cred.str)) 74 | if err != nil { 75 | return nil, fmt.Errorf("ecc: %w", err) 76 | } 77 | 78 | return key, nil 79 | } 80 | -------------------------------------------------------------------------------- /credential_test.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "os" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestNewCredentialFile(t *testing.T) { 10 | t.Parallel() 11 | 12 | cases := map[string]struct { 13 | filename string 14 | want credentialFile 15 | }{ 16 | "valid filename": { 17 | filename: "revoked_private_key.p8", 18 | want: credentialFile{ 19 | filename: "revoked_private_key.p8", 20 | }, 21 | }, 22 | } 23 | 24 | for name, c := range cases { 25 | c := c 26 | t.Run(name, func(t *testing.T) { 27 | t.Parallel() 28 | 29 | got := NewCredentialFile(c.filename) 30 | 31 | if !reflect.DeepEqual(got, c.want) { 32 | t.Errorf("want '%+v', got '%+v'", c.want, got) 33 | } 34 | }) 35 | } 36 | } 37 | 38 | func TestCredentialFile_key(t *testing.T) { 39 | t.Parallel() 40 | 41 | cases := map[string]struct { 42 | cred credentialFile 43 | noErr bool 44 | }{ 45 | "valid credential": { 46 | cred: credentialFile{ 47 | filename: "revoked_private_key.p8", 48 | }, 49 | noErr: true, 50 | }, 51 | "invalid credential": { 52 | cred: credentialFile{ 53 | filename: "credential_test.go", 54 | }, 55 | noErr: false, 56 | }, 57 | "unknown filename": { 58 | cred: credentialFile{ 59 | filename: "unknown_file.p8", 60 | }, 61 | noErr: false, 62 | }, 63 | } 64 | 65 | for name, c := range cases { 66 | c := c 67 | t.Run(name, func(t *testing.T) { 68 | t.Parallel() 69 | 70 | key, err := c.cred.key() 71 | 72 | if c.noErr { 73 | if err != nil { 74 | t.Errorf("want 'nil', got '%+v'", err) 75 | } 76 | if key == nil { 77 | t.Error("want 'not nil', got 'nil'") 78 | } 79 | } else { 80 | if err == nil { 81 | t.Error("want 'not nil', got 'nil'") 82 | } 83 | if key != nil { 84 | t.Errorf("want 'nil', got '%+v'", key) 85 | } 86 | } 87 | }) 88 | } 89 | } 90 | 91 | func TestNewCredentialBytes(t *testing.T) { 92 | t.Parallel() 93 | 94 | cases := map[string]struct { 95 | filename string 96 | }{ 97 | "valid filename": { 98 | filename: "revoked_private_key.p8", 99 | }, 100 | } 101 | 102 | for name, c := range cases { 103 | c := c 104 | t.Run(name, func(t *testing.T) { 105 | t.Parallel() 106 | 107 | raw, err := os.ReadFile(c.filename) 108 | if err != nil { 109 | t.Errorf("want 'nil', got '%+v'", err) 110 | } 111 | 112 | got := NewCredentialBytes(raw) 113 | want := credentialBytes{raw: raw} 114 | 115 | if !reflect.DeepEqual(got, want) { 116 | t.Errorf("want '%+v', got '%+v'", want, got) 117 | } 118 | }) 119 | } 120 | } 121 | 122 | func TestCredentialBytes_key(t *testing.T) { 123 | t.Parallel() 124 | 125 | cases := map[string]struct { 126 | filename string 127 | noErr bool 128 | }{ 129 | "valid filename": { 130 | filename: "revoked_private_key.p8", 131 | noErr: true, 132 | }, 133 | "invalid private key": { 134 | filename: "invalid_private_key.p8", 135 | noErr: false, 136 | }, 137 | } 138 | 139 | for name, c := range cases { 140 | c := c 141 | t.Run(name, func(t *testing.T) { 142 | t.Parallel() 143 | 144 | raw, err := os.ReadFile(c.filename) 145 | if err != nil { 146 | t.Errorf("want 'nil', got '%+v'", err) 147 | } 148 | 149 | cred := NewCredentialBytes(raw) 150 | key, err := cred.key() 151 | 152 | if c.noErr { 153 | if err != nil { 154 | t.Errorf("want 'nil', got '%+v'", err) 155 | } 156 | if key == nil { 157 | t.Error("want 'not nil', got 'nil'") 158 | } 159 | } else { 160 | if err == nil { 161 | t.Error("want 'not nil', got 'nil'") 162 | } 163 | if key != nil { 164 | t.Errorf("want 'nil', got '%+v'", key) 165 | } 166 | } 167 | }) 168 | } 169 | } 170 | 171 | func TestNewCredentialString(t *testing.T) { 172 | t.Parallel() 173 | 174 | cases := map[string]struct { 175 | filename string 176 | }{ 177 | "valid filename": { 178 | filename: "revoked_private_key.p8", 179 | }, 180 | } 181 | 182 | for name, c := range cases { 183 | c := c 184 | t.Run(name, func(t *testing.T) { 185 | t.Parallel() 186 | 187 | raw, err := os.ReadFile(c.filename) 188 | if err != nil { 189 | t.Errorf("want 'nil', got '%+v'", err) 190 | } 191 | 192 | got := NewCredentialString(string(raw)) 193 | want := credentialString{str: string(raw)} 194 | 195 | if !reflect.DeepEqual(got, want) { 196 | t.Errorf("want '%+v', got '%+v'", want, got) 197 | } 198 | }) 199 | } 200 | } 201 | 202 | func TestCredentialString_key(t *testing.T) { 203 | t.Parallel() 204 | 205 | cases := map[string]struct { 206 | filename string 207 | noErr bool 208 | }{ 209 | "valid credential": { 210 | filename: "revoked_private_key.p8", 211 | noErr: true, 212 | }, 213 | "invalid credential": { 214 | filename: "credential_test.go", 215 | noErr: false, 216 | }, 217 | } 218 | 219 | for name, c := range cases { 220 | c := c 221 | t.Run(name, func(t *testing.T) { 222 | t.Parallel() 223 | 224 | raw, err := os.ReadFile(c.filename) 225 | if err != nil { 226 | t.Errorf("want 'nil', got '%+v'", err) 227 | } 228 | 229 | cred := NewCredentialString(string(raw)) 230 | key, err := cred.key() 231 | 232 | if c.noErr { 233 | if err != nil { 234 | t.Errorf("want 'nil', got '%+v'", err) 235 | } 236 | if key == nil { 237 | t.Error("want 'not nil', got 'nil'") 238 | } 239 | } else { 240 | if err == nil { 241 | t.Error("want 'not nil', got 'nil'") 242 | } 243 | if key != nil { 244 | t.Errorf("want 'nil', got '%+v'", key) 245 | } 246 | } 247 | }) 248 | } 249 | } 250 | -------------------------------------------------------------------------------- /devicecheck_test.go: -------------------------------------------------------------------------------- 1 | package devicecheck_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | 8 | devicecheck "github.com/rinchsan/device-check-go/v2" 9 | ) 10 | 11 | func Test(t *testing.T) { 12 | t.Parallel() 13 | 14 | cred := devicecheck.NewCredentialFile("revoked_private_key.p8") 15 | cfg := devicecheck.NewConfig("ISSUER", "KEY_ID", devicecheck.Development) 16 | client := devicecheck.New(cred, cfg) 17 | 18 | err := client.ValidateDeviceToken(context.Background(), "token") 19 | 20 | if !errors.Is(err, devicecheck.ErrUnauthorized) { 21 | t.Error("want 'devicecheck.ErrUnauthorized'") 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "strings" 8 | ) 9 | 10 | const ( 11 | bitStateNotFoundStr = "Failed to find bit state" 12 | ) 13 | 14 | var ( 15 | ErrBadRequest = errors.New("bad request") 16 | ErrUnauthorized = errors.New("invalid or expired token") 17 | ErrForbidden = errors.New("action not allowed") 18 | ErrMethodNotAllowed = errors.New("method not allowed") 19 | ErrTooManyRequests = errors.New("too many requests") 20 | ErrServer = errors.New("server error") 21 | ErrServiceUnavailable = errors.New("service unavailable") 22 | ErrUnknown = errors.New("unknown error") 23 | ErrBitStateNotFound = errors.New("bit state not found") 24 | ) 25 | 26 | func isErrBitStateNotFound(body string) bool { 27 | return strings.Contains(body, bitStateNotFoundStr) 28 | } 29 | 30 | func newError(code int, body string) error { 31 | switch code { 32 | case http.StatusBadRequest: 33 | return fmt.Errorf("%w: %s", ErrBadRequest, body) 34 | case http.StatusUnauthorized: 35 | return fmt.Errorf("%w: %s", ErrUnauthorized, body) 36 | case http.StatusForbidden: 37 | return fmt.Errorf("%w: %s", ErrForbidden, body) 38 | case http.StatusMethodNotAllowed: 39 | return fmt.Errorf("%w: %s", ErrMethodNotAllowed, body) 40 | case http.StatusTooManyRequests: 41 | return fmt.Errorf("%w: %s", ErrTooManyRequests, body) 42 | case http.StatusInternalServerError: 43 | return fmt.Errorf("%w: %s", ErrServer, body) 44 | case http.StatusServiceUnavailable: 45 | return fmt.Errorf("%w: %s", ErrServiceUnavailable, body) 46 | default: 47 | return fmt.Errorf("%w: %s", ErrUnknown, body) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /error_test.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func Test_isErrBitStateNotFound(t *testing.T) { 11 | t.Parallel() 12 | 13 | cases := map[string]struct { 14 | body string 15 | want bool 16 | }{ 17 | "is ErrBitStateNotFound": { 18 | body: "Failed to find bit state", 19 | want: true, 20 | }, 21 | "is not ErrBitStateNotFound": { 22 | body: "Missing or incorrectly formatted bits", 23 | want: false, 24 | }, 25 | } 26 | 27 | for name, c := range cases { 28 | c := c 29 | t.Run(name, func(t *testing.T) { 30 | t.Parallel() 31 | 32 | got := isErrBitStateNotFound(c.body) 33 | 34 | if !reflect.DeepEqual(got, c.want) { 35 | t.Errorf("want '%+v', got '%+v'", c.want, got) 36 | } 37 | }) 38 | } 39 | } 40 | 41 | func Test_newError(t *testing.T) { 42 | t.Parallel() 43 | 44 | cases := map[string]struct { 45 | code int 46 | want error 47 | }{ 48 | "bad request": { 49 | code: http.StatusBadRequest, 50 | want: ErrBadRequest, 51 | }, 52 | "unauthorized": { 53 | code: http.StatusUnauthorized, 54 | want: ErrUnauthorized, 55 | }, 56 | "forbidden": { 57 | code: http.StatusForbidden, 58 | want: ErrForbidden, 59 | }, 60 | "method not allowed": { 61 | code: http.StatusMethodNotAllowed, 62 | want: ErrMethodNotAllowed, 63 | }, 64 | "too many requests": { 65 | code: http.StatusTooManyRequests, 66 | want: ErrTooManyRequests, 67 | }, 68 | "server error": { 69 | code: http.StatusInternalServerError, 70 | want: ErrServer, 71 | }, 72 | "service unavailable": { 73 | code: http.StatusServiceUnavailable, 74 | want: ErrServiceUnavailable, 75 | }, 76 | "unknown": { 77 | code: http.StatusBadGateway, 78 | want: ErrUnknown, 79 | }, 80 | } 81 | 82 | for name, c := range cases { 83 | c := c 84 | t.Run(name, func(t *testing.T) { 85 | t.Parallel() 86 | 87 | got := newError(c.code, "body") 88 | 89 | if !errors.Is(got, c.want) { 90 | t.Error("got error does not wrap expected error") 91 | } 92 | }) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/rinchsan/device-check-go/v2 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/dvsekhvalnov/jose2go v1.5.0 7 | github.com/google/uuid v1.3.1 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/dvsekhvalnov/jose2go v1.5.0 h1:3j8ya4Z4kMCwT5nXIKFSV84YS+HdqSSO0VsTQxaLAeM= 2 | github.com/dvsekhvalnov/jose2go v1.5.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU= 3 | github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= 4 | github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 5 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 6 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 7 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 8 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= 9 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= 10 | gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b h1:QRR6H1YWRnHb4Y/HeNFCTJLFVxaq6wH4YuVdsUOr75U= 11 | gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 12 | -------------------------------------------------------------------------------- /invalid_private_key.p8: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | invalid 3 | -----END EC PRIVATE KEY----- 4 | -------------------------------------------------------------------------------- /jwt.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "crypto/ecdsa" 5 | "encoding/json" 6 | "fmt" 7 | "time" 8 | 9 | jose "github.com/dvsekhvalnov/jose2go" 10 | ) 11 | 12 | type jwt struct { 13 | issuer string 14 | keyID string 15 | } 16 | 17 | func newJWT(issuer, keyID string) jwt { 18 | return jwt{ 19 | issuer: issuer, 20 | keyID: keyID, 21 | } 22 | } 23 | 24 | func (jwt jwt) generate(key *ecdsa.PrivateKey) (string, error) { 25 | claims := map[string]interface{}{ 26 | "iss": jwt.issuer, 27 | "iat": time.Now().UTC().Unix(), 28 | } 29 | 30 | // Ignoring error, because json.Marshal never fails. 31 | claimsJSON, _ := json.Marshal(claims) 32 | 33 | headers := map[string]interface{}{ 34 | "alg": jose.ES256, 35 | "kid": jwt.keyID, 36 | } 37 | 38 | token, err := jose.Sign(string(claimsJSON), jose.ES256, key, jose.Headers(headers)) 39 | if err != nil { 40 | return "", fmt.Errorf("jose: %w", err) 41 | } 42 | 43 | return token, nil 44 | } 45 | -------------------------------------------------------------------------------- /jwt_test.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestJWT_newJWT(t *testing.T) { 9 | t.Parallel() 10 | 11 | cases := map[string]struct { 12 | issuer string 13 | keyID string 14 | want jwt 15 | }{ 16 | "valid issuer/keyID": { 17 | issuer: "issuer", 18 | keyID: "keyID", 19 | want: jwt{ 20 | issuer: "issuer", 21 | keyID: "keyID", 22 | }, 23 | }, 24 | } 25 | 26 | for name, c := range cases { 27 | c := c 28 | t.Run(name, func(t *testing.T) { 29 | t.Parallel() 30 | 31 | got := newJWT(c.issuer, c.keyID) 32 | 33 | if !reflect.DeepEqual(got, c.want) { 34 | t.Errorf("want '%+v', got '%+v'", c.want, got) 35 | } 36 | }) 37 | } 38 | } 39 | 40 | func TestJWT_generate(t *testing.T) { 41 | t.Parallel() 42 | 43 | cases := map[string]struct { 44 | filename string 45 | jwt jwt 46 | }{ 47 | "invalid filename": { 48 | filename: "revoked_private_key.p8", 49 | jwt: jwt{ 50 | issuer: "issuer", 51 | keyID: "keyID", 52 | }, 53 | }, 54 | } 55 | 56 | for name, c := range cases { 57 | c := c 58 | t.Run(name, func(t *testing.T) { 59 | t.Parallel() 60 | 61 | cred := NewCredentialFile(c.filename) 62 | key, err := cred.key() 63 | 64 | if err != nil { 65 | t.Error("want 'nil', got 'not nil'") 66 | } 67 | if key == nil { 68 | t.Error("want 'not nil', got 'nil'") 69 | } 70 | 71 | token, err := c.jwt.generate(key) 72 | 73 | if err != nil { 74 | t.Errorf("want 'nil', got '%+v'", err) 75 | } 76 | if token == "" { 77 | t.Error("want 'not empty', got 'empty'") 78 | } 79 | }) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /query.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "strings" 9 | "time" 10 | 11 | "github.com/google/uuid" 12 | ) 13 | 14 | const queryTwoBitsPath = "/query_two_bits" 15 | 16 | type queryTwoBitsRequestBody struct { 17 | DeviceToken string `json:"device_token"` 18 | TransactionID string `json:"transaction_id"` 19 | Timestamp int64 `json:"timestamp"` 20 | } 21 | 22 | // QueryTwoBitsResult provides a result of query-two-bits method. 23 | type QueryTwoBitsResult struct { 24 | Bit0 bool `json:"bit0"` 25 | Bit1 bool `json:"bit1"` 26 | LastUpdateTime Time `json:"last_update_time"` 27 | } 28 | 29 | type Time struct { 30 | time.Time 31 | } 32 | 33 | const timeFormat = "2006-01" 34 | 35 | func (t Time) MarshalJSON() ([]byte, error) { 36 | b, err := json.Marshal(t.Format(timeFormat)) 37 | if err != nil { 38 | return nil, fmt.Errorf("json: %w", err) 39 | } 40 | 41 | return b, nil 42 | } 43 | 44 | func (t *Time) UnmarshalJSON(b []byte) error { 45 | tm, err := time.Parse(timeFormat, strings.Trim(string(b), `"`)) 46 | if err != nil { 47 | return fmt.Errorf("time: %w", err) 48 | } 49 | 50 | t.Time = tm 51 | 52 | return nil 53 | } 54 | 55 | // QueryTwoBits queries two bits for device token. Returns ErrBitStateNotFound if the bits have not been set. 56 | func (client *Client) QueryTwoBits(ctx context.Context, deviceToken string, result *QueryTwoBitsResult) error { 57 | key, err := client.cred.key() 58 | if err != nil { 59 | return fmt.Errorf("devicecheck: failed to create key: %w", err) 60 | } 61 | 62 | jwt, err := client.jwt.generate(key) 63 | if err != nil { 64 | return fmt.Errorf("devicecheck: failed to generate jwt: %w", err) 65 | } 66 | 67 | body := queryTwoBitsRequestBody{ 68 | DeviceToken: deviceToken, 69 | TransactionID: uuid.New().String(), 70 | Timestamp: time.Now().UTC().UnixNano() / int64(time.Millisecond), 71 | } 72 | 73 | code, respBody, err := client.api.do(ctx, jwt, queryTwoBitsPath, body) 74 | if err != nil { 75 | return fmt.Errorf("devicecheck: failed to query two bits: %w: %s", err, respBody) 76 | } 77 | 78 | if code != http.StatusOK { 79 | return fmt.Errorf("devicecheck: %w", newError(code, respBody)) 80 | } 81 | 82 | if isErrBitStateNotFound(respBody) { 83 | return fmt.Errorf("devicecheck: %w", ErrBitStateNotFound) 84 | } 85 | 86 | if err := json.NewDecoder(strings.NewReader(respBody)).Decode(result); err != nil { 87 | return fmt.Errorf("json: %w", err) 88 | } 89 | 90 | return nil 91 | } 92 | -------------------------------------------------------------------------------- /query_test.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "net/http" 8 | "reflect" 9 | "strings" 10 | "testing" 11 | "testing/iotest" 12 | "time" 13 | ) 14 | 15 | func TestTime_MarshalJSON(t *testing.T) { 16 | t.Parallel() 17 | 18 | cases := map[string]struct { 19 | year int 20 | month time.Month 21 | want string 22 | }{ 23 | "2019-04": { 24 | year: 2019, 25 | month: time.April, 26 | want: `"2019-04"`, 27 | }, 28 | } 29 | 30 | for name, c := range cases { 31 | c := c 32 | t.Run(name, func(t *testing.T) { 33 | t.Parallel() 34 | 35 | tm := Time{Time: time.Date(c.year, c.month, 1, 0, 0, 0, 0, time.UTC)} 36 | got, err := tm.MarshalJSON() 37 | 38 | if err != nil { 39 | t.Errorf("want 'nil', got '%+v'", err) 40 | } 41 | if !reflect.DeepEqual(string(got), c.want) { 42 | t.Errorf("want '%+v', got '%+v'", c.want, string(got)) 43 | } 44 | }) 45 | } 46 | } 47 | 48 | func TestTime_UnmarshalJSON(t *testing.T) { 49 | t.Parallel() 50 | 51 | cases := map[string]struct { 52 | b []byte 53 | noErr bool 54 | want Time 55 | }{ 56 | "2019-04": { 57 | b: []byte("2019-04"), 58 | noErr: true, 59 | want: Time{Time: time.Date(2019, time.April, 1, 0, 0, 0, 0, time.UTC)}, 60 | }, 61 | "invalid format": { 62 | b: []byte("2019-04-01"), 63 | noErr: false, 64 | want: Time{}, 65 | }, 66 | } 67 | 68 | for name, c := range cases { 69 | c := c 70 | t.Run(name, func(t *testing.T) { 71 | t.Parallel() 72 | 73 | var got Time 74 | err := got.UnmarshalJSON(c.b) 75 | 76 | if c.noErr { 77 | if err != nil { 78 | t.Errorf("want 'nil', got '%+v'", err) 79 | } 80 | } else { 81 | if err == nil { 82 | t.Error("want 'not nil', got 'nil'") 83 | } 84 | } 85 | if !reflect.DeepEqual(got, c.want) { 86 | t.Errorf("want '%+v', got '%+v'", c.want, got) 87 | } 88 | }) 89 | } 90 | } 91 | 92 | func TestClient_QueryTwoBits(t *testing.T) { 93 | t.Parallel() 94 | 95 | cases := map[string]struct { 96 | client Client 97 | noErr bool 98 | }{ 99 | "invalid key": { 100 | client: Client{ 101 | api: newAPI(Development), 102 | cred: NewCredentialFile("unknown_file.p8"), 103 | jwt: newJWT("issuer", "keyID"), 104 | }, 105 | noErr: false, 106 | }, 107 | "invalid url": { 108 | client: Client{ 109 | api: api{ 110 | client: new(http.Client), 111 | baseURL: "invalid url", 112 | }, 113 | cred: NewCredentialFile("revoked_private_key.p8"), 114 | jwt: newJWT("issuer", "keyID"), 115 | }, 116 | noErr: false, 117 | }, 118 | "invalid device token": { 119 | client: Client{ 120 | api: newAPI(Development), 121 | cred: NewCredentialFile("revoked_private_key.p8"), 122 | jwt: newJWT("issuer", "keyID"), 123 | }, 124 | noErr: false, 125 | }, 126 | "status ok with ErrBitStateNotFound": { 127 | client: Client{ 128 | api: newAPIWithHTTPClient(newMockHTTPClient(&http.Response{ 129 | StatusCode: http.StatusOK, 130 | Body: io.NopCloser(strings.NewReader("Failed to find bit state")), 131 | }), Development), 132 | cred: NewCredentialFile("revoked_private_key.p8"), 133 | jwt: newJWT("issuer", "keyID"), 134 | }, 135 | noErr: false, 136 | }, 137 | "status ok with valid response": { 138 | client: Client{ 139 | api: newAPIWithHTTPClient(newMockHTTPClient(&http.Response{ 140 | StatusCode: http.StatusOK, 141 | Body: io.NopCloser(strings.NewReader(`{"bit0":true,"bit1":false,"last_update_time":"2006-01"}`)), 142 | }), Development), 143 | cred: NewCredentialFile("revoked_private_key.p8"), 144 | jwt: newJWT("issuer", "keyID"), 145 | }, 146 | noErr: true, 147 | }, 148 | "status ok with invalid response": { 149 | client: Client{ 150 | api: newAPIWithHTTPClient(newMockHTTPClient(&http.Response{ 151 | StatusCode: http.StatusOK, 152 | Body: io.NopCloser(iotest.ErrReader(errors.New("io.Reader error"))), 153 | }), Development), 154 | cred: NewCredentialFile("revoked_private_key.p8"), 155 | jwt: newJWT("issuer", "keyID"), 156 | }, 157 | noErr: false, 158 | }, 159 | } 160 | 161 | for name, c := range cases { 162 | c := c 163 | t.Run(name, func(t *testing.T) { 164 | t.Parallel() 165 | 166 | var result QueryTwoBitsResult 167 | err := c.client.QueryTwoBits(context.Background(), "device_token", &result) 168 | 169 | if c.noErr { 170 | if err != nil { 171 | t.Errorf("want 'nil', got '%+v'", err) 172 | } 173 | } else { 174 | if err == nil { 175 | t.Error("want 'not nil', got 'nil'") 176 | } 177 | } 178 | }) 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": [ 3 | "github>rinchsan/renovate-config" 4 | ], 5 | "ignorePresets": [ 6 | ":dependencyDashboard" 7 | ] 8 | } 9 | -------------------------------------------------------------------------------- /revoked_private_key.p8: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQg8U7V49VQAxublOj9 3 | JyxgONgRw/CoRe0fylMYeJDXdhWgCgYIKoZIzj0DAQehRANCAATvN8FjG+f8qgl9 4 | rmSTd+w5hmtg+JnwqGWuTgSp10nX/RNSX157oIVNEbI7eSgwTC33pAzgGhwy2nbU 5 | NqWsaLXE 6 | -----END EC PRIVATE KEY----- 7 | -------------------------------------------------------------------------------- /update.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "time" 8 | 9 | "github.com/google/uuid" 10 | ) 11 | 12 | const updateTwoBitsPath = "/update_two_bits" 13 | 14 | type updateTwoBitsRequestBody struct { 15 | DeviceToken string `json:"device_token"` 16 | TransactionID string `json:"transaction_id"` 17 | Timestamp int64 `json:"timestamp"` 18 | Bit0 bool `json:"bit0"` 19 | Bit1 bool `json:"bit1"` 20 | } 21 | 22 | // UpdateTwoBits updates two bits for device token. 23 | func (client *Client) UpdateTwoBits(ctx context.Context, deviceToken string, bit0, bit1 bool) error { 24 | key, err := client.cred.key() 25 | if err != nil { 26 | return fmt.Errorf("devicecheck: failed to create key: %w", err) 27 | } 28 | 29 | jwt, err := client.jwt.generate(key) 30 | if err != nil { 31 | return fmt.Errorf("devicecheck: failed to generate jwt: %w", err) 32 | } 33 | 34 | body := updateTwoBitsRequestBody{ 35 | DeviceToken: deviceToken, 36 | TransactionID: uuid.New().String(), 37 | Timestamp: time.Now().UTC().UnixNano() / int64(time.Millisecond), 38 | Bit0: bit0, 39 | Bit1: bit1, 40 | } 41 | 42 | code, respBody, err := client.api.do(ctx, jwt, updateTwoBitsPath, body) 43 | if err != nil { 44 | return fmt.Errorf("devicecheck: failed to update two bits: %w: %s", err, respBody) 45 | } 46 | 47 | if code != http.StatusOK { 48 | return fmt.Errorf("devicecheck: %w", newError(code, respBody)) 49 | } 50 | 51 | return nil 52 | } 53 | -------------------------------------------------------------------------------- /update_test.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net/http" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func TestClient_UpdateTwoBits(t *testing.T) { 12 | t.Parallel() 13 | 14 | cases := map[string]struct { 15 | client Client 16 | noErr bool 17 | }{ 18 | "invalid key": { 19 | client: Client{ 20 | api: newAPI(Development), 21 | cred: NewCredentialFile("unknown_file.p8"), 22 | jwt: newJWT("issuer", "keyID"), 23 | }, 24 | noErr: false, 25 | }, 26 | "invalid url": { 27 | client: Client{ 28 | api: api{ 29 | client: new(http.Client), 30 | baseURL: "invalid url", 31 | }, 32 | cred: NewCredentialFile("revoked_private_key.p8"), 33 | jwt: newJWT("issuer", "keyID"), 34 | }, 35 | noErr: false, 36 | }, 37 | "invalid device token": { 38 | client: Client{ 39 | api: newAPI(Development), 40 | cred: NewCredentialFile("revoked_private_key.p8"), 41 | jwt: newJWT("issuer", "keyID"), 42 | }, 43 | noErr: false, 44 | }, 45 | "status ok": { 46 | client: Client{ 47 | api: newAPIWithHTTPClient(newMockHTTPClient(&http.Response{ 48 | StatusCode: http.StatusOK, 49 | Body: io.NopCloser(strings.NewReader("success")), 50 | }), Development), 51 | cred: NewCredentialFile("revoked_private_key.p8"), 52 | jwt: newJWT("issuer", "keyID"), 53 | }, 54 | noErr: true, 55 | }, 56 | } 57 | 58 | for name, c := range cases { 59 | c := c 60 | t.Run(name, func(t *testing.T) { 61 | t.Parallel() 62 | 63 | err := c.client.UpdateTwoBits(context.Background(), "device_token", true, true) 64 | 65 | if c.noErr { 66 | if err != nil { 67 | t.Errorf("want 'nil', got '%+v'", err) 68 | } 69 | } else { 70 | if err == nil { 71 | t.Error("want 'not nil', got 'nil'") 72 | } 73 | } 74 | }) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /validate.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "time" 8 | 9 | "github.com/google/uuid" 10 | ) 11 | 12 | const validateDeviceTokenPath = "/validate_device_token" 13 | 14 | type validateDeviceTokenRequestBody struct { 15 | DeviceToken string `json:"device_token"` 16 | TransactionID string `json:"transaction_id"` 17 | Timestamp int64 `json:"timestamp"` 18 | } 19 | 20 | // ValidateDeviceToken validates a device for device token. 21 | func (client *Client) ValidateDeviceToken(ctx context.Context, deviceToken string) error { 22 | key, err := client.cred.key() 23 | if err != nil { 24 | return fmt.Errorf("devicecheck: failed to create key: %w", err) 25 | } 26 | 27 | jwt, err := client.jwt.generate(key) 28 | if err != nil { 29 | return fmt.Errorf("devicecheck: failed to generate jwt: %w", err) 30 | } 31 | 32 | body := validateDeviceTokenRequestBody{ 33 | DeviceToken: deviceToken, 34 | TransactionID: uuid.New().String(), 35 | Timestamp: time.Now().UTC().UnixNano() / int64(time.Millisecond), 36 | } 37 | 38 | code, respBody, err := client.api.do(ctx, jwt, validateDeviceTokenPath, body) 39 | if err != nil { 40 | return fmt.Errorf("devicecheck: failed to validate device token: %w", err) 41 | } 42 | 43 | if code != http.StatusOK { 44 | return fmt.Errorf("devicecheck: %w", newError(code, respBody)) 45 | } 46 | 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /validate_test.go: -------------------------------------------------------------------------------- 1 | package devicecheck 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net/http" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func TestClient_ValidateDeviceToken(t *testing.T) { 12 | t.Parallel() 13 | 14 | cases := map[string]struct { 15 | client Client 16 | noErr bool 17 | }{ 18 | "invalid key": { 19 | client: Client{ 20 | api: newAPI(Development), 21 | cred: NewCredentialFile("unknown_file.p8"), 22 | jwt: newJWT("issuer", "keyID"), 23 | }, 24 | noErr: false, 25 | }, 26 | "invalid url": { 27 | client: Client{ 28 | api: api{ 29 | client: new(http.Client), 30 | baseURL: "invalid url", 31 | }, 32 | cred: NewCredentialFile("revoked_private_key.p8"), 33 | jwt: newJWT("issuer", "keyID"), 34 | }, 35 | noErr: false, 36 | }, 37 | "invalid device token": { 38 | client: Client{ 39 | api: newAPI(Development), 40 | cred: NewCredentialFile("revoked_private_key.p8"), 41 | jwt: newJWT("issuer", "keyID"), 42 | }, 43 | noErr: false, 44 | }, 45 | "status ok": { 46 | client: Client{ 47 | api: newAPIWithHTTPClient(newMockHTTPClient(&http.Response{ 48 | StatusCode: http.StatusOK, 49 | Body: io.NopCloser(strings.NewReader("success")), 50 | }), Development), 51 | cred: NewCredentialFile("revoked_private_key.p8"), 52 | jwt: newJWT("issuer", "keyID"), 53 | }, 54 | noErr: true, 55 | }, 56 | } 57 | 58 | for name, c := range cases { 59 | c := c 60 | t.Run(name, func(t *testing.T) { 61 | t.Parallel() 62 | 63 | err := c.client.ValidateDeviceToken(context.Background(), "device_token") 64 | 65 | if c.noErr { 66 | if err != nil { 67 | t.Errorf("want 'nil', got '%+v'", err) 68 | } 69 | } else { 70 | if err == nil { 71 | t.Error("want 'not nil', got 'nil'") 72 | } 73 | } 74 | }) 75 | } 76 | } 77 | --------------------------------------------------------------------------------