├── .bazelrc ├── .bazelversion ├── .github └── workflows │ ├── documentation.yaml │ └── ubuntu-bazel.yaml ├── .gitignore ├── .swift-format.json ├── BUILD ├── LICENSE ├── README.md ├── WORKLOG.md ├── WORKSPACE ├── bazel └── setup_clang.sh ├── deps.bzl ├── examples ├── BUILD.bazel ├── cifar-10 │ └── main.swift ├── ddpg │ └── main.swift ├── dqn │ └── main.swift ├── imdb │ └── main.swift ├── minrf │ └── main.swift ├── pandas │ └── main.swift ├── ppo │ ├── main.swift │ └── python.swift ├── py_sf_ppo │ └── main.swift ├── random │ └── main.swift ├── sac │ └── main.swift └── td3 │ └── main.swift ├── external ├── PythonKit.BUILD ├── fpzip.BUILD ├── swift-algorithms.BUILD ├── swift-argument-parser.BUILD ├── swift-atomics.BUILD ├── swift-format.BUILD ├── swift-nio.BUILD ├── swift-numerics.BUILD ├── swift-protobuf.BUILD ├── swift-syntax.BUILD ├── swift-system.BUILD └── swift-tools-support-core.BUILD ├── gym ├── BUILD.bazel ├── Collector.swift ├── Gym.swift ├── NumericalStatistics.swift ├── RunningMeanStd.swift ├── assets │ ├── ant.xml │ ├── half_cheetah.xml │ ├── hopper.xml │ ├── humanoid.xml │ ├── humanoidstandup.xml │ ├── inverted_double_pendulum.xml │ ├── inverted_pendulum.xml │ ├── reacher.xml │ ├── swimmer.xml │ └── walker2d.xml ├── envs │ ├── Ant.swift │ ├── Env.swift │ ├── HalfCheetah.swift │ ├── Hopper.swift │ ├── Humanoid.swift │ ├── InvertedDoublePendulum.swift │ ├── InvertedPendulum.swift │ ├── Noise.swift │ ├── SFMT.swift │ ├── Swimmer.swift │ ├── TimeLimit.swift │ ├── VecEnv.swift │ └── Walker2D.swift ├── http │ └── HTTP.swift ├── policies │ └── PPO.swift └── renders │ ├── MuJoCoVideo.swift │ ├── MuJoCoViewer.swift │ ├── Renderable.swift │ ├── ffmpeg_shim.c │ └── ffmpeg_shim.h ├── nnc ├── AnyModel.swift ├── AutoGrad.swift ├── BUILD.bazel ├── C_zlib │ └── shim.h ├── CoreMLConversion.swift ├── DataFrame.swift ├── DataFrameAddons.swift ├── DataFrameCore.swift ├── DynamicGraph.swift ├── Functional.swift ├── FunctionalAddons.swift ├── GradScaler.swift ├── Group.swift ├── Hint.swift ├── Loss.swift ├── Model.swift ├── ModelAddons.swift ├── ModelBuilder.swift ├── ModelCore.swift ├── ModelIOAddons.swift ├── MuJoCoConversion.swift ├── Operators.swift ├── Optimizer.swift ├── OptimizerAddons.swift ├── PythonConversion.swift ├── Store.swift ├── StreamContext.swift ├── Tensor.swift ├── TensorGroup.swift └── Wrapped.swift ├── scripts ├── BUILD.bazel ├── buildifier │ └── pre-commit ├── compdb │ ├── BUILD │ └── compdb.py ├── docc.sh ├── install.sh ├── swift-format │ └── pre-commit ├── vendors │ └── dispatch └── vscode │ └── build.sh ├── tensorboard ├── BUILD.bazel ├── EventLogger.swift ├── README.md ├── SummaryWriter+NNC.swift ├── SummaryWriter.swift ├── Support.swift ├── compat │ └── proto │ │ ├── README.md │ │ ├── __init__.py │ │ ├── allocation_description.proto │ │ ├── api_def.proto │ │ ├── attr_value.proto │ │ ├── cluster.proto │ │ ├── config.proto │ │ ├── coordination_config.proto │ │ ├── cost_graph.proto │ │ ├── cpp_shape_inference.proto │ │ ├── debug.proto │ │ ├── event.proto │ │ ├── full_type.proto │ │ ├── function.proto │ │ ├── graph.proto │ │ ├── meta_graph.proto │ │ ├── node_def.proto │ │ ├── op_def.proto │ │ ├── proto_test.py │ │ ├── resource_handle.proto │ │ ├── rewriter_config.proto │ │ ├── saved_object_graph.proto │ │ ├── saver.proto │ │ ├── step_stats.proto │ │ ├── struct.proto │ │ ├── summary.proto │ │ ├── tensor.proto │ │ ├── tensor_description.proto │ │ ├── tensor_shape.proto │ │ ├── tfprof_log.proto │ │ ├── trackable_object_graph.proto │ │ ├── types.proto │ │ ├── update.sh │ │ ├── variable.proto │ │ ├── verifier_config.proto │ │ └── versions.proto └── proto-generated │ ├── allocation_description.pb.swift │ ├── api_def.pb.swift │ ├── attr_value.pb.swift │ ├── cluster.pb.swift │ ├── config.pb.swift │ ├── coordination_config.pb.swift │ ├── cost_graph.pb.swift │ ├── cpp_shape_inference.pb.swift │ ├── debug.pb.swift │ ├── event.pb.swift │ ├── full_type.pb.swift │ ├── function.pb.swift │ ├── graph.pb.swift │ ├── meta_graph.pb.swift │ ├── node_def.pb.swift │ ├── op_def.pb.swift │ ├── resource_handle.pb.swift │ ├── rewriter_config.pb.swift │ ├── saved_object_graph.pb.swift │ ├── saver.pb.swift │ ├── step_stats.pb.swift │ ├── struct.pb.swift │ ├── summary.pb.swift │ ├── tensor.pb.swift │ ├── tensor_description.pb.swift │ ├── tensor_shape.pb.swift │ ├── tfprof_log.pb.swift │ ├── trackable_object_graph.pb.swift │ ├── types.pb.swift │ ├── variable.pb.swift │ ├── verifier_config.pb.swift │ └── versions.pb.swift └── test ├── BUILD.bazel ├── coreml ├── main.swift └── mlshapedarray.swift ├── dataframe.swift ├── graph.swift ├── loss.swift ├── main.swift ├── model.swift ├── ops.swift ├── optimizer.swift ├── python ├── main.swift └── numpy.swift ├── scaled_data.csv ├── some_variables.db ├── store.swift └── tensor.swift /.bazelrc: -------------------------------------------------------------------------------- 1 | common:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain 2 | common:cuda --define=using_cuda=true --define=using_cuda_nvcc=true 3 | 4 | common:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain 5 | common:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true 6 | 7 | common:win-cuda --define=using_cuda=true --define=using_cuda_nvcc=true 8 | 9 | common --disk_cache=.cache 10 | 11 | build --cxxopt='-std=c++17' 12 | 13 | try-import %workspace%/clang.bazelrc 14 | try-import %workspace%/.bazelrc.local 15 | -------------------------------------------------------------------------------- /.bazelversion: -------------------------------------------------------------------------------- 1 | 6.4.0 2 | -------------------------------------------------------------------------------- /.github/workflows/documentation.yaml: -------------------------------------------------------------------------------- 1 | name: documentation 2 | on: 3 | push: 4 | branches: 5 | - main 6 | jobs: 7 | build: 8 | runs-on: ubuntu-22.04 9 | steps: 10 | - uses: actions/checkout@v4 11 | with: 12 | fetch-depth: 0 13 | 14 | - name: Install bazelisk 15 | run: | 16 | curl -LO "https://github.com/bazelbuild/bazelisk/releases/download/v1.21.0/bazelisk-linux-amd64" 17 | mkdir -p "${GITHUB_WORKSPACE}/bin/" 18 | mv bazelisk-linux-amd64 "${GITHUB_WORKSPACE}/bin/bazel" 19 | chmod +x "${GITHUB_WORKSPACE}/bin/bazel" 20 | 21 | - name: Install Swift dependencies 22 | run: | 23 | sudo apt update 24 | sudo apt -y install clang libicu-dev 25 | wget https://download.swift.org/swift-5.10.1-release/ubuntu2204/swift-5.10.1-RELEASE/swift-5.10.1-RELEASE-ubuntu22.04.tar.gz 26 | tar xzf swift-5.10.1-RELEASE-ubuntu22.04.tar.gz 27 | echo "$(pwd)/swift-5.10.1-RELEASE-ubuntu22.04/usr/bin" >> $GITHUB_PATH 28 | 29 | - name: Setup clang 30 | run: | 31 | sudo apt -y install libpng-dev libjpeg-dev libatlas-base-dev libblas-dev libgsl-dev clang llvm libdispatch-dev libomp-dev liblinear-dev libfftw3-dev libtesseract-dev libglfw3-dev 32 | ./bazel/setup_clang.sh 33 | echo "build --config=clang" >> "${GITHUB_WORKSPACE}/.bazelrc.local" 34 | 35 | - name: Clean up documentation branch 36 | run: | 37 | git branch -D documentation || true 38 | git checkout -b documentation 39 | 40 | - name: Run docc 41 | run: | 42 | cd "${GITHUB_WORKSPACE}" && ./scripts/docc.sh 43 | 44 | - name: Add and commit documentation 45 | run: | 46 | git config --global user.email "docbot@github.com" 47 | git config --global user.name "docbot" 48 | cd "${GITHUB_WORKSPACE}" && git add "docs/*" && git commit -m "Update docs." 49 | 50 | - name: Push the new branch 51 | run: | 52 | cd "${GITHUB_WORKSPACE}" && git push --force origin documentation:documentation 53 | 54 | -------------------------------------------------------------------------------- /.github/workflows/ubuntu-bazel.yaml: -------------------------------------------------------------------------------- 1 | name: ubuntu-bazel 2 | on: [push] 3 | jobs: 4 | build: 5 | runs-on: ubuntu-22.04 6 | steps: 7 | - uses: actions/checkout@v4 8 | 9 | - name: Install bazelisk 10 | run: | 11 | curl -LO "https://github.com/bazelbuild/bazelisk/releases/download/v1.21.0/bazelisk-linux-amd64" 12 | mkdir -p "${GITHUB_WORKSPACE}/bin/" 13 | mv bazelisk-linux-amd64 "${GITHUB_WORKSPACE}/bin/bazel" 14 | chmod +x "${GITHUB_WORKSPACE}/bin/bazel" 15 | 16 | - name: Install Swift dependencies 17 | run: | 18 | sudo apt update 19 | sudo apt -y install clang libicu-dev 20 | wget https://download.swift.org/swift-5.10.1-release/ubuntu2204/swift-5.10.1-RELEASE/swift-5.10.1-RELEASE-ubuntu22.04.tar.gz 21 | tar xzf swift-5.10.1-RELEASE-ubuntu22.04.tar.gz 22 | echo "$(pwd)/swift-5.10.1-RELEASE-ubuntu22.04/usr/bin" >> $GITHUB_PATH 23 | 24 | - name: Setup clang 25 | run: | 26 | sudo apt -y install libpng-dev libjpeg-dev libatlas-base-dev libblas-dev libgsl-dev clang llvm libdispatch-dev libomp-dev liblinear-dev libfftw3-dev libtesseract-dev 27 | ./bazel/setup_clang.sh 28 | echo "build --config=clang" >> "${GITHUB_WORKSPACE}/.bazelrc.local" 29 | 30 | - name: Run tests 31 | run: | 32 | "${GITHUB_WORKSPACE}/bin/bazel" test //test:nnc 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Xcode 2 | # 3 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore 4 | 5 | ## User settings 6 | xcuserdata/ 7 | 8 | ## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9) 9 | *.xcscmblueprint 10 | *.xccheckout 11 | 12 | ## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4) 13 | build/ 14 | DerivedData/ 15 | *.moved-aside 16 | *.pbxuser 17 | !default.pbxuser 18 | *.mode1v3 19 | !default.mode1v3 20 | *.mode2v3 21 | !default.mode2v3 22 | *.perspectivev3 23 | !default.perspectivev3 24 | 25 | ## Obj-C/Swift specific 26 | *.hmap 27 | 28 | ## App packaging 29 | *.ipa 30 | *.dSYM.zip 31 | *.dSYM 32 | 33 | ## Playgrounds 34 | timeline.xctimeline 35 | playground.xcworkspace 36 | 37 | # Swift Package Manager 38 | # 39 | # Add this line if you want to avoid checking in source code from Swift Package Manager dependencies. 40 | # Packages/ 41 | # Package.pins 42 | # Package.resolved 43 | # *.xcodeproj 44 | # 45 | # Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata 46 | # hence it is not needed unless you have added a package configuration file to your project 47 | # .swiftpm 48 | 49 | .build/ 50 | 51 | # CocoaPods 52 | # 53 | # We recommend against adding the Pods directory to your .gitignore. However 54 | # you should judge for yourself, the pros and cons are mentioned at: 55 | # https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control 56 | # 57 | # Pods/ 58 | # 59 | # Add this line if you want to avoid checking in source code from the Xcode workspace 60 | # *.xcworkspace 61 | 62 | # Carthage 63 | # 64 | # Add this line if you want to avoid checking in source code from Carthage dependencies. 65 | # Carthage/Checkouts 66 | 67 | Carthage/Build/ 68 | 69 | # Accio dependency management 70 | Dependencies/ 71 | .accio/ 72 | 73 | # fastlane 74 | # 75 | # It is recommended to not store the screenshots in the git repo. 76 | # Instead, use fastlane to re-generate the screenshots whenever they are needed. 77 | # For more information about the recommended setup visit: 78 | # https://docs.fastlane.tools/best-practices/source-control/#source-control 79 | 80 | fastlane/report.xml 81 | fastlane/Preview.html 82 | fastlane/screenshots/**/*.png 83 | fastlane/test_output 84 | 85 | # Code Injection 86 | # 87 | # After new code Injection tools there's a generated folder /iOSInjectionProject 88 | # https://github.com/johnno1962/injectionforxcode 89 | 90 | iOSInjectionProject/ 91 | 92 | bazel-* 93 | 94 | .bazelrc.local 95 | clang.bazelrc 96 | .cache 97 | 98 | *.[oda] 99 | *.~ 100 | *.swp 101 | *.gcno 102 | *.out 103 | 104 | .ipynb_checkpoints/ 105 | .ibzlnb/ 106 | .index/ 107 | _env/ 108 | compile_commands.json 109 | 110 | runs/ 111 | *.mp4 112 | -------------------------------------------------------------------------------- /.swift-format.json: -------------------------------------------------------------------------------- 1 | { 2 | "fileScopedDeclarationPrivacy" : { 3 | "accessLevel" : "private" 4 | }, 5 | "indentation" : { 6 | "spaces" : 2 7 | }, 8 | "indentConditionalCompilationBlocks" : true, 9 | "lineBreakAroundMultilineExpressionChainComponents" : false, 10 | "lineBreakBeforeControlFlowKeywords" : false, 11 | "lineBreakBeforeEachArgument" : false, 12 | "lineBreakBeforeEachGenericRequirement" : false, 13 | "lineLength" : 100, 14 | "maximumBlankLines" : 1, 15 | "prioritizeKeepingFunctionOutputTogether" : false, 16 | "respectsExistingLineBreaks" : true, 17 | "rules" : { 18 | "AllPublicDeclarationsHaveDocumentation" : true, 19 | "AlwaysUseLowerCamelCase" : true, 20 | "AmbiguousTrailingClosureOverload" : true, 21 | "BeginDocumentationCommentWithOneLineSummary" : true, 22 | "DoNotUseSemicolons" : true, 23 | "DontRepeatTypeInStaticProperties" : true, 24 | "FileprivateAtFileScope" : true, 25 | "FullyIndirectEnum" : true, 26 | "GroupNumericLiterals" : true, 27 | "IdentifiersMustBeASCII" : true, 28 | "NeverForceUnwrap" : true, 29 | "NeverUseForceTry" : true, 30 | "NeverUseImplicitlyUnwrappedOptionals" : true, 31 | "NoAccessLevelOnExtensionDeclaration" : true, 32 | "NoBlockComments" : true, 33 | "NoCasesWithOnlyFallthrough" : true, 34 | "NoEmptyTrailingClosureParentheses" : true, 35 | "NoLabelsInCasePatterns" : true, 36 | "NoLeadingUnderscores" : true, 37 | "NoParensAroundConditions" : true, 38 | "NoVoidReturnOnFunctionSignature" : true, 39 | "OneCasePerLine" : true, 40 | "OneVariableDeclarationPerLine" : true, 41 | "OnlyOneTrailingClosureArgument" : true, 42 | "OrderedImports" : true, 43 | "ReturnVoidInsteadOfEmptyTuple" : true, 44 | "UseLetInEveryBoundCaseVariable" : true, 45 | "UseShorthandTypeNames" : true, 46 | "UseSingleLinePropertyGetter" : true, 47 | "UseSynthesizedInitializer" : true, 48 | "UseTripleSlashForDocumentationComments" : true, 49 | "ValidateDocumentationComments" : true 50 | }, 51 | "tabWidth" : 2, 52 | "version" : 1 53 | } 54 | -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuliu/s4nnc/76508762da1553344430d1be580bdd740fa674e9/BUILD -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Liu Liu 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /WORKLOG.md: -------------------------------------------------------------------------------- 1 | 2022-08-26 2 | ---------- 3 | Write down some more implementation details about PPO: 4 | 5 | 1. Handle termination properly. If terminated, shouldn't use critic to estimate value. Just assign 0; 6 | 2. action should be clipped to -1...1 and then scaled to the action range; 7 | 3. Should use the pre-clipped action for probabilistic distribution computation, otherwise it is more likely to get into nan / inf; 8 | 4. observations, if normalized, better clipped to a range, the magic value is -10...10; 9 | 5. reward normalization after each batch or after all steps seems have minimal impact on training. 10 | 11 | 12 | 2022-08-24 13 | ---------- 14 | One thing interesting in Swift compiler now I discovered is how unstable tuple is. Besides it is not an existential type, `withUnsafePointer(to: &atuple.0)` will have different pointer than `withUnsafePointer(to: &atuple)`. It is not an issue if you only ever use Swift, but because static arrays in C is interoperated as tuple, the former one will give you the correct type but incorrect pointer while the later will have the correct pointer but incorrect type. 15 | 16 | 17 | 2021-08-06 18 | ---------- 19 | Swift is memory-safe language in general. But for frameworks talking to C, even C side has no memory issues (really?!), it is a good idea to run valgrind against the Swift binary. 20 | 21 | When upgraded to Swift 5.4.2, the unit tests with this repo now fails in `DataFrameTests.testToGPU()`. The error is because a pointer cannot be unregistered from CUDA side. When we moving memory from CPU to GPU, we first pin the memory such that there will be marginal speed up when doing the copy from CPU to GPU memory. It can be fixed by just ignoring the error, because at unregistering time, I will deallocate that memory anyway. However, I do want to solve it, it bothers me. Digging deeper, the pointer changed from when allocated to when freed. That's why we pinned one memory region, but tries to unpin another one. This hints a memory issue. 22 | 23 | Probing through the codebase doesn't yield much. The strong suspicion would always be on the C side of things. But we don't have this issue for a long time, the pointer looks like simply changed without any interaction on that side. After probing for an hour, I started to try valgrind. It crashes on tests almost immediately. First I dismissed it as if something wrong with valgrind. But looking back, it suggests some data I freed shouldn't. This is interesting. 24 | 25 | In DataFrame, I wrap data from an array in two ways. One is tensor, which I wrapped as the C tensor such that it can be passed to a C processor (such as `to_gpu`) for transformations. Another is object, which we created temporarily and will release when iteration is done. Due to implementation error, I called `Unmanaged.fromOpaque().release()` in the case where I wrap the C tensor, causing modifications on the said pointer from Swift runtime. This error only done when we create DataFrame from an array, which is not the majority of use-case. 26 | 27 | I guess it is also time to see whether `rules_swift` can support address sanitizer, which Swift itself supports. These tools, even for a memory-safe language, is quite useful, it turns out. 28 | 29 | 30 | 2020-12-22 31 | ---------- 32 | Fixed previous mentioned issue. This is possible by having a new method in libnnc: `ccv_nnc_dynamic_graph_has_effect_to_tensor_variables` which will answer questions such as whether a variable A can have effect to variable B, therefore, finding the ones with requiresGrad = true while has effect to the object that you call `backward(to:)` on. 33 | 34 | Added a unit test that simulate implementing `Dense` model by using tensor variable directly, and validated the test works. 35 | 36 | 37 | 2020-12-20 38 | ---------- 39 | One part is not exposed is how requiresGrad works in coordination with `backward(to:)`. The underlying `ccv_nnc_dynamic_graph_backward` method takes input tensors and gives out gradients for the inputs. Thus, if we want to compute grad for parameters that is not in `backward(to:)`, we need to somehow get that tensors from DynamicGraph. 40 | 41 | More over, because we also need to know if that tracked tensor is relevant at all, we need to expose a new method to query such information from `ccv_nnc_dynamic_graph_t`. This requires us to design a new API. 42 | 43 | This is not a big problem for my current use, but will be if I want to release this. 44 | -------------------------------------------------------------------------------- /bazel/setup_clang.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BAZELRC_FILE="${BAZELRC_FILE:-$(bazel info workspace)/clang.bazelrc}" 4 | 5 | LLVM_PREFIX=$1 6 | 7 | if [[ ! -e "${LLVM_PREFIX}/bin/llvm-config" ]]; then 8 | echo "Error: cannot find llvm-config in ${LLVM_PREFIX}." 9 | exit 1 10 | fi 11 | 12 | export PATH="$(${LLVM_PREFIX}/bin/llvm-config --bindir):${PATH}" 13 | 14 | RT_LIBRARY_PATH="$(dirname $(find $(llvm-config --libdir) -name libclang_rt.ubsan_standalone_cxx-x86_64.a | head -1))" 15 | 16 | echo "# Generated file, do not edit. If you want to disable clang, just delete this file. 17 | build:clang --action_env='PATH=${PATH}' 18 | build:clang --action_env=CC=clang 19 | build:clang --action_env=CXX=clang++ 20 | build:clang --action_env='LLVM_CONFIG=${LLVM_PREFIX}/bin/llvm-config' 21 | build:clang --repo_env='LLVM_CONFIG=${LLVM_PREFIX}/bin/llvm-config' 22 | build:clang --linkopt='-L$(llvm-config --libdir)' 23 | build:clang --linkopt='-Wl,-rpath,$(llvm-config --libdir)' 24 | 25 | build:clang-asan --action_env=ENVOY_UBSAN_VPTR=1 26 | build:clang-asan --copt=-fsanitize=vptr,function 27 | build:clang-asan --linkopt=-fsanitize=vptr,function 28 | build:clang-asan --linkopt='-L${RT_LIBRARY_PATH}' 29 | build:clang-asan --linkopt=-l:libclang_rt.ubsan_standalone-x86_64.a 30 | build:clang-asan --linkopt=-l:libclang_rt.ubsan_standalone_cxx-x86_64.a 31 | " > ${BAZELRC_FILE} 32 | 33 | -------------------------------------------------------------------------------- /deps.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository", "new_git_repository") 2 | 3 | def _maybe(repo_rule, name, **kwargs): 4 | """Executes the given repository rule if it hasn't been executed already. 5 | Args: 6 | repo_rule: The repository rule to be executed (e.g., `http_archive`.) 7 | name: The name of the repository to be defined by the rule. 8 | **kwargs: Additional arguments passed directly to the repository rule. 9 | """ 10 | if not native.existing_rule(name): 11 | repo_rule(name = name, **kwargs) 12 | 13 | def s4nnc_deps(): 14 | """Loads common dependencies needed to compile the s4nnc library.""" 15 | 16 | _maybe( 17 | git_repository, 18 | name = "ccv", 19 | remote = "https://github.com/liuliu/ccv.git", 20 | commit = "52e47ec6ba0afbc3be5a4b315f722cd43d290423", 21 | shallow_since = "1748755843 -0400", 22 | ) 23 | 24 | _maybe( 25 | new_git_repository, 26 | name = "PythonKit", 27 | remote = "https://github.com/liuliu/PythonKit.git", 28 | commit = "fbf22756c91d89b0f2e39a89b690aaa538cf9b03", 29 | shallow_since = "1664547636 -0400", 30 | build_file = "@s4nnc//:external/PythonKit.BUILD", 31 | ) 32 | 33 | _maybe( 34 | new_git_repository, 35 | name = "fpzip", 36 | commit = "79aa1b1bd5a0b9497b8ad4352d8561ab17113cdf", 37 | remote = "https://github.com/LLNL/fpzip.git", 38 | shallow_since = "1591380432 -0700", 39 | build_file = "@s4nnc//:external/fpzip.BUILD", 40 | ) 41 | 42 | def s4nnc_extra_deps(): 43 | """Loads common dependencies needed to compile gym and tensorboard.""" 44 | 45 | _maybe( 46 | new_git_repository, 47 | name = "SwiftNumerics", 48 | remote = "https://github.com/apple/swift-numerics.git", 49 | commit = "4a2cbc186b1f8cbbc1ace12cef43d65784b2559e", 50 | shallow_since = "1605460976 -0500", 51 | build_file = "@s4nnc//:external/swift-numerics.BUILD", 52 | ) 53 | 54 | _maybe( 55 | new_git_repository, 56 | name = "SwiftAlgorithms", 57 | remote = "https://github.com/apple/swift-algorithms.git", 58 | commit = "195e0316d7ba71e134d0f6c677f64b4db6160c46", 59 | shallow_since = "1645643239 -0600", 60 | build_file = "@s4nnc//:external/swift-algorithms.BUILD", 61 | ) 62 | 63 | _maybe( 64 | new_git_repository, 65 | name = "SwiftSystem", 66 | build_file = "@s4nnc//:external/swift-system.BUILD", 67 | commit = "fbd61a676d79cbde05cd4fda3cc46e94d6b8f0eb", 68 | remote = "https://github.com/apple/swift-system.git", 69 | shallow_since = "1729316385 -0700", 70 | ) 71 | 72 | _maybe( 73 | new_git_repository, 74 | name = "SwiftProtobuf", 75 | build_file = "@s4nnc//:external/swift-protobuf.BUILD", 76 | commit = "d57a5aecf24a25b32ec4a74be2f5d0a995a47c4b", 77 | remote = "https://github.com/apple/swift-protobuf.git", 78 | shallow_since = "1720448759 -0400", 79 | ) 80 | 81 | _maybe( 82 | git_repository, 83 | name = "swift-jupyter", 84 | commit = "22bdd9758c9070a1de38c8538b34b4cc9ec559c0", 85 | remote = "https://github.com/liuliu/swift-jupyter.git", 86 | shallow_since = "1659044971 -0400", 87 | ) 88 | 89 | _maybe( 90 | new_git_repository, 91 | name = "swift-atomics", 92 | build_file = "@s4nnc//:external/swift-atomics.BUILD", 93 | commit = "088df27f0683f2b458021ebf04873174b91ae597", 94 | remote = "https://github.com/apple/swift-atomics.git", 95 | shallow_since = "1649274362 -0700", 96 | ) 97 | 98 | _maybe( 99 | new_git_repository, 100 | name = "SwiftNIO", 101 | build_file = "@s4nnc//:external/swift-nio.BUILD", 102 | commit = "48916a49afedec69275b70893c773261fdd2cfde", 103 | remote = "https://github.com/apple/swift-nio.git", 104 | shallow_since = "1657195654 +0100", 105 | ) 106 | -------------------------------------------------------------------------------- /examples/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_binary") 2 | 3 | swift_binary( 4 | name = "imdb", 5 | srcs = ["imdb/main.swift"], 6 | deps = [ 7 | "//nnc", 8 | ], 9 | ) 10 | 11 | swift_binary( 12 | name = "minrf", 13 | srcs = ["minrf/main.swift"], 14 | deps = [ 15 | "//nnc", 16 | "//tensorboard", 17 | ], 18 | ) 19 | 20 | swift_binary( 21 | name = "cifar-10", 22 | srcs = ["cifar-10/main.swift"], 23 | deps = [ 24 | "//nnc", 25 | ], 26 | ) 27 | 28 | swift_binary( 29 | name = "pandas", 30 | srcs = ["pandas/main.swift"], 31 | deps = [ 32 | "//nnc", 33 | "@PythonKit", 34 | ], 35 | ) 36 | 37 | swift_binary( 38 | name = "dqn", 39 | srcs = ["dqn/main.swift"], 40 | deps = [ 41 | "//nnc", 42 | "//nnc:nnc_python", 43 | "@PythonKit", 44 | "@SwiftAlgorithms//:Algorithms", 45 | ], 46 | ) 47 | 48 | swift_binary( 49 | name = "ddpg", 50 | srcs = ["ddpg/main.swift"], 51 | deps = [ 52 | "//nnc", 53 | "//nnc:nnc_python", 54 | "@PythonKit", 55 | "@SwiftAlgorithms//:Algorithms", 56 | "@SwiftNumerics//:Numerics", 57 | ], 58 | ) 59 | 60 | swift_binary( 61 | name = "td3", 62 | srcs = ["td3/main.swift"], 63 | deps = [ 64 | "//nnc", 65 | "//nnc:nnc_python", 66 | "@PythonKit", 67 | "@SwiftAlgorithms//:Algorithms", 68 | "@SwiftNumerics//:Numerics", 69 | ], 70 | ) 71 | 72 | swift_binary( 73 | name = "ppo", 74 | srcs = ["ppo/main.swift"], 75 | deps = [ 76 | "//gym", 77 | "//nnc", 78 | "//nnc:nnc_python", 79 | "//tensorboard", 80 | "@PythonKit", 81 | "@SwiftAlgorithms//:Algorithms", 82 | "@SwiftNumerics//:Numerics", 83 | ], 84 | ) 85 | 86 | swift_binary( 87 | name = "py_sf_ppo", # Python, Single-File, PPO. This is kept as is to check against any changes in PPO. 88 | srcs = ["py_sf_ppo/main.swift"], 89 | deps = [ 90 | "//gym", 91 | "//nnc", 92 | "//nnc:nnc_python", 93 | "@PythonKit", 94 | "@SwiftAlgorithms//:Algorithms", 95 | "@SwiftNumerics//:Numerics", 96 | ], 97 | ) 98 | 99 | swift_binary( 100 | name = "sac", 101 | srcs = ["sac/main.swift"], 102 | deps = [ 103 | "//gym", 104 | "//nnc", 105 | "//nnc:nnc_python", 106 | "//tensorboard", 107 | "@PythonKit", 108 | "@SwiftAlgorithms//:Algorithms", 109 | "@SwiftNumerics//:Numerics", 110 | ], 111 | ) 112 | 113 | swift_binary( 114 | name = "random", 115 | srcs = ["random/main.swift"], 116 | deps = [ 117 | "//gym", 118 | "//gym:gym_video", 119 | "//nnc", 120 | "//nnc:nnc_python", 121 | "//tensorboard", 122 | "@PythonKit", 123 | "@SwiftAlgorithms//:Algorithms", 124 | "@SwiftNumerics//:Numerics", 125 | ], 126 | ) 127 | -------------------------------------------------------------------------------- /examples/pandas/main.swift: -------------------------------------------------------------------------------- 1 | import NNC 2 | import PythonKit 3 | 4 | let sys = Python.import("sys") 5 | let result: PythonObject = (sys.version_info.major == 3) 6 | if result == true { 7 | print(sys.version_info) 8 | } 9 | 10 | let gc = Python.import("gc") 11 | 12 | let y = PythonObject(10) 13 | let lambda1 = PythonFunction { x in x * y } 14 | let lambda2 = PythonFunction { (x: PythonObject) -> PythonConvertible in x + y } 15 | print(Python.list(Python.map(lambda1, [10, 12, 14]))) 16 | print(Python.list(Python.map(lambda1, [2, 3, 4]))) 17 | print(Python.list(Python.map(lambda2, [2, 3, 4]))) 18 | print(gc.get_stats()) 19 | -------------------------------------------------------------------------------- /examples/ppo/python.swift: -------------------------------------------------------------------------------- 1 | import PythonKit 2 | 3 | struct PythonEnv: PythonConvertible { 4 | public private(set) var pythonObject: PythonObject 5 | public init(pythonObject: PythonObject) { 6 | self.pythonObject = pythonObject 7 | } 8 | } 9 | 10 | extension PythonEnv: Env { 11 | public typealias ActType = Tensor 12 | public typealias ObsType = Tensor 13 | public typealias RewardType = Float 14 | public typealias DoneType = Bool 15 | public func step(action: ActType) -> (ObsType, RewardType, DoneType, [String: Any]) { 16 | let (obs, reward, done, info) = pythonObject.step(action).tuple4 17 | var newInfo = [String: Any]() 18 | newInfo["TimeLimit.truncated"] = info.checking["TimeLimit.truncated"].flatMap { Bool($0) } 19 | return (try! Tensor(numpy: obs), Float(reward)!, Bool(done)!, newInfo) 20 | } 21 | public func reset(seed: Int?) -> (ObsType, [String: Any]) { 22 | let obs = pythonObject.reset(seed: seed) 23 | return (try! Tensor(numpy: obs), [:]) 24 | } 25 | public static var rewardThreshold: Float { 6_000 } 26 | public static var actionSpace: [ClosedRange] { fatalError() } 27 | public static var stateSize: Int { fatalError() } 28 | } 29 | -------------------------------------------------------------------------------- /examples/random/main.swift: -------------------------------------------------------------------------------- 1 | import Algorithms 2 | import Foundation 3 | import Gym 4 | import GymVideo 5 | import NNC 6 | import NNCPythonConversion 7 | import Numerics 8 | import TensorBoard 9 | 10 | typealias TargetEnv = Walker2D 11 | 12 | let output_dim = TargetEnv.actionSpace.count 13 | let action_range: Float = TargetEnv.actionSpace[0].upperBound 14 | 15 | let graph = DynamicGraph() 16 | var sfmt = SFMT(seed: 10) 17 | 18 | DynamicGraph.setSeed(0) 19 | var testEnv = TimeLimit(env: try TargetEnv(), maxEpisodeSteps: 1_000) 20 | let _ = testEnv.reset(seed: 180) 21 | let video = MuJoCoVideo( 22 | env: testEnv, filePath: "/home/liu/workspace/s4nnc/examples/random/random.mp4") 23 | var episodes = 0 24 | while episodes < 10 { 25 | let act = graph.variable(Tensor(.GPU(0), .C(output_dim))) 26 | act.randn(std: 1, mean: 0) 27 | act.clamp(-1...1) 28 | let act_v = (action_range * act).rawValue.toCPU() 29 | let (_, _, done, _) = testEnv.step(action: Tensor(from: act_v)) 30 | if done { 31 | let _ = testEnv.reset() 32 | episodes += 1 33 | } 34 | video.render() 35 | } 36 | 37 | video.close() 38 | -------------------------------------------------------------------------------- /external/PythonKit.BUILD: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | ) 6 | 7 | swift_library( 8 | name = "PythonKit", 9 | srcs = glob([ 10 | "PythonKit/**/*.swift", 11 | ]), 12 | module_name = "PythonKit", 13 | ) 14 | -------------------------------------------------------------------------------- /external/fpzip.BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | ) 4 | 5 | cc_library( 6 | name = "fpzip", 7 | srcs = glob([ 8 | "src/**/*.cpp", 9 | "src/**/*.h", 10 | "src/**/*.inl", 11 | ]), 12 | hdrs = ["include/fpzip.h"], 13 | defines = ["FPZIP_FP=FPZIP_FP_FAST"], 14 | includes = ["include"], 15 | local_defines = ["FPZIP_BLOCK_SIZE=0x1000"], 16 | tags = ["swift_module=C_fpzip"], 17 | ) 18 | -------------------------------------------------------------------------------- /external/swift-algorithms.BUILD: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") 2 | 3 | swift_library( 4 | name = "Algorithms", 5 | srcs = glob([ 6 | "Sources/Algorithms/**/*.swift", 7 | ]), 8 | module_name = "Algorithms", 9 | visibility = ["//visibility:public"], 10 | deps = [ 11 | "@SwiftNumerics//:RealModule", 12 | ], 13 | ) 14 | -------------------------------------------------------------------------------- /external/swift-argument-parser.BUILD: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") 2 | 3 | swift_library( 4 | name = "ArgumentParserToolInfo", 5 | srcs = glob([ 6 | "Sources/ArgumentParserToolInfo/**/*.swift", 7 | ]), 8 | module_name = "ArgumentParserToolInfo", 9 | ) 10 | 11 | swift_library( 12 | name = "ArgumentParser", 13 | srcs = glob([ 14 | "Sources/ArgumentParser/**/*.swift", 15 | ]), 16 | module_name = "ArgumentParser", 17 | visibility = ["//visibility:public"], 18 | deps = [ 19 | ":ArgumentParserToolInfo", 20 | ], 21 | ) 22 | -------------------------------------------------------------------------------- /external/swift-atomics.BUILD: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") 2 | 3 | cc_library( 4 | name = "_AtomicsShims", 5 | srcs = ["Sources/_AtomicsShims/src/_AtomicsShims.c"], 6 | hdrs = ["Sources/_AtomicsShims/include/_AtomicsShims.h"], 7 | includes = [ 8 | "Sources/_AtomicsShims/include/", 9 | ], 10 | tags = ["swift_module=_AtomicsShims"], 11 | ) 12 | 13 | swift_library( 14 | name = "SwiftAtomics", 15 | srcs = glob([ 16 | "Sources/Atomics/**/*.swift", 17 | ]), 18 | module_name = "Atomics", 19 | visibility = ["//visibility:public"], 20 | deps = [ 21 | ":_AtomicsShims", 22 | ], 23 | alwayslink = True, 24 | ) 25 | -------------------------------------------------------------------------------- /external/swift-format.BUILD: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_binary", "swift_library") 2 | 3 | swift_library( 4 | name = "SwiftFormatConfiguration", 5 | srcs = glob([ 6 | "Sources/SwiftFormatConfiguration/**/*.swift", 7 | ]), 8 | module_name = "SwiftFormatConfiguration", 9 | ) 10 | 11 | swift_library( 12 | name = "SwiftFormatCore", 13 | srcs = glob([ 14 | "Sources/SwiftFormatCore/**/*.swift", 15 | ]), 16 | module_name = "SwiftFormatCore", 17 | deps = [ 18 | ":SwiftFormatConfiguration", 19 | "@SwiftSyntax", 20 | "@SwiftSyntax//:SwiftOperators", 21 | ], 22 | ) 23 | 24 | swift_library( 25 | name = "SwiftFormatRules", 26 | srcs = glob([ 27 | "Sources/SwiftFormatRules/**/*.swift", 28 | ]), 29 | module_name = "SwiftFormatRules", 30 | deps = [ 31 | ":SwiftFormatCore", 32 | ], 33 | ) 34 | 35 | swift_library( 36 | name = "SwiftFormatPrettyPrint", 37 | srcs = glob([ 38 | "Sources/SwiftFormatPrettyPrint/**/*.swift", 39 | ]), 40 | module_name = "SwiftFormatPrettyPrint", 41 | deps = [ 42 | ":SwiftFormatCore", 43 | ], 44 | ) 45 | 46 | swift_library( 47 | name = "SwiftFormatWhitespaceLinter", 48 | srcs = glob([ 49 | "Sources/SwiftFormatWhitespaceLinter/**/*.swift", 50 | ]), 51 | module_name = "SwiftFormatWhitespaceLinter", 52 | deps = [ 53 | ":SwiftFormatCore", 54 | ], 55 | ) 56 | 57 | swift_library( 58 | name = "SwiftFormat", 59 | srcs = glob([ 60 | "Sources/SwiftFormat/**/*.swift", 61 | ]), 62 | module_name = "SwiftFormat", 63 | deps = [ 64 | ":SwiftFormatCore", 65 | ":SwiftFormatPrettyPrint", 66 | ":SwiftFormatRules", 67 | ":SwiftFormatWhitespaceLinter", 68 | "@SwiftSyntax//:SwiftParserDiagnostics", 69 | "@SwiftSyntax//:SwiftSyntaxParser", 70 | ], 71 | ) 72 | 73 | swift_binary( 74 | name = "swift-format", 75 | srcs = glob([ 76 | "Sources/swift-format/**/*.swift", 77 | ]), 78 | visibility = ["//visibility:public"], 79 | deps = [ 80 | ":SwiftFormat", 81 | "@SwiftArgumentParser//:ArgumentParser", 82 | "@SwiftToolsSupportCore//:TSCBasic", 83 | ], 84 | ) 85 | -------------------------------------------------------------------------------- /external/swift-nio.BUILD: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_binary", "swift_library") 2 | 3 | cc_library( 4 | name = "CNIOAtomics", 5 | srcs = glob([ 6 | "Sources/CNIOAtomics/src/*.c", 7 | "Sources/CNIOAtomics/src/*.h", 8 | ]), 9 | hdrs = glob(["Sources/CNIOAtomics/include/*.h"]), 10 | includes = [ 11 | "Sources/CNIOAtomics/include/", 12 | ], 13 | tags = ["swift_module=CNIOAtomics"], 14 | ) 15 | 16 | cc_library( 17 | name = "CNIOSHA1", 18 | srcs = glob([ 19 | "Sources/CNIOSHA1/*.c", 20 | "Sources/CNIOSHA1/*.h", 21 | ]), 22 | hdrs = glob(["Sources/CNIOSHA1/include/*.h"]), 23 | includes = [ 24 | "Sources/CNIOSHA1/include/", 25 | ], 26 | tags = ["swift_module=CNIOSHA1"], 27 | ) 28 | 29 | cc_library( 30 | name = "CNIOLinux", 31 | srcs = glob([ 32 | "Sources/CNIOLinux/*.c", 33 | "Sources/CNIOLinux/*.h", 34 | ]), 35 | hdrs = glob(["Sources/CNIOLinux/include/*.h"]), 36 | includes = [ 37 | "Sources/CNIOLinux/include/", 38 | ], 39 | tags = ["swift_module=CNIOLinux"], 40 | ) 41 | 42 | cc_library( 43 | name = "CNIODarwin", 44 | srcs = glob([ 45 | "Sources/CNIODarwin/*.c", 46 | "Sources/CNIODarwin/*.h", 47 | ]), 48 | hdrs = glob(["Sources/CNIODarwin/include/*.h"]), 49 | defines = ["__APPLE_USE_RFC_3542"], 50 | includes = [ 51 | "Sources/CNIODarwin/include/", 52 | ], 53 | tags = ["swift_module=CNIODarwin"], 54 | ) 55 | 56 | cc_library( 57 | name = "CNIOWindows", 58 | srcs = glob([ 59 | "Sources/CNIOWindows/*.c", 60 | "Sources/CNIOWindows/*.h", 61 | ]), 62 | hdrs = glob(["Sources/CNIOWindows/include/*.h"]), 63 | includes = [ 64 | "Sources/CNIOWindows/include/", 65 | ], 66 | tags = ["swift_module=CNIOWindows"], 67 | ) 68 | 69 | swift_library( 70 | name = "NIOConcurrencyHelpers", 71 | srcs = glob([ 72 | "Sources/NIOConcurrencyHelpers/**/*.swift", 73 | ]), 74 | module_name = "NIOConcurrencyHelpers", 75 | visibility = ["//visibility:public"], 76 | deps = [":CNIOAtomics"], 77 | ) 78 | 79 | swift_library( 80 | name = "NIOCore", 81 | srcs = glob([ 82 | "Sources/NIOCore/**/*.swift", 83 | ]), 84 | module_name = "NIOCore", 85 | visibility = ["//visibility:public"], 86 | deps = [ 87 | ":CNIOLinux", 88 | ":CNIOWindows", 89 | ":NIOConcurrencyHelpers", 90 | ], 91 | ) 92 | 93 | swift_library( 94 | name = "_NIODataStructures", 95 | srcs = glob([ 96 | "Sources/_NIODataStructures/**/*.swift", 97 | ]), 98 | module_name = "_NIODataStructures", 99 | ) 100 | 101 | swift_library( 102 | name = "NIOEmbedded", 103 | srcs = glob([ 104 | "Sources/NIOEmbedded/**/*.swift", 105 | ]), 106 | module_name = "NIOEmbedded", 107 | visibility = ["//visibility:public"], 108 | deps = [ 109 | ":NIOConcurrencyHelpers", 110 | ":NIOCore", 111 | ":_NIODataStructures", 112 | "@swift-atomics//:SwiftAtomics", 113 | ], 114 | ) 115 | 116 | swift_library( 117 | name = "NIOPosix", 118 | srcs = glob([ 119 | "Sources/NIOPosix/**/*.swift", 120 | ]), 121 | module_name = "NIOPosix", 122 | visibility = ["//visibility:public"], 123 | deps = [ 124 | ":CNIODarwin", 125 | ":CNIOLinux", 126 | ":CNIOWindows", 127 | ":NIOConcurrencyHelpers", 128 | ":NIOCore", 129 | ":_NIODataStructures", 130 | "@swift-atomics//:SwiftAtomics", 131 | ], 132 | ) 133 | 134 | swift_library( 135 | name = "NIO", 136 | srcs = glob([ 137 | "Sources/NIO/**/*.swift", 138 | ]), 139 | module_name = "NIO", 140 | visibility = ["//visibility:public"], 141 | deps = [ 142 | ":NIOCore", 143 | ":NIOEmbedded", 144 | ":NIOPosix", 145 | ], 146 | ) 147 | 148 | swift_library( 149 | name = "NIOFoundationCompat", 150 | srcs = glob([ 151 | "Sources/NIOFoundationCompat/**/*.swift", 152 | ]), 153 | module_name = "NIOFoundationCompat", 154 | visibility = ["//visibility:public"], 155 | deps = [ 156 | ":NIO", 157 | ":NIOCore", 158 | ], 159 | ) 160 | 161 | cc_library( 162 | name = "CNIOHTTPParser", 163 | srcs = glob([ 164 | "Sources/CNIOHTTPParser/*.c", 165 | "Sources/CNIOHTTPParser/*.h", 166 | ]), 167 | hdrs = glob(["Sources/CNIOHTTPParser/include/*.h"]), 168 | includes = [ 169 | "Sources/CNIOWindows/include/", 170 | ], 171 | tags = ["swift_module=CNIOHTTPParser"], 172 | ) 173 | 174 | swift_library( 175 | name = "NIOHTTP1", 176 | srcs = glob([ 177 | "Sources/NIOHTTP1/**/*.swift", 178 | ]), 179 | module_name = "NIOHTTP1", 180 | visibility = ["//visibility:public"], 181 | deps = [ 182 | ":CNIOHTTPParser", 183 | ":NIO", 184 | ":NIOConcurrencyHelpers", 185 | ":NIOCore", 186 | ], 187 | ) 188 | 189 | swift_library( 190 | name = "NIOTLS", 191 | srcs = glob([ 192 | "Sources/NIOTLS/**/*.swift", 193 | ]), 194 | module_name = "NIOTLS", 195 | visibility = ["//visibility:public"], 196 | deps = [ 197 | ":NIO", 198 | ":NIOCore", 199 | ], 200 | ) 201 | 202 | swift_library( 203 | name = "NIOWebSocket", 204 | srcs = glob([ 205 | "Sources/NIOWebSocket/**/*.swift", 206 | ]), 207 | module_name = "NIOWebSocket", 208 | visibility = ["//visibility:public"], 209 | deps = [ 210 | ":CNIOSHA1", 211 | ":NIO", 212 | ":NIOCore", 213 | ":NIOHTTP1", 214 | ], 215 | ) 216 | 217 | swift_binary( 218 | name = "NIOHTTP1Server", 219 | srcs = glob([ 220 | "Sources/NIOHTTP1Server/**/*.swift", 221 | ]), 222 | deps = [ 223 | ":NIOConcurrencyHelpers", 224 | ":NIOCore", 225 | ":NIOHTTP1", 226 | ":NIOPosix", 227 | ], 228 | ) 229 | -------------------------------------------------------------------------------- /external/swift-numerics.BUILD: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") 2 | 3 | cc_library( 4 | name = "_NumericsShims", 5 | srcs = ["Sources/_NumericsShims/_NumericsShims.c"], 6 | hdrs = ["Sources/_NumericsShims/include/_NumericsShims.h"], 7 | includes = [ 8 | "Sources/_NumericsShims/include/", 9 | ], 10 | tags = ["swift_module=_NumericsShims"], 11 | ) 12 | 13 | swift_library( 14 | name = "RealModule", 15 | srcs = glob([ 16 | "Sources/RealModule/**/*.swift", 17 | ]), 18 | module_name = "RealModule", 19 | visibility = ["//visibility:public"], 20 | deps = [ 21 | ":_NumericsShims", 22 | ], 23 | ) 24 | 25 | swift_library( 26 | name = "ComplexModule", 27 | srcs = glob([ 28 | "Sources/ComplexModule/**/*.swift", 29 | ]), 30 | module_name = "ComplexModule", 31 | visibility = ["//visibility:public"], 32 | deps = [ 33 | ":RealModule", 34 | ], 35 | ) 36 | 37 | swift_library( 38 | name = "Numerics", 39 | srcs = glob([ 40 | "Sources/Numerics/**/*.swift", 41 | ]), 42 | module_name = "Numerics", 43 | visibility = ["//visibility:public"], 44 | deps = [ 45 | ":ComplexModule", 46 | ":RealModule", 47 | ], 48 | ) 49 | -------------------------------------------------------------------------------- /external/swift-protobuf.BUILD: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_binary", "swift_library") 2 | 3 | swift_library( 4 | name = "SwiftProtobuf", 5 | srcs = glob([ 6 | "Sources/SwiftProtobuf/**/*.swift", 7 | ]), 8 | module_name = "SwiftProtobuf", 9 | visibility = ["//visibility:public"], 10 | deps = [], 11 | ) 12 | 13 | swift_library( 14 | name = "SwiftProtobufPluginLibrary", 15 | srcs = glob([ 16 | "Sources/SwiftProtobufPluginLibrary/**/*.swift", 17 | ]), 18 | module_name = "SwiftProtobufPluginLibrary", 19 | visibility = ["//visibility:public"], 20 | deps = [":SwiftProtobuf"], 21 | ) 22 | 23 | swift_binary( 24 | name = "protoc-gen-swift", 25 | srcs = glob([ 26 | "Sources/protoc-gen-swift/**/*.swift", 27 | ]), 28 | visibility = ["//visibility:public"], 29 | deps = [ 30 | ":SwiftProtobuf", 31 | ":SwiftProtobufPluginLibrary", 32 | ], 33 | ) 34 | -------------------------------------------------------------------------------- /external/swift-syntax.BUILD: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") 2 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") 3 | 4 | swift_library( 5 | name = "SwiftSyntax", 6 | srcs = glob([ 7 | "Sources/SwiftSyntax/**/*.swift", 8 | ]), 9 | module_name = "SwiftSyntax", 10 | visibility = ["//visibility:public"], 11 | deps = [], 12 | ) 13 | 14 | swift_library( 15 | name = "SwiftBasicFormat", 16 | srcs = glob([ 17 | "Sources/SwiftBasicFormat/**/*.swift", 18 | ]), 19 | module_name = "SwiftBasicFormat", 20 | visibility = ["//visibility:public"], 21 | deps = [ 22 | ":SwiftSyntax", 23 | ], 24 | ) 25 | 26 | swift_library( 27 | name = "SwiftDiagnostics", 28 | srcs = glob([ 29 | "Sources/SwiftDiagnostics/**/*.swift", 30 | ]), 31 | module_name = "SwiftDiagnostics", 32 | visibility = ["//visibility:public"], 33 | deps = [ 34 | ":SwiftSyntax", 35 | ], 36 | ) 37 | 38 | swift_library( 39 | name = "SwiftParser", 40 | srcs = glob([ 41 | "Sources/SwiftParser/**/*.swift", 42 | ]), 43 | module_name = "SwiftParser", 44 | visibility = ["//visibility:public"], 45 | deps = [ 46 | ":SwiftDiagnostics", 47 | ":SwiftSyntax", 48 | ], 49 | ) 50 | 51 | swift_library( 52 | name = "SwiftParserDiagnostics", 53 | srcs = glob([ 54 | "Sources/SwiftParserDiagnostics/**/*.swift", 55 | ]), 56 | module_name = "SwiftParserDiagnostics", 57 | visibility = ["//visibility:public"], 58 | deps = [ 59 | ":SwiftBasicFormat", 60 | ":SwiftDiagnostics", 61 | ":SwiftParser", 62 | ":SwiftSyntax", 63 | ], 64 | ) 65 | 66 | swift_library( 67 | name = "SwiftSyntaxBuilder", 68 | srcs = glob([ 69 | "Sources/SwiftSyntaxBuilder/**/*.swift", 70 | ]), 71 | module_name = "SwiftSyntaxBuilder", 72 | visibility = ["//visibility:public"], 73 | deps = [ 74 | ":SwiftBasicFormat", 75 | ":SwiftParser", 76 | ":SwiftParserDiagnostics", 77 | ":SwiftSyntax", 78 | ], 79 | ) 80 | 81 | swift_library( 82 | name = "SwiftSyntaxParser", 83 | srcs = glob([ 84 | "Sources/SwiftSyntaxParser/**/*.swift", 85 | ]), 86 | module_name = "SwiftSyntaxParser", 87 | visibility = ["//visibility:public"], 88 | deps = [ 89 | ":SwiftParser", 90 | ":SwiftSyntax", 91 | ], 92 | ) 93 | 94 | swift_library( 95 | name = "SwiftOperators", 96 | srcs = glob([ 97 | "Sources/SwiftOperators/**/*.swift", 98 | ]), 99 | module_name = "SwiftOperators", 100 | visibility = ["//visibility:public"], 101 | deps = [ 102 | ":SwiftDiagnostics", 103 | ":SwiftParser", 104 | ":SwiftSyntax", 105 | ], 106 | ) 107 | -------------------------------------------------------------------------------- /external/swift-system.BUILD: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") 2 | load("@bazel_skylib//lib:selects.bzl", "selects") 3 | 4 | cc_library( 5 | name = "CSystem", 6 | srcs = ["Sources/CSystem/shims.c"], 7 | hdrs = glob([ 8 | "Sources/CSystem/include/*.h", 9 | ]), 10 | includes = [ 11 | "Sources/CSystem/include/", 12 | ], 13 | tags = ["swift_module=CSystem"], 14 | ) 15 | 16 | config_setting( 17 | name = "macos_build", 18 | constraint_values = [ 19 | "@platforms//os:osx", 20 | ], 21 | ) 22 | 23 | config_setting( 24 | name = "ios_build", 25 | constraint_values = [ 26 | "@platforms//os:ios", 27 | ], 28 | ) 29 | 30 | selects.config_setting_group( 31 | name = "ios_or_macos_build", 32 | match_any = [ 33 | ":macos_build", 34 | ":ios_build", 35 | ], 36 | ) 37 | 38 | swift_library( 39 | name = "SystemPackage", 40 | srcs = glob([ 41 | "Sources/System/**/*.swift", 42 | ]), 43 | defines = [ 44 | "_CRT_SECURE_NO_WARNINGS", 45 | "SYSTEM_PACKAGE", 46 | ] + select({ 47 | ":ios_or_macos_build": ["SYSTEM_PACKAGE_DARWIN"], 48 | "//conditions:default": [], 49 | }), 50 | module_name = "SystemPackage", 51 | visibility = ["//visibility:public"], 52 | deps = [ 53 | ":CSystem", 54 | ], 55 | ) 56 | -------------------------------------------------------------------------------- /external/swift-tools-support-core.BUILD: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") 2 | 3 | config_setting( 4 | name = "linux_build", 5 | constraint_values = [ 6 | "@platforms//os:linux", 7 | ], 8 | ) 9 | 10 | cc_library( 11 | name = "TSCclibc", 12 | srcs = glob(["Sources/TSCclibc/*.c"]), 13 | hdrs = glob([ 14 | "Sources/TSCclibc/include/*.h", 15 | ]), 16 | includes = [ 17 | "Sources/TSCclibc/include/", 18 | ], 19 | local_defines = select({ 20 | ":linux_build": ["_GNU_SOURCE"], 21 | "//conditions:default": [], 22 | }), 23 | tags = ["swift_module=TSCclibc"], 24 | ) 25 | 26 | swift_library( 27 | name = "TSCLibc", 28 | srcs = glob([ 29 | "Sources/TSCLibc/**/*.swift", 30 | ]), 31 | module_name = "TSCLibc", 32 | deps = [], 33 | ) 34 | 35 | swift_library( 36 | name = "TSCBasic", 37 | srcs = glob([ 38 | "Sources/TSCBasic/**/*.swift", 39 | ]), 40 | module_name = "TSCBasic", 41 | visibility = ["//visibility:public"], 42 | deps = [ 43 | ":TSCLibc", 44 | ":TSCclibc", 45 | "@SwiftSystem//:SystemPackage", 46 | ], 47 | ) 48 | -------------------------------------------------------------------------------- /gym/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") 2 | 3 | swift_library( 4 | name = "gym", 5 | srcs = glob( 6 | ["**/*.swift"], 7 | exclude = ["renders/MuJoCoVideo.swift"], 8 | ), 9 | data = glob(["assets/*.xml"]), 10 | module_name = "Gym", 11 | visibility = ["//visibility:public"], 12 | deps = [ 13 | "//nnc", 14 | "//nnc:nnc_mujoco", 15 | "@SwiftAlgorithms//:Algorithms", 16 | "@SwiftNIO//:NIOConcurrencyHelpers", 17 | "@SwiftNIO//:NIOCore", 18 | "@SwiftNIO//:NIOHTTP1", 19 | "@SwiftNIO//:NIOPosix", 20 | "@SwiftNIO//:NIOWebSocket", 21 | "@SwiftNumerics//:Numerics", 22 | "@ccv//lib:SFMT", 23 | "@ccv//lib:ccv", 24 | "@swift-jupyter//:JupyterDisplay", 25 | "@swift-mujoco//:swift-mujoco", 26 | ], 27 | ) 28 | 29 | cc_library( 30 | name = "C_ffmpeg", 31 | srcs = ["renders/ffmpeg_shim.c"], 32 | hdrs = ["renders/ffmpeg_shim.h"], 33 | linkopts = [ 34 | "-lavcodec", 35 | "-lswscale", 36 | "-lavformat", 37 | "-lavutil", 38 | ], 39 | tags = ["swift_module=C_ffmpeg"], 40 | ) 41 | 42 | swift_library( 43 | name = "gym_video", 44 | srcs = ["renders/MuJoCoVideo.swift"], 45 | module_name = "GymVideo", 46 | visibility = ["//visibility:public"], 47 | deps = [ 48 | ":C_ffmpeg", 49 | ":gym", 50 | ], 51 | ) 52 | -------------------------------------------------------------------------------- /gym/Collector.swift: -------------------------------------------------------------------------------- 1 | import NNC 2 | 3 | /// The helper to collect data from the given Envs. 4 | /// 5 | /// The policy is intentional to be flexible as a simple closure. We collect both the action and state 6 | /// data the closure provides. There may be some type conversions for what the Env expects and what 7 | /// the policy provides, thus, need to specify both. To make this easier to use, the placeholder 8 | /// type would be useful: `Collector(envs: envs) { ... }` 9 | public struct Collector 10 | where EnvType.ObsType == Tensor, EnvType.ActType == Tensor { 11 | public typealias ObsType = Tensor 12 | public typealias ActType = Tensor 13 | var envs: [EnvType] 14 | var batch: [CollectedData] 15 | var finalizedBatch: [CollectedData] 16 | let policy: (_: ObsType) -> (ActType, StateType) 17 | public init(envs: [EnvType], policy: @escaping (_: ObsType) -> (ActType, StateType)) { 18 | self.envs = envs 19 | self.policy = policy 20 | batch = [] 21 | for i in 0.. { 30 | public enum EnvState { 31 | case ready 32 | case terminated 33 | case truncated 34 | } 35 | public typealias ObsType = Tensor 36 | public typealias ActType = Tensor 37 | public var lastObservation: ObsType 38 | public var rewards: [Float] 39 | public var states: [StateType] 40 | public var episodeReward: Float 41 | public var episodeLength: Int 42 | public var envState: EnvState 43 | public init(lastObservation: ObsType) { 44 | self.lastObservation = lastObservation 45 | rewards = [] 46 | states = [] 47 | episodeReward = 0 48 | episodeLength = 0 49 | envState = .ready 50 | } 51 | mutating func reset(keepLastN: Int = 0) { 52 | guard keepLastN > 0 else { 53 | rewards.removeAll() 54 | states.removeAll() 55 | return 56 | } 57 | guard keepLastN < rewards.count else { return } 58 | rewards.removeFirst(rewards.count - keepLastN) 59 | states.removeFirst(states.count - keepLastN) 60 | } 61 | } 62 | 63 | extension Collector { 64 | public struct Statistics { 65 | public var episodeCount: Int 66 | public var stepCount: Int 67 | public var episodeReward: NumericalStatistics 68 | public var episodeLength: NumericalStatistics 69 | init( 70 | episodeCount: Int, stepCount: Int, episodeReward: NumericalStatistics, 71 | episodeLength: NumericalStatistics 72 | ) { 73 | self.episodeCount = episodeCount 74 | self.stepCount = stepCount 75 | self.episodeReward = episodeReward 76 | self.episodeLength = episodeLength 77 | } 78 | } 79 | 80 | public mutating func resetData(keepLastN: Int = 0) { 81 | for i in 0..] { 99 | finalizedBatch + batch.filter { $0.rewards.count > 0 } 100 | } 101 | } 102 | 103 | extension Collector where EnvType.TerminatedType == Bool, EnvType.RewardType == Float { 104 | public mutating func collect(nStep: Int) -> Statistics { 105 | var episodeCount = 0 106 | var stepCount = 0 107 | var episodeRewards = [Float]() 108 | var episodeLengths = [Float]() 109 | while stepCount < nStep { 110 | for i in 0..(_ array: C) where C.Element == Float { 11 | if array.count > 0 { 12 | let mean = (array.reduce(0) { $0 + $1 }) / Float(array.count) 13 | self.mean = mean 14 | self.std = ((array.reduce(0) { $0 + ($1 - mean) * ($1 - mean) }) / Float(array.count)) 15 | .squareRoot() 16 | } else { 17 | mean = 0 18 | std = 0 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /gym/RunningMeanStd.swift: -------------------------------------------------------------------------------- 1 | import NNC 2 | 3 | public struct RunningMeanStd { 4 | public var mean: TensorElement 5 | public var variance: TensorElement 6 | public var count: Int 7 | public init(mean: TensorElement, variance: TensorElement) { 8 | self.mean = mean 9 | self.variance = variance 10 | count = 0 11 | } 12 | public mutating func update(_ data: [TensorElement]) { 13 | let graph = mean.graph 14 | precondition(data.count >= 1) 15 | graph.withNoGrad { 16 | let batchMean: TensorElement 17 | let batchVar: TensorElement 18 | if data.count > 1 { 19 | batchMean = 1 / Float(data.count) * Functional.sum(data) 20 | batchVar = 21 | 1 / Float(data.count) * Functional.sum(data.map { ($0 - batchMean) .* ($0 - batchMean) }) 22 | } else { 23 | batchMean = data[0] 24 | batchVar = graph.variable(like: batchMean) 25 | batchVar.full(0) 26 | } 27 | let delta = batchMean - mean 28 | let totalCount = count + data.count 29 | mean = mean + Float(data.count) / Float(totalCount) * delta 30 | let mA = Float(count) * variance 31 | let mB = Float(data.count) * batchVar 32 | let m2 = Functional.sum( 33 | mA, mB, Float(count) * Float(data.count) / Float(totalCount) * (delta .* delta)) 34 | variance = 1.0 / Float(totalCount) * m2 35 | count = totalCount 36 | } 37 | } 38 | public func norm(_ input: TensorElement) -> TensorElement { 39 | return (input - mean) ./ Functional.squareRoot(variance).clamped(1e-5...) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /gym/assets/ant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 82 | -------------------------------------------------------------------------------- /gym/assets/half_cheetah.xml: -------------------------------------------------------------------------------- 1 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 97 | -------------------------------------------------------------------------------- /gym/assets/hopper.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 49 | -------------------------------------------------------------------------------- /gym/assets/inverted_double_pendulum.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /gym/assets/inverted_pendulum.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /gym/assets/reacher.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 40 | -------------------------------------------------------------------------------- /gym/assets/swimmer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 40 | -------------------------------------------------------------------------------- /gym/assets/walker2d.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 63 | -------------------------------------------------------------------------------- /gym/envs/Env.swift: -------------------------------------------------------------------------------- 1 | public protocol Env { 2 | associatedtype ObsType 3 | associatedtype ActType 4 | associatedtype RewardType 5 | associatedtype TerminatedType 6 | mutating func step(action: ActType) -> (ObsType, RewardType, TerminatedType, [String: Any]) 7 | mutating func reset(seed: Int?) -> (ObsType, [String: Any]) 8 | static var rewardThreshold: Float { get } 9 | static var stateSize: Int { get } 10 | static var actionSpace: [ClosedRange] { get } 11 | } 12 | 13 | extension Env { 14 | public mutating func reset() -> (ObsType, [String: Any]) { 15 | return reset(seed: nil) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /gym/envs/Noise.swift: -------------------------------------------------------------------------------- 1 | func noise(_ std: Double, using: inout T) -> Double { 2 | let u1 = Double.random(in: 0...1, using: &using) 3 | let u2 = Double.random(in: 0...1, using: &using) 4 | let mag = std * (-2.0 * .log(u1)).squareRoot() 5 | return mag * .cos(.pi * 2 * u2) 6 | } 7 | -------------------------------------------------------------------------------- /gym/envs/SFMT.swift: -------------------------------------------------------------------------------- 1 | import C_sfmt 2 | 3 | public struct SFMT: RandomNumberGenerator { 4 | private var state: sfmt_t 5 | public init(seed: UInt64) { 6 | state = sfmt_t() 7 | sfmt_init_gen_rand(&state, UInt32(truncatingIfNeeded: seed)) 8 | } 9 | public mutating func next() -> UInt64 { 10 | return sfmt_genrand_uint64(&state) 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /gym/envs/TimeLimit.swift: -------------------------------------------------------------------------------- 1 | import MuJoCo 2 | 3 | public final class TimeLimit { 4 | private var env: EnvType 5 | private let maxEpisodeSteps: Int 6 | private var elapsedSteps: Int 7 | public init(env: EnvType, maxEpisodeSteps: Int) { 8 | self.env = env 9 | self.maxEpisodeSteps = maxEpisodeSteps 10 | self.elapsedSteps = 0 11 | } 12 | } 13 | 14 | extension TimeLimit: Env where EnvType.TerminatedType == Bool { 15 | public typealias ActType = EnvType.ActType 16 | public typealias ObsType = EnvType.ObsType 17 | public typealias RewardType = EnvType.RewardType 18 | public typealias TerminatedType = EnvType.TerminatedType 19 | 20 | public func step(action: ActType) -> (ObsType, RewardType, TerminatedType, [String: Any]) { 21 | let result = env.step(action: action) 22 | var (_, _, terminated, info) = result 23 | elapsedSteps += 1 24 | if elapsedSteps >= maxEpisodeSteps { 25 | // TimeLimit.truncated key may have been already set by the environment 26 | // do not overwrite it 27 | let episodeTruncated = !terminated || (info["TimeLimit.truncated", default: false] as! Bool) 28 | info["TimeLimit.truncated"] = episodeTruncated 29 | terminated = true 30 | } 31 | return (result.0, result.1, terminated, info) 32 | } 33 | 34 | public func reset(seed: Int?) -> (ObsType, [String: Any]) { 35 | elapsedSteps = 0 36 | return env.reset(seed: seed) 37 | } 38 | 39 | public static var rewardThreshold: Float { EnvType.rewardThreshold } 40 | public static var actionSpace: [ClosedRange] { EnvType.actionSpace } 41 | public static var stateSize: Int { EnvType.stateSize } 42 | } 43 | 44 | extension TimeLimit: MuJoCoEnv where EnvType: MuJoCoEnv { 45 | public var model: MjModel { env.model } 46 | public var data: MjData { 47 | get { env.data } 48 | set { env.data = newValue } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /gym/envs/VecEnv.swift: -------------------------------------------------------------------------------- 1 | import Dispatch 2 | import NNC 3 | 4 | public final class VecEnv 5 | where EnvType.ActType == Tensor, EnvType.ObsType == Tensor { 6 | private var envs = [EnvType]() 7 | private var terminated = [Bool]() 8 | private var obs = [EnvType.ObsType]() 9 | private var rewards = [EnvType.RewardType]() 10 | public init(count: Int, _ closure: (_: Int) throws -> EnvType) rethrows { 11 | precondition(count > 0) 12 | envs = [] 13 | terminated = [] 14 | for i in 0.. (ObsType, RewardType, TerminatedType, [String: Any]) { 27 | if obs.count == 0 || rewards.count == 0 { // If we never done obs, we need to build up the array, do it serially. The reason because I cannot construct the array with optional types easily. 28 | obs = [] 29 | rewards = [] 30 | for i in 0..( 46 | self.obs[0].kind, format: self.obs[0].format, 47 | shape: [envs.count, self.obs[0].shape[0]]) 48 | for i in 0.. (ObsType, [String: Any]) { 55 | if let seed = seed { 56 | var sfmt = SFMT(seed: UInt64(bitPattern: Int64(seed))) 57 | if obs.count == 0 { 58 | for i in 0..( 83 | self.obs[0].kind, format: self.obs[0].format, 84 | shape: [envs.count, self.obs[0].shape[0]]) 85 | for i in 0..] { EnvType.actionSpace } 94 | public static var stateSize: Int { EnvType.stateSize } 95 | } 96 | -------------------------------------------------------------------------------- /gym/renders/MuJoCoViewer.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import JupyterDisplay 3 | import MuJoCo 4 | import NIOCore 5 | 6 | public protocol MuJoCoEnv { 7 | var model: MjModel { get } 8 | var data: MjData { get set } 9 | } 10 | 11 | public final class MuJoCoViewer { 12 | var env: EnvType 13 | var cpugenesis: Double = 0 14 | var simgenesis: Double = 0 15 | var simulate: Simulate 16 | let renderServer: HTTPRenderServer 17 | var httpChannel: Channel? = nil 18 | var renderCell: Int? = nil 19 | public init(env: EnvType, width: Int = 1280, height: Int = 720, title: String = "viewer") { 20 | self.env = env 21 | simulate = Simulate(width: width, height: height, title: title) 22 | simulate.use(model: self.env.model, data: &self.env.data) 23 | simulate.ui0 = false 24 | simulate.ui1 = false 25 | renderServer = HTTPRenderServer(simulate, maxWidth: width, maxHeight: height, canResize: false) 26 | } 27 | } 28 | 29 | extension MuJoCoViewer: Renderable { 30 | public func render() { 31 | if JupyterDisplay.isEnabled { 32 | // Check to see if we launched the render server yet. 33 | if httpChannel == nil { 34 | httpChannel = try? renderServer.bind(host: "0.0.0.0", port: .random(in: 10_000..<20_000)) 35 | .wait() 36 | } 37 | if JupyterDisplay.executionCount != renderCell { 38 | JupyterDisplay.display(html: renderServer.html) 39 | JupyterDisplay.flush() 40 | renderCell = JupyterDisplay.executionCount 41 | } 42 | } 43 | if cpugenesis == 0 { 44 | cpugenesis = GLContext.time 45 | simgenesis = env.data.time 46 | } 47 | let simsync = env.data.time 48 | simulate.yield() 49 | var cpusync = GLContext.time 50 | while simsync - simgenesis >= cpusync - cpugenesis { // wait until reality catches up with simulation. 51 | simulate.yield() 52 | cpusync = GLContext.time 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /gym/renders/Renderable.swift: -------------------------------------------------------------------------------- 1 | public protocol Renderable { 2 | func render() 3 | } 4 | -------------------------------------------------------------------------------- /gym/renders/ffmpeg_shim.c: -------------------------------------------------------------------------------- 1 | #include "ffmpeg_shim.h" 2 | 3 | int averror_is_eagain_or_eof(int ret) 4 | { 5 | return ret == AVERROR(EAGAIN) || ret == AVERROR_EOF; 6 | } 7 | -------------------------------------------------------------------------------- /gym/renders/ffmpeg_shim.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | extern int averror_is_eagain_or_eof(int ret); 8 | -------------------------------------------------------------------------------- /nnc/AnyModel.swift: -------------------------------------------------------------------------------- 1 | public protocol AnyModel { 2 | /** 3 | * Whether the existing model is for testing or training. 4 | */ 5 | var testing: Bool { get set } 6 | /** 7 | * Whether to enable memory reduction for this model. The current supported memory reduction 8 | * technique is to redo datatype conversion during backward pass if needed. 9 | */ 10 | var memoryReduction: Bool { get set } 11 | /** 12 | * Specify the maximum number of streams we need to allocate to run this model. 13 | */ 14 | var maxConcurrency: StreamContext.Concurrency { get set } 15 | /** 16 | * Abstract representation of the stateful components from the model builder. 17 | */ 18 | var parameters: Model.Parameters { get } 19 | /** 20 | * Shortcut for weight parameter. 21 | */ 22 | var weight: Model.Parameters { get } 23 | /** 24 | * Shortcut for bias parameter. 25 | */ 26 | var bias: Model.Parameters { get } 27 | /** 28 | * The size of scratch memory allocated for this model. 29 | */ 30 | var runtimeMemorySize: UInt64 { get } 31 | /** 32 | * Broadly speaking, you can have two types of parameters, weight and bias. 33 | * You can get them in abstract fashion with this method. 34 | * 35 | * - Parameter type: Whether it is weight or bias. 36 | * - Returns: An abstract representation of parameters. 37 | */ 38 | func parameters(for type: Model.ParametersType) -> Model.Parameters 39 | /** 40 | * Cancel current evaluation of this model. It only cancels the model that you know is currently 41 | * in evaluation, if you didn't get the execution order right, it won't have effect (you need 42 | * to make sure this method, if it is called, is strictly after call to callAsFunction and before 43 | * it returns). 44 | */ 45 | func cancel() 46 | } 47 | -------------------------------------------------------------------------------- /nnc/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") 2 | 3 | cc_library( 4 | name = "C_zlib", 5 | hdrs = ["C_zlib/shim.h"], 6 | defines = ["_GNU_SOURCE"], 7 | linkopts = ["-lz"], 8 | tags = ["swift_module=C_zlib"], 9 | ) 10 | 11 | swift_library( 12 | name = "nnc", 13 | srcs = [ 14 | "AnyModel.swift", 15 | "AutoGrad.swift", 16 | "DataFrame.swift", 17 | "DataFrameAddons.swift", 18 | "DataFrameCore.swift", 19 | "DynamicGraph.swift", 20 | "Functional.swift", 21 | "FunctionalAddons.swift", 22 | "GradScaler.swift", 23 | "Group.swift", 24 | "Hint.swift", 25 | "Loss.swift", 26 | "Model.swift", 27 | "ModelAddons.swift", 28 | "ModelBuilder.swift", 29 | "ModelCore.swift", 30 | "ModelIOAddons.swift", 31 | "Operators.swift", 32 | "Optimizer.swift", 33 | "OptimizerAddons.swift", 34 | "Store.swift", 35 | "StreamContext.swift", 36 | "Tensor.swift", 37 | "TensorGroup.swift", 38 | "Wrapped.swift", 39 | ], 40 | module_name = "NNC", 41 | visibility = ["//visibility:public"], 42 | deps = [ 43 | ":C_zlib", 44 | "@ccv//lib:ccv", 45 | "@ccv//lib:nnc", 46 | "@fpzip", 47 | ], 48 | ) 49 | 50 | swift_library( 51 | name = "nnc_python", 52 | srcs = [ 53 | "PythonConversion.swift", 54 | ], 55 | module_name = "NNCPythonConversion", 56 | visibility = ["//visibility:public"], 57 | deps = [ 58 | ":nnc", 59 | "@PythonKit", 60 | ], 61 | ) 62 | 63 | swift_library( 64 | name = "nnc_mujoco", 65 | srcs = [ 66 | "MuJoCoConversion.swift", 67 | ], 68 | module_name = "NNCMuJoCoConversion", 69 | visibility = ["//visibility:public"], 70 | deps = [ 71 | ":nnc", 72 | "@swift-mujoco", 73 | ], 74 | ) 75 | 76 | swift_library( 77 | name = "nnc_coreml", 78 | srcs = [ 79 | "CoreMLConversion.swift", 80 | ], 81 | module_name = "NNCCoreMLConversion", 82 | visibility = ["//visibility:public"], 83 | deps = [ 84 | ":nnc", 85 | ], 86 | ) 87 | -------------------------------------------------------------------------------- /nnc/C_zlib/shim.h: -------------------------------------------------------------------------------- 1 | // 2 | // shim.h 3 | // ZIPFoundation 4 | // 5 | // Copyright © 2017-2023 Thomas Zoechling, https://www.peakstep.com and the ZIP Foundation project authors. 6 | // Released under the MIT License. 7 | // 8 | // See https://github.com/weichsel/ZIPFoundation/blob/master/LICENSE for license information. 9 | // 10 | 11 | #ifndef zlib_shim_h 12 | #define zlib_shim_h 13 | 14 | #import 15 | #import 16 | 17 | // [zlib] provide 64-bit offset functions if _LARGEFILE64_SOURCE defined 18 | #ifndef _LARGEFILE64_SOURCE 19 | # define _LARGEFILE64_SOURCE 1 20 | #endif 21 | // [zlib] change the regular functions to 64 bits if _FILE_OFFSET_BITS is 64 22 | #ifndef _FILE_OFFSET_BITS 23 | # define _FILE_OFFSET_BITS 64 24 | #endif 25 | // [zlib] on systems without large file support, _LFS64_LARGEFILE must also be true 26 | #ifndef _LFS64_LARGEFILE 27 | # define _LFS64_LARGEFILE 1 28 | #endif 29 | 30 | #endif 31 | -------------------------------------------------------------------------------- /nnc/CoreMLConversion.swift: -------------------------------------------------------------------------------- 1 | import C_nnc 2 | import NNC 3 | #if canImport(lib_nnc_mps_compat) && canImport(CoreML) 4 | import lib_nnc_mps_compat 5 | import CoreML 6 | 7 | extension Tensor where Element: MLShapedArrayScalar { 8 | public init(_ shapedArray: MLShapedArray) { 9 | var shapedArray = shapedArray 10 | let (pointer, shape) = shapedArray.withUnsafeMutableShapedBufferPointer { pointer, shape, _ in 11 | return (pointer.baseAddress!, shape) 12 | } 13 | // All these are fragile, since at this point, there is no guarantee that the pointer 14 | // will be valid. However, if you are doing everything right, it should be. 15 | self.init( 16 | .CPU, format: .NCHW, shape: TensorShape(shape), unsafeMutablePointer: pointer, 17 | bindLifetimeOf: shapedArray) 18 | } 19 | } 20 | 21 | extension MLShapedArray where Scalar: TensorNumeric { 22 | public init(_ tensor: Tensor) { 23 | let cTensor = tensor.cTensor 24 | switch tensor.kind { 25 | case .CPU: 26 | let storage = Unmanaged.passRetained(tensor.storage) 27 | if tensor.isTensorView { 28 | self.init(bytesNoCopy: cTensor.pointee.data.u8, shape: Array(tensor.shape), strides: Array(tensor.strides), deallocator: .custom({ _, _ in 29 | storage.release() 30 | })) 31 | } else { 32 | self.init(bytesNoCopy: cTensor.pointee.data.u8, shape: Array(tensor.shape), deallocator: .custom({ _, _ in 33 | storage.release() 34 | })) 35 | } 36 | case .GPU(_): 37 | let buffer = mpgetbuffer(cTensor)! 38 | let contents = buffer.contents().assumingMemoryBound(to: UInt8.self) 39 | let unmanaged = Unmanaged.passRetained(buffer) 40 | if tensor.isTensorView { 41 | self.init(bytesNoCopy: contents + Int(cTensor.pointee.dataof), shape: Array(tensor.shape), strides: Array(tensor.strides), deallocator: .custom({ _, _ in 42 | unmanaged.release() 43 | })) 44 | } else { 45 | self.init(bytesNoCopy: contents + Int(cTensor.pointee.dataof), shape: Array(tensor.shape), deallocator: .custom({ _, _ in 46 | unmanaged.release() 47 | })) 48 | } 49 | } 50 | } 51 | } 52 | #endif 53 | -------------------------------------------------------------------------------- /nnc/Hint.swift: -------------------------------------------------------------------------------- 1 | import C_nnc 2 | 3 | /// Hint are parameters to these operations that changes shape from input 4 | /// to output. It given proper stride / padding parameters for these operations. 5 | public struct Hint { 6 | public var stride: [Int] 7 | 8 | public struct Border { 9 | public var begin: [Int] 10 | public var end: [Int] 11 | 12 | public init() { 13 | begin = [] 14 | end = [] 15 | } 16 | 17 | public init(_ border: [Int]) { 18 | begin = border 19 | end = border 20 | } 21 | 22 | public init(begin: [Int], end: [Int]) { 23 | self.begin = begin 24 | self.end = end 25 | } 26 | } 27 | 28 | public var border: Border 29 | 30 | public init() { 31 | stride = [] 32 | border = Border() 33 | } 34 | 35 | public init(stride: [Int], border: Border = Border()) { 36 | self.stride = stride 37 | self.border = border 38 | } 39 | } 40 | 41 | extension Hint { 42 | func toCHint() -> ccv_nnc_hint_t { 43 | var hint = ccv_nnc_hint_t() 44 | if stride.count > 0 { 45 | hint.stride.dim = toCDimensions(stride) 46 | } 47 | if border.begin.count > 0 { 48 | hint.border.begin = toCDimensions(border.begin) 49 | } 50 | if border.end.count > 0 { 51 | hint.border.end = toCDimensions(border.end) 52 | } 53 | return hint 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /nnc/ModelCore.swift: -------------------------------------------------------------------------------- 1 | @resultBuilder 2 | public struct Sequential { 3 | 4 | public typealias Expression = Model 5 | 6 | public typealias Component = [Model] 7 | 8 | public typealias FinalResult = Model 9 | 10 | public static func buildExpression(_ expression: Expression) -> Component { 11 | return [expression] 12 | } 13 | 14 | public static func buildBlock(_ children: Component...) -> Component { 15 | return children.flatMap { $0 } 16 | } 17 | 18 | public static func buildArray(_ components: [Component]) -> Component { 19 | return components.flatMap { $0 } 20 | } 21 | 22 | public static func buildBlock(_ component: Component) -> Component { 23 | return component 24 | } 25 | 26 | public static func buildOptional(_ children: Component?) -> Component { 27 | return children ?? [] 28 | } 29 | 30 | public static func buildEither(first child: Component) -> Component { 31 | return child 32 | } 33 | 34 | public static func buildEither(second child: Component) -> Component { 35 | return child 36 | } 37 | 38 | public static func buildFinalResult(_ component: Component) -> FinalResult { 39 | return Model(component) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /nnc/MuJoCoConversion.swift: -------------------------------------------------------------------------------- 1 | import C_nnc 2 | import MuJoCo 3 | import NNC 4 | 5 | extension Tensor { 6 | /** 7 | * Initialize a tensor from MjArray. This doesn't copy the data over, rather, we simply keep the 8 | * original MjArray alive. That's also why we don't support any data conversion. If you want to 9 | * keep the resulting Tensor for later usage (rather than do the computation right now), you need 10 | * to make your own copies because MjArray **CANNOT** be immutable and can tied to underlying 11 | * MjData updates (through `MjModel.forward` or `MjModel.step`). */ 12 | @inlinable 13 | public init(mjArray: MjArray) { 14 | // MjArray is one dimension. Treat this as a C dimension. 15 | self.init( 16 | .CPU, format: .NCHW, shape: [mjArray.count], unsafeMutablePointer: mjArray + 0, 17 | bindLifetimeOf: mjArray) 18 | } 19 | } 20 | 21 | extension MjArray where Element: TensorNumeric { 22 | @inlinable 23 | public subscript(bounds: Range) -> Tensor { 24 | get { Tensor(mjArray: self[bounds]) } 25 | set { 26 | newValue.withUnsafeBytes { 27 | precondition( 28 | MemoryLayout.size * (bounds.upperBound - bounds.lowerBound) == $0.count) 29 | guard let source = $0.baseAddress else { return } 30 | UnsafeMutableRawPointer(self + bounds.lowerBound).copyMemory( 31 | from: source, byteCount: $0.count) 32 | } 33 | } 34 | } 35 | @inlinable 36 | public subscript(bounds: ClosedRange) -> Tensor { 37 | get { 38 | return self[bounds.lowerBound..<(bounds.upperBound + 1)] 39 | } 40 | set { 41 | self[bounds.lowerBound..<(bounds.upperBound + 1)] = newValue 42 | } 43 | } 44 | @inlinable 45 | public subscript(bounds: PartialRangeUpTo) -> Tensor { 46 | get { 47 | return self[0..) -> Tensor { 55 | get { 56 | return self[0..<(bounds.upperBound + 1)] 57 | } 58 | set { 59 | self[0..<(bounds.upperBound + 1)] = newValue 60 | } 61 | } 62 | @inlinable 63 | public subscript(bounds: PartialRangeFrom) -> Tensor { 64 | get { 65 | return self[bounds.lowerBound.. Void) -> Tensor { 73 | get { 74 | return Tensor(mjArray: self) 75 | } 76 | set { 77 | newValue.withUnsafeBytes { 78 | precondition(MemoryLayout.size * count == $0.count) 79 | guard let source = $0.baseAddress else { return } 80 | UnsafeMutableRawPointer(self + 0).copyMemory(from: source, byteCount: $0.count) 81 | } 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /nnc/PythonConversion.swift: -------------------------------------------------------------------------------- 1 | import C_nnc 2 | import NNC 3 | import PythonKit 4 | 5 | private let np = Python.import("numpy") 6 | private let ctypes = Python.import("ctypes") 7 | 8 | extension Tensor where Element: NumpyScalarCompatible { 9 | 10 | /// Cannot create a tensor from numpy array. 11 | public enum NumpyScalarCompatibleError: Error { 12 | /// The PythonObject is not a numpy array. 13 | case notNumpy 14 | /// The numpy array type doesn't match the expected output type. 15 | case noDataConversion(PythonObject, Any.Type) 16 | /// Cannot find shape information from numpy array. 17 | case noShape 18 | /// Cannot find data pointer from the numpy array. 19 | case noPointer 20 | } 21 | /** 22 | * Initialize a tensor from numpy object. This doesn't copy the data over, rather, we simply 23 | * keep the original numpyArray alive. That's also why we don't support any data conversion. 24 | */ 25 | public init(numpy numpyArray: PythonObject) throws { 26 | // Check if input is a `numpy.ndarray` instance. 27 | guard Python.isinstance(numpyArray, np.ndarray) == true else { 28 | throw NumpyScalarCompatibleError.notNumpy 29 | } 30 | // Check if the dtype of the `ndarray` is compatible with the `Element` 31 | // type. 32 | guard Element.numpyScalarTypes.contains(numpyArray.dtype) else { 33 | throw NumpyScalarCompatibleError.noDataConversion(numpyArray.dtype, Element.self) 34 | } 35 | let pyShape = numpyArray.__array_interface__["shape"] 36 | guard let shape = [Int](pyShape) else { 37 | throw NumpyScalarCompatibleError.noShape 38 | } 39 | precondition(shape.count <= CCV_NNC_MAX_DIM_ALLOC) 40 | // Make sure that the array is contiguous in memory. This does a copy if 41 | // the array is not already contiguous in memory. 42 | let contiguousNumpyArray = np.ascontiguousarray(numpyArray) 43 | guard 44 | let ptrVal = 45 | UInt(contiguousNumpyArray.__array_interface__["data"].tuple2.0) 46 | else { 47 | throw NumpyScalarCompatibleError.noPointer 48 | } 49 | guard let pointer = UnsafeMutablePointer(bitPattern: ptrVal) else { 50 | fatalError("numpy.ndarray data pointer was nil") 51 | } 52 | self.init( 53 | .CPU, format: .NCHW, shape: TensorShape(shape), unsafeMutablePointer: pointer, 54 | bindLifetimeOf: contiguousNumpyArray) 55 | } 56 | } 57 | 58 | extension Tensor where Element: NumpyScalarCompatible { 59 | /** 60 | * Make a numpy object from a typed tensor. 61 | */ 62 | public func makeNumpyArray() -> PythonObject { 63 | precondition(!isTensorView) 64 | return withUnsafeBytes { bytes in 65 | let data = ctypes.cast(Int(bitPattern: bytes.baseAddress), ctypes.POINTER(Element.ctype)) 66 | let ndarray = np.ctypeslib.as_array(data, shape: PythonObject(tupleContentsOf: Array(shape))) 67 | return np.copy(ndarray) 68 | } 69 | } 70 | } 71 | 72 | extension Tensor: PythonConvertible where Element: NumpyScalarCompatible { 73 | public var pythonObject: PythonObject { 74 | makeNumpyArray() 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /nnc/StreamContext.swift: -------------------------------------------------------------------------------- 1 | import C_nnc 2 | 3 | /// A stream context is an object that an execution can be performed upon. 4 | public final class StreamContext { 5 | public enum Concurrency { 6 | case noLimit 7 | case limit(Int) 8 | init(rawValue: Int) { 9 | switch rawValue { 10 | case 0: 11 | self = .noLimit 12 | default: 13 | self = .limit(rawValue) 14 | } 15 | } 16 | var rawValue: Int { 17 | switch self { 18 | case .noLimit: 19 | return 0 20 | case .limit(let value): 21 | return value 22 | } 23 | } 24 | } 25 | 26 | let selfOwned: Bool 27 | let _stream: OpaquePointer 28 | 29 | init(stream: OpaquePointer, selfOwned: Bool) { 30 | _stream = stream 31 | self.selfOwned = selfOwned 32 | } 33 | 34 | /** 35 | * Create a new stream context. 36 | * 37 | * - Parameter kind: Whether this stream context is on CPU or GPU. 38 | */ 39 | public init(_ kind: DeviceKind) { 40 | let type: Int32 41 | switch kind { 42 | case .CPU: 43 | type = Int32(CCV_STREAM_CONTEXT_CPU) 44 | case .GPU(let ordinal): 45 | type = Int32((ordinal << 8) | CCV_STREAM_CONTEXT_GPU) 46 | } 47 | _stream = ccv_nnc_stream_context_new(type)! 48 | selfOwned = true 49 | } 50 | 51 | /** 52 | * Wait until all executions on this stream context to finish. 53 | */ 54 | public func joined() { 55 | ccv_nnc_stream_context_wait(_stream) 56 | } 57 | 58 | /** 59 | * Dispatch a block to be executed when all previous executions prior to 60 | * this method call are done. 61 | */ 62 | public func async(_ closure: @escaping () -> Void) { 63 | ccv_nnc_stream_context_add_callback( 64 | _stream, 65 | { context in 66 | let closure = Unmanaged.fromOpaque(context!).takeRetainedValue() as! (() -> Void) 67 | closure() 68 | }, Unmanaged.passRetained(closure as AnyObject).toOpaque()) 69 | } 70 | 71 | /** 72 | * Set seed for this particular stream context. If not set, it inherits from the global context. 73 | */ 74 | public func setSeed(_ seed: UInt32) { 75 | ccv_nnc_stream_context_set_seed(_stream, seed) 76 | } 77 | 78 | deinit { 79 | guard selfOwned else { return } 80 | ccv_nnc_stream_context_free(_stream) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /nnc/Wrapped.swift: -------------------------------------------------------------------------------- 1 | final class Wrapped { 2 | let value: T 3 | init(_ value: T) { 4 | self.value = value 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /scripts/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@com_github_bazelbuild_buildtools//buildifier:def.bzl", "buildifier") 2 | 3 | buildifier( 4 | name = "buildifier", 5 | ) 6 | -------------------------------------------------------------------------------- /scripts/buildifier/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | FILES=$(git diff --cached --name-only --diff-filter=ACMR "*BUILD*" | sed 's| |\\ |g') 3 | [ -z "$FILES" ] && exit 0 4 | 5 | # Bazel invocation may git clone some repositories, and override these env vars. 6 | 7 | _GIT_INDEX_FILE=$GIT_INDEX_FILE 8 | 9 | unset GIT_INDEX_FILE 10 | 11 | # Prettify all selected files 12 | echo "$FILES" | xargs -I {} bazel run --compilation_mode=opt @com_github_bazelbuild_buildtools//buildifier:buildifier -- -r `realpath {}` 13 | 14 | export GIT_INDEX_FILE=$_GIT_INDEX_FILE 15 | 16 | # Add back the modified/prettified files to staging 17 | echo "$FILES" | xargs git add 18 | 19 | exit 0 20 | -------------------------------------------------------------------------------- /scripts/compdb/BUILD: -------------------------------------------------------------------------------- 1 | py_binary( 2 | name = "compdb", 3 | srcs = ["compdb.py"], 4 | deps = [], 5 | ) 6 | -------------------------------------------------------------------------------- /scripts/compdb/compdb.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | import os 4 | import json 5 | 6 | 7 | def main(root): 8 | execroot = ( 9 | subprocess.check_output(["bazel", "info", "execution_root"]) 10 | .decode("utf-8") 11 | .strip() 12 | ) 13 | aquery = json.loads( 14 | subprocess.check_output( 15 | [ 16 | "bazel", 17 | "aquery", 18 | "--compilation_mode=dbg", 19 | "mnemonic(SwiftCompile, //...)", 20 | "--output=jsonproto", 21 | "--include_artifacts=false", 22 | ] 23 | ) 24 | ) 25 | actions = aquery["actions"] 26 | 27 | def command(action): 28 | arguments = list( 29 | filter( 30 | lambda x: "worker/worker" not in x and "-Xwrapped-swift" not in x, 31 | action["arguments"], 32 | ) 33 | ) 34 | arguments[-1] = os.path.join(root, arguments[-1]) 35 | return {"directory": execroot, "arguments": arguments, "file": arguments[-1]} 36 | 37 | compile_commands = map(command, actions) 38 | with open("compile_commands.json", "w+") as f: 39 | json.dump(list(compile_commands), f, sort_keys=True, indent=2) 40 | 41 | 42 | if __name__ == "__main__": 43 | os.chdir(sys.argv[1]) 44 | main(sys.argv[1]) 45 | -------------------------------------------------------------------------------- /scripts/docc.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euo pipefail 4 | 5 | GIT_ROOT=$(git rev-parse --show-toplevel) 6 | 7 | cd $GIT_ROOT 8 | 9 | # Generate symbol graph 10 | bazel build tensorboard:tensorboard gym:gym nnc:nnc nnc:nnc_python nnc:nnc_mujoco --features=swift.emit_symbol_graph 11 | # Copy it into a valid bundle 12 | mkdir -p s4nnc.docc 13 | cp bazel-bin/nnc/nnc.symbolgraph/*.json s4nnc.docc/ 14 | cp bazel-bin/nnc/nnc_python.symbolgraph/*.json s4nnc.docc/ 15 | cp bazel-bin/nnc/nnc_mujoco.symbolgraph/*.json s4nnc.docc/ 16 | cp bazel-bin/gym/gym.symbolgraph/*.json s4nnc.docc/ 17 | cp bazel-bin/tensorboard/tensorboard.symbolgraph/*.json s4nnc.docc/ 18 | # Remove all docs 19 | rm -rf docs 20 | # Convert into static hosting document 21 | docc convert s4nnc.docc --fallback-display-name="Swift for NNC" --fallback-bundle-identifier org.liuliu.s4nnc --fallback-bundle-version 0.0.1 --output-path docs --hosting-base-path /s4nnc --index --transform-for-static-hosting 22 | # Adding auto-redirect index.html 23 | echo '' > docs/index.html 24 | rm -rf s4nnc.docc 25 | -------------------------------------------------------------------------------- /scripts/install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euo pipefail 4 | 5 | GIT_CONFIG=$(git rev-parse --git-dir) 6 | GIT_ROOT=$(git rev-parse --show-toplevel) 7 | 8 | mkdir -p $GIT_CONFIG/hooks/pre-commit.d 9 | 10 | rm -f $GIT_CONFIG/hooks/pre-commit 11 | ln -s $GIT_ROOT/scripts/vendors/dispatch $GIT_CONFIG/hooks/pre-commit 12 | 13 | rm -f $GIT_CONFIG/hooks/pre-commit.d/swift-format 14 | ln -s $GIT_ROOT/scripts/swift-format/pre-commit $GIT_CONFIG/hooks/pre-commit.d/swift-format 15 | 16 | rm -f $GIT_CONFIG/hooks/pre-commit.d/buildifier 17 | ln -s $GIT_ROOT/scripts/buildifier/pre-commit $GIT_CONFIG/hooks/pre-commit.d/buildifier 18 | -------------------------------------------------------------------------------- /scripts/swift-format/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | FILES=$(git diff --cached --name-only --diff-filter=ACMR "*.swift" | sed 's| |\\ |g') 3 | [ -z "$FILES" ] && exit 0 4 | 5 | GIT_ROOT=$(git rev-parse --show-toplevel) 6 | 7 | # Bazel invocation may git clone some repositories, and override these env vars. 8 | 9 | _GIT_INDEX_FILE=$GIT_INDEX_FILE 10 | 11 | unset GIT_INDEX_FILE 12 | 13 | # Prettify all selected files 14 | echo "$FILES" | xargs -I {} bazel run --compilation_mode=opt @SwiftFormat//:swift-format -- format --configuration "$GIT_ROOT/.swift-format.json" -i `realpath {}` 15 | 16 | export GIT_INDEX_FILE=$_GIT_INDEX_FILE 17 | 18 | # Add back the modified/prettified files to staging 19 | echo "$FILES" | xargs git add 20 | 21 | exit 0 22 | 23 | -------------------------------------------------------------------------------- /scripts/vscode/build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euo pipefail 4 | 5 | if [ $# -eq 0 ] ; then 6 | WORKSPACE=$(git rev-parse --show-toplevel) 7 | else 8 | WORKSPACE=$1 9 | fi 10 | 11 | bazel run //scripts/compdb:compdb -- ${WORKSPACE} 12 | TARGETS=`bazel query "kind(swift_binary,//...)"` 13 | bazel build --compilation_mode=dbg ${TARGETS} 14 | mkdir -p ${WORKSPACE}/.index/store && mkdir -p ${WORKSPACE}/.index/db 15 | find ${WORKSPACE}/bazel-out/k8-dbg/bin -name "*.indexstore" | xargs -I {} rsync -a {}/v5 ${WORKSPACE}/.index/store 16 | -------------------------------------------------------------------------------- /tensorboard/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") 2 | 3 | swift_library( 4 | name = "tensorboard", 5 | srcs = glob(["**/*.swift"]), 6 | module_name = "TensorBoard", 7 | visibility = ["//visibility:public"], 8 | deps = [ 9 | "//nnc", 10 | "@SwiftProtobuf", 11 | "@SwiftSystem//:SystemPackage", 12 | "@ccv//lib:ccv", 13 | ], 14 | ) 15 | -------------------------------------------------------------------------------- /tensorboard/EventLogger.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import SystemPackage 3 | 4 | /// A logger that writes protobuf events into a tensorboard-readable file. 5 | struct EventLogger { 6 | /// File descriptor. 7 | private let fd: FileDescriptor 8 | 9 | /// Creates an instance with log located at `logDirectory`; creates 10 | /// the log file and add an initial event as well. 11 | init(logDirectory: String) throws { 12 | // Create the directory if it is missing. 13 | try FileManager.default.createDirectory(atPath: logDirectory, withIntermediateDirectories: true) 14 | 15 | // Create the file. 16 | let timestamp = Date().timeIntervalSince1970 17 | let filePath = URL(fileURLWithPath: logDirectory, isDirectory: true).appendingPathComponent( 18 | "events.out.tfevents." + String(timestamp).split(separator: ".")[0] 19 | ).path 20 | 21 | fd = try FileDescriptor.open( 22 | filePath, .writeOnly, options: [.create, .truncate], permissions: .ownerReadWrite) 23 | // Add an initial event. 24 | var initialEvent = Tensorboard_Event() 25 | initialEvent.wallTime = timestamp 26 | initialEvent.fileVersion = "brain.Event:2" 27 | try add(initialEvent) 28 | } 29 | 30 | func close() throws { 31 | try fd.close() 32 | } 33 | 34 | /// Add an event to the log. 35 | func add(_ event: Tensorboard_Event) throws { 36 | let data: Data = try event.serializedData() 37 | var header: Data = Data() 38 | header.append(contentsOf: UInt64(data.count).littleEndianBuffer) 39 | var payload = header 40 | payload.append(contentsOf: header.maskedCRC32C().littleEndianBuffer) 41 | payload.append(contentsOf: data) 42 | payload.append(contentsOf: data.maskedCRC32C().littleEndianBuffer) 43 | try fd.writeAll(payload) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /tensorboard/README.md: -------------------------------------------------------------------------------- 1 | This is forked from https://github.com/tensorflow/swift-models/tree/main/TensorBoard. 2 | 3 | Plan to add addImage, addHistogram and addGraph support. 4 | 5 | Removed dependency to swift-models, API names matching tensorboardx better. 6 | 7 | Use proto generated from tensorflow/tensorboard directly. 8 | -------------------------------------------------------------------------------- /tensorboard/SummaryWriter.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// A writer for writing model execution summaries to a tensorboard-readable file; the 4 | /// summaries include scalars for logging statistics, graphs for visualizing model etc. 5 | public struct SummaryWriter { 6 | /// Logger for writing the summaries as protobuf events to the file. 7 | let eventLogger: EventLogger 8 | 9 | static let dateFormatter: DateFormatter = { 10 | let dateFormatter = DateFormatter() 11 | dateFormatter.dateFormat = "MMMd-yyyy-HH-mm-ss" 12 | return dateFormatter 13 | }() 14 | 15 | let logDirectory: String 16 | 17 | /// Creates an instance with log located at `logDirectory`. 18 | public init(logDirectory: String, comment: String = "") { 19 | // Properly construct the folder name. It should be runs/commentcurrentdatetime/ 20 | self.logDirectory = logDirectory + "/runs/\(comment)\(Self.dateFormatter.string(from: Date()))" 21 | eventLogger = try! EventLogger(logDirectory: self.logDirectory) 22 | } 23 | 24 | public func close() throws { 25 | try eventLogger.close() 26 | } 27 | } 28 | 29 | extension SummaryWriter { 30 | /// Add training and validation statistics for tensorboard scalars dashboard. 31 | public func addScalar( 32 | _ tag: String, _ value: T, step: Int, 33 | wallTime: Double = Date().timeIntervalSince1970, displayName: String? = nil, 34 | description: String? = nil 35 | ) { 36 | var summaryMetadata = Tensorboard_SummaryMetadata() 37 | summaryMetadata.displayName = displayName ?? tag 38 | summaryMetadata.summaryDescription = description ?? "" 39 | 40 | var summaryValue = Tensorboard_Summary.Value() 41 | summaryValue.tag = tag 42 | summaryValue.simpleValue = Float(value) 43 | summaryValue.metadata = summaryMetadata 44 | 45 | var summary = Tensorboard_Summary() 46 | summary.value = [summaryValue] 47 | 48 | var event = Tensorboard_Event() 49 | event.summary = summary 50 | event.wallTime = wallTime 51 | event.step = Int64(step) 52 | do { 53 | try eventLogger.add(event) 54 | } catch { 55 | fatalError("Could not add \(event) to log: \(error)") 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /tensorboard/Support.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// Convert an integer represented in its `endian` who takes `count` 4 | /// bytes in length to an array of bytes. 5 | func toByteArr(endian: T, count: Int) -> [UInt8] { 6 | var _endian = endian 7 | let bytePtr = withUnsafePointer(to: &_endian) { 8 | $0.withMemoryRebound(to: UInt8.self, capacity: count) { 9 | UnsafeBufferPointer(start: $0, count: count) 10 | } 11 | } 12 | return [UInt8](bytePtr) 13 | } 14 | 15 | extension UInt64 { 16 | var littleEndianBuffer: [UInt8] { 17 | toByteArr(endian: self.littleEndian, count: 8) 18 | } 19 | 20 | var bigEndianBuffer: [UInt8] { 21 | toByteArr(endian: self.bigEndian, count: 8) 22 | } 23 | } 24 | 25 | extension UInt32 { 26 | var littleEndianBuffer: [UInt8] { 27 | toByteArr(endian: self.littleEndian, count: 4) 28 | } 29 | 30 | var bigEndianBuffer: [UInt8] { 31 | toByteArr(endian: self.bigEndian, count: 4) 32 | } 33 | } 34 | 35 | /// This is taken from https://github.com/tensorflow/swift-models/blob/master/Checkpoints/CheckpointReader.swift#L299 36 | /// Need to take to a common place later. 37 | extension Data { 38 | static var crc32CLookupTable: [UInt32] = { 39 | (0...255).map { index -> UInt32 in 40 | var lookupValue = UInt32(index) 41 | for _ in 0..<8 { 42 | lookupValue = 43 | (lookupValue % 2 == 0) 44 | ? (lookupValue >> 1) : (0x82F6_3B78 ^ (lookupValue >> 1)) 45 | } 46 | return lookupValue 47 | } 48 | }() 49 | 50 | func crc32C() -> UInt32 { 51 | var crc32: UInt32 = 0xFFFF_FFFF 52 | 53 | self.withUnsafeBytes { buffer in 54 | let totalBytes = self.count 55 | var index = 0 56 | while index < totalBytes { 57 | let byte = buffer[index] 58 | let lookupIndex = Int((crc32 ^ (UInt32(byte) & 0xFF)) & 0xFF) 59 | crc32 = (crc32 >> 8) ^ Data.crc32CLookupTable[lookupIndex] 60 | index = index &+ 1 61 | } 62 | } 63 | 64 | return crc32 ^ 0xFFFF_FFFF 65 | } 66 | 67 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/lib/hash/crc32c.h 68 | func maskedCRC32C() -> UInt32 { 69 | let crc32 = self.crc32C() 70 | let maskDelta: UInt32 = 0xA282_EAD8 71 | return ((crc32 &>> 15) | (crc32 &<< 17)) &+ maskDelta 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Protos 2 | 3 | Protobuf files copied from the main TensorFlow repository and used in the case of tensorboard-notf, which builds without a TensorFlow dependency. 4 | 5 | ## Update process 6 | 7 | Copy the proto files from TensorFlow to TensorBoard using the following process: 8 | 9 | * git clone tensorflow/tensorboard and tensorflow/tensorboard in ~ 10 | * cd ~/tensorboard 11 | * ./tensorboard/compat/proto/update.sh 12 | * git add . 13 | * git commit -m "Update TensorFlow protos to xxxx" 14 | 15 | These were taken from TensorFlow version 1.12.0-dev20181012 16 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/allocation_description.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "AllocationDescriptionProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.tensorflow.framework"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/allocation_description_go_proto"; 10 | 11 | message AllocationDescription { 12 | // Total number of bytes requested 13 | int64 requested_bytes = 1; 14 | 15 | // Total number of bytes allocated if known 16 | int64 allocated_bytes = 2; 17 | 18 | // Name of the allocator used 19 | string allocator_name = 3; 20 | 21 | // Identifier of the allocated buffer if known 22 | int64 allocation_id = 4; 23 | 24 | // Set if this tensor only has one remaining reference 25 | bool has_single_reference = 5; 26 | 27 | // Address of the allocation. 28 | uint64 ptr = 6; 29 | } 30 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/attr_value.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | import "tensorboard/compat/proto/tensor.proto"; 6 | import "tensorboard/compat/proto/tensor_shape.proto"; 7 | import "tensorboard/compat/proto/types.proto"; 8 | 9 | option cc_enable_arenas = true; 10 | option java_outer_classname = "AttrValueProtos"; 11 | option java_multiple_files = true; 12 | option java_package = "org.tensorflow.framework"; 13 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/attr_value_go_proto"; 14 | 15 | // Protocol buffer representing the value for an attr used to configure an Op. 16 | // Comment indicates the corresponding attr type. Only the field matching the 17 | // attr type may be filled. 18 | message AttrValue { 19 | // DISABLED.IfChange 20 | message ListValue { 21 | repeated bytes s = 2; // "list(string)" 22 | repeated int64 i = 3 [packed = true]; // "list(int)" 23 | repeated float f = 4 [packed = true]; // "list(float)" 24 | repeated bool b = 5 [packed = true]; // "list(bool)" 25 | repeated DataType type = 6 [packed = true]; // "list(type)" 26 | repeated TensorShapeProto shape = 7; // "list(shape)" 27 | repeated TensorProto tensor = 8; // "list(tensor)" 28 | repeated NameAttrList func = 9; // "list(attr)" 29 | } 30 | // DISABLED.ThenChange(//tensorflow/c/c_api.cc) 31 | 32 | oneof value { 33 | bytes s = 2; // "string" 34 | int64 i = 3; // "int" 35 | float f = 4; // "float" 36 | bool b = 5; // "bool" 37 | DataType type = 6; // "type" 38 | TensorShapeProto shape = 7; // "shape" 39 | TensorProto tensor = 8; // "tensor" 40 | ListValue list = 1; // any "list(...)" 41 | 42 | // "func" represents a function. func.name is a function's name or 43 | // a primitive op's name. func.attr.first is the name of an attr 44 | // defined for that function. func.attr.second is the value for 45 | // that attr in the instantiation. 46 | NameAttrList func = 10; 47 | 48 | // This is a placeholder only used in nodes defined inside a 49 | // function. It indicates the attr value will be supplied when 50 | // the function is instantiated. For example, let us suppose a 51 | // node "N" in function "FN". "N" has an attr "A" with value 52 | // placeholder = "foo". When FN is instantiated with attr "foo" 53 | // set to "bar", the instantiated node N's attr A will have been 54 | // given the value "bar". 55 | string placeholder = 9; 56 | } 57 | } 58 | 59 | // A list of attr names and their values. The whole list is attached 60 | // with a string name. E.g., MatMul[T=float]. 61 | message NameAttrList { 62 | string name = 1; 63 | map attr = 2; 64 | } 65 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/cluster.proto: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | syntax = "proto3"; 17 | 18 | package tensorboard; 19 | 20 | option cc_enable_arenas = true; 21 | option java_outer_classname = "ClusterProtos"; 22 | option java_multiple_files = true; 23 | option java_package = "org.tensorflow.distruntime"; 24 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; 25 | 26 | // This file contains protos to be used when defining a TensorFlow 27 | // cluster. 28 | // 29 | // EXAMPLES 30 | // -------- 31 | // 32 | // 1. A single-process cluster, containing "/job:local/task:0". 33 | // 34 | // Cluster: 35 | // job { name: 'local' tasks { key: 0 value: 'localhost:2222' } } 36 | // 37 | // Server: 38 | // cluster { $CLUSTER } job_name: 'local' task_index: 0 39 | // 40 | // 2. A two-process cluster, containing "/job:local/task:{0,1}". 41 | // 42 | // Cluster: 43 | // job { name: 'local' tasks { key: 0 value: 'localhost:2222' } 44 | // tasks { key: 1 value: 'localhost:2223' } } 45 | // 46 | // Servers: 47 | // cluster { $CLUSTER } job_name: 'local' task_index: 0 48 | // cluster { $CLUSTER } job_name: 'local' task_index: 1 49 | // 50 | // 3. A two-job cluster, containing "/job:worker/task:{0,1,2}" and 51 | // "/job:ps/task:{0,1}". 52 | // 53 | // Cluster: 54 | // job { name: 'worker' tasks { key: 0 value: 'worker1:2222' } 55 | // tasks { key: 1 value: 'worker2:2222' } 56 | // tasks { key: 2 value: 'worker3:2222' } } 57 | // job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } 58 | // tasks { key: 1 value: 'ps1:2222' } } 59 | // 60 | // Servers: 61 | // cluster { $CLUSTER } job_name: 'worker' task_index: 0 62 | // cluster { $CLUSTER } job_name: 'worker' task_index: 1 63 | // cluster { $CLUSTER } job_name: 'worker' task_index: 2 64 | // cluster { $CLUSTER } job_name: 'ps' task_index: 0 65 | // cluster { $CLUSTER } job_name: 'ps' task_index: 1 66 | 67 | // Defines a single job in a TensorFlow cluster. 68 | message JobDef { 69 | // The name of this job. 70 | string name = 1; 71 | 72 | // Mapping from task ID to "hostname:port" string. 73 | // 74 | // If the `name` field contains "worker", and the `tasks` map contains a 75 | // mapping from 7 to "example.org:2222", then the device prefix 76 | // "/job:worker/task:7" will be assigned to "example.org:2222". 77 | map tasks = 2; 78 | } 79 | 80 | // Defines a TensorFlow cluster as a set of jobs. 81 | message ClusterDef { 82 | // The jobs that comprise the cluster. 83 | repeated JobDef job = 1; 84 | } 85 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/coordination_config.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; 6 | 7 | // Coordination service configuration parameters. 8 | // The system picks appropriate values for fields that are not set. 9 | message CoordinationServiceConfig { 10 | // Type of coordination service implementation to enable. 11 | // For example, setting the service type as "standalone" starts a service 12 | // instance on the leader task to provide the coordination services such as 13 | // heartbeats and consistent key-value store. 14 | string service_type = 1; 15 | 16 | // Address where the coordination service instance is hosted. 17 | string service_leader = 2; 18 | 19 | // Whether to enable the health check mechanism. 20 | bool enable_health_check = 3; 21 | 22 | // Maximum wait time for all members in the cluster to be registered. 23 | int64 cluster_register_timeout_in_ms = 4; 24 | 25 | // Heartbeat timeout, if a task does not record heartbeat in this time 26 | // window, it will be considered disconnected. 27 | // Note: This is also used as a grace period to accept any heartbeats after 28 | // the agent has disconnected, to account for the lag time between the service 29 | // recording the state change and the agent stopping heartbeats. 30 | int64 heartbeat_timeout_in_ms = 5; 31 | 32 | // The list of jobs that partipate in the coordination service. If empty, all 33 | // jobs will be included in the coordination service by default. 34 | repeated string coordinated_jobs = 6; 35 | 36 | // Denotes how long to wait for all coordination agents to reach the barriers 37 | // (after the first shutdown request) before disconnecting together. If 38 | // set to 0, no barrier is imposed upon shutdown and each worker can 39 | // disconnect individually. 40 | int64 shutdown_barrier_timeout_in_ms = 7; 41 | 42 | // If set, agents do not make an explicit Shutdown() call. Service will only 43 | // find out about the disconnecte agent via stale heartbeats. Used for 44 | // testing. 45 | bool agent_destruction_without_shutdown = 8; 46 | } 47 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/cost_graph.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | import "tensorboard/compat/proto/tensor_shape.proto"; 6 | import "tensorboard/compat/proto/types.proto"; 7 | 8 | option cc_enable_arenas = true; 9 | option java_outer_classname = "CostGraphProtos"; 10 | option java_multiple_files = true; 11 | option java_package = "org.tensorflow.framework"; 12 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/cost_graph_go_proto"; 13 | 14 | message CostGraphDef { 15 | message Node { 16 | // The name of the node. Names are globally unique. 17 | string name = 1; 18 | 19 | // The device of the node. Can be empty if the node is mapped to the 20 | // default partition or partitioning hasn't been run yet. 21 | string device = 2; 22 | 23 | // The id of the node. Node ids are only unique inside a partition. 24 | int32 id = 3; 25 | 26 | // Inputs of this node. They must be executed before this node can be 27 | // executed. An input is a particular output of another node, specified 28 | // by the node id and the output index. 29 | message InputInfo { 30 | int32 preceding_node = 1; 31 | int32 preceding_port = 2; 32 | } 33 | repeated InputInfo input_info = 4; 34 | 35 | // Outputs of this node. 36 | message OutputInfo { 37 | int64 size = 1; 38 | // If >= 0, the output is an alias of an input. Note that an alias input 39 | // may itself be an alias. The algorithm will therefore need to follow 40 | // those pointers. 41 | int64 alias_input_port = 2; 42 | TensorShapeProto shape = 3; 43 | DataType dtype = 4; 44 | } 45 | repeated OutputInfo output_info = 5; 46 | 47 | // Temporary memory used by this node. 48 | int64 temporary_memory_size = 6; 49 | 50 | // Persistent memory used by this node. 51 | int64 persistent_memory_size = 12; 52 | 53 | int64 host_temp_memory_size = 10 [deprecated = true]; 54 | int64 device_temp_memory_size = 11 [deprecated = true]; 55 | int64 device_persistent_memory_size = 16 [deprecated = true]; 56 | 57 | // Estimate of the computational cost of this node, in microseconds. 58 | int64 compute_cost = 9; 59 | 60 | // Analytical estimate of the computational cost of this node, in 61 | // microseconds. 62 | int64 compute_time = 14; 63 | 64 | // Analytical estimate of the memory access cost of this node, in 65 | // microseconds. 66 | int64 memory_time = 15; 67 | 68 | // If true, the output is permanent: it can't be discarded, because this 69 | // node is part of the "final output". Nodes may depend on final nodes. 70 | bool is_final = 7; 71 | 72 | // Ids of the control inputs for this node. 73 | repeated int32 control_input = 8; 74 | 75 | // Are the costs inaccurate? 76 | bool inaccurate = 17; 77 | } 78 | repeated Node node = 1; 79 | 80 | // Total cost of this graph, typically used for balancing decisions. 81 | message AggregatedCost { 82 | // Aggregated cost value. 83 | float cost = 1; 84 | 85 | // Aggregated cost dimension (e.g. 'memory', 'compute', 'network'). 86 | string dimension = 2; 87 | } 88 | repeated AggregatedCost cost = 2; 89 | } 90 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/cpp_shape_inference.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | import "tensorboard/compat/proto/full_type.proto"; 6 | import "tensorboard/compat/proto/tensor_shape.proto"; 7 | import "tensorboard/compat/proto/types.proto"; 8 | 9 | option cc_enable_arenas = true; 10 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/python/framework/cpp_shape_inference_go_proto"; 11 | 12 | message CppShapeInferenceResult { 13 | message HandleShapeAndType { 14 | reserved 3; 15 | 16 | TensorShapeProto shape = 1; 17 | DataType dtype = 2; 18 | FullTypeDef type = 4; 19 | } 20 | message HandleData { 21 | bool is_set = 1; 22 | 23 | // Only valid if . 24 | repeated HandleShapeAndType shape_and_type = 2; 25 | } 26 | TensorShapeProto shape = 1; 27 | 28 | reserved 2; // was handle_shape 29 | reserved 3; // was handle_dtype 30 | HandleData handle_data = 4; 31 | } 32 | 33 | message CppShapeInferenceInputsNeeded { 34 | repeated int32 input_tensors_needed = 1; 35 | repeated int32 input_tensors_as_shapes_needed = 2; 36 | } 37 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/debug.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "DebugProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.tensorflow.framework"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; 10 | 11 | // Option for watching a node in TensorFlow Debugger (tfdbg). 12 | message DebugTensorWatch { 13 | // Name of the node to watch. 14 | // Use "*" for wildcard. But note: currently, regex is not supported in 15 | // general. 16 | string node_name = 1; 17 | 18 | // Output slot to watch. 19 | // The semantics of output_slot == -1 is that all outputs of the node 20 | // will be watched (i.e., a wildcard). 21 | // Other negative values of output_slot are invalid and will lead to 22 | // errors currently. 23 | int32 output_slot = 2; 24 | 25 | // Name(s) of the debugging op(s). 26 | // One or more than one probes on a tensor. 27 | // e.g., {"DebugIdentity", "DebugNanCount"} 28 | repeated string debug_ops = 3; 29 | 30 | // URL(s) for debug targets(s). 31 | // 32 | // Supported URL formats are: 33 | // - file:///foo/tfdbg_dump: Writes out Event content to file 34 | // /foo/tfdbg_dump. Assumes all directories can be created if they don't 35 | // already exist. 36 | // - grpc://localhost:11011: Sends an RPC request to an EventListener 37 | // service running at localhost:11011 with the event. 38 | // - memcbk:///event_key: Routes tensors to clients using the 39 | // callback registered with the DebugCallbackRegistry for event_key. 40 | // 41 | // Each debug op listed in debug_ops will publish its output tensor (debug 42 | // signal) to all URLs in debug_urls. 43 | // 44 | // N.B. Session::Run() supports concurrent invocations of the same inputs 45 | // (feed keys), outputs and target nodes. If such concurrent invocations 46 | // are to be debugged, the callers of Session::Run() must use distinct 47 | // debug_urls to make sure that the streamed or dumped events do not overlap 48 | // among the invocations. 49 | // TODO(cais): More visible documentation of this in g3docs. 50 | repeated string debug_urls = 4; 51 | 52 | // Do not error out if debug op creation fails (e.g., due to dtype 53 | // incompatibility). Instead, just log the failure. 54 | bool tolerate_debug_op_creation_failures = 5; 55 | } 56 | 57 | // Options for initializing DebuggerState in TensorFlow Debugger (tfdbg). 58 | message DebugOptions { 59 | // Debugging options 60 | repeated DebugTensorWatch debug_tensor_watch_opts = 4; 61 | 62 | // Caller-specified global step count. 63 | // Note that this is distinct from the session run count and the executor 64 | // step count. 65 | int64 global_step = 10; 66 | 67 | // Whether the total disk usage of tfdbg is to be reset to zero 68 | // in this Session.run call. This is used by wrappers and hooks 69 | // such as the local CLI ones to indicate that the dumped tensors 70 | // are cleaned up from the disk after each Session.run. 71 | bool reset_disk_byte_usage = 11; 72 | } 73 | 74 | message DebuggedSourceFile { 75 | // The host name on which a source code file is located. 76 | string host = 1; 77 | 78 | // Path to the source code file. 79 | string file_path = 2; 80 | 81 | // The timestamp at which the source code file is last modified. 82 | int64 last_modified = 3; 83 | 84 | // Byte size of the file. 85 | int64 bytes = 4; 86 | 87 | // Line-by-line content of the source code file. 88 | repeated string lines = 5; 89 | } 90 | 91 | message DebuggedSourceFiles { 92 | // A collection of source code files. 93 | repeated DebuggedSourceFile source_files = 1; 94 | } 95 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/event.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | import "tensorboard/compat/proto/summary.proto"; 6 | 7 | option cc_enable_arenas = true; 8 | option java_outer_classname = "EventProtos"; 9 | option java_multiple_files = true; 10 | option java_package = "org.tensorflow.util"; 11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/util/event_go_proto"; 12 | 13 | // Protocol buffer representing an event that happened during 14 | // the execution of a Brain model. 15 | message Event { 16 | // Timestamp of the event. 17 | double wall_time = 1; 18 | 19 | // Global step of the event. 20 | int64 step = 2; 21 | 22 | oneof what { 23 | // An event file was started, with the specified version. 24 | // This is use to identify the contents of the record IO files 25 | // easily. Current version is "brain.Event:2". All versions 26 | // start with "brain.Event:". 27 | string file_version = 3; 28 | // An encoded version of a GraphDef. 29 | bytes graph_def = 4; 30 | // A summary was generated. 31 | Summary summary = 5; 32 | // The user output a log message. This was theoretically used by the defunct 33 | // tensorboard_logging module, which has since been removed; this field is 34 | // now deprecated and should not be used. 35 | LogMessage log_message = 6 [deprecated = true]; 36 | // The state of the session which can be used for restarting after crashes. 37 | SessionLog session_log = 7; 38 | // The metadata returned by running a session.run() call. 39 | TaggedRunMetadata tagged_run_metadata = 8; 40 | // An encoded version of a MetaGraphDef. 41 | bytes meta_graph_def = 9; 42 | } 43 | } 44 | 45 | // Protocol buffer used for logging messages to the events file. 46 | // 47 | // This was theoretically used by the defunct tensorboard_logging module, which 48 | // has been removed; this message is now deprecated and should not be used. 49 | message LogMessage { 50 | option deprecated = true; 51 | enum Level { 52 | option deprecated = true; 53 | UNKNOWN = 0; 54 | // Note: The logging level 10 cannot be named DEBUG. Some software 55 | // projects compile their C/C++ code with -DDEBUG in debug builds. So the 56 | // C++ code generated from this file should not have an identifier named 57 | // DEBUG. 58 | DEBUGGING = 10; 59 | INFO = 20; 60 | WARN = 30; 61 | ERROR = 40; 62 | FATAL = 50; 63 | } 64 | Level level = 1; 65 | string message = 2; 66 | } 67 | 68 | // Protocol buffer used for logging session state. 69 | message SessionLog { 70 | enum SessionStatus { 71 | STATUS_UNSPECIFIED = 0; 72 | START = 1; 73 | STOP = 2; 74 | CHECKPOINT = 3; 75 | } 76 | 77 | SessionStatus status = 1; 78 | // This checkpoint_path contains both the path and filename. 79 | string checkpoint_path = 2; 80 | string msg = 3; 81 | } 82 | 83 | // For logging the metadata output for a single session.run() call. 84 | message TaggedRunMetadata { 85 | // Tag name associated with this metadata. 86 | string tag = 1; 87 | // Byte-encoded version of the `RunMetadata` proto in order to allow lazy 88 | // deserialization. 89 | bytes run_metadata = 2; 90 | } 91 | 92 | // Worker heartbeat messages. Support for these operations is currently 93 | // internal and expected to change. 94 | 95 | // Current health status of a worker. 96 | enum WorkerHealth { 97 | OK = 0; // By default a worker is healthy. 98 | RECEIVED_SHUTDOWN_SIGNAL = 1; 99 | INTERNAL_ERROR = 2; 100 | SHUTTING_DOWN = 3; // Worker has been instructed to shutdown after a timeout. 101 | } 102 | 103 | // Indicates the behavior of the worker when an internal error or shutdown 104 | // signal is received. 105 | enum WorkerShutdownMode { 106 | DEFAULT = 0; 107 | NOT_CONFIGURED = 1; 108 | WAIT_FOR_COORDINATOR = 2; 109 | SHUTDOWN_AFTER_TIMEOUT = 3; 110 | } 111 | 112 | message WatchdogConfig { 113 | int64 timeout_ms = 1; 114 | } 115 | 116 | message RequestedExitCode { 117 | int32 exit_code = 1; 118 | } 119 | 120 | message WorkerHeartbeatRequest { 121 | WorkerShutdownMode shutdown_mode = 1; 122 | WatchdogConfig watchdog_config = 2; 123 | RequestedExitCode exit_code = 3; 124 | } 125 | 126 | message WorkerHeartbeatResponse { 127 | WorkerHealth health_status = 1; 128 | repeated Event worker_log = 2; 129 | string hostname = 3; 130 | } 131 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/graph.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | import "tensorboard/compat/proto/function.proto"; 6 | import "tensorboard/compat/proto/node_def.proto"; 7 | import "tensorboard/compat/proto/versions.proto"; 8 | 9 | option cc_enable_arenas = true; 10 | option java_outer_classname = "GraphProtos"; 11 | option java_multiple_files = true; 12 | option java_package = "org.tensorflow.framework"; 13 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/graph_go_proto"; 14 | 15 | // Represents the graph of operations 16 | message GraphDef { 17 | repeated NodeDef node = 1; 18 | 19 | // Compatibility versions of the graph. See core/public/version.h for version 20 | // history. The GraphDef version is distinct from the TensorFlow version, and 21 | // each release of TensorFlow will support a range of GraphDef versions. 22 | VersionDef versions = 4; 23 | 24 | // Deprecated single version field; use versions above instead. Since all 25 | // GraphDef changes before "versions" was introduced were forward 26 | // compatible, this field is entirely ignored. 27 | int32 version = 3 [deprecated = true]; 28 | 29 | // "library" provides user-defined functions. 30 | // 31 | // Naming: 32 | // * library.function.name are in a flat namespace. 33 | // NOTE: We may need to change it to be hierarchical to support 34 | // different orgs. E.g., 35 | // { "/google/nn", { ... }}, 36 | // { "/google/vision", { ... }} 37 | // { "/org_foo/module_bar", { ... }} 38 | // map named_lib; 39 | // * If node[i].op is the name of one function in "library", 40 | // node[i] is deemed as a function call. Otherwise, node[i].op 41 | // must be a primitive operation supported by the runtime. 42 | // 43 | // 44 | // Function call semantics: 45 | // 46 | // * The callee may start execution as soon as some of its inputs 47 | // are ready. The caller may want to use Tuple() mechanism to 48 | // ensure all inputs are ready in the same time. 49 | // 50 | // * The consumer of return values may start executing as soon as 51 | // the return values the consumer depends on are ready. The 52 | // consumer may want to use Tuple() mechanism to ensure the 53 | // consumer does not start until all return values of the callee 54 | // function are ready. 55 | FunctionDefLibrary library = 2; 56 | } 57 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/node_def.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | import "tensorboard/compat/proto/attr_value.proto"; 6 | import "tensorboard/compat/proto/full_type.proto"; 7 | 8 | option cc_enable_arenas = true; 9 | option java_outer_classname = "NodeProto"; 10 | option java_multiple_files = true; 11 | option java_package = "org.tensorflow.framework"; 12 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/node_def_go_proto"; 13 | 14 | message NodeDef { 15 | // The name given to this operator. Used for naming inputs, 16 | // logging, visualization, etc. Unique within a single GraphDef. 17 | // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_>./]*". 18 | string name = 1; 19 | 20 | // The operation name. There may be custom parameters in attrs. 21 | // Op names starting with an underscore are reserved for internal use. 22 | string op = 2; 23 | 24 | // Each input is "node:src_output" with "node" being a string name and 25 | // "src_output" indicating which output tensor to use from "node". If 26 | // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs 27 | // may optionally be followed by control inputs that have the format 28 | // "^node". 29 | repeated string input = 3; 30 | 31 | // A (possibly partial) specification for the device on which this 32 | // node should be placed. 33 | // The expected syntax for this string is as follows: 34 | // 35 | // DEVICE_SPEC ::= PARTIAL_SPEC 36 | // 37 | // PARTIAL_SPEC ::= ("/" CONSTRAINT) * 38 | // CONSTRAINT ::= ("job:" JOB_NAME) 39 | // | ("replica:" [1-9][0-9]*) 40 | // | ("task:" [1-9][0-9]*) 41 | // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) 42 | // 43 | // Valid values for this string include: 44 | // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) 45 | // * "/job:worker/device:GPU:3" (partial specification) 46 | // * "" (no specification) 47 | // 48 | // If the constraints do not resolve to a single device (or if this 49 | // field is empty or not present), the runtime will attempt to 50 | // choose a device automatically. 51 | string device = 4; 52 | 53 | // Operation-specific graph-construction-time configuration. 54 | // Note that this should include all attrs defined in the 55 | // corresponding OpDef, including those with a value matching 56 | // the default -- this allows the default to change and makes 57 | // NodeDefs easier to interpret on their own. However, if 58 | // an attr with a default is not specified in this list, the 59 | // default will be used. 60 | // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and 61 | // one of the names from the corresponding OpDef's attr field). 62 | // The values must have a type matching the corresponding OpDef 63 | // attr's type field. 64 | // TODO(josh11b): Add some examples here showing best practices. 65 | map attr = 5; 66 | 67 | message ExperimentalDebugInfo { 68 | // Opaque string inserted into error messages created by the runtime. 69 | // 70 | // This is intended to store the list of names of the nodes from the 71 | // original graph that this node was derived. For example if this node, say 72 | // C, was result of a fusion of 2 nodes A and B, then 'original_node' would 73 | // be {A, B}. This information can be used to map errors originating at the 74 | // current node to some top level source code. 75 | repeated string original_node_names = 1; 76 | 77 | // This is intended to store the list of names of the functions from the 78 | // original graph that this node was derived. For example if this node, say 79 | // C, was result of a fusion of node A in function FA and node B in function 80 | // FB, then `original_funcs` would be {FA, FB}. If the node is in the top 81 | // level graph, the `original_func` is empty. This information, with the 82 | // `original_node_names` can be used to map errors originating at the 83 | // current ndoe to some top level source code. 84 | repeated string original_func_names = 2; 85 | } 86 | 87 | // This stores debug information associated with the node. 88 | ExperimentalDebugInfo experimental_debug_info = 6; 89 | 90 | // The complete type of this node. Experimental and subject to change. 91 | // Currently, the field only contains the return types of the node. That will 92 | // extend in the future to contain the entire signature of the node, as a 93 | // function type. 94 | FullTypeDef experimental_type = 7; 95 | } 96 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/resource_handle.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | import "tensorboard/compat/proto/tensor_shape.proto"; 6 | import "tensorboard/compat/proto/types.proto"; 7 | 8 | option cc_enable_arenas = true; 9 | option java_outer_classname = "ResourceHandle"; 10 | option java_multiple_files = true; 11 | option java_package = "org.tensorflow.framework"; 12 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/resource_handle_go_proto"; 13 | 14 | // Protocol buffer representing a handle to a tensorflow resource. Handles are 15 | // not valid across executions, but can be serialized back and forth from within 16 | // a single run. 17 | message ResourceHandleProto { 18 | // Unique name for the device containing the resource. 19 | string device = 1; 20 | 21 | // Container in which this resource is placed. 22 | string container = 2; 23 | 24 | // Unique name of this resource. 25 | string name = 3; 26 | 27 | // Hash code for the type of the resource. Is only valid in the same device 28 | // and in the same execution. 29 | uint64 hash_code = 4; 30 | 31 | // For debug-only, the name of the type pointed to by this handle, if 32 | // available. 33 | string maybe_type_name = 5; 34 | 35 | // Protocol buffer representing a pair of (data type, tensor shape). 36 | message DtypeAndShape { 37 | DataType dtype = 1; 38 | TensorShapeProto shape = 2; 39 | } 40 | 41 | // Data types and shapes for the underlying resource. 42 | repeated DtypeAndShape dtypes_and_shapes = 6; 43 | 44 | reserved 7; 45 | } 46 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/saver.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "SaverProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.tensorflow.util"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; 10 | 11 | // Protocol buffer representing the configuration of a Saver. 12 | message SaverDef { 13 | // The name of the tensor in which to specify the filename when saving or 14 | // restoring a model checkpoint. 15 | string filename_tensor_name = 1; 16 | 17 | // The operation to run when saving a model checkpoint. 18 | string save_tensor_name = 2; 19 | 20 | // The operation to run when restoring a model checkpoint. 21 | string restore_op_name = 3; 22 | 23 | // Maximum number of checkpoints to keep. If 0, no checkpoints are deleted. 24 | int32 max_to_keep = 4; 25 | 26 | // Shard the save files, one per device that has Variable nodes. 27 | bool sharded = 5; 28 | 29 | // How often to keep an additional checkpoint. If not specified, only the last 30 | // "max_to_keep" checkpoints are kept; if specified, in addition to keeping 31 | // the last "max_to_keep" checkpoints, an additional checkpoint will be kept 32 | // for every n hours of training. 33 | float keep_checkpoint_every_n_hours = 6; 34 | 35 | // A version number that identifies a different on-disk checkpoint format. 36 | // Usually, each subclass of BaseSaverBuilder works with a particular 37 | // version/format. However, it is possible that the same builder may be 38 | // upgraded to support a newer checkpoint format in the future. 39 | enum CheckpointFormatVersion { 40 | // Internal legacy format. 41 | LEGACY = 0; 42 | // Deprecated format: tf.Saver() which works with tensorflow::table::Table. 43 | V1 = 1; 44 | // Current format: more efficient. 45 | V2 = 2; 46 | } 47 | CheckpointFormatVersion version = 7; 48 | } 49 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/step_stats.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | import "tensorboard/compat/proto/allocation_description.proto"; 6 | import "tensorboard/compat/proto/tensor_description.proto"; 7 | 8 | option cc_enable_arenas = true; 9 | option java_outer_classname = "StepStatsProtos"; 10 | option java_multiple_files = true; 11 | option java_package = "org.tensorflow.framework"; 12 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/step_stats_go_proto"; 13 | 14 | // An allocation/de-allocation operation performed by the allocator. 15 | message AllocationRecord { 16 | // The timestamp of the operation. 17 | int64 alloc_micros = 1; 18 | // Number of bytes allocated, or de-allocated if negative. 19 | int64 alloc_bytes = 2; 20 | } 21 | 22 | message AllocatorMemoryUsed { 23 | string allocator_name = 1; 24 | // These are per-node allocator memory stats. 25 | int64 total_bytes = 2; 26 | int64 peak_bytes = 3; 27 | // The bytes that are not deallocated. 28 | int64 live_bytes = 4; 29 | // The allocation and deallocation timeline. 30 | repeated AllocationRecord allocation_records = 6; 31 | 32 | // These are snapshots of the overall allocator memory stats. 33 | // The number of live bytes currently allocated by the allocator. 34 | int64 allocator_bytes_in_use = 5; 35 | } 36 | 37 | // Output sizes recorded for a single execution of a graph node. 38 | message NodeOutput { 39 | int32 slot = 1; 40 | TensorDescription tensor_description = 3; 41 | } 42 | 43 | // For memory tracking. 44 | message MemoryStats { 45 | int64 temp_memory_size = 1; 46 | int64 persistent_memory_size = 3; 47 | repeated int64 persistent_tensor_alloc_ids = 5; 48 | 49 | int64 device_temp_memory_size = 2 [deprecated = true]; 50 | int64 device_persistent_memory_size = 4 [deprecated = true]; 51 | repeated int64 device_persistent_tensor_alloc_ids = 6 [deprecated = true]; 52 | } 53 | 54 | // Time/size stats recorded for a single execution of a graph node. 55 | message NodeExecStats { 56 | // TODO(tucker): Use some more compact form of node identity than 57 | // the full string name. Either all processes should agree on a 58 | // global id (cost_id?) for each node, or we should use a hash of 59 | // the name. 60 | string node_name = 1; 61 | int64 all_start_micros = 2; 62 | int64 op_start_rel_micros = 3; 63 | int64 op_end_rel_micros = 4; 64 | int64 all_end_rel_micros = 5; 65 | repeated AllocatorMemoryUsed memory = 6; 66 | repeated NodeOutput output = 7; 67 | string timeline_label = 8; 68 | int64 scheduled_micros = 9; 69 | uint32 thread_id = 10; 70 | repeated AllocationDescription referenced_tensor = 11; 71 | MemoryStats memory_stats = 12; 72 | int64 all_start_nanos = 13; 73 | int64 op_start_rel_nanos = 14; 74 | int64 op_end_rel_nanos = 15; 75 | int64 all_end_rel_nanos = 16; 76 | int64 scheduled_nanos = 17; 77 | } 78 | 79 | message DeviceStepStats { 80 | string device = 1; 81 | repeated NodeExecStats node_stats = 2; 82 | // Its key is thread id. 83 | map thread_names = 3; 84 | } 85 | 86 | message StepStats { 87 | repeated DeviceStepStats dev_stats = 1; 88 | } 89 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/summary.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | import "tensorboard/compat/proto/tensor.proto"; 6 | 7 | option cc_enable_arenas = true; 8 | option java_outer_classname = "SummaryProtos"; 9 | option java_multiple_files = true; 10 | option java_package = "org.tensorflow.framework"; 11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/summary_go_proto"; 12 | 13 | // Metadata associated with a series of Summary data 14 | message SummaryDescription { 15 | // Hint on how plugins should process the data in this series. 16 | // Supported values include "scalar", "histogram", "image", "audio" 17 | string type_hint = 1; 18 | } 19 | 20 | // Serialization format for histogram module in 21 | // core/lib/histogram/histogram.h 22 | message HistogramProto { 23 | double min = 1; 24 | double max = 2; 25 | double num = 3; 26 | double sum = 4; 27 | double sum_squares = 5; 28 | 29 | // Parallel arrays encoding the bucket boundaries and the bucket values. 30 | // bucket(i) is the count for the bucket i. The range for 31 | // a bucket is: 32 | // i == 0: -DBL_MAX .. bucket_limit(0) 33 | // i != 0: bucket_limit(i-1) .. bucket_limit(i) 34 | repeated double bucket_limit = 6 [packed = true]; 35 | repeated double bucket = 7 [packed = true]; 36 | } 37 | 38 | // A SummaryMetadata encapsulates information on which plugins are able to make 39 | // use of a certain summary value. 40 | message SummaryMetadata { 41 | message PluginData { 42 | // The name of the plugin this data pertains to. 43 | string plugin_name = 1; 44 | 45 | // The content to store for the plugin. The best practice is for this to be 46 | // a binary serialized protocol buffer. 47 | bytes content = 2; 48 | } 49 | 50 | // Data that associates a summary with a certain plugin. 51 | PluginData plugin_data = 1; 52 | 53 | // Display name for viewing in TensorBoard. 54 | string display_name = 2; 55 | 56 | // Longform readable description of the summary sequence. Markdown supported. 57 | string summary_description = 3; 58 | 59 | // Class of data stored in this time series. Required for compatibility with 60 | // TensorBoard's generic data facilities (`DataProvider`, et al.). This value 61 | // imposes constraints on the dtype and shape of the corresponding tensor 62 | // values. See `DataClass` docs for details. 63 | DataClass data_class = 4; 64 | } 65 | 66 | enum DataClass { 67 | // Unknown data class, used (implicitly) for legacy data. Will not be 68 | // processed by data ingestion pipelines. 69 | DATA_CLASS_UNKNOWN = 0; 70 | // Scalar time series. Each `Value` for the corresponding tag must have 71 | // `tensor` set to a rank-0 tensor of type `DT_FLOAT` (float32). 72 | DATA_CLASS_SCALAR = 1; 73 | // Tensor time series. Each `Value` for the corresponding tag must have 74 | // `tensor` set. The tensor value is arbitrary, but should be small to 75 | // accommodate direct storage in database backends: an upper bound of a few 76 | // kilobytes is a reasonable rule of thumb. 77 | DATA_CLASS_TENSOR = 2; 78 | // Blob sequence time series. Each `Value` for the corresponding tag must 79 | // have `tensor` set to a rank-1 tensor of bytestring dtype. 80 | DATA_CLASS_BLOB_SEQUENCE = 3; 81 | } 82 | 83 | // A Summary is a set of named values to be displayed by the 84 | // visualizer. 85 | // 86 | // Summaries are produced regularly during training, as controlled by 87 | // the "summary_interval_secs" attribute of the training operation. 88 | // Summaries are also produced at the end of an evaluation. 89 | message Summary { 90 | message Image { 91 | // Dimensions of the image. 92 | int32 height = 1; 93 | int32 width = 2; 94 | // Valid colorspace values are 95 | // 1 - grayscale 96 | // 2 - grayscale + alpha 97 | // 3 - RGB 98 | // 4 - RGBA 99 | // 5 - DIGITAL_YUV 100 | // 6 - BGRA 101 | int32 colorspace = 3; 102 | // Image data in encoded format. All image formats supported by 103 | // image_codec::CoderUtil can be stored here. 104 | bytes encoded_image_string = 4; 105 | } 106 | 107 | message Audio { 108 | // Sample rate of the audio in Hz. 109 | float sample_rate = 1; 110 | // Number of channels of audio. 111 | int64 num_channels = 2; 112 | // Length of the audio in frames (samples per channel). 113 | int64 length_frames = 3; 114 | // Encoded audio data and its associated RFC 2045 content type (e.g. 115 | // "audio/wav"). 116 | bytes encoded_audio_string = 4; 117 | string content_type = 5; 118 | } 119 | 120 | message Value { 121 | // This field is deprecated and will not be set. 122 | string node_name = 7; 123 | 124 | // Tag name for the data. Used by TensorBoard plugins to organize data. Tags 125 | // are often organized by scope (which contains slashes to convey 126 | // hierarchy). For example: foo/bar/0 127 | string tag = 1; 128 | 129 | // Contains metadata on the summary value such as which plugins may use it. 130 | // Take note that many summary values may lack a metadata field. This is 131 | // because the FileWriter only keeps a metadata object on the first summary 132 | // value with a certain tag for each tag. TensorBoard then remembers which 133 | // tags are associated with which plugins. This saves space. 134 | SummaryMetadata metadata = 9; 135 | 136 | // Value associated with the tag. 137 | oneof value { 138 | float simple_value = 2; 139 | bytes obsolete_old_style_histogram = 3; 140 | Image image = 4; 141 | HistogramProto histo = 5; 142 | Audio audio = 6; 143 | TensorProto tensor = 8; 144 | } 145 | } 146 | 147 | // Set of values for the summary. 148 | repeated Value value = 1; 149 | } 150 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/tensor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | import "tensorboard/compat/proto/resource_handle.proto"; 6 | import "tensorboard/compat/proto/tensor_shape.proto"; 7 | import "tensorboard/compat/proto/types.proto"; 8 | 9 | option cc_enable_arenas = true; 10 | option java_outer_classname = "TensorProtos"; 11 | option java_multiple_files = true; 12 | option java_package = "org.tensorflow.framework"; 13 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_go_proto"; 14 | 15 | // Protocol buffer representing a tensor. 16 | message TensorProto { 17 | DataType dtype = 1; 18 | 19 | // Shape of the tensor. TODO(touts): sort out the 0-rank issues. 20 | TensorShapeProto tensor_shape = 2; 21 | 22 | // Only one of the representations below is set, one of "tensor_contents" and 23 | // the "xxx_val" attributes. We are not using oneof because as oneofs cannot 24 | // contain repeated fields it would require another extra set of messages. 25 | 26 | // Version number. 27 | // 28 | // In version 0, if the "repeated xxx" representations contain only one 29 | // element, that element is repeated to fill the shape. This makes it easy 30 | // to represent a constant Tensor with a single value. 31 | int32 version_number = 3; 32 | 33 | // Serialized raw tensor content from either Tensor::AsProtoTensorContent or 34 | // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation 35 | // can be used for all tensor types. The purpose of this representation is to 36 | // reduce serialization overhead during RPC call by avoiding serialization of 37 | // many repeated small items. 38 | bytes tensor_content = 4; 39 | 40 | // Type specific representations that make it easy to create tensor protos in 41 | // all languages. Only the representation corresponding to "dtype" can 42 | // be set. The values hold the flattened representation of the tensor in 43 | // row major order. 44 | 45 | // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll 46 | // have some pointless zero padding for each value here. 47 | repeated int32 half_val = 13 [packed = true]; 48 | 49 | // DT_FLOAT. 50 | repeated float float_val = 5 [packed = true]; 51 | 52 | // DT_DOUBLE. 53 | repeated double double_val = 6 [packed = true]; 54 | 55 | // DT_INT32, DT_INT16, DT_UINT16, DT_INT8, DT_UINT8. 56 | repeated int32 int_val = 7 [packed = true]; 57 | 58 | // DT_STRING 59 | repeated bytes string_val = 8; 60 | 61 | // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real 62 | // and imaginary parts of i-th single precision complex. 63 | repeated float scomplex_val = 9 [packed = true]; 64 | 65 | // DT_INT64 66 | repeated int64 int64_val = 10 [packed = true]; 67 | 68 | // DT_BOOL 69 | repeated bool bool_val = 11 [packed = true]; 70 | 71 | // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real 72 | // and imaginary parts of i-th double precision complex. 73 | repeated double dcomplex_val = 12 [packed = true]; 74 | 75 | // DT_RESOURCE 76 | repeated ResourceHandleProto resource_handle_val = 14; 77 | 78 | // DT_VARIANT 79 | repeated VariantTensorDataProto variant_val = 15; 80 | 81 | // DT_UINT32 82 | repeated uint32 uint32_val = 16 [packed = true]; 83 | 84 | // DT_UINT64 85 | repeated uint64 uint64_val = 17 [packed = true]; 86 | } 87 | 88 | // Protocol buffer representing the serialization format of DT_VARIANT tensors. 89 | message VariantTensorDataProto { 90 | // Name of the type of objects being serialized. 91 | string type_name = 1; 92 | // Portions of the object that are not Tensors. 93 | bytes metadata = 2; 94 | // Tensors contained within objects being serialized. 95 | repeated TensorProto tensors = 3; 96 | } 97 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/tensor_description.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | import "tensorboard/compat/proto/allocation_description.proto"; 6 | import "tensorboard/compat/proto/tensor_shape.proto"; 7 | import "tensorboard/compat/proto/types.proto"; 8 | 9 | option cc_enable_arenas = true; 10 | option java_outer_classname = "TensorDescriptionProtos"; 11 | option java_multiple_files = true; 12 | option java_package = "org.tensorflow.framework"; 13 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_description_go_proto"; 14 | 15 | message TensorDescription { 16 | // Data type of tensor elements 17 | DataType dtype = 1; 18 | 19 | // Shape of the tensor. 20 | TensorShapeProto shape = 2; 21 | 22 | // Information about the size and allocator used for the data 23 | AllocationDescription allocation_description = 4; 24 | } 25 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/tensor_shape.proto: -------------------------------------------------------------------------------- 1 | // Protocol buffer representing the shape of tensors. 2 | 3 | syntax = "proto3"; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TensorShapeProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_shape_go_proto"; 9 | 10 | package tensorboard; 11 | 12 | // Dimensions of a tensor. 13 | message TensorShapeProto { 14 | // One dimension of the tensor. 15 | message Dim { 16 | // Size of the tensor in that dimension. 17 | // This value must be >= -1, but values of -1 are reserved for "unknown" 18 | // shapes (values of -1 mean "unknown" dimension). Certain wrappers 19 | // that work with TensorShapeProto may fail at runtime when deserializing 20 | // a TensorShapeProto containing a dim value of -1. 21 | int64 size = 1; 22 | 23 | // Optional name of the tensor dimension. 24 | string name = 2; 25 | }; 26 | 27 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40} 28 | // for a 30 x 40 2D tensor. If an entry has size -1, this 29 | // corresponds to a dimension of unknown size. The names are 30 | // optional. 31 | // 32 | // The order of entries in "dim" matters: It indicates the layout of the 33 | // values in the tensor in-memory representation. 34 | // 35 | // The first entry in "dim" is the outermost dimension used to layout the 36 | // values, the last entry is the innermost dimension. This matches the 37 | // in-memory layout of RowMajor Eigen tensors. 38 | // 39 | // If "dim.size()" > 0, "unknown_rank" must be false. 40 | repeated Dim dim = 2; 41 | 42 | // If true, the number of dimensions in the shape is unknown. 43 | // 44 | // If true, "dim.size()" must be 0. 45 | bool unknown_rank = 3; 46 | }; 47 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/tfprof_log.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/profiler/tfprof_log_go_proto"; 5 | 6 | import "tensorboard/compat/proto/attr_value.proto"; 7 | import "tensorboard/compat/proto/step_stats.proto"; 8 | 9 | // It specifies the Python callstack that creates an op. 10 | message CodeDef { 11 | repeated Trace traces = 1; 12 | message Trace { 13 | string file = 1 [deprecated = true]; // deprecated by file_id. 14 | int64 file_id = 6; 15 | 16 | int32 lineno = 2; 17 | 18 | string function = 3 [deprecated = true]; // deprecated by function_id. 19 | int64 function_id = 7; 20 | 21 | string line = 4 [deprecated = true]; // deprecated line_id. 22 | int64 line_id = 8; 23 | 24 | int32 func_start_line = 5; 25 | } 26 | } 27 | 28 | message OpLogEntry { 29 | // op name. 30 | string name = 1; 31 | // float_ops is filled by tfprof Python API when called. It requires the 32 | // op has RegisterStatistics defined. Currently, Conv2D, MatMul, etc, are 33 | // implemented. 34 | int64 float_ops = 2; 35 | // User can define extra op type information for an op. This allows the user 36 | // to select a group of ops precisely using op_type as a key. 37 | repeated string types = 3; 38 | // Used to support tfprof "code" view. 39 | CodeDef code_def = 4; 40 | } 41 | 42 | message OpLogProto { 43 | repeated OpLogEntry log_entries = 1; 44 | 45 | // Maps from id of CodeDef file,function,line to its string 46 | // In the future can also map other id of other fields to string. 47 | map id_to_string = 2; 48 | } 49 | 50 | // A proto representation of the profiler's profile. 51 | // It allows serialization, shipping around and deserialization of the profiles. 52 | // 53 | // Please don't depend on the internals of the profile proto. 54 | message ProfileProto { 55 | map nodes = 1; 56 | // Whether or not has code traces. 57 | bool has_trace = 2; 58 | // Whether or not the TF device tracer fails to return accelerator 59 | // information (which could lead to 0 accelerator execution time). 60 | bool miss_accelerator_stream = 5; 61 | // Traced steps. 62 | repeated int64 steps = 3; 63 | 64 | // Maps from id of CodeDef file,function,line to its string 65 | // In the future can also map other id of other fields to string. 66 | map id_to_string = 4; 67 | } 68 | 69 | message ProfileNode { 70 | // graph node name. 71 | string name = 1; 72 | // graph operation type. 73 | string op = 9; 74 | // A unique id for the node. 75 | int64 id = 13; 76 | 77 | map inputs = 2; 78 | map input_shapes = 16; 79 | map outputs = 3; 80 | map output_shapes = 15; 81 | // A map from source node id to its output index to current node. 82 | map src_output_index = 14; 83 | 84 | repeated int64 shape = 4; 85 | repeated string op_types = 5; 86 | string canonical_device = 6; 87 | string host_device = 7; 88 | 89 | int64 float_ops = 8; 90 | 91 | CodeDef trace = 10; 92 | map attrs = 11; 93 | 94 | map execs = 12; 95 | } 96 | 97 | message ExecProfile { 98 | // Can be larger than 1 if run multiple times in loop. 99 | int64 run_count = 1; 100 | // The earliest/latest time including scheduling and execution. 101 | int64 all_start_micros = 2; 102 | int64 latest_end_micros = 3; 103 | 104 | // device -> vector of {op_start_micros, op_exec_micros} pairs. 105 | // accelerator_execs: gpu:id/stream:all -> {op_start_micros, op_exec_micros} 106 | // For accelerator, vector size can be larger than 1, multiple kernel fires 107 | // or in tf.while_loop. 108 | map accelerator_execs = 4; 109 | // cpu_execs: cpu/gpu:id -> {op_start_micros, op_exec_micros} 110 | // For cpu, vector size can be larger than 1 if in tf.while_loop. 111 | map cpu_execs = 5; 112 | 113 | // Each entry to memory information of a scheduling of the node. 114 | // Normally, there will be multiple entries in while_loop. 115 | repeated ExecMemory memory_execs = 7; 116 | // The allocation and deallocation times and sizes throughout execution. 117 | repeated AllocationRecord allocations = 11; 118 | // The devices related to this execution. 119 | repeated string devices = 6; 120 | } 121 | 122 | message ExecTime { 123 | repeated Tuple times = 1; 124 | } 125 | 126 | message ExecMemory { 127 | // This is the timestamp when the memory information was tracked. 128 | int64 memory_micros = 1; 129 | // NOTE: Please don't depend on the following 4 fields yet. Due to 130 | // TensorFlow internal tracing issues, the numbers can be quite wrong. 131 | // TODO(xpan): Fix the TensorFlow internal tracing. 132 | int64 host_temp_bytes = 2; 133 | int64 host_persistent_bytes = 3; 134 | int64 accelerator_temp_bytes = 4; 135 | int64 accelerator_persistent_bytes = 5; 136 | 137 | // Total bytes requested by the op. 138 | int64 requested_bytes = 6; 139 | // Total bytes requested by the op and released before op end. 140 | int64 peak_bytes = 7; 141 | // Total bytes requested by the op and not released after op end. 142 | int64 residual_bytes = 8; 143 | // Total bytes output by the op (not necessarily requested by the op). 144 | int64 output_bytes = 9; 145 | // The total number of bytes currently allocated by the allocator if >0. 146 | int64 allocator_bytes_in_use = 10; 147 | // The memory of each output of the operation. 148 | map output_memory = 11; 149 | } 150 | 151 | message Tuple { 152 | repeated int64 int64_values = 1; 153 | } 154 | 155 | message Memory { 156 | int64 bytes = 1; 157 | uint64 ptr = 2; 158 | } 159 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/trackable_object_graph.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | import "google/protobuf/wrappers.proto"; 6 | 7 | option cc_enable_arenas = true; 8 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; 9 | 10 | // A TensorBundle addition which saves extra information about the objects which 11 | // own variables, allowing for more robust checkpoint loading into modified 12 | // programs. 13 | 14 | message TrackableObjectGraph { 15 | message TrackableObject { 16 | message ObjectReference { 17 | // An index into `TrackableObjectGraph.nodes`, indicating the object 18 | // being referenced. 19 | int32 node_id = 1; 20 | // A user-provided name for the edge. 21 | string local_name = 2; 22 | } 23 | 24 | message SerializedTensor { 25 | // A name for the Tensor. Simple variables have only one 26 | // `SerializedTensor` named "VARIABLE_VALUE" by convention. This value may 27 | // be restored on object creation as an optimization. 28 | string name = 1; 29 | // The full name of the variable/tensor, if applicable. Used to allow 30 | // name-based loading of checkpoints which were saved using an 31 | // object-based API. Should match the checkpoint key which would have been 32 | // assigned by tf.train.Saver. 33 | string full_name = 2; 34 | // The generated name of the Tensor in the checkpoint. 35 | string checkpoint_key = 3; 36 | // Deprecated bool field for optional restore. This field has never been 37 | // set to True. 38 | reserved "optional_restore"; 39 | reserved 4; 40 | } 41 | 42 | message SlotVariableReference { 43 | // An index into `TrackableObjectGraph.nodes`, indicating the 44 | // variable object this slot was created for. 45 | int32 original_variable_node_id = 1; 46 | // The name of the slot (e.g. "m"/"v"). 47 | string slot_name = 2; 48 | // An index into `TrackableObjectGraph.nodes`, indicating the 49 | // `Object` with the value of the slot variable. 50 | int32 slot_variable_node_id = 3; 51 | } 52 | 53 | // Objects which this object depends on. 54 | repeated ObjectReference children = 1; 55 | // Serialized data specific to this object. 56 | repeated SerializedTensor attributes = 2; 57 | // Slot variables owned by this object. 58 | repeated SlotVariableReference slot_variables = 3; 59 | 60 | // The registered saver used to save this object. If this saver is not 61 | // present when loading the checkpoint, then loading will fail. 62 | RegisteredSaver registered_saver = 4; 63 | 64 | // Whether this object has checkpoint values or descendants with checkpoint 65 | // values. This is computed at save time to avoid traversing the entire 66 | // object graph proto when restoring (which also has to traverse the live 67 | // object graph). 68 | google.protobuf.BoolValue has_checkpoint_values = 5; 69 | } 70 | 71 | repeated TrackableObject nodes = 1; 72 | } 73 | 74 | message RegisteredSaver { 75 | // The name of the registered saver/restore function. 76 | string name = 1; 77 | 78 | // Unique auto-generated name of the object. 79 | string object_name = 2; 80 | } 81 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/types.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "TypesProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.tensorflow.framework"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/types_go_proto"; 10 | 11 | // (== suppress_warning documentation-presence ==) 12 | // DISABLED.IfChange 13 | enum DataType { 14 | // Not a legal value for DataType. Used to indicate a DataType field 15 | // has not been set. 16 | DT_INVALID = 0; 17 | 18 | // Data types that all computation devices are expected to be 19 | // capable to support. 20 | DT_FLOAT = 1; 21 | DT_DOUBLE = 2; 22 | DT_INT32 = 3; 23 | DT_UINT8 = 4; 24 | DT_INT16 = 5; 25 | DT_INT8 = 6; 26 | DT_STRING = 7; 27 | DT_COMPLEX64 = 8; // Single-precision complex 28 | DT_INT64 = 9; 29 | DT_BOOL = 10; 30 | DT_QINT8 = 11; // Quantized int8 31 | DT_QUINT8 = 12; // Quantized uint8 32 | DT_QINT32 = 13; // Quantized int32 33 | DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. 34 | DT_QINT16 = 15; // Quantized int16 35 | DT_QUINT16 = 16; // Quantized uint16 36 | DT_UINT16 = 17; 37 | DT_COMPLEX128 = 18; // Double-precision complex 38 | DT_HALF = 19; 39 | DT_RESOURCE = 20; 40 | DT_VARIANT = 21; // Arbitrary C++ data types 41 | DT_UINT32 = 22; 42 | DT_UINT64 = 23; 43 | 44 | // Do not use! These are only for parameters. Every enum above 45 | // should have a corresponding value below (verified by types_test). 46 | DT_FLOAT_REF = 101; 47 | DT_DOUBLE_REF = 102; 48 | DT_INT32_REF = 103; 49 | DT_UINT8_REF = 104; 50 | DT_INT16_REF = 105; 51 | DT_INT8_REF = 106; 52 | DT_STRING_REF = 107; 53 | DT_COMPLEX64_REF = 108; 54 | DT_INT64_REF = 109; 55 | DT_BOOL_REF = 110; 56 | DT_QINT8_REF = 111; 57 | DT_QUINT8_REF = 112; 58 | DT_QINT32_REF = 113; 59 | DT_BFLOAT16_REF = 114; 60 | DT_QINT16_REF = 115; 61 | DT_QUINT16_REF = 116; 62 | DT_UINT16_REF = 117; 63 | DT_COMPLEX128_REF = 118; 64 | DT_HALF_REF = 119; 65 | DT_RESOURCE_REF = 120; 66 | DT_VARIANT_REF = 121; 67 | DT_UINT32_REF = 122; 68 | DT_UINT64_REF = 123; 69 | } 70 | // DISABLED.ThenChange( 71 | // https://www.tensorflow.org/code/tensorflow/c/tf_datatype.h, 72 | // https://www.tensorflow.org/code/tensorflow/go/tensor.go, 73 | // https://www.tensorflow.org/code/tensorboard/compat/proto/tensor.cc, 74 | // https://www.tensorflow.org/code/tensorboard/compat/proto/types.h, 75 | // https://www.tensorflow.org/code/tensorboard/compat/proto/types.cc, 76 | // https://www.tensorflow.org/code/tensorboard/compat/proto/dtypes.py, 77 | // https://www.tensorflow.org/code/tensorboard/compat/proto/function.py) 78 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/update.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | set -e 18 | 19 | if [ $# -ne 1 ]; then 20 | echo "usage: $0 " >&2 21 | exit 1 22 | fi 23 | 24 | rsync --existing "$1"/tensorflow/core/framework/*.proto tensorboard/compat/proto/ 25 | rsync --existing "$1"/tensorflow/core/protobuf/*.proto tensorboard/compat/proto/ 26 | rsync --existing "$1"/tensorflow/core/profiler/*.proto tensorboard/compat/proto/ 27 | rsync --existing "$1"/tensorflow/core/util/*.proto tensorboard/compat/proto/ 28 | rsync --existing "$1"/tensorflow/python/framework/*.proto tensorboard/compat/proto/ 29 | 30 | # Rewrite file paths and package names and disable LINT checks. 31 | find tensorboard/compat/proto/ -type f -name '*.proto' -exec perl -pi \ 32 | -e 's|tensorflow/core/framework|tensorboard/compat/proto|g;' \ 33 | -e 's|tensorflow/core/protobuf|tensorboard/compat/proto|g;' \ 34 | -e 's|tensorflow/core/profiler|tensorboard/compat/proto|g;' \ 35 | -e 's|tensorflow/core/util|tensorboard/compat/proto|g;' \ 36 | -e 's|tensorflow/python/framework|tensorboard/compat/proto|g;' \ 37 | -e 's|package tensorflow.tfprof;|package tensorboard;|g;' \ 38 | -e 's|package tensorflow;|package tensorboard;|g;' \ 39 | -e 's|tensorflow\.DataType|tensorboard.DataType|g;' \ 40 | -e 's|tensorflow\.TensorProto|tensorboard.TensorProto|g;' \ 41 | -e 's|tensorflow\.TensorShapeProto|tensorboard.TensorShapeProto|g;' \ 42 | -e 's|\/\/ LINT\.|// DISABLED.|g;' \ 43 | {} + 44 | 45 | 46 | # Update dependency graph. 47 | ( 48 | cd tensorboard/compat/proto/ 49 | 50 | { 51 | # Keep all organic content from the build file... 52 | sed -n '/AUTOMATICALLY GENERATED/q;p' BUILD 53 | printf '%s\n' \ 54 | '# DO NOT EDIT: This line and rest of file are AUTOMATICALLY GENERATED' \ 55 | '# by tensorboard/compat/proto/update.sh.' \ 56 | '' 57 | 58 | # ...then regenerate the individual proto targets... 59 | for f in *.proto; do 60 | printf 'tb_proto_library(\n' 61 | printf ' name = "%s",\n' "${f%.proto}" 62 | printf ' srcs = ["%s"],\n' "$f" 63 | if grep -q '^import "tensorboard/' "$f"; then 64 | printf ' deps = [\n' 65 | grep '^import "tensorboard/' "$f" | sort | 66 | sed -e 's#.*compat/proto/\([^.]*\).*# ":\1",#' 67 | printf ' ],\n' 68 | fi 69 | printf ')\n\n' 70 | done 71 | 72 | # ...as well as `protos_all`. 73 | printf '%s\n' \ 74 | '# Protobuf files copied from the main TensorFlow repository.' \ 75 | '# Keep this list synced with proto_test.py' \ 76 | ; 77 | printf 'tb_proto_library(\n' 78 | printf ' name = "protos_all",\n' 79 | printf ' srcs = [],\n' 80 | printf ' visibility = ["//visibility:public"],\n' 81 | printf ' deps = [\n' 82 | for f in *.proto; do 83 | printf ' ":%s",\n' "${f%.proto}" 84 | done | sort 85 | printf ' ],\n' 86 | printf ')\n' 87 | } | expand -t4 >BUILD.new 88 | mv BUILD.new BUILD 89 | 90 | # We made an effort to be style-compliant above, but try to run buildifier if 91 | # available, just in case. 92 | if command -v buildifier >/dev/null 2>/dev/null; then 93 | buildifier BUILD 94 | else 95 | printf >&2 'warning: buildifier(1) not found; tensorboard/compat/proto/BUILD may have lint errors\n' 96 | fi 97 | ) 98 | 99 | echo "Protos in tensorboard/compat/proto/ updated! You can now add and commit them." 100 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/variable.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "VariableProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.tensorflow.framework"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/variable_go_proto"; 10 | 11 | // Indicates when a distributed variable will be synced. 12 | enum VariableSynchronization { 13 | // `AUTO`: Indicates that the synchronization will be determined by the 14 | // current `DistributionStrategy` (eg. With `MirroredStrategy` this would be 15 | // `ON_WRITE`). 16 | VARIABLE_SYNCHRONIZATION_AUTO = 0; 17 | // `NONE`: Indicates that there will only be one copy of the variable, so 18 | // there is no need to sync. 19 | VARIABLE_SYNCHRONIZATION_NONE = 1; 20 | // `ON_WRITE`: Indicates that the variable will be updated across devices 21 | // every time it is written. 22 | VARIABLE_SYNCHRONIZATION_ON_WRITE = 2; 23 | // `ON_READ`: Indicates that the variable will be aggregated across devices 24 | // when it is read (eg. when checkpointing or when evaluating an op that uses 25 | // the variable). 26 | VARIABLE_SYNCHRONIZATION_ON_READ = 3; 27 | } 28 | 29 | // Indicates how a distributed variable will be aggregated. 30 | enum VariableAggregation { 31 | // `NONE`: This is the default, giving an error if you use a 32 | // variable-update operation with multiple replicas. 33 | VARIABLE_AGGREGATION_NONE = 0; 34 | // `SUM`: Add the updates across replicas. 35 | VARIABLE_AGGREGATION_SUM = 1; 36 | // `MEAN`: Take the arithmetic mean ("average") of the updates across 37 | // replicas. 38 | VARIABLE_AGGREGATION_MEAN = 2; 39 | // `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same 40 | // update, but we only want to perform the update once. Used, e.g., for the 41 | // global step counter. 42 | VARIABLE_AGGREGATION_ONLY_FIRST_REPLICA = 3; 43 | } 44 | 45 | // Protocol buffer representing a Variable. 46 | message VariableDef { 47 | // Name of the variable tensor. 48 | string variable_name = 1; 49 | 50 | // Name of the tensor holding the variable's initial value. 51 | string initial_value_name = 6; 52 | 53 | // Name of the initializer op. 54 | string initializer_name = 2; 55 | 56 | // Name of the snapshot tensor. 57 | string snapshot_name = 3; 58 | 59 | // Support for saving variables as slices of a larger variable. 60 | SaveSliceInfoDef save_slice_info_def = 4; 61 | 62 | // Whether to represent this as a ResourceVariable. 63 | bool is_resource = 5; 64 | 65 | // Whether this variable should be trained. 66 | bool trainable = 7; 67 | 68 | // Indicates when a distributed variable will be synced. 69 | VariableSynchronization synchronization = 8; 70 | 71 | // Indicates how a distributed variable will be aggregated. 72 | VariableAggregation aggregation = 9; 73 | } 74 | 75 | message SaveSliceInfoDef { 76 | // Name of the full variable of which this is a slice. 77 | string full_name = 1; 78 | // Shape of the full variable. 79 | repeated int64 full_shape = 2; 80 | // Offset of this variable into the full variable. 81 | repeated int64 var_offset = 3; 82 | // Shape of this variable. 83 | repeated int64 var_shape = 4; 84 | } 85 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/verifier_config.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "VerifierConfigProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.tensorflow.framework"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; 10 | 11 | // The config for graph verifiers. 12 | message VerifierConfig { 13 | enum Toggle { 14 | DEFAULT = 0; 15 | ON = 1; 16 | OFF = 2; 17 | } 18 | 19 | // Deadline for completion of all verification i.e. all the Toggle ON 20 | // verifiers must complete execution within this time. 21 | int64 verification_timeout_in_ms = 1; 22 | 23 | // Perform structural validation on a tensorflow graph. Default is OFF. 24 | Toggle structure_verifier = 2; 25 | 26 | // Next tag: 3 27 | } 28 | -------------------------------------------------------------------------------- /tensorboard/compat/proto/versions.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "VersionsProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.tensorflow.framework"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/versions_go_proto"; 10 | 11 | // Version information for a piece of serialized data 12 | // 13 | // There are different types of versions for each type of data 14 | // (GraphDef, etc.), but they all have the same common shape 15 | // described here. 16 | // 17 | // Each consumer has "consumer" and "min_producer" versions (specified 18 | // elsewhere). A consumer is allowed to consume this data if 19 | // 20 | // producer >= min_producer 21 | // consumer >= min_consumer 22 | // consumer not in bad_consumers 23 | // 24 | message VersionDef { 25 | // The version of the code that produced this data. 26 | int32 producer = 1; 27 | 28 | // Any consumer below this version is not allowed to consume this data. 29 | int32 min_consumer = 2; 30 | 31 | // Specific consumer versions which are disallowed (e.g. due to bugs). 32 | repeated int32 bad_consumers = 3; 33 | } 34 | -------------------------------------------------------------------------------- /tensorboard/proto-generated/allocation_description.pb.swift: -------------------------------------------------------------------------------- 1 | // DO NOT EDIT. 2 | // swift-format-ignore-file 3 | // 4 | // Generated by the Swift generator plugin for the protocol buffer compiler. 5 | // Source: tensorboard/compat/proto/allocation_description.proto 6 | // 7 | // For information on using the generated types, please see the documentation: 8 | // https://github.com/apple/swift-protobuf/ 9 | 10 | import Foundation 11 | import SwiftProtobuf 12 | 13 | // If the compiler emits an error on this type, it is because this file 14 | // was generated by a version of the `protoc` Swift plug-in that is 15 | // incompatible with the version of SwiftProtobuf to which you are linking. 16 | // Please ensure that you are building against the same version of the API 17 | // that was used to generate this file. 18 | fileprivate struct _GeneratedWithProtocGenSwiftVersion: SwiftProtobuf.ProtobufAPIVersionCheck { 19 | struct _2: SwiftProtobuf.ProtobufAPIVersion_2 {} 20 | typealias Version = _2 21 | } 22 | 23 | struct Tensorboard_AllocationDescription: Sendable { 24 | // SwiftProtobuf.Message conformance is added in an extension below. See the 25 | // `Message` and `Message+*Additions` files in the SwiftProtobuf library for 26 | // methods supported on all messages. 27 | 28 | /// Total number of bytes requested 29 | var requestedBytes: Int64 = 0 30 | 31 | /// Total number of bytes allocated if known 32 | var allocatedBytes: Int64 = 0 33 | 34 | /// Name of the allocator used 35 | var allocatorName: String = String() 36 | 37 | /// Identifier of the allocated buffer if known 38 | var allocationID: Int64 = 0 39 | 40 | /// Set if this tensor only has one remaining reference 41 | var hasSingleReference_p: Bool = false 42 | 43 | /// Address of the allocation. 44 | var ptr: UInt64 = 0 45 | 46 | var unknownFields = SwiftProtobuf.UnknownStorage() 47 | 48 | init() {} 49 | } 50 | 51 | // MARK: - Code below here is support for the SwiftProtobuf runtime. 52 | 53 | fileprivate let _protobuf_package = "tensorboard" 54 | 55 | extension Tensorboard_AllocationDescription: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { 56 | static let protoMessageName: String = _protobuf_package + ".AllocationDescription" 57 | static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ 58 | 1: .standard(proto: "requested_bytes"), 59 | 2: .standard(proto: "allocated_bytes"), 60 | 3: .standard(proto: "allocator_name"), 61 | 4: .standard(proto: "allocation_id"), 62 | 5: .standard(proto: "has_single_reference"), 63 | 6: .same(proto: "ptr"), 64 | ] 65 | 66 | mutating func decodeMessage(decoder: inout D) throws { 67 | while let fieldNumber = try decoder.nextFieldNumber() { 68 | // The use of inline closures is to circumvent an issue where the compiler 69 | // allocates stack space for every case branch when no optimizations are 70 | // enabled. https://github.com/apple/swift-protobuf/issues/1034 71 | switch fieldNumber { 72 | case 1: try { try decoder.decodeSingularInt64Field(value: &self.requestedBytes) }() 73 | case 2: try { try decoder.decodeSingularInt64Field(value: &self.allocatedBytes) }() 74 | case 3: try { try decoder.decodeSingularStringField(value: &self.allocatorName) }() 75 | case 4: try { try decoder.decodeSingularInt64Field(value: &self.allocationID) }() 76 | case 5: try { try decoder.decodeSingularBoolField(value: &self.hasSingleReference_p) }() 77 | case 6: try { try decoder.decodeSingularUInt64Field(value: &self.ptr) }() 78 | default: break 79 | } 80 | } 81 | } 82 | 83 | func traverse(visitor: inout V) throws { 84 | if self.requestedBytes != 0 { 85 | try visitor.visitSingularInt64Field(value: self.requestedBytes, fieldNumber: 1) 86 | } 87 | if self.allocatedBytes != 0 { 88 | try visitor.visitSingularInt64Field(value: self.allocatedBytes, fieldNumber: 2) 89 | } 90 | if !self.allocatorName.isEmpty { 91 | try visitor.visitSingularStringField(value: self.allocatorName, fieldNumber: 3) 92 | } 93 | if self.allocationID != 0 { 94 | try visitor.visitSingularInt64Field(value: self.allocationID, fieldNumber: 4) 95 | } 96 | if self.hasSingleReference_p != false { 97 | try visitor.visitSingularBoolField(value: self.hasSingleReference_p, fieldNumber: 5) 98 | } 99 | if self.ptr != 0 { 100 | try visitor.visitSingularUInt64Field(value: self.ptr, fieldNumber: 6) 101 | } 102 | try unknownFields.traverse(visitor: &visitor) 103 | } 104 | 105 | static func ==(lhs: Tensorboard_AllocationDescription, rhs: Tensorboard_AllocationDescription) -> Bool { 106 | if lhs.requestedBytes != rhs.requestedBytes {return false} 107 | if lhs.allocatedBytes != rhs.allocatedBytes {return false} 108 | if lhs.allocatorName != rhs.allocatorName {return false} 109 | if lhs.allocationID != rhs.allocationID {return false} 110 | if lhs.hasSingleReference_p != rhs.hasSingleReference_p {return false} 111 | if lhs.ptr != rhs.ptr {return false} 112 | if lhs.unknownFields != rhs.unknownFields {return false} 113 | return true 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /tensorboard/proto-generated/tensor_description.pb.swift: -------------------------------------------------------------------------------- 1 | // DO NOT EDIT. 2 | // swift-format-ignore-file 3 | // 4 | // Generated by the Swift generator plugin for the protocol buffer compiler. 5 | // Source: tensorboard/compat/proto/tensor_description.proto 6 | // 7 | // For information on using the generated types, please see the documentation: 8 | // https://github.com/apple/swift-protobuf/ 9 | 10 | import Foundation 11 | import SwiftProtobuf 12 | 13 | // If the compiler emits an error on this type, it is because this file 14 | // was generated by a version of the `protoc` Swift plug-in that is 15 | // incompatible with the version of SwiftProtobuf to which you are linking. 16 | // Please ensure that you are building against the same version of the API 17 | // that was used to generate this file. 18 | fileprivate struct _GeneratedWithProtocGenSwiftVersion: SwiftProtobuf.ProtobufAPIVersionCheck { 19 | struct _2: SwiftProtobuf.ProtobufAPIVersion_2 {} 20 | typealias Version = _2 21 | } 22 | 23 | struct Tensorboard_TensorDescription: Sendable { 24 | // SwiftProtobuf.Message conformance is added in an extension below. See the 25 | // `Message` and `Message+*Additions` files in the SwiftProtobuf library for 26 | // methods supported on all messages. 27 | 28 | /// Data type of tensor elements 29 | var dtype: Tensorboard_DataType = .dtInvalid 30 | 31 | /// Shape of the tensor. 32 | var shape: Tensorboard_TensorShapeProto { 33 | get {return _shape ?? Tensorboard_TensorShapeProto()} 34 | set {_shape = newValue} 35 | } 36 | /// Returns true if `shape` has been explicitly set. 37 | var hasShape: Bool {return self._shape != nil} 38 | /// Clears the value of `shape`. Subsequent reads from it will return its default value. 39 | mutating func clearShape() {self._shape = nil} 40 | 41 | /// Information about the size and allocator used for the data 42 | var allocationDescription: Tensorboard_AllocationDescription { 43 | get {return _allocationDescription ?? Tensorboard_AllocationDescription()} 44 | set {_allocationDescription = newValue} 45 | } 46 | /// Returns true if `allocationDescription` has been explicitly set. 47 | var hasAllocationDescription: Bool {return self._allocationDescription != nil} 48 | /// Clears the value of `allocationDescription`. Subsequent reads from it will return its default value. 49 | mutating func clearAllocationDescription() {self._allocationDescription = nil} 50 | 51 | var unknownFields = SwiftProtobuf.UnknownStorage() 52 | 53 | init() {} 54 | 55 | fileprivate var _shape: Tensorboard_TensorShapeProto? = nil 56 | fileprivate var _allocationDescription: Tensorboard_AllocationDescription? = nil 57 | } 58 | 59 | // MARK: - Code below here is support for the SwiftProtobuf runtime. 60 | 61 | fileprivate let _protobuf_package = "tensorboard" 62 | 63 | extension Tensorboard_TensorDescription: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { 64 | static let protoMessageName: String = _protobuf_package + ".TensorDescription" 65 | static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ 66 | 1: .same(proto: "dtype"), 67 | 2: .same(proto: "shape"), 68 | 4: .standard(proto: "allocation_description"), 69 | ] 70 | 71 | mutating func decodeMessage(decoder: inout D) throws { 72 | while let fieldNumber = try decoder.nextFieldNumber() { 73 | // The use of inline closures is to circumvent an issue where the compiler 74 | // allocates stack space for every case branch when no optimizations are 75 | // enabled. https://github.com/apple/swift-protobuf/issues/1034 76 | switch fieldNumber { 77 | case 1: try { try decoder.decodeSingularEnumField(value: &self.dtype) }() 78 | case 2: try { try decoder.decodeSingularMessageField(value: &self._shape) }() 79 | case 4: try { try decoder.decodeSingularMessageField(value: &self._allocationDescription) }() 80 | default: break 81 | } 82 | } 83 | } 84 | 85 | func traverse(visitor: inout V) throws { 86 | // The use of inline closures is to circumvent an issue where the compiler 87 | // allocates stack space for every if/case branch local when no optimizations 88 | // are enabled. https://github.com/apple/swift-protobuf/issues/1034 and 89 | // https://github.com/apple/swift-protobuf/issues/1182 90 | if self.dtype != .dtInvalid { 91 | try visitor.visitSingularEnumField(value: self.dtype, fieldNumber: 1) 92 | } 93 | try { if let v = self._shape { 94 | try visitor.visitSingularMessageField(value: v, fieldNumber: 2) 95 | } }() 96 | try { if let v = self._allocationDescription { 97 | try visitor.visitSingularMessageField(value: v, fieldNumber: 4) 98 | } }() 99 | try unknownFields.traverse(visitor: &visitor) 100 | } 101 | 102 | static func ==(lhs: Tensorboard_TensorDescription, rhs: Tensorboard_TensorDescription) -> Bool { 103 | if lhs.dtype != rhs.dtype {return false} 104 | if lhs._shape != rhs._shape {return false} 105 | if lhs._allocationDescription != rhs._allocationDescription {return false} 106 | if lhs.unknownFields != rhs.unknownFields {return false} 107 | return true 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /tensorboard/proto-generated/verifier_config.pb.swift: -------------------------------------------------------------------------------- 1 | // DO NOT EDIT. 2 | // swift-format-ignore-file 3 | // 4 | // Generated by the Swift generator plugin for the protocol buffer compiler. 5 | // Source: tensorboard/compat/proto/verifier_config.proto 6 | // 7 | // For information on using the generated types, please see the documentation: 8 | // https://github.com/apple/swift-protobuf/ 9 | 10 | import Foundation 11 | import SwiftProtobuf 12 | 13 | // If the compiler emits an error on this type, it is because this file 14 | // was generated by a version of the `protoc` Swift plug-in that is 15 | // incompatible with the version of SwiftProtobuf to which you are linking. 16 | // Please ensure that you are building against the same version of the API 17 | // that was used to generate this file. 18 | fileprivate struct _GeneratedWithProtocGenSwiftVersion: SwiftProtobuf.ProtobufAPIVersionCheck { 19 | struct _2: SwiftProtobuf.ProtobufAPIVersion_2 {} 20 | typealias Version = _2 21 | } 22 | 23 | /// The config for graph verifiers. 24 | struct Tensorboard_VerifierConfig: Sendable { 25 | // SwiftProtobuf.Message conformance is added in an extension below. See the 26 | // `Message` and `Message+*Additions` files in the SwiftProtobuf library for 27 | // methods supported on all messages. 28 | 29 | /// Deadline for completion of all verification i.e. all the Toggle ON 30 | /// verifiers must complete execution within this time. 31 | var verificationTimeoutInMs: Int64 = 0 32 | 33 | /// Perform structural validation on a tensorflow graph. Default is OFF. 34 | var structureVerifier: Tensorboard_VerifierConfig.Toggle = .default 35 | 36 | var unknownFields = SwiftProtobuf.UnknownStorage() 37 | 38 | enum Toggle: SwiftProtobuf.Enum, Swift.CaseIterable { 39 | typealias RawValue = Int 40 | case `default` // = 0 41 | case on // = 1 42 | case off // = 2 43 | case UNRECOGNIZED(Int) 44 | 45 | init() { 46 | self = .default 47 | } 48 | 49 | init?(rawValue: Int) { 50 | switch rawValue { 51 | case 0: self = .default 52 | case 1: self = .on 53 | case 2: self = .off 54 | default: self = .UNRECOGNIZED(rawValue) 55 | } 56 | } 57 | 58 | var rawValue: Int { 59 | switch self { 60 | case .default: return 0 61 | case .on: return 1 62 | case .off: return 2 63 | case .UNRECOGNIZED(let i): return i 64 | } 65 | } 66 | 67 | // The compiler won't synthesize support with the UNRECOGNIZED case. 68 | static let allCases: [Tensorboard_VerifierConfig.Toggle] = [ 69 | .default, 70 | .on, 71 | .off, 72 | ] 73 | 74 | } 75 | 76 | init() {} 77 | } 78 | 79 | // MARK: - Code below here is support for the SwiftProtobuf runtime. 80 | 81 | fileprivate let _protobuf_package = "tensorboard" 82 | 83 | extension Tensorboard_VerifierConfig: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { 84 | static let protoMessageName: String = _protobuf_package + ".VerifierConfig" 85 | static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ 86 | 1: .standard(proto: "verification_timeout_in_ms"), 87 | 2: .standard(proto: "structure_verifier"), 88 | ] 89 | 90 | mutating func decodeMessage(decoder: inout D) throws { 91 | while let fieldNumber = try decoder.nextFieldNumber() { 92 | // The use of inline closures is to circumvent an issue where the compiler 93 | // allocates stack space for every case branch when no optimizations are 94 | // enabled. https://github.com/apple/swift-protobuf/issues/1034 95 | switch fieldNumber { 96 | case 1: try { try decoder.decodeSingularInt64Field(value: &self.verificationTimeoutInMs) }() 97 | case 2: try { try decoder.decodeSingularEnumField(value: &self.structureVerifier) }() 98 | default: break 99 | } 100 | } 101 | } 102 | 103 | func traverse(visitor: inout V) throws { 104 | if self.verificationTimeoutInMs != 0 { 105 | try visitor.visitSingularInt64Field(value: self.verificationTimeoutInMs, fieldNumber: 1) 106 | } 107 | if self.structureVerifier != .default { 108 | try visitor.visitSingularEnumField(value: self.structureVerifier, fieldNumber: 2) 109 | } 110 | try unknownFields.traverse(visitor: &visitor) 111 | } 112 | 113 | static func ==(lhs: Tensorboard_VerifierConfig, rhs: Tensorboard_VerifierConfig) -> Bool { 114 | if lhs.verificationTimeoutInMs != rhs.verificationTimeoutInMs {return false} 115 | if lhs.structureVerifier != rhs.structureVerifier {return false} 116 | if lhs.unknownFields != rhs.unknownFields {return false} 117 | return true 118 | } 119 | } 120 | 121 | extension Tensorboard_VerifierConfig.Toggle: SwiftProtobuf._ProtoNameProviding { 122 | static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ 123 | 0: .same(proto: "DEFAULT"), 124 | 1: .same(proto: "ON"), 125 | 2: .same(proto: "OFF"), 126 | ] 127 | } 128 | -------------------------------------------------------------------------------- /tensorboard/proto-generated/versions.pb.swift: -------------------------------------------------------------------------------- 1 | // DO NOT EDIT. 2 | // swift-format-ignore-file 3 | // 4 | // Generated by the Swift generator plugin for the protocol buffer compiler. 5 | // Source: tensorboard/compat/proto/versions.proto 6 | // 7 | // For information on using the generated types, please see the documentation: 8 | // https://github.com/apple/swift-protobuf/ 9 | 10 | import Foundation 11 | import SwiftProtobuf 12 | 13 | // If the compiler emits an error on this type, it is because this file 14 | // was generated by a version of the `protoc` Swift plug-in that is 15 | // incompatible with the version of SwiftProtobuf to which you are linking. 16 | // Please ensure that you are building against the same version of the API 17 | // that was used to generate this file. 18 | fileprivate struct _GeneratedWithProtocGenSwiftVersion: SwiftProtobuf.ProtobufAPIVersionCheck { 19 | struct _2: SwiftProtobuf.ProtobufAPIVersion_2 {} 20 | typealias Version = _2 21 | } 22 | 23 | /// Version information for a piece of serialized data 24 | /// 25 | /// There are different types of versions for each type of data 26 | /// (GraphDef, etc.), but they all have the same common shape 27 | /// described here. 28 | /// 29 | /// Each consumer has "consumer" and "min_producer" versions (specified 30 | /// elsewhere). A consumer is allowed to consume this data if 31 | /// 32 | /// producer >= min_producer 33 | /// consumer >= min_consumer 34 | /// consumer not in bad_consumers 35 | struct Tensorboard_VersionDef: Sendable { 36 | // SwiftProtobuf.Message conformance is added in an extension below. See the 37 | // `Message` and `Message+*Additions` files in the SwiftProtobuf library for 38 | // methods supported on all messages. 39 | 40 | /// The version of the code that produced this data. 41 | var producer: Int32 = 0 42 | 43 | /// Any consumer below this version is not allowed to consume this data. 44 | var minConsumer: Int32 = 0 45 | 46 | /// Specific consumer versions which are disallowed (e.g. due to bugs). 47 | var badConsumers: [Int32] = [] 48 | 49 | var unknownFields = SwiftProtobuf.UnknownStorage() 50 | 51 | init() {} 52 | } 53 | 54 | // MARK: - Code below here is support for the SwiftProtobuf runtime. 55 | 56 | fileprivate let _protobuf_package = "tensorboard" 57 | 58 | extension Tensorboard_VersionDef: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { 59 | static let protoMessageName: String = _protobuf_package + ".VersionDef" 60 | static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ 61 | 1: .same(proto: "producer"), 62 | 2: .standard(proto: "min_consumer"), 63 | 3: .standard(proto: "bad_consumers"), 64 | ] 65 | 66 | mutating func decodeMessage(decoder: inout D) throws { 67 | while let fieldNumber = try decoder.nextFieldNumber() { 68 | // The use of inline closures is to circumvent an issue where the compiler 69 | // allocates stack space for every case branch when no optimizations are 70 | // enabled. https://github.com/apple/swift-protobuf/issues/1034 71 | switch fieldNumber { 72 | case 1: try { try decoder.decodeSingularInt32Field(value: &self.producer) }() 73 | case 2: try { try decoder.decodeSingularInt32Field(value: &self.minConsumer) }() 74 | case 3: try { try decoder.decodeRepeatedInt32Field(value: &self.badConsumers) }() 75 | default: break 76 | } 77 | } 78 | } 79 | 80 | func traverse(visitor: inout V) throws { 81 | if self.producer != 0 { 82 | try visitor.visitSingularInt32Field(value: self.producer, fieldNumber: 1) 83 | } 84 | if self.minConsumer != 0 { 85 | try visitor.visitSingularInt32Field(value: self.minConsumer, fieldNumber: 2) 86 | } 87 | if !self.badConsumers.isEmpty { 88 | try visitor.visitPackedInt32Field(value: self.badConsumers, fieldNumber: 3) 89 | } 90 | try unknownFields.traverse(visitor: &visitor) 91 | } 92 | 93 | static func ==(lhs: Tensorboard_VersionDef, rhs: Tensorboard_VersionDef) -> Bool { 94 | if lhs.producer != rhs.producer {return false} 95 | if lhs.minConsumer != rhs.minConsumer {return false} 96 | if lhs.badConsumers != rhs.badConsumers {return false} 97 | if lhs.unknownFields != rhs.unknownFields {return false} 98 | return true 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /test/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_swift//swift:swift.bzl", "swift_test") 2 | 3 | swift_test( 4 | name = "nnc", 5 | srcs = [ 6 | "dataframe.swift", 7 | "graph.swift", 8 | "loss.swift", 9 | "main.swift", 10 | "model.swift", 11 | "ops.swift", 12 | "optimizer.swift", 13 | "store.swift", 14 | "tensor.swift", 15 | ], 16 | data = [ 17 | "scaled_data.csv", 18 | "some_variables.db", 19 | ], 20 | deps = [ 21 | "//nnc", 22 | ], 23 | ) 24 | 25 | swift_test( 26 | name = "nnc_python", 27 | srcs = [ 28 | "python/main.swift", 29 | "python/numpy.swift", 30 | ], 31 | deps = [ 32 | "//nnc:nnc_python", 33 | ], 34 | ) 35 | 36 | swift_test( 37 | name = "nnc_coreml", 38 | srcs = [ 39 | "coreml/main.swift", 40 | "coreml/mlshapedarray.swift", 41 | ], 42 | deps = [ 43 | "//nnc:nnc_coreml", 44 | ], 45 | ) 46 | -------------------------------------------------------------------------------- /test/coreml/main.swift: -------------------------------------------------------------------------------- 1 | #if os(Linux) 2 | import XCTest 3 | 4 | XCTMain([ 5 | testCase(MLShapedArrayTests.allTests) 6 | ]) 7 | 8 | #endif 9 | -------------------------------------------------------------------------------- /test/coreml/mlshapedarray.swift: -------------------------------------------------------------------------------- 1 | import NNC 2 | import NNCCoreMLConversion 3 | #if canImport(CoreML) 4 | import CoreML 5 | import XCTest 6 | 7 | final class MLShapedArrayTests: XCTestCase { 8 | 9 | func testMakeArray() throws { 10 | var tensor = Tensor(.CPU, .NC(2, 3)) 11 | tensor[0, 0] = 1 12 | tensor[0, 1] = 2 13 | tensor[0, 2] = 3 14 | tensor[1, 0] = 4 15 | tensor[1, 1] = 5 16 | tensor[1, 2] = 6 17 | let array = MLShapedArray(tensor) 18 | XCTAssertEqual(2.0, array[scalarAt: 0, 1]) 19 | XCTAssertEqual(6.0, array[scalarAt: 1, 2]) 20 | } 21 | 22 | func testReadArray() throws { 23 | let array = MLShapedArray(scalars: [1, 2, 3, 4, 5, 6], shape: [2, 3]) 24 | let tensor = Tensor(array) 25 | XCTAssertEqual(1.0, tensor[0, 0]) 26 | XCTAssertEqual(6.0, tensor[1, 2]) 27 | } 28 | 29 | func testMakeArrayFromGPU() throws { 30 | var tensor = Tensor(.CPU, .NC(2, 3)) 31 | tensor[0, 0] = 1 32 | tensor[0, 1] = 2 33 | tensor[0, 2] = 3 34 | tensor[1, 0] = 4 35 | tensor[1, 1] = 5 36 | tensor[1, 2] = 6 37 | let array = MLShapedArray(tensor.toGPU(0)) 38 | XCTAssertEqual(2.0, array[scalarAt: 0, 1]) 39 | XCTAssertEqual(6.0, array[scalarAt: 1, 2]) 40 | } 41 | 42 | static let allTests = [ 43 | ("testMakeArray", testMakeArray), 44 | ("testReadArray", testReadArray), 45 | ("testMakeArrayFromGPU", testMakeArrayFromGPU), 46 | ] 47 | } 48 | #endif 49 | -------------------------------------------------------------------------------- /test/loss.swift: -------------------------------------------------------------------------------- 1 | import NNC 2 | import XCTest 3 | 4 | final class LossTests: XCTestCase { 5 | 6 | func testLoss() throws { 7 | let dynamicGraph = DynamicGraph() 8 | let tv0 = dynamicGraph.variable(Tensor([0.5, 0.5, 0.2, 0.8], .CPU, .NC(2, 2))) 9 | let tv1 = dynamicGraph.variable(Tensor([0, 1, 1, 0], .CPU, .NC(2, 2))) 10 | let loss = SoftmaxCrossEntropyLoss() 11 | let tv2 = loss(tv0, target: tv1) 12 | XCTAssertEqual([2, 1], tv2[0].shape) 13 | XCTAssertEqual([2, 2], tv2[1].shape) 14 | } 15 | 16 | func testTargetLoss() throws { 17 | let dynamicGraph = DynamicGraph() 18 | let tv0 = dynamicGraph.variable(Tensor([0.5, 0.5, 0.2, 0.8], .CPU, .NC(2, 2))) 19 | let tv1 = dynamicGraph.variable(Tensor([0, 1], .CPU, .NC(2, 1))) 20 | let loss = SoftmaxCrossEntropyLoss() 21 | let tv2 = loss(tv0, target: tv1) 22 | XCTAssertEqual([2, 1], tv2[0].shape) 23 | XCTAssertEqual([2, 2], tv2[1].shape) 24 | } 25 | 26 | static let allTests = [ 27 | ("testLoss", testLoss), 28 | ("testTargetLoss", testTargetLoss), 29 | ] 30 | } 31 | -------------------------------------------------------------------------------- /test/main.swift: -------------------------------------------------------------------------------- 1 | #if os(Linux) 2 | import XCTest 3 | 4 | XCTMain([ 5 | testCase(DataFrameTests.allTests), 6 | testCase(GraphTests.allTests), 7 | testCase(LossTests.allTests), 8 | testCase(ModelTests.allTests), 9 | testCase(OpsTests.allTests), 10 | testCase(OptimizerTests.allTests), 11 | testCase(StoreTests.allTests), 12 | testCase(TensorTests.allTests), 13 | ]) 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /test/optimizer.swift: -------------------------------------------------------------------------------- 1 | import NNC 2 | import XCTest 3 | 4 | final class OptimizerTests: XCTestCase { 5 | 6 | func testSGDOnModel() throws { 7 | let dynamicGraph = DynamicGraph() 8 | let linear = Dense(count: 1) 9 | let z = dynamicGraph.variable(Tensor([5], .CPU, .C(1))) 10 | var sgd = SGDOptimizer( 11 | dynamicGraph, nesterov: false, rate: 0.01, scale: 1, decay: 0.01, momentum: 0, dampening: 0) 12 | sgd.parameters = [linear.parameters] 13 | for i in 0..<100 { 14 | let x: DynamicGraph.Tensor 15 | if i % 2 == 1 { 16 | x = dynamicGraph.variable(.CPU, .NC(2, 1)) 17 | x[0, 0] = 10 18 | x[1, 0] = 10 19 | } else { 20 | x = dynamicGraph.variable(.CPU, .C(1)) 21 | x[0] = 10 22 | } 23 | var y = Functional.log(x) 24 | y = linear(y) 25 | y = y - z 26 | let f = y .* y 27 | f.backward(to: x) 28 | sgd.step() 29 | } 30 | let x: DynamicGraph.Tensor = dynamicGraph.variable(.CPU, .C(1)) 31 | x[0] = 10 32 | var y = Functional.log(x) 33 | y = linear(y) 34 | XCTAssertEqual(y[0], 5, accuracy: 1e-2) 35 | } 36 | 37 | func testSGDOnGraph() throws { 38 | let dynamicGraph = DynamicGraph() 39 | let weight: DynamicGraph.Tensor = dynamicGraph.variable(.CPU, .C(1)) 40 | let bias = dynamicGraph.variable(Tensor([0], .CPU, .C(1))) 41 | weight.rand(-1...1) 42 | let z = dynamicGraph.variable(Tensor([5], .CPU, .C(1))) 43 | var sgd = SGDOptimizer( 44 | dynamicGraph, nesterov: false, rate: 0.01, scale: 1, decay: 0.01, momentum: 0, dampening: 0) 45 | sgd.parameters = [weight, bias] 46 | for i in 0..<100 { 47 | let x: DynamicGraph.Tensor 48 | if i % 2 == 1 { 49 | x = dynamicGraph.variable(.CPU, .NC(2, 1)) 50 | x[0, 0] = 10 51 | x[1, 0] = 10 52 | } else { 53 | x = dynamicGraph.variable(.CPU, .C(1)) 54 | x[0] = 10 55 | } 56 | var y = Functional.log(x) 57 | y = y * weight + bias 58 | y = y - z 59 | let f = y .* y 60 | f.backward(to: x) 61 | sgd.step() 62 | } 63 | let x: DynamicGraph.Tensor = dynamicGraph.variable(.CPU, .C(1)) 64 | x[0] = 10 65 | var y = Functional.log(x) 66 | y = y * weight + bias 67 | XCTAssertEqual(y[0], 5, accuracy: 1e-2) 68 | } 69 | 70 | func testCopyTrainedModel() throws { 71 | let dynamicGraph = DynamicGraph() 72 | let linear = Dense(count: 1) 73 | let z = dynamicGraph.variable(Tensor([5], .CPU, .C(1))) 74 | var sgd = SGDOptimizer( 75 | dynamicGraph, nesterov: false, rate: 0.01, scale: 1, decay: 0.01, momentum: 0, dampening: 0) 76 | sgd.parameters = [linear.parameters] 77 | for i in 0..<100 { 78 | let x: DynamicGraph.Tensor 79 | if i % 2 == 1 { 80 | x = dynamicGraph.variable(.CPU, .NC(2, 1)) 81 | x[0, 0] = 10 82 | x[1, 0] = 10 83 | } else { 84 | x = dynamicGraph.variable(.CPU, .C(1)) 85 | x[0] = 10 86 | } 87 | var y = Functional.log(x) 88 | y = linear(y) 89 | y = y - z 90 | let f = y .* y 91 | f.backward(to: x) 92 | sgd.step() 93 | } 94 | let x: DynamicGraph.Tensor = dynamicGraph.variable(.CPU, .C(1)) 95 | x[0] = 10 96 | var y = Functional.log(x) 97 | y = linear(y) 98 | XCTAssertEqual(y[0], 5, accuracy: 1e-2) 99 | let copy = linear.copied() 100 | copy.parameters.copy(from: linear.parameters) 101 | var yCopy = Functional.log(x) 102 | yCopy = copy(yCopy) 103 | XCTAssertEqual(yCopy[0], 5, accuracy: 1e-2) 104 | } 105 | 106 | static let allTests = [ 107 | ("testSGDOnModel", testSGDOnModel), 108 | ("testSGDOnGraph", testSGDOnGraph), 109 | ("testCopyTrainedModel", testCopyTrainedModel), 110 | ] 111 | } 112 | -------------------------------------------------------------------------------- /test/python/main.swift: -------------------------------------------------------------------------------- 1 | #if os(Linux) 2 | import XCTest 3 | 4 | XCTMain([ 5 | testCase(NumpyTests.allTests) 6 | ]) 7 | 8 | #endif 9 | -------------------------------------------------------------------------------- /test/python/numpy.swift: -------------------------------------------------------------------------------- 1 | import NNC 2 | import NNCPythonConversion 3 | import PythonKit 4 | import XCTest 5 | 6 | final class NumpyTests: XCTestCase { 7 | 8 | func testMakeNumpyArray() throws { 9 | var tensor = Tensor(.CPU, .NC(2, 3)) 10 | tensor[0, 0] = 1 11 | tensor[0, 1] = 2 12 | tensor[0, 2] = 3 13 | tensor[1, 0] = 4 14 | tensor[1, 1] = 5 15 | tensor[1, 2] = 6 16 | let array = tensor.makeNumpyArray() 17 | XCTAssertEqual(2.0, array[0, 1]) 18 | XCTAssertEqual(6.0, array[1, 2]) 19 | } 20 | 21 | func testReadNumpyArray() throws { 22 | let np = Python.import("numpy") 23 | let array = np.ones(PythonObject(tupleOf: 2, 3)) 24 | let tensor = try Tensor(numpy: array) 25 | XCTAssertEqual(1.0, tensor[0, 0]) 26 | XCTAssertEqual(1.0, tensor[1, 2]) 27 | } 28 | 29 | func testReadNumpyArrayTransposed() throws { 30 | let torch = Python.import("torch") 31 | let array = torch.randn(768, 768).type(torch.float).numpy() 32 | let np = Python.import("numpy") 33 | let nparray = np.transpose(array) 34 | let tensor = try Tensor(numpy: nparray) 35 | for i in 0..<768 { 36 | for j in 0..<768 { 37 | XCTAssertEqual(Float(nparray[i, j])!, tensor[i, j]) 38 | } 39 | } 40 | } 41 | 42 | static let allTests = [ 43 | ("testMakeNumpyArray", testMakeNumpyArray), 44 | ("testReadNumpyArray", testReadNumpyArray), 45 | ("testReadNumpyArrayTransposed", testReadNumpyArrayTransposed), 46 | ] 47 | } 48 | -------------------------------------------------------------------------------- /test/some_variables.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuliu/s4nnc/76508762da1553344430d1be580bdd740fa674e9/test/some_variables.db --------------------------------------------------------------------------------