├── README.md ├── client_test.go ├── conn.go ├── middleware └── middleware.go ├── proxy.go └── proxy_test.go /README.md: -------------------------------------------------------------------------------- 1 | # kafka-proxy 2 | 3 | This is a layer 7/application level proxy for Kafka with pluggable middleware support. You can use 4 | it to change client requests before they hit the broker, change responses from the broker being sent 5 | to the client, or pass through requests/responses while making some side-effect like tracking 6 | metrics or logs. 7 | 8 | ## License 9 | 10 | MIT 11 | 12 | --- 13 | 14 | - [travisjeffery.com](http://travisjeffery.com) 15 | - GitHub [@travisjeffery](https://github.com/travisjeffery) 16 | - Twitter [@travisjeffery](https://twitter.com/travisjeffery) 17 | - Medium [@travisjeffery](https://medium.com/@travisjeffery) 18 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by mocker; DO NOT EDIT 2 | // github.com/travisjeffery/mocker 3 | package kafkaproxy_test 4 | 5 | import ( 6 | "context" 7 | "github.com/travisjeffery/jocko/protocol" 8 | "sync" 9 | ) 10 | 11 | var ( 12 | lockMockClientRun sync.RWMutex 13 | ) 14 | 15 | // MockClient is a mock implementation of Client. 16 | // 17 | // func TestSomethingThatUsesClient(t *testing.T) { 18 | // 19 | // // make and configure a mocked Client 20 | // mockedClient := &MockClient{ 21 | // RunFunc: func(ctx context.Context,req *protocol.Request) (*protocol.Response, error) { 22 | // panic("TODO: mock out the Run method") 23 | // }, 24 | // } 25 | // 26 | // // TODO: use mockedClient in code that requires Client 27 | // // and then make assertions. 28 | // 29 | // } 30 | type MockClient struct { 31 | // RunFunc mocks the Run method. 32 | RunFunc func(ctx context.Context, req *protocol.Request) (*protocol.Response, error) 33 | 34 | // calls tracks calls to the methods. 35 | calls struct { 36 | // Run holds details about calls to the Run method. 37 | Run []struct { 38 | // Ctx is the ctx argument value. 39 | Ctx context.Context 40 | // Req is the req argument value. 41 | Req *protocol.Request 42 | } 43 | } 44 | } 45 | 46 | // Reset resets the calls made to the mocked APIs. 47 | func (mock *MockClient) Reset() { 48 | lockMockClientRun.Lock() 49 | mock.calls.Run = nil 50 | lockMockClientRun.Unlock() 51 | } 52 | 53 | // Run calls RunFunc. 54 | func (mock *MockClient) Run(ctx context.Context, req *protocol.Request) (*protocol.Response, error) { 55 | if mock.RunFunc == nil { 56 | panic("moq: MockClient.RunFunc is nil but Client.Run was just called") 57 | } 58 | callInfo := struct { 59 | Ctx context.Context 60 | Req *protocol.Request 61 | }{ 62 | Ctx: ctx, 63 | Req: req, 64 | } 65 | lockMockClientRun.Lock() 66 | mock.calls.Run = append(mock.calls.Run, callInfo) 67 | lockMockClientRun.Unlock() 68 | return mock.RunFunc(ctx, req) 69 | } 70 | 71 | // RunCalled returns true if at least one call was made to Run. 72 | func (mock *MockClient) RunCalled() bool { 73 | lockMockClientRun.RLock() 74 | defer lockMockClientRun.RUnlock() 75 | return len(mock.calls.Run) > 0 76 | } 77 | 78 | // RunCalls gets all the calls that were made to Run. 79 | // Check the length with: 80 | // len(mockedClient.RunCalls()) 81 | func (mock *MockClient) RunCalls() []struct { 82 | Ctx context.Context 83 | Req *protocol.Request 84 | } { 85 | var calls []struct { 86 | Ctx context.Context 87 | Req *protocol.Request 88 | } 89 | lockMockClientRun.RLock() 90 | calls = mock.calls.Run 91 | lockMockClientRun.RUnlock() 92 | return calls 93 | } 94 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package kafkaproxy 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/travisjeffery/jocko/jocko" 7 | "github.com/travisjeffery/jocko/protocol" 8 | ) 9 | 10 | type Conn struct { 11 | *jocko.Conn 12 | } 13 | 14 | func (c *Conn) Run(ctx context.Context, request *protocol.Request) (*protocol.Response, error) { 15 | response := &protocol.Response{ 16 | CorrelationID: request.CorrelationID, 17 | } 18 | var err error 19 | switch req := request.Body.(type) { 20 | case *protocol.ProduceRequest: 21 | response.Body, err = c.Produce(req) 22 | case *protocol.FetchRequest: 23 | response.Body, err = c.Fetch(req) 24 | case *protocol.OffsetsRequest: 25 | response.Body, err = c.Offsets(req) 26 | case *protocol.MetadataRequest: 27 | response.Body, err = c.Metadata(req) 28 | case *protocol.LeaderAndISRRequest: 29 | response.Body, err = c.LeaderAndISR(req) 30 | case *protocol.StopReplicaRequest: 31 | response.Body, err = c.StopReplica(req) 32 | case *protocol.UpdateMetadataRequest: 33 | response.Body, err = c.UpdateMetadata(req) 34 | case *protocol.ControlledShutdownRequest: 35 | response.Body, err = c.ControlledShutdown(req) 36 | case *protocol.OffsetCommitRequest: 37 | response.Body, err = c.OffsetCommit(req) 38 | case *protocol.OffsetFetchRequest: 39 | response.Body, err = c.OffsetFetch(req) 40 | case *protocol.FindCoordinatorRequest: 41 | response.Body, err = c.FindCoordinator(req) 42 | case *protocol.JoinGroupRequest: 43 | response.Body, err = c.JoinGroup(req) 44 | case *protocol.HeartbeatRequest: 45 | response.Body, err = c.Heartbeat(req) 46 | case *protocol.LeaveGroupRequest: 47 | response.Body, err = c.LeaveGroup(req) 48 | case *protocol.SyncGroupRequest: 49 | response.Body, err = c.SyncGroup(req) 50 | case *protocol.DescribeGroupsRequest: 51 | response.Body, err = c.DescribeGroups(req) 52 | case *protocol.ListGroupsRequest: 53 | response.Body, err = c.ListGroups(req) 54 | case *protocol.SaslHandshakeRequest: 55 | response.Body, err = c.SaslHandshake(req) 56 | case *protocol.APIVersionsRequest: 57 | response.Body, err = c.APIVersions(req) 58 | case *protocol.CreateTopicRequests: 59 | response.Body, err = c.CreateTopics(req) 60 | case *protocol.DeleteTopicsRequest: 61 | response.Body, err = c.DeleteTopics(req) 62 | } 63 | return response, err 64 | } 65 | -------------------------------------------------------------------------------- /middleware/middleware.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "time" 7 | 8 | "github.com/travisjeffery/jocko/protocol" 9 | kafkaproxy "github.com/travisjeffery/kafka-proxy" 10 | ) 11 | 12 | // Log is middleware that logs the request's api key and its duration. 13 | func Log() kafkaproxy.Middleware { 14 | return func(next kafkaproxy.Endpoint) kafkaproxy.Endpoint { 15 | return func(ctx context.Context, request *protocol.Request) (*protocol.Response, error) { 16 | t := time.Now() 17 | res, err := next(ctx, request) 18 | req := request.Body 19 | log.Printf("api key: %d, duration: %v\n", req.Key(), time.Since(t)) 20 | return res, err 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /proxy.go: -------------------------------------------------------------------------------- 1 | package kafkaproxy 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "sync" 7 | 8 | "gopkg.in/bufio.v1" 9 | 10 | "github.com/travisjeffery/jocko/protocol" 11 | ) 12 | 13 | // SizeLen is the number of bytes that make up the size field of the requet. 14 | const SizeLen = 4 15 | 16 | // Endpoint handles a API request and returns a response. 17 | type Endpoint func(ctx context.Context, request *protocol.Request) (response *protocol.Response, err error) 18 | 19 | // Middleware is a chainable behavior modifier for endpoints. 20 | type Middleware func(Endpoint) Endpoint 21 | 22 | // Client represents a Kafka client. 23 | type Client interface { 24 | Run(ctx context.Context, req *protocol.Request) (*protocol.Response, error) 25 | } 26 | 27 | // Proxy is a layer 7/application level proxy for Kafka with pluggable middleware support. You can 28 | // change client requests before being sent to the broker, or change responses from the broker 29 | // before being sent to the client. 30 | type Proxy struct { 31 | sync.Mutex 32 | err error 33 | donec chan struct{} 34 | ipPort string 35 | endpoint Endpoint 36 | ln net.Listener 37 | 38 | ListenFunc func(net, laddr string) (net.Listener, error) 39 | } 40 | 41 | // New creates a new kafka proxy that proxies requests made on the conn to the brokers connected to 42 | // the client. 43 | func New(ipPort string, c Client) *Proxy { 44 | return &Proxy{ 45 | ipPort: ipPort, 46 | donec: make(chan struct{}), 47 | ListenFunc: net.Listen, 48 | endpoint: func(ctx context.Context, request *protocol.Request) (*protocol.Response, error) { 49 | return c.Run(ctx, request) 50 | }, 51 | } 52 | } 53 | 54 | // With adds a new middleware to handle requests and responses. 55 | func (p *Proxy) With(me Middleware) { 56 | p.Lock() 57 | defer p.Unlock() 58 | p.endpoint = me(p.endpoint) 59 | } 60 | 61 | // Run starts the listener and request proxy. 62 | func (p *Proxy) Run(ctx context.Context) (err error) { 63 | p.ln, err = p.netListen()("tcp", p.ipPort) 64 | if err != nil { 65 | return err 66 | } 67 | errc := make(chan error, 1) 68 | go p.serveListener(errc, p.ln) 69 | go p.awaitFirstError(errc) 70 | return nil 71 | } 72 | 73 | // Wait waits for the Proxy to finish running. Currently this can only happen if a Listener is closed, or Close is already called on the proxy. 74 | // 75 | // It is only valid to call Wait after a successful call to Run. 76 | func (p *Proxy) Wait() error { 77 | close(p.donec) 78 | return p.err 79 | } 80 | 81 | // Close closes the proxy's listener. 82 | func (p *Proxy) Close() error { 83 | return p.ln.Close() 84 | } 85 | 86 | func (p *Proxy) serveListener(errc chan error, ln net.Listener) { 87 | for { 88 | c, err := ln.Accept() 89 | if err != nil { 90 | errc <- err 91 | return 92 | } 93 | go p.serveConn(errc, c) 94 | } 95 | } 96 | 97 | func (p *Proxy) serveConn(errc chan error, c net.Conn) { 98 | ctx := context.Background() 99 | br := bufio.NewReader(c) 100 | 101 | for { 102 | sizePeek, err := br.Peek(SizeLen) 103 | if err != nil { 104 | errc <- err 105 | return 106 | } 107 | 108 | reqSize := SizeLen + int(protocol.Encoding.Uint32(sizePeek)) 109 | 110 | var reqb []byte 111 | for { 112 | b, err := br.ReadN(reqSize) 113 | if err == bufio.ErrBufferFull && reqSize != br.Buffered() { 114 | reqb = append(reqb, b...) 115 | continue 116 | } 117 | if err != nil { 118 | errc <- err 119 | return 120 | } 121 | reqb = append(reqb, b...) 122 | break 123 | } 124 | 125 | // TODO: a decoder that takes a bufio.Reader would be nice 126 | d := protocol.NewDecoder(reqb) 127 | 128 | header := new(protocol.RequestHeader) 129 | if err := header.Decode(d); err != nil { 130 | errc <- err 131 | return 132 | } 133 | 134 | var req protocol.VersionedDecoder 135 | 136 | switch header.APIKey { 137 | case protocol.ProduceKey: 138 | req = &protocol.ProduceRequest{} 139 | case protocol.FetchKey: 140 | req = &protocol.FetchRequest{} 141 | case protocol.OffsetsKey: 142 | req = &protocol.OffsetsRequest{} 143 | case protocol.MetadataKey: 144 | req = &protocol.MetadataRequest{} 145 | case protocol.LeaderAndISRKey: 146 | req = &protocol.LeaderAndISRRequest{} 147 | case protocol.StopReplicaKey: 148 | req = &protocol.StopReplicaRequest{} 149 | case protocol.UpdateMetadataKey: 150 | req = &protocol.UpdateMetadataRequest{} 151 | case protocol.ControlledShutdownKey: 152 | req = &protocol.ControlledShutdownRequest{} 153 | case protocol.OffsetCommitKey: 154 | req = &protocol.OffsetCommitRequest{} 155 | case protocol.OffsetFetchKey: 156 | req = &protocol.OffsetFetchRequest{} 157 | case protocol.FindCoordinatorKey: 158 | req = &protocol.FindCoordinatorRequest{} 159 | case protocol.JoinGroupKey: 160 | req = &protocol.JoinGroupRequest{} 161 | case protocol.HeartbeatKey: 162 | req = &protocol.HeartbeatRequest{} 163 | case protocol.LeaveGroupKey: 164 | req = &protocol.LeaveGroupRequest{} 165 | case protocol.SyncGroupKey: 166 | req = &protocol.SyncGroupRequest{} 167 | case protocol.DescribeGroupsKey: 168 | req = &protocol.DescribeGroupsRequest{} 169 | case protocol.ListGroupsKey: 170 | req = &protocol.ListGroupsRequest{} 171 | case protocol.SaslHandshakeKey: 172 | req = &protocol.SaslHandshakeRequest{} 173 | case protocol.APIVersionsKey: 174 | req = &protocol.APIVersionsRequest{} 175 | case protocol.CreateTopicsKey: 176 | req = &protocol.CreateTopicRequests{} 177 | case protocol.DeleteTopicsKey: 178 | req = &protocol.DeleteTopicsRequest{} 179 | } 180 | 181 | if err := req.Decode(d, header.APIVersion); err != nil { 182 | errc <- err 183 | return 184 | } 185 | 186 | res, err := p.endpoint(ctx, &protocol.Request{ 187 | CorrelationID: header.CorrelationID, 188 | ClientID: header.ClientID, 189 | Body: req.(protocol.Body), 190 | }) 191 | if err != nil { 192 | errc <- err 193 | return 194 | } 195 | 196 | resb, err := protocol.Encode(res) 197 | if err != nil { 198 | errc <- err 199 | return 200 | } 201 | 202 | _, err = c.Write(resb) 203 | if err != nil { 204 | errc <- err 205 | return 206 | } 207 | } 208 | } 209 | 210 | func (p *Proxy) awaitFirstError(errc chan error) { 211 | p.err = <-errc 212 | close(p.donec) 213 | } 214 | 215 | func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) { 216 | if p.ListenFunc != nil { 217 | return p.ListenFunc 218 | } 219 | return net.Listen 220 | } 221 | -------------------------------------------------------------------------------- /proxy_test.go: -------------------------------------------------------------------------------- 1 | package kafkaproxy_test 2 | 3 | //go:generate mocker --out client_test.go --pkg kafkaproxy_test . Client 4 | 5 | import ( 6 | "bufio" 7 | "context" 8 | "errors" 9 | "flag" 10 | "net" 11 | "reflect" 12 | "testing" 13 | "time" 14 | 15 | "github.com/travisjeffery/jocko/jocko" 16 | "github.com/travisjeffery/jocko/protocol" 17 | kafkaproxy "github.com/travisjeffery/kafka-proxy" 18 | "github.com/travisjeffery/kafka-proxy/middleware" 19 | ) 20 | 21 | func TestProxy(t *testing.T) { 22 | ctx, cancel := context.WithCancel(context.Background()) 23 | defer cancel() 24 | 25 | for _, test := range []struct { 26 | // name of the test. 27 | name string 28 | // request sent from client, pre mw. 29 | clientReq *protocol.Request 30 | // response sent back to client, post mw. 31 | clientRes *protocol.Response 32 | // request server sees, post mw. 33 | serverReq *protocol.Request 34 | // response server sends back, pre mw. 35 | serverRes *protocol.Response 36 | // mw to setup the proxy with. 37 | mw []kafkaproxy.Middleware 38 | }{ 39 | { 40 | name: "passthrough with log", 41 | mw: []kafkaproxy.Middleware{middleware.Log()}, 42 | clientReq: &protocol.Request{ 43 | CorrelationID: 1, 44 | ClientID: "proxytest", 45 | Body: &protocol.ProduceRequest{ 46 | APIVersion: 1, 47 | Timeout: time.Second, 48 | TopicData: []*protocol.TopicData{}, 49 | }, 50 | }, 51 | serverReq: &protocol.Request{ 52 | CorrelationID: 1, 53 | ClientID: "proxytest", 54 | Body: &protocol.ProduceRequest{ 55 | APIVersion: 1, 56 | Timeout: time.Second, 57 | TopicData: []*protocol.TopicData{}, 58 | }, 59 | }, 60 | serverRes: &protocol.Response{ 61 | Body: &protocol.ProduceResponse{ 62 | APIVersion: 1, 63 | ThrottleTime: time.Second, 64 | Responses: []*protocol.ProduceTopicResponse{}, 65 | }, 66 | }, 67 | clientRes: &protocol.Response{ 68 | Body: &protocol.ProduceResponse{ 69 | APIVersion: 1, 70 | ThrottleTime: time.Second, 71 | Responses: []*protocol.ProduceTopicResponse{}, 72 | }, 73 | }, 74 | }, 75 | { 76 | name: "modify", 77 | mw: []kafkaproxy.Middleware{timeoutmw(time.Second * 3)}, 78 | clientReq: &protocol.Request{ 79 | CorrelationID: 1, 80 | ClientID: "proxytest", 81 | Body: &protocol.ProduceRequest{ 82 | APIVersion: 1, 83 | Timeout: time.Second, 84 | TopicData: []*protocol.TopicData{}, 85 | }, 86 | }, 87 | serverReq: &protocol.Request{ 88 | CorrelationID: 1, 89 | ClientID: "proxytest", 90 | Body: &protocol.ProduceRequest{ 91 | APIVersion: 1, 92 | Timeout: 3 * time.Second, 93 | TopicData: []*protocol.TopicData{}, 94 | }, 95 | }, 96 | serverRes: &protocol.Response{ 97 | Body: &protocol.ProduceResponse{ 98 | APIVersion: 1, 99 | ThrottleTime: time.Second, 100 | Responses: []*protocol.ProduceTopicResponse{}, 101 | }, 102 | }, 103 | clientRes: &protocol.Response{ 104 | Body: &protocol.ProduceResponse{ 105 | APIVersion: 1, 106 | ThrottleTime: time.Second * 3, 107 | Responses: []*protocol.ProduceTopicResponse{}, 108 | }, 109 | }, 110 | }, 111 | } { 112 | t.Run(test.name, func(t *testing.T) { 113 | c := &MockClient{ 114 | RunFunc: func(ctx context.Context, got *protocol.Request) (*protocol.Response, error) { 115 | if !reflect.DeepEqual(test.serverReq, got) { 116 | t.Fatalf("requests don't match: got: %v, want: %v", got, test.clientReq.Body) 117 | } 118 | return test.serverRes, nil 119 | }, 120 | } 121 | 122 | ln := newLocalListener(t) 123 | p := kafkaproxy.New(":0", c) 124 | p.ListenFunc = testListenFunc(t, ln) 125 | 126 | for _, mw := range test.mw { 127 | p.With(mw) 128 | } 129 | 130 | go func() { 131 | if err := p.Run(ctx); err != nil { 132 | t.Fatal(err) 133 | } 134 | }() 135 | 136 | client := testConn(t, ln) 137 | 138 | b, err := protocol.Encode(test.clientReq) 139 | if err != nil { 140 | t.Fatal(err) 141 | } 142 | _, err = client.Write(b) 143 | if err != nil { 144 | t.Fatal(err) 145 | } 146 | 147 | time.Sleep(100 * time.Millisecond) 148 | 149 | if !c.RunCalled() { 150 | t.Fatalf("run not called") 151 | } 152 | 153 | br := bufio.NewReader(client) 154 | sizePeek, err := br.Peek(kafkaproxy.SizeLen) 155 | if err != nil { 156 | t.Fatal(err) 157 | } 158 | size := int(protocol.Encoding.Uint32(sizePeek)) 159 | resPeek, err := br.Peek(kafkaproxy.SizeLen + size) 160 | if err != nil { 161 | t.Fatal(err) 162 | } 163 | d := protocol.NewDecoder(resPeek) 164 | clientRes := &protocol.Response{ 165 | Body: &protocol.ProduceResponse{}, 166 | } 167 | if err = clientRes.Decode(d, test.clientReq.Body.Version()); err != nil { 168 | t.Fatal(err) 169 | } 170 | if !reflect.DeepEqual(test.clientRes, clientRes) { 171 | t.Fatalf("client response: got: %v, want: %v", clientRes, test.clientRes) 172 | } 173 | }) 174 | } 175 | } 176 | 177 | func TestProxyLive(t *testing.T) { 178 | addr := flag.String("addr", "", "broker addr") 179 | flag.Parse() 180 | if *addr == "" { 181 | t.Skip() 182 | } 183 | 184 | ctx := context.Background() 185 | conn, err := jocko.Dial("tcp", *addr) 186 | if err != nil { 187 | t.Fatal(err) 188 | } 189 | c := &kafkaproxy.Conn{conn} 190 | 191 | for _, test := range []struct { 192 | // name of the test. 193 | name string 194 | // request sent from client, pre mw. 195 | clientReq *protocol.Request 196 | // response sent back to client, post mw. 197 | clientRes *protocol.Response 198 | // request server sees, post mw. 199 | serverReq *protocol.Request 200 | // response server sends back, pre mw. 201 | serverRes *protocol.Response 202 | // mw to setup the proxy with. 203 | mw []kafkaproxy.Middleware 204 | }{ 205 | { 206 | name: "passthrough with log", 207 | mw: []kafkaproxy.Middleware{middleware.Log()}, 208 | clientReq: &protocol.Request{ 209 | CorrelationID: 1, 210 | ClientID: "proxytest", 211 | Body: &protocol.ProduceRequest{ 212 | APIVersion: 1, 213 | Timeout: time.Second, 214 | TopicData: []*protocol.TopicData{}, 215 | }, 216 | }, 217 | serverReq: &protocol.Request{ 218 | CorrelationID: 1, 219 | ClientID: "proxytest", 220 | Body: &protocol.ProduceRequest{ 221 | APIVersion: 1, 222 | Timeout: time.Second, 223 | TopicData: []*protocol.TopicData{}, 224 | }, 225 | }, 226 | serverRes: &protocol.Response{ 227 | Body: &protocol.ProduceResponse{ 228 | APIVersion: 1, 229 | ThrottleTime: time.Second, 230 | Responses: []*protocol.ProduceTopicResponse{}, 231 | }, 232 | }, 233 | clientRes: &protocol.Response{ 234 | Body: &protocol.ProduceResponse{ 235 | APIVersion: 1, 236 | ThrottleTime: time.Second, 237 | Responses: []*protocol.ProduceTopicResponse{}, 238 | }, 239 | }, 240 | }, 241 | { 242 | name: "modify", 243 | mw: []kafkaproxy.Middleware{timeoutmw(time.Second * 3)}, 244 | clientReq: &protocol.Request{ 245 | CorrelationID: 1, 246 | ClientID: "proxytest", 247 | Body: &protocol.ProduceRequest{ 248 | APIVersion: 1, 249 | Timeout: time.Second, 250 | TopicData: []*protocol.TopicData{}, 251 | }, 252 | }, 253 | serverReq: &protocol.Request{ 254 | CorrelationID: 1, 255 | ClientID: "proxytest", 256 | Body: &protocol.ProduceRequest{ 257 | APIVersion: 1, 258 | Timeout: 3 * time.Second, 259 | TopicData: []*protocol.TopicData{}, 260 | }, 261 | }, 262 | serverRes: &protocol.Response{ 263 | Body: &protocol.ProduceResponse{ 264 | APIVersion: 1, 265 | ThrottleTime: time.Second, 266 | Responses: []*protocol.ProduceTopicResponse{}, 267 | }, 268 | }, 269 | clientRes: &protocol.Response{ 270 | Body: &protocol.ProduceResponse{ 271 | APIVersion: 1, 272 | ThrottleTime: time.Second * 3, 273 | Responses: []*protocol.ProduceTopicResponse{}, 274 | }, 275 | }, 276 | }, 277 | } { 278 | t.Run(test.name, func(t *testing.T) { 279 | ln := newLocalListener(t) 280 | p := kafkaproxy.New(":0", c) 281 | p.ListenFunc = testListenFunc(t, ln) 282 | 283 | for _, mw := range test.mw { 284 | p.With(mw) 285 | } 286 | 287 | go func() { 288 | if err := p.Run(ctx); err != nil { 289 | t.Fatal(err) 290 | } 291 | }() 292 | 293 | client := testConn(t, ln) 294 | 295 | b, err := protocol.Encode(test.clientReq) 296 | if err != nil { 297 | t.Fatal(err) 298 | } 299 | _, err = client.Write(b) 300 | if err != nil { 301 | t.Fatal(err) 302 | } 303 | 304 | time.Sleep(100 * time.Millisecond) 305 | 306 | br := bufio.NewReader(client) 307 | sizePeek, err := br.Peek(kafkaproxy.SizeLen) 308 | if err != nil { 309 | t.Fatal(err) 310 | } 311 | size := int(protocol.Encoding.Uint32(sizePeek)) 312 | resPeek, err := br.Peek(kafkaproxy.SizeLen + size) 313 | if err != nil { 314 | t.Fatal(err) 315 | } 316 | d := protocol.NewDecoder(resPeek) 317 | clientRes := &protocol.Response{ 318 | Body: &protocol.ProduceResponse{}, 319 | } 320 | if err = clientRes.Decode(d, test.clientReq.Body.Version()); err != nil { 321 | t.Fatal(err) 322 | } 323 | if !reflect.DeepEqual(test.clientRes, clientRes) { 324 | t.Fatalf("client response: got: %v, want: %v", clientRes, test.clientRes) 325 | } 326 | }) 327 | } 328 | } 329 | 330 | func newLocalListener(t *testing.T) net.Listener { 331 | ln, err := net.Listen("tcp", "127.0.0.1:0") 332 | if err != nil { 333 | ln, err = net.Listen("tcp", "[::1]:0") 334 | if err != nil { 335 | t.Fatal(err) 336 | } 337 | } 338 | return ln 339 | } 340 | 341 | func testListenFunc(t *testing.T, ln net.Listener) func(network, laddr string) (net.Listener, error) { 342 | return func(network, laddr string) (net.Listener, error) { 343 | if network != "tcp" { 344 | t.Errorf("got Listen call with network %q, not tcp", network) 345 | return nil, errors.New("invalid network") 346 | } 347 | return ln, nil 348 | } 349 | } 350 | 351 | func testConn(t *testing.T, ln net.Listener) net.Conn { 352 | t.Helper() 353 | client, err := net.Dial("tcp", ln.Addr().String()) 354 | if err != nil { 355 | t.Fatal(err) 356 | } 357 | return client 358 | } 359 | 360 | func timeoutmw(timeout time.Duration) kafkaproxy.Middleware { 361 | return func(next kafkaproxy.Endpoint) kafkaproxy.Endpoint { 362 | return func(ctx context.Context, request *protocol.Request) (*protocol.Response, error) { 363 | req := request.Body.(*protocol.ProduceRequest) 364 | req.Timeout = timeout 365 | response, err := next(ctx, request) 366 | res := response.Body.(*protocol.ProduceResponse) 367 | res.ThrottleTime = timeout 368 | return response, err 369 | } 370 | } 371 | } 372 | --------------------------------------------------------------------------------