├── project ├── build.properties └── plugins.sbt ├── cmd ├── Dockerfile-native ├── Dockerfile ├── resources │ └── logback.xml └── src │ └── main │ └── scala │ └── Main.scala ├── .gitignore ├── common └── src │ ├── main │ └── scala │ │ └── com │ │ └── github │ │ └── torrentdam │ │ └── bittorrent │ │ ├── PeerInfo.scala │ │ ├── InfoHash.scala │ │ ├── PeerId.scala │ │ └── MagnetLink.scala │ └── test │ └── scala │ └── MagnetLinkSuite.scala ├── .scalafmt.conf ├── README.md ├── bittorrent ├── shared │ └── src │ │ ├── test │ │ ├── resources │ │ │ └── bencode │ │ │ │ └── ubuntu-18.10-live-server-amd64.iso.torrent │ │ └── scala │ │ │ └── com │ │ │ └── github │ │ │ └── torrentdam │ │ │ └── bittorrent │ │ │ ├── protocol │ │ │ └── message │ │ │ │ └── HandshakeSpec.scala │ │ │ ├── FileMappingSpec.scala │ │ │ ├── wire │ │ │ ├── RequestDispatcherSpec.scala │ │ │ └── WorkQueueSuite.scala │ │ │ └── TorrentMetadataSpec.scala │ │ └── main │ │ └── scala │ │ └── com │ │ └── github │ │ └── torrentdam │ │ └── bittorrent │ │ ├── protocol │ │ ├── extensions │ │ │ ├── Extensions.scala │ │ │ ├── ExtensionHandshake.scala │ │ │ └── metadata │ │ │ │ └── UtMessage.scala │ │ └── message.scala │ │ ├── wire │ │ ├── DownloadMetadata.scala │ │ ├── Torrent.scala │ │ ├── Swarm.scala │ │ ├── MessageSocket.scala │ │ ├── ExtensionHandler.scala │ │ ├── Download.scala │ │ ├── RequestDispatcher.scala │ │ └── Connection.scala │ │ ├── FileMapping.scala │ │ └── TorrentMetadata.scala ├── jvm │ └── src │ │ └── main │ │ └── scala │ │ └── com │ │ └── github │ │ └── torrentdam │ │ └── bittorrent │ │ └── CrossPlatform.scala └── native │ └── src │ └── main │ └── scala │ └── com │ └── github │ └── torrentdam │ └── bittorrent │ └── CrossPlatform.scala ├── .scala-build └── .scalafmt.conf ├── dht └── src │ ├── test │ └── scala │ │ └── com │ │ └── github │ │ └── torrentdam │ │ └── bittorrent │ │ └── dht │ │ ├── NoOpLogger.scala │ │ ├── MessageFormatSpec.scala │ │ └── PeerDiscoverySpec.scala │ └── main │ └── scala │ └── com │ └── github │ └── torrentdam │ └── bittorrent │ └── dht │ ├── NodeInfo.scala │ ├── QueryHandler.scala │ ├── RoutingTableRefresh.scala │ ├── MessageSocket.scala │ ├── RoutingTableBootstrap.scala │ ├── RequestResponse.scala │ ├── Client.scala │ ├── Node.scala │ ├── RoutingTable.scala │ ├── PeerDiscovery.scala │ └── message.scala ├── files └── src │ ├── main │ └── scala │ │ └── com │ │ └── github │ │ └── torrentdam │ │ └── bittorrent │ │ └── files │ │ ├── package.scala │ │ ├── Reader.scala │ │ └── Writer.scala │ └── test │ └── scala │ └── com │ └── github │ └── torrentdam │ └── bittorrent │ └── files │ └── WriterSpec.scala ├── LICENSE ├── .github └── workflows │ └── build.yml └── mill /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.10.1 2 | -------------------------------------------------------------------------------- /cmd/Dockerfile-native: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | COPY ./.native/target/scala-3.3.0/cmd-out /opt/torrentdam 4 | 5 | ENTRYPOINT ["/opt/torrentdam"] 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | out 3 | 4 | .idea 5 | .idea_modules 6 | 7 | .vscode 8 | 9 | metals.sbt 10 | .metals 11 | .bloop 12 | .bsp 13 | 14 | logs 15 | -------------------------------------------------------------------------------- /common/src/main/scala/com/github/torrentdam/bittorrent/PeerInfo.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent 2 | 3 | import com.comcast.ip4s.* 4 | 5 | final case class PeerInfo(address: SocketAddress[IpAddress]) 6 | -------------------------------------------------------------------------------- /.scalafmt.conf: -------------------------------------------------------------------------------- 1 | version = 3.5.3 2 | runner.dialect = scala3 3 | maxColumn = 120 4 | continuationIndent.defnSite = 2 5 | rewrite.rules = [SortImports] 6 | rewrite.imports.expand = true 7 | verticalAlignMultilineOperators = true 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BitTorrent Client 2 | 3 | ## Development 4 | 5 | Run tests: 6 | ```sh 7 | $ sbt test 8 | ``` 9 | 10 | ## Releasing 11 | 12 | Tagging triggers uploading jars to Sonatype where they have to be manually released. -------------------------------------------------------------------------------- /bittorrent/shared/src/test/resources/bencode/ubuntu-18.10-live-server-amd64.iso.torrent: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorrentDam/bittorrent/HEAD/bittorrent/shared/src/test/resources/bencode/ubuntu-18.10-live-server-amd64.iso.torrent -------------------------------------------------------------------------------- /.scala-build/.scalafmt.conf: -------------------------------------------------------------------------------- 1 | version = "3.5.3" 2 | runner.dialect = scala3 3 | maxColumn = 120 4 | continuationIndent.defnSite = 2 5 | rewrite.rules = [SortImports] 6 | rewrite.imports.expand = true 7 | verticalAlignMultilineOperators = true 8 | -------------------------------------------------------------------------------- /bittorrent/jvm/src/main/scala/com/github/torrentdam/bittorrent/CrossPlatform.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent 2 | 3 | import scodec.bits.ByteVector 4 | 5 | object CrossPlatform { 6 | def sha1(bytes: ByteVector): ByteVector = bytes.sha1 7 | } -------------------------------------------------------------------------------- /cmd/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM eclipse-temurin:24-jre-noble 2 | 3 | COPY ./out/cmd/assembly.dest/out.jar /opt/torrentdam/assembly.jar 4 | 5 | ENTRYPOINT [ "java", "-Dcats.effect.tracing.mode=none", "-XX:+UnlockExperimentalVMOptions", "-XX:+UseCompactObjectHeaders", "-jar", "/opt/torrentdam/assembly.jar"] 6 | -------------------------------------------------------------------------------- /bittorrent/native/src/main/scala/com/github/torrentdam/bittorrent/CrossPlatform.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent 2 | 3 | import scodec.bits.ByteVector 4 | 5 | object CrossPlatform { 6 | def sha1(bytes: ByteVector): ByteVector = { 7 | val md = java.security.MessageDigest.getInstance("SHA-1") 8 | val digest = md.digest(bytes.toArray) 9 | md.reset() 10 | ByteVector(digest) 11 | } 12 | } -------------------------------------------------------------------------------- /cmd/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | %-4relative [%thread] %-5level %logger{35} - %msg %n 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.2.1") 2 | addSbtPlugin("com.github.sbt" % "sbt-native-packager" % "1.9.9") 3 | 4 | // Multi-platform support 5 | addSbtPlugin("org.scala-js" % "sbt-scalajs" % "1.16.0") 6 | addSbtPlugin("org.portable-scala" % "sbt-scalajs-crossproject" % "1.2.0") 7 | addSbtPlugin("org.portable-scala" % "sbt-scala-native-crossproject" % "1.2.0") 8 | addSbtPlugin("org.scala-native" % "sbt-scala-native" % "0.4.17") 9 | -------------------------------------------------------------------------------- /bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/protocol/extensions/Extensions.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.protocol.extensions 2 | 3 | import com.github.torrentdam.bittorrent.protocol.message.Message 4 | 5 | object Extensions { 6 | 7 | object MessageId { 8 | val Handshake = 0L 9 | val Metadata = 1L 10 | } 11 | 12 | def handshake: ExtensionHandshake = 13 | ExtensionHandshake( 14 | Map( 15 | ("ut_metadata", MessageId.Metadata) 16 | ), 17 | None 18 | ) 19 | } 20 | -------------------------------------------------------------------------------- /common/src/main/scala/com/github/torrentdam/bittorrent/InfoHash.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent 2 | import scodec.bits.ByteVector 3 | 4 | final case class InfoHash(bytes: ByteVector) { 5 | def toHex: String = bytes.toHex 6 | override def toString: String = toHex 7 | } 8 | 9 | object InfoHash { 10 | 11 | val fromString: PartialFunction[String, InfoHash] = 12 | Function.unlift { s => 13 | for { 14 | b <- ByteVector.fromHexDescriptive(s.toLowerCase).toOption 15 | _ <- if (b.length == 20) Some(()) else None 16 | } yield InfoHash(b) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /dht/src/test/scala/com/github/torrentdam/bittorrent/dht/NoOpLogger.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.dht 2 | 3 | import cats.effect.IO 4 | import cats.syntax.all.given 5 | import org.legogroup.woof.LogInfo 6 | import org.legogroup.woof.LogLevel 7 | import org.legogroup.woof.Logger 8 | import org.legogroup.woof.Logger.StringLocal 9 | 10 | class NoOpLogger extends Logger[IO] { 11 | val stringLocal: StringLocal[IO] = NoOpLocal() 12 | def doLog(level: LogLevel, message: String)(using LogInfo): IO[Unit] = IO.unit 13 | } 14 | 15 | class NoOpLocal extends Logger.StringLocal[IO] { 16 | def ask = List.empty[(String, String)].pure[IO] 17 | def local[A](fa: IO[A])(f: List[(String, String)] => List[(String, String)]) = fa 18 | } 19 | -------------------------------------------------------------------------------- /files/src/main/scala/com/github/torrentdam/bittorrent/files/package.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.files 2 | 3 | import com.github.torrentdam.bittorrent.TorrentMetadata 4 | import scala.util.chaining.scalaUtilChainingOps 5 | 6 | private def createFileOffsets(files: List[TorrentMetadata.File]) = 7 | Array 8 | .ofDim[FileRange](files.length) 9 | .tap: array => 10 | var currentOffset: Long = 0 11 | files.iterator.zipWithIndex.foreach( 12 | (file, index) => 13 | array(index) = FileRange(file, currentOffset, currentOffset + file.length) 14 | currentOffset += file.length 15 | ) 16 | 17 | extension (fileOffsets: Array[FileRange]) { 18 | private def matchFiles(start: Long, end: Long): Iterator[FileRange] = 19 | fileOffsets.iterator 20 | .dropWhile(_.endOffset <= start) 21 | .takeWhile(_.startOffset < end) 22 | } 23 | 24 | private case class FileRange(file: TorrentMetadata.File, startOffset: Long, endOffset: Long) 25 | -------------------------------------------------------------------------------- /common/src/main/scala/com/github/torrentdam/bittorrent/PeerId.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent 2 | 3 | import cats.effect.std.Random 4 | import cats.syntax.all.* 5 | import cats.Monad 6 | import scodec.bits.ByteVector 7 | 8 | final case class PeerId(bytes: ByteVector) { 9 | override def toString() = s"PeerId(${bytes.decodeUtf8.getOrElse(bytes.toHex)})" 10 | } 11 | 12 | object PeerId { 13 | 14 | def apply(b0: Byte, b1: Byte, b2: Byte, b3: Byte, b4: Byte, b5: Byte): PeerId = { 15 | val bytes = Array[Byte](b0, b1, b2, b3, b4, b5) 16 | val hexPart = ByteVector(bytes).toHex 17 | new PeerId(ByteVector.encodeUtf8("-qB0000-" + hexPart).toOption.get) 18 | } 19 | 20 | private def apply(bytes: Array[Byte]): PeerId = { 21 | val hexPart = ByteVector(bytes).toHex 22 | new PeerId(ByteVector.encodeUtf8("-qB0000-" + hexPart).toOption.get) 23 | } 24 | 25 | def generate[F[_]](using Random[F], Monad[F]): F[PeerId] = 26 | for bytes <- Random[F].nextBytes(6) 27 | yield PeerId(bytes) 28 | } 29 | -------------------------------------------------------------------------------- /bittorrent/shared/src/test/scala/com/github/torrentdam/bittorrent/protocol/message/HandshakeSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.protocol.message 2 | 3 | import com.github.torrentdam.bittorrent.InfoHash 4 | import com.github.torrentdam.bittorrent.PeerId 5 | import scala.util.chaining.* 6 | import scodec.bits.ByteVector 7 | 8 | class HandshakeSpec extends munit.FunSuite { 9 | 10 | test("read and write protocol extension bit") { 11 | val message = 12 | Handshake( 13 | true, 14 | InfoHash(ByteVector.fill(20)(0)), 15 | PeerId(0, 0, 0, 0, 0, 0) 16 | ) 17 | assert( 18 | PartialFunction.cond(Handshake.HandshakeCodec.encode(message).toOption) { 19 | case Some(bits) => 20 | bits 21 | .splitAt(20 * 8) 22 | .pipe { case (_, bits) => 23 | bits.splitAt(64) 24 | } 25 | .pipe { case (reserved, bits) => 26 | assert(reserved.get(42) == false) 27 | assert(reserved.get(43) == true) 28 | assert(reserved.get(44) == false) 29 | } 30 | true 31 | case _ => false 32 | } 33 | ) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to -------------------------------------------------------------------------------- /common/src/test/scala/MagnetLinkSuite.scala: -------------------------------------------------------------------------------- 1 | import com.github.torrentdam.bittorrent.* 2 | 3 | class MagnetLinkSuite extends munit.FunSuite { 4 | 5 | test("parses magnet link") { 6 | val link = 7 | """magnet:?xt=urn:btih:C071AA6D06101FE3C1D8D3411343CFEB33D91E5F&tr=http%3A%2F%2Fbt.t-ru.org%2Fann%3Fmagnet&dn=%D0%9C%D0%B0%D1%82%D1%80%D0%B8%D1%86%D0%B0%3A%20%D0%92%D0%BE%D1%81%D0%BA%D1%80%D0%B5%D1%88%D0%B5%D0%BD%D0%B8%D0%B5%20%2F%20The%20Matrix%20Resurrections%20(%D0%9B%D0%B0%D0%BD%D0%B0%20%D0%92%D0%B0%D1%87%D0%BE%D0%B2%D1%81%D0%BA%D0%B8%20%2F%20Lana%20Wachowski)%20%5B2021%2C%20%D0%A1%D0%A8%D0%90%2C%20%D0%A4%D0%B0%D0%BD%D1%82%D0%B0%D1%81%D1%82%D0%B8%D0%BA%D0%B0%2C%20%D0%B1%D0%BE%D0%B5%D0%B2%D0%B8%D0%BA%2C%20WEB-DLRip%5D%20MVO%20(Jaskier)%20%2B%20Sub%20Rus%2C%20E""" 8 | assertEquals( 9 | MagnetLink.fromString(link), 10 | Some( 11 | MagnetLink( 12 | infoHash = InfoHash.fromString("C071AA6D06101FE3C1D8D3411343CFEB33D91E5F"), 13 | displayName = Some( 14 | "Матрица: Воскрешение / The Matrix Resurrections (Лана Вачовски / Lana Wachowski) [2021, США, Фантастика, боевик, WEB-DLRip] MVO (Jaskier) + Sub Rus, E" 15 | ), 16 | trackers = List("http://bt.t-ru.org/ann?magnet") 17 | ) 18 | ) 19 | ) 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/DownloadMetadata.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.wire 2 | 3 | import cats.effect.implicits.* 4 | import cats.effect.IO 5 | import cats.implicits.* 6 | import com.github.torrentdam.bittorrent.TorrentMetadata.Lossless 7 | import fs2.Stream 8 | import org.legogroup.woof.given 9 | import org.legogroup.woof.Logger 10 | import scala.concurrent.duration.* 11 | 12 | object DownloadMetadata { 13 | 14 | def apply(swarm: Swarm)(using logger: Logger[IO]): IO[Lossless] = 15 | logger.info("Downloading metadata") >> 16 | Stream.unit.repeat 17 | .parEvalMapUnordered(10)(_ => 18 | swarm.connect 19 | .use(connection => DownloadMetadata(connection).timeout(1.minute)) 20 | .attempt 21 | ) 22 | .collectFirst { case Right(metadata) => 23 | metadata 24 | } 25 | .compile 26 | .lastOrError 27 | .flatTap(_ => logger.info("Metadata downloaded")) 28 | 29 | def apply(connection: Connection)(using logger: Logger[IO]): IO[Lossless] = 30 | connection.extensionApi 31 | .flatMap(_.utMetadata.liftTo[IO](UtMetadataNotSupported())) 32 | .flatMap(_.fetch) 33 | 34 | case class UtMetadataNotSupported() extends Throwable("UtMetadata is not supported") 35 | 36 | } 37 | -------------------------------------------------------------------------------- /files/src/main/scala/com/github/torrentdam/bittorrent/files/Reader.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.files 2 | 3 | import cats.effect.IO 4 | import com.github.torrentdam.bittorrent.TorrentMetadata 5 | import scodec.bits.ByteVector 6 | 7 | trait Reader { 8 | def read(pieceIndex: Long): List[Reader.ReadBytes] 9 | } 10 | 11 | object Reader { 12 | 13 | def fromTorrent(torrent: TorrentMetadata): Reader = apply(torrent.files, torrent.pieceLength) 14 | case class ReadBytes(file: TorrentMetadata.File, offset: Long, endOffset: Long) 15 | private[files] def apply(files: List[TorrentMetadata.File], pieceLength: Long): Reader = 16 | val totalLength = files.map(_.length).sum 17 | val fileOffsets = createFileOffsets(files) 18 | new { 19 | def read(pieceIndex: Long): List[ReadBytes] = 20 | val pieceStartOffset = pieceIndex * pieceLength 21 | val pieceEndOffset = math.min(pieceStartOffset + pieceLength, totalLength) 22 | val fileRanges = fileOffsets.matchFiles(pieceStartOffset, pieceEndOffset) 23 | fileRanges 24 | .map(range => 25 | ReadBytes( 26 | range.file, 27 | math.max(range.startOffset, pieceStartOffset) - range.startOffset, 28 | math.min(range.endOffset, pieceEndOffset) - range.startOffset 29 | ) 30 | ) 31 | .toList 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/protocol/extensions/ExtensionHandshake.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.protocol.extensions 2 | 3 | import cats.syntax.all.* 4 | import com.github.torrentdam.bencode 5 | import com.github.torrentdam.bencode.format.* 6 | import scodec.bits.ByteVector 7 | 8 | case class ExtensionHandshake( 9 | extensions: Map[String, Long], 10 | metadataSize: Option[Long] 11 | ) 12 | 13 | object ExtensionHandshake { 14 | 15 | private val format = 16 | ( 17 | field[Map[String, Long]]("m"), 18 | fieldOptional[Long]("metadata_size") 19 | ).imapN(ExtensionHandshake.apply)(v => (v.extensions, v.metadataSize)) 20 | 21 | def encode(handshake: ExtensionHandshake): ByteVector = 22 | bencode 23 | .encode(format.write(handshake).toOption.get) 24 | .toByteVector 25 | 26 | def decode(bytes: ByteVector): Either[Throwable, ExtensionHandshake] = 27 | for 28 | bc <- 29 | bencode 30 | .decode(bytes.bits) 31 | .leftMap(Error.BencodeError.apply) 32 | handshakeResponse <- 33 | ExtensionHandshake.format 34 | .read(bc) 35 | .leftMap(Error.HandshakeFormatError("Unable to parse handshake response", _)) 36 | yield handshakeResponse 37 | 38 | object Error { 39 | case class BencodeError(cause: Throwable) extends Error(cause) 40 | case class HandshakeFormatError(message: String, cause: Throwable) extends Error(message, cause) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /bittorrent/shared/src/test/scala/com/github/torrentdam/bittorrent/FileMappingSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent 2 | 3 | import scodec.bits.ByteVector 4 | import FileMapping.FileSpan 5 | import TorrentMetadata.File 6 | 7 | class FileMappingSpec extends munit.FunSuite { 8 | 9 | test("all files in one piece") { 10 | val metadata = TorrentMetadata( 11 | name = "test", 12 | pieceLength = 10, 13 | pieces = ByteVector.empty, 14 | files = List( 15 | File(5, Nil), 16 | File(3, Nil), 17 | File(2, Nil) 18 | ) 19 | ) 20 | val result = FileMapping.fromMetadata(metadata) 21 | val expectation = FileMapping( 22 | List( 23 | FileSpan(0, 5, 0, 0, 0, 5), 24 | FileSpan(1, 3, 0, 5, 0, 8), 25 | FileSpan(2, 2, 0, 8, 1, 0) 26 | ), 27 | pieceLength = 10 28 | ) 29 | assert(result == expectation) 30 | } 31 | 32 | test("file spans multiple pieces") { 33 | val metadata = TorrentMetadata( 34 | name = "test", 35 | pieceLength = 10, 36 | pieces = ByteVector.empty, 37 | files = List( 38 | File(5, Nil), 39 | File(13, Nil), 40 | File(2, Nil) 41 | ) 42 | ) 43 | val result = FileMapping.fromMetadata(metadata) 44 | val expectation = FileMapping( 45 | List( 46 | FileSpan(0, 5, 0, 0, 0, 5), 47 | FileSpan(1, 13, 0, 5, 1, 8), 48 | FileSpan(2, 2, 1, 8, 2, 0) 49 | ), 50 | pieceLength = 10 51 | ) 52 | assert(result == expectation) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /bittorrent/shared/src/test/scala/com/github/torrentdam/bittorrent/wire/RequestDispatcherSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.wire 2 | 3 | import com.github.torrentdam.bittorrent.wire.RequestDispatcher 4 | import com.github.torrentdam.bittorrent.TorrentFile 5 | import com.github.torrentdam.bittorrent.TorrentMetadata 6 | import com.github.torrentdam.bencode 7 | import com.github.torrentdam.bencode.format.BencodeFormat 8 | import scodec.bits.BitVector 9 | 10 | class RequestDispatcherSpec extends munit.FunSuite { 11 | 12 | test("build request queue from torrent metadata") { 13 | val source = getClass.getClassLoader 14 | .getResourceAsStream("bencode/ubuntu-18.10-live-server-amd64.iso.torrent") 15 | .readAllBytes() 16 | val Right(result) = bencode.decode(BitVector(source)): @unchecked 17 | val torrentFile = summon[BencodeFormat[TorrentFile]].read(result).toOption.get 18 | assert( 19 | PartialFunction.cond(torrentFile.info) { 20 | case TorrentMetadata.Lossless(metadata @ TorrentMetadata(_, _, pieces, List(file)), _) => 21 | val fileSize = file.length 22 | val piecesTotal = (pieces.length.toDouble / 20).ceil.toInt 23 | val workGenerator = RequestDispatcher.WorkGenerator(metadata) 24 | val queue = 25 | (0 until piecesTotal).map(pieceIndex => workGenerator.pieceWork(pieceIndex)) 26 | assert(queue.map(_.size).toList.sum == fileSize) 27 | assert(queue.toList.flatMap(_.requests.toList).map(_.length).sum == fileSize) 28 | true 29 | } 30 | ) 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/protocol/extensions/metadata/UtMessage.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.protocol.extensions.metadata 2 | 3 | import cats.syntax.all.* 4 | import com.github.torrentdam.bencode 5 | import com.github.torrentdam.bencode.format.* 6 | import scodec.bits.ByteVector 7 | 8 | enum UtMessage: 9 | case Request(piece: Long) 10 | case Data(piece: Long, byteVector: ByteVector) 11 | case Reject(piece: Long) 12 | 13 | object UtMessage { 14 | 15 | val MessageFormat: BencodeFormat[(Long, Long)] = 16 | ( 17 | field[Long]("msg_type"), 18 | field[Long]("piece") 19 | ).tupled 20 | 21 | def encode(message: UtMessage): ByteVector = { 22 | val (bc, extraBytes) = 23 | message match { 24 | case Request(piece) => (MessageFormat.write((0, piece)).toOption.get, none) 25 | case Data(piece, bytes) => (MessageFormat.write((1, piece)).toOption.get, bytes.some) 26 | case Reject(piece) => (MessageFormat.write((2, piece)).toOption.get, none) 27 | } 28 | bencode.encode(bc).toByteVector ++ extraBytes.getOrElse(ByteVector.empty) 29 | } 30 | 31 | def decode(bytes: ByteVector): Either[Throwable, UtMessage] = { 32 | bencode 33 | .decodeHead(bytes.toBitVector) 34 | .flatMap { case (remainder, result) => 35 | MessageFormat.read(result).map { case (msgType, piece) => 36 | msgType match { 37 | case 0 => Request(piece) 38 | case 1 => Data(piece, remainder.toByteVector) 39 | case 2 => Reject(piece) 40 | } 41 | } 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /dht/src/main/scala/com/github/torrentdam/bittorrent/dht/NodeInfo.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.dht 2 | 3 | import cats.effect.std.Random 4 | import cats.syntax.all.* 5 | import cats.* 6 | import com.comcast.ip4s.* 7 | import com.github.torrentdam.bittorrent.InfoHash 8 | import scodec.bits.ByteVector 9 | 10 | final case class NodeInfo(id: NodeId, address: SocketAddress[IpAddress]) 11 | 12 | final case class NodeId(bytes: ByteVector) { 13 | val int: BigInt = BigInt(1, bytes.toArray) 14 | } 15 | 16 | object NodeId { 17 | 18 | private def distance(a: ByteVector, b: ByteVector): BigInt = BigInt(1, (a.xor(b)).toArray) 19 | 20 | def distance(a: NodeId, b: NodeId): BigInt = distance(a.bytes, b.bytes) 21 | 22 | def distance(a: NodeId, b: InfoHash): BigInt = distance(a.bytes, b.bytes) 23 | 24 | def random[F[_]](using Random[F], Monad[F]): F[NodeId] = { 25 | for bytes <- Random[F].nextBytes(20) 26 | yield NodeId(ByteVector.view(bytes)) 27 | } 28 | 29 | def fromInt(int: BigInt): NodeId = NodeId(ByteVector.view(int.toByteArray).padTo(20)) 30 | 31 | def randomInRange[F[_]](from: BigInt, until: BigInt)(using Random[F], Monad[F]): F[NodeId] = 32 | val difference = BigDecimal(until - from) 33 | for 34 | randomDouble <- Random[F].nextDouble 35 | integer = from + (difference * randomDouble).toBigInt 36 | bigIntBytes = ByteVector(integer.toByteArray) 37 | vector = if bigIntBytes(0) == 0 then bigIntBytes.tail else bigIntBytes 38 | yield NodeId(vector.padLeft(20)) 39 | 40 | given Show[NodeId] = nodeId => s"NodeId(${nodeId.bytes.toHex})" 41 | 42 | val MaxValue: BigInt = BigInt(1, Array.fill(20)(0xff.toByte)) 43 | } 44 | -------------------------------------------------------------------------------- /dht/src/main/scala/com/github/torrentdam/bittorrent/dht/QueryHandler.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.dht 2 | 3 | import cats.implicits.* 4 | import cats.Monad 5 | import com.comcast.ip4s.* 6 | import com.github.torrentdam.bittorrent.PeerInfo 7 | 8 | trait QueryHandler[F[_]] { 9 | def apply(address: SocketAddress[IpAddress], query: Query): F[Option[Response]] 10 | } 11 | 12 | object QueryHandler { 13 | 14 | def noop[F[_]: Monad]: QueryHandler[F] = (_, _) => none.pure[F] 15 | 16 | def simple[F[_]: Monad](selfId: NodeId, routingTable: RoutingTable[F]): QueryHandler[F] = { (address, query) => 17 | query match { 18 | case Query.Ping(_) => 19 | Response.Ping(selfId).some.pure[F] 20 | case Query.FindNode(_, target) => 21 | routingTable.goodNodes(target).map { nodes => 22 | Response.Nodes(selfId, nodes.take(8).toList).some 23 | } 24 | case Query.GetPeers(_, infoHash) => 25 | routingTable.findPeers(infoHash).flatMap { 26 | case Some(peers) => 27 | Response.Peers(selfId, peers.toList).some.pure[F] 28 | case None => 29 | routingTable 30 | .goodNodes(NodeId(infoHash.bytes)) 31 | .map { nodes => 32 | Response.Nodes(selfId, nodes.take(8).toList).some 33 | } 34 | } 35 | case Query.AnnouncePeer(_, infoHash, port) => 36 | routingTable 37 | .addPeer(infoHash, PeerInfo(SocketAddress(address.host, Port.fromInt(port.toInt).get))) 38 | .as( 39 | Response.Ping(selfId).some 40 | ) 41 | case Query.SampleInfoHashes(_, _) => 42 | Response.SampleInfoHashes(selfId, None, List.empty).some.pure[F] 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /files/src/main/scala/com/github/torrentdam/bittorrent/files/Writer.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.files 2 | 3 | import cats.effect.IO 4 | import com.github.torrentdam.bittorrent.TorrentMetadata 5 | import scodec.bits.ByteVector 6 | 7 | import util.chaining.scalaUtilChainingOps 8 | 9 | trait Writer { 10 | def write(index: Long, bytes: ByteVector): List[Writer.WriteBytes] 11 | } 12 | 13 | object Writer { 14 | 15 | def fromTorrent(torrent: TorrentMetadata): Writer = apply(torrent.files, torrent.pieceLength) 16 | case class WriteBytes(file: TorrentMetadata.File, offset: Long, bytes: ByteVector) 17 | private[files] def apply(files: List[TorrentMetadata.File], pieceLength: Long): Writer = 18 | val fileOffsets = createFileOffsets(files) 19 | 20 | def distribute(pieceOffset: Long, byteVector: ByteVector, files: Iterator[FileRange]): Iterator[WriteBytes] = 21 | var remainingBytes = byteVector 22 | var bytesOffset = pieceOffset 23 | files.map: fileRange => 24 | val offsetInFile = bytesOffset - fileRange.startOffset 25 | val lengthToWrite = math.min(remainingBytes.length, fileRange.file.length - offsetInFile) 26 | val (bytesToWrite, rem) = remainingBytes.splitAt(lengthToWrite) 27 | remainingBytes = rem 28 | bytesOffset += lengthToWrite 29 | WriteBytes( 30 | fileRange.file, 31 | offsetInFile, 32 | bytesToWrite 33 | ) 34 | new { 35 | def write(pieceIndex: Long, bytes: ByteVector): List[WriteBytes] = 36 | val pieceOffset = pieceIndex * pieceLength 37 | val files = fileOffsets.matchFiles(pieceOffset, pieceOffset + bytes.length) 38 | val writes = distribute(pieceOffset, bytes, files) 39 | writes.toList 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /common/src/main/scala/com/github/torrentdam/bittorrent/MagnetLink.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent 2 | 3 | import com.github.torrentdam.bittorrent.InfoHash 4 | import java.net.URLDecoder 5 | import java.nio.charset.Charset 6 | 7 | case class MagnetLink(infoHash: InfoHash, displayName: Option[String], trackers: List[String]) 8 | 9 | object MagnetLink { 10 | 11 | def fromString(source: String): Option[MagnetLink] = 12 | source match 13 | case s"magnet:?$query" => fromQueryString(query) 14 | case _ => None 15 | 16 | private def fromQueryString(str: String) = 17 | val params = parseQueryString(str) 18 | for 19 | infoHash <- getInfoHash(params) 20 | displayName = getDisplayName(params) 21 | trackers = getTrackers(params) 22 | yield MagnetLink(infoHash, displayName, trackers) 23 | 24 | private type Query = Map[String, List[String]] 25 | 26 | private def getInfoHash(query: Query): Option[InfoHash] = 27 | query.get("xt").flatMap { 28 | case List(s"urn:btih:${InfoHash.fromString(ih)}") => Some(ih) 29 | case _ => None 30 | } 31 | 32 | private def getDisplayName(query: Query): Option[String] = 33 | query.get("dn").flatMap(_.headOption) 34 | 35 | private def getTrackers(query: Query): List[String] = 36 | query.get("tr").toList.flatten 37 | 38 | private def parseQueryString(str: String): Query = 39 | str 40 | .split('&') 41 | .toList 42 | .map(urlDecode) 43 | .map { p => 44 | val split = p.split('=') 45 | (split.head, split.tail.mkString("=")) 46 | } 47 | .groupMap(_._1)(_._2) 48 | 49 | private def urlDecode(value: String) = 50 | URLDecoder.decode(value, Charset.forName("UTF-8")) 51 | } 52 | -------------------------------------------------------------------------------- /bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/Torrent.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.wire 2 | 3 | import cats.effect.implicits.* 4 | import cats.effect.kernel.syntax.all.* 5 | import cats.effect.kernel.Async 6 | import cats.effect.kernel.Resource 7 | import cats.effect.IO 8 | import cats.effect.Resource 9 | import cats.implicits.* 10 | import com.github.torrentdam.bittorrent.TorrentMetadata 11 | import com.github.torrentdam.bittorrent.TorrentMetadata.Lossless 12 | import org.legogroup.woof.given 13 | import org.legogroup.woof.Logger 14 | import scala.collection.immutable.BitSet 15 | import scodec.bits.ByteVector 16 | 17 | trait Torrent { 18 | def metadata: TorrentMetadata.Lossless 19 | def stats: IO[Torrent.Stats] 20 | def downloadPiece(index: Long): IO[ByteVector] 21 | } 22 | 23 | object Torrent { 24 | 25 | def make( 26 | metadata: TorrentMetadata.Lossless, 27 | swarm: Swarm 28 | )(using logger: Logger[IO]): Resource[IO, Torrent] = 29 | for 30 | requestDispatcher <- RequestDispatcher(metadata.parsed) 31 | _ <- Download(swarm, requestDispatcher).background 32 | yield 33 | val metadata0 = metadata 34 | new Torrent { 35 | def metadata: TorrentMetadata.Lossless = metadata0 36 | def stats: IO[Stats] = 37 | for 38 | connected <- swarm.connected.list 39 | availability <- connected.traverse(_.availability.get) 40 | availability <- availability.foldMap(identity).pure[IO] 41 | yield Stats(connected.size, availability) 42 | def downloadPiece(index: Long): IO[ByteVector] = 43 | requestDispatcher.downloadPiece(index) 44 | } 45 | end for 46 | 47 | case class Stats( 48 | connected: Int, 49 | availability: BitSet 50 | ) 51 | 52 | enum Error extends Exception: 53 | case EmptyMetadata() 54 | } 55 | -------------------------------------------------------------------------------- /dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RoutingTableRefresh.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.dht 2 | 3 | import cats.effect.IO 4 | import cats.syntax.all.* 5 | import cats.effect.cps.{*, given} 6 | import cats.effect.std.Random 7 | import org.legogroup.woof.{Logger, given} 8 | 9 | import scala.concurrent.duration.{DurationInt, FiniteDuration} 10 | 11 | class RoutingTableRefresh(table: RoutingTable[IO], client: Client, discovery: PeerDiscovery)(using logger: Logger[IO], random: Random[IO]): 12 | 13 | def runOnce: IO[Unit] = async[IO]: 14 | val buckets = table.buckets.await 15 | val (fresh, stale) = buckets.toList.partition(_.nodes.values.exists(_.isGood)) 16 | if stale.nonEmpty then 17 | refreshBuckets(stale).await 18 | val nodes = fresh.flatMap(_.nodes.values) 19 | pingNodes(nodes).await 20 | 21 | def runEvery(period: FiniteDuration): IO[Unit] = 22 | IO 23 | .sleep(period) 24 | .productR(runOnce) 25 | .foreverM 26 | .handleErrorWith: e => 27 | logger.error(s"PingRoutine failed: $e") 28 | .foreverM 29 | 30 | private def pingNodes(nodes: List[RoutingTable.Node]) = async[IO]: 31 | logger.info(s"Pinging ${nodes.size} nodes").await 32 | val results = nodes 33 | .parTraverse { node => 34 | client.ping(node.address).timeout(5.seconds).attempt.map(_.bimap(_ => node.id, _ => node.id)) 35 | } 36 | .await 37 | val (bad, good) = results.partitionMap(identity) 38 | logger.info(s"Got ${good.size} good nodes and ${bad.size} bad nodes").await 39 | table.updateGoodness(good.toSet, bad.toSet).await 40 | 41 | private def refreshBuckets(buckets: List[RoutingTable.TreeNode.Bucket]) = async[IO]: 42 | logger.info(s"Found ${buckets.size} stale buckets").await 43 | buckets 44 | .parTraverse: bucket => 45 | val randomId = NodeId.randomInRange(bucket.from, bucket.until).await 46 | discovery.findNodes(randomId).take(32).compile.drain 47 | .await 48 | 49 | 50 | -------------------------------------------------------------------------------- /dht/src/test/scala/com/github/torrentdam/bittorrent/dht/MessageFormatSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.dht 2 | 3 | import com.github.torrentdam.bittorrent.InfoHash 4 | import com.github.torrentdam.bencode.format.BencodeFormat 5 | import com.github.torrentdam.bencode.Bencode 6 | import scodec.bits.ByteVector 7 | 8 | class MessageFormatSpec extends munit.FunSuite { 9 | 10 | test("decode ping response") { 11 | val input = Bencode.BDictionary( 12 | "ip" -> Bencode.BString(ByteVector.fromValidHex("1f14bdfa9f21")), 13 | "y" -> Bencode.BString("q"), 14 | "t" -> Bencode.BString(ByteVector.fromValidHex("6a76679c")), 15 | "a" -> Bencode.BDictionary( 16 | "id" -> Bencode.BString(ByteVector.fromValidHex("32f54e697351ff4aec29cdbaabf2fbe3467cc267")) 17 | ), 18 | "q" -> Bencode.BString("ping") 19 | ) 20 | 21 | val result = summon[BencodeFormat[Message]].read(input) 22 | val expectation = Right( 23 | Message.QueryMessage( 24 | ByteVector.fromValidHex("6a76679c"), 25 | Query.Ping(NodeId(ByteVector.fromValidHex("32f54e697351ff4aec29cdbaabf2fbe3467cc267"))) 26 | ) 27 | ) 28 | 29 | assert(result == expectation) 30 | } 31 | 32 | test("decode announce_peer query") { 33 | val input = Bencode.BDictionary( 34 | "t" -> Bencode.BString(ByteVector.fromValidHex("6a76679c")), 35 | "y" -> Bencode.BString("q"), 36 | "q" -> Bencode.BString("announce_peer"), 37 | "a" -> Bencode.BDictionary( 38 | "id" -> Bencode.BString(ByteVector.fromValidHex("32f54e697351ff4aec29cdbaabf2fbe3467cc267")), 39 | "info_hash" -> Bencode.BString(ByteVector.fromValidHex("32f54e697351ff4aec29cdbaabf2fbe3467cc267")), 40 | "port" -> Bencode.BInteger(9999) 41 | ) 42 | ) 43 | val result = summon[BencodeFormat[Message]].read(input) 44 | val expectation = Right( 45 | Message.QueryMessage( 46 | ByteVector.fromValidHex("6a76679c"), 47 | Query.AnnouncePeer( 48 | NodeId(ByteVector.fromValidHex("32f54e697351ff4aec29cdbaabf2fbe3467cc267")), 49 | InfoHash(ByteVector.fromValidHex("32f54e697351ff4aec29cdbaabf2fbe3467cc267")), 50 | 9999L 51 | ) 52 | ) 53 | ) 54 | assert(result == expectation) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /dht/src/main/scala/com/github/torrentdam/bittorrent/dht/MessageSocket.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.dht 2 | 3 | import cats.* 4 | import cats.effect.Async 5 | import cats.effect.Concurrent 6 | import cats.effect.IO 7 | import cats.effect.Resource 8 | import cats.syntax.all.* 9 | import com.comcast.ip4s.* 10 | import com.comcast.ip4s.ip 11 | import com.github.torrentdam.bencode.decode 12 | import com.github.torrentdam.bencode.encode 13 | import com.github.torrentdam.bencode.format.BencodeFormat 14 | import fs2.io.net.Datagram 15 | import fs2.io.net.DatagramSocket 16 | import fs2.io.net.DatagramSocketGroup 17 | import fs2.io.net.Network 18 | import fs2.Chunk 19 | import java.net.InetSocketAddress 20 | import org.legogroup.woof.given 21 | import org.legogroup.woof.Logger 22 | 23 | class MessageSocket(socket: DatagramSocket[IO], logger: Logger[IO]) { 24 | import MessageSocket.Error 25 | 26 | def readMessage: IO[(SocketAddress[IpAddress], Message)] = 27 | for { 28 | datagram <- socket.read 29 | bc <- IO.fromEither( 30 | decode(datagram.bytes.toBitVector).leftMap(Error.BecodeSerialization.apply) 31 | ) 32 | message <- IO.fromEither( 33 | summon[BencodeFormat[Message]] 34 | .read(bc) 35 | .leftMap(e => Error.MessageFormat(s"Filed to read message from bencode: $bc", e)) 36 | ) 37 | _ <- logger.trace(s"<<< ${datagram.remote} $message") 38 | } yield (datagram.remote, message) 39 | 40 | def writeMessage(address: SocketAddress[IpAddress], message: Message): IO[Unit] = IO.defer { 41 | val bc = summon[BencodeFormat[Message]].write(message).toOption.get 42 | val bytes = encode(bc) 43 | val packet = Datagram(address, Chunk.byteVector(bytes.bytes)) 44 | socket.write(packet) >> logger.trace(s">>> $address $message") 45 | } 46 | } 47 | 48 | object MessageSocket { 49 | 50 | def apply( 51 | port: Option[Port] 52 | )(using logger: Logger[IO]): Resource[IO, MessageSocket] = 53 | Network[IO] 54 | .openDatagramSocket(Some(ip"0.0.0.0"), port) 55 | .map(socket => new MessageSocket(socket, logger)) 56 | 57 | object Error { 58 | case class BecodeSerialization(cause: Throwable) extends Throwable(cause) 59 | case class MessageFormat(message: String, cause: Throwable) extends Throwable(message, cause) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/FileMapping.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent 2 | 3 | import com.github.torrentdam.bittorrent.FileMapping.FileSpan 4 | import com.github.torrentdam.bittorrent.FileMapping.PieceInFile 5 | import com.github.torrentdam.bittorrent.TorrentMetadata 6 | 7 | case class FileMapping(value: List[FileSpan], pieceLength: Long) { 8 | def mapToFiles(pieceIndex: Long, length: Long): List[PieceInFile] = 9 | value 10 | .filter(span => span.beginIndex <= pieceIndex && span.endIndex >= pieceIndex) 11 | .map(span => pieceInFile(span, pieceIndex, length)) 12 | 13 | private def pieceInFile(span: FileSpan, pieceIndex: Long, length: Long): PieceInFile = 14 | val beginOffset = (pieceIndex - span.beginIndex) * pieceLength + span.beginOffset 15 | val writeLength = Math.min(length, span.length - beginOffset) 16 | PieceInFile(span.fileIndex, beginOffset, writeLength) 17 | } 18 | 19 | object FileMapping { 20 | case class FileSpan( 21 | fileIndex: Int, 22 | length: Long, 23 | beginIndex: Long, 24 | beginOffset: Long, 25 | endIndex: Long, 26 | endOffset: Long 27 | ) 28 | case class PieceInFile( 29 | fileIndex: Int, 30 | offset: Long, 31 | length: Long 32 | ) 33 | def fromMetadata(torrentMetadata: TorrentMetadata): FileMapping = 34 | def forFile(fileIndex: Int, beginIndex: Long, beginOffset: Long, length: Long): FileSpan = { 35 | val spansPieces = length / torrentMetadata.pieceLength 36 | val remainder = length % torrentMetadata.pieceLength 37 | val endIndex = beginIndex + spansPieces + (beginOffset + remainder) / torrentMetadata.pieceLength 38 | val endOffset = (beginOffset + remainder) % torrentMetadata.pieceLength 39 | FileSpan(fileIndex, length, beginIndex, beginOffset, endIndex, endOffset) 40 | } 41 | case class State(beginIndex: Long, beginOffset: Long, spans: List[FileSpan]) 42 | val spans = 43 | torrentMetadata.files.zipWithIndex 44 | .foldLeft(State(0L, 0L, Nil)) { case (state, (file, index)) => 45 | val span = forFile(index, state.beginIndex, state.beginOffset, file.length) 46 | State(span.endIndex, span.endOffset, span :: state.spans) 47 | } 48 | .spans 49 | .reverse 50 | FileMapping(spans, torrentMetadata.pieceLength) 51 | } 52 | -------------------------------------------------------------------------------- /files/src/test/scala/com/github/torrentdam/bittorrent/files/WriterSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.files 2 | 3 | import com.github.torrentdam.bittorrent.TorrentMetadata 4 | import scodec.bits.ByteVector 5 | 6 | class WriterSpec extends munit.FunSuite { 7 | 8 | test("one piece one file") { 9 | val file = 10 | TorrentMetadata.File( 11 | 1L, 12 | List("foo.txt") 13 | ) 14 | val writer = Writer( 15 | List(file), 16 | 1 17 | ) 18 | val writes = writer.write(0, ByteVector(1)) 19 | assertEquals(writes, List(Writer.WriteBytes(file, 0, ByteVector(1)))) 20 | } 21 | 22 | test("write piece to second file") { 23 | val files = 24 | List( 25 | TorrentMetadata.File( 26 | 1L, 27 | List("foo.txt") 28 | ), 29 | TorrentMetadata.File( 30 | 1L, 31 | List("bar.txt") 32 | ) 33 | ) 34 | val writer = Writer( 35 | files, 36 | 1 37 | ) 38 | val writes = writer.write(1, ByteVector(2)) 39 | assertEquals(writes, List(Writer.WriteBytes(files(1), 0, ByteVector(2)))) 40 | } 41 | 42 | test("write piece to both files") { 43 | val files = 44 | List( 45 | TorrentMetadata.File( 46 | 1L, 47 | List("foo.txt") 48 | ), 49 | TorrentMetadata.File( 50 | 1L, 51 | List("bar.txt") 52 | ) 53 | ) 54 | val writer = Writer( 55 | files, 56 | 2 57 | ) 58 | val writes = writer.write(0, ByteVector(0, 1)) 59 | assertEquals( 60 | writes, 61 | List( 62 | Writer.WriteBytes(files(0), 0, ByteVector(0)), 63 | Writer.WriteBytes(files(1), 0, ByteVector(1)) 64 | ) 65 | ) 66 | } 67 | 68 | test("start writing in file with offset") { 69 | val files = 70 | List( 71 | TorrentMetadata.File( 72 | 3L, 73 | List("foo.txt") 74 | ), 75 | TorrentMetadata.File( 76 | 1L, 77 | List("bar.txt") 78 | ) 79 | ) 80 | val writer = Writer( 81 | files, 82 | 2 83 | ) 84 | val writes = writer.write(1, ByteVector(0, 1)) 85 | assertEquals( 86 | writes, 87 | List( 88 | Writer.WriteBytes(files(0), 2, ByteVector(0)), 89 | Writer.WriteBytes(files(1), 0, ByteVector(1)) 90 | ) 91 | ) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /dht/src/test/scala/com/github/torrentdam/bittorrent/dht/PeerDiscoverySpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.dht 2 | 3 | import cats.effect.kernel.Ref 4 | import cats.effect.IO 5 | import cats.effect.SyncIO 6 | import com.comcast.ip4s.* 7 | import com.github.torrentdam.bittorrent.InfoHash 8 | import com.github.torrentdam.bittorrent.PeerInfo 9 | import org.legogroup.woof.Logger 10 | import scodec.bits.ByteVector 11 | 12 | class PeerDiscoverySpec extends munit.CatsEffectSuite { 13 | 14 | test("discover new peers") { 15 | 16 | val infoHash = InfoHash(ByteVector.encodeUtf8("c").toOption.get) 17 | 18 | def nodeId(id: String) = NodeId(ByteVector.encodeUtf8(id).toOption.get) 19 | 20 | given logger: Logger[IO] = NoOpLogger() 21 | 22 | def getPeers( 23 | address: SocketAddress[IpAddress], 24 | infoHash: InfoHash 25 | ): IO[Either[Response.Nodes, Response.Peers]] = IO { 26 | address.port.value match { 27 | case 1 => 28 | Left( 29 | Response.Nodes( 30 | nodeId("a"), 31 | List( 32 | NodeInfo( 33 | nodeId("b"), 34 | SocketAddress(ip"1.1.1.1", port"2") 35 | ), 36 | NodeInfo( 37 | nodeId("c"), 38 | SocketAddress(ip"1.1.1.1", port"3") 39 | ) 40 | ) 41 | ) 42 | ) 43 | case 2 => 44 | Right( 45 | Response.Peers( 46 | nodeId("b"), 47 | List( 48 | PeerInfo( 49 | SocketAddress(ip"2.2.2.2", port"2") 50 | ) 51 | ) 52 | ) 53 | ) 54 | case 3 => 55 | Right( 56 | Response.Peers( 57 | nodeId("c"), 58 | List( 59 | PeerInfo( 60 | SocketAddress(ip"2.2.2.2", port"3") 61 | ) 62 | ) 63 | ) 64 | ) 65 | } 66 | } 67 | 68 | for { 69 | state <- PeerDiscovery.DiscoveryState( 70 | initialNodes = List( 71 | NodeInfo( 72 | nodeId("a"), 73 | SocketAddress(ip"1.1.1.1", port"1") 74 | ) 75 | ), 76 | infoHash = infoHash 77 | ) 78 | list <- PeerDiscovery.start(infoHash, getPeers, state, 1).take(1).compile.toList 79 | } yield { 80 | assertEquals( 81 | list, 82 | List( 83 | PeerInfo(SocketAddress(ip"2.2.2.2", port"3")) 84 | ) 85 | ) 86 | } 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /bittorrent/shared/src/test/scala/com/github/torrentdam/bittorrent/wire/WorkQueueSuite.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.wire 2 | 3 | import cats.effect.std.CountDownLatch 4 | import cats.effect.syntax.temporal.genTemporalOps_ 5 | import cats.effect.IO 6 | import cats.effect.Resource 7 | import com.github.torrentdam.bittorrent.wire.RequestDispatcher.WorkQueue 8 | import com.github.torrentdam.bittorrent.wire.RequestDispatcher.WorkQueue.EmptyQueue 9 | import scala.concurrent.duration.DurationInt 10 | import scodec.bits.ByteVector 11 | 12 | class WorkQueueSuite extends munit.CatsEffectSuite { 13 | 14 | test("return request")( 15 | for 16 | workQueue <- WorkQueue(Seq(1), _ => IO.unit) 17 | request <- workQueue.nextRequest.use((request, _) => IO.pure(request)) 18 | yield assertEquals(request, 1) 19 | ) 20 | 21 | test("put request back into queue if it was not completed")( 22 | for 23 | workQueue <- WorkQueue(Seq(1, 2), _ => IO.unit) 24 | request0 <- workQueue.nextRequest.use((request, _) => IO.pure(request)) 25 | request1 <- workQueue.nextRequest.use((request, _) => IO.pure(request)) 26 | yield 27 | assertEquals(request0, 1) 28 | assertEquals(request1, 1) 29 | ) 30 | test("delete request from queue if it was completed")( 31 | for 32 | workQueue <- WorkQueue(Seq(1, 2), _ => IO.unit) 33 | request0 <- workQueue.nextRequest.use((request, promise) => promise.complete(()).as(request)) 34 | request1 <- workQueue.nextRequest.use((request, _) => IO.pure(request)) 35 | yield 36 | assertEquals(request0, 1) 37 | assertEquals(request1, 2) 38 | ) 39 | test("throw PieceComplete when last request was fulfilled")( 40 | for 41 | workQueue <- WorkQueue(Seq(1), _ => IO.unit) 42 | request <- workQueue.nextRequest.use((request, promise) => promise.complete(ByteVector.empty).as(request)) 43 | result <- workQueue.nextRequest.use((request, _) => IO.pure(request)).attempt 44 | yield 45 | assertEquals(request, 1) 46 | assertEquals(result, Left(WorkQueue.PieceComplete)) 47 | ) 48 | test("throw EmptyQueue when queue is empty")( 49 | for 50 | workQueue <- WorkQueue(Seq(1), _ => IO.unit) 51 | finishFirst <- CountDownLatch[IO](1) 52 | fiber0 <- workQueue.nextRequest.use((request, _) => finishFirst.await.as(request)).start 53 | fiber1 <- workQueue.nextRequest.use((request, _) => IO.pure(request)).start 54 | request1 <- fiber1.joinWithNever.attempt 55 | _ <- finishFirst.release 56 | request0 <- fiber0.joinWithNever 57 | yield 58 | assertEquals(request0, 1) 59 | assertEquals(request1, Left(EmptyQueue)) 60 | ) 61 | } 62 | -------------------------------------------------------------------------------- /bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/protocol/message.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.protocol.message 2 | 3 | import com.github.torrentdam.bittorrent.InfoHash 4 | import com.github.torrentdam.bittorrent.PeerId 5 | import scala.util.chaining.* 6 | import scodec.bits.ByteVector 7 | import scodec.codecs.* 8 | import scodec.Codec 9 | 10 | final case class Handshake( 11 | extensionProtocol: Boolean, 12 | infoHash: InfoHash, 13 | peerId: PeerId 14 | ) 15 | 16 | object Handshake { 17 | val ProtocolStringCodec: Codec[Unit] = uint8.unit(19) ~> fixedSizeBytes( 18 | 19, 19 | utf8.unit("BitTorrent protocol") 20 | ) 21 | val ReserveCodec: Codec[Boolean] = bits(8 * 8).xmap( 22 | bv => bv.get(43), 23 | supported => 24 | ByteVector 25 | .fill(8)(0) 26 | .toBitVector 27 | .pipe(v => if supported then v.set(43) else v) 28 | ) 29 | val InfoHashCodec: Codec[InfoHash] = bytes(20).xmap(InfoHash(_), _.bytes) 30 | val PeerIdCodec: Codec[PeerId] = bytes(20).xmap(PeerId.apply, _.bytes) 31 | val HandshakeCodec: Codec[Handshake] = 32 | (ProtocolStringCodec ~> ReserveCodec :: InfoHashCodec :: PeerIdCodec).as 33 | } 34 | 35 | enum Message: 36 | case KeepAlive 37 | case Choke 38 | case Unchoke 39 | case Interested 40 | case NotInterested 41 | case Have(pieceIndex: Long) 42 | case Bitfield(bytes: ByteVector) 43 | case Request(index: Long, begin: Long, length: Long) 44 | case Piece(index: Long, begin: Long, bytes: ByteVector) 45 | case Cancel(index: Long, begin: Long, length: Long) 46 | case Port(port: Int) 47 | case Extended(id: Long, payload: ByteVector) 48 | 49 | object Message { 50 | 51 | val MessageSizeCodec: Codec[Long] = uint32 52 | 53 | val MessageBodyCodec: Codec[Message] = { 54 | val KeepAliveCodec: Codec[KeepAlive.type] = provide(KeepAlive).complete 55 | 56 | val OtherMessagesCodec: Codec[Message] = 57 | discriminated[Message] 58 | .by(uint8) 59 | .caseP(0) { case m @ Choke => m }(identity)(provide(Choke)) 60 | .caseP(1) { case m @ Unchoke => m }(identity)(provide(Unchoke)) 61 | .caseP(2) { case m @ Interested => m }(identity)(provide(Interested)) 62 | .caseP(3) { case m @ NotInterested => m }(identity)(provide(NotInterested)) 63 | .caseP(4) { case Have(index) => index }(Have.apply)(uint32) 64 | .caseP(5) { case Bitfield(bytes) => bytes }(Bitfield.apply)(bytes) 65 | .caseP(6) { case m: Request => m }(identity)((uint32 :: uint32 :: uint32).as) 66 | .caseP(7) { case m: Piece => m }(identity)((uint32 :: uint32 :: bytes).as) 67 | .caseP(8) { case m: Cancel => m }(identity)((uint32 :: uint32 :: uint32).as) 68 | .caseP(9) { case Port(port) => port }(Port.apply)(uint16) 69 | .caseP(20) { case m: Extended => m }(identity)((ulong(8) :: bytes).as) 70 | 71 | choice( 72 | KeepAliveCodec.upcast, 73 | OtherMessagesCodec 74 | ) 75 | } 76 | 77 | val MessageCodec: Codec[Message] = { 78 | variableSizeBytesLong( 79 | MessageSizeCodec, 80 | MessageBodyCodec 81 | ) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: [push] 4 | 5 | jobs: 6 | 7 | build: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v2 11 | - name: Compile 12 | run: ./mill _.compile 13 | - name: Test 14 | run: ./mill _.test 15 | - name: Build distributable 16 | run: ./mill cmd.assembly 17 | - name: Set up Docker Buildx 18 | id: buildx 19 | uses: docker/setup-buildx-action@master 20 | - name: Login to GitHub Container Registry 21 | uses: docker/login-action@v1 22 | with: 23 | registry: ghcr.io 24 | username: ${{ github.actor }} 25 | password: ${{ secrets.GITHUB_TOKEN }} 26 | - name: Push cmd images 27 | uses: docker/build-push-action@v2 28 | with: 29 | context: . 30 | file: cmd/Dockerfile 31 | platforms: linux/amd64,linux/arm64 32 | push: true 33 | tags: ghcr.io/torrentdam/cmd:latest 34 | 35 | build-native: 36 | if: false 37 | runs-on: ubuntu-latest 38 | needs: build 39 | steps: 40 | - uses: actions/checkout@v1 41 | - uses: actions/cache@v1 42 | with: 43 | path: ~/.cache/coursier/v1 44 | key: ${{ runner.os }}-coursier-${{ hashFiles('**/build.sbt') }} 45 | - name: Set up java 46 | uses: actions/setup-java@v2.1.0 47 | with: 48 | distribution: adopt 49 | java-version: 17 50 | java-package: jre 51 | - name: Build Native 52 | env: 53 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 54 | run: sbt -J-Xmx4g cmdNative/nativeLink 55 | - name: Set up Docker Buildx 56 | id: buildx 57 | uses: docker/setup-buildx-action@master 58 | - name: Login to GitHub Container Registry 59 | uses: docker/login-action@v1 60 | with: 61 | registry: ghcr.io 62 | username: ${{ github.actor }} 63 | password: ${{ secrets.GITHUB_TOKEN }} 64 | - name: Push cmd images 65 | uses: docker/build-push-action@v2 66 | with: 67 | context: cmd 68 | file: cmd/Dockerfile-native 69 | platforms: linux/amd64 70 | push: true 71 | tags: ghcr.io/torrentdam/cmd-native:latest 72 | 73 | release: 74 | if: startsWith(github.ref, 'refs/tags/v') 75 | needs: build 76 | runs-on: ubuntu-latest 77 | steps: 78 | - uses: actions/checkout@v1 79 | - name: Set up java 80 | uses: actions/setup-java@v2.1.0 81 | with: 82 | distribution: adopt 83 | java-version: 17 84 | java-package: jre 85 | - name: Release 86 | env: 87 | SONATYPE_CREDS: ${{ secrets.SONATYPE_CREDS }} 88 | PGP_SECRET_KEY: ${{ secrets.PGP_SECRET_KEY }} 89 | run: | 90 | echo ${PGP_SECRET_KEY} | base64 --decode | gpg --import 91 | gpg --list-secret-keys 92 | export VERSION=${GITHUB_REF#*/v} 93 | echo Publishing $VERSION 94 | sbt commonJVM/publishSigned 95 | sbt commonJS/publishSigned 96 | sbt bittorrentJVM/publishSigned 97 | sbt dhtJVM/publishSigned 98 | sbt cmdJVM/publishSigned 99 | -------------------------------------------------------------------------------- /bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/TorrentMetadata.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent 2 | 3 | import cats.implicits.* 4 | import com.github.torrentdam.bencode 5 | import com.github.torrentdam.bencode.format.* 6 | import com.github.torrentdam.bencode.Bencode 7 | import com.github.torrentdam.bencode.BencodeFormatException 8 | import java.time.Instant 9 | import scodec.bits.ByteVector 10 | 11 | case class TorrentMetadata( 12 | name: String, 13 | pieceLength: Long, 14 | pieces: ByteVector, 15 | files: List[TorrentMetadata.File] 16 | ) 17 | 18 | object TorrentMetadata { 19 | 20 | case class File( 21 | length: Long, 22 | path: List[String] 23 | ) 24 | 25 | given BencodeFormat[File] = 26 | ( 27 | field[Long]("length"), 28 | field[List[String]]("path") 29 | ).imapN(File.apply)(v => (v.length, v.path)) 30 | 31 | given BencodeFormat[TorrentMetadata] = { 32 | def to(name: String, pieceLength: Long, pieces: ByteVector, length: Option[Long], filesOpt: Option[List[File]]) = { 33 | val files = length match { 34 | case Some(length) => List(File(length, List(name))) 35 | case None => filesOpt.combineAll 36 | } 37 | TorrentMetadata(name, pieceLength, pieces, files) 38 | } 39 | def from(v: TorrentMetadata) = 40 | (v.name, v.pieceLength, v.pieces, Option.empty[Long], Some(v.files)) 41 | ( 42 | field[String]("name"), 43 | field[Long]("piece length"), 44 | field[ByteVector]("pieces"), 45 | fieldOptional[Long]("length"), 46 | fieldOptional[List[File]]("files") 47 | ).imapN(to)(from) 48 | } 49 | 50 | case class Lossless private ( 51 | parsed: TorrentMetadata, 52 | raw: Bencode 53 | ) { 54 | def infoHash: InfoHash = InfoHash(CrossPlatform.sha1(bencode.encode(raw).toByteVector)) 55 | } 56 | 57 | object Lossless { 58 | def fromBytes(bytes: ByteVector): Either[Throwable, Lossless] = 59 | bencode.decode(bytes.bits).flatMap(fromBencode) 60 | 61 | def fromBencode(bcode: Bencode): Either[BencodeFormatException, Lossless] = 62 | summon[BencodeFormat[TorrentMetadata]].read(bcode).map { metadata => 63 | Lossless(metadata, bcode) 64 | } 65 | given BencodeFormat[Lossless] = { 66 | BencodeFormat( 67 | read = BencodeReader(fromBencode), 68 | write = BencodeWriter(metadata => BencodeFormat.BencodeValueFormat.write(metadata.raw)) 69 | ) 70 | } 71 | } 72 | } 73 | 74 | case class TorrentFile( 75 | info: TorrentMetadata.Lossless, 76 | creationDate: Option[Instant] 77 | ) 78 | 79 | object TorrentFile { 80 | 81 | private given BencodeFormat[Instant] = 82 | BencodeFormat.LongFormat.imap(Instant.ofEpochMilli)(_.toEpochMilli) 83 | 84 | given torrentFileFormat: BencodeFormat[TorrentFile] = { 85 | ( 86 | field[TorrentMetadata.Lossless]("info"), 87 | fieldOptional[Instant]("creationDate") 88 | ).imapN(TorrentFile(_, _))(v => (v.info, v.creationDate)) 89 | } 90 | 91 | def fromBencode(bcode: Bencode): Either[BencodeFormatException, TorrentFile] = 92 | torrentFileFormat.read(bcode) 93 | def fromBytes(bytes: ByteVector): Either[Throwable, TorrentFile] = 94 | bencode.decode(bytes.bits).flatMap(fromBencode) 95 | 96 | def toBytes(torrentFile: TorrentFile): ByteVector = 97 | bencode.encode(torrentFileFormat.write(torrentFile).right.get).toByteVector 98 | } 99 | -------------------------------------------------------------------------------- /dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RoutingTableBootstrap.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.dht 2 | 3 | import cats.effect.kernel.Temporal 4 | import cats.implicits.* 5 | import cats.MonadError 6 | import cats.effect.IO 7 | import cats.effect.implicits.* 8 | import cats.effect.cps.{given, *} 9 | import com.comcast.ip4s.* 10 | import com.github.torrentdam.bittorrent.InfoHash 11 | import org.legogroup.woof.given 12 | import org.legogroup.woof.Logger 13 | import fs2.Stream 14 | 15 | import scala.concurrent.duration.* 16 | 17 | object RoutingTableBootstrap { 18 | 19 | def apply( 20 | table: RoutingTable[IO], 21 | client: Client, 22 | discovery: PeerDiscovery, 23 | bootstrapNodeAddress: List[SocketAddress[Host]] = PublicBootstrapNodes 24 | )(using 25 | dns: Dns[IO], 26 | logger: Logger[IO] 27 | ): IO[Unit] = 28 | for 29 | _ <- logger.info("Bootstrapping") 30 | count <- resolveNodes(client, bootstrapNodeAddress).compile.count.iterateUntil(_ > 0) 31 | _ <- logger.info(s"Communicated with $count bootstrap nodes") 32 | _ <- selfDiscovery(table, client, discovery) 33 | nodeCount <- table.allNodes.map(_.size) 34 | _ <- logger.info(s"Bootstrapping finished with $nodeCount nodes") 35 | yield {} 36 | 37 | private def resolveNodes( 38 | client: Client, 39 | bootstrapNodeAddress: List[SocketAddress[Host]] 40 | )(using 41 | dns: Dns[IO], 42 | logger: Logger[IO] 43 | ): Stream[IO, NodeInfo] = 44 | def tryThis(hostname: SocketAddress[Host]): Stream[IO, NodeInfo] = 45 | Stream.eval(logger.info(s"Trying to reach $hostname")) >> 46 | Stream 47 | .evals( 48 | hostname.host.resolveAll[IO] 49 | .recoverWith: e => 50 | logger.info(s"Failed to resolve $hostname $e").as(List.empty) 51 | ) 52 | .evalMap: ipAddress => 53 | val resolvedAddress = SocketAddress(ipAddress, hostname.port) 54 | logger.info(s"Resolved to $ipAddress") *> 55 | client 56 | .ping(resolvedAddress) 57 | .timeout(5.seconds) 58 | .map(pong => NodeInfo(pong.id, resolvedAddress)) 59 | .flatTap: _ => 60 | logger.info(s"Reached $resolvedAddress node") 61 | .map(_.some) 62 | .recoverWith: e => 63 | logger.info(s"Failed to reach $resolvedAddress $e").as(none) 64 | .collect { 65 | case Some(node) => node 66 | } 67 | Stream 68 | .emits(bootstrapNodeAddress) 69 | .covary[IO] 70 | .flatMap(tryThis) 71 | 72 | private def selfDiscovery( 73 | table: RoutingTable[IO], 74 | client: Client, 75 | discovery: PeerDiscovery 76 | )(using Logger[IO]) = 77 | def attempt(number: Int): IO[Unit] = async[IO]: 78 | Logger[IO].info(s"Discover self to fill up routing table (attempt $number)").await 79 | val count = discovery.findNodes(client.id).take(30).interruptAfter(30.seconds).compile.count.await 80 | Logger[IO].info(s"Communicated with $count nodes during self discovery").await 81 | val nodeCount = table.allNodes.await.size 82 | if nodeCount < 20 then attempt(number + 1).await else IO.unit 83 | attempt(1) 84 | 85 | val PublicBootstrapNodes: List[SocketAddress[Host]] = List( 86 | SocketAddress(host"router.bittorrent.com", port"6881"), 87 | SocketAddress(host"router.utorrent.com", port"6881"), 88 | SocketAddress(host"dht.transmissionbt.com", port"6881"), 89 | SocketAddress(host"router.bitcomet.com", port"6881"), 90 | SocketAddress(host"dht.aelitis.com", port"6881"), 91 | ) 92 | } 93 | -------------------------------------------------------------------------------- /bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/Swarm.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.wire 2 | 3 | import cats.* 4 | import cats.effect.implicits.* 5 | import cats.effect.std.Queue 6 | import cats.effect.IO 7 | import cats.effect.Outcome 8 | import cats.effect.Resource 9 | import cats.implicits.* 10 | import com.github.torrentdam.bittorrent.PeerInfo 11 | import fs2.concurrent.Signal 12 | import fs2.concurrent.SignallingRef 13 | import fs2.concurrent.Topic 14 | import fs2.Stream 15 | import org.legogroup.woof.* 16 | import org.legogroup.woof.given 17 | import org.legogroup.woof.Logger.withLogContext 18 | import scala.concurrent.duration.* 19 | 20 | trait Swarm { 21 | def connect: Resource[IO, Connection] 22 | def connected: Connected 23 | } 24 | 25 | trait Connected { 26 | def count: Signal[IO, Int] 27 | def list: IO[List[Connection]] 28 | } 29 | 30 | object Swarm { 31 | 32 | def apply( 33 | peers: Stream[IO, PeerInfo], 34 | connect: PeerInfo => Resource[IO, Connection] 35 | )(using 36 | logger: Logger[IO] 37 | ): Resource[IO, Swarm] = 38 | for 39 | _ <- Resource.make(logger.info("Starting swarm"))(_ => logger.info("Swarm closed")) 40 | stateRef <- Resource.eval(SignallingRef[IO].of(Map.empty[PeerInfo, Connection])) 41 | newConnections <- Resource.eval(Queue.bounded[IO, Resource[IO, Connection]](10)) 42 | reconnects <- Resource.eval(Queue.unbounded[IO, Resource[IO, Connection]]) 43 | scheduleReconnect = (delay: FiniteDuration) => 44 | (reconnect: Resource[IO, Connection]) => (IO.sleep(delay) >> reconnects.offer(reconnect)).start.void 45 | allSeen <- Resource.eval(IO.ref(Set.empty[PeerInfo])) 46 | _ <- peers 47 | .evalMap(peerInfo => 48 | allSeen.getAndUpdate(_ + peerInfo).flatMap { seen => 49 | if seen(peerInfo) 50 | then IO.pure(None) 51 | else IO.pure(Some(peerInfo)) 52 | } 53 | ) 54 | .collect { case Some(peerInfo) => peerInfo } 55 | .map(peerInfo => newConnection(connect(peerInfo), scheduleReconnect)) 56 | .evalTap(newConnections.offer) 57 | .compile 58 | .drain 59 | .background 60 | connectOrReconnect = 61 | for 62 | resource <- Resource.eval(newConnections.take race reconnects.take) 63 | connection <- resource.merge 64 | yield connection 65 | yield new Impl(stateRef, connectOrReconnect) 66 | end for 67 | 68 | private class Impl( 69 | stateRef: SignallingRef[IO, Map[PeerInfo, Connection]], 70 | connectOrReconnect: Resource[IO, Connection] 71 | ) extends Swarm { 72 | val connect: Resource[IO, Connection] = 73 | connectOrReconnect.flatTap(connection => 74 | Resource.make { 75 | stateRef.update(_ + (connection.info -> connection)) 76 | } { _ => 77 | stateRef.update(_ - connection.info) 78 | } 79 | ) 80 | val connected: Connected = new { 81 | val count: Signal[IO, Int] = stateRef.map(_.size) 82 | val list: IO[List[Connection]] = stateRef.get.map(_.values.toList) 83 | } 84 | } 85 | 86 | private def newConnection( 87 | connect: Resource[IO, Connection], 88 | schedule: FiniteDuration => Resource[IO, Connection] => IO[Unit] 89 | ): Resource[IO, Connection] = { 90 | val maxAttempts = 24 91 | def connectWithRetry(attempt: Int): Resource[IO, Connection] = 92 | connect 93 | .onFinalizeCase { 94 | case Resource.ExitCase.Succeeded => 95 | schedule(10.second)(connectWithRetry(1)) 96 | case Resource.ExitCase.Errored(_) => 97 | if attempt == maxAttempts 98 | then IO.unit 99 | else 100 | val duration = (10 * attempt).seconds 101 | schedule(duration)(connectWithRetry(attempt + 1)) 102 | case _ => 103 | IO.unit 104 | } 105 | connectWithRetry(1) 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /bittorrent/shared/src/test/scala/com/github/torrentdam/bittorrent/TorrentMetadataSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent 2 | 3 | import com.github.torrentdam.bencode.* 4 | import com.github.torrentdam.bencode.format.BencodeFormat 5 | import com.github.torrentdam.bittorrent.CrossPlatform 6 | import scodec.bits.Bases 7 | import scodec.bits.BitVector 8 | import scodec.bits.ByteVector 9 | 10 | class TorrentMetadataSpec extends munit.FunSuite { 11 | 12 | test("encode file class") { 13 | val result = summon[BencodeFormat[TorrentMetadata.File]].write(TorrentMetadata.File(77, "abc" :: Nil)) 14 | val expectation = Right( 15 | Bencode.BDictionary( 16 | "length" -> Bencode.BInteger(77), 17 | "path" -> Bencode.BList(Bencode.BString("abc") :: Nil) 18 | ) 19 | ) 20 | assert(result == expectation) 21 | } 22 | 23 | test("calculate info_hash") { 24 | val source = getClass.getClassLoader 25 | .getResourceAsStream("bencode/ubuntu-18.10-live-server-amd64.iso.torrent") 26 | .readAllBytes() 27 | val Right(bc) = decode(BitVector(source)): @unchecked 28 | val decodedResult = summon[BencodeFormat[TorrentFile]].read(bc) 29 | val result = decodedResult 30 | .map(_.info.raw) 31 | .map(encode(_).bytes) 32 | .map(CrossPlatform.sha1) 33 | .map(_.toHex(Bases.Alphabets.HexUppercase)) 34 | val expectation = Right("8C4ADBF9EBE66F1D804FB6A4FB9B74966C3AB609") 35 | assert(result == expectation) 36 | } 37 | 38 | test("decode either a or b") { 39 | val input = Bencode.BDictionary( 40 | "name" -> Bencode.BString("file_name"), 41 | "piece length" -> Bencode.BInteger(10), 42 | "pieces" -> Bencode.BString.Empty, 43 | "length" -> Bencode.BInteger(10) 44 | ) 45 | 46 | assert( 47 | summon[BencodeFormat[TorrentMetadata]].read(input) == Right( 48 | TorrentMetadata("file_name", 10, ByteVector.empty, List(TorrentMetadata.File(10, List("file_name")))) 49 | ) 50 | ) 51 | 52 | val input1 = Bencode.BDictionary( 53 | "name" -> Bencode.BString("test"), 54 | "piece length" -> Bencode.BInteger(10), 55 | "pieces" -> Bencode.BString.Empty, 56 | "files" -> Bencode.BList( 57 | Bencode.BDictionary( 58 | "length" -> Bencode.BInteger(10), 59 | "path" -> Bencode.BList(Bencode.BString("/root") :: Nil) 60 | ) :: Nil 61 | ) 62 | ) 63 | 64 | assert( 65 | summon[BencodeFormat[TorrentMetadata]].read(input1) == Right( 66 | TorrentMetadata("test", 10, ByteVector.empty, TorrentMetadata.File(10, "/root" :: Nil) :: Nil) 67 | ) 68 | ) 69 | } 70 | 71 | test("decode dictionary") { 72 | val input = Bencode.BDictionary( 73 | "name" -> Bencode.BString("file_name"), 74 | "piece length" -> Bencode.BInteger(10), 75 | "pieces" -> Bencode.BString(ByteVector(10)), 76 | "length" -> Bencode.BInteger(10) 77 | ) 78 | 79 | assert( 80 | summon[BencodeFormat[TorrentMetadata]].read(input) == Right( 81 | TorrentMetadata("file_name", 10, ByteVector(10), List(TorrentMetadata.File(10, List("file_name")))) 82 | ) 83 | ) 84 | } 85 | 86 | test("decode ubuntu torrent") { 87 | assert(decode(BitVector.encodeAscii("i56e").toOption.get) == Right(Bencode.BInteger(56L))) 88 | assert(decode(BitVector.encodeAscii("2:aa").toOption.get) == Right(Bencode.BString("aa"))) 89 | assert( 90 | decode(BitVector.encodeAscii("l1:a2:bbe").toOption.get) == Right( 91 | Bencode.BList(Bencode.BString("a") :: Bencode.BString("bb") :: Nil) 92 | ) 93 | ) 94 | assert( 95 | decode(BitVector.encodeAscii("d1:ai6ee").toOption.get) == Right( 96 | Bencode.BDictionary("a" -> Bencode.BInteger(6)) 97 | ) 98 | ) 99 | val source = getClass.getClassLoader 100 | .getResourceAsStream("bencode/ubuntu-18.10-live-server-amd64.iso.torrent") 101 | .readAllBytes() 102 | val Right(result) = decode(BitVector(source)): @unchecked 103 | val decodeResult = summon[BencodeFormat[TorrentFile]].read(result) 104 | assert(decodeResult.isRight) 105 | } 106 | 107 | } 108 | -------------------------------------------------------------------------------- /dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RequestResponse.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.dht 2 | 3 | import cats.* 4 | import cats.effect.kernel.Deferred 5 | import cats.effect.kernel.Ref 6 | import cats.effect.kernel.Temporal 7 | import cats.effect.syntax.all.* 8 | import cats.effect.Concurrent 9 | import cats.effect.Resource 10 | import cats.syntax.all.* 11 | import com.comcast.ip4s.* 12 | import com.github.torrentdam.bittorrent.dht.RequestResponse.Timeout 13 | import com.github.torrentdam.bencode.Bencode 14 | import scala.concurrent.duration.* 15 | import scodec.bits.ByteVector 16 | 17 | trait RequestResponse[F[_]] { 18 | def sendQuery(address: SocketAddress[IpAddress], query: Query): F[Response] 19 | } 20 | 21 | object RequestResponse { 22 | 23 | def make[F[_]]( 24 | generateTransactionId: F[ByteVector], 25 | sendQuery: (SocketAddress[IpAddress], Message.QueryMessage) => F[Unit], 26 | receiveMessage: F[ 27 | (SocketAddress[IpAddress], Message.ErrorMessage | Message.ResponseMessage) 28 | ] 29 | )(using 30 | F: Temporal[F] 31 | ): Resource[F, RequestResponse[F]] = 32 | Resource { 33 | for { 34 | callbackRegistry <- CallbackRegistry.make[F] 35 | fiber <- receiveLoop(receiveMessage, callbackRegistry.complete).start 36 | } yield { 37 | new Impl(generateTransactionId, sendQuery, callbackRegistry.add) -> fiber.cancel 38 | } 39 | } 40 | 41 | private class Impl[F[_]]( 42 | generateTransactionId: F[ByteVector], 43 | sendQueryMessage: (SocketAddress[IpAddress], Message.QueryMessage) => F[Unit], 44 | receive: ByteVector => F[Either[Throwable, Response]] 45 | )(using F: MonadError[F, Throwable]) 46 | extends RequestResponse[F] { 47 | def sendQuery(address: SocketAddress[IpAddress], query: Query): F[Response] = { 48 | generateTransactionId.flatMap { transactionId => 49 | val send = sendQueryMessage( 50 | address, 51 | Message.QueryMessage(transactionId, query) 52 | ) 53 | send >> receive(transactionId).flatMap(F.fromEither) 54 | } 55 | } 56 | } 57 | 58 | private def receiveLoop[F[_]]( 59 | receive: F[ 60 | (SocketAddress[IpAddress], Message.ErrorMessage | Message.ResponseMessage) 61 | ], 62 | continue: (ByteVector, Either[Throwable, Response]) => F[Boolean] 63 | )(using 64 | F: Monad[F] 65 | ): F[Unit] = { 66 | val step = receive.map(_._2).flatMap { 67 | case Message.ResponseMessage(transactionId, response) => 68 | continue(transactionId, response.asRight) 69 | case Message.ErrorMessage(transactionId, details) => 70 | continue(transactionId, ErrorResponse(details).asLeft) 71 | } 72 | step.foreverM[Unit] 73 | } 74 | 75 | case class ErrorResponse(details: Bencode) extends Throwable 76 | case class InvalidResponse() extends Throwable 77 | case class Timeout() extends Throwable 78 | } 79 | 80 | trait CallbackRegistry[F[_]] { 81 | def add(transactionId: ByteVector): F[Either[Throwable, Response]] 82 | 83 | def complete(transactionId: ByteVector, result: Either[Throwable, Response]): F[Boolean] 84 | } 85 | 86 | object CallbackRegistry { 87 | def make[F[_]: Temporal]: F[CallbackRegistry[F]] = { 88 | for { 89 | ref <- 90 | Ref 91 | .of[F, Map[ByteVector, Either[Throwable, Response] => F[Boolean]]]( 92 | Map.empty 93 | ) 94 | } yield { 95 | new Impl(ref) 96 | } 97 | } 98 | 99 | private class Impl[F[_]]( 100 | ref: Ref[F, Map[ByteVector, Either[Throwable, Response] => F[Boolean]]] 101 | )(using F: Temporal[F]) 102 | extends CallbackRegistry[F] { 103 | def add(transactionId: ByteVector): F[Either[Throwable, Response]] = { 104 | F.deferred[Either[Throwable, Response]].flatMap { deferred => 105 | val update = 106 | ref.update { map => 107 | map.updated(transactionId, deferred.complete) 108 | } 109 | val delete = 110 | ref.update { map => 111 | map - transactionId 112 | } 113 | (update *> deferred.get).guarantee(delete) 114 | } 115 | } 116 | 117 | def complete(transactionId: ByteVector, result: Either[Throwable, Response]): F[Boolean] = 118 | ref.get.flatMap { map => 119 | map.get(transactionId) match { 120 | case Some(callback) => callback(result) 121 | case None => false.pure[F] 122 | } 123 | } 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Client.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.dht 2 | 3 | import cats.effect.kernel.Temporal 4 | import cats.effect.std.{Queue, Random} 5 | import cats.effect.{Concurrent, IO, Resource, Sync} 6 | import cats.syntax.all.* 7 | import com.comcast.ip4s.* 8 | import com.github.torrentdam.bittorrent.InfoHash 9 | 10 | import java.net.InetSocketAddress 11 | import org.legogroup.woof.given 12 | import org.legogroup.woof.Logger 13 | import scodec.bits.ByteVector 14 | 15 | trait Client { 16 | 17 | def id: NodeId 18 | 19 | def getPeers(address: SocketAddress[IpAddress], infoHash: InfoHash): IO[Either[Response.Nodes, Response.Peers]] 20 | 21 | def findNodes(address: SocketAddress[IpAddress], target: NodeId): IO[Response.Nodes] 22 | 23 | def ping(address: SocketAddress[IpAddress]): IO[Response.Ping] 24 | 25 | def sampleInfoHashes(address: SocketAddress[IpAddress], target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]] 26 | } 27 | 28 | object Client { 29 | 30 | def generateTransactionId(using random: Random[IO]): IO[ByteVector] = 31 | val nextChar = random.nextAlphaNumeric 32 | (nextChar, nextChar).mapN((a, b) => ByteVector.encodeAscii(List(a, b).mkString).toOption.get) 33 | 34 | def apply( 35 | selfId: NodeId, 36 | messageSocket: MessageSocket, 37 | queryHandler: QueryHandler[IO] 38 | )(using Logger[IO], Random[IO]): Resource[IO, Client] = { 39 | for 40 | responses <- Resource.eval { 41 | Queue.unbounded[IO, (SocketAddress[IpAddress], Message.ErrorMessage | Message.ResponseMessage)] 42 | } 43 | requestResponse <- RequestResponse.make( 44 | generateTransactionId, 45 | messageSocket.writeMessage, 46 | responses.take 47 | ) 48 | _ <- 49 | messageSocket.readMessage 50 | .flatMap { 51 | case (a, m: Message.QueryMessage) => 52 | Logger[IO].debug(s"Received $m") >> 53 | queryHandler(a, m.query).flatMap { 54 | case Some(response) => 55 | val responseMessage = Message.ResponseMessage(m.transactionId, response) 56 | Logger[IO].debug(s"Responding with $responseMessage") >> 57 | messageSocket.writeMessage(a, responseMessage) 58 | case None => 59 | Logger[IO].debug(s"No response for $m") 60 | } 61 | case (a, m: Message.ResponseMessage) => responses.offer((a, m)) 62 | case (a, m: Message.ErrorMessage) => responses.offer((a, m)) 63 | } 64 | .recoverWith { case e: Throwable => 65 | Logger[IO].debug(s"Failed to read message: $e") 66 | } 67 | .foreverM 68 | .background 69 | yield new Client { 70 | 71 | def id: NodeId = selfId 72 | 73 | def getPeers( 74 | address: SocketAddress[IpAddress], 75 | infoHash: InfoHash 76 | ): IO[Either[Response.Nodes, Response.Peers]] = 77 | requestResponse.sendQuery(address, Query.GetPeers(selfId, infoHash)).flatMap { 78 | case nodes: Response.Nodes => nodes.asLeft.pure 79 | case peers: Response.Peers => peers.asRight.pure 80 | case _ => IO.raiseError(InvalidResponse()) 81 | } 82 | 83 | def findNodes(address: SocketAddress[IpAddress], target: NodeId): IO[Response.Nodes] = 84 | requestResponse.sendQuery(address, Query.FindNode(selfId, target)).flatMap { 85 | case nodes: Response.Nodes => nodes.pure 86 | case _ => IO.raiseError(InvalidResponse()) 87 | } 88 | 89 | def ping(address: SocketAddress[IpAddress]): IO[Response.Ping] = 90 | requestResponse.sendQuery(address, Query.Ping(selfId)).flatMap { 91 | case ping: Response.Ping => ping.pure 92 | case _ => IO.raiseError(InvalidResponse()) 93 | } 94 | def sampleInfoHashes(address: SocketAddress[IpAddress], target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]] = 95 | requestResponse.sendQuery(address, Query.SampleInfoHashes(selfId, target)).flatMap { 96 | case response: Response.SampleInfoHashes => response.asRight[Response.Nodes].pure 97 | case response: Response.Nodes => response.asLeft[Response.SampleInfoHashes].pure 98 | case _ => IO.raiseError(InvalidResponse()) 99 | } 100 | } 101 | } 102 | 103 | case class InvalidResponse() extends Throwable 104 | } 105 | -------------------------------------------------------------------------------- /bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/MessageSocket.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.wire 2 | 3 | import cats.effect.std.Semaphore 4 | import cats.effect.syntax.all.* 5 | import cats.effect.Async 6 | import cats.effect.Resource 7 | import cats.effect.Temporal 8 | import cats.syntax.all.* 9 | import com.github.torrentdam.bittorrent.* 10 | import com.github.torrentdam.bittorrent.protocol.message.{Handshake, Message} 11 | import fs2.io.net.Network 12 | import fs2.io.net.Socket 13 | import fs2.io.net.SocketGroup 14 | import fs2.Chunk 15 | import org.legogroup.woof.given 16 | import org.legogroup.woof.Logger 17 | 18 | import scala.concurrent.duration.* 19 | import scodec.bits.ByteVector 20 | 21 | class MessageSocket[F[_]]( 22 | val handshake: Handshake, 23 | val peerInfo: PeerInfo, 24 | socket: Socket[F], 25 | logger: Logger[F] 26 | )(using F: Temporal[F]) { 27 | 28 | import MessageSocket.readTimeout 29 | import MessageSocket.writeTimeout 30 | import MessageSocket.MaxMessageSize 31 | import MessageSocket.OversizedMessage 32 | 33 | def send(message: Message): F[Unit] = 34 | val bytes = Chunk.byteVector(Message.MessageCodec.encode(message).require.toByteVector) 35 | for 36 | _ <- socket.write(bytes) 37 | _ <- logger.trace(s">>> ${peerInfo.address} $message") 38 | yield () 39 | 40 | def receive: F[Message] = 41 | for 42 | bytes <- readExactlyN(4) 43 | size <- 44 | Message.MessageSizeCodec 45 | .decodeValue(bytes.toBitVector) 46 | .toTry 47 | .liftTo[F] 48 | _ <- F.whenA(size > MaxMessageSize)( 49 | logger.error(s"Oversized payload $size $MaxMessageSize") >> 50 | OversizedMessage(size, MaxMessageSize).raiseError 51 | ) 52 | message <- 53 | if (size == 0) F.pure(Message.KeepAlive) 54 | else 55 | readExactlyN(size.toInt).flatMap(bytes => 56 | F.fromTry( 57 | Message.MessageBodyCodec 58 | .decodeValue(bytes.toBitVector) 59 | .toTry 60 | ) 61 | ) 62 | _ <- logger.trace(s"<<< ${peerInfo.address} $message") 63 | yield message 64 | 65 | private def readExactlyN(numBytes: Int): F[ByteVector] = 66 | for 67 | chunk <- socket.readN(numBytes) 68 | _ <- if chunk.size == numBytes then F.unit else F.raiseError(new Exception("Connection was interrupted by peer")) 69 | yield chunk.toByteVector 70 | 71 | } 72 | 73 | object MessageSocket { 74 | 75 | val MaxMessageSize: Long = 1024 * 1024 // 1MB 76 | val readTimeout = 1.minute 77 | val writeTimeout = 10.seconds 78 | 79 | def connect[F[_]](selfId: PeerId, peerInfo: PeerInfo, infoHash: InfoHash)(using 80 | F: Async[F], 81 | network: Network[F], 82 | logger: Logger[F] 83 | ): Resource[F, MessageSocket[F]] = { 84 | for 85 | socket <- network.client(to = peerInfo.address).timeout(5.seconds) 86 | _ <- Resource.make(F.unit)(_ => logger.trace(s"Closed socket $peerInfo")) 87 | _ <- Resource.eval(logger.trace(s"Opened socket $peerInfo")) 88 | handshakeResponse <- Resource.eval( 89 | logger.trace(s"Initiate handshake with ${peerInfo.address}") *> 90 | handshake(selfId, infoHash, socket) <* 91 | logger.trace(s"Successful handshake with ${peerInfo.address}") 92 | ) 93 | yield new MessageSocket(handshakeResponse, peerInfo, socket, logger) 94 | } 95 | 96 | def handshake[F[_]]( 97 | selfId: PeerId, 98 | infoHash: InfoHash, 99 | socket: Socket[F] 100 | )(using F: Temporal[F]): F[Handshake] = { 101 | val message = Handshake(extensionProtocol = true, infoHash, selfId) 102 | for 103 | _ <- socket 104 | .write( 105 | bytes = Chunk.byteVector( 106 | Handshake.HandshakeCodec.encode(message).require.toByteVector 107 | ) 108 | ) 109 | .timeout(writeTimeout) 110 | handshakeMessageSize = Handshake.HandshakeCodec.sizeBound.exact.get.toInt / 8 111 | bytes <- 112 | socket 113 | .readN(handshakeMessageSize) 114 | .timeout(readTimeout) 115 | .adaptError(e => 116 | Error("Unsuccessful handshake", e) 117 | ) 118 | _ <- 119 | if bytes.size == handshakeMessageSize 120 | then F.unit 121 | else F.raiseError(Error("Unsuccessful handshake: connection prematurely closed")) 122 | response <- F.fromEither( 123 | Handshake.HandshakeCodec 124 | .decodeValue(bytes.toBitVector) 125 | .toEither 126 | .leftMap { e => 127 | Error(s"Unable to decode handhshake reponse: ${e.message}") 128 | } 129 | ) 130 | yield response 131 | } 132 | 133 | case class Error(message: String, cause: Throwable = null) extends Exception(message, cause) 134 | case class OversizedMessage(size: Long, maxSize: Long) extends Throwable(s"Oversized message [$size > $maxSize]") 135 | } 136 | -------------------------------------------------------------------------------- /dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Node.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.dht 2 | 3 | import cats.effect.implicits.* 4 | import cats.effect.std.Queue 5 | import cats.effect.std.Random 6 | import cats.effect.Async 7 | import cats.effect.IO 8 | import cats.effect.Resource 9 | import cats.effect.Sync 10 | import cats.implicits.* 11 | import com.comcast.ip4s.* 12 | import fs2.Stream 13 | import com.github.torrentdam.bittorrent.InfoHash 14 | 15 | import java.net.InetSocketAddress 16 | import org.legogroup.woof.given 17 | import org.legogroup.woof.Logger 18 | import scodec.bits.ByteVector 19 | 20 | import scala.concurrent.duration.DurationInt 21 | 22 | class Node(val id: NodeId, val client: Client, val routingTable: RoutingTable[IO], val discovery: PeerDiscovery) 23 | 24 | object Node { 25 | 26 | def apply( 27 | port: Option[Port] = None, 28 | bootstrapNodeAddress: Option[SocketAddress[Host]] = None 29 | )(using 30 | random: Random[IO], 31 | logger: Logger[IO] 32 | ): Resource[IO, Node] = 33 | for 34 | selfId <- Resource.eval(NodeId.random[IO]) 35 | messageSocket <- MessageSocket(port) 36 | routingTable <- RoutingTable[IO](selfId).toResource 37 | queryingNodes <- Queue.unbounded[IO, NodeInfo].toResource 38 | queryHandler = reportingQueryHandler(queryingNodes, QueryHandler.simple(selfId, routingTable)) 39 | client <- Client(selfId, messageSocket, queryHandler) 40 | insertingClient = new InsertingClient(client, routingTable) 41 | bootstrapNodes = bootstrapNodeAddress.map(List(_)).getOrElse(RoutingTableBootstrap.PublicBootstrapNodes) 42 | discovery = PeerDiscovery(routingTable, insertingClient) 43 | _ <- RoutingTableBootstrap(routingTable, insertingClient, discovery, bootstrapNodes).toResource 44 | _ <- RoutingTableRefresh(routingTable, client, discovery).runEvery(15.minutes).background 45 | _ <- pingCandidates(queryingNodes, client, routingTable).background 46 | yield new Node(selfId, insertingClient, routingTable, discovery) 47 | 48 | private class InsertingClient(client: Client, routingTable: RoutingTable[IO]) extends Client { 49 | 50 | def id: NodeId = client.id 51 | 52 | def getPeers(address: SocketAddress[IpAddress], infoHash: InfoHash): IO[Either[Response.Nodes, Response.Peers]] = 53 | client.getPeers(address, infoHash).flatTap { response => 54 | routingTable.insert( 55 | NodeInfo( 56 | response match 57 | case Left(response) => response.id 58 | case Right(response) => response.id, 59 | address 60 | ) 61 | ) 62 | } 63 | 64 | def findNodes(address: SocketAddress[IpAddress], target: NodeId): IO[Response.Nodes] = 65 | client.findNodes(address, target).flatTap { response => 66 | routingTable.insert(NodeInfo(response.id, address)) 67 | } 68 | 69 | def ping(address: SocketAddress[IpAddress]): IO[Response.Ping] = 70 | client.ping(address).flatTap { response => 71 | routingTable.insert(NodeInfo(response.id, address)) 72 | } 73 | 74 | def sampleInfoHashes(address: SocketAddress[IpAddress], target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]] = 75 | client.sampleInfoHashes(address, target).flatTap { response => 76 | routingTable.insert( 77 | response match 78 | case Left(response) => NodeInfo(response.id, address) 79 | case Right(response) => NodeInfo(response.id, address) 80 | ) 81 | } 82 | 83 | override def toString: String = s"InsertingClient($client)" 84 | } 85 | 86 | private def pingCandidate(node: NodeInfo, client: Client, routingTable: RoutingTable[IO])(using Logger[IO]) = 87 | routingTable.lookup(node.id).flatMap { 88 | case Some(_) => IO.unit 89 | case None => 90 | Logger[IO].info(s"Pinging $node") *> 91 | client.ping(node.address).timeout(5.seconds).attempt.flatMap { 92 | case Right(_) => 93 | Logger[IO].info(s"Got pong from $node -- insert as good") *> 94 | routingTable.insert(node) 95 | case Left(_) => IO.unit 96 | } 97 | } 98 | 99 | private def pingCandidates(nodes: Queue[IO, NodeInfo], client: Client, routingTable: RoutingTable[IO])(using Logger[IO]) = 100 | nodes 101 | .tryTakeN(none) 102 | .flatMap(candidates => 103 | candidates 104 | .distinct 105 | .traverse_(pingCandidate(_, client, routingTable).attempt.void) 106 | ) 107 | .productR(IO.sleep(1.minute)) 108 | .foreverM 109 | 110 | 111 | private def reportingQueryHandler(queue: Queue[IO, NodeInfo], next: QueryHandler[IO]): QueryHandler[IO] = (address, query) => 112 | val nodeInfo = query match 113 | case Query.Ping(id) => NodeInfo(id, address) 114 | case Query.FindNode(id, _) => NodeInfo(id, address) 115 | case Query.GetPeers(id, _) => NodeInfo(id, address) 116 | case Query.AnnouncePeer(id, _, _) => NodeInfo(id, address) 117 | case Query.SampleInfoHashes(id, _) => NodeInfo(id, address) 118 | queue.offer(nodeInfo) *> next(address, query) 119 | } 120 | -------------------------------------------------------------------------------- /bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/ExtensionHandler.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.wire 2 | 3 | import cats.effect.kernel.Deferred 4 | import cats.effect.kernel.Ref 5 | import cats.effect.std.Queue 6 | import cats.effect.Async 7 | import cats.effect.Concurrent 8 | import cats.effect.Sync 9 | import cats.implicits.* 10 | import cats.Applicative 11 | import cats.Monad 12 | import cats.MonadError 13 | import com.github.torrentdam.bittorrent.protocol.extensions.metadata.UtMessage 14 | import com.github.torrentdam.bittorrent.protocol.extensions.ExtensionHandshake 15 | import com.github.torrentdam.bittorrent.protocol.extensions.Extensions 16 | import com.github.torrentdam.bittorrent.protocol.extensions.Extensions.MessageId 17 | import com.github.torrentdam.bittorrent.InfoHash 18 | import com.github.torrentdam.bittorrent.TorrentMetadata.Lossless 19 | import com.github.torrentdam.bittorrent.protocol.message.Message 20 | import fs2.Stream 21 | import scodec.bits.ByteVector 22 | import com.github.torrentdam.bittorrent.CrossPlatform 23 | 24 | trait ExtensionHandler[F[_]] { 25 | 26 | def apply(message: Message.Extended): F[Unit] 27 | } 28 | 29 | object ExtensionHandler { 30 | 31 | def noop[F[_]](using F: Applicative[F]): ExtensionHandler[F] = _ => F.unit 32 | 33 | def dynamic[F[_]: Monad](get: F[ExtensionHandler[F]]): ExtensionHandler[F] = message => get.flatMap(_(message)) 34 | 35 | type Send[F[_]] = Message.Extended => F[Unit] 36 | 37 | trait InitExtension[F[_]] { 38 | 39 | def init: F[ExtensionApi[F]] 40 | } 41 | 42 | object InitExtension { 43 | 44 | def apply[F[_]]( 45 | infoHash: InfoHash, 46 | send: Send[F], 47 | utMetadata: UtMetadata.Create[F] 48 | )(using F: Concurrent[F]): F[(ExtensionHandler[F], InitExtension[F])] = 49 | for 50 | apiDeferred <- F.deferred[ExtensionApi[F]] 51 | handlerRef <- F.ref[ExtensionHandler[F]](ExtensionHandler.noop) 52 | _ <- handlerRef.set( 53 | { 54 | case Message.Extended(MessageId.Handshake, payload) => 55 | for 56 | handshake <- F.fromEither(ExtensionHandshake.decode(payload)) 57 | (handler, extensionApi) <- ExtensionApi[F](infoHash, send, utMetadata, handshake) 58 | _ <- handlerRef.set(handler) 59 | _ <- apiDeferred.complete(extensionApi) 60 | yield () 61 | case message => 62 | F.raiseError(InvalidMessage(s"Expected Handshake but received ${message.getClass.getSimpleName}")) 63 | } 64 | ) 65 | yield 66 | 67 | val handler = dynamic(handlerRef.get) 68 | 69 | val api = new InitExtension[F] { 70 | 71 | def init: F[ExtensionApi[F]] = { 72 | val message: Message.Extended = 73 | Message.Extended( 74 | MessageId.Handshake, 75 | ExtensionHandshake.encode(Extensions.handshake) 76 | ) 77 | send(message) >> apiDeferred.get 78 | } 79 | } 80 | 81 | (handler, api) 82 | end for 83 | } 84 | 85 | trait ExtensionApi[F[_]] { 86 | 87 | def utMetadata: Option[UtMetadata[F]] 88 | } 89 | 90 | object ExtensionApi { 91 | 92 | def apply[F[_]]( 93 | infoHash: InfoHash, 94 | send: Send[F], 95 | utMetadata: UtMetadata.Create[F], 96 | handshake: ExtensionHandshake 97 | )(using F: MonadError[F, Throwable]): F[(ExtensionHandler[F], ExtensionApi[F])] = { 98 | for (utHandler, utMetadata0) <- utMetadata(infoHash, handshake, send) 99 | yield 100 | 101 | val handler: ExtensionHandler[F] = { 102 | case Message.Extended(Extensions.MessageId.Metadata, messageBytes) => 103 | F.fromEither(UtMessage.decode(messageBytes)) >>= utHandler.apply 104 | case Message.Extended(id, _) => 105 | F.raiseError(InvalidMessage(s"Unsupported message id=$id")) 106 | } 107 | 108 | val api: ExtensionApi[F] = new ExtensionApi[F] { 109 | def utMetadata: Option[UtMetadata[F]] = utMetadata0 110 | } 111 | 112 | (handler, api) 113 | end for 114 | } 115 | 116 | } 117 | 118 | trait UtMetadata[F[_]] { 119 | 120 | def fetch: F[Lossless] 121 | } 122 | 123 | object UtMetadata { 124 | 125 | trait Handler[F[_]] { 126 | 127 | def apply(message: UtMessage): F[Unit] 128 | } 129 | 130 | object Handler { 131 | 132 | def unit[F[_]](using F: Applicative[F]): Handler[F] = _ => F.unit 133 | } 134 | 135 | class Create[F[_]](using F: Async[F]) { 136 | 137 | def apply( 138 | infoHash: InfoHash, 139 | handshake: ExtensionHandshake, 140 | send: Message.Extended => F[Unit] 141 | ): F[(Handler[F], Option[UtMetadata[F]])] = { 142 | 143 | (handshake.extensions.get("ut_metadata"), handshake.metadataSize).tupled match { 144 | case Some((messageId, size)) => 145 | for receiveQueue <- Queue.bounded[F, UtMessage](1) 146 | yield 147 | def sendUtMessage(utMessage: UtMessage) = { 148 | val message: Message.Extended = Message.Extended(messageId, UtMessage.encode(utMessage)) 149 | send(message) 150 | } 151 | 152 | def receiveUtMessage: F[UtMessage] = receiveQueue.take 153 | 154 | (receiveQueue.offer, (new Impl(sendUtMessage, receiveUtMessage, size, infoHash)).some) 155 | end for 156 | 157 | case None => 158 | (Handler.unit[F], Option.empty[UtMetadata[F]]).pure[F] 159 | 160 | } 161 | 162 | } 163 | } 164 | 165 | private class Impl[F[_]]( 166 | send: UtMessage => F[Unit], 167 | receive: F[UtMessage], 168 | size: Long, 169 | infoHash: InfoHash 170 | )(using F: Sync[F]) 171 | extends UtMetadata[F] { 172 | 173 | def fetch: F[Lossless] = 174 | Stream 175 | .range(0, 100) 176 | .evalMap { index => 177 | send(UtMessage.Request(index)) *> receive.flatMap { 178 | case UtMessage.Data(`index`, bytes) => bytes.pure[F] 179 | case m => 180 | F.raiseError[ByteVector]( 181 | InvalidMessage(s"Data message expected but received ${m.getClass.getSimpleName}") 182 | ) 183 | } 184 | } 185 | .scan(ByteVector.empty)(_ ++ _) 186 | .find(_.size >= size) 187 | .compile 188 | .lastOrError 189 | .ensure(InvalidMetadata()) { metadata => 190 | CrossPlatform.sha1(metadata) == infoHash.bytes 191 | } 192 | .flatMap { bytes => 193 | Lossless.fromBytes(bytes).liftTo[F] 194 | } 195 | } 196 | } 197 | 198 | case class InvalidMessage(message: String) extends Throwable(message) 199 | case class InvalidMetadata() extends Throwable 200 | } 201 | -------------------------------------------------------------------------------- /dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RoutingTable.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.dht 2 | 3 | import cats.effect.kernel.Concurrent 4 | import cats.effect.kernel.Ref 5 | import cats.effect.Sync 6 | import cats.implicits.* 7 | import com.comcast.ip4s.* 8 | import com.github.torrentdam.bittorrent.InfoHash 9 | import com.github.torrentdam.bittorrent.PeerInfo 10 | 11 | import scala.collection.immutable.ListMap 12 | import scodec.bits.ByteVector 13 | 14 | import scala.annotation.tailrec 15 | 16 | trait RoutingTable[F[_]] { 17 | 18 | def insert(node: NodeInfo): F[Unit] 19 | 20 | def remove(nodeId: NodeId): F[Unit] 21 | 22 | def goodNodes(nodeId: NodeId): F[Iterable[NodeInfo]] 23 | 24 | def addPeer(infoHash: InfoHash, peerInfo: PeerInfo): F[Unit] 25 | 26 | def findPeers(infoHash: InfoHash): F[Option[Iterable[PeerInfo]]] 27 | 28 | def allNodes: F[Iterable[RoutingTable.Node]] 29 | 30 | def buckets: F[Iterable[RoutingTable.TreeNode.Bucket]] 31 | 32 | def updateGoodness(good: Set[NodeId], bad: Set[NodeId]): F[Unit] 33 | 34 | def lookup(nodeId: NodeId): F[Option[RoutingTable.Node]] 35 | } 36 | 37 | object RoutingTable { 38 | 39 | enum TreeNode: 40 | case Split(center: BigInt, lower: TreeNode, higher: TreeNode) 41 | case Bucket(from: BigInt, until: BigInt, nodes: Map[NodeId, Node]) 42 | 43 | case class Node(id: NodeId, address: SocketAddress[IpAddress], isGood: Boolean, badCount: Int = 0): 44 | def toNodeInfo: NodeInfo = NodeInfo(id, address) 45 | 46 | object TreeNode { 47 | 48 | def empty: TreeNode = 49 | TreeNode.Bucket( 50 | from = BigInt(0), 51 | until = BigInt(1, ByteVector.fill(20)(-1: Byte).toArray), 52 | Map.empty 53 | ) 54 | } 55 | 56 | val MaxNodes = 8 57 | 58 | import TreeNode.* 59 | 60 | extension (bucket: TreeNode) 61 | def insert(node: NodeInfo, selfId: NodeId): TreeNode = 62 | bucket match 63 | case b @ Split(center, lower, higher) => 64 | if (node.id.int < center) 65 | b.copy(lower = lower.insert(node, selfId)) 66 | else 67 | b.copy(higher = higher.insert(node, selfId)) 68 | case b @ Bucket(from, until, nodes) => 69 | if nodes.size >= MaxNodes && !nodes.contains(selfId) 70 | then 71 | if selfId.int >= from && selfId.int < until 72 | then 73 | // split the bucket because it contains the self node 74 | val center = (from + until) / 2 75 | val splitNode = 76 | Split( 77 | center, 78 | lower = Bucket(from, center, nodes.view.filterKeys(_.int < center).to(ListMap)), 79 | higher = Bucket(center, until, nodes.view.filterKeys(_.int >= center).to(ListMap)) 80 | ) 81 | splitNode.insert(node, selfId) 82 | else 83 | // drop one node from the bucket 84 | val badNode = nodes.values.find(!_.isGood) 85 | badNode match 86 | case Some(badNode) => Bucket(from, until, nodes.removed(badNode.id)).insert(node, selfId) 87 | case None => b 88 | else 89 | Bucket(from, until, nodes.updated(node.id, Node(node.id, node.address, isGood = true))) 90 | 91 | def remove(nodeId: NodeId): TreeNode = 92 | bucket match 93 | case b @ Split(center, lower, higher) => 94 | if (nodeId.int < center) 95 | (lower.remove(nodeId), higher) match { 96 | case (Bucket(lowerFrom, _, nodes), finalHigher: Bucket) if nodes.isEmpty => 97 | finalHigher.copy(from = lowerFrom) 98 | case (l, _) => 99 | b.copy(lower = l) 100 | } 101 | else 102 | (higher.remove(nodeId), lower) match { 103 | case (Bucket(_, higherUntil, nodes), finalLower: Bucket) if nodes.isEmpty => 104 | finalLower.copy(until = higherUntil) 105 | case (h, _) => 106 | b.copy(higher = h) 107 | } 108 | case b @ Bucket(_, _, nodes) => 109 | b.copy(nodes = nodes - nodeId) 110 | 111 | @tailrec 112 | def findBucket(nodeId: NodeId): Bucket = 113 | bucket match 114 | case Split(center, lower, higher) => 115 | if (nodeId.int < center) 116 | lower.findBucket(nodeId) 117 | else 118 | higher.findBucket(nodeId) 119 | case b: Bucket => b 120 | 121 | def findNodes(nodeId: NodeId): Iterable[Node] = 122 | bucket match 123 | case Split(center, lower, higher) => 124 | if (nodeId.int < center) 125 | lower.findNodes(nodeId) ++ higher.findNodes(nodeId) 126 | else 127 | higher.findNodes(nodeId) ++ lower.findNodes(nodeId) 128 | case b: Bucket => b.nodes.values.to(LazyList) 129 | 130 | def buckets: Iterable[Bucket] = 131 | bucket match 132 | case b: Bucket => Iterable(b) 133 | case Split(_, lower, higher) => lower.buckets ++ higher.buckets 134 | 135 | def update(fn: Node => Node): TreeNode = 136 | bucket match 137 | case b @ Split(_, lower, higher) => 138 | b.copy(lower = lower.update(fn), higher = higher.update(fn)) 139 | case b @ Bucket(from, until, nodes) => 140 | b.copy(nodes = nodes.view.mapValues(fn).to(ListMap)) 141 | 142 | end extension 143 | 144 | def apply[F[_]: Concurrent](selfId: NodeId): F[RoutingTable[F]] = 145 | for { 146 | treeNodeRef <- Ref.of(TreeNode.empty) 147 | peers <- Ref.of(Map.empty[InfoHash, Set[PeerInfo]]) 148 | } yield new RoutingTable[F] { 149 | 150 | def insert(node: NodeInfo): F[Unit] = 151 | treeNodeRef.update(_.insert(node, selfId)) 152 | 153 | def remove(nodeId: NodeId): F[Unit] = 154 | treeNodeRef.update(_.remove(nodeId)) 155 | 156 | def goodNodes(nodeId: NodeId): F[Iterable[NodeInfo]] = 157 | treeNodeRef.get.map(_.findNodes(nodeId).filter(_.isGood).map(_.toNodeInfo)) 158 | 159 | def addPeer(infoHash: InfoHash, peerInfo: PeerInfo): F[Unit] = 160 | peers.update { map => 161 | map.updatedWith(infoHash) { 162 | case Some(set) => Some(set + peerInfo) 163 | case None => Some(Set(peerInfo)) 164 | } 165 | } 166 | 167 | def findPeers(infoHash: InfoHash): F[Option[Iterable[PeerInfo]]] = 168 | peers.get.map(_.get(infoHash)) 169 | 170 | def allNodes: F[Iterable[Node]] = 171 | treeNodeRef.get.map(_.findNodes(selfId)) 172 | 173 | def buckets: F[Iterable[TreeNode.Bucket]] = 174 | treeNodeRef.get.map(_.buckets) 175 | 176 | def updateGoodness(good: Set[NodeId], bad: Set[NodeId]): F[Unit] = 177 | treeNodeRef.update( 178 | _.update(node => 179 | if good.contains(node.id) then node.copy(isGood = true, badCount = 0) 180 | else if bad.contains(node.id) then node.copy(isGood = false, badCount = node.badCount + 1) 181 | else node 182 | ) 183 | ) 184 | 185 | def lookup(nodeId: NodeId): F[Option[Node]] = 186 | treeNodeRef.get.map(_.findBucket(nodeId).nodes.get(nodeId)) 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/Download.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.wire 2 | 3 | import cats.effect.implicits.* 4 | import cats.effect.kernel.Ref 5 | import cats.effect.std.Queue 6 | import cats.effect.std.Semaphore 7 | import cats.effect.std.Supervisor 8 | import cats.effect.Async 9 | import cats.effect.IO 10 | import cats.effect.Resource 11 | import cats.effect.Temporal 12 | import cats.implicits.* 13 | import cats.Show.Shown 14 | import com.github.torrentdam.bittorrent.TorrentMetadata 15 | import com.github.torrentdam.bittorrent.protocol.message.Message 16 | import fs2.concurrent.Signal 17 | import fs2.concurrent.SignallingRef 18 | import fs2.concurrent.Topic 19 | import fs2.Chunk 20 | import fs2.Stream 21 | import org.legogroup.woof.given 22 | import org.legogroup.woof.Logger 23 | 24 | import scala.collection.BitSet 25 | import scala.concurrent.duration.* 26 | import scala.util.chaining.* 27 | import scodec.bits.ByteVector 28 | 29 | object Download { 30 | 31 | def apply( 32 | swarm: Swarm, 33 | piecePicker: RequestDispatcher 34 | )(using 35 | logger: Logger[IO] 36 | ): IO[Unit] = 37 | import Logger.withLogContext 38 | Classifier().use(classifier => 39 | swarm.connect 40 | .use(connection => 41 | ( 42 | for 43 | _ <- connection.interested 44 | _ <- classifier.create.use(speedInfo => 45 | whenUnchoked(connection)( 46 | download(connection, piecePicker, speedInfo) 47 | ) 48 | ) 49 | yield () 50 | ) 51 | .race(connection.disconnected) 52 | .withLogContext("address", connection.info.address.toString) 53 | ) 54 | .attempt 55 | .foreverM 56 | .parReplicateA(30) 57 | .void 58 | ) 59 | 60 | private def download( 61 | connection: Connection, 62 | pieces: RequestDispatcher, 63 | speedInfo: SpeedData 64 | )(using logger: Logger[IO]): IO[Unit] = { 65 | 66 | def bounded(min: Int, max: Int)(n: Int): Int = math.min(math.max(n, min), max) 67 | 68 | def computeOutstanding( 69 | downloadedBytes: Topic[IO, Long], 70 | downloadedTotal: Ref[IO, Long], 71 | maxOutstanding: SignallingRef[IO, Int] 72 | ) = 73 | downloadedBytes 74 | .subscribe(10) 75 | .evalTap(size => downloadedTotal.update(_ + size)) 76 | .groupWithin(Int.MaxValue, 10.seconds) 77 | .map(chunks => 78 | bounded(1, 100)( 79 | (chunks.foldLeft(0L)(_ + _) / 10 / RequestDispatcher.ChunkSize).toInt 80 | ) 81 | ) 82 | .evalTap(maxOutstanding.set) 83 | .evalTap(speedInfo.bytes.set) 84 | .compile 85 | .drain 86 | 87 | def updateSemaphore(semaphore: Semaphore[IO], maxOutstanding: SignallingRef[IO, Int]) = 88 | maxOutstanding.discrete 89 | .sliding(2) 90 | .evalMap { chunk => 91 | val (prev, next) = (chunk(0), chunk(1)) 92 | if prev < next then semaphore.releaseN(next - prev) 93 | else semaphore.acquireN(prev - next) 94 | } 95 | .compile 96 | .drain 97 | 98 | def sendRequest(request: Message.Request): IO[ByteVector] = 99 | logger.trace(s"Request $request") >> 100 | connection 101 | .request(request) 102 | .timeout(5.seconds) 103 | 104 | def nextRequest(semaphore: Semaphore[IO]) = 105 | semaphore.permit >> pieces.stream(connection.availability.get, speedInfo.cls.get) 106 | 107 | def fireRequests( 108 | semaphore: Semaphore[IO], 109 | failureCounter: Ref[IO, Int], 110 | downloadedBytes: Topic[IO, Long] 111 | ) = 112 | Stream 113 | .resource(nextRequest(semaphore)) 114 | .interruptWhen( 115 | IO.sleep(1.minute).as(Left(Error.TimeoutWaitingForPiece(1.minute))) 116 | ) 117 | .repeat 118 | .map { (request, promise) => 119 | Stream.eval( 120 | sendRequest(request).attempt.flatMap { 121 | case Right(bytes) => 122 | failureCounter.set(0) >> downloadedBytes.publish1(bytes.size) >> promise.complete(bytes) 123 | case Left(_) => 124 | failureCounter 125 | .updateAndGet(_ + 1) 126 | .flatMap(count => 127 | if count >= 10 128 | then IO.raiseError(Error.PeerDoesNotRespond()) 129 | else IO.unit 130 | ) 131 | } 132 | ) 133 | } 134 | .parJoinUnbounded 135 | .compile 136 | .drain 137 | 138 | for 139 | failureCounter <- IO.ref(0) 140 | downloadedBytes <- Topic[IO, Long] 141 | downloadedTotal <- IO.ref(0L) 142 | maxOutstanding <- SignallingRef[IO, Int](5) 143 | semaphore <- Semaphore[IO](5) 144 | _ <- ( 145 | computeOutstanding(downloadedBytes, downloadedTotal, maxOutstanding), 146 | updateSemaphore(semaphore, maxOutstanding), 147 | fireRequests(semaphore, failureCounter, downloadedBytes) 148 | ).parTupled 149 | .handleErrorWith(e => 150 | downloadedTotal.get.flatMap { 151 | case 0 => IO.raiseError(e) 152 | case _ => IO.unit 153 | } 154 | ) 155 | yield () 156 | } 157 | 158 | private def whenUnchoked(connection: Connection)(f: IO[Unit])(using 159 | logger: Logger[IO] 160 | ): IO[Unit] = { 161 | def waitChoked = connection.choked.waitUntil(identity) 162 | def waitUnchoked = 163 | connection.choked 164 | .waitUntil(choked => !choked) 165 | .timeoutTo(30.seconds, IO.raiseError(Error.TimeoutWaitingForUnchoke(30.seconds))) 166 | 167 | (waitUnchoked >> (f race waitChoked)).foreverM 168 | } 169 | 170 | private case class SpeedData(bytes: Ref[IO, Int], cls: Ref[IO, SpeedClass]) 171 | 172 | private class Classifier(counter: Ref[IO, Long], state: Ref[IO, Map[Long, SpeedData]]) { 173 | def create: Resource[IO, SpeedData] = Resource( 174 | for 175 | id <- counter.getAndUpdate(_ + 1) 176 | bytes <- IO.ref(0) 177 | cls <- IO.ref(SpeedClass.Slow) 178 | _ <- state.update(_.updated(id, SpeedData(bytes, cls))) 179 | yield (SpeedData(bytes, cls), state.update(_ - id)) 180 | ) 181 | } 182 | private object Classifier { 183 | def apply(): Resource[IO, Classifier] = 184 | for 185 | counter <- Resource.eval(IO.ref(0L)) 186 | state <- Resource.eval(IO.ref(Map.empty[Long, SpeedData])) 187 | _ <- (IO.sleep(10.seconds) >> updateClass(state)).foreverM.background 188 | yield new Classifier(counter, state) 189 | 190 | private def updateClass(state: Ref[IO, Map[Long, SpeedData]]): IO[Unit] = 191 | state.get 192 | .flatMap(_.values.toList.traverse { info => 193 | info.bytes.get.tupleRight(info) 194 | }) 195 | .flatMap(values => 196 | val sorted = values.sortBy(_._1)(using Ordering[Int].reverse).map(_._2) 197 | val fastCount = (values.size.toDouble * 0.7).ceil.toInt 198 | val (fast, slow) = sorted.splitAt(fastCount) 199 | fast.traverse(_.cls.set(SpeedClass.Fast)) >> slow.traverse(_.cls.set(SpeedClass.Slow)) 200 | ) 201 | .void 202 | } 203 | 204 | enum Error(message: String) extends Throwable(message): 205 | case TimeoutWaitingForUnchoke(duration: FiniteDuration) extends Error(s"Unchoke timeout $duration") 206 | case TimeoutWaitingForPiece(duration: FiniteDuration) extends Error(s"Block request timeout $duration") 207 | case InvalidChecksum() extends Error("Invalid checksum") 208 | case PeerDoesNotRespond() extends Error("Peer does not respond") 209 | } 210 | 211 | enum SpeedClass { 212 | case Slow, Fast 213 | } 214 | -------------------------------------------------------------------------------- /bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/RequestDispatcher.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.wire 2 | 3 | import cats.effect.cps.* 4 | import cats.data.Chain 5 | import cats.effect.kernel.Deferred 6 | import cats.effect.std.Dequeue 7 | import cats.effect.std.Semaphore 8 | import cats.effect.IO 9 | import cats.effect.Ref 10 | import cats.effect.Resource 11 | import cats.implicits.* 12 | import com.comcast.ip4s.* 13 | import com.github.torrentdam.bittorrent.protocol.message.Message.Request 14 | import com.github.torrentdam.bittorrent.PeerInfo 15 | import com.github.torrentdam.bittorrent.TorrentMetadata 16 | import com.github.torrentdam.bittorrent.protocol.message.Message 17 | import com.github.torrentdam.bittorrent.CrossPlatform 18 | import fs2.concurrent.Signal 19 | import fs2.concurrent.SignallingRef 20 | import fs2.Stream 21 | import java.util.UUID 22 | import org.legogroup.woof.given 23 | import org.legogroup.woof.Logger 24 | 25 | import scala.collection.immutable.BitSet 26 | import scala.collection.immutable.TreeMap 27 | import scala.concurrent.duration.DurationInt 28 | import scodec.bits.ByteVector 29 | 30 | trait RequestDispatcher { 31 | def downloadPiece(index: Long): IO[ByteVector] 32 | 33 | def stream( 34 | availability: IO[BitSet], 35 | speedClass: IO[SpeedClass] 36 | ): Resource[IO, (Message.Request, Deferred[IO, ByteVector])] 37 | } 38 | 39 | object RequestDispatcher { 40 | val ChunkSize: Int = 16 * 1024 41 | 42 | private type RequestQueue = WorkQueue[Message.Request, ByteVector] 43 | 44 | def apply(metadata: TorrentMetadata)(using logger: Logger[IO]): Resource[IO, RequestDispatcher] = 45 | for 46 | queue <- Resource.eval(IO.ref(TreeMap.empty[Long, RequestQueue])) 47 | queueReverse <- Resource.eval(IO.ref(TreeMap.empty[Long, RequestQueue](using Ordering[Long].reverse))) 48 | workGenerator = WorkGenerator(metadata) 49 | yield Impl(workGenerator, queue, queueReverse) 50 | 51 | private class Impl( 52 | workGenerator: WorkGenerator, 53 | queue: Ref[IO, TreeMap[Long, RequestQueue]], 54 | queueReverse: Ref[IO, TreeMap[Long, RequestQueue]] 55 | )(using Logger[IO]) extends RequestDispatcher { 56 | def downloadPiece(index: Long): IO[ByteVector] = 57 | val pieceWork = workGenerator.pieceWork(index) 58 | val attempt = async[IO] { 59 | try 60 | val result = IO.deferred[Map[Request, ByteVector]].await 61 | val requestQueue = WorkQueue(pieceWork.requests.toList, result.complete).await 62 | queue.update(_.updated(index, requestQueue)).await 63 | queueReverse.update(_.updated(index, requestQueue)).await 64 | val completedBlocks = result.get.await 65 | val pieceBytes = completedBlocks.toList.sortBy(_._1.begin).map(_._2).foldLeft(ByteVector.empty)(_ ++ _) 66 | pieceBytes 67 | finally 68 | queue.update(_ - index).await 69 | queueReverse.update(_ - index).await 70 | } 71 | attempt.flatMap(bytes => 72 | if CrossPlatform.sha1(bytes) == pieceWork.checksum then IO.pure(bytes) 73 | else 74 | Logger[IO].warn(s"Piece $index failed checksum") >> attempt 75 | ) 76 | 77 | def stream( 78 | availability: IO[BitSet], 79 | speedClass: IO[SpeedClass] 80 | ): Resource[IO, (Message.Request, Deferred[IO, ByteVector])] = 81 | def pickFrom(trackers: List[RequestQueue]): Resource[IO, (Message.Request, Deferred[IO, ByteVector])] = 82 | trackers match 83 | case Nil => 84 | Resource.eval(IO.raiseError(NoPieceAvailable)) 85 | case tracker :: rest => 86 | tracker.nextRequest.recoverWith(_ => pickFrom(rest)) 87 | 88 | def singlePass = 89 | for 90 | speedClass <- Resource.eval(speedClass) 91 | inProgress <- Resource.eval( 92 | speedClass match 93 | case SpeedClass.Fast => queue.get 94 | case SpeedClass.Slow => queueReverse.get // take piece with the highest index first 95 | ) 96 | availability <- Resource.eval(availability) 97 | matched = inProgress.collect { case (index, tracker) if availability(index.toInt) => tracker }.toList 98 | result <- pickFrom(matched) 99 | yield result 100 | 101 | def polling: Resource[IO, (Message.Request, Deferred[IO, ByteVector])] = 102 | singlePass.recoverWith(_ => Resource.eval(IO.sleep(1.seconds)) >> polling) 103 | 104 | polling 105 | } 106 | 107 | class WorkGenerator(pieceLength: Long, totalLength: Long, pieces: ByteVector) { 108 | def this(metadata: TorrentMetadata) = 109 | this( 110 | metadata.pieceLength, 111 | metadata.files.map(_.length).sum, 112 | metadata.pieces 113 | ) 114 | 115 | def pieceWork(index: Long): PieceWork = 116 | val thisPieceLength = math.min(pieceLength, totalLength - index * pieceLength) 117 | PieceWork( 118 | thisPieceLength, 119 | pieces.drop(index * 20).take(20), 120 | genRequests(index, thisPieceLength) 121 | ) 122 | 123 | def genRequests(pieceIndex: Long, pieceLength: Long): Chain[Message.Request] = 124 | var result = Chain.empty[Message.Request] 125 | 126 | def loop(requestIndex: Long): Unit = { 127 | val thisChunkSize = math.min(ChunkSize, pieceLength - requestIndex * ChunkSize) 128 | if thisChunkSize > 0 then 129 | val begin = requestIndex * ChunkSize 130 | result = result.append( 131 | Message.Request( 132 | pieceIndex, 133 | begin, 134 | thisChunkSize 135 | ) 136 | ) 137 | loop(requestIndex + 1) 138 | } 139 | 140 | loop(0) 141 | result 142 | } 143 | 144 | case class PieceWork( 145 | size: Long, 146 | checksum: ByteVector, 147 | requests: Chain[Message.Request] 148 | ) 149 | 150 | case object NoPieceAvailable extends Throwable("No piece available") 151 | 152 | trait WorkQueue[Work, Result] { 153 | def nextRequest: Resource[IO, (Work, Deferred[IO, Result])] 154 | } 155 | 156 | object WorkQueue { 157 | 158 | def apply[Request, Response]( 159 | requests: Seq[Request], 160 | onComplete: Map[Request, Response] => IO[Any] 161 | ): IO[WorkQueue[Request, Response]] = 162 | require(requests.nonEmpty) 163 | for 164 | requestQueue <- Dequeue.unbounded[IO, Request] 165 | _ <- requests.traverse(requestQueue.offer) 166 | responses <- IO.ref(Map.empty[Request, Response]) 167 | outstandingCount <- IO.ref(requests.size) 168 | yield new { 169 | override def nextRequest: Resource[IO, (Request, Deferred[IO, Response])] = 170 | Resource( 171 | for 172 | _ <- outstandingCount.get.flatMap(n => IO.raiseWhen(n == 0)(PieceComplete)) 173 | request <- requestQueue.tryTake 174 | request <- request match 175 | case Some(request) => IO.pure(request) 176 | case None => IO.raiseError(EmptyQueue) 177 | promise <- IO.deferred[Response] 178 | yield ( 179 | (request, promise), 180 | for 181 | bytes <- promise.tryGet 182 | _ <- bytes match 183 | case Some(bytes) => 184 | for 185 | _ <- outstandingCount.update(_ - 1) 186 | result <- responses.updateAndGet(_.updated(request, bytes)) 187 | _ <- outstandingCount.get.flatMap(n => IO.whenA(n == 0)(onComplete(result).void)) 188 | yield () 189 | case None => 190 | requestQueue.offerFront(request) 191 | yield () 192 | ) 193 | ) 194 | } 195 | 196 | case object EmptyQueue extends Throwable 197 | case object PieceComplete extends Throwable 198 | } 199 | } 200 | -------------------------------------------------------------------------------- /dht/src/main/scala/com/github/torrentdam/bittorrent/dht/PeerDiscovery.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.dht 2 | 3 | import cats.effect.kernel.Deferred 4 | import cats.effect.kernel.Ref 5 | import cats.effect.Concurrent 6 | import cats.effect.IO 7 | import cats.effect.Resource 8 | import cats.instances.all.* 9 | import cats.syntax.all.* 10 | import cats.effect.cps.{given, *} 11 | import cats.Show.Shown 12 | import com.github.torrentdam.bittorrent.InfoHash 13 | import com.github.torrentdam.bittorrent.PeerInfo 14 | import fs2.Stream 15 | import org.legogroup.woof.given 16 | import org.legogroup.woof.Logger 17 | 18 | import scala.concurrent.duration.DurationInt 19 | import Logger.withLogContext 20 | import com.comcast.ip4s.{IpAddress, SocketAddress} 21 | 22 | trait PeerDiscovery { 23 | 24 | def discover(infoHash: InfoHash): Stream[IO, PeerInfo] 25 | 26 | def findNodes(NodeId: NodeId): Stream[IO, NodeInfo] 27 | } 28 | 29 | object PeerDiscovery { 30 | 31 | def apply( 32 | routingTable: RoutingTable[IO], 33 | dhtClient: Client 34 | )(using 35 | logger: Logger[IO] 36 | ): PeerDiscovery = new { 37 | def discover(infoHash: InfoHash): Stream[IO, PeerInfo] = { 38 | Stream 39 | .eval { 40 | for { 41 | _ <- logger.info("Start discovery") 42 | initialNodes <- routingTable.goodNodes(NodeId(infoHash.bytes)) 43 | initialNodes <- initialNodes.take(16).toList.pure[IO] 44 | _ <- logger.info(s"Received ${initialNodes.size} from own routing table") 45 | state <- DiscoveryState(initialNodes, infoHash) 46 | } yield { 47 | start( 48 | infoHash, 49 | dhtClient.getPeers, 50 | state 51 | ) 52 | } 53 | } 54 | .flatten 55 | .onFinalizeCase { 56 | case Resource.ExitCase.Errored(e) => logger.error(s"Discovery failed with ${e.getMessage}") 57 | case _ => IO.unit 58 | } 59 | } 60 | 61 | def findNodes(nodeId: NodeId): Stream[IO, NodeInfo] = 62 | Stream 63 | .eval( 64 | for 65 | _ <- logger.info(s"Start finding nodes for $nodeId") 66 | initialNodes <- routingTable.goodNodes(nodeId) 67 | initialNodes <- initialNodes 68 | .take(16) 69 | .toList 70 | .sortBy(nodeInfo => NodeId.distance(nodeInfo.id, dhtClient.id)) 71 | .pure[IO] 72 | yield 73 | FindNodesState(nodeId, initialNodes) 74 | ) 75 | .flatMap { state => 76 | Stream 77 | .unfoldEval(state)(_.next) 78 | .flatMap(Stream.emits) 79 | } 80 | 81 | case class FindNodesState( 82 | targetId: NodeId, 83 | nodesToQuery: List[NodeInfo], 84 | usedNodes: Set[NodeInfo] = Set.empty, 85 | respondedCount: Int = 0 86 | ): 87 | def next: IO[Option[(List[NodeInfo], FindNodesState)]] = async[IO]: 88 | if nodesToQuery.isEmpty then 89 | none 90 | else 91 | val responses = nodesToQuery 92 | .parTraverse(nodeInfo => 93 | dhtClient 94 | .findNodes(nodeInfo.address, targetId) 95 | .map(_.nodes.some) 96 | .timeout(5.seconds) 97 | .orElse(none.pure[IO]) 98 | .tupleLeft(nodeInfo) 99 | ) 100 | .await 101 | val respondedNodes = responses.collect { case (nodeInfo, Some(_)) => nodeInfo } 102 | val foundNodes = responses.collect { case (_, Some(nodes)) => nodes }.flatten 103 | val threshold = 104 | if respondedCount > 10 105 | then NodeId.distance(nodesToQuery.head.id, targetId) 106 | else NodeId.MaxValue 107 | val closeNodes = foundNodes 108 | .filterNot(usedNodes) 109 | .distinct 110 | .filter(nodeInfo => NodeId.distance(nodeInfo.id, targetId) < threshold) 111 | .sortBy(nodeInfo => NodeId.distance(nodeInfo.id, targetId)) 112 | .take(10) 113 | ( 114 | respondedNodes, 115 | copy( 116 | nodesToQuery = closeNodes, 117 | usedNodes = usedNodes ++ respondedNodes, 118 | respondedCount = respondedCount + respondedNodes.size) 119 | ).some 120 | } 121 | 122 | private[dht] def start( 123 | infoHash: InfoHash, 124 | getPeers: (SocketAddress[IpAddress], InfoHash) => IO[Either[Response.Nodes, Response.Peers]], 125 | state: DiscoveryState, 126 | parallelism: Int = 10 127 | )(using 128 | logger: Logger[IO] 129 | ): Stream[IO, PeerInfo] = { 130 | 131 | Stream 132 | .repeatEval(state.next) 133 | .parEvalMapUnordered(parallelism) { nodeInfo => 134 | getPeers(nodeInfo.address, infoHash).timeout(5.seconds).attempt <* logger.trace(s"Get peers $nodeInfo") 135 | } 136 | .flatMap { 137 | case Right(response) => 138 | response match { 139 | case Left(Response.Nodes(_, nodes)) => 140 | Stream 141 | .eval(state.addNodes(nodes)) >> Stream.empty 142 | case Right(Response.Peers(_, peers)) => 143 | Stream 144 | .eval(state.addPeers(peers)) 145 | .flatMap(newPeers => Stream.emits(newPeers)) 146 | } 147 | case Left(_) => 148 | Stream.empty 149 | } 150 | } 151 | 152 | class DiscoveryState(ref: Ref[IO, DiscoveryState.Data], infoHash: InfoHash) { 153 | 154 | def next: IO[NodeInfo] = 155 | IO.deferred[NodeInfo] 156 | .flatMap { deferred => 157 | ref.modify { state => 158 | state.nodesToTry match { 159 | case x :: xs => (state.copy(nodesToTry = xs), x.pure[IO]) 160 | case _ => 161 | (state.copy(waiters = deferred :: state.waiters), deferred.get) 162 | } 163 | }.flatten 164 | } 165 | 166 | def addNodes(nodes: List[NodeInfo]): IO[Unit] = { 167 | ref.modify { state => 168 | val newNodes = nodes.filterNot(state.seenNodes) 169 | val seenNodes = state.seenNodes ++ newNodes 170 | val nodesToTry = (newNodes ++ state.nodesToTry).sortBy(n => NodeId.distance(n.id, infoHash)) 171 | val waiters = state.waiters.drop(nodesToTry.size) 172 | val newState = 173 | state.copy( 174 | nodesToTry = nodesToTry.drop(state.waiters.size), 175 | seenNodes = seenNodes, 176 | waiters = waiters 177 | ) 178 | val io = 179 | state.waiters.zip(nodesToTry).map { case (deferred, nodeInfo) => deferred.complete(nodeInfo) }.sequence_ 180 | (newState, io) 181 | }.flatten 182 | } 183 | 184 | type NewPeers = List[PeerInfo] 185 | 186 | def addPeers(peers: List[PeerInfo]): IO[NewPeers] = { 187 | ref 188 | .modify { state => 189 | val newPeers = peers.filterNot(state.seenPeers) 190 | val newState = state.copy( 191 | seenPeers = state.seenPeers ++ newPeers 192 | ) 193 | (newState, newPeers) 194 | } 195 | } 196 | } 197 | 198 | object DiscoveryState { 199 | 200 | case class Data( 201 | nodesToTry: List[NodeInfo], 202 | seenNodes: Set[NodeInfo], 203 | seenPeers: Set[PeerInfo] = Set.empty, 204 | waiters: List[Deferred[IO, NodeInfo]] = Nil 205 | ) 206 | 207 | def apply(initialNodes: List[NodeInfo], infoHash: InfoHash): IO[DiscoveryState] = 208 | for { 209 | ref <- IO.ref(Data(initialNodes, initialNodes.toSet)) 210 | } yield { 211 | new DiscoveryState(ref, infoHash) 212 | } 213 | } 214 | 215 | case class ExhaustedNodeList() extends Exception 216 | } 217 | -------------------------------------------------------------------------------- /bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/Connection.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.wire 2 | 3 | import cats.* 4 | import cats.effect.kernel.Deferred 5 | import cats.effect.kernel.Ref 6 | import cats.effect.std.Queue 7 | import cats.effect.syntax.all.* 8 | import cats.effect.IO 9 | import cats.effect.Outcome 10 | import cats.effect.Resource 11 | import cats.implicits.* 12 | import com.github.torrentdam.bittorrent.wire.ExtensionHandler.ExtensionApi 13 | import com.github.torrentdam.bittorrent.InfoHash 14 | import com.github.torrentdam.bittorrent.PeerId 15 | import com.github.torrentdam.bittorrent.PeerInfo 16 | import com.github.torrentdam.bittorrent.TorrentMetadata 17 | import com.github.torrentdam.bittorrent.protocol.message.Message 18 | import fs2.concurrent.Signal 19 | import fs2.concurrent.SignallingRef 20 | import fs2.io.net.Network 21 | import fs2.io.net.SocketGroup 22 | import monocle.macros.GenLens 23 | import monocle.Lens 24 | import org.legogroup.woof.given 25 | import org.legogroup.woof.Logger 26 | 27 | import scala.collection.immutable.BitSet 28 | import scala.concurrent.duration.* 29 | import scodec.bits.ByteVector 30 | 31 | trait Connection { 32 | def info: PeerInfo 33 | def extensionProtocol: Boolean 34 | def interested: IO[Unit] 35 | def request(request: Message.Request): IO[ByteVector] 36 | def choked: Signal[IO, Boolean] 37 | def availability: Signal[IO, BitSet] 38 | def disconnected: IO[Unit] 39 | def extensionApi: IO[ExtensionApi[IO]] 40 | } 41 | 42 | object Connection { 43 | 44 | case class State(lastMessageAt: Long = 0, interested: Boolean = false) 45 | object State { 46 | val lastMessageAt: Lens[State, Long] = GenLens[State](_.lastMessageAt) 47 | val interested: Lens[State, Boolean] = GenLens[State](_.interested) 48 | } 49 | 50 | trait RequestRegistry { 51 | def register(request: Message.Request): IO[ByteVector] 52 | def complete(request: Message.Request, bytes: ByteVector): IO[Unit] 53 | } 54 | object RequestRegistry { 55 | def apply(): Resource[IO, RequestRegistry] = 56 | for stateRef <- Resource.eval( 57 | IO.ref(Map.empty[Message.Request, Either[Throwable, ByteVector] => IO[Boolean]]) 58 | ) 59 | // _ <- Resource.onFinalize( 60 | // for 61 | // state <- stateRef.get 62 | // _ <- state.values.toList.traverse { cb => 63 | // cb(ConnectionClosed().asLeft) 64 | // } 65 | // yield () 66 | // ) 67 | yield new RequestRegistry { 68 | 69 | def register(request: Message.Request): IO[ByteVector] = 70 | IO.deferred[Either[Throwable, ByteVector]] 71 | .flatMap { deferred => 72 | val update = stateRef.update(_.updated(request, deferred.complete)) 73 | val delete = stateRef.update(_ - request) 74 | (update >> deferred.get).guarantee(delete) 75 | } 76 | .flatMap(IO.fromEither) 77 | 78 | def complete(request: Message.Request, bytes: ByteVector): IO[Unit] = 79 | for 80 | callback <- stateRef.get.map(_.get(request)) 81 | _ <- callback.traverse(cb => cb(bytes.asRight)) 82 | yield () 83 | } 84 | } 85 | 86 | def connect(selfId: PeerId, peerInfo: PeerInfo, infoHash: InfoHash)(using 87 | network: Network[IO], 88 | logger: Logger[IO] 89 | ): Resource[IO, Connection] = 90 | for 91 | requestRegistry <- RequestRegistry() 92 | socket <- MessageSocket.connect[IO](selfId, peerInfo, infoHash) 93 | stateRef <- Resource.eval(IO.ref(State())) 94 | chokedStatusRef <- Resource.eval(SignallingRef[IO].of(true)) 95 | bitfieldRef <- Resource.eval(SignallingRef[IO].of(BitSet.empty)) 96 | sendQueue <- Resource.eval(Queue.bounded[IO, Message](10)) 97 | (extensionHandler, initExtension) <- Resource.eval( 98 | ExtensionHandler.InitExtension( 99 | infoHash, 100 | sendQueue.offer, 101 | new ExtensionHandler.UtMetadata.Create[IO] 102 | ) 103 | ) 104 | updateLastMessageTime = (l: Long) => stateRef.update(State.lastMessageAt.replace(l)) 105 | closed <- 106 | ( 107 | receiveLoop( 108 | requestRegistry, 109 | bitfieldRef.update, 110 | chokedStatusRef.set, 111 | updateLastMessageTime, 112 | socket, 113 | extensionHandler 114 | ), 115 | sendLoop(sendQueue, socket), 116 | keepAliveLoop(stateRef, sendQueue.offer) 117 | ).parTupled.background 118 | yield new Connection { 119 | def info: PeerInfo = peerInfo 120 | def extensionProtocol: Boolean = socket.handshake.extensionProtocol 121 | 122 | def interested: IO[Unit] = 123 | for 124 | interested <- stateRef.modify(s => (State.interested.replace(true)(s), s.interested)) 125 | _ <- IO.whenA(!interested)(sendQueue.offer(Message.Interested)) 126 | yield () 127 | 128 | def request(request: Message.Request): IO[ByteVector] = 129 | sendQueue.offer(request) >> 130 | requestRegistry.register(request).flatMap { bytes => 131 | if bytes.length == request.length 132 | then bytes.pure[IO] 133 | else Error.InvalidBlockLength(request, bytes.length).raiseError[IO, ByteVector] 134 | } 135 | 136 | def choked: Signal[IO, Boolean] = chokedStatusRef 137 | 138 | def availability: Signal[IO, BitSet] = bitfieldRef 139 | 140 | def disconnected: IO[Unit] = closed.void 141 | 142 | def extensionApi: IO[ExtensionApi[IO]] = initExtension.init 143 | } 144 | end for 145 | 146 | case class ConnectionClosed() extends Throwable 147 | 148 | private def receiveLoop( 149 | requestRegistry: RequestRegistry, 150 | updateBitfield: (BitSet => BitSet) => IO[Unit], 151 | updateChokeStatus: Boolean => IO[Unit], 152 | updateLastMessageAt: Long => IO[Unit], 153 | socket: MessageSocket[IO], 154 | extensionHandler: ExtensionHandler[IO] 155 | ): IO[Nothing] = 156 | socket.receive 157 | .flatMap { 158 | case Message.Unchoke => 159 | updateChokeStatus(false) 160 | case Message.Choke => 161 | updateChokeStatus(true) 162 | case Message.Piece(index: Long, begin: Long, bytes: ByteVector) => 163 | val request: Message.Request = Message.Request(index, begin, bytes.length) 164 | requestRegistry.complete(request, bytes) 165 | case Message.Have(index) => 166 | updateBitfield(_ incl index.toInt) 167 | case Message.Bitfield(bytes) => 168 | val indices = bytes.toBitVector.toIndexedSeq.zipWithIndex.collect { case (true, i) => 169 | i 170 | } 171 | updateBitfield(_ => BitSet(indices*)) 172 | case m: Message.Extended => 173 | extensionHandler(m) 174 | case _ => 175 | IO.unit 176 | } 177 | .flatTap { _ => 178 | IO.realTime.flatMap { currentTime => 179 | updateLastMessageAt(currentTime.toMillis) 180 | } 181 | } 182 | .foreverM 183 | 184 | private def keepAliveLoop( 185 | stateRef: Ref[IO, State], 186 | send: Message => IO[Unit] 187 | ): IO[Nothing] = 188 | IO 189 | .sleep(10.seconds) 190 | .flatMap { _ => 191 | for 192 | currentTime <- IO.realTime 193 | timedOut <- stateRef.get.map(s => (currentTime - s.lastMessageAt.millis) > 30.seconds) 194 | _ <- IO.whenA(timedOut) { 195 | IO.raiseError(Error.ConnectionTimeout()) 196 | } 197 | _ <- send(Message.KeepAlive) 198 | yield () 199 | } 200 | .foreverM 201 | 202 | private def sendLoop(queue: Queue[IO, Message], socket: MessageSocket[IO]): IO[Nothing] = 203 | queue.take.flatMap(socket.send).foreverM 204 | 205 | enum Error(message: String) extends Exception(message): 206 | case ConnectionTimeout() extends Error("Connection timed out") 207 | case InvalidBlockLength(request: Message.Request, responseLength: Long) extends Error("Invalid block length") 208 | } 209 | -------------------------------------------------------------------------------- /dht/src/main/scala/com/github/torrentdam/bittorrent/dht/message.scala: -------------------------------------------------------------------------------- 1 | package com.github.torrentdam.bittorrent.dht 2 | 3 | import cats.implicits.* 4 | import com.comcast.ip4s.* 5 | import com.github.torrentdam.bittorrent.InfoHash 6 | import com.github.torrentdam.bittorrent.PeerInfo 7 | import com.github.torrentdam.bencode.format.* 8 | import com.github.torrentdam.bencode.Bencode 9 | import scodec.bits.ByteVector 10 | import scodec.Codec 11 | 12 | enum Message: 13 | case QueryMessage(transactionId: ByteVector, query: Query) 14 | case ResponseMessage(transactionId: ByteVector, response: Response) 15 | case ErrorMessage(transactionId: ByteVector, details: Bencode) 16 | 17 | def transactionId: ByteVector 18 | end Message 19 | 20 | object Message { 21 | 22 | given BencodeFormat[NodeId] = 23 | BencodeFormat.ByteVectorFormat.imap(NodeId.apply)(_.bytes) 24 | 25 | val PingQueryFormat: BencodeFormat[Query.Ping] = ( 26 | field[NodeId]("a")(using field[NodeId]("id")) 27 | ).imap[Query.Ping](qni => Query.Ping(qni))(v => v.queryingNodeId) 28 | 29 | val FindNodeQueryFormat: BencodeFormat[Query.FindNode] = ( 30 | field[(NodeId, NodeId)]("a")( 31 | using (field[NodeId]("id"), field[NodeId]("target")).tupled 32 | ) 33 | ).imap[Query.FindNode](Query.FindNode.apply)(v => (v.queryingNodeId, v.target)) 34 | 35 | given BencodeFormat[InfoHash] = BencodeFormat.ByteVectorFormat.imap(InfoHash(_))(_.bytes) 36 | 37 | val GetPeersQueryFormat: BencodeFormat[Query.GetPeers] = ( 38 | field[(NodeId, InfoHash)]("a")( 39 | using (field[NodeId]("id"), field[InfoHash]("info_hash")).tupled 40 | ) 41 | ).imap[Query.GetPeers](Query.GetPeers.apply)(v => (v.queryingNodeId, v.infoHash)) 42 | 43 | val AnnouncePeerQueryFormat: BencodeFormat[Query.AnnouncePeer] = ( 44 | field[(NodeId, InfoHash, Long)]("a")( 45 | using (field[NodeId]("id"), field[InfoHash]("info_hash"), field[Long]("port")).tupled 46 | ) 47 | ).imap[Query.AnnouncePeer](Query.AnnouncePeer.apply)(v => (v.queryingNodeId, v.infoHash, v.port)) 48 | 49 | val SampleInfoHashesQueryFormat: BencodeFormat[Query.SampleInfoHashes] = ( 50 | field[(NodeId, NodeId)]("a")( 51 | using (field[NodeId]("id"), field[NodeId]("target")).tupled 52 | ) 53 | ).imap[Query.SampleInfoHashes](Query.SampleInfoHashes.apply)(v => (v.queryingNodeId, v.target)) 54 | 55 | val QueryFormat: BencodeFormat[Query] = 56 | field[String]("q").choose( 57 | { 58 | case "ping" => PingQueryFormat.upcast 59 | case "find_node" => FindNodeQueryFormat.upcast 60 | case "get_peers" => GetPeersQueryFormat.upcast 61 | case "announce_peer" => AnnouncePeerQueryFormat.upcast 62 | case "sample_infohashes" => SampleInfoHashesQueryFormat.upcast 63 | }, 64 | { 65 | case _: Query.Ping => "ping" 66 | case _: Query.FindNode => "find_node" 67 | case _: Query.GetPeers => "get_peers" 68 | case _: Query.AnnouncePeer => "announce_peer" 69 | case _: Query.SampleInfoHashes => "sample_infohashes" 70 | } 71 | ) 72 | 73 | val QueryMessageFormat: BencodeFormat[Message.QueryMessage] = ( 74 | field[ByteVector]("t"), 75 | QueryFormat 76 | ).imapN[QueryMessage]((tid, q) => QueryMessage(tid, q))(v => (v.transactionId, v.query)) 77 | 78 | val InetSocketAddressCodec: Codec[SocketAddress[IpAddress]] = { 79 | import scodec.codecs.* 80 | (bytes(4) :: bytes(2)).xmap( 81 | { case (address, port) => 82 | SocketAddress( 83 | IpAddress.fromBytes(address.toArray).get, 84 | Port.fromInt(port.toInt(signed = false)).get 85 | ) 86 | }, 87 | v => (ByteVector(v.host.toBytes), ByteVector.fromInt(v.port.value, 2)) 88 | ) 89 | } 90 | 91 | val CompactNodeInfoCodec: Codec[List[NodeInfo]] = { 92 | import scodec.codecs.* 93 | list( 94 | (bytes(20) :: InetSocketAddressCodec).xmap( 95 | { case (id, address) => 96 | NodeInfo(NodeId(id), address) 97 | }, 98 | v => (v.id.bytes, v.address) 99 | ) 100 | ) 101 | } 102 | 103 | val CompactPeerInfoCodec: Codec[PeerInfo] = InetSocketAddressCodec.xmap(PeerInfo.apply, _.address) 104 | 105 | val CompactInfoHashCodec: Codec[List[InfoHash]] = { 106 | import scodec.codecs.* 107 | list( 108 | (bytes(20)).xmap(InfoHash.apply, _.bytes) 109 | ) 110 | } 111 | 112 | val PingResponseFormat: BencodeFormat[Response.Ping] = 113 | field[NodeId]("id").imap[Response.Ping](Response.Ping.apply)(_.id) 114 | 115 | val NodesResponseFormat: BencodeFormat[Response.Nodes] = ( 116 | field[NodeId]("id"), 117 | field[List[NodeInfo]]("nodes")(using encodedString(CompactNodeInfoCodec)) 118 | ).imapN[Response.Nodes](Response.Nodes.apply)(v => (v.id, v.nodes)) 119 | 120 | val PeersResponseFormat: BencodeFormat[Response.Peers] = ( 121 | field[NodeId]("id"), 122 | field[List[PeerInfo]]("values")(using BencodeFormat.listFormat(using encodedString(CompactPeerInfoCodec))) 123 | ).imapN[Response.Peers](Response.Peers.apply)(v => (v.id, v.peers)) 124 | 125 | val SampleInfoHashesResponseFormat: BencodeFormat[Response.SampleInfoHashes] = ( 126 | field[NodeId]("id"), 127 | fieldOptional[List[NodeInfo]]("nodes")(using encodedString(CompactNodeInfoCodec)), 128 | field[List[InfoHash]]("samples")(using encodedString(CompactInfoHashCodec)) 129 | ).imapN[Response.SampleInfoHashes](Response.SampleInfoHashes.apply)(v => (v.id, v.nodes, v.samples)) 130 | 131 | val ResponseFormat: BencodeFormat[Response] = 132 | BencodeFormat( 133 | BencodeFormat.dictionaryFormat.read.flatMap { 134 | case Bencode.BDictionary(dictionary) if dictionary.contains("values") => PeersResponseFormat.read.widen 135 | case Bencode.BDictionary(dictionary) if dictionary.contains("samples") => 136 | SampleInfoHashesResponseFormat.read.widen 137 | case Bencode.BDictionary(dictionary) if dictionary.contains("nodes") => NodesResponseFormat.read.widen 138 | case _ => PingResponseFormat.read.widen 139 | }, 140 | BencodeWriter { 141 | case value: Response.Peers => PeersResponseFormat.write(value) 142 | case value: Response.Nodes => NodesResponseFormat.write(value) 143 | case value: Response.Ping => PingResponseFormat.write(value) 144 | case value: Response.SampleInfoHashes => SampleInfoHashesResponseFormat.write(value) 145 | } 146 | ) 147 | 148 | val ResponseMessageFormat: BencodeFormat[Message.ResponseMessage] = ( 149 | field[ByteVector]("t"), 150 | field[Response]("r")(using ResponseFormat) 151 | ).imapN[ResponseMessage]((tid, r) => ResponseMessage(tid, r))(v => (v.transactionId, v.response)) 152 | 153 | val ErrorMessageFormat: BencodeFormat[Message.ErrorMessage] = ( 154 | fieldOptional[ByteVector]("t"), 155 | field[Bencode]("e") 156 | ).imapN[ErrorMessage]((tid, details) => ErrorMessage(tid.getOrElse(ByteVector.empty), details))(v => 157 | (v.transactionId.some, v.details) 158 | ) 159 | 160 | given BencodeFormat[Message] = 161 | field[String]("y").choose( 162 | { 163 | case "q" => QueryMessageFormat.upcast 164 | case "r" => ResponseMessageFormat.upcast 165 | case "e" => ErrorMessageFormat.upcast 166 | }, 167 | { 168 | case _: Message.QueryMessage => "q" 169 | case _: Message.ResponseMessage => "r" 170 | case _: Message.ErrorMessage => "e" 171 | } 172 | ) 173 | } 174 | 175 | enum Query: 176 | case Ping(queryingNodeId: NodeId) 177 | case FindNode(queryingNodeId: NodeId, target: NodeId) 178 | case GetPeers(queryingNodeId: NodeId, infoHash: InfoHash) 179 | case AnnouncePeer(queryingNodeId: NodeId, infoHash: InfoHash, port: Long) 180 | case SampleInfoHashes(queryingNodeId: NodeId, target: NodeId) 181 | 182 | def queryingNodeId: NodeId 183 | end Query 184 | 185 | enum Response: 186 | case Ping(id: NodeId) 187 | case Nodes(id: NodeId, nodes: List[NodeInfo]) 188 | case Peers(id: NodeId, peers: List[PeerInfo]) 189 | case SampleInfoHashes(id: NodeId, nodes: Option[List[NodeInfo]], samples: List[InfoHash]) 190 | -------------------------------------------------------------------------------- /mill: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | # This is a wrapper script, that automatically selects or downloads Mill from Maven Central or GitHub release pages. 4 | # 5 | # This script determines the Mill version to use by trying these sources 6 | # - env-variable `MILL_VERSION` 7 | # - local file `.mill-version` 8 | # - local file `.config/mill-version` 9 | # - `mill-version` from YAML fronmatter of current buildfile 10 | # - if accessible, find the latest stable version available on Maven Central (https://repo1.maven.org/maven2) 11 | # - env-variable `DEFAULT_MILL_VERSION` 12 | # 13 | # If a version has the suffix '-native' a native binary will be used. 14 | # If a version has the suffix '-jvm' an executable jar file will be used, requiring an already installed Java runtime. 15 | # If no such suffix is found, the script will pick a default based on version and platform. 16 | # 17 | # Once a version was determined, it tries to use either 18 | # - a system-installed mill, if found and it's version matches 19 | # - an already downloaded version under ~/.cache/mill/download 20 | # 21 | # If no working mill version was found on the system, 22 | # this script downloads a binary file from Maven Central or Github Pages (this is version dependent) 23 | # into a cache location (~/.cache/mill/download). 24 | # 25 | # Mill Project URL: https://github.com/com-lihaoyi/mill 26 | # Script Version: 1.0.0-M1-21-7b6fae-DIRTY892b63e8 27 | # 28 | # If you want to improve this script, please also contribute your changes back! 29 | # This script was generated from: dist/scripts/src/mill.sh 30 | # 31 | # Licensed under the Apache License, Version 2.0 32 | 33 | set -e 34 | 35 | if [ "$1" = "--setup-completions" ] ; then 36 | # Need to preserve the first position of those listed options 37 | MILL_FIRST_ARG=$1 38 | shift 39 | fi 40 | 41 | if [ -z "${DEFAULT_MILL_VERSION}" ] ; then 42 | DEFAULT_MILL_VERSION=1.0.1 43 | fi 44 | 45 | 46 | if [ -z "${GITHUB_RELEASE_CDN}" ] ; then 47 | GITHUB_RELEASE_CDN="" 48 | fi 49 | 50 | 51 | MILL_REPO_URL="https://github.com/com-lihaoyi/mill" 52 | 53 | if [ -z "${CURL_CMD}" ] ; then 54 | CURL_CMD=curl 55 | fi 56 | 57 | # Explicit commandline argument takes precedence over all other methods 58 | if [ "$1" = "--mill-version" ] ; then 59 | echo "The --mill-version option is no longer supported." 1>&2 60 | fi 61 | 62 | MILL_BUILD_SCRIPT="" 63 | 64 | if [ -f "build.mill" ] ; then 65 | MILL_BUILD_SCRIPT="build.mill" 66 | elif [ -f "build.mill.scala" ] ; then 67 | MILL_BUILD_SCRIPT="build.mill.scala" 68 | elif [ -f "build.sc" ] ; then 69 | MILL_BUILD_SCRIPT="build.sc" 70 | fi 71 | 72 | # Please note, that if a MILL_VERSION is already set in the environment, 73 | # We reuse it's value and skip searching for a value. 74 | 75 | # If not already set, read .mill-version file 76 | if [ -z "${MILL_VERSION}" ] ; then 77 | if [ -f ".mill-version" ] ; then 78 | MILL_VERSION="$(tr '\r' '\n' < .mill-version | head -n 1 2> /dev/null)" 79 | elif [ -f ".config/mill-version" ] ; then 80 | MILL_VERSION="$(tr '\r' '\n' < .config/mill-version | head -n 1 2> /dev/null)" 81 | elif [ -n "${MILL_BUILD_SCRIPT}" ] ; then 82 | MILL_VERSION="$(cat ${MILL_BUILD_SCRIPT} | grep '//[|] *mill-version: *' | sed 's;//| *mill-version: *;;')" 83 | fi 84 | fi 85 | 86 | MILL_USER_CACHE_DIR="${XDG_CACHE_HOME:-${HOME}/.cache}/mill" 87 | 88 | if [ -z "${MILL_DOWNLOAD_PATH}" ] ; then 89 | MILL_DOWNLOAD_PATH="${MILL_USER_CACHE_DIR}/download" 90 | fi 91 | 92 | # If not already set, try to fetch newest from Github 93 | if [ -z "${MILL_VERSION}" ] ; then 94 | # TODO: try to load latest version from release page 95 | echo "No mill version specified." 1>&2 96 | echo "You should provide a version via a '//| mill-version: ' comment or a '.mill-version' file." 1>&2 97 | 98 | mkdir -p "${MILL_DOWNLOAD_PATH}" 99 | LANG=C touch -d '1 hour ago' "${MILL_DOWNLOAD_PATH}/.expire_latest" 2>/dev/null || ( 100 | # we might be on OSX or BSD which don't have -d option for touch 101 | # but probably a -A [-][[hh]mm]SS 102 | touch "${MILL_DOWNLOAD_PATH}/.expire_latest"; touch -A -010000 "${MILL_DOWNLOAD_PATH}/.expire_latest" 103 | ) || ( 104 | # in case we still failed, we retry the first touch command with the intention 105 | # to show the (previously suppressed) error message 106 | LANG=C touch -d '1 hour ago' "${MILL_DOWNLOAD_PATH}/.expire_latest" 107 | ) 108 | 109 | # POSIX shell variant of bash's -nt operator, see https://unix.stackexchange.com/a/449744/6993 110 | # if [ "${MILL_DOWNLOAD_PATH}/.latest" -nt "${MILL_DOWNLOAD_PATH}/.expire_latest" ] ; then 111 | if [ -n "$(find -L "${MILL_DOWNLOAD_PATH}/.latest" -prune -newer "${MILL_DOWNLOAD_PATH}/.expire_latest")" ]; then 112 | # we know a current latest version 113 | MILL_VERSION=$(head -n 1 "${MILL_DOWNLOAD_PATH}"/.latest 2> /dev/null) 114 | fi 115 | 116 | if [ -z "${MILL_VERSION}" ] ; then 117 | # we don't know a current latest version 118 | echo "Retrieving latest mill version ..." 1>&2 119 | LANG=C ${CURL_CMD} -s -i -f -I ${MILL_REPO_URL}/releases/latest 2> /dev/null | grep --ignore-case Location: | sed s'/^.*tag\///' | tr -d '\r\n' > "${MILL_DOWNLOAD_PATH}/.latest" 120 | MILL_VERSION=$(head -n 1 "${MILL_DOWNLOAD_PATH}"/.latest 2> /dev/null) 121 | fi 122 | 123 | if [ -z "${MILL_VERSION}" ] ; then 124 | # Last resort 125 | MILL_VERSION="${DEFAULT_MILL_VERSION}" 126 | echo "Falling back to hardcoded mill version ${MILL_VERSION}" 1>&2 127 | else 128 | echo "Using mill version ${MILL_VERSION}" 1>&2 129 | fi 130 | fi 131 | 132 | MILL_NATIVE_SUFFIX="-native" 133 | MILL_JVM_SUFFIX="-jvm" 134 | FULL_MILL_VERSION=$MILL_VERSION 135 | ARTIFACT_SUFFIX="" 136 | set_artifact_suffix(){ 137 | if [ "$(expr substr $(uname -s) 1 5 2>/dev/null)" = "Linux" ]; then 138 | if [ "$(uname -m)" = "aarch64" ]; then 139 | ARTIFACT_SUFFIX="-native-linux-aarch64" 140 | else 141 | ARTIFACT_SUFFIX="-native-linux-amd64" 142 | fi 143 | elif [ "$(uname)" = "Darwin" ]; then 144 | if [ "$(uname -m)" = "arm64" ]; then 145 | ARTIFACT_SUFFIX="-native-mac-aarch64" 146 | else 147 | ARTIFACT_SUFFIX="-native-mac-amd64" 148 | fi 149 | else 150 | echo "This native mill launcher supports only Linux and macOS." 1>&2 151 | exit 1 152 | fi 153 | } 154 | 155 | case "$MILL_VERSION" in 156 | *"$MILL_NATIVE_SUFFIX") 157 | MILL_VERSION=${MILL_VERSION%"$MILL_NATIVE_SUFFIX"} 158 | set_artifact_suffix 159 | ;; 160 | 161 | *"$MILL_JVM_SUFFIX") 162 | MILL_VERSION=${MILL_VERSION%"$MILL_JVM_SUFFIX"} 163 | ;; 164 | 165 | *) 166 | case "$MILL_VERSION" in 167 | 0.1.*) ;; 168 | 0.2.*) ;; 169 | 0.3.*) ;; 170 | 0.4.*) ;; 171 | 0.5.*) ;; 172 | 0.6.*) ;; 173 | 0.7.*) ;; 174 | 0.8.*) ;; 175 | 0.9.*) ;; 176 | 0.10.*) ;; 177 | 0.11.*) ;; 178 | 0.12.*) ;; 179 | *) 180 | set_artifact_suffix 181 | esac 182 | ;; 183 | esac 184 | 185 | MILL="${MILL_DOWNLOAD_PATH}/$MILL_VERSION$ARTIFACT_SUFFIX" 186 | 187 | try_to_use_system_mill() { 188 | if [ "$(uname)" != "Linux" ]; then 189 | return 0 190 | fi 191 | 192 | MILL_IN_PATH="$(command -v mill || true)" 193 | 194 | if [ -z "${MILL_IN_PATH}" ]; then 195 | return 0 196 | fi 197 | 198 | SYSTEM_MILL_FIRST_TWO_BYTES=$(head --bytes=2 "${MILL_IN_PATH}") 199 | if [ "${SYSTEM_MILL_FIRST_TWO_BYTES}" = "#!" ]; then 200 | # MILL_IN_PATH is (very likely) a shell script and not the mill 201 | # executable, ignore it. 202 | return 0 203 | fi 204 | 205 | SYSTEM_MILL_PATH=$(readlink -e "${MILL_IN_PATH}") 206 | SYSTEM_MILL_SIZE=$(stat --format=%s "${SYSTEM_MILL_PATH}") 207 | SYSTEM_MILL_MTIME=$(stat --format=%y "${SYSTEM_MILL_PATH}") 208 | 209 | if [ ! -d "${MILL_USER_CACHE_DIR}" ]; then 210 | mkdir -p "${MILL_USER_CACHE_DIR}" 211 | fi 212 | 213 | SYSTEM_MILL_INFO_FILE="${MILL_USER_CACHE_DIR}/system-mill-info" 214 | if [ -f "${SYSTEM_MILL_INFO_FILE}" ]; then 215 | parseSystemMillInfo() { 216 | LINE_NUMBER="${1}" 217 | # Select the line number of the SYSTEM_MILL_INFO_FILE, cut the 218 | # variable definition in that line in two halves and return 219 | # the value, and finally remove the quotes. 220 | sed -n "${LINE_NUMBER}p" "${SYSTEM_MILL_INFO_FILE}" |\ 221 | cut -d= -f2 |\ 222 | sed 's/"\(.*\)"/\1/' 223 | } 224 | 225 | CACHED_SYSTEM_MILL_PATH=$(parseSystemMillInfo 1) 226 | CACHED_SYSTEM_MILL_VERSION=$(parseSystemMillInfo 2) 227 | CACHED_SYSTEM_MILL_SIZE=$(parseSystemMillInfo 3) 228 | CACHED_SYSTEM_MILL_MTIME=$(parseSystemMillInfo 4) 229 | 230 | if [ "${SYSTEM_MILL_PATH}" = "${CACHED_SYSTEM_MILL_PATH}" ] \ 231 | && [ "${SYSTEM_MILL_SIZE}" = "${CACHED_SYSTEM_MILL_SIZE}" ] \ 232 | && [ "${SYSTEM_MILL_MTIME}" = "${CACHED_SYSTEM_MILL_MTIME}" ]; then 233 | if [ "${CACHED_SYSTEM_MILL_VERSION}" = "${MILL_VERSION}" ]; then 234 | MILL="${SYSTEM_MILL_PATH}" 235 | return 0 236 | else 237 | return 0 238 | fi 239 | fi 240 | fi 241 | 242 | SYSTEM_MILL_VERSION=$(${SYSTEM_MILL_PATH} --version | head -n1 | sed -n 's/^Mill.*version \(.*\)/\1/p') 243 | 244 | cat < "${SYSTEM_MILL_INFO_FILE}" 245 | CACHED_SYSTEM_MILL_PATH="${SYSTEM_MILL_PATH}" 246 | CACHED_SYSTEM_MILL_VERSION="${SYSTEM_MILL_VERSION}" 247 | CACHED_SYSTEM_MILL_SIZE="${SYSTEM_MILL_SIZE}" 248 | CACHED_SYSTEM_MILL_MTIME="${SYSTEM_MILL_MTIME}" 249 | EOF 250 | 251 | if [ "${SYSTEM_MILL_VERSION}" = "${MILL_VERSION}" ]; then 252 | MILL="${SYSTEM_MILL_PATH}" 253 | fi 254 | } 255 | try_to_use_system_mill 256 | 257 | # If not already downloaded, download it 258 | if [ ! -s "${MILL}" ] || [ "$MILL_TEST_DRY_RUN_LAUNCHER_SCRIPT" = "1" ] ; then 259 | case $MILL_VERSION in 260 | 0.0.* | 0.1.* | 0.2.* | 0.3.* | 0.4.* ) 261 | DOWNLOAD_SUFFIX="" 262 | DOWNLOAD_FROM_MAVEN=0 263 | ;; 264 | 0.5.* | 0.6.* | 0.7.* | 0.8.* | 0.9.* | 0.10.* | 0.11.0-M* ) 265 | DOWNLOAD_SUFFIX="-assembly" 266 | DOWNLOAD_FROM_MAVEN=0 267 | ;; 268 | *) 269 | DOWNLOAD_SUFFIX="-assembly" 270 | DOWNLOAD_FROM_MAVEN=1 271 | ;; 272 | esac 273 | case $MILL_VERSION in 274 | 0.12.0 | 0.12.1 | 0.12.2 | 0.12.3 | 0.12.4 | 0.12.5 | 0.12.6 | 0.12.7 | 0.12.8 | 0.12.9 | 0.12.10 | 0.12.11 ) 275 | DOWNLOAD_EXT="jar" 276 | ;; 277 | 0.12.* ) 278 | DOWNLOAD_EXT="exe" 279 | ;; 280 | 0.* ) 281 | DOWNLOAD_EXT="jar" 282 | ;; 283 | *) 284 | DOWNLOAD_EXT="exe" 285 | ;; 286 | esac 287 | 288 | DOWNLOAD_FILE=$(mktemp mill.XXXXXX) 289 | if [ "$DOWNLOAD_FROM_MAVEN" = "1" ] ; then 290 | DOWNLOAD_URL="https://repo1.maven.org/maven2/com/lihaoyi/mill-dist${ARTIFACT_SUFFIX}/${MILL_VERSION}/mill-dist${ARTIFACT_SUFFIX}-${MILL_VERSION}.${DOWNLOAD_EXT}" 291 | else 292 | MILL_VERSION_TAG=$(echo "$MILL_VERSION" | sed -E 's/([^-]+)(-M[0-9]+)?(-.*)?/\1\2/') 293 | DOWNLOAD_URL="${GITHUB_RELEASE_CDN}${MILL_REPO_URL}/releases/download/${MILL_VERSION_TAG}/${MILL_VERSION}${DOWNLOAD_SUFFIX}" 294 | unset MILL_VERSION_TAG 295 | fi 296 | 297 | if [ "$MILL_TEST_DRY_RUN_LAUNCHER_SCRIPT" = "1" ] ; then 298 | echo $DOWNLOAD_URL 299 | echo $MILL 300 | exit 0 301 | fi 302 | # TODO: handle command not found 303 | echo "Downloading mill ${MILL_VERSION} from ${DOWNLOAD_URL} ..." 1>&2 304 | ${CURL_CMD} -f -L -o "${DOWNLOAD_FILE}" "${DOWNLOAD_URL}" 305 | chmod +x "${DOWNLOAD_FILE}" 306 | mkdir -p "${MILL_DOWNLOAD_PATH}" 307 | mv "${DOWNLOAD_FILE}" "${MILL}" 308 | 309 | unset DOWNLOAD_FILE 310 | unset DOWNLOAD_SUFFIX 311 | fi 312 | 313 | if [ -z "$MILL_MAIN_CLI" ] ; then 314 | MILL_MAIN_CLI="${0}" 315 | fi 316 | 317 | MILL_FIRST_ARG="" 318 | if [ "$1" = "--bsp" ] || [ "${1#"-i"}" != "$1" ] || [ "$1" = "--interactive" ] || [ "$1" = "--no-server" ] || [ "$1" = "--no-daemon" ] || [ "$1" = "--repl" ] || [ "$1" = "--help" ] ; then 319 | # Need to preserve the first position of those listed options 320 | MILL_FIRST_ARG=$1 321 | shift 322 | fi 323 | 324 | unset MILL_DOWNLOAD_PATH 325 | unset MILL_OLD_DOWNLOAD_PATH 326 | unset OLD_MILL 327 | unset MILL_VERSION 328 | unset MILL_REPO_URL 329 | 330 | # -D mill.main.cli is for compatibility with Mill 0.10.9 - 0.13.0-M2 331 | # We don't quote MILL_FIRST_ARG on purpose, so we can expand the empty value without quotes 332 | # shellcheck disable=SC2086 333 | exec "${MILL}" $MILL_FIRST_ARG -D "mill.main.cli=${MILL_MAIN_CLI}" "$@" 334 | -------------------------------------------------------------------------------- /cmd/src/main/scala/Main.scala: -------------------------------------------------------------------------------- 1 | import cats.effect.cps.* 2 | import cats.effect.cps.given 3 | import cats.effect.std.Random 4 | import cats.effect.syntax.all.* 5 | import cats.effect.ExitCode 6 | import cats.effect.IO 7 | import cats.effect.Resource 8 | import cats.effect.ResourceIO 9 | import cats.syntax.all.* 10 | import com.comcast.ip4s.Port 11 | import com.comcast.ip4s.SocketAddress 12 | import com.github.torrentdam.bencode 13 | import com.github.torrentdam.bittorrent.dht.* 14 | import com.github.torrentdam.bittorrent.files.Reader 15 | import com.github.torrentdam.bittorrent.files.Writer 16 | import com.github.torrentdam.bittorrent.wire.Connection 17 | import com.github.torrentdam.bittorrent.wire.Download 18 | import com.github.torrentdam.bittorrent.wire.DownloadMetadata 19 | import com.github.torrentdam.bittorrent.wire.RequestDispatcher 20 | import com.github.torrentdam.bittorrent.wire.Swarm 21 | import com.github.torrentdam.bittorrent.wire.Torrent 22 | import com.github.torrentdam.bittorrent.CrossPlatform 23 | import com.github.torrentdam.bittorrent.InfoHash 24 | import com.github.torrentdam.bittorrent.PeerId 25 | import com.github.torrentdam.bittorrent.PeerInfo 26 | import com.github.torrentdam.bittorrent.TorrentFile 27 | import com.github.torrentdam.bittorrent.TorrentMetadata 28 | import com.monovore.decline.effect.CommandIOApp 29 | import com.monovore.decline.Opts 30 | import cps.syntax.* 31 | import fs2.io.file.Files 32 | import fs2.io.file.Flag 33 | import fs2.io.file.Flags 34 | import fs2.io.file.Path 35 | import fs2.io.file.WriteCursor 36 | import fs2.Chunk 37 | import fs2.Stream 38 | import java.util.concurrent.Executors 39 | import java.util.concurrent.ThreadFactory 40 | import org.legogroup.woof.* 41 | import org.legogroup.woof.given 42 | import scala.concurrent.duration.DurationInt 43 | import scodec.bits.ByteVector 44 | 45 | object Main 46 | extends CommandIOApp( 47 | name = "torrentdam", 48 | header = "TorrentDam" 49 | ) { 50 | 51 | def main: Opts[IO[ExitCode]] = 52 | torrentCommand <+> dhtCommand 53 | 54 | def torrentCommand: Opts[IO[ExitCode]] = 55 | Opts.subcommand("torrent", "torrent client")( 56 | fetchFileCommand <+> downloadCommand <+> verifyCommand 57 | ) 58 | 59 | def fetchFileCommand = 60 | Opts.subcommand("fetch-file", "download torrent file") { 61 | ( 62 | Opts.option[String]("info-hash", "Info-hash"), 63 | Opts.option[String]("save", "Save as a torrent file") 64 | ) 65 | .mapN { (infoHashOption, targetFilePath) => 66 | withLogger { 67 | async[ResourceIO] { 68 | given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await 69 | 70 | val selfPeerId = Resource.eval(PeerId.generate[IO]).await 71 | val infoHash = Resource.eval(infoHashFromString(infoHashOption)).await 72 | val node = Node().await 73 | 74 | val swarm = Swarm( 75 | node.discovery.discover(infoHash), 76 | Connection.connect(selfPeerId, _, infoHash) 77 | ).await 78 | val metadata = DownloadMetadata(swarm).toResource.await 79 | val torrentFile = TorrentFile(metadata, None) 80 | Files[IO] 81 | .writeAll(Path(targetFilePath), Flags.Write)( 82 | Stream.chunk(Chunk.byteVector(TorrentFile.toBytes(torrentFile))) 83 | ) 84 | .compile 85 | .drain 86 | .as(ExitCode.Success) 87 | }.useEval 88 | } 89 | } 90 | } 91 | 92 | def downloadCommand = 93 | Opts.subcommand("download", "download torrent data") { 94 | val options = ( 95 | Opts.option[String]("info-hash", "Info-hash").orNone, 96 | Opts.option[String]("torrent", "Torrent file").orNone, 97 | Opts.option[String]("peer", "Peer address").orNone, 98 | Opts.option[String]("dht-node", "DHT node address").orNone 99 | ).tupled 100 | options.map { case (infoHashOption, torrentFileOption, peerAddressOption, dhtNodeAddressOption) => 101 | withLogger { 102 | async[ResourceIO] { 103 | val torrentFile: Option[TorrentFile] = torrentFileOption 104 | .traverse[IO, TorrentFile](torrentFileOption => 105 | async[IO] { 106 | val torrentFileBytes = Files[IO] 107 | .readAll(Path(torrentFileOption)) 108 | .compile 109 | .to(ByteVector) 110 | .await 111 | TorrentFile 112 | .fromBytes(torrentFileBytes) 113 | .liftTo[IO] 114 | .await 115 | } 116 | ) 117 | .toResource 118 | .await 119 | val infoHash: InfoHash = 120 | torrentFile match 121 | case Some(torrentFile) => 122 | torrentFile.infoHash 123 | case None => 124 | infoHashOption match 125 | case Some(infoHashOption) => 126 | infoHashFromString(infoHashOption).toResource.await 127 | case None => 128 | throw new Exception("Missing info-hash") 129 | 130 | given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await 131 | val selfPeerId = Resource.eval(PeerId.generate[IO]).await 132 | val peerAddress = peerAddressOption.flatMap(SocketAddress.fromStringIp) 133 | val peers: Stream[IO, PeerInfo] = 134 | peerAddress match 135 | case Some(peerAddress) => 136 | Stream.emit(PeerInfo(peerAddress)).covary[IO] 137 | case None => 138 | val bootstrapNodeAddress = dhtNodeAddressOption.flatMap(SocketAddress.fromString) 139 | val node = Node(none, bootstrapNodeAddress).await 140 | node.discovery.discover(infoHash) 141 | val swarm = Swarm(peers, peerInfo => Connection.connect(selfPeerId, peerInfo, infoHash)).await 142 | val metadata = 143 | torrentFile match 144 | case Some(torrentFile) => 145 | torrentFile.info 146 | case None => 147 | Resource.eval(DownloadMetadata(swarm)).await 148 | val torrent = Torrent.make(metadata, swarm).await 149 | val total = (metadata.parsed.pieces.length.toDouble / 20).ceil.toLong 150 | val counter = Resource.eval(IO.ref(0)).await 151 | val writer = Writer.fromTorrent(metadata.parsed) 152 | val createDirectories = metadata.parsed.files 153 | .filter(_.path.length > 1) 154 | .map(_.path.init) 155 | .distinct 156 | .traverse { path => 157 | val dir = path.foldLeft(Path("."))(_ / _) 158 | Files[IO].createDirectories(dir) 159 | } 160 | Resource.eval(createDirectories).await 161 | val openFiles: Map[TorrentMetadata.File, WriteCursor[IO]] = 162 | metadata.parsed.files 163 | .traverse { file => 164 | val path = file.path.foldLeft(Path("."))(_ / _) 165 | val flags = Flags(Flag.Create, Flag.Write) 166 | val cursor = Files[IO].writeCursor(path, flags) 167 | cursor.tupleLeft(file) 168 | } 169 | .await 170 | .toMap 171 | Stream 172 | .range(0L, total) 173 | .parEvalMap(10)(index => 174 | async[IO] { 175 | val piece = !torrent.downloadPiece(index) 176 | val count = !counter.updateAndGet(_ + 1) 177 | val percent = ((count.toDouble / total) * 100).toInt 178 | !Logger[IO].info(s"Downloaded piece $count/$total ($percent%)") 179 | Chunk.iterable(writer.write(index, piece)) 180 | } 181 | ) 182 | .unchunks 183 | .evalMap(write => openFiles(write.file).seek(write.offset).write(Chunk.byteVector(write.bytes))) 184 | .compile 185 | .drain 186 | .as(ExitCode.Success) 187 | }.useEval 188 | } 189 | } 190 | } 191 | 192 | def verifyCommand = 193 | Opts.subcommand("verify", "verify torrent data") { 194 | val options: Opts[(String, String)] = 195 | ( 196 | Opts.option[String]("torrent", "Torrent file"), 197 | Opts.option[String]("target", "Torrent data directory") 198 | ).tupled 199 | options.map { (torrentFileName, targetDirName) => 200 | withLogger { 201 | async[IO] { 202 | try 203 | val bytes = Files[IO].readAll(Path(torrentFileName)).compile.to(Array).map(ByteVector(_)).await 204 | val torrentFile = IO.fromEither(TorrentFile.fromBytes(bytes)).await 205 | val infoHash = InfoHash(CrossPlatform.sha1(bencode.encode(torrentFile.info.raw).bytes)) 206 | Logger[IO].info(s"Info-hash: $infoHash").await 207 | 208 | val reader = Reader.fromTorrent(torrentFile.info.parsed) 209 | 210 | def readPiece(index: Long): IO[ByteVector] = 211 | val reads = Stream.emits(reader.read(index)) 212 | reads 213 | .covary[IO] 214 | .evalMap { read => 215 | val path = read.file.path.foldLeft(Path(targetDirName))(_ / _) 216 | Files[IO] 217 | .readRange(path, 1024 * 1024, read.offset, read.endOffset) 218 | .chunks 219 | .map(_.toByteVector) 220 | .compile 221 | .fold(ByteVector.empty)(_ ++ _) 222 | } 223 | .compile 224 | .fold(ByteVector.empty)(_ ++ _) 225 | 226 | val readByteCount = IO.ref(0L).await 227 | 228 | Stream 229 | .unfold(torrentFile.info.parsed.pieces)(bytes => 230 | if bytes.isEmpty then None 231 | else 232 | val (checksum, rest) = bytes.splitAt(20) 233 | Some((checksum, rest)) 234 | ) 235 | .zipWithIndex 236 | .evalMap { (checksum, index) => 237 | readPiece(index).map((checksum, index, _)) 238 | } 239 | .evalTap { (checksum, index, bytes) => 240 | if CrossPlatform.sha1(bytes) == checksum then readByteCount.update(_ + bytes.length) 241 | else Logger[IO].error(s"Piece $index failed") >> IO.raiseError(new Exception) 242 | } 243 | .compile 244 | .drain 245 | .await 246 | val totalBytes = readByteCount.get.await 247 | Logger[IO].info(s"Read $totalBytes bytes").await 248 | Logger[IO].info("All pieces verified").await 249 | ExitCode.Success 250 | catch 251 | case e => 252 | Logger[IO].error(e.getMessage).await 253 | ExitCode.Error 254 | } 255 | } 256 | } 257 | } 258 | 259 | def dhtCommand: Opts[IO[ExitCode]] = 260 | Opts.subcommand("dht", "DHT client")( 261 | startCommand <+> getPeers 262 | ) 263 | 264 | def startCommand = 265 | Opts.subcommand("start", "start DHT node") { 266 | Opts.option[Int]("port", "UDP port").map { portParam => 267 | withLogger { 268 | async[ResourceIO] { 269 | val port = Port.fromInt(portParam).liftTo[ResourceIO](new Exception("Invalid port")).await 270 | given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await 271 | Node(Some(port)).await 272 | }.useForever 273 | } 274 | } 275 | } 276 | 277 | def getPeers = 278 | Opts.subcommand("get-peers", "send single get_peers query") { 279 | ( 280 | Opts.option[String]("host", "DHT node address"), 281 | Opts.option[String]("info-hash", "Info-hash"), 282 | ).tupled.map { (nodeAddressParam, infoHashParam) => 283 | withLogger { 284 | async[ResourceIO] { 285 | val nodeAddress = SocketAddress.fromString(nodeAddressParam).liftTo[ResourceIO](new Exception("Invalid address")).await 286 | val nodeIpAddress = nodeAddress.resolve[IO].toResource.await 287 | given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await 288 | val selfId = Resource.eval(NodeId.random[IO]).await 289 | val infoHash = infoHashFromString(infoHashParam).toResource.await 290 | val messageSocket = MessageSocket(none).await 291 | val client = Client(selfId, messageSocket, QueryHandler.noop).await 292 | async[IO]: 293 | val response = client.getPeers(nodeIpAddress, infoHash).await 294 | IO.println(response).await 295 | ExitCode.Success 296 | }.useEval 297 | } 298 | } 299 | } 300 | 301 | extension (torrentFile: TorrentFile) { 302 | def infoHash: InfoHash = InfoHash(CrossPlatform.sha1(bencode.encode(torrentFile.info.raw).bytes)) 303 | } 304 | 305 | def infoHashFromString(value: String): IO[InfoHash] = 306 | InfoHash.fromString 307 | .unapply(value) 308 | .liftTo[IO](new Exception("Malformed info-hash")) 309 | 310 | def withLogger[A](body: Logger[IO] ?=> IO[A]): IO[A] = 311 | given Filter = Filter.atLeastLevel(LogLevel.Info) 312 | given Printer = ColorPrinter() 313 | DefaultLogger.makeIo(Output.fromConsole[IO]).flatMap(body(using _)) 314 | } 315 | --------------------------------------------------------------------------------