├── .gitignore ├── Makefile ├── .travis.yml ├── go.mod ├── pkg ├── authorizer │ ├── context.go │ ├── mocks.go │ ├── builder │ │ ├── default_context_builder.go │ │ └── default_policy_builder.go │ ├── response.go │ ├── jwt_test.go │ ├── utils.go │ ├── jwt.go │ └── response_test.go └── request │ └── auth │ ├── authorizer.go │ ├── README.md │ ├── cognito.go │ └── cognito_test.go ├── LICENSE ├── go.sum └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test: 2 | go test ./... -covermode=count 3 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.12 5 | - master 6 | 7 | script: 8 | - env GO111MODULE=on go test -v ./... 9 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/nordcloud/cognito-authorizer 2 | 3 | go 1.11 4 | 5 | require ( 6 | github.com/aws/aws-lambda-go v1.9.0 7 | github.com/aws/aws-sdk-go v1.18.6 8 | github.com/dgrijalva/jwt-go v3.2.0+incompatible 9 | github.com/pkg/errors v0.8.1 10 | github.com/sirupsen/logrus v1.4.0 11 | github.com/stretchr/testify v1.3.0 12 | golang.org/x/net v0.0.0-20190628185345-da137c7871d7 // indirect 13 | ) 14 | -------------------------------------------------------------------------------- /pkg/authorizer/context.go: -------------------------------------------------------------------------------- 1 | package authorizer 2 | 3 | /* 4 | Package abstracts out work needed to retrieve AWS Cognito JW token claims. 5 | */ 6 | 7 | // Context is a preset of data needed to build a response. 8 | type Context struct { 9 | Region string 10 | ApplicationID string 11 | Stage string 12 | AllowedUserPoolID string 13 | CognitoClients []string 14 | DecryptionKeys []JWKey 15 | } 16 | -------------------------------------------------------------------------------- /pkg/request/auth/authorizer.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | /* 4 | This package delivers the Signer that can be used to sign the http Request 5 | 6 | For now only Cognito M2M signer is implemented. It would be nice to add other ones like IAM signer 7 | */ 8 | 9 | import ( 10 | "net/http" 11 | ) 12 | 13 | // HeaderAdder is an interface to setup Authorization HTTP header. 14 | type HeaderAdder interface { 15 | Add(key, value string) 16 | } 17 | 18 | // RequestAuthorizer interface delivers method to authorize the http.Request 19 | type RequestAuthorizer interface { 20 | AuthorizeRequest(*http.Request) (*http.Request, error) 21 | AddAuthorizationHeader(headerAdder HeaderAdder) error 22 | } 23 | -------------------------------------------------------------------------------- /pkg/authorizer/mocks.go: -------------------------------------------------------------------------------- 1 | package authorizer 2 | 3 | import ( 4 | "github.com/aws/aws-lambda-go/events" 5 | "github.com/stretchr/testify/mock" 6 | ) 7 | 8 | type policyBuilderMock struct { 9 | mock.Mock 10 | } 11 | 12 | func (m *policyBuilderMock) BuildPolicy(encodedToken string) (events.APIGatewayCustomAuthorizerPolicy, error) { 13 | args := m.Called(encodedToken) 14 | return args.Get(0).(events.APIGatewayCustomAuthorizerPolicy), args.Error(1) 15 | } 16 | 17 | type contextBuilderMock struct { 18 | mock.Mock 19 | } 20 | 21 | func (m *contextBuilderMock) BuildContext(encodedToken string) (map[string]interface{}, error) { 22 | args := m.Called(encodedToken) 23 | return args.Get(0).(map[string]interface{}), args.Error(1) 24 | } 25 | -------------------------------------------------------------------------------- /pkg/request/auth/README.md: -------------------------------------------------------------------------------- 1 | ## Cognito M2M signer 2 | 3 | Cognito M2M signer generates authorization token using the Cognito Client Secret Key stored in the SSM. It can be used to authorize lambda or other compute resource using the Cognito Machine to Machine authorization. 4 | 5 | Use example: 6 | 7 | ``` 8 | ssmSession := session.Must(session.NewSession(&aws.Config{ 9 | Region: aws.String(os.Getenv("REGION")), 10 | })) 11 | sesCli := ssm.New(ssmSession) 12 | 13 | &auth.CognitoM2MSigner{ 14 | CognitoAPIURL: os.Getenv("COGNITO_API_URL"), 15 | ClientID: os.Getenv("COGNITO_APP_ID"), 16 | Scope: "https://scope_identifier_url/full-access", 17 | SsmClient: sesCli, 18 | SsmSecretName: "cognitoM2mSecret", 19 | } 20 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Nordcloud 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pkg/authorizer/builder/default_context_builder.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/nordcloud/cognito-authorizer/pkg/authorizer" 7 | log "github.com/sirupsen/logrus" 8 | ) 9 | 10 | // DefaultContextBuilder implements the ContextBuilder interface 11 | // It creates context with the list of scopes when using M2M authorization 12 | type DefaultContextBuilder struct { 13 | Context *authorizer.Context 14 | } 15 | 16 | // BuildContext builds a context that is passed to resource server. 17 | func (c *DefaultContextBuilder) BuildContext(encodedToken string) (map[string]interface{}, error) { 18 | baseClaims := authorizer.BaseTokenClaims{} 19 | err := authorizer.GetBaseClaims(encodedToken, c.Context.DecryptionKeys, &baseClaims) 20 | if err != nil { 21 | log.Error("Failed to get base claims.") 22 | return map[string]interface{}{}, err 23 | } 24 | 25 | if baseClaims.TokenUse == "access" { 26 | accessClaims := authorizer.AccessTokenClaims{} 27 | err := authorizer.GetAccessClaims(encodedToken, c.Context.DecryptionKeys, &accessClaims) 28 | if err != nil { 29 | return map[string]interface{}{}, err 30 | } 31 | 32 | return c.buildContextForAccessClaims(accessClaims) 33 | } 34 | 35 | idClaims := authorizer.IDTokenClaims{} 36 | err = authorizer.GetIDClaims(encodedToken, c.Context.DecryptionKeys, &idClaims) 37 | 38 | if err != nil { 39 | log.Error("Failed to get id claims.") 40 | return map[string]interface{}{}, err 41 | } 42 | 43 | return c.buildContextForIDClaims(idClaims) 44 | } 45 | 46 | func (c *DefaultContextBuilder) buildContextForAccessClaims(claims authorizer.AccessTokenClaims) (map[string]interface{}, error) { 47 | var scopesWithoutPrefix []string 48 | 49 | scopes := strings.Split(claims.Scope, " ") 50 | for _, s := range scopes { 51 | scopeStr := getScopeFromFullString(s) 52 | scopesWithoutPrefix = append(scopesWithoutPrefix, scopeStr) 53 | } 54 | 55 | return map[string]interface{}{ 56 | "scope": strings.Join(scopesWithoutPrefix, " "), 57 | }, nil 58 | } 59 | 60 | func (c *DefaultContextBuilder) buildContextForIDClaims(claims authorizer.IDTokenClaims) (map[string]interface{}, error) { 61 | return map[string]interface{}{ 62 | "email": claims.Email, 63 | }, nil 64 | } 65 | 66 | func getScopeFromFullString(scopeString string) string { 67 | segments := strings.Split(scopeString, "/") 68 | return segments[len(segments)-1] 69 | } 70 | -------------------------------------------------------------------------------- /pkg/authorizer/response.go: -------------------------------------------------------------------------------- 1 | package authorizer 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/aws/aws-lambda-go/events" 7 | log "github.com/sirupsen/logrus" 8 | ) 9 | 10 | // PolicyBuilder interface for building API GW custom authorizer policy. 11 | type PolicyBuilder interface { 12 | BuildPolicy(encodedToken string) (events.APIGatewayCustomAuthorizerPolicy, error) 13 | } 14 | 15 | // ContextBuilder interface for building context passed to resource server. 16 | type ContextBuilder interface { 17 | BuildContext(encodedToken string) (map[string]interface{}, error) 18 | } 19 | 20 | // ResponseBuilder struct for building proper custom authorizer response. 21 | type ResponseBuilder struct { 22 | Context *Context 23 | PolicyBuilder PolicyBuilder 24 | ContextBuilder ContextBuilder 25 | } 26 | 27 | // BuildResponse builds a proper custom authorizer response based on context, policy and context builders. 28 | func (b ResponseBuilder) BuildResponse(encodedToken string) (events.APIGatewayCustomAuthorizerResponse, error) { 29 | baseClaims := &BaseTokenClaims{} 30 | err := GetBaseClaims(encodedToken, b.Context.DecryptionKeys, baseClaims) 31 | if err != nil { 32 | log.WithField("error", err).Info("Failed to get token standard claims.") 33 | return events.APIGatewayCustomAuthorizerResponse{}, errors.New("Unauthorized") 34 | } 35 | 36 | valid := false 37 | for _, client := range b.Context.CognitoClients { 38 | valid = valid || baseClaims.VerifyAudience(client, true) || baseClaims.TokenUse == "access" // Only ID Token has audience field. 39 | } 40 | 41 | if !valid { 42 | log.WithField("audience", baseClaims.Audience).Error("Failed to verify token audience.") 43 | return events.APIGatewayCustomAuthorizerResponse{}, errors.New("Unauthorized") 44 | } 45 | 46 | policy, err := b.PolicyBuilder.BuildPolicy(encodedToken) 47 | if err != nil { 48 | log.WithField("error", err).Error("Failed to build policy document.") 49 | return events.APIGatewayCustomAuthorizerResponse{}, errors.New("Unauthorized") 50 | } 51 | 52 | context, err := b.ContextBuilder.BuildContext(encodedToken) 53 | if err != nil { 54 | log.WithField("error", err).Error("Failed to build context.") 55 | return events.APIGatewayCustomAuthorizerResponse{}, errors.New("Unauthorized") 56 | } 57 | 58 | return events.APIGatewayCustomAuthorizerResponse{ 59 | PrincipalID: baseClaims.Subject, 60 | PolicyDocument: policy, 61 | Context: context, 62 | }, nil 63 | } 64 | -------------------------------------------------------------------------------- /pkg/authorizer/builder/default_policy_builder.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/aws/aws-lambda-go/events" 9 | "github.com/nordcloud/cognito-authorizer/pkg/authorizer" 10 | log "github.com/sirupsen/logrus" 11 | ) 12 | 13 | type PolicyEffect string 14 | 15 | const ( 16 | allow PolicyEffect = "allow" 17 | deny PolicyEffect = "deny" 18 | ) 19 | 20 | // DefaultPolicyBuilder Implements Policy builder interface. 21 | // It grants full access for the M2M authorized compute resources 22 | // Other methods like `TokenUse = id` are forbidden 23 | type DefaultPolicyBuilder struct { 24 | Context *authorizer.Context 25 | Region string 26 | } 27 | 28 | // BuildPolicy builds proper apigw policy based on scope from claims. 29 | func (p *DefaultPolicyBuilder) BuildPolicy(encodedToken string) (events.APIGatewayCustomAuthorizerPolicy, error) { 30 | baseClaims := authorizer.BaseTokenClaims{} 31 | err := authorizer.GetBaseClaims(encodedToken, p.Context.DecryptionKeys, &baseClaims) 32 | if err != nil { 33 | log.Error("Failed to get base claims.") 34 | return events.APIGatewayCustomAuthorizerPolicy{}, err 35 | } 36 | 37 | var resources []string 38 | 39 | log.WithField("token_use", baseClaims.TokenUse).Debug("Token type.") 40 | if baseClaims.TokenUse == "access" { 41 | accessClaims := authorizer.AccessTokenClaims{} 42 | err = authorizer.GetAccessClaims(encodedToken, p.Context.DecryptionKeys, &accessClaims) 43 | if err != nil { 44 | log.Error("Failed to get access claims.") 45 | return events.APIGatewayCustomAuthorizerPolicy{}, err 46 | } 47 | 48 | resources = p.buildResourcesForAccessClaims(accessClaims) 49 | 50 | } else if baseClaims.TokenUse == "id" { 51 | resources = []string{} 52 | } else { 53 | log.WithField("token_use", baseClaims.TokenUse).Error("Unkown token use. Aborting") 54 | return events.APIGatewayCustomAuthorizerPolicy{}, errors.New("unknown token use") 55 | } 56 | 57 | policy := events.APIGatewayCustomAuthorizerPolicy{ 58 | Version: "2012-10-17", 59 | Statement: []events.IAMPolicyStatement{ 60 | { 61 | Action: []string{"execute-api:Invoke"}, 62 | Effect: string(allow), 63 | Resource: resources, 64 | }, 65 | }, 66 | } 67 | 68 | return policy, nil 69 | } 70 | 71 | func (p *DefaultPolicyBuilder) buildResourcesForAccessClaims(claims authorizer.AccessTokenClaims) []string { 72 | scopes := strings.Split(claims.Scope, " ") 73 | log.WithField("scopes", scopes).Info("Generating access for the scope") 74 | return []string{fmt.Sprintf( 75 | "arn:aws:execute-api:%s:*:%s/%s/*/*", p.Region, p.Context.ApplicationID, p.Context.Stage, 76 | )} 77 | } 78 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/aws/aws-lambda-go v1.9.0 h1:r9TWtk8ozLYdMW+aelUeWny8z2mjghJCMx6/uUwOLNo= 2 | github.com/aws/aws-lambda-go v1.9.0/go.mod h1:zUsUQhAUjYzR8AuduJPCfhBuKWUaDbQiPOG+ouzmE1A= 3 | github.com/aws/aws-sdk-go v1.18.6 h1:NuUz/+bi6C5v3BpIXW/VfovfMpvlhl1WUnD0EiDkOwQ= 4 | github.com/aws/aws-sdk-go v1.18.6/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= 5 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 7 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= 9 | github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= 10 | github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= 11 | github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= 12 | github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= 13 | github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= 14 | github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= 15 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 16 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 17 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 18 | github.com/sirupsen/logrus v1.4.0 h1:yKenngtzGh+cUSSh6GWbxW2abRqhYUSR/t/6+2QqNvE= 19 | github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= 20 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 21 | github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= 22 | github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 23 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 24 | github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= 25 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 26 | golang.org/x/crypto v0.0.0-20180904163835-0709b304e793 h1:u+LnwYTOOW7Ukr/fppxEb1Nwz0AtPflrblfvUudpo+I= 27 | golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= 28 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= 29 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 30 | golang.org/x/net v0.0.0-20190628185345-da137c7871d7 h1:rTIdg5QFRR7XCaK4LCjBiPbx8j4DQRpdYMnGn/bJUEU= 31 | golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 32 | golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8= 33 | golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 34 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= 35 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 36 | golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= 37 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 38 | -------------------------------------------------------------------------------- /pkg/authorizer/jwt_test.go: -------------------------------------------------------------------------------- 1 | package authorizer 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | var testKeyServerResponseBody = ` 13 | { 14 | "keys": [ 15 | { 16 | "alg": "RS256", 17 | "e": "AQAB", 18 | "kid": "abcdefghijklmnopqrsexample=", 19 | "kty": "RSA", 20 | "n": "lsjhglskjhgslkjgh43lj5h34lkjh34lkjht3example", 21 | "use": "sig" 22 | }, { 23 | "alg": "RS256", 24 | "e": "AQAB", 25 | "kid": "123456789", 26 | "kty": "RSA", 27 | "n": "3Nzq67VGAE3RNBN9DWuK-eIQ8LscppizsW9G1U7pUmqOM3-FgYYlWS-cMyIDROzyGNM6R6n0hwTehxyMiX9Ucwf6Q2Z9z0OMb8I0m918CBAYC3NJKWlpxt7O3keZam_U7wY4woYGBt01epJGi5-dIq8N5X2yQ2kx654YfTzrBR-23u8TC_05E1sYyqKPZtO2aasHGC9lFQD9-B2LeBEBChnDpc9pb8JriDibA5NNh-4ZC8RjqBkKTLGphkTDJ28HXYjtwV0yZJ05zwKlW_YWSCdiIh_nzaVKVziboCBaJVVknCEy5brjvLy_5v0HGxRzyeA0xkCauinS2L57JfO_SQ", 28 | "use": "sig" 29 | } 30 | ] 31 | }` 32 | 33 | func createTestKeyServer(body string) *httptest.Server { 34 | return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 35 | w.WriteHeader(200) 36 | w.Write([]byte(body)) 37 | })) 38 | } 39 | 40 | func TestGetDecryptionKey(t *testing.T) { 41 | testServer := createTestKeyServer(testKeyServerResponseBody) 42 | defer testServer.Close() 43 | 44 | keys, err := RequestKeys(testServer.URL) 45 | 46 | assert.Nil(t, err) 47 | assert.Equal(t, createTestKeys(), keys) 48 | } 49 | 50 | func TestGetDecryptionKeyNonRSAKeyError(t *testing.T) { 51 | body := ` 52 | { 53 | "keys": [ 54 | { 55 | "alg": "RS256", 56 | "e": "AQAB", 57 | "kid": "abcdefghijklmnopqrsexample=", 58 | "kty": "RSA", 59 | "n": "abcd", 60 | "use": "sig" 61 | }, { 62 | "alg": "RS256", 63 | "e": "AQAB", 64 | "kid": "123456789", 65 | "kty": "non-rsa-key", 66 | "n": "accd", 67 | "use": "sig" 68 | } 69 | ] 70 | }` 71 | testServer := createTestKeyServer(body) 72 | defer testServer.Close() 73 | 74 | keys, err := RequestKeys(testServer.URL) 75 | 76 | assert.NotNil(t, err) 77 | assert.Nil(t, keys) 78 | } 79 | 80 | func TestGetIDClaims(t *testing.T) { 81 | testEmail := "test@example.com" 82 | testSubject := "test-subject" 83 | testKeys := createTestKeys() 84 | idClaims := &IDTokenClaims{} 85 | token := createTestIDToken(testEmail, testSubject, "", nil) 86 | 87 | err := GetIDClaims(token, testKeys, idClaims) 88 | 89 | assert.Nil(t, err) 90 | assert.Equal(t, testEmail, idClaims.Email) 91 | assert.Equal(t, testSubject, idClaims.StandardClaims.Subject) 92 | } 93 | 94 | func TestGetIDClaimsExpired(t *testing.T) { 95 | testEmail := "test@example.com" 96 | testSubject := "test-subject" 97 | expiresAt := time.Now().Add(-time.Hour) 98 | testKeys := createTestKeys() 99 | idClaims := &IDTokenClaims{} 100 | token := createTestIDToken(testEmail, testSubject, "", &expiresAt) 101 | 102 | err := GetIDClaims(token, testKeys, idClaims) 103 | 104 | assert.NotNil(t, err) 105 | } 106 | 107 | func TestGetAccessClaims(t *testing.T) { 108 | testScope := "test-scope" 109 | testSubject := "test-subject" 110 | testKeys := createTestKeys() 111 | accessClaims := &AccessTokenClaims{} 112 | token := createTestAccessToken(testScope, testSubject, nil) 113 | 114 | err := GetAccessClaims(token, testKeys, accessClaims) 115 | 116 | assert.Nil(t, err) 117 | assert.Equal(t, testScope, accessClaims.Scope) 118 | assert.Equal(t, testSubject, accessClaims.StandardClaims.Subject) 119 | } 120 | 121 | func TestGetBaseClaims(t *testing.T) { 122 | testUse := "test-use" 123 | testSubject := "test-subject" 124 | testAudience := "test-audience" 125 | testKeys := createTestKeys() 126 | baseClaims := &BaseTokenClaims{} 127 | token := createTestBaseToken(testUse, testSubject, testAudience, nil) 128 | 129 | err := GetBaseClaims(token, testKeys, baseClaims) 130 | 131 | assert.Nil(t, err) 132 | assert.Equal(t, testUse, baseClaims.TokenUse) 133 | assert.Equal(t, testAudience, baseClaims.Audience) 134 | assert.Equal(t, testSubject, baseClaims.Subject) 135 | } 136 | -------------------------------------------------------------------------------- /pkg/authorizer/utils.go: -------------------------------------------------------------------------------- 1 | package authorizer 2 | 3 | import ( 4 | "time" 5 | 6 | jwt "github.com/dgrijalva/jwt-go" 7 | ) 8 | 9 | const rawKey = ` 10 | -----BEGIN RSA PRIVATE KEY----- 11 | MIIEpQIBAAKCAQEA3Nzq67VGAE3RNBN9DWuK+eIQ8LscppizsW9G1U7pUmqOM3+F 12 | gYYlWS+cMyIDROzyGNM6R6n0hwTehxyMiX9Ucwf6Q2Z9z0OMb8I0m918CBAYC3NJ 13 | KWlpxt7O3keZam/U7wY4woYGBt01epJGi5+dIq8N5X2yQ2kx654YfTzrBR+23u8T 14 | C/05E1sYyqKPZtO2aasHGC9lFQD9+B2LeBEBChnDpc9pb8JriDibA5NNh+4ZC8Rj 15 | qBkKTLGphkTDJ28HXYjtwV0yZJ05zwKlW/YWSCdiIh/nzaVKVziboCBaJVVknCEy 16 | 5brjvLy/5v0HGxRzyeA0xkCauinS2L57JfO/SQIDAQABAoIBAFrykb5MIC5B3RLv 17 | r4AWN91cTROESXEE0oIPS4DNBOFORY5JRcWnYrvOEikwKV55n9u/J3GZN7tdsvC6 18 | Pdjk2PahY1nb25S8wRjIRPemBcwgLHaSm5707HzbBR6dJzygHnPrAPaBT/wFnV8C 19 | 2w/lw0QkB7nnv79okwjuSjFQI4sw1xhtf/u809EUm1sDwPupKo0n096AiOMJfixi 20 | HnL27rJr9K31D1qTrjVJ0PTBVrOa+88kVX10y12iX+UkJMwJZ9i5misznPEeCSx+ 21 | BCkBwMGGK9a+6QGxkQuM6HXIkfidf+ITxRaYkRWvcx+qsqApWGEyARchlahYdf49 22 | a7icNe0CgYEA9wKA7dj0RsNSRR6nJ2ebunAcfIBmWttmTEMLqwS7NHDywfgdSkdB 23 | GR6Ef2ZjFYamjjipOlBR5ZtwsRcVLmn/BRGpgnl6B8ZTSTycNR/RhvPXPv+xvPPP 24 | Vk32GHvMstTxg0OUPqSkFLYHxPnAuszvsdlq68/b75C7Qnm3XO8Xp3sCgYEA5ObK 25 | +M1ppsuJ7/QL6XfVFxenBk6Ml1WEoxmXOUbv09BTlwvxMQYeL228IXiRxNq4vq8F 26 | Em00RcQkmiXITpbFPgKDSqQkKdyrVvgA3+UeXiQ9CdOVrdtlQIL3XRsT3LprtgSE 27 | WUSJKoHb+DtHabadHIajO5ONO6KRTuHiSLiSlwsCgYEA1YiNipAmRFIv+d7Q48i2 28 | oEqw5ZReZ6cJXV4MZTB24ZPO2I40S/UjOqLeKgCKIZ7At2wWJ3ouAk8I8Z6hyfkJ 29 | 5AjrwAZhzvzNHR/PbkFucbq0VhrXPSCMGfDVkT7cq7BYhIBUVH8h9WGTf93kldf6 30 | UoZA31BWslgs+f+c2zM6AKcCgYEA48CfhB8eWE9817vDfnE1HNzz21qcmJcGeiIk 31 | TWE/j0lhYpEHUvf7YMWWwtbscyoNV+1c5pCxyhj3MkkVnNx3NNPbPpFDSkO+V7I7 32 | bIrURGdaNETKUUpS3HVzGriucpkqQtkLtqZytFCxRbP1wkFo4dE06TpO9F80pYAr 33 | XqAHezECgYEA8nbfXFP1NjftiHS/6R6edwTcdK+AFruP1t4M/0njh/5aaWzJ8aBh 34 | KGnN0M08uvAI8cj3D86h+45gvkVV+ghB+MNh7uStVW5UNyiqPPqtCecm2YTenpnP 35 | bsPcF5WhzjCGzwujQaYtl5tySIWj1+wfiCzBq55jk5Tr74bWu8j1isg= 36 | -----END RSA PRIVATE KEY-----` 37 | 38 | func createTestKeys() []JWKey { 39 | return []JWKey{ 40 | JWKey{ 41 | Algorithm: "RS256", 42 | Exponent: "AQAB", 43 | KeyID: "abcdefghijklmnopqrsexample=", 44 | KeyType: "RSA", 45 | N: "lsjhglskjhgslkjgh43lj5h34lkjh34lkjht3example", 46 | Use: "sig", 47 | }, 48 | JWKey{ 49 | Algorithm: "RS256", 50 | Exponent: "AQAB", 51 | KeyID: "123456789", 52 | KeyType: "RSA", 53 | N: "3Nzq67VGAE3RNBN9DWuK-eIQ8LscppizsW9G1U7pUmqOM3-FgYYlWS-cMyIDROzyGNM6R6n0hwTehxyMiX9Ucwf6Q2Z9z0OMb8I0m918CBAYC3NJKWlpxt7O3keZam_U7wY4woYGBt01epJGi5-dIq8N5X2yQ2kx654YfTzrBR-23u8TC_05E1sYyqKPZtO2aasHGC9lFQD9-B2LeBEBChnDpc9pb8JriDibA5NNh-4ZC8RjqBkKTLGphkTDJ28HXYjtwV0yZJ05zwKlW_YWSCdiIh_nzaVKVziboCBaJVVknCEy5brjvLy_5v0HGxRzyeA0xkCauinS2L57JfO_SQ", 54 | Use: "sig", 55 | }, 56 | } 57 | } 58 | 59 | func createTestBaseToken(tokenUse, subject, audience string, expiresAt *time.Time) string { 60 | claims := BaseTokenClaims{ 61 | TokenUse: tokenUse, 62 | } 63 | claims.Subject = subject 64 | claims.Audience = audience 65 | 66 | if expiresAt != nil { 67 | claims.ExpiresAt = expiresAt.Unix() 68 | } 69 | 70 | rsaKey, _ := jwt.ParseRSAPrivateKeyFromPEM([]byte(rawKey)) 71 | 72 | token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims) 73 | token.Header["kid"] = "123456789" 74 | tokenString, _ := token.SignedString(rsaKey) 75 | 76 | return tokenString 77 | } 78 | 79 | func createTestIDToken(username, subject, audience string, expiresAt *time.Time) string { 80 | claims := IDTokenClaims{ 81 | Email: username, 82 | } 83 | claims.Subject = subject 84 | claims.Audience = audience 85 | claims.TokenUse = "id" 86 | 87 | if expiresAt != nil { 88 | claims.ExpiresAt = expiresAt.Unix() 89 | } 90 | 91 | rsaKey, _ := jwt.ParseRSAPrivateKeyFromPEM([]byte(rawKey)) 92 | 93 | token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims) 94 | token.Header["kid"] = "123456789" 95 | tokenString, _ := token.SignedString(rsaKey) 96 | 97 | return tokenString 98 | } 99 | 100 | func createTestAccessToken(scope, subject string, expiresAt *time.Time) string { 101 | claims := AccessTokenClaims{ 102 | Scope: scope, 103 | } 104 | claims.Subject = subject 105 | claims.TokenUse = "access" 106 | 107 | if expiresAt != nil { 108 | claims.ExpiresAt = expiresAt.Unix() 109 | } 110 | 111 | rsaKey, _ := jwt.ParseRSAPrivateKeyFromPEM([]byte(rawKey)) 112 | 113 | token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims) 114 | token.Header["kid"] = "123456789" 115 | tokenString, _ := token.SignedString(rsaKey) 116 | 117 | return tokenString 118 | 119 | } 120 | -------------------------------------------------------------------------------- /pkg/request/auth/cognito.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "encoding/json" 7 | "fmt" 8 | "io/ioutil" 9 | "net/http" 10 | "time" 11 | 12 | "github.com/aws/aws-sdk-go/aws" 13 | "github.com/aws/aws-sdk-go/service/ssm" 14 | "github.com/aws/aws-sdk-go/service/ssm/ssmiface" 15 | "github.com/pkg/errors" 16 | 17 | log "github.com/sirupsen/logrus" 18 | ) 19 | 20 | const GrantClientCredentials = "client_credentials" 21 | 22 | var ( 23 | cachedToken *tokenCache 24 | cachedSecretKey *string 25 | ) 26 | 27 | type token struct { 28 | AccessToken string `json:"access_token"` 29 | ExpiresIn int `json:"expires_in"` 30 | TokenType string `json:"token_type"` 31 | } 32 | 33 | type tokenCache struct { 34 | token *token 35 | timestamp time.Time 36 | } 37 | 38 | // CognitoM2MAuthorizer implements the Signer interface 39 | // It reads the Cognito App secret key from the SSM parameter store and uses it to create the Authorization token. 40 | // cognitoAPIURL is the URL configured in the Cognito Resource servers 41 | // clientID is the cognito app client ID 42 | // scope is the OAuth scope name without the API URL - it will be concatenated automatically 43 | type CognitoM2MAuthorizer struct { 44 | CognitoAPIURL string 45 | ClientID string 46 | Scope string 47 | 48 | SsmClient ssmiface.SSMAPI 49 | SsmSecretName string 50 | } 51 | 52 | // Sign method signs request using cognito M2M authentication token 53 | func (s *CognitoM2MAuthorizer) AuthorizeRequest(request *http.Request) (*http.Request, error) { 54 | err := s.AddAuthorizationHeader(request.Header) 55 | return request, err 56 | } 57 | 58 | // AddAuthorizationHeader adds Authorization HTTP header. 59 | func (s *CognitoM2MAuthorizer) AddAuthorizationHeader(headerAdder HeaderAdder) error { 60 | secret, err := s.getSecretKey() 61 | if err != nil { 62 | return errors.Wrap(err, "Failed to sign http Request") 63 | } 64 | token, err := s.getCognitoToken(secret) 65 | if err != nil { 66 | return errors.Wrap(err, "Failed to sign http Request") 67 | } 68 | headerAdder.Add("Authorization", *token) 69 | return nil 70 | } 71 | 72 | // GetSecretKey retrieves an secret API key from SSM parameter store using provided parameter name. 73 | // It returns the API key value, SSM parameter version and an error if any occurred. 74 | func (s *CognitoM2MAuthorizer) getSecretKey() (*string, error) { 75 | if cachedSecretKey != nil { 76 | return cachedSecretKey, nil 77 | } 78 | input := ssm.GetParameterInput{ 79 | Name: &s.SsmSecretName, 80 | WithDecryption: aws.Bool(true), 81 | } 82 | param, err := s.SsmClient.GetParameter(&input) 83 | if err != nil { 84 | log.WithError(err).WithField("ssmSecretName", s.SsmSecretName).Error("Failed to get secret API key from SSM") 85 | return nil, errors.Wrap(err, "Failed to get secret API key from SSM") 86 | } 87 | cachedSecretKey = param.Parameter.Value 88 | 89 | return param.Parameter.Value, nil 90 | } 91 | 92 | func (s *CognitoM2MAuthorizer) getCognitoToken(secret *string) (*string, error) { 93 | if token := getTokenFromCache(); token != nil { 94 | return token, nil 95 | } 96 | 97 | req, err := s.buildTokenRequest(secret) 98 | if err != nil { 99 | return nil, errors.Wrap(err, "Failed to build token request") 100 | } 101 | 102 | resp, err := http.DefaultClient.Do(req) 103 | if err != nil { 104 | return nil, errors.Wrap(err, "Failed to send token request") 105 | } 106 | if resp != nil && resp.StatusCode > 299 { 107 | resBytes, _ := ioutil.ReadAll(resp.Body) 108 | log.WithFields(log.Fields{ 109 | "code": resp.StatusCode, 110 | "method": req.Method, 111 | "body": string(resBytes), 112 | "url": req.URL.String()}).Error("Cognito API token call returned error") 113 | return nil, fmt.Errorf("cognito API token call returned error code: %d", resp.StatusCode) 114 | } 115 | 116 | var responseToken token 117 | err = json.NewDecoder(resp.Body).Decode(&responseToken) 118 | if err != nil { 119 | return nil, errors.Wrap(err, "Failed to decode cognito token") 120 | } 121 | 122 | saveTokenInCache(&responseToken) 123 | return &responseToken.AccessToken, nil 124 | 125 | } 126 | 127 | func (s *CognitoM2MAuthorizer) buildTokenRequest(secret *string) (*http.Request, error) { 128 | payload := fmt.Sprintf("grant_type=%s&scope=%s", GrantClientCredentials, s.Scope) 129 | reader := bytes.NewReader([]byte(payload)) 130 | 131 | req, err := http.NewRequest(http.MethodPost, s.CognitoAPIURL, reader) 132 | if err != nil { 133 | return nil, err 134 | } 135 | req.Header["Content-Type"] = []string{"application/x-www-form-urlencoded"} 136 | req.Header["Authorization"] = []string{buildAuthHeader(s.ClientID, *secret)} 137 | 138 | return req, nil 139 | } 140 | 141 | func buildAuthHeader(clientID, secret string) string { 142 | auth := fmt.Sprintf("%s:%s", clientID, secret) 143 | return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(auth))) 144 | } 145 | 146 | func getTokenFromCache() *string { 147 | if cachedToken == nil || cachedToken.token == nil { 148 | return nil 149 | } 150 | 151 | if cachedToken.timestamp.Add(time.Duration(cachedToken.token.ExpiresIn-5) * time.Second).Before(time.Now()) { 152 | return nil 153 | } 154 | return &cachedToken.token.AccessToken 155 | } 156 | 157 | func saveTokenInCache(token *token) { 158 | cachedToken = &tokenCache{ 159 | token: token, 160 | timestamp: time.Now(), 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/nordcloud/cognito-authorizer.svg?branch=master)](https://travis-ci.org/nordcloud/cognito-authorizer) [![Go Report Card](https://goreportcard.com/badge/github.com/nordcloud/cognito-authorizer)](https://goreportcard.com/report/github.com/nordcloud/cognito-authorizer) 2 | 3 | # Cognito authorizer 4 | A golang packages that abstract out work with JSON web access/identity tokens for AWS API Gateway custom authorizer. 5 | 6 | These packages handle: 7 | 8 | - access, id and standard tokens 9 | - token verification 10 | - token payload decrypting (claims) 11 | - building proper responses from a custom authorizer 12 | - a M2M token signer helper 13 | 14 | You don't need to worry about JWT. The `GetIDClaims`, `GetAccessClaims` and `GetStandardClaims` will do the work for you, so you can focus only on building `APIGatewayCustomAuthorizerPolicy`. 15 | 16 | ### Docs 17 | 18 | - [authorizer](https://godoc.org/github.com/nordcloud/cognito-authorizer/pkg/authorizer#pkg-index) 19 | - [default builder](https://godoc.org/github.com/nordcloud/cognito-authorizer/pkg/authorizer/builder) 20 | - [request signer](https://godoc.org/github.com/nordcloud/cognito-authorizer/pkg/request/auth) 21 | 22 | 23 | ### About resource server context 24 | You can pass a context created by your custom authorizer to the resource server. This is done by satisfying ContextBuilder interface. The method should return a `map[string]interface{}` (this is how AWS golang SDK works) but keys and values in this map have to be *strings*. More info [here](https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-lambda-authorizer-output.html). 25 | 26 | 27 | ## Example 28 | 29 | Custom authorizer main package 30 | ```go 31 | package main 32 | 33 | import ( 34 | "context" 35 | "fmt" 36 | "os" 37 | "strings" 38 | 39 | "github.com/aws/aws-lambda-go/events" 40 | "github.com/aws/aws-lambda-go/lambda" 41 | log "github.com/sirupsen/logrus" 42 | 43 | cognitoAuthorizer "github.com/nordcloud/cognito-authorizer/pkg/authorizer" 44 | ) 45 | 46 | type PolicyEffect string 47 | 48 | const ( 49 | allow PolicyEffect = "allow" 50 | deny PolicyEffect = "deny" 51 | ) 52 | 53 | type Policy struct { 54 | Context *cognitoAuthorizer.Context 55 | } 56 | 57 | func (b Policy) BuildPolicy(encodedToken string) (events.APIGatewayCustomAuthorizerPolicy, error) { 58 | accessClaims := &cognitoAuthorizer.AccessTokenClaims{} 59 | err := cognitoAuthorizer.GetAccessClaims(encodedToken, b.Context.DecryptionKeys, accessClaims) 60 | if err != nil { 61 | return events.APIGatewayCustomAuthorizerPolicy{}, err 62 | } 63 | 64 | resources := []string{ 65 | fmt.Sprintf( 66 | "arn:aws:execute-api:%s:*:%s/%s/*/*", 67 | b.Context.Region, 68 | b.Context.ApplicationID, 69 | b.Context.Stage, 70 | ), 71 | } 72 | 73 | policy := events.APIGatewayCustomAuthorizerPolicy{ 74 | Version: "2012-10-17", 75 | Statement: []events.IAMPolicyStatement{ 76 | { 77 | Action: []string{"execute-api:Invoke"}, 78 | Effect: string(allow), 79 | Resource: resources, 80 | }, 81 | }, 82 | } 83 | 84 | return policy, nil 85 | } 86 | 87 | func (b Policy) BuildContext(encodedToken string) (map[string]interface{}, error) { 88 | accessClaims := &cognitoAuthorizer.AccessTokenClaims{} 89 | err := cognitoAuthorizer.GetAccessClaims(encodedToken, b.Context.DecryptionKeys, accessClaims) 90 | if err != nil { 91 | return map[string]interface{}{}, err 92 | } 93 | 94 | context := map[string]interface{}{ 95 | "token_scope": accessClaims.Scope, 96 | } 97 | 98 | return context, nil 99 | } 100 | 101 | var ( 102 | sharedContext *cognitoAuthorizer.Context 103 | ) 104 | 105 | // Init is called on lambda cold start. In this function we pull Cognito keys to verify tokens. 106 | func init() { 107 | sharedContext = &cognitoAuthorizer.Context{ 108 | Region: os.Getenv("REGION"), 109 | ApplicationID: os.Getenv("API_APPLICATION_ID"), 110 | Stage: os.Getenv("API_STAGE"), 111 | AllowedUserPoolID: os.Getenv("API_ALLOWED_USER_POOL_ID"), 112 | CognitoClients: strings.Split(os.Getenv("COGNITO_CLIENTS"), ","), 113 | DecryptionKeys: nil, 114 | } 115 | 116 | keys, err := cognitoAuthorizer.GetDecryptionKeys(sharedContext.Region, sharedContext.AllowedUserPoolID) 117 | if err != nil { 118 | log.WithField("error", err).Error("Unable to get decryption keys.") 119 | } 120 | 121 | sharedContext.DecryptionKeys = keys 122 | 123 | log.WithFields(log.Fields{ 124 | "region": sharedContext.Region, 125 | "application_id": sharedContext.ApplicationID, 126 | "stage": sharedContext.Stage, 127 | "allowed_user_pool_id": sharedContext.AllowedUserPoolID, 128 | "cognito_clients": sharedContext.CognitoClients, 129 | }).Info("Finished initialization.") 130 | } 131 | 132 | func handler(ctx context.Context, event events.APIGatewayCustomAuthorizerRequest) ( 133 | events.APIGatewayCustomAuthorizerResponse, error) { 134 | log.WithField("method_arn", event.MethodArn).Info("Authorizer lambda invoked.") 135 | 136 | policy := &Policy{ 137 | Context: sharedContext, 138 | } 139 | 140 | responseBuilder := &cognitoAuthorizer.ResponseBuilder{ 141 | Context: sharedContext, 142 | PolicyBuilder: policy, 143 | ContextBuilder: policy, 144 | } 145 | 146 | return responseBuilder.BuildResponse(event.AuthorizationToken) 147 | } 148 | 149 | func main() { 150 | lambda.Start(handler) 151 | } 152 | ``` 153 | ## Authors 154 | - Grzegorz Bednarski, Nordcloud 🇵🇱 155 | - Kamil Piotrowski, Nordcloud 🇵🇱 156 | - Artur Kowalski, Nordcloud 🇵🇱 157 | -------------------------------------------------------------------------------- /pkg/authorizer/jwt.go: -------------------------------------------------------------------------------- 1 | package authorizer 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rsa" 6 | "crypto/x509" 7 | "encoding/base64" 8 | "encoding/binary" 9 | "encoding/json" 10 | "encoding/pem" 11 | "errors" 12 | "fmt" 13 | "io/ioutil" 14 | "math/big" 15 | "net/http" 16 | 17 | jwt "github.com/dgrijalva/jwt-go" 18 | ) 19 | 20 | // BaseTokenClaims is a common structure for token data. 21 | type BaseTokenClaims struct { 22 | TokenUse string `json:"token_use"` 23 | jwt.StandardClaims 24 | } 25 | 26 | // IDTokenClaims represents claims stored in ID type JW token 27 | type IDTokenClaims struct { 28 | EmailVerified bool `json:"email_verified"` 29 | AuthTime int64 `json:"auth_time"` 30 | CognitoUsername string `json:"cognito:username"` 31 | GivenName string `json:"given_name"` 32 | Email string `json:"email"` 33 | BaseTokenClaims 34 | } 35 | 36 | // AccessTokenClaims represents claims stored in Access type JW token. 37 | type AccessTokenClaims struct { 38 | AuthTime int64 `json:"auth_time"` 39 | Scope string `json:"scope"` 40 | Username string `json:"username"` 41 | BaseTokenClaims 42 | } 43 | 44 | // JWKey struct holds information about JSON web key. 45 | type JWKey struct { 46 | Algorithm string `json:"alg"` 47 | Exponent string `json:"e"` 48 | KeyID string `json:"kid"` 49 | KeyType string `json:"kty"` 50 | N string `json:"n"` 51 | Use string `json:"use"` 52 | } 53 | 54 | const cognitoKeyRetrieveURLTemplate = "https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json" 55 | 56 | type jwkResponse struct { 57 | Keys []JWKey `json:"keys"` 58 | } 59 | 60 | // GetDecryptionKeys gets JW token description keys from AWS Cognito service. 61 | func GetDecryptionKeys(region, userPoolID string) ([]JWKey, error) { 62 | url := fmt.Sprintf(cognitoKeyRetrieveURLTemplate, region, userPoolID) 63 | return RequestKeys(url) 64 | } 65 | 66 | // RequestKeys retrieves decryption keys from external service. 67 | func RequestKeys(url string) ([]JWKey, error) { 68 | response, err := http.Get(url) 69 | 70 | if err != nil { 71 | return nil, err 72 | } 73 | 74 | defer response.Body.Close() 75 | 76 | body, err := ioutil.ReadAll(response.Body) 77 | if err != nil { 78 | return nil, err 79 | } 80 | 81 | var tokenResponse jwkResponse 82 | err = json.Unmarshal(body, &tokenResponse) 83 | if err != nil { 84 | return nil, err 85 | } 86 | 87 | for _, key := range tokenResponse.Keys { 88 | if key.KeyType != "RSA" { 89 | return nil, errors.New("key type is not an RSA") 90 | } 91 | } 92 | 93 | return tokenResponse.Keys, nil 94 | } 95 | 96 | // getDecryptionKey searches for key by ID in list of keys. 97 | func getDecryptionKey(keys []JWKey, keyID string) (*JWKey, error) { 98 | for _, jwt := range keys { 99 | if jwt.KeyID == keyID { 100 | return &jwt, nil 101 | } 102 | } 103 | 104 | return nil, fmt.Errorf("%s key not found", keyID) 105 | } 106 | 107 | func getKeyForToken(keys []JWKey) func(token *jwt.Token) (interface{}, error) { 108 | f := func(token *jwt.Token) (interface{}, error) { 109 | keyID, ok := token.Header["kid"].(string) 110 | if !ok { 111 | return nil, errors.New("key id is not a string") 112 | } 113 | 114 | jwk, err := getDecryptionKey(keys, keyID) 115 | 116 | if err != nil { 117 | return nil, err 118 | } 119 | 120 | pemString, err := convertJWKtoPEMString(*jwk) 121 | 122 | if err != nil { 123 | return nil, err 124 | } 125 | 126 | return jwt.ParseRSAPublicKeyFromPEM([]byte(*pemString)) 127 | } 128 | return f 129 | } 130 | 131 | // GetIDClaims fills claims with ID type token data. 132 | func GetIDClaims(encodedToken string, keys []JWKey, claims *IDTokenClaims) error { 133 | _, err := jwt.ParseWithClaims(encodedToken, claims, getKeyForToken(keys)) 134 | if err != nil { 135 | return err 136 | } 137 | 138 | return nil 139 | } 140 | 141 | // GetAccessClaims fills claims with Access type token data. 142 | func GetAccessClaims(encodedToken string, keys []JWKey, claims *AccessTokenClaims) error { 143 | _, err := jwt.ParseWithClaims(encodedToken, claims, getKeyForToken(keys)) 144 | 145 | if err != nil { 146 | return err 147 | } 148 | 149 | return nil 150 | } 151 | 152 | // GetStandardClaims fills claims with standard token type data. 153 | func GetBaseClaims(encodedToken string, keys []JWKey, claims *BaseTokenClaims) error { 154 | _, err := jwt.ParseWithClaims(encodedToken, claims, getKeyForToken(keys)) 155 | 156 | if err != nil { 157 | return err 158 | } 159 | 160 | return nil 161 | } 162 | 163 | // converts JWK key type to PEM key type. 164 | func convertJWKtoPEMString(jwk JWKey) (*string, error) { 165 | nb, err := base64.RawURLEncoding.DecodeString(jwk.N) 166 | if err != nil { 167 | return nil, err 168 | } 169 | 170 | eb, err := base64.RawURLEncoding.DecodeString(jwk.Exponent) 171 | if err != nil { 172 | return nil, err 173 | } 174 | 175 | if len(eb) > 4 { 176 | return nil, errors.New("e is not a uint32") 177 | } 178 | 179 | // if byte array has less than four items we need to add leading zeros to match uint32 byte lengths 180 | if len(eb) < 4 { 181 | ndata := make([]byte, 4) 182 | copy(ndata[4-len(eb):], eb) 183 | eb = ndata 184 | } 185 | 186 | e := binary.BigEndian.Uint32(eb) 187 | 188 | pk := &rsa.PublicKey{ 189 | N: new(big.Int).SetBytes(nb), 190 | E: int(e), 191 | } 192 | 193 | der, err := x509.MarshalPKIXPublicKey(pk) 194 | if err != nil { 195 | return nil, err 196 | } 197 | 198 | block := &pem.Block{ 199 | Type: "RSA PUBLIC KEY", 200 | Bytes: der, 201 | } 202 | 203 | var out bytes.Buffer 204 | pem.Encode(&out, block) 205 | 206 | output := out.String() 207 | 208 | return &output, nil 209 | } 210 | -------------------------------------------------------------------------------- /pkg/request/auth/cognito_test.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | "time" 11 | 12 | "github.com/aws/aws-sdk-go/aws" 13 | "github.com/aws/aws-sdk-go/service/ssm" 14 | "github.com/aws/aws-sdk-go/service/ssm/ssmiface" 15 | "github.com/stretchr/testify/assert" 16 | "github.com/stretchr/testify/mock" 17 | ) 18 | 19 | const ( 20 | appID = "testAppID" 21 | expectedToken = "Basic dGVzdEFwcElEOnRlc3RTZWNyZXQ=" 22 | testScope = "full" 23 | testCognitoURL = "testAPIURL" 24 | testSecretName = "secret" 25 | ) 26 | 27 | var ( 28 | tokenValue = "accessToken" 29 | secret = "testSecret" 30 | ) 31 | 32 | type MockedSSM struct { 33 | ssmiface.SSMAPI 34 | mock.Mock 35 | } 36 | 37 | // GetParameter mocks ssm.GetParameter. 38 | func (m *MockedSSM) GetParameter(in *ssm.GetParameterInput) (*ssm.GetParameterOutput, error) { 39 | args := m.Called(in) 40 | if args.Get(0) == nil { 41 | return nil, args.Error(1) 42 | } 43 | return args.Get(0).(*ssm.GetParameterOutput), args.Error(1) 44 | } 45 | 46 | func Test_saveTokenInCache(t *testing.T) { 47 | token := &token{ 48 | AccessToken: "access", 49 | ExpiresIn: 3600, 50 | TokenType: "B", 51 | } 52 | 53 | cachedToken = nil 54 | saveTokenInCache(token) 55 | assert.NotNil(t, cachedToken) 56 | assert.Equal(t, token, cachedToken.token) 57 | } 58 | 59 | func Test_buildAuthHeader(t *testing.T) { 60 | token := buildAuthHeader(appID, secret) 61 | assert.Equal(t, expectedToken, token) 62 | } 63 | 64 | func Test_getTokenFromCache(t *testing.T) { 65 | tests := []struct { 66 | name string 67 | want *string 68 | testCache *tokenCache 69 | }{ 70 | {name: "emptyCache", want: nil, testCache: nil}, 71 | {name: "oldCache", want: nil, testCache: &tokenCache{ 72 | token: &token{AccessToken: tokenValue, ExpiresIn: 3000}, 73 | timestamp: time.Now().Add(-3500 * time.Second), 74 | }}, 75 | {name: "okCache", want: &tokenValue, testCache: &tokenCache{ 76 | token: &token{AccessToken: tokenValue, ExpiresIn: 3000}, 77 | timestamp: time.Now(), 78 | }}, 79 | } 80 | for _, tt := range tests { 81 | t.Run(tt.name, func(t *testing.T) { 82 | cachedToken = tt.testCache 83 | 84 | got := getTokenFromCache() 85 | if tt.want != nil { 86 | assert.Equal(t, *tt.want, *got) 87 | } else { 88 | assert.Nil(t, got) 89 | } 90 | }) 91 | } 92 | } 93 | 94 | func TestCognitoM2MSigner_getCognitoToken(t *testing.T) { 95 | testSecret := "897wgagf97w9f" 96 | testToken := fmt.Sprintf("{\"access_token\": \"%s\"}, \"expires_in\": 3000}", tokenValue) 97 | cachedToken = nil 98 | 99 | signer := &CognitoM2MAuthorizer{ 100 | CognitoAPIURL: testCognitoURL, 101 | ClientID: appID, 102 | Scope: testScope, 103 | } 104 | 105 | tests := []struct { 106 | name string 107 | want *string 108 | wantErr bool 109 | respCode int 110 | respBody string 111 | }{ 112 | {name: "responseError", want: nil, wantErr: true, respCode: 400, respBody: ""}, 113 | {name: "decodeError", want: nil, wantErr: true, respCode: 200, respBody: "invalid json"}, 114 | {name: "getTokenOk", want: &tokenValue, wantErr: false, respCode: 200, respBody: testToken}, 115 | } 116 | testServer := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { 117 | res.Write([]byte("body")) 118 | })) 119 | defer func() { testServer.Close() }() 120 | 121 | for _, tt := range tests { 122 | t.Run(tt.name, func(t *testing.T) { 123 | testServer := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { 124 | res.WriteHeader(tt.respCode) 125 | res.Write([]byte(tt.respBody)) 126 | })) 127 | 128 | signer.CognitoAPIURL = testServer.URL 129 | got, err := signer.getCognitoToken(&testSecret) 130 | if (err != nil) != tt.wantErr { 131 | t.Errorf("CognitoM2MSigner.getCognitoToken() error = %v, wantErr %v", err, tt.wantErr) 132 | return 133 | } 134 | if tt.want != nil { 135 | assert.Equal(t, *tt.want, *got) 136 | } 137 | }) 138 | } 139 | } 140 | 141 | func Test_buildTokenRequest(t *testing.T) { 142 | signer := &CognitoM2MAuthorizer{ 143 | CognitoAPIURL: "testAPIURL", 144 | ClientID: appID, 145 | Scope: "testAPIURL/" + testScope, 146 | } 147 | 148 | req, err := signer.buildTokenRequest(aws.String(secret)) 149 | 150 | assert.Nil(t, err) 151 | assert.NotNil(t, req) 152 | assert.Equal(t, expectedToken, req.Header["Authorization"][0]) 153 | assert.Equal(t, "application/x-www-form-urlencoded", req.Header["Content-Type"][0]) 154 | 155 | buf := new(bytes.Buffer) 156 | buf.ReadFrom(req.Body) 157 | data := buf.String() 158 | 159 | assert.Equal(t, data, "grant_type=client_credentials&scope=testAPIURL/full") 160 | } 161 | 162 | func TestCognitoM2MSigner_getSecretKey(t *testing.T) { 163 | signer := &CognitoM2MAuthorizer{ 164 | CognitoAPIURL: "testAPIURL", 165 | ClientID: appID, 166 | Scope: testScope, 167 | SsmSecretName: testSecretName, 168 | } 169 | 170 | tests := []getSecretKeyTestCase{ 171 | {name: "GetSSMErr", want: nil, wantErr: true}, 172 | {name: "GetSSMOk", want: &secret, wantErr: false}, 173 | {name: "GetSecretCachedInPrevSSMCall", want: &secret, readFromCache: true}, 174 | } 175 | for _, tt := range tests { 176 | t.Run(tt.name, func(t *testing.T) { 177 | mockSSM := mockSSMCall(tt, signer.SsmSecretName) 178 | signer.SsmClient = mockSSM 179 | 180 | got, err := signer.getSecretKey() 181 | if (err != nil) != tt.wantErr { 182 | t.Errorf("CognitoM2MSigner.getSecretKey() error = %v, wantErr %v", err, tt.wantErr) 183 | return 184 | } 185 | if got != tt.want { 186 | t.Errorf("CognitoM2MSigner.getSecretKey() = %v, want %v", got, tt.want) 187 | } 188 | }) 189 | } 190 | } 191 | 192 | type getSecretKeyTestCase struct { 193 | name string 194 | want *string 195 | wantErr bool 196 | readFromCache bool 197 | } 198 | 199 | func mockSSMCall(testCase getSecretKeyTestCase, ssmSecretName string) *MockedSSM { 200 | ssmMock := new(MockedSSM) 201 | if testCase.readFromCache { 202 | return ssmMock 203 | } 204 | if testCase.wantErr { 205 | ssmMock.On("GetParameter", mock.Anything).Return(nil, errors.New("err")) 206 | return ssmMock 207 | } 208 | ssmMock.On("GetParameter", &ssm.GetParameterInput{ 209 | Name: &ssmSecretName, 210 | WithDecryption: aws.Bool(true), 211 | }).Return(&ssm.GetParameterOutput{Parameter: &ssm.Parameter{Value: &secret}}, nil) 212 | return ssmMock 213 | } 214 | -------------------------------------------------------------------------------- /pkg/authorizer/response_test.go: -------------------------------------------------------------------------------- 1 | package authorizer 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/aws/aws-lambda-go/events" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestBuildResponseOk(t *testing.T) { 12 | testEmail := "test@example.com" 13 | testSubject := "test-subject" 14 | testAudience := "test-audience" 15 | testKeys := createTestKeys() 16 | token := createTestIDToken(testEmail, testSubject, testAudience, nil) 17 | 18 | policyBuilderMock := new(policyBuilderMock) 19 | policyBuilderMock.On("BuildPolicy", token).Return(events.APIGatewayCustomAuthorizerPolicy{}, nil).Once() 20 | contextBuilderMock := new(contextBuilderMock) 21 | contextBuilderMock.On("BuildContext", token).Return(map[string]interface{}{}, nil).Once() 22 | 23 | responseBuilder := ResponseBuilder{ 24 | Context: &Context{ 25 | DecryptionKeys: testKeys, 26 | CognitoClients: []string{testAudience}, 27 | }, 28 | PolicyBuilder: policyBuilderMock, 29 | ContextBuilder: contextBuilderMock, 30 | } 31 | 32 | response, err := responseBuilder.BuildResponse(token) 33 | 34 | assert.Nil(t, err) 35 | assert.Equal(t, events.APIGatewayCustomAuthorizerResponse{ 36 | PrincipalID: testSubject, 37 | PolicyDocument: events.APIGatewayCustomAuthorizerPolicy{}, 38 | Context: map[string]interface{}{}, 39 | }, response) 40 | policyBuilderMock.AssertExpectations(t) 41 | contextBuilderMock.AssertExpectations(t) 42 | } 43 | 44 | func TestBuildResponsePolicyBuilderError(t *testing.T) { 45 | testKeys := createTestKeys() 46 | testAudience := "test-audience" 47 | token := createTestIDToken("", "", testAudience, nil) 48 | 49 | policyBuilderMock := new(policyBuilderMock) 50 | policyBuilderMock.On("BuildPolicy", token).Return(events.APIGatewayCustomAuthorizerPolicy{}, errors.New("error")).Once() 51 | contextBuilderMock := new(contextBuilderMock) 52 | 53 | responseBuilder := ResponseBuilder{ 54 | Context: &Context{ 55 | DecryptionKeys: testKeys, 56 | CognitoClients: []string{testAudience}, 57 | }, 58 | PolicyBuilder: policyBuilderMock, 59 | ContextBuilder: contextBuilderMock, 60 | } 61 | 62 | response, err := responseBuilder.BuildResponse(token) 63 | 64 | assert.NotNil(t, err) 65 | assert.Equal(t, events.APIGatewayCustomAuthorizerResponse{}, response) 66 | policyBuilderMock.AssertExpectations(t) 67 | contextBuilderMock.AssertExpectations(t) 68 | } 69 | 70 | func TestBuildResponseContextBuilderError(t *testing.T) { 71 | testKeys := createTestKeys() 72 | testAudience := "test-audience" 73 | token := createTestIDToken("", "", testAudience, nil) 74 | 75 | policyBuilderMock := new(policyBuilderMock) 76 | policyBuilderMock.On("BuildPolicy", token).Return(events.APIGatewayCustomAuthorizerPolicy{}, nil).Once() 77 | contextBuilderMock := new(contextBuilderMock) 78 | contextBuilderMock.On("BuildContext", token).Return(map[string]interface{}{}, errors.New("error")).Once() 79 | 80 | responseBuilder := ResponseBuilder{ 81 | Context: &Context{ 82 | DecryptionKeys: testKeys, 83 | CognitoClients: []string{testAudience}, 84 | }, 85 | PolicyBuilder: policyBuilderMock, 86 | ContextBuilder: contextBuilderMock, 87 | } 88 | 89 | response, err := responseBuilder.BuildResponse(token) 90 | 91 | assert.NotNil(t, err) 92 | assert.Equal(t, events.APIGatewayCustomAuthorizerResponse{}, response) 93 | policyBuilderMock.AssertExpectations(t) 94 | contextBuilderMock.AssertExpectations(t) 95 | } 96 | 97 | func TestBuildResponseBadToken(t *testing.T) { 98 | testKeys := createTestKeys() 99 | token := "bad-token" 100 | policyBuilderMock := new(policyBuilderMock) 101 | contextBuilderMock := new(contextBuilderMock) 102 | 103 | responseBuilder := ResponseBuilder{ 104 | Context: &Context{ 105 | DecryptionKeys: testKeys, 106 | }, 107 | PolicyBuilder: policyBuilderMock, 108 | ContextBuilder: contextBuilderMock, 109 | } 110 | 111 | response, err := responseBuilder.BuildResponse(token) 112 | 113 | assert.NotNil(t, err) 114 | assert.Equal(t, events.APIGatewayCustomAuthorizerResponse{}, response) 115 | policyBuilderMock.AssertExpectations(t) 116 | contextBuilderMock.AssertExpectations(t) 117 | } 118 | 119 | func TestBuildResponseDescryptionKeysNotFound(t *testing.T) { 120 | token := createTestIDToken("", "", "", nil) 121 | 122 | policyBuilderMock := new(policyBuilderMock) 123 | contextBuilderMock := new(contextBuilderMock) 124 | 125 | responseBuilder := ResponseBuilder{ 126 | Context: &Context{ 127 | DecryptionKeys: []JWKey{}, 128 | }, 129 | PolicyBuilder: policyBuilderMock, 130 | ContextBuilder: contextBuilderMock, 131 | } 132 | 133 | response, err := responseBuilder.BuildResponse(token) 134 | 135 | assert.NotNil(t, err) 136 | assert.Equal(t, events.APIGatewayCustomAuthorizerResponse{}, response) 137 | policyBuilderMock.AssertExpectations(t) 138 | contextBuilderMock.AssertExpectations(t) 139 | } 140 | 141 | func TestBuildResponseTokenAudienceError(t *testing.T) { 142 | testEmail := "test@example.com" 143 | testSubject := "test-subject" 144 | testAudience := "test-audience" 145 | testKeys := createTestKeys() 146 | token := createTestIDToken(testEmail, testSubject, testAudience, nil) 147 | 148 | policyBuilderMock := new(policyBuilderMock) 149 | contextBuilderMock := new(contextBuilderMock) 150 | 151 | responseBuilder := ResponseBuilder{ 152 | Context: &Context{ 153 | DecryptionKeys: testKeys, 154 | CognitoClients: []string{"wrong-audience"}, 155 | }, 156 | PolicyBuilder: policyBuilderMock, 157 | ContextBuilder: contextBuilderMock, 158 | } 159 | 160 | response, err := responseBuilder.BuildResponse(token) 161 | 162 | assert.NotNil(t, err) 163 | assert.Equal(t, events.APIGatewayCustomAuthorizerResponse{}, response) 164 | policyBuilderMock.AssertExpectations(t) 165 | contextBuilderMock.AssertExpectations(t) 166 | } 167 | 168 | func TestBuildResponseTokenMultipeAudience(t *testing.T) { 169 | testEmail := "test@example.com" 170 | testSubject := "test-subject" 171 | testAudience := "test-audience-1" 172 | testKeys := createTestKeys() 173 | token := createTestIDToken(testEmail, testSubject, testAudience, nil) 174 | 175 | policyBuilderMock := new(policyBuilderMock) 176 | policyBuilderMock.On("BuildPolicy", token).Return(events.APIGatewayCustomAuthorizerPolicy{}, nil).Once() 177 | contextBuilderMock := new(contextBuilderMock) 178 | contextBuilderMock.On("BuildContext", token).Return(map[string]interface{}{}, nil).Once() 179 | 180 | responseBuilder := ResponseBuilder{ 181 | Context: &Context{ 182 | DecryptionKeys: testKeys, 183 | CognitoClients: []string{"test-audience-2", testAudience}, 184 | }, 185 | PolicyBuilder: policyBuilderMock, 186 | ContextBuilder: contextBuilderMock, 187 | } 188 | 189 | response, err := responseBuilder.BuildResponse(token) 190 | 191 | assert.Nil(t, err) 192 | assert.Equal(t, events.APIGatewayCustomAuthorizerResponse{ 193 | PrincipalID: testSubject, 194 | PolicyDocument: events.APIGatewayCustomAuthorizerPolicy{}, 195 | Context: map[string]interface{}{}, 196 | }, response) 197 | policyBuilderMock.AssertExpectations(t) 198 | contextBuilderMock.AssertExpectations(t) 199 | } 200 | 201 | func TestBuildResponseTokenMultipeAudienceError(t *testing.T) { 202 | testEmail := "test@example.com" 203 | testSubject := "test-subject" 204 | testAudience := "test-audience-bad" 205 | testKeys := createTestKeys() 206 | token := createTestIDToken(testEmail, testSubject, testAudience, nil) 207 | 208 | policyBuilderMock := new(policyBuilderMock) 209 | contextBuilderMock := new(contextBuilderMock) 210 | 211 | responseBuilder := ResponseBuilder{ 212 | Context: &Context{ 213 | DecryptionKeys: testKeys, 214 | CognitoClients: []string{"test-audience-1", "test-audience-2"}, 215 | }, 216 | PolicyBuilder: policyBuilderMock, 217 | ContextBuilder: contextBuilderMock, 218 | } 219 | 220 | response, err := responseBuilder.BuildResponse(token) 221 | 222 | assert.NotNil(t, err) 223 | assert.Equal(t, events.APIGatewayCustomAuthorizerResponse{}, response) 224 | policyBuilderMock.AssertExpectations(t) 225 | contextBuilderMock.AssertExpectations(t) 226 | } 227 | --------------------------------------------------------------------------------