├── .clang-format ├── CMakeLists.txt ├── LICENSE ├── README.md ├── cmake └── test_mpi_f2c.f90 ├── docker ├── Dockerfile ├── Dockerfile_gnu └── Dockerfile_gnu_cpuonly ├── docs ├── Doxyfile ├── Makefile ├── _static │ └── style.css ├── _templates │ └── layout.html ├── api.rst ├── api │ ├── c_api.rst │ ├── config.rst │ └── f_api.rst ├── conf.py ├── extras.rst ├── index.rst ├── installation.rst ├── requirements.txt └── usage.rst ├── examples ├── cpp │ └── cart_pole │ │ ├── CMakeLists.txt │ │ ├── README.md │ │ ├── config.yaml │ │ ├── config_sim.yaml │ │ ├── env.cpp │ │ ├── env.h │ │ ├── media │ │ ├── cartpole.gif │ │ └── cartpole.mp4 │ │ ├── py_env.cpp │ │ ├── python │ │ ├── initialize_models.py │ │ ├── models.py │ │ └── visualize.py │ │ └── train.cpp └── fortran │ ├── graph │ ├── CMakeLists.txt │ ├── README.md │ ├── config.yaml │ ├── connectivity.txt │ ├── generate_loss.py │ ├── generate_model.py │ ├── media │ │ └── validation_results.gif │ ├── nodes.txt │ ├── train.f90 │ └── visualize.py │ └── simulation │ ├── CMakeLists.txt │ ├── README.md │ ├── config_fcn_torchscript.yaml │ ├── config_mlp_native.yaml │ ├── generate_fcn_model.py │ ├── media │ └── validation_results.gif │ ├── simulation.f90 │ ├── train.f90 │ ├── train_distributed.f90 │ └── visualize.py ├── requirements.txt ├── src ├── csrc │ ├── distributed.cpp │ ├── include │ │ ├── internal │ │ │ ├── base_loss.h │ │ │ ├── base_lr_scheduler.h │ │ │ ├── base_model.h │ │ │ ├── defines.h │ │ │ ├── distributed.h │ │ │ ├── exceptions.h │ │ │ ├── logging.h │ │ │ ├── losses.h │ │ │ ├── lr_schedulers.h │ │ │ ├── model_pack.h │ │ │ ├── model_state.h │ │ │ ├── model_wrapper.h │ │ │ ├── models.h │ │ │ ├── nvtx.h │ │ │ ├── param_map.h │ │ │ ├── rl │ │ │ │ ├── distributions.h │ │ │ │ ├── noise_actor.h │ │ │ │ ├── off_policy.h │ │ │ │ ├── off_policy │ │ │ │ │ ├── ddpg.h │ │ │ │ │ ├── sac.h │ │ │ │ │ └── td3.h │ │ │ │ ├── on_policy.h │ │ │ │ ├── on_policy │ │ │ │ │ └── ppo.h │ │ │ │ ├── policy.h │ │ │ │ ├── replay_buffer.h │ │ │ │ ├── rl.h │ │ │ │ ├── rollout_buffer.h │ │ │ │ └── utils.h │ │ │ ├── setup.h │ │ │ ├── tensor_list.h │ │ │ ├── training.h │ │ │ └── utils.h │ │ ├── torchfort.h │ │ ├── torchfort_enums.h │ │ └── torchfort_rl.h │ ├── logging.cpp │ ├── losses │ │ ├── l1_loss.cpp │ │ ├── mse_loss.cpp │ │ └── torchscript_loss.cpp │ ├── lr_schedulers │ │ ├── cosine_annealing_lr.cpp │ │ ├── linear_lr.cpp │ │ ├── multistep_lr.cpp │ │ ├── polynomial_lr.cpp │ │ ├── scheduler_setup.cpp │ │ └── step_lr.cpp │ ├── model_pack.cpp │ ├── model_state.cpp │ ├── model_wrapper.cpp │ ├── models │ │ ├── actor_critic_model.cpp │ │ ├── mlp_model.cpp │ │ └── sac_model.cpp │ ├── param_map.cpp │ ├── rl │ │ ├── off_policy │ │ │ ├── ddpg.cpp │ │ │ ├── interface.cpp │ │ │ ├── sac.cpp │ │ │ └── td3.cpp │ │ ├── on_policy │ │ │ ├── interface.cpp │ │ │ └── ppo.cpp │ │ ├── policy.cpp │ │ └── utils.cpp │ ├── setup.cpp │ ├── torchfort.cpp │ ├── training.cpp │ └── utils.cpp ├── fsrc │ └── torchfort_m.F90 └── python │ └── wandb_helper.py └── tests ├── general ├── CMakeLists.txt ├── configs │ ├── l1.yaml │ ├── l1_multiarg.yaml │ ├── mse.yaml │ ├── mse_multiarg.yaml │ ├── torchscript.yaml │ ├── torchscript_multiarg.yaml │ ├── torchscript_multiarg_extra.yaml │ └── torchscript_multiout.yaml ├── scripts │ └── setup_tests.py └── test_losses.cpp ├── rl ├── CMakeLists.txt ├── configs │ ├── ddpg.yaml │ ├── ppo.yaml │ ├── sac.yaml │ └── td3.yaml ├── environments.h ├── test_distributions.cpp ├── test_off_policy.cpp ├── test_on_policy.cpp ├── test_replay_buffer.cpp └── test_rollout_buffer.cpp ├── supervised ├── CMakeLists.txt ├── configs │ ├── missing_loss.yaml │ ├── missing_opt.yaml │ ├── mlp.yaml │ ├── mlp2.yaml │ ├── mlp2_gradacc.yaml │ ├── torchscript.yaml │ ├── torchscript_multiarg.yaml │ └── torchscript_multiarg_extra.yaml ├── scripts │ └── setup_tests.py ├── test_checkpoint.cpp └── test_training.cpp └── test_utils.h /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | BasedOnStyle: LLVM 3 | ColumnLimit: 120 4 | CommentPragmas: '^\\.+' 5 | DerivePointerAlignment: false 6 | Language: Cpp 7 | PointerAlignment: Left 8 | UseTab: Never 9 | AlignAfterOpenBracket: Align 10 | AlignTrailingComments: true 11 | AllowShortBlocksOnASingleLine: false 12 | AllowShortCaseLabelsOnASingleLine : false 13 | AllowShortIfStatementsOnASingleLine: false 14 | AllowShortLoopsOnASingleLine: false 15 | ... 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | SPDX-License-Identifier: BSD-3-Clause 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchFort 2 | 3 | An Online Deep Learning Interface for HPC programs on NVIDIA GPUs 4 | 5 | ## Introduction 6 | TorchFort is a DL training and inference interface for HPC programs implemented using LibTorch, the C++ backend used by the [PyTorch](https://pytorch.org]) framework. 7 | The goal of this library is to help practitioners and domain scientists to seamlessly combine their simulation codes with Deep Learning functionalities available 8 | within PyTorch. 9 | This library can be invoked directly from Fortran or C/C++ programs, enabling transparent sharing of data arrays to and from the DL framework all contained within the 10 | simulation process (i.e., no external glue/data-sharing code required). The library can directly load PyTorch model definitions exported to TorchScript and implements a 11 | configurable training process that users can control via a simple YAML configuration file format. The configuration files enable users to specify optimizer and loss selection, 12 | learning rate schedules, and much more. 13 | 14 | Please refer to the [documentation](https://nvidia.github.io/TorchFort/) for additional information on the library, build instructions, and usage details. 15 | 16 | Please refer to the [examples](examples) to see TorchFort in action. 17 | 18 | Contact us or open a GitHub issue if you are interested in using this library in your own solvers and have questions on usage and/or feature requests. 19 | 20 | ## License 21 | This library is released under a BSD 3-clause license, which can be found in [LICENSE](license). 22 | -------------------------------------------------------------------------------- /cmake/test_mpi_f2c.f90: -------------------------------------------------------------------------------- 1 | ! SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | ! SPDX-License-Identifier: BSD-3-Clause 3 | ! 4 | ! Redistribution and use in source and binary forms, with or without 5 | ! modification, are permitted provided that the following conditions are met: 6 | ! 7 | ! 1. Redistributions of source code must retain the above copyright notice, this 8 | ! list of conditions and the following disclaimer. 9 | ! 10 | ! 2. Redistributions in binary form must reproduce the above copyright notice, 11 | ! this list of conditions and the following disclaimer in the documentation 12 | ! and/or other materials provided with the distribution. 13 | ! 14 | ! 3. Neither the name of the copyright holder nor the names of its 15 | ! contributors may be used to endorse or promote products derived from 16 | ! this software without specific prior written permission. 17 | ! 18 | ! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | ! AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | ! IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | ! DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | ! FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | ! DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | ! SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | ! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | ! OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | ! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | module test_f2c 30 | use iso_c_binding 31 | implicit none 32 | 33 | type, bind(c) :: MPI_C_Comm 34 | integer(c_int64_t) :: comm 35 | end type MPI_C_Comm 36 | 37 | type, bind(c) :: MPI_F_Comm 38 | integer(c_int) :: comm 39 | end type MPI_F_Comm 40 | 41 | interface 42 | function MPI_Comm_f2c(fcomm) bind(C,name='MPI_Comm_f2c') result(res) 43 | import 44 | type(MPI_F_Comm), value :: fcomm 45 | type(MPI_C_Comm) :: res 46 | end function MPI_Comm_f2c 47 | end interface 48 | end module 49 | 50 | program main 51 | use mpi 52 | use test_f2c 53 | implicit none 54 | 55 | type(MPI_F_Comm) :: fcomm 56 | type(MPI_C_Comm) :: ccomm 57 | 58 | fcomm%comm = MPI_COMM_WORLD 59 | 60 | ccomm = MPI_Comm_f2c(fcomm) 61 | end program 62 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/cuda:12.8.1-devel-ubuntu22.04 2 | 3 | # Install System Dependencies 4 | ENV DEBIAN_FRONTEND noninteractive 5 | RUN apt update -y && \ 6 | apt install -y curl unzip wget cmake && \ 7 | apt install -y python3 python-is-python3 python3-pip python3-pybind11 && \ 8 | apt install -y git vim gfortran doxygen && \ 9 | apt install -y libibverbs-dev ibverbs-utils numactl 10 | 11 | # Install NVHPC SDK 12 | RUN wget https://developer.download.nvidia.com/hpc-sdk/25.3/nvhpc_2025_253_Linux_x86_64_cuda_12.8.tar.gz && \ 13 | tar xpzf nvhpc_2025_253_Linux_x86_64_cuda_12.8.tar.gz && \ 14 | nvhpc_2025_253_Linux_x86_64_cuda_12.8/install --quiet && \ 15 | rm -rf nvhpc_2025_253_Linux_x86_64_cuda_12.8 nvhpc_2025_253_Linux_x86_64_cuda_12.8.tar.gz 16 | 17 | ENV PATH /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/compilers/bin:$PATH 18 | ENV PATH /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/comm_libs/mpi/bin:$PATH 19 | ENV LD_LIBRARY_PATH /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/cuda/lib64:$LD_LIBRARY_PATH 20 | ENV LD_LIBRARY_PATH /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/comm_libs/mpi/lib:$LD_LIBRARY_PATH 21 | ENV LD_LIBRARY_PATH /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/comm_libs/nvshmem/lib:$LD_LIBRARY_PATH 22 | ENV LD_LIBRARY_PATH /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/math_libs/lib64:$LD_LIBRARY_PATH 23 | ENV CUDA_HOME /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/cuda 24 | 25 | RUN echo "source /opt/nvidia/hpc_sdk/Linux_x86_64/25.3/comm_libs/12.8/hpcx/latest/hpcx-init.sh; hpcx_load" >> /root/.bashrc 26 | 27 | # Install PyTorch 28 | RUN pip3 install torch==2.7.0 29 | 30 | # Install yaml-cpp 31 | RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \ 32 | cd yaml-cpp && \ 33 | mkdir build && cd build && \ 34 | cmake -DCMAKE_INSTALL_PREFIX=/opt/yaml-cpp \ 35 | -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" \ 36 | -DBUILD_SHARED_LIBS=OFF \ 37 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && \ 38 | make -j$(nproc) && make install 39 | ENV LD_LIBRARY_PATH /opt/yaml-cpp/lib:${LD_LIBRARY_PATH} 40 | 41 | # Install HDF5 42 | RUN wget https://github.com/HDFGroup/hdf5/archive/refs/tags/hdf5-1_14_3.tar.gz && \ 43 | tar xzf hdf5-1_14_3.tar.gz && \ 44 | cd hdf5-hdf5-1_14_3 && \ 45 | CC=mpicc FC=mpifort FCFLAGS=-fPIC CFLAGS=-fPIC \ 46 | ./configure --enable-parallel \ 47 | --enable-fortran \ 48 | --prefix=/opt/hdf5 && \ 49 | make -j$(nproc) install && \ 50 | cd .. && \ 51 | rm -rf hdf5-hdf5-1_14_3 hdf5-1_14_3.tar.gz 52 | ENV LD_LIBRARY_PATH /opt/hdf5/lib:$LD_LIBRARY_PATH 53 | 54 | # Install additional Python dependencies 55 | RUN pip3 install wandb ruamel-yaml h5py matplotlib pygame moviepy 56 | 57 | # Install TorchFort 58 | ENV FC=nvfortran 59 | ENV HDF5_ROOT=/opt/hdf5 60 | COPY . /torchfort 61 | RUN cd /torchfort && mkdir build && cd build && \ 62 | CUDA_PATH=/opt/nvidia/hpc_sdk/Linux_x86_64/25.3/cuda \ 63 | cmake -DCMAKE_INSTALL_PREFIX=/opt/torchfort \ 64 | -DCMAKE_CXX_COMPILER=`which g++` \ 65 | -DTORCHFORT_YAML_CPP_ROOT=/opt/yaml-cpp \ 66 | -DTORCHFORT_NCCL_ROOT=/opt/nccl/build \ 67 | -DTORCHFORT_BUILD_EXAMPLES=1 \ 68 | -DTORCHFORT_BUILD_TESTS=1 \ 69 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \ 70 | .. && \ 71 | make -j$(nproc) install && \ 72 | cd / && rm -rf torchfort 73 | ENV LD_LIBRARY_PATH /opt/torchfort/lib:${LD_LIBRARY_PATH} 74 | ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH} 75 | 76 | ENTRYPOINT bash 77 | -------------------------------------------------------------------------------- /docker/Dockerfile_gnu: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/cuda:12.8.1-devel-ubuntu22.04 2 | 3 | # Install System Dependencies 4 | ENV DEBIAN_FRONTEND noninteractive 5 | RUN apt update -y && \ 6 | apt install -y curl unzip wget cmake && \ 7 | apt install -y python3 python-is-python3 python3-pip python3-pybind11 && \ 8 | apt install -y git vim gfortran doxygen && \ 9 | apt install -y libibverbs-dev ibverbs-utils numactl 10 | 11 | # Download HPCX and compile with Fortran support 12 | RUN cd /opt && \ 13 | wget http://content.mellanox.com/hpc/hpc-x/v2.22.1rc4/hpcx-v2.22.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64.tbz && \ 14 | tar xjf hpcx-v2.22.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64.tbz && \ 15 | mv hpcx-v2.22.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64 hpcx && \ 16 | rm -rf hpcx/ompi && \ 17 | cd hpcx/sources && \ 18 | tar xzf openmpi-gitclone.tar.gz && \ 19 | cd openmpi-gitclone && \ 20 | LD_LIBRARY_PATH=/opt/hpcx/hcoll/lib:/opt/hpcx/ucc/lib:/opt/hpcx/ucx/lib:$LD_LIBRARY_PATH \ 21 | FC=gfortran CC=gcc CXX=g++ ./configure --prefix=/opt/hpcx/ompi \ 22 | --with-libevent=internal \ 23 | --enable-mpi1-compatibility \ 24 | --without-xpmem \ 25 | --with-cuda=/usr/local/cuda \ 26 | --with-slurm \ 27 | --with-platform=contrib/platform/mellanox/optimized \ 28 | --with-hcoll=/opt/hpcx/hcoll \ 29 | --with-ucx=/opt/hpcx/ucx \ 30 | --with-ucc=/opt/hpcx/ucc && \ 31 | make -j$(nproc) install && \ 32 | cd /opt && rm -rf /opt/hpcx/sources/openmpi-gitclone && rm hpcx-v2.22.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64.tbz 33 | 34 | ENV PATH /opt/hpcx/ompi/bin:$PATH 35 | ENV LD_LIBRARY_PATH /opt/hpcx/ompi/lib:$LD_LIBRARY_PATH 36 | 37 | RUN echo "source /opt/hpcx/hpcx-init.sh; hpcx_load" >> /root/.bashrc 38 | 39 | # Install PyTorch 40 | RUN pip3 install torch==2.7.0 41 | 42 | # Install yaml-cpp 43 | RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \ 44 | cd yaml-cpp && \ 45 | mkdir build && cd build && \ 46 | cmake -DCMAKE_INSTALL_PREFIX=/opt/yaml-cpp \ 47 | -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" \ 48 | -DBUILD_SHARED_LIBS=OFF \ 49 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && \ 50 | make -j$(nproc) && make install 51 | ENV LD_LIBRARY_PATH /opt/yaml-cpp/lib:${LD_LIBRARY_PATH} 52 | 53 | # Install HDF5 54 | RUN wget https://github.com/HDFGroup/hdf5/archive/refs/tags/hdf5-1_14_3.tar.gz && \ 55 | tar xzf hdf5-1_14_3.tar.gz && \ 56 | cd hdf5-hdf5-1_14_3 && \ 57 | CC=mpicc FC=mpifort \ 58 | ./configure --enable-parallel \ 59 | --enable-fortran \ 60 | --prefix=/opt/hdf5 && \ 61 | make -j$(nproc) install && \ 62 | cd .. && \ 63 | rm -rf hdf5-hdf5-1_14_3 hdf5-1_14_3.tar.gz 64 | ENV LD_LIBRARY_PATH /opt/hdf5/lib:$LD_LIBRARY_PATH 65 | 66 | # Install additional Python dependencies 67 | RUN pip3 install wandb ruamel-yaml h5py matplotlib pygame moviepy 68 | 69 | # Install TorchFort 70 | ENV FC=gfortran 71 | ENV HDF5_ROOT=/opt/hdf5 72 | COPY . /torchfort 73 | RUN cd /torchfort && mkdir build && cd build && \ 74 | cmake -DCMAKE_INSTALL_PREFIX=/opt/torchfort \ 75 | -DTORCHFORT_YAML_CPP_ROOT=/opt/yaml-cpp \ 76 | -DTORCHFORT_NCCL_ROOT=/opt/nccl/build \ 77 | -DTORCHFORT_BUILD_EXAMPLES=1 \ 78 | -DTORCHFORT_BUILD_TESTS=1 \ 79 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \ 80 | .. && \ 81 | make -j$(nproc) install && \ 82 | cd / && rm -rf torchfort 83 | ENV LD_LIBRARY_PATH /opt/torchfort/lib:${LD_LIBRARY_PATH} 84 | ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH} 85 | 86 | ENTRYPOINT bash 87 | -------------------------------------------------------------------------------- /docker/Dockerfile_gnu_cpuonly: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | # Install System Dependencies 4 | ENV DEBIAN_FRONTEND noninteractive 5 | RUN apt update -y && \ 6 | apt install -y build-essential && \ 7 | apt install -y curl unzip wget cmake && \ 8 | apt install -y python3 python-is-python3 python3-pip python3-pybind11 && \ 9 | apt install -y git vim gfortran doxygen && \ 10 | apt install -y libibverbs-dev ibverbs-utils numactl 11 | 12 | # Download OpenMPI and compile with Fortran support 13 | RUN cd /opt && \ 14 | wget https://download.open-mpi.org/release/open-mpi/v5.0/openmpi-5.0.5.tar.gz && \ 15 | tar xzf openmpi-5.0.5.tar.gz && \ 16 | cd openmpi-5.0.5 && \ 17 | FC=gfortran CC=gcc CXX=g++ ./configure --prefix=/opt/openmpi \ 18 | --with-libevent=internal \ 19 | --enable-mpi1-compatibility \ 20 | --without-xpmem \ 21 | --with-slurm && \ 22 | make -j$(nproc) install && \ 23 | cd /opt && rm -rf openmpi-5.0.5 && rm openmpi-5.0.5.tar.gz 24 | 25 | ENV PATH /opt/openmpi/bin:$PATH 26 | ENV LD_LIBRARY_PATH /opt/openmpi/lib:$LD_LIBRARY_PATH 27 | 28 | # Install PyTorch 29 | RUN pip3 install torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu 30 | 31 | # Install yaml-cpp 32 | RUN git clone https://github.com/jbeder/yaml-cpp.git --branch 0.8.0 && \ 33 | cd yaml-cpp && \ 34 | mkdir build && cd build && \ 35 | cmake -DCMAKE_INSTALL_PREFIX=/opt/yaml-cpp \ 36 | -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" \ 37 | -DBUILD_SHARED_LIBS=OFF \ 38 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && \ 39 | make -j$(nproc) && make install 40 | ENV LD_LIBRARY_PATH /opt/yaml-cpp/lib:${LD_LIBRARY_PATH} 41 | 42 | # Install HDF5 43 | RUN wget https://github.com/HDFGroup/hdf5/archive/refs/tags/hdf5-1_14_3.tar.gz && \ 44 | tar xzf hdf5-1_14_3.tar.gz && \ 45 | cd hdf5-hdf5-1_14_3 && \ 46 | CC=mpicc FC=mpifort \ 47 | ./configure --enable-parallel \ 48 | --enable-fortran \ 49 | --prefix=/opt/hdf5 && \ 50 | make -j$(nproc) install && \ 51 | cd .. && \ 52 | rm -rf hdf5-hdf5-1_14_3 hdf5-1_14_3.tar.gz 53 | ENV LD_LIBRARY_PATH /opt/hdf5/lib:$LD_LIBRARY_PATH 54 | 55 | # Install additional Python dependencies 56 | RUN pip3 install wandb ruamel-yaml h5py matplotlib pygame moviepy 57 | 58 | # Install TorchFort without GPU support 59 | ENV FC=gfortran 60 | ENV HDF5_ROOT=/opt/hdf5 61 | COPY . /torchfort 62 | RUN cd /torchfort && mkdir build && cd build && \ 63 | cmake -DCMAKE_INSTALL_PREFIX=/opt/torchfort \ 64 | -DTORCHFORT_YAML_CPP_ROOT=/opt/yaml-cpp \ 65 | -DTORCHFORT_ENABLE_GPU=0 \ 66 | -DTORCHFORT_BUILD_EXAMPLES=1 \ 67 | -DTORCHFORT_BUILD_TESTS=1 \ 68 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \ 69 | .. && \ 70 | make -j$(nproc) install && \ 71 | cd / && rm -rf torchfort 72 | ENV LD_LIBRARY_PATH /opt/torchfort/lib:${LD_LIBRARY_PATH} 73 | ENV LD_LIBRARY_PATH /usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH} 74 | 75 | ENTRYPOINT bash 76 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | doxygen Doxyfile 21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | -------------------------------------------------------------------------------- /docs/_static/style.css: -------------------------------------------------------------------------------- 1 | .wy-nav-content { 2 | max-width: 1200px !important; 3 | } 4 | -------------------------------------------------------------------------------- /docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | {% block sidebartitle %} {{ super() }} 3 | 4 | 32 | {% endblock %} 33 | 34 | {% block footer %} {{ super() }} 35 | 36 | 51 | {% endblock %} 52 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | .. _api-label: 2 | 3 | ############# 4 | TorchFort API 5 | ############# 6 | 7 | The following sections describe the types and functions available in the TorchFort library for C/C++ and Fortran programs and 8 | the also the configuration file structure and available options. 9 | 10 | .. toctree:: 11 | 12 | api/config 13 | api/c_api 14 | api/f_api 15 | 16 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | import sphinx_rtd_theme 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | # import os 16 | # import sys 17 | # sys.path.insert(0, os.path.abspath('.')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'torchfort' 23 | copyright = '2023, NVIDIA Corporation' 24 | author = 'NVIDIA Corporation' 25 | 26 | # The full version, including alpha/beta/rc tags 27 | release = '2023' 28 | 29 | 30 | # -- General configuration --------------------------------------------------- 31 | 32 | # Add any Sphinx extension module names here, as strings. They can be 33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 34 | # ones. 35 | extensions = [ 36 | 'breathe', 37 | 'sphinx.ext.mathjax', 38 | 'sphinx_tabs.tabs', 39 | 'sphinxfortran.fortran_domain', 40 | ] 41 | 42 | # Add any paths that contain templates here, relative to this directory. 43 | templates_path = ['_templates'] 44 | 45 | # List of patterns, relative to source directory, that match files and 46 | # directories to ignore when looking for source files. 47 | # This pattern also affects html_static_path and html_extra_path. 48 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 49 | 50 | # The name of the Pygments (syntax highlighting) style to use. 51 | #pygments_style = 'sphinx' 52 | 53 | highlight_language = 'cpp' 54 | 55 | def setup(app): 56 | app.add_css_file('style.css') 57 | 58 | # -- Options for HTML output ------------------------------------------------- 59 | 60 | # The theme to use for HTML and HTML Help pages. See the documentation for 61 | # a list of builtin themes. 62 | # 63 | html_theme = 'sphinx_rtd_theme' 64 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 65 | 66 | html_theme_options = { 67 | 'navigation_depth': 6 68 | } 69 | 70 | # Add any paths that contain custom static files (such as style sheets) here, 71 | # relative to this directory. They are copied after the builtin static files, 72 | # so a file named "default.css" will overwrite the builtin "default.css". 73 | html_static_path = ['_static'] 74 | 75 | breathe_projects = { "torchfort": "xml/" } 76 | breathe_default_project = "torchfort" 77 | -------------------------------------------------------------------------------- /docs/extras.rst: -------------------------------------------------------------------------------- 1 | ###### 2 | Extras 3 | ###### 4 | 5 | .. _wandb_support-ref: 6 | 7 | Weights and Biases Support 8 | ========================== 9 | 10 | `Weights and Biases `_ (wandb) is a popular tool for monitoring machine learning training workflows. Wandb supports plotting and comparing loss curves and other relevant deep learning metrics, system utilization (including GPU, CPU and memory utilization) and other advanced logging functionalities like uploading images and/or videos, network weights and data artifacts. 11 | 12 | Since wandb does currently not offer a C++ interface and thus cannot be called from Fortran or C/C++ directly, we've implemented a wandb daemon written in Python instead. This daemon runs as a background process and waits for changes in a log file generated by the TorchFort application. In order to enable wandb support for your TorchFort application, the following steps have to be performed. 13 | 14 | Add Custom Metrics Reporting to Application 15 | ------------------------------------------- 16 | 17 | The TorchFort training routines already provide logging of training step, loss values as well as learning rate which are captured by the wandb daemon. Additional custom metrics can be added manually by the user. For this purpose, the user may add calls of ``torchfort_wandb_log`` or ``torchfort_rl_off_policy_wandb_log`` for traditional and reinforcement learning applications respectively (see :ref:`supervised_learning-ref` and :ref:`reinforcement_learning-ref` for details about why we provide different implementations for these two cases). For more information, see :ref:`torchfort_api_c-ref` for C/C++ and :ref:`torchfort_api_f-ref` for Fortran applications. 18 | 19 | Set up Environment 20 | ------------------ 21 | 22 | You need to specify your `wandb api token `_ via the environment variable ``WANDB_API_KEY`` (see the `wandb documentation on available environment variables `_ for details). 23 | Furthermore, the daemon needs to know where the the wandb logging data from the TorchFort application will be stored. This can be done by defining the environment variable ``TORCHFORT_LOGDIR``. Lastly, a user0defined wandb logging directory ``WANDB_LOGGING_DIR`` can be created to gather all wandb information as well as the config file in a place specific to the run. 24 | 25 | .. note:: 26 | The logging directory ``TORCHFORT_LOGDIR`` needs to be specified before the daemon and TorchFort application are launched. 27 | 28 | Start Background Watcher Process 29 | -------------------------------- 30 | 31 | Now, the wandb daemon process needs to be started. Assuming TorchFort was installed in ``TORCHFORT_INSTALL_DIR``, we can run 32 | 33 | .. code-block:: bash 34 | 35 | python ${TORCHFORT_INSTALL_DIR}/bin/python/wandb_helper.py \ 36 | --wandb_dir=${WANDB_LOGGING_DIR} \ 37 | --wandb_group= \ 38 | --wandb_project= \ 39 | --wandb_entity= \ 40 | --run_tag= \ 41 | --timeout=2400 & 42 | 43 | The wandb group, project as well as entity name correspond to the wandb project you are logging to. Those correspond to the respective arguments of ``wandb.init`` documented `here `_. Note that the group does not need to exist and will be created during initialization. The run tag can be any alphanumeric string and can be used to identify the specific run on the wandb dashboard. Lastly, the timeout (measured in seconds) determines for how long the background process will wait for changes to appear in ``${TORCHFORT_LOGIDR}/torchfort.log`` before wrapping up the monitoring. 44 | 45 | .. note:: 46 | Do not forget to launch the daemon into the background. 47 | 48 | Start Your TorchFort Application 49 | -------------------------------- 50 | 51 | In the configuration file for your TorchFort application, make sure to enable wandb logging in the ``general`` section by adding or modifying the line ``enable_wandb_hook: 1``. Lastly, start the TorchFort application as usual, e.g.: 52 | 53 | .. code-block:: bash 54 | 55 | ./my_torchfort_app arg1 arg2 arg3 56 | 57 | The daemone process will pick up the log lines from ``${TORCHFORT_LOGIDR}/torchfort.log`` and display the data on the corresponding job dashboard. 58 | 59 | .. note:: 60 | The daemon can finalize the monitoring while the TorchFort application is still running if the timeout is not set sufficiently large, especially for long running applications with very sparse logging. 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. TorchFort documentation master file, created by 2 | sphinx-quickstart on Wed Jun 1 13:44:41 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | ############################################################################ 7 | TorchFort: An Online Deep Learning Interface for HPC programs on NVIDIA GPUs 8 | ############################################################################ 9 | These pages contain the documentation for TorchFort, an online deep learning interface for HPC programs. 10 | 11 | TorchFort is a DL training and inference interface for HPC programs implemented using LibTorch, the C++ backend used by the `PyTorch `_ framework. 12 | The goal of this library is to help practitioners and domain scientists to seamlessly combine their simulation codes with Deep Learning functionalities available 13 | within PyTorch. 14 | This library can be invoked directly from Fortran or C/C++ programs, enabling transparent sharing of data arrays to and from the DL framework all contained within the 15 | simulation process (i.e., no external glue/data-sharing code required). The library can directly load PyTorch model definitions exported to TorchScript and implements a 16 | configurable training process that users can control via a simple YAML configuration file format. The configuration files enable users to specify optimizer and loss selection, 17 | learning rate schedules, and much more. 18 | 19 | Please contact us or open a GitHub issue if you are interested in using this library 20 | in your own solvers and have questions on usage and/or feature requests. 21 | 22 | 23 | Table of Contents 24 | ================= 25 | .. toctree:: 26 | :maxdepth: 4 27 | 28 | installation 29 | usage 30 | api 31 | extras 32 | 33 | 34 | Indices and tables 35 | ================== 36 | 37 | * :ref:`genindex` 38 | * :ref:`modindex` 39 | * :ref:`search` 40 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | ############ 2 | Installation 3 | ############ 4 | 5 | TorchFort can be installed in multiple ways but we highly recommend building and using a Docker container. 6 | 7 | Docker Installation 8 | ------------------- 9 | 10 | We provide a `Dockerfile `_ which contains all relevant dependencies and builds using the `NVIDIA HPC SDK `_ software libraries and compilers, which is our recommended way to build TorchFort. In order to build TorchFort using Docker, simply clone the repo and call: 11 | 12 | .. code-block:: bash 13 | 14 | docker build -t torchfort:latest -f docker/Dockerfile . 15 | 16 | from the top level directory of the repo. Inside the container, TorchFort will be installed in ``/opt/torchfort``. 17 | 18 | We provide an alternative docker file `Dockerfile_gnu `_ which can be used to build TorchFort using GNU compilers. Additionally, we provide a docker file `Dockerfile_gnu_cpuonly `_ which can be used to build TorchFort using GNU compilers without GPU support enabled. 19 | 20 | CMake Installation 21 | ------------------ 22 | 23 | For a native installation TorchFort provides a `CMakeList.txt `_ file. Please make sure that the following required packages are installed on your system before installing TorchFort: 24 | 25 | * Requirements for core functionality and examples: 26 | 27 | - CUDA 12.1 or newer 28 | - ``python`` version 3.6 or higher 29 | - ``pybind11`` 30 | - ``yaml-cpp`` from https://github.com/jbeder/yaml-cpp.git 31 | - MPI 32 | - NVIDIA Collective Communication Library (``NCCL``) 33 | - ``HDF5`` 34 | - the Python modules specified in `requirements.txt `_ 35 | - GNU or `NVHPC `_ compilers. NVHPC compilers are **required** if CUDA Fortran device array support is desired. 36 | 37 | * Additional requirements for building this documentation: 38 | 39 | - Doxygen 40 | - the Python modules specified in `docs/requirements.txt `_ 41 | 42 | For CPU-only builds, CUDA and NCCL are not required. 43 | 44 | 45 | To build TorchFort, clone the repo then call the following from the root directory: 46 | 47 | .. code-block:: bash 48 | 49 | mkdir build && cd build 50 | cmake -DCMAKE_INSTALL_PREFIX= \ 51 | -DTORCHFORT_YAML_CPP_ROOT= \ 52 | -DTORCHFORT_BUILD_EXAMPLES=1 \ 53 | -DCMAKE_PREFIX_PATH="`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`" \ 54 | .. 55 | make -j install 56 | 57 | See the top level `CMakeList.txt `_ file for additional CMake configuration options. 58 | 59 | Build Documentation 60 | ------------------- 61 | 62 | The documentation can be built with the corresponding ``Makefile`` in the ``docs`` directory. Make sure that the requirements are installed and call: 63 | 64 | .. code-block:: bash 65 | 66 | cd docs && make html 67 | 68 | The docs will be located in ``docs/_build/html`` and can be viewed locally in your web browser. 69 | 70 | Directory Structure 71 | ------------------- 72 | 73 | Independent of how you decide to install TorchFort, the directory structure will be as follows:: 74 | 75 | 76 | |--- bin 77 | |--- examples 78 | |--- cpp 79 | |--- fortran 80 | |--- python 81 | |--- include 82 | |--- lib 83 | 84 | The ``bin`` folder contains the examples written in C++ or Fortran located in the corresponding subdirectories. The ``python`` subfolder contains the Python wrappers for :ref:`wandb_support-ref`. 85 | 86 | The Fortran module ``torchfort.mod`` as well as the C headers can be found inside the ``include`` folder and the dynamic libraries inside the ``lib`` folder. 87 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | breathe 2 | sphinx-rtd-theme==1.1.1 3 | sphinx-tabs==3.4.0 4 | sphinx-fortran 5 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(cart_pole_example_targets 2 | train_cart_pole 3 | ) 4 | 5 | add_library(environments STATIC) 6 | target_sources(environments 7 | PRIVATE 8 | env.cpp 9 | ) 10 | set_property(TARGET environments PROPERTY POSITION_INDEPENDENT_CODE ON) 11 | 12 | add_executable(train_cart_pole) 13 | target_sources(train_cart_pole 14 | PRIVATE 15 | train.cpp 16 | ) 17 | 18 | find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED) 19 | find_package(pybind11 CONFIG REQUIRED) 20 | pybind11_add_module(PyEnvironments py_env.cpp) 21 | target_link_libraries(PyEnvironments PRIVATE environments) 22 | 23 | foreach(tgt ${cart_pole_example_targets}) 24 | target_include_directories(${tgt} 25 | PRIVATE 26 | ${YAML_CPP_INCLUDE_DIR} 27 | ${MPI_CXX_INCLUDE_DIRS} 28 | ${CUDAToolkit_INCLUDE_DIRS} 29 | ${Python_INCLUDE_DIRS} 30 | ${CMAKE_BINARY_DIR}/include 31 | ) 32 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME}) 33 | target_link_libraries(${tgt} PRIVATE ${TORCH_LIBRARIES}) 34 | target_link_libraries(${tgt} PRIVATE ${Python_LIBRARIES}) 35 | target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX) 36 | target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY}) 37 | target_link_libraries(${tgt} PRIVATE environments) 38 | target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 39 | target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 40 | if (TORCHFORT_ENABLE_GPU) 41 | target_include_directories(${tgt} 42 | PRIVATE 43 | ${CUDAToolkit_INCLUDE_DIRS} 44 | ) 45 | target_link_libraries(${tgt} PRIVATE CUDA::cudart) 46 | endif() 47 | endforeach() 48 | 49 | # installation 50 | # executable 51 | install( 52 | TARGETS ${cart_pole_example_targets} 53 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole 54 | ) 55 | 56 | # python env 57 | install( 58 | TARGETS PyEnvironments 59 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole/python 60 | ) 61 | 62 | # config files 63 | install( 64 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/config.yaml ${CMAKE_CURRENT_SOURCE_DIR}/config_sim.yaml 65 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole 66 | ) 67 | 68 | # python files 69 | install( 70 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/python/models.py ${CMAKE_CURRENT_SOURCE_DIR}/python/initialize_models.py ${CMAKE_CURRENT_SOURCE_DIR}/python/visualize.py 71 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/cpp/cart_pole/python 72 | ) 73 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/config.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | report_frequency: 1 3 | enable_wandb_hook: 1 4 | verbose: 1 5 | 6 | algorithm: 7 | type: td3 8 | parameters: 9 | batch_size: 512 10 | num_critics: 2 11 | policy_lag: 2 12 | nstep: 1 13 | nstep_reward_reduction: sum_no_skip 14 | gamma: 0.99 15 | rho: 0.99 16 | 17 | actor: 18 | type: space_noise 19 | parameters: 20 | a_low: -1.0 21 | a_high: 1.0 22 | clip: 0.3 23 | sigma_train: 0.1 24 | sigma_explore: 0.2 25 | adaptive: 0 26 | 27 | replay_buffer: 28 | type: uniform 29 | parameters: 30 | max_size: 50000 31 | min_size: 1024 32 | 33 | policy_model: 34 | type: torchscript 35 | parameters: 36 | filename: policy.pt 37 | 38 | critic_model: 39 | type: torchscript 40 | parameters: 41 | filename: value.pt 42 | 43 | optimizer: 44 | type: adam 45 | parameters: 46 | learning_rate: 0.001 47 | beta1: 0.9 48 | beta2: 0.999 49 | weight_decay: 0 50 | eps: 1e-6 51 | amsgrad: 0 52 | 53 | policy_lr_scheduler: 54 | type: cosine_annealing 55 | parameters: 56 | T_max: 500000000 57 | 58 | critic_lr_scheduler: 59 | type: cosine_annealing 60 | parameters: 61 | T_max: 500000000 62 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/config_sim.yaml: -------------------------------------------------------------------------------- 1 | num_episodes: 50000 2 | max_steps_per_episode: 2500 3 | eval_frequency: 25 4 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/env.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | #include 34 | // #include 35 | 36 | enum IntegratorType { EXPLICIT_EULER, SEMI_IMPLICIT_EULER }; 37 | using StateVector = std::array; 38 | 39 | class CartPoleEnv { 40 | 41 | public: 42 | CartPoleEnv(); 43 | CartPoleEnv(const CartPoleEnv&) = delete; 44 | CartPoleEnv& operator=(const CartPoleEnv&) = delete; 45 | 46 | StateVector reset(); 47 | std::pair getStateBounds(); 48 | 49 | std::tuple step(float action); 50 | 51 | private: 52 | // sim parameters 53 | bool terminated_; 54 | float gravity_; 55 | float masscart_; 56 | float masspole_; 57 | float total_mass_; 58 | float length_; 59 | float polemass_length_; 60 | float force_mag_; 61 | float dt_; 62 | float penalty_; 63 | IntegratorType kinematics_integrator_; 64 | int steps_beyond_terminated_; 65 | 66 | // threshold parameters 67 | float theta_threshold_radians_; 68 | float x_threshold_; 69 | 70 | // random stuff 71 | std::mt19937_64 rng_; 72 | std::uniform_real_distribution uniform_dist_; 73 | 74 | // state vector 75 | StateVector state_; 76 | }; 77 | 78 | // pybind11 stuff 79 | // namespace py = pybind11; 80 | // PYBIND11_MODULE(environments, m) { 81 | // py::class_(m, "CartPoleEnv", py::dynamic_attr()) 82 | // .def(py::init<>()) 83 | // .def("step", &CartPoleEnv::step, py::arg("action")) 84 | // .def("reset", &CartPoleEnv::reset); 85 | //} 86 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/media/cartpole.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TorchFort/3ac715ffabe6f0e3fe7e6ee8276e0998d1513310/examples/cpp/cart_pole/media/cartpole.gif -------------------------------------------------------------------------------- /examples/cpp/cart_pole/media/cartpole.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TorchFort/3ac715ffabe6f0e3fe7e6ee8276e0998d1513310/examples/cpp/cart_pole/media/cartpole.mp4 -------------------------------------------------------------------------------- /examples/cpp/cart_pole/py_env.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | 34 | #include "env.h" 35 | 36 | // pybind11 stuff 37 | namespace py = pybind11; 38 | PYBIND11_MODULE(PyEnvironments, m) { 39 | py::class_(m, "CartPoleEnv", py::dynamic_attr()) 40 | .def(py::init<>()) 41 | .def("step", &CartPoleEnv::step, py::arg("action")) 42 | .def("reset", &CartPoleEnv::reset); 43 | } 44 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/python/initialize_models.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | 30 | import argparse as ap 31 | import math 32 | import torch 33 | from functools import partial 34 | from torch import nn 35 | import torch.nn.functional as F 36 | 37 | from models import weight_init, PolicyFunc, ValueFunc 38 | 39 | def main(args): 40 | 41 | # set seed 42 | torch.manual_seed(666) 43 | 44 | # CUDA check 45 | if torch.cuda.is_available(): 46 | torch.cuda.manual_seed(666) 47 | device = torch.device("cuda:0") 48 | else: 49 | device = torch.device("cpu") 50 | 51 | # parameters 52 | batch_size = 64 53 | 54 | # policy model 55 | pmodel = PolicyFunc(hidden_features=args.num_hidden_features).to(device) 56 | weight_init(pmodel) 57 | jpmodel = torch.jit.script(pmodel) 58 | inp = torch.ones((batch_size, 4), dtype=torch.float32, device=device) 59 | out = jpmodel(inp) 60 | print("Policy model:", pmodel) 61 | print("Policy model output shape:", out.shape) 62 | torch.jit.save(jpmodel, "policy.pt") 63 | 64 | # value model 65 | qmodel = ValueFunc(hidden_features=args.num_hidden_features).to(device) 66 | weight_init(qmodel) 67 | jqmodel = torch.jit.script(qmodel) 68 | inp_a = torch.ones((batch_size, 1), dtype=torch.float32, device=device) 69 | out = jqmodel(inp, inp_a) 70 | print("Value model:", qmodel) 71 | print("Value model output shape:", out.shape) 72 | torch.jit.save(jqmodel, "value.pt") 73 | 74 | if __name__ == "__main__": 75 | parser = ap.ArgumentParser() 76 | parser.add_argument("--num_hidden_features", type=int, default=128, help="Number of hidden features") 77 | args = parser.parse_args() 78 | 79 | main(args) 80 | -------------------------------------------------------------------------------- /examples/cpp/cart_pole/python/models.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | 30 | import math 31 | import torch 32 | from torch import nn 33 | import torch.nn.functional as F 34 | 35 | def weight_init(model, scale=0.02): 36 | with torch.no_grad(): 37 | for m in model.modules(): 38 | if isinstance(m, nn.Linear): 39 | sqrtk = math.sqrt(1./float(m.weight.shape[1])) 40 | nn.init.uniform_(m.weight, a=-sqrtk, b=sqrtk) 41 | if m.bias is not None: 42 | m.bias.data.zero_() 43 | 44 | class PolicyFunc(nn.Module): 45 | def __init__(self, hidden_features=128): 46 | super(PolicyFunc, self).__init__() 47 | 48 | layers = [nn.Linear(in_features = 4, 49 | out_features = hidden_features, 50 | bias=True), 51 | nn.ReLU(), 52 | nn.Linear(in_features = hidden_features, 53 | out_features = hidden_features // 2, 54 | bias=True), 55 | nn.ReLU(), 56 | nn.Linear(in_features = hidden_features // 2, 57 | out_features = 1, 58 | bias=True), 59 | nn.Tanh()] 60 | 61 | self.fwd = nn.Sequential(*layers) 62 | 63 | def forward(self, x: torch.Tensor) -> torch.Tensor: 64 | return self.fwd(x) 65 | 66 | class ValueFunc(nn.Module): 67 | def __init__(self, hidden_features=128): 68 | super(ValueFunc, self).__init__() 69 | 70 | layers = [nn.Linear(in_features = 5, 71 | out_features = hidden_features, 72 | bias=True), 73 | nn.ReLU(), 74 | nn.Linear(in_features = hidden_features, 75 | out_features = hidden_features // 2, 76 | bias=True), 77 | nn.ReLU(), 78 | nn.Linear(in_features = hidden_features // 2, 79 | out_features = 1, 80 | bias=True)] 81 | 82 | self.fwd = nn.Sequential(*layers) 83 | 84 | def forward(self, s: torch.Tensor, a: torch.Tensor) -> torch.Tensor: 85 | x = torch.cat([s, a], dim=1) 86 | return self.fwd(x) 87 | -------------------------------------------------------------------------------- /examples/fortran/graph/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(fortran_example_targets 2 | train_graph 3 | ) 4 | 5 | add_executable(train_graph) 6 | target_sources(train_graph 7 | PRIVATE 8 | train.f90 9 | ) 10 | set_target_properties(train_graph 11 | PROPERTIES OUTPUT_NAME train) 12 | 13 | foreach(tgt ${fortran_example_targets}) 14 | target_include_directories(${tgt} 15 | PRIVATE 16 | ${CMAKE_BINARY_DIR}/include 17 | ${MPI_Fortran_INCLUDE_DIRS} 18 | ${HDF5_Fortran_INCLUDE_DIRS} 19 | ) 20 | target_link_libraries(${tgt} PRIVATE MPI::MPI_Fortran) 21 | target_link_libraries(${tgt} PRIVATE "${PROJECT_NAME}_fort") 22 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME}) 23 | if (CMAKE_Fortran_COMPILER_ID STREQUAL "NVHPC") 24 | target_compile_options(${tgt} PRIVATE $<$:-cpp -acc -gpu=${CUF_GPU_ARG}>) 25 | target_link_options(${tgt} PRIVATE $<$: -acc -gpu=${CUF_GPU_ARG}>) 26 | elseif (CMAKE_Fortran_COMPILER_ID STREQUAL "GNU") 27 | target_compile_options(${tgt} PRIVATE $<$:-cpp -fbackslash>) 28 | endif() 29 | endforeach() 30 | 31 | install( 32 | TARGETS ${fortran_example_targets} 33 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/graph 34 | ) 35 | 36 | install( 37 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/config.yaml 38 | ${CMAKE_CURRENT_SOURCE_DIR}/generate_model.py 39 | ${CMAKE_CURRENT_SOURCE_DIR}/generate_loss.py 40 | ${CMAKE_CURRENT_SOURCE_DIR}/nodes.txt 41 | ${CMAKE_CURRENT_SOURCE_DIR}/connectivity.txt 42 | ${CMAKE_CURRENT_SOURCE_DIR}/visualize.py 43 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/graph) 44 | -------------------------------------------------------------------------------- /examples/fortran/graph/config.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | report_frequency: 100 3 | 4 | model: 5 | type: torchscript 6 | parameters: 7 | filename: "model_torchscript.pt" 8 | 9 | loss: 10 | type: torchscript 11 | parameters: 12 | filename: "loss_torchscript.pt" 13 | 14 | optimizer: 15 | type: adam 16 | parameters: 17 | learning_rate: 1e-3 18 | beta1: 0.9 19 | beta2: 0.999 20 | weight_decay: 0 21 | eps: 1e-8 22 | amsgrad: 0 23 | 24 | lr_scheduler: 25 | type: cosine_annealing 26 | parameters: 27 | T_max: 100000 28 | -------------------------------------------------------------------------------- /examples/fortran/graph/generate_loss.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | import torch 30 | 31 | class CustomLoss(torch.nn.Module): 32 | def __init__(self): 33 | super(CustomLoss, self).__init__() 34 | 35 | def forward(self, prediction, label, node_types): 36 | 37 | # Compute MSE over all nodes 38 | err = (label - prediction)**2 39 | 40 | # Zero out error for boundary nodes 41 | mask = node_types != 0 42 | err *= mask.unsqueeze(-1) 43 | 44 | # Compute mean over non-boundary nodes 45 | mse = torch.sum(err) / (torch.sum(mask) * err.shape[1]) 46 | 47 | return mse 48 | 49 | def main(): 50 | # Create loss module 51 | loss = CustomLoss() 52 | print("loss module:", loss) 53 | 54 | try: 55 | # Move model to GPU, JIT, and save 56 | loss.to("cuda") 57 | except: 58 | print("PyTorch does not have CUDA support. Saving model on CPU.") 59 | loss_jit = torch.jit.script(loss) 60 | loss_jit.save("loss_torchscript.pt") 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /examples/fortran/graph/generate_model.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | import torch 30 | 31 | class MessagePassing(torch.nn.Module): 32 | def __init__(self, hidden_dim): 33 | super(MessagePassing, self).__init__() 34 | 35 | self.mlp_edge = torch.nn.Sequential(torch.nn.Linear(3*hidden_dim, hidden_dim), 36 | torch.nn.ReLU(), 37 | torch.nn.Linear(hidden_dim, hidden_dim), 38 | torch.nn.LayerNorm(hidden_dim)) 39 | self.mlp_node = torch.nn.Sequential(torch.nn.Linear(2*hidden_dim, hidden_dim), 40 | torch.nn.ReLU(), 41 | torch.nn.Linear(hidden_dim, hidden_dim), 42 | torch.nn.LayerNorm(hidden_dim)) 43 | 44 | def forward(self, edge_idx, node_feats, edge_feats): 45 | senders = edge_idx[:,0] 46 | receivers = edge_idx[:,1] 47 | 48 | edge_update = torch.cat([node_feats[senders], node_feats[receivers], edge_feats], dim=-1) 49 | edge_update = self.mlp_edge(edge_update) 50 | 51 | accumulate_edges = torch.zeros([node_feats.shape[0], edge_feats.shape[1]], dtype=edge_feats.dtype, device=edge_feats.device) 52 | receivers = receivers.unsqueeze(-1).expand(-1, edge_feats.shape[1]) 53 | accumulate_edges = torch.scatter_add(accumulate_edges, src=edge_feats, index=receivers, dim=0) 54 | node_update = torch.cat([node_feats, accumulate_edges], dim=-1) 55 | node_update = self.mlp_node(node_update) 56 | 57 | edge_feats = edge_feats + edge_update 58 | node_feats = node_feats + node_update 59 | 60 | return node_feats, edge_feats 61 | 62 | 63 | class Net(torch.nn.Module): 64 | def __init__(self, in_node_features, in_edge_features, hidden_dim, n_message_passing_steps): 65 | super(Net, self).__init__() 66 | self.encoder_node = torch.nn.Sequential(torch.nn.Linear(in_node_features, hidden_dim), 67 | torch.nn.ReLU(), 68 | torch.nn.Linear(hidden_dim, hidden_dim), 69 | torch.nn.LayerNorm(hidden_dim)) 70 | self.encoder_edge = torch.nn.Sequential(torch.nn.Linear(in_edge_features, hidden_dim), 71 | torch.nn.ReLU(), 72 | torch.nn.Linear(hidden_dim, hidden_dim), 73 | torch.nn.LayerNorm(hidden_dim)) 74 | 75 | self.mp_layers = torch.nn.ModuleList() 76 | for _ in range(n_message_passing_steps): 77 | self.mp_layers.append(MessagePassing(hidden_dim)) 78 | 79 | self.decoder = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), 80 | torch.nn.ReLU(), 81 | torch.nn.Linear(hidden_dim, in_node_features)) 82 | 83 | def forward(self, edge_idx, node_feats, edge_feats): 84 | # Encode node and edge features 85 | node_feats = self.encoder_node(node_feats) 86 | edge_feats = self.encoder_edge(edge_feats) 87 | 88 | # Message passing 89 | for mp in self.mp_layers: 90 | node_feats, edge_feats = mp(edge_idx, node_feats, edge_feats) 91 | 92 | # Decode node featues 93 | node_feats = self.decoder(node_feats) 94 | 95 | return node_feats 96 | 97 | 98 | def main(): 99 | # Create model 100 | model = Net(1, 3, 128, 8) 101 | print("graph model:", model) 102 | 103 | try: 104 | # Move model to GPU, JIT, and save 105 | model.to("cuda") 106 | except: 107 | print("PyTorch does not have CUDA support. Saving model on CPU.") 108 | model_jit = torch.jit.script(model) 109 | model_jit.save("model_torchscript.pt") 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /examples/fortran/graph/media/validation_results.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TorchFort/3ac715ffabe6f0e3fe7e6ee8276e0998d1513310/examples/fortran/graph/media/validation_results.gif -------------------------------------------------------------------------------- /examples/fortran/graph/visualize.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | import argparse as ap 30 | import glob 31 | import matplotlib.pyplot as plt 32 | import matplotlib.tri as tri 33 | from matplotlib.animation import FuncAnimation, PillowWriter 34 | import numpy as np 35 | import os 36 | import time 37 | 38 | def main(args): 39 | 40 | global reffiles, predfiles, artists, triangulation 41 | print(f"processing files in {args.input_path}...") 42 | 43 | reffiles = sorted(glob.glob(os.path.join(args.input_path, "reference_*.txt"))) 44 | predfiles = sorted(glob.glob(os.path.join(args.input_path, "prediction_*.txt"))) 45 | 46 | # Read mesh data 47 | nodes = np.loadtxt("nodes.txt", skiprows=1) 48 | triangles = np.loadtxt("connectivity.txt", skiprows=1) 49 | triangulation = tri.Triangulation(nodes[:,0], nodes[:,1], triangles) 50 | 51 | artists = [] 52 | 53 | fig, ((ax1), (ax2)) = plt.subplots(2, 1) 54 | ax1.set_title("Ground Truth") 55 | ax1.set_xlabel(r"$x$") 56 | ax1.set_ylabel(r"$y$") 57 | ax2.set_title("Prediction") 58 | ax2.set_xlabel(r"$x$") 59 | ax2.set_ylabel(r"$y$") 60 | 61 | c = ax1.tricontourf(triangulation, np.loadtxt(reffiles[0]), levels=np.linspace(-0.1, 1.0, 15)) 62 | artists += c.collections 63 | c = ax1.triplot(triangulation, linewidth=0.3, color='black') 64 | artists.append(c) 65 | c = ax2.tricontourf(triangulation, np.loadtxt(predfiles[0]), levels=np.linspace(-0.1, 1.0, 15)) 66 | artists += c.collections 67 | c = ax2.triplot(triangulation, linewidth=0.3, color='black') 68 | artists.append(c) 69 | 70 | fig.tight_layout() 71 | 72 | def animate(i): 73 | global reffiles, predfiles, artists, triangulation 74 | for c in artists: 75 | try: 76 | c.remove() 77 | except: 78 | pass 79 | artists.clear() 80 | 81 | c = ax1.tricontourf(triangulation, np.loadtxt(reffiles[i]), levels=np.linspace(-0.1, 1.0, 15)) 82 | artists += c.collections 83 | c = ax1.triplot(triangulation, linewidth=0.3, color='black') 84 | artists.append(c) 85 | c = ax2.tricontourf(triangulation, np.loadtxt(predfiles[i]), levels=np.linspace(-0.1, 1.0, 15)) 86 | artists += c.collections 87 | c = ax2.triplot(triangulation, linewidth=0.3, color='black') 88 | artists.append(c) 89 | 90 | 91 | 92 | ani = FuncAnimation(fig, animate, frames=len(reffiles), repeat=False, interval=1) 93 | 94 | os.makedirs(args.output_path, exist_ok=True) 95 | 96 | def log(i, n): 97 | print(f"processed {i+1} of {n} frames..." ) 98 | ani.save(os.path.join(args.output_path, "validation_results.gif"), writer=PillowWriter(fps=5), progress_callback=lambda i, n: log(i,n)) 99 | print(f"video written to {os.path.join(args.output_path, 'validation_results.gif')}...") 100 | 101 | if __name__ == "__main__": 102 | parser = ap.ArgumentParser() 103 | parser.add_argument("--input_path", type=str, help="Directory containing result text files", required=True) 104 | parser.add_argument("--output_path", type=str, help="Directory to store the generated videos", required=True) 105 | args = parser.parse_args() 106 | 107 | main(args) 108 | 109 | -------------------------------------------------------------------------------- /examples/fortran/simulation/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(HDF5 COMPONENTS Fortran REQUIRED) 2 | 3 | set(fortran_example_targets 4 | train 5 | train_distributed 6 | ) 7 | 8 | add_executable(train) 9 | target_sources(train 10 | PRIVATE 11 | train.f90 12 | simulation.f90 13 | ) 14 | set_target_properties(train PROPERTIES Fortran_MODULE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mod/0 ) 15 | 16 | add_executable(train_distributed) 17 | target_sources(train_distributed 18 | PRIVATE 19 | train_distributed.f90 20 | simulation.f90 21 | ) 22 | set_target_properties(train_distributed PROPERTIES Fortran_MODULE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mod/1 ) 23 | 24 | foreach(tgt ${fortran_example_targets}) 25 | target_include_directories(${tgt} 26 | PRIVATE 27 | ${CMAKE_BINARY_DIR}/include 28 | ${MPI_Fortran_INCLUDE_DIRS} 29 | ${HDF5_Fortran_INCLUDE_DIRS} 30 | ) 31 | target_link_libraries(${tgt} PRIVATE MPI::MPI_Fortran) 32 | target_link_libraries(${tgt} PRIVATE hdf5::hdf5_fortran) 33 | target_link_libraries(${tgt} PRIVATE "${PROJECT_NAME}_fort") 34 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME}) 35 | if (CMAKE_Fortran_COMPILER_ID STREQUAL "NVHPC") 36 | target_compile_options(${tgt} PRIVATE $<$:-cpp -acc -gpu=${CUF_GPU_ARG}>) 37 | target_link_options(${tgt} PRIVATE $<$: -acc -gpu=${CUF_GPU_ARG}>) 38 | elseif (CMAKE_Fortran_COMPILER_ID STREQUAL "GNU") 39 | target_compile_options(${tgt} PRIVATE $<$:-cpp -fbackslash>) 40 | endif() 41 | endforeach() 42 | 43 | install( 44 | TARGETS ${fortran_example_targets} 45 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/simulation 46 | ) 47 | 48 | 49 | install( 50 | FILES ${CMAKE_CURRENT_SOURCE_DIR}/config_mlp_native.yaml 51 | ${CMAKE_CURRENT_SOURCE_DIR}/config_fcn_torchscript.yaml 52 | ${CMAKE_CURRENT_SOURCE_DIR}/generate_fcn_model.py 53 | ${CMAKE_CURRENT_SOURCE_DIR}/visualize.py 54 | DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/examples/fortran/simulation) 55 | -------------------------------------------------------------------------------- /examples/fortran/simulation/config_fcn_torchscript.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | enable_wandb_hook: 1 3 | report_frequency: 100 4 | 5 | model: 6 | type: torchscript 7 | parameters: 8 | filename: "fcn_torchscript.pt" 9 | 10 | loss: 11 | type: MSE 12 | 13 | optimizer: 14 | type: adam 15 | parameters: 16 | learning_rate: 1e-3 17 | beta1: 0.9 18 | beta2: 0.999 19 | weight_decay: 0 20 | eps: 1e-8 21 | amsgrad: 0 22 | 23 | lr_scheduler: 24 | type: cosine_annealing 25 | parameters: 26 | T_max: 100000 27 | -------------------------------------------------------------------------------- /examples/fortran/simulation/config_mlp_native.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | enable_wandb_hook: 1 3 | report_frequency: 100 4 | 5 | model: 6 | type: mlp 7 | parameters: 8 | dropout: 0.0 9 | layer_sizes: [1024, 1024] 10 | 11 | loss: 12 | type: MSE 13 | 14 | optimizer: 15 | type: adam 16 | parameters: 17 | learning_rate: 1e-3 18 | beta1: 0.9 19 | beta2: 0.999 20 | weight_decay: 0 21 | eps: 1e-8 22 | amsgrad: 0 23 | 24 | lr_scheduler: 25 | type: cosine_annealing 26 | parameters: 27 | T_max: 100000 28 | -------------------------------------------------------------------------------- /examples/fortran/simulation/generate_fcn_model.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | import torch 30 | 31 | class Net(torch.nn.Module): 32 | def __init__(self): 33 | super(Net, self).__init__() 34 | self.conv1 = torch.nn.Conv2d(1, 1, 3, padding=1, padding_mode="circular") 35 | 36 | def forward(self, x): 37 | return self.conv1(x) 38 | 39 | 40 | def main(): 41 | # Create model 42 | model = Net() 43 | print("FCN model:", model) 44 | 45 | try: 46 | # Move model to GPU, JIT, and save 47 | model.to("cuda") 48 | except: 49 | print("PyTorch does not have CUDA support. Saving model on CPU.") 50 | model_jit = torch.jit.script(model) 51 | model_jit.save("fcn_torchscript.pt") 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /examples/fortran/simulation/media/validation_results.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TorchFort/3ac715ffabe6f0e3fe7e6ee8276e0998d1513310/examples/fortran/simulation/media/validation_results.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # basic packages 2 | ruamel-yaml 3 | 4 | # pytorch and some dependencies 5 | torch==2.7.0 6 | torchvision==0.22.0 7 | torchaudio==2.7.0 8 | 9 | # training monitoring 10 | wandb 11 | 12 | # RL example visualization related 13 | pygame 14 | moviepy 15 | 16 | # Supervised learning example visualization related 17 | matplotlib 18 | h5py 19 | -------------------------------------------------------------------------------- /src/csrc/include/internal/base_loss.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #include 34 | 35 | #include 36 | 37 | #include "internal/base_loss.h" 38 | #include "internal/param_map.h" 39 | 40 | namespace torchfort { 41 | 42 | struct BaseLoss : torch::nn::Module { 43 | virtual torch::Tensor forward(const std::vector& inputs, 44 | const std::vector& labels, 45 | const std::vector& extra_args) = 0; 46 | virtual void setup(const ParamMap& params) = 0; 47 | }; 48 | 49 | } // namespace torchfort 50 | -------------------------------------------------------------------------------- /src/csrc/include/internal/base_lr_scheduler.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | #include 33 | #include 34 | #include 35 | 36 | #include 37 | 38 | #include "internal/exceptions.h" 39 | 40 | namespace torchfort { 41 | 42 | class BaseLRScheduler : public torch::optim::LRScheduler { 43 | public: 44 | BaseLRScheduler(torch::optim::Optimizer& optimizer) : LRScheduler(optimizer) {} 45 | 46 | // Define generic save/load functionalities. Specialize in derived schedulers if 47 | // needed. 48 | void save(const std::string& fname) { 49 | torch::serialize::OutputArchive archive; 50 | archive.write("step_count", torch::IValue((int64_t)step_count_)); 51 | archive.write("lrs", torch::IValue(get_current_lrs())); 52 | archive.save_to(fname); 53 | } 54 | void load(const std::string& fname, torch::optim::Optimizer& optimizer) { 55 | torch::serialize::InputArchive archive; 56 | archive.load_from(fname); 57 | 58 | torch::IValue ivalue; 59 | if (!archive.try_read("step_count", ivalue)) { 60 | THROW_INVALID_USAGE(fname + " is missing required data."); 61 | } 62 | int64_t step_count = ivalue.to(); 63 | step_count_ = step_count; 64 | 65 | if (!archive.try_read("lrs", ivalue)) { 66 | THROW_INVALID_USAGE(fname + " is missing required data."); 67 | } 68 | auto lrs = ivalue.to>(); 69 | // Can't use this method to set the LRs due to it being private in the base LR class. 70 | // set_optimizer_lrs(lrs); 71 | for (const auto i : c10::irange(optimizer.param_groups().size())) { 72 | optimizer.param_groups()[i].options().set_lr(lrs[i]); 73 | } 74 | } 75 | }; 76 | 77 | } // namespace torchfort 78 | -------------------------------------------------------------------------------- /src/csrc/include/internal/base_model.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #include 34 | 35 | #include "internal/param_map.h" 36 | 37 | namespace torchfort { 38 | 39 | struct BaseModel : torch::nn::Module { 40 | virtual std::vector forward(const std::vector& inputs) = 0; 41 | virtual void setup(const ParamMap& params) = 0; 42 | }; 43 | 44 | } // namespace torchfort 45 | -------------------------------------------------------------------------------- /src/csrc/include/internal/distributed.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #ifdef ENABLE_GPU 34 | #include 35 | #include 36 | #endif 37 | #include 38 | 39 | #include 40 | 41 | namespace torchfort { 42 | 43 | struct Comm { 44 | void initialize(bool initialize_nccl = false); 45 | void finalize(); 46 | void allreduce(torch::Tensor& tensor, bool average = false) const; 47 | void allreduce(std::vector& tensors, bool average = false) const; 48 | void allreduce(double& val, bool average = false) const; 49 | void allreduce(float& val, bool average = false) const; 50 | void broadcast(torch::Tensor& tensor, int root = 0) const; 51 | 52 | int rank; 53 | int size; 54 | MPI_Comm mpi_comm; 55 | #ifdef ENABLE_GPU 56 | ncclComm_t nccl_comm = nullptr; 57 | cudaStream_t stream = nullptr; 58 | cudaEvent_t event = nullptr; 59 | #endif 60 | bool initialized = false; 61 | 62 | Comm(MPI_Comm mpi_comm) : mpi_comm(mpi_comm){}; 63 | }; 64 | 65 | } // namespace torchfort 66 | -------------------------------------------------------------------------------- /src/csrc/include/internal/logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #include 34 | #include 35 | 36 | #include "internal/model_pack.h" 37 | 38 | namespace torchfort { 39 | 40 | namespace logging { 41 | 42 | enum level { info, warn, error, wandb }; 43 | 44 | std::string log_level_prefix(level log_level); 45 | void write(const std::filesystem::path& filename, const std::string& message, level log_level); 46 | void print(const std::string& message, level log_level); 47 | 48 | } // namespace logging 49 | 50 | // Declaration of external global variables 51 | extern std::unordered_map models; 52 | 53 | // specialized logging routines 54 | template 55 | void wandb_log(std::shared_ptr state, std::shared_ptr comm, const char* name, const char* metric_name, 56 | int64_t step, T value) { 57 | if (state->enable_wandb_hook) { 58 | std::stringstream os; 59 | os << "model: " << name << ", "; 60 | os << "step: " << step << ", "; 61 | os << metric_name << ": " << value; 62 | if (!comm || (comm && comm->rank == 0)) { 63 | torchfort::logging::write(state->report_file, os.str(), torchfort::logging::wandb); 64 | } 65 | } 66 | } 67 | 68 | template void wandb_log(const char* name, const char* metric_name, int64_t step, T value) { 69 | auto state = models[name].state.get(); 70 | if (state->enable_wandb_hook) { 71 | std::stringstream os; 72 | os << "model: " << name << ", "; 73 | os << "step: " << step << ", "; 74 | os << metric_name << ": " << value; 75 | if (!models[name].comm || (models[name].comm && models[name].comm->rank == 0)) { 76 | torchfort::logging::write(state->report_file, os.str(), torchfort::logging::wandb); 77 | } 78 | } 79 | } 80 | 81 | } // namespace torchfort 82 | -------------------------------------------------------------------------------- /src/csrc/include/internal/losses.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #include 34 | #include 35 | 36 | #include 37 | #include 38 | #include 39 | 40 | #include "internal/base_loss.h" 41 | #include "internal/defines.h" 42 | #include "internal/param_map.h" 43 | 44 | namespace torchfort { 45 | 46 | struct L1Loss : BaseLoss { 47 | void setup(const ParamMap& params) override; 48 | 49 | torch::Tensor forward(const std::vector& inputs, 50 | const std::vector& labels, 51 | const std::vector& extra_args) override; 52 | 53 | torch::nn::L1Loss module; 54 | }; 55 | 56 | struct MSELoss : BaseLoss { 57 | void setup(const ParamMap& params) override; 58 | 59 | torch::Tensor forward(const std::vector& inputs, 60 | const std::vector& labels, 61 | const std::vector& extra_args) override; 62 | 63 | torch::nn::MSELoss module; 64 | }; 65 | 66 | struct TorchscriptLoss : BaseLoss { 67 | void setup(const ParamMap& params) override; 68 | 69 | torch::Tensor forward(const std::vector& inputs, 70 | const std::vector& labels, 71 | const std::vector& extra_args) override; 72 | 73 | std::shared_ptr module_jit; 74 | }; 75 | 76 | // Creating loss_registry. 77 | BEGIN_LOSS_REGISTRY 78 | 79 | // Add entries for new losses in this section. First argument to REGISTER_LOSS is 80 | // a string key and the second argument is the class name. 81 | REGISTER_LOSS(L1, L1Loss) 82 | REGISTER_LOSS(MSE, MSELoss) 83 | REGISTER_LOSS(torchscript, TorchscriptLoss) 84 | 85 | END_LOSS_REGISTRY 86 | 87 | } // namespace torchfort 88 | -------------------------------------------------------------------------------- /src/csrc/include/internal/lr_schedulers.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | #include 33 | 34 | #include 35 | 36 | #include "internal/base_lr_scheduler.h" 37 | 38 | namespace torchfort { 39 | 40 | class CosineAnnealingLR : public BaseLRScheduler { 41 | public: 42 | CosineAnnealingLR(torch::optim::Optimizer& optimizer, const unsigned T_max, const double eta_min = 0.0); 43 | 44 | private: 45 | std::vector get_lrs() override; 46 | double update_lr(const double& last_lr, const double& base_lr); 47 | 48 | const unsigned T_max_; 49 | const double eta_min_; 50 | std::vector base_lrs_; 51 | }; 52 | 53 | class MultiStepLR : public BaseLRScheduler { 54 | public: 55 | MultiStepLR(torch::optim::Optimizer& optimizer, const std::vector& milestones, const double gamma = 0.1); 56 | 57 | private: 58 | std::vector get_lrs() override; 59 | 60 | const std::vector milestones_; 61 | const double gamma_; 62 | }; 63 | 64 | class PolynomialLR : public BaseLRScheduler { 65 | public: 66 | PolynomialLR(torch::optim::Optimizer& optimizer, const unsigned total_iters, const double power = 1.0); 67 | 68 | private: 69 | std::vector get_lrs() override; 70 | 71 | const unsigned total_iters_; 72 | const double power_; 73 | }; 74 | 75 | class StepLR : public BaseLRScheduler { 76 | public: 77 | StepLR(torch::optim::Optimizer& optimizer, const unsigned step_size, const double gamma = 0.1); 78 | 79 | private: 80 | std::vector get_lrs() override; 81 | 82 | const unsigned step_size_; 83 | const double gamma_; 84 | }; 85 | 86 | class LinearLR : public BaseLRScheduler { 87 | public: 88 | LinearLR(torch::optim::Optimizer& optimizer, const unsigned total_iters, const double start_factor = 0.333, 89 | const double end_factor = 1.0); 90 | 91 | private: 92 | std::vector get_lrs() override; 93 | 94 | const unsigned total_iters_; 95 | const double start_factor_, end_factor_; 96 | }; 97 | 98 | } // namespace torchfort 99 | -------------------------------------------------------------------------------- /src/csrc/include/internal/model_pack.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | #include 33 | 34 | #include 35 | 36 | #include "internal/base_loss.h" 37 | #include "internal/base_lr_scheduler.h" 38 | #include "internal/distributed.h" 39 | #include "internal/model_state.h" 40 | #include "internal/model_wrapper.h" 41 | 42 | namespace torchfort { 43 | 44 | // Simple struct to group model, optimizer, lr scheduler, state, and comm objects 45 | struct ModelPack { 46 | std::shared_ptr model; 47 | std::shared_ptr optimizer; 48 | std::shared_ptr lr_scheduler; 49 | std::shared_ptr loss; 50 | std::shared_ptr comm; 51 | std::shared_ptr state; 52 | int grad_accumulation_steps = 1; 53 | }; 54 | 55 | void save_model_pack(const ModelPack& model_pack, const std::string& fname, bool save_optimizer = true); 56 | void load_model_pack(ModelPack& model_pack, const std::string& fname, bool load_optimizer = true); 57 | 58 | } // namespace torchfort 59 | -------------------------------------------------------------------------------- /src/csrc/include/internal/model_state.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | #include 33 | #include 34 | 35 | #include 36 | 37 | namespace torchfort { 38 | 39 | // Simple struct to store miscellaneous model state (e.g. iteration count) 40 | struct ModelState { 41 | int64_t step_train; 42 | int64_t step_inference; 43 | int64_t step_train_current; // training step of current run (ignoring restarted state) 44 | torch::Device device = torch::Device(torch::kCPU); 45 | 46 | // General option settings 47 | int32_t report_frequency; 48 | bool enable_wandb_hook; 49 | bool verbose; 50 | std::filesystem::path report_file; 51 | 52 | void save(const std::string& fname); 53 | void load(const std::string& fname); 54 | }; 55 | 56 | } // namespace torchfort 57 | -------------------------------------------------------------------------------- /src/csrc/include/internal/model_wrapper.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #include 34 | 35 | #include "internal/base_model.h" 36 | 37 | namespace torchfort { 38 | 39 | class ModelWrapper { 40 | public: 41 | ModelWrapper(const std::shared_ptr& model); 42 | 43 | ModelWrapper(const std::shared_ptr& model_jit); 44 | 45 | ModelWrapper(const std::string& jit_model_fname); 46 | 47 | std::vector parameters() const; 48 | 49 | torch::OrderedDict named_parameters() const; 50 | 51 | void to(torch::Device device, bool non_blocking = false); 52 | 53 | void train(); 54 | 55 | void eval(); 56 | 57 | std::vector forward(const std::vector& inputs) const; 58 | 59 | void save(const std::string& fname) const; 60 | 61 | void load(const std::string& fname); 62 | 63 | torch::Device device() const; 64 | 65 | private: 66 | bool jit = false; 67 | std::shared_ptr model; 68 | std::shared_ptr model_jit; 69 | torch::Device device_ = torch::Device(torch::kCPU); 70 | }; 71 | 72 | } // namespace torchfort 73 | -------------------------------------------------------------------------------- /src/csrc/include/internal/models.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #include 34 | 35 | #include 36 | 37 | #include "internal/base_model.h" 38 | #include "internal/defines.h" 39 | #include "internal/param_map.h" 40 | 41 | namespace torchfort { 42 | 43 | // MLP model in C++ using libtorch 44 | struct MLPModel : BaseModel, public std::enable_shared_from_this { 45 | void setup(const ParamMap& params) override; 46 | std::vector forward(const std::vector& inputs) override; 47 | 48 | double dropout; 49 | std::vector layer_sizes; 50 | 51 | // Use one of many "standard library" modules. 52 | std::vector fc_layers; 53 | std::vector biases; 54 | }; 55 | 56 | struct SACMLPModel : BaseModel, public std::enable_shared_from_this { 57 | void setup(const ParamMap& params) override; 58 | std::vector forward(const std::vector& inputs) override; 59 | 60 | double dropout; 61 | std::vector layer_sizes; 62 | bool state_dependent_sigma; 63 | 64 | // A SAC Model has a common encoder and two output layers for mu and log-sigma 65 | std::vector encoder_layers; 66 | std::vector out_layers; 67 | std::vector biases; 68 | std::vector out_biases; 69 | }; 70 | 71 | struct ActorCriticMLPModel : BaseModel, public std::enable_shared_from_this { 72 | void setup(const ParamMap& params) override; 73 | std::vector forward(const std::vector& inputs) override; 74 | 75 | double dropout; 76 | std::vector encoder_layer_sizes, actor_layer_sizes, value_layer_sizes; 77 | bool state_dependent_sigma; 78 | 79 | // An AC Model has a common encoder and then an MLP for actor and one for value 80 | std::vector encoder_layers, actor_layers, value_layers; 81 | std::vector encoder_biases, actor_biases, value_biases; 82 | }; 83 | 84 | // Creating model_registry. 85 | BEGIN_MODEL_REGISTRY 86 | 87 | // Add entries for new models in this section. 88 | REGISTER_MODEL(MLP, MLPModel) 89 | REGISTER_MODEL(SACMLP, SACMLPModel) 90 | REGISTER_MODEL(ActorCriticMLP, ActorCriticMLPModel) 91 | 92 | END_MODEL_REGISTRY 93 | 94 | } // namespace torchfort 95 | -------------------------------------------------------------------------------- /src/csrc/include/internal/nvtx.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #include 34 | 35 | #ifdef ENABLE_GPU 36 | #include 37 | #endif 38 | 39 | namespace torchfort { 40 | 41 | // Helper class for NVTX ranges 42 | class nvtx { 43 | public: 44 | #ifdef ENABLE_GPU 45 | static void rangePush(const std::string& range_name) { 46 | static constexpr int ncolors_ = 8; 47 | static constexpr int colors_[ncolors_] = {0x3366CC, 0xDC3912, 0xFF9900, 0x109618, 48 | 0x990099, 0x3B3EAC, 0x0099C6, 0xDD4477}; 49 | std::hash hash_fn; 50 | int color = colors_[hash_fn(range_name) % ncolors_]; 51 | nvtxEventAttributes_t ev = {0}; 52 | ev.version = NVTX_VERSION; 53 | ev.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; 54 | ev.colorType = NVTX_COLOR_ARGB; 55 | ev.color = color; 56 | ev.messageType = NVTX_MESSAGE_TYPE_ASCII; 57 | ev.message.ascii = range_name.c_str(); 58 | nvtxRangePushEx(&ev); 59 | } 60 | 61 | static void rangePop() { nvtxRangePop(); } 62 | #else 63 | static void rangePush(const std::string& range_name) {} 64 | static void rangePop() {} 65 | #endif 66 | }; 67 | 68 | } // namespace torchfort 69 | -------------------------------------------------------------------------------- /src/csrc/include/internal/param_map.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | 40 | #include "internal/exceptions.h" 41 | #include "internal/utils.h" 42 | 43 | namespace torchfort { 44 | 45 | // Helper function to get type as string 46 | template std::string type_string() { 47 | if (std::is_same::value) { 48 | return "int"; 49 | } else if (std::is_same::value) { 50 | return "float"; 51 | } else if (std::is_same::value) { 52 | return "double"; 53 | } else if (std::is_same::value) { 54 | return "bool"; 55 | } 56 | return "UNKNOWN"; 57 | }; 58 | 59 | // Conversion functor 60 | template struct ParamMapConverter { 61 | T operator()(const std::string& s) { 62 | try { 63 | if constexpr (std::is_same::value) { 64 | return std::stoi(sanitize(s)); 65 | } 66 | if constexpr (std::is_same::value) { 67 | return std::stof(sanitize(s)); 68 | } 69 | if constexpr (std::is_same::value) { 70 | return std::stod(sanitize(s)); 71 | } 72 | if constexpr (std::is_same::value) { 73 | std::string s_ = sanitize(s); 74 | bool val; 75 | if (s_ == "true") { 76 | val = true; 77 | } else if (s_ == "false") { 78 | val = false; 79 | } else { 80 | val = std::stoi(s_); 81 | } 82 | return val; 83 | } 84 | if constexpr (std::is_same::value) { 85 | return s; 86 | } 87 | } catch (std::invalid_argument) { 88 | THROW_INVALID_USAGE("Could not convert provided parameter value " + s + " to required type."); 89 | } 90 | 91 | THROW_INTERNAL_ERROR("Unknown conversion type."); 92 | } 93 | }; 94 | 95 | class ParamMap { 96 | public: 97 | template void add_param(const std::string& key, const std::vector& value); 98 | 99 | template std::vector get_param(const std::string& key) const; 100 | 101 | template std::vector get_param(const std::string& key, const T& defval) const; 102 | 103 | std::set keys() const; 104 | 105 | private: 106 | std::unordered_map> params; 107 | }; 108 | 109 | template void ParamMap::add_param(const std::string& key, const std::vector& value) { 110 | params[sanitize(key)] = value; 111 | } 112 | 113 | template std::vector ParamMap::get_param(const std::string& key) const { 114 | const auto& entry = params.at(sanitize(key)); 115 | std::vector values; 116 | std::transform(entry.begin(), entry.end(), std::back_inserter(values), ParamMapConverter()); 117 | return values; 118 | } 119 | 120 | // parameter with default value 121 | template std::vector ParamMap::get_param(const std::string& key, const T& defval) const { 122 | try { 123 | const auto& entry = params.at(sanitize(key)); 124 | std::vector values; 125 | std::transform(entry.begin(), entry.end(), std::back_inserter(values), ParamMapConverter()); 126 | return values; 127 | } catch (std::out_of_range) { 128 | return {defval}; 129 | } 130 | } 131 | 132 | } // namespace torchfort 133 | -------------------------------------------------------------------------------- /src/csrc/include/internal/rl/distributions.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | #include 33 | #include 34 | 35 | #include "internal/rl/rl.h" 36 | 37 | namespace torchfort { 38 | 39 | namespace rl { 40 | 41 | class Distribution { 42 | 43 | public: 44 | Distribution(const Distribution&) = delete; 45 | 46 | // constructor 47 | Distribution() {} 48 | virtual torch::Tensor rsample() = 0; 49 | virtual torch::Tensor log_prob(torch::Tensor value) = 0; 50 | virtual torch::Tensor entropy() = 0; 51 | }; 52 | 53 | class NormalDistribution : public Distribution, public std::enable_shared_from_this { 54 | public: 55 | NormalDistribution(torch::Tensor mu, torch::Tensor sigma) : mu_(mu), sigma_(sigma) {} 56 | 57 | torch::Tensor rsample() { 58 | auto noise = torch::empty_like(mu_).normal_(0., 1.); 59 | return torch::Tensor(mu_ + sigma_ * noise).clone(); 60 | } 61 | 62 | torch::Tensor log_prob(torch::Tensor value) { 63 | auto var = torch::square(sigma_); 64 | auto log_sigma = sigma_.log(); 65 | auto result = -torch::square(value - mu_) / (2 * var) - log_sigma - std::log(std::sqrt(2. * M_PI)); 66 | 67 | return result; 68 | } 69 | 70 | torch::Tensor entropy() { 71 | auto log_sigma = sigma_.log(); 72 | auto result = log_sigma + 0.5 * (1. + std::log(2. * M_PI)); 73 | 74 | return result; 75 | } 76 | 77 | protected: 78 | torch::Tensor mu_; 79 | torch::Tensor sigma_; 80 | }; 81 | 82 | } // namespace rl 83 | 84 | } // namespace torchfort 85 | -------------------------------------------------------------------------------- /src/csrc/include/internal/rl/rl.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #include "internal/rl/off_policy.h" 34 | #include "internal/rl/on_policy.h" 35 | -------------------------------------------------------------------------------- /src/csrc/include/internal/rl/utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | #include 33 | 34 | #ifdef ENABLE_GPU 35 | #include 36 | 37 | #include 38 | #include 39 | #endif 40 | #include 41 | 42 | #include "internal/defines.h" 43 | #include "internal/logging.h" 44 | #include "internal/lr_schedulers.h" 45 | #include "internal/model_pack.h" 46 | #include "internal/rl/rl.h" 47 | #include "internal/setup.h" 48 | 49 | namespace torchfort { 50 | 51 | namespace rl { 52 | 53 | // helpers for sanitizing devices 54 | // check if both devices are different, one device has to be a cpu device. 55 | bool validate_devices(int device1, int device2); 56 | 57 | // helpers for extracting LRS from optimizer 58 | std::vector get_current_lrs(std::shared_ptr optimizer); 59 | 60 | // helpers for manipulating weights and grads 61 | void init_parameters(std::shared_ptr model); 62 | void copy_parameters(std::shared_ptr target, std::shared_ptr source); 63 | void set_grad_state(std::shared_ptr model, const bool requires_grad); 64 | 65 | // polyak update for model averaging: 66 | // computes: target = rho * target + (1-rho) * src 67 | // note that target here denotes the updated model parameters, and src the previous ones 68 | template 69 | void polyak_update(std::shared_ptr target, std::shared_ptr source, const T rho) { 70 | 71 | // add no grad guard 72 | torch::NoGradGuard no_grad; 73 | 74 | // get models 75 | auto tar = target->parameters(); 76 | auto src = source->parameters(); 77 | 78 | // some simple asserts here 79 | assert(tar.size() == src.size()); 80 | 81 | // do in-place update: I don't know a good way of doing that with std::transform: 82 | for (size_t i = 0; i < tar.size(); ++i) { 83 | const auto& t = tar[i]; 84 | const auto& s = src[i]; 85 | t.mul_(rho); 86 | t.add_((1. - rho) * s); 87 | // t.copy_(torch::Tensor(rho * t + (1.-rho) * s)); 88 | } 89 | 90 | return; 91 | } 92 | 93 | // Rescale the action from [a_low, a_high] to [-1, 1] 94 | template torch::Tensor scale_action(torch::Tensor unscaled_action, const T& a_low, const T& a_high) { 95 | auto scaled_action = static_cast(2.0) * ((unscaled_action - a_low) / (a_high - a_low)) - static_cast(1.0); 96 | scaled_action.to(unscaled_action.dtype()); 97 | 98 | return scaled_action; 99 | } 100 | 101 | // Unscale the action from [-1., 1.] to [a_low, a_high] 102 | template torch::Tensor unscale_action(torch::Tensor scaled_action, const T& a_low, const T& a_high) { 103 | auto unscaled_action = 0.5 * (a_high - a_low) * (scaled_action + static_cast(1.)) + a_low; 104 | unscaled_action.to(scaled_action.dtype()); 105 | 106 | return unscaled_action; 107 | } 108 | 109 | // explained variance 110 | torch::Tensor explained_variance(torch::Tensor q_pred, torch::Tensor q_true, std::shared_ptr comm); 111 | 112 | } // namespace rl 113 | 114 | } // namespace torchfort 115 | -------------------------------------------------------------------------------- /src/csrc/include/internal/setup.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #include 34 | 35 | #include 36 | #include 37 | 38 | #include "internal/base_loss.h" 39 | #include "internal/base_lr_scheduler.h" 40 | #include "internal/base_model.h" 41 | #include "internal/lr_schedulers.h" 42 | #include "internal/model_state.h" 43 | #include "internal/model_wrapper.h" 44 | 45 | namespace torchfort { 46 | 47 | void check_params(const std::set& supported_params, const std::set& provided_params); 48 | 49 | ParamMap get_params(const YAML::Node& params_node); 50 | 51 | std::shared_ptr get_model(const YAML::Node& model_node); 52 | 53 | std::shared_ptr get_loss(const YAML::Node& loss_node); 54 | 55 | std::shared_ptr get_optimizer(const YAML::Node& optimizer_node, 56 | std::vector parameters); 57 | 58 | std::shared_ptr get_optimizer(const YAML::Node& optimizer_node, 59 | const std::shared_ptr& model); 60 | 61 | std::shared_ptr get_lr_scheduler(const YAML::Node& lr_scheduler_node, 62 | const std::shared_ptr& optimizer); 63 | 64 | std::shared_ptr get_state(const char* name, const YAML::Node& state_node); 65 | } // namespace torchfort 66 | -------------------------------------------------------------------------------- /src/csrc/include/internal/tensor_list.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #include 34 | 35 | #include 36 | 37 | #include "internal/utils.h" 38 | 39 | namespace torchfort { 40 | struct TensorList { 41 | template 42 | void add_tensor(T* data, size_t dim, int64_t* shape) { 43 | auto tensor = get_tensor(data, dim, shape); 44 | tensors.push_back(tensor); 45 | tensors_original_.push_back(tensor); 46 | }; 47 | 48 | void to(torch::Device device, bool non_blocking = false) { 49 | for (auto &t : tensors) { 50 | t = t.to(device, non_blocking); 51 | } 52 | }; 53 | 54 | void reset() { 55 | tensors = tensors_original_; 56 | } 57 | 58 | std::vector tensors; 59 | // To preserve references to external data, we store the original tensor objects 60 | std::vector tensors_original_; 61 | }; 62 | } // namespace torchfort 63 | -------------------------------------------------------------------------------- /src/csrc/include/internal/training.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #ifdef ENABLE_GPU 34 | #include 35 | #endif 36 | 37 | #include 38 | 39 | #include 40 | 41 | namespace torchfort { 42 | 43 | void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfort_tensor_list_t outputs_in, 44 | cudaStream_t ext_stream = 0); 45 | 46 | void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfort_tensor_list_t labels_in, 47 | float* loss_val, torchfort_tensor_list_t extra_loss_args_in, cudaStream_t ext_stream = 0); 48 | 49 | template 50 | void inference(const char* name, T* input, size_t input_dim, int64_t* input_shape, T* output, size_t output_dim, 51 | int64_t* output_shape, cudaStream_t ext_stream = 0) { 52 | TensorList inputs, outputs; 53 | 54 | inputs.add_tensor(input, input_dim, input_shape); 55 | outputs.add_tensor(output, output_dim, output_shape); 56 | 57 | inference_multiarg(name, &inputs, &outputs, ext_stream); 58 | } 59 | 60 | template 61 | void train(const char* name, T* input, size_t input_dim, int64_t* input_shape, T* label, size_t label_dim, 62 | int64_t* label_shape, T* loss_val, cudaStream_t ext_stream = 0) { 63 | TensorList inputs, labels; 64 | 65 | inputs.add_tensor(input, input_dim, input_shape); 66 | labels.add_tensor(label, label_dim, label_shape); 67 | 68 | // multiarg API expects float loss value, so use temporary here 69 | float loss_val_tmp; 70 | train_multiarg(name, &inputs, &labels, &loss_val_tmp, nullptr, ext_stream); 71 | *loss_val = loss_val_tmp; 72 | } 73 | 74 | } // namespace torchfort 75 | -------------------------------------------------------------------------------- /src/csrc/include/internal/utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #include 34 | #include 35 | #include 36 | 37 | #include 38 | #include 39 | 40 | #ifdef ENABLE_GPU 41 | #include 42 | #endif 43 | 44 | #include "internal/exceptions.h" 45 | #include "internal/nvtx.h" 46 | 47 | namespace torchfort { 48 | 49 | // Function to convert string to lowercase and remove whitespace 50 | std::string sanitize(std::string s); 51 | 52 | // Function to convert a string to a filename 53 | std::string filename_sanitize(std::string s); 54 | 55 | // Function to return torch device from integer device value 56 | torch::Device get_device(int device); 57 | 58 | // Function to return torch device from pointer 59 | torch::Device get_device(const void* ptr); 60 | 61 | template torch::Dtype make_type() { 62 | if (std::is_same::value) { 63 | return torch::kFloat32; 64 | } else if (std::is_same::value) { 65 | return torch::kInt32; 66 | } else if (std::is_same::value) { 67 | return torch::kInt64; 68 | } else if (std::is_same::value) { 69 | return torch::kFloat64; 70 | } else { 71 | THROW_INVALID_USAGE("datatype not implemented"); 72 | } 73 | } 74 | 75 | enum MemoryLayout { RowMajor = 0, ColMajor = 1 }; 76 | 77 | template torch::Tensor get_tensor(T* tensor_ptr, size_t dim, int64_t* shape) { 78 | torchfort::nvtx::rangePush("get_tensor"); 79 | // Set tensor options 80 | auto dev = get_device(tensor_ptr); 81 | torch::TensorOptions options = torch::TensorOptions().device(dev); 82 | 83 | // Get type 84 | auto type = make_type(); 85 | options = options.dtype(type); 86 | 87 | // Create shape 88 | std::vector sizes(dim); 89 | switch (L) { 90 | case RowMajor: 91 | for (size_t i = 0; i < dim; ++i) { 92 | sizes[i] = shape[i]; 93 | } 94 | break; 95 | case ColMajor: 96 | // For column major input data, reverse the shape order 97 | for (size_t i = 0; i < dim; ++i) { 98 | sizes[i] = shape[dim - i - 1]; 99 | } 100 | break; 101 | } 102 | torch::IntArrayRef size = c10::makeArrayRef(sizes); 103 | 104 | // Create tensor 105 | auto tensor = torch::from_blob( 106 | tensor_ptr, sizes, [](void* ptr) {}, options); 107 | torchfort::nvtx::rangePop(); 108 | return tensor; 109 | } 110 | 111 | // Helper function to convert string reduction names to torch enums. 112 | template T get_torch_reduction(const std::string& s) { 113 | if (s == "mean") { 114 | return torch::kMean; 115 | } else if (s == "sum") { 116 | return torch::kSum; 117 | } else if (s == "none") { 118 | return torch::kNone; 119 | } else { 120 | THROW_INVALID_USAGE("Unknown reduction type encountered."); 121 | } 122 | } 123 | 124 | // Helper function for printing tensor shapes 125 | std::string print_tensor_shape(torch::Tensor tensor); 126 | 127 | // Helper function to get the lrs 128 | std::vector get_current_lrs(const char* name); 129 | 130 | } // namespace torchfort 131 | -------------------------------------------------------------------------------- /src/csrc/include/torchfort_enums.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #define TORCHFORT_DEVICE_CPU (-1) 34 | 35 | /** 36 | * @brief This enum defines the data types supported. 37 | */ 38 | enum torchfort_datatype_t { TORCHFORT_FLOAT = -1, TORCHFORT_DOUBLE = -2, TORCHFORT_INT32 = -3, TORCHFORT_INT64 = -4 }; 39 | 40 | /** 41 | * @brief This enum defines the possible values return values from TorchFort. Most functions in the TorchFort library 42 | * will return one of these values to indicate if an operation has completed successfully or an error occured. 43 | */ 44 | enum torchfort_result_t { 45 | TORCHFORT_RESULT_SUCCESS = 0, ///< The operation completed successfully 46 | TORCHFORT_RESULT_INVALID_USAGE = 1, ///< A user error, typically an invalid argument 47 | TORCHFORT_RESULT_NOT_SUPPORTED = 2, ///< A user error, requesting an invalid or unsupported operation configuration 48 | TORCHFORT_RESULT_INTERNAL_ERROR = 3, ///< An internal library error, should be reported 49 | TORCHFORT_RESULT_CUDA_ERROR = 4, ///< An error occured in the CUDA Runtime 50 | TORCHFORT_RESULT_MPI_ERROR = 5, ///< An error occured in the MPI library 51 | TORCHFORT_RESULT_NCCL_ERROR = 6 ///< An error occured in the NCCL library 52 | }; 53 | -------------------------------------------------------------------------------- /src/csrc/logging.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | 38 | #include "internal/exceptions.h" 39 | #include "internal/logging.h" 40 | 41 | namespace torchfort { 42 | namespace logging { 43 | 44 | std::mutex logging_mutex; 45 | static std::unique_ptr logfile; 46 | 47 | std::string log_level_prefix(level log_level) { 48 | if (log_level == level::info) { 49 | return "TORCHFORT::INFO:"; 50 | } else if (log_level == level::warn) { 51 | return "TORCHFORT::WARN:"; 52 | } else if (log_level == level::error) { 53 | return "TORCHFORT::ERROR:"; 54 | } else if (log_level == level::wandb) { 55 | return "TORCHFORT::WANDB:"; 56 | } else { 57 | THROW_INVALID_USAGE("Unknown log level encountered."); 58 | } 59 | } 60 | 61 | void print(const std::string& message, level log_level) { 62 | std::cout << log_level_prefix(log_level) << " "; 63 | std::cout << message << std::endl; 64 | } 65 | 66 | bool open_logfile(const std::filesystem::path& filename) { 67 | 68 | // check if filename is empty, meaning we do not want to log 69 | if (filename.empty()) { 70 | return false; 71 | } 72 | 73 | // check if path exists 74 | if (filename.has_parent_path()) { 75 | auto path = filename.parent_path(); 76 | std::filesystem::create_directories(path); 77 | } 78 | 79 | logfile = std::make_unique(filename, std::ofstream::out | std::ofstream::app); 80 | 81 | return true; 82 | } 83 | 84 | void write(const std::filesystem::path& filename, const std::string& message, level log_level) { 85 | std::lock_guard guard(logging_mutex); 86 | 87 | // check of logfile if already open 88 | if (logfile == nullptr) { 89 | if (!open_logfile(filename)) 90 | return; 91 | } 92 | auto line = log_level_prefix(log_level) + " " + message + "\n"; 93 | logfile->write(line.c_str(), line.size()); 94 | logfile->flush(); 95 | } 96 | 97 | } // namespace logging 98 | } // namespace torchfort 99 | -------------------------------------------------------------------------------- /src/csrc/losses/l1_loss.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | #include 34 | #include 35 | 36 | #include 37 | #include 38 | 39 | #include "internal/losses.h" 40 | #include "internal/param_map.h" 41 | #include "internal/setup.h" 42 | #include "internal/utils.h" 43 | 44 | namespace torchfort { 45 | 46 | void L1Loss::setup(const ParamMap& params) { 47 | std::set supported_params{"reduction"}; 48 | check_params(supported_params, params.keys()); 49 | 50 | auto options = torch::nn::L1LossOptions(); 51 | try { 52 | std::string reduction = params.get_param("reduction")[0]; 53 | options = options.reduction(get_torch_reduction(reduction)); 54 | } catch (std::out_of_range) { 55 | // use default 56 | } 57 | 58 | module = torch::nn::L1Loss(options); 59 | } 60 | 61 | torch::Tensor L1Loss::forward(const std::vector& inputs, 62 | const std::vector& labels, 63 | const std::vector& extra_args) { 64 | if (inputs.size() != 1 || labels.size() != 1 || extra_args.size() != 0) { 65 | THROW_INVALID_USAGE("L1Loss only supports one input tensor, one label tensor, and no extra arguments."); 66 | } 67 | auto x = inputs[0]; 68 | auto y = labels[0]; 69 | return module(x.flatten(), y.flatten()); 70 | } 71 | 72 | } // namespace torchfort 73 | -------------------------------------------------------------------------------- /src/csrc/losses/mse_loss.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | #include 34 | #include 35 | 36 | #include 37 | #include 38 | 39 | #include "internal/exceptions.h" 40 | #include "internal/losses.h" 41 | #include "internal/param_map.h" 42 | #include "internal/setup.h" 43 | #include "internal/utils.h" 44 | 45 | namespace torchfort { 46 | 47 | void MSELoss::setup(const ParamMap& params) { 48 | std::set supported_params{"reduction"}; 49 | check_params(supported_params, params.keys()); 50 | 51 | auto options = torch::nn::MSELossOptions(); 52 | try { 53 | std::string reduction = params.get_param("reduction")[0]; 54 | options = options.reduction(get_torch_reduction(reduction)); 55 | } catch (std::out_of_range) { 56 | // use default 57 | } 58 | 59 | module = torch::nn::MSELoss(options); 60 | } 61 | 62 | torch::Tensor MSELoss::forward(const std::vector& inputs, 63 | const std::vector& labels, 64 | const std::vector& extra_args) { 65 | if (inputs.size() != 1 || labels.size() != 1 || extra_args.size() != 0) { 66 | THROW_INVALID_USAGE("MSELoss only supports one input tensor, one label tensor, and no extra arguments."); 67 | } 68 | auto x = inputs[0]; 69 | auto y = labels[0]; 70 | return module(x.flatten(), y.flatten()); 71 | } 72 | 73 | } // namespace torchfort 74 | -------------------------------------------------------------------------------- /src/csrc/losses/torchscript_loss.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | 38 | #include 39 | #include 40 | #include 41 | 42 | #include "internal/exceptions.h" 43 | #include "internal/losses.h" 44 | #include "internal/param_map.h" 45 | #include "internal/setup.h" 46 | #include "internal/utils.h" 47 | 48 | namespace torchfort { 49 | 50 | void TorchscriptLoss::setup(const ParamMap& params) { 51 | std::string jit_loss_fname; 52 | try { 53 | jit_loss_fname = params.get_param("filename")[0]; 54 | } catch (std::out_of_range) { 55 | THROW_INVALID_USAGE("filename parameter is required for torchscript loss type."); 56 | } 57 | 58 | if (!std::filesystem::exists(jit_loss_fname)) { 59 | THROW_INVALID_USAGE(jit_loss_fname + " does not exist."); 60 | } 61 | 62 | module_jit = std::shared_ptr(new torch::jit::Module); 63 | *module_jit = torch::jit::load(jit_loss_fname); 64 | } 65 | 66 | torch::Tensor TorchscriptLoss::forward(const std::vector& inputs, 67 | const std::vector& labels, 68 | const std::vector& extra_args) { 69 | std::vector inputs_jit; 70 | inputs_jit.insert(inputs_jit.end(), inputs.begin(), inputs.end()); 71 | inputs_jit.insert(inputs_jit.end(), labels.begin(), labels.end()); 72 | inputs_jit.insert(inputs_jit.end(), extra_args.begin(), extra_args.end()); 73 | 74 | auto result = module_jit->forward(inputs_jit); 75 | if (!result.isTensor()) { 76 | THROW_INVALID_USAGE("TorchscriptLoss only supports returning a single loss tensor."); 77 | } 78 | return result.toTensor(); 79 | } 80 | 81 | } // namespace torchfort 82 | -------------------------------------------------------------------------------- /src/csrc/lr_schedulers/cosine_annealing_lr.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | 33 | #include "internal/base_lr_scheduler.h" 34 | #include "internal/lr_schedulers.h" 35 | 36 | namespace torchfort { 37 | 38 | CosineAnnealingLR::CosineAnnealingLR(torch::optim::Optimizer& optimizer, const unsigned T_max, const double eta_min) 39 | : BaseLRScheduler(optimizer), T_max_(T_max), eta_min_(eta_min) { 40 | base_lrs_ = get_current_lrs(); 41 | } 42 | 43 | double CosineAnnealingLR::update_lr(const double& last_lr, const double& base_lr) { 44 | double lr; 45 | if ((step_count_ - 1 - T_max_) % (2 * T_max_) == 0) { 46 | lr = eta_min_ + 0.5 * (base_lr - eta_min_) * (1. + cos(double(step_count_) * M_PI / double(T_max_))); 47 | } else { 48 | lr = (1. + cos(M_PI * double(step_count_) / double(T_max_))) / 49 | (1. + cos(M_PI * double(step_count_ - 1) / double(T_max_))) * (last_lr - eta_min_) + 50 | eta_min_; 51 | } 52 | 53 | return lr; 54 | } 55 | 56 | std::vector CosineAnnealingLR::get_lrs() { 57 | std::vector lrs = get_current_lrs(); 58 | if (step_count_ == 0 || T_max_ == 0) 59 | return lrs; 60 | else { 61 | std::vector lrs_new; 62 | std::transform(lrs.begin(), lrs.end(), base_lrs_.begin(), std::back_inserter(lrs_new), 63 | [this](const auto& current, const auto& base) { return update_lr(current, base); }); 64 | return lrs_new; 65 | } 66 | } 67 | 68 | } // namespace torchfort 69 | -------------------------------------------------------------------------------- /src/csrc/lr_schedulers/linear_lr.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | 34 | #include "internal/base_lr_scheduler.h" 35 | #include "internal/lr_schedulers.h" 36 | 37 | namespace torchfort { 38 | 39 | LinearLR::LinearLR(torch::optim::Optimizer& optimizer, const unsigned total_iters, const double start_factor, 40 | const double end_factor) 41 | : BaseLRScheduler(optimizer), total_iters_(total_iters), start_factor_(start_factor), end_factor_(end_factor) {} 42 | 43 | std::vector LinearLR::get_lrs() { 44 | 45 | double factor; 46 | if (step_count_ == 0) { 47 | factor = start_factor_; 48 | } else if (step_count_ > total_iters_) { 49 | factor = 1.; 50 | } else { 51 | factor = (1. + (end_factor_ - start_factor_) / 52 | double(total_iters_ * start_factor_ + (step_count_ - 1) * (end_factor_ - start_factor_))); 53 | } 54 | 55 | // get current lrs and modify 56 | std::vector lrs = get_current_lrs(); 57 | std::transform(lrs.begin(), lrs.end(), lrs.begin(), [factor](const double& v) { return factor * v; }); 58 | 59 | return lrs; 60 | } 61 | 62 | } // namespace torchfort 63 | -------------------------------------------------------------------------------- /src/csrc/lr_schedulers/multistep_lr.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | 34 | #include "internal/base_lr_scheduler.h" 35 | #include "internal/lr_schedulers.h" 36 | 37 | namespace torchfort { 38 | 39 | MultiStepLR::MultiStepLR(torch::optim::Optimizer& optimizer, const std::vector& milestones, const double gamma) 40 | : BaseLRScheduler(optimizer), milestones_(milestones), gamma_(gamma) {} 41 | 42 | std::vector MultiStepLR::get_lrs() { 43 | std::vector lrs = get_current_lrs(); 44 | if (step_count_ == 0 || milestones_.size() == 0) 45 | return lrs; 46 | else { 47 | auto lower_old = std::lower_bound(milestones_.begin(), milestones_.end(), step_count_ - 1, 48 | [](const int& ms, int value) { return ms <= value; }); 49 | auto lower = std::lower_bound(milestones_.begin(), milestones_.end(), step_count_, 50 | [](const int& ms, int value) { return ms <= value; }); 51 | 52 | if (lower_old != lower) { 53 | // in this case we need to decay the LR: 54 | std::transform(lrs.begin(), lrs.end(), lrs.begin(), [this](const double& lr) { return this->gamma_ * lr; }); 55 | } 56 | 57 | return lrs; 58 | } 59 | } 60 | 61 | } // namespace torchfort 62 | -------------------------------------------------------------------------------- /src/csrc/lr_schedulers/polynomial_lr.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | #include 34 | 35 | #include "internal/base_lr_scheduler.h" 36 | #include "internal/lr_schedulers.h" 37 | 38 | namespace torchfort { 39 | 40 | PolynomialLR::PolynomialLR(torch::optim::Optimizer& optimizer, const unsigned total_iters, const double power) 41 | : BaseLRScheduler(optimizer), total_iters_(total_iters), power_(power) {} 42 | 43 | std::vector PolynomialLR::get_lrs() { 44 | std::vector lrs = get_current_lrs(); 45 | if (step_count_ == 0 || step_count_ > total_iters_) 46 | return lrs; 47 | else { 48 | double decay_factor = 49 | (1. - double(step_count_) / double(total_iters_)) / (1. - double(step_count_ - 1) / double(total_iters_)); 50 | decay_factor = std::pow(decay_factor, power_); 51 | 52 | std::transform(lrs.begin(), lrs.end(), lrs.begin(), [decay_factor](const double& v) { return decay_factor * v; }); 53 | 54 | return lrs; 55 | } 56 | } 57 | 58 | } // namespace torchfort 59 | -------------------------------------------------------------------------------- /src/csrc/lr_schedulers/step_lr.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | 34 | #include "internal/base_lr_scheduler.h" 35 | #include "internal/lr_schedulers.h" 36 | 37 | namespace torchfort { 38 | 39 | StepLR::StepLR(torch::optim::Optimizer& optimizer, const unsigned step_size, const double gamma) 40 | : BaseLRScheduler(optimizer), step_size_(step_size), gamma_(gamma) {} 41 | 42 | std::vector StepLR::get_lrs() { 43 | if (step_count_ == 0 || step_count_ % step_size_ != 0) 44 | return get_current_lrs(); 45 | else { 46 | std::vector lrs = get_current_lrs(); 47 | std::transform(lrs.begin(), lrs.end(), lrs.begin(), [this](const double& v) { return this->gamma_ * v; }); 48 | return lrs; 49 | } 50 | } 51 | 52 | } // namespace torchfort 53 | -------------------------------------------------------------------------------- /src/csrc/model_pack.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | 33 | #include "internal/defines.h" 34 | #include "internal/model_pack.h" 35 | #include "internal/model_wrapper.h" 36 | #include "internal/utils.h" 37 | 38 | namespace torchfort { 39 | 40 | void save_model_pack(const ModelPack& model_pack, const std::string& dir, bool save_optimizer) { 41 | std::filesystem::path root_dir(dir); 42 | 43 | if (!std::filesystem::exists(root_dir)) { 44 | bool rv = std::filesystem::create_directory(root_dir); 45 | if (!rv) { 46 | THROW_INVALID_USAGE("Could not create directory " + root_dir.native() + "."); 47 | } 48 | } 49 | 50 | model_pack.state->device = model_pack.model->device(); 51 | 52 | auto model_path = root_dir / "model.pt"; 53 | model_pack.model->save(model_path.native()); 54 | 55 | if (save_optimizer) { 56 | auto optimizer_path = root_dir / "optimizer.pt"; 57 | if (!model_pack.optimizer) { 58 | THROW_INVALID_USAGE("Cannot save checkpoint. Missing optimizer."); 59 | } 60 | torch::save(*(model_pack.optimizer), optimizer_path.native()); 61 | 62 | auto lr_path = root_dir / "lr.pt"; 63 | if (model_pack.lr_scheduler) { 64 | model_pack.lr_scheduler->save(lr_path.native()); 65 | } 66 | } 67 | 68 | auto state_path = root_dir / "state.pt"; 69 | model_pack.state->save(state_path.native()); 70 | } 71 | 72 | void load_model_pack(ModelPack& model_pack, const std::string& dir, bool load_optimizer) { 73 | std::filesystem::path root_dir(dir); 74 | 75 | auto state_path = root_dir / "state.pt"; 76 | if (!std::filesystem::exists(state_path)) { 77 | THROW_INVALID_USAGE("Could not find " + state_path.native() + "."); 78 | } 79 | model_pack.state->load(state_path.native()); 80 | 81 | auto model_path = root_dir / "model.pt"; 82 | if (!std::filesystem::exists(model_path)) { 83 | THROW_INVALID_USAGE("Could not find " + model_path.native() + "."); 84 | } 85 | model_pack.model->load(model_path.native()); 86 | 87 | // Assign optimizer to parameters of loaded model: 88 | // we need to check if the optimizer is initialized before doing so 89 | // (some RL models do not have an optimizer attached to them): 90 | if (model_pack.optimizer) { 91 | model_pack.optimizer->parameters() = model_pack.model->parameters(); 92 | } 93 | 94 | if (load_optimizer) { 95 | auto optimizer_path = root_dir / "optimizer.pt"; 96 | if (!std::filesystem::exists(optimizer_path)) { 97 | THROW_INVALID_USAGE("Could not find " + optimizer_path.native() + "."); 98 | } 99 | torch::load(*(model_pack.optimizer), optimizer_path.native(), model_pack.model->device()); 100 | 101 | auto lr_path = root_dir / "lr.pt"; 102 | if (std::filesystem::exists(lr_path)) { 103 | model_pack.lr_scheduler->load(lr_path.native(), *(model_pack.optimizer)); 104 | } else { 105 | // No LR in checkpoint, disable LR scheduler 106 | model_pack.lr_scheduler = nullptr; 107 | } 108 | } 109 | } 110 | 111 | } // namespace torchfort 112 | -------------------------------------------------------------------------------- /src/csrc/model_state.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | 33 | #include 34 | 35 | #include "internal/exceptions.h" 36 | #include "internal/model_state.h" 37 | 38 | namespace torchfort { 39 | 40 | void ModelState::save(const std::string& fname) { 41 | torch::serialize::OutputArchive archive; 42 | archive.write("step_train", torch::IValue(step_train)); 43 | archive.write("step_inference", torch::IValue(step_inference)); 44 | archive.write("device", torch::IValue(device)); 45 | archive.save_to(fname); 46 | } 47 | 48 | void ModelState::load(const std::string& fname) { 49 | if (!std::filesystem::exists(fname)) { 50 | THROW_INVALID_USAGE(fname + " does not exist."); 51 | } 52 | 53 | torch::serialize::InputArchive archive; 54 | archive.load_from(fname); 55 | 56 | torch::IValue ivalue; 57 | if (!archive.try_read("step_train", ivalue)) { 58 | THROW_INVALID_USAGE(fname + " is missing required data."); 59 | } 60 | step_train = ivalue.to(); 61 | 62 | if (!archive.try_read("step_inference", ivalue)) { 63 | THROW_INVALID_USAGE(fname + " is missing required data."); 64 | } 65 | step_inference = ivalue.to(); 66 | 67 | if (!archive.try_read("device", ivalue)) { 68 | THROW_INVALID_USAGE(fname + " is missing required data."); 69 | } 70 | device = ivalue.to(); 71 | } 72 | 73 | } // namespace torchfort 74 | -------------------------------------------------------------------------------- /src/csrc/models/mlp_model.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | 33 | #include 34 | 35 | #include "internal/models.h" 36 | #include "internal/param_map.h" 37 | #include "internal/setup.h" 38 | 39 | namespace torchfort { 40 | 41 | // MLP model in C++ using libtorch 42 | void MLPModel::setup(const ParamMap& params) { 43 | // Extract params from input map. 44 | std::set supported_params{"dropout", "layer_sizes"}; 45 | check_params(supported_params, params.keys()); 46 | 47 | dropout = params.get_param("dropout", 0.0)[0]; 48 | layer_sizes = params.get_param("layer_sizes"); 49 | 50 | // Construct and register submodules. 51 | for (int i = 0; i < layer_sizes.size() - 1; ++i) { 52 | fc_layers.push_back( 53 | register_module("fc" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1]))); 54 | if (i < layer_sizes.size() - 2) { 55 | biases.push_back(register_parameter("b" + std::to_string(i), torch::zeros(layer_sizes[i + 1]))); 56 | } 57 | } 58 | } 59 | 60 | // Implement the forward function. 61 | std::vector MLPModel::forward(const std::vector& inputs) { 62 | // concatenate inputs 63 | auto x = torch::cat(inputs, 1); 64 | x = x.reshape({x.size(0), -1}); 65 | 66 | for (int i = 0; i < layer_sizes.size() - 1; ++i) { 67 | if (i < layer_sizes.size() - 2) { 68 | x = torch::relu(fc_layers[i]->forward(x) + biases[i]); 69 | x = torch::dropout(x, dropout, is_training()); 70 | } else { 71 | x = fc_layers[i]->forward(x); 72 | } 73 | } 74 | return std::vector{x}; 75 | } 76 | 77 | } // namespace torchfort 78 | -------------------------------------------------------------------------------- /src/csrc/models/sac_model.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "internal/models.h" 6 | #include "internal/param_map.h" 7 | #include "internal/setup.h" 8 | 9 | namespace torchfort { 10 | 11 | // SACMLP model in C++ using libtorch 12 | void SACMLPModel::setup(const ParamMap& params) { 13 | // Extract params from input map. 14 | std::set supported_params{"dropout", "layer_sizes", "state_dependent_sigma", "log_sigma_init"}; 15 | check_params(supported_params, params.keys()); 16 | 17 | dropout = params.get_param("dropout", 0.0)[0]; 18 | layer_sizes = params.get_param("layer_sizes"); 19 | state_dependent_sigma = params.get_param("state_dependent_sigma", true)[0]; 20 | double log_sigma_init = params.get_param("log_sigma_init", 0.)[0]; 21 | 22 | // Construct and register submodules. 23 | for (int i = 0; i < layer_sizes.size() - 1; ++i) { 24 | if (i < layer_sizes.size() - 2) { 25 | encoder_layers.push_back( 26 | register_module("encoder_fc_" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1]))); 27 | biases.push_back(register_parameter("encoder_b_" + std::to_string(i), torch::zeros(layer_sizes[i + 1]))); 28 | } else { 29 | // first output: mu 30 | out_layers.push_back( 31 | register_module("out_fc_1_" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1]))); 32 | out_biases.push_back(register_parameter("out_b_1_" + std::to_string(i), torch::zeros(layer_sizes[i + 1]))); 33 | // second output: log_sigma 34 | if (state_dependent_sigma) { 35 | out_layers.push_back( 36 | register_module("out_fc_2_" + std::to_string(i), torch::nn::Linear(layer_sizes[i], layer_sizes[i + 1]))); 37 | out_biases.push_back(register_parameter("out_b_2_" + std::to_string(i), torch::zeros(layer_sizes[i + 1]))); 38 | } else { 39 | out_biases.push_back( 40 | register_parameter("out_b_2_" + std::to_string(i), torch::ones(layer_sizes[i + 1]) * log_sigma_init)); 41 | } 42 | } 43 | } 44 | } 45 | 46 | // Implement the forward function. 47 | std::vector SACMLPModel::forward(const std::vector& inputs) { 48 | // concatenate inputs 49 | auto x = torch::cat(inputs, 1); 50 | x = x.reshape({x.size(0), -1}); 51 | torch::Tensor y, z; 52 | 53 | for (int i = 0; i < layer_sizes.size() - 1; ++i) { 54 | if (i < layer_sizes.size() - 2) { 55 | // encoder part 56 | x = torch::relu(encoder_layers[i]->forward(x) + biases[i]); 57 | x = torch::dropout(x, dropout, is_training()); 58 | } else { 59 | // y 60 | y = out_layers[0]->forward(x) + out_biases[0]; 61 | // z 62 | if (state_dependent_sigma) { 63 | z = out_layers[1]->forward(x) + out_biases[1]; 64 | } else { 65 | z = out_biases[1]; 66 | } 67 | } 68 | } 69 | return std::vector{y, z}; 70 | } 71 | 72 | } // namespace torchfort 73 | -------------------------------------------------------------------------------- /src/csrc/param_map.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | 34 | #include "internal/param_map.h" 35 | 36 | namespace torchfort { 37 | 38 | std::set ParamMap::keys() const { 39 | std::set keys; 40 | for (const auto& entry : params) { 41 | keys.insert(entry.first); 42 | } 43 | return keys; 44 | } 45 | 46 | } // namespace torchfort 47 | -------------------------------------------------------------------------------- /src/csrc/utils.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | #include 34 | #include 35 | 36 | #include 37 | 38 | #include "internal/defines.h" 39 | #include "internal/model_pack.h" 40 | #include "internal/utils.h" 41 | 42 | namespace torchfort { 43 | 44 | // Declaration of external global variables 45 | extern std::unordered_map models; 46 | 47 | std::string sanitize(std::string s) { 48 | s.erase(std::remove(s.begin(), s.end(), ' '), s.end()); 49 | std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); }); 50 | return s; 51 | } 52 | 53 | std::string filename_sanitize(std::string s) { 54 | // remove trailing whitespace 55 | s.erase(std::remove(s.begin(), s.end(), ' '), s.end()); 56 | 57 | // replace intermediate whitespace 58 | s = std::regex_replace(s, std::regex(" "), "_"); 59 | 60 | // replace all / with _: 61 | s = std::regex_replace(s, std::regex("/"), "-"); 62 | 63 | return s; 64 | } 65 | 66 | torch::Device get_device(int device) { 67 | torch::Device device_torch(torch::kCPU); 68 | if (device != TORCHFORT_DEVICE_CPU) { 69 | #ifdef ENABLE_GPU 70 | device_torch = torch::Device(torch::kCUDA, device); 71 | #else 72 | THROW_NOT_SUPPORTED( 73 | "Attempted to place a model or other component on GPU but TorchFort was build without GPU support."); 74 | #endif 75 | } 76 | return device_torch; 77 | } 78 | 79 | torch::Device get_device(const void* ptr) { 80 | torch::Device device = torch::Device(torch::kCPU); 81 | #ifdef ENABLE_GPU 82 | cudaPointerAttributes attr; 83 | CHECK_CUDA(cudaPointerGetAttributes(&attr, ptr)); 84 | switch (attr.type) { 85 | case cudaMemoryTypeHost: 86 | case cudaMemoryTypeUnregistered: 87 | device = torch::Device(torch::kCPU); 88 | break; 89 | case cudaMemoryTypeManaged: 90 | case cudaMemoryTypeDevice: 91 | device = torch::Device(torch::kCUDA); 92 | break; 93 | } 94 | #endif 95 | return device; 96 | } 97 | 98 | std::string print_tensor_shape(torch::Tensor tensor) { 99 | std::string shapestr = "("; 100 | for (int i = 0; i < tensor.dim(); ++i) 101 | shapestr += std::to_string(tensor.size(i)) + ","; 102 | shapestr.pop_back(); 103 | shapestr += ")"; 104 | return shapestr; 105 | } 106 | 107 | std::vector get_current_lrs(const char* name) { 108 | auto optimizer = models[name].optimizer; 109 | std::vector learnings_rates(optimizer->param_groups().size()); 110 | if (learnings_rates.size() > 0) { 111 | for (const auto i : c10::irange(optimizer->param_groups().size())) { 112 | learnings_rates[i] = optimizer->param_groups()[i].options().get_lr(); 113 | } 114 | } 115 | return learnings_rates; 116 | } 117 | 118 | 119 | } // namespace torchfort 120 | -------------------------------------------------------------------------------- /tests/general/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | 3 | set(test_targets 4 | test_losses 5 | ) 6 | 7 | add_executable(test_losses) 8 | target_sources(test_losses 9 | PRIVATE 10 | test_losses.cpp 11 | ) 12 | 13 | find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED) 14 | 15 | foreach(tgt ${test_targets}) 16 | target_include_directories(${tgt} 17 | PRIVATE 18 | ${YAML_CPP_INCLUDE_DIR} 19 | ${MPI_CXX_INCLUDE_DIRS} 20 | ${CMAKE_BINARY_DIR}/include 21 | ${CMAKE_CURRENT_SOURCE_DIR}/../ 22 | ) 23 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME}) 24 | target_link_libraries(${tgt} PRIVATE ${TORCH_LIBRARIES}) 25 | target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY}) 26 | target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX) 27 | target_link_libraries(${tgt} PRIVATE GTest::gtest_main) 28 | target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 29 | target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 30 | if (TORCHFORT_ENABLE_GPU) 31 | target_include_directories(${tgt} 32 | PRIVATE 33 | ${CUDAToolkit_INCLUDE_DIRS} 34 | ) 35 | target_link_libraries(${tgt} PRIVATE CUDA::cudart) 36 | target_compile_definitions(${tgt} PRIVATE ENABLE_GPU) 37 | endif() 38 | 39 | # discover tests: we have an issue with the work dir of gtest so disable that for now 40 | #gtest_discover_tests(${tgt}) 41 | add_test(NAME ${tgt} COMMAND ${tgt} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) 42 | endforeach() 43 | 44 | # installation 45 | # executable 46 | install( 47 | TARGETS ${test_targets} 48 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general 49 | ) 50 | 51 | # copy files 52 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mse.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 53 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mse_multiarg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 54 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/l1.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 55 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/l1_multiarg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 56 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 57 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 58 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg_extra.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 59 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiout.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/configs) 60 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/scripts/setup_tests.py DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/general/scripts) 61 | -------------------------------------------------------------------------------- /tests/general/configs/l1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model.pt 5 | 6 | loss: 7 | type: L1 8 | 9 | optimizer: 10 | type: adam 11 | -------------------------------------------------------------------------------- /tests/general/configs/l1_multiarg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model_multiarg.pt 5 | 6 | loss: 7 | type: L1 8 | 9 | optimizer: 10 | type: adam 11 | -------------------------------------------------------------------------------- /tests/general/configs/mse.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model.pt 5 | 6 | loss: 7 | type: MSE 8 | 9 | optimizer: 10 | type: adam 11 | -------------------------------------------------------------------------------- /tests/general/configs/mse_multiarg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model_multiarg.pt 5 | 6 | loss: 7 | type: MSE 8 | 9 | optimizer: 10 | type: adam 11 | -------------------------------------------------------------------------------- /tests/general/configs/torchscript.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model.pt 5 | 6 | loss: 7 | type: torchscript 8 | parameters: 9 | filename: loss.pt 10 | 11 | optimizer: 12 | type: adam 13 | -------------------------------------------------------------------------------- /tests/general/configs/torchscript_multiarg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model_multiarg.pt 5 | 6 | loss: 7 | type: torchscript 8 | parameters: 9 | filename: loss_multiarg.pt 10 | 11 | optimizer: 12 | type: adam 13 | -------------------------------------------------------------------------------- /tests/general/configs/torchscript_multiarg_extra.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model_multiarg.pt 5 | 6 | loss: 7 | type: torchscript 8 | parameters: 9 | filename: loss_multiarg_extra.pt 10 | 11 | optimizer: 12 | type: adam 13 | -------------------------------------------------------------------------------- /tests/general/configs/torchscript_multiout.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: model.pt 5 | 6 | loss: 7 | type: torchscript 8 | parameters: 9 | filename: loss_multiout.pt 10 | 11 | optimizer: 12 | type: adam 13 | -------------------------------------------------------------------------------- /tests/general/scripts/setup_tests.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | def save_jit_module(module, fname): 5 | try: 6 | module.to("cuda") 7 | except: 8 | print("PyTorch does not have CUDA support. Saving on CPU.") 9 | module_jit = torch.jit.script(module) 10 | 11 | module_jit.save(fname) 12 | 13 | # Create simple models that just return input for testing 14 | class Net1(torch.nn.Module): 15 | def __init__(self): 16 | super(Net1, self).__init__() 17 | self.layer = torch.nn.Linear(10, 10) 18 | 19 | def forward(self, input1): 20 | x = self.layer(input1) 21 | return input1 + 0.0 * x 22 | 23 | class Net2(torch.nn.Module): 24 | def __init__(self): 25 | super(Net2, self).__init__() 26 | self.layer = torch.nn.Linear(10, 10) 27 | 28 | def forward(self, input1, input2): 29 | x = self.layer(input1) 30 | return input1 + 0.0 * x, input2 + 0.0 * x 31 | 32 | 33 | # Create loss functions with various argument combinations 34 | class Loss1(torch.nn.Module): 35 | def __init__(self): 36 | super(Loss1, self).__init__() 37 | 38 | def forward(self, prediction, label): 39 | return (torch.sum(prediction) + torch.sum(label)) / (2 * prediction.numel()) 40 | 41 | class Loss2(torch.nn.Module): 42 | def __init__(self): 43 | super(Loss2, self).__init__() 44 | 45 | def forward(self, prediction1, prediction2, label1, label2): 46 | return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2)) / (4 * prediction1.numel()) 47 | 48 | class Loss2Extra(torch.nn.Module): 49 | def __init__(self): 50 | super(Loss2Extra, self).__init__() 51 | 52 | def forward(self, prediction1, prediction2, label1, label2, extra_args1, extra_args2): 53 | return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2) + 54 | torch.sum(extra_args1) + torch.sum(extra_args2)) / (6 * prediction1.numel()) 55 | 56 | class Loss3(torch.nn.Module): 57 | def __init__(self): 58 | super(Loss3, self).__init__() 59 | 60 | def forward(self, prediction, label): 61 | return torch.sum(prediction), torch.sum(label) 62 | 63 | def main(): 64 | model1 = Net1() 65 | model2 = Net2() 66 | loss1 = Loss1() 67 | loss2 = Loss2() 68 | loss2_extra = Loss2Extra() 69 | loss3 = Loss3() 70 | 71 | save_jit_module(model1, "model.pt") 72 | save_jit_module(model2, "model_multiarg.pt") 73 | save_jit_module(loss1, "loss.pt") 74 | save_jit_module(loss2, "loss_multiarg.pt") 75 | save_jit_module(loss2_extra, "loss_multiarg_extra.pt") 76 | save_jit_module(loss3, "loss_multiout.pt") 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /tests/rl/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | 3 | set(test_targets 4 | test_replay_buffer 5 | test_rollout_buffer 6 | test_distributions 7 | test_off_policy 8 | test_on_policy 9 | ) 10 | 11 | add_executable(test_replay_buffer) 12 | target_sources(test_replay_buffer 13 | PRIVATE 14 | test_replay_buffer.cpp 15 | ) 16 | 17 | add_executable(test_rollout_buffer) 18 | target_sources(test_rollout_buffer 19 | PRIVATE 20 | test_rollout_buffer.cpp 21 | ) 22 | 23 | add_executable(test_distributions) 24 | target_sources(test_distributions 25 | PRIVATE 26 | test_distributions.cpp 27 | ) 28 | 29 | add_executable(test_off_policy) 30 | target_sources(test_off_policy 31 | PRIVATE 32 | test_off_policy.cpp 33 | ) 34 | 35 | add_executable(test_on_policy) 36 | target_sources(test_on_policy 37 | PRIVATE 38 | test_on_policy.cpp 39 | ) 40 | 41 | 42 | find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED) 43 | 44 | foreach(tgt ${test_targets}) 45 | target_include_directories(${tgt} 46 | PRIVATE 47 | ${YAML_CPP_INCLUDE_DIR} 48 | ${MPI_CXX_INCLUDE_DIRS} 49 | ${CMAKE_BINARY_DIR}/include 50 | ) 51 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME}) 52 | target_link_libraries(${tgt} PRIVATE ${TORCH_LIBRARIES}) 53 | target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY}) 54 | target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX) 55 | target_link_libraries(${tgt} PRIVATE GTest::gtest_main) 56 | target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 57 | target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 58 | if (TORCHFORT_ENABLE_GPU) 59 | target_include_directories(${tgt} 60 | PRIVATE 61 | ${CUDAToolkit_INCLUDE_DIRS} 62 | ) 63 | target_link_libraries(${tgt} PRIVATE CUDA::cudart) 64 | target_compile_definitions(${tgt} PRIVATE ENABLE_GPU) 65 | endif() 66 | 67 | # discover tests: we have an issue with the work dir of gtest so disable that for now 68 | #gtest_discover_tests(${tgt}) 69 | add_test(NAME ${tgt} COMMAND ${tgt} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) 70 | endforeach() 71 | 72 | # installation 73 | # executable 74 | install( 75 | TARGETS ${test_targets} 76 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl 77 | ) 78 | 79 | # copy files 80 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/td3.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl/configs) 81 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/ddpg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl/configs) 82 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/sac.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl/configs) 83 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/ppo.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/rl/configs) 84 | -------------------------------------------------------------------------------- /tests/rl/configs/ddpg.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | report_frequency: 1 3 | enable_wandb_hook: 0 4 | verbose: 1 5 | 6 | algorithm: 7 | type: ddpg 8 | parameters: 9 | batch_size: 128 10 | nstep: 1 11 | nstep_reward_reduction: sum 12 | gamma: 0.95 13 | rho: 0.999 14 | 15 | actor: 16 | type: space_noise 17 | parameters: 18 | a_low: -1. 19 | a_high: 1. 20 | clip: 0.3 21 | sigma_train: 0.2 22 | sigma_explore: 1.0 23 | 24 | replay_buffer: 25 | type: uniform 26 | parameters: 27 | max_size: 4096 28 | min_size: 512 29 | 30 | policy_model: 31 | type: MLP 32 | parameters: 33 | dropout: 0.0 34 | layer_sizes: [1, 16, 1] 35 | 36 | critic_model: 37 | type: MLP 38 | parameters: 39 | dropout: 0.0 40 | layer_sizes: [2, 16, 1] 41 | 42 | optimizer: 43 | type: adam 44 | parameters: 45 | learning_rate: 1e-4 46 | beta1: 0.9 47 | beta2: 0.999 48 | weight_decay: 0 49 | eps: 1e-6 50 | amsgrad: 0 51 | 52 | policy_lr_scheduler: 53 | type: linear 54 | parameters: 55 | total_iters: 20000 56 | start_factor: 1.0 57 | end_factor: 0.01 58 | 59 | critic_lr_scheduler: 60 | type: linear 61 | parameters: 62 | total_iters: 20000 63 | start_factor: 1.0 64 | end_factor: 0.01 -------------------------------------------------------------------------------- /tests/rl/configs/ppo.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | report_frequency: 1 3 | enable_wandb_hook: 0 4 | verbose: 1 5 | 6 | algorithm: 7 | type: ppo 8 | parameters: 9 | batch_size: 8 10 | gamma: 0.99 11 | gae_lambda: 0.95 12 | epsilon: 0.2 13 | clip_q: 0. 14 | target_kl_divergence: 0.02 15 | entropy_loss_coefficient: 0. 16 | value_loss_coefficient: 0.5 17 | max_grad_norm: 0.5 18 | normalize_advantage: True 19 | 20 | actor: 21 | type: gaussian_ac 22 | parameters: 23 | a_low: -1.0 24 | a_high: 1.0 25 | 26 | rollout_buffer: 27 | type: gae_lambda 28 | parameters: 29 | size: 64 30 | 31 | actor_critic_model: 32 | type: ActorCriticMLP 33 | parameters: 34 | dropout: 0.0 35 | encoder_layer_sizes: [1, 16, 8] 36 | actor_layer_sizes: [8, 1] 37 | value_layer_sizes: [8, 1] 38 | state_dependent_sigma: False 39 | log_sigma_init: 0. 40 | 41 | optimizer: 42 | type: adam 43 | parameters: 44 | learning_rate: 1e-4 45 | beta1: 0.9 46 | beta2: 0.999 47 | weight_decay: 0 48 | eps: 1e-6 49 | amsgrad: 0 50 | 51 | lr_scheduler: 52 | type: linear 53 | parameters: 54 | total_iters: 40000 55 | start_factor: 1.0 56 | end_factor: 0.01 57 | -------------------------------------------------------------------------------- /tests/rl/configs/sac.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | report_frequency: 1 3 | enable_wandb_hook: 0 4 | verbose: 1 5 | 6 | algorithm: 7 | type: sac 8 | parameters: 9 | batch_size: 128 10 | num_critics: 2 11 | nstep: 1 12 | nstep_reward_reduction: sum 13 | gamma: 0.95 14 | rho: 0.999 15 | alpha: 0.1 16 | 17 | actor: 18 | type: parameter_noise 19 | parameters: 20 | a_low: -1.0 21 | a_high: 1.0 22 | 23 | replay_buffer: 24 | type: uniform 25 | parameters: 26 | max_size: 4096 27 | min_size: 512 28 | 29 | policy_model: 30 | type: SACMLP 31 | parameters: 32 | dropout: 0.0 33 | layer_sizes: [1, 16, 8, 1] 34 | state_dependent_sigma: False 35 | log_sigma_init: 0. 36 | 37 | critic_model: 38 | type: MLP 39 | parameters: 40 | dropout: 0.0 41 | layer_sizes: [2, 16, 1] 42 | 43 | optimizer: 44 | type: adam 45 | parameters: 46 | learning_rate: 1e-4 47 | beta1: 0.9 48 | beta2: 0.999 49 | weight_decay: 0 50 | eps: 1e-6 51 | amsgrad: 0 52 | 53 | alpha_optimizer: 54 | type: adam 55 | parameters: 56 | learning_rate: 1e-4 57 | beta1: 0.9 58 | beta2: 0.999 59 | weight_decay: 0 60 | eps: 1e-6 61 | amsgrad: 0 62 | 63 | policy_lr_scheduler: 64 | type: linear 65 | parameters: 66 | total_iters: 20000 67 | start_factor: 1.0 68 | end_factor: 0.01 69 | 70 | critic_lr_scheduler: 71 | type: linear 72 | parameters: 73 | total_iters: 20000 74 | start_factor: 1.0 75 | end_factor: 0.01 76 | 77 | alpha_lr_scheduler: 78 | type: linear 79 | parameters: 80 | total_iters: 20000 81 | start_factor: 1.0 82 | end_factor: 0.01 83 | -------------------------------------------------------------------------------- /tests/rl/configs/td3.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | report_frequency: 1 3 | enable_wandb_hook: 0 4 | verbose: 1 5 | 6 | algorithm: 7 | type: td3 8 | parameters: 9 | batch_size: 128 10 | num_critics: 2 11 | policy_lag: 2 12 | nstep: 1 13 | nstep_reward_reduction: sum 14 | gamma: 0.95 15 | rho: 0.999 16 | 17 | actor: 18 | type: space_noise 19 | parameters: 20 | a_low: -1. 21 | a_high: 1. 22 | clip: 0.3 23 | sigma_train: 0.2 24 | sigma_explore: 1.0 25 | 26 | replay_buffer: 27 | type: uniform 28 | parameters: 29 | max_size: 4096 30 | min_size: 512 31 | 32 | policy_model: 33 | type: MLP 34 | parameters: 35 | dropout: 0.0 36 | layer_sizes: [1, 16, 1] 37 | 38 | critic_model: 39 | type: MLP 40 | parameters: 41 | dropout: 0.0 42 | layer_sizes: [2, 16, 1] 43 | 44 | optimizer: 45 | type: adam 46 | parameters: 47 | learning_rate: 1e-4 48 | beta1: 0.9 49 | beta2: 0.999 50 | weight_decay: 0 51 | eps: 1e-6 52 | amsgrad: 0 53 | 54 | policy_lr_scheduler: 55 | type: linear 56 | parameters: 57 | total_iters: 10000 58 | start_factor: 1.0 59 | end_factor: 0.01 60 | 61 | critic_lr_scheduler: 62 | type: linear 63 | parameters: 64 | total_iters: 20000 65 | start_factor: 1.0 66 | end_factor: 0.01 -------------------------------------------------------------------------------- /tests/rl/test_distributions.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include "internal/rl/distributions.h" 32 | #include 33 | #include 34 | 35 | using namespace torchfort; 36 | using namespace torch::indexing; 37 | 38 | TEST(NormalDistribution, RandomSampling) { 39 | // rng 40 | torch::manual_seed(666); 41 | 42 | // no grad guard 43 | torch::NoGradGuard no_grad; 44 | 45 | // create normal distribution with given shape 46 | torch::Tensor mutens = torch::empty({4, 8}, torch::kFloat32); 47 | torch::Tensor log_sigmatens = torch::empty({4, 8}, torch::kFloat32); 48 | 49 | // fill with random elements 50 | mutens.normal_(); 51 | log_sigmatens.normal_(); 52 | torch::Tensor sigmatens = torch::exp(log_sigmatens); 53 | 54 | auto ndist = rl::NormalDistribution(mutens, sigmatens); 55 | torch::Tensor sample = ndist.rsample(); 56 | 57 | // do direct sampling without reparametrization trick 58 | torch::Tensor sample_compare = at::normal(mutens, sigmatens); 59 | 60 | // expect that shapes match: I am not sure how to compare the values as well 61 | EXPECT_NO_THROW(torch::sum(sample - sample_compare).item()); 62 | } 63 | 64 | int main(int argc, char* argv[]) { 65 | ::testing::InitGoogleTest(&argc, argv); 66 | 67 | return RUN_ALL_TESTS(); 68 | } 69 | -------------------------------------------------------------------------------- /tests/supervised/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | 3 | set(test_targets 4 | test_checkpoint 5 | test_training 6 | ) 7 | 8 | add_executable(test_checkpoint) 9 | target_sources(test_checkpoint 10 | PRIVATE 11 | test_checkpoint.cpp 12 | ) 13 | 14 | add_executable(test_training) 15 | target_sources(test_training 16 | PRIVATE 17 | test_training.cpp 18 | ) 19 | 20 | find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED) 21 | 22 | foreach(tgt ${test_targets}) 23 | target_include_directories(${tgt} 24 | PRIVATE 25 | ${YAML_CPP_INCLUDE_DIR} 26 | ${MPI_CXX_INCLUDE_DIRS} 27 | ${CMAKE_BINARY_DIR}/include 28 | ${CMAKE_CURRENT_SOURCE_DIR}/../ 29 | ) 30 | target_link_libraries(${tgt} PRIVATE ${PROJECT_NAME}) 31 | target_link_libraries(${tgt} PRIVATE ${TORCH_LIBRARIES}) 32 | target_link_libraries(${tgt} PRIVATE ${YAML_CPP_LIBRARY}) 33 | target_link_libraries(${tgt} PRIVATE MPI::MPI_CXX) 34 | target_link_libraries(${tgt} PRIVATE GTest::gtest_main) 35 | target_compile_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 36 | target_link_options(${tgt} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 37 | if (TORCHFORT_ENABLE_GPU) 38 | target_include_directories(${tgt} 39 | PRIVATE 40 | ${CUDAToolkit_INCLUDE_DIRS} 41 | ) 42 | target_link_libraries(${tgt} PRIVATE CUDA::cudart) 43 | target_compile_definitions(${tgt} PRIVATE ENABLE_GPU) 44 | endif() 45 | 46 | # discover tests: we have an issue with the work dir of gtest so disable that for now 47 | #gtest_discover_tests(${tgt}) 48 | add_test(NAME ${tgt} COMMAND ${tgt} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) 49 | endforeach() 50 | 51 | # installation 52 | # executable 53 | install( 54 | TARGETS ${test_targets} 55 | RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised 56 | ) 57 | 58 | # copy files 59 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mlp.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 60 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mlp2.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 61 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mlp2_gradacc.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 62 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/missing_opt.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 63 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/missing_loss.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 64 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 65 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 66 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg_extra.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) 67 | install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/scripts/setup_tests.py DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/scripts) 68 | -------------------------------------------------------------------------------- /tests/supervised/configs/missing_loss.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: mlp 3 | parameters: 4 | dropout: 0.0 5 | layer_sizes: [10, 10, 10] 6 | 7 | optimizer: 8 | type: adam 9 | -------------------------------------------------------------------------------- /tests/supervised/configs/missing_opt.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: mlp 3 | parameters: 4 | dropout: 0.0 5 | layer_sizes: [10, 10, 10] 6 | 7 | loss: 8 | type: MSE 9 | -------------------------------------------------------------------------------- /tests/supervised/configs/mlp.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | enable_wandb_hook: 0 3 | report_frequency: 100 4 | 5 | model: 6 | type: mlp 7 | parameters: 8 | dropout: 0.0 9 | layer_sizes: [32, 32, 1] 10 | 11 | loss: 12 | type: MSE 13 | 14 | optimizer: 15 | type: adam 16 | parameters: 17 | learning_rate: 1e-3 18 | beta1: 0.9 19 | beta2: 0.999 20 | weight_decay: 0 21 | eps: 1e-8 22 | amsgrad: 0 23 | 24 | lr_scheduler: 25 | type: cosine_annealing 26 | parameters: 27 | T_max: 100000 28 | -------------------------------------------------------------------------------- /tests/supervised/configs/mlp2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: mlp 3 | parameters: 4 | dropout: 0.0 5 | layer_sizes: [10, 10, 10] 6 | 7 | loss: 8 | type: MSE 9 | 10 | optimizer: 11 | type: adam 12 | -------------------------------------------------------------------------------- /tests/supervised/configs/mlp2_gradacc.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: mlp 3 | parameters: 4 | dropout: 0.0 5 | layer_sizes: [10, 10, 10] 6 | 7 | loss: 8 | type: MSE 9 | 10 | optimizer: 11 | type: adam 12 | general: 13 | grad_accumulation_steps: 4 14 | -------------------------------------------------------------------------------- /tests/supervised/configs/torchscript.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: "model.pt" 5 | 6 | loss: 7 | type: MSE 8 | 9 | optimizer: 10 | type: adam 11 | -------------------------------------------------------------------------------- /tests/supervised/configs/torchscript_multiarg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: "model_multiarg.pt" 5 | 6 | loss: 7 | type: torchscript 8 | parameters: 9 | filename: "loss_multiarg.pt" 10 | 11 | optimizer: 12 | type: adam 13 | -------------------------------------------------------------------------------- /tests/supervised/configs/torchscript_multiarg_extra.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: torchscript 3 | parameters: 4 | filename: "model_multiarg.pt" 5 | 6 | loss: 7 | type: torchscript 8 | parameters: 9 | filename: "loss_multiarg_extra.pt" 10 | 11 | optimizer: 12 | type: adam 13 | -------------------------------------------------------------------------------- /tests/supervised/scripts/setup_tests.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | def save_jit_module(module, fname): 5 | try: 6 | module.to("cuda") 7 | except: 8 | print("PyTorch does not have CUDA support. Saving on CPU.") 9 | module_jit = torch.jit.script(module) 10 | 11 | module_jit.save(fname) 12 | 13 | # Create simple models that just return input for testing 14 | class Net1(torch.nn.Module): 15 | def __init__(self): 16 | super(Net1, self).__init__() 17 | self.layer = torch.nn.Linear(10, 10) 18 | 19 | def forward(self, input1): 20 | x = self.layer(input1) 21 | return input1 + 0.0 * x 22 | 23 | class Net2(torch.nn.Module): 24 | def __init__(self): 25 | super(Net2, self).__init__() 26 | self.layer = torch.nn.Linear(10, 10) 27 | 28 | def forward(self, input1, input2): 29 | x = self.layer(input1) 30 | return input1 + 0.0 * x, input2 + 0.0 * x 31 | 32 | 33 | # Create loss functions with various argument combinations 34 | class Loss1(torch.nn.Module): 35 | def __init__(self): 36 | super(Loss1, self).__init__() 37 | 38 | def forward(self, prediction, label): 39 | return (torch.sum(prediction) + torch.sum(label)) / (2 * prediction.numel()) 40 | 41 | class Loss2(torch.nn.Module): 42 | def __init__(self): 43 | super(Loss2, self).__init__() 44 | 45 | def forward(self, prediction1, prediction2, label1, label2): 46 | return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2)) / (4 * prediction1.numel()) 47 | 48 | class Loss2Extra(torch.nn.Module): 49 | def __init__(self): 50 | super(Loss2Extra, self).__init__() 51 | 52 | def forward(self, prediction1, prediction2, label1, label2, extra_args1, extra_args2): 53 | return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2) + 54 | torch.sum(extra_args1) + torch.sum(extra_args2)) / (6 * prediction1.numel()) 55 | 56 | def main(): 57 | model1 = Net1() 58 | model2 = Net2() 59 | loss1 = Loss1() 60 | loss2 = Loss2() 61 | loss2_extra = Loss2Extra() 62 | 63 | save_jit_module(model1, "model.pt") 64 | save_jit_module(model2, "model_multiarg.pt") 65 | save_jit_module(loss1, "loss.pt") 66 | save_jit_module(loss2, "loss_multiarg.pt") 67 | save_jit_module(loss2_extra, "loss_multiarg_extra.pt") 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /tests/test_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | #pragma once 31 | 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | 39 | #ifdef ENABLE_GPU 40 | #include 41 | #endif 42 | 43 | // Generate random vector data for testing 44 | template 45 | std::vector generate_random(const std::vector& shape) { 46 | 47 | int64_t num_values = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); 48 | std::vector data(num_values); 49 | 50 | std::mt19937 generator; 51 | unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); 52 | generator.seed(seed); 53 | std::uniform_real_distribution dist((T)0, (T)1); 54 | 55 | auto r = [&]() { 56 | return dist(generator); 57 | }; 58 | 59 | std::generate(data.begin(), data.end(), r); 60 | 61 | return data; 62 | 63 | } 64 | 65 | // Generate random names to use as model keys to avoid conflicts between tests 66 | std::string generate_random_name(int length) { 67 | 68 | const std::string character_set = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; 69 | std::mt19937 generator; 70 | unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); 71 | generator.seed(seed); 72 | std::uniform_int_distribution<> dist(0, character_set.size() - 1); 73 | 74 | std::string name; 75 | for (int i = 0; i < length; ++i) { 76 | name += character_set[dist(generator)]; 77 | } 78 | return name; 79 | } 80 | 81 | // Get raw data pointer from vector. If dev is GPU, this routine will allocate GPU memory and copy. 82 | template 83 | T* get_data_ptr(std::vector& data, int dev) { 84 | T* data_ptr; 85 | #ifdef ENABLE_GPU 86 | if (dev == TORCHFORT_DEVICE_CPU) { 87 | data_ptr = data.data(); 88 | } else { 89 | CHECK_CUDA(cudaMalloc(&data_ptr, data.size() * sizeof(T(0)))); 90 | CHECK_CUDA(cudaMemcpy(data_ptr, data.data(), data.size() * sizeof(T(0)), cudaMemcpyHostToDevice)); 91 | } 92 | #else 93 | data_ptr = data.data(); 94 | #endif 95 | 96 | return data_ptr; 97 | } 98 | 99 | // Free raw data pointer. If dev is GPU, this routine will free GPU memory. 100 | template 101 | void free_data_ptr(T* data_ptr, int dev) { 102 | #ifdef ENABLE_GPU 103 | if (dev != TORCHFORT_DEVICE_CPU) { 104 | CHECK_CUDA(cudaFree(data_ptr)); 105 | } 106 | #endif 107 | } 108 | 109 | // Routines to copy vector data to and from GPU. 110 | #ifdef ENABLE_GPU 111 | template 112 | void copy_to_host_vector(std::vector& data, T* data_ptr) { 113 | CHECK_CUDA(cudaMemcpy(data.data(), data_ptr, data.size()*sizeof(T(0)), cudaMemcpyDeviceToHost)); 114 | } 115 | template 116 | void copy_from_host_vector(T* data_ptr, std::vector& data) { 117 | CHECK_CUDA(cudaMemcpy(data_ptr, data.data(), data.size()*sizeof(T(0)), cudaMemcpyHostToDevice)); 118 | } 119 | #endif 120 | 121 | --------------------------------------------------------------------------------