├── README.md ├── client ├── client.go ├── client_test.go └── example │ └── client.go ├── glide.yaml ├── httputil ├── httputil.go └── httputil_test.go ├── pubsub ├── cmds.go ├── cmds_test.go ├── history.go ├── history_test.go ├── topic.go └── topic_test.go ├── server └── server.go └── transport └── transport.go /README.md: -------------------------------------------------------------------------------- 1 | # Drift: An HTTP/2 Pub/Sub service 2 | 3 | [![Stability: Experimental](https://masterminds.github.io/stability/experimental.svg)](https://masterminds.github.io/stability/experimental.html) 4 | 5 | Drift is a topic-based PubSub service based on HTTP/2. It uses HTTP/2's 6 | ability to stream data as a simple mechanism for managing subscriptions. 7 | 8 | In a nutshell, one or more _publishers_ send data to _topics_. One or 9 | more _subscribers_ can listen on that topic. Every time a publisher 10 | sends a message, all of the subscribers will receive it. 11 | 12 | Features: 13 | 14 | - Uses the HTTP/2 standard with no additions. 15 | - JSON? Thrift? ProtoBuf? Use whatever. 16 | - Service metadata is confinted to HTTP headers. The payload is all 17 | yours. 18 | - Configurable history lets clients quickly catch up on what they missed. 19 | - Extensible architecture makes it easy for you to add your own flair. 20 | - And more in the works... 21 | 22 | The current implementation streams Data Frames. Once the Go libraries 23 | mature, we may instead opt to use full pushes (though the overhead for 24 | that may be higher than we want). 25 | 26 | **This library is not stable. The interfaces may change before the 0.1 27 | release** 28 | 29 | **Currently, the library ONLY supports HTTPS.** 30 | 31 | ## Installation 32 | 33 | ``` 34 | $ brew install glide 35 | $ git clone $THIS_REPO 36 | $ glide init 37 | ``` 38 | 39 | From there, you can build the server (`go build server/server.go`) or 40 | the example client (`go build client/client.go`). 41 | 42 | ## Simple Client Example 43 | 44 | To use Drift as a client, import the client library: 45 | 46 | ```go 47 | import "github.com/technosophos/drift/client" 48 | ``` 49 | 50 | Here is a simple publisher: 51 | 52 | ```go 53 | c := client.New("https://localhost:5500") 54 | c.Publish("example", []byte("Hello World")) 55 | ``` 56 | 57 | The above sends the "Hello World" message over the `example` topic. 58 | 59 | A subscriber looks like this: 60 | 61 | ```go 62 | s := client.New("https://localhost:5500") 63 | subscription, err := s.Subscribe("example") 64 | if err != nil { 65 | fmt.Printf("Failed subscription: %s", err) 66 | return 67 | } 68 | 69 | // Now listen on a stream. 70 | for msg := range subscription.Stream { 71 | fmt.Printf("Received: %s\n", msg) 72 | } 73 | 74 | // When you're done... 75 | subscription.Cancel() 76 | ``` 77 | 78 | A more advanced API is provided for configuring history and adding 79 | arbitrary HTTP headers. 80 | 81 | ## About the Server 82 | 83 | The server lives in `server/server.go`. The basic server provides 84 | convenient features for getting running quickly. 85 | 86 | But the server was also designed as a composable system. You can easily 87 | take the parts here and add your own. Take a look at the registry in 88 | `server.server.go` to see how this is done. 89 | 90 | ## API 91 | 92 | `GET /` 93 | 94 | Prints a the runtime API documentation. 95 | 96 | `DELETE /v1/t/TOPIC` 97 | 98 | Destroy a topic named `TOPIC`. 99 | 100 | This will destroy the history and cancel subscriptions for all 101 | subscribed clients. 102 | 103 | `GET /v1/t/TOPIC` 104 | 105 | Subscribe to a topic named `TOPIC`. The client is expected to hold open 106 | a connection for the duration of its subscription. 107 | 108 | This method **does not support HTTP/1 at all!** You must use HTTP/2. 109 | 110 | 111 | `POST /v1/t/TOPIC` 112 | 113 | Post a new message into the topic named `TOPIC`. 114 | 115 | The body of the post message is pushed wholesale into the queue. 116 | 117 | This method accepts HTTP/1.1 POST content in addition to HTTP/2 POST. 118 | Only one data frame of HTTP/2 POST data is accepted. Streamed POST is 119 | currently not supported (though it will be). 120 | 121 | `PUT /v1/t/TOPIC` 122 | 123 | Create a new topic named `TOPIC`. 124 | 125 | The body of this message is a well-defined JSON data structure that 126 | describes the topic. 127 | 128 | `GET /v1/time` 129 | 130 | Get the current time. This returns a plain text value with nothing but a 131 | timestamp. 132 | 133 | ``` 134 | $ curl -k https://localhost:5500/v1/time 135 | 1436464998 136 | ``` 137 | 138 | The purpose of this callback is to give client libries a timestamp to 139 | use as the base time for calculating dates. This can reduce problems 140 | with clock skew. 141 | -------------------------------------------------------------------------------- /client/client.go: -------------------------------------------------------------------------------- 1 | // Package client provides a client library for Drift. 2 | package client 3 | 4 | import ( 5 | "bytes" 6 | "crypto/tls" 7 | "errors" 8 | "fmt" 9 | "net/http" 10 | "path" 11 | "time" 12 | 13 | "github.com/technosophos/drift/transport" 14 | ) 15 | 16 | const v1Path = "/v1/t/" 17 | 18 | // Client provides consumer functions for Drift. 19 | // 20 | // Client contains the simple methods for working with subscriptions 21 | // and publishing. The Subscriber and Publisher objects can be used for 22 | // more detailed work. 23 | type Client struct { 24 | Url string 25 | // Does not verify cert against authorities. 26 | InsecureTLSDial bool 27 | } 28 | 29 | // New creates and initializes a new client. 30 | func New(url string) *Client { 31 | return &Client{ 32 | Url: url, 33 | } 34 | } 35 | 36 | // Create creates a new topic on the pubsub server. 37 | func (c *Client) Create(topic string) error { 38 | url := c.Url + path.Join(v1Path, topic) 39 | _, err := c.basicRoundTrip("PUT", url) 40 | return err 41 | } 42 | 43 | // Delete removes an existing topic from the pubsub server. 44 | func (c *Client) Delete(topic string) error { 45 | url := c.Url + path.Join(v1Path, topic) 46 | _, err := c.basicRoundTrip("DELETE", url) 47 | return err 48 | } 49 | 50 | // Checks whether the server already has the topic. 51 | func (c *Client) Exists(topic string) bool { 52 | url := c.Url + path.Join(v1Path, topic) 53 | _, err := c.basicRoundTrip("HEAD", url) 54 | return err == nil 55 | } 56 | 57 | func (c *Client) Publish(topic string, msg []byte) error { 58 | p := NewPublisher(c.Url) 59 | _, err := p.Publish(topic, msg) 60 | return err 61 | } 62 | 63 | func (c *Client) Subscribe(topic string) (*Subscription, error) { 64 | s := NewSubscriber(c.Url) 65 | s.History.Len = 100 66 | return s.Subscribe(topic) 67 | } 68 | 69 | func (c *Client) basicRoundTrip(verb, url string) (*http.Response, error) { 70 | t := &transport.Transport{InsecureTLSDial: c.InsecureTLSDial} 71 | 72 | req, err := http.NewRequest(verb, url, nil) 73 | if err != nil { 74 | return nil, err 75 | } 76 | 77 | return t.RoundTrip(req) 78 | } 79 | 80 | // Publisher is responsible for publishing messages to the service. 81 | type Publisher struct { 82 | Url string 83 | Header http.Header 84 | } 85 | 86 | // NewPublisher creates a new Publisher. 87 | func NewPublisher(url string) *Publisher { 88 | return &Publisher{ 89 | Url: url, 90 | Header: map[string][]string{}, 91 | } 92 | } 93 | 94 | // Publish sends the service a message for a particular topic. 95 | func (p *Publisher) Publish(topic string, message []byte) (*http.Response, error) { 96 | 97 | if len(message) == 0 { 98 | return nil, errors.New("Cannot send an empty message") 99 | } 100 | if len(topic) == 0 { 101 | return nil, errors.New("Cannot publish to an empty topic.") 102 | } 103 | /* HTTP2 does not currently send the body! So we have to go to HTTP1 104 | t := &transport.Transport{InsecureTLSDial: true} 105 | */ 106 | t := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} 107 | 108 | url := p.Url + path.Join(v1Path, topic) 109 | 110 | var body bytes.Buffer 111 | body.Write(message) 112 | 113 | req, _ := http.NewRequest("POST", url, &body) 114 | 115 | return t.RoundTrip(req) 116 | } 117 | 118 | // Subscription represents an existing subscription that a subscriber 119 | // has subscribed to. 120 | type Subscription struct { 121 | C chan []byte 122 | listener transport.Listener 123 | } 124 | 125 | func (s *Subscription) Cancel() { 126 | // Signal the transport that the clientStream should be removed. 127 | s.listener.Cancel() 128 | } 129 | 130 | // Subscriber defines a client that subscribes to a topic on a PubSub. 131 | type Subscriber struct { 132 | Url string 133 | History History 134 | Header http.Header 135 | } 136 | 137 | func NewSubscriber(url string) *Subscriber { 138 | return &Subscriber{ 139 | Url: url, 140 | Header: map[string][]string{}, 141 | } 142 | } 143 | 144 | // History describes how much history a subscriber should ask for. 145 | // 146 | // Be default, Subscribers do not ask for any history. 147 | type History struct { 148 | Since time.Time 149 | Len int 150 | } 151 | 152 | func (s *Subscriber) Subscribe(topic string) (*Subscription, error) { 153 | if len(topic) == 0 { 154 | return nil, errors.New("Cannot subscribe to an empty channel.") 155 | } 156 | 157 | url := s.Url + path.Join(v1Path, topic) 158 | fmt.Printf("URL: %s\n", url) 159 | 160 | t := &transport.Transport{InsecureTLSDial: true} 161 | 162 | req, err := http.NewRequest("GET", url, nil) 163 | if err != nil { 164 | return nil, err 165 | } 166 | 167 | s.setHeaders(req) 168 | 169 | _, listener, err := t.Listen(req) 170 | if err != nil { 171 | return nil, err 172 | } 173 | 174 | stream, err := listener.Stream() 175 | if err != nil { 176 | return nil, err 177 | } 178 | 179 | return &Subscription{C: stream, listener: listener}, nil 180 | } 181 | 182 | func (s *Subscriber) setHeaders(req *http.Request) { 183 | req.Header = s.Header 184 | if s.History.Len > 0 { 185 | req.Header.Add("X-History-Length", fmt.Sprintf("%d", s.History.Len)) 186 | } 187 | if s.History.Since.After(time.Unix(0, 0)) { 188 | req.Header.Add("X-History-Since", fmt.Sprintf("%d", s.History.Since.Unix())) 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /client/client_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | "time" 7 | 8 | "github.com/Masterminds/cookoo" 9 | "github.com/Masterminds/cookoo/web" 10 | "github.com/bradfitz/http2" 11 | "github.com/technosophos/drift/httputil" 12 | "github.com/technosophos/drift/pubsub" 13 | ) 14 | 15 | var hostport = "127.0.0.1:5500" 16 | var baseurl = "https://127.0.0.1:5500" 17 | var topicname = "test.topic" 18 | 19 | func TestClient(t *testing.T) { 20 | // Lots of timing to simulate networkiness. Because that makes the 21 | // test nondeterministic, we use fairly large times. 22 | 23 | go standUpServer() 24 | time.Sleep(2 * time.Second) 25 | 26 | cli := New(baseurl) 27 | 28 | go func() { 29 | if err := cli.Publish(topicname, []byte("test")); err != nil { 30 | t.Fatal(err) 31 | } 32 | 33 | time.Sleep(50 * time.Millisecond) 34 | 35 | cli.Publish(topicname, []byte("Again")) 36 | }() 37 | 38 | println("Subscribing") 39 | si, err := cli.Subscribe(topicname) 40 | if err != nil { 41 | t.Fatal(err) 42 | } 43 | 44 | go func() { 45 | time.Sleep(500 * time.Millisecond) 46 | si.Cancel() 47 | }() 48 | 49 | if first := <-si.C; string(first) != "test" { 50 | t.Errorf("expected test, got %s", first) 51 | } 52 | if second := <-si.C; string(second) != "Again" { 53 | t.Errorf("expected Again, got %s", second) 54 | } 55 | 56 | time.Sleep(1 * time.Second) 57 | cli.Delete(topicname) 58 | } 59 | 60 | func standUpServer() error { 61 | srv := &http.Server{Addr: hostport} 62 | 63 | reg, router, cxt := cookoo.Cookoo() 64 | 65 | buildRegistry(reg, router, cxt) 66 | 67 | // Our main datasource is the Medium, which manages channels. 68 | m := pubsub.NewMedium() 69 | cxt.AddDatasource(pubsub.MediumDS, m) 70 | cxt.Put("routes", reg.Routes()) 71 | 72 | http2.ConfigureServer(srv, &http2.Server{}) 73 | 74 | srv.Handler = web.NewCookooHandler(reg, router, cxt) 75 | 76 | srv.ListenAndServeTLS("../server/server.crt", "../server/server.key") 77 | return nil 78 | } 79 | 80 | func buildRegistry(reg *cookoo.Registry, router *cookoo.Router, cxt cookoo.Context) { 81 | 82 | reg.AddRoute(cookoo.Route{ 83 | Name: "PUT /v1/t/*", 84 | Help: "Create a new topic.", 85 | Does: cookoo.Tasks{ 86 | cookoo.Cmd{ 87 | Name: "topic", 88 | Fn: pubsub.CreateTopic, 89 | Using: []cookoo.Param{ 90 | {Name: "topic", From: "path:2"}, 91 | }, 92 | }, 93 | }, 94 | }) 95 | 96 | reg.AddRoute(cookoo.Route{ 97 | Name: "POST /v1/t/*", 98 | Help: "Publish a message to a channel.", 99 | Does: cookoo.Tasks{ 100 | cookoo.Cmd{ 101 | Name: "postBody", 102 | Fn: httputil.BufferPost, 103 | }, 104 | cookoo.Cmd{ 105 | Name: "publish", 106 | Fn: pubsub.Publish, 107 | Using: []cookoo.Param{ 108 | {Name: "message", From: "cxt:postBody"}, 109 | {Name: "topic", From: "path:2"}, 110 | }, 111 | }, 112 | }, 113 | }) 114 | 115 | reg.AddRoute(cookoo.Route{ 116 | Name: "GET /v1/t/*", 117 | Help: "Subscribe to a topic.", 118 | Does: cookoo.Tasks{ 119 | cookoo.Cmd{ 120 | Name: "history", 121 | Fn: pubsub.ReplayHistory, 122 | Using: []cookoo.Param{ 123 | {Name: "topic", From: "path:2"}, 124 | }, 125 | }, 126 | cookoo.Cmd{ 127 | Name: "subscribe", 128 | Fn: pubsub.Subscribe, 129 | Using: []cookoo.Param{ 130 | {Name: "topic", From: "path:2"}, 131 | }, 132 | }, 133 | }, 134 | }) 135 | 136 | reg.AddRoute(cookoo.Route{ 137 | Name: "HEAD /v1/t/*", 138 | Help: "Check whether a topic exists.", 139 | Does: cookoo.Tasks{ 140 | cookoo.Cmd{ 141 | Name: "has", 142 | Fn: pubsub.TopicExists, 143 | Using: []cookoo.Param{ 144 | {Name: "topic", From: "path:2"}, 145 | }, 146 | }, 147 | }, 148 | }) 149 | 150 | reg.AddRoute(cookoo.Route{ 151 | Name: "DELETE /v1/t/*", 152 | Help: "Delete a topic and close all subscriptions to the topic.", 153 | Does: cookoo.Tasks{ 154 | cookoo.Cmd{ 155 | Name: "delete", 156 | Fn: pubsub.DeleteTopic, 157 | Using: []cookoo.Param{ 158 | {Name: "topic", From: "path:2"}, 159 | }, 160 | }, 161 | }, 162 | }) 163 | } 164 | -------------------------------------------------------------------------------- /client/example/client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | //"io/ioutil" 7 | "os" 8 | //"net" 9 | "bytes" 10 | "crypto/tls" 11 | "net/http" 12 | "time" 13 | 14 | "github.com/bradfitz/http2" 15 | "github.com/bradfitz/http2/hpack" 16 | "github.com/technosophos/drift/client" 17 | "github.com/technosophos/drift/transport" 18 | ) 19 | 20 | func main() { 21 | 22 | cmd := os.Args[1] 23 | 24 | /* 25 | //t := &http2.Transport{InsecureTLSDial: true} 26 | t := &transport.Transport{InsecureTLSDial: true} 27 | 28 | req, _ := http.NewRequest("GET", "https://localhost:5500/", nil) 29 | 30 | //res, err := t.RoundTrip(req) 31 | res, stream, err := t.Listen(req) 32 | if err != nil { 33 | fmt.Printf("Failed: %s\n", err) 34 | } 35 | data, _ := ioutil.ReadAll(res.Body) 36 | fmt.Printf("Response: %s %d, %q\n", res.Proto, res.StatusCode, data) 37 | 38 | fmt.Println("Waiting for the next message.") 39 | moredata := <-stream 40 | fmt.Printf("Final data: %s", moredata) 41 | 42 | // Next, hit the ticker 43 | getTicker() 44 | */ 45 | switch cmd { 46 | case "publish": 47 | //publish() 48 | p := client.NewPublisher("https://localhost:5500") 49 | p.Publish("example", []byte("Hello World")) 50 | time.Sleep(200 * time.Millisecond) 51 | p.Publish("example", []byte("Hello again")) 52 | case "subscribe": 53 | //subscribe() 54 | s := client.NewSubscriber("https://localhost:5500") 55 | s.History = &client.History{Len: 5} 56 | stream, err := s.Subscribe("example") 57 | if err != nil { 58 | fmt.Printf("Failed subscription: %s", err) 59 | return 60 | } 61 | for msg := range stream { 62 | fmt.Printf("Received: %s\n", msg) 63 | } 64 | default: 65 | fmt.Printf("Unknown command: %s\n", cmd) 66 | } 67 | } 68 | 69 | func publish() { 70 | 71 | /* HTTP2 does not currently send the body! So we have to go to HTTP1 72 | t := &transport.Transport{InsecureTLSDial: true} 73 | */ 74 | t := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} 75 | var body bytes.Buffer 76 | 77 | body.Write([]byte("test")) 78 | 79 | req, _ := http.NewRequest("POST", "https://localhost:5500/v1/t/TEST", &body) 80 | 81 | res, err := t.RoundTrip(req) 82 | if err != nil { 83 | fmt.Printf("Error during round trip: %s", err) 84 | } 85 | 86 | fmt.Printf("Status: %s", res.Status) 87 | 88 | } 89 | 90 | func subscribe() { 91 | t := &transport.Transport{InsecureTLSDial: true} 92 | 93 | req, _ := http.NewRequest("GET", "https://localhost:5500/v1/t/TEST", nil) 94 | 95 | //res, err := t.RoundTrip(req) 96 | res, stream, err := t.Listen(req) 97 | if err != nil { 98 | fmt.Printf("Failed: %s\n", err) 99 | } 100 | fmt.Printf("Headers: %s\n", res.Status) 101 | //data, _ := ioutil.ReadAll(res.Body) 102 | //fmt.Printf("Response: %s %d, %q\n", res.Proto, res.StatusCode, data) 103 | 104 | fmt.Println("Waiting for the next message.") 105 | for data := range stream { 106 | fmt.Printf("\nReceived: %s\n", data) 107 | } 108 | } 109 | 110 | func getTicker() *http.Response { 111 | t := &transport.Transport{InsecureTLSDial: true} 112 | 113 | req, _ := http.NewRequest("GET", "https://localhost:5500/tick", nil) 114 | 115 | //res, err := t.RoundTrip(req) 116 | res, stream, err := t.Listen(req) 117 | if err != nil { 118 | fmt.Printf("Failed: %s\n", err) 119 | } 120 | 121 | fmt.Printf("Status: %s\n", res.Status) 122 | fmt.Printf("Body: %s\n", res.Body) 123 | 124 | for data := range stream { 125 | fmt.Printf("\nReceived: %s\n", data) 126 | } 127 | 128 | //io.Copy(os.Stdout, res.Body) 129 | return res 130 | } 131 | 132 | func custom() { 133 | dest := "localhost" 134 | port := ":5500" 135 | // Create a new client. 136 | 137 | tlscfg := &tls.Config{ 138 | ServerName: dest, 139 | NextProtos: []string{http2.NextProtoTLS}, 140 | InsecureSkipVerify: true, 141 | } 142 | 143 | conn, err := tls.Dial("tcp", dest+port, tlscfg) 144 | if err != nil { 145 | panic(err) 146 | } 147 | defer conn.Close() 148 | 149 | conn.Handshake() 150 | state := conn.ConnectionState() 151 | fmt.Printf("Protocol is : %q\n", state.NegotiatedProtocol) 152 | 153 | if _, err := io.WriteString(conn, http2.ClientPreface); err != nil { 154 | fmt.Printf("Preface failed: %s", err) 155 | return 156 | } 157 | 158 | var hbuf bytes.Buffer 159 | framer := http2.NewFramer(conn, conn) 160 | 161 | enc := hpack.NewEncoder(&hbuf) 162 | writeHeader(enc, ":authority", "localhost") 163 | writeHeader(enc, ":method", "GET") 164 | writeHeader(enc, ":path", "/ping") 165 | writeHeader(enc, ":scheme", "https") 166 | writeHeader(enc, "Accept", "*/*") 167 | 168 | if len(hbuf.Bytes()) > 16<<10 { 169 | fmt.Printf("Need CONTINUATION\n") 170 | } 171 | 172 | headers := http2.HeadersFrameParam{ 173 | StreamID: 1, 174 | EndStream: true, 175 | EndHeaders: true, 176 | BlockFragment: hbuf.Bytes(), 177 | } 178 | 179 | fmt.Printf("All the stuff: %q\n", headers.BlockFragment) 180 | 181 | go listen(framer) 182 | 183 | framer.WriteSettings() 184 | framer.WriteWindowUpdate(0, 1<<30) 185 | framer.WriteSettingsAck() 186 | 187 | time.Sleep(time.Second * 2) 188 | framer.WriteHeaders(headers) 189 | time.Sleep(time.Second * 2) 190 | 191 | /* A ping HTTP request 192 | var payload [8]byte 193 | copy(payload[:], "_c0ffee_") 194 | framer.WritePing(false, payload) 195 | rawpong, err := framer.ReadFrame() 196 | if err != nil { 197 | panic(err) 198 | } 199 | 200 | pong, ok := rawpong.(*http2.PingFrame) 201 | if !ok { 202 | fmt.Printf("Instead of a Ping, I got this: %v\n", pong) 203 | return 204 | } 205 | 206 | fmt.Printf("Pong: %q\n", pong.Data) 207 | */ 208 | 209 | } 210 | 211 | func listen(framer *http2.Framer) { 212 | for { 213 | response, err := framer.ReadFrame() 214 | if err != nil { 215 | if err == io.EOF { 216 | return 217 | } 218 | fmt.Printf("Error: Got %q\n", err) 219 | } 220 | switch t := response.(type) { 221 | case *http2.SettingsFrame: 222 | t.ForeachSetting(func(s http2.Setting) error { 223 | fmt.Printf("Setting: %q\n", s) 224 | return nil 225 | }) 226 | case *http2.GoAwayFrame: 227 | fmt.Printf("Go Away code = %q, stream ID = %d\n", t.ErrCode, t.StreamID) 228 | } 229 | //data := response.(*http2.DataFrame) 230 | fmt.Printf("Got %q\n", response) 231 | } 232 | } 233 | 234 | func writeHeader(enc *hpack.Encoder, name, value string) { 235 | enc.WriteField(hpack.HeaderField{Name: name, Value: value}) 236 | } 237 | -------------------------------------------------------------------------------- /glide.yaml: -------------------------------------------------------------------------------- 1 | package: github.com/technosophos/drift 2 | import: 3 | - package: github.com/bradfitz/http2 4 | - package: github.com/Masterminds/cookoo 5 | -------------------------------------------------------------------------------- /httputil/httputil.go: -------------------------------------------------------------------------------- 1 | package httputil 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "time" 9 | 10 | "github.com/Masterminds/cookoo" 11 | ) 12 | 13 | // BufferPost buffers the body of the POST request into the context. 14 | // 15 | // Params: 16 | // 17 | // Returns: 18 | // - []byte with the content of the request. 19 | func BufferPost(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) { 20 | req := c.Get("http.Request", nil).(*http.Request) 21 | var b bytes.Buffer 22 | _, err := io.Copy(&b, req.Body) 23 | c.Logf("info", "Received POST: %s", b.Bytes()) 24 | return b.Bytes(), err 25 | } 26 | 27 | // Timestamp returns a UNIX timestamp. 28 | // 29 | // Params: 30 | // 31 | // Returns: 32 | // - int64 timestamp as seconds since epoch. 33 | // 34 | func Timestamp(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) { 35 | return fmt.Sprintf("%d", time.Now().Unix()), nil 36 | } 37 | 38 | // Debug displays debugging info. 39 | func Debug(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) { 40 | w := c.Get("http.ResponseWriter", nil).(http.ResponseWriter) 41 | r := c.Get("http.Request", nil).(*http.Request) 42 | reqInfoHandler(w, r) 43 | return nil, nil 44 | } 45 | func reqInfoHandler(w http.ResponseWriter, r *http.Request) { 46 | w.Header().Set("Content-Type", "text/plain") 47 | fmt.Fprintf(w, "Method: %s\n", r.Method) 48 | fmt.Fprintf(w, "Protocol: %s\n", r.Proto) 49 | fmt.Fprintf(w, "Host: %s\n", r.Host) 50 | fmt.Fprintf(w, "RemoteAddr: %s\n", r.RemoteAddr) 51 | fmt.Fprintf(w, "RequestURI: %q\n", r.RequestURI) 52 | fmt.Fprintf(w, "URL: %#v\n", r.URL) 53 | fmt.Fprintf(w, "Body.ContentLength: %d (-1 means unknown)\n", r.ContentLength) 54 | fmt.Fprintf(w, "Close: %v (relevant for HTTP/1 only)\n", r.Close) 55 | fmt.Fprintf(w, "TLS: %#v\n", r.TLS) 56 | fmt.Fprintf(w, "\nHeaders:\n") 57 | r.Header.Write(w) 58 | } 59 | -------------------------------------------------------------------------------- /httputil/httputil_test.go: -------------------------------------------------------------------------------- 1 | package httputil 2 | 3 | import ( 4 | "github.com/Masterminds/cookoo" 5 | "strconv" 6 | "testing" 7 | ) 8 | 9 | func TestTimestamp(t *testing.T) { 10 | reg, router, cxt := cookoo.Cookoo() 11 | 12 | reg.Route("test", "Test route"). 13 | Does(Timestamp, "res") 14 | 15 | err := router.HandleRequest("test", cxt, true) 16 | if err != nil { 17 | t.Error(err) 18 | } 19 | 20 | ts := cxt.Get("res", "").(string) 21 | 22 | if len(ts) == 0 { 23 | t.Errorf("Expected timestamp, not empty string.") 24 | } 25 | 26 | tsInt, err := strconv.Atoi(ts) 27 | if err != nil { 28 | t.Error(err) 29 | } 30 | 31 | if tsInt <= 5 { 32 | t.Error("Dude, you're stuck in the '70s.") 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /pubsub/cmds.go: -------------------------------------------------------------------------------- 1 | /* Package pubsub provides publish/subscribe operations for HTTP/2. 2 | 3 | */ 4 | package pubsub 5 | 6 | import ( 7 | "errors" 8 | "fmt" 9 | "net/http" 10 | "strconv" 11 | "time" 12 | 13 | "github.com/Masterminds/cookoo" 14 | ) 15 | 16 | const MediumDS = "drift.Medium" 17 | 18 | const ( 19 | // XHistorySince is an HTTP header for the client to send a request for history since TIMESTAMP. 20 | XHistorySince = "x-history-since" 21 | // XHistoryLength is an HTTP header for the client to send a request for the last N records. 22 | XHistoryLength = "x-history-length" 23 | // XHistoryEnabled is a flag for the server to notify the client whether history is enabled. 24 | XHistoryEnabled = "x-history-enabled" 25 | ) 26 | 27 | // Publish sends a new message to a topic. 28 | // 29 | // Params: 30 | // - topic (string): The topic to send to. 31 | // - message ([]byte): The message to send. 32 | // - withHistory (bool): Turn on history. Default is true. This only takes 33 | // effect when the channel is created. 34 | // 35 | // Datasources: 36 | // - This uses the 'drift.Medium' datasource. 37 | // 38 | // Returns: 39 | // 40 | func Publish(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) { 41 | hist := p.Get("withHistory", true).(bool) 42 | topic := p.Get("topic", "").(string) 43 | if len(topic) == 0 { 44 | return nil, errors.New("No topic supplied.") 45 | } 46 | 47 | medium, _ := getMedium(c) 48 | 49 | // Is there any reason to disallow empty messages? 50 | msg := p.Get("message", []byte{}).([]byte) 51 | c.Logf("info", "Msg: %s", msg) 52 | 53 | t := fetchOrCreateTopic(medium, topic, hist, DefaultMaxHistory) 54 | return nil, t.Publish(msg) 55 | 56 | } 57 | 58 | // Subscribe allows an request to subscribe to topic updates. 59 | // 60 | // Params: 61 | // - topic (string): The topic to subscribe to. 62 | // - 63 | // 64 | // Returns: 65 | // 66 | func Subscribe(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) { 67 | medium, err := getMedium(c) 68 | if err != nil { 69 | return nil, &cookoo.FatalError{"No medium."} 70 | } 71 | topic := p.Get("topic", "").(string) 72 | if len(topic) == 0 { 73 | return nil, errors.New("No topic is set.") 74 | } 75 | 76 | rw := c.Get("http.ResponseWriter", nil).(ResponseWriterFlusher) 77 | clientGone := rw.(http.CloseNotifier).CloseNotify() 78 | 79 | sub := NewSubscription(rw) 80 | t := fetchOrCreateTopic(medium, topic, true, DefaultMaxHistory) 81 | t.Subscribe(sub) 82 | 83 | defer func() { 84 | t.Unsubscribe(sub) 85 | sub.Close() 86 | }() 87 | 88 | sub.Listen(clientGone) 89 | 90 | return nil, nil 91 | } 92 | 93 | // CreateTopic creates a new topic. 94 | // 95 | // Params: 96 | // - topic (string) 97 | // - history (bool): whether or not to track history 98 | // - historyLength (int): How much history to track. Default is DefaultMaxHistory. 99 | // 100 | // Returns: 101 | // Topic the new topic. 102 | func CreateTopic(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) { 103 | name := p.Get("topic", "").(string) 104 | if len(name) == 0 { 105 | return nil, &cookoo.FatalError{"Topic name required."} 106 | } 107 | 108 | hist := p.Get("history", true).(bool) 109 | histLen := p.Get("historyLength", DefaultMaxHistory).(int) 110 | 111 | m, err := getMedium(c) 112 | if err != nil { 113 | return nil, &cookoo.FatalError{"No medium."} 114 | } 115 | 116 | t := fetchOrCreateTopic(m, name, hist, histLen) 117 | 118 | return t, nil 119 | 120 | } 121 | 122 | // DeleteTopic deletes a topic and its history. 123 | // 124 | // Params: 125 | // - name (string) 126 | // 127 | // Returns: 128 | // 129 | func DeleteTopic(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) { 130 | name := p.Get("topic", "").(string) 131 | if len(name) == 0 { 132 | return nil, &cookoo.FatalError{"Topic name required."} 133 | } 134 | 135 | m, err := getMedium(c) 136 | if err != nil { 137 | return nil, &cookoo.FatalError{"No medium."} 138 | } 139 | 140 | err = m.Delete(name) 141 | if err != nil { 142 | c.Logf("warn", "Failed to delete topic: %s", err) 143 | } 144 | 145 | return nil, nil 146 | } 147 | 148 | // TopicExists tests whether a topic exists, and sends an HTTP 200 if yes, 404 if no. 149 | // 150 | // Params: 151 | // - topic (string): The topic to look up. 152 | // Returns: 153 | // 154 | func TopicExists(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) { 155 | res := c.Get("http.ResponseWriter", nil).(ResponseWriterFlusher) 156 | name := p.Get("topic", "").(string) 157 | if len(name) == 0 { 158 | res.WriteHeader(404) 159 | return nil, nil 160 | } 161 | 162 | medium, err := getMedium(c) 163 | if err != nil { 164 | res.WriteHeader(404) 165 | return nil, nil 166 | } 167 | 168 | if _, ok := medium.Topic(name); ok { 169 | res.WriteHeader(200) 170 | return nil, nil 171 | } 172 | res.WriteHeader(404) 173 | return nil, nil 174 | } 175 | 176 | // ReplayHistory sends back the history to a subscriber. 177 | // 178 | // This should be called before the client goes into active listening. 179 | // 180 | // Params: 181 | // - topic (string): The topic to fetch. 182 | // 183 | // Returns: 184 | // - int: The number of history messages sent to the client. 185 | func ReplayHistory(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) { 186 | req := c.Get("http.Request", nil).(*http.Request) 187 | res := c.Get("http.ResponseWriter", nil).(ResponseWriterFlusher) 188 | medium, _ := getMedium(c) 189 | name := p.Get("topic", "").(string) 190 | 191 | // This does not manage topics. If there is no topic set, we silently fail. 192 | if len(name) == 0 { 193 | c.Log("info", "No topic name given to ReplayHistory.") 194 | return 0, nil 195 | } 196 | top, ok := medium.Topic(name) 197 | if !ok { 198 | c.Logf("info", "No topic named %s exists yet. No history replayed.", name) 199 | return 0, nil 200 | } 201 | 202 | topic, ok := top.(HistoriedTopic) 203 | if !ok { 204 | c.Logf("info", "No history for topic %s.", name) 205 | res.Header().Add(XHistoryEnabled, "False") 206 | return 0, nil 207 | } 208 | res.Header().Add(XHistoryEnabled, "True") 209 | 210 | since := req.Header.Get(XHistorySince) 211 | max := req.Header.Get(XHistoryLength) 212 | 213 | // maxLen can be used either on its own or paired with X-History-Since. 214 | maxLen := 0 215 | if len(max) > 0 { 216 | m, err := parseHistLen(max) 217 | if err != nil { 218 | c.Logf("info", "failed to parse X-History-Length %s", max) 219 | } else { 220 | maxLen = m 221 | } 222 | } 223 | if len(since) > 0 { 224 | ts, err := parseSince(since) 225 | if err != nil { 226 | c.Logf("warn", "Failed to parse X-History-Since field %s: %s", since, err) 227 | return 0, nil 228 | } 229 | toSend := topic.Since(ts) 230 | 231 | // If maxLen is also set, we trim the list by sending the newest. 232 | ls := len(toSend) 233 | if maxLen > 0 && ls > maxLen { 234 | offset := ls - maxLen - 1 235 | toSend = toSend[offset:] 236 | } 237 | return sendHistory(c, res, toSend) 238 | } else if maxLen > 0 { 239 | toSend := topic.Last(maxLen) 240 | return sendHistory(c, res, toSend) 241 | } 242 | 243 | return 0, nil 244 | } 245 | 246 | // sendHistory sends the accumulated history to the writer. 247 | func sendHistory(c cookoo.Context, writer ResponseWriterFlusher, data [][]byte) (int, error) { 248 | c.Logf("info", "Sending history.") 249 | var i int 250 | var d []byte 251 | for i, d = range data { 252 | _, err := writer.Write(d) 253 | if err != nil { 254 | c.Logf("warn", "Failed to write history message: %s", err) 255 | return i + 1, nil 256 | } 257 | writer.Flush() 258 | } 259 | return i + 1, nil 260 | } 261 | 262 | // parseSince parses the X-History-Since value. 263 | func parseSince(s string) (time.Time, error) { 264 | tint, err := strconv.ParseInt(s, 0, 64) 265 | if err != nil { 266 | return time.Unix(0, 0), fmt.Errorf("Could not parse as time: %s", s) 267 | } 268 | return time.Unix(tint, 0), nil 269 | } 270 | 271 | // parseHistLen parses the X-History-Length value. 272 | func parseHistLen(s string) (int, error) { 273 | return strconv.Atoi(s) 274 | } 275 | 276 | // fetchOrCreateTopic gets a topic if it exists, and creates one if it doesn't. 277 | func fetchOrCreateTopic(m *Medium, name string, hist bool, l int) Topic { 278 | t, ok := m.Topic(name) 279 | if !ok { 280 | t = NewTopic(name) 281 | if hist && l > 0 { 282 | t = TrackHistory(t, l) 283 | } 284 | m.Add(t) 285 | } 286 | return t 287 | } 288 | -------------------------------------------------------------------------------- /pubsub/cmds_test.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "net/http" 5 | "os" 6 | "testing" 7 | 8 | "github.com/Masterminds/cookoo" 9 | ) 10 | 11 | func TestReplayHistory(t *testing.T) { 12 | reg, router, cxt := cookoo.Cookoo() 13 | cxt.AddLogger("out", os.Stdout) 14 | 15 | medium := NewMedium() 16 | cxt.AddDatasource(MediumDS, medium) 17 | 18 | topic := NewHistoriedTopic("test", 5) 19 | medium.Add(topic) 20 | 21 | topic.Publish([]byte("first")) 22 | topic.Publish([]byte("second")) 23 | 24 | req, _ := http.NewRequest("GET", "https://localhost/v1/t/test", nil) 25 | req.Header.Add(XHistoryLength, "4") 26 | res := &mockResponseWriter{} 27 | 28 | cxt.Put("http.Request", req) 29 | cxt.Put("http.ResponseWriter", res) 30 | 31 | reg.Route("test", "Test route"). 32 | Does(ReplayHistory, "res").Using("topic").WithDefault("test") 33 | 34 | err := router.HandleRequest("test", cxt, true) 35 | if err != nil { 36 | t.Error(err) 37 | } 38 | 39 | last := res.String() 40 | if last != "firstsecond" { 41 | t.Errorf("Expected 'firstsecond', got '%s'", last) 42 | } 43 | 44 | } 45 | -------------------------------------------------------------------------------- /pubsub/history.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "container/list" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | var DefaultMaxHistory = 1000 10 | 11 | // historyTopic maintains the history for a channel. 12 | type historyTopic struct { 13 | Topic 14 | buffer *list.List 15 | max int 16 | mx sync.Mutex 17 | } 18 | 19 | type entry struct { 20 | msg []byte 21 | ts time.Time 22 | } 23 | 24 | // TrackHistory takes an existing topic and adds history tracking. 25 | // 26 | // The mechanism for history tracking is a doubly linked list no longer than 27 | // maxLen. 28 | func TrackHistory(t Topic, maxLen int) HistoriedTopic { 29 | return &historyTopic{ 30 | Topic: t, 31 | buffer: list.New(), 32 | max: maxLen, 33 | } 34 | } 35 | 36 | // Since fetches an array of history entries. 37 | // 38 | // The entries will be in order, oldest to newest. And the list will not 39 | // exceed the maximum number of histry items. 40 | // 41 | // If the history list grows beyond its max size, the history list is pruned, 42 | // oldest to youngest. 43 | func (h *historyTopic) Since(t time.Time) [][]byte { 44 | 45 | accumulator := [][]byte{} 46 | 47 | for v := h.buffer.Front(); v != nil; v = v.Next() { 48 | e, ok := v.Value.(*entry) 49 | if !ok { 50 | // Skip anything that's not an entry. 51 | continue 52 | } 53 | if e.ts.After(t) { 54 | accumulator = append(accumulator, e.msg) 55 | } else { 56 | return accumulator 57 | } 58 | } 59 | return accumulator 60 | } 61 | 62 | // Last fetches the last n items from the history, regardless of their time. 63 | // 64 | // Of course, it will return fewer than n if n is larger than the max length 65 | // or if the total stored history is less than n. 66 | func (h *historyTopic) Last(n int) [][]byte { 67 | acc := make([][]byte, 0, n) 68 | i := 0 69 | for v := h.buffer.Front(); v != nil; v = v.Next() { 70 | e, ok := v.Value.(*entry) 71 | if !ok { 72 | // Skip anything that's not an entry. 73 | continue 74 | } 75 | if i < n { 76 | acc = append(acc, e.msg) 77 | } else { 78 | return acc 79 | } 80 | i++ 81 | } 82 | return acc 83 | } 84 | 85 | func (h *historyTopic) add(msg []byte) { 86 | h.mx.Lock() 87 | defer h.mx.Unlock() 88 | e := &entry{ 89 | msg: msg, 90 | ts: time.Now(), 91 | } 92 | 93 | h.buffer.PushBack(e) 94 | 95 | for h.buffer.Len() > h.max { 96 | h.buffer.Remove(h.buffer.Front()) 97 | } 98 | } 99 | 100 | // Publish stores this msg as history and then forwards the publish request to the Topic. 101 | func (h *historyTopic) Publish(msg []byte) error { 102 | h.add(msg) 103 | h.Topic.Publish(msg) 104 | return nil 105 | } 106 | 107 | func (h *historyTopic) Close() error { 108 | err := h.Topic.Close() 109 | // We don't want nil pointers during shutdown. 110 | h.buffer = list.New() 111 | return err 112 | } 113 | -------------------------------------------------------------------------------- /pubsub/history_test.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestHistory(t *testing.T) { 10 | topic := NewHistoriedTopic("test", 5) 11 | 12 | for _, s := range []string{"a", "b", "c", "d", "e", "f"} { 13 | topic.Publish([]byte(s)) 14 | } 15 | 16 | short := topic.Last(1) 17 | if len(short) != 1 { 18 | t.Errorf("Expected 1 in list, got %d", len(short)) 19 | } 20 | if string(short[0]) != "b" { 21 | t.Errorf("Expected 'b', got '%s'", short[0]) 22 | } 23 | 24 | long := topic.Last(6) 25 | if len(long) != 5 { 26 | t.Errorf("Expected 5 in list, got %d", len(long)) 27 | } 28 | 29 | str := string(bytes.Join(long, []byte(""))) 30 | if str != "bcdef" { 31 | t.Errorf("Expected bcdef, got %s", str) 32 | } 33 | } 34 | 35 | func TestHistorySince(t *testing.T) { 36 | topic := NewHistoriedTopic("test", 5) 37 | 38 | now := time.Now() 39 | 40 | for _, s := range []string{"a", "b", "c", "d", "e", "f"} { 41 | topic.Publish([]byte(s)) 42 | // Current resolution on timer is at seconds. 43 | time.Sleep(time.Second) 44 | } 45 | 46 | since := topic.Since(now) 47 | 48 | str := string(bytes.Join(since, []byte(""))) 49 | if str != "bcdef" { 50 | t.Errorf("Expected bcdef, got %s", str) 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /pubsub/topic.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "math" 7 | "net/http" 8 | "sync" 9 | "sync/atomic" 10 | "time" 11 | 12 | "github.com/Masterminds/cookoo" 13 | ) 14 | 15 | // ResponseWriterFlusher handles both HTTP response writing and flushing. 16 | // 17 | // We use this simply to declare which interfaces we require support for. 18 | type ResponseWriterFlusher interface { 19 | http.ResponseWriter 20 | http.Flusher 21 | http.CloseNotifier 22 | } 23 | 24 | // Topic is the main channel for sending messages to subscribers. 25 | // 26 | // A publisher is anything that sends a message to a Topic. All 27 | // attached subscribers will receive that message. 28 | type Topic interface { 29 | // Publish sends a message to all subscribers. 30 | Publish([]byte) error 31 | // Subscribe attaches a subscription to this topic. 32 | Subscribe(*Subscription) 33 | // Unsubscribe detaches a subscription from the topic. 34 | Unsubscribe(*Subscription) 35 | // Name returns the topic name. 36 | Name() string 37 | // Subscribers returns a list of subscriptions attached to this topic. 38 | Subscribers() []*Subscription 39 | // Close and destroy the topic. 40 | Close() error 41 | } 42 | 43 | // History provides access too the last N messages on a particular Topic. 44 | type History interface { 45 | // Last provides access to up to N messages. 46 | Last(int) [][]byte 47 | // Since provides access to all messages in history since the given time. 48 | Since(time.Time) [][]byte 49 | } 50 | 51 | // HistoriedTopic is a topic that has an attached history. 52 | type HistoriedTopic interface { 53 | History 54 | Topic 55 | } 56 | 57 | // NewTopic creates a new Topic with no history capabilities. 58 | func NewTopic(name string) Topic { 59 | ct := &channeledTopic{ 60 | name: name, 61 | subscribers: make(map[uint64]*Subscription, 512), // Sane default space? 62 | } 63 | return ct 64 | } 65 | 66 | // NewHistoriedTopic creates a new HistoriedTopic. 67 | // 68 | // This topic will retain `length` history items for the topic. 69 | func NewHistoriedTopic(name string, length int) HistoriedTopic { 70 | return TrackHistory(NewTopic(name), length) 71 | } 72 | 73 | type channeledTopic struct { 74 | name string 75 | subscribers map[uint64]*Subscription 76 | mx sync.RWMutex 77 | closed bool 78 | } 79 | 80 | func (t *channeledTopic) Close() error { 81 | t.mx.Lock() 82 | t.closed = true 83 | for _, s := range t.subscribers { 84 | s.Close() 85 | } 86 | t.subscribers = map[uint64]*Subscription{} 87 | t.mx.Unlock() 88 | return nil 89 | } 90 | 91 | func (t *channeledTopic) Publish(msg []byte) error { 92 | if t.closed { 93 | return errors.New("Topic is being deleted.") 94 | } 95 | t.mx.Lock() 96 | defer func() { 97 | t.mx.Unlock() 98 | if err := recover(); err != nil { 99 | fmt.Printf("Recovered from failed publish. Some messages probably didn't get through. %s\n", err) 100 | } 101 | }() 102 | 103 | for _, s := range t.subscribers { 104 | if s.Queue == nil { 105 | fmt.Printf("Channel appears to be closed. Skipping.\n") 106 | continue 107 | } 108 | //fmt.Printf("Sending msg to subscriber %d: %s\n", s.Id, msg) 109 | s.Queue <- msg 110 | } 111 | //fmt.Printf("Message sent.\n") 112 | return nil 113 | } 114 | 115 | func (t *channeledTopic) Subscribe(s *Subscription) { 116 | if t.closed { 117 | return 118 | } 119 | t.mx.Lock() 120 | defer t.mx.Unlock() 121 | //t.subscribers = append(t.subscribers, s) 122 | if _, ok := t.subscribers[s.Id]; ok { 123 | fmt.Printf("Surprisingly got the same ID as an existing subscriber.") 124 | } 125 | t.subscribers[s.Id] = s 126 | //fmt.Printf("There are now %d subscribers", len(t.subscribers)) 127 | } 128 | 129 | func (t *channeledTopic) Unsubscribe(s *Subscription) { 130 | if t.closed { 131 | return 132 | } 133 | t.mx.Lock() 134 | defer t.mx.Unlock() 135 | delete(t.subscribers, s.Id) 136 | } 137 | 138 | func (t *channeledTopic) Name() string { 139 | return t.name 140 | } 141 | 142 | func (t *channeledTopic) Subscribers() []*Subscription { 143 | t.mx.RLock() 144 | defer t.mx.RUnlock() 145 | c := len(t.subscribers) 146 | s := make([]*Subscription, 0, c) 147 | for _, v := range t.subscribers { 148 | s = append(s, v) 149 | } 150 | return s 151 | } 152 | 153 | // Subscription describes a subscriber. 154 | // 155 | // A subscription attaches to ONLY ONE Topic. 156 | type Subscription struct { 157 | Id uint64 158 | Writer ResponseWriterFlusher 159 | Queue chan []byte 160 | } 161 | 162 | // NewSubscription creates a new subscription. 163 | // 164 | // 165 | func NewSubscription(r ResponseWriterFlusher) *Subscription { 166 | // Queue depth should be revisited. 167 | q := make(chan []byte, 10) 168 | return &Subscription{ 169 | Writer: r, 170 | Queue: q, 171 | Id: newSubId(), 172 | } 173 | } 174 | 175 | // Listen copies messages fromt the Queue into the Writer. 176 | // 177 | // It listens on the Queue unless the `stop` channel receives a message. 178 | func (s *Subscription) Listen(stop <-chan bool) { 179 | // s.Queue <- []byte("SUBSCRIBED") 180 | for { 181 | //for msg := range s.Queue { 182 | select { 183 | case msg := <-s.Queue: 184 | //fmt.Printf("Forwarding message.\n") 185 | // Queue is always serial, and this should be the only writer to the 186 | // RequestWriter, so we don't explicitly sync right now. 187 | s.Writer.Write(msg) 188 | s.Writer.Flush() 189 | case <-stop: 190 | //fmt.Printf("Subscription ended.\n") 191 | return 192 | default: 193 | } 194 | } 195 | } 196 | 197 | // Close closes things and cleans up. 198 | func (s *Subscription) Close() { 199 | close(s.Queue) 200 | } 201 | 202 | // getMedium fetches the Medium from the Datasources list. 203 | func getMedium(c cookoo.Context) (*Medium, error) { 204 | ds, ok := c.HasDatasource(MediumDS) 205 | if !ok { 206 | return nil, errors.New("Cannot find a Medium") 207 | } 208 | return ds.(*Medium), nil 209 | } 210 | 211 | // NewMedium creates and initializes a Medium. 212 | func NewMedium() *Medium { 213 | return &Medium{ 214 | topics: make(map[string]Topic, 256), // Premature optimization... 215 | } 216 | } 217 | 218 | // Medium handles channeling messages to topics. 219 | // 220 | // You should always create one with NewMedium or else you will not be able 221 | // to add new topics. 222 | type Medium struct { 223 | topics map[string]Topic 224 | mx sync.RWMutex 225 | } 226 | 227 | // Topic gets a Topic by name. 228 | // 229 | // If no topic is found, the ok flag will return false. 230 | func (m *Medium) Topic(name string) (Topic, bool) { 231 | m.mx.RLock() 232 | defer m.mx.RUnlock() 233 | t, ok := m.topics[name] 234 | return t, ok 235 | } 236 | 237 | // Add a new Topic to the Medium. 238 | func (m *Medium) Add(t Topic) { 239 | m.mx.Lock() 240 | m.topics[t.Name()] = t 241 | m.mx.Unlock() 242 | } 243 | 244 | // Delete closes a topic and removes it. 245 | func (m *Medium) Delete(name string) error { 246 | t, ok := m.topics[name] 247 | if !ok { 248 | return fmt.Errorf("Cannot delete. No topic named %s.", name) 249 | } 250 | t.Close() 251 | m.mx.Lock() 252 | delete(m.topics, name) 253 | m.mx.Unlock() 254 | return nil 255 | } 256 | 257 | var lastSubId uint64 = 0 258 | 259 | // newSubId returns an atomically incremented ID. 260 | // 261 | // This can probably be done better. 262 | func newSubId() uint64 { 263 | z := atomic.AddUint64(&lastSubId, uint64(1)) 264 | // FIXME: And when we hit max? Rollover? 265 | if z == math.MaxUint64 { 266 | atomic.StoreUint64(&lastSubId, uint64(0)) 267 | } 268 | return z 269 | } 270 | -------------------------------------------------------------------------------- /pubsub/topic_test.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "sync" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestSubscription(t *testing.T) { 12 | 13 | // canaries: 14 | var _ http.ResponseWriter = &mockResponseWriter{} 15 | var _ http.Flusher = &mockResponseWriter{} 16 | 17 | rw := &mockResponseWriter{} 18 | sub := NewSubscription(rw) 19 | 20 | rw2 := &mockResponseWriter{} 21 | sub2 := NewSubscription(rw2) 22 | 23 | if sub.Id == sub2.Id { 24 | t.Error("Two subscriptions have the same ID!!!!") 25 | } 26 | 27 | // Make sure the Queue is buffered. 28 | sub.Queue <- []byte("hi") 29 | out := <-sub.Queue 30 | 31 | if string(out) != "hi" { 32 | t.Error("Expected out to be 'hi'") 33 | } 34 | 35 | // Make sure that listen works. 36 | until := make(chan bool) 37 | go sub.Listen(until) 38 | sub.Queue <- []byte("hi") 39 | 40 | time.Sleep(2 * time.Millisecond) 41 | until <- true 42 | 43 | sub.Close() 44 | if rw.String() != "hi" { 45 | t.Errorf("Expected bytes 'hi', got '%s'", rw.String()) 46 | } 47 | 48 | } 49 | 50 | func TestTopic(t *testing.T) { 51 | topic := NewTopic("test") 52 | 53 | if topic.Name() != "test" { 54 | t.Errorf("Expected name 'test', got '%s'", topic.Name()) 55 | } 56 | 57 | subs := make([]*Subscription, 50) 58 | 59 | // Subscribe 50 times. 60 | for i := 0; i < 50; i++ { 61 | rw := &mockResponseWriter{} 62 | sub := NewSubscription(rw) 63 | subs[i] = sub 64 | done := make(chan bool) 65 | topic.Subscribe(sub) 66 | go sub.Listen(done) 67 | } 68 | 69 | topic.Publish([]byte("hi")) 70 | topic.Publish([]byte("there")) 71 | 72 | if len(topic.Subscribers()) != 50 { 73 | t.Errorf("Expected 50 subscribers, got %d.", len(topic.Subscribers())) 74 | } 75 | 76 | time.Sleep(5 * time.Millisecond) 77 | 78 | for _, s := range topic.Subscribers() { 79 | mw := s.Writer.(*mockResponseWriter).String() 80 | if mw != "hithere" { 81 | t.Errorf("Expected Subscription %d to have 'hithere'. Got '%s'", s.Id, mw) 82 | } 83 | 84 | //topic.Unsubscribe(s) 85 | } 86 | 87 | if err := topic.Close(); err != nil { 88 | t.Errorf("Error closing topic: %s", err) 89 | } 90 | 91 | if len(topic.Subscribers()) > 0 { 92 | t.Errorf("After close, topic should have no subscribers. Got %d", len(topic.Subscribers())) 93 | } 94 | 95 | } 96 | 97 | func BenchmarkTopic1Client(b *testing.B) { 98 | benchmarkTopic(1, b.N) 99 | } 100 | 101 | func BenchmarkTopic5Clients(b *testing.B) { 102 | benchmarkTopic(5, b.N) 103 | } 104 | 105 | // Create 50 subscribers and send 100 messages. 106 | func BenchmarkTopic50Clients(b *testing.B) { 107 | benchmarkTopic(50, b.N) 108 | } 109 | 110 | /* 111 | func BenchmarkTopic1Message(b *testing.B) { 112 | benchmarkTopic(b.N, 1) 113 | } 114 | func BenchmarkTopic5Message(b *testing.B) { 115 | benchmarkTopic(b.N, 1) 116 | } 117 | */ 118 | 119 | func benchmarkTopic(scount, mcount int) { 120 | topic := NewTopic("test") 121 | subs := make([]*Subscription, scount) 122 | 123 | // Subscribe 50 times. 124 | for i := 0; i < scount; i++ { 125 | rw := &mockResponseWriter{} 126 | sub := NewSubscription(rw) 127 | subs[i] = sub 128 | done := make(chan bool) 129 | topic.Subscribe(sub) 130 | go sub.Listen(done) 131 | } 132 | 133 | for i := 0; i < mcount; i++ { 134 | topic.Publish([]byte("hi")) 135 | } 136 | 137 | //for _, s := range topic.Subscribers() { 138 | // topic.Unsubscribe(s) 139 | //} 140 | 141 | } 142 | 143 | type mockResponseWriter struct { 144 | headers http.Header 145 | writer bytes.Buffer 146 | mx sync.Mutex 147 | } 148 | 149 | func (r *mockResponseWriter) Header() http.Header { 150 | if len(r.headers) == 0 { 151 | r.headers = make(map[string][]string, 1) 152 | } 153 | return r.headers 154 | } 155 | func (r *mockResponseWriter) Write(d []byte) (int, error) { 156 | r.mx.Lock() 157 | defer r.mx.Unlock() 158 | return r.writer.Write(d) 159 | } 160 | func (r *mockResponseWriter) Buf() []byte { 161 | r.mx.Lock() 162 | defer r.mx.Unlock() 163 | return r.writer.Bytes() 164 | } 165 | func (r *mockResponseWriter) String() string { 166 | r.mx.Lock() 167 | defer r.mx.Unlock() 168 | return r.writer.String() 169 | } 170 | func (r *mockResponseWriter) WriteHeader(c int) { 171 | } 172 | 173 | func (r *mockResponseWriter) Flush() {} 174 | 175 | func (r *mockResponseWriter) CloseNotify() <-chan bool { 176 | return make(chan bool, 1) 177 | } 178 | 179 | // For benchmarking. 180 | type nilResponseWriter struct { 181 | headers http.Header 182 | writer bytes.Buffer 183 | mx sync.Mutex 184 | } 185 | 186 | func (r *nilResponseWriter) Header() http.Header { 187 | if len(r.headers) == 0 { 188 | r.headers = make(map[string][]string, 1) 189 | } 190 | return r.headers 191 | } 192 | func (r *nilResponseWriter) Write(d []byte) (int, error) { 193 | return len(d), nil 194 | } 195 | func (r *nilResponseWriter) WriteHeader(c int) {} 196 | 197 | func (r *nilResponseWriter) Flush() {} 198 | -------------------------------------------------------------------------------- /server/server.go: -------------------------------------------------------------------------------- 1 | /* Package main demos a Pub/Sub server. 2 | */ 3 | package main 4 | 5 | import ( 6 | "net/http" 7 | 8 | "github.com/Masterminds/cookoo" 9 | cfmt "github.com/Masterminds/cookoo/fmt" 10 | "github.com/Masterminds/cookoo/web" 11 | 12 | "github.com/technosophos/drift/httputil" 13 | "github.com/technosophos/drift/pubsub" 14 | 15 | "github.com/bradfitz/http2" 16 | ) 17 | 18 | var helpTemplate = ` 19 | 20 | API Reference 21 | 22 |

