├── testdata ├── i18n │ ├── messages │ │ ├── invalid_message_file_name.txt │ │ ├── english_messages2.en │ │ ├── dutch_messages.nl │ │ └── english_messages.en │ └── config │ │ └── test_app.conf ├── conf │ ├── mime-types.conf │ ├── routes │ └── app.conf ├── views │ ├── i18n.html │ ├── i18n_ctx.html │ ├── footer.html │ ├── hotels │ │ └── show.html │ └── header.html └── public │ └── js │ └── sessvars.js ├── templates └── errors │ ├── 404.xml │ ├── 403.txt │ ├── 403.xml │ ├── 404.txt │ ├── 405.txt │ ├── 405.xml │ ├── 403.json │ ├── 404.json │ ├── 405.json │ ├── 500.json │ ├── 500.xml │ ├── 403.html │ ├── 405.html │ ├── 500.txt │ ├── 500.html │ ├── 404.html │ ├── 404-dev.html │ └── 500-dev.html ├── .gitignore ├── docs ├── faq.md ├── code-generation.md ├── getting-started.md ├── testing.md ├── index.md └── migration.md ├── examples └── servethis │ └── main.go ├── mkdocs.yml ├── watchfilter.go ├── sanitize.go ├── sign_fuzz_test.go ├── go.mod ├── cookie.go ├── reflection.go ├── panic.go ├── cmd └── mars-gen │ ├── filesorting.go │ ├── main.go │ └── main_test.go ├── cert_test.go ├── mime_test.go ├── reflection_test.go ├── session_fuzz_test.go ├── go.sum ├── sanitize_test.go ├── internal ├── pathtree │ ├── LICENSE │ └── tree_test.go └── watcher │ ├── watcher_test.go │ └── watcher.go ├── LICENSE ├── .github └── workflows │ ├── build-and-test.yml │ └── codeql-analysis.yml ├── filter.go ├── invoker.go ├── testing ├── equal.go └── equal_test.go ├── results_test.go ├── sign.go ├── compress_test.go ├── panic_test.go ├── field.go ├── session_test.go ├── cert.go ├── sign_test.go ├── flash.go ├── hooks.go ├── validation_test.go ├── fakeapp_test.go ├── errors.go ├── intercept_test.go ├── server_test.go ├── templates_test.go ├── config.go ├── invoker_test.go ├── filterconfig_test.go ├── params.go ├── static.go ├── README.md ├── csrf.go ├── validators.go ├── server.go ├── http.go ├── params_test.go ├── compress.go ├── session.go ├── intercept.go ├── CHANGELOG.md ├── filterconfig.go ├── validation.go ├── i18n.go └── validators_test.go /testdata/i18n/messages/invalid_message_file_name.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /testdata/conf/mime-types.conf: -------------------------------------------------------------------------------- 1 | bkng=application/x-booking -------------------------------------------------------------------------------- /testdata/views/i18n.html: -------------------------------------------------------------------------------- 1 | {{msg $ `arguments.html` .input}} -------------------------------------------------------------------------------- /testdata/views/i18n_ctx.html: -------------------------------------------------------------------------------- 1 | {{t `arguments.html` .input}} -------------------------------------------------------------------------------- /testdata/i18n/messages/english_messages2.en: -------------------------------------------------------------------------------- 1 | greeting2=Yo! 2 | -------------------------------------------------------------------------------- /testdata/public/js/sessvars.js: -------------------------------------------------------------------------------- 1 | console.log('Test file'); 2 | -------------------------------------------------------------------------------- /templates/errors/404.xml: -------------------------------------------------------------------------------- 1 | {{.Error.Description}} 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | tmp/ 2 | routes/ 3 | test-results/ 4 | 5 | # editor 6 | *.swp 7 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | # Frequently Asked Questions 2 | 3 | **WORK IN PROGRESS** 4 | -------------------------------------------------------------------------------- /templates/errors/403.txt: -------------------------------------------------------------------------------- 1 | {{.Error.Title}} 2 | 3 | {{.Error.Description}} 4 | -------------------------------------------------------------------------------- /templates/errors/403.xml: -------------------------------------------------------------------------------- 1 | {{.Error.Description}} 2 | -------------------------------------------------------------------------------- /templates/errors/404.txt: -------------------------------------------------------------------------------- 1 | {{.Error.Title}} 2 | 3 | {{.Error.Description}} 4 | -------------------------------------------------------------------------------- /templates/errors/405.txt: -------------------------------------------------------------------------------- 1 | {{.Error.Title}} 2 | 3 | {{.Error.Description}} 4 | -------------------------------------------------------------------------------- /docs/code-generation.md: -------------------------------------------------------------------------------- 1 | # Code generation with mars-gen 2 | 3 | **WORK IN PROGRESS** 4 | -------------------------------------------------------------------------------- /templates/errors/405.xml: -------------------------------------------------------------------------------- 1 | {{.Error.Description}} 2 | -------------------------------------------------------------------------------- /templates/errors/403.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "{{js .Error.Title}}", 3 | "description": "{{js .Error.Description}}" 4 | } 5 | -------------------------------------------------------------------------------- /templates/errors/404.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "{{js .Error.Title}}", 3 | "description": "{{js .Error.Description}}" 4 | } 5 | -------------------------------------------------------------------------------- /templates/errors/405.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "{{js .Error.Title}}", 3 | "description": "{{js .Error.Description}}" 4 | } 5 | -------------------------------------------------------------------------------- /templates/errors/500.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "{{js .Error.Title}}", 3 | "description": "{{js .Error.Description}}" 4 | } 5 | -------------------------------------------------------------------------------- /templates/errors/500.xml: -------------------------------------------------------------------------------- 1 | 2 | {{.Error.Title}} 3 | {{.Error.Description}} 4 | 5 | -------------------------------------------------------------------------------- /examples/servethis/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/roblillack/mars" 5 | ) 6 | 7 | func main() { 8 | mars.Run() 9 | } 10 | -------------------------------------------------------------------------------- /testdata/i18n/messages/dutch_messages.nl: -------------------------------------------------------------------------------- 1 | greeting=Hallo 2 | greeting.name=Rob 3 | greeting.suffix=, welkom bij Mars! 4 | 5 | [NL] 6 | greeting=Goeiedag 7 | 8 | [BE] 9 | greeting=Hallokes 10 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Mars documentation 2 | pages: 3 | - Home: index.md 4 | - Getting Started: getting-started.md 5 | - Migrating from Revel: migration.md 6 | - Code generation: code-generation.md 7 | - Testing: testing.md 8 | - F.A.Q.: faq.md 9 | theme: readthedocs -------------------------------------------------------------------------------- /watchfilter.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | var WatchFilter = func(c *Controller, fc []Filter) { 4 | if mainWatcher != nil { 5 | err := mainWatcher.Notify() 6 | if err != nil { 7 | c.Result = c.RenderError(err) 8 | return 9 | } 10 | } 11 | fc[0](c, fc[1:]) 12 | } 13 | -------------------------------------------------------------------------------- /templates/errors/403.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Forbidden 5 | 6 | 7 | {{with .Error}} 8 |

9 | {{.Title}} 10 |

11 |

12 | {{.Description}} 13 |

14 | {{end}} 15 | 16 | 17 | -------------------------------------------------------------------------------- /templates/errors/405.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Method not allowed 5 | 6 | 7 | {{with .Error}} 8 |

9 | {{.Title}} 10 |

11 |

12 | {{.Description}} 13 |

14 | {{end}} 15 | 16 | 17 | -------------------------------------------------------------------------------- /sanitize.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import "regexp" 4 | 5 | var lineBreakPattern = regexp.MustCompile(`[\r\n]+`) 6 | 7 | func removeLineBreaks(s string) string { 8 | return lineBreakPattern.ReplaceAllString(s, " ") 9 | } 10 | 11 | func removeAllWhitespace(s string) string { 12 | return whiteSpacePattern.ReplaceAllString(s, "") 13 | } 14 | -------------------------------------------------------------------------------- /templates/errors/500.txt: -------------------------------------------------------------------------------- 1 | {{.Error.Title}} 2 | {{.Error.Description}} 3 | 4 | {{if eq .RunMode "dev"}} 5 | {{with .Error}} 6 | {{if .Path}} 7 | ---------- 8 | In {{.Path}} {{if .Line}}(around line {{.Line}}){{end}} 9 | 10 | {{range .ContextSource}} 11 | {{if .IsError}}>{{else}} {{end}} {{.Line}}: {{.Source}}{{end}} 12 | 13 | {{end}} 14 | {{end}} 15 | {{end}} 16 | -------------------------------------------------------------------------------- /sign_fuzz_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.18 2 | // +build go1.18 3 | 4 | package mars 5 | 6 | import ( 7 | "testing" 8 | ) 9 | 10 | func FuzzSignatureVerification(f *testing.F) { 11 | secretKey = generateRandomSecretKey() 12 | f.Add("4UcW-3rLvaGGxmA2KUPQgS30MVK7ESKKEPhs4Gir_-E") 13 | f.Fuzz(func(t *testing.T, sig string) { 14 | Verify("Untouchable", sig) 15 | }) 16 | } 17 | -------------------------------------------------------------------------------- /templates/errors/500.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Application error 5 | 6 | 7 | {{if .DevMode}} 8 | {{template "errors/500-dev.html" .}} 9 | {{else}} 10 |

Oops, an error occured.

11 |

{{.Error.Title}}

12 |

13 | {{.Error.Description}} 14 |

15 | {{end}} 16 | 17 | 18 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/roblillack/mars 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.2 6 | 7 | require ( 8 | github.com/agtorre/gocolorize v1.0.0 9 | github.com/codegangsta/cli v1.20.0 10 | github.com/fsnotify/fsnotify v1.7.0 11 | github.com/robfig/config v0.0.0-20141207224736-0f78529c8c7e 12 | golang.org/x/net v0.38.0 13 | ) 14 | 15 | require golang.org/x/sys v0.31.0 // indirect 16 | -------------------------------------------------------------------------------- /templates/errors/404.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Not found 5 | 6 | 7 | 8 | {{if .DevMode}} 9 | 10 | {{template "errors/404-dev.html" .}} 11 | 12 | {{else}} 13 | 14 | {{with .Error}} 15 |

16 | {{.Title}} 17 |

18 |

19 | {{.Description}} 20 |

21 | {{end}} 22 | 23 | {{end}} 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /cookie.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "net/url" 5 | "regexp" 6 | ) 7 | 8 | var ( 9 | cookieKeyValueParser = regexp.MustCompile("\x00([^:]*):([^\x00]*)\x00") 10 | ) 11 | 12 | // parseKeyValueCookie takes the raw (escaped) cookie value and parses out key values. 13 | func parseKeyValueCookie(val string, cb func(key, val string)) { 14 | val, _ = url.QueryUnescape(val) 15 | if matches := cookieKeyValueParser.FindAllStringSubmatch(val, -1); matches != nil { 16 | for _, match := range matches { 17 | cb(match[1], match[2]) 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /testdata/views/footer.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /reflection.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | // Return the reflect.Method, given a Receiver type and Func value. 8 | func findMethod(recvType reflect.Type, funcVal reflect.Value) *reflect.Method { 9 | // It is not possible to get the name of the method from the Func. 10 | // Instead, compare it to each method of the Controller. 11 | for i := 0; i < recvType.NumMethod(); i++ { 12 | method := recvType.Method(i) 13 | if method.Func.Pointer() == funcVal.Pointer() { 14 | return &method 15 | } 16 | } 17 | return nil 18 | } 19 | -------------------------------------------------------------------------------- /testdata/i18n/messages/english_messages.en: -------------------------------------------------------------------------------- 1 | greeting=Hello 2 | greeting.name=Rob 3 | greeting.suffix=, welcome to Mars! 4 | 5 | folded=Greeting is '%(greeting)s' 6 | folded.arguments=%(greeting.name)s is %d years old 7 | 8 | arguments.string=My name is %s 9 | arguments.hex=The number %d in hexadecimal notation would be %x 10 | arguments.none=No arguments here son 11 | arguments.html=

Hey, there %s!

12 | 13 | only_exists_in_default=Default 14 | 15 | [AU] 16 | greeting=G'day 17 | 18 | [US] 19 | greeting=Howdy 20 | 21 | [GB] 22 | greeting=All right -------------------------------------------------------------------------------- /testdata/i18n/config/test_app.conf: -------------------------------------------------------------------------------- 1 | app.name={{ .AppName }} 2 | app.secret={{ .Secret }} 3 | http.addr= 4 | http.port=9000 5 | cookie.prefix=MARS 6 | 7 | i18n.default_language=en 8 | i18n.cookie=APP_LANG 9 | 10 | [dev] 11 | results.pretty=true 12 | results.staging=true 13 | watch=true 14 | 15 | log.trace.output = off 16 | log.info.output = stderr 17 | log.warn.output = stderr 18 | log.error.output = stderr 19 | 20 | [prod] 21 | results.pretty=false 22 | results.staging=false 23 | watch=false 24 | 25 | log.trace.output = off 26 | log.info.output = off 27 | log.warn.output = %(app.name)s.log 28 | log.error.output = %(app.name)s.log 29 | -------------------------------------------------------------------------------- /panic.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "fmt" 5 | "runtime/debug" 6 | ) 7 | 8 | // PanicFilter wraps the action invocation in a protective defer blanket that 9 | // converts panics into 500 "Runtime Error" pages. 10 | func PanicFilter(c *Controller, fc []Filter) { 11 | defer func() { 12 | if err := recover(); err != nil { 13 | e := &Error{ 14 | Title: "Runtime Error", 15 | Description: fmt.Sprint(err), 16 | } 17 | 18 | if DevMode { 19 | e.Stack = string(debug.Stack()) 20 | } 21 | 22 | ERROR.Println(e, "\n", e.Stack) 23 | c.Result = c.RenderError(e) 24 | } 25 | }() 26 | fc[0](c, fc[1:]) 27 | } 28 | -------------------------------------------------------------------------------- /cmd/mars-gen/filesorting.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "go/ast" 5 | "sort" 6 | ) 7 | 8 | type fInfo struct { 9 | Filename string 10 | File *ast.File 11 | } 12 | 13 | type byName []fInfo 14 | 15 | func (s byName) Len() int { 16 | return len(s) 17 | } 18 | func (s byName) Swap(i, j int) { 19 | s[i], s[j] = s[j], s[i] 20 | } 21 | func (s byName) Less(i, j int) bool { 22 | return s[i].Filename < s[j].Filename 23 | } 24 | 25 | func getSortedFiles(pkg *ast.Package) []fInfo { 26 | entries := make([]fInfo, 0, len(pkg.Files)) 27 | for fn, f := range pkg.Files { 28 | entries = append(entries, fInfo{Filename: fn, File: f}) 29 | } 30 | 31 | sort.Sort(byName(entries)) 32 | 33 | return entries 34 | } 35 | -------------------------------------------------------------------------------- /testdata/conf/routes: -------------------------------------------------------------------------------- 1 | # Routes 2 | # This file defines all application routes (Higher priority routes first) 3 | # ~~~~ 4 | 5 | GET /hotels Hotels.Index 6 | GET /hotels/:id Hotels.Show 7 | GET /hotels/:id/booking Hotels.Book 8 | GET /boom Hotels.Boom 9 | 10 | # Map static resources from the /app/public folder to the /public path 11 | GET /public/*filepath Static.Serve("public") 12 | GET /favicon.ico Static.Serve("public/img","favicon.png") 13 | 14 | # Catch all 15 | * /:controller/:action :controller.:action 16 | -------------------------------------------------------------------------------- /cert_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "crypto/x509" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestCertificateCreation(t *testing.T) { 10 | for org, domains := range map[string][]string{ 11 | "ACME Inc.": {"acme.com", "acme.biz"}, 12 | "Me": {"::1", "127.0.0.1"}, 13 | } { 14 | keypair, err := createCertificate(org, strings.Join(domains, ", ")) 15 | if err != nil { 16 | t.Fatal(err) 17 | } 18 | 19 | cert, err := x509.ParseCertificate(keypair.Certificate[0]) 20 | if err != nil { 21 | t.Fatal(err) 22 | } 23 | 24 | for _, i := range domains { 25 | if err := cert.VerifyHostname(i); err != nil { 26 | t.Errorf("Unable to validate host %s for %s: %s", i, org, err) 27 | } 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /testdata/conf/app.conf: -------------------------------------------------------------------------------- 1 | # Application 2 | app.name=Booking example 3 | app.secret=secret 4 | 5 | # Server 6 | http.addr= 7 | http.port=9000 8 | http.ssl=false 9 | http.sslcert= 10 | http.sslkey= 11 | 12 | # Logging 13 | log.trace.output = stderr 14 | log.info.output = stderr 15 | log.warn.output = stderr 16 | log.error.output = stderr 17 | 18 | log.trace.prefix = "TRACE " 19 | log.info.prefix = "INFO " 20 | log.warn.prefix = "WARN " 21 | log.error.prefix = "ERROR " 22 | 23 | db.import = github.com/mattn/go-sqlite3 24 | db.driver = sqlite3 25 | db.spec = :memory: 26 | 27 | build.tags=gorp 28 | 29 | [dev] 30 | mode.dev=true 31 | watch=true 32 | 33 | [prod] 34 | watch=false 35 | 36 | log.trace.output = off 37 | log.info.output = off 38 | log.warn.output = stderr 39 | log.error.output = stderr 40 | -------------------------------------------------------------------------------- /testdata/views/hotels/show.html: -------------------------------------------------------------------------------- 1 | {{template "header.html" .}} 2 | 3 |

View hotel

4 | 5 | {{with .hotel}} 6 |
7 | 8 |

9 | Name: {{.Name}} 10 |

11 |

12 | Address: {{.Address}} 13 |

14 |

15 | City: {{.City}} 16 |

17 |

18 | State: {{.State}} 19 |

20 |

21 | Zip: {{.Zip}} 22 |

23 |

24 | Country: {{.Country}} 25 |

26 |

27 | Nightly rate: {{.Price}} 28 |

29 | 30 |

31 | 32 | Back to search 33 |

34 |
35 | {{end}} 36 | 37 | {{template "footer.html" .}} 38 | -------------------------------------------------------------------------------- /docs/getting-started.md: -------------------------------------------------------------------------------- 1 | # Getting started with Mars 2 | 3 | **WORK IN PROGRESS** 4 | 5 | There is _no_ fixed directory hierarchy with Mars, but a projects' structure typically looks like this: 6 | 7 | ``` 8 | - myProject (Your project) 9 | | 10 | |-- main.go (Your main go code might live here, but can also be in ./cmd/something) 11 | | 12 | |-- conf (Directory with configuration files which are needed at runtime) 13 | | | 14 | | |-- app.conf (main configuration file) 15 | | | 16 | | +-- routes (Configuration of the routes) 17 | | 18 | |-- views (All the view templates are here) 19 | | | 20 | | |-- hotel (view templates for the “Hotel” controller) 21 | | | 22 | | +-- other (view templates for the “Other” controller) 23 | | 24 | +-- mySubpackage (Your code is in arbitrary sub packages) 25 | | 26 | +-- foo.go 27 | ``` 28 | -------------------------------------------------------------------------------- /mime_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestContentTypeByFilename(t *testing.T) { 8 | testCases := map[string]string{ 9 | "xyz.jpg": "image/jpeg", 10 | "helloworld.c": "text/x-c; charset=utf-8", 11 | "helloworld.": "application/octet-stream", 12 | "helloworld": "application/octet-stream", 13 | "hello.world.c": "text/x-c; charset=utf-8", 14 | } 15 | for filename, expected := range testCases { 16 | actual := ContentTypeByFilename(filename) 17 | if actual != expected { 18 | t.Errorf("%s: %s, Expected %s", filename, actual, expected) 19 | } 20 | } 21 | } 22 | 23 | func TestCustomMimeTypes(t *testing.T) { 24 | startFakeBookingApp() 25 | 26 | if ct := ContentTypeByFilename("B1F1AA4C-8156-4649-9248-0DE19BD63164.bkng"); ct != "application/x-booking" { 27 | t.Errorf("Wrong MIME type returned: %s", ct) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /reflection_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | type T struct{} 9 | 10 | func (t *T) Hello() {} 11 | 12 | func TestFindMethod(t *testing.T) { 13 | for name, tv := range map[string]struct { 14 | reflect.Type 15 | reflect.Value 16 | }{ 17 | "Hello": {reflect.TypeOf(&T{}), reflect.ValueOf((*T).Hello)}, 18 | "Helper": {reflect.TypeOf(t), reflect.ValueOf((*testing.T).Helper)}, 19 | "": {reflect.TypeOf(t), reflect.ValueOf((reflect.Type).Comparable)}, 20 | } { 21 | m := findMethod(tv.Type, tv.Value) 22 | if name == "" { 23 | if m != nil { 24 | t.Errorf("method found that shouldn't be here: %v", m) 25 | } 26 | continue 27 | } 28 | if m == nil { 29 | t.Errorf("No method found when looking for %s", name) 30 | continue 31 | } 32 | if m.Name != name { 33 | t.Errorf("Expected method %s, got %s: %v", name, m.Name, m) 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /session_fuzz_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.18 2 | // +build go1.18 3 | 4 | package mars 5 | 6 | //go 7 | 8 | import ( 9 | "fmt" 10 | "net/http" 11 | "testing" 12 | ) 13 | 14 | func makeCookie(args Args) string { 15 | session := make(Session) 16 | session.SetDefaultExpiration() 17 | for k, v := range args { 18 | session[k] = fmt.Sprint(v) 19 | } 20 | return session.Cookie().Value 21 | } 22 | 23 | func FuzzSessionDecoding(f *testing.F) { 24 | secretKey = generateRandomSecretKey() 25 | 26 | f.Add(makeCookie(Args{"username": "roblillack"})) 27 | f.Add(makeCookie(Args{"username": "roblillack", "lang": "de"})) 28 | f.Add(makeCookie(Args{"username": "roblillack", "lang": "de", "orientation": "portrait"})) 29 | f.Add(makeCookie(Args{"username": "roblillack", "bw": true})) 30 | f.Add(makeCookie(Args{"no": 28963473, "bw": true})) 31 | 32 | f.Fuzz(func(t *testing.T, cookieContent string) { 33 | cookie := &http.Cookie{Value: cookieContent} 34 | if session := GetSessionFromCookie(cookie); session == nil { 35 | t.Fail() 36 | } 37 | }) 38 | } 39 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/agtorre/gocolorize v1.0.0 h1:TvGQd+fAqWQlDjQxSKe//Y6RaxK+RHpEU9X/zPmHW50= 2 | github.com/agtorre/gocolorize v1.0.0/go.mod h1:cH6imfTkHVBRJhSOeSeEZhB4zqEYSq0sXuIyehgZMIY= 3 | github.com/codegangsta/cli v1.20.0 h1:iX1FXEgwzd5+XN6wk5cVHOGQj6Q3Dcp20lUeS4lHNTw= 4 | github.com/codegangsta/cli v1.20.0/go.mod h1:/qJNoX69yVSKu5o4jLyXAENLRyk1uhi7zkbQ3slBdOA= 5 | github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= 6 | github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= 7 | github.com/robfig/config v0.0.0-20141207224736-0f78529c8c7e h1:3/9k/etUfgykjM3Rx8X0echJzo7gNNeND/ubPkqYw1k= 8 | github.com/robfig/config v0.0.0-20141207224736-0f78529c8c7e/go.mod h1:Zerq1qYbCKtIIU9QgPydffGlpYfZ8KI/si49wuTLY/Q= 9 | golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= 10 | golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= 11 | golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= 12 | golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 13 | -------------------------------------------------------------------------------- /sanitize_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestRemovingLineBreaks(t *testing.T) { 8 | for i, exp := range map[string]string{ 9 | "This is a test.": "This is a test.", 10 | "This is\n a test.": "This is a test.", 11 | "This is\r a test.": "This is a test.", 12 | "This is\r\n a test.": "This is a test.", 13 | "\n\n\n\n\nThis is\r a test.": " This is a test.", 14 | } { 15 | if res := removeLineBreaks(i); res != exp { 16 | t.Errorf("Unexpected result '%s' when removing line breaks from '%s'.\n", res, i) 17 | } 18 | } 19 | } 20 | 21 | func TestRemovingAllWhitespace(t *testing.T) { 22 | for i, exp := range map[string]string{ 23 | "This is a test.": "Thisisatest.", 24 | "This is\n a test.": "Thisisatest.", 25 | "This is\r a test.": "Thisisatest.", 26 | "This is\r\n a test.": "Thisisatest.", 27 | "\n\n\n\n\nThis is\r a test.": "Thisisatest.", 28 | } { 29 | if res := removeAllWhitespace(i); res != exp { 30 | t.Errorf("Unexpected result '%s' when removing all whitespace from '%s'.\n", res, i) 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /internal/pathtree/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2013 Rob Figueiredo 2 | All Rights Reserved. 3 | 4 | MIT LICENSE 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy of 7 | this software and associated documentation files (the "Software"), to deal in 8 | the Software without restriction, including without limitation the rights to 9 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 10 | the Software, and to permit persons to whom the Software is furnished to do so, 11 | subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 18 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 19 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 20 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 21 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015–2021 Rob Lillack 2 | Copyright (C) 2012 Rob Figueiredo 3 | All Rights Reserved. 4 | 5 | MIT LICENSE 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy of 8 | this software and associated documentation files (the "Software"), to deal in 9 | the Software without restriction, including without limitation the rights to 10 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 11 | the Software, and to permit persons to whom the Software is furnished to do so, 12 | subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 19 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 20 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 21 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 22 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /.github/workflows/build-and-test.yml: -------------------------------------------------------------------------------- 1 | name: Build & Test 2 | 3 | on: 4 | push: 5 | branches: ["master"] 6 | pull_request: 7 | branches: ["master"] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | go-version: ["1.22", "1.23", "1.24"] 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Setup Go ${{ matrix.go-version }} 20 | uses: actions/setup-go@v5 21 | with: 22 | go-version: ${{ matrix.go-version }} 23 | 24 | - name: Build 25 | run: go build -v ./... 26 | 27 | - name: Test 28 | run: go test -v ./... 29 | 30 | - name: Code coverage analysis 31 | run: go test -v -coverprofile=profile.cov ./... 32 | 33 | - name: Send coverage 34 | uses: shogo82148/actions-goveralls@v1 35 | with: 36 | path-to-profile: profile.cov 37 | flag-name: Go-${{ matrix.go-version }} 38 | parallel: true 39 | 40 | finish: 41 | needs: build 42 | runs-on: ubuntu-latest 43 | steps: 44 | - uses: shogo82148/actions-goveralls@v1 45 | with: 46 | parallel-finished: true 47 | -------------------------------------------------------------------------------- /filter.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | type Filter func(c *Controller, filterChain []Filter) 4 | 5 | // Filters is the default set of global filters. 6 | // It may be set by the application on initialization. 7 | var Filters = []Filter{ 8 | PanicFilter, // Recover from panics and display an error page instead. 9 | RouterFilter, // Use the routing table to select the right Action. 10 | FilterConfiguringFilter, // A hook for adding or removing per-Action filters. 11 | ParamsFilter, // Parse parameters into Controller.Params. 12 | SessionFilter, // Restore and write the session cookie. 13 | FlashFilter, // Restore and write the flash cookie. 14 | ValidationFilter, // Restore kept validation errors and save new ones from cookie. 15 | I18nFilter, // Resolve the requested language. 16 | InterceptorFilter, // Run interceptors around the action. 17 | CompressFilter, // Compress the result. 18 | CSRFFilter, // Protect against Cross-site request forgery 19 | ActionInvoker, // Invoke the action. 20 | } 21 | 22 | // NilFilter and NilChain are helpful in writing filter tests. 23 | var ( 24 | NilFilter = func(_ *Controller, _ []Filter) {} 25 | NilChain = []Filter{NilFilter} 26 | ) 27 | -------------------------------------------------------------------------------- /templates/errors/404-dev.html: -------------------------------------------------------------------------------- 1 | 45 | 46 | 56 | {{if .Router}} 57 |
58 |

These routes have been tried, in this order :

59 |
    60 | {{range .Router.Routes}} 61 |
  1. {{pad .Method 10}}{{pad .Path 50}}{{.Action}}
  2. 62 | {{end}} 63 |
