├── .github
└── workflows
│ └── build.yml
├── .gitignore
├── .golangci.yml
├── Dockerfile
├── LICENSE
├── Makefile
├── README.md
├── api
└── swagger.yaml
├── auth-config.yaml
├── cmd
├── root.go
└── root_test.go
├── docker-compose-traefik.yaml
├── docker-compose.yaml
├── docs
├── .dockerignore
├── .gitignore
└── architecture-diagram.xml
├── go.mod
├── go.sum
├── main.go
├── pkg
├── auth
│ ├── callbacks
│ │ └── callback.go
│ ├── constants
│ │ └── constants.go
│ ├── errors
│ │ └── errors.go
│ ├── flow.go
│ ├── flow_test.go
│ ├── modules
│ │ ├── credentials.go
│ │ ├── credentials_test.go
│ │ ├── hydra.go
│ │ ├── hydra_test.go
│ │ ├── kerberos.go
│ │ ├── kerberos_test.go
│ │ ├── login.go
│ │ ├── module.go
│ │ ├── module_test.go
│ │ ├── otp.go
│ │ ├── otp
│ │ │ ├── email.go
│ │ │ ├── email_test.go
│ │ │ └── sender.go
│ │ ├── otp_test.go
│ │ ├── qr.go
│ │ ├── qr_test.go
│ │ ├── registration.go
│ │ └── registration_test.go
│ └── state
│ │ └── state.go
├── config
│ ├── config.go
│ └── config_test.go
├── controller
│ ├── auth.go
│ ├── auth_test.go
│ ├── controller_test.go
│ ├── passwordless.go
│ ├── passwordless_test.go
│ ├── session.go
│ └── session_test.go
├── crypt
│ └── crypt.go
├── log
│ └── logger.go
├── middleware
│ ├── authenticated.go
│ ├── authenticated_test.go
│ └── requesturi.go
├── server
│ ├── server.go
│ └── server_test.go
├── session
│ ├── config.go
│ ├── service.go
│ ├── session.go
│ ├── session_repository.go
│ ├── session_repository_mongo.go
│ ├── session_repository_mongo_test.go
│ └── session_rest.go
└── user
│ ├── config.go
│ ├── service.go
│ ├── user.go
│ ├── user_repository.go
│ ├── user_repository_ldap.go
│ ├── user_repository_ldap_test.go
│ ├── user_repository_mongo.go
│ ├── user_repository_mongo_test.go
│ └── user_repository_rest.go
├── test
├── auth-config-dev.yaml
├── auth-config-hydra.yaml
├── auth-config-otp.yaml
├── docker-compose-ldap.yaml
├── docker-compose-mongodb-ldap.yaml
├── docker-compose-mongodb.yaml
├── docker-compose-pgsql.yaml
└── integration
│ ├── otp_auth_test.go
│ └── utils_test.go
├── traefik
├── conf
│ └── config.yml
├── plugins-local
│ └── src
│ │ └── github.com
│ │ └── maximthomas
│ │ └── gortas_traefik_plugin
│ │ ├── .traefik.yml
│ │ ├── LICENSE
│ │ ├── go.mod
│ │ ├── gortas_plugin.go
│ │ └── gortas_plugin_test.go
└── traefik.yml
└── ui.Dockerfile
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Build
2 |
3 | on:
4 | push:
5 | pull_request:
6 | branches: [master]
7 |
8 | jobs:
9 | golangci:
10 | name: lint
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v3
14 |
15 | - name: golangci-lint
16 | uses: golangci/golangci-lint-action@v3
17 | with:
18 | version: latest
19 |
20 | build:
21 | runs-on: ubuntu-latest
22 | steps:
23 | - uses: actions/checkout@v3
24 |
25 | - name: Set up Go
26 | uses: actions/setup-go@v3
27 | with:
28 | go-version: 1.19
29 |
30 | - name: Build project
31 | run: make default
32 |
33 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Binaries for programs and plugins
2 | *.exe
3 | *.exe~
4 | *.dll
5 | *.so
6 | *.dylib
7 |
8 | # Test binary, build with `go test -c`
9 | *.test
10 |
11 | # Output of the go coverage tool, specifically when used with LiteIDE
12 | *.out
13 |
14 | .idea/
15 | .vscode/
16 | .DS_Store
17 | *.code-workspace
18 |
19 | gortas
20 |
--------------------------------------------------------------------------------
/.golangci.yml:
--------------------------------------------------------------------------------
1 | run:
2 | tests: false
3 |
4 | linters-settings:
5 | dupl:
6 | threshold: 100
7 | funlen:
8 | lines: 100
9 | statements: 50
10 | goconst:
11 | min-len: 2
12 | min-occurrences: 3
13 | gocritic:
14 | enabled-tags:
15 | - diagnostic
16 | - experimental
17 | - opinionated
18 | - performance
19 | - style
20 | disabled-checks:
21 | - dupImport # https://github.com/go-critic/go-critic/issues/845
22 | - ifElseChain
23 | - octalLiteral
24 | - whyNoLint
25 | - commentedOutCode
26 | # revive:
27 | # -
28 | linters:
29 | disable-all: true
30 | enable:
31 | # - dupl
32 | - goconst
33 | - gocritic
34 | - goimports
35 | - errcheck
36 | - gosimple
37 | - ineffassign
38 | - staticcheck
39 | - typecheck
40 | - revive
41 | - govet
42 | #- funlen
43 | #- gosec
44 | #- unused
45 |
46 | exclude-rules:
47 | - path: _test\.go
48 | linters:
49 | - gocyclo
50 | - errcheck
51 | - dupl
52 | - gosec
53 | - funlen
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM golang:1.20-alpine as builder
2 | RUN mkdir /build
3 | ADD . /build/
4 | WORKDIR /build
5 | RUN go build -o gortas .
6 |
7 | FROM alpine
8 | RUN adduser -S -D -H -h /app appuser
9 | USER appuser
10 | COPY --from=builder /build/gortas /usr/bin/gortas
11 | ADD ./auth-config.yaml /app/config/auth-config.yaml
12 | WORKDIR /app
13 | ENTRYPOINT ["gortas"]
14 | CMD ["--config", "./config/auth-config.yaml"]
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | GORTAS_RELEASE_TAGS := $(shell go list -f ':{{join (context.ReleaseTags) ":"}}:' runtime)
2 |
3 | # Only use the `-race` flag on newer versions of Go (version 1.3 and newer)
4 | ifeq (,$(findstring :go1.3:,$(GO_RELEASE_TAGS)))
5 | RACE_FLAG :=
6 | else
7 | RACE_FLAG := -race -cpu 1,2,4
8 | endif
9 |
10 | default: build quicktest
11 |
12 | install:
13 | go get -t -v ./...
14 |
15 | build:
16 | go build -v ./...
17 |
18 | test:
19 | go test -v $(RACE_FLAG) -cover ./...
20 |
21 | quicktest:
22 | go test ./...
23 |
24 | vet:
25 | go vet ./...
26 |
27 | docker:
28 | docker build -f Dockerfile -t maximthomas/gortas .
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Gortas
2 |
3 |
4 |
5 |
6 | **Gortas** (Golang Authentication Service) is a flexible API based authentication service, allows adding authentication to your site or service with minimum effort.
7 | **Gortas** supports multiple authentication methods across various data sources. You can authenticate against your Active Directory or other LDAP user directory or MongoDB.
8 |
9 | It allows building complex authentication processes with various steps and different authentication methods.
10 |
11 | For example, you can build login and password authentication with an SMS confirmation code
12 |
13 | ## Deeper Into the Details
14 |
15 | ### Supported Authentication methods
16 | * Username and password - authenticates against an existing user datastore
17 | * Registration - creates a user account in a user data store for further authentication
18 | * Kerberos - uses Kerberos authentication
19 | * OTP - one-time password sent via email or SMS
20 |
21 | It is possible to develop custom authentication methods.
22 |
23 | ### Supported Data Sources
24 | * LDAP
25 | * NoSQL
26 | * MongoDB
27 | * SQL databases (in development)
28 |
29 | ## Main concepts
30 |
31 | With **Gortas** you can build an authentication system with any desired complexity across different data sources simultaneously.
32 |
33 | ### Realm
34 |
35 | There could be different realms in Gortas - for example, the `staff` realm for employees and the `clients` realm for clients.
36 | All realms use their own user data stores. For example, for staff users, we will use an enterprise LDAP user directory, for clients we could use another database, for example, MongoDB.
37 | Every realm contains authentication modules, authentication chains, and user data store.
38 |
39 | ### Authentication Module
40 |
41 | Single authentication module, responsible for authentication or authorization step.
42 | For example - prompt username and password or send and verify a one-time password.
43 |
44 | ### Authentication FLow
45 |
46 | Authentication modules are organized in flows.
47 | Every authentication chain is a sequence of authentication modules to orchestrate complex authentication processes.
48 | For example, we have two modules: Login module - prompts a user to provide login and password, and OTP module - sends SMS with a one-time password to the user.
49 |
50 | When a user tries to authenticate he will be prompted to enter login and password.
51 | If the credentials are correct authentication service sends OTP via SMS and prompts the user to enter the one-time password as a second authentication factor.
52 | On the other hand, we can line up Kerberos and login and password in the same chain.
53 | So if a user was not authenticated in Kerberos automatically, he will be prompted for username and password
54 |
55 | ## Quick Start with docker-compose
56 |
57 | Clone **Gortas** repository
58 |
59 | ```
60 | git clone https://github.com/maximthomas/gortas.git
61 | ```
62 |
63 | Then go to `gortas` directory and run `docker-compose`
64 |
65 | ```
66 | docker-compose up
67 | ```
68 |
69 | This command will create services:
70 | 1. `gortas` - authentication API service
71 | 1. `mongo` - Mongo database for user and session storage
72 |
73 | Further reading: [Quick Start](https://github.com/maximthomas/gortas/wiki/Quick-Start)
--------------------------------------------------------------------------------
/api/swagger.yaml:
--------------------------------------------------------------------------------
1 | openapi: 3.0.1
2 | info:
3 | title: Gortas Authentication Service
4 | description: ''
5 | termsOfService: 'TODO'
6 | contact:
7 | email: 'maxim.thomas@gmail.com'
8 | license:
9 | name: Apache 2.0
10 | url: http://www.apache.org/licenses/LICENSE-2.0.html
11 | version: 1.0.0
12 | servers:
13 | - url: https://gortas:8443/v1/
14 | tags:
15 | - name: authentication
16 | description: Authentication service
17 | externalDocs:
18 | description: Find out more
19 | url: http://swagger.io
20 | paths:
21 | /auth/{realm}/{flow}:
22 | get:
23 | tags:
24 | - authentication
25 | summary: start authentication
26 | operationId: startAuth
27 | parameters:
28 | - name: realm
29 | in: path
30 | description: Realm to authenticate
31 | required: true
32 | schema:
33 | type: string
34 | - name: flow
35 | in: path
36 | description: Authentication service
37 | required: true
38 | schema:
39 | type: string
40 | responses:
41 | 200:
42 | description: successful operation
43 | content:
44 | application/json:
45 | schema:
46 | $ref: '#/components/schemas/CredentialsRequest'
47 | post:
48 | tags:
49 | - authentication
50 | summary: submit credentials
51 | operationId: submit authentication data
52 | parameters:
53 | - name: realm
54 | in: path
55 | description: Realm to authenticate
56 | required: true
57 | schema:
58 | type: string
59 | - name: flow
60 | in: path
61 | description: Service ti authenticate to authenticate
62 | required: true
63 | schema:
64 | type: string
65 | requestBody:
66 | description: Credentials data
67 | content:
68 | application/json:
69 | schema:
70 | $ref: '#/components/schemas/CredentialsResponse'
71 | required: true
72 | responses:
73 | 405:
74 | description: Invalid input
75 | content: {}
76 |
77 | components:
78 | schemas:
79 | Credential:
80 | type: object
81 | properties:
82 | type:
83 | type: string
84 | value:
85 | type: string
86 | validation:
87 | type: string
88 | description: 'Regular expression to validate field'
89 | required:
90 | type: boolean
91 | properties:
92 | additionalProperties:
93 | type: object
94 | error:
95 | type: string
96 | description: 'Error message if input is incorrect'
97 | required:
98 | - type
99 | - value
100 | CredentialsRequest:
101 | type: object
102 | properties:
103 | flowId:
104 | type: string
105 | module:
106 | type: string
107 | credentials:
108 | type: object
109 | additionalProperties:
110 | $ref: '#/components/schemas/Credential'
111 | CredentialsResponse:
112 | allOf:
113 | - $ref: '#/components/schemas/CredentialsRequest'
114 |
--------------------------------------------------------------------------------
/auth-config.yaml:
--------------------------------------------------------------------------------
1 | flows:
2 | login:
3 | modules:
4 | - id: "login"
5 | type: "login"
6 | properties:
7 | registration:
8 | modules:
9 | - id: "registration"
10 | type: "registration"
11 | properties:
12 | primaryField:
13 | dataStore: "email"
14 | name: "email"
15 | prompt: "Email"
16 | additionalFields:
17 | - dataStore: "name"
18 | prompt: "Name"
19 | name: "name"
20 |
21 | userDataStore:
22 | type: "mongodb"
23 | properties:
24 | url: "mongodb://root:changeme@localhost:27017"
25 | database: "users"
26 | collection: "users"
27 | userAttributes:
28 | - "name"
29 | - "email"
30 |
31 | session:
32 | type: "stateful" #could be also stateful
33 | expires: 60000
34 | dataStore:
35 | type: "mongo"
36 | properties:
37 | url: "mongodb://root:changeme@localhost:27017"
38 | database: "session"
39 | collection: "sessions"
40 |
41 | server:
42 | cors:
43 | allowedOrigins:
44 | - http://localhost:3000
45 |
46 |
47 |
--------------------------------------------------------------------------------
/cmd/root.go:
--------------------------------------------------------------------------------
1 | package cmd
2 |
3 | import (
4 | "fmt"
5 | "os"
6 | "strings"
7 |
8 | "github.com/maximthomas/gortas/pkg/config"
9 | "github.com/maximthomas/gortas/pkg/server"
10 |
11 | "github.com/mitchellh/go-homedir"
12 | "github.com/spf13/viper"
13 |
14 | "github.com/spf13/cobra"
15 | )
16 |
17 | var (
18 | cfgFile string
19 | rootCmd = &cobra.Command{
20 | Use: "gortas",
21 | Short: "Gortas is a golang authentication service",
22 | Run: func(cmd *cobra.Command, args []string) {
23 | server.RunServer()
24 | },
25 | }
26 |
27 | versionCmd = &cobra.Command{
28 | Use: "version",
29 | Short: "Shown version",
30 | Run: func(cmd *cobra.Command, args []string) {
31 | fmt.Println("0.0.1")
32 | },
33 | }
34 | )
35 |
36 | func Execute() {
37 | if err := rootCmd.Execute(); err != nil {
38 | fmt.Println(err)
39 | os.Exit(1)
40 | }
41 | }
42 |
43 | func init() {
44 | cobra.OnInitialize(initConfig)
45 | rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/auth-config.yaml)")
46 | rootCmd.AddCommand(versionCmd)
47 | }
48 |
49 | func er(msg interface{}) {
50 | fmt.Println("Error:", msg)
51 | os.Exit(1)
52 | }
53 |
54 | func initConfig() {
55 | fmt.Println(os.Getwd())
56 | if cfgFile != "" {
57 | // Use config file from the flag.
58 | viper.SetConfigFile(cfgFile)
59 | } else {
60 | // Find home directory.
61 | home, err := homedir.Dir()
62 | if err != nil {
63 | er(err)
64 | }
65 |
66 | // Search config in home directory with name ".cobra" (without extension).
67 | viper.AddConfigPath(home)
68 | viper.SetConfigName("auth-config")
69 | }
70 |
71 | viper.AutomaticEnv()
72 | viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
73 |
74 | if err := viper.ReadInConfig(); err == nil {
75 | fmt.Println("Using config file:", viper.ConfigFileUsed())
76 | err = config.InitConfig()
77 | if err != nil {
78 | er(err)
79 | }
80 | } else {
81 | er(err)
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/cmd/root_test.go:
--------------------------------------------------------------------------------
1 | package cmd
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/maximthomas/gortas/pkg/config"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestExecute(t *testing.T) {
12 | args := []string{"version", "--config", "../test/auth-config-dev.yaml"}
13 | rootCmd.SetArgs(args)
14 | err := rootCmd.Execute()
15 | assert.NoError(t, err)
16 | conf := config.GetConfig()
17 | assert.True(t, len(conf.Flows) > 0)
18 |
19 | }
20 |
--------------------------------------------------------------------------------
/docker-compose-traefik.yaml:
--------------------------------------------------------------------------------
1 | version: '3.7'
2 |
3 | services:
4 | gortas:
5 | image: maximthomas/gortas:latest
6 | ports:
7 | - "8080"
8 | volumes:
9 | - '$PWD/test/auth-config-dev.yaml:/app/config/auth-config.yaml'
10 |
11 | traefik:
12 | image: traefik:latest
13 | volumes:
14 | - '$PWD/traefik/traefik.yml:/etc/traefik/traefik.yml'
15 | - '$PWD/traefik/conf:/etc/traefik/conf'
16 | - '$PWD/traefik/plugins-local:/plugins-local'
17 | - '/var/run/docker.sock:/var/run/docker.sock'
18 | ports:
19 | - "8081:8080"
20 | - "8080:80"
21 |
22 | sample-service:
23 | image: maximthomas/sample-service
--------------------------------------------------------------------------------
/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | version: '3.7'
2 |
3 | services:
4 | gortas:
5 | build:
6 | context: .
7 | image: maximthomas/gortas:latest
8 | ports:
9 | - "8080:8080"
10 | depends_on:
11 | - mongo
12 | environment:
13 | SESSION_DATASTORE_PROPERTIES_URL: "mongodb://root:changeme@mongo:27017"
14 | AUTHENTICATION_REALMS_USERS_USERDATASTORE_PROPERTIES_URL: "mongodb://root:changeme@mongo:27017"
15 |
16 | mongo:
17 | image: mongo:latest
18 | restart: always
19 | ports:
20 | - "27017:27017"
21 | environment:
22 | MONGO_INITDB_ROOT_USERNAME: root
23 | MONGO_INITDB_ROOT_PASSWORD: changeme
24 |
--------------------------------------------------------------------------------
/docs/.dockerignore:
--------------------------------------------------------------------------------
1 | */node_modules
2 | *.log
3 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 |
3 | node_modules
4 |
5 | lib/core/metadata.js
6 | lib/core/MetadataBlog.js
7 |
8 | website/translated_docs
9 | website/build/
10 | website/yarn.lock
11 | website/node_modules
12 | website/i18n/*
13 |
--------------------------------------------------------------------------------
/docs/architecture-diagram.xml:
--------------------------------------------------------------------------------
1 | 7Vldb5swFP01eeyEDYbmsW36MW3ruqZTu725cJegEhwZp0n262eCDcFOS1olkEqT8mAfro19fHw/SM89mywuOZ2Ov7EIkh52okXPHfQwRh7GvfznRMsCCXwFjHgcKaMKGMZ/QYGOQmdxBFnNUDCWiHhaB0OWphCKGkY5Z/O62R+W1N86pSOwgGFIExu9jyMxLtBjHFT4FcSjsX4z8vvFkwnVxmon2ZhGbL4Guec994wzJorWZHEGSU6e5qUYd/HC03JhHFKxzYBr5+qJ/JjcoOP7Ob/8/ZmSpy9H6jCeaTJTG749H96pBYulZoGzWRpBPpHTc0/n41jAcErD/OlcnrvExmKSyB6STTUlcAGLF9eKSgakdIBNQPClNNEDiCJtqWlV/Xl1BtpkXKNfgVQd+6icumJGNhQ5byDKtYg6mcmVYGcI/DmWTHRNGd6CMuS0yhmxOLtI5A2QnAkqYJ07P5GLOX3ksjXKW7eQCcY3PLiEFLgcfA3zzhn3nDrj/gbC8QbC/X3x7Vt8WxxBGp3kXlH2woRmWRxKLjJBubDhNbbq1MIiFg/qSd7+leOfMFHdwWLNbrDUnVRu8WFl6aBAA8VQT3eroatebewN8FjyBFyBxd4gsvy3cX5y/2zGQ2i+3JKGEYgmQdt6aHBKGuOQUBE/15e7SQTqDTcslhup5IbqcisDi56i2KYatR4IzIl8w1MExkQFD9ZEK02W236/TANLpkOQkmOpBAenlmRl3JzmzXCZxFKE3G2+3I+FXL8+lgANn0YrEX+fCTkNKDwr8g1EduMRSN9gFhPbJaANEjnel0tAnkX2z0xeoQOKW54h687dKLIFqjj72OrEHqkx7fY3MN2uOvttRawq+JD10IMaAs8OY4y+iY1Bxu0yyBwZWaTnvjXIaHGZotlzUNEe4mClpJMmNUynTa/nTLKzx7RHO7pGSRba7UqTyMhXPNPzb5v4YCOD8szKZ98aRa9kPocSj82qBvU35DCtxmN9kzuqa4Jt65p6TUPaCy3bli8vnXxL9YuR6Hnvrl+MLNttuX7Bdk7dph73I8d9l9jlx9smkQadanRXoYZ0HWrsb2//NdqkUfIRJEoCQ1neOyXqY8ON9luWqF1mH45Ed+n5tlRVp3UfMsRQZoFvVRUy5OmSXalKdqs/ywrz6i9H9/wf
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/maximthomas/gortas
2 |
3 | go 1.20
4 |
5 | require (
6 | github.com/dgrijalva/jwt-go v3.2.0+incompatible
7 | github.com/fsnotify/fsnotify v1.6.0 // indirect
8 | github.com/gin-gonic/gin v1.9.1
9 | github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect
10 | github.com/go-ldap/ldap/v3 v3.4.4
11 | github.com/go-playground/validator/v10 v10.14.0 // indirect
12 | github.com/golang/snappy v0.0.4 // indirect
13 | github.com/google/uuid v1.3.0
14 | github.com/jcmturner/gokrb5/v8 v8.4.4
15 | github.com/klauspost/compress v1.16.5 // indirect
16 | github.com/leodido/go-urn v1.2.4 // indirect
17 | github.com/magiconair/properties v1.8.7 // indirect
18 | github.com/mitchellh/go-homedir v1.1.0
19 | github.com/mitchellh/mapstructure v1.5.0
20 | github.com/pkg/errors v0.9.1
21 | github.com/sirupsen/logrus v1.9.0
22 | github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
23 | github.com/spf13/afero v1.9.5 // indirect
24 | github.com/spf13/cast v1.5.0 // indirect
25 | github.com/spf13/cobra v1.7.0
26 | github.com/spf13/jwalterweatherman v1.1.0 // indirect
27 | github.com/spf13/viper v1.15.0
28 | github.com/stretchr/testify v1.8.3
29 | github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a // indirect
30 | go.mongodb.org/mongo-driver v1.11.4
31 | golang.org/x/crypto v0.31.0
32 | golang.org/x/sync v0.10.0 // indirect
33 | golang.org/x/sys v0.28.0 // indirect
34 | golang.org/x/text v0.21.0 // indirect
35 | gopkg.in/ini.v1 v1.67.0 // indirect
36 | gopkg.in/yaml.v3 v3.0.1 // indirect
37 | )
38 |
39 | require (
40 | github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
41 | github.com/davecgh/go-spew v1.1.1 // indirect
42 | github.com/gin-contrib/sse v0.1.0 // indirect
43 | github.com/go-playground/locales v0.14.1 // indirect
44 | github.com/go-playground/universal-translator v0.18.1 // indirect
45 | github.com/hashicorp/go-uuid v1.0.3 // indirect
46 | github.com/hashicorp/hcl v1.0.0 // indirect
47 | github.com/inconshreveable/mousetrap v1.1.0 // indirect
48 | github.com/jcmturner/aescts/v2 v2.0.0 // indirect
49 | github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect
50 | github.com/jcmturner/gofork v1.7.6 // indirect
51 | github.com/jcmturner/goidentity/v6 v6.0.1 // indirect
52 | github.com/jcmturner/rpc/v2 v2.0.3 // indirect
53 | github.com/json-iterator/go v1.1.12 // indirect
54 | github.com/mattn/go-isatty v0.0.19 // indirect
55 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
56 | github.com/modern-go/reflect2 v1.0.2 // indirect
57 | github.com/pmezard/go-difflib v1.0.0 // indirect
58 | github.com/rs/cors/wrapper/gin v0.0.0-20230301160956-5c2b877d2a03
59 | github.com/spf13/pflag v1.0.5 // indirect
60 | github.com/subosito/gotenv v1.4.2 // indirect
61 | github.com/ugorji/go/codec v1.2.11 // indirect
62 | github.com/xdg-go/pbkdf2 v1.0.0 // indirect
63 | github.com/xdg-go/scram v1.1.2 // indirect
64 | github.com/xdg-go/stringprep v1.0.4 // indirect
65 | github.com/xhit/go-simple-mail/v2 v2.13.0
66 | golang.org/x/net v0.23.0 // indirect
67 | google.golang.org/protobuf v1.33.0 // indirect
68 | )
69 |
70 | require (
71 | github.com/bytedance/sonic v1.9.1 // indirect
72 | github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
73 | github.com/gabriel-vasile/mimetype v1.4.2 // indirect
74 | github.com/go-test/deep v1.1.0 // indirect
75 | github.com/goccy/go-json v0.10.2 // indirect
76 | github.com/klauspost/cpuid/v2 v2.2.4 // indirect
77 | github.com/montanaflynn/stats v0.7.0 // indirect
78 | github.com/pelletier/go-toml/v2 v2.0.8 // indirect
79 | github.com/rs/cors v1.11.0 // indirect
80 | github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208 // indirect
81 | github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
82 | golang.org/x/arch v0.3.0 // indirect
83 | )
84 |
--------------------------------------------------------------------------------
/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "github.com/maximthomas/gortas/cmd"
5 | )
6 |
7 | func main() {
8 | cmd.Execute()
9 | }
10 |
--------------------------------------------------------------------------------
/pkg/auth/callbacks/callback.go:
--------------------------------------------------------------------------------
1 | package callbacks
2 |
3 | const (
4 | TypeText = "text"
5 | TypePassword = "password"
6 | TypeImage = "image"
7 | TypeHTTPStatus = "httpstatus"
8 | TypeAutoSubmit = "autosubmit"
9 | TypeOptions = "options"
10 | TypeActions = "actions"
11 | )
12 |
13 | type Callback struct {
14 | Name string `json:"name,omitempty"`
15 | Type string `json:"type"`
16 | Value string `json:"value"`
17 | Prompt string `json:"prompt,omitempty"`
18 | Validation string `json:"validation,omitempty"`
19 | Required bool `json:"required,omitempty"`
20 | Options []string `json:"options,omitempty"`
21 | Properties map[string]string `json:"properties,omitempty"` // TODO v2 move to map[string]interface{}
22 | Error string `json:"error,omitempty"`
23 | }
24 |
25 | // Request TODO move to more appropriate package
26 | type Request struct {
27 | Module string `json:"module,omitempty"`
28 | Callbacks []Callback `json:"callbacks,omitempty"`
29 | FlowID string `json:"flowId,omitempty"`
30 | }
31 |
32 | // Response TODO move to more appropriate package
33 | type Response struct {
34 | Module string `json:"module,omitempty"`
35 | Callbacks []Callback `json:"callbacks,omitempty"`
36 | Token string `json:"token,omitempty"`
37 | Type string `json:"type,omitempty"` // returns token type
38 | FlowID string `json:"flowId,omitempty"` // TODO add error
39 | }
40 |
--------------------------------------------------------------------------------
/pkg/auth/constants/constants.go:
--------------------------------------------------------------------------------
1 | package constants
2 |
3 | const CriteriaSufficient = "sufficient"
4 |
5 | const FlowStateSessionProperty = "fs"
6 |
--------------------------------------------------------------------------------
/pkg/auth/errors/errors.go:
--------------------------------------------------------------------------------
1 | package errors
2 |
3 | type AuthFailed struct {
4 | msg string
5 | }
6 |
7 | func NewAuthFailed(msg string) *AuthFailed {
8 | return &AuthFailed{msg: msg}
9 | }
10 |
11 | func (e *AuthFailed) Error() string { return e.msg }
12 |
--------------------------------------------------------------------------------
/pkg/auth/flow.go:
--------------------------------------------------------------------------------
1 | package auth
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "net/http"
7 |
8 | "github.com/google/uuid"
9 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
10 | "github.com/maximthomas/gortas/pkg/auth/constants"
11 | "github.com/maximthomas/gortas/pkg/auth/modules"
12 | "github.com/maximthomas/gortas/pkg/auth/state"
13 | "github.com/maximthomas/gortas/pkg/config"
14 | "github.com/maximthomas/gortas/pkg/log"
15 | "github.com/maximthomas/gortas/pkg/session"
16 | "github.com/pkg/errors"
17 | "github.com/sirupsen/logrus"
18 |
19 | autherrors "github.com/maximthomas/gortas/pkg/auth/errors"
20 | )
21 |
22 | type FlowProcessor interface {
23 | Process(flowName string, cbReq callbacks.Request, r *http.Request, w http.ResponseWriter) (cbResp callbacks.Response, err error)
24 | }
25 |
26 | type flowProcessor struct {
27 | logger logrus.FieldLogger
28 | }
29 |
30 | func NewFlowProcessor() FlowProcessor {
31 | return &flowProcessor{
32 | logger: log.WithField("module", "FlowProcessor"),
33 | }
34 | }
35 |
36 | // Process responsible for the authentication process
37 | // goes through flow state authentication modules, requests and processes callbacks
38 | func (f *flowProcessor) Process(flowName string, cbReq callbacks.Request, r *http.Request, w http.ResponseWriter) (cbResp callbacks.Response, err error) {
39 |
40 | fs, err := f.getFlowState(flowName, cbReq.FlowID)
41 | if err != nil {
42 | return cbResp, fmt.Errorf("Process: error getting flow state %w", err)
43 | }
44 | // TODO extract process callbacks to a separate function
45 |
46 | inCbs := cbReq.Callbacks
47 | var outCbs []callbacks.Callback
48 | modules:
49 | for moduleIndex, moduleInfo := range fs.Modules {
50 | switch moduleInfo.Status {
51 | // TODO v2 match module names in a callback request
52 | case state.Start, state.InProgress:
53 | var instance modules.AuthModule
54 | instance, err = modules.GetAuthModule(moduleInfo, r, w)
55 | if err != nil {
56 | return cbResp, fmt.Errorf("Process: error getting auth module %v %w", moduleInfo, err)
57 | }
58 | var newState state.ModuleStatus
59 | // if module is the first in the flow, then pass callbacks directly to the module
60 | if (len(cbReq.Callbacks) == 0 || moduleIndex > 0) && moduleInfo.Status == state.Start {
61 | newState, outCbs, err = instance.Process(&fs)
62 | if err != nil {
63 | return cbResp, err
64 | }
65 | } else {
66 | if err != nil {
67 | f.logger.Error("error parsing request body: ", err)
68 | return cbResp, errors.New("bad request")
69 | }
70 | err = instance.ValidateCallbacks(inCbs)
71 | if err != nil {
72 | return cbResp, err
73 | }
74 | newState, outCbs, err = instance.ProcessCallbacks(inCbs, &fs)
75 | if err != nil {
76 | return cbResp, err
77 | }
78 | }
79 |
80 | moduleInfo.Status = newState
81 |
82 | fs.UpdateModuleInfo(moduleIndex, moduleInfo)
83 | err = f.updateFlowState(&fs)
84 | if err != nil {
85 | return cbResp, errors.Wrap(err, "error update flowstate")
86 | }
87 |
88 | switch moduleInfo.Status {
89 | case state.InProgress, state.Start:
90 | cbResp = callbacks.Response{
91 | Callbacks: outCbs,
92 | Module: moduleInfo.ID,
93 | FlowID: fs.ID,
94 | }
95 | return cbResp, err
96 | case state.Pass:
97 | if moduleInfo.Criteria == constants.CriteriaSufficient { // TODO v2 refactor move to function
98 | break modules
99 | }
100 | continue
101 | case state.Fail:
102 | if moduleInfo.Criteria == constants.CriteriaSufficient { // TODO v2 refactor move to function
103 | continue
104 | }
105 | return cbResp, autherrors.NewAuthFailed("auth failed")
106 | }
107 | }
108 | }
109 | authSucceeded := true
110 | for _, moduleInfo := range fs.Modules {
111 |
112 | if moduleInfo.Criteria == constants.CriteriaSufficient { // TODO v2 refactor move to function
113 | if moduleInfo.Status == state.Pass {
114 | break
115 | }
116 | } else if moduleInfo.Status != state.Pass {
117 | authSucceeded = false
118 | break
119 | }
120 | }
121 |
122 | if authSucceeded {
123 | for _, moduleInfo := range fs.Modules {
124 | var am modules.AuthModule
125 | am, err = modules.GetAuthModule(moduleInfo, r, w)
126 | if err != nil {
127 | return cbResp, errors.Wrap(err, "error getting auth module for postprocess")
128 | }
129 | err = am.PostProcess(&fs)
130 | if err != nil {
131 | return cbResp, errors.Wrap(err, "error while postprocess")
132 | }
133 | }
134 |
135 | var sessID string
136 | sessID, err = f.createSession(&fs)
137 | if err != nil {
138 | return cbResp, errors.Wrap(err, "error creating session")
139 | }
140 | cbResp = callbacks.Response{
141 | Token: sessID,
142 | Type: "Bearer",
143 | }
144 | err = session.GetSessionService().DeleteSession(fs.ID)
145 | if err != nil {
146 | f.logger.Warnf("error clearing session %s %v", fs.ID, err)
147 | }
148 |
149 | return cbResp, err
150 | }
151 |
152 | return cbResp, err
153 | }
154 |
155 | func (f *flowProcessor) createSession(fs *state.FlowState) (sessID string, err error) {
156 | if fs.UserID == "" {
157 | return sessID, errors.New("user id is not set")
158 | }
159 |
160 | return session.GetSessionService().CreateUserSession(fs.UserID)
161 |
162 | }
163 |
164 | func (f *flowProcessor) updateFlowState(fs *state.FlowState) error {
165 | sessionProp, err := json.Marshal(*fs)
166 | if err != nil {
167 | return errors.Wrap(err, "error marshaling flow sate")
168 | }
169 |
170 | ss := session.GetSessionService()
171 |
172 | sess, err := ss.GetSession(fs.ID)
173 | if err != nil {
174 | sess = session.Session{
175 | ID: fs.ID,
176 | Properties: make(map[string]string),
177 | }
178 | sess.Properties[constants.FlowStateSessionProperty] = string(sessionProp)
179 | _, err = ss.CreateSession(sess)
180 | } else {
181 | sess.Properties[constants.FlowStateSessionProperty] = string(sessionProp)
182 | err = ss.UpdateSession(sess)
183 | }
184 | if err != nil {
185 | return err
186 | }
187 | return nil
188 |
189 | }
190 |
191 | func (f *flowProcessor) getFlowState(name, id string) (state.FlowState, error) {
192 | c := config.GetConfig()
193 | ss := session.GetSessionService()
194 | sess, err := ss.GetSession(id)
195 | var fs state.FlowState
196 | if err != nil {
197 | flow, ok := c.Flows[name]
198 | if !ok {
199 | return fs, errors.Errorf("auth flow %v not found", name)
200 | }
201 | fs = f.newFlowState(name, flow)
202 | } else {
203 | err = json.Unmarshal([]byte(sess.Properties[constants.FlowStateSessionProperty]), &fs)
204 | if err != nil {
205 | return fs, errors.New("session property fs does not exsit")
206 | }
207 | }
208 |
209 | return fs, nil
210 | }
211 |
212 | // createNewFlowState - creates new flow state from the Realm and AuthFlow settings, generates new flow Id and fill module properties
213 | func (f *flowProcessor) newFlowState(flowName string, flow config.Flow) state.FlowState {
214 |
215 | fs := state.FlowState{
216 | Modules: make([]state.FlowStateModuleInfo, len(flow.Modules)),
217 | SharedState: make(map[string]string),
218 | UserID: "",
219 | ID: uuid.New().String(),
220 | Name: flowName,
221 | }
222 |
223 | for i, module := range flow.Modules {
224 | fs.Modules[i].ID = module.ID
225 | fs.Modules[i].Type = module.Type
226 | fs.Modules[i].Properties = make(state.FlowStateModuleProperties)
227 | for k, v := range module.Properties {
228 | fs.Modules[i].Properties[k] = v
229 | }
230 | for k, v := range module.Properties {
231 | fs.Modules[i].Properties[k] = v
232 | }
233 | fs.Modules[i].State = make(map[string]interface{})
234 | fs.Modules[i].Criteria = module.Criteria
235 | }
236 | return fs
237 | }
238 |
--------------------------------------------------------------------------------
/pkg/auth/flow_test.go:
--------------------------------------------------------------------------------
1 | package auth
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
7 | "github.com/maximthomas/gortas/pkg/auth/constants"
8 | "github.com/maximthomas/gortas/pkg/auth/state"
9 | "github.com/maximthomas/gortas/pkg/session"
10 |
11 | "github.com/maximthomas/gortas/pkg/config"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | const testFlowID = "test-flow-id"
16 | const corruptedFlowID = "corrupted-flow-id"
17 |
18 | func init() {
19 | flows := map[string]config.Flow{
20 | "login": {Modules: []config.Module{
21 | {
22 | ID: "login",
23 | Type: "login",
24 | },
25 | }},
26 | "register": {Modules: []config.Module{
27 | {
28 | ID: "registration",
29 | Properties: map[string]interface{}{
30 | "testProp": "testVal",
31 | "additionalFields": []map[interface{}]interface{}{{
32 | "dataStore": "name",
33 | "prompt": "Name",
34 | },
35 | },
36 | },
37 | },
38 | },
39 | },
40 | "sso": {Modules: []config.Module{}},
41 | }
42 |
43 | conf := config.Config{
44 | Flows: flows,
45 | Session: session.Config{
46 | Type: "stateful",
47 | },
48 | }
49 | config.SetConfig(&conf)
50 |
51 | s := session.Session{
52 | ID: testFlowID,
53 | Properties: map[string]string{
54 | constants.FlowStateSessionProperty: "{}",
55 | },
56 | }
57 | _, _ = session.GetSessionService().CreateSession(s)
58 | corruptedSession := session.Session{
59 | ID: corruptedFlowID,
60 | Properties: map[string]string{
61 | constants.FlowStateSessionProperty: "bad",
62 | },
63 | }
64 | _, _ = session.GetSessionService().CreateSession(corruptedSession)
65 | }
66 |
67 | func TestGetFlowState(t *testing.T) {
68 | fp := flowProcessor{}
69 | tests := []struct {
70 | name string
71 | realm string
72 | flowName string
73 | flowID string
74 | checkError func(t assert.TestingT, err error, msgAndArgs ...interface{}) bool
75 | checkFlow func(t assert.TestingT, fs state.FlowState)
76 | }{
77 | {name: "existing flow", flowName: "login", checkError: assert.NoError,
78 | checkFlow: func(t assert.TestingT, fs state.FlowState) { assert.NotNil(t, fs) }},
79 | {name: "non existing flow", flowName: "bad", checkError: assert.Error,
80 | checkFlow: func(t assert.TestingT, fs state.FlowState) { assert.True(t, fs.ID == "") }},
81 | {name: "existing flowId", flowID: testFlowID, checkError: assert.NoError,
82 | checkFlow: func(t assert.TestingT, fs state.FlowState) { assert.NotNil(t, fs) }},
83 | {name: "corrupted flowId", flowID: corruptedFlowID, checkError: assert.Error,
84 | checkFlow: func(t assert.TestingT, fs state.FlowState) { assert.True(t, fs.ID == "") }},
85 | {name: "non existing flowId", flowName: "login", flowID: "bad-flow-id",
86 | checkError: assert.NoError, checkFlow: func(t assert.TestingT, fs state.FlowState) { assert.NotNil(t, fs) }},
87 | }
88 |
89 | for _, tt := range tests {
90 | t.Run(tt.name, func(t *testing.T) {
91 | fs, err := fp.getFlowState(tt.flowName, tt.flowID)
92 | tt.checkError(t, err)
93 | tt.checkFlow(t, fs)
94 | })
95 | }
96 | }
97 |
98 | func TestProcess(t *testing.T) {
99 | fp := NewFlowProcessor()
100 | var cbReq callbacks.Request
101 | cbResp, err := fp.Process("login", cbReq, nil, nil)
102 | assert.NoError(t, err)
103 | assert.True(t, len(cbResp.Callbacks) > 0)
104 | assert.Equal(t, "login", cbResp.Module)
105 | assert.NotEmpty(t, cbResp.FlowID)
106 |
107 | //invalid login and password
108 | cbReq = callbacks.Request{
109 | Module: cbResp.Module,
110 | Callbacks: cbResp.Callbacks,
111 | }
112 | cbReq.Callbacks[0].Value = "test"
113 | cbReq.Callbacks[1].Value = "test"
114 | cbReq.FlowID = cbResp.FlowID
115 | cbResp, err = fp.Process("login", cbReq, nil, nil)
116 | assert.NoError(t, err)
117 | assert.True(t, len(cbResp.Callbacks) > 0)
118 | assert.Equal(t, "login", cbResp.Module)
119 | assert.NotEmpty(t, cbResp.FlowID)
120 | assert.Equal(t, "Invalid username or password", cbResp.Callbacks[0].Error)
121 |
122 | //valid login and password
123 | cbReq.Callbacks[0].Value = "user1"
124 | cbReq.Callbacks[1].Value = "password"
125 | cbResp, err = fp.Process("login", cbReq, nil, nil)
126 | assert.NoError(t, err)
127 | assert.True(t, len(cbResp.Callbacks) == 0)
128 | assert.Empty(t, cbResp.FlowID)
129 | assert.NotEmpty(t, cbResp.Token)
130 | }
131 |
132 | //TODO v0 add test with complex flow (2FA)
133 |
--------------------------------------------------------------------------------
/pkg/auth/modules/credentials.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "regexp"
5 |
6 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
7 | "github.com/maximthomas/gortas/pkg/auth/state"
8 | "github.com/maximthomas/gortas/pkg/user"
9 | "github.com/mitchellh/mapstructure"
10 | "github.com/pkg/errors"
11 | )
12 |
13 | type Credentials struct {
14 | BaseAuthModule
15 | PrimaryField Field
16 | AdditionalFields []Field
17 | credentialsState *credentialsState
18 | }
19 |
20 | type credentialsState struct {
21 | UserID string
22 | Properties map[string]string
23 | }
24 |
25 | func (cm *Credentials) Process(s *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
26 | return state.InProgress, cm.Callbacks, nil
27 | }
28 |
29 | func (cm *Credentials) ProcessCallbacks(inCbs []callbacks.Callback, s *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
30 | cbs = make([]callbacks.Callback, len(cm.Callbacks))
31 | copy(cbs, cm.Callbacks)
32 |
33 | callbacksValid := true
34 |
35 | // callbacks validation
36 | for i := range inCbs {
37 | cb := inCbs[i]
38 | if cb.Value == "" && cbs[i].Required {
39 | (&cbs[i]).Error = (&cbs[i]).Prompt + " required"
40 | callbacksValid = false
41 | } else if cbs[i].Validation != "" {
42 | var re *regexp.Regexp
43 | re, err = regexp.Compile(cbs[i].Validation)
44 | if err != nil {
45 | cm.l.Errorf("error compiling regex for callback %v", cb.Validation)
46 | return state.Fail, nil, errors.Wrapf(err, "error compiling regex for callback %v", cb.Validation)
47 | }
48 | match := re.MatchString(cb.Value)
49 | if !match {
50 | (&cbs[i]).Error = (&cbs[i]).Prompt + " invalid"
51 | callbacksValid = false
52 | }
53 | }
54 | }
55 |
56 | if !callbacksValid {
57 | return state.InProgress, cbs, err
58 | }
59 |
60 | // fill state values
61 |
62 | for i := range inCbs {
63 | cb := inCbs[i]
64 | if cb.Name == cm.PrimaryField.Name {
65 | cm.credentialsState.UserID = cb.Value
66 | } else {
67 | cm.credentialsState.Properties[cb.Name] = cb.Value
68 | }
69 | }
70 | cm.updateState()
71 | s.UserID = cm.credentialsState.UserID
72 | return state.Pass, nil, err
73 | }
74 |
75 | func (cm *Credentials) updateState() {
76 | cm.State["userId"] = cm.credentialsState.UserID
77 | cm.State["properties"] = cm.credentialsState.Properties
78 | }
79 |
80 | func (cm *Credentials) ValidateCallbacks(cbs []callbacks.Callback) error {
81 | return cm.BaseAuthModule.ValidateCallbacks(cbs)
82 | }
83 |
84 | func (cm *Credentials) PostProcess(fs *state.FlowState) error {
85 | moduleUser := user.User{
86 | ID: cm.credentialsState.UserID,
87 | Properties: cm.credentialsState.Properties,
88 | }
89 | us := user.GetUserService()
90 | u, ok := us.GetUser(moduleUser.ID)
91 | var err error
92 | if !ok {
93 | u, err = us.CreateUser(moduleUser)
94 | if err != nil {
95 | return errors.Wrap(err, "error creating user")
96 | }
97 | } else {
98 | u.Properties = moduleUser.Properties
99 | err = us.UpdateUser(u)
100 | if err != nil {
101 | return errors.Wrap(err, "error updating user")
102 | }
103 | }
104 |
105 | return nil
106 | }
107 |
108 | func init() {
109 | RegisterModule("credentials", newCredentials)
110 | }
111 |
112 | func newCredentials(base BaseAuthModule) AuthModule {
113 | var cm Credentials
114 | err := mapstructure.Decode(base.Properties, &cm)
115 | if err != nil {
116 | panic(err) // TODO add error processing
117 | }
118 |
119 | if cm.PrimaryField.Name == "" {
120 | cm.PrimaryField = Field{
121 | Name: "login",
122 | Prompt: "Login",
123 | Required: true,
124 | }
125 | }
126 | cs := credentialsState{
127 | UserID: "",
128 | Properties: make(map[string]string),
129 | }
130 | _ = mapstructure.Decode(base.State, &cs)
131 |
132 | cm.BaseAuthModule = base
133 | cm.credentialsState = &cs
134 |
135 | cbLen := len(cm.AdditionalFields) + 1
136 |
137 | adcbs := make([]callbacks.Callback, cbLen)
138 | if cm.AdditionalFields != nil {
139 | for i, af := range cm.AdditionalFields {
140 | adcbs[i+1] = callbacks.Callback{
141 | Name: af.Name,
142 | Type: callbacks.TypeText,
143 | Value: "",
144 | Prompt: af.Prompt,
145 | Required: af.Required,
146 | Validation: af.Validation,
147 | }
148 | }
149 | }
150 | pf := cm.PrimaryField
151 | adcbs[0] = callbacks.Callback{
152 | Name: pf.Name,
153 | Type: callbacks.TypeText,
154 | Prompt: pf.Prompt,
155 | Value: "",
156 | Required: true,
157 | Validation: pf.Validation,
158 | }
159 |
160 | (&cm.BaseAuthModule).Callbacks = adcbs
161 | return &cm
162 | }
163 |
--------------------------------------------------------------------------------
/pkg/auth/modules/credentials_test.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
7 | "github.com/maximthomas/gortas/pkg/auth/state"
8 | "github.com/maximthomas/gortas/pkg/config"
9 | "github.com/maximthomas/gortas/pkg/user"
10 | "github.com/sirupsen/logrus"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func TestCredentialsProcess(t *testing.T) {
15 | cm := getCredentialsModule(t)
16 | ms, cbs, err := cm.Process(nil)
17 | assert.NoError(t, err)
18 | assert.Equal(t, 2, len(cbs))
19 | assert.Equal(t, state.InProgress, ms)
20 |
21 | assert.Equal(t, "login", cbs[0].Name)
22 | assert.Equal(t, "name", cbs[1].Name)
23 | }
24 |
25 | func TestCredentialsProcessCallbacks(t *testing.T) {
26 |
27 | cm := getCredentialsModule(t)
28 | const testEmail = "test@test.com"
29 | const testName = "John Doe"
30 |
31 | var tests = []struct {
32 | test string
33 | email string
34 | name string
35 | emailErr string
36 | nameErr string
37 | status state.ModuleStatus
38 | }{
39 | {
40 | test: "empty name email",
41 | email: "",
42 | name: "",
43 | emailErr: "Email required",
44 | nameErr: "Name required",
45 | status: state.InProgress,
46 | },
47 | {
48 | test: "invalid email",
49 | email: "bad",
50 | name: testName,
51 | emailErr: "Email invalid",
52 | nameErr: "",
53 | status: state.InProgress,
54 | },
55 | {
56 | test: "valid name email",
57 | email: testEmail,
58 | name: testName,
59 | emailErr: "",
60 | nameErr: "",
61 | status: state.Pass,
62 | },
63 | }
64 | for _, tt := range tests {
65 | t.Run(tt.test, func(t *testing.T) {
66 | inCbs := []callbacks.Callback{
67 | {
68 | Name: "login",
69 | Value: tt.email,
70 | },
71 | {
72 | Name: "name",
73 | Value: tt.name,
74 | },
75 | }
76 | var fs state.FlowState
77 | ms, cbs, err := cm.ProcessCallbacks(inCbs, &fs)
78 | assert.NoError(t, err)
79 | assert.Equal(t, tt.status, ms)
80 | switch ms {
81 | case state.InProgress:
82 | assert.Equal(t, 2, len(cbs))
83 | assert.Equal(t, tt.emailErr, cbs[0].Error)
84 | assert.Equal(t, tt.nameErr, cbs[1].Error)
85 | case state.Pass:
86 | assert.Equal(t, testEmail, cm.credentialsState.UserID)
87 | assert.Equal(t, testName, cm.credentialsState.Properties["name"])
88 |
89 | assert.Equal(t, testEmail, cm.State["userId"].(string))
90 | props := cm.State["properties"].(map[string]string)
91 | name := props["name"]
92 | assert.Equal(t, testName, name)
93 | }
94 | })
95 | }
96 | }
97 |
98 | func TestCredentiaslPostProcess(t *testing.T) {
99 | const testEmail = "test@test.com"
100 | const testName = "John Doe"
101 |
102 | cm := getCredentialsModule(t)
103 | cm.credentialsState = &credentialsState{
104 | UserID: testEmail,
105 | Properties: map[string]string{
106 | "name": testName,
107 | },
108 | }
109 |
110 | us := user.GetUserService()
111 | _, ok := us.GetUser(testEmail)
112 | assert.False(t, ok, "User does not exists")
113 | fs := &state.FlowState{}
114 | err := cm.PostProcess(fs)
115 | assert.NoError(t, err)
116 |
117 | u, ok := us.GetUser(testEmail)
118 | assert.True(t, ok, "user exists")
119 | assert.Equal(t, testEmail, u.ID)
120 | assert.Equal(t, testName, u.Properties["name"])
121 | }
122 |
123 | func TestGetCredentialsModule(t *testing.T) {
124 | cm := getCredentialsModule(t)
125 | assert.NotNil(t, cm)
126 |
127 | }
128 |
129 | func getCredentialsModule(t *testing.T) *Credentials {
130 | conf := config.Config{}
131 | config.SetConfig(&conf)
132 |
133 | const emailRegexp = "^([a-z0-9_-]+)(@[a-z0-9-]+)(\\.[a-z]+|\\.[a-z]+\\.[a-z]+)?$"
134 | var b = BaseAuthModule{
135 | l: logrus.New().WithField("module", "credentials"),
136 | Properties: map[string]interface{}{
137 | "primaryField": Field{
138 | Name: "login",
139 | Prompt: "Email",
140 | Required: true,
141 | Validation: emailRegexp,
142 | },
143 | "additionalFields": []Field{{
144 | Name: "name",
145 | Prompt: "Name",
146 | Required: true,
147 | },
148 | },
149 | },
150 | State: make(map[string]interface{}),
151 | }
152 | var m = newCredentials(b)
153 | cm, ok := m.(*Credentials)
154 | assert.True(t, ok)
155 | return cm
156 | }
157 |
--------------------------------------------------------------------------------
/pkg/auth/modules/hydra.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "crypto/tls"
7 | "encoding/json"
8 | "fmt"
9 | "io"
10 | "net/http"
11 | "net/url"
12 |
13 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
14 | "github.com/maximthomas/gortas/pkg/auth/state"
15 | )
16 |
17 | // Hydra ORY Hydra authentication module
18 | type Hydra struct {
19 | BaseAuthModule
20 | URI string // hydra URI
21 | client *http.Client
22 | }
23 |
24 | type hydraLoginData struct {
25 | Skip bool `json:"skip"`
26 | Subject string `json:"subject"`
27 | }
28 |
29 | type hydraSubject struct {
30 | Subject string `json:"subject"`
31 | Remember bool `json:"remember"`
32 | RememberFor int32 `json:"remember_for"`
33 | ACR string `json:"acr"`
34 | }
35 |
36 | func (h *Hydra) getLoginChallenge() string {
37 | return url.PathEscape(h.req.URL.Query().Get("login_challenge"))
38 | }
39 |
40 | func (h *Hydra) Process(_ *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
41 | hydraLoginURL := fmt.Sprintf("%s/oauth2/auth/requests/login?login_challenge=%s", h.URI, h.getLoginChallenge())
42 | ctx := context.Background()
43 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, hydraLoginURL, http.NoBody)
44 | if err != nil {
45 | return state.Fail, h.Callbacks, fmt.Errorf("Process %v: %v", hydraLoginURL, err)
46 | }
47 | resp, err := h.client.Do(req)
48 | if err != nil {
49 | return state.Fail, h.Callbacks, fmt.Errorf("Process %v: %v", hydraLoginURL, err)
50 | }
51 |
52 | defer resp.Body.Close()
53 | body, err := io.ReadAll(resp.Body)
54 | if err != nil {
55 | return state.Fail, h.Callbacks, err
56 | }
57 | var hld hydraLoginData
58 | err = json.Unmarshal(body, &hld)
59 | if err != nil {
60 | return state.Fail, h.Callbacks, err
61 | }
62 |
63 | return state.Pass, h.Callbacks, err
64 | }
65 |
66 | func (h *Hydra) ProcessCallbacks(_ []callbacks.Callback, s *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
67 | return h.Process(s)
68 | }
69 |
70 | func (h *Hydra) ValidateCallbacks(cbs []callbacks.Callback) error {
71 | return h.BaseAuthModule.ValidateCallbacks(cbs)
72 | }
73 |
74 | func (h *Hydra) PostProcess(fs *state.FlowState) error {
75 |
76 | hs := hydraSubject{
77 | Subject: fs.UserID,
78 | Remember: false,
79 | RememberFor: 0,
80 | ACR: "gortas",
81 | }
82 |
83 | // marshal User to json
84 | jsonBody, err := json.Marshal(hs)
85 | if err != nil {
86 | return err
87 | }
88 | uri := fmt.Sprintf("%s/oauth2/auth/requests/login/accept?login_challenge=%s", h.URI, h.getLoginChallenge())
89 | ctx := context.Background()
90 | req, err := http.NewRequestWithContext(ctx, http.MethodPut, uri, bytes.NewBuffer(jsonBody))
91 | if err != nil {
92 | return err
93 | }
94 |
95 | req.Header.Set("Content-Type", "application/json; charset=utf-8")
96 | resp, err := h.client.Do(req)
97 | if err != nil {
98 | panic(err)
99 | }
100 |
101 | defer resp.Body.Close()
102 | body, err := io.ReadAll(resp.Body)
103 | if err != nil {
104 | return err
105 | }
106 | var hRes struct {
107 | RedirectTo string `json:"redirect_to"`
108 | }
109 | err = json.Unmarshal(body, &hRes)
110 | if err != nil {
111 | return err
112 | }
113 | fs.RedirectURI = hRes.RedirectTo
114 | return nil
115 |
116 | }
117 |
118 | func init() {
119 | RegisterModule("hydra", newHydraModule)
120 | }
121 |
122 | func newHydraModule(base BaseAuthModule) AuthModule {
123 | skipTLS, _ := base.Properties["skiptls"].(bool)
124 |
125 | uri, ok := base.Properties["uri"].(string)
126 |
127 | if !ok {
128 | panic("hydra module missing uri property")
129 | }
130 | client := &http.Client{
131 | Transport: &http.Transport{
132 | TLSClientConfig: &tls.Config{
133 | InsecureSkipVerify: skipTLS,
134 | },
135 | },
136 | }
137 |
138 | return &Hydra{URI: uri, client: client}
139 | }
140 |
--------------------------------------------------------------------------------
/pkg/auth/modules/hydra_test.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "log"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 |
9 | "github.com/maximthomas/gortas/pkg/auth/state"
10 | "github.com/stretchr/testify/assert"
11 | )
12 |
13 | func TestHydra(t *testing.T) {
14 |
15 | var loginChallenge = "12345"
16 |
17 | server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
18 | assert.Equal(t, loginChallenge, req.URL.Query()["login_challenge"][0])
19 | if req.Method == http.MethodGet && req.URL.Path == "/oauth2/auth/requests/login" {
20 |
21 | _, _ = rw.Write([]byte(`{
22 | "skip": false,
23 | "subject": "user-id",
24 | "client": {"id": "test_client"},
25 | "request_url": "https://hydra/oauth2/auth?client_id=1234&scope=foo+bar&response_type=code",
26 | "requested_scope": ["foo", "bar"],
27 | "oidc_context": {"ui_locales": []},
28 | "context": {}
29 | }`))
30 | return
31 | } else if req.Method == http.MethodPut && req.URL.Path == "/oauth2/auth/requests/login/accept" {
32 | _, _ = rw.Write([]byte(`{
33 | "redirect_to": "https://hydra/"
34 | }`))
35 | return
36 | }
37 |
38 | _, _ = rw.Write([]byte(`{"ok":"ok"}`))
39 | }))
40 |
41 | defer server.Close()
42 |
43 | b := BaseAuthModule{
44 | Properties: map[string]interface{}{
45 | "uri": server.URL,
46 | },
47 | }
48 | am := newHydraModule(b)
49 | h, _ := am.(*Hydra)
50 |
51 | assert.Equal(t, server.URL, h.URI)
52 |
53 | t.Run("Test process", func(t *testing.T) {
54 | h.req = httptest.NewRequest("GET", "/login?login_challenge="+loginChallenge, nil)
55 | h.w = httptest.NewRecorder()
56 | fs := &state.FlowState{}
57 |
58 | status, cbs, err := h.Process(fs)
59 |
60 | assert.NoError(t, err)
61 |
62 | log.Print(status, cbs, err)
63 | })
64 |
65 | t.Run("Test PostProcess", func(t *testing.T) {
66 | h.req = httptest.NewRequest("GET", "/login?login_challenge="+loginChallenge, nil)
67 | h.w = httptest.NewRecorder()
68 |
69 | fs := &state.FlowState{}
70 |
71 | err := h.PostProcess(fs)
72 |
73 | assert.NoError(t, err)
74 | assert.Equal(t, "https://hydra/", fs.RedirectURI)
75 | })
76 | }
77 |
--------------------------------------------------------------------------------
/pkg/auth/modules/kerberos.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "encoding/base64"
5 | "encoding/hex"
6 | "fmt"
7 | "log"
8 | "strings"
9 |
10 | "github.com/pkg/errors"
11 |
12 | "github.com/jcmturner/gokrb5/v8/credentials"
13 | "github.com/jcmturner/gokrb5/v8/gssapi"
14 | "github.com/jcmturner/gokrb5/v8/keytab"
15 | "github.com/jcmturner/gokrb5/v8/service"
16 | "github.com/jcmturner/gokrb5/v8/spnego"
17 | "github.com/jcmturner/gokrb5/v8/types"
18 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
19 | "github.com/maximthomas/gortas/pkg/auth/state"
20 | )
21 |
22 | type Kerberos struct {
23 | BaseAuthModule
24 | servicePrincipal string
25 | kt *keytab.Keytab
26 | }
27 |
28 | const (
29 | keyTabFileProperty = "keytabfile"
30 | keyTabDataProperty = "keytabdata"
31 | servicePrincipalProperty = "serviceprincipal"
32 | ctxCredentials = "github.com/jcmturner/gokrb5/v8/ctxCredentials"
33 | )
34 |
35 | var outCallback = []callbacks.Callback{
36 | {
37 | Name: "httpstatus",
38 | Value: "401",
39 | Properties: map[string]string{
40 | spnego.HTTPHeaderAuthResponse: spnego.HTTPHeaderAuthResponseValueKey,
41 | },
42 | },
43 | }
44 |
45 | func init() {
46 | RegisterModule("kerberos", newKerberosModule)
47 | }
48 |
49 | func newKerberosModule(base BaseAuthModule) AuthModule {
50 | k := &Kerberos{
51 | BaseAuthModule: base,
52 | }
53 | var kt *keytab.Keytab
54 | var err error
55 | if ktFileProp, ok := k.BaseAuthModule.Properties[keyTabFileProperty]; ok {
56 | ktFile, _ := ktFileProp.(string)
57 | kt, err = keytab.Load(ktFile)
58 | if err != nil {
59 | panic(err) // If the "krb5.keytab" file is not available the application will show an error message.
60 | }
61 | } else if ktDataProp, ok := k.BaseAuthModule.Properties[keyTabDataProperty]; ok {
62 | ktData := ktDataProp.(string)
63 | b, _ := hex.DecodeString(ktData)
64 | kt = keytab.New()
65 | err = kt.Unmarshal(b)
66 | if err != nil {
67 | panic(err)
68 | }
69 | }
70 | k.kt = kt
71 | if spProp, ok := k.BaseAuthModule.Properties[servicePrincipalProperty]; ok {
72 | k.servicePrincipal = spProp.(string)
73 | }
74 |
75 | return k
76 | }
77 |
78 | func (k *Kerberos) Process(fs *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
79 |
80 | servicePrincipal := k.servicePrincipal
81 | kt := k.kt
82 | log.Print(kt)
83 | if err != nil {
84 | panic(err) // If the "krb5.keytab" file is not available the application will show an error message.
85 | }
86 | r := k.req
87 | s := strings.SplitN(r.Header.Get(spnego.HTTPHeaderAuthRequest), " ", 2)
88 | if len(s) != 2 || s[0] != spnego.HTTPHeaderAuthResponseValueKey {
89 | return state.InProgress, outCallback, err
90 | }
91 |
92 | settings := service.KeytabPrincipal(servicePrincipal)
93 | // Set up the SPNEGO GSS-API mechanism
94 | var spnegoMech *spnego.SPNEGO
95 | h, err := types.GetHostAddress(k.req.RemoteAddr)
96 | if err == nil {
97 | // put in this order so that if the user provides a ClientAddress it will override the one here.
98 | o := append([]func(*service.Settings){service.ClientAddress(h)}, settings)
99 | spnegoMech = spnego.SPNEGOService(kt, o...)
100 | } else {
101 | spnegoMech = spnego.SPNEGOService(kt)
102 | log.Printf("%s - SPNEGO could not parse client address: %v", r.RemoteAddr, err)
103 | }
104 |
105 | // Decode the header into an SPNEGO context token
106 | b, err := base64.StdEncoding.DecodeString(s[1])
107 | if err != nil {
108 | errText := fmt.Sprintf("%s - SPNEGO error in base64 decoding negotiation header: %v", r.RemoteAddr, err)
109 | log.Print(errText)
110 | return ms, cbs, errors.New(errText)
111 | }
112 | var st spnego.SPNEGOToken
113 | err = st.Unmarshal(b)
114 | if err != nil {
115 | errText := fmt.Sprintf("%s - SPNEGO error in unmarshaling SPNEGO token: %v", r.RemoteAddr, err)
116 | log.Print(errText)
117 | return ms, cbs, errors.New(errText)
118 | }
119 |
120 | // Validate the context token
121 | authed, ctx, status := spnegoMech.AcceptSecContext(&st)
122 | if status.Code != gssapi.StatusComplete && status.Code != gssapi.StatusContinueNeeded {
123 | errText := fmt.Sprintf("%s - SPNEGO validation error: %v", r.RemoteAddr, status)
124 | log.Print(errText)
125 | return ms, cbs, errors.New(errText)
126 | }
127 | if status.Code == gssapi.StatusContinueNeeded {
128 | errText := fmt.Sprintf("%s - SPNEGO GSS-API continue needed", r.RemoteAddr)
129 | log.Print(errText)
130 | return ms, cbs, errors.New(errText)
131 | }
132 | if authed {
133 | // Authentication successful; get user's credentials from the context
134 | id := ctx.Value(ctxCredentials).(*credentials.Credentials)
135 | fs.UserID = id.UserName()
136 | log.Printf("%s %s@%s - SPNEGO authentication succeeded", r.RemoteAddr, id.UserName(), id.Domain())
137 | return state.Pass, k.Callbacks, err
138 | }
139 | errText := fmt.Sprintf("%s - SPNEGO Kerberos authentication failed", r.RemoteAddr)
140 | log.Print(errText)
141 | return ms, cbs, errors.New(errText)
142 |
143 | }
144 |
145 | func (k *Kerberos) ProcessCallbacks(_ []callbacks.Callback, _ *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
146 | return state.InProgress, outCallback, err
147 | }
148 |
149 | func (k *Kerberos) ValidateCallbacks(cbs []callbacks.Callback) error {
150 | return k.BaseAuthModule.ValidateCallbacks(cbs)
151 | }
152 |
153 | func (k *Kerberos) PostProcess(_ *state.FlowState) error {
154 | return nil
155 | }
156 |
--------------------------------------------------------------------------------
/pkg/auth/modules/kerberos_test.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "log"
5 | "net/http/httptest"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 |
10 | "github.com/jcmturner/gokrb5/v8/test/testdata"
11 | "github.com/maximthomas/gortas/pkg/auth/state"
12 | )
13 |
14 | func TestKerberos(t *testing.T) {
15 |
16 | b := BaseAuthModule{
17 | Properties: map[string]interface{}{
18 | keyTabDataProperty: testdata.KEYTAB_TESTUSER1_TEST_GOKRB5,
19 | servicePrincipalProperty: "HTTP/authservice@ADKERBEROS",
20 | },
21 | }
22 |
23 | m := newKerberosModule(b)
24 | k, _ := m.(*Kerberos)
25 |
26 | t.Run("Test request negotiate", func(t *testing.T) {
27 |
28 | recorder := httptest.NewRecorder()
29 | k.req = httptest.NewRequest("GET", "/login", nil)
30 | k.w = recorder
31 | fs := &state.FlowState{}
32 |
33 | status, cbs, err := k.Process(fs)
34 |
35 | assert.NoError(t, err)
36 | log.Print(status, cbs, err)
37 | assert.Equal(t, 1, len(cbs))
38 | assert.Equal(t, "httpstatus", cbs[0].Name)
39 | assert.Equal(t, "Negotiate", cbs[0].Properties["WWW-Authenticate"])
40 |
41 | assert.Equal(t, state.InProgress, status)
42 | })
43 |
44 | t.Run("Test failed authentication", func(t *testing.T) {
45 | req := httptest.NewRequest("GET", "/login", nil)
46 | req.Header.Add("Authorization", "Negotiate bad token")
47 | k.req = req
48 | k.w = httptest.NewRecorder()
49 | fs := &state.FlowState{}
50 |
51 | status, cbs, err := k.Process(fs)
52 |
53 | log.Print(status, cbs, err)
54 |
55 | assert.Error(t, err)
56 | })
57 | }
58 |
--------------------------------------------------------------------------------
/pkg/auth/modules/login.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
5 | "github.com/maximthomas/gortas/pkg/auth/state"
6 | "github.com/maximthomas/gortas/pkg/user"
7 | )
8 |
9 | type LoginPassword struct {
10 | BaseAuthModule
11 | }
12 |
13 | func (lm *LoginPassword) Process(_ *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
14 | return state.InProgress, lm.Callbacks, err
15 | }
16 |
17 | func (lm *LoginPassword) ProcessCallbacks(inCbs []callbacks.Callback, fs *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
18 | var username string
19 | var password string
20 |
21 | for i := range inCbs {
22 | cb := inCbs[i]
23 | switch cb.Name {
24 | case "login":
25 | username = cb.Value
26 | case "password":
27 | password = cb.Value
28 | }
29 | }
30 | us := user.GetUserService()
31 | valid := us.ValidatePassword(username, password)
32 | if valid {
33 | fs.UserID = username
34 | return state.Pass, cbs, err
35 | }
36 | cbs = lm.Callbacks
37 | (&cbs[0]).Error = "Invalid username or password"
38 | return state.InProgress, cbs, err
39 |
40 | }
41 |
42 | func (lm *LoginPassword) ValidateCallbacks(cbs []callbacks.Callback) error {
43 | return lm.BaseAuthModule.ValidateCallbacks(cbs)
44 | }
45 |
46 | func (lm *LoginPassword) PostProcess(_ *state.FlowState) error {
47 | return nil
48 | }
49 |
50 | func init() {
51 | RegisterModule("login", newLoginPassword)
52 | }
53 |
54 | func newLoginPassword(base BaseAuthModule) AuthModule {
55 | (&base).Callbacks = []callbacks.Callback{
56 | {
57 | Name: "login",
58 | Type: callbacks.TypeText,
59 | Prompt: "Login",
60 | Value: "",
61 | Required: true,
62 | },
63 | {
64 | Name: "password",
65 | Type: callbacks.TypePassword,
66 | Prompt: "Password",
67 | Value: "",
68 | Required: true,
69 | },
70 | }
71 | return &LoginPassword{
72 | base,
73 | }
74 | }
75 |
--------------------------------------------------------------------------------
/pkg/auth/modules/module.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "sync"
7 |
8 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
9 | "github.com/maximthomas/gortas/pkg/auth/state"
10 | "github.com/maximthomas/gortas/pkg/log"
11 | "github.com/sirupsen/logrus"
12 | )
13 |
14 | type AuthModule interface {
15 | Process(s *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error)
16 | ProcessCallbacks(inCbs []callbacks.Callback, s *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error)
17 | ValidateCallbacks(cbs []callbacks.Callback) error
18 | PostProcess(fs *state.FlowState) error
19 | }
20 |
21 | type Field struct {
22 | Name string
23 | Prompt string
24 | Required bool
25 | Validation string
26 | }
27 |
28 | var modulesRegistry = &sync.Map{}
29 |
30 | func RegisterModule(mt string, constructor func(BaseAuthModule) AuthModule) {
31 | logrus.Infof("registered %v module", mt)
32 | modulesRegistry.Store(mt, constructor)
33 | }
34 |
35 | type moduleConstructor = func(base BaseAuthModule) AuthModule
36 |
37 | func GetAuthModule(mi state.FlowStateModuleInfo, req *http.Request, w http.ResponseWriter) (AuthModule, error) {
38 | base := BaseAuthModule{
39 | Properties: mi.Properties,
40 | State: mi.State,
41 | req: req,
42 | w: w,
43 | l: log.WithField("module", mi.Type),
44 | }
45 | constructor, ok := modulesRegistry.Load(mi.Type)
46 | if !ok {
47 | return nil, fmt.Errorf("module %v does not exists", mi.Type)
48 | }
49 | if c, ok := constructor.(moduleConstructor); ok {
50 | return moduleConstructor(c)(base), nil
51 | }
52 | return nil, fmt.Errorf("error converting %v to module constructor", constructor)
53 | }
54 |
55 | type BaseAuthModule struct {
56 | Properties map[string]interface{}
57 | Callbacks []callbacks.Callback
58 | State map[string]interface{}
59 | req *http.Request
60 | w http.ResponseWriter
61 | l *logrus.Entry
62 | }
63 |
64 | func (b BaseAuthModule) ValidateCallbacks(cbs []callbacks.Callback) error {
65 | err := fmt.Errorf("callbacks does not match %v %v", b.Callbacks, cbs)
66 | if len(cbs) == len(b.Callbacks) {
67 | for i := range cbs {
68 | if cbs[i].Name != b.Callbacks[i].Name {
69 | return err
70 | }
71 | }
72 | return nil
73 | }
74 | return err
75 | }
76 |
--------------------------------------------------------------------------------
/pkg/auth/modules/module_test.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
7 | "github.com/maximthomas/gortas/pkg/auth/state"
8 | "github.com/maximthomas/gortas/pkg/config"
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | type SimpleModule struct {
13 | BaseAuthModule
14 | }
15 |
16 | func (sm *SimpleModule) Process(_ *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
17 | return state.InProgress, sm.Callbacks, err
18 | }
19 |
20 | func (sm *SimpleModule) ProcessCallbacks(inCbs []callbacks.Callback, fs *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
21 | return state.InProgress, sm.Callbacks, err
22 | }
23 |
24 | func (sm *SimpleModule) ValidateCallbacks(cbs []callbacks.Callback) error {
25 | return sm.BaseAuthModule.ValidateCallbacks(cbs)
26 | }
27 |
28 | func (sm *SimpleModule) PostProcess(_ *state.FlowState) error {
29 | return nil
30 | }
31 |
32 | func init() {
33 | RegisterModule("simple", newSimpleModule)
34 | }
35 |
36 | func newSimpleModule(base BaseAuthModule) AuthModule {
37 | return &SimpleModule{
38 | base,
39 | }
40 | }
41 |
42 | func TestModuleRegistered(t *testing.T) {
43 | constructor, ok := modulesRegistry.Load("simple")
44 | assert.True(t, ok)
45 |
46 | _, ok = constructor.(moduleConstructor)
47 | assert.True(t, ok)
48 | }
49 |
50 | func TestGetModuleFromRegistry(t *testing.T) {
51 | config.SetConfig(&config.Config{})
52 | mi := state.FlowStateModuleInfo{
53 | ID: "simple",
54 | Type: "simple",
55 | }
56 | m, err := GetAuthModule(mi, nil, nil)
57 | assert.NoError(t, err)
58 |
59 | _, ok := m.(*SimpleModule)
60 | assert.True(t, ok)
61 | }
62 |
--------------------------------------------------------------------------------
/pkg/auth/modules/otp.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "bytes"
5 | "encoding/json"
6 | "fmt"
7 | "os"
8 | "strconv"
9 | "strings"
10 | "text/template"
11 | "time"
12 |
13 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
14 | "github.com/maximthomas/gortas/pkg/auth/constants"
15 | "github.com/maximthomas/gortas/pkg/auth/modules/otp"
16 | "github.com/maximthomas/gortas/pkg/auth/state"
17 | "github.com/maximthomas/gortas/pkg/crypt"
18 | "github.com/maximthomas/gortas/pkg/session"
19 | "github.com/mitchellh/mapstructure"
20 | "github.com/pkg/errors"
21 | )
22 |
23 | const (
24 | actionSend = "send"
25 | actionCheck = "check"
26 | otpSenderProperty = "sender"
27 | otpMagicLinkParameter = "code"
28 | )
29 |
30 | type OTP struct {
31 | BaseAuthModule
32 | OtpLength int
33 | UseLetters bool
34 | UseDigits bool
35 | OtpTimeoutSec int
36 | OtpResendSec int
37 | OtpRetryCount int
38 | OtpMessageTemplate string
39 | OtpCheckMagicLink bool
40 | otpState *otpState
41 | otpSender otp.Sender
42 | }
43 |
44 | type otpState struct {
45 | Retries int
46 | GeneratedAt int64
47 | Otp string
48 | }
49 |
50 | type otpSenderProperties struct {
51 | SenderType string
52 | Properties map[string]interface{}
53 | }
54 |
55 | func (lm *OTP) Process(fs *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
56 | defer lm.updateState()
57 |
58 | // TODO add check expired date
59 | if lm.OtpCheckMagicLink { // TODO refactor move to function and code to constant
60 | return lm.checkMagicLink(fs)
61 | }
62 | return lm.generateAndSendOTP(fs)
63 | }
64 |
65 | func (lm *OTP) checkMagicLink(fs *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
66 |
67 | if lm.req.URL.Query().Get(otpMagicLinkParameter) == "" {
68 | return state.Fail, lm.Callbacks, err
69 | }
70 |
71 | code := lm.req.URL.Query().Get(otpMagicLinkParameter)
72 | codeDecrypted, err := crypt.DecryptWithConfig(code)
73 | if err != nil {
74 | return state.Fail, lm.Callbacks, err
75 | }
76 |
77 | codeParts := strings.Split(codeDecrypted, "|")
78 | sessionID := codeParts[0]
79 | expired, err := strconv.ParseInt(codeParts[1], 10, 0)
80 | if err != nil {
81 | return state.Fail, lm.Callbacks, err
82 | }
83 |
84 | if time.Now().UnixMilli() > expired {
85 | return state.Fail, lm.Callbacks, errors.New("code link expired")
86 | }
87 |
88 | sess, err := session.GetSessionService().GetSession(sessionID)
89 | if err != nil {
90 | return state.Fail, lm.Callbacks, err
91 | }
92 | var oldFlowState state.FlowState
93 | err = json.Unmarshal([]byte(sess.Properties[constants.FlowStateSessionProperty]), &oldFlowState)
94 | if err != nil {
95 | return state.Fail, lm.Callbacks, err
96 | }
97 | fs.UserID = oldFlowState.UserID
98 | for k, v := range oldFlowState.SharedState {
99 | fs.SharedState[k] = v
100 | }
101 |
102 | for i, m := range oldFlowState.Modules {
103 | fs.Modules[i].State = m.State
104 | }
105 |
106 | return state.Pass, lm.Callbacks, err
107 | }
108 |
109 | func (lm *OTP) ProcessCallbacks(inCbs []callbacks.Callback, fs *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
110 | defer lm.updateState()
111 | var o string
112 | var action string
113 | for i := range inCbs {
114 | cb := inCbs[i]
115 | switch cb.Name {
116 | case "otp":
117 | o = cb.Value
118 | case "action":
119 | action = cb.Value
120 | }
121 | }
122 |
123 | if action == actionSend {
124 | return lm.generateAndSendOTP(fs)
125 | }
126 | // TODO move to BaseAuthModule
127 | cbs = make([]callbacks.Callback, len(lm.Callbacks))
128 | copy(cbs, lm.Callbacks)
129 |
130 | generatedTime := lm.otpState.GeneratedAt
131 | expiresAt := generatedTime + int64(lm.OtpTimeoutSec*1000)
132 | if time.Now().UnixMilli() > expiresAt {
133 | (&cbs[0]).Error = "OTP expired"
134 | return state.InProgress, cbs, err
135 | }
136 |
137 | generatedOtp := lm.otpState.Otp
138 | if generatedOtp == "" {
139 | cbs = lm.Callbacks
140 | (&cbs[0]).Error = "OTP was not generated"
141 | return state.InProgress, cbs, err
142 | }
143 |
144 | rc := lm.getRetryCount()
145 | if rc <= 0 {
146 | (&cbs[0]).Error = "OTP retries excceded"
147 | (&cbs[0]).Properties["retryCount"] = strconv.Itoa(rc)
148 | return state.InProgress, cbs, err
149 | }
150 |
151 | valid := generatedOtp == o || os.Getenv("GORTAS_OTP_TEST") == o
152 | if valid {
153 | return state.Pass, cbs, err
154 | }
155 | cbs = lm.Callbacks
156 | lm.incrementRetries()
157 | (&cbs[0]).Error = "Invalid OTP"
158 | lm.updateOTPCallbackProperties(&cbs[0])
159 | return state.InProgress, cbs, err
160 | }
161 |
162 | func (lm *OTP) updateState() {
163 | lm.State["generatedAt"] = lm.otpState.GeneratedAt
164 | lm.State["otp"] = lm.otpState.Otp
165 | lm.State["retries"] = lm.otpState.Retries
166 | }
167 |
168 | func (lm *OTP) generateAndSendOTP(fs *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
169 | cbs = make([]callbacks.Callback, len(lm.Callbacks))
170 | copy(cbs, lm.Callbacks)
171 | generatedAt := lm.otpState.GeneratedAt
172 | // check if send allowed
173 | if generatedAt > time.Now().UnixMilli()-int64(lm.OtpResendSec)*1000 {
174 | (&cbs[1]).Error = "Sending not allowed yet"
175 | lm.updateOTPCallbackProperties(&cbs[0])
176 | return state.InProgress, cbs, err
177 | }
178 |
179 | err = lm.generate()
180 | if err != nil {
181 | return state.Fail, cbs, err
182 | }
183 | err = lm.send(fs)
184 | if err != nil {
185 | return state.Fail, cbs, err
186 | }
187 | lm.updateOTPCallbackProperties(&cbs[0])
188 | return state.InProgress, cbs, err
189 | }
190 |
191 | func (lm *OTP) generate() error {
192 | o, err := crypt.RandomString(lm.OtpLength, lm.UseLetters, lm.UseDigits)
193 | if err != nil {
194 | return errors.Wrap(err, "error generating OTP")
195 | }
196 |
197 | lm.otpState.Otp = o
198 | lm.otpState.GeneratedAt = time.Now().UnixMilli()
199 | return nil
200 | }
201 |
202 | func (lm *OTP) send(fs *state.FlowState) error {
203 | msg, err := lm.getMessage(fs)
204 | if err != nil {
205 | return errors.Wrap(err, "error generating message")
206 | }
207 | err = lm.otpSender.Send(fs.UserID, msg)
208 | if err != nil {
209 | return errors.Wrap(err, "error sending message")
210 | }
211 | return nil
212 |
213 | }
214 |
215 | const millisecondsMultiplier = 1000
216 |
217 | // TODO add authentication link
218 | func (lm *OTP) getMessage(fs *state.FlowState) (string, error) {
219 | tmpl, err := template.New("message").Parse(lm.OtpMessageTemplate)
220 | if err != nil {
221 | return "", err
222 | }
223 |
224 | minutes := lm.OtpTimeoutSec / 60
225 | seconds := lm.OtpTimeoutSec % 60
226 | otpExpiresAt := time.Now().UnixMilli() + int64(lm.OtpTimeoutSec*millisecondsMultiplier)
227 |
228 | otpTimeoutFormatted := fmt.Sprintf("%02d:%02d", minutes, seconds)
229 | magicLink, err := crypt.EncryptWithConfig(fs.ID + "|" + strconv.FormatInt(otpExpiresAt, 10))
230 | if err != nil {
231 | return "", err
232 | }
233 |
234 | otpData := struct {
235 | OTP string
236 | ValidFor string
237 | MagicLink string
238 | }{
239 | OTP: lm.otpState.Otp,
240 | ValidFor: otpTimeoutFormatted,
241 | MagicLink: magicLink,
242 | }
243 |
244 | var b bytes.Buffer
245 | err = tmpl.Execute(&b, otpData)
246 | if err != nil {
247 | return "", err
248 | }
249 |
250 | return b.String(), nil
251 | }
252 |
253 | func (lm *OTP) updateOTPCallbackProperties(cb *callbacks.Callback) {
254 | rc := lm.getRetryCount()
255 |
256 | generatedAt := lm.otpState.GeneratedAt
257 |
258 | sinceGeneratedSec := (time.Now().UnixMilli() - generatedAt) / 1000
259 |
260 | rs := lm.OtpResendSec - int(sinceGeneratedSec)
261 | ts := lm.OtpTimeoutSec - int(sinceGeneratedSec)
262 | cb.Properties["retryCount"] = strconv.Itoa(rc)
263 | cb.Properties["resendSec"] = strconv.Itoa(rs)
264 | cb.Properties["timeoutSec"] = strconv.Itoa(ts)
265 | }
266 |
267 | func (lm *OTP) incrementRetries() {
268 | lm.otpState.Retries++
269 | }
270 |
271 | // TODO deal with retry count and retries to eliminate confusion
272 | func (lm *OTP) getRetryCount() int {
273 | return lm.OtpRetryCount - lm.otpState.Retries
274 | }
275 |
276 | func (lm *OTP) ValidateCallbacks(cbs []callbacks.Callback) error {
277 | return lm.BaseAuthModule.ValidateCallbacks(cbs)
278 | }
279 |
280 | func (lm *OTP) PostProcess(_ *state.FlowState) error {
281 | return nil
282 | }
283 |
284 | func init() {
285 | RegisterModule("otp", newOTP)
286 | }
287 |
288 | func newOTP(base BaseAuthModule) AuthModule {
289 |
290 | var om OTP
291 | err := mapstructure.Decode(base.Properties, &om)
292 | if err != nil {
293 | panic(err) // TODO add error processing
294 | }
295 |
296 | (&base).Callbacks = []callbacks.Callback{
297 | {
298 | Name: "otp",
299 | Type: callbacks.TypeText,
300 | Prompt: "One Time Password",
301 | Value: "",
302 | Required: true,
303 | Properties: map[string]string{
304 | "timeoutSec": strconv.Itoa(om.OtpTimeoutSec),
305 | "resendSec": strconv.Itoa(om.OtpResendSec),
306 | "retryCount": strconv.Itoa(om.OtpRetryCount),
307 | },
308 | },
309 | {
310 | Name: "action",
311 | Type: callbacks.TypeActions,
312 | Required: true,
313 | Value: "check",
314 | Properties: map[string]string{
315 | "values": "send|check", // TODO add Camel case
316 | "skipVerifyFor": "send",
317 | },
318 | },
319 | }
320 |
321 | om.BaseAuthModule = base
322 |
323 | var st otpState
324 | _ = mapstructure.Decode(base.State, &st)
325 | om.otpState = &st
326 |
327 | if !om.OtpCheckMagicLink { // if module just checks magic link, there's no need to init OTP sender
328 | var osp otpSenderProperties
329 | err = mapstructure.Decode(base.Properties[otpSenderProperty], &osp)
330 | if err != nil {
331 | panic(err)
332 | }
333 | var sender otp.Sender
334 | sender, err = otp.GetSender(osp.SenderType, osp.Properties)
335 |
336 | if err != nil {
337 | panic(err)
338 | }
339 |
340 | om.otpSender = sender
341 | }
342 |
343 | return &om
344 | }
345 |
--------------------------------------------------------------------------------
/pkg/auth/modules/otp/email.go:
--------------------------------------------------------------------------------
1 | package otp
2 |
3 | import (
4 | "crypto/tls"
5 | "time"
6 |
7 | "github.com/mitchellh/mapstructure"
8 | mail "github.com/xhit/go-simple-mail/v2"
9 | )
10 |
11 | type EmailSender struct {
12 | From string
13 | Subject string
14 | server *mail.SMTPServer
15 | }
16 |
17 | type smtpProperties struct {
18 | Host string
19 | Port int
20 | Username string
21 | Password string
22 | From string
23 | Subject string
24 | }
25 |
26 | func init() {
27 | RegisterSender("email", NewEmailSender)
28 | }
29 |
30 | func NewEmailSender(props map[string]interface{}) (Sender, error) {
31 | var sp smtpProperties
32 | var sender Sender
33 | err := mapstructure.Decode(props, &sp)
34 | if err != nil {
35 | return sender, err
36 | }
37 |
38 | server := mail.NewSMTPClient()
39 |
40 | server.Host = sp.Host
41 | server.Port = sp.Port
42 | server.Username = sp.Username
43 | server.Password = sp.Password
44 | server.Encryption = mail.EncryptionNone
45 |
46 | server.KeepAlive = false
47 |
48 | server.ConnectTimeout = 5 * time.Second
49 | server.SendTimeout = 5 * time.Second
50 |
51 | server.TLSConfig = &tls.Config{InsecureSkipVerify: true}
52 |
53 | return EmailSender{server: server, From: sp.From, Subject: sp.Subject}, nil
54 | }
55 |
56 | func (es EmailSender) Send(to, text string) error {
57 |
58 | smtpClient, err := es.server.Connect()
59 |
60 | if err != nil {
61 | return err
62 | }
63 |
64 | email := mail.NewMSG()
65 | email.SetFrom(es.From).
66 | AddTo(to).
67 | SetSubject(es.Subject)
68 |
69 | email.SetBody(mail.TextHTML, text)
70 |
71 | if email.Error != nil {
72 | return email.Error
73 | }
74 |
75 | // Call Send and pass the client
76 | err = email.Send(smtpClient)
77 | if err != nil {
78 | return err
79 | }
80 |
81 | return nil
82 | }
83 |
--------------------------------------------------------------------------------
/pkg/auth/modules/otp/email_test.go:
--------------------------------------------------------------------------------
1 | //go:build integration
2 |
3 | package otp
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | const (
12 | emailFrom = "john@test.com"
13 | emailSubject = "OTP"
14 | )
15 |
16 | func TestSendEmail(t *testing.T) {
17 | es := getEmailSender(t)
18 | err := es.Send("test@test.com", "hello email")
19 | assert.NoError(t, err)
20 | }
21 |
22 | func TestCreateEmailSender(t *testing.T) {
23 | es := getEmailSender(t)
24 | assert.Equal(t, emailSubject, es.Subject)
25 | assert.Equal(t, emailFrom, es.From)
26 | }
27 |
28 | func getEmailSender(t *testing.T) EmailSender {
29 |
30 | props := map[string]interface{}{
31 | "Host": "localhost",
32 | "Port": 1025,
33 | "Username": "",
34 | "Password": "",
35 | "From": emailFrom,
36 | "Subject": emailSubject,
37 | }
38 |
39 | s, err := NewEmailSender(props)
40 | es := s.(EmailSender)
41 | assert.NoError(t, err)
42 | assert.NotNil(t, es.server)
43 | return es
44 | }
45 |
--------------------------------------------------------------------------------
/pkg/auth/modules/otp/sender.go:
--------------------------------------------------------------------------------
1 | package otp
2 |
3 | import (
4 | "fmt"
5 | "sync"
6 |
7 | "github.com/mitchellh/mapstructure"
8 | "github.com/pkg/errors"
9 | "github.com/sirupsen/logrus"
10 | )
11 |
12 | type Sender interface {
13 | Send(to string, text string) error
14 | }
15 |
16 | var senderRegistry = &sync.Map{}
17 |
18 | type senderConstructor func(map[string]interface{}) (Sender, error)
19 |
20 | func RegisterSender(id string, constructor senderConstructor) {
21 | logrus.Infof("registered %v sender", id)
22 | senderRegistry.Store(id, constructor)
23 | }
24 |
25 | func GetSender(id string, props map[string]interface{}) (Sender, error) {
26 | c, ok := senderRegistry.Load(id)
27 | if !ok {
28 | return nil, fmt.Errorf("sender %s does not exists", id)
29 | }
30 | s, err := c.(senderConstructor)(props)
31 | if err != nil {
32 | return s, errors.Wrapf(err, "error creating sender %s", id)
33 | }
34 | return s, err
35 | }
36 |
37 | var ts *TestSender
38 |
39 | type TestSender struct {
40 | Host string
41 | Port int
42 | Messages map[string]string
43 | }
44 |
45 | func init() {
46 | RegisterSender("test", NewTestSender)
47 | }
48 |
49 | func NewTestSender(props map[string]interface{}) (Sender, error) {
50 | if ts != nil {
51 | return ts, nil
52 | }
53 | var newTS TestSender
54 | err := mapstructure.Decode(props, &newTS)
55 | if err != nil {
56 | return nil, err
57 | }
58 | ts = &newTS
59 | ts.Messages = make(map[string]string)
60 | return ts, nil
61 | }
62 |
63 | func (ts *TestSender) Send(to, text string) error {
64 | ts.Messages[to] = text
65 | return nil
66 | }
67 |
--------------------------------------------------------------------------------
/pkg/auth/modules/otp_test.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "crypto/rand"
5 | "encoding/base64"
6 | "fmt"
7 | "net/http/httptest"
8 | "strconv"
9 | "testing"
10 | "time"
11 |
12 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
13 | "github.com/maximthomas/gortas/pkg/auth/constants"
14 | "github.com/maximthomas/gortas/pkg/auth/modules/otp"
15 | "github.com/maximthomas/gortas/pkg/auth/state"
16 | "github.com/maximthomas/gortas/pkg/config"
17 | "github.com/maximthomas/gortas/pkg/crypt"
18 | "github.com/maximthomas/gortas/pkg/session"
19 | "github.com/stretchr/testify/assert"
20 | )
21 |
22 | func TestNewOTP(t *testing.T) {
23 | m := getOTPModule(t)
24 | assert.Equal(t, 4, m.OtpLength)
25 | assert.Equal(t, false, m.UseLetters)
26 | assert.Equal(t, true, m.UseDigits)
27 | assert.Equal(t, 180, m.OtpTimeoutSec)
28 | assert.Equal(t, 90, m.OtpResendSec)
29 | assert.Equal(t, 5, m.OtpRetryCount)
30 | }
31 |
32 | func TestProcess(t *testing.T) {
33 | key := make([]byte, 32)
34 | rand.Read(key)
35 | keyStr := base64.StdEncoding.EncodeToString(key)
36 | conf := config.Config{}
37 | conf.EncryptionKey = keyStr
38 | config.SetConfig(&conf)
39 |
40 | m := getOTPModule(t)
41 | var fs state.FlowState
42 | status, cbs, err := m.Process(&fs)
43 | assert.NoError(t, err)
44 | assert.Equal(t, state.InProgress, status)
45 | assert.Equal(t, 2, len(cbs))
46 |
47 | //check otp callback
48 | otpCb := cbs[0]
49 | assert.Equal(t, "otp", otpCb.Name)
50 | assert.Equal(t, "180", otpCb.Properties["timeoutSec"]) //TODO to int values
51 | assert.Equal(t, "90", otpCb.Properties["resendSec"])
52 | assert.Equal(t, "5", otpCb.Properties["retryCount"])
53 |
54 | //check action callback
55 | actionCb := cbs[1]
56 | assert.Equal(t, "action", actionCb.Name)
57 | assert.Equal(t, callbacks.TypeActions, actionCb.Type)
58 | }
59 |
60 | func TestProcess_MagicLink(t *testing.T) {
61 | key := make([]byte, 32)
62 | rand.Read(key)
63 | keyStr := base64.StdEncoding.EncodeToString(key)
64 | conf := config.Config{}
65 | conf.EncryptionKey = keyStr
66 | config.SetConfig(&conf)
67 | sessionID := "test_session"
68 | code := sessionID + "|" + strconv.FormatInt(time.Now().UnixMilli()+10000, 10)
69 | encrypted, err := crypt.EncryptWithConfig(code)
70 | assert.NoError(t, err)
71 | sess := session.Session{}
72 | sess.Properties = make(map[string]string, 1)
73 | sess.Properties[constants.FlowStateSessionProperty] = "{}"
74 | sess.ID = sessionID
75 | session.GetSessionService().CreateSession(sess)
76 |
77 | m := getOTPModule(t)
78 | m.req = httptest.NewRequest("GET", "http://localhost/gortas?code="+encrypted, nil)
79 | m.OtpCheckMagicLink = true
80 | var fs state.FlowState
81 | st, _, err := m.Process(&fs)
82 | assert.NoError(t, err)
83 | assert.Equal(t, state.Pass, st)
84 | }
85 |
86 | func TestGenerateOTP(t *testing.T) {
87 | m := getOTPModule(t)
88 | err := m.generate()
89 | assert.NoError(t, err)
90 | otpCode := m.otpState.Otp
91 | generated := m.otpState.GeneratedAt
92 | assert.NotEmpty(t, otpCode)
93 | assert.Equal(t, 4, len(otpCode))
94 | assert.True(t, generated > time.Now().UnixMilli()-10000)
95 | }
96 |
97 | func TestProcessCallbacks_CodeExpired(t *testing.T) {
98 | const testOTP = "1234"
99 | m := getOTPModule(t)
100 | err := m.generate()
101 | assert.NoError(t, err)
102 | m.otpState.Otp = testOTP
103 | m.otpState.GeneratedAt = int64(0)
104 | inCbs := []callbacks.Callback{
105 | {
106 | Name: "otp",
107 | Value: testOTP,
108 | },
109 | {
110 | Name: "action",
111 | Value: "check",
112 | },
113 | }
114 | st, cbs, err := m.ProcessCallbacks(inCbs, nil)
115 | assert.NoError(t, err)
116 | assert.Equal(t, state.InProgress, st)
117 | assert.Equal(t, "OTP expired", cbs[0].Error)
118 | }
119 |
120 | func TestProcessCallbacks_BadOTP(t *testing.T) {
121 | const testOTP = "1234"
122 | m := getOTPModule(t)
123 | err := m.generate()
124 | assert.NoError(t, err)
125 | m.otpState.Otp = testOTP
126 | m.otpState.GeneratedAt = time.Now().UnixMilli()
127 | inCbs := []callbacks.Callback{
128 | {
129 | Name: "otp",
130 | Value: "bad",
131 | },
132 | {
133 | Name: "action",
134 | Value: "check",
135 | },
136 | }
137 | st, cbs, err := m.ProcessCallbacks(inCbs, nil)
138 | assert.NoError(t, err)
139 | assert.Equal(t, state.InProgress, st)
140 | assert.Equal(t, "Invalid OTP", cbs[0].Error)
141 | assert.Equal(t, "4", cbs[0].Properties["retryCount"])
142 | }
143 |
144 | func TestProcessCallbacks_SendNotAllowed(t *testing.T) {
145 | m := getOTPModule(t)
146 | inCbs := []callbacks.Callback{
147 | {
148 | Name: "otp",
149 | Value: "",
150 | },
151 | {
152 | Name: "action",
153 | Value: "send",
154 | },
155 | }
156 | m.otpState.GeneratedAt = time.Now().UnixMilli() - 1000
157 | var fs state.FlowState
158 | st, cbs, err := m.ProcessCallbacks(inCbs, &fs)
159 | assert.NoError(t, err)
160 | assert.Equal(t, state.InProgress, st)
161 | assert.Equal(t, "Sending not allowed yet", cbs[1].Error)
162 | resend, err := strconv.Atoi(cbs[0].Properties["resendSec"])
163 | assert.NoError(t, err, "incorrect converted")
164 | assert.True(t, resend < 1800 && resend > 0, fmt.Sprintf("resend %v", resend))
165 |
166 | }
167 |
168 | func TestProcessCallbacks_Send(t *testing.T) {
169 | const testOTP = "1234"
170 | m := getOTPModule(t)
171 | inCbs := []callbacks.Callback{
172 | {
173 | Name: "otp",
174 | Value: "",
175 | },
176 | {
177 | Name: "action",
178 | Value: "send",
179 | },
180 | }
181 | m.otpState.GeneratedAt = int64(1000)
182 | m.otpState.Otp = testOTP
183 | st, cbs, err := m.ProcessCallbacks(inCbs, &state.FlowState{})
184 | assert.NoError(t, err)
185 | assert.Equal(t, state.InProgress, st)
186 | assert.Empty(t, cbs[1].Error)
187 | assert.NotEqual(t, testOTP, m.State["otp"])
188 | otpCb := cbs[0]
189 | assert.Equal(t, "otp", otpCb.Name)
190 | assert.Equal(t, "180", otpCb.Properties["timeoutSec"]) //TODO to int values
191 | assert.Equal(t, "90", otpCb.Properties["resendSec"])
192 | assert.Equal(t, "5", otpCb.Properties["retryCount"])
193 | }
194 |
195 | func TestProcessCallbacks_CodeValid(t *testing.T) {
196 | const testOTP = "1234"
197 | m := getOTPModule(t)
198 | m.generate()
199 | m.otpState.Otp = testOTP
200 | m.otpState.GeneratedAt = time.Now().UnixMilli()
201 | inCbs := []callbacks.Callback{
202 | {
203 | Name: "otp",
204 | Value: testOTP,
205 | },
206 | {
207 | Name: "action",
208 | Value: "check",
209 | },
210 | }
211 | st, cbs, err := m.ProcessCallbacks(inCbs, nil)
212 | assert.NoError(t, err)
213 | assert.Equal(t, state.Pass, st)
214 | assert.Empty(t, cbs[0].Error)
215 | }
216 |
217 | func TestGetMessage(t *testing.T) {
218 | fs := &state.FlowState{
219 | ID: "test",
220 | }
221 | m := getOTPModule(t)
222 | m.otpState.Otp = "1234"
223 | msg, err := m.getMessage(fs)
224 | assert.NoError(t, err)
225 | const expectedMessage = "Code 1234 valid for 03:00 min"
226 | assert.Equal(t, expectedMessage, msg)
227 | }
228 |
229 | func TestSend(t *testing.T) {
230 | fs := &state.FlowState{
231 | ID: "test",
232 | }
233 | m := getOTPModule(t)
234 | err := m.send(fs)
235 | assert.NoError(t, err)
236 | ts := m.otpSender.(*otp.TestSender)
237 | assert.Equal(t, 1, len(ts.Messages))
238 | }
239 |
240 | func getOTPModule(t *testing.T) *OTP {
241 | var b = BaseAuthModule{
242 | State: make(map[string]interface{}, 1),
243 | Properties: map[string]interface{}{
244 | "otpLength": float64(4),
245 | "useLetters": false,
246 | "useDigits": true,
247 | "otpTimeoutSec": float64(180),
248 | "otpResendSec": float64(90),
249 | "otpRetryCount": float64(5),
250 | "otpMessageTemplate": "Code {{.OTP}} valid for {{.ValidFor}} min",
251 | "sender": map[string]interface{}{
252 | "senderType": "test",
253 | "properties": map[string]interface{}{
254 | "host": "localhost",
255 | "port": 1234,
256 | },
257 | },
258 | },
259 | }
260 | am := newOTP(b)
261 | assert.NotNil(t, am)
262 | o, ok := am.(*OTP)
263 | assert.True(t, ok)
264 | return o
265 | }
266 |
--------------------------------------------------------------------------------
/pkg/auth/modules/qr.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "crypto/rand"
5 | "encoding/base64"
6 | "fmt"
7 | "strconv"
8 | "time"
9 |
10 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
11 | "github.com/maximthomas/gortas/pkg/auth/state"
12 | "github.com/maximthomas/gortas/pkg/crypt"
13 | "github.com/skip2/go-qrcode"
14 | )
15 |
16 | const (
17 | qrCodeSize = 256
18 | secretLen = 32
19 | )
20 |
21 | type QR struct {
22 | BaseAuthModule
23 | qrTimeout int64
24 | }
25 |
26 | func (q *QR) Process(lss *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
27 |
28 | var qrT int64
29 | qrTf, ok := q.State["qrT"].(float64)
30 | if ok {
31 | qrT = int64(qrTf)
32 | } else {
33 | seconds := time.Now().Unix()
34 | qrT = seconds / q.qrTimeout
35 | q.State["qrT"] = qrT
36 | }
37 |
38 | image, err := q.generateQRImage(lss.ID, qrT)
39 | if err != nil {
40 | return state.Fail, q.Callbacks, err
41 | }
42 |
43 | q.Callbacks[0].Properties["image"] = image
44 | return state.InProgress, q.Callbacks, err
45 | }
46 |
47 | func (q *QR) ProcessCallbacks(_ []callbacks.Callback, lss *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
48 |
49 | uid, ok := q.BaseAuthModule.State["qrUserId"].(string)
50 | if !ok {
51 | // check if qr is outdated
52 | var qrT int64
53 | qrTf, ok := q.State["qrT"].(float64)
54 | seconds := time.Now().Unix()
55 | if ok {
56 | qrT = int64(qrTf)
57 | } else {
58 | qrT = seconds / q.qrTimeout
59 | q.State["qrT"] = qrT
60 | }
61 |
62 | newQrT := seconds / q.qrTimeout
63 | if newQrT > qrT {
64 | q.State["qrT"] = newQrT
65 | }
66 | var image string
67 | image, err = q.generateQRImage(lss.ID, qrT)
68 | if err != nil {
69 | return state.Fail, cbs, err
70 | }
71 |
72 | q.Callbacks[0].Properties["image"] = image
73 |
74 | return state.InProgress, q.Callbacks, err
75 | }
76 | lss.UserID = uid
77 | return state.Pass, cbs, err
78 |
79 | }
80 |
81 | func (q *QR) ValidateCallbacks(_ []callbacks.Callback) error {
82 | return nil
83 | }
84 |
85 | func (q *QR) PostProcess(_ *state.FlowState) error {
86 | return nil
87 | }
88 |
89 | func (q *QR) getSecret() (secret string, err error) {
90 | secret, ok := q.State["secret"].(string)
91 | if !ok {
92 |
93 | key := make([]byte, secretLen)
94 | _, err = rand.Read(key)
95 | if err != nil {
96 | return secret, err
97 | }
98 | secret = base64.StdEncoding.EncodeToString(key)
99 | q.State["secret"] = secret
100 | }
101 | return secret, err
102 | }
103 |
104 | func (q *QR) generateQRImage(sessID string, qrT int64) (string, error) {
105 | var image string
106 | secret, err := q.getSecret()
107 | if err != nil {
108 | return image, err
109 | }
110 |
111 | h := crypt.MD5(secret + strconv.FormatInt(qrT, 10))
112 | qrValue := fmt.Sprintf("?sid=%s;%s&action=login", sessID, h)
113 | png, err := qrcode.Encode(qrValue, qrcode.Medium, qrCodeSize)
114 | if err != nil {
115 | return image, err
116 | }
117 |
118 | image = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png))
119 | return image, nil
120 | }
121 |
122 | func init() {
123 | RegisterModule("qr", newQRModule)
124 | }
125 |
126 | func newQRModule(base BaseAuthModule) AuthModule {
127 | (&base).Callbacks = []callbacks.Callback{
128 | {
129 | Name: "qr",
130 | Type: callbacks.TypeImage,
131 | Prompt: "Enter QR code",
132 | Value: "",
133 | Properties: map[string]string{},
134 | },
135 | {
136 | Name: "submit",
137 | Type: callbacks.TypeAutoSubmit,
138 | Properties: map[string]string{
139 | "interval": "5",
140 | },
141 | },
142 | }
143 |
144 | qrTimeout := 30
145 | if qrTimeoutProp, ok := base.Properties["qrTimeout"]; ok {
146 | qrTimeout = qrTimeoutProp.(int)
147 | }
148 | return &QR{
149 | BaseAuthModule: base,
150 | qrTimeout: int64(qrTimeout),
151 | }
152 | }
153 |
--------------------------------------------------------------------------------
/pkg/auth/modules/qr_test.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "log"
5 | "net/http/httptest"
6 | "testing"
7 |
8 | "github.com/gin-gonic/gin"
9 | "github.com/google/uuid"
10 | "github.com/maximthomas/gortas/pkg/auth/state"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func TestQR(t *testing.T) {
15 |
16 | t.Run("Test request new qr", func(t *testing.T) {
17 | q := getQRModule()
18 | recorder := httptest.NewRecorder()
19 | c, _ := gin.CreateTestContext(recorder)
20 | c.Request = httptest.NewRequest("GET", "/login", nil)
21 |
22 | lss := &state.FlowState{}
23 | lss.ID = uuid.New().String()
24 |
25 | status, cbs, err := q.Process(lss)
26 | img, ok := cbs[0].Properties["image"]
27 | assert.True(t, ok)
28 | assert.NotEmpty(t, img)
29 | log.Print(status, cbs, err)
30 | })
31 |
32 | t.Run("Test process successful auth", func(t *testing.T) {
33 | q := getQRModule()
34 | recorder := httptest.NewRecorder()
35 | c, _ := gin.CreateTestContext(recorder)
36 | c.Request = httptest.NewRequest("POST", "/login", nil)
37 | lss := &state.FlowState{SharedState: map[string]string{}}
38 | lss.ID = uuid.New().String()
39 | q.BaseAuthModule.State["qrUserId"] = "ivan"
40 | ms, _, err := q.ProcessCallbacks(q.Callbacks, lss)
41 | assert.Equal(t, state.Pass, ms)
42 | assert.NoError(t, err)
43 | })
44 |
45 | t.Run("Test process update QR", func(t *testing.T) {
46 | q := getQRModule()
47 | recorder := httptest.NewRecorder()
48 | c, _ := gin.CreateTestContext(recorder)
49 | c.Request = httptest.NewRequest("POST", "/login", nil)
50 | lss := &state.FlowState{SharedState: map[string]string{}}
51 | lss.ID = uuid.New().String()
52 | ms, cbs, err := q.ProcessCallbacks(q.Callbacks, lss)
53 | assert.Equal(t, state.InProgress, ms)
54 | assert.NoError(t, err)
55 | image := cbs[0].Properties["image"]
56 | assert.NotEmpty(t, image)
57 | })
58 |
59 | }
60 |
61 | func getQRModule() *QR {
62 | b := BaseAuthModule{
63 | Properties: map[string]interface{}{
64 | "qrTimeout": 10,
65 | },
66 | State: map[string]interface{}{},
67 | }
68 | m := newQRModule(b)
69 | q, _ := m.(*QR)
70 | return q
71 | }
72 |
--------------------------------------------------------------------------------
/pkg/auth/modules/registration.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "regexp"
5 |
6 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
7 | "github.com/maximthomas/gortas/pkg/auth/state"
8 | "github.com/maximthomas/gortas/pkg/user"
9 | "github.com/mitchellh/mapstructure"
10 | "github.com/pkg/errors"
11 | )
12 |
13 | // TODO add password format
14 | // TODO add confirmation password callback
15 | type Registration struct {
16 | BaseAuthModule
17 | PrimaryField Field
18 | UsePassword bool
19 | UseRepeatPassword bool
20 | AdditionalFields []Field
21 | }
22 |
23 | func (f *Field) initField() error {
24 | if f.Validation == "" {
25 | return nil
26 | }
27 | _, err := regexp.Compile(f.Validation)
28 | if err != nil {
29 | return err
30 | }
31 | return nil
32 | }
33 |
34 | func (rm *Registration) Process(_ *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
35 | return state.InProgress, rm.Callbacks, err
36 | }
37 |
38 | func (rm *Registration) ProcessCallbacks(inCbs []callbacks.Callback, fs *state.FlowState) (ms state.ModuleStatus, cbs []callbacks.Callback, err error) {
39 | if inCbs == nil {
40 | return state.Fail, cbs, errors.New("callbacks can't be nil")
41 | }
42 | callbacksValid := true
43 | errCbs := make([]callbacks.Callback, len(rm.Callbacks))
44 | copy(errCbs, rm.Callbacks)
45 |
46 | for i := range inCbs {
47 | cb := inCbs[i]
48 | if cb.Value == "" && errCbs[i].Required {
49 | (&errCbs[i]).Error = (&errCbs[i]).Prompt + " required"
50 | callbacksValid = false
51 | } else if errCbs[i].Validation != "" {
52 | var re *regexp.Regexp
53 | re, err = regexp.Compile(errCbs[i].Validation)
54 | if err != nil {
55 | rm.l.Errorf("error compiling regex for callback %v", cb.Validation)
56 | return state.Fail, cbs, errors.Wrapf(err, "error compiling regex for callback %v", cb.Validation)
57 | }
58 | match := re.MatchString(cb.Value)
59 | if !match {
60 | (&errCbs[i]).Error = (&errCbs[i]).Prompt + " invalid"
61 | callbacksValid = false
62 | }
63 | }
64 | }
65 |
66 | if !callbacksValid {
67 | return state.InProgress, errCbs, nil
68 | }
69 |
70 | var username string
71 | var password string
72 | var repeatPassword string
73 |
74 | fields := make(map[string]string, len(inCbs)-2)
75 |
76 | for i := range inCbs {
77 | cb := inCbs[i]
78 | switch cb.Name {
79 | case rm.PrimaryField.Name:
80 | username = cb.Value
81 | case "password":
82 | password = cb.Value
83 | case "repeatPassword":
84 | repeatPassword = cb.Value
85 | default:
86 | fields[cb.Name] = cb.Value
87 | }
88 | }
89 |
90 | if repeatPassword != password {
91 | (&errCbs[len(inCbs)-1]).Error = "Passwords do not match"
92 | return state.InProgress, errCbs, nil
93 | }
94 |
95 | us := user.GetUserService()
96 | _, exists := us.GetUser(username)
97 | if exists {
98 | (&errCbs[0]).Error = "User exists"
99 | return state.InProgress, errCbs, nil
100 | }
101 |
102 | u := user.User{
103 | ID: username,
104 | Properties: fields,
105 | }
106 |
107 | _, err = us.CreateUser(u)
108 | if err != nil {
109 | return state.Fail, cbs, err
110 | }
111 |
112 | err = us.SetPassword(u.ID, password)
113 | if err != nil {
114 | return state.Fail, cbs, err
115 | }
116 |
117 | fs.UserID = u.ID
118 |
119 | return state.Pass, rm.Callbacks, err
120 | }
121 |
122 | func (rm *Registration) ValidateCallbacks(cbs []callbacks.Callback) error {
123 | return rm.BaseAuthModule.ValidateCallbacks(cbs)
124 | }
125 |
126 | func (rm *Registration) PostProcess(_ *state.FlowState) error {
127 | return nil
128 | }
129 |
130 | func init() {
131 | RegisterModule("registration", newRegistrationModule)
132 | }
133 |
134 | func newRegistrationModule(base BaseAuthModule) AuthModule {
135 | var rm Registration
136 | rm.UsePassword = true // default value
137 | rm.UseRepeatPassword = true
138 | err := mapstructure.Decode(base.Properties, &rm)
139 | if err != nil {
140 | panic(err) // TODO add error processing
141 | }
142 | rm.BaseAuthModule = base
143 |
144 | cbLen := len(rm.AdditionalFields) + 1
145 | if rm.UsePassword {
146 | cbLen++
147 | }
148 |
149 | if rm.UseRepeatPassword {
150 | cbLen++
151 | }
152 |
153 | adcbs := make([]callbacks.Callback, cbLen)
154 | if rm.AdditionalFields != nil {
155 | for i, af := range rm.AdditionalFields {
156 | adcbs[i+1] = callbacks.Callback{
157 | Name: af.Name,
158 | Type: callbacks.TypeText,
159 | Value: "",
160 | Prompt: af.Prompt,
161 | Required: af.Required,
162 | Validation: af.Validation,
163 | }
164 | }
165 | }
166 | adcbs[0] = callbacks.Callback{
167 | Name: rm.PrimaryField.Name,
168 | Type: callbacks.TypeText,
169 | Prompt: rm.PrimaryField.Prompt,
170 | Value: "",
171 | Required: true,
172 | Validation: rm.PrimaryField.Validation,
173 | }
174 | if rm.UsePassword {
175 | adcbs[cbLen-2] = callbacks.Callback{
176 | Name: "password",
177 | Type: callbacks.TypePassword,
178 | Prompt: "Password",
179 | Value: "",
180 | Required: true,
181 | }
182 | }
183 | if rm.UseRepeatPassword {
184 | adcbs[cbLen-1] = callbacks.Callback{
185 | Name: "repeatPassword",
186 | Type: callbacks.TypePassword,
187 | Prompt: "Repeat password",
188 | Value: "",
189 | Required: true,
190 | }
191 | }
192 |
193 | (&rm.BaseAuthModule).Callbacks = adcbs
194 | return &rm
195 | }
196 |
--------------------------------------------------------------------------------
/pkg/auth/modules/registration_test.go:
--------------------------------------------------------------------------------
1 | package modules
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 |
9 | "github.com/gin-gonic/gin"
10 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
11 | "github.com/maximthomas/gortas/pkg/auth/state"
12 | "github.com/maximthomas/gortas/pkg/config"
13 | "github.com/maximthomas/gortas/pkg/log"
14 | "github.com/maximthomas/gortas/pkg/user"
15 | "github.com/stretchr/testify/assert"
16 | )
17 |
18 | const (
19 | userName = "johnDoe"
20 | password = "passw0rdJ0hn"
21 | )
22 |
23 | func TestNewRegistrationModule(t *testing.T) {
24 | rm := getNewRegistrationModule(t)
25 | assert.Equal(t, "login", rm.PrimaryField.Name)
26 | assert.Equal(t, 1, len(rm.AdditionalFields))
27 | assert.Equal(t, true, rm.UsePassword)
28 | assert.Equal(t, true, rm.UseRepeatPassword)
29 | }
30 |
31 | func TestRegistration_Process_InvalidLogin(t *testing.T) {
32 | tests := []struct {
33 | email string
34 | name string
35 | emailError string
36 | nameError string
37 | }{
38 | {email: "", name: "", emailError: "Email required", nameError: "Name required"},
39 | {email: "123", name: "John Doe", emailError: "Email invalid", nameError: ""},
40 | }
41 | rm := getNewRegistrationModule(t)
42 | for _, tt := range tests {
43 | t.Run(tt.name, func(t *testing.T) {
44 | inCbs := []callbacks.Callback{
45 | {
46 | Name: "login",
47 | Value: tt.email,
48 | },
49 | {
50 | Name: "name",
51 | Value: tt.name,
52 | },
53 | }
54 | fs := &state.FlowState{}
55 | ms, cbs, err := rm.ProcessCallbacks(inCbs, fs)
56 | assert.NoError(t, err)
57 | assert.Equal(t, 4, len(cbs))
58 | assert.Equal(t, tt.emailError, cbs[0].Error)
59 | assert.Equal(t, tt.nameError, cbs[1].Error)
60 | assert.Equal(t, state.InProgress, ms)
61 | })
62 | }
63 |
64 | }
65 |
66 | func TestRegistration_Process(t *testing.T) {
67 | rm := getNewRegistrationModule(t)
68 | t.Run("Test request callbacks", func(t *testing.T) {
69 | recorder := httptest.NewRecorder()
70 | c, _ := gin.CreateTestContext(recorder)
71 | c.Request = httptest.NewRequest("GET", "/login", nil)
72 | lss := &state.FlowState{}
73 | status, cbs, err := rm.Process(lss)
74 | fmt.Print(status, cbs, err)
75 | assert.Equal(t, 4, len(cbs))
76 | assert.NoError(t, err)
77 | assert.Equal(t, state.InProgress, status)
78 | assert.Equal(t, http.StatusOK, recorder.Code)
79 | })
80 | }
81 |
82 | func TestRegistration_ProcessCallbacks(t *testing.T) {
83 | rm := getNewRegistrationModule(t)
84 |
85 | tests := []struct {
86 | name string
87 | inCbs []callbacks.Callback
88 | assertions func(t *testing.T, status state.ModuleStatus, cbs []callbacks.Callback, err error)
89 | }{
90 | {
91 | name: "test empty callbacks",
92 | inCbs: nil,
93 | assertions: func(t *testing.T, status state.ModuleStatus, cbs []callbacks.Callback, err error) {
94 | assert.Error(t, err)
95 | assert.Equal(t, state.Fail, status)
96 | },
97 | },
98 | {
99 | name: "test empty fields",
100 | inCbs: []callbacks.Callback{
101 | {
102 | Name: "login",
103 | Value: "",
104 | },
105 | {
106 | Name: "name",
107 | Value: "",
108 | },
109 | {
110 | Name: "password",
111 | Value: "",
112 | },
113 | },
114 | assertions: func(t *testing.T, status state.ModuleStatus, cbs []callbacks.Callback, err error) {
115 | assert.NoError(t, err)
116 | assert.Equal(t, state.InProgress, status)
117 | assert.Equal(t, "Email required", cbs[0].Error)
118 | assert.Equal(t, "Name required", cbs[1].Error)
119 | assert.Equal(t, "Password required", cbs[2].Error)
120 | },
121 | },
122 | {
123 | name: "test user exists",
124 | inCbs: []callbacks.Callback{
125 | {
126 | Name: "login",
127 | Value: "user1",
128 | },
129 | {
130 | Name: "name",
131 | Value: "John Doe",
132 | },
133 | {
134 | Name: "password",
135 | Value: password,
136 | },
137 | {
138 | Name: "repeatPassword",
139 | Value: password,
140 | },
141 | },
142 | assertions: func(t *testing.T, status state.ModuleStatus, cbs []callbacks.Callback, err error) {
143 | assert.Equal(t, state.InProgress, status)
144 | assert.Equal(t, "User exists", cbs[0].Error)
145 | },
146 | },
147 | {
148 | name: "test passwords do not match",
149 | inCbs: []callbacks.Callback{
150 | {
151 | Name: "login",
152 | Value: userName,
153 | },
154 | {
155 | Name: "name",
156 | Value: "John Doe",
157 | },
158 | {
159 | Name: "password",
160 | Value: password,
161 | },
162 | {
163 | Name: "repeatPassword",
164 | Value: "bad",
165 | },
166 | },
167 | assertions: func(t *testing.T, status state.ModuleStatus, cbs []callbacks.Callback, err error) {
168 | assert.NoError(t, err)
169 | assert.Equal(t, state.InProgress, status)
170 | assert.Equal(t, "Passwords do not match", cbs[3].Error)
171 | },
172 | },
173 | {
174 | name: "test successful registration",
175 | inCbs: []callbacks.Callback{
176 | {
177 | Name: "login",
178 | Value: userName,
179 | },
180 | {
181 | Name: "name",
182 | Value: "John Doe",
183 | },
184 | {
185 | Name: "password",
186 | Value: password,
187 | },
188 | {
189 | Name: "repeatPassword",
190 | Value: password,
191 | },
192 | },
193 | assertions: func(t *testing.T, status state.ModuleStatus, cbs []callbacks.Callback, err error) {
194 | assert.NoError(t, err)
195 | assert.Equal(t, state.Pass, status)
196 | us := user.GetUserService()
197 | _, ok := us.GetUser(userName)
198 | assert.True(t, ok)
199 | pValid := us.ValidatePassword(userName, password)
200 | assert.True(t, pValid)
201 | },
202 | },
203 | }
204 |
205 | for _, tt := range tests {
206 | t.Run(tt.name, func(t *testing.T) {
207 | recorder := httptest.NewRecorder()
208 | c, _ := gin.CreateTestContext(recorder)
209 | c.Request = httptest.NewRequest("POST", "/login", nil)
210 | lss := &state.FlowState{}
211 | status, cbs, err := rm.ProcessCallbacks(tt.inCbs, lss)
212 | tt.assertions(t, status, cbs, err)
213 | })
214 | }
215 | }
216 |
217 | func getNewRegistrationModule(t *testing.T) *Registration {
218 | config.SetConfig(&config.Config{})
219 | var b = BaseAuthModule{
220 | l: log.WithField("module", "registration"),
221 | Properties: map[string]interface{}{
222 | "primaryField": Field{
223 | Name: "login",
224 | Prompt: "Email",
225 | Required: true,
226 | Validation: "^\\w{4,}$",
227 | },
228 | "additionalFields": []Field{{
229 | Name: "name",
230 | Prompt: "Name",
231 | Required: true,
232 | },
233 | },
234 | },
235 | }
236 |
237 | var m = newRegistrationModule(b)
238 | rm, ok := m.(*Registration)
239 | assert.True(t, ok)
240 | return rm
241 | }
242 |
--------------------------------------------------------------------------------
/pkg/auth/state/state.go:
--------------------------------------------------------------------------------
1 | package state
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "reflect"
7 | )
8 |
9 | type FlowState struct {
10 | Modules []FlowStateModuleInfo
11 | SharedState map[string]string
12 | UserID string
13 | ID string
14 | RedirectURI string
15 | Name string
16 | }
17 |
18 | type FlowStateModuleInfo struct {
19 | ID string
20 | Type string
21 | Properties FlowStateModuleProperties
22 | Status ModuleStatus
23 | State map[string]interface{}
24 | Criteria string
25 | }
26 |
27 | type FlowStateModuleProperties map[string]interface{}
28 |
29 | func (mp FlowStateModuleProperties) MarshalJSON() ([]byte, error) {
30 | var cp = make(map[string]interface{})
31 | for k, v := range mp {
32 | cp[k] = convertInterface(v)
33 | }
34 | return json.Marshal(cp)
35 | }
36 |
37 | func convertInterface(v interface{}) interface{} {
38 | var res interface{}
39 | switch v.(type) {
40 | case []map[interface{}]interface{}:
41 | rVal := reflect.ValueOf(v)
42 | mMaps := make([]map[string]string, rVal.Len())
43 | for i := 0; i < rVal.Len(); i++ {
44 | mMap := make(map[string]string)
45 | iter := rVal.Index(i).MapRange()
46 | for iter.Next() {
47 | mk := fmt.Sprintf("%v", iter.Key())
48 | mv := fmt.Sprintf("%v", iter.Value())
49 | mMap[mk] = mv
50 | }
51 | mMaps[i] = mMap
52 | }
53 | res = mMaps
54 | case map[interface{}]interface{}:
55 | rVal := reflect.ValueOf(v)
56 | mMap := make(map[string]string)
57 | iter := rVal.MapRange()
58 | for iter.Next() {
59 | k := fmt.Sprintf("%v", iter.Key())
60 | v := fmt.Sprintf("%v", iter.Value())
61 | mMap[k] = v
62 | }
63 | res = mMap
64 | case []interface{}:
65 | rVal := reflect.ValueOf(v)
66 | ar := make([]interface{}, rVal.Len())
67 | for i := 0; i < rVal.Len(); i++ {
68 | ar[i] = convertInterface(rVal.Index(i).Interface())
69 | }
70 | res = ar
71 | default:
72 | res = v
73 | }
74 | return res
75 | }
76 |
77 | func (f *FlowState) UpdateModuleInfo(mIndex int, mInfo FlowStateModuleInfo) {
78 | f.Modules[mIndex] = mInfo
79 | }
80 |
81 | type ModuleStatus int
82 |
83 | const (
84 | Fail ModuleStatus = -1 + iota
85 | Start
86 | InProgress // callbacks requested
87 | Pass
88 | )
89 |
90 | const (
91 | FlowCookieName = "GortasAuthFlow"
92 | SessionCookieName = "GortasSession"
93 | )
94 |
--------------------------------------------------------------------------------
/pkg/config/config.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "github.com/maximthomas/gortas/pkg/log"
5 | "github.com/maximthomas/gortas/pkg/session"
6 | "github.com/maximthomas/gortas/pkg/user"
7 | "github.com/spf13/viper"
8 | )
9 |
10 | type Config struct {
11 | Flows map[string]Flow `yaml:"flows"`
12 | Session session.Config `yaml:"session"`
13 | Server Server `yaml:"server"`
14 | EncryptionKey string `yaml:"encryptionKey"`
15 | UserDataStore user.Config `yaml:"userDataStore"`
16 | }
17 |
18 | type Flow struct {
19 | Modules []Module `yaml:"modules"`
20 | }
21 |
22 | type Module struct {
23 | ID string `yaml:"id"`
24 | Type string `yaml:"type"`
25 | Properties map[string]interface{} `yaml:"properties,omitempty"`
26 | Criteria string `yaml:"criteria"`
27 | }
28 |
29 | type Server struct {
30 | Cors Cors
31 | }
32 |
33 | type Cors struct {
34 | AllowedOrigins []string
35 | }
36 |
37 | var configLogger = log.WithField("module", "config")
38 |
39 | var config Config
40 |
41 | func InitConfig() error {
42 |
43 | // newLogger.SetFormatter(&logrus.JSONFormatter{})
44 | // newLogger.SetReportCaller(true)
45 |
46 | err := viper.Unmarshal(&config)
47 |
48 | if err != nil { // Handle errors reading the config file
49 | configLogger.Errorf("Fatal error config file: %s \n", err)
50 | panic(err)
51 | }
52 | err = user.InitUserService(config.UserDataStore)
53 | if err != nil {
54 | configLogger.Errorf("Fatal error config file: %s \n", err)
55 | panic(err)
56 | }
57 | err = session.InitSessionService(&config.Session)
58 |
59 | if err != nil {
60 | configLogger.Errorf("error while init session service: %s \n", err)
61 | panic(err)
62 | }
63 |
64 | configLogger.Debugf("got configuration %+v\n", config)
65 |
66 | return nil
67 | }
68 |
69 | func GetConfig() Config {
70 | return config
71 | }
72 |
73 | func SetConfig(newConfig *Config) {
74 | config = *newConfig
75 | err := user.InitUserService(newConfig.UserDataStore)
76 | if err != nil {
77 | configLogger.Warnf("error %v", err)
78 | }
79 | err = session.InitSessionService(&newConfig.Session)
80 | if err != nil {
81 | configLogger.Warnf("error %v", err)
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/pkg/config/config_test.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/spf13/viper"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestReadConfigFileViper(t *testing.T) {
12 | viper.SetConfigName("auth-config-dev") // name of config file (without extension)
13 | viper.AddConfigPath("../../test") // optionally look for config in the working directory
14 | err := viper.ReadInConfig() // Find and read the config file
15 | assert.NoError(t, err)
16 | err = InitConfig()
17 | assert.NoError(t, err)
18 | conf := GetConfig()
19 | assert.True(t, len(conf.Flows) > 0)
20 | assert.NotEmpty(t, config.Session.Jwt.PrivateKeyPem)
21 | assert.Equal(t, 1, len(conf.Server.Cors.AllowedOrigins))
22 | }
23 |
--------------------------------------------------------------------------------
/pkg/controller/auth.go:
--------------------------------------------------------------------------------
1 | package controller
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "strconv"
7 |
8 | "github.com/gin-gonic/gin"
9 | "github.com/maximthomas/gortas/pkg/auth"
10 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
11 | "github.com/maximthomas/gortas/pkg/auth/state"
12 | "github.com/maximthomas/gortas/pkg/log"
13 | "github.com/sirupsen/logrus"
14 | )
15 |
16 | // AuthController rest controller for authentication
17 | type AuthController struct {
18 | logger logrus.FieldLogger
19 | }
20 |
21 | // Auth gin handler function
22 | func (a *AuthController) Auth(c *gin.Context) {
23 | fn := c.Param("flow")
24 |
25 | var cbReq callbacks.Request
26 | var fID string
27 | if c.Request.Method == http.MethodPost {
28 | err := c.ShouldBindJSON(&cbReq)
29 | if err != nil {
30 | logrus.Errorf("error binding json body %v", err)
31 | c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})
32 | return
33 | }
34 | fID = cbReq.FlowID
35 | if fID == "" {
36 | fID, _ = c.Cookie(state.FlowCookieName)
37 | }
38 | }
39 | (&cbReq).FlowID = fID
40 | fp := auth.NewFlowProcessor()
41 | cbResp, err := fp.Process(fn, cbReq, c.Request, c.Writer)
42 | a.generateResponse(c, &cbResp, err)
43 | }
44 |
45 | func (a *AuthController) generateResponse(c *gin.Context, cbResp *callbacks.Response, err error) {
46 | if err != nil {
47 | logrus.Errorf("authentication error %v", err)
48 | deleteCookie(state.FlowCookieName, c)
49 | c.JSON(http.StatusUnauthorized, gin.H{"status": "fail"})
50 | return
51 | }
52 |
53 | if cbResp.Token != "" {
54 | setCookie(state.SessionCookieName, cbResp.Token, c)
55 | deleteCookie(state.FlowCookieName, c)
56 | c.JSON(http.StatusOK, cbResp)
57 | } else if cbResp.FlowID != "" {
58 | status := http.StatusOK
59 | outCb := make([]callbacks.Callback, 0)
60 | for i := range cbResp.Callbacks {
61 | cb := cbResp.Callbacks[i]
62 | if cb.Type == callbacks.TypeHTTPStatus {
63 | status, err = strconv.Atoi(cb.Value)
64 | if err != nil {
65 | errMsg := fmt.Sprintf("error parsing status %v", cb.Value)
66 | a.logger.Errorf(errMsg)
67 | c.JSON(http.StatusInternalServerError, gin.H{"status": "fail", "message": errMsg})
68 | return
69 | }
70 | for k, val := range cb.Properties {
71 | c.Header(k, val)
72 | }
73 | } else {
74 | outCb = append(outCb, cb)
75 | }
76 | }
77 | cbOutResp := callbacks.Response{
78 | Callbacks: outCb,
79 | Module: cbResp.Module,
80 | FlowID: cbResp.FlowID,
81 | }
82 | setCookie(state.FlowCookieName, cbResp.FlowID, c)
83 | c.JSON(status, cbOutResp)
84 | } else {
85 | a.logger.Error("this should be never happen")
86 | c.JSON(http.StatusInternalServerError, gin.H{"status": "fail"})
87 | }
88 | }
89 |
90 | func setCookie(name, value string, c *gin.Context) {
91 | c.SetCookie(name, value, 0, "/", "", false, true)
92 | }
93 |
94 | func deleteCookie(name string, c *gin.Context) {
95 | c.SetCookie(name, "", -1, "/", "", false, true)
96 | }
97 |
98 | func NewAuthController() *AuthController {
99 | logger := log.WithField("module", "AuthController")
100 | return &AuthController{logger}
101 | }
102 |
--------------------------------------------------------------------------------
/pkg/controller/auth_test.go:
--------------------------------------------------------------------------------
1 | package controller
2 |
3 | import (
4 | "errors"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 |
9 | "github.com/gin-gonic/gin"
10 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
11 | "github.com/maximthomas/gortas/pkg/config"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | func init() {
16 | conf := config.Config{}
17 | config.SetConfig(&conf)
18 | }
19 |
20 | func TestGenerateResponse(t *testing.T) {
21 |
22 | type cookie struct {
23 | name string
24 | value string
25 | }
26 | type header struct {
27 | name string
28 | value string
29 | }
30 | tests := []struct {
31 | name string
32 | cbResp callbacks.Response
33 | err error
34 | expectedStatus int
35 | expectedBody string
36 | expectedCookies []cookie
37 | expectedHeaders []header
38 | }{
39 | {
40 | name: "auth error",
41 | cbResp: callbacks.Response{},
42 | err: errors.New("authError"),
43 | expectedStatus: 401,
44 | expectedBody: `{"status":"fail"}`,
45 | },
46 | {
47 | name: "auth in progress",
48 | cbResp: callbacks.Response{
49 | Module: "login",
50 | Callbacks: []callbacks.Callback{
51 | {Type: "text", Name: "login", Value: ""},
52 | {Type: "password", Name: "password", Value: ""},
53 | },
54 | FlowID: "test-flow-id",
55 | },
56 | err: nil,
57 | expectedStatus: 200,
58 | expectedBody: `{"module":"login","callbacks":[{"name":"login","type":"text","value":""},` +
59 | `{"name":"password","type":"password","value":""}],"flowId":"test-flow-id"}`,
60 | expectedCookies: []cookie{
61 | {
62 | name: "GortasAuthFlow",
63 | value: "test-flow-id",
64 | },
65 | },
66 | },
67 | {
68 | name: "auth in progress with httpcallback",
69 | cbResp: callbacks.Response{
70 | Module: "kerberos",
71 | Callbacks: []callbacks.Callback{
72 | {Type: "text", Name: "login", Value: ""},
73 | {Type: "httpstatus", Name: "httpstatus", Value: "401", Properties: map[string]string{"Authenticate": "WWW-Negotiate"}},
74 | },
75 | FlowID: "test-flow-id",
76 | },
77 | err: nil,
78 | expectedStatus: 401,
79 | expectedBody: `{"module":"kerberos","callbacks":[{"name":"login","type":"text","value":""}],"flowId":"test-flow-id"}`,
80 | expectedCookies: []cookie{
81 | {
82 | name: "GortasAuthFlow",
83 | value: "test-flow-id",
84 | },
85 | },
86 | expectedHeaders: []header{
87 | {
88 | name: "Authenticate", value: "WWW-Negotiate",
89 | },
90 | },
91 | },
92 | {
93 | name: "auth succeed",
94 | cbResp: callbacks.Response{
95 | Token: "test-token",
96 | Type: "Bearer",
97 | },
98 | err: nil,
99 | expectedStatus: 200,
100 | expectedBody: `{"token":"test-token","type":"Bearer"}`,
101 | expectedCookies: []cookie{
102 | {
103 | name: "GortasSession",
104 | value: "test-token",
105 | },
106 | },
107 | },
108 | }
109 |
110 | var getCookie = func(name string, cookies []*http.Cookie) *http.Cookie {
111 | for _, c := range cookies {
112 | if c.Name == name {
113 | return c
114 | }
115 | }
116 | return nil
117 | }
118 |
119 | ac := NewAuthController()
120 | for _, tt := range tests {
121 | t.Run(tt.name, func(t *testing.T) {
122 | recorder := httptest.NewRecorder()
123 | c, _ := gin.CreateTestContext(recorder)
124 | ac.generateResponse(c, &tt.cbResp, tt.err)
125 | resp := recorder.Result()
126 | defer resp.Body.Close()
127 | assert.Equal(t, tt.expectedStatus, resp.StatusCode)
128 | assert.Equal(t, tt.expectedBody, recorder.Body.String())
129 | for _, ec := range tt.expectedCookies {
130 | c := getCookie(ec.name, resp.Cookies())
131 | assert.NotNil(t, c)
132 | assert.Equal(t, ec.value, c.Value)
133 | }
134 | for _, eh := range tt.expectedHeaders {
135 | assert.Equal(t, eh.value, resp.Header.Get(eh.name))
136 | }
137 |
138 | })
139 | }
140 |
141 | }
142 |
--------------------------------------------------------------------------------
/pkg/controller/controller_test.go:
--------------------------------------------------------------------------------
1 | package controller
2 |
3 | import (
4 | "crypto/rand"
5 | "crypto/rsa"
6 | "crypto/x509"
7 | "encoding/pem"
8 |
9 | "github.com/maximthomas/gortas/pkg/config"
10 | "github.com/maximthomas/gortas/pkg/session"
11 | )
12 |
13 | var (
14 | privateKey, _ = rsa.GenerateKey(rand.Reader, 1024)
15 |
16 | privateKeyStr = string(pem.EncodeToMemory(
17 | &pem.Block{
18 | Type: "RSA PRIVATE KEY",
19 | Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
20 | },
21 | ))
22 |
23 | flows = map[string]config.Flow{
24 | "default": {Modules: []config.Module{
25 | {
26 | ID: "login",
27 | Type: "login",
28 | },
29 | }},
30 | "register": {Modules: []config.Module{
31 | {
32 | ID: "registration",
33 | Type: "registration",
34 | Properties: map[string]interface{}{
35 | "testProp": "testVal",
36 | "additionalFields": []map[interface{}]interface{}{{
37 | "dataStore": "name",
38 | "prompt": "Name",
39 | }},
40 | },
41 | },
42 | }},
43 | "sso": {Modules: []config.Module{}},
44 | }
45 |
46 | conf = config.Config{
47 | Flows: flows,
48 | Session: session.Config{
49 | Type: "stateless",
50 | Expires: 60000,
51 | Jwt: session.JWT{
52 | Issuer: "http://gortas",
53 | },
54 | },
55 | }
56 | )
57 |
--------------------------------------------------------------------------------
/pkg/controller/passwordless.go:
--------------------------------------------------------------------------------
1 | package controller
2 |
3 | import (
4 | "encoding/base64"
5 | "encoding/json"
6 | "fmt"
7 | "net/http"
8 | "strings"
9 |
10 | "github.com/maximthomas/gortas/pkg/auth/constants"
11 | "github.com/maximthomas/gortas/pkg/auth/state"
12 | "github.com/maximthomas/gortas/pkg/log"
13 | "github.com/maximthomas/gortas/pkg/middleware"
14 | "github.com/maximthomas/gortas/pkg/session"
15 | "github.com/maximthomas/gortas/pkg/user"
16 |
17 | "github.com/gin-gonic/gin"
18 | "github.com/google/uuid"
19 | "github.com/maximthomas/gortas/pkg/config"
20 | "github.com/sirupsen/logrus"
21 | "github.com/skip2/go-qrcode"
22 | )
23 |
24 | const qrSize = 256
25 |
26 | // TODO v2 refactor passwordless architecture
27 | type PasswordlessServicesController struct {
28 | logger logrus.FieldLogger
29 | conf config.Config
30 | }
31 |
32 | func NewPasswordlessServicesController(c *config.Config) *PasswordlessServicesController {
33 | logger := log.WithField("module", "PasswordlessServicesController")
34 | return &PasswordlessServicesController{logger, *c}
35 | }
36 |
37 | type QRProps struct {
38 | Secret string `json:"secret"`
39 | }
40 |
41 | func (pc *PasswordlessServicesController) RegisterGenerateQR(c *gin.Context) {
42 | si, ok := c.Get("session")
43 | if !ok {
44 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not authenticated"})
45 | return
46 | }
47 | s := si.(session.Session)
48 | uid := s.GetUserID()
49 | us := user.GetUserService()
50 |
51 | _, ok = us.GetUser(uid)
52 | if !ok {
53 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "No user found in the repository"})
54 | return
55 | }
56 | requestURI := middleware.GetRequestURI(c)
57 | imageData := fmt.Sprintf("%s?sid=%s&action=register", requestURI, s.ID)
58 |
59 | png, err := qrcode.Encode(imageData, qrcode.Medium, qrSize)
60 | if err != nil {
61 | pc.logger.Error(err)
62 | c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error generate QR code"})
63 | return
64 | }
65 |
66 | image := "data:image/png;base64," + base64.StdEncoding.EncodeToString(png)
67 | c.JSON(http.StatusOK, gin.H{"qr": image})
68 | }
69 |
70 | func (pc *PasswordlessServicesController) RegisterConfirmQR(c *gin.Context) {
71 | si, ok := c.Get("session")
72 | if !ok {
73 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not authenticated"})
74 | return
75 | }
76 | s := si.(session.Session)
77 | uid := s.GetUserID()
78 | us := user.GetUserService()
79 |
80 | u, ok := us.GetUser(uid)
81 | if !ok {
82 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "No user found in the repository"})
83 | return
84 | }
85 | // generate secret key
86 | secret := uuid.New().String()
87 | qrProps := QRProps{Secret: secret}
88 | qrPropsJSON, err := json.Marshal(qrProps)
89 | if err != nil {
90 | c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error updating user"})
91 | return
92 | }
93 | if u.Properties == nil {
94 | u.Properties = make(map[string]string)
95 | }
96 | u.Properties["passwordless.qr"] = string(qrPropsJSON)
97 | err = us.UpdateUser(u)
98 | if err != nil {
99 | c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error updating user"})
100 | return
101 | }
102 | requestURI := middleware.GetRequestURI(c)
103 | authURI := strings.ReplaceAll(requestURI, "/idm/otp/qr", "/service/otp/qr/login")
104 |
105 | c.JSON(http.StatusOK, gin.H{"secret": secret, "userId": u.ID, "authURI": authURI})
106 | }
107 |
108 | func (pc *PasswordlessServicesController) AuthQR(c *gin.Context) {
109 | var authQRRequest struct {
110 | SID string `json:"sid"`
111 | UID string `json:"uid"`
112 | Realm string `json:"realm"`
113 | Secret string `json:"secret"`
114 | }
115 |
116 | err := c.ShouldBindJSON(&authQRRequest)
117 | if err != nil {
118 | pc.logger.Warn("invalid request body", err)
119 | c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
120 | return
121 | }
122 |
123 | sess, err := session.GetSessionService().GetSession(authQRRequest.SID)
124 | if err != nil {
125 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "there is no valid authentication session"})
126 | return
127 | }
128 |
129 | us := user.GetUserService()
130 | u, ok := us.GetUser(authQRRequest.UID)
131 | if !ok {
132 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "error updating user"})
133 | return
134 | }
135 | jsonProp, ok := u.Properties["passwordless.qr"]
136 | if !ok {
137 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "the user is not bound to QR"})
138 | return
139 | }
140 | var qrProps QRProps
141 | err = json.Unmarshal([]byte(jsonProp), &qrProps)
142 | if err != nil {
143 | pc.logger.Warn("AuthQR: the user is not bound to QR")
144 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "the user is not bound to QR"})
145 | return
146 | }
147 |
148 | if qrProps.Secret != authQRRequest.Secret {
149 | pc.logger.Warn("AuthQR: user qr secrets does not match")
150 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "the user is not bound to QR"})
151 | return
152 | }
153 |
154 | // authorise session
155 | var fs state.FlowState
156 | err = json.Unmarshal([]byte(sess.Properties[constants.FlowStateSessionProperty]), &fs)
157 | if err != nil {
158 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "there is no valid authentication session"})
159 | return
160 | }
161 |
162 | moduleFound := false
163 | for _, m := range fs.Modules {
164 | if m.Type == "qr" && m.Status == state.InProgress {
165 | m.State["qrUserId"] = authQRRequest.UID
166 | moduleFound = true
167 | break
168 | }
169 | }
170 | if !moduleFound {
171 | pc.logger.Warn("AuthQR: no active qr module in the chain")
172 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "there is no valid authentication session"})
173 | return
174 | }
175 | fsJSON, err := json.Marshal(fs)
176 | if err != nil {
177 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "there is no valid authentication session"})
178 | return
179 | }
180 | sess.Properties[constants.FlowStateSessionProperty] = string(fsJSON)
181 | err = session.GetSessionService().UpdateSession(sess)
182 | if err != nil {
183 | c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
184 | return
185 | }
186 |
187 | c.JSON(http.StatusOK, gin.H{"status": "success"})
188 |
189 | }
190 |
--------------------------------------------------------------------------------
/pkg/controller/passwordless_test.go:
--------------------------------------------------------------------------------
1 | package controller
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "net/http"
7 | "net/http/httptest"
8 | "strings"
9 | "testing"
10 | "time"
11 |
12 | "github.com/gin-gonic/gin"
13 | "github.com/google/uuid"
14 | "github.com/maximthomas/gortas/pkg/auth/state"
15 | "github.com/maximthomas/gortas/pkg/session"
16 | "github.com/maximthomas/gortas/pkg/user"
17 | "github.com/stretchr/testify/assert"
18 | )
19 |
20 | type qrTestCaseArgs struct {
21 | session interface{}
22 | }
23 | type qrTestCaseWant struct {
24 | code int
25 | }
26 | type qrTestCase struct {
27 | name string
28 | args qrTestCaseArgs
29 | want qrTestCaseWant
30 | }
31 |
32 | var qrTests = []qrTestCase{
33 | {
34 | "no session",
35 | qrTestCaseArgs{
36 | session: nil,
37 | },
38 | qrTestCaseWant{
39 | code: http.StatusUnauthorized,
40 | },
41 | },
42 | {
43 | "no valid user in session",
44 | qrTestCaseArgs{
45 | session: session.Session{
46 | ID: uuid.New().String(),
47 | CreatedAt: time.Time{},
48 | Properties: map[string]string{
49 | "sub": "bad",
50 | "realm": "staff",
51 | },
52 | },
53 | },
54 | qrTestCaseWant{
55 | code: http.StatusUnauthorized,
56 | },
57 | },
58 | {
59 | "valid user in session",
60 | qrTestCaseArgs{
61 | session: session.Session{
62 | ID: uuid.New().String(),
63 | CreatedAt: time.Time{},
64 | Properties: map[string]string{
65 | "sub": "user1",
66 | "realm": "staff",
67 | },
68 | },
69 | },
70 | qrTestCaseWant{
71 | code: http.StatusOK,
72 | },
73 | },
74 | }
75 |
76 | func TestPasswordlessServicesController_RegisterGenerateQR(t *testing.T) {
77 | pc := NewPasswordlessServicesController(&conf)
78 |
79 | for _, tt := range qrTests {
80 | t.Run(tt.name, func(t *testing.T) {
81 |
82 | recorder := httptest.NewRecorder()
83 | c, _ := gin.CreateTestContext(recorder)
84 | if tt.args.session != nil {
85 | c.Set("session", tt.args.session)
86 | }
87 | c.Request = httptest.NewRequest("POST", "/", nil)
88 |
89 | pc.RegisterGenerateQR(c)
90 | assert.Equal(t, recorder.Code, tt.want.code)
91 | })
92 | }
93 | }
94 |
95 | func TestPasswordlessServicesController_RegisterConfirmQR(t *testing.T) {
96 | pc := NewPasswordlessServicesController(&conf)
97 | type args struct {
98 | session interface{}
99 | }
100 | type want struct {
101 | code int
102 | }
103 |
104 | for _, tt := range qrTests {
105 | t.Run(tt.name, func(t *testing.T) {
106 | recorder := httptest.NewRecorder()
107 | c, _ := gin.CreateTestContext(recorder)
108 | c.Keys = make(map[string]interface{})
109 | if tt.args.session != nil {
110 | c.Set("session", tt.args.session)
111 | }
112 | c.Request = httptest.NewRequest("POST", "/", nil)
113 |
114 | pc.RegisterConfirmQR(c)
115 |
116 | assert.Equal(t, recorder.Code, tt.want.code)
117 | })
118 | }
119 | }
120 |
121 | func TestPasswordlessServicesController_AuthQR(t *testing.T) {
122 | t.Skip() //TODO implement test
123 | badSess := session.Session{
124 | ID: uuid.New().String(),
125 | CreatedAt: time.Now(),
126 | Properties: nil,
127 | }
128 | _, err := session.GetSessionService().CreateSession(badSess)
129 | assert.NoError(t, err)
130 |
131 | lss := state.FlowState{
132 | Modules: []state.FlowStateModuleInfo{
133 | {
134 | ID: "",
135 | Type: "qr",
136 | Properties: nil,
137 | Status: state.InProgress,
138 | State: map[string]interface{}{},
139 | },
140 | },
141 | SharedState: map[string]string{},
142 | UserID: "",
143 | ID: "",
144 | RedirectURI: "",
145 | }
146 | lssBytes, _ := json.Marshal(lss)
147 | validSess := session.Session{
148 | ID: uuid.New().String(),
149 | CreatedAt: time.Now(),
150 | Properties: map[string]string{
151 | "lss": string(lssBytes),
152 | },
153 | }
154 | _, err = session.GetSessionService().CreateSession(validSess)
155 | assert.NoError(t, err)
156 |
157 | us := user.GetUserService()
158 | u, _ := us.GetUser("user1")
159 | u.Properties = map[string]string{
160 | "passwordless.qr": `{"secret": "s3cr3t"}`,
161 | }
162 | err = us.UpdateUser(u)
163 | assert.NoError(t, err)
164 |
165 | pc := NewPasswordlessServicesController(&conf)
166 | type args struct {
167 | body string
168 | }
169 | type want struct {
170 | code int
171 | errMessage string
172 | }
173 | tests := []struct {
174 | name string
175 | args args
176 | want want
177 | }{
178 | {
179 | name: "bad request body",
180 | args: args{
181 | body: "bad",
182 | },
183 | want: want{
184 | code: http.StatusBadRequest,
185 | errMessage: "invalid request body",
186 | },
187 | },
188 | {
189 | name: "no valid session",
190 | args: args{
191 | body: `{"sid":"bad","uid":"user1","realm":"staff","secret":"s3cr3t"}`,
192 | },
193 | want: want{
194 | code: http.StatusUnauthorized,
195 | errMessage: "there is no valid authentication session",
196 | },
197 | },
198 | {
199 | name: "no valid user in a repo",
200 | args: args{
201 | body: fmt.Sprintf(`{"sid":"%q","uid":"bad","realm":"staff","secret":"s3cr3t"}`, validSess.ID),
202 | },
203 | want: want{
204 | code: http.StatusUnauthorized,
205 | errMessage: "error updating user",
206 | },
207 | },
208 | {
209 | name: "user not bound",
210 | args: args{
211 | body: fmt.Sprintf(`{"sid":"%q","uid":"user2","realm":"staff","secret":"s3cr3t"}`, validSess.ID),
212 | },
213 | want: want{
214 | code: http.StatusUnauthorized,
215 | errMessage: "the user is not bound to QR",
216 | },
217 | },
218 | {
219 | name: "secret does not match",
220 | args: args{
221 | body: fmt.Sprintf(`{"sid":"%q","uid":"user1","realm":"staff","secret":"secret"}`, validSess.ID),
222 | },
223 | want: want{
224 | code: http.StatusUnauthorized,
225 | errMessage: "the user is not bound to QR",
226 | },
227 | },
228 |
229 | {
230 | name: "bad authentication session",
231 | args: args{
232 | body: fmt.Sprintf(`{"sid":"%q","uid":"user1","realm":"staff","secret":"s3cr3t"}`, badSess.ID),
233 | },
234 | want: want{
235 | code: http.StatusUnauthorized,
236 | errMessage: "there is no valid authentication session",
237 | },
238 | },
239 | {
240 | name: "valid authentication session",
241 | args: args{
242 | body: fmt.Sprintf(`{"sid":"%q","uid":"user1","realm":"staff","secret":"s3cr3t"}`, validSess.ID),
243 | },
244 | want: want{
245 | code: http.StatusOK,
246 | },
247 | },
248 | }
249 | for _, tt := range tests {
250 | t.Run(tt.name, func(t *testing.T) {
251 | recorder := httptest.NewRecorder()
252 | c, _ := gin.CreateTestContext(recorder)
253 | c.Request = httptest.NewRequest("POST", "/", strings.NewReader(tt.args.body))
254 | pc.AuthQR(c)
255 | assert.Equal(t, tt.want.code, recorder.Code)
256 | var respJSON = make(map[string]interface{})
257 | err := json.Unmarshal(recorder.Body.Bytes(), &respJSON)
258 | assert.NoError(t, err)
259 | if tt.want.errMessage != "" {
260 | assert.Equal(t, respJSON["error"], tt.want.errMessage)
261 | }
262 | })
263 | }
264 | }
265 |
--------------------------------------------------------------------------------
/pkg/controller/session.go:
--------------------------------------------------------------------------------
1 | package controller
2 |
3 | import (
4 | "net/http"
5 | "strings"
6 |
7 | "github.com/gin-gonic/gin"
8 | "github.com/maximthomas/gortas/pkg/auth/state"
9 | "github.com/maximthomas/gortas/pkg/log"
10 | "github.com/maximthomas/gortas/pkg/session"
11 | "github.com/sirupsen/logrus"
12 | )
13 |
14 | type SessionController struct {
15 | logger logrus.FieldLogger
16 | }
17 |
18 | func (sc *SessionController) SessionInfo(c *gin.Context) {
19 | var sessionID string
20 | authHeader := c.Request.Header.Get("Authorization")
21 | if authHeader != "" { // from header
22 | sessionID = strings.TrimPrefix(authHeader, "Bearer ")
23 | }
24 | if sessionID == "" { // from cookie
25 | cookie, err := c.Request.Cookie(state.SessionCookieName)
26 | if err == nil {
27 | sessionID = cookie.Value
28 | }
29 | }
30 |
31 | if sessionID == "" {
32 | sc.logger.Warn("session not found in the request")
33 | sc.generateErrorResponse(c)
34 | return
35 | }
36 |
37 | sess, err := session.GetSessionService().GetSessionData(sessionID)
38 | if err != nil {
39 | sc.logger.Warnf("error validating sessionId %s", sessionID)
40 | sc.generateErrorResponse(c)
41 | }
42 | c.JSON(http.StatusOK, sess)
43 | }
44 |
45 | func (sc *SessionController) SessionJwt(c *gin.Context) {
46 | var sessionID string
47 | authHeader := c.Request.Header.Get("Authorization")
48 | if authHeader != "" { // from header
49 | sessionID = strings.TrimPrefix(authHeader, "Bearer ")
50 | }
51 | if sessionID == "" { // from cookie
52 | cookie, err := c.Request.Cookie(state.SessionCookieName)
53 | if err == nil {
54 | sessionID = cookie.Value
55 | }
56 | }
57 |
58 | if sessionID == "" {
59 | sc.logger.Warn("session not found in the request")
60 | sc.generateErrorResponse(c)
61 | return
62 | }
63 | jwt, err := session.GetSessionService().ConvertSessionToJwt(sessionID)
64 | if err != nil {
65 | sc.logger.Warnf("error validating sessionId %s", sessionID)
66 | sc.generateErrorResponse(c)
67 | }
68 | c.JSON(http.StatusOK, gin.H{"jwt": jwt})
69 | }
70 |
71 | func (sc *SessionController) generateErrorResponse(c *gin.Context) {
72 | c.JSON(http.StatusNotFound, gin.H{"error": "token not found"})
73 | }
74 |
75 | func NewSessionController() *SessionController {
76 | return &SessionController{
77 | logger: log.WithField("module", "SessionController"),
78 | }
79 | }
80 |
--------------------------------------------------------------------------------
/pkg/controller/session_test.go:
--------------------------------------------------------------------------------
1 | package controller
2 |
3 | import (
4 | "math/rand"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 | "time"
9 |
10 | "github.com/dgrijalva/jwt-go"
11 | "github.com/gin-gonic/gin"
12 | "github.com/maximthomas/gortas/pkg/auth/state"
13 | "github.com/maximthomas/gortas/pkg/config"
14 | "github.com/maximthomas/gortas/pkg/session"
15 | "github.com/stretchr/testify/assert"
16 | )
17 |
18 | func setupConfig(sessRepoType string) {
19 |
20 | conf := config.Config{
21 | Session: session.Config{
22 | Type: sessRepoType,
23 | Expires: 60000,
24 | DataStore: session.DataStore{
25 | Type: "in_memory",
26 | Properties: nil,
27 | },
28 | Jwt: session.JWT{
29 | Issuer: "http://gortas",
30 | PrivateKeyPem: privateKeyStr,
31 | },
32 | },
33 | }
34 | config.SetConfig(&conf)
35 | statefulSession := session.Session{
36 | ID: "testSessionId",
37 | Properties: map[string]string{
38 | "sub": "user1",
39 | "userId": "user1",
40 | "realm": "users",
41 | },
42 | }
43 | _, err := session.GetSessionService().CreateSession(statefulSession)
44 | if err != nil {
45 | panic(err)
46 | }
47 | }
48 |
49 | func getTestJWT() string {
50 | token := jwt.New(jwt.SigningMethodRS256)
51 | claims := token.Claims.(jwt.MapClaims)
52 | exp := time.Second * time.Duration(rand.Intn(1000))
53 | claims["id"] = "testSessionId"
54 | claims["exp"] = time.Now().Add(exp).Unix()
55 | claims["jti"] = "test"
56 | claims["iat"] = time.Now().Unix()
57 | claims["iss"] = "http://gortas"
58 | claims["sub"] = "user1"
59 | claims["realm"] = "realm1"
60 | statelessID, _ := token.SignedString(privateKey)
61 | return statelessID
62 | }
63 |
64 | func TestSessionController_SessionInfo(t *testing.T) {
65 | testJWT := getTestJWT()
66 | setupConfig("stateless")
67 | sc := NewSessionController()
68 | assert.NotNil(t, sc)
69 |
70 | tests := []struct {
71 | name string
72 | getRequest func() *http.Request
73 | wantStatus int
74 | }{
75 | {
76 | name: "get_existing_session_header",
77 | getRequest: func() *http.Request {
78 | request := httptest.NewRequest("GET", "/", nil)
79 | request.Header.Set("Authorization", "Bearer "+testJWT)
80 | return request
81 | },
82 | wantStatus: 200,
83 | },
84 | {
85 | name: "get_existing_session_cookie",
86 | getRequest: func() *http.Request {
87 | request := httptest.NewRequest("GET", "/", nil)
88 | authCookie := &http.Cookie{
89 | Name: state.SessionCookieName,
90 | Value: testJWT,
91 | }
92 | request.AddCookie(authCookie)
93 | return request
94 | },
95 | wantStatus: 200,
96 | },
97 | {
98 | name: "get_empty_session",
99 | getRequest: func() *http.Request {
100 | request := httptest.NewRequest("GET", "/", nil)
101 | return request
102 | },
103 | wantStatus: 404,
104 | },
105 | {
106 | name: "get_bad_session",
107 | getRequest: func() *http.Request {
108 | request := httptest.NewRequest("GET", "/", nil)
109 | return request
110 | },
111 | wantStatus: 404,
112 | },
113 | }
114 | for _, tt := range tests {
115 | t.Run(tt.name, func(t *testing.T) {
116 | recorder := httptest.NewRecorder()
117 | c, _ := gin.CreateTestContext(recorder)
118 | c.Request = tt.getRequest()
119 | sc.SessionInfo(c)
120 | resp := recorder.Result()
121 | defer resp.Body.Close()
122 | assert.Equal(t, tt.wantStatus, resp.StatusCode)
123 | })
124 | }
125 | }
126 |
--------------------------------------------------------------------------------
/pkg/crypt/crypt.go:
--------------------------------------------------------------------------------
1 | package crypt
2 |
3 | import (
4 | "crypto/aes"
5 | "crypto/cipher"
6 | "crypto/md5"
7 | "crypto/rand"
8 | "encoding/base64"
9 | "encoding/hex"
10 | "io"
11 | "math/big"
12 |
13 | "github.com/maximthomas/gortas/pkg/config"
14 | "github.com/pkg/errors"
15 | )
16 |
17 | func EncryptWithConfig(message string) (encmess string, err error) {
18 | encKey := config.GetConfig().EncryptionKey
19 | key, err := base64.StdEncoding.DecodeString(encKey)
20 | if err != nil {
21 | return "", errors.Wrap(err, "error encrypt wuth config")
22 | }
23 | return Encrypt(key, message)
24 | }
25 |
26 | func DecryptWithConfig(message string) (encmess string, err error) {
27 | encKey := config.GetConfig().EncryptionKey
28 | key, err := base64.StdEncoding.DecodeString(encKey)
29 | if err != nil {
30 | return "", errors.Wrap(err, "error encrypt wuth config")
31 | }
32 | return Decrypt(key, message)
33 | }
34 |
35 | func Encrypt(key []byte, message string) (encmess string, err error) {
36 | plainText := []byte(message)
37 |
38 | block, err := aes.NewCipher(key)
39 | if err != nil {
40 | return
41 | }
42 |
43 | cipherText := make([]byte, aes.BlockSize+len(plainText))
44 | iv := cipherText[:aes.BlockSize]
45 | if _, err = io.ReadFull(rand.Reader, iv); err != nil {
46 | return
47 | }
48 |
49 | stream := cipher.NewCFBEncrypter(block, iv)
50 | stream.XORKeyStream(cipherText[aes.BlockSize:], plainText)
51 |
52 | encmess = base64.URLEncoding.EncodeToString(cipherText)
53 | return
54 | }
55 |
56 | func Decrypt(key []byte, securemess string) (decodedmess string, err error) {
57 | cipherText, err := base64.URLEncoding.DecodeString(securemess)
58 | if err != nil {
59 | return
60 | }
61 |
62 | block, err := aes.NewCipher(key)
63 | if err != nil {
64 | return
65 | }
66 |
67 | if len(cipherText) < aes.BlockSize {
68 | err = errors.New("ciphertext block size is too short")
69 | return
70 | }
71 |
72 | iv := cipherText[:aes.BlockSize]
73 | cipherText = cipherText[aes.BlockSize:]
74 |
75 | stream := cipher.NewCFBDecrypter(block, iv)
76 |
77 | stream.XORKeyStream(cipherText, cipherText)
78 |
79 | decodedmess = string(cipherText)
80 | return
81 | }
82 |
83 | func MD5(str string) string {
84 | h := md5.New()
85 | h.Write([]byte(str))
86 | return hex.EncodeToString(h.Sum(nil))
87 | }
88 |
89 | func RandomString(length int, useLetters, useDigits bool) (string, error) {
90 | var runes string
91 | if useLetters {
92 | runes += "abcdefghijklmnopqrstuvwxyz"
93 | }
94 | if useDigits {
95 | runes += "0123456789"
96 | }
97 |
98 | if runes == "" {
99 | return "", errors.New("at least letters or numbers should be specified")
100 | }
101 | ret := make([]byte, length)
102 | for i := 0; i < length; i++ {
103 | num, err := rand.Int(rand.Reader, big.NewInt(int64(len(runes))))
104 | if err != nil {
105 | return "", err
106 | }
107 | ret[i] = runes[num.Int64()]
108 | }
109 | return string(ret), nil
110 | }
111 |
--------------------------------------------------------------------------------
/pkg/log/logger.go:
--------------------------------------------------------------------------------
1 | package log
2 |
3 | import "github.com/sirupsen/logrus"
4 |
5 | var logger = logrus.New()
6 |
7 | func WithField(key string, value interface{}) *logrus.Entry {
8 | return logger.WithField(key, value)
9 | }
10 |
--------------------------------------------------------------------------------
/pkg/middleware/authenticated.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "net/http"
7 | "strings"
8 | "time"
9 |
10 | "github.com/dgrijalva/jwt-go"
11 | "github.com/gin-gonic/gin"
12 | "github.com/maximthomas/gortas/pkg/auth/state"
13 | "github.com/maximthomas/gortas/pkg/session"
14 | )
15 |
16 | func NewAuthenticatedMiddleware(s *session.Config) gin.HandlerFunc {
17 | am := authenticatedMiddleware{*s}
18 | return am.build()
19 | }
20 |
21 | type authenticatedMiddleware struct {
22 | sc session.Config
23 | }
24 |
25 | func (a *authenticatedMiddleware) build() gin.HandlerFunc {
26 | return func(c *gin.Context) {
27 | sessionID := getSessionIDFromRequest(c)
28 | if sessionID == "" {
29 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not authenticated"})
30 | return
31 | }
32 | var sess session.Session
33 | var err error
34 | if a.sc.Type == "stateless" {
35 | claims := jwt.MapClaims{}
36 | _, err = jwt.ParseWithClaims(sessionID, claims, func(token *jwt.Token) (interface{}, error) {
37 | return session.GetSessionService().GetJwtPublicKey(), nil
38 | })
39 | if err != nil {
40 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not authenticated"})
41 | return
42 | }
43 | if !claims.VerifyExpiresAt(time.Now().Unix(), true) {
44 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token expired"})
45 | return
46 | }
47 |
48 | sessionProps := make(map[string]string)
49 | for key, value := range claims {
50 | if value == nil {
51 | continue
52 | }
53 | var strVal string
54 | if key == "props" {
55 | var bytes []byte
56 | bytes, err = json.Marshal(value)
57 | if err != nil {
58 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Error parsing token attrs"})
59 | return
60 | }
61 | strVal = string(bytes)
62 | } else {
63 | strVal = fmt.Sprintf("%v", value)
64 | }
65 | sessionProps[key] = strVal
66 | }
67 |
68 | sess = session.Session{
69 | ID: sessionID,
70 | Properties: sessionProps,
71 | }
72 | } else {
73 | sess, err = session.GetSessionService().GetSession(sessionID)
74 | if err != nil {
75 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not authenticated"})
76 | return
77 | }
78 | }
79 | uid := sess.GetUserID()
80 | if uid == "" {
81 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Session is not valid"})
82 | return
83 | }
84 |
85 | c.Set("session", sess)
86 |
87 | c.Next()
88 | }
89 | }
90 |
91 | func getSessionIDFromRequest(c *gin.Context) string {
92 | sessionCookie, err := c.Request.Cookie(state.SessionCookieName)
93 | if err == nil {
94 | return sessionCookie.Value
95 | }
96 | reqToken := c.Request.Header.Get("Authorization")
97 | splitToken := strings.Split(reqToken, "Bearer ")
98 | if len(splitToken) == 2 {
99 | return splitToken[1]
100 | }
101 |
102 | return ""
103 | }
104 |
--------------------------------------------------------------------------------
/pkg/middleware/authenticated_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "crypto/rand"
5 | "crypto/rsa"
6 | "crypto/x509"
7 | "encoding/pem"
8 | "net/http"
9 | "net/http/httptest"
10 | "testing"
11 | "time"
12 |
13 | "github.com/dgrijalva/jwt-go"
14 | "github.com/maximthomas/gortas/pkg/auth/state"
15 | "github.com/maximthomas/gortas/pkg/session"
16 |
17 | "github.com/stretchr/testify/assert"
18 |
19 | "github.com/gin-gonic/gin"
20 | "github.com/google/uuid"
21 | )
22 |
23 | func TestMiddleware(t *testing.T) {
24 | var privateKey, _ = rsa.GenerateKey(rand.Reader, 2048)
25 | var privateKeyStr = string(pem.EncodeToMemory(
26 | &pem.Block{
27 | Type: "RSA PRIVATE KEY",
28 | Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
29 | },
30 | ))
31 | s := session.Config{
32 | Type: "stateless",
33 | Expires: 0,
34 | Jwt: session.JWT{
35 | Issuer: "http://gortas",
36 | PrivateKeyPem: privateKeyStr,
37 | },
38 | DataStore: session.DataStore{
39 | Type: "",
40 | Properties: nil,
41 | },
42 | }
43 | err := session.InitSessionService(&s)
44 | assert.NoError(t, err)
45 |
46 | // Create the Claims
47 | claims := &jwt.StandardClaims{
48 | ExpiresAt: time.Now().Add(time.Minute * 1).Unix(),
49 | Issuer: "test",
50 | Subject: "user1",
51 | }
52 |
53 | token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
54 | sessJWT, _ := token.SignedString(privateKey)
55 |
56 | m := NewAuthenticatedMiddleware(&s)
57 |
58 | var tests = []struct {
59 | expectedStatus int
60 | sessionID string
61 | name string
62 | exists bool
63 | }{
64 | {401, "bad", "Bad token", false},
65 | {200, sessJWT, "Valid token", true},
66 | }
67 | for _, tt := range tests {
68 | t.Run(tt.name, func(t *testing.T) {
69 | recorder := httptest.NewRecorder()
70 | c, _ := gin.CreateTestContext(recorder)
71 | c.Keys = make(map[string]any)
72 | c.Request = httptest.NewRequest("GET", "/login", nil)
73 | authCookie := &http.Cookie{
74 | Name: state.SessionCookieName,
75 | Value: tt.sessionID,
76 | }
77 | c.Request.AddCookie(authCookie)
78 | m(c)
79 | assert.Equal(t, tt.expectedStatus, recorder.Code)
80 | _, ok := c.Get("session")
81 | assert.Equal(t, tt.exists, ok)
82 | })
83 | }
84 |
85 | }
86 |
87 | func TestGetSessionFormRequest(t *testing.T) {
88 | sessID := uuid.New().String()
89 | sess := session.Session{
90 | ID: sessID,
91 |
92 | Properties: map[string]string{
93 | "test": "test",
94 | "sub": "ivan",
95 | },
96 | }
97 | t.Run("Test get session from cookie", func(t *testing.T) {
98 | recorder := httptest.NewRecorder()
99 | c, _ := gin.CreateTestContext(recorder)
100 | c.Request = httptest.NewRequest("GET", "/", nil)
101 | authCookie := &http.Cookie{
102 | Name: state.SessionCookieName,
103 | Value: sess.ID,
104 | }
105 | c.Request.AddCookie(authCookie)
106 | sessionID := getSessionIDFromRequest(c)
107 | assert.NotEmpty(t, sessionID)
108 | })
109 |
110 | t.Run("Test get session from auth header", func(t *testing.T) {
111 | recorder := httptest.NewRecorder()
112 | c, _ := gin.CreateTestContext(recorder)
113 | c.Request = httptest.NewRequest("GET", "/", nil)
114 |
115 | c.Request.Header.Add("Authorization", "Bearer "+sess.ID)
116 | sessionID := getSessionIDFromRequest(c)
117 | assert.NotEmpty(t, sessionID)
118 | })
119 |
120 | t.Run("Test no session in request", func(t *testing.T) {
121 | recorder := httptest.NewRecorder()
122 | c, _ := gin.CreateTestContext(recorder)
123 | c.Request = httptest.NewRequest("GET", "/", nil)
124 |
125 | sessionID := getSessionIDFromRequest(c)
126 | assert.Empty(t, sessionID)
127 | })
128 | }
129 |
--------------------------------------------------------------------------------
/pkg/middleware/requesturi.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "github.com/gin-gonic/gin"
5 | )
6 |
7 | const requestURIKey = "request.uri"
8 |
9 | func GetRequestURI(c *gin.Context) string {
10 | requestURI, ok := c.Get(requestURIKey)
11 | if ok {
12 | return requestURI.(string)
13 | }
14 | return c.Request.RequestURI
15 |
16 | }
17 | func NewRequestURIMiddleware() gin.HandlerFunc {
18 | return requestURIMiddleware{}.build()
19 | }
20 |
21 | type requestURIMiddleware struct {
22 | }
23 |
24 | func (r requestURIMiddleware) build() gin.HandlerFunc {
25 | return func(c *gin.Context) {
26 | scheme := r.getScheme(c)
27 | host := r.getHost(c)
28 | path := r.getPath(c)
29 | c.Set(requestURIKey, scheme+host+path)
30 | c.Next()
31 | }
32 | }
33 |
34 | func (r requestURIMiddleware) getScheme(c *gin.Context) string {
35 | if r.getHost(c) != "" {
36 | if c.Request.TLS != nil {
37 | return "https://"
38 | }
39 | return "http://"
40 | }
41 | return ""
42 | }
43 |
44 | func (r requestURIMiddleware) getHost(c *gin.Context) string {
45 | return c.Request.Host
46 | }
47 |
48 | func (r requestURIMiddleware) getPath(c *gin.Context) string {
49 | return c.Request.RequestURI
50 | }
51 |
--------------------------------------------------------------------------------
/pkg/server/server.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "github.com/gin-gonic/gin"
5 | "github.com/maximthomas/gortas/pkg/config"
6 | "github.com/maximthomas/gortas/pkg/controller"
7 | "github.com/maximthomas/gortas/pkg/middleware"
8 | cors "github.com/rs/cors/wrapper/gin"
9 | )
10 |
11 | func SetupRouter(conf *config.Config) *gin.Engine {
12 | router := gin.Default()
13 | c := cors.New(cors.Options{
14 | AllowedOrigins: conf.Server.Cors.AllowedOrigins,
15 | AllowCredentials: true,
16 | Debug: gin.IsDebugging(),
17 | })
18 |
19 | ru := middleware.NewRequestURIMiddleware()
20 |
21 | router.Use(c, ru)
22 | var ac = controller.NewAuthController()
23 | var sc = controller.NewSessionController()
24 |
25 | v1 := router.Group("/gortas/v1")
26 | {
27 | auth := v1.Group("/auth")
28 | {
29 | route := "/:flow"
30 | auth.GET(route, ac.Auth)
31 | auth.POST(route, ac.Auth)
32 | }
33 | session := v1.Group("/session")
34 | session.GET("/info", sc.SessionInfo)
35 | session.GET("/jwt", sc.SessionJwt)
36 |
37 | }
38 | return router
39 | }
40 |
41 | func RunServer() {
42 | ac := config.GetConfig()
43 | router := SetupRouter(&ac)
44 | err := router.Run(":" + "8080")
45 | if err != nil {
46 | panic(err)
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/pkg/session/config.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | type Config struct {
4 | Type string `yaml:"type"`
5 | Expires int `yaml:"expires"`
6 | Jwt JWT `yaml:"jwt,omitempty"`
7 | DataStore DataStore `yaml:"dataStore,omitempty"`
8 | }
9 |
10 | type JWT struct {
11 | Issuer string `yaml:"issuer"`
12 | PrivateKeyPem string `yml:"privateKeyPem"`
13 | }
14 |
15 | type DataStore struct {
16 | Type string
17 | Properties map[string]string
18 | }
19 |
--------------------------------------------------------------------------------
/pkg/session/service.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "crypto/rsa"
5 | "crypto/x509"
6 | "encoding/pem"
7 | "errors"
8 | "math/rand"
9 | "time"
10 |
11 | jwt "github.com/dgrijalva/jwt-go"
12 | "github.com/google/uuid"
13 | "github.com/maximthomas/gortas/pkg/user"
14 | "github.com/mitchellh/mapstructure"
15 | )
16 |
17 | type Service struct {
18 | repo sessionRepository
19 | sessionType string
20 | jwt Jwt
21 | expires int
22 | }
23 |
24 | type Jwt struct {
25 | PrivateKeyID string
26 | Issuer string
27 | PrivateKey *rsa.PrivateKey
28 | PublicKey *rsa.PublicKey
29 | }
30 |
31 | func (ss *Service) CreateSession(session Session) (Session, error) {
32 | return ss.repo.CreateSession(session)
33 | }
34 |
35 | func (ss *Service) DeleteSession(id string) error {
36 | return ss.repo.DeleteSession(id)
37 | }
38 |
39 | func (ss *Service) GetSession(id string) (Session, error) {
40 | return ss.repo.GetSession(id)
41 | }
42 |
43 | func (ss *Service) UpdateSession(session Session) error {
44 | return ss.repo.UpdateSession(session)
45 | }
46 |
47 | func (ss *Service) ConvertSessionToJwt(sessID string) (string, error) {
48 | sess, err := ss.GetSession(sessID)
49 | if err != nil {
50 | return "", err
51 | }
52 |
53 | token := jwt.New(jwt.SigningMethodRS256)
54 | claims := token.Claims.(jwt.MapClaims)
55 | exp := time.Second * time.Duration(rand.Intn(ss.expires))
56 | claims["exp"] = time.Now().Add(exp).Unix()
57 | claims["jti"] = ss.jwt.PrivateKeyID
58 | claims["iat"] = time.Now().Unix()
59 | claims["iss"] = ss.jwt.Issuer
60 | claims["sub"] = sess.GetUserID()
61 | claims["props"] = sess.Properties
62 | token.Header["jks"] = ss.jwt.PrivateKeyID
63 | return token.SignedString(ss.jwt.PrivateKey)
64 | }
65 |
66 | func (ss *Service) CreateUserSession(userID string) (sessID string, err error) {
67 | var sessionID string
68 | u, userExists := user.GetUserService().GetUser(userID)
69 | if ss.sessionType == "stateless" {
70 | token := jwt.New(jwt.SigningMethodRS256)
71 | claims := token.Claims.(jwt.MapClaims)
72 | exp := time.Second * time.Duration(rand.Intn(ss.expires))
73 | claims["exp"] = time.Now().Add(exp).Unix()
74 | claims["jti"] = ss.jwt.PrivateKeyID
75 | claims["iat"] = time.Now().Unix()
76 | claims["iss"] = ss.jwt.Issuer
77 | claims["sub"] = userID
78 | if userExists {
79 | claims["props"] = u.Properties
80 | }
81 |
82 | token.Header["jks"] = ss.jwt.PrivateKeyID
83 | ss, _ := token.SignedString(ss.jwt.PrivateKey)
84 | sessionID = ss
85 | } else {
86 | sessionID = uuid.New().String()
87 | newSession := Session{
88 | ID: sessionID,
89 | Properties: map[string]string{
90 | "userId": u.ID,
91 | "sub": userID,
92 | },
93 | }
94 | if userExists {
95 | for k, v := range u.Properties {
96 | newSession.Properties[k] = v
97 | }
98 | }
99 |
100 | newSession, err = ss.CreateSession(newSession)
101 | if err != nil {
102 | return sessID, err
103 | }
104 | }
105 | return sessionID, nil
106 | }
107 |
108 | func (ss *Service) GetSessionData(sessionID string) (sess map[string]interface{}, err error) {
109 | sess = make(map[string]interface{})
110 |
111 | if ss.sessionType == "stateless" {
112 | publicKey := ss.jwt.PublicKey
113 | claims := jwt.MapClaims{}
114 | _, err = jwt.ParseWithClaims(sessionID, claims, func(token *jwt.Token) (interface{}, error) {
115 | return publicKey, nil
116 | })
117 | if err != nil {
118 | return sess, err
119 | }
120 | sess = claims
121 | } else {
122 | var statefulSession Session
123 | statefulSession, err = ss.GetSession(sessionID)
124 | if statefulSession.GetUserID() == "" {
125 | return sess, errors.New("user session not found")
126 | }
127 | if err != nil {
128 | return sess, err
129 | }
130 | sess["id"] = statefulSession.ID
131 | sess["created"] = statefulSession.CreatedAt
132 | sess["properties"] = statefulSession.Properties
133 | }
134 | return sess, err
135 | }
136 |
137 | func (ss *Service) GetJwtPublicKey() *rsa.PublicKey {
138 | return ss.jwt.PublicKey
139 | }
140 |
141 | var ss Service
142 |
143 | func InitSessionService(sc *Config) error {
144 | newSs, err := newSessionServce(sc)
145 | if err != nil {
146 | return err
147 | }
148 | ss = newSs
149 | return nil
150 | }
151 |
152 | func newSessionServce(sc *Config) (ss Service, err error) {
153 | token := sc.Jwt
154 |
155 | if token.PrivateKeyPem != "" {
156 | var privateKey *rsa.PrivateKey
157 | privateKeyBlock, _ := pem.Decode([]byte(token.PrivateKeyPem))
158 | privateKey, err = x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
159 | if err != nil {
160 | return ss, err
161 | }
162 | ss.jwt.PrivateKey = privateKey
163 | ss.jwt.PublicKey = &privateKey.PublicKey
164 | ss.jwt.PrivateKeyID = uuid.New().String()
165 | ss.jwt.Issuer = token.Issuer
166 | }
167 |
168 | if sc.DataStore.Type == "mongo" {
169 | prop := sc.DataStore.Properties
170 | params := make(map[string]string)
171 | err = mapstructure.Decode(&prop, ¶ms)
172 | if err != nil {
173 | return ss, err
174 | }
175 | url := params["url"]
176 | db := params["database"]
177 | col := params["collection"]
178 | ss.repo, err = newMongoSessionRepository(url, db, col)
179 | if err != nil {
180 | return ss, err
181 | }
182 | } else {
183 | ss.repo = newInMemorySessionRepository()
184 | }
185 | ss.sessionType = sc.Type
186 | ss.expires = sc.Expires
187 | return ss, err
188 | }
189 |
190 | func GetSessionService() *Service {
191 | return &ss
192 | }
193 |
194 | func SetSessionService(newSs *Service) {
195 | ss = *newSs
196 | }
197 |
--------------------------------------------------------------------------------
/pkg/session/session.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import "time"
4 |
5 | // SessionDataStore struct represents session object from session service
6 | type Session struct {
7 | ID string `json:"id,omitempty"`
8 | CreatedAt time.Time `json:"createdat,omitempty" bson:"createdAt"`
9 | Properties map[string]string `json:"properties,omitempty"`
10 | }
11 |
12 | func (s *Session) GetUserID() string {
13 | userID, ok := s.Properties["sub"]
14 | if !ok {
15 | return ""
16 | }
17 | return userID
18 | }
19 |
20 | func (s *Session) SetUserID(userID string) {
21 | s.Properties["sub"] = userID
22 | }
23 |
24 | func (s *Session) GetRealm() string {
25 | realm, ok := s.Properties["realm"]
26 | if !ok {
27 | return ""
28 | }
29 | return realm
30 | }
31 |
--------------------------------------------------------------------------------
/pkg/session/session_repository.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "errors"
5 | "time"
6 |
7 | "github.com/sirupsen/logrus"
8 |
9 | "github.com/google/uuid"
10 | )
11 |
12 | type sessionRepository interface {
13 | CreateSession(session Session) (Session, error)
14 | DeleteSession(id string) error
15 | GetSession(id string) (Session, error)
16 | UpdateSession(session Session) error
17 | }
18 |
19 | type inMemorySessionRepository struct {
20 | sessions map[string]Session
21 | logger logrus.FieldLogger
22 | }
23 |
24 | func (sr *inMemorySessionRepository) CreateSession(session Session) (Session, error) {
25 | if session.ID == "" {
26 | session.ID = uuid.New().String()
27 | }
28 | session.CreatedAt = time.Now()
29 | sr.sessions[session.ID] = session
30 | return session, nil
31 | }
32 |
33 | func (sr *inMemorySessionRepository) DeleteSession(id string) error {
34 | if _, ok := sr.sessions[id]; ok {
35 | delete(sr.sessions, id)
36 | return nil
37 | }
38 | return errors.New("session does not exist")
39 | }
40 |
41 | func (sr *inMemorySessionRepository) GetSession(id string) (Session, error) {
42 | if session, ok := sr.sessions[id]; ok {
43 | return session, nil
44 | }
45 | return Session{}, errors.New("session does not exist")
46 | }
47 |
48 | func (sr *inMemorySessionRepository) UpdateSession(session Session) error {
49 | if _, ok := sr.sessions[session.ID]; ok {
50 | sr.sessions[session.ID] = session
51 | return nil
52 | }
53 | return errors.New("session does not exist")
54 | }
55 |
56 | const cleanupIntervalSeconds = 10
57 |
58 | func (sr *inMemorySessionRepository) cleanupExpired() {
59 | ticker := time.NewTicker(time.Second * cleanupIntervalSeconds)
60 | defer ticker.Stop()
61 | for {
62 | <-ticker.C
63 | for k := range sr.sessions {
64 | sess := sr.sessions[k]
65 | if (sess.CreatedAt.Second() + 60*60*24) < time.Now().Second() {
66 | sr.logger.Infof("delete session %s due to timeout", sess.ID)
67 | delete(sr.sessions, k)
68 | }
69 | }
70 | }
71 | }
72 |
73 | func newInMemorySessionRepository() sessionRepository {
74 | repo := &inMemorySessionRepository{
75 | sessions: make(map[string]Session),
76 | }
77 |
78 | go repo.cleanupExpired()
79 | return repo
80 | }
81 |
--------------------------------------------------------------------------------
/pkg/session/session_repository_mongo.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "context"
5 | "log"
6 | "time"
7 |
8 | "go.mongodb.org/mongo-driver/bson"
9 |
10 | "github.com/google/uuid"
11 |
12 | "go.mongodb.org/mongo-driver/mongo"
13 | "go.mongodb.org/mongo-driver/mongo/options"
14 | )
15 |
16 | type mongoSessionRepository struct {
17 | client *mongo.Client
18 | db string
19 | collection string
20 | }
21 |
22 | type mongoRepoSession struct {
23 | Session `bson:",inline"`
24 | }
25 |
26 | const mongoSessionExpireSeconds = 24
27 |
28 | func newMongoSessionRepository(uri, db, c string) (*mongoSessionRepository, error) {
29 | ctx, cancel := context.WithTimeout(context.Background(), cleanupIntervalSeconds*time.Second)
30 | defer cancel()
31 | log.Printf("connecting to mongo, uri: %v", uri)
32 | client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri))
33 | if err != nil {
34 | return nil, err
35 | }
36 | idxOpt := options.Index()
37 | idxOpt.SetExpireAfterSeconds(60 * 60 * mongoSessionExpireSeconds)
38 | mod := mongo.IndexModel{
39 | Keys: bson.M{
40 | "createdAt": 1, // index in ascending order
41 | }, Options: idxOpt,
42 | }
43 |
44 | rep := &mongoSessionRepository{
45 | client: client,
46 | db: db,
47 | collection: c,
48 | }
49 |
50 | _, err = rep.getCollection().Indexes().CreateOne(ctx, mod)
51 | if err != nil {
52 | return nil, err
53 | }
54 | return rep, nil
55 |
56 | }
57 |
58 | func (sr *mongoSessionRepository) CreateSession(session Session) (Session, error) {
59 | if session.ID == "" {
60 | session.ID = uuid.New().String()
61 | }
62 |
63 | session.CreatedAt = time.Now()
64 |
65 | repoSession := mongoRepoSession{
66 | Session: session,
67 | }
68 |
69 | collection := sr.getCollection()
70 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
71 | defer cancel()
72 |
73 | _, err := collection.InsertOne(ctx, &repoSession)
74 | if err != nil {
75 | return session, err
76 | }
77 |
78 | return session, nil
79 | }
80 |
81 | func (sr *mongoSessionRepository) DeleteSession(id string) error {
82 | collection := sr.getCollection()
83 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
84 | defer cancel()
85 | filter := bson.M{"id": id}
86 | err := collection.FindOneAndDelete(ctx, filter).Err()
87 | if err != nil {
88 | return err
89 | }
90 | return nil
91 | }
92 |
93 | func (sr *mongoSessionRepository) GetSession(id string) (Session, error) {
94 | var session Session
95 | collection := sr.getCollection()
96 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
97 | defer cancel()
98 | filter := bson.M{"id": id}
99 |
100 | var repoSession mongoRepoSession
101 | err := collection.FindOne(ctx, filter).Decode(&repoSession)
102 |
103 | if err != nil {
104 | return session, err
105 | }
106 |
107 | return repoSession.Session, nil
108 | }
109 |
110 | func (sr *mongoSessionRepository) UpdateSession(session Session) error {
111 | collection := sr.getCollection()
112 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
113 | defer cancel()
114 |
115 | filter := bson.M{"id": session.ID}
116 | var repoSession mongoRepoSession
117 | err := collection.FindOneAndUpdate(ctx, filter, bson.M{"$set": session}).Decode(&repoSession)
118 |
119 | if err != nil {
120 | return err
121 | }
122 | return nil
123 | }
124 |
125 | func (sr *mongoSessionRepository) getCollection() *mongo.Collection {
126 | return sr.client.Database(sr.db).Collection(sr.collection)
127 | }
128 |
--------------------------------------------------------------------------------
/pkg/session/session_repository_mongo_test.go:
--------------------------------------------------------------------------------
1 | //go:build integration
2 | // +build integration
3 |
4 | package session
5 |
6 | import (
7 | "context"
8 | "os"
9 | "os/exec"
10 | "testing"
11 | "time"
12 |
13 | "github.com/stretchr/testify/assert"
14 | )
15 |
16 | const testSessionID = "c48abbfc-93f9-46d6-b568-5a9d8394a156"
17 |
18 | func TestGetSession(t *testing.T) {
19 |
20 | repo := getRepo(t, true)
21 | t.Run("test get session", func(t *testing.T) {
22 | sess, err := repo.GetSession(testSessionID)
23 | assert.NoError(t, err)
24 | assert.Equal(t, testSessionID, sess.ID)
25 | })
26 | t.Run("test get not existing session", func(t *testing.T) {
27 | sess, err := repo.GetSession("bad")
28 | assert.Error(t, err)
29 | assert.Empty(t, sess.ID)
30 | })
31 | }
32 |
33 | func TestCreateSession(t *testing.T) {
34 | getRepo(t, true)
35 | repo := getRepo(t, false)
36 | t.Run("test create session", func(t *testing.T) {
37 | session := models.Session{
38 | Properties: map[string]string{
39 | "prop1": "value",
40 | },
41 | }
42 | newSession, err := repo.CreateSession(session)
43 | assert.NoError(t, err)
44 | assert.NotEmpty(t, newSession.ID)
45 | })
46 | }
47 |
48 | func TestDeleteSession(t *testing.T) {
49 | repo := getRepo(t, true)
50 | t.Run("test delete session", func(t *testing.T) {
51 | _, err := repo.GetSession(testSessionID)
52 | assert.NoError(t, err)
53 |
54 | err = repo.DeleteSession(testSessionID)
55 | assert.NoError(t, err)
56 |
57 | _, err = repo.GetSession(testSessionID)
58 | assert.Error(t, err)
59 |
60 | })
61 | }
62 |
63 | func TestUpdateSession(t *testing.T) {
64 | repo := getRepo(t, true)
65 | t.Run("test update session", func(t *testing.T) {
66 | sess, err := repo.GetSession(testSessionID)
67 | assert.NoError(t, err)
68 | assert.Equal(t, testSessionID, sess.ID)
69 | (&sess).Properties["prop2"] = "value2"
70 | err = repo.UpdateSession(sess)
71 | assert.NoError(t, err)
72 | newSess, _ := repo.GetSession(testSessionID)
73 | assert.Equal(t, sess.Properties["prop2"], newSess.Properties["prop2"])
74 |
75 | })
76 | t.Run("test get not existing session", func(t *testing.T) {
77 | err := repo.UpdateSession(models.Session{
78 | ID: "bad",
79 | Properties: nil,
80 | })
81 | assert.Error(t, err)
82 | })
83 | }
84 |
85 | func TestGepRepoMultipleTimes(t *testing.T) {
86 | _, err := NewMongoSessionRepository("mongodb://root:changeme@localhost:27017", "test_sessions", "sessions")
87 | assert.NoError(t, err)
88 |
89 | _, err = NewMongoSessionRepository("mongodb://root:changeme@localhost:27017", "test_sessions", "sessions")
90 | assert.NoError(t, err)
91 | }
92 |
93 | func getRepo(t *testing.T, drop bool) SessionRepository {
94 | repo, err := NewMongoSessionRepository("mongodb://root:changeme@localhost:27017", "test_sessions", "sessions")
95 | assert.NoError(t, err)
96 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
97 | defer cancel()
98 | if drop {
99 | err = repo.client.Database(repo.db).Drop(ctx)
100 | assert.NoError(t, err)
101 | }
102 | session := models.Session{
103 | ID: testSessionID,
104 | Properties: map[string]string{
105 | "prop1": "value",
106 | },
107 | }
108 | _, err = repo.CreateSession(session)
109 | assert.NoError(t, err)
110 |
111 | return repo
112 | }
113 |
114 | func checkErr(err error) {
115 | if err != nil {
116 | panic(err)
117 | }
118 | }
119 |
120 | func startDockerCompose() {
121 | dir, err := os.Getwd()
122 | checkErr(err)
123 | cmd := exec.Command("docker-compose", "-f", "docker-compose-mongodb-ldap.yaml", "--project-directory", dir, "up", "-d")
124 | cmd.Stdout = os.Stdout
125 | cmd.Stderr = os.Stderr
126 | err = cmd.Run()
127 | checkErr(err)
128 | }
129 |
130 | func stopDockerCompose() {
131 | dir, err := os.Getwd()
132 | checkErr(err)
133 | cmd := exec.Command("docker-compose", "-f", "docker-compose-mongodb-ldap.yaml", "--project-directory", dir, "down")
134 | cmd.Stdout = os.Stdout
135 | cmd.Stderr = os.Stderr
136 | err = cmd.Run()
137 | checkErr(err)
138 | }
139 |
--------------------------------------------------------------------------------
/pkg/session/session_rest.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "io"
8 | "log"
9 | "net/http"
10 | )
11 |
12 | type restSessionRepository struct {
13 | Endpoint string
14 | client http.Client
15 | }
16 |
17 | func (sr *restSessionRepository) CreateSession(session Session) (Session, error) {
18 | var newSession Session
19 | sessBytes, err := json.Marshal(session)
20 | if err != nil {
21 | return newSession, err
22 | }
23 | buf := bytes.NewBuffer(sessBytes)
24 |
25 | resp, err := sr.client.Post(sr.Endpoint, "application/json", buf)
26 | if err != nil {
27 | log.Printf("error creating session: %v", err)
28 | return newSession, err
29 | }
30 | defer resp.Body.Close()
31 |
32 | body, err := io.ReadAll(resp.Body)
33 | if err != nil {
34 | log.Printf("error creating session: %v", err)
35 | return newSession, err
36 | }
37 |
38 | err = json.Unmarshal(body, &newSession)
39 | if err != nil {
40 | log.Printf("error creating session: %v", err)
41 | return newSession, err
42 | }
43 | log.Printf("created new session: %v", newSession)
44 | return newSession, err
45 | }
46 |
47 | func (sr *restSessionRepository) DeleteSession(id string) error {
48 | ctx := context.Background()
49 | req, err := http.NewRequestWithContext(ctx, http.MethodDelete, sr.Endpoint+"/"+id, http.NoBody)
50 | if err != nil {
51 | return err
52 | }
53 |
54 | resp, err := sr.client.Do(req)
55 | if err != nil {
56 | return err
57 | }
58 | defer resp.Body.Close()
59 |
60 | return err
61 | }
62 |
63 | func (sr *restSessionRepository) UpdateSession(id string, session Session) error {
64 | return nil
65 | }
66 |
--------------------------------------------------------------------------------
/pkg/user/config.go:
--------------------------------------------------------------------------------
1 | package user
2 |
3 | type Config struct {
4 | Type string `yaml:"type"`
5 | Properties map[string]interface{} `yaml:"properties,omitempty"`
6 | }
7 |
--------------------------------------------------------------------------------
/pkg/user/service.go:
--------------------------------------------------------------------------------
1 | package user
2 |
3 | import (
4 | "github.com/mitchellh/mapstructure"
5 | )
6 |
7 | type Service struct {
8 | repo userRepository
9 | }
10 |
11 | func (us Service) GetUser(id string) (user User, exists bool) {
12 | return us.repo.GetUser(id)
13 | }
14 |
15 | func (us Service) ValidatePassword(id, password string) (valid bool) {
16 | return us.repo.ValidatePassword(id, password)
17 | }
18 |
19 | func (us Service) CreateUser(user User) (User, error) {
20 | return us.repo.CreateUser(user)
21 | }
22 |
23 | func (us Service) UpdateUser(usr User) error {
24 | return us.repo.UpdateUser(usr)
25 | }
26 |
27 | func (us Service) SetPassword(id, password string) error {
28 | return us.repo.SetPassword(id, password)
29 | }
30 |
31 | var us Service
32 |
33 | func InitUserService(uc Config) error {
34 | newUs, err := newUserService(uc)
35 | if err != nil {
36 | return err
37 | }
38 | us = newUs
39 | return nil
40 | }
41 |
42 | func GetUserService() Service {
43 | return us
44 | }
45 |
46 | func SetUserService(newUs Service) {
47 | us = newUs
48 | }
49 |
50 | func newUserService(uc Config) (us Service, err error) {
51 |
52 | if uc.Type == "ldap" {
53 | prop := uc.Properties
54 | ur := &userLdapRepository{}
55 | err = mapstructure.Decode(prop, ur)
56 | if err != nil {
57 | return us, err
58 | }
59 | us.repo = ur
60 | } else if uc.Type == "mongodb" {
61 | prop := uc.Properties
62 | params := make(map[string]interface{})
63 | err = mapstructure.Decode(&prop, ¶ms)
64 | if err != nil {
65 | return us, err
66 | }
67 | url, _ := params["url"].(string)
68 | db, _ := params["database"].(string)
69 | col, _ := params["collection"].(string)
70 | var ur *userMongoRepository
71 | ur, err = newUserMongoRepository(url, db, col)
72 | if err != nil {
73 | return us, err
74 | }
75 | us.repo = ur
76 | } else {
77 | us.repo = NewInMemoryUserRepository()
78 | }
79 |
80 | return us, err
81 | }
82 |
--------------------------------------------------------------------------------
/pkg/user/user.go:
--------------------------------------------------------------------------------
1 | package user
2 |
3 | type User struct {
4 | ID string `json:"id,omitempty"`
5 | Realm string `json:"realm,omitempty"`
6 | Roles []string `json:"roles,omitempty"`
7 | Properties map[string]string `json:"properties,omitempty"`
8 | }
9 |
10 | type Password struct {
11 | Password string `json:"password,omitempty"`
12 | }
13 |
14 | type ValidatePasswordResult struct {
15 | Valid bool `json:"valid,omitempty"`
16 | }
17 |
18 | func (u *User) SetProperty(prop, val string) {
19 | if u.Properties == nil {
20 | u.Properties = make(map[string]string)
21 | }
22 | u.Properties[prop] = val
23 | }
24 |
--------------------------------------------------------------------------------
/pkg/user/user_repository.go:
--------------------------------------------------------------------------------
1 | package user
2 |
3 | import (
4 | "github.com/google/uuid"
5 | )
6 |
7 | type userRepository interface {
8 | GetUser(id string) (User, bool)
9 | ValidatePassword(id, password string) bool
10 | CreateUser(user User) (User, error)
11 | UpdateUser(user User) error
12 | SetPassword(id, password string) error
13 | }
14 |
15 | type inMemoryUserRepository struct {
16 | Users []User
17 | Realm string
18 | passwords map[string]string
19 | }
20 |
21 | func (ur *inMemoryUserRepository) GetUser(id string) (user User, exists bool) {
22 | for _, u := range ur.Users {
23 | if u.ID == id {
24 | user = u
25 | exists = true
26 | break
27 | }
28 | }
29 | return user, exists
30 | }
31 |
32 | func (ur *inMemoryUserRepository) ValidatePassword(id, password string) (valid bool) {
33 | if password == "password" {
34 | valid = true
35 | }
36 | if setPass, ok := ur.passwords[id]; ok {
37 | if setPass == password {
38 | valid = true
39 | }
40 | }
41 | return valid
42 | }
43 |
44 | func (ur *inMemoryUserRepository) CreateUser(user User) (User, error) {
45 | if user.ID == "" {
46 | user.ID = uuid.New().String()
47 | }
48 | ur.Users = append(ur.Users, user)
49 | return user, nil
50 | }
51 |
52 | func (ur *inMemoryUserRepository) UpdateUser(usr User) error {
53 | for i, u := range ur.Users {
54 | if u.ID == usr.ID {
55 | ur.Users[i] = usr
56 | break
57 | }
58 | }
59 | return nil
60 | }
61 |
62 | func (ur *inMemoryUserRepository) SetPassword(id, password string) error {
63 | ur.passwords[id] = password
64 | return nil
65 | }
66 |
67 | func NewInMemoryUserRepository() userRepository {
68 |
69 | ds := &inMemoryUserRepository{}
70 | ds.Users = []User{
71 | {
72 | ID: "user1",
73 | Realm: "users",
74 | Roles: []string{"admin"},
75 | },
76 | {
77 | ID: "user2",
78 | Realm: "users",
79 | Roles: []string{"manager"},
80 | },
81 | {
82 | ID: "staff1",
83 | Realm: "staff",
84 | Roles: []string{"head_of_it"},
85 | },
86 | }
87 | ds.passwords = make(map[string]string)
88 | for _, u := range ds.Users {
89 | ds.passwords[u.ID] = "password"
90 | }
91 | return ds
92 | }
93 |
--------------------------------------------------------------------------------
/pkg/user/user_repository_ldap.go:
--------------------------------------------------------------------------------
1 | package user
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "log"
7 |
8 | "github.com/go-ldap/ldap/v3"
9 | )
10 |
11 | const ldapSearchTimeout = 100
12 |
13 | type userLdapRepository struct {
14 | Address string
15 | BindDN string
16 | Password string
17 | BaseDN string
18 | ObjectClasses []string
19 | UserAttributes []string
20 | }
21 |
22 | func (ur *userLdapRepository) getConnection() (*ldap.Conn, error) {
23 | conn, err := ldap.Dial("tcp", ur.Address)
24 | if err != nil {
25 | return nil, err
26 | }
27 | err = conn.Bind(ur.BindDN, ur.Password)
28 | if err != nil {
29 | return nil, err
30 | }
31 | return conn, nil
32 | }
33 |
34 | func (ur *userLdapRepository) getLdapEntry(id string, conn *ldap.Conn) (*ldap.Entry, error) {
35 | fields := append([]string{"dn", "uid"}, ur.UserAttributes...)
36 | result, err := conn.Search(ldap.NewSearchRequest(
37 | ur.BaseDN,
38 | ldap.ScopeSingleLevel,
39 | ldap.NeverDerefAliases,
40 | 0,
41 | ldapSearchTimeout,
42 | false,
43 | fmt.Sprintf("(uid=%v)", id),
44 | fields,
45 | nil,
46 | ))
47 | if err != nil {
48 | log.Print(err)
49 | return nil, err
50 | }
51 |
52 | if len(result.Entries) != 1 {
53 | return nil, fmt.Errorf("found multiple entries %v", result.Entries)
54 | }
55 |
56 | return result.Entries[0], nil
57 | }
58 |
59 | func (ur *userLdapRepository) GetUser(id string) (user User, exists bool) {
60 | conn, err := ur.getConnection()
61 | if err != nil {
62 | log.Print(err)
63 | return user, exists
64 | }
65 | defer conn.Close()
66 |
67 | entry, err := ur.getLdapEntry(id, conn)
68 | if err != nil {
69 | log.Print(err)
70 | return user, exists
71 | }
72 |
73 | properties := make(map[string]string)
74 | for _, attr := range ur.UserAttributes {
75 | properties[attr] = entry.GetAttributeValue(attr)
76 | }
77 |
78 | user = User{
79 | ID: entry.GetAttributeValue("uid"),
80 | Properties: properties,
81 | }
82 | exists = true
83 |
84 | return user, exists
85 | }
86 |
87 | func (ur *userLdapRepository) ValidatePassword(id, password string) bool {
88 | conn, err := ur.getConnection()
89 | if err != nil {
90 | log.Print(err)
91 | return false
92 | }
93 | defer conn.Close()
94 | entry, err := ur.getLdapEntry(id, conn)
95 | if err != nil {
96 | log.Print(err)
97 | return false
98 | }
99 |
100 | if err := conn.Bind(entry.DN, password); err != nil {
101 | return false
102 | }
103 | return true
104 | }
105 | func (ur *userLdapRepository) CreateUser(user User) (User, error) {
106 | conn, err := ur.getConnection()
107 | if err != nil {
108 | log.Print(err)
109 | return user, err
110 | }
111 | defer conn.Close()
112 | dn := fmt.Sprintf("uid=%v,"+ur.BaseDN, user.ID)
113 | addRequest := ldap.NewAddRequest(dn, nil)
114 | addRequest.Attribute("objectClass", ur.ObjectClasses)
115 | addRequest.Attribute("sn", []string{user.ID})
116 | addRequest.Attribute("cn", []string{user.ID})
117 | err = conn.Add(addRequest)
118 | if err != nil {
119 | log.Print(err)
120 | return user, err
121 | }
122 | return user, err
123 |
124 | }
125 | func (ur *userLdapRepository) UpdateUser(user User) error {
126 | return errors.New("not implemented")
127 | }
128 |
129 | func (ur *userLdapRepository) SetPassword(id, password string) error {
130 | conn, err := ur.getConnection()
131 | if err != nil {
132 | log.Print(err)
133 | return err
134 | }
135 | defer conn.Close()
136 | entry, err := ur.getLdapEntry(id, conn)
137 | if err != nil {
138 | log.Print(err)
139 | return err
140 | }
141 |
142 | passwordModifyRequest := ldap.NewPasswordModifyRequest(entry.DN, "", password)
143 | _, err = conn.PasswordModify(passwordModifyRequest)
144 |
145 | if err != nil {
146 | log.Printf("Password could not be changed: %s", err.Error())
147 | }
148 | return nil
149 | }
150 |
--------------------------------------------------------------------------------
/pkg/user/user_repository_ldap_test.go:
--------------------------------------------------------------------------------
1 | package user
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/google/uuid"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestLdapConnection(t *testing.T) {
12 | t.Skip("mock LDAP later...")
13 | ur := getUserLdapRepository()
14 | conn, err := ur.getConnection()
15 | assert.NoError(t, err)
16 | conn.Close()
17 | }
18 |
19 | func TestGetUser(t *testing.T) {
20 | t.Skip("mock LDAP later...")
21 | ur := getUserLdapRepository()
22 | user, exists := ur.GetUser("jerso")
23 | assert.True(t, exists)
24 | assert.Equal(t, "jerso", user.ID)
25 |
26 | _, exists2 := ur.GetUser("bad")
27 | assert.False(t, exists2)
28 | }
29 |
30 | func TestValidatePassword(t *testing.T) {
31 | t.Skip("mock LDAP later...")
32 | ur := getUserLdapRepository()
33 | err := ur.SetPassword("jerso", "passw0rd")
34 | assert.NoError(t, err)
35 | tests := []struct {
36 | name string
37 | user string
38 | password string
39 | result bool
40 | }{
41 | {"valid password", "jerso", "passw0rd", true},
42 | {"invalid password", "jerso", "bad", false},
43 | {"invalid user", "bad", "passw0rd", false},
44 | }
45 |
46 | for _, tt := range tests {
47 | t.Run(tt.name, func(t *testing.T) {
48 | result := ur.ValidatePassword(tt.user, tt.password)
49 | assert.Equal(t, tt.result, result)
50 | })
51 | }
52 | }
53 |
54 | func TestCreateUser(t *testing.T) {
55 | t.Skip("mock LDAP later...")
56 | ur := getUserLdapRepository()
57 |
58 | userID := uuid.New().String()
59 |
60 | user := User{
61 | ID: userID,
62 | }
63 | _, err := ur.CreateUser(user)
64 | assert.NoError(t, err)
65 |
66 | _, exists := ur.GetUser("jerso")
67 | assert.True(t, exists)
68 | }
69 |
70 | func TestSetPassword(t *testing.T) {
71 | t.Skip("mock LDAP later...")
72 |
73 | ur := getUserLdapRepository()
74 | var user = "jerso"
75 | newPassword := "newPassw0rd"
76 |
77 | err := ur.SetPassword(user, uuid.New().String())
78 | assert.NoError(t, err)
79 |
80 | result := ur.ValidatePassword(user, newPassword)
81 | assert.False(t, result)
82 |
83 | err = ur.SetPassword(user, newPassword)
84 | assert.NoError(t, err)
85 |
86 | result = ur.ValidatePassword(user, newPassword)
87 | assert.True(t, result)
88 | }
89 |
90 | func TestModifyUser(t *testing.T) {
91 | t.Skip("mock LDAP later...")
92 | assert.Fail(t, "not implemented")
93 | }
94 |
95 | func getUserLdapRepository() *userLdapRepository {
96 | return &userLdapRepository{
97 | Address: "localhost:50389",
98 | BindDN: "cn=admin,dc=farawaygalaxy,dc=net",
99 | Password: "passw0rd",
100 | BaseDN: "ou=users,dc=farawaygalaxy,dc=net",
101 | ObjectClasses: []string{"inetOrgPerson"},
102 | UserAttributes: []string{"sn", "cn"},
103 | }
104 | }
105 |
--------------------------------------------------------------------------------
/pkg/user/user_repository_mongo.go:
--------------------------------------------------------------------------------
1 | package user
2 |
3 | import (
4 | "context"
5 | "log"
6 | "time"
7 |
8 | "github.com/google/uuid"
9 | "go.mongodb.org/mongo-driver/bson"
10 | "go.mongodb.org/mongo-driver/mongo"
11 | "go.mongodb.org/mongo-driver/mongo/options"
12 | "golang.org/x/crypto/bcrypt"
13 | )
14 |
15 | const (
16 | dbSessionExpireHours = 24
17 | mongoDBConnectTimeoutSeconds = 10
18 | )
19 |
20 | type userMongoRepository struct {
21 | client *mongo.Client
22 | db string
23 | collection string
24 | }
25 |
26 | type mongoRepoUser struct {
27 | User `bson:",inline"`
28 | Password string `json:"password,omitempty"`
29 | }
30 |
31 | func (ur *userMongoRepository) GetUser(id string) (User, bool) {
32 | var user User
33 | collection := ur.getCollection()
34 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
35 | defer cancel()
36 | filter := bson.M{"id": id}
37 |
38 | var repoUser mongoRepoUser
39 | err := collection.FindOne(ctx, filter).Decode(&repoUser)
40 |
41 | if err != nil {
42 | return user, false
43 | }
44 |
45 | return repoUser.User, true
46 | }
47 |
48 | func (ur *userMongoRepository) ValidatePassword(id, password string) bool {
49 | collection := ur.getCollection()
50 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
51 | defer cancel()
52 | filter := bson.M{"id": id}
53 | var repoUser mongoRepoUser
54 | err := collection.FindOne(ctx, filter).Decode(&repoUser)
55 | var valid bool
56 | if err != nil {
57 | return valid
58 | }
59 |
60 | // Comparing the password with the hash
61 | err = bcrypt.CompareHashAndPassword([]byte(repoUser.Password), []byte(password))
62 | if err == nil {
63 | valid = true
64 | }
65 | // valid = repoUser.Password == password
66 |
67 | return valid
68 | }
69 |
70 | func (ur *userMongoRepository) CreateUser(user User) (User, error) {
71 | if user.ID == "" {
72 | user.ID = uuid.New().String()
73 | }
74 | repoUser := mongoRepoUser{
75 | User: user,
76 | Password: "",
77 | }
78 | collection := ur.getCollection()
79 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
80 | defer cancel()
81 |
82 | _, err := collection.InsertOne(ctx, &repoUser)
83 | if err != nil {
84 | return user, err
85 | }
86 |
87 | return user, nil
88 | }
89 |
90 | func (ur *userMongoRepository) UpdateUser(user User) error {
91 | collection := ur.getCollection()
92 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
93 | defer cancel()
94 |
95 | filter := bson.M{"id": user.ID, "realm": user.Realm}
96 | var updatedUser mongoRepoUser
97 | err := collection.FindOneAndUpdate(ctx, filter, bson.M{"$set": user}).Decode(&updatedUser)
98 |
99 | if err != nil {
100 | return err
101 | }
102 |
103 | return nil
104 | }
105 |
106 | func (ur *userMongoRepository) SetPassword(id, password string) error {
107 | collection := ur.getCollection()
108 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
109 | defer cancel()
110 |
111 | hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost)
112 | if err != nil {
113 | return err
114 | }
115 |
116 | filter := bson.M{"id": id}
117 | var updatedUser mongoRepoUser
118 | err = collection.FindOneAndUpdate(ctx, filter, bson.M{"$set": bson.M{"password": hashedPassword}}).Decode(&updatedUser)
119 |
120 | if err != nil {
121 | return err
122 | }
123 |
124 | return nil
125 | }
126 |
127 | func newUserMongoRepository(uri, db, c string) (*userMongoRepository, error) {
128 | ctx, cancel := context.WithTimeout(context.Background(), mongoDBConnectTimeoutSeconds*time.Second)
129 | defer cancel()
130 | log.Printf("connecting to mongo, uri: %v", uri)
131 | client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri))
132 | if err != nil {
133 | return nil, err
134 | }
135 | idxOpt := options.Index()
136 | idxOpt.SetExpireAfterSeconds(60 * 60 * dbSessionExpireHours)
137 | mod := mongo.IndexModel{
138 | Keys: bson.M{
139 | "createdAt": 1, // index in ascending order
140 | }, Options: idxOpt,
141 | }
142 |
143 | rep := &userMongoRepository{
144 | client: client,
145 | db: db,
146 | collection: c,
147 | }
148 |
149 | _, err = rep.getCollection().Indexes().CreateOne(ctx, mod)
150 | if err != nil {
151 | return nil, err
152 | }
153 | return rep, nil
154 | }
155 |
156 | func (ur *userMongoRepository) getCollection() *mongo.Collection {
157 | return ur.client.Database(ur.db).Collection(ur.collection)
158 | }
159 |
--------------------------------------------------------------------------------
/pkg/user/user_repository_mongo_test.go:
--------------------------------------------------------------------------------
1 | //go:build integration
2 | // +build integration
3 |
4 | package user
5 |
6 | import (
7 | "context"
8 | "github.com/google/uuid"
9 | "github.com/stretchr/testify/assert"
10 | "testing"
11 | "time"
12 | )
13 |
14 | const (
15 | testUserId = "user1"
16 | testUser2Id = "user2"
17 | testPassword = "passw0rd"
18 | )
19 |
20 | func TestUserMongoRepository_GetUser(t *testing.T) {
21 | ur := getUserMongoRepo(t, true)
22 | user, exists := ur.GetUser(testUserId)
23 | assert.True(t, exists)
24 | assert.Equal(t, testUserId, user.ID)
25 |
26 | _, exists2 := ur.GetUser("bad")
27 | assert.False(t, exists2)
28 | }
29 |
30 | func TestUserMongoRepository_ValidatePassword(t *testing.T) {
31 | ur := getUserMongoRepo(t, true)
32 | tests := []struct {
33 | name string
34 | user string
35 | password string
36 | result bool
37 | }{
38 | {"valid password", testUserId, testPassword, true},
39 | {"invalid password", testUserId, "bad", false},
40 | {"invalid user", "bad", testPassword, false},
41 | }
42 |
43 | for _, tt := range tests {
44 | t.Run(tt.name, func(t *testing.T) {
45 | result := ur.ValidatePassword(tt.user, tt.password)
46 | assert.Equal(t, tt.result, result)
47 | })
48 | }
49 | }
50 |
51 | func TestUserMongoRepository_CreateUser(t *testing.T) {
52 | user := models.User{
53 | ID: testUser2Id,
54 | }
55 | ur := getUserMongoRepo(t, true)
56 | user, err := ur.CreateUser(user)
57 | assert.NoError(t, err)
58 |
59 | user, exists := ur.GetUser(testUser2Id)
60 | assert.True(t, exists)
61 | }
62 |
63 | func TestUserMongoRepository_SetPassword(t *testing.T) {
64 | var user = testUserId
65 | newPassword := "newPassw0rd"
66 | ur := getUserMongoRepo(t, true)
67 | err := ur.SetPassword(user, uuid.New().String())
68 | assert.NoError(t, err)
69 |
70 | result := ur.ValidatePassword(user, newPassword)
71 | assert.False(t, result)
72 |
73 | err = ur.SetPassword(user, newPassword)
74 | assert.NoError(t, err)
75 |
76 | result = ur.ValidatePassword(user, newPassword)
77 | assert.True(t, result)
78 | }
79 |
80 | func TestUserMongoRepository_ModifyUser(t *testing.T) {
81 | repo := getUserMongoRepo(t, true)
82 | t.Run("test update user", func(t *testing.T) {
83 | user, ok := repo.GetUser(testUserId)
84 | assert.True(t, ok)
85 | assert.Equal(t, testUserId, user.ID)
86 | (&user).Properties["prop2"] = "value2"
87 | err := repo.UpdateUser(user)
88 | assert.NoError(t, err)
89 | newUser, _ := repo.GetUser(testUserId)
90 | assert.Equal(t, user.Properties["prop2"], newUser.Properties["prop2"])
91 |
92 | })
93 | t.Run("test update not existing user", func(t *testing.T) {
94 | err := repo.UpdateUser(models.User{
95 | ID: "bad",
96 | Properties: nil,
97 | })
98 | assert.Error(t, err)
99 | })
100 | }
101 |
102 | func getUserMongoRepo(t *testing.T, drop bool) UserRepository {
103 | repo, err := NewUserMongoRepository("mongodb://root:changeme@localhost:27017", "users", "users")
104 | assert.NoError(t, err)
105 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
106 | defer cancel()
107 | if drop {
108 | err = repo.client.Database(repo.db).Drop(ctx)
109 | assert.NoError(t, err)
110 | }
111 | user := models.User{
112 | ID: testUserId,
113 | Properties: map[string]string{
114 | "prop1": "value",
115 | },
116 | }
117 | _, err = repo.CreateUser(user)
118 | assert.NoError(t, err)
119 | repo.SetPassword(testUserId, testPassword)
120 |
121 | return repo
122 | }
123 |
--------------------------------------------------------------------------------
/pkg/user/user_repository_rest.go:
--------------------------------------------------------------------------------
1 | package user
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "io"
8 | "log"
9 | "net/http"
10 | )
11 |
12 | type userRestRepository struct {
13 | realm string
14 | endpoint string
15 | client http.Client
16 | }
17 |
18 | func (ur *userRestRepository) GetUser(id string) (user User, exists bool) {
19 | ctx := context.Background()
20 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, ur.endpoint+"/users/"+id, http.NoBody)
21 | if err != nil {
22 | log.Printf("error crearing request: %v", err)
23 | return user, exists
24 | }
25 | resp, err := ur.client.Do(req)
26 | if err != nil {
27 | log.Printf("error getting user: %v", err)
28 | return user, exists
29 | }
30 | defer resp.Body.Close()
31 |
32 | if resp.StatusCode >= http.StatusMultipleChoices {
33 | log.Printf("got bad response from user service: %v", resp)
34 | return user, exists
35 | }
36 |
37 | body, err := io.ReadAll(resp.Body)
38 | if err != nil {
39 | log.Printf("error getting user: %v", err)
40 | return user, exists
41 | }
42 |
43 | err = json.Unmarshal(body, &user)
44 | if err != nil {
45 | log.Printf("error unmarshalling user: %v", err)
46 | return user, exists
47 | }
48 | log.Printf("got user user: %v", user)
49 | exists = true
50 | return user, exists
51 | }
52 |
53 | func (ur *userRestRepository) ValidatePassword(id, password string) (valid bool) {
54 | pr := Password{
55 | Password: password,
56 | }
57 |
58 | prBytes, err := json.Marshal(pr)
59 | if err != nil {
60 | return valid
61 | }
62 |
63 | buf := bytes.NewBuffer(prBytes)
64 | ctx := context.Background()
65 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, ur.endpoint+"/users/"+id+"/validatepassword", buf)
66 | if err != nil {
67 | log.Printf("error crearing request: %v", err)
68 | return valid
69 | }
70 | req.Header.Set("Content-Type", "application/json")
71 |
72 | resp, err := ur.client.Do(req)
73 | if err != nil {
74 | log.Printf("error validating password: %v", err)
75 | return valid
76 | }
77 | defer resp.Body.Close()
78 |
79 | if resp.StatusCode >= http.StatusMultipleChoices {
80 | log.Printf("got bad response from user service: %v", resp)
81 | return valid
82 | }
83 |
84 | body, err := io.ReadAll(resp.Body)
85 | if err != nil {
86 | log.Printf("error validating password: %v", err)
87 | return valid
88 | }
89 | var vpr ValidatePasswordResult
90 |
91 | err = json.Unmarshal(body, &vpr)
92 | if err != nil {
93 | log.Printf("error validating password: %v", err)
94 | return valid
95 | }
96 | valid = vpr.Valid
97 |
98 | log.Printf("password validation result for user: %v %v", id, valid)
99 |
100 | return valid
101 | }
102 |
103 | func (ur *userRestRepository) CreateUser(user User) (User, error) {
104 | return user, nil
105 | }
106 |
107 | func (ur *userRestRepository) UpdateUser(user User) error {
108 | return nil
109 | }
110 | func (ur *userRestRepository) SetPassword(id, password string) error {
111 | return nil
112 | }
113 |
114 | func newUserRestRepository(realm, endpoint string) userRepository {
115 | return &userRestRepository{
116 | realm: realm,
117 | endpoint: endpoint,
118 | }
119 | }
120 |
--------------------------------------------------------------------------------
/test/auth-config-dev.yaml:
--------------------------------------------------------------------------------
1 | flows:
2 | login:
3 | modules:
4 | - id: "login"
5 | type: "login"
6 | properties:
7 | registration:
8 | modules:
9 | - id: "registration"
10 | type: "registration"
11 | properties:
12 | primaryField:
13 | name: "login"
14 | prompt: "Login"
15 | additionalFileds:
16 | - dataStore: "name"
17 | prompt: "Name"
18 | qr:
19 | modules:
20 | - id: "qr"
21 | type: "qr"
22 |
23 | kerberos:
24 | modules:
25 | - id: "kerberos"
26 | type: "kerberos"
27 | properties:
28 | keyTabFile: ""
29 | servicePrincipal: ""
30 |
31 | userDataStore:
32 | type: "inMemory"
33 |
34 | session:
35 | type: "stateful"
36 | expires: 60000
37 | jwt:
38 | issuer: 'http://gortas'
39 | privateKeyPem: |
40 | -----BEGIN RSA PRIVATE KEY-----
41 | MIIBOQIBAAJATmLeD2qa5ejVKJ3rwcSJaZAeRw4CVrUHvi1uVvBah6+6qCdjvH8N
42 | RT+GOI3ymdnilILPHcn51A0XQAXyrvFkgwIDAQABAkAPZUvIK2ARGBIF0D6l6Dw1
43 | B6Fqw02iShwjNjkdykd9rsZ+UwsYHJ9xXSa2xp7eGurIUqyaDxF+53xpE9AH72PB
44 | AiEAlEOIScKvyIqp3ZAxjYUd3feke2AGq4ckoq/dXFvxKHcCIQCHWH+6xKyXqaDL
45 | bG5rq18VQR2Nj7VknY4Eir6Z6LrzVQIgSz3WbXBi2wgb2ngx3ZsfpCToEUCTQftM
46 | iU9srFFwmlMCIFPUbMixqHUHi6BzuLDXpDz15+gWarO3Io+NoCCUFbdBAiEAinVf
47 | Lnb+YDP3L5ZzSNF92P9yBQaopFCifjrUqSS85uw=
48 | -----END RSA PRIVATE KEY-----
49 |
50 | dataStore:
51 | type: "inMemory"
52 |
53 | server:
54 | cors:
55 | allowedOrigins:
56 | - http://localhost:3000
57 |
--------------------------------------------------------------------------------
/test/auth-config-hydra.yaml:
--------------------------------------------------------------------------------
1 | flows:
2 | login:
3 | modules:
4 | - id: "hydra"
5 | type: "hydra"
6 | properties:
7 | uri: "https://localhost:4445"
8 | skipTLS: true #disable in production!
9 | - id: "login"
10 | type: "login"
11 | properties:
12 |
13 | registration:
14 | modules:
15 | - id: "hydra"
16 | type: "hydra"
17 | properties:
18 | uri: "https://localhost:4445"
19 | skipTLS: true #disable in production!
20 | - id: "registration"
21 | type: "registration"
22 | properties:
23 | additionalFileds:
24 | - dataStore: "name"
25 | prompt: "Name"
26 |
27 | userDataStore:
28 | type: "inMemory"
29 |
30 | session:
31 | type: "stateless" #could be also stateful, implement later
32 | expires: 60000
33 | jwt:
34 | issuer: 'http://gortas'
35 | privateKeyPem: |
36 | -----BEGIN RSA PRIVATE KEY-----
37 | MIIBOQIBAAJATmLeD2qa5ejVKJ3rwcSJaZAeRw4CVrUHvi1uVvBah6+6qCdjvH8N
38 | RT+GOI3ymdnilILPHcn51A0XQAXyrvFkgwIDAQABAkAPZUvIK2ARGBIF0D6l6Dw1
39 | B6Fqw02iShwjNjkdykd9rsZ+UwsYHJ9xXSa2xp7eGurIUqyaDxF+53xpE9AH72PB
40 | AiEAlEOIScKvyIqp3ZAxjYUd3feke2AGq4ckoq/dXFvxKHcCIQCHWH+6xKyXqaDL
41 | bG5rq18VQR2Nj7VknY4Eir6Z6LrzVQIgSz3WbXBi2wgb2ngx3ZsfpCToEUCTQftM
42 | iU9srFFwmlMCIFPUbMixqHUHi6BzuLDXpDz15+gWarO3Io+NoCCUFbdBAiEAinVf
43 | Lnb+YDP3L5ZzSNF92P9yBQaopFCifjrUqSS85uw=
44 | -----END RSA PRIVATE KEY-----
45 |
46 | dataStore:
47 | type: "inMemory"
48 |
49 | server:
50 | cors:
51 | allowedOrigins:
52 | - http://localhost:3000
53 |
--------------------------------------------------------------------------------
/test/auth-config-otp.yaml:
--------------------------------------------------------------------------------
1 | flows:
2 | otp:
3 | modules:
4 | - id: "email"
5 | type: "credentials"
6 | properties:
7 | primaryField:
8 | name: "email"
9 | prompt: "Email"
10 | required: true
11 |
12 | - id: "otp"
13 | type: "otp"
14 | properties:
15 | otpLength: 4
16 | useLetters: false
17 | useDigits: true
18 | otpTimeoutSec: 180
19 | otpResendSec: 90
20 | otpRetryCount: 5
21 | otpMessageTemplate: Code {{.OTP}} valid for {{.ValidFor}} min
22 | sender:
23 | senderType: "test"
24 | properties:
25 | host: "localhost"
26 | port: 1234
27 |
28 | userDataStore:
29 | type: "inMemory"
30 |
31 | session:
32 | type: "stateless" #could be also stateful
33 | expires: 60000
34 | jwt:
35 | issuer: 'http://gortas'
36 | privateKeyPem: | #generate your own: openssl genrsa -out key.pem 2048
37 | -----BEGIN RSA PRIVATE KEY-----
38 | MIIBOQIBAAJATmLeD2qa5ejVKJ3rwcSJaZAeRw4CVrUHvi1uVvBah6+6qCdjvH8N
39 | RT+GOI3ymdnilILPHcn51A0XQAXyrvFkgwIDAQABAkAPZUvIK2ARGBIF0D6l6Dw1
40 | B6Fqw02iShwjNjkdykd9rsZ+UwsYHJ9xXSa2xp7eGurIUqyaDxF+53xpE9AH72PB
41 | AiEAlEOIScKvyIqp3ZAxjYUd3feke2AGq4ckoq/dXFvxKHcCIQCHWH+6xKyXqaDL
42 | bG5rq18VQR2Nj7VknY4Eir6Z6LrzVQIgSz3WbXBi2wgb2ngx3ZsfpCToEUCTQftM
43 | iU9srFFwmlMCIFPUbMixqHUHi6BzuLDXpDz15+gWarO3Io+NoCCUFbdBAiEAinVf
44 | Lnb+YDP3L5ZzSNF92P9yBQaopFCifjrUqSS85uw=
45 | -----END RSA PRIVATE KEY-----
46 |
47 | dataStore:
48 | type: "inMemory"
49 |
50 | server:
51 | cors:
52 | allowedOrigins:
53 | - http://localhost:3000
54 |
--------------------------------------------------------------------------------
/test/docker-compose-ldap.yaml:
--------------------------------------------------------------------------------
1 | version: '3.7'
2 | services:
3 | ldap:
4 | image: kazhar/openldap-demo
5 | ports:
6 | - "50389:389"
--------------------------------------------------------------------------------
/test/docker-compose-mongodb-ldap.yaml:
--------------------------------------------------------------------------------
1 | version: '3.7'
2 |
3 | services:
4 |
5 | mongo:
6 | image: mongo:latest
7 | restart: always
8 | ports:
9 | - 27017:27017
10 | environment:
11 | MONGO_INITDB_ROOT_USERNAME: root
12 | MONGO_INITDB_ROOT_PASSWORD: changeme
13 |
14 | openldap:
15 | image: kazhar/openldap-demo
16 | ports:
17 | - "50389:389"
18 |
19 | mongo-express:
20 | image: mongo-express
21 | restart: always
22 | ports:
23 | - 8081:8081
24 | environment:
25 | ME_CONFIG_MONGODB_ADMINUSERNAME: root
26 | ME_CONFIG_MONGODB_ADMINPASSWORD: changeme
--------------------------------------------------------------------------------
/test/docker-compose-mongodb.yaml:
--------------------------------------------------------------------------------
1 | version: '3.7'
2 | services:
3 | mongo:
4 | image: mongo:latest
5 | restart: always
6 | ports:
7 | - 27017:27017
8 | environment:
9 | MONGO_INITDB_ROOT_USERNAME: root
10 | MONGO_INITDB_ROOT_PASSWORD: changeme
11 |
--------------------------------------------------------------------------------
/test/docker-compose-pgsql.yaml:
--------------------------------------------------------------------------------
1 | version: '3.7'
2 | services:
3 | postgres:
4 | image: postgres:latest
5 | restart: always
6 | ports:
7 | - 5432:5432
8 | environment:
9 | POSTGRES_PASSWORD: passw0rd
10 | POSTGRES_DB: gortas
11 |
--------------------------------------------------------------------------------
/test/integration/otp_auth_test.go:
--------------------------------------------------------------------------------
1 | package integration_test
2 |
3 | import (
4 | "bytes"
5 | "crypto/rand"
6 | "crypto/rsa"
7 | "crypto/x509"
8 | "encoding/json"
9 | "encoding/pem"
10 | "fmt"
11 | "net/http"
12 | "net/http/httptest"
13 | "strings"
14 | "testing"
15 |
16 | "github.com/gin-gonic/gin"
17 | "github.com/maximthomas/gortas/pkg/auth/callbacks"
18 | "github.com/maximthomas/gortas/pkg/auth/constants"
19 | "github.com/maximthomas/gortas/pkg/auth/modules/otp"
20 | "github.com/maximthomas/gortas/pkg/auth/state"
21 | "github.com/maximthomas/gortas/pkg/config"
22 | "github.com/maximthomas/gortas/pkg/server"
23 | "github.com/maximthomas/gortas/pkg/session"
24 | "github.com/stretchr/testify/assert"
25 | )
26 |
27 | var privateKey, _ = rsa.GenerateKey(rand.Reader, 2048)
28 | var privateKeyStr = string(pem.EncodeToMemory(
29 | &pem.Block{
30 | Type: "RSA PRIVATE KEY",
31 | Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
32 | },
33 | ))
34 | var (
35 | flows = map[string]config.Flow{
36 | "otp": {Modules: []config.Module{
37 | {
38 | ID: "otp_link",
39 | Type: "otp",
40 | Criteria: constants.CriteriaSufficient,
41 | Properties: map[string]interface{}{
42 | "otpCheckMagicLink": true,
43 | }},
44 | {
45 | ID: "phone",
46 | Type: "credentials",
47 | Properties: map[string]interface{}{
48 | "primaryField": map[string]interface{}{
49 | "Name": "phone",
50 | "Prompt": "Phone",
51 | "Required": true,
52 | "Validation": "^\\d{4,20}$",
53 | },
54 | },
55 | },
56 | {
57 | ID: "otp",
58 | Type: "otp",
59 | Properties: map[string]interface{}{
60 | "otpLength": 4,
61 | "useLetters": false,
62 | "useDigits": true,
63 | "otpTimeoutSec": 180,
64 | "otpResendSec": 90,
65 | "otpRetryCount": 5,
66 | "OtpMessageTemplate": "Code {{.OTP}} valid for {{.ValidFor}} min, link code {{.MagicLink}}",
67 | "sender": map[string]interface{}{
68 | "senderType": "test",
69 | "properties": map[string]interface{}{
70 | "host": "localhost",
71 | "port": 1234,
72 | },
73 | },
74 | },
75 | },
76 | }},
77 | }
78 |
79 | conf = config.Config{
80 | Flows: flows,
81 | Session: session.Config{
82 | Type: "stateless",
83 | Expires: 60000,
84 | Jwt: session.JWT{
85 | Issuer: "http://gortas",
86 | PrivateKeyPem: privateKeyStr,
87 | },
88 | },
89 | EncryptionKey: "Gb8l9wSZzEjeL2FTRG0k6bBnw7AZ/rBCcZfDDGLVreY=",
90 | }
91 | router *gin.Engine
92 | )
93 |
94 | const (
95 | authURL = "http://localhost/gortas/v1/auth/otp"
96 | badPhone = "123"
97 | validPhone = "5551112233"
98 | )
99 |
100 | func init() {
101 | config.SetConfig(&conf)
102 | router = server.SetupRouter(&conf)
103 | }
104 |
105 | func TestOTPAuth(t *testing.T) {
106 |
107 | flowCookie := initAuth(t)
108 |
109 | // send phone
110 | requestBody := fmt.Sprintf(`{"callbacks":[{"name":"phone", "value": "%v"}]}`, validPhone)
111 | request := httptest.NewRequest("POST", authURL, bytes.NewBuffer([]byte(requestBody)))
112 | request.AddCookie(flowCookie)
113 |
114 | cbReq := &callbacks.Request{}
115 | executeRequest(t, request, cbReq)
116 | assert.Equal(t, "otp", cbReq.Module)
117 | assert.Equal(t, 2, len(cbReq.Callbacks))
118 | assert.Equal(t, "otp", cbReq.Callbacks[0].Name)
119 | assert.Equal(t, "action", cbReq.Callbacks[1].Name)
120 |
121 | // send OTP
122 | sess, _ := session.GetSessionService().GetSession(flowCookie.Value)
123 | var fs state.FlowState
124 | err := json.Unmarshal([]byte(sess.Properties[constants.FlowStateSessionProperty]), &fs)
125 | if err != nil {
126 | panic(err)
127 | }
128 | fs.Modules[2].State["otp"] = "1234"
129 | sd, _ := json.Marshal(fs)
130 | sess.Properties[constants.FlowStateSessionProperty] = string(sd)
131 | err = session.GetSessionService().UpdateSession(sess)
132 | if err != nil {
133 | panic(err)
134 | }
135 |
136 | requestBody = fmt.Sprintf(`{"callbacks":[{"name":"otp", "value": "%v"},{"name":"action", "value": "%v"}]}`, "1234", "check")
137 | request = httptest.NewRequest("POST", authURL, bytes.NewBuffer([]byte(requestBody)))
138 | request.AddCookie(flowCookie)
139 | cbReq = &callbacks.Request{}
140 | resp := executeRequest(t, request, cbReq)
141 | cookie, err := GetCookieValue("GortasSession", resp.Cookies())
142 | assert.NotEmpty(t, cookie)
143 | assert.NoError(t, err)
144 | }
145 |
146 | func TestOTPAuth_invalidPhone(t *testing.T) {
147 | flowCookie := initAuth(t)
148 | // send invalid phone
149 | requestBody := fmt.Sprintf(`{"callbacks":[{"name":"phone", "value": "%v"}]}`, badPhone)
150 |
151 | request := httptest.NewRequest("POST", authURL, bytes.NewBuffer([]byte(requestBody)))
152 | request.AddCookie(flowCookie)
153 | cbReq := &callbacks.Request{}
154 | executeRequest(t, request, cbReq)
155 | assert.Equal(t, "phone", cbReq.Module)
156 | assert.Equal(t, 1, len(cbReq.Callbacks))
157 | assert.Equal(t, "phone", cbReq.Callbacks[0].Name)
158 | assert.Equal(t, "Phone invalid", cbReq.Callbacks[0].Error)
159 | }
160 |
161 | func TestOTPAuth_invalidOtp(t *testing.T) {
162 | flowCookie := authPhone(validPhone, t)
163 |
164 | //send invalid OTP
165 | requestBody := fmt.Sprintf(`{"callbacks":[{"name":"otp", "value": "%v"},{"name":"action", "value": "%v"}]}`, "1234", "check")
166 | request := httptest.NewRequest("POST", authURL, bytes.NewBuffer([]byte(requestBody)))
167 | request.AddCookie(flowCookie)
168 | cbReq := &callbacks.Request{}
169 | executeRequest(t, request, cbReq)
170 | assert.Equal(t, "Invalid OTP", cbReq.Callbacks[0].Error)
171 | }
172 |
173 | func initAuth(t *testing.T) *http.Cookie {
174 | // init auth
175 | request := httptest.NewRequest("GET", authURL, nil)
176 | cbReq := &callbacks.Request{}
177 | resp := executeRequest(t, request, cbReq)
178 | assert.Equal(t, "phone", cbReq.Module)
179 | assert.Equal(t, 1, len(cbReq.Callbacks))
180 | assert.Equal(t, "phone", cbReq.Callbacks[0].Name)
181 |
182 | cookieVal, _ := GetCookieValue(state.FlowCookieName, resp.Cookies())
183 |
184 | return &http.Cookie{
185 | Name: state.FlowCookieName,
186 | Value: cookieVal,
187 | }
188 | }
189 |
190 | func authPhone(phone string, t *testing.T) *http.Cookie {
191 | flowCookie := initAuth(t)
192 |
193 | requestBody := fmt.Sprintf(`{"callbacks":[{"name":"phone", "value": "%v"}]}`, validPhone)
194 | request := httptest.NewRequest("POST", authURL, bytes.NewBuffer([]byte(requestBody)))
195 | request.AddCookie(flowCookie)
196 |
197 | cbReq := &callbacks.Request{}
198 | executeRequest(t, request, cbReq)
199 | return flowCookie
200 | }
201 |
202 | func TestOTPAuthMagicLink(t *testing.T) {
203 |
204 | request := httptest.NewRequest("GET", authURL, nil)
205 | cbReq := &callbacks.Request{}
206 | resp := executeRequest(t, request, cbReq)
207 | assert.Equal(t, "phone", cbReq.Module)
208 | assert.Equal(t, 1, len(cbReq.Callbacks))
209 | assert.Equal(t, "phone", cbReq.Callbacks[0].Name)
210 |
211 | flowID, _ := GetCookieValue(state.FlowCookieName, resp.Cookies())
212 |
213 | flowCookie := &http.Cookie{
214 | Name: state.FlowCookieName,
215 | Value: flowID,
216 | }
217 |
218 | //send valid phone
219 | requestBody := fmt.Sprintf(`{"callbacks":[{"name":"phone", "value": "%v"}]}`, validPhone)
220 | request = httptest.NewRequest("POST", authURL, bytes.NewBuffer([]byte(requestBody)))
221 | request.AddCookie(flowCookie)
222 |
223 | cbReq = &callbacks.Request{}
224 | executeRequest(t, request, cbReq)
225 | assert.Equal(t, "otp", cbReq.Module)
226 | assert.Equal(t, 2, len(cbReq.Callbacks))
227 | assert.Equal(t, "otp", cbReq.Callbacks[0].Name)
228 | assert.Equal(t, "action", cbReq.Callbacks[1].Name)
229 |
230 | ms, err := otp.GetSender("test", make(map[string]interface{}, 0))
231 | assert.NoError(t, err)
232 | ts := ms.(*otp.TestSender)
233 | msg := ts.Messages[validPhone]
234 | msgCode := strings.Split(msg, "link code")
235 |
236 | request = httptest.NewRequest("GET", authURL+"?code="+strings.TrimSpace(msgCode[1]), nil)
237 | request.AddCookie(flowCookie)
238 | cbReq = &callbacks.Request{}
239 | resp = executeRequest(t, request, cbReq)
240 | cookie, err := GetCookieValue("GortasSession", resp.Cookies())
241 | assert.NotEmpty(t, cookie)
242 | assert.NoError(t, err)
243 | }
244 |
245 | func executeRequest(t *testing.T, r *http.Request, res interface{}) *http.Response {
246 | recorder := httptest.NewRecorder()
247 | router.ServeHTTP(recorder, r)
248 | assert.Equal(t, http.StatusOK, recorder.Result().StatusCode)
249 | response := recorder.Body.String()
250 | fmt.Println(response)
251 | err := json.Unmarshal([]byte(response), res)
252 | assert.NoError(t, err)
253 | return recorder.Result()
254 | }
255 |
--------------------------------------------------------------------------------
/test/integration/utils_test.go:
--------------------------------------------------------------------------------
1 | package integration_test
2 |
3 | import (
4 | "errors"
5 | "net/http"
6 | )
7 |
8 | //helper functions
9 | func GetCookieValue(name string, c []*http.Cookie) (string, error) {
10 |
11 | for _, cookie := range c {
12 | if cookie.Name == name {
13 | return cookie.Value, nil
14 | }
15 | }
16 | return "", errors.New("cookie not found")
17 | }
18 |
--------------------------------------------------------------------------------
/traefik/conf/config.yml:
--------------------------------------------------------------------------------
1 | http:
2 | routers:
3 | gortas-auth:
4 | rule: "PathPrefix(`/gortas/v1/auth`)"
5 | service: gortas
6 |
7 | sample-service:
8 | rule: "PathPrefix(`/secured`)"
9 | service: sample-service
10 | middlewares:
11 | - gortas-plugin
12 |
13 | services:
14 | gortas:
15 | loadBalancer:
16 | servers:
17 | - url: http://gortas:8080
18 | sample-service:
19 | loadBalancer:
20 | servers:
21 | - url: http://sample-service:8080
22 |
23 | middlewares:
24 | gortas-plugin:
25 | plugin:
26 | gortas:
27 | gortasUrl: http://gortas:8080/gortas
28 |
--------------------------------------------------------------------------------
/traefik/plugins-local/src/github.com/maximthomas/gortas_traefik_plugin/.traefik.yml:
--------------------------------------------------------------------------------
1 | displayName: Gortas Traefik Plugin
2 | type: middleware
3 |
4 | import: github.com/maximthomas/gortas_traefik_plugin
5 |
6 | summary: 'Gortas Traefik Plugin'
7 |
8 | testData:
9 | gortasUrl: http://gortas:8080
10 |
--------------------------------------------------------------------------------
/traefik/plugins-local/src/github.com/maximthomas/gortas_traefik_plugin/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/maximthomas/gortas_traefik_plugin
2 |
3 | go 1.19
4 |
--------------------------------------------------------------------------------
/traefik/plugins-local/src/github.com/maximthomas/gortas_traefik_plugin/gortas_plugin.go:
--------------------------------------------------------------------------------
1 | // Package plugindemo a demo plugin.
2 | package gortas_traefik_plugin
3 |
4 | import (
5 | "context"
6 | "encoding/json"
7 | "errors"
8 | "log"
9 | "net/http"
10 | "strings"
11 | "time"
12 | )
13 |
14 | // Config the plugin configuration.
15 | type Config struct {
16 | GortasUrl string `yaml:"gortasUrl,omitempty"`
17 | }
18 |
19 | // CreateConfig creates the default plugin configuration.
20 | func CreateConfig() *Config {
21 | return &Config{GortasUrl: ""}
22 | }
23 |
24 | // GortasPlugin a GortasPlugin plugin.
25 | type GortasPlugin struct {
26 | gortasUrl string
27 | name string
28 | tr *http.Transport
29 | next http.Handler
30 | }
31 |
32 | // New created a new Demo plugin.
33 | func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
34 |
35 | log.Printf("got configuration %v", config)
36 |
37 | return &GortasPlugin{
38 | gortasUrl: config.GortasUrl + "/v1/session/jwt",
39 | next: next,
40 | name: name,
41 | tr: &http.Transport{
42 | MaxIdleConns: 10,
43 | IdleConnTimeout: 10 * time.Second,
44 | DisableCompression: true,
45 | },
46 | }, nil
47 | }
48 |
49 | func (a *GortasPlugin) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
50 |
51 | bearerToken := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
52 | if bearerToken == "" {
53 | a.next.ServeHTTP(rw, req)
54 | return
55 | }
56 |
57 | jwt, err := a.convertToken(bearerToken)
58 | if err != nil {
59 | log.Printf("error converting token: %v", err)
60 | } else {
61 | req.Header.Set("Authorization", "Bearer "+jwt)
62 | }
63 | a.next.ServeHTTP(rw, req)
64 | }
65 |
66 | func (a *GortasPlugin) convertToken(token string) (jwt string, err error) {
67 | client := &http.Client{Transport: a.tr}
68 | req, err := http.NewRequest("GET", a.gortasUrl, nil)
69 |
70 | req.Header.Add("Authorization", "Bearer "+token)
71 | resp, err := client.Do(req)
72 |
73 | if err != nil {
74 | log.Print(err)
75 | return jwt, err
76 | }
77 |
78 | defer resp.Body.Close()
79 |
80 | jwtReponse := &struct {
81 | Jwt string
82 | Error string
83 | }{}
84 |
85 | err = json.NewDecoder(resp.Body).Decode(jwtReponse)
86 | if err != nil {
87 | log.Print(err)
88 | return jwt, err
89 | }
90 |
91 | if jwtReponse.Error != "" {
92 | return jwt, errors.New(jwtReponse.Error)
93 | }
94 |
95 | return jwtReponse.Jwt, err
96 | }
97 |
--------------------------------------------------------------------------------
/traefik/plugins-local/src/github.com/maximthomas/gortas_traefik_plugin/gortas_plugin_test.go:
--------------------------------------------------------------------------------
1 | package gortas_traefik_plugin
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestGortasPlugin(t *testing.T) {
8 | t.Skip()
9 | }
10 |
--------------------------------------------------------------------------------
/traefik/traefik.yml:
--------------------------------------------------------------------------------
1 | ## traefik.yml
2 |
3 | # Docker configuration backend
4 | providers:
5 | docker:
6 | defaultRule: "Host(`{{ trimPrefix `/` .Name }}.docker.localhost`)"
7 | file:
8 | directory: "/etc/traefik/conf"
9 |
10 | # API and dashboard configuration
11 | api:
12 | insecure: true
13 |
14 | # plugins
15 | experimental:
16 | localPlugins:
17 | gortas:
18 | moduleName: github.com/maximthomas/gortas_traefik_plugin
--------------------------------------------------------------------------------
/ui.Dockerfile:
--------------------------------------------------------------------------------
1 | FROM node:lts as build-deps
2 | WORKDIR /usr/src/app
3 | ARG REACT_APP_GORTAS_URL=http://localhost:8080
4 | ENV REACT_APP_GORTAS_URL=${REACT_APP_GORTAS_URL}
5 | ARG REACT_APP_GORTAS_SIGN_UP_PATH=/gortas/v1/login/users/registration
6 | ENV REACT_APP_GORTAS_SIGN_UP_PATH=${REACT_APP_GORTAS_SIGN_UP_PATH}
7 | ARG REACT_APP_GORTAS_SIGN_IN_PATH=/gortas/v1/login/users/login
8 | ENV REACT_APP_GORTAS_SIGN_IN_PATH=${REACT_APP_GORTAS_SIGN_IN_PATH}
9 | RUN git clone https://github.com/maximthomas/gortas-ui.git .
10 | RUN yarn && yarn build
11 |
12 | FROM nginx:1.17-alpine
13 | COPY --from=build-deps /usr/src/app/build /usr/share/nginx/html
14 | EXPOSE 80
15 | CMD ["nginx", "-g", "daemon off;"]
--------------------------------------------------------------------------------