├── .github └── workflows │ ├── deploy.yml │ └── test.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── api ├── account.go ├── account_test.go ├── main_test.go ├── middleware.go ├── middleware_test.go ├── server.go ├── token.go ├── transfer.go ├── transfer_test.go ├── user.go ├── user_test.go └── validator.go ├── app.env ├── backend-master.png ├── db ├── migration │ ├── 000001_init_schema.down.sql │ ├── 000001_init_schema.up.sql │ ├── 000002_add_users.down.sql │ ├── 000002_add_users.up.sql │ ├── 000003_add_sessions.down.sql │ ├── 000003_add_sessions.up.sql │ ├── 000004_add_verify_emails.down.sql │ ├── 000004_add_verify_emails.up.sql │ ├── 000005_add_role_to_users.down.sql │ └── 000005_add_role_to_users.up.sql ├── mock │ └── store.go ├── query │ ├── account.sql │ ├── entry.sql │ ├── session.sql │ ├── transfer.sql │ ├── user.sql │ └── verify_email.sql └── sqlc │ ├── account.sql.go │ ├── account_test.go │ ├── db.go │ ├── entry.sql.go │ ├── entry_test.go │ ├── error.go │ ├── exec_tx.go │ ├── main_test.go │ ├── models.go │ ├── querier.go │ ├── session.sql.go │ ├── store.go │ ├── store_test.go │ ├── transfer.sql.go │ ├── transfer_test.go │ ├── tx_create_user.go │ ├── tx_transfer.go │ ├── tx_verify_email.go │ ├── user.sql.go │ ├── user_test.go │ └── verify_email.sql.go ├── doc ├── db.dbml ├── schema.sql ├── statik │ └── statik.go └── swagger │ ├── favicon-16x16.png │ ├── favicon-32x32.png │ ├── index.css │ ├── index.html │ ├── oauth2-redirect.html │ ├── simple_bank.swagger.json │ ├── swagger-initializer.js │ ├── swagger-ui-bundle.js │ ├── swagger-ui-bundle.js.map │ ├── swagger-ui-es-bundle-core.js │ ├── swagger-ui-es-bundle-core.js.map │ ├── swagger-ui-es-bundle.js │ ├── swagger-ui-es-bundle.js.map │ ├── swagger-ui-standalone-preset.js │ ├── swagger-ui-standalone-preset.js.map │ ├── swagger-ui.css │ ├── swagger-ui.css.map │ ├── swagger-ui.js │ └── swagger-ui.js.map ├── docker-compose.yaml ├── eks ├── aws-auth.yaml ├── deployment.yaml ├── ingress-grpc.yaml ├── ingress-http.yaml ├── ingress-nginx.yaml ├── install.sh ├── issuer.yaml └── service.yaml ├── frontend ├── .eslintrc.cjs ├── .gitignore ├── .prettierrc.json ├── README.md ├── env.d.ts ├── index.html ├── package-lock.json ├── package.json ├── public │ └── favicon.ico ├── src │ ├── App.vue │ ├── assets │ │ └── main.css │ ├── components │ │ ├── LoginUser.vue │ │ └── UserInfo.vue │ ├── main.ts │ ├── router │ │ └── index.ts │ ├── store.ts │ ├── types │ │ ├── auth_state.ts │ │ └── user.ts │ └── views │ │ └── HomeView.vue ├── tsconfig.app.json ├── tsconfig.json ├── tsconfig.node.json ├── tsconfig.vitest.json ├── vite.config.ts └── vitest.config.ts ├── gapi ├── authorization.go ├── converter.go ├── error.go ├── logger.go ├── main_test.go ├── metadata.go ├── rpc_create_user.go ├── rpc_create_user_test.go ├── rpc_login_user.go ├── rpc_update_user.go ├── rpc_update_user_test.go ├── rpc_verify_email.go └── server.go ├── go.mod ├── go.sum ├── mail ├── sender.go └── sender_test.go ├── main.go ├── pb ├── rpc_create_user.pb.go ├── rpc_login_user.pb.go ├── rpc_update_user.pb.go ├── rpc_verify_email.pb.go ├── service_simple_bank.pb.go ├── service_simple_bank.pb.gw.go ├── service_simple_bank_grpc.pb.go └── user.pb.go ├── proto ├── google │ └── api │ │ ├── annotations.proto │ │ ├── field_behavior.proto │ │ ├── http.proto │ │ └── httpbody.proto ├── protoc-gen-openapiv2 │ └── options │ │ ├── annotations.proto │ │ └── openapiv2.proto ├── rpc_create_user.proto ├── rpc_login_user.proto ├── rpc_update_user.proto ├── rpc_verify_email.proto ├── service_simple_bank.proto └── user.proto ├── simplebank ├── sqlc.yaml ├── start.sh ├── token ├── jwt_maker.go ├── jwt_maker_test.go ├── maker.go ├── paseto_maker.go ├── paseto_maker_test.go └── payload.go ├── util ├── config.go ├── currency.go ├── password.go ├── password_test.go ├── random.go └── role.go ├── val └── validator.go ├── wait-for.sh └── worker ├── distributor.go ├── logger.go ├── mock └── distributor.go ├── processor.go └── task_send_verify_email.go /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: Deploy to production 2 | 3 | on: 4 | push: 5 | branches: [release] 6 | 7 | jobs: 8 | deploy: 9 | name: Build image 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Check out code 14 | uses: actions/checkout@v2 15 | 16 | - name: Install kubectl 17 | uses: azure/setup-kubectl@v1 18 | with: 19 | version: "v1.21.3" 20 | id: install 21 | 22 | - name: Configure AWS credentials 23 | uses: aws-actions/configure-aws-credentials@v1 24 | with: 25 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 26 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 27 | aws-region: eu-west-1 28 | 29 | - name: Login to Amazon ECR 30 | id: login-ecr 31 | uses: aws-actions/amazon-ecr-login@v1 32 | 33 | - name: Load secrets and save to app.env 34 | run: aws secretsmanager get-secret-value --secret-id simple_bank --query SecretString --output text | jq -r 'to_entries|map("\(.key)=\(.value)")|.[]' > app.env 35 | 36 | - name: Build, tag, and push image to Amazon ECR 37 | env: 38 | ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }} 39 | ECR_REPOSITORY: simplebank 40 | IMAGE_TAG: ${{ github.sha }} 41 | run: | 42 | docker build -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG -t $ECR_REGISTRY/$ECR_REPOSITORY:latest . 43 | docker push -a $ECR_REGISTRY/$ECR_REPOSITORY 44 | 45 | - name: Update kube config 46 | run: aws eks update-kubeconfig --name simple-bank-eks --region eu-west-1 47 | 48 | - name: Deploy image to Amazon EKS 49 | run: | 50 | kubectl apply -f eks/aws-auth.yaml 51 | kubectl apply -f eks/deployment.yaml 52 | kubectl apply -f eks/service.yaml 53 | kubectl apply -f eks/issuer.yaml 54 | kubectl apply -f eks/ingress-nginx.yaml 55 | kubectl apply -f eks/ingress-http.yaml 56 | kubectl apply -f eks/ingress-grpc.yaml 57 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Run unit tests 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | test: 11 | name: Test 12 | runs-on: ubuntu-latest 13 | 14 | services: 15 | postgres: 16 | image: postgres:14-alpine 17 | env: 18 | POSTGRES_USER: root 19 | POSTGRES_PASSWORD: secret 20 | POSTGRES_DB: simple_bank 21 | ports: 22 | - 5432:5432 23 | options: >- 24 | --health-cmd pg_isready 25 | --health-interval 10s 26 | --health-timeout 5s 27 | --health-retries 5 28 | 29 | steps: 30 | - name: Set up Go 1.x 31 | uses: actions/setup-go@v2 32 | with: 33 | go-version: ^1.22 34 | id: go 35 | 36 | - name: Check out code into the Go module directory 37 | uses: actions/checkout@v2 38 | 39 | - name: Install golang-migrate 40 | run: | 41 | curl -L https://github.com/golang-migrate/migrate/releases/download/v4.14.1/migrate.linux-amd64.tar.gz | tar xvz 42 | sudo mv migrate.linux-amd64 /usr/bin/migrate 43 | which migrate 44 | 45 | - name: Run migrations 46 | run: make migrateup 47 | 48 | - name: Test 49 | run: make test 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | .vscode 3 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Build stage 2 | FROM golang:1.22-alpine3.19 AS builder 3 | WORKDIR /app 4 | COPY . . 5 | RUN go build -o main main.go 6 | 7 | # Run stage 8 | FROM alpine:3.19 9 | WORKDIR /app 10 | COPY --from=builder /app/main . 11 | COPY app.env . 12 | COPY start.sh . 13 | COPY wait-for.sh . 14 | COPY db/migration ./db/migration 15 | 16 | EXPOSE 8080 9090 17 | CMD [ "/app/main" ] 18 | ENTRYPOINT [ "/app/start.sh" ] 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 Quang Pham. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | DB_URL=postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable 2 | 3 | network: 4 | docker network create bank-network 5 | 6 | postgres: 7 | docker run --name postgres --network bank-network -p 5432:5432 -e POSTGRES_USER=root -e POSTGRES_PASSWORD=secret -d postgres:14-alpine 8 | 9 | mysql: 10 | docker run --name mysql8 -p 3306:3306 -e MYSQL_ROOT_PASSWORD=secret -d mysql:8 11 | 12 | createdb: 13 | docker exec -it postgres createdb --username=root --owner=root simple_bank 14 | 15 | dropdb: 16 | docker exec -it postgres dropdb simple_bank 17 | 18 | migrateup: 19 | migrate -path db/migration -database "$(DB_URL)" -verbose up 20 | 21 | migrateup1: 22 | migrate -path db/migration -database "$(DB_URL)" -verbose up 1 23 | 24 | migratedown: 25 | migrate -path db/migration -database "$(DB_URL)" -verbose down 26 | 27 | migratedown1: 28 | migrate -path db/migration -database "$(DB_URL)" -verbose down 1 29 | 30 | new_migration: 31 | migrate create -ext sql -dir db/migration -seq $(name) 32 | 33 | db_docs: 34 | dbdocs build doc/db.dbml 35 | 36 | db_schema: 37 | dbml2sql --postgres -o doc/schema.sql doc/db.dbml 38 | 39 | sqlc: 40 | sqlc generate 41 | 42 | test: 43 | go test -v -cover -short ./... 44 | 45 | server: 46 | go run main.go 47 | 48 | mock: 49 | mockgen -package mockdb -destination db/mock/store.go github.com/techschool/simplebank/db/sqlc Store 50 | mockgen -package mockwk -destination worker/mock/distributor.go github.com/techschool/simplebank/worker TaskDistributor 51 | 52 | proto: 53 | rm -f pb/*.go 54 | rm -f doc/swagger/*.swagger.json 55 | protoc --proto_path=proto --go_out=pb --go_opt=paths=source_relative \ 56 | --go-grpc_out=pb --go-grpc_opt=paths=source_relative \ 57 | --grpc-gateway_out=pb --grpc-gateway_opt=paths=source_relative \ 58 | --openapiv2_out=doc/swagger --openapiv2_opt=allow_merge=true,merge_file_name=simple_bank \ 59 | proto/*.proto 60 | statik -src=./doc/swagger -dest=./doc 61 | 62 | evans: 63 | evans --host localhost --port 9090 -r repl 64 | 65 | redis: 66 | docker run --name redis -p 6379:6379 -d redis:7-alpine 67 | 68 | .PHONY: network postgres createdb dropdb migrateup migratedown migrateup1 migratedown1 new_migration db_docs db_schema sqlc test server mock proto evans redis 69 | -------------------------------------------------------------------------------- /api/account.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | 7 | "github.com/gin-gonic/gin" 8 | db "github.com/techschool/simplebank/db/sqlc" 9 | "github.com/techschool/simplebank/token" 10 | ) 11 | 12 | type createAccountRequest struct { 13 | Currency string `json:"currency" binding:"required,currency"` 14 | } 15 | 16 | func (server *Server) createAccount(ctx *gin.Context) { 17 | var req createAccountRequest 18 | if err := ctx.ShouldBindJSON(&req); err != nil { 19 | ctx.JSON(http.StatusBadRequest, errorResponse(err)) 20 | return 21 | } 22 | 23 | authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) 24 | arg := db.CreateAccountParams{ 25 | Owner: authPayload.Username, 26 | Currency: req.Currency, 27 | Balance: 0, 28 | } 29 | 30 | account, err := server.store.CreateAccount(ctx, arg) 31 | if err != nil { 32 | errCode := db.ErrorCode(err) 33 | if errCode == db.ForeignKeyViolation || errCode == db.UniqueViolation { 34 | ctx.JSON(http.StatusForbidden, errorResponse(err)) 35 | return 36 | } 37 | ctx.JSON(http.StatusInternalServerError, errorResponse(err)) 38 | return 39 | } 40 | 41 | ctx.JSON(http.StatusOK, account) 42 | } 43 | 44 | type getAccountRequest struct { 45 | ID int64 `uri:"id" binding:"required,min=1"` 46 | } 47 | 48 | func (server *Server) getAccount(ctx *gin.Context) { 49 | var req getAccountRequest 50 | if err := ctx.ShouldBindUri(&req); err != nil { 51 | ctx.JSON(http.StatusBadRequest, errorResponse(err)) 52 | return 53 | } 54 | 55 | account, err := server.store.GetAccount(ctx, req.ID) 56 | if err != nil { 57 | if errors.Is(err, db.ErrRecordNotFound) { 58 | ctx.JSON(http.StatusNotFound, errorResponse(err)) 59 | return 60 | } 61 | 62 | ctx.JSON(http.StatusInternalServerError, errorResponse(err)) 63 | return 64 | } 65 | 66 | authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) 67 | if account.Owner != authPayload.Username { 68 | err := errors.New("account doesn't belong to the authenticated user") 69 | ctx.JSON(http.StatusUnauthorized, errorResponse(err)) 70 | return 71 | } 72 | 73 | ctx.JSON(http.StatusOK, account) 74 | } 75 | 76 | type listAccountRequest struct { 77 | PageID int32 `form:"page_id" binding:"required,min=1"` 78 | PageSize int32 `form:"page_size" binding:"required,min=5,max=10"` 79 | } 80 | 81 | func (server *Server) listAccounts(ctx *gin.Context) { 82 | var req listAccountRequest 83 | if err := ctx.ShouldBindQuery(&req); err != nil { 84 | ctx.JSON(http.StatusBadRequest, errorResponse(err)) 85 | return 86 | } 87 | 88 | authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) 89 | arg := db.ListAccountsParams{ 90 | Owner: authPayload.Username, 91 | Limit: req.PageSize, 92 | Offset: (req.PageID - 1) * req.PageSize, 93 | } 94 | 95 | accounts, err := server.store.ListAccounts(ctx, arg) 96 | if err != nil { 97 | ctx.JSON(http.StatusInternalServerError, errorResponse(err)) 98 | return 99 | } 100 | 101 | ctx.JSON(http.StatusOK, accounts) 102 | } 103 | -------------------------------------------------------------------------------- /api/main_test.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | "time" 7 | 8 | "github.com/gin-gonic/gin" 9 | "github.com/stretchr/testify/require" 10 | db "github.com/techschool/simplebank/db/sqlc" 11 | "github.com/techschool/simplebank/util" 12 | ) 13 | 14 | func newTestServer(t *testing.T, store db.Store) *Server { 15 | config := util.Config{ 16 | TokenSymmetricKey: util.RandomString(32), 17 | AccessTokenDuration: time.Minute, 18 | } 19 | 20 | server, err := NewServer(config, store) 21 | require.NoError(t, err) 22 | 23 | return server 24 | } 25 | 26 | func TestMain(m *testing.M) { 27 | gin.SetMode(gin.TestMode) 28 | 29 | os.Exit(m.Run()) 30 | } 31 | -------------------------------------------------------------------------------- /api/middleware.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "strings" 8 | 9 | "github.com/gin-gonic/gin" 10 | "github.com/techschool/simplebank/token" 11 | ) 12 | 13 | const ( 14 | authorizationHeaderKey = "authorization" 15 | authorizationTypeBearer = "bearer" 16 | authorizationPayloadKey = "authorization_payload" 17 | ) 18 | 19 | // AuthMiddleware creates a gin middleware for authorization 20 | func authMiddleware(tokenMaker token.Maker) gin.HandlerFunc { 21 | return func(ctx *gin.Context) { 22 | authorizationHeader := ctx.GetHeader(authorizationHeaderKey) 23 | 24 | if len(authorizationHeader) == 0 { 25 | err := errors.New("authorization header is not provided") 26 | ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) 27 | return 28 | } 29 | 30 | fields := strings.Fields(authorizationHeader) 31 | if len(fields) < 2 { 32 | err := errors.New("invalid authorization header format") 33 | ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) 34 | return 35 | } 36 | 37 | authorizationType := strings.ToLower(fields[0]) 38 | if authorizationType != authorizationTypeBearer { 39 | err := fmt.Errorf("unsupported authorization type %s", authorizationType) 40 | ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) 41 | return 42 | } 43 | 44 | accessToken := fields[1] 45 | payload, err := tokenMaker.VerifyToken(accessToken, token.TokenTypeAccessToken) 46 | if err != nil { 47 | ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) 48 | return 49 | } 50 | 51 | ctx.Set(authorizationPayloadKey, payload) 52 | ctx.Next() 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /api/middleware_test.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | "time" 9 | 10 | "github.com/gin-gonic/gin" 11 | "github.com/stretchr/testify/require" 12 | "github.com/techschool/simplebank/token" 13 | "github.com/techschool/simplebank/util" 14 | ) 15 | 16 | func addAuthorization( 17 | t *testing.T, 18 | request *http.Request, 19 | tokenMaker token.Maker, 20 | authorizationType string, 21 | username string, 22 | role string, 23 | duration time.Duration, 24 | ) { 25 | token, payload, err := tokenMaker.CreateToken(username, role, duration, token.TokenTypeAccessToken) 26 | require.NoError(t, err) 27 | require.NotEmpty(t, payload) 28 | 29 | authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token) 30 | request.Header.Set(authorizationHeaderKey, authorizationHeader) 31 | } 32 | 33 | func TestAuthMiddleware(t *testing.T) { 34 | username := util.RandomOwner() 35 | role := util.DepositorRole 36 | 37 | testCases := []struct { 38 | name string 39 | setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) 40 | checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder) 41 | }{ 42 | { 43 | name: "OK", 44 | setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { 45 | addAuthorization(t, request, tokenMaker, authorizationTypeBearer, username, role, time.Minute) 46 | }, 47 | checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { 48 | require.Equal(t, http.StatusOK, recorder.Code) 49 | }, 50 | }, 51 | { 52 | name: "NoAuthorization", 53 | setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { 54 | }, 55 | checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { 56 | require.Equal(t, http.StatusUnauthorized, recorder.Code) 57 | }, 58 | }, 59 | { 60 | name: "UnsupportedAuthorization", 61 | setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { 62 | addAuthorization(t, request, tokenMaker, "unsupported", username, role, time.Minute) 63 | }, 64 | checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { 65 | require.Equal(t, http.StatusUnauthorized, recorder.Code) 66 | }, 67 | }, 68 | { 69 | name: "InvalidAuthorizationFormat", 70 | setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { 71 | addAuthorization(t, request, tokenMaker, "", username, role, time.Minute) 72 | }, 73 | checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { 74 | require.Equal(t, http.StatusUnauthorized, recorder.Code) 75 | }, 76 | }, 77 | { 78 | name: "ExpiredToken", 79 | setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { 80 | addAuthorization(t, request, tokenMaker, authorizationTypeBearer, username, role, -time.Minute) 81 | }, 82 | checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { 83 | require.Equal(t, http.StatusUnauthorized, recorder.Code) 84 | }, 85 | }, 86 | } 87 | 88 | for i := range testCases { 89 | tc := testCases[i] 90 | 91 | t.Run(tc.name, func(t *testing.T) { 92 | server := newTestServer(t, nil) 93 | authPath := "/auth" 94 | server.router.GET( 95 | authPath, 96 | authMiddleware(server.tokenMaker), 97 | func(ctx *gin.Context) { 98 | ctx.JSON(http.StatusOK, gin.H{}) 99 | }, 100 | ) 101 | 102 | recorder := httptest.NewRecorder() 103 | request, err := http.NewRequest(http.MethodGet, authPath, nil) 104 | require.NoError(t, err) 105 | 106 | tc.setupAuth(t, request, server.tokenMaker) 107 | server.router.ServeHTTP(recorder, request) 108 | tc.checkResponse(t, recorder) 109 | }) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /api/server.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/gin-gonic/gin" 7 | "github.com/gin-gonic/gin/binding" 8 | "github.com/go-playground/validator/v10" 9 | db "github.com/techschool/simplebank/db/sqlc" 10 | "github.com/techschool/simplebank/token" 11 | "github.com/techschool/simplebank/util" 12 | ) 13 | 14 | // Server serves HTTP requests for our banking service. 15 | type Server struct { 16 | config util.Config 17 | store db.Store 18 | tokenMaker token.Maker 19 | router *gin.Engine 20 | } 21 | 22 | // NewServer creates a new HTTP server and set up routing. 23 | func NewServer(config util.Config, store db.Store) (*Server, error) { 24 | tokenMaker, err := token.NewPasetoMaker(config.TokenSymmetricKey) 25 | if err != nil { 26 | return nil, fmt.Errorf("cannot create token maker: %w", err) 27 | } 28 | 29 | server := &Server{ 30 | config: config, 31 | store: store, 32 | tokenMaker: tokenMaker, 33 | } 34 | 35 | if v, ok := binding.Validator.Engine().(*validator.Validate); ok { 36 | v.RegisterValidation("currency", validCurrency) 37 | } 38 | 39 | server.setupRouter() 40 | return server, nil 41 | } 42 | 43 | func (server *Server) setupRouter() { 44 | router := gin.Default() 45 | 46 | router.POST("/users", server.createUser) 47 | router.POST("/users/login", server.loginUser) 48 | router.POST("/tokens/renew_access", server.renewAccessToken) 49 | 50 | authRoutes := router.Group("/").Use(authMiddleware(server.tokenMaker)) 51 | authRoutes.POST("/accounts", server.createAccount) 52 | authRoutes.GET("/accounts/:id", server.getAccount) 53 | authRoutes.GET("/accounts", server.listAccounts) 54 | 55 | authRoutes.POST("/transfers", server.createTransfer) 56 | 57 | server.router = router 58 | } 59 | 60 | // Start runs the HTTP server on a specific address. 61 | func (server *Server) Start(address string) error { 62 | return server.router.Run(address) 63 | } 64 | 65 | func errorResponse(err error) gin.H { 66 | return gin.H{"error": err.Error()} 67 | } 68 | -------------------------------------------------------------------------------- /api/token.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "time" 8 | 9 | "github.com/gin-gonic/gin" 10 | db "github.com/techschool/simplebank/db/sqlc" 11 | "github.com/techschool/simplebank/token" 12 | ) 13 | 14 | type renewAccessTokenRequest struct { 15 | RefreshToken string `json:"refresh_token" binding:"required"` 16 | } 17 | 18 | type renewAccessTokenResponse struct { 19 | AccessToken string `json:"access_token"` 20 | AccessTokenExpiresAt time.Time `json:"access_token_expires_at"` 21 | } 22 | 23 | func (server *Server) renewAccessToken(ctx *gin.Context) { 24 | var req renewAccessTokenRequest 25 | if err := ctx.ShouldBindJSON(&req); err != nil { 26 | ctx.JSON(http.StatusBadRequest, errorResponse(err)) 27 | return 28 | } 29 | 30 | refreshPayload, err := server.tokenMaker.VerifyToken(req.RefreshToken, token.TokenTypeRefreshToken) 31 | if err != nil { 32 | ctx.JSON(http.StatusUnauthorized, errorResponse(err)) 33 | return 34 | } 35 | 36 | session, err := server.store.GetSession(ctx, refreshPayload.ID) 37 | if err != nil { 38 | if errors.Is(err, db.ErrRecordNotFound) { 39 | ctx.JSON(http.StatusNotFound, errorResponse(err)) 40 | return 41 | } 42 | ctx.JSON(http.StatusInternalServerError, errorResponse(err)) 43 | return 44 | } 45 | 46 | if session.IsBlocked { 47 | err := fmt.Errorf("blocked session") 48 | ctx.JSON(http.StatusUnauthorized, errorResponse(err)) 49 | return 50 | } 51 | 52 | if session.Username != refreshPayload.Username { 53 | err := fmt.Errorf("incorrect session user") 54 | ctx.JSON(http.StatusUnauthorized, errorResponse(err)) 55 | return 56 | } 57 | 58 | if session.RefreshToken != req.RefreshToken { 59 | err := fmt.Errorf("mismatched session token") 60 | ctx.JSON(http.StatusUnauthorized, errorResponse(err)) 61 | return 62 | } 63 | 64 | if time.Now().After(session.ExpiresAt) { 65 | err := fmt.Errorf("expired session") 66 | ctx.JSON(http.StatusUnauthorized, errorResponse(err)) 67 | return 68 | } 69 | 70 | accessToken, accessPayload, err := server.tokenMaker.CreateToken( 71 | refreshPayload.Username, 72 | refreshPayload.Role, 73 | server.config.AccessTokenDuration, 74 | token.TokenTypeAccessToken, 75 | ) 76 | if err != nil { 77 | ctx.JSON(http.StatusInternalServerError, errorResponse(err)) 78 | return 79 | } 80 | 81 | rsp := renewAccessTokenResponse{ 82 | AccessToken: accessToken, 83 | AccessTokenExpiresAt: accessPayload.ExpiredAt, 84 | } 85 | ctx.JSON(http.StatusOK, rsp) 86 | } 87 | -------------------------------------------------------------------------------- /api/transfer.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | 8 | "github.com/gin-gonic/gin" 9 | db "github.com/techschool/simplebank/db/sqlc" 10 | "github.com/techschool/simplebank/token" 11 | ) 12 | 13 | type transferRequest struct { 14 | FromAccountID int64 `json:"from_account_id" binding:"required,min=1"` 15 | ToAccountID int64 `json:"to_account_id" binding:"required,min=1"` 16 | Amount int64 `json:"amount" binding:"required,gt=0"` 17 | Currency string `json:"currency" binding:"required,currency"` 18 | } 19 | 20 | func (server *Server) createTransfer(ctx *gin.Context) { 21 | var req transferRequest 22 | if err := ctx.ShouldBindJSON(&req); err != nil { 23 | ctx.JSON(http.StatusBadRequest, errorResponse(err)) 24 | return 25 | } 26 | 27 | fromAccount, valid := server.validAccount(ctx, req.FromAccountID, req.Currency) 28 | if !valid { 29 | return 30 | } 31 | 32 | authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) 33 | if fromAccount.Owner != authPayload.Username { 34 | err := errors.New("from account doesn't belong to the authenticated user") 35 | ctx.JSON(http.StatusUnauthorized, errorResponse(err)) 36 | return 37 | } 38 | 39 | _, valid = server.validAccount(ctx, req.ToAccountID, req.Currency) 40 | if !valid { 41 | return 42 | } 43 | 44 | arg := db.TransferTxParams{ 45 | FromAccountID: req.FromAccountID, 46 | ToAccountID: req.ToAccountID, 47 | Amount: req.Amount, 48 | } 49 | 50 | result, err := server.store.TransferTx(ctx, arg) 51 | if err != nil { 52 | ctx.JSON(http.StatusInternalServerError, errorResponse(err)) 53 | return 54 | } 55 | 56 | ctx.JSON(http.StatusOK, result) 57 | } 58 | 59 | func (server *Server) validAccount(ctx *gin.Context, accountID int64, currency string) (db.Account, bool) { 60 | account, err := server.store.GetAccount(ctx, accountID) 61 | if err != nil { 62 | if errors.Is(err, db.ErrRecordNotFound) { 63 | ctx.JSON(http.StatusNotFound, errorResponse(err)) 64 | return account, false 65 | } 66 | 67 | ctx.JSON(http.StatusInternalServerError, errorResponse(err)) 68 | return account, false 69 | } 70 | 71 | if account.Currency != currency { 72 | err := fmt.Errorf("account [%d] currency mismatch: %s vs %s", account.ID, account.Currency, currency) 73 | ctx.JSON(http.StatusBadRequest, errorResponse(err)) 74 | return account, false 75 | } 76 | 77 | return account, true 78 | } 79 | -------------------------------------------------------------------------------- /api/user.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/gin-gonic/gin" 9 | "github.com/google/uuid" 10 | db "github.com/techschool/simplebank/db/sqlc" 11 | "github.com/techschool/simplebank/token" 12 | "github.com/techschool/simplebank/util" 13 | ) 14 | 15 | type createUserRequest struct { 16 | Username string `json:"username" binding:"required,alphanum"` 17 | Password string `json:"password" binding:"required,min=6"` 18 | FullName string `json:"full_name" binding:"required"` 19 | Email string `json:"email" binding:"required,email"` 20 | } 21 | 22 | type userResponse struct { 23 | Username string `json:"username"` 24 | FullName string `json:"full_name"` 25 | Email string `json:"email"` 26 | PasswordChangedAt time.Time `json:"password_changed_at"` 27 | CreatedAt time.Time `json:"created_at"` 28 | } 29 | 30 | func newUserResponse(user db.User) userResponse { 31 | return userResponse{ 32 | Username: user.Username, 33 | FullName: user.FullName, 34 | Email: user.Email, 35 | PasswordChangedAt: user.PasswordChangedAt, 36 | CreatedAt: user.CreatedAt, 37 | } 38 | } 39 | 40 | func (server *Server) createUser(ctx *gin.Context) { 41 | var req createUserRequest 42 | if err := ctx.ShouldBindJSON(&req); err != nil { 43 | ctx.JSON(http.StatusBadRequest, errorResponse(err)) 44 | return 45 | } 46 | 47 | hashedPassword, err := util.HashPassword(req.Password) 48 | if err != nil { 49 | ctx.JSON(http.StatusInternalServerError, errorResponse(err)) 50 | return 51 | } 52 | 53 | arg := db.CreateUserParams{ 54 | Username: req.Username, 55 | HashedPassword: hashedPassword, 56 | FullName: req.FullName, 57 | Email: req.Email, 58 | } 59 | 60 | user, err := server.store.CreateUser(ctx, arg) 61 | if err != nil { 62 | if db.ErrorCode(err) == db.UniqueViolation { 63 | ctx.JSON(http.StatusForbidden, errorResponse(err)) 64 | return 65 | } 66 | ctx.JSON(http.StatusInternalServerError, errorResponse(err)) 67 | return 68 | } 69 | 70 | rsp := newUserResponse(user) 71 | ctx.JSON(http.StatusOK, rsp) 72 | } 73 | 74 | type loginUserRequest struct { 75 | Username string `json:"username" binding:"required,alphanum"` 76 | Password string `json:"password" binding:"required,min=6"` 77 | } 78 | 79 | type loginUserResponse struct { 80 | SessionID uuid.UUID `json:"session_id"` 81 | AccessToken string `json:"access_token"` 82 | AccessTokenExpiresAt time.Time `json:"access_token_expires_at"` 83 | RefreshToken string `json:"refresh_token"` 84 | RefreshTokenExpiresAt time.Time `json:"refresh_token_expires_at"` 85 | User userResponse `json:"user"` 86 | } 87 | 88 | func (server *Server) loginUser(ctx *gin.Context) { 89 | var req loginUserRequest 90 | if err := ctx.ShouldBindJSON(&req); err != nil { 91 | ctx.JSON(http.StatusBadRequest, errorResponse(err)) 92 | return 93 | } 94 | 95 | user, err := server.store.GetUser(ctx, req.Username) 96 | if err != nil { 97 | if errors.Is(err, db.ErrRecordNotFound) { 98 | ctx.JSON(http.StatusNotFound, errorResponse(err)) 99 | return 100 | } 101 | ctx.JSON(http.StatusInternalServerError, errorResponse(err)) 102 | return 103 | } 104 | 105 | err = util.CheckPassword(req.Password, user.HashedPassword) 106 | if err != nil { 107 | ctx.JSON(http.StatusUnauthorized, errorResponse(err)) 108 | return 109 | } 110 | 111 | accessToken, accessPayload, err := server.tokenMaker.CreateToken( 112 | user.Username, 113 | user.Role, 114 | server.config.AccessTokenDuration, 115 | token.TokenTypeAccessToken, 116 | ) 117 | if err != nil { 118 | ctx.JSON(http.StatusInternalServerError, errorResponse(err)) 119 | return 120 | } 121 | 122 | refreshToken, refreshPayload, err := server.tokenMaker.CreateToken( 123 | user.Username, 124 | user.Role, 125 | server.config.RefreshTokenDuration, 126 | token.TokenTypeRefreshToken, 127 | ) 128 | if err != nil { 129 | ctx.JSON(http.StatusInternalServerError, errorResponse(err)) 130 | return 131 | } 132 | 133 | session, err := server.store.CreateSession(ctx, db.CreateSessionParams{ 134 | ID: refreshPayload.ID, 135 | Username: user.Username, 136 | RefreshToken: refreshToken, 137 | UserAgent: ctx.Request.UserAgent(), 138 | ClientIp: ctx.ClientIP(), 139 | IsBlocked: false, 140 | ExpiresAt: refreshPayload.ExpiredAt, 141 | }) 142 | if err != nil { 143 | ctx.JSON(http.StatusInternalServerError, errorResponse(err)) 144 | return 145 | } 146 | 147 | rsp := loginUserResponse{ 148 | SessionID: session.ID, 149 | AccessToken: accessToken, 150 | AccessTokenExpiresAt: accessPayload.ExpiredAt, 151 | RefreshToken: refreshToken, 152 | RefreshTokenExpiresAt: refreshPayload.ExpiredAt, 153 | User: newUserResponse(user), 154 | } 155 | ctx.JSON(http.StatusOK, rsp) 156 | } 157 | -------------------------------------------------------------------------------- /api/validator.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "github.com/go-playground/validator/v10" 5 | "github.com/techschool/simplebank/util" 6 | ) 7 | 8 | var validCurrency validator.Func = func(fieldLevel validator.FieldLevel) bool { 9 | if currency, ok := fieldLevel.Field().Interface().(string); ok { 10 | return util.IsSupportedCurrency(currency) 11 | } 12 | return false 13 | } 14 | -------------------------------------------------------------------------------- /app.env: -------------------------------------------------------------------------------- 1 | ENVIRONMENT=development 2 | ALLOWED_ORIGINS=http://localhost:3000,https://simplebank.com 3 | DB_SOURCE=postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable 4 | MIGRATION_URL=file://db/migration 5 | HTTP_SERVER_ADDRESS=0.0.0.0:8080 6 | GRPC_SERVER_ADDRESS=0.0.0.0:9090 7 | TOKEN_SYMMETRIC_KEY=12345678901234567890123456789012 8 | ACCESS_TOKEN_DURATION=1m 9 | REFRESH_TOKEN_DURATION=24h 10 | REDIS_ADDRESS=0.0.0.0:6379 11 | EMAIL_SENDER_NAME=Simple Bank 12 | EMAIL_SENDER_ADDRESS=simplebanktest@gmail.com 13 | EMAIL_SENDER_PASSWORD=jekfcygyenvzekke 14 | -------------------------------------------------------------------------------- /backend-master.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/techschool/simplebank/97f000fe58ad01a0774179ffa8884ac7784cf263/backend-master.png -------------------------------------------------------------------------------- /db/migration/000001_init_schema.down.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE IF EXISTS entries; 2 | DROP TABLE IF EXISTS transfers; 3 | DROP TABLE IF EXISTS accounts; 4 | -------------------------------------------------------------------------------- /db/migration/000001_init_schema.up.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE "accounts" ( 2 | "id" bigserial PRIMARY KEY, 3 | "owner" varchar NOT NULL, 4 | "balance" bigint NOT NULL, 5 | "currency" varchar NOT NULL, 6 | "created_at" timestamptz NOT NULL DEFAULT (now()) 7 | ); 8 | 9 | CREATE TABLE "entries" ( 10 | "id" bigserial PRIMARY KEY, 11 | "account_id" bigint NOT NULL, 12 | "amount" bigint NOT NULL, 13 | "created_at" timestamptz NOT NULL DEFAULT (now()) 14 | ); 15 | 16 | CREATE TABLE "transfers" ( 17 | "id" bigserial PRIMARY KEY, 18 | "from_account_id" bigint NOT NULL, 19 | "to_account_id" bigint NOT NULL, 20 | "amount" bigint NOT NULL, 21 | "created_at" timestamptz NOT NULL DEFAULT (now()) 22 | ); 23 | 24 | ALTER TABLE "entries" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id"); 25 | 26 | ALTER TABLE "transfers" ADD FOREIGN KEY ("from_account_id") REFERENCES "accounts" ("id"); 27 | 28 | ALTER TABLE "transfers" ADD FOREIGN KEY ("to_account_id") REFERENCES "accounts" ("id"); 29 | 30 | CREATE INDEX ON "accounts" ("owner"); 31 | 32 | CREATE INDEX ON "entries" ("account_id"); 33 | 34 | CREATE INDEX ON "transfers" ("from_account_id"); 35 | 36 | CREATE INDEX ON "transfers" ("to_account_id"); 37 | 38 | CREATE INDEX ON "transfers" ("from_account_id", "to_account_id"); 39 | 40 | COMMENT ON COLUMN "entries"."amount" IS 'can be negative or positive'; 41 | 42 | COMMENT ON COLUMN "transfers"."amount" IS 'must be positive'; 43 | -------------------------------------------------------------------------------- /db/migration/000002_add_users.down.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE IF EXISTS "accounts" DROP CONSTRAINT IF EXISTS "owner_currency_key"; 2 | 3 | ALTER TABLE IF EXISTS "accounts" DROP CONSTRAINT IF EXISTS "accounts_owner_fkey"; 4 | 5 | DROP TABLE IF EXISTS "users"; 6 | -------------------------------------------------------------------------------- /db/migration/000002_add_users.up.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE "users" ( 2 | "username" varchar PRIMARY KEY, 3 | "hashed_password" varchar NOT NULL, 4 | "full_name" varchar NOT NULL, 5 | "email" varchar UNIQUE NOT NULL, 6 | "password_changed_at" timestamptz NOT NULL DEFAULT('0001-01-01 00:00:00Z'), 7 | "created_at" timestamptz NOT NULL DEFAULT (now()) 8 | ); 9 | 10 | ALTER TABLE "accounts" ADD FOREIGN KEY ("owner") REFERENCES "users" ("username"); 11 | 12 | -- CREATE UNIQUE INDEX ON "accounts" ("owner", "currency"); 13 | ALTER TABLE "accounts" ADD CONSTRAINT "owner_currency_key" UNIQUE ("owner", "currency"); 14 | -------------------------------------------------------------------------------- /db/migration/000003_add_sessions.down.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE IF EXISTS "sessions"; 2 | -------------------------------------------------------------------------------- /db/migration/000003_add_sessions.up.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE "sessions" ( 2 | "id" uuid PRIMARY KEY, 3 | "username" varchar NOT NULL, 4 | "refresh_token" varchar NOT NULL, 5 | "user_agent" varchar NOT NULL, 6 | "client_ip" varchar NOT NULL, 7 | "is_blocked" boolean NOT NULL DEFAULT false, 8 | "expires_at" timestamptz NOT NULL, 9 | "created_at" timestamptz NOT NULL DEFAULT (now()) 10 | ); 11 | 12 | ALTER TABLE "sessions" ADD FOREIGN KEY ("username") REFERENCES "users" ("username"); 13 | -------------------------------------------------------------------------------- /db/migration/000004_add_verify_emails.down.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE IF EXISTS "verify_emails" CASCADE; 2 | 3 | ALTER TABLE "users" DROP COLUMN "is_email_verified"; 4 | -------------------------------------------------------------------------------- /db/migration/000004_add_verify_emails.up.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE "verify_emails" ( 2 | "id" bigserial PRIMARY KEY, 3 | "username" varchar NOT NULL, 4 | "email" varchar NOT NULL, 5 | "secret_code" varchar NOT NULL, 6 | "is_used" bool NOT NULL DEFAULT false, 7 | "created_at" timestamptz NOT NULL DEFAULT (now()), 8 | "expired_at" timestamptz NOT NULL DEFAULT (now() + interval '15 minutes') 9 | ); 10 | 11 | ALTER TABLE "verify_emails" ADD FOREIGN KEY ("username") REFERENCES "users" ("username"); 12 | 13 | ALTER TABLE "users" ADD COLUMN "is_email_verified" bool NOT NULL DEFAULT false; 14 | -------------------------------------------------------------------------------- /db/migration/000005_add_role_to_users.down.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE "users" DROP COLUMN "role"; 2 | -------------------------------------------------------------------------------- /db/migration/000005_add_role_to_users.up.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE "users" ADD COLUMN "role" varchar NOT NULL DEFAULT 'depositor'; 2 | -------------------------------------------------------------------------------- /db/query/account.sql: -------------------------------------------------------------------------------- 1 | -- name: CreateAccount :one 2 | INSERT INTO accounts ( 3 | owner, 4 | balance, 5 | currency 6 | ) VALUES ( 7 | $1, $2, $3 8 | ) RETURNING *; 9 | 10 | -- name: GetAccount :one 11 | SELECT * FROM accounts 12 | WHERE id = $1 LIMIT 1; 13 | 14 | -- name: GetAccountForUpdate :one 15 | SELECT * FROM accounts 16 | WHERE id = $1 LIMIT 1 17 | FOR NO KEY UPDATE; 18 | 19 | -- name: ListAccounts :many 20 | SELECT * FROM accounts 21 | WHERE owner = $1 22 | ORDER BY id 23 | LIMIT $2 24 | OFFSET $3; 25 | 26 | -- name: UpdateAccount :one 27 | UPDATE accounts 28 | SET balance = $2 29 | WHERE id = $1 30 | RETURNING *; 31 | 32 | -- name: AddAccountBalance :one 33 | UPDATE accounts 34 | SET balance = balance + sqlc.arg(amount) 35 | WHERE id = sqlc.arg(id) 36 | RETURNING *; 37 | 38 | -- name: DeleteAccount :exec 39 | DELETE FROM accounts 40 | WHERE id = $1; 41 | -------------------------------------------------------------------------------- /db/query/entry.sql: -------------------------------------------------------------------------------- 1 | -- name: CreateEntry :one 2 | INSERT INTO entries ( 3 | account_id, 4 | amount 5 | ) VALUES ( 6 | $1, $2 7 | ) RETURNING *; 8 | 9 | -- name: GetEntry :one 10 | SELECT * FROM entries 11 | WHERE id = $1 LIMIT 1; 12 | 13 | -- name: ListEntries :many 14 | SELECT * FROM entries 15 | WHERE account_id = $1 16 | ORDER BY id 17 | LIMIT $2 18 | OFFSET $3; 19 | -------------------------------------------------------------------------------- /db/query/session.sql: -------------------------------------------------------------------------------- 1 | -- name: CreateSession :one 2 | INSERT INTO sessions ( 3 | id, 4 | username, 5 | refresh_token, 6 | user_agent, 7 | client_ip, 8 | is_blocked, 9 | expires_at 10 | ) VALUES ( 11 | $1, $2, $3, $4, $5, $6, $7 12 | ) RETURNING *; 13 | 14 | -- name: GetSession :one 15 | SELECT * FROM sessions 16 | WHERE id = $1 LIMIT 1; 17 | -------------------------------------------------------------------------------- /db/query/transfer.sql: -------------------------------------------------------------------------------- 1 | -- name: CreateTransfer :one 2 | INSERT INTO transfers ( 3 | from_account_id, 4 | to_account_id, 5 | amount 6 | ) VALUES ( 7 | $1, $2, $3 8 | ) RETURNING *; 9 | 10 | -- name: GetTransfer :one 11 | SELECT * FROM transfers 12 | WHERE id = $1 LIMIT 1; 13 | 14 | -- name: ListTransfers :many 15 | SELECT * FROM transfers 16 | WHERE 17 | from_account_id = $1 OR 18 | to_account_id = $2 19 | ORDER BY id 20 | LIMIT $3 21 | OFFSET $4; 22 | -------------------------------------------------------------------------------- /db/query/user.sql: -------------------------------------------------------------------------------- 1 | -- name: CreateUser :one 2 | INSERT INTO users ( 3 | username, 4 | hashed_password, 5 | full_name, 6 | email 7 | ) VALUES ( 8 | $1, $2, $3, $4 9 | ) RETURNING *; 10 | 11 | -- name: GetUser :one 12 | SELECT * FROM users 13 | WHERE username = $1 LIMIT 1; 14 | 15 | -- name: UpdateUser :one 16 | UPDATE users 17 | SET 18 | hashed_password = COALESCE(sqlc.narg(hashed_password), hashed_password), 19 | password_changed_at = COALESCE(sqlc.narg(password_changed_at), password_changed_at), 20 | full_name = COALESCE(sqlc.narg(full_name), full_name), 21 | email = COALESCE(sqlc.narg(email), email), 22 | is_email_verified = COALESCE(sqlc.narg(is_email_verified), is_email_verified) 23 | WHERE 24 | username = sqlc.arg(username) 25 | RETURNING *; 26 | -------------------------------------------------------------------------------- /db/query/verify_email.sql: -------------------------------------------------------------------------------- 1 | -- name: CreateVerifyEmail :one 2 | INSERT INTO verify_emails ( 3 | username, 4 | email, 5 | secret_code 6 | ) VALUES ( 7 | $1, $2, $3 8 | ) RETURNING *; 9 | 10 | -- name: UpdateVerifyEmail :one 11 | UPDATE verify_emails 12 | SET 13 | is_used = TRUE 14 | WHERE 15 | id = @id 16 | AND secret_code = @secret_code 17 | AND is_used = FALSE 18 | AND expired_at > now() 19 | RETURNING *; 20 | -------------------------------------------------------------------------------- /db/sqlc/account.sql.go: -------------------------------------------------------------------------------- 1 | // Code generated by sqlc. DO NOT EDIT. 2 | // versions: 3 | // sqlc v1.22.0 4 | // source: account.sql 5 | 6 | package db 7 | 8 | import ( 9 | "context" 10 | ) 11 | 12 | const addAccountBalance = `-- name: AddAccountBalance :one 13 | UPDATE accounts 14 | SET balance = balance + $1 15 | WHERE id = $2 16 | RETURNING id, owner, balance, currency, created_at 17 | ` 18 | 19 | type AddAccountBalanceParams struct { 20 | Amount int64 `json:"amount"` 21 | ID int64 `json:"id"` 22 | } 23 | 24 | func (q *Queries) AddAccountBalance(ctx context.Context, arg AddAccountBalanceParams) (Account, error) { 25 | row := q.db.QueryRow(ctx, addAccountBalance, arg.Amount, arg.ID) 26 | var i Account 27 | err := row.Scan( 28 | &i.ID, 29 | &i.Owner, 30 | &i.Balance, 31 | &i.Currency, 32 | &i.CreatedAt, 33 | ) 34 | return i, err 35 | } 36 | 37 | const createAccount = `-- name: CreateAccount :one 38 | INSERT INTO accounts ( 39 | owner, 40 | balance, 41 | currency 42 | ) VALUES ( 43 | $1, $2, $3 44 | ) RETURNING id, owner, balance, currency, created_at 45 | ` 46 | 47 | type CreateAccountParams struct { 48 | Owner string `json:"owner"` 49 | Balance int64 `json:"balance"` 50 | Currency string `json:"currency"` 51 | } 52 | 53 | func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (Account, error) { 54 | row := q.db.QueryRow(ctx, createAccount, arg.Owner, arg.Balance, arg.Currency) 55 | var i Account 56 | err := row.Scan( 57 | &i.ID, 58 | &i.Owner, 59 | &i.Balance, 60 | &i.Currency, 61 | &i.CreatedAt, 62 | ) 63 | return i, err 64 | } 65 | 66 | const deleteAccount = `-- name: DeleteAccount :exec 67 | DELETE FROM accounts 68 | WHERE id = $1 69 | ` 70 | 71 | func (q *Queries) DeleteAccount(ctx context.Context, id int64) error { 72 | _, err := q.db.Exec(ctx, deleteAccount, id) 73 | return err 74 | } 75 | 76 | const getAccount = `-- name: GetAccount :one 77 | SELECT id, owner, balance, currency, created_at FROM accounts 78 | WHERE id = $1 LIMIT 1 79 | ` 80 | 81 | func (q *Queries) GetAccount(ctx context.Context, id int64) (Account, error) { 82 | row := q.db.QueryRow(ctx, getAccount, id) 83 | var i Account 84 | err := row.Scan( 85 | &i.ID, 86 | &i.Owner, 87 | &i.Balance, 88 | &i.Currency, 89 | &i.CreatedAt, 90 | ) 91 | return i, err 92 | } 93 | 94 | const getAccountForUpdate = `-- name: GetAccountForUpdate :one 95 | SELECT id, owner, balance, currency, created_at FROM accounts 96 | WHERE id = $1 LIMIT 1 97 | FOR NO KEY UPDATE 98 | ` 99 | 100 | func (q *Queries) GetAccountForUpdate(ctx context.Context, id int64) (Account, error) { 101 | row := q.db.QueryRow(ctx, getAccountForUpdate, id) 102 | var i Account 103 | err := row.Scan( 104 | &i.ID, 105 | &i.Owner, 106 | &i.Balance, 107 | &i.Currency, 108 | &i.CreatedAt, 109 | ) 110 | return i, err 111 | } 112 | 113 | const listAccounts = `-- name: ListAccounts :many 114 | SELECT id, owner, balance, currency, created_at FROM accounts 115 | WHERE owner = $1 116 | ORDER BY id 117 | LIMIT $2 118 | OFFSET $3 119 | ` 120 | 121 | type ListAccountsParams struct { 122 | Owner string `json:"owner"` 123 | Limit int32 `json:"limit"` 124 | Offset int32 `json:"offset"` 125 | } 126 | 127 | func (q *Queries) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]Account, error) { 128 | rows, err := q.db.Query(ctx, listAccounts, arg.Owner, arg.Limit, arg.Offset) 129 | if err != nil { 130 | return nil, err 131 | } 132 | defer rows.Close() 133 | items := []Account{} 134 | for rows.Next() { 135 | var i Account 136 | if err := rows.Scan( 137 | &i.ID, 138 | &i.Owner, 139 | &i.Balance, 140 | &i.Currency, 141 | &i.CreatedAt, 142 | ); err != nil { 143 | return nil, err 144 | } 145 | items = append(items, i) 146 | } 147 | if err := rows.Err(); err != nil { 148 | return nil, err 149 | } 150 | return items, nil 151 | } 152 | 153 | const updateAccount = `-- name: UpdateAccount :one 154 | UPDATE accounts 155 | SET balance = $2 156 | WHERE id = $1 157 | RETURNING id, owner, balance, currency, created_at 158 | ` 159 | 160 | type UpdateAccountParams struct { 161 | ID int64 `json:"id"` 162 | Balance int64 `json:"balance"` 163 | } 164 | 165 | func (q *Queries) UpdateAccount(ctx context.Context, arg UpdateAccountParams) (Account, error) { 166 | row := q.db.QueryRow(ctx, updateAccount, arg.ID, arg.Balance) 167 | var i Account 168 | err := row.Scan( 169 | &i.ID, 170 | &i.Owner, 171 | &i.Balance, 172 | &i.Currency, 173 | &i.CreatedAt, 174 | ) 175 | return i, err 176 | } 177 | -------------------------------------------------------------------------------- /db/sqlc/account_test.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/require" 9 | "github.com/techschool/simplebank/util" 10 | ) 11 | 12 | func createRandomAccount(t *testing.T) Account { 13 | user := createRandomUser(t) 14 | 15 | arg := CreateAccountParams{ 16 | Owner: user.Username, 17 | Balance: util.RandomMoney(), 18 | Currency: util.RandomCurrency(), 19 | } 20 | 21 | account, err := testStore.CreateAccount(context.Background(), arg) 22 | require.NoError(t, err) 23 | require.NotEmpty(t, account) 24 | 25 | require.Equal(t, arg.Owner, account.Owner) 26 | require.Equal(t, arg.Balance, account.Balance) 27 | require.Equal(t, arg.Currency, account.Currency) 28 | 29 | require.NotZero(t, account.ID) 30 | require.NotZero(t, account.CreatedAt) 31 | 32 | return account 33 | } 34 | 35 | func TestCreateAccount(t *testing.T) { 36 | createRandomAccount(t) 37 | } 38 | 39 | func TestGetAccount(t *testing.T) { 40 | account1 := createRandomAccount(t) 41 | account2, err := testStore.GetAccount(context.Background(), account1.ID) 42 | require.NoError(t, err) 43 | require.NotEmpty(t, account2) 44 | 45 | require.Equal(t, account1.ID, account2.ID) 46 | require.Equal(t, account1.Owner, account2.Owner) 47 | require.Equal(t, account1.Balance, account2.Balance) 48 | require.Equal(t, account1.Currency, account2.Currency) 49 | require.WithinDuration(t, account1.CreatedAt, account2.CreatedAt, time.Second) 50 | } 51 | 52 | func TestUpdateAccount(t *testing.T) { 53 | account1 := createRandomAccount(t) 54 | 55 | arg := UpdateAccountParams{ 56 | ID: account1.ID, 57 | Balance: util.RandomMoney(), 58 | } 59 | 60 | account2, err := testStore.UpdateAccount(context.Background(), arg) 61 | require.NoError(t, err) 62 | require.NotEmpty(t, account2) 63 | 64 | require.Equal(t, account1.ID, account2.ID) 65 | require.Equal(t, account1.Owner, account2.Owner) 66 | require.Equal(t, arg.Balance, account2.Balance) 67 | require.Equal(t, account1.Currency, account2.Currency) 68 | require.WithinDuration(t, account1.CreatedAt, account2.CreatedAt, time.Second) 69 | } 70 | 71 | func TestDeleteAccount(t *testing.T) { 72 | account1 := createRandomAccount(t) 73 | err := testStore.DeleteAccount(context.Background(), account1.ID) 74 | require.NoError(t, err) 75 | 76 | account2, err := testStore.GetAccount(context.Background(), account1.ID) 77 | require.Error(t, err) 78 | require.EqualError(t, err, ErrRecordNotFound.Error()) 79 | require.Empty(t, account2) 80 | } 81 | 82 | func TestListAccounts(t *testing.T) { 83 | var lastAccount Account 84 | for i := 0; i < 10; i++ { 85 | lastAccount = createRandomAccount(t) 86 | } 87 | 88 | arg := ListAccountsParams{ 89 | Owner: lastAccount.Owner, 90 | Limit: 5, 91 | Offset: 0, 92 | } 93 | 94 | accounts, err := testStore.ListAccounts(context.Background(), arg) 95 | require.NoError(t, err) 96 | require.NotEmpty(t, accounts) 97 | 98 | for _, account := range accounts { 99 | require.NotEmpty(t, account) 100 | require.Equal(t, lastAccount.Owner, account.Owner) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /db/sqlc/db.go: -------------------------------------------------------------------------------- 1 | // Code generated by sqlc. DO NOT EDIT. 2 | // versions: 3 | // sqlc v1.22.0 4 | 5 | package db 6 | 7 | import ( 8 | "context" 9 | 10 | "github.com/jackc/pgx/v5" 11 | "github.com/jackc/pgx/v5/pgconn" 12 | ) 13 | 14 | type DBTX interface { 15 | Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) 16 | Query(context.Context, string, ...interface{}) (pgx.Rows, error) 17 | QueryRow(context.Context, string, ...interface{}) pgx.Row 18 | } 19 | 20 | func New(db DBTX) *Queries { 21 | return &Queries{db: db} 22 | } 23 | 24 | type Queries struct { 25 | db DBTX 26 | } 27 | 28 | func (q *Queries) WithTx(tx pgx.Tx) *Queries { 29 | return &Queries{ 30 | db: tx, 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /db/sqlc/entry.sql.go: -------------------------------------------------------------------------------- 1 | // Code generated by sqlc. DO NOT EDIT. 2 | // versions: 3 | // sqlc v1.22.0 4 | // source: entry.sql 5 | 6 | package db 7 | 8 | import ( 9 | "context" 10 | ) 11 | 12 | const createEntry = `-- name: CreateEntry :one 13 | INSERT INTO entries ( 14 | account_id, 15 | amount 16 | ) VALUES ( 17 | $1, $2 18 | ) RETURNING id, account_id, amount, created_at 19 | ` 20 | 21 | type CreateEntryParams struct { 22 | AccountID int64 `json:"account_id"` 23 | Amount int64 `json:"amount"` 24 | } 25 | 26 | func (q *Queries) CreateEntry(ctx context.Context, arg CreateEntryParams) (Entry, error) { 27 | row := q.db.QueryRow(ctx, createEntry, arg.AccountID, arg.Amount) 28 | var i Entry 29 | err := row.Scan( 30 | &i.ID, 31 | &i.AccountID, 32 | &i.Amount, 33 | &i.CreatedAt, 34 | ) 35 | return i, err 36 | } 37 | 38 | const getEntry = `-- name: GetEntry :one 39 | SELECT id, account_id, amount, created_at FROM entries 40 | WHERE id = $1 LIMIT 1 41 | ` 42 | 43 | func (q *Queries) GetEntry(ctx context.Context, id int64) (Entry, error) { 44 | row := q.db.QueryRow(ctx, getEntry, id) 45 | var i Entry 46 | err := row.Scan( 47 | &i.ID, 48 | &i.AccountID, 49 | &i.Amount, 50 | &i.CreatedAt, 51 | ) 52 | return i, err 53 | } 54 | 55 | const listEntries = `-- name: ListEntries :many 56 | SELECT id, account_id, amount, created_at FROM entries 57 | WHERE account_id = $1 58 | ORDER BY id 59 | LIMIT $2 60 | OFFSET $3 61 | ` 62 | 63 | type ListEntriesParams struct { 64 | AccountID int64 `json:"account_id"` 65 | Limit int32 `json:"limit"` 66 | Offset int32 `json:"offset"` 67 | } 68 | 69 | func (q *Queries) ListEntries(ctx context.Context, arg ListEntriesParams) ([]Entry, error) { 70 | rows, err := q.db.Query(ctx, listEntries, arg.AccountID, arg.Limit, arg.Offset) 71 | if err != nil { 72 | return nil, err 73 | } 74 | defer rows.Close() 75 | items := []Entry{} 76 | for rows.Next() { 77 | var i Entry 78 | if err := rows.Scan( 79 | &i.ID, 80 | &i.AccountID, 81 | &i.Amount, 82 | &i.CreatedAt, 83 | ); err != nil { 84 | return nil, err 85 | } 86 | items = append(items, i) 87 | } 88 | if err := rows.Err(); err != nil { 89 | return nil, err 90 | } 91 | return items, nil 92 | } 93 | -------------------------------------------------------------------------------- /db/sqlc/entry_test.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/require" 9 | "github.com/techschool/simplebank/util" 10 | ) 11 | 12 | func createRandomEntry(t *testing.T, account Account) Entry { 13 | arg := CreateEntryParams{ 14 | AccountID: account.ID, 15 | Amount: util.RandomMoney(), 16 | } 17 | 18 | entry, err := testStore.CreateEntry(context.Background(), arg) 19 | require.NoError(t, err) 20 | require.NotEmpty(t, entry) 21 | 22 | require.Equal(t, arg.AccountID, entry.AccountID) 23 | require.Equal(t, arg.Amount, entry.Amount) 24 | 25 | require.NotZero(t, entry.ID) 26 | require.NotZero(t, entry.CreatedAt) 27 | 28 | return entry 29 | } 30 | 31 | func TestCreateEntry(t *testing.T) { 32 | account := createRandomAccount(t) 33 | createRandomEntry(t, account) 34 | } 35 | 36 | func TestGetEntry(t *testing.T) { 37 | account := createRandomAccount(t) 38 | entry1 := createRandomEntry(t, account) 39 | entry2, err := testStore.GetEntry(context.Background(), entry1.ID) 40 | require.NoError(t, err) 41 | require.NotEmpty(t, entry2) 42 | 43 | require.Equal(t, entry1.ID, entry2.ID) 44 | require.Equal(t, entry1.AccountID, entry2.AccountID) 45 | require.Equal(t, entry1.Amount, entry2.Amount) 46 | require.WithinDuration(t, entry1.CreatedAt, entry2.CreatedAt, time.Second) 47 | } 48 | 49 | func TestListEntries(t *testing.T) { 50 | account := createRandomAccount(t) 51 | for i := 0; i < 10; i++ { 52 | createRandomEntry(t, account) 53 | } 54 | 55 | arg := ListEntriesParams{ 56 | AccountID: account.ID, 57 | Limit: 5, 58 | Offset: 5, 59 | } 60 | 61 | entries, err := testStore.ListEntries(context.Background(), arg) 62 | require.NoError(t, err) 63 | require.Len(t, entries, 5) 64 | 65 | for _, entry := range entries { 66 | require.NotEmpty(t, entry) 67 | require.Equal(t, arg.AccountID, entry.AccountID) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /db/sqlc/error.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/jackc/pgx/v5" 7 | "github.com/jackc/pgx/v5/pgconn" 8 | ) 9 | 10 | const ( 11 | ForeignKeyViolation = "23503" 12 | UniqueViolation = "23505" 13 | ) 14 | 15 | var ErrRecordNotFound = pgx.ErrNoRows 16 | 17 | var ErrUniqueViolation = &pgconn.PgError{ 18 | Code: UniqueViolation, 19 | } 20 | 21 | func ErrorCode(err error) string { 22 | var pgErr *pgconn.PgError 23 | if errors.As(err, &pgErr) { 24 | return pgErr.Code 25 | } 26 | return "" 27 | } 28 | -------------------------------------------------------------------------------- /db/sqlc/exec_tx.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | ) 7 | 8 | // ExecTx executes a function within a database transaction 9 | func (store *SQLStore) execTx(ctx context.Context, fn func(*Queries) error) error { 10 | tx, err := store.connPool.Begin(ctx) 11 | if err != nil { 12 | return err 13 | } 14 | 15 | q := New(tx) 16 | err = fn(q) 17 | if err != nil { 18 | if rbErr := tx.Rollback(ctx); rbErr != nil { 19 | return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr) 20 | } 21 | return err 22 | } 23 | 24 | return tx.Commit(ctx) 25 | } 26 | -------------------------------------------------------------------------------- /db/sqlc/main_test.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "os" 7 | "testing" 8 | 9 | "github.com/jackc/pgx/v5/pgxpool" 10 | "github.com/techschool/simplebank/util" 11 | ) 12 | 13 | var testStore Store 14 | 15 | func TestMain(m *testing.M) { 16 | config, err := util.LoadConfig("../..") 17 | if err != nil { 18 | log.Fatal("cannot load config:", err) 19 | } 20 | 21 | connPool, err := pgxpool.New(context.Background(), config.DBSource) 22 | if err != nil { 23 | log.Fatal("cannot connect to db:", err) 24 | } 25 | 26 | testStore = NewStore(connPool) 27 | os.Exit(m.Run()) 28 | } 29 | -------------------------------------------------------------------------------- /db/sqlc/models.go: -------------------------------------------------------------------------------- 1 | // Code generated by sqlc. DO NOT EDIT. 2 | // versions: 3 | // sqlc v1.22.0 4 | 5 | package db 6 | 7 | import ( 8 | "time" 9 | 10 | "github.com/google/uuid" 11 | ) 12 | 13 | type Account struct { 14 | ID int64 `json:"id"` 15 | Owner string `json:"owner"` 16 | Balance int64 `json:"balance"` 17 | Currency string `json:"currency"` 18 | CreatedAt time.Time `json:"created_at"` 19 | } 20 | 21 | type Entry struct { 22 | ID int64 `json:"id"` 23 | AccountID int64 `json:"account_id"` 24 | // can be negative or positive 25 | Amount int64 `json:"amount"` 26 | CreatedAt time.Time `json:"created_at"` 27 | } 28 | 29 | type Session struct { 30 | ID uuid.UUID `json:"id"` 31 | Username string `json:"username"` 32 | RefreshToken string `json:"refresh_token"` 33 | UserAgent string `json:"user_agent"` 34 | ClientIp string `json:"client_ip"` 35 | IsBlocked bool `json:"is_blocked"` 36 | ExpiresAt time.Time `json:"expires_at"` 37 | CreatedAt time.Time `json:"created_at"` 38 | } 39 | 40 | type Transfer struct { 41 | ID int64 `json:"id"` 42 | FromAccountID int64 `json:"from_account_id"` 43 | ToAccountID int64 `json:"to_account_id"` 44 | // must be positive 45 | Amount int64 `json:"amount"` 46 | CreatedAt time.Time `json:"created_at"` 47 | } 48 | 49 | type User struct { 50 | Username string `json:"username"` 51 | HashedPassword string `json:"hashed_password"` 52 | FullName string `json:"full_name"` 53 | Email string `json:"email"` 54 | PasswordChangedAt time.Time `json:"password_changed_at"` 55 | CreatedAt time.Time `json:"created_at"` 56 | IsEmailVerified bool `json:"is_email_verified"` 57 | Role string `json:"role"` 58 | } 59 | 60 | type VerifyEmail struct { 61 | ID int64 `json:"id"` 62 | Username string `json:"username"` 63 | Email string `json:"email"` 64 | SecretCode string `json:"secret_code"` 65 | IsUsed bool `json:"is_used"` 66 | CreatedAt time.Time `json:"created_at"` 67 | ExpiredAt time.Time `json:"expired_at"` 68 | } 69 | -------------------------------------------------------------------------------- /db/sqlc/querier.go: -------------------------------------------------------------------------------- 1 | // Code generated by sqlc. DO NOT EDIT. 2 | // versions: 3 | // sqlc v1.22.0 4 | 5 | package db 6 | 7 | import ( 8 | "context" 9 | 10 | "github.com/google/uuid" 11 | ) 12 | 13 | type Querier interface { 14 | AddAccountBalance(ctx context.Context, arg AddAccountBalanceParams) (Account, error) 15 | CreateAccount(ctx context.Context, arg CreateAccountParams) (Account, error) 16 | CreateEntry(ctx context.Context, arg CreateEntryParams) (Entry, error) 17 | CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) 18 | CreateTransfer(ctx context.Context, arg CreateTransferParams) (Transfer, error) 19 | CreateUser(ctx context.Context, arg CreateUserParams) (User, error) 20 | CreateVerifyEmail(ctx context.Context, arg CreateVerifyEmailParams) (VerifyEmail, error) 21 | DeleteAccount(ctx context.Context, id int64) error 22 | GetAccount(ctx context.Context, id int64) (Account, error) 23 | GetAccountForUpdate(ctx context.Context, id int64) (Account, error) 24 | GetEntry(ctx context.Context, id int64) (Entry, error) 25 | GetSession(ctx context.Context, id uuid.UUID) (Session, error) 26 | GetTransfer(ctx context.Context, id int64) (Transfer, error) 27 | GetUser(ctx context.Context, username string) (User, error) 28 | ListAccounts(ctx context.Context, arg ListAccountsParams) ([]Account, error) 29 | ListEntries(ctx context.Context, arg ListEntriesParams) ([]Entry, error) 30 | ListTransfers(ctx context.Context, arg ListTransfersParams) ([]Transfer, error) 31 | UpdateAccount(ctx context.Context, arg UpdateAccountParams) (Account, error) 32 | UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) 33 | UpdateVerifyEmail(ctx context.Context, arg UpdateVerifyEmailParams) (VerifyEmail, error) 34 | } 35 | 36 | var _ Querier = (*Queries)(nil) 37 | -------------------------------------------------------------------------------- /db/sqlc/session.sql.go: -------------------------------------------------------------------------------- 1 | // Code generated by sqlc. DO NOT EDIT. 2 | // versions: 3 | // sqlc v1.22.0 4 | // source: session.sql 5 | 6 | package db 7 | 8 | import ( 9 | "context" 10 | "time" 11 | 12 | "github.com/google/uuid" 13 | ) 14 | 15 | const createSession = `-- name: CreateSession :one 16 | INSERT INTO sessions ( 17 | id, 18 | username, 19 | refresh_token, 20 | user_agent, 21 | client_ip, 22 | is_blocked, 23 | expires_at 24 | ) VALUES ( 25 | $1, $2, $3, $4, $5, $6, $7 26 | ) RETURNING id, username, refresh_token, user_agent, client_ip, is_blocked, expires_at, created_at 27 | ` 28 | 29 | type CreateSessionParams struct { 30 | ID uuid.UUID `json:"id"` 31 | Username string `json:"username"` 32 | RefreshToken string `json:"refresh_token"` 33 | UserAgent string `json:"user_agent"` 34 | ClientIp string `json:"client_ip"` 35 | IsBlocked bool `json:"is_blocked"` 36 | ExpiresAt time.Time `json:"expires_at"` 37 | } 38 | 39 | func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) { 40 | row := q.db.QueryRow(ctx, createSession, 41 | arg.ID, 42 | arg.Username, 43 | arg.RefreshToken, 44 | arg.UserAgent, 45 | arg.ClientIp, 46 | arg.IsBlocked, 47 | arg.ExpiresAt, 48 | ) 49 | var i Session 50 | err := row.Scan( 51 | &i.ID, 52 | &i.Username, 53 | &i.RefreshToken, 54 | &i.UserAgent, 55 | &i.ClientIp, 56 | &i.IsBlocked, 57 | &i.ExpiresAt, 58 | &i.CreatedAt, 59 | ) 60 | return i, err 61 | } 62 | 63 | const getSession = `-- name: GetSession :one 64 | SELECT id, username, refresh_token, user_agent, client_ip, is_blocked, expires_at, created_at FROM sessions 65 | WHERE id = $1 LIMIT 1 66 | ` 67 | 68 | func (q *Queries) GetSession(ctx context.Context, id uuid.UUID) (Session, error) { 69 | row := q.db.QueryRow(ctx, getSession, id) 70 | var i Session 71 | err := row.Scan( 72 | &i.ID, 73 | &i.Username, 74 | &i.RefreshToken, 75 | &i.UserAgent, 76 | &i.ClientIp, 77 | &i.IsBlocked, 78 | &i.ExpiresAt, 79 | &i.CreatedAt, 80 | ) 81 | return i, err 82 | } 83 | -------------------------------------------------------------------------------- /db/sqlc/store.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/jackc/pgx/v5/pgxpool" 7 | ) 8 | 9 | // Store defines all functions to execute db queries and transactions 10 | type Store interface { 11 | Querier 12 | TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) 13 | CreateUserTx(ctx context.Context, arg CreateUserTxParams) (CreateUserTxResult, error) 14 | VerifyEmailTx(ctx context.Context, arg VerifyEmailTxParams) (VerifyEmailTxResult, error) 15 | } 16 | 17 | // SQLStore provides all functions to execute SQL queries and transactions 18 | type SQLStore struct { 19 | connPool *pgxpool.Pool 20 | *Queries 21 | } 22 | 23 | // NewStore creates a new store 24 | func NewStore(connPool *pgxpool.Pool) Store { 25 | return &SQLStore{ 26 | connPool: connPool, 27 | Queries: New(connPool), 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /db/sqlc/store_test.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestTransferTx(t *testing.T) { 12 | account1 := createRandomAccount(t) 13 | account2 := createRandomAccount(t) 14 | fmt.Println(">> before:", account1.Balance, account2.Balance) 15 | 16 | n := 5 17 | amount := int64(10) 18 | 19 | errs := make(chan error) 20 | results := make(chan TransferTxResult) 21 | 22 | // run n concurrent transfer transaction 23 | for i := 0; i < n; i++ { 24 | go func() { 25 | result, err := testStore.TransferTx(context.Background(), TransferTxParams{ 26 | FromAccountID: account1.ID, 27 | ToAccountID: account2.ID, 28 | Amount: amount, 29 | }) 30 | 31 | errs <- err 32 | results <- result 33 | }() 34 | } 35 | 36 | // check results 37 | existed := make(map[int]bool) 38 | 39 | for i := 0; i < n; i++ { 40 | err := <-errs 41 | require.NoError(t, err) 42 | 43 | result := <-results 44 | require.NotEmpty(t, result) 45 | 46 | // check transfer 47 | transfer := result.Transfer 48 | require.NotEmpty(t, transfer) 49 | require.Equal(t, account1.ID, transfer.FromAccountID) 50 | require.Equal(t, account2.ID, transfer.ToAccountID) 51 | require.Equal(t, amount, transfer.Amount) 52 | require.NotZero(t, transfer.ID) 53 | require.NotZero(t, transfer.CreatedAt) 54 | 55 | _, err = testStore.GetTransfer(context.Background(), transfer.ID) 56 | require.NoError(t, err) 57 | 58 | // check entries 59 | fromEntry := result.FromEntry 60 | require.NotEmpty(t, fromEntry) 61 | require.Equal(t, account1.ID, fromEntry.AccountID) 62 | require.Equal(t, -amount, fromEntry.Amount) 63 | require.NotZero(t, fromEntry.ID) 64 | require.NotZero(t, fromEntry.CreatedAt) 65 | 66 | _, err = testStore.GetEntry(context.Background(), fromEntry.ID) 67 | require.NoError(t, err) 68 | 69 | toEntry := result.ToEntry 70 | require.NotEmpty(t, toEntry) 71 | require.Equal(t, account2.ID, toEntry.AccountID) 72 | require.Equal(t, amount, toEntry.Amount) 73 | require.NotZero(t, toEntry.ID) 74 | require.NotZero(t, toEntry.CreatedAt) 75 | 76 | _, err = testStore.GetEntry(context.Background(), toEntry.ID) 77 | require.NoError(t, err) 78 | 79 | // check accounts 80 | fromAccount := result.FromAccount 81 | require.NotEmpty(t, fromAccount) 82 | require.Equal(t, account1.ID, fromAccount.ID) 83 | 84 | toAccount := result.ToAccount 85 | require.NotEmpty(t, toAccount) 86 | require.Equal(t, account2.ID, toAccount.ID) 87 | 88 | // check balances 89 | fmt.Println(">> tx:", fromAccount.Balance, toAccount.Balance) 90 | 91 | diff1 := account1.Balance - fromAccount.Balance 92 | diff2 := toAccount.Balance - account2.Balance 93 | require.Equal(t, diff1, diff2) 94 | require.True(t, diff1 > 0) 95 | require.True(t, diff1%amount == 0) // 1 * amount, 2 * amount, 3 * amount, ..., n * amount 96 | 97 | k := int(diff1 / amount) 98 | require.True(t, k >= 1 && k <= n) 99 | require.NotContains(t, existed, k) 100 | existed[k] = true 101 | } 102 | 103 | // check the final updated balance 104 | updatedAccount1, err := testStore.GetAccount(context.Background(), account1.ID) 105 | require.NoError(t, err) 106 | 107 | updatedAccount2, err := testStore.GetAccount(context.Background(), account2.ID) 108 | require.NoError(t, err) 109 | 110 | fmt.Println(">> after:", updatedAccount1.Balance, updatedAccount2.Balance) 111 | 112 | require.Equal(t, account1.Balance-int64(n)*amount, updatedAccount1.Balance) 113 | require.Equal(t, account2.Balance+int64(n)*amount, updatedAccount2.Balance) 114 | } 115 | 116 | func TestTransferTxDeadlock(t *testing.T) { 117 | account1 := createRandomAccount(t) 118 | account2 := createRandomAccount(t) 119 | fmt.Println(">> before:", account1.Balance, account2.Balance) 120 | 121 | n := 10 122 | amount := int64(10) 123 | errs := make(chan error) 124 | 125 | for i := 0; i < n; i++ { 126 | fromAccountID := account1.ID 127 | toAccountID := account2.ID 128 | 129 | if i%2 == 1 { 130 | fromAccountID = account2.ID 131 | toAccountID = account1.ID 132 | } 133 | 134 | go func() { 135 | _, err := testStore.TransferTx(context.Background(), TransferTxParams{ 136 | FromAccountID: fromAccountID, 137 | ToAccountID: toAccountID, 138 | Amount: amount, 139 | }) 140 | 141 | errs <- err 142 | }() 143 | } 144 | 145 | for i := 0; i < n; i++ { 146 | err := <-errs 147 | require.NoError(t, err) 148 | } 149 | 150 | // check the final updated balance 151 | updatedAccount1, err := testStore.GetAccount(context.Background(), account1.ID) 152 | require.NoError(t, err) 153 | 154 | updatedAccount2, err := testStore.GetAccount(context.Background(), account2.ID) 155 | require.NoError(t, err) 156 | 157 | fmt.Println(">> after:", updatedAccount1.Balance, updatedAccount2.Balance) 158 | require.Equal(t, account1.Balance, updatedAccount1.Balance) 159 | require.Equal(t, account2.Balance, updatedAccount2.Balance) 160 | } 161 | -------------------------------------------------------------------------------- /db/sqlc/transfer.sql.go: -------------------------------------------------------------------------------- 1 | // Code generated by sqlc. DO NOT EDIT. 2 | // versions: 3 | // sqlc v1.22.0 4 | // source: transfer.sql 5 | 6 | package db 7 | 8 | import ( 9 | "context" 10 | ) 11 | 12 | const createTransfer = `-- name: CreateTransfer :one 13 | INSERT INTO transfers ( 14 | from_account_id, 15 | to_account_id, 16 | amount 17 | ) VALUES ( 18 | $1, $2, $3 19 | ) RETURNING id, from_account_id, to_account_id, amount, created_at 20 | ` 21 | 22 | type CreateTransferParams struct { 23 | FromAccountID int64 `json:"from_account_id"` 24 | ToAccountID int64 `json:"to_account_id"` 25 | Amount int64 `json:"amount"` 26 | } 27 | 28 | func (q *Queries) CreateTransfer(ctx context.Context, arg CreateTransferParams) (Transfer, error) { 29 | row := q.db.QueryRow(ctx, createTransfer, arg.FromAccountID, arg.ToAccountID, arg.Amount) 30 | var i Transfer 31 | err := row.Scan( 32 | &i.ID, 33 | &i.FromAccountID, 34 | &i.ToAccountID, 35 | &i.Amount, 36 | &i.CreatedAt, 37 | ) 38 | return i, err 39 | } 40 | 41 | const getTransfer = `-- name: GetTransfer :one 42 | SELECT id, from_account_id, to_account_id, amount, created_at FROM transfers 43 | WHERE id = $1 LIMIT 1 44 | ` 45 | 46 | func (q *Queries) GetTransfer(ctx context.Context, id int64) (Transfer, error) { 47 | row := q.db.QueryRow(ctx, getTransfer, id) 48 | var i Transfer 49 | err := row.Scan( 50 | &i.ID, 51 | &i.FromAccountID, 52 | &i.ToAccountID, 53 | &i.Amount, 54 | &i.CreatedAt, 55 | ) 56 | return i, err 57 | } 58 | 59 | const listTransfers = `-- name: ListTransfers :many 60 | SELECT id, from_account_id, to_account_id, amount, created_at FROM transfers 61 | WHERE 62 | from_account_id = $1 OR 63 | to_account_id = $2 64 | ORDER BY id 65 | LIMIT $3 66 | OFFSET $4 67 | ` 68 | 69 | type ListTransfersParams struct { 70 | FromAccountID int64 `json:"from_account_id"` 71 | ToAccountID int64 `json:"to_account_id"` 72 | Limit int32 `json:"limit"` 73 | Offset int32 `json:"offset"` 74 | } 75 | 76 | func (q *Queries) ListTransfers(ctx context.Context, arg ListTransfersParams) ([]Transfer, error) { 77 | rows, err := q.db.Query(ctx, listTransfers, 78 | arg.FromAccountID, 79 | arg.ToAccountID, 80 | arg.Limit, 81 | arg.Offset, 82 | ) 83 | if err != nil { 84 | return nil, err 85 | } 86 | defer rows.Close() 87 | items := []Transfer{} 88 | for rows.Next() { 89 | var i Transfer 90 | if err := rows.Scan( 91 | &i.ID, 92 | &i.FromAccountID, 93 | &i.ToAccountID, 94 | &i.Amount, 95 | &i.CreatedAt, 96 | ); err != nil { 97 | return nil, err 98 | } 99 | items = append(items, i) 100 | } 101 | if err := rows.Err(); err != nil { 102 | return nil, err 103 | } 104 | return items, nil 105 | } 106 | -------------------------------------------------------------------------------- /db/sqlc/transfer_test.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/require" 9 | "github.com/techschool/simplebank/util" 10 | ) 11 | 12 | func createRandomTransfer(t *testing.T, account1, account2 Account) Transfer { 13 | arg := CreateTransferParams{ 14 | FromAccountID: account1.ID, 15 | ToAccountID: account2.ID, 16 | Amount: util.RandomMoney(), 17 | } 18 | 19 | transfer, err := testStore.CreateTransfer(context.Background(), arg) 20 | require.NoError(t, err) 21 | require.NotEmpty(t, transfer) 22 | 23 | require.Equal(t, arg.FromAccountID, transfer.FromAccountID) 24 | require.Equal(t, arg.ToAccountID, transfer.ToAccountID) 25 | require.Equal(t, arg.Amount, transfer.Amount) 26 | 27 | require.NotZero(t, transfer.ID) 28 | require.NotZero(t, transfer.CreatedAt) 29 | 30 | return transfer 31 | } 32 | 33 | func TestCreateTransfer(t *testing.T) { 34 | account1 := createRandomAccount(t) 35 | account2 := createRandomAccount(t) 36 | createRandomTransfer(t, account1, account2) 37 | } 38 | 39 | func TestGetTransfer(t *testing.T) { 40 | account1 := createRandomAccount(t) 41 | account2 := createRandomAccount(t) 42 | transfer1 := createRandomTransfer(t, account1, account2) 43 | 44 | transfer2, err := testStore.GetTransfer(context.Background(), transfer1.ID) 45 | require.NoError(t, err) 46 | require.NotEmpty(t, transfer2) 47 | 48 | require.Equal(t, transfer1.ID, transfer2.ID) 49 | require.Equal(t, transfer1.FromAccountID, transfer2.FromAccountID) 50 | require.Equal(t, transfer1.ToAccountID, transfer2.ToAccountID) 51 | require.Equal(t, transfer1.Amount, transfer2.Amount) 52 | require.WithinDuration(t, transfer1.CreatedAt, transfer2.CreatedAt, time.Second) 53 | } 54 | 55 | func TestListTransfer(t *testing.T) { 56 | account1 := createRandomAccount(t) 57 | account2 := createRandomAccount(t) 58 | 59 | for i := 0; i < 5; i++ { 60 | createRandomTransfer(t, account1, account2) 61 | createRandomTransfer(t, account2, account1) 62 | } 63 | 64 | arg := ListTransfersParams{ 65 | FromAccountID: account1.ID, 66 | ToAccountID: account1.ID, 67 | Limit: 5, 68 | Offset: 5, 69 | } 70 | 71 | transfers, err := testStore.ListTransfers(context.Background(), arg) 72 | require.NoError(t, err) 73 | require.Len(t, transfers, 5) 74 | 75 | for _, transfer := range transfers { 76 | require.NotEmpty(t, transfer) 77 | require.True(t, transfer.FromAccountID == account1.ID || transfer.ToAccountID == account1.ID) 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /db/sqlc/tx_create_user.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import "context" 4 | 5 | type CreateUserTxParams struct { 6 | CreateUserParams 7 | AfterCreate func(user User) error 8 | } 9 | 10 | type CreateUserTxResult struct { 11 | User User 12 | } 13 | 14 | func (store *SQLStore) CreateUserTx(ctx context.Context, arg CreateUserTxParams) (CreateUserTxResult, error) { 15 | var result CreateUserTxResult 16 | 17 | err := store.execTx(ctx, func(q *Queries) error { 18 | var err error 19 | 20 | result.User, err = q.CreateUser(ctx, arg.CreateUserParams) 21 | if err != nil { 22 | return err 23 | } 24 | 25 | return arg.AfterCreate(result.User) 26 | }) 27 | 28 | return result, err 29 | } 30 | -------------------------------------------------------------------------------- /db/sqlc/tx_transfer.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import "context" 4 | 5 | // TransferTxParams contains the input parameters of the transfer transaction 6 | type TransferTxParams struct { 7 | FromAccountID int64 `json:"from_account_id"` 8 | ToAccountID int64 `json:"to_account_id"` 9 | Amount int64 `json:"amount"` 10 | } 11 | 12 | // TransferTxResult is the result of the transfer transaction 13 | type TransferTxResult struct { 14 | Transfer Transfer `json:"transfer"` 15 | FromAccount Account `json:"from_account"` 16 | ToAccount Account `json:"to_account"` 17 | FromEntry Entry `json:"from_entry"` 18 | ToEntry Entry `json:"to_entry"` 19 | } 20 | 21 | // TransferTx performs a money transfer from one account to the other. 22 | // It creates the transfer, add account entries, and update accounts' balance within a database transaction 23 | func (store *SQLStore) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) { 24 | var result TransferTxResult 25 | 26 | err := store.execTx(ctx, func(q *Queries) error { 27 | var err error 28 | 29 | result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{ 30 | FromAccountID: arg.FromAccountID, 31 | ToAccountID: arg.ToAccountID, 32 | Amount: arg.Amount, 33 | }) 34 | if err != nil { 35 | return err 36 | } 37 | 38 | result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{ 39 | AccountID: arg.FromAccountID, 40 | Amount: -arg.Amount, 41 | }) 42 | if err != nil { 43 | return err 44 | } 45 | 46 | result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{ 47 | AccountID: arg.ToAccountID, 48 | Amount: arg.Amount, 49 | }) 50 | if err != nil { 51 | return err 52 | } 53 | 54 | if arg.FromAccountID < arg.ToAccountID { 55 | result.FromAccount, result.ToAccount, err = addMoney(ctx, q, arg.FromAccountID, -arg.Amount, arg.ToAccountID, arg.Amount) 56 | } else { 57 | result.ToAccount, result.FromAccount, err = addMoney(ctx, q, arg.ToAccountID, arg.Amount, arg.FromAccountID, -arg.Amount) 58 | } 59 | 60 | return err 61 | }) 62 | 63 | return result, err 64 | } 65 | 66 | func addMoney( 67 | ctx context.Context, 68 | q *Queries, 69 | accountID1 int64, 70 | amount1 int64, 71 | accountID2 int64, 72 | amount2 int64, 73 | ) (account1 Account, account2 Account, err error) { 74 | account1, err = q.AddAccountBalance(ctx, AddAccountBalanceParams{ 75 | ID: accountID1, 76 | Amount: amount1, 77 | }) 78 | if err != nil { 79 | return 80 | } 81 | 82 | account2, err = q.AddAccountBalance(ctx, AddAccountBalanceParams{ 83 | ID: accountID2, 84 | Amount: amount2, 85 | }) 86 | return 87 | } 88 | -------------------------------------------------------------------------------- /db/sqlc/tx_verify_email.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/jackc/pgx/v5/pgtype" 7 | ) 8 | 9 | type VerifyEmailTxParams struct { 10 | EmailId int64 11 | SecretCode string 12 | } 13 | 14 | type VerifyEmailTxResult struct { 15 | User User 16 | VerifyEmail VerifyEmail 17 | } 18 | 19 | func (store *SQLStore) VerifyEmailTx(ctx context.Context, arg VerifyEmailTxParams) (VerifyEmailTxResult, error) { 20 | var result VerifyEmailTxResult 21 | 22 | err := store.execTx(ctx, func(q *Queries) error { 23 | var err error 24 | 25 | result.VerifyEmail, err = q.UpdateVerifyEmail(ctx, UpdateVerifyEmailParams{ 26 | ID: arg.EmailId, 27 | SecretCode: arg.SecretCode, 28 | }) 29 | if err != nil { 30 | return err 31 | } 32 | 33 | result.User, err = q.UpdateUser(ctx, UpdateUserParams{ 34 | Username: result.VerifyEmail.Username, 35 | IsEmailVerified: pgtype.Bool{ 36 | Bool: true, 37 | Valid: true, 38 | }, 39 | }) 40 | return err 41 | }) 42 | 43 | return result, err 44 | } 45 | -------------------------------------------------------------------------------- /db/sqlc/user.sql.go: -------------------------------------------------------------------------------- 1 | // Code generated by sqlc. DO NOT EDIT. 2 | // versions: 3 | // sqlc v1.22.0 4 | // source: user.sql 5 | 6 | package db 7 | 8 | import ( 9 | "context" 10 | 11 | "github.com/jackc/pgx/v5/pgtype" 12 | ) 13 | 14 | const createUser = `-- name: CreateUser :one 15 | INSERT INTO users ( 16 | username, 17 | hashed_password, 18 | full_name, 19 | email 20 | ) VALUES ( 21 | $1, $2, $3, $4 22 | ) RETURNING username, hashed_password, full_name, email, password_changed_at, created_at, is_email_verified, role 23 | ` 24 | 25 | type CreateUserParams struct { 26 | Username string `json:"username"` 27 | HashedPassword string `json:"hashed_password"` 28 | FullName string `json:"full_name"` 29 | Email string `json:"email"` 30 | } 31 | 32 | func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { 33 | row := q.db.QueryRow(ctx, createUser, 34 | arg.Username, 35 | arg.HashedPassword, 36 | arg.FullName, 37 | arg.Email, 38 | ) 39 | var i User 40 | err := row.Scan( 41 | &i.Username, 42 | &i.HashedPassword, 43 | &i.FullName, 44 | &i.Email, 45 | &i.PasswordChangedAt, 46 | &i.CreatedAt, 47 | &i.IsEmailVerified, 48 | &i.Role, 49 | ) 50 | return i, err 51 | } 52 | 53 | const getUser = `-- name: GetUser :one 54 | SELECT username, hashed_password, full_name, email, password_changed_at, created_at, is_email_verified, role FROM users 55 | WHERE username = $1 LIMIT 1 56 | ` 57 | 58 | func (q *Queries) GetUser(ctx context.Context, username string) (User, error) { 59 | row := q.db.QueryRow(ctx, getUser, username) 60 | var i User 61 | err := row.Scan( 62 | &i.Username, 63 | &i.HashedPassword, 64 | &i.FullName, 65 | &i.Email, 66 | &i.PasswordChangedAt, 67 | &i.CreatedAt, 68 | &i.IsEmailVerified, 69 | &i.Role, 70 | ) 71 | return i, err 72 | } 73 | 74 | const updateUser = `-- name: UpdateUser :one 75 | UPDATE users 76 | SET 77 | hashed_password = COALESCE($1, hashed_password), 78 | password_changed_at = COALESCE($2, password_changed_at), 79 | full_name = COALESCE($3, full_name), 80 | email = COALESCE($4, email), 81 | is_email_verified = COALESCE($5, is_email_verified) 82 | WHERE 83 | username = $6 84 | RETURNING username, hashed_password, full_name, email, password_changed_at, created_at, is_email_verified, role 85 | ` 86 | 87 | type UpdateUserParams struct { 88 | HashedPassword pgtype.Text `json:"hashed_password"` 89 | PasswordChangedAt pgtype.Timestamptz `json:"password_changed_at"` 90 | FullName pgtype.Text `json:"full_name"` 91 | Email pgtype.Text `json:"email"` 92 | IsEmailVerified pgtype.Bool `json:"is_email_verified"` 93 | Username string `json:"username"` 94 | } 95 | 96 | func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) { 97 | row := q.db.QueryRow(ctx, updateUser, 98 | arg.HashedPassword, 99 | arg.PasswordChangedAt, 100 | arg.FullName, 101 | arg.Email, 102 | arg.IsEmailVerified, 103 | arg.Username, 104 | ) 105 | var i User 106 | err := row.Scan( 107 | &i.Username, 108 | &i.HashedPassword, 109 | &i.FullName, 110 | &i.Email, 111 | &i.PasswordChangedAt, 112 | &i.CreatedAt, 113 | &i.IsEmailVerified, 114 | &i.Role, 115 | ) 116 | return i, err 117 | } 118 | -------------------------------------------------------------------------------- /db/sqlc/user_test.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/jackc/pgx/v5/pgtype" 9 | "github.com/stretchr/testify/require" 10 | "github.com/techschool/simplebank/util" 11 | ) 12 | 13 | func createRandomUser(t *testing.T) User { 14 | hashedPassword, err := util.HashPassword(util.RandomString(6)) 15 | require.NoError(t, err) 16 | 17 | arg := CreateUserParams{ 18 | Username: util.RandomOwner(), 19 | HashedPassword: hashedPassword, 20 | FullName: util.RandomOwner(), 21 | Email: util.RandomEmail(), 22 | } 23 | 24 | user, err := testStore.CreateUser(context.Background(), arg) 25 | require.NoError(t, err) 26 | require.NotEmpty(t, user) 27 | 28 | require.Equal(t, arg.Username, user.Username) 29 | require.Equal(t, arg.HashedPassword, user.HashedPassword) 30 | require.Equal(t, arg.FullName, user.FullName) 31 | require.Equal(t, arg.Email, user.Email) 32 | require.True(t, user.PasswordChangedAt.IsZero()) 33 | require.NotZero(t, user.CreatedAt) 34 | 35 | return user 36 | } 37 | 38 | func TestCreateUser(t *testing.T) { 39 | createRandomUser(t) 40 | } 41 | 42 | func TestGetUser(t *testing.T) { 43 | user1 := createRandomUser(t) 44 | user2, err := testStore.GetUser(context.Background(), user1.Username) 45 | require.NoError(t, err) 46 | require.NotEmpty(t, user2) 47 | 48 | require.Equal(t, user1.Username, user2.Username) 49 | require.Equal(t, user1.HashedPassword, user2.HashedPassword) 50 | require.Equal(t, user1.FullName, user2.FullName) 51 | require.Equal(t, user1.Email, user2.Email) 52 | require.WithinDuration(t, user1.PasswordChangedAt, user2.PasswordChangedAt, time.Second) 53 | require.WithinDuration(t, user1.CreatedAt, user2.CreatedAt, time.Second) 54 | } 55 | 56 | func TestUpdateUserOnlyFullName(t *testing.T) { 57 | oldUser := createRandomUser(t) 58 | 59 | newFullName := util.RandomOwner() 60 | updatedUser, err := testStore.UpdateUser(context.Background(), UpdateUserParams{ 61 | Username: oldUser.Username, 62 | FullName: pgtype.Text{ 63 | String: newFullName, 64 | Valid: true, 65 | }, 66 | }) 67 | 68 | require.NoError(t, err) 69 | require.NotEqual(t, oldUser.FullName, updatedUser.FullName) 70 | require.Equal(t, newFullName, updatedUser.FullName) 71 | require.Equal(t, oldUser.Email, updatedUser.Email) 72 | require.Equal(t, oldUser.HashedPassword, updatedUser.HashedPassword) 73 | } 74 | 75 | func TestUpdateUserOnlyEmail(t *testing.T) { 76 | oldUser := createRandomUser(t) 77 | 78 | newEmail := util.RandomEmail() 79 | updatedUser, err := testStore.UpdateUser(context.Background(), UpdateUserParams{ 80 | Username: oldUser.Username, 81 | Email: pgtype.Text{ 82 | String: newEmail, 83 | Valid: true, 84 | }, 85 | }) 86 | 87 | require.NoError(t, err) 88 | require.NotEqual(t, oldUser.Email, updatedUser.Email) 89 | require.Equal(t, newEmail, updatedUser.Email) 90 | require.Equal(t, oldUser.FullName, updatedUser.FullName) 91 | require.Equal(t, oldUser.HashedPassword, updatedUser.HashedPassword) 92 | } 93 | 94 | func TestUpdateUserOnlyPassword(t *testing.T) { 95 | oldUser := createRandomUser(t) 96 | 97 | newPassword := util.RandomString(6) 98 | newHashedPassword, err := util.HashPassword(newPassword) 99 | require.NoError(t, err) 100 | 101 | updatedUser, err := testStore.UpdateUser(context.Background(), UpdateUserParams{ 102 | Username: oldUser.Username, 103 | HashedPassword: pgtype.Text{ 104 | String: newHashedPassword, 105 | Valid: true, 106 | }, 107 | }) 108 | 109 | require.NoError(t, err) 110 | require.NotEqual(t, oldUser.HashedPassword, updatedUser.HashedPassword) 111 | require.Equal(t, newHashedPassword, updatedUser.HashedPassword) 112 | require.Equal(t, oldUser.FullName, updatedUser.FullName) 113 | require.Equal(t, oldUser.Email, updatedUser.Email) 114 | } 115 | 116 | func TestUpdateUserAllFields(t *testing.T) { 117 | oldUser := createRandomUser(t) 118 | 119 | newFullName := util.RandomOwner() 120 | newEmail := util.RandomEmail() 121 | newPassword := util.RandomString(6) 122 | newHashedPassword, err := util.HashPassword(newPassword) 123 | require.NoError(t, err) 124 | 125 | updatedUser, err := testStore.UpdateUser(context.Background(), UpdateUserParams{ 126 | Username: oldUser.Username, 127 | FullName: pgtype.Text{ 128 | String: newFullName, 129 | Valid: true, 130 | }, 131 | Email: pgtype.Text{ 132 | String: newEmail, 133 | Valid: true, 134 | }, 135 | HashedPassword: pgtype.Text{ 136 | String: newHashedPassword, 137 | Valid: true, 138 | }, 139 | }) 140 | 141 | require.NoError(t, err) 142 | require.NotEqual(t, oldUser.HashedPassword, updatedUser.HashedPassword) 143 | require.Equal(t, newHashedPassword, updatedUser.HashedPassword) 144 | require.NotEqual(t, oldUser.Email, updatedUser.Email) 145 | require.Equal(t, newEmail, updatedUser.Email) 146 | require.NotEqual(t, oldUser.FullName, updatedUser.FullName) 147 | require.Equal(t, newFullName, updatedUser.FullName) 148 | } 149 | -------------------------------------------------------------------------------- /db/sqlc/verify_email.sql.go: -------------------------------------------------------------------------------- 1 | // Code generated by sqlc. DO NOT EDIT. 2 | // versions: 3 | // sqlc v1.22.0 4 | // source: verify_email.sql 5 | 6 | package db 7 | 8 | import ( 9 | "context" 10 | ) 11 | 12 | const createVerifyEmail = `-- name: CreateVerifyEmail :one 13 | INSERT INTO verify_emails ( 14 | username, 15 | email, 16 | secret_code 17 | ) VALUES ( 18 | $1, $2, $3 19 | ) RETURNING id, username, email, secret_code, is_used, created_at, expired_at 20 | ` 21 | 22 | type CreateVerifyEmailParams struct { 23 | Username string `json:"username"` 24 | Email string `json:"email"` 25 | SecretCode string `json:"secret_code"` 26 | } 27 | 28 | func (q *Queries) CreateVerifyEmail(ctx context.Context, arg CreateVerifyEmailParams) (VerifyEmail, error) { 29 | row := q.db.QueryRow(ctx, createVerifyEmail, arg.Username, arg.Email, arg.SecretCode) 30 | var i VerifyEmail 31 | err := row.Scan( 32 | &i.ID, 33 | &i.Username, 34 | &i.Email, 35 | &i.SecretCode, 36 | &i.IsUsed, 37 | &i.CreatedAt, 38 | &i.ExpiredAt, 39 | ) 40 | return i, err 41 | } 42 | 43 | const updateVerifyEmail = `-- name: UpdateVerifyEmail :one 44 | UPDATE verify_emails 45 | SET 46 | is_used = TRUE 47 | WHERE 48 | id = $1 49 | AND secret_code = $2 50 | AND is_used = FALSE 51 | AND expired_at > now() 52 | RETURNING id, username, email, secret_code, is_used, created_at, expired_at 53 | ` 54 | 55 | type UpdateVerifyEmailParams struct { 56 | ID int64 `json:"id"` 57 | SecretCode string `json:"secret_code"` 58 | } 59 | 60 | func (q *Queries) UpdateVerifyEmail(ctx context.Context, arg UpdateVerifyEmailParams) (VerifyEmail, error) { 61 | row := q.db.QueryRow(ctx, updateVerifyEmail, arg.ID, arg.SecretCode) 62 | var i VerifyEmail 63 | err := row.Scan( 64 | &i.ID, 65 | &i.Username, 66 | &i.Email, 67 | &i.SecretCode, 68 | &i.IsUsed, 69 | &i.CreatedAt, 70 | &i.ExpiredAt, 71 | ) 72 | return i, err 73 | } 74 | -------------------------------------------------------------------------------- /doc/db.dbml: -------------------------------------------------------------------------------- 1 | Project simple_bank { 2 | database_type: 'PostgreSQL' 3 | Note: ''' 4 | # Simple Bank Database 5 | ''' 6 | } 7 | 8 | Table users as U { 9 | username varchar [pk] 10 | role varchar [not null, default: 'depositor'] 11 | hashed_password varchar [not null] 12 | full_name varchar [not null] 13 | email varchar [unique, not null] 14 | is_email_verified bool [not null, default: false] 15 | password_changed_at timestamptz [not null, default: '0001-01-01'] 16 | created_at timestamptz [not null, default: `now()`] 17 | } 18 | 19 | Table verify_emails { 20 | id bigserial [pk] 21 | username varchar [ref: > U.username, not null] 22 | email varchar [not null] 23 | secret_code varchar [not null] 24 | is_used bool [not null, default: false] 25 | created_at timestamptz [not null, default: `now()`] 26 | expired_at timestamptz [not null, default: `now() + interval '15 minutes'`] 27 | } 28 | 29 | Table accounts as A { 30 | id bigserial [pk] 31 | owner varchar [ref: > U.username, not null] 32 | balance bigint [not null] 33 | currency varchar [not null] 34 | created_at timestamptz [not null, default: `now()`] 35 | 36 | Indexes { 37 | owner 38 | (owner, currency) [unique] 39 | } 40 | } 41 | 42 | Table entries { 43 | id bigserial [pk] 44 | account_id bigint [ref: > A.id, not null] 45 | amount bigint [not null, note: 'can be negative or positive'] 46 | created_at timestamptz [not null, default: `now()`] 47 | 48 | Indexes { 49 | account_id 50 | } 51 | } 52 | 53 | Table transfers { 54 | id bigserial [pk] 55 | from_account_id bigint [ref: > A.id, not null] 56 | to_account_id bigint [ref: > A.id, not null] 57 | amount bigint [not null, note: 'must be positive'] 58 | created_at timestamptz [not null, default: `now()`] 59 | 60 | Indexes { 61 | from_account_id 62 | to_account_id 63 | (from_account_id, to_account_id) 64 | } 65 | } 66 | 67 | Table sessions { 68 | id uuid [pk] 69 | username varchar [ref: > U.username, not null] 70 | refresh_token varchar [not null] 71 | user_agent varchar [not null] 72 | client_ip varchar [not null] 73 | is_blocked boolean [not null, default: false] 74 | expires_at timestamptz [not null] 75 | created_at timestamptz [not null, default: `now()`] 76 | } 77 | -------------------------------------------------------------------------------- /doc/schema.sql: -------------------------------------------------------------------------------- 1 | -- SQL dump generated using DBML (dbml-lang.org) 2 | -- Database: PostgreSQL 3 | -- Generated at: 2023-09-30T12:00:38.491Z 4 | 5 | CREATE TABLE "users" ( 6 | "username" varchar PRIMARY KEY, 7 | "role" varchar NOT NULL DEFAULT 'depositor', 8 | "hashed_password" varchar NOT NULL, 9 | "full_name" varchar NOT NULL, 10 | "email" varchar UNIQUE NOT NULL, 11 | "is_email_verified" bool NOT NULL DEFAULT false, 12 | "password_changed_at" timestamptz NOT NULL DEFAULT '0001-01-01', 13 | "created_at" timestamptz NOT NULL DEFAULT (now()) 14 | ); 15 | 16 | CREATE TABLE "verify_emails" ( 17 | "id" bigserial PRIMARY KEY, 18 | "username" varchar NOT NULL, 19 | "email" varchar NOT NULL, 20 | "secret_code" varchar NOT NULL, 21 | "is_used" bool NOT NULL DEFAULT false, 22 | "created_at" timestamptz NOT NULL DEFAULT (now()), 23 | "expired_at" timestamptz NOT NULL DEFAULT (now() + interval '15 minutes') 24 | ); 25 | 26 | CREATE TABLE "accounts" ( 27 | "id" bigserial PRIMARY KEY, 28 | "owner" varchar NOT NULL, 29 | "balance" bigint NOT NULL, 30 | "currency" varchar NOT NULL, 31 | "created_at" timestamptz NOT NULL DEFAULT (now()) 32 | ); 33 | 34 | CREATE TABLE "entries" ( 35 | "id" bigserial PRIMARY KEY, 36 | "account_id" bigint NOT NULL, 37 | "amount" bigint NOT NULL, 38 | "created_at" timestamptz NOT NULL DEFAULT (now()) 39 | ); 40 | 41 | CREATE TABLE "transfers" ( 42 | "id" bigserial PRIMARY KEY, 43 | "from_account_id" bigint NOT NULL, 44 | "to_account_id" bigint NOT NULL, 45 | "amount" bigint NOT NULL, 46 | "created_at" timestamptz NOT NULL DEFAULT (now()) 47 | ); 48 | 49 | CREATE TABLE "sessions" ( 50 | "id" uuid PRIMARY KEY, 51 | "username" varchar NOT NULL, 52 | "refresh_token" varchar NOT NULL, 53 | "user_agent" varchar NOT NULL, 54 | "client_ip" varchar NOT NULL, 55 | "is_blocked" boolean NOT NULL DEFAULT false, 56 | "expires_at" timestamptz NOT NULL, 57 | "created_at" timestamptz NOT NULL DEFAULT (now()) 58 | ); 59 | 60 | CREATE INDEX ON "accounts" ("owner"); 61 | 62 | CREATE UNIQUE INDEX ON "accounts" ("owner", "currency"); 63 | 64 | CREATE INDEX ON "entries" ("account_id"); 65 | 66 | CREATE INDEX ON "transfers" ("from_account_id"); 67 | 68 | CREATE INDEX ON "transfers" ("to_account_id"); 69 | 70 | CREATE INDEX ON "transfers" ("from_account_id", "to_account_id"); 71 | 72 | COMMENT ON COLUMN "entries"."amount" IS 'can be negative or positive'; 73 | 74 | COMMENT ON COLUMN "transfers"."amount" IS 'must be positive'; 75 | 76 | ALTER TABLE "verify_emails" ADD FOREIGN KEY ("username") REFERENCES "users" ("username"); 77 | 78 | ALTER TABLE "accounts" ADD FOREIGN KEY ("owner") REFERENCES "users" ("username"); 79 | 80 | ALTER TABLE "entries" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id"); 81 | 82 | ALTER TABLE "transfers" ADD FOREIGN KEY ("from_account_id") REFERENCES "accounts" ("id"); 83 | 84 | ALTER TABLE "transfers" ADD FOREIGN KEY ("to_account_id") REFERENCES "accounts" ("id"); 85 | 86 | ALTER TABLE "sessions" ADD FOREIGN KEY ("username") REFERENCES "users" ("username"); 87 | -------------------------------------------------------------------------------- /doc/swagger/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/techschool/simplebank/97f000fe58ad01a0774179ffa8884ac7784cf263/doc/swagger/favicon-16x16.png -------------------------------------------------------------------------------- /doc/swagger/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/techschool/simplebank/97f000fe58ad01a0774179ffa8884ac7784cf263/doc/swagger/favicon-32x32.png -------------------------------------------------------------------------------- /doc/swagger/index.css: -------------------------------------------------------------------------------- 1 | html { 2 | box-sizing: border-box; 3 | overflow: -moz-scrollbars-vertical; 4 | overflow-y: scroll; 5 | } 6 | 7 | *, 8 | *:before, 9 | *:after { 10 | box-sizing: inherit; 11 | } 12 | 13 | body { 14 | margin: 0; 15 | background: #fafafa; 16 | } 17 | -------------------------------------------------------------------------------- /doc/swagger/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
5 | 6 |This is a test message from Tech School
24 | ` 25 | to := []string{"techschool.guru@gmail.com"} 26 | attachFiles := []string{"../README.md"} 27 | 28 | err = sender.SendEmail(subject, content, to, nil, nil, attachFiles) 29 | require.NoError(t, err) 30 | } 31 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net" 7 | "net/http" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | 12 | "github.com/golang-migrate/migrate/v4" 13 | _ "github.com/golang-migrate/migrate/v4/database/postgres" 14 | _ "github.com/golang-migrate/migrate/v4/source/file" 15 | "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" 16 | "github.com/hibiken/asynq" 17 | "github.com/jackc/pgx/v5/pgxpool" 18 | "github.com/rakyll/statik/fs" 19 | "github.com/rs/cors" 20 | "github.com/rs/zerolog" 21 | "github.com/rs/zerolog/log" 22 | "github.com/techschool/simplebank/api" 23 | db "github.com/techschool/simplebank/db/sqlc" 24 | _ "github.com/techschool/simplebank/doc/statik" 25 | "github.com/techschool/simplebank/gapi" 26 | "github.com/techschool/simplebank/mail" 27 | "github.com/techschool/simplebank/pb" 28 | "github.com/techschool/simplebank/util" 29 | "github.com/techschool/simplebank/worker" 30 | "golang.org/x/sync/errgroup" 31 | "google.golang.org/grpc" 32 | "google.golang.org/grpc/reflection" 33 | "google.golang.org/protobuf/encoding/protojson" 34 | ) 35 | 36 | var interruptSignals = []os.Signal{ 37 | os.Interrupt, 38 | syscall.SIGTERM, 39 | syscall.SIGINT, 40 | } 41 | 42 | func main() { 43 | config, err := util.LoadConfig(".") 44 | if err != nil { 45 | log.Fatal().Err(err).Msg("cannot load config") 46 | } 47 | 48 | if config.Environment == "development" { 49 | log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) 50 | } 51 | 52 | ctx, stop := signal.NotifyContext(context.Background(), interruptSignals...) 53 | defer stop() 54 | 55 | connPool, err := pgxpool.New(ctx, config.DBSource) 56 | if err != nil { 57 | log.Fatal().Err(err).Msg("cannot connect to db") 58 | } 59 | 60 | runDBMigration(config.MigrationURL, config.DBSource) 61 | 62 | store := db.NewStore(connPool) 63 | 64 | redisOpt := asynq.RedisClientOpt{ 65 | Addr: config.RedisAddress, 66 | } 67 | 68 | taskDistributor := worker.NewRedisTaskDistributor(redisOpt) 69 | 70 | waitGroup, ctx := errgroup.WithContext(ctx) 71 | 72 | runTaskProcessor(ctx, waitGroup, config, redisOpt, store) 73 | runGatewayServer(ctx, waitGroup, config, store, taskDistributor) 74 | runGrpcServer(ctx, waitGroup, config, store, taskDistributor) 75 | 76 | err = waitGroup.Wait() 77 | if err != nil { 78 | log.Fatal().Err(err).Msg("error from wait group") 79 | } 80 | } 81 | 82 | func runDBMigration(migrationURL string, dbSource string) { 83 | migration, err := migrate.New(migrationURL, dbSource) 84 | if err != nil { 85 | log.Fatal().Err(err).Msg("cannot create new migrate instance") 86 | } 87 | 88 | if err = migration.Up(); err != nil && err != migrate.ErrNoChange { 89 | log.Fatal().Err(err).Msg("failed to run migrate up") 90 | } 91 | 92 | log.Info().Msg("db migrated successfully") 93 | } 94 | 95 | func runTaskProcessor( 96 | ctx context.Context, 97 | waitGroup *errgroup.Group, 98 | config util.Config, 99 | redisOpt asynq.RedisClientOpt, 100 | store db.Store, 101 | ) { 102 | mailer := mail.NewGmailSender(config.EmailSenderName, config.EmailSenderAddress, config.EmailSenderPassword) 103 | taskProcessor := worker.NewRedisTaskProcessor(redisOpt, store, mailer) 104 | 105 | log.Info().Msg("start task processor") 106 | err := taskProcessor.Start() 107 | if err != nil { 108 | log.Fatal().Err(err).Msg("failed to start task processor") 109 | } 110 | 111 | waitGroup.Go(func() error { 112 | <-ctx.Done() 113 | log.Info().Msg("graceful shutdown task processor") 114 | 115 | taskProcessor.Shutdown() 116 | log.Info().Msg("task processor is stopped") 117 | 118 | return nil 119 | }) 120 | } 121 | 122 | func runGrpcServer( 123 | ctx context.Context, 124 | waitGroup *errgroup.Group, 125 | config util.Config, 126 | store db.Store, 127 | taskDistributor worker.TaskDistributor, 128 | ) { 129 | server, err := gapi.NewServer(config, store, taskDistributor) 130 | if err != nil { 131 | log.Fatal().Err(err).Msg("cannot create server") 132 | } 133 | 134 | gprcLogger := grpc.UnaryInterceptor(gapi.GrpcLogger) 135 | grpcServer := grpc.NewServer(gprcLogger) 136 | pb.RegisterSimpleBankServer(grpcServer, server) 137 | reflection.Register(grpcServer) 138 | 139 | listener, err := net.Listen("tcp", config.GRPCServerAddress) 140 | if err != nil { 141 | log.Fatal().Err(err).Msg("cannot create listener") 142 | } 143 | 144 | waitGroup.Go(func() error { 145 | log.Info().Msgf("start gRPC server at %s", listener.Addr().String()) 146 | 147 | err = grpcServer.Serve(listener) 148 | if err != nil { 149 | if errors.Is(err, grpc.ErrServerStopped) { 150 | return nil 151 | } 152 | log.Error().Err(err).Msg("gRPC server failed to serve") 153 | return err 154 | } 155 | 156 | return nil 157 | }) 158 | 159 | waitGroup.Go(func() error { 160 | <-ctx.Done() 161 | log.Info().Msg("graceful shutdown gRPC server") 162 | 163 | grpcServer.GracefulStop() 164 | log.Info().Msg("gRPC server is stopped") 165 | 166 | return nil 167 | }) 168 | } 169 | 170 | func runGatewayServer( 171 | ctx context.Context, 172 | waitGroup *errgroup.Group, 173 | config util.Config, 174 | store db.Store, 175 | taskDistributor worker.TaskDistributor, 176 | ) { 177 | server, err := gapi.NewServer(config, store, taskDistributor) 178 | if err != nil { 179 | log.Fatal().Err(err).Msg("cannot create server") 180 | } 181 | 182 | jsonOption := runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.JSONPb{ 183 | MarshalOptions: protojson.MarshalOptions{ 184 | UseProtoNames: true, 185 | }, 186 | UnmarshalOptions: protojson.UnmarshalOptions{ 187 | DiscardUnknown: true, 188 | }, 189 | }) 190 | 191 | grpcMux := runtime.NewServeMux(jsonOption) 192 | 193 | err = pb.RegisterSimpleBankHandlerServer(ctx, grpcMux, server) 194 | if err != nil { 195 | log.Fatal().Err(err).Msg("cannot register handler server") 196 | } 197 | 198 | mux := http.NewServeMux() 199 | mux.Handle("/", grpcMux) 200 | 201 | statikFS, err := fs.New() 202 | if err != nil { 203 | log.Fatal().Err(err).Msg("cannot create statik fs") 204 | } 205 | 206 | swaggerHandler := http.StripPrefix("/swagger/", http.FileServer(statikFS)) 207 | mux.Handle("/swagger/", swaggerHandler) 208 | 209 | c := cors.New(cors.Options{ 210 | AllowedOrigins: config.AllowedOrigins, 211 | AllowedMethods: []string{ 212 | http.MethodHead, 213 | http.MethodOptions, 214 | http.MethodGet, 215 | http.MethodPost, 216 | http.MethodPut, 217 | http.MethodPatch, 218 | http.MethodDelete, 219 | }, 220 | AllowedHeaders: []string{ 221 | "Content-Type", 222 | "Authorization", 223 | }, 224 | AllowCredentials: true, 225 | }) 226 | handler := c.Handler(gapi.HttpLogger(mux)) 227 | 228 | httpServer := &http.Server{ 229 | Handler: handler, 230 | Addr: config.HTTPServerAddress, 231 | } 232 | 233 | waitGroup.Go(func() error { 234 | log.Info().Msgf("start HTTP gateway server at %s", httpServer.Addr) 235 | err = httpServer.ListenAndServe() 236 | if err != nil { 237 | if errors.Is(err, http.ErrServerClosed) { 238 | return nil 239 | } 240 | log.Error().Err(err).Msg("HTTP gateway server failed to serve") 241 | return err 242 | } 243 | return nil 244 | }) 245 | 246 | waitGroup.Go(func() error { 247 | <-ctx.Done() 248 | log.Info().Msg("graceful shutdown HTTP gateway server") 249 | 250 | err := httpServer.Shutdown(context.Background()) 251 | if err != nil { 252 | log.Error().Err(err).Msg("failed to shutdown HTTP gateway server") 253 | return err 254 | } 255 | 256 | log.Info().Msg("HTTP gateway server is stopped") 257 | return nil 258 | }) 259 | } 260 | 261 | func runGinServer(config util.Config, store db.Store) { 262 | server, err := api.NewServer(config, store) 263 | if err != nil { 264 | log.Fatal().Err(err).Msg("cannot create server") 265 | } 266 | 267 | err = server.Start(config.HTTPServerAddress) 268 | if err != nil { 269 | log.Fatal().Err(err).Msg("cannot start server") 270 | } 271 | } 272 | -------------------------------------------------------------------------------- /pb/rpc_verify_email.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // versions: 3 | // protoc-gen-go v1.28.1 4 | // protoc v4.24.3 5 | // source: rpc_verify_email.proto 6 | 7 | package pb 8 | 9 | import ( 10 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 11 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 12 | reflect "reflect" 13 | sync "sync" 14 | ) 15 | 16 | const ( 17 | // Verify that this generated code is sufficiently up-to-date. 18 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 19 | // Verify that runtime/protoimpl is sufficiently up-to-date. 20 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 21 | ) 22 | 23 | type VerifyEmailRequest struct { 24 | state protoimpl.MessageState 25 | sizeCache protoimpl.SizeCache 26 | unknownFields protoimpl.UnknownFields 27 | 28 | EmailId int64 `protobuf:"varint,1,opt,name=email_id,json=emailId,proto3" json:"email_id,omitempty"` 29 | SecretCode string `protobuf:"bytes,2,opt,name=secret_code,json=secretCode,proto3" json:"secret_code,omitempty"` 30 | } 31 | 32 | func (x *VerifyEmailRequest) Reset() { 33 | *x = VerifyEmailRequest{} 34 | if protoimpl.UnsafeEnabled { 35 | mi := &file_rpc_verify_email_proto_msgTypes[0] 36 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 37 | ms.StoreMessageInfo(mi) 38 | } 39 | } 40 | 41 | func (x *VerifyEmailRequest) String() string { 42 | return protoimpl.X.MessageStringOf(x) 43 | } 44 | 45 | func (*VerifyEmailRequest) ProtoMessage() {} 46 | 47 | func (x *VerifyEmailRequest) ProtoReflect() protoreflect.Message { 48 | mi := &file_rpc_verify_email_proto_msgTypes[0] 49 | if protoimpl.UnsafeEnabled && x != nil { 50 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 51 | if ms.LoadMessageInfo() == nil { 52 | ms.StoreMessageInfo(mi) 53 | } 54 | return ms 55 | } 56 | return mi.MessageOf(x) 57 | } 58 | 59 | // Deprecated: Use VerifyEmailRequest.ProtoReflect.Descriptor instead. 60 | func (*VerifyEmailRequest) Descriptor() ([]byte, []int) { 61 | return file_rpc_verify_email_proto_rawDescGZIP(), []int{0} 62 | } 63 | 64 | func (x *VerifyEmailRequest) GetEmailId() int64 { 65 | if x != nil { 66 | return x.EmailId 67 | } 68 | return 0 69 | } 70 | 71 | func (x *VerifyEmailRequest) GetSecretCode() string { 72 | if x != nil { 73 | return x.SecretCode 74 | } 75 | return "" 76 | } 77 | 78 | type VerifyEmailResponse struct { 79 | state protoimpl.MessageState 80 | sizeCache protoimpl.SizeCache 81 | unknownFields protoimpl.UnknownFields 82 | 83 | IsVerified bool `protobuf:"varint,1,opt,name=is_verified,json=isVerified,proto3" json:"is_verified,omitempty"` 84 | } 85 | 86 | func (x *VerifyEmailResponse) Reset() { 87 | *x = VerifyEmailResponse{} 88 | if protoimpl.UnsafeEnabled { 89 | mi := &file_rpc_verify_email_proto_msgTypes[1] 90 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 91 | ms.StoreMessageInfo(mi) 92 | } 93 | } 94 | 95 | func (x *VerifyEmailResponse) String() string { 96 | return protoimpl.X.MessageStringOf(x) 97 | } 98 | 99 | func (*VerifyEmailResponse) ProtoMessage() {} 100 | 101 | func (x *VerifyEmailResponse) ProtoReflect() protoreflect.Message { 102 | mi := &file_rpc_verify_email_proto_msgTypes[1] 103 | if protoimpl.UnsafeEnabled && x != nil { 104 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 105 | if ms.LoadMessageInfo() == nil { 106 | ms.StoreMessageInfo(mi) 107 | } 108 | return ms 109 | } 110 | return mi.MessageOf(x) 111 | } 112 | 113 | // Deprecated: Use VerifyEmailResponse.ProtoReflect.Descriptor instead. 114 | func (*VerifyEmailResponse) Descriptor() ([]byte, []int) { 115 | return file_rpc_verify_email_proto_rawDescGZIP(), []int{1} 116 | } 117 | 118 | func (x *VerifyEmailResponse) GetIsVerified() bool { 119 | if x != nil { 120 | return x.IsVerified 121 | } 122 | return false 123 | } 124 | 125 | var File_rpc_verify_email_proto protoreflect.FileDescriptor 126 | 127 | var file_rpc_verify_email_proto_rawDesc = []byte{ 128 | 0x0a, 0x16, 0x72, 0x70, 0x63, 0x5f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x79, 0x5f, 0x65, 0x6d, 0x61, 129 | 0x69, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x02, 0x70, 0x62, 0x22, 0x50, 0x0a, 0x12, 130 | 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x45, 0x6d, 0x61, 0x69, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 131 | 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x5f, 0x69, 0x64, 0x18, 0x01, 132 | 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x49, 0x64, 0x12, 0x1f, 0x0a, 133 | 0x0b, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 134 | 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x43, 0x6f, 0x64, 0x65, 0x22, 0x36, 135 | 0x0a, 0x13, 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x45, 0x6d, 0x61, 0x69, 0x6c, 0x52, 0x65, 0x73, 136 | 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x69, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x69, 137 | 0x66, 0x69, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x69, 0x73, 0x56, 0x65, 138 | 0x72, 0x69, 0x66, 0x69, 0x65, 0x64, 0x42, 0x25, 0x5a, 0x23, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 139 | 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x74, 0x65, 0x63, 0x68, 0x73, 0x63, 0x68, 0x6f, 0x6f, 0x6c, 0x2f, 140 | 0x73, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x62, 0x61, 0x6e, 0x6b, 0x2f, 0x70, 0x62, 0x62, 0x06, 0x70, 141 | 0x72, 0x6f, 0x74, 0x6f, 0x33, 142 | } 143 | 144 | var ( 145 | file_rpc_verify_email_proto_rawDescOnce sync.Once 146 | file_rpc_verify_email_proto_rawDescData = file_rpc_verify_email_proto_rawDesc 147 | ) 148 | 149 | func file_rpc_verify_email_proto_rawDescGZIP() []byte { 150 | file_rpc_verify_email_proto_rawDescOnce.Do(func() { 151 | file_rpc_verify_email_proto_rawDescData = protoimpl.X.CompressGZIP(file_rpc_verify_email_proto_rawDescData) 152 | }) 153 | return file_rpc_verify_email_proto_rawDescData 154 | } 155 | 156 | var file_rpc_verify_email_proto_msgTypes = make([]protoimpl.MessageInfo, 2) 157 | var file_rpc_verify_email_proto_goTypes = []interface{}{ 158 | (*VerifyEmailRequest)(nil), // 0: pb.VerifyEmailRequest 159 | (*VerifyEmailResponse)(nil), // 1: pb.VerifyEmailResponse 160 | } 161 | var file_rpc_verify_email_proto_depIdxs = []int32{ 162 | 0, // [0:0] is the sub-list for method output_type 163 | 0, // [0:0] is the sub-list for method input_type 164 | 0, // [0:0] is the sub-list for extension type_name 165 | 0, // [0:0] is the sub-list for extension extendee 166 | 0, // [0:0] is the sub-list for field type_name 167 | } 168 | 169 | func init() { file_rpc_verify_email_proto_init() } 170 | func file_rpc_verify_email_proto_init() { 171 | if File_rpc_verify_email_proto != nil { 172 | return 173 | } 174 | if !protoimpl.UnsafeEnabled { 175 | file_rpc_verify_email_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { 176 | switch v := v.(*VerifyEmailRequest); i { 177 | case 0: 178 | return &v.state 179 | case 1: 180 | return &v.sizeCache 181 | case 2: 182 | return &v.unknownFields 183 | default: 184 | return nil 185 | } 186 | } 187 | file_rpc_verify_email_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { 188 | switch v := v.(*VerifyEmailResponse); i { 189 | case 0: 190 | return &v.state 191 | case 1: 192 | return &v.sizeCache 193 | case 2: 194 | return &v.unknownFields 195 | default: 196 | return nil 197 | } 198 | } 199 | } 200 | type x struct{} 201 | out := protoimpl.TypeBuilder{ 202 | File: protoimpl.DescBuilder{ 203 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 204 | RawDescriptor: file_rpc_verify_email_proto_rawDesc, 205 | NumEnums: 0, 206 | NumMessages: 2, 207 | NumExtensions: 0, 208 | NumServices: 0, 209 | }, 210 | GoTypes: file_rpc_verify_email_proto_goTypes, 211 | DependencyIndexes: file_rpc_verify_email_proto_depIdxs, 212 | MessageInfos: file_rpc_verify_email_proto_msgTypes, 213 | }.Build() 214 | File_rpc_verify_email_proto = out.File 215 | file_rpc_verify_email_proto_rawDesc = nil 216 | file_rpc_verify_email_proto_goTypes = nil 217 | file_rpc_verify_email_proto_depIdxs = nil 218 | } 219 | -------------------------------------------------------------------------------- /pb/user.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // versions: 3 | // protoc-gen-go v1.28.1 4 | // protoc v4.24.3 5 | // source: user.proto 6 | 7 | package pb 8 | 9 | import ( 10 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 11 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 12 | timestamppb "google.golang.org/protobuf/types/known/timestamppb" 13 | reflect "reflect" 14 | sync "sync" 15 | ) 16 | 17 | const ( 18 | // Verify that this generated code is sufficiently up-to-date. 19 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 20 | // Verify that runtime/protoimpl is sufficiently up-to-date. 21 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 22 | ) 23 | 24 | type User struct { 25 | state protoimpl.MessageState 26 | sizeCache protoimpl.SizeCache 27 | unknownFields protoimpl.UnknownFields 28 | 29 | Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` 30 | FullName string `protobuf:"bytes,2,opt,name=full_name,json=fullName,proto3" json:"full_name,omitempty"` 31 | Email string `protobuf:"bytes,3,opt,name=email,proto3" json:"email,omitempty"` 32 | PasswordChangedAt *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=password_changed_at,json=passwordChangedAt,proto3" json:"password_changed_at,omitempty"` 33 | CreatedAt *timestamppb.Timestamp `protobuf:"bytes,5,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` 34 | } 35 | 36 | func (x *User) Reset() { 37 | *x = User{} 38 | if protoimpl.UnsafeEnabled { 39 | mi := &file_user_proto_msgTypes[0] 40 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 41 | ms.StoreMessageInfo(mi) 42 | } 43 | } 44 | 45 | func (x *User) String() string { 46 | return protoimpl.X.MessageStringOf(x) 47 | } 48 | 49 | func (*User) ProtoMessage() {} 50 | 51 | func (x *User) ProtoReflect() protoreflect.Message { 52 | mi := &file_user_proto_msgTypes[0] 53 | if protoimpl.UnsafeEnabled && x != nil { 54 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 55 | if ms.LoadMessageInfo() == nil { 56 | ms.StoreMessageInfo(mi) 57 | } 58 | return ms 59 | } 60 | return mi.MessageOf(x) 61 | } 62 | 63 | // Deprecated: Use User.ProtoReflect.Descriptor instead. 64 | func (*User) Descriptor() ([]byte, []int) { 65 | return file_user_proto_rawDescGZIP(), []int{0} 66 | } 67 | 68 | func (x *User) GetUsername() string { 69 | if x != nil { 70 | return x.Username 71 | } 72 | return "" 73 | } 74 | 75 | func (x *User) GetFullName() string { 76 | if x != nil { 77 | return x.FullName 78 | } 79 | return "" 80 | } 81 | 82 | func (x *User) GetEmail() string { 83 | if x != nil { 84 | return x.Email 85 | } 86 | return "" 87 | } 88 | 89 | func (x *User) GetPasswordChangedAt() *timestamppb.Timestamp { 90 | if x != nil { 91 | return x.PasswordChangedAt 92 | } 93 | return nil 94 | } 95 | 96 | func (x *User) GetCreatedAt() *timestamppb.Timestamp { 97 | if x != nil { 98 | return x.CreatedAt 99 | } 100 | return nil 101 | } 102 | 103 | var File_user_proto protoreflect.FileDescriptor 104 | 105 | var file_user_proto_rawDesc = []byte{ 106 | 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x02, 0x70, 0x62, 107 | 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 108 | 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 109 | 0x6f, 0x22, 0xdc, 0x01, 0x0a, 0x04, 0x55, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 110 | 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 111 | 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x66, 0x75, 0x6c, 0x6c, 0x5f, 0x6e, 112 | 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x66, 0x75, 0x6c, 0x6c, 0x4e, 113 | 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x03, 0x20, 0x01, 114 | 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x12, 0x4a, 0x0a, 0x13, 0x70, 0x61, 0x73, 115 | 0x73, 0x77, 0x6f, 0x72, 0x64, 0x5f, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x64, 0x5f, 0x61, 0x74, 116 | 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 117 | 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 118 | 0x6d, 0x70, 0x52, 0x11, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x43, 0x68, 0x61, 0x6e, 119 | 0x67, 0x65, 0x64, 0x41, 0x74, 0x12, 0x39, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 120 | 0x5f, 0x61, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 121 | 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 122 | 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 123 | 0x42, 0x25, 0x5a, 0x23, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x74, 124 | 0x65, 0x63, 0x68, 0x73, 0x63, 0x68, 0x6f, 0x6f, 0x6c, 0x2f, 0x73, 0x69, 0x6d, 0x70, 0x6c, 0x65, 125 | 0x62, 0x61, 0x6e, 0x6b, 0x2f, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 126 | } 127 | 128 | var ( 129 | file_user_proto_rawDescOnce sync.Once 130 | file_user_proto_rawDescData = file_user_proto_rawDesc 131 | ) 132 | 133 | func file_user_proto_rawDescGZIP() []byte { 134 | file_user_proto_rawDescOnce.Do(func() { 135 | file_user_proto_rawDescData = protoimpl.X.CompressGZIP(file_user_proto_rawDescData) 136 | }) 137 | return file_user_proto_rawDescData 138 | } 139 | 140 | var file_user_proto_msgTypes = make([]protoimpl.MessageInfo, 1) 141 | var file_user_proto_goTypes = []interface{}{ 142 | (*User)(nil), // 0: pb.User 143 | (*timestamppb.Timestamp)(nil), // 1: google.protobuf.Timestamp 144 | } 145 | var file_user_proto_depIdxs = []int32{ 146 | 1, // 0: pb.User.password_changed_at:type_name -> google.protobuf.Timestamp 147 | 1, // 1: pb.User.created_at:type_name -> google.protobuf.Timestamp 148 | 2, // [2:2] is the sub-list for method output_type 149 | 2, // [2:2] is the sub-list for method input_type 150 | 2, // [2:2] is the sub-list for extension type_name 151 | 2, // [2:2] is the sub-list for extension extendee 152 | 0, // [0:2] is the sub-list for field type_name 153 | } 154 | 155 | func init() { file_user_proto_init() } 156 | func file_user_proto_init() { 157 | if File_user_proto != nil { 158 | return 159 | } 160 | if !protoimpl.UnsafeEnabled { 161 | file_user_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { 162 | switch v := v.(*User); i { 163 | case 0: 164 | return &v.state 165 | case 1: 166 | return &v.sizeCache 167 | case 2: 168 | return &v.unknownFields 169 | default: 170 | return nil 171 | } 172 | } 173 | } 174 | type x struct{} 175 | out := protoimpl.TypeBuilder{ 176 | File: protoimpl.DescBuilder{ 177 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 178 | RawDescriptor: file_user_proto_rawDesc, 179 | NumEnums: 0, 180 | NumMessages: 1, 181 | NumExtensions: 0, 182 | NumServices: 0, 183 | }, 184 | GoTypes: file_user_proto_goTypes, 185 | DependencyIndexes: file_user_proto_depIdxs, 186 | MessageInfos: file_user_proto_msgTypes, 187 | }.Build() 188 | File_user_proto = out.File 189 | file_user_proto_rawDesc = nil 190 | file_user_proto_goTypes = nil 191 | file_user_proto_depIdxs = nil 192 | } 193 | -------------------------------------------------------------------------------- /proto/google/api/annotations.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2015 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | syntax = "proto3"; 16 | 17 | package google.api; 18 | 19 | import "google/api/http.proto"; 20 | import "google/protobuf/descriptor.proto"; 21 | 22 | option go_package = "google.golang.org/genproto/googleapis/api/annotations;annotations"; 23 | option java_multiple_files = true; 24 | option java_outer_classname = "AnnotationsProto"; 25 | option java_package = "com.google.api"; 26 | option objc_class_prefix = "GAPI"; 27 | 28 | extend google.protobuf.MethodOptions { 29 | // See `HttpRule`. 30 | HttpRule http = 72295728; 31 | } 32 | -------------------------------------------------------------------------------- /proto/google/api/field_behavior.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | syntax = "proto3"; 16 | 17 | package google.api; 18 | 19 | import "google/protobuf/descriptor.proto"; 20 | 21 | option go_package = "google.golang.org/genproto/googleapis/api/annotations;annotations"; 22 | option java_multiple_files = true; 23 | option java_outer_classname = "FieldBehaviorProto"; 24 | option java_package = "com.google.api"; 25 | option objc_class_prefix = "GAPI"; 26 | 27 | extend google.protobuf.FieldOptions { 28 | // A designation of a specific field behavior (required, output only, etc.) 29 | // in protobuf messages. 30 | // 31 | // Examples: 32 | // 33 | // string name = 1 [(google.api.field_behavior) = REQUIRED]; 34 | // State state = 1 [(google.api.field_behavior) = OUTPUT_ONLY]; 35 | // google.protobuf.Duration ttl = 1 36 | // [(google.api.field_behavior) = INPUT_ONLY]; 37 | // google.protobuf.Timestamp expire_time = 1 38 | // [(google.api.field_behavior) = OUTPUT_ONLY, 39 | // (google.api.field_behavior) = IMMUTABLE]; 40 | repeated google.api.FieldBehavior field_behavior = 1052; 41 | } 42 | 43 | // An indicator of the behavior of a given field (for example, that a field 44 | // is required in requests, or given as output but ignored as input). 45 | // This **does not** change the behavior in protocol buffers itself; it only 46 | // denotes the behavior and may affect how API tooling handles the field. 47 | // 48 | // Note: This enum **may** receive new values in the future. 49 | enum FieldBehavior { 50 | // Conventional default for enums. Do not use this. 51 | FIELD_BEHAVIOR_UNSPECIFIED = 0; 52 | 53 | // Specifically denotes a field as optional. 54 | // While all fields in protocol buffers are optional, this may be specified 55 | // for emphasis if appropriate. 56 | OPTIONAL = 1; 57 | 58 | // Denotes a field as required. 59 | // This indicates that the field **must** be provided as part of the request, 60 | // and failure to do so will cause an error (usually `INVALID_ARGUMENT`). 61 | REQUIRED = 2; 62 | 63 | // Denotes a field as output only. 64 | // This indicates that the field is provided in responses, but including the 65 | // field in a request does nothing (the server *must* ignore it and 66 | // *must not* throw an error as a result of the field's presence). 67 | OUTPUT_ONLY = 3; 68 | 69 | // Denotes a field as input only. 70 | // This indicates that the field is provided in requests, and the 71 | // corresponding field is not included in output. 72 | INPUT_ONLY = 4; 73 | 74 | // Denotes a field as immutable. 75 | // This indicates that the field may be set once in a request to create a 76 | // resource, but may not be changed thereafter. 77 | IMMUTABLE = 5; 78 | 79 | // Denotes that a (repeated) field is an unordered list. 80 | // This indicates that the service may provide the elements of the list 81 | // in any arbitrary order, rather than the order the user originally 82 | // provided. Additionally, the list's order may or may not be stable. 83 | UNORDERED_LIST = 6; 84 | 85 | // Denotes that this field returns a non-empty default value if not set. 86 | // This indicates that if the user provides the empty value in a request, 87 | // a non-empty value will be returned. The user will not be aware of what 88 | // non-empty value to expect. 89 | NON_EMPTY_DEFAULT = 7; 90 | } 91 | -------------------------------------------------------------------------------- /proto/google/api/httpbody.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2015 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | syntax = "proto3"; 16 | 17 | package google.api; 18 | 19 | import "google/protobuf/any.proto"; 20 | 21 | option cc_enable_arenas = true; 22 | option go_package = "google.golang.org/genproto/googleapis/api/httpbody;httpbody"; 23 | option java_multiple_files = true; 24 | option java_outer_classname = "HttpBodyProto"; 25 | option java_package = "com.google.api"; 26 | option objc_class_prefix = "GAPI"; 27 | 28 | // Message that represents an arbitrary HTTP body. It should only be used for 29 | // payload formats that can't be represented as JSON, such as raw binary or 30 | // an HTML page. 31 | // 32 | // 33 | // This message can be used both in streaming and non-streaming API methods in 34 | // the request as well as the response. 35 | // 36 | // It can be used as a top-level request field, which is convenient if one 37 | // wants to extract parameters from either the URL or HTTP template into the 38 | // request fields and also want access to the raw HTTP body. 39 | // 40 | // Example: 41 | // 42 | // message GetResourceRequest { 43 | // // A unique request id. 44 | // string request_id = 1; 45 | // 46 | // // The raw HTTP body is bound to this field. 47 | // google.api.HttpBody http_body = 2; 48 | // 49 | // } 50 | // 51 | // service ResourceService { 52 | // rpc GetResource(GetResourceRequest) 53 | // returns (google.api.HttpBody); 54 | // rpc UpdateResource(google.api.HttpBody) 55 | // returns (google.protobuf.Empty); 56 | // 57 | // } 58 | // 59 | // Example with streaming methods: 60 | // 61 | // service CaldavService { 62 | // rpc GetCalendar(stream google.api.HttpBody) 63 | // returns (stream google.api.HttpBody); 64 | // rpc UpdateCalendar(stream google.api.HttpBody) 65 | // returns (stream google.api.HttpBody); 66 | // 67 | // } 68 | // 69 | // Use of this type only changes how the request and response bodies are 70 | // handled, all other features will continue to work unchanged. 71 | message HttpBody { 72 | // The HTTP Content-Type header value specifying the content type of the body. 73 | string content_type = 1; 74 | 75 | // The HTTP request/response body as raw binary. 76 | bytes data = 2; 77 | 78 | // Application specific response metadata. Must be set in the first response 79 | // for streaming APIs. 80 | repeated google.protobuf.Any extensions = 3; 81 | } 82 | -------------------------------------------------------------------------------- /proto/protoc-gen-openapiv2/options/annotations.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package grpc.gateway.protoc_gen_openapiv2.options; 4 | 5 | option go_package = "github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv2/options"; 6 | 7 | import "google/protobuf/descriptor.proto"; 8 | import "protoc-gen-openapiv2/options/openapiv2.proto"; 9 | 10 | extend google.protobuf.FileOptions { 11 | // ID assigned by protobuf-global-extension-registry@google.com for gRPC-Gateway project. 12 | // 13 | // All IDs are the same, as assigned. It is okay that they are the same, as they extend 14 | // different descriptor messages. 15 | Swagger openapiv2_swagger = 1042; 16 | } 17 | extend google.protobuf.MethodOptions { 18 | // ID assigned by protobuf-global-extension-registry@google.com for gRPC-Gateway project. 19 | // 20 | // All IDs are the same, as assigned. It is okay that they are the same, as they extend 21 | // different descriptor messages. 22 | Operation openapiv2_operation = 1042; 23 | } 24 | extend google.protobuf.MessageOptions { 25 | // ID assigned by protobuf-global-extension-registry@google.com for gRPC-Gateway project. 26 | // 27 | // All IDs are the same, as assigned. It is okay that they are the same, as they extend 28 | // different descriptor messages. 29 | Schema openapiv2_schema = 1042; 30 | } 31 | extend google.protobuf.ServiceOptions { 32 | // ID assigned by protobuf-global-extension-registry@google.com for gRPC-Gateway project. 33 | // 34 | // All IDs are the same, as assigned. It is okay that they are the same, as they extend 35 | // different descriptor messages. 36 | Tag openapiv2_tag = 1042; 37 | } 38 | extend google.protobuf.FieldOptions { 39 | // ID assigned by protobuf-global-extension-registry@google.com for gRPC-Gateway project. 40 | // 41 | // All IDs are the same, as assigned. It is okay that they are the same, as they extend 42 | // different descriptor messages. 43 | JSONSchema openapiv2_field = 1042; 44 | } 45 | -------------------------------------------------------------------------------- /proto/rpc_create_user.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package pb; 4 | 5 | import "user.proto"; 6 | 7 | option go_package = "github.com/techschool/simplebank/pb"; 8 | 9 | message CreateUserRequest { 10 | string username = 1; 11 | string full_name = 2; 12 | string email = 3; 13 | string password = 4; 14 | } 15 | 16 | message CreateUserResponse { 17 | User user = 1; 18 | } 19 | -------------------------------------------------------------------------------- /proto/rpc_login_user.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package pb; 4 | 5 | import "user.proto"; 6 | import "google/protobuf/timestamp.proto"; 7 | 8 | option go_package = "github.com/techschool/simplebank/pb"; 9 | 10 | message LoginUserRequest { 11 | string username = 1; 12 | string password = 2; 13 | } 14 | 15 | message LoginUserResponse { 16 | User user = 1; 17 | string session_id = 2; 18 | string access_token = 3; 19 | string refresh_token = 4; 20 | google.protobuf.Timestamp access_token_expires_at = 5; 21 | google.protobuf.Timestamp refresh_token_expires_at = 6; 22 | } 23 | -------------------------------------------------------------------------------- /proto/rpc_update_user.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package pb; 4 | 5 | import "user.proto"; 6 | 7 | option go_package = "github.com/techschool/simplebank/pb"; 8 | 9 | message UpdateUserRequest { 10 | string username = 1; 11 | optional string full_name = 2; 12 | optional string email = 3; 13 | optional string password = 4; 14 | } 15 | 16 | message UpdateUserResponse { 17 | User user = 1; 18 | } 19 | -------------------------------------------------------------------------------- /proto/rpc_verify_email.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package pb; 4 | 5 | option go_package = "github.com/techschool/simplebank/pb"; 6 | 7 | message VerifyEmailRequest { 8 | int64 email_id = 1; 9 | string secret_code = 2; 10 | } 11 | 12 | message VerifyEmailResponse { 13 | bool is_verified = 1; 14 | } 15 | -------------------------------------------------------------------------------- /proto/service_simple_bank.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package pb; 4 | 5 | import "google/api/annotations.proto"; 6 | import "rpc_create_user.proto"; 7 | import "rpc_update_user.proto"; 8 | import "rpc_login_user.proto"; 9 | import "rpc_verify_email.proto"; 10 | import "protoc-gen-openapiv2/options/annotations.proto"; 11 | 12 | option go_package = "github.com/techschool/simplebank/pb"; 13 | 14 | option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = { 15 | info: { 16 | title: "Simple Bank API"; 17 | version: "1.2"; 18 | contact: { 19 | name: "Tech School"; 20 | url: "https://github.com/techschool"; 21 | email: "techschool.guru@gmail.com"; 22 | }; 23 | }; 24 | }; 25 | 26 | service SimpleBank { 27 | rpc CreateUser (CreateUserRequest) returns (CreateUserResponse) { 28 | option (google.api.http) = { 29 | post: "/v1/create_user" 30 | body: "*" 31 | }; 32 | option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { 33 | description: "Use this API to create a new user"; 34 | summary: "Create new user"; 35 | }; 36 | } 37 | rpc UpdateUser (UpdateUserRequest) returns (UpdateUserResponse) { 38 | option (google.api.http) = { 39 | patch: "/v1/update_user" 40 | body: "*" 41 | }; 42 | option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { 43 | description: "Use this API to update user"; 44 | summary: "Update user"; 45 | }; 46 | } 47 | rpc LoginUser (LoginUserRequest) returns (LoginUserResponse) { 48 | option (google.api.http) = { 49 | post: "/v1/login_user" 50 | body: "*" 51 | }; 52 | option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { 53 | description: "Use this API to login user and get access token & refresh token"; 54 | summary: "Login user"; 55 | }; 56 | } 57 | rpc VerifyEmail (VerifyEmailRequest) returns (VerifyEmailResponse) { 58 | option (google.api.http) = { 59 | get: "/v1/verify_email" 60 | }; 61 | option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { 62 | description: "Use this API to verify user's email address"; 63 | summary: "Verify email"; 64 | }; 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /proto/user.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package pb; 4 | 5 | import "google/protobuf/timestamp.proto"; 6 | 7 | option go_package = "github.com/techschool/simplebank/pb"; 8 | 9 | message User { 10 | string username = 1; 11 | string full_name = 2; 12 | string email = 3; 13 | google.protobuf.Timestamp password_changed_at = 4; 14 | google.protobuf.Timestamp created_at = 5; 15 | } 16 | -------------------------------------------------------------------------------- /simplebank: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/techschool/simplebank/97f000fe58ad01a0774179ffa8884ac7784cf263/simplebank -------------------------------------------------------------------------------- /sqlc.yaml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | sql: 3 | - schema: "db/migration" 4 | queries: "db/query" 5 | engine: "postgresql" 6 | gen: 7 | go: 8 | package: "db" 9 | out: "db/sqlc" 10 | sql_package: "pgx/v5" 11 | emit_json_tags: true 12 | emit_interface: true 13 | emit_empty_slices: true 14 | overrides: 15 | - db_type: "timestamptz" 16 | go_type: "time.Time" 17 | - db_type: "uuid" 18 | go_type: "github.com/google/uuid.UUID" 19 | -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | echo "start the app" 6 | exec "$@" 7 | -------------------------------------------------------------------------------- /token/jwt_maker.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/golang-jwt/jwt/v5" 9 | ) 10 | 11 | const minSecretKeySize = 32 12 | 13 | // JWTMaker is a JSON Web Token maker 14 | type JWTMaker struct { 15 | secretKey string 16 | } 17 | 18 | // NewJWTMaker creates a new JWTMaker 19 | func NewJWTMaker(secretKey string) (Maker, error) { 20 | if len(secretKey) < minSecretKeySize { 21 | return nil, fmt.Errorf("invalid key size: must be at least %d characters", minSecretKeySize) 22 | } 23 | return &JWTMaker{secretKey}, nil 24 | } 25 | 26 | // CreateToken creates a new token for a specific username and duration 27 | func (maker *JWTMaker) CreateToken(username string, role string, duration time.Duration, tokenType TokenType) (string, *Payload, error) { 28 | payload, err := NewPayload(username, role, duration, tokenType) 29 | if err != nil { 30 | return "", payload, err 31 | } 32 | 33 | jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) 34 | token, err := jwtToken.SignedString([]byte(maker.secretKey)) 35 | return token, payload, err 36 | } 37 | 38 | // VerifyToken checks if the token is valid or not 39 | func (maker *JWTMaker) VerifyToken(token string, tokenType TokenType) (*Payload, error) { 40 | keyFunc := func(token *jwt.Token) (interface{}, error) { 41 | _, ok := token.Method.(*jwt.SigningMethodHMAC) 42 | if !ok { 43 | return nil, ErrInvalidToken 44 | } 45 | return []byte(maker.secretKey), nil 46 | } 47 | 48 | jwtToken, err := jwt.ParseWithClaims(token, &Payload{}, keyFunc) 49 | if err != nil { 50 | if errors.Is(err, jwt.ErrTokenExpired) { 51 | return nil, ErrExpiredToken 52 | } 53 | return nil, ErrInvalidToken 54 | } 55 | 56 | payload, ok := jwtToken.Claims.(*Payload) 57 | if !ok { 58 | return nil, ErrInvalidToken 59 | } 60 | 61 | err = payload.Valid(tokenType) 62 | if err != nil { 63 | return nil, err 64 | } 65 | 66 | return payload, nil 67 | } 68 | -------------------------------------------------------------------------------- /token/jwt_maker_test.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/golang-jwt/jwt/v5" 8 | "github.com/stretchr/testify/require" 9 | "github.com/techschool/simplebank/util" 10 | ) 11 | 12 | func TestJWTMaker(t *testing.T) { 13 | maker, err := NewJWTMaker(util.RandomString(32)) 14 | require.NoError(t, err) 15 | 16 | username := util.RandomOwner() 17 | role := util.DepositorRole 18 | duration := time.Minute 19 | 20 | issuedAt := time.Now() 21 | expiredAt := issuedAt.Add(duration) 22 | 23 | token, payload, err := maker.CreateToken(username, role, duration, TokenTypeAccessToken) 24 | require.NoError(t, err) 25 | require.NotEmpty(t, token) 26 | require.NotEmpty(t, payload) 27 | 28 | payload, err = maker.VerifyToken(token, TokenTypeAccessToken) 29 | require.NoError(t, err) 30 | require.NotEmpty(t, token) 31 | 32 | require.NotZero(t, payload.ID) 33 | require.Equal(t, username, payload.Username) 34 | require.Equal(t, role, payload.Role) 35 | require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second) 36 | require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second) 37 | } 38 | 39 | func TestExpiredJWTToken(t *testing.T) { 40 | maker, err := NewJWTMaker(util.RandomString(32)) 41 | require.NoError(t, err) 42 | 43 | token, payload, err := maker.CreateToken(util.RandomOwner(), util.DepositorRole, -time.Minute, TokenTypeAccessToken) 44 | require.NoError(t, err) 45 | require.NotEmpty(t, token) 46 | require.NotEmpty(t, payload) 47 | 48 | payload, err = maker.VerifyToken(token, TokenTypeAccessToken) 49 | require.Error(t, err) 50 | require.EqualError(t, err, ErrExpiredToken.Error()) 51 | require.Nil(t, payload) 52 | } 53 | 54 | func TestInvalidJWTTokenAlgNone(t *testing.T) { 55 | payload, err := NewPayload(util.RandomOwner(), util.DepositorRole, time.Minute, TokenTypeAccessToken) 56 | require.NoError(t, err) 57 | 58 | jwtToken := jwt.NewWithClaims(jwt.SigningMethodNone, payload) 59 | token, err := jwtToken.SignedString(jwt.UnsafeAllowNoneSignatureType) 60 | require.NoError(t, err) 61 | 62 | maker, err := NewJWTMaker(util.RandomString(32)) 63 | require.NoError(t, err) 64 | 65 | payload, err = maker.VerifyToken(token, TokenTypeAccessToken) 66 | require.Error(t, err) 67 | require.EqualError(t, err, ErrInvalidToken.Error()) 68 | require.Nil(t, payload) 69 | } 70 | 71 | func TestJWTWrongTokenType(t *testing.T) { 72 | maker, err := NewJWTMaker(util.RandomString(32)) 73 | require.NoError(t, err) 74 | 75 | token, payload, err := maker.CreateToken(util.RandomOwner(), util.DepositorRole, time.Minute, TokenTypeAccessToken) 76 | require.NoError(t, err) 77 | require.NotEmpty(t, token) 78 | require.NotEmpty(t, payload) 79 | 80 | payload, err = maker.VerifyToken(token, TokenTypeRefreshToken) 81 | require.Error(t, err) 82 | require.EqualError(t, err, ErrInvalidToken.Error()) 83 | require.Nil(t, payload) 84 | } 85 | -------------------------------------------------------------------------------- /token/maker.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | // Maker is an interface for managing tokens 8 | type Maker interface { 9 | // CreateToken creates a new token for a specific username and duration 10 | CreateToken(username string, role string, duration time.Duration, tokenType TokenType) (string, *Payload, error) 11 | 12 | // VerifyToken checks if the token is valid or not 13 | VerifyToken(token string, tokenType TokenType) (*Payload, error) 14 | } 15 | -------------------------------------------------------------------------------- /token/paseto_maker.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/aead/chacha20poly1305" 8 | "github.com/o1egl/paseto" 9 | ) 10 | 11 | // PasetoMaker is a PASETO token maker 12 | type PasetoMaker struct { 13 | paseto *paseto.V2 14 | symmetricKey []byte 15 | } 16 | 17 | // NewPasetoMaker creates a new PasetoMaker 18 | func NewPasetoMaker(symmetricKey string) (Maker, error) { 19 | if len(symmetricKey) != chacha20poly1305.KeySize { 20 | return nil, fmt.Errorf("invalid key size: must be exactly %d characters", chacha20poly1305.KeySize) 21 | } 22 | 23 | maker := &PasetoMaker{ 24 | paseto: paseto.NewV2(), 25 | symmetricKey: []byte(symmetricKey), 26 | } 27 | 28 | return maker, nil 29 | } 30 | 31 | // CreateToken creates a new token for a specific username and duration 32 | func (maker *PasetoMaker) CreateToken(username string, role string, duration time.Duration, tokenType TokenType) (string, *Payload, error) { 33 | payload, err := NewPayload(username, role, duration, tokenType) 34 | if err != nil { 35 | return "", payload, err 36 | } 37 | 38 | token, err := maker.paseto.Encrypt(maker.symmetricKey, payload, nil) 39 | return token, payload, err 40 | } 41 | 42 | // VerifyToken checks if the token is valid or not 43 | func (maker *PasetoMaker) VerifyToken(token string, tokenType TokenType) (*Payload, error) { 44 | payload := &Payload{} 45 | 46 | err := maker.paseto.Decrypt(token, maker.symmetricKey, payload, nil) 47 | if err != nil { 48 | return nil, ErrInvalidToken 49 | } 50 | 51 | err = payload.Valid(tokenType) 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return payload, nil 57 | } 58 | -------------------------------------------------------------------------------- /token/paseto_maker_test.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/require" 8 | "github.com/techschool/simplebank/util" 9 | ) 10 | 11 | func TestPasetoMaker(t *testing.T) { 12 | maker, err := NewPasetoMaker(util.RandomString(32)) 13 | require.NoError(t, err) 14 | 15 | username := util.RandomOwner() 16 | role := util.DepositorRole 17 | duration := time.Minute 18 | 19 | issuedAt := time.Now() 20 | expiredAt := issuedAt.Add(duration) 21 | 22 | token, payload, err := maker.CreateToken(username, role, duration, TokenTypeAccessToken) 23 | require.NoError(t, err) 24 | require.NotEmpty(t, token) 25 | require.NotEmpty(t, payload) 26 | 27 | payload, err = maker.VerifyToken(token, TokenTypeAccessToken) 28 | require.NoError(t, err) 29 | require.NotEmpty(t, token) 30 | 31 | require.NotZero(t, payload.ID) 32 | require.Equal(t, username, payload.Username) 33 | require.Equal(t, role, payload.Role) 34 | require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second) 35 | require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second) 36 | } 37 | 38 | func TestExpiredPasetoToken(t *testing.T) { 39 | maker, err := NewPasetoMaker(util.RandomString(32)) 40 | require.NoError(t, err) 41 | 42 | token, payload, err := maker.CreateToken(util.RandomOwner(), util.DepositorRole, -time.Minute, TokenTypeAccessToken) 43 | require.NoError(t, err) 44 | require.NotEmpty(t, token) 45 | require.NotEmpty(t, payload) 46 | 47 | payload, err = maker.VerifyToken(token, TokenTypeAccessToken) 48 | require.Error(t, err) 49 | require.EqualError(t, err, ErrExpiredToken.Error()) 50 | require.Nil(t, payload) 51 | } 52 | 53 | func TestPasetoWrongTokenType(t *testing.T) { 54 | maker, err := NewPasetoMaker(util.RandomString(32)) 55 | require.NoError(t, err) 56 | 57 | token, payload, err := maker.CreateToken(util.RandomOwner(), util.DepositorRole, time.Minute, TokenTypeAccessToken) 58 | require.NoError(t, err) 59 | require.NotEmpty(t, token) 60 | require.NotEmpty(t, payload) 61 | 62 | payload, err = maker.VerifyToken(token, TokenTypeRefreshToken) 63 | require.Error(t, err) 64 | require.EqualError(t, err, ErrInvalidToken.Error()) 65 | require.Nil(t, payload) 66 | } 67 | -------------------------------------------------------------------------------- /token/payload.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | 7 | "github.com/golang-jwt/jwt/v5" 8 | "github.com/google/uuid" 9 | ) 10 | 11 | // Different types of error returned by the VerifyToken function 12 | var ( 13 | ErrInvalidToken = errors.New("token is invalid") 14 | ErrExpiredToken = errors.New("token has expired") 15 | ) 16 | 17 | type TokenType byte 18 | 19 | const ( 20 | TokenTypeAccessToken = 1 21 | TokenTypeRefreshToken = 2 22 | ) 23 | 24 | // Payload contains the payload data of the token 25 | type Payload struct { 26 | ID uuid.UUID `json:"id"` 27 | Type TokenType `json:"token_type"` 28 | Username string `json:"username"` 29 | Role string `json:"role"` 30 | IssuedAt time.Time `json:"issued_at"` 31 | ExpiredAt time.Time `json:"expired_at"` 32 | } 33 | 34 | // NewPayload creates a new token payload with a specific username and duration 35 | func NewPayload(username string, role string, duration time.Duration, tokenType TokenType) (*Payload, error) { 36 | tokenID, err := uuid.NewRandom() 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | payload := &Payload{ 42 | ID: tokenID, 43 | Type: tokenType, 44 | Username: username, 45 | Role: role, 46 | IssuedAt: time.Now(), 47 | ExpiredAt: time.Now().Add(duration), 48 | } 49 | return payload, nil 50 | } 51 | 52 | // Valid checks if the token payload is valid or not 53 | func (payload *Payload) Valid(tokenType TokenType) error { 54 | if payload.Type != tokenType { 55 | return ErrInvalidToken 56 | } 57 | if time.Now().After(payload.ExpiredAt) { 58 | return ErrExpiredToken 59 | } 60 | return nil 61 | } 62 | 63 | func (payload *Payload) GetExpirationTime() (*jwt.NumericDate, error) { 64 | return &jwt.NumericDate{ 65 | Time: payload.ExpiredAt, 66 | }, nil 67 | } 68 | 69 | func (payload *Payload) GetIssuedAt() (*jwt.NumericDate, error) { 70 | return &jwt.NumericDate{ 71 | Time: payload.IssuedAt, 72 | }, nil 73 | } 74 | 75 | func (payload *Payload) GetNotBefore() (*jwt.NumericDate, error) { 76 | return &jwt.NumericDate{ 77 | Time: payload.IssuedAt, 78 | }, nil 79 | } 80 | 81 | func (payload *Payload) GetIssuer() (string, error) { 82 | return "", nil 83 | } 84 | 85 | func (payload *Payload) GetSubject() (string, error) { 86 | return "", nil 87 | } 88 | 89 | func (payload *Payload) GetAudience() (jwt.ClaimStrings, error) { 90 | return jwt.ClaimStrings{}, nil 91 | } 92 | -------------------------------------------------------------------------------- /util/config.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/spf13/viper" 7 | ) 8 | 9 | // Config stores all configuration of the application. 10 | // The values are read by viper from a config file or environment variable. 11 | type Config struct { 12 | Environment string `mapstructure:"ENVIRONMENT"` 13 | AllowedOrigins []string `mapstructure:"ALLOWED_ORIGINS"` 14 | DBSource string `mapstructure:"DB_SOURCE"` 15 | MigrationURL string `mapstructure:"MIGRATION_URL"` 16 | RedisAddress string `mapstructure:"REDIS_ADDRESS"` 17 | HTTPServerAddress string `mapstructure:"HTTP_SERVER_ADDRESS"` 18 | GRPCServerAddress string `mapstructure:"GRPC_SERVER_ADDRESS"` 19 | TokenSymmetricKey string `mapstructure:"TOKEN_SYMMETRIC_KEY"` 20 | AccessTokenDuration time.Duration `mapstructure:"ACCESS_TOKEN_DURATION"` 21 | RefreshTokenDuration time.Duration `mapstructure:"REFRESH_TOKEN_DURATION"` 22 | EmailSenderName string `mapstructure:"EMAIL_SENDER_NAME"` 23 | EmailSenderAddress string `mapstructure:"EMAIL_SENDER_ADDRESS"` 24 | EmailSenderPassword string `mapstructure:"EMAIL_SENDER_PASSWORD"` 25 | } 26 | 27 | // LoadConfig reads configuration from file or environment variables. 28 | func LoadConfig(path string) (config Config, err error) { 29 | viper.AddConfigPath(path) 30 | viper.SetConfigName("app") 31 | viper.SetConfigType("env") 32 | 33 | viper.AutomaticEnv() 34 | 35 | err = viper.ReadInConfig() 36 | if err != nil { 37 | return 38 | } 39 | 40 | err = viper.Unmarshal(&config) 41 | return 42 | } 43 | -------------------------------------------------------------------------------- /util/currency.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | // Constants for all supported currencies 4 | const ( 5 | USD = "USD" 6 | EUR = "EUR" 7 | CAD = "CAD" 8 | ) 9 | 10 | // IsSupportedCurrency returns true if the currency is supported 11 | func IsSupportedCurrency(currency string) bool { 12 | switch currency { 13 | case USD, EUR, CAD: 14 | return true 15 | } 16 | return false 17 | } 18 | -------------------------------------------------------------------------------- /util/password.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "fmt" 5 | 6 | "golang.org/x/crypto/bcrypt" 7 | ) 8 | 9 | // HashPassword returns the bcrypt hash of the password 10 | func HashPassword(password string) (string, error) { 11 | hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) 12 | if err != nil { 13 | return "", fmt.Errorf("failed to hash password: %w", err) 14 | } 15 | return string(hashedPassword), nil 16 | } 17 | 18 | // CheckPassword checks if the provided password is correct or not 19 | func CheckPassword(password string, hashedPassword string) error { 20 | return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) 21 | } 22 | -------------------------------------------------------------------------------- /util/password_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | "golang.org/x/crypto/bcrypt" 8 | ) 9 | 10 | func TestPassword(t *testing.T) { 11 | password := RandomString(6) 12 | 13 | hashedPassword1, err := HashPassword(password) 14 | require.NoError(t, err) 15 | require.NotEmpty(t, hashedPassword1) 16 | 17 | err = CheckPassword(password, hashedPassword1) 18 | require.NoError(t, err) 19 | 20 | wrongPassword := RandomString(6) 21 | err = CheckPassword(wrongPassword, hashedPassword1) 22 | require.EqualError(t, err, bcrypt.ErrMismatchedHashAndPassword.Error()) 23 | 24 | hashedPassword2, err := HashPassword(password) 25 | require.NoError(t, err) 26 | require.NotEmpty(t, hashedPassword2) 27 | require.NotEqual(t, hashedPassword1, hashedPassword2) 28 | } 29 | -------------------------------------------------------------------------------- /util/random.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "strings" 7 | "time" 8 | ) 9 | 10 | const alphabet = "abcdefghijklmnopqrstuvwxyz" 11 | 12 | func init() { 13 | rand.Seed(time.Now().UnixNano()) 14 | } 15 | 16 | // RandomInt generates a random integer between min and max 17 | func RandomInt(min, max int64) int64 { 18 | return min + rand.Int63n(max-min+1) 19 | } 20 | 21 | // RandomString generates a random string of length n 22 | func RandomString(n int) string { 23 | var sb strings.Builder 24 | k := len(alphabet) 25 | 26 | for i := 0; i < n; i++ { 27 | c := alphabet[rand.Intn(k)] 28 | sb.WriteByte(c) 29 | } 30 | 31 | return sb.String() 32 | } 33 | 34 | // RandomOwner generates a random owner name 35 | func RandomOwner() string { 36 | return RandomString(6) 37 | } 38 | 39 | // RandomMoney generates a random amount of money 40 | func RandomMoney() int64 { 41 | return RandomInt(0, 1000) 42 | } 43 | 44 | // RandomCurrency generates a random currency code 45 | func RandomCurrency() string { 46 | currencies := []string{USD, EUR, CAD} 47 | n := len(currencies) 48 | return currencies[rand.Intn(n)] 49 | } 50 | 51 | // RandomEmail generates a random email 52 | func RandomEmail() string { 53 | return fmt.Sprintf("%s@email.com", RandomString(6)) 54 | } 55 | -------------------------------------------------------------------------------- /util/role.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | const ( 4 | DepositorRole = "depositor" 5 | BankerRole = "banker" 6 | ) 7 | -------------------------------------------------------------------------------- /val/validator.go: -------------------------------------------------------------------------------- 1 | package val 2 | 3 | import ( 4 | "fmt" 5 | "net/mail" 6 | "regexp" 7 | ) 8 | 9 | var ( 10 | isValidUsername = regexp.MustCompile(`^[a-z0-9_]+$`).MatchString 11 | isValidFullName = regexp.MustCompile(`^[a-zA-Z\s]+$`).MatchString 12 | ) 13 | 14 | func ValidateString(value string, minLength int, maxLength int) error { 15 | n := len(value) 16 | if n < minLength || n > maxLength { 17 | return fmt.Errorf("must contain from %d-%d characters", minLength, maxLength) 18 | } 19 | return nil 20 | } 21 | 22 | func ValidateUsername(value string) error { 23 | if err := ValidateString(value, 3, 100); err != nil { 24 | return err 25 | } 26 | if !isValidUsername(value) { 27 | return fmt.Errorf("must contain only lowercase letters, digits, or underscore") 28 | } 29 | return nil 30 | } 31 | 32 | func ValidateFullName(value string) error { 33 | if err := ValidateString(value, 3, 100); err != nil { 34 | return err 35 | } 36 | if !isValidFullName(value) { 37 | return fmt.Errorf("must contain only letters or spaces") 38 | } 39 | return nil 40 | } 41 | 42 | func ValidatePassword(value string) error { 43 | return ValidateString(value, 6, 100) 44 | } 45 | 46 | func ValidateEmail(value string) error { 47 | if err := ValidateString(value, 3, 200); err != nil { 48 | return err 49 | } 50 | if _, err := mail.ParseAddress(value); err != nil { 51 | return fmt.Errorf("is not a valid email address") 52 | } 53 | return nil 54 | } 55 | 56 | func ValidateEmailId(value int64) error { 57 | if value <= 0 { 58 | return fmt.Errorf("must be a positive integer") 59 | } 60 | return nil 61 | } 62 | 63 | func ValidateSecretCode(value string) error { 64 | return ValidateString(value, 32, 128) 65 | } 66 | -------------------------------------------------------------------------------- /wait-for.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2017 Eficode Oy 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | set -- "$@" -- "$TIMEOUT" "$QUIET" "$PROTOCOL" "$HOST" "$PORT" "$result" 26 | TIMEOUT=15 27 | QUIET=0 28 | # The protocol to make the request with, either "tcp" or "http" 29 | PROTOCOL="tcp" 30 | 31 | echoerr() { 32 | if [ "$QUIET" -ne 1 ]; then printf "%s\n" "$*" 1>&2; fi 33 | } 34 | 35 | usage() { 36 | exitcode="$1" 37 | cat << USAGE >&2 38 | Usage: 39 | $0 host:port|url [-t timeout] [-- command args] 40 | -q | --quiet Do not output any status messages 41 | -t TIMEOUT | --timeout=timeout Timeout in seconds, zero for no timeout 42 | -- COMMAND ARGS Execute command with args after the test finishes 43 | USAGE 44 | exit "$exitcode" 45 | } 46 | 47 | wait_for() { 48 | case "$PROTOCOL" in 49 | tcp) 50 | if ! command -v nc >/dev/null; then 51 | echoerr 'nc command is missing!' 52 | exit 1 53 | fi 54 | ;; 55 | wget) 56 | if ! command -v wget >/dev/null; then 57 | echoerr 'nc command is missing!' 58 | exit 1 59 | fi 60 | ;; 61 | esac 62 | 63 | while :; do 64 | case "$PROTOCOL" in 65 | tcp) 66 | nc -z "$HOST" "$PORT" > /dev/null 2>&1 67 | ;; 68 | http) 69 | wget --timeout=1 -q "$HOST" -O /dev/null > /dev/null 2>&1 70 | ;; 71 | *) 72 | echoerr "Unknown protocol '$PROTOCOL'" 73 | exit 1 74 | ;; 75 | esac 76 | 77 | result=$? 78 | 79 | if [ $result -eq 0 ] ; then 80 | if [ $# -gt 7 ] ; then 81 | for result in $(seq $(($# - 7))); do 82 | result=$1 83 | shift 84 | set -- "$@" "$result" 85 | done 86 | 87 | TIMEOUT=$2 QUIET=$3 PROTOCOL=$4 HOST=$5 PORT=$6 result=$7 88 | shift 7 89 | exec "$@" 90 | fi 91 | exit 0 92 | fi 93 | 94 | if [ "$TIMEOUT" -le 0 ]; then 95 | break 96 | fi 97 | TIMEOUT=$((TIMEOUT - 1)) 98 | 99 | sleep 1 100 | done 101 | echo "Operation timed out" >&2 102 | exit 1 103 | } 104 | 105 | while :; do 106 | case "$1" in 107 | http://*|https://*) 108 | HOST="$1" 109 | PROTOCOL="http" 110 | shift 1 111 | ;; 112 | *:* ) 113 | HOST=$(printf "%s\n" "$1"| cut -d : -f 1) 114 | PORT=$(printf "%s\n" "$1"| cut -d : -f 2) 115 | shift 1 116 | ;; 117 | -q | --quiet) 118 | QUIET=1 119 | shift 1 120 | ;; 121 | -q-*) 122 | QUIET=0 123 | echoerr "Unknown option: $1" 124 | usage 1 125 | ;; 126 | -q*) 127 | QUIET=1 128 | result=$1 129 | shift 1 130 | set -- -"${result#-q}" "$@" 131 | ;; 132 | -t | --timeout) 133 | TIMEOUT="$2" 134 | shift 2 135 | ;; 136 | -t*) 137 | TIMEOUT="${1#-t}" 138 | shift 1 139 | ;; 140 | --timeout=*) 141 | TIMEOUT="${1#*=}" 142 | shift 1 143 | ;; 144 | --) 145 | shift 146 | break 147 | ;; 148 | --help) 149 | usage 0 150 | ;; 151 | -*) 152 | QUIET=0 153 | echoerr "Unknown option: $1" 154 | usage 1 155 | ;; 156 | *) 157 | QUIET=0 158 | echoerr "Unknown argument: $1" 159 | usage 1 160 | ;; 161 | esac 162 | done 163 | 164 | if ! [ "$TIMEOUT" -ge 0 ] 2>/dev/null; then 165 | echoerr "Error: invalid timeout '$TIMEOUT'" 166 | usage 3 167 | fi 168 | 169 | case "$PROTOCOL" in 170 | tcp) 171 | if [ "$HOST" = "" ] || [ "$PORT" = "" ]; then 172 | echoerr "Error: you need to provide a host and port to test." 173 | usage 2 174 | fi 175 | ;; 176 | http) 177 | if [ "$HOST" = "" ]; then 178 | echoerr "Error: you need to provide a host to test." 179 | usage 2 180 | fi 181 | ;; 182 | esac 183 | 184 | wait_for "$@" 185 | -------------------------------------------------------------------------------- /worker/distributor.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/hibiken/asynq" 7 | ) 8 | 9 | type TaskDistributor interface { 10 | DistributeTaskSendVerifyEmail( 11 | ctx context.Context, 12 | payload *PayloadSendVerifyEmail, 13 | opts ...asynq.Option, 14 | ) error 15 | } 16 | 17 | type RedisTaskDistributor struct { 18 | client *asynq.Client 19 | } 20 | 21 | func NewRedisTaskDistributor(redisOpt asynq.RedisClientOpt) TaskDistributor { 22 | client := asynq.NewClient(redisOpt) 23 | return &RedisTaskDistributor{ 24 | client: client, 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /worker/logger.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/rs/zerolog" 8 | "github.com/rs/zerolog/log" 9 | ) 10 | 11 | type Logger struct{} 12 | 13 | func NewLogger() *Logger { 14 | return &Logger{} 15 | } 16 | 17 | func (logger *Logger) Print(level zerolog.Level, args ...interface{}) { 18 | log.WithLevel(level).Msg(fmt.Sprint(args...)) 19 | } 20 | 21 | func (logger *Logger) Printf(ctx context.Context, format string, v ...interface{}) { 22 | log.WithLevel(zerolog.DebugLevel).Msgf(format, v...) 23 | } 24 | 25 | func (logger *Logger) Debug(args ...interface{}) { 26 | logger.Print(zerolog.DebugLevel, args...) 27 | } 28 | 29 | func (logger *Logger) Info(args ...interface{}) { 30 | logger.Print(zerolog.InfoLevel, args...) 31 | } 32 | 33 | func (logger *Logger) Warn(args ...interface{}) { 34 | logger.Print(zerolog.WarnLevel, args...) 35 | } 36 | 37 | func (logger *Logger) Error(args ...interface{}) { 38 | logger.Print(zerolog.ErrorLevel, args...) 39 | } 40 | 41 | func (logger *Logger) Fatal(args ...interface{}) { 42 | logger.Print(zerolog.FatalLevel, args...) 43 | } 44 | -------------------------------------------------------------------------------- /worker/mock/distributor.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: github.com/techschool/simplebank/worker (interfaces: TaskDistributor) 3 | 4 | // Package mockwk is a generated GoMock package. 5 | package mockwk 6 | 7 | import ( 8 | context "context" 9 | gomock "github.com/golang/mock/gomock" 10 | asynq "github.com/hibiken/asynq" 11 | worker "github.com/techschool/simplebank/worker" 12 | reflect "reflect" 13 | ) 14 | 15 | // MockTaskDistributor is a mock of TaskDistributor interface 16 | type MockTaskDistributor struct { 17 | ctrl *gomock.Controller 18 | recorder *MockTaskDistributorMockRecorder 19 | } 20 | 21 | // MockTaskDistributorMockRecorder is the mock recorder for MockTaskDistributor 22 | type MockTaskDistributorMockRecorder struct { 23 | mock *MockTaskDistributor 24 | } 25 | 26 | // NewMockTaskDistributor creates a new mock instance 27 | func NewMockTaskDistributor(ctrl *gomock.Controller) *MockTaskDistributor { 28 | mock := &MockTaskDistributor{ctrl: ctrl} 29 | mock.recorder = &MockTaskDistributorMockRecorder{mock} 30 | return mock 31 | } 32 | 33 | // EXPECT returns an object that allows the caller to indicate expected use 34 | func (m *MockTaskDistributor) EXPECT() *MockTaskDistributorMockRecorder { 35 | return m.recorder 36 | } 37 | 38 | // DistributeTaskSendVerifyEmail mocks base method 39 | func (m *MockTaskDistributor) DistributeTaskSendVerifyEmail(arg0 context.Context, arg1 *worker.PayloadSendVerifyEmail, arg2 ...asynq.Option) error { 40 | m.ctrl.T.Helper() 41 | varargs := []interface{}{arg0, arg1} 42 | for _, a := range arg2 { 43 | varargs = append(varargs, a) 44 | } 45 | ret := m.ctrl.Call(m, "DistributeTaskSendVerifyEmail", varargs...) 46 | ret0, _ := ret[0].(error) 47 | return ret0 48 | } 49 | 50 | // DistributeTaskSendVerifyEmail indicates an expected call of DistributeTaskSendVerifyEmail 51 | func (mr *MockTaskDistributorMockRecorder) DistributeTaskSendVerifyEmail(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { 52 | mr.mock.ctrl.T.Helper() 53 | varargs := append([]interface{}{arg0, arg1}, arg2...) 54 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DistributeTaskSendVerifyEmail", reflect.TypeOf((*MockTaskDistributor)(nil).DistributeTaskSendVerifyEmail), varargs...) 55 | } 56 | -------------------------------------------------------------------------------- /worker/processor.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/go-redis/redis/v8" 7 | "github.com/hibiken/asynq" 8 | "github.com/rs/zerolog/log" 9 | db "github.com/techschool/simplebank/db/sqlc" 10 | "github.com/techschool/simplebank/mail" 11 | ) 12 | 13 | const ( 14 | QueueCritical = "critical" 15 | QueueDefault = "default" 16 | ) 17 | 18 | type TaskProcessor interface { 19 | Start() error 20 | Shutdown() 21 | ProcessTaskSendVerifyEmail(ctx context.Context, task *asynq.Task) error 22 | } 23 | 24 | type RedisTaskProcessor struct { 25 | server *asynq.Server 26 | store db.Store 27 | mailer mail.EmailSender 28 | } 29 | 30 | func NewRedisTaskProcessor(redisOpt asynq.RedisClientOpt, store db.Store, mailer mail.EmailSender) TaskProcessor { 31 | logger := NewLogger() 32 | redis.SetLogger(logger) 33 | 34 | server := asynq.NewServer( 35 | redisOpt, 36 | asynq.Config{ 37 | Queues: map[string]int{ 38 | QueueCritical: 10, 39 | QueueDefault: 5, 40 | }, 41 | ErrorHandler: asynq.ErrorHandlerFunc(func(ctx context.Context, task *asynq.Task, err error) { 42 | log.Error().Err(err).Str("type", task.Type()). 43 | Bytes("payload", task.Payload()).Msg("process task failed") 44 | }), 45 | Logger: logger, 46 | }, 47 | ) 48 | 49 | return &RedisTaskProcessor{ 50 | server: server, 51 | store: store, 52 | mailer: mailer, 53 | } 54 | } 55 | 56 | func (processor *RedisTaskProcessor) Start() error { 57 | mux := asynq.NewServeMux() 58 | 59 | mux.HandleFunc(TaskSendVerifyEmail, processor.ProcessTaskSendVerifyEmail) 60 | 61 | return processor.server.Start(mux) 62 | } 63 | 64 | func (processor *RedisTaskProcessor) Shutdown() { 65 | processor.server.Shutdown() 66 | } 67 | -------------------------------------------------------------------------------- /worker/task_send_verify_email.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | 8 | "github.com/hibiken/asynq" 9 | "github.com/rs/zerolog/log" 10 | db "github.com/techschool/simplebank/db/sqlc" 11 | "github.com/techschool/simplebank/util" 12 | ) 13 | 14 | const TaskSendVerifyEmail = "task:send_verify_email" 15 | 16 | type PayloadSendVerifyEmail struct { 17 | Username string `json:"username"` 18 | } 19 | 20 | func (distributor *RedisTaskDistributor) DistributeTaskSendVerifyEmail( 21 | ctx context.Context, 22 | payload *PayloadSendVerifyEmail, 23 | opts ...asynq.Option, 24 | ) error { 25 | jsonPayload, err := json.Marshal(payload) 26 | if err != nil { 27 | return fmt.Errorf("failed to marshal task payload: %w", err) 28 | } 29 | 30 | task := asynq.NewTask(TaskSendVerifyEmail, jsonPayload, opts...) 31 | info, err := distributor.client.EnqueueContext(ctx, task) 32 | if err != nil { 33 | return fmt.Errorf("failed to enqueue task: %w", err) 34 | } 35 | 36 | log.Info().Str("type", task.Type()).Bytes("payload", task.Payload()). 37 | Str("queue", info.Queue).Int("max_retry", info.MaxRetry).Msg("enqueued task") 38 | return nil 39 | } 40 | 41 | func (processor *RedisTaskProcessor) ProcessTaskSendVerifyEmail(ctx context.Context, task *asynq.Task) error { 42 | var payload PayloadSendVerifyEmail 43 | if err := json.Unmarshal(task.Payload(), &payload); err != nil { 44 | return fmt.Errorf("failed to unmarshal payload: %w", asynq.SkipRetry) 45 | } 46 | 47 | user, err := processor.store.GetUser(ctx, payload.Username) 48 | if err != nil { 49 | return fmt.Errorf("failed to get user: %w", err) 50 | } 51 | 52 | verifyEmail, err := processor.store.CreateVerifyEmail(ctx, db.CreateVerifyEmailParams{ 53 | Username: user.Username, 54 | Email: user.Email, 55 | SecretCode: util.RandomString(32), 56 | }) 57 | if err != nil { 58 | return fmt.Errorf("failed to create verify email: %w", err) 59 | } 60 | 61 | subject := "Welcome to Simple Bank" 62 | // TODO: replace this URL with an environment variable that points to a front-end page 63 | verifyUrl := fmt.Sprintf("http://localhost:8080/v1/verify_email?email_id=%d&secret_code=%s", 64 | verifyEmail.ID, verifyEmail.SecretCode) 65 | content := fmt.Sprintf(`Hello %s,