64 |
65 | {{end}} 66 | -------------------------------------------------------------------------------- /invoker.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "reflect" 5 | 6 | "golang.org/x/net/websocket" 7 | ) 8 | 9 | var ( 10 | controllerType = reflect.TypeOf(Controller{}) 11 | controllerPtrType = reflect.TypeOf(&Controller{}) 12 | websocketType = reflect.TypeOf((*websocket.Conn)(nil)) 13 | ) 14 | 15 | func ActionInvoker(c *Controller, _ []Filter) { 16 | // Instantiate the method. 17 | methodValue := reflect.ValueOf(c.AppController).MethodByName(c.MethodType.Name) 18 | 19 | // Collect the values for the method's arguments. 20 | var methodArgs []reflect.Value 21 | for _, arg := range c.MethodType.Args { 22 | // If they accept a websocket connection, treat that arg specially. 23 | var boundArg reflect.Value 24 | if arg.Type == websocketType { 25 | boundArg = reflect.ValueOf(c.Request.Websocket) 26 | } else { 27 | boundArg = Bind(c.Params, arg.Name, arg.Type) 28 | } 29 | methodArgs = append(methodArgs, boundArg) 30 | } 31 | 32 | var resultValue reflect.Value 33 | if methodValue.Type().IsVariadic() { 34 | resultValue = methodValue.CallSlice(methodArgs)[0] 35 | } else { 36 | resultValue = methodValue.Call(methodArgs)[0] 37 | } 38 | if resultValue.Kind() == reflect.Interface && !resultValue.IsNil() { 39 | c.Result = resultValue.Interface().(Result) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /docs/testing.md: -------------------------------------------------------------------------------- 1 | # Testing with Mars 2 | 3 | As Mars tries to achieve a more idiomatic approach to devloping web applications with Go as Revel does, 4 | unit tests are written using the standard Go `testing` package. 5 | 6 | On top of this, Mars provides an easy to use TestSuite (github.com/mars/testing)[https://godoc.org/github.com/roblillack/mars/testing] 7 | which can be used like this: 8 | 9 | package controllers 10 | 11 | import ( 12 | "os" 13 | "path/filepath" 14 | "runtime" 15 | "testing" 16 | "time" 17 | 18 | "github.com/roblillack/mars" 19 | marst "github.com/roblillack/mars/testing" 20 | ) 21 | 22 | func TestMain(m *testing.M) { 23 | setupMars() 24 | retCode := m.Run() 25 | os.Exit(retCode) 26 | } 27 | 28 | func setupMars() { 29 | _, filename, _, _ := runtime.Caller(0) 30 | 31 | RegisterControllers() 32 | mars.ViewsPath = filepath.Join("app", "views") 33 | mars.InitDefaults("dev", filepath.Join(filepath.Dir(filename), "..", "..")) 34 | mars.DevMode = true 35 | 36 | go mars.Run() 37 | 38 | time.Sleep(1 * time.Second) 39 | } 40 | 41 | func Test_Health(t *testing.T) { 42 | ts := marst.NewTestSuite() 43 | ts.Get("/health") 44 | ts.AssertContains("Ok") 45 | ts.AssertOk() 46 | } 47 | -------------------------------------------------------------------------------- /testing/equal.go: -------------------------------------------------------------------------------- 1 | package testing 2 | 3 | import "reflect" 4 | 5 | // Equal is a helper for comparing value equality, following these rules: 6 | // - Values with equivalent types are compared with reflect.DeepEqual 7 | // - int, uint, and float values are compared without regard to the type width. 8 | // for example, Equal(int32(5), int64(5)) == true 9 | // - strings and byte slices are converted to strings before comparison. 10 | // - else, return false. 11 | func Equal(a, b interface{}) bool { 12 | if reflect.TypeOf(a) == reflect.TypeOf(b) { 13 | return reflect.DeepEqual(a, b) 14 | } 15 | switch a.(type) { 16 | case int, int8, int16, int32, int64: 17 | switch b.(type) { 18 | case int, int8, int16, int32, int64: 19 | return reflect.ValueOf(a).Int() == reflect.ValueOf(b).Int() 20 | } 21 | case uint, uint8, uint16, uint32, uint64: 22 | switch b.(type) { 23 | case uint, uint8, uint16, uint32, uint64: 24 | return reflect.ValueOf(a).Uint() == reflect.ValueOf(b).Uint() 25 | } 26 | case float32, float64: 27 | switch b.(type) { 28 | case float32, float64: 29 | return reflect.ValueOf(a).Float() == reflect.ValueOf(b).Float() 30 | } 31 | case string: 32 | switch b.(type) { 33 | case []byte: 34 | return a.(string) == string(b.([]byte)) 35 | } 36 | case []byte: 37 | switch b.(type) { 38 | case string: 39 | return b.(string) == string(a.([]byte)) 40 | } 41 | } 42 | return false 43 | } 44 | -------------------------------------------------------------------------------- /results_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "net/http/httptest" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | // Test that the render response is as expected. 10 | func TestBenchmarkRender(t *testing.T) { 11 | startFakeBookingApp() 12 | resp := httptest.NewRecorder() 13 | c := NewController(NewRequest(showRequest), NewResponse(resp)) 14 | c.SetAction("Hotels", "Show") 15 | result := Hotels{c}.Show(3) 16 | result.Apply(c.Request, c.Response) 17 | if !strings.Contains(resp.Body.String(), "300 Main St.") { 18 | t.Errorf("Failed to find hotel address in action response:\n%s", resp.Body) 19 | } 20 | } 21 | 22 | func BenchmarkRenderChunked(b *testing.B) { 23 | startFakeBookingApp() 24 | resp := httptest.NewRecorder() 25 | resp.Body = nil 26 | c := NewController(NewRequest(showRequest), NewResponse(resp)) 27 | c.SetAction("Hotels", "Show") 28 | Config.SetOption("results.chunked", "true") 29 | b.ResetTimer() 30 | 31 | hotels := Hotels{c} 32 | for i := 0; i < b.N; i++ { 33 | hotels.Show(3).Apply(c.Request, c.Response) 34 | } 35 | } 36 | 37 | func BenchmarkRenderNotChunked(b *testing.B) { 38 | startFakeBookingApp() 39 | resp := httptest.NewRecorder() 40 | resp.Body = nil 41 | c := NewController(NewRequest(showRequest), NewResponse(resp)) 42 | c.SetAction("Hotels", "Show") 43 | Config.SetOption("results.chunked", "false") 44 | b.ResetTimer() 45 | 46 | hotels := Hotels{c} 47 | for i := 0; i < b.N; i++ { 48 | hotels.Show(3).Apply(c.Request, c.Response) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /sign.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "crypto/hmac" 5 | "crypto/rand" 6 | "crypto/sha256" 7 | "encoding/base64" 8 | "io" 9 | ) 10 | 11 | var HashAlgorithm = sha256.New 12 | var HashBlockSize = sha256.BlockSize 13 | 14 | var ( 15 | // Private 16 | secretKey []byte // Key used to sign cookies. 17 | ) 18 | 19 | func SetAppSecret(secret string) { 20 | secretKey = []byte(secret) 21 | } 22 | 23 | func generateRandomSecretKey() []byte { 24 | buf := make([]byte, HashBlockSize) 25 | if _, err := rand.Read(buf); err != nil { 26 | panic("Unable to generate random application secret") 27 | } 28 | 29 | return buf 30 | } 31 | 32 | // Sign a given string with the configured or random secret key. 33 | // If no secret key is set, returns the empty string. 34 | // Return the signature in unpadded, URL-safe base64 encoding 35 | // (A-Z, 0-9, a-z, _ and -). 36 | func Sign(message string) string { 37 | mac := hmac.New(HashAlgorithm, secretKey) 38 | io.WriteString(mac, message) 39 | return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) 40 | } 41 | 42 | // Verify returns true if the given signature is correct for the given message. 43 | // e.g. it matches what we generate with Sign() 44 | func Verify(message, sig string) bool { 45 | // return hmac.Equal([]byte(sig), []byte(Sign(message))) 46 | mac := hmac.New(HashAlgorithm, secretKey) 47 | io.WriteString(mac, message) 48 | hash := mac.Sum(nil) 49 | 50 | received, _ := base64.RawURLEncoding.DecodeString(sig) 51 | return hmac.Equal(received, hash) 52 | } 53 | -------------------------------------------------------------------------------- /testdata/views/header.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | {{.title}} 6 | 7 | 8 | {{range .moreStyles}} 9 | 10 | {{end}} 11 | 12 | 13 | {{range .moreScripts}} 14 | 15 | {{end}} 16 | 17 | 18 | 19 | 33 | 34 |
35 | {{if .flash.error}} 36 |

37 | {{.flash.error}} 38 |

39 | {{end}} 40 | {{if .flash.success}} 41 |

42 | {{.flash.success}} 43 |

44 | {{end}} 45 | 46 | -------------------------------------------------------------------------------- /compress_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "net/http/httptest" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | // Test that the render response is as expected. 10 | func TestBenchmarkCompressed(t *testing.T) { 11 | startFakeBookingApp() 12 | resp := httptest.NewRecorder() 13 | c := NewController(NewRequest(showRequest), NewResponse(resp)) 14 | c.SetAction("Hotels", "Show") 15 | Config.SetOption("results.compressed", "true") 16 | result := Hotels{c}.Show(3) 17 | result.Apply(c.Request, c.Response) 18 | if !strings.Contains(resp.Body.String(), "300 Main St.") { 19 | t.Errorf("Failed to find hotel address in action response:\n%s", resp.Body) 20 | } 21 | } 22 | 23 | func BenchmarkRenderCompressed(b *testing.B) { 24 | startFakeBookingApp() 25 | resp := httptest.NewRecorder() 26 | resp.Body = nil 27 | c := NewController(NewRequest(showRequest), NewResponse(resp)) 28 | c.SetAction("Hotels", "Show") 29 | Config.SetOption("results.compressed", "true") 30 | b.ResetTimer() 31 | 32 | hotels := Hotels{c} 33 | for i := 0; i < b.N; i++ { 34 | hotels.Show(3).Apply(c.Request, c.Response) 35 | } 36 | } 37 | 38 | func BenchmarkRenderUnCompressed(b *testing.B) { 39 | startFakeBookingApp() 40 | resp := httptest.NewRecorder() 41 | resp.Body = nil 42 | c := NewController(NewRequest(showRequest), NewResponse(resp)) 43 | c.SetAction("Hotels", "Show") 44 | Config.SetOption("results.compressed", "false") 45 | b.ResetTimer() 46 | 47 | hotels := Hotels{c} 48 | for i := 0; i < b.N; i++ { 49 | hotels.Show(3).Apply(c.Request, c.Response) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /panic_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | ) 11 | 12 | func TestPanicInAction(t *testing.T) { 13 | startFakeBookingApp() 14 | TRACE = log.New(ioutil.Discard, "", 0) 15 | INFO = TRACE 16 | WARN = TRACE 17 | ERROR = TRACE 18 | DevMode = false 19 | 20 | ts := httptest.NewServer(Handler) 21 | defer ts.Close() 22 | 23 | res, err := http.Get(ts.URL + "/boom") 24 | if err != nil { 25 | log.Fatal(err) 26 | } 27 | resp, err := ioutil.ReadAll(res.Body) 28 | res.Body.Close() 29 | if err != nil { 30 | log.Fatal(err) 31 | } 32 | 33 | if !strings.Contains(string(resp), "OMG") { 34 | t.Error("Unable to get panic description, got:\n", resp) 35 | } 36 | } 37 | 38 | func containsAll(raw []byte, list ...string) bool { 39 | s := string(raw) 40 | for _, i := range list { 41 | if !strings.Contains(s, i) { 42 | return false 43 | } 44 | } 45 | 46 | return true 47 | } 48 | 49 | func TestPanicInDevMode(t *testing.T) { 50 | startFakeBookingApp() 51 | TRACE = log.New(ioutil.Discard, "", 0) 52 | INFO = TRACE 53 | WARN = TRACE 54 | ERROR = TRACE 55 | DevMode = true 56 | 57 | ts := httptest.NewServer(Handler) 58 | defer ts.Close() 59 | 60 | res, err := http.Get(ts.URL + "/boom") 61 | if err != nil { 62 | log.Fatal(err) 63 | } 64 | resp, err := ioutil.ReadAll(res.Body) 65 | res.Body.Close() 66 | if err != nil { 67 | log.Fatal(err) 68 | } 69 | 70 | if !containsAll(resp, 71 | "mars/fakeapp_test.go", 72 | "Hotels.Boom", 73 | "OMG") { 74 | t.Error("Unable to get full panic info, got:\n", string(resp)) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /field.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | ) 7 | 8 | // Field represents a data field that may be collected in a web form. 9 | type Field struct { 10 | Name string 11 | Error *ValidationError 12 | renderArgs map[string]interface{} 13 | } 14 | 15 | func NewField(name string, renderArgs map[string]interface{}) *Field { 16 | err, _ := renderArgs["errors"].(map[string]*ValidationError)[name] 17 | return &Field{ 18 | Name: name, 19 | Error: err, 20 | renderArgs: renderArgs, 21 | } 22 | } 23 | 24 | // Id returns an identifier suitable for use as an HTML id. 25 | func (f *Field) Id() string { 26 | return strings.Replace(f.Name, ".", "_", -1) 27 | } 28 | 29 | // Flash returns the flashed value of this Field. 30 | func (f *Field) Flash() string { 31 | v, _ := f.renderArgs["flash"].(map[string]string)[f.Name] 32 | return v 33 | } 34 | 35 | // FlashArray returns the flashed value of this Field as a list split on comma. 36 | func (f *Field) FlashArray() []string { 37 | v := f.Flash() 38 | if v == "" { 39 | return []string{} 40 | } 41 | return strings.Split(v, ",") 42 | } 43 | 44 | // Value returns the current value of this Field. 45 | func (f *Field) Value() interface{} { 46 | pieces := strings.Split(f.Name, ".") 47 | answer, ok := f.renderArgs[pieces[0]] 48 | if !ok { 49 | return "" 50 | } 51 | 52 | val := reflect.ValueOf(answer) 53 | for i := 1; i < len(pieces); i++ { 54 | if val.Kind() == reflect.Ptr { 55 | val = val.Elem() 56 | } 57 | val = val.FieldByName(pieces[i]) 58 | if !val.IsValid() { 59 | return "" 60 | } 61 | } 62 | 63 | return val.Interface() 64 | } 65 | 66 | // ErrorClass returns ERROR_CLASS if this field has a validation error, else empty string. 67 | func (f *Field) ErrorClass() string { 68 | if f.Error != nil { 69 | if errorClass, ok := f.renderArgs["ERROR_CLASS"]; ok { 70 | return errorClass.(string) 71 | } else { 72 | return ERROR_CLASS 73 | } 74 | } 75 | return "" 76 | } 77 | -------------------------------------------------------------------------------- /session_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestSessionRestore(t *testing.T) { 10 | expireAfterDuration = 0 11 | originSession := make(Session) 12 | originSession["foo"] = "foo" 13 | originSession["bar"] = "bar" 14 | cookie := originSession.Cookie() 15 | if !cookie.Expires.IsZero() { 16 | t.Error("incorrect cookie expire", cookie.Expires) 17 | } 18 | 19 | restoredSession := GetSessionFromCookie(cookie) 20 | for k, v := range originSession { 21 | if restoredSession[k] != v { 22 | t.Errorf("session restore failed session[%s] != %s", k, v) 23 | } 24 | } 25 | } 26 | 27 | func TestSessionExpire(t *testing.T) { 28 | expireAfterDuration = time.Hour 29 | session := make(Session) 30 | session["user"] = "Tom" 31 | var cookie *http.Cookie 32 | for i := 0; i < 3; i++ { 33 | cookie = session.Cookie() 34 | time.Sleep(time.Second) 35 | session = GetSessionFromCookie(cookie) 36 | } 37 | expectExpire := time.Now().Add(expireAfterDuration) 38 | if cookie.Expires.Unix() < expectExpire.Add(-time.Second).Unix() { 39 | t.Error("expect expires", cookie.Expires, "after", expectExpire.Add(-time.Second)) 40 | } 41 | if cookie.Expires.Unix() > expectExpire.Unix() { 42 | t.Error("expect expires", cookie.Expires, "before", expectExpire) 43 | } 44 | 45 | session.SetNoExpiration() 46 | for i := 0; i < 3; i++ { 47 | cookie = session.Cookie() 48 | session = GetSessionFromCookie(cookie) 49 | } 50 | cookie = session.Cookie() 51 | if !cookie.Expires.IsZero() { 52 | t.Error("expect cookie expires is zero") 53 | } 54 | 55 | session.SetDefaultExpiration() 56 | cookie = session.Cookie() 57 | expectExpire = time.Now().Add(expireAfterDuration) 58 | if cookie.Expires.Unix() < expectExpire.Add(-time.Second).Unix() { 59 | t.Error("expect expires", cookie.Expires, "after", expectExpire.Add(-time.Second)) 60 | } 61 | if cookie.Expires.Unix() > expectExpire.Unix() { 62 | t.Error("expect expires", cookie.Expires, "before", expectExpire) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /cert.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "crypto/rsa" 7 | "crypto/tls" 8 | "crypto/x509" 9 | "crypto/x509/pkix" 10 | "encoding/pem" 11 | "math/big" 12 | "net" 13 | "strings" 14 | "time" 15 | ) 16 | 17 | func createCertificate(organization, domainNames string) (tls.Certificate, error) { 18 | INFO.Printf("Creating self-signed TLS certificate for %s\n", organization) 19 | priv, err := rsa.GenerateKey(rand.Reader, 2048) 20 | if err != nil { 21 | ERROR.Fatalf("Failed to generate private key: %s", err) 22 | } 23 | 24 | serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) 25 | serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) 26 | if err != nil { 27 | ERROR.Fatalf("failed to generate serial number: %s", err) 28 | } 29 | 30 | template := x509.Certificate{ 31 | SerialNumber: serialNumber, 32 | Subject: pkix.Name{ 33 | Organization: []string{organization}, 34 | }, 35 | NotBefore: time.Now(), 36 | NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour), 37 | 38 | IsCA: true, 39 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, 40 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 41 | BasicConstraintsValid: true, 42 | } 43 | 44 | hosts := strings.Split(strings.TrimSpace(strings.Replace(domainNames, ",", " ", -1)), " ") 45 | for _, h := range hosts { 46 | if ip := net.ParseIP(h); ip != nil { 47 | template.IPAddresses = append(template.IPAddresses, ip) 48 | } else { 49 | template.DNSNames = append(template.DNSNames, h) 50 | } 51 | } 52 | 53 | derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) 54 | if err != nil { 55 | ERROR.Fatalf("Failed to create certificate: %s", err) 56 | } 57 | cert := &bytes.Buffer{} 58 | pem.Encode(cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) 59 | 60 | key := &bytes.Buffer{} 61 | pem.Encode(key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) 62 | 63 | return tls.X509KeyPair(cert.Bytes(), key.Bytes()) 64 | } 65 | -------------------------------------------------------------------------------- /testing/equal_test.go: -------------------------------------------------------------------------------- 1 | package testing 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestEqual(t *testing.T) { 9 | type testStruct struct{} 10 | type testStruct2 struct{} 11 | i, i2 := 8, 9 12 | s, s2 := "@朕µ\n\tüöäß", "@朕µ\n\tüöäss" 13 | slice, slice2 := []int{1, 2, 3, 4, 5}, []int{1, 2, 3, 4, 5} 14 | slice3, slice4 := []int{5, 4, 3, 2, 1}, []int{5, 4, 3, 2, 1} 15 | 16 | tm := map[string][]interface{}{ 17 | "slices": {slice, slice2}, 18 | "slices2": {slice3, slice4}, 19 | "types": {new(testStruct), new(testStruct)}, 20 | "types2": {new(testStruct2), new(testStruct2)}, 21 | "ints": {int(i), int8(i), int16(i), int32(i), int64(i)}, 22 | "ints2": {int(i2), int8(i2), int16(i2), int32(i2), int64(i2)}, 23 | "uints": {uint(i), uint8(i), uint16(i), uint32(i), uint64(i)}, 24 | "uints2": {uint(i2), uint8(i2), uint16(i2), uint32(i2), uint64(i2)}, 25 | "floats": {float32(i), float64(i)}, 26 | "floats2": {float32(i2), float64(i2)}, 27 | "strings": {[]byte(s), s}, 28 | "strings2": {[]byte(s2), s2}, 29 | } 30 | 31 | testRow := func(row, row2 string, expected bool) { 32 | for _, a := range tm[row] { 33 | for _, b := range tm[row2] { 34 | ok := Equal(a, b) 35 | if ok != expected { 36 | ak := reflect.TypeOf(a).Kind() 37 | bk := reflect.TypeOf(b).Kind() 38 | t.Errorf("eq(%s=%v,%s=%v) want %t got %t", ak, a, bk, b, expected, ok) 39 | } 40 | } 41 | } 42 | } 43 | 44 | testRow("slices", "slices", true) 45 | testRow("slices", "slices2", false) 46 | testRow("slices2", "slices", false) 47 | 48 | testRow("types", "types", true) 49 | testRow("types2", "types", false) 50 | testRow("types", "types2", false) 51 | 52 | testRow("ints", "ints", true) 53 | testRow("ints", "ints2", false) 54 | testRow("ints2", "ints", false) 55 | 56 | testRow("uints", "uints", true) 57 | testRow("uints2", "uints", false) 58 | testRow("uints", "uints2", false) 59 | 60 | testRow("floats", "floats", true) 61 | testRow("floats2", "floats", false) 62 | testRow("floats", "floats2", false) 63 | 64 | testRow("strings", "strings", true) 65 | testRow("strings2", "strings", false) 66 | testRow("strings", "strings2", false) 67 | } 68 | -------------------------------------------------------------------------------- /sign_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/roblillack/mars/internal/pathtree" 8 | ) 9 | 10 | func TestEnsureSecretKeyIsSet(t *testing.T) { 11 | secretKey = nil 12 | MainTemplateLoader = &TemplateLoader{} 13 | MainRouter = &Router{Tree: pathtree.New()} 14 | setup() 15 | if len(secretKey) == 0 || len(secretKey) != HashBlockSize { 16 | t.Fatalf("Not a valid secret key: %+v", secretKey) 17 | } 18 | } 19 | 20 | func BenchmarkSigning(b *testing.B) { 21 | SetAppSecret("Ludolfs lustige Liegestütze ließen Lolas Lachmuskeln leuchten.") 22 | 23 | for n := 0; n < b.N; n++ { 24 | str := fmt.Sprintf("%d", n) 25 | sig := Sign(str) 26 | if ok := Verify(str, sig); !ok { 27 | b.Fatalf("signature '%s' of '%s' cannot be verified!", sig, str) 28 | } 29 | } 30 | } 31 | 32 | func TestSimpleSignatures(t *testing.T) { 33 | SetAppSecret("Kurts käsiger Kugelbauch konterte Karins kichernden Kuss.") 34 | 35 | for msg, sig := range map[string]string{ 36 | "Untouchable": "4UcW-3rLvaGGxmA2KUPQgS30MVK7ESKKEPhs4Gir_-E", 37 | "///a/a///a/": "vYMQQF_m2JnfKa5l0aBt1Iub_IhTu0ZWRcTWDC-oaxE", 38 | } { 39 | if r := Sign(msg); r != sig { 40 | t.Fatalf("wrong signature '%s' for '%s'!", r, msg) 41 | 42 | } 43 | if !Verify(msg, sig) { 44 | t.Fatalf("signature '%s' of '%s' cannot be verified!", sig, msg) 45 | 46 | } 47 | } 48 | } 49 | 50 | func TestSignature(t *testing.T) { 51 | SetAppSecret("Richards rüstige Rottweilerdame Renate riss ruchloserweise Rehe.") 52 | 53 | for n := 0; n < 100; n++ { 54 | msg := "Untouchable " + generateRandomToken() 55 | sig := Sign(msg) 56 | if len(sig) != 43 { 57 | t.Fatalf("wrong signature length %d for '%s' (sig: '%s')!", len(sig), msg, sig) 58 | 59 | } 60 | if !Verify(msg, sig) { 61 | t.Fatalf("signature '%s' of '%s' cannot be verified!", sig, msg) 62 | 63 | } 64 | } 65 | } 66 | 67 | func TestSignatureWithRandomSecret(t *testing.T) { 68 | for i := 0; i < 100; i++ { 69 | secretKey = generateRandomSecretKey() 70 | for n := 0; n < 100; n++ { 71 | msg := "Untouchable " + generateRandomToken() 72 | sig := Sign(msg) 73 | if len(sig) != 43 { 74 | t.Fatalf("wrong signature length %d for '%s' (sig: '%s')!", len(sig), msg, sig) 75 | 76 | } 77 | if !Verify(msg, sig) { 78 | t.Fatalf("signature '%s' of '%s' cannot be verified!", sig, msg) 79 | 80 | } 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /flash.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/url" 7 | ) 8 | 9 | // Flash represents a cookie that is overwritten on each request. 10 | // It allows data to be stored across one page at a time. 11 | // This is commonly used to implement success or error messages. 12 | // E.g. the Post/Redirect/Get pattern: 13 | // http://en.wikipedia.org/wiki/Post/Redirect/Get 14 | type Flash struct { 15 | // `Data` is the input which is read in `restoreFlash`, `Out` is the output which is set in a FLASH cookie at the end of the `FlashFilter()` 16 | Data, Out map[string]string 17 | } 18 | 19 | // Error serializes the given msg and args to an "error" key within 20 | // the Flash cookie. 21 | func (f Flash) Error(msg string, args ...interface{}) { 22 | if len(args) == 0 { 23 | f.Out["error"] = msg 24 | } else { 25 | f.Out["error"] = fmt.Sprintf(msg, args...) 26 | } 27 | } 28 | 29 | // Success serializes the given msg and args to a "success" key within 30 | // the Flash cookie. 31 | func (f Flash) Success(msg string, args ...interface{}) { 32 | if len(args) == 0 { 33 | f.Out["success"] = msg 34 | } else { 35 | f.Out["success"] = fmt.Sprintf(msg, args...) 36 | } 37 | } 38 | 39 | // FlashFilter is a Mars Filter that retrieves and sets the flash cookie. 40 | // Within Mars, it is available as a Flash attribute on Controller instances. 41 | // The name of the Flash cookie is set as CookiePrefix + "_FLASH". 42 | func FlashFilter(c *Controller, fc []Filter) { 43 | c.Flash = restoreFlash(c.Request.Request) 44 | c.RenderArgs["flash"] = c.Flash.Data 45 | 46 | fc[0](c, fc[1:]) 47 | 48 | // Store the flash. 49 | var flashValue string 50 | for key, value := range c.Flash.Out { 51 | flashValue += "\x00" + key + ":" + value + "\x00" 52 | } 53 | c.SetCookie(&http.Cookie{ 54 | Name: CookiePrefix + "_FLASH", 55 | Value: url.QueryEscape(flashValue), 56 | HttpOnly: CookieHttpOnly, 57 | Secure: CookieSecure, 58 | Path: "/", 59 | }) 60 | } 61 | 62 | // restoreFlash deserializes a Flash cookie struct from a request. 63 | func restoreFlash(req *http.Request) Flash { 64 | flash := Flash{ 65 | Data: make(map[string]string), 66 | Out: make(map[string]string), 67 | } 68 | if cookie, err := req.Cookie(CookiePrefix + "_FLASH"); err == nil { 69 | parseKeyValueCookie(cookie.Value, func(key, val string) { 70 | flash.Data[key] = val 71 | }) 72 | } 73 | return flash 74 | } 75 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Mars: A lightweight web toolkit for the Go programming language 2 | 3 | **WORK IN PROGRESS** 4 | 5 | [Mars](https://github.com/roblillack/mars) is a fork of the fantastic, yet not-that-idiomatic-and-pretty-much-abandoned, [Revel framework](https://github.com/revel/revel). You might take a look at the corresponding documentation for the time being. 6 | 7 | Mars provides the following functionality: 8 | 9 | … 10 | 11 | ## Differences to Revel 12 | 13 | The major changes since forking away from Revel are these: 14 | 15 | - More idiomatic approach to integrating the framework into your application: 16 | + No need to use the `revel` command to build, run, package, or distribute your app. 17 | + Code generation (for registering controllers and reverse routes) is supported using the standard `go generate` way. 18 | + No runtime dependencies anymore. Apps using Mars are truly standalone and do not need access to the sources at runtime (default templates and mime config are embedded assets). 19 | + You are not forced into a fixed directory layout or package names anymore. 20 | + Removed most of the "path magic" that tried to determine where the sources of your application and revel are: No global `AppPath`, `ViewsPath`, `TemplatePaths`, `RevelPath`, and `SourcePath` variables anymore. 21 | - Added support for Go 1.5+ vendoring. 22 | - Vendor Mars' dependencies as Git submodules. 23 | - Added support for [HTTP dual-stack mode](https://github.com/roblillack/mars/issues/6). 24 | - Added support for [generating self-signed SSL certificates on-the-fly](https://github.com/roblillack/mars/issues/6). 25 | - Added [graceful shutdown](https://godoc.org/github.com/roblillack/mars#OnAppShutdown) functionality. 26 | - Added [CSRF protection](https://godoc.org/github.com/roblillack/mars#CSRFFilter). 27 | - Integrated `Static` controller to support hosting plain HTML files and assets. 28 | - Removed magic that automatically added template parameter names based on variable names in `Controller.Render()` calls using code generation and runtime introspection. 29 | - Removed the cache library. 30 | - Removed module support. 31 | - Removed support for configurable template delimiters. 32 | - Corrected case of render functions (`RenderXml` --> `RenderXML`). 33 | - Fix generating reverse routes for some edge cases: Action parameter is called `args` or action parameter is of type `interface{}`. 34 | - Fixed a [XSS vulnerability](https://github.com/roblillack/mars/issues/1). -------------------------------------------------------------------------------- /hooks.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | func runStartupHooks() { 4 | for _, hook := range startupHooks { 5 | hook() 6 | } 7 | } 8 | 9 | func runShutdownHooks() { 10 | for _, hook := range shutdownHooks { 11 | hook() 12 | } 13 | } 14 | 15 | var startupHooks []func() 16 | var shutdownHooks []func() 17 | 18 | // Register a function to be run at app startup. 19 | // 20 | // The order you register the functions will be the order they are run. 21 | // You can think of it as a FIFO queue. 22 | // This process will happen after the config file is read 23 | // and before the server is listening for connections. 24 | // 25 | // Ideally, your application should have only one call to init() in the file init.go. 26 | // The reason being that the call order of multiple init() functions in 27 | // the same package is undefined. 28 | // Inside of init() call mars.OnAppStart() for each function you wish to register. 29 | // 30 | // Example: 31 | // 32 | // // from: yourapp/app/controllers/somefile.go 33 | // func InitDB() { 34 | // // do DB connection stuff here 35 | // } 36 | // 37 | // func FillCache() { 38 | // // fill a cache from DB 39 | // // this depends on InitDB having been run 40 | // } 41 | // 42 | // // from: yourapp/app/init.go 43 | // func init() { 44 | // // set up filters... 45 | // 46 | // // register startup functions 47 | // mars.OnAppStart(InitDB) 48 | // mars.OnAppStart(FillCache) 49 | // } 50 | // 51 | // This can be useful when you need to establish connections to databases or third-party services, 52 | // setup app components, compile assets, or any thing you need to do between starting Mars and accepting connections. 53 | // 54 | func OnAppStart(f func()) { 55 | startupHooks = append(startupHooks, f) 56 | } 57 | 58 | // OnAppShutdown register a function to be run at app shutdown. 59 | // 60 | // The order you register the functions will be the order they are run. 61 | // You can think of it as a FIFO queue. 62 | // This process will happen after the HTTP servers have stopped listening. 63 | // 64 | // Ideally, your application should have only one call to init() in the file init.go. 65 | // The reason being that the call order of multiple init() functions in 66 | // the same package is undefined. 67 | // Inside of init() call mars.OnAppShutdown() for each function you wish to register. 68 | // 69 | // See also OnAppStart 70 | func OnAppShutdown(f func()) { 71 | shutdownHooks = append(shutdownHooks, f) 72 | } 73 | -------------------------------------------------------------------------------- /validation_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | ) 8 | 9 | // getRecordedCookie returns the recorded cookie from a ResponseRecorder with 10 | // the given name. It utilizes the cookie reader found in the standard library. 11 | func getRecordedCookie(recorder *httptest.ResponseRecorder, name string) (*http.Cookie, error) { 12 | r := &http.Response{Header: recorder.HeaderMap} 13 | for _, cookie := range r.Cookies() { 14 | if cookie.Name == name { 15 | return cookie, nil 16 | } 17 | } 18 | return nil, http.ErrNoCookie 19 | } 20 | 21 | func validationTester(req *Request, fn func(c *Controller)) *httptest.ResponseRecorder { 22 | recorder := httptest.NewRecorder() 23 | c := NewController(req, NewResponse(recorder)) 24 | ValidationFilter(c, []Filter{func(c *Controller, _ []Filter) { 25 | fn(c) 26 | }}) 27 | return recorder 28 | } 29 | 30 | // Test that errors are encoded into the _ERRORS cookie. 31 | func TestValidationWithError(t *testing.T) { 32 | recorder := validationTester(buildEmptyRequest(), func(c *Controller) { 33 | c.Validation.Required("") 34 | if !c.Validation.HasErrors() { 35 | t.Fatal("errors should be present") 36 | } 37 | c.Validation.Keep() 38 | }) 39 | 40 | if cookie, err := getRecordedCookie(recorder, "MARS_ERRORS"); err != nil { 41 | t.Fatal(err) 42 | } else if cookie.MaxAge < 0 { 43 | t.Fatalf("cookie should not expire") 44 | } 45 | } 46 | 47 | // Test that no cookie is sent if errors are found, but Keep() is not called. 48 | func TestValidationNoKeep(t *testing.T) { 49 | recorder := validationTester(buildEmptyRequest(), func(c *Controller) { 50 | c.Validation.Required("") 51 | if !c.Validation.HasErrors() { 52 | t.Fatal("errors should not be present") 53 | } 54 | }) 55 | 56 | if _, err := getRecordedCookie(recorder, "MARS_ERRORS"); err != http.ErrNoCookie { 57 | t.Fatal(err) 58 | } 59 | } 60 | 61 | // Test that a previously set _ERRORS cookie is deleted if no errors are found. 62 | func TestValidationNoKeepCookiePreviouslySet(t *testing.T) { 63 | req := buildRequestWithCookie("MARS_ERRORS", "invalid") 64 | recorder := validationTester(req, func(c *Controller) { 65 | c.Validation.Required("success") 66 | if c.Validation.HasErrors() { 67 | t.Fatal("errors should not be present") 68 | } 69 | }) 70 | 71 | if cookie, err := getRecordedCookie(recorder, "MARS_ERRORS"); err != nil { 72 | t.Fatal(err) 73 | } else if cookie.MaxAge >= 0 { 74 | t.Fatalf("cookie should be deleted") 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /fakeapp_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | "os" 7 | "path" 8 | "reflect" 9 | ) 10 | 11 | type Hotel struct { 12 | HotelID int 13 | Name, Address string 14 | City, State, Zip string 15 | Country string 16 | Price int 17 | } 18 | 19 | type Hotels struct { 20 | *Controller 21 | } 22 | 23 | type MyStatic struct { 24 | *Controller 25 | } 26 | 27 | func (c Hotels) Show(id int) Result { 28 | title := "View Hotel" 29 | hotel := &Hotel{id, "A Hotel", "300 Main St.", "New York", "NY", "10010", "USA", 300} 30 | return c.Render(Args{"title": title, "hotel": hotel}) 31 | } 32 | 33 | func (c Hotels) Book(id int) Result { 34 | hotel := &Hotel{id, "A Hotel", "300 Main St.", "New York", "NY", "10010", "USA", 300} 35 | return c.RenderJSON(hotel) 36 | } 37 | 38 | func (c Hotels) Index() Result { 39 | return c.RenderText("Hello, World!") 40 | } 41 | 42 | func (c Hotels) Boom() Result { 43 | panic("OMG") 44 | } 45 | 46 | func (c MyStatic) Serve(prefix, filepath string) Result { 47 | var basePath, dirName string 48 | 49 | if !path.IsAbs(dirName) { 50 | basePath = BasePath 51 | } 52 | 53 | fname := path.Join(basePath, prefix, filepath) 54 | file, err := os.Open(fname) 55 | if os.IsNotExist(err) { 56 | return c.NotFound("") 57 | } else if err != nil { 58 | WARN.Printf("Problem opening file (%s): %s ", fname, err) 59 | return c.NotFound("This was found but not sure why we couldn't open it.") 60 | } 61 | return c.RenderFile(file, "") 62 | } 63 | 64 | func startFakeBookingApp() { 65 | RegisterController((*Hotels)(nil), 66 | []*MethodType{ 67 | { 68 | Name: "Index", 69 | }, 70 | { 71 | Name: "Boom", 72 | }, 73 | { 74 | Name: "Show", 75 | Args: []*MethodArg{ 76 | {"id", reflect.TypeOf((*int)(nil))}, 77 | }, 78 | }, 79 | { 80 | Name: "Book", 81 | Args: []*MethodArg{ 82 | {"id", reflect.TypeOf((*int)(nil))}, 83 | }, 84 | }, 85 | }) 86 | 87 | RegisterController((*Static)(nil), 88 | []*MethodType{ 89 | { 90 | Name: "Serve", 91 | Args: []*MethodArg{ 92 | {Name: "prefix", Type: reflect.TypeOf((*string)(nil))}, 93 | {Name: "filepath", Type: reflect.TypeOf((*string)(nil))}, 94 | }, 95 | }, 96 | }) 97 | 98 | // Disable logging. 99 | _ = ioutil.Discard 100 | TRACE = log.New(os.Stderr, "", 0) 101 | INFO = TRACE 102 | WARN = TRACE 103 | ERROR = TRACE 104 | 105 | InitDefaults("prod", "testdata") 106 | Setup() 107 | } 108 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | ) 8 | 9 | // An error description, used as an argument to the error template. 10 | type Error struct { 11 | SourceType string // The type of source that failed to build. 12 | Title, Path, Description string // Description of the error, as presented to the user. 13 | Line, Column int // Where the error was encountered. 14 | SourceLines []string // The entire source file, split into lines. 15 | Stack string // The raw stack trace string from debug.Stack(). 16 | MetaError string // Error that occurred producing the error page. 17 | Link string // A configurable link to wrap the error source in 18 | } 19 | 20 | var _ error = &Error{} 21 | 22 | // An object to hold the per-source-line details. 23 | type sourceLine struct { 24 | Source string 25 | Line int 26 | IsError bool 27 | } 28 | 29 | // Construct a plaintext version of the error, taking account that fields are optionally set. 30 | // Returns e.g. Compilation Error (in views/header.html:51): expected right delim in end; got "}" 31 | func (e *Error) Error() string { 32 | if e == nil { 33 | return "" 34 | } 35 | 36 | loc := "" 37 | if e.Path != "" { 38 | line := "" 39 | if e.Line != 0 { 40 | line = fmt.Sprintf(":%d", e.Line) 41 | } 42 | loc = fmt.Sprintf("(in %s%s)", e.Path, line) 43 | } 44 | header := loc 45 | if e.Title != "" { 46 | if loc != "" { 47 | header = fmt.Sprintf("%s %s: ", e.Title, loc) 48 | } else { 49 | header = fmt.Sprintf("%s: ", e.Title) 50 | } 51 | } 52 | return fmt.Sprintf("%s%s", header, e.Description) 53 | } 54 | 55 | // Returns a snippet of the source around where the error occurred. 56 | func (e *Error) ContextSource() []sourceLine { 57 | if e.SourceLines == nil { 58 | return nil 59 | } 60 | start := (e.Line - 1) - 5 61 | if start < 0 { 62 | start = 0 63 | } 64 | end := (e.Line - 1) + 5 65 | if end > len(e.SourceLines) { 66 | end = len(e.SourceLines) 67 | } 68 | 69 | var lines []sourceLine = make([]sourceLine, end-start) 70 | for i, src := range e.SourceLines[start:end] { 71 | fileLine := start + i + 1 72 | lines[i] = sourceLine{src, fileLine, fileLine == e.Line} 73 | } 74 | return lines 75 | } 76 | 77 | func (e *Error) SetLink(errorLink string) { 78 | errorLink = strings.Replace(errorLink, "{{Path}}", e.Path, -1) 79 | errorLink = strings.Replace(errorLink, "{{Line}}", strconv.Itoa(e.Line), -1) 80 | 81 | e.Link = "" + e.Path + ":" + strconv.Itoa(e.Line) + "" 82 | } 83 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ master ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ master ] 20 | schedule: 21 | - cron: '16 0 * * 1' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'go' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 52 | 53 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 54 | # If this step fails, then you should remove it and run the build manually (see below) 55 | - name: Autobuild 56 | uses: github/codeql-action/autobuild@v2 57 | 58 | # ℹ️ Command-line programs to run using the OS shell. 59 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 60 | 61 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 62 | # and modify them (or add more) to build your code if your project 63 | # uses a compiled language 64 | 65 | #- run: | 66 | # make bootstrap 67 | # make release 68 | 69 | - name: Perform CodeQL Analysis 70 | uses: github/codeql-action/analyze@v2 71 | -------------------------------------------------------------------------------- /intercept_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | var funcP = func(c *Controller) Result { return nil } 9 | var funcP2 = func(c *Controller) Result { return nil } 10 | 11 | type InterceptController struct{ *Controller } 12 | type InterceptControllerN struct{ InterceptController } 13 | type InterceptControllerP struct{ *InterceptController } 14 | type InterceptControllerNP struct { 15 | *Controller 16 | InterceptControllerN 17 | InterceptControllerP 18 | } 19 | 20 | func (c InterceptController) methN() Result { return nil } 21 | func (c *InterceptController) methP() Result { return nil } 22 | 23 | func (c InterceptControllerN) methNN() Result { return nil } 24 | func (c *InterceptControllerN) methNP() Result { return nil } 25 | func (c InterceptControllerP) methPN() Result { return nil } 26 | func (c *InterceptControllerP) methPP() Result { return nil } 27 | 28 | // Methods accessible from InterceptControllerN 29 | var MethodsN = []interface{}{ 30 | InterceptController.methN, 31 | (*InterceptController).methP, 32 | InterceptControllerN.methNN, 33 | (*InterceptControllerN).methNP, 34 | } 35 | 36 | // Methods accessible from InterceptControllerP 37 | var MethodsP = []interface{}{ 38 | InterceptController.methN, 39 | (*InterceptController).methP, 40 | InterceptControllerP.methPN, 41 | (*InterceptControllerP).methPP, 42 | } 43 | 44 | // This checks that all the various kinds of interceptor functions/methods are 45 | // properly invoked. 46 | func TestInvokeArgType(t *testing.T) { 47 | n := InterceptControllerN{InterceptController{&Controller{}}} 48 | p := InterceptControllerP{&InterceptController{&Controller{}}} 49 | np := InterceptControllerNP{&Controller{}, n, p} 50 | testInterceptorController(t, reflect.ValueOf(&n), MethodsN) 51 | testInterceptorController(t, reflect.ValueOf(&p), MethodsP) 52 | testInterceptorController(t, reflect.ValueOf(&np), MethodsN) 53 | testInterceptorController(t, reflect.ValueOf(&np), MethodsP) 54 | } 55 | 56 | func testInterceptorController(t *testing.T, appControllerPtr reflect.Value, methods []interface{}) { 57 | interceptors = []*Interception{} 58 | InterceptFunc(funcP, BEFORE, appControllerPtr.Elem().Interface()) 59 | InterceptFunc(funcP2, BEFORE, AllControllers) 60 | for _, m := range methods { 61 | InterceptMethod(m, BEFORE) 62 | } 63 | ints := getInterceptors(BEFORE, appControllerPtr) 64 | 65 | if len(ints) != 6 { 66 | t.Fatalf("N: Expected 6 interceptors, got %d.", len(ints)) 67 | } 68 | 69 | testInterception(t, ints[0], reflect.ValueOf(&Controller{})) 70 | testInterception(t, ints[1], reflect.ValueOf(&Controller{})) 71 | for i := range methods { 72 | testInterception(t, ints[i+2], appControllerPtr) 73 | } 74 | } 75 | 76 | func testInterception(t *testing.T, intc *Interception, arg reflect.Value) { 77 | val := intc.Invoke(arg) 78 | if !val.IsNil() { 79 | t.Errorf("Failed (%v): Expected nil got %v", intc, val) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /docs/migration.md: -------------------------------------------------------------------------------- 1 | ## Moving from Revel to Mars in 7 steps 2 | 3 | 1. Add the dependency: 4 | - Add `github.com/roblillack/mars` to your depedencies using the Go vendoring tool of your choice, or 5 | - Add said repository as a git submodule, or 6 | - Just run `go get github.com/roblillack/mars` (which is Go's “I'm feeling lucky” button) 7 | 2. Replace all occurences of the `revel` package with `mars`. This will mainly be import paths and 8 | action results (`mars.Result` instead of `revel.Result`), but also things like accessing the config 9 | or logging. You can pretty much automate this. 10 | 3. Fix the case for some of the rendering functions your code might call: 11 | - RenderJson -> RenderJSON 12 | - RenderJsonP -> RenderJSONP 13 | - RenderXml -> RenderXML 14 | - RenderHtml -> RenderHTML 15 | 4. Set a [Key](https://godoc.org/github.com/roblillack/mars#ValidationResult.Key) for all validation result, 16 | because Mars will _not_ guess this based on variable names. Something like `c.Validation.Required(email)` becomes 17 | `c.Validation.Required(email).Key("email")` 18 | 5. Install mars-gen using `go get github.com/roblillack/mars/cmd/mars-gen` and set it up for 19 | controller registration and reverse route generation by adding comments like these to one of Go files: 20 | 21 | //go:generate mars-gen register-controllers ./controllers 22 | //go:generate mars-gen reverse-routes -n routes -o routes/routes.gen.go ./controllers 23 | 24 | Make sure to check in the generated sources, too. Run `mars-gen --help` for usage info. 25 | 6. Setup a main entry point for your server, for example like this: 26 | 27 | package main 28 | 29 | import ( 30 | "flag" 31 | "path" 32 | "github.com/mycompany/myapp/controllers" 33 | "github.com/roblillack/mars" 34 | ) 35 | 36 | func main() { 37 | mode := flag.String("m", "prod", "Runtime mode to select (default: prod)") 38 | flag.Parse() 39 | 40 | // This is the function `mars-gen register-controllers` generates: 41 | controllers.RegisterControllers() 42 | 43 | // Setup some paths to be compatible with the Revel way. Default is not to have an "app" directory below BasePath 44 | mars.ViewsPath = path.Join("app", "views") 45 | mars.ConfigFile = path.Join("app", "conf", "app.conf") 46 | mars.RoutesFile = path.Join("app", "conf", "routes") 47 | 48 | // Ok, we should never, ever, ever disable CSRF protection. 49 | // But to stay compatible with Revel's defaults .... 50 | // Read https://godoc.org/github.com/roblillack/mars#CSRFFilter about what to do to enable this again. 51 | mars.DisableCSRF = true 52 | 53 | // Reads the config, sets up template loader, creates router 54 | mars.InitDefaults(mode, ".") 55 | 56 | mars.Run() 57 | } 58 | 7. Run `go generate && go build && ./myapp` and be happy. 59 | -------------------------------------------------------------------------------- /internal/watcher/watcher_test.go: -------------------------------------------------------------------------------- 1 | package watcher 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "math/rand" 7 | "os" 8 | "path/filepath" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | type SimpleRefresher struct { 14 | Refreshed bool 15 | Error error 16 | } 17 | 18 | func (l *SimpleRefresher) Refresh() error { 19 | l.Refreshed = true 20 | return l.Error 21 | } 22 | 23 | func TestWatcher(t *testing.T) { 24 | w := New() 25 | 26 | tmp := filepath.Join(os.TempDir(), fmt.Sprintf("mars-watcher-test-%d", rand.Uint32())) 27 | err := os.MkdirAll(tmp, 0700) 28 | if err != nil { 29 | t.Fatal(err) 30 | } 31 | 32 | bla := &SimpleRefresher{} 33 | if err := w.Listen(bla, tmp); err != nil { 34 | t.Errorf("unable to setup listener: %s", err) 35 | } 36 | 37 | if err := w.Notify(); err != nil { 38 | t.Errorf("unable to notify listeners: %s", err) 39 | } 40 | if bla.Refreshed { 41 | t.Error("No changes to tmp dir yet, should not have been refreshed.") 42 | } 43 | 44 | bla.Refreshed = false 45 | if f, err := os.Create(filepath.Join(tmp, "yep.dada")); err != nil { 46 | t.Fatal(err) 47 | } else { 48 | fmt.Fprintln(f, "Hello world!") 49 | f.Close() 50 | } 51 | 52 | time.Sleep(1 * time.Second) 53 | 54 | if err := w.Notify(); err != nil { 55 | t.Errorf("unable to notify listeners: %s", err) 56 | } 57 | if !bla.Refreshed { 58 | t.Error("Should have been refreshed.") 59 | } 60 | 61 | if err := os.RemoveAll(tmp); err != nil { 62 | t.Fatal(err) 63 | } 64 | } 65 | 66 | func TestErrorWhileRefreshing(t *testing.T) { 67 | w := New() 68 | 69 | tmp := filepath.Join(os.TempDir(), fmt.Sprintf("mars-watcher-test-%d", rand.Uint32())) 70 | err := os.MkdirAll(tmp, 0700) 71 | if err != nil { 72 | t.Fatal(err) 73 | } 74 | 75 | bla := &SimpleRefresher{Error: errors.New("uh-oh something went wrong!!!11")} 76 | if err := w.Listen(bla, tmp); err != nil { 77 | t.Errorf("unable to setup listener: %s", err) 78 | } 79 | 80 | if err := w.Notify(); err != nil { 81 | t.Errorf("unable to notify listeners: %s", err) 82 | } 83 | if bla.Refreshed { 84 | t.Error("No changes to tmp dir yet, should not have been refreshed.") 85 | } 86 | 87 | bla.Refreshed = false 88 | if f, err := os.Create(filepath.Join(tmp, "yep.dada")); err != nil { 89 | t.Fatal(err) 90 | } else { 91 | fmt.Fprintln(f, "Hello world!") 92 | f.Close() 93 | } 94 | 95 | time.Sleep(1 * time.Second) 96 | 97 | if err := w.Notify(); err == nil { 98 | t.Error("No error while refreshing") 99 | } else if err != bla.Error { 100 | t.Error("Wrong error seen while refreshing: %w", err) 101 | } 102 | if !bla.Refreshed { 103 | t.Error("Should have been refreshed.") 104 | } 105 | 106 | bla.Refreshed = false 107 | bla.Error = nil 108 | time.Sleep(1 * time.Second) 109 | 110 | if err := w.Notify(); err != nil { 111 | t.Errorf("error not resolved yet: %s", err) 112 | } 113 | if !bla.Refreshed { 114 | t.Error("Should have been refreshed.") 115 | } 116 | 117 | if err := os.RemoveAll(tmp); err != nil { 118 | t.Fatal(err) 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "os" 7 | "path" 8 | "strings" 9 | "testing" 10 | ) 11 | 12 | // This tries to benchmark the usual request-serving pipeline to get an overall 13 | // performance metric. 14 | // 15 | // Each iteration runs one mock request to display a hotel's detail page by id. 16 | // 17 | // Contributing parts: 18 | // - Routing 19 | // - Controller lookup / invocation 20 | // - Parameter binding 21 | // - Session, flash, i18n cookies 22 | // - Render() call magic 23 | // - Template rendering 24 | func BenchmarkServeAction(b *testing.B) { 25 | benchmarkRequest(b, showRequest) 26 | } 27 | 28 | func BenchmarkServeJson(b *testing.B) { 29 | benchmarkRequest(b, jsonRequest) 30 | } 31 | 32 | func BenchmarkServePlaintext(b *testing.B) { 33 | benchmarkRequest(b, plaintextRequest) 34 | } 35 | 36 | // This tries to benchmark the static serving overhead when serving an "average 37 | // size" 7k file. 38 | func BenchmarkServeStatic(b *testing.B) { 39 | benchmarkRequest(b, staticRequest) 40 | } 41 | 42 | func benchmarkRequest(b *testing.B, req *http.Request) { 43 | startFakeBookingApp() 44 | b.ResetTimer() 45 | resp := httptest.NewRecorder() 46 | for i := 0; i < b.N; i++ { 47 | handle(resp, req) 48 | } 49 | } 50 | 51 | // Test that the booking app can be successfully run for a test. 52 | func TestFakeServer(t *testing.T) { 53 | startFakeBookingApp() 54 | 55 | resp := httptest.NewRecorder() 56 | 57 | // First, test that the expected responses are actually generated 58 | handle(resp, showRequest) 59 | if !strings.Contains(resp.Body.String(), "300 Main St.") { 60 | t.Errorf("Failed to find hotel address in action response:\n%s", resp.Body) 61 | t.FailNow() 62 | } 63 | resp.Body.Reset() 64 | 65 | handle(resp, staticRequest) 66 | sessvarsSize := getFileSize(t, path.Join(BasePath, "public", "js", "sessvars.js")) 67 | if int64(resp.Body.Len()) != sessvarsSize { 68 | t.Errorf("Expected sessvars.js to have %d bytes, got %d:\n%s", sessvarsSize, resp.Body.Len(), resp.Body) 69 | t.FailNow() 70 | } 71 | resp.Body.Reset() 72 | 73 | handle(resp, jsonRequest) 74 | if !strings.Contains(resp.Body.String(), `"Address":"300 Main St."`) { 75 | t.Errorf("Failed to find hotel address in JSON response:\n%s", resp.Body) 76 | t.FailNow() 77 | } 78 | resp.Body.Reset() 79 | 80 | handle(resp, plaintextRequest) 81 | if resp.Body.String() != "Hello, World!" { 82 | t.Errorf("Failed to find greeting in plaintext response:\n%s", resp.Body) 83 | t.FailNow() 84 | } 85 | 86 | resp.Body = nil 87 | } 88 | 89 | func getFileSize(t *testing.T, name string) int64 { 90 | fi, err := os.Stat(name) 91 | if err != nil { 92 | t.Errorf("Unable to stat file:\n%s", name) 93 | t.FailNow() 94 | } 95 | return fi.Size() 96 | } 97 | 98 | var ( 99 | showRequest, _ = http.NewRequest("GET", "/hotels/3", nil) 100 | staticRequest, _ = http.NewRequest("GET", "/public/js/sessvars.js", nil) 101 | jsonRequest, _ = http.NewRequest("GET", "/hotels/3/booking", nil) 102 | plaintextRequest, _ = http.NewRequest("GET", "/hotels", nil) 103 | ) 104 | -------------------------------------------------------------------------------- /templates/errors/500-dev.html: -------------------------------------------------------------------------------- 1 | 93 | {{with .Error}} 94 | 104 | {{if .Path}} 105 |
106 |

In {{.Path}} 107 | {{if .Line}} 108 | (around {{if .Line}}line {{.Line}}{{end}}{{if .Column}} column {{.Column}}{{end}}) 109 | {{end}} 110 |

111 | {{range .ContextSource}} 112 |
113 | {{.Line}}: 114 |
{{.Source}}
115 |
116 | {{end}} 117 |
118 | {{end}} 119 | {{if .Stack}} 120 |
121 |

Call Stack

122 | {{.Stack}} 123 |
124 | {{end}} 125 | {{if .MetaError}} 126 |
127 |

Additionally, an error occurred while handling this error.

128 |
129 | {{.MetaError}} 130 |
131 |
132 | {{end}} 133 | {{end}} 134 | -------------------------------------------------------------------------------- /templates_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "bytes" 5 | "html/template" 6 | "net/http" 7 | "net/http/httptest" 8 | "path/filepath" 9 | "runtime" 10 | "strings" 11 | "testing" 12 | ) 13 | 14 | func setupTemplateTestingApp() { 15 | _, filename, _, _ := runtime.Caller(0) 16 | BasePath = filepath.Join(filepath.Dir(filename), "testdata") 17 | SetupViews() 18 | } 19 | 20 | func TestContextAwareRenderFuncs(t *testing.T) { 21 | setupTemplateTestingApp() 22 | loadMessages(testDataPath) 23 | 24 | for expected, input := range map[string]interface{}{ 25 | "

Hey, there Rob!

": "Rob", 26 | "

Hey, there <3!

": Blarp("<3"), 27 | } { 28 | result := runRequest("en", "i18n_ctx.html", Args{"input": input}) 29 | if result != expected { 30 | t.Errorf("Expected '%s', got '%s' for input '%s'", expected, result, input) 31 | } 32 | } 33 | } 34 | 35 | func simulateRequest(format, view string) string { 36 | w := httptest.NewRecorder() 37 | httpRequest, _ := http.NewRequest("GET", "/", nil) 38 | req := NewRequest(httpRequest) 39 | req.Format = format 40 | c := NewController(req, &Response{Out: w}) 41 | c.RenderTemplate(view).Apply(c.Request, c.Response) 42 | 43 | buf := &bytes.Buffer{} 44 | buf.ReadFrom(w.Body) 45 | return buf.String() 46 | } 47 | 48 | func TestTemplateNotAvailable(t *testing.T) { 49 | setupTemplateTestingApp() 50 | expectedString := "Template non_existant.html not found." 51 | 52 | if resp := simulateRequest("html", "non_existant.html"); !strings.Contains(resp, expectedString) { 53 | t.Error("Error rendering template error message for plaintext requests. Got:", resp) 54 | } 55 | if resp := simulateRequest("txt", "non_existant.html"); !strings.Contains(resp, expectedString) { 56 | t.Error("Error rendering template error message for plaintext requests. Got:", resp) 57 | } 58 | } 59 | 60 | func TestTemplateFuncs(t *testing.T) { 61 | type Scenario struct { 62 | T string 63 | D Args 64 | R string 65 | E string 66 | } 67 | for _, scenario := range []Scenario{ 68 | { 69 | `{{.title}}`, 70 | Args{"title": "This is a Blog Post!"}, 71 | `This is a Blog Post!`, 72 | ``, 73 | }, 74 | { 75 | `{{raw .title}}`, 76 | Args{"title": "bla"}, 77 | `bla`, 78 | ``, 79 | }, 80 | { 81 | `{{if even .no}}yes{{else}}no{{end}}`, 82 | Args{"no": 0}, 83 | `yes`, 84 | ``, 85 | }, 86 | { 87 | `{{if even .no}}yes{{else}}no{{end}}`, 88 | Args{"no": 1}, 89 | `no`, 90 | ``, 91 | }, 92 | } { 93 | tmpl, err := template.New("foo").Funcs(TemplateFuncs).Parse(scenario.T) 94 | if err != nil { 95 | t.Error(err) 96 | } 97 | buf := &strings.Builder{} 98 | err = goTemplateWrapper{loader: nil, funcMap: nil, Template: tmpl}.Template.Execute(buf, scenario.D) 99 | if err != nil { 100 | t.Error(err) 101 | } 102 | if res := buf.String(); res != scenario.R { 103 | t.Errorf("Expected '%s', got '%s' for input '%s'", scenario.R, res, scenario.T) 104 | } 105 | } 106 | } 107 | 108 | func TestTemplateParsingErrors(t *testing.T) { 109 | for _, scenario := range []string{ 110 | `{{.uhoh}`, 111 | `{{if .condition}}look{{else}}there's no end here`, 112 | `{{undefined_function .parameter}}`, 113 | } { 114 | _, err := template.New("foo").Funcs(TemplateFuncs).Parse(scenario) 115 | if err == nil { 116 | t.Errorf("No error when parsing: %s", scenario) 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "github.com/robfig/config" 5 | "strings" 6 | ) 7 | 8 | // MergedConfig handles the parsing of app.conf 9 | // It has a "preferred" section that is checked first for option queries. 10 | // If the preferred section does not have the option, the DEFAULT section is 11 | // checked fallback. 12 | type MergedConfig struct { 13 | config *config.Config 14 | section string // Check this section first, then fall back to DEFAULT 15 | } 16 | 17 | func NewEmptyConfig() *MergedConfig { 18 | return &MergedConfig{config.NewDefault(), ""} 19 | } 20 | 21 | func LoadConfig(confName string) (*MergedConfig, error) { 22 | conf, err := config.ReadDefault(confName) 23 | if err == nil { 24 | return &MergedConfig{conf, ""}, nil 25 | } 26 | 27 | return nil, err 28 | } 29 | 30 | func (c *MergedConfig) Raw() *config.Config { 31 | return c.config 32 | } 33 | 34 | func (c *MergedConfig) SetSection(section string) { 35 | c.section = section 36 | } 37 | 38 | func (c *MergedConfig) SetOption(name, value string) { 39 | c.config.AddOption(c.section, name, value) 40 | } 41 | 42 | func (c *MergedConfig) Int(option string) (result int, found bool) { 43 | result, err := c.config.Int(c.section, option) 44 | if err == nil { 45 | return result, true 46 | } 47 | if _, ok := err.(config.OptionError); ok { 48 | return 0, false 49 | } 50 | 51 | // If it wasn't an OptionError, it must have failed to parse. 52 | ERROR.Println("Failed to parse config option", option, "as int:", err) 53 | return 0, false 54 | } 55 | 56 | func (c *MergedConfig) IntDefault(option string, dfault int) int { 57 | if r, found := c.Int(option); found { 58 | return r 59 | } 60 | return dfault 61 | } 62 | 63 | func (c *MergedConfig) Bool(option string) (result, found bool) { 64 | result, err := c.config.Bool(c.section, option) 65 | if err == nil { 66 | return result, true 67 | } 68 | if _, ok := err.(config.OptionError); ok { 69 | return false, false 70 | } 71 | 72 | // If it wasn't an OptionError, it must have failed to parse. 73 | ERROR.Println("Failed to parse config option", option, "as bool:", err) 74 | return false, false 75 | } 76 | 77 | func (c *MergedConfig) BoolDefault(option string, dfault bool) bool { 78 | if r, found := c.Bool(option); found { 79 | return r 80 | } 81 | return dfault 82 | } 83 | 84 | func (c *MergedConfig) String(option string) (result string, found bool) { 85 | if r, err := c.config.String(c.section, option); err == nil { 86 | return stripQuotes(r), true 87 | } 88 | return "", false 89 | } 90 | 91 | func (c *MergedConfig) StringDefault(option, dfault string) string { 92 | if r, found := c.String(option); found { 93 | return r 94 | } 95 | return dfault 96 | } 97 | 98 | func (c *MergedConfig) HasSection(section string) bool { 99 | return c.config.HasSection(section) 100 | } 101 | 102 | // Options returns all configuration option keys. 103 | // If a prefix is provided, then that is applied as a filter. 104 | func (c *MergedConfig) Options(prefix string) []string { 105 | var options []string 106 | keys, _ := c.config.Options(c.section) 107 | for _, key := range keys { 108 | if strings.HasPrefix(key, prefix) { 109 | options = append(options, key) 110 | } 111 | } 112 | return options 113 | } 114 | 115 | // Helpers 116 | 117 | func stripQuotes(s string) string { 118 | if s == "" { 119 | return s 120 | } 121 | 122 | if s[0] == '"' && s[len(s)-1] == '"' { 123 | return s[1 : len(s)-1] 124 | } 125 | 126 | return s 127 | } 128 | -------------------------------------------------------------------------------- /invoker_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "net/url" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | // These tests verify that Controllers are initialized properly, given the range 10 | // of embedding possibilities.. 11 | 12 | type P struct{ *Controller } 13 | 14 | type PN struct{ P } 15 | 16 | type PNN struct{ PN } 17 | 18 | // Embedded via two paths 19 | type P2 struct{ *Controller } 20 | type PP2 struct { 21 | *Controller // Need to embed this explicitly to avoid duplicate selector. 22 | P 23 | P2 24 | PNN 25 | } 26 | 27 | var GENERATIONS = []interface{}{P{}, PN{}, PNN{}} 28 | 29 | func TestFindControllers(t *testing.T) { 30 | controllers = make(map[string]*ControllerType) 31 | RegisterController((*P)(nil), nil) 32 | RegisterController((*PN)(nil), nil) 33 | RegisterController((*PNN)(nil), nil) 34 | RegisterController((*PP2)(nil), nil) 35 | 36 | // Test construction of indexes to each *Controller 37 | checkSearchResults(t, P{}, [][]int{{0}}) 38 | checkSearchResults(t, PN{}, [][]int{{0, 0}}) 39 | checkSearchResults(t, PNN{}, [][]int{{0, 0, 0}}) 40 | checkSearchResults(t, PP2{}, [][]int{{0}, {1, 0}, {2, 0}, {3, 0, 0, 0}}) 41 | } 42 | 43 | func checkSearchResults(t *testing.T, obj interface{}, expected [][]int) { 44 | actual := findControllers(reflect.TypeOf(obj)) 45 | if !reflect.DeepEqual(expected, actual) { 46 | t.Errorf("Indexes do not match. expected %v actual %v", expected, actual) 47 | } 48 | } 49 | 50 | func TestSetAction(t *testing.T) { 51 | controllers = make(map[string]*ControllerType) 52 | RegisterController((*P)(nil), []*MethodType{{Name: "Method"}}) 53 | RegisterController((*PNN)(nil), []*MethodType{{Name: "Method"}}) 54 | RegisterController((*PP2)(nil), []*MethodType{{Name: "Method"}}) 55 | 56 | // Test that all *mars.Controllers are initialized. 57 | c := &Controller{Name: "Test"} 58 | if err := c.SetAction("P", "Method"); err != nil { 59 | t.Error(err) 60 | } else if c.AppController.(*P).Controller != c { 61 | t.Errorf("P not initialized") 62 | } 63 | 64 | if err := c.SetAction("PNN", "Method"); err != nil { 65 | t.Error(err) 66 | } else if c.AppController.(*PNN).Controller != c { 67 | t.Errorf("PNN not initialized") 68 | } 69 | 70 | // PP2 has 4 different slots for *Controller. 71 | if err := c.SetAction("PP2", "Method"); err != nil { 72 | t.Error(err) 73 | } else if pp2 := c.AppController.(*PP2); pp2.Controller != c || 74 | pp2.P.Controller != c || 75 | pp2.P2.Controller != c || 76 | pp2.PNN.Controller != c { 77 | t.Errorf("PP2 not initialized") 78 | } 79 | } 80 | 81 | func BenchmarkSetAction(b *testing.B) { 82 | type Mixin1 struct { 83 | *Controller 84 | x, y int 85 | foo string 86 | } 87 | type Mixin2 struct { 88 | *Controller 89 | a, b float64 90 | bar string 91 | } 92 | 93 | type Benchmark struct { 94 | *Controller 95 | Mixin1 96 | Mixin2 97 | user interface{} 98 | guy string 99 | } 100 | 101 | RegisterController((*Mixin1)(nil), []*MethodType{{Name: "Method"}}) 102 | RegisterController((*Mixin2)(nil), []*MethodType{{Name: "Method"}}) 103 | RegisterController((*Benchmark)(nil), []*MethodType{{Name: "Method"}}) 104 | c := Controller{ 105 | RenderArgs: make(map[string]interface{}), 106 | } 107 | 108 | for i := 0; i < b.N; i++ { 109 | if err := c.SetAction("Benchmark", "Method"); err != nil { 110 | b.Errorf("Failed to set action: %s", err) 111 | return 112 | } 113 | } 114 | } 115 | 116 | func BenchmarkInvoker(b *testing.B) { 117 | startFakeBookingApp() 118 | c := Controller{ 119 | RenderArgs: make(map[string]interface{}), 120 | } 121 | if err := c.SetAction("Hotels", "Show"); err != nil { 122 | b.Errorf("Failed to set action: %s", err) 123 | return 124 | } 125 | c.Request = NewRequest(showRequest) 126 | c.Params = &Params{Values: make(url.Values)} 127 | c.Params.Set("id", "3") 128 | 129 | b.ResetTimer() 130 | for i := 0; i < b.N; i++ { 131 | ActionInvoker(&c, nil) 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /filterconfig_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import "testing" 4 | 5 | type FakeController struct{} 6 | 7 | func (c FakeController) Foo() {} 8 | func (c *FakeController) Bar() {} 9 | 10 | func TestFilterConfiguratorKey(t *testing.T) { 11 | conf := FilterController(FakeController{}) 12 | if conf.key != "FakeController" { 13 | t.Errorf("Expected key 'FakeController', was %s", conf.key) 14 | } 15 | 16 | conf = FilterController(&FakeController{}) 17 | if conf.key != "FakeController" { 18 | t.Errorf("Expected key 'FakeController', was %s", conf.key) 19 | } 20 | 21 | conf = FilterAction(FakeController.Foo) 22 | if conf.key != "FakeController.Foo" { 23 | t.Errorf("Expected key 'FakeController.Foo', was %s", conf.key) 24 | } 25 | 26 | conf = FilterAction((*FakeController).Bar) 27 | if conf.key != "FakeController.Bar" { 28 | t.Errorf("Expected key 'FakeController.Bar', was %s", conf.key) 29 | } 30 | } 31 | 32 | func TestFilterConfigurator(t *testing.T) { 33 | // Filters is global state. Restore it after this test. 34 | oldFilters := make([]Filter, len(Filters)) 35 | copy(oldFilters, Filters) 36 | defer func() { 37 | Filters = oldFilters 38 | }() 39 | 40 | Filters = []Filter{ 41 | RouterFilter, 42 | FilterConfiguringFilter, 43 | SessionFilter, 44 | FlashFilter, 45 | ActionInvoker, 46 | } 47 | 48 | // Do one of each operation. 49 | conf := FilterAction(FakeController.Foo). 50 | Add(NilFilter). 51 | Remove(FlashFilter). 52 | Insert(ValidationFilter, BEFORE, NilFilter). 53 | Insert(I18nFilter, AFTER, NilFilter) 54 | expected := []Filter{ 55 | SessionFilter, 56 | ValidationFilter, 57 | NilFilter, 58 | I18nFilter, 59 | ActionInvoker, 60 | } 61 | actual := getOverride("Foo") 62 | if len(actual) != len(expected) || !filterSliceEqual(actual, expected) { 63 | t.Errorf("Ops failed.\nActual: %#v\nExpect: %#v\nConf:%v", actual, expected, conf) 64 | } 65 | 66 | // Action2 should be unchanged 67 | if getOverride("Bar") != nil { 68 | t.Errorf("Filtering Action should not affect Action2.") 69 | } 70 | 71 | // Test that combining overrides on both the Controller and Action works. 72 | FilterController(FakeController{}). 73 | Add(PanicFilter) 74 | expected = []Filter{ 75 | SessionFilter, 76 | ValidationFilter, 77 | NilFilter, 78 | I18nFilter, 79 | PanicFilter, 80 | ActionInvoker, 81 | } 82 | actual = getOverride("Foo") 83 | if len(actual) != len(expected) || !filterSliceEqual(actual, expected) { 84 | t.Errorf("Expected PanicFilter added to Foo.\nActual: %#v\nExpect: %#v", actual, expected) 85 | } 86 | 87 | expected = []Filter{ 88 | SessionFilter, 89 | FlashFilter, 90 | PanicFilter, 91 | ActionInvoker, 92 | } 93 | actual = getOverride("Bar") 94 | if len(actual) != len(expected) || !filterSliceEqual(actual, expected) { 95 | t.Errorf("Expected PanicFilter added to Bar.\nActual: %#v\nExpect: %#v", actual, expected) 96 | } 97 | 98 | FilterAction((*FakeController).Bar). 99 | Add(NilFilter) 100 | expected = []Filter{ 101 | SessionFilter, 102 | ValidationFilter, 103 | NilFilter, 104 | I18nFilter, 105 | PanicFilter, 106 | ActionInvoker, 107 | } 108 | actual = getOverride("Foo") 109 | if len(actual) != len(expected) || !filterSliceEqual(actual, expected) { 110 | t.Errorf("Expected no change to Foo.\nActual: %#v\nExpect: %#v", actual, expected) 111 | } 112 | 113 | expected = []Filter{ 114 | SessionFilter, 115 | FlashFilter, 116 | PanicFilter, 117 | NilFilter, 118 | ActionInvoker, 119 | } 120 | actual = getOverride("Bar") 121 | if len(actual) != len(expected) || !filterSliceEqual(actual, expected) { 122 | t.Errorf("Expected NilFilter added to Bar.\nActual: %#v\nExpect: %#v", actual, expected) 123 | } 124 | } 125 | 126 | func filterSliceEqual(a, e []Filter) bool { 127 | for i, f := range a { 128 | if !FilterEq(f, e[i]) { 129 | return false 130 | } 131 | } 132 | return true 133 | } 134 | 135 | func getOverride(methodName string) []Filter { 136 | return getOverrideChain("FakeController", "FakeController."+methodName) 137 | } 138 | -------------------------------------------------------------------------------- /params.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "mime/multipart" 5 | "net/url" 6 | "os" 7 | "reflect" 8 | ) 9 | 10 | // Params provides a unified view of the request params. 11 | // Includes: 12 | // - URL query string 13 | // - Form values 14 | // - File uploads 15 | // 16 | // Warning: param maps other than Values may be nil if there were none. 17 | type Params struct { 18 | url.Values // A unified view of all the individual param maps below. 19 | 20 | // Set by the router 21 | Fixed url.Values // Fixed parameters from the route, e.g. App.Action("fixed param") 22 | Route url.Values // Parameters extracted from the route, e.g. /customers/{id} 23 | 24 | // Set by the ParamsFilter 25 | Query url.Values // Parameters from the query string, e.g. /index?limit=10 26 | Form url.Values // Parameters from the request body. 27 | 28 | Files map[string][]*multipart.FileHeader // Files uploaded in a multipart form 29 | tmpFiles []*os.File // Temp files used during the request. 30 | } 31 | 32 | func ParseParams(params *Params, req *Request) { 33 | params.Query = req.URL.Query() 34 | 35 | // Parse the body depending on the content type. 36 | switch req.ContentType { 37 | case "application/x-www-form-urlencoded": 38 | // Typical form. 39 | if err := req.ParseForm(); err != nil { 40 | WARN.Println("Error parsing request body:", err) 41 | } else { 42 | params.Form = req.Form 43 | } 44 | 45 | case "multipart/form-data": 46 | // Multipart form. 47 | // TODO: Extract the multipart form param so app can set it. 48 | if err := req.ParseMultipartForm(32 << 20 /* 32 MB */); err != nil { 49 | WARN.Println("Error parsing request body:", err) 50 | } else { 51 | params.Form = req.MultipartForm.Value 52 | params.Files = req.MultipartForm.File 53 | } 54 | } 55 | 56 | params.Values = params.calcValues() 57 | } 58 | 59 | // Bind looks for the named parameter, converts it to the requested type, and 60 | // writes it into "dest", which must be settable. If the value can not be 61 | // parsed, "dest" is set to the zero value. 62 | func (p *Params) Bind(dest interface{}, name string) { 63 | value := reflect.ValueOf(dest) 64 | if value.Kind() != reflect.Ptr { 65 | panic("mars/params: non-pointer passed to Bind: " + name) 66 | } 67 | value = value.Elem() 68 | if !value.CanSet() { 69 | panic("mars/params: non-settable variable passed to Bind: " + name) 70 | } 71 | value.Set(Bind(p, name, value.Type())) 72 | } 73 | 74 | // calcValues returns a unified view of the component param maps. 75 | func (p *Params) calcValues() url.Values { 76 | numParams := len(p.Query) + len(p.Fixed) + len(p.Route) + len(p.Form) 77 | 78 | // If there were no params, return an empty map. 79 | if numParams == 0 { 80 | return make(url.Values, 0) 81 | } 82 | 83 | // If only one of the param sources has anything, return that directly. 84 | switch numParams { 85 | case len(p.Query): 86 | return p.Query 87 | case len(p.Route): 88 | return p.Route 89 | case len(p.Fixed): 90 | return p.Fixed 91 | case len(p.Form): 92 | return p.Form 93 | } 94 | 95 | // Copy everything into the same map. 96 | values := make(url.Values, numParams) 97 | for k, v := range p.Fixed { 98 | values[k] = append(values[k], v...) 99 | } 100 | for k, v := range p.Query { 101 | values[k] = append(values[k], v...) 102 | } 103 | for k, v := range p.Route { 104 | values[k] = append(values[k], v...) 105 | } 106 | for k, v := range p.Form { 107 | values[k] = append(values[k], v...) 108 | } 109 | return values 110 | } 111 | 112 | func ParamsFilter(c *Controller, fc []Filter) { 113 | ParseParams(c.Params, c.Request) 114 | 115 | // Clean up from the request. 116 | defer func() { 117 | // Delete temp files. 118 | if c.Request.MultipartForm != nil { 119 | err := c.Request.MultipartForm.RemoveAll() 120 | if err != nil { 121 | WARN.Println("Error removing temporary files:", err) 122 | } 123 | } 124 | 125 | for _, tmpFile := range c.Params.tmpFiles { 126 | err := os.Remove(tmpFile.Name()) 127 | if err != nil { 128 | WARN.Println("Could not remove upload temp file:", err) 129 | } 130 | } 131 | }() 132 | 133 | fc[0](c, fc[1:]) 134 | } 135 | -------------------------------------------------------------------------------- /static.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | fpath "path/filepath" 7 | "reflect" 8 | "strings" 9 | "syscall" 10 | ) 11 | 12 | type Static struct { 13 | *Controller 14 | } 15 | 16 | func init() { 17 | RegisterController((*Static)(nil), 18 | []*MethodType{ 19 | { 20 | Name: "ServeFresh", 21 | Args: []*MethodArg{ 22 | {Name: "prefix", Type: reflect.TypeOf((*string)(nil))}, 23 | {Name: "filepath", Type: reflect.TypeOf((*string)(nil))}, 24 | }, 25 | }, 26 | { 27 | Name: "Serve", 28 | Args: []*MethodArg{ 29 | {Name: "prefix", Type: reflect.TypeOf((*string)(nil))}, 30 | {Name: "filepath", Type: reflect.TypeOf((*string)(nil))}, 31 | }, 32 | }, 33 | }, 34 | ) 35 | } 36 | 37 | // This method handles requests for files. The supplied prefix may be absolute 38 | // or relative. If the prefix is relative it is assumed to be relative to the 39 | // application directory. The filepath may either be just a file or an 40 | // additional filepath to search for the given file. This response may return 41 | // the following responses in the event of an error or invalid request; 42 | // 403(Forbidden): If the prefix filepath combination results in a directory. 43 | // 404(Not found): If the prefix and filepath combination results in a non-existent file. 44 | // 500(Internal Server Error): There are a few edge cases that would likely indicate some configuration error outside of mars. 45 | // 46 | // Note that when defining routes in routes/conf the parameters must not have 47 | // spaces around the comma. 48 | // Bad: Static.Serve("public/img", "favicon.png") 49 | // Good: Static.Serve("public/img","favicon.png") 50 | // 51 | // Examples: 52 | // Serving a directory 53 | // Route (conf/routes): 54 | // GET /public/{<.*>filepath} Static.Serve("public") 55 | // Request: 56 | // public/js/sessvars.js 57 | // Calls 58 | // Static.Serve("public","js/sessvars.js") 59 | // 60 | // Serving a file 61 | // Route (conf/routes): 62 | // GET /favicon.ico Static.Serve("public/img","favicon.png") 63 | // Request: 64 | // favicon.ico 65 | // Calls: 66 | // Static.Serve("public/img", "favicon.png") 67 | func (c Static) ServeFresh(prefix, filepath string) Result { 68 | // Fix for #503. 69 | prefix = c.Params.Fixed.Get("prefix") 70 | if prefix == "" { 71 | return c.NotFound("") 72 | } 73 | 74 | return serve(c, prefix, filepath, -1) 75 | } 76 | 77 | func (c Static) Serve(prefix, filepath string) Result { 78 | // Fix for #503. 79 | prefix = c.Params.Fixed.Get("prefix") 80 | if prefix == "" { 81 | return c.NotFound("") 82 | } 83 | 84 | return serve(c, prefix, filepath, int(MaxAge.Seconds())) 85 | } 86 | 87 | // This method allows static serving of application files in a verified manner. 88 | func serve(c Static, prefix, filepath string, maxAge int) Result { 89 | var basePath string 90 | if !fpath.IsAbs(prefix) { 91 | basePath = BasePath 92 | } 93 | 94 | basePathPrefix := fpath.Join(basePath, fpath.FromSlash(prefix)) 95 | fname := fpath.Join(basePathPrefix, fpath.FromSlash(filepath)) 96 | // Verify the request file path is within the application's scope of access 97 | if !strings.HasPrefix(fname, basePathPrefix) { 98 | WARN.Printf("Attempted to read file outside of base path: %s", fname) 99 | return c.NotFound("") 100 | } 101 | 102 | // Verify file path is accessible 103 | finfo, err := os.Stat(fname) 104 | if err != nil { 105 | if os.IsNotExist(err) || err.(*os.PathError).Err == syscall.ENOTDIR { 106 | WARN.Printf("File not found (%s): %s ", fname, err) 107 | return c.NotFound("File not found") 108 | } 109 | ERROR.Printf("Error trying to get fileinfo for '%s': %s", fname, err) 110 | return c.RenderError(err) 111 | } 112 | 113 | // Disallow directory listing 114 | if finfo.Mode().IsDir() { 115 | WARN.Printf("Attempted directory listing of %s", fname) 116 | return c.Forbidden("Directory listing not allowed") 117 | } 118 | 119 | // Open request file path 120 | file, err := os.Open(fname) 121 | if err != nil { 122 | if os.IsNotExist(err) { 123 | WARN.Printf("File not found (%s): %s ", fname, err) 124 | return c.NotFound("File not found") 125 | } 126 | ERROR.Printf("Error opening '%s': %s", fname, err) 127 | return c.RenderError(err) 128 | } 129 | 130 | if maxAge > 0 { 131 | c.Response.Out.Header().Add("Cache-Control", fmt.Sprintf("max-age=%d, must-revalidate", maxAge)) 132 | } 133 | 134 | return c.RenderFile(file, Inline) 135 | } 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mars 2 | 3 | A lightweight web toolkit for the [Go language](http://www.golang.org). 4 | 5 | [![Go Reference](https://pkg.go.dev/badge/github.com/roblillack/mars.svg)](https://pkg.go.dev/github.com/roblillack/mars) 6 | [![Build status](https://github.com/roblillack/mars/actions/workflows/build-and-test.yml/badge.svg?branch=master)](https://github.com/roblillack/mars/actions) 7 | [![Documentation Status](https://readthedocs.org/projects/mars/badge/?version=latest)](http://mars.readthedocs.org/en/latest/?badge=latest) 8 | [![Coverage Status](https://coveralls.io/repos/github/roblillack/mars/badge.svg?branch=master)](https://coveralls.io/github/roblillack/mars?branch=master) 9 | [![Go Report Card](https://goreportcard.com/badge/github.com/roblillack/mars)](https://goreportcard.com/report/github.com/roblillack/mars) 10 | [![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) 11 | 12 | - Latest Mars version: 1.1.0 (released May 1, 2022) 13 | - Supported Go versions: 1.13 … 1.23 14 | 15 | Mars is a fork of the fantastic, yet not-that-idiomatic-and-pretty-much-abandoned, [Revel framework](https://github.com/revel/revel). You might take a look at the corresponding documentation for the time being. 16 | 17 | **Have a question?** Head over to our [Discussions](https://github.com/roblillack/mars/discussions)! 💬 18 | 19 | ## Quick Start 20 | 21 | Getting started with Mars is as easy as: 22 | 23 | 1. Adding the package to your project 24 | 25 | ```sh 26 | $ go get github.com/roblillack/mars 27 | ``` 28 | 29 | 2. Creating an empty routes file in `conf/routes` 30 | 31 | ```sh 32 | $ mkdir conf; echo > conf/routes 33 | ``` 34 | 35 | 3. Running the server as part of your main package 36 | 37 | ```go 38 | package main 39 | 40 | import "github.com/roblillack/mars" 41 | 42 | func main() { 43 | mars.Run() 44 | } 45 | ``` 46 | 47 | This essentially sets up an insecure server as part of your application that listens to HTTP (only) and responds to all requests with a 404. To learn where to go from here, please see the [Mars tutorial](http://mars.readthedocs.io/en/latest/getting-started/) 48 | 49 | ## Differences to Revel 50 | 51 | The major changes since forking away from Revel are these: 52 | 53 | - More idiomatic approach to integrating the framework into your application: 54 | - No need to use the `revel` command to build, run, package, or distribute your app. 55 | - Code generation (for registering controllers and reverse routes) is supported using the standard `go generate` way. 56 | - No runtime dependencies anymore. Apps using Mars are truly standalone and do not need access to the sources at runtime (default templates and mime config are embedded assets). 57 | - You are not forced into a fixed directory layout or package names anymore. 58 | - Removed most of the "path magic" that tried to determine where the sources of your application and revel are: No global `AppPath`, `ViewsPath`, `TemplatePaths`, `RevelPath`, and `SourcePath` variables anymore. 59 | - Added support for Go 1.5+ vendoring. 60 | - Vendor Mars' dependencies as Git submodules. 61 | - Added support for [HTTP dual-stack mode](https://github.com/roblillack/mars/issues/6). 62 | - Added support for [generating self-signed SSL certificates on-the-fly](https://github.com/roblillack/mars/issues/6). 63 | - Added [graceful shutdown](https://godoc.org/github.com/roblillack/mars#OnAppShutdown) functionality. 64 | - Added [CSRF protection](https://godoc.org/github.com/roblillack/mars#CSRFFilter). 65 | - Integrated `Static` controller to support hosting plain HTML files and assets. 66 | - Removed magic that automatically added template parameter names based on variable names in `Controller.Render()` calls using code generation and runtime introspection. 67 | - Removed the cache library. 68 | - Removed module support. 69 | - Removed support for configurable template delimiters. 70 | - Corrected case of render functions (`RenderXml` --> `RenderXML`). 71 | - Fix generating reverse routes for some edge cases: Action parameter is called `args` or action parameter is of type `interface{}`. 72 | - Fixed a [XSS vulnerability](https://github.com/roblillack/mars/issues/1). 73 | 74 | ## Documentation 75 | 76 | - [Getting started with Mars](http://mars.readthedocs.io/en/latest/getting-started/) 77 | - [Moving from Revel to Mars in 7 steps](http://mars.readthedocs.io/en/latest/migration/) 78 | 79 | ## Links 80 | 81 | - [Code Coverage](http://gocover.io/github.com/roblillack/mars) 82 | - [Go Report Card](http://goreportcard.com/report/roblillack/mars) 83 | - [GoDoc](https://godoc.org/github.com/roblillack/mars) 84 | -------------------------------------------------------------------------------- /csrf.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/base64" 6 | "fmt" 7 | "html/template" 8 | "net/http" 9 | "time" 10 | ) 11 | 12 | const csrfCookieKey = "_csrf" 13 | const csrfCookieName = "CSRF" 14 | const csrfHeaderName = "X-CSRF-Token" 15 | const csrfFieldName = "_csrf_token" 16 | 17 | func isSafeMethod(c *Controller) bool { 18 | // Methods deemed safe as as defined RFC 7231, section 4.2.1. 19 | // TODO: We might think about adding the two other idempotent methods here, too. 20 | for _, i := range []string{"GET", "HEAD", "OPTIONS", "TRACE"} { 21 | if c.Request.Method == i { 22 | return true 23 | } 24 | 25 | } 26 | 27 | return false 28 | } 29 | 30 | func findCSRFToken(c *Controller) string { 31 | if h := c.Request.Header.Get(csrfHeaderName); h != "" { 32 | return h 33 | } 34 | 35 | if f := c.Params.Get(csrfFieldName); f != "" { 36 | return f 37 | } 38 | 39 | return "" 40 | } 41 | 42 | func generateRandomToken() string { 43 | buf := make([]byte, 16) 44 | if _, err := rand.Read(buf); err != nil { 45 | ERROR.Printf("Error generating random CSRF token: %s\n", err) 46 | return "" 47 | } 48 | 49 | return base64.RawURLEncoding.EncodeToString(buf) 50 | } 51 | 52 | // CSRFFilter provides measures of protecting against attacks known as 53 | // "Cross-site request forgery" multiple ways in which the frontend of 54 | // the application can prove that a mutating request to the server was 55 | // actually initiated by the said frontend and not an attacker, that 56 | // lured the user into calling unwanted on your site. 57 | // 58 | // A random CSRF token is added to the signed session (as key `_csrf`) 59 | // and an additional Cookie (which can be read using JavaScript) called 60 | // `XXX_CSRF`. The token is also available to the template engine as 61 | // `{{.csrfToken}}` or as ready-made, hidden form field using 62 | // `{{.csrfField}}`. 63 | // 64 | // For each HTTP request not deemed safe according to RFC 7231, 65 | // section 4.2.1, one of these methods MUST be used for the server to 66 | // ascertain that the user actually aksed to call this action in the 67 | // first place: 68 | // 69 | // a) The token is sent using a custom header `X-CSRF-Token` with the 70 | // request. This is very useful for single page application and AJAX 71 | // requests, as most frontend toolkits can be set up to include this 72 | // header if needed. An example for jQuery (added to the footer of each 73 | // page) could look like this: 74 | // 75 | // 88 | // 89 | // b) The token is sent as a form field value for forms using non-safe 90 | // actions. Simply adding `{{.csrfField}}`` should be enough. 91 | // 92 | // To disable CSRF protection for individual actions or controllers 93 | // (ie. API calls that authenticate using HTTP Basic Auth or AccessTokens, 94 | // etc.), add an InterceptorMethod to your Controller that sets the 95 | // Controller.DisableCSRF to `true` for said requests. 96 | // 97 | // See also: 98 | // https://tools.ietf.org/html/rfc7231#section-4.2.1 99 | func CSRFFilter(c *Controller, fc []Filter) { 100 | if DisableCSRF { 101 | fc[0](c, fc[1:]) 102 | return 103 | } 104 | 105 | csrfToken := c.Session[csrfCookieKey] 106 | if len(csrfToken) != 22 { 107 | csrfToken = generateRandomToken() 108 | c.Session[csrfCookieKey] = csrfToken 109 | } 110 | 111 | c.SetCookie(&http.Cookie{ 112 | Name: fmt.Sprintf("%s_%s", CookiePrefix, csrfCookieName), 113 | Value: csrfToken, 114 | Domain: CookieDomain, 115 | Path: "/", 116 | HttpOnly: false, 117 | Secure: CookieSecure, 118 | Expires: time.Now().Add(12 * time.Hour).UTC(), 119 | }) 120 | c.RenderArgs["csrfToken"] = csrfToken 121 | c.RenderArgs["csrfField"] = template.HTML(``) 122 | 123 | if !isSafeMethod(c) && !c.SkipCSRF { 124 | token := findCSRFToken(c) 125 | if token == "" || token != csrfToken { 126 | c.Result = c.Forbidden("No/wrong CSRF token given.") 127 | return 128 | } 129 | } 130 | 131 | fc[0](c, fc[1:]) 132 | } 133 | -------------------------------------------------------------------------------- /internal/watcher/watcher.go: -------------------------------------------------------------------------------- 1 | package watcher 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | "path/filepath" 8 | "strings" 9 | "sync" 10 | 11 | "github.com/fsnotify/fsnotify" 12 | ) 13 | 14 | // Listener is an interface for receivers of filesystem events. 15 | type Listener interface { 16 | // Refresh is invoked by the watcher on relevant filesystem events. 17 | // If the listener returns an error, it is served to the user on the current request. 18 | Refresh() error 19 | } 20 | 21 | // Watcher allows listeners to register to be notified of changes under a given 22 | // directory. 23 | type Watcher struct { 24 | // Parallel arrays of watcher/listener pairs. 25 | watchers []*fsnotify.Watcher 26 | listeners []Listener 27 | lastError int 28 | notifyMutex sync.Mutex 29 | } 30 | 31 | func New() *Watcher { 32 | return &Watcher{ 33 | // forceRefresh: true, 34 | lastError: -1, 35 | } 36 | } 37 | 38 | // Listen registers for events within the given root directories (recursively). 39 | func (w *Watcher) Listen(listener Listener, roots ...string) error { 40 | watcher, err := fsnotify.NewWatcher() 41 | if err != nil { 42 | return err 43 | } 44 | 45 | // Replace the unbuffered Event channel with a buffered one. 46 | // Otherwise multiple change events only come out one at a time, across 47 | // multiple page views. (There appears no way to "pump" the events out of 48 | // the watcher) 49 | watcher.Events = make(chan fsnotify.Event, 100) 50 | watcher.Errors = make(chan error, 10) 51 | 52 | // Walk through all files / directories under the root, adding each to watcher. 53 | for _, p := range roots { 54 | // is the directory / file a symlink? 55 | f, err := os.Lstat(p) 56 | if err == nil && f.Mode()&os.ModeSymlink == os.ModeSymlink { 57 | realPath, err := filepath.EvalSymlinks(p) 58 | if err != nil { 59 | panic(err) 60 | } 61 | p = realPath 62 | } 63 | 64 | fi, err := os.Stat(p) 65 | if err != nil { 66 | return fmt.Errorf("Failed to stat watched path %s: %w", p, err) 67 | } 68 | 69 | // If it is a file, watch that specific file. 70 | if !fi.IsDir() { 71 | err = watcher.Add(p) 72 | if err != nil { 73 | return fmt.Errorf("Failed to watch %s: %w", p, err) 74 | } 75 | continue 76 | } 77 | 78 | var watcherWalker func(path string, info os.FileInfo, err error) error 79 | 80 | watcherWalker = func(path string, info os.FileInfo, err error) error { 81 | if err != nil { 82 | return err 83 | } 84 | 85 | // is it a symlinked template? 86 | link, err := os.Lstat(path) 87 | if err == nil && link.Mode()&os.ModeSymlink == os.ModeSymlink { 88 | // lookup the actual target & check for goodness 89 | targetPath, err := filepath.EvalSymlinks(path) 90 | if err != nil { 91 | return fmt.Errorf("failed to read symlink %s: %w", path, err) 92 | } 93 | targetInfo, err := os.Stat(targetPath) 94 | if err != nil { 95 | return fmt.Errorf("failed to stat symlink target %s of %s: %w", targetPath, path, err) 96 | } 97 | 98 | // set the template path to the target of the symlink 99 | path = targetPath 100 | info = targetInfo 101 | if err := filepath.Walk(path, watcherWalker); err != nil { 102 | return err 103 | } 104 | } 105 | 106 | if info.IsDir() { 107 | if err := watcher.Add(path); err != nil { 108 | return err 109 | } 110 | } 111 | return nil 112 | } 113 | 114 | // Else, walk the directory tree. 115 | if err := filepath.Walk(p, watcherWalker); err != nil { 116 | return fmt.Errorf("error walking path %s: %w", p, err) 117 | } 118 | } 119 | 120 | w.watchers = append(w.watchers, watcher) 121 | w.listeners = append(w.listeners, listener) 122 | 123 | return nil 124 | } 125 | 126 | // Notify causes the watcher to forward any change events to listeners. 127 | // It returns the first (if any) error returned. 128 | func (w *Watcher) Notify() error { 129 | // Serialize Notify() calls. 130 | w.notifyMutex.Lock() 131 | defer w.notifyMutex.Unlock() 132 | 133 | for idx, watcher := range w.watchers { 134 | listener := w.listeners[idx] 135 | 136 | // Pull all pending events / errors from the watcher. 137 | refresh := false 138 | for { 139 | select { 140 | case ev := <-watcher.Events: 141 | // Ignore changes to dotfiles. 142 | if !strings.HasPrefix(path.Base(ev.Name), ".") { 143 | refresh = true 144 | } 145 | continue 146 | case <-watcher.Errors: 147 | continue 148 | default: 149 | // No events left to pull 150 | } 151 | break 152 | } 153 | 154 | if refresh || w.lastError == idx { 155 | err := listener.Refresh() 156 | if err != nil { 157 | w.lastError = idx 158 | return err 159 | } 160 | } 161 | } 162 | 163 | w.lastError = -1 164 | return nil 165 | } 166 | -------------------------------------------------------------------------------- /validators.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "regexp" 7 | "time" 8 | ) 9 | 10 | type Validator interface { 11 | IsSatisfied(interface{}) bool 12 | DefaultMessage() string 13 | } 14 | 15 | type Required struct{} 16 | 17 | func ValidRequired() Required { 18 | return Required{} 19 | } 20 | 21 | func (r Required) IsSatisfied(obj interface{}) bool { 22 | if obj == nil { 23 | return false 24 | } 25 | 26 | if str, ok := obj.(string); ok { 27 | return len(str) > 0 28 | } 29 | if b, ok := obj.(bool); ok { 30 | return b 31 | } 32 | if i, ok := obj.(int); ok { 33 | return i != 0 34 | } 35 | if t, ok := obj.(time.Time); ok { 36 | return !t.IsZero() 37 | } 38 | v := reflect.ValueOf(obj) 39 | if v.Kind() == reflect.Slice { 40 | return v.Len() > 0 41 | } 42 | return true 43 | } 44 | 45 | func (r Required) DefaultMessage() string { 46 | return "Required" 47 | } 48 | 49 | type Min struct { 50 | Min int 51 | } 52 | 53 | func ValidMin(min int) Min { 54 | return Min{min} 55 | } 56 | 57 | func (m Min) IsSatisfied(obj interface{}) bool { 58 | num, ok := obj.(int) 59 | if ok { 60 | return num >= m.Min 61 | } 62 | return false 63 | } 64 | 65 | func (m Min) DefaultMessage() string { 66 | return fmt.Sprintln("Minimum is", m.Min) 67 | } 68 | 69 | type Max struct { 70 | Max int 71 | } 72 | 73 | func ValidMax(max int) Max { 74 | return Max{max} 75 | } 76 | 77 | func (m Max) IsSatisfied(obj interface{}) bool { 78 | num, ok := obj.(int) 79 | if ok { 80 | return num <= m.Max 81 | } 82 | return false 83 | } 84 | 85 | func (m Max) DefaultMessage() string { 86 | return fmt.Sprintln("Maximum is", m.Max) 87 | } 88 | 89 | // Requires an integer to be within Min, Max inclusive. 90 | type Range struct { 91 | Min 92 | Max 93 | } 94 | 95 | func ValidRange(min, max int) Range { 96 | return Range{Min{min}, Max{max}} 97 | } 98 | 99 | func (r Range) IsSatisfied(obj interface{}) bool { 100 | return r.Min.IsSatisfied(obj) && r.Max.IsSatisfied(obj) 101 | } 102 | 103 | func (r Range) DefaultMessage() string { 104 | return fmt.Sprintln("Range is", r.Min.Min, "to", r.Max.Max) 105 | } 106 | 107 | // Requires an array or string to be at least a given length. 108 | type MinSize struct { 109 | Min int 110 | } 111 | 112 | func ValidMinSize(min int) MinSize { 113 | return MinSize{min} 114 | } 115 | 116 | func (m MinSize) IsSatisfied(obj interface{}) bool { 117 | if str, ok := obj.(string); ok { 118 | return len(str) >= m.Min 119 | } 120 | v := reflect.ValueOf(obj) 121 | if v.Kind() == reflect.Slice { 122 | return v.Len() >= m.Min 123 | } 124 | return false 125 | } 126 | 127 | func (m MinSize) DefaultMessage() string { 128 | return fmt.Sprintln("Minimum size is", m.Min) 129 | } 130 | 131 | // Requires an array or string to be at most a given length. 132 | type MaxSize struct { 133 | Max int 134 | } 135 | 136 | func ValidMaxSize(max int) MaxSize { 137 | return MaxSize{max} 138 | } 139 | 140 | func (m MaxSize) IsSatisfied(obj interface{}) bool { 141 | if str, ok := obj.(string); ok { 142 | return len(str) <= m.Max 143 | } 144 | v := reflect.ValueOf(obj) 145 | if v.Kind() == reflect.Slice { 146 | return v.Len() <= m.Max 147 | } 148 | return false 149 | } 150 | 151 | func (m MaxSize) DefaultMessage() string { 152 | return fmt.Sprintln("Maximum size is", m.Max) 153 | } 154 | 155 | // Requires an array or string to be exactly a given length. 156 | type Length struct { 157 | N int 158 | } 159 | 160 | func ValidLength(n int) Length { 161 | return Length{n} 162 | } 163 | 164 | func (s Length) IsSatisfied(obj interface{}) bool { 165 | if str, ok := obj.(string); ok { 166 | return len(str) == s.N 167 | } 168 | v := reflect.ValueOf(obj) 169 | if v.Kind() == reflect.Slice { 170 | return v.Len() == s.N 171 | } 172 | return false 173 | } 174 | 175 | func (s Length) DefaultMessage() string { 176 | return fmt.Sprintln("Required length is", s.N) 177 | } 178 | 179 | // Requires a string to match a given regex. 180 | type Match struct { 181 | Regexp *regexp.Regexp 182 | } 183 | 184 | func ValidMatch(regex *regexp.Regexp) Match { 185 | return Match{regex} 186 | } 187 | 188 | func (m Match) IsSatisfied(obj interface{}) bool { 189 | str := obj.(string) 190 | return m.Regexp.MatchString(str) 191 | } 192 | 193 | func (m Match) DefaultMessage() string { 194 | return fmt.Sprintln("Must match", m.Regexp) 195 | } 196 | 197 | var emailPattern = regexp.MustCompile("^[\\w!#$%&'*+/=?^_`{|}~-]+(?:\\.[\\w!#$%&'*+/=?^_`{|}~-]+)*@(?:[\\w](?:[\\w-]*[\\w])?\\.)+[a-zA-Z0-9](?:[\\w-]*[\\w])?$") 198 | 199 | type Email struct { 200 | Match 201 | } 202 | 203 | func ValidEmail() Email { 204 | return Email{Match{emailPattern}} 205 | } 206 | 207 | func (e Email) DefaultMessage() string { 208 | return fmt.Sprintln("Must be a valid email address") 209 | } 210 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "io" 7 | "net/http" 8 | "os" 9 | "os/signal" 10 | "sync" 11 | "syscall" 12 | "time" 13 | 14 | "golang.org/x/net/websocket" 15 | 16 | "github.com/roblillack/mars/internal/watcher" 17 | ) 18 | 19 | var ( 20 | MainRouter *Router 21 | MainTemplateLoader *TemplateLoader 22 | mainWatcher *watcher.Watcher 23 | Server *http.Server 24 | SecureServer *http.Server 25 | ) 26 | 27 | // Handler is a http.HandlerFunc which exposes Mars' filtering, routing, and 28 | // interception functionality for you to use with custom HTTP servers. 29 | var Handler = http.HandlerFunc(handle) 30 | 31 | // This method handles all requests. It dispatches to handleInternal after 32 | // handling / adapting websocket connections. 33 | func handle(w http.ResponseWriter, r *http.Request) { 34 | if maxRequestSize := int64(Config.IntDefault("http.maxrequestsize", 0)); maxRequestSize > 0 { 35 | r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) 36 | } 37 | 38 | upgrade := r.Header.Get("Upgrade") 39 | if upgrade == "websocket" || upgrade == "Websocket" { 40 | websocket.Handler(func(ws *websocket.Conn) { 41 | //Override default Read/Write timeout with sane value for a web socket request 42 | ws.SetDeadline(time.Now().Add(time.Hour * 24)) 43 | r.Method = "WS" 44 | handleInternal(w, r, ws) 45 | }).ServeHTTP(w, r) 46 | } else { 47 | handleInternal(w, r, nil) 48 | } 49 | } 50 | 51 | func handleInternal(w http.ResponseWriter, r *http.Request, ws *websocket.Conn) { 52 | var ( 53 | req = NewRequest(r) 54 | resp = NewResponse(w) 55 | c = NewController(req, resp) 56 | ) 57 | req.Websocket = ws 58 | 59 | Filters[0](c, Filters[1:]) 60 | if c.Result != nil { 61 | c.Result.Apply(req, resp) 62 | } else if c.Response.Status != 0 { 63 | c.Response.Out.WriteHeader(c.Response.Status) 64 | } 65 | // Close the Writer if we can 66 | if w, ok := resp.Out.(io.Closer); ok { 67 | w.Close() 68 | } 69 | } 70 | 71 | func makeServer(addr string) *http.Server { 72 | return &http.Server{ 73 | Addr: addr, 74 | Handler: Handler, 75 | ReadTimeout: time.Duration(Config.IntDefault("timeout.read", 0)) * time.Second, 76 | WriteTimeout: time.Duration(Config.IntDefault("timeout.write", 0)) * time.Second, 77 | } 78 | } 79 | 80 | func initGracefulShutdown() { 81 | stop := make(chan os.Signal, 1) 82 | signal.Notify(stop, os.Interrupt, syscall.SIGTERM) 83 | 84 | go func() { 85 | <-stop 86 | INFO.Println("Shutting down listeners ...") 87 | 88 | ctx := context.Background() 89 | if timeout := Config.IntDefault("timeout.shutdown", 0); timeout != 0 { 90 | newCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) 91 | ctx = newCtx 92 | defer cancel() 93 | } 94 | 95 | if SecureServer != nil { 96 | if err := SecureServer.Shutdown(ctx); err != nil { 97 | ERROR.Println(err) 98 | } 99 | } 100 | if Server != nil { 101 | if err := Server.Shutdown(ctx); err != nil { 102 | ERROR.Println(err) 103 | } 104 | } 105 | }() 106 | } 107 | 108 | func Run() { 109 | if !setupDone { 110 | Setup() 111 | } 112 | 113 | if DevMode { 114 | INFO.Printf("Development mode enabled.") 115 | } 116 | 117 | wg := sync.WaitGroup{} 118 | initializeFallbacks() 119 | initGracefulShutdown() 120 | 121 | if !HttpSsl || DualStackHTTP { 122 | go func() { 123 | time.Sleep(100 * time.Millisecond) 124 | INFO.Printf("Listening on %s (HTTP) ...\n", HttpAddr) 125 | }() 126 | 127 | wg.Add(1) 128 | go func() { 129 | defer wg.Done() 130 | 131 | Server = makeServer(HttpAddr) 132 | if err := Server.ListenAndServe(); err != nil && err != http.ErrServerClosed { 133 | ERROR.Fatalln("Failed to serve:", err) 134 | } 135 | }() 136 | } 137 | 138 | if HttpSsl || DualStackHTTP { 139 | go func() { 140 | time.Sleep(100 * time.Millisecond) 141 | INFO.Printf("Listening on %s (HTTPS) ...\n", SSLAddr) 142 | }() 143 | 144 | wg.Add(1) 145 | go func() { 146 | defer wg.Done() 147 | 148 | serveTLS(SSLAddr) 149 | }() 150 | } 151 | 152 | wg.Wait() 153 | 154 | runShutdownHooks() 155 | } 156 | 157 | func serveTLS(addr string) { 158 | SecureServer = makeServer(addr) 159 | 160 | SecureServer.TLSConfig = &tls.Config{ 161 | Certificates: make([]tls.Certificate, 1), 162 | } 163 | if SelfSignedCert { 164 | keypair, err := createCertificate(SelfSignedOrganization, SelfSignedDomains) 165 | if err != nil { 166 | ERROR.Fatalln("Unable to create key pair:", err) 167 | } 168 | SecureServer.TLSConfig.Certificates[0] = keypair 169 | } else { 170 | keypair, err := tls.LoadX509KeyPair(HttpSslCert, HttpSslKey) 171 | if err != nil { 172 | ERROR.Fatalln("Unable to load key pair:", err) 173 | } 174 | SecureServer.TLSConfig.Certificates[0] = keypair 175 | } 176 | 177 | if err := SecureServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { 178 | ERROR.Fatalln("Failed to serve:", err) 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /http.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "net/http" 7 | "sort" 8 | "strconv" 9 | "strings" 10 | 11 | "golang.org/x/net/websocket" 12 | ) 13 | 14 | type Request struct { 15 | *http.Request 16 | ContentType string 17 | Format string // "html", "xml", "json", or "txt" 18 | AcceptLanguages AcceptLanguages 19 | Locale string 20 | Websocket *websocket.Conn 21 | } 22 | 23 | type Response struct { 24 | Status int 25 | ContentType string 26 | 27 | Out http.ResponseWriter 28 | } 29 | 30 | func NewResponse(w http.ResponseWriter) *Response { 31 | return &Response{Out: w} 32 | } 33 | 34 | func NewRequest(r *http.Request) *Request { 35 | return &Request{ 36 | Request: r, 37 | ContentType: ResolveContentType(r), 38 | Format: ResolveFormat(r), 39 | AcceptLanguages: ResolveAcceptLanguage(r), 40 | } 41 | } 42 | 43 | // Write the header (for now, just the status code). 44 | // The status may be set directly by the application (c.Response.Status = 501). 45 | // if it isn't, then fall back to the provided status code. 46 | func (resp *Response) WriteHeader(defaultStatusCode int, defaultContentType string) { 47 | if resp.Status == 0 { 48 | resp.Status = defaultStatusCode 49 | } 50 | if resp.ContentType == "" { 51 | resp.ContentType = defaultContentType 52 | } 53 | resp.Out.Header().Set("Content-Type", resp.ContentType) 54 | resp.Out.WriteHeader(resp.Status) 55 | } 56 | 57 | // Get the content type. 58 | // e.g. From "multipart/form-data; boundary=--" to "multipart/form-data" 59 | // If none is specified, returns "text/html" by default. 60 | func ResolveContentType(req *http.Request) string { 61 | contentType := req.Header.Get("Content-Type") 62 | if contentType == "" { 63 | return "text/html" 64 | } 65 | return strings.ToLower(strings.TrimSpace(strings.Split(contentType, ";")[0])) 66 | } 67 | 68 | // ResolveFormat maps the request's Accept MIME type declaration to 69 | // a Request.Format attribute, specifically "html", "xml", "json", or "txt", 70 | // returning a default of "html" when Accept header cannot be mapped to a 71 | // value above. 72 | func ResolveFormat(req *http.Request) string { 73 | accept := req.Header.Get("accept") 74 | 75 | switch { 76 | case accept == "", 77 | strings.HasPrefix(accept, "*/*"), // */ 78 | strings.Contains(accept, "application/xhtml"), 79 | strings.Contains(accept, "text/html"): 80 | return "html" 81 | case strings.Contains(accept, "application/json"), 82 | strings.Contains(accept, "text/javascript"): 83 | return "json" 84 | case strings.Contains(accept, "application/xml"), 85 | strings.Contains(accept, "text/xml"): 86 | return "xml" 87 | case strings.Contains(accept, "text/plain"): 88 | return "txt" 89 | } 90 | 91 | return "html" 92 | } 93 | 94 | // AcceptLanguage is a single language from the Accept-Language HTTP header. 95 | type AcceptLanguage struct { 96 | Language string 97 | Quality float32 98 | } 99 | 100 | // AcceptLanguages is collection of sortable AcceptLanguage instances. 101 | type AcceptLanguages []AcceptLanguage 102 | 103 | func (al AcceptLanguages) Len() int { return len(al) } 104 | func (al AcceptLanguages) Swap(i, j int) { al[i], al[j] = al[j], al[i] } 105 | func (al AcceptLanguages) Less(i, j int) bool { return al[i].Quality > al[j].Quality } 106 | func (al AcceptLanguages) String() string { 107 | output := bytes.NewBufferString("") 108 | for i, language := range al { 109 | output.WriteString(fmt.Sprintf("%s (%1.1f)", language.Language, language.Quality)) 110 | if i != len(al)-1 { 111 | output.WriteString(", ") 112 | } 113 | } 114 | return output.String() 115 | } 116 | 117 | // ResolveAcceptLanguage returns a sorted list of Accept-Language 118 | // header values. 119 | // 120 | // The results are sorted using the quality defined in the header for each 121 | // language range with the most qualified language range as the first 122 | // element in the slice. 123 | // 124 | // See the HTTP header fields specification 125 | // (http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.4) for more details. 126 | func ResolveAcceptLanguage(req *http.Request) AcceptLanguages { 127 | header := req.Header.Get("Accept-Language") 128 | if header == "" { 129 | return nil 130 | } 131 | 132 | acceptLanguageHeaderValues := strings.Split(header, ",") 133 | acceptLanguages := make(AcceptLanguages, len(acceptLanguageHeaderValues)) 134 | 135 | for i, languageRange := range acceptLanguageHeaderValues { 136 | if qualifiedRange := strings.Split(languageRange, ";q="); len(qualifiedRange) == 2 { 137 | quality, error := strconv.ParseFloat(qualifiedRange[1], 32) 138 | if error != nil { 139 | WARN.Printf("Detected malformed Accept-Language header quality in '%s', assuming quality is 1", removeLineBreaks(languageRange)) 140 | acceptLanguages[i] = AcceptLanguage{qualifiedRange[0], 1} 141 | } else { 142 | acceptLanguages[i] = AcceptLanguage{qualifiedRange[0], float32(quality)} 143 | } 144 | } else { 145 | acceptLanguages[i] = AcceptLanguage{languageRange, 1} 146 | } 147 | } 148 | 149 | sort.Sort(acceptLanguages) 150 | return acceptLanguages 151 | } 152 | -------------------------------------------------------------------------------- /params_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io/ioutil" 7 | "net/http" 8 | "net/url" 9 | "reflect" 10 | "testing" 11 | ) 12 | 13 | // Params: Testing Multipart forms 14 | 15 | const ( 16 | MultipartBoundary = "A" 17 | MultipartFormData = `--A 18 | Content-Disposition: form-data; name="text1" 19 | 20 | data1 21 | --A 22 | Content-Disposition: form-data; name="text2" 23 | 24 | data2 25 | --A 26 | Content-Disposition: form-data; name="text2" 27 | 28 | data3 29 | --A 30 | Content-Disposition: form-data; name="file1"; filename="test.txt" 31 | Content-Type: text/plain 32 | 33 | content1 34 | --A 35 | Content-Disposition: form-data; name="file2[]"; filename="test.txt" 36 | Content-Type: text/plain 37 | 38 | content2 39 | --A 40 | Content-Disposition: form-data; name="file2[]"; filename="favicon.ico" 41 | Content-Type: image/x-icon 42 | 43 | xyz 44 | --A 45 | Content-Disposition: form-data; name="file3[0]"; filename="test.txt" 46 | Content-Type: text/plain 47 | 48 | content3 49 | --A 50 | Content-Disposition: form-data; name="file3[1]"; filename="favicon.ico" 51 | Content-Type: image/x-icon 52 | 53 | zzz 54 | --A-- 55 | ` 56 | ) 57 | 58 | // The values represented by the form data. 59 | type fh struct { 60 | filename string 61 | content []byte 62 | } 63 | 64 | var ( 65 | expectedValues = map[string][]string{ 66 | "text1": {"data1"}, 67 | "text2": {"data2", "data3"}, 68 | } 69 | expectedFiles = map[string][]fh{ 70 | "file1": {fh{"test.txt", []byte("content1")}}, 71 | "file2[]": {fh{"test.txt", []byte("content2")}, fh{"favicon.ico", []byte("xyz")}}, 72 | "file3[0]": {fh{"test.txt", []byte("content3")}}, 73 | "file3[1]": {fh{"favicon.ico", []byte("zzz")}}, 74 | } 75 | ) 76 | 77 | func getMultipartRequest() *http.Request { 78 | req, _ := http.NewRequest("POST", "http://localhost/path", 79 | bytes.NewBufferString(MultipartFormData)) 80 | req.Header.Set( 81 | "Content-Type", fmt.Sprintf("multipart/form-data; boundary=%s", MultipartBoundary)) 82 | req.Header.Set( 83 | "Content-Length", fmt.Sprintf("%d", len(MultipartFormData))) 84 | return req 85 | } 86 | 87 | func BenchmarkParams(b *testing.B) { 88 | c := Controller{ 89 | Request: NewRequest(getMultipartRequest()), 90 | Params: &Params{}, 91 | } 92 | for i := 0; i < b.N; i++ { 93 | ParamsFilter(&c, NilChain) 94 | } 95 | } 96 | 97 | func TestMultipartForm(t *testing.T) { 98 | c := Controller{ 99 | Request: NewRequest(getMultipartRequest()), 100 | Params: &Params{}, 101 | } 102 | ParamsFilter(&c, NilChain) 103 | 104 | if !reflect.DeepEqual(expectedValues, map[string][]string(c.Params.Values)) { 105 | t.Errorf("Param values: (expected) %v != %v (actual)", 106 | expectedValues, map[string][]string(c.Params.Values)) 107 | } 108 | 109 | actualFiles := make(map[string][]fh) 110 | for key, fileHeaders := range c.Params.Files { 111 | for _, fileHeader := range fileHeaders { 112 | file, _ := fileHeader.Open() 113 | content, _ := ioutil.ReadAll(file) 114 | actualFiles[key] = append(actualFiles[key], fh{fileHeader.Filename, content}) 115 | } 116 | } 117 | 118 | if !reflect.DeepEqual(expectedFiles, actualFiles) { 119 | t.Errorf("Param files: (expected) %v != %v (actual)", expectedFiles, actualFiles) 120 | } 121 | } 122 | 123 | func TestBind(t *testing.T) { 124 | params := Params{ 125 | Values: url.Values{ 126 | "x": {"5"}, 127 | }, 128 | } 129 | var x int 130 | params.Bind(&x, "x") 131 | if x != 5 { 132 | t.Errorf("Failed to bind x. Value: %d", x) 133 | } 134 | } 135 | 136 | func TestResolveAcceptLanguage(t *testing.T) { 137 | request := buildHttpRequestWithAcceptLanguage("") 138 | if result := ResolveAcceptLanguage(request); result != nil { 139 | t.Errorf("Expected Accept-Language to resolve to an empty string but it was '%s'", result) 140 | } 141 | 142 | request = buildHttpRequestWithAcceptLanguage("en-GB,en;q=0.8,nl;q=0.6") 143 | if result := ResolveAcceptLanguage(request); len(result) != 3 { 144 | t.Errorf("Unexpected Accept-Language values length of %d (expected %d)", len(result), 3) 145 | } else { 146 | if result[0].Language != "en-GB" { 147 | t.Errorf("Expected '%s' to be most qualified but instead it's '%s'", "en-GB", result[0].Language) 148 | } 149 | if result[1].Language != "en" { 150 | t.Errorf("Expected '%s' to be most qualified but instead it's '%s'", "en", result[1].Language) 151 | } 152 | if result[2].Language != "nl" { 153 | t.Errorf("Expected '%s' to be most qualified but instead it's '%s'", "nl", result[2].Language) 154 | } 155 | } 156 | 157 | request = buildHttpRequestWithAcceptLanguage("en;q=0.8,nl;q=0.6,en-AU;q=malformed") 158 | if result := ResolveAcceptLanguage(request); len(result) != 3 { 159 | t.Errorf("Unexpected Accept-Language values length of %d (expected %d)", len(result), 3) 160 | } else { 161 | if result[0].Language != "en-AU" { 162 | t.Errorf("Expected '%s' to be most qualified but instead it's '%s'", "en-AU", result[0].Language) 163 | } 164 | } 165 | } 166 | 167 | func BenchmarkResolveAcceptLanguage(b *testing.B) { 168 | for i := 0; i < b.N; i++ { 169 | request := buildHttpRequestWithAcceptLanguage("en-GB,en;q=0.8,nl;q=0.6,fr;q=0.5,de-DE;q=0.4,no-NO;q=0.4,ru;q=0.2") 170 | ResolveAcceptLanguage(request) 171 | } 172 | } 173 | 174 | func buildHttpRequestWithAcceptLanguage(acceptLanguage string) *http.Request { 175 | request, _ := http.NewRequest("POST", "http://localhost/path", nil) 176 | request.Header.Set("Accept-Language", acceptLanguage) 177 | return request 178 | } 179 | -------------------------------------------------------------------------------- /cmd/mars-gen/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/format" 7 | "os" 8 | "path" 9 | "text/template" 10 | "time" 11 | 12 | "github.com/codegangsta/cli" 13 | ) 14 | 15 | func fatalf(layout string, args ...interface{}) { 16 | fmt.Fprintf(os.Stderr, layout+"\n", args...) 17 | os.Exit(1) 18 | } 19 | 20 | func main() { 21 | app := cli.NewApp() 22 | app.HideVersion = true 23 | app.Name = "mars-gen" 24 | app.Usage = "Code generation tool for the Mars web framework" 25 | app.Flags = []cli.Flag{ 26 | cli.BoolFlag{ 27 | Name: "verbose, v", 28 | Usage: "Prints the names of the source files as they are parsed", 29 | }, 30 | } 31 | app.Commands = []cli.Command{ 32 | { 33 | Name: "register-controllers", 34 | Usage: "Generates code to register your controllers with the framework", 35 | Action: registerControllers, 36 | Flags: []cli.Flag{ 37 | cli.StringFlag{ 38 | Name: "n", 39 | Value: "RegisterControllers", 40 | Usage: "Function name to generate", 41 | }, 42 | cli.StringFlag{ 43 | Name: "o", 44 | Value: "register_controllers.gen.go", 45 | Usage: "Name of the file to generate", 46 | }, 47 | }, 48 | }, 49 | { 50 | Name: "reverse-routes", 51 | Usage: "Generates code that allows generating reverse routes", 52 | Action: reverseRoutes, 53 | Flags: []cli.Flag{ 54 | cli.StringFlag{ 55 | Name: "n", 56 | Value: "routes", 57 | Usage: "Package name to generate", 58 | }, 59 | cli.StringFlag{ 60 | Name: "o", 61 | Value: "routes.gen.go", 62 | Usage: "Path of the file to generate", 63 | }, 64 | }, 65 | }, 66 | } 67 | 68 | app.Run(os.Args) 69 | } 70 | 71 | func registerControllers(ctx *cli.Context) { 72 | dir := "." 73 | if len(ctx.Args()) > 0 { 74 | dir = ctx.Args()[0] 75 | } 76 | 77 | sourceInfo, procErr := ProcessSource(dir, ctx.GlobalBool("v")) 78 | if procErr != nil { 79 | fatalf(procErr.Error()) 80 | } 81 | 82 | generateSources(registerTemplate, path.Join(dir, ctx.String("o")), map[string]interface{}{ 83 | "packageName": sourceInfo.PackageName, 84 | "functionName": ctx.String("n"), 85 | "controllers": sourceInfo.ControllerSpecs(), 86 | "ImportPaths": sourceInfo.CalcImportAliases(), 87 | "time": time.Now(), 88 | }) 89 | } 90 | 91 | func reverseRoutes(ctx *cli.Context) { 92 | dir := "." 93 | if len(ctx.Args()) > 0 { 94 | dir = ctx.Args()[0] 95 | } 96 | 97 | sourceInfo, procErr := ProcessSource(dir, ctx.GlobalBool("v")) 98 | if procErr != nil { 99 | fatalf(procErr.Error()) 100 | } 101 | 102 | generateSources(routesTemplate, ctx.String("o"), map[string]interface{}{ 103 | "packageName": ctx.String("n"), 104 | "controllers": sourceInfo.ControllerSpecs(), 105 | "ImportPaths": sourceInfo.CalcImportAliases(), 106 | "time": time.Now(), 107 | }) 108 | } 109 | 110 | func generateSources(tpl, filename string, templateArgs map[string]interface{}) { 111 | var b bytes.Buffer 112 | 113 | tmpl := template.Must(template.New("").Parse(tpl)) 114 | if err := tmpl.Execute(&b, templateArgs); err != nil { 115 | fatalf("Unable to create source file: %v", err) 116 | } 117 | 118 | if err := os.MkdirAll(path.Dir(filename), 0755); err != nil { 119 | fatalf("Unable to create dir: %v", err) 120 | } 121 | 122 | // Create the file 123 | file, err := os.Create(filename) 124 | if err != nil { 125 | fatalf("Failed to create file: %v", err) 126 | } 127 | defer file.Close() 128 | 129 | formatted, err := format.Source(b.Bytes()) 130 | if err != nil { 131 | fatalf("Failed to format file: %v", err) 132 | } 133 | 134 | if _, err := file.Write(formatted); err != nil { 135 | fatalf("Failed to write to file: %v", err) 136 | } 137 | } 138 | 139 | const registerTemplate = `// DO NOT EDIT -- code generated by mars-gen 140 | package {{.packageName}} 141 | 142 | import ( 143 | "reflect" 144 | "github.com/roblillack/mars"{{range $k, $v := $.ImportPaths}} 145 | {{$v}} "{{$k}}"{{end}} 146 | ) 147 | 148 | var ( 149 | // So compiler won't complain if the generated code doesn't reference reflect package... 150 | _ = reflect.Invalid 151 | ) 152 | 153 | func {{.functionName}}() { 154 | {{range $i, $c := .controllers}} 155 | mars.RegisterController((*{{.StructName}})(nil), 156 | []*mars.MethodType{ 157 | {{range .MethodSpecs}}{ 158 | Name: "{{.Name}}", 159 | Args: []*mars.MethodArg{ {{range .Args}} 160 | {Name: "{{.Name}}", Type: reflect.TypeOf((*{{index $.ImportPaths .ImportPath | .TypeExpr.TypeName}})(nil)) },{{end}} 161 | }, 162 | }, 163 | {{end}} 164 | }) 165 | {{end}} 166 | } 167 | ` 168 | 169 | const routesTemplate = `// DO NOT EDIT -- code generated by mars-gen 170 | package {{.packageName}} 171 | 172 | import ( 173 | "github.com/roblillack/mars"{{range $k, $v := $.ImportPaths}} 174 | {{$v}} "{{$k}}"{{end}} 175 | ) 176 | 177 | {{range $i, $c := .controllers}} 178 | type t{{.StructName}} struct {} 179 | var {{.StructName}} t{{.StructName}} 180 | 181 | {{range .MethodSpecs}} 182 | func (_ t{{$c.StructName}}) {{.Name}}({{range .Args}} 183 | {{.Name}}_ {{if .ImportPath}}{{index $.ImportPaths .ImportPath | .TypeExpr.TypeName}}{{else}}{{.TypeExpr.TypeName ""}}{{end}},{{end}} 184 | ) string { 185 | args := make(map[string]string) 186 | {{range .Args}} 187 | mars.Unbind(args, "{{.Name}}", {{.Name}}_){{end}} 188 | return mars.MainRouter.Reverse("{{$c.StructName}}.{{.Name}}", args).Url 189 | } 190 | {{end}} 191 | {{end}} 192 | ` 193 | -------------------------------------------------------------------------------- /compress.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "compress/gzip" 5 | "compress/zlib" 6 | "io" 7 | "net/http" 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | var compressionTypes = [...]string{ 13 | "gzip", 14 | "deflate", 15 | } 16 | 17 | var compressableMimes = [...]string{ 18 | "text/plain", 19 | "text/html", 20 | "text/xml", 21 | "text/css", 22 | "application/json", 23 | "application/xml", 24 | "application/xhtml+xml", 25 | "application/rss+xml", 26 | "application/javascript", 27 | "application/x-javascript", 28 | "image/svg+xml", 29 | } 30 | 31 | type writeFlusher interface { 32 | io.Writer 33 | io.Closer 34 | Flush() error 35 | } 36 | 37 | type CompressResponseWriter struct { 38 | http.ResponseWriter 39 | compressWriter writeFlusher 40 | compressionType string 41 | headersWritten bool 42 | closeNotify chan bool 43 | parentNotify <-chan bool 44 | closed bool 45 | } 46 | 47 | func CompressFilter(c *Controller, fc []Filter) { 48 | fc[0](c, fc[1:]) 49 | if Config.BoolDefault("results.compressed", false) { 50 | if c.Response.Status != http.StatusNoContent && c.Response.Status != http.StatusNotModified { 51 | writer := CompressResponseWriter{c.Response.Out, nil, "", false, make(chan bool, 1), nil, false} 52 | writer.DetectCompressionType(c.Request, c.Response) 53 | w, ok := c.Response.Out.(http.CloseNotifier) 54 | if ok { 55 | writer.parentNotify = w.CloseNotify() 56 | } 57 | c.Response.Out = &writer 58 | } else { 59 | TRACE.Printf("Compression disabled for response status (%d)", c.Response.Status) 60 | } 61 | } 62 | } 63 | 64 | func (c CompressResponseWriter) CloseNotify() <-chan bool { 65 | if c.parentNotify != nil { 66 | return c.parentNotify 67 | } 68 | return c.closeNotify 69 | } 70 | 71 | func (c *CompressResponseWriter) prepareHeaders() { 72 | if c.compressionType != "" { 73 | responseMime := c.Header().Get("Content-Type") 74 | responseMime = strings.TrimSpace(strings.SplitN(responseMime, ";", 2)[0]) 75 | shouldEncode := false 76 | 77 | if c.Header().Get("Content-Encoding") == "" { 78 | for _, compressableMime := range compressableMimes { 79 | if responseMime == compressableMime { 80 | shouldEncode = true 81 | c.Header().Set("Content-Encoding", c.compressionType) 82 | c.Header().Del("Content-Length") 83 | break 84 | } 85 | } 86 | } 87 | 88 | if !shouldEncode { 89 | c.compressWriter = nil 90 | c.compressionType = "" 91 | } 92 | } 93 | } 94 | 95 | func (c *CompressResponseWriter) WriteHeader(status int) { 96 | c.headersWritten = true 97 | c.prepareHeaders() 98 | c.ResponseWriter.WriteHeader(status) 99 | } 100 | 101 | func (c *CompressResponseWriter) Close() error { 102 | if c.compressionType != "" { 103 | c.compressWriter.Close() 104 | } 105 | if w, ok := c.ResponseWriter.(io.Closer); ok { 106 | w.Close() 107 | } 108 | // Non-blocking write to the closenotifier, if we for some reason should 109 | // get called multiple times 110 | select { 111 | case c.closeNotify <- true: 112 | default: 113 | } 114 | c.closed = true 115 | return nil 116 | } 117 | 118 | func (c *CompressResponseWriter) Write(b []byte) (int, error) { 119 | // Abort if parent has been closed 120 | if c.parentNotify != nil { 121 | select { 122 | case <-c.parentNotify: 123 | return 0, io.ErrClosedPipe 124 | default: 125 | } 126 | } 127 | // Abort if we ourselves have been closed 128 | if c.closed { 129 | return 0, io.ErrClosedPipe 130 | } 131 | if !c.headersWritten { 132 | c.prepareHeaders() 133 | c.headersWritten = true 134 | } 135 | 136 | if c.compressionType != "" { 137 | return c.compressWriter.Write(b) 138 | } else { 139 | return c.ResponseWriter.Write(b) 140 | } 141 | } 142 | 143 | func (c *CompressResponseWriter) DetectCompressionType(req *Request, resp *Response) { 144 | if Config.BoolDefault("results.compressed", false) { 145 | acceptedEncodings := strings.Split(req.Request.Header.Get("Accept-Encoding"), ",") 146 | 147 | largestQ := 0.0 148 | chosenEncoding := len(compressionTypes) 149 | 150 | for _, encoding := range acceptedEncodings { 151 | encoding = strings.TrimSpace(encoding) 152 | encodingParts := strings.SplitN(encoding, ";", 2) 153 | 154 | // If we are the format "gzip;q=0.8" 155 | if len(encodingParts) > 1 { 156 | // Strip off the q= 157 | num, err := strconv.ParseFloat(strings.TrimSpace(encodingParts[1])[2:], 32) 158 | if err != nil { 159 | continue 160 | } 161 | 162 | if num >= largestQ && num > 0 { 163 | if encodingParts[0] == "*" { 164 | chosenEncoding = 0 165 | largestQ = num 166 | continue 167 | } 168 | for i, encoding := range compressionTypes { 169 | if encoding == encodingParts[0] { 170 | if i < chosenEncoding { 171 | largestQ = num 172 | chosenEncoding = i 173 | } 174 | break 175 | } 176 | } 177 | } 178 | } else { 179 | // If we can accept anything, chose our preferred method. 180 | if encodingParts[0] == "*" { 181 | chosenEncoding = 0 182 | largestQ = 1 183 | break 184 | } 185 | // This is for just plain "gzip" 186 | for i, encoding := range compressionTypes { 187 | if encoding == encodingParts[0] { 188 | if i < chosenEncoding { 189 | largestQ = 1.0 190 | chosenEncoding = i 191 | } 192 | break 193 | } 194 | } 195 | } 196 | } 197 | 198 | if largestQ == 0 { 199 | return 200 | } 201 | 202 | c.compressionType = compressionTypes[chosenEncoding] 203 | 204 | switch c.compressionType { 205 | case "gzip": 206 | c.compressWriter = gzip.NewWriter(resp.Out) 207 | case "deflate": 208 | c.compressWriter = zlib.NewWriter(resp.Out) 209 | } 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/hex" 6 | "fmt" 7 | "net/http" 8 | "net/url" 9 | "strconv" 10 | "strings" 11 | "time" 12 | ) 13 | 14 | // A signed cookie (and thus limited to 4kb in size). 15 | // Restriction: Keys may not have a colon in them. 16 | type Session map[string]string 17 | 18 | const ( 19 | SESSION_ID_KEY = "_ID" 20 | TIMESTAMP_KEY = "_TS" 21 | ) 22 | 23 | // expireAfterDuration is the time to live, in seconds, of a session cookie. 24 | // It may be specified in config as "session.expires". Values greater than 0 25 | // set a persistent cookie with a time to live as specified, and the value 0 26 | // sets a session cookie. 27 | var expireAfterDuration time.Duration 28 | 29 | func init() { 30 | // Set expireAfterDuration, default to 24 hours if no value in config 31 | OnAppStart(func() { 32 | var err error 33 | if expiresString, ok := Config.String("session.expires"); !ok { 34 | expireAfterDuration = 24 * time.Hour 35 | } else if expiresString == "session" { 36 | expireAfterDuration = 0 37 | } else if expireAfterDuration, err = time.ParseDuration(expiresString); err != nil { 38 | panic(fmt.Errorf("session.expires invalid: %s", err)) 39 | } 40 | }) 41 | } 42 | 43 | // Id retrieves from the cookie or creates a time-based UUID identifying this 44 | // session. 45 | func (s Session) Id() string { 46 | if sessionIdStr, ok := s[SESSION_ID_KEY]; ok { 47 | return sessionIdStr 48 | } 49 | 50 | buffer := make([]byte, 32) 51 | if _, err := rand.Read(buffer); err != nil { 52 | panic(err) 53 | } 54 | 55 | s[SESSION_ID_KEY] = hex.EncodeToString(buffer) 56 | return s[SESSION_ID_KEY] 57 | } 58 | 59 | // getExpiration return a time.Time with the session's expiration date. 60 | // If previous session has set to "session", remain it 61 | func (s Session) getExpiration() time.Time { 62 | if expireAfterDuration == 0 || s[TIMESTAMP_KEY] == "session" { 63 | // Expire after closing browser 64 | return time.Time{} 65 | } 66 | return time.Now().Add(expireAfterDuration) 67 | } 68 | 69 | // Cookie returns an http.Cookie containing the signed session. 70 | func (s Session) Cookie() *http.Cookie { 71 | var sessionValue string 72 | ts := s.getExpiration() 73 | s[TIMESTAMP_KEY] = getSessionExpirationCookie(ts) 74 | for key, value := range s { 75 | if strings.ContainsAny(key, ":\x00") { 76 | panic("Session keys may not have colons or null bytes") 77 | } 78 | if strings.Contains(value, "\x00") { 79 | panic("Session values may not have null bytes") 80 | } 81 | sessionValue += "\x00" + key + ":" + value + "\x00" 82 | } 83 | 84 | sessionData := url.QueryEscape(sessionValue) 85 | return &http.Cookie{ 86 | Name: CookiePrefix + "_SESSION", 87 | Value: Sign(sessionData) + "/" + sessionData, 88 | Domain: CookieDomain, 89 | Path: "/", 90 | HttpOnly: CookieHttpOnly, 91 | Secure: CookieSecure, 92 | Expires: ts.UTC(), 93 | } 94 | } 95 | 96 | // sessionTimeoutExpiredOrMissing returns a boolean of whether the session 97 | // cookie is either not present or present but beyond its time to live; i.e., 98 | // whether there is not a valid session. 99 | func sessionTimeoutExpiredOrMissing(session Session) bool { 100 | if exp, present := session[TIMESTAMP_KEY]; !present { 101 | return true 102 | } else if exp == "session" { 103 | return false 104 | } else if expInt, _ := strconv.Atoi(exp); int64(expInt) < time.Now().Unix() { 105 | return true 106 | } 107 | return false 108 | } 109 | 110 | // GetSessionFromCookie returns a Session struct pulled from the signed 111 | // session cookie. 112 | func GetSessionFromCookie(cookie *http.Cookie) Session { 113 | session := make(Session) 114 | 115 | // Separate the data from the signature. 116 | sep := strings.Index(cookie.Value, "/") 117 | if sep == -1 || sep >= len(cookie.Value)-1 { 118 | return session 119 | } 120 | sig, data := cookie.Value[:sep], cookie.Value[sep+1:] 121 | 122 | // Verify the signature. 123 | if !Verify(data, sig) { 124 | INFO.Println("Session cookie signature failed") 125 | return session 126 | } 127 | 128 | parseKeyValueCookie(data, func(key, val string) { 129 | session[key] = val 130 | }) 131 | 132 | if sessionTimeoutExpiredOrMissing(session) { 133 | session = make(Session) 134 | } 135 | 136 | return session 137 | } 138 | 139 | // SessionFilter is a Mars Filter that retrieves and sets the session cookie. 140 | // Within Mars, it is available as a Session attribute on Controller instances. 141 | // The name of the Session cookie is set as CookiePrefix + "_SESSION". 142 | func SessionFilter(c *Controller, fc []Filter) { 143 | c.Session = restoreSession(c.Request.Request) 144 | sessionWasEmpty := len(c.Session) == 0 145 | 146 | // Make session vars available in templates as {{.session.xyz}} 147 | c.RenderArgs["session"] = c.Session 148 | 149 | fc[0](c, fc[1:]) 150 | 151 | // Store the signed session if it could have changed. 152 | if len(c.Session) > 0 || !sessionWasEmpty { 153 | c.SetCookie(c.Session.Cookie()) 154 | } 155 | } 156 | 157 | // restoreSession returns either the current session, retrieved from the 158 | // session cookie, or a new session. 159 | func restoreSession(req *http.Request) Session { 160 | cookie, err := req.Cookie(CookiePrefix + "_SESSION") 161 | if err != nil { 162 | return make(Session) 163 | } else { 164 | return GetSessionFromCookie(cookie) 165 | } 166 | } 167 | 168 | // getSessionExpirationCookie retrieves the cookie's time to live as a 169 | // string of either the number of seconds, for a persistent cookie, or 170 | // "session". 171 | func getSessionExpirationCookie(t time.Time) string { 172 | if t.IsZero() { 173 | return "session" 174 | } 175 | return strconv.FormatInt(t.Unix(), 10) 176 | } 177 | 178 | // SetNoExpiration sets session to expire when browser session ends 179 | func (s Session) SetNoExpiration() { 180 | s[TIMESTAMP_KEY] = "session" 181 | } 182 | 183 | // SetDefaultExpiration sets session to expire after default duration 184 | func (s Session) SetDefaultExpiration() { 185 | delete(s, TIMESTAMP_KEY) 186 | } 187 | -------------------------------------------------------------------------------- /internal/pathtree/tree_test.go: -------------------------------------------------------------------------------- 1 | package pathtree 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestColon(t *testing.T) { 10 | n := New() 11 | 12 | n.Add("/:first/:second/", 1) 13 | n.Add("/:first", 2) 14 | n.Add("/", 3) 15 | 16 | found(t, n, "/", nil, 3) 17 | found(t, n, "/a", []string{"a"}, 2) 18 | found(t, n, "/a/", []string{"a"}, 2) 19 | found(t, n, "/a/b", []string{"a", "b"}, 1) 20 | found(t, n, "/a/b/", []string{"a", "b"}, 1) 21 | 22 | notfound(t, n, "/a/b/c") 23 | } 24 | 25 | func TestStar(t *testing.T) { 26 | n := New() 27 | 28 | n.Add("/first/second/*star", 1) 29 | n.Add("/:first/*star/", 2) 30 | n.Add("/*star", 3) 31 | n.Add("/", 4) 32 | 33 | found(t, n, "/", nil, 4) 34 | found(t, n, "/a", []string{"a"}, 3) 35 | found(t, n, "/a/", []string{"a"}, 3) 36 | found(t, n, "/a/b", []string{"a", "b"}, 2) 37 | found(t, n, "/a/b/", []string{"a", "b"}, 2) 38 | found(t, n, "/a/b/c", []string{"a", "b/c"}, 2) 39 | found(t, n, "/a/b/c/", []string{"a", "b/c"}, 2) 40 | found(t, n, "/a/b/c/d", []string{"a", "b/c/d"}, 2) 41 | found(t, n, "/first/second", []string{"first", "second"}, 2) 42 | found(t, n, "/first/second/", []string{"first", "second"}, 2) 43 | found(t, n, "/first/second/third", []string{"third"}, 1) 44 | } 45 | 46 | func TestMixedTree(t *testing.T) { 47 | n := New() 48 | 49 | n.Add("/", 0) 50 | n.Add("/path/to/nowhere", 1) 51 | n.Add("/path/:i/nowhere", 2) 52 | n.Add("/:id/to/nowhere", 3) 53 | n.Add("/:a/:b", 4) 54 | n.Add("/not/found", 5) 55 | 56 | found(t, n, "/", nil, 0) 57 | found(t, n, "/path/to/nowhere", nil, 1) 58 | found(t, n, "/path/to/nowhere/", nil, 1) 59 | found(t, n, "/path/from/nowhere", []string{"from"}, 2) 60 | found(t, n, "/walk/to/nowhere", []string{"walk"}, 3) 61 | found(t, n, "/path/to/", []string{"path", "to"}, 4) 62 | found(t, n, "/path/to", []string{"path", "to"}, 4) 63 | found(t, n, "/not/found", []string{"not", "found"}, 4) 64 | notfound(t, n, "/path/to/somewhere") 65 | notfound(t, n, "/path/to/nowhere/else") 66 | notfound(t, n, "/path") 67 | notfound(t, n, "/path/") 68 | 69 | notfound(t, n, "") 70 | notfound(t, n, "xyz") 71 | notfound(t, n, "/path//to/nowhere") 72 | } 73 | 74 | func TestExtensions(t *testing.T) { 75 | n := New() 76 | 77 | n.Add("/:first/:second.json", 1) 78 | n.Add("/a/:second.xml", 2) 79 | n.Add("/:first/:second", 3) 80 | 81 | found(t, n, "/a/b", []string{"a", "b"}, 3) 82 | found(t, n, "/a/b.json", []string{"a", "b"}, 1) 83 | found(t, n, "/a/b.xml", []string{"b"}, 2) 84 | found(t, n, "/a/b.c.xml", []string{"b.c"}, 2) 85 | found(t, n, "/other/b.xml", []string{"other", "b.xml"}, 3) 86 | } 87 | 88 | func TestNonFittingExtensions(t *testing.T) { 89 | n := New() 90 | 91 | n.Add("/:first/:second.json", 1) 92 | n.Add("/a/:second.xml", 2) 93 | 94 | notfound(t, n, "/a/b") 95 | notfound(t, n, "/a/b.png") 96 | } 97 | 98 | func TestVariableExtensions(t *testing.T) { 99 | n := New() 100 | 101 | n.Add("/a/b.:ext", `ext`) 102 | n.Add("/a/:filename.json", `base`) 103 | n.Add("/a/:filename.:ext", `baseext`) 104 | n.Add("/:first/:second.:ext", `dirbaseext`) 105 | 106 | // Fixed path with variable extension will always get higher priority than variable filename with fixed extension 107 | found(t, n, "/a/b.json", []string{"json"}, `ext`) 108 | found(t, n, "/a/somefile.json", []string{"somefile"}, `base`) 109 | found(t, n, "/a/somefile.xml", []string{"somefile", "xml"}, `baseext`) 110 | found(t, n, "/a/b.xml", []string{"xml"}, `ext`) 111 | found(t, n, "/other/b.xml", []string{"other", "b", "xml"}, `dirbaseext`) 112 | } 113 | 114 | func TestNonFittingVariableExtensions(t *testing.T) { 115 | n := New() 116 | 117 | n.Add("/first/:second.:ext", 1) 118 | n.Add("/a/b.:ext", 2) 119 | 120 | notfound(t, n, "/a/b") 121 | notfound(t, n, "/first/file") 122 | } 123 | 124 | func TestErrors(t *testing.T) { 125 | n := New() 126 | fails(t, n.Add("//", 1), "empty path elements not allowed") 127 | } 128 | 129 | func BenchmarkTree100(b *testing.B) { 130 | n := New() 131 | n.Add("/", "root") 132 | 133 | // Exact matches 134 | for i := 0; i < 100; i++ { 135 | depth := i%5 + 1 136 | key := "" 137 | for j := 0; j < depth-1; j++ { 138 | key += fmt.Sprintf("/dir%d", j) 139 | } 140 | key += fmt.Sprintf("/resource%d", i) 141 | n.Add(key, "literal") 142 | // b.Logf("Adding %s", key) 143 | } 144 | 145 | // Wildcards at each level if no exact matches work. 146 | for i := 0; i < 5; i++ { 147 | var key string 148 | for j := 0; j < i; j++ { 149 | key += fmt.Sprintf("/dir%d", j) 150 | } 151 | key += "/:var" 152 | n.Add(key, "var") 153 | // b.Logf("Adding %s", key) 154 | } 155 | 156 | n.Add("/public/*filepath", "static") 157 | // b.Logf("Adding /public/*filepath") 158 | 159 | queries := map[string]string{ 160 | "/": "root", 161 | "/dir0/dir1/dir2/dir3/resource4": "literal", 162 | "/dir0/dir1/resource97": "literal", 163 | "/dir0/variable": "var", 164 | "/dir0/dir1/dir2/dir3/variable": "var", 165 | "/public/stylesheets/main.css": "static", 166 | "/public/images/icons/an-image.png": "static", 167 | } 168 | 169 | for query, answer := range queries { 170 | leaf, _ := n.Find(query) 171 | if leaf == nil { 172 | b.Errorf("Failed to find leaf for querY %s", query) 173 | return 174 | } 175 | if leaf.Value.(string) != answer { 176 | b.Errorf("Incorrect answer for querY %s: expected: %s, actual: %s", 177 | query, answer, leaf.Value.(string)) 178 | return 179 | } 180 | } 181 | 182 | b.ResetTimer() 183 | 184 | for i := 0; i < b.N/len(queries); i++ { 185 | for k := range queries { 186 | n.Find(k) 187 | } 188 | } 189 | } 190 | 191 | func notfound(t *testing.T, n *Node, p string) { 192 | if leaf, _ := n.Find(p); leaf != nil { 193 | t.Errorf("Should not have found: %s", p) 194 | } 195 | } 196 | 197 | func found(t *testing.T, n *Node, p string, expectedExpansions []string, val interface{}) { 198 | leaf, expansions := n.Find(p) 199 | if leaf == nil { 200 | t.Errorf("Didn't find: %s", p) 201 | return 202 | } 203 | if !reflect.DeepEqual(expansions, expectedExpansions) { 204 | t.Errorf("%s: Wildcard expansions (actual) %v != %v (expected)", p, expansions, expectedExpansions) 205 | } 206 | if leaf.Value != val { 207 | t.Errorf("%s: Value (actual) %v != %v (expected)", p, leaf.Value, val) 208 | } 209 | } 210 | 211 | func fails(t *testing.T, err error, msg string) { 212 | if err == nil { 213 | t.Errorf("expected an error. %s", msg) 214 | } 215 | } 216 | -------------------------------------------------------------------------------- /intercept.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "log" 5 | "reflect" 6 | ) 7 | 8 | // An "interceptor" is functionality invoked by the framework BEFORE or AFTER 9 | // an action. 10 | // 11 | // An interceptor may optionally return a Result (instead of nil). Depending on 12 | // when the interceptor was invoked, the response is different: 13 | // 1. BEFORE: No further interceptors are invoked, and neither is the action. 14 | // 2. AFTER: Further interceptors are still run. 15 | // In all cases, any returned Result will take the place of any existing Result. 16 | // 17 | // In the BEFORE case, that returned Result is guaranteed to be final, while 18 | // in the AFTER case it is possible that a further interceptor could emit its 19 | // own Result. 20 | // 21 | // Interceptors are called in the order that they are added. 22 | // 23 | // *** 24 | // 25 | // Two types of interceptors are provided: Funcs and Methods 26 | // 27 | // Func Interceptors may apply to any / all Controllers. 28 | // 29 | // func example(*mars.Controller) mars.Result 30 | // 31 | // Method Interceptors are provided so that properties can be set on application 32 | // controllers. 33 | // 34 | // func (c AppController) example() mars.Result 35 | // func (c *AppController) example() mars.Result 36 | // 37 | 38 | // InterceptorFunc represents a function that can be used to intercept action 39 | // invocation of a specific or all controllers. 40 | type InterceptorFunc func(*Controller) Result 41 | 42 | // InterceptorMethos represents a method that can be used to intercept action 43 | // invocations of a specific application controller. 44 | type InterceptorMethod interface{} 45 | 46 | // When allows specifying when an interceptor shall be called. 47 | type When int 48 | 49 | const ( 50 | // Interceptor shall be called before invoking the action. 51 | BEFORE When = iota 52 | // Interceptor shall be called after invoking the action. 53 | AFTER 54 | // Interceptor shall be called in case the action paniced. 55 | PANIC 56 | // Interceptor shall be called after invoking the action, and after recovering from a panic. 57 | FINALLY 58 | ) 59 | 60 | // InterceptTarget is a helper make AllControllers have a valid type. 61 | type InterceptTarget int 62 | 63 | const ( 64 | // AllControllers means that the function will intercept all controllers. 65 | AllControllers InterceptTarget = iota 66 | ) 67 | 68 | // Interception allows specifying a configuration of when to intercept which invocations. 69 | type Interception struct { 70 | When When 71 | 72 | function InterceptorFunc 73 | method InterceptorMethod 74 | 75 | callable reflect.Value 76 | target reflect.Type 77 | interceptAll bool 78 | } 79 | 80 | // Invoke performs the given interception. val is a pointer to the App Controller. 81 | func (i Interception) Invoke(val reflect.Value) reflect.Value { 82 | var arg reflect.Value 83 | if i.function == nil { 84 | // If it's an InterceptorMethod, then we have to pass in the target type. 85 | arg = findTarget(val, i.target) 86 | } else { 87 | // If it's an InterceptorFunc, then the type must be *Controller. 88 | // We can find that by following the embedded types up the chain. 89 | for val.Type() != controllerPtrType { 90 | if val.Kind() == reflect.Ptr { 91 | val = val.Elem() 92 | } 93 | val = val.Field(0) 94 | } 95 | arg = val 96 | } 97 | 98 | vals := i.callable.Call([]reflect.Value{arg}) 99 | return vals[0] 100 | } 101 | 102 | // InterceptorFilter adds the interception functionality to Mars' filter chain. 103 | func InterceptorFilter(c *Controller, fc []Filter) { 104 | defer invokeInterceptors(FINALLY, c) 105 | defer func() { 106 | if err := recover(); err != nil { 107 | invokeInterceptors(PANIC, c) 108 | panic(err) 109 | } 110 | }() 111 | 112 | // Invoke the BEFORE interceptors and return early, if we get a result. 113 | invokeInterceptors(BEFORE, c) 114 | if c.Result != nil { 115 | return 116 | } 117 | 118 | fc[0](c, fc[1:]) 119 | invokeInterceptors(AFTER, c) 120 | } 121 | 122 | func invokeInterceptors(when When, c *Controller) { 123 | var ( 124 | app = reflect.ValueOf(c.AppController) 125 | result Result 126 | ) 127 | for _, intc := range getInterceptors(when, app) { 128 | resultValue := intc.Invoke(app) 129 | if !resultValue.IsNil() { 130 | result = resultValue.Interface().(Result) 131 | } 132 | if when == BEFORE && result != nil { 133 | c.Result = result 134 | return 135 | } 136 | } 137 | if result != nil { 138 | c.Result = result 139 | } 140 | } 141 | 142 | var interceptors []*Interception 143 | 144 | // InterceptFunc installs a general interceptor. 145 | // This can be applied to any Controller. 146 | // It must have the signature of: 147 | // func example(c *mars.Controller) mars.Result 148 | func InterceptFunc(intc InterceptorFunc, when When, target interface{}) { 149 | interceptors = append(interceptors, &Interception{ 150 | When: when, 151 | function: intc, 152 | callable: reflect.ValueOf(intc), 153 | target: reflect.TypeOf(target), 154 | interceptAll: target == AllControllers, 155 | }) 156 | } 157 | 158 | // InterceptMethod installs an interceptor method that applies to its own Controller. 159 | // func (c AppController) example() mars.Result 160 | // func (c *AppController) example() mars.Result 161 | func InterceptMethod(intc InterceptorMethod, when When) { 162 | methodType := reflect.TypeOf(intc) 163 | if methodType.Kind() != reflect.Func || methodType.NumOut() != 1 || methodType.NumIn() != 1 { 164 | log.Fatalln("Interceptor method should have signature like", 165 | "'func (c *AppController) example() mars.Result' but was", methodType) 166 | } 167 | interceptors = append(interceptors, &Interception{ 168 | When: when, 169 | method: intc, 170 | callable: reflect.ValueOf(intc), 171 | target: methodType.In(0), 172 | }) 173 | } 174 | 175 | func getInterceptors(when When, val reflect.Value) []*Interception { 176 | result := []*Interception{} 177 | for _, intc := range interceptors { 178 | if intc.When != when { 179 | continue 180 | } 181 | 182 | if intc.interceptAll || findTarget(val, intc.target).IsValid() { 183 | result = append(result, intc) 184 | } 185 | } 186 | return result 187 | } 188 | 189 | // Find the value of the target, starting from val and including embedded types. 190 | // Also, convert between any difference in indirection. 191 | // If the target couldn't be found, the returned Value will have IsValid() == false 192 | func findTarget(val reflect.Value, target reflect.Type) reflect.Value { 193 | // Look through the embedded types (until we reach the *mars.Controller at the top). 194 | valueQueue := []reflect.Value{val} 195 | for len(valueQueue) > 0 { 196 | val, valueQueue = valueQueue[0], valueQueue[1:] 197 | 198 | // Check if val is of a similar type to the target type. 199 | if val.Type() == target { 200 | return val 201 | } 202 | if val.Kind() == reflect.Ptr && val.Elem().Type() == target { 203 | return val.Elem() 204 | } 205 | if target.Kind() == reflect.Ptr && target.Elem() == val.Type() { 206 | return val.Addr() 207 | } 208 | 209 | // If we reached the *mars.Controller and still didn't find what we were 210 | // looking for, give up. 211 | if val.Type() == controllerPtrType { 212 | continue 213 | } 214 | 215 | // Else, add each anonymous field to the queue. 216 | if val.Kind() == reflect.Ptr { 217 | val = val.Elem() 218 | } 219 | 220 | for i := 0; i < val.NumField(); i++ { 221 | if val.Type().Field(i).Anonymous { 222 | valueQueue = append(valueQueue, val.Field(i)) 223 | } 224 | } 225 | } 226 | 227 | return reflect.Value{} 228 | } 229 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # MARS CHANGELOG 2 | 3 | All notable changes to Mars will be documented in this file. 4 | The format is based on [Keep a Changelog](http://keepachangelog.com/). 5 | 6 | ## [Unreleased](https://github.com/roblillack/mars/compare/v1.1.0...master) 7 | 8 | ## [v1.1.0](https://github.com/roblillack/mars/compare/v1.0.4...v1.1.0) 9 | 10 | - New routing features: 11 | - Add support for variable file extensions. #18 12 | - Security fixes: 13 | - Better sanitize some trace logs that might contain user input. #24 14 | - Improvements to code generation using `mars-gen`: 15 | - Add support for generating routes when query parameters of type `any` are used. [94d2c9e] 16 | - Improve quality of generated code and `go fmt` it. #21 17 | - Code quality improvements: 18 | - Add support for static code analysis using CodeQL. [b8b7f93] 19 | - Add unit test for `cert`. [9fb9184] 20 | - Hide some internal data structures. #16 (#22) 21 | - Fix test vetting errors. [65ad5a2] 22 | - Improve Go 1.18 support, minimum supported Go version is now 1.13. [a03670b] 23 | 24 | ## [v1.0.4](https://github.com/roblillack/mars/compare/v1.0.3...v1.0.4) 25 | 26 | - Code generation: 27 | - Add package name cache to speed up parsing very large code bases. #19 28 | - Add `--verbose` or `-v` flag to print the names of source files as they are parsed. #19 29 | 30 | ## [v1.0.3](https://github.com/roblillack/mars/compare/v1.0.2...v1.0.3) 31 | 32 | - Infrastructure: 33 | - Remove support for Go 1.8 to 1.11, depend on Go Modules by default. 34 | - Usage: 35 | - Introduce automatic `setup` process for TemplateLoader and Router. Mars can now be started without calling [InitDefaults()](https://godoc.org/github.com/roblillack/mars#InitDefaults) before. #13 36 | - Code generation: 37 | - Add functional tests. #19 38 | - Improve parsing speed, by skipping unnecessary imports of the Go standard library. #19 39 | 40 | ## [v1.0.2](https://github.com/roblillack/mars/compare/v1.0.1...v1.0.2) 41 | 42 | - Defaults: 43 | - Fix mime type for Python bytecode. 44 | - Infrastructure: 45 | - Add AppVeyor for automated builds & tests on Windows. 46 | - Fix tests when running in Go Modules mode. 47 | - Fix `got vet ./...` issues. 48 | 49 | ## [v1.0.1](https://github.com/roblillack/mars/compare/v1.0.0...v1.0.1) 50 | 51 | - Router: Fix panic, if no router initialized. #10 52 | - Templates: 53 | - Fix panic, if no template loader initialized. #11 54 | - Setup fallback template loader to use embedded templates without configuration. #12 55 | 56 | ## [v1.0.0](https://github.com/roblillack/mars/compare/a9a2ff4...v1.0.0) 57 | 58 | - Let's make that 1.0.0. 59 | - Build with current Go versions. 60 | - Setup Go Module. 61 | - Remove old versioning information. 62 | - Remove git submodules. 63 | - mars-gen: More compiler error fixes. 64 | - Fix Go tip compiler errors. 65 | - Router: Add support for file extensions after actions argutments. #9 66 | - Enable graceful shutdown. #8 67 | - Implement shutdown hooks. #8 68 | - Remove support for Go <1.8. #8 69 | - README: Fix links 70 | - router: Fix path escaping for Go<1.8. #7 71 | - travis: Add Go 1.8 support. 72 | - router: Fix building reverse routes with path segments that contain reserved characters. #7 73 | - compress: Compress SVG images, too. 74 | - README: Document changes form #6. 75 | - Add HTTP(S) dualstack support (incl. option to generate self-signed certs). #6 76 | - mars-gen: Sort files by name when processing packages to get stable order of registered controllers. 77 | - mars-gen: Make sure, we can generate without a resolvable Mars installation. Fixes #5. 78 | - Merge pull request #4 from ipoerner/issue-3 79 | - Default to no timeout with configurable changes. 80 | - Allow setting absolute ConfigFile path. 81 | - Fix panic_test for Go 1.5. 82 | - Add Go 1.7 to travis builds. 83 | - panic_test: Make go 1.5 error easier to debug. 84 | - Add test for panic filter. 85 | - templates: Add template availability test. 86 | - Fix context-aware translation function, add test. #2 87 | - templates: Fix HTML safe render func, add tests. #2 88 | - templates: First work towards HTML-safe translate func. #2 89 | - Document XSS fix #1. 90 | - templates: Stop implicitly marking translation output as safe. Fixes #1. 91 | - docs: Add testing.md 92 | - CSRF protection: Set debugging messages to trace, fix 'SkipCSRF' check. 93 | - README: Update to reflect CSRF changes. 94 | - csrf: Make this an actual func. 95 | - docs: Update to reflect CSRF stuff. 96 | - Add CSRF protection functionality. 97 | - server: Expose mars.Handler, also fixes refreshing Router. 98 | - server: Allow booting without having a config file. 99 | - Get rid of configurable template delimiters. 100 | - mars-gen: Fix parsing array types. 101 | - Add Coverage Status badge to README 102 | - Add coveralls integration 103 | - server: Don't set up watcher for templates, if we have no TemplateLoader. 104 | - travis: Build with Go 1.5, 1.6, and tip. 105 | - Remove "cron" dependency – not used anymore. 106 | - Remove glide from .gitmodules. 107 | - Change submodule repo paths to match glide config's. 108 | - Update fsnotify and x/net submodules. 109 | - Update fsnotify to v1.3.1, remove gopkg.in dependency. 110 | - Reformat glide.yaml. 111 | - docs: Fix formatting. 112 | - docs: Start working on the documentation. 113 | - Streamline logger configuration. 114 | - Fix fakeapp test. 115 | - Better handling of default values. 116 | - Remove unused code. 117 | - Remove Initialized flag. 118 | - Remove CodePaths, ConfPaths, ImportPath. 119 | - README: Document how to switch from Revel to Mars. 120 | - Sort controllers, when generating code to not pollute your 'git status' all the time. 121 | - Fix overriding embedded templates from application provided ones. 122 | - Add caching functionality to Static controller. 123 | - README: Add GoDoc reference image. 124 | - README: Update regarding the reverse route generation fixes. 125 | - mars-gen: Remove debug messages. 126 | - mars-gen: Add support for action parameters of type interface{} 127 | - mars-gen: Add support for action parameters called "args". 128 | - Further improve documentation of interception functionality. 129 | - Add render function changes to README. 130 | - Add documentation to interception handling. 131 | - Fix case for render functions and result types. 132 | - Code style improvements. 133 | - Remove magic, that adds template parameters based on variable names in Render calls. 134 | - README: Add GoDoc reference. 135 | - Start improving code style to make golint happier. 136 | - README: Document differences to revel. 137 | - Add mars-gen -- the code generator for registering controllers and reverse routes. 138 | - Remove revel modules. 139 | - Fix fakeapp test. 140 | - Add some links to README. 141 | - Add static controller (was a module before). 142 | - templates: Fix preloading of embedded templates. 143 | - Remove module support. 144 | - Remove AppPath, ViewsPath, TemplatePaths. 145 | - template: Remove possibility of specifying delimiters. 146 | - fix build. 147 | - travis: Fix vendor experiment. 148 | - Add .gitmodules. 149 | - travis: Remove glide, dependencies are Git submodules. 150 | - travis: Enable Go Vendor Experiment 151 | - Enable travis. 152 | - More renaming. 153 | - Remove skeleton. Should really be another repository. 154 | - init: Start removing path magic, RevelPath + SourcePath is gone. 155 | - mime: Embed default mime-types. 156 | - Fix tests. 157 | - template: Add support for embedded error templates. 158 | - router: Remove automatically reading routes file. 159 | - controller: Set HTTP Status OK only after successfully loading template. 160 | - Remove cache/. 161 | - Vendor dependencies using Glide. 162 | - Add support for Go 1.5 Vendor Experiment. 163 | - Rename package to 'github.com/roblillack/mars'. 164 | -------------------------------------------------------------------------------- /filterconfig.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | ) 7 | 8 | // Map from "Controller" or "Controller.Method" to the Filter chain 9 | var filterOverrides = make(map[string][]Filter) 10 | 11 | // FilterConfigurator allows the developer configure the filter chain on a 12 | // per-controller or per-action basis. The filter configuration is applied by 13 | // the FilterConfiguringFilter, which is itself a filter stage. For example, 14 | // 15 | // Assuming: 16 | // Filters = []Filter{ 17 | // RouterFilter, 18 | // FilterConfiguringFilter, 19 | // SessionFilter, 20 | // ActionInvoker, 21 | // } 22 | // 23 | // Add: 24 | // FilterAction(App.Action). 25 | // Add(OtherFilter) 26 | // 27 | // => RouterFilter, FilterConfiguringFilter, SessionFilter, OtherFilter, ActionInvoker 28 | // 29 | // Remove: 30 | // FilterAction(App.Action). 31 | // Remove(SessionFilter) 32 | // 33 | // => RouterFilter, FilterConfiguringFilter, OtherFilter, ActionInvoker 34 | // 35 | // Insert: 36 | // FilterAction(App.Action). 37 | // Insert(OtherFilter, mars.BEFORE, SessionFilter) 38 | // 39 | // => RouterFilter, FilterConfiguringFilter, OtherFilter, SessionFilter, ActionInvoker 40 | // 41 | // Filter modifications may be combined between Controller and Action. For example: 42 | // FilterController(App{}). 43 | // Add(Filter1) 44 | // FilterAction(App.Action). 45 | // Add(Filter2) 46 | // 47 | // .. would result in App.Action being filtered by both Filter1 and Filter2. 48 | // 49 | // Note: the last filter stage is not subject to the configurator. In 50 | // particular, Add() adds a filter to the second-to-last place. 51 | type FilterConfigurator struct { 52 | key string // e.g. "App", "App.Action" 53 | controllerName string // e.g. "App" 54 | } 55 | 56 | func newFilterConfigurator(controllerName, methodName string) FilterConfigurator { 57 | if methodName == "" { 58 | return FilterConfigurator{controllerName, controllerName} 59 | } 60 | return FilterConfigurator{controllerName + "." + methodName, controllerName} 61 | } 62 | 63 | // FilterController returns a configurator for the filters applied to all 64 | // actions on the given controller instance. For example: 65 | // FilterAction(MyController{}) 66 | func FilterController(controllerInstance interface{}) FilterConfigurator { 67 | t := reflect.TypeOf(controllerInstance) 68 | for t.Kind() == reflect.Ptr { 69 | t = t.Elem() 70 | } 71 | return newFilterConfigurator(t.Name(), "") 72 | } 73 | 74 | // FilterAction returns a configurator for the filters applied to the given 75 | // controller method. For example: 76 | // FilterAction(MyController.MyAction) 77 | func FilterAction(methodRef interface{}) FilterConfigurator { 78 | var ( 79 | methodValue = reflect.ValueOf(methodRef) 80 | methodType = methodValue.Type() 81 | ) 82 | if methodType.Kind() != reflect.Func || methodType.NumIn() == 0 { 83 | panic("Expecting a controller method reference (e.g. Controller.Action), got a " + 84 | methodType.String()) 85 | } 86 | 87 | controllerType := methodType.In(0) 88 | method := findMethod(controllerType, methodValue) 89 | if method == nil { 90 | panic("Action not found on controller " + controllerType.Name()) 91 | } 92 | 93 | for controllerType.Kind() == reflect.Ptr { 94 | controllerType = controllerType.Elem() 95 | } 96 | 97 | return newFilterConfigurator(controllerType.Name(), method.Name) 98 | } 99 | 100 | // Add the given filter in the second-to-last position in the filter chain. 101 | // (Second-to-last so that it is before ActionInvoker) 102 | func (conf FilterConfigurator) Add(f Filter) FilterConfigurator { 103 | conf.apply(func(fc []Filter) []Filter { 104 | return conf.addFilter(f, fc) 105 | }) 106 | return conf 107 | } 108 | 109 | func (conf FilterConfigurator) addFilter(f Filter, fc []Filter) []Filter { 110 | return append(fc[:len(fc)-1], f, fc[len(fc)-1]) 111 | } 112 | 113 | // Remove a filter from the filter chain. 114 | func (conf FilterConfigurator) Remove(target Filter) FilterConfigurator { 115 | conf.apply(func(fc []Filter) []Filter { 116 | return conf.rmFilter(target, fc) 117 | }) 118 | return conf 119 | } 120 | 121 | func (conf FilterConfigurator) rmFilter(target Filter, fc []Filter) []Filter { 122 | for i, f := range fc { 123 | if FilterEq(f, target) { 124 | return append(fc[:i], fc[i+1:]...) 125 | } 126 | } 127 | return fc 128 | } 129 | 130 | // Insert a filter into the filter chain before or after another. 131 | // This may be called with the BEFORE or AFTER constants, for example: 132 | // mars.FilterAction(App.Index). 133 | // Insert(MyFilter, mars.BEFORE, mars.ActionInvoker). 134 | // Insert(MyFilter2, mars.AFTER, mars.PanicFilter) 135 | func (conf FilterConfigurator) Insert(insert Filter, where When, target Filter) FilterConfigurator { 136 | if where != BEFORE && where != AFTER { 137 | panic("where must be BEFORE or AFTER") 138 | } 139 | conf.apply(func(fc []Filter) []Filter { 140 | return conf.insertFilter(insert, where, target, fc) 141 | }) 142 | return conf 143 | } 144 | 145 | func (conf FilterConfigurator) insertFilter(insert Filter, where When, target Filter, fc []Filter) []Filter { 146 | for i, f := range fc { 147 | if FilterEq(f, target) { 148 | if where == BEFORE { 149 | return append(fc[:i], append([]Filter{insert}, fc[i:]...)...) 150 | } else { 151 | return append(fc[:i+1], append([]Filter{insert}, fc[i+1:]...)...) 152 | } 153 | } 154 | } 155 | return fc 156 | } 157 | 158 | // getChain returns the filter chain that applies to the given controller or 159 | // action. If no overrides are configured, then a copy of the default filter 160 | // chain is returned. 161 | func (conf FilterConfigurator) getChain() []Filter { 162 | var filters []Filter 163 | if filters = getOverrideChain(conf.controllerName, conf.key); filters == nil { 164 | // The override starts with all filters after FilterConfiguringFilter 165 | for i, f := range Filters { 166 | if FilterEq(f, FilterConfiguringFilter) { 167 | filters = make([]Filter, len(Filters)-i-1) 168 | copy(filters, Filters[i+1:]) 169 | break 170 | } 171 | } 172 | if filters == nil { 173 | panic("FilterConfiguringFilter not found in mars.Filters.") 174 | } 175 | } 176 | return filters 177 | } 178 | 179 | // apply applies the given functional change to the filter overrides. 180 | // No other function modifies the filterOverrides map. 181 | func (conf FilterConfigurator) apply(f func([]Filter) []Filter) { 182 | // Updates any actions that have had their filters overridden, if this is a 183 | // Controller configurator. 184 | if conf.controllerName == conf.key { 185 | for k, v := range filterOverrides { 186 | if strings.HasPrefix(k, conf.controllerName+".") { 187 | filterOverrides[k] = f(v) 188 | } 189 | } 190 | } 191 | 192 | // Update the Controller or Action overrides. 193 | filterOverrides[conf.key] = f(conf.getChain()) 194 | } 195 | 196 | // FilterEq returns true if the two filters reference the same filter. 197 | func FilterEq(a, b Filter) bool { 198 | return reflect.ValueOf(a).Pointer() == reflect.ValueOf(b).Pointer() 199 | } 200 | 201 | // FilterConfiguringFilter is a filter stage that customizes the remaining 202 | // filter chain for the action being invoked. 203 | func FilterConfiguringFilter(c *Controller, fc []Filter) { 204 | if newChain := getOverrideChain(c.Name, c.Action); newChain != nil { 205 | newChain[0](c, newChain[1:]) 206 | return 207 | } 208 | fc[0](c, fc[1:]) 209 | } 210 | 211 | // getOverrideChain retrieves the overrides for the action that is set 212 | func getOverrideChain(controllerName, action string) []Filter { 213 | if newChain, ok := filterOverrides[action]; ok { 214 | return newChain 215 | } 216 | if newChain, ok := filterOverrides[controllerName]; ok { 217 | return newChain 218 | } 219 | return nil 220 | } 221 | -------------------------------------------------------------------------------- /validation.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/url" 7 | "regexp" 8 | "runtime" 9 | ) 10 | 11 | // Simple struct to store the Message & Key of a validation error 12 | type ValidationError struct { 13 | Message, Key string 14 | } 15 | 16 | // String returns the Message field of the ValidationError struct. 17 | func (e *ValidationError) String() string { 18 | if e == nil { 19 | return "" 20 | } 21 | return e.Message 22 | } 23 | 24 | // A Validation context manages data validation and error messages. 25 | type Validation struct { 26 | Errors []*ValidationError 27 | keep bool 28 | } 29 | 30 | // Keep tells Mars to set a flash cookie on the client to make the validation 31 | // errors available for the next request. 32 | // This is helpful when redirecting the client after the validation failed. 33 | // It is good practice to always redirect upon a HTTP POST request. Thus 34 | // one should use this method when HTTP POST validation failed and redirect 35 | // the user back to the form. 36 | func (v *Validation) Keep() { 37 | v.keep = true 38 | } 39 | 40 | // Clear *all* ValidationErrors 41 | func (v *Validation) Clear() { 42 | v.Errors = []*ValidationError{} 43 | } 44 | 45 | // HasErrors returns true if there are any (ie > 0) errors. False otherwise. 46 | func (v *Validation) HasErrors() bool { 47 | return len(v.Errors) > 0 48 | } 49 | 50 | // ErrorMap returns the errors mapped by key. 51 | // If there are multiple validation errors associated with a single key, the 52 | // first one "wins". (Typically the first validation will be the more basic). 53 | func (v *Validation) ErrorMap() map[string]*ValidationError { 54 | m := map[string]*ValidationError{} 55 | for _, e := range v.Errors { 56 | if _, ok := m[e.Key]; !ok { 57 | m[e.Key] = e 58 | } 59 | } 60 | return m 61 | } 62 | 63 | // Error adds an error to the validation context. 64 | func (v *Validation) Error(message string, args ...interface{}) *ValidationResult { 65 | result := (&ValidationResult{ 66 | Ok: false, 67 | Error: &ValidationError{}, 68 | }).Message(message, args...) 69 | v.Errors = append(v.Errors, result.Error) 70 | return result 71 | } 72 | 73 | // A ValidationResult is returned from every validation method. 74 | // It provides an indication of success, and a pointer to the Error (if any). 75 | type ValidationResult struct { 76 | Error *ValidationError 77 | Ok bool 78 | } 79 | 80 | // Key sets the ValidationResult's Error "key" and returns itself for chaining 81 | func (r *ValidationResult) Key(key string) *ValidationResult { 82 | if r.Error != nil { 83 | r.Error.Key = key 84 | } 85 | return r 86 | } 87 | 88 | // Message sets the error message for a ValidationResult. Returns itself to 89 | // allow chaining. Allows Sprintf() type calling with multiple parameters 90 | func (r *ValidationResult) Message(message string, args ...interface{}) *ValidationResult { 91 | if r.Error != nil { 92 | if len(args) == 0 { 93 | r.Error.Message = message 94 | } else { 95 | r.Error.Message = fmt.Sprintf(message, args...) 96 | } 97 | } 98 | return r 99 | } 100 | 101 | // Required tests that the argument is non-nil and non-empty (if string or list) 102 | func (v *Validation) Required(obj interface{}) *ValidationResult { 103 | return v.apply(Required{}, obj) 104 | } 105 | 106 | func (v *Validation) Min(n int, min int) *ValidationResult { 107 | return v.apply(Min{min}, n) 108 | } 109 | 110 | func (v *Validation) Max(n int, max int) *ValidationResult { 111 | return v.apply(Max{max}, n) 112 | } 113 | 114 | func (v *Validation) Range(n, min, max int) *ValidationResult { 115 | return v.apply(Range{Min{min}, Max{max}}, n) 116 | } 117 | 118 | func (v *Validation) MinSize(obj interface{}, min int) *ValidationResult { 119 | return v.apply(MinSize{min}, obj) 120 | } 121 | 122 | func (v *Validation) MaxSize(obj interface{}, max int) *ValidationResult { 123 | return v.apply(MaxSize{max}, obj) 124 | } 125 | 126 | func (v *Validation) Length(obj interface{}, n int) *ValidationResult { 127 | return v.apply(Length{n}, obj) 128 | } 129 | 130 | func (v *Validation) Match(str string, regex *regexp.Regexp) *ValidationResult { 131 | return v.apply(Match{regex}, str) 132 | } 133 | 134 | func (v *Validation) Email(str string) *ValidationResult { 135 | return v.apply(Email{Match{emailPattern}}, str) 136 | } 137 | 138 | func (v *Validation) apply(chk Validator, obj interface{}) *ValidationResult { 139 | if chk.IsSatisfied(obj) { 140 | return &ValidationResult{Ok: true} 141 | } 142 | 143 | // Get the default key. 144 | var key string 145 | if pc, _, line, ok := runtime.Caller(2); ok { 146 | f := runtime.FuncForPC(pc) 147 | if defaultKeys, ok := DefaultValidationKeys[f.Name()]; ok { 148 | key = defaultKeys[line] 149 | } 150 | } else { 151 | INFO.Println("Failed to get Caller information to look up Validation key") 152 | } 153 | 154 | // Add the error to the validation context. 155 | err := &ValidationError{ 156 | Message: chk.DefaultMessage(), 157 | Key: key, 158 | } 159 | v.Errors = append(v.Errors, err) 160 | 161 | // Also return it in the result. 162 | return &ValidationResult{ 163 | Ok: false, 164 | Error: err, 165 | } 166 | } 167 | 168 | // Apply a group of validators to a field, in order, and return the 169 | // ValidationResult from the first one that fails, or the last one that 170 | // succeeds. 171 | func (v *Validation) Check(obj interface{}, checks ...Validator) *ValidationResult { 172 | var result *ValidationResult 173 | for _, check := range checks { 174 | result = v.apply(check, obj) 175 | if !result.Ok { 176 | return result 177 | } 178 | } 179 | return result 180 | } 181 | 182 | // Mars Filter function to be hooked into the filter chain. 183 | func ValidationFilter(c *Controller, fc []Filter) { 184 | errors, err := restoreValidationErrors(c.Request.Request) 185 | c.Validation = &Validation{ 186 | Errors: errors, 187 | keep: false, 188 | } 189 | hasCookie := (err != http.ErrNoCookie) 190 | 191 | fc[0](c, fc[1:]) 192 | 193 | // Add Validation errors to RenderArgs. 194 | c.RenderArgs["errors"] = c.Validation.ErrorMap() 195 | 196 | // Store the Validation errors 197 | var errorsValue string 198 | if c.Validation.keep { 199 | for _, error := range c.Validation.Errors { 200 | if error.Message != "" { 201 | errorsValue += "\x00" + error.Key + ":" + error.Message + "\x00" 202 | } 203 | } 204 | } 205 | 206 | // When there are errors from Validation and Keep() has been called, store the 207 | // values in a cookie. If there previously was a cookie but no errors, remove 208 | // the cookie. 209 | if errorsValue != "" { 210 | c.SetCookie(&http.Cookie{ 211 | Name: CookiePrefix + "_ERRORS", 212 | Value: url.QueryEscape(errorsValue), 213 | Domain: CookieDomain, 214 | Path: "/", 215 | HttpOnly: CookieHttpOnly, 216 | Secure: CookieSecure, 217 | }) 218 | } else if hasCookie { 219 | c.SetCookie(&http.Cookie{ 220 | Name: CookiePrefix + "_ERRORS", 221 | MaxAge: -1, 222 | Domain: CookieDomain, 223 | Path: "/", 224 | HttpOnly: CookieHttpOnly, 225 | Secure: CookieSecure, 226 | }) 227 | } 228 | } 229 | 230 | // Restore Validation.Errors from a request. 231 | func restoreValidationErrors(req *http.Request) ([]*ValidationError, error) { 232 | var ( 233 | err error 234 | cookie *http.Cookie 235 | errors = make([]*ValidationError, 0, 5) 236 | ) 237 | if cookie, err = req.Cookie(CookiePrefix + "_ERRORS"); err == nil { 238 | parseKeyValueCookie(cookie.Value, func(key, val string) { 239 | errors = append(errors, &ValidationError{ 240 | Key: key, 241 | Message: val, 242 | }) 243 | }) 244 | } 245 | return errors, err 246 | } 247 | 248 | // Register default validation keys for all calls to Controller.Validation.Func(). 249 | // Map from (package).func => (line => name of first arg to Validation func) 250 | // E.g. "myapp/controllers.helper" or "myapp/controllers.(*Application).Action" 251 | // This is set on initialization in the generated main.go file. 252 | var DefaultValidationKeys map[string]map[int]string 253 | -------------------------------------------------------------------------------- /i18n.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "fmt" 5 | "html" 6 | "html/template" 7 | "os" 8 | "path/filepath" 9 | "regexp" 10 | "strings" 11 | 12 | "github.com/robfig/config" 13 | ) 14 | 15 | const ( 16 | CurrentLocaleRenderArg = "currentLocale" // The key for the current locale render arg value 17 | 18 | messageFilesDirectory = "messages" 19 | messageFilePattern = `^\w+\.[a-zA-Z]{2}$` 20 | unknownValueFormat = "??? %s ???" 21 | defaultLanguageOption = "i18n.default_language" 22 | localeCookieConfigKey = "i18n.cookie" 23 | ) 24 | 25 | var ( 26 | // All currently loaded message configs. 27 | messages map[string]*config.Config 28 | ) 29 | 30 | // Return all currently loaded message languages. 31 | func MessageLanguages() []string { 32 | languages := make([]string, len(messages)) 33 | i := 0 34 | for language := range messages { 35 | languages[i] = language 36 | i++ 37 | } 38 | return languages 39 | } 40 | 41 | // Perform a message look-up for the given locale and message using the given arguments. 42 | // 43 | // When either an unknown locale or message is detected, a specially formatted string is returned. 44 | func Message(locale, message string, args ...interface{}) string { 45 | language, region := parseLocale(locale) 46 | 47 | messageConfig, knownLanguage := messages[language] 48 | if !knownLanguage { 49 | TRACE.Printf("Unsupported language for locale '%s' and message '%s', trying default language", locale, message) 50 | 51 | if defaultLanguage, found := Config.String(defaultLanguageOption); found { 52 | TRACE.Printf("Using default language '%s'", defaultLanguage) 53 | 54 | messageConfig, knownLanguage = messages[defaultLanguage] 55 | if !knownLanguage { 56 | WARN.Printf("Unsupported default language for locale '%s' and message '%s'", defaultLanguage, message) 57 | return fmt.Sprintf(unknownValueFormat, message) 58 | } 59 | } else { 60 | WARN.Printf("Unable to find default language option (%s); messages for unsupported locales will never be translated", defaultLanguageOption) 61 | return fmt.Sprintf(unknownValueFormat, message) 62 | } 63 | } 64 | 65 | // This works because unlike the goconfig documentation suggests it will actually 66 | // try to resolve message in DEFAULT if it did not find it in the given section. 67 | value, error := messageConfig.String(region, message) 68 | if error != nil { 69 | WARN.Printf("Unknown message '%s' for locale '%s'", message, locale) 70 | return fmt.Sprintf(unknownValueFormat, message) 71 | } 72 | 73 | if len(args) > 0 { 74 | TRACE.Printf("Arguments detected, formatting '%s' with %v", value, args) 75 | value = fmt.Sprintf(value, args...) 76 | } 77 | 78 | return value 79 | } 80 | 81 | // MessageHTML performs a message look-up for the given locale and message using the given arguments 82 | // and guarantees, that safe HTML is always returned. 83 | func MessageHTML(locale, key string, args ...interface{}) template.HTML { 84 | if !strings.HasSuffix(key, ".html") && !strings.HasSuffix(key, "_html") { 85 | return template.HTML(html.EscapeString(Message(locale, key, args...))) 86 | } 87 | 88 | safeArgs := make([]interface{}, len(args)) 89 | for idx, arg := range args { 90 | switch val := arg.(type) { 91 | case template.HTML: 92 | safeArgs[idx] = val 93 | case string: 94 | safeArgs[idx] = html.EscapeString(val) 95 | case fmt.Stringer: 96 | safeArgs[idx] = html.EscapeString(val.String()) 97 | case []byte: 98 | safeArgs[idx] = []byte(html.EscapeString(string(val))) 99 | case bool: 100 | safeArgs[idx] = val 101 | case float32: 102 | safeArgs[idx] = val 103 | case float64: 104 | safeArgs[idx] = val 105 | case complex64: 106 | safeArgs[idx] = val 107 | case complex128: 108 | safeArgs[idx] = val 109 | case int: 110 | safeArgs[idx] = val 111 | case int8: 112 | safeArgs[idx] = val 113 | case int16: 114 | safeArgs[idx] = val 115 | case int32: 116 | safeArgs[idx] = val 117 | case int64: 118 | safeArgs[idx] = val 119 | case uint: 120 | safeArgs[idx] = val 121 | case uint8: 122 | safeArgs[idx] = val 123 | case uint16: 124 | safeArgs[idx] = val 125 | case uint32: 126 | safeArgs[idx] = val 127 | case uint64: 128 | safeArgs[idx] = val 129 | case uintptr: 130 | safeArgs[idx] = val 131 | default: 132 | safeArgs[idx] = html.EscapeString(fmt.Sprint(val)) 133 | } 134 | } 135 | 136 | return template.HTML(Message(locale, key, safeArgs...)) 137 | } 138 | 139 | func parseLocale(locale string) (language, region string) { 140 | if strings.Contains(locale, "-") { 141 | languageAndRegion := strings.Split(locale, "-") 142 | return languageAndRegion[0], languageAndRegion[1] 143 | } 144 | 145 | return locale, "" 146 | } 147 | 148 | // Recursively read and cache all available messages from all message files on the given path. 149 | func loadMessages(path string) { 150 | messages = make(map[string]*config.Config) 151 | 152 | if error := filepath.Walk(path, loadMessageFile); error != nil && !os.IsNotExist(error) { 153 | ERROR.Println("Error reading messages files:", error) 154 | } 155 | } 156 | 157 | // Load a single message file 158 | func loadMessageFile(path string, info os.FileInfo, osError error) error { 159 | if osError != nil { 160 | return osError 161 | } 162 | if info.IsDir() { 163 | return nil 164 | } 165 | 166 | if matched, _ := regexp.MatchString(messageFilePattern, info.Name()); matched { 167 | if config, error := parseMessagesFile(path); error != nil { 168 | return error 169 | } else { 170 | locale := parseLocaleFromFileName(info.Name()) 171 | 172 | // If we have already parsed a message file for this locale, merge both 173 | if _, exists := messages[locale]; exists { 174 | messages[locale].Merge(config) 175 | TRACE.Printf("Successfully merged messages for locale '%s'", locale) 176 | } else { 177 | messages[locale] = config 178 | } 179 | 180 | TRACE.Println("Successfully loaded messages from file", info.Name()) 181 | } 182 | } else { 183 | TRACE.Printf("Ignoring file %s because it did not have a valid extension", info.Name()) 184 | } 185 | 186 | return nil 187 | } 188 | 189 | func parseMessagesFile(path string) (messageConfig *config.Config, error error) { 190 | messageConfig, error = config.ReadDefault(path) 191 | return 192 | } 193 | 194 | func parseLocaleFromFileName(file string) string { 195 | extension := filepath.Ext(file)[1:] 196 | return strings.ToLower(extension) 197 | } 198 | 199 | func init() { 200 | OnAppStart(func() { 201 | loadMessages(filepath.Join(BasePath, messageFilesDirectory)) 202 | }) 203 | } 204 | 205 | func I18nFilter(c *Controller, fc []Filter) { 206 | if foundCookie, cookieValue := hasLocaleCookie(c.Request); foundCookie { 207 | TRACE.Printf("Found locale cookie value: %s", cookieValue) 208 | setCurrentLocaleControllerArguments(c, cookieValue) 209 | } else if foundHeader, headerValue := hasAcceptLanguageHeader(c.Request); foundHeader { 210 | TRACE.Printf("Found Accept-Language header value: %s", headerValue) 211 | setCurrentLocaleControllerArguments(c, headerValue) 212 | } else { 213 | TRACE.Println("Unable to find locale in cookie or header, using empty string") 214 | setCurrentLocaleControllerArguments(c, "") 215 | } 216 | fc[0](c, fc[1:]) 217 | } 218 | 219 | // Set the current locale controller argument (CurrentLocaleControllerArg) with the given locale. 220 | func setCurrentLocaleControllerArguments(c *Controller, locale string) { 221 | c.Request.Locale = locale 222 | c.RenderArgs[CurrentLocaleRenderArg] = locale 223 | } 224 | 225 | // Determine whether the given request has valid Accept-Language value. 226 | // 227 | // Assumes that the accept languages stored in the request are sorted according to quality, with top 228 | // quality first in the slice. 229 | func hasAcceptLanguageHeader(request *Request) (bool, string) { 230 | if request.AcceptLanguages != nil && len(request.AcceptLanguages) > 0 { 231 | return true, removeAllWhitespace(request.AcceptLanguages[0].Language) 232 | } 233 | 234 | return false, "" 235 | } 236 | 237 | // Determine whether the given request has a valid language cookie value. 238 | func hasLocaleCookie(request *Request) (bool, string) { 239 | if request != nil && request.Cookies() != nil { 240 | name := Config.StringDefault(localeCookieConfigKey, CookiePrefix+"_LANG") 241 | if cookie, error := request.Cookie(name); error == nil { 242 | return true, removeAllWhitespace(cookie.Value) 243 | } else { 244 | TRACE.Printf("Unable to read locale cookie with name '%s': %s", name, error.Error()) 245 | } 246 | } 247 | 248 | return false, "" 249 | } 250 | -------------------------------------------------------------------------------- /validators_test.go: -------------------------------------------------------------------------------- 1 | package mars 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "regexp" 7 | "strings" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | const ( 13 | errorsMessage = "validation for %s should not be satisfied with %s\n" 14 | noErrorsMessage = "validation for %s should be satisfied with %s\n" 15 | ) 16 | 17 | type Expect struct { 18 | input interface{} 19 | expectedResult bool 20 | errorMessage string 21 | } 22 | 23 | func performTests(validator Validator, tests []Expect, t *testing.T) { 24 | for _, test := range tests { 25 | if validator.IsSatisfied(test.input) != test.expectedResult { 26 | if test.expectedResult == false { 27 | t.Errorf(errorsMessage, reflect.TypeOf(validator), test.errorMessage) 28 | } else { 29 | t.Errorf(noErrorsMessage, reflect.TypeOf(validator), test.errorMessage) 30 | } 31 | } 32 | } 33 | } 34 | 35 | func TestRequired(t *testing.T) { 36 | 37 | tests := []Expect{ 38 | {nil, false, "nil data"}, 39 | {"Testing", true, "non-empty string"}, 40 | {"", false, "empty string"}, 41 | {true, true, "true boolean"}, 42 | {false, false, "false boolean"}, 43 | {1, true, "positive integer"}, 44 | {-1, true, "negative integer"}, 45 | {0, false, "0 integer"}, 46 | {time.Now(), true, "current time"}, 47 | {time.Time{}, false, "a zero time"}, 48 | {func() {}, true, "other non-nil data types"}, 49 | } 50 | 51 | // testing both the struct and the helper method 52 | for _, required := range []Required{{}, ValidRequired()} { 53 | performTests(required, tests, t) 54 | } 55 | } 56 | 57 | func TestMin(t *testing.T) { 58 | tests := []Expect{ 59 | {11, true, "val > min"}, 60 | {10, true, "val == min"}, 61 | {9, false, "val < min"}, 62 | {true, false, "TypeOf(val) != int"}, 63 | } 64 | for _, min := range []Min{{10}, ValidMin(10)} { 65 | performTests(min, tests, t) 66 | } 67 | } 68 | 69 | func TestMax(t *testing.T) { 70 | tests := []Expect{ 71 | {9, true, "val < max"}, 72 | {10, true, "val == max"}, 73 | {11, false, "val > max"}, 74 | {true, false, "TypeOf(val) != int"}, 75 | } 76 | for _, max := range []Max{{10}, ValidMax(10)} { 77 | performTests(max, tests, t) 78 | } 79 | } 80 | 81 | func TestRange(t *testing.T) { 82 | tests := []Expect{ 83 | {50, true, "min <= val <= max"}, 84 | {10, true, "val == min"}, 85 | {100, true, "val == max"}, 86 | {9, false, "val < min"}, 87 | {101, false, "val > max"}, 88 | } 89 | 90 | goodValidators := []Range{ 91 | {Min{10}, Max{100}}, 92 | ValidRange(10, 100), 93 | } 94 | for _, rangeValidator := range goodValidators { 95 | performTests(rangeValidator, tests, t) 96 | } 97 | 98 | tests = []Expect{ 99 | {10, true, "min == val == max"}, 100 | {9, false, "val < min && val < max && min == max"}, 101 | {11, false, "val > min && val > max && min == max"}, 102 | } 103 | 104 | goodValidators = []Range{ 105 | {Min{10}, Max{10}}, 106 | ValidRange(10, 10), 107 | } 108 | for _, rangeValidator := range goodValidators { 109 | performTests(rangeValidator, tests, t) 110 | } 111 | 112 | tests = make([]Expect, 7) 113 | for i, num := range []int{50, 100, 10, 9, 101, 0, -1} { 114 | tests[i] = Expect{ 115 | num, 116 | false, 117 | "min > val < max", 118 | } 119 | } 120 | // these are min/max with values swapped, so the min is the high 121 | // and max is the low. rangeValidator.IsSatisfied() should ALWAYS 122 | // result in false since val can never be greater than min and less 123 | // than max when min > max 124 | badValidators := []Range{ 125 | {Min{100}, Max{10}}, 126 | ValidRange(100, 10), 127 | } 128 | for _, rangeValidator := range badValidators { 129 | performTests(rangeValidator, tests, t) 130 | } 131 | } 132 | 133 | func TestMinSize(t *testing.T) { 134 | greaterThanMessage := "len(val) >= min" 135 | tests := []Expect{ 136 | {"1", true, greaterThanMessage}, 137 | {"12", true, greaterThanMessage}, 138 | {[]int{1}, true, greaterThanMessage}, 139 | {[]int{1, 2}, true, greaterThanMessage}, 140 | {"", false, "len(val) <= min"}, 141 | {[]int{}, false, "len(val) <= min"}, 142 | {nil, false, "TypeOf(val) != string && TypeOf(val) != slice"}, 143 | } 144 | 145 | for _, minSize := range []MinSize{{1}, ValidMinSize(1)} { 146 | performTests(minSize, tests, t) 147 | } 148 | } 149 | 150 | func TestMaxSize(t *testing.T) { 151 | lessThanMessage := "len(val) <= max" 152 | tests := []Expect{ 153 | {"", true, lessThanMessage}, 154 | {"12", true, lessThanMessage}, 155 | {[]int{}, true, lessThanMessage}, 156 | {[]int{1, 2}, true, lessThanMessage}, 157 | {"123", false, "len(val) >= max"}, 158 | {[]int{1, 2, 3}, false, "len(val) >= max"}, 159 | } 160 | for _, maxSize := range []MaxSize{{2}, ValidMaxSize(2)} { 161 | performTests(maxSize, tests, t) 162 | } 163 | } 164 | 165 | func TestLength(t *testing.T) { 166 | tests := []Expect{ 167 | {"12", true, "len(val) == length"}, 168 | {[]int{1, 2}, true, "len(val) == length"}, 169 | {"123", false, "len(val) > length"}, 170 | {[]int{1, 2, 3}, false, "len(val) > length"}, 171 | {"1", false, "len(val) < length"}, 172 | {[]int{1}, false, "len(val) < length"}, 173 | {nil, false, "TypeOf(val) != string && TypeOf(val) != slice"}, 174 | } 175 | for _, length := range []Length{{2}, ValidLength(2)} { 176 | performTests(length, tests, t) 177 | } 178 | } 179 | 180 | func TestMatch(t *testing.T) { 181 | tests := []Expect{ 182 | {"bca123", true, `"[abc]{3}\d*" matches "bca123"`}, 183 | {"bc123", false, `"[abc]{3}\d*" does not match "bc123"`}, 184 | {"", false, `"[abc]{3}\d*" does not match ""`}, 185 | } 186 | regex := regexp.MustCompile(`[abc]{3}\d*`) 187 | for _, match := range []Match{{regex}, ValidMatch(regex)} { 188 | performTests(match, tests, t) 189 | } 190 | } 191 | 192 | func TestEmail(t *testing.T) { 193 | // unicode char included 194 | validStartingCharacters := strings.Split("!#$%^&*_+1234567890abcdefghijklmnopqrstuvwxyzñ", "") 195 | invalidCharacters := strings.Split(" ()", "") 196 | 197 | definiteInvalidDomains := []string{ 198 | "", // any empty string (x@) 199 | ".com", // only the TLD (x@.com) 200 | ".", // only the . (x@.) 201 | ".*", // TLD containing symbol (x@.*) 202 | "asdf", // no TLD 203 | "a!@#$%^&*()+_.com", // characters which are not ASCII/0-9/dash(-) in a domain 204 | "-a.com", // host starting with any symbol 205 | "a-.com", // host ending with any symbol 206 | "aå.com", // domain containing unicode (however, unicode domains do exist in the state of xn--.com e.g. å.com = xn--5ca.com) 207 | } 208 | 209 | for _, email := range []Email{{Match{emailPattern}}, ValidEmail()} { 210 | var currentEmail string 211 | 212 | // test invalid starting chars 213 | for _, startingChar := range validStartingCharacters { 214 | currentEmail = fmt.Sprintf("%sñbc+123@do-main.com", startingChar) 215 | if email.IsSatisfied(currentEmail) { 216 | t.Errorf(noErrorsMessage, "starting characters", fmt.Sprintf("email = %s", currentEmail)) 217 | } 218 | 219 | // validation should fail because of multiple @ symbols 220 | currentEmail = fmt.Sprintf("%s@ñbc+123@do-main.com", startingChar) 221 | if email.IsSatisfied(currentEmail) { 222 | t.Errorf(errorsMessage, "starting characters with multiple @ symbols", fmt.Sprintf("email = %s", currentEmail)) 223 | } 224 | 225 | // should fail simply because of the invalid char 226 | for _, invalidChar := range invalidCharacters { 227 | currentEmail = fmt.Sprintf("%sñbc%s+123@do-main.com", startingChar, invalidChar) 228 | if email.IsSatisfied(currentEmail) { 229 | t.Errorf(errorsMessage, "invalid starting characters", fmt.Sprintf("email = %s", currentEmail)) 230 | } 231 | } 232 | } 233 | 234 | // test invalid domains 235 | for _, invalidDomain := range definiteInvalidDomains { 236 | currentEmail = fmt.Sprintf("a@%s", invalidDomain) 237 | if email.IsSatisfied(currentEmail) { 238 | t.Errorf(errorsMessage, "invalid domain", fmt.Sprintf("email = %s", currentEmail)) 239 | } 240 | } 241 | 242 | // should always be satisfied 243 | if !email.IsSatisfied("t0.est+email123@1abc0-def.com") { 244 | t.Errorf(noErrorsMessage, "guaranteed valid email", fmt.Sprintf("email = %s", "t0.est+email123@1abc0-def.com")) 245 | } 246 | 247 | // should never be satisfied (this is redundant given the loops above) 248 | if email.IsSatisfied("a@xcom") { 249 | t.Errorf(noErrorsMessage, "guaranteed invalid email", fmt.Sprintf("email = %s", "a@xcom")) 250 | } 251 | if email.IsSatisfied("a@@x.com") { 252 | t.Errorf(noErrorsMessage, "guaranteed invaild email", fmt.Sprintf("email = %s", "a@@x.com")) 253 | } 254 | } 255 | } 256 | -------------------------------------------------------------------------------- /cmd/mars-gen/main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "go/ast" 5 | "go/parser" 6 | "go/token" 7 | "reflect" 8 | "strings" 9 | "testing" 10 | ) 11 | 12 | var TypeExprs = map[string]TypeExpr{ 13 | "int": {"int", "", 0, true}, 14 | "*int": {"*int", "", 1, true}, 15 | "[]int": {"[]int", "", 2, true}, 16 | "...int": {"[]int", "", 2, true}, 17 | "[]*int": {"[]*int", "", 3, true}, 18 | "...*int": {"[]*int", "", 3, true}, 19 | "MyType": {"MyType", "pkg", 0, true}, 20 | "*MyType": {"*MyType", "pkg", 1, true}, 21 | "[]MyType": {"[]MyType", "pkg", 2, true}, 22 | "...MyType": {"[]MyType", "pkg", 2, true}, 23 | "[]*MyType": {"[]*MyType", "pkg", 3, true}, 24 | "...*MyType": {"[]*MyType", "pkg", 3, true}, 25 | "interface{}": {"interface{}", "", 0, true}, 26 | "...interface{}": {"[]interface{}", "", 2, true}, 27 | "any": {"any", "", 0, true}, 28 | "...any": {"[]any", "", 2, true}, 29 | } 30 | 31 | func TestTypeExpr(t *testing.T) { 32 | for str, expected := range TypeExprs { 33 | typeStr := str 34 | // Handle arrays and ... myself, since ParseExpr() does not. 35 | array := strings.HasPrefix(typeStr, "[]") 36 | if array { 37 | typeStr = typeStr[2:] 38 | } 39 | 40 | ellipsis := strings.HasPrefix(typeStr, "...") 41 | if ellipsis { 42 | typeStr = typeStr[3:] 43 | } 44 | 45 | expr, err := parser.ParseExpr(typeStr) 46 | if err != nil { 47 | t.Error("Failed to parse test expr:", typeStr) 48 | continue 49 | } 50 | 51 | if array { 52 | expr = &ast.ArrayType{Lbrack: expr.Pos(), Len: nil, Elt: expr} 53 | } 54 | if ellipsis { 55 | expr = &ast.Ellipsis{Ellipsis: expr.Pos(), Elt: expr} 56 | } 57 | 58 | actual := NewTypeExpr("pkg", expr) 59 | if !reflect.DeepEqual(expected, actual) { 60 | t.Errorf("Fail, expected '%v' for '%s', got '%v'\n", expected, str, actual) 61 | } 62 | } 63 | } 64 | 65 | const testApplication = ` 66 | package test 67 | 68 | import ( 69 | "os" 70 | 71 | "bytes" 72 | "database/sql" 73 | "errors" 74 | "fmt" 75 | "html/template" 76 | "math/rand" 77 | "net/http" 78 | "net/url" 79 | "path" 80 | "runtime" 81 | "sort" 82 | "strconv" 83 | "strings" 84 | "time" 85 | 86 | myMars "github.com/roblillack/mars" 87 | ) 88 | 89 | type Hotel struct { 90 | HotelId int 91 | Name, Address string 92 | City, State, Zip string 93 | Country string 94 | Price int 95 | } 96 | 97 | type Application struct { 98 | myMars.Controller 99 | KnownUser bool 100 | } 101 | 102 | type Hotels struct { 103 | Application 104 | } 105 | 106 | type Static struct { 107 | *myMars.Controller 108 | } 109 | 110 | type Bla struct { 111 | Number int 112 | Text string 113 | } 114 | 115 | type Blurp struct { 116 | Bla 117 | Checkbox bool 118 | } 119 | 120 | func (blurp Blurp) Index() myMars.Result { 121 | return nil 122 | } 123 | 124 | func (c Hotels) Show(id int) myMars.Result { 125 | title := "View Hotel" 126 | hotel := &Hotel{id, "A Hotel", "300 Main St.", "New York", "NY", "10010", "USA", 300} 127 | return c.Render(title, hotel) 128 | } 129 | 130 | func (c Hotels) Book(id int) myMars.Result { 131 | hotel := &Hotel{id, "A Hotel", "300 Main St.", "New York", "NY", "10010", "USA", 300} 132 | return c.RenderJson(hotel) 133 | } 134 | 135 | func (c Hotels) Index() myMars.Result { 136 | return c.RenderText("Hello, World!") 137 | } 138 | 139 | func (c Static) Serve(prefix, filepath string) myMars.Result { 140 | var basePath, dirName string 141 | 142 | if !path.IsAbs(dirName) { 143 | basePath = BasePath 144 | } 145 | 146 | fname := path.Join(basePath, prefix, filepath) 147 | file, err := os.Open(fname) 148 | if os.IsNotExist(err) { 149 | return c.NotFound("") 150 | } else if err != nil { 151 | myMars.WARN.Printf("Problem opening file (%s): %s ", fname, err) 152 | return c.NotFound("This was found but not sure why we couldn't open it.") 153 | } 154 | return c.RenderFile(file, "") 155 | } 156 | ` 157 | 158 | func stringSlicesEqual(a, b []string) bool { 159 | type direction struct { 160 | Slice []string 161 | Other []string 162 | } 163 | for _, t := range []direction{{a, b}, {b, a}} { 164 | for idx, v := range t.Slice { 165 | if idx >= len(t.Other) || t.Other[idx] != v { 166 | return false 167 | } 168 | } 169 | } 170 | return true 171 | } 172 | 173 | func (a *MethodArg) Equals(o *MethodArg) bool { 174 | if a == o { 175 | return true 176 | } 177 | if (a == nil && o != nil) || (o == nil && a != nil) { 178 | return false 179 | } 180 | 181 | return a.ImportPath == o.ImportPath && a.Name == o.Name && a.TypeExpr == o.TypeExpr 182 | } 183 | 184 | func (s *MethodSpec) Equals(o *MethodSpec) bool { 185 | if s.Name != o.Name { 186 | return false 187 | } 188 | 189 | type direction struct { 190 | Slice []*MethodArg 191 | Other []*MethodArg 192 | } 193 | for _, t := range []direction{{s.Args, o.Args}, {o.Args, s.Args}} { 194 | for idx, v := range t.Slice { 195 | if idx >= len(t.Other) || !v.Equals(t.Other[idx]) { 196 | return false 197 | } 198 | } 199 | } 200 | return true 201 | } 202 | 203 | func (i *TypeInfo) Equals(o *TypeInfo) bool { 204 | if i.ImportPath != o.ImportPath || i.PackageName != o.PackageName || i.StructName != o.StructName { 205 | return false 206 | } 207 | 208 | type direction struct { 209 | Slice []*MethodSpec 210 | Other []*MethodSpec 211 | } 212 | for _, t := range []direction{{i.MethodSpecs, o.MethodSpecs}, {o.MethodSpecs, i.MethodSpecs}} { 213 | for idx, v := range t.Slice { 214 | if idx >= len(t.Other) || !v.Equals(t.Other[idx]) { 215 | return false 216 | } 217 | } 218 | } 219 | return true 220 | } 221 | 222 | func TestProcessingSource(t *testing.T) { 223 | fset := token.NewFileSet() 224 | 225 | file, err := parser.ParseFile(fset, "testApplication", testApplication, 0) 226 | if err != nil { 227 | t.Fatal(err) 228 | } 229 | 230 | sourceInfo := ProcessFile(fset, "./test.go", file) 231 | if n := sourceInfo.PackageName; n != "test" { 232 | t.Errorf("wrong package name: %s", n) 233 | } 234 | if v := sourceInfo.InitImportPaths; !stringSlicesEqual(v, []string{}) { 235 | t.Errorf("unexpeced import paths: %+v", v) 236 | } 237 | 238 | if s := sourceInfo.StructSpecs[0]; !s.Equals(&TypeInfo{ 239 | StructName: "Hotel", 240 | ImportPath: "test", 241 | PackageName: "test", 242 | MethodSpecs: []*MethodSpec{}, 243 | }) { 244 | t.Errorf("unexpected struct spec: %+v", s) 245 | } 246 | 247 | if c := sourceInfo.ControllerSpecs()[0]; !c.Equals(&TypeInfo{ 248 | StructName: "Application", 249 | ImportPath: "test", 250 | PackageName: "test", 251 | }) { 252 | t.Errorf("wrong controller spec for Application controller: %+v", c) 253 | } 254 | 255 | if c := sourceInfo.ControllerSpecs()[1]; !c.Equals(&TypeInfo{ 256 | StructName: "Hotels", 257 | ImportPath: "test", 258 | PackageName: "test", 259 | MethodSpecs: []*MethodSpec{ 260 | { 261 | Name: "Show", 262 | Args: []*MethodArg{ 263 | {Name: "id", ImportPath: "", TypeExpr: TypeExpr{"int", "", 0, true}}, 264 | }, 265 | }, 266 | { 267 | Name: "Book", 268 | Args: []*MethodArg{ 269 | {Name: "id", ImportPath: "", TypeExpr: TypeExpr{"int", "", 0, true}}, 270 | }, 271 | }, 272 | { 273 | Name: "Index", 274 | }, 275 | }, 276 | }) { 277 | t.Errorf("wrong controller spec for Hotels controller: %+v", c) 278 | } 279 | 280 | if c := sourceInfo.ControllerSpecs()[2]; !c.Equals(&TypeInfo{ 281 | StructName: "Static", 282 | ImportPath: "test", 283 | PackageName: "test", 284 | MethodSpecs: []*MethodSpec{ 285 | { 286 | Name: "Serve", 287 | Args: []*MethodArg{ 288 | {Name: "prefix", ImportPath: "", TypeExpr: TypeExpr{"string", "", 0, true}}, 289 | {Name: "filepath", ImportPath: "", TypeExpr: TypeExpr{"string", "", 0, true}}, 290 | }, 291 | }, 292 | }, 293 | }) { 294 | t.Errorf("wrong controller spec for Static controller: %+v", c) 295 | } 296 | } 297 | 298 | func BenchmarkParsingFile(b *testing.B) { 299 | var fset *token.FileSet 300 | var file *ast.File 301 | 302 | for n := 0; n < b.N; n++ { 303 | fset = token.NewFileSet() 304 | var err error 305 | file, err = parser.ParseFile(fset, "testApplication", testApplication, 0) 306 | if err != nil { 307 | b.Fatal(err) 308 | } 309 | } 310 | 311 | ProcessFile(fset, "./test.go", file) 312 | } 313 | 314 | func BenchmarkProcessingSource(b *testing.B) { 315 | var fset *token.FileSet 316 | var file *ast.File 317 | fset = token.NewFileSet() 318 | var err error 319 | file, err = parser.ParseFile(fset, "testApplication", testApplication, 0) 320 | if err != nil { 321 | b.Fatal(err) 322 | } 323 | 324 | for n := 0; n < b.N; n++ { 325 | ProcessFile(fset, "./test.go", file) 326 | } 327 | } 328 | --------------------------------------------------------------------------------