├── .gitignore ├── .travis.yml ├── README.md ├── finagle-websocket └── .travis.yml ├── project ├── Build.scala ├── build.properties └── plugins.sbt ├── sbt └── src ├── main └── scala │ └── com │ └── twitter │ └── finagle │ ├── WebSocket.scala │ └── websocket │ ├── ClientDispatcher.scala │ ├── Frame.scala │ ├── Netty3.scala │ ├── Request.scala │ ├── Response.scala │ ├── ServerDispatcher.scala │ └── WebSocketHandler.scala └── test └── scala └── com └── twitter └── finagle └── websocket ├── EndToEndTest.scala └── ServerDispatcherTest.scala /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.log 3 | 4 | # sbt specific 5 | dist/* 6 | target/ 7 | lib_managed/ 8 | src_managed/ 9 | project/boot/ 10 | project/plugins/project/ 11 | 12 | sbt-launch.jar 13 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | dist: trusty 3 | sudo: false 4 | 5 | scala: 6 | - 2.11.11 7 | - 2.12.2 8 | 9 | jdk: 10 | - oraclejdk8 11 | - openjdk8 12 | 13 | script: sbt ++$TRAVIS_SCALA_VERSION coverage test it:test 14 | after_success: sbt ++$TRAVIS_SCALA_VERSION coveralls 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Finagle Websocket 2 | 3 | [![Build status](https://travis-ci.org/finagle/finagle-websocket.svg?branch=master)](https://travis-ci.org/finagle/finagle-websocket) 4 | [![Coverage status](https://img.shields.io/coveralls/finagle/finagle-websocket/master.svg)](https://coveralls.io/r/finagle/finagle-websocket?branch=master) 5 | [![Project status](https://img.shields.io/badge/status-inactive-yellow.svg)](#status) 6 | 7 | Websockets support for Finagle 8 | 9 | ## Status 10 | 11 | This project is inactive. While we are keep it up to date with new Finagle 12 | releases, it is not currently being actively used or developed. If you're using 13 | Finagle Websocket and would be interested in discussing co-ownership, please 14 | [file an issue](https://github.com/finagle/finagle-websocket/issues). 15 | 16 | ## Using websockets 17 | 18 | ### Adding dependencies 19 | 20 | Maven 21 | 22 | 23 | 24 | com.github.sprsquish 25 | https://raw.github.com/sprsquish/mvn-repo/master/ 26 | default 27 | 28 | 29 | 30 | 31 | com.github.sprsquish 32 | finagle-websockets_2.9.2 33 | 6.8.1 34 | compile 35 | 36 | 37 | sbt 38 | 39 | resolvers += "com.github.sprsquish" at "https://raw.github.com/sprsquish/mvn-repo/master" 40 | 41 | "com.github.sprsquish" %% "finagle-websockets" % "6.8.1" 42 | 43 | ## Example 44 | 45 | The following client and server can be run by pasting the code into `sbt 46 | console`. The client sends "1" over to the server. The server responds by 47 | translating the numeral into an word (mostly). When the client receives this 48 | word, it will send back a number which represents the length of characters of 49 | the received word. 50 | 51 | ### Client 52 | 53 | ```scala 54 | import com.twitter.concurrent.AsyncStream 55 | import com.twitter.conversions.time._ 56 | import com.twitter.finagle.Websocket 57 | import com.twitter.finagle.util.DefaultTimer 58 | import com.twitter.finagle.websocket.{Frame, Request} 59 | import com.twitter.util.Promise 60 | import java.net.URI 61 | 62 | implicit val timer = DefaultTimer.twitter 63 | 64 | // Responds to messages from the server. 65 | def handler(messages: AsyncStream[Frame]): AsyncStream[Frame] = 66 | messages.flatMap { 67 | case Frame.Text(message) => 68 | // Print the received message. 69 | println(message) 70 | 71 | AsyncStream.fromFuture( 72 | // Sleep for a second... 73 | Future.sleep(1.second).map { _ => 74 | // ... and then send a message to the server. 75 | Frame.Text(message.length.toString) 76 | }) 77 | 78 | case _ => AsyncStream.of(Frame.Text("??")) 79 | } 80 | 81 | val incoming = new Promise[AsyncStream[Frame]] 82 | val outgoing = 83 | Frame.Text("1") +:: handler( 84 | AsyncStream.fromFuture(incoming).flatten) 85 | 86 | val client = Websocket.client.newService(":14000") 87 | val req = Request(new URI("/"), Map.empty, null, outgoing) 88 | 89 | // Take the messages of the response and fulfill `incoming`. 90 | client(req).map(_.messages).proxyTo(incoming) 91 | ``` 92 | 93 | ### Server 94 | 95 | ```scala 96 | import com.twitter.concurrent.AsyncStream 97 | import com.twitter.finagle.{Service, Websocket} 98 | import com.twitter.finagle.websocket.{Frame, Request, Response} 99 | import com.twitter.util.Future 100 | 101 | // A server that when given a number, responds with a word (mostly). 102 | def handler(messages: AsyncStream[Frame]): AsyncStream[Frame] = { 103 | messages.map { 104 | case Frame.Text("1") => Frame.Text("one") 105 | case Frame.Text("2") => Frame.Text("two") 106 | case Frame.Text("3") => Frame.Text("three") 107 | case Frame.Text("4") => Frame.Text("cuatro") 108 | case Frame.Text("5") => Frame.Text("five") 109 | case Frame.Text("6") => Frame.Text("6") 110 | case _ => Frame.Text("??") 111 | } 112 | } 113 | 114 | Websocket.serve(":14000", new Service[Request, Response] { 115 | def apply(req: Request): Future[Response] = 116 | Future.value(Response(handler(req.messages))) 117 | }) 118 | ``` 119 | -------------------------------------------------------------------------------- /finagle-websocket/.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | script: "sbt +test" 3 | -------------------------------------------------------------------------------- /project/Build.scala: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | import Keys._ 3 | 4 | object FinagleWebsocket extends Build { 5 | val libVersion = "17.12.0" 6 | 7 | val baseSettings = Defaults.coreDefaultSettings ++ Seq( 8 | libraryDependencies ++= Seq( 9 | "com.twitter" %% "finagle-core" % libVersion, 10 | "com.twitter" %% "finagle-netty3" % libVersion, 11 | "org.scalatest" %% "scalatest" % "3.0.1" % Test, 12 | "junit" % "junit" % "4.12" % Test 13 | ) 14 | ) 15 | 16 | lazy val buildSettings = Seq( 17 | organization := "com.github.finagle", 18 | version := libVersion, 19 | scalaVersion := "2.12.2", 20 | crossScalaVersions := Seq("2.11.11", "2.12.2"), 21 | scalacOptions ++= Seq("-deprecation", "-feature", "-Xexperimental") 22 | ) 23 | 24 | lazy val publishSettings = Seq( 25 | publishMavenStyle := true, 26 | publishArtifact := true, 27 | publishTo := Some(Resolver.file("localDirectory", file(Path.userHome.absolutePath + "/workspace/mvn-repo"))), 28 | licenses := Seq("Apache 2.0" -> url("http://www.apache.org/licenses/LICENSE-2.0")), 29 | homepage := Some(url("https://github.com/finagle/finagle-websocket")), 30 | pomExtra := ( 31 | 32 | git://github.com/finagle/finagle-websocket.git 33 | scm:git://github.com/finagle/finagle-websocket.git 34 | 35 | 36 | 37 | sprsquish 38 | Jeff Smick 39 | https://github.com/sprsquish 40 | 41 | ) 42 | ) 43 | 44 | lazy val finagleWebsocket = Project( 45 | id = "finagle-websocket", 46 | base = file("."), 47 | settings = 48 | Defaults.itSettings ++ 49 | baseSettings ++ 50 | buildSettings ++ 51 | publishSettings 52 | ).configs(IntegrationTest) 53 | } 54 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.13 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.8.2") 2 | 3 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0") 4 | 5 | addSbtPlugin("org.scoverage" % "sbt-coveralls" % "1.1.0") 6 | -------------------------------------------------------------------------------- /sbt: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sbtver=0.13.13 4 | sbtjar=sbt-launch.jar 5 | sbtsha128=57d0f04f4b48b11ef7e764f4cea58dee4e806ffd 6 | 7 | sbtrepo=http://repo.typesafe.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch 8 | 9 | if [ ! -f $sbtjar ]; then 10 | echo "downloading $sbtjar" 1>&2 11 | if ! curl --location --silent --fail --remote-name $sbtrepo/$sbtver/$sbtjar; then 12 | exit 1 13 | fi 14 | fi 15 | 16 | checksum=`openssl dgst -sha1 $sbtjar | awk '{ print $2 }'` 17 | if [ "$checksum" != $sbtsha128 ]; then 18 | echo "bad $sbtjar. delete $sbtjar and run $0 again." 19 | exit 1 20 | fi 21 | 22 | [ -f ~/.sbtconfig ] && . ~/.sbtconfig 23 | 24 | java -ea \ 25 | $SBT_OPTS \ 26 | $JAVA_OPTS \ 27 | -Djava.net.preferIPv4Stack=true \ 28 | -XX:+AggressiveOpts \ 29 | -XX:+UseParNewGC \ 30 | -XX:+UseConcMarkSweepGC \ 31 | -XX:+CMSParallelRemarkEnabled \ 32 | -XX:+CMSClassUnloadingEnabled \ 33 | -XX:ReservedCodeCacheSize=128m \ 34 | -XX:MaxPermSize=1024m \ 35 | -XX:SurvivorRatio=128 \ 36 | -XX:MaxTenuringThreshold=0 \ 37 | -Xss8M \ 38 | -Xms512M \ 39 | -Xmx2G \ 40 | -server \ 41 | -jar $sbtjar "$@" 42 | -------------------------------------------------------------------------------- /src/main/scala/com/twitter/finagle/WebSocket.scala: -------------------------------------------------------------------------------- 1 | package com.twitter.finagle 2 | 3 | import com.twitter.finagle.netty3._ 4 | import com.twitter.finagle.param.{Label, ProtocolLibrary, Stats} 5 | import com.twitter.finagle.client.{StackClient, StdStackClient, Transporter} 6 | import com.twitter.finagle.server.{Listener, StackServer, StdStackServer} 7 | import com.twitter.finagle.websocket.{ClientDispatcher, Netty3, Request, Response, ServerDispatcher} 8 | import com.twitter.finagle.transport.{Transport, TransportContext} 9 | import com.twitter.util.Closable 10 | import java.net.SocketAddress 11 | 12 | import com.twitter.finagle.ssl.client.SslClientConfiguration 13 | import org.jboss.netty.channel.Channel 14 | 15 | object Websocket extends Server[Request, Response] { 16 | case class Client(stack: Stack[ServiceFactory[Request, Response]] = StackClient.newStack, 17 | params: Stack.Params = StackClient.defaultParams + ProtocolLibrary("ws")) 18 | extends StdStackClient[Request, Response, Client] { 19 | override protected type In = Any 20 | override protected type Out = Any 21 | override protected type Context = TransportContext 22 | 23 | override protected def newTransporter(addr: SocketAddress): Transporter[In, Out, Context] = 24 | Netty3.newTransporter(addr, params) 25 | 26 | override protected def copy1( 27 | stack: Stack[ServiceFactory[Request, Response]] = this.stack, 28 | params: Stack.Params = this.params 29 | ): Client = copy(stack, params) 30 | 31 | override protected def newDispatcher(transport: Transport[In, Out] { 32 | type Context <: TransportContext 33 | }): Service[Request, Response] = 34 | new ClientDispatcher(transport) 35 | 36 | def withTlsWithoutValidation: Client = withTransport.tlsWithoutValidation 37 | 38 | def withTls(hostname: String): Client = withTransport.tls(hostname) 39 | 40 | def withTls(cfg: SslClientConfiguration): Client = 41 | withTransport.tls.configured(Transport.ClientSsl(Some(cfg))) 42 | } 43 | 44 | val client: Client = Client() 45 | 46 | case class Server( 47 | stack: Stack[ServiceFactory[Request, Response]] = StackServer.newStack, 48 | params: Stack.Params = StackServer.defaultParams + ProtocolLibrary("ws")) 49 | extends StdStackServer[Request, Response, Server] { 50 | 51 | protected type In = Any 52 | protected type Out = Any 53 | protected type Context = TransportContext 54 | 55 | protected def newListener(): Listener[In, Out, Context] = 56 | Netty3.newListener(params) 57 | 58 | private[this] val statsReceiver = { 59 | val Stats(sr) = params[Stats] 60 | sr.scope("websocket") 61 | } 62 | 63 | protected def newDispatcher( 64 | transport: Transport[In, Out] { type Context <: Server.this.Context }, 65 | service: Service[Request, Response] 66 | ): Closable = 67 | new ServerDispatcher(transport, service, statsReceiver) 68 | 69 | protected def copy1( 70 | stack: Stack[ServiceFactory[Request, Response]] = this.stack, 71 | params: Stack.Params = this.params 72 | ): Server = copy(stack, params) 73 | } 74 | 75 | val server: Websocket.Server = Server() 76 | 77 | def serve( 78 | addr: SocketAddress, 79 | factory: ServiceFactory[Request, Response] 80 | ): ListeningServer = server.serve(addr, factory) 81 | } 82 | -------------------------------------------------------------------------------- /src/main/scala/com/twitter/finagle/websocket/ClientDispatcher.scala: -------------------------------------------------------------------------------- 1 | package com.twitter.finagle.websocket 2 | 3 | import com.twitter.concurrent.AsyncStream 4 | import com.twitter.finagle.dispatch.GenSerialClientDispatcher 5 | import com.twitter.finagle.stats.{NullStatsReceiver, StatsReceiver} 6 | import com.twitter.finagle.transport.Transport 7 | import com.twitter.io.Buf 8 | import com.twitter.util.{Closable, Future, Promise, Time} 9 | import org.jboss.netty.handler.codec.http.HttpRequest 10 | import org.jboss.netty.handler.codec.http.websocketx.CloseWebSocketFrame 11 | 12 | private[finagle] class ClientDispatcher( 13 | trans: Transport[Any, Any], 14 | statsReceiver: StatsReceiver) 15 | extends GenSerialClientDispatcher[Request, Response, Any, Any]( 16 | trans, 17 | statsReceiver) { 18 | 19 | import Netty3.{fromNetty, newHandshaker, toNetty} 20 | import GenSerialClientDispatcher.wrapWriteException 21 | 22 | def this(trans: Transport[Any, Any]) = 23 | this(trans, NullStatsReceiver) 24 | 25 | private[this] def messages(): AsyncStream[Frame] = 26 | AsyncStream.fromFuture(trans.read()).flatMap { 27 | case _: CloseWebSocketFrame => AsyncStream.empty 28 | case frame => fromNetty(frame) +:: messages() 29 | } 30 | 31 | protected def dispatch(req: Request, p: Promise[Response]): Future[Unit] = { 32 | p.setValue(Response(messages)) 33 | 34 | val handshake = newHandshaker(req.uri, req.headers) 35 | trans.write(handshake).rescue(wrapWriteException) before 36 | req.messages.foreachF(msg => trans.write(toNetty(msg))) before 37 | trans.write(new CloseWebSocketFrame) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/main/scala/com/twitter/finagle/websocket/Frame.scala: -------------------------------------------------------------------------------- 1 | package com.twitter.finagle.websocket 2 | 3 | import com.twitter.io.Buf 4 | 5 | /** 6 | * Represents various WebSocket frames. 7 | * 8 | * This is a simplification of the frame types described in RFC6455[1]. Notably 9 | * absent are Continuation and Close. Close is handled directly in the 10 | * pipeline, initiating the close handshake. Continuations are treated as 11 | * Binary frames, which means we lose the ability to determine fragmentation. 12 | * 13 | * [1]: https://tools.ietf.org/html/rfc6455 14 | */ 15 | sealed trait Frame 16 | 17 | object Frame { 18 | case class Text(text: String) extends Frame 19 | case class Binary(buf: Buf) extends Frame 20 | case class Ping(buf: Buf) extends Frame 21 | case class Pong(buf: Buf) extends Frame 22 | } 23 | -------------------------------------------------------------------------------- /src/main/scala/com/twitter/finagle/websocket/Netty3.scala: -------------------------------------------------------------------------------- 1 | package com.twitter.finagle.websocket 2 | 3 | import com.twitter.finagle.Stack 4 | import com.twitter.finagle.client.Transporter 5 | import com.twitter.finagle.netty3._ 6 | import com.twitter.finagle.server.Listener 7 | import com.twitter.finagle.transport.TransportContext 8 | import java.net.{SocketAddress, URI} 9 | 10 | import org.jboss.netty.channel.{ChannelPipelineFactory, Channels} 11 | import org.jboss.netty.handler.codec.http._ 12 | import org.jboss.netty.handler.codec.http.websocketx._ 13 | 14 | import scala.collection.JavaConverters._ 15 | 16 | private[finagle] object Netty3 { 17 | import Frame._ 18 | 19 | private def serverPipeline = { 20 | val pipeline = Channels.pipeline() 21 | pipeline.addLast("decoder", new HttpRequestDecoder) 22 | pipeline.addLast("encoder", new HttpResponseEncoder) 23 | pipeline.addLast("handler", new WebSocketServerHandler) 24 | pipeline 25 | } 26 | 27 | private def clientPipeline() = { 28 | val pipeline = Channels.pipeline() 29 | pipeline.addLast("decoder", new HttpResponseDecoder) 30 | pipeline.addLast("encoder", new HttpRequestEncoder) 31 | pipeline.addLast("handler", new WebSocketClientHandler) 32 | pipeline 33 | } 34 | 35 | def newListener[In, Out](params: Stack.Params): Listener[In, Out, TransportContext] = 36 | Netty3Listener(() => serverPipeline, params) 37 | 38 | def newTransporter[In, Out]( 39 | addr: SocketAddress, 40 | params: Stack.Params 41 | ): Transporter[In, Out, TransportContext] = 42 | Netty3Transporter[In, Out](() => clientPipeline(), addr, params) 43 | 44 | def fromNetty(m: Any): Frame = m match { 45 | case text: TextWebSocketFrame => 46 | Text(text.getText) 47 | 48 | case cont: ContinuationWebSocketFrame => 49 | Text(cont.getText) 50 | 51 | case bin: BinaryWebSocketFrame => 52 | Binary(new ChannelBufferBuf(bin.getBinaryData)) 53 | 54 | case ping: PingWebSocketFrame => 55 | Ping(new ChannelBufferBuf(ping.getBinaryData)) 56 | 57 | case pong: PongWebSocketFrame => 58 | Pong(new ChannelBufferBuf(pong.getBinaryData)) 59 | 60 | case frame => 61 | throw new IllegalStateException(s"unknown frame: $frame") 62 | } 63 | 64 | def toNetty(frame: Frame): WebSocketFrame = frame match { 65 | case Text(message) => 66 | new TextWebSocketFrame(message) 67 | 68 | case Binary(buf) => 69 | new BinaryWebSocketFrame(BufChannelBuffer(buf)) 70 | 71 | case Ping(buf) => 72 | new PingWebSocketFrame(BufChannelBuffer(buf)) 73 | 74 | case Pong(buf) => 75 | new PongWebSocketFrame(BufChannelBuffer(buf)) 76 | } 77 | 78 | def newHandshaker(uri: URI, headers: Map[String, String]): WebSocketClientHandshaker = { 79 | val factory = new WebSocketClientHandshakerFactory 80 | factory.newHandshaker(uri, WebSocketVersion.V13, null, false, headers.asJava) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/main/scala/com/twitter/finagle/websocket/Request.scala: -------------------------------------------------------------------------------- 1 | package com.twitter.finagle.websocket 2 | 3 | import com.twitter.concurrent.AsyncStream 4 | import java.net.{SocketAddress, URI} 5 | import scala.collection.immutable 6 | 7 | case class Request( 8 | uri: URI, 9 | headers: immutable.Map[String, String], 10 | remoteAddress: SocketAddress, 11 | messages: AsyncStream[Frame]) 12 | -------------------------------------------------------------------------------- /src/main/scala/com/twitter/finagle/websocket/Response.scala: -------------------------------------------------------------------------------- 1 | package com.twitter.finagle.websocket 2 | 3 | import com.twitter.concurrent.AsyncStream 4 | 5 | case class Response(messages: AsyncStream[Frame]) 6 | 7 | -------------------------------------------------------------------------------- /src/main/scala/com/twitter/finagle/websocket/ServerDispatcher.scala: -------------------------------------------------------------------------------- 1 | package com.twitter.finagle.websocket 2 | 3 | import com.twitter.concurrent.AsyncStream 4 | import com.twitter.finagle.Service 5 | import com.twitter.finagle.stats.StatsReceiver 6 | import com.twitter.finagle.transport.Transport 7 | import com.twitter.io.Buf 8 | import com.twitter.util.{Closable, Future, Time} 9 | import java.net.{SocketAddress, URI} 10 | import org.jboss.netty.handler.codec.http.HttpRequest 11 | import org.jboss.netty.handler.codec.http.websocketx.CloseWebSocketFrame 12 | import scala.collection.JavaConverters._ 13 | 14 | private[finagle] class ServerDispatcher( 15 | trans: Transport[Any, Any], 16 | service: Service[Request, Response], 17 | stats: StatsReceiver) 18 | extends Closable { 19 | 20 | import Netty3.{fromNetty, toNetty} 21 | 22 | private[this] def messages(): AsyncStream[Frame] = 23 | AsyncStream.fromFuture(trans.read()).flatMap { 24 | case _: CloseWebSocketFrame => AsyncStream.empty 25 | case frame => fromNetty(frame) +:: messages() 26 | } 27 | 28 | // The first item is a HttpRequest. 29 | trans.read().flatMap { 30 | case (req: HttpRequest, addr: SocketAddress) => 31 | val uri = new URI(req.getUri) 32 | val headers = req.headers.asScala.map(e => e.getKey -> e.getValue).toMap 33 | service(Request(uri, headers, addr, messages)).flatMap { response => 34 | response.messages 35 | .map(toNetty) 36 | .foreachF(trans.write) 37 | .ensure(trans.close()) 38 | } 39 | 40 | case _ => 41 | trans.close() 42 | } 43 | 44 | def close(deadline: Time): Future[Unit] = trans.close() 45 | } 46 | -------------------------------------------------------------------------------- /src/main/scala/com/twitter/finagle/websocket/WebSocketHandler.scala: -------------------------------------------------------------------------------- 1 | package com.twitter.finagle.websocket 2 | 3 | import com.twitter.concurrent.{Offer, Broker} 4 | import com.twitter.finagle.{CancelledRequestException, ChannelException} 5 | import com.twitter.finagle.util.DefaultTimer 6 | import com.twitter.util.{Future, Promise, Return, Throw, Try, TimerTask} 7 | import java.net.URI 8 | import org.jboss.netty.buffer.ChannelBuffers 9 | import org.jboss.netty.channel._ 10 | import org.jboss.netty.handler.codec.http.websocketx._ 11 | import org.jboss.netty.handler.codec.http.{HttpHeaders, HttpRequest, HttpResponse} 12 | import scala.collection.JavaConversions._ 13 | 14 | private[finagle] class WebSocketServerHandler extends SimpleChannelUpstreamHandler { 15 | private[this] var handshaker: Option[WebSocketServerHandshaker] = None 16 | 17 | override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) = 18 | e.getMessage match { 19 | case req: HttpRequest => 20 | val scheme = if(req.getUri.startsWith("wss")) "wss" else "ws" 21 | val location = scheme + "://" + req.headers.get(HttpHeaders.Names.HOST) + "/" 22 | val wsFactory = new WebSocketServerHandshakerFactory(location, null, false) 23 | handshaker = Option(wsFactory.newHandshaker(req)) 24 | handshaker match { 25 | case None => 26 | wsFactory.sendUnsupportedWebSocketVersionResponse(ctx.getChannel) 27 | case Some(ref) => 28 | ref.handshake(ctx.getChannel, req) 29 | val addr = ctx.getChannel.getRemoteAddress 30 | Channels.fireMessageReceived(ctx, (req, addr)) 31 | } 32 | 33 | case frame: CloseWebSocketFrame => 34 | handshaker match { 35 | case Some(hs) => 36 | hs.close(ctx.getChannel, frame).addListener(ChannelFutureListener.CLOSE) 37 | Channels.fireMessageReceived(ctx, frame) 38 | 39 | case None => 40 | Channels.fireExceptionCaught(ctx, 41 | new IllegalArgumentException(s"Close received before handshake")) 42 | } 43 | 44 | case frame: WebSocketFrame => 45 | Channels.fireMessageReceived(ctx, frame) 46 | 47 | case invalid => 48 | Channels.fireExceptionCaught(ctx, 49 | new IllegalArgumentException(s"invalid message: $invalid")) 50 | } 51 | } 52 | 53 | private[finagle] class WebSocketClientHandler extends SimpleChannelHandler { 54 | @volatile private[this] var ref: Option[WebSocketClientHandshaker] = None 55 | 56 | override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent): Unit = 57 | e.getMessage match { 58 | case res: HttpResponse => 59 | ref match { 60 | case None => 61 | throw new IllegalStateException("unexpected HTTP response before handshake") 62 | case Some(handshaker) if handshaker.isHandshakeComplete => 63 | throw new IllegalStateException("unexpected HTTP response after handshake") 64 | case Some(handshaker) => 65 | handshaker.finishHandshake(ctx.getChannel, res) 66 | } 67 | 68 | case frame: WebSocketFrame => 69 | Channels.fireMessageReceived(ctx, frame) 70 | 71 | case invalid => 72 | Channels.fireExceptionCaught(ctx, 73 | new IllegalArgumentException(s"invalid message: $invalid")) 74 | } 75 | 76 | override def writeRequested(ctx: ChannelHandlerContext, e: MessageEvent): Unit = 77 | e.getMessage match { 78 | case handshaker: WebSocketClientHandshaker => 79 | ref = Some(handshaker) 80 | val future = handshaker.handshake(ctx.getChannel) 81 | future.addListener(ChannelFutureListener.CLOSE_ON_FAILURE) 82 | future.addListener(new ChannelFutureListener { 83 | override def operationComplete(f: ChannelFuture): Unit = 84 | if (f.isSuccess) e.getFuture.setSuccess() 85 | else if (f.isCancelled) e.getFuture.cancel() 86 | else e.getFuture.setFailure(f.getCause) 87 | }) 88 | 89 | case frame: WebSocketFrame => 90 | ctx.sendDownstream(e) 91 | 92 | case req: HttpRequest if !ref.isEmpty => 93 | ctx.sendDownstream(e) 94 | 95 | case invalid => 96 | Channels.fireExceptionCaught(ctx, 97 | new IllegalArgumentException(s"invalid message: $invalid")) 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /src/test/scala/com/twitter/finagle/websocket/EndToEndTest.scala: -------------------------------------------------------------------------------- 1 | package com.twitter.finagle.websocket 2 | 3 | import com.twitter.concurrent.AsyncStream 4 | import com.twitter.conversions.time._ 5 | import com.twitter.finagle 6 | import com.twitter.finagle.Service 7 | import com.twitter.finagle.param.Stats 8 | import com.twitter.finagle.stats.{NullStatsReceiver, StatsReceiver} 9 | import com.twitter.io.Buf 10 | import com.twitter.util._ 11 | import java.net.{InetSocketAddress, SocketAddress, URI} 12 | import org.junit.runner.RunWith 13 | import org.scalatest.FunSuite 14 | import org.scalatest.junit.JUnitRunner 15 | import scala.collection.mutable.ArrayBuffer 16 | 17 | @RunWith(classOf[JUnitRunner]) 18 | class EndToEndTest extends FunSuite { 19 | import Frame._ 20 | import EndToEndTest._ 21 | test("echo") { 22 | val echo = new Service[Request, Response] { 23 | def apply(req: Request): Future[Response] = 24 | Future.value(Response(req.messages)) 25 | } 26 | 27 | connect(echo) { client => 28 | val frames = texts("hello", "world") 29 | for { 30 | response <- client(mkRequest("/", frames)) 31 | messages <- response.messages.toSeq() 32 | } yield assert(messages == frames) 33 | } 34 | } 35 | } 36 | 37 | private object EndToEndTest { 38 | def connect( 39 | service: Service[Request, Response], 40 | stats: StatsReceiver = NullStatsReceiver 41 | )(run: Service[Request, Response] => Future[Unit]): Unit = { 42 | val server = finagle.Websocket.server 43 | .withLabel("server") 44 | .configured(Stats(stats)) 45 | .serve("localhost:*", service) 46 | 47 | val addr = server.boundAddress.asInstanceOf[InetSocketAddress] 48 | 49 | val client = finagle.Websocket.client 50 | .configured(Stats(stats)) 51 | .newService(s"${addr.getHostName}:${addr.getPort}", "client") 52 | 53 | Await.result(run(client).ensure(Closable.all(client, server).close()), 1.second) 54 | } 55 | 56 | def texts(messages: String*): Seq[Frame] = 57 | messages.map(Frame.Text(_)) 58 | 59 | def mkRequest(path: String, frames: Seq[Frame]): Request = 60 | Request(new URI(path), Map.empty, new SocketAddress{}, AsyncStream.fromSeq(frames)) 61 | } 62 | -------------------------------------------------------------------------------- /src/test/scala/com/twitter/finagle/websocket/ServerDispatcherTest.scala: -------------------------------------------------------------------------------- 1 | package com.twitter.finagle.websocket 2 | 3 | import com.twitter.concurrent.AsyncQueue 4 | import com.twitter.conversions.time._ 5 | import com.twitter.finagle.{Service, Status} 6 | import com.twitter.finagle.stats.DefaultStatsReceiver 7 | import com.twitter.finagle.transport.{QueueTransport, Transport} 8 | import com.twitter.util.{Await, Future} 9 | import java.net.SocketAddress 10 | import org.junit.runner.RunWith 11 | import org.jboss.netty.handler.codec.http._ 12 | import org.jboss.netty.handler.codec.http.websocketx._ 13 | import org.scalatest.FunSuite 14 | import org.scalatest.junit.JUnitRunner 15 | 16 | @RunWith(classOf[JUnitRunner]) 17 | class ServerDispatcherTest extends FunSuite { 18 | import ServerDispatcherTest._ 19 | 20 | val echo = new Service[Request, Response] { 21 | def apply(req: Request): Future[Response] = { 22 | Future.value(Response(req.messages)) 23 | } 24 | } 25 | 26 | test("invalid message") { 27 | val (in, out) = mkPair[Any, Any] 28 | val disp = new ServerDispatcher(out, echo, DefaultStatsReceiver) 29 | in.write("invalid") 30 | Await.ready(out.onClose, 1.second) 31 | assert(out.status == Status.Closed) 32 | } 33 | 34 | test("valid message then invalid") { 35 | val (in, out) = mkPair[Any, Any] 36 | val disp = new ServerDispatcher(out, echo, DefaultStatsReceiver) 37 | val req = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/") 38 | val addr = new SocketAddress{} 39 | in.write((req, addr)) 40 | in.write(new TextWebSocketFrame("hello")) 41 | val frame = Await.result(in.read(), 1.second) 42 | assert(frame.asInstanceOf[TextWebSocketFrame].getText == "hello") 43 | in.write("invalid") 44 | assert(out.status == Status.Closed) 45 | } 46 | } 47 | 48 | object ServerDispatcherTest { 49 | def mkPair[A,B] = { 50 | val inq = new AsyncQueue[A] 51 | val outq = new AsyncQueue[B] 52 | (new QueueTransport[A, B](inq, outq), new QueueTransport[B, A](outq, inq)) 53 | } 54 | } 55 | --------------------------------------------------------------------------------