50 |
⚠️
51 |
Access Denied
52 |
Sorry, you cannnot access this resource.
53 |
Please contact the customer support.
54 |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/internal/emitter/http/handler.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "net/http"
5 |
6 | "github.com/sitebatch/waffle-go/handler"
7 | "github.com/sitebatch/waffle-go/handler/response"
8 | )
9 |
10 | type Options struct {
11 | OnBlockFunc func()
12 | }
13 |
14 | func handle(w http.ResponseWriter, r *http.Request, options Options) (http.ResponseWriter, *http.Request, bool, func()) {
15 | ww, waffleResponseWriter := response.NewWaffleResponseWriter(w)
16 | op, ctx := StartHTTPRequestHandlerOperation(r.Context(), BuildHttpRequestHandlerOperationArg(r))
17 | rr := r.WithContext(ctx)
18 |
19 | blocked := false
20 | afterHandler := func() {
21 | result := &HTTPRequestHandlerOperationResult{}
22 | op.Finish(result)
23 |
24 | if result.BlockErr != nil {
25 | if options.OnBlockFunc != nil {
26 | options.OnBlockFunc()
27 | }
28 |
29 | contentType := waffleResponseWriter.Header().Get("Content-Type")
30 | waffleResponseWriter.Reset()
31 | blocked = true
32 |
33 | response.BlockResponseHandler(contentType).ServeHTTP(waffleResponseWriter, rr)
34 | }
35 |
36 | if err := waffleResponseWriter.Commit(); err != nil {
37 | handler.GetErrorHandler().HandleError(err)
38 | }
39 | }
40 |
41 | if op.IsBlock() {
42 | blocked = true
43 |
44 | contentType := waffleResponseWriter.Header().Get("Content-Type")
45 | waffleResponseWriter.Reset()
46 | response.BlockResponseHandler(contentType).ServeHTTP(waffleResponseWriter, rr)
47 | }
48 |
49 | return ww, rr, blocked, afterHandler
50 | }
51 |
52 | func WrapHandler(handler http.Handler, options Options) http.Handler {
53 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
54 | tw, tr, blocked, afterHandler := handle(w, r, options)
55 | defer afterHandler()
56 |
57 | if blocked {
58 | return
59 | }
60 |
61 | handler.ServeHTTP(tw, tr)
62 | })
63 | }
64 |
--------------------------------------------------------------------------------
/waf/wafcontext/operation.go:
--------------------------------------------------------------------------------
1 | package wafcontext
2 |
3 | type HttpRequest struct {
4 | URL string
5 | Headers map[string][]string
6 | RawBody []byte
7 | Body map[string][]string
8 | ClientIP string
9 | }
10 |
11 | // WafOperationContext holds context information for a WAF operation.
12 | type WafOperationContext struct {
13 | meta map[string]string
14 | httpRequest *HttpRequest
15 | }
16 |
17 | func WithMeta(meta map[string]string) WafOperationContextOption {
18 | return func(c *WafOperationContext) {
19 | c.meta = meta
20 | }
21 | }
22 |
23 | func WithHttpRequstContext(req HttpRequest) WafOperationContextOption {
24 | return func(c *WafOperationContext) {
25 | c.httpRequest = &req
26 | }
27 | }
28 |
29 | type WafOperationContextOption func(*WafOperationContext)
30 |
31 | func NewWafOperationContext(opts ...WafOperationContextOption) *WafOperationContext {
32 | c := &WafOperationContext{
33 | meta: make(map[string]string),
34 | }
35 |
36 | for _, opt := range opts {
37 | opt(c)
38 | }
39 |
40 | return c
41 | }
42 |
43 | // WithWafOperationContext applies the given options to the WafOperationContext and returns the modified context.
44 | func (c *WafOperationContext) WithWafOperationContext(opts ...WafOperationContextOption) *WafOperationContext {
45 | for _, opt := range opts {
46 | opt(c)
47 | }
48 |
49 | return c
50 | }
51 |
52 | func (c *WafOperationContext) GetMeta() map[string]string {
53 | if c.meta == nil {
54 | return make(map[string]string)
55 | }
56 |
57 | return c.meta
58 | }
59 |
60 | func (c *WafOperationContext) SetMeta(key, value string) {
61 | if c.meta == nil {
62 | c.meta = make(map[string]string)
63 | }
64 | c.meta[key] = value
65 | }
66 |
67 | func (c *WafOperationContext) GetHttpRequest() *HttpRequest {
68 | if c.httpRequest == nil {
69 | return &HttpRequest{}
70 | }
71 |
72 | return c.httpRequest
73 | }
74 |
--------------------------------------------------------------------------------
/handler/response/handler.go:
--------------------------------------------------------------------------------
1 | package response
2 |
3 | import (
4 | _ "embed"
5 | "net/http"
6 | "strings"
7 | "sync/atomic"
8 | )
9 |
10 | var (
11 | //go:embed templates/blocked.html
12 | defaultBlockResponseTemplateHTMLBytes []byte
13 |
14 | defaultBlockResponseTemplateJSONBytes = []byte(`{"error": "access denied. Sorry, you cannnot access this resource. Please contact the customer support."}`)
15 |
16 | blockResponseTemplateHTML = blockResponseTemplateHTMLValue()
17 | blockResponseTemplateJSON = blockResponseTemplateJSONValue()
18 | )
19 |
20 | func SetBlockResponseTemplateHTML(html []byte) {
21 | blockResponseTemplateHTML.Store(html)
22 | }
23 |
24 | func SetBlockResponseTemplateJSON(json []byte) {
25 | blockResponseTemplateJSON.Store(json)
26 | }
27 |
28 | func GetBlockResponseTemplateHTML() []byte {
29 | return blockResponseTemplateHTML.Load().([]byte)
30 | }
31 |
32 | func GetBlockResponseTemplateJSON() []byte {
33 | return blockResponseTemplateJSON.Load().([]byte)
34 | }
35 |
36 | func BlockResponseHandler(contentType string) http.Handler {
37 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
38 | w.WriteHeader(http.StatusForbidden)
39 |
40 | accept := r.Header.Get("Accept")
41 |
42 | if strings.Contains(contentType, "application/json") || strings.Contains(accept, "application/json") {
43 | w.Header().Add("Content-Type", "application/json")
44 | _, _ = w.Write(GetBlockResponseTemplateJSON())
45 | return
46 | }
47 |
48 | w.Header().Add("Content-Type", "text/html")
49 | _, _ = w.Write(GetBlockResponseTemplateHTML())
50 | })
51 | }
52 |
53 | func blockResponseTemplateHTMLValue() *atomic.Value {
54 | v := &atomic.Value{}
55 | v.Store(defaultBlockResponseTemplateHTMLBytes)
56 |
57 | return v
58 | }
59 |
60 | func blockResponseTemplateJSONValue() *atomic.Value {
61 | v := &atomic.Value{}
62 | v.Store(defaultBlockResponseTemplateJSONBytes)
63 |
64 | return v
65 | }
66 |
--------------------------------------------------------------------------------
/internal/inspector/account_takeover/login.go:
--------------------------------------------------------------------------------
1 | package account_takeover
2 |
3 | import (
4 | "fmt"
5 | "sync"
6 |
7 | "github.com/sitebatch/waffle-go/lib/limitter"
8 | "golang.org/x/time/rate"
9 | )
10 |
11 | var loginLimittersPerIPAddress = make(map[string]RateLimitter)
12 | var loginLimittersPerIPAddressLock sync.Mutex
13 |
14 | var loginLimittersPerUserID = make(map[string]RateLimitter)
15 | var loginLimittersPerUserIDLock sync.Mutex
16 |
17 | type RateLimitter interface {
18 | Allow() bool
19 | }
20 |
21 | func IsLimit(ip string, userID string, rate rate.Limit) error {
22 | if IsLimitedByIPAddress(ip, rate) {
23 | return fmt.Errorf("IP address %s is reached limit", ip)
24 | }
25 |
26 | if IsLimitedByUserID(userID, rate) {
27 | return fmt.Errorf("userID %s is reached limited", userID)
28 | }
29 |
30 | return nil
31 | }
32 |
33 | func IsLimitedByIPAddress(ip string, rate rate.Limit) bool {
34 | loginLimittersPerIPAddressLock.Lock()
35 | defer loginLimittersPerIPAddressLock.Unlock()
36 |
37 | if _, ok := loginLimittersPerIPAddress[ip]; !ok {
38 | loginLimittersPerIPAddress[ip] = limitter.NewLimitter(rate, int(rate))
39 | }
40 |
41 | return !loginLimittersPerIPAddress[ip].Allow()
42 | }
43 |
44 | func IsLimitedByUserID(userID string, rate rate.Limit) bool {
45 | loginLimittersPerUserIDLock.Lock()
46 | defer loginLimittersPerUserIDLock.Unlock()
47 |
48 | if _, ok := loginLimittersPerUserID[userID]; !ok {
49 | loginLimittersPerUserID[userID] = limitter.NewLimitter(rate, int(rate))
50 | }
51 |
52 | return !loginLimittersPerUserID[userID].Allow()
53 | }
54 |
55 | func ClearLimitByIP(ip string) {
56 | loginLimittersPerIPAddressLock.Lock()
57 | defer loginLimittersPerIPAddressLock.Unlock()
58 |
59 | delete(loginLimittersPerIPAddress, ip)
60 | }
61 |
62 | func ClearLimitByUserID(userID string) {
63 | loginLimittersPerUserIDLock.Lock()
64 | defer loginLimittersPerUserIDLock.Unlock()
65 |
66 | delete(loginLimittersPerUserID, userID)
67 | }
68 |
--------------------------------------------------------------------------------
/contrib/net/http/client_test.go:
--------------------------------------------------------------------------------
1 | package http_test
2 |
3 | import (
4 | "context"
5 | "io"
6 | stdhttp "net/http"
7 | "testing"
8 |
9 | "github.com/sitebatch/waffle-go"
10 | "github.com/sitebatch/waffle-go/contrib/net/http"
11 | emitterHttp "github.com/sitebatch/waffle-go/internal/emitter/http"
12 | "github.com/sitebatch/waffle-go/waf"
13 | "github.com/stretchr/testify/assert"
14 | "github.com/stretchr/testify/require"
15 | )
16 |
17 | func TestWrapClient(t *testing.T) {
18 | t.Parallel()
19 | waffle.Start()
20 |
21 | testCases := map[string]struct {
22 | ctx context.Context
23 | url string
24 | expectErr bool
25 | }{
26 | "when through http operation and non-attack request": {
27 | ctx: buildHttpOperationCtx(t),
28 | url: "https://example.com",
29 | expectErr: false,
30 | },
31 | "when through http operation and attack request": {
32 | ctx: buildHttpOperationCtx(t),
33 | url: "http://169.254.169.254",
34 | expectErr: true,
35 | },
36 | }
37 |
38 | for name, tt := range testCases {
39 | t.Run(name, func(t *testing.T) {
40 | t.Parallel()
41 |
42 | c := http.WrapClient(stdhttp.DefaultClient)
43 | req, _ := stdhttp.NewRequestWithContext(tt.ctx, "GET", tt.url, nil)
44 |
45 | resp, err := c.Do(req)
46 | if tt.expectErr {
47 | assert.Error(t, err)
48 |
49 | var secErr *waf.SecurityBlockingError
50 | assert.ErrorAs(t, err, &secErr)
51 | return
52 | }
53 |
54 | assert.NoError(t, err)
55 |
56 | defer resp.Body.Close()
57 |
58 | assert.Equal(t, 200, resp.StatusCode)
59 |
60 | b, err := io.ReadAll(resp.Body)
61 | require.NoError(t, err)
62 | assert.Contains(t, string(b), "Example Domain")
63 | })
64 | }
65 | }
66 |
67 | func buildHttpOperationCtx(t *testing.T) context.Context {
68 | t.Helper()
69 |
70 | _, ctx := emitterHttp.StartHTTPRequestHandlerOperation(context.Background(), emitterHttp.HTTPRequestHandlerOperationArg{})
71 | return ctx
72 | }
73 |
--------------------------------------------------------------------------------
/internal/emitter/sql/handler_test.go:
--------------------------------------------------------------------------------
1 | package sql_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/sitebatch/waffle-go"
8 | "github.com/sitebatch/waffle-go/internal/emitter/http"
9 | "github.com/sitebatch/waffle-go/internal/emitter/sql"
10 | "github.com/sitebatch/waffle-go/waf"
11 | "github.com/stretchr/testify/assert"
12 | "github.com/stretchr/testify/require"
13 | )
14 |
15 | func TestProtectSQLOperation(t *testing.T) {
16 | t.Parallel()
17 |
18 | require.NoError(t, waffle.Start())
19 |
20 | testCases := map[string]struct {
21 | ctx context.Context
22 | query string
23 | expectErr bool
24 | }{
25 | "when through http operation and non-attack request": {
26 | ctx: buildHttpOperationCtx(t),
27 | query: "SELECT * FROM users",
28 | expectErr: false,
29 | },
30 | "when through http operation and attack request": {
31 | ctx: buildHttpOperationCtx(t),
32 | query: "SELECT * FROM users WHERE id = '1' OR 1=1--",
33 | expectErr: true,
34 | },
35 | "when not through http operation": {
36 | ctx: context.Background(),
37 | query: "SELECT * FROM users",
38 | expectErr: false,
39 | },
40 | "not through http operation and attack request": {
41 | ctx: context.Background(),
42 | query: "SELECT * FROM users WHERE id = '1' OR 1=1--",
43 | expectErr: true,
44 | },
45 | }
46 |
47 | for name, tt := range testCases {
48 | t.Run(name, func(t *testing.T) {
49 | t.Parallel()
50 |
51 | err := sql.ProtectSQLOperation(tt.ctx, tt.query)
52 | if tt.expectErr {
53 | assert.Error(t, err)
54 |
55 | var secErr *waf.SecurityBlockingError
56 | assert.ErrorAs(t, err, &secErr)
57 | return
58 | }
59 | assert.NoError(t, err)
60 | })
61 | }
62 | }
63 |
64 | func buildHttpOperationCtx(t *testing.T) context.Context {
65 | t.Helper()
66 |
67 | _, ctx := http.StartHTTPRequestHandlerOperation(context.Background(), http.HTTPRequestHandlerOperationArg{})
68 | return ctx
69 | }
70 |
--------------------------------------------------------------------------------
/internal/emitter/os/handler_test.go:
--------------------------------------------------------------------------------
1 | package os_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/sitebatch/waffle-go"
8 | "github.com/sitebatch/waffle-go/internal/emitter/http"
9 | "github.com/sitebatch/waffle-go/internal/emitter/os"
10 | "github.com/sitebatch/waffle-go/waf"
11 | "github.com/stretchr/testify/assert"
12 | "github.com/stretchr/testify/require"
13 | )
14 |
15 | func TestProtectFileOperation(t *testing.T) {
16 | t.Parallel()
17 |
18 | require.NoError(t, waffle.Start())
19 |
20 | testCases := map[string]struct {
21 | ctx context.Context
22 | filePath string
23 | expectErr bool
24 | }{
25 | "when through http operation and non-attack request": {
26 | ctx: buildHttpOperationCtx(t),
27 | filePath: "file.txt",
28 | expectErr: false,
29 | },
30 | "when through http operation and attack request": {
31 | ctx: buildHttpOperationCtx(t),
32 | filePath: "/var/run/secrets/kubernetes.io/serviceaccount/token",
33 | expectErr: true,
34 | },
35 | "when not through http operation": {
36 | ctx: context.Background(),
37 | filePath: "file.txt",
38 | expectErr: false,
39 | },
40 | "not through http operation and attack request": {
41 | ctx: context.Background(),
42 | filePath: "/var/run/secrets/kubernetes.io/serviceaccount/token",
43 | expectErr: true,
44 | },
45 | }
46 |
47 | for name, tt := range testCases {
48 | tt := tt
49 |
50 | t.Run(name, func(t *testing.T) {
51 | t.Parallel()
52 |
53 | err := os.ProtectFileOperation(tt.ctx, tt.filePath)
54 | if tt.expectErr {
55 | assert.Error(t, err)
56 |
57 | var secErr *waf.SecurityBlockingError
58 | assert.ErrorAs(t, err, &secErr)
59 | return
60 | }
61 | assert.NoError(t, err)
62 | })
63 | }
64 | }
65 |
66 | func buildHttpOperationCtx(t *testing.T) context.Context {
67 | t.Helper()
68 |
69 | _, ctx := http.StartHTTPRequestHandlerOperation(context.Background(), http.HTTPRequestHandlerOperationArg{})
70 | return ctx
71 | }
72 |
--------------------------------------------------------------------------------
/contrib/gorm.io/gorm/README.md:
--------------------------------------------------------------------------------
1 | # gorm
2 |
3 | This package provides integration instructions for using [GORM](https://gorm.io/) with Waffle protection. While GORM itself doesn't require wrapping, you can apply Waffle's SQL injection protection by using the Waffle database driver.
4 |
5 | ## Installation
6 |
7 | ```bash
8 | go get github.com/sitebatch/waffle-go/contrib/gorm.io/gorm
9 | ```
10 |
11 | ## Usage
12 |
13 | To apply Waffle protection to GORM applications, use the Waffle database driver with GORM:
14 |
15 | ```go
16 | package main
17 |
18 | import (
19 | "context"
20 | "fmt"
21 | "log"
22 |
23 | "github.com/sitebatch/waffle-go"
24 | waffleSql "github.com/sitebatch/waffle-go/contrib/database/sql"
25 | "github.com/sitebatch/waffle-go/waf"
26 |
27 | "gorm.io/driver/sqlite"
28 | "gorm.io/gorm"
29 | )
30 |
31 | type Product struct {
32 | gorm.Model
33 | Code string
34 | Name string
35 | }
36 |
37 | func main() {
38 | // Register Waffle driver
39 | driverName, err := waffleSql.Register(sqlite.DriverName)
40 | if err != nil {
41 | log.Fatal(err)
42 | }
43 |
44 | // Open database connection using Waffle's driver
45 | sqlDB, err := waffleSql.Open(driverName, "file:test.db?cache=shared&mode=memory")
46 | if err != nil {
47 | log.Fatal(err)
48 | }
49 |
50 | db, err := gorm.Open(sqlite.New(sqlite.Config{Conn: sqlDB}), &gorm.Config{})
51 | if err != nil {
52 | log.Fatal(err)
53 | }
54 |
55 | db.AutoMigrate(&Product{})
56 |
57 | // Start Waffle
58 | if err := waffle.Start(); err != nil {
59 | log.Fatal(err)
60 | }
61 |
62 | ctx := context.Background()
63 | var product Product
64 |
65 | // Execute queries - Waffle will prevent SQL injection
66 | maliciousCode := "D42') OR 1=1--"
67 | query := fmt.Sprintf("code = '%s'", maliciousCode)
68 | result := db.WithContext(ctx).Where(query).First(&product)
69 |
70 | if result.Error != nil {
71 | if waf.IsSecurityBlockingError(result.Error) {
72 | // Handle blocked query
73 | log.Printf("Blocked SQL injection attempt: %v", result.Error)
74 | }
75 | }
76 | }
77 | ```
78 |
--------------------------------------------------------------------------------
/internal/emitter/http/client_handler_test.go:
--------------------------------------------------------------------------------
1 | package http_test
2 |
3 | import (
4 | "context"
5 | "io"
6 | stdhttp "net/http"
7 | "testing"
8 |
9 | "github.com/sitebatch/waffle-go"
10 | "github.com/sitebatch/waffle-go/internal/emitter/http"
11 | "github.com/sitebatch/waffle-go/waf"
12 | "github.com/stretchr/testify/assert"
13 | "github.com/stretchr/testify/require"
14 | )
15 |
16 | func TestWrapClient(t *testing.T) {
17 | t.Parallel()
18 |
19 | waffle.Start()
20 |
21 | testCases := map[string]struct {
22 | ctx context.Context
23 | url string
24 | expectErr bool
25 | }{
26 | "when through http operation and non-attack request": {
27 | ctx: buildHttpOperationCtx(t),
28 | url: "https://example.com",
29 | expectErr: false,
30 | },
31 | "when through http operation and attack request": {
32 | ctx: buildHttpOperationCtx(t),
33 | url: "http://169.254.169.254",
34 | expectErr: true,
35 | },
36 | "when not through http operation": {
37 | ctx: context.Background(),
38 | url: "http://169.254.169.254",
39 | expectErr: true,
40 | },
41 | }
42 |
43 | for name, tt := range testCases {
44 | t.Run(name, func(t *testing.T) {
45 | t.Parallel()
46 |
47 | c := http.WrapClient(stdhttp.DefaultClient)
48 | req, _ := stdhttp.NewRequestWithContext(tt.ctx, "GET", tt.url, nil)
49 |
50 | resp, err := c.Do(req)
51 | if tt.expectErr {
52 | assert.Error(t, err)
53 | var secErr *waf.SecurityBlockingError
54 | assert.ErrorAs(t, err, &secErr)
55 | return
56 | }
57 |
58 | assert.NoError(t, err)
59 |
60 | defer resp.Body.Close()
61 |
62 | assert.Equal(t, 200, resp.StatusCode)
63 |
64 | b, err := io.ReadAll(resp.Body)
65 | require.NoError(t, err)
66 | assert.Contains(t, string(b), "Example Domain")
67 | })
68 | }
69 | }
70 |
71 | func buildHttpOperationCtx(t *testing.T) context.Context {
72 | t.Helper()
73 |
74 | _, ctx := http.StartHTTPRequestHandlerOperation(context.Background(), http.HTTPRequestHandlerOperationArg{})
75 | return ctx
76 | }
77 |
--------------------------------------------------------------------------------
/internal/inspector/sqli/tautology.go:
--------------------------------------------------------------------------------
1 | package sqli
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/xwb1989/sqlparser"
7 | )
8 |
9 | // IsWhereTautologyFull checks if the given SQL query is a tautology.
10 | func IsWhereTautologyFull(sql string) (bool, error) {
11 | stmt, err := sqlparser.Parse(sql)
12 | if err != nil {
13 | return false, fmt.Errorf("failed to parse SQL: %w", err)
14 | }
15 |
16 | selectStmt, ok := stmt.(*sqlparser.Select)
17 | if !ok || selectStmt.Where == nil {
18 | return false, nil
19 | }
20 |
21 | return checkExprForTautology(selectStmt.Where.Expr), nil
22 | }
23 |
24 | func checkExprForTautology(expr sqlparser.Expr) bool {
25 | switch e := expr.(type) {
26 | case *sqlparser.AndExpr:
27 | return checkExprForTautology(e.Left) || checkExprForTautology(e.Right)
28 | case *sqlparser.OrExpr:
29 | return checkExprForTautology(e.Left) || checkExprForTautology(e.Right)
30 | case *sqlparser.ComparisonExpr:
31 | return isComparisonTautology(e)
32 | case *sqlparser.ParenExpr:
33 | return checkExprForTautology(e.Expr)
34 | }
35 | return false
36 | }
37 |
38 | func isComparisonTautology(expr *sqlparser.ComparisonExpr) bool {
39 | // Check for patterns like "1 = 1", "TRUE = TRUE", etc.
40 | if isLiteralOrBoolean(expr.Left) && isLiteralOrBoolean(expr.Right) {
41 | return expr.Operator == "=" || expr.Operator == "!=" || expr.Operator == "<>"
42 | }
43 |
44 | // Check for patterns like "column = column"
45 | if isColumn(expr.Left) && isColumn(expr.Right) {
46 | leftCol := expr.Left.(*sqlparser.ColName)
47 | rightCol := expr.Right.(*sqlparser.ColName)
48 | return leftCol.Name.String() == rightCol.Name.String() && expr.Operator == "="
49 | }
50 |
51 | return false
52 | }
53 |
54 | func isLiteralOrBoolean(expr sqlparser.Expr) bool {
55 | switch e := expr.(type) {
56 | case *sqlparser.SQLVal:
57 | return e.Type == sqlparser.IntVal || e.Type == sqlparser.StrVal
58 | case sqlparser.BoolVal:
59 | return true
60 | }
61 | return false
62 | }
63 |
64 | func isColumn(expr sqlparser.Expr) bool {
65 | _, ok := expr.(*sqlparser.ColName)
66 | return ok
67 | }
68 |
--------------------------------------------------------------------------------
/contrib/99designs/gqlgen/testserver/graph/schema.resolvers.go:
--------------------------------------------------------------------------------
1 | package graph
2 |
3 | // This file will be automatically regenerated based on the schema, any resolver implementations
4 | // will be copied through when generating and any unknown code will be moved to the end.
5 | // Code generated by github.com/99designs/gqlgen version v0.17.62
6 |
7 | import (
8 | "context"
9 | "fmt"
10 |
11 | "github.com/sitebatch/waffle-go/contrib/99designs/gqlgen/testserver/graph/model"
12 | )
13 |
14 | // CreateTodo is the resolver for the createTodo field.
15 | func (r *mutationResolver) CreateTodo(ctx context.Context, input model.NewTodo) (*model.Todo, error) {
16 | return &model.Todo{
17 | Text: input.Text,
18 | ID: "1",
19 | Done: false,
20 | User: &model.User{
21 | ID: input.UserID,
22 | Name: fmt.Sprintf("user%s", input.UserID),
23 | },
24 | }, nil
25 | }
26 |
27 | // Todos is the resolver for the todos field.
28 | func (r *queryResolver) Todos(ctx context.Context) ([]*model.Todo, error) {
29 | return []*model.Todo{
30 | {
31 | Text: "todo1",
32 | ID: "1",
33 | Done: false,
34 | User: &model.User{
35 | ID: "1",
36 | Name: "user1",
37 | },
38 | },
39 | {
40 | Text: "todo2",
41 | ID: "2",
42 | Done: false,
43 | User: &model.User{
44 | ID: "1",
45 | Name: "user1",
46 | },
47 | },
48 | }, nil
49 | }
50 |
51 | // SearchTodo is the resolver for the searchTodo field.
52 | func (r *queryResolver) SearchTodo(ctx context.Context, id string, text string) ([]*model.Todo, error) {
53 | return []*model.Todo{
54 | {
55 | Text: text,
56 | ID: id,
57 | Done: false,
58 | User: &model.User{
59 | ID: "1",
60 | Name: "user1",
61 | },
62 | },
63 | }, nil
64 | }
65 |
66 | // Mutation returns MutationResolver implementation.
67 | func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} }
68 |
69 | // Query returns QueryResolver implementation.
70 | func (r *Resolver) Query() QueryResolver { return &queryResolver{r} }
71 |
72 | type mutationResolver struct{ *Resolver }
73 | type queryResolver struct{ *Resolver }
74 |
--------------------------------------------------------------------------------
/handler/response/response_writer.go:
--------------------------------------------------------------------------------
1 | package response
2 |
3 | import (
4 | "bytes"
5 | "net/http"
6 | )
7 |
8 | type WaffleResponseWriter struct {
9 | http.ResponseWriter
10 |
11 | // status is the HTTP status code written to the ResponseWriter.
12 | status int
13 |
14 | buf *bytes.Buffer
15 | }
16 |
17 | var (
18 | _ http.ResponseWriter = (*WaffleResponseWriter)(nil)
19 | )
20 |
21 | // NewWaffleResponseWriter returns a new WaffleResponseWriter.
22 | // The http.ResponseWriter should be the original value passed to the handler, or have an Unwrap method returning the original http.ResponseWriter.
23 | func NewWaffleResponseWriter(w http.ResponseWriter) (http.ResponseWriter, *WaffleResponseWriter) {
24 | feature := 0
25 | if _, ok := w.(http.CloseNotifier); ok {
26 | feature |= closeNotifier
27 | }
28 | if _, ok := w.(http.Flusher); ok {
29 | feature |= flusher
30 | }
31 | if _, ok := w.(http.Hijacker); ok {
32 | feature |= hijacker
33 | }
34 | if _, ok := w.(http.Pusher); ok {
35 | feature |= pusher
36 | }
37 |
38 | return featurePicker[feature](w)
39 | }
40 |
41 | func (w *WaffleResponseWriter) WriteHeader(status int) {
42 | if w.status != 0 {
43 | return
44 | }
45 |
46 | w.status = status
47 | }
48 |
49 | func (w *WaffleResponseWriter) Write(b []byte) (int, error) {
50 | return w.buf.Write(b)
51 | }
52 |
53 | func (w *WaffleResponseWriter) Status() int {
54 | return w.status
55 | }
56 |
57 | func (w *WaffleResponseWriter) Unwrap() http.ResponseWriter {
58 | return w.ResponseWriter
59 | }
60 |
61 | func (w *WaffleResponseWriter) Reset() {
62 | w.buf.Reset()
63 | w.status = 0
64 | }
65 |
66 | func (w *WaffleResponseWriter) Flush() {
67 | if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
68 | flusher.Flush()
69 | }
70 | }
71 |
72 | func (w *WaffleResponseWriter) Commit() error {
73 | if w.status == 0 {
74 | w.ResponseWriter.WriteHeader(http.StatusOK)
75 | } else {
76 | w.ResponseWriter.WriteHeader(w.status)
77 | }
78 |
79 | if _, err := w.ResponseWriter.Write(w.buf.Bytes()); err != nil {
80 | return err
81 | }
82 |
83 | w.buf.Reset()
84 |
85 | return nil
86 | }
87 |
--------------------------------------------------------------------------------
/internal/emitter/http/parser/json_test.go:
--------------------------------------------------------------------------------
1 | package parser_test
2 |
3 | import (
4 | "io"
5 | "net/http"
6 | "net/http/httptest"
7 | "strings"
8 | "testing"
9 |
10 | "github.com/sitebatch/waffle-go/internal/emitter/http/parser"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func TestJSONParser_Parse(t *testing.T) {
15 | t.Parallel()
16 |
17 | testCases := map[string]struct {
18 | body string
19 | expected map[string][]string
20 | }{
21 | "success": {
22 | body: `{"key": "value"}`,
23 | expected: map[string][]string{
24 | "key": {"value"},
25 | },
26 | },
27 | "nested": {
28 | body: `{"key": {"nested": "value"}}`,
29 | expected: map[string][]string{
30 | "key.nested": {"value"},
31 | },
32 | },
33 | "array": {
34 | body: `{"key": ["value1", "value2"]}`,
35 | expected: map[string][]string{
36 | "key.0": {"value1"},
37 | "key.1": {"value2"},
38 | },
39 | },
40 | "array_nested": {
41 | body: `{"key": [{"nested": "value1"}, {"nested": "value2"}]}`,
42 | expected: map[string][]string{
43 | "key.0.nested": {"value1"},
44 | "key.1.nested": {"value2"},
45 | },
46 | },
47 | "array_nested_array": {
48 | body: `{"key": [{"nested": ["value1", "value2"]}]}`,
49 | expected: map[string][]string{
50 | "key.0.nested.0": {"value1"},
51 | "key.0.nested.1": {"value2"},
52 | },
53 | },
54 | "invalid_json": {
55 | body: `{"key": "value"`,
56 | expected: nil,
57 | },
58 | }
59 |
60 | for name, tc := range testCases {
61 | t.Run(name, func(t *testing.T) {
62 | t.Parallel()
63 |
64 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(tc.body))
65 | req.Header.Set("Content-Type", "application/json")
66 |
67 | got, err := parser.ParseHTTPRequestBody(req)
68 | if tc.expected == nil {
69 | assert.Error(t, err)
70 | return
71 | }
72 |
73 | assert.NoError(t, err)
74 | assert.Equal(t, tc.expected, got)
75 |
76 | // Ensure the body is still readable
77 | b, err := io.ReadAll(req.Body)
78 | assert.NoError(t, err)
79 | assert.Equal(t, tc.body, string(b))
80 | })
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/contrib/net/http/README.md:
--------------------------------------------------------------------------------
1 | # net/http
2 |
3 | This package provides Waffle middleware and HTTP client wrapper for the Go standard library's [net/http](https://pkg.go.dev/net/http).
4 |
5 | - **HTTP Server Protection**: `WafMiddleware` that automatically analyzes incoming requests for malicious patterns
6 | - **HTTP Client Protection**: `WrapClient` that prevents outbound requests to dangerous destinations (SSRF protection)
7 |
8 | ## Installation
9 |
10 | ```bash
11 | go get github.com/sitebatch/waffle-go/contrib/net/http
12 | ```
13 |
14 | ## Usage
15 |
16 | ### HTTP Server Protection
17 |
18 | ```go
19 | package main
20 |
21 | import (
22 | "net/http"
23 | "github.com/sitebatch/waffle-go"
24 | waffleHttp "github.com/sitebatch/waffle-go/contrib/net/http"
25 | )
26 |
27 | func main() {
28 | mux := http.NewServeMux()
29 |
30 | // Apply Waffle WAF middleware
31 | handler := waffleHttp.WafMiddleware(mux)
32 |
33 | // Start Waffle
34 | if err := waffle.Start(); err != nil {
35 | panic(err)
36 | }
37 |
38 | srv := &http.Server{
39 | Addr: ":8000",
40 | Handler: handler,
41 | }
42 |
43 | srv.ListenAndServe()
44 | }
45 | ```
46 |
47 | ### HTTP Client Protection (SSRF Prevention)
48 |
49 | ```go
50 | package main
51 |
52 | import (
53 | "context"
54 | "fmt"
55 | "net/http"
56 |
57 | "github.com/sitebatch/waffle-go"
58 | waffleHttp "github.com/sitebatch/waffle-go/contrib/net/http"
59 | "github.com/sitebatch/waffle-go/waf"
60 | )
61 |
62 | func main() {
63 | // Start Waffle
64 | if err := waffle.Start(); err != nil {
65 | panic(err)
66 | }
67 |
68 | // Wrap HTTP client for SSRF protection
69 | client := waffleHttp.WrapClient(http.DefaultClient)
70 |
71 | // Protected request - Waffle will prevents SSRF attempts
72 | ctx := context.Background()
73 | req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254/", nil)
74 | if err != nil {
75 | panic(err)
76 | }
77 |
78 | resp, err := client.Do(req)
79 | if err != nil {
80 | if waf.IsSecurityBlockingError(err) {
81 | fmt.Printf("Request blocked by Waffle: %v\n", err)
82 | }
83 | return
84 | }
85 | defer resp.Body.Close()
86 | }
87 | ```
88 |
--------------------------------------------------------------------------------
/internal/inspector/types/values.go:
--------------------------------------------------------------------------------
1 | package types
2 |
3 | type GetInspectTargetValueOptions struct {
4 | ParamNames []string
5 | }
6 |
7 | type WithGetInspectTargetValueOptions func(o *GetInspectTargetValueOptions)
8 |
9 | func WithParamNames(paramNames []string) WithGetInspectTargetValueOptions {
10 | return func(o *GetInspectTargetValueOptions) {
11 | o.ParamNames = paramNames
12 | }
13 | }
14 |
15 | type InspectTargetValue interface {
16 | // GetValue returns the value of the target
17 | GetValue() string
18 | // GetValues returns the values of the target.
19 | // If keys are provided, it returns the values of the keys. If keys are not provided, it returns all values.
20 | GetValues(opts ...WithGetInspectTargetValueOptions) []string
21 | }
22 |
23 | type StringValue struct {
24 | Value string
25 | }
26 |
27 | func NewStringValue(value string) InspectTargetValue {
28 | return &StringValue{
29 | Value: value,
30 | }
31 | }
32 |
33 | func (v *StringValue) GetValue() string {
34 | return v.Value
35 | }
36 |
37 | func (v *StringValue) GetValues(opts ...WithGetInspectTargetValueOptions) []string {
38 | return []string{v.Value}
39 | }
40 |
41 | // KeyValues is a struct that contains key-values (map[string][]string, like http.Header and url.Values and more...) of the target.
42 | type KeyValues struct {
43 | Values map[string][]string
44 | }
45 |
46 | func NewKeyValues(values map[string][]string) InspectTargetValue {
47 | return &KeyValues{
48 | Values: values,
49 | }
50 | }
51 |
52 | func (v *KeyValues) GetValue() string {
53 | // not supported for InspectTargetValueKeyValues, return empty string
54 | return ""
55 | }
56 |
57 | func (v *KeyValues) GetValues(opts ...WithGetInspectTargetValueOptions) []string {
58 | o := &GetInspectTargetValueOptions{}
59 |
60 | for _, opt := range opts {
61 | opt(o)
62 | }
63 |
64 | var values []string
65 |
66 | if len(o.ParamNames) == 0 {
67 | for _, v := range v.Values {
68 | values = append(values, v...)
69 | }
70 | return values
71 | }
72 |
73 | for _, name := range o.ParamNames {
74 | if v.Values[name] == nil {
75 | continue
76 | }
77 |
78 | values = append(values, v.Values[name]...)
79 | }
80 |
81 | return values
82 | }
83 |
--------------------------------------------------------------------------------
/internal/inspector/inspector.go:
--------------------------------------------------------------------------------
1 | package inspector
2 |
3 | type InspectorName string
4 |
5 | const (
6 | RegexInspectorName InspectorName = "regex"
7 | MatchListInspectorName InspectorName = "match_list"
8 | LibInjectionSQLIInspectorName InspectorName = "libinjection_sqli"
9 | LibInjectionXSSInspectorName InspectorName = "libinjection_xss"
10 | SQLiInspectorName InspectorName = "sqli"
11 | LFIInspectorName InspectorName = "lfi"
12 | SSRFInspectorName InspectorName = "ssrf"
13 | AccountTakeoverInspectorName InspectorName = "account_takeover"
14 | )
15 |
16 | func NewInspectors() map[InspectorName]Inspector {
17 | return map[InspectorName]Inspector{
18 | RegexInspectorName: NewRegexInspector(),
19 | MatchListInspectorName: NewMatchListInspector(),
20 | LibInjectionSQLIInspectorName: NewLibInjectionSQLIInspector(),
21 | LibInjectionXSSInspectorName: NewLibInjectionXSSInspector(),
22 | SQLiInspectorName: NewSQLiInspector(),
23 | LFIInspectorName: NewLFIInspector(),
24 | SSRFInspectorName: NewSSRFInspector(),
25 | AccountTakeoverInspectorName: NewAccountTakeoverInspector(),
26 | }
27 | }
28 |
29 | type InspectorArgs struct {
30 | TargetOptions []InspectTargetOptions
31 |
32 | RegexInspectorArgs RegexInspectorArgs
33 | MatchListInspectorArgs MatchListInspectorArgs
34 | AccountTakeoverInspectorArgs AccountTakeoverInspectorArgs
35 | }
36 |
37 | type InspectTargetOptions struct {
38 | Target InspectTarget
39 | Params []string
40 | }
41 |
42 | type Inspector interface {
43 | // Inspect inspects the given data
44 | // Returns SuspiciousResult if the inspected data is determined to be an attack, otherwise returns nil
45 | // If an error occurs during inspection, returns an error
46 | Inspect(inspectData InspectData, inspectorArgs InspectorArgs) (*InspectResult, error)
47 | // IsSupportTarget returns whether the inspector supports the target
48 | IsSupportTarget(target InspectTarget) bool
49 | }
50 |
51 | // InspectResult represents the result of an inspection
52 | type InspectResult struct {
53 | Target InspectTarget // the target that was inspected
54 | Payload string // the payload deemed suspicious
55 | Message string // message describing why it is suspicious
56 | }
57 |
--------------------------------------------------------------------------------
/internal/inspector/libinjection.go:
--------------------------------------------------------------------------------
1 | package inspector
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/sitebatch/waffle-go/internal/inspector/libinjection"
7 | "github.com/sitebatch/waffle-go/internal/inspector/types"
8 | )
9 |
10 | type LibInjectionSQLIInspector struct{}
11 | type LibInjectionXSSInspector struct{}
12 |
13 | type LibInjectionSQLIInspectorArgs struct {
14 | InspectTargetOptions []InspectTargetOptions
15 | }
16 |
17 | type LibInjectionXSSInspectorArgs struct {
18 | InspectTargetOptions []InspectTargetOptions
19 | }
20 |
21 | func NewLibInjectionSQLIInspector() Inspector {
22 | return &LibInjectionSQLIInspector{}
23 | }
24 |
25 | func NewLibInjectionXSSInspector() Inspector {
26 | return &LibInjectionXSSInspector{}
27 | }
28 |
29 | func (r *LibInjectionSQLIInspector) IsSupportTarget(target InspectTarget) bool {
30 | return true
31 | }
32 |
33 | func (r *LibInjectionXSSInspector) IsSupportTarget(target InspectTarget) bool {
34 | return true
35 | }
36 |
37 | func (r *LibInjectionSQLIInspector) Inspect(inspectData InspectData, args InspectorArgs) (*InspectResult, error) {
38 | for _, opt := range args.TargetOptions {
39 | if _, ok := inspectData.Target[opt.Target]; !ok {
40 | continue
41 | }
42 |
43 | values := inspectData.Target[opt.Target].GetValues(
44 | types.WithParamNames(opt.Params),
45 | )
46 |
47 | for _, value := range values {
48 | err := libinjection.IsSQLiPayload(value)
49 | if err != nil {
50 | return &InspectResult{
51 | Target: opt.Target,
52 | Payload: value,
53 | Message: fmt.Sprintf("detected sqli payload: %s", err),
54 | }, nil
55 | }
56 | }
57 | }
58 |
59 | return nil, nil
60 | }
61 |
62 | func (r *LibInjectionXSSInspector) Inspect(inspectData InspectData, args InspectorArgs) (*InspectResult, error) {
63 | for _, opt := range args.TargetOptions {
64 | if _, ok := inspectData.Target[opt.Target]; !ok {
65 | continue
66 | }
67 |
68 | values := inspectData.Target[opt.Target].GetValues(
69 | types.WithParamNames(opt.Params),
70 | )
71 |
72 | for _, value := range values {
73 | err := libinjection.IsXSSPayload(value)
74 | if err != nil {
75 | return &InspectResult{
76 | Target: opt.Target,
77 | Payload: value,
78 | Message: fmt.Sprintf("detected xss payload: %s", err),
79 | }, nil
80 | }
81 | }
82 | }
83 |
84 | return nil, nil
85 | }
86 |
--------------------------------------------------------------------------------
/internal/emitter/graphql/handler.go:
--------------------------------------------------------------------------------
1 | package graphql
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/jeremywohl/flatten"
7 | "github.com/sitebatch/waffle-go/internal/emitter/http"
8 | "github.com/sitebatch/waffle-go/internal/emitter/waf"
9 | "github.com/sitebatch/waffle-go/internal/operation"
10 | )
11 |
12 | type GraphqlRequestHandlerOperation struct {
13 | operation.Operation
14 | *waf.WafOperation
15 | }
16 |
17 | type GraphqlRequestHandlerOperationArg struct {
18 | RawQuery string
19 | OperationName string
20 | Variables map[string][]string
21 | }
22 |
23 | type GraphqlRequestHandlerOperationResult struct {
24 | BlockErr error
25 | }
26 |
27 | func (GraphqlRequestHandlerOperationArg) IsArgOf(*GraphqlRequestHandlerOperation) {}
28 | func (*GraphqlRequestHandlerOperationResult) IsResultOf(*GraphqlRequestHandlerOperation) {}
29 |
30 | func StartGraphQLRequestHandlerOperation(ctx context.Context, args GraphqlRequestHandlerOperationArg) (*GraphqlRequestHandlerOperation, context.Context) {
31 | parent, _ := operation.FindOperationFromContext(ctx)
32 |
33 | var wafop *waf.WafOperation
34 | if parentOp, ok := parent.(*http.HTTPRequestHandlerOperation); ok {
35 | wafop = parentOp.WafOperation
36 | } else {
37 | wafop, _ = waf.InitializeWafOperation(ctx)
38 | }
39 |
40 | op := &GraphqlRequestHandlerOperation{
41 | Operation: operation.NewOperation(parent),
42 | WafOperation: wafop,
43 | }
44 |
45 | return op, operation.StartAndSetOperation(ctx, op, args)
46 | }
47 |
48 | func (op *GraphqlRequestHandlerOperation) Finish(res *GraphqlRequestHandlerOperationResult) {
49 | operation.FinishOperation(op, res)
50 | }
51 |
52 | func BuildGraphqlRequestHandlerOperationArg(
53 | rawQuery string,
54 | operationName string,
55 | variables map[string]interface{},
56 | ) GraphqlRequestHandlerOperationArg {
57 | var graphqlVariables map[string][]string
58 |
59 | flat, err := flatten.Flatten(variables, "", flatten.DotStyle)
60 | if err != nil {
61 | graphqlVariables = map[string][]string{}
62 | } else {
63 | graphqlVariables = make(map[string][]string)
64 | for k, v := range flat {
65 | if s, ok := v.(string); ok {
66 | graphqlVariables[k] = []string{s}
67 | }
68 | }
69 | }
70 |
71 | return GraphqlRequestHandlerOperationArg{
72 | RawQuery: rawQuery,
73 | OperationName: operationName,
74 | Variables: graphqlVariables,
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/internal/emitter/http/parser/form_test.go:
--------------------------------------------------------------------------------
1 | package parser_test
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "strings"
7 | "testing"
8 |
9 | "github.com/sitebatch/waffle-go/internal/emitter/http/parser"
10 | "github.com/stretchr/testify/assert"
11 | )
12 |
13 | func TestFormParser_Parse(t *testing.T) {
14 | t.Parallel()
15 |
16 | testCases := map[string]struct {
17 | body string
18 | expected map[string][]string
19 | }{
20 | "simple": {
21 | body: "key=value",
22 | expected: map[string][]string{
23 | "key": {"value"},
24 | },
25 | },
26 | "multiple": {
27 | body: "key1=value1&key2=value2",
28 | expected: map[string][]string{
29 | "key1": {"value1"},
30 | "key2": {"value2"},
31 | },
32 | },
33 | "array": {
34 | body: "key=value1&key=value2",
35 | expected: map[string][]string{
36 | "key": {"value1", "value2"},
37 | },
38 | },
39 | }
40 |
41 | for name, tc := range testCases {
42 | t.Run(name, func(t *testing.T) {
43 | t.Parallel()
44 |
45 | r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(tc.body))
46 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
47 |
48 | result, err := parser.ParseHTTPRequestBody(r)
49 | assert.NoError(t, err)
50 | assert.Equal(t, tc.expected, result)
51 | })
52 | }
53 | }
54 |
55 | func TestMultipartParser_Parse(t *testing.T) {
56 | t.Parallel()
57 |
58 | testCases := map[string]struct {
59 | body string
60 | expected map[string][]string
61 | }{
62 | "success": {
63 | body: `--boundary
64 | Content-Disposition: form-data; name="key"
65 |
66 | value
67 | --boundary--`,
68 | expected: map[string][]string{
69 | "key": {"value"},
70 | },
71 | },
72 | "multiple": {
73 | body: `--boundary
74 | Content-Disposition: form-data; name="key1"
75 |
76 | value1
77 | --boundary
78 | Content-Disposition: form-data; name="key2"
79 |
80 | value2
81 | --boundary--`,
82 | expected: map[string][]string{
83 | "key1": {"value1"},
84 | "key2": {"value2"},
85 | },
86 | },
87 | }
88 |
89 | for name, tc := range testCases {
90 | t.Run(name, func(t *testing.T) {
91 | t.Parallel()
92 |
93 | r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(tc.body))
94 | r.Header.Set("Content-Type", "multipart/form-data; boundary=boundary")
95 |
96 | result, err := parser.ParseHTTPRequestBody(r)
97 | assert.NoError(t, err)
98 | assert.Equal(t, tc.expected, result)
99 | })
100 | }
101 | }
102 |
--------------------------------------------------------------------------------
/handler/response/response_writer_test.go:
--------------------------------------------------------------------------------
1 | package response_test
2 |
3 | import (
4 | "bufio"
5 | "net"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 |
10 | "github.com/sitebatch/waffle-go/handler/response"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func TestWaffleResponseWriter_Unwrap(t *testing.T) {
15 | response.InitResponseWriterFeature()
16 |
17 | testWriter := httptest.NewRecorder()
18 | _, waffleWriter := response.NewWaffleResponseWriter(testWriter)
19 | assert.Same(t, testWriter, waffleWriter.Unwrap())
20 | }
21 |
22 | func TestResponseWriter_WriteHeader(t *testing.T) {
23 | response.InitResponseWriterFeature()
24 |
25 | testWriter := httptest.NewRecorder()
26 | writer, waffleWriter := response.NewWaffleResponseWriter(testWriter)
27 |
28 | writer.WriteHeader(200)
29 | waffleWriter.WriteHeader(200)
30 |
31 | assert.Equal(t, 200, testWriter.Code)
32 | assert.Equal(t, 200, waffleWriter.Status())
33 |
34 | writer.WriteHeader(400)
35 | waffleWriter.WriteHeader(400)
36 | assert.Equal(t, 200, testWriter.Code)
37 | assert.Equal(t, 200, waffleWriter.Status())
38 |
39 | waffleWriter.Reset()
40 | waffleWriter.WriteHeader(400)
41 | assert.Equal(t, 400, waffleWriter.Status())
42 | }
43 |
44 | func TestResponseWriter_Write(t *testing.T) {
45 | response.InitResponseWriterFeature()
46 |
47 | testWriter := httptest.NewRecorder()
48 | writer, waffleWriter := response.NewWaffleResponseWriter(testWriter)
49 |
50 | _, _ = writer.Write([]byte("Hello, World!"))
51 | assert.Equal(t, "", testWriter.Body.String())
52 | assert.NoError(t, waffleWriter.Commit())
53 | assert.Equal(t, "Hello, World!", testWriter.Body.String())
54 |
55 | _, _ = writer.Write([]byte("Goodbye, World!"))
56 | assert.NoError(t, waffleWriter.Commit())
57 | assert.Equal(t, "Hello, World!Goodbye, World!", testWriter.Body.String())
58 | }
59 |
60 | func TestResponseWriter_Hijack(t *testing.T) {
61 | response.InitResponseWriterFeature()
62 |
63 | testWriter := httptest.NewRecorder()
64 | writer, _ := response.NewWaffleResponseWriter(testWriter)
65 | assert.Panics(t, func() {
66 | writer.(http.Hijacker).Hijack()
67 | })
68 |
69 | hijacker := &mockHijackerResponseWriter{ResponseWriter: testWriter}
70 | writer, _ = response.NewWaffleResponseWriter(hijacker)
71 |
72 | conn, _, err := writer.(http.Hijacker).Hijack()
73 | assert.Nil(t, conn)
74 | assert.Nil(t, err)
75 | }
76 |
77 | type mockHijackerResponseWriter struct {
78 | http.ResponseWriter
79 | }
80 |
81 | func (m *mockHijackerResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
82 | return nil, nil, nil
83 | }
84 |
--------------------------------------------------------------------------------
/handler/response/handler_test.go:
--------------------------------------------------------------------------------
1 | package response_test
2 |
3 | import (
4 | "net/http/httptest"
5 | "testing"
6 |
7 | "github.com/sitebatch/waffle-go/handler/response"
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestBlockResponseHandler(t *testing.T) {
12 | testCases := map[string]struct {
13 | contentType string
14 | acceptHeader string
15 | expected string
16 | }{
17 | "return JSON response, when Accept header is set to application/json": {
18 | contentType: "text/html",
19 | acceptHeader: "application/json",
20 | expected: "{\"error\": \"access denied.",
21 | },
22 | "return JSON response, when Content-Type header is set to application/json": {
23 | contentType: "application/json",
24 | acceptHeader: "*/*",
25 | expected: "{\"error\": \"access denied.",
26 | },
27 | "return HTML response, when Accept header is set to text/html": {
28 | contentType: "",
29 | acceptHeader: "text/html",
30 | expected: "