├── cmd └── shouter │ └── main.go ├── conn.go ├── server_test.go └── server.go /cmd/shouter/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/sahilm/shouter" 7 | ) 8 | 9 | func main() { 10 | srv := shouter.Server{ 11 | Addr: ":8080", 12 | IdleTimeout: 10 * time.Second, 13 | MaxReadBytes: 1000, 14 | } 15 | go srv.ListenAndServe() 16 | time.Sleep(10 * time.Second) 17 | //srv.Shutdown() 18 | select {} 19 | } 20 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package shouter 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "time" 7 | ) 8 | 9 | type conn struct { 10 | net.Conn 11 | 12 | IdleTimeout time.Duration 13 | MaxReadBuffer int64 14 | } 15 | 16 | func (c *conn) Write(p []byte) (n int, err error) { 17 | c.updateDeadline() 18 | n, err = c.Conn.Write(p) 19 | return 20 | } 21 | 22 | func (c *conn) Read(b []byte) (n int, err error) { 23 | c.updateDeadline() 24 | r := io.LimitReader(c.Conn, c.MaxReadBuffer) 25 | n, err = r.Read(b) 26 | return 27 | } 28 | 29 | func (c *conn) Close() (err error) { 30 | err = c.Conn.Close() 31 | return 32 | } 33 | 34 | func (c *conn) updateDeadline() { 35 | idleDeadline := time.Now().Add(c.IdleTimeout) 36 | c.Conn.SetDeadline(idleDeadline) 37 | } 38 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | package shouter_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "bufio" 8 | "net" 9 | 10 | "github.com/sahilm/shouter" 11 | ) 12 | 13 | // These are manual tests 14 | 15 | func TestServerProtectsAgaintSlowloris(t *testing.T) { 16 | srv := shouter.Server{ 17 | Addr: ":8080", 18 | IdleTimeout: 5 * time.Second, 19 | MaxReadBytes: 1000, 20 | } 21 | go srv.ListenAndServe() 22 | 23 | time.Sleep(1 * time.Second) // hack to wait for server to start 24 | conn, err := net.Dial("tcp", "0.0.0.0:8080") 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | // We slowly write to simulate a Slowloris. We should fail in one second 29 | // because we don't satisfy the application level requirement of sending a complete request (with newlines) 30 | // within 1 second 31 | for { 32 | w := bufio.NewWriter(conn) 33 | w.WriteString(".") 34 | time.Sleep(200 * time.Millisecond) 35 | w.Flush() 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package shouter 2 | 3 | import ( 4 | "bufio" 5 | "log" 6 | "net" 7 | "strings" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | type Server struct { 13 | Addr string 14 | IdleTimeout time.Duration 15 | MaxReadBytes int64 16 | 17 | listener net.Listener 18 | conns map[*conn]struct{} 19 | mu sync.Mutex 20 | inShutdown bool 21 | } 22 | 23 | func (srv *Server) ListenAndServe() error { 24 | addr := srv.Addr 25 | if addr == "" { 26 | addr = ":8080" 27 | } 28 | log.Printf("starting server on %v\n", addr) 29 | listener, err := net.Listen("tcp", addr) 30 | if err != nil { 31 | return err 32 | } 33 | defer listener.Close() 34 | srv.listener = listener 35 | for { 36 | // should be guarded by mu 37 | if srv.inShutdown { 38 | break 39 | } 40 | newConn, err := listener.Accept() 41 | if err != nil { 42 | log.Printf("error accepting connection %v", err) 43 | continue 44 | } 45 | log.Printf("accepted connection from %v", newConn.RemoteAddr()) 46 | conn := &conn{ 47 | Conn: newConn, 48 | IdleTimeout: srv.IdleTimeout, 49 | MaxReadBuffer: srv.MaxReadBytes, 50 | } 51 | srv.trackConn(conn) 52 | conn.SetDeadline(time.Now().Add(conn.IdleTimeout)) 53 | go srv.handle(conn) 54 | } 55 | return nil 56 | } 57 | 58 | func (srv *Server) trackConn(c *conn) { 59 | defer srv.mu.Unlock() 60 | srv.mu.Lock() 61 | if srv.conns == nil { 62 | srv.conns = make(map[*conn]struct{}) 63 | } 64 | srv.conns[c] = struct{}{} 65 | } 66 | 67 | func (srv *Server) handle(conn *conn) error { 68 | defer func() { 69 | log.Printf("closing connection from %v", conn.RemoteAddr()) 70 | conn.Close() 71 | srv.deleteConn(conn) 72 | }() 73 | r := bufio.NewReader(conn) 74 | w := bufio.NewWriter(conn) 75 | scanr := bufio.NewScanner(r) 76 | 77 | sc := make(chan bool) 78 | deadline := time.After(conn.IdleTimeout) 79 | for { 80 | go func(s chan bool) { 81 | s <- scanr.Scan() 82 | }(sc) 83 | select { 84 | case <-deadline: 85 | return nil 86 | case scanned := <-sc: 87 | if !scanned { 88 | if err := scanr.Err(); err != nil { 89 | return err 90 | } 91 | return nil 92 | } 93 | w.WriteString(strings.ToUpper(scanr.Text()) + "\n") 94 | w.Flush() 95 | deadline = time.After(conn.IdleTimeout) 96 | } 97 | } 98 | return nil 99 | } 100 | 101 | func (srv *Server) deleteConn(conn *conn) { 102 | defer srv.mu.Unlock() 103 | srv.mu.Lock() 104 | delete(srv.conns, conn) 105 | } 106 | 107 | func (srv *Server) Shutdown() { 108 | // should be guarded by mu 109 | srv.inShutdown = true 110 | log.Println("shutting down...") 111 | srv.listener.Close() 112 | ticker := time.NewTicker(500 * time.Millisecond) 113 | defer ticker.Stop() 114 | for { 115 | select { 116 | case <-ticker.C: 117 | log.Printf("waiting on %v connections", len(srv.conns)) 118 | } 119 | if len(srv.conns) == 0 { 120 | return 121 | } 122 | } 123 | } 124 | --------------------------------------------------------------------------------