├── .gitignore
├── example
├── update-user.sh
├── refresh.sh
├── login.sh
├── create-password-reset-request.sh
├── create-user.sh
├── reset-password.sh
├── get-user.sh
├── delete-user.sh
├── urlencode.sh
└── docker-compose.yml
├── component-tests.Dockerfile
├── mail-templates
├── password-reset-request.yml
├── password-reset-request.txt
└── password-reset-request.html
├── internal
├── web
│ ├── internal.go
│ ├── middleware
│ │ ├── basicauth.go
│ │ └── basicauth_test.go
│ ├── internal_test.go
│ ├── server.go
│ ├── admin.go
│ ├── auth.go
│ ├── server_test.go
│ └── provider_moq_test.go
├── storage
│ ├── claims.go
│ ├── storage.go
│ ├── token.go
│ └── user.go
├── jwt
│ ├── validate.go
│ ├── provider.go
│ ├── generate.go
│ ├── validate_test.go
│ ├── provider_test.go
│ └── generate_test.go
├── provider.go
├── mailer
│ ├── template_moq_test.go
│ ├── template.go
│ ├── smtp.go
│ ├── dialer_moq_test.go
│ ├── send_closer_moq_test.go
│ ├── template_test.go
│ └── smtp_test.go
├── mailer_moq_test.go
├── admin.go
├── jwt_generator_moq_test.go
├── auth.go
├── storage_moq_test.go
└── admin_test.go
├── .github
└── workflows
│ ├── gosec.yml
│ ├── go.yml
│ ├── codecov.yml
│ └── codeql-analysis.yml
├── codecov.yml
├── docs
└── password-reset-flow.puml
├── cmd
└── provider
│ ├── alive_resource_component_test.go
│ ├── crud_user_component_test.go
│ ├── refresh_component_test.go
│ ├── main.go
│ ├── config.go
│ ├── reset_password_component_test.go
│ ├── helper_component_test.go
│ ├── login_component_test.go
│ └── config_test.go
├── Dockerfile
├── CHANGELOG.md
├── go.mod
├── component-tests.sh
├── LICENSE
├── component-tests.docker-compose.yml
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | ecdsa-p521-public.pem
3 | ecdsa-p521-private.pem
4 | coverage.out
--------------------------------------------------------------------------------
/example/update-user.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | if [ "$#" -ne "2" ]; then
4 | echo "Two arguments must be set e.g. ./update-user.sh email user(json)"
5 | exit 1
6 | fi
7 |
8 | curl -X PUT --data "$2" "username:password@localhost:8080/v1/admin/users/$1" -v
9 |
--------------------------------------------------------------------------------
/example/refresh.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | if [ "$#" -ne "1" ]; then
4 | echo "One argument must be set e.g. ./refresh.sh refreshToken"
5 | exit 1
6 | fi
7 |
8 | curl -X POST --data "{\"refresh_token\":\"$1\"}" "username:password@localhost:8080/v1/auth/refresh" -v
9 |
--------------------------------------------------------------------------------
/example/login.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | if [ "$#" -ne "2" ]; then
4 | echo "Two arguments must be set e.g. ./login.sh email password"
5 | exit 1
6 | fi
7 |
8 | curl -X POST --data "{\"email\":\"$1\", \"password\":\"$2\"}" "username:password@localhost:8080/v1/auth/login" -v
9 |
--------------------------------------------------------------------------------
/example/create-password-reset-request.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | if [ "$#" -ne "1" ]; then
4 | echo "One argument must be set e.g. ./create-password-reset-request.sh email"
5 | exit 1
6 | fi
7 | curl -X POST --data "{\"email\":\"$1\"}" "localhost:8080/v1/auth/password-reset-request" -v
8 |
--------------------------------------------------------------------------------
/component-tests.Dockerfile:
--------------------------------------------------------------------------------
1 | FROM golang:latest
2 |
3 | RUN mkdir -p /go/mods/tests
4 | WORKDIR /go/mods/tests
5 |
6 | COPY go.mod .
7 | COPY go.sum .
8 |
9 | RUN go mod download
10 |
11 | COPY . /go/mods/tests
12 |
13 | ENV RUN_TESTS=""
14 | CMD go test -v -count=1 -tags=component ${RUN_TESTS} ./cmd/provider
--------------------------------------------------------------------------------
/example/create-user.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | if [ "$#" -ne "3" ]; then
4 | echo "Three arguments must be set e.g. ./create-user.sh email password claim(json)"
5 | exit 1
6 | fi
7 |
8 | curl -X POST --data "{\"email\":\"$1\", \"password\":\"$2\", \"claims\":$3}" "username:password@localhost:8080/v1/admin/users" -v
9 |
--------------------------------------------------------------------------------
/example/reset-password.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | if [ "$#" -ne "3" ]; then
4 | echo "Three arguments must be set e.g. ./reset-password.sh email password reset-token"
5 | exit 1
6 | fi
7 |
8 | curl -X POST --data "{\"email\":\"$1\", \"password\":\"$2\", \"reset_token\": \"$3\"}" localhost:8080/v1/auth/password-reset -v
9 |
--------------------------------------------------------------------------------
/example/get-user.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | if [ "$#" -ne "1" ]; then
4 | echo "One argument must be set e.g. ./get-user.sh email"
5 | exit 1
6 | fi
7 |
8 | DIR="$(cd "$(dirname "$0")" && pwd)"
9 | URLENCODED_EMAIL=$("$DIR"/urlencode.sh "$1")
10 |
11 | curl -X GET "username:password@localhost:8080/v1/admin/users/$URLENCODED_EMAIL" -v
12 |
--------------------------------------------------------------------------------
/mail-templates/password-reset-request.yml:
--------------------------------------------------------------------------------
1 | From:
2 | - "test@leberkleber.io"
3 | To:
4 | - "{{.Recipient}}"
5 | Subject:
6 | - "Password Reset"
7 | # Note: this file must match with type map[string][]string
8 | # e.g.:
9 | # Bcc:
10 | # - "myBCC"
11 | # Reply-To:
12 | # - "dsd"
13 | # mail-headers could be set here (incl. go templating).
--------------------------------------------------------------------------------
/example/delete-user.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | if [ "$#" -ne "1" ]; then
4 | echo "One argument must be set e.g. ./delete-user.sh email"
5 | exit 1
6 | fi
7 |
8 | DIR="$(cd "$(dirname "$0")" && pwd)"
9 | URLENCODED_EMAIL=$("$DIR"/urlencode.sh "$1")
10 |
11 | curl -X DELETE "username:password@localhost:8080/v1/admin/users/$URLENCODED_EMAIL" -v
12 |
--------------------------------------------------------------------------------
/internal/web/internal.go:
--------------------------------------------------------------------------------
1 | package web
2 |
3 | import (
4 | "github.com/sirupsen/logrus"
5 | "net/http"
6 | )
7 |
8 | func (s *Server) aliveHandler(w http.ResponseWriter, _ *http.Request) {
9 | _, err := w.Write([]byte(`{"alive":true}`))
10 | if err != nil {
11 | logrus.WithError(err).Error("Failed to write alive response body")
12 | writeInternalServerError(w)
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/.github/workflows/gosec.yml:
--------------------------------------------------------------------------------
1 | name: gosec
2 | on: [ push ]
3 |
4 | jobs:
5 | tests:
6 | runs-on: ubuntu-latest
7 | env:
8 | GO111MODULE: on
9 | steps:
10 | - name: Checkout Source
11 | uses: actions/checkout@v2
12 | - name: Run Gosec Security Scanner
13 | uses: securego/gosec@master
14 | with:
15 | args: -exclude=G402,G601 ./...
16 |
--------------------------------------------------------------------------------
/mail-templates/password-reset-request.txt:
--------------------------------------------------------------------------------
1 | Dear {{.Recipient}},
2 | you forgot your password. No Problem!
3 |
4 | {{/* replace 'www.leberkleber.io/passwordReset' with your exposed endpoint */}}
5 | You can change it at 'http://www.leberkleber.io/passwordReset?token={{.PasswordResetToken}}'.
6 |
7 | ({{.PasswordResetToken}})
8 |
9 | {{if index .Claims "myCustomClaim"}} ({{index .Claims "myCustomClaim"}}) {{end}}
10 |
11 | Greetings
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | range: 80..90 # coverage lower than 50 is red, higher than 90 green, between color code
3 |
4 | status:
5 | project: # settings affecting project coverage
6 | enabled: yes
7 | target: auto # auto % coverage target
8 | threshold: 5% # allow for 5% reduction of coverage without failing
9 |
10 | # do not run coverage on patch nor changes
11 | patch: no
12 | changes: no
13 |
--------------------------------------------------------------------------------
/docs/password-reset-flow.puml:
--------------------------------------------------------------------------------
1 | @startuml
2 | actor User as u
3 | participant "simple-jwt-provider" as sjp
4 | participant "mail-server" as ms
5 | participant "mail-client" as mc
6 |
7 | u -> sjp: trigger password-reset request
8 | sjp->sjp: generate password-reset mail
9 | sjp->ms: send password-reset mail with reset token
10 | ms->mc: receive password-reset mail
11 | mc->u: receive password-reset mail and extract reset-token
12 |
13 | u->sjp: reset password with received token
14 |
15 | @enduml
16 |
--------------------------------------------------------------------------------
/mail-templates/password-reset-request.html:
--------------------------------------------------------------------------------
1 | Dear {{.Recipient}},
2 | you forgot your password. No Problem!
3 | You can change it at
4 | {{/* replace 'www.leberkleber.io/passwordReset' with your exposed endpoint */}}
5 | here ({{.PasswordResetToken}}).
6 |
7 |
8 |
9 | ({{.PasswordResetToken}})
10 |
11 | {{if index .Claims "myCustomClaim"}} ({{index .Claims "myCustomClaim"}}) {{end}}
12 | Greetings
--------------------------------------------------------------------------------
/internal/storage/claims.go:
--------------------------------------------------------------------------------
1 | package storage
2 |
3 | import (
4 | "database/sql/driver"
5 | "encoding/json"
6 | "fmt"
7 | )
8 |
9 | // Claims encapsulates database json-claims
10 | type Claims map[string]interface{}
11 |
12 | // Scan scan value into Claims
13 | func (j *Claims) Scan(value interface{}) error {
14 | bytes, ok := value.([]byte)
15 | if !ok {
16 | return fmt.Errorf("failed to unmarshal Claims value: %s", value)
17 | }
18 |
19 | err := json.Unmarshal(bytes, &j)
20 | return err
21 | }
22 |
23 | // Value return json value as byte slice
24 | func (j Claims) Value() (driver.Value, error) {
25 | return json.Marshal(j)
26 | }
27 |
--------------------------------------------------------------------------------
/internal/jwt/validate.go:
--------------------------------------------------------------------------------
1 | package jwt
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "github.com/golang-jwt/jwt"
7 | )
8 |
9 | var parseFunc = jwt.ParseWithClaims
10 |
11 | // IsTokenValid validates the given token with the in NewProvider configured privateKey.PublicKeys and return
12 | // isValid indicator, token-claims (when token is valid) and an error when present
13 | // return
14 | func (p Provider) IsTokenValid(tokenAsString string) (isValid bool, claims jwt.MapClaims, err error) {
15 | token, err := parseFunc(tokenAsString, &claims, checkSigningMethodKeyFunc(p.signingMethod, &p.privateKey.PublicKey))
16 | if err != nil {
17 | return false, nil, fmt.Errorf("failed to parse token: %w", err)
18 | }
19 |
20 | if !token.Valid {
21 | return false, nil, errors.New("token is not valid")
22 | }
23 |
24 | return true, claims, nil
25 | }
26 |
--------------------------------------------------------------------------------
/.github/workflows/go.yml:
--------------------------------------------------------------------------------
1 | name: go
2 | on: [ push ]
3 |
4 | jobs:
5 | build:
6 | name: Build
7 | runs-on: ubuntu-latest
8 | steps:
9 | - name: Set up Go 1.15
10 | uses: actions/setup-go@v1
11 | with:
12 | go-version: 1.15
13 | id: go
14 |
15 | - name: Check out code into the Go module directory
16 | uses: actions/checkout@v2
17 |
18 | - name: Get dependencies
19 | run: |
20 | go get -v -t -d ./...
21 | if [ -f Gopkg.toml ]; then
22 | curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh
23 | dep ensure
24 | fi
25 |
26 | - name: Test
27 | run: go test -cover -count=1 ./...
28 |
29 | - name: Vet
30 | run: go vet ./...
31 |
32 | - name: Component-Tests
33 | run: ./component-tests.sh
34 |
--------------------------------------------------------------------------------
/cmd/provider/alive_resource_component_test.go:
--------------------------------------------------------------------------------
1 | // +build component
2 |
3 | package main
4 |
5 | import (
6 | "io/ioutil"
7 | "net/http"
8 | "testing"
9 | )
10 |
11 | func TestAliveResource(t *testing.T) {
12 | res, err := http.Get("http://simple-jwt-provider/v1/internal/alive")
13 | if err != nil {
14 | t.Fatalf("Failed to call simple-jwt-provider: %s", err)
15 | }
16 | defer res.Body.Close()
17 |
18 | if res.StatusCode != http.StatusOK {
19 | t.Errorf("Response status code is not as expected. Expected %d, Given %d", http.StatusOK, res.StatusCode)
20 | }
21 | reqBody, err := ioutil.ReadAll(res.Body)
22 | if err != nil {
23 | t.Fatalf("Failed to read request body: %s", err)
24 | }
25 | expectedResponseBody := `{"alive":true}`
26 | if string(reqBody) != expectedResponseBody {
27 | t.Errorf("Response body is not as expected. Expected %q, Given %q", expectedResponseBody, reqBody)
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Go build
2 | FROM golang as build
3 |
4 | ENV CGO_ENABLED=1
5 | ENV GO111MODULE=on
6 | ENV GOOS=linux
7 | ENV GOPATH=/
8 |
9 | WORKDIR /src/simple-jwt-provider/
10 |
11 | COPY go.mod .
12 | COPY go.sum .
13 | RUN go mod download
14 |
15 | COPY . .
16 | RUN go build -a -ldflags "-linkmode external -extldflags '-static' -s -w" -o simple-jwt-provider ./cmd/provider/
17 | #RUN go build -ldflags -s -a -installsuffix cgo
18 |
19 | # Service definition
20 | FROM alpine
21 |
22 | RUN apk add --update libcap tzdata util-linux-dev && rm -rf /var/cache/apk/*
23 |
24 | COPY --from=build /src/simple-jwt-provider/simple-jwt-provider /simple-jwt-provider
25 |
26 | COPY mail-templates /mail-templates
27 |
28 | RUN setcap CAP_NET_BIND_SERVICE=+eip /simple-jwt-provider
29 |
30 | RUN addgroup -g 1000 -S runnergroup && adduser -u 1001 -S apprunner -G runnergroup
31 | USER apprunner
32 |
33 | ENTRYPOINT ["/simple-jwt-provider"]
34 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 | ## v?.?.? (unreleased)
3 |
4 | ## v2.0.0
5 | - [[#28] replace github.com/dgrijalva/jwt-go with github.com/golang-jwt/jwt](https://github.com/leberKleber/simple-jwt-provider/issues/28)
6 | - [[#7] file based db support](https://github.com/leberKleber/simple-jwt-provider/issues/7)
7 | - [[#25] use gorm as persistence layer](https://github.com/leberKleber/simple-jwt-provider/issues/25)
8 | - [[#22] make log level configurable and log as json](https://github.com/leberKleber/simple-jwt-provider/issues/22)
9 |
10 | ## v1.1.0
11 | - [[#3] secure admin api password config](https://github.com/leberKleber/simple-jwt-provider/issues/10)
12 | - [[#9] make jwt lifetime configurable](https://github.com/leberKleber/simple-jwt-provider/issues/9)
13 | - [[#6] refresh mechanism](https://github.com/leberKleber/simple-jwt-provider/issues/6)
14 |
15 | ## v1.0.0
16 | - password-reset flow via mail
17 | - admin api CRUD
18 | - login
19 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/leberKleber/simple-jwt-provider
2 |
3 | go 1.13
4 |
5 | require (
6 | github.com/DusanKasan/parsemail v1.2.0
7 | github.com/ardanlabs/conf v1.2.1
8 | github.com/golang-jwt/jwt v3.2.1+incompatible
9 | github.com/golang-migrate/migrate/v4 v4.7.1
10 | github.com/google/go-cmp v0.4.0 // indirect
11 | github.com/google/uuid v1.1.1
12 | github.com/gorilla/mux v1.7.3
13 | github.com/lib/pq v1.3.0
14 | github.com/mattn/go-sqlite3 v1.14.5 // indirect
15 | github.com/sirupsen/logrus v1.4.2
16 | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
17 | golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f // indirect
18 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
19 | gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
20 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
21 | gopkg.in/mail.v2 v2.3.1
22 | gopkg.in/yaml.v2 v2.3.0
23 | gorm.io/driver/postgres v1.0.8
24 | gorm.io/driver/sqlite v1.1.4
25 | gorm.io/gorm v1.21.6
26 | gotest.tools v2.2.0+incompatible
27 | )
28 |
--------------------------------------------------------------------------------
/component-tests.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -e
2 |
3 | BUILD_ID=$RANDOM
4 | export BUILD_ID
5 |
6 | networkName="component-tests-$BUILD_ID"
7 |
8 | docker build -f component-tests.Dockerfile -t ct-runner:$BUILD_ID .
9 | docker network create $networkName
10 |
11 | docker-compose -f component-tests.docker-compose.yml up -d --build
12 |
13 | set +e
14 | docker run --rm -e "WUC_EXPECTED=200" \
15 | -e "WUC_WRITE_OUT=%{http_code}" \
16 | -e "WUC_URL=http://simple-jwt-provider/v1/internal/alive" \
17 | -e "WUC_MAX_ITERATIONS=20" \
18 | --network "${networkName}" leberkleber/wait_until_curl
19 | sjp_alive=$?
20 | set -e
21 |
22 | if [[ "$sjp_alive" -ne 0 ]]; then
23 | echo "provider didn't start successfully"
24 | test_result=1
25 | else
26 | set +e
27 | docker run --rm --network "${networkName}" ct-runner:${BUILD_ID}
28 | test_result=$?
29 | set -e
30 | fi
31 |
32 | if [[ "$test_result" -gt 0 ]]; then
33 | docker-compose -f component-tests.docker-compose.yml logs
34 | fi
35 |
36 | docker-compose -f component-tests.docker-compose.yml down
37 | docker network rm $networkName
38 |
39 | exit $test_result
40 |
--------------------------------------------------------------------------------
/.github/workflows/codecov.yml:
--------------------------------------------------------------------------------
1 | name: codecov
2 | on: [ push ]
3 |
4 | jobs:
5 | codecov:
6 | name: codecov
7 | runs-on: ubuntu-latest
8 | steps:
9 |
10 | - name: Set up Go 1.15
11 | uses: actions/setup-go@v1
12 | with:
13 | go-version: 1.15
14 | id: go
15 |
16 | - name: Check out code into the Go module directory
17 | uses: actions/checkout@v1
18 |
19 | - name: Get dependencies
20 | run: |
21 | go get -v -t -d ./...
22 | if [ -f Gopkg.toml ]; then
23 | curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh
24 | dep ensure
25 | fi
26 |
27 | - name: Generate coverage report
28 | run: |
29 | go test `go list ./... | grep -v examples` -coverprofile=coverage.txt -covermode=atomic
30 |
31 | - name: Upload coverage report
32 | uses: codecov/codecov-action@v1.0.2
33 | with:
34 | token: ${{ secrets.CODECOV_TOKEN }}
35 | file: ./coverage.txt
36 | flags: unittests
37 | name: codecov-umbrella
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Max Marche
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/cmd/provider/crud_user_component_test.go:
--------------------------------------------------------------------------------
1 | // +build component
2 |
3 | package main
4 |
5 | import (
6 | "fmt"
7 | "testing"
8 | )
9 |
10 | func TestCRUDUser(t *testing.T) {
11 | // 1) create user
12 | // 2) login
13 | // 3) update user
14 | // 4) login
15 | // 5) get user
16 | // 6) delete user
17 | // 7) login
18 |
19 | email := "crud_test@leberkleber.io"
20 | password := "s3cr3t"
21 |
22 | // 1)
23 | createUser(t, email, password)
24 |
25 | // 2)
26 | loginUser(t, email, password)
27 |
28 | // 3)
29 | newPassword := "n3wS3cr3t"
30 | updateUser(t, email, newPassword, map[string]interface{}{
31 | "myClaim": 5,
32 | })
33 |
34 | // 4)
35 | loginUser(t, email, newPassword)
36 |
37 | // 5)
38 | expectedUser := User{
39 | EMail: email,
40 | Password: "**********",
41 | Claims: map[string]interface{}{
42 | "myClaim": 5,
43 | },
44 | }
45 | user := readUser(t, email)
46 | if fmt.Sprint(user) != fmt.Sprint(expectedUser) {
47 | t.Fatalf("user is not as expected. Expected:\n%#v\nGiven:\n%#v", expectedUser, user)
48 | }
49 |
50 | // 6)
51 | deleteUser(t, email)
52 |
53 | // 7)
54 | loginUser(t, email, newPassword)
55 | }
56 |
--------------------------------------------------------------------------------
/example/urlencode.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # URL encode script for encoding text.
4 | #
5 | # Syntax:
6 | #
7 | # urlencode
8 | #
9 | # Example:
10 | #
11 | # $ urlencode "foo bar"
12 | # foo%20bar
13 | #
14 | # This implementation uses just the shell,
15 | # with no extra dependencies or languages.
16 | #
17 | # Credit:
18 | #
19 | # * https://gist.github.com/cdown/1163649
20 | #
21 | # Links:
22 | #
23 | # * https://github.com/sixarm/urlencode.sh
24 | # * https://github.com/sixarm/urldecode.sh
25 | #
26 | # Command: urlencode
27 | # Version: 1.0.0
28 | # Created: 2016-09-12
29 | # Updated: 2016-09-12
30 | # License: MIT
31 | # Contact: Joel Parker Henderson (joel@joelparkerhenderson.com)
32 | ##
33 |
34 | urlencode() {
35 | # urlencode
36 |
37 | old_lang=$LANG
38 | LANG=C
39 |
40 | old_lc_collate=$LC_COLLATE
41 | LC_COLLATE=C
42 |
43 | local length="${#1}"
44 | for ((i = 0; i < length; i++)); do
45 | local c="${1:i:1}"
46 | case $c in
47 | [a-zA-Z0-9.~_-]) printf "$c" ;;
48 | *) printf '%%%02X' "'$c" ;;
49 | esac
50 | done
51 |
52 | LANG=$old_lang
53 | LC_COLLATE=$old_lc_collate
54 | }
55 |
56 | urlencode "$1"
57 |
--------------------------------------------------------------------------------
/example/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3.5'
2 | services:
3 | simple-jwt-provider:
4 | image: leberkleber/simple-jwt-provider:v2.0.0
5 | restart: on-failure
6 | ports:
7 | - "8080:8080"
8 | environment:
9 | SJP_SERVER_ADDRESS: ":8080"
10 | SJP_JWT_PRIVATE_KEY: "\n-----BEGIN EC PRIVATE KEY-----\nMIHcAgEBBEIASzDZeTVLxcE5KTAmwrKwFjzr5cDrA+tttx9XRUz0K7AlROtj7cMG\nrHu/bdKj7lc2WaW8x/EOrU/FeCcsIL5nTH+gBwYFK4EEACOhgYkDgYYABAFBJr90\nWldGrPppBCbHqw2nGXeafxnSj6qB+A7E8A/G74mmmwIaqtf/pJ5QjqTPcAVUAEYF\nTz/0SPO3tPL1Ym3V0QH7TfnTf7EueabJqPdsSGR6uvbb2YOA9vy4OU8SXp/9a/4x\nr94giWgKjxYkB7xiy+IiZsWEBXU0rz7rb+IwJ82PfQ==\n-----END EC PRIVATE KEY-----"
11 | SJP_DATABASE_TYPE: "sqlite"
12 | SJP_DATABASE_DSN: "file::memory:?cache=shared"
13 | SJP_ADMIN_API_ENABLE: "true"
14 | SJP_ADMIN_API_USERNAME: "username"
15 | SJP_ADMIN_API_PASSWORD: "password"
16 | SJP_MAIL_SMTP_HOST: "smtp"
17 | SJP_MAIL_SMTP_PORT: 1025
18 | SJP_MAIL_SMTP_PASSWORD: ""
19 | SJP_MAIL_SMTP_USERNAME: ""
20 | SJP_MAIL_TLS_INSECURE_SKIP_VERIFY: "true"
21 | SJP_MAIL_TLS_SERVER_NAME: "smtp"
22 |
23 | smtp:
24 | image: mailhog/mailhog
25 | restart: always
26 | ports:
27 | - "8025:8025"
28 |
--------------------------------------------------------------------------------
/internal/storage/storage.go:
--------------------------------------------------------------------------------
1 | package storage
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "gorm.io/driver/postgres"
7 | "gorm.io/driver/sqlite"
8 | "gorm.io/gorm"
9 | )
10 |
11 | const dbTypePostgres = "postgres"
12 | const dbTypeSQLite = "sqlite"
13 |
14 | var sqlOpen = gorm.Open
15 |
16 | // Storage should be created via New and provides user and token database operation. Before access database Migrate should be called
17 | type Storage struct {
18 | db *gorm.DB
19 | }
20 |
21 | // New opens a new sql connection with the given configuration
22 | func New(dbType, dsn string) (*Storage, error) {
23 | dialector, err := buildDialector(dbType, dsn)
24 | if err != nil {
25 | return nil, err
26 | }
27 |
28 | db, err := sqlOpen(dialector, &gorm.Config{})
29 | if err != nil {
30 | return nil, fmt.Errorf("failed to open database connection: %w", err)
31 | }
32 |
33 | err = db.AutoMigrate(User{}, Token{})
34 | if err != nil {
35 | return nil, fmt.Errorf("failed to auto-migrate persistence: %w", err)
36 | }
37 |
38 | return &Storage{
39 | db: db,
40 | }, nil
41 | }
42 |
43 | func buildDialector(dbType, dsn string) (gorm.Dialector, error) {
44 | var dialector gorm.Dialector
45 |
46 | switch dbType {
47 | case dbTypePostgres:
48 | dialector = postgres.Open(dsn)
49 | case dbTypeSQLite:
50 | dialector = sqlite.Open(dsn)
51 | default:
52 | return nil, errors.New("unsupported database type")
53 | }
54 |
55 | return dialector, nil
56 | }
57 |
--------------------------------------------------------------------------------
/internal/provider.go:
--------------------------------------------------------------------------------
1 | package internal
2 |
3 | import (
4 | "github.com/golang-jwt/jwt"
5 | "github.com/leberKleber/simple-jwt-provider/internal/storage"
6 | )
7 |
8 | // Storage encapsulates storage.Storage to generate mocks
9 | //go:generate moq -out storage_moq_test.go . Storage
10 | type Storage interface {
11 | User(email string) (storage.User, error)
12 | CreateUser(user storage.User) error
13 | UpdateUser(user storage.User) error
14 | DeleteUser(email string) error
15 | CreateToken(t *storage.Token) error
16 | TokensByEMailAndToken(email, token string) ([]storage.Token, error)
17 | DeleteToken(id uint) error
18 | }
19 |
20 | // JWTProvider encapsulates jwt.Provider to generate mocks
21 | //go:generate moq -out jwt_generator_moq_test.go . JWTProvider
22 | type JWTProvider interface {
23 | GenerateAccessToken(email string, userClaims map[string]interface{}) (string, error)
24 | GenerateRefreshToken(email string) (string, string, error)
25 | IsTokenValid(token string) (bool, jwt.MapClaims, error)
26 | }
27 |
28 | // Mailer encapsulates mailer.Mailer to generate mocks
29 | //go:generate moq -out mailer_moq_test.go . Mailer
30 | type Mailer interface {
31 | SendPasswordResetRequestEMail(recipient, passwordResetToken string, claims map[string]interface{}) error
32 | }
33 |
34 | // Provider provides all necessary interfaces for use in internal
35 | type Provider struct {
36 | Storage Storage
37 | JWTProvider JWTProvider
38 | Mailer Mailer
39 | }
40 |
--------------------------------------------------------------------------------
/internal/web/middleware/basicauth.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "github.com/sirupsen/logrus"
5 | "golang.org/x/crypto/bcrypt"
6 | "net/http"
7 | "strings"
8 | )
9 |
10 | const bcryptedPasswordPrefix = "bcrypt:"
11 |
12 | // BasicAuth builds a basic auth http.Handler middleware which blocks all unauthorized request and respond with a
13 | // http status 403
14 | func BasicAuth(username, password string) func(h http.Handler) http.Handler {
15 | passwordIsBcrypted := strings.HasPrefix(password, bcryptedPasswordPrefix)
16 | if passwordIsBcrypted {
17 | password = strings.Replace(password, bcryptedPasswordPrefix, "", 1)
18 | }
19 |
20 | return func(next http.Handler) http.Handler {
21 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
22 | u, p, ok := r.BasicAuth()
23 | if !ok {
24 | unauthorized(w)
25 | return
26 | }
27 |
28 | if u != username {
29 | unauthorized(w)
30 | return
31 | }
32 |
33 | if passwordIsBcrypted {
34 | err := bcrypt.CompareHashAndPassword([]byte(password), []byte(p))
35 | if err != nil {
36 | unauthorized(w)
37 | return
38 | }
39 | } else {
40 | if p != password {
41 | unauthorized(w)
42 | return
43 | }
44 | }
45 |
46 | next.ServeHTTP(w, r)
47 | })
48 | }
49 | }
50 |
51 | func unauthorized(w http.ResponseWriter) {
52 | w.WriteHeader(http.StatusForbidden)
53 | _, err := w.Write([]byte(`{"message": "forbidden"}`))
54 | if err != nil {
55 | logrus.WithError(err).Error("Failed to write unauthorized http response body")
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/internal/web/internal_test.go:
--------------------------------------------------------------------------------
1 | package web
2 |
3 | import (
4 | "bytes"
5 | "encoding/json"
6 | "io/ioutil"
7 | "net/http"
8 | "net/http/httptest"
9 | "testing"
10 | )
11 |
12 | func TestAliveHandler(t *testing.T) {
13 | expectedResponseCode := http.StatusOK
14 | expectedResponseBody := `{"alive":true}`
15 |
16 | toTest := NewServer(nil, false, "", "")
17 | testServer := httptest.NewServer(toTest.h)
18 |
19 | req, err := http.NewRequest(http.MethodGet, testServer.URL+"/v1/internal/alive", nil)
20 | if err != nil {
21 | t.Fatalf("Failed to build http request: %s", err)
22 | }
23 |
24 | resp, err := http.DefaultClient.Do(req)
25 | if err != nil {
26 | t.Fatalf("Failed to call server cause: %s", err)
27 | }
28 | defer resp.Body.Close()
29 |
30 | if resp.StatusCode != expectedResponseCode {
31 | t.Errorf("Request respond with unexpected status code. Expected: %d, Given: %d", expectedResponseCode, resp.StatusCode)
32 | }
33 |
34 | respBody, err := ioutil.ReadAll(resp.Body)
35 | if err != nil {
36 | t.Fatalf("Failed to read response body: %s", err)
37 | }
38 |
39 | var compactedRespBodyAsBytes []byte
40 | if resp.ContentLength > 0 {
41 | compactedRespBody := &bytes.Buffer{}
42 | err = json.Compact(compactedRespBody, respBody)
43 | if err != nil {
44 | t.Fatalf("Failed to compact json: %s", err)
45 | }
46 |
47 | compactedRespBodyAsBytes = compactedRespBody.Bytes()
48 | }
49 |
50 | if !bytes.Equal(compactedRespBodyAsBytes, []byte(expectedResponseBody)) {
51 | t.Errorf("Request response body is not as expected. Expected: %q, Given: %q", expectedResponseBody, string(compactedRespBodyAsBytes))
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/component-tests.docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3.5'
2 | networks:
3 | component-tests:
4 | external: true
5 | name: component-tests-${BUILD_ID}
6 |
7 | services:
8 | simple-jwt-provider:
9 | build:
10 | context: .
11 | restart: on-failure
12 | environment:
13 | SJP_JWT_PRIVATE_KEY: "\n-----BEGIN EC PRIVATE KEY-----\nMIHcAgEBBEIASzDZeTVLxcE5KTAmwrKwFjzr5cDrA+tttx9XRUz0K7AlROtj7cMG\nrHu/bdKj7lc2WaW8x/EOrU/FeCcsIL5nTH+gBwYFK4EEACOhgYkDgYYABAFBJr90\nWldGrPppBCbHqw2nGXeafxnSj6qB+A7E8A/G74mmmwIaqtf/pJ5QjqTPcAVUAEYF\nTz/0SPO3tPL1Ym3V0QH7TfnTf7EueabJqPdsSGR6uvbb2YOA9vy4OU8SXp/9a/4x\nr94giWgKjxYkB7xiy+IiZsWEBXU0rz7rb+IwJ82PfQ==\n-----END EC PRIVATE KEY-----"
14 | SJP_JWT_AUDIENCE: ""
15 | SJP_JWT_ISSUER: ""
16 | SJP_JWT_SUBJECT: ""
17 | SJP_DATABASE_TYPE: "postgres"
18 | SJP_DATABASE_DSN: "host=db user=postgres password=postgres dbname=simple-jwt-provider port=5432 sslmode=disable"
19 | SJP_ADMIN_API_ENABLE: "true"
20 | SJP_ADMIN_API_USERNAME: "username"
21 | # escape $ with $
22 | SJP_ADMIN_API_PASSWORD: "bcrypt:$$2y$$12$$eOiNiEyREa2viPff8suTR.vw.HZSOSLGZE2ozfonFRn6w4HkV4Dbe"
23 | SJP_MAIL_SMTP_HOST: "mail-server"
24 | SJP_MAIL_SMTP_PORT: 1025
25 | SJP_MAIL_SMTP_PASSWORD: ""
26 | SJP_MAIL_SMTP_USERNAME: ""
27 | SJP_MAIL_TLS_INSECURE_SKIP_VERIFY: "true"
28 | SJP_MAIL_TLS_SERVER_NAME: "mail-server"
29 | networks:
30 | - component-tests
31 |
32 | db:
33 | image: postgres
34 | restart: always
35 | environment:
36 | POSTGRES_DB: "simple-jwt-provider"
37 | POSTGRES_USER: "postgres"
38 | POSTGRES_PASSWORD: "postgres"
39 | networks:
40 | - component-tests
41 |
42 | mail-server:
43 | image: mailhog/mailhog
44 | restart: always
45 | networks:
46 | - component-tests
--------------------------------------------------------------------------------
/internal/storage/token.go:
--------------------------------------------------------------------------------
1 | package storage
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "gorm.io/gorm"
7 | )
8 |
9 | // ErrTokenNotFound returned when no token could be found
10 | var ErrTokenNotFound = errors.New("no token found")
11 |
12 | // TokenTypeReset identifies a token as reset-token. Then it can only be used for password-reset
13 | const TokenTypeReset string = "reset"
14 |
15 | // TokenTypeRefresh identifies a token as refresh-token. Then it can only be used for refresh
16 | const TokenTypeRefresh string = "refresh"
17 |
18 | // Token represent a persisted token
19 | type Token struct {
20 | gorm.Model
21 | EMail string
22 | Token string
23 | Type string
24 | }
25 |
26 | // CreateToken persists the given token in database. EMail must match to a users email. ID will be set automatically.
27 | func (s Storage) CreateToken(t *Token) error {
28 | res := s.db.Create(t)
29 |
30 | if res.Error != nil {
31 | return fmt.Errorf("failed to exec create token: %w", res.Error)
32 | }
33 |
34 | return nil
35 | }
36 |
37 | // TokensByEMailAndToken finds all tokens which matches the given email and token.
38 | func (s Storage) TokensByEMailAndToken(email, token string) ([]Token, error) {
39 | var tokens []Token
40 | res := s.db.Find(&tokens, &Token{EMail: email, Token: token})
41 |
42 | if res.Error != nil {
43 | return nil, fmt.Errorf("failed to exec select token stmt: %w", res.Error)
44 | }
45 |
46 | return tokens, nil
47 | }
48 |
49 | // DeleteToken deletes token with the given ID.
50 | // return ErrTokenNotFound there is no token with the given ID
51 | func (s Storage) DeleteToken(id uint) error {
52 | res := s.db.Delete(&Token{}, id)
53 | if res.Error != nil {
54 | return fmt.Errorf("failed to delete token: %w", res.Error)
55 | }
56 |
57 | if res.RowsAffected < 1 {
58 | return ErrTokenNotFound
59 | }
60 |
61 | return nil
62 | }
63 |
--------------------------------------------------------------------------------
/cmd/provider/refresh_component_test.go:
--------------------------------------------------------------------------------
1 | // +build component
2 |
3 | package main
4 |
5 | import (
6 | "bytes"
7 | "encoding/json"
8 | "fmt"
9 | "net/http"
10 | "testing"
11 | )
12 |
13 | func TestRefreshPassword(t *testing.T) {
14 | email := "refresh_test@leberkleber.io"
15 | password := "s3cr3t"
16 |
17 | createUser(t, email, password)
18 | _, refreshToken, authorized := loginUser(t, email, password)
19 | if !authorized {
20 | t.Fatal("failed to auth user")
21 | }
22 |
23 | accessToken, newRefreshToken := refresh(t, refreshToken)
24 |
25 | validateJWT(t, accessToken)
26 | validateJWT(t, newRefreshToken)
27 |
28 | accessToken, newRefreshToken = refresh(t, newRefreshToken)
29 |
30 | validateJWT(t, accessToken)
31 | validateJWT(t, newRefreshToken)
32 | }
33 |
34 | func refresh(t *testing.T, refreshToken string) (string, string) {
35 | t.Helper()
36 | resp, err := http.Post(
37 | "http://simple-jwt-provider/v1/auth/refresh",
38 | "application/json",
39 | bytes.NewReader([]byte(fmt.Sprintf(`{"refresh_token": %q}`, refreshToken))),
40 | )
41 | if err != nil {
42 | t.Fatalf("Failed to refresh with response: %v cause: %s", resp, err)
43 | }
44 |
45 | responseBody := struct {
46 | AccessToken string `json:"access_token"`
47 | RefreshToken string `json:"refresh_token"`
48 | ErrorMessage string `json:"message"`
49 | }{}
50 |
51 | defer resp.Body.Close()
52 | err = json.NewDecoder(resp.Body).Decode(&responseBody)
53 | if err != nil {
54 | t.Fatalf("Failed to read response body: %s", err)
55 | }
56 |
57 | if resp.StatusCode == http.StatusUnauthorized {
58 | return "", ""
59 | }
60 | if resp.StatusCode != http.StatusOK {
61 | t.Fatalf("Invalid response status code. Expected: %d, Given: %d, Body: %s", http.StatusOK, resp.StatusCode, responseBody.ErrorMessage)
62 | }
63 |
64 | return responseBody.AccessToken, responseBody.RefreshToken
65 |
66 | }
67 |
--------------------------------------------------------------------------------
/internal/mailer/template_moq_test.go:
--------------------------------------------------------------------------------
1 | // Code generated by moq; DO NOT EDIT.
2 | // github.com/matryer/moq
3 |
4 | package mailer
5 |
6 | import (
7 | "gopkg.in/mail.v2"
8 | "sync"
9 | )
10 |
11 | // Ensure, that templateMock does implement template.
12 | // If this is not the case, regenerate this file with moq.
13 | var _ template = &templateMock{}
14 |
15 | // templateMock is a mock implementation of template.
16 | //
17 | // func TestSomethingThatUsestemplate(t *testing.T) {
18 | //
19 | // // make and configure a mocked template
20 | // mockedtemplate := &templateMock{
21 | // RenderFunc: func(args interface{}) (*mail.Message, error) {
22 | // panic("mock out the Render method")
23 | // },
24 | // }
25 | //
26 | // // use mockedtemplate in code that requires template
27 | // // and then make assertions.
28 | //
29 | // }
30 | type templateMock struct {
31 | // RenderFunc mocks the Render method.
32 | RenderFunc func(args interface{}) (*mail.Message, error)
33 |
34 | // calls tracks calls to the methods.
35 | calls struct {
36 | // Render holds details about calls to the Render method.
37 | Render []struct {
38 | // Args is the args argument value.
39 | Args interface{}
40 | }
41 | }
42 | lockRender sync.RWMutex
43 | }
44 |
45 | // Render calls RenderFunc.
46 | func (mock *templateMock) Render(args interface{}) (*mail.Message, error) {
47 | if mock.RenderFunc == nil {
48 | panic("templateMock.RenderFunc: method is nil but template.Render was just called")
49 | }
50 | callInfo := struct {
51 | Args interface{}
52 | }{
53 | Args: args,
54 | }
55 | mock.lockRender.Lock()
56 | mock.calls.Render = append(mock.calls.Render, callInfo)
57 | mock.lockRender.Unlock()
58 | return mock.RenderFunc(args)
59 | }
60 |
61 | // RenderCalls gets all the calls that were made to Render.
62 | // Check the length with:
63 | // len(mockedtemplate.RenderCalls())
64 | func (mock *templateMock) RenderCalls() []struct {
65 | Args interface{}
66 | } {
67 | var calls []struct {
68 | Args interface{}
69 | }
70 | mock.lockRender.RLock()
71 | calls = mock.calls.Render
72 | mock.lockRender.RUnlock()
73 | return calls
74 | }
75 |
--------------------------------------------------------------------------------
/internal/jwt/provider.go:
--------------------------------------------------------------------------------
1 | package jwt
2 |
3 | import (
4 | "crypto/ecdsa"
5 | "crypto/x509"
6 | "encoding/pem"
7 | "errors"
8 | "fmt"
9 | "github.com/golang-jwt/jwt"
10 | "reflect"
11 | "strings"
12 | "time"
13 | )
14 |
15 | var x509ParseECPrivateKey = x509.ParseECPrivateKey
16 |
17 | // Provider should be created via NewProvider and creates JWTs via Generate with static and custom claims
18 | type Provider struct {
19 | jwtLifetime time.Duration
20 | privateKey *ecdsa.PrivateKey
21 | signingMethod *jwt.SigningMethodECDSA
22 | privateClaims struct {
23 | audience string
24 | issuer string
25 | subject string
26 | }
27 | }
28 |
29 | // NewProvider a Provider instance with the given jwt-configuration. Before instantiation the private key will be
30 | // checked and parsed
31 | func NewProvider(privateKey string, jwtLifetime time.Duration, jwtAudience, jwtIssuer, jwtSubject string) (*Provider, error) {
32 | privateKey = strings.Replace(privateKey, `\n`, "\n", -1) //TODO fix me (needed for start via ide)
33 | blockPrv, _ := pem.Decode([]byte(privateKey))
34 | if blockPrv == nil {
35 | return nil, errors.New("no valid private key found")
36 | }
37 |
38 | pKey, err := x509ParseECPrivateKey(blockPrv.Bytes)
39 | if err != nil {
40 | return nil, fmt.Errorf("failed to parse private-key: %w", err)
41 | }
42 |
43 | return &Provider{
44 | jwtLifetime: jwtLifetime,
45 | privateKey: pKey,
46 | signingMethod: jwt.SigningMethodES512,
47 | privateClaims: struct {
48 | audience string
49 | issuer string
50 | subject string
51 | }{
52 | audience: jwtAudience,
53 | issuer: jwtIssuer,
54 | subject: jwtSubject,
55 | },
56 | }, err
57 | }
58 |
59 | var checkSigningMethodKeyFunc = func(signingMethod jwt.SigningMethod, publicKey *ecdsa.PublicKey) jwt.Keyfunc {
60 | return func(token *jwt.Token) (interface{}, error) {
61 | tokenSigningMethod := reflect.TypeOf(token.Method)
62 | expectedSigningMethod := reflect.TypeOf(signingMethod)
63 | if tokenSigningMethod != expectedSigningMethod {
64 | return nil, fmt.Errorf("unexpected signing method %q, expected: %q", tokenSigningMethod, expectedSigningMethod)
65 | }
66 |
67 | return publicKey, nil
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/cmd/provider/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "github.com/ardanlabs/conf"
5 | "github.com/leberKleber/simple-jwt-provider/internal"
6 | "github.com/leberKleber/simple-jwt-provider/internal/jwt"
7 | "github.com/leberKleber/simple-jwt-provider/internal/mailer"
8 | "github.com/leberKleber/simple-jwt-provider/internal/storage"
9 | "github.com/leberKleber/simple-jwt-provider/internal/web"
10 | "github.com/sirupsen/logrus"
11 | "net/http"
12 |
13 | // database migration
14 | _ "github.com/golang-migrate/migrate/v4/source/file"
15 | // sql driver
16 | _ "github.com/lib/pq"
17 | )
18 |
19 | func main() {
20 | logrus.SetFormatter(&logrus.JSONFormatter{})
21 |
22 | cfg, err := newConfig()
23 | if err != nil {
24 | logrus.WithError(err).Fatal("Failed to parse config")
25 | }
26 |
27 | logLvl, err := logrus.ParseLevel(cfg.LogLevel)
28 | if err != nil {
29 | logrus.WithError(err).Fatal("Failed to parse log-level")
30 | }
31 | logrus.SetLevel(logLvl)
32 |
33 | cfgAsString, err := conf.String(&cfg)
34 | if err != nil {
35 | logrus.WithError(err).Fatal("Could not build config string")
36 | }
37 | logrus.WithField("configuration", cfgAsString).Info("Starting provider")
38 |
39 | s, err := storage.New(cfg.Database.Type, cfg.Database.DSN)
40 | if err != nil {
41 | logrus.WithError(err).Fatal("Could not create storage")
42 | }
43 |
44 | jwtGenerator, err := jwt.NewProvider(cfg.JWT.PrivateKey, cfg.JWT.Lifetime, cfg.JWT.Audience, cfg.JWT.Issuer, cfg.JWT.Subject)
45 | if err != nil {
46 | logrus.WithError(err).Fatal("Failed to create jwt generator")
47 | }
48 |
49 | m, err := mailer.New(cfg.Mail.TemplatesFolderPath,
50 | cfg.Mail.SMTPUsername,
51 | cfg.Mail.SMTPPassword,
52 | cfg.Mail.SMTPHost,
53 | cfg.Mail.SMTPPort,
54 | cfg.Mail.TLS.InsecureSkipVerify,
55 | cfg.Mail.TLS.ServerName,
56 | )
57 | if err != nil {
58 | logrus.WithError(err).Fatal("Failed to create mailer")
59 | }
60 |
61 | provider := &internal.Provider{Storage: s, JWTProvider: jwtGenerator, Mailer: m}
62 | server := web.NewServer(provider, cfg.AdminAPI.Enable, cfg.AdminAPI.Username, cfg.AdminAPI.Password)
63 |
64 | err = server.ListenAndServe(cfg.ServerAddress)
65 | if err != nil && err != http.ErrServerClosed {
66 | logrus.WithError(err).Fatal("Failed to run server")
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/internal/mailer/template.go:
--------------------------------------------------------------------------------
1 | package mailer
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "gopkg.in/mail.v2"
7 | "gopkg.in/yaml.v2"
8 | htmlTemplate "html/template"
9 | "path/filepath"
10 | textTemplate "text/template"
11 | )
12 |
13 | const passwordResetRequestTemplateName = "password-reset-request"
14 |
15 | var htmlTemplateParseFiles = htmlTemplate.ParseFiles
16 | var textTemplateParseFiles = textTemplate.ParseFiles
17 | var ymlTemplateParseFiles = textTemplate.ParseFiles
18 |
19 | type mailTemplate struct {
20 | name string
21 | htmlTmpl *htmlTemplate.Template
22 | textTmpl *textTemplate.Template
23 | headerTmpl *textTemplate.Template
24 | }
25 |
26 | var loadTemplates = func(path, name string) (mailTemplate, error) {
27 | htmlTmpl, err := htmlTemplateParseFiles(filepath.Join(path, fmt.Sprintf("%s.html", name)))
28 | if err != nil {
29 | return mailTemplate{}, fmt.Errorf("failed to load mail html body template: %w", err)
30 | }
31 |
32 | textTmpl, err := textTemplateParseFiles(filepath.Join(path, fmt.Sprintf("%s.txt", name)))
33 | if err != nil {
34 | return mailTemplate{}, fmt.Errorf("failed to load mail text body template: %w", err)
35 | }
36 |
37 | headerTmpl, err := ymlTemplateParseFiles(filepath.Join(path, fmt.Sprintf("%s.yml", name)))
38 | if err != nil {
39 | return mailTemplate{}, fmt.Errorf("failed to load mail headers template: %w", err)
40 | }
41 |
42 | return mailTemplate{
43 | name: name,
44 | htmlTmpl: htmlTmpl,
45 | textTmpl: textTmpl,
46 | headerTmpl: headerTmpl,
47 | }, nil
48 | }
49 |
50 | func (t mailTemplate) Render(args interface{}) (*mail.Message, error) {
51 | msg := mail.NewMessage()
52 |
53 | err := renderHeaders(msg, t.headerTmpl, args)
54 | if err != nil {
55 | return nil, fmt.Errorf("failed to render mail headers: %w", err)
56 | }
57 |
58 | var buf bytes.Buffer
59 | err = t.textTmpl.Execute(&buf, args)
60 | if err != nil {
61 | return nil, fmt.Errorf("failed to render mail text body: %w", err)
62 | }
63 | msg.SetBody("text/plain", buf.String())
64 |
65 | buf.Reset()
66 | err = t.htmlTmpl.Execute(&buf, args)
67 | if err != nil {
68 | return nil, fmt.Errorf("failed to render mail html body: %w", err)
69 | }
70 | msg.AddAlternative("text/html", buf.String())
71 |
72 | return msg, nil
73 | }
74 |
75 | func renderHeaders(msg *mail.Message, template *textTemplate.Template, args interface{}) error {
76 | var buf bytes.Buffer
77 | err := template.Execute(&buf, args)
78 | if err != nil {
79 | return err
80 | }
81 |
82 | headers := make(map[string][]string)
83 | err = yaml.NewDecoder(&buf).Decode(&headers)
84 | if err != nil {
85 | return err
86 | }
87 | msg.SetHeaders(headers)
88 | return nil
89 | }
90 |
--------------------------------------------------------------------------------
/.github/workflows/codeql-analysis.yml:
--------------------------------------------------------------------------------
1 | # For most projects, this workflow file will not need changing; you simply need
2 | # to commit it to your repository.
3 | #
4 | # You may wish to alter this file to override the set of languages analyzed,
5 | # or to provide custom queries or build logic.
6 | #
7 | # ******** NOTE ********
8 | # We have attempted to detect the languages in your repository. Please check
9 | # the `language` matrix defined below to confirm you have the correct set of
10 | # supported CodeQL languages.
11 | #
12 | name: "CodeQL"
13 |
14 | on:
15 | push:
16 | branches: [ master ]
17 | pull_request:
18 | # The branches below must be a subset of the branches above
19 | branches: [ master ]
20 | schedule:
21 | - cron: '37 21 * * 2'
22 |
23 | jobs:
24 | analyze:
25 | name: Analyze
26 | runs-on: ubuntu-latest
27 | permissions:
28 | actions: read
29 | contents: read
30 | security-events: write
31 |
32 | strategy:
33 | fail-fast: false
34 | matrix:
35 | language: [ 'go' ]
36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ]
37 | # Learn more:
38 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed
39 |
40 | steps:
41 | - name: Checkout repository
42 | uses: actions/checkout@v2
43 |
44 | # Initializes the CodeQL tools for scanning.
45 | - name: Initialize CodeQL
46 | uses: github/codeql-action/init@v1
47 | with:
48 | languages: ${{ matrix.language }}
49 | # If you wish to specify custom queries, you can do so here or in a config file.
50 | # By default, queries listed here will override any specified in a config file.
51 | # Prefix the list here with "+" to use these queries and those in the config file.
52 | # queries: ./path/to/local/query, your-org/your-repo/queries@main
53 |
54 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
55 | # If this step fails, then you should remove it and run the build manually (see below)
56 | - name: Autobuild
57 | uses: github/codeql-action/autobuild@v1
58 |
59 | # ℹ️ Command-line programs to run using the OS shell.
60 | # 📚 https://git.io/JvXDl
61 |
62 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
63 | # and modify them (or add more) to build your code if your project
64 | # uses a compiled language
65 |
66 | #- run: |
67 | # make bootstrap
68 | # make release
69 |
70 | - name: Perform CodeQL Analysis
71 | uses: github/codeql-action/analyze@v1
72 |
--------------------------------------------------------------------------------
/internal/jwt/generate.go:
--------------------------------------------------------------------------------
1 | package jwt
2 |
3 | import (
4 | "fmt"
5 | "github.com/golang-jwt/jwt"
6 | "github.com/google/uuid"
7 | "time"
8 | )
9 |
10 | var timeNow = time.Now
11 | var uuidNewRandom = uuid.NewRandom
12 |
13 | const refreshTokenLifetime = 7 * 24 * time.Hour
14 |
15 | // GenerateAccessToken generates a valid access-jwt based on the Provider.privateKey. The jwt is issued to the given email and enriched
16 | // with the given claims.
17 | // 'userClaims' can be contain all json compatible types
18 | func (p Provider) GenerateAccessToken(email string, userClaims map[string]interface{}) (string, error) {
19 | now := timeNow()
20 | jwtID, err := uuidNewRandom()
21 | if err != nil {
22 | return "", fmt.Errorf("failed to generate jwt-id: %w", err)
23 | }
24 |
25 | claims := jwt.MapClaims{}
26 | if userClaims != nil {
27 | claims = userClaims
28 | }
29 |
30 | // standard claims by https://tools.ietf.org/html/rfc7519#section-4.1
31 | claims["aud"] = p.privateClaims.audience //Audience
32 | claims["exp"] = now.Add(p.jwtLifetime).Unix() //ExpiresAt
33 | claims["jit"] = jwtID.String() //Id
34 | claims["iat"] = now.Unix() //IssuedAt
35 | claims["iss"] = p.privateClaims.issuer //Issuer
36 | claims["nbf"] = now.Unix() //NotBefore
37 | claims["sub"] = p.privateClaims.subject //Subject
38 |
39 | // public claims by https://www.iana.org/assignments/jwt/jwt.xhtml#claims
40 | claims["email"] = email // Preferred e-mail address
41 |
42 | token, err := jwt.NewWithClaims(p.signingMethod, claims).SignedString(p.privateKey)
43 | if err != nil {
44 | return "", fmt.Errorf("failed to sign access-token: %w", err)
45 | }
46 |
47 | return token, nil
48 | }
49 |
50 | // GenerateRefreshToken generates a valid refresh-jwt based on the Provider.privateKey. The jwt is issued to the given email.
51 | func (p Provider) GenerateRefreshToken(email string) (string, string, error) {
52 | now := timeNow()
53 | jwtID, err := uuidNewRandom()
54 | if err != nil {
55 | return "", "", fmt.Errorf("failed to generate jwt-id: %w", err)
56 | }
57 |
58 | claims := jwt.MapClaims{}
59 |
60 | // standard claims by https://tools.ietf.org/html/rfc7519#section-4.1
61 | claims["aud"] = p.privateClaims.audience //Audience
62 | claims["exp"] = now.Add(refreshTokenLifetime).Unix() //ExpiresAt
63 | claims["jit"] = jwtID.String() //Id
64 | claims["iat"] = now.Unix() //IssuedAt
65 | claims["iss"] = p.privateClaims.issuer //Issuer
66 | claims["nbf"] = now.Unix() //NotBefore
67 | claims["sub"] = p.privateClaims.subject //Subject
68 |
69 | // public claims by https://www.iana.org/assignments/jwt/jwt.xhtml#claims
70 | claims["email"] = email // Preferred e-mail address
71 |
72 | token, err := jwt.NewWithClaims(p.signingMethod, claims).SignedString(p.privateKey)
73 | if err != nil {
74 | return "", "", fmt.Errorf("failed to sign refresh-token: %w", err)
75 | }
76 |
77 | return token, jwtID.String(), nil
78 | }
79 |
--------------------------------------------------------------------------------
/cmd/provider/config.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "github.com/ardanlabs/conf"
7 | "os"
8 | "time"
9 | )
10 |
11 | var confUsage = conf.Usage
12 |
13 | type config struct {
14 | LogLevel string `conf:"env:LOG_LEVEL,help:Log-Level can be TRACE DEBUG INFO WARN ERROR FATAL or PANIC,default:INFO"`
15 | ServerAddress string `conf:"env:SERVER_ADDRESS,help:Server-address network-interface to bind on e.g.: '127.0.0.1:8080',default:0.0.0.0:80"`
16 | JWT struct {
17 | Lifetime time.Duration `conf:"env:JWT_LIFETIME,help:Lifetime of JWT,default:4h"`
18 | PrivateKey string `conf:"env:JWT_PRIVATE_KEY,help:JWT PrivateKey ECDSA512,required,noprint"`
19 | Audience string `conf:"env:JWT_AUDIENCE,help:Audience private claim which will be applied in each JWT"`
20 | Issuer string `conf:"env:JWT_ISSUER,help:Issuer private claim which will be applied in each JWT"`
21 | Subject string `conf:"env:JWT_SUBJECT,help:Subject private claim which will be applied in each JWT"`
22 | }
23 | Database struct {
24 | Type string `conf:"env:DATABASE_TYPE,help:Database type. Currently supported postgres and sqlite,required"`
25 | DSN string `conf:"env:DATABASE_DSN,help:Data Source Name for persistence,required,noprint"`
26 | }
27 | AdminAPI struct {
28 | Enable bool `conf:"env:ADMIN_API_ENABLE,help:Enable admin API to manage stored users (true / false),default:false"`
29 | Username string `conf:"env:ADMIN_API_USERNAME,help:Basic Auth Username if enable-admin-api = true"`
30 | Password string `conf:"env:ADMIN_API_PASSWORD,help:Basic Auth Password if enable-admin-api = true when is bcrypted prefix with 'bcrypt',noprint"`
31 | }
32 | Mail struct {
33 | TemplatesFolderPath string `conf:"env:MAIL_TEMPLATES_FOLDER_PATH,help:Path to mail-templates folder,default:/mail-templates"`
34 | SMTPHost string `conf:"env:MAIL_SMTP_HOST,help:SMTP host to connect to,required"`
35 | SMTPPort int `conf:"env:MAIL_SMTP_PORT,help:SMTP port to connect to,default:587"`
36 | SMTPUsername string `conf:"env:MAIL_SMTP_USERNAME,help:SMTP username to authorize with,required"`
37 | SMTPPassword string `conf:"env:MAIL_SMTP_PASSWORD,help:SMTP password to authorize with,required,noprint"`
38 | TLS struct {
39 | InsecureSkipVerify bool `conf:"env:MAIL_TLS_INSECURE_SKIP_VERIFY,help:true if certificates should not be verified,default:false"`
40 | ServerName string `conf:"env:MAIL_TLS_SERVER_NAME,help:name of the server who expose the certificate"`
41 | }
42 | }
43 | }
44 |
45 | func newConfig() (config, error) {
46 | cfg := config{}
47 |
48 | if origErr := conf.Parse(os.Environ(), "SJP", &cfg); origErr != nil {
49 | usage, err := confUsage("SJP", &cfg)
50 | if err != nil {
51 | return cfg, err
52 | }
53 | fmt.Println(usage)
54 | return cfg, origErr
55 | }
56 |
57 | if cfg.AdminAPI.Enable && (cfg.AdminAPI.Password == "" || cfg.AdminAPI.Username == "") {
58 | return cfg, errors.New("admin-api-password and admin-api-username must be set if api has been enabled")
59 | }
60 |
61 | return cfg, nil
62 | }
63 |
--------------------------------------------------------------------------------
/internal/storage/user.go:
--------------------------------------------------------------------------------
1 | package storage
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "github.com/lib/pq"
7 | "github.com/mattn/go-sqlite3"
8 | "gorm.io/gorm"
9 | "reflect"
10 | )
11 |
12 | // User represent a persisted user
13 | type User struct {
14 | gorm.Model
15 | EMail string `gorm:"uniqueIndex:unique_email"`
16 | Password []byte
17 | Claims Claims
18 | }
19 |
20 | // ErrUserNotFound returned when requested user not found
21 | var ErrUserNotFound = errors.New("user not found")
22 |
23 | // ErrUserAlreadyExists returned when given user already exists
24 | var ErrUserAlreadyExists = errors.New("user already exists")
25 |
26 | // CreateUser persists the given user in database
27 | // return ErrUserNotFound when user not found
28 | // return ErrUserAlreadyExists when user already exists
29 | func (s *Storage) CreateUser(u User) error {
30 | res := s.db.Create(&u)
31 | if res.Error != nil {
32 | fmt.Println(reflect.TypeOf(res.Error))
33 | switch err := res.Error.(type) {
34 | case pq.Error:
35 | if err.Constraint == "unique_email" {
36 | return ErrUserAlreadyExists
37 | }
38 | case sqlite3.Error:
39 | if err.Error() == "UNIQUE constraint failed: users.e_mail" {
40 | return ErrUserAlreadyExists
41 | }
42 | }
43 |
44 | return fmt.Errorf("failed to exec create user stmt: %w", res.Error)
45 | }
46 |
47 | return nil
48 | }
49 |
50 | // User finds the user identified by email
51 | // return ErrUserNotFound when user not found
52 | func (s *Storage) User(email string) (User, error) {
53 | var user User
54 |
55 | err := s.db.First(&user, User{EMail: email}).Error
56 | if errors.Is(err, gorm.ErrRecordNotFound) {
57 | return User{}, ErrUserNotFound
58 | } else if err != nil {
59 | return User{}, fmt.Errorf("failed to query user: %w", err)
60 | }
61 |
62 | return user, nil
63 | }
64 |
65 | // UpdateUser updates all properties (excluding email) from the given user which will be identified by email
66 | // return ErrUserNotFound when user not found
67 | func (s *Storage) UpdateUser(u User) error {
68 | res := s.db.Updates(u)
69 | if res.Error != nil {
70 | return fmt.Errorf("failed to exec update user stmt: %w", res.Error)
71 | }
72 |
73 | if res.RowsAffected == 0 {
74 | return ErrUserNotFound
75 | }
76 |
77 | return nil
78 | }
79 |
80 | // DeleteUser deletes the user with the given email and all corresponding tokes in one transaction.
81 | // return ErrUserNotFound when user not found
82 | func (s *Storage) DeleteUser(email string) error {
83 | err := s.db.Transaction(func(tx *gorm.DB) error {
84 | err := s.db.Delete(&Token{}, Token{EMail: email}).Error
85 | if err != nil {
86 | return fmt.Errorf("failed to exec delete tokens from user stmt: %w", err)
87 | }
88 |
89 | res := s.db.Delete(&User{}, User{EMail: email})
90 | if res.Error != nil {
91 | return fmt.Errorf("failed to exec delete user stmt: %w", res.Error)
92 | }
93 |
94 | if res.RowsAffected == 0 {
95 | return ErrUserNotFound
96 | }
97 |
98 | return nil
99 | })
100 |
101 | return err
102 | }
103 |
--------------------------------------------------------------------------------
/internal/mailer/smtp.go:
--------------------------------------------------------------------------------
1 | package mailer
2 |
3 | import (
4 | "crypto/tls"
5 | "fmt"
6 | "gopkg.in/mail.v2"
7 | )
8 |
9 | //go:generate moq -out send_closer_moq_test.go . sendCloser
10 | type sendCloser mail.SendCloser
11 |
12 | //go:generate moq -out dialer_moq_test.go . dialer
13 | type dialer interface {
14 | DialAndSend(msgs ...*mail.Message) error
15 | Dial() (mail.SendCloser, error)
16 | }
17 |
18 | //go:generate moq -out template_moq_test.go . template
19 | type template interface {
20 | Render(args interface{}) (*mail.Message, error)
21 | }
22 |
23 | var mailNewDialer = mail.NewDialer
24 |
25 | // Mailer should be created via New and can send different mails to different recipients based on loaded mail templates
26 | type Mailer struct {
27 | dialer dialer
28 | templates map[string]template
29 | }
30 |
31 | var buildDialer = func(username string, password string, host string, port int, tlsInsecureSkipVerify bool, tlsServerName string) dialer {
32 | d := mailNewDialer(host, port, username, password)
33 | d.TLSConfig = &tls.Config{InsecureSkipVerify: tlsInsecureSkipVerify, ServerName: tlsServerName}
34 |
35 | return d
36 | }
37 |
38 | // New creates a Mailer instance with the given smtp-configuration. Before instantiation a dial tests the configuration
39 | // and all available templates will be parsed.
40 | // 'tlsServerName' is only required if 'tlsInsecureSkipVerify' is false.
41 | func New(templatesFolderPath, username, password, host string, port int, tlsInsecureSkipVerify bool, tlsServerName string) (*Mailer, error) {
42 | d := buildDialer(username, password, host, port, tlsInsecureSkipVerify, tlsServerName)
43 |
44 | //check connection and auth
45 | sc, err := d.Dial()
46 | if err != nil {
47 | return nil, fmt.Errorf("failed to connect to smtp server: %w", err)
48 | }
49 | defer func() { _ = sc.Close() }()
50 |
51 | pwRestTmpl, err := loadTemplates(templatesFolderPath, passwordResetRequestTemplateName)
52 | if err != nil {
53 | return nil, fmt.Errorf("failed to load password-reset mailTemplate: %w", err)
54 | }
55 |
56 | return &Mailer{
57 | dialer: d,
58 | templates: map[string]template{
59 | passwordResetRequestTemplateName: pwRestTmpl,
60 | },
61 | }, nil
62 | }
63 |
64 | // SendPasswordResetRequestEMail sends a password-reset-request mail to the given recipient. 'passwordResetToken' and
65 | // 'claims' can be used in mail-templates.
66 | func (m *Mailer) SendPasswordResetRequestEMail(recipient, passwordResetToken string, claims map[string]interface{}) error {
67 | mailData := struct {
68 | Recipient string
69 | PasswordResetToken string
70 | Claims map[string]interface{}
71 | }{
72 | Recipient: recipient,
73 | PasswordResetToken: passwordResetToken,
74 | Claims: claims,
75 | }
76 |
77 | tpl, found := m.templates[passwordResetRequestTemplateName]
78 | if !found {
79 | return fmt.Errorf("could not found mailTemplate with name %q", passwordResetRequestTemplateName)
80 | }
81 |
82 | msg, err := tpl.Render(mailData)
83 | if err != nil {
84 | return fmt.Errorf("failed to render mail-template %q: %w", passwordResetRequestTemplateName, err)
85 | }
86 |
87 | err = m.dialer.DialAndSend(msg)
88 | if err != nil {
89 | return fmt.Errorf("failed to send email: %w", err)
90 | }
91 |
92 | return nil
93 | }
94 |
--------------------------------------------------------------------------------
/internal/mailer/dialer_moq_test.go:
--------------------------------------------------------------------------------
1 | // Code generated by moq; DO NOT EDIT.
2 | // github.com/matryer/moq
3 |
4 | package mailer
5 |
6 | import (
7 | "gopkg.in/mail.v2"
8 | "sync"
9 | )
10 |
11 | // Ensure, that dialerMock does implement dialer.
12 | // If this is not the case, regenerate this file with moq.
13 | var _ dialer = &dialerMock{}
14 |
15 | // dialerMock is a mock implementation of dialer.
16 | //
17 | // func TestSomethingThatUsesdialer(t *testing.T) {
18 | //
19 | // // make and configure a mocked dialer
20 | // mockeddialer := &dialerMock{
21 | // DialFunc: func() (mail.SendCloser, error) {
22 | // panic("mock out the Dial method")
23 | // },
24 | // DialAndSendFunc: func(msgs ...*mail.Message) error {
25 | // panic("mock out the DialAndSend method")
26 | // },
27 | // }
28 | //
29 | // // use mockeddialer in code that requires dialer
30 | // // and then make assertions.
31 | //
32 | // }
33 | type dialerMock struct {
34 | // DialFunc mocks the Dial method.
35 | DialFunc func() (mail.SendCloser, error)
36 |
37 | // DialAndSendFunc mocks the DialAndSend method.
38 | DialAndSendFunc func(msgs ...*mail.Message) error
39 |
40 | // calls tracks calls to the methods.
41 | calls struct {
42 | // Dial holds details about calls to the Dial method.
43 | Dial []struct {
44 | }
45 | // DialAndSend holds details about calls to the DialAndSend method.
46 | DialAndSend []struct {
47 | // Msgs is the msgs argument value.
48 | Msgs []*mail.Message
49 | }
50 | }
51 | lockDial sync.RWMutex
52 | lockDialAndSend sync.RWMutex
53 | }
54 |
55 | // Dial calls DialFunc.
56 | func (mock *dialerMock) Dial() (mail.SendCloser, error) {
57 | if mock.DialFunc == nil {
58 | panic("dialerMock.DialFunc: method is nil but dialer.Dial was just called")
59 | }
60 | callInfo := struct {
61 | }{}
62 | mock.lockDial.Lock()
63 | mock.calls.Dial = append(mock.calls.Dial, callInfo)
64 | mock.lockDial.Unlock()
65 | return mock.DialFunc()
66 | }
67 |
68 | // DialCalls gets all the calls that were made to Dial.
69 | // Check the length with:
70 | // len(mockeddialer.DialCalls())
71 | func (mock *dialerMock) DialCalls() []struct {
72 | } {
73 | var calls []struct {
74 | }
75 | mock.lockDial.RLock()
76 | calls = mock.calls.Dial
77 | mock.lockDial.RUnlock()
78 | return calls
79 | }
80 |
81 | // DialAndSend calls DialAndSendFunc.
82 | func (mock *dialerMock) DialAndSend(msgs ...*mail.Message) error {
83 | if mock.DialAndSendFunc == nil {
84 | panic("dialerMock.DialAndSendFunc: method is nil but dialer.DialAndSend was just called")
85 | }
86 | callInfo := struct {
87 | Msgs []*mail.Message
88 | }{
89 | Msgs: msgs,
90 | }
91 | mock.lockDialAndSend.Lock()
92 | mock.calls.DialAndSend = append(mock.calls.DialAndSend, callInfo)
93 | mock.lockDialAndSend.Unlock()
94 | return mock.DialAndSendFunc(msgs...)
95 | }
96 |
97 | // DialAndSendCalls gets all the calls that were made to DialAndSend.
98 | // Check the length with:
99 | // len(mockeddialer.DialAndSendCalls())
100 | func (mock *dialerMock) DialAndSendCalls() []struct {
101 | Msgs []*mail.Message
102 | } {
103 | var calls []struct {
104 | Msgs []*mail.Message
105 | }
106 | mock.lockDialAndSend.RLock()
107 | calls = mock.calls.DialAndSend
108 | mock.lockDialAndSend.RUnlock()
109 | return calls
110 | }
111 |
--------------------------------------------------------------------------------
/internal/mailer_moq_test.go:
--------------------------------------------------------------------------------
1 | // Code generated by moq; DO NOT EDIT.
2 | // github.com/matryer/moq
3 |
4 | package internal
5 |
6 | import (
7 | "sync"
8 | )
9 |
10 | // Ensure, that MailerMock does implement Mailer.
11 | // If this is not the case, regenerate this file with moq.
12 | var _ Mailer = &MailerMock{}
13 |
14 | // MailerMock is a mock implementation of Mailer.
15 | //
16 | // func TestSomethingThatUsesMailer(t *testing.T) {
17 | //
18 | // // make and configure a mocked Mailer
19 | // mockedMailer := &MailerMock{
20 | // SendPasswordResetRequestEMailFunc: func(recipient string, passwordResetToken string, claims map[string]interface{}) error {
21 | // panic("mock out the SendPasswordResetRequestEMail method")
22 | // },
23 | // }
24 | //
25 | // // use mockedMailer in code that requires Mailer
26 | // // and then make assertions.
27 | //
28 | // }
29 | type MailerMock struct {
30 | // SendPasswordResetRequestEMailFunc mocks the SendPasswordResetRequestEMail method.
31 | SendPasswordResetRequestEMailFunc func(recipient string, passwordResetToken string, claims map[string]interface{}) error
32 |
33 | // calls tracks calls to the methods.
34 | calls struct {
35 | // SendPasswordResetRequestEMail holds details about calls to the SendPasswordResetRequestEMail method.
36 | SendPasswordResetRequestEMail []struct {
37 | // Recipient is the recipient argument value.
38 | Recipient string
39 | // PasswordResetToken is the passwordResetToken argument value.
40 | PasswordResetToken string
41 | // Claims is the claims argument value.
42 | Claims map[string]interface{}
43 | }
44 | }
45 | lockSendPasswordResetRequestEMail sync.RWMutex
46 | }
47 |
48 | // SendPasswordResetRequestEMail calls SendPasswordResetRequestEMailFunc.
49 | func (mock *MailerMock) SendPasswordResetRequestEMail(recipient string, passwordResetToken string, claims map[string]interface{}) error {
50 | if mock.SendPasswordResetRequestEMailFunc == nil {
51 | panic("MailerMock.SendPasswordResetRequestEMailFunc: method is nil but Mailer.SendPasswordResetRequestEMail was just called")
52 | }
53 | callInfo := struct {
54 | Recipient string
55 | PasswordResetToken string
56 | Claims map[string]interface{}
57 | }{
58 | Recipient: recipient,
59 | PasswordResetToken: passwordResetToken,
60 | Claims: claims,
61 | }
62 | mock.lockSendPasswordResetRequestEMail.Lock()
63 | mock.calls.SendPasswordResetRequestEMail = append(mock.calls.SendPasswordResetRequestEMail, callInfo)
64 | mock.lockSendPasswordResetRequestEMail.Unlock()
65 | return mock.SendPasswordResetRequestEMailFunc(recipient, passwordResetToken, claims)
66 | }
67 |
68 | // SendPasswordResetRequestEMailCalls gets all the calls that were made to SendPasswordResetRequestEMail.
69 | // Check the length with:
70 | // len(mockedMailer.SendPasswordResetRequestEMailCalls())
71 | func (mock *MailerMock) SendPasswordResetRequestEMailCalls() []struct {
72 | Recipient string
73 | PasswordResetToken string
74 | Claims map[string]interface{}
75 | } {
76 | var calls []struct {
77 | Recipient string
78 | PasswordResetToken string
79 | Claims map[string]interface{}
80 | }
81 | mock.lockSendPasswordResetRequestEMail.RLock()
82 | calls = mock.calls.SendPasswordResetRequestEMail
83 | mock.lockSendPasswordResetRequestEMail.RUnlock()
84 | return calls
85 | }
86 |
--------------------------------------------------------------------------------
/internal/mailer/send_closer_moq_test.go:
--------------------------------------------------------------------------------
1 | // Code generated by moq; DO NOT EDIT.
2 | // github.com/matryer/moq
3 |
4 | package mailer
5 |
6 | import (
7 | "io"
8 | "sync"
9 | )
10 |
11 | // Ensure, that sendCloserMock does implement sendCloser.
12 | // If this is not the case, regenerate this file with moq.
13 | var _ sendCloser = &sendCloserMock{}
14 |
15 | // sendCloserMock is a mock implementation of sendCloser.
16 | //
17 | // func TestSomethingThatUsessendCloser(t *testing.T) {
18 | //
19 | // // make and configure a mocked sendCloser
20 | // mockedsendCloser := &sendCloserMock{
21 | // CloseFunc: func() error {
22 | // panic("mock out the Close method")
23 | // },
24 | // SendFunc: func(from string, to []string, msg io.WriterTo) error {
25 | // panic("mock out the Send method")
26 | // },
27 | // }
28 | //
29 | // // use mockedsendCloser in code that requires sendCloser
30 | // // and then make assertions.
31 | //
32 | // }
33 | type sendCloserMock struct {
34 | // CloseFunc mocks the Close method.
35 | CloseFunc func() error
36 |
37 | // SendFunc mocks the Send method.
38 | SendFunc func(from string, to []string, msg io.WriterTo) error
39 |
40 | // calls tracks calls to the methods.
41 | calls struct {
42 | // Close holds details about calls to the Close method.
43 | Close []struct {
44 | }
45 | // Send holds details about calls to the Send method.
46 | Send []struct {
47 | // From is the from argument value.
48 | From string
49 | // To is the to argument value.
50 | To []string
51 | // Msg is the msg argument value.
52 | Msg io.WriterTo
53 | }
54 | }
55 | lockClose sync.RWMutex
56 | lockSend sync.RWMutex
57 | }
58 |
59 | // Close calls CloseFunc.
60 | func (mock *sendCloserMock) Close() error {
61 | if mock.CloseFunc == nil {
62 | panic("sendCloserMock.CloseFunc: method is nil but sendCloser.Close was just called")
63 | }
64 | callInfo := struct {
65 | }{}
66 | mock.lockClose.Lock()
67 | mock.calls.Close = append(mock.calls.Close, callInfo)
68 | mock.lockClose.Unlock()
69 | return mock.CloseFunc()
70 | }
71 |
72 | // CloseCalls gets all the calls that were made to Close.
73 | // Check the length with:
74 | // len(mockedsendCloser.CloseCalls())
75 | func (mock *sendCloserMock) CloseCalls() []struct {
76 | } {
77 | var calls []struct {
78 | }
79 | mock.lockClose.RLock()
80 | calls = mock.calls.Close
81 | mock.lockClose.RUnlock()
82 | return calls
83 | }
84 |
85 | // Send calls SendFunc.
86 | func (mock *sendCloserMock) Send(from string, to []string, msg io.WriterTo) error {
87 | if mock.SendFunc == nil {
88 | panic("sendCloserMock.SendFunc: method is nil but sendCloser.Send was just called")
89 | }
90 | callInfo := struct {
91 | From string
92 | To []string
93 | Msg io.WriterTo
94 | }{
95 | From: from,
96 | To: to,
97 | Msg: msg,
98 | }
99 | mock.lockSend.Lock()
100 | mock.calls.Send = append(mock.calls.Send, callInfo)
101 | mock.lockSend.Unlock()
102 | return mock.SendFunc(from, to, msg)
103 | }
104 |
105 | // SendCalls gets all the calls that were made to Send.
106 | // Check the length with:
107 | // len(mockedsendCloser.SendCalls())
108 | func (mock *sendCloserMock) SendCalls() []struct {
109 | From string
110 | To []string
111 | Msg io.WriterTo
112 | } {
113 | var calls []struct {
114 | From string
115 | To []string
116 | Msg io.WriterTo
117 | }
118 | mock.lockSend.RLock()
119 | calls = mock.calls.Send
120 | mock.lockSend.RUnlock()
121 | return calls
122 | }
123 |
--------------------------------------------------------------------------------
/cmd/provider/reset_password_component_test.go:
--------------------------------------------------------------------------------
1 | // +build component
2 |
3 | package main
4 |
5 | import (
6 | "bytes"
7 | "encoding/json"
8 | "fmt"
9 | "net/http"
10 | "regexp"
11 | "strings"
12 | "testing"
13 | )
14 |
15 | func TestResetPassword(t *testing.T) {
16 | email := "reset_test@leberkleber.io"
17 | password := "s3cr3t"
18 | newPassword := "t3rc3s"
19 |
20 | createUser(t, email, password)
21 | createPasswordResetRequest(t, email)
22 | token := findPasswordResetTokenFromMailAndVerifyContent(t, email)
23 | resetPassword(t, email, token, newPassword)
24 |
25 | loginUser(t, email, newPassword)
26 | }
27 |
28 | func findPasswordResetTokenFromMailAndVerifyContent(t *testing.T, email string) string {
29 | resp, err := http.Get("http://mail-server:8025/api/v2/messages")
30 | if err != nil {
31 | t.Fatalf("Failed to login cause: %s", err)
32 | }
33 |
34 | if resp.StatusCode != http.StatusOK {
35 | t.Fatalf("Invalid response status code. Expected: %d, Given: %d", http.StatusOK, resp.StatusCode)
36 | }
37 |
38 | var mailhogRes MailhogResponse
39 |
40 | defer resp.Body.Close()
41 | err = json.NewDecoder(resp.Body).Decode(&mailhogRes)
42 | if err != nil {
43 | t.Fatalf("failed to encode smtp-server api-response: %s", err)
44 | }
45 |
46 | var respMail MailhogResponseItemRaw
47 | respMailFound := false
48 | for _, r := range mailhogRes.Items {
49 | for i := range r.Raw.To {
50 | if r.Raw.To[i] == email {
51 | respMail = r.Raw
52 | respMailFound = true
53 | break
54 | }
55 | }
56 | }
57 |
58 | if !respMailFound {
59 | t.Fatal("could not find mail body")
60 | }
61 |
62 | expectedCustomCalimValue := "customClaimValue"
63 | if !strings.Contains(respMail.Data, "customClaimValue") {
64 | t.Errorf("email body dosent contains custom claim value %q: \n%q", expectedCustomCalimValue, respMail.Data)
65 | }
66 |
67 | reg, err := regexp.Compile("([a-f0-9]{64})")
68 | if err != nil {
69 | t.Fatal("could not compile regex")
70 | }
71 |
72 | res := reg.Find([]byte(respMail.Data))
73 | if len(res) == 0 {
74 | t.Fatalf("no reset token found. Mail content %q", respMail.Data)
75 | }
76 |
77 | return string(res)
78 | }
79 |
80 | func createPasswordResetRequest(t *testing.T, email string) {
81 | resp, err := http.Post(
82 | "http://simple-jwt-provider/v1/auth/password-reset-request",
83 | "application/json",
84 | bytes.NewReader([]byte(fmt.Sprintf(`{"email": %q}`, email))),
85 | )
86 | if err != nil {
87 | t.Fatalf("Failed to create password-reset-request cause: %s", err)
88 | }
89 |
90 | if resp.StatusCode != http.StatusCreated {
91 | t.Fatalf("Invalid response status code. Expected: %d, Given: %d", http.StatusCreated, resp.StatusCode)
92 | }
93 |
94 | return
95 | }
96 |
97 | func resetPassword(t *testing.T, email, token, newPassword string) {
98 | resp, err := http.Post(
99 | "http://simple-jwt-provider/v1/auth/password-reset",
100 | "application/json",
101 | bytes.NewReader([]byte(fmt.Sprintf(`{"email": %q, "reset_token":%q, "password": %q}`, email, token, newPassword))),
102 | )
103 | if err != nil {
104 | t.Fatalf("Failed to login cause: %s", err)
105 | }
106 |
107 | if resp.StatusCode != http.StatusNoContent {
108 | t.Fatalf("Invalid response status code. Expected: %d, Given: %d", http.StatusNoContent, resp.StatusCode)
109 | }
110 | }
111 |
112 | type MailhogResponse struct {
113 | Items []MailhogResponseItem `json:"items"`
114 | }
115 | type MailhogResponseItem struct {
116 | Raw MailhogResponseItemRaw `json:"Raw"`
117 | }
118 |
119 | type MailhogResponseItemRaw struct {
120 | To []string `json:"To"`
121 | Data string `json:"Data"`
122 | }
123 |
--------------------------------------------------------------------------------
/internal/jwt/validate_test.go:
--------------------------------------------------------------------------------
1 | package jwt
2 |
3 | import (
4 | "crypto/ecdsa"
5 | "errors"
6 | "fmt"
7 | "github.com/golang-jwt/jwt"
8 | "reflect"
9 | "testing"
10 | "time"
11 | )
12 |
13 | func TestProvider_Error_IsTokenValid(t *testing.T) {
14 | tests := []struct {
15 | name string
16 | givenToken string
17 | parseFuncErr error
18 | parseFuncToken *jwt.Token
19 | parseFuncClaims jwt.MapClaims
20 | expectedJWT string
21 | expectedIsValid bool
22 | expectedClaims jwt.MapClaims
23 | expectedErr error
24 | }{
25 | {
26 | name: "Happycase",
27 | givenToken: "myToken",
28 | parseFuncClaims: jwt.MapClaims{"my": "claim"},
29 | parseFuncToken: &jwt.Token{Valid: true},
30 | expectedJWT: "myToken",
31 | expectedClaims: jwt.MapClaims{"my": "claim"},
32 | }, {
33 | name: "parse error",
34 | givenToken: "myToken",
35 | parseFuncErr: errors.New("my error"),
36 | expectedErr: errors.New("failed to parse token: my error"),
37 | expectedJWT: "myToken",
38 | }, {
39 | name: "invalid token",
40 | givenToken: "myToken",
41 | parseFuncClaims: jwt.MapClaims{"my": "claim"},
42 | parseFuncToken: &jwt.Token{Valid: false},
43 | expectedJWT: "myToken",
44 | expectedClaims: jwt.MapClaims{"my": "claim"},
45 | expectedErr: errors.New("token is not valid"),
46 | },
47 | }
48 |
49 | for _, tt := range tests {
50 | t.Run(tt.name, func(t *testing.T) {
51 | oldParseFunc := parseFunc
52 | defer func() {
53 | parseFunc = oldParseFunc
54 | }()
55 |
56 | parseFunc = func(tokenString string, c jwt.Claims, _ jwt.Keyfunc) (token *jwt.Token, e error) {
57 | if tokenString != tt.expectedJWT {
58 | t.Errorf("Unexpected parseFunc>token. Expected: %s. Given: %s", tt.expectedJWT, tokenString)
59 | }
60 |
61 | cc := c.(*jwt.MapClaims)
62 | *cc = tt.parseFuncClaims
63 |
64 | return tt.parseFuncToken, tt.parseFuncErr
65 | }
66 |
67 | isValid, claims, err := Provider{privateKey: &ecdsa.PrivateKey{}}.IsTokenValid(tt.givenToken)
68 | if fmt.Sprint(err) != fmt.Sprint(tt.expectedErr) {
69 | t.Fatalf("Unexpected error. Expected: %q. Given: %q", tt.expectedErr, err)
70 | } else if err != nil {
71 | return
72 | }
73 |
74 | if isValid == tt.expectedIsValid {
75 | t.Errorf("Unexpected response isValid. Expected: %t. Given: %t", tt.expectedIsValid, isValid)
76 | }
77 |
78 | if !reflect.DeepEqual(claims, tt.expectedClaims) {
79 | t.Errorf("Unexpected response claims. Expected: %#v. Given: %#v", tt.expectedClaims, claims)
80 | }
81 | })
82 | }
83 | }
84 |
85 | func TestProvider_IsTokenValid(t *testing.T) {
86 | email := "my.mail@test.de"
87 |
88 | provider, err := NewProvider(jwtPrvKey, time.Minute, "audience", "issuer", "subject")
89 | if err != nil {
90 | t.Fatal("failed to create provider", err)
91 | }
92 |
93 | token, jwtID, err := provider.GenerateRefreshToken(email)
94 | if err != nil {
95 | t.Fatal("failed to generate test refresh-token", err)
96 | }
97 | if jwtID == "" {
98 | t.Error("generate returns no jwtID")
99 | }
100 |
101 | isValid, claims, err := provider.IsTokenValid(token)
102 | if err != nil {
103 | t.Fatal("failed to validate token", err)
104 | }
105 |
106 | if !isValid {
107 | t.Error("token is not valid")
108 | }
109 |
110 | claimEmail, ok := claims["email"].(string)
111 | if !ok {
112 | t.Fatalf("email is not parsable as string. Claims: %#v", claims)
113 | }
114 |
115 | if email != claimEmail {
116 | t.Errorf("claims>email is not as expected. Expected: %q, Given: %q", email, claimEmail)
117 | }
118 | }
119 |
--------------------------------------------------------------------------------
/internal/admin.go:
--------------------------------------------------------------------------------
1 | package internal
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "github.com/leberKleber/simple-jwt-provider/internal/storage"
7 | "golang.org/x/crypto/bcrypt"
8 | )
9 |
10 | const blankedPassword = "**********"
11 |
12 | var bcryptCost = 12
13 |
14 | // ErrUserAlreadyExists returned when given user already exists
15 | var ErrUserAlreadyExists = errors.New("user already exists")
16 |
17 | // User is the representation of a user for use in internal
18 | type User struct {
19 | EMail string
20 | Password string
21 | Claims map[string]interface{}
22 | }
23 |
24 | // CreateUser creates new user with given email, password and claims.
25 | // return ErrUserAlreadyExists when user already exists
26 | func (p Provider) CreateUser(user User) error {
27 | bcryptedPassword, err := bcryptPassword(user.Password)
28 | if err != nil {
29 | return fmt.Errorf("failed to bcrypt password: %w", err)
30 | }
31 |
32 | err = p.Storage.CreateUser(storage.User{
33 | EMail: user.EMail,
34 | Password: bcryptedPassword,
35 | Claims: user.Claims,
36 | })
37 | if err != nil {
38 | if errors.Is(err, storage.ErrUserAlreadyExists) {
39 | return ErrUserAlreadyExists
40 | }
41 | return fmt.Errorf("failed to query user with email %q: %w", user.EMail, err)
42 | }
43 |
44 | return nil
45 | }
46 |
47 | // GetUser returns a user with the given email.
48 | // return ErrUserNotFound when user does not exist
49 | func (p Provider) GetUser(email string) (User, error) {
50 | user, err := p.Storage.User(email)
51 | if err != nil {
52 | if errors.Is(err, storage.ErrUserNotFound) {
53 | return User{}, ErrUserNotFound
54 | }
55 |
56 | return User{}, fmt.Errorf("failed to find user with email %q: %w", email, err)
57 | }
58 |
59 | return User{
60 | EMail: user.EMail,
61 | Password: blankedPassword,
62 | Claims: user.Claims,
63 | }, nil
64 | }
65 |
66 | // UpdateUser updates user with given email.
67 | // return ErrUserNotFound when user does not exist
68 | func (p Provider) UpdateUser(email string, user User) (User, error) {
69 | dbUser, err := p.Storage.User(email)
70 | if err != nil {
71 | if errors.Is(err, storage.ErrUserNotFound) {
72 | return User{}, ErrUserNotFound
73 | }
74 |
75 | return User{}, fmt.Errorf("failed to find user to update: %w", err)
76 | }
77 |
78 | if user.Password != "" {
79 | bcryptedPassword, err := bcryptPassword(user.Password)
80 | if err != nil {
81 | return User{}, fmt.Errorf("failed to bcrypt new password: %w", err)
82 | }
83 | dbUser.Password = bcryptedPassword
84 | }
85 |
86 | if user.Claims != nil {
87 | dbUser.Claims = user.Claims
88 | }
89 |
90 | err = p.Storage.UpdateUser(dbUser)
91 | if err != nil {
92 | if errors.Is(err, storage.ErrUserNotFound) {
93 | return User{}, ErrUserNotFound
94 | }
95 |
96 | return User{}, fmt.Errorf("failed to update user: %w", err)
97 | }
98 |
99 | return User{
100 | EMail: dbUser.EMail,
101 | Password: blankedPassword,
102 | Claims: dbUser.Claims,
103 | }, nil
104 | }
105 |
106 | // DeleteUser deletes user with given email.
107 | // return ErrUserNotFound when user does not exist
108 | // return ErrUserStillHasTokens when user still has tokens
109 | func (p Provider) DeleteUser(email string) error {
110 | err := p.Storage.DeleteUser(email)
111 | if err != nil {
112 | if errors.Is(err, storage.ErrUserNotFound) {
113 | return ErrUserNotFound
114 | }
115 |
116 | return fmt.Errorf("failed to delete user with email %q: %w", email, err)
117 | }
118 |
119 | return nil
120 | }
121 |
122 | var bcryptPassword = func(password string) ([]byte, error) {
123 | return bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
124 | }
125 |
--------------------------------------------------------------------------------
/internal/jwt/provider_test.go:
--------------------------------------------------------------------------------
1 | package jwt
2 |
3 | import (
4 | "crypto/ecdsa"
5 | "errors"
6 | "fmt"
7 | "github.com/golang-jwt/jwt"
8 | "math/big"
9 | "testing"
10 | "time"
11 | )
12 |
13 | var jwtPubKey = `-----BEGIN PUBLIC KEY-----
14 | MIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBQSa/dFpXRqz6aQQmx6sNpxl3mn8Z
15 | 0o+qgfgOxPAPxu+JppsCGqrX/6SeUI6kz3AFVABGBU8/9Ejzt7Ty9WJt1dEB+035
16 | 03+xLnmmyaj3bEhkerr229mDgPb8uDlPEl6f/Wv+Ma/eIIloCo8WJAe8YsviImbF
17 | hAV1NK8+62/iMCfNj30=
18 | -----END PUBLIC KEY-----
19 | `
20 | var jwtPrvKey = `-----BEGIN EC PRIVATE KEY-----
21 | MIHcAgEBBEIASzDZeTVLxcE5KTAmwrKwFjzr5cDrA+tttx9XRUz0K7AlROtj7cMG
22 | rHu/bdKj7lc2WaW8x/EOrU/FeCcsIL5nTH+gBwYFK4EEACOhgYkDgYYABAFBJr90
23 | WldGrPppBCbHqw2nGXeafxnSj6qB+A7E8A/G74mmmwIaqtf/pJ5QjqTPcAVUAEYF
24 | Tz/0SPO3tPL1Ym3V0QH7TfnTf7EueabJqPdsSGR6uvbb2YOA9vy4OU8SXp/9a/4x
25 | r94giWgKjxYkB7xiy+IiZsWEBXU0rz7rb+IwJ82PfQ==
26 | -----END EC PRIVATE KEY-----`
27 |
28 | func TestNewGenerator_WithoutPrivateKey(t *testing.T) {
29 | _, err := NewProvider("", 4*time.Hour, "audience", "issuer", "subject")
30 |
31 | expectedError := errors.New("no valid private key found")
32 | if fmt.Sprint(err) != fmt.Sprint(expectedError) {
33 | t.Errorf("Unexpected error. Expected: %q, Given: %q", expectedError, err)
34 | }
35 | }
36 |
37 | func TestNewGenerator_InvalidPrivateKey(t *testing.T) {
38 | oldX509ParseECPrivateKey := x509ParseECPrivateKey
39 | defer func() { x509ParseECPrivateKey = oldX509ParseECPrivateKey }()
40 |
41 | x509ParseECPrivateKey = func(der []byte) (*ecdsa.PrivateKey, error) {
42 | return nil, errors.New("errrooooooorrrr")
43 | }
44 |
45 | _, err := NewProvider(jwtPrvKey, 4*time.Hour, "audience", "issuer", "subject")
46 |
47 | expectedError := errors.New("failed to parse private-key: errrooooooorrrr")
48 | if fmt.Sprint(err) != fmt.Sprint(expectedError) {
49 | t.Errorf("Unexpected error. Expected: %q, Given: %q", expectedError, err)
50 | }
51 | }
52 |
53 | func TestCheckSigningMethodKeyFunc(t *testing.T) {
54 | tests := []struct {
55 | name string
56 | givenSigningMethod jwt.SigningMethod
57 | givenPublicKey *ecdsa.PublicKey
58 | givenToken *jwt.Token
59 | expectedResponse interface{}
60 | expectedErr error
61 | }{
62 | {
63 | name: "Happycase",
64 | givenSigningMethod: jwt.SigningMethodES512,
65 | givenPublicKey: &ecdsa.PublicKey{X: big.NewInt(555), Y: big.NewInt(666)},
66 | givenToken: &jwt.Token{
67 | Method: jwt.SigningMethodES512,
68 | },
69 | expectedResponse: &ecdsa.PublicKey{X: big.NewInt(555), Y: big.NewInt(666)},
70 | }, {
71 | name: "Unexpected signing method",
72 | givenSigningMethod: jwt.SigningMethodES512,
73 | givenPublicKey: &ecdsa.PublicKey{X: big.NewInt(555), Y: big.NewInt(666)},
74 | givenToken: &jwt.Token{
75 | Method: jwt.SigningMethodPS256,
76 | },
77 | expectedErr: errors.New("unexpected signing method \"*jwt.SigningMethodRSAPSS\", expected: \"*jwt.SigningMethodECDSA\""),
78 | expectedResponse: &ecdsa.PublicKey{X: big.NewInt(555), Y: big.NewInt(666)},
79 | },
80 | }
81 |
82 | for _, tt := range tests {
83 | t.Run(tt.name, func(t *testing.T) {
84 | jwtKeyFunc := checkSigningMethodKeyFunc(tt.givenSigningMethod, tt.givenPublicKey)
85 |
86 | resp, err := jwtKeyFunc(tt.givenToken)
87 | expectedResponseAsString := fmt.Sprint(tt.expectedResponse)
88 | respAsString := fmt.Sprint(resp)
89 |
90 | if fmt.Sprint(err) != fmt.Sprint(tt.expectedErr) {
91 | t.Errorf("Unexpected error. \nExpected: %q\nGiven:\n%q", tt.expectedErr, err)
92 | } else if err != nil {
93 | return
94 | }
95 |
96 | if expectedResponseAsString != respAsString {
97 | t.Errorf("unexpected response. Given: %q, Expected: %q", respAsString, expectedResponseAsString)
98 | }
99 | })
100 | }
101 | }
102 |
--------------------------------------------------------------------------------
/internal/web/server.go:
--------------------------------------------------------------------------------
1 | package web
2 |
3 | import (
4 | "encoding/json"
5 | "github.com/gorilla/mux"
6 | "github.com/leberKleber/simple-jwt-provider/internal"
7 | "github.com/leberKleber/simple-jwt-provider/internal/web/middleware"
8 | "github.com/sirupsen/logrus"
9 | "net/http"
10 | )
11 |
12 | var httpListenAndServe = http.ListenAndServe
13 |
14 | // Provider encapsulates internal.Provider to generate mocks
15 | //go:generate moq -out provider_moq_test.go . Provider
16 | type Provider interface {
17 | Login(email, password string) (string, string, error)
18 | Refresh(refreshToken string) (string, string, error)
19 | CreatePasswordResetRequest(email string) error
20 | ResetPassword(email, resetToken, password string) error
21 | CreateUser(user internal.User) error
22 | UpdateUser(email string, user internal.User) (internal.User, error)
23 | GetUser(email string) (internal.User, error)
24 | DeleteUser(email string) error
25 | }
26 |
27 | // Server should be created via NewServer and starts with ListenAndServe all http endpoints for this service.
28 | type Server struct {
29 | h http.Handler
30 | p Provider
31 | }
32 |
33 | // NewServer returns a Server instance with configure http routs
34 | func NewServer(p Provider, enableAdminAPI bool, adminAPIUsername, adminAPIPassword string) *Server {
35 | s := &Server{}
36 | r := mux.NewRouter()
37 |
38 | r.Use(contentTypeMiddleware)
39 | r.NotFoundHandler = http.HandlerFunc(notFoundHandler)
40 | r.MethodNotAllowedHandler = http.HandlerFunc(methodNotAllowedHandler)
41 |
42 | v1 := r.PathPrefix("/v1").Subrouter()
43 | v1.Path("/internal/alive").Methods(http.MethodGet).HandlerFunc(s.aliveHandler)
44 | v1.Path("/auth/login").Methods(http.MethodPost).HandlerFunc(s.loginHandler)
45 | v1.Path("/auth/refresh").Methods(http.MethodPost).HandlerFunc(s.refreshHandler)
46 | v1.Path("/auth/password-reset-request").Methods(http.MethodPost).HandlerFunc(s.passwordResetRequestHandler)
47 | v1.Path("/auth/password-reset").Methods(http.MethodPost).HandlerFunc(s.passwordResetHandler)
48 |
49 | if enableAdminAPI {
50 | adminAPI := v1.PathPrefix("/admin").Subrouter()
51 | adminAPI.Use(middleware.BasicAuth(adminAPIUsername, adminAPIPassword))
52 |
53 | adminAPI.Path("/users").Methods(http.MethodPost).HandlerFunc(s.createUserHandler)
54 | adminAPI.Path("/users/{email}").Methods(http.MethodGet).HandlerFunc(s.getUserHandler)
55 | adminAPI.Path("/users/{email}").Methods(http.MethodPut).HandlerFunc(s.updateUserHandler)
56 | adminAPI.Path("/users/{email}").Methods(http.MethodDelete).HandlerFunc(s.deleteUserHandler)
57 | }
58 |
59 | s.h = r
60 | s.p = p
61 | return s
62 | }
63 |
64 | // ListenAndServe wraps http.ListenAndServe
65 | func (s *Server) ListenAndServe(address string) error {
66 | return httpListenAndServe(address, s.h)
67 | }
68 |
69 | func contentTypeMiddleware(handler http.Handler) http.Handler {
70 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
71 | w.Header().Set("Content-Type", "application/json")
72 | handler.ServeHTTP(w, r)
73 | })
74 | }
75 |
76 | type errorResponseBody struct {
77 | Message string `json:"message"`
78 | }
79 |
80 | func writeError(w http.ResponseWriter, statusCode int, message string) {
81 | respBody, err := json.Marshal(errorResponseBody{
82 | Message: message,
83 | })
84 | if err != nil {
85 | logrus.WithError(err).Error("Failed to marshal json error response")
86 | writeInternalServerError(w)
87 | return
88 | }
89 |
90 | w.WriteHeader(statusCode)
91 | _, err = w.Write(respBody)
92 | if err != nil {
93 | logrus.WithError(err).Error("Failed to write error response")
94 | writeInternalServerError(w)
95 | return
96 | }
97 | }
98 |
99 | func writeInternalServerError(w http.ResponseWriter) {
100 | w.WriteHeader(http.StatusInternalServerError)
101 | _, err := w.Write([]byte(`{"message":"internal server error"}`))
102 | if err != nil {
103 | logrus.WithError(err).Error("Failed to write error response")
104 | }
105 | }
106 |
107 | func notFoundHandler(w http.ResponseWriter, _ *http.Request) {
108 | writeError(w, http.StatusNotFound, "endpoint not found")
109 | }
110 |
111 | func methodNotAllowedHandler(w http.ResponseWriter, _ *http.Request) {
112 | writeError(w, http.StatusMethodNotAllowed, "method not allowed")
113 | }
114 |
--------------------------------------------------------------------------------
/cmd/provider/helper_component_test.go:
--------------------------------------------------------------------------------
1 | // +build component
2 |
3 | package main
4 |
5 | import (
6 | "bytes"
7 | "encoding/json"
8 | "fmt"
9 | "io/ioutil"
10 | "net/http"
11 | "net/url"
12 | "testing"
13 | )
14 |
15 | type User struct {
16 | EMail string `json:"email,omitempty"`
17 | Password string `json:"password,omitempty"`
18 | Claims map[string]interface{} `json:"claims,omitempty"`
19 | }
20 |
21 | func createUser(t *testing.T, email, password string) {
22 | t.Helper()
23 | req, err := http.NewRequest(
24 | http.MethodPost,
25 | "http://simple-jwt-provider/v1/admin/users",
26 | bytes.NewReader([]byte(fmt.Sprintf(`{"email": %q, "password": %q, "claims": {"myCustomClaim": "customClaimValue"}}`, email, password))),
27 | )
28 | if err != nil {
29 | t.Fatalf("Failed to create http request")
30 | }
31 |
32 | req.SetBasicAuth("username", "password")
33 |
34 | resp, err := http.DefaultClient.Do(req)
35 | if err != nil {
36 | t.Fatalf("Failed to create user cause: %s", err)
37 | }
38 |
39 | respBody, err := ioutil.ReadAll(resp.Body)
40 | if err != nil {
41 | t.Errorf("Failed to read response body")
42 | }
43 |
44 | if resp.StatusCode != http.StatusCreated {
45 | t.Errorf("Invalid response status code. Expected: %d, Given: %d, Body: %s", http.StatusOK, resp.StatusCode, respBody)
46 | }
47 | }
48 |
49 | func readUser(t *testing.T, email string) User {
50 | t.Helper()
51 | req, err := http.NewRequest(
52 | http.MethodGet,
53 | fmt.Sprintf("http://simple-jwt-provider/v1/admin/users/%s", url.PathEscape(email)),
54 | nil,
55 | )
56 | if err != nil {
57 | t.Fatalf("Failed to create http request")
58 | }
59 |
60 | req.SetBasicAuth("username", "password")
61 |
62 | resp, err := http.DefaultClient.Do(req)
63 | if err != nil {
64 | t.Fatalf("Failed to create user cause: %s", err)
65 | }
66 |
67 | var responseBody User
68 | err = json.NewDecoder(resp.Body).Decode(&responseBody)
69 | if err != nil {
70 | t.Error("Failed to read response body", err)
71 | }
72 |
73 | if resp.StatusCode != http.StatusOK {
74 | t.Errorf("Invalid response status code. Expected: %d, Given: %d, Body: %#v", http.StatusOK, resp.StatusCode, responseBody)
75 | }
76 |
77 | return responseBody
78 | }
79 |
80 | func updateUser(t *testing.T, email string, newPassword string, newClaims map[string]interface{}) {
81 | t.Helper()
82 |
83 | requestBody := User{
84 | Password: newPassword,
85 | Claims: newClaims,
86 | }
87 |
88 | var body bytes.Buffer
89 | err := json.NewEncoder(&body).Encode(requestBody)
90 | if err != nil {
91 | t.Fatal("failed to encode request body", err)
92 | }
93 |
94 | req, err := http.NewRequest(
95 | http.MethodPut,
96 | fmt.Sprintf("http://simple-jwt-provider/v1/admin/users/%s", url.PathEscape(email)),
97 | &body,
98 | )
99 | if err != nil {
100 | t.Fatalf("Failed to create http request")
101 | }
102 |
103 | req.SetBasicAuth("username", "password")
104 |
105 | resp, err := http.DefaultClient.Do(req)
106 | if err != nil {
107 | t.Fatalf("Failed to create user cause: %s", err)
108 | }
109 |
110 | respBody, err := ioutil.ReadAll(resp.Body)
111 | if err != nil {
112 | t.Errorf("Failed to read response body")
113 | }
114 |
115 | if resp.StatusCode != http.StatusOK {
116 | t.Errorf("Invalid response status code. Expected: %d, Given: %d, Body: %s", http.StatusOK, resp.StatusCode, respBody)
117 | }
118 | }
119 |
120 | func deleteUser(t *testing.T, email string) {
121 | t.Helper()
122 | req, err := http.NewRequest(
123 | http.MethodDelete,
124 | fmt.Sprintf("http://simple-jwt-provider/v1/admin/users/%s", url.QueryEscape(email)),
125 | nil,
126 | )
127 | if err != nil {
128 | t.Fatalf("Failed to create http request")
129 | }
130 |
131 | req.SetBasicAuth("username", "password")
132 |
133 | resp, err := http.DefaultClient.Do(req)
134 | if err != nil {
135 | t.Fatalf("Failed to create user cause: %s", err)
136 | }
137 |
138 | respBody, err := ioutil.ReadAll(resp.Body)
139 | if err != nil {
140 | t.Errorf("Failed to read response body")
141 | }
142 |
143 | if resp.StatusCode != http.StatusNoContent {
144 | t.Errorf("Invalid response status code. Expected: %d, Given: %d, Body: %s", http.StatusOK, resp.StatusCode, respBody)
145 | }
146 | }
147 |
--------------------------------------------------------------------------------
/internal/web/admin.go:
--------------------------------------------------------------------------------
1 | package web
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "github.com/gorilla/mux"
7 | "github.com/leberKleber/simple-jwt-provider/internal"
8 | "github.com/sirupsen/logrus"
9 | "net/http"
10 | "net/url"
11 | )
12 |
13 | // User is the representation of a user for use in web
14 | type User struct {
15 | EMail string `json:"email"`
16 | Password string `json:"password"`
17 | Claims map[string]interface{} `json:"claims"`
18 | }
19 |
20 | func (s *Server) createUserHandler(w http.ResponseWriter, r *http.Request) {
21 | var user User
22 |
23 | err := json.NewDecoder(r.Body).Decode(&user)
24 | if err != nil {
25 | writeError(w, http.StatusBadRequest, "invalid JSON")
26 | return
27 | }
28 |
29 | if user.EMail == "" {
30 | writeError(w, http.StatusBadRequest, "email must be set")
31 | return
32 | }
33 |
34 | if user.Password == "" {
35 | writeError(w, http.StatusBadRequest, "password must be set")
36 | return
37 | }
38 |
39 | err = s.p.CreateUser(internal.User{
40 | EMail: user.EMail,
41 | Password: user.Password,
42 | Claims: user.Claims,
43 | })
44 | if err != nil {
45 | if errors.Is(err, internal.ErrUserAlreadyExists) {
46 | writeError(w, http.StatusConflict, "User with given email already exists")
47 | return
48 | }
49 |
50 | logrus.WithError(err).Error("Failed to create User")
51 | writeInternalServerError(w)
52 | return
53 | }
54 |
55 | w.WriteHeader(http.StatusCreated)
56 | }
57 |
58 | func (s *Server) getUserHandler(w http.ResponseWriter, r *http.Request) {
59 | email, err := url.PathUnescape(mux.Vars(r)["email"])
60 | if err != nil {
61 | writeError(w, http.StatusBadRequest, "could not unescape email")
62 | return
63 | }
64 |
65 | // when email has not been set 'notFoundHandler' handler will be used
66 |
67 | user, err := s.p.GetUser(email)
68 | if err != nil {
69 | if errors.Is(err, internal.ErrUserNotFound) {
70 | writeError(w, http.StatusNotFound, "User with given email doesn't exists")
71 | return
72 | }
73 |
74 | logrus.WithError(err).Error("Failed to get User")
75 | writeInternalServerError(w)
76 | return
77 | }
78 |
79 | err = json.NewEncoder(w).Encode(User{
80 | EMail: user.EMail,
81 | Password: user.Password,
82 | Claims: user.Claims,
83 | })
84 | if err != nil {
85 | logrus.WithError(err).Error("Failed to encode User")
86 | writeInternalServerError(w)
87 | return
88 | }
89 | }
90 |
91 | func (s *Server) updateUserHandler(w http.ResponseWriter, r *http.Request) {
92 | email, err := url.PathUnescape(mux.Vars(r)["email"])
93 | if err != nil {
94 | writeError(w, http.StatusBadRequest, "could not unescape email")
95 | return
96 | }
97 |
98 | // when email has not been set 'notFoundHandler' handler will be used
99 |
100 | var user User
101 | err = json.NewDecoder(r.Body).Decode(&user)
102 | if err != nil {
103 | writeError(w, http.StatusBadRequest, "invalid JSON")
104 | return
105 | }
106 |
107 | if user.EMail != "" {
108 | writeError(w, http.StatusBadRequest, "email can not be changed")
109 | return
110 | }
111 |
112 | updatedUser, err := s.p.UpdateUser(email, internal.User{
113 | Password: user.Password,
114 | Claims: user.Claims,
115 | })
116 | if err != nil {
117 | if errors.Is(err, internal.ErrUserNotFound) {
118 | writeError(w, http.StatusNotFound, "User with given email doesn't exists")
119 | return
120 | }
121 |
122 | logrus.WithError(err).Error("Failed to update User")
123 | writeInternalServerError(w)
124 | return
125 | }
126 |
127 | err = json.NewEncoder(w).Encode(User{
128 | EMail: updatedUser.EMail,
129 | Password: updatedUser.Password,
130 | Claims: updatedUser.Claims,
131 | })
132 | if err != nil {
133 | logrus.WithError(err).Error("Failed to encode User")
134 | writeInternalServerError(w)
135 | return
136 | }
137 | }
138 |
139 | func (s *Server) deleteUserHandler(w http.ResponseWriter, r *http.Request) {
140 | email, err := url.PathUnescape(mux.Vars(r)["email"])
141 | if err != nil {
142 | writeError(w, http.StatusBadRequest, "could not unescape email")
143 | return
144 | }
145 |
146 | // when email has not been set 'notFoundHandler' handler will be used
147 |
148 | err = s.p.DeleteUser(email)
149 | if err != nil {
150 | if errors.Is(err, internal.ErrUserNotFound) {
151 | writeError(w, http.StatusNotFound, "User with given email doesnt already exists")
152 | return
153 | }
154 |
155 | logrus.WithError(err).Error("Failed to delete User")
156 | writeInternalServerError(w)
157 | return
158 | }
159 |
160 | w.WriteHeader(http.StatusNoContent)
161 | }
162 |
--------------------------------------------------------------------------------
/cmd/provider/login_component_test.go:
--------------------------------------------------------------------------------
1 | // +build component
2 |
3 | package main
4 |
5 | import (
6 | "bytes"
7 | "crypto/ecdsa"
8 | "crypto/x509"
9 | "encoding/json"
10 | "encoding/pem"
11 | "errors"
12 | "fmt"
13 | "github.com/golang-jwt/jwt"
14 | "net/http"
15 | "testing"
16 | )
17 |
18 | func TestLogin(t *testing.T) {
19 | email := "login_test@leberkleber.io"
20 | password := "s3cr3t"
21 |
22 | createUser(t, email, password)
23 | token, _, authorized := loginUser(t, email, password)
24 | if !authorized {
25 | t.Fatal("could not login user")
26 | }
27 |
28 | claims := validateJWT(t, token)
29 | expectedJWTAudience := ""
30 | if claims["aud"] != expectedJWTAudience {
31 | t.Errorf("unexpected aud-privateClaim value. Expected: %q. Given: %q", expectedJWTAudience, claims["aud"])
32 | }
33 |
34 | expectedJWTIssuer := ""
35 | if claims["iss"] != expectedJWTIssuer {
36 | t.Errorf("unexpected iss-privateClaim value. Expected: %q. Given: %q", expectedJWTIssuer, claims["iss"])
37 | }
38 |
39 | expectedJWTSubject := ""
40 | if claims["sub"] != expectedJWTSubject {
41 | t.Errorf("unexpected sub-privateClaim value. Expected: %q. Given: %q", expectedJWTSubject, claims["sub"])
42 | }
43 |
44 | expectedCustomClaim := "customClaimValue"
45 | if claims["myCustomClaim"] != expectedCustomClaim {
46 | t.Errorf("unexpected myCustomClaim value. Expected: %q. Given: %q", expectedCustomClaim, claims["myCustomClaim"])
47 | }
48 |
49 | if claims["id"] == "" {
50 | t.Error("jwt id has not been set")
51 | }
52 |
53 | if claims["exp"] == "" {
54 | t.Error("jwt exp has not been set")
55 | }
56 |
57 | if claims["iat"] == "" {
58 | t.Error("jwt iat has not been set")
59 | }
60 |
61 | if claims["nbf"] == "" {
62 | t.Error("jwt nbf has not been set")
63 | }
64 |
65 | if claims["email"] != email {
66 | t.Errorf("unexpected email-privateClaim value. Expected: %q. Given: %q", email, claims["email"])
67 | }
68 | }
69 |
70 | func validateJWT(t *testing.T, tokenString string) jwt.MapClaims {
71 | pubKey, err := decodeECDSApubKey(`-----BEGIN PUBLIC KEY-----
72 | MIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBQSa/dFpXRqz6aQQmx6sNpxl3mn8Z
73 | 0o+qgfgOxPAPxu+JppsCGqrX/6SeUI6kz3AFVABGBU8/9Ejzt7Ty9WJt1dEB+035
74 | 03+xLnmmyaj3bEhkerr229mDgPb8uDlPEl6f/Wv+Ma/eIIloCo8WJAe8YsviImbF
75 | hAV1NK8+62/iMCfNj30=
76 | -----END PUBLIC KEY-----
77 | `)
78 | if err != nil {
79 | t.Fatalf("Failed to parse public key: %s", err)
80 | }
81 |
82 | var claims jwt.MapClaims
83 | token, err := jwt.ParseWithClaims(tokenString, &claims, func(token *jwt.Token) (interface{}, error) {
84 | // Don't forget to validate the alg is what you expect:
85 | if _, ok := token.Method.(*jwt.SigningMethodECDSA); !ok {
86 | return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
87 | }
88 |
89 | return pubKey, nil
90 | })
91 | if err != nil {
92 | t.Fatalf("Failed to parse jwt: %s", err)
93 | }
94 |
95 | if !token.Valid {
96 | t.Fatalf("Given token ist not valid. Token: %s", tokenString)
97 | }
98 |
99 | return claims
100 | }
101 |
102 | func decodeECDSApubKey(pemEncodedPub string) (*ecdsa.PublicKey, error) {
103 | blockPub, _ := pem.Decode([]byte(pemEncodedPub))
104 | if blockPub == nil {
105 | return nil, errors.New("no valid public key found")
106 | }
107 | x509EncodedPub := blockPub.Bytes
108 | genericPublicKey, err := x509.ParsePKIXPublicKey(x509EncodedPub)
109 | if err != nil {
110 | return nil, err
111 | }
112 | publicKey := genericPublicKey.(*ecdsa.PublicKey)
113 |
114 | return publicKey, nil
115 | }
116 |
117 | func loginUser(t *testing.T, email, password string) (string, string, bool) {
118 | t.Helper()
119 | resp, err := http.Post(
120 | "http://simple-jwt-provider/v1/auth/login",
121 | "application/json",
122 | bytes.NewReader([]byte(fmt.Sprintf(`{"email": %q, "password": %q}`, email, password))),
123 | )
124 | if err != nil {
125 | t.Fatalf("Failed to login with response: %v cause: %s", resp, err)
126 | }
127 |
128 | responseBody := struct {
129 | AccessToken string `json:"access_token"`
130 | RefreshToken string `json:"refresh_token"`
131 | ErrorMessage string `json:"message"`
132 | }{}
133 |
134 | defer resp.Body.Close()
135 | err = json.NewDecoder(resp.Body).Decode(&responseBody)
136 | if err != nil {
137 | t.Fatalf("Failed to read response body: %s", err)
138 | }
139 |
140 | if resp.StatusCode == http.StatusUnauthorized {
141 | return "", "", false
142 | }
143 | if resp.StatusCode != http.StatusOK {
144 | t.Fatalf("Invalid response status code. Expected: %d, Given: %d, Body: %s", http.StatusOK, resp.StatusCode, responseBody.ErrorMessage)
145 | }
146 |
147 | return responseBody.AccessToken, responseBody.RefreshToken, true
148 | }
149 |
--------------------------------------------------------------------------------
/internal/web/middleware/basicauth_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "io/ioutil"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 | )
9 |
10 | func TestBasicAuth(t *testing.T) {
11 | tests := []struct {
12 | name string
13 | configuredUsername string
14 | configuredPassword string
15 | requestUsername string
16 | requestPassword string
17 | expectedNextHasBeenCalled bool
18 | expectedUnauthorizedResponseHeader bool
19 | }{
20 | {
21 | name: "Happycase plain password",
22 | configuredUsername: "username",
23 | configuredPassword: "password",
24 | requestUsername: "username",
25 | requestPassword: "password",
26 | expectedNextHasBeenCalled: true,
27 | expectedUnauthorizedResponseHeader: false,
28 | },
29 | {
30 | name: "Happycase bcrypted password",
31 | configuredUsername: "username",
32 | configuredPassword: "bcrypt:$2y$12$YLjvF/KRsQ6999oazNXBR.DvZ3K2t8boyPFgXt84PFt4yLN3zVKw2",
33 | requestUsername: "username",
34 | requestPassword: "myPassword",
35 | expectedNextHasBeenCalled: true,
36 | expectedUnauthorizedResponseHeader: false,
37 | },
38 | {
39 | name: "Missing auth header",
40 | configuredUsername: "username",
41 | configuredPassword: "password",
42 | expectedNextHasBeenCalled: false,
43 | expectedUnauthorizedResponseHeader: true,
44 | },
45 | {
46 | name: "Invalid username",
47 | configuredUsername: "username",
48 | configuredPassword: "password",
49 | requestUsername: "nope",
50 | requestPassword: "password",
51 | expectedNextHasBeenCalled: false,
52 | expectedUnauthorizedResponseHeader: true,
53 | },
54 | {
55 | name: "Invalid password",
56 | configuredUsername: "username",
57 | configuredPassword: "password",
58 | requestUsername: "username",
59 | requestPassword: "nope",
60 | expectedNextHasBeenCalled: false,
61 | expectedUnauthorizedResponseHeader: true,
62 | },
63 | {
64 | name: "Invalid password (bcrypted)",
65 | configuredUsername: "username",
66 | configuredPassword: "bcrypt:$2y$12$YLjvF/KRsQ6999oazNXBR.DvZ3K2t8boyPFgXt84PFt4yLN3zVKw2",
67 | requestUsername: "username",
68 | requestPassword: "nope",
69 | expectedNextHasBeenCalled: false,
70 | expectedUnauthorizedResponseHeader: true,
71 | },
72 | }
73 |
74 | for _, tt := range tests {
75 | t.Run(tt.name, func(t *testing.T) {
76 | nextHasBeenCalled := false
77 |
78 | w := &httptest.ResponseRecorder{}
79 | r, err := http.NewRequest("GET", "/", nil)
80 | if err != nil {
81 | t.Fatalf("Failed to create test request: %s", err)
82 | }
83 | if tt.requestUsername != "" || tt.requestPassword != "" {
84 | r.SetBasicAuth(tt.requestUsername, tt.requestPassword)
85 | }
86 | next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
87 | _, err := w.Write([]byte("done"))
88 | if err != nil {
89 | t.Fatalf("Could not write http request respons: %s", err)
90 | }
91 | nextHasBeenCalled = true
92 | })
93 |
94 | BasicAuth(tt.configuredUsername, tt.configuredPassword)(next).ServeHTTP(w, r)
95 |
96 | if tt.expectedNextHasBeenCalled != nextHasBeenCalled {
97 | t.Errorf("Call of next handler is not as expected. Given: %t, Exected: %t", nextHasBeenCalled, tt.expectedNextHasBeenCalled)
98 | }
99 |
100 | if tt.expectedUnauthorizedResponseHeader {
101 | response := w.Result()
102 | expectedStatusCode := http.StatusForbidden
103 | if response.StatusCode != expectedStatusCode {
104 | t.Errorf("Unexpected response code. Given: %d, Expected: %d", w.Code, expectedStatusCode)
105 | }
106 |
107 | body, err := ioutil.ReadAll(response.Body)
108 | if err != nil {
109 | t.Fatalf("Failed to read response body: %s", err)
110 | }
111 |
112 | expectedResponseBody := ""
113 | if string(body) != expectedResponseBody {
114 | t.Errorf("Unexpected response body value. \nGiven: %q \nExected: %q",
115 | string(body),
116 | expectedResponseBody)
117 | }
118 | }
119 | })
120 | }
121 | }
122 |
--------------------------------------------------------------------------------
/internal/web/auth.go:
--------------------------------------------------------------------------------
1 | package web
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "github.com/leberKleber/simple-jwt-provider/internal"
7 | "github.com/sirupsen/logrus"
8 | "net/http"
9 | )
10 |
11 | func (s *Server) loginHandler(w http.ResponseWriter, r *http.Request) {
12 | requestBody := struct {
13 | EMail string `json:"email"`
14 | Password string `json:"password"`
15 | }{}
16 |
17 | err := json.NewDecoder(r.Body).Decode(&requestBody)
18 | if err != nil {
19 | writeError(w, http.StatusBadRequest, "invalid JSON")
20 | return
21 | }
22 |
23 | if requestBody.EMail == "" {
24 | writeError(w, http.StatusBadRequest, "email must be set")
25 | return
26 | }
27 |
28 | if requestBody.Password == "" {
29 | writeError(w, http.StatusBadRequest, "password must be set")
30 | return
31 | }
32 |
33 | accessToken, refreshToken, err := s.p.Login(requestBody.EMail, requestBody.Password)
34 | if err != nil {
35 | if errors.Is(err, internal.ErrIncorrectPassword) || errors.Is(err, internal.ErrUserNotFound) {
36 | logrus.WithField("email", requestBody.EMail).Warn("Somebody tried to login with invalid credentials")
37 | writeError(w, http.StatusUnauthorized, "invalid credentials")
38 | return
39 | }
40 |
41 | logrus.WithError(err).Error("Failed to login User")
42 | writeInternalServerError(w)
43 | return
44 | }
45 |
46 | err = json.NewEncoder(w).Encode(struct {
47 | AccessToken string `json:"access_token"`
48 | RefreshToken string `json:"refresh_token"`
49 | }{
50 | AccessToken: accessToken,
51 | RefreshToken: refreshToken,
52 | })
53 | if err != nil {
54 | logrus.WithError(err).Error("Failed marshal request response")
55 | writeInternalServerError(w)
56 | return
57 | }
58 | }
59 |
60 | func (s *Server) refreshHandler(w http.ResponseWriter, r *http.Request) {
61 | requestBody := struct {
62 | RefreshToken string `json:"refresh_token"`
63 | }{}
64 |
65 | err := json.NewDecoder(r.Body).Decode(&requestBody)
66 | if err != nil {
67 | writeError(w, http.StatusBadRequest, "invalid JSON")
68 | return
69 | }
70 |
71 | if requestBody.RefreshToken == "" {
72 | writeError(w, http.StatusBadRequest, "refresh_token must be set")
73 | return
74 | }
75 |
76 | newAccessToken, newRefreshToken, err := s.p.Refresh(requestBody.RefreshToken)
77 | if err != nil {
78 | if errors.Is(err, internal.ErrInvalidToken) ||
79 | errors.Is(err, internal.ErrUserNotFound) ||
80 | errors.Is(err, internal.ErrTokenNotParsable) {
81 | logrus.Debug("failed to refresh user auth", err)
82 | writeError(w, http.StatusUnauthorized, "invalid refresh-token and/or email")
83 | return
84 | }
85 |
86 | logrus.WithError(err).Error("Failed to refresh token")
87 | writeInternalServerError(w)
88 | return
89 | }
90 |
91 | err = json.NewEncoder(w).Encode(struct {
92 | AccessToken string `json:"access_token"`
93 | RefreshToken string `json:"refresh_token"`
94 | }{
95 | AccessToken: newAccessToken,
96 | RefreshToken: newRefreshToken,
97 | })
98 | if err != nil {
99 | logrus.WithError(err).Error("Failed marshal request response")
100 | writeInternalServerError(w)
101 | return
102 | }
103 | }
104 |
105 | func (s *Server) passwordResetRequestHandler(w http.ResponseWriter, r *http.Request) {
106 | requestBody := struct {
107 | EMail string `json:"email"`
108 | }{}
109 |
110 | err := json.NewDecoder(r.Body).Decode(&requestBody)
111 | if err != nil {
112 | writeError(w, http.StatusBadRequest, "invalid JSON")
113 | return
114 | }
115 |
116 | if requestBody.EMail == "" {
117 | writeError(w, http.StatusBadRequest, "email must be set")
118 | return
119 | }
120 |
121 | err = s.p.CreatePasswordResetRequest(requestBody.EMail)
122 | if err != nil {
123 | if errors.Is(err, internal.ErrUserNotFound) {
124 | logrus.WithField("email", requestBody.EMail).Warn("Somebody tried to create a reset-password-request for non existing User")
125 | w.WriteHeader(http.StatusCreated)
126 | return
127 | }
128 |
129 | logrus.WithError(err).Error("Failed to create password-reset-request")
130 | writeInternalServerError(w)
131 | return
132 | }
133 |
134 | w.WriteHeader(http.StatusCreated)
135 | }
136 |
137 | func (s *Server) passwordResetHandler(w http.ResponseWriter, r *http.Request) {
138 | requestBody := struct {
139 | EMail string `json:"email"`
140 | ResetToken string `json:"reset_token"`
141 | Password string `json:"password"`
142 | }{}
143 |
144 | err := json.NewDecoder(r.Body).Decode(&requestBody)
145 | if err != nil {
146 | writeError(w, http.StatusBadRequest, "invalid JSON")
147 | return
148 | }
149 |
150 | if requestBody.EMail == "" {
151 | writeError(w, http.StatusBadRequest, "email must be set")
152 | return
153 | }
154 |
155 | if requestBody.ResetToken == "" {
156 | writeError(w, http.StatusBadRequest, "reset-token must be set")
157 | return
158 | }
159 |
160 | if requestBody.Password == "" {
161 | writeError(w, http.StatusBadRequest, "password must be set")
162 | return
163 | }
164 |
165 | err = s.p.ResetPassword(requestBody.EMail, requestBody.ResetToken, requestBody.Password)
166 | if err != nil {
167 | if errors.Is(err, internal.ErrNoValidTokenFound) {
168 | writeError(w, http.StatusBadRequest, "reset-token is invalid or token email combination is not correct")
169 | return
170 | }
171 | logrus.WithError(err).Error("Failed to create password-reset-request")
172 | writeInternalServerError(w)
173 | return
174 | }
175 |
176 | w.WriteHeader(http.StatusNoContent)
177 | }
178 |
--------------------------------------------------------------------------------
/internal/web/server_test.go:
--------------------------------------------------------------------------------
1 | package web
2 |
3 | import (
4 | "bytes"
5 | "encoding/json"
6 | "errors"
7 | "gotest.tools/assert"
8 | "io/ioutil"
9 | "net/http"
10 | "net/http/httptest"
11 | "testing"
12 | )
13 |
14 | func TestBasicAuth(t *testing.T) {
15 | expectedResponseCode := http.StatusForbidden
16 | expectedResponseBody := `{"message":"forbidden"}`
17 |
18 | toTest := NewServer(nil, true, "un", "pw")
19 | testServer := httptest.NewServer(toTest.h)
20 |
21 | req, err := http.NewRequest(http.MethodPost, testServer.URL+"/v1/admin/users", nil)
22 | if err != nil {
23 | t.Fatalf("Failed to build http request: %s", err)
24 | }
25 |
26 | req.SetBasicAuth("invalid", "invalid")
27 |
28 | resp, err := http.DefaultClient.Do(req)
29 | if err != nil {
30 | t.Fatalf("Failed to call server cause: %s", err)
31 | }
32 | defer resp.Body.Close()
33 |
34 | if resp.StatusCode != expectedResponseCode {
35 | t.Errorf("Request respond with unexpected status code. Expected: %d, Given: %d", expectedResponseCode, resp.StatusCode)
36 | }
37 |
38 | respBody, err := ioutil.ReadAll(resp.Body)
39 | if err != nil {
40 | t.Fatalf("Failed to read response body: %s", err)
41 | }
42 |
43 | var compactedRespBodyAsBytes []byte
44 | if resp.ContentLength > 0 {
45 | compactedRespBody := &bytes.Buffer{}
46 | err = json.Compact(compactedRespBody, respBody)
47 | if err != nil {
48 | t.Fatalf("Failed to compact json: %s", err)
49 | }
50 |
51 | compactedRespBodyAsBytes = compactedRespBody.Bytes()
52 | }
53 |
54 | if !bytes.Equal(compactedRespBodyAsBytes, []byte(expectedResponseBody)) {
55 | t.Errorf("Request response body is not as expected. Expected: %q, Given: %q", expectedResponseBody, string(compactedRespBodyAsBytes))
56 | }
57 | }
58 |
59 | func TestNotFoundHandler(t *testing.T) {
60 | expectedResponseCode := http.StatusNotFound
61 | expectedResponseBody := `{"message":"endpoint not found"}`
62 |
63 | toTest := NewServer(nil, false, "", "")
64 | testServer := httptest.NewServer(toTest.h)
65 |
66 | req, err := http.NewRequest(http.MethodGet, testServer.URL+"/unexpected/endpoint", nil)
67 | if err != nil {
68 | t.Fatalf("Failed to build http request: %s", err)
69 | }
70 |
71 | req.SetBasicAuth("invalid", "invalid")
72 |
73 | resp, err := http.DefaultClient.Do(req)
74 | if err != nil {
75 | t.Fatalf("Failed to call server cause: %s", err)
76 | }
77 | defer resp.Body.Close()
78 |
79 | if resp.StatusCode != expectedResponseCode {
80 | t.Errorf("Request respond with unexpected status code. Expected: %d, Given: %d", expectedResponseCode, resp.StatusCode)
81 | }
82 |
83 | respBody, err := ioutil.ReadAll(resp.Body)
84 | if err != nil {
85 | t.Fatalf("Failed to read response body: %s", err)
86 | }
87 |
88 | var compactedRespBodyAsBytes []byte
89 | if resp.ContentLength > 0 {
90 | compactedRespBody := &bytes.Buffer{}
91 | err = json.Compact(compactedRespBody, respBody)
92 | if err != nil {
93 | t.Fatalf("Failed to compact json: %s", err)
94 | }
95 |
96 | compactedRespBodyAsBytes = compactedRespBody.Bytes()
97 | }
98 |
99 | if !bytes.Equal(compactedRespBodyAsBytes, []byte(expectedResponseBody)) {
100 | t.Errorf("Request response body is not as expected. Expected: %q, Given: %q", expectedResponseBody, string(compactedRespBodyAsBytes))
101 | }
102 | }
103 |
104 | func TestMethodNotAllowedHandler(t *testing.T) {
105 | expectedResponseCode := http.StatusMethodNotAllowed
106 | expectedResponseBody := `{"message":"method not allowed"}`
107 |
108 | toTest := NewServer(nil, false, "", "")
109 | testServer := httptest.NewServer(toTest.h)
110 |
111 | req, err := http.NewRequest(http.MethodGet, testServer.URL+"/v1/auth/password-reset-request", nil)
112 | if err != nil {
113 | t.Fatalf("Failed to build http request: %s", err)
114 | }
115 |
116 | req.SetBasicAuth("invalid", "invalid")
117 |
118 | resp, err := http.DefaultClient.Do(req)
119 | if err != nil {
120 | t.Fatalf("Failed to call server cause: %s", err)
121 | }
122 | defer resp.Body.Close()
123 |
124 | if resp.StatusCode != expectedResponseCode {
125 | t.Errorf("Request respond with unexpected status code. Expected: %d, Given: %d", expectedResponseCode, resp.StatusCode)
126 | }
127 |
128 | respBody, err := ioutil.ReadAll(resp.Body)
129 | if err != nil {
130 | t.Fatalf("Failed to read response body: %s", err)
131 | }
132 |
133 | var compactedRespBodyAsBytes []byte
134 | if resp.ContentLength > 0 {
135 | compactedRespBody := &bytes.Buffer{}
136 | err = json.Compact(compactedRespBody, respBody)
137 | if err != nil {
138 | t.Fatalf("Failed to compact json: %s", err)
139 | }
140 |
141 | compactedRespBodyAsBytes = compactedRespBody.Bytes()
142 | }
143 |
144 | if !bytes.Equal(compactedRespBodyAsBytes, []byte(expectedResponseBody)) {
145 | t.Errorf("Request response body is not as expected. Expected: %q, Given: %q", expectedResponseBody, string(compactedRespBodyAsBytes))
146 | }
147 | }
148 |
149 | func TestServer_ListenAndServe(t *testing.T) {
150 | oldHttpListenAndServe := httpListenAndServe
151 | defer func() {
152 | httpListenAndServe = oldHttpListenAndServe
153 | }()
154 |
155 | addr := "myAddr"
156 | err := errors.New("myErr")
157 | handler := httpHandlerMock{ID: "myHTTPHandlerMock"}
158 |
159 | var givenAddr string
160 | var givenHandler http.Handler
161 |
162 | httpListenAndServe = func(addr string, handler http.Handler) error {
163 | givenAddr = addr
164 | givenHandler = handler
165 | return err
166 | }
167 |
168 | s := Server{
169 | h: httpHandlerMock{ID: "myHTTPHandlerMock"},
170 | }
171 |
172 | returnErr := s.ListenAndServe(addr)
173 |
174 | assert.Equal(t, givenAddr, addr, "Unexpected return addr")
175 | assert.Equal(t, returnErr, err, "Unexpected return error")
176 | assert.Equal(t, givenHandler, handler, "ListenAndServe called with unexpected handler")
177 |
178 | }
179 |
180 | type httpHandlerMock struct {
181 | ID string
182 | }
183 |
184 | func (t httpHandlerMock) ServeHTTP(http.ResponseWriter, *http.Request) {}
185 |
--------------------------------------------------------------------------------
/internal/jwt_generator_moq_test.go:
--------------------------------------------------------------------------------
1 | // Code generated by moq; DO NOT EDIT.
2 | // github.com/matryer/moq
3 |
4 | package internal
5 |
6 | import (
7 | "github.com/golang-jwt/jwt"
8 | "sync"
9 | )
10 |
11 | // Ensure, that JWTProviderMock does implement JWTProvider.
12 | // If this is not the case, regenerate this file with moq.
13 | var _ JWTProvider = &JWTProviderMock{}
14 |
15 | // JWTProviderMock is a mock implementation of JWTProvider.
16 | //
17 | // func TestSomethingThatUsesJWTProvider(t *testing.T) {
18 | //
19 | // // make and configure a mocked JWTProvider
20 | // mockedJWTProvider := &JWTProviderMock{
21 | // GenerateAccessTokenFunc: func(email string, userClaims map[string]interface{}) (string, error) {
22 | // panic("mock out the GenerateAccessToken method")
23 | // },
24 | // GenerateRefreshTokenFunc: func(email string) (string, string, error) {
25 | // panic("mock out the GenerateRefreshToken method")
26 | // },
27 | // IsTokenValidFunc: func(token string) (bool, jwt.MapClaims, error) {
28 | // panic("mock out the IsTokenValid method")
29 | // },
30 | // }
31 | //
32 | // // use mockedJWTProvider in code that requires JWTProvider
33 | // // and then make assertions.
34 | //
35 | // }
36 | type JWTProviderMock struct {
37 | // GenerateAccessTokenFunc mocks the GenerateAccessToken method.
38 | GenerateAccessTokenFunc func(email string, userClaims map[string]interface{}) (string, error)
39 |
40 | // GenerateRefreshTokenFunc mocks the GenerateRefreshToken method.
41 | GenerateRefreshTokenFunc func(email string) (string, string, error)
42 |
43 | // IsTokenValidFunc mocks the IsTokenValid method.
44 | IsTokenValidFunc func(token string) (bool, jwt.MapClaims, error)
45 |
46 | // calls tracks calls to the methods.
47 | calls struct {
48 | // GenerateAccessToken holds details about calls to the GenerateAccessToken method.
49 | GenerateAccessToken []struct {
50 | // Email is the email argument value.
51 | Email string
52 | // UserClaims is the userClaims argument value.
53 | UserClaims map[string]interface{}
54 | }
55 | // GenerateRefreshToken holds details about calls to the GenerateRefreshToken method.
56 | GenerateRefreshToken []struct {
57 | // Email is the email argument value.
58 | Email string
59 | }
60 | // IsTokenValid holds details about calls to the IsTokenValid method.
61 | IsTokenValid []struct {
62 | // Token is the token argument value.
63 | Token string
64 | }
65 | }
66 | lockGenerateAccessToken sync.RWMutex
67 | lockGenerateRefreshToken sync.RWMutex
68 | lockIsTokenValid sync.RWMutex
69 | }
70 |
71 | // GenerateAccessToken calls GenerateAccessTokenFunc.
72 | func (mock *JWTProviderMock) GenerateAccessToken(email string, userClaims map[string]interface{}) (string, error) {
73 | if mock.GenerateAccessTokenFunc == nil {
74 | panic("JWTProviderMock.GenerateAccessTokenFunc: method is nil but JWTProvider.GenerateAccessToken was just called")
75 | }
76 | callInfo := struct {
77 | Email string
78 | UserClaims map[string]interface{}
79 | }{
80 | Email: email,
81 | UserClaims: userClaims,
82 | }
83 | mock.lockGenerateAccessToken.Lock()
84 | mock.calls.GenerateAccessToken = append(mock.calls.GenerateAccessToken, callInfo)
85 | mock.lockGenerateAccessToken.Unlock()
86 | return mock.GenerateAccessTokenFunc(email, userClaims)
87 | }
88 |
89 | // GenerateAccessTokenCalls gets all the calls that were made to GenerateAccessToken.
90 | // Check the length with:
91 | // len(mockedJWTProvider.GenerateAccessTokenCalls())
92 | func (mock *JWTProviderMock) GenerateAccessTokenCalls() []struct {
93 | Email string
94 | UserClaims map[string]interface{}
95 | } {
96 | var calls []struct {
97 | Email string
98 | UserClaims map[string]interface{}
99 | }
100 | mock.lockGenerateAccessToken.RLock()
101 | calls = mock.calls.GenerateAccessToken
102 | mock.lockGenerateAccessToken.RUnlock()
103 | return calls
104 | }
105 |
106 | // GenerateRefreshToken calls GenerateRefreshTokenFunc.
107 | func (mock *JWTProviderMock) GenerateRefreshToken(email string) (string, string, error) {
108 | if mock.GenerateRefreshTokenFunc == nil {
109 | panic("JWTProviderMock.GenerateRefreshTokenFunc: method is nil but JWTProvider.GenerateRefreshToken was just called")
110 | }
111 | callInfo := struct {
112 | Email string
113 | }{
114 | Email: email,
115 | }
116 | mock.lockGenerateRefreshToken.Lock()
117 | mock.calls.GenerateRefreshToken = append(mock.calls.GenerateRefreshToken, callInfo)
118 | mock.lockGenerateRefreshToken.Unlock()
119 | return mock.GenerateRefreshTokenFunc(email)
120 | }
121 |
122 | // GenerateRefreshTokenCalls gets all the calls that were made to GenerateRefreshToken.
123 | // Check the length with:
124 | // len(mockedJWTProvider.GenerateRefreshTokenCalls())
125 | func (mock *JWTProviderMock) GenerateRefreshTokenCalls() []struct {
126 | Email string
127 | } {
128 | var calls []struct {
129 | Email string
130 | }
131 | mock.lockGenerateRefreshToken.RLock()
132 | calls = mock.calls.GenerateRefreshToken
133 | mock.lockGenerateRefreshToken.RUnlock()
134 | return calls
135 | }
136 |
137 | // IsTokenValid calls IsTokenValidFunc.
138 | func (mock *JWTProviderMock) IsTokenValid(token string) (bool, jwt.MapClaims, error) {
139 | if mock.IsTokenValidFunc == nil {
140 | panic("JWTProviderMock.IsTokenValidFunc: method is nil but JWTProvider.IsTokenValid was just called")
141 | }
142 | callInfo := struct {
143 | Token string
144 | }{
145 | Token: token,
146 | }
147 | mock.lockIsTokenValid.Lock()
148 | mock.calls.IsTokenValid = append(mock.calls.IsTokenValid, callInfo)
149 | mock.lockIsTokenValid.Unlock()
150 | return mock.IsTokenValidFunc(token)
151 | }
152 |
153 | // IsTokenValidCalls gets all the calls that were made to IsTokenValid.
154 | // Check the length with:
155 | // len(mockedJWTProvider.IsTokenValidCalls())
156 | func (mock *JWTProviderMock) IsTokenValidCalls() []struct {
157 | Token string
158 | } {
159 | var calls []struct {
160 | Token string
161 | }
162 | mock.lockIsTokenValid.RLock()
163 | calls = mock.calls.IsTokenValid
164 | mock.lockIsTokenValid.RUnlock()
165 | return calls
166 | }
167 |
--------------------------------------------------------------------------------
/internal/jwt/generate_test.go:
--------------------------------------------------------------------------------
1 | package jwt
2 |
3 | import (
4 | "crypto/ecdsa"
5 | "crypto/x509"
6 | "encoding/pem"
7 | "errors"
8 | "fmt"
9 | "github.com/golang-jwt/jwt"
10 | "github.com/google/uuid"
11 | "testing"
12 | "time"
13 | )
14 |
15 | func TestGenerator_GenerateAccessToken(t *testing.T) {
16 | g, err := NewProvider(jwtPrvKey, 4*time.Hour, "audience", "issuer", "subject")
17 | if err != nil {
18 | t.Fatalf("failed to crreate new generator: %s", err)
19 | }
20 |
21 | generatedJWT, err := g.GenerateAccessToken("myMailAddress", map[string]interface{}{"myCustomClaim": "mialc"})
22 | if err != nil {
23 | t.Fatalf("failed to generate jwt: %s", err)
24 | }
25 | claims := validateJWT(t, generatedJWT)
26 | expectedJWTAudience := "audience"
27 | if claims["aud"] != expectedJWTAudience {
28 | t.Errorf("unexpected aud-privateClaim value. Expected: %q. Given: %q", expectedJWTAudience, claims["aud"])
29 | }
30 |
31 | expectedJWTIssuer := "issuer"
32 | if claims["iss"] != expectedJWTIssuer {
33 | t.Errorf("unexpected iss-privateClaim value. Expected: %q. Given: %q", expectedJWTIssuer, claims["iss"])
34 | }
35 |
36 | expectedJWTSubject := "subject"
37 | if claims["sub"] != expectedJWTSubject {
38 | t.Errorf("unexpected sub-privateClaim value. Expected: %q. Given: %q", expectedJWTSubject, claims["sub"])
39 | }
40 |
41 | if claims["id"] == "" {
42 | t.Error("jwt id has not been set")
43 | }
44 |
45 | if claims["exp"] == "" {
46 | t.Error("jwt exp has not been set")
47 | }
48 |
49 | if claims["iat"] == "" {
50 | t.Error("jwt iat has not been set")
51 | }
52 |
53 | if claims["nbf"] == "" {
54 | t.Error("jwt nbf has not been set")
55 | }
56 |
57 | expectedCustomClaim := "mialc"
58 | if claims["myCustomClaim"] != expectedCustomClaim {
59 | t.Errorf("unexpected email-privateClaim value. Expected: %q. Given: %q", expectedCustomClaim, claims["myCustomClaim"])
60 | }
61 |
62 | expectedJWTEMail := "myMailAddress"
63 | if claims["email"] != expectedJWTEMail {
64 | t.Errorf("unexpected email-privateClaim value. Expected: %q. Given: %q", expectedJWTEMail, claims["email"])
65 | }
66 | }
67 |
68 | func TestGenerator_GenerateAccessToken_FailedToGenerateUUID(t *testing.T) {
69 | oldUUIDNewRandom := uuidNewRandom
70 | defer func() { uuidNewRandom = oldUUIDNewRandom }()
71 |
72 | uuidNewRandom = func() (uuid.UUID, error) {
73 | return uuid.UUID{}, errors.New("nope")
74 | }
75 |
76 | _, err := Provider{}.GenerateAccessToken("my.email.de", nil)
77 |
78 | expectedError := errors.New("failed to generate jwt-id: nope")
79 | if fmt.Sprint(err) != fmt.Sprint(expectedError) {
80 | t.Fatalf("unexpected error. Expected: %q. Gven:: %q", expectedError, err)
81 | }
82 | }
83 |
84 | func TestGenerator_GenerateAccessToken_FailedToSignToken(t *testing.T) {
85 | p, err := NewProvider(jwtPrvKey, 4*time.Hour, "audience", "issuer", "subject")
86 | if err != nil {
87 | t.Fatalf("failed to crreate new generator: %s", err)
88 | }
89 |
90 | _, err = p.GenerateAccessToken("my.email.de", map[string]interface{}{
91 | "unmarshableClaim": make(chan string),
92 | })
93 |
94 | expectedError := errors.New("failed to sign access-token: json: unsupported type: chan string")
95 | if fmt.Sprint(err) != fmt.Sprint(expectedError) {
96 | t.Fatalf("unexpected error. Expected: %q. Gven:: %q", expectedError, err)
97 | }
98 | }
99 |
100 | func TestGenerator_GenerateRefreshToken(t *testing.T) {
101 | g, err := NewProvider(jwtPrvKey, 4*time.Hour, "audience", "issuer", "subject")
102 | if err != nil {
103 | t.Fatalf("failed to crreate new generator: %s", err)
104 | }
105 |
106 | generatedJWT, jwtID, err := g.GenerateRefreshToken("myMailAddress")
107 | if err != nil {
108 | t.Fatalf("failed to generate jwt: %s", err)
109 | }
110 | if jwtID == "" {
111 | t.Error("generate returns no jwtID")
112 | }
113 |
114 | claims := validateJWT(t, generatedJWT)
115 | expectedJWTAudience := "audience"
116 | if claims["aud"] != expectedJWTAudience {
117 | t.Errorf("unexpected aud-privateClaim value. Expected: %q. Given: %q", expectedJWTAudience, claims["aud"])
118 | }
119 |
120 | expectedJWTIssuer := "issuer"
121 | if claims["iss"] != expectedJWTIssuer {
122 | t.Errorf("unexpected iss-privateClaim value. Expected: %q. Given: %q", expectedJWTIssuer, claims["iss"])
123 | }
124 |
125 | expectedJWTSubject := "subject"
126 | if claims["sub"] != expectedJWTSubject {
127 | t.Errorf("unexpected sub-privateClaim value. Expected: %q. Given: %q", expectedJWTSubject, claims["sub"])
128 | }
129 |
130 | if claims["id"] == "" {
131 | t.Error("jwt id has not been set")
132 | }
133 |
134 | if claims["exp"] == "" {
135 | t.Error("jwt exp has not been set")
136 | }
137 |
138 | if claims["iat"] == "" {
139 | t.Error("jwt iat has not been set")
140 | }
141 |
142 | if claims["nbf"] == "" {
143 | t.Error("jwt nbf has not been set")
144 | }
145 |
146 | expectedJWTEMail := "myMailAddress"
147 | if claims["email"] != expectedJWTEMail {
148 | t.Errorf("unexpected email-privateClaim value. Expected: %q. Given: %q", expectedJWTEMail, claims["email"])
149 | }
150 | }
151 |
152 | func TestGenerator_GenerateRefreshToken_FailedToGenerateUUID(t *testing.T) {
153 | oldUUIDNewRandom := uuidNewRandom
154 | defer func() { uuidNewRandom = oldUUIDNewRandom }()
155 |
156 | uuidNewRandom = func() (uuid.UUID, error) {
157 | return uuid.UUID{}, errors.New("nope")
158 | }
159 |
160 | _, _, err := Provider{}.GenerateRefreshToken("my.email.de")
161 | expectedError := errors.New("failed to generate jwt-id: nope")
162 | if fmt.Sprint(err) != fmt.Sprint(expectedError) {
163 | t.Fatalf("unexpected error. Expected: %q. Gven:: %q", expectedError, err)
164 | }
165 | }
166 |
167 | func validateJWT(t *testing.T, tokenString string) jwt.MapClaims {
168 | claims := jwt.MapClaims{}
169 | pubKey, err := decodeECDSAPubKey(jwtPubKey)
170 | if err != nil {
171 | t.Fatalf("Failed to parse public key: %s", err)
172 | }
173 | token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
174 | // Don't forget to validate the alg is what you expect:
175 | if _, ok := token.Method.(*jwt.SigningMethodECDSA); !ok {
176 | return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
177 | }
178 |
179 | return pubKey, nil
180 | })
181 | if err != nil {
182 | t.Fatalf("Failed to parse jwt: %s", err)
183 | }
184 |
185 | if !token.Valid {
186 | t.Fatalf("Given token ist not valid. Token: %s", tokenString)
187 | }
188 |
189 | return claims
190 | }
191 |
192 | func decodeECDSAPubKey(pemEncodedPub string) (*ecdsa.PublicKey, error) {
193 | blockPub, _ := pem.Decode([]byte(pemEncodedPub))
194 | if blockPub == nil {
195 | return nil, errors.New("No valid public key found")
196 | }
197 | x509EncodedPub := blockPub.Bytes
198 | genericPublicKey, err := x509.ParsePKIXPublicKey(x509EncodedPub)
199 | if err != nil {
200 | return nil, err
201 | }
202 | publicKey := genericPublicKey.(*ecdsa.PublicKey)
203 |
204 | return publicKey, nil
205 | }
206 |
--------------------------------------------------------------------------------
/internal/auth.go:
--------------------------------------------------------------------------------
1 | package internal
2 |
3 | import (
4 | "crypto/rand"
5 | "errors"
6 | "fmt"
7 | "github.com/leberKleber/simple-jwt-provider/internal/storage"
8 | "golang.org/x/crypto/bcrypt"
9 | )
10 |
11 | // ErrIncorrectPassword returned when user authentication failed cause incorrect password
12 | var ErrIncorrectPassword = errors.New("password incorrect")
13 |
14 | // ErrUserNotFound returned when requested user not found
15 | var ErrUserNotFound = errors.New("user not found")
16 |
17 | // ErrNoValidTokenFound returned when requested user has no valid token
18 | var ErrNoValidTokenFound = errors.New("no valid token found")
19 |
20 | // ErrInvalidToken returned when the give token is not valid
21 | var ErrInvalidToken = errors.New("given token is invalid")
22 |
23 | // ErrTokenNotParsable returned when the give token is not parsable
24 | var ErrTokenNotParsable = errors.New("given token is not parsable")
25 |
26 | // Login checks email / password combination and return a new access and refresh token if correct.
27 | // return ErrIncorrectPassword when password is incorrect
28 | // return ErrUserNotFound when user not found
29 | func (p Provider) Login(email, password string) (accessToken, refreshToken string, err error) {
30 | u, err := p.Storage.User(email)
31 | if err != nil {
32 | if errors.Is(err, storage.ErrUserNotFound) {
33 | return "", "", ErrUserNotFound
34 | }
35 | return "", "", fmt.Errorf("failed to find user with email %q: %w", email, err)
36 | }
37 |
38 | err = bcrypt.CompareHashAndPassword(u.Password, []byte(password))
39 | if err != nil {
40 | return "", "", ErrIncorrectPassword
41 | }
42 |
43 | accessToken, err = p.JWTProvider.GenerateAccessToken(email, u.Claims)
44 | if err != nil {
45 | return "", "", fmt.Errorf("failed to generate access-token: %w", err)
46 | }
47 |
48 | refreshToken, jwtID, err := p.JWTProvider.GenerateRefreshToken(email)
49 | if err != nil {
50 | return "", "", fmt.Errorf("failed to generate refresh-token: %w", err)
51 | }
52 |
53 | err = p.Storage.CreateToken(&storage.Token{
54 | EMail: email,
55 | Token: jwtID,
56 | Type: storage.TokenTypeRefresh,
57 | })
58 | if err != nil {
59 | return "", "", fmt.Errorf("failed to persist refresh-token: %w", err)
60 | }
61 |
62 | return accessToken, refreshToken, nil
63 | }
64 |
65 | // Refresh checks user and token validity and return a new access and refresh token if everything is valid
66 | // return ErrTokenNotParsable when the token is not parsable
67 | // return ErrInvalidToken when the token is not valid
68 | // return ErrUserNotFound when the referred user could not be found
69 | func (p Provider) Refresh(refreshToken string) (newAccessToken, newRefreshToken string, err error) {
70 | isValid, claims, err := p.JWTProvider.IsTokenValid(refreshToken)
71 | if err != nil {
72 | return "", "", fmt.Errorf("%w: %s", ErrTokenNotParsable, err)
73 | }
74 |
75 | if !isValid {
76 | return "", "", ErrInvalidToken
77 | }
78 |
79 | email, ok := claims["email"].(string)
80 | if !ok {
81 | return "", "", errors.New("email claim is not parsable as string")
82 | }
83 |
84 | tokenID, ok := claims["jit"].(string)
85 | if !ok {
86 | return "", "", errors.New("jit claim is not parsable as string")
87 | }
88 |
89 | //TODO do Storage.TokensByEMailAndToken and Storage.DeleteToken in transaction
90 | tokens, err := p.Storage.TokensByEMailAndToken(email, tokenID)
91 | if err != nil {
92 | return "", "", fmt.Errorf("failed to find refresh-tokens: %w", err)
93 | }
94 |
95 | var t *storage.Token
96 | for _, token := range tokens {
97 | if token.Type == storage.TokenTypeRefresh {
98 | // TODO check lifetime
99 | t = &token
100 | break
101 | }
102 | }
103 |
104 | if t == nil {
105 | return "", "", ErrNoValidTokenFound
106 | }
107 |
108 | err = p.Storage.DeleteToken(t.ID)
109 | if err != nil {
110 | return "", "", fmt.Errorf("failed to delete refresh-token: %w", err)
111 | }
112 |
113 | u, err := p.Storage.User(email)
114 | if err != nil {
115 | if errors.Is(err, storage.ErrUserNotFound) {
116 | return "", "", ErrUserNotFound
117 | }
118 | return "", "", fmt.Errorf("failed to find user with email %q: %w", email, err)
119 | }
120 |
121 | newAccessToken, err = p.JWTProvider.GenerateAccessToken(email, u.Claims)
122 | if err != nil {
123 | return "", "", fmt.Errorf("failed to generate access-token: %w", err)
124 | }
125 |
126 | newRefreshToken, jwtID, err := p.JWTProvider.GenerateRefreshToken(email)
127 | if err != nil {
128 | return "", "", fmt.Errorf("failed to generate refresh-token: %w", err)
129 | }
130 |
131 | err = p.Storage.CreateToken(&storage.Token{
132 | EMail: email,
133 | Token: jwtID,
134 | Type: storage.TokenTypeRefresh,
135 | })
136 | if err != nil {
137 | return "", "", fmt.Errorf("failed to persist refresh-token: %w", err)
138 | }
139 |
140 | return newAccessToken, newRefreshToken, nil
141 | }
142 |
143 | // CreatePasswordResetRequest send a password-reset-request email to the give address.
144 | // return ErrUserNotFound when user does not exists
145 | func (p Provider) CreatePasswordResetRequest(email string) error {
146 | u, err := p.Storage.User(email)
147 | if err != nil {
148 | if errors.Is(err, storage.ErrUserNotFound) {
149 | return ErrUserNotFound
150 | }
151 | return fmt.Errorf("failed to find user with email %q: %w", email, err)
152 | }
153 |
154 | t, err := generateHEXToken()
155 | if err != nil {
156 | return fmt.Errorf("failed to generate password reset token: %w", err)
157 | }
158 |
159 | err = p.Storage.CreateToken(&storage.Token{
160 | EMail: email,
161 | Token: t,
162 | Type: storage.TokenTypeReset,
163 | })
164 | if err != nil {
165 | return fmt.Errorf("failed to create password reset token for email %q: %w", email, err)
166 | }
167 |
168 | err = p.Mailer.SendPasswordResetRequestEMail(email, t, u.Claims)
169 | if err != nil {
170 | return fmt.Errorf("failed to send password reset email: %w", err)
171 | }
172 |
173 | return nil
174 | }
175 |
176 | // ResetPassword resets the password of the given account if the reset token is correct.
177 | // return ErrNoValidTokenFound no valid token could be found
178 | func (p *Provider) ResetPassword(email, resetToken, newPassword string) error {
179 | //TODO do Storage.TokensByEMailAndToken and Storage.DeleteToken in transaction
180 | tokens, err := p.Storage.TokensByEMailAndToken(email, resetToken)
181 | if err != nil {
182 | return fmt.Errorf("failed to find reset-tokens: %w", err)
183 | }
184 |
185 | var t *storage.Token
186 | for _, token := range tokens {
187 | if token.Type == storage.TokenTypeReset {
188 | // TODO check lifetime
189 | t = &token
190 | break
191 | }
192 | }
193 |
194 | if t == nil {
195 | return ErrNoValidTokenFound
196 | }
197 |
198 | u, err := p.Storage.User(email)
199 | if err != nil {
200 | return fmt.Errorf("failed to find user with email %q: %w", email, err)
201 | }
202 |
203 | securedPassword, err := bcryptPassword(newPassword)
204 | if err != nil {
205 | return fmt.Errorf("failed to bcrypt password: %w", err)
206 | }
207 | u.Password = securedPassword
208 |
209 | err = p.Storage.UpdateUser(u)
210 | if err != nil {
211 | return fmt.Errorf("failed to update user: %w", err)
212 | }
213 |
214 | err = p.Storage.DeleteToken(t.ID)
215 | if err != nil {
216 | return fmt.Errorf("failed to delete token: %w", err)
217 | }
218 |
219 | return nil
220 | }
221 |
222 | // generate 64 char long hex token (32 bytes == 64 hex chars)
223 | var generateHEXToken = func() (string, error) {
224 | b := make([]byte, 32)
225 | _, err := rand.Read(b)
226 | return fmt.Sprintf("%x", b), err
227 | }
228 |
--------------------------------------------------------------------------------
/cmd/provider/config_test.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "os"
7 | "reflect"
8 | "testing"
9 | )
10 |
11 | func TestNewConfig(t *testing.T) {
12 | serverAddress := "leberKleber.io"
13 | setEnv(t, "SJP_SERVER_ADDRESS", serverAddress)
14 | jwtPrivateKey := "myJWTKey"
15 | setEnv(t, "SJP_JWT_PRIVATE_KEY", jwtPrivateKey)
16 | jwtAudience := "myJWTAudience"
17 | setEnv(t, "SJP_JWT_AUDIENCE", jwtAudience)
18 | jwtIssuer := "myJWTIssuer"
19 | setEnv(t, "SJP_JWT_ISSUER", jwtIssuer)
20 | jwtSubject := "myJWTSubject"
21 | setEnv(t, "SJP_JWT_SUBJECT", jwtSubject)
22 | databaseDSN := "dsn"
23 | setEnv(t, "SJP_DATABASE_DSN", databaseDSN)
24 | databaseType := "type"
25 | setEnv(t, "SJP_DATABASE_TYPE", databaseType)
26 | expectedAdminAPIEnable := true
27 | adminAPIEnable := "true"
28 | setEnv(t, "SJP_ADMIN_API_ENABLE", adminAPIEnable)
29 | adminAPIUsername := "myAdminAPIUsername"
30 | setEnv(t, "SJP_ADMIN_API_USERNAME", adminAPIUsername)
31 | adminAPIPassword := "myAdminAPIPassword"
32 | setEnv(t, "SJP_ADMIN_API_PASSWORD", adminAPIPassword)
33 | mailTemplatesFolderPath := "myAdminAPIMailTemplatesFolderPath"
34 | setEnv(t, "SJP_MAIL_TEMPLATES_FOLDER_PATH", mailTemplatesFolderPath)
35 | mailSMTPHost := "myMailSMTPHost"
36 | setEnv(t, "SJP_MAIL_SMTP_HOST", mailSMTPHost)
37 | expectedMailSMTPPort := 42
38 | mailSMTPPort := "42"
39 | setEnv(t, "SJP_MAIL_SMTP_PORT", mailSMTPPort)
40 | mailSMTPUsername := "myMailSMTPUsername"
41 | setEnv(t, "SJP_MAIL_SMTP_USERNAME", mailSMTPUsername)
42 | mailSMTPPassword := "myMailSMTPPassword"
43 | setEnv(t, "SJP_MAIL_SMTP_PASSWORD", mailSMTPPassword)
44 | expectedMailTLSServerName := true
45 | mailTLSInsecureSkipVerify := "true"
46 | setEnv(t, "SJP_MAIL_TLS_INSECURE_SKIP_VERIFY", mailTLSInsecureSkipVerify)
47 | mailTLSServerName := "myMailTLSServerName"
48 | setEnv(t, "SJP_MAIL_TLS_SERVER_NAME", mailTLSServerName)
49 |
50 | cfg, err := newConfig()
51 | if err != nil {
52 | t.Fatalf("Unexpected error while building new config cuase: %s", err)
53 | }
54 |
55 | fieldEqual(t, "serverAddress", cfg.ServerAddress, serverAddress)
56 | fieldEqual(t, "jwt>privateKey", cfg.JWT.PrivateKey, jwtPrivateKey)
57 | fieldEqual(t, "jwt>audience", cfg.JWT.Audience, jwtAudience)
58 | fieldEqual(t, "jwt>issuer", cfg.JWT.Issuer, jwtIssuer)
59 | fieldEqual(t, "jwt>subject", cfg.JWT.Subject, jwtSubject)
60 | fieldEqual(t, "dsn", cfg.Database.DSN, databaseDSN)
61 | fieldEqual(t, "dsn", cfg.Database.Type, databaseType)
62 | // noinspection GoBoolExpressions
63 | fieldEqual(t, "adminAPI>enable", cfg.AdminAPI.Enable, expectedAdminAPIEnable)
64 | fieldEqual(t, "adminAPI>username", cfg.AdminAPI.Username, adminAPIUsername)
65 | fieldEqual(t, "adminAPI>password", cfg.AdminAPI.Password, adminAPIPassword)
66 | fieldEqual(t, "mail>templatesFolderPath", cfg.Mail.TemplatesFolderPath, mailTemplatesFolderPath)
67 | fieldEqual(t, "mail>smtpHost", cfg.Mail.SMTPHost, mailSMTPHost)
68 | fieldEqual(t, "mail>smtpPort", cfg.Mail.SMTPPort, expectedMailSMTPPort)
69 | fieldEqual(t, "mail>smtpUsername", cfg.Mail.SMTPUsername, mailSMTPUsername)
70 | fieldEqual(t, "mail>smtpPassword", cfg.Mail.SMTPPassword, mailSMTPPassword)
71 | // noinspection GoBoolExpressions
72 | fieldEqual(t, "mail>tls>insecureSkipVerify", cfg.Mail.TLS.InsecureSkipVerify, expectedMailTLSServerName)
73 | fieldEqual(t, "mail>tls>serverName", cfg.Mail.TLS.ServerName, mailTLSServerName)
74 | }
75 |
76 | func TestNewConfigWithAdminAPIConstraint(t *testing.T) {
77 | cleanupEnvs(t)
78 |
79 | setEnv(t, "SJP_SERVER_ADDRESS", "leberKleber.io")
80 | setEnv(t, "SJP_JWT_PRIVATE_KEY", "myJWTKey")
81 | setEnv(t, "SJP_JWT_AUDIENCE", "myJWTAudience")
82 | setEnv(t, "SJP_JWT_ISSUER", "myJWTIssuer")
83 | setEnv(t, "SJP_JWT_SUBJECT", "myJWTSubject")
84 | setEnv(t, "SJP_DATABASE_DSN", "myDSN")
85 | setEnv(t, "SJP_DATABASE_TYPE", "myType")
86 | setEnv(t, "SJP_MAIL_TEMPLATES_FOLDER_PATH", "myAdminAPIMailTemplatesFolderPath")
87 | setEnv(t, "SJP_MAIL_SMTP_HOST", "myMailSMTPHost")
88 | setEnv(t, "SJP_MAIL_SMTP_PORT", "42")
89 | setEnv(t, "SJP_MAIL_SMTP_USERNAME", "myMailSMTPUsername")
90 | setEnv(t, "SJP_MAIL_SMTP_PASSWORD", "myMailSMTPPassword")
91 | setEnv(t, "SJP_MAIL_TLS_INSECURE_SKIP_VERIFY", "true")
92 | setEnv(t, "SJP_MAIL_TLS_SERVER_NAME", "myMailTLSServerName")
93 |
94 | // without username
95 | setEnv(t, "SJP_ADMIN_API_ENABLE", "true")
96 | setEnv(t, "SJP_ADMIN_API_USERNAME", "")
97 | setEnv(t, "SJP_ADMIN_API_PASSWORD", "myAdminAPIPassword")
98 |
99 | _, err := newConfig()
100 | expectedError := errors.New("admin-api-password and admin-api-username must be set if api has been enabled")
101 | if fmt.Sprint(err) != fmt.Sprint(expectedError) {
102 |
103 | }
104 |
105 | // without password
106 | setEnv(t, "SJP_ADMIN_API_ENABLE", "true")
107 | setEnv(t, "SJP_ADMIN_API_USERNAME", "myAdminAPIUsername")
108 | setEnv(t, "SJP_ADMIN_API_PASSWORD", "")
109 |
110 | _, err = newConfig()
111 | expectedError = errors.New("admin-api-password and admin-api-username must be set if api has been enabled")
112 | if fmt.Sprint(err) != fmt.Sprint(expectedError) {
113 | t.Fatalf("returned error is not as expected. Expected:\n%s\nGiven:\n%s", expectedError, err)
114 | }
115 |
116 | cleanupEnvs(t)
117 | }
118 |
119 | func TestNewConfigCfgLibErrorHandling(t *testing.T) {
120 | cleanupEnvs(t)
121 |
122 | // PrivateKey cfg must be set but is not
123 | _, err := newConfig()
124 |
125 | expectedError := errors.New("required field PrivateKey is missing value")
126 | if fmt.Sprint(expectedError) != fmt.Sprint(err) {
127 | t.Fatalf("returned error is not as expected. Expected:\n%s\nGiven:\n%s", expectedError, err)
128 | }
129 | }
130 |
131 | func TestNewConfigUnableToGenerateUsage(t *testing.T) {
132 | oldConfigUsage := confUsage
133 | defer func() {
134 | confUsage = oldConfigUsage
135 | }()
136 |
137 | confUsage = func(namespace string, v interface{}) (string, error) {
138 | return "", errors.New("failed to generate usage")
139 | }
140 |
141 | cleanupEnvs(t)
142 |
143 | // PrivateKey cfg must be set but is not
144 | _, err := newConfig()
145 |
146 | expectedError := errors.New("failed to generate usage")
147 | if fmt.Sprint(expectedError) != fmt.Sprint(err) {
148 | t.Fatalf("returned error is not as expected. Expected:\n%s\nGiven:\n%s", expectedError, err)
149 | }
150 |
151 | }
152 |
153 | func setEnv(t *testing.T, key, value string) {
154 | err := os.Setenv(key, value)
155 | if err != nil {
156 | t.Fatalf("failed to set env variable %q cause: %s", key, err)
157 | }
158 | }
159 |
160 | func unsetEnv(t *testing.T, key string) {
161 | err := os.Unsetenv(key)
162 | if err != nil {
163 | t.Fatalf("failed to unset env variable %q cause: %s", key, err)
164 | }
165 | }
166 |
167 | func fieldEqual(t *testing.T, name string, cfgValue, expectedValue interface{}) {
168 | if !reflect.DeepEqual(cfgValue, expectedValue) {
169 | t.Errorf("unexpected cfg-value in field %q. Given: %s, Expected: %s", name, cfgValue, expectedValue)
170 | }
171 | }
172 |
173 | func cleanupEnvs(t *testing.T) {
174 | unsetEnv(t, "SJP_SERVER_ADDRESS")
175 | unsetEnv(t, "SJP_JWT_PRIVATE_KEY")
176 | unsetEnv(t, "SJP_JWT_AUDIENCE")
177 | unsetEnv(t, "SJP_JWT_ISSUER")
178 | unsetEnv(t, "SJP_JWT_SUBJECT")
179 | unsetEnv(t, "SJP_DATABASE_DSN")
180 | unsetEnv(t, "SJP_DATABASE_TYPE")
181 | unsetEnv(t, "SJP_MAIL_TEMPLATES_FOLDER_PATH")
182 | unsetEnv(t, "SJP_MAIL_SMTP_HOST")
183 | unsetEnv(t, "SJP_MAIL_SMTP_PORT")
184 | unsetEnv(t, "SJP_MAIL_SMTP_USERNAME")
185 | unsetEnv(t, "SJP_MAIL_SMTP_PASSWORD")
186 | unsetEnv(t, "SJP_MAIL_TLS_INSECURE_SKIP_VERIFY")
187 | unsetEnv(t, "SJP_MAIL_TLS_SERVER_NAME")
188 | unsetEnv(t, "SJP_ADMIN_API_ENABLE")
189 | unsetEnv(t, "SJP_ADMIN_API_USERNAME")
190 | unsetEnv(t, "SJP_ADMIN_API_PASSWORD")
191 | }
192 |
--------------------------------------------------------------------------------
/internal/mailer/template_test.go:
--------------------------------------------------------------------------------
1 | package mailer
2 |
3 | import (
4 | "bytes"
5 | "errors"
6 | "fmt"
7 | "github.com/DusanKasan/parsemail"
8 | htmlTemplate "html/template"
9 | "os"
10 | "reflect"
11 | "testing"
12 | textTemplate "text/template"
13 | )
14 |
15 | func TestLoadTemplates(t *testing.T) {
16 | tests := []struct {
17 | name string
18 | givenTemplatePath string
19 | givenTemplateName string
20 | parseHTMLFileExpectedFilenames []string
21 | parseHTMLFileTemplate *htmlTemplate.Template
22 | parseHTMLFileErr error
23 | parseTextFileExpectedFilenames []string
24 | parseTextFileTemplate *textTemplate.Template
25 | parseTextFileErr error
26 | parseYMLFileExpectedFilenames []string
27 | parseYMLFileTemplate *textTemplate.Template
28 | parseYMLFileErr error
29 | expectedTemplate mailTemplate
30 | expectedErr error
31 | }{
32 | {
33 | name: "happycase",
34 | givenTemplatePath: "/my/mailTemplate/path",
35 | givenTemplateName: "myTemplate",
36 | parseHTMLFileExpectedFilenames: []string{"/my/mailTemplate/path/myTemplate.html"},
37 | parseHTMLFileTemplate: htmlTemplate.New("myTemplateHTML"),
38 | parseTextFileExpectedFilenames: []string{"/my/mailTemplate/path/myTemplate.txt"},
39 | parseTextFileTemplate: textTemplate.New("myTemplateText"),
40 | parseYMLFileExpectedFilenames: []string{"/my/mailTemplate/path/myTemplate.yml"},
41 | parseYMLFileTemplate: textTemplate.New("myTemplateYML"),
42 | expectedTemplate: mailTemplate{
43 | name: "myTemplate",
44 | htmlTmpl: htmlTemplate.New("myTemplateHTML"),
45 | textTmpl: textTemplate.New("myTemplateText"),
46 | headerTmpl: textTemplate.New("myTemplateYML"),
47 | },
48 | }, {
49 | name: "could not load html mailTemplate",
50 | givenTemplatePath: "/my/mailTemplate/path",
51 | givenTemplateName: "myTemplate",
52 | parseHTMLFileExpectedFilenames: []string{"/my/mailTemplate/path/myTemplate.html"},
53 | parseHTMLFileErr: os.ErrNotExist,
54 | expectedErr: errors.New("failed to load mail html body template: file does not exist"),
55 | }, {
56 | name: "could not load text mailTemplate",
57 | givenTemplatePath: "/my/mailTemplate/path",
58 | givenTemplateName: "myTemplate",
59 | parseHTMLFileExpectedFilenames: []string{"/my/mailTemplate/path/myTemplate.html"},
60 | parseHTMLFileTemplate: htmlTemplate.New("myTemplateHTML"),
61 | parseTextFileExpectedFilenames: []string{"/my/mailTemplate/path/myTemplate.txt"},
62 | parseTextFileErr: os.ErrPermission,
63 | expectedErr: errors.New("failed to load mail text body template: permission denied"),
64 | }, {
65 | name: "could not load yml mailTemplate",
66 | givenTemplatePath: "/my/mailTemplate/path",
67 | givenTemplateName: "myTemplate",
68 | parseHTMLFileExpectedFilenames: []string{"/my/mailTemplate/path/myTemplate.html"},
69 | parseHTMLFileTemplate: htmlTemplate.New("myTemplateHTML"),
70 | parseTextFileExpectedFilenames: []string{"/my/mailTemplate/path/myTemplate.txt"},
71 | parseTextFileTemplate: textTemplate.New("myTemplateText"),
72 | parseYMLFileExpectedFilenames: []string{"/my/mailTemplate/path/myTemplate.yml"},
73 | parseYMLFileErr: errors.New("abc error"),
74 | expectedErr: errors.New("failed to load mail headers template: abc error"),
75 | },
76 | }
77 |
78 | for _, tt := range tests {
79 | t.Run(tt.name, func(t *testing.T) {
80 | oldHTMLTemplateParseFiles := htmlTemplateParseFiles
81 | oldTextTemplateParseFiles := textTemplateParseFiles
82 | oldYMLTemplateParseFiles := ymlTemplateParseFiles
83 | defer func() {
84 | htmlTemplateParseFiles = oldHTMLTemplateParseFiles
85 | textTemplateParseFiles = oldTextTemplateParseFiles
86 | ymlTemplateParseFiles = oldYMLTemplateParseFiles
87 | }()
88 |
89 | htmlTemplateParseFiles = func(filenames ...string) (*htmlTemplate.Template, error) {
90 | if !reflect.DeepEqual(filenames, tt.parseHTMLFileExpectedFilenames) {
91 | t.Errorf("html mailTemplate file path is not as expected. Expected: %q, given: %q", tt.parseHTMLFileExpectedFilenames, filenames)
92 | }
93 | return tt.parseHTMLFileTemplate, tt.parseHTMLFileErr
94 | }
95 | textTemplateParseFiles = func(filenames ...string) (*textTemplate.Template, error) {
96 | if !reflect.DeepEqual(filenames, tt.parseTextFileExpectedFilenames) {
97 | t.Errorf("text mailTemplate file path is not as expected. Expected: %q, given: %q", tt.parseTextFileExpectedFilenames, filenames)
98 | }
99 | return tt.parseTextFileTemplate, tt.parseTextFileErr
100 | }
101 | ymlTemplateParseFiles = func(filenames ...string) (*textTemplate.Template, error) {
102 | if !reflect.DeepEqual(filenames, tt.parseYMLFileExpectedFilenames) {
103 | t.Errorf("yml mailTemplate file path is not as expected. Expected: %q, given: %q", tt.parseYMLFileExpectedFilenames, filenames)
104 | }
105 | return tt.parseYMLFileTemplate, tt.parseYMLFileErr
106 | }
107 |
108 | template, err := loadTemplates(tt.givenTemplatePath, tt.givenTemplateName)
109 | if fmt.Sprint(err) != fmt.Sprint(tt.expectedErr) {
110 | t.Fatalf("Unexpected error. Given:\n%q\nExpected:\n%q", err, tt.expectedErr)
111 | } else if err != nil {
112 | return
113 | }
114 |
115 | if !reflect.DeepEqual(template, tt.expectedTemplate) {
116 | t.Fatalf("Unexpected mailTemplate. Given:\n%#v\nExpected:\n%#v", template, tt.expectedTemplate)
117 | }
118 | })
119 | }
120 | }
121 |
122 | func TestTemplate_Render(t *testing.T) {
123 | tests := []struct {
124 | name string
125 | htmlTplContent string
126 | textTplContent string
127 | headerTplContent string
128 | givenRenderArgs interface{}
129 | expectedError error
130 | }{
131 | {
132 | name: "Happycase",
133 | htmlTplContent: "html template {{.TestID}}",
134 | textTplContent: "text template {{.TestID}}",
135 | headerTplContent: `MyHeader:
136 | - "headaaaaa {{.TestID}}"`,
137 | givenRenderArgs: struct {
138 | TestID string
139 | }{
140 | "myTestID",
141 | },
142 | },
143 | {
144 | name: "header-template execute error handling",
145 | headerTplContent: `MyHeader:
146 | - "headaaaaa {{.notExisting}}"`,
147 | givenRenderArgs: struct {
148 | notExisting string
149 | }{},
150 | expectedError: errors.New("failed to render mail headers: template: htmlTemplate:2:17: executing \"htmlTemplate\" at <.notExisting>: notExisting is an unexported field of struct type struct { notExisting string }"),
151 | },
152 | {
153 | name: "invalid header-template yml syntax",
154 | headerTplContent: `MyHeader:
155 | "headaaaaa"`,
156 | expectedError: errors.New("failed to render mail headers: yaml: unmarshal errors:\n line 2: cannot unmarshal !!str `headaaaaa` into []string"),
157 | },
158 | {
159 | name: "text-template Execute error handling",
160 | htmlTplContent: "html template",
161 | textTplContent: "text template {{.testID}}",
162 | headerTplContent: `MyHeader:
163 | - "headaaaaa"`,
164 | givenRenderArgs: struct {
165 | testID string
166 | }{
167 | "myTestID",
168 | },
169 | expectedError: errors.New("failed to render mail text body: template: htmlTemplate:1:16: executing \"htmlTemplate\" at <.testID>: testID is an unexported field of struct type struct { testID string }"),
170 | },
171 | {
172 | name: "html-template Execute error handling",
173 | htmlTplContent: "html template {{.testID}}",
174 | textTplContent: "text template",
175 | headerTplContent: `MyHeader:
176 | - "headaaaaa"`,
177 | givenRenderArgs: struct {
178 | testID string
179 | }{
180 | "myTestID",
181 | },
182 | expectedError: errors.New("failed to render mail html body: template: htmlTemplate:1:16: executing \"htmlTemplate\" at <.testID>: testID is an unexported field of struct type struct { testID string }"),
183 | },
184 | }
185 |
186 | for _, tt := range tests {
187 | t.Run(tt.name, func(t *testing.T) {
188 | htmlTpl, err := htmlTemplate.New("htmlTemplate").Parse(tt.htmlTplContent)
189 | if err != nil {
190 | t.Fatal("failed to parse test html template", err)
191 | }
192 | textTpl, err := textTemplate.New("htmlTemplate").Parse(tt.textTplContent)
193 | if err != nil {
194 | t.Fatal("failed to parse test html template", err)
195 | }
196 | headerTpl, err := textTemplate.New("htmlTemplate").Parse(tt.headerTplContent)
197 | if err != nil {
198 | t.Fatal("failed to parse test html template", err)
199 | }
200 |
201 | mt := mailTemplate{
202 | name: "testMail",
203 | htmlTmpl: htmlTpl,
204 | textTmpl: textTpl,
205 | headerTmpl: headerTpl,
206 | }
207 |
208 | mail, err := mt.Render(tt.givenRenderArgs)
209 | if fmt.Sprint(err) != fmt.Sprint(tt.expectedError) {
210 | t.Fatalf("unexpected error while render template. Expected:\n%q\nGiven:\n%q", tt.expectedError, err)
211 | } else if err != nil {
212 | return
213 | }
214 |
215 | var bb bytes.Buffer
216 | _, err = mail.WriteTo(&bb)
217 | if err != nil {
218 | t.Error("failed to write mail to bb", err)
219 | }
220 |
221 | parsedEMail, err := parsemail.Parse(&bb) // returns Email struct and error
222 | if err != nil {
223 | t.Fatal("failed to parse written mail", err)
224 | }
225 |
226 | expectedHTMLBody := "html template myTestID"
227 | if expectedHTMLBody != parsedEMail.HTMLBody {
228 | t.Errorf("html body is not as expected. Expected: %q, Give: %q", expectedHTMLBody, parsedEMail.HTMLBody)
229 | }
230 | expectedTextBody := "text template myTestID"
231 | if expectedTextBody != parsedEMail.TextBody {
232 | t.Errorf("text body is not as expected. Expected: %q, Give: %q", expectedTextBody, parsedEMail.TextBody)
233 | }
234 | expectedTestHeaderContent := "headaaaaa myTestID"
235 | testHeaderValue := parsedEMail.Header.Get("MyHeader")
236 | if expectedTestHeaderContent != testHeaderValue {
237 | t.Errorf("test header value is not as expected. Expected: %q, Give: %q", expectedTestHeaderContent, testHeaderValue)
238 | }
239 | })
240 | }
241 | }
242 |
--------------------------------------------------------------------------------
/internal/storage_moq_test.go:
--------------------------------------------------------------------------------
1 | // Code generated by moq; DO NOT EDIT.
2 | // github.com/matryer/moq
3 |
4 | package internal
5 |
6 | import (
7 | "github.com/leberKleber/simple-jwt-provider/internal/storage"
8 | "sync"
9 | )
10 |
11 | // Ensure, that StorageMock does implement Storage.
12 | // If this is not the case, regenerate this file with moq.
13 | var _ Storage = &StorageMock{}
14 |
15 | // StorageMock is a mock implementation of Storage.
16 | //
17 | // func TestSomethingThatUsesStorage(t *testing.T) {
18 | //
19 | // // make and configure a mocked Storage
20 | // mockedStorage := &StorageMock{
21 | // CreateTokenFunc: func(t *storage.Token) error {
22 | // panic("mock out the CreateToken method")
23 | // },
24 | // CreateUserFunc: func(user storage.User) error {
25 | // panic("mock out the CreateUser method")
26 | // },
27 | // DeleteTokenFunc: func(id uint) error {
28 | // panic("mock out the DeleteToken method")
29 | // },
30 | // DeleteUserFunc: func(email string) error {
31 | // panic("mock out the DeleteUser method")
32 | // },
33 | // TokensByEMailAndTokenFunc: func(email string, token string) ([]storage.Token, error) {
34 | // panic("mock out the TokensByEMailAndToken method")
35 | // },
36 | // UpdateUserFunc: func(user storage.User) error {
37 | // panic("mock out the UpdateUser method")
38 | // },
39 | // UserFunc: func(email string) (storage.User, error) {
40 | // panic("mock out the User method")
41 | // },
42 | // }
43 | //
44 | // // use mockedStorage in code that requires Storage
45 | // // and then make assertions.
46 | //
47 | // }
48 | type StorageMock struct {
49 | // CreateTokenFunc mocks the CreateToken method.
50 | CreateTokenFunc func(t *storage.Token) error
51 |
52 | // CreateUserFunc mocks the CreateUser method.
53 | CreateUserFunc func(user storage.User) error
54 |
55 | // DeleteTokenFunc mocks the DeleteToken method.
56 | DeleteTokenFunc func(id uint) error
57 |
58 | // DeleteUserFunc mocks the DeleteUser method.
59 | DeleteUserFunc func(email string) error
60 |
61 | // TokensByEMailAndTokenFunc mocks the TokensByEMailAndToken method.
62 | TokensByEMailAndTokenFunc func(email string, token string) ([]storage.Token, error)
63 |
64 | // UpdateUserFunc mocks the UpdateUser method.
65 | UpdateUserFunc func(user storage.User) error
66 |
67 | // UserFunc mocks the User method.
68 | UserFunc func(email string) (storage.User, error)
69 |
70 | // calls tracks calls to the methods.
71 | calls struct {
72 | // CreateToken holds details about calls to the CreateToken method.
73 | CreateToken []struct {
74 | // T is the t argument value.
75 | T *storage.Token
76 | }
77 | // CreateUser holds details about calls to the CreateUser method.
78 | CreateUser []struct {
79 | // User is the user argument value.
80 | User storage.User
81 | }
82 | // DeleteToken holds details about calls to the DeleteToken method.
83 | DeleteToken []struct {
84 | // ID is the id argument value.
85 | ID uint
86 | }
87 | // DeleteUser holds details about calls to the DeleteUser method.
88 | DeleteUser []struct {
89 | // Email is the email argument value.
90 | Email string
91 | }
92 | // TokensByEMailAndToken holds details about calls to the TokensByEMailAndToken method.
93 | TokensByEMailAndToken []struct {
94 | // Email is the email argument value.
95 | Email string
96 | // Token is the token argument value.
97 | Token string
98 | }
99 | // UpdateUser holds details about calls to the UpdateUser method.
100 | UpdateUser []struct {
101 | // User is the user argument value.
102 | User storage.User
103 | }
104 | // User holds details about calls to the User method.
105 | User []struct {
106 | // Email is the email argument value.
107 | Email string
108 | }
109 | }
110 | lockCreateToken sync.RWMutex
111 | lockCreateUser sync.RWMutex
112 | lockDeleteToken sync.RWMutex
113 | lockDeleteUser sync.RWMutex
114 | lockTokensByEMailAndToken sync.RWMutex
115 | lockUpdateUser sync.RWMutex
116 | lockUser sync.RWMutex
117 | }
118 |
119 | // CreateToken calls CreateTokenFunc.
120 | func (mock *StorageMock) CreateToken(t *storage.Token) error {
121 | if mock.CreateTokenFunc == nil {
122 | panic("StorageMock.CreateTokenFunc: method is nil but Storage.CreateToken was just called")
123 | }
124 | callInfo := struct {
125 | T *storage.Token
126 | }{
127 | T: t,
128 | }
129 | mock.lockCreateToken.Lock()
130 | mock.calls.CreateToken = append(mock.calls.CreateToken, callInfo)
131 | mock.lockCreateToken.Unlock()
132 | return mock.CreateTokenFunc(t)
133 | }
134 |
135 | // CreateTokenCalls gets all the calls that were made to CreateToken.
136 | // Check the length with:
137 | // len(mockedStorage.CreateTokenCalls())
138 | func (mock *StorageMock) CreateTokenCalls() []struct {
139 | T *storage.Token
140 | } {
141 | var calls []struct {
142 | T *storage.Token
143 | }
144 | mock.lockCreateToken.RLock()
145 | calls = mock.calls.CreateToken
146 | mock.lockCreateToken.RUnlock()
147 | return calls
148 | }
149 |
150 | // CreateUser calls CreateUserFunc.
151 | func (mock *StorageMock) CreateUser(user storage.User) error {
152 | if mock.CreateUserFunc == nil {
153 | panic("StorageMock.CreateUserFunc: method is nil but Storage.CreateUser was just called")
154 | }
155 | callInfo := struct {
156 | User storage.User
157 | }{
158 | User: user,
159 | }
160 | mock.lockCreateUser.Lock()
161 | mock.calls.CreateUser = append(mock.calls.CreateUser, callInfo)
162 | mock.lockCreateUser.Unlock()
163 | return mock.CreateUserFunc(user)
164 | }
165 |
166 | // CreateUserCalls gets all the calls that were made to CreateUser.
167 | // Check the length with:
168 | // len(mockedStorage.CreateUserCalls())
169 | func (mock *StorageMock) CreateUserCalls() []struct {
170 | User storage.User
171 | } {
172 | var calls []struct {
173 | User storage.User
174 | }
175 | mock.lockCreateUser.RLock()
176 | calls = mock.calls.CreateUser
177 | mock.lockCreateUser.RUnlock()
178 | return calls
179 | }
180 |
181 | // DeleteToken calls DeleteTokenFunc.
182 | func (mock *StorageMock) DeleteToken(id uint) error {
183 | if mock.DeleteTokenFunc == nil {
184 | panic("StorageMock.DeleteTokenFunc: method is nil but Storage.DeleteToken was just called")
185 | }
186 | callInfo := struct {
187 | ID uint
188 | }{
189 | ID: id,
190 | }
191 | mock.lockDeleteToken.Lock()
192 | mock.calls.DeleteToken = append(mock.calls.DeleteToken, callInfo)
193 | mock.lockDeleteToken.Unlock()
194 | return mock.DeleteTokenFunc(id)
195 | }
196 |
197 | // DeleteTokenCalls gets all the calls that were made to DeleteToken.
198 | // Check the length with:
199 | // len(mockedStorage.DeleteTokenCalls())
200 | func (mock *StorageMock) DeleteTokenCalls() []struct {
201 | ID uint
202 | } {
203 | var calls []struct {
204 | ID uint
205 | }
206 | mock.lockDeleteToken.RLock()
207 | calls = mock.calls.DeleteToken
208 | mock.lockDeleteToken.RUnlock()
209 | return calls
210 | }
211 |
212 | // DeleteUser calls DeleteUserFunc.
213 | func (mock *StorageMock) DeleteUser(email string) error {
214 | if mock.DeleteUserFunc == nil {
215 | panic("StorageMock.DeleteUserFunc: method is nil but Storage.DeleteUser was just called")
216 | }
217 | callInfo := struct {
218 | Email string
219 | }{
220 | Email: email,
221 | }
222 | mock.lockDeleteUser.Lock()
223 | mock.calls.DeleteUser = append(mock.calls.DeleteUser, callInfo)
224 | mock.lockDeleteUser.Unlock()
225 | return mock.DeleteUserFunc(email)
226 | }
227 |
228 | // DeleteUserCalls gets all the calls that were made to DeleteUser.
229 | // Check the length with:
230 | // len(mockedStorage.DeleteUserCalls())
231 | func (mock *StorageMock) DeleteUserCalls() []struct {
232 | Email string
233 | } {
234 | var calls []struct {
235 | Email string
236 | }
237 | mock.lockDeleteUser.RLock()
238 | calls = mock.calls.DeleteUser
239 | mock.lockDeleteUser.RUnlock()
240 | return calls
241 | }
242 |
243 | // TokensByEMailAndToken calls TokensByEMailAndTokenFunc.
244 | func (mock *StorageMock) TokensByEMailAndToken(email string, token string) ([]storage.Token, error) {
245 | if mock.TokensByEMailAndTokenFunc == nil {
246 | panic("StorageMock.TokensByEMailAndTokenFunc: method is nil but Storage.TokensByEMailAndToken was just called")
247 | }
248 | callInfo := struct {
249 | Email string
250 | Token string
251 | }{
252 | Email: email,
253 | Token: token,
254 | }
255 | mock.lockTokensByEMailAndToken.Lock()
256 | mock.calls.TokensByEMailAndToken = append(mock.calls.TokensByEMailAndToken, callInfo)
257 | mock.lockTokensByEMailAndToken.Unlock()
258 | return mock.TokensByEMailAndTokenFunc(email, token)
259 | }
260 |
261 | // TokensByEMailAndTokenCalls gets all the calls that were made to TokensByEMailAndToken.
262 | // Check the length with:
263 | // len(mockedStorage.TokensByEMailAndTokenCalls())
264 | func (mock *StorageMock) TokensByEMailAndTokenCalls() []struct {
265 | Email string
266 | Token string
267 | } {
268 | var calls []struct {
269 | Email string
270 | Token string
271 | }
272 | mock.lockTokensByEMailAndToken.RLock()
273 | calls = mock.calls.TokensByEMailAndToken
274 | mock.lockTokensByEMailAndToken.RUnlock()
275 | return calls
276 | }
277 |
278 | // UpdateUser calls UpdateUserFunc.
279 | func (mock *StorageMock) UpdateUser(user storage.User) error {
280 | if mock.UpdateUserFunc == nil {
281 | panic("StorageMock.UpdateUserFunc: method is nil but Storage.UpdateUser was just called")
282 | }
283 | callInfo := struct {
284 | User storage.User
285 | }{
286 | User: user,
287 | }
288 | mock.lockUpdateUser.Lock()
289 | mock.calls.UpdateUser = append(mock.calls.UpdateUser, callInfo)
290 | mock.lockUpdateUser.Unlock()
291 | return mock.UpdateUserFunc(user)
292 | }
293 |
294 | // UpdateUserCalls gets all the calls that were made to UpdateUser.
295 | // Check the length with:
296 | // len(mockedStorage.UpdateUserCalls())
297 | func (mock *StorageMock) UpdateUserCalls() []struct {
298 | User storage.User
299 | } {
300 | var calls []struct {
301 | User storage.User
302 | }
303 | mock.lockUpdateUser.RLock()
304 | calls = mock.calls.UpdateUser
305 | mock.lockUpdateUser.RUnlock()
306 | return calls
307 | }
308 |
309 | // User calls UserFunc.
310 | func (mock *StorageMock) User(email string) (storage.User, error) {
311 | if mock.UserFunc == nil {
312 | panic("StorageMock.UserFunc: method is nil but Storage.User was just called")
313 | }
314 | callInfo := struct {
315 | Email string
316 | }{
317 | Email: email,
318 | }
319 | mock.lockUser.Lock()
320 | mock.calls.User = append(mock.calls.User, callInfo)
321 | mock.lockUser.Unlock()
322 | return mock.UserFunc(email)
323 | }
324 |
325 | // UserCalls gets all the calls that were made to User.
326 | // Check the length with:
327 | // len(mockedStorage.UserCalls())
328 | func (mock *StorageMock) UserCalls() []struct {
329 | Email string
330 | } {
331 | var calls []struct {
332 | Email string
333 | }
334 | mock.lockUser.RLock()
335 | calls = mock.calls.User
336 | mock.lockUser.RUnlock()
337 | return calls
338 | }
339 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://github.com/avelino/awesome-go)
2 | [](https://github.com/leberKleber/simple-jwt-provider/actions?query=workflow%3Ago)
3 | [](https://goreportcard.com/report/github.com/leberKleber/simple-jwt-provider)
4 | [](https://codecov.io/gh/leberKleber/simple-jwt-provider)
5 |
6 | # simple-jwt-provider
7 |
8 | Simple and lightweight JWT-Provider written in go (golang). It exhibits JWT for the in postgres or sqlite persisted user, which
9 | can be managed via api. Also, a password-reset flow via mail verification is available. User specific custom-claims also
10 | available for jwt-generation and mail rendering.
11 |
12 | dockerized: https://hub.docker.com/r/leberkleber/simple-jwt-provider
13 |
14 | build it yourself:
15 |
16 | ```shell script
17 | # as docker-image
18 | docker build . -t leberkleber/simple-jwt-provider
19 |
20 | # as binary
21 | go build -o simple-jwt-provider ./cmd/provider/
22 | ```
23 |
24 | # Table of contents
25 |
26 | - [Try it](#try-it)
27 | - [Getting started](#getting-started)
28 | - [Generate ECDSA-512 key pair](#generate-ecdsa-512-key-pair)
29 | - [Configuration](#configuration)
30 | - [API](#api)
31 | - [POST `/v1/auth/login`](#post-v1authlogin)
32 | - [POST `/v1/auth/refresh`](#post-v1authrefresh)
33 | - [POST `/v1/auth/password-reset-request`](#post-v1authpassword-reset-request)
34 | - [POST `/v1/auth/password-reset`](#post-v1authpassword-reset)
35 | - [POST `/v1/admin/users`](#post-v1adminusers)
36 | - [PUT `/v1/admin/users/{email}`](#put-v1adminusersemail)
37 | - [DELETE `/v1/admin/users/{email}`](#delete-v1adminusersemail)
38 | - [Mail](#mail)
39 | - [Password reset request](#password-reset-request)
40 | - [Development](#development)
41 | - [mocks](#mocks)
42 | - [component tests](#component-tests)
43 |
44 | ## Try it
45 |
46 | ```shell script
47 | git clone git@github.com:leberKleber/simple-jwt-provider.git
48 | docker-compose -f example/docker-compose.yml up
49 |
50 | # create user via admin-api
51 | ./example/create-user.sh test.test@test.test password {}
52 |
53 | # login with created user
54 | ./example/login.sh test.tscest@test.test password
55 |
56 | # reset password
57 | # 1) create password reset request
58 | # - mail with reset token would be send
59 | # 2) reset password with received token
60 | # 3) do crud operations on user
61 |
62 | # 1) create password reset request
63 | ./example/create-password-reset-request.sh test.test@test.test
64 | # 1.1) open browser at http://127.0.0.1:8025/ and copy reset token (token only not the url)
65 | # 2) reset password with received token
66 | ./example/reset-password.sh test.test@test.test newPassword {reset-token}
67 | # verify new password
68 | ./example/login.sh test.test@test.test newPassword
69 |
70 | # 3) do crud operations on user
71 | # see ./example/*.sh
72 | ```
73 |
74 | ## Getting started
75 |
76 | ### Generate ECDSA-512 key pair
77 |
78 | ```sh
79 | # private key
80 | openssl ecparam -genkey -name secp521r1 -noout -out ecdsa-p521-private.pem
81 | # public key
82 | openssl ec -in ecdsa-p521-private.pem -pubout -out ecdsa-p521-public.pem
83 | ```
84 |
85 | ### Configuration
86 |
87 | | Environment variable | Description | Required | Default |
88 | | --------------------------------- |:-------------------------------------------------------------------------------------:| -----------------------------------:|----------------------:|
89 | | SJP_LOG_LEVEL | Log-Level can be TRACE DEBUG INFO WARN ERROR FATAL or PANIC | no | INFO |
90 | | SJP_SERVER_ADDRESS | Server-address network-interface to bind on e.g.: '127.0.0.1:8080' | no | 0.0.0.0:80 |
91 | | SJP_JWT_LIFETIME | Lifetime of JWT | no | 4h |
92 | | SJP_JWT_PRIVATE_KEY | JWT PrivateKey ECDSA512 | yes | - |
93 | | SJP_JWT_AUDIENCE | Audience private claim which will be applied in each JWT | no | - |
94 | | SJP_JWT_ISSUER | Issuer private claim which will be applied in each JWT | no | - |
95 | | SJP_JWT_SUBJECT | Subject private claim which will be applied in each JWT | no | - |
96 | | SJP_DATABASE_TYPE | Database type. Currently supported postgres and sqlite | yes | - |
97 | | SJP_DATABASE_DSN | Data Source Name for persistence | yes | - |
98 | | SJP_ADMIN_API_ENABLE | Enable admin API to manage stored users (true / false) | no | false |
99 | | SJP_ADMIN_API_USERNAME | Basic Auth Username if enable-admin-api = true | yes, when enable-admin-api = true | - |
100 | | SJP_ADMIN_API_PASSWORD | Basic Auth Password if enable-admin-api = true when is bcrypted prefix with 'bcrypt:' | yes, when enable-admin-api = true | - |
101 | | SJP_MAIL_TEMPLATES_FOLDER_PATH | Path to mail-templates folder | no | /mail-templates |
102 | | SJP_MAIL_SMTP_HOST | SMTP host to connect to | yes | - |
103 | | SJP_MAIL_SMTP_PORT | SMTP port to connect to | no | 587 |
104 | | SJP_MAIL_SMTP_USERNAME | SMTP username to authorize with | yes | - |
105 | | SJP_MAIL_SMTP_PASSWORD | SMTP password to authorize with | yes | - |
106 | | SJP_MAIL_TLS_INSECURE_SKIP_VERIFY | true if certificates should not be verified | no | false |
107 | | SJP_MAIL_TLS_SERVER_NAME | name of the server who expose the certificate | no | - |
108 |
109 | ## API
110 |
111 | ### POST `/v1/auth/login`
112 |
113 | This endpoint will check the email/password combination and will set the respond with an jwtauthToken if correct:
114 |
115 | Request body:
116 | ```json
117 | {
118 | "email": "info@leberkleber.io",
119 | "password": "s3cr3t"
120 | }
121 | ```
122 |
123 | Response body (200 - OK):
124 | ```json
125 | {
126 | "access_token": "",
127 | "refresh_token": ""
128 | }
129 | ```
130 |
131 | ### POST `/v1/auth/refresh`
132 |
133 | This endpoint will return a new access and refresh token. The submitted refresh-token will no longer be valid.
134 |
135 | Request body:
136 | ```json
137 | {
138 | "refresh_token": ""
139 | }
140 | ```
141 |
142 | Response body (200 - OK):
143 | ```json
144 | {
145 | "access_token": "",
146 | "refresh_token": ""
147 | }
148 | ```
149 | ### POST `/v1/auth/password-reset-request`
150 |
151 | This endpoint will trigger a password reset request. The user gets a token per mail. With this token, the password can
152 | be reset via POST@`/v1/auth/password-reset`.
153 |
154 | Request body:
155 | ```json
156 | {
157 | "email": "info@leberkleber.io"
158 | }
159 | ```
160 |
161 | Response (201 - CREATED)
162 |
163 | ### POST `/v1/auth/password-reset`
164 |
165 | This endpoint will reset the password of the given user if the reset-token is valid and matches to the given email.
166 |
167 | Request body:
168 | ```json
169 | {
170 | "email": "info@leberkleber.io",
171 | "reset_token": "rAnDoMsHiT456",
172 | "password": "SeCReT"
173 | }
174 | ```
175 |
176 | Response (204 - NO CONTENT)
177 |
178 | ### POST `/v1/admin/users`
179 |
180 | This endpoint will create a new user if admin api auth was successfully:
181 |
182 | Request body:
183 | ```json
184 | {
185 | "email": "info@leberkleber.io",
186 | "password": "s3cr3t",
187 | "claims": {
188 | "myCustomClaim": "custom claims for jwt and mail templates"
189 | }
190 | }
191 | ```
192 |
193 | Response body (201 - CREATED)
194 |
195 | ### PUT `/v1/admin/users/{email}`
196 |
197 | This endpoint will update the given properties (excluding email) of the user with the given email when the admin api
198 | auth was successfully:
199 |
200 | Request body:
201 | ```json
202 | {
203 | "password": "n3wS3cr3t",
204 | "claims": {
205 | "updatedClaim": "now updated"
206 | }
207 | }
208 | ```
209 |
210 | Response body (200 - NO CONTENT)
211 |
212 | ```json
213 | {
214 | "email": "info@leberkleber.io",
215 | "password": "**********",
216 | "claims": {
217 | "updatedClaim": "now updated"
218 | }
219 | }
220 | ```
221 |
222 | ### DELETE `/v1/admin/users/{email}`
223 |
224 | This endpoint will delete the user with the given email when there are no tokens which referred to this user, and the
225 | admin api auth was successfully:
226 |
227 | Response body (201 - NO CONTENT)
228 |
229 | ## Mail
230 |
231 | Mails will be generated based on a set of templates which should be prepared for productive usage.
232 |
233 | - `.html` represents the html body of the mail and can be templated with `html.template` syntax
234 | (https://golang.org/pkg/html/template/). Available templating arguments listed in detailed template type description.
235 | - `.txt` represents the text body of the mail and can be templated with `text.template` syntax
236 | (https://golang.org/pkg/text/template/). Available templating arguments listed in detailed template type description.
237 | - `.yml` represents the header of the mail. In this template headers e.g. `From`, `To` or `Subject`
238 | can be set `text.template` syntax (https://golang.org/pkg/text/template/). Available templating arguments listed in
239 | detailed template type description.
240 |
241 | ### Password reset request
242 |
243 | An example of this mail type can be found in `/mail-templates/password-reset-request.*`. Available template arguments:
244 |
245 | | Argument | Content | Example usage |
246 | |--------------------|--------------------------------------------------------|-------------------------------------|
247 | | Recipient | Users email address | `{{.Recipient}}` |
248 | | PasswordResetToken | The token which is required to reset the password | `{{.PasswordResetToken}}` |
249 | | Claims | All custom-claims which stored in relation to the user | `{{if index .Claims "first_name"}}` |
250 |
251 | ## Development
252 |
253 | ### mocks
254 |
255 | Mocks will be generated with github.com/matryer/moq. Execute the following for generation:
256 |
257 | ```shell script
258 | go get github.com/matryer/moq
259 | go generate ./...
260 | ```
261 |
262 | ### component tests
263 |
264 | Component tests can be executed locally with:
265 |
266 | ```shell script
267 | # build simple-jwt-provider from source code
268 | # setup infrastructure
269 | # run all test file with build-tag component in /cmd/provider
270 | ./component-tests.sh
271 | ```
--------------------------------------------------------------------------------
/internal/mailer/smtp_test.go:
--------------------------------------------------------------------------------
1 | package mailer
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "gopkg.in/mail.v2"
7 | "gotest.tools/assert"
8 | "reflect"
9 | "testing"
10 | )
11 |
12 | func TestNew(t *testing.T) {
13 | tests := []struct {
14 | name string
15 | dialerDialSendCloser mail.SendCloser
16 | dialerDialErr error
17 | loadTemplatesTemplate mailTemplate
18 | loadTemplatesErr error
19 | expectedErr error
20 | expectedMailerTemplates map[string]template
21 | }{
22 | {
23 | name: "Happycase",
24 | dialerDialSendCloser: &sendCloserMock{
25 | CloseFunc: func() error { return nil },
26 | },
27 | loadTemplatesTemplate: mailTemplate{
28 | name: "password-reset-request",
29 | },
30 | expectedMailerTemplates: map[string]template{
31 | "password-reset-request": mailTemplate{
32 | name: "password-reset-request",
33 | },
34 | },
35 | }, {
36 | name: "Unable to connect to smtp server",
37 | dialerDialErr: errors.New("unable to dial: !42"),
38 | expectedErr: errors.New("failed to connect to smtp server: unable to dial: !42"),
39 | }, {
40 | name: "Unable to load templates",
41 | dialerDialSendCloser: &sendCloserMock{
42 | CloseFunc: func() error { return nil },
43 | },
44 | loadTemplatesErr: errors.New("angry file system: you're stupid peace of s*it"),
45 | expectedErr: errors.New("failed to load password-reset mailTemplate: angry file system: you're stupid peace of s*it"),
46 | },
47 | }
48 | for _, tt := range tests {
49 | t.Run(tt.name, func(t *testing.T) {
50 | oldBuildDialer := buildDialer
51 | oldLoadTemplates := loadTemplates
52 | defer func() {
53 | buildDialer = oldBuildDialer
54 | loadTemplates = oldLoadTemplates
55 | }()
56 |
57 | givenTemplatesFolderPath := "/my/mailTemplate/path"
58 | givenUsername := ">username<"
59 | givenPassword := ">password<"
60 | givenHost := ">host<"
61 | givenPort := 5555
62 | givenTLSInsecureSkipVerify := true
63 | givenTLSServerName := ">tlsServerName<"
64 | givenDialer := &dialerMock{
65 | DialFunc: func() (mail.SendCloser, error) {
66 | return tt.dialerDialSendCloser, tt.dialerDialErr
67 | },
68 | }
69 |
70 | buildDialer = func(username, password, host string, port int, tlsInsecureSkipVerify bool, tlsServerName string) dialer {
71 | return givenDialer
72 | }
73 |
74 | loadTemplates = func(path, name string) (mailTemplate, error) {
75 | if path != givenTemplatesFolderPath {
76 | t.Errorf("unexpected loadTemplates.path. Given: %q, Expected: %q", path, givenTemplatesFolderPath)
77 | }
78 |
79 | expectedLoadedTemplateName := "password-reset-request"
80 | if name != expectedLoadedTemplateName {
81 | t.Errorf("unexpected loadTemplates.name. Given: %q, Expected: %q", name, expectedLoadedTemplateName)
82 | }
83 |
84 | return tt.loadTemplatesTemplate, tt.loadTemplatesErr
85 | }
86 |
87 | mailer, err := New(givenTemplatesFolderPath, givenUsername, givenPassword, givenHost, givenPort, givenTLSInsecureSkipVerify, givenTLSServerName)
88 | if fmt.Sprint(err) != fmt.Sprint(tt.expectedErr) {
89 | t.Fatalf("Unexpected error. Given:\n%q\nExpected:\n%q", err, tt.loadTemplatesErr)
90 | } else if err != nil {
91 | return
92 | }
93 |
94 | if !reflect.DeepEqual(mailer.templates, tt.expectedMailerTemplates) {
95 | t.Fatalf("mailer.templates are not as expected. Given:\n%#v\nExpected:\n%#v", mailer.templates, tt.expectedMailerTemplates)
96 | }
97 |
98 | if !reflect.DeepEqual(mailer.dialer, givenDialer) {
99 | t.Fatalf("mailer.dialer is not as expected. Given:\n%#v\nExpected:\n%#v", mailer.dialer, givenDialer)
100 | }
101 | })
102 | }
103 |
104 | }
105 |
106 | func TestMailer_SendPasswordResetRequestEMail_HappyCase(t *testing.T) {
107 | givenRecipient := ">recipient<"
108 | givenPasswordResetToken := ">passwordResetToken<"
109 | givenClaims := map[string]interface{}{
110 | "customClaim4711": 3,
111 | }
112 |
113 | prrMail := mail.NewMessage(mail.SetCharset("UTF-8"))
114 | prrMail.SetHeader("test_id", "yay")
115 |
116 | var mailsToSend []*mail.Message
117 |
118 | dialer := &dialerMock{
119 | DialAndSendFunc: func(msgs ...*mail.Message) error {
120 | mailsToSend = msgs
121 | return nil
122 | },
123 | }
124 |
125 | var calledMailData interface{}
126 | tplID := "password-reset-request"
127 | tplMock := &templateMock{
128 | RenderFunc: func(mailData interface{}) (*mail.Message, error) {
129 | calledMailData = mailData
130 | return prrMail, nil
131 | },
132 | }
133 | tpls := map[string]template{
134 | tplID: tplMock,
135 | }
136 |
137 | m := Mailer{
138 | dialer: dialer,
139 | templates: tpls,
140 | }
141 |
142 | err := m.SendPasswordResetRequestEMail(givenRecipient, givenPasswordResetToken, givenClaims)
143 | if err != nil {
144 | t.Fatal("Unexpected error", err)
145 | }
146 |
147 | expectedSendMails := []*mail.Message{prrMail}
148 | if !reflect.DeepEqual(mailsToSend, expectedSendMails) {
149 | t.Errorf("The send mail(s) are not the rendered. Rendered: %#v. Send: %#v", mailsToSend, expectedSendMails)
150 | }
151 |
152 | tplMockRenderCalls := tplMock.RenderCalls()
153 | if len(tplMockRenderCalls) != 1 {
154 | t.Errorf("tpls[%q].Render should be called 1 time but was %d", tplID, len(tplMockRenderCalls))
155 | }
156 |
157 | dialerDialAndSendCalls := dialer.DialAndSendCalls()
158 | if len(dialerDialAndSendCalls) != 1 {
159 | t.Errorf("dialer.DialAndSendCalls should be called 1 time but was %d", len(dialerDialAndSendCalls))
160 | }
161 |
162 | expectedMailData := struct {
163 | Recipient string
164 | PasswordResetToken string
165 | Claims map[string]interface{}
166 | }{
167 | Recipient: givenRecipient,
168 | PasswordResetToken: givenPasswordResetToken,
169 | Claims: givenClaims,
170 | }
171 | if !reflect.DeepEqual(expectedMailData, calledMailData) {
172 | t.Errorf("called mail data are not as expected. Expected:\n%#v\nGiven:\n%#v", expectedMailData, calledMailData)
173 | }
174 |
175 | dialerDialCalls := dialer.DialCalls()
176 | if len(dialerDialCalls) != 0 {
177 | t.Errorf("dialer.DialAndSendCalls should be called 0 time but was %d", len(dialerDialCalls))
178 | }
179 | }
180 |
181 | func TestMailer_SendPasswordResetRequestEMail_TemplateNotFound(t *testing.T) {
182 | givenRecipient := ">recipient<"
183 | givenPasswordResetToken := ">passwordResetToken<"
184 | givenClaims := map[string]interface{}{
185 | "customClaim4711": 3,
186 | }
187 |
188 | prrMail := mail.NewMessage(mail.SetCharset("UTF-8"))
189 | prrMail.SetHeader("test_id", "yay")
190 |
191 | m := Mailer{
192 | templates: map[string]template{},
193 | }
194 |
195 | err := m.SendPasswordResetRequestEMail(givenRecipient, givenPasswordResetToken, givenClaims)
196 | expectedError := errors.New("could not found mailTemplate with name \"password-reset-request\"")
197 | if fmt.Sprint(err) != fmt.Sprint(expectedError) {
198 | t.Fatalf("Unexpected error. Error:\n%q,\nExpected:\n%q", err, expectedError)
199 | }
200 | }
201 |
202 | func TestMailer_SendPasswordResetRequestEMail_FailedToRenderTemplate(t *testing.T) {
203 | givenRecipient := ">recipient<"
204 | givenPasswordResetToken := ">passwordResetToken<"
205 | givenClaims := map[string]interface{}{
206 | "customClaim4711": 3,
207 | }
208 |
209 | prrMail := mail.NewMessage(mail.SetCharset("UTF-8"))
210 | prrMail.SetHeader("test_id", "yay")
211 |
212 | var calledMailData interface{}
213 | tplID := "password-reset-request"
214 | tplMock := &templateMock{
215 | RenderFunc: func(mailData interface{}) (*mail.Message, error) {
216 | calledMailData = mailData
217 | return prrMail, errors.New("i dont think so")
218 | },
219 | }
220 | tpls := map[string]template{
221 | tplID: tplMock,
222 | }
223 |
224 | m := Mailer{
225 | templates: tpls,
226 | }
227 |
228 | err := m.SendPasswordResetRequestEMail(givenRecipient, givenPasswordResetToken, givenClaims)
229 | expectedError := errors.New("failed to render mail-template \"password-reset-request\": i dont think so")
230 | if fmt.Sprint(err) != fmt.Sprint(expectedError) {
231 | t.Fatalf("Unexpected error. Error:\n%q,\nExpected:\n%q", err, expectedError)
232 | }
233 |
234 | tplMockRenderCalls := tplMock.RenderCalls()
235 | if len(tplMockRenderCalls) != 1 {
236 | t.Errorf("tpls[%q].Render should be called 1 time but was %d", tplID, len(tplMockRenderCalls))
237 | }
238 |
239 | expectedMailData := struct {
240 | Recipient string
241 | PasswordResetToken string
242 | Claims map[string]interface{}
243 | }{
244 | Recipient: givenRecipient,
245 | PasswordResetToken: givenPasswordResetToken,
246 | Claims: givenClaims,
247 | }
248 | if !reflect.DeepEqual(expectedMailData, calledMailData) {
249 | t.Errorf("called mail data are not as expected. Expected:\n%#v\nGiven:\n%#v", expectedMailData, calledMailData)
250 | }
251 | }
252 |
253 | func TestMailer_SendPasswordResetRequestEMail_FailedToSendMail(t *testing.T) {
254 | givenRecipient := ">recipient<"
255 | givenPasswordResetToken := ">passwordResetToken<"
256 | givenClaims := map[string]interface{}{
257 | "customClaim4711": 3,
258 | }
259 |
260 | prrMail := mail.NewMessage(mail.SetCharset("UTF-8"))
261 | prrMail.SetHeader("test_id", "yay")
262 |
263 | var mailsToSend []*mail.Message
264 |
265 | dialer := &dialerMock{
266 | DialAndSendFunc: func(msgs ...*mail.Message) error {
267 | mailsToSend = msgs
268 | return errors.New("perhaps yes but no")
269 | },
270 | }
271 |
272 | var calledMailData interface{}
273 | tplID := "password-reset-request"
274 | tplMock := &templateMock{
275 | RenderFunc: func(mailData interface{}) (*mail.Message, error) {
276 | calledMailData = mailData
277 | return prrMail, nil
278 | },
279 | }
280 | tpls := map[string]template{
281 | tplID: tplMock,
282 | }
283 |
284 | m := Mailer{
285 | dialer: dialer,
286 | templates: tpls,
287 | }
288 |
289 | err := m.SendPasswordResetRequestEMail(givenRecipient, givenPasswordResetToken, givenClaims)
290 | expectedError := errors.New("failed to send email: perhaps yes but no")
291 | if fmt.Sprint(err) != fmt.Sprint(expectedError) {
292 | t.Fatalf("Unexpected error. Error:\n%q,\nExpected:\n%q", err, expectedError)
293 | }
294 |
295 | expectedSendMails := []*mail.Message{prrMail}
296 | if !reflect.DeepEqual(mailsToSend, expectedSendMails) {
297 | t.Errorf("The send mail(s) are not the rendered. Rendered: %#v. Send: %#v", mailsToSend, expectedSendMails)
298 | }
299 |
300 | tplMockRenderCalls := tplMock.RenderCalls()
301 | if len(tplMockRenderCalls) != 1 {
302 | t.Errorf("tpls[%q].Render should be called 1 time but was %d", tplID, len(tplMockRenderCalls))
303 | }
304 |
305 | dialerDialAndSendCalls := dialer.DialAndSendCalls()
306 | if len(dialerDialAndSendCalls) != 1 {
307 | t.Errorf("dialer.DialAndSendCalls should be called 1 time but was %d", len(dialerDialAndSendCalls))
308 | }
309 |
310 | expectedMailData := struct {
311 | Recipient string
312 | PasswordResetToken string
313 | Claims map[string]interface{}
314 | }{
315 | Recipient: givenRecipient,
316 | PasswordResetToken: givenPasswordResetToken,
317 | Claims: givenClaims,
318 | }
319 | if !reflect.DeepEqual(expectedMailData, calledMailData) {
320 | t.Errorf("called mail data are not as expected. Expected:\n%#v\nGiven:\n%#v", expectedMailData, calledMailData)
321 | }
322 |
323 | dialerDialCalls := dialer.DialCalls()
324 | if len(dialerDialCalls) != 0 {
325 | t.Errorf("dialer.DialAndSendCalls should be called 0 time but was %d", len(dialerDialCalls))
326 | }
327 | }
328 |
329 | func TestBuildDialer(t *testing.T) {
330 | oldMailNewDialer := mailNewDialer
331 | defer func() {
332 | mailNewDialer = oldMailNewDialer
333 | }()
334 |
335 | //given
336 | host := "myHost"
337 | port := 1234
338 | username := "myUsername"
339 | password := "myPassword"
340 | tlsServerName := "myServerName"
341 |
342 | mailNewDialer = func(host string, port int, username, password string) *mail.Dialer {
343 | return &mail.Dialer{
344 | Host: host,
345 | Port: port,
346 | Username: username,
347 | Password: password,
348 | }
349 | }
350 |
351 | //when
352 | buildDialer := buildDialer(username, password, host, port, false, tlsServerName)
353 |
354 | mailDialer := buildDialer.(*mail.Dialer)
355 |
356 | //then
357 | assert.Equal(t, mailDialer.Host, host, "dialer>Host is not as expected")
358 | assert.Equal(t, mailDialer.Port, port, "dialer>Port is not as expected")
359 | assert.Equal(t, mailDialer.Username, username, "dialer>Username is not as expected")
360 | assert.Equal(t, mailDialer.Password, password, "dialer>Password is not as expected")
361 | assert.Equal(t, mailDialer.TLSConfig.InsecureSkipVerify, false, "dialer>TLSConfig.InsecureSkipVerify should be false")
362 | assert.Equal(t, mailDialer.TLSConfig.ServerName, tlsServerName, "buildDialer.TLSConfig.ServerName is not as expected")
363 | }
364 |
--------------------------------------------------------------------------------
/internal/web/provider_moq_test.go:
--------------------------------------------------------------------------------
1 | // Code generated by moq; DO NOT EDIT.
2 | // github.com/matryer/moq
3 |
4 | package web
5 |
6 | import (
7 | "github.com/leberKleber/simple-jwt-provider/internal"
8 | "sync"
9 | )
10 |
11 | // Ensure, that ProviderMock does implement Provider.
12 | // If this is not the case, regenerate this file with moq.
13 | var _ Provider = &ProviderMock{}
14 |
15 | // ProviderMock is a mock implementation of Provider.
16 | //
17 | // func TestSomethingThatUsesProvider(t *testing.T) {
18 | //
19 | // // make and configure a mocked Provider
20 | // mockedProvider := &ProviderMock{
21 | // CreatePasswordResetRequestFunc: func(email string) error {
22 | // panic("mock out the CreatePasswordResetRequest method")
23 | // },
24 | // CreateUserFunc: func(user internal.User) error {
25 | // panic("mock out the CreateUser method")
26 | // },
27 | // DeleteUserFunc: func(email string) error {
28 | // panic("mock out the DeleteUser method")
29 | // },
30 | // GetUserFunc: func(email string) (internal.User, error) {
31 | // panic("mock out the GetUser method")
32 | // },
33 | // LoginFunc: func(email string, password string) (string, string, error) {
34 | // panic("mock out the Login method")
35 | // },
36 | // RefreshFunc: func(refreshToken string) (string, string, error) {
37 | // panic("mock out the Refresh method")
38 | // },
39 | // ResetPasswordFunc: func(email string, resetToken string, password string) error {
40 | // panic("mock out the ResetPassword method")
41 | // },
42 | // UpdateUserFunc: func(email string, user internal.User) (internal.User, error) {
43 | // panic("mock out the UpdateUser method")
44 | // },
45 | // }
46 | //
47 | // // use mockedProvider in code that requires Provider
48 | // // and then make assertions.
49 | //
50 | // }
51 | type ProviderMock struct {
52 | // CreatePasswordResetRequestFunc mocks the CreatePasswordResetRequest method.
53 | CreatePasswordResetRequestFunc func(email string) error
54 |
55 | // CreateUserFunc mocks the CreateUser method.
56 | CreateUserFunc func(user internal.User) error
57 |
58 | // DeleteUserFunc mocks the DeleteUser method.
59 | DeleteUserFunc func(email string) error
60 |
61 | // GetUserFunc mocks the GetUser method.
62 | GetUserFunc func(email string) (internal.User, error)
63 |
64 | // LoginFunc mocks the Login method.
65 | LoginFunc func(email string, password string) (string, string, error)
66 |
67 | // RefreshFunc mocks the Refresh method.
68 | RefreshFunc func(refreshToken string) (string, string, error)
69 |
70 | // ResetPasswordFunc mocks the ResetPassword method.
71 | ResetPasswordFunc func(email string, resetToken string, password string) error
72 |
73 | // UpdateUserFunc mocks the UpdateUser method.
74 | UpdateUserFunc func(email string, user internal.User) (internal.User, error)
75 |
76 | // calls tracks calls to the methods.
77 | calls struct {
78 | // CreatePasswordResetRequest holds details about calls to the CreatePasswordResetRequest method.
79 | CreatePasswordResetRequest []struct {
80 | // Email is the email argument value.
81 | Email string
82 | }
83 | // CreateUser holds details about calls to the CreateUser method.
84 | CreateUser []struct {
85 | // User is the user argument value.
86 | User internal.User
87 | }
88 | // DeleteUser holds details about calls to the DeleteUser method.
89 | DeleteUser []struct {
90 | // Email is the email argument value.
91 | Email string
92 | }
93 | // GetUser holds details about calls to the GetUser method.
94 | GetUser []struct {
95 | // Email is the email argument value.
96 | Email string
97 | }
98 | // Login holds details about calls to the Login method.
99 | Login []struct {
100 | // Email is the email argument value.
101 | Email string
102 | // Password is the password argument value.
103 | Password string
104 | }
105 | // Refresh holds details about calls to the Refresh method.
106 | Refresh []struct {
107 | // RefreshToken is the refreshToken argument value.
108 | RefreshToken string
109 | }
110 | // ResetPassword holds details about calls to the ResetPassword method.
111 | ResetPassword []struct {
112 | // Email is the email argument value.
113 | Email string
114 | // ResetToken is the resetToken argument value.
115 | ResetToken string
116 | // Password is the password argument value.
117 | Password string
118 | }
119 | // UpdateUser holds details about calls to the UpdateUser method.
120 | UpdateUser []struct {
121 | // Email is the email argument value.
122 | Email string
123 | // User is the user argument value.
124 | User internal.User
125 | }
126 | }
127 | lockCreatePasswordResetRequest sync.RWMutex
128 | lockCreateUser sync.RWMutex
129 | lockDeleteUser sync.RWMutex
130 | lockGetUser sync.RWMutex
131 | lockLogin sync.RWMutex
132 | lockRefresh sync.RWMutex
133 | lockResetPassword sync.RWMutex
134 | lockUpdateUser sync.RWMutex
135 | }
136 |
137 | // CreatePasswordResetRequest calls CreatePasswordResetRequestFunc.
138 | func (mock *ProviderMock) CreatePasswordResetRequest(email string) error {
139 | if mock.CreatePasswordResetRequestFunc == nil {
140 | panic("ProviderMock.CreatePasswordResetRequestFunc: method is nil but Provider.CreatePasswordResetRequest was just called")
141 | }
142 | callInfo := struct {
143 | Email string
144 | }{
145 | Email: email,
146 | }
147 | mock.lockCreatePasswordResetRequest.Lock()
148 | mock.calls.CreatePasswordResetRequest = append(mock.calls.CreatePasswordResetRequest, callInfo)
149 | mock.lockCreatePasswordResetRequest.Unlock()
150 | return mock.CreatePasswordResetRequestFunc(email)
151 | }
152 |
153 | // CreatePasswordResetRequestCalls gets all the calls that were made to CreatePasswordResetRequest.
154 | // Check the length with:
155 | // len(mockedProvider.CreatePasswordResetRequestCalls())
156 | func (mock *ProviderMock) CreatePasswordResetRequestCalls() []struct {
157 | Email string
158 | } {
159 | var calls []struct {
160 | Email string
161 | }
162 | mock.lockCreatePasswordResetRequest.RLock()
163 | calls = mock.calls.CreatePasswordResetRequest
164 | mock.lockCreatePasswordResetRequest.RUnlock()
165 | return calls
166 | }
167 |
168 | // CreateUser calls CreateUserFunc.
169 | func (mock *ProviderMock) CreateUser(user internal.User) error {
170 | if mock.CreateUserFunc == nil {
171 | panic("ProviderMock.CreateUserFunc: method is nil but Provider.CreateUser was just called")
172 | }
173 | callInfo := struct {
174 | User internal.User
175 | }{
176 | User: user,
177 | }
178 | mock.lockCreateUser.Lock()
179 | mock.calls.CreateUser = append(mock.calls.CreateUser, callInfo)
180 | mock.lockCreateUser.Unlock()
181 | return mock.CreateUserFunc(user)
182 | }
183 |
184 | // CreateUserCalls gets all the calls that were made to CreateUser.
185 | // Check the length with:
186 | // len(mockedProvider.CreateUserCalls())
187 | func (mock *ProviderMock) CreateUserCalls() []struct {
188 | User internal.User
189 | } {
190 | var calls []struct {
191 | User internal.User
192 | }
193 | mock.lockCreateUser.RLock()
194 | calls = mock.calls.CreateUser
195 | mock.lockCreateUser.RUnlock()
196 | return calls
197 | }
198 |
199 | // DeleteUser calls DeleteUserFunc.
200 | func (mock *ProviderMock) DeleteUser(email string) error {
201 | if mock.DeleteUserFunc == nil {
202 | panic("ProviderMock.DeleteUserFunc: method is nil but Provider.DeleteUser was just called")
203 | }
204 | callInfo := struct {
205 | Email string
206 | }{
207 | Email: email,
208 | }
209 | mock.lockDeleteUser.Lock()
210 | mock.calls.DeleteUser = append(mock.calls.DeleteUser, callInfo)
211 | mock.lockDeleteUser.Unlock()
212 | return mock.DeleteUserFunc(email)
213 | }
214 |
215 | // DeleteUserCalls gets all the calls that were made to DeleteUser.
216 | // Check the length with:
217 | // len(mockedProvider.DeleteUserCalls())
218 | func (mock *ProviderMock) DeleteUserCalls() []struct {
219 | Email string
220 | } {
221 | var calls []struct {
222 | Email string
223 | }
224 | mock.lockDeleteUser.RLock()
225 | calls = mock.calls.DeleteUser
226 | mock.lockDeleteUser.RUnlock()
227 | return calls
228 | }
229 |
230 | // GetUser calls GetUserFunc.
231 | func (mock *ProviderMock) GetUser(email string) (internal.User, error) {
232 | if mock.GetUserFunc == nil {
233 | panic("ProviderMock.GetUserFunc: method is nil but Provider.GetUser was just called")
234 | }
235 | callInfo := struct {
236 | Email string
237 | }{
238 | Email: email,
239 | }
240 | mock.lockGetUser.Lock()
241 | mock.calls.GetUser = append(mock.calls.GetUser, callInfo)
242 | mock.lockGetUser.Unlock()
243 | return mock.GetUserFunc(email)
244 | }
245 |
246 | // GetUserCalls gets all the calls that were made to GetUser.
247 | // Check the length with:
248 | // len(mockedProvider.GetUserCalls())
249 | func (mock *ProviderMock) GetUserCalls() []struct {
250 | Email string
251 | } {
252 | var calls []struct {
253 | Email string
254 | }
255 | mock.lockGetUser.RLock()
256 | calls = mock.calls.GetUser
257 | mock.lockGetUser.RUnlock()
258 | return calls
259 | }
260 |
261 | // Login calls LoginFunc.
262 | func (mock *ProviderMock) Login(email string, password string) (string, string, error) {
263 | if mock.LoginFunc == nil {
264 | panic("ProviderMock.LoginFunc: method is nil but Provider.Login was just called")
265 | }
266 | callInfo := struct {
267 | Email string
268 | Password string
269 | }{
270 | Email: email,
271 | Password: password,
272 | }
273 | mock.lockLogin.Lock()
274 | mock.calls.Login = append(mock.calls.Login, callInfo)
275 | mock.lockLogin.Unlock()
276 | return mock.LoginFunc(email, password)
277 | }
278 |
279 | // LoginCalls gets all the calls that were made to Login.
280 | // Check the length with:
281 | // len(mockedProvider.LoginCalls())
282 | func (mock *ProviderMock) LoginCalls() []struct {
283 | Email string
284 | Password string
285 | } {
286 | var calls []struct {
287 | Email string
288 | Password string
289 | }
290 | mock.lockLogin.RLock()
291 | calls = mock.calls.Login
292 | mock.lockLogin.RUnlock()
293 | return calls
294 | }
295 |
296 | // Refresh calls RefreshFunc.
297 | func (mock *ProviderMock) Refresh(refreshToken string) (string, string, error) {
298 | if mock.RefreshFunc == nil {
299 | panic("ProviderMock.RefreshFunc: method is nil but Provider.Refresh was just called")
300 | }
301 | callInfo := struct {
302 | RefreshToken string
303 | }{
304 | RefreshToken: refreshToken,
305 | }
306 | mock.lockRefresh.Lock()
307 | mock.calls.Refresh = append(mock.calls.Refresh, callInfo)
308 | mock.lockRefresh.Unlock()
309 | return mock.RefreshFunc(refreshToken)
310 | }
311 |
312 | // RefreshCalls gets all the calls that were made to Refresh.
313 | // Check the length with:
314 | // len(mockedProvider.RefreshCalls())
315 | func (mock *ProviderMock) RefreshCalls() []struct {
316 | RefreshToken string
317 | } {
318 | var calls []struct {
319 | RefreshToken string
320 | }
321 | mock.lockRefresh.RLock()
322 | calls = mock.calls.Refresh
323 | mock.lockRefresh.RUnlock()
324 | return calls
325 | }
326 |
327 | // ResetPassword calls ResetPasswordFunc.
328 | func (mock *ProviderMock) ResetPassword(email string, resetToken string, password string) error {
329 | if mock.ResetPasswordFunc == nil {
330 | panic("ProviderMock.ResetPasswordFunc: method is nil but Provider.ResetPassword was just called")
331 | }
332 | callInfo := struct {
333 | Email string
334 | ResetToken string
335 | Password string
336 | }{
337 | Email: email,
338 | ResetToken: resetToken,
339 | Password: password,
340 | }
341 | mock.lockResetPassword.Lock()
342 | mock.calls.ResetPassword = append(mock.calls.ResetPassword, callInfo)
343 | mock.lockResetPassword.Unlock()
344 | return mock.ResetPasswordFunc(email, resetToken, password)
345 | }
346 |
347 | // ResetPasswordCalls gets all the calls that were made to ResetPassword.
348 | // Check the length with:
349 | // len(mockedProvider.ResetPasswordCalls())
350 | func (mock *ProviderMock) ResetPasswordCalls() []struct {
351 | Email string
352 | ResetToken string
353 | Password string
354 | } {
355 | var calls []struct {
356 | Email string
357 | ResetToken string
358 | Password string
359 | }
360 | mock.lockResetPassword.RLock()
361 | calls = mock.calls.ResetPassword
362 | mock.lockResetPassword.RUnlock()
363 | return calls
364 | }
365 |
366 | // UpdateUser calls UpdateUserFunc.
367 | func (mock *ProviderMock) UpdateUser(email string, user internal.User) (internal.User, error) {
368 | if mock.UpdateUserFunc == nil {
369 | panic("ProviderMock.UpdateUserFunc: method is nil but Provider.UpdateUser was just called")
370 | }
371 | callInfo := struct {
372 | Email string
373 | User internal.User
374 | }{
375 | Email: email,
376 | User: user,
377 | }
378 | mock.lockUpdateUser.Lock()
379 | mock.calls.UpdateUser = append(mock.calls.UpdateUser, callInfo)
380 | mock.lockUpdateUser.Unlock()
381 | return mock.UpdateUserFunc(email, user)
382 | }
383 |
384 | // UpdateUserCalls gets all the calls that were made to UpdateUser.
385 | // Check the length with:
386 | // len(mockedProvider.UpdateUserCalls())
387 | func (mock *ProviderMock) UpdateUserCalls() []struct {
388 | Email string
389 | User internal.User
390 | } {
391 | var calls []struct {
392 | Email string
393 | User internal.User
394 | }
395 | mock.lockUpdateUser.RLock()
396 | calls = mock.calls.UpdateUser
397 | mock.lockUpdateUser.RUnlock()
398 | return calls
399 | }
400 |
--------------------------------------------------------------------------------
/internal/admin_test.go:
--------------------------------------------------------------------------------
1 | package internal
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "github.com/leberKleber/simple-jwt-provider/internal/storage"
7 | "golang.org/x/crypto/bcrypt"
8 | "reflect"
9 | "testing"
10 | )
11 |
12 | func TestProvider_CreateUser(t *testing.T) {
13 | tests := []struct {
14 | name string
15 | givenUser User
16 | bcryptPasswordError error
17 | bcryptPasswordPassword []byte
18 | dbExpectedUser storage.User // password not encrypted
19 | dbReturnError error
20 | expectedError error
21 | }{
22 | {
23 | name: "Happycase",
24 | givenUser: User{
25 | EMail: "test@test.test",
26 | Password: "s3cr3t",
27 | Claims: map[string]interface{}{"cLaIM": "as"},
28 | },
29 | dbExpectedUser: storage.User{
30 | EMail: "test@test.test",
31 | Password: []byte("s3cr3t"),
32 | Claims: map[string]interface{}{"cLaIM": "as"},
33 | },
34 | }, {
35 | name: "user already exists",
36 | givenUser: User{
37 | EMail: "test@test.test",
38 | Password: "s3cr3t",
39 | },
40 | dbReturnError: storage.ErrUserAlreadyExists,
41 | dbExpectedUser: storage.User{
42 | EMail: "test@test.test",
43 | Password: []byte("s3cr3t"),
44 | },
45 | expectedError: ErrUserAlreadyExists,
46 | }, {
47 | name: "Some db error",
48 | givenUser: User{
49 | EMail: "test@test.test",
50 | Password: "s3cr3t",
51 | },
52 | dbReturnError: errors.New("my custom error. ALARM"),
53 | dbExpectedUser: storage.User{
54 | EMail: "test@test.test",
55 | Password: []byte("s3cr3t"),
56 | },
57 | expectedError: errors.New(`failed to query user with email "test@test.test": my custom error. ALARM`),
58 | }, {
59 | name: "failed to bcrypt password",
60 | givenUser: User{
61 | EMail: "test@test.test",
62 | Password: "s3cr3t",
63 | },
64 | bcryptPasswordError: errors.New("failed to bcrypt password"),
65 | dbReturnError: errors.New("my custom error. ALARM"),
66 | expectedError: errors.New(`failed to bcrypt password: failed to bcrypt password`),
67 | },
68 | }
69 |
70 | for _, tt := range tests {
71 | t.Run(tt.name, func(t *testing.T) {
72 | if tt.bcryptPasswordError != nil {
73 | oldBcryptPassword := bcryptPassword
74 | defer func() { bcryptPassword = oldBcryptPassword }()
75 | bcryptPassword = func(password string) ([]byte, error) {
76 | return tt.bcryptPasswordPassword, tt.bcryptPasswordError
77 | }
78 | }
79 | var givenDbUser storage.User
80 | toTest := Provider{
81 | Storage: &StorageMock{
82 | CreateUserFunc: func(user storage.User) error {
83 | givenDbUser = user
84 | return tt.dbReturnError
85 | },
86 | },
87 | }
88 |
89 | err := toTest.CreateUser(tt.givenUser)
90 | if fmt.Sprint(err) != fmt.Sprint(tt.expectedError) {
91 | t.Fatalf("Processing error is not as expected: \nExpected:%s\nGiven:%s", tt.expectedError, err)
92 | }
93 |
94 | if givenDbUser.EMail != tt.dbExpectedUser.EMail {
95 | t.Errorf("Given db user > email is not as expected: \nExpected:%s\nGiven:%s", tt.dbExpectedUser.EMail, givenDbUser.EMail)
96 | }
97 |
98 | err = bcrypt.CompareHashAndPassword(givenDbUser.Password, tt.dbExpectedUser.Password)
99 | if err != nil && !reflect.DeepEqual(givenDbUser.Password, tt.dbExpectedUser.Password) {
100 | t.Errorf("Given db user > password is not as expected: \nExpected:%s\nGiven(bcrypted):%s", tt.dbExpectedUser.Password, givenDbUser.Password)
101 | }
102 | })
103 | }
104 | }
105 |
106 | func TestProvider_GetUser(t *testing.T) {
107 | tests := []struct {
108 | name string
109 | givenEMail string
110 | dbExpectedEMail string
111 | dbReturnUser storage.User
112 | dbReturnError error
113 | expectedError error
114 | expectedUser User
115 | }{
116 | {
117 | name: "Happycase",
118 | dbExpectedEMail: "test@test.test",
119 | dbReturnUser: storage.User{
120 | EMail: "test.test@test.test",
121 | Claims: map[string]interface{}{
122 | "claaa": "bbb",
123 | },
124 | Password: []byte("password"),
125 | },
126 | givenEMail: "test@test.test",
127 | expectedUser: User{
128 | EMail: "test.test@test.test",
129 | Password: "**********",
130 | Claims: map[string]interface{}{
131 | "claaa": "bbb",
132 | },
133 | },
134 | }, {
135 | name: "user not found",
136 | givenEMail: "test@test.test",
137 | dbExpectedEMail: "test@test.test",
138 | dbReturnError: storage.ErrUserNotFound,
139 | expectedError: ErrUserNotFound,
140 | }, {
141 | name: "Some db error",
142 | givenEMail: "test@test.test",
143 | dbExpectedEMail: "test@test.test",
144 | dbReturnError: errors.New("my custom error. ALARM"),
145 | expectedError: errors.New(`failed to find user with email "test@test.test": my custom error. ALARM`),
146 | },
147 | }
148 |
149 | for _, tt := range tests {
150 | t.Run(tt.name, func(t *testing.T) {
151 | var givenEMail string
152 | toTest := Provider{
153 | Storage: &StorageMock{
154 | UserFunc: func(email string) (storage.User, error) {
155 | givenEMail = email
156 | return tt.dbReturnUser, tt.dbReturnError
157 | },
158 | },
159 | }
160 |
161 | user, err := toTest.GetUser(tt.givenEMail)
162 | if fmt.Sprint(err) != fmt.Sprint(tt.expectedError) {
163 | t.Fatalf("Processing error is not as expected: \nExpected:%s\nGiven:%s", tt.expectedError, err)
164 | }
165 |
166 | if !reflect.DeepEqual(user, tt.expectedUser) {
167 | t.Errorf("Returned user is not as expected. Given:\n%#v\nExpected:\n%#v", user, tt.expectedUser)
168 | }
169 |
170 | if givenEMail != tt.dbExpectedEMail {
171 | t.Errorf("Given db email is not as expected: \nExpected:%s\nGiven:%s", tt.dbExpectedEMail, givenEMail)
172 | }
173 | })
174 | }
175 |
176 | }
177 |
178 | func TestProvider_UpdateUser_Happycase(t *testing.T) {
179 | dbUserToUpdate := storage.User{
180 | EMail: "test.test@test.test",
181 | Password: []byte("testSecret"),
182 | Claims: map[string]interface{}{
183 | "c": "g",
184 | },
185 | }
186 | var dbUpdateUser storage.User
187 | toTest := Provider{
188 | Storage: &StorageMock{
189 | UserFunc: func(email string) (storage.User, error) {
190 | return dbUserToUpdate, nil
191 | },
192 | UpdateUserFunc: func(user storage.User) error {
193 | dbUpdateUser = user
194 | return nil
195 | },
196 | },
197 | }
198 |
199 | updatedUser, err := toTest.UpdateUser("test.test@test.test", User{
200 | Password: "newPassword",
201 | Claims: map[string]interface{}{
202 | "d": "w",
203 | },
204 | })
205 | if err != nil {
206 | t.Fatal("unexpected error", err)
207 | }
208 |
209 | expectedUpdatedUser := User{
210 | EMail: "test.test@test.test",
211 | Password: "**********",
212 | Claims: map[string]interface{}{
213 | "d": "w",
214 | },
215 | }
216 | if !reflect.DeepEqual(updatedUser, expectedUpdatedUser) {
217 | t.Errorf("returned updated user is not as expected. Expected:\n%#v\nGiven:\n%#v", expectedUpdatedUser, updatedUser)
218 | }
219 |
220 | expectedDBUpdateUser := storage.User{
221 | EMail: "test.test@test.test",
222 | Password: []byte("newPassword"),
223 | Claims: map[string]interface{}{
224 | "d": "w",
225 | },
226 | }
227 | if dbUpdateUser.EMail != expectedDBUpdateUser.EMail {
228 | t.Errorf("user.email to update in db is not as expected. Expected:\n%q\nGiven:\n%q", expectedDBUpdateUser.EMail, dbUpdateUser.EMail)
229 | }
230 |
231 | err = bcrypt.CompareHashAndPassword(dbUpdateUser.Password, expectedDBUpdateUser.Password)
232 | if err != nil {
233 | t.Errorf("user.password to update in db is not as expected. Expected:\n%q", expectedDBUpdateUser.Password)
234 | }
235 |
236 | if !reflect.DeepEqual(dbUpdateUser.Claims, expectedDBUpdateUser.Claims) {
237 | t.Errorf("user.claims to update in db is not as expected. Expected:\n%#v\nGiven:\n%#v", expectedDBUpdateUser.Claims, dbUpdateUser.Claims)
238 | }
239 | }
240 |
241 | func TestProvider_UpdateUser_UnableToGetUser(t *testing.T) {
242 | tests := []struct {
243 | name string
244 | dbUserResponseError error
245 | expectedError error
246 | }{
247 | {
248 | name: "user not found",
249 | dbUserResponseError: storage.ErrUserNotFound,
250 | expectedError: ErrUserNotFound,
251 | },
252 | {
253 | name: "unexpected error",
254 | dbUserResponseError: errors.New("nope"),
255 | expectedError: errors.New("failed to find user to update: nope"),
256 | },
257 | }
258 |
259 | for _, tt := range tests {
260 | t.Run(tt.name, func(t *testing.T) {
261 | userEMail := "test.test@test.test"
262 | var dbUserCalledEMail string
263 | toTest := Provider{
264 | Storage: &StorageMock{
265 | UserFunc: func(email string) (storage.User, error) {
266 | dbUserCalledEMail = email
267 | return storage.User{}, tt.dbUserResponseError
268 | },
269 | },
270 | }
271 |
272 | _, err := toTest.UpdateUser(userEMail, User{
273 | Password: "newPassword",
274 | Claims: map[string]interface{}{
275 | "d": "w",
276 | },
277 | })
278 | if fmt.Sprint(err) != fmt.Sprint(tt.expectedError) {
279 | t.Errorf("unexpected error. Expected:\n%q\nGiven:\n%q", tt.expectedError, err)
280 | }
281 | if dbUserCalledEMail != userEMail {
282 | t.Errorf("db called with unexpected email. Expected: %q, Given: %q", userEMail, dbUserCalledEMail)
283 | }
284 | })
285 | }
286 |
287 | }
288 |
289 | func TestProvider_UpdateUser_UnableToUpdateUser(t *testing.T) {
290 | tests := []struct {
291 | name string
292 | dbUpdateUserResponseError error
293 | expectedError error
294 | }{
295 | {
296 | name: "user not found",
297 | dbUpdateUserResponseError: storage.ErrUserNotFound,
298 | expectedError: ErrUserNotFound,
299 | },
300 | {
301 | name: "unexpected error",
302 | dbUpdateUserResponseError: errors.New("nope"),
303 | expectedError: errors.New("failed to update user: nope"),
304 | },
305 | }
306 |
307 | for _, tt := range tests {
308 | t.Run(tt.name, func(t *testing.T) {
309 | userEMail := "test.test@test.test"
310 | toTest := Provider{
311 | Storage: &StorageMock{
312 | UserFunc: func(_ string) (storage.User, error) {
313 | return storage.User{
314 | EMail: userEMail,
315 | Password: []byte("bycryptedPassword"),
316 | Claims: map[string]interface{}{
317 | "stored": "claim",
318 | },
319 | }, nil
320 | },
321 | UpdateUserFunc: func(_ storage.User) error {
322 | return tt.dbUpdateUserResponseError
323 | },
324 | },
325 | }
326 |
327 | _, err := toTest.UpdateUser(userEMail, User{
328 | Password: "newPassword",
329 | Claims: map[string]interface{}{
330 | "d": "w",
331 | },
332 | })
333 | if fmt.Sprint(err) != fmt.Sprint(tt.expectedError) {
334 | t.Errorf("unexpected error. Expected:\n%q\nGiven:\n%q", tt.expectedError, err)
335 | }
336 | })
337 | }
338 | }
339 |
340 | func TestProvider_UpdateUser_UnableToBcryptPassword(t *testing.T) {
341 | oldBcryptPassword := bcryptPassword
342 | defer func() { bcryptPassword = oldBcryptPassword }()
343 | bcryptPassword = func(password string) ([]byte, error) {
344 | return nil, errors.New("failed to bcryptPassword")
345 | }
346 |
347 | userEMail := "test.test@test.test"
348 | toTest := Provider{
349 | Storage: &StorageMock{
350 | UserFunc: func(_ string) (storage.User, error) {
351 | return storage.User{
352 | EMail: userEMail,
353 | Password: []byte("bycryptedPassword"),
354 | Claims: map[string]interface{}{
355 | "stored": "claim",
356 | },
357 | }, nil
358 | },
359 | UpdateUserFunc: func(_ storage.User) error {
360 | return storage.ErrUserNotFound
361 | },
362 | },
363 | }
364 |
365 | _, err := toTest.UpdateUser(userEMail, User{
366 | Password: "newPassword",
367 | Claims: map[string]interface{}{
368 | "d": "w",
369 | },
370 | })
371 |
372 | expectedErr := errors.New("failed to bcrypt new password: failed to bcryptPassword")
373 | if fmt.Sprint(err) != fmt.Sprint(expectedErr) {
374 | t.Errorf("unexpected error. Expected:\n%q\nGiven:\n%q", expectedErr, err)
375 | }
376 | }
377 |
378 | func TestProvider_DeleteUser(t *testing.T) {
379 | tests := []struct {
380 | name string
381 | givenEMail string
382 | expectedError error
383 | dbExpectedEMail string
384 | dbReturnError error
385 | }{
386 | {
387 | name: "Happycase",
388 | dbExpectedEMail: "test@test.test",
389 | givenEMail: "test@test.test",
390 | }, {
391 | name: "user not found",
392 | givenEMail: "test@test.test",
393 | dbExpectedEMail: "test@test.test",
394 | dbReturnError: storage.ErrUserNotFound,
395 | expectedError: ErrUserNotFound,
396 | }, {
397 | name: "Some db error",
398 | givenEMail: "test@test.test",
399 | dbExpectedEMail: "test@test.test",
400 | dbReturnError: errors.New("my custom error. ALARM"),
401 | expectedError: errors.New(`failed to delete user with email "test@test.test": my custom error. ALARM`),
402 | },
403 | }
404 |
405 | for _, tt := range tests {
406 | t.Run(tt.name, func(t *testing.T) {
407 | var givenEMail string
408 | toTest := Provider{
409 | Storage: &StorageMock{
410 | DeleteUserFunc: func(email string) error {
411 | givenEMail = email
412 | return tt.dbReturnError
413 | },
414 | },
415 | }
416 |
417 | err := toTest.DeleteUser(tt.givenEMail)
418 | if fmt.Sprint(err) != fmt.Sprint(tt.expectedError) {
419 | t.Fatalf("Processing error is not as expected: \nExpected:%s\nGiven:%s", tt.expectedError, err)
420 | }
421 |
422 | if givenEMail != tt.dbExpectedEMail {
423 | t.Errorf("Given db email is not as expected: \nExpected:%s\nGiven:%s", tt.dbExpectedEMail, givenEMail)
424 | }
425 | })
426 | }
427 |
428 | }
429 |
--------------------------------------------------------------------------------