├── .dockerignore ├── .github └── workflows │ └── tests.yml ├── .gitignore ├── CHANGELOG.md ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── Setup.hs ├── app └── Main.hs ├── entrypoint.sh ├── src └── Network │ └── Wai │ ├── Auth │ ├── AppRoot.hs │ ├── ClientSession.hs │ ├── Config.hs │ ├── Executable.hs │ ├── Internal.hs │ └── Tools.hs │ └── Middleware │ ├── Auth.hs │ └── Auth │ ├── OAuth2.hs │ ├── OAuth2 │ ├── Github.hs │ ├── Gitlab.hs │ └── Google.hs │ ├── OIDC.hs │ └── Provider.hs ├── stack-lts-14.yaml ├── stack-nightly.yaml ├── stack.yaml ├── stack.yaml.lock ├── test ├── Main.hs ├── Network │ └── Wai │ │ └── Auth │ │ └── Test.hs └── Spec │ └── Network │ └── Wai │ ├── Auth │ └── Internal.hs │ └── Middleware │ └── Auth │ ├── OAuth2.hs │ └── OIDC.hs └── wai-middleware-auth.cabal /.dockerignore: -------------------------------------------------------------------------------- 1 | Dockerfile 2 | .git 3 | .stack-work 4 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | 9 | jobs: 10 | build: 11 | name: CI 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | os: [ubuntu-latest, macos-latest] 17 | args: 18 | - "--resolver nightly --stack-yaml stack-nightly.yaml" 19 | - "--resolver lts-17" 20 | - "--resolver lts-16" 21 | - "--resolver lts-14 --stack-yaml stack-lts-14.yaml" 22 | 23 | steps: 24 | - name: Clone project 25 | uses: actions/checkout@v2 26 | 27 | # Getting weird OS X errors... 28 | # - name: Cache dependencies 29 | # uses: actions/cache@v1 30 | # with: 31 | # path: ~/.stack 32 | # key: ${{ runner.os }}-${{ matrix.resolver }}-${{ hashFiles('stack.yaml') }} 33 | # restore-keys: | 34 | # ${{ runner.os }}-${{ matrix.resolver }}- 35 | 36 | - name: Build and run tests 37 | shell: bash 38 | run: | 39 | set -ex 40 | stack upgrade 41 | stack --version 42 | stack test --fast --no-terminal ${{ matrix.args }} 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | dist-* 3 | cabal-dev 4 | *.o 5 | *.hi 6 | *.chi 7 | *.chs.h 8 | *.dyn_o 9 | *.dyn_hi 10 | .hpc 11 | .hsenv 12 | .cabal-sandbox/ 13 | cabal.sandbox.config 14 | *.prof 15 | *.aux 16 | *.hp 17 | *.eventlog 18 | .stack-work/ 19 | cabal.project.local 20 | client_session_key.aes 21 | *.swp 22 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # 0.2.6.0 2 | 3 | - Add support for https in reverse_proxy 4 | 5 | # 0.2.5.1 6 | 7 | - Add support for GHC 9.0. 8 | 9 | # 0.2.5.0 10 | 11 | - Add `getAuthUserFromVault` for `Servant.Api.Vault` user. 12 | 13 | # 0.2.4.0 14 | 15 | - Add GitLab provider. 16 | 17 | # 0.2.3.1 18 | 19 | - Expose `discoverURI` in `Network.Wai.Middleware.Auth.OIDC` 20 | - Fix bug with OAuth2 and OpenID Connect authentication where scopes were 21 | separated using comma's instead of spaces. 22 | 23 | # 0.2.3.0 24 | 25 | - Support `hoauth2-1.11.0` 26 | - Drop support for `jose` versions < 0.8 27 | - Expose `decodeKey` 28 | - OAuth2 provider remove a session when an access token expires. It will use a 29 | refresh token if one is available to create a new session. If no refresh token 30 | is available it will redirect the user to re-authenticate. 31 | - Providers can define logic for refreshing a session without user intervention. 32 | - Add an OpenID Connect provider. 33 | 34 | # 0.2.2.0 35 | 36 | - Add request logging to executable 37 | - Newer multistage Docker build system 38 | 39 | # 0.2.1.0 40 | 41 | - Fix a bug in deserialization of `UserIdentity` 42 | 43 | # 0.2.0.0 44 | 45 | - Drop compatiblity with hoauth2 versions <= 1.0.0. 46 | - Add a function for getting the oauth2 token from an authenticated request. 47 | - Modify encoding of oauth2 session cookies. As a consequence existing cookies will be invalid. 48 | 49 | # 0.1.2.1 50 | 51 | - Compatibility with hoauth2-1.3.0 - fixed: [#4](https://github.com/fpco/wai-middleware-auth/issues/4) 52 | 53 | # 0.1.2.0 54 | 55 | - Implemented compatibility with hoauth2 >= 1.0.0 - fixed: [#3](https://github.com/fpco/wai-middleware-auth/issues/3) 56 | 57 | # 0.1.1.2 58 | 59 | - Fixed [wai-middleware-auth-0.1.1.1 does not compile in 32 bit Linux](https://github.com/fpco/wai-middleware-auth/issues/2) 60 | 61 | # 0.1.1.1 62 | 63 | - Disallow empty `userIdentity` to produce a successfull login. 64 | - Produces a 404 on `/favicon.ico` page if not logged in: work around for issue 65 | with Chrome requesting it first and messing up the redirect url. 66 | - Added JQuery to the template, since it's bootstrap's requirement. 67 | 68 | # 0.1.1.0 69 | 70 | - Fixed whitelist email regex matching for Github and Google auth. 71 | 72 | # 0.1.0.0 73 | 74 | - Initial implementation. 75 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Just a preexisting build image that has everything we need 2 | FROM snoyberg/haskellers-build-image:e17739d1c2c043aae11924fee66c9ee4304ad37d as build 3 | 4 | # Get the compiler in place and cached 5 | COPY stack.yaml /tmp/stack.yaml 6 | RUN stack setup --stack-yaml /tmp/stack.yaml 7 | 8 | # Build just the dependencies in the cache 9 | COPY wai-middleware-auth.cabal /tmp/ 10 | RUN stack build --only-dependencies --stack-yaml /tmp/stack.yaml 11 | 12 | # Build the actual project 13 | COPY . /src 14 | RUN stack install --local-bin-path /output --stack-yaml /src/stack.yaml 15 | 16 | # Runtime image 17 | FROM fpco/pid1 18 | 19 | # Set lang env var appropriately 20 | ENV LANG C.UTF-8 21 | 22 | # Install necessary dependencies for making SSL connections 23 | ENV DEBIAN_FRONTEND noninteractive 24 | 25 | RUN apt-get update && apt-get install -y \ 26 | ca-certificates \ 27 | libgmp-dev \ 28 | netbase 29 | 30 | # Copy over the executable from the build image 31 | COPY --from=build /output/wai-auth /usr/local/bin/wai-auth 32 | 33 | # Set up the entrypoint correctly for local users 34 | COPY entrypoint.sh /usr/local/bin/entrypoint.sh 35 | ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Alexey Kuleshevich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help build build-image build-base push-image 2 | 3 | DEFAULT_GOAL: help 4 | 5 | VERSION ?= $(shell grep "^version:" wai-middleware-auth.cabal | cut -d " " -f14) 6 | IMAGE_NAME := fpco/wai-auth 7 | 8 | dinamo: 9 | @echo ${VERSION} 10 | 11 | ## Build stack project (natively) 12 | build: 13 | @stack build 14 | 15 | ## Build docker image (builds project in a container first) 16 | build-image: build-base 17 | @docker build . --tag ${IMAGE_NAME}:${VERSION} 18 | @docker tag ${IMAGE_NAME}:${VERSION} ${IMAGE_NAME} 19 | 20 | ## Push docker image 21 | push-image: 22 | @docker push ${IMAGE_NAME}:${VERSION} 23 | @docker push ${IMAGE_NAME} 24 | 25 | ## Show help screen. 26 | help: 27 | @echo "Please use \`make ' where is one of\n\n" 28 | @awk '/^[a-zA-Z\-\_0-9]+:/ { \ 29 | helpMessage = match(lastLine, /^## (.*)/); \ 30 | if (helpMessage) { \ 31 | helpCommand = substr($$1, 0, index($$1, ":")-1); \ 32 | helpMessage = substr(lastLine, RSTART + 3, RLENGTH); \ 33 | printf "%-30s %s\n", helpCommand, helpMessage; \ 34 | } \ 35 | } \ 36 | { lastLine = $$0 }' $(MAKEFILE_LIST) 37 | 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wai-middleware-auth 2 | 3 | [![Build Status](https://github.com/fpco/wai-middleware-auth/actions/workflows/tests.yml/badge.svg)](https://github.com/fpco/wai-middleware-auth/actions/workflows/tests.yml) 4 | 5 | Middleware that secures WAI application 6 | 7 | ## Installation 8 | 9 | ```shell 10 | $ stack install wai-middleware-auth 11 | ``` 12 | OR 13 | ```shell 14 | $ cabal install wai-middleware-auth 15 | ``` 16 | 17 | ## wai-auth 18 | 19 | Along with middleware this package ships with an executable `wai-auth`, which 20 | can function as a protected file server or a reverse proxy. Right from the box 21 | it supports OAuth2 authentication as well as it's custom implementations for 22 | Google and Github. 23 | 24 | Configuration is done using a yaml config file. Here is a sample file that will 25 | configure `wai-auth` to run a file server with Google, GitHub, and GitLab 26 | authentication on `http://localhost:3000`: 27 | 28 | ```yaml 29 | app_root: "_env:APPROOT:http://localhost:3000" 30 | app_port: 3000 31 | cookie_age: 3600 32 | secret_key: "...+vwscbKR4DyPT" 33 | file_server: 34 | root_folder: "/path/to/html/files" 35 | redirect_to_index: true 36 | add_trailing_slash: true 37 | providers: 38 | github: 39 | client_id: "...94cc" 40 | client_secret: "...166f" 41 | app_name: "Dev App for wai-middleware-auth" 42 | email_white_list: 43 | - "^[a-zA-Z0-9._%+-]+@example.com$" 44 | google: 45 | client_id: "...qlj.apps.googleusercontent.com" 46 | client_secret: "...oxW" 47 | email_white_list: 48 | - "^[a-zA-Z0-9._%+-]+@example.com$" 49 | gitlab: 50 | client_id: "...9cfc" 51 | client_secret: "...f0d0" 52 | app_name: "Dev App for wai-middleware-auth" 53 | email_white_list: 54 | - "^[a-zA-Z0-9._%+-]+@example.com$" 55 | ``` 56 | 57 | Above configuration will also block access to users that don't have an email 58 | with `example.com` domain. There is also a `secret_key` field which will be used 59 | to encrypt the session cookie. In order to generate a new random key run this command: 60 | 61 | ```shell 62 | $ echo $(wai-auth key --base64) 63 | azuCFq0zEBkLSXhQrhliZzZD8Kblo... 64 | ``` 65 | 66 | Make sure you have proper callback/redirect urls registered with 67 | google/github/gitlab apps, eg: 68 | `http://localhost:3000/_auth_middleware/google/complete`. 69 | 70 | After configuration file is ready, running application is very easy: 71 | 72 | ```shell 73 | $ wai-auth --config-file=/path/to/config.yaml 74 | Listening on port 3000 75 | ``` 76 | 77 | ### Reverse proxy 78 | 79 | To use a reverse proxy instead of a file server, replace `file_server` with 80 | `reverse_proxy`, eg: 81 | 82 | ```yaml 83 | reverse_proxy: 84 | host: myapp.example.com 85 | port: 80 86 | secure: false 87 | ``` 88 | 89 | ### Self-hosted GitLab 90 | 91 | The GitLab provider also supports using a self-hosted GitLab instance by 92 | setting the `gitlab_host` field. In this case you may also want to override 93 | the `provider_info` to change the title, logo, and description. For example: 94 | 95 | ```yaml 96 | providers: 97 | gitlab: 98 | gitlab_host: gitlab.mycompany.com 99 | client_id: "...9cfc" 100 | client_secret: "...f0d0" 101 | app_name: "Dev App for wai-middleware-auth" 102 | email_white_list: 103 | - "^[a-zA-Z0-9._%+-]+@mycompany.com$" 104 | provider_info: 105 | title: My Company's GitLab 106 | logo_url: https://mycompany.com/logo.png 107 | descr: Use your My Company GitLab account to access this page. 108 | ``` 109 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main :: IO () 3 | main = defaultMain 4 | -------------------------------------------------------------------------------- /app/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE OverloadedStrings #-} 3 | {-# LANGUAGE RecordWildCards #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | module Main where 6 | import qualified Data.ByteString as S 7 | import Data.Serialize (put, runPut) 8 | import Network.Wai.Auth.Executable 9 | import Network.Wai.Handler.Warp (run) 10 | import Network.Wai.Middleware.Auth 11 | import Network.Wai.Middleware.Auth.OAuth2 12 | import Network.Wai.Middleware.Auth.OAuth2.Gitlab 13 | import Network.Wai.Middleware.Auth.OAuth2.Github 14 | import Network.Wai.Middleware.Auth.OAuth2.Google 15 | import Network.Wai.Middleware.RequestLogger (logStdout) 16 | import Options.Applicative.Simple 17 | import Web.ClientSession 18 | 19 | data BasicOptions 20 | = ConfigFile FilePath 21 | | KeyFile KeyOptions 22 | 23 | showHelpText :: ParseError 24 | showHelpText = ShowHelpText 25 | #if MIN_VERSION_optparse_applicative(0,16,0) 26 | Nothing 27 | #endif 28 | 29 | 30 | 31 | basicSettingsParser :: String -> Parser BasicOptions 32 | basicSettingsParser version = 33 | (ConfigFile <$> 34 | strOption 35 | (long "config-file" <> short 'c' <> metavar "CONFIG" <> 36 | help "File with configuration for the Auth application.") <* 37 | abortOption 38 | (InfoMsg version) 39 | (long "version" <> short 'v' <> help "Current version.") <* 40 | abortOption 41 | showHelpText 42 | (long "help" <> short 'h' <> help "Display this message.")) <|> 43 | (subparser 44 | (command 45 | "key" 46 | (info 47 | (KeyFile <$> keyOptionsParser) 48 | (progDesc 49 | ("Command for creating a secret key or converting one into base64 " ++ 50 | "form, which can then be directly used inside a config file.") <> 51 | fullDesc)))) 52 | 53 | 54 | data KeyOptions = KeyOptions 55 | { keyInput :: FilePath 56 | , keyOutput :: FilePath 57 | , keyBase64 :: Bool 58 | } 59 | 60 | keyOptionsParser :: Parser KeyOptions 61 | keyOptionsParser = 62 | KeyOptions <$> 63 | strOption 64 | (long "input-file" <> short 'i' <> metavar "INPUT" <> value "" <> 65 | help "Read key from a file, instead of generating a new one.") <*> 66 | strOption 67 | (long "output-file" <> short 'o' <> metavar "OUTPUT" <> value "" <> 68 | help "Write key into a file, instead of stdout. File will be overwritten.") <*> 69 | switch 70 | (long "base64" <> short 'b' <> 71 | help "Produce a key in a base64 encoded form.") <* 72 | abortOption 73 | showHelpText 74 | (long "help" <> short 'h' <> help "Display this message.") 75 | 76 | 77 | main :: IO () 78 | main = do 79 | opts <- 80 | execParser 81 | (info 82 | (basicSettingsParser $(simpleVersion waiMiddlewareAuthVersion)) 83 | (header "wai-auth - Authentication server" <> 84 | progDesc "Run a protected file server or reverse proxy." <> 85 | fullDesc)) 86 | case opts of 87 | ConfigFile configFile -> do 88 | authConfig <- readAuthConfig configFile 89 | mkMain authConfig [gitlabParser, githubParser, googleParser, oAuth2Parser] $ \port app -> do 90 | putStrLn $ "Listening on port " ++ show port 91 | run port $ logStdout app 92 | KeyFile (KeyOptions {..}) -> do 93 | let key2str = 94 | if keyBase64 95 | then encodeKey 96 | else (runPut . put) 97 | key <- 98 | key2str <$> 99 | if null keyInput 100 | then snd <$> randomKey 101 | else do 102 | keyContent <- S.readFile keyInput 103 | either error return (decodeKey keyContent <|> initKey keyContent) 104 | if null keyOutput 105 | then S.putStr key 106 | else S.writeFile keyOutput key 107 | -------------------------------------------------------------------------------- /entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Add local user 4 | # Either use the LOCAL_USER_ID if passed in at runtime or 5 | # fallback 6 | 7 | USER_ID=${LOCAL_USER_ID:-9001} 8 | APP_DIR=${APP_DIR:-/opt/app} 9 | 10 | echo "Starting with UID : $USER_ID" 11 | useradd --shell /bin/bash -u $USER_ID -o -c "" -m user 12 | export HOME=/home/user 13 | 14 | exec /sbin/pid1 -u user -g user "$@" 15 | 16 | -------------------------------------------------------------------------------- /src/Network/Wai/Auth/AppRoot.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | module Network.Wai.Auth.AppRoot 3 | ( smartAppRoot 4 | ) where 5 | 6 | import Data.ByteString (ByteString) 7 | import Data.CaseInsensitive (CI, mk) 8 | import qualified Data.HashMap.Lazy as HM 9 | import qualified Data.Text as T 10 | import Data.Text.Encoding (decodeUtf8With) 11 | import Data.Text.Encoding.Error (lenientDecode) 12 | import Network.HTTP.Types (Header) 13 | import Network.Wai (Request, isSecure, requestHeaderHost, 14 | requestHeaders) 15 | 16 | 17 | -- | Determine approot by: 18 | -- 19 | -- * Respect the Host header and isSecure property, together with the following de facto standards: x-forwarded-protocol, x-forwarded-ssl, x-url-scheme, x-forwarded-proto, front-end-https. (Note: this list may be updated at will in the future without doc updates.) 20 | -- 21 | -- Normally trusting headers in this way is insecure, however in the case of approot, the worst that can happen is that the client will get an incorrect URL. Note that this does not work for some situations, e.g.: 22 | -- 23 | -- * Reverse proxies not setting one of the above mentioned headers 24 | -- 25 | -- * Applications hosted somewhere besides the root of the domain name 26 | -- 27 | -- * Reverse proxies that modify the host header 28 | -- 29 | -- @since 0.1.0.0 30 | smartAppRoot :: Request -> T.Text 31 | smartAppRoot req = 32 | let secure = isSecure req || any isSecureHeader (requestHeaders req) 33 | host = 34 | maybe "localhost" (decodeUtf8With lenientDecode) (requestHeaderHost req) 35 | in (if secure 36 | then "https://" 37 | else "http://") <> 38 | host 39 | 40 | -- | 41 | -- 42 | -- See: http://stackoverflow.com/a/16042648/369198 43 | httpsHeaders :: HM.HashMap (CI ByteString) (CI ByteString) 44 | httpsHeaders = 45 | HM.fromList 46 | [ ("X-Forwarded-Protocol", "https") 47 | , ("X-Forwarded-Ssl", "on") 48 | , ("X-Url-Scheme", "https") 49 | , ("X-Forwarded-Proto", "https") 50 | , ("Front-End-Https", "on") 51 | ] 52 | 53 | isSecureHeader :: Header -> Bool 54 | isSecureHeader (key, value) = 55 | case HM.lookup key httpsHeaders of 56 | Nothing -> False 57 | Just value' -> valueCI == value' 58 | where 59 | valueCI = mk value 60 | -------------------------------------------------------------------------------- /src/Network/Wai/Auth/ClientSession.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveGeneric #-} 2 | {-# LANGUAGE OverloadedStrings #-} 3 | module Network.Wai.Auth.ClientSession 4 | ( loadCookieValue 5 | , saveCookieValue 6 | , deleteCookieValue 7 | , Key 8 | , getDefaultKey 9 | ) where 10 | 11 | import Blaze.ByteString.Builder (toByteString) 12 | import Control.Monad (guard) 13 | import Data.Binary (Binary, decodeOrFail, encode) 14 | import qualified Data.ByteString as S 15 | import qualified Data.ByteString.Base64.URL as B64 16 | import qualified Data.ByteString.Lazy as L 17 | import Data.Int (Int64) 18 | import Data.Maybe (listToMaybe) 19 | import Data.Time.Clock (UTCTime(UTCTime)) 20 | import Data.Time.Calendar (fromGregorian) 21 | import Foreign.C.Types (CTime (..)) 22 | import GHC.Generics (Generic) 23 | import Network.HTTP.Types (Header) 24 | import Network.Wai (Request, requestHeaders) 25 | import System.PosixCompat.Time (epochTime) 26 | import Web.ClientSession (Key, decrypt, encryptIO, 27 | getDefaultKey) 28 | import Web.Cookie (def, parseCookies, renderSetCookie, 29 | sameSiteLax, setCookieExpires, 30 | setCookieHttpOnly, setCookieMaxAge, 31 | setCookieName, setCookiePath, 32 | setCookieSameSite, setCookieValue) 33 | 34 | data Wrapper value = Wrapper 35 | { contained :: value 36 | , expires :: !Int64 -- ^ should really be EpochTime or CTime, but there's no Binary instance 37 | } deriving (Generic) 38 | instance Binary value => Binary (Wrapper value) 39 | 40 | loadCookieValue 41 | :: Binary value 42 | => Key 43 | -> S.ByteString -- ^ cookie name 44 | -> Request 45 | -> IO (Maybe value) 46 | loadCookieValue key name req = do 47 | CTime now <- epochTime 48 | return $ 49 | listToMaybe $ do 50 | (k, v) <- requestHeaders req 51 | guard $ k == "cookie" 52 | (name', v') <- parseCookies v 53 | guard $ name == name' 54 | Right v'' <- return $ B64.decode v' 55 | Just v''' <- return $ decrypt key v'' 56 | Right (_, _, Wrapper res expi) <- 57 | return $ decodeOrFail $ L.fromStrict v''' 58 | guard $ expi >= fromIntegral now 59 | return res 60 | 61 | saveCookieValue 62 | :: Binary value 63 | => Key 64 | -> S.ByteString -- ^ cookie name 65 | -> Int -- ^ age in seconds 66 | -> value 67 | -> IO Header 68 | saveCookieValue key name age value = do 69 | CTime now <- epochTime 70 | value' <- 71 | encryptIO key $ 72 | L.toStrict $ 73 | encode 74 | Wrapper {contained = value, expires = fromIntegral now + fromIntegral age} 75 | return 76 | ( "Set-Cookie" 77 | , toByteString $ 78 | renderSetCookie 79 | def 80 | { setCookieName = name 81 | , setCookieValue = B64.encode value' 82 | , setCookiePath = Just "/" 83 | , setCookieHttpOnly = True 84 | , setCookieMaxAge = Just $ fromIntegral age 85 | , setCookieSameSite = Just sameSiteLax 86 | }) 87 | 88 | deleteCookieValue 89 | :: S.ByteString -- ^ cookie name 90 | -> Header 91 | deleteCookieValue name = 92 | ( "Set-Cookie" 93 | , toByteString $ 94 | renderSetCookie 95 | def 96 | { setCookieName = name 97 | , setCookieValue = "" 98 | , setCookiePath = Just "/" 99 | , setCookieHttpOnly = True 100 | , setCookieExpires = Just $ UTCTime (fromGregorian 1970 01 01) 0 101 | , setCookieSameSite = Just sameSiteLax 102 | }) 103 | -------------------------------------------------------------------------------- /src/Network/Wai/Auth/Config.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | {-# LANGUAGE TemplateHaskell #-} 4 | module Network.Wai.Auth.Config 5 | ( AuthConfig(..) 6 | , SecretKey(..) 7 | , Service(..) 8 | , FileServer(..) 9 | , ReverseProxy(..) 10 | , encodeKey 11 | , decodeKey 12 | ) where 13 | 14 | import Data.Aeson 15 | import Data.Aeson.TH (deriveJSON) 16 | import qualified Data.Text as T 17 | import Data.Text.Encoding (encodeUtf8) 18 | import Network.Wai.Auth.Tools (decodeKey, encodeKey, 19 | toLowerUnderscore) 20 | import Web.ClientSession (Key) 21 | 22 | -- | Configuration for a secret key that will be used to encrypt authenticated 23 | -- user as client side cookie. 24 | data SecretKey 25 | = SecretKeyFile FilePath -- ^ Path to a secret key file in binary form, if it 26 | -- is malformed or doesn't exist it will be 27 | -- (re)created. If empty "client_session_key.aes" 28 | -- name will be used 29 | | SecretKey Key -- ^ Serialized and base64 encoded form of a secret key. Use 30 | -- `encodeKey` to get a proper encoded form. 31 | 32 | 33 | -- | Configuration for reverse proxy application. 34 | data FileServer = FileServer 35 | { fsRootFolder :: FilePath -- ^ Path to a folder containing files 36 | -- that will be served by this app. 37 | , fsRedirectToIndex :: Bool -- ^ Redirect to the actual index file, not 38 | -- leaving the URL containing the directory 39 | -- name 40 | , fsAddTrailingSlash :: Bool -- ^ Add a trailing slash to directory names 41 | } 42 | 43 | -- | Configuration for reverse proxy application. 44 | data ReverseProxy = ReverseProxy 45 | { rpHost :: T.Text -- ^ Hostname of the destination webserver 46 | , rpPort :: Int -- ^ Port of the destination webserver 47 | , rpSecure :: Maybe Bool -- ^ Should the request be sent to destination webbserver using https or not (default: false) 48 | } 49 | 50 | -- | Available services. 51 | data Service = ServiceFiles FileServer 52 | | ServiceProxy ReverseProxy 53 | 54 | -- | Configuration for @wai-auth@ executable and any other, that is created using 55 | -- `Network.Wai.Auth.Executable.mkMain` 56 | data AuthConfig = AuthConfig 57 | { configAppRoot :: Maybe T.Text -- ^ Root Url of the website, eg: 58 | -- http://example.com or 59 | -- https://example.com It will be used to 60 | -- perform redirects back from external 61 | -- authentication providers. 62 | , configAppPort :: Int -- ^ Port number. Default is 3000 63 | , configRequireTls :: Bool -- ^ Require requests come in over a secure 64 | -- connection (determined via headers). Will 65 | -- redirect to HTTPS if non-secure 66 | -- dedected. Default is @False@ 67 | , configSkipAuth :: Bool -- ^ Turn off authentication middleware, useful for 68 | -- testing. Default is @False@ 69 | , configCookieAge :: Int -- ^ Duration of the session in seconds. Default is 70 | -- one hour (3600 seconds). 71 | , configSecretKey :: SecretKey -- ^ Secret key. Default is "client_session_key.aes" 72 | , configService :: Service 73 | , configProviders :: Object 74 | } 75 | 76 | $(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 2} ''FileServer) 77 | 78 | $(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 2} ''ReverseProxy) 79 | 80 | instance FromJSON AuthConfig where 81 | parseJSON = 82 | withObject "Auth Config Object" $ \obj -> do 83 | configAppRoot <- obj .:? "app_root" 84 | configAppPort <- obj .:? "app_port" .!= 3000 85 | configRequireTls <- (obj .:? "require_tls" .!= False) 86 | configSkipAuth <- obj .:? "skip_auth" .!= False 87 | configCookieAge <- obj .:? "cookie_age" .!= 3600 88 | mSecretKeyB64T <- obj .:? "secret_key" 89 | configSecretKey <- 90 | case mSecretKeyB64T of 91 | Just secretKeyB64T -> 92 | either fail (return . SecretKey) $ decodeKey (encodeUtf8 secretKeyB64T) 93 | Nothing -> SecretKeyFile <$> (obj .:? "secret_key_file" .!= "") 94 | mFileServer <- obj .:? "file_server" 95 | mReverseProxy <- obj .:? "reverse_proxy" 96 | let sErrMsg = 97 | "Either 'file_server' or 'reverse_proxy' is required, but not both." 98 | configService <- 99 | case (mFileServer, mReverseProxy) of 100 | (Just fileServer, Nothing) -> ServiceFiles <$> parseJSON fileServer 101 | (Nothing, Just reverseProxy) -> 102 | ServiceProxy <$> parseJSON reverseProxy 103 | (Just _, Just _) -> fail $ "Too many services. " ++ sErrMsg 104 | (Nothing, Nothing) -> fail $ "No service is supplied. " ++ sErrMsg 105 | configProviders <- obj .: "providers" 106 | return AuthConfig {..} 107 | -------------------------------------------------------------------------------- /src/Network/Wai/Auth/Executable.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE OverloadedStrings #-} 3 | {-# LANGUAGE RecordWildCards #-} 4 | module Network.Wai.Auth.Executable 5 | ( mkMain 6 | , readAuthConfig 7 | , serviceToApp 8 | , module Network.Wai.Auth.Config 9 | , Port 10 | ) where 11 | import Data.Aeson (Result (..)) 12 | import Data.String (fromString) 13 | import Data.Text.Encoding (encodeUtf8) 14 | import Data.Yaml.Config (loadYamlSettings, useEnv) 15 | import Network.HTTP.Client.TLS (getGlobalManager) 16 | import Network.HTTP.ReverseProxy (ProxyDest (..), 17 | WaiProxyResponse (WPRProxyDest, WPRProxyDestSecure), 18 | defaultOnExc, waiProxyTo) 19 | import Network.Wai (Application) 20 | import Network.Wai.Application.Static (defaultFileServerSettings, 21 | ssAddTrailingSlash, 22 | ssRedirectToIndex, 23 | staticApp) 24 | import Network.Wai.Auth.Config 25 | import Network.Wai.Middleware.Auth 26 | import Network.Wai.Middleware.Auth.Provider 27 | import Network.Wai.Middleware.ForceSSL (forceSSL) 28 | import Web.ClientSession (getKey) 29 | 30 | 31 | type Port = Int 32 | 33 | -- | Create an `Application` from a `Service` 34 | -- 35 | -- @since 0.1.0 36 | serviceToApp :: Service -> IO Application 37 | serviceToApp (ServiceFiles FileServer {..}) = do 38 | return $ 39 | staticApp 40 | (defaultFileServerSettings $ fromString fsRootFolder) 41 | { ssRedirectToIndex = fsRedirectToIndex 42 | , ssAddTrailingSlash = fsAddTrailingSlash 43 | } 44 | serviceToApp (ServiceProxy (ReverseProxy host port secure)) = do 45 | manager <- getGlobalManager 46 | return $ 47 | waiProxyTo 48 | (const $ return $ proxydest $ ProxyDest (encodeUtf8 host) port) 49 | defaultOnExc 50 | manager 51 | where 52 | proxydest = 53 | case secure of 54 | Just True -> WPRProxyDestSecure 55 | _ -> WPRProxyDest 56 | 57 | 58 | -- | Read configuration from a yaml file with ability to use environment 59 | -- variables. See "Data.Yaml.Config" module for details. 60 | -- 61 | -- @since 0.1.0 62 | readAuthConfig :: FilePath -> IO AuthConfig 63 | readAuthConfig confFile = loadYamlSettings [confFile] [] useEnv 64 | 65 | 66 | -- | Construct a @main@ function. 67 | -- 68 | -- @since 0.1.0 69 | mkMain 70 | :: AuthConfig -- ^ Use `readAuthConfig` to read config from a file. 71 | -> [ProviderParser] 72 | -- ^ Parsers for supported providers. `ProviderParser` can be created with 73 | -- `Network.Wai.Middleware.Auth.Provider.mkProviderParser`. 74 | -> (Port -> Application -> IO ()) 75 | -- ^ Application runner, for instance Warp's @run@ function. 76 | -> IO () 77 | mkMain AuthConfig {..} providerParsers run = do 78 | let !providers = 79 | case parseProviders configProviders providerParsers of 80 | Error errMsg -> error errMsg 81 | Success providers' -> providers' 82 | let authSettings = 83 | (case configSecretKey of 84 | SecretKey key -> setAuthKey $ return key 85 | SecretKeyFile "" -> id 86 | SecretKeyFile keyPath -> setAuthKey (getKey keyPath)) 87 | . (case configAppRoot of 88 | Just appRoot -> setAuthAppRootStatic appRoot 89 | Nothing -> id) 90 | . setAuthProviders providers 91 | . setAuthSessionAge configCookieAge 92 | $ defaultAuthSettings 93 | authMiddleware <- mkAuthMiddleware authSettings 94 | app <- serviceToApp configService 95 | run configAppPort $ 96 | (if configRequireTls 97 | then forceSSL 98 | else id) 99 | (if configSkipAuth 100 | then app 101 | else authMiddleware app) 102 | -------------------------------------------------------------------------------- /src/Network/Wai/Auth/Internal.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_HADDOCK hide, not-home #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE RecordWildCards #-} 4 | {-# LANGUAGE OverloadedStrings #-} 5 | {-# LANGUAGE TupleSections #-} 6 | module Network.Wai.Auth.Internal 7 | ( OAuth2TokenBinary(..) 8 | , Metadata(..) 9 | , encodeToken 10 | , decodeToken 11 | , oauth2Login 12 | , refreshTokens 13 | ) where 14 | 15 | import qualified Data.Aeson as Aeson 16 | import Data.Binary (Binary(get, put), encode, 17 | decodeOrFail) 18 | import qualified Data.ByteString as S 19 | import qualified Data.ByteString.Char8 as S8 (pack) 20 | import qualified Data.ByteString.Lazy as SL 21 | import qualified Data.Text as T 22 | import Data.Text.Encoding (encodeUtf8, 23 | decodeUtf8With) 24 | import Data.Text.Encoding.Error (lenientDecode) 25 | import GHC.Generics (Generic) 26 | import Network.HTTP.Client (Manager) 27 | import Network.HTTP.Types (Status, status303, 28 | status403, status404, 29 | status501) 30 | import qualified Network.OAuth.OAuth2 as OA2 31 | import Network.Wai (Request, Response, 32 | queryString, responseLBS) 33 | import Network.Wai.Middleware.Auth.Provider 34 | import qualified URI.ByteString as U 35 | import URI.ByteString (URI) 36 | 37 | decodeToken :: S.ByteString -> Either String OA2.OAuth2Token 38 | decodeToken bs = 39 | case decodeOrFail $ SL.fromStrict bs of 40 | Right (_, _, token) -> Right $ unOAuth2TokenBinary token 41 | Left (_, _, err) -> Left err 42 | 43 | encodeToken :: OA2.OAuth2Token -> S.ByteString 44 | encodeToken = SL.toStrict . encode . OAuth2TokenBinary 45 | 46 | newtype OAuth2TokenBinary = 47 | OAuth2TokenBinary { unOAuth2TokenBinary :: OA2.OAuth2Token } 48 | deriving (Show) 49 | 50 | instance Binary OAuth2TokenBinary where 51 | put (OAuth2TokenBinary token) = do 52 | put $ OA2.atoken $ OA2.accessToken token 53 | put $ OA2.rtoken <$> OA2.refreshToken token 54 | put $ OA2.expiresIn token 55 | put $ OA2.tokenType token 56 | put $ OA2.idtoken <$> OA2.idToken token 57 | get = do 58 | accessToken <- OA2.AccessToken <$> get 59 | refreshToken <- fmap OA2.RefreshToken <$> get 60 | expiresIn <- get 61 | tokenType <- get 62 | idToken <- fmap OA2.IdToken <$> get 63 | pure $ OAuth2TokenBinary $ 64 | OA2.OAuth2Token accessToken refreshToken expiresIn tokenType idToken 65 | 66 | oauth2Login 67 | :: OA2.OAuth2 68 | -> Manager 69 | -> Maybe [T.Text] 70 | -> T.Text 71 | -> Request 72 | -> [T.Text] 73 | -> (AuthLoginState -> IO Response) 74 | -> (Status -> S.ByteString -> IO Response) 75 | -> IO Response 76 | oauth2Login oauth2 man oa2Scope providerName req suffix onSuccess onFailure = 77 | case suffix of 78 | [] -> do 79 | -- https://tools.ietf.org/html/rfc6749#section-3.3 80 | let scope = (encodeUtf8 . T.intercalate " ") <$> oa2Scope 81 | let redirectUrl = 82 | getRedirectURI $ 83 | appendQueryParams 84 | (OA2.authorizationUrl oauth2) 85 | (maybe [] ((: []) . ("scope", )) scope) 86 | return $ 87 | responseLBS 88 | status303 89 | [("Location", redirectUrl)] 90 | "Redirect to OAuth2 Authentication server" 91 | ["complete"] -> 92 | let params = queryString req 93 | in case lookup "code" params of 94 | Just (Just code) -> do 95 | eRes <- OA2.fetchAccessToken man oauth2 $ getExchangeToken code 96 | case eRes of 97 | Left err -> onFailure status501 $ S8.pack $ show err 98 | Right token -> onSuccess $ encodeToken token 99 | _ -> 100 | case lookup "error" params of 101 | (Just (Just "access_denied")) -> 102 | onFailure 103 | status403 104 | "User rejected access to the application." 105 | (Just (Just error_code)) -> 106 | onFailure status501 $ "Received an error: " <> error_code 107 | (Just Nothing) -> 108 | onFailure status501 $ 109 | "Unknown error connecting to " <> 110 | encodeUtf8 providerName 111 | Nothing -> 112 | onFailure 113 | status404 114 | "Page not found. Please continue with login." 115 | _ -> onFailure status404 "Page not found. Please continue with login." 116 | 117 | refreshTokens :: OA2.OAuth2Token -> Manager -> OA2.OAuth2 -> IO (Maybe OA2.OAuth2Token) 118 | refreshTokens tokens manager oauth2 = 119 | case OA2.refreshToken tokens of 120 | Nothing -> pure Nothing 121 | Just refreshToken -> do 122 | res <- OA2.refreshAccessToken manager oauth2 refreshToken 123 | case res of 124 | Left _ -> pure Nothing 125 | Right newTokens -> pure (Just newTokens) 126 | 127 | getExchangeToken :: S.ByteString -> OA2.ExchangeToken 128 | getExchangeToken = OA2.ExchangeToken . decodeUtf8With lenientDecode 129 | 130 | appendQueryParams :: URI -> [(S.ByteString, S.ByteString)] -> URI 131 | appendQueryParams uri params = 132 | OA2.appendQueryParams params uri 133 | 134 | getRedirectURI :: U.URIRef a -> S.ByteString 135 | getRedirectURI = U.serializeURIRef' 136 | 137 | data Metadata 138 | = Metadata 139 | { issuer :: T.Text 140 | , authorizationEndpoint :: U.URI 141 | , tokenEndpoint :: U.URI 142 | , userinfoEndpoint :: Maybe T.Text 143 | , revocationEndpoint :: Maybe T.Text 144 | , jwksUri :: T.Text 145 | , responseTypesSupported :: [T.Text] 146 | , subjectTypesSupported :: [T.Text] 147 | , idTokenSigningAlgValuesSupported :: [T.Text] 148 | , scopesSupported :: Maybe [T.Text] 149 | , tokenEndpointAuthMethodsSupported :: Maybe [T.Text] 150 | , claimsSupported :: Maybe [T.Text] 151 | } 152 | deriving (Generic) 153 | 154 | instance Aeson.FromJSON Metadata where 155 | parseJSON = Aeson.genericParseJSON metadataAesonOptions 156 | 157 | instance Aeson.ToJSON Metadata where 158 | 159 | toJSON = Aeson.genericToJSON metadataAesonOptions 160 | 161 | toEncoding = Aeson.genericToEncoding metadataAesonOptions 162 | 163 | metadataAesonOptions :: Aeson.Options 164 | metadataAesonOptions = 165 | Aeson.defaultOptions {Aeson.fieldLabelModifier = Aeson.camelTo2 '_'} 166 | -------------------------------------------------------------------------------- /src/Network/Wai/Auth/Tools.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | module Network.Wai.Auth.Tools 3 | ( encodeKey 4 | , decodeKey 5 | , toLowerUnderscore 6 | , getValidEmail 7 | ) where 8 | 9 | import qualified Data.ByteString as S 10 | import Data.ByteString.Base64 as B64 11 | import Data.Char (isLower, toLower) 12 | import Data.Foldable (foldr') 13 | import Data.Maybe (listToMaybe) 14 | import Data.Serialize (Get, get, put, runGet, runPut) 15 | import Text.Regex.Posix ((=~)) 16 | import Web.ClientSession (Key) 17 | 18 | 19 | -- | Decode a `Key` that is in a base64 encoded serialized form 20 | decodeKey :: S.ByteString -> Either String Key 21 | decodeKey secretKeyB64 = B64.decode secretKeyB64 >>= runGet (get :: Get Key) 22 | 23 | 24 | -- | Serialize and base64 encode a secret `Key` 25 | encodeKey :: Key -> S.ByteString 26 | encodeKey = B64.encode . runPut . put 27 | 28 | 29 | -- | Prepend all but the first capital letter with underscores and convert all 30 | -- of them to lower case. 31 | toLowerUnderscore :: String -> String 32 | toLowerUnderscore [] = [] 33 | toLowerUnderscore (x:xs) = toLower x : foldr' toLowerWithUnder [] xs 34 | where 35 | toLowerWithUnder !y !acc 36 | | isLower y = y : acc 37 | | otherwise = '_' : toLower y : acc 38 | 39 | 40 | -- | Check email list against a whitelist and pick first one that matches or 41 | -- Nothing otherwise. 42 | getValidEmail :: [S.ByteString] -> [S.ByteString] -> Maybe S.ByteString 43 | getValidEmail whitelist emails = 44 | listToMaybe $ filter (not . S.null) [e =~ w | e <- emails, w <- whitelist] 45 | -------------------------------------------------------------------------------- /src/Network/Wai/Middleware/Auth.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 4 | {-# LANGUAGE OverloadedStrings #-} 5 | {-# LANGUAGE RecordWildCards #-} 6 | module Network.Wai.Middleware.Auth 7 | ( -- * Settings 8 | AuthSettings 9 | , defaultAuthSettings 10 | , setAuthKey 11 | , setAuthAppRootStatic 12 | , setAuthAppRootGeneric 13 | , setAuthSessionAge 14 | , setAuthPrefix 15 | , setAuthCookieName 16 | , setAuthProviders 17 | , setAuthProvidersTemplate 18 | -- * Middleware 19 | , mkAuthMiddleware 20 | -- * Helpers 21 | , smartAppRoot 22 | , waiMiddlewareAuthVersion 23 | , getAuthUser 24 | , getAuthUserFromVault 25 | , getDeleteSessionHeader 26 | , decodeKey 27 | ) where 28 | 29 | import Blaze.ByteString.Builder (fromByteString) 30 | import Data.Binary (Binary) 31 | import qualified Data.ByteString as S 32 | import Data.ByteString.Builder (Builder) 33 | import qualified Data.HashMap.Strict as HM 34 | import qualified Data.Text as T 35 | import Data.Text.Encoding (decodeUtf8With, 36 | encodeUtf8) 37 | import Data.Text.Encoding.Error (lenientDecode) 38 | import qualified Data.Vault.Lazy as Vault 39 | import Data.Version (Version) 40 | import Foreign.C.Types (CTime (..)) 41 | import GHC.Generics (Generic) 42 | import Network.HTTP.Types (Header, status200, 43 | status303, status404, 44 | status501) 45 | import Network.Wai (mapResponseHeaders, 46 | Middleware, Request, 47 | pathInfo, rawPathInfo, 48 | rawQueryString, 49 | responseBuilder, 50 | responseLBS, vault) 51 | import Network.Wai.Auth.AppRoot 52 | import Network.Wai.Auth.ClientSession 53 | import Network.Wai.Middleware.Auth.Provider 54 | import Network.Wai.Auth.Tools (decodeKey) 55 | import qualified Paths_wai_middleware_auth as Paths 56 | import System.IO.Unsafe (unsafePerformIO) 57 | import System.PosixCompat.Time (epochTime) 58 | import Text.Hamlet (Render) 59 | 60 | 61 | 62 | 63 | -- | Settings for creating the Auth middleware. 64 | -- 65 | -- To create a value, use 'defaultAuthSettings' and then various setter 66 | -- functions. 67 | -- 68 | -- @since 0.1.0 69 | data AuthSettings = AuthSettings 70 | { asGetKey :: IO Key 71 | , asGetAppRoot :: Request -> IO T.Text 72 | , asSessionAge :: Int -- ^ default: 3600 seconds (1 hour) 73 | , asAuthPrefix :: T.Text -- ^ default: _auth_middleware 74 | , asStateKey :: S.ByteString -- ^ Cookie name, default: auth_state 75 | , asProviders :: Providers 76 | , asProvidersTemplate :: Maybe T.Text -> Render Provider -> Providers -> Builder 77 | } 78 | 79 | -- | Default middleware settings. See various setters in order to change 80 | -- available settings 81 | -- 82 | -- @since 0.1.0 83 | defaultAuthSettings :: AuthSettings 84 | defaultAuthSettings = 85 | AuthSettings 86 | { asGetKey = getDefaultKey 87 | , asGetAppRoot = return <$> smartAppRoot 88 | , asSessionAge = 3600 89 | , asAuthPrefix = "_auth_middleware" 90 | , asStateKey = "auth_state" 91 | , asProviders = HM.empty 92 | , asProvidersTemplate = providersTemplate 93 | } 94 | 95 | 96 | -- | Set the function to get client session key for encrypting cookie data. 97 | -- 98 | -- Default: 'getDefaultKey' 99 | -- 100 | -- @since 0.1.0 101 | setAuthKey :: IO Key -> AuthSettings -> AuthSettings 102 | setAuthKey x as = as { asGetKey = x } 103 | 104 | -- | Set the cookie name. 105 | -- 106 | -- Default: "auth_state" 107 | -- 108 | -- @since 0.1.0 109 | setAuthCookieName :: S.ByteString -> AuthSettings -> AuthSettings 110 | setAuthCookieName x as = as { asStateKey = x } 111 | 112 | 113 | -- | Set the cookie key. 114 | -- 115 | -- Default: "auth_state" 116 | -- 117 | -- @since 0.1.0 118 | setAuthPrefix :: T.Text -> AuthSettings -> AuthSettings 119 | setAuthPrefix x as = as { asAuthPrefix = x } 120 | 121 | 122 | -- | The application root for this application. 123 | -- 124 | -- | Set the root for this Aplication. Required for external Authentication 125 | -- providers to perform proper redirect. 126 | -- 127 | -- Default: use the APPROOT environment variable. 128 | -- 129 | -- @since 0.1.0 130 | setAuthAppRootStatic :: T.Text -> AuthSettings -> AuthSettings 131 | setAuthAppRootStatic = setAuthAppRootGeneric . const . return 132 | 133 | -- | More generalized version of 'setAuthApprootStatic'. 134 | -- 135 | -- @since 0.1.0 136 | setAuthAppRootGeneric :: (Request -> IO T.Text) -> AuthSettings -> AuthSettings 137 | setAuthAppRootGeneric x as = as { asGetAppRoot = x } 138 | 139 | -- | Number of seconds to keep an authentication cookie active 140 | -- 141 | -- Default: 3600 142 | -- 143 | -- @since 0.1.0 144 | setAuthSessionAge :: Int -> AuthSettings -> AuthSettings 145 | setAuthSessionAge x as = as { asSessionAge = x } 146 | 147 | 148 | -- | Set Authentication providers to be used. 149 | -- 150 | -- Default is empty. 151 | -- 152 | -- @since 0.1.0 153 | setAuthProviders :: Providers -> AuthSettings -> AuthSettings 154 | setAuthProviders !ps as = as { asProviders = ps } 155 | 156 | 157 | -- | Set a custom template that will be rendered for a providers page 158 | -- 159 | -- Default: `providersTemplate` 160 | -- 161 | -- @since 0.1.0 162 | setAuthProvidersTemplate :: (Maybe T.Text -> Render Provider -> Providers -> Builder) 163 | -> AuthSettings 164 | -> AuthSettings 165 | setAuthProvidersTemplate t as = as { asProvidersTemplate = t } 166 | 167 | 168 | -- | Current state of the user. 169 | data AuthState = AuthNeedRedirect !S.ByteString 170 | | AuthLoggedIn !AuthUser 171 | deriving (Generic, Show) 172 | 173 | instance Binary AuthState 174 | 175 | 176 | -- | Creates an Authentication middleware that will make sure application is 177 | -- protected, thus allowing access only to users that go through an 178 | -- authentication process with one of the available providers. If more than one 179 | -- provider is specified, user will be directed to a page were one can be chosen 180 | -- from a list. 181 | -- 182 | -- @since 0.1.0 183 | mkAuthMiddleware :: AuthSettings -> IO Middleware 184 | mkAuthMiddleware AuthSettings {..} = do 185 | secretKey <- asGetKey 186 | let saveAuthState = saveCookieValue secretKey asStateKey asSessionAge 187 | authRouteRender = mkRouteRender Nothing asAuthPrefix [] 188 | -- Redirect to a list of providers if more than one is available, otherwise 189 | -- start login process with the only provider. 190 | let enforceLogin protectedPath req respond = 191 | case pathInfo req of 192 | (prefix:rest) 193 | | prefix == asAuthPrefix -> 194 | case rest of 195 | [] -> 196 | case HM.elems asProviders of 197 | [] -> 198 | respond $ 199 | responseLBS 200 | status501 201 | [] 202 | "No Authentication providers available." 203 | [soleProvider] -> 204 | let loginUrl = 205 | encodeUtf8 $ authRouteRender soleProvider [] 206 | in respond $ 207 | responseLBS 208 | status303 209 | [("Location", loginUrl)] 210 | "Redirecting to Login page" 211 | _ -> 212 | respond $ 213 | responseBuilder status200 [] $ 214 | asProvidersTemplate Nothing authRouteRender asProviders 215 | (providerName:pathSuffix) 216 | | Just provider <- HM.lookup providerName asProviders -> do 217 | appRoot <- asGetAppRoot req 218 | let onFailure status errMsg = 219 | return $ 220 | responseBuilder status [] $ 221 | asProvidersTemplate 222 | (Just $ decodeUtf8With lenientDecode errMsg) 223 | authRouteRender 224 | asProviders 225 | let onSuccess "" = 226 | onFailure 227 | status501 228 | "Empty user identity is not allowed" 229 | onSuccess authLoginState = do 230 | CTime now <- epochTime 231 | cookie <- 232 | saveAuthState $ 233 | AuthLoggedIn $ 234 | AuthUser 235 | { authLoginState = authLoginState 236 | , authProviderName = 237 | encodeUtf8 $ getProviderName provider 238 | , authLoginTime = fromIntegral now 239 | } 240 | return $ 241 | responseBuilder 242 | status303 243 | [("Location", protectedPath), cookie] 244 | (fromByteString "Redirecting to " <> 245 | fromByteString protectedPath) 246 | let providerUrlRenderer (ProviderUrl suffix) = 247 | mkRouteRender 248 | (Just appRoot) 249 | asAuthPrefix 250 | suffix 251 | provider 252 | respond =<< 253 | handleLogin 254 | provider 255 | req 256 | pathSuffix 257 | providerUrlRenderer 258 | onSuccess 259 | onFailure 260 | ["health"] -> respond $ responseLBS status200 [] "OK" 261 | _ -> respond $ responseLBS status404 [] "Unknown URL" 262 | -- Workaround for Chrome asking for favicon.ico, causing a wrong 263 | -- redirect url to be stored in a cookie. 264 | ["favicon.ico"] -> respond $ responseLBS status404 [] "No favicon.ico" 265 | _ -> do 266 | cookie <- 267 | saveAuthState $ 268 | AuthNeedRedirect (rawPathInfo req <> rawQueryString req) 269 | respond $ 270 | responseBuilder 271 | status303 272 | [("Location", "/" <> encodeUtf8 asAuthPrefix), cookie] 273 | "Redirecting to Login Page" 274 | return $ \app req respond -> do 275 | authState <- loadCookieValue secretKey asStateKey req 276 | case authState of 277 | Just (AuthLoggedIn user) -> 278 | let providerName = decodeUtf8With lenientDecode (authProviderName user) 279 | in case HM.lookup providerName asProviders of 280 | Nothing -> 281 | -- We can no longer find the provider the user originally 282 | -- authenticated with, and as a result have no way to check if the 283 | -- session is still valid. For backwards compatibility with older 284 | -- versions of this library we'll assume the session remains valid. 285 | let req' = req {vault = Vault.insert userKey user $ vault req} 286 | in app req' respond 287 | Just provider -> do 288 | refreshResult <- refreshLoginState provider req user 289 | case refreshResult of 290 | Nothing -> 291 | -- The session has expired, the user needs to re-authenticate. 292 | enforceLogin "/" req respond 293 | Just (req', user') -> 294 | let req'' = req' {vault = Vault.insert userKey user' $ vault req'} 295 | respond' response 296 | | user' == user = respond response 297 | | otherwise = do 298 | cookieHeader <- saveAuthState (AuthLoggedIn user') 299 | respond $ mapResponseHeaders (cookieHeader :) response 300 | in app req'' respond' 301 | Just (AuthNeedRedirect url) -> enforceLogin url req respond 302 | Nothing -> enforceLogin "/" req respond 303 | 304 | 305 | userKey :: Vault.Key AuthUser 306 | userKey = unsafePerformIO Vault.newKey 307 | {-# NOINLINE userKey #-} 308 | 309 | 310 | -- | Get the username for the current user. 311 | -- 312 | -- If called on a @Request@ behind the middleware, should always return a 313 | -- @Just@ value. 314 | -- 315 | -- @since 0.1.0 316 | getAuthUser :: Request -> Maybe AuthUser 317 | getAuthUser = Vault.lookup userKey . vault 318 | 319 | -- | Get the username for the current user from the given Vault. 320 | -- 321 | -- My be used instead of 'getAuthUser' in libraries that do not provide a 322 | -- 'Request' value, such as @Servant.Api.Vault@. 323 | -- 324 | -- @since 0.2.5.0 325 | getAuthUserFromVault :: Vault.Vault -> Maybe AuthUser 326 | getAuthUserFromVault = Vault.lookup userKey 327 | 328 | -- | Current version 329 | -- 330 | -- @since 0.1.0 331 | waiMiddlewareAuthVersion :: Version 332 | waiMiddlewareAuthVersion = Paths.version 333 | 334 | -- | Get a response header to delete the users current session. 335 | -- 336 | -- @since 0.2.0 337 | getDeleteSessionHeader :: AuthSettings -> Header 338 | getDeleteSessionHeader = deleteCookieValue . asStateKey 339 | -------------------------------------------------------------------------------- /src/Network/Wai/Middleware/Auth/OAuth2.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | {-# LANGUAGE TemplateHaskell #-} 4 | {-# LANGUAGE TupleSections #-} 5 | module Network.Wai.Middleware.Auth.OAuth2 6 | ( OAuth2(..) 7 | , oAuth2Parser 8 | , URIParseException(..) 9 | , parseAbsoluteURI 10 | , getAccessToken 11 | ) where 12 | 13 | import Control.Monad.Catch 14 | import Data.Aeson.TH (defaultOptions, 15 | deriveJSON, 16 | fieldLabelModifier) 17 | import Data.Functor ((<&>)) 18 | import Data.Int (Int64) 19 | import Data.Proxy (Proxy (..)) 20 | import qualified Data.Text as T 21 | import Data.Text.Encoding (encodeUtf8) 22 | import Foreign.C.Types (CTime (..)) 23 | import Network.HTTP.Client.TLS (getGlobalManager) 24 | import qualified Network.OAuth.OAuth2 as OA2 25 | import Network.Wai (Request) 26 | import Network.Wai.Auth.Internal (decodeToken, encodeToken, 27 | oauth2Login, 28 | refreshTokens) 29 | import Network.Wai.Auth.Tools (toLowerUnderscore) 30 | import qualified Network.Wai.Middleware.Auth as MA 31 | import Network.Wai.Middleware.Auth.Provider 32 | import System.PosixCompat.Time (epochTime) 33 | import qualified URI.ByteString as U 34 | 35 | -- | General OAuth2 authentication `Provider`. 36 | data OAuth2 = OAuth2 37 | { oa2ClientId :: T.Text 38 | , oa2ClientSecret :: T.Text 39 | , oa2AuthorizeEndpoint :: T.Text 40 | , oa2AccessTokenEndpoint :: T.Text 41 | , oa2Scope :: Maybe [T.Text] 42 | , oa2ProviderInfo :: ProviderInfo 43 | } 44 | 45 | -- | Used for validating proper url structure. Can be thrown by 46 | -- `parseAbsoluteURI` and consequently by `handleLogin` for `OAuth2` `Provider` 47 | -- instance. 48 | -- 49 | -- @since 0.1.2.0 50 | data URIParseException = URIParseException U.URIParseError deriving Show 51 | 52 | instance Exception URIParseException 53 | 54 | -- | Parse absolute URI and throw `URIParseException` in case it is malformed 55 | -- 56 | -- @since 0.1.2.0 57 | parseAbsoluteURI :: MonadThrow m => T.Text -> m U.URI 58 | parseAbsoluteURI urlTxt = do 59 | case U.parseURI U.strictURIParserOptions (encodeUtf8 urlTxt) of 60 | Left err -> throwM $ URIParseException err 61 | Right url -> return url 62 | 63 | getClientId :: T.Text -> T.Text 64 | getClientId = id 65 | 66 | getClientSecret :: T.Text -> T.Text 67 | getClientSecret = id 68 | 69 | $(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 3} ''OAuth2) 70 | 71 | -- | Aeson parser for `OAuth2` provider. 72 | -- 73 | -- @since 0.1.0 74 | oAuth2Parser :: ProviderParser 75 | oAuth2Parser = mkProviderParser (Proxy :: Proxy OAuth2) 76 | 77 | 78 | instance AuthProvider OAuth2 where 79 | getProviderName _ = "oauth2" 80 | getProviderInfo = oa2ProviderInfo 81 | handleLogin oa2@OAuth2 {..} req suffix renderUrl onSuccess onFailure = do 82 | authEndpointURI <- parseAbsoluteURI oa2AuthorizeEndpoint 83 | accessTokenEndpointURI <- parseAbsoluteURI oa2AccessTokenEndpoint 84 | callbackURI <- parseAbsoluteURI $ renderUrl (ProviderUrl ["complete"]) [] 85 | let oauth2 = 86 | OA2.OAuth2 87 | { oauthClientId = getClientId oa2ClientId 88 | , oauthClientSecret = Just $ getClientSecret oa2ClientSecret 89 | , oauthOAuthorizeEndpoint = authEndpointURI 90 | , oauthAccessTokenEndpoint = accessTokenEndpointURI 91 | , oauthCallback = Just callbackURI 92 | } 93 | man <- getGlobalManager 94 | oauth2Login 95 | oauth2 96 | man 97 | oa2Scope 98 | (getProviderName oa2) 99 | req 100 | suffix 101 | onSuccess 102 | onFailure 103 | refreshLoginState OAuth2 {..} req user = do 104 | authEndpointURI <- parseAbsoluteURI oa2AuthorizeEndpoint 105 | accessTokenEndpointURI <- parseAbsoluteURI oa2AccessTokenEndpoint 106 | let loginState = authLoginState user 107 | case decodeToken loginState of 108 | Left _ -> pure Nothing 109 | Right tokens -> do 110 | CTime now <- epochTime 111 | if tokenExpired user now tokens then do 112 | let oauth2 = 113 | OA2.OAuth2 114 | { oauthClientId = getClientId oa2ClientId 115 | , oauthClientSecret = Just (getClientSecret oa2ClientSecret) 116 | , oauthOAuthorizeEndpoint = authEndpointURI 117 | , oauthAccessTokenEndpoint = accessTokenEndpointURI 118 | -- Setting callback endpoint to `Nothing` below is a lie. 119 | -- We do have a callback endpoint but in this context 120 | -- don't have access to the function that can render it. 121 | -- We get away with this because the callback endpoint is 122 | -- not needed for obtaining a refresh token, the only 123 | -- way we use the config here constructed. 124 | , oauthCallback = Nothing 125 | } 126 | man <- getGlobalManager 127 | rRes <- refreshTokens tokens man oauth2 128 | pure (rRes <&> \newTokens -> (req, user { 129 | authLoginState = encodeToken newTokens, 130 | authLoginTime = fromIntegral now 131 | })) 132 | else 133 | pure (Just (req, user)) 134 | 135 | tokenExpired :: AuthUser -> Int64 -> OA2.OAuth2Token -> Bool 136 | tokenExpired user now tokens = 137 | case OA2.expiresIn tokens of 138 | Nothing -> False 139 | Just expiresIn -> authLoginTime user + (fromIntegral expiresIn) < now 140 | 141 | -- | Get the @AccessToken@ for the current user. 142 | -- 143 | -- If called on a @Request@ behind the middleware, should always return a 144 | -- @Just@ value. 145 | -- 146 | -- @since 0.2.0.0 147 | getAccessToken :: Request -> Maybe OA2.OAuth2Token 148 | getAccessToken req = do 149 | user <- MA.getAuthUser req 150 | either (const Nothing) Just $ decodeToken (authLoginState user) 151 | -------------------------------------------------------------------------------- /src/Network/Wai/Middleware/Auth/OAuth2/Github.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | module Network.Wai.Middleware.Auth.OAuth2.Github 4 | ( Github(..) 5 | , mkGithubProvider 6 | , githubParser 7 | ) where 8 | import Control.Exception.Safe (catchAny) 9 | import Data.Maybe (fromMaybe) 10 | import Data.Aeson 11 | import qualified Data.ByteString as S 12 | import Data.Proxy (Proxy (..)) 13 | import qualified Data.Text as T 14 | import Data.Text.Encoding (encodeUtf8) 15 | import Network.HTTP.Simple (getResponseBody, 16 | httpJSON, parseRequest, 17 | setRequestHeaders) 18 | import Network.HTTP.Types 19 | import qualified Network.OAuth.OAuth2 as OA2 20 | import Network.Wai.Auth.Internal (decodeToken) 21 | import Network.Wai.Auth.Tools (getValidEmail) 22 | import Network.Wai.Middleware.Auth.OAuth2 23 | import Network.Wai.Middleware.Auth.Provider 24 | 25 | 26 | -- | Create a github authentication provider 27 | -- 28 | -- @since 0.1.0 29 | mkGithubProvider 30 | :: T.Text -- ^ Name of the application as it is registered on github 31 | -> T.Text -- ^ @client_id@ from github 32 | -> T.Text -- ^ @client_secret@ from github 33 | -> [S.ByteString] -- ^ White list of posix regular expressions for emails 34 | -- attached to github account. 35 | -> Maybe ProviderInfo -- ^ Replacement for default info 36 | -> Github 37 | mkGithubProvider appName clientId clientSecret emailWhiteList mProviderInfo = 38 | Github 39 | appName 40 | "https://api.github.com/user/emails" 41 | emailWhiteList 42 | OAuth2 43 | { oa2ClientId = clientId 44 | , oa2ClientSecret = clientSecret 45 | , oa2AuthorizeEndpoint = "https://github.com/login/oauth/authorize" 46 | , oa2AccessTokenEndpoint = "https://github.com/login/oauth/access_token" 47 | , oa2Scope = Just ["user:email"] 48 | , oa2ProviderInfo = fromMaybe defProviderInfo mProviderInfo 49 | } 50 | where 51 | defProviderInfo = 52 | ProviderInfo 53 | { providerTitle = "GitHub" 54 | , providerLogoUrl = 55 | "https://assets-cdn.github.com/images/modules/logos_page/Octocat.png" 56 | , providerDescr = "Use your GitHub account to access this page." 57 | } 58 | 59 | -- | Aeson parser for `Github` provider. 60 | -- 61 | -- @since 0.1.0 62 | githubParser :: ProviderParser 63 | githubParser = mkProviderParser (Proxy :: Proxy Github) 64 | 65 | 66 | -- | Github authentication provider 67 | data Github = Github 68 | { githubAppName :: T.Text 69 | , githubAPIEmailEndpoint :: T.Text 70 | , githubEmailWhitelist :: [S.ByteString] 71 | , githubOAuth2 :: OAuth2 72 | } 73 | 74 | instance FromJSON Github where 75 | parseJSON = 76 | withObject "Github Provider Object" $ \obj -> do 77 | appName <- obj .: "app_name" 78 | clientId <- obj .: "client_id" 79 | clientSecret <- obj .: "client_secret" 80 | emailWhiteList <- obj .:? "email_white_list" .!= [] 81 | mProviderInfo <- obj .:? "provider_info" 82 | return $ 83 | mkGithubProvider 84 | appName 85 | clientId 86 | clientSecret 87 | (map encodeUtf8 emailWhiteList) 88 | mProviderInfo 89 | 90 | -- | Newtype wrapper for a github verified email 91 | newtype GithubEmail = GithubEmail { githubEmail :: S.ByteString } deriving Show 92 | 93 | instance FromJSON GithubEmail where 94 | parseJSON = withObject "Github Verified Email" $ \ obj -> do 95 | True <- obj .: "verified" 96 | email <- obj .: "email" 97 | return (GithubEmail $ encodeUtf8 email) 98 | 99 | 100 | -- | Makes an API call to github and retrieves all user's verified emails. 101 | retrieveEmails :: T.Text -> T.Text -> S.ByteString -> IO [GithubEmail] 102 | retrieveEmails appName emailApiEndpoint accessToken = do 103 | req <- parseRequest (T.unpack emailApiEndpoint) 104 | resp <- httpJSON $ setRequestHeaders headers req 105 | return $ getResponseBody resp 106 | where 107 | headers = 108 | [ ("Accept", "application/vnd.github.v3+json") 109 | , ("Authorization", "token " <> accessToken) 110 | , ("User-Agent", encodeUtf8 appName) 111 | ] 112 | 113 | 114 | instance AuthProvider Github where 115 | getProviderName _ = "github" 116 | getProviderInfo = getProviderInfo . githubOAuth2 117 | handleLogin Github {..} req suffix renderUrl onSuccess onFailure = do 118 | let onOAuth2Success oauth2Tokens = do 119 | catchAny 120 | (do accessToken <- 121 | case decodeToken oauth2Tokens of 122 | Left err -> fail err 123 | Right tokens -> pure $ encodeUtf8 $ OA2.atoken $ OA2.accessToken tokens 124 | emails <- 125 | map githubEmail <$> 126 | retrieveEmails 127 | githubAppName 128 | githubAPIEmailEndpoint 129 | accessToken 130 | let mEmail = getValidEmail githubEmailWhitelist emails 131 | case mEmail of 132 | Just email -> onSuccess email 133 | Nothing -> 134 | onFailure status403 $ 135 | "No valid email was found with permission to access this resource. " <> 136 | "Please contact the administrator.") 137 | (\_err -> onFailure status501 "Issue communicating with github") 138 | handleLogin githubOAuth2 req suffix renderUrl onOAuth2Success onFailure 139 | -------------------------------------------------------------------------------- /src/Network/Wai/Middleware/Auth/OAuth2/Gitlab.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | module Network.Wai.Middleware.Auth.OAuth2.Gitlab 4 | ( Gitlab(..) 5 | , mkGitlabProvider 6 | , gitlabParser 7 | ) where 8 | import Control.Exception.Safe (catchAny) 9 | import Data.Maybe (fromMaybe) 10 | import Data.Aeson 11 | import qualified Data.ByteString as S 12 | import Data.Proxy (Proxy (..)) 13 | import qualified Data.Text as T 14 | import Data.Text.Encoding (encodeUtf8) 15 | import Network.HTTP.Simple (getResponseBody, 16 | httpJSON, parseRequest, 17 | setRequestHeaders) 18 | import Network.HTTP.Types 19 | import qualified Network.OAuth.OAuth2 as OA2 20 | import Network.Wai.Auth.Internal (decodeToken) 21 | import Network.Wai.Auth.Tools (getValidEmail) 22 | import Network.Wai.Middleware.Auth.OAuth2 23 | import Network.Wai.Middleware.Auth.Provider 24 | 25 | -- | Create a gitlab authentication provider 26 | -- 27 | -- @since 0.2.4.0 28 | mkGitlabProvider 29 | :: T.Text -- ^ Hostname of GitLab instance (e.g. @gitlab.com@) 30 | -> T.Text -- ^ Name of the application as it is registered on gitlab 31 | -> T.Text -- ^ @client_id@ from gitlab 32 | -> T.Text -- ^ @client_secret@ from gitlab 33 | -> [S.ByteString] -- ^ White list of posix regular expressions for emails 34 | -- attached to gitlab account. 35 | -> Maybe ProviderInfo -- ^ Replacement for default info 36 | -> Gitlab 37 | mkGitlabProvider gitlabHost appName clientId clientSecret emailWhiteList mProviderInfo = 38 | Gitlab 39 | appName 40 | ("https://" <> gitlabHost <> "/api/v4/user") 41 | emailWhiteList 42 | OAuth2 43 | { oa2ClientId = clientId 44 | , oa2ClientSecret = clientSecret 45 | , oa2AuthorizeEndpoint = ("https://" <> gitlabHost <> "/oauth/authorize") 46 | , oa2AccessTokenEndpoint = ("https://" <> gitlabHost <> "/oauth/token") 47 | , oa2Scope = Just ["read_user"] 48 | , oa2ProviderInfo = fromMaybe defProviderInfo mProviderInfo 49 | } 50 | where 51 | defProviderInfo = 52 | ProviderInfo 53 | { providerTitle = "GitLab" 54 | , providerLogoUrl = 55 | "https://about.gitlab.com/images/press/logo/png/gitlab-icon-rgb.png" 56 | , providerDescr = "Use your GitLab account to access this page." 57 | } 58 | 59 | -- | Aeson parser for `Gitlab` provider. 60 | -- 61 | -- @since 0.2.4.0 62 | gitlabParser :: ProviderParser 63 | gitlabParser = mkProviderParser (Proxy :: Proxy Gitlab) 64 | 65 | 66 | -- | Gitlab authentication provider 67 | data Gitlab = Gitlab 68 | { gitlabAppName :: T.Text 69 | , gitlabAPIUserEndpoint :: T.Text 70 | , gitlabEmailWhitelist :: [S.ByteString] 71 | , gitlabOAuth2 :: OAuth2 72 | } 73 | 74 | instance FromJSON Gitlab where 75 | parseJSON = 76 | withObject "Gitlab Provider Object" $ \obj -> do 77 | gitlabHost <- obj .:? "gitlab_host" 78 | appName <- obj .: "app_name" 79 | clientId <- obj .: "client_id" 80 | clientSecret <- obj .: "client_secret" 81 | emailWhiteList <- obj .:? "email_white_list" .!= [] 82 | mProviderInfo <- obj .:? "provider_info" 83 | return $ 84 | mkGitlabProvider 85 | (fromMaybe "gitlab.com" gitlabHost) 86 | appName 87 | clientId 88 | clientSecret 89 | (map encodeUtf8 emailWhiteList) 90 | mProviderInfo 91 | 92 | -- | Newtype wrapper for a gitlab user 93 | newtype GitlabEmail = GitlabEmail { gitlabEmail :: S.ByteString } deriving Show 94 | 95 | instance FromJSON GitlabEmail where 96 | parseJSON = withObject "Gitlab Email" $ \ obj -> do 97 | email <- obj .: "email" 98 | return (GitlabEmail $ encodeUtf8 email) 99 | 100 | 101 | -- | Makes an API call to gitlab and retrieves user's verified email. 102 | -- Note: we only retrieve the PRIMARY email, because there is no way 103 | -- to tell whether secondary emails are verified or not. 104 | retrieveUser :: T.Text -> T.Text -> S.ByteString -> IO GitlabEmail 105 | retrieveUser appName userApiEndpoint accessToken = do 106 | req <- parseRequest (T.unpack userApiEndpoint) 107 | resp <- httpJSON $ setRequestHeaders headers req 108 | return $ getResponseBody resp 109 | where 110 | headers = 111 | [ ("Authorization", "Bearer " <> accessToken) 112 | , ("User-Agent", encodeUtf8 appName) 113 | ] 114 | 115 | 116 | instance AuthProvider Gitlab where 117 | getProviderName _ = "gitlab" 118 | getProviderInfo = getProviderInfo . gitlabOAuth2 119 | handleLogin Gitlab {..} req suffix renderUrl onSuccess onFailure = do 120 | let onOAuth2Success oauth2Tokens = do 121 | catchAny 122 | (do accessToken <- 123 | case decodeToken oauth2Tokens of 124 | Left err -> fail err 125 | Right tokens -> pure $ encodeUtf8 $ OA2.atoken $ OA2.accessToken tokens 126 | email <- 127 | gitlabEmail <$> 128 | retrieveUser 129 | gitlabAppName 130 | gitlabAPIUserEndpoint 131 | accessToken 132 | let mValidEmail = getValidEmail gitlabEmailWhitelist [email] 133 | case mValidEmail of 134 | Just validEmail -> onSuccess validEmail 135 | Nothing -> 136 | onFailure status403 $ 137 | "Your primary email address does not have permission to access this resource. " <> 138 | "Please contact the administrator.") 139 | (\_err -> onFailure status501 "Issue communicating with gitlab") 140 | handleLogin gitlabOAuth2 req suffix renderUrl onOAuth2Success onFailure 141 | -------------------------------------------------------------------------------- /src/Network/Wai/Middleware/Auth/OAuth2/Google.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | module Network.Wai.Middleware.Auth.OAuth2.Google 4 | ( Google(..) 5 | , mkGoogleProvider 6 | , googleParser 7 | ) where 8 | import Control.Exception.Safe (catchAny) 9 | import Control.Monad (guard) 10 | import Data.Aeson 11 | import qualified Data.ByteString as S 12 | import Data.Maybe (fromMaybe) 13 | import Data.Proxy (Proxy (..)) 14 | import qualified Data.Text as T 15 | import Data.Text.Encoding (encodeUtf8) 16 | import Network.HTTP.Simple (getResponseBody, 17 | httpJSON, parseRequestThrow, 18 | setRequestHeaders) 19 | import Network.HTTP.Types 20 | import qualified Network.OAuth.OAuth2 as OA2 21 | import Network.Wai.Auth.Internal (decodeToken) 22 | import Network.Wai.Auth.Tools (getValidEmail) 23 | import Network.Wai.Middleware.Auth.OAuth2 24 | import Network.Wai.Middleware.Auth.Provider 25 | import System.IO (hPutStrLn, stderr) 26 | 27 | 28 | -- | Create a google authentication provider 29 | -- 30 | -- @since 0.1.0 31 | mkGoogleProvider 32 | :: T.Text -- ^ @client_id@ from google 33 | -> T.Text -- ^ @client_secret@ from google 34 | -> [S.ByteString] -- ^ White list of posix regular expressions for emails 35 | -- attached to github account. 36 | -> Maybe ProviderInfo -- ^ Replacement for default info 37 | -> Google 38 | mkGoogleProvider clientId clientSecret emailWhiteList mProviderInfo = 39 | Google 40 | "https://www.googleapis.com/oauth2/v3/userinfo" 41 | emailWhiteList 42 | OAuth2 43 | { oa2ClientId = clientId 44 | , oa2ClientSecret = clientSecret 45 | , oa2AuthorizeEndpoint = "https://accounts.google.com/o/oauth2/v2/auth" 46 | , oa2AccessTokenEndpoint = "https://www.googleapis.com/oauth2/v4/token" 47 | , oa2Scope = Just ["https://www.googleapis.com/auth/userinfo.email"] 48 | , oa2ProviderInfo = fromMaybe defProviderInfo mProviderInfo 49 | } 50 | where 51 | defProviderInfo = 52 | ProviderInfo 53 | { providerTitle = "Google" 54 | , providerLogoUrl = 55 | "https://upload.wikimedia.org/wikipedia/commons/thumb/5/53/Google_%22G%22_Logo.svg/200px-Google_%22G%22_Logo.svg.png" 56 | , providerDescr = "Use your Google account to access this page." 57 | } 58 | 59 | -- | Aeson parser for `Google` provider. 60 | -- 61 | -- @since 0.1.0 62 | googleParser :: ProviderParser 63 | googleParser = mkProviderParser (Proxy :: Proxy Google) 64 | 65 | 66 | data Google = Google 67 | { googleAPIEmailEndpoint :: T.Text 68 | , googleEmailWhitelist :: [S.ByteString] 69 | , googleOAuth2 :: OAuth2 70 | } 71 | 72 | instance FromJSON Google where 73 | parseJSON = 74 | withObject "Google Provider Object" $ \obj -> do 75 | clientId <- obj .: "client_id" 76 | clientSecret <- obj .: "client_secret" 77 | emailWhiteList <- obj .:? "email_white_list" .!= [] 78 | mProviderInfo <- obj .:? "provider_info" 79 | return $ 80 | mkGoogleProvider 81 | clientId 82 | clientSecret 83 | (map encodeUtf8 emailWhiteList) 84 | mProviderInfo 85 | 86 | 87 | newtype GoogleEmail = GoogleEmail { googleEmail :: S.ByteString } deriving Show 88 | 89 | instance FromJSON GoogleEmail where 90 | parseJSON = withObject "Google Verified Email" $ \ obj -> do 91 | verified <- obj .: "email_verified" 92 | guard verified 93 | email <- obj .: "email" 94 | return (GoogleEmail $ encodeUtf8 email) 95 | 96 | 97 | 98 | -- | Makes a call to google API and retrieves user's main email. 99 | retrieveEmail :: T.Text -> S.ByteString -> IO GoogleEmail 100 | retrieveEmail emailApiEndpoint accessToken = do 101 | req <- parseRequestThrow (T.unpack emailApiEndpoint) 102 | resp <- httpJSON $ setRequestHeaders headers req 103 | return $ getResponseBody resp 104 | where 105 | headers = [("Authorization", "Bearer " <> accessToken)] 106 | 107 | 108 | instance AuthProvider Google where 109 | getProviderName _ = "google" 110 | getProviderInfo = getProviderInfo . googleOAuth2 111 | handleLogin Google {..} req suffix renderUrl onSuccess onFailure = do 112 | let onOAuth2Success oauth2Tokens = do 113 | catchAny 114 | (do accessToken <- 115 | case decodeToken oauth2Tokens of 116 | Left err -> fail err 117 | Right tokens -> pure $ encodeUtf8 $ OA2.atoken $ OA2.accessToken tokens 118 | email <- 119 | googleEmail <$> 120 | retrieveEmail googleAPIEmailEndpoint accessToken 121 | let mEmail = getValidEmail googleEmailWhitelist [email] 122 | case mEmail of 123 | Just email' -> onSuccess email' 124 | Nothing -> 125 | onFailure 126 | status403 127 | "No valid email with permission to access was found.") $ \err -> do 128 | hPutStrLn stderr $ "Issue communicating with Google: " ++ show err 129 | onFailure status501 "Issue communicating with Google." 130 | handleLogin googleOAuth2 req suffix renderUrl onOAuth2Success onFailure 131 | -------------------------------------------------------------------------------- /src/Network/Wai/Middleware/Auth/OIDC.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | {-# LANGUAGE OverloadedStrings #-} 4 | -- | An OpenID connect provider. 5 | -- 6 | -- OpenID Connect is a simple identity layer on top of the OAuth2 protocol. 7 | -- Learn more about it here: 8 | -- 9 | -- @since 0.2.3.0 10 | module Network.Wai.Middleware.Auth.OIDC 11 | ( -- * Creating a provider 12 | OpenIDConnect 13 | , discover 14 | , discoverURI 15 | -- * Customizing a provider 16 | , oidcClientId 17 | , oidcClientSecret 18 | , oidcProviderInfo 19 | , oidcManager 20 | , oidcScopes 21 | , oidcAllowedSkew 22 | -- * Accessing session data 23 | , getAccessToken 24 | , getIdToken 25 | ) where 26 | 27 | import Control.Applicative ((<|>)) 28 | import qualified Crypto.JOSE as JOSE 29 | import qualified Crypto.JWT as JWT 30 | import Control.Monad.Except (runExceptT) 31 | import Data.Aeson (FromJSON(parseJSON), 32 | withObject, (.:), (.!=)) 33 | import qualified Data.ByteString.Char8 as S8 34 | import Data.Function ((&)) 35 | import qualified Data.Time.Clock as Clock 36 | import Data.Traversable (for) 37 | import qualified Data.Text as T 38 | import qualified Data.Text.Lazy as TL 39 | import qualified Data.Text.Lazy.Encoding as TLE 40 | import qualified Data.Vault.Lazy as Vault 41 | import Foreign.C.Types (CTime (..)) 42 | import qualified Lens.Micro as Lens 43 | import qualified Lens.Micro.Extras as Lens.Extras 44 | import Network.HTTP.Simple (httpJSON, 45 | getResponseBody, 46 | parseRequestThrow) 47 | import Network.Wai.Middleware.Auth.OAuth2 (parseAbsoluteURI, 48 | getAccessToken) 49 | import qualified Network.OAuth.OAuth2 as OA2 50 | import Network.HTTP.Client (Manager) 51 | import Network.HTTP.Client.TLS (getGlobalManager) 52 | import Network.Wai (Request, vault) 53 | import Network.Wai.Auth.Internal (Metadata(..), 54 | decodeToken, encodeToken, 55 | oauth2Login, 56 | refreshTokens) 57 | import Network.Wai.Middleware.Auth.Provider 58 | import System.IO.Unsafe (unsafePerformIO) 59 | import System.PosixCompat.Time (epochTime) 60 | import qualified Text.Hamlet 61 | import qualified URI.ByteString as U 62 | 63 | -- | An Open ID Connect provider. 64 | -- 65 | -- To create a value use `discover` to download configuration for an existing 66 | -- provider, then use various setter functions to customize it. 67 | -- 68 | -- @since 0.2.3.0 69 | data OpenIDConnect 70 | = OpenIDConnect 71 | { oidcMetadata :: Metadata 72 | , oidcJwkSet :: JOSE.JWKSet 73 | -- | The client id this application is registered with at the Open ID 74 | -- Connect provider. The default is an empty string, you will need to 75 | -- overwrite this. 76 | -- 77 | -- @since 0.2.3.0 78 | , oidcClientId :: T.Text 79 | -- | The client secret of this application. The default is an empty 80 | -- string, you will need to overwrite this. 81 | -- 82 | -- @since 0.2.3.0 83 | , oidcClientSecret :: T.Text 84 | -- | The information for this provider. The default contains some 85 | -- placeholder texts. If you're using the provider screen you'll want to 86 | -- overwrite this. 87 | -- 88 | -- @since 0.2.3.0 89 | , oidcProviderInfo :: ProviderInfo 90 | -- | The HTTP manager to use. Defaults to the global manager when not set. 91 | -- 92 | -- @since 0.2.3.0 93 | , oidcManager :: Maybe Manager 94 | -- | The scopes to set. Defaults to only the "openid" scope. 95 | -- 96 | -- @since 0.2.3.0 97 | , oidcScopes :: [T.Text] 98 | -- | The amount of clock skew to allow when validating id tokens. Defaults 99 | -- to 0. 100 | -- 101 | -- @since 0.2.3.0 102 | , oidcAllowedSkew :: Clock.NominalDiffTime 103 | } 104 | 105 | instance FromJSON OpenIDConnect where 106 | parseJSON = 107 | withObject "OpenIDConnect Object" $ \obj -> do 108 | metadata <- obj .: "metadata" 109 | jwkSet <- obj .: "jwk_set" 110 | clientId <- obj .: "client_id" 111 | clientSecret <- obj .: "client_secret" 112 | providerInfo <- obj .: "provider_info" .!= defProviderInfo 113 | scopes <- obj .: "scopes" .!= ["openid"] 114 | allowedSkew <- obj .: "allowed_skew" .!= 0 115 | pure OpenIDConnect { 116 | oidcMetadata = metadata, 117 | oidcJwkSet = jwkSet, 118 | oidcClientId = clientId, 119 | oidcClientSecret = clientSecret, 120 | oidcProviderInfo = providerInfo, 121 | oidcManager = Nothing, 122 | oidcScopes = scopes, 123 | oidcAllowedSkew = allowedSkew 124 | } 125 | 126 | instance AuthProvider OpenIDConnect where 127 | getProviderName _ = "oidc" 128 | getProviderInfo = oidcProviderInfo 129 | handleLogin oidc@OpenIDConnect {.. } req suffix renderUrl onSuccess onFailure = do 130 | oauth2 <- mkOauth2 oidc (Just renderUrl) 131 | manager <- maybe getGlobalManager pure oidcManager 132 | oauth2Login 133 | oauth2 134 | manager 135 | (Just oidcScopes) 136 | (getProviderName oidc) 137 | req 138 | suffix 139 | onSuccess 140 | onFailure 141 | refreshLoginState oidc req user = 142 | let loginState = authLoginState user 143 | in case decodeToken loginState of 144 | Left _ -> pure Nothing 145 | Right tokens -> do 146 | vRes <- validateIdToken' oidc tokens 147 | case vRes of 148 | Nothing -> do 149 | oauth2 <- mkOauth2 oidc Nothing 150 | manager <- maybe getGlobalManager pure (oidcManager oidc) 151 | rRes <- refreshTokens tokens manager oauth2 152 | case rRes of 153 | Nothing -> pure Nothing 154 | Just newTokens -> do 155 | v2Res <- validateIdToken' oidc newTokens 156 | case v2Res of 157 | Nothing -> pure Nothing 158 | Just claims -> do 159 | CTime now <- epochTime 160 | let newUser = 161 | user { 162 | authLoginState = encodeToken newTokens, 163 | authLoginTime = fromIntegral now 164 | } 165 | pure (Just (storeClaims claims req, newUser)) 166 | Just claims -> 167 | pure (Just (storeClaims claims req, user)) 168 | 169 | -- | Fetch configuration for a provider from its discovery 170 | -- endpoint. Sets the path to @/.well-known/..@. 171 | -- 172 | -- @since 0.2.3.0 173 | discover :: T.Text -> IO OpenIDConnect 174 | discover urlText = do 175 | base <- parseAbsoluteURI urlText 176 | let uri = base { U.uriPath = "/.well-known/openid-configuration" } 177 | discoverURI uri 178 | 179 | -- | Fetch configuration for a provider from an exact URI. 180 | -- 181 | -- @since 0.2.3.1 182 | discoverURI :: U.URI -> IO OpenIDConnect 183 | discoverURI uri = do 184 | metadata <- fetchMetadata uri 185 | jwkset <- fetchJWKSet (jwksUri metadata) 186 | pure OpenIDConnect 187 | { oidcClientId = "" 188 | , oidcClientSecret = "" 189 | , oidcMetadata = metadata 190 | , oidcJwkSet = jwkset 191 | , oidcProviderInfo = defProviderInfo 192 | , oidcManager = Nothing 193 | , oidcScopes = ["openid"] 194 | , oidcAllowedSkew = 0 195 | } 196 | 197 | defProviderInfo :: ProviderInfo 198 | defProviderInfo = ProviderInfo "OpenID Connect Provider" "" "" 199 | 200 | fetchMetadata :: U.URI -> IO Metadata 201 | fetchMetadata metadataEndpoint = do 202 | req <- parseRequestThrow (S8.unpack $ U.serializeURIRef' metadataEndpoint) 203 | getResponseBody <$> httpJSON req 204 | 205 | fetchJWKSet :: T.Text -> IO JOSE.JWKSet 206 | fetchJWKSet jwkSetEndpoint = do 207 | req <- parseRequestThrow (T.unpack jwkSetEndpoint) 208 | getResponseBody <$> httpJSON req 209 | 210 | mkOauth2 :: OpenIDConnect -> Maybe (Text.Hamlet.Render ProviderUrl) -> IO OA2.OAuth2 211 | mkOauth2 OpenIDConnect {..} renderUrl = do 212 | callbackURI <- for renderUrl $ \render -> parseAbsoluteURI $ render (ProviderUrl ["complete"]) [] 213 | pure OA2.OAuth2 214 | { oauthClientId = oidcClientId 215 | , oauthClientSecret = Just oidcClientSecret 216 | , oauthOAuthorizeEndpoint = authorizationEndpoint oidcMetadata 217 | , oauthAccessTokenEndpoint = tokenEndpoint oidcMetadata 218 | , oauthCallback = callbackURI 219 | } 220 | 221 | validateIdToken :: OpenIDConnect -> OA2.IdToken -> IO (Either JWT.JWTError JWT.ClaimsSet) 222 | validateIdToken oidc (OA2.IdToken idToken) = runExceptT $ do 223 | signedJwt <- JOSE.decodeCompact (TLE.encodeUtf8 $ TL.fromStrict idToken) 224 | JWT.verifyClaims (validationSettings oidc) (oidcJwkSet oidc) signedJwt 225 | 226 | validateIdToken' :: OpenIDConnect -> OA2.OAuth2Token -> IO (Maybe JWT.ClaimsSet) 227 | validateIdToken' oidc tokens = 228 | case OA2.idToken tokens of 229 | Nothing -> pure Nothing 230 | Just idToken -> 231 | either (const Nothing) Just <$> validateIdToken oidc idToken 232 | 233 | -- The validation of the ID token below is stricter then specified in the OIDC 234 | -- spec, to make the job of validating tokens easier. If this is too limiting 235 | -- for your user case please open an issue. 236 | -- 237 | -- Full spec for ID token validation: 238 | -- https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation 239 | -- 240 | -- Ways in which the validation below is stricter then the spec requires: 241 | -- - We don't allow the `aud` claim to contain any audiences beyond ourselves. 242 | validationSettings :: OpenIDConnect -> JWT.JWTValidationSettings 243 | validationSettings oidc = 244 | -- The Client MUST validate that the aud (audience) Claim contains its 245 | -- client_id value registered at the Issuer identified by the iss (issuer) 246 | -- Claim as an audience. The aud (audience) Claim MAY contain an array with 247 | -- more than one element. The ID Token MUST be rejected if the ID Token does 248 | -- not list the Client as a valid audience, or if it contains additional 249 | -- audiences not trusted by the Client. 250 | validateAudience oidc 251 | -- If the ID Token is encrypted, decrypt it using the keys and algorithms 252 | -- that the Client specified during Registration that the OP was to use to 253 | -- encrypt the ID Token. If encryption was negotiated with the OP at 254 | -- Registration time and the ID Token is not encrypted, the RP SHOULD 255 | -- reject it. 256 | & JWT.defaultJWTValidationSettings 257 | -- The current time MUST be before the time represented by the exp Claim. 258 | & Lens.set JWT.jwtValidationSettingsCheckIssuedAt True 259 | -- The Issuer Identifier for the OpenID Provider (which is typically 260 | -- obtained during Discovery) MUST exactly match the value of the iss 261 | -- (issuer) Claim. 262 | & Lens.set JWT.jwtValidationSettingsIssuerPredicate (validateIssuer oidc) 263 | & Lens.set JWT.jwtValidationSettingsAllowedSkew (oidcAllowedSkew oidc) 264 | 265 | validateAudience :: OpenIDConnect -> JWT.StringOrURI -> Bool 266 | validateAudience oidc audClaim = 267 | audienceFromJWT == Just correctClientId 268 | where 269 | correctClientId = oidcClientId oidc 270 | audienceFromJWT = fromStringOrURI audClaim 271 | 272 | validateIssuer :: OpenIDConnect -> JWT.StringOrURI -> Bool 273 | validateIssuer oidc issClaim = 274 | issuerFromJWT == Just correctIssuer 275 | where 276 | correctIssuer = issuer (oidcMetadata oidc) 277 | issuerFromJWT = fromStringOrURI issClaim 278 | 279 | fromStringOrURI :: JWT.StringOrURI -> Maybe T.Text 280 | fromStringOrURI stringOrURI = 281 | Lens.Extras.preview JWT.string stringOrURI 282 | <|> fmap (T.pack . show) (Lens.Extras.preview JWT.uri stringOrURI) 283 | 284 | storeClaims :: JWT.ClaimsSet -> Request -> Request 285 | storeClaims claims req = 286 | req { vault = Vault.insert idTokenKey claims (vault req) } 287 | 288 | -- | Get the @IdToken@ for the current user. 289 | -- 290 | -- If called on a @Request@ behind the middleware, should always return a 291 | -- @Just@ value. 292 | -- 293 | -- The token returned was validated when the request was processed by the 294 | -- middleware. 295 | -- 296 | -- @since 0.2.3.0 297 | getIdToken :: Request -> Maybe JWT.ClaimsSet 298 | getIdToken req = Vault.lookup idTokenKey (vault req) 299 | 300 | idTokenKey :: Vault.Key JWT.ClaimsSet 301 | idTokenKey = unsafePerformIO Vault.newKey 302 | {-# NOINLINE idTokenKey #-} 303 | -------------------------------------------------------------------------------- /src/Network/Wai/Middleware/Auth/Provider.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveGeneric #-} 2 | {-# LANGUAGE GADTs #-} 3 | {-# LANGUAGE OverloadedStrings #-} 4 | {-# LANGUAGE QuasiQuotes #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE TemplateHaskell #-} 7 | module Network.Wai.Middleware.Auth.Provider 8 | ( AuthProvider(..) 9 | -- * Provider 10 | , Provider(..) 11 | , ProviderUrl(..) 12 | , ProviderInfo(..) 13 | , Providers 14 | -- * Provider Parsing 15 | , ProviderParser 16 | , mkProviderParser 17 | , parseProviders 18 | -- * User 19 | , AuthUser(..) 20 | , AuthLoginState 21 | , UserIdentity 22 | , authUserIdentity 23 | -- * Template 24 | , mkRouteRender 25 | , providersTemplate 26 | ) where 27 | 28 | import Blaze.ByteString.Builder (toByteString) 29 | import Control.Arrow (second) 30 | import Data.Aeson (FromJSON (..), Object, 31 | Result (..), Value) 32 | import Data.Aeson.Types (parseEither) 33 | 34 | import Data.Aeson.TH (defaultOptions, deriveJSON, 35 | fieldLabelModifier) 36 | import Data.Aeson.Types (Parser) 37 | import Data.Binary (Binary) 38 | import qualified Data.ByteString as S 39 | import qualified Data.ByteString.Builder as B 40 | import qualified Data.HashMap.Strict as HM 41 | import Data.Int 42 | import Data.Maybe (fromMaybe) 43 | import Data.Proxy (Proxy) 44 | import qualified Data.Text as T 45 | import Data.Text.Encoding (decodeUtf8With) 46 | import Data.Text.Encoding.Error (lenientDecode) 47 | import GHC.Generics (Generic) 48 | import Network.HTTP.Types (Status, renderQueryText) 49 | import Network.Wai (Request, Response) 50 | import Network.Wai.Auth.Tools (toLowerUnderscore) 51 | import Text.Blaze.Html.Renderer.Utf8 (renderHtmlBuilder) 52 | import Text.Hamlet (Render, hamlet) 53 | 54 | -- | Core Authentication class, that allows for extensibility of the Auth 55 | -- middleware created by `Network.Wai.Middleware.Auth.mkAuthMiddleware`. Most 56 | -- important function is `handleLogin`, which implements the actual behavior of a 57 | -- provider. It's function arguments in order: 58 | -- 59 | -- * @`ap`@ - Current provider. 60 | -- * @`Request`@ - Request made to the login page 61 | -- * @[`T.Text`]@ - Url suffix, i.e. last part of the Url split by @\'/\'@ character, 62 | -- for instance @["login", "complete"]@ suffix in the example below. 63 | -- * @`Render` `ProviderUrl`@ - 64 | -- Url renderer. It takes desired suffix as first argument and produces an 65 | -- absolute Url renderer. It can further be used to generate provider urls, 66 | -- for instance in Hamlet templates as 67 | -- will result in 68 | -- @"https:\/\/approot.com\/_auth_middleware\/providerName\/login\/complete?user=Hamlet"@ 69 | -- or generate Urls for callbacks. 70 | -- 71 | -- @ 72 | -- \@?{(ProviderUrl ["login", "complete"], [("user", "Hamlet")])} 73 | -- @ 74 | -- 75 | -- * @(`AuthLoginState` -> `IO` `Response`)@ - Action to call on a successfull login. 76 | -- * @(`Status` -> `S.ByteString` -> `IO` `Response`)@ - Should be called in case of 77 | -- a failure with login process by supplying a 78 | -- status and a short error message. 79 | class AuthProvider ap where 80 | 81 | -- | Return a name for the provider. It will be used as a unique identifier 82 | -- for this provider. Argument should not be evaluated, as there are many 83 | -- places were `undefined` value is passed to this function. 84 | -- 85 | -- @since 0.1.0 86 | getProviderName :: ap -> T.Text 87 | 88 | -- | Get info about the provider. It will be used in rendering the web page 89 | -- with a list of providers. 90 | -- 91 | -- @since 0.1.0 92 | getProviderInfo :: ap -> ProviderInfo 93 | 94 | -- | Handle a login request in a custom manner. Can be used to render a login 95 | -- page with a form or redirect to some other authentication service like 96 | -- OpenID or OAuth2. 97 | -- 98 | -- @since 0.1.0 99 | handleLogin 100 | :: ap 101 | -> Request 102 | -> [T.Text] 103 | -> Render ProviderUrl 104 | -> (AuthLoginState -> IO Response) 105 | -> (Status -> S.ByteString -> IO Response) 106 | -> IO Response 107 | 108 | -- | Check if the login state in a session is still valid, and have the 109 | -- opportunity to update it. Return `Nothing` to indicate a session has 110 | -- expired, and the user will be directed to re-authenticate. 111 | -- 112 | -- The default implementation never invalidates a session once set. 113 | -- 114 | -- @since 0.2.3.0 115 | refreshLoginState 116 | :: ap 117 | -> Request 118 | -> AuthUser 119 | -> IO (Maybe (Request, AuthUser)) 120 | refreshLoginState _ req loginState = pure (Just (req, loginState)) 121 | 122 | -- | Generic authentication provider wrapper. 123 | data Provider where 124 | Provider :: AuthProvider p => p -> Provider 125 | 126 | 127 | instance AuthProvider Provider where 128 | 129 | getProviderName (Provider p) = getProviderName p 130 | 131 | getProviderInfo (Provider p) = getProviderInfo p 132 | 133 | handleLogin (Provider p) = handleLogin p 134 | 135 | refreshLoginState (Provider p) = refreshLoginState p 136 | 137 | -- | Collection of supported providers. 138 | type Providers = HM.HashMap T.Text Provider 139 | 140 | -- | Aeson parser for a provider with unique provider name (same as returned by 141 | -- `getProviderName`) 142 | type ProviderParser = (T.Text, Value -> Parser Provider) 143 | 144 | -- | Data type for rendering Provider specific urls. 145 | newtype ProviderUrl = ProviderUrl [T.Text] 146 | 147 | -- | Provider information used for rendering a page with list of supported providers. 148 | data ProviderInfo = ProviderInfo 149 | { providerTitle :: T.Text 150 | , providerLogoUrl :: T.Text 151 | , providerDescr :: T.Text 152 | } deriving (Show) 153 | 154 | 155 | -- | An arbitrary state that comes with logged in user, eg. a username, token or an email address. 156 | type AuthLoginState = S.ByteString 157 | 158 | type UserIdentity = S.ByteString 159 | {-# DEPRECATED UserIdentity "In favor of `AuthLoginState`" #-} 160 | 161 | authUserIdentity :: AuthUser -> UserIdentity 162 | authUserIdentity = authLoginState 163 | {-# DEPRECATED authUserIdentity "In favor of `authLoginState`" #-} 164 | 165 | -- | Representation of a user for a particular `Provider`. 166 | data AuthUser = AuthUser 167 | { authLoginState :: !UserIdentity 168 | , authProviderName :: !S.ByteString 169 | , authLoginTime :: !Int64 170 | } deriving (Eq, Generic, Show) 171 | 172 | instance Binary AuthUser 173 | 174 | 175 | 176 | -- | First argument is not evaluated and is only needed for restricting the type. 177 | mkProviderParser :: forall ap . (FromJSON ap, AuthProvider ap) => Proxy ap -> ProviderParser 178 | mkProviderParser _ = 179 | ( getProviderName nameProxyError 180 | , fmap Provider <$> (parseJSON :: Value -> Parser ap)) 181 | where 182 | nameProxyError :: ap 183 | nameProxyError = error "AuthProvider.getProviderName should not evaluate it's argument." 184 | 185 | -- | Parse configuration for providers from an `Object`. 186 | parseProviders :: Object -> [ProviderParser] -> Result Providers 187 | parseProviders unparsedProvidersHM providerParsers = 188 | if HM.null unrecognized 189 | then sequence $ HM.intersectionWith parseProvider unparsedProvidersHM parsersHM 190 | else Error $ 191 | "Provider name(s) are not recognized: " ++ 192 | T.unpack (T.intercalate ", " $ HM.keys unrecognized) 193 | where 194 | parsersHM = HM.fromList providerParsers 195 | unrecognized = HM.difference unparsedProvidersHM parsersHM 196 | parseProvider v p = either Error Success $ parseEither p v 197 | 198 | -- | Create a url renderer for a provider. 199 | mkRouteRender :: Maybe T.Text -> T.Text -> [T.Text] -> Render Provider 200 | mkRouteRender appRoot authPrefix authSuffix (Provider p) params = 201 | (T.intercalate "/" $ [root, authPrefix, getProviderName p] ++ authSuffix) <> 202 | decodeUtf8With 203 | lenientDecode 204 | (toByteString $ renderQueryText True (map (second Just) params)) 205 | where 206 | root = fromMaybe "" appRoot 207 | 208 | 209 | $(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 8} ''ProviderInfo) 210 | 211 | 212 | -- | Template for the providers page 213 | providersTemplate :: Maybe T.Text -- ^ Error message to display, if any. 214 | -> Render Provider -- ^ Renderer function for provider urls. 215 | -> Providers -- ^ List of available providers. 216 | -> B.Builder 217 | providersTemplate merrMsg render providers = 218 | renderHtmlBuilder $ [hamlet| 219 | $doctype 5 220 | 221 | 222 | WAI Auth Middleware - Authentication Providers. 223 | <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap.min.css" integrity="sha384-BVYiiSIFeK1dGmJRAkycuHAHRg32OmUcww7on3RYdg4Va+PmSTsz/K68vbdEjh4u" crossorigin="anonymous"> 224 | <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap-theme.min.css" integrity="sha384-rHyoN1iRsVXV4nD0JutlnGaslCJuC7uwjduW9SVrLvRYooPp2bWYgmgJQIXwl/Sp" crossorigin="anonymous"> 225 | <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.1.1/jquery.min.js"> 226 | <script src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/js/bootstrap.min.js" integrity="sha384-Tc5IQib027qvyjSMfHjOMaLkfuWVxZxUPnCJA7l2mCWNIpG9mGCD8wGNIcPD7Txa" crossorigin="anonymous"> 227 | <style> 228 | .provider-logo { 229 | max-height: 64px; 230 | max-width: 64px; 231 | padding: 5px; 232 | margin: auto; 233 | position: absolute; 234 | top: 0; 235 | bottom: 0; 236 | left: 0; 237 | right: 0; 238 | } 239 | .media-container { 240 | width: 600px; 241 | position: absolute; 242 | top: 100px; 243 | bottom: 0; 244 | left: 0; 245 | right: 0; 246 | margin: auto; 247 | } 248 | .provider.media { 249 | border: 1px solid #e1e1e8; 250 | padding: 5px; 251 | height: 82px; 252 | text-overflow: ellipsis; 253 | margin-top: 5px; 254 | } 255 | .provider.media:hover { 256 | background-color: #f5f5f5; 257 | border: 1px solid #337ab7; 258 | } 259 | .provider .media-left { 260 | height: 70px; 261 | width: 0px; 262 | padding-right: 70px; 263 | position: relative; 264 | } 265 | a:hover { 266 | text-decoration: none; 267 | } 268 | <body> 269 | <div .media-container> 270 | <h3>Select one of available authentication methods: 271 | $maybe errMsg <- merrMsg 272 | <div .alert .alert-danger role="alert"> 273 | #{errMsg} 274 | $forall provider <- providers 275 | $with info <- getProviderInfo provider 276 | <div .media.provider> 277 | <a href=@{provider}> 278 | <div .media-left .container> 279 | <img .provider-logo src=#{providerLogoUrl info}> 280 | <div .media-body> 281 | <h3 .media-heading> 282 | #{providerTitle info} 283 | #{providerDescr info} 284 | |] render 285 | -------------------------------------------------------------------------------- /stack-lts-14.yaml: -------------------------------------------------------------------------------- 1 | resolver: lts-14.27 2 | extra-deps: 3 | - hoauth2-1.11.0@rev:0 # For older resolvers on CI 4 | -------------------------------------------------------------------------------- /stack-nightly.yaml: -------------------------------------------------------------------------------- 1 | resolver: nightly-2021-08-06 2 | extra-deps: 3 | - hoauth2-1.16.0 4 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | resolver: lts-17.12 2 | -------------------------------------------------------------------------------- /stack.yaml.lock: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by Stack. 2 | # You should not edit this file by hand. 3 | # For more information, please see the documentation at: 4 | # https://docs.haskellstack.org/en/stable/lock_files 5 | 6 | packages: [] 7 | snapshots: 8 | - completed: 9 | size: 565712 10 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/17/6.yaml 11 | sha256: 4e5e581a709c88e3fe26a9ce8bf331435729bead762fb5c190064c6c5bb1b835 12 | original: lts-17.6 13 | -------------------------------------------------------------------------------- /test/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | {-# OPTIONS_GHC -fno-warn-orphans #-} 3 | module Main (main) where 4 | 5 | import Test.Tasty 6 | import qualified Spec.Network.Wai.Auth.Internal 7 | import qualified Spec.Network.Wai.Middleware.Auth.OAuth2 8 | import qualified Spec.Network.Wai.Middleware.Auth.OIDC 9 | 10 | main :: IO () 11 | main = defaultMain tests 12 | 13 | tests :: TestTree 14 | tests = testGroup "wai-middleware-auth" 15 | [ Spec.Network.Wai.Auth.Internal.tests 16 | , Spec.Network.Wai.Middleware.Auth.OAuth2.tests 17 | , Spec.Network.Wai.Middleware.Auth.OIDC.tests 18 | ] 19 | -------------------------------------------------------------------------------- /test/Network/Wai/Auth/Test.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | 4 | module Network.Wai.Auth.Test 5 | (ChangeProvider 6 | , FakeProviderConf(..) 7 | , fakeProvider 8 | , const200 9 | , get 10 | ) where 11 | 12 | import Control.Monad.IO.Class (liftIO) 13 | import Data.ByteString (ByteString) 14 | import qualified Data.IORef as IORef 15 | import qualified Crypto.JOSE as JOSE 16 | import qualified Crypto.JWT as JWT 17 | import qualified Control.Monad.Except 18 | import qualified Data.Aeson as Aeson 19 | import Data.Function ((&)) 20 | import qualified Data.Text as T 21 | import qualified Data.Text.Encoding as TE 22 | import qualified Data.Text.Lazy as TL 23 | import qualified Data.Text.Lazy.Encoding as TLE 24 | import qualified Data.Time.Clock as Clock 25 | import GHC.Exts (fromString) 26 | import qualified Network.HTTP.Types.Status as Status 27 | import qualified Network.OAuth.OAuth2 as OA2 28 | import qualified Network.Wai as Wai 29 | import Network.Wai.Auth.Internal (Metadata(..)) 30 | import Network.Wai.Test (Session, SResponse, 31 | defaultRequest, 32 | request, setPath) 33 | import qualified Lens.Micro as Lens 34 | import qualified URI.ByteString as U 35 | 36 | get :: ByteString -> Session SResponse 37 | get = request . setPath defaultRequest 38 | 39 | const200 :: Wai.Application 40 | const200 _ respond = respond $ Wai.responseLBS Status.ok200 [] "" 41 | 42 | data FakeProviderConf 43 | = FakeProviderConf 44 | { jwtExpiresIn :: Clock.NominalDiffTime, 45 | jwtAudience :: JWT.StringOrURI, 46 | jwtIssuer :: T.Text, 47 | jwtJWK :: JOSE.JWK, 48 | jwtSub :: String, 49 | accessTokenExpiresIn :: Int, 50 | returnIdToken :: Bool, 51 | returnRefreshToken :: Bool 52 | } 53 | 54 | defaultConfig :: IO FakeProviderConf 55 | defaultConfig = do 56 | jwk <- JOSE.genJWK (JOSE.RSAGenParam 256) 57 | pure 58 | FakeProviderConf 59 | { jwtExpiresIn = 600, 60 | jwtAudience = "client-id", 61 | jwtIssuer = "test-oidc-provider", 62 | jwtJWK = jwk, 63 | jwtSub = "1234", 64 | accessTokenExpiresIn = 600, 65 | returnIdToken = True, 66 | returnRefreshToken = True 67 | } 68 | 69 | type ChangeProvider = (FakeProviderConf -> FakeProviderConf) -> Session () 70 | 71 | fakeProvider :: IO (Wai.Application, ChangeProvider) 72 | fakeProvider = do 73 | config <- defaultConfig 74 | configRef <- IORef.newIORef config 75 | let changeProvider = IORef.modifyIORef configRef 76 | pure (fakeProvider' configRef, liftIO . changeProvider) 77 | 78 | fakeProvider' :: IORef.IORef FakeProviderConf -> Wai.Application 79 | fakeProvider' configRef req respond = do 80 | config <- IORef.readIORef configRef 81 | case Wai.pathInfo req of 82 | [".well-known", "openid-configuration"] -> 83 | case TE.decodeUtf8 <$> Wai.requestHeaderHost req of 84 | Nothing -> 85 | Wai.responseLBS Status.badRequest400 [] "" 86 | & respond 87 | Just host -> 88 | Metadata 89 | { issuer = jwtIssuer config, 90 | authorizationEndpoint = parseURI ("http://" <> host <> "/authorize"), 91 | tokenEndpoint = parseURI ("http://" <> host <> "/token"), 92 | userinfoEndpoint = Nothing, 93 | revocationEndpoint = Nothing, 94 | jwksUri = "http://" <> host <> "/jwks", 95 | responseTypesSupported = ["code"], 96 | subjectTypesSupported = ["public"], 97 | idTokenSigningAlgValuesSupported = ["RS256"], 98 | scopesSupported = Just ["openid"], 99 | tokenEndpointAuthMethodsSupported = Just ["client_secret_basic"], 100 | claimsSupported = Just ["iss", "sub", "aud", "exp", "iat"] 101 | } 102 | & Aeson.encode 103 | & Wai.responseLBS Status.ok200 [("Content-Type", "application/json")] 104 | & respond 105 | ["jwks"] -> 106 | JOSE.JWKSet [jwtJWK config] 107 | & Aeson.encode 108 | & Wai.responseLBS Status.ok200 [("Content-Type", "application/json")] 109 | & respond 110 | ["token"] -> do 111 | now <- Clock.getCurrentTime 112 | let claims = 113 | JWT.emptyClaimsSet 114 | & Lens.set JWT.claimIss (Just (fromString (T.unpack (jwtIssuer config)))) 115 | & Lens.set JWT.claimAud (Just (JWT.Audience [jwtAudience config])) 116 | & Lens.set JWT.claimIat (Just (JWT.NumericDate now)) 117 | & Lens.set JWT.claimExp (Just (JWT.NumericDate (Clock.addUTCTime (jwtExpiresIn config) now))) 118 | & Lens.set JWT.claimSub (Just (fromString (jwtSub config))) 119 | idToken <- doJwtSign (jwtJWK config) claims 120 | OA2.OAuth2Token 121 | { OA2.accessToken = OA2.AccessToken "access-granted", 122 | OA2.refreshToken = 123 | if returnRefreshToken config 124 | then Just (OA2.RefreshToken "refresh-token") 125 | else Nothing, 126 | OA2.expiresIn = Just (accessTokenExpiresIn config), 127 | OA2.tokenType = Nothing, 128 | OA2.idToken = 129 | if returnIdToken config 130 | then Just (OA2.IdToken idToken) 131 | else Nothing 132 | } 133 | & Aeson.encode 134 | & Wai.responseLBS Status.ok200 [("Content-Type", "application/json")] 135 | & respond 136 | _ -> 137 | Wai.responseLBS Status.notFound404 [] "" 138 | & respond 139 | 140 | doJwtSign :: JOSE.JWK -> JWT.ClaimsSet -> IO T.Text 141 | doJwtSign jwk claims = do 142 | result <- Control.Monad.Except.runExceptT $ do 143 | alg <- JOSE.bestJWSAlg jwk 144 | JWT.signClaims jwk (JOSE.newJWSHeader ((), alg)) claims 145 | case result of 146 | Left (err :: JOSE.Error) -> fail (show err) 147 | Right bytestring -> 148 | JOSE.encodeCompact bytestring 149 | & TLE.decodeUtf8 150 | & TL.toStrict 151 | & pure 152 | 153 | parseURI :: T.Text -> U.URIRef U.Absolute 154 | parseURI uri = 155 | TE.encodeUtf8 uri 156 | & U.parseURI U.laxURIParserOptions 157 | & either (error . show) id 158 | -------------------------------------------------------------------------------- /test/Spec/Network/Wai/Auth/Internal.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | {-# OPTIONS_GHC -fno-warn-orphans #-} 3 | module Spec.Network.Wai.Auth.Internal (tests) where 4 | 5 | import Data.Binary (encode, decodeOrFail) 6 | import qualified Data.ByteString.Lazy.Char8 as BSL8 7 | import qualified Data.Text as T 8 | import Test.Tasty (TestTree, testGroup) 9 | import Test.Tasty.Hedgehog (testProperty) 10 | import Hedgehog 11 | import Hedgehog.Gen as Gen 12 | import Hedgehog.Range as Range 13 | import Network.Wai.Auth.Internal 14 | import qualified Network.OAuth.OAuth2.Internal as OA2 15 | 16 | tests :: TestTree 17 | tests = testGroup "Network.Wai.Auth.Internal" 18 | [ testProperty "oAuth2TokenBinaryDuality" oAuth2TokenBinaryDuality 19 | ] 20 | 21 | oAuth2TokenBinaryDuality :: Property 22 | oAuth2TokenBinaryDuality = property $ do 23 | token <- forAll oauth2TokenBinary 24 | let checkUnconsumed ("", _, roundTripToken) = roundTripToken 25 | checkUnconsumed (unconsumed, _, _) = 26 | error $ "Unexpected unconsumed in bytes: " <> BSL8.unpack unconsumed 27 | tripping token encode (fmap checkUnconsumed . decodeOrFail) 28 | tripping token (encodeToken . unOAuth2TokenBinary) (fmap OAuth2TokenBinary . decodeToken) 29 | 30 | oauth2TokenBinary :: Gen OAuth2TokenBinary 31 | oauth2TokenBinary = do 32 | accessToken <- OA2.AccessToken <$> anyText 33 | refreshToken <- Gen.maybe $ OA2.RefreshToken <$> anyText 34 | expiresIn <- Gen.maybe $ Gen.int (Range.linear 0 1000) 35 | tokenType <- Gen.maybe anyText 36 | idToken <- Gen.maybe $ OA2.IdToken <$> anyText 37 | pure $ 38 | OAuth2TokenBinary $ 39 | OA2.OAuth2Token accessToken refreshToken expiresIn tokenType idToken 40 | 41 | anyText :: Gen T.Text 42 | anyText = Gen.text (Range.linear 0 100) Gen.unicodeAll 43 | 44 | -- The `OAuth2Token` type from the `hoauth2` library does not have a `Eq` 45 | -- instance, and it's constituent parts don't have a `Generic` instance. Hence 46 | -- this orphan instance here. 47 | instance Eq OAuth2TokenBinary where 48 | (OAuth2TokenBinary t1) == (OAuth2TokenBinary t2) = 49 | and 50 | [ OA2.atoken (OA2.accessToken t1) == OA2.atoken (OA2.accessToken t2) 51 | , (OA2.rtoken <$> OA2.refreshToken t1) == (OA2.rtoken <$> OA2.refreshToken t2) 52 | , OA2.expiresIn t1 == OA2.expiresIn t2 53 | , OA2.tokenType t1 == OA2.tokenType t2 54 | , (OA2.idtoken <$> OA2.idToken t1) == (OA2.idtoken <$> OA2.idToken t2) 55 | ] 56 | -------------------------------------------------------------------------------- /test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | 3 | module Spec.Network.Wai.Middleware.Auth.OAuth2 (tests) where 4 | 5 | import Control.Monad (void) 6 | import Data.Function ((&)) 7 | import qualified Data.Text as T 8 | import qualified Data.Text.Encoding as TE 9 | import GHC.Exts (fromList) 10 | import qualified Network.HTTP.Types.Status as Status 11 | import qualified Network.Wai as Wai 12 | import Network.Wai.Auth.Test (ChangeProvider, 13 | FakeProviderConf(..), 14 | fakeProvider, 15 | const200, get) 16 | import qualified Network.Wai.Handler.Warp as Warp 17 | import qualified Network.Wai.Middleware.Auth as Auth 18 | import Network.Wai.Middleware.Auth.OAuth2 (OAuth2(..), 19 | getAccessToken) 20 | import Network.Wai.Middleware.Auth.Provider (Provider(..), 21 | ProviderInfo(..)) 22 | import Network.Wai.Test (Session, assertHeader, 23 | assertStatus, 24 | runSession, 25 | setClientCookie) 26 | import Test.Tasty (TestTree, testGroup) 27 | import Test.Tasty.HUnit (testCase) 28 | import qualified Web.Cookie as Cookie 29 | import qualified Web.ClientSession 30 | 31 | tests :: TestTree 32 | tests = testGroup "Network.Wai.Auth.OAuth2" 33 | [ testCase "when a request without a session is made then redirect to re-authorize" $ 34 | runSessionWithProvider const200 $ \host _ -> do 35 | redirect1 <- get "/hi" 36 | assertStatus 303 redirect1 37 | assertHeader "Location" "/prefix" redirect1 38 | redirect2 <- get "/prefix" 39 | assertStatus 303 redirect2 40 | assertHeader "location" "/prefix/oauth2" redirect2 41 | redirect3 <- get "/prefix/oauth2" 42 | assertStatus 303 redirect3 43 | assertHeader 44 | "location" 45 | (TE.encodeUtf8 host <> "/authorize?scope=scope1%20scope2&client_id=client-id&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%2Fprefix%2Foauth2%2Fcomplete") 46 | redirect3 47 | 48 | , testCase "when a request is made with a valid session then pass the request through" $ 49 | runSessionWithProvider const200 $ \_ _ -> do 50 | createSession 51 | response <- get "/some/endpoint" 52 | assertStatus 200 response 53 | 54 | , testCase "when an access token expired and no refresh token is available then redirect to re-authorize" $ 55 | runSessionWithProvider const200 $ \_ changeProvider -> do 56 | changeProvider (\c -> c { accessTokenExpiresIn = -600, returnRefreshToken = False }) 57 | createSession 58 | response <- get "/some/endpoint" 59 | assertStatus 303 response 60 | 61 | , testCase "when an access token expired then use a refresh token" $ 62 | runSessionWithProvider const200 $ \_ changeProvider -> do 63 | changeProvider (\c -> c { accessTokenExpiresIn = -600 }) 64 | createSession 65 | response <- get "/some/endpoint" 66 | assertStatus 200 response 67 | 68 | , testCase "when a request is made with an invalid session redirect to re-authorize" $ 69 | runSessionWithProvider const200 $ \_ _ -> do 70 | -- First create a known valid session, so we can see that it's the act 71 | -- of corrupting it that makes the test fail. 72 | createSession 73 | setClientCookie 74 | Cookie.defaultSetCookie 75 | { Cookie.setCookieName = "auth-cookie" 76 | , Cookie.setCookieValue = "garbage" 77 | } 78 | response <- get "/some/endpoint" 79 | assertStatus 303 response 80 | 81 | , testCase "when a request is made to the complete endpoint then create a session" $ 82 | runSessionWithProvider const200 $ \_ _ -> do 83 | response <- get "/prefix/oauth2/complete?code=1234" 84 | assertStatus 303 response 85 | assertHeader "location" "/" response 86 | 87 | , testCase "when a request with a valid session is made then the app can access the session" $ 88 | let app req respond = 89 | case getAccessToken req of 90 | Nothing -> respond $ Wai.responseLBS Status.badRequest400 [] "" 91 | Just _ -> respond $ Wai.responseLBS Status.ok200 [] "" 92 | in runSessionWithProvider app $ \_ _ -> do 93 | createSession 94 | response <- get "/some/endpoint" 95 | assertStatus 200 response 96 | ] 97 | 98 | createSession :: Session () 99 | createSession = void $ get "/prefix/oauth2/complete?code=1234" 100 | 101 | authSettings :: T.Text -> Auth.AuthSettings 102 | authSettings host = 103 | Auth.defaultAuthSettings 104 | & Auth.setAuthProviders (fromList [("oauth2", provider host)]) 105 | & Auth.setAuthPrefix "prefix" 106 | & Auth.setAuthCookieName "auth-cookie" 107 | & Auth.setAuthKey (snd <$> Web.ClientSession.randomKey) 108 | 109 | provider :: T.Text -> Provider 110 | provider host = 111 | Provider 112 | OAuth2 113 | { oa2ClientId = "client-id" 114 | , oa2ClientSecret = "client-secret" 115 | , oa2AuthorizeEndpoint = host <> "/authorize" 116 | , oa2AccessTokenEndpoint = host <> "/token" 117 | , oa2Scope = Just ["scope1", "scope2"] 118 | , oa2ProviderInfo = 119 | ProviderInfo 120 | { providerTitle = "" 121 | , providerLogoUrl = "" 122 | , providerDescr = "" 123 | } 124 | } 125 | 126 | runSessionWithProvider :: Wai.Application -> (T.Text -> ChangeProvider -> Session a) -> IO a 127 | runSessionWithProvider app session = do 128 | (p, changeProvider) <- fakeProvider 129 | Warp.testWithApplication (pure p) $ \port -> do 130 | let host = "http://localhost:" <> T.pack (show port) 131 | middleware <- Auth.mkAuthMiddleware $ authSettings host 132 | let app' = middleware app 133 | runSession (session host changeProvider) app' 134 | -------------------------------------------------------------------------------- /test/Spec/Network/Wai/Middleware/Auth/OIDC.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | 4 | module Spec.Network.Wai.Middleware.Auth.OIDC (tests) where 5 | 6 | import Control.Monad (void) 7 | import Control.Monad.IO.Class (liftIO) 8 | import qualified Crypto.JOSE as JOSE 9 | import Data.Function ((&)) 10 | import qualified Data.Text as T 11 | import qualified Data.Text.Encoding as TE 12 | import GHC.Exts (fromList, fromString) 13 | import qualified Network.HTTP.Types.Status as Status 14 | import qualified Network.Wai as Wai 15 | import Network.Wai.Auth.Test (ChangeProvider, 16 | FakeProviderConf(..), 17 | fakeProvider, 18 | const200, get) 19 | import qualified Network.Wai.Handler.Warp as Warp 20 | import qualified Network.Wai.Middleware.Auth as Auth 21 | import Network.Wai.Middleware.Auth.OIDC 22 | import Network.Wai.Middleware.Auth.Provider (Provider(..)) 23 | import Network.Wai.Test (Session, assertHeader, 24 | assertStatus, 25 | runSession, 26 | setClientCookie) 27 | import Test.Tasty (TestTree, testGroup) 28 | import Test.Tasty.HUnit (testCase) 29 | import qualified Web.Cookie as Cookie 30 | import qualified Web.ClientSession 31 | 32 | tests :: TestTree 33 | tests = testGroup "Network.Wai.Auth.OIDC" 34 | [ testCase "when a request without a session is made then redirect to re-authorize" $ 35 | runSessionWithProvider const200 $ \host _ -> do 36 | redirect1 <- get "/hi" 37 | assertStatus 303 redirect1 38 | assertHeader "Location" "/prefix" redirect1 39 | redirect2 <- get "/prefix" 40 | assertStatus 303 redirect2 41 | assertHeader "location" "/prefix/oidc" redirect2 42 | redirect3 <- get "/prefix/oidc" 43 | assertStatus 303 redirect3 44 | assertHeader 45 | "location" 46 | (TE.encodeUtf8 host <> "/authorize?scope=openid%20scope1&client_id=client-id&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%2Fprefix%2Foidc%2Fcomplete") 47 | redirect3 48 | 49 | , testCase "when a request is made with a valid session then pass the request through" $ 50 | runSessionWithProvider const200 $ \_ _ -> do 51 | createSession 52 | response <- get "/some/endpoint" 53 | assertStatus 200 response 54 | 55 | , testCase "when an ID token expired and no refresh token is available then redirect to re-authorize" $ 56 | runSessionWithProvider const200 $ \_ changeProvider -> do 57 | changeProvider (\c -> c { jwtExpiresIn = -600, returnRefreshToken = False }) 58 | createSession 59 | response <- get "/some/endpoint" 60 | assertStatus 303 response 61 | 62 | , testCase "when an ID token expired then use a refresh token" $ 63 | runSessionWithProvider const200 $ \_ changeProvider -> do 64 | changeProvider (\c -> c { jwtExpiresIn = -600 }) 65 | createSession 66 | changeProvider (\c -> c { jwtExpiresIn = 600 }) 67 | response <- get "/some/endpoint" 68 | assertStatus 200 response 69 | 70 | , testCase "when a request is made with an invalid session redirect to re-authorize" $ 71 | runSessionWithProvider const200 $ \_ _ -> do 72 | -- First create a known valid session, so we can see that it's the act 73 | -- of corrupting it that makes the test fail. 74 | createSession 75 | setClientCookie 76 | Cookie.defaultSetCookie 77 | { Cookie.setCookieName = "auth-cookie" 78 | , Cookie.setCookieValue = "garbage" 79 | } 80 | response <- get "/some/endpoint" 81 | assertStatus 303 response 82 | 83 | , testCase "when a request is made to the complete endpoint then create a session" $ 84 | runSessionWithProvider const200 $ \_ _ -> do 85 | response <- get "/prefix/oidc/complete?code=1234" 86 | assertStatus 303 response 87 | assertHeader "location" "/" response 88 | 89 | , testCase "when a request with a valid session is made then the app can access the access token" $ 90 | let app req respond = 91 | case getAccessToken req of 92 | Nothing -> respond $ Wai.responseLBS Status.badRequest400 [] "" 93 | Just _ -> respond $ Wai.responseLBS Status.ok200 [] "" 94 | in runSessionWithProvider app $ \_ _ -> do 95 | createSession 96 | response <- get "/some/endpoint" 97 | assertStatus 200 response 98 | 99 | , testCase "when a request with a valid session is made then the app can access the id token" $ 100 | let app req respond = 101 | case getIdToken req of 102 | Nothing -> respond $ Wai.responseLBS Status.badRequest400 [] "" 103 | Just _ -> respond $ Wai.responseLBS Status.ok200 [] "" 104 | in runSessionWithProvider app $ \_ _ -> do 105 | createSession 106 | response <- get "/some/endpoint" 107 | assertStatus 200 response 108 | 109 | , testCase "when an ID token has an invalid audience then redirect to re-authorize" $ 110 | runSessionWithProvider const200 $ \_ changeProvider -> do 111 | changeProvider (\c -> c { jwtAudience = fromString "wrong-audience" }) 112 | createSession 113 | response <- get "/some/endpoint" 114 | assertStatus 303 response 115 | 116 | , testCase "when an ID token has an invalid issuer then redirect to re-authorize" $ 117 | runSessionWithProvider const200 $ \_ changeProvider -> do 118 | changeProvider (\c -> c { jwtIssuer = "wrong-issuer" }) 119 | createSession 120 | response <- get "/some/endpoint" 121 | assertStatus 303 response 122 | 123 | , testCase "when a session does not contain an ID token then redirect to re-authorize" $ 124 | runSessionWithProvider const200 $ \_ changeProvider -> do 125 | changeProvider (\c -> c { returnIdToken = False }) 126 | createSession 127 | response <- get "/some/endpoint" 128 | assertStatus 303 response 129 | 130 | , testCase "when an ID token has an invalid signature then redirect to re-authorize" $ 131 | runSessionWithProvider const200 $ \_ changeProvider -> do 132 | newJWK <- liftIO $ JOSE.genJWK (JOSE.RSAGenParam 256) 133 | changeProvider (\c -> c { jwtJWK = newJWK }) 134 | createSession 135 | response <- get "/some/endpoint" 136 | assertStatus 303 response 137 | ] 138 | 139 | createSession :: Session () 140 | createSession = void $ get "/prefix/oidc/complete?code=1234" 141 | 142 | runSessionWithProvider :: Wai.Application -> (T.Text -> ChangeProvider -> Session a) -> IO a 143 | runSessionWithProvider app session = do 144 | (provider, changeProvider) <- fakeProvider 145 | Warp.testWithApplication (pure provider) $ \port -> do 146 | let host = "http://localhost:" <> T.pack (show port) 147 | middleware <- Auth.mkAuthMiddleware =<< authSettings host 148 | let app' = middleware app 149 | runSession (session host changeProvider) app' 150 | 151 | authSettings :: T.Text -> IO Auth.AuthSettings 152 | authSettings host = do 153 | oidc' <- discover host 154 | let oidc = 155 | oidc' 156 | { oidcClientId = "client-id" 157 | , oidcClientSecret = "client-secret" 158 | , oidcScopes = ["openid", "scope1"] 159 | } 160 | pure $ Auth.defaultAuthSettings 161 | & Auth.setAuthProviders (fromList [("oidc", Provider oidc)]) 162 | & Auth.setAuthPrefix "prefix" 163 | & Auth.setAuthCookieName "auth-cookie" 164 | & Auth.setAuthKey (snd <$> Web.ClientSession.randomKey) 165 | -------------------------------------------------------------------------------- /wai-middleware-auth.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.18 2 | name: wai-middleware-auth 3 | version: 0.2.6.0 4 | synopsis: Authentication middleware that secures WAI application 5 | description: Please see the README and Haddocks at <https://www.stackage.org/package/wai-middleware-auth> 6 | license: MIT 7 | license-file: LICENSE 8 | author: Alexey Kuleshevich 9 | maintainer: alexey@fpcomplete.com 10 | category: Web 11 | build-type: Simple 12 | extra-doc-files: README.md CHANGELOG.md 13 | 14 | library 15 | exposed-modules: Network.Wai.Middleware.Auth 16 | Network.Wai.Middleware.Auth.OAuth2 17 | Network.Wai.Middleware.Auth.OAuth2.Gitlab 18 | Network.Wai.Middleware.Auth.OAuth2.Github 19 | Network.Wai.Middleware.Auth.OAuth2.Google 20 | Network.Wai.Middleware.Auth.OIDC 21 | Network.Wai.Middleware.Auth.Provider 22 | Network.Wai.Auth.Executable 23 | Network.Wai.Auth.Internal 24 | other-modules: Paths_wai_middleware_auth 25 | Network.Wai.Auth.AppRoot 26 | Network.Wai.Auth.Config 27 | Network.Wai.Auth.ClientSession 28 | Network.Wai.Auth.Tools 29 | build-depends: aeson 30 | , base >= 4.12 && < 5 31 | , base64-bytestring 32 | , binary 33 | , blaze-builder 34 | , blaze-html 35 | , bytestring 36 | , case-insensitive 37 | , cereal 38 | , clientsession 39 | , cookie >= 0.4.2 40 | , exceptions 41 | , hoauth2 >= 1.11 42 | , http-client 43 | , http-client-tls 44 | , http-conduit 45 | , http-reverse-proxy 46 | , http-types 47 | , jose >= 0.8.0 48 | , microlens 49 | , mtl 50 | , regex-posix 51 | , safe-exceptions 52 | , shakespeare 53 | , text 54 | , time 55 | , unix-compat 56 | , unordered-containers 57 | , uri-bytestring 58 | , vault 59 | , wai >= 3.0 && < 4 60 | , wai-app-static 61 | , wai-extra >= 3.0.7 62 | , yaml 63 | hs-source-dirs: src 64 | default-language: Haskell2010 65 | ghc-options: -Wall 66 | 67 | executable wai-auth 68 | default-language: Haskell2010 69 | hs-source-dirs: app 70 | main-is: Main.hs 71 | build-depends: base 72 | , bytestring 73 | , cereal 74 | , clientsession 75 | , optparse-simple 76 | , optparse-applicative 77 | , wai-extra 78 | , wai-middleware-auth 79 | , warp 80 | ghc-options: -Wall -threaded -rtsopts -with-rtsopts=-N 81 | 82 | test-suite spec 83 | default-language: Haskell2010 84 | type: exitcode-stdio-1.0 85 | main-is: Main.hs 86 | hs-source-dirs: test 87 | other-modules: Network.Wai.Auth.Test 88 | , Spec.Network.Wai.Auth.Internal 89 | , Spec.Network.Wai.Middleware.Auth.OAuth2 90 | , Spec.Network.Wai.Middleware.Auth.OIDC 91 | build-depends: base 92 | , aeson 93 | , binary 94 | , bytestring 95 | , clientsession 96 | , cookie 97 | , hedgehog 98 | , hoauth2 99 | , http-types 100 | , jose 101 | , microlens 102 | , mtl 103 | , tasty 104 | , tasty-hedgehog 105 | , tasty-hunit 106 | , text 107 | , time 108 | , uri-bytestring 109 | , wai 110 | , wai-extra 111 | , wai-middleware-auth 112 | , warp 113 | ghc-options: -Wall -threaded -rtsopts -with-rtsopts=-N 114 | 115 | source-repository head 116 | type: git 117 | location: https://github.com/fpco/wai-middleware-auth 118 | --------------------------------------------------------------------------------