├── glide.yaml
├── httputil
├── httputil_test.go
└── httputil.go
├── pubsub
├── cmds_test.go
├── history_test.go
├── history.go
├── topic_test.go
├── topic.go
└── cmds.go
├── client
├── client_test.go
├── client.go
└── example
│ └── client.go
├── README.md
├── server
└── server.go
└── transport
└── transport.go
/glide.yaml:
--------------------------------------------------------------------------------
1 | package: github.com/technosophos/drift
2 | import:
3 | - package: github.com/bradfitz/http2
4 | - package: github.com/Masterminds/cookoo
5 |
--------------------------------------------------------------------------------
/httputil/httputil_test.go:
--------------------------------------------------------------------------------
1 | package httputil
2 |
3 | import (
4 | "github.com/Masterminds/cookoo"
5 | "strconv"
6 | "testing"
7 | )
8 |
9 | func TestTimestamp(t *testing.T) {
10 | reg, router, cxt := cookoo.Cookoo()
11 |
12 | reg.Route("test", "Test route").
13 | Does(Timestamp, "res")
14 |
15 | err := router.HandleRequest("test", cxt, true)
16 | if err != nil {
17 | t.Error(err)
18 | }
19 |
20 | ts := cxt.Get("res", "").(string)
21 |
22 | if len(ts) == 0 {
23 | t.Errorf("Expected timestamp, not empty string.")
24 | }
25 |
26 | tsInt, err := strconv.Atoi(ts)
27 | if err != nil {
28 | t.Error(err)
29 | }
30 |
31 | if tsInt <= 5 {
32 | t.Error("Dude, you're stuck in the '70s.")
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/pubsub/cmds_test.go:
--------------------------------------------------------------------------------
1 | package pubsub
2 |
3 | import (
4 | "net/http"
5 | "os"
6 | "testing"
7 |
8 | "github.com/Masterminds/cookoo"
9 | )
10 |
11 | func TestReplayHistory(t *testing.T) {
12 | reg, router, cxt := cookoo.Cookoo()
13 | cxt.AddLogger("out", os.Stdout)
14 |
15 | medium := NewMedium()
16 | cxt.AddDatasource(MediumDS, medium)
17 |
18 | topic := NewHistoriedTopic("test", 5)
19 | medium.Add(topic)
20 |
21 | topic.Publish([]byte("first"))
22 | topic.Publish([]byte("second"))
23 |
24 | req, _ := http.NewRequest("GET", "https://localhost/v1/t/test", nil)
25 | req.Header.Add(XHistoryLength, "4")
26 | res := &mockResponseWriter{}
27 |
28 | cxt.Put("http.Request", req)
29 | cxt.Put("http.ResponseWriter", res)
30 |
31 | reg.Route("test", "Test route").
32 | Does(ReplayHistory, "res").Using("topic").WithDefault("test")
33 |
34 | err := router.HandleRequest("test", cxt, true)
35 | if err != nil {
36 | t.Error(err)
37 | }
38 |
39 | last := res.String()
40 | if last != "firstsecond" {
41 | t.Errorf("Expected 'firstsecond', got '%s'", last)
42 | }
43 |
44 | }
45 |
--------------------------------------------------------------------------------
/pubsub/history_test.go:
--------------------------------------------------------------------------------
1 | package pubsub
2 |
3 | import (
4 | "bytes"
5 | "testing"
6 | "time"
7 | )
8 |
9 | func TestHistory(t *testing.T) {
10 | topic := NewHistoriedTopic("test", 5)
11 |
12 | for _, s := range []string{"a", "b", "c", "d", "e", "f"} {
13 | topic.Publish([]byte(s))
14 | }
15 |
16 | short := topic.Last(1)
17 | if len(short) != 1 {
18 | t.Errorf("Expected 1 in list, got %d", len(short))
19 | }
20 | if string(short[0]) != "b" {
21 | t.Errorf("Expected 'b', got '%s'", short[0])
22 | }
23 |
24 | long := topic.Last(6)
25 | if len(long) != 5 {
26 | t.Errorf("Expected 5 in list, got %d", len(long))
27 | }
28 |
29 | str := string(bytes.Join(long, []byte("")))
30 | if str != "bcdef" {
31 | t.Errorf("Expected bcdef, got %s", str)
32 | }
33 | }
34 |
35 | func TestHistorySince(t *testing.T) {
36 | topic := NewHistoriedTopic("test", 5)
37 |
38 | now := time.Now()
39 |
40 | for _, s := range []string{"a", "b", "c", "d", "e", "f"} {
41 | topic.Publish([]byte(s))
42 | // Current resolution on timer is at seconds.
43 | time.Sleep(time.Second)
44 | }
45 |
46 | since := topic.Since(now)
47 |
48 | str := string(bytes.Join(since, []byte("")))
49 | if str != "bcdef" {
50 | t.Errorf("Expected bcdef, got %s", str)
51 | }
52 |
53 | }
54 |
--------------------------------------------------------------------------------
/httputil/httputil.go:
--------------------------------------------------------------------------------
1 | package httputil
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "io"
7 | "net/http"
8 | "time"
9 |
10 | "github.com/Masterminds/cookoo"
11 | )
12 |
13 | // BufferPost buffers the body of the POST request into the context.
14 | //
15 | // Params:
16 | //
17 | // Returns:
18 | // - []byte with the content of the request.
19 | func BufferPost(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) {
20 | req := c.Get("http.Request", nil).(*http.Request)
21 | var b bytes.Buffer
22 | _, err := io.Copy(&b, req.Body)
23 | c.Logf("info", "Received POST: %s", b.Bytes())
24 | return b.Bytes(), err
25 | }
26 |
27 | // Timestamp returns a UNIX timestamp.
28 | //
29 | // Params:
30 | //
31 | // Returns:
32 | // - int64 timestamp as seconds since epoch.
33 | //
34 | func Timestamp(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) {
35 | return fmt.Sprintf("%d", time.Now().Unix()), nil
36 | }
37 |
38 | // Debug displays debugging info.
39 | func Debug(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) {
40 | w := c.Get("http.ResponseWriter", nil).(http.ResponseWriter)
41 | r := c.Get("http.Request", nil).(*http.Request)
42 | reqInfoHandler(w, r)
43 | return nil, nil
44 | }
45 | func reqInfoHandler(w http.ResponseWriter, r *http.Request) {
46 | w.Header().Set("Content-Type", "text/plain")
47 | fmt.Fprintf(w, "Method: %s\n", r.Method)
48 | fmt.Fprintf(w, "Protocol: %s\n", r.Proto)
49 | fmt.Fprintf(w, "Host: %s\n", r.Host)
50 | fmt.Fprintf(w, "RemoteAddr: %s\n", r.RemoteAddr)
51 | fmt.Fprintf(w, "RequestURI: %q\n", r.RequestURI)
52 | fmt.Fprintf(w, "URL: %#v\n", r.URL)
53 | fmt.Fprintf(w, "Body.ContentLength: %d (-1 means unknown)\n", r.ContentLength)
54 | fmt.Fprintf(w, "Close: %v (relevant for HTTP/1 only)\n", r.Close)
55 | fmt.Fprintf(w, "TLS: %#v\n", r.TLS)
56 | fmt.Fprintf(w, "\nHeaders:\n")
57 | r.Header.Write(w)
58 | }
59 |
--------------------------------------------------------------------------------
/pubsub/history.go:
--------------------------------------------------------------------------------
1 | package pubsub
2 |
3 | import (
4 | "container/list"
5 | "sync"
6 | "time"
7 | )
8 |
9 | var DefaultMaxHistory = 1000
10 |
11 | // historyTopic maintains the history for a channel.
12 | type historyTopic struct {
13 | Topic
14 | buffer *list.List
15 | max int
16 | mx sync.Mutex
17 | }
18 |
19 | type entry struct {
20 | msg []byte
21 | ts time.Time
22 | }
23 |
24 | // TrackHistory takes an existing topic and adds history tracking.
25 | //
26 | // The mechanism for history tracking is a doubly linked list no longer than
27 | // maxLen.
28 | func TrackHistory(t Topic, maxLen int) HistoriedTopic {
29 | return &historyTopic{
30 | Topic: t,
31 | buffer: list.New(),
32 | max: maxLen,
33 | }
34 | }
35 |
36 | // Since fetches an array of history entries.
37 | //
38 | // The entries will be in order, oldest to newest. And the list will not
39 | // exceed the maximum number of histry items.
40 | //
41 | // If the history list grows beyond its max size, the history list is pruned,
42 | // oldest to youngest.
43 | func (h *historyTopic) Since(t time.Time) [][]byte {
44 |
45 | accumulator := [][]byte{}
46 |
47 | for v := h.buffer.Front(); v != nil; v = v.Next() {
48 | e, ok := v.Value.(*entry)
49 | if !ok {
50 | // Skip anything that's not an entry.
51 | continue
52 | }
53 | if e.ts.After(t) {
54 | accumulator = append(accumulator, e.msg)
55 | } else {
56 | return accumulator
57 | }
58 | }
59 | return accumulator
60 | }
61 |
62 | // Last fetches the last n items from the history, regardless of their time.
63 | //
64 | // Of course, it will return fewer than n if n is larger than the max length
65 | // or if the total stored history is less than n.
66 | func (h *historyTopic) Last(n int) [][]byte {
67 | acc := make([][]byte, 0, n)
68 | i := 0
69 | for v := h.buffer.Front(); v != nil; v = v.Next() {
70 | e, ok := v.Value.(*entry)
71 | if !ok {
72 | // Skip anything that's not an entry.
73 | continue
74 | }
75 | if i < n {
76 | acc = append(acc, e.msg)
77 | } else {
78 | return acc
79 | }
80 | i++
81 | }
82 | return acc
83 | }
84 |
85 | func (h *historyTopic) add(msg []byte) {
86 | h.mx.Lock()
87 | defer h.mx.Unlock()
88 | e := &entry{
89 | msg: msg,
90 | ts: time.Now(),
91 | }
92 |
93 | h.buffer.PushBack(e)
94 |
95 | for h.buffer.Len() > h.max {
96 | h.buffer.Remove(h.buffer.Front())
97 | }
98 | }
99 |
100 | // Publish stores this msg as history and then forwards the publish request to the Topic.
101 | func (h *historyTopic) Publish(msg []byte) error {
102 | h.add(msg)
103 | h.Topic.Publish(msg)
104 | return nil
105 | }
106 |
107 | func (h *historyTopic) Close() error {
108 | err := h.Topic.Close()
109 | // We don't want nil pointers during shutdown.
110 | h.buffer = list.New()
111 | return err
112 | }
113 |
--------------------------------------------------------------------------------
/client/client_test.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "net/http"
5 | "testing"
6 | "time"
7 |
8 | "github.com/Masterminds/cookoo"
9 | "github.com/Masterminds/cookoo/web"
10 | "github.com/bradfitz/http2"
11 | "github.com/technosophos/drift/httputil"
12 | "github.com/technosophos/drift/pubsub"
13 | )
14 |
15 | var hostport = "127.0.0.1:5500"
16 | var baseurl = "https://127.0.0.1:5500"
17 | var topicname = "test.topic"
18 |
19 | func TestClient(t *testing.T) {
20 | // Lots of timing to simulate networkiness. Because that makes the
21 | // test nondeterministic, we use fairly large times.
22 |
23 | go standUpServer()
24 | time.Sleep(2 * time.Second)
25 |
26 | cli := New(baseurl)
27 |
28 | go func() {
29 | if err := cli.Publish(topicname, []byte("test")); err != nil {
30 | t.Fatal(err)
31 | }
32 |
33 | time.Sleep(50 * time.Millisecond)
34 |
35 | cli.Publish(topicname, []byte("Again"))
36 | }()
37 |
38 | println("Subscribing")
39 | si, err := cli.Subscribe(topicname)
40 | if err != nil {
41 | t.Fatal(err)
42 | }
43 |
44 | go func() {
45 | time.Sleep(500 * time.Millisecond)
46 | si.Cancel()
47 | }()
48 |
49 | if first := <-si.C; string(first) != "test" {
50 | t.Errorf("expected test, got %s", first)
51 | }
52 | if second := <-si.C; string(second) != "Again" {
53 | t.Errorf("expected Again, got %s", second)
54 | }
55 |
56 | time.Sleep(1 * time.Second)
57 | cli.Delete(topicname)
58 | }
59 |
60 | func standUpServer() error {
61 | srv := &http.Server{Addr: hostport}
62 |
63 | reg, router, cxt := cookoo.Cookoo()
64 |
65 | buildRegistry(reg, router, cxt)
66 |
67 | // Our main datasource is the Medium, which manages channels.
68 | m := pubsub.NewMedium()
69 | cxt.AddDatasource(pubsub.MediumDS, m)
70 | cxt.Put("routes", reg.Routes())
71 |
72 | http2.ConfigureServer(srv, &http2.Server{})
73 |
74 | srv.Handler = web.NewCookooHandler(reg, router, cxt)
75 |
76 | srv.ListenAndServeTLS("../server/server.crt", "../server/server.key")
77 | return nil
78 | }
79 |
80 | func buildRegistry(reg *cookoo.Registry, router *cookoo.Router, cxt cookoo.Context) {
81 |
82 | reg.AddRoute(cookoo.Route{
83 | Name: "PUT /v1/t/*",
84 | Help: "Create a new topic.",
85 | Does: cookoo.Tasks{
86 | cookoo.Cmd{
87 | Name: "topic",
88 | Fn: pubsub.CreateTopic,
89 | Using: []cookoo.Param{
90 | {Name: "topic", From: "path:2"},
91 | },
92 | },
93 | },
94 | })
95 |
96 | reg.AddRoute(cookoo.Route{
97 | Name: "POST /v1/t/*",
98 | Help: "Publish a message to a channel.",
99 | Does: cookoo.Tasks{
100 | cookoo.Cmd{
101 | Name: "postBody",
102 | Fn: httputil.BufferPost,
103 | },
104 | cookoo.Cmd{
105 | Name: "publish",
106 | Fn: pubsub.Publish,
107 | Using: []cookoo.Param{
108 | {Name: "message", From: "cxt:postBody"},
109 | {Name: "topic", From: "path:2"},
110 | },
111 | },
112 | },
113 | })
114 |
115 | reg.AddRoute(cookoo.Route{
116 | Name: "GET /v1/t/*",
117 | Help: "Subscribe to a topic.",
118 | Does: cookoo.Tasks{
119 | cookoo.Cmd{
120 | Name: "history",
121 | Fn: pubsub.ReplayHistory,
122 | Using: []cookoo.Param{
123 | {Name: "topic", From: "path:2"},
124 | },
125 | },
126 | cookoo.Cmd{
127 | Name: "subscribe",
128 | Fn: pubsub.Subscribe,
129 | Using: []cookoo.Param{
130 | {Name: "topic", From: "path:2"},
131 | },
132 | },
133 | },
134 | })
135 |
136 | reg.AddRoute(cookoo.Route{
137 | Name: "HEAD /v1/t/*",
138 | Help: "Check whether a topic exists.",
139 | Does: cookoo.Tasks{
140 | cookoo.Cmd{
141 | Name: "has",
142 | Fn: pubsub.TopicExists,
143 | Using: []cookoo.Param{
144 | {Name: "topic", From: "path:2"},
145 | },
146 | },
147 | },
148 | })
149 |
150 | reg.AddRoute(cookoo.Route{
151 | Name: "DELETE /v1/t/*",
152 | Help: "Delete a topic and close all subscriptions to the topic.",
153 | Does: cookoo.Tasks{
154 | cookoo.Cmd{
155 | Name: "delete",
156 | Fn: pubsub.DeleteTopic,
157 | Using: []cookoo.Param{
158 | {Name: "topic", From: "path:2"},
159 | },
160 | },
161 | },
162 | })
163 | }
164 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Drift: An HTTP/2 Pub/Sub service
2 |
3 | [](https://masterminds.github.io/stability/experimental.html)
4 |
5 | Drift is a topic-based PubSub service based on HTTP/2. It uses HTTP/2's
6 | ability to stream data as a simple mechanism for managing subscriptions.
7 |
8 | In a nutshell, one or more _publishers_ send data to _topics_. One or
9 | more _subscribers_ can listen on that topic. Every time a publisher
10 | sends a message, all of the subscribers will receive it.
11 |
12 | Features:
13 |
14 | - Uses the HTTP/2 standard with no additions.
15 | - JSON? Thrift? ProtoBuf? Use whatever.
16 | - Service metadata is confinted to HTTP headers. The payload is all
17 | yours.
18 | - Configurable history lets clients quickly catch up on what they missed.
19 | - Extensible architecture makes it easy for you to add your own flair.
20 | - And more in the works...
21 |
22 | The current implementation streams Data Frames. Once the Go libraries
23 | mature, we may instead opt to use full pushes (though the overhead for
24 | that may be higher than we want).
25 |
26 | **This library is not stable. The interfaces may change before the 0.1
27 | release**
28 |
29 | **Currently, the library ONLY supports HTTPS.**
30 |
31 | ## Installation
32 |
33 | ```
34 | $ brew install glide
35 | $ git clone $THIS_REPO
36 | $ glide init
37 | ```
38 |
39 | From there, you can build the server (`go build server/server.go`) or
40 | the example client (`go build client/client.go`).
41 |
42 | ## Simple Client Example
43 |
44 | To use Drift as a client, import the client library:
45 |
46 | ```go
47 | import "github.com/technosophos/drift/client"
48 | ```
49 |
50 | Here is a simple publisher:
51 |
52 | ```go
53 | c := client.New("https://localhost:5500")
54 | c.Publish("example", []byte("Hello World"))
55 | ```
56 |
57 | The above sends the "Hello World" message over the `example` topic.
58 |
59 | A subscriber looks like this:
60 |
61 | ```go
62 | s := client.New("https://localhost:5500")
63 | subscription, err := s.Subscribe("example")
64 | if err != nil {
65 | fmt.Printf("Failed subscription: %s", err)
66 | return
67 | }
68 |
69 | // Now listen on a stream.
70 | for msg := range subscription.Stream {
71 | fmt.Printf("Received: %s\n", msg)
72 | }
73 |
74 | // When you're done...
75 | subscription.Cancel()
76 | ```
77 |
78 | A more advanced API is provided for configuring history and adding
79 | arbitrary HTTP headers.
80 |
81 | ## About the Server
82 |
83 | The server lives in `server/server.go`. The basic server provides
84 | convenient features for getting running quickly.
85 |
86 | But the server was also designed as a composable system. You can easily
87 | take the parts here and add your own. Take a look at the registry in
88 | `server.server.go` to see how this is done.
89 |
90 | ## API
91 |
92 | `GET /`
93 |
94 | Prints a the runtime API documentation.
95 |
96 | `DELETE /v1/t/TOPIC`
97 |
98 | Destroy a topic named `TOPIC`.
99 |
100 | This will destroy the history and cancel subscriptions for all
101 | subscribed clients.
102 |
103 | `GET /v1/t/TOPIC`
104 |
105 | Subscribe to a topic named `TOPIC`. The client is expected to hold open
106 | a connection for the duration of its subscription.
107 |
108 | This method **does not support HTTP/1 at all!** You must use HTTP/2.
109 |
110 |
111 | `POST /v1/t/TOPIC`
112 |
113 | Post a new message into the topic named `TOPIC`.
114 |
115 | The body of the post message is pushed wholesale into the queue.
116 |
117 | This method accepts HTTP/1.1 POST content in addition to HTTP/2 POST.
118 | Only one data frame of HTTP/2 POST data is accepted. Streamed POST is
119 | currently not supported (though it will be).
120 |
121 | `PUT /v1/t/TOPIC`
122 |
123 | Create a new topic named `TOPIC`.
124 |
125 | The body of this message is a well-defined JSON data structure that
126 | describes the topic.
127 |
128 | `GET /v1/time`
129 |
130 | Get the current time. This returns a plain text value with nothing but a
131 | timestamp.
132 |
133 | ```
134 | $ curl -k https://localhost:5500/v1/time
135 | 1436464998
136 | ```
137 |
138 | The purpose of this callback is to give client libries a timestamp to
139 | use as the base time for calculating dates. This can reduce problems
140 | with clock skew.
141 |
--------------------------------------------------------------------------------
/server/server.go:
--------------------------------------------------------------------------------
1 | /* Package main demos a Pub/Sub server.
2 | */
3 | package main
4 |
5 | import (
6 | "net/http"
7 |
8 | "github.com/Masterminds/cookoo"
9 | cfmt "github.com/Masterminds/cookoo/fmt"
10 | "github.com/Masterminds/cookoo/web"
11 |
12 | "github.com/technosophos/drift/httputil"
13 | "github.com/technosophos/drift/pubsub"
14 |
15 | "github.com/bradfitz/http2"
16 | )
17 |
18 | var helpTemplate = `
19 |
20 | API Reference
21 |
22 | API Reference
23 | These are the API endpoints currently defined for this server.
24 | {{ range .Routes }}
25 | {{.Name}}
26 | - {{.Description}}
27 | {{end}}
28 |
29 | `
30 |
31 | func main() {
32 | srv := &http.Server{
33 | Addr: ":5500",
34 | }
35 |
36 | reg, router, cxt := cookoo.Cookoo()
37 |
38 | buildRegistry(reg, router, cxt)
39 |
40 | // Our main datasource is the Medium, which manages channels.
41 | m := pubsub.NewMedium()
42 | cxt.AddDatasource(pubsub.MediumDS, m)
43 | cxt.Put("routes", reg.Routes())
44 |
45 | http2.ConfigureServer(srv, &http2.Server{})
46 |
47 | srv.Handler = web.NewCookooHandler(reg, router, cxt)
48 |
49 | srv.ListenAndServeTLS("server.crt", "server.key")
50 | }
51 |
52 | func buildRegistry(reg *cookoo.Registry, router *cookoo.Router, cxt cookoo.Context) {
53 | reg.AddRoute(cookoo.Route{
54 | Name: "GET /ping",
55 | Help: "Ping the server, get a pong reponse.",
56 | Does: cookoo.Tasks{
57 | cookoo.Cmd{
58 | Fn: func(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) {
59 | w := c.Get("http.ResponseWriter", nil).(http.ResponseWriter)
60 | w.Write([]byte("pong"))
61 | return nil, nil
62 | },
63 | },
64 | },
65 | })
66 |
67 | reg.AddRoute(cookoo.Route{
68 | Name: "GET /v1/time",
69 | Help: "Print the current server time as a UNIX seconds-since-epoch",
70 | Does: cookoo.Tasks{
71 | cookoo.Cmd{
72 | Name: "timestamp",
73 | Fn: httputil.Timestamp,
74 | },
75 | cookoo.Cmd{
76 | Name: "_",
77 | Fn: web.Flush,
78 | Using: []cookoo.Param{
79 | {Name: "content", From: "cxt:timestamp"},
80 | {Name: "contentType", DefaultValue: "text/plain"},
81 | },
82 | },
83 | },
84 | })
85 |
86 | reg.AddRoute(cookoo.Route{
87 | Name: "GET /",
88 | Help: "API Reference",
89 | Does: cookoo.Tasks{
90 | cookoo.Cmd{
91 | Name: "help",
92 | Fn: cfmt.Template,
93 | Using: []cookoo.Param{
94 | {Name: "template", DefaultValue: helpTemplate},
95 | {Name: "Routes", From: "cxt:routes"},
96 | },
97 | },
98 | cookoo.Cmd{
99 | Name: "_",
100 | Fn: web.Flush,
101 | Using: []cookoo.Param{
102 | {Name: "content", From: "cxt:help"},
103 | {Name: "contentType", DefaultValue: "text/html"},
104 | },
105 | },
106 | },
107 | })
108 |
109 | reg.AddRoute(cookoo.Route{
110 | Name: "PUT /v1/t/*",
111 | Help: "Create a new topic.",
112 | Does: cookoo.Tasks{
113 | cookoo.Cmd{
114 | Name: "topic",
115 | Fn: pubsub.CreateTopic,
116 | Using: []cookoo.Param{
117 | {Name: "topic", From: "path:2"},
118 | },
119 | },
120 | },
121 | })
122 |
123 | reg.AddRoute(cookoo.Route{
124 | Name: "POST /v1/t/*",
125 | Help: "Publish a message to a channel.",
126 | Does: cookoo.Tasks{
127 | cookoo.Cmd{
128 | Name: "postBody",
129 | Fn: httputil.BufferPost,
130 | },
131 | cookoo.Cmd{
132 | Name: "publish",
133 | Fn: pubsub.Publish,
134 | Using: []cookoo.Param{
135 | {Name: "message", From: "cxt:postBody"},
136 | {Name: "topic", From: "path:2"},
137 | },
138 | },
139 | },
140 | })
141 |
142 | reg.AddRoute(cookoo.Route{
143 | Name: "GET /v1/t/*",
144 | Help: "Subscribe to a channel.",
145 | Does: cookoo.Tasks{
146 | cookoo.Cmd{
147 | Name: "history",
148 | Fn: pubsub.ReplayHistory,
149 | Using: []cookoo.Param{
150 | {Name: "topic", From: "path:2"},
151 | },
152 | },
153 | cookoo.Cmd{
154 | Name: "subscribe",
155 | Fn: pubsub.Subscribe,
156 | Using: []cookoo.Param{
157 | {Name: "topic", From: "path:2"},
158 | },
159 | },
160 | },
161 | })
162 |
163 | reg.AddRoute(cookoo.Route{
164 | Name: "HEAD /v1/t/*",
165 | Help: "Check whether a topic exists.",
166 | Does: cookoo.Tasks{
167 | cookoo.Cmd{
168 | Name: "has",
169 | Fn: pubsub.TopicExists,
170 | Using: []cookoo.Param{
171 | {Name: "topic", From: "path:2"},
172 | },
173 | },
174 | },
175 | })
176 |
177 | reg.AddRoute(cookoo.Route{
178 | Name: "DELETE /v1/t/*",
179 | Help: "Delete a topic and close all subscriptions to the topic.",
180 | Does: cookoo.Tasks{
181 | cookoo.Cmd{
182 | Name: "delete",
183 | Fn: pubsub.DeleteTopic,
184 | Using: []cookoo.Param{
185 | {Name: "topic", From: "path:2"},
186 | },
187 | },
188 | },
189 | })
190 | }
191 |
--------------------------------------------------------------------------------
/pubsub/topic_test.go:
--------------------------------------------------------------------------------
1 | package pubsub
2 |
3 | import (
4 | "bytes"
5 | "net/http"
6 | "sync"
7 | "testing"
8 | "time"
9 | )
10 |
11 | func TestSubscription(t *testing.T) {
12 |
13 | // canaries:
14 | var _ http.ResponseWriter = &mockResponseWriter{}
15 | var _ http.Flusher = &mockResponseWriter{}
16 |
17 | rw := &mockResponseWriter{}
18 | sub := NewSubscription(rw)
19 |
20 | rw2 := &mockResponseWriter{}
21 | sub2 := NewSubscription(rw2)
22 |
23 | if sub.Id == sub2.Id {
24 | t.Error("Two subscriptions have the same ID!!!!")
25 | }
26 |
27 | // Make sure the Queue is buffered.
28 | sub.Queue <- []byte("hi")
29 | out := <-sub.Queue
30 |
31 | if string(out) != "hi" {
32 | t.Error("Expected out to be 'hi'")
33 | }
34 |
35 | // Make sure that listen works.
36 | until := make(chan bool)
37 | go sub.Listen(until)
38 | sub.Queue <- []byte("hi")
39 |
40 | time.Sleep(2 * time.Millisecond)
41 | until <- true
42 |
43 | sub.Close()
44 | if rw.String() != "hi" {
45 | t.Errorf("Expected bytes 'hi', got '%s'", rw.String())
46 | }
47 |
48 | }
49 |
50 | func TestTopic(t *testing.T) {
51 | topic := NewTopic("test")
52 |
53 | if topic.Name() != "test" {
54 | t.Errorf("Expected name 'test', got '%s'", topic.Name())
55 | }
56 |
57 | subs := make([]*Subscription, 50)
58 |
59 | // Subscribe 50 times.
60 | for i := 0; i < 50; i++ {
61 | rw := &mockResponseWriter{}
62 | sub := NewSubscription(rw)
63 | subs[i] = sub
64 | done := make(chan bool)
65 | topic.Subscribe(sub)
66 | go sub.Listen(done)
67 | }
68 |
69 | topic.Publish([]byte("hi"))
70 | topic.Publish([]byte("there"))
71 |
72 | if len(topic.Subscribers()) != 50 {
73 | t.Errorf("Expected 50 subscribers, got %d.", len(topic.Subscribers()))
74 | }
75 |
76 | time.Sleep(5 * time.Millisecond)
77 |
78 | for _, s := range topic.Subscribers() {
79 | mw := s.Writer.(*mockResponseWriter).String()
80 | if mw != "hithere" {
81 | t.Errorf("Expected Subscription %d to have 'hithere'. Got '%s'", s.Id, mw)
82 | }
83 |
84 | //topic.Unsubscribe(s)
85 | }
86 |
87 | if err := topic.Close(); err != nil {
88 | t.Errorf("Error closing topic: %s", err)
89 | }
90 |
91 | if len(topic.Subscribers()) > 0 {
92 | t.Errorf("After close, topic should have no subscribers. Got %d", len(topic.Subscribers()))
93 | }
94 |
95 | }
96 |
97 | func BenchmarkTopic1Client(b *testing.B) {
98 | benchmarkTopic(1, b.N)
99 | }
100 |
101 | func BenchmarkTopic5Clients(b *testing.B) {
102 | benchmarkTopic(5, b.N)
103 | }
104 |
105 | // Create 50 subscribers and send 100 messages.
106 | func BenchmarkTopic50Clients(b *testing.B) {
107 | benchmarkTopic(50, b.N)
108 | }
109 |
110 | /*
111 | func BenchmarkTopic1Message(b *testing.B) {
112 | benchmarkTopic(b.N, 1)
113 | }
114 | func BenchmarkTopic5Message(b *testing.B) {
115 | benchmarkTopic(b.N, 1)
116 | }
117 | */
118 |
119 | func benchmarkTopic(scount, mcount int) {
120 | topic := NewTopic("test")
121 | subs := make([]*Subscription, scount)
122 |
123 | // Subscribe 50 times.
124 | for i := 0; i < scount; i++ {
125 | rw := &mockResponseWriter{}
126 | sub := NewSubscription(rw)
127 | subs[i] = sub
128 | done := make(chan bool)
129 | topic.Subscribe(sub)
130 | go sub.Listen(done)
131 | }
132 |
133 | for i := 0; i < mcount; i++ {
134 | topic.Publish([]byte("hi"))
135 | }
136 |
137 | //for _, s := range topic.Subscribers() {
138 | // topic.Unsubscribe(s)
139 | //}
140 |
141 | }
142 |
143 | type mockResponseWriter struct {
144 | headers http.Header
145 | writer bytes.Buffer
146 | mx sync.Mutex
147 | }
148 |
149 | func (r *mockResponseWriter) Header() http.Header {
150 | if len(r.headers) == 0 {
151 | r.headers = make(map[string][]string, 1)
152 | }
153 | return r.headers
154 | }
155 | func (r *mockResponseWriter) Write(d []byte) (int, error) {
156 | r.mx.Lock()
157 | defer r.mx.Unlock()
158 | return r.writer.Write(d)
159 | }
160 | func (r *mockResponseWriter) Buf() []byte {
161 | r.mx.Lock()
162 | defer r.mx.Unlock()
163 | return r.writer.Bytes()
164 | }
165 | func (r *mockResponseWriter) String() string {
166 | r.mx.Lock()
167 | defer r.mx.Unlock()
168 | return r.writer.String()
169 | }
170 | func (r *mockResponseWriter) WriteHeader(c int) {
171 | }
172 |
173 | func (r *mockResponseWriter) Flush() {}
174 |
175 | func (r *mockResponseWriter) CloseNotify() <-chan bool {
176 | return make(chan bool, 1)
177 | }
178 |
179 | // For benchmarking.
180 | type nilResponseWriter struct {
181 | headers http.Header
182 | writer bytes.Buffer
183 | mx sync.Mutex
184 | }
185 |
186 | func (r *nilResponseWriter) Header() http.Header {
187 | if len(r.headers) == 0 {
188 | r.headers = make(map[string][]string, 1)
189 | }
190 | return r.headers
191 | }
192 | func (r *nilResponseWriter) Write(d []byte) (int, error) {
193 | return len(d), nil
194 | }
195 | func (r *nilResponseWriter) WriteHeader(c int) {}
196 |
197 | func (r *nilResponseWriter) Flush() {}
198 |
--------------------------------------------------------------------------------
/client/client.go:
--------------------------------------------------------------------------------
1 | // Package client provides a client library for Drift.
2 | package client
3 |
4 | import (
5 | "bytes"
6 | "crypto/tls"
7 | "errors"
8 | "fmt"
9 | "net/http"
10 | "path"
11 | "time"
12 |
13 | "github.com/technosophos/drift/transport"
14 | )
15 |
16 | const v1Path = "/v1/t/"
17 |
18 | // Client provides consumer functions for Drift.
19 | //
20 | // Client contains the simple methods for working with subscriptions
21 | // and publishing. The Subscriber and Publisher objects can be used for
22 | // more detailed work.
23 | type Client struct {
24 | Url string
25 | // Does not verify cert against authorities.
26 | InsecureTLSDial bool
27 | }
28 |
29 | // New creates and initializes a new client.
30 | func New(url string) *Client {
31 | return &Client{
32 | Url: url,
33 | }
34 | }
35 |
36 | // Create creates a new topic on the pubsub server.
37 | func (c *Client) Create(topic string) error {
38 | url := c.Url + path.Join(v1Path, topic)
39 | _, err := c.basicRoundTrip("PUT", url)
40 | return err
41 | }
42 |
43 | // Delete removes an existing topic from the pubsub server.
44 | func (c *Client) Delete(topic string) error {
45 | url := c.Url + path.Join(v1Path, topic)
46 | _, err := c.basicRoundTrip("DELETE", url)
47 | return err
48 | }
49 |
50 | // Checks whether the server already has the topic.
51 | func (c *Client) Exists(topic string) bool {
52 | url := c.Url + path.Join(v1Path, topic)
53 | _, err := c.basicRoundTrip("HEAD", url)
54 | return err == nil
55 | }
56 |
57 | func (c *Client) Publish(topic string, msg []byte) error {
58 | p := NewPublisher(c.Url)
59 | _, err := p.Publish(topic, msg)
60 | return err
61 | }
62 |
63 | func (c *Client) Subscribe(topic string) (*Subscription, error) {
64 | s := NewSubscriber(c.Url)
65 | s.History.Len = 100
66 | return s.Subscribe(topic)
67 | }
68 |
69 | func (c *Client) basicRoundTrip(verb, url string) (*http.Response, error) {
70 | t := &transport.Transport{InsecureTLSDial: c.InsecureTLSDial}
71 |
72 | req, err := http.NewRequest(verb, url, nil)
73 | if err != nil {
74 | return nil, err
75 | }
76 |
77 | return t.RoundTrip(req)
78 | }
79 |
80 | // Publisher is responsible for publishing messages to the service.
81 | type Publisher struct {
82 | Url string
83 | Header http.Header
84 | }
85 |
86 | // NewPublisher creates a new Publisher.
87 | func NewPublisher(url string) *Publisher {
88 | return &Publisher{
89 | Url: url,
90 | Header: map[string][]string{},
91 | }
92 | }
93 |
94 | // Publish sends the service a message for a particular topic.
95 | func (p *Publisher) Publish(topic string, message []byte) (*http.Response, error) {
96 |
97 | if len(message) == 0 {
98 | return nil, errors.New("Cannot send an empty message")
99 | }
100 | if len(topic) == 0 {
101 | return nil, errors.New("Cannot publish to an empty topic.")
102 | }
103 | /* HTTP2 does not currently send the body! So we have to go to HTTP1
104 | t := &transport.Transport{InsecureTLSDial: true}
105 | */
106 | t := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
107 |
108 | url := p.Url + path.Join(v1Path, topic)
109 |
110 | var body bytes.Buffer
111 | body.Write(message)
112 |
113 | req, _ := http.NewRequest("POST", url, &body)
114 |
115 | return t.RoundTrip(req)
116 | }
117 |
118 | // Subscription represents an existing subscription that a subscriber
119 | // has subscribed to.
120 | type Subscription struct {
121 | C chan []byte
122 | listener transport.Listener
123 | }
124 |
125 | func (s *Subscription) Cancel() {
126 | // Signal the transport that the clientStream should be removed.
127 | s.listener.Cancel()
128 | }
129 |
130 | // Subscriber defines a client that subscribes to a topic on a PubSub.
131 | type Subscriber struct {
132 | Url string
133 | History History
134 | Header http.Header
135 | }
136 |
137 | func NewSubscriber(url string) *Subscriber {
138 | return &Subscriber{
139 | Url: url,
140 | Header: map[string][]string{},
141 | }
142 | }
143 |
144 | // History describes how much history a subscriber should ask for.
145 | //
146 | // Be default, Subscribers do not ask for any history.
147 | type History struct {
148 | Since time.Time
149 | Len int
150 | }
151 |
152 | func (s *Subscriber) Subscribe(topic string) (*Subscription, error) {
153 | if len(topic) == 0 {
154 | return nil, errors.New("Cannot subscribe to an empty channel.")
155 | }
156 |
157 | url := s.Url + path.Join(v1Path, topic)
158 | fmt.Printf("URL: %s\n", url)
159 |
160 | t := &transport.Transport{InsecureTLSDial: true}
161 |
162 | req, err := http.NewRequest("GET", url, nil)
163 | if err != nil {
164 | return nil, err
165 | }
166 |
167 | s.setHeaders(req)
168 |
169 | _, listener, err := t.Listen(req)
170 | if err != nil {
171 | return nil, err
172 | }
173 |
174 | stream, err := listener.Stream()
175 | if err != nil {
176 | return nil, err
177 | }
178 |
179 | return &Subscription{C: stream, listener: listener}, nil
180 | }
181 |
182 | func (s *Subscriber) setHeaders(req *http.Request) {
183 | req.Header = s.Header
184 | if s.History.Len > 0 {
185 | req.Header.Add("X-History-Length", fmt.Sprintf("%d", s.History.Len))
186 | }
187 | if s.History.Since.After(time.Unix(0, 0)) {
188 | req.Header.Add("X-History-Since", fmt.Sprintf("%d", s.History.Since.Unix()))
189 | }
190 | }
191 |
--------------------------------------------------------------------------------
/client/example/client.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | //"io/ioutil"
7 | "os"
8 | //"net"
9 | "bytes"
10 | "crypto/tls"
11 | "net/http"
12 | "time"
13 |
14 | "github.com/bradfitz/http2"
15 | "github.com/bradfitz/http2/hpack"
16 | "github.com/technosophos/drift/client"
17 | "github.com/technosophos/drift/transport"
18 | )
19 |
20 | func main() {
21 |
22 | cmd := os.Args[1]
23 |
24 | /*
25 | //t := &http2.Transport{InsecureTLSDial: true}
26 | t := &transport.Transport{InsecureTLSDial: true}
27 |
28 | req, _ := http.NewRequest("GET", "https://localhost:5500/", nil)
29 |
30 | //res, err := t.RoundTrip(req)
31 | res, stream, err := t.Listen(req)
32 | if err != nil {
33 | fmt.Printf("Failed: %s\n", err)
34 | }
35 | data, _ := ioutil.ReadAll(res.Body)
36 | fmt.Printf("Response: %s %d, %q\n", res.Proto, res.StatusCode, data)
37 |
38 | fmt.Println("Waiting for the next message.")
39 | moredata := <-stream
40 | fmt.Printf("Final data: %s", moredata)
41 |
42 | // Next, hit the ticker
43 | getTicker()
44 | */
45 | switch cmd {
46 | case "publish":
47 | //publish()
48 | p := client.NewPublisher("https://localhost:5500")
49 | p.Publish("example", []byte("Hello World"))
50 | time.Sleep(200 * time.Millisecond)
51 | p.Publish("example", []byte("Hello again"))
52 | case "subscribe":
53 | //subscribe()
54 | s := client.NewSubscriber("https://localhost:5500")
55 | s.History = &client.History{Len: 5}
56 | stream, err := s.Subscribe("example")
57 | if err != nil {
58 | fmt.Printf("Failed subscription: %s", err)
59 | return
60 | }
61 | for msg := range stream {
62 | fmt.Printf("Received: %s\n", msg)
63 | }
64 | default:
65 | fmt.Printf("Unknown command: %s\n", cmd)
66 | }
67 | }
68 |
69 | func publish() {
70 |
71 | /* HTTP2 does not currently send the body! So we have to go to HTTP1
72 | t := &transport.Transport{InsecureTLSDial: true}
73 | */
74 | t := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
75 | var body bytes.Buffer
76 |
77 | body.Write([]byte("test"))
78 |
79 | req, _ := http.NewRequest("POST", "https://localhost:5500/v1/t/TEST", &body)
80 |
81 | res, err := t.RoundTrip(req)
82 | if err != nil {
83 | fmt.Printf("Error during round trip: %s", err)
84 | }
85 |
86 | fmt.Printf("Status: %s", res.Status)
87 |
88 | }
89 |
90 | func subscribe() {
91 | t := &transport.Transport{InsecureTLSDial: true}
92 |
93 | req, _ := http.NewRequest("GET", "https://localhost:5500/v1/t/TEST", nil)
94 |
95 | //res, err := t.RoundTrip(req)
96 | res, stream, err := t.Listen(req)
97 | if err != nil {
98 | fmt.Printf("Failed: %s\n", err)
99 | }
100 | fmt.Printf("Headers: %s\n", res.Status)
101 | //data, _ := ioutil.ReadAll(res.Body)
102 | //fmt.Printf("Response: %s %d, %q\n", res.Proto, res.StatusCode, data)
103 |
104 | fmt.Println("Waiting for the next message.")
105 | for data := range stream {
106 | fmt.Printf("\nReceived: %s\n", data)
107 | }
108 | }
109 |
110 | func getTicker() *http.Response {
111 | t := &transport.Transport{InsecureTLSDial: true}
112 |
113 | req, _ := http.NewRequest("GET", "https://localhost:5500/tick", nil)
114 |
115 | //res, err := t.RoundTrip(req)
116 | res, stream, err := t.Listen(req)
117 | if err != nil {
118 | fmt.Printf("Failed: %s\n", err)
119 | }
120 |
121 | fmt.Printf("Status: %s\n", res.Status)
122 | fmt.Printf("Body: %s\n", res.Body)
123 |
124 | for data := range stream {
125 | fmt.Printf("\nReceived: %s\n", data)
126 | }
127 |
128 | //io.Copy(os.Stdout, res.Body)
129 | return res
130 | }
131 |
132 | func custom() {
133 | dest := "localhost"
134 | port := ":5500"
135 | // Create a new client.
136 |
137 | tlscfg := &tls.Config{
138 | ServerName: dest,
139 | NextProtos: []string{http2.NextProtoTLS},
140 | InsecureSkipVerify: true,
141 | }
142 |
143 | conn, err := tls.Dial("tcp", dest+port, tlscfg)
144 | if err != nil {
145 | panic(err)
146 | }
147 | defer conn.Close()
148 |
149 | conn.Handshake()
150 | state := conn.ConnectionState()
151 | fmt.Printf("Protocol is : %q\n", state.NegotiatedProtocol)
152 |
153 | if _, err := io.WriteString(conn, http2.ClientPreface); err != nil {
154 | fmt.Printf("Preface failed: %s", err)
155 | return
156 | }
157 |
158 | var hbuf bytes.Buffer
159 | framer := http2.NewFramer(conn, conn)
160 |
161 | enc := hpack.NewEncoder(&hbuf)
162 | writeHeader(enc, ":authority", "localhost")
163 | writeHeader(enc, ":method", "GET")
164 | writeHeader(enc, ":path", "/ping")
165 | writeHeader(enc, ":scheme", "https")
166 | writeHeader(enc, "Accept", "*/*")
167 |
168 | if len(hbuf.Bytes()) > 16<<10 {
169 | fmt.Printf("Need CONTINUATION\n")
170 | }
171 |
172 | headers := http2.HeadersFrameParam{
173 | StreamID: 1,
174 | EndStream: true,
175 | EndHeaders: true,
176 | BlockFragment: hbuf.Bytes(),
177 | }
178 |
179 | fmt.Printf("All the stuff: %q\n", headers.BlockFragment)
180 |
181 | go listen(framer)
182 |
183 | framer.WriteSettings()
184 | framer.WriteWindowUpdate(0, 1<<30)
185 | framer.WriteSettingsAck()
186 |
187 | time.Sleep(time.Second * 2)
188 | framer.WriteHeaders(headers)
189 | time.Sleep(time.Second * 2)
190 |
191 | /* A ping HTTP request
192 | var payload [8]byte
193 | copy(payload[:], "_c0ffee_")
194 | framer.WritePing(false, payload)
195 | rawpong, err := framer.ReadFrame()
196 | if err != nil {
197 | panic(err)
198 | }
199 |
200 | pong, ok := rawpong.(*http2.PingFrame)
201 | if !ok {
202 | fmt.Printf("Instead of a Ping, I got this: %v\n", pong)
203 | return
204 | }
205 |
206 | fmt.Printf("Pong: %q\n", pong.Data)
207 | */
208 |
209 | }
210 |
211 | func listen(framer *http2.Framer) {
212 | for {
213 | response, err := framer.ReadFrame()
214 | if err != nil {
215 | if err == io.EOF {
216 | return
217 | }
218 | fmt.Printf("Error: Got %q\n", err)
219 | }
220 | switch t := response.(type) {
221 | case *http2.SettingsFrame:
222 | t.ForeachSetting(func(s http2.Setting) error {
223 | fmt.Printf("Setting: %q\n", s)
224 | return nil
225 | })
226 | case *http2.GoAwayFrame:
227 | fmt.Printf("Go Away code = %q, stream ID = %d\n", t.ErrCode, t.StreamID)
228 | }
229 | //data := response.(*http2.DataFrame)
230 | fmt.Printf("Got %q\n", response)
231 | }
232 | }
233 |
234 | func writeHeader(enc *hpack.Encoder, name, value string) {
235 | enc.WriteField(hpack.HeaderField{Name: name, Value: value})
236 | }
237 |
--------------------------------------------------------------------------------
/pubsub/topic.go:
--------------------------------------------------------------------------------
1 | package pubsub
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "math"
7 | "net/http"
8 | "sync"
9 | "sync/atomic"
10 | "time"
11 |
12 | "github.com/Masterminds/cookoo"
13 | )
14 |
15 | // ResponseWriterFlusher handles both HTTP response writing and flushing.
16 | //
17 | // We use this simply to declare which interfaces we require support for.
18 | type ResponseWriterFlusher interface {
19 | http.ResponseWriter
20 | http.Flusher
21 | http.CloseNotifier
22 | }
23 |
24 | // Topic is the main channel for sending messages to subscribers.
25 | //
26 | // A publisher is anything that sends a message to a Topic. All
27 | // attached subscribers will receive that message.
28 | type Topic interface {
29 | // Publish sends a message to all subscribers.
30 | Publish([]byte) error
31 | // Subscribe attaches a subscription to this topic.
32 | Subscribe(*Subscription)
33 | // Unsubscribe detaches a subscription from the topic.
34 | Unsubscribe(*Subscription)
35 | // Name returns the topic name.
36 | Name() string
37 | // Subscribers returns a list of subscriptions attached to this topic.
38 | Subscribers() []*Subscription
39 | // Close and destroy the topic.
40 | Close() error
41 | }
42 |
43 | // History provides access too the last N messages on a particular Topic.
44 | type History interface {
45 | // Last provides access to up to N messages.
46 | Last(int) [][]byte
47 | // Since provides access to all messages in history since the given time.
48 | Since(time.Time) [][]byte
49 | }
50 |
51 | // HistoriedTopic is a topic that has an attached history.
52 | type HistoriedTopic interface {
53 | History
54 | Topic
55 | }
56 |
57 | // NewTopic creates a new Topic with no history capabilities.
58 | func NewTopic(name string) Topic {
59 | ct := &channeledTopic{
60 | name: name,
61 | subscribers: make(map[uint64]*Subscription, 512), // Sane default space?
62 | }
63 | return ct
64 | }
65 |
66 | // NewHistoriedTopic creates a new HistoriedTopic.
67 | //
68 | // This topic will retain `length` history items for the topic.
69 | func NewHistoriedTopic(name string, length int) HistoriedTopic {
70 | return TrackHistory(NewTopic(name), length)
71 | }
72 |
73 | type channeledTopic struct {
74 | name string
75 | subscribers map[uint64]*Subscription
76 | mx sync.RWMutex
77 | closed bool
78 | }
79 |
80 | func (t *channeledTopic) Close() error {
81 | t.mx.Lock()
82 | t.closed = true
83 | for _, s := range t.subscribers {
84 | s.Close()
85 | }
86 | t.subscribers = map[uint64]*Subscription{}
87 | t.mx.Unlock()
88 | return nil
89 | }
90 |
91 | func (t *channeledTopic) Publish(msg []byte) error {
92 | if t.closed {
93 | return errors.New("Topic is being deleted.")
94 | }
95 | t.mx.Lock()
96 | defer func() {
97 | t.mx.Unlock()
98 | if err := recover(); err != nil {
99 | fmt.Printf("Recovered from failed publish. Some messages probably didn't get through. %s\n", err)
100 | }
101 | }()
102 |
103 | for _, s := range t.subscribers {
104 | if s.Queue == nil {
105 | fmt.Printf("Channel appears to be closed. Skipping.\n")
106 | continue
107 | }
108 | //fmt.Printf("Sending msg to subscriber %d: %s\n", s.Id, msg)
109 | s.Queue <- msg
110 | }
111 | //fmt.Printf("Message sent.\n")
112 | return nil
113 | }
114 |
115 | func (t *channeledTopic) Subscribe(s *Subscription) {
116 | if t.closed {
117 | return
118 | }
119 | t.mx.Lock()
120 | defer t.mx.Unlock()
121 | //t.subscribers = append(t.subscribers, s)
122 | if _, ok := t.subscribers[s.Id]; ok {
123 | fmt.Printf("Surprisingly got the same ID as an existing subscriber.")
124 | }
125 | t.subscribers[s.Id] = s
126 | //fmt.Printf("There are now %d subscribers", len(t.subscribers))
127 | }
128 |
129 | func (t *channeledTopic) Unsubscribe(s *Subscription) {
130 | if t.closed {
131 | return
132 | }
133 | t.mx.Lock()
134 | defer t.mx.Unlock()
135 | delete(t.subscribers, s.Id)
136 | }
137 |
138 | func (t *channeledTopic) Name() string {
139 | return t.name
140 | }
141 |
142 | func (t *channeledTopic) Subscribers() []*Subscription {
143 | t.mx.RLock()
144 | defer t.mx.RUnlock()
145 | c := len(t.subscribers)
146 | s := make([]*Subscription, 0, c)
147 | for _, v := range t.subscribers {
148 | s = append(s, v)
149 | }
150 | return s
151 | }
152 |
153 | // Subscription describes a subscriber.
154 | //
155 | // A subscription attaches to ONLY ONE Topic.
156 | type Subscription struct {
157 | Id uint64
158 | Writer ResponseWriterFlusher
159 | Queue chan []byte
160 | }
161 |
162 | // NewSubscription creates a new subscription.
163 | //
164 | //
165 | func NewSubscription(r ResponseWriterFlusher) *Subscription {
166 | // Queue depth should be revisited.
167 | q := make(chan []byte, 10)
168 | return &Subscription{
169 | Writer: r,
170 | Queue: q,
171 | Id: newSubId(),
172 | }
173 | }
174 |
175 | // Listen copies messages fromt the Queue into the Writer.
176 | //
177 | // It listens on the Queue unless the `stop` channel receives a message.
178 | func (s *Subscription) Listen(stop <-chan bool) {
179 | // s.Queue <- []byte("SUBSCRIBED")
180 | for {
181 | //for msg := range s.Queue {
182 | select {
183 | case msg := <-s.Queue:
184 | //fmt.Printf("Forwarding message.\n")
185 | // Queue is always serial, and this should be the only writer to the
186 | // RequestWriter, so we don't explicitly sync right now.
187 | s.Writer.Write(msg)
188 | s.Writer.Flush()
189 | case <-stop:
190 | //fmt.Printf("Subscription ended.\n")
191 | return
192 | default:
193 | }
194 | }
195 | }
196 |
197 | // Close closes things and cleans up.
198 | func (s *Subscription) Close() {
199 | close(s.Queue)
200 | }
201 |
202 | // getMedium fetches the Medium from the Datasources list.
203 | func getMedium(c cookoo.Context) (*Medium, error) {
204 | ds, ok := c.HasDatasource(MediumDS)
205 | if !ok {
206 | return nil, errors.New("Cannot find a Medium")
207 | }
208 | return ds.(*Medium), nil
209 | }
210 |
211 | // NewMedium creates and initializes a Medium.
212 | func NewMedium() *Medium {
213 | return &Medium{
214 | topics: make(map[string]Topic, 256), // Premature optimization...
215 | }
216 | }
217 |
218 | // Medium handles channeling messages to topics.
219 | //
220 | // You should always create one with NewMedium or else you will not be able
221 | // to add new topics.
222 | type Medium struct {
223 | topics map[string]Topic
224 | mx sync.RWMutex
225 | }
226 |
227 | // Topic gets a Topic by name.
228 | //
229 | // If no topic is found, the ok flag will return false.
230 | func (m *Medium) Topic(name string) (Topic, bool) {
231 | m.mx.RLock()
232 | defer m.mx.RUnlock()
233 | t, ok := m.topics[name]
234 | return t, ok
235 | }
236 |
237 | // Add a new Topic to the Medium.
238 | func (m *Medium) Add(t Topic) {
239 | m.mx.Lock()
240 | m.topics[t.Name()] = t
241 | m.mx.Unlock()
242 | }
243 |
244 | // Delete closes a topic and removes it.
245 | func (m *Medium) Delete(name string) error {
246 | t, ok := m.topics[name]
247 | if !ok {
248 | return fmt.Errorf("Cannot delete. No topic named %s.", name)
249 | }
250 | t.Close()
251 | m.mx.Lock()
252 | delete(m.topics, name)
253 | m.mx.Unlock()
254 | return nil
255 | }
256 |
257 | var lastSubId uint64 = 0
258 |
259 | // newSubId returns an atomically incremented ID.
260 | //
261 | // This can probably be done better.
262 | func newSubId() uint64 {
263 | z := atomic.AddUint64(&lastSubId, uint64(1))
264 | // FIXME: And when we hit max? Rollover?
265 | if z == math.MaxUint64 {
266 | atomic.StoreUint64(&lastSubId, uint64(0))
267 | }
268 | return z
269 | }
270 |
--------------------------------------------------------------------------------
/pubsub/cmds.go:
--------------------------------------------------------------------------------
1 | /* Package pubsub provides publish/subscribe operations for HTTP/2.
2 |
3 | */
4 | package pubsub
5 |
6 | import (
7 | "errors"
8 | "fmt"
9 | "net/http"
10 | "strconv"
11 | "time"
12 |
13 | "github.com/Masterminds/cookoo"
14 | )
15 |
16 | const MediumDS = "drift.Medium"
17 |
18 | const (
19 | // XHistorySince is an HTTP header for the client to send a request for history since TIMESTAMP.
20 | XHistorySince = "x-history-since"
21 | // XHistoryLength is an HTTP header for the client to send a request for the last N records.
22 | XHistoryLength = "x-history-length"
23 | // XHistoryEnabled is a flag for the server to notify the client whether history is enabled.
24 | XHistoryEnabled = "x-history-enabled"
25 | )
26 |
27 | // Publish sends a new message to a topic.
28 | //
29 | // Params:
30 | // - topic (string): The topic to send to.
31 | // - message ([]byte): The message to send.
32 | // - withHistory (bool): Turn on history. Default is true. This only takes
33 | // effect when the channel is created.
34 | //
35 | // Datasources:
36 | // - This uses the 'drift.Medium' datasource.
37 | //
38 | // Returns:
39 | //
40 | func Publish(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) {
41 | hist := p.Get("withHistory", true).(bool)
42 | topic := p.Get("topic", "").(string)
43 | if len(topic) == 0 {
44 | return nil, errors.New("No topic supplied.")
45 | }
46 |
47 | medium, _ := getMedium(c)
48 |
49 | // Is there any reason to disallow empty messages?
50 | msg := p.Get("message", []byte{}).([]byte)
51 | c.Logf("info", "Msg: %s", msg)
52 |
53 | t := fetchOrCreateTopic(medium, topic, hist, DefaultMaxHistory)
54 | return nil, t.Publish(msg)
55 |
56 | }
57 |
58 | // Subscribe allows an request to subscribe to topic updates.
59 | //
60 | // Params:
61 | // - topic (string): The topic to subscribe to.
62 | // -
63 | //
64 | // Returns:
65 | //
66 | func Subscribe(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) {
67 | medium, err := getMedium(c)
68 | if err != nil {
69 | return nil, &cookoo.FatalError{"No medium."}
70 | }
71 | topic := p.Get("topic", "").(string)
72 | if len(topic) == 0 {
73 | return nil, errors.New("No topic is set.")
74 | }
75 |
76 | rw := c.Get("http.ResponseWriter", nil).(ResponseWriterFlusher)
77 | clientGone := rw.(http.CloseNotifier).CloseNotify()
78 |
79 | sub := NewSubscription(rw)
80 | t := fetchOrCreateTopic(medium, topic, true, DefaultMaxHistory)
81 | t.Subscribe(sub)
82 |
83 | defer func() {
84 | t.Unsubscribe(sub)
85 | sub.Close()
86 | }()
87 |
88 | sub.Listen(clientGone)
89 |
90 | return nil, nil
91 | }
92 |
93 | // CreateTopic creates a new topic.
94 | //
95 | // Params:
96 | // - topic (string)
97 | // - history (bool): whether or not to track history
98 | // - historyLength (int): How much history to track. Default is DefaultMaxHistory.
99 | //
100 | // Returns:
101 | // Topic the new topic.
102 | func CreateTopic(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) {
103 | name := p.Get("topic", "").(string)
104 | if len(name) == 0 {
105 | return nil, &cookoo.FatalError{"Topic name required."}
106 | }
107 |
108 | hist := p.Get("history", true).(bool)
109 | histLen := p.Get("historyLength", DefaultMaxHistory).(int)
110 |
111 | m, err := getMedium(c)
112 | if err != nil {
113 | return nil, &cookoo.FatalError{"No medium."}
114 | }
115 |
116 | t := fetchOrCreateTopic(m, name, hist, histLen)
117 |
118 | return t, nil
119 |
120 | }
121 |
122 | // DeleteTopic deletes a topic and its history.
123 | //
124 | // Params:
125 | // - name (string)
126 | //
127 | // Returns:
128 | //
129 | func DeleteTopic(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) {
130 | name := p.Get("topic", "").(string)
131 | if len(name) == 0 {
132 | return nil, &cookoo.FatalError{"Topic name required."}
133 | }
134 |
135 | m, err := getMedium(c)
136 | if err != nil {
137 | return nil, &cookoo.FatalError{"No medium."}
138 | }
139 |
140 | err = m.Delete(name)
141 | if err != nil {
142 | c.Logf("warn", "Failed to delete topic: %s", err)
143 | }
144 |
145 | return nil, nil
146 | }
147 |
148 | // TopicExists tests whether a topic exists, and sends an HTTP 200 if yes, 404 if no.
149 | //
150 | // Params:
151 | // - topic (string): The topic to look up.
152 | // Returns:
153 | //
154 | func TopicExists(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) {
155 | res := c.Get("http.ResponseWriter", nil).(ResponseWriterFlusher)
156 | name := p.Get("topic", "").(string)
157 | if len(name) == 0 {
158 | res.WriteHeader(404)
159 | return nil, nil
160 | }
161 |
162 | medium, err := getMedium(c)
163 | if err != nil {
164 | res.WriteHeader(404)
165 | return nil, nil
166 | }
167 |
168 | if _, ok := medium.Topic(name); ok {
169 | res.WriteHeader(200)
170 | return nil, nil
171 | }
172 | res.WriteHeader(404)
173 | return nil, nil
174 | }
175 |
176 | // ReplayHistory sends back the history to a subscriber.
177 | //
178 | // This should be called before the client goes into active listening.
179 | //
180 | // Params:
181 | // - topic (string): The topic to fetch.
182 | //
183 | // Returns:
184 | // - int: The number of history messages sent to the client.
185 | func ReplayHistory(c cookoo.Context, p *cookoo.Params) (interface{}, cookoo.Interrupt) {
186 | req := c.Get("http.Request", nil).(*http.Request)
187 | res := c.Get("http.ResponseWriter", nil).(ResponseWriterFlusher)
188 | medium, _ := getMedium(c)
189 | name := p.Get("topic", "").(string)
190 |
191 | // This does not manage topics. If there is no topic set, we silently fail.
192 | if len(name) == 0 {
193 | c.Log("info", "No topic name given to ReplayHistory.")
194 | return 0, nil
195 | }
196 | top, ok := medium.Topic(name)
197 | if !ok {
198 | c.Logf("info", "No topic named %s exists yet. No history replayed.", name)
199 | return 0, nil
200 | }
201 |
202 | topic, ok := top.(HistoriedTopic)
203 | if !ok {
204 | c.Logf("info", "No history for topic %s.", name)
205 | res.Header().Add(XHistoryEnabled, "False")
206 | return 0, nil
207 | }
208 | res.Header().Add(XHistoryEnabled, "True")
209 |
210 | since := req.Header.Get(XHistorySince)
211 | max := req.Header.Get(XHistoryLength)
212 |
213 | // maxLen can be used either on its own or paired with X-History-Since.
214 | maxLen := 0
215 | if len(max) > 0 {
216 | m, err := parseHistLen(max)
217 | if err != nil {
218 | c.Logf("info", "failed to parse X-History-Length %s", max)
219 | } else {
220 | maxLen = m
221 | }
222 | }
223 | if len(since) > 0 {
224 | ts, err := parseSince(since)
225 | if err != nil {
226 | c.Logf("warn", "Failed to parse X-History-Since field %s: %s", since, err)
227 | return 0, nil
228 | }
229 | toSend := topic.Since(ts)
230 |
231 | // If maxLen is also set, we trim the list by sending the newest.
232 | ls := len(toSend)
233 | if maxLen > 0 && ls > maxLen {
234 | offset := ls - maxLen - 1
235 | toSend = toSend[offset:]
236 | }
237 | return sendHistory(c, res, toSend)
238 | } else if maxLen > 0 {
239 | toSend := topic.Last(maxLen)
240 | return sendHistory(c, res, toSend)
241 | }
242 |
243 | return 0, nil
244 | }
245 |
246 | // sendHistory sends the accumulated history to the writer.
247 | func sendHistory(c cookoo.Context, writer ResponseWriterFlusher, data [][]byte) (int, error) {
248 | c.Logf("info", "Sending history.")
249 | var i int
250 | var d []byte
251 | for i, d = range data {
252 | _, err := writer.Write(d)
253 | if err != nil {
254 | c.Logf("warn", "Failed to write history message: %s", err)
255 | return i + 1, nil
256 | }
257 | writer.Flush()
258 | }
259 | return i + 1, nil
260 | }
261 |
262 | // parseSince parses the X-History-Since value.
263 | func parseSince(s string) (time.Time, error) {
264 | tint, err := strconv.ParseInt(s, 0, 64)
265 | if err != nil {
266 | return time.Unix(0, 0), fmt.Errorf("Could not parse as time: %s", s)
267 | }
268 | return time.Unix(tint, 0), nil
269 | }
270 |
271 | // parseHistLen parses the X-History-Length value.
272 | func parseHistLen(s string) (int, error) {
273 | return strconv.Atoi(s)
274 | }
275 |
276 | // fetchOrCreateTopic gets a topic if it exists, and creates one if it doesn't.
277 | func fetchOrCreateTopic(m *Medium, name string, hist bool, l int) Topic {
278 | t, ok := m.Topic(name)
279 | if !ok {
280 | t = NewTopic(name)
281 | if hist && l > 0 {
282 | t = TrackHistory(t, l)
283 | }
284 | m.Add(t)
285 | }
286 | return t
287 | }
288 |
--------------------------------------------------------------------------------
/transport/transport.go:
--------------------------------------------------------------------------------
1 | // This is a fork of http2.Transport and the associated code.
2 | //
3 | // The purpose of this is to expose the data frames so that we can handle
4 | // streams directly. The RoundTripper implementation has thus changed.
5 | //
6 | // Copyright 2015 The Go Authors.
7 | // See https://go.googlesource.com/go/+/master/CONTRIBUTORS
8 | // Licensed under the same terms as Go itself:
9 | // https://go.googlesource.com/go/+/master/LICENSE
10 |
11 | package transport
12 |
13 | import (
14 | "bufio"
15 | "bytes"
16 | "crypto/tls"
17 | "errors"
18 | "fmt"
19 | "io"
20 | "log"
21 | "net"
22 | "net/http"
23 | "strconv"
24 | "strings"
25 | "sync"
26 | "time"
27 |
28 | "github.com/bradfitz/http2"
29 | "github.com/bradfitz/http2/hpack"
30 | )
31 |
32 | // MPB: copied from http2
33 |
34 | var (
35 | clientPreface = []byte(http2.ClientPreface)
36 | initialHeaderTableSize uint32 = 4096
37 | )
38 |
39 | type streamEnder interface {
40 | StreamEnded() bool
41 | }
42 |
43 | type headersEnder interface {
44 | HeadersEnded() bool
45 | }
46 |
47 | // MPB: end copy.
48 |
49 | type Transport struct {
50 | Fallback http.RoundTripper
51 |
52 | // TODO: remove this and make more general with a TLS dial hook, like http
53 | InsecureTLSDial bool
54 |
55 | connMu sync.Mutex
56 | conns map[string][]*clientConn // key is host:port
57 | }
58 |
59 | type clientConn struct {
60 | t *Transport
61 | tconn *tls.Conn
62 | tlsState *tls.ConnectionState
63 | connKey []string // key(s) this connection is cached in, in t.conns
64 |
65 | readerDone chan struct{} // closed on error
66 | readerErr error // set before readerDone is closed
67 | hdec *hpack.Decoder
68 | nextRes *http.Response
69 |
70 | mu sync.Mutex
71 | closed bool
72 | goAway *http2.GoAwayFrame // if non-nil, the GoAwayFrame we received
73 | streams map[uint32]*clientStream
74 | nextStreamID uint32
75 | bw *bufio.Writer
76 | werr error // first write error that has occurred
77 | br *bufio.Reader
78 | fr *http2.Framer
79 | // Settings from peer:
80 | maxFrameSize uint32
81 | maxConcurrentStreams uint32
82 | initialWindowSize uint32
83 | hbuf bytes.Buffer // HPACK encoder writes into this
84 | henc *hpack.Encoder
85 | }
86 |
87 | type clientStream struct {
88 | ID uint32
89 | resc chan resAndError
90 | pw *io.PipeWriter
91 | pr *io.PipeReader
92 |
93 | // MPB: Allow option to send data frame over a channel.
94 | dataToChan bool
95 | data chan []byte
96 | cancel chan bool
97 | }
98 |
99 | // Listener makes a stream into something that can be listened to.
100 | type Listener interface {
101 | Stream() (chan []byte, error)
102 | Cancel()
103 | }
104 |
105 | func (c *clientStream) Stream() (chan []byte, error) {
106 | if !c.dataToChan {
107 | return nil, errors.New("Listener has no data.")
108 | }
109 | return c.data, nil
110 | }
111 |
112 | func (c *clientStream) Cancel() {
113 | c.cancel <- true
114 | }
115 |
116 | type stickyErrWriter struct {
117 | w io.Writer
118 | err *error
119 | }
120 |
121 | func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
122 | if *sew.err != nil {
123 | return 0, *sew.err
124 | }
125 | n, err = sew.w.Write(p)
126 | *sew.err = err
127 | return
128 | }
129 |
130 | func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
131 | if req.URL.Scheme != "https" {
132 | if t.Fallback == nil {
133 | return nil, errors.New("http2: unsupported scheme and no Fallback")
134 | }
135 | return t.Fallback.RoundTrip(req)
136 | }
137 |
138 | host, port, err := net.SplitHostPort(req.URL.Host)
139 | if err != nil {
140 | host = req.URL.Host
141 | port = "443"
142 | }
143 |
144 | for {
145 | cc, err := t.getClientConn(host, port)
146 | if err != nil {
147 | return nil, err
148 | }
149 | res, err := cc.roundTrip(req)
150 | if shouldRetryRequest(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
151 | continue
152 | }
153 | if err != nil {
154 | return nil, err
155 | }
156 | return res, nil
157 | }
158 | }
159 | func (t *Transport) Listen(req *http.Request) (*http.Response, Listener, error) {
160 | if req.URL.Scheme != "https" {
161 | return nil, nil, errors.New("http2: unsupported scheme and no Fallback")
162 | }
163 |
164 | host, port, err := net.SplitHostPort(req.URL.Host)
165 | if err != nil {
166 | host = req.URL.Host
167 | port = "443"
168 | }
169 |
170 | for {
171 | cc, err := t.getClientConn(host, port)
172 | if err != nil {
173 | return nil, nil, err
174 | }
175 | res, cs, err := cc.listen(req)
176 | if shouldRetryRequest(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
177 | continue
178 | }
179 | if err != nil {
180 | return nil, nil, err
181 | }
182 | return res, cs, nil
183 | }
184 | }
185 |
186 | // CloseIdleConnections closes any connections which were previously
187 | // connected from previous requests but are now sitting idle.
188 | // It does not interrupt any connections currently in use.
189 | func (t *Transport) CloseIdleConnections() {
190 | t.connMu.Lock()
191 | defer t.connMu.Unlock()
192 | for _, vv := range t.conns {
193 | for _, cc := range vv {
194 | cc.closeIfIdle()
195 | }
196 | }
197 | }
198 |
199 | var errClientConnClosed = errors.New("http2: client conn is closed")
200 |
201 | func shouldRetryRequest(err error) bool {
202 | // TODO: or GOAWAY graceful shutdown stuff
203 | return err == errClientConnClosed
204 | }
205 |
206 | func (t *Transport) removeClientConn(cc *clientConn) {
207 | t.connMu.Lock()
208 | defer t.connMu.Unlock()
209 | for _, key := range cc.connKey {
210 | vv, ok := t.conns[key]
211 | if !ok {
212 | continue
213 | }
214 | newList := filterOutClientConn(vv, cc)
215 | if len(newList) > 0 {
216 | t.conns[key] = newList
217 | } else {
218 | delete(t.conns, key)
219 | }
220 | }
221 | }
222 |
223 | func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn {
224 | out := in[:0]
225 | for _, v := range in {
226 | if v != exclude {
227 | out = append(out, v)
228 | }
229 | }
230 | return out
231 | }
232 |
233 | func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
234 | t.connMu.Lock()
235 | defer t.connMu.Unlock()
236 |
237 | key := net.JoinHostPort(host, port)
238 |
239 | for _, cc := range t.conns[key] {
240 | if cc.canTakeNewRequest() {
241 | return cc, nil
242 | }
243 | }
244 | if t.conns == nil {
245 | t.conns = make(map[string][]*clientConn)
246 | }
247 | cc, err := t.newClientConn(host, port, key)
248 | if err != nil {
249 | return nil, err
250 | }
251 | t.conns[key] = append(t.conns[key], cc)
252 | return cc, nil
253 | }
254 |
255 | func (t *Transport) newClientConn(host, port, key string) (*clientConn, error) {
256 | cfg := &tls.Config{
257 | ServerName: host,
258 | NextProtos: []string{http2.NextProtoTLS},
259 | InsecureSkipVerify: t.InsecureTLSDial,
260 | }
261 | tconn, err := tls.Dial("tcp", host+":"+port, cfg)
262 | if err != nil {
263 | return nil, err
264 | }
265 | if err := tconn.Handshake(); err != nil {
266 | return nil, err
267 | }
268 | if !t.InsecureTLSDial {
269 | if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
270 | return nil, err
271 | }
272 | }
273 | state := tconn.ConnectionState()
274 | if p := state.NegotiatedProtocol; p != http2.NextProtoTLS {
275 | // TODO(bradfitz): fall back to Fallback
276 | return nil, fmt.Errorf("bad protocol: %v", p)
277 | }
278 | if !state.NegotiatedProtocolIsMutual {
279 | return nil, errors.New("could not negotiate protocol mutually")
280 | }
281 | if _, err := tconn.Write(clientPreface); err != nil {
282 | return nil, err
283 | }
284 |
285 | cc := &clientConn{
286 | t: t,
287 | tconn: tconn,
288 | connKey: []string{key}, // TODO: cert's validated hostnames too
289 | tlsState: &state,
290 | readerDone: make(chan struct{}),
291 | nextStreamID: 1,
292 | maxFrameSize: 16 << 10, // spec default
293 | initialWindowSize: 65535, // spec default
294 | maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough.
295 | streams: make(map[uint32]*clientStream),
296 | }
297 | cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
298 | cc.br = bufio.NewReader(tconn)
299 | cc.fr = http2.NewFramer(cc.bw, cc.br)
300 | cc.henc = hpack.NewEncoder(&cc.hbuf)
301 |
302 | cc.fr.WriteSettings()
303 | // TODO: re-send more conn-level flow control tokens when server uses all these.
304 | cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs?
305 | cc.bw.Flush()
306 | if cc.werr != nil {
307 | return nil, cc.werr
308 | }
309 |
310 | // Read the obligatory SETTINGS frame
311 | f, err := cc.fr.ReadFrame()
312 | if err != nil {
313 | return nil, err
314 | }
315 | sf, ok := f.(*http2.SettingsFrame)
316 | if !ok {
317 | return nil, fmt.Errorf("expected settings frame, got: %T", f)
318 | }
319 | cc.fr.WriteSettingsAck()
320 | cc.bw.Flush()
321 |
322 | sf.ForeachSetting(func(s http2.Setting) error {
323 | switch s.ID {
324 | case http2.SettingMaxFrameSize:
325 | cc.maxFrameSize = s.Val
326 | case http2.SettingMaxConcurrentStreams:
327 | cc.maxConcurrentStreams = s.Val
328 | case http2.SettingInitialWindowSize:
329 | cc.initialWindowSize = s.Val
330 | default:
331 | // TODO(bradfitz): handle more
332 | log.Printf("Unhandled Setting: %v", s)
333 | }
334 | return nil
335 | })
336 | // TODO: figure out henc size
337 | cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
338 |
339 | go cc.readLoop()
340 | go cc.cancelLoop()
341 | return cc, nil
342 | }
343 |
344 | func (cc *clientConn) setGoAway(f *http2.GoAwayFrame) {
345 | cc.mu.Lock()
346 | defer cc.mu.Unlock()
347 | cc.goAway = f
348 | }
349 |
350 | func (cc *clientConn) canTakeNewRequest() bool {
351 | cc.mu.Lock()
352 | defer cc.mu.Unlock()
353 | return cc.goAway == nil &&
354 | int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
355 | cc.nextStreamID < 2147483647
356 | }
357 |
358 | func (cc *clientConn) closeIfIdle() {
359 | cc.mu.Lock()
360 | if len(cc.streams) > 0 {
361 | cc.mu.Unlock()
362 | return
363 | }
364 | cc.closed = true
365 | // TODO: do clients send GOAWAY too? maybe? Just Close:
366 | cc.mu.Unlock()
367 |
368 | cc.tconn.Close()
369 | }
370 |
371 | func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
372 | cc.mu.Lock()
373 |
374 | if cc.closed {
375 | cc.mu.Unlock()
376 | return nil, errClientConnClosed
377 | }
378 |
379 | cs := cc.newStream()
380 | hasBody := false // TODO
381 |
382 | // we send: HEADERS[+CONTINUATION] + (DATA?)
383 | hdrs := cc.encodeHeaders(req)
384 | first := true
385 | for len(hdrs) > 0 {
386 | chunk := hdrs
387 | if len(chunk) > int(cc.maxFrameSize) {
388 | chunk = chunk[:cc.maxFrameSize]
389 | }
390 | hdrs = hdrs[len(chunk):]
391 | endHeaders := len(hdrs) == 0
392 | if first {
393 | cc.fr.WriteHeaders(http2.HeadersFrameParam{
394 | StreamID: cs.ID,
395 | BlockFragment: chunk,
396 | EndStream: !hasBody,
397 | EndHeaders: endHeaders,
398 | })
399 | first = false
400 | } else {
401 | cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
402 | }
403 | }
404 | cc.bw.Flush()
405 | werr := cc.werr
406 | cc.mu.Unlock()
407 |
408 | if hasBody {
409 | // TODO: write data. and it should probably be interleaved:
410 | // go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc
411 | }
412 |
413 | if werr != nil {
414 | return nil, werr
415 | }
416 |
417 | re := <-cs.resc
418 | if re.err != nil {
419 | return nil, re.err
420 | }
421 | res := re.res
422 | res.Request = req
423 | res.TLS = cc.tlsState
424 | return res, nil
425 | }
426 |
427 | // listen sends a request, gets the headers of the response, and then listens
428 | // for additional data frames.
429 | //
430 | // Data frames are then shunted onto the data channel.
431 | //
432 | // When the remote indicates that the transaction is complete, the channel is
433 | // closed.
434 | func (cc *clientConn) listen(req *http.Request) (*http.Response, *clientStream, error) {
435 | cc.mu.Lock()
436 |
437 | if cc.closed {
438 | cc.mu.Unlock()
439 | return nil, nil, errClientConnClosed
440 | }
441 |
442 | cs := cc.newStream()
443 | hasBody := false // TODO
444 |
445 | cs.dataToChan = true
446 | cs.data = make(chan []byte, 1)
447 |
448 | // we send: HEADERS[+CONTINUATION] + (DATA?)
449 | hdrs := cc.encodeHeaders(req)
450 | first := true
451 | for len(hdrs) > 0 {
452 | chunk := hdrs
453 | if len(chunk) > int(cc.maxFrameSize) {
454 | chunk = chunk[:cc.maxFrameSize]
455 | }
456 | hdrs = hdrs[len(chunk):]
457 | endHeaders := len(hdrs) == 0
458 | if first {
459 | cc.fr.WriteHeaders(http2.HeadersFrameParam{
460 | StreamID: cs.ID,
461 | BlockFragment: chunk,
462 | EndStream: !hasBody,
463 | EndHeaders: endHeaders,
464 | })
465 | first = false
466 | } else {
467 | cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
468 | }
469 | }
470 | cc.bw.Flush()
471 | werr := cc.werr
472 | cc.mu.Unlock()
473 |
474 | if hasBody {
475 | // TODO: write data. and it should probably be interleaved:
476 | // go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc
477 | }
478 |
479 | if werr != nil {
480 | return nil, nil, werr
481 | }
482 |
483 | re := <-cs.resc
484 | if re.err != nil {
485 | //return nil, re.err
486 | fmt.Printf("Error closing client stream: %s", re.err)
487 | }
488 | res := re.res
489 | res.Request = req
490 | res.TLS = cc.tlsState
491 | return res, cs, nil
492 | }
493 |
494 | // requires cc.mu be held.
495 | func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
496 | cc.hbuf.Reset()
497 |
498 | // TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
499 | host := req.Host
500 | if host == "" {
501 | host = req.URL.Host
502 | }
503 |
504 | path := req.URL.Path
505 | if path == "" {
506 | path = "/"
507 | }
508 |
509 | cc.writeHeader(":authority", host) // probably not right for all sites
510 | cc.writeHeader(":method", req.Method)
511 | cc.writeHeader(":path", path)
512 | cc.writeHeader(":scheme", "https")
513 |
514 | for k, vv := range req.Header {
515 | lowKey := strings.ToLower(k)
516 | if lowKey == "host" {
517 | continue
518 | }
519 | for _, v := range vv {
520 | cc.writeHeader(lowKey, v)
521 | }
522 | }
523 | return cc.hbuf.Bytes()
524 | }
525 |
526 | func (cc *clientConn) writeHeader(name, value string) {
527 | log.Printf("sending %q = %q", name, value)
528 | cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
529 | }
530 |
531 | type resAndError struct {
532 | res *http.Response
533 | err error
534 | }
535 |
536 | // requires cc.mu be held.
537 | func (cc *clientConn) newStream() *clientStream {
538 | cs := &clientStream{
539 | ID: cc.nextStreamID,
540 | resc: make(chan resAndError, 1),
541 | cancel: make(chan bool, 1),
542 | }
543 | cc.nextStreamID += 2
544 | cc.streams[cs.ID] = cs
545 | return cs
546 | }
547 |
548 | func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
549 | cc.mu.Lock()
550 | defer cc.mu.Unlock()
551 | cs := cc.streams[id]
552 | if andRemove {
553 | delete(cc.streams, id)
554 | }
555 | return cs
556 | }
557 |
558 | func (cc *clientConn) cancelLoop() {
559 | for {
560 | // Find anything that needs to be removed and schedule it for removal.
561 | remove := map[uint32]*clientStream{}
562 | for streamID, cs := range cc.streams {
563 | if len(cs.cancel) > 0 {
564 | remove[streamID] = cs
565 | }
566 | }
567 | // Run removals.
568 | for streamID, cs := range remove {
569 | log.Printf("Canceling %d\n", streamID)
570 | if cs.dataToChan {
571 | close(cs.data)
572 | }
573 | cs.pw.Close()
574 | delete(cc.streams, streamID)
575 | }
576 | // Yield to GC.
577 | time.Sleep(3 * time.Millisecond)
578 | }
579 | }
580 |
581 | // runs in its own goroutine.
582 | func (cc *clientConn) readLoop() {
583 | defer cc.t.removeClientConn(cc)
584 | defer close(cc.readerDone)
585 |
586 | activeRes := map[uint32]*clientStream{} // keyed by streamID
587 | // Close any response bodies if the server closes prematurely.
588 | // TODO: also do this if we've written the headers but not
589 | // gotten a response yet.
590 | defer func() {
591 | err := cc.readerErr
592 | if err == io.EOF {
593 | err = io.ErrUnexpectedEOF
594 | }
595 | for _, cs := range activeRes {
596 | if cs.dataToChan {
597 | close(cs.data)
598 | }
599 | cs.pw.CloseWithError(err)
600 | }
601 | }()
602 |
603 | // continueStreamID is the stream ID we're waiting for
604 | // continuation frames for.
605 | var continueStreamID uint32
606 |
607 | for {
608 | f, err := cc.fr.ReadFrame()
609 | if err != nil {
610 | cc.readerErr = err
611 | return
612 | }
613 | log.Printf("Transport received %v: %#v", f.Header(), f)
614 |
615 | streamID := f.Header().StreamID
616 |
617 | _, isContinue := f.(*http2.ContinuationFrame)
618 | if isContinue {
619 | if streamID != continueStreamID {
620 | log.Printf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID)
621 | cc.readerErr = http2.ConnectionError(http2.ErrCodeProtocol)
622 | return
623 | }
624 | } else if continueStreamID != 0 {
625 | // Continue frames need to be adjacent in the stream
626 | // and we were in the middle of headers.
627 | log.Printf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID)
628 | cc.readerErr = http2.ConnectionError(http2.ErrCodeProtocol)
629 | return
630 | }
631 |
632 | if streamID%2 == 0 {
633 | // Ignore streams pushed from the server for now.
634 | // These always have an even stream id.
635 | continue
636 | }
637 | streamEnded := false
638 | if ff, ok := f.(streamEnder); ok {
639 | streamEnded = ff.StreamEnded()
640 | }
641 |
642 | cs := cc.streamByID(streamID, streamEnded)
643 | if cs == nil {
644 | log.Printf("Received frame for untracked stream ID %d", streamID)
645 | continue
646 | }
647 |
648 | switch f := f.(type) {
649 | case *http2.HeadersFrame:
650 | cc.nextRes = &http.Response{
651 | Proto: "HTTP/2.0",
652 | ProtoMajor: 2,
653 | Header: make(http.Header),
654 | }
655 | cs.pr, cs.pw = io.Pipe()
656 | cc.hdec.Write(f.HeaderBlockFragment())
657 | case *http2.ContinuationFrame:
658 | cc.hdec.Write(f.HeaderBlockFragment())
659 | case *http2.DataFrame:
660 | log.Printf("DATA: %q", f.Data())
661 | if cs.dataToChan {
662 | cs.data <- f.Data()
663 | } else {
664 | cs.pw.Write(f.Data())
665 | }
666 | case *http2.GoAwayFrame:
667 | cc.t.removeClientConn(cc)
668 | if f.ErrCode != 0 {
669 | // TODO: deal with GOAWAY more. particularly the error code
670 | log.Printf("transport got GOAWAY with error code = %v", f.ErrCode)
671 | }
672 | cc.setGoAway(f)
673 | default:
674 | log.Printf("Transport: unhandled response frame type %T", f)
675 | }
676 | headersEnded := false
677 | if he, ok := f.(headersEnder); ok {
678 | headersEnded = he.HeadersEnded()
679 | if headersEnded {
680 | continueStreamID = 0
681 | } else {
682 | continueStreamID = streamID
683 | }
684 | }
685 |
686 | if streamEnded {
687 | if cs.dataToChan {
688 | close(cs.data)
689 | }
690 | cs.pw.Close()
691 | delete(activeRes, streamID)
692 | }
693 | if headersEnded {
694 | if cs == nil {
695 | panic("couldn't find stream") // TODO be graceful
696 | }
697 | // TODO: set the Body to one which notes the
698 | // Close and also sends the server a
699 | // RST_STREAM
700 | cc.nextRes.Body = cs.pr
701 | res := cc.nextRes
702 | activeRes[streamID] = cs
703 | cs.resc <- resAndError{res: res}
704 | }
705 | }
706 | }
707 |
708 | func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
709 | // TODO: verifiy pseudo headers come before non-pseudo headers
710 | // TODO: verifiy the status is set
711 | log.Printf("Header field: %+v", f)
712 | if f.Name == ":status" {
713 | code, err := strconv.Atoi(f.Value)
714 | if err != nil {
715 | panic("TODO: be graceful")
716 | }
717 | cc.nextRes.Status = f.Value + " " + http.StatusText(code)
718 | cc.nextRes.StatusCode = code
719 | return
720 | }
721 | if strings.HasPrefix(f.Name, ":") {
722 | // "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document."
723 | // TODO: treat as invalid?
724 | return
725 | }
726 | cc.nextRes.Header.Add(http.CanonicalHeaderKey(f.Name), f.Value)
727 | }
728 |
--------------------------------------------------------------------------------