├── .gitignore
├── .travis.yml
├── README.md
├── finagle-websocket
└── .travis.yml
├── project
├── Build.scala
├── build.properties
└── plugins.sbt
├── sbt
└── src
├── main
└── scala
│ └── com
│ └── twitter
│ └── finagle
│ ├── WebSocket.scala
│ └── websocket
│ ├── ClientDispatcher.scala
│ ├── Frame.scala
│ ├── Netty3.scala
│ ├── Request.scala
│ ├── Response.scala
│ ├── ServerDispatcher.scala
│ └── WebSocketHandler.scala
└── test
└── scala
└── com
└── twitter
└── finagle
└── websocket
├── EndToEndTest.scala
└── ServerDispatcherTest.scala
/.gitignore:
--------------------------------------------------------------------------------
1 | *.class
2 | *.log
3 |
4 | # sbt specific
5 | dist/*
6 | target/
7 | lib_managed/
8 | src_managed/
9 | project/boot/
10 | project/plugins/project/
11 |
12 | sbt-launch.jar
13 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: scala
2 | dist: trusty
3 | sudo: false
4 |
5 | scala:
6 | - 2.11.11
7 | - 2.12.2
8 |
9 | jdk:
10 | - oraclejdk8
11 | - openjdk8
12 |
13 | script: sbt ++$TRAVIS_SCALA_VERSION coverage test it:test
14 | after_success: sbt ++$TRAVIS_SCALA_VERSION coveralls
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Finagle Websocket
2 |
3 | [](https://travis-ci.org/finagle/finagle-websocket)
4 | [](https://coveralls.io/r/finagle/finagle-websocket?branch=master)
5 | [](#status)
6 |
7 | Websockets support for Finagle
8 |
9 | ## Status
10 |
11 | This project is inactive. While we are keep it up to date with new Finagle
12 | releases, it is not currently being actively used or developed. If you're using
13 | Finagle Websocket and would be interested in discussing co-ownership, please
14 | [file an issue](https://github.com/finagle/finagle-websocket/issues).
15 |
16 | ## Using websockets
17 |
18 | ### Adding dependencies
19 |
20 | Maven
21 |
22 |
23 |
24 | com.github.sprsquish
25 | https://raw.github.com/sprsquish/mvn-repo/master/
26 | default
27 |
28 |
29 |
30 |
31 | com.github.sprsquish
32 | finagle-websockets_2.9.2
33 | 6.8.1
34 | compile
35 |
36 |
37 | sbt
38 |
39 | resolvers += "com.github.sprsquish" at "https://raw.github.com/sprsquish/mvn-repo/master"
40 |
41 | "com.github.sprsquish" %% "finagle-websockets" % "6.8.1"
42 |
43 | ## Example
44 |
45 | The following client and server can be run by pasting the code into `sbt
46 | console`. The client sends "1" over to the server. The server responds by
47 | translating the numeral into an word (mostly). When the client receives this
48 | word, it will send back a number which represents the length of characters of
49 | the received word.
50 |
51 | ### Client
52 |
53 | ```scala
54 | import com.twitter.concurrent.AsyncStream
55 | import com.twitter.conversions.time._
56 | import com.twitter.finagle.Websocket
57 | import com.twitter.finagle.util.DefaultTimer
58 | import com.twitter.finagle.websocket.{Frame, Request}
59 | import com.twitter.util.Promise
60 | import java.net.URI
61 |
62 | implicit val timer = DefaultTimer.twitter
63 |
64 | // Responds to messages from the server.
65 | def handler(messages: AsyncStream[Frame]): AsyncStream[Frame] =
66 | messages.flatMap {
67 | case Frame.Text(message) =>
68 | // Print the received message.
69 | println(message)
70 |
71 | AsyncStream.fromFuture(
72 | // Sleep for a second...
73 | Future.sleep(1.second).map { _ =>
74 | // ... and then send a message to the server.
75 | Frame.Text(message.length.toString)
76 | })
77 |
78 | case _ => AsyncStream.of(Frame.Text("??"))
79 | }
80 |
81 | val incoming = new Promise[AsyncStream[Frame]]
82 | val outgoing =
83 | Frame.Text("1") +:: handler(
84 | AsyncStream.fromFuture(incoming).flatten)
85 |
86 | val client = Websocket.client.newService(":14000")
87 | val req = Request(new URI("/"), Map.empty, null, outgoing)
88 |
89 | // Take the messages of the response and fulfill `incoming`.
90 | client(req).map(_.messages).proxyTo(incoming)
91 | ```
92 |
93 | ### Server
94 |
95 | ```scala
96 | import com.twitter.concurrent.AsyncStream
97 | import com.twitter.finagle.{Service, Websocket}
98 | import com.twitter.finagle.websocket.{Frame, Request, Response}
99 | import com.twitter.util.Future
100 |
101 | // A server that when given a number, responds with a word (mostly).
102 | def handler(messages: AsyncStream[Frame]): AsyncStream[Frame] = {
103 | messages.map {
104 | case Frame.Text("1") => Frame.Text("one")
105 | case Frame.Text("2") => Frame.Text("two")
106 | case Frame.Text("3") => Frame.Text("three")
107 | case Frame.Text("4") => Frame.Text("cuatro")
108 | case Frame.Text("5") => Frame.Text("five")
109 | case Frame.Text("6") => Frame.Text("6")
110 | case _ => Frame.Text("??")
111 | }
112 | }
113 |
114 | Websocket.serve(":14000", new Service[Request, Response] {
115 | def apply(req: Request): Future[Response] =
116 | Future.value(Response(handler(req.messages)))
117 | })
118 | ```
119 |
--------------------------------------------------------------------------------
/finagle-websocket/.travis.yml:
--------------------------------------------------------------------------------
1 | language: scala
2 | script: "sbt +test"
3 |
--------------------------------------------------------------------------------
/project/Build.scala:
--------------------------------------------------------------------------------
1 | import sbt._
2 | import Keys._
3 |
4 | object FinagleWebsocket extends Build {
5 | val libVersion = "17.12.0"
6 |
7 | val baseSettings = Defaults.coreDefaultSettings ++ Seq(
8 | libraryDependencies ++= Seq(
9 | "com.twitter" %% "finagle-core" % libVersion,
10 | "com.twitter" %% "finagle-netty3" % libVersion,
11 | "org.scalatest" %% "scalatest" % "3.0.1" % Test,
12 | "junit" % "junit" % "4.12" % Test
13 | )
14 | )
15 |
16 | lazy val buildSettings = Seq(
17 | organization := "com.github.finagle",
18 | version := libVersion,
19 | scalaVersion := "2.12.2",
20 | crossScalaVersions := Seq("2.11.11", "2.12.2"),
21 | scalacOptions ++= Seq("-deprecation", "-feature", "-Xexperimental")
22 | )
23 |
24 | lazy val publishSettings = Seq(
25 | publishMavenStyle := true,
26 | publishArtifact := true,
27 | publishTo := Some(Resolver.file("localDirectory", file(Path.userHome.absolutePath + "/workspace/mvn-repo"))),
28 | licenses := Seq("Apache 2.0" -> url("http://www.apache.org/licenses/LICENSE-2.0")),
29 | homepage := Some(url("https://github.com/finagle/finagle-websocket")),
30 | pomExtra := (
31 |
32 | git://github.com/finagle/finagle-websocket.git
33 | scm:git://github.com/finagle/finagle-websocket.git
34 |
35 |
36 |
37 | sprsquish
38 | Jeff Smick
39 | https://github.com/sprsquish
40 |
41 | )
42 | )
43 |
44 | lazy val finagleWebsocket = Project(
45 | id = "finagle-websocket",
46 | base = file("."),
47 | settings =
48 | Defaults.itSettings ++
49 | baseSettings ++
50 | buildSettings ++
51 | publishSettings
52 | ).configs(IntegrationTest)
53 | }
54 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=0.13.13
2 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.8.2")
2 |
3 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0")
4 |
5 | addSbtPlugin("org.scoverage" % "sbt-coveralls" % "1.1.0")
6 |
--------------------------------------------------------------------------------
/sbt:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | sbtver=0.13.13
4 | sbtjar=sbt-launch.jar
5 | sbtsha128=57d0f04f4b48b11ef7e764f4cea58dee4e806ffd
6 |
7 | sbtrepo=http://repo.typesafe.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch
8 |
9 | if [ ! -f $sbtjar ]; then
10 | echo "downloading $sbtjar" 1>&2
11 | if ! curl --location --silent --fail --remote-name $sbtrepo/$sbtver/$sbtjar; then
12 | exit 1
13 | fi
14 | fi
15 |
16 | checksum=`openssl dgst -sha1 $sbtjar | awk '{ print $2 }'`
17 | if [ "$checksum" != $sbtsha128 ]; then
18 | echo "bad $sbtjar. delete $sbtjar and run $0 again."
19 | exit 1
20 | fi
21 |
22 | [ -f ~/.sbtconfig ] && . ~/.sbtconfig
23 |
24 | java -ea \
25 | $SBT_OPTS \
26 | $JAVA_OPTS \
27 | -Djava.net.preferIPv4Stack=true \
28 | -XX:+AggressiveOpts \
29 | -XX:+UseParNewGC \
30 | -XX:+UseConcMarkSweepGC \
31 | -XX:+CMSParallelRemarkEnabled \
32 | -XX:+CMSClassUnloadingEnabled \
33 | -XX:ReservedCodeCacheSize=128m \
34 | -XX:MaxPermSize=1024m \
35 | -XX:SurvivorRatio=128 \
36 | -XX:MaxTenuringThreshold=0 \
37 | -Xss8M \
38 | -Xms512M \
39 | -Xmx2G \
40 | -server \
41 | -jar $sbtjar "$@"
42 |
--------------------------------------------------------------------------------
/src/main/scala/com/twitter/finagle/WebSocket.scala:
--------------------------------------------------------------------------------
1 | package com.twitter.finagle
2 |
3 | import com.twitter.finagle.netty3._
4 | import com.twitter.finagle.param.{Label, ProtocolLibrary, Stats}
5 | import com.twitter.finagle.client.{StackClient, StdStackClient, Transporter}
6 | import com.twitter.finagle.server.{Listener, StackServer, StdStackServer}
7 | import com.twitter.finagle.websocket.{ClientDispatcher, Netty3, Request, Response, ServerDispatcher}
8 | import com.twitter.finagle.transport.{Transport, TransportContext}
9 | import com.twitter.util.Closable
10 | import java.net.SocketAddress
11 |
12 | import com.twitter.finagle.ssl.client.SslClientConfiguration
13 | import org.jboss.netty.channel.Channel
14 |
15 | object Websocket extends Server[Request, Response] {
16 | case class Client(stack: Stack[ServiceFactory[Request, Response]] = StackClient.newStack,
17 | params: Stack.Params = StackClient.defaultParams + ProtocolLibrary("ws"))
18 | extends StdStackClient[Request, Response, Client] {
19 | override protected type In = Any
20 | override protected type Out = Any
21 | override protected type Context = TransportContext
22 |
23 | override protected def newTransporter(addr: SocketAddress): Transporter[In, Out, Context] =
24 | Netty3.newTransporter(addr, params)
25 |
26 | override protected def copy1(
27 | stack: Stack[ServiceFactory[Request, Response]] = this.stack,
28 | params: Stack.Params = this.params
29 | ): Client = copy(stack, params)
30 |
31 | override protected def newDispatcher(transport: Transport[In, Out] {
32 | type Context <: TransportContext
33 | }): Service[Request, Response] =
34 | new ClientDispatcher(transport)
35 |
36 | def withTlsWithoutValidation: Client = withTransport.tlsWithoutValidation
37 |
38 | def withTls(hostname: String): Client = withTransport.tls(hostname)
39 |
40 | def withTls(cfg: SslClientConfiguration): Client =
41 | withTransport.tls.configured(Transport.ClientSsl(Some(cfg)))
42 | }
43 |
44 | val client: Client = Client()
45 |
46 | case class Server(
47 | stack: Stack[ServiceFactory[Request, Response]] = StackServer.newStack,
48 | params: Stack.Params = StackServer.defaultParams + ProtocolLibrary("ws"))
49 | extends StdStackServer[Request, Response, Server] {
50 |
51 | protected type In = Any
52 | protected type Out = Any
53 | protected type Context = TransportContext
54 |
55 | protected def newListener(): Listener[In, Out, Context] =
56 | Netty3.newListener(params)
57 |
58 | private[this] val statsReceiver = {
59 | val Stats(sr) = params[Stats]
60 | sr.scope("websocket")
61 | }
62 |
63 | protected def newDispatcher(
64 | transport: Transport[In, Out] { type Context <: Server.this.Context },
65 | service: Service[Request, Response]
66 | ): Closable =
67 | new ServerDispatcher(transport, service, statsReceiver)
68 |
69 | protected def copy1(
70 | stack: Stack[ServiceFactory[Request, Response]] = this.stack,
71 | params: Stack.Params = this.params
72 | ): Server = copy(stack, params)
73 | }
74 |
75 | val server: Websocket.Server = Server()
76 |
77 | def serve(
78 | addr: SocketAddress,
79 | factory: ServiceFactory[Request, Response]
80 | ): ListeningServer = server.serve(addr, factory)
81 | }
82 |
--------------------------------------------------------------------------------
/src/main/scala/com/twitter/finagle/websocket/ClientDispatcher.scala:
--------------------------------------------------------------------------------
1 | package com.twitter.finagle.websocket
2 |
3 | import com.twitter.concurrent.AsyncStream
4 | import com.twitter.finagle.dispatch.GenSerialClientDispatcher
5 | import com.twitter.finagle.stats.{NullStatsReceiver, StatsReceiver}
6 | import com.twitter.finagle.transport.Transport
7 | import com.twitter.io.Buf
8 | import com.twitter.util.{Closable, Future, Promise, Time}
9 | import org.jboss.netty.handler.codec.http.HttpRequest
10 | import org.jboss.netty.handler.codec.http.websocketx.CloseWebSocketFrame
11 |
12 | private[finagle] class ClientDispatcher(
13 | trans: Transport[Any, Any],
14 | statsReceiver: StatsReceiver)
15 | extends GenSerialClientDispatcher[Request, Response, Any, Any](
16 | trans,
17 | statsReceiver) {
18 |
19 | import Netty3.{fromNetty, newHandshaker, toNetty}
20 | import GenSerialClientDispatcher.wrapWriteException
21 |
22 | def this(trans: Transport[Any, Any]) =
23 | this(trans, NullStatsReceiver)
24 |
25 | private[this] def messages(): AsyncStream[Frame] =
26 | AsyncStream.fromFuture(trans.read()).flatMap {
27 | case _: CloseWebSocketFrame => AsyncStream.empty
28 | case frame => fromNetty(frame) +:: messages()
29 | }
30 |
31 | protected def dispatch(req: Request, p: Promise[Response]): Future[Unit] = {
32 | p.setValue(Response(messages))
33 |
34 | val handshake = newHandshaker(req.uri, req.headers)
35 | trans.write(handshake).rescue(wrapWriteException) before
36 | req.messages.foreachF(msg => trans.write(toNetty(msg))) before
37 | trans.write(new CloseWebSocketFrame)
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/src/main/scala/com/twitter/finagle/websocket/Frame.scala:
--------------------------------------------------------------------------------
1 | package com.twitter.finagle.websocket
2 |
3 | import com.twitter.io.Buf
4 |
5 | /**
6 | * Represents various WebSocket frames.
7 | *
8 | * This is a simplification of the frame types described in RFC6455[1]. Notably
9 | * absent are Continuation and Close. Close is handled directly in the
10 | * pipeline, initiating the close handshake. Continuations are treated as
11 | * Binary frames, which means we lose the ability to determine fragmentation.
12 | *
13 | * [1]: https://tools.ietf.org/html/rfc6455
14 | */
15 | sealed trait Frame
16 |
17 | object Frame {
18 | case class Text(text: String) extends Frame
19 | case class Binary(buf: Buf) extends Frame
20 | case class Ping(buf: Buf) extends Frame
21 | case class Pong(buf: Buf) extends Frame
22 | }
23 |
--------------------------------------------------------------------------------
/src/main/scala/com/twitter/finagle/websocket/Netty3.scala:
--------------------------------------------------------------------------------
1 | package com.twitter.finagle.websocket
2 |
3 | import com.twitter.finagle.Stack
4 | import com.twitter.finagle.client.Transporter
5 | import com.twitter.finagle.netty3._
6 | import com.twitter.finagle.server.Listener
7 | import com.twitter.finagle.transport.TransportContext
8 | import java.net.{SocketAddress, URI}
9 |
10 | import org.jboss.netty.channel.{ChannelPipelineFactory, Channels}
11 | import org.jboss.netty.handler.codec.http._
12 | import org.jboss.netty.handler.codec.http.websocketx._
13 |
14 | import scala.collection.JavaConverters._
15 |
16 | private[finagle] object Netty3 {
17 | import Frame._
18 |
19 | private def serverPipeline = {
20 | val pipeline = Channels.pipeline()
21 | pipeline.addLast("decoder", new HttpRequestDecoder)
22 | pipeline.addLast("encoder", new HttpResponseEncoder)
23 | pipeline.addLast("handler", new WebSocketServerHandler)
24 | pipeline
25 | }
26 |
27 | private def clientPipeline() = {
28 | val pipeline = Channels.pipeline()
29 | pipeline.addLast("decoder", new HttpResponseDecoder)
30 | pipeline.addLast("encoder", new HttpRequestEncoder)
31 | pipeline.addLast("handler", new WebSocketClientHandler)
32 | pipeline
33 | }
34 |
35 | def newListener[In, Out](params: Stack.Params): Listener[In, Out, TransportContext] =
36 | Netty3Listener(() => serverPipeline, params)
37 |
38 | def newTransporter[In, Out](
39 | addr: SocketAddress,
40 | params: Stack.Params
41 | ): Transporter[In, Out, TransportContext] =
42 | Netty3Transporter[In, Out](() => clientPipeline(), addr, params)
43 |
44 | def fromNetty(m: Any): Frame = m match {
45 | case text: TextWebSocketFrame =>
46 | Text(text.getText)
47 |
48 | case cont: ContinuationWebSocketFrame =>
49 | Text(cont.getText)
50 |
51 | case bin: BinaryWebSocketFrame =>
52 | Binary(new ChannelBufferBuf(bin.getBinaryData))
53 |
54 | case ping: PingWebSocketFrame =>
55 | Ping(new ChannelBufferBuf(ping.getBinaryData))
56 |
57 | case pong: PongWebSocketFrame =>
58 | Pong(new ChannelBufferBuf(pong.getBinaryData))
59 |
60 | case frame =>
61 | throw new IllegalStateException(s"unknown frame: $frame")
62 | }
63 |
64 | def toNetty(frame: Frame): WebSocketFrame = frame match {
65 | case Text(message) =>
66 | new TextWebSocketFrame(message)
67 |
68 | case Binary(buf) =>
69 | new BinaryWebSocketFrame(BufChannelBuffer(buf))
70 |
71 | case Ping(buf) =>
72 | new PingWebSocketFrame(BufChannelBuffer(buf))
73 |
74 | case Pong(buf) =>
75 | new PongWebSocketFrame(BufChannelBuffer(buf))
76 | }
77 |
78 | def newHandshaker(uri: URI, headers: Map[String, String]): WebSocketClientHandshaker = {
79 | val factory = new WebSocketClientHandshakerFactory
80 | factory.newHandshaker(uri, WebSocketVersion.V13, null, false, headers.asJava)
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/src/main/scala/com/twitter/finagle/websocket/Request.scala:
--------------------------------------------------------------------------------
1 | package com.twitter.finagle.websocket
2 |
3 | import com.twitter.concurrent.AsyncStream
4 | import java.net.{SocketAddress, URI}
5 | import scala.collection.immutable
6 |
7 | case class Request(
8 | uri: URI,
9 | headers: immutable.Map[String, String],
10 | remoteAddress: SocketAddress,
11 | messages: AsyncStream[Frame])
12 |
--------------------------------------------------------------------------------
/src/main/scala/com/twitter/finagle/websocket/Response.scala:
--------------------------------------------------------------------------------
1 | package com.twitter.finagle.websocket
2 |
3 | import com.twitter.concurrent.AsyncStream
4 |
5 | case class Response(messages: AsyncStream[Frame])
6 |
7 |
--------------------------------------------------------------------------------
/src/main/scala/com/twitter/finagle/websocket/ServerDispatcher.scala:
--------------------------------------------------------------------------------
1 | package com.twitter.finagle.websocket
2 |
3 | import com.twitter.concurrent.AsyncStream
4 | import com.twitter.finagle.Service
5 | import com.twitter.finagle.stats.StatsReceiver
6 | import com.twitter.finagle.transport.Transport
7 | import com.twitter.io.Buf
8 | import com.twitter.util.{Closable, Future, Time}
9 | import java.net.{SocketAddress, URI}
10 | import org.jboss.netty.handler.codec.http.HttpRequest
11 | import org.jboss.netty.handler.codec.http.websocketx.CloseWebSocketFrame
12 | import scala.collection.JavaConverters._
13 |
14 | private[finagle] class ServerDispatcher(
15 | trans: Transport[Any, Any],
16 | service: Service[Request, Response],
17 | stats: StatsReceiver)
18 | extends Closable {
19 |
20 | import Netty3.{fromNetty, toNetty}
21 |
22 | private[this] def messages(): AsyncStream[Frame] =
23 | AsyncStream.fromFuture(trans.read()).flatMap {
24 | case _: CloseWebSocketFrame => AsyncStream.empty
25 | case frame => fromNetty(frame) +:: messages()
26 | }
27 |
28 | // The first item is a HttpRequest.
29 | trans.read().flatMap {
30 | case (req: HttpRequest, addr: SocketAddress) =>
31 | val uri = new URI(req.getUri)
32 | val headers = req.headers.asScala.map(e => e.getKey -> e.getValue).toMap
33 | service(Request(uri, headers, addr, messages)).flatMap { response =>
34 | response.messages
35 | .map(toNetty)
36 | .foreachF(trans.write)
37 | .ensure(trans.close())
38 | }
39 |
40 | case _ =>
41 | trans.close()
42 | }
43 |
44 | def close(deadline: Time): Future[Unit] = trans.close()
45 | }
46 |
--------------------------------------------------------------------------------
/src/main/scala/com/twitter/finagle/websocket/WebSocketHandler.scala:
--------------------------------------------------------------------------------
1 | package com.twitter.finagle.websocket
2 |
3 | import com.twitter.concurrent.{Offer, Broker}
4 | import com.twitter.finagle.{CancelledRequestException, ChannelException}
5 | import com.twitter.finagle.util.DefaultTimer
6 | import com.twitter.util.{Future, Promise, Return, Throw, Try, TimerTask}
7 | import java.net.URI
8 | import org.jboss.netty.buffer.ChannelBuffers
9 | import org.jboss.netty.channel._
10 | import org.jboss.netty.handler.codec.http.websocketx._
11 | import org.jboss.netty.handler.codec.http.{HttpHeaders, HttpRequest, HttpResponse}
12 | import scala.collection.JavaConversions._
13 |
14 | private[finagle] class WebSocketServerHandler extends SimpleChannelUpstreamHandler {
15 | private[this] var handshaker: Option[WebSocketServerHandshaker] = None
16 |
17 | override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) =
18 | e.getMessage match {
19 | case req: HttpRequest =>
20 | val scheme = if(req.getUri.startsWith("wss")) "wss" else "ws"
21 | val location = scheme + "://" + req.headers.get(HttpHeaders.Names.HOST) + "/"
22 | val wsFactory = new WebSocketServerHandshakerFactory(location, null, false)
23 | handshaker = Option(wsFactory.newHandshaker(req))
24 | handshaker match {
25 | case None =>
26 | wsFactory.sendUnsupportedWebSocketVersionResponse(ctx.getChannel)
27 | case Some(ref) =>
28 | ref.handshake(ctx.getChannel, req)
29 | val addr = ctx.getChannel.getRemoteAddress
30 | Channels.fireMessageReceived(ctx, (req, addr))
31 | }
32 |
33 | case frame: CloseWebSocketFrame =>
34 | handshaker match {
35 | case Some(hs) =>
36 | hs.close(ctx.getChannel, frame).addListener(ChannelFutureListener.CLOSE)
37 | Channels.fireMessageReceived(ctx, frame)
38 |
39 | case None =>
40 | Channels.fireExceptionCaught(ctx,
41 | new IllegalArgumentException(s"Close received before handshake"))
42 | }
43 |
44 | case frame: WebSocketFrame =>
45 | Channels.fireMessageReceived(ctx, frame)
46 |
47 | case invalid =>
48 | Channels.fireExceptionCaught(ctx,
49 | new IllegalArgumentException(s"invalid message: $invalid"))
50 | }
51 | }
52 |
53 | private[finagle] class WebSocketClientHandler extends SimpleChannelHandler {
54 | @volatile private[this] var ref: Option[WebSocketClientHandshaker] = None
55 |
56 | override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent): Unit =
57 | e.getMessage match {
58 | case res: HttpResponse =>
59 | ref match {
60 | case None =>
61 | throw new IllegalStateException("unexpected HTTP response before handshake")
62 | case Some(handshaker) if handshaker.isHandshakeComplete =>
63 | throw new IllegalStateException("unexpected HTTP response after handshake")
64 | case Some(handshaker) =>
65 | handshaker.finishHandshake(ctx.getChannel, res)
66 | }
67 |
68 | case frame: WebSocketFrame =>
69 | Channels.fireMessageReceived(ctx, frame)
70 |
71 | case invalid =>
72 | Channels.fireExceptionCaught(ctx,
73 | new IllegalArgumentException(s"invalid message: $invalid"))
74 | }
75 |
76 | override def writeRequested(ctx: ChannelHandlerContext, e: MessageEvent): Unit =
77 | e.getMessage match {
78 | case handshaker: WebSocketClientHandshaker =>
79 | ref = Some(handshaker)
80 | val future = handshaker.handshake(ctx.getChannel)
81 | future.addListener(ChannelFutureListener.CLOSE_ON_FAILURE)
82 | future.addListener(new ChannelFutureListener {
83 | override def operationComplete(f: ChannelFuture): Unit =
84 | if (f.isSuccess) e.getFuture.setSuccess()
85 | else if (f.isCancelled) e.getFuture.cancel()
86 | else e.getFuture.setFailure(f.getCause)
87 | })
88 |
89 | case frame: WebSocketFrame =>
90 | ctx.sendDownstream(e)
91 |
92 | case req: HttpRequest if !ref.isEmpty =>
93 | ctx.sendDownstream(e)
94 |
95 | case invalid =>
96 | Channels.fireExceptionCaught(ctx,
97 | new IllegalArgumentException(s"invalid message: $invalid"))
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/src/test/scala/com/twitter/finagle/websocket/EndToEndTest.scala:
--------------------------------------------------------------------------------
1 | package com.twitter.finagle.websocket
2 |
3 | import com.twitter.concurrent.AsyncStream
4 | import com.twitter.conversions.time._
5 | import com.twitter.finagle
6 | import com.twitter.finagle.Service
7 | import com.twitter.finagle.param.Stats
8 | import com.twitter.finagle.stats.{NullStatsReceiver, StatsReceiver}
9 | import com.twitter.io.Buf
10 | import com.twitter.util._
11 | import java.net.{InetSocketAddress, SocketAddress, URI}
12 | import org.junit.runner.RunWith
13 | import org.scalatest.FunSuite
14 | import org.scalatest.junit.JUnitRunner
15 | import scala.collection.mutable.ArrayBuffer
16 |
17 | @RunWith(classOf[JUnitRunner])
18 | class EndToEndTest extends FunSuite {
19 | import Frame._
20 | import EndToEndTest._
21 | test("echo") {
22 | val echo = new Service[Request, Response] {
23 | def apply(req: Request): Future[Response] =
24 | Future.value(Response(req.messages))
25 | }
26 |
27 | connect(echo) { client =>
28 | val frames = texts("hello", "world")
29 | for {
30 | response <- client(mkRequest("/", frames))
31 | messages <- response.messages.toSeq()
32 | } yield assert(messages == frames)
33 | }
34 | }
35 | }
36 |
37 | private object EndToEndTest {
38 | def connect(
39 | service: Service[Request, Response],
40 | stats: StatsReceiver = NullStatsReceiver
41 | )(run: Service[Request, Response] => Future[Unit]): Unit = {
42 | val server = finagle.Websocket.server
43 | .withLabel("server")
44 | .configured(Stats(stats))
45 | .serve("localhost:*", service)
46 |
47 | val addr = server.boundAddress.asInstanceOf[InetSocketAddress]
48 |
49 | val client = finagle.Websocket.client
50 | .configured(Stats(stats))
51 | .newService(s"${addr.getHostName}:${addr.getPort}", "client")
52 |
53 | Await.result(run(client).ensure(Closable.all(client, server).close()), 1.second)
54 | }
55 |
56 | def texts(messages: String*): Seq[Frame] =
57 | messages.map(Frame.Text(_))
58 |
59 | def mkRequest(path: String, frames: Seq[Frame]): Request =
60 | Request(new URI(path), Map.empty, new SocketAddress{}, AsyncStream.fromSeq(frames))
61 | }
62 |
--------------------------------------------------------------------------------
/src/test/scala/com/twitter/finagle/websocket/ServerDispatcherTest.scala:
--------------------------------------------------------------------------------
1 | package com.twitter.finagle.websocket
2 |
3 | import com.twitter.concurrent.AsyncQueue
4 | import com.twitter.conversions.time._
5 | import com.twitter.finagle.{Service, Status}
6 | import com.twitter.finagle.stats.DefaultStatsReceiver
7 | import com.twitter.finagle.transport.{QueueTransport, Transport}
8 | import com.twitter.util.{Await, Future}
9 | import java.net.SocketAddress
10 | import org.junit.runner.RunWith
11 | import org.jboss.netty.handler.codec.http._
12 | import org.jboss.netty.handler.codec.http.websocketx._
13 | import org.scalatest.FunSuite
14 | import org.scalatest.junit.JUnitRunner
15 |
16 | @RunWith(classOf[JUnitRunner])
17 | class ServerDispatcherTest extends FunSuite {
18 | import ServerDispatcherTest._
19 |
20 | val echo = new Service[Request, Response] {
21 | def apply(req: Request): Future[Response] = {
22 | Future.value(Response(req.messages))
23 | }
24 | }
25 |
26 | test("invalid message") {
27 | val (in, out) = mkPair[Any, Any]
28 | val disp = new ServerDispatcher(out, echo, DefaultStatsReceiver)
29 | in.write("invalid")
30 | Await.ready(out.onClose, 1.second)
31 | assert(out.status == Status.Closed)
32 | }
33 |
34 | test("valid message then invalid") {
35 | val (in, out) = mkPair[Any, Any]
36 | val disp = new ServerDispatcher(out, echo, DefaultStatsReceiver)
37 | val req = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")
38 | val addr = new SocketAddress{}
39 | in.write((req, addr))
40 | in.write(new TextWebSocketFrame("hello"))
41 | val frame = Await.result(in.read(), 1.second)
42 | assert(frame.asInstanceOf[TextWebSocketFrame].getText == "hello")
43 | in.write("invalid")
44 | assert(out.status == Status.Closed)
45 | }
46 | }
47 |
48 | object ServerDispatcherTest {
49 | def mkPair[A,B] = {
50 | val inq = new AsyncQueue[A]
51 | val outq = new AsyncQueue[B]
52 | (new QueueTransport[A, B](inq, outq), new QueueTransport[B, A](outq, inq))
53 | }
54 | }
55 |
--------------------------------------------------------------------------------