API Reference

23 |

These are the API endpoints currently defined for this server.

24 |
{{ range .Routes }} 25 | {{.Name}} 26 |
{{.Description}}
27 | {{end}}
28 | 29 | ` 30 | 31 | func main() { 32 | srv := &http.Server{ 33 | Addr: ":5500", 34 | } 35 | 36 | reg, router, cxt := cookoo.Cookoo() 37 | 38 | buildRegistry(reg, router, cxt) 39 | 40 | // Our main datasource is the Medium, which manages channels. 41 | m := pubsub.NewMedium() 42 | cxt.AddDatasource(pubsub.MediumDS, m) 43 | cxt.Put("routes", reg.Routes()) 44 | 45 | http2.ConfigureServer(srv, &http2.Server{}) 46 | 47 | srv.Handler = web.NewCookooHandler(reg, router, cxt) 48 | 49 | srv.ListenAndServeTLS("server.crt", "server.key") 50 | } 51 | 52 | func buildRegistry(reg *cookoo.Registry, router *cookoo.Router, cxt cookoo.Context) { 53 | reg.AddRoute(cookoo.Route{ 54 | Name: "GET /ping", 55 | Help: "Ping the server, get a pong reponse.", 56 | Does: cookoo.Tasks{ 57 | cookoo.Cmd{ 58 | Fn: func(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) { 59 | w := c.Get("http.ResponseWriter", nil).(http.ResponseWriter) 60 | w.Write([]byte("pong")) 61 | return nil, nil 62 | }, 63 | }, 64 | }, 65 | }) 66 | 67 | reg.AddRoute(cookoo.Route{ 68 | Name: "GET /v1/time", 69 | Help: "Print the current server time as a UNIX seconds-since-epoch", 70 | Does: cookoo.Tasks{ 71 | cookoo.Cmd{ 72 | Name: "timestamp", 73 | Fn: httputil.Timestamp, 74 | }, 75 | cookoo.Cmd{ 76 | Name: "_", 77 | Fn: web.Flush, 78 | Using: []cookoo.Param{ 79 | {Name: "content", From: "cxt:timestamp"}, 80 | {Name: "contentType", DefaultValue: "text/plain"}, 81 | }, 82 | }, 83 | }, 84 | }) 85 | 86 | reg.AddRoute(cookoo.Route{ 87 | Name: "GET /", 88 | Help: "API Reference", 89 | Does: cookoo.Tasks{ 90 | cookoo.Cmd{ 91 | Name: "help", 92 | Fn: cfmt.Template, 93 | Using: []cookoo.Param{ 94 | {Name: "template", DefaultValue: helpTemplate}, 95 | {Name: "Routes", From: "cxt:routes"}, 96 | }, 97 | }, 98 | cookoo.Cmd{ 99 | Name: "_", 100 | Fn: web.Flush, 101 | Using: []cookoo.Param{ 102 | {Name: "content", From: "cxt:help"}, 103 | {Name: "contentType", DefaultValue: "text/html"}, 104 | }, 105 | }, 106 | }, 107 | }) 108 | 109 | reg.AddRoute(cookoo.Route{ 110 | Name: "PUT /v1/t/*", 111 | Help: "Create a new topic.", 112 | Does: cookoo.Tasks{ 113 | cookoo.Cmd{ 114 | Name: "topic", 115 | Fn: pubsub.CreateTopic, 116 | Using: []cookoo.Param{ 117 | {Name: "topic", From: "path:2"}, 118 | }, 119 | }, 120 | }, 121 | }) 122 | 123 | reg.AddRoute(cookoo.Route{ 124 | Name: "POST /v1/t/*", 125 | Help: "Publish a message to a channel.", 126 | Does: cookoo.Tasks{ 127 | cookoo.Cmd{ 128 | Name: "postBody", 129 | Fn: httputil.BufferPost, 130 | }, 131 | cookoo.Cmd{ 132 | Name: "publish", 133 | Fn: pubsub.Publish, 134 | Using: []cookoo.Param{ 135 | {Name: "message", From: "cxt:postBody"}, 136 | {Name: "topic", From: "path:2"}, 137 | }, 138 | }, 139 | }, 140 | }) 141 | 142 | reg.AddRoute(cookoo.Route{ 143 | Name: "GET /v1/t/*", 144 | Help: "Subscribe to a channel.", 145 | Does: cookoo.Tasks{ 146 | cookoo.Cmd{ 147 | Name: "history", 148 | Fn: pubsub.ReplayHistory, 149 | Using: []cookoo.Param{ 150 | {Name: "topic", From: "path:2"}, 151 | }, 152 | }, 153 | cookoo.Cmd{ 154 | Name: "subscribe", 155 | Fn: pubsub.Subscribe, 156 | Using: []cookoo.Param{ 157 | {Name: "topic", From: "path:2"}, 158 | }, 159 | }, 160 | }, 161 | }) 162 | 163 | reg.AddRoute(cookoo.Route{ 164 | Name: "HEAD /v1/t/*", 165 | Help: "Check whether a topic exists.", 166 | Does: cookoo.Tasks{ 167 | cookoo.Cmd{ 168 | Name: "has", 169 | Fn: pubsub.TopicExists, 170 | Using: []cookoo.Param{ 171 | {Name: "topic", From: "path:2"}, 172 | }, 173 | }, 174 | }, 175 | }) 176 | 177 | reg.AddRoute(cookoo.Route{ 178 | Name: "DELETE /v1/t/*", 179 | Help: "Delete a topic and close all subscriptions to the topic.", 180 | Does: cookoo.Tasks{ 181 | cookoo.Cmd{ 182 | Name: "delete", 183 | Fn: pubsub.DeleteTopic, 184 | Using: []cookoo.Param{ 185 | {Name: "topic", From: "path:2"}, 186 | }, 187 | }, 188 | }, 189 | }) 190 | } 191 | -------------------------------------------------------------------------------- /transport/transport.go: -------------------------------------------------------------------------------- 1 | // This is a fork of http2.Transport and the associated code. 2 | // 3 | // The purpose of this is to expose the data frames so that we can handle 4 | // streams directly. The RoundTripper implementation has thus changed. 5 | // 6 | // Copyright 2015 The Go Authors. 7 | // See https://go.googlesource.com/go/+/master/CONTRIBUTORS 8 | // Licensed under the same terms as Go itself: 9 | // https://go.googlesource.com/go/+/master/LICENSE 10 | 11 | package transport 12 | 13 | import ( 14 | "bufio" 15 | "bytes" 16 | "crypto/tls" 17 | "errors" 18 | "fmt" 19 | "io" 20 | "log" 21 | "net" 22 | "net/http" 23 | "strconv" 24 | "strings" 25 | "sync" 26 | "time" 27 | 28 | "github.com/bradfitz/http2" 29 | "github.com/bradfitz/http2/hpack" 30 | ) 31 | 32 | // MPB: copied from http2 33 | 34 | var ( 35 | clientPreface = []byte(http2.ClientPreface) 36 | initialHeaderTableSize uint32 = 4096 37 | ) 38 | 39 | type streamEnder interface { 40 | StreamEnded() bool 41 | } 42 | 43 | type headersEnder interface { 44 | HeadersEnded() bool 45 | } 46 | 47 | // MPB: end copy. 48 | 49 | type Transport struct { 50 | Fallback http.RoundTripper 51 | 52 | // TODO: remove this and make more general with a TLS dial hook, like http 53 | InsecureTLSDial bool 54 | 55 | connMu sync.Mutex 56 | conns map[string][]*clientConn // key is host:port 57 | } 58 | 59 | type clientConn struct { 60 | t *Transport 61 | tconn *tls.Conn 62 | tlsState *tls.ConnectionState 63 | connKey []string // key(s) this connection is cached in, in t.conns 64 | 65 | readerDone chan struct{} // closed on error 66 | readerErr error // set before readerDone is closed 67 | hdec *hpack.Decoder 68 | nextRes *http.Response 69 | 70 | mu sync.Mutex 71 | closed bool 72 | goAway *http2.GoAwayFrame // if non-nil, the GoAwayFrame we received 73 | streams map[uint32]*clientStream 74 | nextStreamID uint32 75 | bw *bufio.Writer 76 | werr error // first write error that has occurred 77 | br *bufio.Reader 78 | fr *http2.Framer 79 | // Settings from peer: 80 | maxFrameSize uint32 81 | maxConcurrentStreams uint32 82 | initialWindowSize uint32 83 | hbuf bytes.Buffer // HPACK encoder writes into this 84 | henc *hpack.Encoder 85 | } 86 | 87 | type clientStream struct { 88 | ID uint32 89 | resc chan resAndError 90 | pw *io.PipeWriter 91 | pr *io.PipeReader 92 | 93 | // MPB: Allow option to send data frame over a channel. 94 | dataToChan bool 95 | data chan []byte 96 | cancel chan bool 97 | } 98 | 99 | // Listener makes a stream into something that can be listened to. 100 | type Listener interface { 101 | Stream() (chan []byte, error) 102 | Cancel() 103 | } 104 | 105 | func (c *clientStream) Stream() (chan []byte, error) { 106 | if !c.dataToChan { 107 | return nil, errors.New("Listener has no data.") 108 | } 109 | return c.data, nil 110 | } 111 | 112 | func (c *clientStream) Cancel() { 113 | c.cancel <- true 114 | } 115 | 116 | type stickyErrWriter struct { 117 | w io.Writer 118 | err *error 119 | } 120 | 121 | func (sew stickyErrWriter) Write(p []byte) (n int, err error) { 122 | if *sew.err != nil { 123 | return 0, *sew.err 124 | } 125 | n, err = sew.w.Write(p) 126 | *sew.err = err 127 | return 128 | } 129 | 130 | func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { 131 | if req.URL.Scheme != "https" { 132 | if t.Fallback == nil { 133 | return nil, errors.New("http2: unsupported scheme and no Fallback") 134 | } 135 | return t.Fallback.RoundTrip(req) 136 | } 137 | 138 | host, port, err := net.SplitHostPort(req.URL.Host) 139 | if err != nil { 140 | host = req.URL.Host 141 | port = "443" 142 | } 143 | 144 | for { 145 | cc, err := t.getClientConn(host, port) 146 | if err != nil { 147 | return nil, err 148 | } 149 | res, err := cc.roundTrip(req) 150 | if shouldRetryRequest(err) { // TODO: or clientconn is overloaded (too many outstanding requests)? 151 | continue 152 | } 153 | if err != nil { 154 | return nil, err 155 | } 156 | return res, nil 157 | } 158 | } 159 | func (t *Transport) Listen(req *http.Request) (*http.Response, Listener, error) { 160 | if req.URL.Scheme != "https" { 161 | return nil, nil, errors.New("http2: unsupported scheme and no Fallback") 162 | } 163 | 164 | host, port, err := net.SplitHostPort(req.URL.Host) 165 | if err != nil { 166 | host = req.URL.Host 167 | port = "443" 168 | } 169 | 170 | for { 171 | cc, err := t.getClientConn(host, port) 172 | if err != nil { 173 | return nil, nil, err 174 | } 175 | res, cs, err := cc.listen(req) 176 | if shouldRetryRequest(err) { // TODO: or clientconn is overloaded (too many outstanding requests)? 177 | continue 178 | } 179 | if err != nil { 180 | return nil, nil, err 181 | } 182 | return res, cs, nil 183 | } 184 | } 185 | 186 | // CloseIdleConnections closes any connections which were previously 187 | // connected from previous requests but are now sitting idle. 188 | // It does not interrupt any connections currently in use. 189 | func (t *Transport) CloseIdleConnections() { 190 | t.connMu.Lock() 191 | defer t.connMu.Unlock() 192 | for _, vv := range t.conns { 193 | for _, cc := range vv { 194 | cc.closeIfIdle() 195 | } 196 | } 197 | } 198 | 199 | var errClientConnClosed = errors.New("http2: client conn is closed") 200 | 201 | func shouldRetryRequest(err error) bool { 202 | // TODO: or GOAWAY graceful shutdown stuff 203 | return err == errClientConnClosed 204 | } 205 | 206 | func (t *Transport) removeClientConn(cc *clientConn) { 207 | t.connMu.Lock() 208 | defer t.connMu.Unlock() 209 | for _, key := range cc.connKey { 210 | vv, ok := t.conns[key] 211 | if !ok { 212 | continue 213 | } 214 | newList := filterOutClientConn(vv, cc) 215 | if len(newList) > 0 { 216 | t.conns[key] = newList 217 | } else { 218 | delete(t.conns, key) 219 | } 220 | } 221 | } 222 | 223 | func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn { 224 | out := in[:0] 225 | for _, v := range in { 226 | if v != exclude { 227 | out = append(out, v) 228 | } 229 | } 230 | return out 231 | } 232 | 233 | func (t *Transport) getClientConn(host, port string) (*clientConn, error) { 234 | t.connMu.Lock() 235 | defer t.connMu.Unlock() 236 | 237 | key := net.JoinHostPort(host, port) 238 | 239 | for _, cc := range t.conns[key] { 240 | if cc.canTakeNewRequest() { 241 | return cc, nil 242 | } 243 | } 244 | if t.conns == nil { 245 | t.conns = make(map[string][]*clientConn) 246 | } 247 | cc, err := t.newClientConn(host, port, key) 248 | if err != nil { 249 | return nil, err 250 | } 251 | t.conns[key] = append(t.conns[key], cc) 252 | return cc, nil 253 | } 254 | 255 | func (t *Transport) newClientConn(host, port, key string) (*clientConn, error) { 256 | cfg := &tls.Config{ 257 | ServerName: host, 258 | NextProtos: []string{http2.NextProtoTLS}, 259 | InsecureSkipVerify: t.InsecureTLSDial, 260 | } 261 | tconn, err := tls.Dial("tcp", host+":"+port, cfg) 262 | if err != nil { 263 | return nil, err 264 | } 265 | if err := tconn.Handshake(); err != nil { 266 | return nil, err 267 | } 268 | if !t.InsecureTLSDial { 269 | if err := tconn.VerifyHostname(cfg.ServerName); err != nil { 270 | return nil, err 271 | } 272 | } 273 | state := tconn.ConnectionState() 274 | if p := state.NegotiatedProtocol; p != http2.NextProtoTLS { 275 | // TODO(bradfitz): fall back to Fallback 276 | return nil, fmt.Errorf("bad protocol: %v", p) 277 | } 278 | if !state.NegotiatedProtocolIsMutual { 279 | return nil, errors.New("could not negotiate protocol mutually") 280 | } 281 | if _, err := tconn.Write(clientPreface); err != nil { 282 | return nil, err 283 | } 284 | 285 | cc := &clientConn{ 286 | t: t, 287 | tconn: tconn, 288 | connKey: []string{key}, // TODO: cert's validated hostnames too 289 | tlsState: &state, 290 | readerDone: make(chan struct{}), 291 | nextStreamID: 1, 292 | maxFrameSize: 16 << 10, // spec default 293 | initialWindowSize: 65535, // spec default 294 | maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough. 295 | streams: make(map[uint32]*clientStream), 296 | } 297 | cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr}) 298 | cc.br = bufio.NewReader(tconn) 299 | cc.fr = http2.NewFramer(cc.bw, cc.br) 300 | cc.henc = hpack.NewEncoder(&cc.hbuf) 301 | 302 | cc.fr.WriteSettings() 303 | // TODO: re-send more conn-level flow control tokens when server uses all these. 304 | cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs? 305 | cc.bw.Flush() 306 | if cc.werr != nil { 307 | return nil, cc.werr 308 | } 309 | 310 | // Read the obligatory SETTINGS frame 311 | f, err := cc.fr.ReadFrame() 312 | if err != nil { 313 | return nil, err 314 | } 315 | sf, ok := f.(*http2.SettingsFrame) 316 | if !ok { 317 | return nil, fmt.Errorf("expected settings frame, got: %T", f) 318 | } 319 | cc.fr.WriteSettingsAck() 320 | cc.bw.Flush() 321 | 322 | sf.ForeachSetting(func(s http2.Setting) error { 323 | switch s.ID { 324 | case http2.SettingMaxFrameSize: 325 | cc.maxFrameSize = s.Val 326 | case http2.SettingMaxConcurrentStreams: 327 | cc.maxConcurrentStreams = s.Val 328 | case http2.SettingInitialWindowSize: 329 | cc.initialWindowSize = s.Val 330 | default: 331 | // TODO(bradfitz): handle more 332 | log.Printf("Unhandled Setting: %v", s) 333 | } 334 | return nil 335 | }) 336 | // TODO: figure out henc size 337 | cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField) 338 | 339 | go cc.readLoop() 340 | go cc.cancelLoop() 341 | return cc, nil 342 | } 343 | 344 | func (cc *clientConn) setGoAway(f *http2.GoAwayFrame) { 345 | cc.mu.Lock() 346 | defer cc.mu.Unlock() 347 | cc.goAway = f 348 | } 349 | 350 | func (cc *clientConn) canTakeNewRequest() bool { 351 | cc.mu.Lock() 352 | defer cc.mu.Unlock() 353 | return cc.goAway == nil && 354 | int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) && 355 | cc.nextStreamID < 2147483647 356 | } 357 | 358 | func (cc *clientConn) closeIfIdle() { 359 | cc.mu.Lock() 360 | if len(cc.streams) > 0 { 361 | cc.mu.Unlock() 362 | return 363 | } 364 | cc.closed = true 365 | // TODO: do clients send GOAWAY too? maybe? Just Close: 366 | cc.mu.Unlock() 367 | 368 | cc.tconn.Close() 369 | } 370 | 371 | func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) { 372 | cc.mu.Lock() 373 | 374 | if cc.closed { 375 | cc.mu.Unlock() 376 | return nil, errClientConnClosed 377 | } 378 | 379 | cs := cc.newStream() 380 | hasBody := false // TODO 381 | 382 | // we send: HEADERS[+CONTINUATION] + (DATA?) 383 | hdrs := cc.encodeHeaders(req) 384 | first := true 385 | for len(hdrs) > 0 { 386 | chunk := hdrs 387 | if len(chunk) > int(cc.maxFrameSize) { 388 | chunk = chunk[:cc.maxFrameSize] 389 | } 390 | hdrs = hdrs[len(chunk):] 391 | endHeaders := len(hdrs) == 0 392 | if first { 393 | cc.fr.WriteHeaders(http2.HeadersFrameParam{ 394 | StreamID: cs.ID, 395 | BlockFragment: chunk, 396 | EndStream: !hasBody, 397 | EndHeaders: endHeaders, 398 | }) 399 | first = false 400 | } else { 401 | cc.fr.WriteContinuation(cs.ID, endHeaders, chunk) 402 | } 403 | } 404 | cc.bw.Flush() 405 | werr := cc.werr 406 | cc.mu.Unlock() 407 | 408 | if hasBody { 409 | // TODO: write data. and it should probably be interleaved: 410 | // go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc 411 | } 412 | 413 | if werr != nil { 414 | return nil, werr 415 | } 416 | 417 | re := <-cs.resc 418 | if re.err != nil { 419 | return nil, re.err 420 | } 421 | res := re.res 422 | res.Request = req 423 | res.TLS = cc.tlsState 424 | return res, nil 425 | } 426 | 427 | // listen sends a request, gets the headers of the response, and then listens 428 | // for additional data frames. 429 | // 430 | // Data frames are then shunted onto the data channel. 431 | // 432 | // When the remote indicates that the transaction is complete, the channel is 433 | // closed. 434 | func (cc *clientConn) listen(req *http.Request) (*http.Response, *clientStream, error) { 435 | cc.mu.Lock() 436 | 437 | if cc.closed { 438 | cc.mu.Unlock() 439 | return nil, nil, errClientConnClosed 440 | } 441 | 442 | cs := cc.newStream() 443 | hasBody := false // TODO 444 | 445 | cs.dataToChan = true 446 | cs.data = make(chan []byte, 1) 447 | 448 | // we send: HEADERS[+CONTINUATION] + (DATA?) 449 | hdrs := cc.encodeHeaders(req) 450 | first := true 451 | for len(hdrs) > 0 { 452 | chunk := hdrs 453 | if len(chunk) > int(cc.maxFrameSize) { 454 | chunk = chunk[:cc.maxFrameSize] 455 | } 456 | hdrs = hdrs[len(chunk):] 457 | endHeaders := len(hdrs) == 0 458 | if first { 459 | cc.fr.WriteHeaders(http2.HeadersFrameParam{ 460 | StreamID: cs.ID, 461 | BlockFragment: chunk, 462 | EndStream: !hasBody, 463 | EndHeaders: endHeaders, 464 | }) 465 | first = false 466 | } else { 467 | cc.fr.WriteContinuation(cs.ID, endHeaders, chunk) 468 | } 469 | } 470 | cc.bw.Flush() 471 | werr := cc.werr 472 | cc.mu.Unlock() 473 | 474 | if hasBody { 475 | // TODO: write data. and it should probably be interleaved: 476 | // go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc 477 | } 478 | 479 | if werr != nil { 480 | return nil, nil, werr 481 | } 482 | 483 | re := <-cs.resc 484 | if re.err != nil { 485 | //return nil, re.err 486 | fmt.Printf("Error closing client stream: %s", re.err) 487 | } 488 | res := re.res 489 | res.Request = req 490 | res.TLS = cc.tlsState 491 | return res, cs, nil 492 | } 493 | 494 | // requires cc.mu be held. 495 | func (cc *clientConn) encodeHeaders(req *http.Request) []byte { 496 | cc.hbuf.Reset() 497 | 498 | // TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go 499 | host := req.Host 500 | if host == "" { 501 | host = req.URL.Host 502 | } 503 | 504 | path := req.URL.Path 505 | if path == "" { 506 | path = "/" 507 | } 508 | 509 | cc.writeHeader(":authority", host) // probably not right for all sites 510 | cc.writeHeader(":method", req.Method) 511 | cc.writeHeader(":path", path) 512 | cc.writeHeader(":scheme", "https") 513 | 514 | for k, vv := range req.Header { 515 | lowKey := strings.ToLower(k) 516 | if lowKey == "host" { 517 | continue 518 | } 519 | for _, v := range vv { 520 | cc.writeHeader(lowKey, v) 521 | } 522 | } 523 | return cc.hbuf.Bytes() 524 | } 525 | 526 | func (cc *clientConn) writeHeader(name, value string) { 527 | log.Printf("sending %q = %q", name, value) 528 | cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) 529 | } 530 | 531 | type resAndError struct { 532 | res *http.Response 533 | err error 534 | } 535 | 536 | // requires cc.mu be held. 537 | func (cc *clientConn) newStream() *clientStream { 538 | cs := &clientStream{ 539 | ID: cc.nextStreamID, 540 | resc: make(chan resAndError, 1), 541 | cancel: make(chan bool, 1), 542 | } 543 | cc.nextStreamID += 2 544 | cc.streams[cs.ID] = cs 545 | return cs 546 | } 547 | 548 | func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream { 549 | cc.mu.Lock() 550 | defer cc.mu.Unlock() 551 | cs := cc.streams[id] 552 | if andRemove { 553 | delete(cc.streams, id) 554 | } 555 | return cs 556 | } 557 | 558 | func (cc *clientConn) cancelLoop() { 559 | for { 560 | // Find anything that needs to be removed and schedule it for removal. 561 | remove := map[uint32]*clientStream{} 562 | for streamID, cs := range cc.streams { 563 | if len(cs.cancel) > 0 { 564 | remove[streamID] = cs 565 | } 566 | } 567 | // Run removals. 568 | for streamID, cs := range remove { 569 | log.Printf("Canceling %d\n", streamID) 570 | if cs.dataToChan { 571 | close(cs.data) 572 | } 573 | cs.pw.Close() 574 | delete(cc.streams, streamID) 575 | } 576 | // Yield to GC. 577 | time.Sleep(3 * time.Millisecond) 578 | } 579 | } 580 | 581 | // runs in its own goroutine. 582 | func (cc *clientConn) readLoop() { 583 | defer cc.t.removeClientConn(cc) 584 | defer close(cc.readerDone) 585 | 586 | activeRes := map[uint32]*clientStream{} // keyed by streamID 587 | // Close any response bodies if the server closes prematurely. 588 | // TODO: also do this if we've written the headers but not 589 | // gotten a response yet. 590 | defer func() { 591 | err := cc.readerErr 592 | if err == io.EOF { 593 | err = io.ErrUnexpectedEOF 594 | } 595 | for _, cs := range activeRes { 596 | if cs.dataToChan { 597 | close(cs.data) 598 | } 599 | cs.pw.CloseWithError(err) 600 | } 601 | }() 602 | 603 | // continueStreamID is the stream ID we're waiting for 604 | // continuation frames for. 605 | var continueStreamID uint32 606 | 607 | for { 608 | f, err := cc.fr.ReadFrame() 609 | if err != nil { 610 | cc.readerErr = err 611 | return 612 | } 613 | log.Printf("Transport received %v: %#v", f.Header(), f) 614 | 615 | streamID := f.Header().StreamID 616 | 617 | _, isContinue := f.(*http2.ContinuationFrame) 618 | if isContinue { 619 | if streamID != continueStreamID { 620 | log.Printf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID) 621 | cc.readerErr = http2.ConnectionError(http2.ErrCodeProtocol) 622 | return 623 | } 624 | } else if continueStreamID != 0 { 625 | // Continue frames need to be adjacent in the stream 626 | // and we were in the middle of headers. 627 | log.Printf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID) 628 | cc.readerErr = http2.ConnectionError(http2.ErrCodeProtocol) 629 | return 630 | } 631 | 632 | if streamID%2 == 0 { 633 | // Ignore streams pushed from the server for now. 634 | // These always have an even stream id. 635 | continue 636 | } 637 | streamEnded := false 638 | if ff, ok := f.(streamEnder); ok { 639 | streamEnded = ff.StreamEnded() 640 | } 641 | 642 | cs := cc.streamByID(streamID, streamEnded) 643 | if cs == nil { 644 | log.Printf("Received frame for untracked stream ID %d", streamID) 645 | continue 646 | } 647 | 648 | switch f := f.(type) { 649 | case *http2.HeadersFrame: 650 | cc.nextRes = &http.Response{ 651 | Proto: "HTTP/2.0", 652 | ProtoMajor: 2, 653 | Header: make(http.Header), 654 | } 655 | cs.pr, cs.pw = io.Pipe() 656 | cc.hdec.Write(f.HeaderBlockFragment()) 657 | case *http2.ContinuationFrame: 658 | cc.hdec.Write(f.HeaderBlockFragment()) 659 | case *http2.DataFrame: 660 | log.Printf("DATA: %q", f.Data()) 661 | if cs.dataToChan { 662 | cs.data <- f.Data() 663 | } else { 664 | cs.pw.Write(f.Data()) 665 | } 666 | case *http2.GoAwayFrame: 667 | cc.t.removeClientConn(cc) 668 | if f.ErrCode != 0 { 669 | // TODO: deal with GOAWAY more. particularly the error code 670 | log.Printf("transport got GOAWAY with error code = %v", f.ErrCode) 671 | } 672 | cc.setGoAway(f) 673 | default: 674 | log.Printf("Transport: unhandled response frame type %T", f) 675 | } 676 | headersEnded := false 677 | if he, ok := f.(headersEnder); ok { 678 | headersEnded = he.HeadersEnded() 679 | if headersEnded { 680 | continueStreamID = 0 681 | } else { 682 | continueStreamID = streamID 683 | } 684 | } 685 | 686 | if streamEnded { 687 | if cs.dataToChan { 688 | close(cs.data) 689 | } 690 | cs.pw.Close() 691 | delete(activeRes, streamID) 692 | } 693 | if headersEnded { 694 | if cs == nil { 695 | panic("couldn't find stream") // TODO be graceful 696 | } 697 | // TODO: set the Body to one which notes the 698 | // Close and also sends the server a 699 | // RST_STREAM 700 | cc.nextRes.Body = cs.pr 701 | res := cc.nextRes 702 | activeRes[streamID] = cs 703 | cs.resc <- resAndError{res: res} 704 | } 705 | } 706 | } 707 | 708 | func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) { 709 | // TODO: verifiy pseudo headers come before non-pseudo headers 710 | // TODO: verifiy the status is set 711 | log.Printf("Header field: %+v", f) 712 | if f.Name == ":status" { 713 | code, err := strconv.Atoi(f.Value) 714 | if err != nil { 715 | panic("TODO: be graceful") 716 | } 717 | cc.nextRes.Status = f.Value + " " + http.StatusText(code) 718 | cc.nextRes.StatusCode = code 719 | return 720 | } 721 | if strings.HasPrefix(f.Name, ":") { 722 | // "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document." 723 | // TODO: treat as invalid? 724 | return 725 | } 726 | cc.nextRes.Header.Add(http.CanonicalHeaderKey(f.Name), f.Value) 727 | } 728 | --------------------------------------------------------------------------------