├── project
├── build.properties
└── plugins.sbt
├── cmd
├── Dockerfile-native
├── Dockerfile
├── resources
│ └── logback.xml
└── src
│ └── main
│ └── scala
│ └── Main.scala
├── .gitignore
├── common
└── src
│ ├── main
│ └── scala
│ │ └── com
│ │ └── github
│ │ └── torrentdam
│ │ └── bittorrent
│ │ ├── PeerInfo.scala
│ │ ├── InfoHash.scala
│ │ ├── PeerId.scala
│ │ └── MagnetLink.scala
│ └── test
│ └── scala
│ └── MagnetLinkSuite.scala
├── .scalafmt.conf
├── README.md
├── bittorrent
├── shared
│ └── src
│ │ ├── test
│ │ ├── resources
│ │ │ └── bencode
│ │ │ │ └── ubuntu-18.10-live-server-amd64.iso.torrent
│ │ └── scala
│ │ │ └── com
│ │ │ └── github
│ │ │ └── torrentdam
│ │ │ └── bittorrent
│ │ │ ├── protocol
│ │ │ └── message
│ │ │ │ └── HandshakeSpec.scala
│ │ │ ├── FileMappingSpec.scala
│ │ │ ├── wire
│ │ │ ├── RequestDispatcherSpec.scala
│ │ │ └── WorkQueueSuite.scala
│ │ │ └── TorrentMetadataSpec.scala
│ │ └── main
│ │ └── scala
│ │ └── com
│ │ └── github
│ │ └── torrentdam
│ │ └── bittorrent
│ │ ├── protocol
│ │ ├── extensions
│ │ │ ├── Extensions.scala
│ │ │ ├── ExtensionHandshake.scala
│ │ │ └── metadata
│ │ │ │ └── UtMessage.scala
│ │ └── message.scala
│ │ ├── wire
│ │ ├── DownloadMetadata.scala
│ │ ├── Torrent.scala
│ │ ├── Swarm.scala
│ │ ├── MessageSocket.scala
│ │ ├── ExtensionHandler.scala
│ │ ├── Download.scala
│ │ ├── RequestDispatcher.scala
│ │ └── Connection.scala
│ │ ├── FileMapping.scala
│ │ └── TorrentMetadata.scala
├── jvm
│ └── src
│ │ └── main
│ │ └── scala
│ │ └── com
│ │ └── github
│ │ └── torrentdam
│ │ └── bittorrent
│ │ └── CrossPlatform.scala
└── native
│ └── src
│ └── main
│ └── scala
│ └── com
│ └── github
│ └── torrentdam
│ └── bittorrent
│ └── CrossPlatform.scala
├── .scala-build
└── .scalafmt.conf
├── dht
└── src
│ ├── test
│ └── scala
│ │ └── com
│ │ └── github
│ │ └── torrentdam
│ │ └── bittorrent
│ │ └── dht
│ │ ├── NoOpLogger.scala
│ │ ├── MessageFormatSpec.scala
│ │ └── PeerDiscoverySpec.scala
│ └── main
│ └── scala
│ └── com
│ └── github
│ └── torrentdam
│ └── bittorrent
│ └── dht
│ ├── NodeInfo.scala
│ ├── QueryHandler.scala
│ ├── RoutingTableRefresh.scala
│ ├── MessageSocket.scala
│ ├── RoutingTableBootstrap.scala
│ ├── RequestResponse.scala
│ ├── Client.scala
│ ├── Node.scala
│ ├── RoutingTable.scala
│ ├── PeerDiscovery.scala
│ └── message.scala
├── files
└── src
│ ├── main
│ └── scala
│ │ └── com
│ │ └── github
│ │ └── torrentdam
│ │ └── bittorrent
│ │ └── files
│ │ ├── package.scala
│ │ ├── Reader.scala
│ │ └── Writer.scala
│ └── test
│ └── scala
│ └── com
│ └── github
│ └── torrentdam
│ └── bittorrent
│ └── files
│ └── WriterSpec.scala
├── LICENSE
├── .github
└── workflows
│ └── build.yml
└── mill
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=1.10.1
2 |
--------------------------------------------------------------------------------
/cmd/Dockerfile-native:
--------------------------------------------------------------------------------
1 | FROM ubuntu:22.04
2 |
3 | COPY ./.native/target/scala-3.3.0/cmd-out /opt/torrentdam
4 |
5 | ENTRYPOINT ["/opt/torrentdam"]
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | target
2 | out
3 |
4 | .idea
5 | .idea_modules
6 |
7 | .vscode
8 |
9 | metals.sbt
10 | .metals
11 | .bloop
12 | .bsp
13 |
14 | logs
15 |
--------------------------------------------------------------------------------
/common/src/main/scala/com/github/torrentdam/bittorrent/PeerInfo.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent
2 |
3 | import com.comcast.ip4s.*
4 |
5 | final case class PeerInfo(address: SocketAddress[IpAddress])
6 |
--------------------------------------------------------------------------------
/.scalafmt.conf:
--------------------------------------------------------------------------------
1 | version = 3.5.3
2 | runner.dialect = scala3
3 | maxColumn = 120
4 | continuationIndent.defnSite = 2
5 | rewrite.rules = [SortImports]
6 | rewrite.imports.expand = true
7 | verticalAlignMultilineOperators = true
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BitTorrent Client
2 |
3 | ## Development
4 |
5 | Run tests:
6 | ```sh
7 | $ sbt test
8 | ```
9 |
10 | ## Releasing
11 |
12 | Tagging triggers uploading jars to Sonatype where they have to be manually released.
--------------------------------------------------------------------------------
/bittorrent/shared/src/test/resources/bencode/ubuntu-18.10-live-server-amd64.iso.torrent:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TorrentDam/bittorrent/HEAD/bittorrent/shared/src/test/resources/bencode/ubuntu-18.10-live-server-amd64.iso.torrent
--------------------------------------------------------------------------------
/.scala-build/.scalafmt.conf:
--------------------------------------------------------------------------------
1 | version = "3.5.3"
2 | runner.dialect = scala3
3 | maxColumn = 120
4 | continuationIndent.defnSite = 2
5 | rewrite.rules = [SortImports]
6 | rewrite.imports.expand = true
7 | verticalAlignMultilineOperators = true
8 |
--------------------------------------------------------------------------------
/bittorrent/jvm/src/main/scala/com/github/torrentdam/bittorrent/CrossPlatform.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent
2 |
3 | import scodec.bits.ByteVector
4 |
5 | object CrossPlatform {
6 | def sha1(bytes: ByteVector): ByteVector = bytes.sha1
7 | }
--------------------------------------------------------------------------------
/cmd/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM eclipse-temurin:24-jre-noble
2 |
3 | COPY ./out/cmd/assembly.dest/out.jar /opt/torrentdam/assembly.jar
4 |
5 | ENTRYPOINT [ "java", "-Dcats.effect.tracing.mode=none", "-XX:+UnlockExperimentalVMOptions", "-XX:+UseCompactObjectHeaders", "-jar", "/opt/torrentdam/assembly.jar"]
6 |
--------------------------------------------------------------------------------
/bittorrent/native/src/main/scala/com/github/torrentdam/bittorrent/CrossPlatform.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent
2 |
3 | import scodec.bits.ByteVector
4 |
5 | object CrossPlatform {
6 | def sha1(bytes: ByteVector): ByteVector = {
7 | val md = java.security.MessageDigest.getInstance("SHA-1")
8 | val digest = md.digest(bytes.toArray)
9 | md.reset()
10 | ByteVector(digest)
11 | }
12 | }
--------------------------------------------------------------------------------
/cmd/resources/logback.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | %-4relative [%thread] %-5level %logger{35} - %msg %n
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.2.1")
2 | addSbtPlugin("com.github.sbt" % "sbt-native-packager" % "1.9.9")
3 |
4 | // Multi-platform support
5 | addSbtPlugin("org.scala-js" % "sbt-scalajs" % "1.16.0")
6 | addSbtPlugin("org.portable-scala" % "sbt-scalajs-crossproject" % "1.2.0")
7 | addSbtPlugin("org.portable-scala" % "sbt-scala-native-crossproject" % "1.2.0")
8 | addSbtPlugin("org.scala-native" % "sbt-scala-native" % "0.4.17")
9 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/protocol/extensions/Extensions.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.protocol.extensions
2 |
3 | import com.github.torrentdam.bittorrent.protocol.message.Message
4 |
5 | object Extensions {
6 |
7 | object MessageId {
8 | val Handshake = 0L
9 | val Metadata = 1L
10 | }
11 |
12 | def handshake: ExtensionHandshake =
13 | ExtensionHandshake(
14 | Map(
15 | ("ut_metadata", MessageId.Metadata)
16 | ),
17 | None
18 | )
19 | }
20 |
--------------------------------------------------------------------------------
/common/src/main/scala/com/github/torrentdam/bittorrent/InfoHash.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent
2 | import scodec.bits.ByteVector
3 |
4 | final case class InfoHash(bytes: ByteVector) {
5 | def toHex: String = bytes.toHex
6 | override def toString: String = toHex
7 | }
8 |
9 | object InfoHash {
10 |
11 | val fromString: PartialFunction[String, InfoHash] =
12 | Function.unlift { s =>
13 | for {
14 | b <- ByteVector.fromHexDescriptive(s.toLowerCase).toOption
15 | _ <- if (b.length == 20) Some(()) else None
16 | } yield InfoHash(b)
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/dht/src/test/scala/com/github/torrentdam/bittorrent/dht/NoOpLogger.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.dht
2 |
3 | import cats.effect.IO
4 | import cats.syntax.all.given
5 | import org.legogroup.woof.LogInfo
6 | import org.legogroup.woof.LogLevel
7 | import org.legogroup.woof.Logger
8 | import org.legogroup.woof.Logger.StringLocal
9 |
10 | class NoOpLogger extends Logger[IO] {
11 | val stringLocal: StringLocal[IO] = NoOpLocal()
12 | def doLog(level: LogLevel, message: String)(using LogInfo): IO[Unit] = IO.unit
13 | }
14 |
15 | class NoOpLocal extends Logger.StringLocal[IO] {
16 | def ask = List.empty[(String, String)].pure[IO]
17 | def local[A](fa: IO[A])(f: List[(String, String)] => List[(String, String)]) = fa
18 | }
19 |
--------------------------------------------------------------------------------
/files/src/main/scala/com/github/torrentdam/bittorrent/files/package.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.files
2 |
3 | import com.github.torrentdam.bittorrent.TorrentMetadata
4 | import scala.util.chaining.scalaUtilChainingOps
5 |
6 | private def createFileOffsets(files: List[TorrentMetadata.File]) =
7 | Array
8 | .ofDim[FileRange](files.length)
9 | .tap: array =>
10 | var currentOffset: Long = 0
11 | files.iterator.zipWithIndex.foreach(
12 | (file, index) =>
13 | array(index) = FileRange(file, currentOffset, currentOffset + file.length)
14 | currentOffset += file.length
15 | )
16 |
17 | extension (fileOffsets: Array[FileRange]) {
18 | private def matchFiles(start: Long, end: Long): Iterator[FileRange] =
19 | fileOffsets.iterator
20 | .dropWhile(_.endOffset <= start)
21 | .takeWhile(_.startOffset < end)
22 | }
23 |
24 | private case class FileRange(file: TorrentMetadata.File, startOffset: Long, endOffset: Long)
25 |
--------------------------------------------------------------------------------
/common/src/main/scala/com/github/torrentdam/bittorrent/PeerId.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent
2 |
3 | import cats.effect.std.Random
4 | import cats.syntax.all.*
5 | import cats.Monad
6 | import scodec.bits.ByteVector
7 |
8 | final case class PeerId(bytes: ByteVector) {
9 | override def toString() = s"PeerId(${bytes.decodeUtf8.getOrElse(bytes.toHex)})"
10 | }
11 |
12 | object PeerId {
13 |
14 | def apply(b0: Byte, b1: Byte, b2: Byte, b3: Byte, b4: Byte, b5: Byte): PeerId = {
15 | val bytes = Array[Byte](b0, b1, b2, b3, b4, b5)
16 | val hexPart = ByteVector(bytes).toHex
17 | new PeerId(ByteVector.encodeUtf8("-qB0000-" + hexPart).toOption.get)
18 | }
19 |
20 | private def apply(bytes: Array[Byte]): PeerId = {
21 | val hexPart = ByteVector(bytes).toHex
22 | new PeerId(ByteVector.encodeUtf8("-qB0000-" + hexPart).toOption.get)
23 | }
24 |
25 | def generate[F[_]](using Random[F], Monad[F]): F[PeerId] =
26 | for bytes <- Random[F].nextBytes(6)
27 | yield PeerId(bytes)
28 | }
29 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/test/scala/com/github/torrentdam/bittorrent/protocol/message/HandshakeSpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.protocol.message
2 |
3 | import com.github.torrentdam.bittorrent.InfoHash
4 | import com.github.torrentdam.bittorrent.PeerId
5 | import scala.util.chaining.*
6 | import scodec.bits.ByteVector
7 |
8 | class HandshakeSpec extends munit.FunSuite {
9 |
10 | test("read and write protocol extension bit") {
11 | val message =
12 | Handshake(
13 | true,
14 | InfoHash(ByteVector.fill(20)(0)),
15 | PeerId(0, 0, 0, 0, 0, 0)
16 | )
17 | assert(
18 | PartialFunction.cond(Handshake.HandshakeCodec.encode(message).toOption) {
19 | case Some(bits) =>
20 | bits
21 | .splitAt(20 * 8)
22 | .pipe { case (_, bits) =>
23 | bits.splitAt(64)
24 | }
25 | .pipe { case (reserved, bits) =>
26 | assert(reserved.get(42) == false)
27 | assert(reserved.get(43) == true)
28 | assert(reserved.get(44) == false)
29 | }
30 | true
31 | case _ => false
32 | }
33 | )
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | This is free and unencumbered software released into the public domain.
2 |
3 | Anyone is free to copy, modify, publish, use, compile, sell, or
4 | distribute this software, either in source code form or as a compiled
5 | binary, for any purpose, commercial or non-commercial, and by any
6 | means.
7 |
8 | In jurisdictions that recognize copyright laws, the author or authors
9 | of this software dedicate any and all copyright interest in the
10 | software to the public domain. We make this dedication for the benefit
11 | of the public at large and to the detriment of our heirs and
12 | successors. We intend this dedication to be an overt act of
13 | relinquishment in perpetuity of all present and future rights to this
14 | software under copyright law.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 | OTHER DEALINGS IN THE SOFTWARE.
23 |
24 | For more information, please refer to
--------------------------------------------------------------------------------
/common/src/test/scala/MagnetLinkSuite.scala:
--------------------------------------------------------------------------------
1 | import com.github.torrentdam.bittorrent.*
2 |
3 | class MagnetLinkSuite extends munit.FunSuite {
4 |
5 | test("parses magnet link") {
6 | val link =
7 | """magnet:?xt=urn:btih:C071AA6D06101FE3C1D8D3411343CFEB33D91E5F&tr=http%3A%2F%2Fbt.t-ru.org%2Fann%3Fmagnet&dn=%D0%9C%D0%B0%D1%82%D1%80%D0%B8%D1%86%D0%B0%3A%20%D0%92%D0%BE%D1%81%D0%BA%D1%80%D0%B5%D1%88%D0%B5%D0%BD%D0%B8%D0%B5%20%2F%20The%20Matrix%20Resurrections%20(%D0%9B%D0%B0%D0%BD%D0%B0%20%D0%92%D0%B0%D1%87%D0%BE%D0%B2%D1%81%D0%BA%D0%B8%20%2F%20Lana%20Wachowski)%20%5B2021%2C%20%D0%A1%D0%A8%D0%90%2C%20%D0%A4%D0%B0%D0%BD%D1%82%D0%B0%D1%81%D1%82%D0%B8%D0%BA%D0%B0%2C%20%D0%B1%D0%BE%D0%B5%D0%B2%D0%B8%D0%BA%2C%20WEB-DLRip%5D%20MVO%20(Jaskier)%20%2B%20Sub%20Rus%2C%20E"""
8 | assertEquals(
9 | MagnetLink.fromString(link),
10 | Some(
11 | MagnetLink(
12 | infoHash = InfoHash.fromString("C071AA6D06101FE3C1D8D3411343CFEB33D91E5F"),
13 | displayName = Some(
14 | "Матрица: Воскрешение / The Matrix Resurrections (Лана Вачовски / Lana Wachowski) [2021, США, Фантастика, боевик, WEB-DLRip] MVO (Jaskier) + Sub Rus, E"
15 | ),
16 | trackers = List("http://bt.t-ru.org/ann?magnet")
17 | )
18 | )
19 | )
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/DownloadMetadata.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.wire
2 |
3 | import cats.effect.implicits.*
4 | import cats.effect.IO
5 | import cats.implicits.*
6 | import com.github.torrentdam.bittorrent.TorrentMetadata.Lossless
7 | import fs2.Stream
8 | import org.legogroup.woof.given
9 | import org.legogroup.woof.Logger
10 | import scala.concurrent.duration.*
11 |
12 | object DownloadMetadata {
13 |
14 | def apply(swarm: Swarm)(using logger: Logger[IO]): IO[Lossless] =
15 | logger.info("Downloading metadata") >>
16 | Stream.unit.repeat
17 | .parEvalMapUnordered(10)(_ =>
18 | swarm.connect
19 | .use(connection => DownloadMetadata(connection).timeout(1.minute))
20 | .attempt
21 | )
22 | .collectFirst { case Right(metadata) =>
23 | metadata
24 | }
25 | .compile
26 | .lastOrError
27 | .flatTap(_ => logger.info("Metadata downloaded"))
28 |
29 | def apply(connection: Connection)(using logger: Logger[IO]): IO[Lossless] =
30 | connection.extensionApi
31 | .flatMap(_.utMetadata.liftTo[IO](UtMetadataNotSupported()))
32 | .flatMap(_.fetch)
33 |
34 | case class UtMetadataNotSupported() extends Throwable("UtMetadata is not supported")
35 |
36 | }
37 |
--------------------------------------------------------------------------------
/files/src/main/scala/com/github/torrentdam/bittorrent/files/Reader.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.files
2 |
3 | import cats.effect.IO
4 | import com.github.torrentdam.bittorrent.TorrentMetadata
5 | import scodec.bits.ByteVector
6 |
7 | trait Reader {
8 | def read(pieceIndex: Long): List[Reader.ReadBytes]
9 | }
10 |
11 | object Reader {
12 |
13 | def fromTorrent(torrent: TorrentMetadata): Reader = apply(torrent.files, torrent.pieceLength)
14 | case class ReadBytes(file: TorrentMetadata.File, offset: Long, endOffset: Long)
15 | private[files] def apply(files: List[TorrentMetadata.File], pieceLength: Long): Reader =
16 | val totalLength = files.map(_.length).sum
17 | val fileOffsets = createFileOffsets(files)
18 | new {
19 | def read(pieceIndex: Long): List[ReadBytes] =
20 | val pieceStartOffset = pieceIndex * pieceLength
21 | val pieceEndOffset = math.min(pieceStartOffset + pieceLength, totalLength)
22 | val fileRanges = fileOffsets.matchFiles(pieceStartOffset, pieceEndOffset)
23 | fileRanges
24 | .map(range =>
25 | ReadBytes(
26 | range.file,
27 | math.max(range.startOffset, pieceStartOffset) - range.startOffset,
28 | math.min(range.endOffset, pieceEndOffset) - range.startOffset
29 | )
30 | )
31 | .toList
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/protocol/extensions/ExtensionHandshake.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.protocol.extensions
2 |
3 | import cats.syntax.all.*
4 | import com.github.torrentdam.bencode
5 | import com.github.torrentdam.bencode.format.*
6 | import scodec.bits.ByteVector
7 |
8 | case class ExtensionHandshake(
9 | extensions: Map[String, Long],
10 | metadataSize: Option[Long]
11 | )
12 |
13 | object ExtensionHandshake {
14 |
15 | private val format =
16 | (
17 | field[Map[String, Long]]("m"),
18 | fieldOptional[Long]("metadata_size")
19 | ).imapN(ExtensionHandshake.apply)(v => (v.extensions, v.metadataSize))
20 |
21 | def encode(handshake: ExtensionHandshake): ByteVector =
22 | bencode
23 | .encode(format.write(handshake).toOption.get)
24 | .toByteVector
25 |
26 | def decode(bytes: ByteVector): Either[Throwable, ExtensionHandshake] =
27 | for
28 | bc <-
29 | bencode
30 | .decode(bytes.bits)
31 | .leftMap(Error.BencodeError.apply)
32 | handshakeResponse <-
33 | ExtensionHandshake.format
34 | .read(bc)
35 | .leftMap(Error.HandshakeFormatError("Unable to parse handshake response", _))
36 | yield handshakeResponse
37 |
38 | object Error {
39 | case class BencodeError(cause: Throwable) extends Error(cause)
40 | case class HandshakeFormatError(message: String, cause: Throwable) extends Error(message, cause)
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/test/scala/com/github/torrentdam/bittorrent/FileMappingSpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent
2 |
3 | import scodec.bits.ByteVector
4 | import FileMapping.FileSpan
5 | import TorrentMetadata.File
6 |
7 | class FileMappingSpec extends munit.FunSuite {
8 |
9 | test("all files in one piece") {
10 | val metadata = TorrentMetadata(
11 | name = "test",
12 | pieceLength = 10,
13 | pieces = ByteVector.empty,
14 | files = List(
15 | File(5, Nil),
16 | File(3, Nil),
17 | File(2, Nil)
18 | )
19 | )
20 | val result = FileMapping.fromMetadata(metadata)
21 | val expectation = FileMapping(
22 | List(
23 | FileSpan(0, 5, 0, 0, 0, 5),
24 | FileSpan(1, 3, 0, 5, 0, 8),
25 | FileSpan(2, 2, 0, 8, 1, 0)
26 | ),
27 | pieceLength = 10
28 | )
29 | assert(result == expectation)
30 | }
31 |
32 | test("file spans multiple pieces") {
33 | val metadata = TorrentMetadata(
34 | name = "test",
35 | pieceLength = 10,
36 | pieces = ByteVector.empty,
37 | files = List(
38 | File(5, Nil),
39 | File(13, Nil),
40 | File(2, Nil)
41 | )
42 | )
43 | val result = FileMapping.fromMetadata(metadata)
44 | val expectation = FileMapping(
45 | List(
46 | FileSpan(0, 5, 0, 0, 0, 5),
47 | FileSpan(1, 13, 0, 5, 1, 8),
48 | FileSpan(2, 2, 1, 8, 2, 0)
49 | ),
50 | pieceLength = 10
51 | )
52 | assert(result == expectation)
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/test/scala/com/github/torrentdam/bittorrent/wire/RequestDispatcherSpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.wire
2 |
3 | import com.github.torrentdam.bittorrent.wire.RequestDispatcher
4 | import com.github.torrentdam.bittorrent.TorrentFile
5 | import com.github.torrentdam.bittorrent.TorrentMetadata
6 | import com.github.torrentdam.bencode
7 | import com.github.torrentdam.bencode.format.BencodeFormat
8 | import scodec.bits.BitVector
9 |
10 | class RequestDispatcherSpec extends munit.FunSuite {
11 |
12 | test("build request queue from torrent metadata") {
13 | val source = getClass.getClassLoader
14 | .getResourceAsStream("bencode/ubuntu-18.10-live-server-amd64.iso.torrent")
15 | .readAllBytes()
16 | val Right(result) = bencode.decode(BitVector(source)): @unchecked
17 | val torrentFile = summon[BencodeFormat[TorrentFile]].read(result).toOption.get
18 | assert(
19 | PartialFunction.cond(torrentFile.info) {
20 | case TorrentMetadata.Lossless(metadata @ TorrentMetadata(_, _, pieces, List(file)), _) =>
21 | val fileSize = file.length
22 | val piecesTotal = (pieces.length.toDouble / 20).ceil.toInt
23 | val workGenerator = RequestDispatcher.WorkGenerator(metadata)
24 | val queue =
25 | (0 until piecesTotal).map(pieceIndex => workGenerator.pieceWork(pieceIndex))
26 | assert(queue.map(_.size).toList.sum == fileSize)
27 | assert(queue.toList.flatMap(_.requests.toList).map(_.length).sum == fileSize)
28 | true
29 | }
30 | )
31 | }
32 |
33 | }
34 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/protocol/extensions/metadata/UtMessage.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.protocol.extensions.metadata
2 |
3 | import cats.syntax.all.*
4 | import com.github.torrentdam.bencode
5 | import com.github.torrentdam.bencode.format.*
6 | import scodec.bits.ByteVector
7 |
8 | enum UtMessage:
9 | case Request(piece: Long)
10 | case Data(piece: Long, byteVector: ByteVector)
11 | case Reject(piece: Long)
12 |
13 | object UtMessage {
14 |
15 | val MessageFormat: BencodeFormat[(Long, Long)] =
16 | (
17 | field[Long]("msg_type"),
18 | field[Long]("piece")
19 | ).tupled
20 |
21 | def encode(message: UtMessage): ByteVector = {
22 | val (bc, extraBytes) =
23 | message match {
24 | case Request(piece) => (MessageFormat.write((0, piece)).toOption.get, none)
25 | case Data(piece, bytes) => (MessageFormat.write((1, piece)).toOption.get, bytes.some)
26 | case Reject(piece) => (MessageFormat.write((2, piece)).toOption.get, none)
27 | }
28 | bencode.encode(bc).toByteVector ++ extraBytes.getOrElse(ByteVector.empty)
29 | }
30 |
31 | def decode(bytes: ByteVector): Either[Throwable, UtMessage] = {
32 | bencode
33 | .decodeHead(bytes.toBitVector)
34 | .flatMap { case (remainder, result) =>
35 | MessageFormat.read(result).map { case (msgType, piece) =>
36 | msgType match {
37 | case 0 => Request(piece)
38 | case 1 => Data(piece, remainder.toByteVector)
39 | case 2 => Reject(piece)
40 | }
41 | }
42 | }
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/NodeInfo.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.dht
2 |
3 | import cats.effect.std.Random
4 | import cats.syntax.all.*
5 | import cats.*
6 | import com.comcast.ip4s.*
7 | import com.github.torrentdam.bittorrent.InfoHash
8 | import scodec.bits.ByteVector
9 |
10 | final case class NodeInfo(id: NodeId, address: SocketAddress[IpAddress])
11 |
12 | final case class NodeId(bytes: ByteVector) {
13 | val int: BigInt = BigInt(1, bytes.toArray)
14 | }
15 |
16 | object NodeId {
17 |
18 | private def distance(a: ByteVector, b: ByteVector): BigInt = BigInt(1, (a.xor(b)).toArray)
19 |
20 | def distance(a: NodeId, b: NodeId): BigInt = distance(a.bytes, b.bytes)
21 |
22 | def distance(a: NodeId, b: InfoHash): BigInt = distance(a.bytes, b.bytes)
23 |
24 | def random[F[_]](using Random[F], Monad[F]): F[NodeId] = {
25 | for bytes <- Random[F].nextBytes(20)
26 | yield NodeId(ByteVector.view(bytes))
27 | }
28 |
29 | def fromInt(int: BigInt): NodeId = NodeId(ByteVector.view(int.toByteArray).padTo(20))
30 |
31 | def randomInRange[F[_]](from: BigInt, until: BigInt)(using Random[F], Monad[F]): F[NodeId] =
32 | val difference = BigDecimal(until - from)
33 | for
34 | randomDouble <- Random[F].nextDouble
35 | integer = from + (difference * randomDouble).toBigInt
36 | bigIntBytes = ByteVector(integer.toByteArray)
37 | vector = if bigIntBytes(0) == 0 then bigIntBytes.tail else bigIntBytes
38 | yield NodeId(vector.padLeft(20))
39 |
40 | given Show[NodeId] = nodeId => s"NodeId(${nodeId.bytes.toHex})"
41 |
42 | val MaxValue: BigInt = BigInt(1, Array.fill(20)(0xff.toByte))
43 | }
44 |
--------------------------------------------------------------------------------
/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/QueryHandler.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.dht
2 |
3 | import cats.implicits.*
4 | import cats.Monad
5 | import com.comcast.ip4s.*
6 | import com.github.torrentdam.bittorrent.PeerInfo
7 |
8 | trait QueryHandler[F[_]] {
9 | def apply(address: SocketAddress[IpAddress], query: Query): F[Option[Response]]
10 | }
11 |
12 | object QueryHandler {
13 |
14 | def noop[F[_]: Monad]: QueryHandler[F] = (_, _) => none.pure[F]
15 |
16 | def simple[F[_]: Monad](selfId: NodeId, routingTable: RoutingTable[F]): QueryHandler[F] = { (address, query) =>
17 | query match {
18 | case Query.Ping(_) =>
19 | Response.Ping(selfId).some.pure[F]
20 | case Query.FindNode(_, target) =>
21 | routingTable.goodNodes(target).map { nodes =>
22 | Response.Nodes(selfId, nodes.take(8).toList).some
23 | }
24 | case Query.GetPeers(_, infoHash) =>
25 | routingTable.findPeers(infoHash).flatMap {
26 | case Some(peers) =>
27 | Response.Peers(selfId, peers.toList).some.pure[F]
28 | case None =>
29 | routingTable
30 | .goodNodes(NodeId(infoHash.bytes))
31 | .map { nodes =>
32 | Response.Nodes(selfId, nodes.take(8).toList).some
33 | }
34 | }
35 | case Query.AnnouncePeer(_, infoHash, port) =>
36 | routingTable
37 | .addPeer(infoHash, PeerInfo(SocketAddress(address.host, Port.fromInt(port.toInt).get)))
38 | .as(
39 | Response.Ping(selfId).some
40 | )
41 | case Query.SampleInfoHashes(_, _) =>
42 | Response.SampleInfoHashes(selfId, None, List.empty).some.pure[F]
43 | }
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/files/src/main/scala/com/github/torrentdam/bittorrent/files/Writer.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.files
2 |
3 | import cats.effect.IO
4 | import com.github.torrentdam.bittorrent.TorrentMetadata
5 | import scodec.bits.ByteVector
6 |
7 | import util.chaining.scalaUtilChainingOps
8 |
9 | trait Writer {
10 | def write(index: Long, bytes: ByteVector): List[Writer.WriteBytes]
11 | }
12 |
13 | object Writer {
14 |
15 | def fromTorrent(torrent: TorrentMetadata): Writer = apply(torrent.files, torrent.pieceLength)
16 | case class WriteBytes(file: TorrentMetadata.File, offset: Long, bytes: ByteVector)
17 | private[files] def apply(files: List[TorrentMetadata.File], pieceLength: Long): Writer =
18 | val fileOffsets = createFileOffsets(files)
19 |
20 | def distribute(pieceOffset: Long, byteVector: ByteVector, files: Iterator[FileRange]): Iterator[WriteBytes] =
21 | var remainingBytes = byteVector
22 | var bytesOffset = pieceOffset
23 | files.map: fileRange =>
24 | val offsetInFile = bytesOffset - fileRange.startOffset
25 | val lengthToWrite = math.min(remainingBytes.length, fileRange.file.length - offsetInFile)
26 | val (bytesToWrite, rem) = remainingBytes.splitAt(lengthToWrite)
27 | remainingBytes = rem
28 | bytesOffset += lengthToWrite
29 | WriteBytes(
30 | fileRange.file,
31 | offsetInFile,
32 | bytesToWrite
33 | )
34 | new {
35 | def write(pieceIndex: Long, bytes: ByteVector): List[WriteBytes] =
36 | val pieceOffset = pieceIndex * pieceLength
37 | val files = fileOffsets.matchFiles(pieceOffset, pieceOffset + bytes.length)
38 | val writes = distribute(pieceOffset, bytes, files)
39 | writes.toList
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/common/src/main/scala/com/github/torrentdam/bittorrent/MagnetLink.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent
2 |
3 | import com.github.torrentdam.bittorrent.InfoHash
4 | import java.net.URLDecoder
5 | import java.nio.charset.Charset
6 |
7 | case class MagnetLink(infoHash: InfoHash, displayName: Option[String], trackers: List[String])
8 |
9 | object MagnetLink {
10 |
11 | def fromString(source: String): Option[MagnetLink] =
12 | source match
13 | case s"magnet:?$query" => fromQueryString(query)
14 | case _ => None
15 |
16 | private def fromQueryString(str: String) =
17 | val params = parseQueryString(str)
18 | for
19 | infoHash <- getInfoHash(params)
20 | displayName = getDisplayName(params)
21 | trackers = getTrackers(params)
22 | yield MagnetLink(infoHash, displayName, trackers)
23 |
24 | private type Query = Map[String, List[String]]
25 |
26 | private def getInfoHash(query: Query): Option[InfoHash] =
27 | query.get("xt").flatMap {
28 | case List(s"urn:btih:${InfoHash.fromString(ih)}") => Some(ih)
29 | case _ => None
30 | }
31 |
32 | private def getDisplayName(query: Query): Option[String] =
33 | query.get("dn").flatMap(_.headOption)
34 |
35 | private def getTrackers(query: Query): List[String] =
36 | query.get("tr").toList.flatten
37 |
38 | private def parseQueryString(str: String): Query =
39 | str
40 | .split('&')
41 | .toList
42 | .map(urlDecode)
43 | .map { p =>
44 | val split = p.split('=')
45 | (split.head, split.tail.mkString("="))
46 | }
47 | .groupMap(_._1)(_._2)
48 |
49 | private def urlDecode(value: String) =
50 | URLDecoder.decode(value, Charset.forName("UTF-8"))
51 | }
52 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/Torrent.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.wire
2 |
3 | import cats.effect.implicits.*
4 | import cats.effect.kernel.syntax.all.*
5 | import cats.effect.kernel.Async
6 | import cats.effect.kernel.Resource
7 | import cats.effect.IO
8 | import cats.effect.Resource
9 | import cats.implicits.*
10 | import com.github.torrentdam.bittorrent.TorrentMetadata
11 | import com.github.torrentdam.bittorrent.TorrentMetadata.Lossless
12 | import org.legogroup.woof.given
13 | import org.legogroup.woof.Logger
14 | import scala.collection.immutable.BitSet
15 | import scodec.bits.ByteVector
16 |
17 | trait Torrent {
18 | def metadata: TorrentMetadata.Lossless
19 | def stats: IO[Torrent.Stats]
20 | def downloadPiece(index: Long): IO[ByteVector]
21 | }
22 |
23 | object Torrent {
24 |
25 | def make(
26 | metadata: TorrentMetadata.Lossless,
27 | swarm: Swarm
28 | )(using logger: Logger[IO]): Resource[IO, Torrent] =
29 | for
30 | requestDispatcher <- RequestDispatcher(metadata.parsed)
31 | _ <- Download(swarm, requestDispatcher).background
32 | yield
33 | val metadata0 = metadata
34 | new Torrent {
35 | def metadata: TorrentMetadata.Lossless = metadata0
36 | def stats: IO[Stats] =
37 | for
38 | connected <- swarm.connected.list
39 | availability <- connected.traverse(_.availability.get)
40 | availability <- availability.foldMap(identity).pure[IO]
41 | yield Stats(connected.size, availability)
42 | def downloadPiece(index: Long): IO[ByteVector] =
43 | requestDispatcher.downloadPiece(index)
44 | }
45 | end for
46 |
47 | case class Stats(
48 | connected: Int,
49 | availability: BitSet
50 | )
51 |
52 | enum Error extends Exception:
53 | case EmptyMetadata()
54 | }
55 |
--------------------------------------------------------------------------------
/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RoutingTableRefresh.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.dht
2 |
3 | import cats.effect.IO
4 | import cats.syntax.all.*
5 | import cats.effect.cps.{*, given}
6 | import cats.effect.std.Random
7 | import org.legogroup.woof.{Logger, given}
8 |
9 | import scala.concurrent.duration.{DurationInt, FiniteDuration}
10 |
11 | class RoutingTableRefresh(table: RoutingTable[IO], client: Client, discovery: PeerDiscovery)(using logger: Logger[IO], random: Random[IO]):
12 |
13 | def runOnce: IO[Unit] = async[IO]:
14 | val buckets = table.buckets.await
15 | val (fresh, stale) = buckets.toList.partition(_.nodes.values.exists(_.isGood))
16 | if stale.nonEmpty then
17 | refreshBuckets(stale).await
18 | val nodes = fresh.flatMap(_.nodes.values)
19 | pingNodes(nodes).await
20 |
21 | def runEvery(period: FiniteDuration): IO[Unit] =
22 | IO
23 | .sleep(period)
24 | .productR(runOnce)
25 | .foreverM
26 | .handleErrorWith: e =>
27 | logger.error(s"PingRoutine failed: $e")
28 | .foreverM
29 |
30 | private def pingNodes(nodes: List[RoutingTable.Node]) = async[IO]:
31 | logger.info(s"Pinging ${nodes.size} nodes").await
32 | val results = nodes
33 | .parTraverse { node =>
34 | client.ping(node.address).timeout(5.seconds).attempt.map(_.bimap(_ => node.id, _ => node.id))
35 | }
36 | .await
37 | val (bad, good) = results.partitionMap(identity)
38 | logger.info(s"Got ${good.size} good nodes and ${bad.size} bad nodes").await
39 | table.updateGoodness(good.toSet, bad.toSet).await
40 |
41 | private def refreshBuckets(buckets: List[RoutingTable.TreeNode.Bucket]) = async[IO]:
42 | logger.info(s"Found ${buckets.size} stale buckets").await
43 | buckets
44 | .parTraverse: bucket =>
45 | val randomId = NodeId.randomInRange(bucket.from, bucket.until).await
46 | discovery.findNodes(randomId).take(32).compile.drain
47 | .await
48 |
49 |
50 |
--------------------------------------------------------------------------------
/dht/src/test/scala/com/github/torrentdam/bittorrent/dht/MessageFormatSpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.dht
2 |
3 | import com.github.torrentdam.bittorrent.InfoHash
4 | import com.github.torrentdam.bencode.format.BencodeFormat
5 | import com.github.torrentdam.bencode.Bencode
6 | import scodec.bits.ByteVector
7 |
8 | class MessageFormatSpec extends munit.FunSuite {
9 |
10 | test("decode ping response") {
11 | val input = Bencode.BDictionary(
12 | "ip" -> Bencode.BString(ByteVector.fromValidHex("1f14bdfa9f21")),
13 | "y" -> Bencode.BString("q"),
14 | "t" -> Bencode.BString(ByteVector.fromValidHex("6a76679c")),
15 | "a" -> Bencode.BDictionary(
16 | "id" -> Bencode.BString(ByteVector.fromValidHex("32f54e697351ff4aec29cdbaabf2fbe3467cc267"))
17 | ),
18 | "q" -> Bencode.BString("ping")
19 | )
20 |
21 | val result = summon[BencodeFormat[Message]].read(input)
22 | val expectation = Right(
23 | Message.QueryMessage(
24 | ByteVector.fromValidHex("6a76679c"),
25 | Query.Ping(NodeId(ByteVector.fromValidHex("32f54e697351ff4aec29cdbaabf2fbe3467cc267")))
26 | )
27 | )
28 |
29 | assert(result == expectation)
30 | }
31 |
32 | test("decode announce_peer query") {
33 | val input = Bencode.BDictionary(
34 | "t" -> Bencode.BString(ByteVector.fromValidHex("6a76679c")),
35 | "y" -> Bencode.BString("q"),
36 | "q" -> Bencode.BString("announce_peer"),
37 | "a" -> Bencode.BDictionary(
38 | "id" -> Bencode.BString(ByteVector.fromValidHex("32f54e697351ff4aec29cdbaabf2fbe3467cc267")),
39 | "info_hash" -> Bencode.BString(ByteVector.fromValidHex("32f54e697351ff4aec29cdbaabf2fbe3467cc267")),
40 | "port" -> Bencode.BInteger(9999)
41 | )
42 | )
43 | val result = summon[BencodeFormat[Message]].read(input)
44 | val expectation = Right(
45 | Message.QueryMessage(
46 | ByteVector.fromValidHex("6a76679c"),
47 | Query.AnnouncePeer(
48 | NodeId(ByteVector.fromValidHex("32f54e697351ff4aec29cdbaabf2fbe3467cc267")),
49 | InfoHash(ByteVector.fromValidHex("32f54e697351ff4aec29cdbaabf2fbe3467cc267")),
50 | 9999L
51 | )
52 | )
53 | )
54 | assert(result == expectation)
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/MessageSocket.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.dht
2 |
3 | import cats.*
4 | import cats.effect.Async
5 | import cats.effect.Concurrent
6 | import cats.effect.IO
7 | import cats.effect.Resource
8 | import cats.syntax.all.*
9 | import com.comcast.ip4s.*
10 | import com.comcast.ip4s.ip
11 | import com.github.torrentdam.bencode.decode
12 | import com.github.torrentdam.bencode.encode
13 | import com.github.torrentdam.bencode.format.BencodeFormat
14 | import fs2.io.net.Datagram
15 | import fs2.io.net.DatagramSocket
16 | import fs2.io.net.DatagramSocketGroup
17 | import fs2.io.net.Network
18 | import fs2.Chunk
19 | import java.net.InetSocketAddress
20 | import org.legogroup.woof.given
21 | import org.legogroup.woof.Logger
22 |
23 | class MessageSocket(socket: DatagramSocket[IO], logger: Logger[IO]) {
24 | import MessageSocket.Error
25 |
26 | def readMessage: IO[(SocketAddress[IpAddress], Message)] =
27 | for {
28 | datagram <- socket.read
29 | bc <- IO.fromEither(
30 | decode(datagram.bytes.toBitVector).leftMap(Error.BecodeSerialization.apply)
31 | )
32 | message <- IO.fromEither(
33 | summon[BencodeFormat[Message]]
34 | .read(bc)
35 | .leftMap(e => Error.MessageFormat(s"Filed to read message from bencode: $bc", e))
36 | )
37 | _ <- logger.trace(s"<<< ${datagram.remote} $message")
38 | } yield (datagram.remote, message)
39 |
40 | def writeMessage(address: SocketAddress[IpAddress], message: Message): IO[Unit] = IO.defer {
41 | val bc = summon[BencodeFormat[Message]].write(message).toOption.get
42 | val bytes = encode(bc)
43 | val packet = Datagram(address, Chunk.byteVector(bytes.bytes))
44 | socket.write(packet) >> logger.trace(s">>> $address $message")
45 | }
46 | }
47 |
48 | object MessageSocket {
49 |
50 | def apply(
51 | port: Option[Port]
52 | )(using logger: Logger[IO]): Resource[IO, MessageSocket] =
53 | Network[IO]
54 | .openDatagramSocket(Some(ip"0.0.0.0"), port)
55 | .map(socket => new MessageSocket(socket, logger))
56 |
57 | object Error {
58 | case class BecodeSerialization(cause: Throwable) extends Throwable(cause)
59 | case class MessageFormat(message: String, cause: Throwable) extends Throwable(message, cause)
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/FileMapping.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent
2 |
3 | import com.github.torrentdam.bittorrent.FileMapping.FileSpan
4 | import com.github.torrentdam.bittorrent.FileMapping.PieceInFile
5 | import com.github.torrentdam.bittorrent.TorrentMetadata
6 |
7 | case class FileMapping(value: List[FileSpan], pieceLength: Long) {
8 | def mapToFiles(pieceIndex: Long, length: Long): List[PieceInFile] =
9 | value
10 | .filter(span => span.beginIndex <= pieceIndex && span.endIndex >= pieceIndex)
11 | .map(span => pieceInFile(span, pieceIndex, length))
12 |
13 | private def pieceInFile(span: FileSpan, pieceIndex: Long, length: Long): PieceInFile =
14 | val beginOffset = (pieceIndex - span.beginIndex) * pieceLength + span.beginOffset
15 | val writeLength = Math.min(length, span.length - beginOffset)
16 | PieceInFile(span.fileIndex, beginOffset, writeLength)
17 | }
18 |
19 | object FileMapping {
20 | case class FileSpan(
21 | fileIndex: Int,
22 | length: Long,
23 | beginIndex: Long,
24 | beginOffset: Long,
25 | endIndex: Long,
26 | endOffset: Long
27 | )
28 | case class PieceInFile(
29 | fileIndex: Int,
30 | offset: Long,
31 | length: Long
32 | )
33 | def fromMetadata(torrentMetadata: TorrentMetadata): FileMapping =
34 | def forFile(fileIndex: Int, beginIndex: Long, beginOffset: Long, length: Long): FileSpan = {
35 | val spansPieces = length / torrentMetadata.pieceLength
36 | val remainder = length % torrentMetadata.pieceLength
37 | val endIndex = beginIndex + spansPieces + (beginOffset + remainder) / torrentMetadata.pieceLength
38 | val endOffset = (beginOffset + remainder) % torrentMetadata.pieceLength
39 | FileSpan(fileIndex, length, beginIndex, beginOffset, endIndex, endOffset)
40 | }
41 | case class State(beginIndex: Long, beginOffset: Long, spans: List[FileSpan])
42 | val spans =
43 | torrentMetadata.files.zipWithIndex
44 | .foldLeft(State(0L, 0L, Nil)) { case (state, (file, index)) =>
45 | val span = forFile(index, state.beginIndex, state.beginOffset, file.length)
46 | State(span.endIndex, span.endOffset, span :: state.spans)
47 | }
48 | .spans
49 | .reverse
50 | FileMapping(spans, torrentMetadata.pieceLength)
51 | }
52 |
--------------------------------------------------------------------------------
/files/src/test/scala/com/github/torrentdam/bittorrent/files/WriterSpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.files
2 |
3 | import com.github.torrentdam.bittorrent.TorrentMetadata
4 | import scodec.bits.ByteVector
5 |
6 | class WriterSpec extends munit.FunSuite {
7 |
8 | test("one piece one file") {
9 | val file =
10 | TorrentMetadata.File(
11 | 1L,
12 | List("foo.txt")
13 | )
14 | val writer = Writer(
15 | List(file),
16 | 1
17 | )
18 | val writes = writer.write(0, ByteVector(1))
19 | assertEquals(writes, List(Writer.WriteBytes(file, 0, ByteVector(1))))
20 | }
21 |
22 | test("write piece to second file") {
23 | val files =
24 | List(
25 | TorrentMetadata.File(
26 | 1L,
27 | List("foo.txt")
28 | ),
29 | TorrentMetadata.File(
30 | 1L,
31 | List("bar.txt")
32 | )
33 | )
34 | val writer = Writer(
35 | files,
36 | 1
37 | )
38 | val writes = writer.write(1, ByteVector(2))
39 | assertEquals(writes, List(Writer.WriteBytes(files(1), 0, ByteVector(2))))
40 | }
41 |
42 | test("write piece to both files") {
43 | val files =
44 | List(
45 | TorrentMetadata.File(
46 | 1L,
47 | List("foo.txt")
48 | ),
49 | TorrentMetadata.File(
50 | 1L,
51 | List("bar.txt")
52 | )
53 | )
54 | val writer = Writer(
55 | files,
56 | 2
57 | )
58 | val writes = writer.write(0, ByteVector(0, 1))
59 | assertEquals(
60 | writes,
61 | List(
62 | Writer.WriteBytes(files(0), 0, ByteVector(0)),
63 | Writer.WriteBytes(files(1), 0, ByteVector(1))
64 | )
65 | )
66 | }
67 |
68 | test("start writing in file with offset") {
69 | val files =
70 | List(
71 | TorrentMetadata.File(
72 | 3L,
73 | List("foo.txt")
74 | ),
75 | TorrentMetadata.File(
76 | 1L,
77 | List("bar.txt")
78 | )
79 | )
80 | val writer = Writer(
81 | files,
82 | 2
83 | )
84 | val writes = writer.write(1, ByteVector(0, 1))
85 | assertEquals(
86 | writes,
87 | List(
88 | Writer.WriteBytes(files(0), 2, ByteVector(0)),
89 | Writer.WriteBytes(files(1), 0, ByteVector(1))
90 | )
91 | )
92 | }
93 | }
94 |
--------------------------------------------------------------------------------
/dht/src/test/scala/com/github/torrentdam/bittorrent/dht/PeerDiscoverySpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.dht
2 |
3 | import cats.effect.kernel.Ref
4 | import cats.effect.IO
5 | import cats.effect.SyncIO
6 | import com.comcast.ip4s.*
7 | import com.github.torrentdam.bittorrent.InfoHash
8 | import com.github.torrentdam.bittorrent.PeerInfo
9 | import org.legogroup.woof.Logger
10 | import scodec.bits.ByteVector
11 |
12 | class PeerDiscoverySpec extends munit.CatsEffectSuite {
13 |
14 | test("discover new peers") {
15 |
16 | val infoHash = InfoHash(ByteVector.encodeUtf8("c").toOption.get)
17 |
18 | def nodeId(id: String) = NodeId(ByteVector.encodeUtf8(id).toOption.get)
19 |
20 | given logger: Logger[IO] = NoOpLogger()
21 |
22 | def getPeers(
23 | address: SocketAddress[IpAddress],
24 | infoHash: InfoHash
25 | ): IO[Either[Response.Nodes, Response.Peers]] = IO {
26 | address.port.value match {
27 | case 1 =>
28 | Left(
29 | Response.Nodes(
30 | nodeId("a"),
31 | List(
32 | NodeInfo(
33 | nodeId("b"),
34 | SocketAddress(ip"1.1.1.1", port"2")
35 | ),
36 | NodeInfo(
37 | nodeId("c"),
38 | SocketAddress(ip"1.1.1.1", port"3")
39 | )
40 | )
41 | )
42 | )
43 | case 2 =>
44 | Right(
45 | Response.Peers(
46 | nodeId("b"),
47 | List(
48 | PeerInfo(
49 | SocketAddress(ip"2.2.2.2", port"2")
50 | )
51 | )
52 | )
53 | )
54 | case 3 =>
55 | Right(
56 | Response.Peers(
57 | nodeId("c"),
58 | List(
59 | PeerInfo(
60 | SocketAddress(ip"2.2.2.2", port"3")
61 | )
62 | )
63 | )
64 | )
65 | }
66 | }
67 |
68 | for {
69 | state <- PeerDiscovery.DiscoveryState(
70 | initialNodes = List(
71 | NodeInfo(
72 | nodeId("a"),
73 | SocketAddress(ip"1.1.1.1", port"1")
74 | )
75 | ),
76 | infoHash = infoHash
77 | )
78 | list <- PeerDiscovery.start(infoHash, getPeers, state, 1).take(1).compile.toList
79 | } yield {
80 | assertEquals(
81 | list,
82 | List(
83 | PeerInfo(SocketAddress(ip"2.2.2.2", port"3"))
84 | )
85 | )
86 | }
87 | }
88 | }
89 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/test/scala/com/github/torrentdam/bittorrent/wire/WorkQueueSuite.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.wire
2 |
3 | import cats.effect.std.CountDownLatch
4 | import cats.effect.syntax.temporal.genTemporalOps_
5 | import cats.effect.IO
6 | import cats.effect.Resource
7 | import com.github.torrentdam.bittorrent.wire.RequestDispatcher.WorkQueue
8 | import com.github.torrentdam.bittorrent.wire.RequestDispatcher.WorkQueue.EmptyQueue
9 | import scala.concurrent.duration.DurationInt
10 | import scodec.bits.ByteVector
11 |
12 | class WorkQueueSuite extends munit.CatsEffectSuite {
13 |
14 | test("return request")(
15 | for
16 | workQueue <- WorkQueue(Seq(1), _ => IO.unit)
17 | request <- workQueue.nextRequest.use((request, _) => IO.pure(request))
18 | yield assertEquals(request, 1)
19 | )
20 |
21 | test("put request back into queue if it was not completed")(
22 | for
23 | workQueue <- WorkQueue(Seq(1, 2), _ => IO.unit)
24 | request0 <- workQueue.nextRequest.use((request, _) => IO.pure(request))
25 | request1 <- workQueue.nextRequest.use((request, _) => IO.pure(request))
26 | yield
27 | assertEquals(request0, 1)
28 | assertEquals(request1, 1)
29 | )
30 | test("delete request from queue if it was completed")(
31 | for
32 | workQueue <- WorkQueue(Seq(1, 2), _ => IO.unit)
33 | request0 <- workQueue.nextRequest.use((request, promise) => promise.complete(()).as(request))
34 | request1 <- workQueue.nextRequest.use((request, _) => IO.pure(request))
35 | yield
36 | assertEquals(request0, 1)
37 | assertEquals(request1, 2)
38 | )
39 | test("throw PieceComplete when last request was fulfilled")(
40 | for
41 | workQueue <- WorkQueue(Seq(1), _ => IO.unit)
42 | request <- workQueue.nextRequest.use((request, promise) => promise.complete(ByteVector.empty).as(request))
43 | result <- workQueue.nextRequest.use((request, _) => IO.pure(request)).attempt
44 | yield
45 | assertEquals(request, 1)
46 | assertEquals(result, Left(WorkQueue.PieceComplete))
47 | )
48 | test("throw EmptyQueue when queue is empty")(
49 | for
50 | workQueue <- WorkQueue(Seq(1), _ => IO.unit)
51 | finishFirst <- CountDownLatch[IO](1)
52 | fiber0 <- workQueue.nextRequest.use((request, _) => finishFirst.await.as(request)).start
53 | fiber1 <- workQueue.nextRequest.use((request, _) => IO.pure(request)).start
54 | request1 <- fiber1.joinWithNever.attempt
55 | _ <- finishFirst.release
56 | request0 <- fiber0.joinWithNever
57 | yield
58 | assertEquals(request0, 1)
59 | assertEquals(request1, Left(EmptyQueue))
60 | )
61 | }
62 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/protocol/message.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.protocol.message
2 |
3 | import com.github.torrentdam.bittorrent.InfoHash
4 | import com.github.torrentdam.bittorrent.PeerId
5 | import scala.util.chaining.*
6 | import scodec.bits.ByteVector
7 | import scodec.codecs.*
8 | import scodec.Codec
9 |
10 | final case class Handshake(
11 | extensionProtocol: Boolean,
12 | infoHash: InfoHash,
13 | peerId: PeerId
14 | )
15 |
16 | object Handshake {
17 | val ProtocolStringCodec: Codec[Unit] = uint8.unit(19) ~> fixedSizeBytes(
18 | 19,
19 | utf8.unit("BitTorrent protocol")
20 | )
21 | val ReserveCodec: Codec[Boolean] = bits(8 * 8).xmap(
22 | bv => bv.get(43),
23 | supported =>
24 | ByteVector
25 | .fill(8)(0)
26 | .toBitVector
27 | .pipe(v => if supported then v.set(43) else v)
28 | )
29 | val InfoHashCodec: Codec[InfoHash] = bytes(20).xmap(InfoHash(_), _.bytes)
30 | val PeerIdCodec: Codec[PeerId] = bytes(20).xmap(PeerId.apply, _.bytes)
31 | val HandshakeCodec: Codec[Handshake] =
32 | (ProtocolStringCodec ~> ReserveCodec :: InfoHashCodec :: PeerIdCodec).as
33 | }
34 |
35 | enum Message:
36 | case KeepAlive
37 | case Choke
38 | case Unchoke
39 | case Interested
40 | case NotInterested
41 | case Have(pieceIndex: Long)
42 | case Bitfield(bytes: ByteVector)
43 | case Request(index: Long, begin: Long, length: Long)
44 | case Piece(index: Long, begin: Long, bytes: ByteVector)
45 | case Cancel(index: Long, begin: Long, length: Long)
46 | case Port(port: Int)
47 | case Extended(id: Long, payload: ByteVector)
48 |
49 | object Message {
50 |
51 | val MessageSizeCodec: Codec[Long] = uint32
52 |
53 | val MessageBodyCodec: Codec[Message] = {
54 | val KeepAliveCodec: Codec[KeepAlive.type] = provide(KeepAlive).complete
55 |
56 | val OtherMessagesCodec: Codec[Message] =
57 | discriminated[Message]
58 | .by(uint8)
59 | .caseP(0) { case m @ Choke => m }(identity)(provide(Choke))
60 | .caseP(1) { case m @ Unchoke => m }(identity)(provide(Unchoke))
61 | .caseP(2) { case m @ Interested => m }(identity)(provide(Interested))
62 | .caseP(3) { case m @ NotInterested => m }(identity)(provide(NotInterested))
63 | .caseP(4) { case Have(index) => index }(Have.apply)(uint32)
64 | .caseP(5) { case Bitfield(bytes) => bytes }(Bitfield.apply)(bytes)
65 | .caseP(6) { case m: Request => m }(identity)((uint32 :: uint32 :: uint32).as)
66 | .caseP(7) { case m: Piece => m }(identity)((uint32 :: uint32 :: bytes).as)
67 | .caseP(8) { case m: Cancel => m }(identity)((uint32 :: uint32 :: uint32).as)
68 | .caseP(9) { case Port(port) => port }(Port.apply)(uint16)
69 | .caseP(20) { case m: Extended => m }(identity)((ulong(8) :: bytes).as)
70 |
71 | choice(
72 | KeepAliveCodec.upcast,
73 | OtherMessagesCodec
74 | )
75 | }
76 |
77 | val MessageCodec: Codec[Message] = {
78 | variableSizeBytesLong(
79 | MessageSizeCodec,
80 | MessageBodyCodec
81 | )
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: build
2 |
3 | on: [push]
4 |
5 | jobs:
6 |
7 | build:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@v2
11 | - name: Compile
12 | run: ./mill _.compile
13 | - name: Test
14 | run: ./mill _.test
15 | - name: Build distributable
16 | run: ./mill cmd.assembly
17 | - name: Set up Docker Buildx
18 | id: buildx
19 | uses: docker/setup-buildx-action@master
20 | - name: Login to GitHub Container Registry
21 | uses: docker/login-action@v1
22 | with:
23 | registry: ghcr.io
24 | username: ${{ github.actor }}
25 | password: ${{ secrets.GITHUB_TOKEN }}
26 | - name: Push cmd images
27 | uses: docker/build-push-action@v2
28 | with:
29 | context: .
30 | file: cmd/Dockerfile
31 | platforms: linux/amd64,linux/arm64
32 | push: true
33 | tags: ghcr.io/torrentdam/cmd:latest
34 |
35 | build-native:
36 | if: false
37 | runs-on: ubuntu-latest
38 | needs: build
39 | steps:
40 | - uses: actions/checkout@v1
41 | - uses: actions/cache@v1
42 | with:
43 | path: ~/.cache/coursier/v1
44 | key: ${{ runner.os }}-coursier-${{ hashFiles('**/build.sbt') }}
45 | - name: Set up java
46 | uses: actions/setup-java@v2.1.0
47 | with:
48 | distribution: adopt
49 | java-version: 17
50 | java-package: jre
51 | - name: Build Native
52 | env:
53 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
54 | run: sbt -J-Xmx4g cmdNative/nativeLink
55 | - name: Set up Docker Buildx
56 | id: buildx
57 | uses: docker/setup-buildx-action@master
58 | - name: Login to GitHub Container Registry
59 | uses: docker/login-action@v1
60 | with:
61 | registry: ghcr.io
62 | username: ${{ github.actor }}
63 | password: ${{ secrets.GITHUB_TOKEN }}
64 | - name: Push cmd images
65 | uses: docker/build-push-action@v2
66 | with:
67 | context: cmd
68 | file: cmd/Dockerfile-native
69 | platforms: linux/amd64
70 | push: true
71 | tags: ghcr.io/torrentdam/cmd-native:latest
72 |
73 | release:
74 | if: startsWith(github.ref, 'refs/tags/v')
75 | needs: build
76 | runs-on: ubuntu-latest
77 | steps:
78 | - uses: actions/checkout@v1
79 | - name: Set up java
80 | uses: actions/setup-java@v2.1.0
81 | with:
82 | distribution: adopt
83 | java-version: 17
84 | java-package: jre
85 | - name: Release
86 | env:
87 | SONATYPE_CREDS: ${{ secrets.SONATYPE_CREDS }}
88 | PGP_SECRET_KEY: ${{ secrets.PGP_SECRET_KEY }}
89 | run: |
90 | echo ${PGP_SECRET_KEY} | base64 --decode | gpg --import
91 | gpg --list-secret-keys
92 | export VERSION=${GITHUB_REF#*/v}
93 | echo Publishing $VERSION
94 | sbt commonJVM/publishSigned
95 | sbt commonJS/publishSigned
96 | sbt bittorrentJVM/publishSigned
97 | sbt dhtJVM/publishSigned
98 | sbt cmdJVM/publishSigned
99 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/TorrentMetadata.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent
2 |
3 | import cats.implicits.*
4 | import com.github.torrentdam.bencode
5 | import com.github.torrentdam.bencode.format.*
6 | import com.github.torrentdam.bencode.Bencode
7 | import com.github.torrentdam.bencode.BencodeFormatException
8 | import java.time.Instant
9 | import scodec.bits.ByteVector
10 |
11 | case class TorrentMetadata(
12 | name: String,
13 | pieceLength: Long,
14 | pieces: ByteVector,
15 | files: List[TorrentMetadata.File]
16 | )
17 |
18 | object TorrentMetadata {
19 |
20 | case class File(
21 | length: Long,
22 | path: List[String]
23 | )
24 |
25 | given BencodeFormat[File] =
26 | (
27 | field[Long]("length"),
28 | field[List[String]]("path")
29 | ).imapN(File.apply)(v => (v.length, v.path))
30 |
31 | given BencodeFormat[TorrentMetadata] = {
32 | def to(name: String, pieceLength: Long, pieces: ByteVector, length: Option[Long], filesOpt: Option[List[File]]) = {
33 | val files = length match {
34 | case Some(length) => List(File(length, List(name)))
35 | case None => filesOpt.combineAll
36 | }
37 | TorrentMetadata(name, pieceLength, pieces, files)
38 | }
39 | def from(v: TorrentMetadata) =
40 | (v.name, v.pieceLength, v.pieces, Option.empty[Long], Some(v.files))
41 | (
42 | field[String]("name"),
43 | field[Long]("piece length"),
44 | field[ByteVector]("pieces"),
45 | fieldOptional[Long]("length"),
46 | fieldOptional[List[File]]("files")
47 | ).imapN(to)(from)
48 | }
49 |
50 | case class Lossless private (
51 | parsed: TorrentMetadata,
52 | raw: Bencode
53 | ) {
54 | def infoHash: InfoHash = InfoHash(CrossPlatform.sha1(bencode.encode(raw).toByteVector))
55 | }
56 |
57 | object Lossless {
58 | def fromBytes(bytes: ByteVector): Either[Throwable, Lossless] =
59 | bencode.decode(bytes.bits).flatMap(fromBencode)
60 |
61 | def fromBencode(bcode: Bencode): Either[BencodeFormatException, Lossless] =
62 | summon[BencodeFormat[TorrentMetadata]].read(bcode).map { metadata =>
63 | Lossless(metadata, bcode)
64 | }
65 | given BencodeFormat[Lossless] = {
66 | BencodeFormat(
67 | read = BencodeReader(fromBencode),
68 | write = BencodeWriter(metadata => BencodeFormat.BencodeValueFormat.write(metadata.raw))
69 | )
70 | }
71 | }
72 | }
73 |
74 | case class TorrentFile(
75 | info: TorrentMetadata.Lossless,
76 | creationDate: Option[Instant]
77 | )
78 |
79 | object TorrentFile {
80 |
81 | private given BencodeFormat[Instant] =
82 | BencodeFormat.LongFormat.imap(Instant.ofEpochMilli)(_.toEpochMilli)
83 |
84 | given torrentFileFormat: BencodeFormat[TorrentFile] = {
85 | (
86 | field[TorrentMetadata.Lossless]("info"),
87 | fieldOptional[Instant]("creationDate")
88 | ).imapN(TorrentFile(_, _))(v => (v.info, v.creationDate))
89 | }
90 |
91 | def fromBencode(bcode: Bencode): Either[BencodeFormatException, TorrentFile] =
92 | torrentFileFormat.read(bcode)
93 | def fromBytes(bytes: ByteVector): Either[Throwable, TorrentFile] =
94 | bencode.decode(bytes.bits).flatMap(fromBencode)
95 |
96 | def toBytes(torrentFile: TorrentFile): ByteVector =
97 | bencode.encode(torrentFileFormat.write(torrentFile).right.get).toByteVector
98 | }
99 |
--------------------------------------------------------------------------------
/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RoutingTableBootstrap.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.dht
2 |
3 | import cats.effect.kernel.Temporal
4 | import cats.implicits.*
5 | import cats.MonadError
6 | import cats.effect.IO
7 | import cats.effect.implicits.*
8 | import cats.effect.cps.{given, *}
9 | import com.comcast.ip4s.*
10 | import com.github.torrentdam.bittorrent.InfoHash
11 | import org.legogroup.woof.given
12 | import org.legogroup.woof.Logger
13 | import fs2.Stream
14 |
15 | import scala.concurrent.duration.*
16 |
17 | object RoutingTableBootstrap {
18 |
19 | def apply(
20 | table: RoutingTable[IO],
21 | client: Client,
22 | discovery: PeerDiscovery,
23 | bootstrapNodeAddress: List[SocketAddress[Host]] = PublicBootstrapNodes
24 | )(using
25 | dns: Dns[IO],
26 | logger: Logger[IO]
27 | ): IO[Unit] =
28 | for
29 | _ <- logger.info("Bootstrapping")
30 | count <- resolveNodes(client, bootstrapNodeAddress).compile.count.iterateUntil(_ > 0)
31 | _ <- logger.info(s"Communicated with $count bootstrap nodes")
32 | _ <- selfDiscovery(table, client, discovery)
33 | nodeCount <- table.allNodes.map(_.size)
34 | _ <- logger.info(s"Bootstrapping finished with $nodeCount nodes")
35 | yield {}
36 |
37 | private def resolveNodes(
38 | client: Client,
39 | bootstrapNodeAddress: List[SocketAddress[Host]]
40 | )(using
41 | dns: Dns[IO],
42 | logger: Logger[IO]
43 | ): Stream[IO, NodeInfo] =
44 | def tryThis(hostname: SocketAddress[Host]): Stream[IO, NodeInfo] =
45 | Stream.eval(logger.info(s"Trying to reach $hostname")) >>
46 | Stream
47 | .evals(
48 | hostname.host.resolveAll[IO]
49 | .recoverWith: e =>
50 | logger.info(s"Failed to resolve $hostname $e").as(List.empty)
51 | )
52 | .evalMap: ipAddress =>
53 | val resolvedAddress = SocketAddress(ipAddress, hostname.port)
54 | logger.info(s"Resolved to $ipAddress") *>
55 | client
56 | .ping(resolvedAddress)
57 | .timeout(5.seconds)
58 | .map(pong => NodeInfo(pong.id, resolvedAddress))
59 | .flatTap: _ =>
60 | logger.info(s"Reached $resolvedAddress node")
61 | .map(_.some)
62 | .recoverWith: e =>
63 | logger.info(s"Failed to reach $resolvedAddress $e").as(none)
64 | .collect {
65 | case Some(node) => node
66 | }
67 | Stream
68 | .emits(bootstrapNodeAddress)
69 | .covary[IO]
70 | .flatMap(tryThis)
71 |
72 | private def selfDiscovery(
73 | table: RoutingTable[IO],
74 | client: Client,
75 | discovery: PeerDiscovery
76 | )(using Logger[IO]) =
77 | def attempt(number: Int): IO[Unit] = async[IO]:
78 | Logger[IO].info(s"Discover self to fill up routing table (attempt $number)").await
79 | val count = discovery.findNodes(client.id).take(30).interruptAfter(30.seconds).compile.count.await
80 | Logger[IO].info(s"Communicated with $count nodes during self discovery").await
81 | val nodeCount = table.allNodes.await.size
82 | if nodeCount < 20 then attempt(number + 1).await else IO.unit
83 | attempt(1)
84 |
85 | val PublicBootstrapNodes: List[SocketAddress[Host]] = List(
86 | SocketAddress(host"router.bittorrent.com", port"6881"),
87 | SocketAddress(host"router.utorrent.com", port"6881"),
88 | SocketAddress(host"dht.transmissionbt.com", port"6881"),
89 | SocketAddress(host"router.bitcomet.com", port"6881"),
90 | SocketAddress(host"dht.aelitis.com", port"6881"),
91 | )
92 | }
93 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/Swarm.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.wire
2 |
3 | import cats.*
4 | import cats.effect.implicits.*
5 | import cats.effect.std.Queue
6 | import cats.effect.IO
7 | import cats.effect.Outcome
8 | import cats.effect.Resource
9 | import cats.implicits.*
10 | import com.github.torrentdam.bittorrent.PeerInfo
11 | import fs2.concurrent.Signal
12 | import fs2.concurrent.SignallingRef
13 | import fs2.concurrent.Topic
14 | import fs2.Stream
15 | import org.legogroup.woof.*
16 | import org.legogroup.woof.given
17 | import org.legogroup.woof.Logger.withLogContext
18 | import scala.concurrent.duration.*
19 |
20 | trait Swarm {
21 | def connect: Resource[IO, Connection]
22 | def connected: Connected
23 | }
24 |
25 | trait Connected {
26 | def count: Signal[IO, Int]
27 | def list: IO[List[Connection]]
28 | }
29 |
30 | object Swarm {
31 |
32 | def apply(
33 | peers: Stream[IO, PeerInfo],
34 | connect: PeerInfo => Resource[IO, Connection]
35 | )(using
36 | logger: Logger[IO]
37 | ): Resource[IO, Swarm] =
38 | for
39 | _ <- Resource.make(logger.info("Starting swarm"))(_ => logger.info("Swarm closed"))
40 | stateRef <- Resource.eval(SignallingRef[IO].of(Map.empty[PeerInfo, Connection]))
41 | newConnections <- Resource.eval(Queue.bounded[IO, Resource[IO, Connection]](10))
42 | reconnects <- Resource.eval(Queue.unbounded[IO, Resource[IO, Connection]])
43 | scheduleReconnect = (delay: FiniteDuration) =>
44 | (reconnect: Resource[IO, Connection]) => (IO.sleep(delay) >> reconnects.offer(reconnect)).start.void
45 | allSeen <- Resource.eval(IO.ref(Set.empty[PeerInfo]))
46 | _ <- peers
47 | .evalMap(peerInfo =>
48 | allSeen.getAndUpdate(_ + peerInfo).flatMap { seen =>
49 | if seen(peerInfo)
50 | then IO.pure(None)
51 | else IO.pure(Some(peerInfo))
52 | }
53 | )
54 | .collect { case Some(peerInfo) => peerInfo }
55 | .map(peerInfo => newConnection(connect(peerInfo), scheduleReconnect))
56 | .evalTap(newConnections.offer)
57 | .compile
58 | .drain
59 | .background
60 | connectOrReconnect =
61 | for
62 | resource <- Resource.eval(newConnections.take race reconnects.take)
63 | connection <- resource.merge
64 | yield connection
65 | yield new Impl(stateRef, connectOrReconnect)
66 | end for
67 |
68 | private class Impl(
69 | stateRef: SignallingRef[IO, Map[PeerInfo, Connection]],
70 | connectOrReconnect: Resource[IO, Connection]
71 | ) extends Swarm {
72 | val connect: Resource[IO, Connection] =
73 | connectOrReconnect.flatTap(connection =>
74 | Resource.make {
75 | stateRef.update(_ + (connection.info -> connection))
76 | } { _ =>
77 | stateRef.update(_ - connection.info)
78 | }
79 | )
80 | val connected: Connected = new {
81 | val count: Signal[IO, Int] = stateRef.map(_.size)
82 | val list: IO[List[Connection]] = stateRef.get.map(_.values.toList)
83 | }
84 | }
85 |
86 | private def newConnection(
87 | connect: Resource[IO, Connection],
88 | schedule: FiniteDuration => Resource[IO, Connection] => IO[Unit]
89 | ): Resource[IO, Connection] = {
90 | val maxAttempts = 24
91 | def connectWithRetry(attempt: Int): Resource[IO, Connection] =
92 | connect
93 | .onFinalizeCase {
94 | case Resource.ExitCase.Succeeded =>
95 | schedule(10.second)(connectWithRetry(1))
96 | case Resource.ExitCase.Errored(_) =>
97 | if attempt == maxAttempts
98 | then IO.unit
99 | else
100 | val duration = (10 * attempt).seconds
101 | schedule(duration)(connectWithRetry(attempt + 1))
102 | case _ =>
103 | IO.unit
104 | }
105 | connectWithRetry(1)
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/test/scala/com/github/torrentdam/bittorrent/TorrentMetadataSpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent
2 |
3 | import com.github.torrentdam.bencode.*
4 | import com.github.torrentdam.bencode.format.BencodeFormat
5 | import com.github.torrentdam.bittorrent.CrossPlatform
6 | import scodec.bits.Bases
7 | import scodec.bits.BitVector
8 | import scodec.bits.ByteVector
9 |
10 | class TorrentMetadataSpec extends munit.FunSuite {
11 |
12 | test("encode file class") {
13 | val result = summon[BencodeFormat[TorrentMetadata.File]].write(TorrentMetadata.File(77, "abc" :: Nil))
14 | val expectation = Right(
15 | Bencode.BDictionary(
16 | "length" -> Bencode.BInteger(77),
17 | "path" -> Bencode.BList(Bencode.BString("abc") :: Nil)
18 | )
19 | )
20 | assert(result == expectation)
21 | }
22 |
23 | test("calculate info_hash") {
24 | val source = getClass.getClassLoader
25 | .getResourceAsStream("bencode/ubuntu-18.10-live-server-amd64.iso.torrent")
26 | .readAllBytes()
27 | val Right(bc) = decode(BitVector(source)): @unchecked
28 | val decodedResult = summon[BencodeFormat[TorrentFile]].read(bc)
29 | val result = decodedResult
30 | .map(_.info.raw)
31 | .map(encode(_).bytes)
32 | .map(CrossPlatform.sha1)
33 | .map(_.toHex(Bases.Alphabets.HexUppercase))
34 | val expectation = Right("8C4ADBF9EBE66F1D804FB6A4FB9B74966C3AB609")
35 | assert(result == expectation)
36 | }
37 |
38 | test("decode either a or b") {
39 | val input = Bencode.BDictionary(
40 | "name" -> Bencode.BString("file_name"),
41 | "piece length" -> Bencode.BInteger(10),
42 | "pieces" -> Bencode.BString.Empty,
43 | "length" -> Bencode.BInteger(10)
44 | )
45 |
46 | assert(
47 | summon[BencodeFormat[TorrentMetadata]].read(input) == Right(
48 | TorrentMetadata("file_name", 10, ByteVector.empty, List(TorrentMetadata.File(10, List("file_name"))))
49 | )
50 | )
51 |
52 | val input1 = Bencode.BDictionary(
53 | "name" -> Bencode.BString("test"),
54 | "piece length" -> Bencode.BInteger(10),
55 | "pieces" -> Bencode.BString.Empty,
56 | "files" -> Bencode.BList(
57 | Bencode.BDictionary(
58 | "length" -> Bencode.BInteger(10),
59 | "path" -> Bencode.BList(Bencode.BString("/root") :: Nil)
60 | ) :: Nil
61 | )
62 | )
63 |
64 | assert(
65 | summon[BencodeFormat[TorrentMetadata]].read(input1) == Right(
66 | TorrentMetadata("test", 10, ByteVector.empty, TorrentMetadata.File(10, "/root" :: Nil) :: Nil)
67 | )
68 | )
69 | }
70 |
71 | test("decode dictionary") {
72 | val input = Bencode.BDictionary(
73 | "name" -> Bencode.BString("file_name"),
74 | "piece length" -> Bencode.BInteger(10),
75 | "pieces" -> Bencode.BString(ByteVector(10)),
76 | "length" -> Bencode.BInteger(10)
77 | )
78 |
79 | assert(
80 | summon[BencodeFormat[TorrentMetadata]].read(input) == Right(
81 | TorrentMetadata("file_name", 10, ByteVector(10), List(TorrentMetadata.File(10, List("file_name"))))
82 | )
83 | )
84 | }
85 |
86 | test("decode ubuntu torrent") {
87 | assert(decode(BitVector.encodeAscii("i56e").toOption.get) == Right(Bencode.BInteger(56L)))
88 | assert(decode(BitVector.encodeAscii("2:aa").toOption.get) == Right(Bencode.BString("aa")))
89 | assert(
90 | decode(BitVector.encodeAscii("l1:a2:bbe").toOption.get) == Right(
91 | Bencode.BList(Bencode.BString("a") :: Bencode.BString("bb") :: Nil)
92 | )
93 | )
94 | assert(
95 | decode(BitVector.encodeAscii("d1:ai6ee").toOption.get) == Right(
96 | Bencode.BDictionary("a" -> Bencode.BInteger(6))
97 | )
98 | )
99 | val source = getClass.getClassLoader
100 | .getResourceAsStream("bencode/ubuntu-18.10-live-server-amd64.iso.torrent")
101 | .readAllBytes()
102 | val Right(result) = decode(BitVector(source)): @unchecked
103 | val decodeResult = summon[BencodeFormat[TorrentFile]].read(result)
104 | assert(decodeResult.isRight)
105 | }
106 |
107 | }
108 |
--------------------------------------------------------------------------------
/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RequestResponse.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.dht
2 |
3 | import cats.*
4 | import cats.effect.kernel.Deferred
5 | import cats.effect.kernel.Ref
6 | import cats.effect.kernel.Temporal
7 | import cats.effect.syntax.all.*
8 | import cats.effect.Concurrent
9 | import cats.effect.Resource
10 | import cats.syntax.all.*
11 | import com.comcast.ip4s.*
12 | import com.github.torrentdam.bittorrent.dht.RequestResponse.Timeout
13 | import com.github.torrentdam.bencode.Bencode
14 | import scala.concurrent.duration.*
15 | import scodec.bits.ByteVector
16 |
17 | trait RequestResponse[F[_]] {
18 | def sendQuery(address: SocketAddress[IpAddress], query: Query): F[Response]
19 | }
20 |
21 | object RequestResponse {
22 |
23 | def make[F[_]](
24 | generateTransactionId: F[ByteVector],
25 | sendQuery: (SocketAddress[IpAddress], Message.QueryMessage) => F[Unit],
26 | receiveMessage: F[
27 | (SocketAddress[IpAddress], Message.ErrorMessage | Message.ResponseMessage)
28 | ]
29 | )(using
30 | F: Temporal[F]
31 | ): Resource[F, RequestResponse[F]] =
32 | Resource {
33 | for {
34 | callbackRegistry <- CallbackRegistry.make[F]
35 | fiber <- receiveLoop(receiveMessage, callbackRegistry.complete).start
36 | } yield {
37 | new Impl(generateTransactionId, sendQuery, callbackRegistry.add) -> fiber.cancel
38 | }
39 | }
40 |
41 | private class Impl[F[_]](
42 | generateTransactionId: F[ByteVector],
43 | sendQueryMessage: (SocketAddress[IpAddress], Message.QueryMessage) => F[Unit],
44 | receive: ByteVector => F[Either[Throwable, Response]]
45 | )(using F: MonadError[F, Throwable])
46 | extends RequestResponse[F] {
47 | def sendQuery(address: SocketAddress[IpAddress], query: Query): F[Response] = {
48 | generateTransactionId.flatMap { transactionId =>
49 | val send = sendQueryMessage(
50 | address,
51 | Message.QueryMessage(transactionId, query)
52 | )
53 | send >> receive(transactionId).flatMap(F.fromEither)
54 | }
55 | }
56 | }
57 |
58 | private def receiveLoop[F[_]](
59 | receive: F[
60 | (SocketAddress[IpAddress], Message.ErrorMessage | Message.ResponseMessage)
61 | ],
62 | continue: (ByteVector, Either[Throwable, Response]) => F[Boolean]
63 | )(using
64 | F: Monad[F]
65 | ): F[Unit] = {
66 | val step = receive.map(_._2).flatMap {
67 | case Message.ResponseMessage(transactionId, response) =>
68 | continue(transactionId, response.asRight)
69 | case Message.ErrorMessage(transactionId, details) =>
70 | continue(transactionId, ErrorResponse(details).asLeft)
71 | }
72 | step.foreverM[Unit]
73 | }
74 |
75 | case class ErrorResponse(details: Bencode) extends Throwable
76 | case class InvalidResponse() extends Throwable
77 | case class Timeout() extends Throwable
78 | }
79 |
80 | trait CallbackRegistry[F[_]] {
81 | def add(transactionId: ByteVector): F[Either[Throwable, Response]]
82 |
83 | def complete(transactionId: ByteVector, result: Either[Throwable, Response]): F[Boolean]
84 | }
85 |
86 | object CallbackRegistry {
87 | def make[F[_]: Temporal]: F[CallbackRegistry[F]] = {
88 | for {
89 | ref <-
90 | Ref
91 | .of[F, Map[ByteVector, Either[Throwable, Response] => F[Boolean]]](
92 | Map.empty
93 | )
94 | } yield {
95 | new Impl(ref)
96 | }
97 | }
98 |
99 | private class Impl[F[_]](
100 | ref: Ref[F, Map[ByteVector, Either[Throwable, Response] => F[Boolean]]]
101 | )(using F: Temporal[F])
102 | extends CallbackRegistry[F] {
103 | def add(transactionId: ByteVector): F[Either[Throwable, Response]] = {
104 | F.deferred[Either[Throwable, Response]].flatMap { deferred =>
105 | val update =
106 | ref.update { map =>
107 | map.updated(transactionId, deferred.complete)
108 | }
109 | val delete =
110 | ref.update { map =>
111 | map - transactionId
112 | }
113 | (update *> deferred.get).guarantee(delete)
114 | }
115 | }
116 |
117 | def complete(transactionId: ByteVector, result: Either[Throwable, Response]): F[Boolean] =
118 | ref.get.flatMap { map =>
119 | map.get(transactionId) match {
120 | case Some(callback) => callback(result)
121 | case None => false.pure[F]
122 | }
123 | }
124 | }
125 | }
126 |
--------------------------------------------------------------------------------
/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Client.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.dht
2 |
3 | import cats.effect.kernel.Temporal
4 | import cats.effect.std.{Queue, Random}
5 | import cats.effect.{Concurrent, IO, Resource, Sync}
6 | import cats.syntax.all.*
7 | import com.comcast.ip4s.*
8 | import com.github.torrentdam.bittorrent.InfoHash
9 |
10 | import java.net.InetSocketAddress
11 | import org.legogroup.woof.given
12 | import org.legogroup.woof.Logger
13 | import scodec.bits.ByteVector
14 |
15 | trait Client {
16 |
17 | def id: NodeId
18 |
19 | def getPeers(address: SocketAddress[IpAddress], infoHash: InfoHash): IO[Either[Response.Nodes, Response.Peers]]
20 |
21 | def findNodes(address: SocketAddress[IpAddress], target: NodeId): IO[Response.Nodes]
22 |
23 | def ping(address: SocketAddress[IpAddress]): IO[Response.Ping]
24 |
25 | def sampleInfoHashes(address: SocketAddress[IpAddress], target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]]
26 | }
27 |
28 | object Client {
29 |
30 | def generateTransactionId(using random: Random[IO]): IO[ByteVector] =
31 | val nextChar = random.nextAlphaNumeric
32 | (nextChar, nextChar).mapN((a, b) => ByteVector.encodeAscii(List(a, b).mkString).toOption.get)
33 |
34 | def apply(
35 | selfId: NodeId,
36 | messageSocket: MessageSocket,
37 | queryHandler: QueryHandler[IO]
38 | )(using Logger[IO], Random[IO]): Resource[IO, Client] = {
39 | for
40 | responses <- Resource.eval {
41 | Queue.unbounded[IO, (SocketAddress[IpAddress], Message.ErrorMessage | Message.ResponseMessage)]
42 | }
43 | requestResponse <- RequestResponse.make(
44 | generateTransactionId,
45 | messageSocket.writeMessage,
46 | responses.take
47 | )
48 | _ <-
49 | messageSocket.readMessage
50 | .flatMap {
51 | case (a, m: Message.QueryMessage) =>
52 | Logger[IO].debug(s"Received $m") >>
53 | queryHandler(a, m.query).flatMap {
54 | case Some(response) =>
55 | val responseMessage = Message.ResponseMessage(m.transactionId, response)
56 | Logger[IO].debug(s"Responding with $responseMessage") >>
57 | messageSocket.writeMessage(a, responseMessage)
58 | case None =>
59 | Logger[IO].debug(s"No response for $m")
60 | }
61 | case (a, m: Message.ResponseMessage) => responses.offer((a, m))
62 | case (a, m: Message.ErrorMessage) => responses.offer((a, m))
63 | }
64 | .recoverWith { case e: Throwable =>
65 | Logger[IO].debug(s"Failed to read message: $e")
66 | }
67 | .foreverM
68 | .background
69 | yield new Client {
70 |
71 | def id: NodeId = selfId
72 |
73 | def getPeers(
74 | address: SocketAddress[IpAddress],
75 | infoHash: InfoHash
76 | ): IO[Either[Response.Nodes, Response.Peers]] =
77 | requestResponse.sendQuery(address, Query.GetPeers(selfId, infoHash)).flatMap {
78 | case nodes: Response.Nodes => nodes.asLeft.pure
79 | case peers: Response.Peers => peers.asRight.pure
80 | case _ => IO.raiseError(InvalidResponse())
81 | }
82 |
83 | def findNodes(address: SocketAddress[IpAddress], target: NodeId): IO[Response.Nodes] =
84 | requestResponse.sendQuery(address, Query.FindNode(selfId, target)).flatMap {
85 | case nodes: Response.Nodes => nodes.pure
86 | case _ => IO.raiseError(InvalidResponse())
87 | }
88 |
89 | def ping(address: SocketAddress[IpAddress]): IO[Response.Ping] =
90 | requestResponse.sendQuery(address, Query.Ping(selfId)).flatMap {
91 | case ping: Response.Ping => ping.pure
92 | case _ => IO.raiseError(InvalidResponse())
93 | }
94 | def sampleInfoHashes(address: SocketAddress[IpAddress], target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]] =
95 | requestResponse.sendQuery(address, Query.SampleInfoHashes(selfId, target)).flatMap {
96 | case response: Response.SampleInfoHashes => response.asRight[Response.Nodes].pure
97 | case response: Response.Nodes => response.asLeft[Response.SampleInfoHashes].pure
98 | case _ => IO.raiseError(InvalidResponse())
99 | }
100 | }
101 | }
102 |
103 | case class InvalidResponse() extends Throwable
104 | }
105 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/MessageSocket.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.wire
2 |
3 | import cats.effect.std.Semaphore
4 | import cats.effect.syntax.all.*
5 | import cats.effect.Async
6 | import cats.effect.Resource
7 | import cats.effect.Temporal
8 | import cats.syntax.all.*
9 | import com.github.torrentdam.bittorrent.*
10 | import com.github.torrentdam.bittorrent.protocol.message.{Handshake, Message}
11 | import fs2.io.net.Network
12 | import fs2.io.net.Socket
13 | import fs2.io.net.SocketGroup
14 | import fs2.Chunk
15 | import org.legogroup.woof.given
16 | import org.legogroup.woof.Logger
17 |
18 | import scala.concurrent.duration.*
19 | import scodec.bits.ByteVector
20 |
21 | class MessageSocket[F[_]](
22 | val handshake: Handshake,
23 | val peerInfo: PeerInfo,
24 | socket: Socket[F],
25 | logger: Logger[F]
26 | )(using F: Temporal[F]) {
27 |
28 | import MessageSocket.readTimeout
29 | import MessageSocket.writeTimeout
30 | import MessageSocket.MaxMessageSize
31 | import MessageSocket.OversizedMessage
32 |
33 | def send(message: Message): F[Unit] =
34 | val bytes = Chunk.byteVector(Message.MessageCodec.encode(message).require.toByteVector)
35 | for
36 | _ <- socket.write(bytes)
37 | _ <- logger.trace(s">>> ${peerInfo.address} $message")
38 | yield ()
39 |
40 | def receive: F[Message] =
41 | for
42 | bytes <- readExactlyN(4)
43 | size <-
44 | Message.MessageSizeCodec
45 | .decodeValue(bytes.toBitVector)
46 | .toTry
47 | .liftTo[F]
48 | _ <- F.whenA(size > MaxMessageSize)(
49 | logger.error(s"Oversized payload $size $MaxMessageSize") >>
50 | OversizedMessage(size, MaxMessageSize).raiseError
51 | )
52 | message <-
53 | if (size == 0) F.pure(Message.KeepAlive)
54 | else
55 | readExactlyN(size.toInt).flatMap(bytes =>
56 | F.fromTry(
57 | Message.MessageBodyCodec
58 | .decodeValue(bytes.toBitVector)
59 | .toTry
60 | )
61 | )
62 | _ <- logger.trace(s"<<< ${peerInfo.address} $message")
63 | yield message
64 |
65 | private def readExactlyN(numBytes: Int): F[ByteVector] =
66 | for
67 | chunk <- socket.readN(numBytes)
68 | _ <- if chunk.size == numBytes then F.unit else F.raiseError(new Exception("Connection was interrupted by peer"))
69 | yield chunk.toByteVector
70 |
71 | }
72 |
73 | object MessageSocket {
74 |
75 | val MaxMessageSize: Long = 1024 * 1024 // 1MB
76 | val readTimeout = 1.minute
77 | val writeTimeout = 10.seconds
78 |
79 | def connect[F[_]](selfId: PeerId, peerInfo: PeerInfo, infoHash: InfoHash)(using
80 | F: Async[F],
81 | network: Network[F],
82 | logger: Logger[F]
83 | ): Resource[F, MessageSocket[F]] = {
84 | for
85 | socket <- network.client(to = peerInfo.address).timeout(5.seconds)
86 | _ <- Resource.make(F.unit)(_ => logger.trace(s"Closed socket $peerInfo"))
87 | _ <- Resource.eval(logger.trace(s"Opened socket $peerInfo"))
88 | handshakeResponse <- Resource.eval(
89 | logger.trace(s"Initiate handshake with ${peerInfo.address}") *>
90 | handshake(selfId, infoHash, socket) <*
91 | logger.trace(s"Successful handshake with ${peerInfo.address}")
92 | )
93 | yield new MessageSocket(handshakeResponse, peerInfo, socket, logger)
94 | }
95 |
96 | def handshake[F[_]](
97 | selfId: PeerId,
98 | infoHash: InfoHash,
99 | socket: Socket[F]
100 | )(using F: Temporal[F]): F[Handshake] = {
101 | val message = Handshake(extensionProtocol = true, infoHash, selfId)
102 | for
103 | _ <- socket
104 | .write(
105 | bytes = Chunk.byteVector(
106 | Handshake.HandshakeCodec.encode(message).require.toByteVector
107 | )
108 | )
109 | .timeout(writeTimeout)
110 | handshakeMessageSize = Handshake.HandshakeCodec.sizeBound.exact.get.toInt / 8
111 | bytes <-
112 | socket
113 | .readN(handshakeMessageSize)
114 | .timeout(readTimeout)
115 | .adaptError(e =>
116 | Error("Unsuccessful handshake", e)
117 | )
118 | _ <-
119 | if bytes.size == handshakeMessageSize
120 | then F.unit
121 | else F.raiseError(Error("Unsuccessful handshake: connection prematurely closed"))
122 | response <- F.fromEither(
123 | Handshake.HandshakeCodec
124 | .decodeValue(bytes.toBitVector)
125 | .toEither
126 | .leftMap { e =>
127 | Error(s"Unable to decode handhshake reponse: ${e.message}")
128 | }
129 | )
130 | yield response
131 | }
132 |
133 | case class Error(message: String, cause: Throwable = null) extends Exception(message, cause)
134 | case class OversizedMessage(size: Long, maxSize: Long) extends Throwable(s"Oversized message [$size > $maxSize]")
135 | }
136 |
--------------------------------------------------------------------------------
/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Node.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.dht
2 |
3 | import cats.effect.implicits.*
4 | import cats.effect.std.Queue
5 | import cats.effect.std.Random
6 | import cats.effect.Async
7 | import cats.effect.IO
8 | import cats.effect.Resource
9 | import cats.effect.Sync
10 | import cats.implicits.*
11 | import com.comcast.ip4s.*
12 | import fs2.Stream
13 | import com.github.torrentdam.bittorrent.InfoHash
14 |
15 | import java.net.InetSocketAddress
16 | import org.legogroup.woof.given
17 | import org.legogroup.woof.Logger
18 | import scodec.bits.ByteVector
19 |
20 | import scala.concurrent.duration.DurationInt
21 |
22 | class Node(val id: NodeId, val client: Client, val routingTable: RoutingTable[IO], val discovery: PeerDiscovery)
23 |
24 | object Node {
25 |
26 | def apply(
27 | port: Option[Port] = None,
28 | bootstrapNodeAddress: Option[SocketAddress[Host]] = None
29 | )(using
30 | random: Random[IO],
31 | logger: Logger[IO]
32 | ): Resource[IO, Node] =
33 | for
34 | selfId <- Resource.eval(NodeId.random[IO])
35 | messageSocket <- MessageSocket(port)
36 | routingTable <- RoutingTable[IO](selfId).toResource
37 | queryingNodes <- Queue.unbounded[IO, NodeInfo].toResource
38 | queryHandler = reportingQueryHandler(queryingNodes, QueryHandler.simple(selfId, routingTable))
39 | client <- Client(selfId, messageSocket, queryHandler)
40 | insertingClient = new InsertingClient(client, routingTable)
41 | bootstrapNodes = bootstrapNodeAddress.map(List(_)).getOrElse(RoutingTableBootstrap.PublicBootstrapNodes)
42 | discovery = PeerDiscovery(routingTable, insertingClient)
43 | _ <- RoutingTableBootstrap(routingTable, insertingClient, discovery, bootstrapNodes).toResource
44 | _ <- RoutingTableRefresh(routingTable, client, discovery).runEvery(15.minutes).background
45 | _ <- pingCandidates(queryingNodes, client, routingTable).background
46 | yield new Node(selfId, insertingClient, routingTable, discovery)
47 |
48 | private class InsertingClient(client: Client, routingTable: RoutingTable[IO]) extends Client {
49 |
50 | def id: NodeId = client.id
51 |
52 | def getPeers(address: SocketAddress[IpAddress], infoHash: InfoHash): IO[Either[Response.Nodes, Response.Peers]] =
53 | client.getPeers(address, infoHash).flatTap { response =>
54 | routingTable.insert(
55 | NodeInfo(
56 | response match
57 | case Left(response) => response.id
58 | case Right(response) => response.id,
59 | address
60 | )
61 | )
62 | }
63 |
64 | def findNodes(address: SocketAddress[IpAddress], target: NodeId): IO[Response.Nodes] =
65 | client.findNodes(address, target).flatTap { response =>
66 | routingTable.insert(NodeInfo(response.id, address))
67 | }
68 |
69 | def ping(address: SocketAddress[IpAddress]): IO[Response.Ping] =
70 | client.ping(address).flatTap { response =>
71 | routingTable.insert(NodeInfo(response.id, address))
72 | }
73 |
74 | def sampleInfoHashes(address: SocketAddress[IpAddress], target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]] =
75 | client.sampleInfoHashes(address, target).flatTap { response =>
76 | routingTable.insert(
77 | response match
78 | case Left(response) => NodeInfo(response.id, address)
79 | case Right(response) => NodeInfo(response.id, address)
80 | )
81 | }
82 |
83 | override def toString: String = s"InsertingClient($client)"
84 | }
85 |
86 | private def pingCandidate(node: NodeInfo, client: Client, routingTable: RoutingTable[IO])(using Logger[IO]) =
87 | routingTable.lookup(node.id).flatMap {
88 | case Some(_) => IO.unit
89 | case None =>
90 | Logger[IO].info(s"Pinging $node") *>
91 | client.ping(node.address).timeout(5.seconds).attempt.flatMap {
92 | case Right(_) =>
93 | Logger[IO].info(s"Got pong from $node -- insert as good") *>
94 | routingTable.insert(node)
95 | case Left(_) => IO.unit
96 | }
97 | }
98 |
99 | private def pingCandidates(nodes: Queue[IO, NodeInfo], client: Client, routingTable: RoutingTable[IO])(using Logger[IO]) =
100 | nodes
101 | .tryTakeN(none)
102 | .flatMap(candidates =>
103 | candidates
104 | .distinct
105 | .traverse_(pingCandidate(_, client, routingTable).attempt.void)
106 | )
107 | .productR(IO.sleep(1.minute))
108 | .foreverM
109 |
110 |
111 | private def reportingQueryHandler(queue: Queue[IO, NodeInfo], next: QueryHandler[IO]): QueryHandler[IO] = (address, query) =>
112 | val nodeInfo = query match
113 | case Query.Ping(id) => NodeInfo(id, address)
114 | case Query.FindNode(id, _) => NodeInfo(id, address)
115 | case Query.GetPeers(id, _) => NodeInfo(id, address)
116 | case Query.AnnouncePeer(id, _, _) => NodeInfo(id, address)
117 | case Query.SampleInfoHashes(id, _) => NodeInfo(id, address)
118 | queue.offer(nodeInfo) *> next(address, query)
119 | }
120 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/ExtensionHandler.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.wire
2 |
3 | import cats.effect.kernel.Deferred
4 | import cats.effect.kernel.Ref
5 | import cats.effect.std.Queue
6 | import cats.effect.Async
7 | import cats.effect.Concurrent
8 | import cats.effect.Sync
9 | import cats.implicits.*
10 | import cats.Applicative
11 | import cats.Monad
12 | import cats.MonadError
13 | import com.github.torrentdam.bittorrent.protocol.extensions.metadata.UtMessage
14 | import com.github.torrentdam.bittorrent.protocol.extensions.ExtensionHandshake
15 | import com.github.torrentdam.bittorrent.protocol.extensions.Extensions
16 | import com.github.torrentdam.bittorrent.protocol.extensions.Extensions.MessageId
17 | import com.github.torrentdam.bittorrent.InfoHash
18 | import com.github.torrentdam.bittorrent.TorrentMetadata.Lossless
19 | import com.github.torrentdam.bittorrent.protocol.message.Message
20 | import fs2.Stream
21 | import scodec.bits.ByteVector
22 | import com.github.torrentdam.bittorrent.CrossPlatform
23 |
24 | trait ExtensionHandler[F[_]] {
25 |
26 | def apply(message: Message.Extended): F[Unit]
27 | }
28 |
29 | object ExtensionHandler {
30 |
31 | def noop[F[_]](using F: Applicative[F]): ExtensionHandler[F] = _ => F.unit
32 |
33 | def dynamic[F[_]: Monad](get: F[ExtensionHandler[F]]): ExtensionHandler[F] = message => get.flatMap(_(message))
34 |
35 | type Send[F[_]] = Message.Extended => F[Unit]
36 |
37 | trait InitExtension[F[_]] {
38 |
39 | def init: F[ExtensionApi[F]]
40 | }
41 |
42 | object InitExtension {
43 |
44 | def apply[F[_]](
45 | infoHash: InfoHash,
46 | send: Send[F],
47 | utMetadata: UtMetadata.Create[F]
48 | )(using F: Concurrent[F]): F[(ExtensionHandler[F], InitExtension[F])] =
49 | for
50 | apiDeferred <- F.deferred[ExtensionApi[F]]
51 | handlerRef <- F.ref[ExtensionHandler[F]](ExtensionHandler.noop)
52 | _ <- handlerRef.set(
53 | {
54 | case Message.Extended(MessageId.Handshake, payload) =>
55 | for
56 | handshake <- F.fromEither(ExtensionHandshake.decode(payload))
57 | (handler, extensionApi) <- ExtensionApi[F](infoHash, send, utMetadata, handshake)
58 | _ <- handlerRef.set(handler)
59 | _ <- apiDeferred.complete(extensionApi)
60 | yield ()
61 | case message =>
62 | F.raiseError(InvalidMessage(s"Expected Handshake but received ${message.getClass.getSimpleName}"))
63 | }
64 | )
65 | yield
66 |
67 | val handler = dynamic(handlerRef.get)
68 |
69 | val api = new InitExtension[F] {
70 |
71 | def init: F[ExtensionApi[F]] = {
72 | val message: Message.Extended =
73 | Message.Extended(
74 | MessageId.Handshake,
75 | ExtensionHandshake.encode(Extensions.handshake)
76 | )
77 | send(message) >> apiDeferred.get
78 | }
79 | }
80 |
81 | (handler, api)
82 | end for
83 | }
84 |
85 | trait ExtensionApi[F[_]] {
86 |
87 | def utMetadata: Option[UtMetadata[F]]
88 | }
89 |
90 | object ExtensionApi {
91 |
92 | def apply[F[_]](
93 | infoHash: InfoHash,
94 | send: Send[F],
95 | utMetadata: UtMetadata.Create[F],
96 | handshake: ExtensionHandshake
97 | )(using F: MonadError[F, Throwable]): F[(ExtensionHandler[F], ExtensionApi[F])] = {
98 | for (utHandler, utMetadata0) <- utMetadata(infoHash, handshake, send)
99 | yield
100 |
101 | val handler: ExtensionHandler[F] = {
102 | case Message.Extended(Extensions.MessageId.Metadata, messageBytes) =>
103 | F.fromEither(UtMessage.decode(messageBytes)) >>= utHandler.apply
104 | case Message.Extended(id, _) =>
105 | F.raiseError(InvalidMessage(s"Unsupported message id=$id"))
106 | }
107 |
108 | val api: ExtensionApi[F] = new ExtensionApi[F] {
109 | def utMetadata: Option[UtMetadata[F]] = utMetadata0
110 | }
111 |
112 | (handler, api)
113 | end for
114 | }
115 |
116 | }
117 |
118 | trait UtMetadata[F[_]] {
119 |
120 | def fetch: F[Lossless]
121 | }
122 |
123 | object UtMetadata {
124 |
125 | trait Handler[F[_]] {
126 |
127 | def apply(message: UtMessage): F[Unit]
128 | }
129 |
130 | object Handler {
131 |
132 | def unit[F[_]](using F: Applicative[F]): Handler[F] = _ => F.unit
133 | }
134 |
135 | class Create[F[_]](using F: Async[F]) {
136 |
137 | def apply(
138 | infoHash: InfoHash,
139 | handshake: ExtensionHandshake,
140 | send: Message.Extended => F[Unit]
141 | ): F[(Handler[F], Option[UtMetadata[F]])] = {
142 |
143 | (handshake.extensions.get("ut_metadata"), handshake.metadataSize).tupled match {
144 | case Some((messageId, size)) =>
145 | for receiveQueue <- Queue.bounded[F, UtMessage](1)
146 | yield
147 | def sendUtMessage(utMessage: UtMessage) = {
148 | val message: Message.Extended = Message.Extended(messageId, UtMessage.encode(utMessage))
149 | send(message)
150 | }
151 |
152 | def receiveUtMessage: F[UtMessage] = receiveQueue.take
153 |
154 | (receiveQueue.offer, (new Impl(sendUtMessage, receiveUtMessage, size, infoHash)).some)
155 | end for
156 |
157 | case None =>
158 | (Handler.unit[F], Option.empty[UtMetadata[F]]).pure[F]
159 |
160 | }
161 |
162 | }
163 | }
164 |
165 | private class Impl[F[_]](
166 | send: UtMessage => F[Unit],
167 | receive: F[UtMessage],
168 | size: Long,
169 | infoHash: InfoHash
170 | )(using F: Sync[F])
171 | extends UtMetadata[F] {
172 |
173 | def fetch: F[Lossless] =
174 | Stream
175 | .range(0, 100)
176 | .evalMap { index =>
177 | send(UtMessage.Request(index)) *> receive.flatMap {
178 | case UtMessage.Data(`index`, bytes) => bytes.pure[F]
179 | case m =>
180 | F.raiseError[ByteVector](
181 | InvalidMessage(s"Data message expected but received ${m.getClass.getSimpleName}")
182 | )
183 | }
184 | }
185 | .scan(ByteVector.empty)(_ ++ _)
186 | .find(_.size >= size)
187 | .compile
188 | .lastOrError
189 | .ensure(InvalidMetadata()) { metadata =>
190 | CrossPlatform.sha1(metadata) == infoHash.bytes
191 | }
192 | .flatMap { bytes =>
193 | Lossless.fromBytes(bytes).liftTo[F]
194 | }
195 | }
196 | }
197 |
198 | case class InvalidMessage(message: String) extends Throwable(message)
199 | case class InvalidMetadata() extends Throwable
200 | }
201 |
--------------------------------------------------------------------------------
/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RoutingTable.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.dht
2 |
3 | import cats.effect.kernel.Concurrent
4 | import cats.effect.kernel.Ref
5 | import cats.effect.Sync
6 | import cats.implicits.*
7 | import com.comcast.ip4s.*
8 | import com.github.torrentdam.bittorrent.InfoHash
9 | import com.github.torrentdam.bittorrent.PeerInfo
10 |
11 | import scala.collection.immutable.ListMap
12 | import scodec.bits.ByteVector
13 |
14 | import scala.annotation.tailrec
15 |
16 | trait RoutingTable[F[_]] {
17 |
18 | def insert(node: NodeInfo): F[Unit]
19 |
20 | def remove(nodeId: NodeId): F[Unit]
21 |
22 | def goodNodes(nodeId: NodeId): F[Iterable[NodeInfo]]
23 |
24 | def addPeer(infoHash: InfoHash, peerInfo: PeerInfo): F[Unit]
25 |
26 | def findPeers(infoHash: InfoHash): F[Option[Iterable[PeerInfo]]]
27 |
28 | def allNodes: F[Iterable[RoutingTable.Node]]
29 |
30 | def buckets: F[Iterable[RoutingTable.TreeNode.Bucket]]
31 |
32 | def updateGoodness(good: Set[NodeId], bad: Set[NodeId]): F[Unit]
33 |
34 | def lookup(nodeId: NodeId): F[Option[RoutingTable.Node]]
35 | }
36 |
37 | object RoutingTable {
38 |
39 | enum TreeNode:
40 | case Split(center: BigInt, lower: TreeNode, higher: TreeNode)
41 | case Bucket(from: BigInt, until: BigInt, nodes: Map[NodeId, Node])
42 |
43 | case class Node(id: NodeId, address: SocketAddress[IpAddress], isGood: Boolean, badCount: Int = 0):
44 | def toNodeInfo: NodeInfo = NodeInfo(id, address)
45 |
46 | object TreeNode {
47 |
48 | def empty: TreeNode =
49 | TreeNode.Bucket(
50 | from = BigInt(0),
51 | until = BigInt(1, ByteVector.fill(20)(-1: Byte).toArray),
52 | Map.empty
53 | )
54 | }
55 |
56 | val MaxNodes = 8
57 |
58 | import TreeNode.*
59 |
60 | extension (bucket: TreeNode)
61 | def insert(node: NodeInfo, selfId: NodeId): TreeNode =
62 | bucket match
63 | case b @ Split(center, lower, higher) =>
64 | if (node.id.int < center)
65 | b.copy(lower = lower.insert(node, selfId))
66 | else
67 | b.copy(higher = higher.insert(node, selfId))
68 | case b @ Bucket(from, until, nodes) =>
69 | if nodes.size >= MaxNodes && !nodes.contains(selfId)
70 | then
71 | if selfId.int >= from && selfId.int < until
72 | then
73 | // split the bucket because it contains the self node
74 | val center = (from + until) / 2
75 | val splitNode =
76 | Split(
77 | center,
78 | lower = Bucket(from, center, nodes.view.filterKeys(_.int < center).to(ListMap)),
79 | higher = Bucket(center, until, nodes.view.filterKeys(_.int >= center).to(ListMap))
80 | )
81 | splitNode.insert(node, selfId)
82 | else
83 | // drop one node from the bucket
84 | val badNode = nodes.values.find(!_.isGood)
85 | badNode match
86 | case Some(badNode) => Bucket(from, until, nodes.removed(badNode.id)).insert(node, selfId)
87 | case None => b
88 | else
89 | Bucket(from, until, nodes.updated(node.id, Node(node.id, node.address, isGood = true)))
90 |
91 | def remove(nodeId: NodeId): TreeNode =
92 | bucket match
93 | case b @ Split(center, lower, higher) =>
94 | if (nodeId.int < center)
95 | (lower.remove(nodeId), higher) match {
96 | case (Bucket(lowerFrom, _, nodes), finalHigher: Bucket) if nodes.isEmpty =>
97 | finalHigher.copy(from = lowerFrom)
98 | case (l, _) =>
99 | b.copy(lower = l)
100 | }
101 | else
102 | (higher.remove(nodeId), lower) match {
103 | case (Bucket(_, higherUntil, nodes), finalLower: Bucket) if nodes.isEmpty =>
104 | finalLower.copy(until = higherUntil)
105 | case (h, _) =>
106 | b.copy(higher = h)
107 | }
108 | case b @ Bucket(_, _, nodes) =>
109 | b.copy(nodes = nodes - nodeId)
110 |
111 | @tailrec
112 | def findBucket(nodeId: NodeId): Bucket =
113 | bucket match
114 | case Split(center, lower, higher) =>
115 | if (nodeId.int < center)
116 | lower.findBucket(nodeId)
117 | else
118 | higher.findBucket(nodeId)
119 | case b: Bucket => b
120 |
121 | def findNodes(nodeId: NodeId): Iterable[Node] =
122 | bucket match
123 | case Split(center, lower, higher) =>
124 | if (nodeId.int < center)
125 | lower.findNodes(nodeId) ++ higher.findNodes(nodeId)
126 | else
127 | higher.findNodes(nodeId) ++ lower.findNodes(nodeId)
128 | case b: Bucket => b.nodes.values.to(LazyList)
129 |
130 | def buckets: Iterable[Bucket] =
131 | bucket match
132 | case b: Bucket => Iterable(b)
133 | case Split(_, lower, higher) => lower.buckets ++ higher.buckets
134 |
135 | def update(fn: Node => Node): TreeNode =
136 | bucket match
137 | case b @ Split(_, lower, higher) =>
138 | b.copy(lower = lower.update(fn), higher = higher.update(fn))
139 | case b @ Bucket(from, until, nodes) =>
140 | b.copy(nodes = nodes.view.mapValues(fn).to(ListMap))
141 |
142 | end extension
143 |
144 | def apply[F[_]: Concurrent](selfId: NodeId): F[RoutingTable[F]] =
145 | for {
146 | treeNodeRef <- Ref.of(TreeNode.empty)
147 | peers <- Ref.of(Map.empty[InfoHash, Set[PeerInfo]])
148 | } yield new RoutingTable[F] {
149 |
150 | def insert(node: NodeInfo): F[Unit] =
151 | treeNodeRef.update(_.insert(node, selfId))
152 |
153 | def remove(nodeId: NodeId): F[Unit] =
154 | treeNodeRef.update(_.remove(nodeId))
155 |
156 | def goodNodes(nodeId: NodeId): F[Iterable[NodeInfo]] =
157 | treeNodeRef.get.map(_.findNodes(nodeId).filter(_.isGood).map(_.toNodeInfo))
158 |
159 | def addPeer(infoHash: InfoHash, peerInfo: PeerInfo): F[Unit] =
160 | peers.update { map =>
161 | map.updatedWith(infoHash) {
162 | case Some(set) => Some(set + peerInfo)
163 | case None => Some(Set(peerInfo))
164 | }
165 | }
166 |
167 | def findPeers(infoHash: InfoHash): F[Option[Iterable[PeerInfo]]] =
168 | peers.get.map(_.get(infoHash))
169 |
170 | def allNodes: F[Iterable[Node]] =
171 | treeNodeRef.get.map(_.findNodes(selfId))
172 |
173 | def buckets: F[Iterable[TreeNode.Bucket]] =
174 | treeNodeRef.get.map(_.buckets)
175 |
176 | def updateGoodness(good: Set[NodeId], bad: Set[NodeId]): F[Unit] =
177 | treeNodeRef.update(
178 | _.update(node =>
179 | if good.contains(node.id) then node.copy(isGood = true, badCount = 0)
180 | else if bad.contains(node.id) then node.copy(isGood = false, badCount = node.badCount + 1)
181 | else node
182 | )
183 | )
184 |
185 | def lookup(nodeId: NodeId): F[Option[Node]] =
186 | treeNodeRef.get.map(_.findBucket(nodeId).nodes.get(nodeId))
187 | }
188 | }
189 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/Download.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.wire
2 |
3 | import cats.effect.implicits.*
4 | import cats.effect.kernel.Ref
5 | import cats.effect.std.Queue
6 | import cats.effect.std.Semaphore
7 | import cats.effect.std.Supervisor
8 | import cats.effect.Async
9 | import cats.effect.IO
10 | import cats.effect.Resource
11 | import cats.effect.Temporal
12 | import cats.implicits.*
13 | import cats.Show.Shown
14 | import com.github.torrentdam.bittorrent.TorrentMetadata
15 | import com.github.torrentdam.bittorrent.protocol.message.Message
16 | import fs2.concurrent.Signal
17 | import fs2.concurrent.SignallingRef
18 | import fs2.concurrent.Topic
19 | import fs2.Chunk
20 | import fs2.Stream
21 | import org.legogroup.woof.given
22 | import org.legogroup.woof.Logger
23 |
24 | import scala.collection.BitSet
25 | import scala.concurrent.duration.*
26 | import scala.util.chaining.*
27 | import scodec.bits.ByteVector
28 |
29 | object Download {
30 |
31 | def apply(
32 | swarm: Swarm,
33 | piecePicker: RequestDispatcher
34 | )(using
35 | logger: Logger[IO]
36 | ): IO[Unit] =
37 | import Logger.withLogContext
38 | Classifier().use(classifier =>
39 | swarm.connect
40 | .use(connection =>
41 | (
42 | for
43 | _ <- connection.interested
44 | _ <- classifier.create.use(speedInfo =>
45 | whenUnchoked(connection)(
46 | download(connection, piecePicker, speedInfo)
47 | )
48 | )
49 | yield ()
50 | )
51 | .race(connection.disconnected)
52 | .withLogContext("address", connection.info.address.toString)
53 | )
54 | .attempt
55 | .foreverM
56 | .parReplicateA(30)
57 | .void
58 | )
59 |
60 | private def download(
61 | connection: Connection,
62 | pieces: RequestDispatcher,
63 | speedInfo: SpeedData
64 | )(using logger: Logger[IO]): IO[Unit] = {
65 |
66 | def bounded(min: Int, max: Int)(n: Int): Int = math.min(math.max(n, min), max)
67 |
68 | def computeOutstanding(
69 | downloadedBytes: Topic[IO, Long],
70 | downloadedTotal: Ref[IO, Long],
71 | maxOutstanding: SignallingRef[IO, Int]
72 | ) =
73 | downloadedBytes
74 | .subscribe(10)
75 | .evalTap(size => downloadedTotal.update(_ + size))
76 | .groupWithin(Int.MaxValue, 10.seconds)
77 | .map(chunks =>
78 | bounded(1, 100)(
79 | (chunks.foldLeft(0L)(_ + _) / 10 / RequestDispatcher.ChunkSize).toInt
80 | )
81 | )
82 | .evalTap(maxOutstanding.set)
83 | .evalTap(speedInfo.bytes.set)
84 | .compile
85 | .drain
86 |
87 | def updateSemaphore(semaphore: Semaphore[IO], maxOutstanding: SignallingRef[IO, Int]) =
88 | maxOutstanding.discrete
89 | .sliding(2)
90 | .evalMap { chunk =>
91 | val (prev, next) = (chunk(0), chunk(1))
92 | if prev < next then semaphore.releaseN(next - prev)
93 | else semaphore.acquireN(prev - next)
94 | }
95 | .compile
96 | .drain
97 |
98 | def sendRequest(request: Message.Request): IO[ByteVector] =
99 | logger.trace(s"Request $request") >>
100 | connection
101 | .request(request)
102 | .timeout(5.seconds)
103 |
104 | def nextRequest(semaphore: Semaphore[IO]) =
105 | semaphore.permit >> pieces.stream(connection.availability.get, speedInfo.cls.get)
106 |
107 | def fireRequests(
108 | semaphore: Semaphore[IO],
109 | failureCounter: Ref[IO, Int],
110 | downloadedBytes: Topic[IO, Long]
111 | ) =
112 | Stream
113 | .resource(nextRequest(semaphore))
114 | .interruptWhen(
115 | IO.sleep(1.minute).as(Left(Error.TimeoutWaitingForPiece(1.minute)))
116 | )
117 | .repeat
118 | .map { (request, promise) =>
119 | Stream.eval(
120 | sendRequest(request).attempt.flatMap {
121 | case Right(bytes) =>
122 | failureCounter.set(0) >> downloadedBytes.publish1(bytes.size) >> promise.complete(bytes)
123 | case Left(_) =>
124 | failureCounter
125 | .updateAndGet(_ + 1)
126 | .flatMap(count =>
127 | if count >= 10
128 | then IO.raiseError(Error.PeerDoesNotRespond())
129 | else IO.unit
130 | )
131 | }
132 | )
133 | }
134 | .parJoinUnbounded
135 | .compile
136 | .drain
137 |
138 | for
139 | failureCounter <- IO.ref(0)
140 | downloadedBytes <- Topic[IO, Long]
141 | downloadedTotal <- IO.ref(0L)
142 | maxOutstanding <- SignallingRef[IO, Int](5)
143 | semaphore <- Semaphore[IO](5)
144 | _ <- (
145 | computeOutstanding(downloadedBytes, downloadedTotal, maxOutstanding),
146 | updateSemaphore(semaphore, maxOutstanding),
147 | fireRequests(semaphore, failureCounter, downloadedBytes)
148 | ).parTupled
149 | .handleErrorWith(e =>
150 | downloadedTotal.get.flatMap {
151 | case 0 => IO.raiseError(e)
152 | case _ => IO.unit
153 | }
154 | )
155 | yield ()
156 | }
157 |
158 | private def whenUnchoked(connection: Connection)(f: IO[Unit])(using
159 | logger: Logger[IO]
160 | ): IO[Unit] = {
161 | def waitChoked = connection.choked.waitUntil(identity)
162 | def waitUnchoked =
163 | connection.choked
164 | .waitUntil(choked => !choked)
165 | .timeoutTo(30.seconds, IO.raiseError(Error.TimeoutWaitingForUnchoke(30.seconds)))
166 |
167 | (waitUnchoked >> (f race waitChoked)).foreverM
168 | }
169 |
170 | private case class SpeedData(bytes: Ref[IO, Int], cls: Ref[IO, SpeedClass])
171 |
172 | private class Classifier(counter: Ref[IO, Long], state: Ref[IO, Map[Long, SpeedData]]) {
173 | def create: Resource[IO, SpeedData] = Resource(
174 | for
175 | id <- counter.getAndUpdate(_ + 1)
176 | bytes <- IO.ref(0)
177 | cls <- IO.ref(SpeedClass.Slow)
178 | _ <- state.update(_.updated(id, SpeedData(bytes, cls)))
179 | yield (SpeedData(bytes, cls), state.update(_ - id))
180 | )
181 | }
182 | private object Classifier {
183 | def apply(): Resource[IO, Classifier] =
184 | for
185 | counter <- Resource.eval(IO.ref(0L))
186 | state <- Resource.eval(IO.ref(Map.empty[Long, SpeedData]))
187 | _ <- (IO.sleep(10.seconds) >> updateClass(state)).foreverM.background
188 | yield new Classifier(counter, state)
189 |
190 | private def updateClass(state: Ref[IO, Map[Long, SpeedData]]): IO[Unit] =
191 | state.get
192 | .flatMap(_.values.toList.traverse { info =>
193 | info.bytes.get.tupleRight(info)
194 | })
195 | .flatMap(values =>
196 | val sorted = values.sortBy(_._1)(using Ordering[Int].reverse).map(_._2)
197 | val fastCount = (values.size.toDouble * 0.7).ceil.toInt
198 | val (fast, slow) = sorted.splitAt(fastCount)
199 | fast.traverse(_.cls.set(SpeedClass.Fast)) >> slow.traverse(_.cls.set(SpeedClass.Slow))
200 | )
201 | .void
202 | }
203 |
204 | enum Error(message: String) extends Throwable(message):
205 | case TimeoutWaitingForUnchoke(duration: FiniteDuration) extends Error(s"Unchoke timeout $duration")
206 | case TimeoutWaitingForPiece(duration: FiniteDuration) extends Error(s"Block request timeout $duration")
207 | case InvalidChecksum() extends Error("Invalid checksum")
208 | case PeerDoesNotRespond() extends Error("Peer does not respond")
209 | }
210 |
211 | enum SpeedClass {
212 | case Slow, Fast
213 | }
214 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/RequestDispatcher.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.wire
2 |
3 | import cats.effect.cps.*
4 | import cats.data.Chain
5 | import cats.effect.kernel.Deferred
6 | import cats.effect.std.Dequeue
7 | import cats.effect.std.Semaphore
8 | import cats.effect.IO
9 | import cats.effect.Ref
10 | import cats.effect.Resource
11 | import cats.implicits.*
12 | import com.comcast.ip4s.*
13 | import com.github.torrentdam.bittorrent.protocol.message.Message.Request
14 | import com.github.torrentdam.bittorrent.PeerInfo
15 | import com.github.torrentdam.bittorrent.TorrentMetadata
16 | import com.github.torrentdam.bittorrent.protocol.message.Message
17 | import com.github.torrentdam.bittorrent.CrossPlatform
18 | import fs2.concurrent.Signal
19 | import fs2.concurrent.SignallingRef
20 | import fs2.Stream
21 | import java.util.UUID
22 | import org.legogroup.woof.given
23 | import org.legogroup.woof.Logger
24 |
25 | import scala.collection.immutable.BitSet
26 | import scala.collection.immutable.TreeMap
27 | import scala.concurrent.duration.DurationInt
28 | import scodec.bits.ByteVector
29 |
30 | trait RequestDispatcher {
31 | def downloadPiece(index: Long): IO[ByteVector]
32 |
33 | def stream(
34 | availability: IO[BitSet],
35 | speedClass: IO[SpeedClass]
36 | ): Resource[IO, (Message.Request, Deferred[IO, ByteVector])]
37 | }
38 |
39 | object RequestDispatcher {
40 | val ChunkSize: Int = 16 * 1024
41 |
42 | private type RequestQueue = WorkQueue[Message.Request, ByteVector]
43 |
44 | def apply(metadata: TorrentMetadata)(using logger: Logger[IO]): Resource[IO, RequestDispatcher] =
45 | for
46 | queue <- Resource.eval(IO.ref(TreeMap.empty[Long, RequestQueue]))
47 | queueReverse <- Resource.eval(IO.ref(TreeMap.empty[Long, RequestQueue](using Ordering[Long].reverse)))
48 | workGenerator = WorkGenerator(metadata)
49 | yield Impl(workGenerator, queue, queueReverse)
50 |
51 | private class Impl(
52 | workGenerator: WorkGenerator,
53 | queue: Ref[IO, TreeMap[Long, RequestQueue]],
54 | queueReverse: Ref[IO, TreeMap[Long, RequestQueue]]
55 | )(using Logger[IO]) extends RequestDispatcher {
56 | def downloadPiece(index: Long): IO[ByteVector] =
57 | val pieceWork = workGenerator.pieceWork(index)
58 | val attempt = async[IO] {
59 | try
60 | val result = IO.deferred[Map[Request, ByteVector]].await
61 | val requestQueue = WorkQueue(pieceWork.requests.toList, result.complete).await
62 | queue.update(_.updated(index, requestQueue)).await
63 | queueReverse.update(_.updated(index, requestQueue)).await
64 | val completedBlocks = result.get.await
65 | val pieceBytes = completedBlocks.toList.sortBy(_._1.begin).map(_._2).foldLeft(ByteVector.empty)(_ ++ _)
66 | pieceBytes
67 | finally
68 | queue.update(_ - index).await
69 | queueReverse.update(_ - index).await
70 | }
71 | attempt.flatMap(bytes =>
72 | if CrossPlatform.sha1(bytes) == pieceWork.checksum then IO.pure(bytes)
73 | else
74 | Logger[IO].warn(s"Piece $index failed checksum") >> attempt
75 | )
76 |
77 | def stream(
78 | availability: IO[BitSet],
79 | speedClass: IO[SpeedClass]
80 | ): Resource[IO, (Message.Request, Deferred[IO, ByteVector])] =
81 | def pickFrom(trackers: List[RequestQueue]): Resource[IO, (Message.Request, Deferred[IO, ByteVector])] =
82 | trackers match
83 | case Nil =>
84 | Resource.eval(IO.raiseError(NoPieceAvailable))
85 | case tracker :: rest =>
86 | tracker.nextRequest.recoverWith(_ => pickFrom(rest))
87 |
88 | def singlePass =
89 | for
90 | speedClass <- Resource.eval(speedClass)
91 | inProgress <- Resource.eval(
92 | speedClass match
93 | case SpeedClass.Fast => queue.get
94 | case SpeedClass.Slow => queueReverse.get // take piece with the highest index first
95 | )
96 | availability <- Resource.eval(availability)
97 | matched = inProgress.collect { case (index, tracker) if availability(index.toInt) => tracker }.toList
98 | result <- pickFrom(matched)
99 | yield result
100 |
101 | def polling: Resource[IO, (Message.Request, Deferred[IO, ByteVector])] =
102 | singlePass.recoverWith(_ => Resource.eval(IO.sleep(1.seconds)) >> polling)
103 |
104 | polling
105 | }
106 |
107 | class WorkGenerator(pieceLength: Long, totalLength: Long, pieces: ByteVector) {
108 | def this(metadata: TorrentMetadata) =
109 | this(
110 | metadata.pieceLength,
111 | metadata.files.map(_.length).sum,
112 | metadata.pieces
113 | )
114 |
115 | def pieceWork(index: Long): PieceWork =
116 | val thisPieceLength = math.min(pieceLength, totalLength - index * pieceLength)
117 | PieceWork(
118 | thisPieceLength,
119 | pieces.drop(index * 20).take(20),
120 | genRequests(index, thisPieceLength)
121 | )
122 |
123 | def genRequests(pieceIndex: Long, pieceLength: Long): Chain[Message.Request] =
124 | var result = Chain.empty[Message.Request]
125 |
126 | def loop(requestIndex: Long): Unit = {
127 | val thisChunkSize = math.min(ChunkSize, pieceLength - requestIndex * ChunkSize)
128 | if thisChunkSize > 0 then
129 | val begin = requestIndex * ChunkSize
130 | result = result.append(
131 | Message.Request(
132 | pieceIndex,
133 | begin,
134 | thisChunkSize
135 | )
136 | )
137 | loop(requestIndex + 1)
138 | }
139 |
140 | loop(0)
141 | result
142 | }
143 |
144 | case class PieceWork(
145 | size: Long,
146 | checksum: ByteVector,
147 | requests: Chain[Message.Request]
148 | )
149 |
150 | case object NoPieceAvailable extends Throwable("No piece available")
151 |
152 | trait WorkQueue[Work, Result] {
153 | def nextRequest: Resource[IO, (Work, Deferred[IO, Result])]
154 | }
155 |
156 | object WorkQueue {
157 |
158 | def apply[Request, Response](
159 | requests: Seq[Request],
160 | onComplete: Map[Request, Response] => IO[Any]
161 | ): IO[WorkQueue[Request, Response]] =
162 | require(requests.nonEmpty)
163 | for
164 | requestQueue <- Dequeue.unbounded[IO, Request]
165 | _ <- requests.traverse(requestQueue.offer)
166 | responses <- IO.ref(Map.empty[Request, Response])
167 | outstandingCount <- IO.ref(requests.size)
168 | yield new {
169 | override def nextRequest: Resource[IO, (Request, Deferred[IO, Response])] =
170 | Resource(
171 | for
172 | _ <- outstandingCount.get.flatMap(n => IO.raiseWhen(n == 0)(PieceComplete))
173 | request <- requestQueue.tryTake
174 | request <- request match
175 | case Some(request) => IO.pure(request)
176 | case None => IO.raiseError(EmptyQueue)
177 | promise <- IO.deferred[Response]
178 | yield (
179 | (request, promise),
180 | for
181 | bytes <- promise.tryGet
182 | _ <- bytes match
183 | case Some(bytes) =>
184 | for
185 | _ <- outstandingCount.update(_ - 1)
186 | result <- responses.updateAndGet(_.updated(request, bytes))
187 | _ <- outstandingCount.get.flatMap(n => IO.whenA(n == 0)(onComplete(result).void))
188 | yield ()
189 | case None =>
190 | requestQueue.offerFront(request)
191 | yield ()
192 | )
193 | )
194 | }
195 |
196 | case object EmptyQueue extends Throwable
197 | case object PieceComplete extends Throwable
198 | }
199 | }
200 |
--------------------------------------------------------------------------------
/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/PeerDiscovery.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.dht
2 |
3 | import cats.effect.kernel.Deferred
4 | import cats.effect.kernel.Ref
5 | import cats.effect.Concurrent
6 | import cats.effect.IO
7 | import cats.effect.Resource
8 | import cats.instances.all.*
9 | import cats.syntax.all.*
10 | import cats.effect.cps.{given, *}
11 | import cats.Show.Shown
12 | import com.github.torrentdam.bittorrent.InfoHash
13 | import com.github.torrentdam.bittorrent.PeerInfo
14 | import fs2.Stream
15 | import org.legogroup.woof.given
16 | import org.legogroup.woof.Logger
17 |
18 | import scala.concurrent.duration.DurationInt
19 | import Logger.withLogContext
20 | import com.comcast.ip4s.{IpAddress, SocketAddress}
21 |
22 | trait PeerDiscovery {
23 |
24 | def discover(infoHash: InfoHash): Stream[IO, PeerInfo]
25 |
26 | def findNodes(NodeId: NodeId): Stream[IO, NodeInfo]
27 | }
28 |
29 | object PeerDiscovery {
30 |
31 | def apply(
32 | routingTable: RoutingTable[IO],
33 | dhtClient: Client
34 | )(using
35 | logger: Logger[IO]
36 | ): PeerDiscovery = new {
37 | def discover(infoHash: InfoHash): Stream[IO, PeerInfo] = {
38 | Stream
39 | .eval {
40 | for {
41 | _ <- logger.info("Start discovery")
42 | initialNodes <- routingTable.goodNodes(NodeId(infoHash.bytes))
43 | initialNodes <- initialNodes.take(16).toList.pure[IO]
44 | _ <- logger.info(s"Received ${initialNodes.size} from own routing table")
45 | state <- DiscoveryState(initialNodes, infoHash)
46 | } yield {
47 | start(
48 | infoHash,
49 | dhtClient.getPeers,
50 | state
51 | )
52 | }
53 | }
54 | .flatten
55 | .onFinalizeCase {
56 | case Resource.ExitCase.Errored(e) => logger.error(s"Discovery failed with ${e.getMessage}")
57 | case _ => IO.unit
58 | }
59 | }
60 |
61 | def findNodes(nodeId: NodeId): Stream[IO, NodeInfo] =
62 | Stream
63 | .eval(
64 | for
65 | _ <- logger.info(s"Start finding nodes for $nodeId")
66 | initialNodes <- routingTable.goodNodes(nodeId)
67 | initialNodes <- initialNodes
68 | .take(16)
69 | .toList
70 | .sortBy(nodeInfo => NodeId.distance(nodeInfo.id, dhtClient.id))
71 | .pure[IO]
72 | yield
73 | FindNodesState(nodeId, initialNodes)
74 | )
75 | .flatMap { state =>
76 | Stream
77 | .unfoldEval(state)(_.next)
78 | .flatMap(Stream.emits)
79 | }
80 |
81 | case class FindNodesState(
82 | targetId: NodeId,
83 | nodesToQuery: List[NodeInfo],
84 | usedNodes: Set[NodeInfo] = Set.empty,
85 | respondedCount: Int = 0
86 | ):
87 | def next: IO[Option[(List[NodeInfo], FindNodesState)]] = async[IO]:
88 | if nodesToQuery.isEmpty then
89 | none
90 | else
91 | val responses = nodesToQuery
92 | .parTraverse(nodeInfo =>
93 | dhtClient
94 | .findNodes(nodeInfo.address, targetId)
95 | .map(_.nodes.some)
96 | .timeout(5.seconds)
97 | .orElse(none.pure[IO])
98 | .tupleLeft(nodeInfo)
99 | )
100 | .await
101 | val respondedNodes = responses.collect { case (nodeInfo, Some(_)) => nodeInfo }
102 | val foundNodes = responses.collect { case (_, Some(nodes)) => nodes }.flatten
103 | val threshold =
104 | if respondedCount > 10
105 | then NodeId.distance(nodesToQuery.head.id, targetId)
106 | else NodeId.MaxValue
107 | val closeNodes = foundNodes
108 | .filterNot(usedNodes)
109 | .distinct
110 | .filter(nodeInfo => NodeId.distance(nodeInfo.id, targetId) < threshold)
111 | .sortBy(nodeInfo => NodeId.distance(nodeInfo.id, targetId))
112 | .take(10)
113 | (
114 | respondedNodes,
115 | copy(
116 | nodesToQuery = closeNodes,
117 | usedNodes = usedNodes ++ respondedNodes,
118 | respondedCount = respondedCount + respondedNodes.size)
119 | ).some
120 | }
121 |
122 | private[dht] def start(
123 | infoHash: InfoHash,
124 | getPeers: (SocketAddress[IpAddress], InfoHash) => IO[Either[Response.Nodes, Response.Peers]],
125 | state: DiscoveryState,
126 | parallelism: Int = 10
127 | )(using
128 | logger: Logger[IO]
129 | ): Stream[IO, PeerInfo] = {
130 |
131 | Stream
132 | .repeatEval(state.next)
133 | .parEvalMapUnordered(parallelism) { nodeInfo =>
134 | getPeers(nodeInfo.address, infoHash).timeout(5.seconds).attempt <* logger.trace(s"Get peers $nodeInfo")
135 | }
136 | .flatMap {
137 | case Right(response) =>
138 | response match {
139 | case Left(Response.Nodes(_, nodes)) =>
140 | Stream
141 | .eval(state.addNodes(nodes)) >> Stream.empty
142 | case Right(Response.Peers(_, peers)) =>
143 | Stream
144 | .eval(state.addPeers(peers))
145 | .flatMap(newPeers => Stream.emits(newPeers))
146 | }
147 | case Left(_) =>
148 | Stream.empty
149 | }
150 | }
151 |
152 | class DiscoveryState(ref: Ref[IO, DiscoveryState.Data], infoHash: InfoHash) {
153 |
154 | def next: IO[NodeInfo] =
155 | IO.deferred[NodeInfo]
156 | .flatMap { deferred =>
157 | ref.modify { state =>
158 | state.nodesToTry match {
159 | case x :: xs => (state.copy(nodesToTry = xs), x.pure[IO])
160 | case _ =>
161 | (state.copy(waiters = deferred :: state.waiters), deferred.get)
162 | }
163 | }.flatten
164 | }
165 |
166 | def addNodes(nodes: List[NodeInfo]): IO[Unit] = {
167 | ref.modify { state =>
168 | val newNodes = nodes.filterNot(state.seenNodes)
169 | val seenNodes = state.seenNodes ++ newNodes
170 | val nodesToTry = (newNodes ++ state.nodesToTry).sortBy(n => NodeId.distance(n.id, infoHash))
171 | val waiters = state.waiters.drop(nodesToTry.size)
172 | val newState =
173 | state.copy(
174 | nodesToTry = nodesToTry.drop(state.waiters.size),
175 | seenNodes = seenNodes,
176 | waiters = waiters
177 | )
178 | val io =
179 | state.waiters.zip(nodesToTry).map { case (deferred, nodeInfo) => deferred.complete(nodeInfo) }.sequence_
180 | (newState, io)
181 | }.flatten
182 | }
183 |
184 | type NewPeers = List[PeerInfo]
185 |
186 | def addPeers(peers: List[PeerInfo]): IO[NewPeers] = {
187 | ref
188 | .modify { state =>
189 | val newPeers = peers.filterNot(state.seenPeers)
190 | val newState = state.copy(
191 | seenPeers = state.seenPeers ++ newPeers
192 | )
193 | (newState, newPeers)
194 | }
195 | }
196 | }
197 |
198 | object DiscoveryState {
199 |
200 | case class Data(
201 | nodesToTry: List[NodeInfo],
202 | seenNodes: Set[NodeInfo],
203 | seenPeers: Set[PeerInfo] = Set.empty,
204 | waiters: List[Deferred[IO, NodeInfo]] = Nil
205 | )
206 |
207 | def apply(initialNodes: List[NodeInfo], infoHash: InfoHash): IO[DiscoveryState] =
208 | for {
209 | ref <- IO.ref(Data(initialNodes, initialNodes.toSet))
210 | } yield {
211 | new DiscoveryState(ref, infoHash)
212 | }
213 | }
214 |
215 | case class ExhaustedNodeList() extends Exception
216 | }
217 |
--------------------------------------------------------------------------------
/bittorrent/shared/src/main/scala/com/github/torrentdam/bittorrent/wire/Connection.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.wire
2 |
3 | import cats.*
4 | import cats.effect.kernel.Deferred
5 | import cats.effect.kernel.Ref
6 | import cats.effect.std.Queue
7 | import cats.effect.syntax.all.*
8 | import cats.effect.IO
9 | import cats.effect.Outcome
10 | import cats.effect.Resource
11 | import cats.implicits.*
12 | import com.github.torrentdam.bittorrent.wire.ExtensionHandler.ExtensionApi
13 | import com.github.torrentdam.bittorrent.InfoHash
14 | import com.github.torrentdam.bittorrent.PeerId
15 | import com.github.torrentdam.bittorrent.PeerInfo
16 | import com.github.torrentdam.bittorrent.TorrentMetadata
17 | import com.github.torrentdam.bittorrent.protocol.message.Message
18 | import fs2.concurrent.Signal
19 | import fs2.concurrent.SignallingRef
20 | import fs2.io.net.Network
21 | import fs2.io.net.SocketGroup
22 | import monocle.macros.GenLens
23 | import monocle.Lens
24 | import org.legogroup.woof.given
25 | import org.legogroup.woof.Logger
26 |
27 | import scala.collection.immutable.BitSet
28 | import scala.concurrent.duration.*
29 | import scodec.bits.ByteVector
30 |
31 | trait Connection {
32 | def info: PeerInfo
33 | def extensionProtocol: Boolean
34 | def interested: IO[Unit]
35 | def request(request: Message.Request): IO[ByteVector]
36 | def choked: Signal[IO, Boolean]
37 | def availability: Signal[IO, BitSet]
38 | def disconnected: IO[Unit]
39 | def extensionApi: IO[ExtensionApi[IO]]
40 | }
41 |
42 | object Connection {
43 |
44 | case class State(lastMessageAt: Long = 0, interested: Boolean = false)
45 | object State {
46 | val lastMessageAt: Lens[State, Long] = GenLens[State](_.lastMessageAt)
47 | val interested: Lens[State, Boolean] = GenLens[State](_.interested)
48 | }
49 |
50 | trait RequestRegistry {
51 | def register(request: Message.Request): IO[ByteVector]
52 | def complete(request: Message.Request, bytes: ByteVector): IO[Unit]
53 | }
54 | object RequestRegistry {
55 | def apply(): Resource[IO, RequestRegistry] =
56 | for stateRef <- Resource.eval(
57 | IO.ref(Map.empty[Message.Request, Either[Throwable, ByteVector] => IO[Boolean]])
58 | )
59 | // _ <- Resource.onFinalize(
60 | // for
61 | // state <- stateRef.get
62 | // _ <- state.values.toList.traverse { cb =>
63 | // cb(ConnectionClosed().asLeft)
64 | // }
65 | // yield ()
66 | // )
67 | yield new RequestRegistry {
68 |
69 | def register(request: Message.Request): IO[ByteVector] =
70 | IO.deferred[Either[Throwable, ByteVector]]
71 | .flatMap { deferred =>
72 | val update = stateRef.update(_.updated(request, deferred.complete))
73 | val delete = stateRef.update(_ - request)
74 | (update >> deferred.get).guarantee(delete)
75 | }
76 | .flatMap(IO.fromEither)
77 |
78 | def complete(request: Message.Request, bytes: ByteVector): IO[Unit] =
79 | for
80 | callback <- stateRef.get.map(_.get(request))
81 | _ <- callback.traverse(cb => cb(bytes.asRight))
82 | yield ()
83 | }
84 | }
85 |
86 | def connect(selfId: PeerId, peerInfo: PeerInfo, infoHash: InfoHash)(using
87 | network: Network[IO],
88 | logger: Logger[IO]
89 | ): Resource[IO, Connection] =
90 | for
91 | requestRegistry <- RequestRegistry()
92 | socket <- MessageSocket.connect[IO](selfId, peerInfo, infoHash)
93 | stateRef <- Resource.eval(IO.ref(State()))
94 | chokedStatusRef <- Resource.eval(SignallingRef[IO].of(true))
95 | bitfieldRef <- Resource.eval(SignallingRef[IO].of(BitSet.empty))
96 | sendQueue <- Resource.eval(Queue.bounded[IO, Message](10))
97 | (extensionHandler, initExtension) <- Resource.eval(
98 | ExtensionHandler.InitExtension(
99 | infoHash,
100 | sendQueue.offer,
101 | new ExtensionHandler.UtMetadata.Create[IO]
102 | )
103 | )
104 | updateLastMessageTime = (l: Long) => stateRef.update(State.lastMessageAt.replace(l))
105 | closed <-
106 | (
107 | receiveLoop(
108 | requestRegistry,
109 | bitfieldRef.update,
110 | chokedStatusRef.set,
111 | updateLastMessageTime,
112 | socket,
113 | extensionHandler
114 | ),
115 | sendLoop(sendQueue, socket),
116 | keepAliveLoop(stateRef, sendQueue.offer)
117 | ).parTupled.background
118 | yield new Connection {
119 | def info: PeerInfo = peerInfo
120 | def extensionProtocol: Boolean = socket.handshake.extensionProtocol
121 |
122 | def interested: IO[Unit] =
123 | for
124 | interested <- stateRef.modify(s => (State.interested.replace(true)(s), s.interested))
125 | _ <- IO.whenA(!interested)(sendQueue.offer(Message.Interested))
126 | yield ()
127 |
128 | def request(request: Message.Request): IO[ByteVector] =
129 | sendQueue.offer(request) >>
130 | requestRegistry.register(request).flatMap { bytes =>
131 | if bytes.length == request.length
132 | then bytes.pure[IO]
133 | else Error.InvalidBlockLength(request, bytes.length).raiseError[IO, ByteVector]
134 | }
135 |
136 | def choked: Signal[IO, Boolean] = chokedStatusRef
137 |
138 | def availability: Signal[IO, BitSet] = bitfieldRef
139 |
140 | def disconnected: IO[Unit] = closed.void
141 |
142 | def extensionApi: IO[ExtensionApi[IO]] = initExtension.init
143 | }
144 | end for
145 |
146 | case class ConnectionClosed() extends Throwable
147 |
148 | private def receiveLoop(
149 | requestRegistry: RequestRegistry,
150 | updateBitfield: (BitSet => BitSet) => IO[Unit],
151 | updateChokeStatus: Boolean => IO[Unit],
152 | updateLastMessageAt: Long => IO[Unit],
153 | socket: MessageSocket[IO],
154 | extensionHandler: ExtensionHandler[IO]
155 | ): IO[Nothing] =
156 | socket.receive
157 | .flatMap {
158 | case Message.Unchoke =>
159 | updateChokeStatus(false)
160 | case Message.Choke =>
161 | updateChokeStatus(true)
162 | case Message.Piece(index: Long, begin: Long, bytes: ByteVector) =>
163 | val request: Message.Request = Message.Request(index, begin, bytes.length)
164 | requestRegistry.complete(request, bytes)
165 | case Message.Have(index) =>
166 | updateBitfield(_ incl index.toInt)
167 | case Message.Bitfield(bytes) =>
168 | val indices = bytes.toBitVector.toIndexedSeq.zipWithIndex.collect { case (true, i) =>
169 | i
170 | }
171 | updateBitfield(_ => BitSet(indices*))
172 | case m: Message.Extended =>
173 | extensionHandler(m)
174 | case _ =>
175 | IO.unit
176 | }
177 | .flatTap { _ =>
178 | IO.realTime.flatMap { currentTime =>
179 | updateLastMessageAt(currentTime.toMillis)
180 | }
181 | }
182 | .foreverM
183 |
184 | private def keepAliveLoop(
185 | stateRef: Ref[IO, State],
186 | send: Message => IO[Unit]
187 | ): IO[Nothing] =
188 | IO
189 | .sleep(10.seconds)
190 | .flatMap { _ =>
191 | for
192 | currentTime <- IO.realTime
193 | timedOut <- stateRef.get.map(s => (currentTime - s.lastMessageAt.millis) > 30.seconds)
194 | _ <- IO.whenA(timedOut) {
195 | IO.raiseError(Error.ConnectionTimeout())
196 | }
197 | _ <- send(Message.KeepAlive)
198 | yield ()
199 | }
200 | .foreverM
201 |
202 | private def sendLoop(queue: Queue[IO, Message], socket: MessageSocket[IO]): IO[Nothing] =
203 | queue.take.flatMap(socket.send).foreverM
204 |
205 | enum Error(message: String) extends Exception(message):
206 | case ConnectionTimeout() extends Error("Connection timed out")
207 | case InvalidBlockLength(request: Message.Request, responseLength: Long) extends Error("Invalid block length")
208 | }
209 |
--------------------------------------------------------------------------------
/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/message.scala:
--------------------------------------------------------------------------------
1 | package com.github.torrentdam.bittorrent.dht
2 |
3 | import cats.implicits.*
4 | import com.comcast.ip4s.*
5 | import com.github.torrentdam.bittorrent.InfoHash
6 | import com.github.torrentdam.bittorrent.PeerInfo
7 | import com.github.torrentdam.bencode.format.*
8 | import com.github.torrentdam.bencode.Bencode
9 | import scodec.bits.ByteVector
10 | import scodec.Codec
11 |
12 | enum Message:
13 | case QueryMessage(transactionId: ByteVector, query: Query)
14 | case ResponseMessage(transactionId: ByteVector, response: Response)
15 | case ErrorMessage(transactionId: ByteVector, details: Bencode)
16 |
17 | def transactionId: ByteVector
18 | end Message
19 |
20 | object Message {
21 |
22 | given BencodeFormat[NodeId] =
23 | BencodeFormat.ByteVectorFormat.imap(NodeId.apply)(_.bytes)
24 |
25 | val PingQueryFormat: BencodeFormat[Query.Ping] = (
26 | field[NodeId]("a")(using field[NodeId]("id"))
27 | ).imap[Query.Ping](qni => Query.Ping(qni))(v => v.queryingNodeId)
28 |
29 | val FindNodeQueryFormat: BencodeFormat[Query.FindNode] = (
30 | field[(NodeId, NodeId)]("a")(
31 | using (field[NodeId]("id"), field[NodeId]("target")).tupled
32 | )
33 | ).imap[Query.FindNode](Query.FindNode.apply)(v => (v.queryingNodeId, v.target))
34 |
35 | given BencodeFormat[InfoHash] = BencodeFormat.ByteVectorFormat.imap(InfoHash(_))(_.bytes)
36 |
37 | val GetPeersQueryFormat: BencodeFormat[Query.GetPeers] = (
38 | field[(NodeId, InfoHash)]("a")(
39 | using (field[NodeId]("id"), field[InfoHash]("info_hash")).tupled
40 | )
41 | ).imap[Query.GetPeers](Query.GetPeers.apply)(v => (v.queryingNodeId, v.infoHash))
42 |
43 | val AnnouncePeerQueryFormat: BencodeFormat[Query.AnnouncePeer] = (
44 | field[(NodeId, InfoHash, Long)]("a")(
45 | using (field[NodeId]("id"), field[InfoHash]("info_hash"), field[Long]("port")).tupled
46 | )
47 | ).imap[Query.AnnouncePeer](Query.AnnouncePeer.apply)(v => (v.queryingNodeId, v.infoHash, v.port))
48 |
49 | val SampleInfoHashesQueryFormat: BencodeFormat[Query.SampleInfoHashes] = (
50 | field[(NodeId, NodeId)]("a")(
51 | using (field[NodeId]("id"), field[NodeId]("target")).tupled
52 | )
53 | ).imap[Query.SampleInfoHashes](Query.SampleInfoHashes.apply)(v => (v.queryingNodeId, v.target))
54 |
55 | val QueryFormat: BencodeFormat[Query] =
56 | field[String]("q").choose(
57 | {
58 | case "ping" => PingQueryFormat.upcast
59 | case "find_node" => FindNodeQueryFormat.upcast
60 | case "get_peers" => GetPeersQueryFormat.upcast
61 | case "announce_peer" => AnnouncePeerQueryFormat.upcast
62 | case "sample_infohashes" => SampleInfoHashesQueryFormat.upcast
63 | },
64 | {
65 | case _: Query.Ping => "ping"
66 | case _: Query.FindNode => "find_node"
67 | case _: Query.GetPeers => "get_peers"
68 | case _: Query.AnnouncePeer => "announce_peer"
69 | case _: Query.SampleInfoHashes => "sample_infohashes"
70 | }
71 | )
72 |
73 | val QueryMessageFormat: BencodeFormat[Message.QueryMessage] = (
74 | field[ByteVector]("t"),
75 | QueryFormat
76 | ).imapN[QueryMessage]((tid, q) => QueryMessage(tid, q))(v => (v.transactionId, v.query))
77 |
78 | val InetSocketAddressCodec: Codec[SocketAddress[IpAddress]] = {
79 | import scodec.codecs.*
80 | (bytes(4) :: bytes(2)).xmap(
81 | { case (address, port) =>
82 | SocketAddress(
83 | IpAddress.fromBytes(address.toArray).get,
84 | Port.fromInt(port.toInt(signed = false)).get
85 | )
86 | },
87 | v => (ByteVector(v.host.toBytes), ByteVector.fromInt(v.port.value, 2))
88 | )
89 | }
90 |
91 | val CompactNodeInfoCodec: Codec[List[NodeInfo]] = {
92 | import scodec.codecs.*
93 | list(
94 | (bytes(20) :: InetSocketAddressCodec).xmap(
95 | { case (id, address) =>
96 | NodeInfo(NodeId(id), address)
97 | },
98 | v => (v.id.bytes, v.address)
99 | )
100 | )
101 | }
102 |
103 | val CompactPeerInfoCodec: Codec[PeerInfo] = InetSocketAddressCodec.xmap(PeerInfo.apply, _.address)
104 |
105 | val CompactInfoHashCodec: Codec[List[InfoHash]] = {
106 | import scodec.codecs.*
107 | list(
108 | (bytes(20)).xmap(InfoHash.apply, _.bytes)
109 | )
110 | }
111 |
112 | val PingResponseFormat: BencodeFormat[Response.Ping] =
113 | field[NodeId]("id").imap[Response.Ping](Response.Ping.apply)(_.id)
114 |
115 | val NodesResponseFormat: BencodeFormat[Response.Nodes] = (
116 | field[NodeId]("id"),
117 | field[List[NodeInfo]]("nodes")(using encodedString(CompactNodeInfoCodec))
118 | ).imapN[Response.Nodes](Response.Nodes.apply)(v => (v.id, v.nodes))
119 |
120 | val PeersResponseFormat: BencodeFormat[Response.Peers] = (
121 | field[NodeId]("id"),
122 | field[List[PeerInfo]]("values")(using BencodeFormat.listFormat(using encodedString(CompactPeerInfoCodec)))
123 | ).imapN[Response.Peers](Response.Peers.apply)(v => (v.id, v.peers))
124 |
125 | val SampleInfoHashesResponseFormat: BencodeFormat[Response.SampleInfoHashes] = (
126 | field[NodeId]("id"),
127 | fieldOptional[List[NodeInfo]]("nodes")(using encodedString(CompactNodeInfoCodec)),
128 | field[List[InfoHash]]("samples")(using encodedString(CompactInfoHashCodec))
129 | ).imapN[Response.SampleInfoHashes](Response.SampleInfoHashes.apply)(v => (v.id, v.nodes, v.samples))
130 |
131 | val ResponseFormat: BencodeFormat[Response] =
132 | BencodeFormat(
133 | BencodeFormat.dictionaryFormat.read.flatMap {
134 | case Bencode.BDictionary(dictionary) if dictionary.contains("values") => PeersResponseFormat.read.widen
135 | case Bencode.BDictionary(dictionary) if dictionary.contains("samples") =>
136 | SampleInfoHashesResponseFormat.read.widen
137 | case Bencode.BDictionary(dictionary) if dictionary.contains("nodes") => NodesResponseFormat.read.widen
138 | case _ => PingResponseFormat.read.widen
139 | },
140 | BencodeWriter {
141 | case value: Response.Peers => PeersResponseFormat.write(value)
142 | case value: Response.Nodes => NodesResponseFormat.write(value)
143 | case value: Response.Ping => PingResponseFormat.write(value)
144 | case value: Response.SampleInfoHashes => SampleInfoHashesResponseFormat.write(value)
145 | }
146 | )
147 |
148 | val ResponseMessageFormat: BencodeFormat[Message.ResponseMessage] = (
149 | field[ByteVector]("t"),
150 | field[Response]("r")(using ResponseFormat)
151 | ).imapN[ResponseMessage]((tid, r) => ResponseMessage(tid, r))(v => (v.transactionId, v.response))
152 |
153 | val ErrorMessageFormat: BencodeFormat[Message.ErrorMessage] = (
154 | fieldOptional[ByteVector]("t"),
155 | field[Bencode]("e")
156 | ).imapN[ErrorMessage]((tid, details) => ErrorMessage(tid.getOrElse(ByteVector.empty), details))(v =>
157 | (v.transactionId.some, v.details)
158 | )
159 |
160 | given BencodeFormat[Message] =
161 | field[String]("y").choose(
162 | {
163 | case "q" => QueryMessageFormat.upcast
164 | case "r" => ResponseMessageFormat.upcast
165 | case "e" => ErrorMessageFormat.upcast
166 | },
167 | {
168 | case _: Message.QueryMessage => "q"
169 | case _: Message.ResponseMessage => "r"
170 | case _: Message.ErrorMessage => "e"
171 | }
172 | )
173 | }
174 |
175 | enum Query:
176 | case Ping(queryingNodeId: NodeId)
177 | case FindNode(queryingNodeId: NodeId, target: NodeId)
178 | case GetPeers(queryingNodeId: NodeId, infoHash: InfoHash)
179 | case AnnouncePeer(queryingNodeId: NodeId, infoHash: InfoHash, port: Long)
180 | case SampleInfoHashes(queryingNodeId: NodeId, target: NodeId)
181 |
182 | def queryingNodeId: NodeId
183 | end Query
184 |
185 | enum Response:
186 | case Ping(id: NodeId)
187 | case Nodes(id: NodeId, nodes: List[NodeInfo])
188 | case Peers(id: NodeId, peers: List[PeerInfo])
189 | case SampleInfoHashes(id: NodeId, nodes: Option[List[NodeInfo]], samples: List[InfoHash])
190 |
--------------------------------------------------------------------------------
/mill:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | # This is a wrapper script, that automatically selects or downloads Mill from Maven Central or GitHub release pages.
4 | #
5 | # This script determines the Mill version to use by trying these sources
6 | # - env-variable `MILL_VERSION`
7 | # - local file `.mill-version`
8 | # - local file `.config/mill-version`
9 | # - `mill-version` from YAML fronmatter of current buildfile
10 | # - if accessible, find the latest stable version available on Maven Central (https://repo1.maven.org/maven2)
11 | # - env-variable `DEFAULT_MILL_VERSION`
12 | #
13 | # If a version has the suffix '-native' a native binary will be used.
14 | # If a version has the suffix '-jvm' an executable jar file will be used, requiring an already installed Java runtime.
15 | # If no such suffix is found, the script will pick a default based on version and platform.
16 | #
17 | # Once a version was determined, it tries to use either
18 | # - a system-installed mill, if found and it's version matches
19 | # - an already downloaded version under ~/.cache/mill/download
20 | #
21 | # If no working mill version was found on the system,
22 | # this script downloads a binary file from Maven Central or Github Pages (this is version dependent)
23 | # into a cache location (~/.cache/mill/download).
24 | #
25 | # Mill Project URL: https://github.com/com-lihaoyi/mill
26 | # Script Version: 1.0.0-M1-21-7b6fae-DIRTY892b63e8
27 | #
28 | # If you want to improve this script, please also contribute your changes back!
29 | # This script was generated from: dist/scripts/src/mill.sh
30 | #
31 | # Licensed under the Apache License, Version 2.0
32 |
33 | set -e
34 |
35 | if [ "$1" = "--setup-completions" ] ; then
36 | # Need to preserve the first position of those listed options
37 | MILL_FIRST_ARG=$1
38 | shift
39 | fi
40 |
41 | if [ -z "${DEFAULT_MILL_VERSION}" ] ; then
42 | DEFAULT_MILL_VERSION=1.0.1
43 | fi
44 |
45 |
46 | if [ -z "${GITHUB_RELEASE_CDN}" ] ; then
47 | GITHUB_RELEASE_CDN=""
48 | fi
49 |
50 |
51 | MILL_REPO_URL="https://github.com/com-lihaoyi/mill"
52 |
53 | if [ -z "${CURL_CMD}" ] ; then
54 | CURL_CMD=curl
55 | fi
56 |
57 | # Explicit commandline argument takes precedence over all other methods
58 | if [ "$1" = "--mill-version" ] ; then
59 | echo "The --mill-version option is no longer supported." 1>&2
60 | fi
61 |
62 | MILL_BUILD_SCRIPT=""
63 |
64 | if [ -f "build.mill" ] ; then
65 | MILL_BUILD_SCRIPT="build.mill"
66 | elif [ -f "build.mill.scala" ] ; then
67 | MILL_BUILD_SCRIPT="build.mill.scala"
68 | elif [ -f "build.sc" ] ; then
69 | MILL_BUILD_SCRIPT="build.sc"
70 | fi
71 |
72 | # Please note, that if a MILL_VERSION is already set in the environment,
73 | # We reuse it's value and skip searching for a value.
74 |
75 | # If not already set, read .mill-version file
76 | if [ -z "${MILL_VERSION}" ] ; then
77 | if [ -f ".mill-version" ] ; then
78 | MILL_VERSION="$(tr '\r' '\n' < .mill-version | head -n 1 2> /dev/null)"
79 | elif [ -f ".config/mill-version" ] ; then
80 | MILL_VERSION="$(tr '\r' '\n' < .config/mill-version | head -n 1 2> /dev/null)"
81 | elif [ -n "${MILL_BUILD_SCRIPT}" ] ; then
82 | MILL_VERSION="$(cat ${MILL_BUILD_SCRIPT} | grep '//[|] *mill-version: *' | sed 's;//| *mill-version: *;;')"
83 | fi
84 | fi
85 |
86 | MILL_USER_CACHE_DIR="${XDG_CACHE_HOME:-${HOME}/.cache}/mill"
87 |
88 | if [ -z "${MILL_DOWNLOAD_PATH}" ] ; then
89 | MILL_DOWNLOAD_PATH="${MILL_USER_CACHE_DIR}/download"
90 | fi
91 |
92 | # If not already set, try to fetch newest from Github
93 | if [ -z "${MILL_VERSION}" ] ; then
94 | # TODO: try to load latest version from release page
95 | echo "No mill version specified." 1>&2
96 | echo "You should provide a version via a '//| mill-version: ' comment or a '.mill-version' file." 1>&2
97 |
98 | mkdir -p "${MILL_DOWNLOAD_PATH}"
99 | LANG=C touch -d '1 hour ago' "${MILL_DOWNLOAD_PATH}/.expire_latest" 2>/dev/null || (
100 | # we might be on OSX or BSD which don't have -d option for touch
101 | # but probably a -A [-][[hh]mm]SS
102 | touch "${MILL_DOWNLOAD_PATH}/.expire_latest"; touch -A -010000 "${MILL_DOWNLOAD_PATH}/.expire_latest"
103 | ) || (
104 | # in case we still failed, we retry the first touch command with the intention
105 | # to show the (previously suppressed) error message
106 | LANG=C touch -d '1 hour ago' "${MILL_DOWNLOAD_PATH}/.expire_latest"
107 | )
108 |
109 | # POSIX shell variant of bash's -nt operator, see https://unix.stackexchange.com/a/449744/6993
110 | # if [ "${MILL_DOWNLOAD_PATH}/.latest" -nt "${MILL_DOWNLOAD_PATH}/.expire_latest" ] ; then
111 | if [ -n "$(find -L "${MILL_DOWNLOAD_PATH}/.latest" -prune -newer "${MILL_DOWNLOAD_PATH}/.expire_latest")" ]; then
112 | # we know a current latest version
113 | MILL_VERSION=$(head -n 1 "${MILL_DOWNLOAD_PATH}"/.latest 2> /dev/null)
114 | fi
115 |
116 | if [ -z "${MILL_VERSION}" ] ; then
117 | # we don't know a current latest version
118 | echo "Retrieving latest mill version ..." 1>&2
119 | LANG=C ${CURL_CMD} -s -i -f -I ${MILL_REPO_URL}/releases/latest 2> /dev/null | grep --ignore-case Location: | sed s'/^.*tag\///' | tr -d '\r\n' > "${MILL_DOWNLOAD_PATH}/.latest"
120 | MILL_VERSION=$(head -n 1 "${MILL_DOWNLOAD_PATH}"/.latest 2> /dev/null)
121 | fi
122 |
123 | if [ -z "${MILL_VERSION}" ] ; then
124 | # Last resort
125 | MILL_VERSION="${DEFAULT_MILL_VERSION}"
126 | echo "Falling back to hardcoded mill version ${MILL_VERSION}" 1>&2
127 | else
128 | echo "Using mill version ${MILL_VERSION}" 1>&2
129 | fi
130 | fi
131 |
132 | MILL_NATIVE_SUFFIX="-native"
133 | MILL_JVM_SUFFIX="-jvm"
134 | FULL_MILL_VERSION=$MILL_VERSION
135 | ARTIFACT_SUFFIX=""
136 | set_artifact_suffix(){
137 | if [ "$(expr substr $(uname -s) 1 5 2>/dev/null)" = "Linux" ]; then
138 | if [ "$(uname -m)" = "aarch64" ]; then
139 | ARTIFACT_SUFFIX="-native-linux-aarch64"
140 | else
141 | ARTIFACT_SUFFIX="-native-linux-amd64"
142 | fi
143 | elif [ "$(uname)" = "Darwin" ]; then
144 | if [ "$(uname -m)" = "arm64" ]; then
145 | ARTIFACT_SUFFIX="-native-mac-aarch64"
146 | else
147 | ARTIFACT_SUFFIX="-native-mac-amd64"
148 | fi
149 | else
150 | echo "This native mill launcher supports only Linux and macOS." 1>&2
151 | exit 1
152 | fi
153 | }
154 |
155 | case "$MILL_VERSION" in
156 | *"$MILL_NATIVE_SUFFIX")
157 | MILL_VERSION=${MILL_VERSION%"$MILL_NATIVE_SUFFIX"}
158 | set_artifact_suffix
159 | ;;
160 |
161 | *"$MILL_JVM_SUFFIX")
162 | MILL_VERSION=${MILL_VERSION%"$MILL_JVM_SUFFIX"}
163 | ;;
164 |
165 | *)
166 | case "$MILL_VERSION" in
167 | 0.1.*) ;;
168 | 0.2.*) ;;
169 | 0.3.*) ;;
170 | 0.4.*) ;;
171 | 0.5.*) ;;
172 | 0.6.*) ;;
173 | 0.7.*) ;;
174 | 0.8.*) ;;
175 | 0.9.*) ;;
176 | 0.10.*) ;;
177 | 0.11.*) ;;
178 | 0.12.*) ;;
179 | *)
180 | set_artifact_suffix
181 | esac
182 | ;;
183 | esac
184 |
185 | MILL="${MILL_DOWNLOAD_PATH}/$MILL_VERSION$ARTIFACT_SUFFIX"
186 |
187 | try_to_use_system_mill() {
188 | if [ "$(uname)" != "Linux" ]; then
189 | return 0
190 | fi
191 |
192 | MILL_IN_PATH="$(command -v mill || true)"
193 |
194 | if [ -z "${MILL_IN_PATH}" ]; then
195 | return 0
196 | fi
197 |
198 | SYSTEM_MILL_FIRST_TWO_BYTES=$(head --bytes=2 "${MILL_IN_PATH}")
199 | if [ "${SYSTEM_MILL_FIRST_TWO_BYTES}" = "#!" ]; then
200 | # MILL_IN_PATH is (very likely) a shell script and not the mill
201 | # executable, ignore it.
202 | return 0
203 | fi
204 |
205 | SYSTEM_MILL_PATH=$(readlink -e "${MILL_IN_PATH}")
206 | SYSTEM_MILL_SIZE=$(stat --format=%s "${SYSTEM_MILL_PATH}")
207 | SYSTEM_MILL_MTIME=$(stat --format=%y "${SYSTEM_MILL_PATH}")
208 |
209 | if [ ! -d "${MILL_USER_CACHE_DIR}" ]; then
210 | mkdir -p "${MILL_USER_CACHE_DIR}"
211 | fi
212 |
213 | SYSTEM_MILL_INFO_FILE="${MILL_USER_CACHE_DIR}/system-mill-info"
214 | if [ -f "${SYSTEM_MILL_INFO_FILE}" ]; then
215 | parseSystemMillInfo() {
216 | LINE_NUMBER="${1}"
217 | # Select the line number of the SYSTEM_MILL_INFO_FILE, cut the
218 | # variable definition in that line in two halves and return
219 | # the value, and finally remove the quotes.
220 | sed -n "${LINE_NUMBER}p" "${SYSTEM_MILL_INFO_FILE}" |\
221 | cut -d= -f2 |\
222 | sed 's/"\(.*\)"/\1/'
223 | }
224 |
225 | CACHED_SYSTEM_MILL_PATH=$(parseSystemMillInfo 1)
226 | CACHED_SYSTEM_MILL_VERSION=$(parseSystemMillInfo 2)
227 | CACHED_SYSTEM_MILL_SIZE=$(parseSystemMillInfo 3)
228 | CACHED_SYSTEM_MILL_MTIME=$(parseSystemMillInfo 4)
229 |
230 | if [ "${SYSTEM_MILL_PATH}" = "${CACHED_SYSTEM_MILL_PATH}" ] \
231 | && [ "${SYSTEM_MILL_SIZE}" = "${CACHED_SYSTEM_MILL_SIZE}" ] \
232 | && [ "${SYSTEM_MILL_MTIME}" = "${CACHED_SYSTEM_MILL_MTIME}" ]; then
233 | if [ "${CACHED_SYSTEM_MILL_VERSION}" = "${MILL_VERSION}" ]; then
234 | MILL="${SYSTEM_MILL_PATH}"
235 | return 0
236 | else
237 | return 0
238 | fi
239 | fi
240 | fi
241 |
242 | SYSTEM_MILL_VERSION=$(${SYSTEM_MILL_PATH} --version | head -n1 | sed -n 's/^Mill.*version \(.*\)/\1/p')
243 |
244 | cat < "${SYSTEM_MILL_INFO_FILE}"
245 | CACHED_SYSTEM_MILL_PATH="${SYSTEM_MILL_PATH}"
246 | CACHED_SYSTEM_MILL_VERSION="${SYSTEM_MILL_VERSION}"
247 | CACHED_SYSTEM_MILL_SIZE="${SYSTEM_MILL_SIZE}"
248 | CACHED_SYSTEM_MILL_MTIME="${SYSTEM_MILL_MTIME}"
249 | EOF
250 |
251 | if [ "${SYSTEM_MILL_VERSION}" = "${MILL_VERSION}" ]; then
252 | MILL="${SYSTEM_MILL_PATH}"
253 | fi
254 | }
255 | try_to_use_system_mill
256 |
257 | # If not already downloaded, download it
258 | if [ ! -s "${MILL}" ] || [ "$MILL_TEST_DRY_RUN_LAUNCHER_SCRIPT" = "1" ] ; then
259 | case $MILL_VERSION in
260 | 0.0.* | 0.1.* | 0.2.* | 0.3.* | 0.4.* )
261 | DOWNLOAD_SUFFIX=""
262 | DOWNLOAD_FROM_MAVEN=0
263 | ;;
264 | 0.5.* | 0.6.* | 0.7.* | 0.8.* | 0.9.* | 0.10.* | 0.11.0-M* )
265 | DOWNLOAD_SUFFIX="-assembly"
266 | DOWNLOAD_FROM_MAVEN=0
267 | ;;
268 | *)
269 | DOWNLOAD_SUFFIX="-assembly"
270 | DOWNLOAD_FROM_MAVEN=1
271 | ;;
272 | esac
273 | case $MILL_VERSION in
274 | 0.12.0 | 0.12.1 | 0.12.2 | 0.12.3 | 0.12.4 | 0.12.5 | 0.12.6 | 0.12.7 | 0.12.8 | 0.12.9 | 0.12.10 | 0.12.11 )
275 | DOWNLOAD_EXT="jar"
276 | ;;
277 | 0.12.* )
278 | DOWNLOAD_EXT="exe"
279 | ;;
280 | 0.* )
281 | DOWNLOAD_EXT="jar"
282 | ;;
283 | *)
284 | DOWNLOAD_EXT="exe"
285 | ;;
286 | esac
287 |
288 | DOWNLOAD_FILE=$(mktemp mill.XXXXXX)
289 | if [ "$DOWNLOAD_FROM_MAVEN" = "1" ] ; then
290 | DOWNLOAD_URL="https://repo1.maven.org/maven2/com/lihaoyi/mill-dist${ARTIFACT_SUFFIX}/${MILL_VERSION}/mill-dist${ARTIFACT_SUFFIX}-${MILL_VERSION}.${DOWNLOAD_EXT}"
291 | else
292 | MILL_VERSION_TAG=$(echo "$MILL_VERSION" | sed -E 's/([^-]+)(-M[0-9]+)?(-.*)?/\1\2/')
293 | DOWNLOAD_URL="${GITHUB_RELEASE_CDN}${MILL_REPO_URL}/releases/download/${MILL_VERSION_TAG}/${MILL_VERSION}${DOWNLOAD_SUFFIX}"
294 | unset MILL_VERSION_TAG
295 | fi
296 |
297 | if [ "$MILL_TEST_DRY_RUN_LAUNCHER_SCRIPT" = "1" ] ; then
298 | echo $DOWNLOAD_URL
299 | echo $MILL
300 | exit 0
301 | fi
302 | # TODO: handle command not found
303 | echo "Downloading mill ${MILL_VERSION} from ${DOWNLOAD_URL} ..." 1>&2
304 | ${CURL_CMD} -f -L -o "${DOWNLOAD_FILE}" "${DOWNLOAD_URL}"
305 | chmod +x "${DOWNLOAD_FILE}"
306 | mkdir -p "${MILL_DOWNLOAD_PATH}"
307 | mv "${DOWNLOAD_FILE}" "${MILL}"
308 |
309 | unset DOWNLOAD_FILE
310 | unset DOWNLOAD_SUFFIX
311 | fi
312 |
313 | if [ -z "$MILL_MAIN_CLI" ] ; then
314 | MILL_MAIN_CLI="${0}"
315 | fi
316 |
317 | MILL_FIRST_ARG=""
318 | if [ "$1" = "--bsp" ] || [ "${1#"-i"}" != "$1" ] || [ "$1" = "--interactive" ] || [ "$1" = "--no-server" ] || [ "$1" = "--no-daemon" ] || [ "$1" = "--repl" ] || [ "$1" = "--help" ] ; then
319 | # Need to preserve the first position of those listed options
320 | MILL_FIRST_ARG=$1
321 | shift
322 | fi
323 |
324 | unset MILL_DOWNLOAD_PATH
325 | unset MILL_OLD_DOWNLOAD_PATH
326 | unset OLD_MILL
327 | unset MILL_VERSION
328 | unset MILL_REPO_URL
329 |
330 | # -D mill.main.cli is for compatibility with Mill 0.10.9 - 0.13.0-M2
331 | # We don't quote MILL_FIRST_ARG on purpose, so we can expand the empty value without quotes
332 | # shellcheck disable=SC2086
333 | exec "${MILL}" $MILL_FIRST_ARG -D "mill.main.cli=${MILL_MAIN_CLI}" "$@"
334 |
--------------------------------------------------------------------------------
/cmd/src/main/scala/Main.scala:
--------------------------------------------------------------------------------
1 | import cats.effect.cps.*
2 | import cats.effect.cps.given
3 | import cats.effect.std.Random
4 | import cats.effect.syntax.all.*
5 | import cats.effect.ExitCode
6 | import cats.effect.IO
7 | import cats.effect.Resource
8 | import cats.effect.ResourceIO
9 | import cats.syntax.all.*
10 | import com.comcast.ip4s.Port
11 | import com.comcast.ip4s.SocketAddress
12 | import com.github.torrentdam.bencode
13 | import com.github.torrentdam.bittorrent.dht.*
14 | import com.github.torrentdam.bittorrent.files.Reader
15 | import com.github.torrentdam.bittorrent.files.Writer
16 | import com.github.torrentdam.bittorrent.wire.Connection
17 | import com.github.torrentdam.bittorrent.wire.Download
18 | import com.github.torrentdam.bittorrent.wire.DownloadMetadata
19 | import com.github.torrentdam.bittorrent.wire.RequestDispatcher
20 | import com.github.torrentdam.bittorrent.wire.Swarm
21 | import com.github.torrentdam.bittorrent.wire.Torrent
22 | import com.github.torrentdam.bittorrent.CrossPlatform
23 | import com.github.torrentdam.bittorrent.InfoHash
24 | import com.github.torrentdam.bittorrent.PeerId
25 | import com.github.torrentdam.bittorrent.PeerInfo
26 | import com.github.torrentdam.bittorrent.TorrentFile
27 | import com.github.torrentdam.bittorrent.TorrentMetadata
28 | import com.monovore.decline.effect.CommandIOApp
29 | import com.monovore.decline.Opts
30 | import cps.syntax.*
31 | import fs2.io.file.Files
32 | import fs2.io.file.Flag
33 | import fs2.io.file.Flags
34 | import fs2.io.file.Path
35 | import fs2.io.file.WriteCursor
36 | import fs2.Chunk
37 | import fs2.Stream
38 | import java.util.concurrent.Executors
39 | import java.util.concurrent.ThreadFactory
40 | import org.legogroup.woof.*
41 | import org.legogroup.woof.given
42 | import scala.concurrent.duration.DurationInt
43 | import scodec.bits.ByteVector
44 |
45 | object Main
46 | extends CommandIOApp(
47 | name = "torrentdam",
48 | header = "TorrentDam"
49 | ) {
50 |
51 | def main: Opts[IO[ExitCode]] =
52 | torrentCommand <+> dhtCommand
53 |
54 | def torrentCommand: Opts[IO[ExitCode]] =
55 | Opts.subcommand("torrent", "torrent client")(
56 | fetchFileCommand <+> downloadCommand <+> verifyCommand
57 | )
58 |
59 | def fetchFileCommand =
60 | Opts.subcommand("fetch-file", "download torrent file") {
61 | (
62 | Opts.option[String]("info-hash", "Info-hash"),
63 | Opts.option[String]("save", "Save as a torrent file")
64 | )
65 | .mapN { (infoHashOption, targetFilePath) =>
66 | withLogger {
67 | async[ResourceIO] {
68 | given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await
69 |
70 | val selfPeerId = Resource.eval(PeerId.generate[IO]).await
71 | val infoHash = Resource.eval(infoHashFromString(infoHashOption)).await
72 | val node = Node().await
73 |
74 | val swarm = Swarm(
75 | node.discovery.discover(infoHash),
76 | Connection.connect(selfPeerId, _, infoHash)
77 | ).await
78 | val metadata = DownloadMetadata(swarm).toResource.await
79 | val torrentFile = TorrentFile(metadata, None)
80 | Files[IO]
81 | .writeAll(Path(targetFilePath), Flags.Write)(
82 | Stream.chunk(Chunk.byteVector(TorrentFile.toBytes(torrentFile)))
83 | )
84 | .compile
85 | .drain
86 | .as(ExitCode.Success)
87 | }.useEval
88 | }
89 | }
90 | }
91 |
92 | def downloadCommand =
93 | Opts.subcommand("download", "download torrent data") {
94 | val options = (
95 | Opts.option[String]("info-hash", "Info-hash").orNone,
96 | Opts.option[String]("torrent", "Torrent file").orNone,
97 | Opts.option[String]("peer", "Peer address").orNone,
98 | Opts.option[String]("dht-node", "DHT node address").orNone
99 | ).tupled
100 | options.map { case (infoHashOption, torrentFileOption, peerAddressOption, dhtNodeAddressOption) =>
101 | withLogger {
102 | async[ResourceIO] {
103 | val torrentFile: Option[TorrentFile] = torrentFileOption
104 | .traverse[IO, TorrentFile](torrentFileOption =>
105 | async[IO] {
106 | val torrentFileBytes = Files[IO]
107 | .readAll(Path(torrentFileOption))
108 | .compile
109 | .to(ByteVector)
110 | .await
111 | TorrentFile
112 | .fromBytes(torrentFileBytes)
113 | .liftTo[IO]
114 | .await
115 | }
116 | )
117 | .toResource
118 | .await
119 | val infoHash: InfoHash =
120 | torrentFile match
121 | case Some(torrentFile) =>
122 | torrentFile.infoHash
123 | case None =>
124 | infoHashOption match
125 | case Some(infoHashOption) =>
126 | infoHashFromString(infoHashOption).toResource.await
127 | case None =>
128 | throw new Exception("Missing info-hash")
129 |
130 | given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await
131 | val selfPeerId = Resource.eval(PeerId.generate[IO]).await
132 | val peerAddress = peerAddressOption.flatMap(SocketAddress.fromStringIp)
133 | val peers: Stream[IO, PeerInfo] =
134 | peerAddress match
135 | case Some(peerAddress) =>
136 | Stream.emit(PeerInfo(peerAddress)).covary[IO]
137 | case None =>
138 | val bootstrapNodeAddress = dhtNodeAddressOption.flatMap(SocketAddress.fromString)
139 | val node = Node(none, bootstrapNodeAddress).await
140 | node.discovery.discover(infoHash)
141 | val swarm = Swarm(peers, peerInfo => Connection.connect(selfPeerId, peerInfo, infoHash)).await
142 | val metadata =
143 | torrentFile match
144 | case Some(torrentFile) =>
145 | torrentFile.info
146 | case None =>
147 | Resource.eval(DownloadMetadata(swarm)).await
148 | val torrent = Torrent.make(metadata, swarm).await
149 | val total = (metadata.parsed.pieces.length.toDouble / 20).ceil.toLong
150 | val counter = Resource.eval(IO.ref(0)).await
151 | val writer = Writer.fromTorrent(metadata.parsed)
152 | val createDirectories = metadata.parsed.files
153 | .filter(_.path.length > 1)
154 | .map(_.path.init)
155 | .distinct
156 | .traverse { path =>
157 | val dir = path.foldLeft(Path("."))(_ / _)
158 | Files[IO].createDirectories(dir)
159 | }
160 | Resource.eval(createDirectories).await
161 | val openFiles: Map[TorrentMetadata.File, WriteCursor[IO]] =
162 | metadata.parsed.files
163 | .traverse { file =>
164 | val path = file.path.foldLeft(Path("."))(_ / _)
165 | val flags = Flags(Flag.Create, Flag.Write)
166 | val cursor = Files[IO].writeCursor(path, flags)
167 | cursor.tupleLeft(file)
168 | }
169 | .await
170 | .toMap
171 | Stream
172 | .range(0L, total)
173 | .parEvalMap(10)(index =>
174 | async[IO] {
175 | val piece = !torrent.downloadPiece(index)
176 | val count = !counter.updateAndGet(_ + 1)
177 | val percent = ((count.toDouble / total) * 100).toInt
178 | !Logger[IO].info(s"Downloaded piece $count/$total ($percent%)")
179 | Chunk.iterable(writer.write(index, piece))
180 | }
181 | )
182 | .unchunks
183 | .evalMap(write => openFiles(write.file).seek(write.offset).write(Chunk.byteVector(write.bytes)))
184 | .compile
185 | .drain
186 | .as(ExitCode.Success)
187 | }.useEval
188 | }
189 | }
190 | }
191 |
192 | def verifyCommand =
193 | Opts.subcommand("verify", "verify torrent data") {
194 | val options: Opts[(String, String)] =
195 | (
196 | Opts.option[String]("torrent", "Torrent file"),
197 | Opts.option[String]("target", "Torrent data directory")
198 | ).tupled
199 | options.map { (torrentFileName, targetDirName) =>
200 | withLogger {
201 | async[IO] {
202 | try
203 | val bytes = Files[IO].readAll(Path(torrentFileName)).compile.to(Array).map(ByteVector(_)).await
204 | val torrentFile = IO.fromEither(TorrentFile.fromBytes(bytes)).await
205 | val infoHash = InfoHash(CrossPlatform.sha1(bencode.encode(torrentFile.info.raw).bytes))
206 | Logger[IO].info(s"Info-hash: $infoHash").await
207 |
208 | val reader = Reader.fromTorrent(torrentFile.info.parsed)
209 |
210 | def readPiece(index: Long): IO[ByteVector] =
211 | val reads = Stream.emits(reader.read(index))
212 | reads
213 | .covary[IO]
214 | .evalMap { read =>
215 | val path = read.file.path.foldLeft(Path(targetDirName))(_ / _)
216 | Files[IO]
217 | .readRange(path, 1024 * 1024, read.offset, read.endOffset)
218 | .chunks
219 | .map(_.toByteVector)
220 | .compile
221 | .fold(ByteVector.empty)(_ ++ _)
222 | }
223 | .compile
224 | .fold(ByteVector.empty)(_ ++ _)
225 |
226 | val readByteCount = IO.ref(0L).await
227 |
228 | Stream
229 | .unfold(torrentFile.info.parsed.pieces)(bytes =>
230 | if bytes.isEmpty then None
231 | else
232 | val (checksum, rest) = bytes.splitAt(20)
233 | Some((checksum, rest))
234 | )
235 | .zipWithIndex
236 | .evalMap { (checksum, index) =>
237 | readPiece(index).map((checksum, index, _))
238 | }
239 | .evalTap { (checksum, index, bytes) =>
240 | if CrossPlatform.sha1(bytes) == checksum then readByteCount.update(_ + bytes.length)
241 | else Logger[IO].error(s"Piece $index failed") >> IO.raiseError(new Exception)
242 | }
243 | .compile
244 | .drain
245 | .await
246 | val totalBytes = readByteCount.get.await
247 | Logger[IO].info(s"Read $totalBytes bytes").await
248 | Logger[IO].info("All pieces verified").await
249 | ExitCode.Success
250 | catch
251 | case e =>
252 | Logger[IO].error(e.getMessage).await
253 | ExitCode.Error
254 | }
255 | }
256 | }
257 | }
258 |
259 | def dhtCommand: Opts[IO[ExitCode]] =
260 | Opts.subcommand("dht", "DHT client")(
261 | startCommand <+> getPeers
262 | )
263 |
264 | def startCommand =
265 | Opts.subcommand("start", "start DHT node") {
266 | Opts.option[Int]("port", "UDP port").map { portParam =>
267 | withLogger {
268 | async[ResourceIO] {
269 | val port = Port.fromInt(portParam).liftTo[ResourceIO](new Exception("Invalid port")).await
270 | given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await
271 | Node(Some(port)).await
272 | }.useForever
273 | }
274 | }
275 | }
276 |
277 | def getPeers =
278 | Opts.subcommand("get-peers", "send single get_peers query") {
279 | (
280 | Opts.option[String]("host", "DHT node address"),
281 | Opts.option[String]("info-hash", "Info-hash"),
282 | ).tupled.map { (nodeAddressParam, infoHashParam) =>
283 | withLogger {
284 | async[ResourceIO] {
285 | val nodeAddress = SocketAddress.fromString(nodeAddressParam).liftTo[ResourceIO](new Exception("Invalid address")).await
286 | val nodeIpAddress = nodeAddress.resolve[IO].toResource.await
287 | given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await
288 | val selfId = Resource.eval(NodeId.random[IO]).await
289 | val infoHash = infoHashFromString(infoHashParam).toResource.await
290 | val messageSocket = MessageSocket(none).await
291 | val client = Client(selfId, messageSocket, QueryHandler.noop).await
292 | async[IO]:
293 | val response = client.getPeers(nodeIpAddress, infoHash).await
294 | IO.println(response).await
295 | ExitCode.Success
296 | }.useEval
297 | }
298 | }
299 | }
300 |
301 | extension (torrentFile: TorrentFile) {
302 | def infoHash: InfoHash = InfoHash(CrossPlatform.sha1(bencode.encode(torrentFile.info.raw).bytes))
303 | }
304 |
305 | def infoHashFromString(value: String): IO[InfoHash] =
306 | InfoHash.fromString
307 | .unapply(value)
308 | .liftTo[IO](new Exception("Malformed info-hash"))
309 |
310 | def withLogger[A](body: Logger[IO] ?=> IO[A]): IO[A] =
311 | given Filter = Filter.atLeastLevel(LogLevel.Info)
312 | given Printer = ColorPrinter()
313 | DefaultLogger.makeIo(Output.fromConsole[IO]).flatMap(body(using _))
314 | }
315 |
--------------------------------------------------------------------------------