├── .github └── workflows │ └── fpm.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── VERSION ├── example ├── bench1.f90 ├── demo1.f90 ├── demo2.f90 ├── demo3.f90 └── demo4.f90 ├── fpm.toml ├── src ├── ad_intrinsic.f90 ├── ad_kinds.f90 ├── ad_operator.f90 ├── ad_types.f90 ├── ad_usr_func.f90 └── auto_diff.f90 └── test ├── test_func.f90 ├── test_operator.f90 └── tester.f90 /.github/workflows/fpm.yml: -------------------------------------------------------------------------------- 1 | name: fpm 2 | 3 | on: 4 | push: 5 | paths: 6 | - "example/**.f90" 7 | - "src/**.f90" 8 | - "test/**.f90" 9 | - "fpm.toml" 10 | 11 | pull_request: 12 | branches: 13 | - main 14 | paths: 15 | - "example/**.f90" 16 | - "src/**.f90" 17 | - "test/**.f90" 18 | - "fpm.toml" 19 | 20 | jobs: 21 | build: 22 | runs-on: ${{ matrix.os }} 23 | strategy: 24 | fail-fast: false 25 | matrix: 26 | # gcc 11 in ubuntu-latest is currently unavailable 27 | os: [macos-latest] 28 | # os: [ubuntu-latest, macos-latest] 29 | gcc_v: [11] # Version of GFortran we want to use. 30 | include: 31 | # - os: ubuntu-latest 32 | # os-arch: linux-x86_64 33 | 34 | - os: macos-latest 35 | os-arch: macos-x86_64 36 | 37 | env: 38 | FC: gfortran 39 | GCC_V: ${{ matrix.gcc_v }} 40 | 41 | steps: 42 | - name: Checkout code 43 | uses: actions/checkout@v2 44 | 45 | - name: Install GFortran macOS 46 | if: contains(matrix.os, 'macos') 47 | run: | 48 | ln -s /usr/local/bin/gfortran-${GCC_V} /usr/local/bin/gfortran 49 | which gfortran-${GCC_V} 50 | which gfortran 51 | 52 | - name: Install GFortran Linux 53 | if: contains(matrix.os, 'ubuntu') 54 | run: | 55 | sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-${GCC_V} 100 \ 56 | --slave /usr/bin/gfortran gfortran /usr/bin/gfortran-${GCC_V} \ 57 | --slave /usr/bin/gcov gcov /usr/bin/gcov-${GCC_V} 58 | 59 | - name: Install fpm 60 | uses: fortran-lang/setup-fpm@v3 61 | with: 62 | fpm-version: 'v0.4.0' 63 | 64 | - name: Build & Test 65 | run: | 66 | gfortran --version 67 | fpm build 68 | fpm test -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## v0.1.0 4 | 5 | - [x] Initialize the code 6 | - [x] Change the data structure, hide pointer 7 | - [x] Refactor the code 8 | - [x] Support arrays 9 | - [x] Add a benchmark 10 | - [x] Final procedure of `tree_t`: Not sure if this will ocurr memory leak 11 | - [x] Support `assignment(=)` and `tree_t` constructor. 12 | 13 | ## TO DO 14 | 15 | - [ ] Better support for arrays (?) 16 | - [ ] Destructor, avoid memory leaks (?) 17 | - [ ] Support `integer` 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 St-Maxwell, 左志华 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Auto-Diff 2 | 3 | `Auto-Diff` is an implementation of Modern Fortran's backward mode automatic differentiation. 4 | 5 | *This project is still in the experimental stage, feedback is welcome!* 6 | 7 | *This library is only available in the `gfortran` compiler, and there is a risk of memory leakage, which is still a toy project. And I can't think of a more appropriate way to further improve them. For a more reasonable inverse differential library, see [fazang](https://github.com/yizhang-yiz/fazang).* 8 | 9 | [![MIT](https://img.shields.io/github/license/zoziha/Auto-Diff?color=pink)](LICENSE) 10 | 11 | ## Getting Started 12 | 13 | ### Get the Code 14 | 15 | ```sh 16 | git clone https://github.com/zoziha/Auto-Diff.git 17 | cd Auto-Diff 18 | ``` 19 | 20 | ### Build with [Fortran-lang/fpm](https://github.com/fortran-lang/fpm) 21 | 22 | Fortran Package Manager (fpm) is a package manager and build system for Fortran. 23 | You can build `Auto-Diff` using the provided `fpm.toml`: 24 | 25 | ```sh 26 | fpm build 27 | fpm run --example --list 28 | ``` 29 | 30 | To use `Auto-Diff` within your `fpm` project, add the following lines to your `fpm.toml` file: 31 | 32 | ```toml 33 | [dependencies] 34 | Auto-Diff = { git="https://github.com/zoziha/Auto-Diff" } 35 | ``` 36 | 37 | ### Demo3 38 | 39 | ```fortran 40 | !> Staged solution, run this code: fpm run --example demo3 41 | program main 42 | 43 | use auto_diff 44 | implicit none 45 | type(tree_t) :: x1, x2 46 | type(tree_t) :: y 47 | 48 | x1 = 3.0_rk 49 | x2 = -4.0_rk 50 | 51 | print *, "staged demo: y = (x1 + sigmoid(x2))/(sigmoid(x1) + (x1 + x2)**2)" 52 | y = (x1 + sigmoid(x2))/(sigmoid(x1) + (x1 + x2)**2) 53 | 54 | print *, "y = ", y%get_value() 55 | call y%backward() 56 | 57 | print *, "dy/dx1 = ", x1%get_grad() 58 | print *, "dy/dx2 = ", x2%get_grad() 59 | 60 | end program main 61 | 62 | !> staged demo: y = (x1 + sigmoid(x2))/(sigmoid(x1) + (x1 + x2)**2) 63 | !> y = 1.5456448841066441 64 | !> dy/dx1 = -1.1068039935182090 65 | !> dy/dx2 = -1.5741410376065648 66 | ``` 67 | 68 | ### Bench1 69 | 70 | ```sh 71 | $ fpm run --example bench1 --profile debug 72 | Forward 73 | Elapsed time (seconds): 1.7968750000000000 74 | Ordinary arithmetic 75 | Elapsed time (seconds): 0.14062500000000000 76 | Backward 77 | Elapsed time (seconds): 0.29687500000000000 78 | 79 | $ fpm run --example bench1 --profile release 80 | Forward 81 | Elapsed time (seconds): 1.6718750000000000 82 | Ordinary arithmetic 83 | Elapsed time (seconds): 0.0000000000000000 84 | Backward 85 | Elapsed time (seconds): 0.20312500000000000 86 | ``` 87 | 88 | The `Bench1` code for arrays (1000*1000) is [here](./example/bench1.f90). 89 | 90 | ## Links 91 | 92 | - [St-Maxwell/backward(F90)](https://gist.github.com/St-Maxwell/0a936b03ecf99e284a05d10dd994516e) 93 | - [KT19/automatic_differentiation](https://github.com/KT19/automatic_differentiation) 94 | - [李理的博客/自动微分](http://fancyerii.github.io/books/autodiff/) 95 | - [Fortran-lang Discourse: Backward Mode Auto-Diff in Modern Fortran](https://fortran-lang.discourse.group/t/backward-mode-auto-diff-in-modern-fortran/2334) 96 | - [joddlehod/DNAD](https://github.com/joddlehod/dnad) 97 | - [SCM-NV/ftl](https://github.com/SCM-NV/ftl/blob/master/src/ftlList.F90_template) 98 | - [reverse-mode-automatic-differentiation](https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation) 99 | - [fazang](https://github.com/yizhang-yiz/fazang) 100 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.0.4 -------------------------------------------------------------------------------- /example/bench1.f90: -------------------------------------------------------------------------------- 1 | program main 2 | 3 | use auto_diff 4 | implicit none 5 | 6 | integer, parameter :: n = 1000 7 | type(tree_t) :: x1(n, n), x2(n, n) 8 | type(tree_t) :: y(n, n) 9 | real(rk) :: r1(n, n), r2(n, n) 10 | real(rk) :: z(n, n) 11 | 12 | real(rk) :: t1, t2 13 | 14 | x1 = randu(n, n) 15 | x2 = randu(n, n) 16 | 17 | r1 = randu(n, n) 18 | r2 = randu(n, n) 19 | 20 | print *, "Forward" 21 | call cpu_time(t1) 22 | y = (x1 + sigmoid(x2))/(sigmoid(x1) + (x1 + x2)**2) 23 | call cpu_time(t2) 24 | print *, "Elapsed time (seconds):", t2 - t1 25 | 26 | print *, "Ordinary arithmetic" 27 | call cpu_time(t1) 28 | z = (r1 + sigmoid_local(r2))/(sigmoid_local(r1) + (r1 + r2)**2) 29 | call cpu_time(t2) 30 | print *, "Elapsed time (seconds):", t2 - t1 31 | 32 | print *, "Backward" 33 | call cpu_time(t1) 34 | call y%backward() 35 | call cpu_time(t2) 36 | print *, "Elapsed time (seconds):", t2 - t1 37 | 38 | contains 39 | 40 | function randu(m, n) result(out) 41 | integer, intent(in) :: m, n 42 | real(rk) :: out(m, n) 43 | 44 | call random_number(out) 45 | 46 | end function randu 47 | 48 | elemental function sigmoid_local(x) result(y) 49 | real(rk), intent(in) :: x 50 | real(rk) :: y 51 | y = 1.0_rk/(1.0_rk + exp(-x)) 52 | end function sigmoid_local 53 | 54 | end program main 55 | 56 | ! fpm run --example bench1 --profile debug: 57 | ! Forward 58 | ! Elapsed time (seconds): 1.9218750000000000 59 | ! Ordinary arithmetic 60 | ! Elapsed time (seconds): 0.15625000000000000 61 | ! Backward 62 | ! Elapsed time (seconds): 0.29687500000000000 63 | 64 | ! fpm run --example bench1 --profile release: 65 | ! Forward 66 | ! Elapsed time (seconds): 1.6093750000000000 67 | ! Ordinary arithmetic 68 | ! Elapsed time (seconds): 0.0000000000000000 69 | ! Backward 70 | ! Elapsed time (seconds): 0.21875000000000000 71 | -------------------------------------------------------------------------------- /example/demo1.f90: -------------------------------------------------------------------------------- 1 | !> Backward auto diff 2 | program main 3 | 4 | use auto_diff, only: tree_t, operator(*), operator(+), exp, rk, assignment(=) 5 | implicit none 6 | type(tree_t) :: a, b, c 7 | type(tree_t) :: y 8 | 9 | a = 2.0_rk 10 | b = 1.0_rk 11 | c = 0.0_rk 12 | 13 | print *, "demo1: y = (a + b*b)*exp(c)" 14 | 15 | y = (a + b*b)*exp(c) 16 | 17 | print *, "y = ", y%get_value() 18 | call y%backward() 19 | 20 | print *, "dy/da = ", a%get_grad() 21 | print *, "dy/db = ", b%get_grad() 22 | print *, "dy/dc = ", c%get_grad() 23 | 24 | ! - - - 25 | 26 | a = 2.0_rk 27 | b = 1.0_rk 28 | 29 | print *, "" 30 | print *, "demo2: y = (a + b)*(b + 1.0)" 31 | 32 | y = (a + b)*(b + 1.0_rk) 33 | 34 | print *, "y = ", y%get_value() 35 | 36 | call y%backward() 37 | 38 | print *, "dy/da = ", a%get_grad() 39 | print *, "dy/db = ", b%get_grad() 40 | 41 | end program main 42 | 43 | !> demo1: y = (a + b*b)*exp(c) 44 | !> y = 3.0000000000000000 45 | !> dy/da = 1.0000000000000000 46 | !> dy/db = 2.0000000000000000 47 | !> dy/dc = 3.0000000000000000 48 | !> 49 | !> demo2: y = (a + b)*(b + 1.0) 50 | !> y = 6.0000000000000000 51 | !> dy/da = 2.0000000000000000 52 | !> dy/db = 5.0000000000000000 -------------------------------------------------------------------------------- /example/demo2.f90: -------------------------------------------------------------------------------- 1 | !> Sigmoid func & gate 2 | program main 3 | 4 | use auto_diff, only: sigmoid 5 | use auto_diff, only: tree_t, rk, assignment(=) 6 | use auto_diff, only: operator(*), operator(+) 7 | implicit none 8 | type(tree_t) :: w0, w1, w2, x0, x1 9 | type(tree_t) :: y 10 | 11 | w0 = 2.0_rk 12 | w1 = -3.0_rk 13 | w2 = -3.0_rk 14 | x0 = -1.0_rk 15 | x1 = -2.0_rk 16 | 17 | print *, "sigmoid demo: y = 1/(1 + exp(-z)), z = w0*x0 + w1*x1 + w2" 18 | y = sigmoid(w0*x0 + w1*x1 + w2) 19 | 20 | print *, "y = ", y%get_value() ! should be 0.73 21 | call y%backward() 22 | 23 | print *, "dy/dw0 = ", w0%get_grad() ! should be -0.20 24 | print *, "dy/dw1 = ", w1%get_grad() ! should be -0.39 25 | print *, "dy/dw2 = ", w2%get_grad() ! should be 0.20 26 | print *, "dy/dx0 = ", x0%get_grad() ! should be 0.39 27 | print *, "dy/dx1 = ", x1%get_grad() ! should be -0.59 28 | 29 | end program main 30 | 31 | !> sigmoid demo: y = 1/(1 + exp(-z), z = w0*x0 + w1*x1 + w2 32 | !> y = 0.73105857863000490 33 | !> dy/dw0 = -0.19661193324148185 34 | !> dy/dw1 = -0.39322386648296370 35 | !> dy/dw2 = 0.19661193324148185 36 | !> dy/dx0 = 0.39322386648296370 37 | !> dy/dx1 = -0.58983579972444555 38 | -------------------------------------------------------------------------------- /example/demo3.f90: -------------------------------------------------------------------------------- 1 | !> Staged solution 2 | program main 3 | 4 | use auto_diff, only: sigmoid 5 | use auto_diff, only: tree_t, rk, assignment(=) 6 | use auto_diff, only: operator(*), operator(+), operator(/), operator(**) 7 | implicit none 8 | type(tree_t) :: x1, x2 9 | type(tree_t) :: y 10 | 11 | x1 = 3.0_rk 12 | x2 = -4.0_rk 13 | 14 | print *, "staged demo: y = (x1 + sigmoid(x2))/(sigmoid(x1) + (x1 + x2)**2)" 15 | y = (x1 + sigmoid(x2))/(sigmoid(x1) + (x1 + x2)**2) 16 | 17 | print *, "y = ", y%get_value() 18 | call y%backward() 19 | 20 | print *, "dy/dx1 = ", x1%get_grad() 21 | print *, "dy/dx2 = ", x2%get_grad() 22 | 23 | end program main 24 | 25 | !> staged demo: y = (x1 + sigmoid(x2))/(sigmoid(x1) + (x1 + x2)**2) 26 | !> y = 1.5456448841066441 27 | !> dy/dx1 = -1.1068039935182090 28 | !> dy/dx2 = -1.5741410376065648 -------------------------------------------------------------------------------- /example/demo4.f90: -------------------------------------------------------------------------------- 1 | !> Matrix 2 | program main 3 | 4 | use auto_diff 5 | implicit none 6 | type(tree_t) :: a(2, 2), b(2, 2) 7 | type(tree_t) :: y(2, 2) 8 | 9 | a = 1.0_rk 10 | b = 2.0_rk 11 | 12 | print *, "demo4: y = a + a * b" 13 | 14 | y = a + a*b 15 | 16 | print *, "y = ", y%get_value() 17 | call y%backward() 18 | 19 | print *, "dy/da = ", a%get_grad() 20 | print *, "dy/db = ", b%get_grad() 21 | 22 | end program main 23 | 24 | ! demo4: y = a + a * b 25 | ! y = 3.0000000000000000 3.0000000000000000 3.0000000000000000 3.0000000000000000 26 | ! dy/da = 3.0000000000000000 3.0000000000000000 3.0000000000000000 3.0000000000000000 27 | ! dy/db = 1.0000000000000000 1.0000000000000000 1.0000000000000000 1.0000000000000000 -------------------------------------------------------------------------------- /fpm.toml: -------------------------------------------------------------------------------- 1 | name = 'Auto-Diff' 2 | version = 'VERSION' 3 | license = 'MIT' 4 | maintainer = '左志华' 5 | copyright = 'Copyright 2021 St-Maxwell, 左志华' 6 | description = 'Auto-Diff' 7 | categories = ['Auto-Diff', 'Backward Mode'] 8 | 9 | [dev-dependencies] 10 | test-drive = { git="https://github.com/fortran-lang/test-drive", tag="v0.4.0" } 11 | -------------------------------------------------------------------------------- /src/ad_intrinsic.f90: -------------------------------------------------------------------------------- 1 | module ad_intrinsic 2 | 3 | use ad_kinds, only: rk 4 | use ad_types, only: tree_t, assignment(=) 5 | implicit none 6 | private 7 | 8 | public :: abs, max, min 9 | public :: sin, cos, tan 10 | public :: exp, sqrt, log, log10 11 | 12 | interface abs 13 | module procedure :: abs_t 14 | end interface abs 15 | 16 | interface max 17 | module procedure :: max_tt 18 | module procedure :: max_tr 19 | ! Ambiguous interfaces in generic interface 'max' for 'max_tr' at (1) and 'max_rt' here 20 | ! module procedure :: max_rt 21 | end interface max 22 | 23 | interface min 24 | module procedure :: min_tt 25 | module procedure :: min_tr 26 | ! module procedure :: min_rt 27 | end interface min 28 | 29 | interface sin 30 | module procedure :: sin_t 31 | end interface sin 32 | 33 | interface cos 34 | module procedure :: cos_t 35 | end interface cos 36 | 37 | interface tan 38 | module procedure :: tan_t 39 | end interface tan 40 | 41 | interface exp 42 | module procedure :: exp_t 43 | end interface exp 44 | 45 | interface sqrt 46 | module procedure :: sqrt_t 47 | end interface sqrt 48 | 49 | interface log 50 | module procedure :: log_t 51 | end interface log 52 | 53 | interface log10 54 | module procedure :: log10_t 55 | end interface log10 56 | 57 | contains 58 | 59 | impure elemental function abs_t(t1) result(t) 60 | type(tree_t), intent(in) :: t1 61 | type(tree_t) :: t 62 | 63 | t = abs(t1%node%value) 64 | 65 | t%node%left => t1%node 66 | t%node%left_grad = merge(-1.0_rk, 1.0_rk, t1%node%value < 0.0_rk) 67 | 68 | end function abs_t 69 | 70 | impure elemental function max_tt(t1, t2) result(t) 71 | type(tree_t), intent(in) :: t1, t2 72 | type(tree_t) :: t 73 | 74 | t = max(t1%node%value, t2%node%value) 75 | 76 | t%node%left => t1%node 77 | t%node%left_grad = 1.0_rk 78 | 79 | end function max_tt 80 | 81 | ! impure elemental function max_rt(r, t1) result(t) 82 | ! real(rk), intent(in) :: r 83 | ! type(tree_t), intent(in) :: t1 84 | ! type(tree_t) :: t 85 | 86 | ! allocate (t%node) 87 | ! t%node%value = max(r, t1%node%value) 88 | 89 | ! t%node%left => t1%node 90 | ! t%node%left_grad = 1.0_rk 91 | 92 | ! end function max_rt 93 | 94 | impure elemental function max_tr(t1, r) result(t) 95 | type(tree_t), intent(in) :: t1 96 | real(rk), intent(in) :: r 97 | type(tree_t) :: t 98 | 99 | t = max(t1%node%value, r) 100 | 101 | t%node%left => t1%node 102 | t%node%left_grad = 1.0_rk 103 | 104 | end function max_tr 105 | 106 | impure elemental function min_tt(t1, t2) result(t) 107 | type(tree_t), intent(in) :: t1, t2 108 | type(tree_t) :: t 109 | 110 | t = min(t1%node%value, t2%node%value) 111 | 112 | t%node%left => t1%node 113 | t%node%left_grad = 1.0_rk 114 | 115 | end function min_tt 116 | 117 | ! impure elemental function min_rt(r, t1) result(t) 118 | ! real(rk), intent(in) :: r 119 | ! type(tree_t), intent(in) :: t1 120 | ! type(tree_t) :: t 121 | 122 | ! allocate (t%node) 123 | ! t%node%value = min(r, t1%node%value) 124 | 125 | ! t%node%left => t1%node 126 | ! t%node%left_grad = 1.0_rk 127 | 128 | ! end function min_rt 129 | 130 | impure elemental function min_tr(t1, r) result(t) 131 | type(tree_t), intent(in) :: t1 132 | real(rk), intent(in) :: r 133 | type(tree_t) :: t 134 | 135 | t = min(t1%node%value, r) 136 | 137 | t%node%left => t1%node 138 | t%node%left_grad = 1.0_rk 139 | 140 | end function min_tr 141 | 142 | impure elemental function sin_t(t1) result(t) 143 | type(tree_t), intent(in) :: t1 144 | type(tree_t) :: t 145 | 146 | t = sin(t1%node%value) 147 | 148 | t%node%left => t1%node 149 | t%node%left_grad = cos(t1%node%value) 150 | 151 | end function sin_t 152 | 153 | impure elemental function cos_t(t1) result(t) 154 | type(tree_t), intent(in) :: t1 155 | type(tree_t) :: t 156 | 157 | t = cos(t1%node%value) 158 | 159 | t%node%left => t1%node 160 | t%node%left_grad = -sin(t1%node%value) 161 | 162 | end function cos_t 163 | 164 | impure elemental function tan_t(t1) result(t) 165 | type(tree_t), intent(in) :: t1 166 | type(tree_t) :: t 167 | 168 | t = tan(t1%node%value) 169 | 170 | t%node%left => t1%node 171 | t%node%left_grad = 1.0_rk/(cos(t1%node%value)*cos(t1%node%value)) 172 | 173 | end function tan_t 174 | 175 | impure elemental function exp_t(t1) result(t) 176 | type(tree_t), intent(in) :: t1 177 | type(tree_t) :: t 178 | 179 | t = exp(t1%node%value) 180 | 181 | t%node%left => t1%node 182 | t%node%left_grad = t%node%value 183 | 184 | end function exp_t 185 | 186 | impure elemental function sqrt_t(t1) result(t) 187 | use, intrinsic :: ieee_arithmetic, only: ieee_value, NAN => ieee_quiet_nan 188 | type(tree_t), intent(in) :: t1 189 | type(tree_t) :: t 190 | 191 | t = sqrt(t1%node%value) 192 | 193 | t%node%left => t1%node 194 | t%node%left_grad = merge(0.5_rk/t%node%value, ieee_value(1.0_rk, NAN), t1%node%value >= 0.0_rk) !TODO: NAN 195 | 196 | end function sqrt_t 197 | 198 | impure elemental function log_t(t1) result(t) 199 | type(tree_t), intent(in) :: t1 200 | type(tree_t) :: t 201 | 202 | t = log(t1%node%value) 203 | 204 | t%node%left => t1%node 205 | t%node%left_grad = 1.0_rk/t1%node%value 206 | 207 | end function log_t 208 | 209 | impure elemental function log10_t(t1) result(t) 210 | type(tree_t), intent(in) :: t1 211 | type(tree_t) :: t 212 | 213 | t = log10(t1%node%value) 214 | 215 | t%node%left => t1%node 216 | t%node%left_grad = 1.0_rk/t1%node%value/log(10.0_rk) 217 | 218 | end function log10_t 219 | 220 | end module ad_intrinsic 221 | -------------------------------------------------------------------------------- /src/ad_kinds.f90: -------------------------------------------------------------------------------- 1 | module ad_kinds 2 | 3 | use, intrinsic :: iso_fortran_env, only: real64, real32 4 | implicit none 5 | private 6 | 7 | public :: rk 8 | 9 | integer, parameter :: rk = real32 10 | 11 | end module ad_kinds 12 | -------------------------------------------------------------------------------- /src/ad_operator.f90: -------------------------------------------------------------------------------- 1 | module ad_operator 2 | 3 | use ad_kinds, only: rk 4 | use ad_types, only: tree_t, assignment(=) 5 | implicit none 6 | private 7 | 8 | public :: operator(+), operator(-), operator(*), operator(/), operator(**) 9 | 10 | interface operator(+) 11 | module procedure :: add_tt 12 | module procedure :: add_tr 13 | module procedure :: add_rt 14 | end interface operator(+) 15 | 16 | interface operator(-) 17 | module procedure :: sub_tt 18 | module procedure :: sub_tr 19 | module procedure :: sub_rt 20 | end interface operator(-) 21 | 22 | interface operator(*) 23 | module procedure :: mul_tt 24 | module procedure :: mul_tr 25 | module procedure :: mul_rt 26 | end interface operator(*) 27 | 28 | interface operator(/) 29 | module procedure :: div_tt 30 | module procedure :: div_tr 31 | module procedure :: div_rt 32 | end interface operator(/) 33 | 34 | interface operator(**) 35 | module procedure :: pow_tt 36 | module procedure :: pow_tr 37 | module procedure :: pow_rt 38 | module procedure :: pow_ti 39 | end interface operator(**) 40 | 41 | contains 42 | 43 | impure elemental function add_tt(t1, t2) result(t) 44 | type(tree_t), intent(in) :: t1, t2 45 | type(tree_t) :: t 46 | 47 | t = t1%node%value + t2%node%value 48 | 49 | t%node%left => t1%node 50 | t%node%right => t2%node 51 | 52 | t%node%left_grad = 1.0_rk 53 | t%node%right_grad = 1.0_rk 54 | 55 | end function add_tt 56 | 57 | impure elemental function add_tr(t1, r) result(t) 58 | type(tree_t), intent(in) :: t1 59 | real(rk), intent(in) :: r 60 | type(tree_t) :: t 61 | 62 | t = add_rt(r, t1) 63 | 64 | end function add_tr 65 | 66 | impure elemental function add_rt(r, t1) result(t) 67 | real(rk), intent(in) :: r 68 | type(tree_t), intent(in) :: t1 69 | type(tree_t) :: t 70 | 71 | t = t1%node%value + r 72 | 73 | t%node%left => t1%node 74 | t%node%left_grad = 1.0_rk 75 | 76 | end function add_rt 77 | 78 | impure elemental function sub_tt(t1, t2) result(t) 79 | type(tree_t), intent(in) :: t1, t2 80 | type(tree_t) :: t 81 | 82 | t = add_tt(t1, mul_rt(-1.0_rk, t2)) 83 | 84 | end function sub_tt 85 | 86 | impure elemental function sub_tr(t1, r) result(t) 87 | type(tree_t), intent(in) :: t1 88 | real(rk), intent(in) :: r 89 | type(tree_t) :: t 90 | 91 | t = add_tr(t1, -r) 92 | 93 | end function sub_tr 94 | 95 | impure elemental function sub_rt(r, t1) result(t) 96 | real(rk), intent(in) :: r 97 | type(tree_t), intent(in) :: t1 98 | type(tree_t) :: t 99 | 100 | t = add_rt(r, mul_rt(-1.0_rk, t1)) 101 | 102 | end function sub_rt 103 | 104 | impure elemental function mul_tt(t1, t2) result(t) 105 | type(tree_t), intent(in) :: t1, t2 106 | type(tree_t) :: t 107 | 108 | t = t1%node%value*t2%node%value 109 | 110 | t%node%left => t1%node 111 | t%node%right => t2%node 112 | 113 | t%node%left_grad = t2%node%value 114 | t%node%right_grad = t1%node%value 115 | 116 | end function mul_tt 117 | 118 | impure elemental function mul_tr(t1, r) result(t) 119 | type(tree_t), intent(in) :: t1 120 | real(rk), intent(in) :: r 121 | type(tree_t) :: t 122 | 123 | t = mul_rt(r, t1) 124 | 125 | end function mul_tr 126 | 127 | impure elemental function mul_rt(r, t1) result(t) 128 | real(rk), intent(in) :: r 129 | type(tree_t), intent(in) :: t1 130 | type(tree_t) :: t 131 | 132 | t = t1%node%value*r 133 | 134 | t%node%left => t1%node 135 | t%node%left_grad = r 136 | 137 | end function mul_rt 138 | 139 | impure elemental function div_tt(t1, t2) result(t) 140 | type(tree_t), intent(in) :: t1, t2 141 | type(tree_t) :: t 142 | 143 | t = mul_tt(t1, div_rt(1.0_rk, t2)) 144 | 145 | end function div_tt 146 | 147 | impure elemental function div_tr(t1, r) result(t) 148 | type(tree_t), intent(in) :: t1 149 | real(rk), intent(in) :: r 150 | type(tree_t) :: t 151 | 152 | t = mul_tr(t1, 1.0_rk/r) 153 | 154 | end function div_tr 155 | 156 | impure elemental function div_rt(r, t1) result(t) 157 | real(rk), intent(in) :: r 158 | type(tree_t), intent(in) :: t1 159 | type(tree_t) :: t 160 | 161 | t = r/t1%node%value 162 | 163 | t%node%left => t1%node 164 | t%node%left_grad = -r/t1%node%value**2 165 | 166 | end function div_rt 167 | 168 | impure elemental function pow_tt(t1, t2) result(t) 169 | type(tree_t), intent(in) :: t1, t2 170 | type(tree_t) :: t 171 | 172 | t = t1%node%value**t2%node%value 173 | 174 | t%node%left => t1%node 175 | t%node%right => t2%node 176 | 177 | ! t%node%left_grad = t2%node%value*t1%node%value**(t2%node%value-1.0_rk) 178 | ! t%node%right_grad = t1%node%value**t2%node%value*log(t1%node%value) 179 | t%node%left_grad = t%node%value*(log(t1%node%value) + t2%node%value/t1%node%value) 180 | t%node%right_grad = t%node%value*(log(t1%node%value) + t2%node%value/t1%node%value) 181 | 182 | end function pow_tt 183 | 184 | impure elemental function pow_tr(t1, r) result(t) 185 | type(tree_t), intent(in) :: t1 186 | real(rk), intent(in) :: r 187 | type(tree_t) :: t 188 | 189 | t = t1%node%value**r 190 | 191 | t%node%left => t1%node 192 | t%node%left_grad = r*t%node%value**(r - 1.0_rk) 193 | 194 | end function pow_tr 195 | 196 | impure elemental function pow_ti(t1, i) result(t) 197 | type(tree_t), intent(in) :: t1 198 | integer, intent(in) :: i 199 | type(tree_t) :: t 200 | 201 | t = t1%node%value**i 202 | 203 | t%node%left => t1%node 204 | t%node%left_grad = i*t%node%value**(i - 1) 205 | 206 | end function pow_ti 207 | 208 | impure elemental function pow_rt(r, t1) result(t) 209 | real(rk), intent(in) :: r 210 | type(tree_t), intent(in) :: t1 211 | type(tree_t) :: t 212 | 213 | t = r**t1%node%value 214 | 215 | t%node%left => t1%node 216 | t%node%left_grad = r**t1%node%value*log(r) 217 | 218 | end function pow_rt 219 | 220 | end module ad_operator 221 | -------------------------------------------------------------------------------- /src/ad_types.f90: -------------------------------------------------------------------------------- 1 | module ad_types 2 | 3 | use ad_kinds, only: rk 4 | implicit none 5 | private 6 | 7 | public :: tree_t, assignment(=) 8 | 9 | !> Linked list node 10 | type node_t 11 | 12 | real(rk) :: value !! value of the node 13 | real(rk) :: grad = 0.0_rk !! gradient of this node 14 | 15 | type(node_t), pointer :: left => null() 16 | real(rk), allocatable :: left_grad !! gradient of the left node 17 | 18 | type(node_t), pointer :: right => null() 19 | real(rk), allocatable :: right_grad !! gradient of the right node 20 | 21 | contains 22 | 23 | procedure :: backward => node_t_backward 24 | procedure :: destructor => node_t_destructor 25 | 26 | end type node_t 27 | 28 | !> Linked tree 29 | type tree_t 30 | type(node_t), pointer :: node => null() 31 | contains 32 | procedure :: backward => tree_t_backward 33 | procedure :: get_value => tree_t_get_value 34 | procedure :: get_grad => tree_t_get_grad 35 | procedure :: destructor => tree_t_destructor 36 | final :: final_destructor 37 | end type tree_t 38 | 39 | interface assignment(=) 40 | module procedure :: tree_t_assignment 41 | end interface assignment(=) 42 | 43 | interface tree_t 44 | module procedure :: tree_t_constructor 45 | end interface tree_t 46 | 47 | contains 48 | 49 | elemental function node_t_constructor(value) result(return_node) 50 | real(rk), intent(in) :: value 51 | type(node_t) :: return_node 52 | 53 | return_node%value = value 54 | 55 | end function node_t_constructor 56 | 57 | elemental subroutine tree_t_assignment(self, value) 58 | type(tree_t), intent(inout) :: self 59 | real(rk), intent(in) :: value 60 | 61 | allocate (self%node, source=node_t_constructor(value)) 62 | 63 | end subroutine tree_t_assignment 64 | 65 | elemental function tree_t_constructor(value) result(return_tree) 66 | real(rk), intent(in) :: value 67 | type(tree_t) :: return_tree 68 | 69 | call tree_t_assignment(return_tree, value) 70 | 71 | end function tree_t_constructor 72 | 73 | elemental subroutine tree_t_backward(self) 74 | class(tree_t), intent(inout) :: self 75 | 76 | associate (node => self%node) 77 | 78 | node%grad = 1.0_rk 79 | 80 | if (associated(node%left)) & 81 | call node%left%backward(node%left_grad) 82 | if (associated(node%right)) & 83 | call node%right%backward(node%right_grad) 84 | 85 | end associate 86 | 87 | end subroutine tree_t_backward 88 | 89 | elemental subroutine node_t_backward(self, out) 90 | class(node_t), intent(inout) :: self 91 | real(rk), intent(in) :: out 92 | 93 | self%grad = self%grad + out 94 | 95 | if (associated(self%left)) call self%left%backward(out*self%left_grad) 96 | if (associated(self%right)) call self%right%backward(out*self%right_grad) 97 | 98 | end subroutine node_t_backward 99 | 100 | elemental function tree_t_get_value(self) result(value) 101 | class(tree_t), intent(in) :: self 102 | real(rk) :: value 103 | 104 | value = self%node%value 105 | 106 | end function tree_t_get_value 107 | 108 | elemental function tree_t_get_grad(self) result(grad) 109 | class(tree_t), intent(in) :: self 110 | real(rk) :: grad 111 | 112 | grad = self%node%grad 113 | 114 | end function tree_t_get_grad 115 | 116 | elemental subroutine tree_t_destructor(self) 117 | class(tree_t), intent(inout) :: self 118 | 119 | if (associated(self%node)) then 120 | call self%node%destructor() 121 | nullify (self%node) ! @note: Not sure if this will ocurr memory leak 122 | end if 123 | 124 | end subroutine tree_t_destructor 125 | 126 | ! - GFortran >= 11.0 127 | ! - Intel Fortran >= 2019 128 | ! or report an error: ELEMENTAL attribute conflicts with RECURSIVE attribute. 129 | elemental recursive subroutine node_t_destructor(self) 130 | class(node_t), intent(inout) :: self 131 | 132 | if (associated(self%left)) then 133 | call self%left%destructor() 134 | nullify (self%left) ! @note: Not sure if this will ocurr memory leak 135 | end if 136 | 137 | if (associated(self%right)) then 138 | call self%right%destructor() 139 | nullify (self%right) ! @note: Not sure if this will ocurr memory leak 140 | end if 141 | 142 | end subroutine node_t_destructor 143 | 144 | elemental subroutine final_destructor(self) 145 | type(tree_t), intent(inout) :: self 146 | 147 | call self%destructor() 148 | 149 | end subroutine final_destructor 150 | 151 | end module ad_types 152 | -------------------------------------------------------------------------------- /src/ad_usr_func.f90: -------------------------------------------------------------------------------- 1 | module ad_usr_func 2 | 3 | use ad_kinds, only: rk 4 | use ad_types, only: tree_t, assignment(=) 5 | implicit none 6 | private 7 | 8 | public :: sigmoid 9 | 10 | contains 11 | 12 | impure elemental function sigmoid(t1) result(t) 13 | type(tree_t), intent(in) :: t1 14 | type(tree_t) :: t 15 | 16 | t = 1.0_rk/(1.0_rk + exp(-t1%node%value)) 17 | 18 | t%node%left => t1%node 19 | t%node%left_grad = t%node%value*(1.0_rk - t%node%value) 20 | 21 | end function sigmoid 22 | 23 | end module ad_usr_func -------------------------------------------------------------------------------- /src/auto_diff.f90: -------------------------------------------------------------------------------- 1 | module auto_diff 2 | 3 | use ad_kinds, only: rk 4 | use ad_types, only: tree_t, assignment(=) 5 | use ad_operator, only: operator(+), operator(-), operator(*), operator(/), & 6 | operator(**) 7 | use ad_intrinsic, only: abs, max, min, & 8 | sin, cos, tan, & 9 | exp, sqrt, log, log10 10 | use ad_usr_func, only: sigmoid 11 | implicit none 12 | 13 | end module auto_diff 14 | -------------------------------------------------------------------------------- /test/test_func.f90: -------------------------------------------------------------------------------- 1 | module test_func 2 | 3 | use testdrive, only: new_unittest, unittest_type, error_type, check 4 | use auto_diff, only: abs, exp, sqrt, sin, cos, tan, log, log10 5 | use auto_diff, only: max, min, sigmoid 6 | use auto_diff, only: tree_t, rk, assignment(=) 7 | implicit none 8 | private 9 | 10 | public :: collect_suite_func 11 | 12 | contains 13 | 14 | !> Collect all exported unit tests 15 | subroutine collect_suite_func(testsuite) 16 | !> Collection of tests 17 | type(unittest_type), allocatable, intent(out) :: testsuite(:) 18 | 19 | testsuite = [ & 20 | new_unittest("func abs valid", test_abs_valid), & 21 | new_unittest("func exp valid", test_exp_valid), & 22 | new_unittest("func sqrt valid", test_sqrt_valid), & 23 | new_unittest("func sin valid", test_sin_valid), & 24 | new_unittest("func cos valid", test_cos_valid), & 25 | new_unittest("func tan valid", test_tan_valid), & 26 | new_unittest("func log valid", test_log_valid), & 27 | new_unittest("func log10 valid", test_log10_valid), & 28 | new_unittest("func max valid", test_max_valid), & 29 | new_unittest("func min valid", test_min_valid), & 30 | new_unittest("func sigmoid valid", test_sigmoid_valid) & 31 | ] 32 | 33 | end subroutine collect_suite_func 34 | 35 | subroutine test_abs_valid(error) 36 | type(error_type), allocatable, intent(out) :: error 37 | 38 | type(tree_t) :: a 39 | type(tree_t) :: b 40 | 41 | a = 2.0_rk 42 | b = abs(a) 43 | call b%backward() 44 | 45 | call check(error, b%get_value(), 2.0_rk); if (allocated(error)) return 46 | call check(error, a%get_grad(), 1.0_rk) 47 | 48 | end subroutine test_abs_valid 49 | 50 | subroutine test_sqrt_valid(error) 51 | type(error_type), allocatable, intent(out) :: error 52 | 53 | type(tree_t) :: a 54 | type(tree_t) :: b 55 | 56 | a = 2.0_rk 57 | b = sqrt(a) 58 | call b%backward() 59 | 60 | call check(error, b%get_value(), sqrt(2.0_rk)); if (allocated(error)) return 61 | call check(error, a%get_grad(), 1.0_rk/(2.0_rk*sqrt(2.0_rk))) 62 | 63 | end subroutine test_sqrt_valid 64 | 65 | subroutine test_exp_valid(error) 66 | type(error_type), allocatable, intent(out) :: error 67 | 68 | type(tree_t) :: a 69 | type(tree_t) :: b 70 | 71 | a = 2.0_rk 72 | b = exp(a) 73 | call b%backward() 74 | 75 | call check(error, b%get_value(), exp(2.0_rk)); if (allocated(error)) return 76 | call check(error, a%get_grad(), exp(2.0_rk)) 77 | 78 | end subroutine test_exp_valid 79 | 80 | subroutine test_sin_valid(error) 81 | type(error_type), allocatable, intent(out) :: error 82 | 83 | type(tree_t) :: a 84 | type(tree_t) :: b 85 | 86 | a = 2.0_rk 87 | b = sin(a) 88 | call b%backward() 89 | 90 | call check(error, b%get_value(), sin(2.0_rk)); if (allocated(error)) return 91 | call check(error, a%get_grad(), cos(2.0_rk)) 92 | 93 | end subroutine test_sin_valid 94 | 95 | subroutine test_cos_valid(error) 96 | type(error_type), allocatable, intent(out) :: error 97 | 98 | type(tree_t) :: a 99 | type(tree_t) :: b 100 | 101 | a = 2.0_rk 102 | b = cos(a) 103 | call b%backward() 104 | 105 | call check(error, b%get_value(), cos(2.0_rk)); if (allocated(error)) return 106 | call check(error, a%get_grad(), -sin(2.0_rk)) 107 | 108 | end subroutine test_cos_valid 109 | 110 | subroutine test_tan_valid(error) 111 | type(error_type), allocatable, intent(out) :: error 112 | 113 | type(tree_t) :: a 114 | type(tree_t) :: b 115 | 116 | a = 2.0_rk 117 | b = tan(a) 118 | call b%backward() 119 | 120 | call check(error, b%get_value(), tan(2.0_rk)); if (allocated(error)) return 121 | call check(error, a%get_grad(), 1.0_rk/(cos(2.0_rk)**2)) 122 | 123 | end subroutine test_tan_valid 124 | 125 | subroutine test_log_valid(error) 126 | type(error_type), allocatable, intent(out) :: error 127 | 128 | type(tree_t) :: a 129 | type(tree_t) :: b 130 | 131 | a = 2.0_rk 132 | b = log(a) 133 | call b%backward() 134 | 135 | call check(error, b%get_value(), log(2.0_rk)); if (allocated(error)) return 136 | call check(error, a%get_grad(), 1.0_rk/2.0_rk) 137 | 138 | end subroutine test_log_valid 139 | 140 | subroutine test_log10_valid(error) 141 | type(error_type), allocatable, intent(out) :: error 142 | 143 | type(tree_t) :: a 144 | type(tree_t) :: b 145 | 146 | a = 2.0_rk 147 | b = log10(a) 148 | call b%backward() 149 | 150 | call check(error, b%get_value(), log10(2.0_rk)); if (allocated(error)) return 151 | call check(error, a%get_grad(), 1.0_rk/log(10.0_rk)/2.0_rk) 152 | 153 | end subroutine test_log10_valid 154 | 155 | subroutine test_max_valid(error) 156 | type(error_type), allocatable, intent(out) :: error 157 | 158 | type(tree_t) :: a 159 | type(tree_t) :: b 160 | 161 | a = 2.0_rk 162 | b = max(a, a) 163 | call b%backward() 164 | 165 | call check(error, b%get_value(), 2.0_rk); if (allocated(error)) return 166 | call check(error, a%get_grad(), 1.0_rk) 167 | 168 | end subroutine test_max_valid 169 | 170 | subroutine test_min_valid(error) 171 | type(error_type), allocatable, intent(out) :: error 172 | 173 | type(tree_t) :: a 174 | type(tree_t) :: b 175 | 176 | a = 2.0_rk 177 | b = min(a, a) 178 | call b%backward() 179 | 180 | call check(error, b%get_value(), 2.0_rk); if (allocated(error)) return 181 | call check(error, a%get_grad(), 1.0_rk) 182 | 183 | a = 1.0_rk 184 | b = min(a, 2.0_rk) 185 | call b%backward() 186 | 187 | call check(error, b%get_value(), 1.0_rk); if (allocated(error)) return 188 | call check(error, a%get_grad(), 1.0_rk) 189 | 190 | !> Not implemented yet 191 | ! call a%constructor(value=1.0_rk) 192 | ! b = min(2.0_rk, a) 193 | ! call b%backward() 194 | 195 | ! call check(error, b%get_value(), 1.0_rk) 196 | ! call check(error, a%get_grad(), 1.0_rk) 197 | 198 | end subroutine test_min_valid 199 | 200 | subroutine test_sigmoid_valid(error) 201 | type(error_type), allocatable, intent(out) :: error 202 | 203 | type(tree_t) :: a 204 | type(tree_t) :: b 205 | 206 | a = 2.0_rk 207 | b = sigmoid(a) 208 | call b%backward() 209 | 210 | call check(error, b%get_value(), 1.0_rk/(1.0_rk+exp(-2.0_rk))); if (allocated(error)) return 211 | call check(error, a%get_grad(), 0.10499358540350662_rk) 212 | 213 | end subroutine test_sigmoid_valid 214 | 215 | end module test_func 216 | -------------------------------------------------------------------------------- /test/test_operator.f90: -------------------------------------------------------------------------------- 1 | module test_operator 2 | use testdrive, only: new_unittest, unittest_type, error_type, check 3 | use auto_diff, only: operator(+), operator(-), operator(*), operator(/), operator(**) 4 | use auto_diff, only: rk, tree_t, assignment(=) 5 | implicit none 6 | private 7 | 8 | public :: collect_suite_operator 9 | 10 | contains 11 | 12 | !> Collect all exported unit tests 13 | subroutine collect_suite_operator(testsuite) 14 | !> Collection of tests 15 | type(unittest_type), allocatable, intent(out) :: testsuite(:) 16 | 17 | testsuite = [ & 18 | new_unittest("operator(+) valid", test_add_valid), & 19 | new_unittest("operator(-) valid", test_sub_valid), & 20 | new_unittest("operator(*) valid", test_mult_valid), & 21 | new_unittest("operator(/) valid", test_div_valid), & 22 | new_unittest("operator(**) valid", test_pow_valid) & 23 | ] 24 | 25 | end subroutine collect_suite_operator 26 | 27 | subroutine test_add_valid(error) 28 | type(error_type), allocatable, intent(out) :: error 29 | 30 | type(tree_t) :: a, b 31 | type(tree_t) :: c 32 | 33 | a = 1.0_rk 34 | b = 2.0_rk 35 | 36 | c = a + b 37 | call c%backward() 38 | 39 | call check(error, c%get_value(), 3.0_rk); if (allocated(error)) return 40 | call check(error, a%get_grad(), 1.0_rk); if (allocated(error)) return 41 | call check(error, b%get_grad(), 1.0_rk); if (allocated(error)) return 42 | 43 | a = 1.0_rk 44 | c = a + 1.0_rk 45 | call c%backward() 46 | 47 | call check(error, c%get_value(), 2.0_rk); if (allocated(error)) return 48 | call check(error, a%get_grad(), 1.0_rk); if (allocated(error)) return 49 | 50 | a = 1.0_rk 51 | c = 1.0_rk + a 52 | call c%backward() 53 | 54 | call check(error, c%get_value(), 2.0_rk); if (allocated(error)) return 55 | call check(error, a%get_grad(), 1.0_rk) 56 | 57 | end subroutine test_add_valid 58 | 59 | subroutine test_sub_valid(error) 60 | type(error_type), allocatable, intent(out) :: error 61 | 62 | type(tree_t) :: a, b 63 | type(tree_t) :: c 64 | 65 | a = 1.0_rk 66 | b = 2.0_rk 67 | 68 | c = a - b 69 | call c%backward() 70 | 71 | call check(error, c%get_value(), -1.0_rk); if (allocated(error)) return 72 | call check(error, a%get_grad(), 1.0_rk); if (allocated(error)) return 73 | call check(error, b%get_grad(), -1.0_rk); if (allocated(error)) return 74 | 75 | a = 1.0_rk 76 | c = a - 1.0_rk 77 | call c%backward() 78 | 79 | call check(error, c%get_value(), 0.0_rk); if (allocated(error)) return 80 | call check(error, a%get_grad(), 1.0_rk); if (allocated(error)) return 81 | 82 | a = 1.0_rk 83 | c = 1.0_rk - a 84 | call c%backward() 85 | 86 | call check(error, c%get_value(), 0.0_rk); if (allocated(error)) return 87 | call check(error, a%get_grad(), -1.0_rk) 88 | 89 | end subroutine test_sub_valid 90 | 91 | subroutine test_mult_valid(error) 92 | type(error_type), allocatable, intent(out) :: error 93 | 94 | type(tree_t) :: a, b 95 | type(tree_t) :: c 96 | 97 | a = 1.0_rk 98 | b = 2.0_rk 99 | 100 | c = a * b 101 | call c%backward() 102 | 103 | call check(error, c%get_value(), 2.0_rk); if (allocated(error)) return 104 | call check(error, a%get_grad(), 2.0_rk); if (allocated(error)) return 105 | call check(error, b%get_grad(), 1.0_rk); if (allocated(error)) return 106 | 107 | a = 1.0_rk 108 | c = a * 1.0_rk 109 | call c%backward() 110 | 111 | call check(error, c%get_value(), 1.0_rk); if (allocated(error)) return 112 | call check(error, a%get_grad(), 1.0_rk); if (allocated(error)) return 113 | 114 | a = 1.0_rk 115 | c = 1.0_rk * a 116 | call c%backward() 117 | 118 | call check(error, c%get_value(), 1.0_rk); if (allocated(error)) return 119 | call check(error, a%get_grad(), 1.0_rk) 120 | 121 | end subroutine test_mult_valid 122 | 123 | subroutine test_div_valid(error) 124 | type(error_type), allocatable, intent(out) :: error 125 | 126 | type(tree_t) :: a, b 127 | type(tree_t) :: c 128 | 129 | a = 1.0_rk 130 | b = 2.0_rk 131 | 132 | c = a / b 133 | call c%backward() 134 | 135 | call check(error, c%get_value(), 0.5_rk); if (allocated(error)) return 136 | call check(error, a%get_grad(), 0.5_rk); if (allocated(error)) return 137 | call check(error, b%get_grad(), -0.25_rk); if (allocated(error)) return 138 | 139 | a = 1.0_rk 140 | c = a / 1.0_rk 141 | call c%backward() 142 | 143 | call check(error, c%get_value(), 1.0_rk); if (allocated(error)) return 144 | call check(error, a%get_grad(), 1.0_rk); if (allocated(error)) return 145 | 146 | a = 1.0_rk 147 | c = 1.0_rk / a 148 | call c%backward() 149 | 150 | call check(error, c%get_value(), 1.0_rk); if (allocated(error)) return 151 | call check(error, a%get_grad(), -1.0_rk) 152 | 153 | end subroutine test_div_valid 154 | 155 | subroutine test_pow_valid(error) 156 | type(error_type), allocatable, intent(out) :: error 157 | 158 | type(tree_t) :: a, b 159 | type(tree_t) :: c 160 | 161 | a = 1.0_rk 162 | b = 2.0_rk 163 | 164 | c = a ** b 165 | call c%backward() 166 | 167 | call check(error, c%get_value(), 1.0_rk); if (allocated(error)) return 168 | call check(error, a%get_grad(), 2.0_rk); if (allocated(error)) return 169 | call check(error, b%get_grad(), 2.0_rk); if (allocated(error)) return 170 | 171 | a = 1.0_rk 172 | c = a ** 1.0_rk 173 | call c%backward() 174 | 175 | call check(error, c%get_value(), 1.0_rk); if (allocated(error)) return 176 | call check(error, a%get_grad(), 1.0_rk); if (allocated(error)) return 177 | 178 | a = 1.0_rk 179 | c = 1.0_rk ** a 180 | call c%backward() 181 | 182 | call check(error, c%get_value(), 1.0_rk); if (allocated(error)) return 183 | call check(error, a%get_grad(), 0.0_rk) 184 | 185 | end subroutine test_pow_valid 186 | 187 | end module test_operator 188 | -------------------------------------------------------------------------------- /test/tester.f90: -------------------------------------------------------------------------------- 1 | program tester 2 | use, intrinsic :: iso_fortran_env, only: error_unit 3 | use testdrive, only: run_testsuite, new_testsuite, testsuite_type 4 | use test_operator, only: collect_suite_operator 5 | use test_func, only: collect_suite_func 6 | implicit none 7 | integer :: stat, is 8 | type(testsuite_type), allocatable :: testsuites(:) 9 | character(len=*), parameter :: fmt = '("#", *(1x, a))' 10 | 11 | stat = 0 12 | 13 | testsuites = [ & 14 | new_testsuite("operator", collect_suite_operator), & 15 | new_testsuite("func", collect_suite_func) & 16 | ] 17 | 18 | do is = 1, size(testsuites) 19 | write (error_unit, fmt) "Testing:", testsuites(is)%name 20 | call run_testsuite(testsuites(is)%collect, error_unit, stat) 21 | end do 22 | 23 | if (stat > 0) then 24 | write (error_unit, '(i0, 1x, a)') stat, "test(s) failed!" 25 | error stop 26 | end if 27 | 28 | end program tester 29 | --------------------------------------------------------------------------------