├── .clang-format
├── CMakeLists.txt
├── LICENSE
├── README.md
├── cmake
└── test_mpi_f2c.f90
├── docker
├── Dockerfile
├── Dockerfile_gnu
└── Dockerfile_gnu_cpuonly
├── docs
├── Doxyfile
├── Makefile
├── _static
│ └── style.css
├── _templates
│ └── layout.html
├── api.rst
├── api
│ ├── c_api.rst
│ ├── config.rst
│ └── f_api.rst
├── conf.py
├── extras.rst
├── index.rst
├── installation.rst
├── requirements.txt
└── usage.rst
├── examples
├── cpp
│ └── cart_pole
│ │ ├── CMakeLists.txt
│ │ ├── README.md
│ │ ├── config.yaml
│ │ ├── config_sim.yaml
│ │ ├── env.cpp
│ │ ├── env.h
│ │ ├── media
│ │ ├── cartpole.gif
│ │ └── cartpole.mp4
│ │ ├── py_env.cpp
│ │ ├── python
│ │ ├── initialize_models.py
│ │ ├── models.py
│ │ └── visualize.py
│ │ └── train.cpp
└── fortran
│ ├── graph
│ ├── CMakeLists.txt
│ ├── README.md
│ ├── config.yaml
│ ├── connectivity.txt
│ ├── generate_loss.py
│ ├── generate_model.py
│ ├── media
│ │ └── validation_results.gif
│ ├── nodes.txt
│ ├── train.f90
│ └── visualize.py
│ └── simulation
│ ├── CMakeLists.txt
│ ├── README.md
│ ├── config_fcn_torchscript.yaml
│ ├── config_mlp_native.yaml
│ ├── generate_fcn_model.py
│ ├── media
│ └── validation_results.gif
│ ├── simulation.f90
│ ├── train.f90
│ ├── train_distributed.f90
│ └── visualize.py
├── requirements.txt
├── src
├── csrc
│ ├── distributed.cpp
│ ├── include
│ │ ├── internal
│ │ │ ├── base_loss.h
│ │ │ ├── base_lr_scheduler.h
│ │ │ ├── base_model.h
│ │ │ ├── defines.h
│ │ │ ├── distributed.h
│ │ │ ├── exceptions.h
│ │ │ ├── logging.h
│ │ │ ├── losses.h
│ │ │ ├── lr_schedulers.h
│ │ │ ├── model_pack.h
│ │ │ ├── model_state.h
│ │ │ ├── model_wrapper.h
│ │ │ ├── models.h
│ │ │ ├── nvtx.h
│ │ │ ├── param_map.h
│ │ │ ├── rl
│ │ │ │ ├── distributions.h
│ │ │ │ ├── noise_actor.h
│ │ │ │ ├── off_policy.h
│ │ │ │ ├── off_policy
│ │ │ │ │ ├── ddpg.h
│ │ │ │ │ ├── sac.h
│ │ │ │ │ └── td3.h
│ │ │ │ ├── on_policy.h
│ │ │ │ ├── on_policy
│ │ │ │ │ └── ppo.h
│ │ │ │ ├── policy.h
│ │ │ │ ├── replay_buffer.h
│ │ │ │ ├── rl.h
│ │ │ │ ├── rollout_buffer.h
│ │ │ │ └── utils.h
│ │ │ ├── setup.h
│ │ │ ├── tensor_list.h
│ │ │ ├── training.h
│ │ │ └── utils.h
│ │ ├── torchfort.h
│ │ ├── torchfort_enums.h
│ │ └── torchfort_rl.h
│ ├── logging.cpp
│ ├── losses
│ │ ├── l1_loss.cpp
│ │ ├── mse_loss.cpp
│ │ └── torchscript_loss.cpp
│ ├── lr_schedulers
│ │ ├── cosine_annealing_lr.cpp
│ │ ├── linear_lr.cpp
│ │ ├── multistep_lr.cpp
│ │ ├── polynomial_lr.cpp
│ │ ├── scheduler_setup.cpp
│ │ └── step_lr.cpp
│ ├── model_pack.cpp
│ ├── model_state.cpp
│ ├── model_wrapper.cpp
│ ├── models
│ │ ├── actor_critic_model.cpp
│ │ ├── mlp_model.cpp
│ │ └── sac_model.cpp
│ ├── param_map.cpp
│ ├── rl
│ │ ├── off_policy
│ │ │ ├── ddpg.cpp
│ │ │ ├── interface.cpp
│ │ │ ├── sac.cpp
│ │ │ └── td3.cpp
│ │ ├── on_policy
│ │ │ ├── interface.cpp
│ │ │ └── ppo.cpp
│ │ ├── policy.cpp
│ │ └── utils.cpp
│ ├── setup.cpp
│ ├── torchfort.cpp
│ ├── training.cpp
│ └── utils.cpp
├── fsrc
│ └── torchfort_m.F90
└── python
│ └── wandb_helper.py
└── tests
├── general
├── CMakeLists.txt
├── configs
│ ├── l1.yaml
│ ├── l1_multiarg.yaml
│ ├── mse.yaml
│ ├── mse_multiarg.yaml
│ ├── torchscript.yaml
│ ├── torchscript_multiarg.yaml
│ ├── torchscript_multiarg_extra.yaml
│ └── torchscript_multiout.yaml
├── scripts
│ └── setup_tests.py
└── test_losses.cpp
├── rl
├── CMakeLists.txt
├── configs
│ ├── ddpg.yaml
│ ├── ppo.yaml
│ ├── sac.yaml
│ └── td3.yaml
├── environments.h
├── test_distributions.cpp
├── test_off_policy.cpp
├── test_on_policy.cpp
├── test_replay_buffer.cpp
└── test_rollout_buffer.cpp
├── supervised
├── CMakeLists.txt
├── configs
│ ├── missing_loss.yaml
│ ├── missing_opt.yaml
│ ├── mlp.yaml
│ ├── mlp2.yaml
│ ├── mlp2_gradacc.yaml
│ ├── torchscript.yaml
│ ├── torchscript_multiarg.yaml
│ └── torchscript_multiarg_extra.yaml
├── scripts
│ └── setup_tests.py
├── test_checkpoint.cpp
└── test_training.cpp
└── test_utils.h
/.clang-format:
--------------------------------------------------------------------------------
1 | ---
2 | BasedOnStyle: LLVM
3 | ColumnLimit: 120
4 | CommentPragmas: '^\\.+'
5 | DerivePointerAlignment: false
6 | Language: Cpp
7 | PointerAlignment: Left
8 | UseTab: Never
9 | AlignAfterOpenBracket: Align
10 | AlignTrailingComments: true
11 | AllowShortBlocksOnASingleLine: false
12 | AllowShortCaseLabelsOnASingleLine : false
13 | AllowShortIfStatementsOnASingleLine: false
14 | AllowShortLoopsOnASingleLine: false
15 | ...
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | SPDX-License-Identifier: BSD-3-Clause
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | 1. Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | 2. Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | 3. Neither the name of the copyright holder nor the names of its
15 | contributors may be used to endorse or promote products derived from
16 | this software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TorchFort
2 |
3 | An Online Deep Learning Interface for HPC programs on NVIDIA GPUs
4 |
5 | ## Introduction
6 | TorchFort is a DL training and inference interface for HPC programs implemented using LibTorch, the C++ backend used by the [PyTorch](https://pytorch.org]) framework.
7 | The goal of this library is to help practitioners and domain scientists to seamlessly combine their simulation codes with Deep Learning functionalities available
8 | within PyTorch.
9 | This library can be invoked directly from Fortran or C/C++ programs, enabling transparent sharing of data arrays to and from the DL framework all contained within the
10 | simulation process (i.e., no external glue/data-sharing code required). The library can directly load PyTorch model definitions exported to TorchScript and implements a
11 | configurable training process that users can control via a simple YAML configuration file format. The configuration files enable users to specify optimizer and loss selection,
12 | learning rate schedules, and much more.
13 |
14 | Please refer to the [documentation](https://nvidia.github.io/TorchFort/) for additional information on the library, build instructions, and usage details.
15 |
16 | Please refer to the [examples](examples) to see TorchFort in action.
17 |
18 | Contact us or open a GitHub issue if you are interested in using this library in your own solvers and have questions on usage and/or feature requests.
19 |
20 | ## License
21 | This library is released under a BSD 3-clause license, which can be found in [LICENSE](license).
22 |
--------------------------------------------------------------------------------
/cmake/test_mpi_f2c.f90:
--------------------------------------------------------------------------------
1 | ! SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | ! SPDX-License-Identifier: BSD-3-Clause
3 | !
4 | ! Redistribution and use in source and binary forms, with or without
5 | ! modification, are permitted provided that the following conditions are met:
6 | !
7 | ! 1. Redistributions of source code must retain the above copyright notice, this
8 | ! list of conditions and the following disclaimer.
9 | !
10 | ! 2. Redistributions in binary form must reproduce the above copyright notice,
11 | ! this list of conditions and the following disclaimer in the documentation
12 | ! and/or other materials provided with the distribution.
13 | !
14 | ! 3. Neither the name of the copyright holder nor the names of its
15 | ! contributors may be used to endorse or promote products derived from
16 | ! this software without specific prior written permission.
17 | !
18 | ! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | ! AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | ! IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | ! DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | ! FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | ! DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | ! SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | ! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | ! OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | ! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
29 | module test_f2c
30 | use iso_c_binding
31 | implicit none
32 |
33 | type, bind(c) :: MPI_C_Comm
34 | integer(c_int64_t) :: comm
35 | end type MPI_C_Comm
36 |
37 | type, bind(c) :: MPI_F_Comm
38 | integer(c_int) :: comm
39 | end type MPI_F_Comm
40 |
41 | interface
42 | function MPI_Comm_f2c(fcomm) bind(C,name='MPI_Comm_f2c') result(res)
43 | import
44 | type(MPI_F_Comm), value :: fcomm
45 | type(MPI_C_Comm) :: res
46 | end function MPI_Comm_f2c
47 | end interface
48 | end module
49 |
50 | program main
51 | use mpi
52 | use test_f2c
53 | implicit none
54 |
55 | type(MPI_F_Comm) :: fcomm
56 | type(MPI_C_Comm) :: ccomm
57 |
58 | fcomm%comm = MPI_COMM_WORLD
59 |
60 | ccomm = MPI_Comm_f2c(fcomm)
61 | end program
62 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/cuda:12.8.1-devel-ubuntu22.04
2 |
3 | # Install System Dependencies
4 | ENV DEBIAN_FRONTEND noninteractive
5 | RUN apt update -y && \
6 | apt install -y curl unzip wget cmake && \
7 | apt install -y python3 python-is-python3 python3-pip python3-pybind11 && \
8 | apt install -y git vim gfortran doxygen && \
9 | apt install -y libibverbs-dev ibverbs-utils numactl
10 |
11 | # Install NVHPC SDK
12 | RUN wget https://developer.download.nvidia.com/hpc-sdk/25.3/nvhpc_2025_253_Linux_x86_64_cuda_12.8.tar.gz && \
13 | tar xpzf nvhpc_2025_253_Linux_x86_64_cuda_12.8.tar.gz && \
14 | nvhpc_2025_253_Linux_x86_64_cuda_12.8/install --quiet && \
15 | rm -rf nvhpc_2025_253_Linux_x86_64_cuda_12.8 nvhpc_2025_253_Linux_x86_64_cuda_12.8.tar.gz
16 |
17 | ENV PATH /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/compilers/bin:$PATH
18 | ENV PATH /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/comm_libs/mpi/bin:$PATH
19 | ENV LD_LIBRARY_PATH /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/cuda/lib64:$LD_LIBRARY_PATH
20 | ENV LD_LIBRARY_PATH /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/comm_libs/mpi/lib:$LD_LIBRARY_PATH
21 | ENV LD_LIBRARY_PATH /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/comm_libs/nvshmem/lib:$LD_LIBRARY_PATH
22 | ENV LD_LIBRARY_PATH /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/math_libs/lib64:$LD_LIBRARY_PATH
23 | ENV CUDA_HOME /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/cuda
24 |
25 | RUN echo "source /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/comm_libs/12.8/hpcx/latest/hpcx-init.sh; hpcx_load" >> /root/.bashrc
26 |
27 | # Install PyTorch
28 | RUN pip3 install torch==2.7.0
29 |
30 | # Install yaml-cpp
31 | RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \
32 | cd yaml-cpp && \
33 | mkdir build && cd build && \
34 | cmake -DCMAKE_INSTALL_PREFIX=/opt/yaml-cpp \
35 | -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" \
36 | -DBUILD_SHARED_LIBS=OFF \
37 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && \
38 | make -j$(nproc) && make install
39 | ENV LD_LIBRARY_PATH /opt/yaml-cpp/lib:${LD_LIBRARY_PATH}
40 |
41 | # Install HDF5
42 | RUN wget https://github.com/HDFGroup/hdf5/archive/refs/tags/hdf5-1_14_3.tar.gz && \
43 | tar xzf hdf5-1_14_3.tar.gz && \
44 | cd hdf5-hdf5-1_14_3 && \
45 | CC=mpicc FC=mpifort FCFLAGS=-fPIC CFLAGS=-fPIC \
46 | ./configure --enable-parallel \
47 | --enable-fortran \
48 | --prefix=/opt/hdf5 && \
49 | make -j$(nproc) install && \
50 | cd .. && \
51 | rm -rf hdf5-hdf5-1_14_3 hdf5-1_14_3.tar.gz
52 | ENV LD_LIBRARY_PATH /opt/hdf5/lib:$LD_LIBRARY_PATH
53 |
54 | # Install additional Python dependencies
55 | RUN pip3 install wandb ruamel-yaml h5py matplotlib pygame moviepy
56 |
57 | # Install TorchFort
58 | ENV FC=nvfortran
59 | ENV HDF5_ROOT=/opt/hdf5
60 | COPY . /torchfort
61 | RUN cd /torchfort && mkdir build && cd build && \
62 | CUDA_PATH=/opt/nvidia/hpc_sdk/Linux_x86_64/25.3/cuda \
63 | cmake -DCMAKE_INSTALL_PREFIX=/opt/torchfort \
64 | -DCMAKE_CXX_COMPILER=`which g++` \
65 | -DTORCHFORT_YAML_CPP_ROOT=/opt/yaml-cpp \
66 | -DTORCHFORT_NCCL_ROOT=/opt/nccl/build \
67 | -DTORCHFORT_BUILD_EXAMPLES=1 \
68 | -DTORCHFORT_BUILD_TESTS=1 \
69 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \
70 | .. && \
71 | make -j$(nproc) install && \
72 | cd / && rm -rf torchfort
73 | ENV LD_LIBRARY_PATH /opt/torchfort/lib:${LD_LIBRARY_PATH}
74 | ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH}
75 |
76 | ENTRYPOINT bash
77 |
--------------------------------------------------------------------------------
/docker/Dockerfile_gnu:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/cuda:12.8.1-devel-ubuntu22.04
2 |
3 | # Install System Dependencies
4 | ENV DEBIAN_FRONTEND noninteractive
5 | RUN apt update -y && \
6 | apt install -y curl unzip wget cmake && \
7 | apt install -y python3 python-is-python3 python3-pip python3-pybind11 && \
8 | apt install -y git vim gfortran doxygen && \
9 | apt install -y libibverbs-dev ibverbs-utils numactl
10 |
11 | # Download HPCX and compile with Fortran support
12 | RUN cd /opt && \
13 | wget http://content.mellanox.com/hpc/hpc-x/v2.22.1rc4/hpcx-v2.22.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64.tbz && \
14 | tar xjf hpcx-v2.22.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64.tbz && \
15 | mv hpcx-v2.22.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64 hpcx && \
16 | rm -rf hpcx/ompi && \
17 | cd hpcx/sources && \
18 | tar xzf openmpi-gitclone.tar.gz && \
19 | cd openmpi-gitclone && \
20 | LD_LIBRARY_PATH=/opt/hpcx/hcoll/lib:/opt/hpcx/ucc/lib:/opt/hpcx/ucx/lib:$LD_LIBRARY_PATH \
21 | FC=gfortran CC=gcc CXX=g++ ./configure --prefix=/opt/hpcx/ompi \
22 | --with-libevent=internal \
23 | --enable-mpi1-compatibility \
24 | --without-xpmem \
25 | --with-cuda=/usr/local/cuda \
26 | --with-slurm \
27 | --with-platform=contrib/platform/mellanox/optimized \
28 | --with-hcoll=/opt/hpcx/hcoll \
29 | --with-ucx=/opt/hpcx/ucx \
30 | --with-ucc=/opt/hpcx/ucc && \
31 | make -j$(nproc) install && \
32 | cd /opt && rm -rf /opt/hpcx/sources/openmpi-gitclone && rm hpcx-v2.22.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64.tbz
33 |
34 | ENV PATH /opt/hpcx/ompi/bin:$PATH
35 | ENV LD_LIBRARY_PATH /opt/hpcx/ompi/lib:$LD_LIBRARY_PATH
36 |
37 | RUN echo "source /opt/hpcx/hpcx-init.sh; hpcx_load" >> /root/.bashrc
38 |
39 | # Install PyTorch
40 | RUN pip3 install torch==2.7.0
41 |
42 | # Install yaml-cpp
43 | RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \
44 | cd yaml-cpp && \
45 | mkdir build && cd build && \
46 | cmake -DCMAKE_INSTALL_PREFIX=/opt/yaml-cpp \
47 | -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" \
48 | -DBUILD_SHARED_LIBS=OFF \
49 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && \
50 | make -j$(nproc) && make install
51 | ENV LD_LIBRARY_PATH /opt/yaml-cpp/lib:${LD_LIBRARY_PATH}
52 |
53 | # Install HDF5
54 | RUN wget https://github.com/HDFGroup/hdf5/archive/refs/tags/hdf5-1_14_3.tar.gz && \
55 | tar xzf hdf5-1_14_3.tar.gz && \
56 | cd hdf5-hdf5-1_14_3 && \
57 | CC=mpicc FC=mpifort \
58 | ./configure --enable-parallel \
59 | --enable-fortran \
60 | --prefix=/opt/hdf5 && \
61 | make -j$(nproc) install && \
62 | cd .. && \
63 | rm -rf hdf5-hdf5-1_14_3 hdf5-1_14_3.tar.gz
64 | ENV LD_LIBRARY_PATH /opt/hdf5/lib:$LD_LIBRARY_PATH
65 |
66 | # Install additional Python dependencies
67 | RUN pip3 install wandb ruamel-yaml h5py matplotlib pygame moviepy
68 |
69 | # Install TorchFort
70 | ENV FC=gfortran
71 | ENV HDF5_ROOT=/opt/hdf5
72 | COPY . /torchfort
73 | RUN cd /torchfort && mkdir build && cd build && \
74 | cmake -DCMAKE_INSTALL_PREFIX=/opt/torchfort \
75 | -DTORCHFORT_YAML_CPP_ROOT=/opt/yaml-cpp \
76 | -DTORCHFORT_NCCL_ROOT=/opt/nccl/build \
77 | -DTORCHFORT_BUILD_EXAMPLES=1 \
78 | -DTORCHFORT_BUILD_TESTS=1 \
79 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \
80 | .. && \
81 | make -j$(nproc) install && \
82 | cd / && rm -rf torchfort
83 | ENV LD_LIBRARY_PATH /opt/torchfort/lib:${LD_LIBRARY_PATH}
84 | ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH}
85 |
86 | ENTRYPOINT bash
87 |
--------------------------------------------------------------------------------
/docker/Dockerfile_gnu_cpuonly:
--------------------------------------------------------------------------------
1 | FROM ubuntu:22.04
2 |
3 | # Install System Dependencies
4 | ENV DEBIAN_FRONTEND noninteractive
5 | RUN apt update -y && \
6 | apt install -y build-essential && \
7 | apt install -y curl unzip wget cmake && \
8 | apt install -y python3 python-is-python3 python3-pip python3-pybind11 && \
9 | apt install -y git vim gfortran doxygen && \
10 | apt install -y libibverbs-dev ibverbs-utils numactl
11 |
12 | # Download OpenMPI and compile with Fortran support
13 | RUN cd /opt && \
14 | wget https://download.open-mpi.org/release/open-mpi/v5.0/openmpi-5.0.5.tar.gz && \
15 | tar xzf openmpi-5.0.5.tar.gz && \
16 | cd openmpi-5.0.5 && \
17 | FC=gfortran CC=gcc CXX=g++ ./configure --prefix=/opt/openmpi \
18 | --with-libevent=internal \
19 | --enable-mpi1-compatibility \
20 | --without-xpmem \
21 | --with-slurm && \
22 | make -j$(nproc) install && \
23 | cd /opt && rm -rf openmpi-5.0.5 && rm openmpi-5.0.5.tar.gz
24 |
25 | ENV PATH /opt/openmpi/bin:$PATH
26 | ENV LD_LIBRARY_PATH /opt/openmpi/lib:$LD_LIBRARY_PATH
27 |
28 | # Install PyTorch
29 | RUN pip3 install torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu
30 |
31 | # Install yaml-cpp
32 | RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \
33 | cd yaml-cpp && \
34 | mkdir build && cd build && \
35 | cmake -DCMAKE_INSTALL_PREFIX=/opt/yaml-cpp \
36 | -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" \
37 | -DBUILD_SHARED_LIBS=OFF \
38 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && \
39 | make -j$(nproc) && make install
40 | ENV LD_LIBRARY_PATH /opt/yaml-cpp/lib:${LD_LIBRARY_PATH}
41 |
42 | # Install HDF5
43 | RUN wget https://github.com/HDFGroup/hdf5/archive/refs/tags/hdf5-1_14_3.tar.gz && \
44 | tar xzf hdf5-1_14_3.tar.gz && \
45 | cd hdf5-hdf5-1_14_3 && \
46 | CC=mpicc FC=mpifort \
47 | ./configure --enable-parallel \
48 | --enable-fortran \
49 | --prefix=/opt/hdf5 && \
50 | make -j$(nproc) install && \
51 | cd .. && \
52 | rm -rf hdf5-hdf5-1_14_3 hdf5-1_14_3.tar.gz
53 | ENV LD_LIBRARY_PATH /opt/hdf5/lib:$LD_LIBRARY_PATH
54 |
55 | # Install additional Python dependencies
56 | RUN pip3 install wandb ruamel-yaml h5py matplotlib pygame moviepy
57 |
58 | # Install TorchFort without GPU support
59 | ENV FC=gfortran
60 | ENV HDF5_ROOT=/opt/hdf5
61 | COPY . /torchfort
62 | RUN cd /torchfort && mkdir build && cd build && \
63 | cmake -DCMAKE_INSTALL_PREFIX=/opt/torchfort \
64 | -DTORCHFORT_YAML_CPP_ROOT=/opt/yaml-cpp \
65 | -DTORCHFORT_ENABLE_GPU=0 \
66 | -DTORCHFORT_BUILD_EXAMPLES=1 \
67 | -DTORCHFORT_BUILD_TESTS=1 \
68 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \
69 | .. && \
70 | make -j$(nproc) install && \
71 | cd / && rm -rf torchfort
72 | ENV LD_LIBRARY_PATH /opt/torchfort/lib:${LD_LIBRARY_PATH}
73 | ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH}
74 |
75 | ENTRYPOINT bash
76 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | doxygen Doxyfile
21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
22 |
--------------------------------------------------------------------------------
/docs/_static/style.css:
--------------------------------------------------------------------------------
1 | .wy-nav-content {
2 | max-width: 1200px !important;
3 | }
4 |
--------------------------------------------------------------------------------
/docs/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 | {% block sidebartitle %} {{ super() }}
3 |
4 |
32 | {% endblock %}
33 |
34 | {% block footer %} {{ super() }}
35 |
36 |
51 | {% endblock %}
52 |
--------------------------------------------------------------------------------
/docs/api.rst:
--------------------------------------------------------------------------------
1 | .. _api-label:
2 |
3 | #############
4 | TorchFort API
5 | #############
6 |
7 | The following sections describe the types and functions available in the TorchFort library for C/C++ and Fortran programs and
8 | the also the configuration file structure and available options.
9 |
10 | .. toctree::
11 |
12 | api/config
13 | api/c_api
14 | api/f_api
15 |
16 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | import sphinx_rtd_theme
8 |
9 | # -- Path setup --------------------------------------------------------------
10 |
11 | # If extensions (or modules to document with autodoc) are in another directory,
12 | # add these directories to sys.path here. If the directory is relative to the
13 | # documentation root, use os.path.abspath to make it absolute, like shown here.
14 | #
15 | # import os
16 | # import sys
17 | # sys.path.insert(0, os.path.abspath('.'))
18 |
19 |
20 | # -- Project information -----------------------------------------------------
21 |
22 | project = 'torchfort'
23 | copyright = '2023, NVIDIA Corporation'
24 | author = 'NVIDIA Corporation'
25 |
26 | # The full version, including alpha/beta/rc tags
27 | release = '2023'
28 |
29 |
30 | # -- General configuration ---------------------------------------------------
31 |
32 | # Add any Sphinx extension module names here, as strings. They can be
33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
34 | # ones.
35 | extensions = [
36 | 'breathe',
37 | 'sphinx.ext.mathjax',
38 | 'sphinx_tabs.tabs',
39 | 'sphinxfortran.fortran_domain',
40 | ]
41 |
42 | # Add any paths that contain templates here, relative to this directory.
43 | templates_path = ['_templates']
44 |
45 | # List of patterns, relative to source directory, that match files and
46 | # directories to ignore when looking for source files.
47 | # This pattern also affects html_static_path and html_extra_path.
48 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
49 |
50 | # The name of the Pygments (syntax highlighting) style to use.
51 | #pygments_style = 'sphinx'
52 |
53 | highlight_language = 'cpp'
54 |
55 | def setup(app):
56 | app.add_css_file('style.css')
57 |
58 | # -- Options for HTML output -------------------------------------------------
59 |
60 | # The theme to use for HTML and HTML Help pages. See the documentation for
61 | # a list of builtin themes.
62 | #
63 | html_theme = 'sphinx_rtd_theme'
64 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
65 |
66 | html_theme_options = {
67 | 'navigation_depth': 6
68 | }
69 |
70 | # Add any paths that contain custom static files (such as style sheets) here,
71 | # relative to this directory. They are copied after the builtin static files,
72 | # so a file named "default.css" will overwrite the builtin "default.css".
73 | html_static_path = ['_static']
74 |
75 | breathe_projects = { "torchfort": "xml/" }
76 | breathe_default_project = "torchfort"
77 |
--------------------------------------------------------------------------------
/docs/extras.rst:
--------------------------------------------------------------------------------
1 | ######
2 | Extras
3 | ######
4 |
5 | .. _wandb_support-ref:
6 |
7 | Weights and Biases Support
8 | ==========================
9 |
10 | `Weights and Biases `_ (wandb) is a popular tool for monitoring machine learning training workflows. Wandb supports plotting and comparing loss curves and other relevant deep learning metrics, system utilization (including GPU, CPU and memory utilization) and other advanced logging functionalities like uploading images and/or videos, network weights and data artifacts.
11 |
12 | Since wandb does currently not offer a C++ interface and thus cannot be called from Fortran or C/C++ directly, we've implemented a wandb daemon written in Python instead. This daemon runs as a background process and waits for changes in a log file generated by the TorchFort application. In order to enable wandb support for your TorchFort application, the following steps have to be performed.
13 |
14 | Add Custom Metrics Reporting to Application
15 | -------------------------------------------
16 |
17 | The TorchFort training routines already provide logging of training step, loss values as well as learning rate which are captured by the wandb daemon. Additional custom metrics can be added manually by the user. For this purpose, the user may add calls of ``torchfort_wandb_log`` or ``torchfort_rl_off_policy_wandb_log`` for traditional and reinforcement learning applications respectively (see :ref:`supervised_learning-ref` and :ref:`reinforcement_learning-ref` for details about why we provide different implementations for these two cases). For more information, see :ref:`torchfort_api_c-ref` for C/C++ and :ref:`torchfort_api_f-ref` for Fortran applications.
18 |
19 | Set up Environment
20 | ------------------
21 |
22 | You need to specify your `wandb api token `_ via the environment variable ``WANDB_API_KEY`` (see the `wandb documentation on available environment variables `_ for details).
23 | Furthermore, the daemon needs to know where the the wandb logging data from the TorchFort application will be stored. This can be done by defining the environment variable ``TORCHFORT_LOGDIR``. Lastly, a user0defined wandb logging directory ``WANDB_LOGGING_DIR`` can be created to gather all wandb information as well as the config file in a place specific to the run.
24 |
25 | .. note::
26 | The logging directory ``TORCHFORT_LOGDIR`` needs to be specified before the daemon and TorchFort application are launched.
27 |
28 | Start Background Watcher Process
29 | --------------------------------
30 |
31 | Now, the wandb daemon process needs to be started. Assuming TorchFort was installed in ``TORCHFORT_INSTALL_DIR``, we can run
32 |
33 | .. code-block:: bash
34 |
35 | python ${TORCHFORT_INSTALL_DIR}/bin/python/wandb_helper.py \
36 | --wandb_dir=${WANDB_LOGGING_DIR} \
37 | --wandb_group= \
38 | --wandb_project= \
39 | --wandb_entity= \
40 | --run_tag= \
41 | --timeout=2400 &
42 |
43 | The wandb group, project as well as entity name correspond to the wandb project you are logging to. Those correspond to the respective arguments of ``wandb.init`` documented `here `_. Note that the group does not need to exist and will be created during initialization. The run tag can be any alphanumeric string and can be used to identify the specific run on the wandb dashboard. Lastly, the timeout (measured in seconds) determines for how long the background process will wait for changes to appear in ``${TORCHFORT_LOGIDR}/torchfort.log`` before wrapping up the monitoring.
44 |
45 | .. note::
46 | Do not forget to launch the daemon into the background.
47 |
48 | Start Your TorchFort Application
49 | --------------------------------
50 |
51 | In the configuration file for your TorchFort application, make sure to enable wandb logging in the ``general`` section by adding or modifying the line ``enable_wandb_hook: 1``. Lastly, start the TorchFort application as usual, e.g.:
52 |
53 | .. code-block:: bash
54 |
55 | ./my_torchfort_app arg1 arg2 arg3
56 |
57 | The daemone process will pick up the log lines from ``${TORCHFORT_LOGIDR}/torchfort.log`` and display the data on the corresponding job dashboard.
58 |
59 | .. note::
60 | The daemon can finalize the monitoring while the TorchFort application is still running if the timeout is not set sufficiently large, especially for long running applications with very sparse logging.
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. TorchFort documentation master file, created by
2 | sphinx-quickstart on Wed Jun 1 13:44:41 2022.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | ############################################################################
7 | TorchFort: An Online Deep Learning Interface for HPC programs on NVIDIA GPUs
8 | ############################################################################
9 | These pages contain the documentation for TorchFort, an online deep learning interface for HPC programs.
10 |
11 | TorchFort is a DL training and inference interface for HPC programs implemented using LibTorch, the C++ backend used by the `PyTorch `_ framework.
12 | The goal of this library is to help practitioners and domain scientists to seamlessly combine their simulation codes with Deep Learning functionalities available
13 | within PyTorch.
14 | This library can be invoked directly from Fortran or C/C++ programs, enabling transparent sharing of data arrays to and from the DL framework all contained within the
15 | simulation process (i.e., no external glue/data-sharing code required). The library can directly load PyTorch model definitions exported to TorchScript and implements a
16 | configurable training process that users can control via a simple YAML configuration file format. The configuration files enable users to specify optimizer and loss selection,
17 | learning rate schedules, and much more.
18 |
19 | Please contact us or open a GitHub issue if you are interested in using this library
20 | in your own solvers and have questions on usage and/or feature requests.
21 |
22 |
23 | Table of Contents
24 | =================
25 | .. toctree::
26 | :maxdepth: 4
27 |
28 | installation
29 | usage
30 | api
31 | extras
32 |
33 |
34 | Indices and tables
35 | ==================
36 |
37 | * :ref:`genindex`
38 | * :ref:`modindex`
39 | * :ref:`search`
40 |
--------------------------------------------------------------------------------
/docs/installation.rst:
--------------------------------------------------------------------------------
1 | ############
2 | Installation
3 | ############
4 |
5 | TorchFort can be installed in multiple ways but we highly recommend building and using a Docker container.
6 |
7 | Docker Installation
8 | -------------------
9 |
10 | We provide a `Dockerfile `_ which contains all relevant dependencies and builds using the `NVIDIA HPC SDK `_ software libraries and compilers, which is our recommended way to build TorchFort. In order to build TorchFort using Docker, simply clone the repo and call:
11 |
12 | .. code-block:: bash
13 |
14 | docker build -t torchfort:latest -f docker/Dockerfile .
15 |
16 | from the top level directory of the repo. Inside the container, TorchFort will be installed in ``/opt/torchfort``.
17 |
18 | We provide an alternative docker file `Dockerfile_gnu `_ which can be used to build TorchFort using GNU compilers. Additionally, we provide a docker file `Dockerfile_gnu_cpuonly `_ which can be used to build TorchFort using GNU compilers without GPU support enabled.
19 |
20 | CMake Installation
21 | ------------------
22 |
23 | For a native installation TorchFort provides a `CMakeList.txt `_ file. Please make sure that the following required packages are installed on your system before installing TorchFort:
24 |
25 | * Requirements for core functionality and examples:
26 |
27 | - CUDA 12.1 or newer
28 | - ``python`` version 3.6 or higher
29 | - ``pybind11``
30 | - ``yaml-cpp`` from https://github.com/jbeder/yaml-cpp.git
31 | - MPI
32 | - NVIDIA Collective Communication Library (``NCCL``)
33 | - ``HDF5``
34 | - the Python modules specified in `requirements.txt `_
35 | - GNU or `NVHPC `_ compilers. NVHPC compilers are **required** if CUDA Fortran device array support is desired.
36 |
37 | * Additional requirements for building this documentation:
38 |
39 | - Doxygen
40 | - the Python modules specified in `docs/requirements.txt `_
41 |
42 | For CPU-only builds, CUDA and NCCL are not required.
43 |
44 |
45 | To build TorchFort, clone the repo then call the following from the root directory:
46 |
47 | .. code-block:: bash
48 |
49 | mkdir build && cd build
50 | cmake -DCMAKE_INSTALL_PREFIX= \
51 | -DTORCHFORT_YAML_CPP_ROOT= \
52 | -DTORCHFORT_BUILD_EXAMPLES=1 \
53 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \
54 | ..
55 | make -j install
56 |
57 | See the top level `CMakeList.txt `_ file for additional CMake configuration options.
58 |
59 | Build Documentation
60 | -------------------
61 |
62 | The documentation can be built with the corresponding ``Makefile`` in the ``docs`` directory. Make sure that the requirements are installed and call:
63 |
64 | .. code-block:: bash
65 |
66 | cd docs && make html
67 |
68 | The docs will be located in ``docs/_build/html`` and can be viewed locally in your web browser.
69 |
70 | Directory Structure
71 | -------------------
72 |
73 | Independent of how you decide to install TorchFort, the directory structure will be as follows::
74 |
75 |
76 | |--- bin
77 | |--- examples
78 | |--- cpp
79 | |--- fortran
80 | |--- python
81 | |--- include
82 | |--- lib
83 |
84 | The ``bin`` folder contains the examples written in C++ or Fortran located in the corresponding subdirectories. The ``python`` subfolder contains the Python wrappers for :ref:`wandb_support-ref`.
85 |
86 | The Fortran module ``torchfort.mod`` as well as the C headers can be found inside the ``include`` folder and the dynamic libraries inside the ``lib`` folder.
87 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | breathe
2 | sphinx-rtd-theme==1.1.1
3 | sphinx-tabs==3.4.0
4 | sphinx-fortran
5 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | set(cart_pole_example_targets
2 | train_cart_pole
3 | )
4 |
5 | add_library(environments STATIC)
6 | target_sources(environments
7 | PRIVATE
8 | env.cpp
9 | )
10 | set_property(TARGET environments PROPERTY POSITION_INDEPENDENT_CODE ON)
11 |
12 | add_executable(train_cart_pole)
13 | target_sources(train_cart_pole
14 | PRIVATE
15 | train.cpp
16 | )
17 |
18 | find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED)
19 | find_package(pybind11 CONFIG REQUIRED)
20 | pybind11_add_module(PyEnvironments py_env.cpp)
21 | target_link_libraries(PyEnvironments PRIVATE environments)
22 |
23 | foreach(tgt ${cart_pole_example_targets})
24 | target_include_directories(${tgt}
25 | PRIVATE
26 | ${YAML_CPP_INCLUDE_DIR}
27 | ${MPI_CXX_INCLUDE_DIRS}
28 | ${CUDAToolkit_INCLUDE_DIRS}
29 | ${Python_INCLUDE_DIRS}
30 | ${CMAKE_BINARY_DIR}/include
31 | )
32 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME})
33 | target_link_libraries(${tgt} PRIVATE ${TORCH_LIBRARIES})
34 | target_link_libraries(${tgt} PRIVATE ${Python_LIBRARIES})
35 | target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX)
36 | target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY})
37 | target_link_libraries(${tgt} PRIVATE environments)
38 | target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>)
39 | target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>)
40 | if (TORCHFORT_ENABLE_GPU)
41 | target_include_directories(${tgt}
42 | PRIVATE
43 | ${CUDAToolkit_INCLUDE_DIRS}
44 | )
45 | target_link_libraries(${tgt} PRIVATE CUDA::cudart)
46 | endif()
47 | endforeach()
48 |
49 | # installation
50 | # executable
51 | install(
52 | TARGETS ${cart_pole_example_targets}
53 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole
54 | )
55 |
56 | # python env
57 | install(
58 | TARGETS PyEnvironments
59 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole/python
60 | )
61 |
62 | # config files
63 | install(
64 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/config.yaml ${CMAKE_CURRENT_SOURCE_DIR}/config_sim.yaml
65 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole
66 | )
67 |
68 | # python files
69 | install(
70 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/python/models.py ${CMAKE_CURRENT_SOURCE_DIR}/python/initialize_models.py ${CMAKE_CURRENT_SOURCE_DIR}/python/visualize.py
71 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole/python
72 | )
73 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/config.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | report_frequency: 1
3 | enable_wandb_hook: 1
4 | verbose: 1
5 |
6 | algorithm:
7 | type: td3
8 | parameters:
9 | batch_size: 512
10 | num_critics: 2
11 | policy_lag: 2
12 | nstep: 1
13 | nstep_reward_reduction: sum_no_skip
14 | gamma: 0.99
15 | rho: 0.99
16 |
17 | actor:
18 | type: space_noise
19 | parameters:
20 | a_low: -1.0
21 | a_high: 1.0
22 | clip: 0.3
23 | sigma_train: 0.1
24 | sigma_explore: 0.2
25 | adaptive: 0
26 |
27 | replay_buffer:
28 | type: uniform
29 | parameters:
30 | max_size: 50000
31 | min_size: 1024
32 |
33 | policy_model:
34 | type: torchscript
35 | parameters:
36 | filename: policy.pt
37 |
38 | critic_model:
39 | type: torchscript
40 | parameters:
41 | filename: value.pt
42 |
43 | optimizer:
44 | type: adam
45 | parameters:
46 | learning_rate: 0.001
47 | beta1: 0.9
48 | beta2: 0.999
49 | weight_decay: 0
50 | eps: 1e-6
51 | amsgrad: 0
52 |
53 | policy_lr_scheduler:
54 | type: cosine_annealing
55 | parameters:
56 | T_max: 500000000
57 |
58 | critic_lr_scheduler:
59 | type: cosine_annealing
60 | parameters:
61 | T_max: 500000000
62 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/config_sim.yaml:
--------------------------------------------------------------------------------
1 | num_episodes: 50000
2 | max_steps_per_episode: 2500
3 | eval_frequency: 25
4 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/env.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #include
32 | #include
33 | #include
34 | // #include
35 |
36 | enum IntegratorType { EXPLICIT_EULER, SEMI_IMPLICIT_EULER };
37 | using StateVector = std::array;
38 |
39 | class CartPoleEnv {
40 |
41 | public:
42 | CartPoleEnv();
43 | CartPoleEnv(const CartPoleEnv&) = delete;
44 | CartPoleEnv& operator=(const CartPoleEnv&) = delete;
45 |
46 | StateVector reset();
47 | std::pair getStateBounds();
48 |
49 | std::tuple step(float action);
50 |
51 | private:
52 | // sim parameters
53 | bool terminated_;
54 | float gravity_;
55 | float masscart_;
56 | float masspole_;
57 | float total_mass_;
58 | float length_;
59 | float polemass_length_;
60 | float force_mag_;
61 | float dt_;
62 | float penalty_;
63 | IntegratorType kinematics_integrator_;
64 | int steps_beyond_terminated_;
65 |
66 | // threshold parameters
67 | float theta_threshold_radians_;
68 | float x_threshold_;
69 |
70 | // random stuff
71 | std::mt19937_64 rng_;
72 | std::uniform_real_distribution uniform_dist_;
73 |
74 | // state vector
75 | StateVector state_;
76 | };
77 |
78 | // pybind11 stuff
79 | // namespace py = pybind11;
80 | // PYBIND11_MODULE(environments, m) {
81 | // py::class_(m, "CartPoleEnv", py::dynamic_attr())
82 | // .def(py::init<>())
83 | // .def("step", &CartPoleEnv::step, py::arg("action"))
84 | // .def("reset", &CartPoleEnv::reset);
85 | //}
86 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/media/cartpole.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA/TorchFort/3ac715ffabe6f0e3fe7e6ee8276e0998d1513310/examples/cpp/cart_pole/media/cartpole.gif
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/media/cartpole.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA/TorchFort/3ac715ffabe6f0e3fe7e6ee8276e0998d1513310/examples/cpp/cart_pole/media/cartpole.mp4
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/py_env.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #include
32 | #include
33 |
34 | #include "env.h"
35 |
36 | // pybind11 stuff
37 | namespace py = pybind11;
38 | PYBIND11_MODULE(PyEnvironments, m) {
39 | py::class_(m, "CartPoleEnv", py::dynamic_attr())
40 | .def(py::init<>())
41 | .def("step", &CartPoleEnv::step, py::arg("action"))
42 | .def("reset", &CartPoleEnv::reset);
43 | }
44 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/python/initialize_models.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
29 |
30 | import argparse as ap
31 | import math
32 | import torch
33 | from functools import partial
34 | from torch import nn
35 | import torch.nn.functional as F
36 |
37 | from models import weight_init, PolicyFunc, ValueFunc
38 |
39 | def main(args):
40 |
41 | # set seed
42 | torch.manual_seed(666)
43 |
44 | # CUDA check
45 | if torch.cuda.is_available():
46 | torch.cuda.manual_seed(666)
47 | device = torch.device("cuda:0")
48 | else:
49 | device = torch.device("cpu")
50 |
51 | # parameters
52 | batch_size = 64
53 |
54 | # policy model
55 | pmodel = PolicyFunc(hidden_features=args.num_hidden_features).to(device)
56 | weight_init(pmodel)
57 | jpmodel = torch.jit.script(pmodel)
58 | inp = torch.ones((batch_size, 4), dtype=torch.float32, device=device)
59 | out = jpmodel(inp)
60 | print("Policy model:", pmodel)
61 | print("Policy model output shape:", out.shape)
62 | torch.jit.save(jpmodel, "policy.pt")
63 |
64 | # value model
65 | qmodel = ValueFunc(hidden_features=args.num_hidden_features).to(device)
66 | weight_init(qmodel)
67 | jqmodel = torch.jit.script(qmodel)
68 | inp_a = torch.ones((batch_size, 1), dtype=torch.float32, device=device)
69 | out = jqmodel(inp, inp_a)
70 | print("Value model:", qmodel)
71 | print("Value model output shape:", out.shape)
72 | torch.jit.save(jqmodel, "value.pt")
73 |
74 | if __name__ == "__main__":
75 | parser = ap.ArgumentParser()
76 | parser.add_argument("--num_hidden_features", type=int, default=128, help="Number of hidden features")
77 | args = parser.parse_args()
78 |
79 | main(args)
80 |
--------------------------------------------------------------------------------
/examples/cpp/cart_pole/python/models.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
29 |
30 | import math
31 | import torch
32 | from torch import nn
33 | import torch.nn.functional as F
34 |
35 | def weight_init(model, scale=0.02):
36 | with torch.no_grad():
37 | for m in model.modules():
38 | if isinstance(m, nn.Linear):
39 | sqrtk = math.sqrt(1./float(m.weight.shape[1]))
40 | nn.init.uniform_(m.weight, a=-sqrtk, b=sqrtk)
41 | if m.bias is not None:
42 | m.bias.data.zero_()
43 |
44 | class PolicyFunc(nn.Module):
45 | def __init__(self, hidden_features=128):
46 | super(PolicyFunc, self).__init__()
47 |
48 | layers = [nn.Linear(in_features = 4,
49 | out_features = hidden_features,
50 | bias=True),
51 | nn.ReLU(),
52 | nn.Linear(in_features = hidden_features,
53 | out_features = hidden_features // 2,
54 | bias=True),
55 | nn.ReLU(),
56 | nn.Linear(in_features = hidden_features // 2,
57 | out_features = 1,
58 | bias=True),
59 | nn.Tanh()]
60 |
61 | self.fwd = nn.Sequential(*layers)
62 |
63 | def forward(self, x: torch.Tensor) -> torch.Tensor:
64 | return self.fwd(x)
65 |
66 | class ValueFunc(nn.Module):
67 | def __init__(self, hidden_features=128):
68 | super(ValueFunc, self).__init__()
69 |
70 | layers = [nn.Linear(in_features = 5,
71 | out_features = hidden_features,
72 | bias=True),
73 | nn.ReLU(),
74 | nn.Linear(in_features = hidden_features,
75 | out_features = hidden_features // 2,
76 | bias=True),
77 | nn.ReLU(),
78 | nn.Linear(in_features = hidden_features // 2,
79 | out_features = 1,
80 | bias=True)]
81 |
82 | self.fwd = nn.Sequential(*layers)
83 |
84 | def forward(self, s: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
85 | x = torch.cat([s, a], dim=1)
86 | return self.fwd(x)
87 |
--------------------------------------------------------------------------------
/examples/fortran/graph/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | set(fortran_example_targets
2 | train_graph
3 | )
4 |
5 | add_executable(train_graph)
6 | target_sources(train_graph
7 | PRIVATE
8 | train.f90
9 | )
10 | set_target_properties(train_graph
11 | PROPERTIES OUTPUT_NAME train)
12 |
13 | foreach(tgt ${fortran_example_targets})
14 | target_include_directories(${tgt}
15 | PRIVATE
16 | ${CMAKE_BINARY_DIR}/include
17 | ${MPI_Fortran_INCLUDE_DIRS}
18 | ${HDF5_Fortran_INCLUDE_DIRS}
19 | )
20 | target_link_libraries(${tgt} PRIVATE MPI::MPI_Fortran)
21 | target_link_libraries(${tgt} PRIVATE "${PROJECT_NAME}_fort")
22 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME})
23 | if (CMAKE_Fortran_COMPILER_ID STREQUAL "NVHPC")
24 | target_compile_options(${tgt} PRIVATE $<$:-cpp -acc -gpu=${CUF_GPU_ARG}>)
25 | target_link_options(${tgt} PRIVATE $<$: -acc -gpu=${CUF_GPU_ARG}>)
26 | elseif (CMAKE_Fortran_COMPILER_ID STREQUAL "GNU")
27 | target_compile_options(${tgt} PRIVATE $<$:-cpp -fbackslash>)
28 | endif()
29 | endforeach()
30 |
31 | install(
32 | TARGETS ${fortran_example_targets}
33 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/graph
34 | )
35 |
36 | install(
37 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/config.yaml
38 | ${CMAKE_CURRENT_SOURCE_DIR}/generate_model.py
39 | ${CMAKE_CURRENT_SOURCE_DIR}/generate_loss.py
40 | ${CMAKE_CURRENT_SOURCE_DIR}/nodes.txt
41 | ${CMAKE_CURRENT_SOURCE_DIR}/connectivity.txt
42 | ${CMAKE_CURRENT_SOURCE_DIR}/visualize.py
43 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/graph)
44 |
--------------------------------------------------------------------------------
/examples/fortran/graph/config.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | report_frequency: 100
3 |
4 | model:
5 | type: torchscript
6 | parameters:
7 | filename: "model_torchscript.pt"
8 |
9 | loss:
10 | type: torchscript
11 | parameters:
12 | filename: "loss_torchscript.pt"
13 |
14 | optimizer:
15 | type: adam
16 | parameters:
17 | learning_rate: 1e-3
18 | beta1: 0.9
19 | beta2: 0.999
20 | weight_decay: 0
21 | eps: 1e-8
22 | amsgrad: 0
23 |
24 | lr_scheduler:
25 | type: cosine_annealing
26 | parameters:
27 | T_max: 100000
28 |
--------------------------------------------------------------------------------
/examples/fortran/graph/generate_loss.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
29 | import torch
30 |
31 | class CustomLoss(torch.nn.Module):
32 | def __init__(self):
33 | super(CustomLoss, self).__init__()
34 |
35 | def forward(self, prediction, label, node_types):
36 |
37 | # Compute MSE over all nodes
38 | err = (label - prediction)**2
39 |
40 | # Zero out error for boundary nodes
41 | mask = node_types != 0
42 | err *= mask.unsqueeze(-1)
43 |
44 | # Compute mean over non-boundary nodes
45 | mse = torch.sum(err) / (torch.sum(mask) * err.shape[1])
46 |
47 | return mse
48 |
49 | def main():
50 | # Create loss module
51 | loss = CustomLoss()
52 | print("loss module:", loss)
53 |
54 | try:
55 | # Move model to GPU, JIT, and save
56 | loss.to("cuda")
57 | except:
58 | print("PyTorch does not have CUDA support. Saving model on CPU.")
59 | loss_jit = torch.jit.script(loss)
60 | loss_jit.save("loss_torchscript.pt")
61 |
62 | if __name__ == "__main__":
63 | main()
64 |
--------------------------------------------------------------------------------
/examples/fortran/graph/generate_model.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
29 | import torch
30 |
31 | class MessagePassing(torch.nn.Module):
32 | def __init__(self, hidden_dim):
33 | super(MessagePassing, self).__init__()
34 |
35 | self.mlp_edge = torch.nn.Sequential(torch.nn.Linear(3*hidden_dim, hidden_dim),
36 | torch.nn.ReLU(),
37 | torch.nn.Linear(hidden_dim, hidden_dim),
38 | torch.nn.LayerNorm(hidden_dim))
39 | self.mlp_node = torch.nn.Sequential(torch.nn.Linear(2*hidden_dim, hidden_dim),
40 | torch.nn.ReLU(),
41 | torch.nn.Linear(hidden_dim, hidden_dim),
42 | torch.nn.LayerNorm(hidden_dim))
43 |
44 | def forward(self, edge_idx, node_feats, edge_feats):
45 | senders = edge_idx[:,0]
46 | receivers = edge_idx[:,1]
47 |
48 | edge_update = torch.cat([node_feats[senders], node_feats[receivers], edge_feats], dim=-1)
49 | edge_update = self.mlp_edge(edge_update)
50 |
51 | accumulate_edges = torch.zeros([node_feats.shape[0], edge_feats.shape[1]], dtype=edge_feats.dtype, device=edge_feats.device)
52 | receivers = receivers.unsqueeze(-1).expand(-1, edge_feats.shape[1])
53 | accumulate_edges = torch.scatter_add(accumulate_edges, src=edge_feats, index=receivers, dim=0)
54 | node_update = torch.cat([node_feats, accumulate_edges], dim=-1)
55 | node_update = self.mlp_node(node_update)
56 |
57 | edge_feats = edge_feats + edge_update
58 | node_feats = node_feats + node_update
59 |
60 | return node_feats, edge_feats
61 |
62 |
63 | class Net(torch.nn.Module):
64 | def __init__(self, in_node_features, in_edge_features, hidden_dim, n_message_passing_steps):
65 | super(Net, self).__init__()
66 | self.encoder_node = torch.nn.Sequential(torch.nn.Linear(in_node_features, hidden_dim),
67 | torch.nn.ReLU(),
68 | torch.nn.Linear(hidden_dim, hidden_dim),
69 | torch.nn.LayerNorm(hidden_dim))
70 | self.encoder_edge = torch.nn.Sequential(torch.nn.Linear(in_edge_features, hidden_dim),
71 | torch.nn.ReLU(),
72 | torch.nn.Linear(hidden_dim, hidden_dim),
73 | torch.nn.LayerNorm(hidden_dim))
74 |
75 | self.mp_layers = torch.nn.ModuleList()
76 | for _ in range(n_message_passing_steps):
77 | self.mp_layers.append(MessagePassing(hidden_dim))
78 |
79 | self.decoder = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim),
80 | torch.nn.ReLU(),
81 | torch.nn.Linear(hidden_dim, in_node_features))
82 |
83 | def forward(self, edge_idx, node_feats, edge_feats):
84 | # Encode node and edge features
85 | node_feats = self.encoder_node(node_feats)
86 | edge_feats = self.encoder_edge(edge_feats)
87 |
88 | # Message passing
89 | for mp in self.mp_layers:
90 | node_feats, edge_feats = mp(edge_idx, node_feats, edge_feats)
91 |
92 | # Decode node featues
93 | node_feats = self.decoder(node_feats)
94 |
95 | return node_feats
96 |
97 |
98 | def main():
99 | # Create model
100 | model = Net(1, 3, 128, 8)
101 | print("graph model:", model)
102 |
103 | try:
104 | # Move model to GPU, JIT, and save
105 | model.to("cuda")
106 | except:
107 | print("PyTorch does not have CUDA support. Saving model on CPU.")
108 | model_jit = torch.jit.script(model)
109 | model_jit.save("model_torchscript.pt")
110 |
111 | if __name__ == "__main__":
112 | main()
113 |
--------------------------------------------------------------------------------
/examples/fortran/graph/media/validation_results.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA/TorchFort/3ac715ffabe6f0e3fe7e6ee8276e0998d1513310/examples/fortran/graph/media/validation_results.gif
--------------------------------------------------------------------------------
/examples/fortran/graph/visualize.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
29 | import argparse as ap
30 | import glob
31 | import matplotlib.pyplot as plt
32 | import matplotlib.tri as tri
33 | from matplotlib.animation import FuncAnimation, PillowWriter
34 | import numpy as np
35 | import os
36 | import time
37 |
38 | def main(args):
39 |
40 | global reffiles, predfiles, artists, triangulation
41 | print(f"processing files in {args.input_path}...")
42 |
43 | reffiles = sorted(glob.glob(os.path.join(args.input_path, "reference_*.txt")))
44 | predfiles = sorted(glob.glob(os.path.join(args.input_path, "prediction_*.txt")))
45 |
46 | # Read mesh data
47 | nodes = np.loadtxt("nodes.txt", skiprows=1)
48 | triangles = np.loadtxt("connectivity.txt", skiprows=1)
49 | triangulation = tri.Triangulation(nodes[:,0], nodes[:,1], triangles)
50 |
51 | artists = []
52 |
53 | fig, ((ax1), (ax2)) = plt.subplots(2, 1)
54 | ax1.set_title("Ground Truth")
55 | ax1.set_xlabel(r"$x$")
56 | ax1.set_ylabel(r"$y$")
57 | ax2.set_title("Prediction")
58 | ax2.set_xlabel(r"$x$")
59 | ax2.set_ylabel(r"$y$")
60 |
61 | c = ax1.tricontourf(triangulation, np.loadtxt(reffiles[0]), levels=np.linspace(-0.1, 1.0, 15))
62 | artists += c.collections
63 | c = ax1.triplot(triangulation, linewidth=0.3, color='black')
64 | artists.append(c)
65 | c = ax2.tricontourf(triangulation, np.loadtxt(predfiles[0]), levels=np.linspace(-0.1, 1.0, 15))
66 | artists += c.collections
67 | c = ax2.triplot(triangulation, linewidth=0.3, color='black')
68 | artists.append(c)
69 |
70 | fig.tight_layout()
71 |
72 | def animate(i):
73 | global reffiles, predfiles, artists, triangulation
74 | for c in artists:
75 | try:
76 | c.remove()
77 | except:
78 | pass
79 | artists.clear()
80 |
81 | c = ax1.tricontourf(triangulation, np.loadtxt(reffiles[i]), levels=np.linspace(-0.1, 1.0, 15))
82 | artists += c.collections
83 | c = ax1.triplot(triangulation, linewidth=0.3, color='black')
84 | artists.append(c)
85 | c = ax2.tricontourf(triangulation, np.loadtxt(predfiles[i]), levels=np.linspace(-0.1, 1.0, 15))
86 | artists += c.collections
87 | c = ax2.triplot(triangulation, linewidth=0.3, color='black')
88 | artists.append(c)
89 |
90 |
91 |
92 | ani = FuncAnimation(fig, animate, frames=len(reffiles), repeat=False, interval=1)
93 |
94 | os.makedirs(args.output_path, exist_ok=True)
95 |
96 | def log(i, n):
97 | print(f"processed {i+1} of {n} frames..." )
98 | ani.save(os.path.join(args.output_path, "validation_results.gif"), writer=PillowWriter(fps=5), progress_callback=lambda i, n: log(i,n))
99 | print(f"video written to {os.path.join(args.output_path, 'validation_results.gif')}...")
100 |
101 | if __name__ == "__main__":
102 | parser = ap.ArgumentParser()
103 | parser.add_argument("--input_path", type=str, help="Directory containing result text files", required=True)
104 | parser.add_argument("--output_path", type=str, help="Directory to store the generated videos", required=True)
105 | args = parser.parse_args()
106 |
107 | main(args)
108 |
109 |
--------------------------------------------------------------------------------
/examples/fortran/simulation/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | find_package(HDF5 COMPONENTS Fortran REQUIRED)
2 |
3 | set(fortran_example_targets
4 | train
5 | train_distributed
6 | )
7 |
8 | add_executable(train)
9 | target_sources(train
10 | PRIVATE
11 | train.f90
12 | simulation.f90
13 | )
14 | set_target_properties(train PROPERTIES Fortran_MODULE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mod/0 )
15 |
16 | add_executable(train_distributed)
17 | target_sources(train_distributed
18 | PRIVATE
19 | train_distributed.f90
20 | simulation.f90
21 | )
22 | set_target_properties(train_distributed PROPERTIES Fortran_MODULE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mod/1 )
23 |
24 | foreach(tgt ${fortran_example_targets})
25 | target_include_directories(${tgt}
26 | PRIVATE
27 | ${CMAKE_BINARY_DIR}/include
28 | ${MPI_Fortran_INCLUDE_DIRS}
29 | ${HDF5_Fortran_INCLUDE_DIRS}
30 | )
31 | target_link_libraries(${tgt} PRIVATE MPI::MPI_Fortran)
32 | target_link_libraries(${tgt} PRIVATE hdf5::hdf5_fortran)
33 | target_link_libraries(${tgt} PRIVATE "${PROJECT_NAME}_fort")
34 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME})
35 | if (CMAKE_Fortran_COMPILER_ID STREQUAL "NVHPC")
36 | target_compile_options(${tgt} PRIVATE $<$:-cpp -acc -gpu=${CUF_GPU_ARG}>)
37 | target_link_options(${tgt} PRIVATE $<$: -acc -gpu=${CUF_GPU_ARG}>)
38 | elseif (CMAKE_Fortran_COMPILER_ID STREQUAL "GNU")
39 | target_compile_options(${tgt} PRIVATE $<$:-cpp -fbackslash>)
40 | endif()
41 | endforeach()
42 |
43 | install(
44 | TARGETS ${fortran_example_targets}
45 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/simulation
46 | )
47 |
48 |
49 | install(
50 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/config_mlp_native.yaml
51 | ${CMAKE_CURRENT_SOURCE_DIR}/config_fcn_torchscript.yaml
52 | ${CMAKE_CURRENT_SOURCE_DIR}/generate_fcn_model.py
53 | ${CMAKE_CURRENT_SOURCE_DIR}/visualize.py
54 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/simulation)
55 |
--------------------------------------------------------------------------------
/examples/fortran/simulation/config_fcn_torchscript.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | enable_wandb_hook: 1
3 | report_frequency: 100
4 |
5 | model:
6 | type: torchscript
7 | parameters:
8 | filename: "fcn_torchscript.pt"
9 |
10 | loss:
11 | type: MSE
12 |
13 | optimizer:
14 | type: adam
15 | parameters:
16 | learning_rate: 1e-3
17 | beta1: 0.9
18 | beta2: 0.999
19 | weight_decay: 0
20 | eps: 1e-8
21 | amsgrad: 0
22 |
23 | lr_scheduler:
24 | type: cosine_annealing
25 | parameters:
26 | T_max: 100000
27 |
--------------------------------------------------------------------------------
/examples/fortran/simulation/config_mlp_native.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | enable_wandb_hook: 1
3 | report_frequency: 100
4 |
5 | model:
6 | type: mlp
7 | parameters:
8 | dropout: 0.0
9 | layer_sizes: [1024, 1024]
10 |
11 | loss:
12 | type: MSE
13 |
14 | optimizer:
15 | type: adam
16 | parameters:
17 | learning_rate: 1e-3
18 | beta1: 0.9
19 | beta2: 0.999
20 | weight_decay: 0
21 | eps: 1e-8
22 | amsgrad: 0
23 |
24 | lr_scheduler:
25 | type: cosine_annealing
26 | parameters:
27 | T_max: 100000
28 |
--------------------------------------------------------------------------------
/examples/fortran/simulation/generate_fcn_model.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: BSD-3-Clause
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
29 | import torch
30 |
31 | class Net(torch.nn.Module):
32 | def __init__(self):
33 | super(Net, self).__init__()
34 | self.conv1 = torch.nn.Conv2d(1, 1, 3, padding=1, padding_mode="circular")
35 |
36 | def forward(self, x):
37 | return self.conv1(x)
38 |
39 |
40 | def main():
41 | # Create model
42 | model = Net()
43 | print("FCN model:", model)
44 |
45 | try:
46 | # Move model to GPU, JIT, and save
47 | model.to("cuda")
48 | except:
49 | print("PyTorch does not have CUDA support. Saving model on CPU.")
50 | model_jit = torch.jit.script(model)
51 | model_jit.save("fcn_torchscript.pt")
52 |
53 | if __name__ == "__main__":
54 | main()
55 |
--------------------------------------------------------------------------------
/examples/fortran/simulation/media/validation_results.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA/TorchFort/3ac715ffabe6f0e3fe7e6ee8276e0998d1513310/examples/fortran/simulation/media/validation_results.gif
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # basic packages
2 | ruamel-yaml
3 |
4 | # pytorch and some dependencies
5 | torch==2.7.0
6 | torchvision==0.22.0
7 | torchaudio==2.7.0
8 |
9 | # training monitoring
10 | wandb
11 |
12 | # RL example visualization related
13 | pygame
14 | moviepy
15 |
16 | # Supervised learning example visualization related
17 | matplotlib
18 | h5py
19 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/base_loss.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #include
34 |
35 | #include
36 |
37 | #include "internal/base_loss.h"
38 | #include "internal/param_map.h"
39 |
40 | namespace torchfort {
41 |
42 | struct BaseLoss : torch::nn::Module {
43 | virtual torch::Tensor forward(const std::vector& inputs,
44 | const std::vector& labels,
45 | const std::vector& extra_args) = 0;
46 | virtual void setup(const ParamMap& params) = 0;
47 | };
48 |
49 | } // namespace torchfort
50 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/base_lr_scheduler.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 | #include
33 | #include
34 | #include
35 |
36 | #include
37 |
38 | #include "internal/exceptions.h"
39 |
40 | namespace torchfort {
41 |
42 | class BaseLRScheduler : public torch::optim::LRScheduler {
43 | public:
44 | BaseLRScheduler(torch::optim::Optimizer& optimizer) : LRScheduler(optimizer) {}
45 |
46 | // Define generic save/load functionalities. Specialize in derived schedulers if
47 | // needed.
48 | void save(const std::string& fname) {
49 | torch::serialize::OutputArchive archive;
50 | archive.write("step_count", torch::IValue((int64_t)step_count_));
51 | archive.write("lrs", torch::IValue(get_current_lrs()));
52 | archive.save_to(fname);
53 | }
54 | void load(const std::string& fname, torch::optim::Optimizer& optimizer) {
55 | torch::serialize::InputArchive archive;
56 | archive.load_from(fname);
57 |
58 | torch::IValue ivalue;
59 | if (!archive.try_read("step_count", ivalue)) {
60 | THROW_INVALID_USAGE(fname + " is missing required data.");
61 | }
62 | int64_t step_count = ivalue.to();
63 | step_count_ = step_count;
64 |
65 | if (!archive.try_read("lrs", ivalue)) {
66 | THROW_INVALID_USAGE(fname + " is missing required data.");
67 | }
68 | auto lrs = ivalue.to>();
69 | // Can't use this method to set the LRs due to it being private in the base LR class.
70 | // set_optimizer_lrs(lrs);
71 | for (const auto i : c10::irange(optimizer.param_groups().size())) {
72 | optimizer.param_groups()[i].options().set_lr(lrs[i]);
73 | }
74 | }
75 | };
76 |
77 | } // namespace torchfort
78 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/base_model.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #include
34 |
35 | #include "internal/param_map.h"
36 |
37 | namespace torchfort {
38 |
39 | struct BaseModel : torch::nn::Module {
40 | virtual std::vector forward(const std::vector& inputs) = 0;
41 | virtual void setup(const ParamMap& params) = 0;
42 | };
43 |
44 | } // namespace torchfort
45 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/distributed.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #ifdef ENABLE_GPU
34 | #include
35 | #include
36 | #endif
37 | #include
38 |
39 | #include
40 |
41 | namespace torchfort {
42 |
43 | struct Comm {
44 | void initialize(bool initialize_nccl = false);
45 | void finalize();
46 | void allreduce(torch::Tensor& tensor, bool average = false) const;
47 | void allreduce(std::vector& tensors, bool average = false) const;
48 | void allreduce(double& val, bool average = false) const;
49 | void allreduce(float& val, bool average = false) const;
50 | void broadcast(torch::Tensor& tensor, int root = 0) const;
51 |
52 | int rank;
53 | int size;
54 | MPI_Comm mpi_comm;
55 | #ifdef ENABLE_GPU
56 | ncclComm_t nccl_comm = nullptr;
57 | cudaStream_t stream = nullptr;
58 | cudaEvent_t event = nullptr;
59 | #endif
60 | bool initialized = false;
61 |
62 | Comm(MPI_Comm mpi_comm) : mpi_comm(mpi_comm){};
63 | };
64 |
65 | } // namespace torchfort
66 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/logging.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #include
34 | #include
35 |
36 | #include "internal/model_pack.h"
37 |
38 | namespace torchfort {
39 |
40 | namespace logging {
41 |
42 | enum level { info, warn, error, wandb };
43 |
44 | std::string log_level_prefix(level log_level);
45 | void write(const std::filesystem::path& filename, const std::string& message, level log_level);
46 | void print(const std::string& message, level log_level);
47 |
48 | } // namespace logging
49 |
50 | // Declaration of external global variables
51 | extern std::unordered_map models;
52 |
53 | // specialized logging routines
54 | template
55 | void wandb_log(std::shared_ptr state, std::shared_ptr comm, const char* name, const char* metric_name,
56 | int64_t step, T value) {
57 | if (state->enable_wandb_hook) {
58 | std::stringstream os;
59 | os << "model: " << name << ", ";
60 | os << "step: " << step << ", ";
61 | os << metric_name << ": " << value;
62 | if (!comm || (comm && comm->rank == 0)) {
63 | torchfort::logging::write(state->report_file, os.str(), torchfort::logging::wandb);
64 | }
65 | }
66 | }
67 |
68 | template void wandb_log(const char* name, const char* metric_name, int64_t step, T value) {
69 | auto state = models[name].state.get();
70 | if (state->enable_wandb_hook) {
71 | std::stringstream os;
72 | os << "model: " << name << ", ";
73 | os << "step: " << step << ", ";
74 | os << metric_name << ": " << value;
75 | if (!models[name].comm || (models[name].comm && models[name].comm->rank == 0)) {
76 | torchfort::logging::write(state->report_file, os.str(), torchfort::logging::wandb);
77 | }
78 | }
79 | }
80 |
81 | } // namespace torchfort
82 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/losses.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #include
34 | #include
35 |
36 | #include
37 | #include
38 | #include
39 |
40 | #include "internal/base_loss.h"
41 | #include "internal/defines.h"
42 | #include "internal/param_map.h"
43 |
44 | namespace torchfort {
45 |
46 | struct L1Loss : BaseLoss {
47 | void setup(const ParamMap& params) override;
48 |
49 | torch::Tensor forward(const std::vector& inputs,
50 | const std::vector& labels,
51 | const std::vector& extra_args) override;
52 |
53 | torch::nn::L1Loss module;
54 | };
55 |
56 | struct MSELoss : BaseLoss {
57 | void setup(const ParamMap& params) override;
58 |
59 | torch::Tensor forward(const std::vector& inputs,
60 | const std::vector& labels,
61 | const std::vector& extra_args) override;
62 |
63 | torch::nn::MSELoss module;
64 | };
65 |
66 | struct TorchscriptLoss : BaseLoss {
67 | void setup(const ParamMap& params) override;
68 |
69 | torch::Tensor forward(const std::vector& inputs,
70 | const std::vector& labels,
71 | const std::vector& extra_args) override;
72 |
73 | std::shared_ptr module_jit;
74 | };
75 |
76 | // Creating loss_registry.
77 | BEGIN_LOSS_REGISTRY
78 |
79 | // Add entries for new losses in this section. First argument to REGISTER_LOSS is
80 | // a string key and the second argument is the class name.
81 | REGISTER_LOSS(L1, L1Loss)
82 | REGISTER_LOSS(MSE, MSELoss)
83 | REGISTER_LOSS(torchscript, TorchscriptLoss)
84 |
85 | END_LOSS_REGISTRY
86 |
87 | } // namespace torchfort
88 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/lr_schedulers.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 | #include
33 |
34 | #include
35 |
36 | #include "internal/base_lr_scheduler.h"
37 |
38 | namespace torchfort {
39 |
40 | class CosineAnnealingLR : public BaseLRScheduler {
41 | public:
42 | CosineAnnealingLR(torch::optim::Optimizer& optimizer, const unsigned T_max, const double eta_min = 0.0);
43 |
44 | private:
45 | std::vector get_lrs() override;
46 | double update_lr(const double& last_lr, const double& base_lr);
47 |
48 | const unsigned T_max_;
49 | const double eta_min_;
50 | std::vector base_lrs_;
51 | };
52 |
53 | class MultiStepLR : public BaseLRScheduler {
54 | public:
55 | MultiStepLR(torch::optim::Optimizer& optimizer, const std::vector& milestones, const double gamma = 0.1);
56 |
57 | private:
58 | std::vector get_lrs() override;
59 |
60 | const std::vector milestones_;
61 | const double gamma_;
62 | };
63 |
64 | class PolynomialLR : public BaseLRScheduler {
65 | public:
66 | PolynomialLR(torch::optim::Optimizer& optimizer, const unsigned total_iters, const double power = 1.0);
67 |
68 | private:
69 | std::vector get_lrs() override;
70 |
71 | const unsigned total_iters_;
72 | const double power_;
73 | };
74 |
75 | class StepLR : public BaseLRScheduler {
76 | public:
77 | StepLR(torch::optim::Optimizer& optimizer, const unsigned step_size, const double gamma = 0.1);
78 |
79 | private:
80 | std::vector get_lrs() override;
81 |
82 | const unsigned step_size_;
83 | const double gamma_;
84 | };
85 |
86 | class LinearLR : public BaseLRScheduler {
87 | public:
88 | LinearLR(torch::optim::Optimizer& optimizer, const unsigned total_iters, const double start_factor = 0.333,
89 | const double end_factor = 1.0);
90 |
91 | private:
92 | std::vector get_lrs() override;
93 |
94 | const unsigned total_iters_;
95 | const double start_factor_, end_factor_;
96 | };
97 |
98 | } // namespace torchfort
99 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/model_pack.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 | #include
33 |
34 | #include
35 |
36 | #include "internal/base_loss.h"
37 | #include "internal/base_lr_scheduler.h"
38 | #include "internal/distributed.h"
39 | #include "internal/model_state.h"
40 | #include "internal/model_wrapper.h"
41 |
42 | namespace torchfort {
43 |
44 | // Simple struct to group model, optimizer, lr scheduler, state, and comm objects
45 | struct ModelPack {
46 | std::shared_ptr model;
47 | std::shared_ptr optimizer;
48 | std::shared_ptr lr_scheduler;
49 | std::shared_ptr loss;
50 | std::shared_ptr comm;
51 | std::shared_ptr state;
52 | int grad_accumulation_steps = 1;
53 | };
54 |
55 | void save_model_pack(const ModelPack& model_pack, const std::string& fname, bool save_optimizer = true);
56 | void load_model_pack(ModelPack& model_pack, const std::string& fname, bool load_optimizer = true);
57 |
58 | } // namespace torchfort
59 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/model_state.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 | #include
33 | #include
34 |
35 | #include
36 |
37 | namespace torchfort {
38 |
39 | // Simple struct to store miscellaneous model state (e.g. iteration count)
40 | struct ModelState {
41 | int64_t step_train;
42 | int64_t step_inference;
43 | int64_t step_train_current; // training step of current run (ignoring restarted state)
44 | torch::Device device = torch::Device(torch::kCPU);
45 |
46 | // General option settings
47 | int32_t report_frequency;
48 | bool enable_wandb_hook;
49 | bool verbose;
50 | std::filesystem::path report_file;
51 |
52 | void save(const std::string& fname);
53 | void load(const std::string& fname);
54 | };
55 |
56 | } // namespace torchfort
57 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/model_wrapper.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #include
34 |
35 | #include "internal/base_model.h"
36 |
37 | namespace torchfort {
38 |
39 | class ModelWrapper {
40 | public:
41 | ModelWrapper(const std::shared_ptr& model);
42 |
43 | ModelWrapper(const std::shared_ptr& model_jit);
44 |
45 | ModelWrapper(const std::string& jit_model_fname);
46 |
47 | std::vector parameters() const;
48 |
49 | torch::OrderedDict named_parameters() const;
50 |
51 | void to(torch::Device device, bool non_blocking = false);
52 |
53 | void train();
54 |
55 | void eval();
56 |
57 | std::vector forward(const std::vector& inputs) const;
58 |
59 | void save(const std::string& fname) const;
60 |
61 | void load(const std::string& fname);
62 |
63 | torch::Device device() const;
64 |
65 | private:
66 | bool jit = false;
67 | std::shared_ptr model;
68 | std::shared_ptr model_jit;
69 | torch::Device device_ = torch::Device(torch::kCPU);
70 | };
71 |
72 | } // namespace torchfort
73 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/models.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #include
34 |
35 | #include
36 |
37 | #include "internal/base_model.h"
38 | #include "internal/defines.h"
39 | #include "internal/param_map.h"
40 |
41 | namespace torchfort {
42 |
43 | // MLP model in C++ using libtorch
44 | struct MLPModel : BaseModel, public std::enable_shared_from_this {
45 | void setup(const ParamMap& params) override;
46 | std::vector forward(const std::vector& inputs) override;
47 |
48 | double dropout;
49 | std::vector layer_sizes;
50 |
51 | // Use one of many "standard library" modules.
52 | std::vector fc_layers;
53 | std::vector biases;
54 | };
55 |
56 | struct SACMLPModel : BaseModel, public std::enable_shared_from_this {
57 | void setup(const ParamMap& params) override;
58 | std::vector forward(const std::vector& inputs) override;
59 |
60 | double dropout;
61 | std::vector layer_sizes;
62 | bool state_dependent_sigma;
63 |
64 | // A SAC Model has a common encoder and two output layers for mu and log-sigma
65 | std::vector encoder_layers;
66 | std::vector out_layers;
67 | std::vector biases;
68 | std::vector out_biases;
69 | };
70 |
71 | struct ActorCriticMLPModel : BaseModel, public std::enable_shared_from_this {
72 | void setup(const ParamMap& params) override;
73 | std::vector forward(const std::vector& inputs) override;
74 |
75 | double dropout;
76 | std::vector encoder_layer_sizes, actor_layer_sizes, value_layer_sizes;
77 | bool state_dependent_sigma;
78 |
79 | // An AC Model has a common encoder and then an MLP for actor and one for value
80 | std::vector encoder_layers, actor_layers, value_layers;
81 | std::vector encoder_biases, actor_biases, value_biases;
82 | };
83 |
84 | // Creating model_registry.
85 | BEGIN_MODEL_REGISTRY
86 |
87 | // Add entries for new models in this section.
88 | REGISTER_MODEL(MLP, MLPModel)
89 | REGISTER_MODEL(SACMLP, SACMLPModel)
90 | REGISTER_MODEL(ActorCriticMLP, ActorCriticMLPModel)
91 |
92 | END_MODEL_REGISTRY
93 |
94 | } // namespace torchfort
95 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/nvtx.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #include
34 |
35 | #ifdef ENABLE_GPU
36 | #include
37 | #endif
38 |
39 | namespace torchfort {
40 |
41 | // Helper class for NVTX ranges
42 | class nvtx {
43 | public:
44 | #ifdef ENABLE_GPU
45 | static void rangePush(const std::string& range_name) {
46 | static constexpr int ncolors_ = 8;
47 | static constexpr int colors_[ncolors_] = {0x3366CC, 0xDC3912, 0xFF9900, 0x109618,
48 | 0x990099, 0x3B3EAC, 0x0099C6, 0xDD4477};
49 | std::hash hash_fn;
50 | int color = colors_[hash_fn(range_name) % ncolors_];
51 | nvtxEventAttributes_t ev = {0};
52 | ev.version = NVTX_VERSION;
53 | ev.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
54 | ev.colorType = NVTX_COLOR_ARGB;
55 | ev.color = color;
56 | ev.messageType = NVTX_MESSAGE_TYPE_ASCII;
57 | ev.message.ascii = range_name.c_str();
58 | nvtxRangePushEx(&ev);
59 | }
60 |
61 | static void rangePop() { nvtxRangePop(); }
62 | #else
63 | static void rangePush(const std::string& range_name) {}
64 | static void rangePop() {}
65 | #endif
66 | };
67 |
68 | } // namespace torchfort
69 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/param_map.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #include
34 | #include
35 | #include
36 | #include
37 | #include
38 | #include
39 |
40 | #include "internal/exceptions.h"
41 | #include "internal/utils.h"
42 |
43 | namespace torchfort {
44 |
45 | // Helper function to get type as string
46 | template std::string type_string() {
47 | if (std::is_same::value) {
48 | return "int";
49 | } else if (std::is_same::value) {
50 | return "float";
51 | } else if (std::is_same::value) {
52 | return "double";
53 | } else if (std::is_same::value) {
54 | return "bool";
55 | }
56 | return "UNKNOWN";
57 | };
58 |
59 | // Conversion functor
60 | template struct ParamMapConverter {
61 | T operator()(const std::string& s) {
62 | try {
63 | if constexpr (std::is_same::value) {
64 | return std::stoi(sanitize(s));
65 | }
66 | if constexpr (std::is_same::value) {
67 | return std::stof(sanitize(s));
68 | }
69 | if constexpr (std::is_same::value) {
70 | return std::stod(sanitize(s));
71 | }
72 | if constexpr (std::is_same::value) {
73 | std::string s_ = sanitize(s);
74 | bool val;
75 | if (s_ == "true") {
76 | val = true;
77 | } else if (s_ == "false") {
78 | val = false;
79 | } else {
80 | val = std::stoi(s_);
81 | }
82 | return val;
83 | }
84 | if constexpr (std::is_same::value) {
85 | return s;
86 | }
87 | } catch (std::invalid_argument) {
88 | THROW_INVALID_USAGE("Could not convert provided parameter value " + s + " to required type.");
89 | }
90 |
91 | THROW_INTERNAL_ERROR("Unknown conversion type.");
92 | }
93 | };
94 |
95 | class ParamMap {
96 | public:
97 | template void add_param(const std::string& key, const std::vector& value);
98 |
99 | template std::vector get_param(const std::string& key) const;
100 |
101 | template std::vector get_param(const std::string& key, const T& defval) const;
102 |
103 | std::set keys() const;
104 |
105 | private:
106 | std::unordered_map> params;
107 | };
108 |
109 | template void ParamMap::add_param(const std::string& key, const std::vector& value) {
110 | params[sanitize(key)] = value;
111 | }
112 |
113 | template std::vector ParamMap::get_param(const std::string& key) const {
114 | const auto& entry = params.at(sanitize(key));
115 | std::vector values;
116 | std::transform(entry.begin(), entry.end(), std::back_inserter(values), ParamMapConverter());
117 | return values;
118 | }
119 |
120 | // parameter with default value
121 | template std::vector ParamMap::get_param(const std::string& key, const T& defval) const {
122 | try {
123 | const auto& entry = params.at(sanitize(key));
124 | std::vector values;
125 | std::transform(entry.begin(), entry.end(), std::back_inserter(values), ParamMapConverter());
126 | return values;
127 | } catch (std::out_of_range) {
128 | return {defval};
129 | }
130 | }
131 |
132 | } // namespace torchfort
133 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/rl/distributions.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 | #include
33 | #include
34 |
35 | #include "internal/rl/rl.h"
36 |
37 | namespace torchfort {
38 |
39 | namespace rl {
40 |
41 | class Distribution {
42 |
43 | public:
44 | Distribution(const Distribution&) = delete;
45 |
46 | // constructor
47 | Distribution() {}
48 | virtual torch::Tensor rsample() = 0;
49 | virtual torch::Tensor log_prob(torch::Tensor value) = 0;
50 | virtual torch::Tensor entropy() = 0;
51 | };
52 |
53 | class NormalDistribution : public Distribution, public std::enable_shared_from_this {
54 | public:
55 | NormalDistribution(torch::Tensor mu, torch::Tensor sigma) : mu_(mu), sigma_(sigma) {}
56 |
57 | torch::Tensor rsample() {
58 | auto noise = torch::empty_like(mu_).normal_(0., 1.);
59 | return torch::Tensor(mu_ + sigma_ * noise).clone();
60 | }
61 |
62 | torch::Tensor log_prob(torch::Tensor value) {
63 | auto var = torch::square(sigma_);
64 | auto log_sigma = sigma_.log();
65 | auto result = -torch::square(value - mu_) / (2 * var) - log_sigma - std::log(std::sqrt(2. * M_PI));
66 |
67 | return result;
68 | }
69 |
70 | torch::Tensor entropy() {
71 | auto log_sigma = sigma_.log();
72 | auto result = log_sigma + 0.5 * (1. + std::log(2. * M_PI));
73 |
74 | return result;
75 | }
76 |
77 | protected:
78 | torch::Tensor mu_;
79 | torch::Tensor sigma_;
80 | };
81 |
82 | } // namespace rl
83 |
84 | } // namespace torchfort
85 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/rl/rl.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #include "internal/rl/off_policy.h"
34 | #include "internal/rl/on_policy.h"
35 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/rl/utils.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 | #include
33 |
34 | #ifdef ENABLE_GPU
35 | #include
36 |
37 | #include
38 | #include
39 | #endif
40 | #include
41 |
42 | #include "internal/defines.h"
43 | #include "internal/logging.h"
44 | #include "internal/lr_schedulers.h"
45 | #include "internal/model_pack.h"
46 | #include "internal/rl/rl.h"
47 | #include "internal/setup.h"
48 |
49 | namespace torchfort {
50 |
51 | namespace rl {
52 |
53 | // helpers for sanitizing devices
54 | // check if both devices are different, one device has to be a cpu device.
55 | bool validate_devices(int device1, int device2);
56 |
57 | // helpers for extracting LRS from optimizer
58 | std::vector get_current_lrs(std::shared_ptr optimizer);
59 |
60 | // helpers for manipulating weights and grads
61 | void init_parameters(std::shared_ptr model);
62 | void copy_parameters(std::shared_ptr target, std::shared_ptr source);
63 | void set_grad_state(std::shared_ptr model, const bool requires_grad);
64 |
65 | // polyak update for model averaging:
66 | // computes: target = rho * target + (1-rho) * src
67 | // note that target here denotes the updated model parameters, and src the previous ones
68 | template
69 | void polyak_update(std::shared_ptr target, std::shared_ptr source, const T rho) {
70 |
71 | // add no grad guard
72 | torch::NoGradGuard no_grad;
73 |
74 | // get models
75 | auto tar = target->parameters();
76 | auto src = source->parameters();
77 |
78 | // some simple asserts here
79 | assert(tar.size() == src.size());
80 |
81 | // do in-place update: I don't know a good way of doing that with std::transform:
82 | for (size_t i = 0; i < tar.size(); ++i) {
83 | const auto& t = tar[i];
84 | const auto& s = src[i];
85 | t.mul_(rho);
86 | t.add_((1. - rho) * s);
87 | // t.copy_(torch::Tensor(rho * t + (1.-rho) * s));
88 | }
89 |
90 | return;
91 | }
92 |
93 | // Rescale the action from [a_low, a_high] to [-1, 1]
94 | template torch::Tensor scale_action(torch::Tensor unscaled_action, const T& a_low, const T& a_high) {
95 | auto scaled_action = static_cast(2.0) * ((unscaled_action - a_low) / (a_high - a_low)) - static_cast(1.0);
96 | scaled_action.to(unscaled_action.dtype());
97 |
98 | return scaled_action;
99 | }
100 |
101 | // Unscale the action from [-1., 1.] to [a_low, a_high]
102 | template torch::Tensor unscale_action(torch::Tensor scaled_action, const T& a_low, const T& a_high) {
103 | auto unscaled_action = 0.5 * (a_high - a_low) * (scaled_action + static_cast(1.)) + a_low;
104 | unscaled_action.to(scaled_action.dtype());
105 |
106 | return unscaled_action;
107 | }
108 |
109 | // explained variance
110 | torch::Tensor explained_variance(torch::Tensor q_pred, torch::Tensor q_true, std::shared_ptr comm);
111 |
112 | } // namespace rl
113 |
114 | } // namespace torchfort
115 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/setup.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #include
34 |
35 | #include
36 | #include
37 |
38 | #include "internal/base_loss.h"
39 | #include "internal/base_lr_scheduler.h"
40 | #include "internal/base_model.h"
41 | #include "internal/lr_schedulers.h"
42 | #include "internal/model_state.h"
43 | #include "internal/model_wrapper.h"
44 |
45 | namespace torchfort {
46 |
47 | void check_params(const std::set& supported_params, const std::set& provided_params);
48 |
49 | ParamMap get_params(const YAML::Node& params_node);
50 |
51 | std::shared_ptr get_model(const YAML::Node& model_node);
52 |
53 | std::shared_ptr get_loss(const YAML::Node& loss_node);
54 |
55 | std::shared_ptr get_optimizer(const YAML::Node& optimizer_node,
56 | std::vector parameters);
57 |
58 | std::shared_ptr get_optimizer(const YAML::Node& optimizer_node,
59 | const std::shared_ptr& model);
60 |
61 | std::shared_ptr get_lr_scheduler(const YAML::Node& lr_scheduler_node,
62 | const std::shared_ptr& optimizer);
63 |
64 | std::shared_ptr get_state(const char* name, const YAML::Node& state_node);
65 | } // namespace torchfort
66 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/tensor_list.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #include
34 |
35 | #include
36 |
37 | #include "internal/utils.h"
38 |
39 | namespace torchfort {
40 | struct TensorList {
41 | template
42 | void add_tensor(T* data, size_t dim, int64_t* shape) {
43 | auto tensor = get_tensor(data, dim, shape);
44 | tensors.push_back(tensor);
45 | tensors_original_.push_back(tensor);
46 | };
47 |
48 | void to(torch::Device device, bool non_blocking = false) {
49 | for (auto &t : tensors) {
50 | t = t.to(device, non_blocking);
51 | }
52 | };
53 |
54 | void reset() {
55 | tensors = tensors_original_;
56 | }
57 |
58 | std::vector tensors;
59 | // To preserve references to external data, we store the original tensor objects
60 | std::vector tensors_original_;
61 | };
62 | } // namespace torchfort
63 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/training.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #ifdef ENABLE_GPU
34 | #include
35 | #endif
36 |
37 | #include
38 |
39 | #include
40 |
41 | namespace torchfort {
42 |
43 | void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfort_tensor_list_t outputs_in,
44 | cudaStream_t ext_stream = 0);
45 |
46 | void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfort_tensor_list_t labels_in,
47 | float* loss_val, torchfort_tensor_list_t extra_loss_args_in, cudaStream_t ext_stream = 0);
48 |
49 | template
50 | void inference(const char* name, T* input, size_t input_dim, int64_t* input_shape, T* output, size_t output_dim,
51 | int64_t* output_shape, cudaStream_t ext_stream = 0) {
52 | TensorList inputs, outputs;
53 |
54 | inputs.add_tensor(input, input_dim, input_shape);
55 | outputs.add_tensor(output, output_dim, output_shape);
56 |
57 | inference_multiarg(name, &inputs, &outputs, ext_stream);
58 | }
59 |
60 | template
61 | void train(const char* name, T* input, size_t input_dim, int64_t* input_shape, T* label, size_t label_dim,
62 | int64_t* label_shape, T* loss_val, cudaStream_t ext_stream = 0) {
63 | TensorList inputs, labels;
64 |
65 | inputs.add_tensor(input, input_dim, input_shape);
66 | labels.add_tensor(label, label_dim, label_shape);
67 |
68 | // multiarg API expects float loss value, so use temporary here
69 | float loss_val_tmp;
70 | train_multiarg(name, &inputs, &labels, &loss_val_tmp, nullptr, ext_stream);
71 | *loss_val = loss_val_tmp;
72 | }
73 |
74 | } // namespace torchfort
75 |
--------------------------------------------------------------------------------
/src/csrc/include/internal/utils.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #include
34 | #include
35 | #include
36 |
37 | #include
38 | #include
39 |
40 | #ifdef ENABLE_GPU
41 | #include
42 | #endif
43 |
44 | #include "internal/exceptions.h"
45 | #include "internal/nvtx.h"
46 |
47 | namespace torchfort {
48 |
49 | // Function to convert string to lowercase and remove whitespace
50 | std::string sanitize(std::string s);
51 |
52 | // Function to convert a string to a filename
53 | std::string filename_sanitize(std::string s);
54 |
55 | // Function to return torch device from integer device value
56 | torch::Device get_device(int device);
57 |
58 | // Function to return torch device from pointer
59 | torch::Device get_device(const void* ptr);
60 |
61 | template torch::Dtype make_type() {
62 | if (std::is_same::value) {
63 | return torch::kFloat32;
64 | } else if (std::is_same::value) {
65 | return torch::kInt32;
66 | } else if (std::is_same::value) {
67 | return torch::kInt64;
68 | } else if (std::is_same::value) {
69 | return torch::kFloat64;
70 | } else {
71 | THROW_INVALID_USAGE("datatype not implemented");
72 | }
73 | }
74 |
75 | enum MemoryLayout { RowMajor = 0, ColMajor = 1 };
76 |
77 | template torch::Tensor get_tensor(T* tensor_ptr, size_t dim, int64_t* shape) {
78 | torchfort::nvtx::rangePush("get_tensor");
79 | // Set tensor options
80 | auto dev = get_device(tensor_ptr);
81 | torch::TensorOptions options = torch::TensorOptions().device(dev);
82 |
83 | // Get type
84 | auto type = make_type();
85 | options = options.dtype(type);
86 |
87 | // Create shape
88 | std::vector sizes(dim);
89 | switch (L) {
90 | case RowMajor:
91 | for (size_t i = 0; i < dim; ++i) {
92 | sizes[i] = shape[i];
93 | }
94 | break;
95 | case ColMajor:
96 | // For column major input data, reverse the shape order
97 | for (size_t i = 0; i < dim; ++i) {
98 | sizes[i] = shape[dim - i - 1];
99 | }
100 | break;
101 | }
102 | torch::IntArrayRef size = c10::makeArrayRef(sizes);
103 |
104 | // Create tensor
105 | auto tensor = torch::from_blob(
106 | tensor_ptr, sizes, [](void* ptr) {}, options);
107 | torchfort::nvtx::rangePop();
108 | return tensor;
109 | }
110 |
111 | // Helper function to convert string reduction names to torch enums.
112 | template T get_torch_reduction(const std::string& s) {
113 | if (s == "mean") {
114 | return torch::kMean;
115 | } else if (s == "sum") {
116 | return torch::kSum;
117 | } else if (s == "none") {
118 | return torch::kNone;
119 | } else {
120 | THROW_INVALID_USAGE("Unknown reduction type encountered.");
121 | }
122 | }
123 |
124 | // Helper function for printing tensor shapes
125 | std::string print_tensor_shape(torch::Tensor tensor);
126 |
127 | // Helper function to get the lrs
128 | std::vector get_current_lrs(const char* name);
129 |
130 | } // namespace torchfort
131 |
--------------------------------------------------------------------------------
/src/csrc/include/torchfort_enums.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #pragma once
32 |
33 | #define TORCHFORT_DEVICE_CPU (-1)
34 |
35 | /**
36 | * @brief This enum defines the data types supported.
37 | */
38 | enum torchfort_datatype_t { TORCHFORT_FLOAT = -1, TORCHFORT_DOUBLE = -2, TORCHFORT_INT32 = -3, TORCHFORT_INT64 = -4 };
39 |
40 | /**
41 | * @brief This enum defines the possible values return values from TorchFort. Most functions in the TorchFort library
42 | * will return one of these values to indicate if an operation has completed successfully or an error occured.
43 | */
44 | enum torchfort_result_t {
45 | TORCHFORT_RESULT_SUCCESS = 0, ///< The operation completed successfully
46 | TORCHFORT_RESULT_INVALID_USAGE = 1, ///< A user error, typically an invalid argument
47 | TORCHFORT_RESULT_NOT_SUPPORTED = 2, ///< A user error, requesting an invalid or unsupported operation configuration
48 | TORCHFORT_RESULT_INTERNAL_ERROR = 3, ///< An internal library error, should be reported
49 | TORCHFORT_RESULT_CUDA_ERROR = 4, ///< An error occured in the CUDA Runtime
50 | TORCHFORT_RESULT_MPI_ERROR = 5, ///< An error occured in the MPI library
51 | TORCHFORT_RESULT_NCCL_ERROR = 6 ///< An error occured in the NCCL library
52 | };
53 |
--------------------------------------------------------------------------------
/src/csrc/logging.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #include
32 | #include
33 | #include
34 | #include
35 | #include
36 | #include
37 |
38 | #include "internal/exceptions.h"
39 | #include "internal/logging.h"
40 |
41 | namespace torchfort {
42 | namespace logging {
43 |
44 | std::mutex logging_mutex;
45 | static std::unique_ptr logfile;
46 |
47 | std::string log_level_prefix(level log_level) {
48 | if (log_level == level::info) {
49 | return "TORCHFORT::INFO:";
50 | } else if (log_level == level::warn) {
51 | return "TORCHFORT::WARN:";
52 | } else if (log_level == level::error) {
53 | return "TORCHFORT::ERROR:";
54 | } else if (log_level == level::wandb) {
55 | return "TORCHFORT::WANDB:";
56 | } else {
57 | THROW_INVALID_USAGE("Unknown log level encountered.");
58 | }
59 | }
60 |
61 | void print(const std::string& message, level log_level) {
62 | std::cout << log_level_prefix(log_level) << " ";
63 | std::cout << message << std::endl;
64 | }
65 |
66 | bool open_logfile(const std::filesystem::path& filename) {
67 |
68 | // check if filename is empty, meaning we do not want to log
69 | if (filename.empty()) {
70 | return false;
71 | }
72 |
73 | // check if path exists
74 | if (filename.has_parent_path()) {
75 | auto path = filename.parent_path();
76 | std::filesystem::create_directories(path);
77 | }
78 |
79 | logfile = std::make_unique(filename, std::ofstream::out | std::ofstream::app);
80 |
81 | return true;
82 | }
83 |
84 | void write(const std::filesystem::path& filename, const std::string& message, level log_level) {
85 | std::lock_guard guard(logging_mutex);
86 |
87 | // check of logfile if already open
88 | if (logfile == nullptr) {
89 | if (!open_logfile(filename))
90 | return;
91 | }
92 | auto line = log_level_prefix(log_level) + " " + message + "\n";
93 | logfile->write(line.c_str(), line.size());
94 | logfile->flush();
95 | }
96 |
97 | } // namespace logging
98 | } // namespace torchfort
99 |
--------------------------------------------------------------------------------
/src/csrc/losses/l1_loss.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: BSD-3-Clause
4 | *
5 | * Redistribution and use in source and binary forms, with or without
6 | * modification, are permitted provided that the following conditions are met:
7 | *
8 | * 1. Redistributions of source code must retain the above copyright notice, this
9 | * list of conditions and the following disclaimer.
10 | *
11 | * 2. Redistributions in binary form must reproduce the above copyright notice,
12 | * this list of conditions and the following disclaimer in the documentation
13 | * and/or other materials provided with the distribution.
14 | *
15 | * 3. Neither the name of the copyright holder nor the names of its
16 | * contributors may be used to endorse or promote products derived from
17 | * this software without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | */
30 |
31 | #include
32 | #include
33 | #include
34 | #include
35 |
36 | #include