├── README.md └── src ├── token ├── validator.go ├── mock.go ├── claim_test.go └── claim.go ├── test ├── sample-key.pub └── sample-key └── cognito ├── server-token.go ├── validator_test.go └── validator.go /README.md: -------------------------------------------------------------------------------- 1 | # Go Cognito 2 | A golang library to handle cognito JWTs. 3 | 4 | **Note** This library is a work in progress. 5 | -------------------------------------------------------------------------------- /src/token/validator.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import "net/http" 4 | 5 | // Validator can Validate a JWT and return a Claim object 6 | type Validator interface { 7 | Validate(jwt string) (Claim, error) 8 | ValidateRequest(r *http.Request) (Claim, error) 9 | } 10 | -------------------------------------------------------------------------------- /src/test/sample-key.pub: -------------------------------------------------------------------------------- 1 | -----BEGIN PUBLIC KEY----- 2 | MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4f5wg5l2hKsTeNem/V41 3 | fGnJm6gOdrj8ym3rFkEU/wT8RDtnSgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7 4 | mCpz9Er5qLaMXJwZxzHzAahlfA0icqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBp 5 | HssPnpYGIn20ZZuNlX2BrClciHhCPUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2 6 | XrHhR+1DcKJzQBSTAGnpYVaqpsARap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3b 7 | ODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy 8 | 7wIDAQAB 9 | -----END PUBLIC KEY----- 10 | -------------------------------------------------------------------------------- /src/token/mock.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import "net/http" 4 | 5 | type mockValidator struct { 6 | claim Claim 7 | err error 8 | } 9 | 10 | // NewMockValidator will return a validator that will return the provided claim and error 11 | func NewMockValidator(c Claim, err error) Validator { 12 | return &mockValidator{claim: c, err: err} 13 | } 14 | 15 | func (m *mockValidator) Validate(jwt string) (Claim, error) { 16 | return m.claim, m.err 17 | } 18 | 19 | func (m *mockValidator) ValidateRequest(r *http.Request) (Claim, error) { 20 | return m.claim, m.err 21 | } 22 | -------------------------------------------------------------------------------- /src/cognito/server-token.go: -------------------------------------------------------------------------------- 1 | package cognito 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/aws/aws-sdk-go/aws" 7 | "github.com/aws/aws-sdk-go/aws/session" 8 | "github.com/aws/aws-sdk-go/service/cognitoidentityprovider" 9 | ) 10 | 11 | // GetAdminAccessToken makes an admin auth request to cognito and returns an access token 12 | // This is useful for getting credentials for backend services when necessary 13 | // This should probably only be used against a segregated user pool that doesn't have public users 14 | // This will require ADMIN_NO_SRP_AUTH to be enabled an no additional challenges 15 | func GetAdminAccessToken(username, password, clientID, userPoolID string) (string, error) { 16 | sess, err := session.NewSession() 17 | if err != nil { 18 | return "", err 19 | } 20 | 21 | cog := cognitoidentityprovider.New(sess, nil) 22 | req := &cognitoidentityprovider.AdminInitiateAuthInput{ 23 | AuthFlow: aws.String(cognitoidentityprovider.AuthFlowTypeAdminNoSrpAuth), 24 | AuthParameters: map[string]*string{ 25 | "USERNAME": aws.String(username), 26 | "PASSWORD": aws.String(password), 27 | }, 28 | ClientId: aws.String(clientID), 29 | UserPoolId: aws.String(userPoolID), 30 | } 31 | 32 | resp, err := cog.AdminInitiateAuth(req) 33 | if err != nil { 34 | return "", err 35 | } 36 | if resp.AuthenticationResult == nil || resp.AuthenticationResult.AccessToken == nil { 37 | return "", fmt.Errorf("Error initiating auth, authentication result or token was nil") 38 | } 39 | return *resp.AuthenticationResult.AccessToken, nil 40 | } 41 | -------------------------------------------------------------------------------- /src/test/sample-key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEowIBAAKCAQEA4f5wg5l2hKsTeNem/V41fGnJm6gOdrj8ym3rFkEU/wT8RDtn 3 | SgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7mCpz9Er5qLaMXJwZxzHzAahlfA0i 4 | cqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBpHssPnpYGIn20ZZuNlX2BrClciHhC 5 | PUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2XrHhR+1DcKJzQBSTAGnpYVaqpsAR 6 | ap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3bODIRe1AuTyHceAbewn8b462yEWKA 7 | Rdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy7wIDAQABAoIBAQCwia1k7+2oZ2d3 8 | n6agCAbqIE1QXfCmh41ZqJHbOY3oRQG3X1wpcGH4Gk+O+zDVTV2JszdcOt7E5dAy 9 | MaomETAhRxB7hlIOnEN7WKm+dGNrKRvV0wDU5ReFMRHg31/Lnu8c+5BvGjZX+ky9 10 | POIhFFYJqwCRlopGSUIxmVj5rSgtzk3iWOQXr+ah1bjEXvlxDOWkHN6YfpV5ThdE 11 | KdBIPGEVqa63r9n2h+qazKrtiRqJqGnOrHzOECYbRFYhexsNFz7YT02xdfSHn7gM 12 | IvabDDP/Qp0PjE1jdouiMaFHYnLBbgvlnZW9yuVf/rpXTUq/njxIXMmvmEyyvSDn 13 | FcFikB8pAoGBAPF77hK4m3/rdGT7X8a/gwvZ2R121aBcdPwEaUhvj/36dx596zvY 14 | mEOjrWfZhF083/nYWE2kVquj2wjs+otCLfifEEgXcVPTnEOPO9Zg3uNSL0nNQghj 15 | FuD3iGLTUBCtM66oTe0jLSslHe8gLGEQqyMzHOzYxNqibxcOZIe8Qt0NAoGBAO+U 16 | I5+XWjWEgDmvyC3TrOSf/KCGjtu0TSv30ipv27bDLMrpvPmD/5lpptTFwcxvVhCs 17 | 2b+chCjlghFSWFbBULBrfci2FtliClOVMYrlNBdUSJhf3aYSG2Doe6Bgt1n2CpNn 18 | /iu37Y3NfemZBJA7hNl4dYe+f+uzM87cdQ214+jrAoGAXA0XxX8ll2+ToOLJsaNT 19 | OvNB9h9Uc5qK5X5w+7G7O998BN2PC/MWp8H+2fVqpXgNENpNXttkRm1hk1dych86 20 | EunfdPuqsX+as44oCyJGFHVBnWpm33eWQw9YqANRI+pCJzP08I5WK3osnPiwshd+ 21 | hR54yjgfYhBFNI7B95PmEQkCgYBzFSz7h1+s34Ycr8SvxsOBWxymG5zaCsUbPsL0 22 | 4aCgLScCHb9J+E86aVbbVFdglYa5Id7DPTL61ixhl7WZjujspeXZGSbmq0Kcnckb 23 | mDgqkLECiOJW2NHP/j0McAkDLL4tysF8TLDO8gvuvzNC+WQ6drO2ThrypLVZQ+ry 24 | eBIPmwKBgEZxhqa0gVvHQG/7Od69KWj4eJP28kq13RhKay8JOoN0vPmspXJo1HY3 25 | CKuHRG+AP579dncdUnOMvfXOtkdM4vk0+hWASBQzM9xzVcztCa+koAugjVaLS9A+ 26 | 9uQoqEeVNTckxx0S2bYevRy7hGQmUJTyQm3j1zEUR5jpdbL83Fbq 27 | -----END RSA PRIVATE KEY----- -------------------------------------------------------------------------------- /src/token/claim_test.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import "testing" 4 | 5 | func Test_GetClaim(t *testing.T) { 6 | c := Claim{} 7 | c.Claims = make(map[string]interface{}) 8 | c.Claims["bool"] = true 9 | c.Claims["float32"] = float32(123.45) 10 | c.Claims["float64"] = float64(123.45) 11 | c.Claims["int"] = int(123) 12 | c.Claims["int32"] = int32(123) 13 | c.Claims["int64"] = int64(123) 14 | c.Claims["string"] = "foobar" 15 | 16 | tests := []string{"bool", "float32", "float64", "int", "int32", "int64", "string", "noValue"} 17 | 18 | for _, v := range tests { 19 | t.Run(v, func(t *testing.T) { 20 | interfaceVal, err := c.GetClaim(v) 21 | if err != nil && v != "noValue" { 22 | t.Fatalf("Unexpected error occurred: %s", err.Error()) 23 | } 24 | if interfaceVal == nil && v != "noValue" { 25 | t.Fatalf("Unexpected nil value for %s", v) 26 | } 27 | 28 | b, err := c.GetClaimBool(v) 29 | if err != nil && v == "bool" { 30 | t.Fatalf("Unexpected error occurred: %s", err.Error()) 31 | } 32 | if v == "bool" && !b { 33 | t.Fatalf("Unexpected bool value for %s - %v", v, b) 34 | } 35 | 36 | f32, err := c.GetClaimFloat32(v) 37 | if err != nil && v == "float32" { 38 | t.Fatalf("Unexpected error occurred: %s", err.Error()) 39 | } 40 | if v == "float32" && f32 != float32(123.45) { 41 | t.Fatalf("Unexpected float32 value for %s - %v", v, f32) 42 | } 43 | 44 | f64, err := c.GetClaimFloat64(v) 45 | if err != nil && v == "float64" { 46 | t.Fatalf("Unexpected error occurred: %s", err.Error()) 47 | } 48 | if v == "float64" && f64 != float64(123.45) { 49 | t.Fatalf("Unexpected float64 value for %s - %v", v, f64) 50 | } 51 | 52 | i, err := c.GetClaimInt(v) 53 | if err != nil && v == "int" { 54 | t.Fatalf("Unexpected error occurred: %s", err.Error()) 55 | } 56 | if v == "int" && i != int(123) { 57 | t.Fatalf("Unexpected int value for %s - %v", v, i) 58 | } 59 | 60 | i32, err := c.GetClaimInt32(v) 61 | if err != nil && v == "int32" { 62 | t.Fatalf("Unexpected error occurred: %s", err.Error()) 63 | } 64 | if v == "int32" && i32 != int32(123) { 65 | t.Fatalf("Unexpected int32 value for %s - %v", v, i32) 66 | } 67 | 68 | i64, err := c.GetClaimInt64(v) 69 | if err != nil && v == "int64" { 70 | t.Fatalf("Unexpected error occurred: %s", err.Error()) 71 | } 72 | if v == "int64" && i64 != int64(123) { 73 | t.Fatalf("Unexpected int64 value for %s - %v", v, i64) 74 | } 75 | 76 | s, err := c.GetClaimString(v) 77 | if err != nil && v == "string" { 78 | t.Fatalf("Unexpected error occurred: %s", err.Error()) 79 | } 80 | if v == "string" && s != "foobar" { 81 | t.Fatalf("Unexpected string value for %s - %v", v, s) 82 | } 83 | }) 84 | } 85 | 86 | } 87 | -------------------------------------------------------------------------------- /src/cognito/validator_test.go: -------------------------------------------------------------------------------- 1 | package cognito 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "testing" 7 | "time" 8 | 9 | jwt "github.com/dgrijalva/jwt-go" 10 | ) 11 | 12 | func Test_Validate(t *testing.T) { 13 | // rsa keys are copied from jwt package tests since I know they work there and this test isn't designed to test that package. 14 | pubKeyData, err := ioutil.ReadFile("../test/sample-key.pub") 15 | if err != nil { 16 | t.Fatalf("error reading public key: %s", err.Error()) 17 | } 18 | publicKey, err := jwt.ParseRSAPublicKeyFromPEM(pubKeyData) 19 | if err != nil { 20 | t.Fatalf("Error parsing public key: %s", err.Error()) 21 | } 22 | 23 | privateKeyData, err := ioutil.ReadFile("../test/sample-key") 24 | if err != nil { 25 | t.Fatalf("error reading private key: %s", err.Error()) 26 | } 27 | privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(privateKeyData) 28 | if err != nil { 29 | t.Fatalf("Error parsing private key: %s", err.Error()) 30 | } 31 | 32 | token := jwt.New(jwt.SigningMethodRS512) 33 | token.Claims = make(map[string]interface{}) 34 | token.Claims["sub"] = "userid" 35 | token.Claims["iss"] = "https://cognito-idp.us-east-1.amazonaws.com/test-pool" 36 | token.Claims["exp"] = fmt.Sprintf("%d", time.Now().Add(5*time.Minute).Unix()) 37 | token.Claims["bool"] = true 38 | token.Header["kid"] = "testKey" 39 | 40 | tokenString, err := token.SignedString(privateKey) 41 | if err != nil { 42 | t.Fatalf("Could not sign token: %s", err.Error()) 43 | } 44 | 45 | keys := map[string]awsRSAKey{"testKey": awsRSAKey{KID: "testKey", Pub: publicKey}} 46 | 47 | v := NewRS256ValidatorFromKeys([]string{"test-pool"}, "us-east-1", keys) 48 | 49 | c, err := v.Validate(tokenString) 50 | if err != nil { 51 | t.Fatalf("Unexpected error validating string: %s", err.Error()) 52 | } 53 | 54 | if c.GetUserID() != "userid" { 55 | t.Fatalf("Unexpected userID: %v", c.GetUserID()) 56 | } 57 | if c.GetRole() != "test-pool" { 58 | t.Fatalf("Unexpected userID: %v", c.GetRole()) 59 | } 60 | b, err := c.GetClaimBool("bool") 61 | if !b || err != nil { 62 | t.Fatalf("Unexpected bool: %v | %s", b, err.Error()) 63 | } 64 | } 65 | 66 | func Test_DownloadCerts(t *testing.T) { 67 | //https://cognito-idp.us-east-1.amazonaws.com/us-east-1_gENunu2aW/.well-known/jwks.json 68 | type test struct { 69 | name string 70 | region string 71 | userPoolID string 72 | keyCount int 73 | expectedError bool 74 | } 75 | tests := []test{ 76 | { 77 | name: "base path", 78 | region: "us-east-1", 79 | userPoolID: "us-east-1_gENunu2aW", // this test user pool may not exist in the future 80 | keyCount: 2, 81 | }, 82 | { 83 | name: "exceptional path", 84 | region: "us-east-1", 85 | userPoolID: "foobar", 86 | expectedError: true, 87 | }, 88 | } 89 | 90 | for _, tc := range tests { 91 | t.Run(tc.name, func(t *testing.T) { 92 | keys, err := downloadCerts(tc.region, tc.userPoolID) 93 | if err != nil && !tc.expectedError { 94 | t.Fatalf("Unexpected error occurred: %v", err) 95 | } 96 | if len(keys) != tc.keyCount { 97 | t.Fatalf("Expected %d keys, got %d", tc.keyCount, len(keys)) 98 | } 99 | }) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /src/token/claim.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // Claim holds all the claim information from a JWT and provides helpers to access them 8 | type Claim struct { 9 | ID string 10 | Role string 11 | Claims map[string]interface{} 12 | } 13 | 14 | // GetUserID will return the user's id 15 | func (c Claim) GetUserID() string { 16 | return c.ID 17 | } 18 | 19 | // IsUser returns true if this claim is for the provided user 20 | func (c Claim) IsUser(userID string) bool { 21 | return c.ID == userID 22 | } 23 | 24 | // GetRole will return the user's role 25 | func (c Claim) GetRole() string { 26 | return c.Role 27 | } 28 | 29 | // IsRole returns true if this claim is for the provided role 30 | func (c Claim) IsRole(role string) bool { 31 | return c.Role == role 32 | } 33 | 34 | // GetClaim returns the claim as an interface or an error if it does not exist 35 | func (c Claim) GetClaim(key string) (interface{}, error) { 36 | if v, ok := c.Claims[key]; ok { 37 | return v, nil 38 | } 39 | return nil, fmt.Errorf("claim %s does not exist", key) 40 | } 41 | 42 | // GetClaimString returns the claim as a string if possible 43 | func (c Claim) GetClaimString(key string) (string, error) { 44 | v, err := c.GetClaim(key) 45 | if err != nil { 46 | return "", err 47 | } 48 | if v != nil { 49 | if ret, ok := v.(string); ok { 50 | return ret, nil 51 | } 52 | } 53 | return "", fmt.Errorf("claim %s is not a string", key) 54 | } 55 | 56 | // GetClaimBool returns the claim as a bool if possible 57 | func (c Claim) GetClaimBool(key string) (bool, error) { 58 | v, err := c.GetClaim(key) 59 | if err != nil { 60 | return false, err 61 | } 62 | if v != nil { 63 | if ret, ok := v.(bool); ok { 64 | return ret, nil 65 | } 66 | } 67 | return false, fmt.Errorf("claim %s is not a bool", key) 68 | } 69 | 70 | // GetClaimInt64 returns the claim as a int64 if possible 71 | func (c *Claim) GetClaimInt64(key string) (int64, error) { 72 | v, err := c.GetClaim(key) 73 | if err != nil { 74 | return 0, err 75 | } 76 | if v != nil { 77 | if ret, ok := v.(int64); ok { 78 | return ret, nil 79 | } 80 | } 81 | return 0, fmt.Errorf("claim %s is not a int64", key) 82 | } 83 | 84 | // GetClaimInt32 returns the claim as a int32 if possible 85 | func (c Claim) GetClaimInt32(key string) (int32, error) { 86 | v, err := c.GetClaim(key) 87 | if err != nil { 88 | return 0, err 89 | } 90 | if v != nil { 91 | if ret, ok := v.(int32); ok { 92 | return ret, nil 93 | } 94 | } 95 | return 0, fmt.Errorf("claim %s is not a int32", key) 96 | } 97 | 98 | // GetClaimInt returns the claim as a int if possible 99 | func (c Claim) GetClaimInt(key string) (int, error) { 100 | v, err := c.GetClaim(key) 101 | if err != nil { 102 | return 0, err 103 | } 104 | if v != nil { 105 | if ret, ok := v.(int); ok { 106 | return ret, nil 107 | } 108 | } 109 | return 0, fmt.Errorf("claim %s is not a int", key) 110 | } 111 | 112 | // GetClaimFloat64 returns the claim as a Float64 if possible 113 | func (c Claim) GetClaimFloat64(key string) (float64, error) { 114 | v, err := c.GetClaim(key) 115 | if err != nil { 116 | return 0, err 117 | } 118 | if v != nil { 119 | if ret, ok := v.(float64); ok { 120 | return ret, nil 121 | } 122 | } 123 | return 0, fmt.Errorf("claim %s is not a float64", key) 124 | } 125 | 126 | // GetClaimFloat32 returns the claim as a float32 if possible 127 | func (c Claim) GetClaimFloat32(key string) (float32, error) { 128 | v, err := c.GetClaim(key) 129 | if err != nil { 130 | return 0, err 131 | } 132 | if v != nil { 133 | if ret, ok := v.(float32); ok { 134 | return ret, nil 135 | } 136 | } 137 | return 0, fmt.Errorf("claim %s is not a float32", key) 138 | } 139 | -------------------------------------------------------------------------------- /src/cognito/validator.go: -------------------------------------------------------------------------------- 1 | package cognito 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rsa" 6 | "encoding/base64" 7 | "encoding/binary" 8 | "encoding/json" 9 | "fmt" 10 | "io" 11 | "io/ioutil" 12 | "math/big" 13 | "net/http" 14 | "strings" 15 | "time" 16 | 17 | jwt "github.com/dgrijalva/jwt-go" 18 | "github.com/divideandconquer/go-cognito/src/token" 19 | ) 20 | 21 | const ( 22 | cognitoIssuer = "https://cognito-idp.%s.amazonaws.com/" 23 | certURL = "https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json" 24 | ) 25 | 26 | type validator struct { 27 | keys map[string]awsRSAKey 28 | region string 29 | issuer string 30 | validUserPoolIDs []string 31 | } 32 | 33 | type awsRSAKey struct { 34 | ALG string `json:"alg"` 35 | E string `json:"e"` 36 | KID string `json:"kid"` 37 | KTY string `json:"kty"` 38 | N string `json:"n"` 39 | Use string `json:"use"` 40 | Pub *rsa.PublicKey `json:"-"` 41 | } 42 | 43 | type awsKeyObj struct { 44 | Keys []awsRSAKey `json:"keys"` 45 | } 46 | 47 | // NewRS256Validator returns a token.Validator that is specifically designed for Cognito generated JWTs 48 | func NewRS256Validator(validUserPoolIDs []string, awsRegion string) (token.Validator, error) { 49 | result := validator{region: awsRegion, issuer: fmt.Sprintf(cognitoIssuer, awsRegion), validUserPoolIDs: validUserPoolIDs} 50 | result.keys = make(map[string]awsRSAKey) 51 | 52 | var allKeys []awsRSAKey 53 | for _, v := range result.validUserPoolIDs { 54 | keys, err := downloadCerts(awsRegion, v) 55 | if err != nil { 56 | return nil, fmt.Errorf("Could not download signing keys from AWS: %s", err.Error()) 57 | } 58 | allKeys = append(allKeys, keys...) 59 | } 60 | 61 | for _, v := range allKeys { 62 | result.keys[v.KID] = v 63 | } 64 | 65 | return &result, nil 66 | } 67 | 68 | // NewRS256ValidatorFromKeys returns a token.Validator that is specifically designed for Cognito generated JWTs 69 | // To use this signature you must pass in the RSA public key information yourself. To automatically download them use 70 | // NewRS256Validator instead. 71 | func NewRS256ValidatorFromKeys(validUserPoolIDs []string, awsRegion string, keys map[string]awsRSAKey) token.Validator { 72 | return &validator{region: awsRegion, issuer: fmt.Sprintf(cognitoIssuer, awsRegion), validUserPoolIDs: validUserPoolIDs, keys: keys} 73 | } 74 | 75 | func downloadCerts(region string, userPoolID string) ([]awsRSAKey, error) { 76 | url := fmt.Sprintf(certURL, region, userPoolID) 77 | resp, err := http.Get(url) 78 | if err != nil { 79 | return nil, fmt.Errorf("Could not download RSA cert info: %s", err.Error()) 80 | } 81 | defer func() { 82 | io.Copy(ioutil.Discard, resp.Body) 83 | resp.Body.Close() 84 | }() 85 | 86 | if resp.StatusCode != http.StatusOK { 87 | return nil, fmt.Errorf("Could not download RSA cert info. Status code %d", resp.StatusCode) 88 | } 89 | 90 | var keyObj awsKeyObj 91 | dec := json.NewDecoder(resp.Body) 92 | err = dec.Decode(&keyObj) 93 | if err != nil { 94 | return nil, fmt.Errorf("Could not parse RSA cert JSON: %s", err.Error()) 95 | } 96 | 97 | for k, v := range keyObj.Keys { 98 | v.Pub, err = parseRSAPublicKey(v.N, v.E) 99 | if err != nil { 100 | return nil, err 101 | } 102 | keyObj.Keys[k] = v 103 | } 104 | 105 | return keyObj.Keys, nil 106 | } 107 | 108 | // converts base64 encoded N and E strings into a rsa.PublicKey 109 | func parseRSAPublicKey(nStr string, eStr string) (*rsa.PublicKey, error) { 110 | decN, err := base64.RawURLEncoding.DecodeString(nStr) 111 | if err != nil { 112 | return nil, fmt.Errorf("Error decoding N string from public key: %s", err.Error()) 113 | } 114 | n := big.NewInt(0) 115 | n.SetBytes(decN) 116 | 117 | decE, err := base64.RawURLEncoding.DecodeString(eStr) 118 | if err != nil { 119 | return nil, fmt.Errorf("Error decoding E string from public key: %s", err.Error()) 120 | } 121 | var eBytes []byte 122 | if len(decE) < 8 { 123 | eBytes = make([]byte, 8-len(decE), 8) 124 | eBytes = append(eBytes, decE...) 125 | } else { 126 | eBytes = decE 127 | } 128 | eReader := bytes.NewReader(eBytes) 129 | var e uint64 130 | err = binary.Read(eReader, binary.BigEndian, &e) 131 | if err != nil { 132 | return nil, fmt.Errorf("Error reading E as int: %s", err.Error()) 133 | } 134 | pKey := rsa.PublicKey{N: n, E: int(e)} 135 | return &pKey, nil 136 | } 137 | 138 | func (v *validator) Validate(tokenString string) (token.Claim, error) { 139 | result := token.Claim{} 140 | t, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) { 141 | // validate the alg is what is RSA 142 | if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok { 143 | return nil, fmt.Errorf("Unexpected signing method: %v", t.Header["alg"]) 144 | } 145 | // Get the matching key 146 | if keyID, ok := t.Header["kid"]; ok { 147 | if keyIDStr, ok := keyID.(string); ok { 148 | if signingKey, ok := v.keys[keyIDStr]; ok { 149 | return signingKey.Pub, nil 150 | } 151 | } 152 | } 153 | 154 | return nil, fmt.Errorf("Could not find matching key for kid %s", t.Header["kid"]) 155 | }) 156 | if err != nil { 157 | return result, err 158 | } 159 | 160 | // make sure the token is valid 161 | if !t.Valid { 162 | return result, fmt.Errorf("Token not valid") 163 | } 164 | 165 | // check user id 166 | var userID string 167 | if sub, ok := t.Claims["sub"]; ok { 168 | if subStr, ok := sub.(string); ok { 169 | userID = subStr 170 | } else { 171 | return result, fmt.Errorf("Token sub not properly formatted") 172 | } 173 | } else { 174 | return result, fmt.Errorf("Token is missing the 'sub' claim") 175 | } 176 | 177 | // check iss and parse role (user pool id) 178 | var role string 179 | if iss, ok := t.Claims["iss"]; ok { 180 | if issStr, ok := iss.(string); ok { 181 | if !strings.HasPrefix(issStr, v.issuer) { 182 | return result, fmt.Errorf("Token issuer [%s] does not match expected cognito-idp url [%s]", issStr, v.issuer) 183 | } 184 | role = strings.Replace(issStr, v.issuer, "", -1) 185 | } else { 186 | return result, fmt.Errorf("Token issuer not properly formatted") 187 | } 188 | } else { 189 | return result, fmt.Errorf("Token is missing the 'iss' claim") 190 | } 191 | 192 | // make sure the role is a valid user pool id 193 | roleValid := false 194 | for _, upID := range v.validUserPoolIDs { 195 | if upID == role { 196 | roleValid = true 197 | break 198 | } 199 | } 200 | if !roleValid { 201 | return result, fmt.Errorf("Token role [%s] is not valid", role) 202 | } 203 | 204 | // check expiration if set 205 | if exp, ok := t.Claims["exp"]; ok { 206 | if expFloat, ok := exp.(float64); ok { 207 | expInt := int64(expFloat) 208 | if expInt < time.Now().UTC().Unix() { 209 | return result, fmt.Errorf("Token expired") 210 | } 211 | } else { 212 | return result, fmt.Errorf("Token expiration not properly formatted %#v", exp) 213 | } 214 | } 215 | 216 | result.ID = userID 217 | result.Role = role 218 | result.Claims = t.Claims 219 | return result, nil 220 | } 221 | 222 | func (v *validator) ValidateRequest(r *http.Request) (token.Claim, error) { 223 | jwt := r.Header.Get("Authorization") 224 | jwt = strings.TrimPrefix(jwt, "Bearer ") 225 | return v.Validate(jwt) 226 | } 227 | --------------------------------------------------------------------------------