├── docs ├── _static │ └── style.css ├── requirements.txt ├── _templates │ ├── footer.html │ └── layout.html ├── api.rst ├── Makefile ├── index.rst ├── conf.py ├── installation.rst └── extras.rst ├── examples ├── cpp │ └── cart_pole │ │ ├── config_sim.yaml │ │ ├── media │ │ ├── cartpole.gif │ │ └── cartpole.mp4 │ │ ├── py_env.cpp │ │ ├── config.yaml │ │ ├── env.h │ │ ├── python │ │ ├── initialize_models.py │ │ └── models.py │ │ ├── CMakeLists.txt │ │ ├── env.cpp │ │ └── README.md └── fortran │ ├── graph │ ├── media │ │ └── validation_results.gif │ ├── config.yaml │ ├── CMakeLists.txt │ ├── generate_loss.py │ ├── visualize.py │ └── generate_model.py │ └── simulation │ ├── media │ └── validation_results.gif │ ├── config_fcn_torchscript.yaml │ ├── config_mlp_native.yaml │ ├── generate_fcn_model.py │ ├── CMakeLists.txt │ ├── simulation.f90 │ └── visualize.py ├── tests ├── supervised │ ├── configs │ │ ├── missing_opt.yaml │ │ ├── missing_loss.yaml │ │ ├── torchscript.yaml │ │ ├── mlp2.yaml │ │ ├── mlp3.yaml │ │ ├── mlp2_gradacc.yaml │ │ ├── torchscript_multiarg.yaml │ │ ├── torchscript_multiarg_extra.yaml │ │ └── mlp.yaml │ ├── scripts │ │ └── setup_tests.py │ └── CMakeLists.txt ├── general │ ├── configs │ │ ├── l1.yaml │ │ ├── mse.yaml │ │ ├── l1_multiarg.yaml │ │ ├── mse_multiarg.yaml │ │ ├── torchscript.yaml │ │ ├── torchscript_multiout.yaml │ │ ├── torchscript_multiarg.yaml │ │ └── torchscript_multiarg_extra.yaml │ ├── scripts │ │ └── setup_tests.py │ └── CMakeLists.txt ├── rl │ ├── configs │ │ ├── ppo.yaml │ │ ├── ddpg.yaml │ │ ├── td3.yaml │ │ └── sac.yaml │ ├── test_distributions.cpp │ └── CMakeLists.txt └── test_utils.h ├── requirements.txt ├── .clang-format ├── .github ├── scripts │ └── run_ci_tests.sh └── workflows │ ├── docs.yaml │ ├── format.yaml │ └── build.yml ├── src └── csrc │ ├── include │ ├── internal │ │ ├── rl │ │ │ ├── rl.h │ │ │ ├── distributions.h │ │ │ └── utils.h │ │ ├── base_model.h │ │ ├── base_loss.h │ │ ├── model_state.h │ │ ├── tensor_list.h │ │ ├── model_pack.h │ │ ├── distributed.h │ │ ├── model_wrapper.h │ │ ├── nvtx.h │ │ ├── setup.h │ │ ├── base_lr_scheduler.h │ │ ├── losses.h │ │ ├── training.h │ │ ├── logging.h │ │ ├── lr_schedulers.h │ │ ├── models.h │ │ ├── utils.h │ │ └── param_map.h │ ├── torchfort_config.h.in │ └── torchfort_enums.h │ ├── param_map.cpp │ ├── lr_schedulers │ ├── step_lr.cpp │ ├── polynomial_lr.cpp │ ├── linear_lr.cpp │ ├── multistep_lr.cpp │ └── cosine_annealing_lr.cpp │ ├── model_state.cpp │ ├── losses │ ├── l1_loss.cpp │ ├── mse_loss.cpp │ └── torchscript_loss.cpp │ ├── models │ ├── mlp_model.cpp │ └── rl │ │ ├── sac_model.cpp │ │ └── common_models.cpp │ ├── logging.cpp │ ├── utils.cpp │ ├── model_pack.cpp │ ├── model_wrapper.cpp │ └── rl │ └── utils.cpp ├── README.md ├── SECURITY.md ├── docker ├── Dockerfile_gnu_cpuonly ├── Dockerfile_gnu └── Dockerfile └── CONTRIBUTING.md /docs/_static/style.css: -------------------------------------------------------------------------------- 1 | .wy-nav-content { 2 | max-width: 1200px !important; 3 | } 4 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/config_sim.yaml: -------------------------------------------------------------------------------- 1 | num_episodes: 50000 2 | max_steps_per_episode: 2500 3 | eval_frequency: 25 4 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==8.2.3 2 | sphinx_rtd_theme==3.0.2 3 | breathe==4.36.0 4 | sphinx-tabs==3.4.7 5 | sphinx-fortran==1.1.1 6 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/media/cartpole.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TorchFort/HEAD/examples/cpp/cart_pole/media/cartpole.gif -------------------------------------------------------------------------------- /examples/cpp/cart_pole/media/cartpole.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TorchFort/HEAD/examples/cpp/cart_pole/media/cartpole.mp4 -------------------------------------------------------------------------------- /examples/fortran/graph/media/validation_results.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TorchFort/HEAD/examples/fortran/graph/media/validation_results.gif -------------------------------------------------------------------------------- /examples/fortran/simulation/media/validation_results.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TorchFort/HEAD/examples/fortran/simulation/media/validation_results.gif -------------------------------------------------------------------------------- /tests/supervised/configs/missing_opt.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: mlp 3 | parameters: 4 | dropout: 0.0 5 | layer_sizes: [10, 10, 10] 6 | 7 | loss: 8 | type: MSE 9 | -------------------------------------------------------------------------------- /tests/supervised/configs/missing_loss.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: mlp 3 | parameters: 4 | dropout: 0.0 5 | layer_sizes: [10, 10, 10] 6 | 7 | optimizer: 8 | type: adam 9 | -------------------------------------------------------------------------------- /tests/general/configs/l1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model.pt 5 | 6 | loss: 7 | type: L1 8 | 9 | optimizer: 10 | type: adam 11 | -------------------------------------------------------------------------------- /tests/general/configs/mse.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model.pt 5 | 6 | loss: 7 | type: MSE 8 | 9 | optimizer: 10 | type: adam 11 | -------------------------------------------------------------------------------- /docs/_templates/footer.html: -------------------------------------------------------------------------------- 1 | {% extends "!footer.html" %} 2 | {%- block contentinfo %} 3 | {{ super() }} 4 |
Documentation built from commit {{ version }}. 5 | {% endblock %} 6 | -------------------------------------------------------------------------------- /tests/supervised/configs/torchscript.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: "model.pt" 5 | 6 | loss: 7 | type: MSE 8 | 9 | optimizer: 10 | type: adam 11 | -------------------------------------------------------------------------------- /tests/general/configs/l1_multiarg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model_multiarg.pt 5 | 6 | loss: 7 | type: L1 8 | 9 | optimizer: 10 | type: adam 11 | -------------------------------------------------------------------------------- /tests/general/configs/mse_multiarg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model_multiarg.pt 5 | 6 | loss: 7 | type: MSE 8 | 9 | optimizer: 10 | type: adam 11 | -------------------------------------------------------------------------------- /tests/supervised/configs/mlp2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: mlp 3 | parameters: 4 | dropout: 0.0 5 | layer_sizes: [10, 10, 10] 6 | 7 | loss: 8 | type: MSE 9 | 10 | optimizer: 11 | type: adam 12 | -------------------------------------------------------------------------------- /tests/general/configs/torchscript.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model.pt 5 | 6 | loss: 7 | type: torchscript 8 | parameters: 9 | filename: loss.pt 10 | 11 | optimizer: 12 | type: adam 13 | -------------------------------------------------------------------------------- /tests/supervised/configs/mlp3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: mlp 3 | parameters: 4 | dropout: 0.0 5 | layer_sizes: [10, 10, 10] 6 | flatten_non_batch_dims: false 7 | 8 | loss: 9 | type: MSE 10 | 11 | optimizer: 12 | type: adam 13 | -------------------------------------------------------------------------------- /tests/general/configs/torchscript_multiout.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model.pt 5 | 6 | loss: 7 | type: torchscript 8 | parameters: 9 | filename: loss_multiout.pt 10 | 11 | optimizer: 12 | type: adam 13 | -------------------------------------------------------------------------------- /tests/general/configs/torchscript_multiarg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model_multiarg.pt 5 | 6 | loss: 7 | type: torchscript 8 | parameters: 9 | filename: loss_multiarg.pt 10 | 11 | optimizer: 12 | type: adam 13 | -------------------------------------------------------------------------------- /tests/supervised/configs/mlp2_gradacc.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: mlp 3 | parameters: 4 | dropout: 0.0 5 | layer_sizes: [10, 10, 10] 6 | 7 | loss: 8 | type: MSE 9 | 10 | optimizer: 11 | type: adam 12 | general: 13 | grad_accumulation_steps: 4 14 | -------------------------------------------------------------------------------- /tests/supervised/configs/torchscript_multiarg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: "model_multiarg.pt" 5 | 6 | loss: 7 | type: torchscript 8 | parameters: 9 | filename: "loss_multiarg.pt" 10 | 11 | optimizer: 12 | type: adam 13 | -------------------------------------------------------------------------------- /tests/general/configs/torchscript_multiarg_extra.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model_multiarg.pt 5 | 6 | loss: 7 | type: torchscript 8 | parameters: 9 | filename: loss_multiarg_extra.pt 10 | 11 | optimizer: 12 | type: adam 13 | -------------------------------------------------------------------------------- /tests/supervised/configs/torchscript_multiarg_extra.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: "model_multiarg.pt" 5 | 6 | loss: 7 | type: torchscript 8 | parameters: 9 | filename: "loss_multiarg_extra.pt" 10 | 11 | optimizer: 12 | type: adam 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # basic packages 2 | ruamel-yaml 3 | 4 | # pytorch and some dependencies 5 | torch==2.8.0 6 | 7 | # training monitoring 8 | wandb 9 | 10 | # RL example visualization related 11 | pygame 12 | moviepy 13 | 14 | # Supervised learning example visualization related 15 | matplotlib 16 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | .. _api-label: 2 | 3 | ############# 4 | TorchFort API 5 | ############# 6 | 7 | The following sections describe the types and functions available in the TorchFort library for C/C++ and Fortran programs and 8 | the also the configuration file structure and available options. 9 | 10 | .. toctree:: 11 | 12 | api/config 13 | api/c_api 14 | api/f_api 15 | 16 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | BasedOnStyle: LLVM 3 | ColumnLimit: 120 4 | CommentPragmas: '^\\.+' 5 | DerivePointerAlignment: false 6 | Language: Cpp 7 | PointerAlignment: Left 8 | UseTab: Never 9 | AlignAfterOpenBracket: Align 10 | AlignTrailingComments: true 11 | AllowShortBlocksOnASingleLine: false 12 | AllowShortCaseLabelsOnASingleLine : false 13 | AllowShortIfStatementsOnASingleLine: false 14 | AllowShortLoopsOnASingleLine: false 15 | ... 16 | -------------------------------------------------------------------------------- /tests/supervised/configs/mlp.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | enable_wandb_hook: 0 3 | report_frequency: 100 4 | 5 | model: 6 | type: mlp 7 | parameters: 8 | dropout: 0.0 9 | layer_sizes: [32, 32, 1] 10 | 11 | loss: 12 | type: MSE 13 | 14 | optimizer: 15 | type: adam 16 | parameters: 17 | learning_rate: 1e-3 18 | beta1: 0.9 19 | beta2: 0.999 20 | weight_decay: 0 21 | eps: 1e-8 22 | amsgrad: 0 23 | 24 | lr_scheduler: 25 | type: cosine_annealing 26 | parameters: 27 | T_max: 100000 28 | -------------------------------------------------------------------------------- /examples/fortran/simulation/config_fcn_torchscript.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | enable_wandb_hook: 1 3 | report_frequency: 100 4 | 5 | model: 6 | type: torchscript 7 | parameters: 8 | filename: "fcn_torchscript.pt" 9 | 10 | loss: 11 | type: MSE 12 | 13 | optimizer: 14 | type: adam 15 | parameters: 16 | learning_rate: 1e-3 17 | beta1: 0.9 18 | beta2: 0.999 19 | weight_decay: 0 20 | eps: 1e-8 21 | amsgrad: 0 22 | 23 | lr_scheduler: 24 | type: cosine_annealing 25 | parameters: 26 | T_max: 100000 27 | -------------------------------------------------------------------------------- /examples/fortran/simulation/config_mlp_native.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | enable_wandb_hook: 1 3 | report_frequency: 100 4 | 5 | model: 6 | type: mlp 7 | parameters: 8 | dropout: 0.0 9 | layer_sizes: [1024, 1024] 10 | 11 | loss: 12 | type: MSE 13 | 14 | optimizer: 15 | type: adam 16 | parameters: 17 | learning_rate: 1e-3 18 | beta1: 0.9 19 | beta2: 0.999 20 | weight_decay: 0 21 | eps: 1e-8 22 | amsgrad: 0 23 | 24 | lr_scheduler: 25 | type: cosine_annealing 26 | parameters: 27 | T_max: 100000 28 | -------------------------------------------------------------------------------- /.github/scripts/run_ci_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | 4 | cd /opt/torchfort/bin/tests/general 5 | python scripts/setup_tests.py 6 | ./test_losses 7 | 8 | cd /opt/torchfort/bin/tests/supervised 9 | python scripts/setup_tests.py 10 | ./test_checkpoint 11 | ./test_training 12 | mpirun -np 2 --allow-run-as-root ./test_distributed_training 13 | 14 | cd /opt/torchfort/bin/tests/rl 15 | ./test_distributions 16 | ./test_replay_buffer 17 | ./test_rollout_buffer 18 | ./test_off_policy --gtest_filter=*L0* 19 | ./test_on_policy --gtest_filter=*L0* 20 | -------------------------------------------------------------------------------- /examples/fortran/graph/config.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | report_frequency: 100 3 | 4 | model: 5 | type: torchscript 6 | parameters: 7 | filename: "model_torchscript.pt" 8 | 9 | loss: 10 | type: torchscript 11 | parameters: 12 | filename: "loss_torchscript.pt" 13 | 14 | optimizer: 15 | type: adam 16 | parameters: 17 | learning_rate: 1e-3 18 | beta1: 0.9 19 | beta2: 0.999 20 | weight_decay: 0 21 | eps: 1e-8 22 | amsgrad: 0 23 | 24 | lr_scheduler: 25 | type: cosine_annealing 26 | parameters: 27 | T_max: 100000 28 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | doxygen Doxyfile 21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | -------------------------------------------------------------------------------- /src/csrc/include/internal/rl/rl.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include "internal/rl/off_policy.h" 21 | #include "internal/rl/on_policy.h" 22 | -------------------------------------------------------------------------------- /src/csrc/include/torchfort_config.h.in: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | /* Define to 1 if TorchFort was built with GPU support */ 21 | #cmakedefine01 TORCHFORT_ENABLE_GPU 22 | 23 | -------------------------------------------------------------------------------- /src/csrc/param_map.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | 21 | #include "internal/param_map.h" 22 | 23 | namespace torchfort { 24 | 25 | std::set ParamMap::keys() const { 26 | std::set keys; 27 | for (const auto& entry : params) { 28 | keys.insert(entry.first); 29 | } 30 | return keys; 31 | } 32 | 33 | } // namespace torchfort 34 | -------------------------------------------------------------------------------- /src/csrc/include/internal/base_model.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | 22 | #include "internal/param_map.h" 23 | 24 | namespace torchfort { 25 | 26 | struct BaseModel : torch::nn::Module { 27 | virtual std::vector forward(const std::vector& inputs) = 0; 28 | virtual void setup(const ParamMap& params) = 0; 29 | }; 30 | 31 | } // namespace torchfort 32 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/py_env.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | 21 | #include "env.h" 22 | 23 | // pybind11 stuff 24 | namespace py = pybind11; 25 | PYBIND11_MODULE(PyEnvironments, m) { 26 | py::class_(m, "CartPoleEnv", py::dynamic_attr()) 27 | .def(py::init<>()) 28 | .def("step", &CartPoleEnv::step, py::arg("action")) 29 | .def("reset", &CartPoleEnv::reset); 30 | } 31 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/config.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | report_frequency: 1 3 | enable_wandb_hook: 1 4 | verbose: 1 5 | 6 | algorithm: 7 | type: td3 8 | parameters: 9 | batch_size: 512 10 | num_critics: 2 11 | policy_lag: 2 12 | nstep: 1 13 | nstep_reward_reduction: sum_no_skip 14 | gamma: 0.99 15 | rho: 0.99 16 | 17 | actor: 18 | type: space_noise 19 | parameters: 20 | a_low: -1.0 21 | a_high: 1.0 22 | clip: 0.3 23 | sigma_train: 0.1 24 | sigma_explore: 0.2 25 | adaptive: 0 26 | 27 | replay_buffer: 28 | type: uniform 29 | parameters: 30 | max_size: 50000 31 | min_size: 1024 32 | 33 | policy_model: 34 | type: torchscript 35 | parameters: 36 | filename: policy.pt 37 | 38 | critic_model: 39 | type: torchscript 40 | parameters: 41 | filename: value.pt 42 | 43 | optimizer: 44 | type: adam 45 | parameters: 46 | learning_rate: 0.001 47 | beta1: 0.9 48 | beta2: 0.999 49 | weight_decay: 0 50 | eps: 1e-6 51 | amsgrad: 0 52 | 53 | policy_lr_scheduler: 54 | type: cosine_annealing 55 | parameters: 56 | T_max: 500000000 57 | 58 | critic_lr_scheduler: 59 | type: cosine_annealing 60 | parameters: 61 | T_max: 500000000 62 | -------------------------------------------------------------------------------- /tests/rl/configs/ppo.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | report_frequency: 1 3 | enable_wandb_hook: 0 4 | verbose: 1 5 | 6 | algorithm: 7 | type: ppo 8 | parameters: 9 | batch_size: 8 10 | gamma: 0.99 11 | gae_lambda: 0.95 12 | epsilon: 0.2 13 | clip_q: 0. 14 | target_kl_divergence: 0.02 15 | entropy_loss_coefficient: 0. 16 | value_loss_coefficient: 0.5 17 | normalize_advantage: True 18 | 19 | actor: 20 | type: gaussian_ac 21 | parameters: 22 | a_low: -1.0 23 | a_high: 1.0 24 | 25 | rollout_buffer: 26 | type: gae_lambda 27 | parameters: 28 | size: 64 29 | 30 | actor_critic_model: 31 | type: ActorCriticMLP 32 | parameters: 33 | dropout: 0.0 34 | encoder_layer_sizes: [1, 16, 8] 35 | actor_layer_sizes: [8, 1] 36 | value_layer_sizes: [8, 1] 37 | state_dependent_sigma: False 38 | log_sigma_init: 0. 39 | 40 | optimizer: 41 | type: adam 42 | parameters: 43 | learning_rate: 1e-4 44 | beta1: 0.9 45 | beta2: 0.999 46 | weight_decay: 0 47 | eps: 1e-6 48 | amsgrad: 0 49 | general: 50 | max_grad_norm: 0.5 51 | 52 | lr_scheduler: 53 | type: linear 54 | parameters: 55 | total_iters: 40000 56 | start_factor: 1.0 57 | end_factor: 0.01 58 | -------------------------------------------------------------------------------- /docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | {% block sidebartitle %} {{ super() }} 3 | 4 | 32 | {% endblock %} 33 | 34 | {% block footer %} {{ super() }} 35 | 36 | 51 | {% endblock %} 52 | -------------------------------------------------------------------------------- /tests/rl/configs/ddpg.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | report_frequency: 1 3 | enable_wandb_hook: 0 4 | verbose: 1 5 | 6 | algorithm: 7 | type: ddpg 8 | parameters: 9 | batch_size: 128 10 | nstep: 1 11 | nstep_reward_reduction: sum 12 | gamma: 0.95 13 | rho: 0.999 14 | 15 | actor: 16 | type: space_noise 17 | parameters: 18 | a_low: -1. 19 | a_high: 1. 20 | clip: 0.3 21 | sigma_train: 0.2 22 | sigma_explore: 1.0 23 | 24 | replay_buffer: 25 | type: uniform 26 | parameters: 27 | max_size: 4096 28 | min_size: 512 29 | 30 | policy_model: 31 | type: MLP 32 | parameters: 33 | dropout: 0.0 34 | layer_sizes: [1, 16, 1] 35 | 36 | critic_model: 37 | type: CriticMLP 38 | parameters: 39 | dropout: 0.0 40 | layer_sizes: [2, 16, 1] 41 | 42 | optimizer: 43 | type: adam 44 | parameters: 45 | learning_rate: 1e-4 46 | beta1: 0.9 47 | beta2: 0.999 48 | weight_decay: 0 49 | eps: 1e-6 50 | amsgrad: 0 51 | 52 | policy_lr_scheduler: 53 | type: linear 54 | parameters: 55 | total_iters: 20000 56 | start_factor: 1.0 57 | end_factor: 0.01 58 | 59 | critic_lr_scheduler: 60 | type: linear 61 | parameters: 62 | total_iters: 20000 63 | start_factor: 1.0 64 | end_factor: 0.01 -------------------------------------------------------------------------------- /tests/rl/configs/td3.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | report_frequency: 1 3 | enable_wandb_hook: 0 4 | verbose: 1 5 | 6 | algorithm: 7 | type: td3 8 | parameters: 9 | batch_size: 128 10 | num_critics: 2 11 | policy_lag: 2 12 | nstep: 1 13 | nstep_reward_reduction: sum 14 | gamma: 0.95 15 | rho: 0.999 16 | 17 | actor: 18 | type: space_noise 19 | parameters: 20 | a_low: -1. 21 | a_high: 1. 22 | clip: 0.3 23 | sigma_train: 0.2 24 | sigma_explore: 1.0 25 | 26 | replay_buffer: 27 | type: uniform 28 | parameters: 29 | max_size: 4096 30 | min_size: 512 31 | 32 | policy_model: 33 | type: MLP 34 | parameters: 35 | dropout: 0.0 36 | layer_sizes: [1, 16, 1] 37 | 38 | critic_model: 39 | type: CriticMLP 40 | parameters: 41 | dropout: 0.0 42 | layer_sizes: [2, 16, 1] 43 | 44 | optimizer: 45 | type: adam 46 | parameters: 47 | learning_rate: 1e-4 48 | beta1: 0.9 49 | beta2: 0.999 50 | weight_decay: 0 51 | eps: 1e-6 52 | amsgrad: 0 53 | 54 | policy_lr_scheduler: 55 | type: linear 56 | parameters: 57 | total_iters: 10000 58 | start_factor: 1.0 59 | end_factor: 0.01 60 | 61 | critic_lr_scheduler: 62 | type: linear 63 | parameters: 64 | total_iters: 20000 65 | start_factor: 1.0 66 | end_factor: 0.01 -------------------------------------------------------------------------------- /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: "Documentation" 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build-documentation: 11 | runs-on: ubuntu-latest 12 | name: "Build Documentation" 13 | 14 | steps: 15 | - name: "Checkout code" 16 | uses: actions/checkout@v4 17 | 18 | - name: "Install dependencies" 19 | run: | 20 | sudo apt-get update && sudo apt-get install -y doxygen 21 | pip install -r docs/requirements.txt 22 | 23 | - name: "Build documentation" 24 | run: | 25 | cd docs 26 | export TORCHFORT_GIT_SHA=${{ github.event.pull_request.head.sha || github.sha }} 27 | make html 28 | 29 | - name: "Upload documentation" 30 | uses: actions/upload-pages-artifact@v3 31 | with: 32 | path: docs/_build/html 33 | 34 | 35 | deploy-documentation: 36 | runs-on: ubuntu-latest 37 | if: github.event_name == 'push' 38 | needs: build-documentation 39 | permissions: 40 | contents: read 41 | pages: write 42 | id-token: write 43 | environment: 44 | name: github-pages 45 | url: ${{steps.deployment.outputs.page_url}} 46 | steps: 47 | - name: "Deploy documentation" 48 | id: deployment 49 | uses: actions/deploy-pages@v4 50 | -------------------------------------------------------------------------------- /src/csrc/include/internal/base_loss.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | 22 | #include 23 | 24 | #include "internal/base_loss.h" 25 | #include "internal/param_map.h" 26 | 27 | namespace torchfort { 28 | 29 | struct BaseLoss : torch::nn::Module { 30 | virtual torch::Tensor forward(const std::vector& inputs, const std::vector& labels, 31 | const std::vector& extra_args) = 0; 32 | virtual void setup(const ParamMap& params) = 0; 33 | }; 34 | 35 | } // namespace torchfort 36 | -------------------------------------------------------------------------------- /examples/fortran/simulation/generate_fcn_model.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | class Net(torch.nn.Module): 19 | def __init__(self): 20 | super(Net, self).__init__() 21 | self.conv1 = torch.nn.Conv2d(1, 1, 3, padding=1, padding_mode="circular") 22 | 23 | def forward(self, x): 24 | return self.conv1(x) 25 | 26 | 27 | def main(): 28 | # Create model 29 | model = Net() 30 | print("FCN model:", model) 31 | 32 | try: 33 | # Move model to GPU, JIT, and save 34 | model.to("cuda") 35 | except: 36 | print("PyTorch does not have CUDA support. Saving model on CPU.") 37 | model_jit = torch.jit.script(model) 38 | model_jit.save("fcn_torchscript.pt") 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /src/csrc/lr_schedulers/step_lr.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | 21 | #include "internal/base_lr_scheduler.h" 22 | #include "internal/lr_schedulers.h" 23 | 24 | namespace torchfort { 25 | 26 | StepLR::StepLR(torch::optim::Optimizer& optimizer, const unsigned step_size, const double gamma) 27 | : BaseLRScheduler(optimizer), step_size_(step_size), gamma_(gamma) {} 28 | 29 | std::vector StepLR::get_lrs() { 30 | if (step_count_ == 0 || step_count_ % step_size_ != 0) 31 | return get_current_lrs(); 32 | else { 33 | std::vector lrs = get_current_lrs(); 34 | std::transform(lrs.begin(), lrs.end(), lrs.begin(), [this](const double& v) { return this->gamma_ * v; }); 35 | return lrs; 36 | } 37 | } 38 | 39 | } // namespace torchfort 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchFort 2 | 3 | An Online Deep Learning Interface for HPC programs on NVIDIA GPUs 4 | 5 | ## Introduction 6 | TorchFort is a DL training and inference interface for HPC programs implemented using LibTorch, the C++ backend used by the [PyTorch](https://pytorch.org) framework. 7 | The goal of this library is to help practitioners and domain scientists to seamlessly combine their simulation codes with Deep Learning functionalities available 8 | within PyTorch. 9 | This library can be invoked directly from Fortran or C/C++ programs, enabling transparent sharing of data arrays to and from the DL framework all contained within the 10 | simulation process (i.e., no external glue/data-sharing code required). The library can directly load PyTorch model definitions exported to TorchScript and implements a 11 | configurable training process that users can control via a simple YAML configuration file format. The configuration files enable users to specify optimizer and loss selection, 12 | learning rate schedules, and much more. 13 | 14 | Please refer to the [documentation](https://nvidia.github.io/TorchFort/) for additional information on the library, build instructions, and usage details. 15 | 16 | Please refer to the [examples](examples) to see TorchFort in action. 17 | 18 | Contact us or open a GitHub issue if you are interested in using this library in your own solvers and have questions on usage and/or feature requests. 19 | 20 | ## License 21 | This library is released under an Apache 2.0 license, which can be found in [LICENSE](LICENSE). 22 | -------------------------------------------------------------------------------- /src/csrc/include/internal/model_state.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | #include 20 | #include 21 | 22 | #include 23 | 24 | namespace torchfort { 25 | 26 | // Simple struct to store miscellaneous model state (e.g. iteration count) 27 | struct ModelState { 28 | int64_t step_train; 29 | int64_t step_inference; 30 | int64_t step_train_current; // training step of current run (ignoring restarted state) 31 | torch::Device device = torch::Device(torch::kCPU); 32 | 33 | // General option settings 34 | int32_t report_frequency; 35 | bool enable_wandb_hook; 36 | bool verbose; 37 | std::filesystem::path report_file; 38 | 39 | void save(const std::string& fname); 40 | void load(const std::string& fname); 41 | }; 42 | 43 | } // namespace torchfort 44 | -------------------------------------------------------------------------------- /examples/fortran/graph/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(fortran_example_targets 2 | train_graph 3 | ) 4 | 5 | add_executable(train_graph) 6 | target_sources(train_graph 7 | PRIVATE 8 | train.f90 9 | ) 10 | set_target_properties(train_graph 11 | PROPERTIES OUTPUT_NAME train) 12 | 13 | foreach(tgt ${fortran_example_targets}) 14 | target_include_directories(${tgt} 15 | PRIVATE 16 | ${CMAKE_BINARY_DIR}/include 17 | ${MPI_Fortran_INCLUDE_DIRS} 18 | ) 19 | target_link_libraries(${tgt} PRIVATE MPI::MPI_Fortran) 20 | target_link_libraries(${tgt} PRIVATE "${PROJECT_NAME}_fort") 21 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME}) 22 | if (CMAKE_Fortran_COMPILER_ID STREQUAL "NVHPC") 23 | target_compile_options(${tgt} PRIVATE $<$:-cpp -acc -gpu=${CUF_GPU_ARG}>) 24 | target_link_options(${tgt} PRIVATE $<$: -acc -gpu=${CUF_GPU_ARG}>) 25 | elseif (CMAKE_Fortran_COMPILER_ID STREQUAL "GNU") 26 | target_compile_options(${tgt} PRIVATE $<$:-cpp -fbackslash>) 27 | endif() 28 | endforeach() 29 | 30 | install( 31 | TARGETS ${fortran_example_targets} 32 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/graph 33 | ) 34 | 35 | install( 36 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/config.yaml 37 | ${CMAKE_CURRENT_SOURCE_DIR}/generate_model.py 38 | ${CMAKE_CURRENT_SOURCE_DIR}/generate_loss.py 39 | ${CMAKE_CURRENT_SOURCE_DIR}/nodes.txt 40 | ${CMAKE_CURRENT_SOURCE_DIR}/connectivity.txt 41 | ${CMAKE_CURRENT_SOURCE_DIR}/visualize.py 42 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/graph) 43 | -------------------------------------------------------------------------------- /src/csrc/include/internal/tensor_list.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | 22 | #include 23 | 24 | #include "internal/utils.h" 25 | 26 | namespace torchfort { 27 | struct TensorList { 28 | template void add_tensor(T* data, size_t dim, int64_t* shape) { 29 | auto tensor = get_tensor(data, dim, shape); 30 | tensors.push_back(tensor); 31 | tensors_original_.push_back(tensor); 32 | }; 33 | 34 | void to(torch::Device device, bool non_blocking = false) { 35 | for (auto& t : tensors) { 36 | t = t.to(device, non_blocking); 37 | } 38 | }; 39 | 40 | void reset() { tensors = tensors_original_; } 41 | 42 | std::vector tensors; 43 | // To preserve references to external data, we store the original tensor objects 44 | std::vector tensors_original_; 45 | }; 46 | } // namespace torchfort 47 | -------------------------------------------------------------------------------- /tests/rl/configs/sac.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | report_frequency: 1 3 | enable_wandb_hook: 0 4 | verbose: 1 5 | 6 | algorithm: 7 | type: sac 8 | parameters: 9 | batch_size: 128 10 | num_critics: 2 11 | nstep: 1 12 | nstep_reward_reduction: sum 13 | gamma: 0.95 14 | rho: 0.999 15 | alpha: 0.1 16 | 17 | actor: 18 | type: parameter_noise 19 | parameters: 20 | a_low: -1.0 21 | a_high: 1.0 22 | 23 | replay_buffer: 24 | type: uniform 25 | parameters: 26 | max_size: 4096 27 | min_size: 512 28 | 29 | policy_model: 30 | type: SACMLP 31 | parameters: 32 | dropout: 0.0 33 | layer_sizes: [1, 16, 8, 1] 34 | state_dependent_sigma: False 35 | log_sigma_init: 0. 36 | 37 | critic_model: 38 | type: CriticMLP 39 | parameters: 40 | dropout: 0.0 41 | layer_sizes: [2, 16, 1] 42 | 43 | optimizer: 44 | type: adam 45 | parameters: 46 | learning_rate: 1e-4 47 | beta1: 0.9 48 | beta2: 0.999 49 | weight_decay: 0 50 | eps: 1e-6 51 | amsgrad: 0 52 | 53 | alpha_optimizer: 54 | type: adam 55 | parameters: 56 | learning_rate: 1e-4 57 | beta1: 0.9 58 | beta2: 0.999 59 | weight_decay: 0 60 | eps: 1e-6 61 | amsgrad: 0 62 | 63 | policy_lr_scheduler: 64 | type: linear 65 | parameters: 66 | total_iters: 20000 67 | start_factor: 1.0 68 | end_factor: 0.01 69 | 70 | critic_lr_scheduler: 71 | type: linear 72 | parameters: 73 | total_iters: 20000 74 | start_factor: 1.0 75 | end_factor: 0.01 76 | 77 | alpha_lr_scheduler: 78 | type: linear 79 | parameters: 80 | total_iters: 20000 81 | start_factor: 1.0 82 | end_factor: 0.01 83 | -------------------------------------------------------------------------------- /examples/fortran/graph/generate_loss.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | class CustomLoss(torch.nn.Module): 19 | def __init__(self): 20 | super(CustomLoss, self).__init__() 21 | 22 | def forward(self, prediction, label, node_types): 23 | 24 | # Compute MSE over all nodes 25 | err = (label - prediction)**2 26 | 27 | # Zero out error for boundary nodes 28 | mask = node_types != 0 29 | err *= mask.unsqueeze(-1) 30 | 31 | # Compute mean over non-boundary nodes 32 | mse = torch.sum(err) / (torch.sum(mask) * err.shape[1]) 33 | 34 | return mse 35 | 36 | def main(): 37 | # Create loss module 38 | loss = CustomLoss() 39 | print("loss module:", loss) 40 | 41 | try: 42 | # Move model to GPU, JIT, and save 43 | loss.to("cuda") 44 | except: 45 | print("PyTorch does not have CUDA support. Saving model on CPU.") 46 | loss_jit = torch.jit.script(loss) 47 | loss_jit.save("loss_torchscript.pt") 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | NVIDIA is dedicated to the security and trust of our software products and services, including all source code repositories managed through our organization. 4 | 5 | If you need to report a security issue, please use the appropriate contact points outlined below. **Please do not report security vulnerabilities through GitHub.** 6 | 7 | ## Reporting Potential Security Vulnerability in an NVIDIA Product 8 | 9 | To report a potential security vulnerability in any NVIDIA product: 10 | - Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html) 11 | - E-Mail: psirt@nvidia.com 12 | - We encourage you to use the following PGP key for secure email communication: [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key) 13 | - Please include the following information: 14 | - Product/Driver name and version/branch that contains the vulnerability 15 | - Type of vulnerability (code execution, denial of service, buffer overflow, etc.) 16 | - Instructions to reproduce the vulnerability 17 | - Proof-of-concept or exploit code 18 | - Potential impact of the vulnerability, including how an attacker could exploit the vulnerability 19 | 20 | While NVIDIA currently does not have a bug bounty program, we do offer acknowledgement when an externally reported security issue is addressed under our coordinated vulnerability disclosure policy. Please visit our [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/) policies page for more information. 21 | 22 | ## NVIDIA Product Security 23 | 24 | For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security 25 | -------------------------------------------------------------------------------- /src/csrc/lr_schedulers/polynomial_lr.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #include "internal/base_lr_scheduler.h" 23 | #include "internal/lr_schedulers.h" 24 | 25 | namespace torchfort { 26 | 27 | PolynomialLR::PolynomialLR(torch::optim::Optimizer& optimizer, const unsigned total_iters, const double power) 28 | : BaseLRScheduler(optimizer), total_iters_(total_iters), power_(power) {} 29 | 30 | std::vector PolynomialLR::get_lrs() { 31 | std::vector lrs = get_current_lrs(); 32 | if (step_count_ == 0 || step_count_ > total_iters_) 33 | return lrs; 34 | else { 35 | double decay_factor = 36 | (1. - double(step_count_) / double(total_iters_)) / (1. - double(step_count_ - 1) / double(total_iters_)); 37 | decay_factor = std::pow(decay_factor, power_); 38 | 39 | std::transform(lrs.begin(), lrs.end(), lrs.begin(), [decay_factor](const double& v) { return decay_factor * v; }); 40 | 41 | return lrs; 42 | } 43 | } 44 | 45 | } // namespace torchfort 46 | -------------------------------------------------------------------------------- /src/csrc/include/internal/model_pack.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | #include 20 | 21 | #include 22 | 23 | #include "internal/base_loss.h" 24 | #include "internal/base_lr_scheduler.h" 25 | #include "internal/distributed.h" 26 | #include "internal/model_state.h" 27 | #include "internal/model_wrapper.h" 28 | 29 | namespace torchfort { 30 | 31 | // Simple struct to group model, optimizer, lr scheduler, state, and comm objects 32 | struct ModelPack { 33 | std::shared_ptr model; 34 | std::shared_ptr optimizer; 35 | std::shared_ptr lr_scheduler; 36 | std::shared_ptr loss; 37 | std::shared_ptr comm; 38 | std::shared_ptr state; 39 | int grad_accumulation_steps = 1; 40 | float max_grad_norm = 0.0; 41 | }; 42 | 43 | void save_model_pack(const ModelPack& model_pack, const std::string& fname, bool save_optimizer = true); 44 | void load_model_pack(ModelPack& model_pack, const std::string& fname, bool load_optimizer = true); 45 | 46 | } // namespace torchfort 47 | -------------------------------------------------------------------------------- /src/csrc/include/internal/distributed.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #ifdef ENABLE_GPU 21 | #include 22 | #include 23 | #endif 24 | #include 25 | 26 | #include 27 | 28 | namespace torchfort { 29 | 30 | struct Comm { 31 | void initialize(); 32 | void finalize(); 33 | void allreduce(torch::Tensor& tensor, bool average = false) const; 34 | void allreduce(std::vector& tensors, bool average = false) const; 35 | void allreduce(double& val, bool average = false) const; 36 | void allreduce(float& val, bool average = false) const; 37 | void broadcast(torch::Tensor& tensor, int root = 0) const; 38 | 39 | int rank; 40 | int size; 41 | torch::Device device; 42 | MPI_Comm mpi_comm; 43 | #ifdef ENABLE_GPU 44 | ncclComm_t nccl_comm = nullptr; 45 | cudaStream_t stream = nullptr; 46 | cudaEvent_t event = nullptr; 47 | #endif 48 | bool initialized = false; 49 | 50 | Comm(MPI_Comm mpi_comm, torch::Device device) : mpi_comm(mpi_comm), device(device){}; 51 | }; 52 | 53 | } // namespace torchfort 54 | -------------------------------------------------------------------------------- /.github/workflows/format.yaml: -------------------------------------------------------------------------------- 1 | name: "Code Format" 2 | 3 | on: 4 | pull_request: 5 | branches: [ master ] 6 | 7 | jobs: 8 | clang-format: 9 | env: 10 | CLANG_FORMAT_VERSION: "15" 11 | 12 | runs-on: ubuntu-latest 13 | name: "clang-format check" 14 | 15 | steps: 16 | - name: "Checkout code" 17 | uses: actions/checkout@v4 18 | 19 | - name: "Install dependencies" 20 | run: | 21 | # Install clang-format 22 | sudo apt-get update && sudo apt-get install -y clang-format-${CLANG_FORMAT_VERSION} 23 | 24 | # Print clang-format version information 25 | clang-format-${CLANG_FORMAT_VERSION} --version 26 | 27 | - name: "Run clang-format check" 28 | run: | 29 | # Collect names of files that are not properly formatted 30 | filelist=`find src examples tests -name "*.cpp" -o -name "*.h"` 31 | files_to_fix=() 32 | for file in $filelist; do 33 | if ! clang-format-${CLANG_FORMAT_VERSION} --dry-run --Werror "$file" 2>/dev/null; then 34 | files_to_fix+=("$file") 35 | fi 36 | done 37 | 38 | # If any file is not properly formatted, print diff and exit with error 39 | if [ ${#files_to_fix[@]} -gt 0 ]; then 40 | # Print the list of files that are not properly formatted 41 | echo "FAIL: Some files are not properly formatted. To resolve issues, run:" 42 | for file in "${files_to_fix[@]}"; do 43 | echo "clang-format-${CLANG_FORMAT_VERSION} -i $file" 44 | done 45 | echo 46 | 47 | for file in "${files_to_fix[@]}"; do 48 | echo "Diff for $file:" 49 | bash -c "clang-format-${CLANG_FORMAT_VERSION} $file | diff $file -; exit 0" 50 | echo 51 | done 52 | 53 | exit 1 54 | fi 55 | 56 | echo "PASS: All files are properly formatted." 57 | -------------------------------------------------------------------------------- /docker/Dockerfile_gnu_cpuonly: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | SHELL ["/bin/bash", "-c"] 4 | 5 | # Install System Dependencies 6 | ENV DEBIAN_FRONTEND noninteractive 7 | RUN apt update -y && \ 8 | apt install -y build-essential && \ 9 | apt install -y wget cmake && \ 10 | apt install -y python3 python-is-python3 python3-pip python3-pybind11 && \ 11 | apt install -y git vim gfortran && \ 12 | apt install -y libibverbs-dev ibverbs-utils numactl && \ 13 | apt install -y openmpi-bin libopenmpi-dev 14 | 15 | # Install PyTorch 16 | RUN pip3 install torch==2.8.0 --index-url https://download.pytorch.org/whl/cpu 17 | 18 | # Install yaml-cpp 19 | RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \ 20 | cd yaml-cpp && \ 21 | mkdir build && cd build && \ 22 | cmake -DCMAKE_INSTALL_PREFIX=/opt/yaml-cpp \ 23 | -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" \ 24 | -DBUILD_SHARED_LIBS=OFF \ 25 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && \ 26 | make -j$(nproc) && make install 27 | ENV LD_LIBRARY_PATH /opt/yaml-cpp/lib:${LD_LIBRARY_PATH} 28 | 29 | # Install additional Python dependencies 30 | RUN pip3 install wandb ruamel-yaml matplotlib pygame moviepy 31 | 32 | # Install TorchFort without GPU support 33 | ENV FC=gfortran 34 | COPY . /torchfort 35 | RUN cd /torchfort && mkdir build && cd build && \ 36 | cmake -DCMAKE_INSTALL_PREFIX=/opt/torchfort \ 37 | -DTORCHFORT_YAML_CPP_ROOT=/opt/yaml-cpp \ 38 | -DTORCHFORT_ENABLE_GPU=0 \ 39 | -DTORCHFORT_BUILD_EXAMPLES=1 \ 40 | -DTORCHFORT_BUILD_TESTS=1 \ 41 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \ 42 | .. && \ 43 | make -j$(nproc) install && \ 44 | cd / && rm -rf torchfort 45 | ENV LD_LIBRARY_PATH /opt/torchfort/lib:${LD_LIBRARY_PATH} 46 | ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH} 47 | -------------------------------------------------------------------------------- /examples/fortran/simulation/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | set(fortran_example_targets 3 | train 4 | train_distributed 5 | ) 6 | 7 | add_executable(train) 8 | target_sources(train 9 | PRIVATE 10 | train.f90 11 | simulation.f90 12 | ) 13 | set_target_properties(train PROPERTIES Fortran_MODULE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mod/0 ) 14 | 15 | add_executable(train_distributed) 16 | target_sources(train_distributed 17 | PRIVATE 18 | train_distributed.f90 19 | simulation.f90 20 | ) 21 | set_target_properties(train_distributed PROPERTIES Fortran_MODULE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mod/1 ) 22 | 23 | foreach(tgt ${fortran_example_targets}) 24 | target_include_directories(${tgt} 25 | PRIVATE 26 | ${CMAKE_BINARY_DIR}/include 27 | ${MPI_Fortran_INCLUDE_DIRS} 28 | ) 29 | target_link_libraries(${tgt} PRIVATE MPI::MPI_Fortran) 30 | target_link_libraries(${tgt} PRIVATE "${PROJECT_NAME}_fort") 31 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME}) 32 | if (CMAKE_Fortran_COMPILER_ID STREQUAL "NVHPC") 33 | target_compile_options(${tgt} PRIVATE $<$:-cpp -acc -gpu=${CUF_GPU_ARG}>) 34 | target_link_options(${tgt} PRIVATE $<$: -acc -gpu=${CUF_GPU_ARG}>) 35 | elseif (CMAKE_Fortran_COMPILER_ID STREQUAL "GNU") 36 | target_compile_options(${tgt} PRIVATE $<$:-cpp -fbackslash>) 37 | endif() 38 | endforeach() 39 | 40 | install( 41 | TARGETS ${fortran_example_targets} 42 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/simulation 43 | ) 44 | 45 | 46 | install( 47 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/config_mlp_native.yaml 48 | ${CMAKE_CURRENT_SOURCE_DIR}/config_fcn_torchscript.yaml 49 | ${CMAKE_CURRENT_SOURCE_DIR}/generate_fcn_model.py 50 | ${CMAKE_CURRENT_SOURCE_DIR}/visualize.py 51 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/simulation) 52 | -------------------------------------------------------------------------------- /src/csrc/include/internal/model_wrapper.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | 22 | #include "internal/base_model.h" 23 | 24 | namespace torchfort { 25 | 26 | class ModelWrapper { 27 | public: 28 | ModelWrapper(const std::shared_ptr& model); 29 | 30 | ModelWrapper(const std::shared_ptr& model_jit); 31 | 32 | ModelWrapper(const std::string& jit_model_fname); 33 | 34 | std::vector parameters() const; 35 | 36 | torch::OrderedDict named_parameters() const; 37 | 38 | void to(torch::Device device, bool non_blocking = false); 39 | 40 | void train(); 41 | 42 | void eval(); 43 | 44 | std::vector forward(const std::vector& inputs) const; 45 | 46 | void save(const std::string& fname) const; 47 | 48 | void load(const std::string& fname); 49 | 50 | torch::Device device() const; 51 | 52 | private: 53 | bool jit = false; 54 | std::shared_ptr model; 55 | std::shared_ptr model_jit; 56 | torch::Device device_ = torch::Device(torch::kCPU); 57 | }; 58 | 59 | } // namespace torchfort 60 | -------------------------------------------------------------------------------- /src/csrc/lr_schedulers/linear_lr.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | 21 | #include "internal/base_lr_scheduler.h" 22 | #include "internal/lr_schedulers.h" 23 | 24 | namespace torchfort { 25 | 26 | LinearLR::LinearLR(torch::optim::Optimizer& optimizer, const unsigned total_iters, const double start_factor, 27 | const double end_factor) 28 | : BaseLRScheduler(optimizer), total_iters_(total_iters), start_factor_(start_factor), end_factor_(end_factor) {} 29 | 30 | std::vector LinearLR::get_lrs() { 31 | 32 | double factor; 33 | if (step_count_ == 0) { 34 | factor = start_factor_; 35 | } else if (step_count_ > total_iters_) { 36 | factor = 1.; 37 | } else { 38 | factor = (1. + (end_factor_ - start_factor_) / 39 | double(total_iters_ * start_factor_ + (step_count_ - 1) * (end_factor_ - start_factor_))); 40 | } 41 | 42 | // get current lrs and modify 43 | std::vector lrs = get_current_lrs(); 44 | std::transform(lrs.begin(), lrs.end(), lrs.begin(), [factor](const double& v) { return factor * v; }); 45 | 46 | return lrs; 47 | } 48 | 49 | } // namespace torchfort 50 | -------------------------------------------------------------------------------- /src/csrc/include/internal/nvtx.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | 22 | #ifdef ENABLE_GPU 23 | #include 24 | #endif 25 | 26 | namespace torchfort { 27 | 28 | // Helper class for NVTX ranges 29 | class nvtx { 30 | public: 31 | #ifdef ENABLE_GPU 32 | static void rangePush(const std::string& range_name) { 33 | static constexpr int ncolors_ = 8; 34 | static constexpr int colors_[ncolors_] = {0x3366CC, 0xDC3912, 0xFF9900, 0x109618, 35 | 0x990099, 0x3B3EAC, 0x0099C6, 0xDD4477}; 36 | std::hash hash_fn; 37 | int color = colors_[hash_fn(range_name) % ncolors_]; 38 | nvtxEventAttributes_t ev = {0}; 39 | ev.version = NVTX_VERSION; 40 | ev.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; 41 | ev.colorType = NVTX_COLOR_ARGB; 42 | ev.color = color; 43 | ev.messageType = NVTX_MESSAGE_TYPE_ASCII; 44 | ev.message.ascii = range_name.c_str(); 45 | nvtxRangePushEx(&ev); 46 | } 47 | 48 | static void rangePop() { nvtxRangePop(); } 49 | #else 50 | static void rangePush(const std::string& range_name) {} 51 | static void rangePop() {} 52 | #endif 53 | }; 54 | 55 | } // namespace torchfort 56 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. TorchFort documentation master file, created by 2 | sphinx-quickstart on Wed Jun 1 13:44:41 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | ############################################################################ 7 | TorchFort: An Online Deep Learning Interface for HPC programs on NVIDIA GPUs 8 | ############################################################################ 9 | These pages contain the documentation for TorchFort, an online deep learning interface for HPC programs. 10 | 11 | TorchFort is a DL training and inference interface for HPC programs implemented using LibTorch, the C++ backend used by the `PyTorch `_ framework. 12 | The goal of this library is to help practitioners and domain scientists to seamlessly combine their simulation codes with Deep Learning functionalities available 13 | within PyTorch. 14 | This library can be invoked directly from Fortran or C/C++ programs, enabling transparent sharing of data arrays to and from the DL framework all contained within the 15 | simulation process (i.e., no external glue/data-sharing code required). The library can directly load PyTorch model definitions exported to TorchScript and implements a 16 | configurable training process that users can control via a simple YAML configuration file format. The configuration files enable users to specify optimizer and loss selection, 17 | learning rate schedules, and much more. 18 | 19 | Please contact us or open a GitHub issue if you are interested in using this library 20 | in your own solvers and have questions on usage and/or feature requests. 21 | 22 | 23 | Table of Contents 24 | ================= 25 | .. toctree:: 26 | :maxdepth: 4 27 | 28 | installation 29 | usage 30 | api 31 | extras 32 | 33 | 34 | Indices and tables 35 | ================== 36 | 37 | * :ref:`genindex` 38 | * :ref:`modindex` 39 | * :ref:`search` 40 | -------------------------------------------------------------------------------- /src/csrc/include/torchfort_enums.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #define TORCHFORT_DEVICE_CPU (-1) 21 | 22 | /** 23 | * @brief This enum defines the data types supported. 24 | */ 25 | enum torchfort_datatype_t { TORCHFORT_FLOAT = -1, TORCHFORT_DOUBLE = -2, TORCHFORT_INT32 = -3, TORCHFORT_INT64 = -4 }; 26 | 27 | /** 28 | * @brief This enum defines the possible values return values from TorchFort. Most functions in the TorchFort library 29 | * will return one of these values to indicate if an operation has completed successfully or an error occured. 30 | */ 31 | enum torchfort_result_t { 32 | TORCHFORT_RESULT_SUCCESS = 0, ///< The operation completed successfully 33 | TORCHFORT_RESULT_INVALID_USAGE = 1, ///< A user error, typically an invalid argument 34 | TORCHFORT_RESULT_NOT_SUPPORTED = 2, ///< A user error, requesting an invalid or unsupported operation configuration 35 | TORCHFORT_RESULT_INTERNAL_ERROR = 3, ///< An internal library error, should be reported 36 | TORCHFORT_RESULT_CUDA_ERROR = 4, ///< An error occured in the CUDA Runtime 37 | TORCHFORT_RESULT_MPI_ERROR = 5, ///< An error occured in the MPI library 38 | TORCHFORT_RESULT_NCCL_ERROR = 6 ///< An error occured in the NCCL library 39 | }; 40 | -------------------------------------------------------------------------------- /tests/rl/test_distributions.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include "internal/rl/distributions.h" 19 | #include 20 | #include 21 | 22 | using namespace torchfort; 23 | using namespace torch::indexing; 24 | 25 | TEST(NormalDistribution, RandomSampling) { 26 | // rng 27 | torch::manual_seed(666); 28 | 29 | // no grad guard 30 | torch::NoGradGuard no_grad; 31 | 32 | // create normal distribution with given shape 33 | torch::Tensor mutens = torch::empty({4, 8}, torch::kFloat32); 34 | torch::Tensor log_sigmatens = torch::empty({4, 8}, torch::kFloat32); 35 | 36 | // fill with random elements 37 | mutens.normal_(); 38 | log_sigmatens.normal_(); 39 | torch::Tensor sigmatens = torch::exp(log_sigmatens); 40 | 41 | auto ndist = rl::NormalDistribution(mutens, sigmatens); 42 | torch::Tensor sample = ndist.rsample(); 43 | 44 | // do direct sampling without reparametrization trick 45 | torch::Tensor sample_compare = at::normal(mutens, sigmatens); 46 | 47 | // expect that shapes match: I am not sure how to compare the values as well 48 | EXPECT_NO_THROW(torch::sum(sample - sample_compare).item()); 49 | } 50 | 51 | int main(int argc, char* argv[]) { 52 | ::testing::InitGoogleTest(&argc, argv); 53 | 54 | return RUN_ALL_TESTS(); 55 | } 56 | -------------------------------------------------------------------------------- /src/csrc/lr_schedulers/multistep_lr.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | 21 | #include "internal/base_lr_scheduler.h" 22 | #include "internal/lr_schedulers.h" 23 | 24 | namespace torchfort { 25 | 26 | MultiStepLR::MultiStepLR(torch::optim::Optimizer& optimizer, const std::vector& milestones, const double gamma) 27 | : BaseLRScheduler(optimizer), milestones_(milestones), gamma_(gamma) {} 28 | 29 | std::vector MultiStepLR::get_lrs() { 30 | std::vector lrs = get_current_lrs(); 31 | if (step_count_ == 0 || milestones_.size() == 0) 32 | return lrs; 33 | else { 34 | auto lower_old = std::lower_bound(milestones_.begin(), milestones_.end(), step_count_ - 1, 35 | [](const int& ms, int value) { return ms <= value; }); 36 | auto lower = std::lower_bound(milestones_.begin(), milestones_.end(), step_count_, 37 | [](const int& ms, int value) { return ms <= value; }); 38 | 39 | if (lower_old != lower) { 40 | // in this case we need to decay the LR: 41 | std::transform(lrs.begin(), lrs.end(), lrs.begin(), [this](const double& lr) { return this->gamma_ * lr; }); 42 | } 43 | 44 | return lrs; 45 | } 46 | } 47 | 48 | } // namespace torchfort 49 | -------------------------------------------------------------------------------- /src/csrc/model_state.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | 20 | #include 21 | 22 | #include "internal/exceptions.h" 23 | #include "internal/model_state.h" 24 | 25 | namespace torchfort { 26 | 27 | void ModelState::save(const std::string& fname) { 28 | torch::serialize::OutputArchive archive; 29 | archive.write("step_train", torch::IValue(step_train)); 30 | archive.write("step_inference", torch::IValue(step_inference)); 31 | archive.write("device", torch::IValue(device)); 32 | archive.save_to(fname); 33 | } 34 | 35 | void ModelState::load(const std::string& fname) { 36 | if (!std::filesystem::exists(fname)) { 37 | THROW_INVALID_USAGE(fname + " does not exist."); 38 | } 39 | 40 | torch::serialize::InputArchive archive; 41 | archive.load_from(fname); 42 | 43 | torch::IValue ivalue; 44 | if (!archive.try_read("step_train", ivalue)) { 45 | THROW_INVALID_USAGE(fname + " is missing required data."); 46 | } 47 | step_train = ivalue.to(); 48 | 49 | if (!archive.try_read("step_inference", ivalue)) { 50 | THROW_INVALID_USAGE(fname + " is missing required data."); 51 | } 52 | step_inference = ivalue.to(); 53 | 54 | if (!archive.try_read("device", ivalue)) { 55 | THROW_INVALID_USAGE(fname + " is missing required data."); 56 | } 57 | device = ivalue.to(); 58 | } 59 | 60 | } // namespace torchfort 61 | -------------------------------------------------------------------------------- /src/csrc/losses/l1_loss.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include 24 | #include 25 | 26 | #include "internal/losses.h" 27 | #include "internal/param_map.h" 28 | #include "internal/setup.h" 29 | #include "internal/utils.h" 30 | 31 | namespace torchfort { 32 | 33 | void L1Loss::setup(const ParamMap& params) { 34 | std::set supported_params{"reduction"}; 35 | check_params(supported_params, params.keys()); 36 | 37 | auto options = torch::nn::L1LossOptions(); 38 | try { 39 | std::string reduction = params.get_param("reduction")[0]; 40 | options = options.reduction(get_torch_reduction(reduction)); 41 | } catch (std::out_of_range) { 42 | // use default 43 | } 44 | 45 | module = torch::nn::L1Loss(options); 46 | } 47 | 48 | torch::Tensor L1Loss::forward(const std::vector& inputs, const std::vector& labels, 49 | const std::vector& extra_args) { 50 | if (inputs.size() != 1 || labels.size() != 1 || extra_args.size() != 0) { 51 | THROW_INVALID_USAGE("L1Loss only supports one input tensor, one label tensor, and no extra arguments."); 52 | } 53 | auto x = inputs[0]; 54 | auto y = labels[0]; 55 | return module(x.flatten(), y.flatten()); 56 | } 57 | 58 | } // namespace torchfort 59 | -------------------------------------------------------------------------------- /src/csrc/lr_schedulers/cosine_annealing_lr.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | 20 | #include "internal/base_lr_scheduler.h" 21 | #include "internal/lr_schedulers.h" 22 | 23 | namespace torchfort { 24 | 25 | CosineAnnealingLR::CosineAnnealingLR(torch::optim::Optimizer& optimizer, const unsigned T_max, const double eta_min) 26 | : BaseLRScheduler(optimizer), T_max_(T_max), eta_min_(eta_min) { 27 | base_lrs_ = get_current_lrs(); 28 | } 29 | 30 | double CosineAnnealingLR::update_lr(const double& last_lr, const double& base_lr) { 31 | double lr; 32 | if ((step_count_ - 1 - T_max_) % (2 * T_max_) == 0) { 33 | lr = eta_min_ + 0.5 * (base_lr - eta_min_) * (1. + cos(double(step_count_) * M_PI / double(T_max_))); 34 | } else { 35 | lr = (1. + cos(M_PI * double(step_count_) / double(T_max_))) / 36 | (1. + cos(M_PI * double(step_count_ - 1) / double(T_max_))) * (last_lr - eta_min_) + 37 | eta_min_; 38 | } 39 | 40 | return lr; 41 | } 42 | 43 | std::vector CosineAnnealingLR::get_lrs() { 44 | std::vector lrs = get_current_lrs(); 45 | if (step_count_ == 0 || T_max_ == 0) 46 | return lrs; 47 | else { 48 | std::vector lrs_new; 49 | std::transform(lrs.begin(), lrs.end(), base_lrs_.begin(), std::back_inserter(lrs_new), 50 | [this](const auto& current, const auto& base) { return update_lr(current, base); }); 51 | return lrs_new; 52 | } 53 | } 54 | 55 | } // namespace torchfort 56 | -------------------------------------------------------------------------------- /src/csrc/include/internal/setup.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | 22 | #include 23 | #include 24 | 25 | #include "internal/base_loss.h" 26 | #include "internal/base_lr_scheduler.h" 27 | #include "internal/base_model.h" 28 | #include "internal/lr_schedulers.h" 29 | #include "internal/model_state.h" 30 | #include "internal/model_wrapper.h" 31 | 32 | namespace torchfort { 33 | 34 | void check_params(const std::set& supported_params, const std::set& provided_params); 35 | 36 | ParamMap get_params(const YAML::Node& params_node); 37 | 38 | std::shared_ptr get_model(const YAML::Node& model_node); 39 | 40 | std::shared_ptr get_loss(const YAML::Node& loss_node); 41 | 42 | std::shared_ptr get_optimizer(const YAML::Node& optimizer_node, 43 | std::vector parameters); 44 | 45 | std::shared_ptr get_optimizer(const YAML::Node& optimizer_node, 46 | const std::shared_ptr& model); 47 | 48 | std::shared_ptr get_lr_scheduler(const YAML::Node& lr_scheduler_node, 49 | const std::shared_ptr& optimizer); 50 | 51 | std::shared_ptr get_state(const char* name, const YAML::Node& state_node); 52 | } // namespace torchfort 53 | -------------------------------------------------------------------------------- /src/csrc/losses/mse_loss.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include 24 | #include 25 | 26 | #include "internal/exceptions.h" 27 | #include "internal/losses.h" 28 | #include "internal/param_map.h" 29 | #include "internal/setup.h" 30 | #include "internal/utils.h" 31 | 32 | namespace torchfort { 33 | 34 | void MSELoss::setup(const ParamMap& params) { 35 | std::set supported_params{"reduction"}; 36 | check_params(supported_params, params.keys()); 37 | 38 | auto options = torch::nn::MSELossOptions(); 39 | try { 40 | std::string reduction = params.get_param("reduction")[0]; 41 | options = options.reduction(get_torch_reduction(reduction)); 42 | } catch (std::out_of_range) { 43 | // use default 44 | } 45 | 46 | module = torch::nn::MSELoss(options); 47 | } 48 | 49 | torch::Tensor MSELoss::forward(const std::vector& inputs, const std::vector& labels, 50 | const std::vector& extra_args) { 51 | if (inputs.size() != 1 || labels.size() != 1 || extra_args.size() != 0) { 52 | THROW_INVALID_USAGE("MSELoss only supports one input tensor, one label tensor, and no extra arguments."); 53 | } 54 | auto x = inputs[0]; 55 | auto y = labels[0]; 56 | return module(x.flatten(), y.flatten()); 57 | } 58 | 59 | } // namespace torchfort 60 | -------------------------------------------------------------------------------- /src/csrc/include/internal/rl/distributions.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | #include 20 | #include 21 | 22 | #include "internal/rl/rl.h" 23 | 24 | namespace torchfort { 25 | 26 | namespace rl { 27 | 28 | class Distribution { 29 | 30 | public: 31 | Distribution(const Distribution&) = delete; 32 | 33 | // constructor 34 | Distribution() {} 35 | virtual torch::Tensor rsample() = 0; 36 | virtual torch::Tensor log_prob(torch::Tensor value) = 0; 37 | virtual torch::Tensor entropy() = 0; 38 | }; 39 | 40 | class NormalDistribution : public Distribution, public std::enable_shared_from_this { 41 | public: 42 | NormalDistribution(torch::Tensor mu, torch::Tensor sigma) : mu_(mu), sigma_(sigma) {} 43 | 44 | torch::Tensor rsample() { 45 | auto noise = torch::empty_like(mu_).normal_(0., 1.); 46 | return torch::Tensor(mu_ + sigma_ * noise).clone(); 47 | } 48 | 49 | torch::Tensor log_prob(torch::Tensor value) { 50 | auto var = torch::square(sigma_); 51 | auto log_sigma = sigma_.log(); 52 | auto result = -torch::square(value - mu_) / (2 * var) - log_sigma - std::log(std::sqrt(2. * M_PI)); 53 | 54 | return result; 55 | } 56 | 57 | torch::Tensor entropy() { 58 | auto log_sigma = sigma_.log(); 59 | auto result = log_sigma + 0.5 * (1. + std::log(2. * M_PI)); 60 | 61 | return result; 62 | } 63 | 64 | protected: 65 | torch::Tensor mu_; 66 | torch::Tensor sigma_; 67 | }; 68 | 69 | } // namespace rl 70 | 71 | } // namespace torchfort 72 | -------------------------------------------------------------------------------- /tests/supervised/scripts/setup_tests.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | def save_jit_module(module, fname): 5 | try: 6 | module.to("cuda") 7 | except: 8 | print("PyTorch does not have CUDA support. Saving on CPU.") 9 | module_jit = torch.jit.script(module) 10 | 11 | module_jit.save(fname) 12 | 13 | # Create simple models that just return input for testing 14 | class Net1(torch.nn.Module): 15 | def __init__(self): 16 | super(Net1, self).__init__() 17 | self.layer = torch.nn.Linear(10, 10) 18 | 19 | def forward(self, input1): 20 | x = self.layer(input1) 21 | return input1 + 0.0 * x 22 | 23 | class Net2(torch.nn.Module): 24 | def __init__(self): 25 | super(Net2, self).__init__() 26 | self.layer = torch.nn.Linear(10, 10) 27 | 28 | def forward(self, input1, input2): 29 | x = self.layer(input1) 30 | return input1 + 0.0 * x, input2 + 0.0 * x 31 | 32 | 33 | # Create loss functions with various argument combinations 34 | class Loss1(torch.nn.Module): 35 | def __init__(self): 36 | super(Loss1, self).__init__() 37 | 38 | def forward(self, prediction, label): 39 | return (torch.sum(prediction) + torch.sum(label)) 40 | 41 | class Loss2(torch.nn.Module): 42 | def __init__(self): 43 | super(Loss2, self).__init__() 44 | 45 | def forward(self, prediction1, prediction2, label1, label2): 46 | return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2)) 47 | 48 | class Loss2Extra(torch.nn.Module): 49 | def __init__(self): 50 | super(Loss2Extra, self).__init__() 51 | 52 | def forward(self, prediction1, prediction2, label1, label2, extra_args1, extra_args2): 53 | return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2) + 54 | torch.sum(extra_args1) + torch.sum(extra_args2)) 55 | 56 | def main(): 57 | model1 = Net1() 58 | model2 = Net2() 59 | loss1 = Loss1() 60 | loss2 = Loss2() 61 | loss2_extra = Loss2Extra() 62 | 63 | save_jit_module(model1, "model.pt") 64 | save_jit_module(model2, "model_multiarg.pt") 65 | save_jit_module(loss1, "loss.pt") 66 | save_jit_module(loss2, "loss_multiarg.pt") 67 | save_jit_module(loss2_extra, "loss_multiarg_extra.pt") 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/env.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | #include 21 | // #include 22 | 23 | enum IntegratorType { EXPLICIT_EULER, SEMI_IMPLICIT_EULER }; 24 | using StateVector = std::array; 25 | 26 | class CartPoleEnv { 27 | 28 | public: 29 | CartPoleEnv(); 30 | CartPoleEnv(const CartPoleEnv&) = delete; 31 | CartPoleEnv& operator=(const CartPoleEnv&) = delete; 32 | 33 | StateVector reset(); 34 | std::pair getStateBounds(); 35 | 36 | std::tuple step(float action); 37 | 38 | private: 39 | // sim parameters 40 | bool terminated_; 41 | float gravity_; 42 | float masscart_; 43 | float masspole_; 44 | float total_mass_; 45 | float length_; 46 | float polemass_length_; 47 | float force_mag_; 48 | float dt_; 49 | float penalty_; 50 | IntegratorType kinematics_integrator_; 51 | int steps_beyond_terminated_; 52 | 53 | // threshold parameters 54 | float theta_threshold_radians_; 55 | float x_threshold_; 56 | 57 | // random stuff 58 | std::mt19937_64 rng_; 59 | std::uniform_real_distribution uniform_dist_; 60 | 61 | // state vector 62 | StateVector state_; 63 | }; 64 | 65 | // pybind11 stuff 66 | // namespace py = pybind11; 67 | // PYBIND11_MODULE(environments, m) { 68 | // py::class_(m, "CartPoleEnv", py::dynamic_attr()) 69 | // .def(py::init<>()) 70 | // .def("step", &CartPoleEnv::step, py::arg("action")) 71 | // .def("reset", &CartPoleEnv::reset); 72 | //} 73 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/python/initialize_models.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import argparse as ap 18 | import math 19 | import torch 20 | from functools import partial 21 | from torch import nn 22 | import torch.nn.functional as F 23 | 24 | from models import weight_init, PolicyFunc, ValueFunc 25 | 26 | def main(args): 27 | 28 | # set seed 29 | torch.manual_seed(666) 30 | 31 | # CUDA check 32 | if torch.cuda.is_available(): 33 | torch.cuda.manual_seed(666) 34 | device = torch.device("cuda:0") 35 | else: 36 | device = torch.device("cpu") 37 | 38 | # parameters 39 | batch_size = 64 40 | 41 | # policy model 42 | pmodel = PolicyFunc(hidden_features=args.num_hidden_features).to(device) 43 | weight_init(pmodel) 44 | jpmodel = torch.jit.script(pmodel) 45 | inp = torch.ones((batch_size, 4), dtype=torch.float32, device=device) 46 | out = jpmodel(inp) 47 | print("Policy model:", pmodel) 48 | print("Policy model output shape:", out.shape) 49 | torch.jit.save(jpmodel, "policy.pt") 50 | 51 | # value model 52 | qmodel = ValueFunc(hidden_features=args.num_hidden_features).to(device) 53 | weight_init(qmodel) 54 | jqmodel = torch.jit.script(qmodel) 55 | inp_a = torch.ones((batch_size, 1), dtype=torch.float32, device=device) 56 | out = jqmodel(inp, inp_a) 57 | print("Value model:", qmodel) 58 | print("Value model output shape:", out.shape) 59 | torch.jit.save(jqmodel, "value.pt") 60 | 61 | if __name__ == "__main__": 62 | parser = ap.ArgumentParser() 63 | parser.add_argument("--num_hidden_features", type=int, default=128, help="Number of hidden features") 64 | args = parser.parse_args() 65 | 66 | main(args) 67 | -------------------------------------------------------------------------------- /src/csrc/include/internal/base_lr_scheduler.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | #include 20 | #include 21 | #include 22 | 23 | #include 24 | 25 | #include "internal/exceptions.h" 26 | 27 | namespace torchfort { 28 | 29 | class BaseLRScheduler : public torch::optim::LRScheduler { 30 | public: 31 | BaseLRScheduler(torch::optim::Optimizer& optimizer) : LRScheduler(optimizer) {} 32 | 33 | // Define generic save/load functionalities. Specialize in derived schedulers if 34 | // needed. 35 | void save(const std::string& fname) { 36 | torch::serialize::OutputArchive archive; 37 | archive.write("step_count", torch::IValue((int64_t)step_count_)); 38 | archive.write("lrs", torch::IValue(get_current_lrs())); 39 | archive.save_to(fname); 40 | } 41 | void load(const std::string& fname, torch::optim::Optimizer& optimizer) { 42 | torch::serialize::InputArchive archive; 43 | archive.load_from(fname); 44 | 45 | torch::IValue ivalue; 46 | if (!archive.try_read("step_count", ivalue)) { 47 | THROW_INVALID_USAGE(fname + " is missing required data."); 48 | } 49 | int64_t step_count = ivalue.to(); 50 | step_count_ = step_count; 51 | 52 | if (!archive.try_read("lrs", ivalue)) { 53 | THROW_INVALID_USAGE(fname + " is missing required data."); 54 | } 55 | auto lrs = ivalue.to>(); 56 | // Can't use this method to set the LRs due to it being private in the base LR class. 57 | // set_optimizer_lrs(lrs); 58 | for (const auto i : c10::irange(optimizer.param_groups().size())) { 59 | optimizer.param_groups()[i].options().set_lr(lrs[i]); 60 | } 61 | } 62 | }; 63 | 64 | } // namespace torchfort 65 | -------------------------------------------------------------------------------- /src/csrc/include/internal/losses.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | #include 22 | 23 | #include 24 | #include 25 | #include 26 | 27 | #include "internal/base_loss.h" 28 | #include "internal/defines.h" 29 | #include "internal/param_map.h" 30 | 31 | namespace torchfort { 32 | 33 | struct L1Loss : BaseLoss { 34 | void setup(const ParamMap& params) override; 35 | 36 | torch::Tensor forward(const std::vector& inputs, const std::vector& labels, 37 | const std::vector& extra_args) override; 38 | 39 | torch::nn::L1Loss module; 40 | }; 41 | 42 | struct MSELoss : BaseLoss { 43 | void setup(const ParamMap& params) override; 44 | 45 | torch::Tensor forward(const std::vector& inputs, const std::vector& labels, 46 | const std::vector& extra_args) override; 47 | 48 | torch::nn::MSELoss module; 49 | }; 50 | 51 | struct TorchscriptLoss : BaseLoss { 52 | void setup(const ParamMap& params) override; 53 | 54 | torch::Tensor forward(const std::vector& inputs, const std::vector& labels, 55 | const std::vector& extra_args) override; 56 | 57 | std::shared_ptr module_jit; 58 | }; 59 | 60 | // Creating loss_registry. 61 | BEGIN_LOSS_REGISTRY 62 | 63 | // Add entries for new losses in this section. First argument to REGISTER_LOSS is 64 | // a string key and the second argument is the class name. 65 | REGISTER_LOSS(L1, L1Loss) 66 | REGISTER_LOSS(MSE, MSELoss) 67 | REGISTER_LOSS(torchscript, TorchscriptLoss) 68 | 69 | END_LOSS_REGISTRY 70 | 71 | } // namespace torchfort 72 | -------------------------------------------------------------------------------- /src/csrc/include/internal/training.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #ifdef ENABLE_GPU 21 | #include 22 | #endif 23 | 24 | #include 25 | 26 | #include 27 | 28 | namespace torchfort { 29 | 30 | void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfort_tensor_list_t outputs_in, 31 | cudaStream_t ext_stream = 0); 32 | 33 | void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfort_tensor_list_t labels_in, 34 | float* loss_val, torchfort_tensor_list_t extra_loss_args_in, cudaStream_t ext_stream = 0); 35 | 36 | template 37 | void inference(const char* name, T* input, size_t input_dim, int64_t* input_shape, T* output, size_t output_dim, 38 | int64_t* output_shape, cudaStream_t ext_stream = 0) { 39 | TensorList inputs, outputs; 40 | 41 | inputs.add_tensor(input, input_dim, input_shape); 42 | outputs.add_tensor(output, output_dim, output_shape); 43 | 44 | inference_multiarg(name, &inputs, &outputs, ext_stream); 45 | } 46 | 47 | template 48 | void train(const char* name, T* input, size_t input_dim, int64_t* input_shape, T* label, size_t label_dim, 49 | int64_t* label_shape, T* loss_val, cudaStream_t ext_stream = 0) { 50 | TensorList inputs, labels; 51 | 52 | inputs.add_tensor(input, input_dim, input_shape); 53 | labels.add_tensor(label, label_dim, label_shape); 54 | 55 | // multiarg API expects float loss value, so use temporary here 56 | float loss_val_tmp; 57 | train_multiarg(name, &inputs, &labels, &loss_val_tmp, nullptr, ext_stream); 58 | *loss_val = loss_val_tmp; 59 | } 60 | 61 | } // namespace torchfort 62 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## TorchFort Contribution Guide 2 | We welcome and appreciate external contributions to TorchFort! File a [pull request](https://github.com/NVIDIA/TorchFort/compare) 3 | to get the process started. 4 | 5 | ### Signing Your Work 6 | 7 | - We require that all contributors "sign-off" on their commits. This certifies that the 8 | contribution is your original work, or you have rights to submit it under the same 9 | license, or a compatible license. 10 | 11 | - Any contribution which contains commits that are not Signed-Off will not be accepted. 12 | 13 | - To sign off on a commit you simply use the `--signoff` (or `-s`) option when 14 | committing your changes: 15 | 16 | ```bash 17 | git commit -s -m "Add cool feature." 18 | ``` 19 | 20 | This will append the following to your commit message: 21 | 22 | ```text 23 | Signed-off-by: Your Name 24 | ``` 25 | 26 | - Full text of the DCO: 27 | 28 | ```text 29 | Developer Certificate of Origin 30 | Version 1.1 31 | 32 | Copyright (C) 2004, 2006 The Linux Foundation and its contributors. 33 | 1 Letterman Drive 34 | Suite D4700 35 | San Francisco, CA, 94129 36 | 37 | Everyone is permitted to copy and distribute verbatim copies of this license 38 | document, but changing it is not allowed. 39 | ``` 40 | 41 | ```text 42 | Developer's Certificate of Origin 1.1 43 | 44 | By making a contribution to this project, I certify that: 45 | 46 | (a) The contribution was created in whole or in part by me and I have the right to 47 | submit it under the open source license indicated in the file; or 48 | 49 | (b) The contribution is based upon previous work that, to the best of my knowledge, 50 | is covered under an appropriate open source license and I have the right under that 51 | license to submit that work with modifications, whether created in whole or in part 52 | by me, under the same open source license (unless I am permitted to submit under a 53 | different license), as indicated in the file; or 54 | 55 | (c) The contribution was provided directly to me by some other person who certified 56 | (a), (b) or (c) and I have not modified it. 57 | 58 | (d) I understand and agree that this project and the contribution are public and 59 | that a record of the contribution (including all personal information I submit with 60 | it, including my sign-off) is maintained indefinitely and may be redistributed 61 | consistent with this project or the open source license(s) involved. 62 | 63 | ``` 64 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(cart_pole_example_targets 2 | train_cart_pole 3 | ) 4 | 5 | add_library(environments STATIC) 6 | target_sources(environments 7 | PRIVATE 8 | env.cpp 9 | ) 10 | set_property(TARGET environments PROPERTY POSITION_INDEPENDENT_CODE ON) 11 | 12 | add_executable(train_cart_pole) 13 | target_sources(train_cart_pole 14 | PRIVATE 15 | train.cpp 16 | ) 17 | 18 | find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED) 19 | find_package(pybind11 CONFIG REQUIRED) 20 | pybind11_add_module(PyEnvironments py_env.cpp) 21 | target_link_libraries(PyEnvironments PRIVATE environments) 22 | 23 | foreach(tgt ${cart_pole_example_targets}) 24 | target_include_directories(${tgt} 25 | PRIVATE 26 | ${YAML_CPP_INCLUDE_DIR} 27 | ${MPI_CXX_INCLUDE_DIRS} 28 | ${CUDAToolkit_INCLUDE_DIRS} 29 | ${Python_INCLUDE_DIRS} 30 | ${CMAKE_BINARY_DIR}/include 31 | ) 32 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME}) 33 | target_link_libraries(${tgt} PRIVATE ${TORCH_LIBRARIES}) 34 | target_link_libraries(${tgt} PRIVATE ${Python_LIBRARIES}) 35 | target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX) 36 | target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY}) 37 | target_link_libraries(${tgt} PRIVATE environments) 38 | target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 39 | target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 40 | if (TORCHFORT_ENABLE_GPU) 41 | target_include_directories(${tgt} 42 | PRIVATE 43 | ${CUDAToolkit_INCLUDE_DIRS} 44 | ) 45 | target_link_libraries(${tgt} PRIVATE CUDA::cudart) 46 | target_compile_definitions(${tgt} PRIVATE ENABLE_GPU) 47 | endif() 48 | endforeach() 49 | 50 | # installation 51 | # executable 52 | install( 53 | TARGETS ${cart_pole_example_targets} 54 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole 55 | ) 56 | 57 | # python env 58 | install( 59 | TARGETS PyEnvironments 60 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole/python 61 | ) 62 | 63 | # config files 64 | install( 65 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/config.yaml ${CMAKE_CURRENT_SOURCE_DIR}/config_sim.yaml 66 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole 67 | ) 68 | 69 | # python files 70 | install( 71 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/python/models.py ${CMAKE_CURRENT_SOURCE_DIR}/python/initialize_models.py ${CMAKE_CURRENT_SOURCE_DIR}/python/visualize.py 72 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole/python 73 | ) 74 | -------------------------------------------------------------------------------- /src/csrc/include/internal/logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | #include 22 | 23 | #include "internal/model_pack.h" 24 | 25 | namespace torchfort { 26 | 27 | namespace logging { 28 | 29 | enum level { info, warn, error, wandb }; 30 | 31 | std::string log_level_prefix(level log_level); 32 | void write(const std::filesystem::path& filename, const std::string& message, level log_level); 33 | void print(const std::string& message, level log_level); 34 | 35 | } // namespace logging 36 | 37 | // Declaration of external global variables 38 | extern std::unordered_map models; 39 | 40 | // specialized logging routines 41 | template 42 | void wandb_log(std::shared_ptr state, std::shared_ptr comm, const char* name, const char* metric_name, 43 | int64_t step, T value) { 44 | if (state->enable_wandb_hook) { 45 | std::stringstream os; 46 | os << "model: " << name << ", "; 47 | os << "step: " << step << ", "; 48 | os << metric_name << ": " << value; 49 | if (!comm || (comm && comm->rank == 0)) { 50 | torchfort::logging::write(state->report_file, os.str(), torchfort::logging::wandb); 51 | } 52 | } 53 | } 54 | 55 | template void wandb_log(const char* name, const char* metric_name, int64_t step, T value) { 56 | auto state = models[name].state.get(); 57 | if (state->enable_wandb_hook) { 58 | std::stringstream os; 59 | os << "model: " << name << ", "; 60 | os << "step: " << step << ", "; 61 | os << metric_name << ": " << value; 62 | if (!models[name].comm || (models[name].comm && models[name].comm->rank == 0)) { 63 | torchfort::logging::write(state->report_file, os.str(), torchfort::logging::wandb); 64 | } 65 | } 66 | } 67 | 68 | } // namespace torchfort 69 | -------------------------------------------------------------------------------- /docker/Dockerfile_gnu: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/cuda:12.9.1-devel-ubuntu22.04 2 | 3 | SHELL ["/bin/bash", "-c"] 4 | 5 | ENV CUDA_HOME /usr/local/cuda 6 | 7 | # Install System Dependencies 8 | ENV DEBIAN_FRONTEND noninteractive 9 | RUN apt update -y && \ 10 | apt install -y wget cmake && \ 11 | apt install -y python3 python-is-python3 python3-pip python3-pybind11 && \ 12 | apt install -y vim gfortran git && \ 13 | apt install -y libibverbs-dev ibverbs-utils numactl 14 | 15 | # Download HPCX 16 | RUN cd /opt && \ 17 | wget https://content.mellanox.com/hpc/hpc-x/v2.24.1_cuda12/hpcx-v2.24.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64.tbz && \ 18 | tar xjf hpcx-v2.24.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64.tbz && \ 19 | mv hpcx-v2.24.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64 hpcx && \ 20 | cd /opt && rm hpcx-v2.24.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64.tbz 21 | 22 | ENV PATH /opt/hpcx/ompi/bin:$PATH 23 | ENV LD_LIBRARY_PATH /opt/hpcx/ompi/lib:$LD_LIBRARY_PATH 24 | 25 | RUN echo "source /opt/hpcx/hpcx-init.sh; hpcx_load" >> /root/.bashrc 26 | 27 | # Install PyTorch 28 | RUN pip3 install --no-deps torch==2.8.0 --index-url https://download.pytorch.org/whl/cu129 && \ 29 | pip3 install --no-deps nvidia-cudnn-cu12==9.10.2.21 nvidia-cusparselt-cu12==0.7.1 nvidia-cufile-cu12==1.14.1.1 30 | 31 | # Install yaml-cpp 32 | RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \ 33 | cd yaml-cpp && \ 34 | mkdir build && cd build && \ 35 | cmake -DCMAKE_INSTALL_PREFIX=/opt/yaml-cpp \ 36 | -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" \ 37 | -DBUILD_SHARED_LIBS=OFF \ 38 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && \ 39 | make -j$(nproc) && make install 40 | ENV LD_LIBRARY_PATH /opt/yaml-cpp/lib:${LD_LIBRARY_PATH} 41 | 42 | # Install additional Python dependencies 43 | RUN pip3 install wandb ruamel-yaml matplotlib pygame moviepy 44 | 45 | # Install TorchFort 46 | ENV FC=gfortran 47 | COPY . /torchfort 48 | RUN source /opt/hpcx/hpcx-init.sh && hpcx_load && \ 49 | cd /torchfort && mkdir build && cd build && \ 50 | cmake -DCMAKE_INSTALL_PREFIX=/opt/torchfort \ 51 | -DTORCHFORT_YAML_CPP_ROOT=/opt/yaml-cpp \ 52 | -DTORCHFORT_BUILD_EXAMPLES=1 \ 53 | -DTORCHFORT_BUILD_TESTS=1 \ 54 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \ 55 | .. && \ 56 | make -j$(nproc) install && \ 57 | cd / && rm -rf torchfort 58 | ENV LD_LIBRARY_PATH /opt/torchfort/lib:${LD_LIBRARY_PATH} 59 | ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH} 60 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/nvhpc:25.7-devel-cuda12.9-ubuntu22.04 2 | 3 | SHELL ["/bin/bash", "-c"] 4 | 5 | ENV CUDA_HOME /opt/nvidia/hpc_sdk/Linux_x86_64/25.7/cuda 6 | ENV LD_LIBRARY_PATH /opt/nvidia/hpc_sdk/Linux_x86_64/25.7/cuda/12.9/extras/CUPTI/lib64:$LD_LIBRARY_PATH 7 | 8 | # Install System Dependencies 9 | ENV DEBIAN_FRONTEND noninteractive 10 | RUN apt update -y && \ 11 | apt install -y wget cmake && \ 12 | apt install -y python3 python-is-python3 python3-pip python3-pybind11 && \ 13 | apt install -y vim gfortran 14 | 15 | # Install PyTorch (and select dependencies) 16 | RUN pip3 install --no-deps torch==2.8.0 --index-url https://download.pytorch.org/whl/cu129 && \ 17 | pip3 install --no-deps nvidia-cudnn-cu12==9.10.2.21 nvidia-cusparselt-cu12==0.7.1 nvidia-cufile-cu12==1.14.1.1 nvidia-nccl-cu12==2.27.3 18 | 19 | # Remove conflicting NCCL version from NVHPC SDK, add libnccl.so symlink to pip installed NCCL 20 | RUN rm -rf /opt/nvidia/hpc_sdk/Linux_x86_64/25.7/comm_libs/nccl/ && \ 21 | cd /usr/local/lib/python3.10/dist-packages/nvidia/nccl/lib && \ 22 | ln -s libnccl.so.2 libnccl.so 23 | ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/nvidia/nccl/lib:$LD_LIBRARY_PATH 24 | 25 | # Install yaml-cpp 26 | RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \ 27 | cd yaml-cpp && \ 28 | mkdir build && cd build && \ 29 | cmake -DCMAKE_INSTALL_PREFIX=/opt/yaml-cpp \ 30 | -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" \ 31 | -DBUILD_SHARED_LIBS=OFF \ 32 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && \ 33 | make -j$(nproc) && make install 34 | ENV LD_LIBRARY_PATH /opt/yaml-cpp/lib:${LD_LIBRARY_PATH} 35 | 36 | # Install additional Python dependencies 37 | RUN pip3 install wandb ruamel-yaml matplotlib pygame moviepy 38 | 39 | # Install TorchFort 40 | ENV FC=nvfortran 41 | COPY . /torchfort 42 | RUN cd /torchfort && mkdir build && cd build && \ 43 | CUDA_PATH=$CUDA_HOME \ 44 | cmake -DCMAKE_INSTALL_PREFIX=/opt/torchfort \ 45 | -DCMAKE_CXX_COMPILER=`which g++` \ 46 | -DTORCHFORT_YAML_CPP_ROOT=/opt/yaml-cpp \ 47 | -DTORCHFORT_NCCL_ROOT=/usr/local/lib/python3.10/dist-packages/nvidia/nccl \ 48 | -DTORCHFORT_BUILD_EXAMPLES=1 \ 49 | -DTORCHFORT_BUILD_TESTS=1 \ 50 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \ 51 | .. && \ 52 | make -j$(nproc) install && \ 53 | cd / && rm -rf torchfort 54 | ENV LD_LIBRARY_PATH /opt/torchfort/lib:${LD_LIBRARY_PATH} 55 | ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH} 56 | -------------------------------------------------------------------------------- /tests/general/scripts/setup_tests.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | def save_jit_module(module, fname): 5 | try: 6 | module.to("cuda") 7 | except: 8 | print("PyTorch does not have CUDA support. Saving on CPU.") 9 | module_jit = torch.jit.script(module) 10 | 11 | module_jit.save(fname) 12 | 13 | # Create simple models that just return input for testing 14 | class Net1(torch.nn.Module): 15 | def __init__(self): 16 | super(Net1, self).__init__() 17 | self.layer = torch.nn.Linear(10, 10) 18 | 19 | def forward(self, input1): 20 | x = self.layer(input1) 21 | return input1 + 0.0 * x 22 | 23 | class Net2(torch.nn.Module): 24 | def __init__(self): 25 | super(Net2, self).__init__() 26 | self.layer = torch.nn.Linear(10, 10) 27 | 28 | def forward(self, input1, input2): 29 | x = self.layer(input1) 30 | return input1 + 0.0 * x, input2 + 0.0 * x 31 | 32 | 33 | # Create loss functions with various argument combinations 34 | class Loss1(torch.nn.Module): 35 | def __init__(self): 36 | super(Loss1, self).__init__() 37 | 38 | def forward(self, prediction, label): 39 | return (torch.sum(prediction) + torch.sum(label)) / (2 * prediction.numel()) 40 | 41 | class Loss2(torch.nn.Module): 42 | def __init__(self): 43 | super(Loss2, self).__init__() 44 | 45 | def forward(self, prediction1, prediction2, label1, label2): 46 | return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2)) / (4 * prediction1.numel()) 47 | 48 | class Loss2Extra(torch.nn.Module): 49 | def __init__(self): 50 | super(Loss2Extra, self).__init__() 51 | 52 | def forward(self, prediction1, prediction2, label1, label2, extra_args1, extra_args2): 53 | return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2) + 54 | torch.sum(extra_args1) + torch.sum(extra_args2)) / (6 * prediction1.numel()) 55 | 56 | class Loss3(torch.nn.Module): 57 | def __init__(self): 58 | super(Loss3, self).__init__() 59 | 60 | def forward(self, prediction, label): 61 | return torch.sum(prediction), torch.sum(label) 62 | 63 | def main(): 64 | model1 = Net1() 65 | model2 = Net2() 66 | loss1 = Loss1() 67 | loss2 = Loss2() 68 | loss2_extra = Loss2Extra() 69 | loss3 = Loss3() 70 | 71 | save_jit_module(model1, "model.pt") 72 | save_jit_module(model2, "model_multiarg.pt") 73 | save_jit_module(loss1, "loss.pt") 74 | save_jit_module(loss2, "loss_multiarg.pt") 75 | save_jit_module(loss2_extra, "loss_multiarg_extra.pt") 76 | save_jit_module(loss3, "loss_multiout.pt") 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /src/csrc/losses/torchscript_loss.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include 26 | #include 27 | #include 28 | 29 | #include "internal/exceptions.h" 30 | #include "internal/losses.h" 31 | #include "internal/param_map.h" 32 | #include "internal/setup.h" 33 | #include "internal/utils.h" 34 | 35 | namespace torchfort { 36 | 37 | void TorchscriptLoss::setup(const ParamMap& params) { 38 | std::string jit_loss_fname; 39 | try { 40 | jit_loss_fname = params.get_param("filename")[0]; 41 | } catch (std::out_of_range) { 42 | THROW_INVALID_USAGE("filename parameter is required for torchscript loss type."); 43 | } 44 | 45 | if (!std::filesystem::exists(jit_loss_fname)) { 46 | THROW_INVALID_USAGE(jit_loss_fname + " does not exist."); 47 | } 48 | 49 | module_jit = std::shared_ptr(new torch::jit::Module); 50 | *module_jit = torch::jit::load(jit_loss_fname); 51 | } 52 | 53 | torch::Tensor TorchscriptLoss::forward(const std::vector& inputs, 54 | const std::vector& labels, 55 | const std::vector& extra_args) { 56 | std::vector inputs_jit; 57 | inputs_jit.insert(inputs_jit.end(), inputs.begin(), inputs.end()); 58 | inputs_jit.insert(inputs_jit.end(), labels.begin(), labels.end()); 59 | inputs_jit.insert(inputs_jit.end(), extra_args.begin(), extra_args.end()); 60 | 61 | auto result = module_jit->forward(inputs_jit); 62 | if (!result.isTensor()) { 63 | THROW_INVALID_USAGE("TorchscriptLoss only supports returning a single loss tensor."); 64 | } 65 | return result.toTensor(); 66 | } 67 | 68 | } // namespace torchfort 69 | -------------------------------------------------------------------------------- /src/csrc/models/mlp_model.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | 20 | #include 21 | 22 | #include "internal/models.h" 23 | #include "internal/param_map.h" 24 | #include "internal/setup.h" 25 | 26 | namespace torchfort { 27 | 28 | // MLP model in C++ using libtorch 29 | void MLPModel::setup(const ParamMap& params) { 30 | // Extract params from input map. 31 | std::set supported_params{"dropout", "flatten_non_batch_dims", "layer_sizes"}; 32 | check_params(supported_params, params.keys()); 33 | 34 | dropout = params.get_param("dropout", 0.0)[0]; 35 | flatten_non_batch_dims = params.get_param("flatten_non_batch_dims", true)[0]; 36 | layer_sizes = params.get_param("layer_sizes"); 37 | 38 | // Construct and register submodules. 39 | for (int i = 0; i < layer_sizes.size() - 1; ++i) { 40 | fc_layers.push_back( 41 | register_module("fc" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1]))); 42 | if (i < layer_sizes.size() - 2) { 43 | biases.push_back(register_parameter("b" + std::to_string(i), torch::zeros(layer_sizes[i + 1]))); 44 | } 45 | } 46 | } 47 | 48 | // Implement the forward function. 49 | std::vector MLPModel::forward(const std::vector& inputs) { 50 | if (inputs.size() > 1) 51 | THROW_INVALID_USAGE("Built-in MLP model does not support multiple input tensors."); 52 | 53 | auto x = inputs[0]; 54 | 55 | if (flatten_non_batch_dims) { 56 | x = x.reshape({x.size(0), -1}); 57 | } 58 | 59 | for (int i = 0; i < layer_sizes.size() - 1; ++i) { 60 | if (i < layer_sizes.size() - 2) { 61 | x = torch::relu(fc_layers[i]->forward(x) + biases[i]); 62 | x = torch::dropout(x, dropout, is_training()); 63 | } else { 64 | x = fc_layers[i]->forward(x); 65 | } 66 | } 67 | return std::vector{x}; 68 | } 69 | 70 | } // namespace torchfort 71 | -------------------------------------------------------------------------------- /src/csrc/include/internal/lr_schedulers.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | #include 20 | 21 | #include 22 | 23 | #include "internal/base_lr_scheduler.h" 24 | 25 | namespace torchfort { 26 | 27 | class CosineAnnealingLR : public BaseLRScheduler { 28 | public: 29 | CosineAnnealingLR(torch::optim::Optimizer& optimizer, const unsigned T_max, const double eta_min = 0.0); 30 | 31 | private: 32 | std::vector get_lrs() override; 33 | double update_lr(const double& last_lr, const double& base_lr); 34 | 35 | const unsigned T_max_; 36 | const double eta_min_; 37 | std::vector base_lrs_; 38 | }; 39 | 40 | class MultiStepLR : public BaseLRScheduler { 41 | public: 42 | MultiStepLR(torch::optim::Optimizer& optimizer, const std::vector& milestones, const double gamma = 0.1); 43 | 44 | private: 45 | std::vector get_lrs() override; 46 | 47 | const std::vector milestones_; 48 | const double gamma_; 49 | }; 50 | 51 | class PolynomialLR : public BaseLRScheduler { 52 | public: 53 | PolynomialLR(torch::optim::Optimizer& optimizer, const unsigned total_iters, const double power = 1.0); 54 | 55 | private: 56 | std::vector get_lrs() override; 57 | 58 | const unsigned total_iters_; 59 | const double power_; 60 | }; 61 | 62 | class StepLR : public BaseLRScheduler { 63 | public: 64 | StepLR(torch::optim::Optimizer& optimizer, const unsigned step_size, const double gamma = 0.1); 65 | 66 | private: 67 | std::vector get_lrs() override; 68 | 69 | const unsigned step_size_; 70 | const double gamma_; 71 | }; 72 | 73 | class LinearLR : public BaseLRScheduler { 74 | public: 75 | LinearLR(torch::optim::Optimizer& optimizer, const unsigned total_iters, const double start_factor = 0.333, 76 | const double end_factor = 1.0); 77 | 78 | private: 79 | std::vector get_lrs() override; 80 | 81 | const unsigned total_iters_; 82 | const double start_factor_, end_factor_; 83 | }; 84 | 85 | } // namespace torchfort 86 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | import sphinx_rtd_theme 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | # import sys 17 | # sys.path.insert(0, os.path.abspath('.')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'TorchFort' 23 | copyright = '2023-2025, NVIDIA Corporation' 24 | author = 'NVIDIA Corporation' 25 | 26 | version = os.getenv("TORCHFORT_GIT_SHA", default="N/A") 27 | #release = version 28 | 29 | 30 | # -- General configuration --------------------------------------------------- 31 | 32 | # Add any Sphinx extension module names here, as strings. They can be 33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 34 | # ones. 35 | extensions = [ 36 | 'breathe', 37 | 'sphinx.ext.mathjax', 38 | 'sphinx_tabs.tabs', 39 | 'sphinxfortran.fortran_domain', 40 | ] 41 | 42 | # Add any paths that contain templates here, relative to this directory. 43 | templates_path = ['_templates'] 44 | 45 | # List of patterns, relative to source directory, that match files and 46 | # directories to ignore when looking for source files. 47 | # This pattern also affects html_static_path and html_extra_path. 48 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 49 | 50 | # The name of the Pygments (syntax highlighting) style to use. 51 | #pygments_style = 'sphinx' 52 | 53 | highlight_language = 'cpp' 54 | 55 | def setup(app): 56 | app.add_css_file('style.css') 57 | 58 | # -- Options for HTML output ------------------------------------------------- 59 | 60 | # The theme to use for HTML and HTML Help pages. See the documentation for 61 | # a list of builtin themes. 62 | # 63 | html_theme = 'sphinx_rtd_theme' 64 | #html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 65 | 66 | html_theme_options = { 67 | "collapse_navigation" : False, 68 | "navigation_depth" : 6, 69 | } 70 | 71 | # Add any paths that contain custom static files (such as style sheets) here, 72 | # relative to this directory. They are copied after the builtin static files, 73 | # so a file named "default.css" will overwrite the builtin "default.css". 74 | html_static_path = ['_static'] 75 | 76 | breathe_projects = { "torchfort": "xml/" } 77 | breathe_default_project = "torchfort" 78 | -------------------------------------------------------------------------------- /src/csrc/logging.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include "internal/exceptions.h" 26 | #include "internal/logging.h" 27 | 28 | namespace torchfort { 29 | namespace logging { 30 | 31 | std::mutex logging_mutex; 32 | static std::unique_ptr logfile; 33 | 34 | std::string log_level_prefix(level log_level) { 35 | if (log_level == level::info) { 36 | return "TORCHFORT::INFO:"; 37 | } else if (log_level == level::warn) { 38 | return "TORCHFORT::WARN:"; 39 | } else if (log_level == level::error) { 40 | return "TORCHFORT::ERROR:"; 41 | } else if (log_level == level::wandb) { 42 | return "TORCHFORT::WANDB:"; 43 | } else { 44 | THROW_INVALID_USAGE("Unknown log level encountered."); 45 | } 46 | } 47 | 48 | void print(const std::string& message, level log_level) { 49 | std::cout << log_level_prefix(log_level) << " "; 50 | std::cout << message << std::endl; 51 | } 52 | 53 | bool open_logfile(const std::filesystem::path& filename) { 54 | 55 | // check if filename is empty, meaning we do not want to log 56 | if (filename.empty()) { 57 | return false; 58 | } 59 | 60 | // check if path exists 61 | if (filename.has_parent_path()) { 62 | auto path = filename.parent_path(); 63 | std::filesystem::create_directories(path); 64 | } 65 | 66 | logfile = std::make_unique(filename, std::ofstream::out | std::ofstream::app); 67 | 68 | return true; 69 | } 70 | 71 | void write(const std::filesystem::path& filename, const std::string& message, level log_level) { 72 | std::lock_guard guard(logging_mutex); 73 | 74 | // check of logfile if already open 75 | if (logfile == nullptr) { 76 | if (!open_logfile(filename)) 77 | return; 78 | } 79 | auto line = log_level_prefix(log_level) + " " + message + "\n"; 80 | logfile->write(line.c_str(), line.size()); 81 | logfile->flush(); 82 | } 83 | 84 | } // namespace logging 85 | } // namespace torchfort 86 | -------------------------------------------------------------------------------- /tests/general/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | 3 | set(test_targets 4 | test_losses 5 | ) 6 | 7 | add_executable(test_losses) 8 | target_sources(test_losses 9 | PRIVATE 10 | test_losses.cpp 11 | ) 12 | 13 | find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED) 14 | 15 | foreach(tgt ${test_targets}) 16 | target_include_directories(${tgt} 17 | PRIVATE 18 | ${YAML_CPP_INCLUDE_DIR} 19 | ${MPI_CXX_INCLUDE_DIRS} 20 | ${CMAKE_BINARY_DIR}/include 21 | ${CMAKE_CURRENT_SOURCE_DIR}/../ 22 | ) 23 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME}) 24 | target_link_libraries(${tgt} PRIVATE ${TORCH_LIBRARIES}) 25 | target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY}) 26 | target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX) 27 | target_link_libraries(${tgt} PRIVATE GTest::gtest_main) 28 | target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 29 | target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 30 | if (TORCHFORT_ENABLE_GPU) 31 | target_include_directories(${tgt} 32 | PRIVATE 33 | ${CUDAToolkit_INCLUDE_DIRS} 34 | ${NCCL_INCLUDE_DIR} 35 | ) 36 | target_link_libraries(${tgt} PRIVATE CUDA::cudart) 37 | target_compile_definitions(${tgt} PRIVATE ENABLE_GPU) 38 | endif() 39 | 40 | # discover tests: we have an issue with the work dir of gtest so disable that for now 41 | #gtest_discover_tests(${tgt}) 42 | add_test(NAME ${tgt} COMMAND ${tgt} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) 43 | endforeach() 44 | 45 | # installation 46 | # executable 47 | install( 48 | TARGETS ${test_targets} 49 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general 50 | ) 51 | 52 | # copy files 53 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mse.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 54 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mse_multiarg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 55 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/l1.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 56 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/l1_multiarg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 57 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 58 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 59 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg_extra.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 60 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiout.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 61 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/scripts/setup_tests.py DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/scripts) 62 | -------------------------------------------------------------------------------- /tests/rl/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | 3 | set(test_targets 4 | test_replay_buffer 5 | test_rollout_buffer 6 | test_distributions 7 | test_interface 8 | test_off_policy 9 | test_on_policy 10 | ) 11 | 12 | add_executable(test_replay_buffer) 13 | target_sources(test_replay_buffer 14 | PRIVATE 15 | test_replay_buffer.cpp 16 | ) 17 | 18 | add_executable(test_rollout_buffer) 19 | target_sources(test_rollout_buffer 20 | PRIVATE 21 | test_rollout_buffer.cpp 22 | ) 23 | 24 | add_executable(test_distributions) 25 | target_sources(test_distributions 26 | PRIVATE 27 | test_distributions.cpp 28 | ) 29 | 30 | add_executable(test_interface) 31 | target_sources(test_interface 32 | PRIVATE 33 | test_interface.cpp 34 | ) 35 | 36 | add_executable(test_off_policy) 37 | target_sources(test_off_policy 38 | PRIVATE 39 | test_off_policy.cpp 40 | ) 41 | 42 | add_executable(test_on_policy) 43 | target_sources(test_on_policy 44 | PRIVATE 45 | test_on_policy.cpp 46 | ) 47 | 48 | 49 | find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED) 50 | 51 | foreach(tgt ${test_targets}) 52 | target_include_directories(${tgt} 53 | PRIVATE 54 | ${YAML_CPP_INCLUDE_DIR} 55 | ${MPI_CXX_INCLUDE_DIRS} 56 | ${CMAKE_BINARY_DIR}/include 57 | ) 58 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME}) 59 | target_link_libraries(${tgt} PRIVATE ${TORCH_LIBRARIES}) 60 | target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY}) 61 | target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX) 62 | target_link_libraries(${tgt} PRIVATE GTest::gtest_main) 63 | target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 64 | target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 65 | if (TORCHFORT_ENABLE_GPU) 66 | target_include_directories(${tgt} 67 | PRIVATE 68 | ${CUDAToolkit_INCLUDE_DIRS} 69 | ${NCCL_INCLUDE_DIR} 70 | ) 71 | target_link_libraries(${tgt} PRIVATE CUDA::cudart) 72 | target_compile_definitions(${tgt} PRIVATE ENABLE_GPU) 73 | endif() 74 | 75 | # discover tests: we have an issue with the work dir of gtest so disable that for now 76 | #gtest_discover_tests(${tgt}) 77 | add_test(NAME ${tgt} COMMAND ${tgt} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) 78 | endforeach() 79 | 80 | # installation 81 | # executable 82 | install( 83 | TARGETS ${test_targets} 84 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl 85 | ) 86 | 87 | # copy files 88 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/td3.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl/configs) 89 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/ddpg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl/configs) 90 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/sac.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl/configs) 91 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/ppo.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl/configs) 92 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/python/models.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import math 18 | import torch 19 | from torch import nn 20 | import torch.nn.functional as F 21 | 22 | def weight_init(model, scale=0.02): 23 | with torch.no_grad(): 24 | for m in model.modules(): 25 | if isinstance(m, nn.Linear): 26 | sqrtk = math.sqrt(1./float(m.weight.shape[1])) 27 | nn.init.uniform_(m.weight, a=-sqrtk, b=sqrtk) 28 | if m.bias is not None: 29 | m.bias.data.zero_() 30 | 31 | class PolicyFunc(nn.Module): 32 | def __init__(self, hidden_features=128): 33 | super(PolicyFunc, self).__init__() 34 | 35 | layers = [nn.Linear(in_features = 4, 36 | out_features = hidden_features, 37 | bias=True), 38 | nn.ReLU(), 39 | nn.Linear(in_features = hidden_features, 40 | out_features = hidden_features // 2, 41 | bias=True), 42 | nn.ReLU(), 43 | nn.Linear(in_features = hidden_features // 2, 44 | out_features = 1, 45 | bias=True), 46 | nn.Tanh()] 47 | 48 | self.fwd = nn.Sequential(*layers) 49 | 50 | def forward(self, x: torch.Tensor) -> torch.Tensor: 51 | return self.fwd(x) 52 | 53 | class ValueFunc(nn.Module): 54 | def __init__(self, hidden_features=128): 55 | super(ValueFunc, self).__init__() 56 | 57 | layers = [nn.Linear(in_features = 5, 58 | out_features = hidden_features, 59 | bias=True), 60 | nn.ReLU(), 61 | nn.Linear(in_features = hidden_features, 62 | out_features = hidden_features // 2, 63 | bias=True), 64 | nn.ReLU(), 65 | nn.Linear(in_features = hidden_features // 2, 66 | out_features = 1, 67 | bias=True)] 68 | 69 | self.fwd = nn.Sequential(*layers) 70 | 71 | def forward(self, s: torch.Tensor, a: torch.Tensor) -> torch.Tensor: 72 | x = torch.cat([s, a], dim=1) 73 | return self.fwd(x) 74 | -------------------------------------------------------------------------------- /src/csrc/models/rl/sac_model.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "internal/models.h" 6 | #include "internal/param_map.h" 7 | #include "internal/setup.h" 8 | 9 | namespace torchfort { 10 | 11 | // SACMLP model in C++ using libtorch 12 | void SACMLPModel::setup(const ParamMap& params) { 13 | // Extract params from input map. 14 | std::set supported_params{"dropout", "layer_sizes", "state_dependent_sigma", "log_sigma_init"}; 15 | check_params(supported_params, params.keys()); 16 | 17 | dropout = params.get_param("dropout", 0.0)[0]; 18 | layer_sizes = params.get_param("layer_sizes"); 19 | state_dependent_sigma = params.get_param("state_dependent_sigma", true)[0]; 20 | double log_sigma_init = params.get_param("log_sigma_init", 0.)[0]; 21 | 22 | // Construct and register submodules. 23 | for (int i = 0; i < layer_sizes.size() - 1; ++i) { 24 | if (i < layer_sizes.size() - 2) { 25 | encoder_layers.push_back( 26 | register_module("encoder_fc_" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1]))); 27 | biases.push_back(register_parameter("encoder_b_" + std::to_string(i), torch::zeros(layer_sizes[i + 1]))); 28 | } else { 29 | // first output: mu 30 | out_layers.push_back( 31 | register_module("out_fc_1_" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1]))); 32 | out_biases.push_back(register_parameter("out_b_1_" + std::to_string(i), torch::zeros(layer_sizes[i + 1]))); 33 | // second output: log_sigma 34 | if (state_dependent_sigma) { 35 | out_layers.push_back( 36 | register_module("out_fc_2_" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1]))); 37 | out_biases.push_back(register_parameter("out_b_2_" + std::to_string(i), torch::zeros(layer_sizes[i + 1]))); 38 | } else { 39 | out_biases.push_back( 40 | register_parameter("out_b_2_" + std::to_string(i), torch::ones(layer_sizes[i + 1]) * log_sigma_init)); 41 | } 42 | } 43 | } 44 | } 45 | 46 | // Implement the forward function. 47 | std::vector SACMLPModel::forward(const std::vector& inputs) { 48 | 49 | // make sure that exactly two tensors are fed (state and action) 50 | if (inputs.size() != 1) { 51 | THROW_INVALID_USAGE("You have to provide exactly one tensor (state) to the SACMLPModel"); 52 | } 53 | 54 | // unpack 55 | auto state = inputs[0]; 56 | 57 | // expand dims if necessary 58 | if (state.dim() == 1) { 59 | state = state.unsqueeze(0); 60 | } 61 | 62 | // flatten everything beyond dim 0: 63 | auto x = state.reshape({state.size(0), -1}); 64 | torch::Tensor y, z; 65 | 66 | for (int i = 0; i < layer_sizes.size() - 1; ++i) { 67 | if (i < layer_sizes.size() - 2) { 68 | // encoder part 69 | x = torch::relu(encoder_layers[i]->forward(x) + biases[i]); 70 | x = torch::dropout(x, dropout, is_training()); 71 | } else { 72 | // y 73 | y = out_layers[0]->forward(x) + out_biases[0]; 74 | // z 75 | if (state_dependent_sigma) { 76 | z = out_layers[1]->forward(x) + out_biases[1]; 77 | } else { 78 | z = out_biases[1]; 79 | } 80 | } 81 | } 82 | return std::vector{y, z}; 83 | } 84 | 85 | } // namespace torchfort 86 | -------------------------------------------------------------------------------- /src/csrc/models/rl/common_models.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | 20 | #include 21 | 22 | #include "internal/models.h" 23 | #include "internal/param_map.h" 24 | #include "internal/setup.h" 25 | 26 | namespace torchfort { 27 | 28 | // MLP model in C++ using libtorch 29 | void CriticMLPModel::setup(const ParamMap& params) { 30 | // Extract params from input map. 31 | std::set supported_params{"dropout", "layer_sizes"}; 32 | check_params(supported_params, params.keys()); 33 | 34 | dropout = params.get_param("dropout", 0.0)[0]; 35 | layer_sizes = params.get_param("layer_sizes"); 36 | 37 | // sanity checks 38 | // make sure that value function is emitting a scalar. 39 | if (layer_sizes[layer_sizes.size() - 1] != 1) { 40 | THROW_INVALID_USAGE( 41 | "CriticMLPModel::setup: error, the value of the last element of layer_sizes has to be equal to one."); 42 | } 43 | 44 | // Construct and register submodules. 45 | for (int i = 0; i < layer_sizes.size() - 1; ++i) { 46 | fc_layers.push_back( 47 | register_module("fc" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1]))); 48 | if (i < layer_sizes.size() - 2) { 49 | biases.push_back(register_parameter("b" + std::to_string(i), torch::zeros(layer_sizes[i + 1]))); 50 | } 51 | } 52 | } 53 | 54 | // Implement the forward function. 55 | std::vector CriticMLPModel::forward(const std::vector& inputs) { 56 | 57 | // makse sure that exactly two tensors are fed, state and action: 58 | if (inputs.size() != 2) { 59 | THROW_INVALID_USAGE("You have to provide exactly two tensors (state, action) to the CriticMLPModel"); 60 | } 61 | 62 | // unpack 63 | auto state = inputs[0]; 64 | auto action = inputs[1]; 65 | 66 | // expand dims if necessary 67 | if (state.dim() == 1) { 68 | state = state.unsqueeze(0); 69 | } 70 | if (action.dim() == 1) { 71 | action = action.unsqueeze(0); 72 | } 73 | 74 | // flatten everything beyond dim 0: 75 | state = state.reshape({state.size(0), -1}); 76 | action = action.reshape({action.size(0), -1}); 77 | 78 | // concatenate inputs along feature dimension 79 | auto x = torch::cat({state, action}, 1); 80 | 81 | // forward pass 82 | for (int i = 0; i < layer_sizes.size() - 1; ++i) { 83 | if (i < layer_sizes.size() - 2) { 84 | x = torch::relu(fc_layers[i]->forward(x) + biases[i]); 85 | x = torch::dropout(x, dropout, is_training()); 86 | } else { 87 | x = fc_layers[i]->forward(x); 88 | } 89 | } 90 | return std::vector{x}; 91 | } 92 | 93 | } // namespace torchfort 94 | -------------------------------------------------------------------------------- /tests/supervised/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | 3 | set(test_targets 4 | test_checkpoint 5 | test_training 6 | test_distributed_training 7 | ) 8 | 9 | add_executable(test_checkpoint) 10 | target_sources(test_checkpoint 11 | PRIVATE 12 | test_checkpoint.cpp 13 | ) 14 | 15 | add_executable(test_training) 16 | target_sources(test_training 17 | PRIVATE 18 | test_training.cpp 19 | ) 20 | 21 | add_executable(test_distributed_training) 22 | target_sources(test_distributed_training 23 | PRIVATE 24 | test_distributed_training.cpp 25 | ) 26 | 27 | find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED) 28 | 29 | foreach(tgt ${test_targets}) 30 | target_include_directories(${tgt} 31 | PRIVATE 32 | ${YAML_CPP_INCLUDE_DIR} 33 | ${MPI_CXX_INCLUDE_DIRS} 34 | ${CMAKE_BINARY_DIR}/include 35 | ${CMAKE_CURRENT_SOURCE_DIR}/../ 36 | ) 37 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME}) 38 | target_link_libraries(${tgt} PRIVATE ${TORCH_LIBRARIES}) 39 | target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY}) 40 | target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX) 41 | target_link_libraries(${tgt} PRIVATE GTest::gtest_main) 42 | target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 43 | target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 44 | if (TORCHFORT_ENABLE_GPU) 45 | target_include_directories(${tgt} 46 | PRIVATE 47 | ${CUDAToolkit_INCLUDE_DIRS} 48 | ${NCCL_INCLUDE_DIR} 49 | ) 50 | target_link_libraries(${tgt} PRIVATE CUDA::cudart) 51 | target_compile_definitions(${tgt} PRIVATE ENABLE_GPU) 52 | endif() 53 | 54 | # discover tests: we have an issue with the work dir of gtest so disable that for now 55 | #gtest_discover_tests(${tgt}) 56 | add_test(NAME ${tgt} COMMAND ${tgt} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) 57 | endforeach() 58 | 59 | # installation 60 | # executable 61 | install( 62 | TARGETS ${test_targets} 63 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised 64 | ) 65 | 66 | # copy files 67 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mlp.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 68 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mlp2.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 69 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mlp3.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 70 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mlp2_gradacc.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 71 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/missing_opt.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 72 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/missing_loss.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 73 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 74 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 75 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg_extra.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 76 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/scripts/setup_tests.py DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/scripts) 77 | -------------------------------------------------------------------------------- /src/csrc/include/internal/models.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | 22 | #include 23 | 24 | #include "internal/base_model.h" 25 | #include "internal/defines.h" 26 | #include "internal/param_map.h" 27 | 28 | namespace torchfort { 29 | 30 | // MLP model in C++ using libtorch 31 | struct MLPModel : BaseModel, public std::enable_shared_from_this { 32 | void setup(const ParamMap& params) override; 33 | std::vector forward(const std::vector& inputs) override; 34 | 35 | double dropout; 36 | bool flatten_non_batch_dims; 37 | std::vector layer_sizes; 38 | 39 | // Use one of many "standard library" modules. 40 | std::vector fc_layers; 41 | std::vector biases; 42 | }; 43 | 44 | struct CriticMLPModel : BaseModel, public std::enable_shared_from_this { 45 | void setup(const ParamMap& params) override; 46 | std::vector forward(const std::vector& inputs) override; 47 | 48 | double dropout; 49 | std::vector layer_sizes; 50 | 51 | // Use one of many "standard library" modules. 52 | std::vector fc_layers; 53 | std::vector biases; 54 | }; 55 | 56 | struct SACMLPModel : BaseModel, public std::enable_shared_from_this { 57 | void setup(const ParamMap& params) override; 58 | std::vector forward(const std::vector& inputs) override; 59 | 60 | double dropout; 61 | std::vector layer_sizes; 62 | bool state_dependent_sigma; 63 | 64 | // A SAC Model has a common encoder and two output layers for mu and log-sigma 65 | std::vector encoder_layers; 66 | std::vector out_layers; 67 | std::vector biases; 68 | std::vector out_biases; 69 | }; 70 | 71 | struct ActorCriticMLPModel : BaseModel, public std::enable_shared_from_this { 72 | void setup(const ParamMap& params) override; 73 | std::vector forward(const std::vector& inputs) override; 74 | 75 | double dropout; 76 | std::vector encoder_layer_sizes, actor_layer_sizes, value_layer_sizes; 77 | bool state_dependent_sigma; 78 | 79 | // An AC Model has a common encoder and then an MLP for actor and one for value 80 | std::vector encoder_layers, actor_layers, value_layers; 81 | std::vector encoder_biases, actor_biases, value_biases; 82 | }; 83 | 84 | // Creating model_registry. 85 | BEGIN_MODEL_REGISTRY 86 | 87 | // Add entries for new models in this section. 88 | REGISTER_MODEL(MLP, MLPModel) 89 | REGISTER_MODEL(CriticMLP, CriticMLPModel) 90 | REGISTER_MODEL(SACMLP, SACMLPModel) 91 | REGISTER_MODEL(ActorCriticMLP, ActorCriticMLPModel) 92 | 93 | END_MODEL_REGISTRY 94 | 95 | } // namespace torchfort 96 | -------------------------------------------------------------------------------- /src/csrc/utils.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include 24 | 25 | #include "internal/defines.h" 26 | #include "internal/model_pack.h" 27 | #include "internal/utils.h" 28 | 29 | namespace torchfort { 30 | 31 | // Declaration of external global variables 32 | extern std::unordered_map models; 33 | 34 | std::string sanitize(std::string s) { 35 | s.erase(std::remove(s.begin(), s.end(), ' '), s.end()); 36 | std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); }); 37 | return s; 38 | } 39 | 40 | std::string filename_sanitize(std::string s) { 41 | // remove trailing whitespace 42 | s.erase(std::remove(s.begin(), s.end(), ' '), s.end()); 43 | 44 | // replace intermediate whitespace 45 | s = std::regex_replace(s, std::regex(" "), "_"); 46 | 47 | // replace all / with _: 48 | s = std::regex_replace(s, std::regex("/"), "-"); 49 | 50 | return s; 51 | } 52 | 53 | torch::Device get_device(int device) { 54 | torch::Device device_torch(torch::kCPU); 55 | if (device != TORCHFORT_DEVICE_CPU) { 56 | #ifdef ENABLE_GPU 57 | device_torch = torch::Device(torch::kCUDA, device); 58 | #else 59 | THROW_NOT_SUPPORTED( 60 | "Attempted to place a model or other component on GPU but TorchFort was build without GPU support."); 61 | #endif 62 | } 63 | return device_torch; 64 | } 65 | 66 | torch::Device get_device(const void* ptr) { 67 | torch::Device device = torch::Device(torch::kCPU); 68 | #ifdef ENABLE_GPU 69 | cudaPointerAttributes attr; 70 | CHECK_CUDA(cudaPointerGetAttributes(&attr, ptr)); 71 | switch (attr.type) { 72 | case cudaMemoryTypeHost: 73 | case cudaMemoryTypeUnregistered: 74 | device = torch::Device(torch::kCPU); 75 | break; 76 | case cudaMemoryTypeManaged: 77 | case cudaMemoryTypeDevice: 78 | device = torch::Device(torch::kCUDA); 79 | break; 80 | } 81 | #endif 82 | return device; 83 | } 84 | 85 | std::string print_tensor_shape(torch::Tensor tensor) { 86 | std::string shapestr = "("; 87 | for (int i = 0; i < tensor.dim(); ++i) 88 | shapestr += std::to_string(tensor.size(i)) + ","; 89 | shapestr.pop_back(); 90 | shapestr += ")"; 91 | return shapestr; 92 | } 93 | 94 | std::vector get_current_lrs(const char* name) { 95 | auto optimizer = models[name].optimizer; 96 | std::vector learnings_rates(optimizer->param_groups().size()); 97 | if (learnings_rates.size() > 0) { 98 | for (const auto i : c10::irange(optimizer->param_groups().size())) { 99 | learnings_rates[i] = optimizer->param_groups()[i].options().get_lr(); 100 | } 101 | } 102 | return learnings_rates; 103 | } 104 | 105 | } // namespace torchfort 106 | -------------------------------------------------------------------------------- /src/csrc/model_pack.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | 20 | #include "internal/defines.h" 21 | #include "internal/model_pack.h" 22 | #include "internal/model_wrapper.h" 23 | #include "internal/utils.h" 24 | 25 | namespace torchfort { 26 | 27 | void save_model_pack(const ModelPack& model_pack, const std::string& dir, bool save_optimizer) { 28 | std::filesystem::path root_dir(dir); 29 | 30 | if (!std::filesystem::exists(root_dir)) { 31 | bool rv = std::filesystem::create_directory(root_dir); 32 | if (!rv) { 33 | THROW_INVALID_USAGE("Could not create directory " + root_dir.native() + "."); 34 | } 35 | } 36 | 37 | model_pack.state->device = model_pack.model->device(); 38 | 39 | auto model_path = root_dir / "model.pt"; 40 | model_pack.model->save(model_path.native()); 41 | 42 | if (save_optimizer) { 43 | auto optimizer_path = root_dir / "optimizer.pt"; 44 | if (!model_pack.optimizer) { 45 | THROW_INVALID_USAGE("Cannot save checkpoint. Missing optimizer."); 46 | } 47 | torch::save(*(model_pack.optimizer), optimizer_path.native()); 48 | 49 | auto lr_path = root_dir / "lr.pt"; 50 | if (model_pack.lr_scheduler) { 51 | model_pack.lr_scheduler->save(lr_path.native()); 52 | } 53 | } 54 | 55 | auto state_path = root_dir / "state.pt"; 56 | model_pack.state->save(state_path.native()); 57 | } 58 | 59 | void load_model_pack(ModelPack& model_pack, const std::string& dir, bool load_optimizer) { 60 | std::filesystem::path root_dir(dir); 61 | 62 | auto state_path = root_dir / "state.pt"; 63 | if (!std::filesystem::exists(state_path)) { 64 | THROW_INVALID_USAGE("Could not find " + state_path.native() + "."); 65 | } 66 | model_pack.state->load(state_path.native()); 67 | 68 | auto model_path = root_dir / "model.pt"; 69 | if (!std::filesystem::exists(model_path)) { 70 | THROW_INVALID_USAGE("Could not find " + model_path.native() + "."); 71 | } 72 | model_pack.model->load(model_path.native()); 73 | 74 | // Assign optimizer to parameters of loaded model: 75 | // we need to check if the optimizer is initialized before doing so 76 | // (some RL models do not have an optimizer attached to them): 77 | if (model_pack.optimizer) { 78 | model_pack.optimizer->parameters() = model_pack.model->parameters(); 79 | } 80 | 81 | if (load_optimizer) { 82 | auto optimizer_path = root_dir / "optimizer.pt"; 83 | if (!std::filesystem::exists(optimizer_path)) { 84 | THROW_INVALID_USAGE("Could not find " + optimizer_path.native() + "."); 85 | } 86 | torch::load(*(model_pack.optimizer), optimizer_path.native(), model_pack.model->device()); 87 | 88 | auto lr_path = root_dir / "lr.pt"; 89 | if (std::filesystem::exists(lr_path)) { 90 | model_pack.lr_scheduler->load(lr_path.native(), *(model_pack.optimizer)); 91 | } else { 92 | // No LR in checkpoint, disable LR scheduler 93 | model_pack.lr_scheduler = nullptr; 94 | } 95 | } 96 | } 97 | 98 | } // namespace torchfort 99 | -------------------------------------------------------------------------------- /src/csrc/include/internal/rl/utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | #include 20 | 21 | #ifdef ENABLE_GPU 22 | #include 23 | 24 | #include 25 | #include 26 | #endif 27 | #include 28 | 29 | #include "internal/defines.h" 30 | #include "internal/logging.h" 31 | #include "internal/lr_schedulers.h" 32 | #include "internal/model_pack.h" 33 | #include "internal/rl/rl.h" 34 | #include "internal/setup.h" 35 | 36 | namespace torchfort { 37 | 38 | namespace rl { 39 | 40 | // helpers for sanitizing devices 41 | // check if both devices are different, one device has to be a cpu device. 42 | bool validate_devices(int device1, int device2); 43 | 44 | // helpers for extracting LRS from optimizer 45 | std::vector get_current_lrs(std::shared_ptr optimizer); 46 | 47 | // helpers for manipulating weights and grads 48 | void init_parameters(std::shared_ptr model); 49 | void copy_parameters(std::shared_ptr target, std::shared_ptr source); 50 | void set_grad_state(std::shared_ptr model, const bool requires_grad); 51 | 52 | // polyak update for model averaging: 53 | // computes: target = rho * target + (1-rho) * src 54 | // note that target here denotes the updated model parameters, and src the previous ones 55 | template 56 | void polyak_update(std::shared_ptr target, std::shared_ptr source, const T rho) { 57 | 58 | // add no grad guard 59 | torch::NoGradGuard no_grad; 60 | 61 | // get models 62 | auto tar = target->parameters(); 63 | auto src = source->parameters(); 64 | 65 | // some simple asserts here 66 | assert(tar.size() == src.size()); 67 | 68 | // do in-place update: I don't know a good way of doing that with std::transform: 69 | for (size_t i = 0; i < tar.size(); ++i) { 70 | const auto& t = tar[i]; 71 | const auto& s = src[i]; 72 | t.mul_(rho); 73 | t.add_((1. - rho) * s); 74 | // t.copy_(torch::Tensor(rho * t + (1.-rho) * s)); 75 | } 76 | 77 | return; 78 | } 79 | 80 | // Rescale the action from [a_low, a_high] to [-1, 1] 81 | template torch::Tensor scale_action(torch::Tensor unscaled_action, const T& a_low, const T& a_high) { 82 | auto scaled_action = static_cast(2.0) * ((unscaled_action - a_low) / (a_high - a_low)) - static_cast(1.0); 83 | scaled_action.to(unscaled_action.dtype()); 84 | 85 | return scaled_action; 86 | } 87 | 88 | // Unscale the action from [-1., 1.] to [a_low, a_high] 89 | template torch::Tensor unscale_action(torch::Tensor scaled_action, const T& a_low, const T& a_high) { 90 | auto unscaled_action = 0.5 * (a_high - a_low) * (scaled_action + static_cast(1.)) + a_low; 91 | unscaled_action.to(scaled_action.dtype()); 92 | 93 | return unscaled_action; 94 | } 95 | 96 | // explained variance 97 | torch::Tensor explained_variance(torch::Tensor q_pred, torch::Tensor q_true, std::shared_ptr comm); 98 | 99 | } // namespace rl 100 | 101 | } // namespace torchfort 102 | -------------------------------------------------------------------------------- /src/csrc/include/internal/utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | #include 22 | #include 23 | 24 | #include 25 | #include 26 | 27 | #ifdef ENABLE_GPU 28 | #include 29 | #endif 30 | 31 | #include "internal/exceptions.h" 32 | #include "internal/nvtx.h" 33 | 34 | namespace torchfort { 35 | 36 | // Function to convert string to lowercase and remove whitespace 37 | std::string sanitize(std::string s); 38 | 39 | // Function to convert a string to a filename 40 | std::string filename_sanitize(std::string s); 41 | 42 | // Function to return torch device from integer device value 43 | torch::Device get_device(int device); 44 | 45 | // Function to return torch device from pointer 46 | torch::Device get_device(const void* ptr); 47 | 48 | template torch::Dtype make_type() { 49 | if (std::is_same::value) { 50 | return torch::kFloat32; 51 | } else if (std::is_same::value) { 52 | return torch::kInt32; 53 | } else if (std::is_same::value) { 54 | return torch::kInt64; 55 | } else if (std::is_same::value) { 56 | return torch::kFloat64; 57 | } else { 58 | THROW_INVALID_USAGE("datatype not implemented"); 59 | } 60 | } 61 | 62 | enum MemoryLayout { RowMajor = 0, ColMajor = 1 }; 63 | 64 | template torch::Tensor get_tensor(T* tensor_ptr, size_t dim, int64_t* shape) { 65 | torchfort::nvtx::rangePush("get_tensor"); 66 | // Set tensor options 67 | auto dev = get_device(tensor_ptr); 68 | torch::TensorOptions options = torch::TensorOptions().device(dev); 69 | 70 | // Get type 71 | auto type = make_type(); 72 | options = options.dtype(type); 73 | 74 | // Create shape 75 | std::vector sizes(dim); 76 | switch (L) { 77 | case RowMajor: 78 | for (size_t i = 0; i < dim; ++i) { 79 | sizes[i] = shape[i]; 80 | } 81 | break; 82 | case ColMajor: 83 | // For column major input data, reverse the shape order 84 | for (size_t i = 0; i < dim; ++i) { 85 | sizes[i] = shape[dim - i - 1]; 86 | } 87 | break; 88 | } 89 | torch::IntArrayRef size = c10::makeArrayRef(sizes); 90 | 91 | // Create tensor 92 | auto tensor = torch::from_blob( 93 | tensor_ptr, sizes, [](void* ptr) {}, options); 94 | torchfort::nvtx::rangePop(); 95 | return tensor; 96 | } 97 | 98 | // Helper function to convert string reduction names to torch enums. 99 | template T get_torch_reduction(const std::string& s) { 100 | if (s == "mean") { 101 | return torch::kMean; 102 | } else if (s == "sum") { 103 | return torch::kSum; 104 | } else if (s == "none") { 105 | return torch::kNone; 106 | } else { 107 | THROW_INVALID_USAGE("Unknown reduction type encountered."); 108 | } 109 | } 110 | 111 | // Helper function for printing tensor shapes 112 | std::string print_tensor_shape(torch::Tensor tensor); 113 | 114 | // Helper function to get the lrs 115 | std::vector get_current_lrs(const char* name); 116 | 117 | } // namespace torchfort 118 | -------------------------------------------------------------------------------- /tests/test_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #ifdef ENABLE_GPU 27 | #include 28 | #endif 29 | 30 | // Generate random vector data for testing 31 | template std::vector generate_random(const std::vector& shape) { 32 | 33 | int64_t num_values = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); 34 | std::vector data(num_values); 35 | 36 | std::mt19937 generator; 37 | unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); 38 | generator.seed(seed); 39 | std::uniform_real_distribution dist((T)0, (T)1); 40 | 41 | auto r = [&]() { return dist(generator); }; 42 | 43 | std::generate(data.begin(), data.end(), r); 44 | 45 | return data; 46 | } 47 | 48 | // Generate constant vector data for testing 49 | template std::vector generate_constant(const std::vector& shape, T value) { 50 | 51 | int64_t num_values = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); 52 | std::vector data(num_values, value); 53 | 54 | return data; 55 | } 56 | 57 | // Generate random names to use as model keys to avoid conflicts between tests 58 | std::string generate_random_name(int length) { 59 | 60 | const std::string character_set = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; 61 | std::mt19937 generator; 62 | unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); 63 | generator.seed(seed); 64 | std::uniform_int_distribution<> dist(0, character_set.size() - 1); 65 | 66 | std::string name; 67 | for (int i = 0; i < length; ++i) { 68 | name += character_set[dist(generator)]; 69 | } 70 | return name; 71 | } 72 | 73 | // Get raw data pointer from vector. If dev is GPU, this routine will allocate GPU memory and copy. 74 | template T* get_data_ptr(std::vector& data, int dev) { 75 | T* data_ptr; 76 | #ifdef ENABLE_GPU 77 | if (dev == TORCHFORT_DEVICE_CPU) { 78 | data_ptr = data.data(); 79 | } else { 80 | CHECK_CUDA(cudaMalloc(&data_ptr, data.size() * sizeof(T(0)))); 81 | CHECK_CUDA(cudaMemcpy(data_ptr, data.data(), data.size() * sizeof(T(0)), cudaMemcpyHostToDevice)); 82 | } 83 | #else 84 | data_ptr = data.data(); 85 | #endif 86 | 87 | return data_ptr; 88 | } 89 | 90 | // Free raw data pointer. If dev is GPU, this routine will free GPU memory. 91 | template void free_data_ptr(T* data_ptr, int dev) { 92 | #ifdef ENABLE_GPU 93 | if (dev != TORCHFORT_DEVICE_CPU) { 94 | CHECK_CUDA(cudaFree(data_ptr)); 95 | } 96 | #endif 97 | } 98 | 99 | // Routines to copy vector data to and from GPU. 100 | #ifdef ENABLE_GPU 101 | template void copy_to_host_vector(std::vector& data, T* data_ptr) { 102 | CHECK_CUDA(cudaMemcpy(data.data(), data_ptr, data.size() * sizeof(T(0)), cudaMemcpyDeviceToHost)); 103 | } 104 | template void copy_from_host_vector(T* data_ptr, std::vector& data) { 105 | CHECK_CUDA(cudaMemcpy(data_ptr, data.data(), data.size() * sizeof(T(0)), cudaMemcpyHostToDevice)); 106 | } 107 | #endif 108 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | ############ 2 | Installation 3 | ############ 4 | 5 | TorchFort can be installed in multiple ways but we highly recommend building and using a Docker container. 6 | 7 | Docker Installation 8 | ------------------- 9 | 10 | We provide a `Dockerfile `_ which contains all relevant dependencies and builds using the `NVIDIA HPC SDK `_ software libraries and compilers, which is our recommended way to build TorchFort. In order to build TorchFort using Docker, simply clone the repo and call: 11 | 12 | .. code-block:: bash 13 | 14 | docker build -t torchfort:latest -f docker/Dockerfile . 15 | 16 | from the top level directory of the repo. Inside the container, TorchFort will be installed in ``/opt/torchfort``. 17 | 18 | We provide an alternative docker file `Dockerfile_gnu `_ which can be used to build TorchFort using GNU compilers. Additionally, we provide a docker file `Dockerfile_gnu_cpuonly `_ which can be used to build TorchFort using GNU compilers without GPU support enabled. 19 | 20 | CMake Installation 21 | ------------------ 22 | 23 | For a native installation TorchFort provides a `CMakeList.txt `_ file. Please make sure that the following required packages are installed on your system before installing TorchFort: 24 | 25 | * Requirements for core functionality and examples: 26 | 27 | - CUDA 12.1 or newer 28 | - ``python`` version 3.6 or higher 29 | - ``pybind11`` 30 | - ``yaml-cpp`` from https://github.com/jbeder/yaml-cpp.git 31 | - MPI 32 | - NVIDIA Collective Communication Library (``NCCL``) 33 | - the Python modules specified in `requirements.txt `_ 34 | - GNU or `NVHPC `_ compilers. NVHPC compilers are **required** if CUDA Fortran device array support is desired. 35 | 36 | * Additional requirements for building this documentation: 37 | 38 | - Doxygen 39 | - the Python modules specified in `docs/requirements.txt `_ 40 | 41 | For CPU-only builds, CUDA and NCCL are not required. 42 | 43 | 44 | To build TorchFort, clone the repo then call the following from the root directory: 45 | 46 | .. code-block:: bash 47 | 48 | mkdir build && cd build 49 | cmake -DCMAKE_INSTALL_PREFIX= \ 50 | -DTORCHFORT_YAML_CPP_ROOT= \ 51 | -DTORCHFORT_BUILD_EXAMPLES=1 \ 52 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \ 53 | .. 54 | make -j install 55 | 56 | See the top level `CMakeList.txt `_ file for additional CMake configuration options. 57 | 58 | Build Documentation 59 | ------------------- 60 | 61 | The documentation can be built with the corresponding ``Makefile`` in the ``docs`` directory. Make sure that the requirements are installed and call: 62 | 63 | .. code-block:: bash 64 | 65 | cd docs && make html 66 | 67 | The docs will be located in ``docs/_build/html`` and can be viewed locally in your web browser. 68 | 69 | Directory Structure 70 | ------------------- 71 | 72 | Independent of how you decide to install TorchFort, the directory structure will be as follows:: 73 | 74 | 75 | |--- bin 76 | |--- examples 77 | |--- cpp 78 | |--- fortran 79 | |--- python 80 | |--- include 81 | |--- lib 82 | 83 | The ``bin`` folder contains the examples written in C++ or Fortran located in the corresponding subdirectories. The ``python`` subfolder contains the Python wrappers for :ref:`wandb_support-ref`. 84 | 85 | The Fortran module ``torchfort.mod`` as well as the C headers can be found inside the ``include`` folder and the dynamic libraries inside the ``lib`` folder. 86 | -------------------------------------------------------------------------------- /src/csrc/include/internal/param_map.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | #include "internal/exceptions.h" 28 | #include "internal/utils.h" 29 | 30 | namespace torchfort { 31 | 32 | // Helper function to get type as string 33 | template std::string type_string() { 34 | if (std::is_same::value) { 35 | return "int"; 36 | } else if (std::is_same::value) { 37 | return "float"; 38 | } else if (std::is_same::value) { 39 | return "double"; 40 | } else if (std::is_same::value) { 41 | return "bool"; 42 | } 43 | return "UNKNOWN"; 44 | }; 45 | 46 | // Conversion functor 47 | template struct ParamMapConverter { 48 | T operator()(const std::string& s) { 49 | try { 50 | if constexpr (std::is_same::value) { 51 | return std::stoi(sanitize(s)); 52 | } 53 | if constexpr (std::is_same::value) { 54 | return std::stof(sanitize(s)); 55 | } 56 | if constexpr (std::is_same::value) { 57 | return std::stod(sanitize(s)); 58 | } 59 | if constexpr (std::is_same::value) { 60 | std::string s_ = sanitize(s); 61 | bool val; 62 | if (s_ == "true") { 63 | val = true; 64 | } else if (s_ == "false") { 65 | val = false; 66 | } else { 67 | val = std::stoi(s_); 68 | } 69 | return val; 70 | } 71 | if constexpr (std::is_same::value) { 72 | return s; 73 | } 74 | } catch (std::invalid_argument) { 75 | THROW_INVALID_USAGE("Could not convert provided parameter value " + s + " to required type."); 76 | } 77 | 78 | THROW_INTERNAL_ERROR("Unknown conversion type."); 79 | } 80 | }; 81 | 82 | class ParamMap { 83 | public: 84 | template void add_param(const std::string& key, const std::vector& value); 85 | 86 | template std::vector get_param(const std::string& key) const; 87 | 88 | template std::vector get_param(const std::string& key, const T& defval) const; 89 | 90 | std::set keys() const; 91 | 92 | private: 93 | std::unordered_map> params; 94 | }; 95 | 96 | template void ParamMap::add_param(const std::string& key, const std::vector& value) { 97 | params[sanitize(key)] = value; 98 | } 99 | 100 | template std::vector ParamMap::get_param(const std::string& key) const { 101 | const auto& entry = params.at(sanitize(key)); 102 | std::vector values; 103 | std::transform(entry.begin(), entry.end(), std::back_inserter(values), ParamMapConverter()); 104 | return values; 105 | } 106 | 107 | // parameter with default value 108 | template std::vector ParamMap::get_param(const std::string& key, const T& defval) const { 109 | try { 110 | const auto& entry = params.at(sanitize(key)); 111 | std::vector values; 112 | std::transform(entry.begin(), entry.end(), std::back_inserter(values), ParamMapConverter()); 113 | return values; 114 | } catch (std::out_of_range) { 115 | return {defval}; 116 | } 117 | } 118 | 119 | } // namespace torchfort 120 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: "Build" 2 | 3 | on: 4 | issue_comment: 5 | types: [ created ] 6 | 7 | jobs: 8 | trigger-comment: 9 | if: github.event.issue.pull_request && contains(github.event.comment.body, '/build_and_test') && (github.actor == 'romerojosh' || github.actor == 'azrael417') 10 | runs-on: ubuntu-latest 11 | permissions: 12 | issues: write 13 | pull-requests: write 14 | steps: 15 | - name: "Post trigger comment" 16 | uses: actions/github-script@v7 17 | with: 18 | script: | 19 | const workflowUrl = `${context.payload.repository.html_url}/actions/runs/${context.runId}`; 20 | github.rest.issues.createComment({ 21 | issue_number: context.issue.number, 22 | owner: context.repo.owner, 23 | repo: context.repo.repo, 24 | body: `🚀 Build workflow triggered! [View run](${workflowUrl})` 25 | }); 26 | 27 | build: 28 | needs: trigger-comment 29 | if: github.event.issue.pull_request && contains(github.event.comment.body, '/build_and_test') && (github.actor == 'romerojosh' || github.actor == 'azrael417') 30 | strategy: 31 | matrix: 32 | include: 33 | - name: "GPU build (NVHPC SDK 25.7, Ubuntu 22.04)" 34 | dockerfile: "Dockerfile" 35 | free_space: true 36 | run_tests: false 37 | - name: "GPU build (GNU, Ubuntu 22.04)" 38 | dockerfile: "Dockerfile_gnu" 39 | free_space: true 40 | run_tests: false 41 | - name: "CPU build (GNU, Ubuntu 22.04)" 42 | dockerfile: "Dockerfile_gnu_cpuonly" 43 | free_space: false 44 | run_tests: true 45 | 46 | name: ${{ matrix.name }} 47 | runs-on: ubuntu-latest 48 | 49 | steps: 50 | - name: "Free disk space" 51 | if: ${{ matrix.free_space }} 52 | run: | 53 | sudo rm -rf /usr/local/lib/android || true 54 | sudo rm -rf /usr/share/dotnet || true 55 | 56 | - name: "Retrieve PR info" 57 | uses: actions/github-script@v7 58 | id: pr-info 59 | with: 60 | script: | 61 | const pr = await github.rest.pulls.get({ 62 | owner: context.repo.owner, 63 | repo: context.repo.repo, 64 | pull_number: context.issue.number 65 | }); 66 | core.setOutput('sha', pr.data.head.sha); 67 | 68 | - name: "Checkout PR code" 69 | uses: actions/checkout@v4 70 | with: 71 | persist-credentials: false 72 | ref: ${{ steps.pr-info.outputs.sha }} 73 | 74 | - name: "Set up Docker Buildx" 75 | uses: docker/setup-buildx-action@v3 76 | 77 | - name: "Build docker container" 78 | run: | 79 | docker build -t torchfort -f docker/${{ matrix.dockerfile }} . 80 | 81 | - name: "Run tests" 82 | if: ${{ matrix.run_tests }} 83 | run: | 84 | docker run -v ${PWD}/.github/scripts:/scripts -w /scripts --rm torchfort ./run_ci_tests.sh 85 | 86 | result-comment: 87 | needs: build 88 | if: always() && github.event.issue.pull_request && contains(github.event.comment.body, '/build_and_test') && (github.actor == 'romerojosh' || github.actor == 'azrael417') 89 | runs-on: ubuntu-latest 90 | permissions: 91 | issues: write 92 | pull-requests: write 93 | steps: 94 | - name: "Post result comment" 95 | uses: actions/github-script@v7 96 | with: 97 | script: | 98 | const workflowUrl = `${context.payload.repository.html_url}/actions/runs/${context.runId}`; 99 | const success = '${{ needs.build.result }}' === 'success'; 100 | 101 | const message = success 102 | ? `✅ Build workflow passed! [View run](${workflowUrl})` 103 | : `❌ Build workflow failed! [View run](${workflowUrl})`; 104 | 105 | github.rest.issues.createComment({ 106 | issue_number: context.issue.number, 107 | owner: context.repo.owner, 108 | repo: context.repo.repo, 109 | body: message 110 | }); 111 | -------------------------------------------------------------------------------- /examples/fortran/graph/visualize.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse as ap 17 | import glob 18 | import matplotlib.pyplot as plt 19 | import matplotlib.tri as tri 20 | from matplotlib.animation import FuncAnimation, PillowWriter 21 | import numpy as np 22 | import os 23 | import time 24 | 25 | def main(args): 26 | 27 | global reffiles, predfiles, artists, triangulation 28 | print(f"processing files in {args.input_path}...") 29 | 30 | reffiles = sorted(glob.glob(os.path.join(args.input_path, "reference_*.txt"))) 31 | predfiles = sorted(glob.glob(os.path.join(args.input_path, "prediction_*.txt"))) 32 | 33 | # Read mesh data 34 | nodes = np.loadtxt("nodes.txt", skiprows=1) 35 | triangles = np.loadtxt("connectivity.txt", skiprows=1) 36 | triangulation = tri.Triangulation(nodes[:,0], nodes[:,1], triangles) 37 | 38 | artists = [] 39 | 40 | fig, ((ax1), (ax2)) = plt.subplots(2, 1) 41 | ax1.set_title("Ground Truth") 42 | ax1.set_xlabel(r"$x$") 43 | ax1.set_ylabel(r"$y$") 44 | ax2.set_title("Prediction") 45 | ax2.set_xlabel(r"$x$") 46 | ax2.set_ylabel(r"$y$") 47 | 48 | c = ax1.tricontourf(triangulation, np.loadtxt(reffiles[0]), levels=np.linspace(-0.1, 1.0, 15)) 49 | try: 50 | artists += c.collections 51 | except: 52 | artists.append(c) 53 | c = ax1.triplot(triangulation, linewidth=0.3, color='black') 54 | artists.append(c) 55 | c = ax2.tricontourf(triangulation, np.loadtxt(predfiles[0]), levels=np.linspace(-0.1, 1.0, 15)) 56 | try: 57 | artists += c.collections 58 | except: 59 | artists.append(c) 60 | c = ax2.triplot(triangulation, linewidth=0.3, color='black') 61 | artists.append(c) 62 | 63 | fig.tight_layout() 64 | 65 | def animate(i): 66 | global reffiles, predfiles, artists, triangulation 67 | for c in artists: 68 | try: 69 | c.remove() 70 | except: 71 | pass 72 | artists.clear() 73 | 74 | c = ax1.tricontourf(triangulation, np.loadtxt(reffiles[i]), levels=np.linspace(-0.1, 1.0, 15)) 75 | try: 76 | artists += c.collections 77 | except: 78 | artists.append(c) 79 | c = ax1.triplot(triangulation, linewidth=0.3, color='black') 80 | artists.append(c) 81 | c = ax2.tricontourf(triangulation, np.loadtxt(predfiles[i]), levels=np.linspace(-0.1, 1.0, 15)) 82 | try: 83 | artists += c.collections 84 | except: 85 | artists.append(c) 86 | c = ax2.triplot(triangulation, linewidth=0.3, color='black') 87 | artists.append(c) 88 | 89 | 90 | 91 | ani = FuncAnimation(fig, animate, frames=len(reffiles), repeat=False, interval=1) 92 | 93 | os.makedirs(args.output_path, exist_ok=True) 94 | 95 | def log(i, n): 96 | print(f"processed {i+1} of {n} frames..." ) 97 | ani.save(os.path.join(args.output_path, "validation_results.gif"), writer=PillowWriter(fps=5), progress_callback=lambda i, n: log(i,n)) 98 | print(f"video written to {os.path.join(args.output_path, 'validation_results.gif')}...") 99 | 100 | if __name__ == "__main__": 101 | parser = ap.ArgumentParser() 102 | parser.add_argument("--input_path", type=str, help="Directory containing result text files", required=True) 103 | parser.add_argument("--output_path", type=str, help="Directory to store the generated videos", required=True) 104 | args = parser.parse_args() 105 | 106 | main(args) 107 | 108 | -------------------------------------------------------------------------------- /examples/fortran/graph/generate_model.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | class MessagePassing(torch.nn.Module): 19 | def __init__(self, hidden_dim): 20 | super(MessagePassing, self).__init__() 21 | 22 | self.mlp_edge = torch.nn.Sequential(torch.nn.Linear(3*hidden_dim, hidden_dim), 23 | torch.nn.ReLU(), 24 | torch.nn.Linear(hidden_dim, hidden_dim), 25 | torch.nn.LayerNorm(hidden_dim)) 26 | self.mlp_node = torch.nn.Sequential(torch.nn.Linear(2*hidden_dim, hidden_dim), 27 | torch.nn.ReLU(), 28 | torch.nn.Linear(hidden_dim, hidden_dim), 29 | torch.nn.LayerNorm(hidden_dim)) 30 | 31 | def forward(self, edge_idx, node_feats, edge_feats): 32 | senders = edge_idx[:,0] 33 | receivers = edge_idx[:,1] 34 | 35 | edge_update = torch.cat([node_feats[senders], node_feats[receivers], edge_feats], dim=-1) 36 | edge_update = self.mlp_edge(edge_update) 37 | 38 | accumulate_edges = torch.zeros([node_feats.shape[0], edge_feats.shape[1]], dtype=edge_feats.dtype, device=edge_feats.device) 39 | receivers = receivers.unsqueeze(-1).expand(-1, edge_feats.shape[1]) 40 | accumulate_edges = torch.scatter_add(accumulate_edges, src=edge_feats, index=receivers, dim=0) 41 | node_update = torch.cat([node_feats, accumulate_edges], dim=-1) 42 | node_update = self.mlp_node(node_update) 43 | 44 | edge_feats = edge_feats + edge_update 45 | node_feats = node_feats + node_update 46 | 47 | return node_feats, edge_feats 48 | 49 | 50 | class Net(torch.nn.Module): 51 | def __init__(self, in_node_features, in_edge_features, hidden_dim, n_message_passing_steps): 52 | super(Net, self).__init__() 53 | self.encoder_node = torch.nn.Sequential(torch.nn.Linear(in_node_features, hidden_dim), 54 | torch.nn.ReLU(), 55 | torch.nn.Linear(hidden_dim, hidden_dim), 56 | torch.nn.LayerNorm(hidden_dim)) 57 | self.encoder_edge = torch.nn.Sequential(torch.nn.Linear(in_edge_features, hidden_dim), 58 | torch.nn.ReLU(), 59 | torch.nn.Linear(hidden_dim, hidden_dim), 60 | torch.nn.LayerNorm(hidden_dim)) 61 | 62 | self.mp_layers = torch.nn.ModuleList() 63 | for _ in range(n_message_passing_steps): 64 | self.mp_layers.append(MessagePassing(hidden_dim)) 65 | 66 | self.decoder = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), 67 | torch.nn.ReLU(), 68 | torch.nn.Linear(hidden_dim, in_node_features)) 69 | 70 | def forward(self, edge_idx, node_feats, edge_feats): 71 | # Encode node and edge features 72 | node_feats = self.encoder_node(node_feats) 73 | edge_feats = self.encoder_edge(edge_feats) 74 | 75 | # Message passing 76 | for mp in self.mp_layers: 77 | node_feats, edge_feats = mp(edge_idx, node_feats, edge_feats) 78 | 79 | # Decode node featues 80 | node_feats = self.decoder(node_feats) 81 | 82 | return node_feats 83 | 84 | 85 | def main(): 86 | # Create model 87 | model = Net(1, 3, 128, 8) 88 | print("graph model:", model) 89 | 90 | try: 91 | # Move model to GPU, JIT, and save 92 | model.to("cuda") 93 | except: 94 | print("PyTorch does not have CUDA support. Saving model on CPU.") 95 | model_jit = torch.jit.script(model) 96 | model_jit.save("model_torchscript.pt") 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /src/csrc/model_wrapper.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #ifdef ENABLE_GPU 24 | #include 25 | #endif 26 | #include 27 | #include 28 | 29 | #include "internal/base_model.h" 30 | #include "internal/defines.h" 31 | #include "internal/model_wrapper.h" 32 | #include "internal/utils.h" 33 | 34 | namespace torchfort { 35 | 36 | ModelWrapper::ModelWrapper(const std::shared_ptr& model) : model{model} {} 37 | 38 | ModelWrapper::ModelWrapper(const std::shared_ptr& model_jit) : model_jit{model_jit}, jit{true} {} 39 | 40 | ModelWrapper::ModelWrapper(const std::string& jit_model_fname) : jit{true} { 41 | 42 | if (!std::filesystem::exists(jit_model_fname)) { 43 | THROW_INVALID_USAGE(jit_model_fname + " does not exist."); 44 | } 45 | 46 | model_jit = std::shared_ptr(new torch::jit::Module); 47 | *model_jit = torch::jit::load(jit_model_fname, device_); 48 | } 49 | 50 | std::vector ModelWrapper::parameters() const { 51 | if (jit) { 52 | std::vector parameters; 53 | for (const auto& params : model_jit->parameters()) { 54 | parameters.push_back(params); 55 | } 56 | return parameters; 57 | } 58 | 59 | return model->parameters(); 60 | } 61 | 62 | torch::OrderedDict ModelWrapper::named_parameters() const { 63 | if (jit) { 64 | torch::OrderedDict parameters; 65 | for (const auto& params : model_jit->named_parameters()) { 66 | parameters.insert(params.name, params.value); 67 | } 68 | return parameters; 69 | } 70 | 71 | return model->named_parameters(); 72 | } 73 | 74 | void ModelWrapper::to(torch::Device device, bool non_blocking) { 75 | if (jit) { 76 | model_jit->to(device, non_blocking); 77 | } else { 78 | model->to(device, non_blocking); 79 | } 80 | 81 | this->device_ = device; 82 | } 83 | 84 | void ModelWrapper::train() { 85 | if (jit) { 86 | model_jit->train(); 87 | } else { 88 | model->train(); 89 | } 90 | } 91 | 92 | void ModelWrapper::eval() { 93 | if (jit) { 94 | model_jit->eval(); 95 | } else { 96 | model->eval(); 97 | } 98 | } 99 | 100 | std::vector ModelWrapper::forward(const std::vector& inputs) const { 101 | if (jit) { 102 | std::vector inputs_jit; 103 | inputs_jit.assign(inputs.begin(), inputs.end()); 104 | auto result = model_jit->forward(inputs_jit); 105 | if (result.isTensor()) { 106 | return std::vector{result.toTensor()}; 107 | } else if (result.isTuple()) { 108 | std::vector tensors; 109 | for (const auto& x : result.toTuple()->elements()) { 110 | tensors.push_back(x.toTensor()); 111 | } 112 | return tensors; 113 | } else { 114 | assert(true); 115 | } 116 | } 117 | return model->forward(inputs); 118 | } 119 | 120 | void ModelWrapper::save(const std::string& fname) const { 121 | if (jit) { 122 | model_jit->save(fname); 123 | } else { 124 | torch::save(model, fname); 125 | } 126 | } 127 | 128 | void ModelWrapper::load(const std::string& fname) { 129 | if (!std::filesystem::exists(fname)) { 130 | THROW_INVALID_USAGE(fname + " does not exist."); 131 | } 132 | if (jit) { 133 | model_jit.reset(); 134 | model_jit = std::shared_ptr(new torch::jit::Module); 135 | 136 | *model_jit = torch::jit::load(fname, device_); 137 | } else { 138 | torch::load(model, fname, device_); 139 | } 140 | } 141 | 142 | torch::Device ModelWrapper::device() const { return device_; } 143 | 144 | } // namespace torchfort 145 | -------------------------------------------------------------------------------- /examples/fortran/simulation/simulation.f90: -------------------------------------------------------------------------------- 1 | ! SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | ! SPDX-License-Identifier: Apache-2.0 3 | ! 4 | ! Licensed under the Apache License, Version 2.0 (the "License"); 5 | ! you may not use this file except in compliance with the License. 6 | ! You may obtain a copy of the License at 7 | ! 8 | ! http://www.apache.org/licenses/LICENSE-2.0 9 | ! 10 | ! Unless required by applicable law or agreed to in writing, software 11 | ! distributed under the License is distributed on an "AS IS" BASIS, 12 | ! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | ! See the License for the specific language governing permissions and 14 | ! limitations under the License. 15 | 16 | module simulation 17 | use, intrinsic :: iso_fortran_env, only: real32, real64 18 | implicit none 19 | 20 | real(real32), private, parameter :: PI = 4 * atan(1.0) 21 | 22 | integer, private :: n, js, je 23 | integer, private :: simulation_device 24 | real(real32), private :: dx, t_total, dt, a 25 | real(real32), private :: ax, ay 26 | real(real32), private :: k = 1 27 | 28 | contains 29 | subroutine init_simulation(n_in, dt_in, a_in, t_start, rank, nranks, simulation_device_in) 30 | implicit none 31 | integer, intent(in) :: n_in, simulation_device_in 32 | real(real32), intent(in) :: dt_in, t_start 33 | real(real32), intent(in) :: a_in(2) 34 | integer :: rank, nranks 35 | n = n_in 36 | dt = dt_in 37 | ax = a_in(1) 38 | ay = a_in(2) 39 | simulation_device = simulation_device_in 40 | 41 | dx = 2.0/(n) 42 | t_total = t_start 43 | 44 | js = 1 45 | je = n 46 | 47 | ! if running in parallel, split domain into slabs 48 | if (nranks > 1) then 49 | js = rank * n/nranks + 1 50 | je = (rank + 1) * n/nranks 51 | endif 52 | 53 | end subroutine init_simulation 54 | 55 | subroutine f_u(u, t) 56 | implicit none 57 | real(real32), intent(out) :: u(n, js:je) 58 | real(real32), intent(in) :: t 59 | integer :: i, j 60 | real(real32) :: x, y 61 | 62 | !$acc parallel loop collapse(2) default(present) async if(simulation_device >= 0) 63 | do j = js, je 64 | do i = 1, n 65 | x = -1.0 + dx * (i-1) - mod(ax*t, 2.0) 66 | y = -1.0 + dx * (j-1) - mod(ay*t, 2.0) 67 | if (x < -1.0) x = x + 2.0 68 | if (y < -1.0) y = y + 2.0 69 | u(i, j) = sin(k*PI*x) * sin(k*PI*y) 70 | end do 71 | end do 72 | 73 | end subroutine f_u 74 | 75 | subroutine f_u_div(u_div, t) 76 | implicit none 77 | real(real32), intent(out) :: u_div(n, js:je) 78 | real(real32), intent(in) :: t 79 | integer :: i, j 80 | real(real32) :: x, y 81 | 82 | !$acc parallel loop collapse(2) default(present) async if(simulation_device >= 0) 83 | do j = js, je 84 | do i = 1, n 85 | x = -1.0 + dx * (i-1) - mod(ax*t, 2.0) 86 | y = -1.0 + dx * (j-1) - mod(ay*t, 2.0) 87 | if (x < -1.0) x = x + 2.0 88 | if (y < -1.0) y = y + 2.0 89 | u_div(i, j) = k*PI * cos(k*PI*x) * sin(k*PI*y) + & 90 | k*PI * sin(k*PI*x) * cos(k*PI*y) 91 | end do 92 | end do 93 | 94 | end subroutine f_u_div 95 | 96 | subroutine run_simulation_step(u, u_div) 97 | implicit none 98 | real(real32), intent(out) :: u(n, js:je) 99 | real(real32), intent(out) :: u_div(n, js:je) 100 | 101 | call f_u(u, t_total) 102 | call f_u_div(u_div, t_total) 103 | t_total = t_total + dt 104 | 105 | end subroutine run_simulation_step 106 | 107 | subroutine write_sample(sample, fname) 108 | character(len=*) :: fname 109 | real(real32), intent(in) :: sample(n, n) 110 | integer :: unit, i, j, err 111 | 112 | !$acc update host(sample) if(simulation_device >= 0) 113 | 114 | open(newunit=unit, file=fname, status='replace', action='write', iostat=err) 115 | if (err /= 0) then 116 | write(*,*) 'Error opening file: ', fname 117 | return 118 | endif 119 | 120 | do j = 1, n 121 | do i = 1, n-1 122 | write(unit, '(ES14.6E2)', advance='no') sample(i, j) 123 | write(unit, '(A)', advance='no') ' ' 124 | end do 125 | write(unit, '(ES14.6E2)') sample(n, j) 126 | end do 127 | 128 | close(unit) 129 | end subroutine write_sample 130 | 131 | end module 132 | -------------------------------------------------------------------------------- /src/csrc/rl/utils.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | 21 | #include "internal/rl/utils.h" 22 | 23 | namespace torchfort { 24 | 25 | namespace rl { 26 | 27 | bool validate_devices(int device1, int device2) { 28 | if ((device1 != device2) && (device1 * device2 > 0)) { 29 | return false; 30 | } else { 31 | return true; 32 | } 33 | } 34 | 35 | std::vector get_current_lrs(std::shared_ptr optimizer) { 36 | std::vector learnings_rates(optimizer->param_groups().size()); 37 | if (learnings_rates.size() > 0) { 38 | for (const auto i : c10::irange(optimizer->param_groups().size())) { 39 | learnings_rates[i] = optimizer->param_groups()[i].options().get_lr(); 40 | } 41 | } 42 | return learnings_rates; 43 | } 44 | 45 | void init_parameters(std::shared_ptr model) { 46 | 47 | // no grad guard 48 | torch::NoGradGuard no_grad; 49 | 50 | for (const auto& p : model->named_parameters()) { 51 | std::string key = p.key(); 52 | auto val = p.value(); 53 | auto dim = val.dim(); 54 | 55 | if (key.find("weight") != std::string::npos) { 56 | // likely a conv or linear, we can use ortho 57 | if (dim >= 2) { 58 | torch::nn::init::orthogonal_(val, sqrt(2.)); 59 | } else { 60 | // likely normalization layer stuff: constant init 61 | torch::nn::init::constant_(val, 1.); 62 | } 63 | } else if (key.find("bias") != std::string::npos) { 64 | torch::nn::init::constant_(val, 0.); 65 | } 66 | } 67 | return; 68 | } 69 | 70 | void copy_parameters(std::shared_ptr target, std::shared_ptr source) { 71 | 72 | // create handles 73 | auto ptar = target->parameters(); 74 | auto psrc = source->parameters(); 75 | 76 | // sanity checks 77 | assert(ptar.size() == psrc.size()); 78 | 79 | // important, apply no grad 80 | torch::NoGradGuard no_grad; 81 | 82 | // copy loop 83 | for (size_t i = 0; i < ptar.size(); ++i) { 84 | auto& t = ptar[i]; 85 | auto& s = psrc[i]; 86 | t.copy_(s); 87 | } 88 | 89 | return; 90 | } 91 | 92 | void set_grad_state(std::shared_ptr model, const bool requires_grad) { 93 | // create handle for parameters 94 | auto pars = model->parameters(); 95 | 96 | // set grad state 97 | for (const auto& par : pars) { 98 | par.requires_grad_(requires_grad); 99 | } 100 | return; 101 | } 102 | 103 | torch::Tensor explained_variance(torch::Tensor q_pred, torch::Tensor q_true, std::shared_ptr comm) { 104 | // Computes fraction of variance that ypred explains about y. 105 | // Returns 1 - Var[y-ypred] / Var[y] 106 | // see https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/utils.py 107 | // the communicator is required for distributed training 108 | torch::Tensor result; 109 | if (!comm) { 110 | torch::Tensor var_q_true = torch::var(q_true); 111 | result = 1. - torch::var(q_true - q_pred) / var_q_true; 112 | } else { 113 | // Compute variance of q_true 114 | std::vector mean_vec = {torch::mean(q_true)}; 115 | comm->allreduce(mean_vec, true); 116 | torch::Tensor var_q_true = torch::mean(torch::square(q_true - mean_vec[0])); 117 | mean_vec = {var_q_true}; 118 | comm->allreduce(mean_vec, true); 119 | var_q_true = mean_vec[0]; 120 | // compute variane of difference: 121 | mean_vec = {torch::mean(q_true - q_pred)}; 122 | comm->allreduce(mean_vec, true); 123 | torch::Tensor var_q_diff = torch::mean(torch::square(q_true - q_pred - mean_vec[0])); 124 | mean_vec = {var_q_diff}; 125 | comm->allreduce(mean_vec, true); 126 | var_q_diff = mean_vec[0]; 127 | result = 1. - var_q_diff / var_q_true; 128 | } 129 | return result; 130 | } 131 | 132 | } // namespace rl 133 | 134 | } // namespace torchfort 135 | -------------------------------------------------------------------------------- /docs/extras.rst: -------------------------------------------------------------------------------- 1 | ###### 2 | Extras 3 | ###### 4 | 5 | .. _wandb_support-ref: 6 | 7 | Weights and Biases Support 8 | ========================== 9 | 10 | `Weights and Biases `_ (wandb) is a popular tool for monitoring machine learning training workflows. Wandb supports plotting and comparing loss curves and other relevant deep learning metrics, system utilization (including GPU, CPU and memory utilization) and other advanced logging functionalities like uploading images and/or videos, network weights and data artifacts. 11 | 12 | Since wandb does currently not offer a C++ interface and thus cannot be called from Fortran or C/C++ directly, we've implemented a wandb daemon written in Python instead. This daemon runs as a background process and waits for changes in a log file generated by the TorchFort application. In order to enable wandb support for your TorchFort application, the following steps have to be performed. 13 | 14 | Add Custom Metrics Reporting to Application 15 | ------------------------------------------- 16 | 17 | The TorchFort training routines already provide logging of training step, loss values as well as learning rate which are captured by the wandb daemon. Additional custom metrics can be added manually by the user. For this purpose, the user may add calls of ``torchfort_wandb_log`` or ``torchfort_rl_off_policy_wandb_log`` for traditional and reinforcement learning applications respectively (see :ref:`supervised_learning-ref` and :ref:`reinforcement_learning-ref` for details about why we provide different implementations for these two cases). For more information, see :ref:`torchfort_api_c-ref` for C/C++ and :ref:`torchfort_api_f-ref` for Fortran applications. 18 | 19 | Set up Environment 20 | ------------------ 21 | 22 | You need to specify your `wandb api token `_ via the environment variable ``WANDB_API_KEY`` (see the `wandb documentation on available environment variables `_ for details). 23 | Furthermore, the daemon needs to know where the the wandb logging data from the TorchFort application will be stored. This can be done by defining the environment variable ``TORCHFORT_LOGDIR``. Lastly, a user0defined wandb logging directory ``WANDB_LOGGING_DIR`` can be created to gather all wandb information as well as the config file in a place specific to the run. 24 | 25 | .. note:: 26 | The logging directory ``TORCHFORT_LOGDIR`` needs to be specified before the daemon and TorchFort application are launched. 27 | 28 | Start Background Watcher Process 29 | -------------------------------- 30 | 31 | Now, the wandb daemon process needs to be started. Assuming TorchFort was installed in ``TORCHFORT_INSTALL_DIR``, we can run 32 | 33 | .. code-block:: bash 34 | 35 | python ${TORCHFORT_INSTALL_DIR}/bin/python/wandb_helper.py \ 36 | --wandb_dir=${WANDB_LOGGING_DIR} \ 37 | --wandb_group= \ 38 | --wandb_project= \ 39 | --wandb_entity= \ 40 | --run_tag= \ 41 | --timeout=2400 & 42 | 43 | The wandb group, project as well as entity name correspond to the wandb project you are logging to. Those correspond to the respective arguments of ``wandb.init`` documented `here `_. Note that the group does not need to exist and will be created during initialization. The run tag can be any alphanumeric string and can be used to identify the specific run on the wandb dashboard. Lastly, the timeout (measured in seconds) determines for how long the background process will wait for changes to appear in ``${TORCHFORT_LOGIDR}/torchfort.log`` before wrapping up the monitoring. 44 | 45 | .. note:: 46 | Do not forget to launch the daemon into the background. 47 | 48 | Start Your TorchFort Application 49 | -------------------------------- 50 | 51 | In the configuration file for your TorchFort application, make sure to enable wandb logging in the ``general`` section by adding or modifying the line ``enable_wandb_hook: 1``. Lastly, start the TorchFort application as usual, e.g.: 52 | 53 | .. code-block:: bash 54 | 55 | ./my_torchfort_app arg1 arg2 arg3 56 | 57 | The daemone process will pick up the log lines from ``${TORCHFORT_LOGIDR}/torchfort.log`` and display the data on the corresponding job dashboard. 58 | 59 | .. note:: 60 | The daemon can finalize the monitoring while the TorchFort application is still running if the timeout is not set sufficiently large, especially for long running applications with very sparse logging. 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /examples/fortran/simulation/visualize.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse as ap 17 | import glob 18 | import matplotlib.pyplot as plt 19 | from matplotlib.animation import FuncAnimation, PillowWriter 20 | import numpy as np 21 | import os 22 | import time 23 | 24 | def main(args): 25 | 26 | global infiles, labelfiles, outfiles, artists 27 | print(f"processing files in {args.input_path}...") 28 | 29 | infiles = sorted(glob.glob(os.path.join(args.input_path, "input_0*"))) 30 | labelfiles = sorted(glob.glob(os.path.join(args.input_path, "label_0*"))) 31 | outfiles = sorted(glob.glob(os.path.join(args.input_path, "output_0*"))) 32 | 33 | artists = [] 34 | 35 | fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) 36 | ax1.set_title(r"$u$") 37 | ax1.set_xlabel(r"$x$") 38 | ax1.set_ylabel(r"$y$") 39 | ax2.set_title(r"$\nabla \cdot \mathbf{a}u$ (true)") 40 | ax2.set_xlabel(r"$x$") 41 | ax2.set_ylabel(r"$y$") 42 | ax3.set_title(r"$\nabla \cdot \mathbf{a}u$ (prediction)") 43 | ax3.set_xlabel(r"$x$") 44 | ax3.set_ylabel(r"$y$") 45 | ax4.set_title(r"1D sample along dotted line") 46 | ax4.set_xlabel(r"$x$") 47 | 48 | idata = np.loadtxt(infiles[0]) 49 | ldata = np.loadtxt(labelfiles[0]) 50 | odata = np.loadtxt(outfiles[0]) 51 | 52 | c = ax1.contourf(idata) 53 | try: 54 | artists += c.collections 55 | except: 56 | artists.append(c) 57 | c = ax1.hlines(idata.shape[0]//2 + 1, 0, idata.shape[1]-1, colors="black", linestyles="dashed") 58 | artists.append(c) 59 | c = ax2.contourf(ldata) 60 | try: 61 | artists += c.collections 62 | except: 63 | artists.append(c) 64 | c = ax3.contourf(odata) 65 | try: 66 | artists += c.collections 67 | except: 68 | artists.append(c) 69 | c, = ax4.plot(idata[idata.shape[0]//2 + 1,:], 'k') 70 | artists.append(c) 71 | c, = ax4.plot(ldata[idata.shape[0]//2 + 1,:], 'b') 72 | artists.append(c) 73 | c, = ax4.plot(odata[idata.shape[0]//2 + 1,:], 'g.') 74 | artists.append(c) 75 | 76 | fig.tight_layout() 77 | 78 | def animate(i): 79 | global infiles, labelfiles, outfiles, artists 80 | for c in artists: 81 | c.remove() 82 | artists.clear() 83 | 84 | idata = np.loadtxt(infiles[i]) 85 | ldata = np.loadtxt(labelfiles[i]) 86 | odata = np.loadtxt(outfiles[i]) 87 | c = ax1.contourf(idata) 88 | try: 89 | artists += c.collections 90 | except: 91 | artists.append(c) 92 | c = ax1.hlines(idata.shape[0]//2 + 1, 0, idata.shape[1]-1, colors="black", linestyles="dashed") 93 | artists.append(c) 94 | c = ax2.contourf(ldata) 95 | try: 96 | artists += c.collections 97 | except: 98 | artists.append(c) 99 | c = ax3.contourf(odata) 100 | try: 101 | artists += c.collections 102 | except: 103 | artists.append(c) 104 | c, = ax4.plot(idata[idata.shape[0]//2 + 1,:], 'k') 105 | artists.append(c) 106 | c, = ax4.plot(ldata[idata.shape[0]//2 + 1,:], 'b') 107 | artists.append(c) 108 | c, = ax4.plot(odata[idata.shape[0]//2 + 1,:], 'g.') 109 | artists.append(c) 110 | 111 | ani = FuncAnimation(fig, animate, frames=len(infiles), repeat=False, interval=1) 112 | 113 | os.makedirs(args.output_path, exist_ok=True) 114 | 115 | def log(i, n): 116 | print(f"processed {i+1} of {n} frames..." ) 117 | ani.save(os.path.join(args.output_path, "validation_results.gif"), writer=PillowWriter(fps=5), progress_callback=lambda i, n: log(i,n)) 118 | print(f"video written to {os.path.join(args.output_path, 'validation_results.gif')}...") 119 | 120 | if __name__ == "__main__": 121 | parser = ap.ArgumentParser() 122 | parser.add_argument("--input_path", type=str, help="Directory containing validation text files", required=True) 123 | parser.add_argument("--output_path", type=str, help="Directory to store the generated videos", required=True) 124 | args = parser.parse_args() 125 | 126 | main(args) 127 | 128 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/env.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #include "env.h" 23 | 24 | static float cost_continuous_fn(float r1, float r2, float e1, float e2, float x, float x_dot, float theta, 25 | float theta_dot) { 26 | float sign_r2 = (std::signbit(r2) ? -1.f : 1.f); 27 | return (sign_r2 * 100.f * r2 * r2 - 4.f * x * x) / 1000.f; 28 | } 29 | 30 | CartPoleEnv::CartPoleEnv() : uniform_dist_(-1.f, 1.f) { 31 | // important parameters 32 | terminated_ = true; 33 | gravity_ = 9.81f; 34 | masscart_ = 1.f; 35 | masspole_ = 0.1f; 36 | total_mass_ = masspole_ + masscart_; 37 | length_ = 0.5f; // actually half the pole's length 38 | polemass_length_ = masspole_ * length_; // CMS of pole 39 | force_mag_ = 20.f; 40 | dt_ = 0.02f; // seconds between state updates 41 | kinematics_integrator_ = EXPLICIT_EULER; 42 | penalty_ = 20.f; 43 | 44 | // threshold parameters 45 | theta_threshold_radians_ = 20.f * 2.f * M_PI / 360.f; 46 | x_threshold_ = 5.f; 47 | } 48 | 49 | std::pair CartPoleEnv::getStateBounds() { 50 | return std::make_pair({-2.f * x_threshold_, -1e7f, -2.f * theta_threshold_radians_, -1e7f}, 51 | {2.f * x_threshold_, 1e7f, 2.f * theta_threshold_radians_, 1e7f}); 52 | } 53 | 54 | StateVector CartPoleEnv::reset() { 55 | 56 | // reset state vector 57 | auto gen = [&udist = this->uniform_dist_, &rng = this->rng_]() { return udist(rng); }; 58 | 59 | std::generate(state_.begin(), state_.end(), gen); 60 | 61 | // now compress/expand the distribution to make sure we are inside the correct bounds 62 | state_[0] *= 0.3 * x_threshold_; 63 | state_[1] *= 0.2; 64 | state_[2] *= 0.3 * theta_threshold_radians_; 65 | state_[3] *= 0.2; 66 | 67 | // env info 68 | steps_beyond_terminated_ = -1; 69 | terminated_ = false; 70 | 71 | return state_; 72 | } 73 | 74 | std::tuple CartPoleEnv::step(float action) { 75 | 76 | // extract state vectors and set force term based on action 77 | float x = state_[0]; 78 | float x_dot = state_[1]; 79 | float theta = state_[2]; 80 | float theta_dot = state_[3]; 81 | 82 | // action is in [-1, 1], so we multiply with force mag 83 | float force = action * force_mag_; 84 | 85 | // derived parameters 86 | float ctheta = std::cos(theta); 87 | float stheta = std::sin(theta); 88 | 89 | float tmp = (force + polemass_length_ * theta_dot * theta_dot * stheta) / total_mass_; 90 | float thetaacc = 91 | (gravity_ * stheta - ctheta * tmp) / (length_ * (4.0 / 3.0 - masspole_ * ctheta * ctheta / total_mass_)); 92 | float xacc = tmp - polemass_length_ * thetaacc * ctheta / total_mass_; 93 | 94 | switch (kinematics_integrator_) { 95 | case EXPLICIT_EULER: 96 | x = x + dt_ * x_dot; 97 | x_dot = x_dot + dt_ * xacc; 98 | theta = theta + dt_ * theta_dot; 99 | theta_dot = theta_dot + dt_ * thetaacc; 100 | break; 101 | case SEMI_IMPLICIT_EULER: 102 | x_dot = x_dot + dt_ * xacc; 103 | x = x + dt_ * x_dot; 104 | theta_dot = theta_dot + dt_ * thetaacc; 105 | theta = theta + dt_ * theta_dot; 106 | break; 107 | } 108 | 109 | // update state 110 | state_ = {x, x_dot, theta, theta_dot}; 111 | 112 | // decide if sim was terminated 113 | bool terminated = false; 114 | if ((x < -x_threshold_) || (x > x_threshold_) || (theta < -theta_threshold_radians_) || 115 | (theta > theta_threshold_radians_)) { 116 | terminated = true; 117 | } 118 | 119 | // compute reward 120 | float r1 = (x_threshold_ - std::abs(x)) / x_threshold_; 121 | float r2 = (theta_threshold_radians_ / 4. - std::abs(theta)) / (theta_threshold_radians_ / 4.); 122 | float e1 = (std::abs(x)) / x_threshold_; 123 | float e2 = (std::abs(theta)) / theta_threshold_radians_; 124 | float reward = cost_continuous_fn(r1, r2, e1, e2, x, x_dot, theta, theta_dot); 125 | 126 | // add reward penalty for failure 127 | if (terminated) { 128 | reward -= penalty_; 129 | } 130 | 131 | return std::make_tuple(state_, reward, terminated); 132 | } 133 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/README.md: -------------------------------------------------------------------------------- 1 | # Cartpole Balancing with Deep Reinforcement Learning 2 | 3 | ## Introduction 4 | 5 | ![Cartpole Balancing Illustration](media/cartpole.gif) 6 | 7 | The cartpole problem is a classical deep reinforcement learning problem (DRL). It is an inverted pendulum with a center of gravity above its pivot point. 8 | Although the system is unstable, it can be balanced by a movable frinctionless cart by moving its pivot point under the center of mass of the system. 9 | The goal is to train a software which balances the pole using this technique for as long possible. 10 | 11 | ## Description 12 | A more detailed description of the cartpole problem can be found in the [Gym Documentation](https://www.gymlibrary.dev/environments/classic_control/cart_pole/). 13 | Our implementation differs slightly from this implementation as we allow for a continuous action space instead of a discrete one. This means, instead of allowing the cart to be moved only to the left or right with a constant force, we allow the cart to be moved with a force value between $a \in [-1, 1]$, where the boundaries corresponds to a left and right move of the original problem respectively. Values in between represent left moves (for a<0) or right moves (a>0) with force magnitudes $10\cdot|a|$ respectively. Note that the problem formulation with a discrete action space is impossible to solve with TD3 and therefore we chose the continuous representation (see below). 14 | The state space consists of a 4-tuple containing the position and velocity of cart as well as pole angle and its velocity respectively. We denote those values as $(x, \dot{x}, \theta, \dot{\theta})$. 15 | We furthermore define bounds for the position and pole angle and stop the simulation as soon as those bounds are violated. 16 | 17 | The reward function penalizes the agent proportional to the cart distance from the center as well as angle distance from the angular threshold $\theta_t$. If the cart leaves the allowed bounds, we add an additional penalty of 20. More specifically, we use the following cost function: 18 | $\mathrm{cost}(x, r, d) = (\mathrm{sgn}(r) \cdot 100 \cdot r^2 - 4\cdot x^2) / 1000 - 20\cdot d$, where $r = (\theta_t/4 - |\theta|) / (\theta_t/4)$ and $d=1$ if the system goes out of bounds and $d=0$ otherwise. 19 | 20 | The cartpole problem can traditionally be efficiently solved with on-policy methods such as actor-critic (e.g. [A2C](https://stable-baselines3.readthedocs.io/en/master/modules/a2c.html)). In our approach, we are using the off-policy method TD3 (cf the [stable baselines documentation](https://stable-baselines3.readthedocs.io/en/master/modules/td3.html) for details). 21 | 22 | ## Directory Structure 23 | 24 | The cartpole environment is implemented in C++ in file `env.cpp` and a corresponding pybind11 wrapper is defined in `py_env.cpp`. With this wrapper, the environment can also be used from python based training infrastructures but we only use it for visualization. 25 | Since the problem is very small, we will run the environment on the CPU and use the GPU only for training and inference. The training code is implemented in `train.cpp`. Before training can be started, the models have to be initialized from pytorch. This can be done by executing `python/initialize_models.py` from this directory. 26 | 27 | The configuration file for the sim environment is `config_sim.yaml` and the one for the TD3 deep reinforcement learning system is in `config.yaml`. 28 | 29 | After training the trained agent can be visualied using the visualization code in `python/visualize.py`. 30 | 31 | ## Installation 32 | 33 | We highly recommend building the example together with the torchfort library using `cmake`. The example can also be build independently but the cmake specification file `CMakeList.txt` might need to be adjusted. 34 | 35 | ## How To Run 36 | 37 | In order to run training, create a directory called `checkpoint` in the location of the `train_cart_pole` and place the `python` directory as well as both configs file there. Execute `python python/initialize_models.py` which should create the two model files `policy.pt` and `value.pt` respectively. Those are PyTorch compatible model files which can be read by torchfort. 38 | Once everything is in place, the training can then be started by executing `train_cart_pole`. This file will read the two configuration files for the simulation (`config_sim.yaml`) and the TD3 DRL system (`config.yaml`). 39 | 40 | The training is logging action and state data as well as loss and reward values during the training process. Training will automatically successfully end once an evaluation episode was able to balance the pole for at least 500 environment steps without falling down. Otherwise it will terminate unsuccessfully after 50000 episodes (note that this number can be changed in the `config_sim.yaml` file). 41 | 42 | During or after training, the cart performance can be visualized using the `python/visualize.py` script. It uses [pygame](https://www.pygame.org) to render frames and [MoviePy](https://moviepy.readthedocs.io/en/latest/) to write the rendered file as a gif as well as an mp4 video. The script can be run via 43 | `python python/visualize.py --policy_checkpoint= --num_steps= --output_path=`. Note that in TD3, the target policy model which should be used can be found under `checkpoints/policy_target/model.pt`. The policy model under `checkpoints/policy/model.pt` is the active network and might not perform as well as expected. For a detailed explanation of the difference, see the [TD3 documentation](https://stable-baselines3.readthedocs.io/en/master/modules/td3.html). 44 | --------------------------------------------------------------------------------