├── auth ├── auth.go ├── auth_test.go ├── passwords.go └── passwords_test.go ├── go.mod ├── go.sum ├── httpd.go ├── route ├── cond.go ├── match.go ├── route.go ├── router.go └── router_test.go ├── router.go ├── server.go ├── server_generic.go ├── server_linux.go ├── session ├── session.go ├── storage.go ├── storage_test.go └── store.go ├── template.go ├── template_helpers.go ├── transaction.go ├── util.go ├── util └── util.go └── util1.go /auth/auth.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/subtle" 6 | "encoding/base64" 7 | "encoding/binary" 8 | "errors" 9 | "fmt" 10 | 11 | "golang.org/x/crypto/scrypt" 12 | ) 13 | 14 | // scrypt constants 15 | type Config struct { 16 | N int // scrypt CPU/memory cost parameter, which must be a power of two greater than 1. 17 | R int // scrypt block size parameter (must satisfy R * P < 2^30) 18 | P int // scrypt parallelisation parameter (must satisfy R * P < 2^30) 19 | 20 | SaltLen int // length of generated salt, in bytes 21 | HashLen int // length of generated hash, in bytes 22 | } 23 | 24 | // DefaultConfig holds the default configuration parameters. 25 | // The recommended parameters for interactive logins as of 2017 are N=32768, r=8 and p=1. 26 | var DefaultConfig = Config{ 27 | N: 32768, // CPU/memory cost parameter 28 | R: 8, // block size parameter 29 | P: 1, // parallelisation parameter 30 | SaltLen: 32, 31 | HashLen: 32, 32 | } 33 | 34 | // ErrInvalidPassword is returned by CheckPassword is the input password is not a match 35 | var ErrInvalidPassword = errors.New("invalid password") 36 | 37 | // HashPassword takes an input password and a salt, returning a hash (or "derived key"). 38 | // The hash returned can together with the input salt be used to verify a password using 39 | // CheckPassword. 40 | // 41 | // In case you need different groups of passwords, you could hash the password with a 42 | // private key before passing it to HashPassword: 43 | // 44 | // import ( 45 | // "crypto/hmac" 46 | // "crypto/sha256" 47 | // ) 48 | // hm := hmac.New(sha256.New, privateKey) 49 | // _, err := hm.Write(password) 50 | // if err != nil { 51 | // return nil, err 52 | // } 53 | // return HashPassword(hm.Sum(nil), salt) 54 | // 55 | func (c Config) HashPassword(password, salt []byte) ([]byte, error) { 56 | return scrypt.Key(password, salt, c.N, c.R, c.P, c.HashLen) 57 | } 58 | 59 | // CheckPassword verifies a password; returns nil if password is correct 60 | func (c Config) CheckPassword(password, salt, hash []byte) error { 61 | hash2, err := HashPassword(password, salt) 62 | if err == nil { 63 | if subtle.ConstantTimeCompare(hash2, hash) != 1 { 64 | err = ErrInvalidPassword 65 | } 66 | } 67 | return err 68 | } 69 | 70 | // GenSalt generates a new cryptographically-strong salt to be used with HashPassword 71 | func (c Config) GenSalt() ([]byte, error) { 72 | salt := make([]byte, c.SaltLen) 73 | _, err := rand.Read(salt) 74 | return salt, err 75 | } 76 | 77 | // Encode config along with salt and hash, returning base-64 data 78 | func (c Config) Encode(salt, hash []byte) []byte { 79 | b := c.EncodeRaw(salt, hash) 80 | out := make([]byte, base64.RawStdEncoding.EncodedLen(len(b))) 81 | base64.RawStdEncoding.Encode(out, b) 82 | return out 83 | } 84 | 85 | // Decode decodes a base-64 encoded config, salt and hash 86 | // previously encoded with c.Encode 87 | func Decode(data []byte) (c Config, salt, hash []byte, err error) { 88 | b := make([]byte, base64.RawStdEncoding.DecodedLen(len(data))) 89 | _, err = base64.RawStdEncoding.Decode(b, data) 90 | if err != nil { 91 | return 92 | } 93 | return DecodeRaw(b) 94 | } 95 | 96 | // EncodeRaw encodes the config along with salt and hash 97 | func (c Config) EncodeRaw(salt, hash []byte) []byte { 98 | z := binary.MaxVarintLen32*5 + len(salt) + len(hash) 99 | b := make([]byte, z) 100 | 101 | i := binary.PutUvarint(b, uint64(c.N)) 102 | i += binary.PutUvarint(b[i:], uint64(c.R)) 103 | i += binary.PutUvarint(b[i:], uint64(c.P)) 104 | i += binary.PutUvarint(b[i:], uint64(len(salt))) 105 | i += binary.PutUvarint(b[i:], uint64(len(hash))) 106 | copy(b[i:], salt) 107 | i += len(salt) 108 | copy(b[i:], hash) 109 | i += len(hash) 110 | 111 | return b[:i] 112 | } 113 | 114 | // DecodeRaw decodes a base-64 encoded config, salt and hash previously encoded with c.EncodeRaw 115 | func DecodeRaw(data []byte) (c Config, salt, hash []byte, err error) { 116 | b := data 117 | 118 | N, n := binary.Uvarint(b) 119 | i := n 120 | R, n := binary.Uvarint(b[i:]) 121 | i += n 122 | P, n := binary.Uvarint(b[i:]) 123 | i += n 124 | SaltLen, n := binary.Uvarint(b[i:]) 125 | i += n 126 | HashLen, n := binary.Uvarint(b[i:]) 127 | if n <= 0 { 128 | err = fmt.Errorf("invalid data (header)") 129 | } 130 | i += n 131 | c.N = int(N) 132 | c.R = int(R) 133 | c.P = int(P) 134 | c.SaltLen = int(SaltLen) 135 | c.HashLen = int(HashLen) 136 | salt = b[i : i+c.SaltLen] 137 | i += c.SaltLen 138 | hash = b[i : i+c.HashLen] 139 | return 140 | } 141 | 142 | // Package-level functions on DefaultConfig: 143 | 144 | // HashPassword takes an input password and a salt, returning a hash 145 | func HashPassword(password, salt []byte) ([]byte, error) { 146 | return DefaultConfig.HashPassword(password, salt) 147 | } 148 | 149 | // CheckPassword verifies a password; returns nil if password is correct 150 | func CheckPassword(password, salt, hash []byte) error { 151 | return DefaultConfig.CheckPassword(password, salt, hash) 152 | } 153 | 154 | // GenSalt generates a new cryptographically-strong salt to be used with HashPassword 155 | func GenSalt() ([]byte, error) { 156 | return DefaultConfig.GenSalt() 157 | } 158 | 159 | // Encode config along with salt and hash, returning base-64 data 160 | func Encode(salt, hash []byte) []byte { 161 | return DefaultConfig.Encode(salt, hash) 162 | } 163 | -------------------------------------------------------------------------------- /auth/auth_test.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/rsms/go-testutil" 8 | ) 9 | 10 | func Test_init(t *testing.T) { 11 | // These tests uses intentionally-slow computations which means timeout can't be too short 12 | deadline, usesTimeout := t.Deadline() 13 | if usesTimeout && time.Until(deadline) < 10*time.Second { 14 | t.Errorf( 15 | "This test suite needs to be run with a longer timeout since password verification"+ 16 | " makes use of timing. Tests usually need 2-5s to complete. (Current timeout in %s)", 17 | time.Until(deadline), 18 | ) 19 | } 20 | } 21 | 22 | func TestFundamentals(t *testing.T) { 23 | assert := testutil.NewAssert(t) 24 | 25 | password1 := []byte("lolcat") 26 | password2 := []byte("hotdog") 27 | 28 | salt, err := GenSalt() 29 | assert.NoErr("GenSalt", err) 30 | assert.Eq("GenSalt generated SaltLen long salt", DefaultConfig.SaltLen, len(salt)) 31 | 32 | hash, err := HashPassword(password1, salt) 33 | assert.NoErr("HashPassword", err) 34 | assert.Eq("HashPassword generates HashLen long hash", DefaultConfig.HashLen, len(hash)) 35 | 36 | err = CheckPassword(password1, salt, hash) 37 | assert.NoErr("CheckPassword succeeded", err) 38 | 39 | err = CheckPassword(password2, salt, hash) 40 | assert.Err("CheckPassword failed with wrong password", "invalid password", err) 41 | 42 | err = CheckPassword(password1, append(salt, 'x'), hash) 43 | assert.Err("CheckPassword failed with different salt", "invalid password", err) 44 | 45 | err = CheckPassword(password1, salt, append(hash, 'x')) 46 | assert.Err("CheckPassword failed with different hash", "invalid password", err) 47 | } 48 | 49 | func TestConfigEncode(t *testing.T) { 50 | assert := testutil.NewAssert(t) 51 | password1 := []byte("lolcat") 52 | 53 | salt, err := GenSalt() 54 | assert.NoErr("GenSalt", err) 55 | hash, err := HashPassword(password1, salt) 56 | assert.NoErr("HashPassword", err) 57 | 58 | config1 := DefaultConfig 59 | 60 | data := config1.Encode(salt, hash) 61 | t.Logf("config1.Encode => %q", data) 62 | 63 | config2, salt2, hash2, err := Decode(data) 64 | assert.NoErr("Decode", err) 65 | assert.Eq("config N", config2.N, config1.N) 66 | assert.Eq("config R", config2.R, config1.R) 67 | assert.Eq("config P", config2.P, config1.P) 68 | assert.Eq("config SaltLen", config2.SaltLen, config1.SaltLen) 69 | assert.Eq("config HashLen", config2.HashLen, config1.HashLen) 70 | assert.Eq("salt2", salt, salt2) 71 | assert.Eq("hash2", hash, hash2) 72 | } 73 | -------------------------------------------------------------------------------- /auth/passwords.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | // ErrInvalidAccount is returned by Change and Verify if there is no password data for the 8 | // requested account. Usually this means that the account does not exist. 9 | var ErrInvalidAccount = errors.New("invalid account") 10 | 11 | // Passwords represents a group of passwords 12 | type Passwords struct { 13 | // Config is the crypto configuration used when setting or changing password. 14 | // This must be initialized before using any of the methods of this type. 15 | Config 16 | 17 | // SetAccountPasswordData is called to store password data for an account identified by a. 18 | // It's called as a result from calling Passwords.Set or Passwords.Change 19 | // The data is base-64 encoded and can thus safely be printed or transmitted via e.g. JSON. 20 | SetAccountPasswordData func(a interface{}, data []byte) error 21 | 22 | // GetAccountPasswordData is called to load password data for an account identified by a. 23 | // It must return data previously stored via SetAccountPasswordData. 24 | // It should return an error if no data exists for the account, however an implementation 25 | // can choose to return nil or an empty byte slice in this case instead, which leads to 26 | // ErrInvalidAccount being returned from the calling function. 27 | GetAccountPasswordData func(a interface{}) ([]byte, error) 28 | } 29 | 30 | // Set computes a hash from salt + password and assigns the result to the account identified by a. 31 | // This is usually used when creating new accounts or during password recovery. 32 | func (s *Passwords) Set(a interface{}, password string) error { 33 | // generate a salt 34 | salt, err := s.Config.GenSalt() 35 | if err != nil { 36 | return err 37 | } 38 | // compute derived key which we call "hash" 39 | hash, err := s.Config.HashPassword([]byte(password), salt) 40 | if err != nil { 41 | return err 42 | } 43 | // encode config, salt and hash 44 | data := s.Config.Encode(salt, hash) 45 | return s.SetAccountPasswordData(a, data) 46 | } 47 | 48 | // Verify checks if the provided password is correct for the account identified by a. 49 | // This is usually used during sign in. 50 | func (s *Passwords) Verify(a interface{}, password string) error { 51 | data, err := s.GetAccountPasswordData(a) 52 | if err != nil { 53 | return err 54 | } 55 | if len(data) == 0 { 56 | return ErrInvalidAccount 57 | } 58 | c, salt, hash, err := Decode(data) 59 | if err != nil { 60 | return err 61 | } 62 | return c.CheckPassword([]byte(password), salt, hash) 63 | } 64 | 65 | // Change is like a conditional Set: sets the password of account identified by a to newPassword 66 | // only if the currentPassword passes Verify. 67 | // It's essentially a wrapper around Verify() & Set(). 68 | // This is usually used when the user changes their password. 69 | func (s *Passwords) Change(a interface{}, currentPassword, newPassword string) error { 70 | err := s.Verify(a, currentPassword) 71 | if err != nil { 72 | return err 73 | } 74 | return s.Set(a, newPassword) 75 | } 76 | -------------------------------------------------------------------------------- /auth/passwords_test.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/rsms/go-testutil" 8 | ) 9 | 10 | type TestAccounts struct { 11 | Passwords 12 | db map[int]*TestAccount 13 | } 14 | 15 | type TestAccount struct { 16 | id int 17 | passwordData []byte 18 | } 19 | 20 | func TestPasswords(t *testing.T) { 21 | assert := testutil.NewAssert(t) 22 | assert.Ok("a", true) 23 | 24 | var accounts TestAccounts 25 | accounts.Passwords.Config = DefaultConfig 26 | accounts.db = make(map[int]*TestAccount) 27 | 28 | accounts.Passwords.SetAccountPasswordData = func(id interface{}, data []byte) error { 29 | a := accounts.db[id.(int)] 30 | if a == nil { 31 | return fmt.Errorf("account %v not found", a) 32 | } 33 | a.passwordData = data 34 | return nil 35 | } 36 | 37 | accounts.Passwords.GetAccountPasswordData = func(id interface{}) ([]byte, error) { 38 | a := accounts.db[id.(int)] 39 | if a == nil { 40 | return nil, fmt.Errorf("account %v not found", a) 41 | } 42 | return a.passwordData, nil 43 | } 44 | 45 | account1 := &TestAccount{id: 1} 46 | account2 := &TestAccount{id: 2} 47 | accounts.db[account1.id] = account1 48 | accounts.db[account2.id] = account2 49 | 50 | err := accounts.Passwords.Set(account1.id, "lolcat") 51 | assert.NoErr("Passwords.Set account1", err) 52 | 53 | err = accounts.Passwords.Verify(account1.id, "lolcat") 54 | assert.NoErr("Passwords.Verify account1", err) 55 | 56 | // account2 has no password data so the error message is "invalid account" 57 | err = accounts.Passwords.Verify(account2.id, "lolcat") 58 | assert.Err("Passwords.Verify account2 with wrong password", "invalid account", err) 59 | 60 | err = accounts.Passwords.Set(account2.id, "hotdog") 61 | assert.NoErr("Passwords.Set account2", err) 62 | 63 | err = accounts.Passwords.Verify(account2.id, "hotdog") 64 | assert.NoErr("Passwords.Verify account2", err) 65 | 66 | // account2 now has password data so the error message should be different 67 | err = accounts.Passwords.Verify(account2.id, "lolcat") 68 | assert.Err("Passwords.Verify account2 with wrong password", "invalid password", err) 69 | 70 | // change password 71 | err = accounts.Passwords.Change(account1.id, "lolcat", "monorail") 72 | assert.NoErr("Passwords.Change account1", err) 73 | 74 | // change password fails with the wrong "currentPassword" argument 75 | err = accounts.Passwords.Change(account1.id, "lolcat", "monorail") 76 | assert.Err("Passwords.Change account1", "invalid password", err) 77 | } 78 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/rsms/go-httpd 2 | 3 | go 1.15 4 | 5 | require ( 6 | github.com/rsms/go-log v0.1.2 7 | github.com/rsms/go-testutil v0.1.1 8 | github.com/rsms/go-uuid v0.1.2 9 | github.com/rsms/gotalk v1.3.6 10 | golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 11 | ) 12 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= 2 | github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= 3 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 4 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 5 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 6 | github.com/rsms/go-log v0.1.2 h1:rnA+V1ccv+fsZcDtj5rwZ0B1PPS+iEJVB+F2emQs2s8= 7 | github.com/rsms/go-log v0.1.2/go.mod h1:yhPudnQQV6DCSCmGPQa4agpQB6jtXokgjEHGIgurDms= 8 | github.com/rsms/go-testutil v0.1.0 h1:UkSEZKtXXK3B4kvpjNJs2GePVRuWIER4nOYyL72QTqs= 9 | github.com/rsms/go-testutil v0.1.0/go.mod h1:Jm6EzhXOLcqNmqWbqOYMXOat3diHHyH1L5MLuP+6PyI= 10 | github.com/rsms/go-testutil v0.1.1 h1:IC5+Iruf368jqSovAvQCC1bQyiIo8+gCeLSCCxExmbY= 11 | github.com/rsms/go-testutil v0.1.1/go.mod h1:Jm6EzhXOLcqNmqWbqOYMXOat3diHHyH1L5MLuP+6PyI= 12 | github.com/rsms/go-uuid v0.1.1 h1:ULJayqoAvyyQTCnuCVaFCZMUu9gkLuASvzcksGYd/UE= 13 | github.com/rsms/go-uuid v0.1.1/go.mod h1:9EDSCn1x9u88kXbj+LAzIcOaGWbTK09TyS28Nd2HH/s= 14 | github.com/rsms/go-uuid v0.1.2 h1:Ul5LNJR4IXA1q1KHU0RCn67nXpaftFEpmDnYreaz1Wg= 15 | github.com/rsms/go-uuid v0.1.2/go.mod h1:9EDSCn1x9u88kXbj+LAzIcOaGWbTK09TyS28Nd2HH/s= 16 | github.com/rsms/gotalk v1.3.6 h1:3o/WtbkXna4RPiG3/qOqpDg4iiPAED1C2NWvuKtkPH0= 17 | github.com/rsms/gotalk v1.3.6/go.mod h1:4GeESylOuAJwY9Dbl1UPRBOdHbAiB5/XLxz8zLcLBak= 18 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 19 | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 20 | golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 h1:pLI5jrR7OSLijeIDcmRxNmw2api+jEfxLoykJVice/E= 21 | golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 22 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 23 | golang.org/x/net v0.0.0-20200930145003-4acb6c075d10 h1:YfxMZzv3PjGonQYNUaeU2+DhAdqOxerQ30JFB6WgAXo= 24 | golang.org/x/net v0.0.0-20200930145003-4acb6c075d10/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= 25 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 26 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 27 | golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 28 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 29 | -------------------------------------------------------------------------------- /httpd.go: -------------------------------------------------------------------------------- 1 | package httpd 2 | 3 | // DevMode can be enabled to allow development features: 4 | // - allow storing cookies from unencrypted connections 5 | // 6 | var DevMode bool 7 | -------------------------------------------------------------------------------- /route/cond.go: -------------------------------------------------------------------------------- 1 | package route 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | type CondFlags uint64 9 | 10 | const ( 11 | CondMethodGET = (1 << iota) 12 | CondMethodCONNECT 13 | CondMethodDELETE 14 | CondMethodHEAD 15 | CondMethodOPTIONS 16 | CondMethodPATCH 17 | CondMethodPOST 18 | CondMethodPUT 19 | CondMethodTRACE 20 | ) 21 | 22 | func (fl CondFlags) String() string { 23 | if fl == 0 { 24 | return "*" 25 | } 26 | var sb strings.Builder 27 | if (fl & CondMethodGET) != 0 { 28 | sb.WriteString("|GET") 29 | } 30 | if (fl & CondMethodCONNECT) != 0 { 31 | sb.WriteString("|CONNECT") 32 | } 33 | if (fl & CondMethodDELETE) != 0 { 34 | sb.WriteString("|DELETE") 35 | } 36 | if (fl & CondMethodHEAD) != 0 { 37 | sb.WriteString("|HEAD") 38 | } 39 | if (fl & CondMethodOPTIONS) != 0 { 40 | sb.WriteString("|OPTIONS") 41 | } 42 | if (fl & CondMethodPATCH) != 0 { 43 | sb.WriteString("|PATCH") 44 | } 45 | if (fl & CondMethodPOST) != 0 { 46 | sb.WriteString("|POST") 47 | } 48 | if (fl & CondMethodPUT) != 0 { 49 | sb.WriteString("|PUT") 50 | } 51 | if (fl & CondMethodTRACE) != 0 { 52 | sb.WriteString("|TRACE") 53 | } 54 | b := sb.String() 55 | if len(b) == 0 { 56 | return "*" 57 | } 58 | return b[1:] 59 | } 60 | 61 | func ParseCondFlags(tokens []string) (CondFlags, error) { 62 | var f CondFlags 63 | if len(tokens) == 1 && tokens[0] == "*" { 64 | // special case: "*" for "any" which is the same as no conditions 65 | return f, nil 66 | } 67 | for _, tok := range tokens { 68 | switch tok { 69 | case "GET": 70 | f |= CondMethodGET 71 | case "CONNECT": 72 | f |= CondMethodCONNECT 73 | case "DELETE": 74 | f |= CondMethodDELETE 75 | case "HEAD": 76 | f |= CondMethodHEAD 77 | case "OPTIONS": 78 | f |= CondMethodOPTIONS 79 | case "PATCH": 80 | f |= CondMethodPATCH 81 | case "POST": 82 | f |= CondMethodPOST 83 | case "PUT": 84 | f |= CondMethodPUT 85 | case "TRACE": 86 | f |= CondMethodTRACE 87 | default: 88 | return f, fmt.Errorf("invalid condition %q", tok) 89 | } 90 | } 91 | return f, nil 92 | } 93 | -------------------------------------------------------------------------------- /route/match.go: -------------------------------------------------------------------------------- 1 | package route 2 | 3 | type Match struct { 4 | *Route 5 | Path string // relative to router's BasePath 6 | values []string // variable values 7 | } 8 | 9 | // Values returns all variable values 10 | func (m Match) Values() []string { return m.values } 11 | 12 | // Vars returns all variable names and values as a map 13 | func (m Match) Vars() map[string]string { 14 | kv := make(map[string]string, len(m.Route.Vars)) 15 | if len(m.Route.Vars) > 0 { 16 | for name, index := range m.Route.Vars { 17 | kv[name] = m.values[index] 18 | } 19 | } 20 | return kv 21 | } 22 | 23 | // Var retrieves the value of a variable by name 24 | func (m Match) Var(name string, fallback ...string) string { 25 | if i, ok := m.Route.Vars[name]; ok && m.values != nil { 26 | return m.values[i] 27 | } 28 | if len(fallback) > 0 { 29 | return fallback[0] 30 | } 31 | return "" 32 | } 33 | -------------------------------------------------------------------------------- /route/route.go: -------------------------------------------------------------------------------- 1 | package route 2 | 3 | import ( 4 | "fmt" 5 | "regexp" 6 | "strings" 7 | ) 8 | 9 | var ( 10 | reSplitOR = regexp.MustCompile(`\s*\|\s*`) 11 | reMatchVars = regexp.MustCompile(`\{(?P\w+)(?:\:\s*(?P[^\}]*)|)\s*\}`) 12 | ) 13 | 14 | const defaultVarPattern = `[^/]+` // implicit pattern in "{name}" (no ":pattern") 15 | 16 | type Route struct { 17 | Conditions CondFlags 18 | Pattern *regexp.Regexp 19 | Vars map[string]int // name => match position 20 | EntryPrefix string 21 | IsPrefix bool 22 | Handler interface{} 23 | } 24 | 25 | func (r *Route) String() string { 26 | pattern := r.EntryPrefix 27 | if r.Pattern != nil { 28 | pattern = r.Pattern.String() 29 | } 30 | return fmt.Sprintf("{%s %s}", r.Conditions, pattern) 31 | } 32 | 33 | func (r *Route) Parse(pathPattern string) error { 34 | // parse: "COND|COND /path/pattern" -> {{"COND", "COND"}, "path/pattern"} 35 | pathPattern = strings.TrimSpace(pathPattern) 36 | i := strings.IndexByte(pathPattern, '/') 37 | if i == -1 { 38 | return fmt.Errorf("invalid route pattern %q; missing leading \"/\" in path", pathPattern) 39 | } 40 | var conditions []string 41 | condstr := strings.Trim(pathPattern[:i], "| \t\r\n") 42 | if len(condstr) > 0 { 43 | conditions = reSplitOR.Split(condstr, -1) 44 | pathPattern = pathPattern[i:] 45 | } 46 | if len(pathPattern) == 0 { 47 | return fmt.Errorf("empty route pattern") 48 | } 49 | 50 | // parse conditions 51 | conds, err := ParseCondFlags(conditions) 52 | if err != nil { 53 | return err 54 | } 55 | r.Conditions = conds 56 | 57 | // prefix? i.e. "/foo/" is a prefix while "/foo" and "/foo/!" are not. 58 | c := pathPattern[len(pathPattern)-1] 59 | if c == '/' { 60 | r.IsPrefix = true 61 | } else if c == '!' { 62 | // "/foo/!" => "/foo/" 63 | // "/foo/!!" => "/foo/!" 64 | pathPattern = pathPattern[:len(pathPattern)-1] 65 | } 66 | 67 | // find vars 68 | pathPatternBytes := []byte(pathPattern) 69 | locations := reMatchVars.FindAllSubmatchIndex(pathPatternBytes, -1) 70 | if len(locations) == 0 { 71 | // no vars 72 | r.EntryPrefix = pathPattern 73 | r.Vars = nil 74 | r.Pattern = nil 75 | return nil 76 | } 77 | 78 | // has vars; will build r.Pattern 79 | r.EntryPrefix = "" 80 | r.Vars = make(map[string]int, len(locations)) 81 | resultPattern := make([]byte, 1, len(pathPatternBytes)*2) 82 | resultPattern[0] = '^' 83 | plainStart := 0 84 | 85 | for varIndex, loc := range locations { 86 | varStart, varEnd := loc[0], loc[1] // range of whole "{...}" chunk 87 | 88 | // add plain chunk to resultPattern (whatever comes before the var) 89 | if plainStart < varStart { 90 | chunk := pathPattern[plainStart:varStart] 91 | if plainStart == 0 { 92 | r.EntryPrefix = chunk 93 | } 94 | resultPattern = append(resultPattern, regexp.QuoteMeta(chunk)...) 95 | } 96 | plainStart = varEnd 97 | 98 | // extract var name and pattern 99 | varName := pathPattern[loc[2]:loc[3]] 100 | pat := defaultVarPattern 101 | if loc[4] > -1 { 102 | // trim away leading "^" and trailing "$" in pattern. E.g. "^\w+$" -> "\w+" 103 | pat = pathPattern[loc[4]:loc[5]] 104 | if len(pat) == 0 { 105 | pat = defaultVarPattern 106 | } 107 | } 108 | 109 | // memorize var 110 | if varName != "_" { 111 | if _, ok := r.Vars[varName]; ok { 112 | return fmt.Errorf("duplicate var %q in route pattern %q", varName, pathPattern) 113 | } 114 | r.Vars[varName] = varIndex 115 | } 116 | 117 | // add var capture pattern 118 | resultPattern = append(resultPattern, '(') 119 | resultPattern = append(resultPattern, pat...) 120 | resultPattern = append(resultPattern, ')') 121 | } 122 | 123 | // add any trailing plain chunk 124 | if plainStart < len(pathPattern) { 125 | resultPattern = append(resultPattern, regexp.QuoteMeta(pathPattern[plainStart:])...) 126 | } 127 | 128 | // terminating "$", unless r.IsPrefix 129 | if !r.IsPrefix { 130 | resultPattern = append(resultPattern, '$') 131 | } 132 | 133 | // compile regexp pattern 134 | re, err := regexp.Compile(string(resultPattern)) 135 | if err != nil { 136 | return err 137 | } 138 | r.Pattern = re 139 | return nil 140 | } 141 | -------------------------------------------------------------------------------- /route/router.go: -------------------------------------------------------------------------------- 1 | package route 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | type Router struct { 9 | // BasePath is the URL path prefix where these routes begin. 10 | // All rules within are relative to this path. 11 | BasePath string 12 | Routes []Route 13 | } 14 | 15 | func (r *Router) Add(pattern string, handler interface{}) (*Route, error) { 16 | // perform some generic checks on Router, since Add is called a lot less often than ServeHTTP. 17 | if len(r.BasePath) > 0 { 18 | if r.BasePath == "/" { 19 | r.BasePath = "" 20 | } else if r.BasePath[0] != '/' { 21 | // the Dispatch function expects BasePath to start with "/"; add it 22 | r.BasePath = "/" + r.BasePath 23 | } else { 24 | // the Dispatch function expects BasePath to not end in "/"; trim away 25 | for i := len(r.BasePath) - 1; i >= 0 && r.BasePath[i] == '/'; i-- { 26 | r.BasePath = r.BasePath[:i] 27 | } 28 | } 29 | } 30 | 31 | // new Route 32 | r.Routes = append(r.Routes, Route{}) 33 | route := &(r.Routes[len(r.Routes)-1]) 34 | 35 | // parse 36 | if err := route.Parse(pattern); err != nil { 37 | r.Routes = r.Routes[:len(r.Routes)-1] 38 | return nil, err 39 | } 40 | 41 | route.Handler = handler 42 | return route, nil 43 | } 44 | 45 | func (r *Router) Match(conditions CondFlags, path string) (*Match, error) { 46 | // trim BasePath off of URL path 47 | if len(r.BasePath) > 0 { 48 | // when BasePath is non-empty it... 49 | // - always begins with "/" 50 | // - never ends with "/" 51 | // - is never just "/" 52 | // 53 | if !strings.HasPrefix(path, r.BasePath) { 54 | return nil, fmt.Errorf("requested path %q outside of BasePath %q", path, r.BasePath) 55 | } 56 | path = path[len(r.BasePath):] 57 | } 58 | 59 | // This could be a lot more efficient with something fancy like a b-tree. 60 | // For now, keep it simple and just do a linear scan. 61 | for i := range r.Routes { 62 | route := &r.Routes[i] 63 | 64 | // check conditions 65 | if route.Conditions != 0 && (route.Conditions&conditions) == 0 { 66 | continue 67 | } 68 | 69 | // check constant prefix 70 | if len(route.EntryPrefix) > 0 && !strings.HasPrefix(path, route.EntryPrefix) { 71 | continue 72 | } 73 | 74 | if route.Pattern == nil { 75 | // no variables 76 | if route.IsPrefix || path == route.EntryPrefix { 77 | return &Match{Route: route, Path: path}, nil 78 | } 79 | } else { 80 | // check regexp 81 | values := route.Pattern.FindStringSubmatch(path) 82 | if len(values) == 1+len(route.Vars) { 83 | return &Match{Route: route, Path: path, values: values[1:]}, nil 84 | } 85 | } 86 | } 87 | 88 | // no route found 89 | return nil, nil 90 | } 91 | -------------------------------------------------------------------------------- /route/router_test.go: -------------------------------------------------------------------------------- 1 | package route 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/rsms/go-testutil" 7 | ) 8 | 9 | func TestRouter(t *testing.T) { 10 | assert := testutil.NewAssert(t) 11 | 12 | var r Router 13 | 14 | r.BasePath = "/foo" 15 | 16 | r.Add(`/us.er/{id:[0-9a-zA-Z]+}/{action:\w+}/thing`, 1) 17 | r.Add("/us.er/{q}", 2) 18 | r.Add("GET|POST /", 99) 19 | 20 | m, err := r.Match(CondMethodPUT, r.BasePath+"/") 21 | assert.NoErr("no input error", err) 22 | assert.Ok("PUT method not in condition for '/'", m == nil) 23 | 24 | m, err = r.Match(CondMethodGET, r.BasePath+"/") 25 | assert.NoErr("no input error", err) 26 | assert.Eq("route 99", m.Handler.(int), 99) 27 | 28 | m, err = r.Match(CondMethodGET, r.BasePath+"/us.er/bob") 29 | assert.NoErr("no input error", err) 30 | assert.Eq("route 2", m.Handler.(int), 2) 31 | 32 | m, err = r.Match(CondMethodGET, r.BasePath+"/us.er/bob/lol/thing") 33 | assert.NoErr("no input error", err) 34 | assert.Eq("route 1", m.Handler.(int), 1) 35 | 36 | assert.Eq("Var", m.Var("id"), "bob") 37 | assert.Eq("Var", m.Var("action"), "lol") 38 | assert.Eq("Var", m.Var("unknown"), "") 39 | 40 | assert.Eq("Values", len(m.Values()), 2) 41 | assert.Eq("Values", m.Values()[0], "bob") 42 | assert.Eq("Values", m.Values()[1], "lol") 43 | 44 | assert.Eq("Vars", len(m.Vars()), 2) 45 | assert.Eq("Vars", m.Vars()["id"], "bob") 46 | assert.Eq("Vars", m.Vars()["action"], "lol") 47 | } 48 | -------------------------------------------------------------------------------- /router.go: -------------------------------------------------------------------------------- 1 | package httpd 2 | 3 | import ( 4 | "github.com/rsms/go-httpd/route" 5 | ) 6 | 7 | type handlerFunc func(*Transaction) 8 | 9 | func (f handlerFunc) ServeHTTP(t *Transaction) { f(t) } 10 | 11 | // Router is a HTTP-specific kind of route.Router 12 | type Router struct { 13 | route.Router 14 | } 15 | 16 | func (r *Router) HandleFunc(pattern string, f func(*Transaction)) (*route.Route, error) { 17 | return r.Handle(pattern, handlerFunc(f)) 18 | } 19 | 20 | func (r *Router) Handle(pattern string, handler Handler) (*route.Route, error) { 21 | return r.Add(pattern, handler) 22 | } 23 | 24 | func (r *Router) Match(t *Transaction) (Handler, error) { 25 | // effective conditions of the transaction 26 | conditions, _ := route.ParseCondFlags([]string{t.Method()}) 27 | 28 | // find a matching route 29 | m, err := r.Router.Match(conditions, t.URL.Path) 30 | if err != nil || m == nil { 31 | return nil, err 32 | } 33 | t.routeMatch = m 34 | return m.Route.Handler.(Handler), nil 35 | } 36 | 37 | func (r *Router) ServeHTTP(t *Transaction) { 38 | if !r.MaybeServeHTTP(t) { 39 | t.RespondWithStatusNotFound() 40 | } 41 | } 42 | 43 | func (r *Router) MaybeServeHTTP(t *Transaction) bool { 44 | route, err := r.Match(t) 45 | if err != nil { 46 | t.Server.LogError("Router configuration error: %v", err) 47 | t.RespondWithStatusInternalServerError() 48 | return true 49 | } 50 | if route == nil { 51 | return false 52 | } 53 | route.ServeHTTP(t) 54 | return true 55 | } 56 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package httpd 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "net/http" 7 | "os" 8 | "os/signal" 9 | "path" 10 | "runtime/debug" 11 | "strings" 12 | "sync" 13 | "syscall" 14 | "time" 15 | 16 | "github.com/rsms/go-httpd/session" 17 | "github.com/rsms/go-log" 18 | "github.com/rsms/gotalk" 19 | ) 20 | 21 | type Server struct { 22 | Logger *log.Logger // defaults to log.RootLogger 23 | PubDir string // directory to serve files from. File serving is disabled if empty. 24 | Routes Router // http request routes 25 | Server http.Server // underlying http server 26 | Sessions session.Store // Call Sessions.SetStorage(s) to enable sessions 27 | 28 | Gotalk *gotalk.WebSocketServer // set to nil to disable gotalk 29 | GotalkPath string // defaults to "/gotalk/" 30 | 31 | fileHandler http.Handler // serves pubdir (nil if len(PubDir)==0) 32 | 33 | gotalkSocksMu sync.RWMutex // protects gotalkSocks field 34 | gotalkSocks map[*gotalk.WebSocket]int // currently connected gotalk sockets 35 | gotalkOnConnectUser func(sock *gotalk.WebSocket) // saved value of .Gotalk.OnConnect 36 | 37 | gracefulShutdownTimeout time.Duration 38 | } 39 | 40 | func NewServer(pubDir, addr string) *Server { 41 | s := &Server{ 42 | PubDir: pubDir, 43 | Logger: log.RootLogger, 44 | Server: http.Server{ 45 | Addr: addr, 46 | WriteTimeout: 10 * time.Second, 47 | ReadTimeout: 10 * time.Second, 48 | MaxHeaderBytes: 1 << 20, // 1MB 49 | }, 50 | Gotalk: gotalk.WebSocketHandler(), 51 | GotalkPath: "/gotalk/", 52 | } 53 | 54 | if len(pubDir) > 0 { 55 | s.fileHandler = http.FileServer(http.Dir(pubDir)) 56 | } 57 | 58 | s.Server.Handler = s 59 | s.Gotalk.Handlers = gotalk.NewHandlers() 60 | s.Server.RegisterOnShutdown(func() { 61 | // close all connected sockets 62 | s.gotalkSocksMu.RLock() 63 | defer s.gotalkSocksMu.RUnlock() 64 | for s := range s.gotalkSocks { 65 | s.Close() 66 | } 67 | }) 68 | 69 | return s 70 | } 71 | 72 | // ServeHTTP serves a HTTP request using this server 73 | func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { 74 | if r.RequestURI == "*" { 75 | if r.ProtoAtLeast(1, 1) { 76 | w.Header().Set("Connection", "close") 77 | } 78 | w.WriteHeader(http.StatusBadRequest) 79 | return 80 | } 81 | 82 | // CONNECT requests are not canonicalized 83 | if r.Method != "CONNECT" { 84 | // strip port and clean path 85 | url := *r.URL 86 | path := cleanPath(r.URL.Path) 87 | 88 | // redirect if the path was not canonical 89 | if path != r.URL.Path { 90 | url.Path = path 91 | http.Redirect(w, r, url.String(), http.StatusMovedPermanently) 92 | return 93 | } 94 | 95 | // set cleaned valued 96 | url.Path = path 97 | r.Host = stripHostPort(r.Host) 98 | } 99 | 100 | // gotalk? 101 | if s.Gotalk != nil && s.GotalkPath != "" && strings.HasPrefix(r.URL.Path, s.GotalkPath) { 102 | // Note: s.Gotalk.OnAccept handler is installed in prepareToServe 103 | s.Gotalk.ServeHTTP(w, r) 104 | return 105 | } 106 | 107 | // create a new transaction 108 | t := NewTransaction(s, w, r) 109 | 110 | // recover panic and turn it into an error 111 | defer func() { 112 | if err := recover(); err != nil { 113 | s.LogError("ServeHTTP error: %v", err) 114 | if s.Logger.Level <= log.LevelDebug { 115 | s.LogDebug("ServeHTTP error: %s\n%s", err, string(debug.Stack())) 116 | } 117 | t.RespondWithMessage(500, err) 118 | } 119 | }() 120 | 121 | // serve 122 | if s.Routes.MaybeServeHTTP(t) { 123 | return 124 | } 125 | 126 | // fallback to serving files, if configured 127 | if s.fileHandler != nil { 128 | s.fileHandler.ServeHTTP(w, r) 129 | return 130 | } 131 | 132 | // 404 not found 133 | t.RespondWithStatusNotFound() 134 | } 135 | 136 | // cleanPath returns the canonical path for p, eliminating . and .. elements. 137 | func cleanPath(p string) string { 138 | if p == "" { 139 | return "/" 140 | } 141 | if p[0] != '/' { 142 | p = "/" + p 143 | } 144 | np := path.Clean(p) 145 | // path.Clean removes trailing slash except for root; 146 | // put the trailing slash back if necessary. 147 | if p[len(p)-1] == '/' && np != "/" { 148 | // Fast path for common case of p being the string we want: 149 | if len(p) == len(np)+1 && strings.HasPrefix(p, np) { 150 | np = p 151 | } else { 152 | np += "/" 153 | } 154 | } 155 | return np 156 | } 157 | 158 | // stripHostPort returns h without any trailing ":". 159 | func stripHostPort(h string) string { 160 | // If no port on host, return unchanged 161 | if strings.IndexByte(h, ':') == -1 { 162 | return h 163 | } 164 | host, _, err := net.SplitHostPort(h) 165 | if err != nil { 166 | return h // on error, return unchanged 167 | } 168 | return host 169 | } 170 | 171 | // ----------------------------------------------------------------------------------------------- 172 | // routes 173 | 174 | // Handler responds to a HTTP request 175 | type Handler interface { 176 | ServeHTTP(*Transaction) 177 | } 178 | 179 | // Handle registers a HTTP request handler for the given pattern. 180 | // 181 | // The server takes care of sanitizing the URL request path and the Host header, 182 | // stripping the port number and redirecting any request containing . or .. 183 | // elements or repeated slashes to an equivalent, cleaner URL. 184 | func (s *Server) Handle(pattern string, handler Handler) { 185 | s.Routes.Handle(pattern, handler) 186 | } 187 | 188 | // HandleFunc registers a HTTP request handler function for the given pattern. 189 | // 190 | // The server takes care of sanitizing the URL request path and the Host header, 191 | // stripping the port number and redirecting any request containing . or .. 192 | // elements or repeated slashes to an equivalent, cleaner URL. 193 | func (s *Server) HandleFunc(pattern string, handler func(*Transaction)) { 194 | s.Routes.HandleFunc(pattern, handler) 195 | } 196 | 197 | // HandleGotalk registers a Gotalk request handler for the given operation, 198 | // with automatic JSON encoding of values. 199 | // 200 | // `handler` must conform to one of the following signatures: 201 | // func(*WebSocket, string, interface{}) (interface{}, error) ; takes socket, op and parameters 202 | // func(*WebSocket, interface{}) (interface{}, error) ; takes socket and parameters 203 | // func(*WebSocket) (interface{}, error) ; takes socket only 204 | // func(interface{}) (interface{}, error) ; takes parameters, but no socket 205 | // func() (interface{},error) ; takes no socket or parameters 206 | // 207 | // Optionally the `interface{}` return value can be omitted, i.e: 208 | // func(*WebSocket, string, interface{}) error 209 | // func(*WebSocket, interface{}) error 210 | // func(*WebSocket) error 211 | // func(interface{}) error 212 | // func() error 213 | // 214 | // If `op` is empty, handle all requests which doesn't have a specific handler registered. 215 | func (s *Server) HandleGotalk(op string, handler interface{}) { 216 | s.Gotalk.Handlers.Handle(op, handler) 217 | } 218 | 219 | // ----------------------------------------------------------------------------------------------- 220 | 221 | // protoname should be "http" or "https" 222 | func (s *Server) bindListener(protoname string) (net.Listener, error) { 223 | addr := s.Server.Addr 224 | if addr == "" { 225 | addr = ":" + protoname 226 | } 227 | return net.Listen("tcp", addr) 228 | } 229 | 230 | func (s *Server) prepareToServe() { 231 | // Configure logger 232 | if s.Logger == nil { 233 | s.Logger = log.RootLogger 234 | } 235 | if s.Server.ErrorLog == nil { 236 | // From go's net/http documentation on Server.ErrorLog: 237 | // ErrorLog specifies an optional logger for errors accepting connections, 238 | // unexpected behavior from handlers, and underlying FileSystem errors. 239 | // If nil, logging is done via the log package's standard logger. 240 | s.Server.ErrorLog = s.Logger.GoLogger(log.LevelError) 241 | } 242 | 243 | // // Unless there's already a handler registered for "/", install a "catch all" file handler. 244 | // // s.fileHandler is nil if PubDir is empty. 245 | // if s.fileHandler != nil { 246 | // func() { 247 | // defer func() { 248 | // if r := recover(); r != nil { 249 | // // ignore error "http: multiple registrations for /" 250 | // } 251 | // }() 252 | // s.Routes.Handle("/", s.fileHandler) 253 | // }() 254 | // } 255 | 256 | if s.Gotalk != nil { 257 | // Install the gotalk connect handler here rather than when creating the Server struct so that 258 | // in case the user installed a handler, we can wrap it. 259 | s.gotalkOnConnectUser = s.Gotalk.OnConnect // save any user handler 260 | s.Gotalk.OnConnect = s.gotalkOnConnect 261 | } 262 | } 263 | 264 | func (s *Server) justBeforeServing(ln net.Listener, protoname, extraLogMsg string) { 265 | s.LogInfo("listening on %s://%s (pubdir %q%s)", protoname, ln.Addr(), s.PubDir, extraLogMsg) 266 | } 267 | 268 | func (s *Server) returnFromServe(err error) error { 269 | if s.Gotalk != nil { 270 | // Restore previously replaced Gotalk.OnConnect 271 | s.Gotalk.OnConnect = s.gotalkOnConnectUser 272 | s.gotalkOnConnectUser = nil 273 | } 274 | 275 | if err == http.ErrServerClosed { 276 | // returned from Serve functions when server.Shutdown() was initiated 277 | err = nil 278 | } 279 | return err 280 | } 281 | 282 | func (s *Server) gotalkOnConnect(sock *gotalk.WebSocket) { 283 | s.LogDebug("gotalk sock#%p connected", sock) 284 | 285 | // call user handler 286 | if s.gotalkOnConnectUser != nil { 287 | s.gotalkOnConnectUser(sock) 288 | if sock.IsClosed() { 289 | return 290 | } 291 | } 292 | 293 | // register close handler 294 | userCloseHandler := sock.CloseHandler 295 | sock.CloseHandler = func(sock *gotalk.WebSocket, closeCode int) { 296 | s.LogDebug("gotalk sock#%p disconnected", sock) 297 | s.gotalkSocksMu.Lock() 298 | delete(s.gotalkSocks, sock) 299 | s.gotalkSocksMu.Unlock() 300 | sock.CloseHandler = nil 301 | if userCloseHandler != nil { 302 | userCloseHandler(sock, closeCode) 303 | } 304 | } 305 | 306 | // register connection 307 | s.gotalkSocksMu.Lock() 308 | if s.gotalkSocks == nil { 309 | s.gotalkSocks = make(map[*gotalk.WebSocket]int) 310 | } 311 | s.gotalkSocks[sock] = 1 312 | s.gotalkSocksMu.Unlock() 313 | } 314 | 315 | // RangeGotalkSockets calls f with each currently-connected gotalk socket. 316 | // Gotalk socket acceptance will be blocked while this method is called as the underlying 317 | // socket list is locked during the call. 318 | // If f returns false then iteration stops early. 319 | func (s *Server) RangeGotalkSockets(f func(*gotalk.WebSocket) bool) { 320 | s.gotalkSocksMu.RLock() 321 | defer s.gotalkSocksMu.RUnlock() 322 | for s := range s.gotalkSocks { 323 | if !f(s) { 324 | break 325 | } 326 | } 327 | } 328 | 329 | func (s *Server) Serve(ln net.Listener) error { 330 | s.prepareToServe() 331 | s.justBeforeServing(ln, "http", "") 332 | return s.Server.Serve(ln) 333 | } 334 | 335 | func (s *Server) Addr() string { 336 | return s.Server.Addr 337 | } 338 | 339 | func (s *Server) Close() error { 340 | return s.Server.Close() 341 | } 342 | 343 | func (s *Server) Shutdown(ctx context.Context, stoppedAcceptingCallback func()) error { 344 | if stoppedAcceptingCallback != nil { 345 | s.Server.RegisterOnShutdown(stoppedAcceptingCallback) 346 | } 347 | s.Server.SetKeepAlivesEnabled(false) 348 | return s.Server.Shutdown(ctx) 349 | } 350 | 351 | // 352 | // ————————————————————————————————————————————————————————————————————————————————————————————— 353 | // logging 354 | // 355 | 356 | func (s *Server) LogError(format string, v ...interface{}) { 357 | s.Logger.Error(format, v...) 358 | } 359 | func (s *Server) LogWarn(format string, v ...interface{}) { 360 | s.Logger.Warn(format, v...) 361 | } 362 | func (s *Server) LogInfo(format string, v ...interface{}) { 363 | s.Logger.Info(format, v...) 364 | } 365 | func (s *Server) LogDebug(format string, v ...interface{}) { 366 | s.Logger.LogDebug(1, format, v...) 367 | } 368 | 369 | // ----------------------------------------------------------------------------------------------- 370 | // Graceful shutdown 371 | 372 | var ( 373 | // protects the following fields 374 | gracefulShutdownMu sync.Mutex 375 | 376 | // listening servers which opted in to graceful shutdown 377 | gracefulShutdownServers []*Server 378 | 379 | // channel that closes when all servers has completed shutdown 380 | gracefulShutdownChan chan struct{} 381 | ) 382 | 383 | // EnableGracefulShutdown enables the server to be shut down gracefully, allowing active 384 | // connections to end within shutdownTimeout. 385 | // 386 | // When graceful shutdown is enabled, SIGINT and SIGTERM signals to the process initiates 387 | // shutdown of all servers which opted in to graceful shutdown. Servers which didn't will close 388 | // as usual (immediately.) 389 | // 390 | // See net/http.Server.Shutdown for details on shutdown semantics. 391 | // 392 | func (s *Server) EnableGracefulShutdown(shutdownTimeout time.Duration) chan struct{} { 393 | if shutdownTimeout == 0 { 394 | panic("timeout can not be 0") 395 | } 396 | s.gracefulShutdownTimeout = shutdownTimeout 397 | gracefulShutdownMu.Lock() 398 | defer gracefulShutdownMu.Unlock() 399 | gracefulShutdownServers = append(gracefulShutdownServers, s) 400 | if gracefulShutdownChan == nil { 401 | // Install signal handler 402 | gracefulShutdownChan = make(chan struct{}) 403 | quit := make(chan os.Signal, 1) 404 | signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) 405 | go func() { 406 | <-quit 407 | gracefulShutdownAll() 408 | }() 409 | } 410 | return gracefulShutdownChan 411 | } 412 | 413 | func (s *Server) DisableGracefulShutdown() { 414 | s.gracefulShutdownTimeout = 0 415 | // Note: It may seem like a good idea to remove s from gracefulShutdownServers but that 416 | // would not work. If the caller is waiting for gracefulShutdownChan to close, after Listen 417 | // returns, Listen would never return since the server would never be shut down. 418 | } 419 | 420 | func gracefulShutdownAll() { 421 | gracefulShutdownMu.Lock() 422 | defer gracefulShutdownMu.Unlock() 423 | 424 | var wg sync.WaitGroup 425 | 426 | shutdownServer := func(server *Server) { 427 | defer wg.Done() 428 | if server.gracefulShutdownTimeout == 0 { 429 | // DisableGracefulShutdown was called; close server to make sure that the caller's 430 | // Listen call ends & returns. 431 | server.Server.Close() 432 | return 433 | } 434 | server.Logger.Debug("graceful shutdown initiated") 435 | ctx, cancel := context.WithTimeout(context.Background(), server.gracefulShutdownTimeout) 436 | defer cancel() 437 | server.Server.SetKeepAlivesEnabled(false) 438 | if err := server.Server.Shutdown(ctx); err != nil { 439 | server.Logger.Error("graceful shutdown error: %s", err) 440 | } else { 441 | server.Logger.Debug("graceful shutdown complete") 442 | } 443 | } 444 | 445 | for i, server := range gracefulShutdownServers { 446 | wg.Add(1) 447 | if i == len(gracefulShutdownServers)-1 { 448 | shutdownServer(server) 449 | } else { 450 | go shutdownServer(server) 451 | } 452 | } 453 | 454 | wg.Wait() 455 | close(gracefulShutdownChan) 456 | gracefulShutdownChan = nil 457 | gracefulShutdownServers = nil 458 | } 459 | -------------------------------------------------------------------------------- /server_generic.go: -------------------------------------------------------------------------------- 1 | // +build !linux 2 | 3 | package httpd 4 | 5 | func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { 6 | s.prepareToServe() 7 | ln, err := s.bindListener("https") 8 | if err == nil { 9 | defer ln.Close() 10 | s.justBeforeServing(ln, "https", "") 11 | err = s.Server.ServeTLS(ln, certFile, keyFile) 12 | } 13 | return s.returnFromServe(err) 14 | } 15 | 16 | func (s *Server) ListenAndServe() error { 17 | s.prepareToServe() 18 | ln, err := s.bindListener("http") 19 | if err == nil { 20 | defer ln.Close() 21 | s.justBeforeServing(ln, "http", "") 22 | err = s.Server.Serve(ln) 23 | } 24 | return s.returnFromServe(err) 25 | } 26 | -------------------------------------------------------------------------------- /server_linux.go: -------------------------------------------------------------------------------- 1 | // +build linux 2 | 3 | package httpd 4 | 5 | import ( 6 | "net" 7 | 8 | "github.com/coreos/go-systemd/activation" 9 | ) 10 | 11 | // This file implements ListenAndServer which works with systemd socket activation. 12 | // Inspired by https://vincent.bernat.ch/en/blog/2018-systemd-golang-socket-activation 13 | 14 | func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { 15 | ln, err := s.listenSystemd("https") 16 | if err == nil { 17 | defer ln.Close() 18 | err = s.Server.ServeTLS(ln, certFile, keyFile) 19 | } 20 | return s.returnFromServe(err) 21 | } 22 | 23 | func (s *Server) ListenAndServe() error { 24 | ln, err := s.listenSystemd("http") 25 | if err == nil { 26 | defer ln.Close() 27 | err = s.Server.Serve(ln) 28 | } 29 | return s.returnFromServe(err) 30 | } 31 | 32 | func (s *Server) listenSystemd(proto string) (net.Listener, error) { 33 | var ln net.Listener 34 | s.prepareToServe() 35 | listeners, err := activation.Listeners() 36 | if err == nil && len(listeners) == 1 { 37 | // use systemd listener 38 | s.Logger.Debug("using socket from systemd socket activation") 39 | ln = listeners[0] 40 | s.justBeforeServing(ln, proto, ", systemd-socket") 41 | } else { 42 | if len(listeners) > 1 { 43 | // We can only handle a single socket; fail if we get more than 1. 44 | // If multiple sockets are provided by systemd for the process, it's better to call Serve(l) 45 | // directly instead of using ListenSystemd() 46 | s.Logger.Warn("More than one socket fds from systemd: %d (ignoring systemd socket)", 47 | len(listeners)) 48 | } 49 | ln, err = s.bindListener(proto) 50 | if err == nil { 51 | s.justBeforeServing(ln, proto, "") 52 | } 53 | } 54 | return ln, err 55 | } 56 | -------------------------------------------------------------------------------- /session/session.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "bytes" 5 | "encoding/gob" 6 | "fmt" 7 | "net/http" 8 | "time" 9 | 10 | "github.com/rsms/go-httpd/util" 11 | "github.com/rsms/go-uuid" 12 | ) 13 | 14 | // Session holds a set of keys & values associated with an ID 15 | type Session struct { 16 | ID string // globally unique session identifier 17 | 18 | store *Store // parent store 19 | values map[string]interface{} // cached values (including pending changes, if dirty=true) 20 | dirty bool // true if values have been modified 21 | } 22 | 23 | func (s *Session) String() string { 24 | return s.ID 25 | } 26 | 27 | func (s *Session) Set(key string, value interface{}) { 28 | if s.values == nil { 29 | if value != nil { 30 | s.values = map[string]interface{}{key: value} 31 | s.dirty = true 32 | } 33 | } else { 34 | if value == nil { 35 | delete(s.values, key) 36 | } else { 37 | s.values[key] = value 38 | } 39 | s.dirty = true 40 | } 41 | } 42 | 43 | func (s *Session) Get(key string) interface{} { 44 | if s.values != nil { 45 | return s.values[key] 46 | } 47 | return nil 48 | } 49 | 50 | func (s *Session) Del(key string) { 51 | if s.values != nil { 52 | if _, ok := s.values[key]; ok { 53 | delete(s.values, key) 54 | s.dirty = true 55 | } 56 | } 57 | } 58 | 59 | // Clear removes all data for the session. 60 | // A subsequent call to Save or SaveHTTP will remove the session data from the db 61 | // (and the cookie from the HTTP client in case of calling SaveHTTP.) 62 | func (s *Session) Clear() { 63 | if len(s.values) != 0 { 64 | s.values = nil 65 | s.dirty = true 66 | } 67 | } 68 | 69 | func (s *Session) LoadHTTP(r *http.Request) error { 70 | c, err := r.Cookie(s.store.CookieName) 71 | if err != nil { 72 | return err 73 | } 74 | id := c.Value 75 | if !isValidSessionID(id) { 76 | return fmt.Errorf("invalid session id in session cookie") 77 | } 78 | return s.Load(id) 79 | } 80 | 81 | // SaveHTTP persists the session's data if needed and refreshes its expiration by 82 | // calling Save() and then sets a corresponding cookie in the response header. 83 | // 84 | // SaveHTTP should be called after a session's Set or Del methods have been called. 85 | // 86 | // Note that a session in a Transaction is saved automatically. 87 | // 88 | // If the session does not have an ID (i.e. the session is new), then s.ID 89 | // is assigned a new identifier in the case the session has any data. 90 | // 91 | func (s *Session) SaveHTTP(w http.ResponseWriter) error { 92 | // Save() might clear s.dirty and/or s.ID so check if s will modify db storage 93 | // before we call Save() 94 | shouldSetCookie := s.dirty || len(s.ID) > 0 95 | 96 | // Set if dirty and refresh TTL 97 | if err := s.Save(); err != nil || !shouldSetCookie { 98 | // either save failed or we only needed to refresh TTL 99 | return err 100 | } 101 | 102 | // Set cookie 103 | cookie := s.bakeSessionIDCookie() 104 | return util.HeaderSetCookie(w.Header(), cookie) 105 | } 106 | 107 | // bakeSessionIDCookie creates a cookie named s.store.CookieName 108 | // with max-age s.store.TTL and value s.ID 109 | func (s *Session) bakeSessionIDCookie() string { 110 | // See https://tools.ietf.org/html/rfc6265 111 | 112 | // MaxAge=0 means no Max-Age attribute specified and the cookie will be 113 | // deleted after the browser session ends. 114 | // MaxAge<0 means delete cookie immediately. 115 | // MaxAge>0 means Max-Age attribute present and given in seconds. 116 | maxAgeSec := -1 117 | if len(s.values) > 0 { 118 | maxAgeSec = int(s.store.TTL / time.Second) 119 | } 120 | 121 | cookie := fmt.Sprintf("%s=%s;Path=/;Max-Age=%d;HttpOnly;SameSite=Strict", 122 | s.store.CookieName, 123 | s.ID, 124 | maxAgeSec, 125 | ) 126 | // Note: "HttpOnly" = don't expose to javascript 127 | 128 | // "Secure" instructs the requestor to only store this cookie if the connection over which 129 | // it's transmitted is secure (i.e. only over HTTPS.) 130 | if !s.store.AllowInsecureCookies { 131 | cookie += ";Secure" 132 | } 133 | 134 | return cookie 135 | } 136 | 137 | // Load restores session data for id. s.ID is assigned id on success and "" on error. 138 | // 139 | func (s *Session) Load(id string) error { 140 | data, err := s.store.storage.GetSessionData(id) 141 | s.ID = "" 142 | if err == nil && len(data) > 0 { 143 | var values map[string]interface{} 144 | values, err = decodeSessionValues(data) 145 | if err == nil { 146 | s.values = values 147 | s.ID = id 148 | } 149 | } 150 | return err 151 | } 152 | 153 | // Save persists the session's data if needed and refreshes its expiration. 154 | // If the session does not have an ID (i.e. the session is new), then s.ID 155 | // is assigned a new identifier in the case the session has any data. 156 | // 157 | func (s *Session) Save() (err error) { 158 | if s.dirty { 159 | var data []byte 160 | if len(s.values) == 0 { 161 | err = s.store.storage.DelSessionData(s.ID) 162 | if err == nil { 163 | s.ID = "" 164 | } 165 | } else if data, err = encodeSessionValues(s.values); err == nil { 166 | if len(s.ID) == 0 { 167 | id, err1 := uuid.Gen() 168 | if err1 != nil { 169 | err = err1 170 | } else { 171 | s.ID = id.String() 172 | } 173 | } 174 | err = s.store.storage.SetSessionData(s.ID, data, s.store.TTL) 175 | } 176 | if err == nil { 177 | s.dirty = false 178 | } 179 | } else if len(s.ID) > 0 { 180 | // refresh timestamp in db 181 | err = s.store.storage.RefreshSessionData(s.ID, s.store.TTL) 182 | } 183 | return 184 | } 185 | 186 | func decodeSessionValues(data []byte) (values map[string]interface{}, err error) { 187 | buf := bytes.NewBuffer(data) 188 | err = gob.NewDecoder(buf).Decode(&values) 189 | return 190 | } 191 | 192 | func encodeSessionValues(values map[string]interface{}) ([]byte, error) { 193 | var buf bytes.Buffer 194 | err := gob.NewEncoder(&buf).Encode(values) 195 | return buf.Bytes(), err 196 | } 197 | 198 | func isValidSessionID(id string) bool { 199 | if len(id) < 4 || len(id) > uuid.StringMaxLen { 200 | return false 201 | } 202 | for i := 0; i < len(id); i++ { 203 | b := id[i] 204 | if !((b >= '0' && b <= '9') || (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || 205 | b == '-' || b == '_') { 206 | // invalid byte 207 | return false 208 | } 209 | } 210 | return true 211 | } 212 | -------------------------------------------------------------------------------- /session/storage.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | // Storage provides persistance for session data. 10 | // 11 | // The information stored in a session is usually sensitive, i.e. it may contain user 12 | // authentication data or server-private details like internal IDs. For this reason a 13 | // Storage implementation should do its best to store data in a secure manner. 14 | // 15 | type Storage interface { 16 | // GetSessionData retrieves data for a session. 17 | // It returns nil if not found along with a decriptive error. 18 | GetSessionData(sessionId string) ([]byte, error) 19 | 20 | // SetSessionData adds or replaces data for a session. 21 | // 22 | // The TTL value is relative to "now" and dictates the maximum age of session data. 23 | // The implementation may choose how to enforce this requirement. Storage backed by for 24 | // example memcached or redis could use the built-in TTL functionality of those storage 25 | // mechanisms while a simpler implementation could store the time + ttl along with the 26 | // data and check if data is expired or not in its GetSessionData method. 27 | // 28 | SetSessionData(sessionId string, data []byte, ttl time.Duration) error 29 | 30 | // RefreshSessionData is similar to SetSessionData and is used to extend the expiration 31 | // time of a session without changing its data. 32 | // 33 | // This method is usually called whenever a session has been used (accessed) and so 34 | // implementations should make sure this is efficient. A trivial implementation, or an 35 | // implementation not concerned with performance, may implement this method as a call to 36 | // GetSessionData followed by a call to SetSessionData. 37 | // 38 | RefreshSessionData(sessionId string, ttl time.Duration) error 39 | 40 | // DelSessionData removes any data associated with sessionId. 41 | // If there is no data for sessionId then nil is returned rather than a "not found" error. 42 | // An error is returned when there was data for sessionId but the delete operation failed. 43 | DelSessionData(sessionId string) error 44 | } 45 | 46 | // MemoryStorage is an implementation of Storage that keeps session data in memory. 47 | // Useful for testing and also demonstrates a concrete implementation. 48 | type MemoryStorage struct { 49 | sync.Map 50 | } 51 | 52 | type memStorageEntry struct { 53 | data []byte 54 | expires time.Time 55 | } 56 | 57 | var ( 58 | ErrStorageNotFound = errors.New("not found") 59 | ErrStorageExpired = errors.New("expired") 60 | ) 61 | 62 | func (s *MemoryStorage) GetSessionData(sessionId string) ([]byte, error) { 63 | if v, ok := s.Load(sessionId); ok { 64 | entry := v.(memStorageEntry) 65 | if time.Now().After(entry.expires) { 66 | s.Delete(sessionId) 67 | return nil, ErrStorageExpired 68 | } 69 | return entry.data, nil 70 | } 71 | return nil, ErrStorageNotFound 72 | } 73 | 74 | func (s *MemoryStorage) SetSessionData(sessionId string, data []byte, ttl time.Duration) error { 75 | s.Store(sessionId, memStorageEntry{data, time.Now().Add(ttl)}) 76 | return nil 77 | } 78 | 79 | func (s *MemoryStorage) RefreshSessionData(sessionId string, ttl time.Duration) error { 80 | v, ok := s.Load(sessionId) 81 | if !ok { 82 | return ErrStorageNotFound 83 | } 84 | entry := v.(memStorageEntry) 85 | entry.expires = time.Now().Add(ttl) 86 | s.Store(sessionId, entry) 87 | return nil 88 | } 89 | 90 | func (s *MemoryStorage) DelSessionData(sessionId string) error { 91 | s.Delete(sessionId) 92 | return nil 93 | } 94 | -------------------------------------------------------------------------------- /session/storage_test.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/rsms/go-testutil" 8 | ) 9 | 10 | func TestMemoryStorage(t *testing.T) { 11 | assert := testutil.NewAssert(t) 12 | 13 | var s MemoryStorage 14 | sessionId := "abc123" 15 | 16 | data, err := s.GetSessionData(sessionId) 17 | assert.Err("get non-existing should fail", "not found", err) 18 | assert.Eq("get non-existing should yield nil", data, nil) 19 | 20 | indata := []byte("hello") 21 | err = s.SetSessionData(sessionId, indata, time.Second) 22 | assert.NoErr("SetSessionData should succeed", err) 23 | 24 | data, err = s.GetSessionData(sessionId) 25 | assert.NoErr("GetSessionData should succeed after SetSessionData", err) 26 | assert.Eq("data", data, indata) 27 | 28 | err = s.DelSessionData(sessionId) 29 | assert.NoErr("DelSessionData should succeed", err) 30 | 31 | data, err = s.GetSessionData(sessionId) 32 | assert.Err("get non-existing should fail after DelSessionData", "not found", err) 33 | assert.Eq("get non-existing should yield nil after DelSessionData", data, nil) 34 | 35 | // expiry 36 | s.SetSessionData(sessionId, indata, time.Nanosecond /* expire immediately */) 37 | time.Sleep(time.Millisecond) 38 | data, err = s.GetSessionData(sessionId) 39 | assert.Err("get expired should fail", "expired", err) 40 | assert.Eq("get expired should yield nil", data, nil) 41 | 42 | // subsequent attempts to load should error with "not found" rather than "expired" 43 | _, err = s.GetSessionData(sessionId) 44 | assert.Err("get expired should fail", "not found", err) 45 | 46 | // expiration refresh 47 | // Note: The following two lines should be identical to the lines above under the "expiry" 48 | // comment. 49 | s.SetSessionData(sessionId, indata, time.Nanosecond /* expire immediately */) 50 | time.Sleep(time.Millisecond) 51 | // extend expiration time & fetch to verify 52 | err = s.RefreshSessionData(sessionId, time.Second) 53 | assert.NoErr("RefreshSessionData should succeed", err) 54 | // we should now be able to fetch the data within the next second 55 | data, err = s.GetSessionData(sessionId) 56 | assert.NoErr("GetSessionData should succeed after RefreshSessionData", err) 57 | assert.Eq("data", data, indata) 58 | } 59 | -------------------------------------------------------------------------------- /session/store.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "time" 8 | ) 9 | 10 | // Store represents a domain of sessions. Usually an app has just one of these. 11 | // A Store manages Sessions and persists their data using Storage. 12 | type Store struct { 13 | // TTL defines the time-to-live for sessions; how old a session can be before it 14 | // expires & is considered dead. This is relative to the last Save() call. 15 | TTL time.Duration // defaults to 30 days 16 | 17 | // CookieName is the name of the HTTP cookie to use for session ID transmission 18 | CookieName string // defaults to "session" 19 | 20 | // AllowInsecureCookies can be set to true to omit the "Secure" directive in cookies. 21 | // This is needed for cookies to "stick" when serving over unencrypted http (i.e. no TLS.) 22 | AllowInsecureCookies bool 23 | 24 | storage Storage 25 | } 26 | 27 | func NewStore(storage Storage) *Store { 28 | ss := &Store{} 29 | ss.SetStorage(storage) 30 | return ss 31 | } 32 | 33 | // Storage returns the storage used by this store 34 | func (ss *Store) Storage() Storage { return ss.storage } 35 | 36 | // SetStorage sets the storage used for persistance. 37 | // This method also initializes TTL and CookieName to default values, if they are empty. 38 | func (ss *Store) SetStorage(storage Storage) { 39 | ss.storage = storage 40 | if ss.TTL == 0 { 41 | ss.TTL = 30 * 24 * time.Hour // 30 days 42 | } 43 | if ss.CookieName == "" { 44 | ss.CookieName = "session" 45 | } 46 | } 47 | 48 | // GetHTTP retrieves a session from a http request. 49 | // The results are cached for the same request making this function efficient to call frequently. 50 | func (ss *Store) GetHTTP(r *http.Request) *Session { 51 | var ctx = r.Context() 52 | v := ctx.Value(ss) 53 | if v != nil { 54 | return v.(*Session) 55 | } 56 | s, _ := ss.LoadHTTP(r) // ignore error 57 | *r = *r.WithContext(context.WithValue(ctx, ss, s)) // cache 58 | return s 59 | } 60 | 61 | // LoadHTTP attempts to load a session from a http request by reading session id from cookie and 62 | // loading the session data from storage. 63 | // 64 | // A valid Session object is always returned. The returned error value indicates if loading of a 65 | // session succeeded (err=nil) and should be considered informative rather than a hard error. 66 | // 67 | func (ss *Store) LoadHTTP(r *http.Request) (*Session, error) { 68 | s := &Session{store: ss} 69 | if ss.storage == nil { 70 | return s, ErrNoStorage 71 | } 72 | return s, s.LoadHTTP(r) 73 | } 74 | 75 | // ErrNoStorage is returned when loading a session with a Store that has no backing storage 76 | var ErrNoStorage = errors.New("no session storage configured") 77 | -------------------------------------------------------------------------------- /template.go: -------------------------------------------------------------------------------- 1 | package httpd 2 | 3 | import ( 4 | "bytes" 5 | html_template "html/template" 6 | "io" 7 | "io/ioutil" 8 | "path/filepath" 9 | text_template "text/template" 10 | tparse "text/template/parse" 11 | ) 12 | 13 | type Template interface { 14 | Name() string 15 | AddParseTree(name string, tree *tparse.Tree) (Template, error) // returns possibly new template 16 | Option(option string) 17 | Funcs(funcs map[string]interface{}) 18 | Exec(w io.Writer, data interface{}) error 19 | ExecNamed(w io.Writer, name string, data interface{}) error 20 | ExecBuf(data interface{}) ([]byte, error) 21 | Templates() []Template 22 | Tree() *tparse.Tree 23 | } 24 | 25 | func NewHtmlTemplate(name string) Template { 26 | return &htmlTemplate{html_template.New(name)} 27 | } 28 | 29 | func NewTextTemplate(name string) Template { 30 | return &textTemplate{text_template.New(name)} 31 | } 32 | 33 | func ParseHtmlTemplate(name, text string) (t Template, err error) { 34 | tpl := html_template.New(name) 35 | tpl.Funcs(standardTemplateHelpers()) 36 | if _, err := tpl.Parse(text); err != nil { 37 | return nil, err 38 | } 39 | return &htmlTemplate{tpl}, nil 40 | } 41 | 42 | func ParseTextTemplate(name, text string) (Template, error) { 43 | tpl := text_template.New(name) 44 | tpl.Funcs(standardTemplateHelpers()) 45 | if _, err := tpl.Parse(text); err != nil { 46 | return nil, err 47 | } 48 | return &textTemplate{tpl}, nil 49 | } 50 | 51 | func ParseHtmlTemplateFile(filename string) (t Template, err error) { 52 | b, err := ioutil.ReadFile(filename) 53 | if err != nil { 54 | return nil, err 55 | } 56 | return ParseHtmlTemplate(filepath.Base(filename), string(b)) 57 | } 58 | 59 | func ParseTextTemplateFile(filename string) (t Template, err error) { 60 | b, err := ioutil.ReadFile(filename) 61 | if err != nil { 62 | return nil, err 63 | } 64 | return ParseTextTemplate(filepath.Base(filename), string(b)) 65 | } 66 | 67 | // func parseTemplate(name string, text string) (*tparse.Tree, error) { 68 | // helpers := standardTemplateHelpers() 69 | // asts, err := tparse.Parse(name, text, "{{", "}}", helpers) 70 | // if err != nil { 71 | // return nil, err 72 | // } 73 | // return asts[name], nil 74 | // } 75 | 76 | // ------------------------------------------------------------------------ 77 | 78 | type htmlTemplate struct { 79 | t *html_template.Template 80 | } 81 | 82 | func (t *htmlTemplate) AddParseTree(name string, tree *tparse.Tree) (Template, error) { 83 | t2, err := t.t.AddParseTree(name, tree) 84 | if err != nil { 85 | return nil, err 86 | } 87 | return &htmlTemplate{t2}, err 88 | } 89 | 90 | func (t *htmlTemplate) Name() string { return t.t.Name() } 91 | func (t *htmlTemplate) Option(option string) { t.t.Option(option) } 92 | func (t *htmlTemplate) Funcs(funcs map[string]interface{}) { t.t.Funcs(funcs) } 93 | func (t *htmlTemplate) Exec(w io.Writer, data interface{}) error { return t.t.Execute(w, data) } 94 | func (t *htmlTemplate) ExecNamed(w io.Writer, name string, data interface{}) error { 95 | return t.t.ExecuteTemplate(w, name, data) 96 | } 97 | func (t *htmlTemplate) Tree() *tparse.Tree { return t.t.Tree } 98 | 99 | func (t *htmlTemplate) ExecBuf(data interface{}) ([]byte, error) { 100 | var buf bytes.Buffer 101 | err := t.Exec(&buf, data) 102 | return buf.Bytes(), err 103 | } 104 | 105 | func (t *htmlTemplate) Templates() []Template { 106 | src := t.t.Templates() 107 | tv := make([]Template, len(src)) 108 | for i, t2 := range src { 109 | tv[i] = &htmlTemplate{t2} 110 | } 111 | return tv 112 | } 113 | 114 | // ------------------------------------------------------------------------ 115 | 116 | type textTemplate struct { 117 | t *text_template.Template 118 | } 119 | 120 | func (t *textTemplate) AddParseTree(name string, tree *tparse.Tree) (Template, error) { 121 | t2, err := t.t.AddParseTree(name, tree) 122 | if err != nil { 123 | return nil, err 124 | } 125 | return &textTemplate{t2}, err 126 | } 127 | 128 | func (t *textTemplate) Name() string { return t.t.Name() } 129 | func (t *textTemplate) Option(option string) { t.t.Option(option) } 130 | func (t *textTemplate) Funcs(funcs map[string]interface{}) { t.t.Funcs(funcs) } 131 | func (t *textTemplate) Exec(w io.Writer, data interface{}) error { return t.t.Execute(w, data) } 132 | func (t *textTemplate) ExecNamed(w io.Writer, name string, data interface{}) error { 133 | return t.t.ExecuteTemplate(w, name, data) 134 | } 135 | func (t *textTemplate) Tree() *tparse.Tree { return t.t.Tree } 136 | 137 | func (t *textTemplate) ExecBuf(data interface{}) ([]byte, error) { 138 | var buf bytes.Buffer 139 | err := t.Exec(&buf, data) 140 | return buf.Bytes(), err 141 | } 142 | 143 | func (t *textTemplate) Templates() []Template { 144 | src := t.t.Templates() 145 | tv := make([]Template, len(src)) 146 | for i, t2 := range src { 147 | tv[i] = &textTemplate{t2} 148 | } 149 | return tv 150 | } 151 | -------------------------------------------------------------------------------- /template_helpers.go: -------------------------------------------------------------------------------- 1 | package httpd 2 | 3 | import ( 4 | "fmt" 5 | "path" 6 | "strings" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | type TemplateHelpersMap = map[string]interface{} 12 | 13 | func NewTemplateHelpersMap(base TemplateHelpersMap) TemplateHelpersMap { 14 | h := make(TemplateHelpersMap) 15 | for k, v := range base { 16 | h[k] = v 17 | } 18 | return h 19 | } 20 | 21 | var ( 22 | standardTemplateHelpersOnce sync.Once 23 | standardTemplateHelpersMap TemplateHelpersMap 24 | ) 25 | 26 | func standardTemplateHelpers() TemplateHelpersMap { 27 | standardTemplateHelpersOnce.Do(func() { 28 | standardTemplateHelpersMap = buildStandardTemplateHelpers() 29 | }) 30 | return standardTemplateHelpersMap 31 | } 32 | 33 | func buildStandardTemplateHelpers() TemplateHelpersMap { 34 | // helper functions shared by everything 35 | h := make(TemplateHelpersMap) 36 | 37 | h["ServerDevMode"] = func() bool { 38 | return DevMode 39 | } 40 | 41 | h["now"] = func() time.Time { 42 | return time.Now() 43 | } 44 | 45 | h["cat"] = func(args ...interface{}) string { 46 | var b strings.Builder 47 | fmt.Fprint(&b, args...) 48 | return b.String() 49 | } 50 | 51 | h["url"] = func(args ...string) string { 52 | return path.Join(args...) 53 | } 54 | 55 | h["timestamp"] = func(v ...interface{}) int64 { 56 | if len(v) == 0 { 57 | return time.Now().UTC().Unix() 58 | } else { 59 | if t, ok := v[0].(time.Time); ok { 60 | return t.UTC().Unix() 61 | } 62 | } 63 | return 0 64 | } 65 | 66 | return h 67 | } 68 | 69 | // ---------------- 70 | 71 | // func cleanFileName(basedir, name string) string { 72 | // var fn string 73 | // if runtime.GOOS == "windows" { 74 | // name = strings.Replace(name, "/", "\\", -1) 75 | // fn = filepath.Join(basedir, strings.TrimLeft(name, "\\")) 76 | // } else { 77 | // fn = filepath.Join(basedir, strings.TrimLeft(name, "/")) 78 | // } 79 | // fn = filepath.Clean(fn) 80 | // if !strings.HasPrefix(fn, basedir) { 81 | // return "" 82 | // } 83 | // return fn 84 | // } 85 | 86 | // func (service *Service) buildHelpers(base TemplateHelpersMap) TemplateHelpersMap { 87 | // // helper functions shared by everything in the same Ghp instance. 88 | // h := NewTemplateHelpersMap(base) 89 | 90 | // // readfile reads a file relative to PubDir 91 | // h["readfile"] = func (name string) (string, error) { 92 | // fn := cleanFileName(g.config.PubDir, name) 93 | // if fn == "" { 94 | // return "", errorf("file not found %v", name) 95 | // } 96 | // data, err := ioutil.ReadFile(fn) 97 | // if err != nil { 98 | // return "", err 99 | // } 100 | // return string(data), nil 101 | // } 102 | 103 | // return h 104 | // } 105 | -------------------------------------------------------------------------------- /transaction.go: -------------------------------------------------------------------------------- 1 | package httpd 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "html" 7 | "net/http" 8 | "net/url" 9 | "path/filepath" 10 | "runtime" 11 | "strconv" 12 | "sync" 13 | "time" 14 | 15 | "github.com/rsms/go-httpd/route" 16 | "github.com/rsms/go-httpd/session" 17 | "github.com/rsms/go-httpd/util" 18 | ) 19 | 20 | // Transaction represents a HTTP request + response. 21 | // Implements io.Writable 22 | // Implements http.ResponseWriter 23 | // 24 | type Transaction struct { 25 | http.ResponseWriter 26 | Request *http.Request 27 | Server *Server 28 | URL *url.URL 29 | Status int // response status code (200 by default) 30 | AuxData map[string]interface{} // can be used to associate arbitrary data with a transaction 31 | 32 | headerWritten bool 33 | query url.Values // initially nil (it's a map); cached value of .URL.Query() 34 | session *session.Session 35 | routeMatch *route.Match // non-nil when the transaction went through HttpRouter 36 | } 37 | 38 | // thread-safe pool of free Transaction objects reduces memory thrash 39 | var httpTransactionFreePool = sync.Pool{} 40 | 41 | func init() { 42 | // must set in init function since the New function references gcTransaction 43 | // which depends on httpTransactionFreePool. 44 | httpTransactionFreePool.New = func() interface{} { 45 | // called when there are no free items in the pool 46 | return new(Transaction) 47 | } 48 | } 49 | 50 | func gcTransaction(t *Transaction) { 51 | // called by the garbage collector when t is definitely garbage. 52 | // instead of letting t be collected, we put it into our free list. 53 | // 54 | // Clear references to data to allow that data to be garbage collected. 55 | t.ResponseWriter = nil 56 | t.Request = nil 57 | t.URL = nil 58 | t.headerWritten = false 59 | t.query = nil 60 | // t.user = nil 61 | // t.userLoaded = false 62 | t.session = nil 63 | httpTransactionFreePool.Put(t) 64 | } 65 | 66 | func NewTransaction(server *Server, w http.ResponseWriter, r *http.Request) *Transaction { 67 | t := httpTransactionFreePool.Get().(*Transaction) 68 | runtime.SetFinalizer(t, gcTransaction) 69 | t.ResponseWriter = w 70 | t.Request = r 71 | t.Server = server 72 | t.URL = r.URL 73 | t.Status = 200 74 | return t 75 | } 76 | 77 | // Method returns the HTTP request method (i.e. GET, POST, etc.) 78 | func (t *Transaction) Method() string { return t.Request.Method } 79 | 80 | // Var returns the first value for the named component of the query. 81 | // 82 | // Search order: 83 | // 1. URL route parameter (e.g. "id" in "/user/{id}") 84 | // 2. FORM or PUT parameters 85 | // 3. URL query-string parameters 86 | // 87 | // This function calls Request.ParseMultipartForm and Request.ParseForm if necessary and ignores 88 | // any errors returned by these functions. If key is not present, Var returns the empty string. 89 | // 90 | // To access multiple values of the same key, call Request.ParseForm and then inspect 91 | // Request.Form directly. 92 | func (t *Transaction) Var(name string) string { 93 | if t.routeMatch != nil { 94 | if value := t.routeMatch.Var(name); value != "" { 95 | return value 96 | } 97 | } 98 | return t.Request.FormValue(name) 99 | } 100 | 101 | // parameter from URL route 102 | func (t *Transaction) RouteVar(name string) string { 103 | if t.routeMatch == nil { 104 | return "" 105 | } 106 | return t.routeMatch.Var(name) 107 | } 108 | 109 | // FormVar retrieves a POST, PATCH or PUT form parameter 110 | func (t *Transaction) FormVar(name string) string { 111 | return t.Request.PostFormValue(name) 112 | } 113 | 114 | // QueryVar retrieves a URL query-string parameter 115 | func (t *Transaction) QueryVar(name string) string { 116 | return t.Query().Get(name) 117 | } 118 | 119 | // SessionVar retrieves a session parameter (nil if not found or no session) 120 | func (t *Transaction) SessionVar(name string) interface{} { 121 | return t.Session().Get(name) 122 | } 123 | 124 | // Query returns all URL query-string parameters 125 | func (t *Transaction) Query() url.Values { 126 | if t.query == nil { 127 | t.query = t.URL.Query() 128 | } 129 | return t.query 130 | } 131 | 132 | // Form returns all POST, PATCH or PUT parameters 133 | func (t *Transaction) Form() url.Values { 134 | // cause ParseMultipartForm to be called with 135 | const maxMemory = 32 << 20 // 32 MB (matches defaultMaxMemory of go/net/http/request.go) 136 | if t.Request.PostForm == nil { 137 | t.Request.ParseMultipartForm(maxMemory) 138 | } 139 | // ignore error (complains if the conent type is not multipart) 140 | return t.Request.PostForm 141 | } 142 | 143 | func (t *Transaction) AuxVar(name string) interface{} { 144 | return t.AuxData[name] 145 | } 146 | 147 | func (t *Transaction) SetAuxVar(name string, value interface{}) { 148 | if t.AuxData == nil { 149 | t.AuxData = make(map[string]interface{}) 150 | } 151 | t.AuxData[name] = value 152 | } 153 | 154 | // RoutePath returns the request URL path relative to the Router that dispatched this request. 155 | // If the dispatch was done by t.Server.Router then this is identical to t.URL.Path. 156 | func (t *Transaction) RoutePath() string { 157 | if t.routeMatch != nil { 158 | return t.routeMatch.Path 159 | } 160 | return t.URL.Path 161 | } 162 | 163 | // -------------------------------------------------------------------------------------- 164 | // Responding 165 | 166 | // SetLastModified sets Last-Modified header if modtime != 0 167 | func (t *Transaction) SetLastModified(modtime time.Time) { 168 | if !isZeroTime(modtime) { 169 | t.Header().Set("Last-Modified", modtime.UTC().Format(http.TimeFormat)) 170 | } 171 | } 172 | 173 | func (t *Transaction) SetNoCacheHeaders() { 174 | h := t.Header() 175 | h.Set("Cache-Control", "no-cache, no-store, must-revalidate, pre-check=0, post-check=0") 176 | h.Set("Pragma", "no-cache") 177 | h.Set("Strict-Transport-Security", "max-age=15552000; preload") // for HTTPS 178 | } 179 | 180 | // SetCookie sets or adds a cookie to the response header. 181 | // See HeaderSetCookie for more details. 182 | func (t *Transaction) SetCookie(cookie string) error { 183 | return util.HeaderSetCookie(t.Header(), cookie) 184 | } 185 | 186 | func (t *Transaction) WriteHeader(statusCode int) { 187 | if !t.headerWritten { 188 | t.headerWritten = true 189 | if t.session != nil { 190 | if err := t.session.SaveHTTP(t); err != nil { 191 | t.Server.LogError("Transaction.WriteHeader;Session.SaveHTTP error: %v", err) 192 | } 193 | } 194 | t.ResponseWriter.WriteHeader(statusCode) 195 | } 196 | } 197 | 198 | func (t *Transaction) Write(data []byte) (int, error) { 199 | t.WriteHeader(t.Status) 200 | return t.ResponseWriter.Write(data) 201 | } 202 | 203 | func (t *Transaction) WriteString(s string) (int, error) { 204 | return t.Write([]byte(s)) 205 | } 206 | 207 | func (t *Transaction) Print(a interface{}) (int, error) { 208 | return fmt.Fprint(t, a) 209 | } 210 | 211 | func (t *Transaction) Printf(format string, arg ...interface{}) (int, error) { 212 | return fmt.Fprintf(t, format, arg...) 213 | } 214 | 215 | func (t *Transaction) Flush() bool { 216 | t.WriteHeader(t.Status) 217 | flusher, ok := t.ResponseWriter.(http.Flusher) 218 | if ok { 219 | flusher.Flush() 220 | } 221 | return ok 222 | } 223 | 224 | func (t *Transaction) WriteTemplate(tpl Template, data interface{}) error { 225 | buf, err := tpl.ExecBuf(data) 226 | if err != nil { 227 | return err 228 | } 229 | t.Header().Set("Content-Type", "text/html; charset=utf-8") 230 | t.Header().Set("Content-Length", strconv.Itoa(len(buf))) 231 | t.Write(buf) 232 | return nil 233 | } 234 | 235 | func (t *Transaction) WriteHtmlTemplateFile(filename string, data interface{}) { 236 | filename = t.AbsFilePath(filename) 237 | // TODO: cache 238 | tpl, err := ParseHtmlTemplateFile(filename) 239 | if err == nil { 240 | err = t.WriteTemplate(tpl, data) 241 | } 242 | if err != nil { 243 | panic(err) 244 | } 245 | } 246 | 247 | func (t *Transaction) WriteHtmlTemplateStr(templateSource string, data interface{}) { 248 | tpl, err := ParseHtmlTemplate("main", templateSource) 249 | if err == nil { 250 | err = t.WriteTemplate(tpl, data) 251 | } 252 | if err != nil { 253 | panic(err) 254 | } 255 | } 256 | 257 | // AbsFilePath takes a relative or absolute filename and returns an absolute filename. 258 | // If the filename is relative it will be resolved to t.Server.PubDir (if PubDir is empty, 259 | // filename is resolve to current working directory.) An absolute filename is returned verbatim. 260 | func (t *Transaction) AbsFilePath(filename string) string { 261 | if filepath.IsAbs(filename) { 262 | return filename 263 | } 264 | if len(t.Server.PubDir) != 0 { 265 | return filepath.Join(t.Server.PubDir, filename) 266 | } 267 | filename, err := filepath.Abs(filename) 268 | if err != nil { 269 | panic(err) 270 | } 271 | return filename 272 | } 273 | 274 | func (t *Transaction) ServeFile(filename string) { 275 | filename = t.AbsFilePath(filename) 276 | http.ServeFile(t, t.Request, filename) 277 | } 278 | 279 | func (t *Transaction) RespondWithMessage(statusCode int, msg interface{}) { 280 | body := "

" + http.StatusText(statusCode) + "

" 281 | if msg != nil { 282 | body += "" + html.EscapeString(fmt.Sprint(msg)) + "" 283 | } 284 | body += "" 285 | t.Header().Set("Content-Type", "text/html; charset=utf-8") 286 | t.Header().Set("Content-Length", strconv.Itoa(len(body))) 287 | t.Status = statusCode 288 | t.Write([]byte(body)) 289 | } 290 | 291 | func (t *Transaction) RespondWithStatus(statusCode int) { 292 | t.RespondWithMessage(statusCode, t.URL.String()) 293 | } 294 | 295 | func (t *Transaction) rws(statusCode int) { t.RespondWithStatus(statusCode) } 296 | 297 | func (t *Transaction) RespondWithStatusContinue() { t.rws(100) } 298 | func (t *Transaction) RespondWithStatusSwitchingProtocols() { t.rws(101) } 299 | func (t *Transaction) RespondWithStatusProcessing() { t.rws(102) } 300 | func (t *Transaction) RespondWithStatusEarlyHints() { t.rws(103) } 301 | 302 | func (t *Transaction) RespondWithStatusOK() { t.rws(200) } 303 | func (t *Transaction) RespondWithStatusCreated() { t.rws(201) } 304 | func (t *Transaction) RespondWithStatusAccepted() { t.rws(202) } 305 | func (t *Transaction) RespondWithStatusNonAuthoritativeInfo() { t.rws(203) } 306 | func (t *Transaction) RespondWithStatusNoContent() { t.rws(204) } 307 | func (t *Transaction) RespondWithStatusResetContent() { t.rws(205) } 308 | func (t *Transaction) RespondWithStatusPartialContent() { t.rws(206) } 309 | func (t *Transaction) RespondWithStatusMultiStatus() { t.rws(207) } 310 | func (t *Transaction) RespondWithStatusAlreadyReported() { t.rws(208) } 311 | func (t *Transaction) RespondWithStatusIMUsed() { t.rws(226) } 312 | 313 | func (t *Transaction) RespondWithStatusMultipleChoices() { t.rws(300) } 314 | func (t *Transaction) RespondWithStatusMovedPermanently() { t.rws(301) } 315 | func (t *Transaction) RespondWithStatusFound() { t.rws(302) } 316 | func (t *Transaction) RespondWithStatusSeeOther() { t.rws(303) } 317 | func (t *Transaction) RespondWithStatusNotModified() { t.rws(304) } 318 | func (t *Transaction) RespondWithStatusUseProxy() { t.rws(305) } 319 | 320 | func (t *Transaction) RespondWithStatusTemporaryRedirect() { t.rws(307) } 321 | func (t *Transaction) RespondWithStatusPermanentRedirect() { t.rws(308) } 322 | 323 | func (t *Transaction) RespondWithStatusBadRequest() { t.rws(400) } 324 | func (t *Transaction) RespondWithStatusUnauthorized() { t.rws(401) } 325 | func (t *Transaction) RespondWithStatusPaymentRequired() { t.rws(402) } 326 | func (t *Transaction) RespondWithStatusForbidden() { t.rws(403) } 327 | func (t *Transaction) RespondWithStatusNotFound() { t.rws(404) } 328 | func (t *Transaction) RespondWithStatusMethodNotAllowed() { t.rws(405) } 329 | func (t *Transaction) RespondWithStatusNotAcceptable() { t.rws(406) } 330 | func (t *Transaction) RespondWithStatusProxyAuthRequired() { t.rws(407) } 331 | func (t *Transaction) RespondWithStatusRequestTimeout() { t.rws(408) } 332 | func (t *Transaction) RespondWithStatusConflict() { t.rws(409) } 333 | func (t *Transaction) RespondWithStatusGone() { t.rws(410) } 334 | func (t *Transaction) RespondWithStatusLengthRequired() { t.rws(411) } 335 | func (t *Transaction) RespondWithStatusPreconditionFailed() { t.rws(412) } 336 | func (t *Transaction) RespondWithStatusRequestEntityTooLarge() { t.rws(413) } 337 | func (t *Transaction) RespondWithStatusRequestURITooLong() { t.rws(414) } 338 | func (t *Transaction) RespondWithStatusUnsupportedMediaType() { t.rws(415) } 339 | func (t *Transaction) RespondWithStatusRequestedRangeNotSatisfiable() { t.rws(416) } 340 | func (t *Transaction) RespondWithStatusExpectationFailed() { t.rws(417) } 341 | func (t *Transaction) RespondWithStatusTeapot() { t.rws(418) } 342 | func (t *Transaction) RespondWithStatusMisdirectedRequest() { t.rws(421) } 343 | func (t *Transaction) RespondWithStatusUnprocessableEntity() { t.rws(422) } 344 | func (t *Transaction) RespondWithStatusLocked() { t.rws(423) } 345 | func (t *Transaction) RespondWithStatusFailedDependency() { t.rws(424) } 346 | func (t *Transaction) RespondWithStatusTooEarly() { t.rws(425) } 347 | func (t *Transaction) RespondWithStatusUpgradeRequired() { t.rws(426) } 348 | func (t *Transaction) RespondWithStatusPreconditionRequired() { t.rws(428) } 349 | func (t *Transaction) RespondWithStatusTooManyRequests() { t.rws(429) } 350 | func (t *Transaction) RespondWithStatusRequestHeaderFieldsTooLarge() { t.rws(431) } 351 | func (t *Transaction) RespondWithStatusUnavailableForLegalReasons() { t.rws(451) } 352 | 353 | func (t *Transaction) RespondWithStatusInternalServerError() { t.rws(500) } 354 | func (t *Transaction) RespondWithStatusNotImplemented() { t.rws(501) } 355 | func (t *Transaction) RespondWithStatusBadGateway() { t.rws(502) } 356 | func (t *Transaction) RespondWithStatusServiceUnavailable() { t.rws(503) } 357 | func (t *Transaction) RespondWithStatusGatewayTimeout() { t.rws(504) } 358 | func (t *Transaction) RespondWithStatusHTTPVersionNotSupported() { t.rws(505) } 359 | func (t *Transaction) RespondWithStatusVariantAlsoNegotiates() { t.rws(506) } 360 | func (t *Transaction) RespondWithStatusInsufficientStorage() { t.rws(507) } 361 | func (t *Transaction) RespondWithStatusLoopDetected() { t.rws(508) } 362 | func (t *Transaction) RespondWithStatusNotExtended() { t.rws(510) } 363 | func (t *Transaction) RespondWithStatusNetworkAuthenticationRequired() { t.rws(511) } 364 | 365 | // Redirect sends a redirection response by setting the "location" header field. 366 | // The url may be a path relative to the request path. 367 | // 368 | // If the Content-Type header has not been set, Redirect sets it to "text/html; charset=utf-8" 369 | // and writes a small HTML body. Setting the Content-Type header to any value, including nil, 370 | // disables that behavior. 371 | func (t *Transaction) Redirect(url string, code int) { 372 | http.Redirect(t, t.Request, url, code) 373 | } 374 | 375 | // TemporaryRedirect sends a redirection response HTTP 302. 376 | // The user agent may change method (usually it uses GET) but it's ambiguous. 377 | func (t *Transaction) TemporaryRedirect(url string) { 378 | t.Redirect(url, http.StatusFound) 379 | } 380 | 381 | // TemporaryRedirectGET sends a redirection response HTTP 303. 382 | // The new location will be requested using the GET method. 383 | func (t *Transaction) TemporaryRedirectGET(url string) { 384 | code := http.StatusFound 385 | if t.Request.ProtoAtLeast(1, 1) { 386 | code = http.StatusSeeOther 387 | } 388 | t.Redirect(url, code) 389 | } 390 | 391 | // TemporaryRedirectSameMethod sends a redirection response HTTP 307. 392 | // The new location will be requested using the same method as the current request. 393 | func (t *Transaction) TemporaryRedirectSameMethod(url string) { 394 | code := http.StatusFound 395 | if t.Request.ProtoAtLeast(1, 1) { 396 | code = http.StatusTemporaryRedirect 397 | } 398 | t.Redirect(url, code) 399 | } 400 | 401 | // PermanentRedirect sends a redirection response HTTP 301. 402 | func (t *Transaction) PermanentRedirect(url string) { 403 | t.Redirect(url, http.StatusMovedPermanently) 404 | } 405 | 406 | // PermanentRedirectSameMethod sends a redirection response HTTP 308. 407 | // The new location will be requested using the same method as the current request. 408 | func (t *Transaction) PermanentRedirectSameMethod(url string) { 409 | code := http.StatusMovedPermanently 410 | if t.Request.ProtoAtLeast(1, 1) { 411 | code = http.StatusPermanentRedirect 412 | } 413 | t.Redirect(url, code) 414 | } 415 | 416 | // ReferrerURL returns a URL of the "referer" request field if present. 417 | // Returns fallback if the value of the "referer" header is not a valid URL. 418 | func (t *Transaction) ReferrerURL(fallback *url.URL) *url.URL { 419 | referrer := t.Request.Header.Get("referer") // yup, it's misspelled 420 | if referrer != "" { 421 | refurl, err := url.Parse(referrer) 422 | if err == nil { 423 | return refurl 424 | } 425 | } 426 | return fallback 427 | } 428 | 429 | // DifferentReferrerURL returns a URL of the "referer" request field if present. 430 | // If the referrer's path is the same as t.URL.Path then nil is returned. 431 | func (t *Transaction) DifferentReferrerURL() *url.URL { 432 | referrer := t.ReferrerURL(nil) 433 | if referrer != nil && referrer.Path != t.URL.Path { 434 | return referrer 435 | } 436 | return nil 437 | } 438 | 439 | // ------------------------------------------------------------------------------------------- 440 | // Context 441 | 442 | func (t *Transaction) Context() context.Context { 443 | return t.Request.Context() 444 | } 445 | 446 | func (t *Transaction) ContextWithTimeout( 447 | timeout time.Duration) (context.Context, context.CancelFunc) { 448 | return context.WithTimeout(t.Request.Context(), timeout) 449 | } 450 | 451 | // ------------------------------------------------------------------------------------------- 452 | // Session 453 | 454 | // Session returns the session for the transaction. 455 | // If the server has a valid session store, this always returns a valid Session object, which 456 | // may be empty in case there's no session. 457 | // Returns nil only when the server does not have a valid session store. 458 | // 459 | func (t *Transaction) Session() *session.Session { 460 | if t.session == nil { 461 | // Note: LoadHTTP always returns a valid Session object 462 | t.session, _ = t.Server.Sessions.LoadHTTP(t.Request) // ignore error 463 | } 464 | return t.session 465 | } 466 | 467 | func (t *Transaction) SaveSession() { 468 | if t.session != nil { 469 | if err := t.session.SaveHTTP(t); err != nil { 470 | t.Server.Logger.Warn("Transaction.SaveSession: %v", err) 471 | } 472 | } 473 | } 474 | 475 | func (t *Transaction) ClearSession() { 476 | s := t.Session() 477 | if s != nil { 478 | s.Clear() 479 | if err := s.SaveHTTP(t); err != nil { 480 | t.Server.Logger.Warn("Transaction.ClearSession: %v", err) 481 | } 482 | } 483 | } 484 | 485 | // // RequireUser verifies that the request is from an authenticated user. 486 | // // 487 | // // If the user is authenticated, their corresponding User object is returned. 488 | // // In this case, the caller should complete the response. 489 | // // 490 | // // If the user is not authenticated, or their session has become invalid, a redirect 491 | // // response is sent to the sign-in page and nil is returned. 492 | // // In the case of nil being returned, the caller should NOT modify the response. 493 | // // 494 | // func (t *Transaction) RequireUser1() *User1 { 495 | // setNoCacheHeaders(t) 496 | // user := t.User1() 497 | // if user == nil { 498 | // // logf("unathuenticated request; redirecting to sign-in") 499 | // redirectToLogIn(t) 500 | // return nil 501 | // } 502 | // return user 503 | // } 504 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package httpd 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "time" 8 | ) 9 | 10 | func absfile(filename string) string { 11 | if s, err := filepath.Abs(filename); err == nil { 12 | return s 13 | } 14 | return filename 15 | } 16 | 17 | func crashOnErr(err error) { 18 | if err != nil { 19 | fatalf(err) 20 | } 21 | } 22 | 23 | func errorf(format string, v ...interface{}) error { 24 | return fmt.Errorf(format, v...) 25 | } 26 | 27 | func fatalf(msg interface{}, arg ...interface{}) { 28 | var format string 29 | if s, ok := msg.(string); ok { 30 | format = s 31 | } else if s, ok := msg.(fmt.Stringer); ok { 32 | format = s.String() 33 | } else { 34 | format = fmt.Sprintf("%v", msg) 35 | } 36 | fmt.Fprintf(os.Stderr, format+"\n", arg...) 37 | os.Exit(1) 38 | } 39 | 40 | var unixEpochTime = time.Unix(0, 0) 41 | 42 | // isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0). 43 | func isZeroTime(t time.Time) bool { 44 | return t.IsZero() || t.Equal(unixEpochTime) 45 | } 46 | -------------------------------------------------------------------------------- /util/util.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "net/textproto" 7 | "strings" 8 | ) 9 | 10 | // HeaderSetCookie sets or adds a cookie in or to the provided HTTP header. 11 | // This is different from header.Set("Set-Cookie") as it replaces cookies with the same 12 | // cookie name but adds cookies with different names. 13 | // 14 | // In other words, if you call: 15 | // HeaderSetCookie("foo=123;...") 16 | // HeaderSetCookie("bar=456;...") 17 | // HeaderSetCookie("foo=789;...") // replaces foo cookie 18 | // then the actual header will contain: 19 | // Set-Cookie: foo=789;... 20 | // Set-Cookie: bar=456;... 21 | // 22 | func HeaderSetCookie(header http.Header, cookie string) error { 23 | name := parseCookieName(cookie) 24 | if len(name) == 0 { 25 | return errors.New("invalid cookie") 26 | } 27 | // find existing cookie with same name 28 | existingCookies := header.Values("Set-Cookie") 29 | for i, line := range existingCookies { 30 | name2 := parseCookieName(line) 31 | if name == name2 { 32 | // replace cookie 33 | existingCookies[i] = cookie 34 | return nil 35 | } 36 | } 37 | // cookie not yet set in header; add it 38 | header.Add("Set-Cookie", cookie) 39 | return nil 40 | } 41 | 42 | // returns the name in a string like "name=value;..." 43 | func parseCookieName(cookie string) string { 44 | i := strings.IndexByte(cookie, ';') 45 | if i < 0 { 46 | return "" 47 | } 48 | e := strings.IndexByte(cookie, '=') 49 | if e < 0 || e > i { 50 | return "" 51 | } 52 | return textproto.TrimString(cookie[:e]) 53 | } 54 | -------------------------------------------------------------------------------- /util1.go: -------------------------------------------------------------------------------- 1 | package httpd 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "time" 8 | ) 9 | 10 | func absfile(filename string) string { 11 | if s, err := filepath.Abs(filename); err == nil { 12 | return s 13 | } 14 | return filename 15 | } 16 | 17 | func crashOnErr(err error) { 18 | if err != nil { 19 | fatalf(err) 20 | } 21 | } 22 | 23 | func errorf(format string, v ...interface{}) error { 24 | return fmt.Errorf(format, v...) 25 | } 26 | 27 | func fatalf(msg interface{}, arg ...interface{}) { 28 | var format string 29 | if s, ok := msg.(string); ok { 30 | format = s 31 | } else if s, ok := msg.(fmt.Stringer); ok { 32 | format = s.String() 33 | } else { 34 | format = fmt.Sprintf("%v", msg) 35 | } 36 | fmt.Fprintf(os.Stderr, format+"\n", arg...) 37 | os.Exit(1) 38 | } 39 | 40 | var unixEpochTime = time.Unix(0, 0) 41 | 42 | // isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0). 43 | func isZeroTime(t time.Time) bool { 44 | return t.IsZero() || t.Equal(unixEpochTime) 45 | } 46 | --------------------------------------------------------------------------------