├── .gitignore ├── dev.json ├── .travis.yml ├── main.go ├── Dockerfile ├── cmd ├── version_cmd.go └── root_cmd.go ├── models ├── user.go ├── db.go └── subscriptions.go ├── example-config.json ├── glide.yaml ├── Makefile ├── api ├── error.go ├── payers.go ├── context.go ├── api.go ├── api_test.go ├── subscriptions.go └── subscriptions_test.go ├── CONTRIBUTING.md ├── .github ├── PULL_REQUEST_TEMPLATE.md └── ISSUE_TEMPLATE.md ├── LICENSE ├── conf ├── reflect_test.go ├── logging.go ├── reflect.go └── configuration.go ├── README.md ├── CODE_OF_CONDUCT.md └── glide.lock /.gitignore: -------------------------------------------------------------------------------- 1 | vendor 2 | .idea 3 | *.iml 4 | netlify-subscriptions 5 | -------------------------------------------------------------------------------- /dev.json: -------------------------------------------------------------------------------- 1 | { 2 | "log_config": { 3 | "level": "debug" 4 | }, 5 | "port": 1000 6 | } 7 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | --- 2 | language: go 3 | 4 | go: 5 | - 1.8 6 | 7 | install: make deps 8 | script: make all 9 | notifications: 10 | email: false 11 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/netlify/gojoin/cmd" 7 | ) 8 | 9 | func main() { 10 | if err := cmd.RootCommand().Execute(); err != nil { 11 | log.Fatal(err) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM calavera/go-glide:v0.12.2 2 | 3 | ADD . /go/src/github.com/netlify/gojoin 4 | 5 | RUN useradd -m netlify && cd /go/src/github.com/netlify/gojoin && make deps build && mv gojoin /usr/local/bin/ 6 | 7 | USER netlify 8 | CMD ["gojoin"] 9 | -------------------------------------------------------------------------------- /cmd/version_cmd.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/spf13/cobra" 7 | ) 8 | 9 | var Version string 10 | 11 | var versionCmd = cobra.Command{ 12 | Run: showVersion, 13 | Use: "version", 14 | } 15 | 16 | func showVersion(cmd *cobra.Command, args []string) { 17 | fmt.Println(Version) 18 | } 19 | -------------------------------------------------------------------------------- /models/user.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import "time" 4 | 5 | type User struct { 6 | ID string `json:"id"` 7 | Email string `json:"email"` 8 | RemoteID string `json:"remote_id"` 9 | 10 | CreatedAt time.Time 11 | UpdatedAt time.Time 12 | DeletedAt *time.Time 13 | } 14 | 15 | func (User) TableName() string { 16 | return tableName("users") 17 | } 18 | -------------------------------------------------------------------------------- /example-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "port": 9000, 3 | "jwt_secret": "super-secret-value", 4 | "admin_group_name": "admin", 5 | "stripe_key": "stripe-key", 6 | "log": { 7 | "level": "debug", 8 | "file": "" 9 | }, 10 | "db": { 11 | "driver": "", 12 | "url": "", 13 | "namespace": "sk_prpgMJNTV1xr8tw3t5iFs0FhhamYb", 14 | "automigrate": true 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /glide.yaml: -------------------------------------------------------------------------------- 1 | package: github.com/netlify/gojoin 2 | import: 3 | - package: github.com/spf13/cobra 4 | - package: github.com/spf13/viper 5 | - package: github.com/rs/cors 6 | version: v1.0 7 | - package: github.com/guregu/kami 8 | version: v2.1.0 9 | - package: github.com/pborman/uuid 10 | version: v1.0 11 | - package: github.com/dgrijalva/jwt-go 12 | version: v3.0.0 13 | - package: github.com/stretchr/testify 14 | version: v1.1.4 15 | subpackages: 16 | - assert 17 | - package: github.com/stripe/stripe-go 18 | version: v18.1.0 19 | - package: github.com/valyala/fasthttp 20 | version: v20160617 21 | - package: github.com/sirupsen/logrus 22 | version: 1.0.2 23 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PONY: all build deps image lint test 2 | 3 | help: ## Show this help. 4 | @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {sub("\\\\n",sprintf("\n%22c"," "), $$2);printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) 5 | 6 | all: test build ## Run the tests and build the binary. 7 | 8 | build: ## Build the binary. 9 | go build -ldflags "-X github.com/netlify/gojoin/cmd.Version=`git rev-parse HEAD`" 10 | 11 | deps: ## Install dependencies. 12 | @go get -u github.com/golang/lint/golint 13 | @go get -u github.com/Masterminds/glide && glide install 14 | 15 | image: ## Build the Docker image. 16 | docker build . 17 | 18 | lint: ## Lint the code 19 | golint `go list ./... | grep -v /vendor/` 20 | 21 | test: ## Run tests. 22 | go test -v `go list ./... | grep -v /vendor/` 23 | -------------------------------------------------------------------------------- /api/error.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | ) 7 | 8 | // HTTPError is an error with a message 9 | type HTTPError struct { 10 | Code int `json:"code"` 11 | Message string `json:"msg"` 12 | } 13 | 14 | func (e HTTPError) Error() string { 15 | return fmt.Sprintf("%d: %s", e.Code, e.Message) 16 | } 17 | 18 | func httpError(code int, fmtString string, args ...interface{}) *HTTPError { 19 | return &HTTPError{ 20 | Code: code, 21 | Message: fmt.Sprintf(fmtString, args...), 22 | } 23 | } 24 | 25 | func writeError(w http.ResponseWriter, code int, msg string, args ...interface{}) *HTTPError { 26 | err := httpError(code, msg, args...) 27 | sendJSON(w, err.Code, err) 28 | return err 29 | } 30 | 31 | func notFoundError(w http.ResponseWriter, msg string, args ...interface{}) *HTTPError { 32 | return writeError(w, http.StatusNotFound, msg, args...) 33 | } 34 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # CONTRIBUTING 2 | 3 | Contributions are always welcome, no matter how large or small. Before contributing, 4 | please read the [code of conduct](CODE_OF_CONDUCT.md). 5 | 6 | ## Setup 7 | 8 | > Install Go and Glide https://github.com/Masterminds/glide 9 | 10 | ```sh 11 | $ git clone https://github.com/netlify/gojoin 12 | $ cd gojoin 13 | $ make deps 14 | ``` 15 | 16 | ## Building 17 | 18 | ```sh 19 | $ make build 20 | ``` 21 | 22 | ## Testing 23 | 24 | ```sh 25 | $ make test 26 | ``` 27 | 28 | ## Pull Requests 29 | 30 | We actively welcome your pull requests. 31 | 32 | 1. Fork the repo and create your branch from `master`. 33 | 2. If you've added code that should be tested, add tests. 34 | 3. If you've changed APIs, update the documentation. 35 | 4. Ensure the test suite passes. 36 | 5. Make sure your code lints. 37 | 38 | ## License 39 | 40 | By contributing to Netlify CMS, you agree that your contributions will be licensed 41 | under its [MIT license](LICENSE). 42 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 13 | 14 | **- Summary** 15 | 16 | 20 | 21 | **- Test plan** 22 | 23 | 27 | 28 | **- Description for the changelog** 29 | 30 | 34 | 35 | **- A picture of a cute animal (not mandatory but encouraged)** 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016 Netlify 2 | 3 | MIT License 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining 6 | a copy of this software and associated documentation files (the 7 | "Software"), to deal in the Software without restriction, including 8 | without limitation the rights to use, copy, modify, merge, publish, 9 | distribute, sublicense, and/or sell copies of the Software, and to 10 | permit persons to whom the Software is furnished to do so, subject to 11 | the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be 14 | included in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 19 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 20 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 22 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /conf/reflect_test.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/spf13/viper" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestSimpleValues(t *testing.T) { 12 | c := struct { 13 | Simple string `json:"simple"` 14 | }{} 15 | 16 | viper.SetDefault("simple", "i am a simple string") 17 | 18 | assert.Nil(t, recursivelySet(reflect.ValueOf(&c), "")) 19 | assert.Equal(t, "i am a simple string", c.Simple) 20 | } 21 | 22 | func TestNestedValues(t *testing.T) { 23 | c := struct { 24 | Simple string `json:"simple"` 25 | Nested struct { 26 | BoolVal bool `json:"bool"` 27 | StringVal string `json:"string"` 28 | NumberVal int `json:"number"` 29 | } `json:"nested"` 30 | }{} 31 | 32 | viper.SetDefault("simple", "simple") 33 | viper.SetDefault("nested.bool", true) 34 | viper.SetDefault("nested.string", "i am a simple string") 35 | viper.SetDefault("nested.number", 4) 36 | 37 | assert.Nil(t, recursivelySet(reflect.ValueOf(&c), "")) 38 | assert.Equal(t, "simple", c.Simple) 39 | assert.Equal(t, 4, c.Nested.NumberVal) 40 | assert.Equal(t, "i am a simple string", c.Nested.StringVal) 41 | assert.Equal(t, true, c.Nested.BoolVal) 42 | } 43 | -------------------------------------------------------------------------------- /conf/logging.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | 7 | "github.com/sirupsen/logrus" 8 | ) 9 | 10 | // LoggingConfig specifies all the parameters needed for logging 11 | type LoggingConfig struct { 12 | Level string `json:"level"` 13 | File string `json:"file"` 14 | } 15 | 16 | // ConfigureLogging will take the logging configuration and also adds 17 | // a few default parameters 18 | func ConfigureLogging(config *LoggingConfig) (*logrus.Entry, error) { 19 | hostname, err := os.Hostname() 20 | if err != nil { 21 | return nil, err 22 | } 23 | 24 | // use a file if you want 25 | if config.File != "" { 26 | f, errOpen := os.OpenFile(config.File, os.O_RDWR|os.O_APPEND|os.O_CREATE, 0660) 27 | if errOpen != nil { 28 | return nil, errOpen 29 | } 30 | logrus.SetOutput(f) 31 | } 32 | 33 | if config.Level != "" { 34 | level, err := logrus.ParseLevel(strings.ToUpper(config.Level)) 35 | if err != nil { 36 | return nil, err 37 | } 38 | logrus.SetLevel(level) 39 | } 40 | 41 | // always use the fulltimestamp 42 | logrus.SetFormatter(&logrus.TextFormatter{ 43 | FullTimestamp: true, 44 | DisableTimestamp: false, 45 | }) 46 | 47 | return logrus.StandardLogger().WithField("hostname", hostname), nil 48 | } 49 | -------------------------------------------------------------------------------- /models/db.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | _ "github.com/go-sql-driver/mysql" 5 | _ "github.com/lib/pq" 6 | _ "github.com/mattn/go-sqlite3" 7 | "github.com/pkg/errors" 8 | 9 | "github.com/jinzhu/gorm" 10 | 11 | "github.com/netlify/gojoin/conf" 12 | ) 13 | 14 | // Namespace puts all tables names under a common 15 | // namespace. This is useful if you want to use 16 | // the same database for several services and don't 17 | // want table names to collide. 18 | var Namespace string 19 | 20 | func Connect(config *conf.DBConfig) (*gorm.DB, error) { 21 | db, err := gorm.Open(config.Driver, config.ConnURL) 22 | if err != nil { 23 | return nil, errors.Wrap(err, "opening database connection") 24 | } 25 | 26 | err = db.DB().Ping() 27 | if err != nil { 28 | return nil, errors.Wrap(err, "checking database connection") 29 | } 30 | 31 | if config.Automigrate { 32 | if err := AutoMigrate(db); err != nil { 33 | return nil, errors.Wrap(err, "migrating tables") 34 | } 35 | } 36 | 37 | return db, nil 38 | } 39 | 40 | func AutoMigrate(db *gorm.DB) error { 41 | return db.AutoMigrate(Subscription{}, User{}).Error 42 | } 43 | func tableName(defaultName string) string { 44 | if Namespace != "" { 45 | return Namespace + "_" + defaultName 46 | } 47 | return defaultName 48 | } 49 | -------------------------------------------------------------------------------- /models/subscriptions.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "time" 5 | 6 | "errors" 7 | "strings" 8 | 9 | "github.com/jinzhu/gorm" 10 | "github.com/pborman/uuid" 11 | ) 12 | 13 | type Subscription struct { 14 | ID string `gorm:"unique;primary",json:"id"` 15 | Type string `json:"type"` 16 | 17 | User *User `json:"user,omitempty"` 18 | UserID string `json:"user_id,omitempty"` 19 | 20 | RemoteID string `json:"remote_id"` 21 | Plan string `json:"plan"` 22 | 23 | CreatedAt time.Time `json:"created_at"` 24 | UpdatedAt time.Time `json:"updated_at"` 25 | DeletedAt *time.Time `json:"-"` 26 | } 27 | 28 | func (s *Subscription) BeforeCreate(scope *gorm.Scope) error { 29 | s.ID = uuid.NewRandom().String() 30 | scope.SetColumn("ID", s.ID) 31 | 32 | fields := map[string]string{ 33 | "user_id": s.UserID, 34 | "plan": s.Plan, 35 | "remote_id": s.RemoteID, 36 | "type": s.Type, 37 | } 38 | 39 | members := []string{} 40 | for k, v := range fields { 41 | if v == "" { 42 | members = append(members, k) 43 | } 44 | } 45 | if len(members) > 0 { 46 | return errors.New("Missing required fields: " + strings.Join(members, ",")) 47 | } 48 | 49 | return nil 50 | } 51 | 52 | func (Subscription) TableName() string { 53 | return tableName("subscriptions") 54 | } 55 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 21 | 22 | **- Do you want to request a *feature* or report a *bug*?** 23 | 24 | **- What is the current behavior?** 25 | 26 | **- If the current behavior is a bug, please provide the steps to reproduce.** 27 | 28 | **- What is the expected behavior?** 29 | 30 | **- Please mention your Go version, and operating system version.** 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GoJoin 2 | 3 | This acts as a proxy to Stripe. It exposes a very simple way to call Stripe's subscription endpoints. 4 | 5 | GoJoin is released under the [MIT License](LICENSE). 6 | Please make sure you understand its [implications and guarantees](https://writing.kemitchell.com/2016/09/21/MIT-License-Line-by-Line.html). 7 | 8 | ## authentication 9 | All of the endpoints rely on a JWT token. We will use the user ID set in that token for the user information to Stripe. 10 | 11 | The API as is: 12 | 13 | GET /subscriptions -- list all the subscriptions for the user 14 | 15 | This endpoint will return a list of subscriptions, but also a JWT token that has been decorated with an `app_metadata.subscriptions` property which is a map of the users subscriptions. 16 | 17 | These endpoints are all grouped by a `type` of subscription. For instance if you have a `membership` type with 18 | plan levels gold, silver, and bronze. 19 | 20 | GET /subscriptions/:type 21 | POST /subscriptions/:type 22 | DELETE /subscriptions/:type 23 | 24 | The POST endpoint takes a payload like so 25 | 26 | ``` json 27 | { 28 | "stripe_key": "xxxxx", 29 | "plan": "silver" 30 | } 31 | ``` 32 | 33 | Using this endpoint will create the plan if it doesn't exist, otherwise it will change the subscription to that plan. 34 | The other responses are defined in `api/subscriptions.go`. 35 | -------------------------------------------------------------------------------- /cmd/root_cmd.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "log" 5 | 6 | "os" 7 | 8 | "github.com/netlify/gojoin/api" 9 | "github.com/netlify/gojoin/conf" 10 | "github.com/netlify/gojoin/models" 11 | "github.com/spf13/cobra" 12 | "github.com/stripe/stripe-go" 13 | ) 14 | 15 | var rootCmd = cobra.Command{ 16 | Run: run, 17 | } 18 | 19 | // RootCommand will setup and return the root command 20 | func RootCommand() *cobra.Command { 21 | rootCmd.PersistentFlags().StringP("config", "c", "", "the config file to use") 22 | rootCmd.Flags().IntP("port", "p", 0, "the port to use") 23 | 24 | rootCmd.AddCommand(&versionCmd) 25 | 26 | return &rootCmd 27 | } 28 | 29 | func run(cmd *cobra.Command, args []string) { 30 | config, err := conf.LoadConfig(cmd) 31 | if err != nil { 32 | log.Fatal("Failed to load config: " + err.Error()) 33 | } 34 | 35 | logger, err := conf.ConfigureLogging(&config.LogConfig) 36 | if err != nil { 37 | log.Fatal("Failed to configure logging: " + err.Error()) 38 | } 39 | 40 | logger.Infof("Connecting to DB") 41 | db, err := models.Connect(&config.DBConfig) 42 | if err != nil { 43 | logger.Fatal("Failed to connect to db: " + err.Error()) 44 | } 45 | 46 | logger.Info("Configuring stripe access") 47 | stripe.Key = config.StripeKey 48 | 49 | logger.Infof("Starting API on port %d", config.Port) 50 | a := api.NewAPI(config, db, &api.StripeProxy{}, Version) 51 | err = a.Serve() 52 | if err != nil { 53 | logger.WithError(err).Error("Error while running API: %v", err) 54 | os.Exit(1) 55 | } 56 | logger.Info("API Shutdown") 57 | } 58 | -------------------------------------------------------------------------------- /conf/reflect.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | 8 | "github.com/spf13/viper" 9 | ) 10 | 11 | const tagPrefix = "viper" 12 | 13 | func populateConfig(config *Config) (*Config, error) { 14 | err := recursivelySet(reflect.ValueOf(config), "") 15 | if err != nil { 16 | return nil, err 17 | } 18 | 19 | return config, nil 20 | } 21 | 22 | func recursivelySet(val reflect.Value, prefix string) error { 23 | if val.Kind() != reflect.Ptr { 24 | return errors.New("WTF") 25 | } 26 | 27 | // dereference 28 | val = reflect.Indirect(val) 29 | if val.Kind() != reflect.Struct { 30 | return errors.New("FML") 31 | } 32 | 33 | // grab the type for this instance 34 | vType := reflect.TypeOf(val.Interface()) 35 | 36 | // go through child fields 37 | for i := 0; i < val.NumField(); i++ { 38 | thisField := val.Field(i) 39 | thisType := vType.Field(i) 40 | tag := prefix + getTag(thisType) 41 | 42 | switch thisField.Kind() { 43 | case reflect.Struct: 44 | if err := recursivelySet(thisField.Addr(), tag+"."); err != nil { 45 | return err 46 | } 47 | case reflect.Int: 48 | fallthrough 49 | case reflect.Int32: 50 | fallthrough 51 | case reflect.Int64: 52 | // you can only set with an int64 -> int 53 | configVal := int64(viper.GetInt(tag)) 54 | thisField.SetInt(configVal) 55 | case reflect.String: 56 | thisField.SetString(viper.GetString(tag)) 57 | case reflect.Bool: 58 | thisField.SetBool(viper.GetBool(tag)) 59 | default: 60 | return fmt.Errorf("unexpected type detected ~ aborting: %s", thisField.Kind()) 61 | } 62 | } 63 | 64 | return nil 65 | } 66 | 67 | func getTag(field reflect.StructField) string { 68 | // check if maybe we have a special magic tag 69 | tag := field.Tag 70 | if tag != "" { 71 | for _, prefix := range []string{tagPrefix, "mapstructure", "json"} { 72 | if v := tag.Get(prefix); v != "" { 73 | return v 74 | } 75 | } 76 | } 77 | 78 | return field.Name 79 | } 80 | -------------------------------------------------------------------------------- /api/payers.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/stripe/stripe-go" 7 | "github.com/stripe/stripe-go/customer" 8 | "github.com/stripe/stripe-go/sub" 9 | ) 10 | 11 | type payerProxy interface { 12 | createCustomer(userID, email, payToken string) (string, error) 13 | create(userID, plan, token string) (string, error) 14 | update(subID, plan, token string) (string, error) 15 | delete(subID string) error 16 | } 17 | 18 | type StripeProxy struct { 19 | } 20 | 21 | func (StripeProxy) create(userID, plan, token string) (string, error) { 22 | s, err := sub.New(&stripe.SubParams{ 23 | Customer: userID, 24 | Plan: plan, 25 | }) 26 | if err != nil { 27 | return "", err 28 | } 29 | return s.ID, nil 30 | } 31 | 32 | func (StripeProxy) update(subID, plan, token string) (string, error) { 33 | s, err := sub.Update(subID, &stripe.SubParams{ 34 | Plan: plan, 35 | }) 36 | if err != nil { 37 | return "", err 38 | } 39 | 40 | return s.ID, nil 41 | } 42 | 43 | func (StripeProxy) delete(subID string) error { 44 | _, err := sub.Cancel(subID, &stripe.SubParams{}) 45 | return err 46 | } 47 | 48 | func (StripeProxy) createCustomer(userID, email, payToken string) (string, error) { 49 | params := &stripe.CustomerParams{ 50 | Email: email, 51 | Source: &stripe.SourceParams{ 52 | Token: payToken, 53 | }, 54 | } 55 | params.Meta = map[string]string{"nf_id": userID} 56 | c, err := customer.New(params) 57 | if err != nil { 58 | return "", err 59 | } 60 | return c.ID, nil 61 | } 62 | 63 | /* 64 | 65 | POST /subscriptions/members/smashing 66 | {first_name: Matt, last_name: Biilmann, strie_token: "sdfsdfsfsd"} 67 | Signed by {user_id: 1234} 68 | 69 | --- 70 | 71 | existingUser = db.findUser({user_id: 1234}) 72 | 73 | if existingUser 74 | stripApi.subscribeToPlan(existingUser.customer_id, params.plan, source: params.stripe_token) 75 | else 76 | customer = stripeApi.createCustomer({user_id: ...}) 77 | user.create({customer_id: customer.id}) 78 | stripeApi.subscripbeToPlan(user.customer_id, params.plan, source: params.stripe_token) 79 | end 80 | 81 | 82 | */ 83 | 84 | type errorProxy struct { 85 | } 86 | 87 | func (errorProxy) createCustomer(_, _, _ string) (string, error) { 88 | return "", errors.New("No payer proxy provided") 89 | } 90 | 91 | func (errorProxy) create(userID, plan, token string) (string, error) { 92 | return "", errors.New("No payer proxy provided") 93 | } 94 | func (errorProxy) update(subID, plan, token string) (string, error) { 95 | return "", errors.New("No payer proxy provided") 96 | } 97 | func (errorProxy) delete(subID string) error { 98 | return errors.New("No payer proxy provided") 99 | } 100 | -------------------------------------------------------------------------------- /conf/configuration.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "net/url" 5 | "os" 6 | "strconv" 7 | "strings" 8 | 9 | "github.com/pkg/errors" 10 | "github.com/spf13/cobra" 11 | "github.com/spf13/viper" 12 | ) 13 | 14 | // Config the application's configuration 15 | type Config struct { 16 | Port int `mapstructure:"port" json:"port"` 17 | JWTSecret string `mapstructure:"jwt_secret" json:"jwt_secret"` 18 | AdminGroupName string `mapstructure:"admin_group_name" json:"admin_group_name"` 19 | StripeKey string `mapstructure:"stripe_key" json:"stripe_key"` 20 | LogConfig LoggingConfig `mapstructure:"log" json:"log"` 21 | DBConfig DBConfig `mapstructure:"db" json:"db"` 22 | } 23 | 24 | type DBConfig struct { 25 | Driver string `mapstructure:"driver" json:"driver"` 26 | ConnURL string `mapstructure:"url" json:"url"` 27 | Namespace string `mapstructure:"namespace" json:"namespace"` 28 | Automigrate bool `mapstructure:"automigrate" json:"automigrate"` 29 | } 30 | 31 | // LoadConfig loads the config from a file if specified, otherwise from the environment 32 | func LoadConfig(cmd *cobra.Command) (*Config, error) { 33 | viper.SetConfigType("json") 34 | 35 | err := viper.BindPFlags(cmd.Flags()) 36 | if err != nil { 37 | return nil, err 38 | } 39 | 40 | viper.SetEnvPrefix("GOJOIN") 41 | viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) 42 | viper.AutomaticEnv() 43 | 44 | if configFile, _ := cmd.Flags().GetString("config"); configFile != "" { 45 | viper.SetConfigFile(configFile) 46 | } else { 47 | viper.SetConfigName("config") 48 | viper.AddConfigPath("./") 49 | viper.AddConfigPath("$HOME/.gojoin/") 50 | } 51 | 52 | if err := viper.ReadInConfig(); err != nil { 53 | _, ok := err.(viper.ConfigFileNotFoundError) 54 | if !ok { 55 | return nil, errors.Wrap(err, "reading configuration from files") 56 | } 57 | } 58 | 59 | config := new(Config) 60 | if err := viper.Unmarshal(config); err != nil { 61 | return nil, errors.Wrap(err, "unmarshaling configuration") 62 | } 63 | 64 | config, err = populateConfig(config) 65 | if err != nil { 66 | return nil, errors.Wrap(err, "populating config") 67 | } 68 | 69 | return validateConfig(config) 70 | } 71 | 72 | func validateConfig(config *Config) (*Config, error) { 73 | if config.DBConfig.ConnURL == "" && os.Getenv("DATABASE_URL") != "" { 74 | config.DBConfig.ConnURL = os.Getenv("DATABASE_URL") 75 | } 76 | 77 | if config.DBConfig.Driver == "" && config.DBConfig.ConnURL != "" { 78 | u, err := url.Parse(config.DBConfig.ConnURL) 79 | if err != nil { 80 | return nil, errors.Wrap(err, "parsing db connection url") 81 | } 82 | config.DBConfig.Driver = u.Scheme 83 | } 84 | 85 | if config.Port == 0 && os.Getenv("PORT") != "" { 86 | port, err := strconv.Atoi(os.Getenv("PORT")) 87 | if err != nil { 88 | return nil, errors.Wrap(err, "formatting PORT into int") 89 | } 90 | 91 | config.Port = port 92 | } 93 | 94 | if config.Port == 0 { 95 | config.Port = 7070 96 | } 97 | 98 | return config, nil 99 | } 100 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, gender identity and expression, level of experience, 9 | nationality, personal appearance, race, religion, or sexual identity and 10 | orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at david@netlify.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at [http://contributor-covenant.org/version/1/4][version] 72 | 73 | [homepage]: http://contributor-covenant.org 74 | [version]: http://contributor-covenant.org/version/1/4/ 75 | -------------------------------------------------------------------------------- /api/context.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/sirupsen/logrus" 8 | "github.com/netlify/gojoin/conf" 9 | 10 | "github.com/dgrijalva/jwt-go" 11 | "github.com/jinzhu/gorm" 12 | "golang.org/x/net/context" 13 | ) 14 | 15 | const ( 16 | dbKey = "db" 17 | startTimeKey = "start_time" 18 | versionKey = "app_version" 19 | configKey = "app_config" 20 | loggerKey = "app_logger" 21 | reqIDKey = "request_id" 22 | adminFlagKey = "admin_flag" 23 | tokenKey = "token" 24 | payerProxyKey = "payer_proxy" 25 | ) 26 | 27 | func setStartTime(ctx context.Context, startTime time.Time) context.Context { 28 | return context.WithValue(ctx, startTimeKey, &startTime) 29 | } 30 | func getStartTime(ctx context.Context) *time.Time { 31 | obj := ctx.Value(startTimeKey) 32 | if obj == nil { 33 | return nil 34 | } 35 | return obj.(*time.Time) 36 | } 37 | 38 | func setConfig(ctx context.Context, config *conf.Config) context.Context { 39 | return context.WithValue(ctx, configKey, config) 40 | } 41 | func getConfig(ctx context.Context) *conf.Config { 42 | obj := ctx.Value(configKey) 43 | if obj == nil { 44 | return nil 45 | } 46 | return obj.(*conf.Config) 47 | } 48 | 49 | func setLogger(ctx context.Context, log *logrus.Entry) context.Context { 50 | return context.WithValue(ctx, loggerKey, log) 51 | } 52 | func getLogger(ctx context.Context) *logrus.Entry { 53 | obj := ctx.Value(loggerKey) 54 | if obj == nil { 55 | return logrus.NewEntry(logrus.StandardLogger()) 56 | } 57 | return obj.(*logrus.Entry) 58 | } 59 | 60 | func setRequestID(ctx context.Context, reqID string) context.Context { 61 | return context.WithValue(ctx, reqIDKey, reqID) 62 | } 63 | func getRequestID(ctx context.Context) string { 64 | obj := ctx.Value(reqIDKey) 65 | if obj == nil { 66 | return "" 67 | } 68 | return obj.(string) 69 | } 70 | 71 | func setAdminFlag(ctx context.Context, isAdmin bool) context.Context { 72 | return context.WithValue(ctx, adminFlagKey, isAdmin) 73 | } 74 | 75 | func isAdmin(ctx context.Context) bool { 76 | obj := ctx.Value(adminFlagKey) 77 | if obj == nil { 78 | return false 79 | } 80 | return obj.(bool) 81 | } 82 | 83 | func setDB(ctx context.Context, db *gorm.DB) context.Context { 84 | return context.WithValue(ctx, dbKey, db) 85 | } 86 | func getDB(ctx context.Context) *gorm.DB { 87 | return ctx.Value(dbKey).(*gorm.DB) 88 | } 89 | 90 | func getClaims(ctx context.Context) *JWTClaims { 91 | return ctx.Value(tokenKey).(*jwt.Token).Claims.(*JWTClaims) 92 | } 93 | 94 | func getClaimsAsMap(ctx context.Context) jwt.MapClaims { 95 | token := ctx.Value(tokenKey).(*jwt.Token) 96 | config := getConfig(ctx) 97 | if config == nil { 98 | return nil 99 | } 100 | claims := jwt.MapClaims{} 101 | token, err := jwt.ParseWithClaims(token.Raw, &claims, func(token *jwt.Token) (interface{}, error) { 102 | if token.Header["alg"] != jwt.SigningMethodHS256.Name { 103 | return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) 104 | } 105 | return []byte(config.JWTSecret), nil 106 | }) 107 | if err != nil { 108 | return nil 109 | } 110 | 111 | return claims 112 | } 113 | 114 | func setToken(ctx context.Context, token *jwt.Token) context.Context { 115 | return context.WithValue(ctx, tokenKey, token) 116 | } 117 | 118 | func setPayerProxy(ctx context.Context, proxy payerProxy) context.Context { 119 | return context.WithValue(ctx, payerProxyKey, proxy) 120 | } 121 | func getPayerProxy(ctx context.Context) payerProxy { 122 | obj := ctx.Value(payerProxyKey) 123 | if obj == nil { 124 | return &errorProxy{} 125 | } 126 | return obj.(payerProxy) 127 | } 128 | -------------------------------------------------------------------------------- /glide.lock: -------------------------------------------------------------------------------- 1 | hash: b95e2f65098f25b21f5f94b1aecaef84ba115c24eb7b013656e597cc195a69ed 2 | updated: 2017-07-28T16:18:17.301657768-04:00 3 | imports: 4 | - name: github.com/dgrijalva/jwt-go 5 | version: d2709f9f1f31ebcda9651b03077758c1f3a0018c 6 | - name: github.com/dimfeld/httptreemux 7 | version: 86f7c217d9043ebc6adfd8e2ed04a0bb1e1db651 8 | - name: github.com/fsnotify/fsnotify 9 | version: 7d7316ed6e1ed2de075aab8dfc76de5d158d66e1 10 | - name: github.com/go-sql-driver/mysql 11 | version: 2e00b5cd70399450106cec6431c2e2ce3cae5034 12 | - name: github.com/golang/protobuf 13 | version: 69b215d01a5606c843240eab4937eab3acee6530 14 | subpackages: 15 | - proto 16 | - name: github.com/guregu/kami 17 | version: 556ef16b10fbac3cec79e38bbf26ce4af543608f 18 | subpackages: 19 | - treemux 20 | - name: github.com/hashicorp/hcl 21 | version: 630949a3c5fa3c613328e1b8256052cbc2327c9b 22 | subpackages: 23 | - hcl/ast 24 | - hcl/parser 25 | - hcl/scanner 26 | - hcl/strconv 27 | - hcl/token 28 | - json/parser 29 | - json/scanner 30 | - json/token 31 | - name: github.com/inconshreveable/mousetrap 32 | version: 76626ae9c91c4f2a10f34cad8ce83ea42c93bb75 33 | - name: github.com/jinzhu/gorm 34 | version: 5409931a1bb87e484d68d649af9367c207713ea2 35 | - name: github.com/jinzhu/inflection 36 | version: 1c35d901db3da928c72a72d8458480cc9ade058f 37 | - name: github.com/lib/pq 38 | version: ca5bc43047f2138703da0f3d3ca89a59f3d597f1 39 | subpackages: 40 | - oid 41 | - name: github.com/magiconair/properties 42 | version: b3b15ef068fd0b17ddf408a23669f20811d194d2 43 | - name: github.com/mattn/go-sqlite3 44 | version: eac1dfa2a61ebccaa117538a5bb12044f6700cd0 45 | - name: github.com/mitchellh/mapstructure 46 | version: db1efb556f84b25a0a13a04aad883943538ad2e0 47 | - name: github.com/pborman/uuid 48 | version: a97ce2ca70fa5a848076093f05e639a89ca34d06 49 | - name: github.com/pelletier/go-buffruneio 50 | version: c37440a7cf42ac63b919c752ca73a85067e05992 51 | - name: github.com/pelletier/go-toml 52 | version: 13d49d4606eb801b8f01ae542b4afc4c6ee3d84a 53 | - name: github.com/pkg/errors 54 | version: bfd5150e4e41705ded2129ec33379de1cb90b513 55 | - name: github.com/rs/cors 56 | version: a62a804a8a009876ca59105f7899938a1349f4b3 57 | - name: github.com/rs/xhandler 58 | version: ed27b6fd65218132ee50cd95f38474a3d8a2cd12 59 | - name: github.com/sirupsen/logrus 60 | version: a3f95b5c423586578a4e099b11a46c2479628cac 61 | - name: github.com/spf13/afero 62 | version: 9be650865eab0c12963d8753212f4f9c66cdcf12 63 | subpackages: 64 | - mem 65 | - name: github.com/spf13/cast 66 | version: ce135a4ebeee6cfe9a26c93ee0d37825f26113c7 67 | - name: github.com/spf13/cobra 68 | version: fcd0c5a1df88f5d6784cb4feead962c3f3d0b66c 69 | - name: github.com/spf13/jwalterweatherman 70 | version: fa7ca7e836cf3a8bb4ebf799f472c12d7e903d66 71 | - name: github.com/spf13/pflag 72 | version: 9ff6c6923cfffbcd502984b8e0c80539a94968b7 73 | - name: github.com/spf13/viper 74 | version: 7538d73b4eb9511d85a9f1dfef202eeb8ac260f4 75 | - name: github.com/stretchr/testify 76 | version: 69483b4bd14f5845b5a1e55bca19e954e827f1d0 77 | subpackages: 78 | - assert 79 | - name: github.com/stripe/stripe-go 80 | version: fd0493806620259f607f131587fa2bf1cfdc5218 81 | subpackages: 82 | - customer 83 | - orderitem 84 | - sub 85 | - name: github.com/valyala/fasthttp 86 | version: d42167fd04f636e20b005e9934159e95454233c7 87 | subpackages: 88 | - fasthttputil 89 | - name: github.com/zenazn/goji 90 | version: 4d7077956293261309684d3cf1af673f773c6819 91 | subpackages: 92 | - bind 93 | - graceful 94 | - graceful/listener 95 | - web/mutil 96 | - name: golang.org/x/net 97 | version: d379faa25cbdc04d653984913a2ceb43b0bc46d7 98 | subpackages: 99 | - context 100 | - name: golang.org/x/sys 101 | version: e48874b42435b4347fc52bdee0424a52abc974d7 102 | subpackages: 103 | - unix 104 | - name: golang.org/x/text 105 | version: f28f36722d5ef2f9655ad3de1f248e3e52ad5ebd 106 | subpackages: 107 | - transform 108 | - unicode/norm 109 | - name: google.golang.org/appengine 110 | version: 5403c08c6e8fb3b2dc1209d2d833d8e8ac8240de 111 | subpackages: 112 | - internal 113 | - internal/app_identity 114 | - internal/base 115 | - internal/datastore 116 | - internal/log 117 | - internal/modules 118 | - internal/remote_api 119 | - name: gopkg.in/square/go-jose.v1 120 | version: aa2e30fdd1fe9dd3394119af66451ae790d50e0d 121 | subpackages: 122 | - json 123 | - name: gopkg.in/yaml.v2 124 | version: a3f3340b5840cee44f372bddb5880fcbc419b46a 125 | testImports: 126 | - name: github.com/davecgh/go-spew 127 | version: 6d212800a42e8ab5c146b8ace3490ee17e5225f9 128 | subpackages: 129 | - spew 130 | - name: github.com/klauspost/compress 131 | version: 14c9a76e3c95e47f8ccce949bba2c1101a8b85e6 132 | subpackages: 133 | - flate 134 | - gzip 135 | - zlib 136 | - name: github.com/klauspost/cpuid 137 | version: 09cded8978dc9e80714c4d85b0322337b0a1e5e0 138 | - name: github.com/klauspost/crc32 139 | version: 1bab8b35b6bb565f92cbc97939610af9369f942a 140 | - name: github.com/pmezard/go-difflib 141 | version: d8ed2627bdf02c080bf22230dbb337003b7aba2d 142 | subpackages: 143 | - difflib 144 | -------------------------------------------------------------------------------- /api/api.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "regexp" 9 | "time" 10 | 11 | jwt "github.com/dgrijalva/jwt-go" 12 | "github.com/guregu/kami" 13 | "github.com/pborman/uuid" 14 | "github.com/rs/cors" 15 | "github.com/sirupsen/logrus" 16 | 17 | "github.com/jinzhu/gorm" 18 | "github.com/netlify/gojoin/conf" 19 | "github.com/zenazn/goji/web/mutil" 20 | ) 21 | 22 | type API struct { 23 | log *logrus.Entry 24 | config *conf.Config 25 | port int 26 | handler http.Handler 27 | db *gorm.DB 28 | payerProxy payerProxy 29 | version string 30 | } 31 | 32 | type JWTClaims struct { 33 | jwt.StandardClaims 34 | Email string `json:"email"` 35 | Groups []string `json:"groups"` 36 | } 37 | 38 | var bearerRegexp = regexp.MustCompile(`^(?:B|b)earer (\S+$)`) 39 | var signingMethod = jwt.SigningMethodHS256 40 | 41 | func NewAPI(config *conf.Config, db *gorm.DB, proxy payerProxy, version string) *API { 42 | api := &API{ 43 | log: logrus.WithField("component", "api"), 44 | config: config, 45 | port: config.Port, 46 | db: db, 47 | payerProxy: proxy, 48 | version: version, 49 | } 50 | 51 | k := kami.New() 52 | k.LogHandler = logCompleted 53 | 54 | k.Get("/", api.hello) 55 | 56 | k.Use("/subscriptions/", api.populateConfig) 57 | k.Use("/subscriptions", api.populateConfig) 58 | 59 | k.Get("/subscriptions", listSubs) 60 | k.Get("/subscriptions/:type", viewSub) 61 | k.Put("/subscriptions/:type", createOrModSub) 62 | k.Delete("/subscriptions/:type", deleteSub) 63 | 64 | corsHandler := cors.New(cors.Options{ 65 | AllowedMethods: []string{"GET", "POST", "PATCH", "PUT", "DELETE"}, 66 | AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"}, 67 | AllowCredentials: true, 68 | }) 69 | 70 | api.handler = corsHandler.Handler(k) 71 | return api 72 | } 73 | 74 | func (a *API) Serve() error { 75 | l := fmt.Sprintf(":%d", a.port) 76 | a.log.Infof("GoJoin API started on: %s", l) 77 | return http.ListenAndServe(l, a.handler) 78 | } 79 | 80 | func logCompleted(ctx context.Context, wp mutil.WriterProxy, r *http.Request) { 81 | log := getLogger(ctx).WithField("status", wp.Status()) 82 | 83 | start := getStartTime(ctx) 84 | if start != nil { 85 | log = log.WithField("duration", time.Since(*start).Nanoseconds()) 86 | } 87 | 88 | log.Infof("Completed request %s. path: %s, method: %s, status: %d", getRequestID(ctx), r.URL.Path, r.Method, wp.Status()) 89 | } 90 | 91 | func (a *API) populateConfig(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 92 | reqID := uuid.NewRandom().String() 93 | log := a.log.WithFields(logrus.Fields{ 94 | "request_id": reqID, 95 | "method": r.Method, 96 | "path": r.URL.Path, 97 | }) 98 | log.Info("Started request") 99 | 100 | ctx = setRequestID(ctx, reqID) 101 | ctx = setStartTime(ctx, time.Now()) 102 | ctx = setConfig(ctx, a.config) 103 | ctx = setDB(ctx, a.db) 104 | 105 | ctx = setPayerProxy(ctx, a.payerProxy) 106 | 107 | token, err := extractToken(a.config.JWTSecret, r) 108 | if err != nil { 109 | log.WithError(err).Info("Failed to parse token") 110 | sendJSON(w, err.Code, err) 111 | return nil 112 | } 113 | 114 | if token == nil { 115 | log.Info("Attempted to make unauthenticated request") 116 | writeError(w, http.StatusBadRequest, "Must provide a valid JWT Token") 117 | return nil 118 | } 119 | 120 | claims := token.Claims.(*JWTClaims) 121 | if claims.Subject == "" { 122 | log.Info("JWT token did not contain a sub") 123 | writeError(w, http.StatusBadRequest, "JWT Token must contain a sub") 124 | return nil 125 | } 126 | 127 | adminFlag := false 128 | for _, g := range claims.Groups { 129 | if g == a.config.AdminGroupName { 130 | adminFlag = true 131 | break 132 | } 133 | } 134 | log = log.WithFields(logrus.Fields{ 135 | "is_admin": adminFlag, 136 | "user_id": claims.Subject, 137 | }) 138 | ctx = setAdminFlag(ctx, adminFlag) 139 | ctx = setToken(ctx, token) 140 | ctx = setLogger(ctx, log) 141 | 142 | return ctx 143 | } 144 | 145 | func extractToken(secret string, r *http.Request) (*jwt.Token, *HTTPError) { 146 | authHeader := r.Header.Get("Authorization") 147 | if authHeader == "" { 148 | return nil, nil 149 | } 150 | 151 | matches := bearerRegexp.FindStringSubmatch(authHeader) 152 | if len(matches) != 2 { 153 | return nil, httpError(http.StatusBadRequest, "Bad authentication header") 154 | } 155 | 156 | token, err := jwt.ParseWithClaims(matches[1], &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { 157 | if token.Header["alg"] != signingMethod.Name { 158 | return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) 159 | } 160 | return []byte(secret), nil 161 | }) 162 | if err != nil { 163 | return nil, httpError(http.StatusUnauthorized, "Invalid Token") 164 | } 165 | 166 | claims := token.Claims.(*JWTClaims) 167 | if claims.StandardClaims.ExpiresAt < time.Now().Unix() { 168 | return nil, httpError(http.StatusUnauthorized, fmt.Sprintf("Token expired at %v", time.Unix(claims.StandardClaims.ExpiresAt, 0))) 169 | } 170 | return token, nil 171 | } 172 | 173 | func (a *API) hello(ctx context.Context, w http.ResponseWriter, r *http.Request) { 174 | sendJSON(w, http.StatusOK, map[string]string{ 175 | "version": a.version, 176 | "application": "gojoin", 177 | }) 178 | } 179 | 180 | func sendJSON(w http.ResponseWriter, status int, obj interface{}) { 181 | w.Header().Set("Content-Type", "application/json") 182 | w.WriteHeader(status) 183 | encoder := json.NewEncoder(w) 184 | encoder.Encode(obj) 185 | } 186 | -------------------------------------------------------------------------------- /api/api_test.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io/ioutil" 9 | "net/http" 10 | "net/http/httptest" 11 | "os" 12 | "testing" 13 | "time" 14 | 15 | "github.com/dgrijalva/jwt-go" 16 | "github.com/jinzhu/gorm" 17 | "github.com/sirupsen/logrus" 18 | "github.com/stretchr/testify/assert" 19 | 20 | "github.com/netlify/gojoin/conf" 21 | "github.com/netlify/gojoin/models" 22 | ) 23 | 24 | var db *gorm.DB 25 | var config *conf.Config 26 | var api *API 27 | 28 | var serverURL string 29 | var client *http.Client 30 | 31 | var testUserID = "joker" 32 | var testUserEmail = "joker@dc.com" 33 | 34 | func TestMain(m *testing.M) { 35 | f, err := ioutil.TempFile("", "test-db") 36 | if err != nil { 37 | panic(err) 38 | } 39 | defer os.Remove(f.Name()) 40 | 41 | config = &conf.Config{ 42 | AdminGroupName: "admin", 43 | JWTSecret: "secret", 44 | DBConfig: conf.DBConfig{ 45 | Automigrate: true, 46 | Namespace: "test", 47 | Driver: "sqlite3", 48 | ConnURL: f.Name(), 49 | }, 50 | } 51 | db, err = models.Connect(&config.DBConfig) 52 | 53 | if err != nil { 54 | fmt.Println("Failed to connect to db") 55 | os.Exit(1) 56 | } 57 | logrus.SetLevel(logrus.DebugLevel) 58 | api = NewAPI(config, db, errorProxy{}, "test") 59 | server := httptest.NewServer(api.handler) 60 | defer server.Close() 61 | 62 | serverURL = server.URL 63 | client = new(http.Client) 64 | 65 | os.Exit(m.Run()) 66 | } 67 | 68 | func TestTokenExtraction(t *testing.T) { 69 | tokenString := testToken(t, testUserID, testUserEmail, config.JWTSecret, true) 70 | r, _ := http.NewRequest("GET", "http://doesnotmatter", nil) 71 | r.Header.Add("Authorization", "Bearer "+tokenString) 72 | 73 | token, err := extractToken("secret", r) 74 | assert.Nil(t, err) 75 | if assert.NotNil(t, token) { 76 | assert.Nil(t, token.Claims.Valid()) 77 | outClaims := token.Claims.(*JWTClaims) 78 | assert.Equal(t, "joker", outClaims.Subject) 79 | 80 | foundAdmin := false 81 | for _, g := range outClaims.Groups { 82 | switch g { 83 | case "admin": 84 | foundAdmin = true 85 | default: 86 | assert.Fail(t, "unexpected group: "+g) 87 | } 88 | } 89 | assert.True(t, foundAdmin) 90 | } 91 | } 92 | 93 | func TestBadAuthHeader(t *testing.T) { 94 | r, _ := http.NewRequest("GET", serverURL+"/subscriptions", nil) 95 | r.Header.Add("Authorization", "Bearer NONSENSE") 96 | 97 | rsp, _ := client.Do(r) 98 | extractError(t, http.StatusUnauthorized, rsp) 99 | } 100 | 101 | func TestMissingAuthHeader(t *testing.T) { 102 | r, _ := http.NewRequest("GET", serverURL+"/subscriptions", nil) 103 | 104 | rsp, _ := client.Do(r) 105 | extractError(t, http.StatusBadRequest, rsp) 106 | } 107 | 108 | func TestMiddleware(t *testing.T) { 109 | tokenString := testToken(t, testUserID, testUserEmail, config.JWTSecret, true) 110 | r, _ := http.NewRequest("GET", serverURL+"/subscriptions", nil) 111 | r.Header.Add("Authorization", "Bearer "+tokenString) 112 | 113 | ctx := context.Background() 114 | ctx = api.populateConfig(ctx, nil, r) 115 | 116 | assert.Equal(t, db, getDB(ctx)) 117 | assert.NotEqual(t, "", getRequestID(ctx)) 118 | assert.True(t, isAdmin(ctx)) 119 | assert.Equal(t, config, getConfig(ctx)) 120 | log := getLogger(ctx) 121 | 122 | expectedFields := map[string]bool{ 123 | "request_id": false, 124 | "method": false, 125 | "path": false, 126 | "is_admin": false, 127 | "user_id": false, 128 | } 129 | 130 | for k, v := range log.Data { 131 | assert.NotEqual(t, "", v, k+" was empty") 132 | expectedFields[k] = true 133 | } 134 | for k, v := range expectedFields { 135 | assert.True(t, v, k+" is missing") 136 | } 137 | } 138 | 139 | func TestGetHello(t *testing.T) { 140 | rsp := request(t, "GET", "", nil, false) 141 | payload := make(map[string]interface{}) 142 | extractPayload(t, rsp, &payload) 143 | 144 | _, exists := payload["version"] 145 | assert.True(t, exists) 146 | _, exists = payload["application"] 147 | assert.True(t, exists) 148 | } 149 | 150 | // ------------------------------------------------------------------------------------------------ 151 | // utilities 152 | // ------------------------------------------------------------------------------------------------ 153 | 154 | func request(t *testing.T, method, path string, body interface{}, isAdmin bool) *http.Response { 155 | var r *http.Request 156 | if body != nil { 157 | b, err := json.Marshal(body) 158 | if err != nil { 159 | assert.FailNow(t, "failed to make request: "+err.Error()) 160 | } 161 | 162 | r, _ = http.NewRequest(method, serverURL+path, bytes.NewBuffer(b)) 163 | } else { 164 | r, _ = http.NewRequest(method, serverURL+path, nil) 165 | } 166 | tokenString := testToken(t, testUserID, testUserEmail, config.JWTSecret, isAdmin) 167 | r.Header.Add("Authorization", "Bearer "+tokenString) 168 | 169 | rsp, err := client.Do(r) 170 | if !assert.NoError(t, err) { 171 | assert.FailNow(t, "failed to make request: "+r.URL.String()) 172 | } 173 | 174 | return rsp 175 | } 176 | 177 | func extractPayload(t *testing.T, rsp *http.Response, payload interface{}) { 178 | b, _ := ioutil.ReadAll(rsp.Body) 179 | defer rsp.Body.Close() 180 | if rsp.StatusCode != http.StatusOK { 181 | assert.FailNow(t, fmt.Sprintf("Expected a 200 - %d: with payload: %s", rsp.StatusCode, string(b))) 182 | } 183 | 184 | err := json.Unmarshal(b, payload) 185 | if !assert.NoError(t, err) { 186 | assert.FailNow(t, "Failed to parse payload: "+string(b)) 187 | } 188 | } 189 | 190 | func extractError(t *testing.T, errCode int, rsp *http.Response) *HTTPError { 191 | var err *HTTPError 192 | if assert.Equal(t, errCode, rsp.StatusCode) { 193 | b, _ := ioutil.ReadAll(rsp.Body) 194 | err = new(HTTPError) 195 | e := json.Unmarshal(b, err) 196 | if !assert.NoError(t, e) { 197 | assert.FailNow(t, "Failed to parse payload: "+string(b)) 198 | } 199 | 200 | assert.Equal(t, errCode, err.Code) 201 | assert.NotEmpty(t, err.Message) 202 | } 203 | 204 | return err 205 | } 206 | 207 | func testToken(t *testing.T, name, email, secret string, isAdmin bool) string { 208 | return testTokenWithGroups(t, name, email, secret, isAdmin, []string{}) 209 | } 210 | 211 | func testTokenWithGroups(t *testing.T, name, email, secret string, isAdmin bool, groups []string) string { 212 | claims := &JWTClaims{ 213 | StandardClaims: jwt.StandardClaims{ 214 | Subject: name, 215 | }, 216 | Email: email, 217 | } 218 | claims.ExpiresAt = time.Now().Add(time.Hour).Unix() 219 | 220 | if isAdmin { 221 | claims.Groups = []string{"admin"} 222 | } 223 | 224 | for _, g := range groups { 225 | claims.Groups = append(claims.Groups, g) 226 | } 227 | 228 | tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(secret)) 229 | if !assert.NoError(t, err) { 230 | assert.FailNow(t, "failed to create token") 231 | } 232 | return tokenString 233 | } 234 | 235 | func decodeToken(t *testing.T, jwtString, secret string) jwt.MapClaims { 236 | claims := jwt.MapClaims{} 237 | _, err := jwt.ParseWithClaims(jwtString, &claims, func(token *jwt.Token) (interface{}, error) { 238 | if assert.Equal(t, token.Header["alg"], signingMethod.Name) { 239 | return []byte(secret), nil 240 | } 241 | return nil, nil 242 | }) 243 | if assert.NoError(t, err) { 244 | return claims 245 | } 246 | return nil 247 | } 248 | -------------------------------------------------------------------------------- /api/subscriptions.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "fmt" 8 | "strings" 9 | 10 | jwt "github.com/dgrijalva/jwt-go" 11 | "github.com/guregu/kami" 12 | "github.com/netlify/gojoin/models" 13 | "github.com/sirupsen/logrus" 14 | "gopkg.in/square/go-jose.v1/json" 15 | ) 16 | 17 | type subscriptionRequest struct { 18 | StripeKey string `json:"stripe_key"` 19 | Plan string `json:"plan"` 20 | } 21 | 22 | func (s subscriptionRequest) Valid() error { 23 | missing := []string{} 24 | if s.StripeKey == "" { 25 | missing = append(missing, "stripe_key") 26 | } 27 | if s.Plan == "" { 28 | missing = append(missing, "plan") 29 | } 30 | 31 | if len(missing) > 0 { 32 | return fmt.Errorf("Missing fields: " + strings.Join(missing, ",")) 33 | } 34 | 35 | return nil 36 | } 37 | 38 | type getAllResponse struct { 39 | Subscriptions []models.Subscription `json:"subscriptions"` 40 | Token string `json:"token"` 41 | } 42 | 43 | // listSubs will query stripe for all the subscriptions for a given user. 44 | // it also returns a newly decorated token. The 'groups' are added as: 'subs..' 45 | func listSubs(ctx context.Context, w http.ResponseWriter, r *http.Request) { 46 | log := getLogger(ctx) 47 | claims := getClaims(ctx) 48 | db := getDB(ctx) 49 | 50 | subs := []models.Subscription{} 51 | if rsp := db.Where("user_id = ? ", claims.Subject).Find(&subs); rsp.Error != nil { 52 | if rsp.RecordNotFound() { 53 | notFoundError(w, "Found no records associated with user id %s", claims.Subject) 54 | } else { 55 | log.WithError(rsp.Error).Warnf("Failed to find records associated with %s", claims.Subject) 56 | writeError(w, http.StatusInternalServerError, "DB error while searching for subscriptions") 57 | } 58 | return 59 | } 60 | 61 | log.Debugf("Found %d subscriptions associated with id %s", len(subs), claims.Subject) 62 | 63 | response := &getAllResponse{ 64 | Subscriptions: subs, 65 | } 66 | 67 | claimsMap := getClaimsAsMap(ctx) 68 | app_metadata, ok := claimsMap["app_metadata"] 69 | var metadata map[string]interface{} 70 | if ok && app_metadata != nil { 71 | metadata = app_metadata.(map[string]interface{}) 72 | } else { 73 | metadata = map[string]interface{}{} 74 | app_metadata = metadata 75 | } 76 | subsClaim := map[string]string{} 77 | metadata["subscriptions"] = subsClaim 78 | 79 | for _, sub := range subs { 80 | subsClaim[sub.Type] = sub.Plan 81 | } 82 | claimsMap["app_metadata"] = app_metadata 83 | 84 | // now we need to re-serialize the token 85 | config := getConfig(ctx) 86 | signed, err := jwt.NewWithClaims(signingMethod, claimsMap).SignedString([]byte(config.JWTSecret)) 87 | if err != nil { 88 | log.WithError(err).Warnf("Error while creating new signed token") 89 | writeError(w, http.StatusInternalServerError, "Error while creating new signed token") 90 | } 91 | response.Token = signed 92 | 93 | sendJSON(w, http.StatusOK, response) 94 | } 95 | 96 | func viewSub(ctx context.Context, w http.ResponseWriter, r *http.Request) { 97 | subType := kami.Param(ctx, "type") 98 | claims := getClaims(ctx) 99 | sub, err := getSubscription(ctx, claims.Subject, subType) 100 | if err != nil { 101 | sendJSON(w, err.Code, err) 102 | } 103 | if sub == nil { 104 | writeError(w, http.StatusNotFound, "No subscription found") 105 | return 106 | } 107 | 108 | sendJSON(w, http.StatusOK, sub) 109 | } 110 | 111 | func deleteSub(ctx context.Context, w http.ResponseWriter, r *http.Request) { 112 | subType := kami.Param(ctx, "type") 113 | claims := getClaims(ctx) 114 | sub, err := getSubscription(ctx, claims.Subject, subType) 115 | if err != nil { 116 | sendJSON(w, err.Code, err) 117 | } 118 | 119 | if sub != nil { 120 | log := getLogger(ctx).WithField("type", subType) 121 | 122 | pp := getPayerProxy(ctx) 123 | err := pp.delete(sub.RemoteID) 124 | if err != nil { 125 | writeError(w, http.StatusBadRequest, "Error communicating with stripe: %s", err) 126 | return 127 | } 128 | 129 | log.Info("Removed subscription from stripe") 130 | rsp := getDB(ctx).Delete(sub) 131 | if rsp.Error != nil { 132 | log.WithError(rsp.Error).Warnf("Error while deleting subscription %+v", sub) 133 | writeError(w, http.StatusInternalServerError, "Error while deleting subscription") 134 | return 135 | } 136 | 137 | log.Info("Removed subscription from db") 138 | } 139 | 140 | sendJSON(w, http.StatusAccepted, struct{}{}) 141 | } 142 | 143 | func createOrModSub(ctx context.Context, w http.ResponseWriter, r *http.Request) { 144 | payload, httpErr := extractValidPayload(r) 145 | if httpErr != nil { 146 | sendJSON(w, httpErr.Code, httpErr) 147 | return 148 | } 149 | 150 | subType := kami.Param(ctx, "type") 151 | log := getLogger(ctx).WithFields(logrus.Fields{ 152 | "plan": payload.Plan, 153 | "type": subType, 154 | }) 155 | ctx = setLogger(ctx, log) 156 | 157 | // do we have a subscription already? 158 | claims := getClaims(ctx) 159 | sub, httpErr := getSubscription(ctx, claims.Subject, subType) 160 | if httpErr != nil { 161 | sendJSON(w, httpErr.Code, httpErr) 162 | return 163 | } 164 | 165 | if sub == nil { 166 | log.Debug("Starting to create new subscription") 167 | sub, httpErr = createSub(ctx, subType, payload) 168 | } else { 169 | log.WithField("old_plan", sub.Plan).Debug("Starting to update subscription") 170 | httpErr = updateSub(ctx, sub, payload) 171 | } 172 | 173 | if httpErr != nil { 174 | sendJSON(w, httpErr.Code, httpErr) 175 | return 176 | } 177 | 178 | sendJSON(w, http.StatusOK, sub) 179 | } 180 | 181 | func createSub(ctx context.Context, subType string, payload *subscriptionRequest) (*models.Subscription, *HTTPError) { 182 | log := getLogger(ctx) 183 | pp := getPayerProxy(ctx) 184 | claims := getClaims(ctx) 185 | db := getDB(ctx) 186 | 187 | // do we have a user? 188 | user := &models.User{ 189 | ID: claims.Subject, 190 | } 191 | if rsp := db.Where(user).Find(user); rsp.Error != nil { 192 | if rsp.RecordNotFound() { 193 | remoteID, err := pp.createCustomer(claims.Subject, claims.Email, payload.StripeKey) 194 | if err != nil { 195 | return nil, httpError(http.StatusInternalServerError, "Failed to create new customer in stripe") 196 | } 197 | user.RemoteID = remoteID 198 | user.Email = claims.Email 199 | 200 | if rsp := db.Save(user); rsp.Error != nil { 201 | log.WithError(rsp.Error).Warnf("Failed to save new user with remote ID %s", remoteID) 202 | return nil, httpError(http.StatusInternalServerError, "Failed to save customer to db: %d", remoteID) 203 | } 204 | log.Infof("Created new user with remote ID: %s", user.RemoteID) 205 | } else { 206 | log.WithError(rsp.Error).Warn("Failed to find user") 207 | return nil, httpError(http.StatusInternalServerError, "Failed to find the user specified") 208 | } 209 | } else { 210 | log.WithField("remote_id", user.RemoteID).Debug("Found existing user") 211 | } 212 | 213 | // create the subscription 214 | subRemoteID, err := pp.create(user.RemoteID, payload.Plan, payload.StripeKey) 215 | if err != nil { 216 | log.WithError(err).Info("Failed to create sub in stripe") 217 | return nil, httpError(http.StatusBadRequest, "Failed create new subscription for plan %s", payload.Plan) 218 | } 219 | 220 | sub := &models.Subscription{ 221 | RemoteID: subRemoteID, 222 | UserID: user.ID, 223 | Plan: payload.Plan, 224 | Type: subType, 225 | } 226 | 227 | rsp := getDB(ctx).Create(sub) 228 | if rsp.Error != nil { 229 | log.WithError(rsp.Error).Warnf("Failed to create new subscription after successful stripe call: %+v", sub) 230 | return nil, httpError(http.StatusInternalServerError, "Error while creating db entry, but stripe call was successful") 231 | } 232 | 233 | return sub, nil 234 | } 235 | 236 | func updateSub(ctx context.Context, existing *models.Subscription, payload *subscriptionRequest) *HTTPError { 237 | log := getLogger(ctx) 238 | pp := getPayerProxy(ctx) 239 | 240 | remoteID, err := pp.update(existing.RemoteID, payload.Plan, payload.StripeKey) 241 | if err != nil { 242 | log.WithError(err).Info("Failed to create sub in stripe") 243 | return httpError(http.StatusBadRequest, "Failed updating subscription %s to plan %s", existing.RemoteID, payload.Plan) 244 | } 245 | 246 | existing.RemoteID = remoteID 247 | existing.Plan = payload.Plan 248 | 249 | rsp := getDB(ctx).Save(existing) 250 | if rsp.Error != nil { 251 | log.WithError(rsp.Error).Warnf("Failed to create new subscription after successful stripe call: %+v", existing) 252 | return httpError(http.StatusInternalServerError, "Error while creating db entry, but stripe call was successful") 253 | } 254 | 255 | return nil 256 | } 257 | 258 | func getSubscription(ctx context.Context, userID string, planType string) (*models.Subscription, *HTTPError) { 259 | log := getLogger(ctx).WithField("type", planType) 260 | db := getDB(ctx) 261 | sub := &models.Subscription{ 262 | Type: planType, 263 | UserID: userID, 264 | } 265 | 266 | if rsp := db.Where(sub).First(sub); rsp.Error != nil { 267 | 268 | if rsp.RecordNotFound() { 269 | log.Debug("Didn't find record") 270 | return nil, nil 271 | } 272 | forString := fmt.Sprintf("Error while searching for subscription user %s and type %s", userID, planType) 273 | log.WithError(rsp.Error).Warnf(forString) 274 | return nil, httpError(http.StatusInternalServerError, forString) 275 | } 276 | 277 | log.Debug("Successfully retrieved subscription") 278 | return sub, nil 279 | } 280 | 281 | func extractValidPayload(r *http.Request) (*subscriptionRequest, *HTTPError) { 282 | payload := new(subscriptionRequest) 283 | defer r.Body.Close() 284 | if err := json.NewDecoder(r.Body).Decode(payload); err != nil { 285 | return nil, httpError(http.StatusBadRequest, "failed to decode payload: "+err.Error()) 286 | } 287 | if err := payload.Valid(); err != nil { 288 | return nil, httpError(http.StatusBadRequest, "Failed to provide a valid request: "+err.Error()) 289 | } 290 | return payload, nil 291 | } 292 | -------------------------------------------------------------------------------- /api/subscriptions_test.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "testing" 5 | 6 | "io/ioutil" 7 | 8 | "net/http" 9 | 10 | "github.com/netlify/gojoin/models" 11 | "github.com/pborman/uuid" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/valyala/fasthttp" 14 | ) 15 | 16 | func TestQueryForAllSubsAsUser(t *testing.T) { 17 | tu1 := createUser(testUserID, testUserEmail, "some-stripe-value") 18 | tu2 := createUser("batman", "bruce@dc.com", "eulav-epits-emos") 19 | s1 := createSubscription(testUserID, "membership", "nonsense") 20 | s2 := createSubscription(testUserID, "revenue", "more-nonsense") 21 | s3 := createSubscription("batman", "membership", "nonsense") 22 | defer cleanup(s1, s2, s3, tu1, tu2) 23 | 24 | rsp := request(t, "GET", "/subscriptions", nil, false) 25 | body := new(getAllResponse) 26 | extractPayload(t, rsp, &body) 27 | 28 | assert.Equal(t, 2, len(body.Subscriptions)) 29 | for _, s := range body.Subscriptions { 30 | switch s.ID { 31 | case s1.ID: 32 | validateSub(t, s1, &s) 33 | case s2.ID: 34 | validateSub(t, s2, &s) 35 | default: 36 | assert.Fail(t, "unexpected sub: "+s.ID) 37 | } 38 | } 39 | 40 | assert.NotEmpty(t, body.Token) 41 | claims := decodeToken(t, body.Token, config.JWTSecret) 42 | if assert.NotNil(t, claims) { 43 | meta, ok := claims["app_metadata"].(map[string]interface{}) 44 | if !ok { 45 | assert.Fail(t, "No app_metadata in token") 46 | } 47 | subs, ok := meta["subscriptions"] 48 | if !ok { 49 | assert.Fail(t, "No subscriptions in metadata") 50 | } 51 | 52 | subsMap, ok := subs.(map[string]interface{}) 53 | if !ok { 54 | assert.Fail(t, "Subscriptions is not a map") 55 | } 56 | assert.Equal(t, "nonsense", subsMap["membership"]) 57 | assert.Equal(t, "more-nonsense", subsMap["revenue"]) 58 | } 59 | } 60 | 61 | func TestQueryForSingleSubAsUser(t *testing.T) { 62 | tu := createUser(testUserID, testUserEmail, "stripe-given-value") 63 | s1 := createSubscription(testUserID, "membership", "nonsense") 64 | s2 := createSubscription(testUserID, "revenue", "more-nonsense") 65 | defer cleanup(s1, s2, tu) 66 | 67 | rsp := request(t, "GET", "/subscriptions/membership", nil, false) 68 | sub := new(models.Subscription) 69 | extractPayload(t, rsp, sub) 70 | 71 | validateSub(t, s1, sub) 72 | } 73 | 74 | func TestRemoveSubscriptionAsUser(t *testing.T) { 75 | tp := &testProxy{} 76 | api.payerProxy = tp 77 | tu := createUser(testUserID, testUserEmail, "stripe-given-value") 78 | s1 := createSubscription(testUserID, "membership", "nonsense") 79 | s2 := createSubscription(testUserID, "revenue", "more-nonsense") 80 | defer cleanup(s1, s2, tu) 81 | 82 | rsp := request(t, "DELETE", "/subscriptions/membership", nil, false) 83 | 84 | b, _ := ioutil.ReadAll(rsp.Body) 85 | assert.Equal(t, "{}\n", string(b)) 86 | assert.Equal(t, 202, rsp.StatusCode) 87 | 88 | found := &models.Subscription{ID: s1.ID} 89 | dbRsp := db.Unscoped().Find(found) 90 | if assert.Nil(t, dbRsp.Error) { 91 | assert.NotNil(t, found.DeletedAt) 92 | } 93 | 94 | // validate it was removed in stripe 95 | assert.Len(t, tp.deleteCalls, 1) 96 | assert.Equal(t, s1.RemoteID, tp.deleteCalls[0]) 97 | } 98 | 99 | func TestRemoveSubNotFound(t *testing.T) { 100 | rsp := request(t, "DELETE", "/subscriptions/membership", nil, false) 101 | 102 | b, _ := ioutil.ReadAll(rsp.Body) 103 | assert.Equal(t, "{}\n", string(b)) 104 | assert.Equal(t, 202, rsp.StatusCode) 105 | } 106 | 107 | func TestGetSubNotFound(t *testing.T) { 108 | rsp := request(t, "GET", "/subscriptions/membership", nil, false) 109 | extractError(t, 404, rsp) 110 | } 111 | 112 | func TestCreateNewSubscription(t *testing.T) { 113 | tp := &testProxy{createSubID: "remote-id", createCustomerID: "remote-user-id"} 114 | api.payerProxy = tp 115 | defer func() { api.payerProxy = &errorProxy{} }() 116 | 117 | payload := &subscriptionRequest{ 118 | StripeKey: "something", 119 | Plan: "super-important", 120 | } 121 | rsp := request(t, "PUT", "/subscriptions/membership", payload, false) 122 | 123 | expectedSub := models.Subscription{ 124 | Type: "membership", 125 | UserID: testUserID, 126 | Plan: "super-important", 127 | RemoteID: "remote-id", 128 | } 129 | expectedUser := models.User{ 130 | ID: testUserID, 131 | Email: testUserEmail, 132 | RemoteID: "remote-user-id", 133 | } 134 | 135 | dbRsp, dbUser := validateResponseAndDBVal(t, rsp, &expectedSub, &expectedUser) 136 | cleanup(dbRsp, dbUser) 137 | 138 | assert.Len(t, tp.createCalls, 1) 139 | call := tp.createCalls[0] 140 | assert.Equal(t, "super-important", call.plan) 141 | assert.Equal(t, "something", call.token) 142 | assert.Equal(t, "remote-user-id", call.userID) 143 | assert.Empty(t, tp.updateCalls) 144 | } 145 | 146 | func TestModifySubscription(t *testing.T) { 147 | tp := &testProxy{updateSubID: "remote-id"} 148 | api.payerProxy = tp 149 | defer func() { api.payerProxy = &errorProxy{} }() 150 | 151 | tu := createUser(testUserID, testUserEmail, "stripe-given-value") 152 | s1 := createSubscription(testUserID, "pokemon", "magicarp") 153 | defer cleanup(s1, tu) 154 | 155 | payload := &subscriptionRequest{ 156 | StripeKey: "something", 157 | Plan: "charizard", 158 | } 159 | rsp := request(t, "PUT", "/subscriptions/pokemon", payload, false) 160 | expectedSub := &models.Subscription{ 161 | Type: "pokemon", 162 | UserID: testUserID, 163 | Plan: "charizard", 164 | RemoteID: "remote-id", 165 | } 166 | expectedUser := &models.User{ 167 | ID: testUserID, 168 | Email: testUserEmail, 169 | RemoteID: "stripe-given-value", 170 | } 171 | dbRsp, dbUser := validateResponseAndDBVal(t, rsp, expectedSub, expectedUser) 172 | cleanup(dbRsp, dbUser) 173 | 174 | assert.Len(t, tp.updateCalls, 1) 175 | call := tp.updateCalls[0] 176 | assert.Equal(t, "charizard", call.plan) 177 | assert.Equal(t, "something", call.token) 178 | assert.Equal(t, s1.RemoteID, call.subID) 179 | assert.Empty(t, tp.createCalls) 180 | 181 | assert.Len(t, tp.createCustomerCalls, 0) 182 | } 183 | 184 | func TestCreateNewSubscriptionWithBadPayload(t *testing.T) { 185 | payload := &subscriptionRequest{ 186 | StripeKey: "something", 187 | Plan: "", 188 | } 189 | rsp := request(t, "PUT", "/subscriptions/membership", payload, false) 190 | extractError(t, fasthttp.StatusBadRequest, rsp) 191 | 192 | payload.StripeKey = "" 193 | payload.Plan = "something" 194 | rsp = request(t, "PUT", "/subscriptions/membership", payload, false) 195 | extractError(t, fasthttp.StatusBadRequest, rsp) 196 | } 197 | 198 | func TestCreateNewSubscriptionWithStripeError(t *testing.T) { 199 | defer cleanup(createUser(testUserID, testUserEmail, "remote-id")) 200 | api.payerProxy = errorProxy{} 201 | payload := &subscriptionRequest{ 202 | StripeKey: "something", 203 | Plan: "unicorn", 204 | } 205 | rsp := request(t, "PUT", "/subscriptions/membership", payload, false) 206 | extractError(t, fasthttp.StatusBadRequest, rsp) 207 | } 208 | 209 | // ------------------------------------------------------------------------------------------------ 210 | // helpers 211 | // ------------------------------------------------------------------------------------------------ 212 | 213 | func validateSub(t *testing.T, expected *models.Subscription, actual *models.Subscription) { 214 | assert.Equal(t, expected.ID, actual.ID) 215 | assert.Equal(t, expected.UserID, actual.UserID) 216 | assert.Equal(t, expected.Type, actual.Type) 217 | assert.Equal(t, expected.RemoteID, actual.RemoteID) 218 | assert.Equal(t, expected.Plan, actual.Plan) 219 | 220 | assert.NotEmpty(t, actual.ID) 221 | assert.NotEmpty(t, actual.UserID) 222 | assert.NotEmpty(t, actual.Type) 223 | assert.NotEmpty(t, actual.RemoteID) 224 | assert.NotEmpty(t, actual.Plan) 225 | } 226 | 227 | func createUser(userID, email, remoteID string) *models.User { 228 | user := &models.User{ 229 | Email: email, 230 | ID: userID, 231 | RemoteID: remoteID, 232 | } 233 | db.Create(user) 234 | return user 235 | } 236 | 237 | func createSubscription(userID, planType string, plan string) *models.Subscription { 238 | sub := &models.Subscription{ 239 | UserID: userID, 240 | Plan: plan, 241 | Type: planType, 242 | RemoteID: uuid.NewRandom().String(), 243 | } 244 | 245 | db.Create(sub) 246 | return sub 247 | } 248 | 249 | func cleanup(todelete ...interface{}) { 250 | for _, td := range todelete { 251 | db.Unscoped().Delete(td) 252 | } 253 | } 254 | 255 | type testProxy struct { 256 | createSubID string 257 | createCalls []struct { 258 | userID string 259 | plan string 260 | token string 261 | } 262 | updateSubID string 263 | updateCalls []struct { 264 | subID string 265 | plan string 266 | token string 267 | } 268 | deleteCalls []string 269 | 270 | createCustomerID string 271 | createCustomerCalls []struct { 272 | userID string 273 | email string 274 | token string 275 | } 276 | } 277 | 278 | func (tp *testProxy) createCustomer(userID, email, payToken string) (string, error) { 279 | tp.createCustomerCalls = append(tp.createCustomerCalls, struct { 280 | userID string 281 | email string 282 | token string 283 | }{userID, email, payToken}) 284 | return tp.createCustomerID, nil 285 | } 286 | 287 | func (tp *testProxy) delete(subID string) error { 288 | tp.deleteCalls = append(tp.deleteCalls, subID) 289 | return nil 290 | } 291 | 292 | func (tp *testProxy) create(userID, plan, token string) (string, error) { 293 | tp.createCalls = append(tp.createCalls, struct { 294 | userID string 295 | plan string 296 | token string 297 | }{userID, plan, token}) 298 | return tp.createSubID, nil 299 | } 300 | 301 | func (tp *testProxy) update(subID, plan, token string) (string, error) { 302 | tp.updateCalls = append(tp.updateCalls, struct { 303 | subID string 304 | plan string 305 | token string 306 | }{subID, plan, token}) 307 | return tp.updateSubID, nil 308 | } 309 | 310 | func validateResponseAndDBVal(t *testing.T, rsp *http.Response, expected *models.Subscription, expectedUser *models.User) (*models.Subscription, *models.User) { 311 | var dbSub *models.Subscription 312 | var dbUser *models.User 313 | 314 | if assert.Equal(t, http.StatusOK, rsp.StatusCode) { 315 | rspSub := new(models.Subscription) 316 | extractPayload(t, rsp, rspSub) 317 | if rspSub.ID == "" { 318 | assert.FailNow(t, "Failed to get a valid subscription") 319 | } 320 | 321 | expected.ID = rspSub.ID 322 | validateSub(t, expected, rspSub) 323 | 324 | dbSub = &models.Subscription{ 325 | ID: rspSub.ID, 326 | } 327 | dbRsp := db.Where(dbSub).First(dbSub) 328 | if assert.NoError(t, dbRsp.Error) { 329 | validateSub(t, expected, dbSub) 330 | } 331 | 332 | dbUser = new(models.User) 333 | userRsp := db.Where("id = ?", expectedUser.ID).Find(dbUser) 334 | if assert.NoError(t, userRsp.Error) { 335 | assert.Equal(t, expectedUser.Email, dbUser.Email) 336 | assert.Equal(t, expectedUser.RemoteID, dbUser.RemoteID) 337 | } 338 | } 339 | 340 | return dbSub, dbUser 341 | } 342 | --------------------------------------------------------------------------------