├── .gitignore ├── .gitmodules ├── .sbtrc ├── LICENSE ├── README ├── core └── src │ ├── main │ └── scala │ │ └── net │ │ └── danielkza │ │ └── http2 │ │ ├── Coder.scala │ │ ├── Http2.scala │ │ ├── api │ │ └── Header.scala │ │ ├── hpack │ │ ├── DynamicTable.scala │ │ ├── HeaderError.scala │ │ ├── HeaderRepr.scala │ │ ├── StaticTable.scala │ │ ├── Table.scala │ │ └── coders │ │ │ ├── BytesCoder.scala │ │ │ ├── CompressedBytesCoder.scala │ │ │ ├── HeaderBlockCoder.scala │ │ │ ├── HeaderCoder.scala │ │ │ ├── HuffmanCoding.scala │ │ │ ├── IntCoder.scala │ │ │ └── LiteralCoder.scala │ │ ├── model │ │ ├── AkkaMessageAdapter.scala │ │ ├── Http2Response.scala │ │ └── headers │ │ │ └── Trailer.scala │ │ ├── protocol │ │ ├── Frame.scala │ │ ├── HTTP2Error.scala │ │ ├── Http2Stream.scala │ │ ├── Setting.scala │ │ ├── StreamManager.scala │ │ └── coders │ │ │ ├── FrameCoder.scala │ │ │ └── IntCoder.scala │ │ ├── ssl │ │ ├── ALPNSSLContext.scala │ │ └── WrappedSSLContext.scala │ │ ├── stream │ │ ├── ChunkedDataDecodeStage.scala │ │ ├── FrameDecoderStage.scala │ │ ├── FrameEncoderStage.scala │ │ ├── HeaderCollapseStage.scala │ │ ├── HeaderDecodeActor.scala │ │ ├── HeaderEncodeActor.scala │ │ ├── HeaderSplitStage.scala │ │ ├── HeaderTransformActorBase.scala │ │ ├── Http2Message.scala │ │ ├── InboundStreamDispatcher.scala │ │ ├── NormalDataDecodeStage.scala │ │ ├── ServerConnectionBlueprint.scala │ │ ├── StreamManagerActor.scala │ │ └── package.scala │ │ └── util │ │ ├── ArrayQueue.scala │ │ ├── Implicits.scala │ │ ├── package.scala │ │ └── stream │ │ ├── Concentrator.scala │ │ └── ConcentratorShape.scala │ └── test │ └── scala │ └── net │ └── danielkza │ └── http2 │ ├── AkkaStreamsTest.scala │ ├── AkkaTest.scala │ ├── TestHelpers.scala │ ├── hpack │ ├── DynamicTableTest.scala │ ├── HeaderBlockCoderTest.scala │ └── coders │ │ ├── BytesCoderTest.scala │ │ ├── CompressedBytesCoderTest.scala │ │ ├── HeaderCoderTest.scala │ │ └── IntCoderTest.scala │ ├── protocol │ └── coders │ │ └── FrameCoderTest.scala │ └── stream │ ├── FrameDecoderStageTest.scala │ ├── FrameEncoderStageTest.scala │ ├── HeaderCollapseStageTest.scala │ ├── HeaderSplitStageTest.scala │ └── HeaderStageTest.scala ├── examples └── src │ └── main │ └── scala │ └── net │ └── danielkza │ └── http2 │ └── examples │ └── ServerExample.scala ├── macros └── src │ └── main │ └── scala │ └── net │ └── danielkza │ └── http2 │ └── macros │ └── BitPatterns.scala └── project ├── Build.scala └── build.properties /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.iml 3 | *.class 4 | target/ 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "http2-frame-test-case"] 2 | path = http2-frame-test-case 3 | url = git@github.com:http2jp/http2-frame-test-case.git 4 | -------------------------------------------------------------------------------- /.sbtrc: -------------------------------------------------------------------------------- 1 | alias boot = ;reload ;project core ;iflast shell 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "{}" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright {yyyy} {name of copyright owner} 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | 205 | 206 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | HTTP/2 Server implementation is Scala and Akka Streams 2 | 3 | This implements the basics of an HTTP/2 server using Akka Streams for all the data flow. 4 | 5 | All the build is handled by SBT, and using it should be enough to compile and start testing. 6 | 7 | Any actual programs that run the server must start with the Jetty ALPN JAR in it's bootclasspath. 8 | Since there are not supported plain-text protocols this wiill cause issue immediately if not taken care of. 9 | 10 | Also, it's necessary to have a keystore set-up, even if with a temporary replacement. Tutorials on how to achieve that 11 | are a available in many places. 12 | 13 | One simple example application server is included that just serves files in the local directory, which can be run with 14 | `sbt examples/run` 15 | 16 | The default application instance can also be debugged using the :debug configuration. 17 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/Coder.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2 2 | 3 | import scalaz.{StateT, \/, -\/, \/-} 4 | import shapeless._ 5 | import shapeless.ops.hlist._ 6 | import akka.util.{ByteString, ByteStringBuilder} 7 | 8 | abstract class Coder[T] { 9 | final type Subject = T 10 | type Error 11 | 12 | protected final type StateTES[TT, E, S] = StateT[\/[E, ?], S, TT] 13 | 14 | final type EncodeStateE[E] = StateTES[Unit, E, ByteStringBuilder] 15 | final type EncodeState = EncodeStateE[Error] 16 | 17 | final type DecodeStateTE[TT, E] = StateTES[TT, E, ByteString] 18 | final type DecodeStateT[TT] = DecodeStateTE[TT, Error] 19 | final type DecodeState = DecodeStateT[T] 20 | 21 | protected implicit def stateMonad[S] = StateT.stateTMonadState[S, \/[Error, ?]] 22 | 23 | def encode(value: T, stream: ByteStringBuilder): \/[Error, Unit] 24 | def encode(value: T): \/[Error, ByteString] = { 25 | val builder = ByteString.newBuilder 26 | encode(value, builder).map { _ => 27 | val res = builder.result() 28 | res 29 | } 30 | } 31 | 32 | def encodeS(value: T): EncodeState = StateT[\/[Error, ?], ByteStringBuilder, Unit] { in => 33 | encode(value, in).map(_ => (in, ())) 34 | } 35 | 36 | 37 | def decode(bs: ByteString): \/[Error, (T, Int)] 38 | 39 | def decodeS: DecodeStateT[T] = StateT[\/[Error, ?], ByteString, T] { in => 40 | decode(in).map { case (value, bytesRead) => in.drop(bytesRead) -> value } 41 | } 42 | 43 | 44 | final def takeS(length: Int): DecodeStateT[ByteString] = StateT[\/[Error, ?], ByteString, ByteString] { in => 45 | val (left, right) = in.splitAt(length) 46 | \/-(right -> left) 47 | } 48 | 49 | final def ensureS[S](error: => Error)(cond: => Boolean): StateTES[Unit, Error, S] = { 50 | val SM = stateMonad[S] 51 | if(!cond) failS(error) 52 | else SM.pure(()) 53 | } 54 | 55 | final def failS[TT, S](error: Error): StateTES[TT, Error, S] = StateT[\/[Error, ?], S, TT] { in => 56 | -\/(error) 57 | } 58 | } 59 | 60 | object Coder { 61 | final type Aux[T, E] = Coder[T] {type Error = E} 62 | 63 | // Trait that witnesses that an HList is a mapping of instances of Coder[_] another one, all having the same Error 64 | // type 65 | trait CodersOf[C <: HList, L <: HList] { 66 | type Error 67 | } 68 | 69 | object CodersOf { 70 | type Aux[C <: HC :: HList, L <: HL :: HList, E, HC <: Coder.Aux[HL, E], HL] = CodersOf[C, L] {type Error = E} 71 | 72 | //implicit def hnilCodersOf = new CodersOf[HNil, HNil] {type Error = Nothing} 73 | implicit def hlist1CodersOf[HC <: Coder[HL], HL] = 74 | new CodersOf[HC :: HNil, HL :: HNil] {type Error = HC#Error} 75 | implicit def hlistCodersOf[HC <: Coder[HL], C <: HList, HL, L <: HList] 76 | (implicit co: CodersOf[C, L] {type Error = HC#Error}) = 77 | { 78 | new CodersOf[HC :: C, HL :: L] {type Error = HC#Error} 79 | } 80 | } 81 | 82 | object dec extends Poly2 { 83 | final type R[T, E] = Coder[_]#DecodeStateTE[T, E] 84 | 85 | implicit def first 86 | [Value, Err, CurCoder] 87 | (implicit ev: CurCoder <:< Coder.Aux[Value, Err]) = 88 | at[R[HNil, Err], CurCoder] 89 | { (in, coder) => 90 | for { 91 | value <- coder.decodeS 92 | } yield value :: HNil 93 | } 94 | 95 | implicit def default 96 | [Value, Err, CurCoder, Values <: HList] 97 | (implicit ev: CurCoder <:< Coder.Aux[Value, Err], cons: IsHCons[Values], p: Prepend[Values, Value :: HNil]) = 98 | at[R[Values, Err], CurCoder] 99 | { (in, coder) => 100 | for { 101 | prev <- in 102 | cur <- coder.decodeS 103 | } yield p(prev, cur :: HNil) 104 | } 105 | } 106 | 107 | object enc extends Poly2 { 108 | final type R[E] = (ByteStringBuilder, \/[E, Unit]) 109 | 110 | implicit def default[Value, Err, CurCoder <: Coder.Aux[Value, Err]] = 111 | at[R[Err], (Value, CurCoder)] 112 | { (in, v) => 113 | val (stream, prev) = in 114 | val (value, coder) = v 115 | val res = for { 116 | _ <- prev 117 | _ <- coder.encode(value, stream) 118 | } yield () 119 | stream -> res 120 | } 121 | } 122 | 123 | def defineCompositeCoder 124 | [Prod, Err, Z <: HList, Values <: HList, Coders <: HList] 125 | (gen: Generic.Aux[Prod, Values])(coders: Coders)(implicit 126 | c: CodersOf[Coders, Values] {type Error = Err}, 127 | zipVC: Zip.Aux[Values :: Coders :: HNil, Z], 128 | encodeF: LeftFolder.Aux[Z, enc.R[Err], enc.type, enc.R[Err]], 129 | decodeF: LeftFolder.Aux[Coders, dec.R[HNil, Err], dec.type, dec.R[Values, Err]] 130 | ) = 131 | { 132 | new Coder[Prod] { 133 | override type Error = Err 134 | 135 | override def encode(value: Prod, stream: ByteStringBuilder): \/[Error, Unit] = { 136 | val elems = gen.to(value) 137 | val zipped = elems.zip(coders)(zipVC) 138 | val init: enc.R[Error] = stream -> \/-(()) 139 | zipped.foldLeft(init)(enc)(encodeF)._2 140 | } 141 | 142 | private val decoder = { 143 | val init = StateT.stateT[\/[Error, ?], ByteString, HNil](HNil) 144 | coders.foldLeft(init)(dec)(decodeF) 145 | } 146 | 147 | override final def decode(bs: ByteString): \/[Error, (Prod, Int)] = { 148 | decoder.run(bs).map { case (rem, result) => (gen.from(result), bs.length - rem.length) } 149 | } 150 | 151 | override final def decodeS: DecodeStateT[Prod] = 152 | decoder.map(gen.from) 153 | } 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/Http2.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2 2 | 3 | import java.net.InetSocketAddress 4 | import scala.collection.immutable 5 | import scala.collection.JavaConverters._ 6 | import scala.concurrent.{ExecutionContext, Future, Await} 7 | import scala.concurrent.duration._ 8 | import com.typesafe.config.{Config, ConfigFactory} 9 | import akka.actor.{ExtendedActorSystem, ExtensionId, ExtensionIdProvider, ActorSystem} 10 | import akka.event.LoggingAdapter 11 | import akka.stream.Materializer 12 | import akka.stream.io._ 13 | import akka.stream.scaladsl._ 14 | import akka.http.ServerSettings 15 | import akka.http.scaladsl.{Http, HttpsContext} 16 | import akka.http.scaladsl.Http.ServerBinding 17 | import akka.http.scaladsl.model.{HttpResponse, HttpRequest} 18 | import net.danielkza.http2.ssl.ALPNSSLContext 19 | import net.danielkza.http2.model.Http2Response 20 | import net.danielkza.http2.stream.ServerConnectionBlueprint 21 | 22 | class Http2Ext(config: Config)(implicit system: ActorSystem) extends akka.actor.Extension { 23 | import Http2._ 24 | 25 | private def sslTlsStage(httpsContext: Option[HttpsContext], role: Role, hostInfo: Option[(String, Int)] = None) = 26 | httpsContext match { 27 | case Some(hctx) => SslTls(new ALPNSSLContext(hctx.sslContext, immutable.Seq("h2")), hctx.firstSession, role, 28 | hostInfo = hostInfo) 29 | case None => SslTlsPlacebo.forScala 30 | } 31 | 32 | def bind( 33 | interface: String, port: Int = -1, 34 | settings: ServerSettings = ServerSettings(system), 35 | httpsContext: Option[HttpsContext] = None, 36 | log: LoggingAdapter = system.log, 37 | http2Settings: Http2Settings = Http2Settings(system)) 38 | (implicit fm: Materializer) 39 | : Source[IncomingConnection, Future[ServerBinding]] = 40 | { 41 | val effectiveHttpsContext = httpsContext.getOrElse(Http().defaultClientHttpsContext) 42 | val effectivePort = if (port >= 0) port else 443 43 | val tlsStage = sslTlsStage(Some(effectiveHttpsContext), Server) 44 | val connections: Source[Tcp.IncomingConnection, Future[Tcp.ServerBinding]] = 45 | Tcp().bind(interface, effectivePort, settings.backlog, settings.socketOptions, halfClose = false, settings.timeouts.idleTimeout) 46 | 47 | connections.map { case Tcp.IncomingConnection(localAddress, remoteAddress, flow) => 48 | val layer = ServerConnectionBlueprint(settings, http2Settings) 49 | IncomingConnection(localAddress, remoteAddress) { handler => 50 | layer(handler).join(tlsStage.join(flow)) 51 | } 52 | }.mapMaterializedValue { 53 | _.map(tcpBinding => ServerBinding(tcpBinding.localAddress)(() => tcpBinding.unbind()))(fm.executionContext) 54 | } 55 | } 56 | } 57 | 58 | object Http2 extends ExtensionId[Http2Ext] with ExtensionIdProvider { 59 | case class Http2Settings( 60 | maxIncomingStreams: Int, 61 | requestConcurrency: Int, 62 | maxOutgoingStreams: Int, 63 | incomingStreamFrameBufferSize: Int) 64 | 65 | object Http2Settings { 66 | private val defaults = ConfigFactory.parseMap(Map[String, AnyRef]( 67 | "akka.http2.max-incoming-streams" -> Int.box(8), 68 | "akka.http2.request-concurrency" -> Int.box(8), 69 | "akka.http2.max-outgoing-streams" -> Int.box(8), 70 | "akka.http2.incoming-stream-frame-buffer-size" -> Int.box(32) 71 | ).asJava) 72 | 73 | def apply(config: Config): Http2Settings = { 74 | val c = config.withFallback(defaults) 75 | 76 | apply( 77 | maxIncomingStreams = c.getInt("akka.http2.max-incoming-streams"), 78 | requestConcurrency = c.getInt("akka.http2.request-concurrency"), 79 | maxOutgoingStreams = c.getInt("akka.http2.max-outgoing-streams"), 80 | incomingStreamFrameBufferSize = c.getInt("akka.http2.incoming-stream-frame-buffer-size") 81 | ) 82 | } 83 | 84 | def apply(system: ActorSystem): Http2Settings = 85 | apply(system.settings.config) 86 | } 87 | 88 | case class IncomingConnection 89 | (localAddress: InetSocketAddress, remoteAddress: InetSocketAddress) 90 | (private val flowGen: Flow[HttpRequest, Http2Response, Any] => RunnableGraph[Future[Unit]]) 91 | { 92 | def handleWith(handler: Flow[HttpRequest, HttpResponse, Any])(implicit fm: Materializer): Future[Unit] = 93 | flowGen(handler.map(Http2Response.Simple)).run() 94 | 95 | def handleWith(handler: Flow[HttpRequest, Http2Response, Any]) 96 | (implicit fm: Materializer, dummyImplicit: DummyImplicit): Future[Unit] = 97 | flowGen(handler).run() 98 | 99 | def handleWithSyncHandler(handler: HttpRequest ⇒ HttpResponse)(implicit fm: Materializer): Unit = 100 | handleWith(Flow[HttpRequest].map { req => Http2Response.Simple(handler(req)) }) 101 | 102 | def handleWithSyncHandler(handler: HttpRequest ⇒ Http2Response) 103 | (implicit fm: Materializer, dummyImplicit: DummyImplicit): Unit = 104 | handleWith(Flow[HttpRequest].map(handler)) 105 | 106 | def handleWithAsyncHandler(handler: HttpRequest ⇒ Future[HttpResponse]) 107 | (implicit fm: Materializer): Unit = { 108 | implicit val ec = fm.executionContext 109 | handleWith(Flow[HttpRequest].mapAsync(1) { req => handler(req).map(Http2Response.Simple) }) 110 | } 111 | 112 | def handleWithAsyncHandler(handler: HttpRequest ⇒ Future[Http2Response]) 113 | (implicit fm: Materializer, dummyImplicit: DummyImplicit): Unit = 114 | handleWith(Flow[HttpRequest].mapAsync(1)(handler)) 115 | } 116 | 117 | def apply()(implicit system: ActorSystem): Http2Ext = super.apply(system) 118 | 119 | def lookup() = Http2 120 | 121 | def createExtension(system: ExtendedActorSystem): Http2Ext = { 122 | val default = ConfigFactory.parseString("""akka.http2 {}""") 123 | val httpConfig = system.settings.config.getConfig("akka.http") 124 | val http2Config = system.settings.config.withFallback(default).withFallback(httpConfig).getConfig("akka.http2") 125 | new Http2Ext(http2Config.withFallback(http2Config))(system) 126 | } 127 | 128 | object Implicits { 129 | implicit class SourceWithRunServer[T](source: Source[T, Future[ServerBinding]]) { 130 | def runServerIndefinitely(system: ActorSystem)(implicit mat: Materializer, ec: ExecutionContext) = { 131 | val binding = source.toMat(Sink.ignore)(Keep.left).run() 132 | 133 | Runtime.getRuntime.addShutdownHook(new Thread(new Runnable() { 134 | override def run(): Unit = { 135 | Await.ready(binding.flatMap(_.unbind()).flatMap(_ => system.terminate()), Duration.Inf) 136 | } 137 | })) 138 | } 139 | } 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/api/Header.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.api 2 | 3 | import akka.http.scaladsl.model.HttpHeader 4 | 5 | import scala.language.implicitConversions 6 | import akka.util.ByteString 7 | import akka.http.scaladsl.{model => akkaModel} 8 | 9 | sealed trait Header { 10 | def name: ByteString 11 | def value: ByteString 12 | def secure: Boolean 13 | } 14 | 15 | object Header extends { 16 | object Constants { 17 | final val METHOD = ByteString(":method") 18 | final val SCHEME = ByteString(":scheme") 19 | final val AUTHORITY = ByteString(":authority") 20 | final val STATUS = ByteString(":status") 21 | final val PATH = ByteString(":path") 22 | final val HOST = ByteString("Host") 23 | } 24 | 25 | import Constants._ 26 | 27 | private def encode(s: String): ByteString = 28 | ByteString.fromString(s, "UTF-8") 29 | 30 | case class RawHeader(name: ByteString, value: ByteString, secure: Boolean = false) 31 | extends Header 32 | 33 | case class WrappedAkkaHeader(akkaHeader: akkaModel.HttpHeader, secure: Boolean = false) extends Header { 34 | override val name = encode(akkaHeader.name.toLowerCase) 35 | override val value = encode(akkaHeader.value.toLowerCase) 36 | } 37 | 38 | implicit def headerFromAkka(header: HttpHeader): WrappedAkkaHeader = 39 | WrappedAkkaHeader(header, false) 40 | 41 | object WrappedAkkaHeader { 42 | implicit def unwrapAkkaHeader(wrapped: WrappedAkkaHeader): akkaModel.HttpHeader = 43 | wrapped.akkaHeader 44 | } 45 | 46 | case class WrappedAkkaStatusCode(akkaStatusCode: akkaModel.StatusCode) extends Header { 47 | override val name = STATUS 48 | override val value = ByteString(Integer.toString(akkaStatusCode.intValue)) 49 | override val secure = false 50 | } 51 | 52 | case class WrappedAkkaMethod(akkaMethod: akkaModel.HttpMethod) extends Header { 53 | override val name = METHOD 54 | override val value = ByteString(akkaMethod.value) 55 | override val secure = false 56 | } 57 | 58 | def plain(name: ByteString, value: ByteString): RawHeader = 59 | RawHeader(name, value, secure = false) 60 | 61 | def plain(name: String, value: String): RawHeader = 62 | RawHeader(ByteString(name.toLowerCase), ByteString(value.toLowerCase), secure = false) 63 | 64 | def secure(name: ByteString, value: ByteString): RawHeader = 65 | RawHeader(name, value, secure = true) 66 | 67 | def secure(name: String, value: String): RawHeader = 68 | RawHeader(ByteString(name.toLowerCase), ByteString(value.toLowerCase), secure = true) 69 | } 70 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/hpack/DynamicTable.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack 2 | 3 | import scalaz._ 4 | import scalaz.syntax.either._ 5 | import akka.util.ByteString 6 | 7 | class DynamicTable private ( 8 | val maxCapacity: Int, 9 | val curCapacity: Int, 10 | val baseOffset: Int, 11 | val entries: Vector[DynamicTable.Entry], 12 | val curSize: Int, 13 | protected val entryMap: Map[DynamicTable.Entry, Int], 14 | protected val entryMapOffset: Int) 15 | extends Table { 16 | final override type Entry = DynamicTable.Entry 17 | final override def Entry(name: ByteString, value: ByteString) = DynamicTable.Entry(name, value) 18 | 19 | def this(maxCapacity: Int, baseOffset: Int, entries: Traversable[DynamicTable.Entry] = Traversable.empty) = { 20 | this(maxCapacity, maxCapacity, baseOffset, entries.toVector, entries.foldLeft(0) { _ + _.size }, 21 | Table.genEntryMap(entries).toMap, 0) 22 | } 23 | 24 | @inline override def get(index: Int) = entries.lift(index - baseOffset - 1) 25 | 26 | override def lookupEntry(entry: DynamicTable.Entry): Option[Int] = { 27 | entryMap.get(entry).flatMap { index => 28 | index + entryMapOffset match { 29 | case i if i > baseOffset && i <= baseOffset + entries.size => Some(i) 30 | case _ => None 31 | } 32 | } 33 | } 34 | 35 | private def updated(newCapacity: Int = this.curCapacity, maxCapacity: Int = this.maxCapacity, 36 | addEntries: Iterable[Entry] = Iterable.empty): DynamicTable = { 37 | var newSize = 0 38 | var numEntries = 0 39 | 40 | def takeWhileNotFull(it: Iterator[Entry]) = { 41 | val prevNumEntries = numEntries 42 | it.takeWhile(newSize + _.size <= newCapacity).foreach { entry => 43 | newSize += entry.size 44 | numEntries += 1 45 | } 46 | numEntries - prevNumEntries 47 | } 48 | 49 | // Calculate how many entries we'll need to take from the old and new list of entries until we fill the maximum size 50 | val numNewEntries = takeWhileNotFull(addEntries.iterator) 51 | val numOldEntries = takeWhileNotFull(entries.iterator) 52 | 53 | val newEntries = addEntries.take(numNewEntries).toVector 54 | val oldEntries = entries.take(numOldEntries) 55 | val mergedEntries = newEntries ++ oldEntries 56 | 57 | // Find out how much the old entries will be offset from their real indexes. Add new entries also with that same 58 | // offset, so we can reverse it uniformly while fetching. 59 | var newEntryMapOffset = entryMapOffset + numNewEntries 60 | 61 | val mergedMap = if(entryMap.size >= curCapacity / 2 || newEntryMapOffset >= curCapacity / 2) { 62 | // Rewrite thw whole map without an offset if enough entries are outdated 63 | newEntryMapOffset = 0 64 | Table.genEntryMap(mergedEntries).toMap 65 | } else { 66 | // Apply the offsets to the new entries and merge 67 | val newMap = Table.genEntryMap(newEntries).map { case (entry, index) => 68 | entry -> (index + baseOffset - newEntryMapOffset) 69 | }.toMap 70 | newMap ++ entryMap 71 | } 72 | 73 | new DynamicTable(maxCapacity, newCapacity, baseOffset, mergedEntries, newSize, mergedMap, newEntryMapOffset) 74 | } 75 | 76 | def withCapacity(capacity: Int): \/[HeaderError, DynamicTable] = { 77 | if(capacity > maxCapacity) 78 | HeaderError.DynamicTableCapacityExceeded.left 79 | else 80 | updated(newCapacity = capacity).right 81 | } 82 | 83 | def withMaxCapacity(maxCapacity: Int): DynamicTable = 84 | updated(maxCapacity = maxCapacity, newCapacity = Math.min(curCapacity, maxCapacity)) 85 | 86 | def +(nv: (ByteString, ByteString)): DynamicTable = 87 | this ++ DynamicTable.Entry(nv._1, nv._2) 88 | 89 | def +(nv: (String, String))(implicit d1: DummyImplicit): DynamicTable = 90 | this ++ DynamicTable.Entry(ByteString(nv._1), ByteString(nv._2)) 91 | 92 | def +(entry: Entry): DynamicTable = 93 | this ++ entry 94 | 95 | def ++(entries: Entry*): DynamicTable = 96 | this ++ entries 97 | 98 | def ++(entries: Traversable[Entry]): DynamicTable = 99 | updated(addEntries = entries.toIterable) 100 | } 101 | 102 | object DynamicTable { 103 | final val ENTRY_OVERHEAD: Int = 32 104 | 105 | case class Entry(name: ByteString, value: ByteString) extends Table.Entry { 106 | override type Self = Entry 107 | override def withEmptyValue: Entry = copy(value = ByteString.empty) 108 | @inline def size = name.length + value.length + DynamicTable.ENTRY_OVERHEAD 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/hpack/HeaderError.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack 2 | 3 | sealed trait HeaderError 4 | 5 | object HeaderError { 6 | case class ExcessivePadding(pos: Int) extends HeaderError 7 | case class InvalidPadding(pos: Int) extends HeaderError 8 | case class UnknownCode(pos: Int) extends HeaderError 9 | case class IncompleteInput(received: Int, expected: Int) extends HeaderError 10 | case class EOSDecoded(pos: Int) extends HeaderError 11 | case object DynamicTableCapacityExceeded extends HeaderError 12 | case object ParseError extends HeaderError 13 | case class InvalidIndex(index: Int) extends HeaderError 14 | } 15 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/hpack/HeaderRepr.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack 2 | 3 | import akka.util.ByteString 4 | 5 | sealed trait HeaderRepr 6 | sealed trait Incremental extends HeaderRepr 7 | 8 | object HeaderRepr { 9 | case class Indexed(index: Int) 10 | extends HeaderRepr 11 | case class IncrementalLiteralWithIndexedName(keyIndex: Int, value: ByteString) 12 | extends Incremental 13 | case class IncrementalLiteral(key: ByteString, value: ByteString) 14 | extends Incremental 15 | case class LiteralWithIndexedName(keyIndex: Int, value: ByteString) 16 | extends HeaderRepr 17 | case class Literal(key: ByteString, value: ByteString) 18 | extends HeaderRepr 19 | case class NeverIndexedWithIndexedName(keyIndex: Int, value: ByteString) 20 | extends HeaderRepr 21 | case class NeverIndexed(key: ByteString, value: ByteString) 22 | extends HeaderRepr 23 | case class DynamicTableSizeUpdate(size: Int) 24 | extends HeaderRepr 25 | } 26 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/hpack/StaticTable.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack 2 | 3 | import akka.util.ByteString 4 | 5 | class StaticTable(initialEntries: (String, String)*) extends Table { 6 | final override type Entry = StaticTable.Entry 7 | final override def Entry(name: ByteString, value: ByteString) = StaticTable.Entry(name, value) 8 | 9 | override val entries = initialEntries.map { case (name, value) => 10 | Entry(ByteString(name), ByteString(value)) 11 | }.toIndexedSeq 12 | 13 | override val entryMap = Table.genEntryMap[StaticTable.Entry](entries).toMap 14 | } 15 | 16 | object StaticTable { 17 | case class Entry(name: ByteString, value: ByteString) extends Table.Entry { 18 | override type Self = Entry 19 | override def withEmptyValue: Entry = copy(value = ByteString.empty) 20 | } 21 | 22 | lazy val default: StaticTable = new StaticTable( 23 | ":authority" -> "", 24 | ":method" -> "GET", 25 | ":method" -> "POST", 26 | ":path" -> "/", 27 | ":path" -> "/index.html", 28 | ":scheme" -> "http", 29 | ":scheme" -> "https", 30 | ":status" -> "200", 31 | ":status" -> "204", 32 | ":status" -> "206", 33 | ":status" -> "304", 34 | ":status" -> "400", 35 | ":status" -> "404", 36 | ":status" -> "500", 37 | "accept-charset" -> "", 38 | "accept-encoding" -> "gzip, deflate", 39 | "accept-language" -> "", 40 | "accept-ranges" -> "", 41 | "accept" -> "", 42 | "access-control-allow-origin" -> "", 43 | "age" -> "", 44 | "allow" -> "", 45 | "authorization" -> "", 46 | "cache-control" -> "", 47 | "content-disposition" -> "", 48 | "content-encoding" -> "", 49 | "content-language" -> "", 50 | "content-length" -> "", 51 | "content-location" -> "", 52 | "content-range" -> "", 53 | "content-type" -> "", 54 | "cookie" -> "", 55 | "date" -> "", 56 | "etag" -> "", 57 | "expect" -> "", 58 | "expires" -> "", 59 | "from" -> "", 60 | "host" -> "", 61 | "if-match" -> "", 62 | "if-modified-since" -> "", 63 | "if-match" -> "", 64 | "if-range" -> "", 65 | "if-unmodified-since" -> "", 66 | "last-modified" -> "", 67 | "link" -> "", 68 | "location" -> "", 69 | "max-forwards" -> "", 70 | "proxy-authenticate" -> "", 71 | "proxy-authorization" -> "", 72 | "range" -> "", 73 | "referer" -> "", 74 | "refresh" -> "", 75 | "retry-after" -> "", 76 | "server" -> "", 77 | "set-cookie" -> "", 78 | "strict-transport-security" -> "", 79 | "transfer-encoding" -> "", 80 | "user-agent" -> "", 81 | "vary" -> "", 82 | "via" -> "", 83 | "www-authenticate" -> "" 84 | ) 85 | } 86 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/hpack/Table.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack 2 | 3 | import scala.collection.mutable 4 | import akka.util.ByteString 5 | 6 | trait Table extends { 7 | import Table.{FindResult, FoundName, FoundNameValue, NotFound} 8 | 9 | type Entry <: Table.Entry {type Self = Entry} 10 | def Entry(name: ByteString, value: ByteString): Entry 11 | 12 | def entries: collection.IndexedSeq[Entry] 13 | protected def entryMap: collection.Map[Entry, Int] 14 | def length: Int = entries.length 15 | 16 | @inline def get(index: Int): Option[Entry] = entries.lift(index - 1) 17 | 18 | def lookupEntry(entry: Entry): Option[Int] = entryMap.get(entry) 19 | 20 | def find(name: ByteString, value: Option[ByteString]): FindResult = { 21 | value flatMap { v => 22 | lookupEntry(Entry(name, v)) match { 23 | case Some(index) => Some(FoundNameValue(index)) 24 | case _ if v.isEmpty => Some(NotFound) // Don't lookup empty values twice 25 | case _ => None 26 | } 27 | } orElse { 28 | lookupEntry(Entry(name, ByteString.empty)).map(FoundName) 29 | } getOrElse { 30 | NotFound 31 | } 32 | } 33 | } 34 | 35 | object Table { 36 | sealed trait FindResult 37 | case class FoundName(index: Int) extends FindResult 38 | case class FoundNameValue(index: Int) extends FindResult 39 | case object NotFound extends FindResult 40 | 41 | trait Entry { 42 | type Self <: Entry 43 | def name: ByteString 44 | def value: ByteString 45 | def withEmptyValue: Self 46 | } 47 | 48 | def genEntryMap[E <: Entry {type Self = E}](entries: Traversable[E]): Traversable[(E, Int)] = { 49 | new Traversable[(E, Int)] { 50 | override def foreach[U](f: ((E, Int)) => U): Unit = { 51 | var i = 1 52 | entries.foreach { case entry => 53 | f(entry -> i) 54 | f(entry.withEmptyValue -> i) 55 | i += 1 56 | } 57 | } 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/hpack/coders/BytesCoder.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack.coders 2 | 3 | import akka.util.{ByteStringBuilder, ByteString} 4 | import scalaz.\/ 5 | import scalaz.syntax.either._ 6 | 7 | import net.danielkza.http2.Coder 8 | import net.danielkza.http2.hpack.HeaderError 9 | 10 | class BytesCoder(lengthPrefix: Int = 0) extends Coder[ByteString] { 11 | override final type Error = HeaderError 12 | 13 | val intCoder = new IntCoder(7, lengthPrefix) 14 | 15 | override def encode(value: ByteString, stream: ByteStringBuilder): \/[HeaderError, Unit] = { 16 | intCoder.encode(value.length, stream).map { _ => 17 | stream ++= value 18 | } 19 | } 20 | 21 | override def decode(bs: ByteString): \/[HeaderError, (ByteString, Int)] = { 22 | for { 23 | lengthResult <- intCoder.decode(bs) 24 | (length, numReadBytes) = lengthResult 25 | content = bs.slice(numReadBytes, numReadBytes + length) 26 | _ <- if(content.length != length) 27 | HeaderError.IncompleteInput(content.length, length).left 28 | else 29 | ().right 30 | } yield (content, numReadBytes + length) 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/hpack/coders/CompressedBytesCoder.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack.coders 2 | 3 | import scala.annotation.tailrec 4 | import scalaz._ 5 | import scalaz.syntax.either._ 6 | import akka.util.{ByteString, ByteStringBuilder} 7 | import net.danielkza.http2.Coder 8 | import net.danielkza.http2.hpack.HeaderError 9 | 10 | class CompressedBytesCoder extends Coder[ByteString] { 11 | final override type Error = HeaderError 12 | 13 | import CompressedBytesCoder._ 14 | 15 | private def encodeInternal(value: ByteString, buffer: ByteStringBuilder): Int = { 16 | import HuffmanCoding.{Code, coding} 17 | 18 | var hold: Long = 0 19 | var outstandingBits = 0 20 | var length: Int = 0 21 | 22 | def encodeSymbol(symbol: Byte): Unit = { 23 | val Code(code, codeBits) = coding.encodingTable(symbol) 24 | hold |= (code & 0xFFFFFFFFL) << (32 - outstandingBits) 25 | outstandingBits += codeBits 26 | 27 | while(outstandingBits >= 8) { 28 | buffer += (hold >>> 56).toByte 29 | outstandingBits -= 8 30 | hold <<= 8 31 | 32 | length += 1 33 | } 34 | } 35 | 36 | value.foreach(byte => encodeSymbol(byte)) 37 | 38 | if(outstandingBits > 0) { 39 | val lastByte = (hold >>> 56) | (0xFF >>> outstandingBits) 40 | buffer.putByte(lastByte.toByte) 41 | length += 1 42 | } 43 | 44 | length 45 | } 46 | 47 | override def encode(value: ByteString, stream: ByteStringBuilder): \/[HeaderError, Unit] = { 48 | val buffer = ByteString.newBuilder 49 | buffer.sizeHint((value.length * 2) / 3) 50 | 51 | val length = encodeInternal(value, buffer) 52 | lengthCoder.encode(length, stream).map { _ => 53 | stream ++= buffer.result() 54 | } 55 | } 56 | 57 | override def encode(value: ByteString): \/[HeaderError, ByteString] = { 58 | val buffer = ByteString.newBuilder 59 | encode(value, buffer).map(_ => buffer.result()) 60 | } 61 | 62 | private def validPadding(value: Int, bits: Int) = { 63 | if (bits == 0) true 64 | else (value.toByte >> (8 - bits)) == -1 65 | } 66 | 67 | private def decodeLiteral(iter: BufferedIterator[Byte], length: Int): \/[HeaderError, (ByteString, Int)] = { 68 | import HuffmanCoding.{EOS, coding, Symbol} 69 | import HeaderError._ 70 | import coding.DecodingTable 71 | 72 | var outstandingBits = 0 73 | var hold: Long = 0 74 | var prefixBits: Int = 0 75 | var bytesRead = 0 76 | var error: HeaderError = null 77 | 78 | @inline def raiseError(f: (Int) => HeaderError): Short = { 79 | error = f(bytesRead) 80 | -1 81 | } 82 | 83 | @tailrec 84 | def decodeSymbol(currentTable: DecodingTable = coding.initialTable): Short = { 85 | if(outstandingBits >= 32) 86 | return raiseError(UnknownCode) 87 | 88 | if(bytesRead < length && (outstandingBits - prefixBits) < 8) { 89 | if(!iter.hasNext) 90 | return raiseError(IncompleteInput(_, length)) 91 | else { 92 | hold |= (iter.next() & 0xFFL) << (56 - outstandingBits) 93 | bytesRead += 1 94 | outstandingBits += 8 95 | } 96 | } 97 | 98 | val nextByte = ((hold >>> (56 - prefixBits)) & 0xFF).toInt 99 | currentTable.lift(nextByte) match { 100 | case Some(Symbol(symbol, codeLen)) if outstandingBits >= codeLen => 101 | outstandingBits -= codeLen 102 | hold <<= codeLen 103 | prefixBits = 0 104 | 105 | if(symbol == EOS) 106 | raiseError(EOSDecoded) 107 | else 108 | symbol 109 | case _ if bytesRead == length => 110 | if(outstandingBits > 7) 111 | raiseError(ExcessivePadding) 112 | else if(!validPadding(nextByte, outstandingBits)) 113 | raiseError(InvalidPadding) 114 | else 115 | EOS 116 | case _ => 117 | prefixBits = 8 * (outstandingBits / 8) // observe truncation 118 | val prefix = (hold >>> 32).toInt & (-1 << (32 - prefixBits)) 119 | coding.decodingTableForCode(prefix) match { 120 | case Some(nextTable) => 121 | decodeSymbol(nextTable) 122 | case _ => raiseError(UnknownCode) 123 | } 124 | } 125 | } 126 | 127 | var res = new ByteStringBuilder 128 | res.sizeHint((length * 3) / 2) 129 | 130 | @tailrec 131 | def decode(): \/[HeaderError, (ByteString, Int)] = { 132 | decodeSymbol() match { 133 | case HuffmanCoding.EOS => 134 | (res.result(), bytesRead).right 135 | case -1 => 136 | -\/(error) 137 | case symbol => 138 | res = res.putByte(symbol.toByte) 139 | decode() 140 | } 141 | } 142 | 143 | decode() 144 | } 145 | 146 | override def decode(bs: ByteString): \/[HeaderError, (ByteString, Int)] = { 147 | for { 148 | encodedBytesResult <- bytesCoder.decode(bs) 149 | (encodedBytes, numReadBytes) = encodedBytesResult 150 | decodeResult <- decodeLiteral(encodedBytes.iterator, encodedBytes.length) 151 | } yield (decodeResult._1, numReadBytes) 152 | } 153 | } 154 | 155 | object CompressedBytesCoder { 156 | val lengthCoder = new IntCoder(7, prefix=1) 157 | val bytesCoder = new BytesCoder(lengthPrefix=1) 158 | } 159 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/hpack/coders/HeaderBlockCoder.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack.coders 2 | 3 | import scalaz._ 4 | import scalaz.syntax.either._ 5 | import scalaz.syntax.std.option._ 6 | import akka.util.{ByteString, ByteStringBuilder} 7 | import net.danielkza.http2.Coder 8 | import net.danielkza.http2.api.Header 9 | import net.danielkza.http2.hpack._ 10 | 11 | class HeaderBlockCoder(maxCapacity: Int = 4096, 12 | private val headerCoder: HeaderCoder = new HeaderCoder()) 13 | extends Coder[Seq[Header]] 14 | { 15 | import Header._ 16 | import HeaderRepr._ 17 | 18 | override type Error = HeaderError 19 | 20 | private val staticTable = StaticTable.default 21 | private var dynamicTable: DynamicTable = new DynamicTable(maxCapacity, baseOffset = staticTable.length) 22 | 23 | def withCapacity(capacity: Int): \/[HeaderError, HeaderBlockCoder] = 24 | dynamicTable.withCapacity(capacity).map { dt => dynamicTable = dt; this } 25 | 26 | def withMaxCapacity(capacity: Int): HeaderBlockCoder = { 27 | dynamicTable = dynamicTable.withMaxCapacity(capacity) 28 | this 29 | } 30 | 31 | private def indexToEntry(index: Int): \/[HeaderError, Table.Entry] = { 32 | staticTable.get(index) orElse dynamicTable.get(index) match { 33 | case Some(entry) => \/-(entry) 34 | case _ => -\/(HeaderError.InvalidIndex(index)) 35 | } 36 | } 37 | 38 | private def entryToIndex(name: ByteString, value: Option[ByteString]): Table.FindResult = { 39 | import Table.NotFound 40 | 41 | staticTable.find(name, value) match { 42 | case NotFound => dynamicTable.find(name, value) 43 | case r => r 44 | } 45 | } 46 | 47 | private def processHeader(header: Header): (HeaderRepr, DynamicTable) = { 48 | import Table.{FoundName, FoundNameValue, NotFound} 49 | 50 | val name = header.name 51 | val value = header.value 52 | 53 | if(!header.secure) { 54 | entryToIndex(name, Some(value)) match { 55 | case FoundNameValue(index) => 56 | Indexed(index) -> dynamicTable 57 | case FoundName(index) => 58 | IncrementalLiteralWithIndexedName(index, value) -> (dynamicTable + (name -> value)) 59 | case NotFound => 60 | IncrementalLiteral(name, value) -> (dynamicTable + (name -> value)) 61 | } 62 | } else { 63 | entryToIndex(name, None) match { 64 | case FoundName(index) => NeverIndexedWithIndexedName(index, value) -> dynamicTable 65 | case NotFound => NeverIndexed(name, value) -> dynamicTable 66 | case _ => throw new AssertionError("Secure header value should never be matched") 67 | } 68 | } 69 | } 70 | 71 | private def processHeaderRepr(headerRepr: HeaderRepr) 72 | : \/[HeaderError, (Option[Header], DynamicTable)] = 73 | { 74 | headerRepr match { 75 | case DynamicTableSizeUpdate(size) => 76 | dynamicTable.withCapacity(size).map(None -> _) 77 | case Indexed(index) => 78 | indexToEntry(index).map { e => plain(e.name, e.value).some -> dynamicTable } 79 | case h @ IncrementalLiteralWithIndexedName(index, value) => 80 | indexToEntry(index).map { e => plain(e.name, value).some -> (dynamicTable + (e.name -> value)) } 81 | case h @ LiteralWithIndexedName(index, value) => 82 | indexToEntry(index).map { e => plain(e.name, value).some -> dynamicTable} 83 | case h @ NeverIndexedWithIndexedName(index, value) => 84 | indexToEntry(index).map { e => secure(e.name, value).some -> dynamicTable} 85 | case h @ IncrementalLiteral(name, value) => 86 | (plain(name, value).some -> (dynamicTable + (name -> value))).right 87 | case h @ Literal(name, value) => 88 | (plain(name, value).some -> dynamicTable).right 89 | case h @ NeverIndexed(name, value) => 90 | (secure(name, value).some -> dynamicTable).right 91 | } 92 | } 93 | 94 | override def decode(bs: ByteString): \/[HeaderError, (Seq[Header], Int)] = { 95 | val headers = Seq.newBuilder[Header] 96 | var buffer = bs 97 | var totalBytes = 0 98 | do { 99 | (for { 100 | reprDec <- headerCoder.decode(buffer) 101 | (repr, bytesRead) = reprDec 102 | processed <- processHeaderRepr(repr) 103 | (maybeHeader, newDynamicTable) = processed 104 | } yield { 105 | maybeHeader.foreach { headers += _ } 106 | dynamicTable = newDynamicTable 107 | totalBytes += bytesRead 108 | buffer = buffer.drop(bytesRead) 109 | }) match { 110 | // match on left so right is inferred from the return type 111 | case -\/(e) => return -\/(e) 112 | case _ => // pass 113 | } 114 | } while(buffer.nonEmpty) 115 | 116 | \/-(headers.result() -> totalBytes) 117 | } 118 | 119 | override def encode(headers: Seq[Header], stream: ByteStringBuilder): \/[HeaderError, Unit] = { 120 | headers.foreach { header => 121 | val (repr, newDynamicTable) = processHeader(header) 122 | headerCoder.encode(repr, stream) match { 123 | case e @ -\/(_) => return e 124 | case _ => dynamicTable = newDynamicTable 125 | } 126 | } 127 | 128 | \/-(()) 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/hpack/coders/HeaderCoder.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack.coders 2 | 3 | import scalaz.{-\/, \/} 4 | import akka.util.{ByteStringBuilder, ByteString} 5 | import shapeless._ 6 | import net.danielkza.http2.Coder 7 | import net.danielkza.http2.util._ 8 | import net.danielkza.http2.hpack.{HeaderError, HeaderRepr} 9 | 10 | class HeaderCoder(val compressionPredicate: ByteString => Boolean = HeaderCoder.compress.default) 11 | extends Coder[HeaderRepr] 12 | { 13 | import HeaderRepr._ 14 | import Coder.defineCompositeCoder 15 | 16 | override final type Error = HeaderError 17 | 18 | private final val indexedIndex = new IntCoder(7, prefix=1) 19 | private final val incrementalLiteralIndex = new IntCoder(6, prefix=1) 20 | private final val literalIndex = new IntCoder(4, prefix=0) 21 | private final val neverIndexedIndex = new IntCoder(4, prefix=1) 22 | private final val tableUpdateSize = new IntCoder(5, prefix=1) 23 | private final val literalString = new LiteralCoder(compressionPredicate) 24 | 25 | private final val indexed = defineCompositeCoder(Generic[Indexed]) 26 | { indexedIndex :: HNil } 27 | 28 | private final val incrLiteralIdxName = defineCompositeCoder(Generic[IncrementalLiteralWithIndexedName]) 29 | { incrementalLiteralIndex :: literalString :: HNil } 30 | 31 | private final val incrLiteral = defineCompositeCoder(Generic[HeaderRepr.IncrementalLiteral]) 32 | { literalString :: literalString :: HNil } 33 | 34 | private final val literalIdxName = defineCompositeCoder(Generic[HeaderRepr.LiteralWithIndexedName]) 35 | { literalIndex :: literalString :: HNil } 36 | 37 | private final val literal = defineCompositeCoder(Generic[HeaderRepr.Literal]) 38 | { literalString :: literalString :: HNil } 39 | 40 | private final val neverIndexedIdxName = defineCompositeCoder(Generic[HeaderRepr.NeverIndexedWithIndexedName]) 41 | { neverIndexedIndex :: literalString :: HNil } 42 | 43 | private final val neverIndexed = defineCompositeCoder(Generic[HeaderRepr.NeverIndexed]) 44 | { literalString :: literalString :: HNil } 45 | 46 | private final val dynamicTableSizeUpdate = defineCompositeCoder(Generic[DynamicTableSizeUpdate]) 47 | { tableUpdateSize :: HNil } 48 | 49 | 50 | override def encode(value: HeaderRepr, stream: ByteStringBuilder): \/[HeaderError, Unit] = { 51 | value match { 52 | case h: Indexed => indexed.encode(h, stream) 53 | case h: IncrementalLiteral => stream.putByte(bin_b"01000000"); incrLiteral.encode(h, stream) 54 | case h: IncrementalLiteralWithIndexedName => incrLiteralIdxName.encode(h, stream) 55 | case h: Literal => stream.putByte(bin_b"00000000"); literal.encode(h, stream) 56 | case h: LiteralWithIndexedName => literalIdxName.encode(h, stream) 57 | case h: NeverIndexed => stream.putByte(bin_b"00010000"); neverIndexed.encode(h, stream) 58 | case h: NeverIndexedWithIndexedName => neverIndexedIdxName.encode(h, stream) 59 | case h: DynamicTableSizeUpdate => dynamicTableSizeUpdate.encode(h, stream) 60 | case _ => -\/(HeaderError.ParseError) 61 | } 62 | } 63 | 64 | override def decode(stream: ByteString): \/[HeaderError, (HeaderRepr, Int)] = { 65 | stream.head match { 66 | case bin"1-------" => indexed.decode(stream) 67 | case bin"01000000" => incrLiteral.decode(stream.drop(1)).map { case (v, l) => (v, l + 1) } 68 | case bin"01------" => incrLiteralIdxName.decode(stream) 69 | case bin"001-----" => dynamicTableSizeUpdate.decode(stream) 70 | case bin"00000000" => literal.decode(stream.drop(1)).map { case (v, l) => (v, l + 1) } 71 | case bin"0000----" => literalIdxName.decode(stream) 72 | case bin"00010000" => neverIndexed.decode(stream.drop(1)).map { case (v, l) => (v, l + 1) } 73 | case bin"0001----" => neverIndexedIdxName.decode(stream) 74 | case _ => -\/(HeaderError.ParseError) 75 | } 76 | } 77 | } 78 | 79 | object HeaderCoder { 80 | object compress { 81 | case class Threshold(minBytes: Int) extends (ByteString => Boolean) { 82 | @inline override def apply(value: ByteString): Boolean = 83 | value.length >= minBytes 84 | } 85 | 86 | final val Never = (value: ByteString) => false 87 | final val Always = (value: ByteString) => true 88 | 89 | final val default = Threshold(512) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/hpack/coders/HuffmanCoding.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack.coders 2 | 3 | import scala.language.experimental.macros 4 | import scala.collection.immutable.{IntMap, HashMap} 5 | 6 | import scalaz.ImmutableArray 7 | 8 | class HuffmanCoding(codes: (Byte, Short)*) { 9 | import HuffmanCoding.{Code, Symbol, EOS} 10 | 11 | final val (decodingTablesMap, encodingTable) = expandCodes(codes: _*) 12 | final val initialTable = decodingTablesMap(0) 13 | 14 | final type DecodingTable = IndexedSeq[Short] 15 | final type DecodingTableMap = Map[Int, DecodingTable] 16 | final type EncodingTable = IndexedSeq[Long] 17 | 18 | private def expandCodes(codes: (Byte, Short)*): (DecodingTableMap, EncodingTable) = { 19 | var decodingTableMap = IntMap.empty[DecodingTable] 20 | var curDecodingTable = ImmutableArray.newBuilder[Short] 21 | val encodingTable = new Array[Long](EOS + 1) 22 | 23 | var lastPrefix = 0 24 | var currentCode = 0 25 | var prevLength = 0 26 | var prevLengthBytes = 1 27 | 28 | codes.foreach { case (length, symbol) => 29 | val lengthBytes = (length + 8 - 1) / 8 30 | 31 | // Only check for the need to create a new table if the code is larger than 8 bits. Otherwise it should always be 32 | // assigned to the initial table (since there's no previous byte to care about). 33 | if(lengthBytes > 1) { 34 | // Shift the code if we now use an extra byte 35 | var currentPrefix = currentCode 36 | 37 | if(lengthBytes > prevLengthBytes) { 38 | currentCode <<= 8 39 | } else { 40 | currentPrefix >>>= 8 41 | } 42 | 43 | if(currentPrefix != lastPrefix) { 44 | // The last byte has changed. Write out the just finished table to it's prefix and create a new one. 45 | decodingTableMap += Code.alignByBytes(lastPrefix, prevLengthBytes - 1) -> curDecodingTable.result() 46 | curDecodingTable = ImmutableArray.newBuilder 47 | 48 | // Start out a new prefix and adjust the current code to match the correct length 49 | lastPrefix = currentPrefix 50 | } 51 | } 52 | 53 | encodingTable(symbol) = Code(currentCode, length) 54 | 55 | curDecodingTable += Symbol(symbol, length) 56 | currentCode += 1 57 | 58 | // If the current code doesn't fill whole bytes, add extra indexes representing it with all the possible values 59 | // of the unfilled bytes. For example, if we have a code ending with '000011', it's entries will be: 60 | // '000011-00', '000011-01', '000011-10', '000011-11' to match any whole byte that has it as a prefix. 61 | val remBits = (lengthBytes * 8) - length 62 | if(remBits > 0) { 63 | val lastCode = currentCode + (1 << remBits) - 1 64 | do { 65 | curDecodingTable += Symbol(symbol, length) 66 | currentCode += 1 67 | } while(currentCode != lastCode) 68 | } 69 | 70 | prevLength = length 71 | prevLengthBytes = lengthBytes 72 | } 73 | 74 | // The last table won't be added inside the loop since it break after the values run out 75 | decodingTableMap += lastPrefix -> curDecodingTable.result() 76 | 77 | (decodingTableMap, ImmutableArray.make(encodingTable)) 78 | } 79 | 80 | def decodingTableForCode(currentCode: Int): Option[DecodingTable] = { 81 | decodingTablesMap.get(currentCode) 82 | } 83 | } 84 | 85 | 86 | object HuffmanCoding { 87 | final val EOS: Short = 256 88 | 89 | object Code { 90 | @inline def align(code: Int, length: Int): Int = { 91 | alignByBytes(code, (length + 8 - 1) / 8) 92 | } 93 | 94 | @inline def alignByBytes(code: Int, bytes: Int): Int = { 95 | bytes match { 96 | case 1 => code << 24 97 | case 2 => code << 16 98 | case 3 => code << 8 99 | case _ => code 100 | } 101 | } 102 | 103 | @inline def apply(code: Int, length: Int): Long = 104 | (align(code, length) & 0xFFFFFFFFL) | (length & 0xFFFFFFFFL) << 32 105 | 106 | def unapply(l: Long): Option[(Int, Int)] = 107 | Some((l & 0xFFFFFFFFL).toInt, (l >>> 32).toInt & 0xFF) 108 | } 109 | 110 | object Symbol { 111 | @inline def apply(symbol: Short, length: Byte): Short = (symbol | (length << 9)).toShort 112 | def unapply(s: Short): Option[(Short, Byte)] = 113 | Some((s & 511).toShort, (s >>> 9).toByte) 114 | } 115 | 116 | final lazy val coding = new HuffmanCoding( 117 | (5, 48), 118 | (5, 49), 119 | (5, 50), 120 | (5, 97), 121 | (5, 99), 122 | (5, 101), 123 | (5, 105), 124 | (5, 111), 125 | (5, 115), 126 | (5, 116), 127 | (6, 32), 128 | (6, 37), 129 | (6, 45), 130 | (6, 46), 131 | (6, 47), 132 | (6, 51), 133 | (6, 52), 134 | (6, 53), 135 | (6, 54), 136 | (6, 55), 137 | (6, 56), 138 | (6, 57), 139 | (6, 61), 140 | (6, 65), 141 | (6, 95), 142 | (6, 98), 143 | (6, 100), 144 | (6, 102), 145 | (6, 103), 146 | (6, 104), 147 | (6, 108), 148 | (6, 109), 149 | (6, 110), 150 | (6, 112), 151 | (6, 114), 152 | (6, 117), 153 | (7, 58), 154 | (7, 66), 155 | (7, 67), 156 | (7, 68), 157 | (7, 69), 158 | (7, 70), 159 | (7, 71), 160 | (7, 72), 161 | (7, 73), 162 | (7, 74), 163 | (7, 75), 164 | (7, 76), 165 | (7, 77), 166 | (7, 78), 167 | (7, 79), 168 | (7, 80), 169 | (7, 81), 170 | (7, 82), 171 | (7, 83), 172 | (7, 84), 173 | (7, 85), 174 | (7, 86), 175 | (7, 87), 176 | (7, 89), 177 | (7, 106), 178 | (7, 107), 179 | (7, 113), 180 | (7, 118), 181 | (7, 119), 182 | (7, 120), 183 | (7, 121), 184 | (7, 122), 185 | (8, 38), 186 | (8, 42), 187 | (8, 44), 188 | (8, 59), 189 | (8, 88), 190 | (8, 90), 191 | (10, 33), 192 | (10, 34), 193 | (10, 40), 194 | (10, 41), 195 | (10, 63), 196 | (11, 39), 197 | (11, 43), 198 | (11, 124), 199 | (12, 35), 200 | (12, 62), 201 | (13, 0), 202 | (13, 36), 203 | (13, 64), 204 | (13, 91), 205 | (13, 93), 206 | (13, 126), 207 | (14, 94), 208 | (14, 125), 209 | (15, 60), 210 | (15, 96), 211 | (15, 123), 212 | (19, 92), 213 | (19, 195), 214 | (19, 208), 215 | (20, 128), 216 | (20, 130), 217 | (20, 131), 218 | (20, 162), 219 | (20, 184), 220 | (20, 194), 221 | (20, 224), 222 | (20, 226), 223 | (21, 153), 224 | (21, 161), 225 | (21, 167), 226 | (21, 172), 227 | (21, 176), 228 | (21, 177), 229 | (21, 179), 230 | (21, 209), 231 | (21, 216), 232 | (21, 217), 233 | (21, 227), 234 | (21, 229), 235 | (21, 230), 236 | (22, 129), 237 | (22, 132), 238 | (22, 133), 239 | (22, 134), 240 | (22, 136), 241 | (22, 146), 242 | (22, 154), 243 | (22, 156), 244 | (22, 160), 245 | (22, 163), 246 | (22, 164), 247 | (22, 169), 248 | (22, 170), 249 | (22, 173), 250 | (22, 178), 251 | (22, 181), 252 | (22, 185), 253 | (22, 186), 254 | (22, 187), 255 | (22, 189), 256 | (22, 190), 257 | (22, 196), 258 | (22, 198), 259 | (22, 228), 260 | (22, 232), 261 | (22, 233), 262 | (23, 1), 263 | (23, 135), 264 | (23, 137), 265 | (23, 138), 266 | (23, 139), 267 | (23, 140), 268 | (23, 141), 269 | (23, 143), 270 | (23, 147), 271 | (23, 149), 272 | (23, 150), 273 | (23, 151), 274 | (23, 152), 275 | (23, 155), 276 | (23, 157), 277 | (23, 158), 278 | (23, 165), 279 | (23, 166), 280 | (23, 168), 281 | (23, 174), 282 | (23, 175), 283 | (23, 180), 284 | (23, 182), 285 | (23, 183), 286 | (23, 188), 287 | (23, 191), 288 | (23, 197), 289 | (23, 231), 290 | (23, 239), 291 | (24, 9), 292 | (24, 142), 293 | (24, 144), 294 | (24, 145), 295 | (24, 148), 296 | (24, 159), 297 | (24, 171), 298 | (24, 206), 299 | (24, 215), 300 | (24, 225), 301 | (24, 236), 302 | (24, 237), 303 | (25, 199), 304 | (25, 207), 305 | (25, 234), 306 | (25, 235), 307 | (26, 192), 308 | (26, 193), 309 | (26, 200), 310 | (26, 201), 311 | (26, 202), 312 | (26, 205), 313 | (26, 210), 314 | (26, 213), 315 | (26, 218), 316 | (26, 219), 317 | (26, 238), 318 | (26, 240), 319 | (26, 242), 320 | (26, 243), 321 | (26, 255), 322 | (27, 203), 323 | (27, 204), 324 | (27, 211), 325 | (27, 212), 326 | (27, 214), 327 | (27, 221), 328 | (27, 222), 329 | (27, 223), 330 | (27, 241), 331 | (27, 244), 332 | (27, 245), 333 | (27, 246), 334 | (27, 247), 335 | (27, 248), 336 | (27, 250), 337 | (27, 251), 338 | (27, 252), 339 | (27, 253), 340 | (27, 254), 341 | (28, 2), 342 | (28, 3), 343 | (28, 4), 344 | (28, 5), 345 | (28, 6), 346 | (28, 7), 347 | (28, 8), 348 | (28, 11), 349 | (28, 12), 350 | (28, 14), 351 | (28, 15), 352 | (28, 16), 353 | (28, 17), 354 | (28, 18), 355 | (28, 19), 356 | (28, 20), 357 | (28, 21), 358 | (28, 23), 359 | (28, 24), 360 | (28, 25), 361 | (28, 26), 362 | (28, 27), 363 | (28, 28), 364 | (28, 29), 365 | (28, 30), 366 | (28, 31), 367 | (28, 127), 368 | (28, 220), 369 | (28, 249), 370 | (30, 10), 371 | (30, 13), 372 | (30, 22), 373 | (30, 256) 374 | ) 375 | } 376 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/hpack/coders/IntCoder.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack.coders 2 | 3 | import java.util.NoSuchElementException 4 | import scalaz.\/ 5 | import scalaz.syntax.either._ 6 | import akka.util.{ByteString, ByteStringBuilder} 7 | import net.danielkza.http2.Coder 8 | import net.danielkza.http2.hpack.HeaderError 9 | 10 | class IntCoder(elementSize: Int, prefix: Int = 0) extends Coder[Int] { 11 | override final type Error = HeaderError 12 | 13 | val elementMax = (1 << elementSize) - 1 14 | val elementMask = elementMax 15 | 16 | val prefixSize = 8 - elementSize 17 | val prefixMax = (1 << prefixSize) - 1 18 | val alignedPrefix = (prefix & prefixMax) << elementSize 19 | 20 | override def encode(value: Int, stream: ByteStringBuilder): \/[HeaderError, Unit] = { 21 | if(value < elementMax) { 22 | stream.putByte(((value & elementMax) | alignedPrefix).toByte) 23 | } else { 24 | stream.putByte((elementMax | alignedPrefix).toByte) 25 | 26 | var tmpValue = value - elementMax 27 | while(tmpValue >= 128) { 28 | stream.putByte(((tmpValue & 127) | 128).toByte) 29 | tmpValue >>>= 7 30 | } 31 | 32 | stream.putByte(tmpValue.toByte) 33 | } 34 | 35 | ().right 36 | } 37 | 38 | private def decode(stream: BufferedIterator[Byte]): \/[HeaderError, (Int, Int)] = try { 39 | val firstElm = stream.next & elementMask 40 | if(firstElm < elementMax) 41 | (firstElm.toInt, 1).right 42 | else { 43 | var bitPos: Int = 0 44 | var value: Int = elementMax 45 | var byte: Byte = 0 46 | do { 47 | byte = stream.next() 48 | value += (byte & 127) << bitPos 49 | bitPos += 7 50 | } while((byte & 128) == 128) 51 | (value, bitPos / 7 + 1).right 52 | } 53 | } catch { case e: NoSuchElementException => 54 | HeaderError.ParseError.left 55 | } 56 | 57 | override def decode(bs: ByteString): \/[HeaderError, (Int, Int)] = 58 | decode(bs.iterator) 59 | } 60 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/hpack/coders/LiteralCoder.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack.coders 2 | 3 | import scalaz.\/ 4 | import akka.util.{ByteStringBuilder, ByteString} 5 | import net.danielkza.http2.util._ 6 | import net.danielkza.http2.Coder 7 | import net.danielkza.http2.hpack.HeaderError 8 | 9 | class LiteralCoder(val compressionPredicate: ByteString => Boolean) extends Coder[ByteString] { 10 | override final type Error = HeaderError 11 | 12 | private val huffmanCoder = new CompressedBytesCoder 13 | private val plainCoder = new BytesCoder 14 | 15 | override def encode(value: ByteString): \/[HeaderError, ByteString] = { 16 | if(!compressionPredicate(value)) 17 | plainCoder.encode(value) 18 | else 19 | huffmanCoder.encode(value) 20 | } 21 | 22 | override def encode(value: ByteString, stream: ByteStringBuilder): \/[HeaderError, Unit] = { 23 | if(!compressionPredicate(value)) 24 | plainCoder.encode(value, stream) 25 | else 26 | huffmanCoder.encode(value, stream) 27 | } 28 | 29 | override def decode(bs: ByteString): \/[HeaderError, (ByteString, Int)] = { 30 | bs.head match { 31 | case bin"1-------" => 32 | huffmanCoder.decode(bs) 33 | case _ => 34 | plainCoder.decode(bs) 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/model/AkkaMessageAdapter.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.model 2 | 3 | import net.danielkza.http2.model.headers.Trailer 4 | 5 | import scala.collection.immutable 6 | import scalaz._ 7 | import scalaz.syntax.either._ 8 | import scalaz.syntax.std.option._ 9 | import akka.util.ByteString 10 | import akka.http.ParserSettings 11 | import akka.http.scaladsl.{model => akkaModel} 12 | import net.danielkza.http2.api.Header 13 | import net.danielkza.http2.protocol.HTTP2Error 14 | 15 | class AkkaMessageAdapter(parserSettings: ParserSettings) { 16 | import Header._ 17 | import Header.Constants._ 18 | import akkaModel.{headers => h, _} 19 | import HTTP2Error._ 20 | 21 | private def headerErr(info: ErrorInfo): HeaderError = 22 | HeaderError(errorInfo = Some(info)) 23 | 24 | private def headerErr(summary: String): HeaderError = 25 | headerErr(ErrorInfo(summary)) 26 | 27 | def headersFromAkkaUri(akkaUri: Uri): \/[HTTP2Error, immutable.Seq[Header]] = { 28 | val scheme = akkaUri.scheme 29 | val authority = if(!akkaUri.authority.isEmpty) { 30 | if(!akkaUri.authority.userinfo.isEmpty) 31 | return headerErr("Userinfo not allowed in :authority").left 32 | else 33 | ByteString(akkaUri.authority.toString) 34 | } else { 35 | ByteString.empty 36 | } 37 | 38 | val target = ByteString(akkaUri.toHttpRequestTargetOriginForm.toString) 39 | var headers = immutable.Seq(RawHeader(SCHEME, ByteString(scheme)), RawHeader(PATH, target)) 40 | if(authority.nonEmpty) headers = headers :+ RawHeader(AUTHORITY, authority) 41 | 42 | headers.right 43 | } 44 | 45 | def headersFromAkka(akkaHeaders: Seq[HttpHeader]): immutable.Seq[Header] = 46 | akkaHeaders.map(WrappedAkkaHeader(_)).toList 47 | 48 | def headersFromAkkaMessage(message: HttpMessage): \/[HTTP2Error, immutable.Seq[Header]] = { 49 | message match { 50 | case req: HttpRequest => 51 | for { 52 | uriHeaders <- headersFromAkkaUri(req.uri) 53 | method = WrappedAkkaMethod(req.method) 54 | otherHeaders = req.headers.map(WrappedAkkaHeader(_)) 55 | } yield (uriHeaders :+ method) ++ otherHeaders 56 | case resp: HttpResponse => 57 | var headers = immutable.Seq.newBuilder[Header] 58 | 59 | headers += WrappedAkkaStatusCode(resp.status) 60 | headers += Header.plain("Content-Type", resp.entity.contentType.value) 61 | 62 | resp.entity match { 63 | case _: HttpEntity.CloseDelimited => 64 | return InternalError("CloseDelimited response is not supported in HTTP/2").left 65 | case _: HttpEntity.Chunked => 66 | case _ if resp.headers.exists(_.isInstanceOf[Trailer]) => 67 | return InternalError("Trailer header only supported with Chunked response entity").left 68 | case _ => 69 | } 70 | 71 | resp.entity.contentLengthOption.foreach { len => 72 | headers += Header.plain("Content-Length", len.toString) 73 | } 74 | 75 | if(!resp.headers.exists(_.isInstanceOf[h.Date])) 76 | headers += h.Date(DateTime.now) 77 | 78 | headers ++= resp.headers.map(h => h: WrappedAkkaHeader) 79 | headers.result().right 80 | } 81 | } 82 | 83 | private def decodeBytes(bs: ByteString): String = 84 | bs.decodeString("UTF-8") 85 | 86 | def parseHostAuthority(name: String, value: String) = { 87 | try { 88 | HttpHeader.parse("Host", value) match { 89 | case HttpHeader.ParsingResult.Ok(akkaModel.headers.Host(uriHost, port), _) => 90 | Uri.Authority(uriHost, port).right 91 | case HttpHeader.ParsingResult.Error(error) => 92 | throw IllegalUriException(error) 93 | case _ => 94 | throw IllegalUriException("Bad host value") 95 | 96 | } 97 | } catch { case e: IllegalUriException => 98 | headerErr(e.info.withSummaryPrepended(s"Invalid $name")).left 99 | } 100 | } 101 | 102 | private def parseSingleHeader(header: Header): \/[HTTP2Error, HttpHeader] = { 103 | HttpHeader.parse(decodeBytes(header.name), decodeBytes(header.value)) match { 104 | case HttpHeader.ParsingResult.Error(error) => 105 | headerErr(error).left 106 | case HttpHeader.ParsingResult.Ok(akkaHeader, _) => 107 | akkaHeader.right 108 | } 109 | } 110 | 111 | def headersToAkka(headers: Seq[Header]): \/[HTTP2Error, immutable.Seq[HttpHeader]] = { 112 | // It could be nice to use Scalaz's traverse here, but we want to take any kind of Seq and return an immutable.Seq, 113 | // but it can only return the same collection that it takes 114 | val akkaHeaders = immutable.Seq.newBuilder[HttpHeader] 115 | 116 | headers.foreach { header => 117 | parseSingleHeader(header) match { 118 | case \/-(akkaHeader) => akkaHeaders += akkaHeader 119 | case e @ -\/(_) => return e 120 | } 121 | } 122 | 123 | akkaHeaders.result().right 124 | } 125 | 126 | def headersToAkkaRequest(headers: Seq[Header]): \/[HTTP2Error, HttpRequest] = { 127 | var scheme: Option[String] = None 128 | var authority: Option[String] = None 129 | var path: Option[String] = None 130 | var host: Option[String] = None 131 | var method: Option[HttpMethod] = None 132 | val akkaHeaders = immutable.Seq.newBuilder[HttpHeader] 133 | 134 | headers.foreach { header => 135 | header.name match { 136 | case STATUS => 137 | return headerErr("Status not allowed in request").left 138 | case SCHEME if scheme.isDefined => 139 | return headerErr("Scheme redefined").left 140 | case SCHEME => 141 | scheme = decodeBytes(header.value).some 142 | case METHOD if method.isDefined => 143 | return headerErr("Method redefined").left 144 | case METHOD => 145 | val stringValue = decodeBytes(header.value) 146 | method = HttpMethods.getForKey(stringValue).orElse { 147 | HttpMethod.custom(stringValue).some 148 | } 149 | case PATH if path.isDefined => 150 | return headerErr("Path redefined").left 151 | case PATH => 152 | path = decodeBytes(header.value).some 153 | case AUTHORITY if authority.isDefined => 154 | return headerErr("Authority redefined").left 155 | case AUTHORITY => 156 | authority = decodeBytes(header.value).some 157 | case HOST if host.isDefined => 158 | return headerErr("Host redefined").left 159 | case HOST => 160 | host = decodeBytes(header.value).some 161 | case _ => 162 | parseSingleHeader(header) match { 163 | case \/-(akkaHeader) => akkaHeaders += akkaHeader 164 | case err @ -\/(_) => return err 165 | } 166 | } 167 | } 168 | 169 | for { 170 | schemeVal <- scheme map(_.right) getOrElse 171 | headerErr("Scheme must not be empty in request").left 172 | pathVal <- path map(_.right) getOrElse 173 | headerErr("Path must not be empty in request").left 174 | methodVal <- method map(_.right) getOrElse 175 | headerErr("Method must not be empty in request").left 176 | authorityVal <- (authority, host) match { 177 | case (Some(_), Some(_)) => 178 | headerErr("Cannot have :authority and Host headers simultaneously").left 179 | case (Some(auth), _) => 180 | parseHostAuthority(":authority", auth) 181 | case (_, Some(hostHeader)) => 182 | parseHostAuthority("Host", hostHeader) 183 | case _ => 184 | Uri.Authority.Empty.right 185 | } 186 | uri <- try { 187 | val pathUri = Uri.parseHttpRequestTarget(pathVal, mode = parserSettings.uriParsingMode) 188 | pathUri.withScheme(schemeVal).withAuthority(authorityVal).right 189 | } catch { case IllegalUriException(error) => 190 | headerErr(error.withSummaryPrepended("Invalid request target")).left 191 | } 192 | } yield HttpRequest(methodVal, uri, akkaHeaders.result()) 193 | } 194 | 195 | def parseStatus(value: ByteString): \/[HTTP2Error, StatusCode] = { 196 | val status = try { 197 | val code = decodeBytes(value).toInt 198 | StatusCodes.getForKey(code).orElse(parserSettings.customStatusCodes(code)) 199 | } catch { case _: NumberFormatException => 200 | None 201 | } 202 | 203 | status.map(_.right).getOrElse(-\/(headerErr("Invalid status code"))) 204 | } 205 | 206 | def headersToAkkaResponse(headers: Seq[Header]): \/[HTTP2Error, HttpResponse] = { 207 | var status: Option[StatusCode] = None 208 | val akkaHeaders = immutable.Seq.newBuilder[HttpHeader] 209 | 210 | headers.foreach { header => 211 | header.name match { 212 | case STATUS if status.isDefined => 213 | return headerErr("Status redefined").left 214 | case STATUS => 215 | status = parseStatus(header.value) match { 216 | case -\/(error) => return error.left 217 | case \/-(statusCode) => statusCode.some 218 | } 219 | case SCHEME => 220 | return headerErr(":status not allowed in response").left 221 | case METHOD => 222 | return headerErr(":method not allowed in response").left 223 | case PATH => 224 | return headerErr(":path not allowed in response").left 225 | case AUTHORITY => 226 | return headerErr(":authority not allowed in response").left 227 | case HOST => 228 | return headerErr("Host not allowed in response").left 229 | case _ => 230 | HttpHeader.parse(decodeBytes(header.name), decodeBytes(header.value)) match { 231 | case HttpHeader.ParsingResult.Ok(akkaHeader, _) => akkaHeaders += akkaHeader 232 | case HttpHeader.ParsingResult.Error(error) => return headerErr(error).left 233 | } 234 | } 235 | } 236 | 237 | status map { status => 238 | HttpResponse(status, akkaHeaders.result()).right 239 | } getOrElse { 240 | headerErr("Status must not be empty in response").left 241 | } 242 | } 243 | } 244 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/model/Http2Response.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.model 2 | 3 | import scala.language.implicitConversions 4 | import akka.stream.scaladsl.Source 5 | import akka.http.scaladsl.model.{HttpRequest, HttpResponse, StatusCodes} 6 | 7 | sealed trait Http2Response { 8 | def response: HttpResponse 9 | } 10 | 11 | object Http2Response { 12 | case class Simple(response: HttpResponse) extends Http2Response 13 | case class Promised(response: HttpResponse, promises: Source[(HttpRequest, HttpResponse), Any]) extends Http2Response 14 | object Promised { 15 | def apply(response: HttpResponse, promises: Iterable[(HttpRequest, HttpResponse)]): Promised = 16 | Promised(response, Source(() => promises.iterator)) 17 | } 18 | 19 | 20 | case object NoResponse extends Http2Response { 21 | override val response = HttpResponse(StatusCodes.InternalServerError) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/model/headers/Trailer.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.model.headers 2 | 3 | import akka.http.scaladsl.model.headers.CustomHeader 4 | 5 | import scala.collection.immutable 6 | 7 | final case class Trailer(fields: immutable.Seq[String]) extends CustomHeader { 8 | override def name(): String = "Trailer" 9 | 10 | override def value(): String = fields.mkString(", ") 11 | } 12 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/protocol/Frame.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.protocol 2 | 3 | import akka.util.ByteString 4 | 5 | sealed trait Frame { 6 | def tpe: Byte 7 | def flags: Byte 8 | def stream: Int 9 | def withFlags(flags: Byte): Frame 10 | } 11 | 12 | sealed trait HeaderFrame extends Frame { 13 | def headerFragment: ByteString 14 | } 15 | 16 | object Frame { 17 | final val HEADER_LENGTH = 9 18 | final val DEFAULT_MAX_FRAME_SIZE = 16384 19 | 20 | object Types { 21 | final val DATA: Byte = 0x0 22 | final val HEADERS: Byte = 0x1 23 | final val PRIORITY: Byte = 0x2 24 | final val RST_STREAM: Byte = 0x3 25 | final val SETTINGS: Byte = 0x4 26 | final val PUSH_PROMISE: Byte = 0x5 27 | final val PING: Byte = 0x6 28 | final val GOAWAY: Byte = 0x7 29 | final val WINDOW_UPDATE: Byte = 0x8 30 | final val CONTINUATION: Byte = 0x9 31 | } 32 | 33 | object Flags { 34 | object DATA { 35 | final val END_STREAM: Byte = 0x1 36 | final val PADDED: Byte = 0x8 37 | } 38 | 39 | object HEADERS { 40 | final val END_STREAM: Byte = 0x1 41 | final val END_HEADERS: Byte = 0x4 42 | final val PADDED: Byte = 0x8 43 | final val PRIORITY: Byte = 0x20 44 | } 45 | 46 | object PUSH_PROMISE { 47 | final val END_HEADERS: Byte = 0x4 48 | final val PADDED: Byte = 0x8 49 | } 50 | 51 | object PING { 52 | final val ACK: Byte = 0x1 53 | } 54 | 55 | object SETTINGS { 56 | final val ACK: Byte = 0x1 57 | } 58 | 59 | object CONTINUATION { 60 | final val END_HEADERS: Byte = 0x4 61 | } 62 | } 63 | 64 | case class NonStandard( 65 | override val stream: Int, 66 | override val tpe: Byte, 67 | override val flags: Byte, 68 | payload: ByteString 69 | ) extends Frame { 70 | override def withFlags(flags: Byte): NonStandard = copy(flags = flags) 71 | } 72 | 73 | sealed abstract class Standard(override val tpe: Byte) extends Frame { 74 | def flags: Byte = 0 75 | } 76 | 77 | case class StreamDependency(exclusive: Boolean, stream: Int, weight: Int) 78 | 79 | case class Data( 80 | override val stream: Int, 81 | data: ByteString, 82 | endStream: Boolean = false, 83 | padding: Option[ByteString] = None 84 | ) extends Standard(Types.DATA) { 85 | override def flags: Byte = { 86 | var flags = if(endStream) Flags.DATA.END_STREAM else 0 87 | padding.foreach { _ => flags |= Flags.DATA.PADDED } 88 | flags.toByte 89 | } 90 | 91 | override def withFlags(flags: Byte): Data = 92 | copy(endStream = (flags & Flags.DATA.END_STREAM) != 0) 93 | } 94 | 95 | case class Headers( 96 | override val stream: Int, 97 | streamDependency: Option[StreamDependency], 98 | headerFragment: ByteString, 99 | endStream: Boolean = false, 100 | endHeaders: Boolean = true, 101 | padding: Option[ByteString] = None 102 | ) extends Standard(Types.HEADERS) with HeaderFrame { 103 | override def flags: Byte = { 104 | var flags = if(streamDependency.isDefined) Flags.HEADERS.PRIORITY else 0 105 | padding.foreach { _ => flags |= Flags.HEADERS.PADDED } 106 | if(endStream) flags |= Flags.HEADERS.END_STREAM 107 | if(endHeaders) flags |= Flags.HEADERS.END_HEADERS 108 | 109 | flags.toByte 110 | } 111 | 112 | override def withFlags(flags: Byte): Headers = 113 | copy(endStream = (flags & Flags.HEADERS.END_STREAM) != 0, 114 | endHeaders = (flags & Flags.HEADERS.END_HEADERS) != 0) 115 | } 116 | 117 | case class Priority( 118 | override val stream: Int, 119 | streamDependency: StreamDependency 120 | ) extends Standard(Types.PRIORITY) { 121 | override def withFlags(flags: Byte): Priority = this 122 | } 123 | 124 | case class ResetStream( 125 | override val stream: Int, 126 | errorCode: Int 127 | ) extends Standard(Types.RST_STREAM) { 128 | override def withFlags(flags: Byte): ResetStream = this 129 | } 130 | case class PushPromise( 131 | override val stream: Int, 132 | promisedStream: Int, 133 | headerFragment: ByteString, 134 | endHeaders: Boolean = true, 135 | padding: Option[ByteString] = None 136 | ) extends Standard(Types.PUSH_PROMISE) with HeaderFrame { 137 | override def flags: Byte = { 138 | var flags = if(endHeaders) Flags.PUSH_PROMISE.END_HEADERS else 0 139 | padding.foreach { _ => flags |= Flags.PUSH_PROMISE.PADDED } 140 | flags.toByte 141 | } 142 | 143 | override def withFlags(flags: Byte): PushPromise = 144 | copy(endHeaders = (flags & Flags.PUSH_PROMISE.END_HEADERS) != 0) 145 | } 146 | 147 | case class Ping( 148 | data: ByteString, 149 | ack: Boolean = false 150 | ) extends Standard(Types.PING) { 151 | override def stream: Int = 0 152 | 153 | override def flags: Byte = 154 | if(ack) Flags.PING.ACK else 0 155 | 156 | override def withFlags(flags: Byte): Ping = 157 | copy(ack = (flags & Flags.PING.ACK) != 0) 158 | } 159 | 160 | case class Settings( 161 | settings: List[Setting], 162 | ack: Boolean = false 163 | ) extends Standard(Types.SETTINGS) { 164 | override def stream: Int = 0 165 | 166 | override def flags: Byte = 167 | if(ack) Flags.SETTINGS.ACK else 0 168 | 169 | override def withFlags(flags: Byte): Settings = 170 | copy(ack = (flags & Flags.SETTINGS.ACK) != 0) 171 | } 172 | 173 | object Settings { 174 | def apply(settings: Traversable[(Short, Int)]): Settings = 175 | Settings(settings.map { case (id, value) => Setting(id, value) }.toList) 176 | 177 | val ack: Settings = Settings(List(), ack = true) 178 | } 179 | 180 | case class GoAway( 181 | lastStream: Int, 182 | errorCode: Int, 183 | debugData: ByteString 184 | ) extends Standard(Types.GOAWAY) { 185 | override def stream: Int = 0 186 | 187 | override def withFlags(flags: Byte): GoAway = this 188 | } 189 | object GoAway { 190 | def apply(lastStream: Int, error: HTTP2Error = HTTP2Error.NoError()): GoAway = { 191 | GoAway(lastStream, error.code, error.debugData.getOrElse(ByteString.empty)) 192 | } 193 | } 194 | 195 | case class WindowUpdate( 196 | override val stream: Int, 197 | windowIncrement: Int 198 | ) extends Standard(Types.WINDOW_UPDATE) { 199 | override def withFlags(flags: Byte): WindowUpdate= this 200 | } 201 | 202 | case class Continuation( 203 | override val stream: Int, 204 | headerFragment: ByteString, 205 | endHeaders: Boolean = true 206 | ) extends Standard(Types.CONTINUATION) { 207 | override def withFlags(flags: Byte): Continuation = { 208 | copy(endHeaders = (flags & Flags.CONTINUATION.END_HEADERS) != 0) 209 | } 210 | } 211 | 212 | } 213 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/protocol/HTTP2Error.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.protocol 2 | 3 | import scala.language.implicitConversions 4 | import akka.util.ByteString 5 | import akka.http.scaladsl.model.ErrorInfo 6 | 7 | trait HTTP2Error { 8 | def code: Int 9 | def debugData: Option[ByteString] 10 | def errorInfo: Option[ErrorInfo] 11 | 12 | final def toException(message: String = null, cause: Throwable = null): HTTP2Exception = { 13 | val formattedMessage = (errorInfo, Option(message)) match { 14 | case (Some(error), Some(msg)) => s"$msg: ${error.summary}" 15 | case (Some(error), None) => error.summary 16 | case (None, Some(msg)) => msg 17 | case (None, None) => "Unknown error" 18 | } 19 | 20 | try { 21 | throw new HTTP2Exception(this)(getClass.getSimpleName + ": " + formattedMessage, cause) 22 | } catch { case e: HTTP2Exception => e } 23 | } 24 | 25 | final def toException: HTTP2Exception = toException() 26 | } 27 | 28 | case class HTTP2Exception(error: HTTP2Error)(message: String = null, cause: Throwable = null) 29 | extends Exception(message, cause) 30 | 31 | object HTTP2Error { 32 | object Codes { 33 | final val NO_ERROR = 0x0 34 | final val PROTOCOL_ERROR = 0x1 35 | final val INTERNAL_ERROR = 0x2 36 | final val FLOW_CONTROL_ERROR = 0x3 37 | final val SETTINGS_TIMEOUT = 0x4 38 | final val STREAM_CLOSED = 0x5 39 | final val FRAME_SIZE_ERROR = 0x6 40 | final val REFUSED_STREAM = 0x7 41 | final val CANCEL = 0x8 42 | final val COMPRESSION_ERROR = 0x9 43 | final val CONNECT_ERROR = 0xa 44 | final val ENHANCE_YOUR_CALM = 0xb 45 | final val INADEQUATE_SECURITY = 0xc 46 | final val HTTP_1_1_REQUIRED = 0xd 47 | } 48 | import Codes._ 49 | 50 | implicit def toException(error: HTTP2Error): HTTP2Exception = error.toException 51 | 52 | sealed abstract class Standard(override val code: Int) extends HTTP2Error { 53 | type Self <: HTTP2Error 54 | def withDebugData(debugData: Option[ByteString]): Self 55 | def withErrorInfo(errorInfo: Option[ErrorInfo]): Self 56 | def withErrorMessage(message: String): Self = withErrorInfo(Some(ErrorInfo(message))) 57 | } 58 | 59 | private case class GenericStandard(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 60 | (errorCode: Int) extends Standard(errorCode) { 61 | override type Self = GenericStandard 62 | 63 | override def withDebugData(debugData: Option[ByteString]): Self = copy(debugData = debugData)(errorCode) 64 | override def withErrorInfo(errorInfo: Option[ErrorInfo]): Self = copy(errorInfo = errorInfo)(errorCode) 65 | } 66 | 67 | object Standard { 68 | def unapply(e: HTTP2Error): Option[(Int, Option[ErrorInfo], Option[ByteString])] = { 69 | e match { 70 | case s: Standard => Some(s.code, s.errorInfo, s.debugData) 71 | case _ => None 72 | } 73 | } 74 | 75 | def fromCode(errorCode: Int): HTTP2Error = { 76 | errorCode match { 77 | case NO_ERROR => NoError() 78 | case PROTOCOL_ERROR => ProtocolError() 79 | case INTERNAL_ERROR => InternalError() 80 | case FLOW_CONTROL_ERROR => FlowControlError() 81 | case STREAM_CLOSED => StreamClosedError() 82 | case FRAME_SIZE_ERROR => InvalidFrameSize() 83 | case REFUSED_STREAM => RefusedStream() 84 | case COMPRESSION_ERROR => CompressionError() 85 | case SETTINGS_TIMEOUT => SettingsTimeout() 86 | case e => GenericStandard()(e) 87 | } 88 | } 89 | } 90 | 91 | implicit def stringToErrorInfo(message: String): Option[ErrorInfo] 92 | = Some(ErrorInfo(message)) 93 | 94 | case class NoError(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 95 | extends Standard(NO_ERROR) 96 | { 97 | final type Self = NoError 98 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 99 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 100 | } 101 | 102 | case class ProtocolError(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 103 | extends Standard(PROTOCOL_ERROR) 104 | { 105 | final type Self = ProtocolError 106 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 107 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 108 | } 109 | 110 | case class InvalidStream(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 111 | extends Standard(PROTOCOL_ERROR) 112 | { 113 | final type Self = InvalidStream 114 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 115 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 116 | } 117 | 118 | case class InvalidFrameSize(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 119 | extends Standard(FRAME_SIZE_ERROR) 120 | { 121 | final type Self = InvalidFrameSize 122 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 123 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 124 | } 125 | 126 | case class InvalidWindowUpdate(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 127 | extends Standard(PROTOCOL_ERROR) 128 | { 129 | final type Self = InvalidWindowUpdate 130 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 131 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 132 | } 133 | 134 | case class InvalidPadding(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 135 | extends Standard(PROTOCOL_ERROR) 136 | { 137 | final type Self = InvalidPadding 138 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 139 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 140 | } 141 | 142 | case class ContinuationError(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 143 | extends Standard(PROTOCOL_ERROR) 144 | { 145 | final type Self = ContinuationError 146 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 147 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 148 | } 149 | 150 | case class CompressionError(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 151 | extends Standard(COMPRESSION_ERROR) 152 | { 153 | final type Self = CompressionError 154 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 155 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 156 | } 157 | 158 | case class HeaderError(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 159 | extends Standard(PROTOCOL_ERROR) 160 | { 161 | final type Self = HeaderError 162 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 163 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 164 | } 165 | 166 | case class SettingsError(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 167 | extends Standard(PROTOCOL_ERROR) 168 | { 169 | final type Self = SettingsError 170 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 171 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 172 | } 173 | 174 | case class SettingsTimeout(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 175 | extends Standard(SETTINGS_TIMEOUT) 176 | { 177 | final type Self = SettingsTimeout 178 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 179 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 180 | } 181 | 182 | case class StreamClosedError(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 183 | extends Standard(STREAM_CLOSED) 184 | { 185 | final type Self = StreamClosedError 186 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 187 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 188 | } 189 | 190 | case class UnacceptableFrameError(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 191 | extends Standard(PROTOCOL_ERROR) 192 | { 193 | final type Self = UnacceptableFrameError 194 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 195 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 196 | } 197 | 198 | case class ExhaustedStreams(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 199 | extends Standard(NO_ERROR) 200 | { 201 | final type Self = ExhaustedStreams 202 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 203 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 204 | } 205 | 206 | case class RefusedStream(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 207 | extends Standard(REFUSED_STREAM) 208 | { 209 | final type Self = RefusedStream 210 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 211 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 212 | } 213 | 214 | case class InternalError(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 215 | extends Standard(INTERNAL_ERROR) 216 | { 217 | final type Self = InternalError 218 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 219 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 220 | } 221 | 222 | case class FlowControlError(errorInfo: Option[ErrorInfo] = None, debugData: Option[ByteString] = None) 223 | extends Standard(FLOW_CONTROL_ERROR) 224 | { 225 | final type Self = FlowControlError 226 | def withDebugData(debugData: Option[ByteString]) = copy(debugData = debugData) 227 | def withErrorInfo(errorInfo: Option[ErrorInfo]) = copy(errorInfo = errorInfo) 228 | } 229 | 230 | object NonStandard { 231 | def unapply(e: HTTP2Error): Option[(Int, Option[ErrorInfo], Option[ByteString])] = { 232 | Standard.unapply(e) match { 233 | case Some(_) => None 234 | case None => Some(e.code, e.errorInfo, e.debugData) 235 | } 236 | } 237 | } 238 | } 239 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/protocol/Http2Stream.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.protocol 2 | 3 | import scala.collection.mutable 4 | import scala.collection.JavaConverters._ 5 | import scala.concurrent.{Future, Promise} 6 | import scala.util.Success 7 | import scalaz.\/ 8 | import scalaz.syntax.either._ 9 | import net.danielkza.http2.util.ArrayQueue 10 | import net.danielkza.http2.protocol.HTTP2Error._ 11 | 12 | abstract class Http2Stream(initialState: Http2Stream.State) { 13 | import Http2Stream._ 14 | import Frame._ 15 | 16 | protected var _state = initialState 17 | protected var _inFlowWindow = 65535 18 | protected var _outFlowWindow = 65535 19 | protected var _inWeight = 16 20 | protected var _outWeight = 16 21 | protected var _closeError: Option[HTTP2Error] = None 22 | 23 | def id: Int 24 | 25 | final def state: State = _state 26 | protected[http2] def state_=(s: State): Unit = _state = s 27 | 28 | final def inWeight: Int = _inWeight 29 | protected[http2] def inWeight_=(w: Int): Unit = _inWeight = w 30 | final def outWeight: Int = _outWeight 31 | protected[http2] def outWeight_=(w: Int): Unit = _outWeight = w 32 | 33 | final def inFlowWindow: Int = _inFlowWindow 34 | protected[http2] final def inFlowWindow_=(w: Int): Unit= _inFlowWindow = w 35 | protected[http2] def incrementInFlowWindow(inc: Int): Unit = _inFlowWindow += inc 36 | protected[http2] def decrementInFlowWindow(dec: Int): Unit = _inFlowWindow -= dec 37 | 38 | final def outFlowWindow: Int = _outFlowWindow 39 | protected[http2] final def outFlowWindow_=(w: Int) = _outFlowWindow = w 40 | protected[http2] def incrementOutFlowWindow(inc: Int): Unit = _outFlowWindow += inc 41 | protected[http2] def decrementOutFlowWindow(dec: Int): Unit = _outFlowWindow -= dec 42 | 43 | protected[http2] def inWindowAvailable(size: Int): Boolean 44 | protected[http2] def outWindowAvailable(size: Int): Boolean 45 | 46 | final def isClosed: Boolean = state.isClosed 47 | def closeError: Option[HTTP2Error] = closeError.filter(_ => this.isClosed) 48 | 49 | private def updateState(newState: Option[State]) = { 50 | (state, newState) match { 51 | case (_, Some(s)) => 52 | s.right 53 | case (Closed | HalfClosedByLocal | HalfClosedByRemote, _) => 54 | StreamClosedError("Frame received on closed stream").left 55 | case _ => 56 | UnacceptableFrameError().left 57 | } 58 | } 59 | 60 | def receive(frame: Frame): \/[HTTP2Error, State] = { 61 | import HTTP2Error.Codes._ 62 | 63 | state.onReceive.lift(frame).map { newState => 64 | (newState, frame) match { 65 | case (Closed, ResetStream(_, errorCode)) => 66 | close(HTTP2Error.Standard.fromCode(errorCode)) 67 | if(errorCode == CANCEL || errorCode == REFUSED_STREAM) 68 | Closed.right // Don't fail the current call for a simple stream reset 69 | else 70 | closeError.get.left 71 | case (_, d: Data) if !inWindowAvailable(d.data.length) => 72 | FlowControlError().left 73 | case (_, _) => 74 | updateState(Some(newState)) 75 | } 76 | }.getOrElse(updateState(None)) 77 | } 78 | 79 | def send(frame: Frame): \/[HTTP2Error, State] = { 80 | val newState = state.onSend.lift(frame) 81 | updateState(newState) 82 | } 83 | 84 | def close(error: HTTP2Error): Closed.type = { 85 | if(isClosed) throw new IllegalStateException("Stream already closed") 86 | _closeError = Some(error) 87 | updateState(Some(Closed)) 88 | Closed 89 | } 90 | } 91 | 92 | class ControlStream private[http2] extends Http2Stream(Http2Stream.Control) { 93 | override val id = 0 94 | 95 | val streams = new mutable.LongMap[DataStream] 96 | val delayedStreamsQueue = mutable.Queue.empty[DataStream] 97 | protected var _initialInFlowWindow = 65535 98 | protected var _initialOutFlowWindow = 65535 99 | 100 | protected[http2] def inWindowAvailable(size: Int): Boolean = 101 | inFlowWindow >= size 102 | 103 | protected[http2] def outWindowAvailable(size: Int): Boolean = 104 | outFlowWindow >= size 105 | 106 | protected[http2] override def incrementOutFlowWindow(inc: Int): Unit = { 107 | super.incrementOutFlowWindow(inc) 108 | runOutQueue() 109 | } 110 | 111 | protected[http2] def runOutQueue(): Unit = { 112 | while(outFlowWindow > 0) { 113 | val nextStream = delayedStreamsQueue.dequeueFirst(_.possibleProgressForWindow(outFlowWindow) > 0) 114 | nextStream match { 115 | case Some(s) => s.runOutQueue() 116 | case None => return 117 | } 118 | } 119 | } 120 | 121 | final def initialInFlowWindow: Int = _initialInFlowWindow 122 | protected[http2] def initialInFlowWindow_=(w: Int): Unit = _initialInFlowWindow = w 123 | final def initialOutFlowWindow: Int = _initialOutFlowWindow 124 | protected[http2] def initialOutFlowWindow_=(w: Int): Unit = _initialOutFlowWindow = w 125 | 126 | protected[http2] def addStream(dataStream: DataStream): Unit = { 127 | streams(dataStream.id) = dataStream 128 | dataStream.inFlowWindow = initialInFlowWindow 129 | dataStream.outFlowWindow = initialOutFlowWindow 130 | } 131 | } 132 | 133 | class DataStream private[http2] ( 134 | val id: Int, 135 | initialState: Http2Stream.State, 136 | controlStream: ControlStream, 137 | flowControlQueueSize: Int = 8) 138 | extends Http2Stream(initialState) { 139 | import Http2Stream._ 140 | import Frame._ 141 | import HTTP2Error._ 142 | 143 | protected var _parentStream = 0 144 | 145 | protected var delayedDataQueue = new ArrayQueue[(Promise[Data], Data)](flowControlQueueSize) 146 | 147 | def parentStream: Int = _parentStream 148 | protected[http2] def parentStream_=(p: Int) = _parentStream = p 149 | 150 | protected[http2] override def inWindowAvailable(size: Int): Boolean = 151 | controlStream.inWindowAvailable(size) && ownInWindowAvailable(size) 152 | 153 | protected[http2] def ownInWindowAvailable(size: Int): Boolean = 154 | inFlowWindow >= size 155 | 156 | protected[http2] override def outWindowAvailable(size: Int): Boolean = 157 | controlStream.outWindowAvailable(size) && ownOutWindowAvailable(size) 158 | 159 | protected[http2] def ownOutWindowAvailable(size: Int): Boolean = 160 | outFlowWindow >= size 161 | 162 | 163 | override protected[http2] def incrementOutFlowWindow(inc: Int): Unit = { 164 | super.incrementOutFlowWindow(inc) 165 | runOutQueue() 166 | } 167 | 168 | protected[http2] def possibleProgressForWindow(window: Int): Int = { 169 | var size = 0 170 | delayedDataQueue.iterator().asScala.foreach { case (_, data) => 171 | if(size + data.data.length > window) 172 | return size 173 | 174 | size += data.data.length 175 | } 176 | 177 | size 178 | } 179 | 180 | protected[http2] def runOutQueue(): Unit = { 181 | while(!delayedDataQueue.isEmpty) { 182 | val (promise, data) = delayedDataQueue.peek() 183 | if(outWindowAvailable(data.data.length)) { 184 | decrementOutFlowWindow(data.data.length) 185 | controlStream.decrementOutFlowWindow(data.data.length) 186 | delayedDataQueue.remove() 187 | 188 | promise.complete(Success(data)) 189 | } else { 190 | return 191 | } 192 | } 193 | } 194 | 195 | protected[http2] def acceptOutData(data: Data, closing: Boolean = false): Future[Data] = { 196 | state match { 197 | case Open | ReservedForRemote | HalfClosedByRemote => 198 | if(delayedDataQueue.isFull) { 199 | Future.failed(FlowControlError()) 200 | } else if(delayedDataQueue.isEmpty && outWindowAvailable(data.data.length)) { 201 | decrementOutFlowWindow(data.data.length) 202 | controlStream.decrementOutFlowWindow(data.data.length) 203 | Future.successful(data) 204 | } else { 205 | val promise = Promise[Data] 206 | delayedDataQueue.add((promise, data)) 207 | runOutQueue() 208 | promise.future 209 | } 210 | case _ => 211 | Future.failed(UnacceptableFrameError()) 212 | } 213 | } 214 | 215 | override def close(error: HTTP2Error = NoError()) = { 216 | val r = super.close(error) 217 | while(!delayedDataQueue.isEmpty) { 218 | val (promise, _) = delayedDataQueue.remove() 219 | promise.failure(error) 220 | } 221 | r 222 | } 223 | } 224 | 225 | object Http2Stream { 226 | import Frame._ 227 | 228 | sealed trait ReceiveAction 229 | sealed trait SendAction 230 | case class Continue(frame: Frame) extends ReceiveAction with SendAction 231 | case class Finish(frame: Frame) extends ReceiveAction with SendAction 232 | case class Delay(f: Future[Data]) extends SendAction 233 | case object Stop extends ReceiveAction with SendAction 234 | case object Skip extends ReceiveAction 235 | 236 | sealed trait State { 237 | def onReceive: PartialFunction[Frame, State] = defaults 238 | def onSend: PartialFunction[Frame, State] = defaults 239 | def isClosed: Boolean = false 240 | 241 | def defaults: PartialFunction[Frame, State] = { 242 | case rst: ResetStream => Closed 243 | case p: Priority => this 244 | case w: WindowUpdate => this 245 | } 246 | 247 | def withDefaults(f: PartialFunction[Frame, State]): PartialFunction[Frame, State] = 248 | f orElse defaults 249 | } 250 | 251 | case object Control extends State { 252 | override val onReceive: PartialFunction[Frame, State] = { 253 | case _: Settings => this 254 | case _: Ping => this 255 | case _: GoAway => this 256 | case _: WindowUpdate => this 257 | } 258 | 259 | override val onSend: PartialFunction[Frame, State] = { 260 | case _: Settings => this 261 | case _: Ping => this 262 | case _: GoAway => this 263 | case _: WindowUpdate => this 264 | } 265 | } 266 | 267 | case object Idle extends State { 268 | override val onReceive: PartialFunction[Frame, State] = { 269 | case h: Headers if h.endStream => HalfClosedByRemote 270 | case _: Headers => Open 271 | case _: Priority => this 272 | } 273 | 274 | override val onSend: PartialFunction[Frame, State] = { 275 | case h: Headers if h.endStream => HalfClosedByLocal 276 | case _: Headers => Open 277 | case _: Priority => this 278 | } 279 | } 280 | 281 | case object ReservedForRemote extends State { 282 | override val onReceive: PartialFunction[Frame, State] = withDefaults { 283 | case h: Headers => HalfClosedByRemote 284 | } 285 | } 286 | 287 | case object ReservedForLocal extends State { 288 | override val onSend: PartialFunction[Frame, State] = withDefaults { 289 | case h: Headers => HalfClosedByLocal 290 | } 291 | } 292 | 293 | case object Open extends State { 294 | override val onReceive: PartialFunction[Frame, State] = { 295 | case d: Data if d.endStream => HalfClosedByRemote 296 | case h: Headers if h.endStream => HalfClosedByRemote 297 | case _: ResetStream => Closed 298 | case _ => this 299 | } 300 | 301 | override val onSend: PartialFunction[Frame, State] = { 302 | case d: Data if d.endStream => HalfClosedByLocal 303 | case h: Headers if h.endStream => HalfClosedByLocal 304 | case _: ResetStream => Closed 305 | case _ => this 306 | } 307 | } 308 | 309 | case object HalfClosedByLocal extends State { 310 | override val onReceive: PartialFunction[Frame, State] = { 311 | case d: Data if d.endStream => Closed 312 | case h: Headers if h.endStream => Closed 313 | case _: ResetStream => Closed 314 | case _ => this 315 | } 316 | } 317 | 318 | case object HalfClosedByRemote extends State { 319 | override val onSend: PartialFunction[Frame, State] = { 320 | case d: Data if d.endStream => Closed 321 | case h: Headers if h.endStream => Closed 322 | case _: ResetStream => Closed 323 | case _ => this 324 | } 325 | } 326 | 327 | case object Closed extends State { 328 | override def isClosed = true 329 | } 330 | 331 | case object Dead extends State { 332 | override def isClosed = true 333 | 334 | override val onReceive = PartialFunction.empty[Frame, State] 335 | override val onSend = PartialFunction.empty[Frame, State] 336 | } 337 | 338 | // sealed trait PriorityTree { 339 | // def weight: Int 340 | // } 341 | // 342 | // object PriorityTree { 343 | // case class Leaf(stream: Int, weight: Int) extends PriorityTree 344 | // case class Node(children: PriorityTree, weight: Int) extends PriorityTree 345 | // } 346 | } 347 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/protocol/Setting.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.protocol 2 | 3 | final case class Setting(identifier: Short, value: Int) 4 | 5 | object Setting { 6 | object Identifiers { 7 | final val SETTINGS_HEADER_TABLE_SIZE: Short = 0x1 8 | final val SETTINGS_ENABLE_PUSH: Short = 0x2 9 | final val SETTINGS_MAX_CONCURRENT_STREAMS: Short = 0x3 10 | final val SETTINGS_INITIAL_WINDOW_SIZE: Short = 0x4 11 | final val SETTINGS_MAX_FRAME_SIZE: Short = 0x5 12 | final val SETTINGS_MAX_HEADER_LIST_SIZE: Short = 0x6 13 | } 14 | 15 | import Identifiers._ 16 | object Standard { 17 | def unapply(setting: Setting): Option[(Short, Int)] = { 18 | if(setting.identifier >= SETTINGS_HEADER_TABLE_SIZE && setting.identifier <= SETTINGS_MAX_HEADER_LIST_SIZE) 19 | Some(setting.identifier -> setting.value) 20 | else 21 | None 22 | } 23 | } 24 | 25 | object NonStandard { 26 | def unapply(setting: Setting): Option[(Short, Int)] = Standard.unapply(setting) match { 27 | case Some(_) => None 28 | case None => Some(setting.identifier -> setting.value) 29 | } 30 | } 31 | 32 | case class Extractor(identifier: Short) { 33 | def unapply(settings: List[Setting]): Option[Int] = { 34 | settings.find(_.identifier == identifier).map(_.value) 35 | } 36 | } 37 | 38 | final val ExtractHeaderTableSize = Extractor(SETTINGS_HEADER_TABLE_SIZE) 39 | final val ExtractEnablePush = Extractor(SETTINGS_ENABLE_PUSH) 40 | final val ExtractMaxConcurrentStreams = Extractor(SETTINGS_MAX_CONCURRENT_STREAMS) 41 | final val ExtractInitialWindowSize = Extractor(SETTINGS_INITIAL_WINDOW_SIZE) 42 | final val ExtractMaxFrameSize = Extractor(SETTINGS_MAX_FRAME_SIZE) 43 | final val ExtractMaxHeaderListSize = Extractor(SETTINGS_MAX_HEADER_LIST_SIZE) 44 | } 45 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/protocol/StreamManager.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.protocol 2 | 3 | import scala.collection.mutable 4 | import scalaz.\/ 5 | import scalaz.syntax.either._ 6 | 7 | class StreamManager(val isClient: Boolean, val maxInStreams: Int, val ourMaxOutStreams: Int) 8 | { 9 | import Frame._ 10 | import HTTP2Error._ 11 | import Http2Stream._ 12 | import StreamManager._ 13 | import Setting._ 14 | 15 | private val controlStream = new ControlStream 16 | 17 | private var _lastProcessedId: Int = 0 18 | private var _lastOutId: Int = 0 19 | private var outPeerMaxStreams = Int.MaxValue 20 | 21 | private var openInStreams: Int = 0 22 | private var openOutStreams: Int = 0 23 | 24 | def lastProcessedId: Int = _lastProcessedId 25 | def lastOutId: Int = _lastOutId 26 | 27 | def markProcessed(streamId: Int): Unit = { 28 | if(checkInId(streamId) && controlStream.streams.contains(streamId) && streamId > _lastProcessedId) 29 | _lastProcessedId = streamId 30 | } 31 | 32 | def maxOutStreams: Int = Math.min(ourMaxOutStreams, outPeerMaxStreams) 33 | 34 | private def checkInId(id: Int): Boolean = { 35 | if(isClient) id != 0 && id % 2 == 0 36 | else id % 2 != 0 37 | } 38 | 39 | private def checkOutId(id: Int): Boolean = { 40 | id != 0 && !checkInId(id) 41 | } 42 | 43 | private def closeUnusedInStreamsBelow(topId: Int): Unit = { 44 | var id = topId - 2 45 | 46 | while(id > 0) { 47 | if(controlStream.streams.contains(id)) 48 | return 49 | 50 | controlStream.addStream(new DataStream(id, Closed, controlStream)) 51 | id -= 2 52 | } 53 | } 54 | 55 | private def createInStream(id: Int, parent: Option[Int], dependency: Option[StreamDependency]) 56 | : \/[HTTP2Error, DataStream] = 57 | { 58 | if(!checkInId(id) || controlStream.streams.contains(id)) 59 | return InvalidStream().left 60 | 61 | if(parent.isEmpty && openInStreams >= maxInStreams) 62 | return RefusedStream("Available streams exhausted, must reconnect").left 63 | 64 | closeUnusedInStreamsBelow(id) 65 | val state = if(parent.isEmpty) { 66 | ReservedForRemote 67 | } else { 68 | openInStreams += 1 69 | Idle 70 | } 71 | 72 | val stream = new DataStream(id, state, controlStream) 73 | controlStream.addStream(stream) 74 | parent.foreach { p => stream.parentStream = p } 75 | 76 | stream.right 77 | } 78 | 79 | protected def adjustStreamWindow(delta: Int, stream: Http2Stream) = { 80 | if(delta >= 0) 81 | controlStream.incrementOutFlowWindow(delta) 82 | else 83 | controlStream.decrementOutFlowWindow(-delta) 84 | } 85 | 86 | def receive(frame: Frame): \/[HTTP2Error, ReceiveReply] = { 87 | def getStream(id: Int) = 88 | controlStream.streams.get(id).map(_.right).getOrElse(InvalidStream().left) 89 | 90 | def getOrCreateStream = frame match { 91 | case h: Headers => 92 | createInStream(h.stream, None, h.streamDependency) 93 | case pp: PushPromise => 94 | getStream(pp.stream).leftMap(_.withErrorMessage("Invalid parent stream in PUSH_PROMISE frame")) 95 | case f if f.stream == 0 => 96 | controlStream.right 97 | case f => 98 | getStream(f.stream) 99 | } 100 | 101 | for { 102 | stream <- getOrCreateStream 103 | newState <- stream.receive(frame) 104 | action <- frame match { 105 | case pp: PushPromise => 106 | createInStream(pp.promisedStream, Some(pp.stream), None).map(_ => Continue(pp)) 107 | case Priority(id, StreamDependency(exclusive, parent, weight)) if checkInId(id) => 108 | stream match { 109 | case ds: DataStream => 110 | ds.parentStream = parent 111 | ds.inWeight = weight 112 | // TODO: Handle exclusive 113 | Skip.right 114 | case _ => 115 | InvalidStream("Priority cannot be send to control stream").left 116 | } 117 | case _: Priority => 118 | InvalidStream("Invalid Priority stream").left 119 | case WindowUpdate(id, increment) => 120 | stream.incrementOutFlowWindow(increment) 121 | Skip.right 122 | case _: ResetStream => 123 | Stop.right 124 | case g: GoAway => 125 | closeAll() 126 | Finish(g).right 127 | case s @ Settings(ExtractInitialWindowSize(window), _) => 128 | val delta = window - controlStream.initialOutFlowWindow 129 | adjustStreamWindow(delta, controlStream) 130 | controlStream.streams.foreach { 131 | case (_, adjStream) if stream.state == Open || stream.state == HalfClosedByLocal => 132 | adjustStreamWindow(delta, adjStream) 133 | case _ => 134 | } 135 | controlStream.initialOutFlowWindow = window 136 | Skip.right 137 | case s @ Settings(ExtractMaxConcurrentStreams(numStreams), _) => 138 | outPeerMaxStreams = numStreams 139 | Skip.right 140 | case s: Settings => 141 | Skip.right 142 | case f if stream.isClosed => 143 | Finish(f).right 144 | case f => 145 | Continue(f).right 146 | } 147 | } yield { 148 | stream.state = newState 149 | ReceiveReply(stream, action) 150 | } 151 | } 152 | 153 | private def allocateOut(f: Int => DataStream): Option[Http2Stream] = { 154 | // The stream ID will wrap back to negative if it goes past 2^31, which is also the maximum acceptable value. 155 | // In this case, there's nothing to do and we just fail the allocation, and the peer should open a new connection 156 | // if desired (RFC-7540, Section 5.1.1) 157 | val newId = lastOutId + 2 158 | if(newId <= 0) 159 | None 160 | else { 161 | val stream = f(newId) 162 | controlStream.addStream(stream) 163 | _lastOutId = newId 164 | 165 | Some(stream) 166 | } 167 | } 168 | 169 | def reserveOut(): Option[Http2Stream] = { 170 | allocateOut { newId => 171 | new DataStream(newId, ReservedForLocal, controlStream) 172 | } 173 | } 174 | 175 | private def checkPromisedStreamOut(id: Int): Boolean = { 176 | checkOutId(id) && controlStream.streams.get(id).exists { stream => 177 | stream.state == ReservedForLocal 178 | } 179 | } 180 | 181 | def send(frame: Frame): \/[HTTP2Error, SendReply] = { 182 | for { 183 | stream <- controlStream.streams.get(frame.stream).map(_.right).getOrElse(InvalidStream().left) 184 | newState <- stream.send(frame) 185 | action <- frame match { 186 | case pp: PushPromise if !checkPromisedStreamOut(pp.promisedStream) => 187 | // At this point the frame sender already has to have allocated the stream for the PushPromise. So we just 188 | // check if that is the case, without creating anything 189 | InvalidStream("Invalid promised stream in PUSH_PROMISE frame").left 190 | case w @ WindowUpdate(_, increment) => 191 | stream.incrementInFlowWindow(increment) 192 | Continue(w).right 193 | case p @ Priority(id, StreamDependency(exclusive, parent, weight)) if checkOutId(id) => 194 | stream.parentStream = parent 195 | stream.outWeight = weight 196 | // TODO: Handle exclusive 197 | Continue(p).right 198 | case _: Priority => 199 | InvalidStream("Invalid stream in PRIORITY frame").left 200 | case rst: ResetStream => 201 | Finish(rst).right 202 | case d: Data => 203 | Delay(stream.acceptOutData(d)).right 204 | case g: GoAway => 205 | closeAll() 206 | Finish(g).right 207 | case s @ Settings(ExtractInitialWindowSize(window), _) => 208 | controlStream.initialOutFlowWindow = window 209 | Continue(s).right 210 | case s @ Settings(ExtractMaxConcurrentStreams(numStreams), _) if numStreams > maxInStreams => 211 | SettingsError("Maximum in streams exceeds configuratin").left 212 | case f if stream.isClosed => 213 | Finish(f).right 214 | case f => 215 | Continue(f).right 216 | } 217 | } yield { 218 | stream.state = newState 219 | SendReply(stream, action) 220 | } 221 | } 222 | 223 | private def closeStream(stream: DataStream): Boolean = { 224 | if(!stream.isClosed) { 225 | val oldState = stream.state 226 | stream.close() 227 | 228 | if(checkInId(stream.id) && (oldState == HalfClosedByLocal || oldState == Open)) 229 | openInStreams -= 1 230 | else if(checkOutId(stream.id) && (oldState == HalfClosedByRemote || oldState == Open)) 231 | openOutStreams -= 1 232 | 233 | true 234 | } else { 235 | false 236 | } 237 | } 238 | 239 | private[http2] def closeAll() = controlStream.streams.values.foreach(closeStream) 240 | 241 | def close(streamId: Int): \/[HTTP2Error, Unit] = { 242 | controlStream.streams.get(streamId).map { stream => 243 | if(!closeStream(stream)) 244 | StreamClosedError().left 245 | else 246 | ().right 247 | } getOrElse InvalidStream().left 248 | } 249 | } 250 | 251 | object StreamManager { 252 | import Http2Stream._ 253 | 254 | case class ReceiveReply(stream: Http2Stream, action: ReceiveAction) 255 | case class SendReply(stream: Http2Stream, action: SendAction) 256 | } 257 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/protocol/coders/FrameCoder.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.protocol.coders 2 | 3 | import scalaz._ 4 | import scalaz.std.list._ 5 | import scalaz.syntax.either._ 6 | import scalaz.syntax.traverse._ 7 | import akka.util.{ByteStringBuilder, ByteString} 8 | import net.danielkza.http2.Coder 9 | import net.danielkza.http2.protocol.{HTTP2Error, Frame, Setting} 10 | 11 | class FrameCoder extends Coder[Frame] { 12 | import Frame._ 13 | import Frame.Flags._ 14 | import IntCoder._ 15 | import HTTP2Error._ 16 | 17 | override final type Error = HTTP2Error 18 | 19 | protected def decodeStreamDependency: DecodeStateT[StreamDependency] = { 20 | for { 21 | stream <- int.decodeS 22 | weight <- byte.decodeS 23 | } yield { 24 | val exclusive = (stream >>> 31) != 0 25 | val streamNum = stream & 0x7FFFFFFF 26 | StreamDependency(exclusive, streamNum, (weight & 0xFF) + 1) 27 | } 28 | } 29 | 30 | protected def decodeUnpaddedBytes(length: Int): DecodeStateT[ByteString] = { 31 | takeS(length) 32 | } 33 | 34 | protected def decodeBytes(length: Int, padded: Boolean): DecodeStateT[(ByteString, Option[ByteString])] = { 35 | val SM = stateMonad[ByteString]; import SM._ 36 | 37 | if(!padded) { 38 | decodeUnpaddedBytes(length).map(bs => bs -> None) 39 | } else { 40 | for { 41 | paddingLen <- if(padded) byte.decodeS 42 | else pure(0: Byte) 43 | dataLen = length - paddingLen - 1 44 | _ <- ensureS(new InvalidPadding) { dataLen > paddingLen } 45 | data <- takeS(dataLen) 46 | padding <- takeS(paddingLen) 47 | _ <- ensureS(new InvalidFrameSize) { data.length == dataLen && padding.length == paddingLen } 48 | } yield data -> (Some(padding): Option[ByteString]) 49 | } 50 | } 51 | 52 | protected def decodeData(length: Int, stream: Int, padded: Boolean, endStream: Boolean): DecodeStateT[Data] = { 53 | decodeBytes(length, padded).map { case (content, padding) => Data(stream, content, endStream, padding) } 54 | } 55 | 56 | protected def decodeHeaders(length: Int, stream: Int, padded: Boolean, streamDependency: Boolean, endStream: Boolean, 57 | endHeaders: Boolean): DecodeStateT[Headers] = { 58 | val SM = stateMonad[ByteString]; import SM._ 59 | 60 | for { 61 | bytes <- decodeBytes(length, padded) 62 | (content, padding) = bytes 63 | rem <- get 64 | headers <- for { 65 | _ <- SM.put(content) 66 | streamDependency <- if(streamDependency) decodeStreamDependency.map(Some(_)) 67 | else SM.pure(None) 68 | data <- get 69 | } yield Headers(stream, streamDependency, data, endStream = endStream, endHeaders = endHeaders, padding = padding) 70 | _ <- put(rem) 71 | } yield headers 72 | } 73 | 74 | protected def decodePriority(stream: Int): DecodeStateT[Priority] = { 75 | decodeStreamDependency.map(Priority(stream, _)) 76 | } 77 | 78 | protected def decodeResetStream(stream: Int): DecodeStateT[ResetStream] = { 79 | int.decodeS.map(ResetStream(stream, _)) 80 | } 81 | 82 | protected def decodeSingleSetting: DecodeStateT[Setting] = { 83 | for { 84 | identifier <- short.decodeS 85 | value <- int.decodeS 86 | } yield Setting(identifier, value) 87 | } 88 | 89 | protected def decodeSettings(num: Int, ack: Boolean): DecodeStateT[Settings] = { 90 | stateMonad[ByteString].replicateM(num, decodeSingleSetting).map(Settings(_, ack)) 91 | } 92 | 93 | protected def decodePushPromise(length: Int, stream: Int, padded: Boolean, endHeaders: Boolean) 94 | : DecodeStateT[PushPromise] = 95 | { 96 | val SM = stateMonad[ByteString]; import SM._ 97 | 98 | for { 99 | bytes <- decodeBytes(length, padded) 100 | (content, padding) = bytes 101 | rem <- get 102 | _ <- put(content) 103 | promisedStream <- int.decodeS 104 | _ <- ensureS(new InvalidStream) { promisedStream > 0 && promisedStream % 2 == 0 } 105 | data <- get 106 | _ <- put(rem) 107 | } yield PushPromise(stream, promisedStream, data, endHeaders, padding) 108 | } 109 | 110 | protected def decodePing(ack: Boolean): DecodeStateT[Ping] = { 111 | for { 112 | content <- decodeUnpaddedBytes(8) 113 | } yield Ping(content, ack) 114 | } 115 | 116 | protected def decodeGoAway(length: Int): DecodeStateT[GoAway] = { 117 | for { 118 | stream <- int.decodeS 119 | _ <- ensureS(new InvalidStream) { stream >= 0 } 120 | errorCode <- int.decodeS 121 | debugData <- takeS(length - 8) 122 | _ <- ensureS(new InvalidFrameSize) { debugData.length == length - 8 } 123 | } yield GoAway(stream, errorCode, debugData) 124 | } 125 | 126 | protected def decodeWindowUpdate(stream: Int): DecodeStateT[WindowUpdate] = { 127 | for { 128 | window <- int.decodeS 129 | windowVal = window & 0x7FFFFFFF 130 | _ <- ensureS(new InvalidWindowUpdate) { windowVal != 0 } 131 | } yield WindowUpdate(stream, windowVal) 132 | } 133 | 134 | protected def decodePassthrough(tpe: Byte, length: Int, stream: Int, flags: Byte) : DecodeStateT[NonStandard] = { 135 | for { 136 | content <- decodeUnpaddedBytes(length) 137 | } yield Frame.NonStandard(stream, tpe, flags, content) 138 | } 139 | 140 | protected def decodeContinuation(length: Int, stream: Int, endHeaders: Boolean): DecodeStateT[Continuation] = { 141 | decodeUnpaddedBytes(length).map(Continuation(stream, _, endHeaders)) 142 | } 143 | 144 | protected def checkStream[S](stream: Int, tpe: Byte) = { 145 | import Frame.Types._ 146 | ensureS[S](new InvalidStream) { 147 | if(stream != 0 && (tpe == SETTINGS || tpe == PING || tpe == GOAWAY)) 148 | false 149 | else if(stream == 0 && (tpe == DATA || tpe == HEADERS || tpe == RST_STREAM || tpe == PRIORITY || 150 | tpe == CONTINUATION)) 151 | false 152 | else 153 | true 154 | } 155 | } 156 | 157 | def payloadDecoder(tpe: Byte, length: Int, flags: Byte, stream: Int): \/[HTTP2Error, DecodeStateT[Frame]] = { 158 | def err = (new InvalidFrameSize).left 159 | 160 | val maybeHandler = tpe match { 161 | case Types.DATA => 162 | val padded = (flags & DATA.PADDED) != 0 163 | val endStream = (flags & DATA.END_STREAM) != 0 164 | if (padded && length < 1) err 165 | else decodeData(length, stream, padded, endStream).right 166 | 167 | case Types.HEADERS => 168 | val padded = (flags & HEADERS.PADDED) != 0 169 | val priority = (flags & HEADERS.PRIORITY) != 0 170 | val endStream = (flags & HEADERS.END_STREAM) != 0 171 | val endHeaders = (flags & HEADERS.END_HEADERS) != 0 172 | 173 | if(padded && priority && length < 6) err 174 | else if(priority && length < 5) err 175 | else if(padded && length < 1) err 176 | else decodeHeaders(length, stream, padded, priority, endStream, endHeaders).right 177 | 178 | case Types.PRIORITY => 179 | if (length != 5) err 180 | else decodePriority(stream).right 181 | 182 | case Types.RST_STREAM => 183 | if (length != 4) err 184 | else decodeResetStream(stream).right 185 | 186 | case Types.SETTINGS => 187 | if (length > 0 && (flags & SETTINGS.ACK) != 0) err 188 | else if (length % 6 != 0) err 189 | else decodeSettings(length / 6, (flags & SETTINGS.ACK) != 0).right 190 | 191 | case Types.PUSH_PROMISE => 192 | val padded = (flags & PUSH_PROMISE.PADDED) != 0 193 | if (padded && length < 5) err 194 | else if(length < 4) err 195 | else decodePushPromise(length, stream, padded, (flags & PUSH_PROMISE.END_HEADERS) != 0).right 196 | 197 | case Types.PING => 198 | if (length != 8) err 199 | else decodePing((flags & PING.ACK) != 0).right 200 | 201 | case Types.GOAWAY => 202 | if (length < 8) err 203 | else decodeGoAway(length).right 204 | 205 | case Types.WINDOW_UPDATE => 206 | if (length != 4) err 207 | else decodeWindowUpdate(stream).right 208 | 209 | case Types.CONTINUATION => 210 | decodeContinuation(length, stream, (flags & HEADERS.END_HEADERS) != 0).right 211 | 212 | case _ => 213 | decodePassthrough(tpe, length, stream, flags).right 214 | } 215 | 216 | // Convert from an invariant StateT of a subtype of Frame to one for Frame 217 | maybeHandler.map { handler => 218 | checkStream(stream, tpe).flatMap(_ => handler.map(f => f: Frame)) 219 | } 220 | } 221 | 222 | override def decode(bs: ByteString): \/[HTTP2Error, (Frame, Int)] = { 223 | decodeS.run(bs).map { case (rem, frame) => (frame, bs.length - rem.length) } 224 | } 225 | 226 | def decodeHeader: DecodeStateT[(DecodeStateT[Frame], Int)] = { 227 | val SMT = StateT.StateMonadTrans[ByteString]; import SMT._ 228 | 229 | for { 230 | length <- int24.decodeS 231 | tpe <- byte.decodeS 232 | flags <- byte.decodeS 233 | stream <- int.decodeS 234 | handler <- liftMU(payloadDecoder(tpe, length, flags, stream)) 235 | } yield handler -> length 236 | } 237 | 238 | override def decodeS: DecodeStateT[Frame] = { 239 | val SM = stateMonad[ByteString]; import SM._ 240 | 241 | for { 242 | partialResult <- decodeHeader 243 | (payloadHandler, remLength) = partialResult 244 | remInput <- get 245 | _ <- ensureS(new InvalidFrameSize) { remInput.length >= remLength } 246 | result <- payloadHandler 247 | } yield result 248 | } 249 | 250 | protected def encodeStreamDependency(streamDependency: StreamDependency): EncodeState = { 251 | val exclusiveBit = if(streamDependency.exclusive) 0x80000000 else 0 252 | for { 253 | _ <- int.encodeS(streamDependency.stream | exclusiveBit) 254 | _ <- byte.encodeS((streamDependency.weight - 1).toByte) 255 | } yield () 256 | } 257 | 258 | protected def encodeBytes(padding: Option[ByteString])(f: EncodeState): EncodeState = { 259 | val SM = stateMonad[ByteStringBuilder]; import SM._ 260 | 261 | padding.map { padding => 262 | for { 263 | _ <- ensureS(new InvalidPadding) { padding.length < 256 } 264 | _ <- byte.encodeS(padding.length.toByte) 265 | _ <- f 266 | _ <- modify { _ ++= padding } 267 | } yield () 268 | }.getOrElse { 269 | f 270 | } 271 | } 272 | 273 | protected def encodeData(data: Data): EncodeState = { 274 | val SM = stateMonad[ByteStringBuilder]; import SM._ 275 | 276 | encodeBytes(data.padding) { modify { _ ++= data.data } } 277 | } 278 | 279 | protected def encodeHeaders(headers: Headers): EncodeState = { 280 | val SM = stateMonad[ByteStringBuilder]; import SM._ 281 | 282 | encodeBytes(headers.padding) { 283 | for { 284 | _ <- headers.streamDependency.map(encodeStreamDependency).getOrElse(point(())) 285 | _ <- modify { _ ++= headers.headerFragment } 286 | } yield () 287 | } 288 | } 289 | 290 | protected def encodePriority(priority: Priority): EncodeState = { 291 | encodeStreamDependency(priority.streamDependency) 292 | } 293 | 294 | protected def encodeResetStream(resetStream: ResetStream): EncodeState = { 295 | int.encodeS(resetStream.errorCode) 296 | } 297 | 298 | protected def encodeSettings(settings: Settings): EncodeState = { 299 | type S[T] = StateTES[T, Error, ByteStringBuilder] 300 | settings.settings.traverse_[S] { case Setting(identifier, value) => 301 | for { 302 | _ <- short.encodeS(identifier) 303 | _ <- int.encodeS(value) 304 | } yield () 305 | } 306 | } 307 | 308 | protected def encodePushPromise(pushPromise: PushPromise): EncodeState = { 309 | val SM = stateMonad[ByteStringBuilder]; import SM._ 310 | 311 | encodeBytes(pushPromise.padding) { 312 | for { 313 | _ <- ensureS(new InvalidStream) { pushPromise.promisedStream >= 0 } 314 | _ <- int.encodeS(pushPromise.promisedStream) 315 | _ <- modify { _ ++= pushPromise.headerFragment } 316 | } yield () 317 | } 318 | } 319 | 320 | protected def encodePing(ping: Ping): EncodeState = { 321 | val SM = stateMonad[ByteStringBuilder]; import SM._ 322 | 323 | for { 324 | _ <- ensureS(new InvalidFrameSize) { ping.data.length == 8 } 325 | _ <- encodeBytes(None) { modify { _ ++= ping.data } } 326 | } yield () 327 | } 328 | 329 | protected def encodeGoAway(goAway: GoAway): EncodeState = { 330 | val SM = stateMonad[ByteStringBuilder]; import SM._ 331 | 332 | for { 333 | _ <- ensureS(new InvalidStream) { goAway.lastStream >= 0 } 334 | _ <- int.encodeS(goAway.lastStream) 335 | _ <- int.encodeS(goAway.errorCode) 336 | _ <- modify { _ ++= goAway.debugData } 337 | } yield () 338 | } 339 | 340 | protected def encodeWindowUpdate(windowUpdate: WindowUpdate): EncodeState = { 341 | int.encodeS(windowUpdate.windowIncrement & 0x7FFFFFFF) 342 | } 343 | 344 | protected def encodeContinuation(continuation: Continuation): EncodeState = { 345 | val SM = stateMonad[ByteStringBuilder]; import SM._ 346 | 347 | encodeBytes(None) { modify { _ ++= continuation.headerFragment } } 348 | } 349 | 350 | 351 | protected def encodePassthrough(unknown: NonStandard): EncodeState = { 352 | val SM = stateMonad[ByteStringBuilder]; import SM._ 353 | 354 | modify { _ ++= unknown.payload } 355 | } 356 | 357 | override def encode(frame: Frame, stream: ByteStringBuilder): \/[HTTP2Error, Unit] = { 358 | encodeS(frame).eval(stream) 359 | } 360 | 361 | override def encodeS(frame: Frame): EncodeState = { 362 | val SM = stateMonad[ByteStringBuilder]; import SM._ 363 | 364 | for { 365 | _ <- checkStream(frame.stream, frame.tpe) 366 | buffer <- get 367 | _ <- put(ByteString.newBuilder) 368 | payload <- get 369 | _ <- frame match { 370 | case f: Data => encodeData(f) 371 | case f: Headers => encodeHeaders(f) 372 | case f: Priority => encodePriority(f) 373 | case f: ResetStream => encodeResetStream(f) 374 | case f: Settings => encodeSettings(f) 375 | case f: PushPromise => encodePushPromise(f) 376 | case f: Ping => encodePing(f) 377 | case f: GoAway => encodeGoAway(f) 378 | case f: WindowUpdate => encodeWindowUpdate(f) 379 | case f: Continuation => encodeContinuation(f) 380 | case f: NonStandard => encodePassthrough(f) 381 | } 382 | _ <- put(buffer) 383 | _ <- int24.encodeS(payload.length) 384 | _ <- byte.encodeS(frame.tpe) 385 | _ <- byte.encodeS(frame.flags) 386 | _ <- int.encodeS(frame.stream) 387 | _ <- modify { _ ++= payload.result() } 388 | } yield () 389 | } 390 | } 391 | 392 | object FrameCoder { 393 | sealed trait PartialDecodeResult 394 | object PartialDecodeResult { 395 | case class MoreData(length: Int) extends PartialDecodeResult 396 | case class Result(result: \/[HTTP2Error, Frame]) extends PartialDecodeResult 397 | } 398 | } 399 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/protocol/coders/IntCoder.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.protocol.coders 2 | 3 | import java.nio.ByteOrder 4 | import scalaz.\/ 5 | import scalaz.syntax.either._ 6 | import akka.util.{ByteStringBuilder, ByteString} 7 | import net.danielkza.http2.Coder 8 | import net.danielkza.http2.protocol.HTTP2Error 9 | 10 | abstract class IntCoder[T] extends Coder[T] { 11 | @inline def checkDecode(bs: ByteString)(f: ByteString => \/[HTTP2Error, (T, Int)]): \/[HTTP2Error, (T, Int)] = { 12 | try { 13 | f(bs) 14 | } catch { case e: IndexOutOfBoundsException => 15 | (new HTTP2Error.InvalidFrameSize).left 16 | } 17 | } 18 | } 19 | 20 | object IntCoder { 21 | object byte extends IntCoder[Byte] { 22 | override final type Error = HTTP2Error 23 | 24 | override final def decode(bs: ByteString) = checkDecode(bs) { bs => 25 | (bs.head, 1).right 26 | } 27 | 28 | override final def encode(value: Byte, stream: ByteStringBuilder): \/[HTTP2Error, Unit] = { 29 | stream.putByte(value) 30 | ().right 31 | } 32 | 33 | override final def encode(value: Byte): \/[HTTP2Error, ByteString] = 34 | ByteString(value).right 35 | } 36 | 37 | object short extends IntCoder[Short] { 38 | override final type Error = HTTP2Error 39 | 40 | override final def decode(bs: ByteString): \/[HTTP2Error, (Short, Int)] = checkDecode(bs) { bs => 41 | (( 42 | (bs(0) & 0xFF) << 8 | 43 | (bs(1) & 0xFF) << 0 44 | ).toShort -> 2).right 45 | } 46 | 47 | override final def encode(value: Short, stream: ByteStringBuilder): \/[HTTP2Error, Unit] = { 48 | stream.putShort(value)(ByteOrder.BIG_ENDIAN) 49 | ().right 50 | } 51 | 52 | override final def encode(value: Short): \/[HTTP2Error, ByteString] = 53 | ByteString((value >>> 8).toByte, 54 | (value >>> 0).toByte).right 55 | } 56 | 57 | object int extends IntCoder[Int] { 58 | override final type Error = HTTP2Error 59 | 60 | override final def decode(bs: ByteString): \/[HTTP2Error, (Int, Int)] = checkDecode(bs) { bs => 61 | (( 62 | (bs(0) & 0xFF) << 24 | 63 | (bs(1) & 0xFF) << 16 | 64 | (bs(2) & 0xFF) << 8 | 65 | (bs(3) & 0xFF) << 0 66 | ) -> 4).right 67 | } 68 | 69 | override final def encode(value: Int, stream: ByteStringBuilder): \/[HTTP2Error, Unit] = { 70 | stream.putInt(value)(ByteOrder.BIG_ENDIAN) 71 | ().right 72 | } 73 | 74 | override final def encode(value: Int): \/[HTTP2Error, ByteString] = 75 | ByteString((value >>> 24).toByte, 76 | (value >>> 16).toByte, 77 | (value >>> 8).toByte, 78 | (value >>> 0).toByte).right 79 | } 80 | 81 | object int24 extends IntCoder[Int] { 82 | override final type Error = HTTP2Error 83 | 84 | override final def decode(bs: ByteString): \/[HTTP2Error, (Int, Int)] = checkDecode(bs) { bs => 85 | (( 86 | (bs(0) & 0xFF) << 16 | 87 | (bs(1) & 0xFF) << 8 | 88 | (bs(2) & 0xFF) << 0 89 | ) -> 3).right 90 | } 91 | 92 | override final def encode(value: Int, stream: ByteStringBuilder): \/[HTTP2Error, Unit] = { 93 | stream.putByte((value >>> 16).toByte) 94 | stream.putByte((value >>> 8).toByte) 95 | stream.putByte((value >>> 0).toByte) 96 | ().right 97 | } 98 | 99 | override final def encode(value: Int): \/[HTTP2Error, ByteString] = 100 | ByteString((value >>> 16).toByte, 101 | (value >>> 8).toByte, 102 | (value >>> 0).toByte).right 103 | } 104 | 105 | object long extends IntCoder[Long] { 106 | override final type Error = HTTP2Error 107 | 108 | override final def decode(bs: ByteString): \/[HTTP2Error, (Long, Int)] = checkDecode(bs) { bs => 109 | (( 110 | (bs(0) & 0xFFL) << 56 | 111 | (bs(1) & 0xFFL) << 48 | 112 | (bs(2) & 0xFFL) << 40 | 113 | (bs(3) & 0xFFL) << 32 | 114 | (bs(4) & 0xFFL) << 24 | 115 | (bs(5) & 0xFFL) << 16 | 116 | (bs(6) & 0xFFL) << 8 | 117 | (bs(7) & 0xFFL) << 0 118 | ) -> 8).right 119 | } 120 | 121 | override final def encode(value: Long, stream: ByteStringBuilder): \/[HTTP2Error, Unit] = { 122 | stream.putLong(value)(ByteOrder.BIG_ENDIAN) 123 | ().right 124 | } 125 | 126 | override final def encode(value: Long): \/[HTTP2Error, ByteString] = 127 | ByteString((value >>> 56).toByte, 128 | (value >>> 48).toByte, 129 | (value >>> 40).toByte, 130 | (value >>> 32).toByte, 131 | (value >>> 24).toByte, 132 | (value >>> 16).toByte, 133 | (value >>> 8).toByte, 134 | (value >>> 0).toByte).right 135 | } 136 | 137 | } 138 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/ssl/ALPNSSLContext.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.ssl 2 | 3 | import java.util 4 | import javax.net.ssl.{SSLContext, SSLEngine, SSLException} 5 | import scala.collection.JavaConversions._ 6 | import scala.collection.immutable 7 | import org.eclipse.jetty.alpn.ALPN 8 | 9 | class ALPNSSLContext(context: SSLContext, orderedProtocols: immutable.Seq[String]) 10 | extends WrappedSSLContext(context) 11 | { 12 | private class Provider(engine: SSLEngine) extends ALPN.ClientProvider with ALPN.ServerProvider { 13 | override def protocols(): util.List[String] = 14 | orderedProtocols 15 | 16 | override def selected(protocol: String): Unit = { 17 | if(!orderedProtocols.contains(protocol)) 18 | throw new SSLException(s"ALPN: Unsupported protocol $protocol") 19 | } 20 | 21 | override def select(serverProtocols: util.List[String]): String = { 22 | serverProtocols.find(orderedProtocols.contains(_)).getOrElse { 23 | throw new SSLException(s"ALPN: No common supported protocols with server") 24 | } 25 | } 26 | 27 | override def unsupported(): Unit = 28 | throw new SSLException(s"ALPN: Unsupported by peer") 29 | } 30 | 31 | override def mapEngine(engine: SSLEngine): SSLEngine = { 32 | ALPN.put(engine, new Provider(engine)) 33 | ALPN.debug = true 34 | engine 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/ssl/WrappedSSLContext.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.ssl 2 | 3 | import java.lang.reflect.Method 4 | import java.security.SecureRandom 5 | import javax.net.ssl._ 6 | 7 | import scala.annotation.tailrec 8 | import scala.collection.mutable 9 | 10 | class WrappedSSLContext(context: SSLContext) 11 | extends SSLContext(new WrappedSSLContext.SpiWrapper(WrappedSSLContext.getSpi(context)), context.getProvider, 12 | context.getProtocol) 13 | { 14 | import WrappedSSLContext._ 15 | 16 | getSpi(this).asInstanceOf[SpiWrapper].parent = this 17 | 18 | def mapEngine(engine: SSLEngine): SSLEngine = 19 | engine 20 | 21 | def mapSocketFactory(socketFactory: SSLSocketFactory): SSLSocketFactory = 22 | socketFactory 23 | 24 | def mapServerSocketFactory(serverSocketFactory: SSLServerSocketFactory): SSLServerSocketFactory = 25 | serverSocketFactory 26 | 27 | def mapSessionContext(context: SSLSessionContext): SSLSessionContext = 28 | context 29 | } 30 | 31 | object WrappedSSLContext { 32 | private[WrappedSSLContext] class SpiWrapper(val wrapped: SSLContextSpi, var parent: WrappedSSLContext = null) 33 | extends SSLContextSpi 34 | { 35 | private val methods = mutable.Map.empty[String, Method] 36 | 37 | @tailrec 38 | private def getDeclaredMethod(cls: Class[_], name: String, params: Class[_]*): Method = { 39 | try { 40 | return cls.getDeclaredMethod(name, params: _*) 41 | } catch { case e: NoSuchMethodException if cls.getSuperclass != null => 42 | // pass 43 | } 44 | 45 | getDeclaredMethod(cls.getSuperclass, name, params: _*) 46 | } 47 | 48 | private def callMethod[T](name: String, params: Class[_]*)(actualParams: AnyRef*) = { 49 | methods.getOrElseUpdate(s"$name(${params.map(_.toString).mkString(",")})", { 50 | val m = getDeclaredMethod(wrapped.getClass, name, params: _*) 51 | m.setAccessible(true) 52 | m 53 | }).invoke(wrapped, actualParams: _*).asInstanceOf[T] 54 | } 55 | 56 | override def engineCreateSSLEngine(): SSLEngine = parent.mapEngine( 57 | callMethod[SSLEngine]("engineCreateSSLEngine")() 58 | ) 59 | 60 | override def engineGetSocketFactory(): SSLSocketFactory = parent.mapSocketFactory( 61 | callMethod[SSLSocketFactory]("engineGetSocketFactory")() 62 | ) 63 | 64 | override def engineInit(keyManagers: Array[KeyManager], trustManagers: Array[TrustManager], 65 | secureRandom: SecureRandom): Unit = 66 | { 67 | callMethod[Unit]("engineInit", classOf[Array[KeyManager]], classOf[Array[TrustManager]], classOf[SecureRandom])( 68 | keyManagers, trustManagers, secureRandom) 69 | } 70 | 71 | override def engineCreateSSLEngine(peerHost: String, peerPort: Int): SSLEngine = parent.mapEngine( 72 | callMethod[SSLEngine]("engineCreateSSLEngine", classOf[String], classOf[Int])(peerHost, Int.box(peerPort)) 73 | ) 74 | 75 | override def engineGetClientSessionContext(): SSLSessionContext = parent.mapSessionContext( 76 | callMethod[SSLSessionContext]("engineGetClientSessionContext")() 77 | ) 78 | 79 | override def engineGetServerSessionContext(): SSLSessionContext = parent.mapSessionContext( 80 | callMethod[SSLSessionContext]("engineGetServerSessionContext")() 81 | ) 82 | 83 | override def engineGetServerSocketFactory(): SSLServerSocketFactory = parent.mapServerSocketFactory( 84 | callMethod[SSLServerSocketFactory]("engineGetServerSocketFactory")() 85 | ) 86 | } 87 | 88 | private val spiField = { 89 | val field = classOf[SSLContext].getDeclaredField("contextSpi") 90 | field.setAccessible(true) 91 | field 92 | } 93 | 94 | private [WrappedSSLContext] def getSpi(context: SSLContext): SSLContextSpi = 95 | spiField.get(context).asInstanceOf[SSLContextSpi] 96 | } 97 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/stream/ChunkedDataDecodeStage.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import scala.collection.immutable 4 | import akka.stream._ 5 | import akka.stream.stage._ 6 | import akka.http.scaladsl.model.HttpEntity.{LastChunk, ChunkStreamPart, Chunk} 7 | import net.danielkza.http2.protocol.{HTTP2Error, Frame} 8 | import net.danielkza.http2.protocol.Frame.{Data, Headers} 9 | import net.danielkza.http2.protocol.HTTP2Error.UnacceptableFrameError 10 | 11 | 12 | class ChunkedDataDecodeStage(val trailers: Boolean = false) 13 | extends GraphStage[FanOutShape2[Frame, ChunkStreamPart, Headers]] 14 | { 15 | val in: Inlet[Frame] = Inlet[Frame]("ChunkedDataDecodeStage.in") 16 | val out0: Outlet[ChunkStreamPart] = Outlet[ChunkStreamPart]("ChunkedDataDecodeStage.out0") 17 | val out1: Outlet[Headers] = Outlet[Headers]("ChunkedDataDecodeStage.out1") 18 | override val shape = new FanOutShape2(in, out0, out1) 19 | 20 | override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { logic => 21 | private var completed = false 22 | 23 | setHandler(in, new InHandler { 24 | override def onPush(): Unit = { 25 | grab(in) match { 26 | case d: Data if trailers && d.endStream => 27 | failStage(UnacceptableFrameError()) 28 | case d: Data if !trailers && d.endStream => 29 | if(d.data.isEmpty) emit(out0, LastChunk, () => completeStage()) 30 | else emitMultiple(out0, immutable.Seq(Chunk(d.data), LastChunk), () => completeStage()) 31 | completed = true 32 | case d: Data => 33 | emit(out0, Chunk(d.data)) 34 | case h: Headers if trailers && h.endStream => 35 | complete(out0) 36 | emit(out1, h, () => completeStage()) 37 | completed = true 38 | case h: Headers => 39 | failStage(UnacceptableFrameError()) 40 | } 41 | } 42 | 43 | override def onUpstreamFinish(): Unit = { 44 | if(!completed) 45 | failStage(HTTP2Error.HeaderError()) 46 | else 47 | super.onUpstreamFinish() 48 | } 49 | }) 50 | 51 | setHandler(out0, new OutHandler { 52 | override def onPull(): Unit = pull(in) 53 | }) 54 | 55 | setHandler(out1, new OutHandler { 56 | override def onPull(): Unit = { 57 | // Do nothing by default. Only forward the demand when emitting the single header frame, after we have already 58 | // grabbed all the Data frames (the stage set up by `emit` will take care of it) 59 | } 60 | }) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/stream/FrameDecoderStage.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import akka.stream.stage._ 4 | import akka.util.ByteString 5 | import net.danielkza.http2.util.Implicits._ 6 | import net.danielkza.http2.protocol.Frame.Settings 7 | import net.danielkza.http2.protocol.Frame 8 | import net.danielkza.http2.protocol.coders.FrameCoder 9 | import net.danielkza.http2.stream.FrameDecoderStage.NotHTTP2Exception 10 | 11 | import scala.annotation.tailrec 12 | import scalaz.{\/-, -\/} 13 | 14 | class FrameDecoderStage(val waitForClientPreface: Boolean) extends PushPullStage[ByteString, Frame] { 15 | import FrameDecoderStage.CONNECTION_PREFACE 16 | 17 | private val coder = new FrameCoder 18 | private var nextFrameCoder: Option[coder.DecodeState] = None 19 | private var stash = ByteString.empty 20 | private var needed = if(!waitForClientPreface) -1 else CONNECTION_PREFACE.length 21 | private var clientPrefaceNeeded = waitForClientPreface 22 | private var settingsFrameNeeded = !waitForClientPreface 23 | 24 | override def onPush(bytes: ByteString, ctx: Context[Frame]) = { 25 | stash ++= bytes 26 | run(ctx) 27 | } 28 | 29 | override def onPull(ctx: Context[Frame]) = run(ctx) 30 | 31 | override def onUpstreamFinish(ctx: Context[Frame]) = 32 | if (stash.isEmpty) ctx.finish() 33 | else ctx.absorbTermination() 34 | 35 | @tailrec 36 | private def run(ctx: Context[Frame]): SyncDirective = { 37 | if (needed == -1) { 38 | if (stash.length < Frame.HEADER_LENGTH) { 39 | pullOrFinish(ctx) 40 | } else { 41 | coder.decodeHeader.run(stash) match { 42 | case -\/(error) => 43 | ctx.fail(error) 44 | case \/-((remaining, (nextFrameCoder, remainingFrameLen))) => 45 | stash = remaining 46 | needed = remainingFrameLen 47 | this.nextFrameCoder = Some(nextFrameCoder) 48 | run(ctx) 49 | } 50 | } 51 | } else if (stash.length < needed) { 52 | pullOrFinish(ctx) 53 | } else if(clientPrefaceNeeded) { 54 | val (preface, rest) = stash.splitAt(needed) 55 | if(preface != CONNECTION_PREFACE) 56 | ctx.fail(NotHTTP2Exception()) 57 | else { 58 | clientPrefaceNeeded = false 59 | settingsFrameNeeded = true 60 | needed = -1 61 | stash = rest 62 | run(ctx) 63 | } 64 | } else { 65 | nextFrameCoder.get.run(stash) match { 66 | case -\/(error) => 67 | ctx.fail(error) 68 | case \/-((remaining, frame)) => 69 | if(settingsFrameNeeded && !frame.isInstanceOf[Settings]) 70 | ctx.fail(NotHTTP2Exception()) 71 | else { 72 | settingsFrameNeeded = false 73 | stash = remaining 74 | needed = -1 75 | this.nextFrameCoder = None 76 | ctx.push(frame) 77 | } 78 | } 79 | } 80 | } 81 | 82 | private def pullOrFinish(ctx: Context[Frame]): SyncDirective = { 83 | if (ctx.isFinishing) ctx.finish() 84 | else ctx.pull() 85 | } 86 | } 87 | 88 | object FrameDecoderStage { 89 | case class NotHTTP2Exception(message: String = null, cause: Throwable = null) 90 | extends RuntimeException(message, cause) 91 | 92 | final val CONNECTION_PREFACE = hex_bs"505249202a20485454502f322e300d0a0d0a534d0d0a0d0a" 93 | } 94 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/stream/FrameEncoderStage.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import akka.stream.stage._ 4 | import akka.util.ByteString 5 | 6 | import net.danielkza.http2.protocol.Frame 7 | import net.danielkza.http2.protocol.coders.FrameCoder 8 | 9 | class FrameEncoderStage extends PushStage[Frame, ByteString] { 10 | private val coder = new FrameCoder 11 | 12 | override def onPush(frame: Frame, ctx: Context[ByteString]): SyncDirective = { 13 | coder.encodeS(frame).exec(ByteString.newBuilder).leftMap { error => 14 | ctx.fail(error.toException()) 15 | }.map { data => 16 | ctx.push(data.result()) 17 | }.merge 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/stream/HeaderCollapseStage.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import akka.stream.stage._ 4 | import net.danielkza.http2.protocol.Frame 5 | import net.danielkza.http2.protocol.Frame.{Headers, PushPromise, Continuation} 6 | import net.danielkza.http2.protocol.HTTP2Error.ContinuationError 7 | 8 | class HeaderCollapseStage extends StatefulStage[Frame, Frame] { 9 | object Passthrough extends State { 10 | override def onPush(frame: Frame, ctx: Context[Frame]): SyncDirective = { 11 | frame match { 12 | case h: Headers if !h.endHeaders => 13 | become(Continue(Left(h))) 14 | ctx.pull() 15 | case p: PushPromise if !p.endHeaders => 16 | become(Continue(Right(p))) 17 | ctx.pull() 18 | case c: Continuation => 19 | ctx.fail(ContinuationError().toException) 20 | case _ => 21 | ctx.push(frame) 22 | } 23 | } 24 | } 25 | 26 | case class Continue(initialFrame: Either[Headers, PushPromise]) extends State { 27 | var headerBlock = initialFrame.left.map(_.headerFragment).right.map(_.headerFragment).merge 28 | val stream = initialFrame.merge.stream 29 | 30 | def collapse: Frame = { 31 | initialFrame 32 | .left.map(_.copy(headerFragment = headerBlock, endHeaders = true)) 33 | .right.map(_.copy(headerFragment = headerBlock, endHeaders = true)) 34 | .merge 35 | } 36 | 37 | override def onPush(frame: Frame, ctx: Context[Frame]): SyncDirective = { 38 | frame match { 39 | case Continuation(`stream`, fragment, endHeaders) => 40 | headerBlock ++= fragment 41 | if(endHeaders) { 42 | become(Passthrough) 43 | ctx.push(collapse) 44 | } else { 45 | ctx.pull() 46 | } 47 | case _ => 48 | ctx.fail(ContinuationError().toException) 49 | } 50 | } 51 | } 52 | 53 | override def initial = Passthrough 54 | } 55 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/stream/HeaderDecodeActor.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import scala.language.implicitConversions 4 | import scalaz.{-\/, \/-} 5 | import akka.actor._ 6 | import net.danielkza.http2.protocol.HTTP2Error 7 | 8 | class HeaderDecodeActor(initialMaxTableSize: Int) 9 | extends HeaderTransformActorBase(initialMaxTableSize) 10 | { 11 | import HeaderDecodeActor._ 12 | import HeaderTransformActorBase._ 13 | 14 | override def receive: Receive = { 15 | case Fragment(block) => 16 | headerBlockCoder.decode(block) match { 17 | case \/-((headers, readLen)) if readLen == block.length => 18 | sender ! Headers(headers) 19 | case -\/(error) => 20 | sender ! Failure(compressionError(error.toString)) 21 | } 22 | case SetTableMaxSize(size) => 23 | if(size < 0) { 24 | sender ! Failure(settingsError) 25 | } else { 26 | updateHeaderBlockCoder(_.withMaxCapacity(size)) 27 | sender ! OK 28 | } 29 | } 30 | } 31 | 32 | object HeaderDecodeActor { 33 | def props(initialMaxTableSize: Int): Props = Props(new HeaderDecodeActor(initialMaxTableSize)) 34 | 35 | private [HeaderDecodeActor] def compressionError(message: String) = 36 | HTTP2Error.CompressionError().withErrorMessage(s"Header compression failure in decoding: $message") 37 | 38 | private [HeaderDecodeActor] lazy val settingsError = 39 | HTTP2Error.SettingsError().withErrorMessage("Invalid SETTINGS_MAX_HEADER_TABLE_SIZE") 40 | } 41 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/stream/HeaderEncodeActor.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import akka.actor.Props 4 | import net.danielkza.http2.protocol.HTTP2Error 5 | 6 | import scala.language.implicitConversions 7 | import scalaz.{-\/, \/-} 8 | 9 | class HeaderEncodeActor(initialMaxTableSize: Int) 10 | extends HeaderTransformActorBase(initialMaxTableSize) 11 | { 12 | import HeaderEncodeActor._ 13 | import HeaderTransformActorBase._ 14 | 15 | override def receive: Receive = { 16 | case Headers(headers) => 17 | headerBlockCoder.encode(headers) match { 18 | case -\/(error) => sender ! compressionError 19 | case \/-(block) => sender ! Fragment(block) 20 | } 21 | case SetTableMaxSize(size) => 22 | if(size < 0) { 23 | sender ! settingsError 24 | } else { 25 | updateHeaderBlockCoder(_.withMaxCapacity(size)) 26 | sender ! OK 27 | } 28 | } 29 | } 30 | 31 | object HeaderEncodeActor { 32 | def props(initialMaxTableSize: Int): Props = Props(new HeaderEncodeActor(initialMaxTableSize)) 33 | 34 | private [HeaderEncodeActor] lazy val compressionError = 35 | HTTP2Error.CompressionError().withErrorMessage("Header compression failure in encoding").toException 36 | 37 | private [HeaderEncodeActor] lazy val settingsError = 38 | HTTP2Error.SettingsError().withErrorMessage("Invalid SETTINGS_MAX_HEADER_TABLE_SIZE").toException 39 | } 40 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/stream/HeaderSplitStage.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import akka.stream.stage._ 4 | import net.danielkza.http2.protocol.Frame 5 | import net.danielkza.http2.protocol.Frame.{Headers, PushPromise, Continuation} 6 | 7 | class HeaderSplitStage(var maxFrameSize: Int = Frame.DEFAULT_MAX_FRAME_SIZE) extends StatefulStage[Frame, Frame] { 8 | case object Splitting extends State { 9 | override def onPush(frame: Frame, ctx: Context[Frame]): SyncDirective = { 10 | val frames = frame match { 11 | case h: Headers if h.endHeaders => 12 | splitFrame(Left(h)) 13 | case p: PushPromise if p.endHeaders => 14 | splitFrame(Right(p)) 15 | case _ => 16 | return ctx.push(frame) 17 | } 18 | 19 | emit(frames.iterator, ctx) 20 | } 21 | } 22 | 23 | def initial = Splitting 24 | 25 | def maxContinuationFragmentSize = maxFrameSize - Frame.HEADER_LENGTH 26 | 27 | def splitFrame(frame: Either[Headers, PushPromise]): Seq[Frame] = { 28 | var size = Frame.HEADER_LENGTH 29 | val block = frame.left.map { h => 30 | h.padding.foreach { size += _.length + 1 } 31 | h.streamDependency.foreach { _ => size += 5 } 32 | 33 | h.headerFragment 34 | }.right.map { p => 35 | p.padding.foreach { size += _.length + 1 } 36 | size += 4 37 | 38 | p.headerFragment 39 | }.merge 40 | 41 | size += block.length 42 | if(size <= maxFrameSize) 43 | return Seq(frame.merge) 44 | 45 | val headFragSize = block.length - (size - maxFrameSize) 46 | val split = block.splitAt(headFragSize) 47 | 48 | val headFrag = split._1 49 | val headFrame = frame.left.map { h => 50 | h.copy(headerFragment = headFrag, endHeaders = false) 51 | }.right.map { p => 52 | p.copy(headerFragment = headFrag, endHeaders = false) 53 | }.merge 54 | 55 | var rest = split._2 56 | val fragSize = maxContinuationFragmentSize 57 | 58 | val frames = Seq.newBuilder[Frame] 59 | frames.sizeHint(1 + (rest.length / fragSize)) 60 | frames += headFrame 61 | 62 | while(rest.length > fragSize) { 63 | val split = rest.splitAt(fragSize) 64 | frames += Continuation(headFrame.stream, split._1, endHeaders = false) 65 | rest = split._2 66 | } 67 | 68 | frames += Continuation(headFrame.stream, rest, endHeaders = true) 69 | frames.result() 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/stream/HeaderTransformActorBase.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import akka.actor.Actor 4 | import akka.util.ByteString 5 | import net.danielkza.http2.api.Header 6 | import net.danielkza.http2.hpack.coders.HeaderBlockCoder 7 | import net.danielkza.http2.protocol.HTTP2Error 8 | 9 | abstract class HeaderTransformActorBase(val initialMaxTableSize: Int) extends Actor { 10 | private var _headerBlockCoder: HeaderBlockCoder = null 11 | 12 | override def postStop(): Unit = { 13 | _headerBlockCoder = null 14 | super.postStop() 15 | } 16 | 17 | override def preStart(): Unit = { 18 | super.preStart() 19 | _headerBlockCoder = new HeaderBlockCoder(initialMaxTableSize) 20 | } 21 | 22 | def headerBlockCoder: HeaderBlockCoder = { 23 | if(_headerBlockCoder == null) 24 | throw new IllegalStateException("Header block coder uninitialized") 25 | 26 | _headerBlockCoder 27 | } 28 | 29 | def updateHeaderBlockCoder(f: HeaderBlockCoder => HeaderBlockCoder): Unit = { 30 | val newCoder = f(headerBlockCoder) 31 | if(newCoder == null) 32 | throw new IllegalArgumentException("Header block coder cannot be made null") 33 | 34 | _headerBlockCoder = newCoder 35 | } 36 | } 37 | 38 | object HeaderTransformActorBase { 39 | sealed trait Message 40 | case class Fragment(block: ByteString) extends Message 41 | case class Headers(headers: Seq[Header]) extends Message 42 | case class SetTableMaxSize(size: Int) extends Message 43 | case class Failure(error: HTTP2Error) extends Message 44 | case object OK extends Message 45 | } 46 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/stream/Http2Message.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import akka.util.ByteString 4 | import akka.http.scaladsl.model.{HttpRequest, HttpEntity} 5 | import net.danielkza.http2.api.Header 6 | 7 | private[http2] case class Http2Message(dataStream: Int, responseStream: Int, body: HttpEntity, 8 | headers: Seq[Header], promise: Option[Http2Message]= None, 9 | trailers: Seq[String] = Seq.empty) 10 | 11 | private[http2] object Http2Message { 12 | sealed trait Headers 13 | object Headers { 14 | case class Unencoded(headers: Seq[Header]) extends Headers 15 | case class Encoded(fragment: ByteString) extends Headers 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/stream/InboundStreamDispatcher.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import scala.collection.JavaConverters._ 4 | import scala.concurrent.ExecutionContext 5 | import scala.concurrent.duration._ 6 | import scala.util.Try 7 | import scalaz.\/ 8 | import scalaz.syntax.either._ 9 | import akka.util.Timeout 10 | import akka.actor._ 11 | import akka.pattern.ask 12 | import akka.stream._ 13 | import akka.stream.stage._ 14 | import net.danielkza.http2.util.ArrayQueue 15 | import net.danielkza.http2.protocol.{HTTP2Error, Http2Stream, Frame} 16 | 17 | class InboundStreamDispatcher(val outputPorts: Int, val bufferSize: Int, manager: => ActorRef) 18 | (implicit ec: ExecutionContext, timeout: Timeout = 5.seconds) 19 | extends GraphStage[UniformFanOutShape[Frame, Frame]] 20 | { 21 | import HTTP2Error._ 22 | import StreamManagerActor._ 23 | import Http2Stream._ 24 | 25 | val in: Inlet[Frame] = Inlet[Frame]("InboundStreamDispatcher.in") 26 | val outs: Array[Outlet[Frame]] = Array.tabulate(outputPorts) { i => 27 | Outlet[Frame]("InboundFrameDispatcher.out" + i) 28 | } 29 | 30 | override val shape = UniformFanOutShape[Frame, Frame](in, outs: _*) 31 | 32 | override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { logic => 33 | private val queues = Array.fill(outputPorts)(new ArrayQueue[Frame](bufferSize)) 34 | private val demands = Array.fill(outputPorts)(false) 35 | private val assignedStreams = Array.fill(outputPorts)(0) 36 | private val streamManager = manager 37 | private var finishing = -1 38 | 39 | private val resultCallback = getAsyncCallback[Try[OutMessage]] { 40 | case scala.util.Success(reply) => reply match { 41 | case Failure(error) => 42 | failStage(error) 43 | case IncomingAction(stream, Skip) => 44 | // skip 45 | case IncomingAction(stream, Continue(frame)) => 46 | getOrAssignPort(stream).map { port => 47 | queues(port).add(frame) 48 | dispatchAll() 49 | }.valueOr { error => failStage(error) } 50 | case IncomingAction(stream, Finish(frame)) => 51 | getOrAssignPort(stream).map { port => 52 | queues(port).add(frame) 53 | dispatchAll() 54 | assignedStreams(port) = 0 55 | }.valueOr { error => failStage(error) } 56 | case IncomingAction(stream, Stop) => 57 | // pass 58 | case _ => 59 | failStage(new RuntimeException("Unexpected response from StreamManagerActor")) 60 | } 61 | case scala.util.Failure(error) => 62 | failStage(error) 63 | } 64 | 65 | private val errorCallback = getAsyncCallback[Throwable](failStage) 66 | 67 | private def getOrAssignPort(stream: Int): \/[HTTP2Error, Int] = { 68 | var i: Int = 0 69 | var free: Int = -1 70 | while(i < outputPorts) { 71 | if(assignedStreams(i) == stream) return i.right 72 | if(free == -1 && assignedStreams(i) == 0) free = i 73 | 74 | i += 1 75 | } 76 | 77 | if(free != -1) { 78 | assignedStreams(free) = stream 79 | free.right 80 | } else { 81 | RefusedStream().left 82 | } 83 | } 84 | 85 | private def getAssignedPort(stream: Int): \/[HTTP2Error, Int] = { 86 | assignedStreams.indexOf(stream) match { 87 | case -1 => InvalidStream().left 88 | case i => i.right 89 | } 90 | } 91 | 92 | private def backed: Boolean = 93 | queues.exists(_.isFull) 94 | 95 | private def enqueueIn(): Unit = { 96 | val frame = grab(in) 97 | println(s"InboundStreamDispatcher: enqueueIn $frame"); System.out.flush() 98 | 99 | (streamManager ? IncomingFrame(frame)).mapTo[OutMessage].onComplete(resultCallback.invoke) 100 | } 101 | 102 | private def dispatchAll(): Unit = { 103 | var port: Int = 0 104 | while(port < outputPorts) { 105 | val out = outs(port) 106 | 107 | val queue = queues(port) 108 | if(demands(port) && !queue.isEmpty) { 109 | val frame = queue.remove() 110 | println(s"InboundStreamDispatcher: pushing $frame to $port"); System.out.flush() 111 | push(out, frame) 112 | demands(port) = false 113 | } 114 | 115 | port += 1 116 | } 117 | } 118 | 119 | def finishIn(): Unit = { 120 | println(s"InboundStreamDispatcher: finishing"); System.out.flush() 121 | finishing = outputPorts 122 | 123 | var port: Int = 0 124 | while(port < outputPorts) { 125 | val out = outs(port) 126 | val queue = queues(port) 127 | 128 | if(!queue.isEmpty) 129 | emitMultiple(out, queue.iterator().asScala, () => completeOut(out)) 130 | else 131 | completeOut(out) 132 | 133 | port += 1 134 | } 135 | } 136 | 137 | private def completeOut(out: Outlet[Frame]): Unit = { 138 | complete(out) 139 | 140 | finishing -= 1 141 | if(finishing == 0) { 142 | completeStage() 143 | finishing = -1 144 | } 145 | } 146 | 147 | setHandler(in, new InHandler { 148 | override def onPush(): Unit = { 149 | enqueueIn() 150 | if(!isClosed(in) && !hasBeenPulled(in) && demands.exists(b => b) && !backed) { 151 | println(s"InboundStreamDispatcher: pulling inlet after push"); System.out.flush() 152 | pull(in) 153 | } 154 | } 155 | 156 | override def onUpstreamFinish(): Unit = 157 | finishIn() 158 | }) 159 | 160 | for(port <- 0 until outputPorts) { 161 | val out = outs(port) 162 | setHandler(out, new OutHandler { 163 | override def onPull(): Unit = { 164 | println(s"InboundStreamDispatcher: $port pulled"); System.out.flush() 165 | queues(port).poll() match { 166 | case null => 167 | demands(port) = true 168 | if(!hasBeenPulled(in) && !backed) { 169 | println(s"InboundStreamDispatcher: pulling inlet"); System.out.flush() 170 | pull(in) 171 | } 172 | case elm => 173 | push(out, elm) 174 | } 175 | } 176 | 177 | override def onDownstreamFinish(): Unit = { 178 | if(finishing < 0) { 179 | failStage(new RuntimeException(s"Outlet finished too early: $out")) 180 | } else { 181 | super.onDownstreamFinish() 182 | } 183 | } 184 | }) 185 | } 186 | 187 | override def preStart(): Unit = () //pull(in) 188 | } 189 | 190 | override def toString = "InboundStreamDispatcher" 191 | } 192 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/stream/NormalDataDecodeStage.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import akka.util.ByteString 4 | import akka.stream.stage.{SyncDirective, Context, PushStage} 5 | import net.danielkza.http2.protocol.Frame 6 | import net.danielkza.http2.protocol.Frame.Data 7 | 8 | private class NormalDataDecodeStage(val ignoreNonData: Boolean = false) extends PushStage[Frame, ByteString] { 9 | override def onPush(frame: Frame, ctx: Context[ByteString]): SyncDirective = { 10 | frame match { 11 | case d: Data if d.endStream => ctx.pushAndFinish(d.data) 12 | case d: Data => ctx.push(d.data) 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/stream/StreamManagerActor.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import scalaz.{\/-, -\/} 4 | import akka.actor._ 5 | import akka.pattern.pipe 6 | import net.danielkza.http2.protocol._ 7 | import net.danielkza.http2.protocol.HTTP2Error._ 8 | 9 | class StreamManagerActor(managerF: => StreamManager) extends Actor { 10 | import StreamManagerActor._ 11 | import StreamManager._ 12 | import Http2Stream._ 13 | 14 | private var manager: StreamManager = null 15 | 16 | override def preStart(): Unit = { 17 | super.preStart() 18 | manager = managerF 19 | } 20 | 21 | override def postStop(): Unit = { 22 | manager.closeAll() 23 | manager = null 24 | super.postStop() 25 | } 26 | 27 | def receive: Receive = { case i: InMessage => i match { 28 | case IncomingFrame(frame) => 29 | manager.receive(frame) match { 30 | case -\/(error) => 31 | sender ! Failure(error) 32 | case \/-(ReceiveReply(stream, action)) => 33 | sender ! IncomingAction(stream.id, action) 34 | } 35 | case OutgoingFrame(frame) => 36 | manager.send(frame) match { 37 | case -\/(error) => 38 | sender ! Failure(error) 39 | case \/-(SendReply(stream, Delay(dataFuture))) => 40 | implicit val ec = context.dispatcher 41 | val s = sender 42 | dataFuture.map { 43 | case d if d.endStream => 44 | OutgoingAction(stream.id, Finish(d)) 45 | case d => 46 | OutgoingAction(stream.id, Continue(d)) 47 | }.recover { 48 | case e: HTTP2Exception => 49 | Failure(e.error) 50 | case e => 51 | Failure(InternalError(s"Unexpected exception: $e")) 52 | } to s 53 | case \/-(SendReply(stream, action))=> 54 | sender ! OutgoingAction(stream.id, action) 55 | } 56 | case ReserveStream => 57 | manager.reserveOut() match { 58 | case Some(stream) => sender ! Reserved(stream.id) 59 | case None => sender ! Failure(RefusedStream()) 60 | } 61 | case StreamProcessingStarted(streamId) => 62 | manager.markProcessed(streamId) 63 | case GetLastProcessedStream => 64 | sender ! LastProcessedStream(manager.lastProcessedId) 65 | }} 66 | } 67 | 68 | object StreamManagerActor { 69 | import Http2Stream._ 70 | 71 | sealed trait InMessage 72 | case class IncomingFrame(frame: Frame) extends InMessage 73 | case class OutgoingFrame(frame: Frame) extends InMessage 74 | case object ReserveStream extends InMessage 75 | case class StreamProcessingStarted(stream: Int) extends InMessage 76 | case object GetLastProcessedStream extends InMessage 77 | 78 | sealed trait OutMessage 79 | case class Failure(error: HTTP2Error) extends OutMessage 80 | case class Reserved(stream: Int) extends OutMessage 81 | case class IncomingAction(stream: Int, action: ReceiveAction) extends OutMessage 82 | case class OutgoingAction(stream: Int, action: SendAction) extends OutMessage 83 | case class LastProcessedStream(stream: Int) extends OutMessage 84 | 85 | def props(manager: => StreamManager): Props = 86 | Props(new StreamManagerActor(manager)) 87 | } 88 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/stream/package.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2 2 | 3 | import akka.stream._ 4 | import akka.stream.scaladsl._ 5 | import akka.util.ByteString 6 | import net.danielkza.http2.protocol.Frame 7 | 8 | package object stream { 9 | private[stream] def headAndTailFlow[T]: Flow[Source[T, Any], (T, Source[T, Unit]), Unit] = 10 | Flow[Source[T, Any]] 11 | .flatMapConcat { 12 | _.prefixAndTail(1) 13 | .filter(_._1.nonEmpty) 14 | .map { case (prefix, tail) ⇒ (prefix.head, tail) } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/util/ArrayQueue.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.util 2 | 3 | import java.util 4 | import scala.reflect.ClassTag 5 | 6 | class ArrayQueue[T <: AnyRef : ClassTag](val maxCapacity: Int) extends util.AbstractQueue[T] { self => 7 | private val backing = Array.ofDim[T](maxCapacity) 8 | private var firstIndex = 0 9 | private var curSize = 0 10 | 11 | private def insertionIndex: Int = 12 | (firstIndex + curSize) % maxCapacity 13 | 14 | override def offer(e: T): Boolean = { 15 | if(size == maxCapacity) 16 | false 17 | else { 18 | backing(insertionIndex) = e 19 | curSize += 1 20 | true 21 | } 22 | } 23 | 24 | override def peek(): T = { 25 | if(size == 0) 26 | null.asInstanceOf[T] 27 | else 28 | backing(firstIndex) 29 | } 30 | 31 | override def poll(): T = { 32 | if(size == 0) { 33 | null.asInstanceOf[T] 34 | } else { 35 | val elm = backing(firstIndex) 36 | firstIndex = (firstIndex + 1) % maxCapacity 37 | curSize -= 1 38 | elm 39 | } 40 | } 41 | 42 | override def size(): Int = curSize 43 | 44 | def isFull: Boolean = size() == maxCapacity 45 | 46 | override def iterator(): util.Iterator[T] = new util.Iterator[T] { 47 | private val end = insertionIndex 48 | private var cur = firstIndex 49 | 50 | override def hasNext: Boolean = cur != end 51 | 52 | override def next(): T = { 53 | val elm = self.backing(cur) 54 | cur = (cur + 1) % maxCapacity 55 | elm 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/util/Implicits.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.util 2 | 3 | import scala.language.experimental.macros 4 | import java.io.{IOException, OutputStream} 5 | import java.nio.ByteBuffer 6 | import scalaz.\/ 7 | import akka.util.ByteString 8 | import net.danielkza.http2.macros.{BitPatterns => BitPatternsMacros} 9 | 10 | trait Implicits { 11 | implicit class BytePattern(sc: StringContext) { 12 | object bin extends { 13 | def unapply(x: Byte): Boolean = macro BitPatternsMacros.binaryLiteralExtractorImpl[Byte] 14 | def unapply(x: Short): Boolean = macro BitPatternsMacros.binaryLiteralExtractorImpl[Short] 15 | def unapply(x: Int): Boolean = macro BitPatternsMacros.binaryLiteralExtractorImpl[Int] 16 | def unapply(x: Long): Boolean = macro BitPatternsMacros.binaryLiteralExtractorImpl[Long] 17 | 18 | def apply(args: Any*) = macro BitPatternsMacros.binaryLiteralImpl[Int] 19 | } 20 | 21 | def bin_b(args: Any*): Byte = macro BitPatternsMacros.binaryLiteralImpl[Byte] 22 | def bin_s(args: Any*): Short = macro BitPatternsMacros.binaryLiteralImpl[Short] 23 | def bin_i(args: Any*): Int = macro BitPatternsMacros.binaryLiteralImpl[Int] 24 | def bin_l(args: Any*): Long = macro BitPatternsMacros.binaryLiteralImpl[Long] 25 | } 26 | 27 | private def byteStringFromIntegers(str: String, groupSize: Int, radix: Int): ByteString = { 28 | val cleanStr = str.replaceAll("\\s", "") 29 | val builder = ByteString.newBuilder 30 | builder.sizeHint(cleanStr.length / groupSize) 31 | 32 | cleanStr.grouped(groupSize).map(s => Integer.parseUnsignedInt(s, radix)).foreach { byte => 33 | builder.putByte(byte.toByte) 34 | } 35 | 36 | builder.result() 37 | } 38 | 39 | implicit class ByteStringConversion(sc: StringContext) { 40 | def u8_bs(args: Any*): ByteString = 41 | ByteString(sc.s(args: _*), "UTF-8") 42 | 43 | def bs(args: Any*): ByteString = u8_bs(args: _*) 44 | 45 | 46 | def bits_bs(args: Any*): ByteString = 47 | byteStringFromIntegers(sc.s(args: _*), groupSize=8, radix=2) 48 | 49 | def hex_bs(args: Any*): ByteString = 50 | byteStringFromIntegers(sc.s(args: _*), groupSize=2, radix=16) 51 | } 52 | 53 | implicit class StringByteStringOps(s: String) { 54 | def byteStringFromBits: ByteString = 55 | byteStringFromIntegers(s, groupSize=8, radix=2) 56 | 57 | def byteStringFromHex: ByteString = 58 | byteStringFromIntegers(s, groupSize=2, radix=16) 59 | } 60 | 61 | implicit class RichOutputStream(stream: OutputStream) { 62 | @throws[IOException] def write(buffer: ByteBuffer): Unit = { 63 | val pos = buffer.position() 64 | if(buffer.hasArray) { 65 | stream.write(buffer.array(), buffer.arrayOffset() + pos, buffer.limit() - pos) 66 | } else { 67 | while(buffer.hasRemaining) { 68 | stream.write(buffer.get()) 69 | } 70 | 71 | buffer.position(pos) 72 | } 73 | } 74 | 75 | @throws[IOException] def write(string: ByteString): Unit = { 76 | string.asByteBuffers.foreach(write) 77 | } 78 | } 79 | 80 | implicit class DisjunctionWithThrow[E, V](self: \/[E, V])(implicit ev: E => Throwable) { 81 | def orThrow: V = self.leftMap { e => throw ev(e) }.merge 82 | } 83 | } 84 | 85 | object Implicits extends Implicits 86 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/util/package.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2 2 | 3 | package object util extends Implicits 4 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/util/stream/Concentrator.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.util.stream 2 | 3 | import akka.stream._ 4 | import akka.stream.scaladsl._ 5 | import akka.stream.stage.Stage 6 | 7 | object Concentrator { 8 | def fromFlow[I, O, M](n: Int, flow: Flow[I, O, M], bufferSize: Int = 1, 9 | overflowStrategy: OverflowStrategy = OverflowStrategy.backpressure) 10 | : Graph[ConcentratorShape[I, O], M] = 11 | { 12 | FlowGraph.create(flow) { implicit b => flow => 13 | import FlowGraph.Implicits._ 14 | 15 | val inputs = Vector.tabulate(n) { i => 16 | b.add(Flow[I].map { in => (in, i) }) 17 | } 18 | 19 | val outputs = Vector.tabulate(n) { i => 20 | b.add(Flow[(O, Int)].collect { case (out, `i`) => out }.buffer(bufferSize, overflowStrategy)) 21 | } 22 | 23 | val inMerge = b.add(Merge[(I, Int)](n)) 24 | val outBroadcast = b.add(Broadcast[(O, Int)](n)) 25 | 26 | val indexBypassIn = b.add(Unzip[I, Int]) 27 | val indexBypassOut = b.add(Zip[O, Int]) 28 | 29 | for(i <- 0 until n) { 30 | inputs(i).outlet ~> inMerge.in(i) 31 | } 32 | 33 | inMerge.out ~> indexBypassIn.in 34 | indexBypassIn.out0 ~> flow ~> indexBypassOut.in0 35 | indexBypassIn.out1 ~> indexBypassOut.in1 36 | indexBypassOut.out ~> outBroadcast 37 | 38 | for(i <- 0 until n) { 39 | outBroadcast.out(i) ~> outputs(i).inlet 40 | } 41 | 42 | new ConcentratorShape(inputs.map(_.inlet), outputs.map(_.outlet)) 43 | } 44 | } 45 | 46 | def fromStage[I, O](n: Int, stage: () => Stage[I, O]): Graph[ConcentratorShape[I, O], Unit] = 47 | fromFlow(n, Flow[I].transform(stage)) 48 | } 49 | -------------------------------------------------------------------------------- /core/src/main/scala/net/danielkza/http2/util/stream/ConcentratorShape.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.util.stream 2 | 3 | import akka.stream._ 4 | 5 | import scala.annotation.unchecked.uncheckedVariance 6 | import scala.collection.immutable 7 | 8 | class ConcentratorShape[-I, +O] private[http2] ( 9 | override val inlets: immutable.Seq[Inlet[I @uncheckedVariance]], 10 | override val outlets: immutable.Seq[Outlet[O @uncheckedVariance]]) 11 | extends Shape 12 | { 13 | def this(n: Int) = 14 | this(Vector.tabulate(n) { i => Inlet("Concentrator.in" + i)}, 15 | Vector.tabulate(n) { i => Outlet("Concentrator.out" + i)}) 16 | 17 | override def deepCopy(): ConcentratorShape[I, O] = 18 | new ConcentratorShape(inlets.map(_.carbonCopy()), outlets.map(_.carbonCopy())) 19 | 20 | override def copyFromPorts(inlets: immutable.Seq[Inlet[_]], outlets: immutable.Seq[Outlet[_]]): Shape = { 21 | require(inlets.nonEmpty, s"Empty inlets or outlets") 22 | require(inlets.size == outlets.size, 23 | s"Non-matching count of inlets [${inlets.mkString(", ")}] and outlets [${outlets.mkString(", ")}]") 24 | new ConcentratorShape(inlets.asInstanceOf[immutable.Seq[Inlet[I]]], 25 | outlets.asInstanceOf[immutable.Seq[Outlet[O]]]) 26 | } 27 | 28 | def in(i: Int): Inlet[I @uncheckedVariance] = inlets(i) 29 | def out(i: Int): Outlet[O @uncheckedVariance]= outlets(i) 30 | def flow(i: Int): FlowShape[I, O] = FlowShape(inlets(i), outlets(i)) 31 | } 32 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/AkkaStreamsTest.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2 2 | 3 | import akka.actor.ActorSystem 4 | import akka.stream.{ActorMaterializerSettings, ActorMaterializer} 5 | import com.typesafe.config.ConfigFactory 6 | import org.specs2.mutable.SpecificationLike 7 | 8 | abstract class AkkaStreamsTest( 9 | _system: ActorSystem = ActorSystem("AkkaStreamTest", ConfigFactory.parseString(AkkaStreamsTest.config)) 10 | ) extends AkkaTest(_system) 11 | { self: SpecificationLike => 12 | implicit val actorMaterializer = ActorMaterializer(ActorMaterializerSettings(system)) 13 | } 14 | 15 | object AkkaStreamsTest { 16 | val config = """ 17 | akka { 18 | test { 19 | single-expect-default = 10 seconds 20 | } 21 | } 22 | """ 23 | } 24 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/AkkaTest.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2 2 | 3 | import org.specs2.specification.core._ 4 | import org.specs2.mutable.{After, SpecificationLike} 5 | import akka.actor.ActorSystem 6 | import akka.testkit.{ImplicitSender, TestKit} 7 | 8 | abstract class AkkaTest(_system: ActorSystem = ActorSystem("AkkaTest")) extends TestKit(_system) with ImplicitSender 9 | with SpecificationLike 10 | { 11 | override def map(fs: => Fragments) = super.map(fs) ^ step(system.shutdown, global = true) 12 | } 13 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/TestHelpers.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2 2 | 3 | import scala.language.implicitConversions 4 | 5 | import java.io.{InputStream, OutputStream} 6 | import akka.util.{ByteString, ByteStringBuilder} 7 | 8 | import scalaz.\/ 9 | 10 | trait TestHelpers extends util.Implicits { 11 | def inputStream(s: ByteString): InputStream = 12 | s.iterator.asInputStream 13 | 14 | def withOutputStream[T](f: OutputStream => T): ByteString = { 15 | val bytes = new ByteStringBuilder 16 | f(bytes.asOutputStream) 17 | bytes.result() 18 | } 19 | 20 | implicit def stringToByteString(s: String): ByteString = ByteString.fromString(s) 21 | 22 | implicit class DisjunctionWithThrow[A, B](self: \/[A, B]) { 23 | def getOrThrow(): B = self.getOrElse(throw new NoSuchElementException) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/hpack/DynamicTableTest.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack 2 | 3 | import scalaz._ 4 | 5 | import org.specs2.matcher.DisjunctionMatchers 6 | import org.specs2.mutable.Specification 7 | import org.specs2.specification.Scope 8 | 9 | import net.danielkza.http2.TestHelpers 10 | 11 | class DynamicTableTest extends Specification with TestHelpers { 12 | trait Context extends Scope { 13 | var table = new DynamicTable(512, baseOffset=0) 14 | } 15 | 16 | "DynamicTableTest" should { 17 | "withCapacity" in { 18 | "should resize correctly" >> new Context { 19 | while(table.curSize <= 100) 20 | table += "Test" -> "Test" 21 | 22 | table.withCapacity(100) must beLike { case \/-(t) => t.curSize must be_<=(100) } 23 | } 24 | "should not resize over the max capacity" >> new Context { 25 | table.withCapacity(1000) must beLike { case -\/(e: HeaderError) => ok } 26 | } 27 | } 28 | 29 | "append" in { 30 | "should resize before insertion if necessary" >> new Context { 31 | val entry = DynamicTable.Entry(bs"Test1", bs"Test1") 32 | table.withCapacity(entry.size).map { table => 33 | table + entry + ("Test2" -> "Test2") 34 | } must beLike { case \/-(t: DynamicTable) => 35 | (t.entries must have size 1) and (t.entries.head.name === bs"Test2") 36 | } 37 | } 38 | "should leave the table empty if the new entry doesn't fit" >> new Context { 39 | table.withCapacity(1).map { table => 40 | table + ("Test1" -> "Test1") 41 | } must beLike { case \/-(t: DynamicTable) => 42 | t.entries must beEmpty 43 | } 44 | } 45 | } 46 | 47 | "find" in new Context { 48 | table += ("Test1" -> "Test1") 49 | table += ("Test2" -> "") 50 | table += ("Test1" -> "") 51 | 52 | "should find an exact match for name and value" >> { 53 | table.find(bs"Test1", Some(bs"")) === Table.FoundNameValue(3) 54 | } 55 | "should find a name match for name and value" >> { 56 | table.find(bs"Test1", Some(bs"Other")) === Table.FoundName(1) 57 | } 58 | "should find a match for name only" >> { 59 | table.find(bs"Test2", None) === Table.FoundName(2) 60 | } 61 | "should not find a match for name and value" >> { 62 | table.find(bs"Test3", Some(bs"")) === Table.NotFound 63 | } 64 | "should not find a match for name only" >> { 65 | table.find(bs"Test4", None) === Table.NotFound 66 | } 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/hpack/HeaderBlockCoderTest.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack 2 | 3 | 4 | import scalaz._ 5 | import scalaz.std.AllInstances._ 6 | import scalaz.syntax.traverse._ 7 | import org.specs2.matcher.DataTables 8 | import org.specs2.mutable.Specification 9 | import org.specs2.specification.Scope 10 | import akka.util.ByteString 11 | import net.danielkza.http2.TestHelpers 12 | import net.danielkza.http2.api.Header 13 | import net.danielkza.http2.hpack.coders.{HeaderBlockCoder, HeaderCoder} 14 | 15 | class HeaderBlockCoderTest extends Specification with DataTables with TestHelpers { 16 | import HeaderRepr._ 17 | import Header.{plain, secure} 18 | 19 | def dynTablePos(x: Int) = StaticTable.default.length + x 20 | 21 | val (headers, reprs) = List( 22 | // fully indexed from static table 23 | plain (":status", "200" ) -> Indexed(8), 24 | // name indexed from static table, dt-size = 1 25 | plain (":status", "999" ) -> IncrementalLiteralWithIndexedName(14, bs"999"), 26 | // new literal, dt-size = 2 27 | plain ("fruit", "banana") -> IncrementalLiteral(bs"fruit", bs"banana"), 28 | // new literal, dt-size = 3 29 | plain ("color", "yellow") -> IncrementalLiteral(bs"color", bs"yellow"), 30 | // repeat, fully indexed from dynamic table 31 | plain ("fruit", "banana") -> Indexed(dynTablePos(2)), 32 | // name indexed from dynamic table, dt-size = 4 33 | plain ("fruit", "apple" ) -> IncrementalLiteralWithIndexedName(dynTablePos(2), bs"apple"), 34 | // repeat, fully indexed from dynamic table 35 | plain ("fruit", "apple" ) -> Indexed(dynTablePos(1)), 36 | // repeat, fully indexed from dynamic table 37 | plain ("color", "yellow") -> Indexed(dynTablePos(2)), 38 | // literal never indexed 39 | secure("drink", "soda" ) -> NeverIndexed(bs"drink", bs"soda"), 40 | // repeat literal never indexed, must not be in dynamic table 41 | secure("drink", "soda" ) -> NeverIndexed(bs"drink", bs"soda"), 42 | // literal never indexed, name indexed from dynamic table 43 | secure("color", "blue" ) -> NeverIndexedWithIndexedName(dynTablePos(2), bs"blue") 44 | ).unzip 45 | 46 | val headerCoder = new HeaderCoder(HeaderCoder.compress.Never) 47 | 48 | val encoded = { 49 | val parts: \/[HeaderError, List[ByteString]] = reprs.map(headerCoder.encode).sequenceU 50 | parts.map(_.reduce(_ ++ _)).getOrElse(throw new AssertionError) 51 | } 52 | 53 | trait Context extends Scope { 54 | val coder = new HeaderBlockCoder(headerCoder = headerCoder) 55 | } 56 | 57 | "HeaderBlockCoderTest" should { 58 | "encode" in { 59 | "a sequence of headers correctly" >> new Context { 60 | coder.encode(headers) must_== \/-(encoded) 61 | } 62 | } 63 | 64 | "decode" in { 65 | "a sequence of headers correctly" >> new Context { 66 | coder.decode(encoded) must_== \/-((headers, encoded.length)) 67 | } 68 | } 69 | 70 | "withDynamicTableCapacity" in { 71 | ok 72 | } 73 | 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/hpack/coders/BytesCoderTest.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack.coders 2 | 3 | import scalaz.\/- 4 | 5 | import org.specs2.matcher.DataTables 6 | import org.specs2.mutable.Specification 7 | import net.danielkza.http2.TestHelpers 8 | 9 | class BytesCoderTest extends Specification with DataTables with TestHelpers { 10 | val cases = { 11 | "result" | "input" |> 12 | bs"custom-key" ! hex_bs"0a 6375 7374 6f6d 2d6b 6579" | 13 | bs"custom-header" ! hex_bs"0d 6375 7374 6f6d 2d68 6561 6465 72" 14 | } 15 | 16 | "BytesCoder" should { 17 | "decode" in { 18 | cases | { (result, input) => 19 | new BytesCoder(0).decode(input) must_== \/-((result, input.length)) 20 | } 21 | } 22 | 23 | "encode" in { 24 | cases | { (result, input) => 25 | new BytesCoder(0).encode(result) must_== \/-(input) 26 | } 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/hpack/coders/CompressedBytesCoderTest.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack.coders 2 | 3 | import akka.util.ByteString 4 | 5 | import scalaz.\/- 6 | 7 | import org.specs2.matcher.DataTables 8 | import org.specs2.mutable.Specification 9 | import net.danielkza.http2.TestHelpers 10 | 11 | class CompressedBytesCoderTest extends Specification with DataTables with TestHelpers { 12 | val cases = { 13 | "plain" | "encoded" |> 14 | u8_bs"www.example.com" ! hex_bs"8c f1e3 c2e5 f23a 6ba0 ab90 f4ff" | 15 | u8_bs"no-cache" ! hex_bs"86 a8eb 1064 9cbf" | 16 | u8_bs"custom-key" ! hex_bs"88 25a8 49e9 5ba9 7d7f" | 17 | u8_bs"custom-value" ! hex_bs"89 25a8 49e9 5bb8 e8b4 bf" | 18 | u8_bs"text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8" ! 19 | ByteString(-72, 73, 124, -91, -119, -45, 77, 31, 67, -82, -70, 12, 65, -92, -57, -87, -113, 51, -90, -102, 63, -33, -102, 104, -6, 29, 117, -48, 98, 13, 38, 61, 76, 121, -90, -113, -66, -48, 1, 119, -2, -115, 72, -26, 43, 30, 11, 29, 127, 95, 44, 124, -3, -10, -128, 11, -67) 20 | } 21 | 22 | "CompressedBytesCoder" should { 23 | "decode" in { 24 | cases | { (plain, encoded) => 25 | (new CompressedBytesCoder).decode(encoded) must_== \/-((plain, encoded.length)) 26 | } 27 | } 28 | 29 | "encode" in { 30 | cases | { (plain, encoded) => 31 | (new CompressedBytesCoder).encode(plain) must_== \/-(encoded) 32 | } 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/hpack/coders/HeaderCoderTest.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack.coders 2 | 3 | import scalaz.\/- 4 | import org.specs2.matcher.DataTables 5 | import org.specs2.mutable.Specification 6 | import net.danielkza.http2.TestHelpers 7 | import net.danielkza.http2.hpack.HeaderRepr._ 8 | 9 | class HeaderCoderTest extends Specification with DataTables with TestHelpers { 10 | val plainCoder = new HeaderCoder(HeaderCoder.compress.Never) 11 | val compressedCoder = new HeaderCoder(HeaderCoder.compress.Always) 12 | 13 | val plainCases = { "header" | "encoded" |> 14 | IncrementalLiteral("custom-key", "custom-header") ! 15 | hex_bs"40 0a 6375 7374 6f6d 2d6b 6579 0d 6375 7374 6f6d 2d68 6561 6465 72" | 16 | LiteralWithIndexedName(4, "/sample/path") ! 17 | hex_bs"04 0c 2f73 616d 706c 652f 7061 7468 " | 18 | NeverIndexed("password", "secret") ! 19 | hex_bs"10 08 7061 7373 776f 7264 06 7365 6372 6574" | 20 | Indexed(2) ! 21 | hex_bs"82" | 22 | DynamicTableSizeUpdate(256) ! 23 | hex_bs"3f e1 01" 24 | } 25 | 26 | val compressedCases = { "header" | "encoded" |> 27 | IncrementalLiteralWithIndexedName(1, "www.example.com") ! 28 | hex_bs"41 8c f1e3 c2e5 f23a 6ba0 ab90 f4ff" | 29 | IncrementalLiteral("custom-key", "custom-value") ! 30 | hex_bs"40 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf" | 31 | IncrementalLiteralWithIndexedName(33, "Mon, 21 Oct 2013 20:13:21 GMT") ! 32 | hex_bs"61 96 d07a be94 1054 d444 a820 0595 040b 8166 e082 a62d 1bff" | 33 | IncrementalLiteralWithIndexedName(55, "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1") ! 34 | hex_bs""" 35 | 77 ad 94e7 821d d7f2 e6c7 b335 dfdf cd5b 3960 d5af 2708 7f36 72c1 ab27 0fb5 291f 9587 3160 65c0 03ed 4ee5 b106 3d50 36 | 07""" 37 | } 38 | 39 | "HeaderCoder" should { 40 | "decode" in { 41 | "plain-text" in { 42 | plainCases | { (header, encoded) => 43 | plainCoder.decode(encoded) must_== \/-((header, encoded.length)) 44 | } 45 | } 46 | "compressed" in { 47 | compressedCases | { (header, encoded) => 48 | val res = compressedCoder.decode(encoded) 49 | res must_== \/-((header, encoded.length)) 50 | } 51 | } 52 | } 53 | "encode" in { 54 | "plain-text" in { 55 | plainCases | { (header, encoded) => 56 | plainCoder.encode(header) must_== \/-(encoded) 57 | } 58 | } 59 | "compressed" in { 60 | compressedCases | { (header, encoded) => 61 | val res = compressedCoder.encode(header) 62 | res must_== \/-(encoded) 63 | } 64 | } 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/hpack/coders/IntCoderTest.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.hpack.coders 2 | 3 | import scalaz.\/- 4 | 5 | import org.specs2.matcher.DataTables 6 | import org.specs2.mutable.Specification 7 | import net.danielkza.http2.TestHelpers 8 | 9 | class IntCoderTest extends Specification with DataTables with TestHelpers { 10 | val cases = { 11 | "elemSize"| "result" | "input" |> 12 | 5 ! 10 ! bits_bs"00001010" | 13 | 5 ! 1337 ! bits_bs"00011111 10011010 00001010" | 14 | 8 ! 42 ! bits_bs"00101010" 15 | } 16 | 17 | "IntCoder" should { 18 | "decode" in { 19 | cases | { (elmSize, result, input) => 20 | new IntCoder(elmSize).decode(input) must_== \/-((result, input.length)) 21 | } 22 | } 23 | 24 | "encode" in { 25 | cases | { (elmSize, result, input) => 26 | new IntCoder(elmSize).encode(result) must_== \/-(input) 27 | } 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/protocol/coders/FrameCoderTest.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.protocol.coders 2 | 3 | import scalaz._ 4 | import scalaz.std.AllInstances._ 5 | import scalaz.syntax.traverse._ 6 | import akka.util.ByteString 7 | import argonaut._ 8 | import Argonaut._ 9 | import better.files._ 10 | import org.specs2.mutable.Specification 11 | import org.specs2.specification.core.Fragments 12 | import net.danielkza.http2.TestHelpers 13 | import net.danielkza.http2.protocol.{Setting, Frame} 14 | 15 | class FrameCoderTest extends Specification with TestHelpers { 16 | sealed trait Case { 17 | def wire: ByteString 18 | def description: String 19 | } 20 | case class OkCase(wire: ByteString, description: String, length: Int, flags: Byte, stream: Int, tpe: Byte, 21 | payload: Frame) extends Case 22 | case class ErrorCase(wire: ByteString, description: String, errors: List[Int]) extends Case 23 | 24 | implicit def bsJson: DecodeJson[ByteString] = 25 | DecodeJson(c => c.as[String].map(ByteString(_))) 26 | 27 | def frameJson(tpe: Byte, stream: Int): DecodeJson[Frame] = DecodeJson(c => { 28 | import Frame._ 29 | tpe match { 30 | case Types.DATA => for { 31 | padLen <- c.get[Option[Int]]("padding_length") 32 | data <- c.get[ByteString]("data") 33 | padding <- c.get[Option[ByteString]]("padding") 34 | } yield Data(stream, data, padding = padding) 35 | 36 | case Types.HEADERS => for { 37 | padLen <- c.get[Option[Int]]("padding_length") 38 | depStream <- c.get[Option[Int]]("stream_dependency") 39 | exclusive <- c.get[Option[Boolean]]("exclusive") 40 | weight <- c.get[Option[Int]]("weight") 41 | frag <- c.get[ByteString]("header_block_fragment") 42 | padding <- c.get[Option[ByteString]]("padding") 43 | } yield { 44 | val streamDependency = depStream.map { s => StreamDependency(exclusive.get, s, weight.get.toByte) } 45 | Headers(stream, streamDependency, frag, padding = padding) 46 | } 47 | 48 | case Types.PRIORITY => for { 49 | targetStream <- c.get[Int]("stream_dependency") 50 | weight <- c.get[Int]("weight") 51 | exclusive <- c.get[Boolean]("exclusive") 52 | } yield Priority(stream, StreamDependency(exclusive, targetStream, weight.toByte)) 53 | 54 | case Types.RST_STREAM => for { 55 | error <- c.get[Int]("error_code") 56 | } yield ResetStream(stream, error) 57 | 58 | case Types.SETTINGS => for { 59 | settings <- c.get[List[(Int, Int)]]("settings") 60 | } yield Settings(settings.map(t => Setting(t._1.toShort, t._2))) 61 | 62 | case Types.PUSH_PROMISE => for { 63 | padLen <- c.get[Option[Int]]("padding_length") 64 | promisedStream <- c.get[Int]("promised_stream_id") 65 | frag <- c.get[ByteString]("header_block_fragment") 66 | padding <- c.get[Option[ByteString]]("padding") 67 | } yield PushPromise(stream, promisedStream, frag, padding = padding) 68 | 69 | case Types.PING => for { 70 | data <- c.get[ByteString]("opaque_data") 71 | } yield Ping(data) 72 | 73 | case Types.GOAWAY => for { 74 | stream <- c.get[Int]("last_stream_id") 75 | error <- c.get[Int]("error_code") 76 | debugData <- c.get[ByteString]("additional_debug_data") 77 | } yield GoAway(stream ,error, debugData) 78 | 79 | case Types.CONTINUATION => for { 80 | frag <- c.get[ByteString]("header_block_fragment") 81 | } yield Continuation(stream, frag) 82 | 83 | case Types.WINDOW_UPDATE => for { 84 | increment <- c.get[Int]("window_size_increment") 85 | } yield WindowUpdate(stream, increment) 86 | 87 | case _ => 88 | DecodeResult.fail(s"Unknown frame type $tpe", c.history) 89 | } 90 | }) 91 | 92 | implicit def caseJson: DecodeJson[Case] = DecodeJson(c => for { 93 | wire <- (c --\ "wire").as[String].map(_.byteStringFromHex) 94 | description <- (c --\ "description").as[String] 95 | result <- (c --\ "error").as[List[Int]].map { errorList => 96 | ErrorCase(wire, description, errorList): Case 97 | } ||| { 98 | for { 99 | length <- (c --\ "frame" --\ "length").as[Int] 100 | flags <- (c --\ "frame" --\ "flags").as[Int].map(_.toByte) 101 | stream <- (c --\ "frame" --\ "stream_identifier").as[Int] 102 | tpe <- (c --\ "frame" --\ "type").as[Int].map(_.toByte) 103 | payload <- frameJson(tpe.toByte, stream).tryDecode(c --\ "frame" --\ "frame_payload").map(_.withFlags(flags)) 104 | } yield OkCase(wire, description, length, flags.toByte, stream, tpe, payload): Case 105 | } 106 | } yield result) 107 | 108 | def readCases(): \/[String, List[Case]] = { 109 | for { 110 | caseDirectory <- sys.props.get("http2.frame_tests_dir").map(\/-(_)).getOrElse { 111 | -\/("Failed to find HTTP2 Frame Test Cases. Make sure the `http2.frame_tests_dir` system property is correct") 112 | } 113 | files = caseDirectory.toFile.listRecursively.filter(_.name.endsWith(".json")).toList 114 | results <- files.map { file => 115 | Parse.decodeValidation[Case](file.contentAsString).disjunction.leftMap { error => 116 | file.fullPath + ": " + error 117 | } 118 | }.sequence[\/[String, ?], Case] 119 | } yield results 120 | } 121 | 122 | "FrameCoder" should { 123 | val cases = readCases() 124 | val coder = new FrameCoder 125 | 126 | cases map { cases => 127 | val okCases = cases.collect { case c: OkCase => c } 128 | val errCases = cases.collect { case c: ErrorCase => c } 129 | 130 | "encode" in Fragments.foreach(okCases) { c => c.description >> { 131 | coder.encode(c.payload) must_== \/-(c.wire) 132 | }} 133 | 134 | "decode" in { 135 | "with success" in Fragments.foreach(okCases) { c => c.description >> { 136 | coder.decodeS.run(c.wire) must beLike { case \/-((rem, frame)) => 137 | (rem must beEmpty) and (frame must_== c.payload) and (frame.flags must_== c.flags) 138 | } 139 | }} 140 | "with failure" in Fragments.foreach(errCases) { c => c.description >> { 141 | coder.decodeS.eval(c.wire) must beLike { case -\/(error) => 142 | c.errors must contain(error.code) 143 | } 144 | }} 145 | } 146 | } valueOr { error => 147 | Fragments( 148 | "encode" in skipped("Error: " + error), 149 | "decode" in skipped("Error: " + error) 150 | ) 151 | } 152 | } 153 | } 154 | 155 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/stream/FrameDecoderStageTest.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import net.danielkza.http2.api.Header 4 | 5 | import scala.collection.immutable 6 | import akka.stream.scaladsl._ 7 | import akka.stream.testkit.scaladsl._ 8 | import akka.util.ByteString 9 | import net.danielkza.http2.{AkkaStreamsTest, TestHelpers} 10 | import net.danielkza.http2.protocol.{Frame, HTTP2Error} 11 | import net.danielkza.http2.protocol.coders.FrameCoder 12 | import net.danielkza.http2.hpack.coders.HeaderBlockCoder 13 | 14 | class FrameDecoderStageTest extends AkkaStreamsTest with TestHelpers { 15 | import Frame._ 16 | val headerCoder = new HeaderBlockCoder 17 | val frameCoder = new FrameCoder 18 | 19 | val headers = immutable.Seq( 20 | ":method" -> "GET", 21 | ":path" -> "/", 22 | "host" -> "example.com" 23 | ).map(t => Header.plain(t._1, t._2)) 24 | val headerBlock = headerCoder.encode(headers).getOrThrow() 25 | 26 | val okFrames = immutable.Seq( 27 | Headers(1, None, headerBlock, endHeaders=true), 28 | Data(1, "Line 1\n"), 29 | Data(1, "Line 2\n\n", padding=Some("Padding"), endStream=true), 30 | GoAway(1) 31 | ) 32 | 33 | val framesBytes = okFrames.map(frameCoder.encode(_).getOrThrow()) 34 | 35 | "FrameDecoderStage" should { 36 | val flow = Flow[ByteString].transform(() => new FrameDecoderStage(false)) 37 | val (pub, sub) = TestSource.probe[ByteString] 38 | .via(flow) 39 | .toMat(TestSink.probe[Frame])(Keep.both) 40 | .run() 41 | 42 | "decode frames correctly" in { 43 | okFrames.zip(framesBytes).foreach { case (frame, bytes) => 44 | sub.request(1) 45 | pub.sendNext(bytes) 46 | sub.expectNextOrError() must_== Right(frame) 47 | } 48 | 49 | ok 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/stream/FrameEncoderStageTest.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import scala.collection.immutable 4 | import akka.stream.scaladsl._ 5 | import akka.stream.testkit.scaladsl._ 6 | import akka.util.ByteString 7 | import net.danielkza.http2.{AkkaStreamsTest, TestHelpers} 8 | import net.danielkza.http2.api.Header 9 | import net.danielkza.http2.protocol.{Frame, HTTP2Error} 10 | import net.danielkza.http2.protocol.coders.FrameCoder 11 | import net.danielkza.http2.hpack.coders.HeaderBlockCoder 12 | 13 | class FrameEncoderStageTest extends AkkaStreamsTest with TestHelpers { 14 | import Frame._ 15 | val headerCoder = new HeaderBlockCoder 16 | val frameCoder = new FrameCoder 17 | 18 | val headers = immutable.Seq( 19 | ":method" -> "GET", 20 | ":path" -> "/", 21 | "host" -> "example.com" 22 | ).map(t => Header.plain(t._1, t._2)) 23 | 24 | val okFrames = immutable.Seq( 25 | Headers(1, None, headerCoder.encode(headers).getOrThrow(), endHeaders=true), 26 | Data(1, "Line 1\n"), 27 | Data(1, "Line 2\n", padding=Some("Padding"), endStream=true), 28 | GoAway(1) 29 | ) 30 | 31 | val errorFrames = immutable.Seq( 32 | GoAway(1, new HTTP2Error.CompressionError) 33 | ) 34 | 35 | "FrameEncoderStage" should { 36 | val flow = Flow[Frame].transform(() => new FrameEncoderStage) 37 | val (pub, sub) = TestSource.probe[Frame] 38 | .via(flow) 39 | .toMat(TestSink.probe[ByteString])(Keep.both) 40 | .run() 41 | 42 | "encode frames correctly" in { 43 | sub.request(okFrames.length) 44 | okFrames.foreach(pub.sendNext) 45 | sub.expectNextN(okFrames.map(frameCoder.encode(_).getOrThrow())) 46 | 47 | sub.request(errorFrames.length) 48 | errorFrames.foreach(pub.sendNext) 49 | sub.expectNextN(errorFrames.map(frameCoder.encode(_).getOrThrow())) 50 | 51 | ok 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/stream/HeaderCollapseStageTest.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import akka.util.ByteString 4 | import net.danielkza.http2.protocol.HTTP2Error.ContinuationError 5 | 6 | import scala.collection.immutable.Seq 7 | import akka.stream.scaladsl._ 8 | import akka.stream.testkit.scaladsl._ 9 | import org.specs2.matcher.MatchResult 10 | import net.danielkza.http2.protocol.{Frame, HTTP2Error} 11 | 12 | class HeaderCollapseStageTest extends HeaderStageTest { 13 | import Frame._ 14 | 15 | val testMaxSize = 16384 16 | 17 | val flow = Flow[Frame].transform(() => new HeaderCollapseStage) 18 | val graph = TestSource.probe[Frame] 19 | .via(flow) 20 | .toMat(TestSink.probe[Frame])(Keep.both) 21 | 22 | lazy val (pub, sub) = graph.run() 23 | 24 | override def runCase(testCase: (Frame, Seq[Frame])): MatchResult[Any] = { 25 | val (combined, split) = testCase 26 | 27 | sub.request(1) 28 | split.foreach(pub.sendNext) 29 | sub.expectNextOrError() must_=== Right(combined) 30 | } 31 | 32 | def runFailCase(split: Frame*): MatchResult[Any] = { 33 | sub.request(1) 34 | split.foreach(pub.sendNext) 35 | sub.expectNextOrError() must_=== Left(ContinuationError().toException) 36 | } 37 | 38 | "HeaderCollapseStage" should { 39 | "collapse" in testHeaders 40 | 41 | "passthrough" in testPassthrough 42 | 43 | "report an error for " in { 44 | isolated 45 | 46 | "non-Continuation frame" >> runFailCase( 47 | Headers(1, None, zeroes(100), endHeaders = false), 48 | Headers(1, None, zeroes(100)) 49 | ) 50 | 51 | "unfinished Continuation" >> runFailCase( 52 | Headers(1, None, zeroes(100), endHeaders = false), 53 | Continuation(1, zeroes(100), endHeaders = false), 54 | Headers(1, None, zeroes(100)) 55 | ) 56 | 57 | "Continuation with invalid stream" >> runFailCase( 58 | Headers(1, None, zeroes(100), endHeaders = false), 59 | Continuation(2, zeroes(100), endHeaders = false), 60 | Continuation(1, ByteString.empty) 61 | ) 62 | 63 | "unsolicited Continuation" >> runFailCase( 64 | Continuation(1, zeroes(100)) 65 | ) 66 | } 67 | } 68 | } 69 | 70 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/stream/HeaderSplitStageTest.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import scala.collection.immutable.Seq 4 | import akka.stream.scaladsl._ 5 | import akka.stream.testkit.scaladsl._ 6 | import org.specs2.matcher.MatchResult 7 | import net.danielkza.http2.protocol.Frame 8 | 9 | class HeaderSplitStageTest extends HeaderStageTest { 10 | val testMaxSize = 16384 11 | 12 | val flow = Flow[Frame].transform(() => new HeaderSplitStage(testMaxSize)) 13 | val graph = TestSource.probe[Frame] 14 | .via(flow) 15 | .toMat(TestSink.probe[Frame])(Keep.both) 16 | 17 | lazy val (pub, sub) = graph.run() 18 | 19 | override def runCase(testCase: (Frame, Seq[Frame])): MatchResult[Any] = { 20 | val (original, split) = testCase 21 | sub.request(split.length) 22 | pub.sendNext(original) 23 | 24 | sub.expectNextN(split.length) must containTheSameElementsAs(split) 25 | } 26 | 27 | "HeaderSplitStage" should { 28 | "split" in testHeaders 29 | "passthrough" in testPassthrough 30 | } 31 | } 32 | 33 | -------------------------------------------------------------------------------- /core/src/test/scala/net/danielkza/http2/stream/HeaderStageTest.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.stream 2 | 3 | import scala.collection.immutable.Seq 4 | import akka.util.ByteString 5 | import org.specs2.mutable.SpecificationLike 6 | import org.specs2.matcher.MatchResult 7 | import net.danielkza.http2.{AkkaStreamsTest, TestHelpers} 8 | import net.danielkza.http2.protocol.{Frame, HTTP2Error} 9 | 10 | abstract class HeaderStageTest extends AkkaStreamsTest with SpecificationLike with TestHelpers { 11 | import Frame._ 12 | 13 | sequential 14 | 15 | val testMaxSize: Int 16 | 17 | def runCase(testCase: (Frame, Seq[Frame])): MatchResult[Any] 18 | 19 | def zeroes(n: Int) = ByteString.fromArray(Array.fill(n)(0: Byte)) 20 | 21 | def testPassthrough = { 22 | import HTTP2Error.Codes._ 23 | 24 | def run(f: Frame) = runCase(f -> Seq(f)) 25 | 26 | "DATA" >> run(Data(1, ByteString.empty)) 27 | "PRIORITY" >> run(Priority(1, StreamDependency(false, 1, 1))) 28 | "RST_STREAM" >> run(ResetStream(1, PROTOCOL_ERROR)) 29 | "PING" >> run(Ping(ByteString.empty)) 30 | "SETTINGS" >> run(Settings(List())) 31 | "GOAWAY" >> run(GoAway(1)) 32 | "Non-standard" >> run(NonStandard(1, 0xFF.toByte, 0xFF.toByte, ByteString.empty)) 33 | } 34 | 35 | def testHeaders = { 36 | 37 | val streamLen = 4 38 | val padLen = 1 39 | 40 | "small Headers frame" >> runCase( 41 | Headers(1, None, zeroes(1000)) -> Seq( 42 | Headers(1, None, zeroes(1000)) 43 | ) 44 | ) 45 | 46 | "small PushPromise frame" >> runCase( 47 | PushPromise(1, 2, zeroes(1000)) -> Seq( 48 | PushPromise(1, 2, zeroes(1000)) 49 | ) 50 | ) 51 | 52 | "large Headers frame" >> runCase { 53 | val firstLen = testMaxSize - Frame.HEADER_LENGTH 54 | Headers(1, None, zeroes(20000)) -> Seq( 55 | Headers(1, None, zeroes(firstLen), endHeaders = false), 56 | Continuation(1, zeroes(20000 - firstLen)) 57 | ) 58 | } 59 | 60 | "large Headers frame with padding" >> runCase { 61 | val firstLen = testMaxSize - Frame.HEADER_LENGTH - padLen - 20 62 | Headers(1, None, zeroes(20000), padding = Some(zeroes(20))) -> Seq( 63 | Headers(1, None, zeroes(firstLen), endHeaders = false, padding = Some(zeroes(20))), 64 | Continuation(1, zeroes(20000 - firstLen)) 65 | ) 66 | } 67 | 68 | "large PushPromise frame" >> runCase { 69 | val firstLen = testMaxSize - Frame.HEADER_LENGTH - streamLen 70 | PushPromise(1, 2, zeroes(20000)) -> Seq( 71 | PushPromise(1, 2, zeroes(firstLen), endHeaders = false), 72 | Continuation(1, zeroes(20000 - firstLen)) 73 | ) 74 | } 75 | 76 | "large PushPromise frame" >> runCase { 77 | val firstLen = testMaxSize - Frame.HEADER_LENGTH - streamLen - padLen - 100 78 | PushPromise(1, 2, zeroes(20000), padding = Some(zeroes(100))) -> Seq( 79 | PushPromise(1, 2, zeroes(firstLen), endHeaders = false, padding = Some(zeroes(100))), 80 | Continuation(1, zeroes(20000 - firstLen)) 81 | ) 82 | } 83 | 84 | "very large Headers frame" >> runCase { 85 | val fragLen = testMaxSize - Frame.HEADER_LENGTH 86 | Headers(1, None, zeroes(40000)) -> Seq( 87 | Headers(1, None, zeroes(fragLen), endHeaders = false), 88 | Continuation(1, zeroes(fragLen), endHeaders = false), 89 | Continuation(1, zeroes(40000 - 2 * fragLen)) 90 | ) 91 | } 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /examples/src/main/scala/net/danielkza/http2/examples/ServerExample.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.examples 2 | 3 | import java.io.{FileInputStream, File} 4 | import java.nio.file.Files.probeContentType 5 | import java.security.KeyStore 6 | import javax.net.ssl.{KeyManagerFactory, SSLContext} 7 | import com.typesafe.config.ConfigFactory 8 | import akka.actor.ActorSystem 9 | import akka.stream._ 10 | import akka.stream.io._ 11 | import akka.http.scaladsl.model._ 12 | import akka.http.scaladsl.server._ 13 | import akka.http.scaladsl.server.Directives._ 14 | import net.danielkza.http2.Http2 15 | import net.danielkza.http2.Http2.Implicits._ 16 | 17 | import scala.concurrent.ExecutionContext 18 | 19 | object ServerExample extends App { 20 | def createSSLContext: SSLContext = { 21 | val pass = System.getProperty("javax.net.ssl.keyStorePassword").toCharArray 22 | val keyStore = KeyStore.getInstance("jks") 23 | keyStore.load(new FileInputStream(System.getProperty("javax.net.ssl.keyStore")), pass) 24 | 25 | val keyManagerFactory = KeyManagerFactory.getInstance("SunX509") 26 | keyManagerFactory.init(keyStore, pass) 27 | 28 | val ctx = SSLContext.getInstance("TLSv1.2") 29 | ctx.init(keyManagerFactory.getKeyManagers, null, null) 30 | ctx 31 | } 32 | 33 | val config = ConfigFactory.parseString( 34 | """akka { 35 | | stdout-loglevel = "DEBUG" 36 | | loglevel = "DEBUG" 37 | |} 38 | """.stripMargin) 39 | 40 | implicit val actorSystem: ActorSystem = ActorSystem("ServerExample", config) 41 | val matSettings = ActorMaterializerSettings(actorSystem) 42 | .withDebugLogging(true) 43 | .withSupervisionStrategy { e: Throwable => println(e); Supervision.Stop } 44 | implicit val materializer: ActorMaterializer = ActorMaterializer(matSettings, "http2") 45 | implicit val ec: ExecutionContext = actorSystem.dispatcher 46 | 47 | val sslContext = createSSLContext 48 | 49 | val routes: Route = 50 | path(RestPath) { path => 51 | get { 52 | val file = new File("./" + path) 53 | if(file.isFile && file.canRead) { 54 | complete { 55 | val contentType = ContentType.parse(probeContentType(file.toPath)).right.get 56 | val source = HttpEntity(contentType, SynchronousFileSource(file)) 57 | HttpResponse(StatusCodes.OK, entity = source) 58 | } 59 | } else { 60 | complete { 61 | (StatusCodes.NotFound, s"File `$path` not found") 62 | } 63 | } 64 | } 65 | } 66 | 67 | val binding = Http2().bind("0.0.0.0", port = 8080).map( 68 | _.handleWith(routes) 69 | ).runServerIndefinitely(actorSystem) 70 | } 71 | -------------------------------------------------------------------------------- /macros/src/main/scala/net/danielkza/http2/macros/BitPatterns.scala: -------------------------------------------------------------------------------- 1 | package net.danielkza.http2.macros 2 | 3 | import scala.reflect.macros.whitebox 4 | import java.nio.ByteOrder 5 | 6 | object BitPatterns { 7 | private def extractLiteral(c: whitebox.Context): String = { 8 | import c.universe._ 9 | 10 | try { 11 | val (args, prefix) = c.prefix.tree match { 12 | case q"""$wrapper($stringContext.apply(..$args)).${prefix: TermName}""" 13 | => (args, Some(prefix)) 14 | case q"""$wrapper($stringContext.apply(..$args))""" 15 | => (args, None) 16 | } 17 | 18 | (args, prefix) match { 19 | case (List(Literal(Constant(str: String))), _) => str 20 | } 21 | } catch { case e: MatchError => 22 | c.abort(c.enclosingPosition, "Invalid binary literal, must be a string literal without interpolation") 23 | } 24 | } 25 | 26 | private def parseBinaryLiteral(c: whitebox.Context)(value: String, allowPlaceholder: Boolean = false) 27 | : IndexedSeq[(Byte, Byte)] = 28 | { 29 | val clean = value.replaceAll("\\s", "") 30 | clean.foreach { 31 | case '0' | '1' => 32 | case '-' if allowPlaceholder => 33 | case _ => 34 | c.abort(c.enclosingPosition, "Invalid binary literal, must only contain 1, 0, whitespaces, dashes (only if pattern)") 35 | } 36 | 37 | clean.grouped(8).map { s => 38 | val maskStr = clean.map { case '-' => '0' case _ => '1' } 39 | val mask = Integer.parseUnsignedInt(maskStr, 2) 40 | val valueStr = clean.replace('-', '0') 41 | val value = Integer.parseUnsignedInt(valueStr, 2) 42 | (value.toByte, mask.toByte) 43 | }.toIndexedSeq 44 | } 45 | 46 | def binaryLiteralImpl[T : c.WeakTypeTag](c: whitebox.Context)(args: c.Tree*): c.Expr[T] = { 47 | import c.universe._ 48 | 49 | val bytes = parseBinaryLiteral(c)(extractLiteral(c)) 50 | var l: Long = 0 51 | bytes.indices.foreach { idx => 52 | val (byte, mask) = bytes(idx) 53 | l = (l << 8) + (byte & mask) 54 | } 55 | 56 | val (tByte, tShort, tInt, tLong) = (c.typeOf[Byte], c.typeOf[Short], c.typeOf[Int], c.typeOf[Long]) 57 | 58 | implicitly[c.WeakTypeTag[T]].tpe match { 59 | case `tByte` => 60 | val b = l.toByte 61 | c.Expr[Byte](q"""$b: Byte""").asInstanceOf[c.Expr[T]] 62 | case `tShort` => 63 | val s = l.toShort 64 | c.Expr[Short](q"""$s: Short""").asInstanceOf[c.Expr[T]] 65 | case `tInt` => 66 | val i = l.toInt 67 | c.Expr[Int](q"""$i: Int""").asInstanceOf[c.Expr[T]] 68 | case `tLong` => 69 | c.Expr[Long](q"""$l: Long""").asInstanceOf[c.Expr[T]] 70 | case tpe => 71 | c.abort(c.enclosingPosition, s"Unsupported type ${tpe.toString}") 72 | } 73 | } 74 | 75 | def binaryLiteralExtractorImpl[T](c: whitebox.Context)(x: c.Tree)(implicit tt: c.WeakTypeTag[T]) = { 76 | import c.universe._ 77 | import definitions._ 78 | 79 | val bytes = parseBinaryLiteral(c)(extractLiteral(c), allowPlaceholder = true) 80 | 81 | def expr(trees: c.Tree*) = { 82 | val tree = trees.reduceLeft((a, b) => q"""$a && $b""") 83 | 84 | q""" 85 | new { 86 | @inline final def unapply(b: ${tt.tpe}): Boolean = $tree 87 | }.unapply($x) 88 | """ 89 | } 90 | def byteComparison(byte: Byte, mask: Byte, shift: Int = 0) = { 91 | q"""((b >>> $shift) & $mask) == $byte""" 92 | } 93 | 94 | def bytesComparisons = { 95 | val exprs = bytes.zipWithIndex.map { case ((byte, mask), index) => byteComparison(byte, mask, index * 8) } 96 | expr(exprs: _*) 97 | } 98 | 99 | def literalError = c.abort(x.pos, "Binary literal to big for primitive type") 100 | 101 | val (byteClass, shortClass, intClass, longClass) = (ByteClass, ShortClass, IntClass, LongClass) 102 | tt.tpe.typeSymbol.asClass match { 103 | case `byteClass` if bytes.length == 1 => bytesComparisons 104 | case `byteClass` => literalError 105 | case `shortClass` if bytes.length <= 2 => bytesComparisons 106 | case `shortClass` => literalError 107 | case `intClass` if bytes.length <= 4 => bytesComparisons 108 | case `intClass` => literalError 109 | case `longClass` if bytes.length <= 8 => bytesComparisons 110 | case `longClass` => literalError 111 | case _ => 112 | c.abort(c.enclosingPosition, "Unexpected non-primitive type in pattern") 113 | } 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /project/Build.scala: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | import Keys._ 3 | 4 | object BuildSettings { 5 | val Debug = config("debug").extend(Runtime) 6 | 7 | val buildSettings = Defaults.defaultSettings ++ Seq( 8 | version := "0.1-SNAPSHOT", 9 | scalaVersion := "2.11.7", 10 | 11 | javacOptions ++= Seq("-source", "1.8", "-target", "1.8"), 12 | scalacOptions ++= Seq("-target:jvm-1.8"), 13 | 14 | resolvers ++= Seq( 15 | Resolver.sonatypeRepo("releases"), 16 | Resolver.sonatypeRepo("snapshots") 17 | ), 18 | libraryDependencies ++= Seq( 19 | "org.scala-lang" % "scala-reflect" % scalaVersion.value, 20 | "org.scalaz" %% "scalaz-core" % "7.1.4", 21 | "com.chuusai" %% "shapeless" % "2.2.5", 22 | "com.typesafe.akka" %% "akka-actor" % "2.4.0", 23 | "com.typesafe.akka" %% "akka-stream-experimental" % "2.0-M1", 24 | "com.typesafe.akka" %% "akka-http-core-experimental" % "2.0-M1", 25 | "org.mortbay.jetty.alpn" % "alpn-boot" % "8.1.6.v20151105" % "provided", 26 | "com.github.pathikrit" %% "better-files" % "2.13.0" % "test", 27 | "org.specs2" %% "specs2-core" % "3.6.4" % "test", 28 | "io.argonaut" %% "argonaut" % "6.1-M4" % "test", 29 | "com.typesafe.akka" %% "akka-stream-testkit-experimental" % "2.0-M1" % "test" 30 | 31 | ), 32 | addCompilerPlugin("org.spire-math" %% "kind-projector" % "0.7.1"), 33 | 34 | fork in (run in Compile) := true, 35 | javaOptions ++= { 36 | val jars = (fullClasspath in Compile).value.files 37 | val alpnJar = jars.find(_.getName.contains("alpn-boot")).getOrElse { 38 | sys.error("No Jetty alpn-boot JAR found in classpath, cannot continue") 39 | } 40 | 41 | Seq( 42 | "-Xbootclasspath/p:" + alpnJar.absolutePath, 43 | "-Djavax.net.ssl.keyStore=" + (baseDirectory in ThisBuild).value / "keystore.jks", 44 | "-Djavax.net.ssl.keyStorePassword=test" 45 | ) 46 | }, 47 | javaOptions in Debug ++= Seq( 48 | "-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=9999", 49 | "-Djavax.net.debug=handshake" 50 | ), 51 | 52 | scalacOptions in Test ++= Seq("-Yrangepos"), // for Specs2 53 | testOptions in Test += Tests.Setup { _ => 54 | val casesPath = ((baseDirectory in ThisBuild) / "http2-frame-test-case").value.absolutePath 55 | sys.props += "http2.frame_tests_dir" -> casesPath 56 | } 57 | ) 58 | } 59 | 60 | object MyBuild extends Build { 61 | import BuildSettings._ 62 | 63 | lazy val core: Project = 64 | Project("core", file("core")).dependsOn(macros).settings(buildSettings) 65 | 66 | lazy val examples: Project = 67 | Project("examples", file("examples")).dependsOn(core) 68 | .configs(Debug) 69 | .settings(inConfig(Debug)(Defaults.configTasks):_*) 70 | .settings(buildSettings) 71 | .settings( 72 | libraryDependencies += "com.typesafe.akka" %% "akka-http-experimental" % "2.0-M1" 73 | ) 74 | 75 | lazy val root: Project = 76 | Project("http2-server", file(".")).aggregate(macros, core).settings(buildSettings).settings( 77 | run <<= run in Compile in core 78 | ) 79 | 80 | lazy val macros: Project = 81 | Project("macros", file("macros")).settings(buildSettings).settings( 82 | libraryDependencies ++= Seq( 83 | "org.scala-lang" % "scala-reflect" % scalaVersion.value, 84 | "org.scala-lang" % "scala-compiler" % scalaVersion.value 85 | ) 86 | ) 87 | 88 | } 89 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.9 2 | 3 | [log] 4 | level: debug 5 | --------------------------------------------------------------------------------