├── docs
├── _static
│ └── style.css
├── requirements.txt
├── _templates
│ ├── footer.html
│ └── layout.html
├── api.rst
├── Makefile
├── index.rst
├── conf.py
├── installation.rst
└── extras.rst
├── examples
├── cpp
│ └── cart_pole
│ │ ├── config_sim.yaml
│ │ ├── media
│ │ ├── cartpole.gif
│ │ └── cartpole.mp4
│ │ ├── py_env.cpp
│ │ ├── config.yaml
│ │ ├── env.h
│ │ ├── python
│ │ ├── initialize_models.py
│ │ └── models.py
│ │ ├── CMakeLists.txt
│ │ ├── env.cpp
│ │ └── README.md
└── fortran
│ ├── graph
│ ├── media
│ │ └── validation_results.gif
│ ├── config.yaml
│ ├── CMakeLists.txt
│ ├── generate_loss.py
│ ├── visualize.py
│ └── generate_model.py
│ └── simulation
│ ├── media
│ └── validation_results.gif
│ ├── config_fcn_torchscript.yaml
│ ├── config_mlp_native.yaml
│ ├── generate_fcn_model.py
│ ├── CMakeLists.txt
│ ├── simulation.f90
│ └── visualize.py
├── tests
├── supervised
│ ├── configs
│ │ ├── missing_opt.yaml
│ │ ├── missing_loss.yaml
│ │ ├── torchscript.yaml
│ │ ├── mlp2.yaml
│ │ ├── mlp3.yaml
│ │ ├── mlp2_gradacc.yaml
│ │ ├── torchscript_multiarg.yaml
│ │ ├── torchscript_multiarg_extra.yaml
│ │ └── mlp.yaml
│ ├── scripts
│ │ └── setup_tests.py
│ └── CMakeLists.txt
├── general
│ ├── configs
│ │ ├── l1.yaml
│ │ ├── mse.yaml
│ │ ├── l1_multiarg.yaml
│ │ ├── mse_multiarg.yaml
│ │ ├── torchscript.yaml
│ │ ├── torchscript_multiout.yaml
│ │ ├── torchscript_multiarg.yaml
│ │ └── torchscript_multiarg_extra.yaml
│ ├── scripts
│ │ └── setup_tests.py
│ └── CMakeLists.txt
├── rl
│ ├── configs
│ │ ├── ppo.yaml
│ │ ├── ddpg.yaml
│ │ ├── td3.yaml
│ │ └── sac.yaml
│ ├── test_distributions.cpp
│ └── CMakeLists.txt
└── test_utils.h
├── requirements.txt
├── .clang-format
├── .github
├── scripts
│ └── run_ci_tests.sh
└── workflows
│ ├── docs.yaml
│ ├── format.yaml
│ └── build.yml
├── src
└── csrc
│ ├── include
│ ├── internal
│ │ ├── rl
│ │ │ ├── rl.h
│ │ │ ├── distributions.h
│ │ │ └── utils.h
│ │ ├── base_model.h
│ │ ├── base_loss.h
│ │ ├── model_state.h
│ │ ├── tensor_list.h
│ │ ├── model_pack.h
│ │ ├── distributed.h
│ │ ├── model_wrapper.h
│ │ ├── nvtx.h
│ │ ├── setup.h
│ │ ├── base_lr_scheduler.h
│ │ ├── losses.h
│ │ ├── training.h
│ │ ├── logging.h
│ │ ├── lr_schedulers.h
│ │ ├── models.h
│ │ ├── utils.h
│ │ └── param_map.h
│ ├── torchfort_config.h.in
│ └── torchfort_enums.h
│ ├── param_map.cpp
│ ├── lr_schedulers
│ ├── step_lr.cpp
│ ├── polynomial_lr.cpp
│ ├── linear_lr.cpp
│ ├── multistep_lr.cpp
│ └── cosine_annealing_lr.cpp
│ ├── model_state.cpp
│ ├── losses
│ ├── l1_loss.cpp
│ ├── mse_loss.cpp
│ └── torchscript_loss.cpp
│ ├── models
│ ├── mlp_model.cpp
│ └── rl
│ │ ├── sac_model.cpp
│ │ └── common_models.cpp
│ ├── logging.cpp
│ ├── utils.cpp
│ ├── model_pack.cpp
│ ├── model_wrapper.cpp
│ └── rl
│ └── utils.cpp
├── README.md
├── SECURITY.md
├── docker
├── Dockerfile_gnu_cpuonly
├── Dockerfile_gnu
└── Dockerfile
└── CONTRIBUTING.md
/docs/_static/style.css:
--------------------------------------------------------------------------------
1 | .wy-nav-content {
2 | max-width: 1200px !important;
3 | }
4 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/config_sim.yaml:
--------------------------------------------------------------------------------
1 | num_episodes: 50000
2 | max_steps_per_episode: 2500
3 | eval_frequency: 25
4 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx==8.2.3
2 | sphinx_rtd_theme==3.0.2
3 | breathe==4.36.0
4 | sphinx-tabs==3.4.7
5 | sphinx-fortran==1.1.1
6 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/media/cartpole.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA/TorchFort/HEAD/examples/cpp/cart_pole/media/cartpole.gif
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/media/cartpole.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA/TorchFort/HEAD/examples/cpp/cart_pole/media/cartpole.mp4
--------------------------------------------------------------------------------
/examples/fortran/graph/media/validation_results.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA/TorchFort/HEAD/examples/fortran/graph/media/validation_results.gif
--------------------------------------------------------------------------------
/examples/fortran/simulation/media/validation_results.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA/TorchFort/HEAD/examples/fortran/simulation/media/validation_results.gif
--------------------------------------------------------------------------------
/tests/supervised/configs/missing_opt.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: mlp
3 | parameters:
4 | dropout: 0.0
5 | layer_sizes: [10, 10, 10]
6 |
7 | loss:
8 | type: MSE
9 |
--------------------------------------------------------------------------------
/tests/supervised/configs/missing_loss.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: mlp
3 | parameters:
4 | dropout: 0.0
5 | layer_sizes: [10, 10, 10]
6 |
7 | optimizer:
8 | type: adam
9 |
--------------------------------------------------------------------------------
/tests/general/configs/l1.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: torchscript
3 | parameters:
4 | filename: model.pt
5 |
6 | loss:
7 | type: L1
8 |
9 | optimizer:
10 | type: adam
11 |
--------------------------------------------------------------------------------
/tests/general/configs/mse.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: torchscript
3 | parameters:
4 | filename: model.pt
5 |
6 | loss:
7 | type: MSE
8 |
9 | optimizer:
10 | type: adam
11 |
--------------------------------------------------------------------------------
/docs/_templates/footer.html:
--------------------------------------------------------------------------------
1 | {% extends "!footer.html" %}
2 | {%- block contentinfo %}
3 | {{ super() }}
4 |
Documentation built from commit {{ version }}.
5 | {% endblock %}
6 |
--------------------------------------------------------------------------------
/tests/supervised/configs/torchscript.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: torchscript
3 | parameters:
4 | filename: "model.pt"
5 |
6 | loss:
7 | type: MSE
8 |
9 | optimizer:
10 | type: adam
11 |
--------------------------------------------------------------------------------
/tests/general/configs/l1_multiarg.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: torchscript
3 | parameters:
4 | filename: model_multiarg.pt
5 |
6 | loss:
7 | type: L1
8 |
9 | optimizer:
10 | type: adam
11 |
--------------------------------------------------------------------------------
/tests/general/configs/mse_multiarg.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: torchscript
3 | parameters:
4 | filename: model_multiarg.pt
5 |
6 | loss:
7 | type: MSE
8 |
9 | optimizer:
10 | type: adam
11 |
--------------------------------------------------------------------------------
/tests/supervised/configs/mlp2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: mlp
3 | parameters:
4 | dropout: 0.0
5 | layer_sizes: [10, 10, 10]
6 |
7 | loss:
8 | type: MSE
9 |
10 | optimizer:
11 | type: adam
12 |
--------------------------------------------------------------------------------
/tests/general/configs/torchscript.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: torchscript
3 | parameters:
4 | filename: model.pt
5 |
6 | loss:
7 | type: torchscript
8 | parameters:
9 | filename: loss.pt
10 |
11 | optimizer:
12 | type: adam
13 |
--------------------------------------------------------------------------------
/tests/supervised/configs/mlp3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: mlp
3 | parameters:
4 | dropout: 0.0
5 | layer_sizes: [10, 10, 10]
6 | flatten_non_batch_dims: false
7 |
8 | loss:
9 | type: MSE
10 |
11 | optimizer:
12 | type: adam
13 |
--------------------------------------------------------------------------------
/tests/general/configs/torchscript_multiout.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: torchscript
3 | parameters:
4 | filename: model.pt
5 |
6 | loss:
7 | type: torchscript
8 | parameters:
9 | filename: loss_multiout.pt
10 |
11 | optimizer:
12 | type: adam
13 |
--------------------------------------------------------------------------------
/tests/general/configs/torchscript_multiarg.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: torchscript
3 | parameters:
4 | filename: model_multiarg.pt
5 |
6 | loss:
7 | type: torchscript
8 | parameters:
9 | filename: loss_multiarg.pt
10 |
11 | optimizer:
12 | type: adam
13 |
--------------------------------------------------------------------------------
/tests/supervised/configs/mlp2_gradacc.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: mlp
3 | parameters:
4 | dropout: 0.0
5 | layer_sizes: [10, 10, 10]
6 |
7 | loss:
8 | type: MSE
9 |
10 | optimizer:
11 | type: adam
12 | general:
13 | grad_accumulation_steps: 4
14 |
--------------------------------------------------------------------------------
/tests/supervised/configs/torchscript_multiarg.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: torchscript
3 | parameters:
4 | filename: "model_multiarg.pt"
5 |
6 | loss:
7 | type: torchscript
8 | parameters:
9 | filename: "loss_multiarg.pt"
10 |
11 | optimizer:
12 | type: adam
13 |
--------------------------------------------------------------------------------
/tests/general/configs/torchscript_multiarg_extra.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: torchscript
3 | parameters:
4 | filename: model_multiarg.pt
5 |
6 | loss:
7 | type: torchscript
8 | parameters:
9 | filename: loss_multiarg_extra.pt
10 |
11 | optimizer:
12 | type: adam
13 |
--------------------------------------------------------------------------------
/tests/supervised/configs/torchscript_multiarg_extra.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: torchscript
3 | parameters:
4 | filename: "model_multiarg.pt"
5 |
6 | loss:
7 | type: torchscript
8 | parameters:
9 | filename: "loss_multiarg_extra.pt"
10 |
11 | optimizer:
12 | type: adam
13 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # basic packages
2 | ruamel-yaml
3 |
4 | # pytorch and some dependencies
5 | torch==2.8.0
6 |
7 | # training monitoring
8 | wandb
9 |
10 | # RL example visualization related
11 | pygame
12 | moviepy
13 |
14 | # Supervised learning example visualization related
15 | matplotlib
16 |
--------------------------------------------------------------------------------
/docs/api.rst:
--------------------------------------------------------------------------------
1 | .. _api-label:
2 |
3 | #############
4 | TorchFort API
5 | #############
6 |
7 | The following sections describe the types and functions available in the TorchFort library for C/C++ and Fortran programs and
8 | the also the configuration file structure and available options.
9 |
10 | .. toctree::
11 |
12 | api/config
13 | api/c_api
14 | api/f_api
15 |
16 |
--------------------------------------------------------------------------------
/.clang-format:
--------------------------------------------------------------------------------
1 | ---
2 | BasedOnStyle: LLVM
3 | ColumnLimit: 120
4 | CommentPragmas: '^\\.+'
5 | DerivePointerAlignment: false
6 | Language: Cpp
7 | PointerAlignment: Left
8 | UseTab: Never
9 | AlignAfterOpenBracket: Align
10 | AlignTrailingComments: true
11 | AllowShortBlocksOnASingleLine: false
12 | AllowShortCaseLabelsOnASingleLine : false
13 | AllowShortIfStatementsOnASingleLine: false
14 | AllowShortLoopsOnASingleLine: false
15 | ...
16 |
--------------------------------------------------------------------------------
/tests/supervised/configs/mlp.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | enable_wandb_hook: 0
3 | report_frequency: 100
4 |
5 | model:
6 | type: mlp
7 | parameters:
8 | dropout: 0.0
9 | layer_sizes: [32, 32, 1]
10 |
11 | loss:
12 | type: MSE
13 |
14 | optimizer:
15 | type: adam
16 | parameters:
17 | learning_rate: 1e-3
18 | beta1: 0.9
19 | beta2: 0.999
20 | weight_decay: 0
21 | eps: 1e-8
22 | amsgrad: 0
23 |
24 | lr_scheduler:
25 | type: cosine_annealing
26 | parameters:
27 | T_max: 100000
28 |
--------------------------------------------------------------------------------
/examples/fortran/simulation/config_fcn_torchscript.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | enable_wandb_hook: 1
3 | report_frequency: 100
4 |
5 | model:
6 | type: torchscript
7 | parameters:
8 | filename: "fcn_torchscript.pt"
9 |
10 | loss:
11 | type: MSE
12 |
13 | optimizer:
14 | type: adam
15 | parameters:
16 | learning_rate: 1e-3
17 | beta1: 0.9
18 | beta2: 0.999
19 | weight_decay: 0
20 | eps: 1e-8
21 | amsgrad: 0
22 |
23 | lr_scheduler:
24 | type: cosine_annealing
25 | parameters:
26 | T_max: 100000
27 |
--------------------------------------------------------------------------------
/examples/fortran/simulation/config_mlp_native.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | enable_wandb_hook: 1
3 | report_frequency: 100
4 |
5 | model:
6 | type: mlp
7 | parameters:
8 | dropout: 0.0
9 | layer_sizes: [1024, 1024]
10 |
11 | loss:
12 | type: MSE
13 |
14 | optimizer:
15 | type: adam
16 | parameters:
17 | learning_rate: 1e-3
18 | beta1: 0.9
19 | beta2: 0.999
20 | weight_decay: 0
21 | eps: 1e-8
22 | amsgrad: 0
23 |
24 | lr_scheduler:
25 | type: cosine_annealing
26 | parameters:
27 | T_max: 100000
28 |
--------------------------------------------------------------------------------
/.github/scripts/run_ci_tests.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -euxo pipefail
3 |
4 | cd /opt/torchfort/bin/tests/general
5 | python scripts/setup_tests.py
6 | ./test_losses
7 |
8 | cd /opt/torchfort/bin/tests/supervised
9 | python scripts/setup_tests.py
10 | ./test_checkpoint
11 | ./test_training
12 | mpirun -np 2 --allow-run-as-root ./test_distributed_training
13 |
14 | cd /opt/torchfort/bin/tests/rl
15 | ./test_distributions
16 | ./test_replay_buffer
17 | ./test_rollout_buffer
18 | ./test_off_policy --gtest_filter=*L0*
19 | ./test_on_policy --gtest_filter=*L0*
20 |
--------------------------------------------------------------------------------
/examples/fortran/graph/config.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | report_frequency: 100
3 |
4 | model:
5 | type: torchscript
6 | parameters:
7 | filename: "model_torchscript.pt"
8 |
9 | loss:
10 | type: torchscript
11 | parameters:
12 | filename: "loss_torchscript.pt"
13 |
14 | optimizer:
15 | type: adam
16 | parameters:
17 | learning_rate: 1e-3
18 | beta1: 0.9
19 | beta2: 0.999
20 | weight_decay: 0
21 | eps: 1e-8
22 | amsgrad: 0
23 |
24 | lr_scheduler:
25 | type: cosine_annealing
26 | parameters:
27 | T_max: 100000
28 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | doxygen Doxyfile
21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
22 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/rl/rl.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #include "internal/rl/off_policy.h"
21 | #include "internal/rl/on_policy.h"
22 |
--------------------------------------------------------------------------------
/src/csrc/include/torchfort_config.h.in:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | /* Define to 1 if TorchFort was built with GPU support */
21 | #cmakedefine01 TORCHFORT_ENABLE_GPU
22 |
23 |
--------------------------------------------------------------------------------
/src/csrc/param_map.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 | #include
20 |
21 | #include "internal/param_map.h"
22 |
23 | namespace torchfort {
24 |
25 | std::set ParamMap::keys() const {
26 | std::set keys;
27 | for (const auto& entry : params) {
28 | keys.insert(entry.first);
29 | }
30 | return keys;
31 | }
32 |
33 | } // namespace torchfort
34 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/base_model.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #include
21 |
22 | #include "internal/param_map.h"
23 |
24 | namespace torchfort {
25 |
26 | struct BaseModel : torch::nn::Module {
27 | virtual std::vector forward(const std::vector& inputs) = 0;
28 | virtual void setup(const ParamMap& params) = 0;
29 | };
30 |
31 | } // namespace torchfort
32 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/py_env.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 | #include
20 |
21 | #include "env.h"
22 |
23 | // pybind11 stuff
24 | namespace py = pybind11;
25 | PYBIND11_MODULE(PyEnvironments, m) {
26 | py::class_(m, "CartPoleEnv", py::dynamic_attr())
27 | .def(py::init<>())
28 | .def("step", &CartPoleEnv::step, py::arg("action"))
29 | .def("reset", &CartPoleEnv::reset);
30 | }
31 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/config.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | report_frequency: 1
3 | enable_wandb_hook: 1
4 | verbose: 1
5 |
6 | algorithm:
7 | type: td3
8 | parameters:
9 | batch_size: 512
10 | num_critics: 2
11 | policy_lag: 2
12 | nstep: 1
13 | nstep_reward_reduction: sum_no_skip
14 | gamma: 0.99
15 | rho: 0.99
16 |
17 | actor:
18 | type: space_noise
19 | parameters:
20 | a_low: -1.0
21 | a_high: 1.0
22 | clip: 0.3
23 | sigma_train: 0.1
24 | sigma_explore: 0.2
25 | adaptive: 0
26 |
27 | replay_buffer:
28 | type: uniform
29 | parameters:
30 | max_size: 50000
31 | min_size: 1024
32 |
33 | policy_model:
34 | type: torchscript
35 | parameters:
36 | filename: policy.pt
37 |
38 | critic_model:
39 | type: torchscript
40 | parameters:
41 | filename: value.pt
42 |
43 | optimizer:
44 | type: adam
45 | parameters:
46 | learning_rate: 0.001
47 | beta1: 0.9
48 | beta2: 0.999
49 | weight_decay: 0
50 | eps: 1e-6
51 | amsgrad: 0
52 |
53 | policy_lr_scheduler:
54 | type: cosine_annealing
55 | parameters:
56 | T_max: 500000000
57 |
58 | critic_lr_scheduler:
59 | type: cosine_annealing
60 | parameters:
61 | T_max: 500000000
62 |
--------------------------------------------------------------------------------
/tests/rl/configs/ppo.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | report_frequency: 1
3 | enable_wandb_hook: 0
4 | verbose: 1
5 |
6 | algorithm:
7 | type: ppo
8 | parameters:
9 | batch_size: 8
10 | gamma: 0.99
11 | gae_lambda: 0.95
12 | epsilon: 0.2
13 | clip_q: 0.
14 | target_kl_divergence: 0.02
15 | entropy_loss_coefficient: 0.
16 | value_loss_coefficient: 0.5
17 | normalize_advantage: True
18 |
19 | actor:
20 | type: gaussian_ac
21 | parameters:
22 | a_low: -1.0
23 | a_high: 1.0
24 |
25 | rollout_buffer:
26 | type: gae_lambda
27 | parameters:
28 | size: 64
29 |
30 | actor_critic_model:
31 | type: ActorCriticMLP
32 | parameters:
33 | dropout: 0.0
34 | encoder_layer_sizes: [1, 16, 8]
35 | actor_layer_sizes: [8, 1]
36 | value_layer_sizes: [8, 1]
37 | state_dependent_sigma: False
38 | log_sigma_init: 0.
39 |
40 | optimizer:
41 | type: adam
42 | parameters:
43 | learning_rate: 1e-4
44 | beta1: 0.9
45 | beta2: 0.999
46 | weight_decay: 0
47 | eps: 1e-6
48 | amsgrad: 0
49 | general:
50 | max_grad_norm: 0.5
51 |
52 | lr_scheduler:
53 | type: linear
54 | parameters:
55 | total_iters: 40000
56 | start_factor: 1.0
57 | end_factor: 0.01
58 |
--------------------------------------------------------------------------------
/docs/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 | {% block sidebartitle %} {{ super() }}
3 |
4 |
32 | {% endblock %}
33 |
34 | {% block footer %} {{ super() }}
35 |
36 |
51 | {% endblock %}
52 |
--------------------------------------------------------------------------------
/tests/rl/configs/ddpg.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | report_frequency: 1
3 | enable_wandb_hook: 0
4 | verbose: 1
5 |
6 | algorithm:
7 | type: ddpg
8 | parameters:
9 | batch_size: 128
10 | nstep: 1
11 | nstep_reward_reduction: sum
12 | gamma: 0.95
13 | rho: 0.999
14 |
15 | actor:
16 | type: space_noise
17 | parameters:
18 | a_low: -1.
19 | a_high: 1.
20 | clip: 0.3
21 | sigma_train: 0.2
22 | sigma_explore: 1.0
23 |
24 | replay_buffer:
25 | type: uniform
26 | parameters:
27 | max_size: 4096
28 | min_size: 512
29 |
30 | policy_model:
31 | type: MLP
32 | parameters:
33 | dropout: 0.0
34 | layer_sizes: [1, 16, 1]
35 |
36 | critic_model:
37 | type: CriticMLP
38 | parameters:
39 | dropout: 0.0
40 | layer_sizes: [2, 16, 1]
41 |
42 | optimizer:
43 | type: adam
44 | parameters:
45 | learning_rate: 1e-4
46 | beta1: 0.9
47 | beta2: 0.999
48 | weight_decay: 0
49 | eps: 1e-6
50 | amsgrad: 0
51 |
52 | policy_lr_scheduler:
53 | type: linear
54 | parameters:
55 | total_iters: 20000
56 | start_factor: 1.0
57 | end_factor: 0.01
58 |
59 | critic_lr_scheduler:
60 | type: linear
61 | parameters:
62 | total_iters: 20000
63 | start_factor: 1.0
64 | end_factor: 0.01
--------------------------------------------------------------------------------
/tests/rl/configs/td3.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | report_frequency: 1
3 | enable_wandb_hook: 0
4 | verbose: 1
5 |
6 | algorithm:
7 | type: td3
8 | parameters:
9 | batch_size: 128
10 | num_critics: 2
11 | policy_lag: 2
12 | nstep: 1
13 | nstep_reward_reduction: sum
14 | gamma: 0.95
15 | rho: 0.999
16 |
17 | actor:
18 | type: space_noise
19 | parameters:
20 | a_low: -1.
21 | a_high: 1.
22 | clip: 0.3
23 | sigma_train: 0.2
24 | sigma_explore: 1.0
25 |
26 | replay_buffer:
27 | type: uniform
28 | parameters:
29 | max_size: 4096
30 | min_size: 512
31 |
32 | policy_model:
33 | type: MLP
34 | parameters:
35 | dropout: 0.0
36 | layer_sizes: [1, 16, 1]
37 |
38 | critic_model:
39 | type: CriticMLP
40 | parameters:
41 | dropout: 0.0
42 | layer_sizes: [2, 16, 1]
43 |
44 | optimizer:
45 | type: adam
46 | parameters:
47 | learning_rate: 1e-4
48 | beta1: 0.9
49 | beta2: 0.999
50 | weight_decay: 0
51 | eps: 1e-6
52 | amsgrad: 0
53 |
54 | policy_lr_scheduler:
55 | type: linear
56 | parameters:
57 | total_iters: 10000
58 | start_factor: 1.0
59 | end_factor: 0.01
60 |
61 | critic_lr_scheduler:
62 | type: linear
63 | parameters:
64 | total_iters: 20000
65 | start_factor: 1.0
66 | end_factor: 0.01
--------------------------------------------------------------------------------
/.github/workflows/docs.yaml:
--------------------------------------------------------------------------------
1 | name: "Documentation"
2 |
3 | on:
4 | push:
5 | branches: [ master ]
6 | pull_request:
7 | branches: [ master ]
8 |
9 | jobs:
10 | build-documentation:
11 | runs-on: ubuntu-latest
12 | name: "Build Documentation"
13 |
14 | steps:
15 | - name: "Checkout code"
16 | uses: actions/checkout@v4
17 |
18 | - name: "Install dependencies"
19 | run: |
20 | sudo apt-get update && sudo apt-get install -y doxygen
21 | pip install -r docs/requirements.txt
22 |
23 | - name: "Build documentation"
24 | run: |
25 | cd docs
26 | export TORCHFORT_GIT_SHA=${{ github.event.pull_request.head.sha || github.sha }}
27 | make html
28 |
29 | - name: "Upload documentation"
30 | uses: actions/upload-pages-artifact@v3
31 | with:
32 | path: docs/_build/html
33 |
34 |
35 | deploy-documentation:
36 | runs-on: ubuntu-latest
37 | if: github.event_name == 'push'
38 | needs: build-documentation
39 | permissions:
40 | contents: read
41 | pages: write
42 | id-token: write
43 | environment:
44 | name: github-pages
45 | url: ${{steps.deployment.outputs.page_url}}
46 | steps:
47 | - name: "Deploy documentation"
48 | id: deployment
49 | uses: actions/deploy-pages@v4
50 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/base_loss.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #include
21 |
22 | #include
23 |
24 | #include "internal/base_loss.h"
25 | #include "internal/param_map.h"
26 |
27 | namespace torchfort {
28 |
29 | struct BaseLoss : torch::nn::Module {
30 | virtual torch::Tensor forward(const std::vector& inputs, const std::vector& labels,
31 | const std::vector& extra_args) = 0;
32 | virtual void setup(const ParamMap& params) = 0;
33 | };
34 |
35 | } // namespace torchfort
36 |
--------------------------------------------------------------------------------
/examples/fortran/simulation/generate_fcn_model.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 |
18 | class Net(torch.nn.Module):
19 | def __init__(self):
20 | super(Net, self).__init__()
21 | self.conv1 = torch.nn.Conv2d(1, 1, 3, padding=1, padding_mode="circular")
22 |
23 | def forward(self, x):
24 | return self.conv1(x)
25 |
26 |
27 | def main():
28 | # Create model
29 | model = Net()
30 | print("FCN model:", model)
31 |
32 | try:
33 | # Move model to GPU, JIT, and save
34 | model.to("cuda")
35 | except:
36 | print("PyTorch does not have CUDA support. Saving model on CPU.")
37 | model_jit = torch.jit.script(model)
38 | model_jit.save("fcn_torchscript.pt")
39 |
40 | if __name__ == "__main__":
41 | main()
42 |
--------------------------------------------------------------------------------
/src/csrc/lr_schedulers/step_lr.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 | #include
20 |
21 | #include "internal/base_lr_scheduler.h"
22 | #include "internal/lr_schedulers.h"
23 |
24 | namespace torchfort {
25 |
26 | StepLR::StepLR(torch::optim::Optimizer& optimizer, const unsigned step_size, const double gamma)
27 | : BaseLRScheduler(optimizer), step_size_(step_size), gamma_(gamma) {}
28 |
29 | std::vector StepLR::get_lrs() {
30 | if (step_count_ == 0 || step_count_ % step_size_ != 0)
31 | return get_current_lrs();
32 | else {
33 | std::vector lrs = get_current_lrs();
34 | std::transform(lrs.begin(), lrs.end(), lrs.begin(), [this](const double& v) { return this->gamma_ * v; });
35 | return lrs;
36 | }
37 | }
38 |
39 | } // namespace torchfort
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TorchFort
2 |
3 | An Online Deep Learning Interface for HPC programs on NVIDIA GPUs
4 |
5 | ## Introduction
6 | TorchFort is a DL training and inference interface for HPC programs implemented using LibTorch, the C++ backend used by the [PyTorch](https://pytorch.org) framework.
7 | The goal of this library is to help practitioners and domain scientists to seamlessly combine their simulation codes with Deep Learning functionalities available
8 | within PyTorch.
9 | This library can be invoked directly from Fortran or C/C++ programs, enabling transparent sharing of data arrays to and from the DL framework all contained within the
10 | simulation process (i.e., no external glue/data-sharing code required). The library can directly load PyTorch model definitions exported to TorchScript and implements a
11 | configurable training process that users can control via a simple YAML configuration file format. The configuration files enable users to specify optimizer and loss selection,
12 | learning rate schedules, and much more.
13 |
14 | Please refer to the [documentation](https://nvidia.github.io/TorchFort/) for additional information on the library, build instructions, and usage details.
15 |
16 | Please refer to the [examples](examples) to see TorchFort in action.
17 |
18 | Contact us or open a GitHub issue if you are interested in using this library in your own solvers and have questions on usage and/or feature requests.
19 |
20 | ## License
21 | This library is released under an Apache 2.0 license, which can be found in [LICENSE](LICENSE).
22 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/model_state.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 | #include
20 | #include
21 |
22 | #include
23 |
24 | namespace torchfort {
25 |
26 | // Simple struct to store miscellaneous model state (e.g. iteration count)
27 | struct ModelState {
28 | int64_t step_train;
29 | int64_t step_inference;
30 | int64_t step_train_current; // training step of current run (ignoring restarted state)
31 | torch::Device device = torch::Device(torch::kCPU);
32 |
33 | // General option settings
34 | int32_t report_frequency;
35 | bool enable_wandb_hook;
36 | bool verbose;
37 | std::filesystem::path report_file;
38 |
39 | void save(const std::string& fname);
40 | void load(const std::string& fname);
41 | };
42 |
43 | } // namespace torchfort
44 |
--------------------------------------------------------------------------------
/examples/fortran/graph/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | set(fortran_example_targets
2 | train_graph
3 | )
4 |
5 | add_executable(train_graph)
6 | target_sources(train_graph
7 | PRIVATE
8 | train.f90
9 | )
10 | set_target_properties(train_graph
11 | PROPERTIES OUTPUT_NAME train)
12 |
13 | foreach(tgt ${fortran_example_targets})
14 | target_include_directories(${tgt}
15 | PRIVATE
16 | ${CMAKE_BINARY_DIR}/include
17 | ${MPI_Fortran_INCLUDE_DIRS}
18 | )
19 | target_link_libraries(${tgt} PRIVATE MPI::MPI_Fortran)
20 | target_link_libraries(${tgt} PRIVATE "${PROJECT_NAME}_fort")
21 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME})
22 | if (CMAKE_Fortran_COMPILER_ID STREQUAL "NVHPC")
23 | target_compile_options(${tgt} PRIVATE $<$:-cpp -acc -gpu=${CUF_GPU_ARG}>)
24 | target_link_options(${tgt} PRIVATE $<$: -acc -gpu=${CUF_GPU_ARG}>)
25 | elseif (CMAKE_Fortran_COMPILER_ID STREQUAL "GNU")
26 | target_compile_options(${tgt} PRIVATE $<$:-cpp -fbackslash>)
27 | endif()
28 | endforeach()
29 |
30 | install(
31 | TARGETS ${fortran_example_targets}
32 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/graph
33 | )
34 |
35 | install(
36 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/config.yaml
37 | ${CMAKE_CURRENT_SOURCE_DIR}/generate_model.py
38 | ${CMAKE_CURRENT_SOURCE_DIR}/generate_loss.py
39 | ${CMAKE_CURRENT_SOURCE_DIR}/nodes.txt
40 | ${CMAKE_CURRENT_SOURCE_DIR}/connectivity.txt
41 | ${CMAKE_CURRENT_SOURCE_DIR}/visualize.py
42 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/graph)
43 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/tensor_list.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #include
21 |
22 | #include
23 |
24 | #include "internal/utils.h"
25 |
26 | namespace torchfort {
27 | struct TensorList {
28 | template void add_tensor(T* data, size_t dim, int64_t* shape) {
29 | auto tensor = get_tensor(data, dim, shape);
30 | tensors.push_back(tensor);
31 | tensors_original_.push_back(tensor);
32 | };
33 |
34 | void to(torch::Device device, bool non_blocking = false) {
35 | for (auto& t : tensors) {
36 | t = t.to(device, non_blocking);
37 | }
38 | };
39 |
40 | void reset() { tensors = tensors_original_; }
41 |
42 | std::vector tensors;
43 | // To preserve references to external data, we store the original tensor objects
44 | std::vector tensors_original_;
45 | };
46 | } // namespace torchfort
47 |
--------------------------------------------------------------------------------
/tests/rl/configs/sac.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | report_frequency: 1
3 | enable_wandb_hook: 0
4 | verbose: 1
5 |
6 | algorithm:
7 | type: sac
8 | parameters:
9 | batch_size: 128
10 | num_critics: 2
11 | nstep: 1
12 | nstep_reward_reduction: sum
13 | gamma: 0.95
14 | rho: 0.999
15 | alpha: 0.1
16 |
17 | actor:
18 | type: parameter_noise
19 | parameters:
20 | a_low: -1.0
21 | a_high: 1.0
22 |
23 | replay_buffer:
24 | type: uniform
25 | parameters:
26 | max_size: 4096
27 | min_size: 512
28 |
29 | policy_model:
30 | type: SACMLP
31 | parameters:
32 | dropout: 0.0
33 | layer_sizes: [1, 16, 8, 1]
34 | state_dependent_sigma: False
35 | log_sigma_init: 0.
36 |
37 | critic_model:
38 | type: CriticMLP
39 | parameters:
40 | dropout: 0.0
41 | layer_sizes: [2, 16, 1]
42 |
43 | optimizer:
44 | type: adam
45 | parameters:
46 | learning_rate: 1e-4
47 | beta1: 0.9
48 | beta2: 0.999
49 | weight_decay: 0
50 | eps: 1e-6
51 | amsgrad: 0
52 |
53 | alpha_optimizer:
54 | type: adam
55 | parameters:
56 | learning_rate: 1e-4
57 | beta1: 0.9
58 | beta2: 0.999
59 | weight_decay: 0
60 | eps: 1e-6
61 | amsgrad: 0
62 |
63 | policy_lr_scheduler:
64 | type: linear
65 | parameters:
66 | total_iters: 20000
67 | start_factor: 1.0
68 | end_factor: 0.01
69 |
70 | critic_lr_scheduler:
71 | type: linear
72 | parameters:
73 | total_iters: 20000
74 | start_factor: 1.0
75 | end_factor: 0.01
76 |
77 | alpha_lr_scheduler:
78 | type: linear
79 | parameters:
80 | total_iters: 20000
81 | start_factor: 1.0
82 | end_factor: 0.01
83 |
--------------------------------------------------------------------------------
/examples/fortran/graph/generate_loss.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 |
18 | class CustomLoss(torch.nn.Module):
19 | def __init__(self):
20 | super(CustomLoss, self).__init__()
21 |
22 | def forward(self, prediction, label, node_types):
23 |
24 | # Compute MSE over all nodes
25 | err = (label - prediction)**2
26 |
27 | # Zero out error for boundary nodes
28 | mask = node_types != 0
29 | err *= mask.unsqueeze(-1)
30 |
31 | # Compute mean over non-boundary nodes
32 | mse = torch.sum(err) / (torch.sum(mask) * err.shape[1])
33 |
34 | return mse
35 |
36 | def main():
37 | # Create loss module
38 | loss = CustomLoss()
39 | print("loss module:", loss)
40 |
41 | try:
42 | # Move model to GPU, JIT, and save
43 | loss.to("cuda")
44 | except:
45 | print("PyTorch does not have CUDA support. Saving model on CPU.")
46 | loss_jit = torch.jit.script(loss)
47 | loss_jit.save("loss_torchscript.pt")
48 |
49 | if __name__ == "__main__":
50 | main()
51 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | ## Security
2 |
3 | NVIDIA is dedicated to the security and trust of our software products and services, including all source code repositories managed through our organization.
4 |
5 | If you need to report a security issue, please use the appropriate contact points outlined below. **Please do not report security vulnerabilities through GitHub.**
6 |
7 | ## Reporting Potential Security Vulnerability in an NVIDIA Product
8 |
9 | To report a potential security vulnerability in any NVIDIA product:
10 | - Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html)
11 | - E-Mail: psirt@nvidia.com
12 | - We encourage you to use the following PGP key for secure email communication: [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key)
13 | - Please include the following information:
14 | - Product/Driver name and version/branch that contains the vulnerability
15 | - Type of vulnerability (code execution, denial of service, buffer overflow, etc.)
16 | - Instructions to reproduce the vulnerability
17 | - Proof-of-concept or exploit code
18 | - Potential impact of the vulnerability, including how an attacker could exploit the vulnerability
19 |
20 | While NVIDIA currently does not have a bug bounty program, we do offer acknowledgement when an externally reported security issue is addressed under our coordinated vulnerability disclosure policy. Please visit our [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/) policies page for more information.
21 |
22 | ## NVIDIA Product Security
23 |
24 | For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security
25 |
--------------------------------------------------------------------------------
/src/csrc/lr_schedulers/polynomial_lr.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 | #include
20 | #include
21 |
22 | #include "internal/base_lr_scheduler.h"
23 | #include "internal/lr_schedulers.h"
24 |
25 | namespace torchfort {
26 |
27 | PolynomialLR::PolynomialLR(torch::optim::Optimizer& optimizer, const unsigned total_iters, const double power)
28 | : BaseLRScheduler(optimizer), total_iters_(total_iters), power_(power) {}
29 |
30 | std::vector PolynomialLR::get_lrs() {
31 | std::vector lrs = get_current_lrs();
32 | if (step_count_ == 0 || step_count_ > total_iters_)
33 | return lrs;
34 | else {
35 | double decay_factor =
36 | (1. - double(step_count_) / double(total_iters_)) / (1. - double(step_count_ - 1) / double(total_iters_));
37 | decay_factor = std::pow(decay_factor, power_);
38 |
39 | std::transform(lrs.begin(), lrs.end(), lrs.begin(), [decay_factor](const double& v) { return decay_factor * v; });
40 |
41 | return lrs;
42 | }
43 | }
44 |
45 | } // namespace torchfort
46 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/model_pack.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 | #include
20 |
21 | #include
22 |
23 | #include "internal/base_loss.h"
24 | #include "internal/base_lr_scheduler.h"
25 | #include "internal/distributed.h"
26 | #include "internal/model_state.h"
27 | #include "internal/model_wrapper.h"
28 |
29 | namespace torchfort {
30 |
31 | // Simple struct to group model, optimizer, lr scheduler, state, and comm objects
32 | struct ModelPack {
33 | std::shared_ptr model;
34 | std::shared_ptr optimizer;
35 | std::shared_ptr lr_scheduler;
36 | std::shared_ptr loss;
37 | std::shared_ptr comm;
38 | std::shared_ptr state;
39 | int grad_accumulation_steps = 1;
40 | float max_grad_norm = 0.0;
41 | };
42 |
43 | void save_model_pack(const ModelPack& model_pack, const std::string& fname, bool save_optimizer = true);
44 | void load_model_pack(ModelPack& model_pack, const std::string& fname, bool load_optimizer = true);
45 |
46 | } // namespace torchfort
47 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/distributed.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #ifdef ENABLE_GPU
21 | #include
22 | #include
23 | #endif
24 | #include
25 |
26 | #include
27 |
28 | namespace torchfort {
29 |
30 | struct Comm {
31 | void initialize();
32 | void finalize();
33 | void allreduce(torch::Tensor& tensor, bool average = false) const;
34 | void allreduce(std::vector& tensors, bool average = false) const;
35 | void allreduce(double& val, bool average = false) const;
36 | void allreduce(float& val, bool average = false) const;
37 | void broadcast(torch::Tensor& tensor, int root = 0) const;
38 |
39 | int rank;
40 | int size;
41 | torch::Device device;
42 | MPI_Comm mpi_comm;
43 | #ifdef ENABLE_GPU
44 | ncclComm_t nccl_comm = nullptr;
45 | cudaStream_t stream = nullptr;
46 | cudaEvent_t event = nullptr;
47 | #endif
48 | bool initialized = false;
49 |
50 | Comm(MPI_Comm mpi_comm, torch::Device device) : mpi_comm(mpi_comm), device(device){};
51 | };
52 |
53 | } // namespace torchfort
54 |
--------------------------------------------------------------------------------
/.github/workflows/format.yaml:
--------------------------------------------------------------------------------
1 | name: "Code Format"
2 |
3 | on:
4 | pull_request:
5 | branches: [ master ]
6 |
7 | jobs:
8 | clang-format:
9 | env:
10 | CLANG_FORMAT_VERSION: "15"
11 |
12 | runs-on: ubuntu-latest
13 | name: "clang-format check"
14 |
15 | steps:
16 | - name: "Checkout code"
17 | uses: actions/checkout@v4
18 |
19 | - name: "Install dependencies"
20 | run: |
21 | # Install clang-format
22 | sudo apt-get update && sudo apt-get install -y clang-format-${CLANG_FORMAT_VERSION}
23 |
24 | # Print clang-format version information
25 | clang-format-${CLANG_FORMAT_VERSION} --version
26 |
27 | - name: "Run clang-format check"
28 | run: |
29 | # Collect names of files that are not properly formatted
30 | filelist=`find src examples tests -name "*.cpp" -o -name "*.h"`
31 | files_to_fix=()
32 | for file in $filelist; do
33 | if ! clang-format-${CLANG_FORMAT_VERSION} --dry-run --Werror "$file" 2>/dev/null; then
34 | files_to_fix+=("$file")
35 | fi
36 | done
37 |
38 | # If any file is not properly formatted, print diff and exit with error
39 | if [ ${#files_to_fix[@]} -gt 0 ]; then
40 | # Print the list of files that are not properly formatted
41 | echo "FAIL: Some files are not properly formatted. To resolve issues, run:"
42 | for file in "${files_to_fix[@]}"; do
43 | echo "clang-format-${CLANG_FORMAT_VERSION} -i $file"
44 | done
45 | echo
46 |
47 | for file in "${files_to_fix[@]}"; do
48 | echo "Diff for $file:"
49 | bash -c "clang-format-${CLANG_FORMAT_VERSION} $file | diff $file -; exit 0"
50 | echo
51 | done
52 |
53 | exit 1
54 | fi
55 |
56 | echo "PASS: All files are properly formatted."
57 |
--------------------------------------------------------------------------------
/docker/Dockerfile_gnu_cpuonly:
--------------------------------------------------------------------------------
1 | FROM ubuntu:22.04
2 |
3 | SHELL ["/bin/bash", "-c"]
4 |
5 | # Install System Dependencies
6 | ENV DEBIAN_FRONTEND noninteractive
7 | RUN apt update -y && \
8 | apt install -y build-essential && \
9 | apt install -y wget cmake && \
10 | apt install -y python3 python-is-python3 python3-pip python3-pybind11 && \
11 | apt install -y git vim gfortran && \
12 | apt install -y libibverbs-dev ibverbs-utils numactl && \
13 | apt install -y openmpi-bin libopenmpi-dev
14 |
15 | # Install PyTorch
16 | RUN pip3 install torch==2.8.0 --index-url https://download.pytorch.org/whl/cpu
17 |
18 | # Install yaml-cpp
19 | RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \
20 | cd yaml-cpp && \
21 | mkdir build && cd build && \
22 | cmake -DCMAKE_INSTALL_PREFIX=/opt/yaml-cpp \
23 | -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" \
24 | -DBUILD_SHARED_LIBS=OFF \
25 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && \
26 | make -j$(nproc) && make install
27 | ENV LD_LIBRARY_PATH /opt/yaml-cpp/lib:${LD_LIBRARY_PATH}
28 |
29 | # Install additional Python dependencies
30 | RUN pip3 install wandb ruamel-yaml matplotlib pygame moviepy
31 |
32 | # Install TorchFort without GPU support
33 | ENV FC=gfortran
34 | COPY . /torchfort
35 | RUN cd /torchfort && mkdir build && cd build && \
36 | cmake -DCMAKE_INSTALL_PREFIX=/opt/torchfort \
37 | -DTORCHFORT_YAML_CPP_ROOT=/opt/yaml-cpp \
38 | -DTORCHFORT_ENABLE_GPU=0 \
39 | -DTORCHFORT_BUILD_EXAMPLES=1 \
40 | -DTORCHFORT_BUILD_TESTS=1 \
41 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \
42 | .. && \
43 | make -j$(nproc) install && \
44 | cd / && rm -rf torchfort
45 | ENV LD_LIBRARY_PATH /opt/torchfort/lib:${LD_LIBRARY_PATH}
46 | ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH}
47 |
--------------------------------------------------------------------------------
/examples/fortran/simulation/CMakeLists.txt:
--------------------------------------------------------------------------------
1 |
2 | set(fortran_example_targets
3 | train
4 | train_distributed
5 | )
6 |
7 | add_executable(train)
8 | target_sources(train
9 | PRIVATE
10 | train.f90
11 | simulation.f90
12 | )
13 | set_target_properties(train PROPERTIES Fortran_MODULE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mod/0 )
14 |
15 | add_executable(train_distributed)
16 | target_sources(train_distributed
17 | PRIVATE
18 | train_distributed.f90
19 | simulation.f90
20 | )
21 | set_target_properties(train_distributed PROPERTIES Fortran_MODULE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mod/1 )
22 |
23 | foreach(tgt ${fortran_example_targets})
24 | target_include_directories(${tgt}
25 | PRIVATE
26 | ${CMAKE_BINARY_DIR}/include
27 | ${MPI_Fortran_INCLUDE_DIRS}
28 | )
29 | target_link_libraries(${tgt} PRIVATE MPI::MPI_Fortran)
30 | target_link_libraries(${tgt} PRIVATE "${PROJECT_NAME}_fort")
31 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME})
32 | if (CMAKE_Fortran_COMPILER_ID STREQUAL "NVHPC")
33 | target_compile_options(${tgt} PRIVATE $<$:-cpp -acc -gpu=${CUF_GPU_ARG}>)
34 | target_link_options(${tgt} PRIVATE $<$: -acc -gpu=${CUF_GPU_ARG}>)
35 | elseif (CMAKE_Fortran_COMPILER_ID STREQUAL "GNU")
36 | target_compile_options(${tgt} PRIVATE $<$:-cpp -fbackslash>)
37 | endif()
38 | endforeach()
39 |
40 | install(
41 | TARGETS ${fortran_example_targets}
42 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/simulation
43 | )
44 |
45 |
46 | install(
47 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/config_mlp_native.yaml
48 | ${CMAKE_CURRENT_SOURCE_DIR}/config_fcn_torchscript.yaml
49 | ${CMAKE_CURRENT_SOURCE_DIR}/generate_fcn_model.py
50 | ${CMAKE_CURRENT_SOURCE_DIR}/visualize.py
51 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/simulation)
52 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/model_wrapper.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #include
21 |
22 | #include "internal/base_model.h"
23 |
24 | namespace torchfort {
25 |
26 | class ModelWrapper {
27 | public:
28 | ModelWrapper(const std::shared_ptr& model);
29 |
30 | ModelWrapper(const std::shared_ptr& model_jit);
31 |
32 | ModelWrapper(const std::string& jit_model_fname);
33 |
34 | std::vector parameters() const;
35 |
36 | torch::OrderedDict named_parameters() const;
37 |
38 | void to(torch::Device device, bool non_blocking = false);
39 |
40 | void train();
41 |
42 | void eval();
43 |
44 | std::vector forward(const std::vector& inputs) const;
45 |
46 | void save(const std::string& fname) const;
47 |
48 | void load(const std::string& fname);
49 |
50 | torch::Device device() const;
51 |
52 | private:
53 | bool jit = false;
54 | std::shared_ptr model;
55 | std::shared_ptr model_jit;
56 | torch::Device device_ = torch::Device(torch::kCPU);
57 | };
58 |
59 | } // namespace torchfort
60 |
--------------------------------------------------------------------------------
/src/csrc/lr_schedulers/linear_lr.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 | #include
20 |
21 | #include "internal/base_lr_scheduler.h"
22 | #include "internal/lr_schedulers.h"
23 |
24 | namespace torchfort {
25 |
26 | LinearLR::LinearLR(torch::optim::Optimizer& optimizer, const unsigned total_iters, const double start_factor,
27 | const double end_factor)
28 | : BaseLRScheduler(optimizer), total_iters_(total_iters), start_factor_(start_factor), end_factor_(end_factor) {}
29 |
30 | std::vector LinearLR::get_lrs() {
31 |
32 | double factor;
33 | if (step_count_ == 0) {
34 | factor = start_factor_;
35 | } else if (step_count_ > total_iters_) {
36 | factor = 1.;
37 | } else {
38 | factor = (1. + (end_factor_ - start_factor_) /
39 | double(total_iters_ * start_factor_ + (step_count_ - 1) * (end_factor_ - start_factor_)));
40 | }
41 |
42 | // get current lrs and modify
43 | std::vector lrs = get_current_lrs();
44 | std::transform(lrs.begin(), lrs.end(), lrs.begin(), [factor](const double& v) { return factor * v; });
45 |
46 | return lrs;
47 | }
48 |
49 | } // namespace torchfort
50 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/nvtx.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #include
21 |
22 | #ifdef ENABLE_GPU
23 | #include
24 | #endif
25 |
26 | namespace torchfort {
27 |
28 | // Helper class for NVTX ranges
29 | class nvtx {
30 | public:
31 | #ifdef ENABLE_GPU
32 | static void rangePush(const std::string& range_name) {
33 | static constexpr int ncolors_ = 8;
34 | static constexpr int colors_[ncolors_] = {0x3366CC, 0xDC3912, 0xFF9900, 0x109618,
35 | 0x990099, 0x3B3EAC, 0x0099C6, 0xDD4477};
36 | std::hash hash_fn;
37 | int color = colors_[hash_fn(range_name) % ncolors_];
38 | nvtxEventAttributes_t ev = {0};
39 | ev.version = NVTX_VERSION;
40 | ev.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
41 | ev.colorType = NVTX_COLOR_ARGB;
42 | ev.color = color;
43 | ev.messageType = NVTX_MESSAGE_TYPE_ASCII;
44 | ev.message.ascii = range_name.c_str();
45 | nvtxRangePushEx(&ev);
46 | }
47 |
48 | static void rangePop() { nvtxRangePop(); }
49 | #else
50 | static void rangePush(const std::string& range_name) {}
51 | static void rangePop() {}
52 | #endif
53 | };
54 |
55 | } // namespace torchfort
56 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. TorchFort documentation master file, created by
2 | sphinx-quickstart on Wed Jun 1 13:44:41 2022.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | ############################################################################
7 | TorchFort: An Online Deep Learning Interface for HPC programs on NVIDIA GPUs
8 | ############################################################################
9 | These pages contain the documentation for TorchFort, an online deep learning interface for HPC programs.
10 |
11 | TorchFort is a DL training and inference interface for HPC programs implemented using LibTorch, the C++ backend used by the `PyTorch `_ framework.
12 | The goal of this library is to help practitioners and domain scientists to seamlessly combine their simulation codes with Deep Learning functionalities available
13 | within PyTorch.
14 | This library can be invoked directly from Fortran or C/C++ programs, enabling transparent sharing of data arrays to and from the DL framework all contained within the
15 | simulation process (i.e., no external glue/data-sharing code required). The library can directly load PyTorch model definitions exported to TorchScript and implements a
16 | configurable training process that users can control via a simple YAML configuration file format. The configuration files enable users to specify optimizer and loss selection,
17 | learning rate schedules, and much more.
18 |
19 | Please contact us or open a GitHub issue if you are interested in using this library
20 | in your own solvers and have questions on usage and/or feature requests.
21 |
22 |
23 | Table of Contents
24 | =================
25 | .. toctree::
26 | :maxdepth: 4
27 |
28 | installation
29 | usage
30 | api
31 | extras
32 |
33 |
34 | Indices and tables
35 | ==================
36 |
37 | * :ref:`genindex`
38 | * :ref:`modindex`
39 | * :ref:`search`
40 |
--------------------------------------------------------------------------------
/src/csrc/include/torchfort_enums.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #define TORCHFORT_DEVICE_CPU (-1)
21 |
22 | /**
23 | * @brief This enum defines the data types supported.
24 | */
25 | enum torchfort_datatype_t { TORCHFORT_FLOAT = -1, TORCHFORT_DOUBLE = -2, TORCHFORT_INT32 = -3, TORCHFORT_INT64 = -4 };
26 |
27 | /**
28 | * @brief This enum defines the possible values return values from TorchFort. Most functions in the TorchFort library
29 | * will return one of these values to indicate if an operation has completed successfully or an error occured.
30 | */
31 | enum torchfort_result_t {
32 | TORCHFORT_RESULT_SUCCESS = 0, ///< The operation completed successfully
33 | TORCHFORT_RESULT_INVALID_USAGE = 1, ///< A user error, typically an invalid argument
34 | TORCHFORT_RESULT_NOT_SUPPORTED = 2, ///< A user error, requesting an invalid or unsupported operation configuration
35 | TORCHFORT_RESULT_INTERNAL_ERROR = 3, ///< An internal library error, should be reported
36 | TORCHFORT_RESULT_CUDA_ERROR = 4, ///< An error occured in the CUDA Runtime
37 | TORCHFORT_RESULT_MPI_ERROR = 5, ///< An error occured in the MPI library
38 | TORCHFORT_RESULT_NCCL_ERROR = 6 ///< An error occured in the NCCL library
39 | };
40 |
--------------------------------------------------------------------------------
/tests/rl/test_distributions.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include "internal/rl/distributions.h"
19 | #include
20 | #include
21 |
22 | using namespace torchfort;
23 | using namespace torch::indexing;
24 |
25 | TEST(NormalDistribution, RandomSampling) {
26 | // rng
27 | torch::manual_seed(666);
28 |
29 | // no grad guard
30 | torch::NoGradGuard no_grad;
31 |
32 | // create normal distribution with given shape
33 | torch::Tensor mutens = torch::empty({4, 8}, torch::kFloat32);
34 | torch::Tensor log_sigmatens = torch::empty({4, 8}, torch::kFloat32);
35 |
36 | // fill with random elements
37 | mutens.normal_();
38 | log_sigmatens.normal_();
39 | torch::Tensor sigmatens = torch::exp(log_sigmatens);
40 |
41 | auto ndist = rl::NormalDistribution(mutens, sigmatens);
42 | torch::Tensor sample = ndist.rsample();
43 |
44 | // do direct sampling without reparametrization trick
45 | torch::Tensor sample_compare = at::normal(mutens, sigmatens);
46 |
47 | // expect that shapes match: I am not sure how to compare the values as well
48 | EXPECT_NO_THROW(torch::sum(sample - sample_compare).item());
49 | }
50 |
51 | int main(int argc, char* argv[]) {
52 | ::testing::InitGoogleTest(&argc, argv);
53 |
54 | return RUN_ALL_TESTS();
55 | }
56 |
--------------------------------------------------------------------------------
/src/csrc/lr_schedulers/multistep_lr.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 | #include
20 |
21 | #include "internal/base_lr_scheduler.h"
22 | #include "internal/lr_schedulers.h"
23 |
24 | namespace torchfort {
25 |
26 | MultiStepLR::MultiStepLR(torch::optim::Optimizer& optimizer, const std::vector& milestones, const double gamma)
27 | : BaseLRScheduler(optimizer), milestones_(milestones), gamma_(gamma) {}
28 |
29 | std::vector MultiStepLR::get_lrs() {
30 | std::vector lrs = get_current_lrs();
31 | if (step_count_ == 0 || milestones_.size() == 0)
32 | return lrs;
33 | else {
34 | auto lower_old = std::lower_bound(milestones_.begin(), milestones_.end(), step_count_ - 1,
35 | [](const int& ms, int value) { return ms <= value; });
36 | auto lower = std::lower_bound(milestones_.begin(), milestones_.end(), step_count_,
37 | [](const int& ms, int value) { return ms <= value; });
38 |
39 | if (lower_old != lower) {
40 | // in this case we need to decay the LR:
41 | std::transform(lrs.begin(), lrs.end(), lrs.begin(), [this](const double& lr) { return this->gamma_ * lr; });
42 | }
43 |
44 | return lrs;
45 | }
46 | }
47 |
48 | } // namespace torchfort
49 |
--------------------------------------------------------------------------------
/src/csrc/model_state.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 |
20 | #include
21 |
22 | #include "internal/exceptions.h"
23 | #include "internal/model_state.h"
24 |
25 | namespace torchfort {
26 |
27 | void ModelState::save(const std::string& fname) {
28 | torch::serialize::OutputArchive archive;
29 | archive.write("step_train", torch::IValue(step_train));
30 | archive.write("step_inference", torch::IValue(step_inference));
31 | archive.write("device", torch::IValue(device));
32 | archive.save_to(fname);
33 | }
34 |
35 | void ModelState::load(const std::string& fname) {
36 | if (!std::filesystem::exists(fname)) {
37 | THROW_INVALID_USAGE(fname + " does not exist.");
38 | }
39 |
40 | torch::serialize::InputArchive archive;
41 | archive.load_from(fname);
42 |
43 | torch::IValue ivalue;
44 | if (!archive.try_read("step_train", ivalue)) {
45 | THROW_INVALID_USAGE(fname + " is missing required data.");
46 | }
47 | step_train = ivalue.to();
48 |
49 | if (!archive.try_read("step_inference", ivalue)) {
50 | THROW_INVALID_USAGE(fname + " is missing required data.");
51 | }
52 | step_inference = ivalue.to();
53 |
54 | if (!archive.try_read("device", ivalue)) {
55 | THROW_INVALID_USAGE(fname + " is missing required data.");
56 | }
57 | device = ivalue.to();
58 | }
59 |
60 | } // namespace torchfort
61 |
--------------------------------------------------------------------------------
/src/csrc/losses/l1_loss.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | #include
24 | #include
25 |
26 | #include "internal/losses.h"
27 | #include "internal/param_map.h"
28 | #include "internal/setup.h"
29 | #include "internal/utils.h"
30 |
31 | namespace torchfort {
32 |
33 | void L1Loss::setup(const ParamMap& params) {
34 | std::set supported_params{"reduction"};
35 | check_params(supported_params, params.keys());
36 |
37 | auto options = torch::nn::L1LossOptions();
38 | try {
39 | std::string reduction = params.get_param("reduction")[0];
40 | options = options.reduction(get_torch_reduction(reduction));
41 | } catch (std::out_of_range) {
42 | // use default
43 | }
44 |
45 | module = torch::nn::L1Loss(options);
46 | }
47 |
48 | torch::Tensor L1Loss::forward(const std::vector& inputs, const std::vector& labels,
49 | const std::vector& extra_args) {
50 | if (inputs.size() != 1 || labels.size() != 1 || extra_args.size() != 0) {
51 | THROW_INVALID_USAGE("L1Loss only supports one input tensor, one label tensor, and no extra arguments.");
52 | }
53 | auto x = inputs[0];
54 | auto y = labels[0];
55 | return module(x.flatten(), y.flatten());
56 | }
57 |
58 | } // namespace torchfort
59 |
--------------------------------------------------------------------------------
/src/csrc/lr_schedulers/cosine_annealing_lr.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 |
20 | #include "internal/base_lr_scheduler.h"
21 | #include "internal/lr_schedulers.h"
22 |
23 | namespace torchfort {
24 |
25 | CosineAnnealingLR::CosineAnnealingLR(torch::optim::Optimizer& optimizer, const unsigned T_max, const double eta_min)
26 | : BaseLRScheduler(optimizer), T_max_(T_max), eta_min_(eta_min) {
27 | base_lrs_ = get_current_lrs();
28 | }
29 |
30 | double CosineAnnealingLR::update_lr(const double& last_lr, const double& base_lr) {
31 | double lr;
32 | if ((step_count_ - 1 - T_max_) % (2 * T_max_) == 0) {
33 | lr = eta_min_ + 0.5 * (base_lr - eta_min_) * (1. + cos(double(step_count_) * M_PI / double(T_max_)));
34 | } else {
35 | lr = (1. + cos(M_PI * double(step_count_) / double(T_max_))) /
36 | (1. + cos(M_PI * double(step_count_ - 1) / double(T_max_))) * (last_lr - eta_min_) +
37 | eta_min_;
38 | }
39 |
40 | return lr;
41 | }
42 |
43 | std::vector CosineAnnealingLR::get_lrs() {
44 | std::vector lrs = get_current_lrs();
45 | if (step_count_ == 0 || T_max_ == 0)
46 | return lrs;
47 | else {
48 | std::vector lrs_new;
49 | std::transform(lrs.begin(), lrs.end(), base_lrs_.begin(), std::back_inserter(lrs_new),
50 | [this](const auto& current, const auto& base) { return update_lr(current, base); });
51 | return lrs_new;
52 | }
53 | }
54 |
55 | } // namespace torchfort
56 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/setup.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #include
21 |
22 | #include
23 | #include
24 |
25 | #include "internal/base_loss.h"
26 | #include "internal/base_lr_scheduler.h"
27 | #include "internal/base_model.h"
28 | #include "internal/lr_schedulers.h"
29 | #include "internal/model_state.h"
30 | #include "internal/model_wrapper.h"
31 |
32 | namespace torchfort {
33 |
34 | void check_params(const std::set& supported_params, const std::set& provided_params);
35 |
36 | ParamMap get_params(const YAML::Node& params_node);
37 |
38 | std::shared_ptr get_model(const YAML::Node& model_node);
39 |
40 | std::shared_ptr get_loss(const YAML::Node& loss_node);
41 |
42 | std::shared_ptr get_optimizer(const YAML::Node& optimizer_node,
43 | std::vector parameters);
44 |
45 | std::shared_ptr get_optimizer(const YAML::Node& optimizer_node,
46 | const std::shared_ptr& model);
47 |
48 | std::shared_ptr get_lr_scheduler(const YAML::Node& lr_scheduler_node,
49 | const std::shared_ptr& optimizer);
50 |
51 | std::shared_ptr get_state(const char* name, const YAML::Node& state_node);
52 | } // namespace torchfort
53 |
--------------------------------------------------------------------------------
/src/csrc/losses/mse_loss.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | #include
24 | #include
25 |
26 | #include "internal/exceptions.h"
27 | #include "internal/losses.h"
28 | #include "internal/param_map.h"
29 | #include "internal/setup.h"
30 | #include "internal/utils.h"
31 |
32 | namespace torchfort {
33 |
34 | void MSELoss::setup(const ParamMap& params) {
35 | std::set supported_params{"reduction"};
36 | check_params(supported_params, params.keys());
37 |
38 | auto options = torch::nn::MSELossOptions();
39 | try {
40 | std::string reduction = params.get_param("reduction")[0];
41 | options = options.reduction(get_torch_reduction(reduction));
42 | } catch (std::out_of_range) {
43 | // use default
44 | }
45 |
46 | module = torch::nn::MSELoss(options);
47 | }
48 |
49 | torch::Tensor MSELoss::forward(const std::vector& inputs, const std::vector& labels,
50 | const std::vector& extra_args) {
51 | if (inputs.size() != 1 || labels.size() != 1 || extra_args.size() != 0) {
52 | THROW_INVALID_USAGE("MSELoss only supports one input tensor, one label tensor, and no extra arguments.");
53 | }
54 | auto x = inputs[0];
55 | auto y = labels[0];
56 | return module(x.flatten(), y.flatten());
57 | }
58 |
59 | } // namespace torchfort
60 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/rl/distributions.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 | #include
20 | #include
21 |
22 | #include "internal/rl/rl.h"
23 |
24 | namespace torchfort {
25 |
26 | namespace rl {
27 |
28 | class Distribution {
29 |
30 | public:
31 | Distribution(const Distribution&) = delete;
32 |
33 | // constructor
34 | Distribution() {}
35 | virtual torch::Tensor rsample() = 0;
36 | virtual torch::Tensor log_prob(torch::Tensor value) = 0;
37 | virtual torch::Tensor entropy() = 0;
38 | };
39 |
40 | class NormalDistribution : public Distribution, public std::enable_shared_from_this {
41 | public:
42 | NormalDistribution(torch::Tensor mu, torch::Tensor sigma) : mu_(mu), sigma_(sigma) {}
43 |
44 | torch::Tensor rsample() {
45 | auto noise = torch::empty_like(mu_).normal_(0., 1.);
46 | return torch::Tensor(mu_ + sigma_ * noise).clone();
47 | }
48 |
49 | torch::Tensor log_prob(torch::Tensor value) {
50 | auto var = torch::square(sigma_);
51 | auto log_sigma = sigma_.log();
52 | auto result = -torch::square(value - mu_) / (2 * var) - log_sigma - std::log(std::sqrt(2. * M_PI));
53 |
54 | return result;
55 | }
56 |
57 | torch::Tensor entropy() {
58 | auto log_sigma = sigma_.log();
59 | auto result = log_sigma + 0.5 * (1. + std::log(2. * M_PI));
60 |
61 | return result;
62 | }
63 |
64 | protected:
65 | torch::Tensor mu_;
66 | torch::Tensor sigma_;
67 | };
68 |
69 | } // namespace rl
70 |
71 | } // namespace torchfort
72 |
--------------------------------------------------------------------------------
/tests/supervised/scripts/setup_tests.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | def save_jit_module(module, fname):
5 | try:
6 | module.to("cuda")
7 | except:
8 | print("PyTorch does not have CUDA support. Saving on CPU.")
9 | module_jit = torch.jit.script(module)
10 |
11 | module_jit.save(fname)
12 |
13 | # Create simple models that just return input for testing
14 | class Net1(torch.nn.Module):
15 | def __init__(self):
16 | super(Net1, self).__init__()
17 | self.layer = torch.nn.Linear(10, 10)
18 |
19 | def forward(self, input1):
20 | x = self.layer(input1)
21 | return input1 + 0.0 * x
22 |
23 | class Net2(torch.nn.Module):
24 | def __init__(self):
25 | super(Net2, self).__init__()
26 | self.layer = torch.nn.Linear(10, 10)
27 |
28 | def forward(self, input1, input2):
29 | x = self.layer(input1)
30 | return input1 + 0.0 * x, input2 + 0.0 * x
31 |
32 |
33 | # Create loss functions with various argument combinations
34 | class Loss1(torch.nn.Module):
35 | def __init__(self):
36 | super(Loss1, self).__init__()
37 |
38 | def forward(self, prediction, label):
39 | return (torch.sum(prediction) + torch.sum(label))
40 |
41 | class Loss2(torch.nn.Module):
42 | def __init__(self):
43 | super(Loss2, self).__init__()
44 |
45 | def forward(self, prediction1, prediction2, label1, label2):
46 | return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2))
47 |
48 | class Loss2Extra(torch.nn.Module):
49 | def __init__(self):
50 | super(Loss2Extra, self).__init__()
51 |
52 | def forward(self, prediction1, prediction2, label1, label2, extra_args1, extra_args2):
53 | return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2) +
54 | torch.sum(extra_args1) + torch.sum(extra_args2))
55 |
56 | def main():
57 | model1 = Net1()
58 | model2 = Net2()
59 | loss1 = Loss1()
60 | loss2 = Loss2()
61 | loss2_extra = Loss2Extra()
62 |
63 | save_jit_module(model1, "model.pt")
64 | save_jit_module(model2, "model_multiarg.pt")
65 | save_jit_module(loss1, "loss.pt")
66 | save_jit_module(loss2, "loss_multiarg.pt")
67 | save_jit_module(loss2_extra, "loss_multiarg_extra.pt")
68 |
69 | if __name__ == "__main__":
70 | main()
71 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/env.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 | #include
20 | #include
21 | // #include
22 |
23 | enum IntegratorType { EXPLICIT_EULER, SEMI_IMPLICIT_EULER };
24 | using StateVector = std::array;
25 |
26 | class CartPoleEnv {
27 |
28 | public:
29 | CartPoleEnv();
30 | CartPoleEnv(const CartPoleEnv&) = delete;
31 | CartPoleEnv& operator=(const CartPoleEnv&) = delete;
32 |
33 | StateVector reset();
34 | std::pair getStateBounds();
35 |
36 | std::tuple step(float action);
37 |
38 | private:
39 | // sim parameters
40 | bool terminated_;
41 | float gravity_;
42 | float masscart_;
43 | float masspole_;
44 | float total_mass_;
45 | float length_;
46 | float polemass_length_;
47 | float force_mag_;
48 | float dt_;
49 | float penalty_;
50 | IntegratorType kinematics_integrator_;
51 | int steps_beyond_terminated_;
52 |
53 | // threshold parameters
54 | float theta_threshold_radians_;
55 | float x_threshold_;
56 |
57 | // random stuff
58 | std::mt19937_64 rng_;
59 | std::uniform_real_distribution uniform_dist_;
60 |
61 | // state vector
62 | StateVector state_;
63 | };
64 |
65 | // pybind11 stuff
66 | // namespace py = pybind11;
67 | // PYBIND11_MODULE(environments, m) {
68 | // py::class_(m, "CartPoleEnv", py::dynamic_attr())
69 | // .def(py::init<>())
70 | // .def("step", &CartPoleEnv::step, py::arg("action"))
71 | // .def("reset", &CartPoleEnv::reset);
72 | //}
73 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/python/initialize_models.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import argparse as ap
18 | import math
19 | import torch
20 | from functools import partial
21 | from torch import nn
22 | import torch.nn.functional as F
23 |
24 | from models import weight_init, PolicyFunc, ValueFunc
25 |
26 | def main(args):
27 |
28 | # set seed
29 | torch.manual_seed(666)
30 |
31 | # CUDA check
32 | if torch.cuda.is_available():
33 | torch.cuda.manual_seed(666)
34 | device = torch.device("cuda:0")
35 | else:
36 | device = torch.device("cpu")
37 |
38 | # parameters
39 | batch_size = 64
40 |
41 | # policy model
42 | pmodel = PolicyFunc(hidden_features=args.num_hidden_features).to(device)
43 | weight_init(pmodel)
44 | jpmodel = torch.jit.script(pmodel)
45 | inp = torch.ones((batch_size, 4), dtype=torch.float32, device=device)
46 | out = jpmodel(inp)
47 | print("Policy model:", pmodel)
48 | print("Policy model output shape:", out.shape)
49 | torch.jit.save(jpmodel, "policy.pt")
50 |
51 | # value model
52 | qmodel = ValueFunc(hidden_features=args.num_hidden_features).to(device)
53 | weight_init(qmodel)
54 | jqmodel = torch.jit.script(qmodel)
55 | inp_a = torch.ones((batch_size, 1), dtype=torch.float32, device=device)
56 | out = jqmodel(inp, inp_a)
57 | print("Value model:", qmodel)
58 | print("Value model output shape:", out.shape)
59 | torch.jit.save(jqmodel, "value.pt")
60 |
61 | if __name__ == "__main__":
62 | parser = ap.ArgumentParser()
63 | parser.add_argument("--num_hidden_features", type=int, default=128, help="Number of hidden features")
64 | args = parser.parse_args()
65 |
66 | main(args)
67 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/base_lr_scheduler.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 | #include
20 | #include
21 | #include
22 |
23 | #include
24 |
25 | #include "internal/exceptions.h"
26 |
27 | namespace torchfort {
28 |
29 | class BaseLRScheduler : public torch::optim::LRScheduler {
30 | public:
31 | BaseLRScheduler(torch::optim::Optimizer& optimizer) : LRScheduler(optimizer) {}
32 |
33 | // Define generic save/load functionalities. Specialize in derived schedulers if
34 | // needed.
35 | void save(const std::string& fname) {
36 | torch::serialize::OutputArchive archive;
37 | archive.write("step_count", torch::IValue((int64_t)step_count_));
38 | archive.write("lrs", torch::IValue(get_current_lrs()));
39 | archive.save_to(fname);
40 | }
41 | void load(const std::string& fname, torch::optim::Optimizer& optimizer) {
42 | torch::serialize::InputArchive archive;
43 | archive.load_from(fname);
44 |
45 | torch::IValue ivalue;
46 | if (!archive.try_read("step_count", ivalue)) {
47 | THROW_INVALID_USAGE(fname + " is missing required data.");
48 | }
49 | int64_t step_count = ivalue.to();
50 | step_count_ = step_count;
51 |
52 | if (!archive.try_read("lrs", ivalue)) {
53 | THROW_INVALID_USAGE(fname + " is missing required data.");
54 | }
55 | auto lrs = ivalue.to>();
56 | // Can't use this method to set the LRs due to it being private in the base LR class.
57 | // set_optimizer_lrs(lrs);
58 | for (const auto i : c10::irange(optimizer.param_groups().size())) {
59 | optimizer.param_groups()[i].options().set_lr(lrs[i]);
60 | }
61 | }
62 | };
63 |
64 | } // namespace torchfort
65 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/losses.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #include
21 | #include
22 |
23 | #include
24 | #include
25 | #include
26 |
27 | #include "internal/base_loss.h"
28 | #include "internal/defines.h"
29 | #include "internal/param_map.h"
30 |
31 | namespace torchfort {
32 |
33 | struct L1Loss : BaseLoss {
34 | void setup(const ParamMap& params) override;
35 |
36 | torch::Tensor forward(const std::vector& inputs, const std::vector& labels,
37 | const std::vector& extra_args) override;
38 |
39 | torch::nn::L1Loss module;
40 | };
41 |
42 | struct MSELoss : BaseLoss {
43 | void setup(const ParamMap& params) override;
44 |
45 | torch::Tensor forward(const std::vector& inputs, const std::vector& labels,
46 | const std::vector& extra_args) override;
47 |
48 | torch::nn::MSELoss module;
49 | };
50 |
51 | struct TorchscriptLoss : BaseLoss {
52 | void setup(const ParamMap& params) override;
53 |
54 | torch::Tensor forward(const std::vector& inputs, const std::vector& labels,
55 | const std::vector& extra_args) override;
56 |
57 | std::shared_ptr module_jit;
58 | };
59 |
60 | // Creating loss_registry.
61 | BEGIN_LOSS_REGISTRY
62 |
63 | // Add entries for new losses in this section. First argument to REGISTER_LOSS is
64 | // a string key and the second argument is the class name.
65 | REGISTER_LOSS(L1, L1Loss)
66 | REGISTER_LOSS(MSE, MSELoss)
67 | REGISTER_LOSS(torchscript, TorchscriptLoss)
68 |
69 | END_LOSS_REGISTRY
70 |
71 | } // namespace torchfort
72 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/training.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #ifdef ENABLE_GPU
21 | #include
22 | #endif
23 |
24 | #include
25 |
26 | #include
27 |
28 | namespace torchfort {
29 |
30 | void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfort_tensor_list_t outputs_in,
31 | cudaStream_t ext_stream = 0);
32 |
33 | void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfort_tensor_list_t labels_in,
34 | float* loss_val, torchfort_tensor_list_t extra_loss_args_in, cudaStream_t ext_stream = 0);
35 |
36 | template
37 | void inference(const char* name, T* input, size_t input_dim, int64_t* input_shape, T* output, size_t output_dim,
38 | int64_t* output_shape, cudaStream_t ext_stream = 0) {
39 | TensorList inputs, outputs;
40 |
41 | inputs.add_tensor(input, input_dim, input_shape);
42 | outputs.add_tensor(output, output_dim, output_shape);
43 |
44 | inference_multiarg(name, &inputs, &outputs, ext_stream);
45 | }
46 |
47 | template
48 | void train(const char* name, T* input, size_t input_dim, int64_t* input_shape, T* label, size_t label_dim,
49 | int64_t* label_shape, T* loss_val, cudaStream_t ext_stream = 0) {
50 | TensorList inputs, labels;
51 |
52 | inputs.add_tensor(input, input_dim, input_shape);
53 | labels.add_tensor(label, label_dim, label_shape);
54 |
55 | // multiarg API expects float loss value, so use temporary here
56 | float loss_val_tmp;
57 | train_multiarg(name, &inputs, &labels, &loss_val_tmp, nullptr, ext_stream);
58 | *loss_val = loss_val_tmp;
59 | }
60 |
61 | } // namespace torchfort
62 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ## TorchFort Contribution Guide
2 | We welcome and appreciate external contributions to TorchFort! File a [pull request](https://github.com/NVIDIA/TorchFort/compare)
3 | to get the process started.
4 |
5 | ### Signing Your Work
6 |
7 | - We require that all contributors "sign-off" on their commits. This certifies that the
8 | contribution is your original work, or you have rights to submit it under the same
9 | license, or a compatible license.
10 |
11 | - Any contribution which contains commits that are not Signed-Off will not be accepted.
12 |
13 | - To sign off on a commit you simply use the `--signoff` (or `-s`) option when
14 | committing your changes:
15 |
16 | ```bash
17 | git commit -s -m "Add cool feature."
18 | ```
19 |
20 | This will append the following to your commit message:
21 |
22 | ```text
23 | Signed-off-by: Your Name
24 | ```
25 |
26 | - Full text of the DCO:
27 |
28 | ```text
29 | Developer Certificate of Origin
30 | Version 1.1
31 |
32 | Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
33 | 1 Letterman Drive
34 | Suite D4700
35 | San Francisco, CA, 94129
36 |
37 | Everyone is permitted to copy and distribute verbatim copies of this license
38 | document, but changing it is not allowed.
39 | ```
40 |
41 | ```text
42 | Developer's Certificate of Origin 1.1
43 |
44 | By making a contribution to this project, I certify that:
45 |
46 | (a) The contribution was created in whole or in part by me and I have the right to
47 | submit it under the open source license indicated in the file; or
48 |
49 | (b) The contribution is based upon previous work that, to the best of my knowledge,
50 | is covered under an appropriate open source license and I have the right under that
51 | license to submit that work with modifications, whether created in whole or in part
52 | by me, under the same open source license (unless I am permitted to submit under a
53 | different license), as indicated in the file; or
54 |
55 | (c) The contribution was provided directly to me by some other person who certified
56 | (a), (b) or (c) and I have not modified it.
57 |
58 | (d) I understand and agree that this project and the contribution are public and
59 | that a record of the contribution (including all personal information I submit with
60 | it, including my sign-off) is maintained indefinitely and may be redistributed
61 | consistent with this project or the open source license(s) involved.
62 |
63 | ```
64 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | set(cart_pole_example_targets
2 | train_cart_pole
3 | )
4 |
5 | add_library(environments STATIC)
6 | target_sources(environments
7 | PRIVATE
8 | env.cpp
9 | )
10 | set_property(TARGET environments PROPERTY POSITION_INDEPENDENT_CODE ON)
11 |
12 | add_executable(train_cart_pole)
13 | target_sources(train_cart_pole
14 | PRIVATE
15 | train.cpp
16 | )
17 |
18 | find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED)
19 | find_package(pybind11 CONFIG REQUIRED)
20 | pybind11_add_module(PyEnvironments py_env.cpp)
21 | target_link_libraries(PyEnvironments PRIVATE environments)
22 |
23 | foreach(tgt ${cart_pole_example_targets})
24 | target_include_directories(${tgt}
25 | PRIVATE
26 | ${YAML_CPP_INCLUDE_DIR}
27 | ${MPI_CXX_INCLUDE_DIRS}
28 | ${CUDAToolkit_INCLUDE_DIRS}
29 | ${Python_INCLUDE_DIRS}
30 | ${CMAKE_BINARY_DIR}/include
31 | )
32 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME})
33 | target_link_libraries(${tgt} PRIVATE ${TORCH_LIBRARIES})
34 | target_link_libraries(${tgt} PRIVATE ${Python_LIBRARIES})
35 | target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX)
36 | target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY})
37 | target_link_libraries(${tgt} PRIVATE environments)
38 | target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>)
39 | target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>)
40 | if (TORCHFORT_ENABLE_GPU)
41 | target_include_directories(${tgt}
42 | PRIVATE
43 | ${CUDAToolkit_INCLUDE_DIRS}
44 | )
45 | target_link_libraries(${tgt} PRIVATE CUDA::cudart)
46 | target_compile_definitions(${tgt} PRIVATE ENABLE_GPU)
47 | endif()
48 | endforeach()
49 |
50 | # installation
51 | # executable
52 | install(
53 | TARGETS ${cart_pole_example_targets}
54 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole
55 | )
56 |
57 | # python env
58 | install(
59 | TARGETS PyEnvironments
60 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole/python
61 | )
62 |
63 | # config files
64 | install(
65 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/config.yaml ${CMAKE_CURRENT_SOURCE_DIR}/config_sim.yaml
66 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole
67 | )
68 |
69 | # python files
70 | install(
71 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/python/models.py ${CMAKE_CURRENT_SOURCE_DIR}/python/initialize_models.py ${CMAKE_CURRENT_SOURCE_DIR}/python/visualize.py
72 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole/python
73 | )
74 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/logging.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #include
21 | #include
22 |
23 | #include "internal/model_pack.h"
24 |
25 | namespace torchfort {
26 |
27 | namespace logging {
28 |
29 | enum level { info, warn, error, wandb };
30 |
31 | std::string log_level_prefix(level log_level);
32 | void write(const std::filesystem::path& filename, const std::string& message, level log_level);
33 | void print(const std::string& message, level log_level);
34 |
35 | } // namespace logging
36 |
37 | // Declaration of external global variables
38 | extern std::unordered_map models;
39 |
40 | // specialized logging routines
41 | template
42 | void wandb_log(std::shared_ptr state, std::shared_ptr comm, const char* name, const char* metric_name,
43 | int64_t step, T value) {
44 | if (state->enable_wandb_hook) {
45 | std::stringstream os;
46 | os << "model: " << name << ", ";
47 | os << "step: " << step << ", ";
48 | os << metric_name << ": " << value;
49 | if (!comm || (comm && comm->rank == 0)) {
50 | torchfort::logging::write(state->report_file, os.str(), torchfort::logging::wandb);
51 | }
52 | }
53 | }
54 |
55 | template void wandb_log(const char* name, const char* metric_name, int64_t step, T value) {
56 | auto state = models[name].state.get();
57 | if (state->enable_wandb_hook) {
58 | std::stringstream os;
59 | os << "model: " << name << ", ";
60 | os << "step: " << step << ", ";
61 | os << metric_name << ": " << value;
62 | if (!models[name].comm || (models[name].comm && models[name].comm->rank == 0)) {
63 | torchfort::logging::write(state->report_file, os.str(), torchfort::logging::wandb);
64 | }
65 | }
66 | }
67 |
68 | } // namespace torchfort
69 |
--------------------------------------------------------------------------------
/docker/Dockerfile_gnu:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/cuda:12.9.1-devel-ubuntu22.04
2 |
3 | SHELL ["/bin/bash", "-c"]
4 |
5 | ENV CUDA_HOME /usr/local/cuda
6 |
7 | # Install System Dependencies
8 | ENV DEBIAN_FRONTEND noninteractive
9 | RUN apt update -y && \
10 | apt install -y wget cmake && \
11 | apt install -y python3 python-is-python3 python3-pip python3-pybind11 && \
12 | apt install -y vim gfortran git && \
13 | apt install -y libibverbs-dev ibverbs-utils numactl
14 |
15 | # Download HPCX
16 | RUN cd /opt && \
17 | wget https://content.mellanox.com/hpc/hpc-x/v2.24.1_cuda12/hpcx-v2.24.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64.tbz && \
18 | tar xjf hpcx-v2.24.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64.tbz && \
19 | mv hpcx-v2.24.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64 hpcx && \
20 | cd /opt && rm hpcx-v2.24.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64.tbz
21 |
22 | ENV PATH /opt/hpcx/ompi/bin:$PATH
23 | ENV LD_LIBRARY_PATH /opt/hpcx/ompi/lib:$LD_LIBRARY_PATH
24 |
25 | RUN echo "source /opt/hpcx/hpcx-init.sh; hpcx_load" >> /root/.bashrc
26 |
27 | # Install PyTorch
28 | RUN pip3 install --no-deps torch==2.8.0 --index-url https://download.pytorch.org/whl/cu129 && \
29 | pip3 install --no-deps nvidia-cudnn-cu12==9.10.2.21 nvidia-cusparselt-cu12==0.7.1 nvidia-cufile-cu12==1.14.1.1
30 |
31 | # Install yaml-cpp
32 | RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \
33 | cd yaml-cpp && \
34 | mkdir build && cd build && \
35 | cmake -DCMAKE_INSTALL_PREFIX=/opt/yaml-cpp \
36 | -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" \
37 | -DBUILD_SHARED_LIBS=OFF \
38 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && \
39 | make -j$(nproc) && make install
40 | ENV LD_LIBRARY_PATH /opt/yaml-cpp/lib:${LD_LIBRARY_PATH}
41 |
42 | # Install additional Python dependencies
43 | RUN pip3 install wandb ruamel-yaml matplotlib pygame moviepy
44 |
45 | # Install TorchFort
46 | ENV FC=gfortran
47 | COPY . /torchfort
48 | RUN source /opt/hpcx/hpcx-init.sh && hpcx_load && \
49 | cd /torchfort && mkdir build && cd build && \
50 | cmake -DCMAKE_INSTALL_PREFIX=/opt/torchfort \
51 | -DTORCHFORT_YAML_CPP_ROOT=/opt/yaml-cpp \
52 | -DTORCHFORT_BUILD_EXAMPLES=1 \
53 | -DTORCHFORT_BUILD_TESTS=1 \
54 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \
55 | .. && \
56 | make -j$(nproc) install && \
57 | cd / && rm -rf torchfort
58 | ENV LD_LIBRARY_PATH /opt/torchfort/lib:${LD_LIBRARY_PATH}
59 | ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH}
60 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/nvhpc:25.7-devel-cuda12.9-ubuntu22.04
2 |
3 | SHELL ["/bin/bash", "-c"]
4 |
5 | ENV CUDA_HOME /opt/nvidia/hpc_sdk/Linux_x86_64/25.7/cuda
6 | ENV LD_LIBRARY_PATH /opt/nvidia/hpc_sdk/Linux_x86_64/25.7/cuda/12.9/extras/CUPTI/lib64:$LD_LIBRARY_PATH
7 |
8 | # Install System Dependencies
9 | ENV DEBIAN_FRONTEND noninteractive
10 | RUN apt update -y && \
11 | apt install -y wget cmake && \
12 | apt install -y python3 python-is-python3 python3-pip python3-pybind11 && \
13 | apt install -y vim gfortran
14 |
15 | # Install PyTorch (and select dependencies)
16 | RUN pip3 install --no-deps torch==2.8.0 --index-url https://download.pytorch.org/whl/cu129 && \
17 | pip3 install --no-deps nvidia-cudnn-cu12==9.10.2.21 nvidia-cusparselt-cu12==0.7.1 nvidia-cufile-cu12==1.14.1.1 nvidia-nccl-cu12==2.27.3
18 |
19 | # Remove conflicting NCCL version from NVHPC SDK, add libnccl.so symlink to pip installed NCCL
20 | RUN rm -rf /opt/nvidia/hpc_sdk/Linux_x86_64/25.7/comm_libs/nccl/ && \
21 | cd /usr/local/lib/python3.10/dist-packages/nvidia/nccl/lib && \
22 | ln -s libnccl.so.2 libnccl.so
23 | ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/nvidia/nccl/lib:$LD_LIBRARY_PATH
24 |
25 | # Install yaml-cpp
26 | RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \
27 | cd yaml-cpp && \
28 | mkdir build && cd build && \
29 | cmake -DCMAKE_INSTALL_PREFIX=/opt/yaml-cpp \
30 | -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" \
31 | -DBUILD_SHARED_LIBS=OFF \
32 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && \
33 | make -j$(nproc) && make install
34 | ENV LD_LIBRARY_PATH /opt/yaml-cpp/lib:${LD_LIBRARY_PATH}
35 |
36 | # Install additional Python dependencies
37 | RUN pip3 install wandb ruamel-yaml matplotlib pygame moviepy
38 |
39 | # Install TorchFort
40 | ENV FC=nvfortran
41 | COPY . /torchfort
42 | RUN cd /torchfort && mkdir build && cd build && \
43 | CUDA_PATH=$CUDA_HOME \
44 | cmake -DCMAKE_INSTALL_PREFIX=/opt/torchfort \
45 | -DCMAKE_CXX_COMPILER=`which g++` \
46 | -DTORCHFORT_YAML_CPP_ROOT=/opt/yaml-cpp \
47 | -DTORCHFORT_NCCL_ROOT=/usr/local/lib/python3.10/dist-packages/nvidia/nccl \
48 | -DTORCHFORT_BUILD_EXAMPLES=1 \
49 | -DTORCHFORT_BUILD_TESTS=1 \
50 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \
51 | .. && \
52 | make -j$(nproc) install && \
53 | cd / && rm -rf torchfort
54 | ENV LD_LIBRARY_PATH /opt/torchfort/lib:${LD_LIBRARY_PATH}
55 | ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH}
56 |
--------------------------------------------------------------------------------
/tests/general/scripts/setup_tests.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | def save_jit_module(module, fname):
5 | try:
6 | module.to("cuda")
7 | except:
8 | print("PyTorch does not have CUDA support. Saving on CPU.")
9 | module_jit = torch.jit.script(module)
10 |
11 | module_jit.save(fname)
12 |
13 | # Create simple models that just return input for testing
14 | class Net1(torch.nn.Module):
15 | def __init__(self):
16 | super(Net1, self).__init__()
17 | self.layer = torch.nn.Linear(10, 10)
18 |
19 | def forward(self, input1):
20 | x = self.layer(input1)
21 | return input1 + 0.0 * x
22 |
23 | class Net2(torch.nn.Module):
24 | def __init__(self):
25 | super(Net2, self).__init__()
26 | self.layer = torch.nn.Linear(10, 10)
27 |
28 | def forward(self, input1, input2):
29 | x = self.layer(input1)
30 | return input1 + 0.0 * x, input2 + 0.0 * x
31 |
32 |
33 | # Create loss functions with various argument combinations
34 | class Loss1(torch.nn.Module):
35 | def __init__(self):
36 | super(Loss1, self).__init__()
37 |
38 | def forward(self, prediction, label):
39 | return (torch.sum(prediction) + torch.sum(label)) / (2 * prediction.numel())
40 |
41 | class Loss2(torch.nn.Module):
42 | def __init__(self):
43 | super(Loss2, self).__init__()
44 |
45 | def forward(self, prediction1, prediction2, label1, label2):
46 | return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2)) / (4 * prediction1.numel())
47 |
48 | class Loss2Extra(torch.nn.Module):
49 | def __init__(self):
50 | super(Loss2Extra, self).__init__()
51 |
52 | def forward(self, prediction1, prediction2, label1, label2, extra_args1, extra_args2):
53 | return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2) +
54 | torch.sum(extra_args1) + torch.sum(extra_args2)) / (6 * prediction1.numel())
55 |
56 | class Loss3(torch.nn.Module):
57 | def __init__(self):
58 | super(Loss3, self).__init__()
59 |
60 | def forward(self, prediction, label):
61 | return torch.sum(prediction), torch.sum(label)
62 |
63 | def main():
64 | model1 = Net1()
65 | model2 = Net2()
66 | loss1 = Loss1()
67 | loss2 = Loss2()
68 | loss2_extra = Loss2Extra()
69 | loss3 = Loss3()
70 |
71 | save_jit_module(model1, "model.pt")
72 | save_jit_module(model2, "model_multiarg.pt")
73 | save_jit_module(loss1, "loss.pt")
74 | save_jit_module(loss2, "loss_multiarg.pt")
75 | save_jit_module(loss2_extra, "loss_multiarg_extra.pt")
76 | save_jit_module(loss3, "loss_multiout.pt")
77 |
78 | if __name__ == "__main__":
79 | main()
80 |
--------------------------------------------------------------------------------
/src/csrc/losses/torchscript_loss.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 |
25 | #include
26 | #include
27 | #include
28 |
29 | #include "internal/exceptions.h"
30 | #include "internal/losses.h"
31 | #include "internal/param_map.h"
32 | #include "internal/setup.h"
33 | #include "internal/utils.h"
34 |
35 | namespace torchfort {
36 |
37 | void TorchscriptLoss::setup(const ParamMap& params) {
38 | std::string jit_loss_fname;
39 | try {
40 | jit_loss_fname = params.get_param("filename")[0];
41 | } catch (std::out_of_range) {
42 | THROW_INVALID_USAGE("filename parameter is required for torchscript loss type.");
43 | }
44 |
45 | if (!std::filesystem::exists(jit_loss_fname)) {
46 | THROW_INVALID_USAGE(jit_loss_fname + " does not exist.");
47 | }
48 |
49 | module_jit = std::shared_ptr(new torch::jit::Module);
50 | *module_jit = torch::jit::load(jit_loss_fname);
51 | }
52 |
53 | torch::Tensor TorchscriptLoss::forward(const std::vector& inputs,
54 | const std::vector& labels,
55 | const std::vector& extra_args) {
56 | std::vector inputs_jit;
57 | inputs_jit.insert(inputs_jit.end(), inputs.begin(), inputs.end());
58 | inputs_jit.insert(inputs_jit.end(), labels.begin(), labels.end());
59 | inputs_jit.insert(inputs_jit.end(), extra_args.begin(), extra_args.end());
60 |
61 | auto result = module_jit->forward(inputs_jit);
62 | if (!result.isTensor()) {
63 | THROW_INVALID_USAGE("TorchscriptLoss only supports returning a single loss tensor.");
64 | }
65 | return result.toTensor();
66 | }
67 |
68 | } // namespace torchfort
69 |
--------------------------------------------------------------------------------
/src/csrc/models/mlp_model.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 |
20 | #include
21 |
22 | #include "internal/models.h"
23 | #include "internal/param_map.h"
24 | #include "internal/setup.h"
25 |
26 | namespace torchfort {
27 |
28 | // MLP model in C++ using libtorch
29 | void MLPModel::setup(const ParamMap& params) {
30 | // Extract params from input map.
31 | std::set supported_params{"dropout", "flatten_non_batch_dims", "layer_sizes"};
32 | check_params(supported_params, params.keys());
33 |
34 | dropout = params.get_param("dropout", 0.0)[0];
35 | flatten_non_batch_dims = params.get_param("flatten_non_batch_dims", true)[0];
36 | layer_sizes = params.get_param("layer_sizes");
37 |
38 | // Construct and register submodules.
39 | for (int i = 0; i < layer_sizes.size() - 1; ++i) {
40 | fc_layers.push_back(
41 | register_module("fc" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1])));
42 | if (i < layer_sizes.size() - 2) {
43 | biases.push_back(register_parameter("b" + std::to_string(i), torch::zeros(layer_sizes[i + 1])));
44 | }
45 | }
46 | }
47 |
48 | // Implement the forward function.
49 | std::vector MLPModel::forward(const std::vector& inputs) {
50 | if (inputs.size() > 1)
51 | THROW_INVALID_USAGE("Built-in MLP model does not support multiple input tensors.");
52 |
53 | auto x = inputs[0];
54 |
55 | if (flatten_non_batch_dims) {
56 | x = x.reshape({x.size(0), -1});
57 | }
58 |
59 | for (int i = 0; i < layer_sizes.size() - 1; ++i) {
60 | if (i < layer_sizes.size() - 2) {
61 | x = torch::relu(fc_layers[i]->forward(x) + biases[i]);
62 | x = torch::dropout(x, dropout, is_training());
63 | } else {
64 | x = fc_layers[i]->forward(x);
65 | }
66 | }
67 | return std::vector{x};
68 | }
69 |
70 | } // namespace torchfort
71 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/lr_schedulers.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 | #include
20 |
21 | #include
22 |
23 | #include "internal/base_lr_scheduler.h"
24 |
25 | namespace torchfort {
26 |
27 | class CosineAnnealingLR : public BaseLRScheduler {
28 | public:
29 | CosineAnnealingLR(torch::optim::Optimizer& optimizer, const unsigned T_max, const double eta_min = 0.0);
30 |
31 | private:
32 | std::vector get_lrs() override;
33 | double update_lr(const double& last_lr, const double& base_lr);
34 |
35 | const unsigned T_max_;
36 | const double eta_min_;
37 | std::vector base_lrs_;
38 | };
39 |
40 | class MultiStepLR : public BaseLRScheduler {
41 | public:
42 | MultiStepLR(torch::optim::Optimizer& optimizer, const std::vector& milestones, const double gamma = 0.1);
43 |
44 | private:
45 | std::vector get_lrs() override;
46 |
47 | const std::vector milestones_;
48 | const double gamma_;
49 | };
50 |
51 | class PolynomialLR : public BaseLRScheduler {
52 | public:
53 | PolynomialLR(torch::optim::Optimizer& optimizer, const unsigned total_iters, const double power = 1.0);
54 |
55 | private:
56 | std::vector get_lrs() override;
57 |
58 | const unsigned total_iters_;
59 | const double power_;
60 | };
61 |
62 | class StepLR : public BaseLRScheduler {
63 | public:
64 | StepLR(torch::optim::Optimizer& optimizer, const unsigned step_size, const double gamma = 0.1);
65 |
66 | private:
67 | std::vector get_lrs() override;
68 |
69 | const unsigned step_size_;
70 | const double gamma_;
71 | };
72 |
73 | class LinearLR : public BaseLRScheduler {
74 | public:
75 | LinearLR(torch::optim::Optimizer& optimizer, const unsigned total_iters, const double start_factor = 0.333,
76 | const double end_factor = 1.0);
77 |
78 | private:
79 | std::vector get_lrs() override;
80 |
81 | const unsigned total_iters_;
82 | const double start_factor_, end_factor_;
83 | };
84 |
85 | } // namespace torchfort
86 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | import sphinx_rtd_theme
8 |
9 | # -- Path setup --------------------------------------------------------------
10 |
11 | # If extensions (or modules to document with autodoc) are in another directory,
12 | # add these directories to sys.path here. If the directory is relative to the
13 | # documentation root, use os.path.abspath to make it absolute, like shown here.
14 | #
15 | import os
16 | # import sys
17 | # sys.path.insert(0, os.path.abspath('.'))
18 |
19 |
20 | # -- Project information -----------------------------------------------------
21 |
22 | project = 'TorchFort'
23 | copyright = '2023-2025, NVIDIA Corporation'
24 | author = 'NVIDIA Corporation'
25 |
26 | version = os.getenv("TORCHFORT_GIT_SHA", default="N/A")
27 | #release = version
28 |
29 |
30 | # -- General configuration ---------------------------------------------------
31 |
32 | # Add any Sphinx extension module names here, as strings. They can be
33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
34 | # ones.
35 | extensions = [
36 | 'breathe',
37 | 'sphinx.ext.mathjax',
38 | 'sphinx_tabs.tabs',
39 | 'sphinxfortran.fortran_domain',
40 | ]
41 |
42 | # Add any paths that contain templates here, relative to this directory.
43 | templates_path = ['_templates']
44 |
45 | # List of patterns, relative to source directory, that match files and
46 | # directories to ignore when looking for source files.
47 | # This pattern also affects html_static_path and html_extra_path.
48 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
49 |
50 | # The name of the Pygments (syntax highlighting) style to use.
51 | #pygments_style = 'sphinx'
52 |
53 | highlight_language = 'cpp'
54 |
55 | def setup(app):
56 | app.add_css_file('style.css')
57 |
58 | # -- Options for HTML output -------------------------------------------------
59 |
60 | # The theme to use for HTML and HTML Help pages. See the documentation for
61 | # a list of builtin themes.
62 | #
63 | html_theme = 'sphinx_rtd_theme'
64 | #html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
65 |
66 | html_theme_options = {
67 | "collapse_navigation" : False,
68 | "navigation_depth" : 6,
69 | }
70 |
71 | # Add any paths that contain custom static files (such as style sheets) here,
72 | # relative to this directory. They are copied after the builtin static files,
73 | # so a file named "default.css" will overwrite the builtin "default.css".
74 | html_static_path = ['_static']
75 |
76 | breathe_projects = { "torchfort": "xml/" }
77 | breathe_default_project = "torchfort"
78 |
--------------------------------------------------------------------------------
/src/csrc/logging.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 |
25 | #include "internal/exceptions.h"
26 | #include "internal/logging.h"
27 |
28 | namespace torchfort {
29 | namespace logging {
30 |
31 | std::mutex logging_mutex;
32 | static std::unique_ptr logfile;
33 |
34 | std::string log_level_prefix(level log_level) {
35 | if (log_level == level::info) {
36 | return "TORCHFORT::INFO:";
37 | } else if (log_level == level::warn) {
38 | return "TORCHFORT::WARN:";
39 | } else if (log_level == level::error) {
40 | return "TORCHFORT::ERROR:";
41 | } else if (log_level == level::wandb) {
42 | return "TORCHFORT::WANDB:";
43 | } else {
44 | THROW_INVALID_USAGE("Unknown log level encountered.");
45 | }
46 | }
47 |
48 | void print(const std::string& message, level log_level) {
49 | std::cout << log_level_prefix(log_level) << " ";
50 | std::cout << message << std::endl;
51 | }
52 |
53 | bool open_logfile(const std::filesystem::path& filename) {
54 |
55 | // check if filename is empty, meaning we do not want to log
56 | if (filename.empty()) {
57 | return false;
58 | }
59 |
60 | // check if path exists
61 | if (filename.has_parent_path()) {
62 | auto path = filename.parent_path();
63 | std::filesystem::create_directories(path);
64 | }
65 |
66 | logfile = std::make_unique(filename, std::ofstream::out | std::ofstream::app);
67 |
68 | return true;
69 | }
70 |
71 | void write(const std::filesystem::path& filename, const std::string& message, level log_level) {
72 | std::lock_guard guard(logging_mutex);
73 |
74 | // check of logfile if already open
75 | if (logfile == nullptr) {
76 | if (!open_logfile(filename))
77 | return;
78 | }
79 | auto line = log_level_prefix(log_level) + " " + message + "\n";
80 | logfile->write(line.c_str(), line.size());
81 | logfile->flush();
82 | }
83 |
84 | } // namespace logging
85 | } // namespace torchfort
86 |
--------------------------------------------------------------------------------
/tests/general/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.14)
2 |
3 | set(test_targets
4 | test_losses
5 | )
6 |
7 | add_executable(test_losses)
8 | target_sources(test_losses
9 | PRIVATE
10 | test_losses.cpp
11 | )
12 |
13 | find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED)
14 |
15 | foreach(tgt ${test_targets})
16 | target_include_directories(${tgt}
17 | PRIVATE
18 | ${YAML_CPP_INCLUDE_DIR}
19 | ${MPI_CXX_INCLUDE_DIRS}
20 | ${CMAKE_BINARY_DIR}/include
21 | ${CMAKE_CURRENT_SOURCE_DIR}/../
22 | )
23 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME})
24 | target_link_libraries(${tgt} PRIVATE ${TORCH_LIBRARIES})
25 | target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY})
26 | target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX)
27 | target_link_libraries(${tgt} PRIVATE GTest::gtest_main)
28 | target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>)
29 | target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>)
30 | if (TORCHFORT_ENABLE_GPU)
31 | target_include_directories(${tgt}
32 | PRIVATE
33 | ${CUDAToolkit_INCLUDE_DIRS}
34 | ${NCCL_INCLUDE_DIR}
35 | )
36 | target_link_libraries(${tgt} PRIVATE CUDA::cudart)
37 | target_compile_definitions(${tgt} PRIVATE ENABLE_GPU)
38 | endif()
39 |
40 | # discover tests: we have an issue with the work dir of gtest so disable that for now
41 | #gtest_discover_tests(${tgt})
42 | add_test(NAME ${tgt} COMMAND ${tgt} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
43 | endforeach()
44 |
45 | # installation
46 | # executable
47 | install(
48 | TARGETS ${test_targets}
49 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general
50 | )
51 |
52 | # copy files
53 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mse.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs)
54 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mse_multiarg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs)
55 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/l1.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs)
56 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/l1_multiarg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs)
57 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs)
58 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs)
59 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg_extra.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs)
60 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiout.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs)
61 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/scripts/setup_tests.py DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/scripts)
62 |
--------------------------------------------------------------------------------
/tests/rl/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.14)
2 |
3 | set(test_targets
4 | test_replay_buffer
5 | test_rollout_buffer
6 | test_distributions
7 | test_interface
8 | test_off_policy
9 | test_on_policy
10 | )
11 |
12 | add_executable(test_replay_buffer)
13 | target_sources(test_replay_buffer
14 | PRIVATE
15 | test_replay_buffer.cpp
16 | )
17 |
18 | add_executable(test_rollout_buffer)
19 | target_sources(test_rollout_buffer
20 | PRIVATE
21 | test_rollout_buffer.cpp
22 | )
23 |
24 | add_executable(test_distributions)
25 | target_sources(test_distributions
26 | PRIVATE
27 | test_distributions.cpp
28 | )
29 |
30 | add_executable(test_interface)
31 | target_sources(test_interface
32 | PRIVATE
33 | test_interface.cpp
34 | )
35 |
36 | add_executable(test_off_policy)
37 | target_sources(test_off_policy
38 | PRIVATE
39 | test_off_policy.cpp
40 | )
41 |
42 | add_executable(test_on_policy)
43 | target_sources(test_on_policy
44 | PRIVATE
45 | test_on_policy.cpp
46 | )
47 |
48 |
49 | find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED)
50 |
51 | foreach(tgt ${test_targets})
52 | target_include_directories(${tgt}
53 | PRIVATE
54 | ${YAML_CPP_INCLUDE_DIR}
55 | ${MPI_CXX_INCLUDE_DIRS}
56 | ${CMAKE_BINARY_DIR}/include
57 | )
58 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME})
59 | target_link_libraries(${tgt} PRIVATE ${TORCH_LIBRARIES})
60 | target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY})
61 | target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX)
62 | target_link_libraries(${tgt} PRIVATE GTest::gtest_main)
63 | target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>)
64 | target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>)
65 | if (TORCHFORT_ENABLE_GPU)
66 | target_include_directories(${tgt}
67 | PRIVATE
68 | ${CUDAToolkit_INCLUDE_DIRS}
69 | ${NCCL_INCLUDE_DIR}
70 | )
71 | target_link_libraries(${tgt} PRIVATE CUDA::cudart)
72 | target_compile_definitions(${tgt} PRIVATE ENABLE_GPU)
73 | endif()
74 |
75 | # discover tests: we have an issue with the work dir of gtest so disable that for now
76 | #gtest_discover_tests(${tgt})
77 | add_test(NAME ${tgt} COMMAND ${tgt} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
78 | endforeach()
79 |
80 | # installation
81 | # executable
82 | install(
83 | TARGETS ${test_targets}
84 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl
85 | )
86 |
87 | # copy files
88 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/td3.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl/configs)
89 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/ddpg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl/configs)
90 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/sac.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl/configs)
91 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/ppo.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl/configs)
92 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/python/models.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import math
18 | import torch
19 | from torch import nn
20 | import torch.nn.functional as F
21 |
22 | def weight_init(model, scale=0.02):
23 | with torch.no_grad():
24 | for m in model.modules():
25 | if isinstance(m, nn.Linear):
26 | sqrtk = math.sqrt(1./float(m.weight.shape[1]))
27 | nn.init.uniform_(m.weight, a=-sqrtk, b=sqrtk)
28 | if m.bias is not None:
29 | m.bias.data.zero_()
30 |
31 | class PolicyFunc(nn.Module):
32 | def __init__(self, hidden_features=128):
33 | super(PolicyFunc, self).__init__()
34 |
35 | layers = [nn.Linear(in_features = 4,
36 | out_features = hidden_features,
37 | bias=True),
38 | nn.ReLU(),
39 | nn.Linear(in_features = hidden_features,
40 | out_features = hidden_features // 2,
41 | bias=True),
42 | nn.ReLU(),
43 | nn.Linear(in_features = hidden_features // 2,
44 | out_features = 1,
45 | bias=True),
46 | nn.Tanh()]
47 |
48 | self.fwd = nn.Sequential(*layers)
49 |
50 | def forward(self, x: torch.Tensor) -> torch.Tensor:
51 | return self.fwd(x)
52 |
53 | class ValueFunc(nn.Module):
54 | def __init__(self, hidden_features=128):
55 | super(ValueFunc, self).__init__()
56 |
57 | layers = [nn.Linear(in_features = 5,
58 | out_features = hidden_features,
59 | bias=True),
60 | nn.ReLU(),
61 | nn.Linear(in_features = hidden_features,
62 | out_features = hidden_features // 2,
63 | bias=True),
64 | nn.ReLU(),
65 | nn.Linear(in_features = hidden_features // 2,
66 | out_features = 1,
67 | bias=True)]
68 |
69 | self.fwd = nn.Sequential(*layers)
70 |
71 | def forward(self, s: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
72 | x = torch.cat([s, a], dim=1)
73 | return self.fwd(x)
74 |
--------------------------------------------------------------------------------
/src/csrc/models/rl/sac_model.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "internal/models.h"
6 | #include "internal/param_map.h"
7 | #include "internal/setup.h"
8 |
9 | namespace torchfort {
10 |
11 | // SACMLP model in C++ using libtorch
12 | void SACMLPModel::setup(const ParamMap& params) {
13 | // Extract params from input map.
14 | std::set supported_params{"dropout", "layer_sizes", "state_dependent_sigma", "log_sigma_init"};
15 | check_params(supported_params, params.keys());
16 |
17 | dropout = params.get_param("dropout", 0.0)[0];
18 | layer_sizes = params.get_param("layer_sizes");
19 | state_dependent_sigma = params.get_param("state_dependent_sigma", true)[0];
20 | double log_sigma_init = params.get_param("log_sigma_init", 0.)[0];
21 |
22 | // Construct and register submodules.
23 | for (int i = 0; i < layer_sizes.size() - 1; ++i) {
24 | if (i < layer_sizes.size() - 2) {
25 | encoder_layers.push_back(
26 | register_module("encoder_fc_" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1])));
27 | biases.push_back(register_parameter("encoder_b_" + std::to_string(i), torch::zeros(layer_sizes[i + 1])));
28 | } else {
29 | // first output: mu
30 | out_layers.push_back(
31 | register_module("out_fc_1_" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1])));
32 | out_biases.push_back(register_parameter("out_b_1_" + std::to_string(i), torch::zeros(layer_sizes[i + 1])));
33 | // second output: log_sigma
34 | if (state_dependent_sigma) {
35 | out_layers.push_back(
36 | register_module("out_fc_2_" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1])));
37 | out_biases.push_back(register_parameter("out_b_2_" + std::to_string(i), torch::zeros(layer_sizes[i + 1])));
38 | } else {
39 | out_biases.push_back(
40 | register_parameter("out_b_2_" + std::to_string(i), torch::ones(layer_sizes[i + 1]) * log_sigma_init));
41 | }
42 | }
43 | }
44 | }
45 |
46 | // Implement the forward function.
47 | std::vector SACMLPModel::forward(const std::vector& inputs) {
48 |
49 | // make sure that exactly two tensors are fed (state and action)
50 | if (inputs.size() != 1) {
51 | THROW_INVALID_USAGE("You have to provide exactly one tensor (state) to the SACMLPModel");
52 | }
53 |
54 | // unpack
55 | auto state = inputs[0];
56 |
57 | // expand dims if necessary
58 | if (state.dim() == 1) {
59 | state = state.unsqueeze(0);
60 | }
61 |
62 | // flatten everything beyond dim 0:
63 | auto x = state.reshape({state.size(0), -1});
64 | torch::Tensor y, z;
65 |
66 | for (int i = 0; i < layer_sizes.size() - 1; ++i) {
67 | if (i < layer_sizes.size() - 2) {
68 | // encoder part
69 | x = torch::relu(encoder_layers[i]->forward(x) + biases[i]);
70 | x = torch::dropout(x, dropout, is_training());
71 | } else {
72 | // y
73 | y = out_layers[0]->forward(x) + out_biases[0];
74 | // z
75 | if (state_dependent_sigma) {
76 | z = out_layers[1]->forward(x) + out_biases[1];
77 | } else {
78 | z = out_biases[1];
79 | }
80 | }
81 | }
82 | return std::vector{y, z};
83 | }
84 |
85 | } // namespace torchfort
86 |
--------------------------------------------------------------------------------
/src/csrc/models/rl/common_models.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include
19 |
20 | #include
21 |
22 | #include "internal/models.h"
23 | #include "internal/param_map.h"
24 | #include "internal/setup.h"
25 |
26 | namespace torchfort {
27 |
28 | // MLP model in C++ using libtorch
29 | void CriticMLPModel::setup(const ParamMap& params) {
30 | // Extract params from input map.
31 | std::set supported_params{"dropout", "layer_sizes"};
32 | check_params(supported_params, params.keys());
33 |
34 | dropout = params.get_param("dropout", 0.0)[0];
35 | layer_sizes = params.get_param("layer_sizes");
36 |
37 | // sanity checks
38 | // make sure that value function is emitting a scalar.
39 | if (layer_sizes[layer_sizes.size() - 1] != 1) {
40 | THROW_INVALID_USAGE(
41 | "CriticMLPModel::setup: error, the value of the last element of layer_sizes has to be equal to one.");
42 | }
43 |
44 | // Construct and register submodules.
45 | for (int i = 0; i < layer_sizes.size() - 1; ++i) {
46 | fc_layers.push_back(
47 | register_module("fc" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1])));
48 | if (i < layer_sizes.size() - 2) {
49 | biases.push_back(register_parameter("b" + std::to_string(i), torch::zeros(layer_sizes[i + 1])));
50 | }
51 | }
52 | }
53 |
54 | // Implement the forward function.
55 | std::vector CriticMLPModel::forward(const std::vector& inputs) {
56 |
57 | // makse sure that exactly two tensors are fed, state and action:
58 | if (inputs.size() != 2) {
59 | THROW_INVALID_USAGE("You have to provide exactly two tensors (state, action) to the CriticMLPModel");
60 | }
61 |
62 | // unpack
63 | auto state = inputs[0];
64 | auto action = inputs[1];
65 |
66 | // expand dims if necessary
67 | if (state.dim() == 1) {
68 | state = state.unsqueeze(0);
69 | }
70 | if (action.dim() == 1) {
71 | action = action.unsqueeze(0);
72 | }
73 |
74 | // flatten everything beyond dim 0:
75 | state = state.reshape({state.size(0), -1});
76 | action = action.reshape({action.size(0), -1});
77 |
78 | // concatenate inputs along feature dimension
79 | auto x = torch::cat({state, action}, 1);
80 |
81 | // forward pass
82 | for (int i = 0; i < layer_sizes.size() - 1; ++i) {
83 | if (i < layer_sizes.size() - 2) {
84 | x = torch::relu(fc_layers[i]->forward(x) + biases[i]);
85 | x = torch::dropout(x, dropout, is_training());
86 | } else {
87 | x = fc_layers[i]->forward(x);
88 | }
89 | }
90 | return std::vector{x};
91 | }
92 |
93 | } // namespace torchfort
94 |
--------------------------------------------------------------------------------
/tests/supervised/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.14)
2 |
3 | set(test_targets
4 | test_checkpoint
5 | test_training
6 | test_distributed_training
7 | )
8 |
9 | add_executable(test_checkpoint)
10 | target_sources(test_checkpoint
11 | PRIVATE
12 | test_checkpoint.cpp
13 | )
14 |
15 | add_executable(test_training)
16 | target_sources(test_training
17 | PRIVATE
18 | test_training.cpp
19 | )
20 |
21 | add_executable(test_distributed_training)
22 | target_sources(test_distributed_training
23 | PRIVATE
24 | test_distributed_training.cpp
25 | )
26 |
27 | find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED)
28 |
29 | foreach(tgt ${test_targets})
30 | target_include_directories(${tgt}
31 | PRIVATE
32 | ${YAML_CPP_INCLUDE_DIR}
33 | ${MPI_CXX_INCLUDE_DIRS}
34 | ${CMAKE_BINARY_DIR}/include
35 | ${CMAKE_CURRENT_SOURCE_DIR}/../
36 | )
37 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME})
38 | target_link_libraries(${tgt} PRIVATE ${TORCH_LIBRARIES})
39 | target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY})
40 | target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX)
41 | target_link_libraries(${tgt} PRIVATE GTest::gtest_main)
42 | target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>)
43 | target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>)
44 | if (TORCHFORT_ENABLE_GPU)
45 | target_include_directories(${tgt}
46 | PRIVATE
47 | ${CUDAToolkit_INCLUDE_DIRS}
48 | ${NCCL_INCLUDE_DIR}
49 | )
50 | target_link_libraries(${tgt} PRIVATE CUDA::cudart)
51 | target_compile_definitions(${tgt} PRIVATE ENABLE_GPU)
52 | endif()
53 |
54 | # discover tests: we have an issue with the work dir of gtest so disable that for now
55 | #gtest_discover_tests(${tgt})
56 | add_test(NAME ${tgt} COMMAND ${tgt} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
57 | endforeach()
58 |
59 | # installation
60 | # executable
61 | install(
62 | TARGETS ${test_targets}
63 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised
64 | )
65 |
66 | # copy files
67 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mlp.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs)
68 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mlp2.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs)
69 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mlp3.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs)
70 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mlp2_gradacc.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs)
71 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/missing_opt.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs)
72 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/missing_loss.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs)
73 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs)
74 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs)
75 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg_extra.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs)
76 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/scripts/setup_tests.py DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/scripts)
77 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/models.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #include
21 |
22 | #include
23 |
24 | #include "internal/base_model.h"
25 | #include "internal/defines.h"
26 | #include "internal/param_map.h"
27 |
28 | namespace torchfort {
29 |
30 | // MLP model in C++ using libtorch
31 | struct MLPModel : BaseModel, public std::enable_shared_from_this {
32 | void setup(const ParamMap& params) override;
33 | std::vector forward(const std::vector& inputs) override;
34 |
35 | double dropout;
36 | bool flatten_non_batch_dims;
37 | std::vector layer_sizes;
38 |
39 | // Use one of many "standard library" modules.
40 | std::vector fc_layers;
41 | std::vector biases;
42 | };
43 |
44 | struct CriticMLPModel : BaseModel, public std::enable_shared_from_this {
45 | void setup(const ParamMap& params) override;
46 | std::vector forward(const std::vector& inputs) override;
47 |
48 | double dropout;
49 | std::vector