├── README.md ├── cmd └── app │ ├── app.go │ └── main.go └── pkg ├── config.go ├── config └── config.go ├── credentials.go ├── crypto.go ├── crypto ├── crypto.go └── crypto_test.go ├── mock └── user_service_mock.go ├── mongo ├── mongo_test.go ├── session.go ├── user_model.go └── user_service.go ├── server ├── auth.go ├── response.go ├── server.go ├── user_router.go └── user_router_test.go └── user.go /README.md: -------------------------------------------------------------------------------- 1 | # go_rest_api - https://hackernoon.com/make-yourself-a-go-web-server-with-mongodb-go-on-go-on-go-on-48f394f24e 2 | Go web server with mongoDb, gorilla toolkit, and jwt authentication 3 | 4 | This project follows the package layout recommended by Ben Johnson in his Medium article - Standard package layout 5 | https://medium.com/@benbjohnson/standard-package-layout-7cdbc8391fc1 6 | 7 | Structure 8 | ~~~ 9 | go_rest_api 10 | /cmd 11 | /app 12 | -app.go 13 | -main.go 14 | /pkg 15 | /mongo 16 | -session.go 17 | -user_service.go 18 | -user_model.go 19 | /server 20 | -server.go 21 | -user_router.go 22 | -response.go 23 | -auth.go 24 | /mock 25 | -user_service_mock.go 26 | -user.go 27 | -credentials.go 28 | ~~~ 29 | 30 | Dependancies 31 | - go get gopkg.in/mgo.v2 32 | - go get github.com/gorilla/mux 33 | - go get github.com/gorilla/handlers 34 | - go get golang.org/x/crypto/bcrypt 35 | - go get github.com/google/uuid 36 | 37 | References 38 | - https://medium.com/@benbjohnson/standard-package-layout-7cdbc8391fc1#.sdcvblyts 39 | - https://semaphoreci.com/community/tutorials/building-and-testing-a-rest-api-in-go-with-gorilla-mux-and-postgresql 40 | - https://dinosaurscode.xyz/go/2016/06/17/golang-jwt-authentication/ 41 | - https://medium.com/@matryer/context-keys-in-go-5312346a868d#.hb4spbx1a 42 | -------------------------------------------------------------------------------- /cmd/app/app.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "go_rest_api/pkg" 6 | "go_rest_api/pkg/config" 7 | "go_rest_api/pkg/mongo" 8 | "go_rest_api/pkg/server" 9 | "log" 10 | ) 11 | 12 | type App struct { 13 | server *server.Server 14 | session *mongo.Session 15 | config *root.Config 16 | } 17 | 18 | func (a *App) Initialize() { 19 | a.config = config.GetConfig() 20 | var err error 21 | a.session, err = mongo.NewSession(a.config.Mongo) 22 | if err != nil { 23 | log.Fatalln("unable to connect to mongodb") 24 | } 25 | 26 | u := mongo.NewUserService(a.session.Copy(), a.config.Mongo) 27 | a.server = server.NewServer(u, a.config) 28 | } 29 | 30 | func (a *App) Run() { 31 | fmt.Println("Run") 32 | defer a.session.Close() 33 | a.server.Start() 34 | } 35 | -------------------------------------------------------------------------------- /cmd/app/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func main() { 4 | a := App{} 5 | a.Initialize() 6 | a.Run() 7 | } 8 | -------------------------------------------------------------------------------- /pkg/config.go: -------------------------------------------------------------------------------- 1 | package root 2 | 3 | type MongoConfig struct { 4 | Ip string `json:"ip"` 5 | DbName string `json:"dbName"` 6 | } 7 | 8 | type ServerConfig struct { 9 | Port string `json:"port"` 10 | } 11 | 12 | type AuthConfig struct { 13 | Secret string `json:"secret"` 14 | } 15 | 16 | type Config struct { 17 | Mongo *MongoConfig `json:"mongo"` 18 | Server *ServerConfig `json:"server"` 19 | Auth *AuthConfig `json:"auth"` 20 | } -------------------------------------------------------------------------------- /pkg/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "os" 5 | "go_rest_api/pkg" 6 | ) 7 | 8 | func GetConfig() *root.Config { 9 | return &root.Config { 10 | Mongo: &root.MongoConfig { 11 | Ip: envOrDefaultString("go_rest_api:mongo:ip", "127.0.0.1:27017"), 12 | DbName: envOrDefaultString("go_rest_api:mongo:dbName", "myDb")}, 13 | Server: &root.ServerConfig { Port: envOrDefaultString("go_rest_api:server:port", ":1377")}, 14 | Auth: &root.AuthConfig { Secret: envOrDefaultString("go_rest_api:auth:secret", "mysecret")}} 15 | } 16 | 17 | func envOrDefaultString(envVar string, defaultValue string) string { 18 | value := os.Getenv(envVar) 19 | if value == "" { 20 | return defaultValue; 21 | } 22 | 23 | return value 24 | } -------------------------------------------------------------------------------- /pkg/credentials.go: -------------------------------------------------------------------------------- 1 | package root 2 | 3 | type Credentials struct { 4 | Username string `json:"username"` 5 | Password string `json:"password"` 6 | } -------------------------------------------------------------------------------- /pkg/crypto.go: -------------------------------------------------------------------------------- 1 | package root 2 | 3 | type Crypto interface { 4 | Salt(s string) (error, string) 5 | Compare(hash string, s string) (error, bool) 6 | } -------------------------------------------------------------------------------- /pkg/crypto/crypto.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "errors" 5 | "github.com/google/uuid" 6 | "golang.org/x/crypto/bcrypt" 7 | "strings" 8 | ) 9 | 10 | //Hash implements root.Hash 11 | type Crypto struct{} 12 | 13 | var deliminator = "||" 14 | 15 | //Generate a salted hash for the input string 16 | func (c *Crypto) Generate(s string) (string, error) { 17 | salt := uuid.New().String() 18 | saltedBytes := []byte(s + salt) 19 | hashedBytes, err := bcrypt.GenerateFromPassword(saltedBytes, bcrypt.DefaultCost) 20 | if err != nil { 21 | return "", err 22 | } 23 | 24 | hash := string(hashedBytes[:]) 25 | return hash + deliminator + salt, nil 26 | } 27 | 28 | //Compare string to generated hash 29 | func (c *Crypto) Compare(hash string, s string) error { 30 | parts := strings.Split(hash, deliminator) 31 | if len(parts) != 2 { 32 | return errors.New("Invalid hash, must have 2 parts") 33 | } 34 | 35 | incoming := []byte(s + parts[1]) 36 | existing := []byte(parts[0]) 37 | return bcrypt.CompareHashAndPassword(existing, incoming) 38 | } 39 | -------------------------------------------------------------------------------- /pkg/crypto/crypto_test.go: -------------------------------------------------------------------------------- 1 | package crypto_test 2 | 3 | import ( 4 | "go_web_server/pkg/crypto" 5 | "testing" 6 | ) 7 | 8 | func Test_Hash(t *testing.T) { 9 | t.Run("Can hash and compare", should_be_able_to_hash_and_compare_strings) 10 | t.Run("Can detect unequal hashes", should_return_error_when_comparing_unequal_hashes) 11 | t.Run("Generates a different salt every time", should_generate_a_different_salt_each_time) 12 | } 13 | 14 | func should_be_able_to_hash_and_compare_strings(t *testing.T) { 15 | //Arrange 16 | c := crypto.Hash{} 17 | testInput := "testInput" 18 | 19 | //Act 20 | generatedHash, generateError := c.Generate(testInput) 21 | compareError := c.Compare(generatedHash, testInput) 22 | 23 | //Assert 24 | if generateError != nil { 25 | t.Error("Error generating hash") 26 | } 27 | if testInput == generatedHash { 28 | t.Error("Generated hash is the same as input") 29 | } 30 | if compareError != nil { 31 | t.Error("Error comparing hash to input") 32 | } 33 | } 34 | 35 | func should_return_error_when_comparing_unequal_hashes(t *testing.T) { 36 | //Arrange 37 | c := crypto.Hash{} 38 | testInput := "testInput" 39 | testCompare := "testCompare" 40 | 41 | //Act 42 | generatedHash, generateError := c.Generate(testInput) 43 | compareError := c.Compare(generatedHash, testCompare) 44 | 45 | //Assert 46 | if generateError != nil { 47 | t.Error("Error generating hash") 48 | } 49 | if testInput == generatedHash { 50 | t.Error("Generated hash is the same as input") 51 | } 52 | if compareError == nil { 53 | t.Error("Compare should not have been successful") 54 | } 55 | } 56 | 57 | func should_generate_a_different_salt_each_time(t *testing.T) { 58 | //Arrange 59 | c := crypto.Hash{} 60 | testInput := "testInput" 61 | 62 | hash1, _ := c.Generate(testInput) 63 | hash2, _ := c.Generate(testInput) 64 | 65 | if hash1 == hash2 { 66 | t.Error("Subsequent hashes should not be equal") 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /pkg/mock/user_service_mock.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import "go_rest_api/pkg" 4 | 5 | type UserService struct { 6 | CreateUserFn func(u *root.User) error 7 | CreateUserInvoked bool 8 | 9 | GetUserByUsernameFn func(username string) (error, root.User) 10 | GetUserByUsernameInvoked bool 11 | 12 | LoginFn func(c root.Credentials) (error, root.User) 13 | LoginInvoked bool 14 | } 15 | 16 | func(us *UserService) CreateUser(u *root.User) error { 17 | us.CreateUserInvoked = true 18 | return us.CreateUserFn(u) 19 | } 20 | 21 | func(us *UserService) GetUserByUsername(username string) (error, root.User) { 22 | us.GetUserByUsernameInvoked = true 23 | return us.GetUserByUsernameFn(username) 24 | } 25 | 26 | func(us *UserService) Login(c root.Credentials) (error, root.User) { 27 | us.LoginInvoked = true 28 | return us.LoginFn(c) 29 | } -------------------------------------------------------------------------------- /pkg/mongo/mongo_test.go: -------------------------------------------------------------------------------- 1 | package mongo_test 2 | 3 | import ( 4 | "log" 5 | "testing" 6 | "go_rest_api/pkg" 7 | "go_rest_api/pkg/mongo" 8 | ) 9 | 10 | 11 | const ( 12 | mongoUrl = "localhost:27017" 13 | dbName = "test_db" 14 | userCollectionName = "user" 15 | ) 16 | 17 | func Test_UserService(t *testing.T) { 18 | t.Run("CreateUser", createUser_should_insert_user_into_mongo) 19 | } 20 | 21 | func createUser_should_insert_user_into_mongo(t *testing.T) { 22 | //Arrange 23 | mongoConfig := root.MongoConfig { 24 | Ip: "127.0.0.1:27017", 25 | DbName: "myDb" } 26 | session, err := mongo.NewSession(&mongoConfig) 27 | if(err != nil) { 28 | log.Fatalf("Unable to connect to mongo: %s", err) 29 | } 30 | defer func() { 31 | session.DropDatabase(mongoConfig.DbName) 32 | session.Close() 33 | }() 34 | 35 | userService := mongo.NewUserService(session.Copy(), &mongoConfig) 36 | 37 | testUsername := "integration_test_user" 38 | testPassword := "integration_test_password" 39 | user := root.User{ 40 | Username: testUsername, 41 | Password: testPassword } 42 | 43 | //Act 44 | err = userService.CreateUser(&user) 45 | 46 | //Assert 47 | if(err != nil) { 48 | t.Error("Unable to create user: %s", err) 49 | } 50 | 51 | _, resultUser := userService.GetUserByUsername(testUsername) 52 | 53 | if(resultUser.Username != user.Username) { 54 | t.Error("Incorrect Username. Expected `%s`, Got: `%s`", testUsername, resultUser.Username) 55 | } 56 | } -------------------------------------------------------------------------------- /pkg/mongo/session.go: -------------------------------------------------------------------------------- 1 | package mongo 2 | 3 | import ( 4 | "go_rest_api/pkg" 5 | "gopkg.in/mgo.v2" 6 | ) 7 | 8 | type Session struct { 9 | session *mgo.Session 10 | } 11 | 12 | func NewSession(config *root.MongoConfig) (*Session,error) { 13 | //var err error 14 | session, err := mgo.Dial(config.Ip) 15 | if err != nil { 16 | return nil,err 17 | } 18 | session.SetMode(mgo.Monotonic, true) 19 | return &Session{session}, err 20 | } 21 | 22 | func(s *Session) Copy() *mgo.Session { 23 | return s.session.Copy() 24 | } 25 | 26 | func(s *Session) Close() { 27 | if(s.session != nil) { 28 | s.session.Close() 29 | } 30 | } 31 | 32 | func(s *Session) DropDatabase(db string) error { 33 | if(s.session != nil) { 34 | return s.session.DB(db).DropDatabase() 35 | } 36 | return nil 37 | } 38 | -------------------------------------------------------------------------------- /pkg/mongo/user_model.go: -------------------------------------------------------------------------------- 1 | package mongo 2 | 3 | import ( 4 | "go_rest_api/pkg" 5 | "gopkg.in/mgo.v2/bson" 6 | "gopkg.in/mgo.v2" 7 | "golang.org/x/crypto/bcrypt" 8 | "github.com/google/uuid" 9 | ) 10 | 11 | type userModel struct { 12 | Id bson.ObjectId `bson:"_id,omitempty"` 13 | Username string 14 | PasswordHash string 15 | Salt string 16 | } 17 | 18 | func userModelIndex() mgo.Index { 19 | return mgo.Index{ 20 | Key: []string{"username"}, 21 | Unique: true, 22 | DropDups: true, 23 | Background: true, 24 | Sparse: true, 25 | } 26 | } 27 | 28 | func newUserModel(u *root.User) (*userModel,error) { 29 | user := userModel{Username: u.Username} 30 | err := user.setSaltedPassword(u.Password) 31 | return &user, err 32 | } 33 | 34 | func(u *userModel) comparePassword(password string) error { 35 | incoming := []byte(password+u.Salt) 36 | existing := []byte(u.PasswordHash) 37 | err := bcrypt.CompareHashAndPassword(existing, incoming) 38 | return err 39 | } 40 | 41 | func(u *userModel) setSaltedPassword(password string) error { 42 | salt := uuid.New().String() 43 | passwordBytes := []byte(password + salt) 44 | hash, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost) 45 | if err != nil { 46 | return err 47 | } 48 | 49 | u.PasswordHash = string(hash[:]) 50 | u.Salt = salt 51 | 52 | return nil 53 | } -------------------------------------------------------------------------------- /pkg/mongo/user_service.go: -------------------------------------------------------------------------------- 1 | package mongo 2 | 3 | import ( 4 | "gopkg.in/mgo.v2/bson" 5 | "gopkg.in/mgo.v2" 6 | "go_rest_api/pkg" 7 | ) 8 | 9 | type UserService struct { 10 | collection *mgo.Collection 11 | } 12 | 13 | func NewUserService(session *mgo.Session, config *root.MongoConfig) *UserService { 14 | collection := session.DB(config.DbName).C("user") 15 | collection.EnsureIndex(userModelIndex()) 16 | return &UserService {collection} 17 | } 18 | 19 | func(p *UserService) CreateUser(u *root.User) error { 20 | user, err := newUserModel(u) 21 | if err != nil { 22 | return err 23 | } 24 | return p.collection.Insert(&user) 25 | } 26 | 27 | func (p *UserService) GetUserByUsername(username string) (error, root.User) { 28 | model := userModel{} 29 | err := p.collection.Find(bson.M{"username": username}).One(&model) 30 | return err, root.User{ 31 | Id: model.Id.Hex(), 32 | Username: model.Username, 33 | Password: "-" } 34 | } 35 | 36 | func (p *UserService) Login(c root.Credentials) (error, root.User) { 37 | model := userModel{} 38 | err := p.collection.Find(bson.M{"username": c.Username}).One(&model) 39 | 40 | err = model.comparePassword(c.Password) 41 | if(err != nil) { 42 | return err, root.User{} 43 | } 44 | 45 | return err, root.User{ 46 | Id: model.Id.Hex(), 47 | Username: model.Username, 48 | Password: "-" } 49 | } 50 | 51 | -------------------------------------------------------------------------------- /pkg/server/auth.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | "net/http" 7 | "context" 8 | "github.com/dgrijalva/jwt-go" 9 | "go_rest_api/pkg" 10 | ) 11 | 12 | type authHelper struct { 13 | secret string 14 | } 15 | 16 | type claims struct { 17 | Username string `json:"username"` 18 | jwt.StandardClaims 19 | } 20 | 21 | type contextKey string 22 | func (c contextKey) String() string { 23 | return "mypackage context key " + string(c) 24 | } 25 | var ( 26 | contextKeyAuthtoken = contextKey("auth-token") 27 | ) 28 | 29 | func(a *authHelper) newCookie(user root.User) http.Cookie { 30 | expireTime := time.Now().Add(time.Hour * 1) 31 | c := claims { 32 | user.Username, 33 | jwt.StandardClaims { 34 | ExpiresAt: expireTime.Unix(), 35 | Issuer: "localhost!", 36 | }} 37 | 38 | token, _ := jwt.NewWithClaims(jwt.SigningMethodHS256,c).SignedString([]byte(a.secret)) 39 | 40 | cookie := http.Cookie { 41 | Name: "Auth", 42 | Value: token, 43 | Expires: expireTime, 44 | HttpOnly: true } 45 | return cookie 46 | } 47 | 48 | func(a *authHelper) validate(next http.HandlerFunc) http.HandlerFunc { 49 | return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { 50 | cookie, err := req.Cookie("Auth") 51 | if err != nil { 52 | Error(res, http.StatusUnauthorized, "No authorization cookie") 53 | return 54 | } 55 | 56 | token, err := jwt.ParseWithClaims(cookie.Value, &claims{}, func(token *jwt.Token) (interface{}, error){ 57 | if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { 58 | return nil, fmt.Errorf("Unexpected siging method") 59 | } 60 | return []byte(a.secret), nil 61 | }) 62 | 63 | if err != nil { 64 | Error(res, http.StatusUnauthorized, "Invalid token") 65 | return 66 | } 67 | 68 | if claims, ok := token.Claims.(*claims); ok && token.Valid { 69 | ctx := context.WithValue(req.Context(), contextKeyAuthtoken, *claims) 70 | next(res, req.WithContext(ctx)) 71 | } else { 72 | Error(res, http.StatusUnauthorized, "Unauthorized") 73 | return 74 | } 75 | }) 76 | } -------------------------------------------------------------------------------- /pkg/server/response.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "net/http" 5 | "encoding/json" 6 | ) 7 | 8 | func Error(w http.ResponseWriter, code int, message string) { 9 | Json(w, code, map[string]string{"error": message}) 10 | } 11 | 12 | func Json(w http.ResponseWriter, code int, payload interface{}) { 13 | response, _ := json.Marshal(payload) 14 | 15 | w.Header().Set("Content-Type", "application/json") 16 | w.WriteHeader(code) 17 | w.Write(response) 18 | } 19 | 20 | func JsonWithCookie(w http.ResponseWriter, code int, payload interface{}, cookie http.Cookie) { 21 | response, _ := json.Marshal(payload) 22 | http.SetCookie(w, &cookie) 23 | 24 | w.Header().Set("Content-Type", "application/json") 25 | w.WriteHeader(code) 26 | w.Write(response) 27 | } -------------------------------------------------------------------------------- /pkg/server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "os" 5 | "log" 6 | "net/http" 7 | "github.com/gorilla/mux" 8 | "github.com/gorilla/handlers" 9 | "go_rest_api/pkg" 10 | ) 11 | 12 | type Server struct { 13 | router *mux.Router 14 | config *root.ServerConfig 15 | } 16 | 17 | func NewServer(u root.UserService, config *root.Config) *Server { 18 | s := Server { 19 | router: mux.NewRouter(), 20 | config: config.Server } 21 | 22 | a := authHelper{config.Auth.Secret} 23 | NewUserRouter(u, s.getSubrouter("/user"), &a) 24 | return &s 25 | } 26 | 27 | func(s *Server) Start() { 28 | log.Println("Listening on port " + s.config.Port) 29 | if err := http.ListenAndServe(s.config.Port, handlers.LoggingHandler(os.Stdout, s.router)); err != nil { 30 | log.Fatal("http.ListenAndServe: ", err) 31 | } 32 | } 33 | 34 | func(s *Server) getSubrouter(path string) *mux.Router { 35 | return s.router.PathPrefix(path).Subrouter() 36 | } -------------------------------------------------------------------------------- /pkg/server/user_router.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "encoding/json" 7 | "github.com/gorilla/mux" 8 | "go_rest_api/pkg" 9 | ) 10 | 11 | type userRouter struct { 12 | userService root.UserService 13 | auth *authHelper 14 | } 15 | 16 | func NewUserRouter(u root.UserService, router *mux.Router, a *authHelper) *mux.Router { 17 | userRouter := userRouter{u,a} 18 | router.HandleFunc("/", userRouter.createUserHandler).Methods("PUT") 19 | router.HandleFunc("/profile", a.validate(userRouter.profileHandler)).Methods("GET") 20 | router.HandleFunc("/{username}", userRouter.getUserHandler).Methods("GET") 21 | router.HandleFunc("/login", userRouter.loginHandler).Methods("POST") 22 | return router 23 | } 24 | 25 | func(ur* userRouter) createUserHandler(w http.ResponseWriter, r *http.Request) { 26 | err, user := decodeUser(r) 27 | if err != nil { 28 | Error(w, http.StatusBadRequest, "Invalid request payload") 29 | return 30 | } 31 | 32 | err = ur.userService.CreateUser(&user) 33 | if err != nil { 34 | Error(w, http.StatusInternalServerError, err.Error()) 35 | return 36 | } 37 | 38 | Json(w, http.StatusOK, err) 39 | } 40 | 41 | func(ur* userRouter) profileHandler(w http.ResponseWriter, r *http.Request) { 42 | claim, ok := r.Context().Value(contextKeyAuthtoken).(claims) 43 | if !ok { 44 | Error(w, http.StatusBadRequest, "no context") 45 | return 46 | } 47 | username := claim.Username 48 | 49 | err, user := ur.userService.GetUserByUsername(username) 50 | if err != nil { 51 | Error(w, http.StatusNotFound, err.Error()) 52 | return 53 | } 54 | 55 | Json(w, http.StatusOK, user) 56 | } 57 | 58 | func(ur *userRouter) getUserHandler(w http.ResponseWriter, r *http.Request) { 59 | vars := mux.Vars(r) 60 | username := vars["username"] 61 | 62 | err, user := ur.userService.GetUserByUsername(username) 63 | if err != nil { 64 | Error(w, http.StatusNotFound, err.Error()) 65 | return 66 | } 67 | 68 | Json(w, http.StatusOK, user) 69 | } 70 | 71 | func(ur* userRouter) loginHandler(w http.ResponseWriter, r *http.Request) { 72 | err, credentials := decodeCredentials(r) 73 | if err != nil { 74 | Error(w, http.StatusBadRequest, "Invalid request payload") 75 | return 76 | } 77 | 78 | var user root.User 79 | err, user = ur.userService.Login(credentials) 80 | if err == nil { 81 | cookie := ur.auth.newCookie(user) 82 | JsonWithCookie(w, http.StatusOK, user, cookie) 83 | } else { 84 | Error(w, http.StatusInternalServerError, "Incorrect password") 85 | } 86 | } 87 | 88 | func decodeUser(r *http.Request) (error,root.User) { 89 | var u root.User 90 | if r.Body == nil { 91 | return errors.New("no request body"), u 92 | } 93 | decoder := json.NewDecoder(r.Body) 94 | err := decoder.Decode(&u) 95 | return err, u 96 | } 97 | 98 | func decodeCredentials(r *http.Request) (error,root.Credentials) { 99 | var c root.Credentials 100 | if r.Body == nil { 101 | return errors.New("no request body"), c 102 | } 103 | decoder := json.NewDecoder(r.Body) 104 | err := decoder.Decode(&c) 105 | return err, c 106 | } -------------------------------------------------------------------------------- /pkg/server/user_router_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "bytes" 6 | "errors" 7 | "context" 8 | "testing" 9 | "net/http" 10 | "net/http/httptest" 11 | "encoding/json" 12 | "go_rest_api/pkg" 13 | "go_rest_api/pkg/mock" 14 | "github.com/gorilla/mux" 15 | "github.com/dgrijalva/jwt-go" 16 | ) 17 | 18 | //createUserHandler tests 19 | func Test_UserRouter_createUserHandler(t *testing.T) { 20 | t.Run("happy path", createUserHandler_should_pass_User_object_to_UserService_CreateUser) 21 | t.Run("invalid payload", createUserHandler_should_return_StatusBadRequest_if_payload_is_invalid) 22 | t.Run("internal error", createUserHandler_should_return_StatusInternalServerError_if_UserService_returns_error) 23 | } 24 | 25 | func createUserHandler_should_pass_User_object_to_UserService_CreateUser(t *testing.T) { 26 | // Arrange 27 | us := mock.UserService{} 28 | ah := authHelper{"secret"} 29 | test_mux := NewUserRouter(&us, mux.NewRouter(), &ah) 30 | var result *root.User 31 | us.CreateUserFn = func(u *root.User) error { 32 | result = u 33 | return nil 34 | } 35 | 36 | testUsername := "test_username" 37 | testPassword := "test_password" 38 | 39 | values := map[string]string{"username": testUsername, "password": testPassword} 40 | jsonValue, _ := json.Marshal(values) 41 | payload := bytes.NewBuffer(jsonValue) 42 | 43 | // Act 44 | w := httptest.NewRecorder() 45 | r, _ := http.NewRequest("PUT", "/", payload) 46 | r.Header.Set("Content-Type", "application/json") 47 | test_mux.ServeHTTP(w,r) 48 | 49 | // Assert 50 | if !us.CreateUserInvoked { 51 | t.Fatal("expected CreateUser() to be invoked") 52 | } 53 | if result.Username != testUsername { 54 | t.Fatalf("expected username to be: `%s`, got: `%s`", testUsername, result.Username) 55 | } 56 | if result.Password != testPassword { 57 | t.Fatalf("expected username to be: `%s`, got: `%s`", testPassword, result.Password) 58 | } 59 | } 60 | 61 | func createUserHandler_should_return_StatusBadRequest_if_payload_is_invalid(t *testing.T) { 62 | //Arrange 63 | us := mock.UserService{} 64 | ah := authHelper{"secret"} 65 | test_mux := NewUserRouter(&us, mux.NewRouter(), &ah) 66 | us.CreateUserFn = func(u *root.User) error { 67 | return nil 68 | } 69 | 70 | //Act 71 | w := httptest.NewRecorder() 72 | r, _ := http.NewRequest("PUT", "/", nil) 73 | r.Header.Set("Content-Type", "application/json") 74 | test_mux.ServeHTTP(w,r) 75 | 76 | //Assert 77 | if w.Code != http.StatusBadRequest { 78 | t.Fatalf("expected: http.StatusBadRequest, got: `%i`",w.Code) 79 | } 80 | } 81 | 82 | func createUserHandler_should_return_StatusInternalServerError_if_UserService_returns_error(t *testing.T) { 83 | //Arrange 84 | us := mock.UserService{} 85 | ah := authHelper{"secret"} 86 | test_mux := NewUserRouter(&us, mux.NewRouter(), &ah) 87 | us.CreateUserFn = func(u *root.User) error { 88 | return errors.New("user service error") 89 | } 90 | 91 | values := map[string]string{"username": "", "password": ""} 92 | jsonValue, _ := json.Marshal(values) 93 | payload := bytes.NewBuffer(jsonValue) 94 | 95 | //Act 96 | w := httptest.NewRecorder() 97 | r, _ := http.NewRequest("PUT", "/", payload) 98 | r.Header.Set("Content-Type", "application/json") 99 | test_mux.ServeHTTP(w,r) 100 | 101 | //Assert 102 | if w.Code != http.StatusInternalServerError { 103 | t.Fatalf("expected: http.StatusInternalServerError, got: `%i`", w.Code) 104 | } 105 | } 106 | 107 | //profileHandler tests 108 | func Test_UserRouter_profileHandler(t *testing.T) { 109 | t.Run("happy path", profileHandler_should_return_User_from_context) 110 | t.Run("no context", profileHandler_should_return_StatusBadRequest_if_no_auth_context) 111 | t.Run("user not found", profileHandler_should_return_StatusNotFound_if_no_user_found) 112 | } 113 | 114 | func profileHandler_should_return_User_from_context(t *testing.T) { 115 | // Arrange 116 | us := mock.UserService{} 117 | ah := authHelper{"secret"} 118 | test_mux := NewUserRouter(&us, mux.NewRouter(), &ah) 119 | var result string 120 | us.GetUserByUsernameFn = func(username string) (error, root.User) { 121 | result = username 122 | return nil, root.User{} 123 | } 124 | 125 | testUsername := "test_username" 126 | testUser := root.User{Username:testUsername} 127 | 128 | // Act 129 | w := httptest.NewRecorder() 130 | r, _ := http.NewRequest("GET", "/profile", nil) 131 | testCookie := ah.newCookie(testUser) 132 | r.AddCookie(&testCookie) 133 | ctx := context.WithValue(r.Context(), contextKeyAuthtoken, claims { testUsername, jwt.StandardClaims{} }) 134 | test_mux.ServeHTTP(w,r.WithContext(ctx)) 135 | 136 | // Assert 137 | if !us.GetUserByUsernameInvoked { 138 | t.Fatal("expected GetUserByUsername() to be invoked") 139 | } 140 | if result != testUsername { 141 | t.Fatalf("expected username to be: `%s`, got: `%s`", testUsername, result) 142 | } 143 | } 144 | 145 | func profileHandler_should_return_StatusBadRequest_if_no_auth_context(t *testing.T) { 146 | //Arrange 147 | us := mock.UserService{} 148 | ah := authHelper{"secret"} 149 | test_mux := NewUserRouter(&us, mux.NewRouter(), &ah) 150 | 151 | //Act 152 | w := httptest.NewRecorder() 153 | r, _ := http.NewRequest("GET", "/profile", nil) 154 | test_mux.ServeHTTP(w,r) 155 | 156 | //Assert 157 | if w.Code != http.StatusUnauthorized{ 158 | t.Fatalf("expected StatusUnauthorized, got: %s",w.Code) 159 | } 160 | } 161 | 162 | func profileHandler_should_return_StatusNotFound_if_no_user_found(t *testing.T) { 163 | //Arrange 164 | us := mock.UserService{} 165 | ah := authHelper{"secret"} 166 | test_mux := NewUserRouter(&us, mux.NewRouter(), &ah) 167 | var result string 168 | us.GetUserByUsernameFn = func(username string) (error, root.User) { 169 | result = username 170 | return errors.New("user service error"), root.User{} 171 | } 172 | testUsername := "test_username" 173 | testUser := root.User{Username:testUsername} 174 | 175 | //Act 176 | w := httptest.NewRecorder() 177 | r, _ := http.NewRequest("GET", "/profile", nil) 178 | testCookie := ah.newCookie(testUser) 179 | r.AddCookie(&testCookie) 180 | ctx := context.WithValue(r.Context(), contextKeyAuthtoken, claims { testUsername, jwt.StandardClaims{} }) 181 | test_mux.ServeHTTP(w,r.WithContext(ctx)) 182 | 183 | //Assert 184 | if !us.GetUserByUsernameInvoked { 185 | t.Fatal("expected GetUserByUsername() to be invoked") 186 | } 187 | if w.Code != http.StatusNotFound { 188 | t.Fatalf("expected: StatusNotFound, got: %s", w.Code) 189 | } 190 | } 191 | 192 | //getUserHandler tests 193 | func Test_UserRouter_getUserHandler(t *testing.T) { 194 | t.Run("happy path", getUserHandler_should_call_GetUserByUsername_with_username_from_querystring) 195 | t.Run("no user found", getUserHandler_should_return_StatusNotFound_if_no_user_found) 196 | } 197 | 198 | func getUserHandler_should_call_GetUserByUsername_with_username_from_querystring(t *testing.T) { 199 | // Arrange 200 | us := mock.UserService{} 201 | ah := authHelper{"secret"} 202 | test_mux := NewUserRouter(&us, mux.NewRouter(), &ah) 203 | var result string 204 | us.GetUserByUsernameFn = func(username string) (error, root.User) { 205 | result = username 206 | return nil, root.User{} 207 | } 208 | 209 | testUsername := "test_username" 210 | 211 | // Act 212 | w := httptest.NewRecorder() 213 | r, _ := http.NewRequest("GET", "/"+testUsername, nil) 214 | test_mux.ServeHTTP(w,r) 215 | 216 | // Assert 217 | if !us.GetUserByUsernameInvoked { 218 | t.Fatal("expected GetUserByUsername() to be invoked") 219 | } 220 | if result != testUsername { 221 | t.Fatalf("expected username to be: `%s`, got: `%s`", testUsername, result) 222 | } 223 | } 224 | 225 | func getUserHandler_should_return_StatusNotFound_if_no_user_found(t *testing.T) { 226 | // Arrange 227 | us := mock.UserService{} 228 | ah := authHelper{"secret"} 229 | test_mux := NewUserRouter(&us, mux.NewRouter(), &ah) 230 | var result string 231 | us.GetUserByUsernameFn = func(username string) (error, root.User) { 232 | result = username 233 | return errors.New("user service error"), root.User{} 234 | } 235 | 236 | testUsername := "test_username" 237 | 238 | // Act 239 | w := httptest.NewRecorder() 240 | r, _ := http.NewRequest("GET", "/"+testUsername, nil) 241 | test_mux.ServeHTTP(w,r) 242 | 243 | // Assert 244 | if !us.GetUserByUsernameInvoked { 245 | t.Fatal("expected GetUserByUsername() to be invoked") 246 | } 247 | if w.Code != http.StatusNotFound { 248 | t.Fatalf("expected: StatusNotFound, got: %s", w.Code) 249 | } 250 | } 251 | 252 | //gHandler tests 253 | func Test_UserRouter_loginHandler(t *testing.T) { 254 | t.Run("happy path", loginHandler_should_provide_new_auth_cookie_if_userService_returns_a_user) 255 | t.Run("no user found", getUserHandler_should_return_StatusNotFound_if_no_user_found) 256 | } 257 | 258 | func loginHandler_should_provide_new_auth_cookie_if_userService_returns_a_user(t *testing.T) { 259 | // Arrange 260 | us := mock.UserService{} 261 | ah := authHelper{"secret"} 262 | test_mux := NewUserRouter(&us, mux.NewRouter(), &ah) 263 | var result string 264 | us.LoginFn = func(credentials root.Credentials) (error, root.User) { 265 | result = credentials.Username 266 | return nil, root.User{} 267 | } 268 | 269 | testUsername := "test_username" 270 | testPassword := "test_password" 271 | 272 | values := map[string]string{"username": testUsername, "password": testPassword} 273 | jsonValue, _ := json.Marshal(values) 274 | payload := bytes.NewBuffer(jsonValue) 275 | 276 | // Act 277 | w := httptest.NewRecorder() 278 | r, _ := http.NewRequest("POST", "/login", payload) 279 | test_mux.ServeHTTP(w,r) 280 | 281 | // Assert 282 | if !us.LoginInvoked { 283 | t.Fatal("expected Login() to be invoked") 284 | } 285 | 286 | request := &http.Request{Header: http.Header{"Cookie": w.HeaderMap["Set-Cookie"]}} 287 | cookie, err := request.Cookie("Auth") 288 | if err != nil || cookie == nil { 289 | panic("Expected Cookie named 'Auth'") 290 | } 291 | } -------------------------------------------------------------------------------- /pkg/user.go: -------------------------------------------------------------------------------- 1 | package root 2 | 3 | type User struct { 4 | Id string `json:"id"` 5 | Username string `json:"username"` 6 | Password string `json:"password"` 7 | } 8 | 9 | type UserService interface { 10 | CreateUser(u *User) error 11 | GetUserByUsername(username string) (error, User) 12 | Login(c Credentials) (error, User) 13 | } --------------------------------------------------------------------------------