├── .travis.yml ├── LICENSE ├── README.md ├── client.go ├── clientstate_string.go ├── control.go ├── helper_test.go ├── httpproxy.go ├── proto ├── control_msg.go └── proto.go ├── proxy.go ├── server.go ├── spec.md ├── tcpproxy.go ├── tunnel_test.go ├── tunneltest ├── state_recorder.go └── tunneltest.go ├── util.go ├── virtualaddr.go ├── virtualhost.go └── websocket_test.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | sudo: false 4 | 5 | addons: 6 | apt: 7 | packages: 8 | - moreutils 9 | 10 | go: 11 | - 1.4.3 12 | - 1.6.3 13 | - 1.7 14 | 15 | script: 16 | - export GOMAXPROCS=$(nproc) 17 | - gofmt -s -l . | ifne false 18 | - go build ./... 19 | - go test -race ./... 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 The Koding Authors. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of Koding Inc. nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tunnel [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](http://godoc.org/github.com/koding/tunnel) [![Go Report Card](https://goreportcard.com/badge/github.com/koding/tunnel)](https://goreportcard.com/report/github.com/koding/tunnel) [![Build Status](http://img.shields.io/travis/koding/tunnel.svg?style=flat-square)](https://travis-ci.org/koding/tunnel) 2 | 3 | Tunnel is a server/client package that enables to proxy public connections to 4 | your local machine over a tunnel connection from the local machine to the 5 | public server. What this means is, you can share your localhost even if it 6 | doesn't have a Public IP or if it's not reachable from outside. 7 | 8 | It uses the excellent [yamux](https://github.com/hashicorp/yamux) package to 9 | multiplex connections between server and client. 10 | 11 | The project is under active development, please vendor it if you want to use it. 12 | 13 | # Usage 14 | 15 | The tunnel package consists of two parts. The `server` and the `client`. 16 | 17 | Server is the public facing part. It's type that satisfies the `http.Handler`. 18 | So it's easily pluggable into existing servers. 19 | 20 | 21 | Let assume that you setup your DNS service so all `*.example.com` domains route 22 | to your server at the public IP `203.0.113.0`. Let us first create the server 23 | part: 24 | 25 | ```go 26 | package main 27 | 28 | import ( 29 | "net/http" 30 | 31 | "github.com/koding/tunnel" 32 | ) 33 | 34 | func main() { 35 | cfg := &tunnel.ServerConfig{} 36 | server, _ := tunnel.NewServer(cfg) 37 | server.AddHost("sub.example.com", "1234") 38 | http.ListenAndServe(":80", server) 39 | } 40 | ``` 41 | 42 | Once you create the `server`, you just plug it into your server. The only 43 | detail here is to map a virtualhost to a secret token. The secret token is the 44 | only part that needs to be known for the client side. 45 | 46 | Let us now create the client side part: 47 | 48 | ```go 49 | package main 50 | 51 | import "github.com/koding/tunnel" 52 | 53 | func main() { 54 | cfg := &tunnel.ClientConfig{ 55 | Identifier: "1234", 56 | ServerAddr: "203.0.113.0:80", 57 | } 58 | 59 | client, err := tunnel.NewClient(cfg) 60 | if err != nil { 61 | panic(err) 62 | } 63 | 64 | client.Start() 65 | } 66 | ``` 67 | 68 | The `Start()` method is by default blocking. As you see you, we just passed the 69 | server address and the secret token. 70 | 71 | Now whenever someone hit `sub.example.com`, the request will be proxied to the 72 | machine where client is running and hit the local server running `127.0.0.1:80` 73 | (assuming there is one). If someone hits `sub.example.com:3000` (assume your 74 | server is running at this port), it'll be routed to `127.0.0.1:3000` 75 | 76 | That's it. 77 | 78 | There are many options that can be changed, such as a static local address for 79 | your client. Have alook at the 80 | [documentation](http://godoc.org/github.com/koding/tunnel) 81 | 82 | 83 | # Protocol 84 | 85 | The server/client protocol is written in the [spec.md](spec.md) file. Please 86 | have a look for more detail. 87 | 88 | 89 | ## License 90 | 91 | The BSD 3-Clause License - see LICENSE for more details 92 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package tunnel 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "fmt" 7 | "io/ioutil" 8 | "net" 9 | "net/http" 10 | "sync" 11 | "sync/atomic" 12 | "time" 13 | 14 | "github.com/koding/logging" 15 | "github.com/koding/tunnel/proto" 16 | 17 | "github.com/hashicorp/yamux" 18 | ) 19 | 20 | //go:generate stringer -type ClientState 21 | 22 | // ErrRedialAborted is emitted on ClientClosed event, when backoff policy 23 | // used by a client decided no more reconnection attempts must be made. 24 | var ErrRedialAborted = errors.New("unable to restore the connection, aborting") 25 | 26 | // ClientState represents client connection state to tunnel server. 27 | type ClientState uint32 28 | 29 | // ClientState enumeration. 30 | const ( 31 | ClientUnknown ClientState = iota 32 | ClientStarted 33 | ClientConnecting 34 | ClientConnected 35 | ClientDisconnected 36 | ClientClosed // keep it always last 37 | ) 38 | 39 | // ClientStateChange represents single client state transition. 40 | type ClientStateChange struct { 41 | Identifier string 42 | Previous ClientState 43 | Current ClientState 44 | Error error 45 | } 46 | 47 | // Strings implements the fmt.Stringer interface. 48 | func (cs *ClientStateChange) String() string { 49 | if cs.Error != nil { 50 | return fmt.Sprintf("[%s] %s->%s (%s)", cs.Identifier, cs.Previous, cs.Current, cs.Error) 51 | } 52 | return fmt.Sprintf("[%s] %s->%s", cs.Identifier, cs.Previous, cs.Current) 53 | } 54 | 55 | // Backoff defines behavior of staggering reconnection retries. 56 | type Backoff interface { 57 | // Next returns the duration to sleep before retrying reconnections. 58 | // If the returned value is negative, the retry is aborted. 59 | NextBackOff() time.Duration 60 | 61 | // Reset is used to signal a reconnection was successful and next 62 | // call to Next should return desired time duration for 1st reconnection 63 | // attempt. 64 | Reset() 65 | } 66 | 67 | // Client is responsible for creating a control connection to a tunnel server, 68 | // creating new tunnels and proxy them to tunnel server. 69 | type Client struct { 70 | // underlying yamux session 71 | session *yamux.Session 72 | 73 | // config holds the ClientConfig 74 | config *ClientConfig 75 | 76 | // yamuxConfig is passed to new yamux.Session's 77 | yamuxConfig *yamux.Config 78 | 79 | // proxy handles local server communication. 80 | proxy ProxyFunc 81 | 82 | // startNotify is a chanel user can get to be notified when client is 83 | // connected to the server. The preferred way of doing this however, 84 | // would be using StateChanges in ClientConfig where user can provide 85 | // his own channel. 86 | startNotify chan bool 87 | // closed is a flag set when client calls Close() and quits. 88 | closed bool 89 | // closedMu guards both closed flag and startNotify channel. Since library 90 | // owns the channel it's cleared when trying to reconnect. 91 | closedMu sync.RWMutex 92 | 93 | reqWg sync.WaitGroup 94 | ctrlWg sync.WaitGroup 95 | 96 | state ClientState 97 | 98 | // redialBackoff is used to reconnect in exponential backoff intervals 99 | redialBackoff Backoff 100 | 101 | log logging.Logger 102 | } 103 | 104 | // ClientConfig defines the configuration for the Client 105 | type ClientConfig struct { 106 | // Identifier is the secret token that needs to be passed to the server. 107 | // Required if FetchIdentifier is not set. 108 | Identifier string 109 | 110 | // FetchIdentifier can be used to fetch identifier. Required if Identifier 111 | // is not set. 112 | FetchIdentifier func() (string, error) 113 | 114 | // ServerAddr defines the TCP address of the tunnel server to be connected. 115 | // Required if FetchServerAddr is not set. 116 | ServerAddr string 117 | 118 | // FetchServerAddr can be used to fetch tunnel server address. 119 | // Required if ServerAddress is not set. 120 | FetchServerAddr func() (string, error) 121 | 122 | // Dial provides custom transport layer for client server communication. 123 | // 124 | // If nil, default implementation is to return net.Dial("tcp", address). 125 | // 126 | // It can be used for connection monitoring, setting different timeouts or 127 | // securing the connection. 128 | Dial func(network, address string) (net.Conn, error) 129 | 130 | // Proxy defines custom proxing logic. This is optional extension point 131 | // where you can provide your local server selection or communication rules. 132 | Proxy ProxyFunc 133 | 134 | // StateChanges receives state transition details each time client 135 | // connection state changes. The channel is expected to be sufficiently 136 | // buffered to keep up with event pace. 137 | // 138 | // If nil, no information about state transitions are dispatched 139 | // by the library. 140 | StateChanges chan<- *ClientStateChange 141 | 142 | // Backoff is used to control behavior of staggering reconnection loop. 143 | // 144 | // If nil, default backoff policy is used which makes a client to never 145 | // give up on reconnection. 146 | // 147 | // If custom backoff is used, client will emit ErrRedialAborted set 148 | // with ClientClosed event when no more reconnection atttemps should 149 | // be made. 150 | Backoff Backoff 151 | 152 | // YamuxConfig defines the config which passed to every new yamux.Session. If nil 153 | // yamux.DefaultConfig() is used. 154 | YamuxConfig *yamux.Config 155 | 156 | // Log defines the logger. If nil a default logging.Logger is used. 157 | Log logging.Logger 158 | 159 | // Debug enables debug mode, enable only if you want to debug the server. 160 | Debug bool 161 | 162 | // DEPRECATED: 163 | 164 | // LocalAddr is DEPRECATED please use ProxyHTTP.LocalAddr, see ProxyOverwrite for more details. 165 | LocalAddr string 166 | 167 | // FetchLocalAddr is DEPRECATED please use ProxyTCP.FetchLocalAddr, see ProxyOverwrite for more details. 168 | FetchLocalAddr func(port int) (string, error) 169 | } 170 | 171 | // verify is used to verify the ClientConfig 172 | func (c *ClientConfig) verify() error { 173 | if c.ServerAddr == "" && c.FetchServerAddr == nil { 174 | return errors.New("neither ServerAddr nor FetchServerAddr is set") 175 | } 176 | 177 | if c.Identifier == "" && c.FetchIdentifier == nil { 178 | return errors.New("neither Identifier nor FetchIdentifier is set") 179 | } 180 | 181 | if c.YamuxConfig != nil { 182 | if err := yamux.VerifyConfig(c.YamuxConfig); err != nil { 183 | return err 184 | } 185 | } 186 | 187 | if c.Proxy != nil && (c.LocalAddr != "" || c.FetchLocalAddr != nil) { 188 | return errors.New("both Proxy and LocalAddr or FetchLocalAddr are set") 189 | } 190 | 191 | return nil 192 | } 193 | 194 | // NewClient creates a new tunnel that is established between the serverAddr 195 | // and localAddr. It exits if it can't create a new control connection to the 196 | // server. If localAddr is empty client will always try to proxy to a local 197 | // port. 198 | func NewClient(cfg *ClientConfig) (*Client, error) { 199 | if err := cfg.verify(); err != nil { 200 | return nil, err 201 | } 202 | 203 | yamuxConfig := yamux.DefaultConfig() 204 | if cfg.YamuxConfig != nil { 205 | yamuxConfig = cfg.YamuxConfig 206 | } 207 | 208 | var proxy = DefaultProxy 209 | if cfg.Proxy != nil { 210 | proxy = cfg.Proxy 211 | } 212 | // DEPRECATED API SUPPORT 213 | if cfg.LocalAddr != "" || cfg.FetchLocalAddr != nil { 214 | var f ProxyFuncs 215 | if cfg.LocalAddr != "" { 216 | f.HTTP = (&HTTPProxy{LocalAddr: cfg.LocalAddr}).Proxy 217 | f.WS = (&HTTPProxy{LocalAddr: cfg.LocalAddr}).Proxy 218 | } 219 | if cfg.FetchLocalAddr != nil { 220 | f.TCP = (&TCPProxy{FetchLocalAddr: cfg.FetchLocalAddr}).Proxy 221 | } 222 | proxy = Proxy(f) 223 | } 224 | 225 | var bo Backoff = newForeverBackoff() 226 | if cfg.Backoff != nil { 227 | bo = cfg.Backoff 228 | } 229 | 230 | log := newLogger("tunnel-client", cfg.Debug) 231 | if cfg.Log != nil { 232 | log = cfg.Log 233 | } 234 | 235 | client := &Client{ 236 | config: cfg, 237 | yamuxConfig: yamuxConfig, 238 | proxy: proxy, 239 | startNotify: make(chan bool, 1), 240 | redialBackoff: bo, 241 | log: log, 242 | } 243 | 244 | return client, nil 245 | } 246 | 247 | // Start starts the client and connects to the server with the identifier. 248 | // client.FetchIdentifier() will be used if it's not nil. It's supports 249 | // reconnecting with exponential backoff intervals when the connection to the 250 | // server disconnects. Call client.Close() to shutdown the client completely. A 251 | // successful connection will cause StartNotify() to receive a value. 252 | func (c *Client) Start() { 253 | fetchIdent := func() (string, error) { 254 | if c.config.FetchIdentifier != nil { 255 | return c.config.FetchIdentifier() 256 | } 257 | 258 | return c.config.Identifier, nil 259 | } 260 | 261 | fetchServerAddr := func() (string, error) { 262 | if c.config.FetchServerAddr != nil { 263 | return c.config.FetchServerAddr() 264 | } 265 | 266 | return c.config.ServerAddr, nil 267 | } 268 | 269 | c.changeState(ClientStarted, nil) 270 | 271 | c.redialBackoff.Reset() 272 | var lastErr error 273 | for { 274 | prev := c.changeState(ClientConnecting, lastErr) 275 | 276 | if c.isRetry(prev) { 277 | dur := c.redialBackoff.NextBackOff() 278 | if dur < 0 { 279 | c.setClosed(true) 280 | c.changeState(ClientClosed, ErrRedialAborted) 281 | return 282 | } 283 | 284 | time.Sleep(dur) 285 | 286 | // exit if closed 287 | if c.isClosed() { 288 | c.changeState(ClientClosed, lastErr) 289 | return 290 | } 291 | } 292 | 293 | identifier, err := fetchIdent() 294 | if err != nil { 295 | lastErr = err 296 | c.log.Critical("client fetch identifier error: %s", err) 297 | continue 298 | } 299 | 300 | serverAddr, err := fetchServerAddr() 301 | if err != nil { 302 | lastErr = err 303 | c.log.Critical("client fetch server address error: %s", err) 304 | continue 305 | } 306 | 307 | c.setClosed(false) 308 | 309 | if err := c.connect(identifier, serverAddr); err != nil { 310 | lastErr = err 311 | c.log.Debug("client connect error: %s", err) 312 | } 313 | 314 | // exit if closed 315 | if c.isClosed() { 316 | c.changeState(ClientClosed, lastErr) 317 | return 318 | } 319 | } 320 | } 321 | 322 | // Close closes the client and shutdowns the connection to the tunnel server 323 | func (c *Client) Close() error { 324 | defer c.setClosed(true) 325 | 326 | if c.session == nil { 327 | return errors.New("session is not initialized") 328 | } 329 | 330 | // wait until all connections are finished 331 | waitCh := make(chan struct{}) 332 | go func() { 333 | if err := c.session.GoAway(); err != nil { 334 | c.log.Debug("Session go away failed: %s", err) 335 | } 336 | 337 | c.reqWg.Wait() 338 | close(waitCh) 339 | }() 340 | select { 341 | case <-waitCh: 342 | // ok 343 | case <-time.After(time.Second * 10): 344 | c.log.Info("Timeout waiting for connections to finish") 345 | } 346 | 347 | if err := c.session.Close(); err != nil { 348 | return err 349 | } 350 | 351 | return nil 352 | } 353 | 354 | // isClosed securely checks if client is marked as closed. 355 | func (c *Client) isClosed() bool { 356 | c.closedMu.RLock() 357 | defer c.closedMu.RUnlock() 358 | return c.closed 359 | } 360 | 361 | // setClosed securely marks client as closed (or not closed). If not closed 362 | // also empty the value inside the startNotify channel by retrieving it (if any), 363 | // so it doesn't block during connect, when the client was closed and started again, 364 | // and startNotify was never listened to. 365 | func (c *Client) setClosed(closed bool) { 366 | c.closedMu.Lock() 367 | defer c.closedMu.Unlock() 368 | c.closed = closed 369 | 370 | if !closed { 371 | // clear channel 372 | select { 373 | case <-c.startNotify: 374 | default: 375 | } 376 | } 377 | } 378 | 379 | // startNotifyIfNeeded sends ok to startNotify channel if it's listened to. 380 | // This function is called by connect when connection was successful. 381 | func (c *Client) startNotifyIfNeeded() { 382 | c.closedMu.RLock() 383 | if !c.closed { 384 | c.log.Debug("sending ok to startNotify chan") 385 | select { 386 | case c.startNotify <- true: 387 | default: 388 | // reaching here means the client never read the signal via 389 | // StartNotify(). This is OK, we shouldn't except it the consumer 390 | // to read from this channel. It's optional, so we just drop the 391 | // signal. 392 | c.log.Debug("startNotify message was dropped") 393 | } 394 | } 395 | c.closedMu.RUnlock() 396 | } 397 | 398 | // StartNotify returns a channel that receives a single value when the client 399 | // established a successful connection to the server. 400 | func (c *Client) StartNotify() <-chan bool { 401 | return c.startNotify 402 | } 403 | 404 | func (c *Client) changeState(state ClientState, err error) (prev ClientState) { 405 | prev = ClientState(atomic.LoadUint32((*uint32)(&c.state))) 406 | 407 | if c.config.StateChanges != nil { 408 | change := &ClientStateChange{ 409 | Identifier: c.config.Identifier, 410 | Previous: ClientState(prev), 411 | Current: state, 412 | Error: err, 413 | } 414 | 415 | select { 416 | case c.config.StateChanges <- change: 417 | default: 418 | c.log.Warning("Dropping state change due to slow reader: %s", change) 419 | } 420 | } 421 | 422 | atomic.CompareAndSwapUint32((*uint32)(&c.state), uint32(prev), uint32(state)) 423 | 424 | return prev 425 | } 426 | 427 | func (c *Client) isRetry(state ClientState) bool { 428 | return state != ClientStarted && state != ClientClosed 429 | } 430 | 431 | func (c *Client) connect(identifier, serverAddr string) error { 432 | c.log.Debug("Trying to connect to %q with identifier %q", serverAddr, identifier) 433 | 434 | conn, err := c.dial(serverAddr) 435 | if err != nil { 436 | return err 437 | } 438 | 439 | remoteURL := controlURL(conn) 440 | c.log.Debug("CONNECT to %q", remoteURL) 441 | req, err := http.NewRequest("CONNECT", remoteURL, nil) 442 | if err != nil { 443 | return fmt.Errorf("error creating request to %s: %s", remoteURL, err) 444 | } 445 | 446 | req.Header.Set(proto.ClientIdentifierHeader, identifier) 447 | 448 | c.log.Debug("Writing request to TCP: %+v", req) 449 | 450 | if err := req.Write(conn); err != nil { 451 | return fmt.Errorf("writing CONNECT request to %s failed: %s", req.URL, err) 452 | } 453 | 454 | c.log.Debug("Reading response from TCP") 455 | 456 | resp, err := http.ReadResponse(bufio.NewReader(conn), req) 457 | if err != nil { 458 | return fmt.Errorf("reading CONNECT response from %s failed: %s", req.URL, err) 459 | } 460 | defer resp.Body.Close() 461 | 462 | if resp.StatusCode != http.StatusOK || resp.Status != proto.Connected { 463 | out, err := ioutil.ReadAll(resp.Body) 464 | if err != nil { 465 | return fmt.Errorf("tunnel server error: status=%d, error=%s", resp.StatusCode, err) 466 | } 467 | 468 | return fmt.Errorf("tunnel server error: status=%d, body=%s", resp.StatusCode, string(out)) 469 | } 470 | 471 | c.ctrlWg.Wait() // wait until previous listenControl observes disconnection 472 | 473 | c.session, err = yamux.Client(conn, c.yamuxConfig) 474 | if err != nil { 475 | return fmt.Errorf("session initialization failed: %s", err) 476 | } 477 | 478 | var stream net.Conn 479 | openStream := func() error { 480 | // this is blocking until client opens a session to us 481 | stream, err = c.session.Open() 482 | return err 483 | } 484 | 485 | // if we don't receive anything from the server, we'll timeout 486 | select { 487 | case err := <-async(openStream): 488 | if err != nil { 489 | return fmt.Errorf("waiting for session to open failed: %s", err) 490 | } 491 | case <-time.After(time.Second * 10): 492 | if stream != nil { 493 | stream.Close() 494 | } 495 | return errors.New("timeout opening session") 496 | } 497 | 498 | if _, err := stream.Write([]byte(proto.HandshakeRequest)); err != nil { 499 | return fmt.Errorf("writing handshake request failed: %s", err) 500 | } 501 | 502 | buf := make([]byte, len(proto.HandshakeResponse)) 503 | if _, err := stream.Read(buf); err != nil { 504 | return fmt.Errorf("reading handshake response failed: %s", err) 505 | } 506 | 507 | if string(buf) != proto.HandshakeResponse { 508 | return fmt.Errorf("invalid handshake response, received: %s", string(buf)) 509 | } 510 | 511 | ct := newControl(stream) 512 | c.log.Debug("client has started successfully") 513 | c.redialBackoff.Reset() // we successfully connected, so we can reset the backoff 514 | 515 | c.startNotifyIfNeeded() 516 | 517 | return c.listenControl(ct) 518 | } 519 | 520 | func (c *Client) dial(serverAddr string) (net.Conn, error) { 521 | if c.config.Dial != nil { 522 | return c.config.Dial("tcp", serverAddr) 523 | } 524 | 525 | return net.Dial("tcp", serverAddr) 526 | } 527 | 528 | func (c *Client) listenControl(ct *control) error { 529 | c.ctrlWg.Add(1) 530 | defer c.ctrlWg.Done() 531 | 532 | c.changeState(ClientConnected, nil) 533 | 534 | for { 535 | var msg proto.ControlMessage 536 | if err := ct.dec.Decode(&msg); err != nil { 537 | c.reqWg.Wait() // wait until all requests are finished 538 | c.session.GoAway() 539 | c.session.Close() 540 | c.changeState(ClientDisconnected, err) 541 | 542 | return fmt.Errorf("failure decoding control message: %s", err) 543 | } 544 | 545 | c.log.Debug("Received control msg %+v", msg) 546 | c.log.Debug("Opening a new stream from server session") 547 | 548 | remote, err := c.session.Open() 549 | if err != nil { 550 | return err 551 | } 552 | 553 | isHTTP := msg.Protocol == proto.HTTP 554 | if isHTTP { 555 | c.reqWg.Add(1) 556 | } 557 | go func() { 558 | c.proxy(remote, &msg) 559 | if isHTTP { 560 | c.reqWg.Done() 561 | } 562 | remote.Close() 563 | }() 564 | } 565 | } 566 | -------------------------------------------------------------------------------- /clientstate_string.go: -------------------------------------------------------------------------------- 1 | // Code generated by "stringer -type ClientState"; DO NOT EDIT 2 | 3 | package tunnel 4 | 5 | import "fmt" 6 | 7 | const _ClientState_name = "ClientUnknownClientStartedClientConnectingClientConnectedClientDisconnectedClientClosed" 8 | 9 | var _ClientState_index = [...]uint8{0, 13, 26, 42, 57, 75, 87} 10 | 11 | func (i ClientState) String() string { 12 | if i >= ClientState(len(_ClientState_index)-1) { 13 | return fmt.Sprintf("ClientState(%d)", i) 14 | } 15 | return _ClientState_name[_ClientState_index[i]:_ClientState_index[i+1]] 16 | } 17 | -------------------------------------------------------------------------------- /control.go: -------------------------------------------------------------------------------- 1 | package tunnel 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "net" 7 | "sync" 8 | ) 9 | 10 | var errControlClosed = errors.New("control connection is closed") 11 | 12 | type control struct { 13 | // enc and dec are responsible for encoding and decoding json values forth 14 | // and back 15 | enc *json.Encoder 16 | dec *json.Decoder 17 | 18 | // underlying connection responsible for encoder and decoder 19 | nc net.Conn 20 | 21 | // identifier associated with this control 22 | identifier string 23 | 24 | mu sync.Mutex // guards the following 25 | closed bool // if Close() and quits 26 | } 27 | 28 | func newControl(nc net.Conn) *control { 29 | c := &control{ 30 | enc: json.NewEncoder(nc), 31 | dec: json.NewDecoder(nc), 32 | nc: nc, 33 | } 34 | 35 | return c 36 | } 37 | 38 | func (c *control) send(v interface{}) error { 39 | if c.enc == nil { 40 | return errors.New("encoder is not initialized") 41 | } 42 | 43 | c.mu.Lock() 44 | if c.closed { 45 | c.mu.Unlock() 46 | return errControlClosed 47 | } 48 | c.mu.Unlock() 49 | 50 | return c.enc.Encode(v) 51 | } 52 | 53 | func (c *control) recv(v interface{}) error { 54 | if c.dec == nil { 55 | return errors.New("decoder is not initialized") 56 | } 57 | 58 | c.mu.Lock() 59 | if c.closed { 60 | c.mu.Unlock() 61 | return errControlClosed 62 | } 63 | c.mu.Unlock() 64 | 65 | return c.dec.Decode(v) 66 | } 67 | 68 | func (c *control) Close() error { 69 | if c.nc == nil { 70 | return nil 71 | } 72 | 73 | c.mu.Lock() 74 | c.closed = true 75 | c.mu.Unlock() 76 | 77 | return c.nc.Close() 78 | } 79 | 80 | type controls struct { 81 | sync.Mutex 82 | controls map[string]*control 83 | } 84 | 85 | func newControls() *controls { 86 | return &controls{ 87 | controls: make(map[string]*control), 88 | } 89 | } 90 | 91 | func (c *controls) getControl(identifier string) (*control, bool) { 92 | c.Lock() 93 | control, ok := c.controls[identifier] 94 | c.Unlock() 95 | return control, ok 96 | } 97 | 98 | func (c *controls) addControl(identifier string, control *control) { 99 | control.identifier = identifier 100 | 101 | c.Lock() 102 | c.controls[identifier] = control 103 | c.Unlock() 104 | } 105 | 106 | func (c *controls) deleteControl(identifier string) { 107 | c.Lock() 108 | delete(c.controls, identifier) 109 | c.Unlock() 110 | } 111 | -------------------------------------------------------------------------------- /helper_test.go: -------------------------------------------------------------------------------- 1 | package tunnel_test 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "log" 10 | "math/rand" 11 | "net" 12 | "net/http" 13 | "net/url" 14 | "os" 15 | "time" 16 | 17 | "github.com/koding/tunnel" 18 | "github.com/koding/tunnel/tunneltest" 19 | 20 | "github.com/gorilla/websocket" 21 | ) 22 | 23 | func init() { 24 | rand.Seed(time.Now().UnixNano() + int64(os.Getpid())) 25 | } 26 | 27 | var upgrader = websocket.Upgrader{ 28 | ReadBufferSize: 1024, 29 | WriteBufferSize: 1024, 30 | } 31 | 32 | type EchoMessage struct { 33 | Value string `json:"value,omitempty"` 34 | Close bool `json:"close,omitempty"` 35 | } 36 | 37 | var timeout = 10 * time.Second 38 | 39 | var dialer = &websocket.Dialer{ 40 | ReadBufferSize: 1024, 41 | WriteBufferSize: 1024, 42 | HandshakeTimeout: timeout, 43 | NetDial: func(_, addr string) (net.Conn, error) { 44 | return net.DialTimeout("tcp4", addr, timeout) 45 | }, 46 | } 47 | 48 | func echoHTTP(tt *tunneltest.TunnelTest, echo string) (string, error) { 49 | req := tt.Request("http", url.Values{"echo": []string{echo}}) 50 | if req == nil { 51 | return "", fmt.Errorf(`tunnel "http" does not exist`) 52 | } 53 | 54 | req.Close = rand.Int()%2 == 0 55 | 56 | resp, err := http.DefaultClient.Do(req) 57 | if err != nil { 58 | return "", err 59 | } 60 | defer resp.Body.Close() 61 | 62 | p, err := ioutil.ReadAll(resp.Body) 63 | if err != nil { 64 | return "", err 65 | } 66 | 67 | return string(bytes.TrimSpace(p)), nil 68 | } 69 | 70 | func echoTCP(tt *tunneltest.TunnelTest, echo string) (string, error) { 71 | return echoTCPIdent(tt, echo, "tcp") 72 | } 73 | 74 | func echoTCPIdent(tt *tunneltest.TunnelTest, echo, ident string) (string, error) { 75 | addr := tt.Addr(ident) 76 | if addr == nil { 77 | return "", fmt.Errorf("tunnel %q does not exist", ident) 78 | } 79 | s := addr.String() 80 | ip := tt.Tunnels[ident].IP 81 | if ip != nil { 82 | _, port, err := net.SplitHostPort(s) 83 | if err != nil { 84 | return "", err 85 | } 86 | s = net.JoinHostPort(ip.String(), port) 87 | } 88 | 89 | c, err := dialTCP(s) 90 | if err != nil { 91 | return "", err 92 | } 93 | 94 | c.out <- echo 95 | 96 | select { 97 | case reply := <-c.in: 98 | return reply, nil 99 | case <-time.After(tcpTimeout): 100 | return "", fmt.Errorf("timed out waiting for reply from %s (%s) after %v", s, addr, tcpTimeout) 101 | } 102 | } 103 | 104 | func websocketDial(tt *tunneltest.TunnelTest, ident string) (*websocket.Conn, error) { 105 | req := tt.Request(ident, nil) 106 | if req == nil { 107 | return nil, fmt.Errorf("no client found for ident %q", ident) 108 | } 109 | 110 | h := http.Header{"Host": {req.Host}} 111 | wsurl := fmt.Sprintf("ws://%s", tt.ServerAddr()) 112 | 113 | conn, _, err := dialer.Dial(wsurl, h) 114 | return conn, err 115 | } 116 | 117 | func sleep() { 118 | time.Sleep(time.Duration(rand.Intn(2000)) * time.Millisecond) 119 | } 120 | 121 | func handlerEchoWS(sleepFn func()) func(w http.ResponseWriter, r *http.Request) error { 122 | return func(w http.ResponseWriter, r *http.Request) (e error) { 123 | conn, err := upgrader.Upgrade(w, r, nil) 124 | if err != nil { 125 | http.Error(w, err.Error(), http.StatusInternalServerError) 126 | return err 127 | } 128 | defer func() { 129 | err := conn.Close() 130 | if e == nil { 131 | e = err 132 | } 133 | }() 134 | 135 | if sleepFn != nil { 136 | sleepFn() 137 | } 138 | 139 | for { 140 | var msg EchoMessage 141 | err := conn.ReadJSON(&msg) 142 | if err != nil { 143 | return fmt.Errorf("ReadJSON error: %s", err) 144 | } 145 | 146 | if sleepFn != nil { 147 | sleepFn() 148 | } 149 | 150 | err = conn.WriteJSON(&msg) 151 | if err != nil { 152 | return fmt.Errorf("WriteJSON error: %s", err) 153 | } 154 | 155 | if msg.Close { 156 | return nil 157 | } 158 | } 159 | } 160 | } 161 | 162 | func handlerEchoHTTP(w http.ResponseWriter, r *http.Request) { 163 | io.WriteString(w, r.URL.Query().Get("echo")) 164 | } 165 | 166 | func handlerLatencyEchoHTTP(w http.ResponseWriter, r *http.Request) { 167 | sleep() 168 | handlerEchoHTTP(w, r) 169 | } 170 | 171 | func handlerEchoTCP(conn net.Conn) { 172 | io.Copy(conn, conn) 173 | } 174 | 175 | func handlerLatencyEchoTCP(conn net.Conn) { 176 | sleep() 177 | handlerEchoTCP(conn) 178 | } 179 | 180 | var tcpTimeout = 10 * time.Second 181 | 182 | type tcpClient struct { 183 | conn net.Conn 184 | scanner *bufio.Scanner 185 | in chan string 186 | out chan string 187 | } 188 | 189 | func (c *tcpClient) loop() { 190 | for out := range c.out { 191 | if _, err := fmt.Fprintln(c.conn, out); err != nil { 192 | log.Printf("[tunnelclient] error writing %q to %q: %s", out, c.conn.RemoteAddr(), err) 193 | return 194 | } 195 | 196 | if !c.scanner.Scan() { 197 | log.Printf("[tunnelclient] error reading from %q: %v", c.conn.RemoteAddr(), c.scanner.Err()) 198 | return 199 | } 200 | 201 | c.in <- c.scanner.Text() 202 | } 203 | } 204 | 205 | func (c *tcpClient) Close() error { 206 | close(c.out) 207 | return c.conn.Close() 208 | } 209 | 210 | func dialTCP(addr string) (*tcpClient, error) { 211 | conn, err := net.DialTimeout("tcp", addr, tcpTimeout) 212 | if err != nil { 213 | return nil, err 214 | } 215 | 216 | c := &tcpClient{ 217 | conn: conn, 218 | scanner: bufio.NewScanner(conn), 219 | in: make(chan string, 1), 220 | out: make(chan string, 1), 221 | } 222 | 223 | go c.loop() 224 | 225 | return c, nil 226 | } 227 | 228 | func singleHTTP(handler interface{}) map[string]*tunneltest.Tunnel { 229 | return singleRecHTTP(handler, nil) 230 | } 231 | 232 | func singleRecHTTP(handler interface{}, stateChanges chan<- *tunnel.ClientStateChange) map[string]*tunneltest.Tunnel { 233 | return map[string]*tunneltest.Tunnel{ 234 | "http": { 235 | Type: tunneltest.TypeHTTP, 236 | LocalAddr: "127.0.0.1:0", 237 | Handler: handler, 238 | StateChanges: stateChanges, 239 | }, 240 | } 241 | } 242 | 243 | func singleTCP(handler interface{}) map[string]*tunneltest.Tunnel { 244 | return singleRecTCP(handler, nil) 245 | } 246 | 247 | func singleRecTCP(handler interface{}, stateChanges chan<- *tunnel.ClientStateChange) map[string]*tunneltest.Tunnel { 248 | return map[string]*tunneltest.Tunnel{ 249 | "http": { 250 | Type: tunneltest.TypeHTTP, 251 | LocalAddr: "127.0.0.1:0", 252 | Handler: handlerEchoHTTP, 253 | StateChanges: stateChanges, 254 | }, 255 | "tcp": { 256 | Type: tunneltest.TypeTCP, 257 | ClientIdent: "http", 258 | LocalAddr: "127.0.0.1:0", 259 | RemoteAddr: "127.0.0.1:0", 260 | Handler: handler, 261 | }, 262 | } 263 | } 264 | -------------------------------------------------------------------------------- /httpproxy.go: -------------------------------------------------------------------------------- 1 | package tunnel 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "io/ioutil" 8 | "net" 9 | "net/http" 10 | 11 | "github.com/koding/logging" 12 | "github.com/koding/tunnel/proto" 13 | ) 14 | 15 | var ( 16 | httpLog = logging.NewLogger("http") 17 | ) 18 | 19 | // HTTPProxy forwards HTTP traffic. 20 | // 21 | // When tunnel server requests a connection it's proxied to 127.0.0.1:incomingPort 22 | // where incomingPort is control message LocalPort. 23 | // Usually this is tunnel server's public exposed Port. 24 | // This behaviour can be changed by setting LocalAddr or FetchLocalAddr. 25 | // FetchLocalAddr takes precedence over LocalAddr. 26 | // 27 | // When connection to local server cannot be established proxy responds with http error message. 28 | type HTTPProxy struct { 29 | // LocalAddr defines the TCP address of the local server. 30 | // This is optional if you want to specify a single TCP address. 31 | LocalAddr string 32 | // FetchLocalAddr is used for looking up TCP address of the server. 33 | // This is optional if you want to specify a dynamic TCP address based on incommig port. 34 | FetchLocalAddr func(port int) (string, error) 35 | // ErrorResp is custom response send to tunnel server when client cannot 36 | // establish connection to local server. If not set a default "no local server" 37 | // response is sent. 38 | ErrorResp *http.Response 39 | // Log is a custom logger that can be used for the proxy. 40 | // If not set a "http" logger is used. 41 | Log logging.Logger 42 | } 43 | 44 | // Proxy is a ProxyFunc. 45 | func (p *HTTPProxy) Proxy(remote net.Conn, msg *proto.ControlMessage) { 46 | if msg.Protocol != proto.HTTP && msg.Protocol != proto.WS { 47 | panic("Proxy mismatch") 48 | } 49 | 50 | var log = p.log() 51 | 52 | var port = msg.LocalPort 53 | if port == 0 { 54 | port = 80 55 | } 56 | 57 | var localAddr = fmt.Sprintf("127.0.0.1:%d", port) 58 | if p.LocalAddr != "" { 59 | localAddr = p.LocalAddr 60 | } else if p.FetchLocalAddr != nil { 61 | l, err := p.FetchLocalAddr(msg.LocalPort) 62 | if err != nil { 63 | log.Warning("Failed to get custom local address: %s", err) 64 | p.sendError(remote) 65 | return 66 | } 67 | localAddr = l 68 | } 69 | 70 | log.Debug("Dialing local server %q", localAddr) 71 | local, err := net.DialTimeout("tcp", localAddr, defaultTimeout) 72 | if err != nil { 73 | log.Error("Dialing local server %q failed: %s", localAddr, err) 74 | p.sendError(remote) 75 | return 76 | } 77 | 78 | Join(local, remote, log) 79 | } 80 | 81 | func (p *HTTPProxy) sendError(remote net.Conn) { 82 | var w = noLocalServer() 83 | if p.ErrorResp != nil { 84 | w = p.ErrorResp 85 | } 86 | 87 | buf := new(bytes.Buffer) 88 | w.Write(buf) 89 | if _, err := io.Copy(remote, buf); err != nil { 90 | var log = p.log() 91 | log.Debug("Copy in-mem response error: %s", err) 92 | } 93 | 94 | remote.Close() 95 | } 96 | 97 | func noLocalServer() *http.Response { 98 | body := bytes.NewBufferString("no local server") 99 | return &http.Response{ 100 | Status: http.StatusText(http.StatusServiceUnavailable), 101 | StatusCode: http.StatusServiceUnavailable, 102 | Proto: "HTTP/1.1", 103 | ProtoMajor: 1, 104 | ProtoMinor: 1, 105 | Body: ioutil.NopCloser(body), 106 | ContentLength: int64(body.Len()), 107 | } 108 | } 109 | 110 | func (p *HTTPProxy) log() logging.Logger { 111 | if p.Log != nil { 112 | return p.Log 113 | } 114 | return httpLog 115 | } 116 | -------------------------------------------------------------------------------- /proto/control_msg.go: -------------------------------------------------------------------------------- 1 | package proto 2 | 3 | // ControlMessage is sent from server to client to establish tunneled connection. 4 | type ControlMessage struct { 5 | Action Action `json:"action"` 6 | Protocol Type `json:"transportProtocol"` 7 | LocalPort int `json:"localPort"` 8 | } 9 | 10 | // Action represents type of ControlMsg request. 11 | type Action int 12 | 13 | // ControlMessage actions. 14 | const ( 15 | RequestClientSession Action = iota + 1 16 | ) 17 | 18 | // Type represents tunneled connection type. 19 | type Type int 20 | 21 | // ControlMessage protocols. 22 | const ( 23 | HTTP Type = iota + 1 24 | TCP 25 | WS 26 | ) 27 | -------------------------------------------------------------------------------- /proto/proto.go: -------------------------------------------------------------------------------- 1 | // Package proto defines tunnel client server communication protocol. 2 | package proto 3 | 4 | const ( 5 | // ControlPath is http.Handler url path for control connection. 6 | ControlPath = "/_controlPath/" 7 | 8 | // ClientIdentifierHeader is header carrying information about tunnel identifier. 9 | ClientIdentifierHeader = "X-KTunnel-Identifier" 10 | 11 | // control messages 12 | 13 | // Connected is message sent by server to client when control connection was established. 14 | Connected = "200 Connected to Tunnel" 15 | // HandshakeRequest is hello message sent by client to server. 16 | HandshakeRequest = "controlHandshake" 17 | // HandshakeResponse is response to HandshakeRequest sent by server to client. 18 | HandshakeResponse = "controlOk" 19 | ) 20 | -------------------------------------------------------------------------------- /proxy.go: -------------------------------------------------------------------------------- 1 | package tunnel 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "sync" 7 | 8 | "github.com/koding/logging" 9 | "github.com/koding/tunnel/proto" 10 | ) 11 | 12 | // ProxyFunc is responsible for forwarding a remote connection to local server and writing the response back. 13 | type ProxyFunc func(remote net.Conn, msg *proto.ControlMessage) 14 | 15 | var ( 16 | // DefaultProxyFuncs holds global default proxy functions for all transport protocols. 17 | DefaultProxyFuncs = ProxyFuncs{ 18 | HTTP: new(HTTPProxy).Proxy, 19 | TCP: new(TCPProxy).Proxy, 20 | WS: new(HTTPProxy).Proxy, 21 | } 22 | // DefaultProxy is a ProxyFunc that uses DefaultProxyFuncs. 23 | DefaultProxy = Proxy(ProxyFuncs{}) 24 | ) 25 | 26 | // ProxyFuncs is a collection of ProxyFunc. 27 | type ProxyFuncs struct { 28 | // HTTP is custom implementation of HTTP proxing. 29 | HTTP ProxyFunc 30 | // TCP is custom implementation of TCP proxing. 31 | TCP ProxyFunc 32 | // WS is custom implementation of web socket proxing. 33 | WS ProxyFunc 34 | } 35 | 36 | // Proxy returns a ProxyFunc that uses custom function if provided, otherwise falls back to DefaultProxyFuncs. 37 | func Proxy(p ProxyFuncs) ProxyFunc { 38 | return func(remote net.Conn, msg *proto.ControlMessage) { 39 | var f ProxyFunc 40 | switch msg.Protocol { 41 | case proto.HTTP: 42 | f = DefaultProxyFuncs.HTTP 43 | if p.HTTP != nil { 44 | f = p.HTTP 45 | } 46 | case proto.TCP: 47 | f = DefaultProxyFuncs.TCP 48 | if p.TCP != nil { 49 | f = p.TCP 50 | } 51 | case proto.WS: 52 | f = DefaultProxyFuncs.WS 53 | if p.WS != nil { 54 | f = p.WS 55 | } 56 | } 57 | 58 | if f == nil { 59 | logging.Error("Could not determine proxy function for %v", msg) 60 | remote.Close() 61 | } 62 | 63 | f(remote, msg) 64 | } 65 | } 66 | 67 | // Join copies data between local and remote connections. 68 | // It reads from one connection and writes to the other. 69 | // It's a building block for ProxyFunc implementations. 70 | func Join(local, remote net.Conn, log logging.Logger) { 71 | var wg sync.WaitGroup 72 | wg.Add(2) 73 | 74 | transfer := func(side string, dst, src net.Conn) { 75 | log.Debug("proxing %s -> %s", src.RemoteAddr(), dst.RemoteAddr()) 76 | 77 | n, err := io.Copy(dst, src) 78 | if err != nil { 79 | log.Error("%s: copy error: %s", side, err) 80 | } 81 | 82 | if err := src.Close(); err != nil { 83 | log.Debug("%s: close error: %s", side, err) 84 | } 85 | 86 | // not for yamux streams, but for client to local server connections 87 | if d, ok := dst.(*net.TCPConn); ok { 88 | if err := d.CloseWrite(); err != nil { 89 | log.Debug("%s: closeWrite error: %s", side, err) 90 | } 91 | 92 | } 93 | wg.Done() 94 | log.Debug("done proxing %s -> %s: %d bytes", src.RemoteAddr(), dst.RemoteAddr(), n) 95 | } 96 | 97 | go transfer("remote to local", local, remote) 98 | go transfer("local to remote", remote, local) 99 | 100 | wg.Wait() 101 | } 102 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | // Package tunnel is a server/client package that enables to proxy public 2 | // connections to your local machine over a tunnel connection from the local 3 | // machine to the public server. 4 | package tunnel 5 | 6 | import ( 7 | "bufio" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "net" 12 | "net/http" 13 | "os" 14 | "path" 15 | "strconv" 16 | "strings" 17 | "sync" 18 | "time" 19 | 20 | "github.com/koding/logging" 21 | "github.com/koding/tunnel/proto" 22 | 23 | "github.com/hashicorp/yamux" 24 | ) 25 | 26 | var ( 27 | errNoClientSession = errors.New("no client session established") 28 | defaultTimeout = 10 * time.Second 29 | ) 30 | 31 | // Server is responsible for proxying public connections to the client over a 32 | // tunnel connection. It also listens to control messages from the client. 33 | type Server struct { 34 | // pending contains the channel that is associated with each new tunnel request. 35 | pending map[string]chan net.Conn 36 | // pendingMu protects the pending map. 37 | pendingMu sync.Mutex 38 | 39 | // sessions contains a session per virtual host. 40 | // Sessions provides multiplexing over one connection. 41 | sessions map[string]*yamux.Session 42 | // sessionsMu protects sessions. 43 | sessionsMu sync.Mutex 44 | 45 | // controls contains the control connection from the client to the server. 46 | controls *controls 47 | 48 | // virtualHosts is used to map public hosts to remote clients. 49 | virtualHosts vhostStorage 50 | 51 | // virtualAddrs. 52 | virtualAddrs *vaddrStorage 53 | 54 | // connCh is used to publish accepted connections for tcp tunnels. 55 | connCh chan net.Conn 56 | 57 | // onConnectCallbacks contains client callbacks called when control 58 | // session is established for a client with given identifier. 59 | onConnectCallbacks *callbacks 60 | 61 | // onDisconnectCallbacks contains client callbacks called when control 62 | // session is closed for a client with given identifier. 63 | onDisconnectCallbacks *callbacks 64 | 65 | // states represents current clients' connections state. 66 | states map[string]ClientState 67 | // statesMu protects states. 68 | statesMu sync.RWMutex 69 | // stateCh notifies receiver about client state changes. 70 | stateCh chan<- *ClientStateChange 71 | 72 | // httpDirector is provided by ServerConfig, if not nil decorates http requests 73 | // before forwarding them to client. 74 | httpDirector func(*http.Request) 75 | 76 | // yamuxConfig is passed to new yamux.Session's 77 | yamuxConfig *yamux.Config 78 | 79 | log logging.Logger 80 | } 81 | 82 | // ServerConfig defines the configuration for the Server 83 | type ServerConfig struct { 84 | // StateChanges receives state transition details each time client 85 | // connection state changes. The channel is expected to be sufficiently 86 | // buffered to keep up with event pace. 87 | // 88 | // If nil, no information about state transitions are dispatched 89 | // by the library. 90 | StateChanges chan<- *ClientStateChange 91 | 92 | // Director is a function that modifies HTTP request into a new HTTP request 93 | // before sending to client. If nil no modifications are done. 94 | Director func(*http.Request) 95 | 96 | // Debug enables debug mode, enable only if you want to debug the server 97 | Debug bool 98 | 99 | // Log defines the logger. If nil a default logging.Logger is used. 100 | Log logging.Logger 101 | 102 | // YamuxConfig defines the config which passed to every new yamux.Session. If nil 103 | // yamux.DefaultConfig() is used. 104 | YamuxConfig *yamux.Config 105 | } 106 | 107 | // NewServer creates a new Server. The defaults are used if config is nil. 108 | func NewServer(cfg *ServerConfig) (*Server, error) { 109 | yamuxConfig := yamux.DefaultConfig() 110 | if cfg.YamuxConfig != nil { 111 | if err := yamux.VerifyConfig(cfg.YamuxConfig); err != nil { 112 | return nil, err 113 | } 114 | 115 | yamuxConfig = cfg.YamuxConfig 116 | } 117 | 118 | log := newLogger("tunnel-server", cfg.Debug) 119 | if cfg.Log != nil { 120 | log = cfg.Log 121 | } 122 | 123 | connCh := make(chan net.Conn) 124 | 125 | opts := &vaddrOptions{ 126 | connCh: connCh, 127 | log: log, 128 | } 129 | 130 | s := &Server{ 131 | pending: make(map[string]chan net.Conn), 132 | sessions: make(map[string]*yamux.Session), 133 | onConnectCallbacks: newCallbacks("OnConnect"), 134 | onDisconnectCallbacks: newCallbacks("OnDisconnect"), 135 | virtualHosts: newVirtualHosts(), 136 | virtualAddrs: newVirtualAddrs(opts), 137 | controls: newControls(), 138 | states: make(map[string]ClientState), 139 | stateCh: cfg.StateChanges, 140 | httpDirector: cfg.Director, 141 | yamuxConfig: yamuxConfig, 142 | connCh: connCh, 143 | log: log, 144 | } 145 | 146 | go s.serveTCP() 147 | 148 | return s, nil 149 | } 150 | 151 | // ServeHTTP is a tunnel that creates an http/websocket tunnel between a 152 | // public connection and the client connection. 153 | func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { 154 | // if the user didn't add the control and tunnel handler manually, we'll 155 | // going to infer and call the respective path handlers. 156 | switch path.Clean(r.URL.Path) + "/" { 157 | case proto.ControlPath: 158 | s.checkConnect(s.controlHandler).ServeHTTP(w, r) 159 | return 160 | } 161 | 162 | if err := s.handleHTTP(w, r); err != nil { 163 | if !strings.Contains(err.Error(), "no virtual host available") { // this one is outputted too much, unnecessarily 164 | s.log.Error("remote %s (%s): %s", r.RemoteAddr, r.RequestURI, err) 165 | } 166 | http.Error(w, err.Error(), http.StatusBadGateway) 167 | } 168 | } 169 | 170 | // handleHTTP handles a single HTTP request 171 | func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) error { 172 | s.log.Debug("HandleHTTP request:") 173 | s.log.Debug("%v", r) 174 | 175 | if s.httpDirector != nil { 176 | s.httpDirector(r) 177 | } 178 | 179 | hostPort := strings.ToLower(r.Host) 180 | if hostPort == "" { 181 | return errors.New("request host is empty") 182 | } 183 | 184 | // if someone hits foo.example.com:8080, this should be proxied to 185 | // localhost:8080, so send the port to the client so it knows how to proxy 186 | // correctly. If no port is available, it's up to client how to interpret it 187 | host, port, err := parseHostPort(hostPort) 188 | if err != nil { 189 | // no need to return, just continue lazily, port will be 0, which in 190 | // our case will be proxied to client's local servers port 80 191 | s.log.Debug("No port available for %q, sending port 80 to client", hostPort) 192 | } 193 | 194 | // get the identifier associated with this host 195 | identifier, ok := s.getIdentifier(hostPort) 196 | if !ok { 197 | // fallback to host 198 | identifier, ok = s.getIdentifier(host) 199 | if !ok { 200 | return fmt.Errorf("no virtual host available for %q", hostPort) 201 | } 202 | } 203 | 204 | if isWebsocketConn(r) { 205 | s.log.Debug("handling websocket connection") 206 | 207 | return s.handleWSConn(w, r, identifier, port) 208 | } 209 | 210 | stream, err := s.dial(identifier, proto.HTTP, port) 211 | if err != nil { 212 | return err 213 | } 214 | defer func() { 215 | s.log.Debug("Closing stream") 216 | stream.Close() 217 | }() 218 | 219 | if err := r.Write(stream); err != nil { 220 | return err 221 | } 222 | 223 | s.log.Debug("Session opened to client, writing request to client") 224 | resp, err := http.ReadResponse(bufio.NewReader(stream), r) 225 | if err != nil { 226 | return fmt.Errorf("read from tunnel: %s", err.Error()) 227 | } 228 | 229 | defer func() { 230 | if resp.Body != nil { 231 | if err := resp.Body.Close(); err != nil && err != io.ErrUnexpectedEOF { 232 | s.log.Error("resp.Body Close error: %s", err.Error()) 233 | } 234 | } 235 | }() 236 | 237 | s.log.Debug("Response received, writing back to public connection: %+v", resp) 238 | 239 | copyHeader(w.Header(), resp.Header) 240 | w.WriteHeader(resp.StatusCode) 241 | 242 | if _, err := io.Copy(w, resp.Body); err != nil { 243 | if err == io.ErrUnexpectedEOF { 244 | s.log.Debug("Client closed the connection, couldn't copy response") 245 | } else { 246 | s.log.Error("copy err: %s", err) // do not return, because we might write multipe headers 247 | } 248 | } 249 | 250 | return nil 251 | } 252 | 253 | func (s *Server) serveTCP() { 254 | for conn := range s.connCh { 255 | go s.serveTCPConn(conn) 256 | } 257 | } 258 | 259 | func (s *Server) serveTCPConn(conn net.Conn) { 260 | err := s.handleTCPConn(conn) 261 | if err != nil { 262 | s.log.Warning("failed to serve %q: %s", conn.RemoteAddr(), err) 263 | conn.Close() 264 | } 265 | } 266 | 267 | func (s *Server) handleWSConn(w http.ResponseWriter, r *http.Request, ident string, port int) error { 268 | hj, ok := w.(http.Hijacker) 269 | if !ok { 270 | return fmt.Errorf("webserver doesn't support hijacking: %T", w) 271 | } 272 | 273 | conn, _, err := hj.Hijack() 274 | if err != nil { 275 | return fmt.Errorf("hijack not possible: %s", err) 276 | } 277 | 278 | stream, err := s.dial(ident, proto.WS, port) 279 | if err != nil { 280 | return err 281 | } 282 | 283 | if err := r.Write(stream); err != nil { 284 | err = errors.New("unable to write upgrade request: " + err.Error()) 285 | return nonil(err, stream.Close()) 286 | } 287 | 288 | resp, err := http.ReadResponse(bufio.NewReader(stream), r) 289 | if err != nil { 290 | err = errors.New("unable to read upgrade response: " + err.Error()) 291 | return nonil(err, stream.Close()) 292 | } 293 | 294 | if err := resp.Write(conn); err != nil { 295 | err = errors.New("unable to write upgrade response: " + err.Error()) 296 | return nonil(err, stream.Close()) 297 | } 298 | 299 | var wg sync.WaitGroup 300 | wg.Add(2) 301 | 302 | go s.proxy(&wg, conn, stream) 303 | go s.proxy(&wg, stream, conn) 304 | 305 | wg.Wait() 306 | 307 | return nonil(stream.Close(), conn.Close()) 308 | } 309 | 310 | func (s *Server) handleTCPConn(conn net.Conn) error { 311 | ident, ok := s.virtualAddrs.getIdent(conn) 312 | if !ok { 313 | return fmt.Errorf("no virtual address available for %s", conn.LocalAddr()) 314 | } 315 | 316 | _, port, err := parseHostPort(conn.LocalAddr().String()) 317 | if err != nil { 318 | return err 319 | } 320 | 321 | stream, err := s.dial(ident, proto.TCP, port) 322 | if err != nil { 323 | return err 324 | } 325 | 326 | var wg sync.WaitGroup 327 | wg.Add(2) 328 | 329 | go s.proxy(&wg, conn, stream) 330 | go s.proxy(&wg, stream, conn) 331 | 332 | wg.Wait() 333 | 334 | return nonil(stream.Close(), conn.Close()) 335 | } 336 | 337 | func (s *Server) proxy(wg *sync.WaitGroup, dst, src net.Conn) { 338 | defer wg.Done() 339 | 340 | s.log.Debug("tunneling %s -> %s", src.RemoteAddr(), dst.RemoteAddr()) 341 | n, err := io.Copy(dst, src) 342 | s.log.Debug("tunneled %d bytes %s -> %s: %v", n, src.RemoteAddr(), dst.RemoteAddr(), err) 343 | } 344 | 345 | func (s *Server) dial(identifier string, p proto.Type, port int) (net.Conn, error) { 346 | control, ok := s.getControl(identifier) 347 | if !ok { 348 | return nil, errNoClientSession 349 | } 350 | 351 | session, err := s.getSession(identifier) 352 | if err != nil { 353 | return nil, err 354 | } 355 | 356 | msg := proto.ControlMessage{ 357 | Action: proto.RequestClientSession, 358 | Protocol: p, 359 | LocalPort: port, 360 | } 361 | 362 | s.log.Debug("Sending control msg %+v", msg) 363 | 364 | // ask client to open a session to us, so we can accept it 365 | if err := control.send(msg); err != nil { 366 | // we might have several issues here, either the stream is closed, or 367 | // the session is going be shut down, the underlying connection might 368 | // be broken. In all cases, it's not reliable anymore having a client 369 | // session. 370 | control.Close() 371 | s.deleteControl(identifier) 372 | return nil, errNoClientSession 373 | } 374 | 375 | var stream net.Conn 376 | acceptStream := func() error { 377 | stream, err = session.Accept() 378 | return err 379 | } 380 | 381 | // if we don't receive anything from the client, we'll timeout 382 | s.log.Debug("Waiting for session accept") 383 | 384 | select { 385 | case err := <-async(acceptStream): 386 | return stream, err 387 | case <-time.After(defaultTimeout): 388 | return nil, errors.New("timeout getting session") 389 | } 390 | } 391 | 392 | // controlHandler is used to capture incoming tunnel connect requests into raw 393 | // tunnel TCP connections. 394 | func (s *Server) controlHandler(w http.ResponseWriter, r *http.Request) (ctErr error) { 395 | identifier := r.Header.Get(proto.ClientIdentifierHeader) 396 | _, ok := s.getHost(identifier) 397 | if !ok { 398 | return fmt.Errorf("no host associated for identifier %s. please use server.AddHost()", identifier) 399 | } 400 | 401 | ct, ok := s.getControl(identifier) 402 | if ok { 403 | ct.Close() 404 | s.deleteControl(identifier) 405 | s.deleteSession(identifier) 406 | s.log.Warning("Control connection for %q already exists. This is a race condition and needs to be fixed on client implementation", identifier) 407 | return fmt.Errorf("control conn for %s already exist. \n", identifier) 408 | } 409 | 410 | s.log.Debug("Tunnel with identifier %s", identifier) 411 | 412 | hj, ok := w.(http.Hijacker) 413 | if !ok { 414 | return fmt.Errorf("webserver doesn't support hijacking: %T", w) 415 | } 416 | 417 | conn, _, err := hj.Hijack() 418 | if err != nil { 419 | return fmt.Errorf("hijack not possible: %s", err) 420 | } 421 | 422 | if _, err := io.WriteString(conn, "HTTP/1.1 "+proto.Connected+"\n\n"); err != nil { 423 | return fmt.Errorf("error writing response: %s", err) 424 | } 425 | 426 | if err := conn.SetDeadline(time.Time{}); err != nil { 427 | return fmt.Errorf("error setting connection deadline: %s", err) 428 | } 429 | 430 | s.log.Debug("Creating control session") 431 | session, err := yamux.Server(conn, s.yamuxConfig) 432 | if err != nil { 433 | return err 434 | } 435 | s.addSession(identifier, session) 436 | 437 | var stream net.Conn 438 | 439 | // close and delete the session/stream if something goes wrong 440 | defer func() { 441 | if ctErr != nil { 442 | if stream != nil { 443 | stream.Close() 444 | } 445 | s.deleteSession(identifier) 446 | } 447 | }() 448 | 449 | acceptStream := func() error { 450 | stream, err = session.Accept() 451 | return err 452 | } 453 | 454 | // if we don't receive anything from the client, we'll timeout 455 | select { 456 | case err := <-async(acceptStream): 457 | if err != nil { 458 | return err 459 | } 460 | case <-time.After(time.Second * 10): 461 | return errors.New("timeout getting session") 462 | } 463 | 464 | s.log.Debug("Initiating handshake protocol") 465 | buf := make([]byte, len(proto.HandshakeRequest)) 466 | if _, err := stream.Read(buf); err != nil { 467 | return err 468 | } 469 | 470 | if string(buf) != proto.HandshakeRequest { 471 | return fmt.Errorf("handshake aborted. got: %s", string(buf)) 472 | } 473 | 474 | if _, err := stream.Write([]byte(proto.HandshakeResponse)); err != nil { 475 | return err 476 | } 477 | 478 | // setup control stream and start to listen to messages 479 | ct = newControl(stream) 480 | s.addControl(identifier, ct) 481 | go s.listenControl(ct) 482 | 483 | s.log.Debug("Control connection is setup") 484 | return nil 485 | } 486 | 487 | // listenControl listens to messages coming from the client. 488 | func (s *Server) listenControl(ct *control) { 489 | s.onConnect(ct.identifier) 490 | 491 | for { 492 | var msg map[string]interface{} 493 | err := ct.dec.Decode(&msg) 494 | if err != nil { 495 | host, _ := s.getHost(ct.identifier) 496 | s.log.Debug("Closing client connection: '%s', %s'", host, ct.identifier) 497 | 498 | // close client connection so it reconnects again 499 | ct.Close() 500 | 501 | // don't forget to cleanup anything 502 | s.deleteControl(ct.identifier) 503 | s.deleteSession(ct.identifier) 504 | 505 | s.onDisconnect(ct.identifier, err) 506 | 507 | if err != io.EOF { 508 | s.log.Error("decode err: %s", err) 509 | } 510 | return 511 | } 512 | 513 | // right now we don't do anything with the messages, but because the 514 | // underlying connection needs to establihsed, we know when we have 515 | // disconnection(above), so we can cleanup the connection. 516 | s.log.Debug("msg: %s", msg) 517 | } 518 | } 519 | 520 | // OnConnect invokes a callback for client with given identifier, 521 | // when it establishes a control session. 522 | // After a client is connected, the associated function 523 | // is also removed and needs to be added again. 524 | func (s *Server) OnConnect(identifier string, fn func() error) { 525 | s.onConnectCallbacks.add(identifier, fn) 526 | } 527 | 528 | // onConnect sends notifications to listeners (registered in onConnectCallbacks 529 | // or stateChanges chanel readers) when client connects. 530 | func (s *Server) onConnect(identifier string) { 531 | if err := s.onConnectCallbacks.call(identifier); err != nil { 532 | s.log.Error("OnConnect: error calling callback for %q: %s", identifier, err) 533 | } 534 | 535 | s.changeState(identifier, ClientConnected, nil) 536 | } 537 | 538 | // OnDisconnect calls the function when the client connected with the 539 | // associated identifier disconnects from the server. 540 | // After a client is disconnected, the associated function 541 | // is also removed and needs to be added again. 542 | func (s *Server) OnDisconnect(identifier string, fn func() error) { 543 | s.onDisconnectCallbacks.add(identifier, fn) 544 | } 545 | 546 | // onDisconnect sends notifications to listeners (registered in onDisconnectCallbacks 547 | // or stateChanges chanel readers) when client disconnects. 548 | func (s *Server) onDisconnect(identifier string, err error) { 549 | if err := s.onDisconnectCallbacks.call(identifier); err != nil { 550 | s.log.Error("OnDisconnect: error calling callback for %q: %s", identifier, err) 551 | } 552 | 553 | s.changeState(identifier, ClientClosed, err) 554 | } 555 | 556 | func (s *Server) changeState(identifier string, state ClientState, err error) (prev ClientState) { 557 | s.statesMu.Lock() 558 | defer s.statesMu.Unlock() 559 | 560 | prev = s.states[identifier] 561 | s.states[identifier] = state 562 | 563 | if s.stateCh != nil { 564 | change := &ClientStateChange{ 565 | Identifier: identifier, 566 | Previous: prev, 567 | Current: state, 568 | Error: err, 569 | } 570 | 571 | select { 572 | case s.stateCh <- change: 573 | default: 574 | s.log.Warning("Dropping state change due to slow reader: %s", change) 575 | } 576 | } 577 | 578 | return prev 579 | } 580 | 581 | // AddHost adds the given virtual host and maps it to the identifier. 582 | func (s *Server) AddHost(host, identifier string) { 583 | s.virtualHosts.AddHost(host, identifier) 584 | } 585 | 586 | // DeleteHost deletes the given virtual host. Once removed any request to this 587 | // host is denied. 588 | func (s *Server) DeleteHost(host string) { 589 | s.virtualHosts.DeleteHost(host) 590 | } 591 | 592 | // AddAddr starts accepting connections on listener l, routing every connection 593 | // to a tunnel client given by the identifier. 594 | // 595 | // When ip parameter is nil, all connections accepted from the listener are 596 | // routed to the tunnel client specified by the identifier (port-based routing). 597 | // 598 | // When ip parameter is non-nil, only those connections are routed whose local 599 | // address matches the specified ip (ip-based routing). 600 | // 601 | // If l listens on multiple interfaces it's desirable to call AddAddr multiple 602 | // times with the same l value but different ip one. 603 | func (s *Server) AddAddr(l net.Listener, ip net.IP, identifier string) { 604 | s.virtualAddrs.Add(l, ip, identifier) 605 | } 606 | 607 | // DeleteAddr stops listening for connections on the given listener. 608 | // 609 | // Upon return no more connections will be tunneled, but as the method does not 610 | // close the listener, so any ongoing connection won't get interrupted. 611 | func (s *Server) DeleteAddr(l net.Listener, ip net.IP) { 612 | s.virtualAddrs.Delete(l, ip) 613 | } 614 | 615 | func (s *Server) getIdentifier(host string) (string, bool) { 616 | identifier, ok := s.virtualHosts.GetIdentifier(host) 617 | return identifier, ok 618 | } 619 | 620 | func (s *Server) getHost(identifier string) (string, bool) { 621 | host, ok := s.virtualHosts.GetHost(identifier) 622 | return host, ok 623 | } 624 | 625 | func (s *Server) addControl(identifier string, conn *control) { 626 | s.controls.addControl(identifier, conn) 627 | } 628 | 629 | func (s *Server) getControl(identifier string) (*control, bool) { 630 | return s.controls.getControl(identifier) 631 | } 632 | 633 | func (s *Server) deleteControl(identifier string) { 634 | s.controls.deleteControl(identifier) 635 | } 636 | 637 | func (s *Server) getSession(identifier string) (*yamux.Session, error) { 638 | s.sessionsMu.Lock() 639 | session, ok := s.sessions[identifier] 640 | s.sessionsMu.Unlock() 641 | 642 | if !ok { 643 | return nil, fmt.Errorf("no session available for identifier: '%s'", identifier) 644 | } 645 | 646 | return session, nil 647 | } 648 | 649 | func (s *Server) addSession(identifier string, session *yamux.Session) { 650 | s.sessionsMu.Lock() 651 | s.sessions[identifier] = session 652 | s.sessionsMu.Unlock() 653 | } 654 | 655 | func (s *Server) deleteSession(identifier string) { 656 | s.sessionsMu.Lock() 657 | defer s.sessionsMu.Unlock() 658 | 659 | session, ok := s.sessions[identifier] 660 | 661 | if !ok { 662 | return // nothing to delete 663 | } 664 | 665 | if session != nil { 666 | session.GoAway() // don't accept any new connection 667 | session.Close() 668 | } 669 | 670 | delete(s.sessions, identifier) 671 | } 672 | 673 | func copyHeader(dst, src http.Header) { 674 | for k, v := range src { 675 | vv := make([]string, len(v)) 676 | copy(vv, v) 677 | dst[k] = vv 678 | } 679 | } 680 | 681 | // checkConnect checks whether the incoming request is HTTP CONNECT method. 682 | func (s *Server) checkConnect(fn func(w http.ResponseWriter, r *http.Request) error) http.Handler { 683 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 684 | if r.Method != "CONNECT" { 685 | http.Error(w, "405 must CONNECT\n", http.StatusMethodNotAllowed) 686 | return 687 | } 688 | 689 | if err := fn(w, r); err != nil { 690 | s.log.Error("Handler err: %v", err.Error()) 691 | 692 | if identifier := r.Header.Get(proto.ClientIdentifierHeader); identifier != "" { 693 | s.onDisconnect(identifier, err) 694 | } 695 | 696 | http.Error(w, err.Error(), 502) 697 | } 698 | }) 699 | } 700 | 701 | func parseHostPort(addr string) (string, int, error) { 702 | host, port, err := net.SplitHostPort(addr) 703 | if err != nil { 704 | return "", 0, err 705 | } 706 | 707 | n, err := strconv.ParseUint(port, 10, 16) 708 | if err != nil { 709 | return "", 0, err 710 | } 711 | 712 | return host, int(n), nil 713 | } 714 | 715 | func isWebsocketConn(r *http.Request) bool { 716 | return r.Method == "GET" && headerContains(r.Header["Connection"], "upgrade") && 717 | headerContains(r.Header["Upgrade"], "websocket") 718 | } 719 | 720 | // headerContains is a copy of tokenListContainsValue from gorilla/websocket/util.go 721 | func headerContains(header []string, value string) bool { 722 | for _, h := range header { 723 | for _, v := range strings.Split(h, ",") { 724 | if strings.EqualFold(strings.TrimSpace(v), value) { 725 | return true 726 | } 727 | } 728 | } 729 | 730 | return false 731 | } 732 | 733 | func nonil(err ...error) error { 734 | for _, e := range err { 735 | if e != nil { 736 | return e 737 | } 738 | } 739 | 740 | return nil 741 | } 742 | 743 | func newLogger(name string, debug bool) logging.Logger { 744 | log := logging.NewLogger(name) 745 | logHandler := logging.NewWriterHandler(os.Stderr) 746 | logHandler.Colorize = true 747 | log.SetHandler(logHandler) 748 | 749 | if debug { 750 | log.SetLevel(logging.DEBUG) 751 | logHandler.SetLevel(logging.DEBUG) 752 | } 753 | 754 | return log 755 | } 756 | -------------------------------------------------------------------------------- /spec.md: -------------------------------------------------------------------------------- 1 | # Specification 2 | 3 | # Naming conventions 4 | 5 | * `server` is listening to public connection and is responsible of routing 6 | public HTTP requests to clients. 7 | * `client` is a long running process, connected to a server and running on a local machine. 8 | * `virtualHost` is a virtual domain that maps a domain to a single client. i.e: 9 | `arslan.koding.io` is a virtualhost which is mapped to my `client` running on 10 | my local machine. 11 | * `identifier` is a secret token, which is not meant to be shared with others. 12 | An identifier is responsible of mapping a virtualhost to a client. 13 | * `session` is a single TCP connection which uses the library `yamux`. A 14 | session can be created either via `yamux.Server()` or `yamux.Client` 15 | * `stream` is a `net.Conn` compatible `virtual` connection that is multiplexed 16 | over the `session`. A session can have hundreds of thousands streams 17 | * `control connection` is a single `stream` which is used to communicate and 18 | handle messaging between server and client. It uses a custom protocol which 19 | is JSON encoded. 20 | * `tunnel connection` is a single `stream` which is used to proxy public HTTP 21 | requests from the `server` to the `client` and vice versa. A single `tunnel` 22 | connection is created for every single HTTP requests. 23 | * `public connection` is a connection from a remote machine to the `server` 24 | * `ControlHandler` is a http.Handler which listens to requests coming to 25 | `/_controlPath_/`. It's used to setup the initial `session` connection from 26 | `client` to `server`. And creates the `control connection` from this session. 27 | server and client, and also for all additional new tunnel. It literally 28 | captures the incoming HTTP request and hijacks it and converts it into RAW TCP, 29 | which then is used as the foundation for all yamux `sessions.` 30 | 31 | 32 | # Server 33 | 1. Server is created with `NewServer()` which returns `*Server`, a `http.Handler` 34 | compatible type. Plug into any HTTP server you want. The root path `"/"` is 35 | recommended to listen and proxy any tunnels. It also listens to any request 36 | coming to `ControlHandler` 37 | 2. Tunneling is based on virtual hosts. A virtual hosts is identified with an 38 | unique identifier. This identifier is the only piece that both client and 39 | server needs to known ahead. Think of it as a secret token. 40 | 3. To add a virtual host, call `server.AddHost(virtualHost, identifier)`. This 41 | step needs to be done from the server itself. This can be could manually or 42 | via custom auth based HTTP handlers, such as "/addhost", which adds 43 | virtualhosts and returns the `identifier` to the requester (in our case `client`) 44 | 4. A DNS record and it's subdomains needs to point to a `server`, so it can 45 | handle virtual hosts, i.e: `*.example.com` is routed to a server, which can 46 | handle `foo.example.com`, `bar.example.com`, etc.. 47 | 48 | 49 | # Client 50 | 51 | 1. Client is created with `NewClient(serverAddr, localAddr)` which returns a 52 | `*Client`. Here `serverAddr` is the TCP address to the server. `localAddr` 53 | is the server in which all public requests are forwarded to. It's optional if 54 | you want it to be done dynamically 55 | 2. Once a client is created, it starts with `client.Start(identifier)`. Here 56 | `identifier` is needed upfront. This method creates the initial TCP 57 | connection to the server. It sends the identifier back to the server. This 58 | TCP connection is used as the foundation for `yamux.Client()`. Once a yamux 59 | session is established, we are able to use this single connection to have 60 | multiple streams, which are multiplexed over this one connection. A `control 61 | connection` is created and client starts to listen it. `client.Start` is 62 | blocking. 63 | 64 | # Control Handshake 65 | 66 | 1. Client sends a `handshakeRequest` over the `control connection` stream 67 | 2. The server sends back a `handshakeResponse` to the client over the `control connection` stream 68 | 3. Once the client receives the `handshakeResponse` from the server, it starts 69 | to listen from the `control connection` stream. 70 | 4. A `control connection` is json.Encoder/Decoder both for server and client 71 | 72 | 73 | # Tunnel creation 74 | 1. When the server receives a public connection, it checks the HTTP host 75 | headers and retrieves the corresponding identifier from the given host. 76 | 2. The server retrieves the `control connection` which was associated with this 77 | `identifier` and sends a `ControlMsg` message with the action 78 | `RequestClientSession`. This message is in the form of: 79 | 80 | type ControlMsg struct { 81 | Action Action `json:"action"` 82 | Protocol TransportProtocol `json:"transportProtocol"` 83 | LocalPort string `json:"localPort"` 84 | } 85 | 86 | Here the `LocalPort` is read from the HTTP Host header. If absent a zero 87 | port is sent and client maps it to the local server running at port 80, unless 88 | the `localAddr` is specified in `client.Start()` method. `Protocol` is 89 | reserved for future features. 90 | 3. The server immediately starts to listen(accept) to a new `stream`. This is 91 | blocking and it waits there. 92 | 4. When the client receives the `RequestClientSession` message, it opens a new 93 | `virtual` TCP connection, a `stream` to the server. 94 | 5. The server which was waiting for a new stream in step 3, establish the stream. 95 | 6. The server copies the request over the stream to the client. 96 | 7. The client copies the request coming from the server to the local server and 97 | copies back the result to the server 98 | 8. The server reads the response coming from the client and returns back it to 99 | the public connection requester 100 | 101 | -------------------------------------------------------------------------------- /tcpproxy.go: -------------------------------------------------------------------------------- 1 | package tunnel 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | 7 | "github.com/koding/logging" 8 | "github.com/koding/tunnel/proto" 9 | ) 10 | 11 | var ( 12 | tpcLog = logging.NewLogger("tcp") 13 | ) 14 | 15 | // TCPProxy forwards TCP streams. 16 | // 17 | // If port-based routing is used, LocalAddr or FetchLocalAddr field is required 18 | // for tunneling to function properly. 19 | // Otherwise you'll be forwarding traffic to random ports and this is usually not desired. 20 | // 21 | // If IP-based routing is used then tunnel server connection request is 22 | // proxied to 127.0.0.1:incomingPort where incomingPort is control message LocalPort. 23 | // Usually this is tunnel server's public exposed Port. 24 | // This behaviour can be changed by setting LocalAddr or FetchLocalAddr. 25 | // FetchLocalAddr takes precedence over LocalAddr. 26 | type TCPProxy struct { 27 | // LocalAddr defines the TCP address of the local server. 28 | // This is optional if you want to specify a single TCP address. 29 | LocalAddr string 30 | // FetchLocalAddr is used for looking up TCP address of the server. 31 | // This is optional if you want to specify a dynamic TCP address based on incommig port. 32 | FetchLocalAddr func(port int) (string, error) 33 | // Log is a custom logger that can be used for the proxy. 34 | // If not set a "tcp" logger is used. 35 | Log logging.Logger 36 | } 37 | 38 | // Proxy is a ProxyFunc. 39 | func (p *TCPProxy) Proxy(remote net.Conn, msg *proto.ControlMessage) { 40 | if msg.Protocol != proto.TCP { 41 | panic("Proxy mismatch") 42 | } 43 | 44 | var log = p.log() 45 | 46 | var port = msg.LocalPort 47 | if port == 0 { 48 | log.Warning("TCP proxy to port 0") 49 | } 50 | 51 | var localAddr = fmt.Sprintf("127.0.0.1:%d", port) 52 | if p.LocalAddr != "" { 53 | localAddr = p.LocalAddr 54 | } else if p.FetchLocalAddr != nil { 55 | l, err := p.FetchLocalAddr(msg.LocalPort) 56 | if err != nil { 57 | log.Warning("Failed to get custom local address: %s", err) 58 | return 59 | } 60 | localAddr = l 61 | } 62 | 63 | log.Debug("Dialing local server: %q", localAddr) 64 | local, err := net.DialTimeout("tcp", localAddr, defaultTimeout) 65 | if err != nil { 66 | log.Error("Dialing local server %q failed: %s", localAddr, err) 67 | return 68 | } 69 | 70 | Join(local, remote, log) 71 | } 72 | 73 | func (p *TCPProxy) log() logging.Logger { 74 | if p.Log != nil { 75 | return p.Log 76 | } 77 | return tpcLog 78 | } 79 | -------------------------------------------------------------------------------- /tunnel_test.go: -------------------------------------------------------------------------------- 1 | package tunnel_test 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/koding/tunnel" 11 | "github.com/koding/tunnel/tunneltest" 12 | 13 | "github.com/cenkalti/backoff" 14 | ) 15 | 16 | func TestMultipleRequest(t *testing.T) { 17 | tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP)) 18 | if err != nil { 19 | t.Fatal(err) 20 | } 21 | defer tt.Close() 22 | 23 | // make a request to tunnelserver, this should be tunneled to local server 24 | var wg sync.WaitGroup 25 | for i := 0; i < 100; i++ { 26 | wg.Add(1) 27 | 28 | go func(i int) { 29 | defer wg.Done() 30 | msg := "hello" + strconv.Itoa(i) 31 | res, err := echoHTTP(tt, msg) 32 | if err != nil { 33 | t.Fatalf("echoHTTP error: %s", err) 34 | } 35 | 36 | if res != msg { 37 | t.Errorf("got %q, want %q", res, msg) 38 | } 39 | }(i) 40 | } 41 | 42 | wg.Wait() 43 | } 44 | 45 | func TestMultipleLatencyRequest(t *testing.T) { 46 | tt, err := tunneltest.Serve(singleHTTP(handlerLatencyEchoHTTP)) 47 | if err != nil { 48 | t.Fatal(err) 49 | } 50 | defer tt.Close() 51 | 52 | // make a request to tunnelserver, this should be tunneled to local server 53 | var wg sync.WaitGroup 54 | for i := 0; i < 100; i++ { 55 | wg.Add(1) 56 | 57 | go func(i int) { 58 | defer wg.Done() 59 | msg := "hello" + strconv.Itoa(i) 60 | res, err := echoHTTP(tt, msg) 61 | if err != nil { 62 | t.Fatalf("echoHTTP error: %s", err) 63 | } 64 | 65 | if res != msg { 66 | t.Errorf("got %q, want %q", res, msg) 67 | } 68 | }(i) 69 | } 70 | 71 | wg.Wait() 72 | } 73 | 74 | func TestReconnectClient(t *testing.T) { 75 | tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP)) 76 | if err != nil { 77 | t.Fatal(err) 78 | } 79 | defer tt.Close() 80 | 81 | msg := "hello" 82 | res, err := echoHTTP(tt, msg) 83 | if err != nil { 84 | t.Fatalf("echoHTTP error: %s", err) 85 | } 86 | 87 | if res != msg { 88 | t.Errorf("got %q, want %q", res, msg) 89 | } 90 | 91 | client := tt.Clients["http"] 92 | 93 | // close client, and start it again 94 | client.Close() 95 | 96 | go client.Start() 97 | <-client.StartNotify() 98 | 99 | msg = "helloagain" 100 | res, err = echoHTTP(tt, msg) 101 | if err != nil { 102 | t.Fatalf("echoHTTP error: %s", err) 103 | } 104 | 105 | if res != msg { 106 | t.Errorf("got %q, want %q", res, msg) 107 | } 108 | } 109 | 110 | func TestNoClient(t *testing.T) { 111 | const expectedErr = "no client session established" 112 | 113 | rec := tunneltest.NewStateRecorder() 114 | 115 | tt, err := tunneltest.Serve(singleRecHTTP(handlerEchoHTTP, rec.C())) 116 | if err != nil { 117 | t.Fatal(err) 118 | } 119 | defer tt.Close() 120 | 121 | if err := rec.WaitTransitions( 122 | tunnel.ClientStarted, 123 | tunnel.ClientConnecting, 124 | tunnel.ClientConnected, 125 | ); err != nil { 126 | t.Fatal(err) 127 | } 128 | 129 | if err := tt.ServerStateRecorder.WaitTransition( 130 | tunnel.ClientUnknown, 131 | tunnel.ClientConnected, 132 | ); err != nil { 133 | t.Fatal(err) 134 | } 135 | 136 | // close client, this is the main point of the test 137 | if err := tt.Clients["http"].Close(); err != nil { 138 | t.Fatal(err) 139 | } 140 | 141 | if err := rec.WaitTransitions( 142 | tunnel.ClientConnected, 143 | tunnel.ClientDisconnected, 144 | tunnel.ClientClosed, 145 | ); err != nil { 146 | t.Fatal(err) 147 | } 148 | 149 | if err := tt.ServerStateRecorder.WaitTransition( 150 | tunnel.ClientConnected, 151 | tunnel.ClientClosed, 152 | ); err != nil { 153 | t.Fatal(err) 154 | } 155 | 156 | msg := "hello" 157 | res, err := echoHTTP(tt, msg) 158 | if err != nil { 159 | t.Fatalf("echoHTTP error: %s", err) 160 | } 161 | 162 | if res != expectedErr { 163 | t.Errorf("got %q, want %q", res, msg) 164 | } 165 | } 166 | 167 | func TestNoHost(t *testing.T) { 168 | tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP)) 169 | if err != nil { 170 | t.Fatal(err) 171 | } 172 | defer tt.Close() 173 | 174 | noBackoff := backoff.NewConstantBackOff(time.Duration(-1)) 175 | 176 | unknown, err := tunnel.NewClient(&tunnel.ClientConfig{ 177 | Identifier: "unknown", 178 | ServerAddr: tt.ServerAddr().String(), 179 | Backoff: noBackoff, 180 | Debug: testing.Verbose(), 181 | }) 182 | if err != nil { 183 | t.Fatalf("client error: %s", err) 184 | } 185 | unknown.Start() 186 | defer unknown.Close() 187 | 188 | if err := tt.ServerStateRecorder.WaitTransition( 189 | tunnel.ClientUnknown, 190 | tunnel.ClientClosed, 191 | ); err != nil { 192 | t.Fatal(err) 193 | } 194 | 195 | unknown.Start() 196 | if err := tt.ServerStateRecorder.WaitTransition( 197 | tunnel.ClientClosed, 198 | tunnel.ClientClosed, 199 | ); err != nil { 200 | t.Fatal(err) 201 | } 202 | } 203 | 204 | func TestNoLocalServer(t *testing.T) { 205 | const expectedErr = "no local server" 206 | 207 | tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP)) 208 | if err != nil { 209 | t.Fatal(err) 210 | } 211 | defer tt.Close() 212 | 213 | // close local listener, this is the main point of the test 214 | tt.Listeners["http"][0].Close() 215 | 216 | msg := "hello" 217 | res, err := echoHTTP(tt, msg) 218 | if err != nil { 219 | t.Fatalf("echoHTTP error: %s", err) 220 | } 221 | 222 | if res != expectedErr { 223 | t.Errorf("got %q, want %q", res, msg) 224 | } 225 | } 226 | 227 | func TestSingleRequest(t *testing.T) { 228 | tt, err := tunneltest.Serve(singleHTTP(handlerEchoHTTP)) 229 | if err != nil { 230 | t.Fatal(err) 231 | } 232 | defer tt.Close() 233 | 234 | msg := "hello" 235 | res, err := echoHTTP(tt, msg) 236 | if err != nil { 237 | t.Fatalf("echoHTTP error: %s", err) 238 | } 239 | 240 | if res != msg { 241 | t.Errorf("got %q, want %q", res, msg) 242 | } 243 | } 244 | 245 | func TestSingleLatencyRequest(t *testing.T) { 246 | tt, err := tunneltest.Serve(singleHTTP(handlerLatencyEchoHTTP)) 247 | if err != nil { 248 | t.Fatal(err) 249 | } 250 | defer tt.Close() 251 | 252 | msg := "hello" 253 | res, err := echoHTTP(tt, msg) 254 | if err != nil { 255 | t.Fatalf("echoHTTP error: %s", err) 256 | } 257 | 258 | if res != msg { 259 | t.Errorf("got %q, want %q", res, msg) 260 | } 261 | } 262 | 263 | func TestSingleTCP(t *testing.T) { 264 | tt, err := tunneltest.Serve(singleTCP(handlerEchoTCP)) 265 | if err != nil { 266 | t.Fatal(err) 267 | } 268 | defer tt.Close() 269 | 270 | msg := "hello" 271 | res, err := echoTCP(tt, msg) 272 | if err != nil { 273 | t.Fatalf("echoTCP error: %s", err) 274 | } 275 | 276 | if msg != res { 277 | t.Errorf("got %q, want %q", res, msg) 278 | } 279 | } 280 | 281 | func TestMultipleTCP(t *testing.T) { 282 | tt, err := tunneltest.Serve(singleTCP(handlerEchoTCP)) 283 | if err != nil { 284 | t.Fatal(err) 285 | } 286 | defer tt.Close() 287 | 288 | var wg sync.WaitGroup 289 | for i := 0; i < 100; i++ { 290 | wg.Add(1) 291 | 292 | go func(i int) { 293 | defer wg.Done() 294 | msg := "hello" + strconv.Itoa(i) 295 | res, err := echoTCP(tt, msg) 296 | if err != nil { 297 | t.Errorf("echoTCP: %s", err) 298 | } 299 | 300 | if res != msg { 301 | t.Errorf("got %q, want %q", res, msg) 302 | } 303 | }(i) 304 | } 305 | 306 | wg.Wait() 307 | } 308 | 309 | func TestMultipleLatencyTCP(t *testing.T) { 310 | tt, err := tunneltest.Serve(singleTCP(handlerLatencyEchoTCP)) 311 | if err != nil { 312 | t.Fatal(err) 313 | } 314 | defer tt.Close() 315 | 316 | var wg sync.WaitGroup 317 | for i := 0; i < 100; i++ { 318 | wg.Add(1) 319 | 320 | go func(i int) { 321 | defer wg.Done() 322 | msg := "hello" + strconv.Itoa(i) 323 | res, err := echoTCP(tt, msg) 324 | if err != nil { 325 | t.Errorf("echoTCP: %s", err) 326 | } 327 | 328 | if res != msg { 329 | t.Errorf("got %q, want %q", res, msg) 330 | } 331 | }(i) 332 | } 333 | 334 | wg.Wait() 335 | } 336 | 337 | func TestMultipleStreamTCP(t *testing.T) { 338 | tunnels := map[string]*tunneltest.Tunnel{ 339 | "http": { 340 | Type: tunneltest.TypeHTTP, 341 | LocalAddr: "127.0.0.1:0", 342 | Handler: handlerEchoHTTP, 343 | }, 344 | "tcp": { 345 | Type: tunneltest.TypeTCP, 346 | ClientIdent: "http", 347 | LocalAddr: "127.0.0.1:0", 348 | RemoteAddr: "127.0.0.1:0", 349 | Handler: handlerEchoTCP, 350 | }, 351 | "tcp_all": { 352 | Type: tunneltest.TypeTCP, 353 | ClientIdent: "http", 354 | LocalAddr: "127.0.0.1:0", 355 | RemoteAddr: "0.0.0.0:0", 356 | Handler: handlerEchoTCP, 357 | }, 358 | } 359 | 360 | addrs, err := tunneltest.UsableAddrs() 361 | if err != nil { 362 | t.Fatal(err) 363 | } 364 | 365 | clients := []string{"tcp"} 366 | for i, addr := range addrs { 367 | if addr.IP.IsLoopback() { 368 | continue 369 | } 370 | 371 | client := fmt.Sprintf("tcp_%d", i) 372 | 373 | tunnels[client] = &tunneltest.Tunnel{ 374 | Type: tunneltest.TypeTCP, 375 | ClientIdent: "http", 376 | LocalAddr: "127.0.0.1:0", 377 | RemoteAddrIdent: "tcp_all", 378 | IP: addr.IP, 379 | Handler: handlerEchoTCP, 380 | } 381 | 382 | clients = append(clients, client) 383 | } 384 | 385 | tt, err := tunneltest.Serve(tunnels) 386 | if err != nil { 387 | t.Fatal(err) 388 | } 389 | defer tt.Close() 390 | 391 | var wg sync.WaitGroup 392 | for i := 0; i < 100/len(clients); i++ { 393 | wg.Add(len(clients)) 394 | 395 | for j, ident := range clients { 396 | go func(ident string, i, j int) { 397 | defer wg.Done() 398 | msg := fmt.Sprintf("hello_%d_client_%d", j, i) 399 | res, err := echoTCPIdent(tt, msg, ident) 400 | if err != nil { 401 | t.Errorf("echoTCP: %s", err) 402 | } 403 | 404 | if res != msg { 405 | t.Errorf("got %q, want %q", res, msg) 406 | } 407 | }(ident, i, j) 408 | } 409 | } 410 | 411 | wg.Wait() 412 | } 413 | -------------------------------------------------------------------------------- /tunneltest/state_recorder.go: -------------------------------------------------------------------------------- 1 | package tunneltest 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "sync" 7 | "time" 8 | 9 | "github.com/koding/tunnel" 10 | ) 11 | 12 | var ( 13 | recWaitTimeout = 5 * time.Second 14 | recBuffer = 32 15 | ) 16 | 17 | // States is a sequence of client state changes. 18 | type States []*tunnel.ClientStateChange 19 | 20 | func (s States) String() string { 21 | if len(s) == 0 { 22 | return "" 23 | } 24 | 25 | var buf bytes.Buffer 26 | 27 | fmt.Fprintf(&buf, "[%s", s[0].String()) 28 | 29 | for _, s := range s[1:] { 30 | fmt.Fprintf(&buf, ",%s", s.String()) 31 | } 32 | 33 | buf.WriteRune(']') 34 | 35 | return buf.String() 36 | } 37 | 38 | // StateRecorder saves state changes pushed to StateRecorder.C(). 39 | type StateRecorder struct { 40 | mu sync.Mutex 41 | ch chan *tunnel.ClientStateChange 42 | recorded []*tunnel.ClientStateChange 43 | offset int 44 | } 45 | 46 | func NewStateRecorder() *StateRecorder { 47 | rec := &StateRecorder{ 48 | ch: make(chan *tunnel.ClientStateChange, recBuffer), 49 | } 50 | 51 | go rec.record() 52 | 53 | return rec 54 | } 55 | 56 | func (rec *StateRecorder) record() { 57 | for state := range rec.ch { 58 | rec.mu.Lock() 59 | rec.recorded = append(rec.recorded, state) 60 | rec.mu.Unlock() 61 | } 62 | } 63 | 64 | func (rec *StateRecorder) C() chan<- *tunnel.ClientStateChange { 65 | return rec.ch 66 | } 67 | 68 | func (rec *StateRecorder) WaitTransitions(states ...tunnel.ClientState) error { 69 | from := states[0] 70 | for _, to := range states[1:] { 71 | if err := rec.WaitTransition(from, to); err != nil { 72 | return err 73 | } 74 | 75 | from = to 76 | } 77 | 78 | return nil 79 | } 80 | 81 | func (rec *StateRecorder) WaitTransition(from, to tunnel.ClientState) error { 82 | timeout := time.After(recWaitTimeout) 83 | 84 | var lastStates []*tunnel.ClientStateChange 85 | for { 86 | select { 87 | case <-timeout: 88 | return fmt.Errorf("timed out waiting for %s->%s transition: %v", from, to, States(lastStates)) 89 | default: 90 | time.Sleep(50 * time.Millisecond) 91 | 92 | lastStates = rec.States()[rec.offset:] 93 | 94 | for i, state := range lastStates { 95 | if from != 0 && state.Previous != from { 96 | continue 97 | } 98 | 99 | if to != 0 && state.Current != to { 100 | continue 101 | } 102 | 103 | rec.offset += i 104 | 105 | return nil 106 | } 107 | } 108 | } 109 | } 110 | 111 | func (rec *StateRecorder) States() []*tunnel.ClientStateChange { 112 | rec.mu.Lock() 113 | defer rec.mu.Unlock() 114 | 115 | states := make([]*tunnel.ClientStateChange, len(rec.recorded)) 116 | copy(states, rec.recorded) 117 | return states 118 | } 119 | -------------------------------------------------------------------------------- /tunneltest/tunneltest.go: -------------------------------------------------------------------------------- 1 | package tunneltest 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "log" 7 | "net" 8 | "net/http" 9 | "net/url" 10 | "os" 11 | "sort" 12 | "strconv" 13 | "sync" 14 | "testing" 15 | "time" 16 | 17 | "github.com/koding/tunnel" 18 | ) 19 | 20 | var debugNet = os.Getenv("DEBUGNET") == "1" 21 | 22 | type dbgListener struct { 23 | net.Listener 24 | } 25 | 26 | func (l dbgListener) Accept() (net.Conn, error) { 27 | conn, err := l.Listener.Accept() 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | return dbgConn{conn}, nil 33 | } 34 | 35 | type dbgConn struct { 36 | net.Conn 37 | } 38 | 39 | func (c dbgConn) Read(p []byte) (int, error) { 40 | n, err := c.Conn.Read(p) 41 | os.Stderr.Write(p) 42 | return n, err 43 | } 44 | 45 | func (c dbgConn) Write(p []byte) (int, error) { 46 | n, err := c.Conn.Write(p) 47 | os.Stderr.Write(p) 48 | return n, err 49 | } 50 | 51 | func logf(format string, args ...interface{}) { 52 | if testing.Verbose() { 53 | log.Printf("[tunneltest] "+format, args...) 54 | } 55 | } 56 | 57 | func nonil(err ...error) error { 58 | for _, e := range err { 59 | if e != nil { 60 | return e 61 | } 62 | } 63 | return nil 64 | } 65 | 66 | func parseHostPort(addr string) (string, int, error) { 67 | host, port, err := net.SplitHostPort(addr) 68 | if err != nil { 69 | return "", 0, err 70 | } 71 | 72 | n, err := strconv.ParseUint(port, 10, 16) 73 | if err != nil { 74 | return "", 0, err 75 | } 76 | 77 | return host, int(n), nil 78 | } 79 | 80 | // UsableAddrs returns all tcp addresses that we can bind a listener to. 81 | func UsableAddrs() ([]*net.TCPAddr, error) { 82 | addrs, err := net.InterfaceAddrs() 83 | if err != nil { 84 | return nil, err 85 | } 86 | 87 | var usable []*net.TCPAddr 88 | for _, addr := range addrs { 89 | if ipNet, ok := addr.(*net.IPNet); ok { 90 | if !ipNet.IP.IsLinkLocalUnicast() { 91 | usable = append(usable, &net.TCPAddr{IP: ipNet.IP}) 92 | } 93 | } 94 | } 95 | 96 | if len(usable) == 0 { 97 | return nil, errors.New("no usable addresses found") 98 | } 99 | 100 | return usable, nil 101 | } 102 | 103 | const ( 104 | TypeHTTP = iota 105 | TypeTCP 106 | ) 107 | 108 | // Tunnel represents a single HTTP or TCP tunnel that can be served 109 | // by TunnelTest. 110 | type Tunnel struct { 111 | // Type specifies a tunnel type - either TypeHTTP (default) or TypeTCP. 112 | Type int 113 | 114 | // Handler is a handler to use for serving tunneled connections on 115 | // local server. The value of this field is required to be of type: 116 | // 117 | // - http.Handler or http.HandlerFunc for HTTP tunnels 118 | // - func(net.Conn) for TCP tunnels 119 | // 120 | // Required field. 121 | Handler interface{} 122 | 123 | // LocalAddr is a network address of local server that handles 124 | // connections/requests with Handler. 125 | // 126 | // Optional field, takes value of "127.0.0.1:0" when empty. 127 | LocalAddr string 128 | 129 | // ClientIdent is an identifier of a client that have already 130 | // registered a HTTP tunnel and have established control connection. 131 | // 132 | // If the Type is TypeTCP, instead of creating new client 133 | // for this TCP tunnel, we add it to an existing client 134 | // specified by the field. 135 | // 136 | // Optional field for TCP tunnels. 137 | // Ignored field for HTTP tunnels. 138 | ClientIdent string 139 | 140 | // RemoteAddr is a network address of remote server, which accepts 141 | // connections on a tunnel server side. 142 | // 143 | // Required field for TCP tunnels. 144 | // Ignored field for HTTP tunnels. 145 | RemoteAddr string 146 | 147 | // RemoteAddrIdent an identifier of an already existing listener, 148 | // that listens on multiple interfaces; if the RemoteAddrIdent is valid 149 | // identifier the IP field is required to be non-nil and RemoteAddr 150 | // is ignored. 151 | // 152 | // Optional field for TCP tunnels. 153 | // Ignored field for HTTP tunnels. 154 | RemoteAddrIdent string 155 | 156 | // IP specifies an IP address value for IP-based routing for TCP tunnels. 157 | // For more details see inline documentation for (*tunnel.Server).AddAddr. 158 | // 159 | // Optional field for TCP tunnels. 160 | // Ignored field for HTTP tunnels. 161 | IP net.IP 162 | 163 | // StateChanges listens on state transitions. 164 | // 165 | // If ClientIdent field is empty, the StateChanges will receive 166 | // state transition events for the newly created client. 167 | // Otherwise setting this field is a nop. 168 | StateChanges chan<- *tunnel.ClientStateChange 169 | } 170 | 171 | type TunnelTest struct { 172 | Server *tunnel.Server 173 | ServerStateRecorder *StateRecorder 174 | Clients map[string]*tunnel.Client 175 | Listeners map[string][2]net.Listener // [0] is local listener, [1] is remote one (for TCP tunnels) 176 | Addrs []*net.TCPAddr 177 | Tunnels map[string]*Tunnel 178 | DebugNet bool // for debugging network communication 179 | 180 | mu sync.Mutex // protects Listeners 181 | } 182 | 183 | func NewTunnelTest() (*TunnelTest, error) { 184 | rec := NewStateRecorder() 185 | 186 | cfg := &tunnel.ServerConfig{ 187 | StateChanges: rec.C(), 188 | Debug: testing.Verbose(), 189 | } 190 | s, err := tunnel.NewServer(cfg) 191 | if err != nil { 192 | return nil, err 193 | } 194 | 195 | l, err := net.Listen("tcp", "127.0.0.1:0") 196 | if err != nil { 197 | return nil, err 198 | } 199 | 200 | if debugNet { 201 | l = dbgListener{l} 202 | } 203 | 204 | addrs, err := UsableAddrs() 205 | if err != nil { 206 | return nil, err 207 | } 208 | 209 | go (&http.Server{Handler: s}).Serve(l) 210 | 211 | return &TunnelTest{ 212 | Server: s, 213 | ServerStateRecorder: rec, 214 | Clients: make(map[string]*tunnel.Client), 215 | Listeners: map[string][2]net.Listener{"": {l, nil}}, 216 | Addrs: addrs, 217 | Tunnels: make(map[string]*Tunnel), 218 | DebugNet: debugNet, 219 | }, nil 220 | } 221 | 222 | // Serve creates new TunnelTest that serves the given tunnels. 223 | // 224 | // If tunnels is nil, DefaultTunnels() are used instead. 225 | func Serve(tunnels map[string]*Tunnel) (*TunnelTest, error) { 226 | tt, err := NewTunnelTest() 227 | if err != nil { 228 | return nil, err 229 | } 230 | 231 | if err = tt.Serve(tunnels); err != nil { 232 | return nil, err 233 | } 234 | 235 | return tt, nil 236 | } 237 | 238 | func (tt *TunnelTest) serveSingle(ident string, t *Tunnel) (bool, error) { 239 | // Verify tunnel dependencies for TCP tunnels. 240 | if t.Type == TypeTCP { 241 | // If tunnel specified by t.Client was not already started, 242 | // skip and move on. 243 | if _, ok := tt.Clients[t.ClientIdent]; t.ClientIdent != "" && !ok { 244 | return false, nil 245 | } 246 | 247 | // Verify the TCP tunnel whose remote endpoint listens on multiple 248 | // interfaces is already served. 249 | if t.RemoteAddrIdent != "" { 250 | if _, ok := tt.Listeners[t.RemoteAddrIdent]; !ok { 251 | return false, nil 252 | } 253 | 254 | if tt.Tunnels[t.RemoteAddrIdent].Type != TypeTCP { 255 | return false, fmt.Errorf("expected tunnel %q to be of TCP type", t.RemoteAddrIdent) 256 | } 257 | } 258 | } 259 | 260 | l, err := net.Listen("tcp", t.LocalAddr) 261 | if err != nil { 262 | return false, fmt.Errorf("failed to listen on %q for %q tunnel: %s", t.LocalAddr, ident, err) 263 | } 264 | 265 | if tt.DebugNet { 266 | l = dbgListener{l} 267 | } 268 | 269 | localAddr := l.Addr().String() 270 | httpProxy := &tunnel.HTTPProxy{LocalAddr: localAddr} 271 | tcpProxy := &tunnel.TCPProxy{FetchLocalAddr: tt.fetchLocalAddr} 272 | 273 | cfg := &tunnel.ClientConfig{ 274 | Identifier: ident, 275 | ServerAddr: tt.ServerAddr().String(), 276 | Proxy: tunnel.Proxy(tunnel.ProxyFuncs{ 277 | HTTP: httpProxy.Proxy, 278 | TCP: tcpProxy.Proxy, 279 | }), 280 | StateChanges: t.StateChanges, 281 | Debug: testing.Verbose(), 282 | } 283 | 284 | // Register tunnel: 285 | // 286 | // - start tunnel.Client (tt.Clients[ident]) or reuse existing one (tt.Clients[t.ExistingClient]) 287 | // - listen on local address and start local server (tt.Listeners[ident][0]) 288 | // - register tunnel on tunnel.Server 289 | // 290 | switch t.Type { 291 | case TypeHTTP: 292 | // TODO(rjeczalik): refactor to separate method 293 | 294 | h, ok := t.Handler.(http.Handler) 295 | if !ok { 296 | h, ok = t.Handler.(http.HandlerFunc) 297 | if !ok { 298 | fn, ok := t.Handler.(func(http.ResponseWriter, *http.Request)) 299 | if !ok { 300 | return false, fmt.Errorf("invalid handler type for %q tunnel: %T", ident, t.Handler) 301 | } 302 | 303 | h = http.HandlerFunc(fn) 304 | } 305 | 306 | } 307 | 308 | logf("serving on local %s for HTTP tunnel %q", l.Addr(), ident) 309 | 310 | go (&http.Server{Handler: h}).Serve(l) 311 | 312 | tt.Server.AddHost(localAddr, ident) 313 | 314 | tt.mu.Lock() 315 | tt.Listeners[ident] = [2]net.Listener{l, nil} 316 | tt.mu.Unlock() 317 | 318 | if err := tt.addClient(ident, cfg); err != nil { 319 | return false, fmt.Errorf("error creating client for %q tunnel: %s", ident, err) 320 | } 321 | 322 | logf("registered HTTP tunnel: host=%s, ident=%s", localAddr, ident) 323 | 324 | case TypeTCP: 325 | // TODO(rjeczalik): refactor to separate method 326 | 327 | h, ok := t.Handler.(func(net.Conn)) 328 | if !ok { 329 | return false, fmt.Errorf("invalid handler type for %q tunnel: %T", ident, t.Handler) 330 | } 331 | 332 | logf("serving on local %s for TCP tunnel %q", l.Addr(), ident) 333 | 334 | go func() { 335 | for { 336 | conn, err := l.Accept() 337 | if err != nil { 338 | log.Printf("failed accepting conn for %q tunnel: %s", ident, err) 339 | return 340 | } 341 | 342 | go h(conn) 343 | } 344 | }() 345 | 346 | var remote net.Listener 347 | 348 | if t.RemoteAddrIdent != "" { 349 | tt.mu.Lock() 350 | remote = tt.Listeners[t.RemoteAddrIdent][1] 351 | tt.mu.Unlock() 352 | } else { 353 | remote, err = net.Listen("tcp", t.RemoteAddr) 354 | if err != nil { 355 | return false, fmt.Errorf("failed to listen on %q for %q tunnel: %s", t.RemoteAddr, ident, err) 356 | } 357 | } 358 | 359 | // addrIdent holds identifier of client which is going to have registered 360 | // tunnel via (*tunnel.Server).AddAddr 361 | addrIdent := ident 362 | if t.ClientIdent != "" { 363 | tt.Clients[ident] = tt.Clients[t.ClientIdent] 364 | addrIdent = t.ClientIdent 365 | } 366 | 367 | tt.Server.AddAddr(remote, t.IP, addrIdent) 368 | 369 | tt.mu.Lock() 370 | tt.Listeners[ident] = [2]net.Listener{l, remote} 371 | tt.mu.Unlock() 372 | 373 | if _, ok := tt.Clients[ident]; !ok { 374 | if err := tt.addClient(ident, cfg); err != nil { 375 | return false, fmt.Errorf("error creating client for %q tunnel: %s", ident, err) 376 | } 377 | } 378 | 379 | logf("registered TCP tunnel: listener=%s, ip=%v, ident=%s", remote.Addr(), t.IP, addrIdent) 380 | 381 | default: 382 | return false, fmt.Errorf("unknown %q tunnel type: %d", ident, t.Type) 383 | } 384 | 385 | return true, nil 386 | } 387 | 388 | func (tt *TunnelTest) addClient(ident string, cfg *tunnel.ClientConfig) error { 389 | if _, ok := tt.Clients[ident]; ok { 390 | return fmt.Errorf("tunnel %q is already being served", ident) 391 | } 392 | 393 | c, err := tunnel.NewClient(cfg) 394 | if err != nil { 395 | return err 396 | } 397 | 398 | done := make(chan struct{}) 399 | 400 | tt.Server.OnConnect(ident, func() error { 401 | close(done) 402 | return nil 403 | }) 404 | 405 | go c.Start() 406 | <-c.StartNotify() 407 | 408 | select { 409 | case <-time.After(10 * time.Second): 410 | return errors.New("timed out after 10s waiting on client to establish control conn") 411 | case <-done: 412 | } 413 | 414 | tt.Clients[ident] = c 415 | return nil 416 | } 417 | 418 | func (tt *TunnelTest) Serve(tunnels map[string]*Tunnel) error { 419 | if len(tunnels) == 0 { 420 | return errors.New("no tunnels to serve") 421 | } 422 | 423 | // Since one tunnels depends on others do 3 passes to start them 424 | // all, each started tunnel is removed from the tunnels map. 425 | // After 3 passes all of them must be started, otherwise the 426 | // configuration is bad: 427 | // 428 | // - first pass starts HTTP tunnels as new client tunnels 429 | // - second pass starts TCP tunnels that rely on on already existing client tunnels (t.ClientIdent) 430 | // - third pass starts TCP tunnels that rely on on already existing TCP tunnels (t.RemoteAddrIdent) 431 | // 432 | for i := 0; i < 3; i++ { 433 | if err := tt.popServedDeps(tunnels); err != nil { 434 | return err 435 | } 436 | } 437 | 438 | if len(tunnels) != 0 { 439 | unresolved := make([]string, len(tunnels)) 440 | for ident := range tunnels { 441 | unresolved = append(unresolved, ident) 442 | } 443 | sort.Strings(unresolved) 444 | 445 | return fmt.Errorf("unable to start tunnels due to unresolved dependencies: %v", unresolved) 446 | } 447 | 448 | return nil 449 | } 450 | 451 | func (tt *TunnelTest) popServedDeps(tunnels map[string]*Tunnel) error { 452 | for ident, t := range tunnels { 453 | ok, err := tt.serveSingle(ident, t) 454 | if err != nil { 455 | return err 456 | } 457 | 458 | if ok { 459 | // Remove already started tunnels so they won't get started again. 460 | delete(tunnels, ident) 461 | tt.Tunnels[ident] = t 462 | } 463 | } 464 | 465 | return nil 466 | } 467 | 468 | func (tt *TunnelTest) fetchLocalAddr(port int) (string, error) { 469 | tt.mu.Lock() 470 | defer tt.mu.Unlock() 471 | 472 | for _, l := range tt.Listeners { 473 | if l[1] == nil { 474 | // this listener does not belong to a TCP tunnel 475 | continue 476 | } 477 | 478 | _, remotePort, err := parseHostPort(l[1].Addr().String()) 479 | if err != nil { 480 | return "", err 481 | } 482 | 483 | if port == remotePort { 484 | return l[0].Addr().String(), nil 485 | } 486 | } 487 | 488 | return "", fmt.Errorf("no route for %d port", port) 489 | } 490 | 491 | func (tt *TunnelTest) ServerAddr() net.Addr { 492 | return tt.Listeners[""][0].Addr() 493 | } 494 | 495 | // Addr gives server endpoint of the TCP tunnel for the given ident. 496 | // 497 | // If the tunnel does not exist or is a HTTP one, TunnelAddr return nil. 498 | func (tt *TunnelTest) Addr(ident string) net.Addr { 499 | l, ok := tt.Listeners[ident] 500 | if !ok { 501 | return nil 502 | } 503 | 504 | return l[1].Addr() 505 | } 506 | 507 | // Request creates a HTTP request to a server endpoint of the HTTP tunnel 508 | // for the given ident. 509 | // 510 | // If the tunnel does not exist, Request returns nil. 511 | func (tt *TunnelTest) Request(ident string, query url.Values) *http.Request { 512 | l, ok := tt.Listeners[ident] 513 | if !ok { 514 | return nil 515 | } 516 | 517 | var raw string 518 | if query != nil { 519 | raw = query.Encode() 520 | } 521 | 522 | return &http.Request{ 523 | Method: "GET", 524 | URL: &url.URL{ 525 | Scheme: "http", 526 | Host: tt.ServerAddr().String(), 527 | Path: "/", 528 | RawQuery: raw, 529 | }, 530 | Proto: "HTTP/1.1", 531 | ProtoMajor: 1, 532 | ProtoMinor: 1, 533 | Host: l[0].Addr().String(), 534 | } 535 | } 536 | 537 | func (tt *TunnelTest) Close() (err error) { 538 | // Close tunnel.Clients. 539 | clients := make(map[*tunnel.Client]struct{}) 540 | for _, c := range tt.Clients { 541 | clients[c] = struct{}{} 542 | } 543 | for c := range clients { 544 | err = nonil(err, c.Close()) 545 | } 546 | 547 | // Stop all TCP/HTTP servers. 548 | listeners := make(map[net.Listener]struct{}) 549 | for _, l := range tt.Listeners { 550 | for _, l := range l { 551 | if l != nil { 552 | listeners[l] = struct{}{} 553 | } 554 | } 555 | } 556 | for l := range listeners { 557 | err = nonil(err, l.Close()) 558 | } 559 | 560 | return err 561 | } 562 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package tunnel 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "net" 7 | "sync" 8 | "time" 9 | 10 | "github.com/koding/tunnel/proto" 11 | 12 | "github.com/cenkalti/backoff" 13 | ) 14 | 15 | // async is a helper function to convert a blocking function to a function 16 | // returning an error. Useful for plugging function closures into select and co 17 | func async(fn func() error) <-chan error { 18 | errChan := make(chan error, 0) 19 | go func() { 20 | select { 21 | case errChan <- fn(): 22 | default: 23 | } 24 | 25 | close(errChan) 26 | }() 27 | 28 | return errChan 29 | } 30 | 31 | type expBackoff struct { 32 | mu sync.Mutex 33 | bk *backoff.ExponentialBackOff 34 | } 35 | 36 | func newForeverBackoff() *expBackoff { 37 | eb := &expBackoff{ 38 | bk: backoff.NewExponentialBackOff(), 39 | } 40 | eb.bk.MaxElapsedTime = 0 // never stops 41 | return eb 42 | } 43 | 44 | func (eb *expBackoff) NextBackOff() time.Duration { 45 | eb.mu.Lock() 46 | defer eb.mu.Unlock() 47 | 48 | return eb.bk.NextBackOff() 49 | } 50 | 51 | func (eb *expBackoff) Reset() { 52 | eb.mu.Lock() 53 | eb.bk.Reset() 54 | eb.mu.Unlock() 55 | } 56 | 57 | type callbacks struct { 58 | mu sync.Mutex 59 | name string 60 | funcs map[string]func() error 61 | } 62 | 63 | func newCallbacks(name string) *callbacks { 64 | return &callbacks{ 65 | name: name, 66 | funcs: make(map[string]func() error), 67 | } 68 | } 69 | 70 | func (c *callbacks) add(identifier string, fn func() error) { 71 | c.mu.Lock() 72 | c.funcs[identifier] = fn 73 | c.mu.Unlock() 74 | } 75 | 76 | func (c *callbacks) pop(identifier string) (func() error, error) { 77 | c.mu.Lock() 78 | defer c.mu.Unlock() 79 | 80 | fn, ok := c.funcs[identifier] 81 | if !ok { 82 | return nil, nil // nop 83 | } 84 | 85 | delete(c.funcs, identifier) 86 | 87 | if fn == nil { 88 | return nil, fmt.Errorf("nil callback set for %q client", identifier) 89 | } 90 | 91 | return fn, nil 92 | } 93 | 94 | func (c *callbacks) call(identifier string) error { 95 | fn, err := c.pop(identifier) 96 | if err != nil { 97 | return err 98 | } 99 | 100 | if fn == nil { 101 | return nil // nop 102 | } 103 | 104 | return fn() 105 | } 106 | 107 | // Returns server control url as a string. Reads scheme and remote address from connection. 108 | func controlURL(conn net.Conn) string { 109 | return fmt.Sprint(scheme(conn), "://", conn.RemoteAddr(), proto.ControlPath) 110 | } 111 | 112 | func scheme(conn net.Conn) (scheme string) { 113 | switch conn.(type) { 114 | case *tls.Conn: 115 | scheme = "https" 116 | default: 117 | scheme = "http" 118 | } 119 | 120 | return 121 | } 122 | -------------------------------------------------------------------------------- /virtualaddr.go: -------------------------------------------------------------------------------- 1 | package tunnel 2 | 3 | import ( 4 | "net" 5 | "strconv" 6 | "sync" 7 | "sync/atomic" 8 | 9 | "github.com/koding/logging" 10 | ) 11 | 12 | type listener struct { 13 | net.Listener 14 | *vaddrOptions 15 | 16 | done int32 17 | 18 | // ips keeps track of registered clients for ip-based routing; 19 | // when last client is deleted from the ip routing map, we stop 20 | // listening on connections 21 | ips map[string]struct{} 22 | } 23 | 24 | type vaddrOptions struct { 25 | connCh chan<- net.Conn 26 | log logging.Logger 27 | } 28 | 29 | type vaddrStorage struct { 30 | *vaddrOptions 31 | 32 | listeners map[net.Listener]*listener 33 | ports map[int]string // port-based routing: maps port number to identifier 34 | ips map[string]string // ip-based routing: maps ip address to identifier 35 | 36 | mu sync.RWMutex 37 | } 38 | 39 | func newVirtualAddrs(opts *vaddrOptions) *vaddrStorage { 40 | return &vaddrStorage{ 41 | vaddrOptions: opts, 42 | listeners: make(map[net.Listener]*listener), 43 | ports: make(map[int]string), 44 | ips: make(map[string]string), 45 | } 46 | } 47 | 48 | func (l *listener) serve() { 49 | for { 50 | conn, err := l.Accept() 51 | if err != nil { 52 | l.log.Error("failue listening on %q: %s", l.Addr(), err) 53 | return 54 | } 55 | 56 | if atomic.LoadInt32(&l.done) != 0 { 57 | l.log.Debug("stopped serving %q", l.Addr()) 58 | conn.Close() 59 | return 60 | } 61 | 62 | l.connCh <- conn 63 | } 64 | } 65 | 66 | func (l *listener) localAddr() string { 67 | if addr, ok := l.Addr().(*net.TCPAddr); ok { 68 | if addr.IP.Equal(net.IPv4zero) { 69 | return net.JoinHostPort("127.0.0.1", strconv.Itoa(addr.Port)) 70 | } 71 | } 72 | return l.Addr().String() 73 | } 74 | 75 | func (l *listener) stop() { 76 | if atomic.CompareAndSwapInt32(&l.done, 0, 1) { 77 | // stop is called when no more connections should be accepted by 78 | // the user-provided listener; as we can't simple close the listener 79 | // to not break the guarantee given by the (*Server).DeleteAddr 80 | // method, we make a dummy connection to break out of serve loop. 81 | // It is safe to make a dummy connection, as either the following 82 | // dial will time out when the listener is busy accepting connections, 83 | // or will get closed immadiately after idle listeners accepts connection 84 | // and returns from the serve loop. 85 | conn, err := net.DialTimeout("tcp", l.localAddr(), defaultTimeout) 86 | if err == nil { 87 | conn.Close() 88 | } 89 | } 90 | } 91 | 92 | func (vaddr *vaddrStorage) Add(l net.Listener, ip net.IP, ident string) { 93 | vaddr.mu.Lock() 94 | defer vaddr.mu.Unlock() 95 | 96 | lis, ok := vaddr.listeners[l] 97 | if !ok { 98 | lis = vaddr.newListener(l) 99 | vaddr.listeners[l] = lis 100 | go lis.serve() 101 | } 102 | 103 | if ip != nil { 104 | lis.ips[ip.String()] = struct{}{} 105 | vaddr.ips[ip.String()] = ident 106 | } else { 107 | vaddr.ports[mustPort(l)] = ident 108 | } 109 | } 110 | 111 | func (vaddr *vaddrStorage) Delete(l net.Listener, ip net.IP) { 112 | vaddr.mu.Lock() 113 | defer vaddr.mu.Unlock() 114 | 115 | lis, ok := vaddr.listeners[l] 116 | if !ok { 117 | return 118 | } 119 | 120 | var stop bool 121 | 122 | if ip != nil { 123 | delete(lis.ips, ip.String()) 124 | delete(vaddr.ips, ip.String()) 125 | 126 | stop = len(lis.ips) == 0 127 | } else { 128 | delete(vaddr.ports, mustPort(l)) 129 | 130 | stop = true 131 | } 132 | 133 | // Only stop listening for connections when listener has clients 134 | // registered to tunnel the connections to. 135 | if stop { 136 | lis.stop() 137 | delete(vaddr.listeners, l) 138 | } 139 | } 140 | 141 | func (vaddr *vaddrStorage) newListener(l net.Listener) *listener { 142 | return &listener{ 143 | Listener: l, 144 | vaddrOptions: vaddr.vaddrOptions, 145 | ips: make(map[string]struct{}), 146 | } 147 | } 148 | 149 | func (vaddr *vaddrStorage) getIdent(conn net.Conn) (string, bool) { 150 | vaddr.mu.Lock() 151 | defer vaddr.mu.Unlock() 152 | 153 | ip, port, err := parseHostPort(conn.LocalAddr().String()) 154 | if err != nil { 155 | vaddr.log.Debug("failed to get identifier for connection %q: %s", conn.LocalAddr(), err) 156 | return "", false 157 | } 158 | 159 | // First lookup if there's a ip-based route, then try port-base one. 160 | 161 | if ident, ok := vaddr.ips[ip]; ok { 162 | return ident, true 163 | } 164 | 165 | ident, ok := vaddr.ports[port] 166 | return ident, ok 167 | } 168 | 169 | func mustPort(l net.Listener) int { 170 | _, port, err := parseHostPort(l.Addr().String()) 171 | if err != nil { 172 | // This can happened when user passed custom type that 173 | // implements net.Listener, which returns ill-formed 174 | // net.Addr value. 175 | panic("ill-formed net.Addr: " + err.Error()) 176 | } 177 | 178 | return port 179 | } 180 | -------------------------------------------------------------------------------- /virtualhost.go: -------------------------------------------------------------------------------- 1 | package tunnel 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | type vhostStorage interface { 8 | // AddHost adds the given host and identifier to the storage 9 | AddHost(host, identifier string) 10 | 11 | // DeleteHost deletes the given host 12 | DeleteHost(host string) 13 | 14 | // GetHost returns the host name for the given identifier 15 | GetHost(identifier string) (string, bool) 16 | 17 | // GetIdentifier returns the identifier for the given host 18 | GetIdentifier(host string) (string, bool) 19 | } 20 | 21 | type virtualHost struct { 22 | identifier string 23 | } 24 | 25 | // virtualHosts is used for mapping host to users example: host 26 | // "fs-1-fatih.kd.io" belongs to user "arslan" 27 | type virtualHosts struct { 28 | mapping map[string]*virtualHost 29 | sync.Mutex 30 | } 31 | 32 | // newVirtualHosts provides an in memory virtual host storage for mapping 33 | // virtual hosts to identifiers. 34 | func newVirtualHosts() *virtualHosts { 35 | return &virtualHosts{ 36 | mapping: make(map[string]*virtualHost), 37 | } 38 | } 39 | 40 | func (v *virtualHosts) AddHost(host, identifier string) { 41 | v.Lock() 42 | v.mapping[host] = &virtualHost{identifier: identifier} 43 | v.Unlock() 44 | } 45 | 46 | func (v *virtualHosts) DeleteHost(host string) { 47 | v.Lock() 48 | delete(v.mapping, host) 49 | v.Unlock() 50 | } 51 | 52 | // GetIdentifier returns the identifier associated with the given host 53 | func (v *virtualHosts) GetIdentifier(host string) (string, bool) { 54 | v.Lock() 55 | ht, ok := v.mapping[host] 56 | v.Unlock() 57 | 58 | if !ok { 59 | return "", false 60 | } 61 | 62 | return ht.identifier, true 63 | } 64 | 65 | // GetHost returns the host associated with the given identifier 66 | func (v *virtualHosts) GetHost(identifier string) (string, bool) { 67 | v.Lock() 68 | defer v.Unlock() 69 | 70 | for hostname, hst := range v.mapping { 71 | if hst.identifier == identifier { 72 | return hostname, true 73 | } 74 | } 75 | 76 | return "", false 77 | } 78 | -------------------------------------------------------------------------------- /websocket_test.go: -------------------------------------------------------------------------------- 1 | package tunnel_test 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "reflect" 7 | "testing" 8 | 9 | "github.com/koding/tunnel/tunneltest" 10 | ) 11 | 12 | func testWebsocket(name string, n int, t *testing.T, tt *tunneltest.TunnelTest) { 13 | conn, err := websocketDial(tt, "http") 14 | if err != nil { 15 | t.Fatalf("Dial()=%s", err) 16 | } 17 | defer conn.Close() 18 | 19 | for i := 0; i < n; i++ { 20 | want := &EchoMessage{ 21 | Value: fmt.Sprintf("message #%d", i), 22 | Close: i == (n - 1), 23 | } 24 | 25 | err := conn.WriteJSON(want) 26 | if err != nil { 27 | t.Errorf("(test %s) %d: failed sending %q: %s", name, i, want, err) 28 | continue 29 | } 30 | 31 | got := &EchoMessage{} 32 | 33 | err = conn.ReadJSON(got) 34 | if err != nil { 35 | t.Errorf("(test %s) %d: failed reading: %s", name, i, err) 36 | continue 37 | } 38 | 39 | if !reflect.DeepEqual(got, want) { 40 | t.Errorf("(test %s) %d: got %+v, want %+v", name, i, got, want) 41 | } 42 | } 43 | } 44 | 45 | func testHandler(t *testing.T, fn func(w http.ResponseWriter, r *http.Request) error) http.HandlerFunc { 46 | return func(w http.ResponseWriter, r *http.Request) { 47 | if err := fn(w, r); err != nil { 48 | t.Errorf("handler func error: %s", err) 49 | } 50 | } 51 | } 52 | 53 | func TestWebsocket(t *testing.T) { 54 | tt, err := tunneltest.Serve(singleHTTP(testHandler(t, handlerEchoWS(nil)))) 55 | if err != nil { 56 | t.Fatal(err) 57 | } 58 | 59 | testWebsocket("handlerEchoWS", 100, t, tt) 60 | } 61 | 62 | func TestLatencyWebsocket(t *testing.T) { 63 | tt, err := tunneltest.Serve(singleHTTP(testHandler(t, handlerEchoWS(sleep)))) 64 | if err != nil { 65 | t.Fatal(err) 66 | } 67 | 68 | testWebsocket("handlerLatencyEchoWS", 20, t, tt) 69 | } 70 | --------------------------------------------------------------------------------