├── .github ├── actions │ └── prepare-build │ │ └── action.yaml └── workflows │ └── main.yaml ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── cpp ├── .clang-format ├── .gitignore ├── CMakeLists.txt ├── Makefile ├── README.md ├── afl-driver │ ├── input │ │ ├── test-input-1391d5ec4099.bin │ │ ├── test-input-15011afa9422.bin │ │ ├── test-input-196b5e49ee87.bin │ │ ├── test-input-1b33ca8199f4.bin │ │ ├── test-input-235ec4531793.bin │ │ ├── test-input-27832f299ed2.bin │ │ ├── test-input-2fee9654c30c.bin │ │ ├── test-input-34ad70c6cca5.bin │ │ ├── test-input-369d0c4bf7a9.bin │ │ ├── test-input-38af39dc2d29.bin │ │ ├── test-input-4a64a107f0cb.bin │ │ ├── test-input-4aee27bef7b5.bin │ │ ├── test-input-4e289323f5d6.bin │ │ ├── test-input-54027a4fe751.bin │ │ ├── test-input-58b2bde5f8af.bin │ │ ├── test-input-5bc9e95a145a.bin │ │ ├── test-input-5feceb66ffc8.bin │ │ ├── test-input-65465e652d86.bin │ │ ├── test-input-6c36230ff10e.bin │ │ ├── test-input-6ce6cf70fffc.bin │ │ ├── test-input-6ea581a0e757.bin │ │ ├── test-input-802c20fe2e20.bin │ │ ├── test-input-88d385a01d77.bin │ │ ├── test-input-9af15b336e6a.bin │ │ ├── test-input-9f3d0f837f35.bin │ │ ├── test-input-a1bbf2e554bb.bin │ │ ├── test-input-ab801f7b5f13.bin │ │ ├── test-input-bb5639198655.bin │ │ ├── test-input-c037db48e6e1.bin │ │ ├── test-input-d355cded3e49.bin │ │ ├── test-input-de009b6db0fd.bin │ │ ├── test-input-de15ac7d6d86.bin │ │ ├── test-input-dfa882933f94.bin │ │ ├── test-input-f456f8617776.bin │ │ ├── test-input-f91405489d5d.bin │ │ ├── test-input-fcdb4b423f4e.bin │ │ └── test-input-fe9bdea66eab.bin │ └── src │ │ └── main.cpp ├── src │ └── dave │ │ ├── boringssl_cryptor.cpp │ │ ├── boringssl_cryptor.h │ │ ├── codec_utils.cpp │ │ ├── codec_utils.h │ │ ├── common.h │ │ ├── cryptor.cpp │ │ ├── cryptor.h │ │ ├── cryptor_manager.cpp │ │ ├── cryptor_manager.h │ │ ├── decryptor.cpp │ │ ├── decryptor.h │ │ ├── encryptor.cpp │ │ ├── encryptor.h │ │ ├── frame_processors.cpp │ │ ├── frame_processors.h │ │ ├── key_ratchet.h │ │ ├── logger.cpp │ │ ├── logger.h │ │ ├── mls │ │ ├── detail │ │ │ ├── persisted_key_pair.h │ │ │ ├── persisted_key_pair_apple.cpp │ │ │ ├── persisted_key_pair_generic.cpp │ │ │ ├── persisted_key_pair_null.cpp │ │ │ └── persisted_key_pair_win.cpp │ │ ├── parameters.cpp │ │ ├── parameters.h │ │ ├── persisted_key_pair.cpp │ │ ├── persisted_key_pair.h │ │ ├── persisted_key_pair_null.cpp │ │ ├── session.cpp │ │ ├── session.h │ │ ├── user_credential.cpp │ │ ├── user_credential.h │ │ ├── util.cpp │ │ └── util.h │ │ ├── mls_key_ratchet.cpp │ │ ├── mls_key_ratchet.h │ │ ├── utils │ │ ├── array_view.h │ │ ├── clock.h │ │ ├── leb128.cpp │ │ ├── leb128.h │ │ └── scope_exit.h │ │ ├── version.cpp │ │ └── version.h ├── test │ ├── CMakeLists.txt │ ├── boringssl_cryptor_tests.cpp │ ├── codec_utils_tests.cpp │ ├── cryptor_manager_tests.cpp │ ├── cryptor_tests.cpp │ ├── dave_test.cpp │ ├── dave_test.h │ ├── static_key_ratchet.cpp │ └── static_key_ratchet.h └── vcpkg-alts │ └── boringssl │ ├── overlay-ports │ └── mlspp │ │ ├── portfile.cmake │ │ └── vcpkg.json │ └── vcpkg.json └── js ├── .gitignore ├── DisplayableCode.ts ├── KeyFingerprint.ts ├── KeySerialization.ts ├── PairwiseFingerprint.ts ├── README.md ├── __tests__ ├── DisplayableCode-test.ts ├── KeyFingerprint-test.ts ├── KeySerialization-test.ts └── PairwiseFingerprint-test.ts ├── jest-setup.js ├── jest.config.js ├── libdave.ts ├── package.json ├── pnpm-lock.yaml └── tsconfig.json /.github/actions/prepare-build/action.yaml: -------------------------------------------------------------------------------- 1 | name: Install build prerequisites 2 | 3 | inputs: 4 | runner: 5 | description: The runner on which the action is being run 6 | required: true 7 | crypto: 8 | description: The crypto library being used 9 | required: true 10 | cache-dir: 11 | description: Where to put vcpkg cache 12 | required: true 13 | 14 | runs: 15 | using: "composite" 16 | steps: 17 | - name: Capture vcpkg revision for use in cache key 18 | shell: bash 19 | run: | 20 | git -C cpp/vcpkg rev-parse HEAD > cpp/vcpkg_commit.txt 21 | 22 | - name: Restore cache 23 | uses: actions/cache@v4 24 | with: 25 | path: ${{ inputs.cache-dir }} 26 | key: v01-vcpkg-${{ inputs.runner }}-${{ inputs.crypto }}-${{ hashFiles('vcpkg_commit', 'cpp/vcpkg-alts/*') }} 27 | restore-keys: | 28 | v01-vcpkg-${{ inputs.runner }}-${{ inputs.crypto }} 29 | 30 | - name: Install dependencies (macOS) 31 | if: ${{ runner.os == 'macOS' }} 32 | shell: bash 33 | run: | 34 | brew install ninja go nasm 35 | 36 | - name: Install dependencies (Ubuntu) 37 | if: ${{ runner.os == 'Linux' }} 38 | shell: bash 39 | run: | 40 | sudo apt-get install -y nasm 41 | -------------------------------------------------------------------------------- /.github/workflows/main.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["**"] 8 | 9 | env: 10 | CMAKE_BUILD_PARALLEL_LEVEL: 3 11 | CMAKE_TOOLCHAIN_FILE: ${{ github.workspace }}/cpp/vcpkg/scripts/buildsystems/vcpkg.cmake 12 | VCPKG_BINARY_SOURCES: files,${{ github.workspace }}/vcpkg_cache,readwrite 13 | 14 | jobs: 15 | build: 16 | strategy: 17 | matrix: 18 | runner: [ubuntu-latest, macos-latest] 19 | crypto: [boringssl] 20 | persistent_keys: [ON, OFF] 21 | fail-fast: false 22 | 23 | env: 24 | CRYPTO_ALT_DIR: "./vcpkg-alts/${{ matrix.crypto }}" 25 | 26 | runs-on: ${{matrix.runner}} 27 | 28 | defaults: 29 | run: 30 | working-directory: ./cpp 31 | 32 | steps: 33 | - uses: actions/checkout@v4 34 | with: 35 | submodules: "recursive" 36 | fetch-depth: 0 37 | 38 | - uses: ./.github/actions/prepare-build 39 | with: 40 | runner: ${{ matrix.runner }} 41 | crypto: ${{ matrix.crypto }} 42 | cache-dir: ${{ github.workspace }}/vcpkg_cache 43 | 44 | - name: configure 45 | run: >- 46 | cmake -B "${{ runner.temp }}/build_${{ matrix.crypto }}" 47 | -DVCPKG_MANIFEST_DIR="${{ env.CRYPTO_ALT_DIR }}" 48 | -DPERSISTENT_KEYS=${{ matrix.persistent_keys }} 49 | -DTESTING=ON 50 | 51 | - name: build 52 | run: cmake --build "${{ runner.temp }}/build_${{ matrix.crypto }}" 53 | 54 | - name: test 55 | run: | 56 | ${{ runner.temp }}/build_${{ matrix.crypto }}/test/libdave_test 57 | 58 | - name: Archive mlspp 59 | uses: actions/upload-artifact@v4 60 | if: always() 61 | with: 62 | name: mlspp-${{ matrix.runner }}-${{ matrix.crypto }}-${{ matrix.persistent_keys }} 63 | path: | 64 | ${{ github.workspace }}/cpp/vcpkg/buildtrees/mlspp 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "cpp/vcpkg"] 2 | path = cpp/vcpkg 3 | url = https://github.com/microsoft/vcpkg.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Discord 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## libdave 2 | 3 | This repository contains the JS and C++ libraries which together implement Discord's Audio & Video End-to-End Encryption (DAVE) protocol. These libraries are leveraged by Discord's native clients to support the DAVE protocol. 4 | 5 | The DAVE protocol is described in detail in the [protocol whitepaper](https://github.com/discord/dave-protocol). -------------------------------------------------------------------------------- /cpp/.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | AccessModifierOffset: -4 3 | AlignAfterOpenBracket: true 4 | AlignConsecutiveAssignments: false 5 | AlignConsecutiveDeclarations: false 6 | AlignEscapedNewlines: Left 7 | AlignOperands: false 8 | AlignTrailingComments: true 9 | AllowAllParametersOfDeclarationOnNextLine: false 10 | AllowShortBlocksOnASingleLine: false 11 | AllowShortCaseLabelsOnASingleLine: false 12 | AllowShortFunctionsOnASingleLine: InlineOnly 13 | AllowShortIfStatementsOnASingleLine: false 14 | AllowShortLoopsOnASingleLine: false 15 | AlwaysBreakAfterReturnType: None 16 | AlwaysBreakBeforeMultilineStrings: false 17 | AlwaysBreakTemplateDeclarations: true 18 | BinPackArguments: false 19 | BinPackParameters: false 20 | BreakBeforeBinaryOperators: None 21 | BreakBeforeBraces: Stroustrup 22 | BreakBeforeInheritanceComma: true 23 | BreakBeforeTernaryOperators: true 24 | BreakConstructorInitializers: BeforeComma 25 | BreakStringLiterals: true 26 | ColumnLimit: 100 27 | CommentPragmas: "" 28 | CompactNamespaces: false 29 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 30 | ConstructorInitializerIndentWidth: 2 31 | ContinuationIndentWidth: 2 32 | Cpp11BracedListStyle: true 33 | DerivePointerAlignment: false 34 | DisableFormat: false 35 | FixNamespaceComments: true 36 | ForEachMacros: [] 37 | IndentCaseLabels: false 38 | IncludeBlocks: Preserve 39 | IncludeCategories: 40 | - Regex: "^<(W|w)indows.h>" 41 | Priority: 1 42 | - Regex: "^<" 43 | Priority: 2 44 | - Regex: ".*" 45 | Priority: 3 46 | IncludeIsMainRegex: "(_test|_win|_linux|_mac|_ios|_osx|_null)?$" 47 | IndentPPDirectives: None 48 | IndentWidth: 4 49 | IndentWrappedFunctionNames: false 50 | KeepEmptyLinesAtTheStartOfBlocks: false 51 | MacroBlockBegin: "" 52 | MacroBlockEnd: "" 53 | MaxEmptyLinesToKeep: 1 54 | NamespaceIndentation: None 55 | PenaltyBreakAssignment: 0 56 | PenaltyBreakBeforeFirstCallParameter: 1 57 | PenaltyBreakComment: 300 58 | PenaltyBreakFirstLessLess: 120 59 | PenaltyBreakString: 1000 60 | PenaltyExcessCharacter: 1000000 61 | PenaltyReturnTypeOnItsOwnLine: 9999999 62 | PointerAlignment: Left 63 | ReflowComments: true 64 | SortIncludes: true 65 | SortUsingDeclarations: true 66 | SpaceAfterCStyleCast: false 67 | SpaceAfterTemplateKeyword: true 68 | SpaceBeforeAssignmentOperators: true 69 | SpaceBeforeParens: ControlStatements 70 | SpaceInEmptyParentheses: false 71 | SpacesBeforeTrailingComments: 1 72 | SpacesInAngles: false 73 | SpacesInCStyleCastParentheses: false 74 | SpacesInContainerLiterals: true 75 | SpacesInParentheses: false 76 | SpacesInSquareBrackets: false 77 | Standard: Cpp11 78 | TabWidth: 4 79 | UseTab: Never 80 | -------------------------------------------------------------------------------- /cpp/.gitignore: -------------------------------------------------------------------------------- 1 | build/ -------------------------------------------------------------------------------- /cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.20) 2 | 3 | project( 4 | libdave 5 | VERSION 1.0 6 | LANGUAGES CXX 7 | ) 8 | 9 | option(REQUIRE_BORINGSSL "Require BoringSSL instead of OpenSSL" ON) 10 | option(TESTING "Build tests" OFF) 11 | option(PERSISTENT_KEYS "Enable storage of persistent signature keys" OFF) 12 | 13 | include(CheckCXXCompilerFlag) 14 | 15 | set(CMAKE_CXX_STANDARD 17) 16 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 17 | 18 | if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID MATCHES "GNU") 19 | add_compile_options(-Wall -pedantic -Wextra -Werror) 20 | elseif(MSVC) 21 | add_compile_options(/W4 /WX) 22 | add_definitions(-DWINDOWS) 23 | 24 | # MSVC helpfully recommends safer equivalents for things like 25 | # getenv, but they are not portable. 26 | add_definitions(-D_CRT_SECURE_NO_WARNINGS) 27 | endif() 28 | 29 | find_package(OpenSSL REQUIRED) 30 | if (OPENSSL_FOUND) 31 | find_path(BORINGSSL_INCLUDE_DIR openssl/is_boringssl.h HINTS ${OPENSSL_INCLUDE_DIR} NO_DEFAULT_PATH) 32 | 33 | if (BORINGSSL_INCLUDE_DIR) 34 | message(STATUS "Found OpenSSL includes are for BoringSSL") 35 | 36 | add_compile_definitions(WITH_BORINGSSL) 37 | 38 | if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID MATCHES "GNU") 39 | add_compile_options(-Wno-gnu-anonymous-struct -Wno-nested-anon-types) 40 | endif () 41 | 42 | file(STRINGS "${OPENSSL_INCLUDE_DIR}/openssl/crypto.h" boringssl_version_str 43 | REGEX "^#[\t ]*define[\t ]+OPENSSL_VERSION_TEXT[\t ]+\"OpenSSL ([0-9])+\\.([0-9])+\\.([0-9])+ .+") 44 | 45 | string(REGEX REPLACE "^.*OPENSSL_VERSION_TEXT[\t ]+\"OpenSSL ([0-9]+\\.[0-9]+\\.[0-9])+ .+$" 46 | "\\1" OPENSSL_VERSION "${boringssl_version_str}") 47 | 48 | elseif (REQUIRE_BORINGSSL) 49 | message(FATAL_ERROR "BoringSSL required but not found") 50 | endif () 51 | 52 | if (${OPENSSL_VERSION} VERSION_GREATER_EQUAL 3) 53 | message(FATAL_ERROR "OpenSSL 3.x is not supported") 54 | elseif(${OPENSSL_VERSION} VERSION_LESS 1.1.1) 55 | message(FATAL_ERROR "OpenSSL 1.1.1 or greater is required") 56 | endif() 57 | 58 | message(STATUS "OpenSSL Found: ${OPENSSL_VERSION}") 59 | message(STATUS "OpenSSL Include: ${OPENSSL_INCLUDE_DIR}") 60 | message(STATUS "OpenSSL Libraries: ${OPENSSL_LIBRARIES}") 61 | else() 62 | message(FATAL_ERROR "No OpenSSL library found") 63 | endif() 64 | 65 | find_package(MLSPP CONFIG REQUIRED) 66 | 67 | set(CMAKE_STATIC_LIBRARY_PREFIX "") 68 | 69 | SET(LIB_NAME ${PROJECT_NAME}) 70 | file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.h") 71 | file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") 72 | 73 | # remove all of the persistent key files 74 | list(FILTER LIB_SOURCES EXCLUDE REGEX ".*persisted_key.*") 75 | 76 | if (PERSISTENT_KEYS) 77 | # persistent keys enabled 78 | list(APPEND LIB_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/src/dave/mls/persisted_key_pair.cpp") 79 | 80 | if (APPLE) 81 | # Apple has its own native and generic implementation, we just add the _apple.cpp file 82 | list(APPEND LIB_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/src/dave/mls/detail/persisted_key_pair_apple.cpp") 83 | else () 84 | # Other platforms share the generic implementation 85 | list(APPEND LIB_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/src/dave/mls/detail/persisted_key_pair_generic.cpp") 86 | 87 | if (WIN32) 88 | # Windows has a native implementation 89 | list(APPEND LIB_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/src/dave/mls/detail/persisted_key_pair_win.cpp") 90 | else () 91 | # We don't have a native implementation, so we include the nullified native 92 | list(APPEND LIB_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/src/dave/mls/detail/persisted_key_pair_null.cpp") 93 | endif () 94 | endif () 95 | 96 | else () 97 | # not using persistent keys, so we just need to add the null implementation 98 | list (APPEND LIB_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/src/dave/mls/persisted_key_pair_null.cpp") 99 | endif () 100 | 101 | if (NOT WIN32) 102 | list(FILTER LIB_SOURCES EXCLUDE REGEX ".*_win.cpp") 103 | endif () 104 | 105 | if (NOT APPLE) 106 | list(FILTER LIB_SOURCES EXCLUDE REGEX ".*_apple.cpp") 107 | endif () 108 | 109 | add_library(${LIB_NAME} ${LIB_HEADERS} ${LIB_SOURCES}) 110 | 111 | target_include_directories( 112 | ${LIB_NAME} 113 | PUBLIC 114 | $ 115 | $ 116 | ${OPENSSL_INCLUDE_DIR} 117 | ) 118 | 119 | target_link_libraries(${LIB_NAME} PUBLIC OpenSSL::Crypto) 120 | target_link_libraries(${LIB_NAME} PUBLIC MLSPP::mlspp) 121 | 122 | if (TESTING) 123 | add_subdirectory(test) 124 | endif() -------------------------------------------------------------------------------- /cpp/Makefile: -------------------------------------------------------------------------------- 1 | BUILD_DIR=build 2 | TEST_DIR=build/test 3 | CLANG_FORMAT=clang-format -i -style=file:.clang-format 4 | BORINGSSL_MANIFEST=vcpkg-alts/boringssl 5 | TOOLCHAIN_FILE=vcpkg/scripts/buildsystems/vcpkg.cmake 6 | 7 | all: ${BUILD_DIR} 8 | cmake --build ${BUILD_DIR} --target libdave 9 | 10 | ${BUILD_DIR}: CMakeLists.txt test/CMakeLists.txt 11 | cmake -B${BUILD_DIR} \ 12 | -DVCPKG_MANIFEST_DIR=${BORINGSSL_MANIFEST} \ 13 | -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} 14 | 15 | dev: ${TOOLCHAIN_FILE} 16 | cmake -B${BUILD_DIR} -DTESTING=ON -DCMAKE_BUILD_TYPE=Debug \ 17 | -DVCPKG_MANIFEST_DIR=${BORINGSSL_MANIFEST} \ 18 | -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} 19 | 20 | test: dev test/* 21 | cmake --build ${BUILD_DIR} --target libdave_test 22 | 23 | dtest: test 24 | ${TEST_DIR}/libdave_test 25 | 26 | dbtest: test 27 | lldb ${TEST_DIR}/libdave_test 28 | 29 | ctest: test 30 | cmake --build ${BUILD_DIR} --target test 31 | 32 | clean: 33 | cmake --build ${BUILD_DIR} --target clean 34 | 35 | cclean: 36 | rm -rf ${BUILD_DIR} 37 | 38 | format: 39 | find src -iname "*.h" -or -iname "*.cpp" | xargs ${CLANG_FORMAT} 40 | find test -iname "*.h" -or -iname "*.cpp" | xargs ${CLANG_FORMAT} 41 | -------------------------------------------------------------------------------- /cpp/README.md: -------------------------------------------------------------------------------- 1 | ## libdave C++ 2 | 3 | Contains the libdave C++ library, which handles the bulk of the DAVE protocol implementation for Discord's native clients. 4 | 5 | ### Dependencies 6 | 7 | - [mlspp](https://github.com/cisco/mlspp) 8 | - Configured with `-DMLS_CXX_NAMESPACE="mlspp"` and `-DDISABLE_GREASE=ON` 9 | - [boringssl](https://boringssl.googlesource.com/boringssl) 10 | 11 | #### Testing 12 | 13 | - [googletest](https://github.com/google/googletest) 14 | - [AFLplusplus](https://github.com/AFLplusplus/AFLplusplus) 15 | -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-1391d5ec4099.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-1391d5ec4099.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-15011afa9422.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-15011afa9422.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-196b5e49ee87.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-196b5e49ee87.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-1b33ca8199f4.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-1b33ca8199f4.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-235ec4531793.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-235ec4531793.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-27832f299ed2.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-27832f299ed2.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-2fee9654c30c.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-2fee9654c30c.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-34ad70c6cca5.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-34ad70c6cca5.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-369d0c4bf7a9.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-369d0c4bf7a9.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-38af39dc2d29.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-38af39dc2d29.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-4a64a107f0cb.bin: -------------------------------------------------------------------------------- 1 |  -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-4aee27bef7b5.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-4aee27bef7b5.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-4e289323f5d6.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-4e289323f5d6.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-54027a4fe751.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-54027a4fe751.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-58b2bde5f8af.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-58b2bde5f8af.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-5bc9e95a145a.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-5bc9e95a145a.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-5feceb66ffc8.bin: -------------------------------------------------------------------------------- 1 | 0 -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-65465e652d86.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-65465e652d86.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-6c36230ff10e.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-6c36230ff10e.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-6ce6cf70fffc.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-6ce6cf70fffc.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-6ea581a0e757.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-6ea581a0e757.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-802c20fe2e20.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-802c20fe2e20.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-88d385a01d77.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-88d385a01d77.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-9af15b336e6a.bin: -------------------------------------------------------------------------------- 1 | 0000 -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-9f3d0f837f35.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-9f3d0f837f35.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-a1bbf2e554bb.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-a1bbf2e554bb.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-ab801f7b5f13.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-ab801f7b5f13.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-bb5639198655.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-bb5639198655.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-c037db48e6e1.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-c037db48e6e1.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-d355cded3e49.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-d355cded3e49.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-de009b6db0fd.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-de009b6db0fd.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-de15ac7d6d86.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-de15ac7d6d86.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-dfa882933f94.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-dfa882933f94.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-f456f8617776.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-f456f8617776.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-f91405489d5d.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-f91405489d5d.bin -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-fcdb4b423f4e.bin: -------------------------------------------------------------------------------- 1 | 0000000000000000 -------------------------------------------------------------------------------- /cpp/afl-driver/input/test-input-fe9bdea66eab.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord/libdave/6e5ffbc1cb4eef6be96e8115c4626be598b7e501/cpp/afl-driver/input/test-input-fe9bdea66eab.bin -------------------------------------------------------------------------------- /cpp/afl-driver/src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include "dave/common.h" 10 | #include "dave/utils/array_view.h" 11 | #include "dave/decryptor.h" 12 | 13 | using namespace discord::dave; 14 | 15 | extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) 16 | { 17 | FuzzedDataProvider provider(data, size); 18 | MediaType mediaType = static_cast(provider.ConsumeIntegralInRange(0, 2)); 19 | const auto InFrame = provider.ConsumeRemainingBytes(); 20 | 21 | Decryptor decryptor; 22 | const auto OutFrameSize = decryptor.GetMaxPlaintextByteSize(mediaType, InFrame.size()); 23 | auto outFrame = std::make_unique(OutFrameSize); 24 | [[maybe_unused]] auto res = decryptor.Decrypt(mediaType, 25 | MakeArrayView(InFrame.data(), InFrame.size()), 26 | MakeArrayView(outFrame.get(), OutFrameSize)); 27 | return 0; 28 | } 29 | -------------------------------------------------------------------------------- /cpp/src/dave/boringssl_cryptor.cpp: -------------------------------------------------------------------------------- 1 | #include "boringssl_cryptor.h" 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "dave/common.h" 8 | #include "dave/logger.h" 9 | 10 | namespace discord { 11 | namespace dave { 12 | 13 | void PrintSSLErrors() 14 | { 15 | ERR_print_errors_cb( 16 | [](const char* str, size_t len, [[maybe_unused]] void* ctx) { 17 | DISCORD_LOG(LS_ERROR) << std::string(str, len); 18 | return 1; 19 | }, 20 | nullptr); 21 | } 22 | 23 | BoringSSLCryptor::BoringSSLCryptor(const EncryptionKey& encryptionKey) 24 | { 25 | EVP_AEAD_CTX_zero(&cipherCtx_); 26 | 27 | auto initResult = EVP_AEAD_CTX_init(&cipherCtx_, 28 | EVP_aead_aes_128_gcm(), 29 | encryptionKey.data(), 30 | encryptionKey.size(), 31 | kAesGcm128TruncatedTagBytes, 32 | nullptr); 33 | 34 | if (initResult != 1) { 35 | DISCORD_LOG(LS_ERROR) << "Failed to initialize AEAD context"; 36 | PrintSSLErrors(); 37 | } 38 | } 39 | 40 | BoringSSLCryptor::~BoringSSLCryptor() 41 | { 42 | EVP_AEAD_CTX_cleanup(&cipherCtx_); 43 | } 44 | 45 | bool BoringSSLCryptor::Encrypt(ArrayView ciphertextBufferOut, 46 | ArrayView plaintextBuffer, 47 | ArrayView nonceBuffer, 48 | ArrayView additionalData, 49 | ArrayView tagBufferOut) 50 | { 51 | if (cipherCtx_.aead == nullptr) { 52 | DISCORD_LOG(LS_ERROR) << "Encrypt: AEAD context is not initialized"; 53 | return false; 54 | } 55 | 56 | size_t tagSizeOut; 57 | auto encryptResult = EVP_AEAD_CTX_seal_scatter(&cipherCtx_, 58 | ciphertextBufferOut.data(), 59 | tagBufferOut.data(), 60 | &tagSizeOut, 61 | kAesGcm128TruncatedTagBytes, 62 | nonceBuffer.data(), 63 | kAesGcm128NonceBytes, 64 | plaintextBuffer.data(), 65 | plaintextBuffer.size(), 66 | nullptr, 67 | 0, 68 | additionalData.data(), 69 | additionalData.size()); 70 | if (encryptResult != 1) { 71 | DISCORD_LOG(LS_ERROR) << "Failed to encrypt data"; 72 | PrintSSLErrors(); 73 | } 74 | 75 | return encryptResult == 1; 76 | } 77 | 78 | bool BoringSSLCryptor::Decrypt(ArrayView plaintextBufferOut, 79 | ArrayView ciphertextBuffer, 80 | ArrayView tagBuffer, 81 | ArrayView nonceBuffer, 82 | ArrayView additionalData) 83 | { 84 | if (cipherCtx_.aead == nullptr) { 85 | DISCORD_LOG(LS_ERROR) << "Decrypt: AEAD context is not initialized"; 86 | return false; 87 | } 88 | 89 | auto decryptResult = EVP_AEAD_CTX_open_gather(&cipherCtx_, 90 | plaintextBufferOut.data(), 91 | nonceBuffer.data(), 92 | kAesGcm128NonceBytes, 93 | ciphertextBuffer.data(), 94 | ciphertextBuffer.size(), 95 | tagBuffer.data(), 96 | kAesGcm128TruncatedTagBytes, 97 | additionalData.data(), 98 | additionalData.size()); 99 | 100 | return decryptResult == 1; 101 | } 102 | 103 | } // namespace dave 104 | } // namespace discord 105 | -------------------------------------------------------------------------------- /cpp/src/dave/boringssl_cryptor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "dave/cryptor.h" 6 | 7 | namespace discord { 8 | namespace dave { 9 | 10 | class BoringSSLCryptor : public ICryptor { 11 | public: 12 | BoringSSLCryptor(const EncryptionKey& encryptionKey); 13 | ~BoringSSLCryptor(); 14 | 15 | bool IsValid() const { return cipherCtx_.aead != nullptr; } 16 | 17 | bool Encrypt(ArrayView ciphertextBufferOut, 18 | ArrayView plaintextBuffer, 19 | ArrayView nonceBuffer, 20 | ArrayView additionalData, 21 | ArrayView tagBufferOut) override; 22 | bool Decrypt(ArrayView plaintextBufferOut, 23 | ArrayView ciphertextBuffer, 24 | ArrayView tagBuffer, 25 | ArrayView nonceBuffer, 26 | ArrayView additionalData) override; 27 | 28 | private: 29 | EVP_AEAD_CTX cipherCtx_; 30 | }; 31 | 32 | } // namespace dave 33 | } // namespace discord 34 | -------------------------------------------------------------------------------- /cpp/src/dave/codec_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "common.h" 4 | #include "dave/frame_processors.h" 5 | #include "utils/array_view.h" 6 | 7 | namespace discord { 8 | namespace dave { 9 | namespace codec_utils { 10 | 11 | bool ProcessFrameOpus(OutboundFrameProcessor& processor, ArrayView frame); 12 | bool ProcessFrameVp8(OutboundFrameProcessor& processor, ArrayView frame); 13 | bool ProcessFrameVp9(OutboundFrameProcessor& processor, ArrayView frame); 14 | bool ProcessFrameH264(OutboundFrameProcessor& processor, ArrayView frame); 15 | bool ProcessFrameH265(OutboundFrameProcessor& processor, ArrayView frame); 16 | bool ProcessFrameAv1(OutboundFrameProcessor& processor, ArrayView frame); 17 | 18 | bool ValidateEncryptedFrame(OutboundFrameProcessor& processor, ArrayView frame); 19 | 20 | } // namespace codec_utils 21 | } // namespace dave 22 | } // namespace discord 23 | -------------------------------------------------------------------------------- /cpp/src/dave/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "version.h" 12 | 13 | namespace mlspp::bytes_ns { 14 | struct bytes; 15 | }; 16 | 17 | namespace discord { 18 | namespace dave { 19 | 20 | using UnencryptedFrameHeaderSize = uint16_t; 21 | using TruncatedSyncNonce = uint32_t; 22 | using MagicMarker = uint16_t; 23 | using EncryptionKey = ::mlspp::bytes_ns::bytes; 24 | using TransitionId = uint16_t; 25 | using SupplementalBytesSize = uint8_t; 26 | 27 | enum MediaType : uint8_t { Audio, Video }; 28 | enum Codec : uint8_t { Unknown, Opus, VP8, VP9, H264, H265, AV1 }; 29 | 30 | // Returned in std::variant when a message is hard-rejected and should trigger a reset 31 | struct failed_t {}; 32 | 33 | // Returned in std::variant when a message is soft-rejected and should not trigger a reset 34 | struct ignored_t {}; 35 | 36 | // Map of ID-key pairs. 37 | // In ProcessCommit, this lists IDs whose keys have been added, changed, or removed; 38 | // an empty value value means a key was removed. 39 | using RosterMap = std::map>; 40 | 41 | // Return type for functions producing RosterMap or hard or soft failures 42 | using RosterVariant = std::variant; 43 | 44 | constexpr MagicMarker kMarkerBytes = 0xFAFA; 45 | 46 | // Layout constants 47 | constexpr size_t kAesGcm128KeyBytes = 16; 48 | constexpr size_t kAesGcm128NonceBytes = 12; 49 | constexpr size_t kAesGcm128TruncatedSyncNonceBytes = 4; 50 | constexpr size_t kAesGcm128TruncatedSyncNonceOffset = 51 | kAesGcm128NonceBytes - kAesGcm128TruncatedSyncNonceBytes; 52 | constexpr size_t kAesGcm128TruncatedTagBytes = 8; 53 | constexpr size_t kRatchetGenerationBytes = 1; 54 | constexpr size_t kRatchetGenerationShiftBits = 55 | 8 * (kAesGcm128TruncatedSyncNonceBytes - kRatchetGenerationBytes); 56 | constexpr size_t kSupplementalBytes = 57 | kAesGcm128TruncatedTagBytes + sizeof(SupplementalBytesSize) + sizeof(MagicMarker); 58 | constexpr size_t kTransformPaddingBytes = 64; 59 | 60 | // Timing constants 61 | constexpr auto kDefaultTransitionDuration = std::chrono::seconds(10); 62 | constexpr auto kCryptorExpiry = std::chrono::seconds(10); 63 | 64 | // Behavior constants 65 | constexpr auto kInitTransitionId = 0; 66 | constexpr auto kDisabledVersion = 0; 67 | constexpr auto kMaxGenerationGap = 250; 68 | constexpr auto kMaxMissingNonces = 1000; 69 | constexpr auto kGenerationWrap = 1 << (8 * kRatchetGenerationBytes); 70 | constexpr auto kMaxFramesPerSecond = 50 + 2 * 60; // 50 audio frames + 2 * 60fps video streams 71 | constexpr std::array kOpusSilencePacket = {0xF8, 0xFF, 0xFE}; 72 | 73 | // Utility routine for variant return types 74 | template 75 | inline std::optional GetOptional(V&& variant) 76 | { 77 | if (auto map = std::get_if(&variant)) { 78 | if constexpr (std::is_rvalue_reference_v) { 79 | return std::move(*map); 80 | } 81 | else { 82 | return *map; 83 | } 84 | } 85 | else { 86 | return std::nullopt; 87 | } 88 | } 89 | 90 | } // namespace dave 91 | } // namespace discord 92 | -------------------------------------------------------------------------------- /cpp/src/dave/cryptor.cpp: -------------------------------------------------------------------------------- 1 | #include "cryptor.h" 2 | 3 | #include "boringssl_cryptor.h" 4 | 5 | namespace discord { 6 | namespace dave { 7 | 8 | std::unique_ptr CreateCryptor(const EncryptionKey& encryptionKey) 9 | { 10 | auto cryptor = std::make_unique(encryptionKey); 11 | return cryptor->IsValid() ? std::move(cryptor) : nullptr; 12 | } 13 | 14 | } // namespace dave 15 | } // namespace discord 16 | -------------------------------------------------------------------------------- /cpp/src/dave/cryptor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "common.h" 6 | #include "utils/array_view.h" 7 | 8 | namespace discord { 9 | namespace dave { 10 | 11 | class ICryptor { 12 | public: 13 | virtual ~ICryptor() = default; 14 | 15 | virtual bool Encrypt(ArrayView ciphertextBufferOut, 16 | ArrayView plaintextBuffer, 17 | ArrayView nonceBuffer, 18 | ArrayView additionalData, 19 | ArrayView tagBufferOut) = 0; 20 | virtual bool Decrypt(ArrayView plaintextBufferOut, 21 | ArrayView ciphertextBuffer, 22 | ArrayView tagBuffer, 23 | ArrayView nonceBuffer, 24 | ArrayView additionalData) = 0; 25 | }; 26 | 27 | std::unique_ptr CreateCryptor(const EncryptionKey& encryptionKey); 28 | 29 | } // namespace dave 30 | } // namespace discord 31 | -------------------------------------------------------------------------------- /cpp/src/dave/cryptor_manager.cpp: -------------------------------------------------------------------------------- 1 | #include "cryptor_manager.h" 2 | 3 | #include 4 | 5 | #include "key_ratchet.h" 6 | #include "logger.h" 7 | 8 | #include 9 | 10 | using namespace std::chrono_literals; 11 | 12 | namespace discord { 13 | namespace dave { 14 | 15 | KeyGeneration ComputeWrappedGeneration(KeyGeneration oldest, KeyGeneration generation) 16 | { 17 | // Assume generation is greater than or equal to oldest, this may be wrong in a few cases but 18 | // will be caught by the max generation gap check. 19 | auto remainder = oldest % kGenerationWrap; 20 | auto factor = oldest / kGenerationWrap + (generation < remainder ? 1 : 0); 21 | return factor * kGenerationWrap + generation; 22 | } 23 | 24 | BigNonce ComputeWrappedBigNonce(KeyGeneration generation, TruncatedSyncNonce nonce) 25 | { 26 | // Remove the generation bits from the nonce 27 | auto maskedNonce = nonce & ((1 << kRatchetGenerationShiftBits) - 1); 28 | // Add the wrapped generation bits back in 29 | return static_cast(generation) << kRatchetGenerationShiftBits | maskedNonce; 30 | } 31 | 32 | CryptorManager::CryptorManager(const IClock& clock, std::unique_ptr keyRatchet) 33 | : clock_(clock) 34 | , keyRatchet_(std::move(keyRatchet)) 35 | , ratchetCreation_(clock.Now()) 36 | , ratchetExpiry_(TimePoint::max()) 37 | { 38 | } 39 | 40 | bool CryptorManager::CanProcessNonce(KeyGeneration generation, TruncatedSyncNonce nonce) const 41 | { 42 | if (!newestProcessedNonce_) { 43 | return true; 44 | } 45 | 46 | auto bigNonce = ComputeWrappedBigNonce(generation, nonce); 47 | return bigNonce > *newestProcessedNonce_ || 48 | std::find(missingNonces_.rbegin(), missingNonces_.rend(), bigNonce) != missingNonces_.rend(); 49 | } 50 | 51 | ICryptor* CryptorManager::GetCryptor(KeyGeneration generation) 52 | { 53 | CleanupExpiredCryptors(); 54 | 55 | if (generation < oldestGeneration_) { 56 | DISCORD_LOG(LS_INFO) << "Received frame with old generation: " << generation 57 | << ", oldest generation: " << oldestGeneration_; 58 | return nullptr; 59 | } 60 | 61 | if (generation > newestGeneration_ + kMaxGenerationGap) { 62 | DISCORD_LOG(LS_INFO) << "Received frame with future generation: " << generation 63 | << ", newest generation: " << newestGeneration_; 64 | return nullptr; 65 | } 66 | 67 | auto ratchetLifetimeSec = 68 | std::chrono::duration_cast(clock_.Now() - ratchetCreation_).count(); 69 | auto maxLifetimeFrames = kMaxFramesPerSecond * ratchetLifetimeSec; 70 | auto maxLifetimeGenerations = maxLifetimeFrames >> kRatchetGenerationShiftBits; 71 | if (generation > maxLifetimeGenerations) { 72 | DISCORD_LOG(LS_INFO) << "Received frame with generation " << generation 73 | << " beyond ratchet max lifetime generations: " 74 | << maxLifetimeGenerations 75 | << ", ratchet lifetime: " << ratchetLifetimeSec << "s"; 76 | return nullptr; 77 | } 78 | 79 | auto it = cryptors_.find(generation); 80 | if (it == cryptors_.end()) { 81 | // We don't have a cryptor for this generation, create one 82 | std::tie(it, std::ignore) = cryptors_.emplace(generation, MakeExpiringCryptor(generation)); 83 | } 84 | 85 | // Return a non-owning pointer to the cryptor 86 | auto& [cryptor, expiry] = it->second; 87 | return cryptor.get(); 88 | } 89 | 90 | void CryptorManager::ReportCryptorSuccess(KeyGeneration generation, TruncatedSyncNonce nonce) 91 | { 92 | auto bigNonce = ComputeWrappedBigNonce(generation, nonce); 93 | 94 | // Add any missing nonces to the queue 95 | if (!newestProcessedNonce_) { 96 | newestProcessedNonce_ = bigNonce; 97 | } 98 | else if (bigNonce > *newestProcessedNonce_) { 99 | auto oldestMissingNonce = bigNonce > kMaxMissingNonces ? bigNonce - kMaxMissingNonces : 0; 100 | 101 | while (!missingNonces_.empty() && missingNonces_.front() < oldestMissingNonce) { 102 | missingNonces_.pop_front(); 103 | } 104 | 105 | // If we're missing a lot, we don't want to add everything since newestProcessedNonce_ 106 | auto missingRangeStart = std::max(oldestMissingNonce, *newestProcessedNonce_ + 1); 107 | for (auto i = missingRangeStart; i < bigNonce; ++i) { 108 | missingNonces_.push_back(i); 109 | } 110 | 111 | // Update the newest processed nonce 112 | newestProcessedNonce_ = bigNonce; 113 | } 114 | else { 115 | auto it = std::find(missingNonces_.begin(), missingNonces_.end(), bigNonce); 116 | if (it != missingNonces_.end()) { 117 | missingNonces_.erase(it); 118 | } 119 | } 120 | 121 | if (generation <= newestGeneration_ || cryptors_.find(generation) == cryptors_.end()) { 122 | return; 123 | } 124 | DISCORD_LOG(LS_INFO) << "Reporting cryptor success, generation: " << generation; 125 | newestGeneration_ = generation; 126 | 127 | // Update the expiry time for all old cryptors 128 | const auto expiryTime = clock_.Now() + kCryptorExpiry; 129 | for (auto& [gen, cryptor] : cryptors_) { 130 | if (gen < newestGeneration_) { 131 | DISCORD_LOG(LS_INFO) << "Updating expiry for cryptor, generation: " << gen; 132 | cryptor.expiry = std::min(cryptor.expiry, expiryTime); 133 | } 134 | } 135 | } 136 | 137 | KeyGeneration CryptorManager::ComputeWrappedGeneration(KeyGeneration generation) 138 | { 139 | return ::discord::dave::ComputeWrappedGeneration(oldestGeneration_, generation); 140 | } 141 | 142 | CryptorManager::ExpiringCryptor CryptorManager::MakeExpiringCryptor(KeyGeneration generation) 143 | { 144 | // Get the new key from the ratchet 145 | auto encryptionKey = keyRatchet_->GetKey(generation); 146 | auto expiryTime = TimePoint::max(); 147 | 148 | // If we got frames out of order, we might have to create a cryptor for an old generation 149 | // In that case, create it with a non-infinite expiry time as we have already transitioned 150 | // to a newer generation 151 | if (generation < newestGeneration_) { 152 | DISCORD_LOG(LS_INFO) << "Creating cryptor for old generation: " << generation; 153 | expiryTime = clock_.Now() + kCryptorExpiry; 154 | } 155 | else { 156 | DISCORD_LOG(LS_INFO) << "Creating cryptor for new generation: " << generation; 157 | } 158 | 159 | return {CreateCryptor(encryptionKey), expiryTime}; 160 | } 161 | 162 | void CryptorManager::CleanupExpiredCryptors() 163 | { 164 | for (auto it = cryptors_.begin(); it != cryptors_.end();) { 165 | auto& [generation, cryptor] = *it; 166 | 167 | bool expired = cryptor.expiry < clock_.Now(); 168 | if (expired) { 169 | DISCORD_LOG(LS_INFO) << "Removing expired cryptor, generation: " << generation; 170 | } 171 | 172 | it = expired ? cryptors_.erase(it) : ++it; 173 | } 174 | 175 | while (oldestGeneration_ < newestGeneration_ && 176 | cryptors_.find(oldestGeneration_) == cryptors_.end()) { 177 | DISCORD_LOG(LS_INFO) << "Deleting key for old generation: " << oldestGeneration_; 178 | keyRatchet_->DeleteKey(oldestGeneration_); 179 | ++oldestGeneration_; 180 | } 181 | } 182 | 183 | } // namespace dave 184 | } // namespace discord 185 | -------------------------------------------------------------------------------- /cpp/src/dave/cryptor_manager.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "cryptor.h" 9 | #include "dave/common.h" 10 | #include "key_ratchet.h" 11 | #include "utils/clock.h" 12 | 13 | namespace discord { 14 | namespace dave { 15 | 16 | KeyGeneration ComputeWrappedGeneration(KeyGeneration oldest, KeyGeneration generation); 17 | 18 | using BigNonce = uint64_t; 19 | BigNonce ComputeWrappedBigNonce(KeyGeneration generation, TruncatedSyncNonce nonce); 20 | 21 | class CryptorManager { 22 | public: 23 | using TimePoint = typename IClock::TimePoint; 24 | 25 | CryptorManager(const IClock& clock, std::unique_ptr keyRatchet); 26 | 27 | void UpdateExpiry(TimePoint expiry) { ratchetExpiry_ = expiry; } 28 | bool IsExpired() const { return clock_.Now() > ratchetExpiry_; } 29 | 30 | bool CanProcessNonce(KeyGeneration generation, TruncatedSyncNonce nonce) const; 31 | KeyGeneration ComputeWrappedGeneration(KeyGeneration generation); 32 | 33 | ICryptor* GetCryptor(KeyGeneration generation); 34 | void ReportCryptorSuccess(KeyGeneration generation, TruncatedSyncNonce nonce); 35 | 36 | private: 37 | struct ExpiringCryptor { 38 | std::unique_ptr cryptor; 39 | TimePoint expiry; 40 | }; 41 | 42 | ExpiringCryptor MakeExpiringCryptor(KeyGeneration generation); 43 | void CleanupExpiredCryptors(); 44 | 45 | const IClock& clock_; 46 | std::unique_ptr keyRatchet_; 47 | std::unordered_map cryptors_; 48 | 49 | TimePoint ratchetCreation_; 50 | TimePoint ratchetExpiry_; 51 | KeyGeneration oldestGeneration_{0}; 52 | KeyGeneration newestGeneration_{0}; 53 | 54 | std::optional newestProcessedNonce_; 55 | std::deque missingNonces_; 56 | }; 57 | 58 | } // namespace dave 59 | } // namespace discord 60 | -------------------------------------------------------------------------------- /cpp/src/dave/decryptor.cpp: -------------------------------------------------------------------------------- 1 | #include "decryptor.h" 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "common.h" 8 | #include "logger.h" 9 | #include "utils/leb128.h" 10 | #include "utils/scope_exit.h" 11 | 12 | using namespace std::chrono_literals; 13 | 14 | namespace discord { 15 | namespace dave { 16 | 17 | constexpr auto kStatsInterval = 10s; 18 | 19 | void Decryptor::TransitionToKeyRatchet(std::unique_ptr keyRatchet, 20 | Duration transitionExpiry) 21 | { 22 | DISCORD_LOG(LS_INFO) << "Transitioning to new key ratchet: " << keyRatchet.get() 23 | << ", expiry: " << transitionExpiry.count(); 24 | 25 | // Update the expiry time for all existing cryptor managers 26 | UpdateCryptorManagerExpiry(transitionExpiry); 27 | 28 | if (keyRatchet) { 29 | cryptorManagers_.emplace_back(clock_, std::move(keyRatchet)); 30 | } 31 | } 32 | 33 | void Decryptor::TransitionToPassthroughMode(bool passthroughMode, Duration transitionExpiry) 34 | { 35 | if (passthroughMode) { 36 | allowPassThroughUntil_ = TimePoint::max(); 37 | } 38 | else { 39 | // Update the pass through mode expiry 40 | auto maxExpiry = clock_.Now() + transitionExpiry; 41 | allowPassThroughUntil_ = std::min(allowPassThroughUntil_, maxExpiry); 42 | } 43 | } 44 | 45 | size_t Decryptor::Decrypt(MediaType mediaType, 46 | ArrayView encryptedFrame, 47 | ArrayView frame) 48 | { 49 | if (mediaType != Audio && mediaType != Video) { 50 | DISCORD_LOG(LS_WARNING) << "Decrypt failed, invalid media type: " 51 | << static_cast(mediaType); 52 | return 0; 53 | } 54 | 55 | auto start = clock_.Now(); 56 | 57 | auto localFrame = GetOrCreateFrameProcessor(); 58 | ScopeExit cleanup([&] { ReturnFrameProcessor(std::move(localFrame)); }); 59 | 60 | // Skip decrypting for silence frames 61 | if (mediaType == Audio && encryptedFrame.size() == kOpusSilencePacket.size() && 62 | memcmp(encryptedFrame.data(), kOpusSilencePacket.data(), kOpusSilencePacket.size()) == 0) { 63 | DISCORD_LOG(LS_VERBOSE) << "Decrypt skipping silence of size: " << encryptedFrame.size(); 64 | if (encryptedFrame.data() != frame.data()) { 65 | memcpy(frame.data(), encryptedFrame.data(), encryptedFrame.size()); 66 | } 67 | return encryptedFrame.size(); 68 | } 69 | 70 | // Remove any expired cryptor manager 71 | CleanupExpiredCryptorManagers(); 72 | 73 | // Process the incoming frame 74 | // This will check whether it looks like a valid encrypted frame 75 | // and if so it will parse it into its different components 76 | localFrame->ParseFrame(encryptedFrame); 77 | 78 | // If the frame is not encrypted and we can pass it through, do it 79 | bool canUsePassThrough = allowPassThroughUntil_ > start; 80 | if (!localFrame->IsEncrypted() && canUsePassThrough) { 81 | if (encryptedFrame.data() != frame.data()) { 82 | memcpy(frame.data(), encryptedFrame.data(), encryptedFrame.size()); 83 | } 84 | stats_[mediaType].passthroughCount++; 85 | return encryptedFrame.size(); 86 | } 87 | 88 | // If the frame is not encrypted and we can't pass it through, fail 89 | if (!localFrame->IsEncrypted()) { 90 | DISCORD_LOG(LS_INFO) 91 | << "Decrypt failed, frame is not encrypted and pass through is disabled"; 92 | stats_[mediaType].decryptFailureCount++; 93 | return 0; 94 | } 95 | 96 | // Try and decrypt with each valid cryptor 97 | // reverse iterate to try the newest cryptors first 98 | bool success = false; 99 | for (auto it = cryptorManagers_.rbegin(); it != cryptorManagers_.rend(); ++it) { 100 | auto& cryptorManager = *it; 101 | success = DecryptImpl(cryptorManager, mediaType, *localFrame, frame); 102 | if (success) { 103 | break; 104 | } 105 | } 106 | 107 | size_t bytesWritten = 0; 108 | if (success) { 109 | stats_[mediaType].decryptSuccessCount++; 110 | bytesWritten = localFrame->ReconstructFrame(frame); 111 | } 112 | else { 113 | stats_[mediaType].decryptFailureCount++; 114 | DISCORD_LOG(LS_WARNING) << "Decrypt failed, no valid cryptor found, type: " 115 | << (mediaType ? "video" : "audio") 116 | << ", encrypted frame size: " << encryptedFrame.size() 117 | << ", plaintext frame size: " << frame.size() 118 | << ", number of cryptor managers: " << cryptorManagers_.size() 119 | << ", pass through enabled: " << (canUsePassThrough ? "yes" : "no"); 120 | } 121 | 122 | auto end = clock_.Now(); 123 | if (end > lastStatsTime_ + kStatsInterval) { 124 | lastStatsTime_ = end; 125 | DISCORD_LOG(LS_INFO) << "Decrypted audio: " << stats_[Audio].decryptSuccessCount 126 | << ", video: " << stats_[Video].decryptSuccessCount 127 | << ". Failed audio: " << stats_[Audio].decryptFailureCount 128 | << ", video: " << stats_[Video].decryptFailureCount; 129 | } 130 | stats_[mediaType].decryptDuration += 131 | std::chrono::duration_cast(end - start).count(); 132 | 133 | return bytesWritten; 134 | } 135 | 136 | bool Decryptor::DecryptImpl(CryptorManager& cryptorManager, 137 | MediaType mediaType, 138 | InboundFrameProcessor& encryptedFrame, 139 | [[maybe_unused]] ArrayView frame) 140 | { 141 | auto tag = encryptedFrame.GetTag(); 142 | auto truncatedNonce = encryptedFrame.GetTruncatedNonce(); 143 | 144 | auto authenticatedData = encryptedFrame.GetAuthenticatedData(); 145 | auto ciphertext = encryptedFrame.GetCiphertext(); 146 | auto plaintext = encryptedFrame.GetPlaintext(); 147 | 148 | // expand the truncated nonce to the full sized one needed for decryption 149 | auto nonceBuffer = std::array(); 150 | memcpy(nonceBuffer.data() + kAesGcm128TruncatedSyncNonceOffset, 151 | &truncatedNonce, 152 | kAesGcm128TruncatedSyncNonceBytes); 153 | 154 | auto nonceBufferView = MakeArrayView(nonceBuffer.data(), nonceBuffer.size()); 155 | 156 | auto generation = 157 | cryptorManager.ComputeWrappedGeneration(truncatedNonce >> kRatchetGenerationShiftBits); 158 | 159 | if (!cryptorManager.CanProcessNonce(generation, truncatedNonce)) { 160 | DISCORD_LOG(LS_INFO) << "Decrypt failed, cannot process nonce: " << truncatedNonce; 161 | return false; 162 | } 163 | 164 | // Get the cryptor for this generation 165 | ICryptor* cryptor = cryptorManager.GetCryptor(generation); 166 | 167 | if (!cryptor) { 168 | DISCORD_LOG(LS_INFO) << "Decrypt failed, no cryptor found for generation: " << generation; 169 | return false; 170 | } 171 | 172 | // perform the decryption 173 | bool success = cryptor->Decrypt(plaintext, ciphertext, tag, nonceBufferView, authenticatedData); 174 | stats_[mediaType].decryptAttempts++; 175 | 176 | if (success) { 177 | cryptorManager.ReportCryptorSuccess(generation, truncatedNonce); 178 | } 179 | 180 | return success; 181 | } 182 | 183 | size_t Decryptor::GetMaxPlaintextByteSize([[maybe_unused]] MediaType mediaType, 184 | size_t encryptedFrameSize) 185 | { 186 | return encryptedFrameSize; 187 | } 188 | 189 | void Decryptor::UpdateCryptorManagerExpiry(Duration expiry) 190 | { 191 | auto maxExpiryTime = clock_.Now() + expiry; 192 | for (auto& cryptorManager : cryptorManagers_) { 193 | cryptorManager.UpdateExpiry(maxExpiryTime); 194 | } 195 | } 196 | 197 | void Decryptor::CleanupExpiredCryptorManagers() 198 | { 199 | while (!cryptorManagers_.empty() && cryptorManagers_.front().IsExpired()) { 200 | DISCORD_LOG(LS_INFO) << "Removing expired cryptor manager."; 201 | cryptorManagers_.pop_front(); 202 | } 203 | } 204 | 205 | std::unique_ptr Decryptor::GetOrCreateFrameProcessor() 206 | { 207 | std::lock_guard lock(frameProcessorsMutex_); 208 | if (frameProcessors_.empty()) { 209 | return std::make_unique(); 210 | } 211 | auto frameProcessor = std::move(frameProcessors_.back()); 212 | frameProcessors_.pop_back(); 213 | return frameProcessor; 214 | } 215 | 216 | void Decryptor::ReturnFrameProcessor(std::unique_ptr frameProcessor) 217 | { 218 | std::lock_guard lock(frameProcessorsMutex_); 219 | frameProcessors_.push_back(std::move(frameProcessor)); 220 | } 221 | 222 | } // namespace dave 223 | } // namespace discord 224 | -------------------------------------------------------------------------------- /cpp/src/dave/decryptor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "codec_utils.h" 11 | #include "common.h" 12 | #include "cryptor.h" 13 | #include "cryptor_manager.h" 14 | #include "dave/version.h" 15 | #include "frame_processors.h" 16 | #include "utils/clock.h" 17 | 18 | namespace discord { 19 | namespace dave { 20 | 21 | class IKeyRatchet; 22 | 23 | struct DecryptorStats { 24 | uint64_t passthroughCount = 0; 25 | uint64_t decryptSuccessCount = 0; 26 | uint64_t decryptFailureCount = 0; 27 | uint64_t decryptDuration = 0; 28 | uint64_t decryptAttempts = 0; 29 | }; 30 | 31 | class Decryptor { 32 | public: 33 | using Duration = std::chrono::seconds; 34 | 35 | void TransitionToKeyRatchet(std::unique_ptr keyRatchet, 36 | Duration transitionExpiry = kDefaultTransitionDuration); 37 | void TransitionToPassthroughMode(bool passthroughMode, 38 | Duration transitionExpiry = kDefaultTransitionDuration); 39 | 40 | size_t Decrypt(MediaType mediaType, 41 | ArrayView encryptedFrame, 42 | ArrayView frame); 43 | 44 | size_t GetMaxPlaintextByteSize(MediaType mediaType, size_t encryptedFrameSize); 45 | DecryptorStats GetStats(MediaType mediaType) const { return stats_[mediaType]; } 46 | 47 | private: 48 | using TimePoint = IClock::TimePoint; 49 | 50 | bool DecryptImpl(CryptorManager& cryptor, 51 | MediaType mediaType, 52 | InboundFrameProcessor& encryptedFrame, 53 | ArrayView frame); 54 | 55 | void UpdateCryptorManagerExpiry(Duration expiry); 56 | void CleanupExpiredCryptorManagers(); 57 | 58 | std::unique_ptr GetOrCreateFrameProcessor(); 59 | void ReturnFrameProcessor(std::unique_ptr frameProcessor); 60 | 61 | Clock clock_; 62 | std::deque cryptorManagers_; 63 | 64 | std::mutex frameProcessorsMutex_; 65 | std::vector> frameProcessors_; 66 | 67 | TimePoint allowPassThroughUntil_{TimePoint::min()}; 68 | 69 | TimePoint lastStatsTime_{TimePoint::min()}; 70 | std::array stats_; 71 | }; 72 | 73 | } // namespace dave 74 | } // namespace discord 75 | -------------------------------------------------------------------------------- /cpp/src/dave/encryptor.cpp: -------------------------------------------------------------------------------- 1 | #include "encryptor.h" 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #include "common.h" 9 | #include "dave/codec_utils.h" 10 | #include "dave/common.h" 11 | #include "dave/cryptor_manager.h" 12 | #include "dave/logger.h" 13 | #include "dave/utils/array_view.h" 14 | #include "dave/utils/leb128.h" 15 | #include "dave/utils/scope_exit.h" 16 | 17 | using namespace std::chrono_literals; 18 | 19 | namespace discord { 20 | namespace dave { 21 | 22 | constexpr auto kStatsInterval = 10s; 23 | 24 | void Encryptor::SetKeyRatchet(std::unique_ptr keyRatchet) 25 | { 26 | std::lock_guard lock(keyGenMutex_); 27 | keyRatchet_ = std::move(keyRatchet); 28 | cryptor_ = nullptr; 29 | currentKeyGeneration_ = 0; 30 | truncatedNonce_ = 0; 31 | } 32 | 33 | void Encryptor::SetPassthroughMode(bool passthroughMode) 34 | { 35 | passthroughMode_ = passthroughMode; 36 | UpdateCurrentProtocolVersion(passthroughMode ? 0 : MaxSupportedProtocolVersion()); 37 | } 38 | 39 | int Encryptor::Encrypt(MediaType mediaType, 40 | uint32_t ssrc, 41 | ArrayView frame, 42 | ArrayView encryptedFrame, 43 | size_t* bytesWritten) 44 | { 45 | if (mediaType != Audio && mediaType != Video) { 46 | DISCORD_LOG(LS_WARNING) << "Encrypt failed, invalid media type: " 47 | << static_cast(mediaType); 48 | return 0; 49 | } 50 | 51 | if (passthroughMode_) { 52 | // Pass frame through without encrypting 53 | memcpy(encryptedFrame.data(), frame.data(), frame.size()); 54 | *bytesWritten = frame.size(); 55 | stats_[mediaType].passthroughCount++; 56 | return ResultCode::Success; 57 | } 58 | 59 | { 60 | std::lock_guard lock(keyGenMutex_); 61 | if (!keyRatchet_) { 62 | stats_[mediaType].encryptFailureCount++; 63 | return ResultCode::EncryptionFailure; 64 | } 65 | } 66 | 67 | auto start = std::chrono::steady_clock::now(); 68 | auto result = ResultCode::Success; 69 | 70 | // write the codec identifier 71 | auto codec = CodecForSsrc(ssrc); 72 | 73 | auto frameProcessor = GetOrCreateFrameProcessor(); 74 | ScopeExit cleanup([&] { ReturnFrameProcessor(std::move(frameProcessor)); }); 75 | 76 | frameProcessor->ProcessFrame(frame, codec); 77 | 78 | const auto& unencryptedBytes = frameProcessor->GetUnencryptedBytes(); 79 | const auto& encryptedBytes = frameProcessor->GetEncryptedBytes(); 80 | auto& ciphertextBytes = frameProcessor->GetCiphertextBytes(); 81 | 82 | const auto& unencryptedRanges = frameProcessor->GetUnencryptedRanges(); 83 | auto unencryptedRangesSize = UnencryptedRangesSize(unencryptedRanges); 84 | 85 | auto additionalData = MakeArrayView(unencryptedBytes.data(), unencryptedBytes.size()); 86 | auto plaintextBuffer = MakeArrayView(encryptedBytes.data(), encryptedBytes.size()); 87 | auto ciphertextBuffer = MakeArrayView(ciphertextBytes.data(), ciphertextBytes.size()); 88 | 89 | auto frameSize = encryptedBytes.size() + unencryptedBytes.size(); 90 | auto tagBuffer = MakeArrayView(encryptedFrame.data() + frameSize, kAesGcm128TruncatedTagBytes); 91 | 92 | auto nonceBuffer = std::array(); 93 | auto nonceBufferView = MakeArrayView(nonceBuffer.data(), nonceBuffer.size()); 94 | 95 | constexpr auto MAX_CIPHERTEXT_VALIDATION_RETRIES = 10; 96 | 97 | // some codecs (e.g. H26X) have packetizers that cannot handle specific byte sequences 98 | // so we attempt up to MAX_CIPHERTEXT_VALIDATION_RETRIES to encrypt the frame 99 | // calling into codec utils to validate the ciphertext + supplemental section 100 | // and re-rolling the truncated nonce if it fails 101 | 102 | // the nonce increment will definitely change the ciphertext and the tag 103 | // incrementing the nonce will also change the appropriate bytes 104 | // in the tail end of the nonce 105 | // which can remove start codes from the last 1 or 2 bytes of the nonce 106 | // and the two bytes of the unencrypted header bytes 107 | for (auto attempt = 1; attempt <= MAX_CIPHERTEXT_VALIDATION_RETRIES; ++attempt) { 108 | auto [cryptor, truncatedNonce] = GetNextCryptorAndNonce(); 109 | 110 | if (!cryptor) { 111 | result = ResultCode::EncryptionFailure; 112 | break; 113 | } 114 | 115 | // write the truncated nonce to our temporary full nonce array 116 | // (since the encryption call expects a full size nonce) 117 | memcpy(nonceBuffer.data() + kAesGcm128TruncatedSyncNonceOffset, 118 | &truncatedNonce, 119 | kAesGcm128TruncatedSyncNonceBytes); 120 | 121 | // encrypt the plaintext, adding the unencrypted header to the tag 122 | bool success = cryptor->Encrypt( 123 | ciphertextBuffer, plaintextBuffer, nonceBufferView, additionalData, tagBuffer); 124 | 125 | stats_[mediaType].encryptAttempts++; 126 | stats_[mediaType].encryptMaxAttempts = 127 | std::max(stats_[mediaType].encryptMaxAttempts, (uint64_t)attempt); 128 | 129 | if (!success) { 130 | assert(false && "Failed to encrypt frame"); 131 | result = ResultCode::EncryptionFailure; 132 | break; 133 | } 134 | 135 | auto reconstructedFrameSize = frameProcessor->ReconstructFrame(encryptedFrame); 136 | assert(reconstructedFrameSize == frameSize && "Failed to reconstruct frame"); 137 | 138 | auto nonceSize = Leb128Size(truncatedNonce); 139 | 140 | auto truncatedNonceBuffer = MakeArrayView(tagBuffer.end(), nonceSize); 141 | auto unencryptedRangesBuffer = 142 | MakeArrayView(truncatedNonceBuffer.end(), unencryptedRangesSize); 143 | auto supplementalBytesBuffer = 144 | MakeArrayView(unencryptedRangesBuffer.end(), sizeof(SupplementalBytesSize)); 145 | auto markerBytesBuffer = MakeArrayView(supplementalBytesBuffer.end(), sizeof(MagicMarker)); 146 | 147 | // write the nonce 148 | auto res = WriteLeb128(truncatedNonce, truncatedNonceBuffer.begin()); 149 | if (res != nonceSize) { 150 | assert(false && "Failed to write truncated nonce"); 151 | result = ResultCode::EncryptionFailure; 152 | break; 153 | } 154 | 155 | // write the unencrypted ranges 156 | res = SerializeUnencryptedRanges( 157 | unencryptedRanges, unencryptedRangesBuffer.begin(), unencryptedRangesBuffer.size()); 158 | if (res != unencryptedRangesSize) { 159 | assert(false && "Failed to write unencrypted ranges"); 160 | result = ResultCode::EncryptionFailure; 161 | break; 162 | } 163 | 164 | // write the supplemental bytes size 165 | uint64_t supplementalBytesLarge = kSupplementalBytes + nonceSize + unencryptedRangesSize; 166 | 167 | if (supplementalBytesLarge > std::numeric_limits::max()) { 168 | assert(false && "Supplemental bytes size too large"); 169 | result = ResultCode::EncryptionFailure; 170 | break; 171 | } 172 | 173 | SupplementalBytesSize supplementalBytes = 174 | static_cast(supplementalBytesLarge); 175 | memcpy(supplementalBytesBuffer.data(), &supplementalBytes, sizeof(SupplementalBytesSize)); 176 | 177 | // write the marker bytes, ends the frame 178 | memcpy(markerBytesBuffer.data(), &kMarkerBytes, sizeof(MagicMarker)); 179 | 180 | auto encryptedFrameBytes = reconstructedFrameSize + kAesGcm128TruncatedTagBytes + 181 | nonceSize + unencryptedRangesSize + sizeof(SupplementalBytesSize) + sizeof(MagicMarker); 182 | 183 | if (codec_utils::ValidateEncryptedFrame( 184 | *frameProcessor, MakeArrayView(encryptedFrame.data(), encryptedFrameBytes))) { 185 | *bytesWritten = encryptedFrameBytes; 186 | break; 187 | } 188 | else if (attempt >= MAX_CIPHERTEXT_VALIDATION_RETRIES) { 189 | assert(false && "Failed to validate encrypted section for codec"); 190 | result = ResultCode::EncryptionFailure; 191 | break; 192 | } 193 | } 194 | 195 | auto now = std::chrono::steady_clock::now(); 196 | stats_[mediaType].encryptDuration += 197 | std::chrono::duration_cast(now - start).count(); 198 | if (result == ResultCode::Success) { 199 | stats_[mediaType].encryptSuccessCount++; 200 | } 201 | else { 202 | stats_[mediaType].encryptFailureCount++; 203 | } 204 | 205 | if (now > lastStatsTime_ + kStatsInterval) { 206 | lastStatsTime_ = now; 207 | DISCORD_LOG(LS_INFO) << "Encrypted audio: " << stats_[Audio].encryptSuccessCount 208 | << ", video: " << stats_[Video].encryptSuccessCount 209 | << ". Failed audio: " << stats_[Audio].encryptFailureCount 210 | << ", video: " << stats_[Video].encryptFailureCount; 211 | DISCORD_LOG(LS_INFO) << "Last encrypted frame, type: " 212 | << (mediaType == Audio ? "audio" : "video") << ", ssrc: " << ssrc 213 | << ", size: " << frame.size(); 214 | } 215 | 216 | return result; 217 | } 218 | 219 | size_t Encryptor::GetMaxCiphertextByteSize([[maybe_unused]] MediaType mediaType, size_t frameSize) 220 | { 221 | return frameSize + kSupplementalBytes + kTransformPaddingBytes; 222 | } 223 | 224 | void Encryptor::AssignSsrcToCodec(uint32_t ssrc, Codec codecType) 225 | { 226 | auto existingCodecIt = std::find_if( 227 | ssrcCodecPairs_.begin(), ssrcCodecPairs_.end(), [ssrc](const SsrcCodecPair& pair) { 228 | return pair.first == ssrc; 229 | }); 230 | 231 | if (existingCodecIt == ssrcCodecPairs_.end()) { 232 | ssrcCodecPairs_.emplace_back(ssrc, codecType); 233 | } 234 | else { 235 | existingCodecIt->second = codecType; 236 | } 237 | } 238 | 239 | Codec Encryptor::CodecForSsrc(uint32_t ssrc) 240 | { 241 | auto existingCodecIt = std::find_if( 242 | ssrcCodecPairs_.begin(), ssrcCodecPairs_.end(), [ssrc](const SsrcCodecPair& pair) { 243 | return pair.first == ssrc; 244 | }); 245 | 246 | if (existingCodecIt != ssrcCodecPairs_.end()) { 247 | return existingCodecIt->second; 248 | } 249 | else { 250 | return Codec::Unknown; 251 | } 252 | } 253 | 254 | std::unique_ptr Encryptor::GetOrCreateFrameProcessor() 255 | { 256 | std::lock_guard lock(frameProcessorsMutex_); 257 | if (frameProcessors_.empty()) { 258 | return std::make_unique(); 259 | } 260 | auto frameProcessor = std::move(frameProcessors_.back()); 261 | frameProcessors_.pop_back(); 262 | return frameProcessor; 263 | } 264 | 265 | void Encryptor::ReturnFrameProcessor(std::unique_ptr frameProcessor) 266 | { 267 | std::lock_guard lock(frameProcessorsMutex_); 268 | frameProcessors_.push_back(std::move(frameProcessor)); 269 | } 270 | 271 | Encryptor::CryptorAndNonce Encryptor::GetNextCryptorAndNonce() 272 | { 273 | std::lock_guard lock(keyGenMutex_); 274 | if (!keyRatchet_) { 275 | return {nullptr, 0}; 276 | } 277 | 278 | auto generation = ComputeWrappedGeneration(currentKeyGeneration_, 279 | ++truncatedNonce_ >> kRatchetGenerationShiftBits); 280 | 281 | if (generation != currentKeyGeneration_ || !cryptor_) { 282 | currentKeyGeneration_ = generation; 283 | 284 | auto encryptionKey = keyRatchet_->GetKey(currentKeyGeneration_); 285 | cryptor_ = CreateCryptor(encryptionKey); 286 | } 287 | 288 | return {cryptor_, truncatedNonce_}; 289 | } 290 | 291 | void Encryptor::UpdateCurrentProtocolVersion(ProtocolVersion version) 292 | { 293 | if (version == currentProtocolVersion_) { 294 | return; 295 | } 296 | 297 | currentProtocolVersion_ = version; 298 | if (protocolVersionChangedCallback_) { 299 | protocolVersionChangedCallback_(); 300 | } 301 | } 302 | 303 | } // namespace dave 304 | } // namespace discord 305 | -------------------------------------------------------------------------------- /cpp/src/dave/encryptor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "dave/codec_utils.h" 12 | #include "dave/common.h" 13 | #include "dave/cryptor.h" 14 | #include "dave/frame_processors.h" 15 | #include "dave/key_ratchet.h" 16 | #include "dave/version.h" 17 | 18 | namespace discord { 19 | namespace dave { 20 | 21 | struct EncryptorStats { 22 | uint64_t passthroughCount = 0; 23 | uint64_t encryptSuccessCount = 0; 24 | uint64_t encryptFailureCount = 0; 25 | uint64_t encryptDuration = 0; 26 | uint64_t encryptAttempts = 0; 27 | uint64_t encryptMaxAttempts = 0; 28 | }; 29 | 30 | class Encryptor { 31 | public: 32 | void SetKeyRatchet(std::unique_ptr keyRatchet); 33 | void SetPassthroughMode(bool passthroughMode); 34 | 35 | bool HasKeyRatchet() const { return keyRatchet_ != nullptr; } 36 | bool IsPassthroughMode() const { return passthroughMode_; } 37 | 38 | void AssignSsrcToCodec(uint32_t ssrc, Codec codecType); 39 | Codec CodecForSsrc(uint32_t ssrc); 40 | 41 | int Encrypt(MediaType mediaType, 42 | uint32_t ssrc, 43 | ArrayView frame, 44 | ArrayView encryptedFrame, 45 | size_t* bytesWritten); 46 | 47 | size_t GetMaxCiphertextByteSize(MediaType mediaType, size_t frameSize); 48 | EncryptorStats GetStats(MediaType mediaType) const { return stats_[mediaType]; } 49 | 50 | using ProtocolVersionChangedCallback = std::function; 51 | void SetProtocolVersionChangedCallback(ProtocolVersionChangedCallback callback) 52 | { 53 | protocolVersionChangedCallback_ = std::move(callback); 54 | } 55 | ProtocolVersion GetProtocolVersion() const { return currentProtocolVersion_; } 56 | 57 | private: 58 | std::unique_ptr GetOrCreateFrameProcessor(); 59 | void ReturnFrameProcessor(std::unique_ptr frameProcessor); 60 | 61 | using CryptorAndNonce = std::pair, TruncatedSyncNonce>; 62 | CryptorAndNonce GetNextCryptorAndNonce(); 63 | 64 | void UpdateCurrentProtocolVersion(ProtocolVersion version); 65 | 66 | enum ResultCode { 67 | Success, 68 | UninitializedContext, 69 | InitializationFailure, 70 | UnsupportedCodec, 71 | EncryptionFailure, 72 | FinalizationFailure, 73 | TagAppendFailure 74 | }; 75 | 76 | std::atomic_bool passthroughMode_{false}; 77 | 78 | std::mutex keyGenMutex_; 79 | std::unique_ptr keyRatchet_; 80 | std::shared_ptr cryptor_; 81 | KeyGeneration currentKeyGeneration_{0}; 82 | TruncatedSyncNonce truncatedNonce_{0}; 83 | 84 | std::mutex frameProcessorsMutex_; 85 | std::vector> frameProcessors_; 86 | 87 | using SsrcCodecPair = std::pair; 88 | std::vector ssrcCodecPairs_; 89 | 90 | using TimePoint = std::chrono::time_point; 91 | TimePoint lastStatsTime_{TimePoint::min()}; 92 | std::array stats_; 93 | 94 | ProtocolVersionChangedCallback protocolVersionChangedCallback_; 95 | ProtocolVersion currentProtocolVersion_{MaxSupportedProtocolVersion()}; 96 | }; 97 | 98 | } // namespace dave 99 | } // namespace discord 100 | -------------------------------------------------------------------------------- /cpp/src/dave/frame_processors.cpp: -------------------------------------------------------------------------------- 1 | #include "frame_processors.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "codec_utils.h" 9 | #include "logger.h" 10 | #include "utils/array_view.h" 11 | #include "utils/leb128.h" 12 | 13 | #if defined(_MSC_VER) 14 | #include 15 | #endif 16 | 17 | namespace discord { 18 | namespace dave { 19 | 20 | std::pair OverflowAdd(size_t a, size_t b) 21 | { 22 | size_t res; 23 | #if defined(_MSC_VER) && defined(_M_X64) 24 | bool didOverflow = _addcarry_u64(0, a, b, &res); 25 | #elif defined(_MSC_VER) && defined(_M_IX86) 26 | bool didOverflow = _addcarry_u32(0, a, b, &res); 27 | #else 28 | bool didOverflow = __builtin_add_overflow(a, b, &res); 29 | #endif 30 | return {didOverflow, res}; 31 | } 32 | 33 | uint8_t UnencryptedRangesSize(const Ranges& unencryptedRanges) 34 | { 35 | size_t size = 0; 36 | for (const auto& range : unencryptedRanges) { 37 | size += Leb128Size(range.offset); 38 | size += Leb128Size(range.size); 39 | } 40 | assert(size <= std::numeric_limits::max() && 41 | "Unencrypted ranges size exceeds 255 bytes"); 42 | return static_cast(size); 43 | } 44 | 45 | uint8_t SerializeUnencryptedRanges(const Ranges& unencryptedRanges, 46 | uint8_t* buffer, 47 | size_t bufferSize) 48 | { 49 | auto writeAt = buffer; 50 | auto end = buffer + bufferSize; 51 | for (const auto& range : unencryptedRanges) { 52 | auto rangeSize = Leb128Size(range.offset) + Leb128Size(range.size); 53 | if (rangeSize > static_cast(end - writeAt)) { 54 | assert(false && "Buffer is too small to serialize unencrypted ranges"); 55 | break; 56 | } 57 | 58 | writeAt += WriteLeb128(range.offset, writeAt); 59 | writeAt += WriteLeb128(range.size, writeAt); 60 | } 61 | 62 | assert(writeAt <= buffer); 63 | return static_cast(writeAt - buffer); 64 | } 65 | 66 | uint8_t DeserializeUnencryptedRanges(const uint8_t*& readAt, 67 | const uint8_t bufferSize, 68 | Ranges& unencryptedRanges) 69 | { 70 | auto start = readAt; 71 | auto end = readAt + bufferSize; 72 | while (readAt < end) { 73 | size_t offset = ReadLeb128(readAt, end); 74 | if (readAt == nullptr) { 75 | break; 76 | } 77 | 78 | size_t size = ReadLeb128(readAt, end); 79 | if (readAt == nullptr) { 80 | break; 81 | } 82 | unencryptedRanges.push_back({offset, size}); 83 | } 84 | 85 | if (readAt != end) { 86 | DISCORD_LOG(LS_WARNING) << "Failed to deserialize unencrypted ranges"; 87 | unencryptedRanges.clear(); 88 | readAt = nullptr; 89 | return 0; 90 | } 91 | 92 | return static_cast(readAt - start); 93 | } 94 | 95 | bool ValidateUnencryptedRanges(const Ranges& unencryptedRanges, size_t frameSize) 96 | { 97 | if (unencryptedRanges.empty()) { 98 | return true; 99 | } 100 | 101 | // validate that the ranges are in order and don't overlap 102 | for (auto i = 0u; i < unencryptedRanges.size(); ++i) { 103 | auto current = unencryptedRanges[i]; 104 | // The current range should not overflow into the next range 105 | // or if it is the last range, the end of the frame 106 | auto maxEnd = 107 | i + 1 < unencryptedRanges.size() ? unencryptedRanges[i + 1].offset : frameSize; 108 | 109 | auto [didOverflow, currentEnd] = OverflowAdd(current.offset, current.size); 110 | if (didOverflow || currentEnd > maxEnd) { 111 | DISCORD_LOG(LS_WARNING) 112 | << "Unencrypted range may overlap or be out of order: current offset: " 113 | << current.offset << ", current size: " << current.size << ", maximum end: " << maxEnd 114 | << ", frame size: " << frameSize; 115 | return false; 116 | } 117 | } 118 | 119 | return true; 120 | } 121 | 122 | size_t Reconstruct(Ranges ranges, 123 | const std::vector& rangeBytes, 124 | const std::vector& otherBytes, 125 | const ArrayView& output) 126 | { 127 | size_t frameIndex = 0; 128 | size_t rangeBytesIndex = 0; 129 | size_t otherBytesIndex = 0; 130 | 131 | const auto CopyRangeBytes = [&](size_t size) { 132 | assert(rangeBytesIndex + size <= rangeBytes.size()); 133 | assert(frameIndex + size <= output.size()); 134 | memcpy(output.data() + frameIndex, rangeBytes.data() + rangeBytesIndex, size); 135 | rangeBytesIndex += size; 136 | frameIndex += size; 137 | }; 138 | 139 | const auto CopyOtherBytes = [&](size_t size) { 140 | assert(otherBytesIndex + size <= otherBytes.size()); 141 | assert(frameIndex + size <= output.size()); 142 | memcpy(output.data() + frameIndex, otherBytes.data() + otherBytesIndex, size); 143 | otherBytesIndex += size; 144 | frameIndex += size; 145 | }; 146 | 147 | for (const auto& range : ranges) { 148 | if (range.offset > frameIndex) { 149 | CopyOtherBytes(range.offset - frameIndex); 150 | } 151 | 152 | CopyRangeBytes(range.size); 153 | } 154 | 155 | if (otherBytesIndex < otherBytes.size()) { 156 | CopyOtherBytes(otherBytes.size() - otherBytesIndex); 157 | } 158 | 159 | assert(rangeBytesIndex == rangeBytes.size()); 160 | assert(otherBytesIndex == otherBytes.size()); 161 | assert(frameIndex <= output.size()); 162 | 163 | return frameIndex; 164 | } 165 | 166 | void InboundFrameProcessor::Clear() 167 | { 168 | isEncrypted_ = false; 169 | originalSize_ = 0; 170 | truncatedNonce_ = std::numeric_limits::max(); 171 | unencryptedRanges_.clear(); 172 | authenticated_.clear(); 173 | ciphertext_.clear(); 174 | plaintext_.clear(); 175 | } 176 | 177 | void InboundFrameProcessor::ParseFrame(ArrayView frame) 178 | { 179 | Clear(); 180 | 181 | constexpr auto MinSupplementalBytesSize = 182 | kAesGcm128TruncatedTagBytes + sizeof(SupplementalBytesSize) + sizeof(MagicMarker); 183 | if (frame.size() < MinSupplementalBytesSize) { 184 | DISCORD_LOG(LS_WARNING) << "Encrypted frame is too small to contain min supplemental bytes"; 185 | return; 186 | } 187 | 188 | // Check the frame ends with the magic marker 189 | auto magicMarkerBuffer = frame.end() - sizeof(MagicMarker); 190 | if (memcmp(magicMarkerBuffer, &kMarkerBytes, sizeof(MagicMarker)) != 0) { 191 | return; 192 | } 193 | 194 | // Read the supplemental bytes size 195 | SupplementalBytesSize supplementalBytesSize; 196 | auto supplementalBytesSizeBuffer = magicMarkerBuffer - sizeof(SupplementalBytesSize); 197 | assert(frame.begin() <= supplementalBytesSizeBuffer && 198 | supplementalBytesSizeBuffer <= frame.end()); 199 | memcpy(&supplementalBytesSize, supplementalBytesSizeBuffer, sizeof(SupplementalBytesSize)); 200 | 201 | // Check the frame is large enough to contain the supplemental bytes 202 | if (frame.size() < supplementalBytesSize) { 203 | DISCORD_LOG(LS_WARNING) << "Encrypted frame is too small to contain supplemental bytes"; 204 | return; 205 | } 206 | 207 | // Check that supplemental bytes size is large enough to contain the supplemental bytes 208 | if (supplementalBytesSize < MinSupplementalBytesSize) { 209 | DISCORD_LOG(LS_WARNING) 210 | << "Supplemental bytes size is too small to contain supplemental bytes"; 211 | return; 212 | } 213 | 214 | auto supplementalBytesBuffer = frame.end() - supplementalBytesSize; 215 | assert(frame.begin() <= supplementalBytesBuffer && supplementalBytesBuffer <= frame.end()); 216 | 217 | // Read the tag 218 | tag_ = MakeArrayView(supplementalBytesBuffer, kAesGcm128TruncatedTagBytes); 219 | 220 | // Read the nonce 221 | auto nonceBuffer = supplementalBytesBuffer + kAesGcm128TruncatedTagBytes; 222 | assert(frame.begin() <= nonceBuffer && nonceBuffer <= frame.end()); 223 | auto readAt = nonceBuffer; 224 | auto end = supplementalBytesSizeBuffer; 225 | truncatedNonce_ = static_cast(ReadLeb128(readAt, end)); 226 | if (readAt == nullptr) { 227 | DISCORD_LOG(LS_WARNING) << "Failed to read truncated nonce"; 228 | return; 229 | } 230 | 231 | // Read the unencrypted ranges 232 | assert(nonceBuffer <= readAt && readAt <= end && 233 | end - readAt <= std::numeric_limits::max()); 234 | auto unencryptedRangesSize = static_cast(end - readAt); 235 | 236 | DeserializeUnencryptedRanges(readAt, unencryptedRangesSize, unencryptedRanges_); 237 | if (readAt == nullptr) { 238 | DISCORD_LOG(LS_WARNING) << "Failed to read unencrypted ranges"; 239 | return; 240 | } 241 | 242 | if (!ValidateUnencryptedRanges(unencryptedRanges_, frame.size())) { 243 | DISCORD_LOG(LS_WARNING) << "Invalid unencrypted ranges"; 244 | return; 245 | } 246 | 247 | // This is overly aggressive but will keep reallocations to a minimum 248 | authenticated_.reserve(frame.size()); 249 | ciphertext_.reserve(frame.size()); 250 | plaintext_.reserve(frame.size()); 251 | 252 | originalSize_ = frame.size(); 253 | 254 | // Split the frame into authenticated and ciphertext bytes 255 | size_t frameIndex = 0; 256 | for (const auto& range : unencryptedRanges_) { 257 | auto encryptedBytes = range.offset - frameIndex; 258 | if (encryptedBytes > 0) { 259 | assert(frameIndex + encryptedBytes <= frame.size()); 260 | AddCiphertextBytes(frame.data() + frameIndex, encryptedBytes); 261 | } 262 | 263 | assert(range.offset + range.size <= frame.size()); 264 | AddAuthenticatedBytes(frame.data() + range.offset, range.size); 265 | frameIndex = range.offset + range.size; 266 | } 267 | auto actualFrameSize = frame.size() - supplementalBytesSize; 268 | if (frameIndex < actualFrameSize) { 269 | AddCiphertextBytes(frame.data() + frameIndex, actualFrameSize - frameIndex); 270 | } 271 | 272 | // Make sure the plaintext buffer is the same size as the ciphertext buffer 273 | plaintext_.resize(ciphertext_.size()); 274 | 275 | // We've successfully parsed the frame 276 | // Mark the frame as encrypted 277 | isEncrypted_ = true; 278 | } 279 | 280 | size_t InboundFrameProcessor::ReconstructFrame(ArrayView frame) const 281 | { 282 | if (!isEncrypted_) { 283 | DISCORD_LOG(LS_WARNING) << "Cannot reconstruct an invalid encrypted frame"; 284 | return 0; 285 | } 286 | 287 | if (authenticated_.size() + plaintext_.size() > frame.size()) { 288 | DISCORD_LOG(LS_WARNING) << "Frame is too small to contain the decrypted frame"; 289 | return 0; 290 | } 291 | 292 | return Reconstruct(unencryptedRanges_, authenticated_, plaintext_, frame); 293 | } 294 | 295 | void InboundFrameProcessor::AddAuthenticatedBytes(const uint8_t* data, size_t size) 296 | { 297 | authenticated_.resize(authenticated_.size() + size); 298 | memcpy(authenticated_.data() + authenticated_.size() - size, data, size); 299 | } 300 | 301 | void InboundFrameProcessor::AddCiphertextBytes(const uint8_t* data, size_t size) 302 | { 303 | ciphertext_.resize(ciphertext_.size() + size); 304 | memcpy(ciphertext_.data() + ciphertext_.size() - size, data, size); 305 | } 306 | 307 | void OutboundFrameProcessor::Reset() 308 | { 309 | codec_ = Codec::Unknown; 310 | frameIndex_ = 0; 311 | unencryptedBytes_.clear(); 312 | encryptedBytes_.clear(); 313 | unencryptedRanges_.clear(); 314 | } 315 | 316 | void OutboundFrameProcessor::ProcessFrame(ArrayView frame, Codec codec) 317 | { 318 | Reset(); 319 | 320 | codec_ = codec; 321 | unencryptedBytes_.reserve(frame.size()); 322 | encryptedBytes_.reserve(frame.size()); 323 | 324 | bool success = false; 325 | switch (codec) { 326 | case Codec::Opus: 327 | success = codec_utils::ProcessFrameOpus(*this, frame); 328 | break; 329 | case Codec::VP8: 330 | success = codec_utils::ProcessFrameVp8(*this, frame); 331 | break; 332 | case Codec::VP9: 333 | success = codec_utils::ProcessFrameVp9(*this, frame); 334 | break; 335 | case Codec::H264: 336 | success = codec_utils::ProcessFrameH264(*this, frame); 337 | break; 338 | case Codec::H265: 339 | success = codec_utils::ProcessFrameH265(*this, frame); 340 | break; 341 | case Codec::AV1: 342 | success = codec_utils::ProcessFrameAv1(*this, frame); 343 | break; 344 | default: 345 | assert(false && "Unsupported codec for frame encryption"); 346 | break; 347 | } 348 | 349 | if (!success) { 350 | frameIndex_ = 0; 351 | unencryptedBytes_.clear(); 352 | encryptedBytes_.clear(); 353 | unencryptedRanges_.clear(); 354 | AddEncryptedBytes(frame.data(), frame.size()); 355 | } 356 | 357 | ciphertextBytes_.resize(encryptedBytes_.size()); 358 | } 359 | 360 | size_t OutboundFrameProcessor::ReconstructFrame(ArrayView frame) 361 | { 362 | if (unencryptedBytes_.size() + ciphertextBytes_.size() > frame.size()) { 363 | DISCORD_LOG(LS_WARNING) << "Frame is too small to contain the encrypted frame"; 364 | return 0; 365 | } 366 | 367 | return Reconstruct(unencryptedRanges_, unencryptedBytes_, ciphertextBytes_, frame); 368 | } 369 | 370 | void OutboundFrameProcessor::AddUnencryptedBytes(const uint8_t* bytes, size_t size) 371 | { 372 | if (!unencryptedRanges_.empty() && 373 | unencryptedRanges_.back().offset + unencryptedRanges_.back().size == frameIndex_) { 374 | // extend the last range 375 | unencryptedRanges_.back().size += size; 376 | } 377 | else { 378 | // add a new range (offset, size) 379 | unencryptedRanges_.push_back({frameIndex_, size}); 380 | } 381 | 382 | unencryptedBytes_.resize(unencryptedBytes_.size() + size); 383 | memcpy(unencryptedBytes_.data() + unencryptedBytes_.size() - size, bytes, size); 384 | frameIndex_ += size; 385 | } 386 | 387 | void OutboundFrameProcessor::AddEncryptedBytes(const uint8_t* bytes, size_t size) 388 | { 389 | encryptedBytes_.resize(encryptedBytes_.size() + size); 390 | memcpy(encryptedBytes_.data() + encryptedBytes_.size() - size, bytes, size); 391 | frameIndex_ += size; 392 | } 393 | 394 | } // namespace dave 395 | } // namespace discord 396 | -------------------------------------------------------------------------------- /cpp/src/dave/frame_processors.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "common.h" 9 | #include "utils/array_view.h" 10 | 11 | namespace discord { 12 | namespace dave { 13 | 14 | struct Range { 15 | size_t offset; 16 | size_t size; 17 | }; 18 | using Ranges = std::vector; 19 | 20 | uint8_t UnencryptedRangesSize(const Ranges& unencryptedRanges); 21 | uint8_t SerializeUnencryptedRanges(const Ranges& unencryptedRanges, 22 | uint8_t* buffer, 23 | size_t bufferSize); 24 | uint8_t DeserializeUnencryptedRanges(const uint8_t*& buffer, 25 | const uint8_t bufferSize, 26 | Ranges& unencryptedRanges); 27 | bool ValidateUnencryptedRanges(const Ranges& unencryptedRanges, size_t frameSize); 28 | 29 | class InboundFrameProcessor { 30 | public: 31 | void ParseFrame(ArrayView frame); 32 | size_t ReconstructFrame(ArrayView frame) const; 33 | 34 | bool IsEncrypted() const { return isEncrypted_; } 35 | size_t Size() const { return originalSize_; } 36 | void Clear(); 37 | 38 | ArrayView GetTag() const { return tag_; } 39 | TruncatedSyncNonce GetTruncatedNonce() const { return truncatedNonce_; } 40 | ArrayView GetAuthenticatedData() const 41 | { 42 | return MakeArrayView(authenticated_.data(), authenticated_.size()); 43 | } 44 | ArrayView GetCiphertext() const 45 | { 46 | return MakeArrayView(ciphertext_.data(), ciphertext_.size()); 47 | } 48 | ArrayView GetPlaintext() { return MakeArrayView(plaintext_); } 49 | 50 | private: 51 | void AddAuthenticatedBytes(const uint8_t* data, size_t size); 52 | void AddCiphertextBytes(const uint8_t* data, size_t size); 53 | 54 | bool isEncrypted_{false}; 55 | size_t originalSize_{0}; 56 | ArrayView tag_; 57 | TruncatedSyncNonce truncatedNonce_; 58 | Ranges unencryptedRanges_; 59 | std::vector authenticated_; 60 | std::vector ciphertext_; 61 | std::vector plaintext_; 62 | }; 63 | 64 | class OutboundFrameProcessor { 65 | public: 66 | void ProcessFrame(ArrayView frame, Codec codec); 67 | size_t ReconstructFrame(ArrayView frame); 68 | 69 | Codec GetCodec() const { return codec_; } 70 | const std::vector& GetUnencryptedBytes() const { return unencryptedBytes_; } 71 | const std::vector& GetEncryptedBytes() const { return encryptedBytes_; } 72 | std::vector& GetCiphertextBytes() { return ciphertextBytes_; } 73 | const Ranges& GetUnencryptedRanges() const { return unencryptedRanges_; } 74 | 75 | void Reset(); 76 | void AddUnencryptedBytes(const uint8_t* bytes, size_t size); 77 | void AddEncryptedBytes(const uint8_t* bytes, size_t size); 78 | 79 | private: 80 | Codec codec_{Codec::Unknown}; 81 | size_t frameIndex_{0}; 82 | std::vector unencryptedBytes_; 83 | std::vector encryptedBytes_; 84 | std::vector ciphertextBytes_; 85 | Ranges unencryptedRanges_; 86 | }; 87 | 88 | } // namespace dave 89 | } // namespace discord 90 | -------------------------------------------------------------------------------- /cpp/src/dave/key_ratchet.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "common.h" 6 | 7 | namespace discord { 8 | namespace dave { 9 | 10 | using KeyGeneration = uint32_t; 11 | 12 | class IKeyRatchet { 13 | public: 14 | virtual ~IKeyRatchet() noexcept = default; 15 | virtual EncryptionKey GetKey(KeyGeneration generation) noexcept = 0; 16 | virtual void DeleteKey(KeyGeneration generation) noexcept = 0; 17 | }; 18 | 19 | } // namespace dave 20 | } // namespace discord 21 | -------------------------------------------------------------------------------- /cpp/src/dave/logger.cpp: -------------------------------------------------------------------------------- 1 | #include "logger.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace discord { 8 | namespace dave { 9 | 10 | std::atomic gLogSink = nullptr; 11 | 12 | void SetLogSink(LogSink sink) 13 | { 14 | gLogSink = sink; 15 | } 16 | 17 | LogStreamer::LogStreamer(LoggingSeverity severity, const char* file, int line) 18 | : severity_(severity) 19 | , file_(file) 20 | , line_(line) 21 | { 22 | } 23 | 24 | LogStreamer::~LogStreamer() 25 | { 26 | std::string logLine = stream_.str(); 27 | if (logLine.empty()) { 28 | return; 29 | } 30 | 31 | auto sink = gLogSink.load(); 32 | if (sink) { 33 | sink(severity_, file_, line_, logLine); 34 | return; 35 | } 36 | 37 | switch (severity_) { 38 | case LS_VERBOSE: 39 | case LS_INFO: 40 | case LS_WARNING: 41 | case LS_ERROR: { 42 | const char* file = file_; 43 | if (auto separator = strrchr(file, '/')) { 44 | file = separator + 1; 45 | } 46 | std::cout << "(" << file << ":" << line_ << ") " << logLine << std::endl; 47 | break; 48 | } 49 | case LS_NONE: 50 | break; 51 | } 52 | } 53 | 54 | } // namespace dave 55 | } // namespace discord 56 | -------------------------------------------------------------------------------- /cpp/src/dave/logger.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #if !defined(DISCORD_LOG) 6 | #define DISCORD_LOG_FILE_LINE(sev, file, line) ::discord::dave::LogStreamer(sev, file, line) 7 | #define DISCORD_LOG(sev) DISCORD_LOG_FILE_LINE(::discord::dave::sev, __FILE__, __LINE__) 8 | #endif 9 | namespace discord { 10 | namespace dave { 11 | 12 | enum LoggingSeverity { 13 | LS_VERBOSE, 14 | LS_INFO, 15 | LS_WARNING, 16 | LS_ERROR, 17 | LS_NONE, 18 | }; 19 | 20 | using LogSink = void (*)(LoggingSeverity severity, 21 | const char* file, 22 | int line, 23 | const std::string& message); 24 | void SetLogSink(LogSink sink); 25 | 26 | class LogStreamer { 27 | public: 28 | LogStreamer(LoggingSeverity severity, const char* file, int line); 29 | ~LogStreamer(); 30 | 31 | template 32 | LogStreamer& operator<<(const T& value) 33 | { 34 | stream_ << value; 35 | return *this; 36 | } 37 | 38 | private: 39 | LoggingSeverity severity_; 40 | const char* file_; 41 | int line_; 42 | std::ostringstream stream_; 43 | }; 44 | 45 | } // namespace dave 46 | } // namespace discord 47 | -------------------------------------------------------------------------------- /cpp/src/dave/mls/detail/persisted_key_pair.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #include "dave/mls/persisted_key_pair.h" 9 | 10 | namespace discord { 11 | namespace dave { 12 | namespace mls { 13 | namespace detail { 14 | 15 | std::shared_ptr<::mlspp::SignaturePrivateKey> GetNativePersistedKeyPair(KeyPairContextType ctx, 16 | const std::string& keyID, 17 | ::mlspp::CipherSuite suite, 18 | bool& supported); 19 | std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair( 20 | KeyPairContextType ctx, 21 | const std::string& keyID, 22 | ::mlspp::CipherSuite suite); 23 | 24 | bool DeleteNativePersistedKeyPair(KeyPairContextType ctx, const std::string& keyID); 25 | bool DeleteGenericPersistedKeyPair(KeyPairContextType ctx, const std::string& keyID); 26 | 27 | } // namespace detail 28 | } // namespace mls 29 | } // namespace dave 30 | } // namespace discord 31 | -------------------------------------------------------------------------------- /cpp/src/dave/mls/detail/persisted_key_pair_apple.cpp: -------------------------------------------------------------------------------- 1 | #include "dave/mls/detail/persisted_key_pair.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | #include "dave/logger.h" 19 | #include "dave/mls/parameters.h" 20 | 21 | static const CFStringRef KeyServiceLabel = CFSTR("Discord Secure Frames Key"); 22 | static const std::string KeyLabelPrefix = "Discord Secure Frames Key: "; 23 | static const std::string KeyTagPrefix = "discord-secure-frames-key-"; 24 | 25 | static void AddAccessGroup([[maybe_unused]] CFMutableDictionaryRef dict) 26 | { 27 | #if TARGET_OS_IPHONE 28 | CFDictionaryAddValue(dict, kSecAttrAccessGroup, CFSTR("group.com.hammerandchisel.discord")); 29 | #endif 30 | } 31 | 32 | template 33 | struct ScopedCFTypeRef { 34 | ScopedCFTypeRef() = default; 35 | ScopedCFTypeRef(T ref) 36 | : ref_(ref) 37 | { 38 | } 39 | ScopedCFTypeRef(ScopedCFTypeRef& other) 40 | : ref_(other.ref_) 41 | { 42 | if (ref_) { 43 | CFRetain(ref_); 44 | } 45 | } 46 | ScopedCFTypeRef(ScopedCFTypeRef&& other) 47 | : ref_(std::exchange(other.ref_, nullptr)) 48 | { 49 | } 50 | 51 | ~ScopedCFTypeRef() { release(); } 52 | 53 | ScopedCFTypeRef& operator=(T ref) 54 | { 55 | release(); 56 | ref_ = ref; 57 | return *this; 58 | } 59 | 60 | void release() 61 | { 62 | if (ref_) { 63 | CFRelease(ref_); 64 | } 65 | ref_ = nullptr; 66 | } 67 | 68 | T& get() { return ref_; } 69 | 70 | T* getPtr() { return &ref_; } 71 | CFTypeRef* getGenericPtr() { return (CFTypeRef*)getPtr(); } 72 | 73 | operator T&() { return get(); } 74 | 75 | explicit operator bool() { return ref_ != nullptr; } 76 | 77 | T ref_ = nullptr; 78 | }; 79 | 80 | static std::string ConvertCFString(CFStringRef string) 81 | { 82 | if (const char* str = CFStringGetCStringPtr(string, kCFStringEncodingUTF8)) { 83 | return str; 84 | } 85 | 86 | CFIndex len = CFStringGetLength(string); 87 | std::string ret(CFStringGetMaximumSizeForEncoding(len, kCFStringEncodingUTF8), 0); 88 | 89 | CFStringGetBytes(string, 90 | CFRangeMake(0, len), 91 | kCFStringEncodingUTF8, 92 | '?', 93 | false, 94 | (UInt8*)ret.data(), 95 | ret.size(), 96 | &len); 97 | 98 | ret.resize(len); 99 | 100 | return ret; 101 | } 102 | 103 | static std::string SecStatusToString(OSStatus status) 104 | { 105 | std::string ret = std::to_string(status); 106 | 107 | if (__builtin_available(macOS 10.3, iOS 11.3, *)) { 108 | ScopedCFTypeRef string = SecCopyErrorMessageString(status, NULL); 109 | if (string) { 110 | ret += " ("; 111 | ret += ConvertCFString(string); 112 | ret += ")"; 113 | } 114 | } 115 | 116 | return ret; 117 | } 118 | 119 | static std::string ErrorToString(CFErrorRef error) 120 | { 121 | if (!error) { 122 | return "(null)"; 123 | } 124 | 125 | if (__builtin_available(macOS 10.3, iOS 11.3, *)) { 126 | OSStatus status = CFErrorGetCode(error); 127 | ScopedCFTypeRef string = SecCopyErrorMessageString(status, NULL); 128 | if (string) { 129 | std::string ret = std::to_string(status); 130 | 131 | ret += " ("; 132 | ret += ConvertCFString(string); 133 | ret += ")"; 134 | 135 | return ret; 136 | } 137 | } 138 | 139 | if (ScopedCFTypeRef string = CFErrorCopyDescription(error)) { 140 | return ConvertCFString(string); 141 | } 142 | 143 | return "(unknown)"; 144 | } 145 | 146 | namespace discord { 147 | namespace dave { 148 | namespace mls { 149 | namespace detail { 150 | 151 | std::shared_ptr<::mlspp::SignaturePrivateKey> GetNativePersistedKeyPair( 152 | [[maybe_unused]] KeyPairContextType ctx, 153 | const std::string& id, 154 | ::mlspp::CipherSuite suite, 155 | bool& supported) 156 | { 157 | std::shared_ptr<::mlspp::SignaturePrivateKey> ret; 158 | 159 | CFStringRef keyType = nullptr; 160 | int keySize = 0; 161 | std::function convertKey; 162 | 163 | ScopedCFTypeRef query = CFDictionaryCreateMutable( 164 | NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks); 165 | 166 | CFDictionaryAddValue(query, kSecReturnRef, kCFBooleanTrue); 167 | CFDictionaryAddValue(query, kSecUseAuthenticationUI, kSecUseAuthenticationUISkip); 168 | AddAccessGroup(query); 169 | 170 | auto suiteId = suite.cipher_suite(); 171 | switch (suiteId) { 172 | case ::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256: 173 | case ::mlspp::CipherSuite::ID::P384_AES256GCM_SHA384_P384: 174 | case ::mlspp::CipherSuite::ID::P521_AES256GCM_SHA512_P521: 175 | supported = true; 176 | keyType = kSecAttrKeyTypeECSECPrimeRandom; 177 | if (suiteId == ::mlspp::CipherSuite::ID::P521_AES256GCM_SHA512_P521) { 178 | keySize = 521; 179 | } 180 | else if (suiteId == ::mlspp::CipherSuite::ID::P384_AES256GCM_SHA384_P384) { 181 | keySize = 384; 182 | } 183 | else { 184 | keySize = 256; 185 | } 186 | convertKey = [keySize](CFDataRef data) { 187 | // https://developer.apple.com/documentation/security/1643698-seckeycopyexternalrepresentation 188 | // Input has a 1-byte header (always 0x04, per ANSI X9.63), followed by 3 189 | // keySize-bit left-padded byte-aligned big-endian integers: X, Y, and K. 190 | // X and Y are the public key (represented as the coordinates); 191 | // K is the private key. 192 | bytes ret; 193 | constexpr size_t HeaderSize = 1; 194 | constexpr size_t ValueCount = 3; 195 | constexpr size_t PublicValues = 2; 196 | constexpr uint8_t HeaderByte = 0x04; 197 | 198 | // Convert keySize from bits to bytes (rounding up) 199 | CFIndex byteLen = (keySize + 7) / 8; 200 | 201 | CFIndex len = CFDataGetLength(data); 202 | if (len < 0 || (size_t)len < HeaderSize + ValueCount * byteLen) { 203 | DISCORD_LOG(LS_ERROR) 204 | << "Exported key blob too small in GetPersistedKeyPair/convertKey: " << len; 205 | return ret; 206 | } 207 | 208 | const uint8_t* ptr = CFDataGetBytePtr(data); 209 | if (ptr[0] != HeaderByte) { 210 | DISCORD_LOG(LS_ERROR) 211 | << "Exported key blob has unexpected format in GetPersistedKeyPair/convertKey: " 212 | << ptr[0]; 213 | return ret; 214 | } 215 | 216 | // Skip header, X, and Y, and extract K. 217 | ptr += HeaderSize + PublicValues * byteLen; 218 | ret.as_vec().assign(ptr, ptr + byteLen); 219 | 220 | return ret; 221 | }; 222 | break; 223 | default: 224 | // Other suites will need to store keys as generic data items 225 | return nullptr; 226 | } 227 | 228 | assert(keyType && keySize && convertKey); 229 | 230 | ScopedCFTypeRef sizeRef = CFNumberCreate(NULL, kCFNumberIntType, &keySize); 231 | 232 | std::string labelString = KeyLabelPrefix + id; 233 | std::string tagString = KeyTagPrefix + id; 234 | ScopedCFTypeRef labelStringRef = 235 | CFStringCreateWithCString(NULL, labelString.c_str(), kCFStringEncodingUTF8); 236 | ScopedCFTypeRef tagDataRef = 237 | CFDataCreate(NULL, (const UInt8*)tagString.c_str(), tagString.size()); 238 | 239 | CFDictionaryAddValue(query, kSecClass, kSecClassKey); 240 | CFDictionaryAddValue(query, kSecAttrKeyType, keyType); 241 | CFDictionaryAddValue(query, kSecAttrApplicationTag, tagDataRef); 242 | CFDictionaryAddValue(query, kSecAttrCanSign, kCFBooleanTrue); 243 | 244 | ScopedCFTypeRef cfError; 245 | ScopedCFTypeRef key; 246 | 247 | // If we get errSecMissingEntitlement, try again with the file-based keychain 248 | constexpr int AttemptCount = 2; 249 | for (int attempt = 0; attempt < AttemptCount && !key; attempt++) { 250 | cfError.release(); 251 | 252 | CFBooleanRef useDataProtection = attempt == 0 ? kCFBooleanTrue : kCFBooleanFalse; 253 | if (__builtin_available(macOS 10.15, *)) { 254 | CFDictionarySetValue(query, kSecUseDataProtectionKeychain, useDataProtection); 255 | } 256 | else if (attempt == 1) { 257 | return nullptr; 258 | } 259 | 260 | OSStatus status = SecItemCopyMatching(query, key.getGenericPtr()); 261 | 262 | if (status == errSecItemNotFound) { 263 | DISCORD_LOG(LS_INFO) << "Item not found in GetPersistedKeyPair; generating new: " 264 | << SecStatusToString(status); 265 | 266 | ScopedCFTypeRef params = CFDictionaryCreateMutable( 267 | NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks); 268 | AddAccessGroup(params); 269 | CFDictionaryAddValue(params, kSecAttrKeyType, keyType); 270 | CFDictionaryAddValue(params, kSecAttrKeySizeInBits, sizeRef); 271 | CFDictionaryAddValue(params, kSecAttrCanEncrypt, kCFBooleanFalse); 272 | CFDictionaryAddValue(params, kSecAttrCanDecrypt, kCFBooleanFalse); 273 | CFDictionaryAddValue(params, kSecAttrCanWrap, kCFBooleanFalse); 274 | CFDictionaryAddValue(params, kSecAttrCanUnwrap, kCFBooleanFalse); 275 | if (__builtin_available(macOS 10.15, *)) { 276 | CFDictionaryAddValue(params, kSecUseDataProtectionKeychain, useDataProtection); 277 | } 278 | 279 | ScopedCFTypeRef privParams = CFDictionaryCreateMutable( 280 | NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks); 281 | CFDictionaryAddValue(privParams, kSecAttrIsPermanent, kCFBooleanTrue); 282 | CFDictionaryAddValue(privParams, kSecAttrLabel, labelStringRef); 283 | CFDictionaryAddValue(privParams, kSecAttrApplicationTag, tagDataRef); 284 | 285 | CFDictionaryAddValue(params, kSecPrivateKeyAttrs, privParams); 286 | 287 | key = SecKeyCreateRandomKey(params, cfError.getPtr()); 288 | 289 | if (!key || cfError) { 290 | DISCORD_LOG(LS_WARNING) 291 | << "Failed to create key in GetPersistedKeyPair: " << ErrorToString(cfError); 292 | 293 | if (!cfError || CFErrorGetCode(cfError) != errSecMissingEntitlement) { 294 | return nullptr; 295 | } 296 | 297 | key.release(); 298 | } 299 | } 300 | else if (status != 0 || !key) { 301 | DISCORD_LOG(LS_WARNING) 302 | << "Item not found GetPersistedKeyPair: " << SecStatusToString(status); 303 | if (status != errSecMissingEntitlement) { 304 | return nullptr; 305 | } 306 | } 307 | } 308 | 309 | if (!key) { 310 | return nullptr; 311 | } 312 | 313 | ScopedCFTypeRef data = SecKeyCopyExternalRepresentation(key, cfError.getPtr()); 314 | if (!data) { 315 | DISCORD_LOG(LS_ERROR) << "Failed to export key in GetPersistedKeyPair: " 316 | << ErrorToString(cfError); 317 | return nullptr; 318 | } 319 | 320 | bytes converted = convertKey(data); 321 | if (converted.empty()) { 322 | DISCORD_LOG(LS_ERROR) << "Failed to convert key in GetPersistedKeyPair"; 323 | return nullptr; 324 | } 325 | 326 | return std::make_shared<::mlspp::SignaturePrivateKey>( 327 | ::mlspp::SignaturePrivateKey::parse(suite, converted)); 328 | } 329 | 330 | std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair( 331 | [[maybe_unused]] KeyPairContextType ctx, 332 | const std::string& id, 333 | ::mlspp::CipherSuite suite) 334 | { 335 | ::mlspp::SignaturePrivateKey ret; 336 | 337 | ScopedCFTypeRef query = CFDictionaryCreateMutable( 338 | NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks); 339 | 340 | ScopedCFTypeRef accountString = 341 | CFStringCreateWithCString(NULL, id.c_str(), kCFStringEncodingUTF8); 342 | CFDictionaryAddValue(query, kSecReturnData, kCFBooleanTrue); 343 | CFDictionaryAddValue(query, kSecUseAuthenticationUI, kSecUseAuthenticationUISkip); 344 | CFDictionaryAddValue(query, kSecAttrService, KeyServiceLabel); 345 | CFDictionaryAddValue(query, kSecAttrAccount, accountString); 346 | CFDictionaryAddValue(query, kSecClass, kSecClassGenericPassword); 347 | AddAccessGroup(query); 348 | 349 | // If we get errSecMissingEntitlement, try again with the file-based keychain 350 | constexpr int AttemptCount = 2; 351 | for (int attempt = 0; attempt < AttemptCount && ret.public_key.data.empty(); attempt++) { 352 | if (__builtin_available(macOS 10.15, *)) { 353 | CFDictionarySetValue(query, 354 | kSecUseDataProtectionKeychain, 355 | attempt == 0 ? kCFBooleanTrue : kCFBooleanFalse); 356 | } 357 | else if (attempt == 1) { 358 | return nullptr; 359 | } 360 | 361 | ScopedCFTypeRef result; 362 | OSStatus status = SecItemCopyMatching(query, result.getGenericPtr()); 363 | 364 | std::string curstr; 365 | if (status == 0 && result) { 366 | curstr.assign((char*)CFDataGetBytePtr(result), CFDataGetLength(result)); 367 | 368 | try { 369 | ret = ::mlspp::SignaturePrivateKey::from_jwk(suite, curstr); 370 | } 371 | catch (std::exception& ex) { 372 | DISCORD_LOG(LS_WARNING) 373 | << "Failed to parse key in GetPersistedKeyPair: " << ex.what(); 374 | return nullptr; 375 | } 376 | } 377 | else if (status == errSecItemNotFound) { 378 | DISCORD_LOG(LS_INFO) << "Did not receive item in GetPersistedKeyPair; generating new: " 379 | << SecStatusToString(status); 380 | 381 | ret = ::mlspp::SignaturePrivateKey::generate(suite); 382 | 383 | std::string newstr = ret.to_jwk(suite); 384 | 385 | ScopedCFTypeRef data = 386 | CFDataCreate(NULL, (const UInt8*)newstr.c_str(), newstr.length()); 387 | 388 | CFDictionaryRemoveValue(query, kSecReturnData); 389 | CFDictionaryAddValue(query, kSecValueData, data); 390 | 391 | status = SecItemAdd(query, nullptr); 392 | if (status) { 393 | DISCORD_LOG(LS_WARNING) << "Failed to create keychain item in GetPersistedKeyPair: " 394 | << SecStatusToString(status); 395 | 396 | if (status != errSecMissingEntitlement) { 397 | return nullptr; 398 | } 399 | 400 | ret = ::mlspp::SignaturePrivateKey(); 401 | } 402 | } 403 | else { 404 | DISCORD_LOG(LS_WARNING) 405 | << "Failed to retrieve item in GetPersistedKeyPair: " << SecStatusToString(status); 406 | if (status != errSecMissingEntitlement) { 407 | return nullptr; 408 | } 409 | } 410 | } 411 | 412 | if (!ret.public_key.data.empty()) { 413 | return std::make_shared<::mlspp::SignaturePrivateKey>(std::move(ret)); 414 | } 415 | else { 416 | return nullptr; 417 | } 418 | } 419 | 420 | static bool DeleteWithQuery(CFMutableDictionaryRef query) 421 | { 422 | #if !TARGET_OS_IPHONE 423 | if (__builtin_available(macOS 10.15, *)) { 424 | CFDictionarySetValue(query, kSecUseDataProtectionKeychain, kCFBooleanTrue); 425 | } 426 | #endif 427 | 428 | auto ret = SecItemDelete(query); 429 | 430 | #if !TARGET_OS_IPHONE 431 | if (__builtin_available(macOS 10.15, *)) { 432 | if (ret == errSecMissingEntitlement) { 433 | CFDictionarySetValue(query, kSecUseDataProtectionKeychain, kCFBooleanFalse); 434 | ret = SecItemDelete(query); 435 | } 436 | } 437 | #endif 438 | 439 | return ret == errSecSuccess; 440 | } 441 | 442 | bool DeleteNativePersistedKeyPair([[maybe_unused]] KeyPairContextType ctx, const std::string& id) 443 | { 444 | std::string tagString = KeyTagPrefix + id; 445 | ScopedCFTypeRef tagDataRef = 446 | CFDataCreate(NULL, (const UInt8*)tagString.c_str(), tagString.size()); 447 | 448 | ScopedCFTypeRef query = CFDictionaryCreateMutable( 449 | NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks); 450 | 451 | CFDictionaryAddValue(query, kSecClass, kSecClassKey); 452 | CFDictionaryAddValue(query, kSecAttrApplicationTag, tagDataRef); 453 | AddAccessGroup(query); 454 | 455 | return DeleteWithQuery(query); 456 | } 457 | 458 | bool DeleteGenericPersistedKeyPair([[maybe_unused]] KeyPairContextType ctx, const std::string& id) 459 | { 460 | ScopedCFTypeRef accountString = 461 | CFStringCreateWithCString(NULL, id.c_str(), kCFStringEncodingUTF8); 462 | 463 | ScopedCFTypeRef query = CFDictionaryCreateMutable( 464 | NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks); 465 | 466 | CFDictionaryAddValue(query, kSecAttrService, KeyServiceLabel); 467 | CFDictionaryAddValue(query, kSecAttrAccount, accountString); 468 | CFDictionaryAddValue(query, kSecClass, kSecClassGenericPassword); 469 | AddAccessGroup(query); 470 | 471 | return DeleteWithQuery(query); 472 | } 473 | 474 | } // namespace detail 475 | } // namespace mls 476 | } // namespace dave 477 | } // namespace discord -------------------------------------------------------------------------------- /cpp/src/dave/mls/detail/persisted_key_pair_generic.cpp: -------------------------------------------------------------------------------- 1 | #include "dave/mls/detail/persisted_key_pair.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #ifdef _WIN32 12 | #include 13 | #else 14 | #include 15 | #endif 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | #include "dave/logger.h" 22 | #include "dave/mls/parameters.h" 23 | 24 | static const std::string_view KeyStorageDir = "Discord Key Storage"; 25 | 26 | static std::filesystem::path GetKeyStorageDirectory() 27 | { 28 | std::filesystem::path dir; 29 | 30 | #if defined(__ANDROID__) 31 | dir = std::filesystem::path("/data/data"); 32 | 33 | { 34 | std::ifstream idFile("/proc/self/cmdline", std::ios_base::in); 35 | std::string appId; 36 | std::getline(idFile, appId, '\0'); 37 | dir /= appId; 38 | } 39 | #else // __ANDROID__ 40 | #if defined(_WIN32) 41 | if (const wchar_t* appdata = _wgetenv(L"LOCALAPPDATA")) { 42 | dir = std::filesystem::path(appdata); 43 | } 44 | #else // _WIN32 45 | if (const char* xdg = getenv("XDG_CONFIG_HOME")) { 46 | dir = std::filesystem::path(xdg); 47 | } 48 | else if (const char* home = getenv("HOME")) { 49 | dir = std::filesystem::path(home); 50 | dir /= ".config"; 51 | } 52 | #endif // !_WIN32 53 | else { 54 | return dir; 55 | } 56 | #endif // !__ANDROID__ 57 | 58 | return dir / KeyStorageDir; 59 | } 60 | 61 | namespace discord { 62 | namespace dave { 63 | namespace mls { 64 | namespace detail { 65 | 66 | std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair( 67 | [[maybe_unused]] KeyPairContextType ctx, 68 | const std::string& id, 69 | ::mlspp::CipherSuite suite) 70 | { 71 | ::mlspp::SignaturePrivateKey ret; 72 | std::string curstr; 73 | std::filesystem::path dir = GetKeyStorageDirectory(); 74 | 75 | if (dir.empty()) { 76 | DISCORD_LOG(LS_ERROR) << "Failed to determine key storage directory in GetPersistedKeyPair"; 77 | return nullptr; 78 | } 79 | 80 | std::error_code errc; 81 | std::filesystem::create_directories(dir, errc); 82 | if (errc) { 83 | DISCORD_LOG(LS_ERROR) << "Failed to create key storage directory in GetPersistedKeyPair: " 84 | << errc; 85 | return nullptr; 86 | } 87 | 88 | std::filesystem::path file = dir / (id + ".key"); 89 | 90 | if (std::filesystem::exists(file)) { 91 | std::ifstream ifs(file, std::ios_base::in | std::ios_base::binary); 92 | if (!ifs) { 93 | DISCORD_LOG(LS_ERROR) << "Failed to open key in GetPersistedKeyPair"; 94 | return nullptr; 95 | } 96 | 97 | curstr = (std::stringstream() << ifs.rdbuf()).str(); 98 | if (!ifs) { 99 | DISCORD_LOG(LS_ERROR) << "Failed to read key in GetPersistedKeyPair"; 100 | return nullptr; 101 | } 102 | 103 | try { 104 | ret = ::mlspp::SignaturePrivateKey::from_jwk(suite, curstr); 105 | } 106 | catch (std::exception& ex) { 107 | DISCORD_LOG(LS_ERROR) << "Failed to parse key in GetPersistedKeyPair: " << ex.what(); 108 | return nullptr; 109 | } 110 | } 111 | else { 112 | ret = ::mlspp::SignaturePrivateKey::generate(suite); 113 | 114 | std::string newstr = ret.to_jwk(suite); 115 | 116 | std::filesystem::path tmpfile = file; 117 | tmpfile += ".tmp"; 118 | 119 | #ifdef _WIN32 120 | int fd = _wopen(tmpfile.c_str(), _O_WRONLY | _O_CREAT | _O_TRUNC, _S_IREAD | _S_IWRITE); 121 | #else 122 | int fd = open(tmpfile.c_str(), 123 | O_WRONLY | O_CLOEXEC | O_NOFOLLOW | O_CREAT | O_TRUNC, 124 | S_IRUSR | S_IWUSR); 125 | #endif 126 | if (fd < 0) { 127 | DISCORD_LOG(LS_ERROR) << "Failed to open output file in GetPersistedKeyPair: " << errno 128 | << "(" << tmpfile << ")"; 129 | return nullptr; 130 | } 131 | 132 | #ifdef _WIN32 133 | int wret = _write(fd, newstr.c_str(), static_cast(newstr.size())); 134 | _close(fd); 135 | #else 136 | ssize_t wret = write(fd, newstr.c_str(), newstr.size()); 137 | close(fd); 138 | #endif 139 | if (wret < 0 || (size_t)wret != newstr.size()) { 140 | DISCORD_LOG(LS_ERROR) << "Failed to write output file in GetPersistedKeyPair: " 141 | << errno; 142 | return nullptr; 143 | } 144 | 145 | std::filesystem::rename(tmpfile, file, errc); 146 | if (errc) { 147 | DISCORD_LOG(LS_ERROR) << "Failed to rename output file in GetPersistedKeyPair: " 148 | << errc; 149 | return nullptr; 150 | } 151 | } 152 | 153 | if (!ret.public_key.data.empty()) { 154 | return std::make_shared<::mlspp::SignaturePrivateKey>(std::move(ret)); 155 | } 156 | else { 157 | return nullptr; 158 | } 159 | } 160 | 161 | bool DeleteGenericPersistedKeyPair([[maybe_unused]] KeyPairContextType ctx, const std::string& id) 162 | { 163 | std::error_code errc; 164 | std::filesystem::path dir = GetKeyStorageDirectory(); 165 | if (dir.empty()) { 166 | DISCORD_LOG(LS_ERROR) << "Failed to determine key storage directory in GetPersistedKeyPair"; 167 | return false; 168 | } 169 | 170 | std::filesystem::path file = dir / (id + ".key"); 171 | 172 | return std::filesystem::remove(file, errc); 173 | } 174 | 175 | } // namespace detail 176 | } // namespace mls 177 | } // namespace dave 178 | } // namespace discord 179 | -------------------------------------------------------------------------------- /cpp/src/dave/mls/detail/persisted_key_pair_null.cpp: -------------------------------------------------------------------------------- 1 | #include "dave/mls/detail/persisted_key_pair.h" 2 | 3 | namespace discord { 4 | namespace dave { 5 | namespace mls { 6 | namespace detail { 7 | 8 | std::shared_ptr<::mlspp::SignaturePrivateKey> GetNativePersistedKeyPair( 9 | [[maybe_unused]] KeyPairContextType ctx, 10 | [[maybe_unused]] const std::string& keyID, 11 | [[maybe_unused]] ::mlspp::CipherSuite suite, 12 | bool& supported) 13 | { 14 | supported = false; 15 | return nullptr; 16 | } 17 | 18 | bool DeleteNativePersistedKeyPair([[maybe_unused]] KeyPairContextType ctx, 19 | [[maybe_unused]] const std::string& keyID) 20 | { 21 | return false; 22 | } 23 | 24 | } // namespace detail 25 | } // namespace mls 26 | } // namespace dave 27 | } // namespace discord 28 | -------------------------------------------------------------------------------- /cpp/src/dave/mls/detail/persisted_key_pair_win.cpp: -------------------------------------------------------------------------------- 1 | #include "dave/mls/detail/persisted_key_pair.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #ifndef SECURITY_WIN32 12 | #define SECURITY_WIN32 1 13 | #endif 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | #include 21 | #include 22 | 23 | #include "dave/logger.h" 24 | #include "dave/mls/parameters.h" 25 | 26 | static const std::string KeyTagPrefix = "discord-secure-frames-key-"; 27 | 28 | template 29 | struct ScopedNCryptHandle { 30 | ScopedNCryptHandle() = default; 31 | ScopedNCryptHandle(T handle) 32 | : handle_(handle) 33 | { 34 | } 35 | ScopedNCryptHandle(const ScopedNCryptHandle& other) = delete; 36 | ScopedNCryptHandle(ScopedNCryptHandle&& other) 37 | : handle_(std::exchange(other.handle_, T())) 38 | { 39 | } 40 | 41 | ~ScopedNCryptHandle() { finalize(); } 42 | 43 | ScopedNCryptHandle& operator=(T handle) 44 | { 45 | finalize(); 46 | handle_ = handle; 47 | return *this; 48 | } 49 | 50 | T release() { return std::exchange(handle_, T()); } 51 | 52 | void finalize() 53 | { 54 | if (auto handle = release()) { 55 | NCryptFreeObject(handle); 56 | } 57 | } 58 | 59 | T& get() { return handle_; } 60 | 61 | T* getPtr() { return &handle_; } 62 | 63 | operator T&() { return get(); } 64 | 65 | explicit operator bool() { return handle_ != T(); } 66 | 67 | T handle_ = T(); 68 | }; 69 | 70 | namespace discord { 71 | namespace dave { 72 | namespace mls { 73 | namespace detail { 74 | 75 | std::shared_ptr<::mlspp::SignaturePrivateKey> GetNativePersistedKeyPair( 76 | [[maybe_unused]] KeyPairContextType ctx, 77 | const std::string& id, 78 | ::mlspp::CipherSuite suite, 79 | bool& supported) 80 | { 81 | LPCWSTR keyType = nullptr; 82 | ULONG keyBlobMagic = 0; 83 | std::function convertBlob; 84 | 85 | auto suiteId = suite.cipher_suite(); 86 | switch (suiteId) { 87 | case ::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256: 88 | case ::mlspp::CipherSuite::ID::P384_AES256GCM_SHA384_P384: 89 | case ::mlspp::CipherSuite::ID::P521_AES256GCM_SHA512_P521: 90 | supported = true; 91 | if (suiteId == ::mlspp::CipherSuite::ID::P521_AES256GCM_SHA512_P521) { 92 | keyType = BCRYPT_ECDSA_P521_ALGORITHM; 93 | keyBlobMagic = BCRYPT_ECDSA_PRIVATE_P521_MAGIC; 94 | } 95 | else if (suiteId == ::mlspp::CipherSuite::ID::P384_AES256GCM_SHA384_P384) { 96 | keyType = BCRYPT_ECDSA_P384_ALGORITHM; 97 | keyBlobMagic = BCRYPT_ECDSA_PRIVATE_P384_MAGIC; 98 | } 99 | else { 100 | keyType = BCRYPT_ECDSA_P256_ALGORITHM; 101 | keyBlobMagic = BCRYPT_ECDSA_PRIVATE_P256_MAGIC; 102 | } 103 | 104 | convertBlob = [](bytes& blob) { 105 | // https://learn.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_ecckey_blob 106 | // Input has an PBCRYPT_ECCKEY_BLOB header, followed by 3 cbKey-byte big-endian 107 | // integers: X, Y, and d. X and Y are the public key (represented as the coordinates); 108 | // d is the private key. 109 | constexpr size_t ValueCount = 3; 110 | constexpr size_t PublicValues = 2; 111 | 112 | if (blob.size() < sizeof(BCRYPT_ECCKEY_BLOB)) { 113 | DISCORD_LOG(LS_ERROR) 114 | << "Exported key blob too small in GetPersistedKeyPair/convertBlob: " 115 | << blob.size(); 116 | return false; 117 | } 118 | 119 | PBCRYPT_ECCKEY_BLOB header = (PBCRYPT_ECCKEY_BLOB)blob.data(); 120 | ULONG keySize = header->cbKey; 121 | if (blob.size() < sizeof(BCRYPT_ECCKEY_BLOB) + keySize * ValueCount) { 122 | DISCORD_LOG(LS_ERROR) 123 | << "Exported key blob too small in GetPersistedKeyPair/convertBlob: " 124 | << blob.size(); 125 | return false; 126 | } 127 | blob.resize(sizeof(BCRYPT_ECCKEY_BLOB) + keySize * ValueCount); 128 | blob.as_vec().erase(blob.begin(), 129 | blob.begin() + sizeof(BCRYPT_ECCKEY_BLOB) + keySize * PublicValues); 130 | return true; 131 | }; 132 | break; 133 | default: 134 | // Other suites will need to store keys as JWK files on disk 135 | return nullptr; 136 | } 137 | 138 | assert(keyType && keyBlobMagic && convertBlob); 139 | 140 | ScopedNCryptHandle provider; 141 | SECURITY_STATUS status = 142 | NCryptOpenStorageProvider(provider.getPtr(), MS_KEY_STORAGE_PROVIDER, 0); 143 | if (status != ERROR_SUCCESS) { 144 | DISCORD_LOG(LS_ERROR) << "Failed to open storage provider in GetPersistedKeyPair: " 145 | << status; 146 | return nullptr; 147 | } 148 | 149 | std::filesystem::path keyName = KeyTagPrefix + id; 150 | 151 | ScopedNCryptHandle key; 152 | status = 153 | NCryptOpenKey(provider, key.getPtr(), keyName.c_str(), AT_SIGNATURE, NCRYPT_SILENT_FLAG); 154 | 155 | if (status == NTE_BAD_KEYSET) { 156 | DISCORD_LOG(LS_INFO) << "No key found in GetPersistedKeyPair; generating new"; 157 | 158 | status = NCryptCreatePersistedKey( 159 | provider, key.getPtr(), keyType, keyName.c_str(), AT_SIGNATURE, 0); 160 | if (status != ERROR_SUCCESS) { 161 | DISCORD_LOG(LS_ERROR) << "Failed to create key in GetPersistedKeyPair: " << status; 162 | return nullptr; 163 | } 164 | 165 | DWORD exportPolicyValue = NCRYPT_ALLOW_EXPORT_FLAG | NCRYPT_ALLOW_PLAINTEXT_EXPORT_FLAG; 166 | status = NCryptSetProperty(key, 167 | NCRYPT_EXPORT_POLICY_PROPERTY, 168 | (PBYTE)&exportPolicyValue, 169 | sizeof(exportPolicyValue), 170 | NCRYPT_PERSIST_FLAG | NCRYPT_SILENT_FLAG); 171 | if (status != ERROR_SUCCESS) { 172 | DISCORD_LOG(LS_ERROR) 173 | << "Failed to configure key export policy in GetPersistedKeyPair: " << status; 174 | return nullptr; 175 | } 176 | 177 | // struct { 178 | // DWORD dwVersion; 179 | // DWORD dwFlags; 180 | // LPCWSTR pszCreationTitle; 181 | // LPCWSTR pszFriendlyName; 182 | // LPCWSTR pszDescription; 183 | // } NCRYPT_UI_POLICY; 184 | 185 | NCRYPT_UI_POLICY uiPolicyValue = {1, 0, nullptr, nullptr, nullptr}; 186 | status = NCryptSetProperty(key, 187 | NCRYPT_UI_POLICY_PROPERTY, 188 | (PBYTE)&uiPolicyValue, 189 | sizeof(uiPolicyValue), 190 | NCRYPT_PERSIST_FLAG | NCRYPT_SILENT_FLAG); 191 | if (status != ERROR_SUCCESS) { 192 | DISCORD_LOG(LS_ERROR) << "Failed to configure key UI policy in GetPersistedKeyPair: " 193 | << status; 194 | return nullptr; 195 | } 196 | 197 | status = NCryptFinalizeKey(key, NCRYPT_SILENT_FLAG); 198 | if (status != ERROR_SUCCESS) { 199 | DISCORD_LOG(LS_ERROR) << "Failed to finalize key in GetPersistedKeyPair: " << status; 200 | return nullptr; 201 | } 202 | } 203 | else if (status != ERROR_SUCCESS) { 204 | DISCORD_LOG(LS_ERROR) << "Failed to open key in GetPersistedKeyPair: " << status; 205 | return nullptr; 206 | } 207 | 208 | DWORD keySize = 0; 209 | status = NCryptExportKey( 210 | key, NULL, BCRYPT_PRIVATE_KEY_BLOB, NULL, NULL, 0, &keySize, NCRYPT_SILENT_FLAG); 211 | if (status != ERROR_SUCCESS) { 212 | DISCORD_LOG(LS_ERROR) << "Failed to size key in GetPersistedKeyPair: " << status; 213 | return nullptr; 214 | } 215 | 216 | bytes keyData(keySize); 217 | 218 | status = NCryptExportKey(key, 219 | NULL, 220 | BCRYPT_PRIVATE_KEY_BLOB, 221 | NULL, 222 | keyData.data(), 223 | keySize, 224 | &keySize, 225 | NCRYPT_SILENT_FLAG); 226 | if (status != ERROR_SUCCESS) { 227 | DISCORD_LOG(LS_ERROR) << "Failed to export key in GetPersistedKeyPair: " << status; 228 | return nullptr; 229 | } 230 | 231 | if (keyData.size() < sizeof(BCRYPT_KEY_BLOB)) { 232 | DISCORD_LOG(LS_ERROR) << "Exported key blob too small in GetPersistedKeyPair/convertBlob: " 233 | << keyData.size(); 234 | return nullptr; 235 | } 236 | 237 | BCRYPT_KEY_BLOB* header = (BCRYPT_KEY_BLOB*)keyData.data(); 238 | if (header->Magic != keyBlobMagic) { 239 | DISCORD_LOG(LS_ERROR) << "Exported key blob has unexpected magic in GetPersistedKeyPair: " 240 | << header->Magic; 241 | return nullptr; 242 | } 243 | 244 | if (!convertBlob(keyData)) { 245 | DISCORD_LOG(LS_ERROR) << "Failed to convert key in GetPersistedKeyPair"; 246 | return nullptr; 247 | } 248 | 249 | return std::make_shared<::mlspp::SignaturePrivateKey>( 250 | ::mlspp::SignaturePrivateKey::parse(suite, keyData)); 251 | } 252 | 253 | bool DeleteNativePersistedKeyPair([[maybe_unused]] KeyPairContextType ctx, const std::string& id) 254 | { 255 | ScopedNCryptHandle provider; 256 | SECURITY_STATUS status = 257 | NCryptOpenStorageProvider(provider.getPtr(), MS_KEY_STORAGE_PROVIDER, 0); 258 | if (status != ERROR_SUCCESS) { 259 | DISCORD_LOG(LS_ERROR) << "Failed to open storage provider in DeletePersistedKeyPair: " 260 | << status; 261 | return false; 262 | } 263 | 264 | std::filesystem::path keyName = KeyTagPrefix + id; 265 | 266 | ScopedNCryptHandle key; 267 | status = 268 | NCryptOpenKey(provider, key.getPtr(), keyName.c_str(), AT_SIGNATURE, NCRYPT_SILENT_FLAG); 269 | if (status != ERROR_SUCCESS) { 270 | return false; 271 | } 272 | 273 | auto ret = NCryptDeleteKey(key, NCRYPT_SILENT_FLAG); 274 | if (ret == ERROR_SUCCESS) { 275 | // If NCryptDeleteKey succeeds, it frees the handle, so our wrapper shouldn't also do so. 276 | key.release(); 277 | return true; 278 | } 279 | else { 280 | return false; 281 | } 282 | } 283 | 284 | } // namespace detail 285 | } // namespace mls 286 | } // namespace dave 287 | } // namespace discord 288 | -------------------------------------------------------------------------------- /cpp/src/dave/mls/parameters.cpp: -------------------------------------------------------------------------------- 1 | #include "parameters.h" 2 | 3 | namespace discord { 4 | namespace dave { 5 | namespace mls { 6 | 7 | ::mlspp::CipherSuite::ID CiphersuiteIDForProtocolVersion( 8 | [[maybe_unused]] ProtocolVersion version) noexcept 9 | { 10 | return ::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256; 11 | } 12 | 13 | ::mlspp::CipherSuite CiphersuiteForProtocolVersion(ProtocolVersion version) noexcept 14 | { 15 | return ::mlspp::CipherSuite{CiphersuiteIDForProtocolVersion(version)}; 16 | } 17 | 18 | ::mlspp::CipherSuite::ID CiphersuiteIDForSignatureVersion( 19 | [[maybe_unused]] SignatureVersion version) noexcept 20 | { 21 | return ::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256; 22 | } 23 | 24 | ::mlspp::CipherSuite CiphersuiteForSignatureVersion(SignatureVersion version) noexcept 25 | { 26 | return ::mlspp::CipherSuite{CiphersuiteIDForProtocolVersion(version)}; 27 | } 28 | 29 | ::mlspp::Capabilities LeafNodeCapabilitiesForProtocolVersion(ProtocolVersion version) noexcept 30 | { 31 | auto capabilities = ::mlspp::Capabilities::create_default(); 32 | 33 | capabilities.cipher_suites = {CiphersuiteIDForProtocolVersion(version)}; 34 | capabilities.credentials = {::mlspp::CredentialType::basic}; 35 | 36 | return capabilities; 37 | } 38 | 39 | ::mlspp::ExtensionList LeafNodeExtensionsForProtocolVersion( 40 | [[maybe_unused]] ProtocolVersion version) noexcept 41 | { 42 | return ::mlspp::ExtensionList{}; 43 | } 44 | 45 | ::mlspp::ExtensionList GroupExtensionsForProtocolVersion( 46 | [[maybe_unused]] ProtocolVersion version, 47 | const ::mlspp::ExternalSender& externalSender) noexcept 48 | { 49 | auto extensionList = ::mlspp::ExtensionList{}; 50 | 51 | extensionList.add(::mlspp::ExternalSendersExtension{{ 52 | {externalSender.signature_key, externalSender.credential}, 53 | }}); 54 | 55 | return extensionList; 56 | } 57 | 58 | } // namespace mls 59 | } // namespace dave 60 | } // namespace discord 61 | -------------------------------------------------------------------------------- /cpp/src/dave/mls/parameters.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "dave/version.h" 8 | 9 | namespace discord { 10 | namespace dave { 11 | namespace mls { 12 | 13 | ::mlspp::CipherSuite::ID CiphersuiteIDForProtocolVersion(ProtocolVersion version) noexcept; 14 | ::mlspp::CipherSuite CiphersuiteForProtocolVersion(ProtocolVersion version) noexcept; 15 | ::mlspp::CipherSuite::ID CiphersuiteIDForSignatureVersion(SignatureVersion version) noexcept; 16 | ::mlspp::CipherSuite CiphersuiteForSignatureVersion(SignatureVersion version) noexcept; 17 | ::mlspp::Capabilities LeafNodeCapabilitiesForProtocolVersion(ProtocolVersion version) noexcept; 18 | ::mlspp::ExtensionList LeafNodeExtensionsForProtocolVersion(ProtocolVersion version) noexcept; 19 | ::mlspp::ExtensionList GroupExtensionsForProtocolVersion( 20 | ProtocolVersion version, 21 | const ::mlspp::ExternalSender& externalSender) noexcept; 22 | 23 | } // namespace mls 24 | } // namespace dave 25 | } // namespace discord 26 | -------------------------------------------------------------------------------- /cpp/src/dave/mls/persisted_key_pair.cpp: -------------------------------------------------------------------------------- 1 | #include "dave/mls/detail/persisted_key_pair.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | #include "dave/logger.h" 15 | #include "dave/mls/parameters.h" 16 | 17 | static const std::string SelfSignatureLabel = "DiscordSelfSignature"; 18 | 19 | static std::string MakeKeyID(const std::string& sessionID, ::mlspp::CipherSuite suite) 20 | { 21 | return sessionID + "-" + std::to_string((uint16_t)suite.cipher_suite()) + "-" + 22 | std::to_string(discord::dave::mls::KeyVersion); 23 | } 24 | 25 | static std::mutex mtx; 26 | static std::map> map; 27 | 28 | namespace discord { 29 | namespace dave { 30 | namespace mls { 31 | 32 | static std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair( 33 | [[maybe_unused]] KeyPairContextType ctx, 34 | const std::string& sessionID, 35 | ::mlspp::CipherSuite suite) 36 | { 37 | std::lock_guard lk(mtx); 38 | 39 | std::string id = MakeKeyID(sessionID, suite); 40 | 41 | if (auto it = map.find(id); it != map.end()) { 42 | return it->second; 43 | } 44 | 45 | std::shared_ptr<::mlspp::SignaturePrivateKey> ret; 46 | 47 | bool supported = false; 48 | ret = ::discord::dave::mls::detail::GetNativePersistedKeyPair(ctx, id, suite, supported); 49 | 50 | if (!ret && supported) { 51 | // Do not fall back on the generic route if we error here 52 | DISCORD_LOG(LS_ERROR) << "Encountered error in native key handling in GetPersistedKeyPair"; 53 | return nullptr; 54 | } 55 | 56 | if (!ret) { 57 | ret = ::discord::dave::mls::detail::GetGenericPersistedKeyPair(ctx, id, suite); 58 | } 59 | 60 | if (!ret) { 61 | DISCORD_LOG(LS_ERROR) << "Failed to get key in GetPersistedKeyPair"; 62 | return nullptr; 63 | } 64 | 65 | map.emplace(id, ret); 66 | 67 | return ret; 68 | } 69 | 70 | std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair(KeyPairContextType ctx, 71 | const std::string& sessionID, 72 | ProtocolVersion version) 73 | { 74 | return GetPersistedKeyPair(ctx, sessionID, CiphersuiteForProtocolVersion(version)); 75 | } 76 | 77 | KeyAndSelfSignature GetPersistedPublicKey(KeyPairContextType ctx, 78 | const std::string& sessionID, 79 | SignatureVersion version) 80 | { 81 | auto suite = CiphersuiteForSignatureVersion(version); 82 | 83 | auto pair = GetPersistedKeyPair(ctx, sessionID, suite); 84 | 85 | if (!pair) { 86 | return {}; 87 | } 88 | 89 | bytes sign_data = from_ascii(sessionID + ":") + pair->public_key.data; 90 | 91 | return { 92 | pair->public_key.data.as_vec(), 93 | std::move(pair->sign(suite, SelfSignatureLabel, sign_data).as_vec()), 94 | }; 95 | } 96 | 97 | bool DeletePersistedKeyPair([[maybe_unused]] KeyPairContextType ctx, 98 | const std::string& sessionID, 99 | SignatureVersion version) 100 | { 101 | std::string id = MakeKeyID(sessionID, CiphersuiteForSignatureVersion(version)); 102 | 103 | std::lock_guard lk(mtx); 104 | 105 | map.erase(id); 106 | 107 | bool native = ::discord::dave::mls::detail::DeleteNativePersistedKeyPair(ctx, id); 108 | bool generic = ::discord::dave::mls::detail::DeleteGenericPersistedKeyPair(ctx, id); 109 | 110 | return native || generic; 111 | } 112 | 113 | } // namespace mls 114 | } // namespace dave 115 | } // namespace discord 116 | -------------------------------------------------------------------------------- /cpp/src/dave/mls/persisted_key_pair.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #ifdef __ANDROID__ 8 | #include 9 | #endif 10 | 11 | #include "dave/version.h" 12 | 13 | namespace mlspp { 14 | struct SignaturePrivateKey; 15 | }; 16 | 17 | namespace discord { 18 | namespace dave { 19 | namespace mls { 20 | 21 | #if defined(__ANDROID__) 22 | typedef JNIEnv* KeyPairContextType; 23 | #else 24 | typedef const char* KeyPairContextType; 25 | #endif 26 | 27 | std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair(KeyPairContextType ctx, 28 | const std::string& sessionID, 29 | ProtocolVersion version); 30 | 31 | struct KeyAndSelfSignature { 32 | std::vector key; 33 | std::vector signature; 34 | }; 35 | 36 | KeyAndSelfSignature GetPersistedPublicKey(KeyPairContextType ctx, 37 | const std::string& sessionID, 38 | SignatureVersion version); 39 | 40 | bool DeletePersistedKeyPair(KeyPairContextType ctx, 41 | const std::string& sessionID, 42 | SignatureVersion version); 43 | 44 | constexpr unsigned KeyVersion = 1; 45 | 46 | } // namespace mls 47 | } // namespace dave 48 | } // namespace discord 49 | -------------------------------------------------------------------------------- /cpp/src/dave/mls/persisted_key_pair_null.cpp: -------------------------------------------------------------------------------- 1 | #include "persisted_key_pair.h" 2 | 3 | namespace discord { 4 | namespace dave { 5 | namespace mls { 6 | 7 | std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair( 8 | [[maybe_unused]] KeyPairContextType, 9 | const std::string&, 10 | ProtocolVersion) 11 | { 12 | return nullptr; 13 | } 14 | 15 | bool DeletePersistedKeyPair([[maybe_unused]] KeyPairContextType, 16 | [[maybe_unused]] const std::string&, 17 | [[maybe_unused]] SignatureVersion version) 18 | { 19 | return false; 20 | } 21 | 22 | } // namespace mls 23 | } // namespace dave 24 | } // namespace discord 25 | -------------------------------------------------------------------------------- /cpp/src/dave/mls/session.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include "dave/key_ratchet.h" 14 | #include "dave/mls/persisted_key_pair.h" 15 | #include "dave/version.h" 16 | 17 | namespace mlspp { 18 | struct AuthenticatedContent; 19 | struct Credential; 20 | struct ExternalSender; 21 | struct HPKEPrivateKey; 22 | struct KeyPackage; 23 | struct LeafNode; 24 | struct MLSMessage; 25 | struct SignaturePrivateKey; 26 | class State; 27 | } // namespace mlspp 28 | 29 | namespace discord { 30 | namespace dave { 31 | namespace mls { 32 | 33 | struct QueuedProposal; 34 | 35 | class Session { 36 | public: 37 | using MLSFailureCallback = std::function; 38 | 39 | Session(KeyPairContextType context, 40 | std::string authSessionId, 41 | MLSFailureCallback callback) noexcept; 42 | 43 | ~Session() noexcept; 44 | 45 | void Init(ProtocolVersion version, 46 | uint64_t groupId, 47 | std::string const& selfUserId, 48 | std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept; 49 | void Reset() noexcept; 50 | 51 | void SetProtocolVersion(ProtocolVersion version) noexcept; 52 | ProtocolVersion GetProtocolVersion() const noexcept { return protocolVersion_; } 53 | 54 | std::vector GetLastEpochAuthenticator() const noexcept; 55 | 56 | void SetExternalSender(std::vector const& externalSenderPackage) noexcept; 57 | 58 | std::optional> ProcessProposals( 59 | std::vector proposals, 60 | std::set const& recognizedUserIDs) noexcept; 61 | 62 | RosterVariant ProcessCommit(std::vector commit) noexcept; 63 | 64 | std::optional ProcessWelcome( 65 | std::vector welcome, 66 | std::set const& recognizedUserIDs) noexcept; 67 | 68 | std::vector GetMarshalledKeyPackage() noexcept; 69 | 70 | std::unique_ptr GetKeyRatchet(std::string const& userId) const noexcept; 71 | 72 | using PairwiseFingerprintCallback = std::function const&)>; 73 | 74 | void GetPairwiseFingerprint(uint16_t version, 75 | std::string const& userId, 76 | PairwiseFingerprintCallback callback) const noexcept; 77 | 78 | private: 79 | void InitLeafNode(std::string const& selfUserId, 80 | std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept; 81 | void ResetJoinKeyPackage() noexcept; 82 | 83 | void CreatePendingGroup() noexcept; 84 | 85 | bool HasCryptographicStateForWelcome() const noexcept; 86 | 87 | bool IsRecognizedUserID(const ::mlspp::Credential& cred, 88 | std::set const& recognizedUserIDs) const; 89 | bool ValidateProposalMessage(::mlspp::AuthenticatedContent const& message, 90 | ::mlspp::State const& targetState, 91 | std::set const& recognizedUserIDs) const; 92 | bool VerifyWelcomeState(::mlspp::State const& state, 93 | std::set const& recognizedUserIDs) const; 94 | 95 | bool CanProcessCommit(const ::mlspp::MLSMessage& commit) noexcept; 96 | 97 | RosterMap ReplaceState(std::unique_ptr<::mlspp::State>&& state); 98 | 99 | void ClearPendingState(); 100 | 101 | inline static const std::string USER_MEDIA_KEY_BASE_LABEL = "Discord Secure Frames v0"; 102 | 103 | ProtocolVersion protocolVersion_; 104 | std::vector groupId_; 105 | std::string signingKeyId_; 106 | std::string selfUserId_; 107 | KeyPairContextType keyPairContext_{nullptr}; 108 | 109 | std::unique_ptr<::mlspp::LeafNode> selfLeafNode_; 110 | std::shared_ptr<::mlspp::SignaturePrivateKey> selfSigPrivateKey_; 111 | std::unique_ptr<::mlspp::HPKEPrivateKey> selfHPKEPrivateKey_; 112 | 113 | std::unique_ptr<::mlspp::HPKEPrivateKey> joinInitPrivateKey_; 114 | std::unique_ptr<::mlspp::KeyPackage> joinKeyPackage_; 115 | 116 | std::unique_ptr<::mlspp::ExternalSender> externalSender_; 117 | 118 | std::unique_ptr<::mlspp::State> pendingGroupState_; 119 | std::unique_ptr<::mlspp::MLSMessage> pendingGroupCommit_; 120 | 121 | std::unique_ptr<::mlspp::State> outboundCachedGroupState_; 122 | 123 | std::unique_ptr<::mlspp::State> currentState_; 124 | RosterMap roster_; 125 | 126 | std::unique_ptr<::mlspp::State> stateWithProposals_; 127 | std::list proposalQueue_; 128 | 129 | MLSFailureCallback onMLSFailureCallback_{}; 130 | }; 131 | 132 | } // namespace mls 133 | } // namespace dave 134 | } // namespace discord 135 | -------------------------------------------------------------------------------- /cpp/src/dave/mls/user_credential.cpp: -------------------------------------------------------------------------------- 1 | #include "user_credential.h" 2 | 3 | #include 4 | 5 | #include "dave/mls/util.h" 6 | 7 | namespace discord { 8 | namespace dave { 9 | namespace mls { 10 | 11 | ::mlspp::Credential CreateUserCredential(const std::string& userId, 12 | [[maybe_unused]] ProtocolVersion version) 13 | { 14 | // convert the string user ID to a big endian uint64_t 15 | auto userID = std::stoull(userId); 16 | auto credentialBytes = BigEndianBytesFrom(userID); 17 | 18 | return ::mlspp::Credential::basic(credentialBytes); 19 | } 20 | 21 | std::string UserCredentialToString(const ::mlspp::Credential& cred, 22 | [[maybe_unused]] ProtocolVersion version) 23 | { 24 | if (cred.type() != ::mlspp::CredentialType::basic) { 25 | return ""; 26 | } 27 | 28 | const auto& basic = cred.template get<::mlspp::BasicCredential>(); 29 | 30 | auto uidVal = FromBigEndianBytes(basic.identity); 31 | 32 | return std::to_string(uidVal); 33 | } 34 | 35 | } // namespace mls 36 | } // namespace dave 37 | } // namespace discord 38 | -------------------------------------------------------------------------------- /cpp/src/dave/mls/user_credential.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "dave/version.h" 8 | 9 | namespace discord { 10 | namespace dave { 11 | namespace mls { 12 | 13 | ::mlspp::Credential CreateUserCredential(const std::string& userId, ProtocolVersion version); 14 | std::string UserCredentialToString(const ::mlspp::Credential& cred, ProtocolVersion version); 15 | 16 | } // namespace mls 17 | } // namespace dave 18 | } // namespace discord 19 | -------------------------------------------------------------------------------- /cpp/src/dave/mls/util.cpp: -------------------------------------------------------------------------------- 1 | #include "util.h" 2 | 3 | namespace discord { 4 | namespace dave { 5 | namespace mls { 6 | 7 | ::mlspp::bytes_ns::bytes BigEndianBytesFrom(uint64_t value) noexcept 8 | { 9 | auto buffer = ::mlspp::bytes_ns::bytes(); 10 | buffer.reserve(sizeof(value)); 11 | 12 | for (int i = sizeof(value) - 1; i >= 0; --i) { 13 | buffer.push_back(static_cast(value >> (i * 8))); 14 | } 15 | 16 | return buffer; 17 | } 18 | 19 | uint64_t FromBigEndianBytes(const ::mlspp::bytes_ns::bytes& buffer) noexcept 20 | { 21 | uint64_t val = 0; 22 | 23 | if (buffer.size() <= sizeof(val)) { 24 | for (uint8_t byte : buffer) { 25 | val = (val << 8) | byte; 26 | } 27 | } 28 | 29 | return val; 30 | } 31 | 32 | } // namespace mls 33 | } // namespace dave 34 | } // namespace discord 35 | -------------------------------------------------------------------------------- /cpp/src/dave/mls/util.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | namespace discord { 8 | namespace dave { 9 | namespace mls { 10 | 11 | ::mlspp::bytes_ns::bytes BigEndianBytesFrom(uint64_t value) noexcept; 12 | uint64_t FromBigEndianBytes(const ::mlspp::bytes_ns::bytes& value) noexcept; 13 | 14 | } // namespace mls 15 | } // namespace dave 16 | } // namespace discord 17 | -------------------------------------------------------------------------------- /cpp/src/dave/mls_key_ratchet.cpp: -------------------------------------------------------------------------------- 1 | #include "mls_key_ratchet.h" 2 | 3 | #include 4 | 5 | #include "dave/logger.h" 6 | 7 | namespace discord { 8 | namespace dave { 9 | 10 | MlsKeyRatchet::MlsKeyRatchet(::mlspp::CipherSuite suite, bytes baseSecret) noexcept 11 | : hashRatchet_(suite, std::move(baseSecret)) 12 | { 13 | } 14 | 15 | MlsKeyRatchet::~MlsKeyRatchet() noexcept = default; 16 | 17 | EncryptionKey MlsKeyRatchet::GetKey(KeyGeneration generation) noexcept 18 | { 19 | DISCORD_LOG(LS_INFO) << "Retrieving key for generation " << generation << " from HashRatchet"; 20 | 21 | try { 22 | auto keyAndNonce = hashRatchet_.get(generation); 23 | assert(keyAndNonce.key.size() >= kAesGcm128KeyBytes); 24 | return std::move(keyAndNonce.key.as_vec()); 25 | } 26 | catch (const std::exception& e) { 27 | DISCORD_LOG(LS_ERROR) << "Failed to retrieve key for generation " << generation << ": " 28 | << e.what(); 29 | return {}; 30 | } 31 | } 32 | 33 | void MlsKeyRatchet::DeleteKey(KeyGeneration generation) noexcept 34 | { 35 | hashRatchet_.erase(generation); 36 | } 37 | 38 | } // namespace dave 39 | } // namespace discord 40 | -------------------------------------------------------------------------------- /cpp/src/dave/mls_key_ratchet.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "key_ratchet.h" 6 | 7 | namespace discord { 8 | namespace dave { 9 | 10 | class MlsKeyRatchet : public IKeyRatchet { 11 | public: 12 | MlsKeyRatchet(::mlspp::CipherSuite suite, bytes baseSecret) noexcept; 13 | ~MlsKeyRatchet() noexcept override; 14 | 15 | EncryptionKey GetKey(KeyGeneration generation) noexcept override; 16 | void DeleteKey(KeyGeneration generation) noexcept override; 17 | 18 | private: 19 | ::mlspp::HashRatchet hashRatchet_; 20 | }; 21 | 22 | } // namespace dave 23 | } // namespace discord 24 | -------------------------------------------------------------------------------- /cpp/src/dave/utils/array_view.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace discord { 7 | namespace dave { 8 | 9 | template 10 | class ArrayView { 11 | public: 12 | ArrayView() = default; 13 | ArrayView(T* data, size_t size) 14 | : data_(data) 15 | , size_(size) 16 | { 17 | } 18 | 19 | size_t size() const { return size_; } 20 | T* data() const { return data_; } 21 | 22 | T* begin() const { return data_; } 23 | T* end() const { return data_ + size_; } 24 | 25 | private: 26 | T* data_ = nullptr; 27 | size_t size_ = 0; 28 | }; 29 | 30 | template 31 | inline ArrayView MakeArrayView(T* data, size_t size) 32 | { 33 | return ArrayView(data, size); 34 | } 35 | 36 | template 37 | inline ArrayView MakeArrayView(std::vector& data) 38 | { 39 | return ArrayView(data.data(), data.size()); 40 | } 41 | 42 | } // namespace dave 43 | } // namespace discord 44 | -------------------------------------------------------------------------------- /cpp/src/dave/utils/clock.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace discord { 6 | namespace dave { 7 | 8 | class IClock { 9 | public: 10 | using BaseClock = std::chrono::steady_clock; 11 | using TimePoint = BaseClock::time_point; 12 | using Duration = BaseClock::duration; 13 | 14 | virtual ~IClock() = default; 15 | virtual TimePoint Now() const = 0; 16 | }; 17 | 18 | class Clock : public IClock { 19 | public: 20 | TimePoint Now() const override { return BaseClock::now(); } 21 | }; 22 | 23 | } // namespace dave 24 | } // namespace discord 25 | -------------------------------------------------------------------------------- /cpp/src/dave/utils/leb128.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "leb128.h" 3 | 4 | // The following code was copied from the webrtc source code: 5 | // https://webrtc.googlesource.com/src/+/refs/heads/main/modules/rtp_rtcp/source/leb128.cc 6 | 7 | namespace discord { 8 | namespace dave { 9 | 10 | size_t Leb128Size(uint64_t value) 11 | { 12 | int size = 0; 13 | while (value >= 0x80) { 14 | ++size; 15 | value >>= 7; 16 | } 17 | return size + 1; 18 | } 19 | 20 | uint64_t ReadLeb128(const uint8_t*& readAt, const uint8_t* end) 21 | { 22 | uint64_t value = 0; 23 | int fillBits = 0; 24 | while (readAt != end && fillBits < 64 - 7) { 25 | uint8_t leb128Byte = *readAt; 26 | value |= uint64_t{leb128Byte & 0x7Fu} << fillBits; 27 | ++readAt; 28 | fillBits += 7; 29 | if ((leb128Byte & 0x80) == 0) { 30 | return value; 31 | } 32 | } 33 | // Read 9 bytes and didn't find the terminator byte. Check if 10th byte 34 | // is that terminator, however to fit result into uint64_t it may carry only 35 | // single bit. 36 | if (readAt != end && *readAt <= 1) { 37 | value |= uint64_t{*readAt} << fillBits; 38 | ++readAt; 39 | return value; 40 | } 41 | // Failed to find terminator leb128 byte. 42 | readAt = nullptr; 43 | return 0; 44 | } 45 | 46 | size_t WriteLeb128(uint64_t value, uint8_t* buffer) 47 | { 48 | int size = 0; 49 | while (value >= 0x80) { 50 | buffer[size] = 0x80 | (value & 0x7F); 51 | ++size; 52 | value >>= 7; 53 | } 54 | buffer[size] = static_cast(value); 55 | ++size; 56 | return size; 57 | } 58 | 59 | } // namespace dave 60 | } // namespace discord 61 | -------------------------------------------------------------------------------- /cpp/src/dave/utils/leb128.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace discord { 7 | namespace dave { 8 | 9 | constexpr size_t Leb128MaxSize = 10; 10 | 11 | // Returns number of bytes needed to store `value` in leb128 format. 12 | size_t Leb128Size(uint64_t value); 13 | 14 | // Reads leb128 encoded value and advance read_at by number of bytes consumed. 15 | // Sets read_at to nullptr on error. 16 | uint64_t ReadLeb128(const uint8_t*& readAt, const uint8_t* end); 17 | 18 | // Encodes `value` in leb128 format. Assumes buffer has size of at least 19 | // Leb128Size(value). Returns number of bytes consumed. 20 | size_t WriteLeb128(uint64_t value, uint8_t* buffer); 21 | 22 | } // namespace dave 23 | } // namespace discord 24 | -------------------------------------------------------------------------------- /cpp/src/dave/utils/scope_exit.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace discord { 8 | namespace dave { 9 | 10 | class [[nodiscard]] ScopeExit final { 11 | public: 12 | template 13 | explicit ScopeExit(Cleanup&& cleanup) 14 | : cleanup_{std::forward(cleanup)} 15 | { 16 | } 17 | 18 | ScopeExit(ScopeExit&& rhs) 19 | : cleanup_{std::move(rhs.cleanup_)} 20 | { 21 | rhs.cleanup_ = nullptr; 22 | } 23 | 24 | ~ScopeExit() 25 | { 26 | if (cleanup_) { 27 | cleanup_(); 28 | } 29 | } 30 | 31 | ScopeExit& operator=(ScopeExit&& rhs) 32 | { 33 | cleanup_ = std::move(rhs.cleanup_); 34 | rhs.cleanup_ = nullptr; 35 | return *this; 36 | } 37 | 38 | void Dismiss() { cleanup_ = nullptr; } 39 | 40 | private: 41 | ScopeExit(ScopeExit const&) = delete; 42 | ScopeExit& operator=(ScopeExit const&) = delete; 43 | 44 | std::function cleanup_; 45 | }; 46 | 47 | } // namespace dave 48 | } // namespace discord 49 | -------------------------------------------------------------------------------- /cpp/src/dave/version.cpp: -------------------------------------------------------------------------------- 1 | #include "version.h" 2 | 3 | namespace discord { 4 | namespace dave { 5 | 6 | constexpr ProtocolVersion CurrentDaveProtocolVersion = 1; 7 | 8 | ProtocolVersion MaxSupportedProtocolVersion() 9 | { 10 | return CurrentDaveProtocolVersion; 11 | } 12 | 13 | } // namespace dave 14 | } // namespace discord 15 | -------------------------------------------------------------------------------- /cpp/src/dave/version.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace discord { 6 | namespace dave { 7 | 8 | using ProtocolVersion = uint16_t; 9 | using SignatureVersion = uint8_t; 10 | 11 | ProtocolVersion MaxSupportedProtocolVersion(); 12 | 13 | } // namespace dave 14 | } // namespace discord 15 | -------------------------------------------------------------------------------- /cpp/test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | enable_testing() 2 | 3 | find_package(GTest CONFIG REQUIRED) 4 | 5 | SET(TEST_APP_NAME "libdave_test") 6 | 7 | file(GLOB_RECURSE TEST_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/*.h") 8 | file(GLOB_RECURSE TEST_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp") 9 | 10 | add_executable(${TEST_APP_NAME} ${TEST_HEADERS} ${TEST_SOURCES}) 11 | add_dependencies(${TEST_APP_NAME} ${LIB_NAME}) 12 | target_include_directories(${TEST_APP_NAME} PRIVATE ${PROJECT_SOURCE_DIR}/src) 13 | 14 | target_link_libraries(libdave_test PRIVATE ${LIB_NAME} GTest::gtest_main GTest::gmock MLSPP::bytes) -------------------------------------------------------------------------------- /cpp/test/boringssl_cryptor_tests.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | 3 | #include 4 | 5 | #include "dave/boringssl_cryptor.h" 6 | 7 | #include "dave_test.h" 8 | #include "static_key_ratchet.h" 9 | 10 | namespace discord { 11 | namespace dave { 12 | namespace test { 13 | 14 | TEST_F(DaveTests, BoringSSLEncryptDecrypt) 15 | { 16 | constexpr size_t PLAINTEXT_SIZE = 1024; 17 | auto plaintextBufferIn = std::vector(PLAINTEXT_SIZE, 0); 18 | auto additionalDataBuffer = std::vector(PLAINTEXT_SIZE, 0); 19 | auto plaintextBufferOut = std::vector(PLAINTEXT_SIZE, 0); 20 | auto ciphertextBuffer = std::vector(PLAINTEXT_SIZE, 0); 21 | auto nonceBuffer = std::vector(kAesGcm128NonceBytes, 0); 22 | auto tagBuffer = std::vector(kAesGcm128TruncatedTagBytes, 0); 23 | 24 | auto plaintextIn = 25 | MakeArrayView(plaintextBufferIn.data(), plaintextBufferIn.size()); 26 | auto additionalData = 27 | MakeArrayView(additionalDataBuffer.data(), additionalDataBuffer.size()); 28 | auto plaintextOut = 29 | MakeArrayView(plaintextBufferOut.data(), plaintextBufferOut.size()); 30 | auto ciphertextOut = MakeArrayView(ciphertextBuffer.data(), ciphertextBuffer.size()); 31 | auto ciphertextIn = 32 | MakeArrayView(ciphertextBuffer.data(), ciphertextBuffer.size()); 33 | auto nonce = MakeArrayView(nonceBuffer.data(), nonceBuffer.size()); 34 | auto tagOut = MakeArrayView(tagBuffer.data(), tagBuffer.size()); 35 | auto tagIn = MakeArrayView(tagBuffer.data(), tagBuffer.size()); 36 | 37 | BoringSSLCryptor cryptor(MakeStaticSenderKey("12345678901234567890")); 38 | 39 | EXPECT_TRUE(cryptor.Encrypt(ciphertextOut, plaintextIn, nonce, additionalData, tagOut)); 40 | 41 | // The ciphertext should not be the same as the plaintext 42 | EXPECT_FALSE(memcmp(plaintextBufferIn.data(), ciphertextBuffer.data(), PLAINTEXT_SIZE) == 0); 43 | 44 | EXPECT_TRUE(cryptor.Decrypt(plaintextOut, ciphertextIn, tagIn, nonce, additionalData)); 45 | 46 | // The plaintext should be the same as the original plaintext 47 | EXPECT_TRUE(memcmp(plaintextBufferIn.data(), plaintextBufferOut.data(), PLAINTEXT_SIZE) == 0); 48 | } 49 | 50 | TEST_F(DaveTests, BoringSSLAdditionalDataAuth) 51 | { 52 | constexpr size_t PLAINTEXT_SIZE = 1024; 53 | auto plaintextBufferIn = std::vector(PLAINTEXT_SIZE, 0); 54 | auto additionalDataBuffer = std::vector(PLAINTEXT_SIZE, 0); 55 | auto plaintextBufferOut = std::vector(PLAINTEXT_SIZE, 0); 56 | auto ciphertextBuffer = std::vector(PLAINTEXT_SIZE, 0); 57 | auto nonceBuffer = std::vector(kAesGcm128NonceBytes, 0); 58 | auto tagBuffer = std::vector(kAesGcm128TruncatedTagBytes, 0); 59 | 60 | auto plaintextIn = 61 | MakeArrayView(plaintextBufferIn.data(), plaintextBufferIn.size()); 62 | auto additionalData = 63 | MakeArrayView(additionalDataBuffer.data(), additionalDataBuffer.size()); 64 | auto plaintextOut = 65 | MakeArrayView(plaintextBufferOut.data(), plaintextBufferOut.size()); 66 | auto ciphertextOut = MakeArrayView(ciphertextBuffer.data(), ciphertextBuffer.size()); 67 | auto ciphertextIn = 68 | MakeArrayView(ciphertextBuffer.data(), ciphertextBuffer.size()); 69 | auto nonce = MakeArrayView(nonceBuffer.data(), nonceBuffer.size()); 70 | auto tagOut = MakeArrayView(tagBuffer.data(), tagBuffer.size()); 71 | auto tagIn = MakeArrayView(tagBuffer.data(), tagBuffer.size()); 72 | 73 | BoringSSLCryptor cryptor(MakeStaticSenderKey("12345678901234567890")); 74 | 75 | EXPECT_TRUE(cryptor.Encrypt(ciphertextOut, plaintextIn, nonce, additionalData, tagOut)); 76 | 77 | // We modify the additional data before decryption 78 | additionalDataBuffer[0] = 1; 79 | 80 | EXPECT_FALSE(cryptor.Decrypt(plaintextOut, ciphertextIn, tagIn, nonce, additionalData)); 81 | } 82 | 83 | TEST_F(DaveTests, BoringSSLKeyDiff) 84 | { 85 | constexpr size_t PLAINTEXT_SIZE = 1024; 86 | auto plaintextBuffer1 = std::vector(PLAINTEXT_SIZE, 0); 87 | auto additionalDataBuffer1 = std::vector(PLAINTEXT_SIZE, 0); 88 | auto plaintextBuffer2 = std::vector(PLAINTEXT_SIZE, 0); 89 | auto additionalDataBuffer2 = std::vector(PLAINTEXT_SIZE, 0); 90 | auto ciphertextBuffer1 = std::vector(PLAINTEXT_SIZE, 0); 91 | auto ciphertextBuffer2 = std::vector(PLAINTEXT_SIZE, 0); 92 | auto nonceBuffer = std::vector(kAesGcm128NonceBytes, 0); 93 | auto tagBuffer = std::vector(kAesGcm128TruncatedTagBytes, 0); 94 | 95 | auto plaintext1 = 96 | MakeArrayView(plaintextBuffer1.data(), plaintextBuffer1.size()); 97 | auto additionalData1 = 98 | MakeArrayView(additionalDataBuffer1.data(), additionalDataBuffer1.size()); 99 | auto plaintext2 = 100 | MakeArrayView(plaintextBuffer2.data(), plaintextBuffer2.size()); 101 | auto additionalData2 = 102 | MakeArrayView(additionalDataBuffer2.data(), additionalDataBuffer2.size()); 103 | auto ciphertext1 = MakeArrayView(ciphertextBuffer1.data(), ciphertextBuffer1.size()); 104 | auto ciphertext2 = MakeArrayView(ciphertextBuffer2.data(), ciphertextBuffer2.size()); 105 | auto nonce = MakeArrayView(nonceBuffer.data(), nonceBuffer.size()); 106 | auto tag = MakeArrayView(tagBuffer.data(), tagBuffer.size()); 107 | 108 | BoringSSLCryptor cryptor1(MakeStaticSenderKey("12345678901234567890")); 109 | BoringSSLCryptor cryptor2(MakeStaticSenderKey("09876543210987654321")); 110 | 111 | EXPECT_TRUE(cryptor1.Encrypt(ciphertext1, plaintext1, nonce, additionalData1, tag)); 112 | EXPECT_TRUE(cryptor2.Encrypt(ciphertext2, plaintext2, nonce, additionalData2, tag)); 113 | 114 | EXPECT_FALSE(memcmp(ciphertextBuffer1.data(), ciphertextBuffer2.data(), PLAINTEXT_SIZE) == 0); 115 | } 116 | 117 | TEST_F(DaveTests, BoringSSLNonceDiff) 118 | { 119 | constexpr size_t PLAINTEXT_SIZE = 1024; 120 | auto plaintextBuffer1 = std::vector(PLAINTEXT_SIZE, 0); 121 | auto additionalDataBuffer1 = std::vector(PLAINTEXT_SIZE, 0); 122 | auto plaintextBuffer2 = std::vector(PLAINTEXT_SIZE, 0); 123 | auto additionalDataBuffer2 = std::vector(PLAINTEXT_SIZE, 0); 124 | auto ciphertextBuffer1 = std::vector(PLAINTEXT_SIZE, 0); 125 | auto ciphertextBuffer2 = std::vector(PLAINTEXT_SIZE, 0); 126 | auto nonceBuffer1 = std::vector(kAesGcm128NonceBytes, 0); 127 | auto nonceBuffer2 = std::vector(kAesGcm128NonceBytes, 1); 128 | auto tagBuffer = std::vector(kAesGcm128TruncatedTagBytes, 0); 129 | 130 | auto plaintext1 = 131 | MakeArrayView(plaintextBuffer1.data(), plaintextBuffer1.size()); 132 | auto additionalData1 = 133 | MakeArrayView(additionalDataBuffer1.data(), additionalDataBuffer1.size()); 134 | auto plaintext2 = 135 | MakeArrayView(plaintextBuffer2.data(), plaintextBuffer2.size()); 136 | auto additionalData2 = 137 | MakeArrayView(additionalDataBuffer2.data(), additionalDataBuffer2.size()); 138 | auto ciphertext1 = MakeArrayView(ciphertextBuffer1.data(), ciphertextBuffer1.size()); 139 | auto ciphertext2 = MakeArrayView(ciphertextBuffer2.data(), ciphertextBuffer2.size()); 140 | auto nonce1 = MakeArrayView(nonceBuffer1.data(), nonceBuffer1.size()); 141 | auto nonce2 = MakeArrayView(nonceBuffer2.data(), nonceBuffer2.size()); 142 | auto tag = MakeArrayView(tagBuffer.data(), tagBuffer.size()); 143 | 144 | BoringSSLCryptor cryptor(MakeStaticSenderKey("12345678901234567890")); 145 | 146 | EXPECT_TRUE(cryptor.Encrypt(ciphertext1, plaintext1, nonce1, additionalData1, tag)); 147 | EXPECT_TRUE(cryptor.Encrypt(ciphertext2, plaintext2, nonce2, additionalData2, tag)); 148 | 149 | EXPECT_FALSE(memcmp(ciphertextBuffer1.data(), ciphertextBuffer2.data(), PLAINTEXT_SIZE) == 0); 150 | } 151 | 152 | } // namespace test 153 | } // namespace dave 154 | } // namespace discord 155 | -------------------------------------------------------------------------------- /cpp/test/codec_utils_tests.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "gtest/gtest.h" 3 | 4 | #include "dave/codec_utils.h" 5 | #include "dave/decryptor.h" 6 | #include "dave/encryptor.h" 7 | #include "dave/frame_processors.h" 8 | #include "dave/utils/array_view.h" 9 | 10 | #include "dave_test.h" 11 | #include "static_key_ratchet.h" 12 | 13 | namespace discord { 14 | namespace dave { 15 | namespace test { 16 | 17 | TEST_F(DaveTests, RandomOpusFrame) 18 | { 19 | constexpr std::string_view randomBytes = 20 | "0dc5aedd5bdc3f20be5697e54dd1f437b896a36f858c6f20bbd69e2a493ca170c4f0c1b9acd4" 21 | "9d324b92afa788d09b12b29115a2feb3552b60fff983234a6c9608af3933683efc6b0f5579a9"; 22 | 23 | // load the hex encoded sample frame to a buffer 24 | auto incomingFrame = GetBufferFromHex(randomBytes); 25 | 26 | auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); 27 | 28 | OutboundFrameProcessor frameProcessor; 29 | 30 | frameProcessor.ProcessFrame( 31 | MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::Opus); 32 | auto& unencryptedBytes = frameProcessor.GetUnencryptedBytes(); 33 | auto& encryptedBytes = frameProcessor.GetEncryptedBytes(); 34 | auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); 35 | 36 | EXPECT_EQ(incomingFrame.size(), 76u); 37 | EXPECT_EQ(unencryptedBytes.size(), 0u); 38 | EXPECT_EQ(encryptedBytes.size(), incomingFrame.size()); 39 | EXPECT_EQ(unencryptedRanges.size(), 0u); 40 | } 41 | 42 | TEST_F(DaveTests, SplitReconstruct) 43 | { 44 | std::string randomBytes = 45 | "0dc5aedd5bdc3f20be5697e54dd1f437b896a36f858c6f20bbd69e2a493ca170c4f0c1b9acd4" 46 | "9d324b92afa788d09b12b29115a2feb3552b60fff983234a6c9608af3933683efc6b0f5579a9" 47 | "0000000000000000 00 000a 140a 280a 3c0a 14 fafa"; 48 | randomBytes.erase(std::remove(randomBytes.begin(), randomBytes.end(), ' '), randomBytes.end()); 49 | 50 | // load the hex encoded sample frame to a buffer 51 | auto incomingFrame = GetBufferFromHex(randomBytes); 52 | 53 | auto reconstructedFrame = std::make_unique(incomingFrame.size()); 54 | 55 | InboundFrameProcessor frameProcessor; 56 | 57 | frameProcessor.ParseFrame( 58 | MakeArrayView(incomingFrame.data(), incomingFrame.size())); 59 | memcpy(frameProcessor.GetPlaintext().data(), 60 | frameProcessor.GetCiphertext().data(), 61 | frameProcessor.GetCiphertext().size()); 62 | auto bytesWritten = frameProcessor.ReconstructFrame( 63 | MakeArrayView(reconstructedFrame.get(), incomingFrame.size())); 64 | 65 | EXPECT_EQ(bytesWritten, 76u); 66 | EXPECT_EQ(memcmp(incomingFrame.data(), reconstructedFrame.get(), bytesWritten), 0); 67 | } 68 | 69 | TEST_F(DaveTests, H264SliceOneByteExpGolomb) 70 | { 71 | // start code, nal unit header 72 | // 3 exponential golomb values (first_mb_in_slice, slice_type, pic_parameter_set_id) 73 | // then slice payloads 74 | constexpr std::string_view kH264SliceHex = "0000000161e0fafafa"; 75 | 76 | // load the hex encoded sample frame to a buffer 77 | auto incomingFrame = GetBufferFromHex(kH264SliceHex); 78 | 79 | auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); 80 | 81 | OutboundFrameProcessor frameProcessor; 82 | frameProcessor.ProcessFrame( 83 | MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); 84 | 85 | auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); 86 | 87 | EXPECT_EQ(unencryptedRanges.size(), 1u); 88 | EXPECT_EQ(unencryptedRanges.front().offset, 0u); 89 | EXPECT_EQ(unencryptedRanges.front().size, 6u); 90 | } 91 | 92 | TEST_F(DaveTests, H264ShortIDROneByteExpGolomb) 93 | { 94 | // SPS NAL UNIT, PPS NAL UNIT, then IDR NAL Unit 95 | // for IDR: nal unit header, then 3 exponential golomb values (first_mb_in_slice, slice_type, 96 | // pic_parameter_set_id) then IDR payloads 97 | constexpr std::string_view kH264ShortIDR = 98 | "000000016742c00d8c8d40d0fbc900f08846a00000000168ce3c800000000165b8fafafa"; 99 | 100 | // load the hex encoded sample frame to a buffer 101 | auto incomingFrame = GetBufferFromHex(kH264ShortIDR); 102 | auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); 103 | 104 | OutboundFrameProcessor frameProcessor; 105 | frameProcessor.ProcessFrame( 106 | MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); 107 | 108 | auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); 109 | 110 | EXPECT_EQ(unencryptedRanges.size(), 1u); 111 | EXPECT_EQ(unencryptedRanges.front().offset, 0u); 112 | EXPECT_EQ(unencryptedRanges.front().size, 33u); 113 | } 114 | 115 | TEST_F(DaveTests, H264ShortIDRTwoByteExpGolomb) 116 | { 117 | // SPS NAL UNIT, PPS NAL UNIT, then IDR NAL Unit 118 | // for IDR: nal unit header, then 3 exponential golomb values (first_mb_in_slice, slice_type, 119 | // pic_parameter_set_id) then IDR payloads 120 | constexpr std::string_view kH264ShortIDR = 121 | "000000016742c00d8c8d40d0fbc900f08846a00000000168ce3c8000000001654760fafafa"; 122 | 123 | // load the hex encoded sample frame to a buffer 124 | auto incomingFrame = GetBufferFromHex(kH264ShortIDR); 125 | auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); 126 | 127 | OutboundFrameProcessor frameProcessor; 128 | frameProcessor.ProcessFrame( 129 | MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); 130 | 131 | auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); 132 | 133 | EXPECT_EQ(unencryptedRanges.size(), 1u); 134 | EXPECT_EQ(unencryptedRanges.front().offset, 0u); 135 | EXPECT_EQ(unencryptedRanges.front().size, 34u); 136 | } 137 | 138 | TEST_F(DaveTests, H264LongIDROneByteExpGolomb) 139 | { 140 | // SPS NAL UNIT, PPS NAL UNIT, SEI NAL unit, then IDR NAL Unit 141 | // which has nal unit header, 142 | // then 3 exponential golomb values (first_mb_in_slice, slice_type, pic_parameter_set_id) 143 | // then IDR payloads 144 | constexpr std::string_view kH264LongIDR = 145 | "00000001274d0033ab402802dd00da08846a000000000128ee3c800000000106051a47564adc5c4c433f94efc511" 146 | "3cd143a801ffccccff020004ca90800000000125b8fafafa"; 147 | 148 | // load the hex encoded sample frame to a buffer 149 | auto incomingFrame = GetBufferFromHex(kH264LongIDR); 150 | auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); 151 | 152 | OutboundFrameProcessor frameProcessor; 153 | frameProcessor.ProcessFrame( 154 | MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); 155 | 156 | auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); 157 | 158 | EXPECT_EQ(unencryptedRanges.size(), 1u); 159 | EXPECT_EQ(unencryptedRanges.front().offset, 0u); 160 | EXPECT_EQ(unencryptedRanges.front().size, 67u); 161 | } 162 | 163 | TEST_F(DaveTests, H264LongIDRTwoByteExpGolomb) 164 | { 165 | // SPS NAL UNIT, PPS NAL UNIT, SEI NAL unit, then IDR NAL Unit 166 | // which has nal unit header, then 3 exponential golomb values 167 | // (first_mb_in_slice, slice_type, pic_parameter_set_id) then IDR payloads 168 | constexpr std::string_view kH264LongIDR = 169 | "00000001274d0033ab402802dd00da08846a000000000128ee3c800000000106051a47564adc5c4c433f94efc5" 170 | "11" 171 | "3cd143a801ffccccff020004ca908000000001254760fafafa"; 172 | 173 | // load the hex encoded sample frame to a buffer 174 | auto incomingFrame = GetBufferFromHex(kH264LongIDR); 175 | auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); 176 | 177 | OutboundFrameProcessor frameProcessor; 178 | frameProcessor.ProcessFrame( 179 | MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); 180 | 181 | auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); 182 | 183 | EXPECT_EQ(unencryptedRanges.size(), 1u); 184 | EXPECT_EQ(unencryptedRanges.front().offset, 0u); 185 | EXPECT_EQ(unencryptedRanges.front().size, 68u); 186 | } 187 | 188 | TEST_F(DaveTests, H264EmulationPreventionInEarlyExpGolomb) 189 | { 190 | constexpr std::string_view kH264SliceHex = "00000001610000038000e0fafafa"; 191 | 192 | // load the hex encoded sample frame to a buffer 193 | auto incomingFrame = GetBufferFromHex(kH264SliceHex); 194 | 195 | auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); 196 | 197 | OutboundFrameProcessor frameProcessor; 198 | frameProcessor.ProcessFrame( 199 | MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); 200 | 201 | auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); 202 | 203 | EXPECT_EQ(unencryptedRanges.size(), 1u); 204 | EXPECT_EQ(unencryptedRanges.front().offset, 0u); 205 | EXPECT_EQ(unencryptedRanges.front().size, 11u); 206 | } 207 | 208 | TEST_F(DaveTests, H264ThreeByteShortCodeExtension) 209 | { 210 | constexpr std::string_view kH264MixedShortCodes = 211 | "000000012764001fac2b602802dd8088000003000800000301b46d0e1970" 212 | "00000128ee3cb0000001258880ababab"; 213 | 214 | // load the hex encoded sample frame to a buffer 215 | auto incomingFrame = GetBufferFromHex(kH264MixedShortCodes); 216 | auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); 217 | 218 | OutboundFrameProcessor frameProcessor; 219 | frameProcessor.ProcessFrame( 220 | MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); 221 | 222 | auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); 223 | 224 | EXPECT_EQ(unencryptedRanges.size(), 1u); 225 | EXPECT_EQ(unencryptedRanges.front().offset, 0u); 226 | EXPECT_EQ(unencryptedRanges.front().size, 45u); 227 | 228 | auto bytesToEncrypt = frameProcessor.GetEncryptedBytes(); 229 | auto encryptedBytes = frameProcessor.GetCiphertextBytes(); 230 | EXPECT_EQ(bytesToEncrypt.size(), encryptedBytes.size()); 231 | memcpy(encryptedFrame.get(), bytesToEncrypt.data(), bytesToEncrypt.size()); 232 | 233 | frameProcessor.ReconstructFrame(MakeArrayView( 234 | encryptedFrame.get(), bytesToEncrypt.size() + frameProcessor.GetUnencryptedBytes().size())); 235 | 236 | constexpr std::string_view kExpectedUnencryptedHeaderHex = 237 | "000000012764001fac2b602802dd8088000003000800000301b46d0e19700000000128ee3cb000000001258880"; 238 | auto expectedUnencryptedHeader = GetBufferFromHex(kExpectedUnencryptedHeaderHex); 239 | 240 | auto compareResultExpected = memcmp( 241 | encryptedFrame.get(), expectedUnencryptedHeader.data(), expectedUnencryptedHeader.size()); 242 | 243 | EXPECT_EQ(compareResultExpected, 0); 244 | } 245 | 246 | TEST_F(DaveTests, H264TwoSliceTest) 247 | { 248 | // start code, nal unit header 249 | // 3 exponential golomb values (first_mb_in_slice, slice_type, pic_parameter_set_id) 250 | // then slice payload 251 | // and repeated again 252 | constexpr std::string_view kH264TwoSliceHex = "0000000161e0fafafa0000000161e0fafafa"; 253 | 254 | // load the hex encoded sample frame to a buffer 255 | auto incomingFrame = GetBufferFromHex(kH264TwoSliceHex); 256 | 257 | auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); 258 | 259 | OutboundFrameProcessor frameProcessor; 260 | frameProcessor.ProcessFrame( 261 | MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); 262 | 263 | auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); 264 | 265 | EXPECT_EQ(unencryptedRanges.size(), 2u); 266 | EXPECT_EQ(unencryptedRanges[0].offset, 0u); 267 | EXPECT_EQ(unencryptedRanges[0].size, 6u); 268 | EXPECT_EQ(unencryptedRanges[1].offset, 9u); 269 | EXPECT_EQ(unencryptedRanges[1].size, 6u); 270 | } 271 | 272 | TEST_F(DaveTests, H265IdrSlice) 273 | { 274 | constexpr std::string_view kH265IdrSliceHex = 275 | "0000000140010c01ffff016000000300b0000003000003005d17024" 276 | "000000001420101016000000300b0000003000003005da00280802d16205ee45914bff2e7f13fa2" 277 | "000000014401c072f05324000000014e01051a47564adc5c4c433f94efc5113cd143a803ee0000ee02001fc8b88" 278 | "0000000012801abab"; 279 | 280 | // load the hex encoded sample frame to a buffer 281 | auto incomingFrame = GetBufferFromHex(kH265IdrSliceHex); 282 | auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); 283 | 284 | OutboundFrameProcessor frameProcessor; 285 | frameProcessor.ProcessFrame( 286 | MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H265); 287 | 288 | auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); 289 | 290 | EXPECT_EQ(unencryptedRanges.size(), 1u); 291 | EXPECT_EQ(unencryptedRanges.front().offset, 0u); 292 | EXPECT_EQ(unencryptedRanges.front().size, 119u); 293 | } 294 | 295 | TEST_F(DaveTests, H265TsaSlice) 296 | { 297 | constexpr std::string_view kH265TsaSliceHex = "000000010201abab"; 298 | 299 | // load the hex encoded sample frame to a buffer 300 | auto incomingFrame = GetBufferFromHex(kH265TsaSliceHex); 301 | auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); 302 | 303 | OutboundFrameProcessor frameProcessor; 304 | frameProcessor.ProcessFrame( 305 | MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H265); 306 | 307 | auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); 308 | 309 | EXPECT_EQ(unencryptedRanges.size(), 1u); 310 | EXPECT_EQ(unencryptedRanges.front().offset, 0u); 311 | EXPECT_EQ(unencryptedRanges.front().size, 6u); 312 | } 313 | 314 | TEST_F(DaveTests, H265SimpleThreeByteCodeExtension) 315 | { 316 | constexpr std::string_view kH265TsaSliceHexShort = "0000010201abab"; 317 | 318 | // load the hex encoded sample frame to a buffer 319 | auto incomingFrame = GetBufferFromHex(kH265TsaSliceHexShort); 320 | auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); 321 | 322 | OutboundFrameProcessor frameProcessor; 323 | frameProcessor.ProcessFrame( 324 | MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H265); 325 | 326 | auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); 327 | 328 | EXPECT_EQ(unencryptedRanges.size(), 1u); 329 | EXPECT_EQ(unencryptedRanges.front().offset, 0u); 330 | EXPECT_EQ(unencryptedRanges.front().size, 6u); 331 | } 332 | 333 | TEST_F(DaveTests, H265MultipleThreeByteCodeExtensions) 334 | { 335 | constexpr std::string_view kH265IdrSliceHex = 336 | "00000140010c01ffff016000000300b0000003000003005d17024" 337 | "0000001420101016000000300b0000003000003005da00280802d16205ee45914bff2e7f13fa2" 338 | "000000014401c072f05324000000014e01051a47564adc5c4c433f94efc5113cd143a803ee0000ee02001fc8b88" 339 | "00000012801abab"; 340 | 341 | // load the hex encoded sample frame to a buffer 342 | auto incomingFrame = GetBufferFromHex(kH265IdrSliceHex); 343 | auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); 344 | 345 | OutboundFrameProcessor frameProcessor; 346 | frameProcessor.ProcessFrame( 347 | MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H265); 348 | 349 | auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); 350 | 351 | EXPECT_EQ(unencryptedRanges.size(), 1u); 352 | EXPECT_EQ(unencryptedRanges.front().offset, 0u); 353 | EXPECT_EQ(unencryptedRanges.front().size, 119u); 354 | } 355 | 356 | TEST_F(DaveTests, H265TwoIdrSlice) 357 | { 358 | constexpr std::string_view kH265TwoIdrSliceHex = "0000010201abab0000010201abab"; 359 | 360 | // load the hex encoded sample frame to a buffer 361 | auto incomingFrame = GetBufferFromHex(kH265TwoIdrSliceHex); 362 | auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); 363 | 364 | OutboundFrameProcessor frameProcessor; 365 | frameProcessor.ProcessFrame( 366 | MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H265); 367 | 368 | auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); 369 | 370 | EXPECT_EQ(unencryptedRanges.size(), 2u); 371 | EXPECT_EQ(unencryptedRanges[0].offset, 0u); 372 | EXPECT_EQ(unencryptedRanges[0].size, 6u); 373 | EXPECT_EQ(unencryptedRanges[1].offset, 8u); 374 | EXPECT_EQ(unencryptedRanges[1].size, 6u); 375 | } 376 | 377 | } // namespace test 378 | } // namespace dave 379 | } // namespace discord 380 | -------------------------------------------------------------------------------- /cpp/test/cryptor_manager_tests.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include "dave/common.h" 8 | #include "dave/cryptor_manager.h" 9 | #include "dave/utils/clock.h" 10 | 11 | #include "dave_test.h" 12 | #include "static_key_ratchet.h" 13 | 14 | using namespace testing; 15 | using namespace std::chrono_literals; 16 | 17 | namespace discord { 18 | namespace dave { 19 | namespace test { 20 | 21 | // Gap can't be larger than the amount of bits allocated for it if we want to handle wraparound 22 | // correctly 23 | static_assert(kMaxGenerationGap < kGenerationWrap, "Gap can't be larger than wraparound value"); 24 | 25 | class MockKeyRatchet : public IKeyRatchet { 26 | public: 27 | MockKeyRatchet() 28 | { 29 | ON_CALL(*this, GetKey).WillByDefault([](KeyGeneration generation) { 30 | auto userId = std::string("12345678901234567890"); 31 | return MakeStaticSenderKey(userId + std::to_string(generation)); 32 | }); 33 | } 34 | MOCK_METHOD(EncryptionKey, GetKey, (KeyGeneration generation), (override, noexcept)); 35 | MOCK_METHOD(void, DeleteKey, (KeyGeneration generation), (override, noexcept)); 36 | }; 37 | 38 | class MockClock : public IClock { 39 | public: 40 | TimePoint Now() const override { return now_; } 41 | 42 | void SetNow(TimePoint now) { now_ = now; } 43 | void Advance(Duration duration) { now_ += duration; } 44 | 45 | private: 46 | TimePoint now_{std::chrono::steady_clock::now()}; 47 | }; 48 | 49 | TEST_F(DaveTests, CryptorManagerCheckMaxGap) 50 | { 51 | auto mockKeyRatchet = std::make_unique(); 52 | EXPECT_CALL(*mockKeyRatchet, GetKey(0)); 53 | EXPECT_CALL(*mockKeyRatchet, GetKey(kMaxGenerationGap)); 54 | EXPECT_CALL(*mockKeyRatchet, GetKey(kMaxGenerationGap + 1)); 55 | 56 | MockClock clock; 57 | CryptorManager cryptorManager{clock, std::move(mockKeyRatchet)}; 58 | // Give plenty of room to not trigger the max lifetime generations check 59 | clock.Advance(kMaxGenerationGap * 48h); 60 | 61 | auto cryptor = cryptorManager.GetCryptor(0); 62 | EXPECT_NE(cryptor, nullptr); 63 | EXPECT_EQ(cryptorManager.GetCryptor(0), cryptor); 64 | EXPECT_NE(cryptorManager.GetCryptor(kMaxGenerationGap), nullptr); 65 | EXPECT_EQ(cryptorManager.GetCryptor(kMaxGenerationGap + 1), nullptr); 66 | cryptorManager.ReportCryptorSuccess( 67 | kMaxGenerationGap, 68 | static_cast(kMaxGenerationGap << kRatchetGenerationShiftBits)); 69 | EXPECT_NE(cryptorManager.GetCryptor(kMaxGenerationGap + 1), nullptr); 70 | } 71 | 72 | TEST_F(DaveTests, CryptorManagerCheckExpiry) 73 | { 74 | auto mockKeyRatchet = std::make_unique(); 75 | EXPECT_CALL(*mockKeyRatchet, GetKey(0)); 76 | EXPECT_CALL(*mockKeyRatchet, GetKey(1)); 77 | EXPECT_CALL(*mockKeyRatchet, DeleteKey(0)); 78 | 79 | MockClock clock; 80 | CryptorManager cryptorManager{clock, std::move(mockKeyRatchet)}; 81 | EXPECT_NE(cryptorManager.GetCryptor(0), nullptr); 82 | clock.Advance(1000000h); 83 | EXPECT_NE(cryptorManager.GetCryptor(0), nullptr); 84 | EXPECT_NE(cryptorManager.GetCryptor(1), nullptr); 85 | cryptorManager.ReportCryptorSuccess(1, 1 << kRatchetGenerationShiftBits); 86 | clock.Advance(kCryptorExpiry - 1us); 87 | EXPECT_NE(cryptorManager.GetCryptor(0), nullptr); 88 | clock.Advance(2us); 89 | EXPECT_EQ(cryptorManager.GetCryptor(0), nullptr); 90 | } 91 | 92 | TEST_F(DaveTests, CryptorManagerDeleteOldKeys) 93 | { 94 | auto mockKeyRatchet = std::make_unique(); 95 | EXPECT_CALL(*mockKeyRatchet, GetKey(0)); 96 | EXPECT_CALL(*mockKeyRatchet, GetKey(5)); 97 | EXPECT_CALL(*mockKeyRatchet, DeleteKey(0)); 98 | EXPECT_CALL(*mockKeyRatchet, DeleteKey(1)); 99 | EXPECT_CALL(*mockKeyRatchet, DeleteKey(2)); 100 | EXPECT_CALL(*mockKeyRatchet, DeleteKey(3)); 101 | EXPECT_CALL(*mockKeyRatchet, DeleteKey(4)); 102 | 103 | MockClock clock; 104 | CryptorManager cryptorManager{clock, std::move(mockKeyRatchet)}; 105 | // Give plenty of room to not trigger the max lifetime generations check 106 | clock.Advance(kMaxGenerationGap * 48h); 107 | 108 | EXPECT_NE(cryptorManager.GetCryptor(0), nullptr); 109 | EXPECT_NE(cryptorManager.GetCryptor(5), nullptr); 110 | cryptorManager.ReportCryptorSuccess(5, 5 << kRatchetGenerationShiftBits); 111 | clock.Advance(kCryptorExpiry + 1us); 112 | EXPECT_NE(cryptorManager.GetCryptor(5), nullptr); 113 | } 114 | 115 | TEST_F(DaveTests, CryptorManagerGenerationWrap) 116 | { 117 | EXPECT_EQ(ComputeWrappedGeneration(0, 0), KeyGeneration{0}); 118 | EXPECT_EQ(ComputeWrappedGeneration(0, 1), KeyGeneration{1}); 119 | EXPECT_EQ(ComputeWrappedGeneration(0, 250), KeyGeneration{250}); 120 | 121 | EXPECT_EQ(ComputeWrappedGeneration(11 * kGenerationWrap + 42, 42), 122 | KeyGeneration{11 * kGenerationWrap + 42}); 123 | EXPECT_EQ(ComputeWrappedGeneration(11 * kGenerationWrap + 42, 50), 124 | KeyGeneration{11 * kGenerationWrap + 50}); 125 | EXPECT_EQ(ComputeWrappedGeneration(11 * kGenerationWrap + 42, 10), 126 | KeyGeneration{12 * kGenerationWrap + 10}); 127 | } 128 | 129 | TEST_F(DaveTests, CryptorManagerBigNonce) 130 | { 131 | EXPECT_EQ(ComputeWrappedBigNonce(0, 0), 0u); 132 | EXPECT_EQ(ComputeWrappedBigNonce(0, 1), 1u); 133 | EXPECT_EQ(ComputeWrappedBigNonce(0, 250), 250u); 134 | 135 | EXPECT_EQ(ComputeWrappedBigNonce(11, 10), 11 << kRatchetGenerationShiftBits | 10u); 136 | EXPECT_EQ(ComputeWrappedBigNonce(11, 42), 11 << kRatchetGenerationShiftBits | 42u); 137 | EXPECT_EQ(ComputeWrappedBigNonce(11, 50), 11 << kRatchetGenerationShiftBits | 50u); 138 | 139 | EXPECT_EQ(ComputeWrappedBigNonce(11, 2 << kRatchetGenerationShiftBits | 34), 140 | 11 << kRatchetGenerationShiftBits | 34u); 141 | EXPECT_EQ(ComputeWrappedBigNonce(11, 37 << kRatchetGenerationShiftBits | 139), 142 | 11 << kRatchetGenerationShiftBits | 139u); 143 | EXPECT_EQ(ComputeWrappedBigNonce(11, 89 << kRatchetGenerationShiftBits | 294), 144 | 11 << kRatchetGenerationShiftBits | 294u); 145 | } 146 | 147 | TEST_F(DaveTests, CryptorManagerNoReprocess) 148 | { 149 | auto mockKeyRatchet = std::make_unique(); 150 | EXPECT_CALL(*mockKeyRatchet, GetKey(0)); 151 | 152 | MockClock clock; 153 | CryptorManager cryptorManager{clock, std::move(mockKeyRatchet)}; 154 | // Give plenty of room to not trigger the max lifetime generations check 155 | clock.Advance(kMaxGenerationGap * 48h); 156 | 157 | auto cryptor = cryptorManager.GetCryptor(0); 158 | EXPECT_NE(cryptor, nullptr); 159 | 160 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 0)); 161 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 1)); 162 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 2)); 163 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 3)); 164 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, std::numeric_limits::max())); 165 | cryptorManager.ReportCryptorSuccess(0, 0); 166 | EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 0)); 167 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 1)); 168 | cryptorManager.ReportCryptorSuccess(0, 1); 169 | cryptorManager.ReportCryptorSuccess(0, 2); 170 | cryptorManager.ReportCryptorSuccess(0, 5); 171 | cryptorManager.ReportCryptorSuccess(0, 7); 172 | EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 0)); 173 | EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 1)); 174 | EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 2)); 175 | EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 5)); 176 | EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 7)); 177 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 3)); 178 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 4)); 179 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 6)); 180 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 8)); 181 | cryptorManager.ReportCryptorSuccess(0, 4); 182 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 3)); 183 | EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 4)); 184 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 6)); 185 | cryptorManager.ReportCryptorSuccess(0, 6); 186 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 3)); 187 | EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 6)); 188 | cryptorManager.ReportCryptorSuccess(0, 10 + kMaxMissingNonces); 189 | EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 3)); 190 | EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 7)); 191 | EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 8)); 192 | EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 9)); 193 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 10)); 194 | EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 11)); 195 | } 196 | 197 | } // namespace test 198 | } // namespace dave 199 | } // namespace discord 200 | -------------------------------------------------------------------------------- /cpp/test/cryptor_tests.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "dave/decryptor.h" 4 | #include "dave/encryptor.h" 5 | #include "dave/frame_processors.h" 6 | 7 | #include "dave_test.h" 8 | #include "static_key_ratchet.h" 9 | 10 | using namespace testing; 11 | using namespace std::chrono_literals; 12 | 13 | namespace discord { 14 | namespace dave { 15 | namespace test { 16 | 17 | constexpr std::string_view RandomBytes = 18 | "0dc5aedd5bdc3f20be5697e54dd1f437b896a36f858c6f20bbd69e2a493ca170c4f0c1b9acd4" 19 | "9d324b92afa788d09b12b29115a2feb3552b60fff983234a6c9608af3933683efc6b0f5579a9"; 20 | 21 | TEST_F(DaveTests, PassthroughInOutBuffer) 22 | { 23 | auto incomingFrame = GetBufferFromHex(RandomBytes); 24 | auto frameCopy = incomingFrame; 25 | 26 | auto frameViewIn = MakeArrayView(incomingFrame.data(), incomingFrame.size()); 27 | auto frameViewOut = MakeArrayView(incomingFrame.data(), incomingFrame.size()); 28 | 29 | EXPECT_NE(incomingFrame.data(), frameCopy.data()); 30 | 31 | Encryptor encryptor; 32 | encryptor.AssignSsrcToCodec(0, Codec::Opus); 33 | encryptor.SetPassthroughMode(true); 34 | 35 | size_t bytesWritten = 0; 36 | auto encryptResult = 37 | encryptor.Encrypt(MediaType::Audio, 0, frameViewIn, frameViewOut, &bytesWritten); 38 | 39 | EXPECT_EQ(encryptResult, 0); 40 | EXPECT_EQ(bytesWritten, frameCopy.size()); 41 | EXPECT_EQ(memcmp(incomingFrame.data(), frameCopy.data(), bytesWritten), 0); 42 | 43 | Decryptor decryptor; 44 | decryptor.TransitionToPassthroughMode(true, 0s); 45 | 46 | auto decryptResult = decryptor.Decrypt(MediaType::Audio, frameViewIn, frameViewOut); 47 | 48 | EXPECT_EQ(decryptResult, frameCopy.size()); 49 | EXPECT_EQ(memcmp(incomingFrame.data(), frameCopy.data(), bytesWritten), 0); 50 | } 51 | 52 | TEST_F(DaveTests, PassthroughTwoBuffers) 53 | { 54 | auto incomingFrame = GetBufferFromHex(RandomBytes); 55 | auto encryptedFrame = std::vector(incomingFrame.size() * 2); 56 | auto decryptedFrame = std::vector(incomingFrame.size()); 57 | 58 | Encryptor encryptor; 59 | encryptor.AssignSsrcToCodec(0, Codec::Opus); 60 | encryptor.SetPassthroughMode(true); 61 | 62 | size_t bytesWritten = 0; 63 | auto encryptResult = encryptor.Encrypt(MediaType::Audio, 64 | 0, 65 | {incomingFrame.data(), incomingFrame.size()}, 66 | {encryptedFrame.data(), encryptedFrame.size()}, 67 | &bytesWritten); 68 | 69 | EXPECT_EQ(encryptResult, 0); 70 | EXPECT_EQ(bytesWritten, incomingFrame.size()); 71 | EXPECT_EQ(memcmp(incomingFrame.data(), encryptedFrame.data(), bytesWritten), 0); 72 | 73 | Decryptor decryptor; 74 | decryptor.TransitionToPassthroughMode(true, 0s); 75 | 76 | auto decryptResult = decryptor.Decrypt(MediaType::Audio, 77 | {encryptedFrame.data(), bytesWritten}, 78 | {decryptedFrame.data(), decryptedFrame.size()}); 79 | 80 | EXPECT_EQ(decryptResult, incomingFrame.size()); 81 | EXPECT_EQ(memcmp(encryptedFrame.data(), decryptedFrame.data(), decryptResult), 0); 82 | } 83 | 84 | TEST_F(DaveTests, SilencePacketPassthrough) 85 | { 86 | const std::vector WorkerSilencePacket = {248, 255, 254}; 87 | 88 | Decryptor decryptor; 89 | decryptor.TransitionToKeyRatchet(std::make_unique("0123456789876543210"), 0s); 90 | 91 | auto decryptedFrame = std::vector(WorkerSilencePacket.size()); 92 | auto decryptResult = decryptor.Decrypt(MediaType::Audio, 93 | {WorkerSilencePacket.data(), WorkerSilencePacket.size()}, 94 | {decryptedFrame.data(), decryptedFrame.size()}); 95 | 96 | EXPECT_EQ(decryptResult, WorkerSilencePacket.size()); 97 | EXPECT_EQ(memcmp(WorkerSilencePacket.data(), decryptedFrame.data(), decryptResult), 0); 98 | } 99 | 100 | TEST_F(DaveTests, RandomOpusFrameEncryptDecrypt) 101 | { 102 | Encryptor encryptor; 103 | Decryptor decryptor; 104 | 105 | // set static key ratchet for testing 106 | encryptor.SetKeyRatchet(std::make_unique("0123456789876543210")); 107 | decryptor.TransitionToKeyRatchet(std::make_unique("0123456789876543210"), 0s); 108 | 109 | // load the hex encoded sample frame to a buffer 110 | auto incomingFrame = GetBufferFromHex(RandomBytes); 111 | auto encryptedFrame = std::vector(incomingFrame.size() * 2); 112 | auto decryptedFrame = std::vector(incomingFrame.size()); 113 | 114 | for (size_t i = 0; i < 1; i++) { 115 | // encrypt frame 116 | size_t bytesWritten = 0; 117 | encryptor.AssignSsrcToCodec(0, Codec::Opus); 118 | auto encryptResult = encryptor.Encrypt(MediaType::Audio, 119 | 0, 120 | {incomingFrame.data(), incomingFrame.size()}, 121 | {encryptedFrame.data(), encryptedFrame.size()}, 122 | &bytesWritten); 123 | 124 | EXPECT_EQ(encryptResult, 0); 125 | EXPECT_GE(bytesWritten, incomingFrame.size()); 126 | 127 | // decrypt frame 128 | auto decryptResult = decryptor.Decrypt(MediaType::Audio, 129 | {encryptedFrame.data(), bytesWritten}, 130 | {decryptedFrame.data(), decryptedFrame.size()}); 131 | EXPECT_EQ(decryptResult, incomingFrame.size()); 132 | EXPECT_EQ(memcmp(incomingFrame.data(), decryptedFrame.data(), incomingFrame.size()), 0); 133 | } 134 | } 135 | 136 | } // namespace test 137 | } // namespace dave 138 | } // namespace discord 139 | -------------------------------------------------------------------------------- /cpp/test/dave_test.cpp: -------------------------------------------------------------------------------- 1 | #include "dave_test.h" 2 | 3 | namespace discord { 4 | namespace dave { 5 | namespace test { 6 | 7 | std::vector GetBufferFromHex(const std::string_view& hex) 8 | { 9 | auto hexLength = hex.length(); 10 | 11 | if (hexLength % 2 != 0) { 12 | return {}; 13 | } 14 | 15 | auto buffer = std::vector(hexLength / 2); 16 | 17 | for (unsigned int i = 0; i < hexLength; i += 2) { 18 | auto byte = std::string(hex.substr(i, 2)); 19 | buffer[i / 2] = static_cast(std::stoi(byte, nullptr, 16)); 20 | } 21 | 22 | return buffer; 23 | } 24 | 25 | } // namespace test 26 | } // namespace dave 27 | } // namespace discord 28 | -------------------------------------------------------------------------------- /cpp/test/dave_test.h: -------------------------------------------------------------------------------- 1 | 2 | #include "gtest/gtest.h" 3 | 4 | #include "dave/common.h" 5 | 6 | namespace discord { 7 | namespace dave { 8 | namespace test { 9 | 10 | std::vector GetBufferFromHex(const std::string_view& hex); 11 | 12 | class DaveTests : public ::testing::Test { 13 | public: 14 | void SetUp() override {} 15 | 16 | void TearDown() override {} 17 | }; 18 | 19 | } // namespace test 20 | } // namespace dave 21 | } // namespace discord 22 | -------------------------------------------------------------------------------- /cpp/test/static_key_ratchet.cpp: -------------------------------------------------------------------------------- 1 | #include "static_key_ratchet.h" 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "dave/common.h" 8 | #include "dave/logger.h" 9 | 10 | namespace discord { 11 | namespace dave { 12 | namespace test { 13 | 14 | EncryptionKey MakeStaticSenderKey(const std::string& userID) 15 | { 16 | auto u64userID = strtoull(userID.c_str(), nullptr, 10); 17 | return MakeStaticSenderKey(u64userID); 18 | } 19 | 20 | EncryptionKey MakeStaticSenderKey(uint64_t u64userID) 21 | { 22 | static_assert(kAesGcm128KeyBytes == 2 * sizeof(u64userID)); 23 | EncryptionKey senderKey(kAesGcm128KeyBytes); 24 | const uint8_t* bytePtr = reinterpret_cast(&u64userID); 25 | std::copy_n(bytePtr, sizeof(u64userID), senderKey.begin()); 26 | std::copy_n(bytePtr, sizeof(u64userID), senderKey.begin() + sizeof(u64userID)); 27 | return senderKey; 28 | } 29 | 30 | StaticKeyRatchet::StaticKeyRatchet(const std::string& userId) noexcept 31 | : u64userID_(strtoull(userId.c_str(), nullptr, 10)) 32 | { 33 | } 34 | 35 | EncryptionKey StaticKeyRatchet::GetKey(KeyGeneration generation) noexcept 36 | { 37 | DISCORD_LOG(LS_INFO) << "Retrieving static key for generation " << generation << " for user " 38 | << u64userID_; 39 | return MakeStaticSenderKey(u64userID_); 40 | } 41 | 42 | void StaticKeyRatchet::DeleteKey([[maybe_unused]] KeyGeneration generation) noexcept 43 | { 44 | // noop 45 | } 46 | 47 | } // namespace test 48 | } // namespace dave 49 | } // namespace discord 50 | -------------------------------------------------------------------------------- /cpp/test/static_key_ratchet.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "dave/key_ratchet.h" 6 | 7 | namespace discord { 8 | namespace dave { 9 | namespace test { 10 | 11 | EncryptionKey MakeStaticSenderKey(const std::string& userID); 12 | EncryptionKey MakeStaticSenderKey(uint64_t u64userID); 13 | 14 | class StaticKeyRatchet : public IKeyRatchet { 15 | public: 16 | StaticKeyRatchet(const std::string& userId) noexcept; 17 | ~StaticKeyRatchet() noexcept override = default; 18 | 19 | EncryptionKey GetKey(KeyGeneration generation) noexcept override; 20 | void DeleteKey(KeyGeneration generation) noexcept override; 21 | 22 | private: 23 | uint64_t u64userID_; 24 | }; 25 | 26 | } // namespace test 27 | } // namespace dave 28 | } // namespace discord 29 | -------------------------------------------------------------------------------- /cpp/vcpkg-alts/boringssl/overlay-ports/mlspp/portfile.cmake: -------------------------------------------------------------------------------- 1 | vcpkg_from_github( 2 | OUT_SOURCE_PATH SOURCE_PATH 3 | REPO cisco/mlspp 4 | REF "${VERSION}" 5 | SHA512 ca2a7e9cb512f38c49d84e351ca304d7aca176b2686a7ad1326d72dbb6f4b4063dabdf36c57336674b71b1b74b5135abd274adbc79f30d46f792e7862ef5306c 6 | ) 7 | 8 | vcpkg_cmake_configure( 9 | SOURCE_PATH "${SOURCE_PATH}" 10 | OPTIONS 11 | -DDISABLE_GREASE=ON 12 | -DVCPKG_MANIFEST_DIR="alternatives/boringssl" 13 | -DMLS_CXX_NAMESPACE="mlspp" 14 | ) 15 | 16 | vcpkg_cmake_install() 17 | 18 | file(REMOVE_RECURSE "${CURRENT_PACKAGES_DIR}/debug/include") 19 | file(REMOVE_RECURSE "${CURRENT_PACKAGES_DIR}/debug/share") -------------------------------------------------------------------------------- /cpp/vcpkg-alts/boringssl/overlay-ports/mlspp/vcpkg.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mlspp", 3 | "version-string": "49f1c412691e2dc394e24322a84e95ddabf4bf4b", 4 | "description": "Cisco MLS C++ library", 5 | "dependencies": [ 6 | { 7 | "name": "boringssl", 8 | "version>=": "2023-10-13" 9 | }, 10 | "nlohmann-json", 11 | "vcpkg-cmake" 12 | ], 13 | "builtin-baseline": "eb33d2f7583405fca184bcdf7fdd5828ec88ac05" 14 | } -------------------------------------------------------------------------------- /cpp/vcpkg-alts/boringssl/vcpkg.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "libdave", 3 | "license": "MIT", 4 | "dependencies": [ 5 | { 6 | "name": "boringssl", 7 | "version>=": "2023-10-13" 8 | }, 9 | "gtest", 10 | "mlspp" 11 | ], 12 | "builtin-baseline": "7adc2e4d49e8d0efc07a369079faa6bc3dbb90f3", 13 | "vcpkg-configuration": { 14 | "overlay-ports": [ 15 | "./overlay-ports" 16 | ] 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /js/.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | lerna-debug.log* 8 | .pnpm-debug.log* 9 | 10 | # Diagnostic reports (https://nodejs.org/api/report.html) 11 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json 12 | 13 | # Runtime data 14 | pids 15 | *.pid 16 | *.seed 17 | *.pid.lock 18 | 19 | # Directory for instrumented libs generated by jscoverage/JSCover 20 | lib-cov 21 | 22 | # Coverage directory used by tools like istanbul 23 | coverage 24 | *.lcov 25 | 26 | # nyc test coverage 27 | .nyc_output 28 | 29 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) 30 | .grunt 31 | 32 | # Bower dependency directory (https://bower.io/) 33 | bower_components 34 | 35 | # node-waf configuration 36 | .lock-wscript 37 | 38 | # Compiled binary addons (https://nodejs.org/api/addons.html) 39 | build/Release 40 | 41 | # Dependency directories 42 | node_modules/ 43 | jspm_packages/ 44 | 45 | # Snowpack dependency directory (https://snowpack.dev/) 46 | web_modules/ 47 | 48 | # TypeScript cache 49 | *.tsbuildinfo 50 | 51 | # Optional npm cache directory 52 | .npm 53 | 54 | # Optional eslint cache 55 | .eslintcache 56 | 57 | # Optional stylelint cache 58 | .stylelintcache 59 | 60 | # Microbundle cache 61 | .rpt2_cache/ 62 | .rts2_cache_cjs/ 63 | .rts2_cache_es/ 64 | .rts2_cache_umd/ 65 | 66 | # Optional REPL history 67 | .node_repl_history 68 | 69 | # Output of 'npm pack' 70 | *.tgz 71 | 72 | # Yarn Integrity file 73 | .yarn-integrity 74 | 75 | # dotenv environment variable files 76 | .env 77 | .env.development.local 78 | .env.test.local 79 | .env.production.local 80 | .env.local 81 | 82 | # parcel-bundler cache (https://parceljs.org/) 83 | .cache 84 | .parcel-cache 85 | 86 | # Next.js build output 87 | .next 88 | out 89 | 90 | # Nuxt.js build / generate output 91 | .nuxt 92 | dist 93 | 94 | # Gatsby files 95 | .cache/ 96 | # Comment in the public line in if your project uses Gatsby and not Next.js 97 | # https://nextjs.org/blog/next-9-1#public-directory-support 98 | # public 99 | 100 | # vuepress build output 101 | .vuepress/dist 102 | 103 | # vuepress v2.x temp and cache directory 104 | .temp 105 | .cache 106 | 107 | # Docusaurus cache and generated files 108 | .docusaurus 109 | 110 | # Serverless directories 111 | .serverless/ 112 | 113 | # FuseBox cache 114 | .fusebox/ 115 | 116 | # DynamoDB Local files 117 | .dynamodb/ 118 | 119 | # TernJS port file 120 | .tern-port 121 | 122 | # Stores VSCode versions used for testing VSCode extensions 123 | .vscode-test 124 | 125 | # yarn v2 126 | .yarn/cache 127 | .yarn/unplugged 128 | .yarn/build-state.yml 129 | .yarn/install-state.gz 130 | .pnp.* -------------------------------------------------------------------------------- /js/DisplayableCode.ts: -------------------------------------------------------------------------------- 1 | const MAX_GROUP_SIZE = 8; 2 | 3 | export function generateDisplayableCode(data: Uint8Array, desiredLength: number, groupSize: number): string { 4 | if (data.byteLength < desiredLength) { 5 | throw new Error('data.byteLength must be greater than or equal to desiredLength'); 6 | } 7 | 8 | if (desiredLength % groupSize !== 0) { 9 | throw new Error('desiredLength must be a multiple of groupSize'); 10 | } 11 | 12 | if (groupSize > MAX_GROUP_SIZE) { 13 | throw new Error(`groupSize must be less than or equal to ${MAX_GROUP_SIZE}`); 14 | } 15 | 16 | const groupModulus = BigInt(10 ** groupSize); 17 | 18 | let result = ''; 19 | 20 | for (let i = 0; i < desiredLength; i += groupSize) { 21 | let groupValue = BigInt(0); 22 | 23 | for (let j = groupSize; j > 0; --j) { 24 | 25 | const nextByte = data[i + (groupSize - j)] 26 | if (nextByte === undefined) { 27 | throw new Error('Out of bounds access from data array'); 28 | } 29 | 30 | groupValue = (groupValue << 8n) | BigInt(nextByte); 31 | } 32 | 33 | groupValue %= groupModulus; 34 | 35 | result += groupValue.toString().padStart(groupSize, '0'); 36 | } 37 | 38 | return result; 39 | } 40 | -------------------------------------------------------------------------------- /js/KeyFingerprint.ts: -------------------------------------------------------------------------------- 1 | const VERSION_LEN = 2; 2 | const UID_LEN = 8; 3 | 4 | export async function generateKeyFingerprint(version: number, key: Uint8Array, userId: string): Promise { 5 | if (version !== 0) { 6 | throw new Error('unsupported fingerprint format version'); 7 | } 8 | 9 | if (key.byteLength === 0) { 10 | throw new Error('zero-length key'); 11 | } 12 | 13 | if (userId.length === 0) { 14 | throw new Error('zero-length user ID'); 15 | } 16 | 17 | const userIdInt = BigInt(userId); 18 | if (userIdInt < 0n || userIdInt >= 2n ** 64n) { 19 | throw new Error('user ID out of range'); 20 | } 21 | 22 | let lbuf = new Uint8Array(VERSION_LEN + key.byteLength + UID_LEN); 23 | lbuf.set(key, VERSION_LEN); 24 | 25 | const dv = new DataView(lbuf.buffer); 26 | dv.setUint16(0, version); 27 | dv.setBigUint64(VERSION_LEN + key.byteLength, userIdInt); 28 | 29 | return lbuf; 30 | } 31 | -------------------------------------------------------------------------------- /js/KeySerialization.ts: -------------------------------------------------------------------------------- 1 | import base64 from 'base64-js'; 2 | 3 | export function serializeKey(data: Uint8Array): string { 4 | return base64.fromByteArray(data); 5 | } 6 | -------------------------------------------------------------------------------- /js/PairwiseFingerprint.ts: -------------------------------------------------------------------------------- 1 | import {generateKeyFingerprint} from './KeyFingerprint'; 2 | import {scryptAsync} from '@noble/hashes/scrypt'; 3 | 4 | const salt = Uint8Array.of( 5 | 0x24, 6 | 0xca, 7 | 0xb1, 8 | 0x7a, 9 | 0x7a, 10 | 0xf8, 11 | 0xec, 12 | 0x2b, 13 | 0x82, 14 | 0xb4, 15 | 0x12, 16 | 0xb9, 17 | 0x2d, 18 | 0xab, 19 | 0x19, 20 | 0x2e, 21 | ); 22 | const scryptParams = { 23 | N: 16384, 24 | r: 8, 25 | p: 2, 26 | dkLen: 64, 27 | }; 28 | 29 | function compareArrays(a: Uint8Array, b: Uint8Array) { 30 | for (let i = 0; i < a.length && i < b.length; i++) { 31 | if (a[i] != b[i]) return a[i]! - b[i]!; 32 | } 33 | 34 | return a.length - b.length; 35 | } 36 | 37 | export async function generatePairwiseFingerprint( 38 | version: number, 39 | keyA: Uint8Array, 40 | userIdA: string, 41 | keyB: Uint8Array, 42 | userIdB: string, 43 | ): Promise { 44 | const fingerprints = await Promise.all([ 45 | generateKeyFingerprint(version, keyA, userIdA), 46 | generateKeyFingerprint(version, keyB, userIdB), 47 | ]); 48 | 49 | fingerprints.sort(compareArrays); 50 | 51 | const input = new Uint8Array(fingerprints[0].byteLength + fingerprints[1].byteLength); 52 | input.set(fingerprints[0], 0); 53 | input.set(fingerprints[1], fingerprints[0].byteLength); 54 | 55 | const ret = await scryptAsync(input, salt, scryptParams); 56 | 57 | return new Uint8Array(ret); 58 | } 59 | -------------------------------------------------------------------------------- /js/README.md: -------------------------------------------------------------------------------- 1 | ## libdave JS 2 | 3 | Contains the package @discordapp/libdave. This is leveraged by Discord clients to enable out-of-band verifications of DAVE protocol call members and the MLS epoch authenticator. 4 | 5 | ### Testing 6 | 7 | Testing uses [Jest](https://jestjs.io/). You can run the tests with `pnpm jest`. 8 | 9 | ### Dependencies 10 | 11 | - [@noble/hashes](https://github.com/paulmillr/noble-hashes) 12 | - [base64-js](https://www.npmjs.com/package/base64-js) -------------------------------------------------------------------------------- /js/__tests__/DisplayableCode-test.ts: -------------------------------------------------------------------------------- 1 | import {describe, expect, test} from '@jest/globals'; 2 | 3 | const DAVE = require('../libdave'); 4 | 5 | describe('DisplayableCode', () => { 6 | test('expectedOutput', () => { 7 | const shortData = new Uint8Array([0xaa, 0xbb, 0xcc, 0xdd, 0xee]); 8 | expect(DAVE.generateDisplayableCode(shortData, 5, 5)).toBe('05870'); 9 | 10 | const longDataBuffer = Buffer.from('aabbccddeebbccddeeffccddeeffaaddeeffaabbeeffaabbccffaabbccdd', 'hex'); 11 | const longData = Uint8Array.from(longDataBuffer); 12 | expect(DAVE.generateDisplayableCode(longData, 30, 5)).toBe('058708105556138052119572494877'); 13 | }); 14 | 15 | test('expectedFailure', () => { 16 | const tooShortData = new Uint8Array([0xaa, 0xbb, 0xcc, 0xdd]); 17 | expect(() => { 18 | DAVE.generateDisplayableCode(tooShortData, 5, 5); 19 | }).toThrow(); 20 | 21 | const goodData = new Uint8Array([0xaa, 0xbb, 0xcc, 0xdd]); 22 | expect(() => { 23 | DAVE.generateDisplayableCode(goodData, 4, 3); 24 | }).toThrow(); 25 | 26 | const randomData = new Uint8Array(1024); 27 | globalThis.crypto.getRandomValues(randomData); 28 | expect(() => { 29 | DAVE.generateDisplayableCode(randomData, 1024, 11); 30 | }).toThrow(); 31 | }); 32 | }); 33 | -------------------------------------------------------------------------------- /js/__tests__/KeyFingerprint-test.ts: -------------------------------------------------------------------------------- 1 | import {describe, expect, test} from '@jest/globals'; 2 | 3 | const DAVE = require('../libdave'); 4 | 5 | describe('KeyFingerprint', () => { 6 | test('expectedOutput', async () => { 7 | const shortData = new Uint8Array(33); 8 | expect((await DAVE.generateKeyFingerprint(0, shortData, '1234')).join('')).toBe( 9 | '000000000000000000000000000000000000000004210', 10 | ); 11 | 12 | const longData = new Uint8Array(65); 13 | expect((await DAVE.generateKeyFingerprint(0, longData, '12345678')).join('')).toBe( 14 | '0000000000000000000000000000000000000000000000000000000000000000000000001889778', 15 | ); 16 | }); 17 | 18 | test('expectedFailure', async () => { 19 | const data = new Uint8Array(33); 20 | await expect(DAVE.generateKeyFingerprint(1, data, '1234')).rejects.toThrow(); 21 | 22 | await expect(DAVE.generateKeyFingerprint(0, data, 'abcd')).rejects.toThrow(); 23 | 24 | await expect(DAVE.generateKeyFingerprint(0, new Uint8Array(0), '1234')).rejects.toThrow(); 25 | }); 26 | }); 27 | -------------------------------------------------------------------------------- /js/__tests__/KeySerialization-test.ts: -------------------------------------------------------------------------------- 1 | import {describe, expect, test} from '@jest/globals'; 2 | 3 | const DAVE = require('../libdave'); 4 | 5 | describe('KeySerialization', () => { 6 | test('expectedOutput', async () => { 7 | const zeroData = new Uint8Array(6); 8 | expect(DAVE.serializeKey(zeroData)).toBe('AAAAAAAA'); 9 | 10 | const moreData = new Uint8Array([0, 1, 0xff, 0x7f, 0x80]); 11 | expect(DAVE.serializeKey(moreData)).toBe('AAH/f4A='); 12 | }); 13 | }); 14 | -------------------------------------------------------------------------------- /js/__tests__/PairwiseFingerprint-test.ts: -------------------------------------------------------------------------------- 1 | import {describe, expect, test} from '@jest/globals'; 2 | 3 | const DAVE = require('../libdave'); 4 | 5 | describe('PairwiseFingerprint', () => { 6 | test('expectedOutput', async () => { 7 | const data1 = new Uint8Array(33); 8 | const data2 = new Uint8Array(65); 9 | expect(DAVE.generatePairwiseFingerprint(0, data1, '1234', data2, '5678')).resolves.toEqual( 10 | new Uint8Array([ 11 | 133, 129, 241, 44, 36, 135, 79, 195, 27, 28, 151, 69, 124, 197, 189, 41, 192, 7, 16, 45, 79, 247, 138, 58, 126, 12 | 161, 178, 136, 12, 109, 96, 164, 169, 92, 2, 232, 136, 174, 74, 156, 173, 144, 191, 184, 34, 45, 242, 136, 41, 13 | 133, 14, 158, 119, 79, 204, 48, 6, 220, 121, 6, 242, 11, 164, 60, 14 | ]), 15 | ); 16 | }); 17 | 18 | test('badSort', async () => { 19 | const data1 = new Uint8Array([0, 100]); 20 | const data2 = new Uint8Array([0, 20]); 21 | expect(DAVE.generatePairwiseFingerprint(0, data1, '1', data2, '2')).resolves.toEqual( 22 | new Uint8Array([ 23 | 141, 169, 194, 143, 22, 72, 22, 245, 13, 140, 66, 228, 159, 195, 101, 106, 119, 240, 69, 191, 178, 227, 194, 24 | 126, 162, 255, 222, 148, 138, 5, 33, 215, 240, 167, 234, 245, 149, 182, 46, 20, 4, 83, 191, 31, 165, 74, 253, 25 | 165, 199, 16, 29, 71, 193, 205, 169, 154, 255, 154, 34, 30, 94, 171, 247, 43, 26 | ]), 27 | ); 28 | }); 29 | 30 | test('expectedFailure', async () => { 31 | const data = new Uint8Array(33); 32 | await expect(DAVE.generatePairwiseFingerprint(1, data, '1234', data, '5678')).rejects.toThrow(); 33 | 34 | await expect(DAVE.generatePairwiseFingerprint(0, data, 'abcd', data, '5678')).rejects.toThrow(); 35 | 36 | await expect(DAVE.generatePairwiseFingerprint(0, new Uint8Array(0), '1234', data, '5678')).rejects.toThrow(); 37 | }); 38 | }); 39 | -------------------------------------------------------------------------------- /js/jest-setup.js: -------------------------------------------------------------------------------- 1 | const crypto = require('crypto'); 2 | 3 | function convertAlgorithm(name) { 4 | switch (name) { 5 | case 'SHA-512': 6 | return 'sha512'; 7 | default: 8 | return name; 9 | } 10 | } 11 | 12 | Object.defineProperty(globalThis, 'crypto', { 13 | value: { 14 | getRandomValues: (arr) => crypto.randomBytes(arr.length), 15 | subtle: { 16 | digest: (algorithm, data) => { 17 | return crypto.hash(convertAlgorithm(algorithm), data, 'buffer').buffer; 18 | }, 19 | }, 20 | }, 21 | }); 22 | -------------------------------------------------------------------------------- /js/jest.config.js: -------------------------------------------------------------------------------- 1 | /** 2 | * For a detailed explanation regarding each configuration property, visit: 3 | * https://jestjs.io/docs/configuration 4 | */ 5 | 6 | /** @type {import('jest').Config} */ 7 | const config = { 8 | preset: "ts-jest", 9 | 10 | // Automatically clear mock calls, instances, contexts and results before every test 11 | clearMocks: true, 12 | 13 | // Indicates whether the coverage information should be collected while executing the test 14 | collectCoverage: true, 15 | 16 | // The directory where Jest should output its coverage files 17 | coverageDirectory: "coverage", 18 | 19 | // Indicates which provider should be used to instrument code for coverage 20 | coverageProvider: "v8", 21 | 22 | // Array of regexp pattern strings for visible paths to module loader 23 | modulePathIgnorePatterns: ["dist/"] 24 | }; 25 | 26 | module.exports = config; 27 | -------------------------------------------------------------------------------- /js/libdave.ts: -------------------------------------------------------------------------------- 1 | export {generateDisplayableCode} from './DisplayableCode'; 2 | export {generateKeyFingerprint} from './KeyFingerprint'; 3 | export {generatePairwiseFingerprint} from './PairwiseFingerprint'; 4 | export {serializeKey} from './KeySerialization'; 5 | -------------------------------------------------------------------------------- /js/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@discordapp/libdave", 3 | "license": "MIT", 4 | "version": "1.0.0", 5 | "main": "libdave.tsx", 6 | "scripts": { 7 | "test": "jest", 8 | "build:watch": "tsc --watch", 9 | "build": "tsc" 10 | }, 11 | "devDependencies": { 12 | "@jest/globals": "^29.7.0", 13 | "jest": "^29.7.0", 14 | "ts-jest": "^29.2.5", 15 | "typescript": "^5.6.2" 16 | }, 17 | "dependencies": { 18 | "@noble/hashes": "1.5.0", 19 | "base64-js": "1.5.1" 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /js/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "esModuleInterop": true, 4 | "skipLibCheck": true, 5 | "target": "es2022", 6 | "allowJs": true, 7 | "resolveJsonModule": true, 8 | "moduleDetection": "force", 9 | "isolatedModules": true, 10 | "strict": true, 11 | "noUncheckedIndexedAccess": true, 12 | "noImplicitOverride": true, 13 | "module": "NodeNext", 14 | "outDir": "dist", 15 | "sourceMap": true, 16 | "declaration": true, 17 | "lib": ["es2022", "dom", "dom.iterable"], 18 | } 19 | } --------------------------------------------------------------------------------