├── project ├── plugins.sbt └── build.properties ├── .gitignore └── src ├── main ├── scala │ └── example │ │ ├── Main.scala │ │ ├── FutureUtils.scala │ │ ├── Log.scala │ │ ├── CurrentThreadExecutionContext.scala │ │ ├── ClientKey.scala │ │ ├── ExecutorServiceUtils.scala │ │ ├── SetHeadersHandler.scala │ │ ├── ClientsRegistry.scala │ │ ├── Initializer.scala │ │ └── MainHandler.scala └── resources │ └── logback.xml └── test └── scala └── example ├── InitializerSpec.scala ├── Http.scala ├── SetHeadersHandlerSpec.scala ├── ClientsRegistrySpec.scala └── MainHandlerSpec.scala /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | logLevel := Level.Warn -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 0.13.8 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target/* 2 | project/target/* 3 | project/project/* 4 | -------------------------------------------------------------------------------- /src/main/scala/example/Main.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | object Main extends App { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/main/scala/example/FutureUtils.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import scala.concurrent.duration._ 4 | import scala.concurrent.{Await, Future} 5 | import scala.language.postfixOps 6 | 7 | object FutureUtils { 8 | 9 | def awaitFuture[T](future: Future[T]): T = 10 | Await.result(future, 5 seconds) 11 | 12 | } 13 | -------------------------------------------------------------------------------- /src/main/scala/example/Log.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import org.slf4j.LoggerFactory 4 | 5 | object Log { 6 | 7 | def get[T](implicit tag: reflect.ClassTag[T]) = 8 | LoggerFactory.getLogger(tag.runtimeClass.getName) 9 | 10 | def getByName(name: String) = 11 | LoggerFactory.getLogger(name) 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala/example/CurrentThreadExecutionContext.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import scala.concurrent.ExecutionContext 4 | 5 | object CurrentThreadExecutionContext extends ExecutionContext { 6 | override def execute(runnable: Runnable): Unit = runnable.run() 7 | 8 | override def reportFailure(cause: Throwable): Unit = { 9 | cause.printStackTrace() 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | %date{ISO8601} [%thread] %-5level %logger{36} - %msg%n 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /src/main/scala/example/ClientKey.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import java.util.Date 4 | import io.netty.channel.ChannelHandlerContext 5 | 6 | case class ClientKey(path: String, expiration: Date, ctx : ChannelHandlerContext) 7 | extends Comparable[ClientKey] { 8 | 9 | override def compareTo(o: ClientKey): Int = 10 | expiration.compareTo(o.expiration) 11 | 12 | def isExpired : Boolean = new Date().after(expiration) 13 | 14 | 15 | } 16 | -------------------------------------------------------------------------------- /src/main/scala/example/ExecutorServiceUtils.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import java.util.concurrent.{Executors, ThreadFactory} 4 | 5 | import scala.concurrent.ExecutionContext 6 | 7 | object DaemonThreadFactory extends ThreadFactory { 8 | 9 | def newThread(r: Runnable): Thread = { 10 | val thread = Executors.defaultThreadFactory().newThread(r) 11 | thread.setDaemon(true) 12 | thread 13 | } 14 | 15 | } 16 | 17 | object ExecutorServiceUtils { 18 | 19 | implicit val CachedThreadPool = Executors.newCachedThreadPool(DaemonThreadFactory) 20 | implicit val CachedExecutionContext = ExecutionContext.fromExecutor(CachedThreadPool) 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/test/scala/example/InitializerSpec.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import java.util.concurrent.atomic.AtomicInteger 4 | 5 | import org.specs2.mutable.Specification 6 | import FutureUtils.awaitFuture 7 | 8 | class InitializerSpec extends Specification { 9 | 10 | "initializer" >> { 11 | 12 | "binds and handles requests correctly" >> { 13 | withInitializer { 14 | initializer => 15 | val getFuture = get(initializer) 16 | 17 | Thread.sleep(500) 18 | 19 | val postResult = awaitFuture(post(initializer)) 20 | 21 | postResult.getStatusCode must_==(200) 22 | 23 | val getResult = awaitFuture(getFuture) 24 | 25 | getResult.getStatusCode must_==(200) 26 | getResult.getResponseBody must_==(sampleBody) 27 | } 28 | } 29 | 30 | } 31 | 32 | val path = "/some-path" 33 | val sampleBody = "sample-body" 34 | val portCounter = new AtomicInteger(5000) 35 | 36 | def get(initializer : Initializer) = 37 | Http.get(s"http://localhost:${initializer.port}${path}") 38 | 39 | def post(initializer : Initializer) = 40 | Http.post(s"http://localhost:${initializer.port}${path}", sampleBody, Map("Content-Type" -> "text/plain")) 41 | 42 | def withInitializer[R]( fn : Initializer => R ) : R = { 43 | val initializer = new Initializer(10, portCounter.incrementAndGet())(ExecutorServiceUtils.CachedExecutionContext) 44 | 45 | val t = new Thread(new Runnable { 46 | override def run(): Unit = initializer.start() 47 | }) 48 | 49 | t.start() 50 | 51 | Thread.sleep(500) 52 | 53 | try { 54 | fn(initializer) 55 | } finally { 56 | initializer.stop() 57 | } 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /src/test/scala/example/Http.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import com.ning.http.client.{AsyncCompletionHandler, AsyncHttpClient, Response} 4 | 5 | import scala.concurrent.{Future, Promise} 6 | 7 | class ErrorResponseException(val response: Response) 8 | extends IllegalStateException(s"HTTP request failed with ${response.getStatusCode} - ${response.getHeaders} - ${response.getResponseBody}") 9 | 10 | object Http { 11 | 12 | private val log = Log.getByName(this.getClass.getCanonicalName) 13 | private val client = new AsyncHttpClient() 14 | 15 | private def perform( 16 | builder: AsyncHttpClient#BoundRequestBuilder, 17 | body: String, 18 | headers: Map[String, String] = Map.empty): Future[Response] = { 19 | 20 | headers.foreach { 21 | case (key, value) => 22 | builder.addHeader(key, value) 23 | } 24 | 25 | builder.setBody(body) 26 | 27 | val promise = Promise[Response]() 28 | 29 | builder.execute(new AsyncCompletionHandler[Response]() { 30 | 31 | override def onCompleted(response: Response): Response = { 32 | log.info(s"Request finished with ${response.getStatusCode}") 33 | 34 | if (response.getStatusCode() > 399) { 35 | log.info(s"response was ${response.getStatusCode}\n${response.getResponseBody}") 36 | promise.failure(new ErrorResponseException(response)) 37 | } else { 38 | promise.success(response) 39 | } 40 | 41 | response 42 | } 43 | 44 | override def onThrowable(t: Throwable): Unit = 45 | promise.failure(t) 46 | 47 | }) 48 | 49 | promise.future 50 | 51 | } 52 | 53 | def post(url: String, body: String, headers: Map[String, String] = Map.empty): Future[Response] = 54 | perform(client.preparePost(url), body, headers) 55 | 56 | def get(url: String, headers: Map[String, String] = Map.empty): Future[Response] = 57 | perform(client.prepareGet(url), "", headers) 58 | 59 | } -------------------------------------------------------------------------------- /src/main/scala/example/SetHeadersHandler.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import io.netty.channel.ChannelHandler.Sharable 4 | import io.netty.channel.{ChannelDuplexHandler, ChannelHandlerContext, ChannelPromise} 5 | import io.netty.handler.codec.http.{FullHttpRequest, FullHttpResponse, HttpHeaders, HttpVersion} 6 | import io.netty.util.AttributeKey 7 | 8 | object SetHeadersHandler { 9 | 10 | val DefaultServerName = "long-polling-server-example" 11 | val ConnectionAttribute = 12 | AttributeKey.valueOf[String](s"${SetHeadersHandler.getClass.getName}.connection") 13 | val HttpVersionAttribute = 14 | AttributeKey.valueOf[HttpVersion](s"${SetHeadersHandler.getClass.getName}.version") 15 | 16 | } 17 | 18 | @Sharable 19 | class SetHeadersHandler extends ChannelDuplexHandler { 20 | 21 | import SetHeadersHandler._ 22 | 23 | override def channelRead(ctx: ChannelHandlerContext, msg: scala.Any): Unit = { 24 | msg match { 25 | case request: FullHttpRequest => { 26 | val connection = if (HttpHeaders.isKeepAlive(request)) 27 | HttpHeaders.Values.KEEP_ALIVE 28 | else 29 | HttpHeaders.Values.CLOSE 30 | ctx.channel().attr(ConnectionAttribute).set(connection) 31 | ctx.channel().attr(HttpVersionAttribute).set(request.getProtocolVersion) 32 | } 33 | case _ => 34 | } 35 | 36 | super.channelRead(ctx, msg) 37 | } 38 | 39 | override def write(ctx: ChannelHandlerContext, msg: scala.Any, promise: ChannelPromise): Unit = { 40 | 41 | msg match { 42 | case response: FullHttpResponse => { 43 | response.setProtocolVersion(ctx.channel().attr(HttpVersionAttribute).get()) 44 | response.headers().set(HttpHeaders.Names.SERVER, DefaultServerName) 45 | response.headers().set(HttpHeaders.Names.CONNECTION, ctx.channel().attr(ConnectionAttribute).get()) 46 | response.headers().set(HttpHeaders.Names.CONTENT_LENGTH, response.content().readableBytes()) 47 | } 48 | case _ => 49 | } 50 | 51 | super.write(ctx, msg, promise) 52 | } 53 | } -------------------------------------------------------------------------------- /src/main/scala/example/ClientsRegistry.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import java.util.concurrent.locks.ReentrantLock 4 | import java.util.{Calendar, Date} 5 | 6 | import io.netty.channel.ChannelHandlerContext 7 | 8 | import scala.collection.mutable.ListBuffer 9 | import scala.concurrent.{ExecutionContext, Future, Promise} 10 | 11 | class ClientsRegistry(timeoutInSeconds: Int) { 12 | 13 | private val lock = new ReentrantLock() 14 | private val pathsToClients = scala.collection.mutable.Map[String, ListBuffer[ClientKey]]() 15 | private val orderedClients = scala.collection.mutable.TreeSet[ClientKey]() 16 | 17 | def registerClient(path: String, ctx: ChannelHandlerContext)(implicit executor: ExecutionContext): Future[ClientKey] = 18 | withLock { 19 | val client = ClientKey(path, calculateTimeout(), ctx) 20 | 21 | val clients = pathsToClients.getOrElseUpdate(path, ListBuffer[ClientKey]()) 22 | clients += client 23 | orderedClients += client 24 | 25 | client 26 | } 27 | 28 | def complete(path: String)(implicit executor: ExecutionContext): Future[Iterable[ClientKey]] = 29 | withLock { 30 | pathsToClients.remove(path).map { 31 | clients => 32 | orderedClients --= clients 33 | clients 34 | }.getOrElse(Iterable.empty) 35 | } 36 | 37 | def collectTimeouts()(implicit executor: ExecutionContext): Future[Iterable[ClientKey]] = { 38 | withLock { 39 | val iterator = orderedClients.iterator 40 | val timeouts = ListBuffer[ClientKey]() 41 | 42 | var done = false 43 | 44 | while (iterator.hasNext && !done) { 45 | val next = iterator.next() 46 | if (next.isExpired) { 47 | timeouts += next 48 | } else { 49 | done = true 50 | } 51 | } 52 | 53 | orderedClients --= timeouts 54 | 55 | timeouts.foreach { 56 | timeout => 57 | pathsToClients.get(timeout.path).foreach(b => b -= timeout) 58 | } 59 | 60 | timeouts 61 | } 62 | } 63 | 64 | def calculateTimeout(): Date = { 65 | val calendar = Calendar.getInstance 66 | calendar.add(Calendar.SECOND, timeoutInSeconds) 67 | 68 | calendar.getTime 69 | } 70 | 71 | private def withLock[R](fn: => R)(implicit executor: ExecutionContext): Future[R] = { 72 | val p = Promise[R] 73 | 74 | executor.execute(new Runnable { 75 | override def run(): Unit = { 76 | lock.lock() 77 | try { 78 | p.success(fn) 79 | } catch { 80 | case e: Throwable => p.failure(e) 81 | } finally { 82 | lock.unlock() 83 | } 84 | } 85 | }) 86 | 87 | p.future 88 | } 89 | 90 | } 91 | -------------------------------------------------------------------------------- /src/test/scala/example/SetHeadersHandlerSpec.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import io.netty.buffer.Unpooled 4 | import io.netty.channel.embedded.EmbeddedChannel 5 | import io.netty.handler.codec.http._ 6 | import io.netty.util.CharsetUtil 7 | import org.specs2.mutable.Specification 8 | 9 | class SetHeadersHandlerSpec extends Specification { 10 | 11 | "handler" >> { 12 | 13 | "sets the necessary headers based in the incoming message" >> { 14 | val handler = new SetHeadersHandler 15 | val channel = new EmbeddedChannel(handler) 16 | 17 | val request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/some-path") 18 | request 19 | .headers() 20 | .set(HttpHeaders.Names.CONNECTION, HttpHeaders.Values.KEEP_ALIVE) 21 | 22 | val content = "some content".getBytes(CharsetUtil.UTF_8) 23 | val response = new DefaultFullHttpResponse( 24 | HttpVersion.HTTP_1_0, 25 | HttpResponseStatus.OK, 26 | Unpooled.wrappedBuffer(content)) 27 | 28 | channel.writeInbound(request) 29 | channel.attr(SetHeadersHandler.ConnectionAttribute).get() must_== (HttpHeaders.Values.KEEP_ALIVE) 30 | 31 | channel.writeOutbound(response) 32 | 33 | response.getProtocolVersion must_== (HttpVersion.HTTP_1_1) 34 | response.headers().get(HttpHeaders.Names.CONNECTION) must_== (HttpHeaders.Values.KEEP_ALIVE) 35 | response.headers().get(HttpHeaders.Names.CONTENT_LENGTH) must_== (content.length.toString) 36 | response.headers().get(HttpHeaders.Names.SERVER) must_== (SetHeadersHandler.DefaultServerName) 37 | } 38 | 39 | "sets the content length to 0 if there was no content" >> { 40 | val handler = new SetHeadersHandler 41 | val channel = new EmbeddedChannel(handler) 42 | 43 | val request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/some-path") 44 | request 45 | .headers() 46 | .set(HttpHeaders.Names.CONNECTION, HttpHeaders.Values.CLOSE) 47 | 48 | val response = new DefaultFullHttpResponse( 49 | HttpVersion.HTTP_1_1, 50 | HttpResponseStatus.OK) 51 | 52 | channel.writeInbound(request) 53 | channel.attr(SetHeadersHandler.ConnectionAttribute).get() must_== (HttpHeaders.Values.CLOSE) 54 | 55 | channel.writeOutbound(response) 56 | 57 | response.getProtocolVersion must_== (HttpVersion.HTTP_1_1) 58 | response.headers().get(HttpHeaders.Names.CONNECTION) must_== (HttpHeaders.Values.CLOSE) 59 | response.headers().get(HttpHeaders.Names.CONTENT_LENGTH) must_== "0" 60 | response.headers().get(HttpHeaders.Names.SERVER) must_== (SetHeadersHandler.DefaultServerName) 61 | } 62 | 63 | } 64 | 65 | } 66 | -------------------------------------------------------------------------------- /src/test/scala/example/ClientsRegistrySpec.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import java.util.{Date, Calendar} 4 | 5 | import io.netty.channel.ChannelHandlerContext 6 | import org.specs2.mock.Mockito 7 | import org.specs2.mutable.Specification 8 | import FutureUtils.awaitFuture 9 | import ExecutorServiceUtils.CachedExecutionContext 10 | 11 | class ClientsRegistrySpec extends Specification with Mockito { 12 | 13 | val path = "sample-path" 14 | 15 | "registry" >> { 16 | 17 | "registers a client and produces it when completing" >> { 18 | val context = mock[ChannelHandlerContext] 19 | val registry = new ClientsRegistry(1) 20 | 21 | val client = awaitFuture(registry.registerClient(path, context)) 22 | 23 | val result = awaitFuture(registry.complete(path)).toList 24 | 25 | result.length must_==(1) 26 | result(0) must_==(client) 27 | 28 | awaitFuture(registry.complete(path)).toList.length must_==(0) 29 | } 30 | 31 | "returns an empty collection if no clients were there to be collected" >> { 32 | val registry = new ClientsRegistry(1) 33 | awaitFuture(registry.complete(path)) must beEmpty 34 | } 35 | 36 | "removes from timeouts once completed" >> { 37 | val context = mock[ChannelHandlerContext] 38 | val registry = new ClientsRegistry(0) 39 | 40 | val client = awaitFuture(registry.registerClient(path, context)) 41 | val result = awaitFuture(registry.complete(path)).toList 42 | 43 | result.length must_==(1) 44 | result(0) must_==(client) 45 | 46 | Thread.sleep(500) 47 | 48 | awaitFuture(registry.collectTimeouts()) must beEmpty 49 | } 50 | 51 | "timeouts clients and removes them from the collection" >> { 52 | var date = new Date() 53 | 54 | val registry = new ClientsRegistry(0) { 55 | override def calculateTimeout(): Date = date 56 | } 57 | 58 | val timeoutedChannel = mock[ChannelHandlerContext] 59 | val timeoutedClient = awaitFuture(registry.registerClient(path, timeoutedChannel)) 60 | 61 | val futureChannel = mock[ChannelHandlerContext] 62 | val futureTime = Calendar.getInstance 63 | futureTime.add(Calendar.DATE, 1) 64 | 65 | date = futureTime.getTime 66 | 67 | val futureClient = awaitFuture(registry.registerClient(path, futureChannel)) 68 | 69 | Thread.sleep(500) 70 | 71 | val timeouts = awaitFuture(registry.collectTimeouts()).toList 72 | 73 | timeouts.length must_==(1) 74 | timeouts(0) must_==(timeoutedClient) 75 | 76 | val result = awaitFuture(registry.complete(path)).toList 77 | 78 | result.length must_==(1) 79 | result(0) must_==(futureClient) 80 | } 81 | 82 | } 83 | 84 | } 85 | -------------------------------------------------------------------------------- /src/main/scala/example/Initializer.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import java.util.concurrent.TimeUnit 4 | 5 | import io.netty.bootstrap.ServerBootstrap 6 | import io.netty.channel.socket.nio.NioServerSocketChannel 7 | import io.netty.channel.{ChannelFuture, Channel, ChannelOption, ChannelInitializer} 8 | import io.netty.channel.nio.NioEventLoopGroup 9 | import io.netty.channel.socket.SocketChannel 10 | import io.netty.handler.codec.http.{HttpObjectAggregator, HttpServerCodec} 11 | 12 | import scala.concurrent.ExecutionContext 13 | 14 | object Initializer { 15 | 16 | val log = Log.get[Initializer] 17 | 18 | } 19 | 20 | class Initializer (timeoutInSeconds : Int, val port: Int) (implicit executor: ExecutionContext) 21 | extends ChannelInitializer[SocketChannel] { 22 | 23 | import Initializer.log 24 | 25 | private val bossGroup = new NioEventLoopGroup(1) 26 | private val workerGroup = new NioEventLoopGroup() 27 | 28 | private val serverBootstrap = new ServerBootstrap() 29 | serverBootstrap.option(ChannelOption.SO_BACKLOG, java.lang.Integer.valueOf(1024)) 30 | serverBootstrap.group(bossGroup, workerGroup) 31 | .channel(classOf[NioServerSocketChannel]) 32 | .childHandler(this) 33 | 34 | private var serverChannel: Channel = null 35 | private val setHeadersHandler = new SetHeadersHandler 36 | private val mainHandler = new MainHandler(new ClientsRegistry(timeoutInSeconds)) 37 | 38 | override def initChannel(ch: SocketChannel): Unit = { 39 | val p = ch.pipeline() 40 | 41 | p.addLast("http-codec", new HttpServerCodec()) 42 | p.addLast("aggregator", new HttpObjectAggregator(Int.MaxValue)) 43 | p.addLast("set-headers-handler", setHeadersHandler) 44 | p.addLast("handler", mainHandler) 45 | } 46 | 47 | def start(): Unit = { 48 | try { 49 | serverChannel = serverBootstrap.bind(port).sync().channel() 50 | serverChannel.eventLoop().scheduleAtFixedRate(new Runnable { 51 | override def run(): Unit = 52 | mainHandler.evaluateTimeouts() 53 | }, 54 | timeoutInSeconds, 55 | timeoutInSeconds, 56 | TimeUnit.SECONDS 57 | ) 58 | 59 | log.info(s"Starting server ${serverChannel}") 60 | serverChannel.closeFuture().sync() 61 | } catch { 62 | case e: Exception => 63 | log.error(s"Server channel failed with ${e.getMessage}", e) 64 | } 65 | finally { 66 | bossGroup.shutdownGracefully() 67 | workerGroup.shutdownGracefully() 68 | } 69 | } 70 | 71 | def stop(): ChannelFuture = { 72 | log.info(s"Stopping server ${serverChannel}") 73 | val channelFuture = serverChannel.close().awaitUninterruptibly() 74 | log.info(s"Closed server channel ${serverChannel}") 75 | channelFuture 76 | } 77 | 78 | } 79 | 80 | -------------------------------------------------------------------------------- /src/main/scala/example/MainHandler.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import java.util.concurrent.TimeoutException 4 | 5 | import io.netty.buffer.Unpooled 6 | import io.netty.channel.ChannelHandler.Sharable 7 | import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} 8 | import io.netty.handler.codec.http._ 9 | import io.netty.util.{CharsetUtil, ReferenceCountUtil} 10 | 11 | import scala.concurrent.ExecutionContext 12 | import scala.util.{Failure, Success} 13 | 14 | object MainHandler { 15 | 16 | val log = Log.get[MainHandler] 17 | 18 | } 19 | 20 | @Sharable 21 | class MainHandler( registry : ClientsRegistry )(implicit executor: ExecutionContext) 22 | extends SimpleChannelInboundHandler[FullHttpRequest] { 23 | 24 | import MainHandler.log 25 | 26 | override def channelRead0(ctx: ChannelHandlerContext, msg: FullHttpRequest): Unit = { 27 | 28 | log.info(s"Received request ${msg}") 29 | 30 | msg.getMethod match { 31 | case HttpMethod.GET => { 32 | registry.registerClient(msg.getUri, ctx).onFailure { 33 | case e => writeError(ctx, e) 34 | } 35 | } 36 | case HttpMethod.POST => { 37 | ReferenceCountUtil.retain(msg) 38 | registry.complete(msg.getUri).onComplete { 39 | result => 40 | try { 41 | result match { 42 | case Success(clients) => { 43 | clients.foreach { 44 | client => 45 | client.ctx.writeAndFlush(buildResponse(msg)) 46 | } 47 | ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK)) 48 | } 49 | case Failure(e) => 50 | writeError(ctx, e) 51 | } 52 | } finally { 53 | ReferenceCountUtil.release(msg) 54 | } 55 | } 56 | } 57 | case _ => 58 | ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NOT_FOUND)) 59 | } 60 | } 61 | 62 | def evaluateTimeouts(): Unit = { 63 | registry.collectTimeouts().onSuccess { 64 | case clients => clients.foreach { 65 | client => 66 | writeError(client.ctx, new TimeoutException("channel timeouted without a response")) 67 | } 68 | } 69 | } 70 | 71 | def writeError(ctx : ChannelHandlerContext, e : Throwable): Unit = { 72 | val response = new DefaultFullHttpResponse( 73 | HttpVersion.HTTP_1_1, 74 | HttpResponseStatus.INTERNAL_SERVER_ERROR, 75 | Unpooled.wrappedBuffer(e.getMessage.getBytes(CharsetUtil.UTF_8)) 76 | ) 77 | 78 | response.headers().add(HttpHeaders.Names.CONTENT_TYPE, "text/plain") 79 | 80 | ctx.writeAndFlush(response) 81 | } 82 | 83 | def buildResponse( request : FullHttpRequest ) : FullHttpResponse = { 84 | val response = new DefaultFullHttpResponse( 85 | HttpVersion.HTTP_1_1, 86 | HttpResponseStatus.OK, 87 | Unpooled.copiedBuffer(request.content()) 88 | ) 89 | 90 | if ( request.headers().contains(HttpHeaders.Names.CONTENT_TYPE) ) { 91 | response.headers().add(HttpHeaders.Names.CONTENT_TYPE, request.headers().get(HttpHeaders.Names.CONTENT_TYPE)) 92 | } 93 | 94 | response 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /src/test/scala/example/MainHandlerSpec.scala: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import io.netty.buffer.Unpooled 4 | import io.netty.channel.ChannelHandlerContext 5 | import io.netty.channel.embedded.EmbeddedChannel 6 | import io.netty.handler.codec.http._ 7 | import io.netty.util.CharsetUtil 8 | import org.specs2.mutable.Specification 9 | 10 | import scala.concurrent.{Future, ExecutionContext} 11 | 12 | class MainHandlerSpec extends Specification { 13 | 14 | val path = "/some-path" 15 | val contents = "some-contents" 16 | val contentBytes = contents.getBytes(CharsetUtil.UTF_8) 17 | 18 | "handler" >> { 19 | 20 | "registers the client and sends the response back" >> { 21 | val registry = new ClientsRegistry(1) 22 | val handler = new MainHandler(registry)(CurrentThreadExecutionContext) 23 | val requesterChannel = new EmbeddedChannel(handler) 24 | val notifierChannel = new EmbeddedChannel(handler) 25 | 26 | val pollRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path) 27 | requesterChannel.writeInbound(pollRequest) 28 | 29 | val notificationRequest = new DefaultFullHttpRequest( 30 | HttpVersion.HTTP_1_1, HttpMethod.POST, path, Unpooled.wrappedBuffer(contentBytes)) 31 | notifierChannel.writeInbound(notificationRequest) 32 | 33 | val pollResponse = requesterChannel.readOutbound().asInstanceOf[FullHttpResponse] 34 | 35 | pollResponse.getStatus must_==(HttpResponseStatus.OK) 36 | pollResponse.content().toString(CharsetUtil.UTF_8) must_==(contents) 37 | 38 | val notificationResponse = notifierChannel.readOutbound().asInstanceOf[FullHttpResponse] 39 | notificationResponse.getStatus must_==(HttpResponseStatus.OK) 40 | } 41 | 42 | "timeouts clients and sends them a response" >> { 43 | val registry = new ClientsRegistry(0) 44 | val handler = new MainHandler(registry)(CurrentThreadExecutionContext) 45 | val channel = new EmbeddedChannel(handler) 46 | 47 | val request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path) 48 | channel.writeInbound(request) 49 | 50 | Thread.sleep(500) 51 | 52 | handler.evaluateTimeouts() 53 | 54 | val response = channel.readOutbound().asInstanceOf[FullHttpResponse] 55 | 56 | response.getStatus must_==(HttpResponseStatus.INTERNAL_SERVER_ERROR) 57 | response.content().toString(CharsetUtil.UTF_8) must_==("channel timeouted without a response") 58 | } 59 | 60 | "returns a 404 for other requests" >> { 61 | val registry = new ClientsRegistry(1) 62 | val handler = new MainHandler(registry)(CurrentThreadExecutionContext) 63 | val channel = new EmbeddedChannel(handler) 64 | 65 | val request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.DELETE, path) 66 | channel.writeInbound(request) 67 | 68 | val response = channel.readOutbound().asInstanceOf[FullHttpResponse] 69 | 70 | response.getStatus must_==(HttpResponseStatus.NOT_FOUND) 71 | } 72 | 73 | "returns an error if it can't register the client" >> { 74 | val exception = new IllegalStateException("can't register clients right now, sorry") 75 | val registry = new ClientsRegistry(1) { 76 | override def registerClient(path: String, ctx: ChannelHandlerContext)(implicit executor: ExecutionContext): 77 | Future[ClientKey] = Future.failed(exception) 78 | } 79 | 80 | val handler = new MainHandler(registry)(CurrentThreadExecutionContext) 81 | val channel = new EmbeddedChannel(handler) 82 | 83 | val request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path) 84 | channel.writeInbound(request) 85 | 86 | val response = channel.readOutbound().asInstanceOf[FullHttpResponse] 87 | response.getStatus must_==(HttpResponseStatus.INTERNAL_SERVER_ERROR) 88 | response.content().toString(CharsetUtil.UTF_8) must_==(exception.getMessage) 89 | } 90 | 91 | "returns an error if it can't notify clients" >> { 92 | val exception = new IllegalStateException("can't notify clients right now, sorry") 93 | val registry = new ClientsRegistry(1) { 94 | override def complete(path: String)(implicit executor: ExecutionContext): 95 | Future[Iterable[ClientKey]] = Future.failed(exception) 96 | } 97 | 98 | val handler = new MainHandler(registry)(CurrentThreadExecutionContext) 99 | val channel = new EmbeddedChannel(handler) 100 | 101 | val request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, path) 102 | channel.writeInbound(request) 103 | 104 | val response = channel.readOutbound().asInstanceOf[FullHttpResponse] 105 | response.getStatus must_==(HttpResponseStatus.INTERNAL_SERVER_ERROR) 106 | response.content().toString(CharsetUtil.UTF_8) must_==(exception.getMessage) 107 | } 108 | 109 | } 110 | 111 | 112 | } 113 | --------------------------------------------------------------------------------