├── LICENSE ├── README.md ├── _examples └── cmd │ └── wsechoserver │ ├── echoserver │ ├── Makefile │ ├── echoserver.pb.go │ ├── echoserver.pb.gw.go │ ├── echoserver.proto │ └── gen.go │ ├── main.go │ └── server.go ├── go.mod ├── go.sum └── wsproxy ├── doc.go └── websocket_proxy.go /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2016 Travis Cline 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # grpc-websocket-proxy 2 | 3 | [![GoDoc](https://godoc.org/github.com/tmc/grpc-websocket-proxy/wsproxy?status.svg)](http://godoc.org/github.com/tmc/grpc-websocket-proxy/wsproxy) 4 | 5 | Wrap your grpc-gateway mux with this helper to expose streaming endpoints over websockets. 6 | 7 | On the wire this uses newline-delimited json encoding of the messages. 8 | 9 | Usage: 10 | ```diff 11 | mux := runtime.NewServeMux() 12 | opts := []grpc.DialOption{grpc.WithInsecure()} 13 | if err := echoserver.RegisterEchoServiceHandlerFromEndpoint(ctx, mux, *grpcAddr, opts); err != nil { 14 | return err 15 | } 16 | - http.ListenAndServe(*httpAddr, mux) 17 | + http.ListenAndServe(*httpAddr, wsproxy.WebsocketProxy(mux)) 18 | ``` 19 | 20 | 21 | # wsproxy 22 | import "github.com/tmc/grpc-websocket-proxy/wsproxy" 23 | 24 | Package wsproxy implements a websocket proxy for grpc-gateway backed services 25 | 26 | ## Usage 27 | 28 | ```go 29 | var ( 30 | MethodOverrideParam = "method" 31 | TokenCookieName = "token" 32 | ) 33 | ``` 34 | 35 | #### func WebsocketProxy 36 | 37 | ```go 38 | func WebsocketProxy(h http.Handler) http.HandlerFunc 39 | ``` 40 | WebsocketProxy attempts to expose the underlying handler as a bidi websocket 41 | stream with newline-delimited JSON as the content encoding. 42 | 43 | The HTTP Authorization header is either populated from the 44 | Sec-Websocket-Protocol field or by a cookie. The cookie name is specified by the 45 | TokenCookieName value. 46 | 47 | example: 48 | 49 | Sec-Websocket-Protocol: Bearer, foobar 50 | 51 | is converted to: 52 | 53 | Authorization: Bearer foobar 54 | 55 | Method can be overwritten with the MethodOverrideParam get parameter in the 56 | requested URL 57 | -------------------------------------------------------------------------------- /_examples/cmd/wsechoserver/echoserver/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | protoc -I/usr/local/include -I. \ 3 | -I${GOPATH}/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis \ 4 | --go_out=plugins=grpc:. \ 5 | --grpc-gateway_out=logtostderr=true:. \ 6 | echoserver.proto 7 | -------------------------------------------------------------------------------- /_examples/cmd/wsechoserver/echoserver/echoserver.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. 2 | // source: echoserver.proto 3 | // DO NOT EDIT! 4 | 5 | /* 6 | Package echoserver is a generated protocol buffer package. 7 | 8 | It is generated from these files: 9 | echoserver.proto 10 | 11 | It has these top-level messages: 12 | EchoRequest 13 | EchoResponse 14 | Heartbeat 15 | Empty 16 | */ 17 | package echoserver 18 | 19 | import proto "github.com/golang/protobuf/proto" 20 | import fmt "fmt" 21 | import math "math" 22 | import _ "google.golang.org/genproto/googleapis/api/annotations" 23 | 24 | import ( 25 | context "golang.org/x/net/context" 26 | grpc "google.golang.org/grpc" 27 | ) 28 | 29 | // Reference imports to suppress errors if they are not otherwise used. 30 | var _ = proto.Marshal 31 | var _ = fmt.Errorf 32 | var _ = math.Inf 33 | 34 | // This is a compile-time assertion to ensure that this generated file 35 | // is compatible with the proto package it is being compiled against. 36 | // A compilation error at this line likely means your copy of the 37 | // proto package needs to be updated. 38 | const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package 39 | 40 | type Heartbeat_Status int32 41 | 42 | const ( 43 | Heartbeat_UNKNOWN Heartbeat_Status = 0 44 | Heartbeat_OK Heartbeat_Status = 1 45 | ) 46 | 47 | var Heartbeat_Status_name = map[int32]string{ 48 | 0: "UNKNOWN", 49 | 1: "OK", 50 | } 51 | var Heartbeat_Status_value = map[string]int32{ 52 | "UNKNOWN": 0, 53 | "OK": 1, 54 | } 55 | 56 | func (x Heartbeat_Status) String() string { 57 | return proto.EnumName(Heartbeat_Status_name, int32(x)) 58 | } 59 | func (Heartbeat_Status) EnumDescriptor() ([]byte, []int) { return fileDescriptor0, []int{2, 0} } 60 | 61 | type EchoRequest struct { 62 | Message string `protobuf:"bytes,1,opt,name=message" json:"message,omitempty"` 63 | } 64 | 65 | func (m *EchoRequest) Reset() { *m = EchoRequest{} } 66 | func (m *EchoRequest) String() string { return proto.CompactTextString(m) } 67 | func (*EchoRequest) ProtoMessage() {} 68 | func (*EchoRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } 69 | 70 | func (m *EchoRequest) GetMessage() string { 71 | if m != nil { 72 | return m.Message 73 | } 74 | return "" 75 | } 76 | 77 | type EchoResponse struct { 78 | Message string `protobuf:"bytes,1,opt,name=message" json:"message,omitempty"` 79 | } 80 | 81 | func (m *EchoResponse) Reset() { *m = EchoResponse{} } 82 | func (m *EchoResponse) String() string { return proto.CompactTextString(m) } 83 | func (*EchoResponse) ProtoMessage() {} 84 | func (*EchoResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } 85 | 86 | func (m *EchoResponse) GetMessage() string { 87 | if m != nil { 88 | return m.Message 89 | } 90 | return "" 91 | } 92 | 93 | type Heartbeat struct { 94 | Status Heartbeat_Status `protobuf:"varint,1,opt,name=status,enum=echoserver.Heartbeat_Status" json:"status,omitempty"` 95 | } 96 | 97 | func (m *Heartbeat) Reset() { *m = Heartbeat{} } 98 | func (m *Heartbeat) String() string { return proto.CompactTextString(m) } 99 | func (*Heartbeat) ProtoMessage() {} 100 | func (*Heartbeat) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } 101 | 102 | func (m *Heartbeat) GetStatus() Heartbeat_Status { 103 | if m != nil { 104 | return m.Status 105 | } 106 | return Heartbeat_UNKNOWN 107 | } 108 | 109 | type Empty struct { 110 | } 111 | 112 | func (m *Empty) Reset() { *m = Empty{} } 113 | func (m *Empty) String() string { return proto.CompactTextString(m) } 114 | func (*Empty) ProtoMessage() {} 115 | func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} } 116 | 117 | func init() { 118 | proto.RegisterType((*EchoRequest)(nil), "echoserver.EchoRequest") 119 | proto.RegisterType((*EchoResponse)(nil), "echoserver.EchoResponse") 120 | proto.RegisterType((*Heartbeat)(nil), "echoserver.Heartbeat") 121 | proto.RegisterType((*Empty)(nil), "echoserver.Empty") 122 | proto.RegisterEnum("echoserver.Heartbeat_Status", Heartbeat_Status_name, Heartbeat_Status_value) 123 | } 124 | 125 | // Reference imports to suppress errors if they are not otherwise used. 126 | var _ context.Context 127 | var _ grpc.ClientConn 128 | 129 | // This is a compile-time assertion to ensure that this generated file 130 | // is compatible with the grpc package it is being compiled against. 131 | const _ = grpc.SupportPackageIsVersion4 132 | 133 | // Client API for EchoService service 134 | 135 | type EchoServiceClient interface { 136 | Echo(ctx context.Context, opts ...grpc.CallOption) (EchoService_EchoClient, error) 137 | Stream(ctx context.Context, in *Empty, opts ...grpc.CallOption) (EchoService_StreamClient, error) 138 | Heartbeats(ctx context.Context, opts ...grpc.CallOption) (EchoService_HeartbeatsClient, error) 139 | } 140 | 141 | type echoServiceClient struct { 142 | cc *grpc.ClientConn 143 | } 144 | 145 | func NewEchoServiceClient(cc *grpc.ClientConn) EchoServiceClient { 146 | return &echoServiceClient{cc} 147 | } 148 | 149 | func (c *echoServiceClient) Echo(ctx context.Context, opts ...grpc.CallOption) (EchoService_EchoClient, error) { 150 | stream, err := grpc.NewClientStream(ctx, &_EchoService_serviceDesc.Streams[0], c.cc, "/echoserver.EchoService/Echo", opts...) 151 | if err != nil { 152 | return nil, err 153 | } 154 | x := &echoServiceEchoClient{stream} 155 | return x, nil 156 | } 157 | 158 | type EchoService_EchoClient interface { 159 | Send(*EchoRequest) error 160 | Recv() (*EchoResponse, error) 161 | grpc.ClientStream 162 | } 163 | 164 | type echoServiceEchoClient struct { 165 | grpc.ClientStream 166 | } 167 | 168 | func (x *echoServiceEchoClient) Send(m *EchoRequest) error { 169 | return x.ClientStream.SendMsg(m) 170 | } 171 | 172 | func (x *echoServiceEchoClient) Recv() (*EchoResponse, error) { 173 | m := new(EchoResponse) 174 | if err := x.ClientStream.RecvMsg(m); err != nil { 175 | return nil, err 176 | } 177 | return m, nil 178 | } 179 | 180 | func (c *echoServiceClient) Stream(ctx context.Context, in *Empty, opts ...grpc.CallOption) (EchoService_StreamClient, error) { 181 | stream, err := grpc.NewClientStream(ctx, &_EchoService_serviceDesc.Streams[1], c.cc, "/echoserver.EchoService/Stream", opts...) 182 | if err != nil { 183 | return nil, err 184 | } 185 | x := &echoServiceStreamClient{stream} 186 | if err := x.ClientStream.SendMsg(in); err != nil { 187 | return nil, err 188 | } 189 | if err := x.ClientStream.CloseSend(); err != nil { 190 | return nil, err 191 | } 192 | return x, nil 193 | } 194 | 195 | type EchoService_StreamClient interface { 196 | Recv() (*EchoResponse, error) 197 | grpc.ClientStream 198 | } 199 | 200 | type echoServiceStreamClient struct { 201 | grpc.ClientStream 202 | } 203 | 204 | func (x *echoServiceStreamClient) Recv() (*EchoResponse, error) { 205 | m := new(EchoResponse) 206 | if err := x.ClientStream.RecvMsg(m); err != nil { 207 | return nil, err 208 | } 209 | return m, nil 210 | } 211 | 212 | func (c *echoServiceClient) Heartbeats(ctx context.Context, opts ...grpc.CallOption) (EchoService_HeartbeatsClient, error) { 213 | stream, err := grpc.NewClientStream(ctx, &_EchoService_serviceDesc.Streams[2], c.cc, "/echoserver.EchoService/Heartbeats", opts...) 214 | if err != nil { 215 | return nil, err 216 | } 217 | x := &echoServiceHeartbeatsClient{stream} 218 | return x, nil 219 | } 220 | 221 | type EchoService_HeartbeatsClient interface { 222 | Send(*Empty) error 223 | Recv() (*Heartbeat, error) 224 | grpc.ClientStream 225 | } 226 | 227 | type echoServiceHeartbeatsClient struct { 228 | grpc.ClientStream 229 | } 230 | 231 | func (x *echoServiceHeartbeatsClient) Send(m *Empty) error { 232 | return x.ClientStream.SendMsg(m) 233 | } 234 | 235 | func (x *echoServiceHeartbeatsClient) Recv() (*Heartbeat, error) { 236 | m := new(Heartbeat) 237 | if err := x.ClientStream.RecvMsg(m); err != nil { 238 | return nil, err 239 | } 240 | return m, nil 241 | } 242 | 243 | // Server API for EchoService service 244 | 245 | type EchoServiceServer interface { 246 | Echo(EchoService_EchoServer) error 247 | Stream(*Empty, EchoService_StreamServer) error 248 | Heartbeats(EchoService_HeartbeatsServer) error 249 | } 250 | 251 | func RegisterEchoServiceServer(s *grpc.Server, srv EchoServiceServer) { 252 | s.RegisterService(&_EchoService_serviceDesc, srv) 253 | } 254 | 255 | func _EchoService_Echo_Handler(srv interface{}, stream grpc.ServerStream) error { 256 | return srv.(EchoServiceServer).Echo(&echoServiceEchoServer{stream}) 257 | } 258 | 259 | type EchoService_EchoServer interface { 260 | Send(*EchoResponse) error 261 | Recv() (*EchoRequest, error) 262 | grpc.ServerStream 263 | } 264 | 265 | type echoServiceEchoServer struct { 266 | grpc.ServerStream 267 | } 268 | 269 | func (x *echoServiceEchoServer) Send(m *EchoResponse) error { 270 | return x.ServerStream.SendMsg(m) 271 | } 272 | 273 | func (x *echoServiceEchoServer) Recv() (*EchoRequest, error) { 274 | m := new(EchoRequest) 275 | if err := x.ServerStream.RecvMsg(m); err != nil { 276 | return nil, err 277 | } 278 | return m, nil 279 | } 280 | 281 | func _EchoService_Stream_Handler(srv interface{}, stream grpc.ServerStream) error { 282 | m := new(Empty) 283 | if err := stream.RecvMsg(m); err != nil { 284 | return err 285 | } 286 | return srv.(EchoServiceServer).Stream(m, &echoServiceStreamServer{stream}) 287 | } 288 | 289 | type EchoService_StreamServer interface { 290 | Send(*EchoResponse) error 291 | grpc.ServerStream 292 | } 293 | 294 | type echoServiceStreamServer struct { 295 | grpc.ServerStream 296 | } 297 | 298 | func (x *echoServiceStreamServer) Send(m *EchoResponse) error { 299 | return x.ServerStream.SendMsg(m) 300 | } 301 | 302 | func _EchoService_Heartbeats_Handler(srv interface{}, stream grpc.ServerStream) error { 303 | return srv.(EchoServiceServer).Heartbeats(&echoServiceHeartbeatsServer{stream}) 304 | } 305 | 306 | type EchoService_HeartbeatsServer interface { 307 | Send(*Heartbeat) error 308 | Recv() (*Empty, error) 309 | grpc.ServerStream 310 | } 311 | 312 | type echoServiceHeartbeatsServer struct { 313 | grpc.ServerStream 314 | } 315 | 316 | func (x *echoServiceHeartbeatsServer) Send(m *Heartbeat) error { 317 | return x.ServerStream.SendMsg(m) 318 | } 319 | 320 | func (x *echoServiceHeartbeatsServer) Recv() (*Empty, error) { 321 | m := new(Empty) 322 | if err := x.ServerStream.RecvMsg(m); err != nil { 323 | return nil, err 324 | } 325 | return m, nil 326 | } 327 | 328 | var _EchoService_serviceDesc = grpc.ServiceDesc{ 329 | ServiceName: "echoserver.EchoService", 330 | HandlerType: (*EchoServiceServer)(nil), 331 | Methods: []grpc.MethodDesc{}, 332 | Streams: []grpc.StreamDesc{ 333 | { 334 | StreamName: "Echo", 335 | Handler: _EchoService_Echo_Handler, 336 | ServerStreams: true, 337 | ClientStreams: true, 338 | }, 339 | { 340 | StreamName: "Stream", 341 | Handler: _EchoService_Stream_Handler, 342 | ServerStreams: true, 343 | }, 344 | { 345 | StreamName: "Heartbeats", 346 | Handler: _EchoService_Heartbeats_Handler, 347 | ServerStreams: true, 348 | ClientStreams: true, 349 | }, 350 | }, 351 | Metadata: "echoserver.proto", 352 | } 353 | 354 | func init() { proto.RegisterFile("echoserver.proto", fileDescriptor0) } 355 | 356 | var fileDescriptor0 = []byte{ 357 | // 300 bytes of a gzipped FileDescriptorProto 358 | 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x91, 0xc1, 0x4a, 0xc3, 0x40, 359 | 0x10, 0x86, 0xdd, 0x62, 0x13, 0x3a, 0xb5, 0x12, 0x57, 0xc4, 0x50, 0x2a, 0xc8, 0x5e, 0x0c, 0x1e, 360 | 0x92, 0x52, 0x3d, 0x79, 0xaf, 0x08, 0x85, 0x04, 0x52, 0xc4, 0xab, 0xdb, 0x30, 0x24, 0x01, 0x93, 361 | 0x8d, 0xbb, 0xdb, 0x82, 0x57, 0x5f, 0xc1, 0x47, 0xf3, 0x15, 0x7c, 0x07, 0xaf, 0x92, 0xac, 0xc6, 362 | 0x88, 0xa5, 0xc7, 0x99, 0xfd, 0xe6, 0x9f, 0xff, 0xdf, 0x01, 0x07, 0x93, 0x4c, 0x28, 0x94, 0x1b, 363 | 0x94, 0x7e, 0x25, 0x85, 0x16, 0x14, 0x7e, 0x3b, 0xe3, 0x49, 0x2a, 0x44, 0xfa, 0x84, 0x01, 0xaf, 364 | 0xf2, 0x80, 0x97, 0xa5, 0xd0, 0x5c, 0xe7, 0xa2, 0x54, 0x86, 0x64, 0x17, 0x30, 0x9c, 0x27, 0x99, 365 | 0x88, 0xf1, 0x79, 0x8d, 0x4a, 0x53, 0x17, 0xec, 0x02, 0x95, 0xe2, 0x29, 0xba, 0xe4, 0x9c, 0x78, 366 | 0x83, 0xf8, 0xa7, 0x64, 0x1e, 0x1c, 0x18, 0x50, 0x55, 0xa2, 0x54, 0xb8, 0x83, 0x7c, 0x84, 0xc1, 367 | 0x1d, 0x72, 0xa9, 0x57, 0xc8, 0x35, 0xbd, 0x06, 0x4b, 0x69, 0xae, 0xd7, 0xaa, 0xa1, 0x0e, 0x67, 368 | 0x13, 0xbf, 0x63, 0xb6, 0xc5, 0xfc, 0x65, 0xc3, 0xc4, 0xdf, 0x2c, 0x3b, 0x03, 0xcb, 0x74, 0xe8, 369 | 0x10, 0xec, 0xfb, 0x70, 0x11, 0x46, 0x0f, 0xa1, 0xb3, 0x47, 0x2d, 0xe8, 0x45, 0x0b, 0x87, 0x30, 370 | 0x1b, 0xfa, 0xf3, 0xa2, 0xd2, 0x2f, 0xb3, 0x4f, 0x62, 0xec, 0x2f, 0x51, 0x6e, 0xf2, 0x04, 0x69, 371 | 0x04, 0xfb, 0x75, 0x49, 0x4f, 0xbb, 0x5b, 0x3a, 0xf9, 0xc6, 0xee, 0xff, 0x07, 0x93, 0x87, 0x39, 372 | 0xaf, 0xef, 0x1f, 0x6f, 0x3d, 0x60, 0xfd, 0xa0, 0x26, 0x6e, 0xc8, 0xa5, 0x47, 0xa6, 0x84, 0xde, 373 | 0xd6, 0x46, 0x24, 0xf2, 0x82, 0x1e, 0xfd, 0x99, 0xac, 0xb7, 0xef, 0x10, 0x1b, 0x35, 0x62, 0x36, 374 | 0x35, 0x62, 0x53, 0x42, 0x23, 0x80, 0x36, 0xac, 0xda, 0xa6, 0x75, 0xb2, 0xf5, 0x5f, 0xd8, 0x71, 375 | 0x23, 0x34, 0x62, 0xc3, 0x20, 0x6b, 0xc7, 0x6b, 0x63, 0x2b, 0xab, 0x39, 0xdf, 0xd5, 0x57, 0x00, 376 | 0x00, 0x00, 0xff, 0xff, 0xd4, 0x6a, 0xec, 0x7b, 0xfc, 0x01, 0x00, 0x00, 377 | } 378 | -------------------------------------------------------------------------------- /_examples/cmd/wsechoserver/echoserver/echoserver.pb.gw.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-grpc-gateway 2 | // source: echoserver.proto 3 | // DO NOT EDIT! 4 | 5 | /* 6 | Package echoserver is a reverse proxy. 7 | 8 | It translates gRPC into RESTful JSON APIs. 9 | */ 10 | package echoserver 11 | 12 | import ( 13 | "io" 14 | "net/http" 15 | 16 | "github.com/golang/protobuf/proto" 17 | "github.com/grpc-ecosystem/grpc-gateway/runtime" 18 | "github.com/grpc-ecosystem/grpc-gateway/utilities" 19 | "golang.org/x/net/context" 20 | "google.golang.org/grpc" 21 | "google.golang.org/grpc/codes" 22 | "google.golang.org/grpc/grpclog" 23 | "google.golang.org/grpc/status" 24 | ) 25 | 26 | var _ codes.Code 27 | var _ io.Reader 28 | var _ status.Status 29 | var _ = runtime.String 30 | var _ = utilities.NewDoubleArray 31 | 32 | func request_EchoService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client EchoServiceClient, req *http.Request, pathParams map[string]string) (EchoService_EchoClient, runtime.ServerMetadata, error) { 33 | var metadata runtime.ServerMetadata 34 | stream, err := client.Echo(ctx) 35 | if err != nil { 36 | grpclog.Printf("Failed to start streaming: %v", err) 37 | return nil, metadata, err 38 | } 39 | dec := marshaler.NewDecoder(req.Body) 40 | handleSend := func() error { 41 | var protoReq EchoRequest 42 | err = dec.Decode(&protoReq) 43 | if err == io.EOF { 44 | return err 45 | } 46 | if err != nil { 47 | grpclog.Printf("Failed to decode request: %v", err) 48 | return err 49 | } 50 | if err = stream.Send(&protoReq); err != nil { 51 | grpclog.Printf("Failed to send request: %v", err) 52 | return err 53 | } 54 | return nil 55 | } 56 | if err := handleSend(); err != nil { 57 | if cerr := stream.CloseSend(); cerr != nil { 58 | grpclog.Printf("Failed to terminate client stream: %v", cerr) 59 | } 60 | if err == io.EOF { 61 | return stream, metadata, nil 62 | } 63 | return nil, metadata, err 64 | } 65 | go func() { 66 | for { 67 | if err := handleSend(); err != nil { 68 | break 69 | } 70 | } 71 | if err := stream.CloseSend(); err != nil { 72 | grpclog.Printf("Failed to terminate client stream: %v", err) 73 | } 74 | }() 75 | header, err := stream.Header() 76 | if err != nil { 77 | grpclog.Printf("Failed to get header from client: %v", err) 78 | return nil, metadata, err 79 | } 80 | metadata.HeaderMD = header 81 | return stream, metadata, nil 82 | } 83 | 84 | func request_EchoService_Stream_0(ctx context.Context, marshaler runtime.Marshaler, client EchoServiceClient, req *http.Request, pathParams map[string]string) (EchoService_StreamClient, runtime.ServerMetadata, error) { 85 | var protoReq Empty 86 | var metadata runtime.ServerMetadata 87 | 88 | stream, err := client.Stream(ctx, &protoReq) 89 | if err != nil { 90 | return nil, metadata, err 91 | } 92 | header, err := stream.Header() 93 | if err != nil { 94 | return nil, metadata, err 95 | } 96 | metadata.HeaderMD = header 97 | return stream, metadata, nil 98 | 99 | } 100 | 101 | func request_EchoService_Heartbeats_0(ctx context.Context, marshaler runtime.Marshaler, client EchoServiceClient, req *http.Request, pathParams map[string]string) (EchoService_HeartbeatsClient, runtime.ServerMetadata, error) { 102 | var metadata runtime.ServerMetadata 103 | stream, err := client.Heartbeats(ctx) 104 | if err != nil { 105 | grpclog.Printf("Failed to start streaming: %v", err) 106 | return nil, metadata, err 107 | } 108 | dec := marshaler.NewDecoder(req.Body) 109 | handleSend := func() error { 110 | var protoReq Empty 111 | err = dec.Decode(&protoReq) 112 | if err == io.EOF { 113 | return err 114 | } 115 | if err != nil { 116 | grpclog.Printf("Failed to decode request: %v", err) 117 | return err 118 | } 119 | if err = stream.Send(&protoReq); err != nil { 120 | grpclog.Printf("Failed to send request: %v", err) 121 | return err 122 | } 123 | return nil 124 | } 125 | if err := handleSend(); err != nil { 126 | if cerr := stream.CloseSend(); cerr != nil { 127 | grpclog.Printf("Failed to terminate client stream: %v", cerr) 128 | } 129 | if err == io.EOF { 130 | return stream, metadata, nil 131 | } 132 | return nil, metadata, err 133 | } 134 | go func() { 135 | for { 136 | if err := handleSend(); err != nil { 137 | break 138 | } 139 | } 140 | if err := stream.CloseSend(); err != nil { 141 | grpclog.Printf("Failed to terminate client stream: %v", err) 142 | } 143 | }() 144 | header, err := stream.Header() 145 | if err != nil { 146 | grpclog.Printf("Failed to get header from client: %v", err) 147 | return nil, metadata, err 148 | } 149 | metadata.HeaderMD = header 150 | return stream, metadata, nil 151 | } 152 | 153 | // RegisterEchoServiceHandlerFromEndpoint is same as RegisterEchoServiceHandler but 154 | // automatically dials to "endpoint" and closes the connection when "ctx" gets done. 155 | func RegisterEchoServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { 156 | conn, err := grpc.Dial(endpoint, opts...) 157 | if err != nil { 158 | return err 159 | } 160 | defer func() { 161 | if err != nil { 162 | if cerr := conn.Close(); cerr != nil { 163 | grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr) 164 | } 165 | return 166 | } 167 | go func() { 168 | <-ctx.Done() 169 | if cerr := conn.Close(); cerr != nil { 170 | grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr) 171 | } 172 | }() 173 | }() 174 | 175 | return RegisterEchoServiceHandler(ctx, mux, conn) 176 | } 177 | 178 | // RegisterEchoServiceHandler registers the http handlers for service EchoService to "mux". 179 | // The handlers forward requests to the grpc endpoint over "conn". 180 | func RegisterEchoServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { 181 | client := NewEchoServiceClient(conn) 182 | 183 | mux.Handle("POST", pattern_EchoService_Echo_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { 184 | ctx, cancel := context.WithCancel(ctx) 185 | defer cancel() 186 | if cn, ok := w.(http.CloseNotifier); ok { 187 | go func(done <-chan struct{}, closed <-chan bool) { 188 | select { 189 | case <-done: 190 | case <-closed: 191 | cancel() 192 | } 193 | }(ctx.Done(), cn.CloseNotify()) 194 | } 195 | inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) 196 | rctx, err := runtime.AnnotateContext(ctx, mux, req) 197 | if err != nil { 198 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 199 | } 200 | resp, md, err := request_EchoService_Echo_0(rctx, inboundMarshaler, client, req, pathParams) 201 | ctx = runtime.NewServerMetadataContext(ctx, md) 202 | if err != nil { 203 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 204 | return 205 | } 206 | 207 | forward_EchoService_Echo_0(ctx, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) 208 | 209 | }) 210 | 211 | mux.Handle("GET", pattern_EchoService_Stream_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { 212 | ctx, cancel := context.WithCancel(ctx) 213 | defer cancel() 214 | if cn, ok := w.(http.CloseNotifier); ok { 215 | go func(done <-chan struct{}, closed <-chan bool) { 216 | select { 217 | case <-done: 218 | case <-closed: 219 | cancel() 220 | } 221 | }(ctx.Done(), cn.CloseNotify()) 222 | } 223 | inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) 224 | rctx, err := runtime.AnnotateContext(ctx, mux, req) 225 | if err != nil { 226 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 227 | } 228 | resp, md, err := request_EchoService_Stream_0(rctx, inboundMarshaler, client, req, pathParams) 229 | ctx = runtime.NewServerMetadataContext(ctx, md) 230 | if err != nil { 231 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 232 | return 233 | } 234 | 235 | forward_EchoService_Stream_0(ctx, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) 236 | 237 | }) 238 | 239 | mux.Handle("POST", pattern_EchoService_Heartbeats_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { 240 | ctx, cancel := context.WithCancel(ctx) 241 | defer cancel() 242 | if cn, ok := w.(http.CloseNotifier); ok { 243 | go func(done <-chan struct{}, closed <-chan bool) { 244 | select { 245 | case <-done: 246 | case <-closed: 247 | cancel() 248 | } 249 | }(ctx.Done(), cn.CloseNotify()) 250 | } 251 | inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) 252 | rctx, err := runtime.AnnotateContext(ctx, mux, req) 253 | if err != nil { 254 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 255 | } 256 | resp, md, err := request_EchoService_Heartbeats_0(rctx, inboundMarshaler, client, req, pathParams) 257 | ctx = runtime.NewServerMetadataContext(ctx, md) 258 | if err != nil { 259 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 260 | return 261 | } 262 | 263 | forward_EchoService_Heartbeats_0(ctx, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) 264 | 265 | }) 266 | 267 | return nil 268 | } 269 | 270 | var ( 271 | pattern_EchoService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0}, []string{"echo"}, "")) 272 | 273 | pattern_EchoService_Stream_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0}, []string{"echo"}, "")) 274 | 275 | pattern_EchoService_Heartbeats_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0}, []string{"heartbeats"}, "")) 276 | ) 277 | 278 | var ( 279 | forward_EchoService_Echo_0 = runtime.ForwardResponseStream 280 | 281 | forward_EchoService_Stream_0 = runtime.ForwardResponseStream 282 | 283 | forward_EchoService_Heartbeats_0 = runtime.ForwardResponseStream 284 | ) 285 | -------------------------------------------------------------------------------- /_examples/cmd/wsechoserver/echoserver/echoserver.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | package echoserver; 3 | 4 | import "google/api/annotations.proto"; 5 | 6 | message EchoRequest { 7 | string message = 1; 8 | } 9 | 10 | message EchoResponse { 11 | string message = 1; 12 | } 13 | 14 | message Heartbeat { 15 | enum Status { 16 | UNKNOWN = 0; 17 | OK = 1; 18 | } 19 | Status status = 1; 20 | } 21 | 22 | message Empty {} 23 | 24 | service EchoService { 25 | rpc Echo(stream EchoRequest) returns (stream EchoResponse) { 26 | option (google.api.http) = {post: "/echo", body: "*"}; 27 | } 28 | rpc Stream(Empty) returns (stream EchoResponse) { 29 | option (google.api.http) = {get: "/echo"}; 30 | } 31 | rpc Heartbeats(stream Empty) returns (stream Heartbeat) { 32 | option (google.api.http) = {post: "/heartbeats"}; 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /_examples/cmd/wsechoserver/echoserver/gen.go: -------------------------------------------------------------------------------- 1 | //go:generate make 2 | 3 | package echoserver 4 | -------------------------------------------------------------------------------- /_examples/cmd/wsechoserver/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "net" 8 | "net/http" 9 | _ "net/http/pprof" 10 | 11 | _ "golang.org/x/net/trace" 12 | 13 | "github.com/golang/glog" 14 | "github.com/grpc-ecosystem/grpc-gateway/runtime" 15 | "github.com/tmc/grpc-websocket-proxy/examples/cmd/wsechoserver/echoserver" 16 | "github.com/tmc/grpc-websocket-proxy/wsproxy" 17 | "golang.org/x/net/context" 18 | "google.golang.org/grpc" 19 | ) 20 | 21 | var ( 22 | grpcAddr = flag.String("grpcaddr", ":8001", "listen grpc addr") 23 | httpAddr = flag.String("addr", ":8000", "listen http addr") 24 | debugAddr = flag.String("debugaddr", ":8002", "listen debug addr") 25 | ) 26 | 27 | func run() error { 28 | ctx := context.Background() 29 | ctx, cancel := context.WithCancel(ctx) 30 | defer cancel() 31 | 32 | if err := listenGRPC(*grpcAddr); err != nil { 33 | return err 34 | } 35 | 36 | mux := runtime.NewServeMux() 37 | opts := []grpc.DialOption{grpc.WithInsecure()} 38 | err := echoserver.RegisterEchoServiceHandlerFromEndpoint(ctx, mux, *grpcAddr, opts) 39 | if err != nil { 40 | return err 41 | } 42 | go http.ListenAndServe(*debugAddr, nil) 43 | fmt.Println("listening") 44 | http.ListenAndServe(*httpAddr, wsproxy.WebsocketProxy(mux)) 45 | return nil 46 | } 47 | 48 | func listenGRPC(listenAddr string) error { 49 | lis, err := net.Listen("tcp", listenAddr) 50 | if err != nil { 51 | return err 52 | } 53 | grpcServer := grpc.NewServer() 54 | echoserver.RegisterEchoServiceServer(grpcServer, &Server{}) 55 | go func() { 56 | if err := grpcServer.Serve(lis); err != nil { 57 | log.Println("serveGRPC err:", err) 58 | } 59 | }() 60 | return nil 61 | 62 | } 63 | 64 | func main() { 65 | flag.Parse() 66 | defer glog.Flush() 67 | 68 | if err := run(); err != nil { 69 | glog.Fatal(err) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /_examples/cmd/wsechoserver/server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/golang/protobuf/jsonpb" 9 | log "github.com/sirupsen/logrus" 10 | "github.com/tmc/grpc-websocket-proxy/examples/cmd/wsechoserver/echoserver" 11 | ) 12 | 13 | type Server struct{} 14 | 15 | func (s *Server) Stream(_ *echoserver.Empty, stream echoserver.EchoService_StreamServer) error { 16 | start := time.Now() 17 | for i := 0; i < 5; i++ { 18 | time.Sleep(time.Second) 19 | if err := stream.Send(&echoserver.EchoResponse{ 20 | Message: "hello there!" + fmt.Sprint(time.Now().Sub(start)), 21 | }); err != nil { 22 | return err 23 | } 24 | } 25 | return nil 26 | } 27 | 28 | func (s *Server) Echo(srv echoserver.EchoService_EchoServer) error { 29 | for { 30 | req, err := srv.Recv() 31 | if err != nil { 32 | return err 33 | } 34 | if err := srv.Send(&echoserver.EchoResponse{ 35 | Message: req.Message + "!", 36 | }); err != nil { 37 | return err 38 | } 39 | } 40 | } 41 | 42 | func (s *Server) Heartbeats(srv echoserver.EchoService_HeartbeatsServer) error { 43 | go func() { 44 | for { 45 | _, err := srv.Recv() 46 | if err != nil { 47 | log.Println("Recv() err:", err) 48 | return 49 | } 50 | log.Println("got hb from client") 51 | } 52 | }() 53 | t := time.NewTicker(time.Second * 1) 54 | for { 55 | log.Println("sending hb") 56 | hb := &echoserver.Heartbeat{ 57 | Status: echoserver.Heartbeat_OK, 58 | } 59 | b := new(bytes.Buffer) 60 | if err := (&jsonpb.Marshaler{}).Marshal(b, hb); err != nil { 61 | log.Println("marshal err:", err) 62 | } 63 | log.Println(string(b.Bytes())) 64 | if err := srv.Send(hb); err != nil { 65 | return err 66 | } 67 | <-t.C 68 | } 69 | return nil 70 | } 71 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/tmc/grpc-websocket-proxy 2 | 3 | go 1.15 4 | 5 | require ( 6 | github.com/gorilla/websocket v1.4.2 7 | github.com/sirupsen/logrus v1.8.1 8 | github.com/stretchr/testify v1.7.0 // indirect 9 | golang.org/x/net v0.0.0-20211123203042-d83791d6bcd9 10 | golang.org/x/sys v0.0.0-20210510120138-977fb7262007 // indirect 11 | ) 12 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= 5 | github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 6 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 7 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 8 | github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= 9 | github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= 10 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 11 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 12 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 13 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 14 | golang.org/x/net v0.0.0-20211123203042-d83791d6bcd9 h1:0qxwC5n+ttVOINCBeRHO0nq9X7uy8SDsPoi5OaCdIEI= 15 | golang.org/x/net v0.0.0-20211123203042-d83791d6bcd9/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 16 | golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 17 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 18 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 19 | golang.org/x/sys v0.0.0-20210510120138-977fb7262007 h1:gG67DSER+11cZvqIMb8S8bt0vZtiN6xWYARwirrOSfE= 20 | golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 21 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 22 | golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= 23 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 24 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 25 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 26 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 27 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 28 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 29 | -------------------------------------------------------------------------------- /wsproxy/doc.go: -------------------------------------------------------------------------------- 1 | // Package wsproxy implements a websocket proxy for grpc-gateway backed services 2 | package wsproxy 3 | -------------------------------------------------------------------------------- /wsproxy/websocket_proxy.go: -------------------------------------------------------------------------------- 1 | package wsproxy 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "strings" 9 | "time" 10 | 11 | "github.com/gorilla/websocket" 12 | "github.com/sirupsen/logrus" 13 | "golang.org/x/net/context" 14 | ) 15 | 16 | // MethodOverrideParam defines the special URL parameter that is translated into the subsequent proxied streaming http request's method. 17 | // 18 | // Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters. 19 | var MethodOverrideParam = "method" 20 | 21 | // TokenCookieName defines the cookie name that is translated to an 'Authorization: Bearer' header in the streaming http request's headers. 22 | // 23 | // Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters. 24 | var TokenCookieName = "token" 25 | 26 | // RequestMutatorFunc can supply an alternate outgoing request. 27 | type RequestMutatorFunc func(incoming *http.Request, outgoing *http.Request) *http.Request 28 | 29 | // Proxy provides websocket transport upgrade to compatible endpoints. 30 | type Proxy struct { 31 | h http.Handler 32 | logger Logger 33 | maxRespBodyBufferBytes int 34 | methodOverrideParam string 35 | tokenCookieName string 36 | requestMutator RequestMutatorFunc 37 | headerForwarder func(header string) bool 38 | pingInterval time.Duration 39 | pingWait time.Duration 40 | pongWait time.Duration 41 | } 42 | 43 | // Logger collects log messages. 44 | type Logger interface { 45 | Warnln(...interface{}) 46 | Debugln(...interface{}) 47 | } 48 | 49 | func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { 50 | if !websocket.IsWebSocketUpgrade(r) { 51 | p.h.ServeHTTP(w, r) 52 | return 53 | } 54 | p.proxy(w, r) 55 | } 56 | 57 | // Option allows customization of the proxy. 58 | type Option func(*Proxy) 59 | 60 | // WithMaxRespBodyBufferSize allows specification of a custom size for the 61 | // buffer used while reading the response body. By default, the bufio.Scanner 62 | // used to read the response body sets the maximum token size to MaxScanTokenSize. 63 | func WithMaxRespBodyBufferSize(nBytes int) Option { 64 | return func(p *Proxy) { 65 | p.maxRespBodyBufferBytes = nBytes 66 | } 67 | } 68 | 69 | // WithMethodParamOverride allows specification of the special http parameter that is used in the proxied streaming request. 70 | func WithMethodParamOverride(param string) Option { 71 | return func(p *Proxy) { 72 | p.methodOverrideParam = param 73 | } 74 | } 75 | 76 | // WithTokenCookieName allows specification of the cookie that is supplied as an upstream 'Authorization: Bearer' http header. 77 | func WithTokenCookieName(param string) Option { 78 | return func(p *Proxy) { 79 | p.tokenCookieName = param 80 | } 81 | } 82 | 83 | // WithRequestMutator allows a custom RequestMutatorFunc to be supplied. 84 | func WithRequestMutator(fn RequestMutatorFunc) Option { 85 | return func(p *Proxy) { 86 | p.requestMutator = fn 87 | } 88 | } 89 | 90 | // WithForwardedHeaders allows controlling which headers are forwarded. 91 | func WithForwardedHeaders(fn func(header string) bool) Option { 92 | return func(p *Proxy) { 93 | p.headerForwarder = fn 94 | } 95 | } 96 | 97 | // WithLogger allows a custom FieldLogger to be supplied 98 | func WithLogger(logger Logger) Option { 99 | return func(p *Proxy) { 100 | p.logger = logger 101 | } 102 | } 103 | 104 | // WithPingControl allows specification of ping pong control. The interval 105 | // parameter specifies the pingInterval between pings. The allowed wait time 106 | // for a pong response is (pingInterval * 10) / 9. 107 | func WithPingControl(interval time.Duration) Option { 108 | return func(proxy *Proxy) { 109 | proxy.pingInterval = interval 110 | proxy.pongWait = (interval * 10) / 9 111 | proxy.pingWait = proxy.pongWait / 6 112 | } 113 | } 114 | 115 | var defaultHeadersToForward = map[string]bool{ 116 | "Origin": true, 117 | "origin": true, 118 | "Referer": true, 119 | "referer": true, 120 | } 121 | 122 | func defaultHeaderForwarder(header string) bool { 123 | return defaultHeadersToForward[header] 124 | } 125 | 126 | // WebsocketProxy attempts to expose the underlying handler as a bidi websocket stream with newline-delimited 127 | // JSON as the content encoding. 128 | // 129 | // The HTTP Authorization header is either populated from the Sec-Websocket-Protocol field or by a cookie. 130 | // The cookie name is specified by the TokenCookieName value. 131 | // 132 | // example: 133 | // Sec-Websocket-Protocol: Bearer, foobar 134 | // is converted to: 135 | // Authorization: Bearer foobar 136 | // 137 | // Method can be overwritten with the MethodOverrideParam get parameter in the requested URL 138 | func WebsocketProxy(h http.Handler, opts ...Option) http.Handler { 139 | p := &Proxy{ 140 | h: h, 141 | logger: logrus.New(), 142 | methodOverrideParam: MethodOverrideParam, 143 | tokenCookieName: TokenCookieName, 144 | headerForwarder: defaultHeaderForwarder, 145 | } 146 | for _, o := range opts { 147 | o(p) 148 | } 149 | return p 150 | } 151 | 152 | // TODO(tmc): allow modification of upgrader settings? 153 | var upgrader = websocket.Upgrader{ 154 | ReadBufferSize: 1024, 155 | WriteBufferSize: 1024, 156 | CheckOrigin: func(r *http.Request) bool { return true }, 157 | } 158 | 159 | func isClosedConnError(err error) bool { 160 | str := err.Error() 161 | if strings.Contains(str, "use of closed network connection") { 162 | return true 163 | } 164 | return websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) 165 | } 166 | 167 | func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) { 168 | var responseHeader http.Header 169 | // If Sec-WebSocket-Protocol starts with "Bearer", respond in kind. 170 | // TODO(tmc): consider customizability/extension point here. 171 | if strings.HasPrefix(r.Header.Get("Sec-WebSocket-Protocol"), "Bearer") { 172 | responseHeader = http.Header{ 173 | "Sec-WebSocket-Protocol": []string{"Bearer"}, 174 | } 175 | } 176 | conn, err := upgrader.Upgrade(w, r, responseHeader) 177 | if err != nil { 178 | p.logger.Warnln("error upgrading websocket:", err) 179 | return 180 | } 181 | defer conn.Close() 182 | 183 | ctx, cancelFn := context.WithCancel(context.Background()) 184 | defer cancelFn() 185 | 186 | requestBodyR, requestBodyW := io.Pipe() 187 | request, err := http.NewRequestWithContext(r.Context(), r.Method, r.URL.String(), requestBodyR) 188 | if err != nil { 189 | p.logger.Warnln("error preparing request:", err) 190 | return 191 | } 192 | if swsp := r.Header.Get("Sec-WebSocket-Protocol"); swsp != "" { 193 | request.Header.Set("Authorization", transformSubProtocolHeader(swsp)) 194 | } 195 | for header := range r.Header { 196 | if p.headerForwarder(header) { 197 | request.Header.Set(header, r.Header.Get(header)) 198 | } 199 | } 200 | // If token cookie is present, populate Authorization header from the cookie instead. 201 | if cookie, err := r.Cookie(p.tokenCookieName); err == nil { 202 | request.Header.Set("Authorization", "Bearer "+cookie.Value) 203 | } 204 | if m := r.URL.Query().Get(p.methodOverrideParam); m != "" { 205 | request.Method = m 206 | } 207 | 208 | if p.requestMutator != nil { 209 | request = p.requestMutator(r, request) 210 | } 211 | 212 | responseBodyR, responseBodyW := io.Pipe() 213 | response := newInMemoryResponseWriter(responseBodyW) 214 | go func() { 215 | <-ctx.Done() 216 | p.logger.Debugln("closing pipes") 217 | requestBodyW.CloseWithError(io.EOF) 218 | responseBodyW.CloseWithError(io.EOF) 219 | response.closed <- true 220 | }() 221 | 222 | go func() { 223 | defer cancelFn() 224 | p.h.ServeHTTP(response, request) 225 | }() 226 | 227 | // read loop -- take messages from websocket and write to http request 228 | go func() { 229 | if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 { 230 | conn.SetReadDeadline(time.Now().Add(p.pongWait)) 231 | conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(p.pongWait)); return nil }) 232 | } 233 | defer func() { 234 | cancelFn() 235 | }() 236 | for { 237 | select { 238 | case <-ctx.Done(): 239 | p.logger.Debugln("read loop done") 240 | return 241 | default: 242 | } 243 | p.logger.Debugln("[read] reading from socket.") 244 | _, payload, err := conn.ReadMessage() 245 | if err != nil { 246 | if isClosedConnError(err) { 247 | p.logger.Debugln("[read] websocket closed:", err) 248 | return 249 | } 250 | p.logger.Warnln("error reading websocket message:", err) 251 | return 252 | } 253 | p.logger.Debugln("[read] read payload:", string(payload)) 254 | p.logger.Debugln("[read] writing to requestBody:") 255 | n, err := requestBodyW.Write(payload) 256 | requestBodyW.Write([]byte("\n")) 257 | p.logger.Debugln("[read] wrote to requestBody", n) 258 | if err != nil { 259 | p.logger.Warnln("[read] error writing message to upstream http server:", err) 260 | return 261 | } 262 | } 263 | }() 264 | // ping write loop 265 | if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 { 266 | go func() { 267 | ticker := time.NewTicker(p.pingInterval) 268 | defer func() { 269 | ticker.Stop() 270 | conn.Close() 271 | }() 272 | for { 273 | select { 274 | case <-ctx.Done(): 275 | p.logger.Debugln("ping loop done") 276 | return 277 | case <-ticker.C: 278 | conn.SetWriteDeadline(time.Now().Add(p.pingWait)) 279 | if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { 280 | return 281 | } 282 | } 283 | } 284 | }() 285 | } 286 | // write loop -- take messages from response and write to websocket 287 | scanner := bufio.NewScanner(responseBodyR) 288 | 289 | // if maxRespBodyBufferSize has been specified, use custom buffer for scanner 290 | var scannerBuf []byte 291 | if p.maxRespBodyBufferBytes > 0 { 292 | scannerBuf = make([]byte, 0, 64*1024) 293 | scanner.Buffer(scannerBuf, p.maxRespBodyBufferBytes) 294 | } 295 | 296 | for scanner.Scan() { 297 | if len(scanner.Bytes()) == 0 { 298 | p.logger.Warnln("[write] empty scan", scanner.Err()) 299 | continue 300 | } 301 | p.logger.Debugln("[write] scanned", scanner.Text()) 302 | if err = conn.WriteMessage(websocket.TextMessage, scanner.Bytes()); err != nil { 303 | p.logger.Warnln("[write] error writing websocket message:", err) 304 | return 305 | } 306 | } 307 | if err := scanner.Err(); err != nil { 308 | p.logger.Warnln("scanner err:", err) 309 | } 310 | } 311 | 312 | type inMemoryResponseWriter struct { 313 | io.Writer 314 | header http.Header 315 | code int 316 | closed chan bool 317 | } 318 | 319 | func newInMemoryResponseWriter(w io.Writer) *inMemoryResponseWriter { 320 | return &inMemoryResponseWriter{ 321 | Writer: w, 322 | header: http.Header{}, 323 | closed: make(chan bool, 1), 324 | } 325 | } 326 | 327 | // IE and Edge do not delimit Sec-WebSocket-Protocol strings with spaces 328 | func transformSubProtocolHeader(header string) string { 329 | tokens := strings.SplitN(header, "Bearer,", 2) 330 | 331 | if len(tokens) < 2 { 332 | return "" 333 | } 334 | 335 | return fmt.Sprintf("Bearer %v", strings.Trim(tokens[1], " ")) 336 | } 337 | 338 | func (w *inMemoryResponseWriter) Write(b []byte) (int, error) { 339 | return w.Writer.Write(b) 340 | } 341 | func (w *inMemoryResponseWriter) Header() http.Header { 342 | return w.header 343 | } 344 | func (w *inMemoryResponseWriter) WriteHeader(code int) { 345 | w.code = code 346 | } 347 | func (w *inMemoryResponseWriter) CloseNotify() <-chan bool { 348 | return w.closed 349 | } 350 | func (w *inMemoryResponseWriter) Flush() {} 351 | --------------------------------------------------------------------------------