├── .gitignore ├── README.md ├── authenticator.go ├── conf.go ├── config_sample.yml ├── config_test.go ├── httpd.go ├── httpd_test.go └── main.go /.gitignore: -------------------------------------------------------------------------------- 1 | config.yml 2 | gate -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # "gate" for your private resources 2 | 3 | gate is a static file server and reverse proxy integrated with OAuth2 account authentication. 4 | 5 | With gate, you can safely serve your private resources based on whether or not request user is a member of your company's Google Apps or GitHub organizations. 6 | 7 | ## Usage 8 | 9 | 1. Download [binary](https://github.com/typester/gate/releases) or `go get` 10 | 2. rename `config_sample.yml` to `config.yml` 11 | 3. edit `config.yml` to fit your environment 12 | 4. run `gate` 13 | 14 | ## Example config 15 | 16 | ```yaml 17 | # address to bind 18 | address: :9999 19 | 20 | # # ssl keys (optional) 21 | # ssl: 22 | # cert: ./ssl/ssl.cer 23 | # key: ./ssl/ssl.key 24 | 25 | auth: 26 | session: 27 | # authentication key for cookie store 28 | key: secret123 29 | 30 | info: 31 | # oauth2 provider name (`google` or `github`) 32 | service: google 33 | # your app keys for the service 34 | client_id: your client id 35 | client_secret: your client secret 36 | # your app redirect_url for the service: if the service is Google, path is always "/oauth2callback" 37 | redirect_url: https://yourapp.example.com/oauth2callback 38 | 39 | # # restrict user request. (optional) 40 | # restrictions: 41 | # - yourdomain.com # domain of your Google App (Google) 42 | # - example@gmail.com # specific email address (same as above) 43 | # - your_company_org # organization name (GitHub) 44 | 45 | # document root for static files 46 | htdocs: ./ 47 | 48 | # proxy definitions 49 | proxy: 50 | - path: /elasticsearch 51 | dest: http://127.0.0.1:9200 52 | strip_path: yes 53 | 54 | - path: /influxdb 55 | dest: http://127.0.0.1:8086 56 | strip_path: yes 57 | ``` 58 | 59 | ## Authentication Strategy 60 | 61 | gate now supports Google Apps and GitHub to authenticate users. 62 | 63 | ### Example config for Google 64 | 65 | ```yaml 66 | auth: 67 | info: 68 | service: google 69 | client_id: your client id 70 | client_secret: your client secret 71 | redirect_url: https://yourapp.example.com/oauth2callback 72 | 73 | # restrict user request. (optional) 74 | restrictions: 75 | - yourdomain.com # domain of your Google App 76 | - example@gmail.com # specific email address 77 | ``` 78 | 79 | ### Example config for GitHub 80 | 81 | Unlike the example of Google Apps above, if the `service` is GitHub, gate uses whether request user is a member of organization designated like below: 82 | 83 | ```yaml 84 | auth: 85 | info: 86 | service: github 87 | client_id: your client id 88 | client_secret: your client secret 89 | redirect_url: https://yourapp.example.com/oauth2callback 90 | 91 | # restrict user request. (optional) 92 | restrictions: 93 | - foo_organization 94 | - bar_organization 95 | ``` 96 | 97 | #### github:e support 98 | 99 | GitHub Enterprise is also supported. To authenticate via github enterprise, add api endpoint information to config like following: 100 | 101 | ```yaml 102 | auth: 103 | info: 104 | service: github 105 | client_id: your client id 106 | client_secret: your client secret 107 | redirect_url: https://yourapp.example.com/oauth2callback 108 | endpoint: https://github.yourcompany.com 109 | api_endpoint: https://github.yourcompany.com/api 110 | ``` 111 | 112 | ## Name Based Virtual Host 113 | 114 | An example of "Name Based Viatual Host" setting. 115 | 116 | ```yaml 117 | auth: 118 | session: 119 | # authentication key for cookie store 120 | key: secret123 121 | # domain of virtual hosts base host 122 | cookie_domain: gate.example.com 123 | 124 | # proxy definitions 125 | proxy: 126 | - path: / 127 | host: elasticsearch.gate.example.com 128 | dest: http://127.0.0.1:9200 129 | 130 | - path: / 131 | host: influxdb.gate.example.com 132 | dest: http://127.0.0.1:8086 133 | ``` 134 | 135 | ## License 136 | 137 | MIT 138 | -------------------------------------------------------------------------------- /authenticator.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "github.com/go-martini/martini" 7 | gooauth2 "github.com/golang/oauth2" 8 | "github.com/martini-contrib/oauth2" 9 | "io/ioutil" 10 | "log" 11 | "net/http" 12 | "strings" 13 | ) 14 | 15 | type Authenticator interface { 16 | Authenticate([]string, martini.Context, oauth2.Tokens, http.ResponseWriter, *http.Request) 17 | Handler() martini.Handler 18 | } 19 | 20 | func NewAuthenticator(conf *Conf) Authenticator { 21 | var authenticator Authenticator 22 | 23 | if conf.Auth.Info.Service == "google" { 24 | handler := oauth2.Google(&gooauth2.Options{ 25 | ClientID: conf.Auth.Info.ClientId, 26 | ClientSecret: conf.Auth.Info.ClientSecret, 27 | RedirectURL: conf.Auth.Info.RedirectURL, 28 | Scopes: []string{"email"}, 29 | }) 30 | authenticator = &GoogleAuth{&BaseAuth{handler, conf}} 31 | } else if conf.Auth.Info.Service == "github" { 32 | handler := GithubGeneral(&gooauth2.Options{ 33 | ClientID: conf.Auth.Info.ClientId, 34 | ClientSecret: conf.Auth.Info.ClientSecret, 35 | RedirectURL: conf.Auth.Info.RedirectURL, 36 | Scopes: []string{"read:org"}, 37 | }, conf) 38 | authenticator = &GitHubAuth{&BaseAuth{handler, conf}} 39 | } else { 40 | panic("unsupported authentication method") 41 | } 42 | 43 | return authenticator 44 | } 45 | 46 | // Currently, martini-contrib/oauth2 doesn't support github enterprise directly. 47 | func GithubGeneral(opts *gooauth2.Options, conf *Conf) martini.Handler { 48 | authUrl := fmt.Sprintf("%s/login/oauth/authorize", conf.Auth.Info.Endpoint) 49 | tokenUrl := fmt.Sprintf("%s/login/oauth/access_token", conf.Auth.Info.Endpoint) 50 | 51 | return oauth2.NewOAuth2Provider(opts, authUrl, tokenUrl) 52 | } 53 | 54 | type BaseAuth struct { 55 | handler martini.Handler 56 | conf *Conf 57 | } 58 | 59 | func (b *BaseAuth) Handler() martini.Handler { 60 | return b.handler 61 | } 62 | 63 | type GoogleAuth struct { 64 | *BaseAuth 65 | } 66 | 67 | func (a *GoogleAuth) Authenticate(domain []string, c martini.Context, tokens oauth2.Tokens, w http.ResponseWriter, r *http.Request) { 68 | extra := tokens.ExtraData() 69 | if _, ok := extra["id_token"]; ok == false { 70 | log.Printf("id_token not found") 71 | forbidden(w) 72 | return 73 | } 74 | 75 | keys := strings.Split(extra["id_token"], ".") 76 | if len(keys) < 2 { 77 | log.Printf("invalid id_token") 78 | forbidden(w) 79 | return 80 | } 81 | 82 | data, err := base64Decode(keys[1]) 83 | if err != nil { 84 | log.Printf("failed to decode base64: %s", err.Error()) 85 | forbidden(w) 86 | return 87 | } 88 | 89 | var info map[string]interface{} 90 | if err := json.Unmarshal(data, &info); err != nil { 91 | log.Printf("failed to decode json: %s", err.Error()) 92 | forbidden(w) 93 | return 94 | } 95 | 96 | if email, ok := info["email"].(string); ok { 97 | var user *User 98 | if len(domain) > 0 { 99 | for _, d := range domain { 100 | if strings.Contains(d, "@") { 101 | if d == email { 102 | user = &User{email} 103 | } 104 | } else { 105 | if strings.HasSuffix(email, "@"+d) { 106 | user = &User{email} 107 | break 108 | } 109 | } 110 | } 111 | } else { 112 | user = &User{email} 113 | } 114 | 115 | if user != nil { 116 | log.Printf("user %s logged in", email) 117 | c.Map(user) 118 | } else { 119 | log.Printf("email doesn't allow: %s", email) 120 | forbidden(w) 121 | return 122 | } 123 | } else { 124 | log.Printf("email not found") 125 | forbidden(w) 126 | return 127 | } 128 | } 129 | 130 | type GitHubAuth struct { 131 | *BaseAuth 132 | } 133 | 134 | func (a *GitHubAuth) Authenticate(organizations []string, c martini.Context, tokens oauth2.Tokens, w http.ResponseWriter, r *http.Request) { 135 | if len(organizations) > 0 { 136 | req, err := http.NewRequest("GET", fmt.Sprintf("%s/user/orgs", a.conf.Auth.Info.ApiEndpoint), nil) 137 | if err != nil { 138 | log.Printf("failed to create a request to retrieve organizations: %s", err) 139 | forbidden(w) 140 | return 141 | } 142 | 143 | req.SetBasicAuth(tokens.Access(), "x-oauth-basic") 144 | 145 | client := http.Client{} 146 | res, err := client.Do(req) 147 | if err != nil { 148 | log.Printf("failed to retrieve organizations: %s", err) 149 | forbidden(w) 150 | return 151 | } 152 | 153 | data, err := ioutil.ReadAll(res.Body) 154 | res.Body.Close() 155 | 156 | if err != nil { 157 | log.Printf("failed to read body of GitHub response: %s", err) 158 | forbidden(w) 159 | return 160 | } 161 | 162 | var info []map[string]interface{} 163 | if err := json.Unmarshal(data, &info); err != nil { 164 | log.Printf("failed to decode json: %s", err.Error()) 165 | forbidden(w) 166 | return 167 | } 168 | 169 | for _, userOrg := range info { 170 | for _, org := range organizations { 171 | if userOrg["login"] == org { 172 | return 173 | } 174 | } 175 | } 176 | 177 | log.Print("not a member of designated organizations") 178 | forbidden(w) 179 | return 180 | } 181 | } 182 | 183 | func forbidden(w http.ResponseWriter) { 184 | w.WriteHeader(403) 185 | w.Write([]byte("Access denied")) 186 | } 187 | -------------------------------------------------------------------------------- /conf.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "gopkg.in/yaml.v1" 6 | "io/ioutil" 7 | "github.com/martini-contrib/oauth2" 8 | ) 9 | 10 | const ( 11 | noAuthServiceName = "nothing" // for testing only (undocumented) 12 | ) 13 | 14 | type Conf struct { 15 | Addr string `yaml:"address"` 16 | SSL SSLConf `yaml:"ssl"` 17 | Auth AuthConf `yaml:"auth"` 18 | Restrictions []string `yaml:"restrictions"` 19 | Proxies []ProxyConf `yaml:"proxy"` 20 | Paths PathConf `yaml:"paths"` 21 | Htdocs string `yaml:"htdocs"` 22 | } 23 | 24 | type SSLConf struct { 25 | Cert string `yaml:"cert"` 26 | Key string `yaml:"key"` 27 | } 28 | 29 | type AuthConf struct { 30 | Session AuthSessionConf `yaml:"session"` 31 | Info AuthInfoConf `yaml:"info"` 32 | } 33 | 34 | type AuthSessionConf struct { 35 | Key string `yaml:"key"` 36 | CookieDomain string `yaml:"cookie_domain"` 37 | } 38 | 39 | type AuthInfoConf struct { 40 | Service string `yaml:"service"` 41 | ClientId string `yaml:"client_id"` 42 | ClientSecret string `yaml:"client_secret"` 43 | RedirectURL string `yaml:"redirect_url"` 44 | Endpoint string `yaml:"endpoint"` 45 | ApiEndpoint string `yaml:"api_endpoint"` 46 | } 47 | 48 | type ProxyConf struct { 49 | Path string `yaml:"path"` 50 | Dest string `yaml:"dest"` 51 | Strip bool `yaml:"strip_path"` 52 | Host string `yaml:"host"` 53 | } 54 | 55 | type PathConf struct { 56 | Login string `yaml:"login"` 57 | Logout string `yaml:"logout"` 58 | Callback string `yaml:"callback"` 59 | Error string `yaml:"error"` 60 | } 61 | 62 | func ParseConf(path string) (*Conf, error) { 63 | data, err := ioutil.ReadFile(path) 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | c := &Conf{} 69 | if err := yaml.Unmarshal(data, c); err != nil { 70 | return nil, err 71 | } 72 | 73 | if c.Addr == "" { 74 | return nil, errors.New("address config is required") 75 | } 76 | 77 | if c.Auth.Session.Key == "" { 78 | return nil, errors.New("auth.session.key config is required") 79 | } 80 | if c.Auth.Info.Service == "" { 81 | return nil, errors.New("auth.info.service config is required") 82 | } 83 | if c.Auth.Info.ClientId == "" { 84 | return nil, errors.New("auth.info.client_id config is required") 85 | } 86 | if c.Auth.Info.ClientSecret == "" { 87 | return nil, errors.New("auth.info.client_secret config is required") 88 | } 89 | if c.Auth.Info.RedirectURL == "" { 90 | return nil, errors.New("auth.info.redirect_url config is required") 91 | } 92 | 93 | if c.Htdocs == "" { 94 | c.Htdocs = "." 95 | } 96 | 97 | if c.Auth.Info.Service == "github" && c.Auth.Info.Endpoint == "" { 98 | c.Auth.Info.Endpoint = "https://github.com" 99 | } 100 | if c.Auth.Info.Service == "github" && c.Auth.Info.ApiEndpoint == "" { 101 | c.Auth.Info.ApiEndpoint = "https://api.github.com" 102 | } 103 | 104 | return c, nil 105 | } 106 | 107 | func (c *Conf) SetOAuth2Paths() { 108 | if c.Paths.Login != "" { 109 | oauth2.PathLogin = c.Paths.Login 110 | } 111 | if c.Paths.Logout != "" { 112 | oauth2.PathLogout = c.Paths.Logout 113 | } 114 | if c.Paths.Callback != "" { 115 | oauth2.PathCallback = c.Paths.Callback 116 | } 117 | if c.Paths.Error != "" { 118 | oauth2.PathError = c.Paths.Error 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /config_sample.yml: -------------------------------------------------------------------------------- 1 | # address to bind 2 | address: :9999 3 | 4 | # # ssl keys (optional) 5 | # ssl: 6 | # cert: ./ssl/ssl.cer 7 | # key: ./ssl/ssl.key 8 | 9 | auth: 10 | session: 11 | # authentication key for cookie store 12 | key: secret123 13 | 14 | info: 15 | # oauth2 provider name (`google` or `github`) 16 | service: google 17 | # your app keys for the service 18 | client_id: your client id 19 | client_secret: your client secret 20 | # your app redirect_url for the service: if the service is Google, path is always "/oauth2callback" 21 | redirect_url: https://yourapp.example.com/oauth2callback 22 | 23 | # # restrict user request. (optional) 24 | # restrictions: 25 | # - yourdomain.com # domain of your Google App (Google) 26 | # - example@gmail.com # specific email address (same as above) 27 | # - your_company_org # organization name (GitHub) 28 | 29 | # document root for static files 30 | htdocs: ./ 31 | 32 | # proxy definitions 33 | proxy: 34 | - path: /elasticsearch 35 | dest: http://127.0.0.1:9200 36 | strip_path: yes 37 | 38 | - path: /influxdb 39 | dest: http://127.0.0.1:8086 40 | strip_path: yes 41 | -------------------------------------------------------------------------------- /config_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "testing" 7 | "github.com/martini-contrib/oauth2" 8 | ) 9 | 10 | func TestParse(t *testing.T) { 11 | f, err := ioutil.TempFile("", "") 12 | if err != nil { 13 | t.Error(err) 14 | } 15 | defer func() { 16 | f.Close() 17 | os.Remove(f.Name()) 18 | }() 19 | 20 | data := `--- 21 | address: ":9999" 22 | 23 | auth: 24 | session: 25 | key: secret 26 | 27 | info: 28 | service: 'google' 29 | client_id: 'secret client id' 30 | client_secret: 'secret client secret' 31 | redirect_url: 'http://example.com/oauth2callback' 32 | 33 | htdocs: ./ 34 | 35 | proxy: 36 | - path: /foo 37 | dest: http://example.com/bar 38 | strip_path: yes 39 | ` 40 | if err := ioutil.WriteFile(f.Name(), []byte(data), 0644); err != nil { 41 | t.Error(err) 42 | } 43 | 44 | conf, err := ParseConf(f.Name()) 45 | if err != nil { 46 | t.Error(err) 47 | } 48 | 49 | if conf.Addr != ":9999" { 50 | t.Errorf("unexpected address: %s", conf.Addr) 51 | } 52 | } 53 | 54 | func TestParseMultiRestrictions(t *testing.T) { 55 | f, err := ioutil.TempFile("", "") 56 | if err != nil { 57 | t.Error(err) 58 | } 59 | defer func() { 60 | f.Close() 61 | os.Remove(f.Name()) 62 | }() 63 | 64 | data := `--- 65 | address: ":9999" 66 | 67 | auth: 68 | session: 69 | key: secret 70 | 71 | info: 72 | service: 'google' 73 | client_id: 'secret client id' 74 | client_secret: 'secret client secret' 75 | redirect_url: 'http://example.com/oauth2callback' 76 | 77 | htdocs: ./ 78 | 79 | proxy: 80 | - path: /foo 81 | dest: http://example.com/bar 82 | strip_path: yes 83 | 84 | restrictions: 85 | - 'example1.com' 86 | - 'example2.com' 87 | ` 88 | if err := ioutil.WriteFile(f.Name(), []byte(data), 0644); err != nil { 89 | t.Error(err) 90 | } 91 | 92 | conf, err := ParseConf(f.Name()) 93 | if err != nil { 94 | t.Error(err) 95 | } 96 | 97 | if len(conf.Restrictions) != 2 { 98 | t.Errorf("unexpected restrictions num: %d", len(conf.Restrictions)) 99 | } 100 | 101 | if conf.Restrictions[0] != "example1.com" || conf.Restrictions[1] != "example2.com" { 102 | t.Errorf("unexpected restrictions: %+v", conf.Restrictions) 103 | } 104 | } 105 | 106 | func TestParseGithubServiceShouldSetDefaultValue(t *testing.T) { 107 | f, err := ioutil.TempFile("", "") 108 | if err != nil { 109 | t.Error(err) 110 | } 111 | defer func() { 112 | f.Close() 113 | os.Remove(f.Name()) 114 | }() 115 | 116 | data := `--- 117 | address: ":9999" 118 | 119 | auth: 120 | session: 121 | key: secret 122 | 123 | info: 124 | service: 'github' 125 | client_id: 'secret client id' 126 | client_secret: 'secret client secret' 127 | redirect_url: 'http://example.com/oauth2callback' 128 | ` 129 | if err := ioutil.WriteFile(f.Name(), []byte(data), 0644); err != nil { 130 | t.Error(err) 131 | } 132 | 133 | conf, err := ParseConf(f.Name()) 134 | if err != nil { 135 | t.Error(err) 136 | } 137 | 138 | if conf.Auth.Info.Endpoint != "https://github.com" { 139 | t.Errorf("unexpected endpoint address: %s", conf.Auth.Info.Endpoint) 140 | } 141 | if conf.Auth.Info.ApiEndpoint != "https://api.github.com" { 142 | t.Errorf("unexpected api endpoint address: %s", conf.Auth.Info.ApiEndpoint) 143 | } 144 | } 145 | 146 | func TestParseNamebasedVhosts(t *testing.T) { 147 | f, err := ioutil.TempFile("", "") 148 | if err != nil { 149 | t.Error(err) 150 | } 151 | defer func() { 152 | f.Close() 153 | os.Remove(f.Name()) 154 | }() 155 | 156 | data := `--- 157 | address: ":9999" 158 | 159 | auth: 160 | session: 161 | key: secret 162 | cookie_domain: example.com 163 | 164 | info: 165 | service: 'google' 166 | client_id: 'secret client id' 167 | client_secret: 'secret client secret' 168 | redirect_url: 'http://example.com/oauth2callback' 169 | 170 | htdocs: ./ 171 | 172 | proxy: 173 | - path: / 174 | host: elasticsearch.example.com 175 | dest: http://127.0.0.1:9200 176 | - path: / 177 | host: influxdb.example.com 178 | dest: http://127.0.0.1:8086 179 | ` 180 | if err := ioutil.WriteFile(f.Name(), []byte(data), 0644); err != nil { 181 | t.Error(err) 182 | } 183 | 184 | conf, err := ParseConf(f.Name()) 185 | if err != nil { 186 | t.Error(err) 187 | } 188 | 189 | if conf.Auth.Session.CookieDomain != "example.com" { 190 | t.Errorf("unexpected cookie_domain: %s", conf.Auth.Session.CookieDomain) 191 | } 192 | 193 | if len(conf.Proxies) != 2 { 194 | t.Errorf("insufficient proxy definions") 195 | } 196 | es := conf.Proxies[0] 197 | if es.Path != "/" || es.Host != "elasticsearch.example.com" || es.Dest != "http://127.0.0.1:9200" { 198 | t.Errorf("unexpected proxy[0]: %#v", es) 199 | } 200 | 201 | ifdb := conf.Proxies[1] 202 | if ifdb.Path != "/" || ifdb.Host != "influxdb.example.com" || ifdb.Dest != "http://127.0.0.1:8086" { 203 | t.Errorf("unexpected proxy[1]: %#v", ifdb) 204 | } 205 | } 206 | 207 | func TestPathConf(t *testing.T) { 208 | f, err := ioutil.TempFile("", "") 209 | if err != nil { 210 | t.Error(err) 211 | } 212 | defer func() { 213 | f.Close() 214 | os.Remove(f.Name()) 215 | }() 216 | 217 | data := `--- 218 | address: ":9999" 219 | 220 | auth: 221 | session: 222 | key: secret 223 | 224 | info: 225 | service: 'github' 226 | client_id: 'secret client id' 227 | client_secret: 'secret client secret' 228 | redirect_url: 'http://example.com/_gate_callback' 229 | 230 | paths: 231 | login: "/_gate_login" 232 | logout: "/_gate_logout" 233 | callback: "/_gate_callback" 234 | error: "/_gate_error" 235 | ` 236 | if err := ioutil.WriteFile(f.Name(), []byte(data), 0644); err != nil { 237 | t.Error(err) 238 | } 239 | 240 | conf, err := ParseConf(f.Name()) 241 | if err != nil { 242 | t.Error(err) 243 | } 244 | 245 | conf.SetOAuth2Paths() 246 | 247 | if oauth2.PathLogin != "/_gate_login" { 248 | t.Errorf("unexpected oauth2.PathLogin: %s", oauth2.PathLogin) 249 | } 250 | if oauth2.PathLogout != "/_gate_logout" { 251 | t.Errorf("unexpected oauth2.PathLogout: %s", oauth2.PathLogout) 252 | } 253 | if oauth2.PathCallback != "/_gate_callback" { 254 | t.Errorf("unexpected oauth2.PathCallback: %s", oauth2.PathCallback) 255 | } 256 | if oauth2.PathError != "/_gate_error" { 257 | t.Errorf("unexpected oauth2.PathError: %s", oauth2.PathError) 258 | } 259 | } 260 | -------------------------------------------------------------------------------- /httpd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/base64" 5 | "io" 6 | "log" 7 | "net" 8 | "net/http" 9 | "net/http/httputil" 10 | "net/url" 11 | "path/filepath" 12 | "strings" 13 | 14 | "github.com/go-martini/martini" 15 | "github.com/martini-contrib/oauth2" 16 | "github.com/martini-contrib/sessions" 17 | ) 18 | 19 | type Server struct { 20 | Conf *Conf 21 | } 22 | 23 | type User struct { 24 | Email string 25 | } 26 | 27 | type Backend struct { 28 | Host string 29 | URL *url.URL 30 | Strip bool 31 | StripPath string 32 | } 33 | 34 | const ( 35 | BackendHostHeader = "X-Gate-Backend-Host" 36 | ) 37 | 38 | func NewServer(conf *Conf) *Server { 39 | return &Server{conf} 40 | } 41 | 42 | func (s *Server) Run() error { 43 | m := martini.Classic() 44 | 45 | cookieStore := sessions.NewCookieStore([]byte(s.Conf.Auth.Session.Key)) 46 | if domain := s.Conf.Auth.Session.CookieDomain; domain != "" { 47 | cookieStore.Options(sessions.Options{Domain: domain}) 48 | } 49 | m.Use(sessions.Sessions("session", cookieStore)) 50 | 51 | if s.Conf.Auth.Info.Service != noAuthServiceName { 52 | a := NewAuthenticator(s.Conf) 53 | m.Use(a.Handler()) 54 | m.Use(loginRequired()) 55 | m.Use(restrictRequest(s.Conf.Restrictions, a)) 56 | } 57 | 58 | backendsFor := make(map[string][]Backend) 59 | backendIndex := make([]string, len(s.Conf.Proxies)) 60 | rawPaths := make([]string, len(s.Conf.Proxies)) 61 | 62 | for i := range s.Conf.Proxies { 63 | p := s.Conf.Proxies[i] 64 | 65 | rawPath := "" 66 | if strings.HasSuffix(p.Path, "/") == false { 67 | rawPath = p.Path 68 | p.Path += "/" 69 | } 70 | strip_path := p.Path 71 | 72 | if strings.HasSuffix(p.Path, "**") == false { 73 | p.Path += "**" 74 | } 75 | 76 | u, err := url.Parse(p.Dest) 77 | if err != nil { 78 | return err 79 | } 80 | backendsFor[p.Path] = append(backendsFor[p.Path], Backend{ 81 | Host: p.Host, 82 | URL: u, 83 | Strip: p.Strip, 84 | StripPath: strip_path, 85 | }) 86 | backendIndex[i] = p.Path 87 | rawPaths[i] = rawPath 88 | log.Printf("register proxy host:%s path:%s dest:%s strip_path:%v", p.Host, strip_path, u.String(), p.Strip) 89 | } 90 | 91 | registered := make(map[string]bool) 92 | for i, path := range backendIndex { 93 | if registered[path] { 94 | continue 95 | } 96 | proxy := newVirtualHostReverseProxy(backendsFor[path]) 97 | m.Any(path, proxyHandleWrapper(proxy)) 98 | registered[path] = true 99 | rawPath := rawPaths[i] 100 | if rawPath != "" { 101 | m.Get(rawPath, func(w http.ResponseWriter, r *http.Request) { 102 | http.Redirect(w, r, rawPath+"/", http.StatusFound) 103 | }) 104 | } 105 | } 106 | 107 | path, err := filepath.Abs(s.Conf.Htdocs) 108 | if err != nil { 109 | return err 110 | } 111 | 112 | log.Printf("starting static file server for: %s", path) 113 | fileServer := http.FileServer(http.Dir(path)) 114 | m.Get("/**", fileServer.ServeHTTP) 115 | 116 | log.Printf("starting server at %s", s.Conf.Addr) 117 | 118 | if s.Conf.SSL.Cert != "" && s.Conf.SSL.Key != "" { 119 | return http.ListenAndServeTLS(s.Conf.Addr, s.Conf.SSL.Cert, s.Conf.SSL.Key, m) 120 | } else { 121 | return http.ListenAndServe(s.Conf.Addr, m) 122 | } 123 | } 124 | 125 | func newVirtualHostReverseProxy(backends []Backend) http.Handler { 126 | bmap := make(map[string]Backend) 127 | for _, b := range backends { 128 | bmap[b.Host] = b 129 | } 130 | defaultBackend, ok := bmap[""] 131 | if !ok { 132 | defaultBackend = backends[0] 133 | } 134 | 135 | director := func(req *http.Request) { 136 | b, ok := bmap[req.Host] 137 | if !ok { 138 | b = defaultBackend 139 | } 140 | req.URL.Scheme = b.URL.Scheme 141 | req.URL.Host = b.URL.Host 142 | if b.Strip { 143 | if p := strings.TrimPrefix(req.URL.Path, b.StripPath); len(p) < len(req.URL.Path) { 144 | req.URL.Path = "/" + p 145 | } 146 | } 147 | req.Header.Set(BackendHostHeader, req.URL.Host) 148 | log.Println("backend url", req.URL.String()) 149 | } 150 | return &httputil.ReverseProxy{Director: director} 151 | } 152 | 153 | func isWebsocket(r *http.Request) bool { 154 | if strings.ToLower(r.Header.Get("Connection")) == "upgrade" && 155 | strings.ToLower(r.Header.Get("Upgrade")) == "websocket" { 156 | return true 157 | } else { 158 | return false 159 | } 160 | } 161 | 162 | func proxyHandleWrapper(handler http.Handler) http.Handler { 163 | proxy, _ := handler.(*httputil.ReverseProxy) 164 | director := proxy.Director 165 | 166 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 167 | // websocket? 168 | if isWebsocket(r) { 169 | director(r) // rewrite request headers for backend 170 | target := r.Header.Get(BackendHostHeader) 171 | 172 | if strings.HasPrefix(r.URL.Path, "/") == false { 173 | r.URL.Path = "/" + r.URL.Path 174 | } 175 | 176 | log.Printf("proxy ws request: %s", r.URL.String()) 177 | 178 | // websocket proxy by bradfitz https://groups.google.com/forum/#!topic/golang-nuts/KBx9pDlvFOc 179 | d, err := net.Dial("tcp", target) 180 | if err != nil { 181 | http.Error(w, "Error contacting backend server.", 500) 182 | log.Printf("Error dialing websocket backend %s: %v", target, err) 183 | return 184 | } 185 | hj, ok := w.(http.Hijacker) 186 | if !ok { 187 | http.Error(w, "Not a hijacker?", 500) 188 | return 189 | } 190 | nc, _, err := hj.Hijack() 191 | if err != nil { 192 | log.Printf("Hijack error: %v", err) 193 | return 194 | } 195 | defer nc.Close() 196 | defer d.Close() 197 | 198 | err = r.Write(d) 199 | if err != nil { 200 | log.Printf("Error copying request to target: %v", err) 201 | return 202 | } 203 | 204 | errc := make(chan error, 2) 205 | cp := func(dst io.Writer, src io.Reader) { 206 | _, err := io.Copy(dst, src) 207 | errc <- err 208 | } 209 | go cp(d, nc) 210 | go cp(nc, d) 211 | for i := 0; i < cap(errc); i++ { 212 | <-errc 213 | } 214 | } else { 215 | handler.ServeHTTP(w, r) 216 | } 217 | }) 218 | } 219 | 220 | // base64Decode decodes the Base64url encoded string 221 | // 222 | // steel from code.google.com/p/goauth2/oauth/jwt 223 | func base64Decode(s string) ([]byte, error) { 224 | // add back missing padding 225 | switch len(s) % 4 { 226 | case 2: 227 | s += "==" 228 | case 3: 229 | s += "=" 230 | } 231 | return base64.URLEncoding.DecodeString(s) 232 | } 233 | 234 | func restrictRequest(restrictions []string, authenticator Authenticator) martini.Handler { 235 | return func(c martini.Context, tokens oauth2.Tokens, w http.ResponseWriter, r *http.Request) { 236 | // skip websocket 237 | if isWebsocket(r) { 238 | return 239 | } 240 | 241 | authenticator.Authenticate(restrictions, c, tokens, w, r) 242 | } 243 | } 244 | 245 | func loginRequired() martini.Handler { 246 | return func(s sessions.Session, c martini.Context, w http.ResponseWriter, r *http.Request) { 247 | if isWebsocket(r) { 248 | return 249 | } 250 | c.Invoke(oauth2.LoginRequired) 251 | } 252 | } 253 | -------------------------------------------------------------------------------- /httpd_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "net/http" 7 | "os" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestPrepareFoo(t *testing.T) { 13 | http.HandleFunc("/foo/", func(w http.ResponseWriter, r *http.Request) { 14 | fmt.Fprint(w, "hello Foo\n") 15 | }) 16 | go func() { 17 | err := http.ListenAndServe(":10001", nil) 18 | if err != nil { 19 | t.Error(err) 20 | } 21 | }() 22 | time.Sleep(1 * time.Second) 23 | } 24 | 25 | func TestPrepareBar(t *testing.T) { 26 | http.HandleFunc("/bar/", func(w http.ResponseWriter, r *http.Request) { 27 | fmt.Fprint(w, "hello Bar\n") 28 | }) 29 | go func() { 30 | err := http.ListenAndServe(":10002", nil) 31 | if err != nil { 32 | t.Error(err) 33 | } 34 | }() 35 | time.Sleep(1 * time.Second) 36 | } 37 | 38 | func TestRunHTTPd(t *testing.T) { 39 | f, err := ioutil.TempFile("", "") 40 | if err != nil { 41 | t.Error(err) 42 | } 43 | defer func() { 44 | f.Close() 45 | os.Remove(f.Name()) 46 | }() 47 | data := ` 48 | address: "127.0.0.1:9999" 49 | auth: 50 | session: 51 | key: dummy 52 | info: 53 | service: nothing 54 | client_id: dummy 55 | client_secret: dummy 56 | redirect_url: "http://example.com/oauth2callback" 57 | proxy: 58 | - path: /foo 59 | dest: http://127.0.0.1:10001 60 | strip_path: no 61 | 62 | - path: /bar 63 | dest: http://127.0.0.1:10002 64 | strip_path: no 65 | ` 66 | if err := ioutil.WriteFile(f.Name(), []byte(data), 0644); err != nil { 67 | t.Error(err) 68 | } 69 | conf, err := ParseConf(f.Name()) 70 | if err != nil { 71 | t.Error(err) 72 | } 73 | server := NewServer(conf) 74 | if server == nil { 75 | t.Error("NewServer failed") 76 | } 77 | go server.Run() 78 | time.Sleep(1 * time.Second) 79 | 80 | // backend foo 81 | if res, err := http.Get("http://127.0.0.1:9999/foo/"); err == nil { 82 | defer res.Body.Close() 83 | body, _ := ioutil.ReadAll(res.Body) 84 | if string(body) != "hello Foo\n" { 85 | t.Errorf("unexpected foo body %s", body) 86 | } 87 | } else { 88 | t.Error(err) 89 | } 90 | 91 | // backend bar 92 | if res, err := http.Get("http://127.0.0.1:9999/bar/"); err == nil { 93 | defer res.Body.Close() 94 | body, _ := ioutil.ReadAll(res.Body) 95 | if string(body) != "hello Bar\n" { 96 | t.Errorf("unexpected bar body %s", body) 97 | } 98 | } else { 99 | t.Error(err) 100 | } 101 | } 102 | 103 | func TestRunVhost(t *testing.T) { 104 | f, err := ioutil.TempFile("", "") 105 | if err != nil { 106 | t.Error(err) 107 | } 108 | defer func() { 109 | f.Close() 110 | os.Remove(f.Name()) 111 | }() 112 | data := ` 113 | address: "127.0.0.1:10000" 114 | auth: 115 | session: 116 | key: dummy 117 | cookie_domain: example.com 118 | info: 119 | service: nothing 120 | client_id: dummy 121 | client_secret: dummy 122 | redirect_url: "http://example.com/oauth2callback" 123 | proxy: 124 | - path: / 125 | dest: http://127.0.0.1:10001 126 | strip_path: no 127 | host: foo.example.com 128 | 129 | - path: / 130 | dest: http://127.0.0.1:10002 131 | strip_path: no 132 | host: bar.example.com 133 | ` 134 | if err := ioutil.WriteFile(f.Name(), []byte(data), 0644); err != nil { 135 | t.Error(err) 136 | } 137 | conf, err := ParseConf(f.Name()) 138 | if err != nil { 139 | t.Error(err) 140 | } 141 | server := NewServer(conf) 142 | if server == nil { 143 | t.Error("NewServer failed") 144 | } 145 | go server.Run() 146 | time.Sleep(1 * time.Second) 147 | 148 | var req *http.Request 149 | client := &http.Client{} 150 | 151 | // backend foo 152 | req, _ = http.NewRequest("GET", "http://127.0.0.1:10000/foo/", nil) 153 | req.Header.Add("Host", "foo.example.com") 154 | if res, err := client.Do(req); err == nil { 155 | defer res.Body.Close() 156 | body, _ := ioutil.ReadAll(res.Body) 157 | if string(body) != "hello Foo\n" { 158 | t.Errorf("unexpected foo body %s", body) 159 | } 160 | } else { 161 | t.Error(err) 162 | } 163 | 164 | // backend bar 165 | req, _ = http.NewRequest("GET", "http://127.0.0.1:10000/bar/", nil) 166 | req.Header.Add("Host", "bar.example.com") 167 | if res, err := http.Get("http://127.0.0.1:10000/bar/"); err == nil { 168 | defer res.Body.Close() 169 | body, _ := ioutil.ReadAll(res.Body) 170 | if string(body) != "hello Bar\n" { 171 | t.Errorf("unexpected bar body %s", body) 172 | } 173 | } else { 174 | t.Error(err) 175 | } 176 | 177 | } 178 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | ) 7 | 8 | var ( 9 | confFile = flag.String("conf", "config.yml", "config file path") 10 | ) 11 | 12 | func main() { 13 | flag.Parse() 14 | 15 | conf, err := ParseConf(*confFile) 16 | if err != nil { 17 | panic(err) 18 | } 19 | 20 | conf.SetOAuth2Paths() 21 | 22 | server := NewServer(conf) 23 | log.Fatal(server.Run()) 24 | } 25 | --------------------------------------------------------------------------------