├── .cargo └── config.toml ├── .github ├── actions │ └── setup │ │ └── action.yaml └── workflows │ └── ci.yaml ├── .gitignore ├── .gitmodules ├── Cargo.toml ├── Dockerfile ├── LICENSE ├── README.md ├── oneline_json.bash ├── test_data_json ├── conv2d.json ├── conv_bias.json ├── gru_small.json ├── lstm_bidirectional.json ├── lstm_fwd_bkwd_medium.json ├── lstm_fwd_bkwd_small.json ├── rnn_fwd_bkwd_bidirectional_2_layers.json ├── rnn_fwd_bkwd_single.json ├── rnn_fwd_bkwd_single_seq_len_1.json └── rnn_fwd_bkwd_single_seq_len_3.json ├── zenu-autograd ├── Cargo.toml ├── README.md ├── benches │ └── conv2d.rs ├── src │ ├── activation │ │ ├── mod.rs │ │ ├── relu.rs │ │ └── sigmoid.rs │ ├── concat.rs │ ├── creator │ │ ├── alloc.rs │ │ ├── from_vec.rs │ │ ├── mod.rs │ │ ├── ones.rs │ │ ├── rand.rs │ │ └── zeros.rs │ ├── functions │ │ ├── add.rs │ │ ├── broadcast.rs │ │ ├── clip.rs │ │ ├── cosh.rs │ │ ├── div.rs │ │ ├── exp.rs │ │ ├── flatten.rs │ │ ├── index_axis.rs │ │ ├── log.rs │ │ ├── matmul.rs │ │ ├── mod.rs │ │ ├── mul.rs │ │ ├── powf.rs │ │ ├── reshape.rs │ │ ├── sinh.rs │ │ ├── slice.rs │ │ ├── stack.rs │ │ ├── sub.rs │ │ ├── sum.rs │ │ ├── sum_to.rs │ │ ├── tanh.rs │ │ └── transpose.rs │ ├── lib.rs │ ├── loss │ │ ├── cross_entropy.rs │ │ ├── mod.rs │ │ ├── mse.rs │ │ └── softmax.rs │ └── nn │ │ ├── batch_norm.rs │ │ ├── conv │ │ ├── conv_with_bias.rs │ │ ├── conv_without_bias.rs │ │ └── mod.rs │ │ ├── dropout.rs │ │ ├── mod.rs │ │ ├── pool2d.rs │ │ └── rnns │ │ ├── gru │ │ ├── cudnn.rs │ │ ├── mod.rs │ │ └── naive.rs │ │ ├── lstm │ │ ├── cudnn.rs │ │ ├── mod.rs │ │ └── naive.rs │ │ ├── mod.rs │ │ ├── rnn │ │ ├── cudnn.rs │ │ ├── mod.rs │ │ └── naive.rs │ │ └── weights.rs └── tests │ ├── combined.rs │ └── get_all_test_variable.rs ├── zenu-cublas-sys ├── Cargo.toml ├── build.rs ├── src │ └── lib.rs └── wrapper.h ├── zenu-cuda-config ├── Cargo.toml └── src │ └── lib.rs ├── zenu-cuda-driver-sys ├── Cargo.toml ├── build.rs ├── src │ └── lib.rs └── wrapper.h ├── zenu-cuda-kernel-sys ├── Cargo.toml ├── build.rs ├── kernel │ ├── activations.cu │ ├── activations.h │ ├── array_array.cu │ ├── array_array.h │ ├── array_scalar.cu │ ├── array_scalar.h │ ├── conv2d_bkwd_data.cu │ ├── conv2d_bkwd_data.h │ ├── element_wise.cu │ ├── element_wise.h │ ├── kernel.h │ ├── memory_access.cu │ └── memory_access.h └── src │ └── lib.rs ├── zenu-cuda-runtime-sys ├── Cargo.toml ├── build.rs ├── src │ └── lib.rs └── wrapper.h ├── zenu-cuda ├── Cargo.toml └── src │ ├── cublas │ ├── cublas_error.rs │ └── mod.rs │ ├── cudnn │ ├── batch_norm.rs │ ├── dropout.rs │ ├── error.rs │ ├── graph_batchnorm.rs │ ├── graph_conv.rs │ ├── graph_utils.rs │ ├── mod.rs │ ├── pooling.rs │ └── rnn │ │ ├── descriptor.rs │ │ ├── helper.rs │ │ ├── mod.rs │ │ └── test.rs │ ├── kernel │ ├── activation.rs │ └── mod.rs │ ├── lib.rs │ └── runtime │ ├── mod.rs │ └── runtime_error.rs ├── zenu-cudnn-frontend-wrapper-sys ├── Cargo.toml ├── build.rs ├── cudnn_frontend_wrapper │ ├── .gitignore │ ├── .gitmodules │ ├── CMakeLists.txt │ ├── README.md │ ├── include │ │ └── cudnn_frontend_wrapper.h │ ├── src │ │ ├── batchnorm.cpp │ │ ├── batchnorm.h │ │ ├── conv.cpp │ │ ├── conv.h │ │ ├── cudnn_frontend_wrapper.cpp │ │ ├── i_graph_desc.h │ │ ├── macros.h │ │ ├── utils.cpp │ │ └── utils.h │ └── tests │ │ ├── batchnorm.cpp │ │ ├── conv.cpp │ │ └── helpers.h └── src │ └── lib.rs ├── zenu-cudnn-sys ├── Cargo.toml ├── build.rs ├── src │ └── lib.rs └── wrapper.h ├── zenu-layer ├── Cargo.toml ├── README.md └── src │ ├── layers │ ├── batch_norm_2d.rs │ ├── conv2d.rs │ ├── dropout.rs │ ├── linear.rs │ ├── max_pool_2d.rs │ ├── mod.rs │ └── rnn │ │ ├── builder.rs │ │ ├── gru.rs │ │ ├── inner.rs │ │ ├── lstm.rs │ │ ├── mod.rs │ │ └── rnn.rs │ └── lib.rs ├── zenu-macros ├── Cargo.toml ├── src │ └── lib.rs └── tests │ ├── include_vector_map_model.rs │ ├── multi_parameter_struct.rs │ └── small_case.rs ├── zenu-matrix ├── Cargo.toml ├── README.md ├── benches │ ├── copy_from_all_matrix.rs │ ├── copy_from_im2col_way.rs │ ├── im2col_function.rs │ └── transpose_reshape_im2col.rs ├── src │ ├── concat.rs │ ├── constructor │ ├── device │ │ ├── cpu │ │ │ └── mod.rs │ │ ├── mod.rs │ │ └── nvidia │ │ │ └── mod.rs │ ├── dim │ │ ├── dim_dyn.rs │ │ ├── dim_static.rs │ │ └── mod.rs │ ├── impl_ops.rs │ ├── impl_serde.rs │ ├── index │ │ ├── index_dyn_impl.rs │ │ ├── index_impl.rs │ │ └── mod.rs │ ├── lib.rs │ ├── matrix.rs │ ├── matrix_blas │ │ └── mod.rs │ ├── matrix_format.rs │ ├── matrix_iter.rs │ ├── memory_pool │ │ ├── data_ptr.rs │ │ ├── dynamic_buffer.rs │ │ ├── dynamic_pool.rs │ │ ├── mod.rs │ │ ├── static_buffer.rs │ │ ├── static_mem_pool.rs │ │ └── test.rs │ ├── nn │ │ ├── batch_norm.rs │ │ ├── col2im.rs │ │ ├── conv │ │ │ ├── cpu │ │ │ │ ├── bias.rs │ │ │ │ ├── col2im.rs │ │ │ │ ├── conv_bkwd_data.rs │ │ │ │ ├── conv_bkwd_filter.rs │ │ │ │ ├── conv_fwd.rs │ │ │ │ ├── im2col.rs │ │ │ │ └── mod.rs │ │ │ ├── interface.rs │ │ │ ├── mod.rs │ │ │ ├── nvidia.rs │ │ │ ├── shape_check.rs │ │ │ └── utils.rs │ │ ├── dropout.rs │ │ ├── im2col.rs │ │ ├── mod.rs │ │ ├── pool2d.rs │ │ └── rnn │ │ │ ├── descriptor.rs │ │ │ ├── gru.rs │ │ │ ├── gru_params.rs │ │ │ ├── lstm.rs │ │ │ ├── lstm_params.rs │ │ │ ├── mod.rs │ │ │ ├── rnn.rs │ │ │ └── rnn_params.rs │ ├── num.rs │ ├── operation │ │ ├── add_axis.rs │ │ ├── asum.rs │ │ ├── basic_operations.rs │ │ ├── broadcast.rs │ │ ├── clip.rs │ │ ├── copy_from.rs │ │ ├── max.rs │ │ ├── mean.rs │ │ ├── mod.rs │ │ ├── mul.rs │ │ ├── norm2.rs │ │ ├── relu.rs │ │ ├── reshape.rs │ │ ├── softmax.rs │ │ ├── split.rs │ │ ├── stack.rs │ │ ├── sum.rs │ │ ├── to_default_stride.rs │ │ ├── transpose.rs │ │ └── var.rs │ ├── shape_stride.rs │ ├── slice │ │ ├── dynamic.rs │ │ ├── macro.rs │ │ ├── mod.rs │ │ ├── slice_dim.rs │ │ └── static_dim_slice.rs │ └── with_clousers.rs └── tests │ └── rnn.rs ├── zenu-optimizer ├── Cargo.toml ├── README.md ├── src │ ├── adam.rs │ ├── adamw.rs │ ├── lib.rs │ └── sgd.rs └── tests │ └── net_test.rs ├── zenu-test ├── Cargo.toml └── src │ └── lib.rs └── zenu ├── Cargo.toml ├── README.md ├── examples ├── cifar10.rs ├── install_mnist.py ├── mnist.py ├── mnist.rs └── resnet.rs └── src ├── dataset.rs ├── dataset_loader.rs └── lib.rs /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [env] 2 | RUST_TEST_THREADS = "1" 3 | -------------------------------------------------------------------------------- /.github/actions/setup/action.yaml: -------------------------------------------------------------------------------- 1 | name: Setup 2 | description: Setup 3 | 4 | runs: 5 | using: "composite" 6 | steps: 7 | - uses: actions/checkout@v4 8 | 9 | - uses: chetan/git-restore-mtime-action@v2 # for Rust incremental build 10 | 11 | - name: Rust cache 12 | uses: actions/cache@v4 13 | with: 14 | path: | 15 | ~/.cargo/bin/ 16 | ~/.cargo/registry/index/ 17 | ~/.cargo/registry/cache/ 18 | ~/.cargo/git/db 19 | sql-extraction/rs/target/ 20 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }} 21 | 22 | - name: Install openblas 23 | shell: bash 24 | run: sudo apt-get install libopenblas-dev -y 25 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | clippy: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Setup 16 | uses: ./.github/actions/setup 17 | 18 | - name: Run clippy on zenu-matrix 19 | run: cd zenu-matrix && cargo clippy --verbose 20 | 21 | - name: Run clippy on zenu-autograd 22 | run: cd zenu-autograd && cargo clippy --verbose 23 | 24 | - name: Run clippy on zenu-layer 25 | run: cd zenu-layer && cargo clippy --verbose 26 | 27 | - name: Run clippy on zenu-optimizer 28 | run: cd zenu-optimizer && cargo clippy --verbose 29 | 30 | - name: Run clippy on zenu-macros 31 | run: cd zenu-macros && cargo clippy --verbose 32 | 33 | - name: Run clippy on zenu 34 | run: cd zenu && cargo clippy --verbose 35 | 36 | fmt: 37 | runs-on: ubuntu-latest 38 | steps: 39 | - uses: actions/checkout@v4 40 | 41 | - name: Setup 42 | uses: ./.github/actions/setup 43 | 44 | - name: Run fmt 45 | run: cargo fmt -- --check 46 | 47 | test: 48 | runs-on: ubuntu-latest 49 | steps: 50 | - uses: actions/checkout@v4 51 | 52 | - name: Setup 53 | uses: ./.github/actions/setup 54 | 55 | - name: Run tests on zenu-matrix 56 | run: cd zenu-matrix && cargo test --verbose 57 | 58 | - name: Run tests on zenu-autograd 59 | run: cd zenu-autograd && cargo test --verbose 60 | 61 | - name: Run tests on zenu-layer 62 | run: cd zenu-layer && cargo test --verbose 63 | 64 | - name: Run tests on zenu-optimizer 65 | run: cd zenu-optimizer && cargo test --verbose 66 | 67 | - name: Run tests on zenu-macros 68 | run: cd zenu-macros && cargo test --verbose 69 | 70 | - name: Run tests on zenu 71 | run: cd zenu && cargo test --verbose 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | *ccls* 4 | *bindings.rs 5 | .DS_Store 6 | .bash_history 7 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "zenu-cudnn-frontend-wrapper-sys/cudnn_frontend_wrapper/cudnn-frontend"] 2 | path = zenu-cudnn-frontend-wrapper-sys/cudnn_frontend_wrapper/cudnn-frontend 3 | url = git@github.com:nvidia/cudnn-frontend.git 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | resolver = "2" 3 | members = [ 4 | "zenu", 5 | "zenu-autograd", 6 | "zenu-cublas-sys", 7 | "zenu-cuda", 8 | "zenu-cuda-config", 9 | "zenu-cuda-driver-sys", 10 | "zenu-cuda-kernel-sys", 11 | "zenu-cuda-driver-sys", 12 | "zenu-cuda-kernel-sys", 13 | "zenu-cuda-runtime-sys", 14 | "zenu-cudnn-sys", 15 | "zenu-layer", 16 | "zenu-matrix", 17 | "zenu-test", 18 | "zenu-optimizer", 19 | "zenu-macros", 20 | "zenu-cudnn-frontend-wrapper-sys" 21 | ] 22 | 23 | [workspace.lints.clippy] 24 | pedantic = "warn" 25 | 26 | [profile.bench] 27 | debug = true 28 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | 5 | RUN apt-get update && \ 6 | echo "tzdata tzdata/Areas select Etc" > /tmp/preseed.txt && \ 7 | echo "tzdata tzdata/Zones/Etc select UTC" >> /tmp/preseed.txt && \ 8 | debconf-set-selections /tmp/preseed.txt && \ 9 | apt-get install -y --no-install-recommends \ 10 | tzdata \ 11 | build-essential \ 12 | curl \ 13 | ca-certificates \ 14 | gnupg \ 15 | pkg-config \ 16 | libssl-dev \ 17 | lsb-release \ 18 | software-properties-common \ 19 | clang \ 20 | libopenblas-dev && \ 21 | rm -rf /var/lib/apt/lists/* 22 | 23 | RUN curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | apt-key add - && \ 24 | curl -s -L https://nvidia.github.io/nvidia-docker/ubuntu20.04/nvidia-docker.list | tee /etc/apt/sources.list.d/nvidia-docker.list && \ 25 | apt-get update 26 | 27 | RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y 28 | 29 | ENV PATH=/root/.cargo/bin:$PATH 30 | 31 | # home directory 32 | WORKDIR /home 33 | 34 | CMD ["bash"] 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 冒頓単于 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /oneline_json.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | git ls-files '*.json' | while read file; do 4 | jq -c . "$file" > "$file.tmp" && mv "$file.tmp" "$file" 5 | done 6 | -------------------------------------------------------------------------------- /test_data_json/lstm_fwd_bkwd_small.json: -------------------------------------------------------------------------------- 1 | {"rnn.weight_ih_l0":{"shape":[8,2],"stride":[2,1],"data":[-0.005293981172144413,0.37932288646698,-0.5819807648658752,-0.5203874707221985,-0.27234524488449097,0.1896158903837204,-0.014010033570230007,0.5606575012207031,-0.06275151669979095,0.1871093362569809,-0.2136969119310379,-0.13899271190166473,-0.6755334138870239,-0.4683041572570801,-0.2914857566356659,0.026193760335445404],"data_type":"f32","ptr_offset":0},"rnn.weight_hh_l0":{"shape":[8,2],"stride":[2,1],"data":[0.2795442044734955,0.42428016662597656,-0.47937673330307007,-0.3079187273979187,0.2568332850933075,0.5871729254722595,-0.14552269876003265,0.5291363000869751,-0.11397384107112885,0.07482186704874039,0.6402683854103088,-0.6559619903564453,-0.4451505243778229,-0.17901486158370972,-0.2756301760673523,0.6109408140182495],"data_type":"f32","ptr_offset":0},"rnn.bias_ih_l0":{"shape":[8],"stride":[1],"data":[-0.45833221077919006,-0.32550448179244995,-0.49401339888572693,-0.6622486114501953,-0.41276684403419495,0.6078276038169861,0.31552404165267944,0.34271523356437683],"data_type":"f32","ptr_offset":0},"rnn.bias_hh_l0":{"shape":[8],"stride":[1],"data":[0.0371878482401371,-0.3625219762325287,0.11963163316249847,-0.6602218747138977,-0.5109314918518066,-0.36453473567962646,0.44614049792289734,0.4145916998386383],"data_type":"f32","ptr_offset":0},"input":{"shape":[3,1,2],"stride":[2,2,1],"data":[0.19186942279338837,1.2637947797775269,-1.2904350757598877,-0.7911027073860168,-0.02087947353720665,-0.7184800505638123],"data_type":"f32","ptr_offset":0},"output":{"shape":[3,1,2],"stride":[2,2,1],"data":[-0.1535118669271469,0.0033985055051743984,-0.32110878825187683,0.22685329616069794,-0.2945723235607147,0.0365699902176857],"data_type":"f32","ptr_offset":0},"input_grad":{"shape":[3,1,2],"stride":[2,2,1],"data":[0.010557422414422035,0.026116354390978813,-0.07728055119514465,-0.11851117014884949,-0.015096936374902725,-0.0308743417263031],"data_type":"f32","ptr_offset":0},"rnn.weight_ih_l0_grad":{"shape":[8,2],"stride":[2,1],"data":[0.21715931594371796,0.03572092950344086,-0.09307316690683365,-0.05358332023024559,0.11629969626665115,0.11310197412967682,-0.0012049516662955284,-0.023297177627682686,-0.1474316269159317,0.22660601139068604,-0.38990604877471924,-0.34543612599372864,0.04021185263991356,-0.02074984833598137,-0.05248009040951729,-0.038393534719944],"data_type":"f32","ptr_offset":0},"rnn.weight_hh_l0_grad":{"shape":[8,2],"stride":[2,1],"data":[0.06499463319778442,-0.02602470852434635,-0.010737811215221882,-2.880929969251156e-05,0.03271043300628662,-0.013741767033934593,-0.01032883208245039,0.007253238931298256,-0.05449433624744415,0.021447159349918365,-0.15156187117099762,0.07422848790884018,0.028118863701820374,-0.015438945032656193,-0.009459368884563446,0.0024156884755939245],"data_type":"f32","ptr_offset":0},"rnn.bias_ih_l0_grad":{"shape":[8],"stride":[1],"data":[-0.4551074504852295,0.07355727255344391,-0.14840565621852875,0.03238354250788689,0.5876979231834412,0.7409195899963379,-0.19068323075771332,0.05141274631023407],"data_type":"f32","ptr_offset":0},"rnn.bias_hh_l0_grad":{"shape":[8],"stride":[1],"data":[-0.4551074504852295,0.07355727255344391,-0.14840565621852875,0.03238354250788689,0.5876979231834412,0.7409195899963379,-0.19068323075771332,0.05141274631023407],"data_type":"f32","ptr_offset":0}} 2 | -------------------------------------------------------------------------------- /test_data_json/rnn_fwd_bkwd_single.json: -------------------------------------------------------------------------------- 1 | {"rnn.weight_ih_l0":{"shape":[4,2],"stride":[2,1],"data":[-0.003743410110473633,0.26822179555892944,-0.4115225672721863,-0.3679695129394531,-0.19257718324661255,0.13407868146896362,-0.009906589984893799,0.39644473791122437],"data_type":"f32","ptr_offset":0},"rnn.weight_hh_l0":{"shape":[4,4],"stride":[4,1],"data":[-0.04437202215194702,0.1323062777519226,-0.15110653638839722,-0.09828269481658936,-0.4776742458343506,-0.33114105463027954,-0.20611155033111572,0.018521785736083984,0.19766759872436523,0.30001139640808105,-0.3389705419540405,-0.21773141622543335,0.18160855770111084,0.41519397497177124,-0.10290008783340454,0.3741558790206909],"data_type":"f32","ptr_offset":0},"rnn.bias_ih_l0":{"shape":[4],"stride":[1],"data":[-0.08059167861938477,0.05290704965591431,0.45273810625076294,-0.4638351798057556],"data_type":"f32","ptr_offset":0},"rnn.bias_hh_l0":{"shape":[4],"stride":[1],"data":[-0.3147689700126648,-0.12658262252807617,-0.19489997625350952,0.43200039863586426],"data_type":"f32","ptr_offset":0},"input":{"shape":[5,1,2],"stride":[2,2,1],"data":[-0.5663174986839294,0.3731146454811096,-0.8919953107833862,-1.5091077089309692,0.3703935444355011,1.4565025568008423,0.9398099184036255,0.7748488187789917,0.19186942279338837,1.2637947797775269],"data_type":"f32","ptr_offset":0},"output":{"shape":[5,1,4],"stride":[4,4,1],"data":[0,0.022082045674324036,0.4169246554374695,0.12169483304023743,0,0.7577149271965027,0.06607982516288757,0,0.08418391644954681,0,0.586718738079071,0.8497177958488464,0,0,0,0.5388816595077515,0,0,0.2730049192905426,0.6689149141311646],"data_type":"f32","ptr_offset":0},"input_grad":{"shape":[5,1,2],"stride":[2,2,1],"data":[-0.27985307574272156,0.2777567505836487,-0.902430534362793,-0.6932638883590698,-0.1883838176727295,1.0106563568115234,-0.011456223204731941,0.45845842361450195,-0.20248377323150635,0.530523419380188],"data_type":"f32","ptr_offset":0},"rnn.weight_ih_l0_grad":{"shape":[4,2],"stride":[2,1],"data":[0.4481823444366455,1.762392282485962,-2.051854372024536,-2.8820881843566895,-0.07221640646457672,2.1578407287597656,1.2676061391830444,4.6034770011901855],"data_type":"f32","ptr_offset":0},"rnn.weight_hh_l0_grad":{"shape":[4,4],"stride":[4,1],"data":[0,0.9168476462364197,0.07995768636465073,0,0,0.04458906501531601,0.8418731093406677,0.24573169648647308,0,0.6757444143295288,0.21293775737285614,0.5840427279472351,0.09735234826803207,1.0855653285980225,0.7731673717498779,1.5215160846710205],"data_type":"f32","ptr_offset":0},"rnn.bias_ih_l0_grad":{"shape":[4],"stride":[1],"data":[1.2100166082382202,2.4619247913360596,2.71012282371521,4.5457072257995605],"data_type":"f32","ptr_offset":0},"rnn.bias_hh_l0_grad":{"shape":[4],"stride":[1],"data":[1.2100166082382202,2.4619247913360596,2.71012282371521,4.5457072257995605],"data_type":"f32","ptr_offset":0}} 2 | -------------------------------------------------------------------------------- /test_data_json/rnn_fwd_bkwd_single_seq_len_1.json: -------------------------------------------------------------------------------- 1 | {"rnn.weight_ih_l0":{"shape":[4,2],"stride":[2,1],"data":[-0.003743410110473633,0.26822179555892944,-0.4115225672721863,-0.3679695129394531,-0.19257718324661255,0.13407868146896362,-0.009906589984893799,0.39644473791122437],"data_type":"f32","ptr_offset":0},"rnn.weight_hh_l0":{"shape":[4,4],"stride":[4,1],"data":[-0.04437202215194702,0.1323062777519226,-0.15110653638839722,-0.09828269481658936,-0.4776742458343506,-0.33114105463027954,-0.20611155033111572,0.018521785736083984,0.19766759872436523,0.30001139640808105,-0.3389705419540405,-0.21773141622543335,0.18160855770111084,0.41519397497177124,-0.10290008783340454,0.3741558790206909],"data_type":"f32","ptr_offset":0},"rnn.bias_ih_l0":{"shape":[4],"stride":[1],"data":[-0.08059167861938477,0.05290704965591431,0.45273810625076294,-0.4638351798057556],"data_type":"f32","ptr_offset":0},"rnn.bias_hh_l0":{"shape":[4],"stride":[1],"data":[-0.3147689700126648,-0.12658262252807617,-0.19489997625350952,0.43200039863586426],"data_type":"f32","ptr_offset":0},"input":{"shape":[1,1,2],"stride":[2,2,1],"data":[-0.5663174986839294,0.3731146454811096],"data_type":"f32","ptr_offset":0},"output":{"shape":[1,1,4],"stride":[4,4,1],"data":[0,0.022082045674324036,0.4169246554374695,0.12169483304023743],"data_type":"f32","ptr_offset":0},"input_grad":{"shape":[1,1,2],"stride":[2,2,1],"data":[-0.6140063405036926,0.16255390644073486],"data_type":"f32","ptr_offset":0},"rnn.weight_ih_l0_grad":{"shape":[4,2],"stride":[2,1],"data":[-0,0,-0.5663174986839294,0.3731146454811096,-0.5663174986839294,0.3731146454811096,-0.5663174986839294,0.3731146454811096],"data_type":"f32","ptr_offset":0},"rnn.weight_hh_l0_grad":{"shape":[4,4],"stride":[4,1],"data":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"data_type":"f32","ptr_offset":0},"rnn.bias_ih_l0_grad":{"shape":[4],"stride":[1],"data":[0,1,1,1],"data_type":"f32","ptr_offset":0},"rnn.bias_hh_l0_grad":{"shape":[4],"stride":[1],"data":[0,1,1,1],"data_type":"f32","ptr_offset":0}} 2 | -------------------------------------------------------------------------------- /test_data_json/rnn_fwd_bkwd_single_seq_len_3.json: -------------------------------------------------------------------------------- 1 | {"rnn.weight_ih_l0":{"shape":[4,2],"stride":[2,1],"data":[-0.003743410110473633,0.26822179555892944,-0.4115225672721863,-0.3679695129394531,-0.19257718324661255,0.13407868146896362,-0.009906589984893799,0.39644473791122437],"data_type":"f32","ptr_offset":0},"rnn.weight_hh_l0":{"shape":[4,4],"stride":[4,1],"data":[-0.04437202215194702,0.1323062777519226,-0.15110653638839722,-0.09828269481658936,-0.4776742458343506,-0.33114105463027954,-0.20611155033111572,0.018521785736083984,0.19766759872436523,0.30001139640808105,-0.3389705419540405,-0.21773141622543335,0.18160855770111084,0.41519397497177124,-0.10290008783340454,0.3741558790206909],"data_type":"f32","ptr_offset":0},"rnn.bias_ih_l0":{"shape":[4],"stride":[1],"data":[-0.08059167861938477,0.05290704965591431,0.45273810625076294,-0.4638351798057556],"data_type":"f32","ptr_offset":0},"rnn.bias_hh_l0":{"shape":[4],"stride":[1],"data":[-0.3147689700126648,-0.12658262252807617,-0.19489997625350952,0.43200039863586426],"data_type":"f32","ptr_offset":0},"input":{"shape":[3,1,2],"stride":[2,2,1],"data":[-0.5663174986839294,0.3731146454811096,-0.8919953107833862,-1.5091077089309692,0.3703935444355011,1.4565025568008423],"data_type":"f32","ptr_offset":0},"output":{"shape":[3,1,4],"stride":[4,4,1],"data":[0,0.022082045674324036,0.4169246554374695,0.12169483304023743,0,0.7577149271965027,0.06607982516288757,0,0.08418391644954681,0,0.586718738079071,0.8497177958488464],"data_type":"f32","ptr_offset":0},"input_grad":{"shape":[3,1,2],"stride":[2,2,1],"data":[-0.3120531439781189,0.2516170144081116,-0.8386760354042053,-0.6252549290657043,-0.20622718334197998,0.7987452149391174],"data_type":"f32","ptr_offset":0},"rnn.weight_ih_l0_grad":{"shape":[4,2],"stride":[2,1],"data":[0.3703935444355011,1.4565025568008423,-1.936977505683899,-2.597684383392334,-0.26520225405693054,1.0218181610107422,-0.16511490941047668,1.8093187808990479],"data_type":"f32","ptr_offset":0},"rnn.weight_hh_l0_grad":{"shape":[4,4],"stride":[4,1],"data":[0,0.7577149271965027,0.06607982516288757,0,0,0.04079683497548103,0.7702731490135193,0.22483262419700623,0,0.7667028307914734,0.23577767610549927,0.04953257739543915,0,0.7577149271965027,0.06607982516288757,0],"data_type":"f32","ptr_offset":0},"rnn.bias_ih_l0_grad":{"shape":[4],"stride":[1],"data":[1,2.3578362464904785,1.8882606029510498,1.945597529411316],"data_type":"f32","ptr_offset":0},"rnn.bias_hh_l0_grad":{"shape":[4],"stride":[1],"data":[1,2.3578362464904785,1.8882606029510498,1.945597529411316],"data_type":"f32","ptr_offset":0}} 2 | -------------------------------------------------------------------------------- /zenu-autograd/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu-autograd" 3 | version = "0.1.2" 4 | edition = "2021" 5 | description = "A simple autograd library for learning purposes" 6 | license = "MIT" 7 | repository = "https://github.com/bokutotu/zenu" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | zenu-matrix = { path = "../zenu-matrix", version = "0.1.2" } 13 | rand = "0.8.5" 14 | rand_distr = "0.4.3" 15 | lazy_static = "1.4.0" 16 | serde = { version = "1.0.197", features = ["derive"] } 17 | once_cell = "1.19.0" 18 | 19 | [features] 20 | nvidia = ["zenu-matrix/nvidia"] 21 | 22 | [dev-dependencies] 23 | criterion = "0.5.1" 24 | zenu-test = { path = "../zenu-test"} 25 | serde_json = "1.0.114" 26 | 27 | [[bench]] 28 | name = "conv2d" 29 | harness = false 30 | 31 | [profile.bench] 32 | debug = true 33 | 34 | 35 | -------------------------------------------------------------------------------- /zenu-autograd/README.md: -------------------------------------------------------------------------------- 1 | # ZeNu Autograd 2 | 3 | ZeNu Autograd is an automatic differentiation library for Rust. It provides the foundation for building and training neural networks by automatically computing gradients of mathematical expressions. 4 | 5 | ## Features 6 | 7 | - Define and manipulate mathematical expressions using Variables 8 | - Automatically compute gradients through reverse-mode automatic differentiation 9 | - Support for various mathematical operations and functions 10 | - Integration with ZeNu deep learning library 11 | 12 | ## Getting Started 13 | 14 | To use ZeNu Autograd in your Rust project, add the following to your `Cargo.toml` file: 15 | 16 | ```toml 17 | [dependencies] 18 | zenu-autograd = "0.1.0" 19 | ``` 20 | 21 | Here's a simple example of using ZeNu Autograd: 22 | 23 | ```rust 24 | use zenu_autograd::{Variable, creator::from_vec::from_vec}; 25 | 26 | fn main() { 27 | let x = from_vec(vec![1., 2., 3., 4., 5., 6.], [3, 2]); 28 | let y = from_vec(vec![7., 8., 9., 10., 11., 12.], [3, 2]); 29 | let z = x.clone() * y.clone() + y.clone(); 30 | 31 | z.backward(); 32 | 33 | let x_grad = x.get_grad().unwrap(); 34 | let y_grad = y.get_grad().unwrap(); 35 | 36 | // Perform further computations with the gradients 37 | } 38 | ``` 39 | 40 | For more details and examples, please refer to the [documentation](https://docs.rs/zenu-autograd). 41 | 42 | ## License 43 | 44 | ZeNu Autograd is licensed under the [MIT License](LICENSE). 45 | -------------------------------------------------------------------------------- /zenu-autograd/benches/conv2d.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | extern crate criterion; 3 | 4 | use criterion::black_box; 5 | use criterion::Criterion; 6 | use zenu_autograd::creator::ones::ones; 7 | use zenu_autograd::functions::conv2d::conv2d; 8 | use zenu_autograd::Variable; 9 | 10 | fn conv2d_bench(kernel: Variable, input: Variable) { 11 | let _ = conv2d(input, kernel, None, (1, 1), (0, 0)); 12 | } 13 | 14 | fn conv2d_bench_no_bias(c: &mut Criterion) { 15 | let kernel = black_box(ones([32, 16, 3, 3])); 16 | let input = black_box(ones([32, 16, 128, 128])); 17 | 18 | c.bench_function("conv2d_bech_no_bias", |b| { 19 | b.iter(|| conv2d_bench(kernel.clone(), input.clone())) 20 | }); 21 | } 22 | 23 | criterion_group!(benches, conv2d_bench_no_bias); 24 | criterion_main!(benches); 25 | -------------------------------------------------------------------------------- /zenu-autograd/src/activation/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod relu; 2 | pub mod sigmoid; 3 | -------------------------------------------------------------------------------- /zenu-autograd/src/activation/relu.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, rc::Rc}; 2 | 3 | use zenu_matrix::{ 4 | device::Device, 5 | dim::DimDyn, 6 | matrix::{Matrix, Owned}, 7 | num::Num, 8 | }; 9 | 10 | use crate::{creator::alloc::alloc, Function, Variable, VariableWeak}; 11 | 12 | struct Relu { 13 | input: Variable, 14 | output: VariableWeak, 15 | } 16 | 17 | impl Relu { 18 | pub fn new(input: Variable, output: Variable) -> Self { 19 | let output = output.downgrade(); 20 | Self { input, output } 21 | } 22 | } 23 | 24 | impl Function for Relu { 25 | fn forward(&self) { 26 | let input = self.input.get_data(); 27 | let output = self.output.upgrade().unwrap(); 28 | let mut output = output.get_data_mut(); 29 | let output_view_mut = output.to_ref_mut(); 30 | output_view_mut.relu(&input.to_ref(), T::zero()); 31 | } 32 | 33 | fn backward(&self) { 34 | // リファレンスカウンタの関係でスコープを切る必要がある 35 | // TODO: 複数回微分の場合に対応する 36 | let input_grad = { 37 | let input = self.input.get_data(); 38 | let output = self.output.upgrade().unwrap(); 39 | let output_grad = output.get_grad().unwrap(); 40 | let mut mask: Matrix, DimDyn, D> = Matrix::alloc(input.shape()); 41 | mask.to_ref_mut() 42 | .relu_backward_mask(&input.to_ref(), T::zero()); 43 | let mask = Variable::from(mask); 44 | output_grad * mask 45 | }; 46 | self.input.set_grad(input_grad); 47 | } 48 | 49 | fn get_inputs(&self) -> Vec> { 50 | vec![self.input.clone()] 51 | } 52 | } 53 | 54 | #[must_use] 55 | pub fn relu(input: Variable) -> Variable { 56 | let output = alloc(input.get_shape()); 57 | let relu = Relu::new(input, output.clone()); 58 | relu.forward(); 59 | output.set_creator(Rc::new(RefCell::new(Box::new(relu)))); 60 | output 61 | } 62 | 63 | #[cfg(test)] 64 | mod relu_test { 65 | 66 | use zenu_matrix::{ 67 | device::Device, 68 | dim::DimDyn, 69 | matrix::{Matrix, Owned}, 70 | }; 71 | use zenu_test::{assert_val_eq, assert_val_eq_grad, run_test}; 72 | 73 | use crate::Variable; 74 | 75 | use super::relu; 76 | 77 | fn relu_1d() { 78 | let x: Matrix, DimDyn, D> = Matrix::from_vec(vec![-1., 0., 2., 3.], [2, 2]); 79 | let x_v = Variable::from(x); 80 | let y = relu(x_v.clone()); 81 | y.backward(); 82 | let ans = Matrix::, DimDyn, D>::from_vec(vec![0., 0., 2., 3.], [2, 2]); 83 | let x_grad = Matrix::, DimDyn, D>::from_vec(vec![0., 0., 1., 1.], [2, 2]); 84 | assert_val_eq!(y, ans, 1.0e-6); 85 | assert_val_eq_grad!(x_v, x_grad, 1.0e-6); 86 | } 87 | run_test!(relu_1d, relu_1d_cpu, relu_1d_nvidia); 88 | } 89 | -------------------------------------------------------------------------------- /zenu-autograd/src/activation/sigmoid.rs: -------------------------------------------------------------------------------- 1 | use zenu_matrix::{device::Device, num::Num}; 2 | 3 | use crate::{functions::tanh::tanh, Variable}; 4 | 5 | #[expect(clippy::needless_pass_by_value)] 6 | #[must_use] 7 | pub fn sigmoid(x: Variable) -> Variable { 8 | let one = T::one(); 9 | let two = one + one; 10 | let half = one / two; 11 | let half = Variable::from(half); 12 | let x_half = x.clone() * half.clone(); 13 | let x_half_tanh = tanh(x_half); 14 | half.clone() + half.clone() * x_half_tanh 15 | } 16 | 17 | #[expect(clippy::unreadable_literal, clippy::excessive_precision)] 18 | #[cfg(test)] 19 | mod sigmoid { 20 | use zenu_test::run_test; 21 | 22 | use crate::creator::from_vec::from_vec; 23 | 24 | use super::*; 25 | 26 | fn test_sigmoid() { 27 | let x = Variable::::from(0.0); 28 | let y = sigmoid(x); 29 | assert!(y.get_data().index_item([]) - 0.5 < 1e-6); 30 | } 31 | run_test!(test_sigmoid, test_sigmoid_cpu, test_sigmoid_nvidia); 32 | 33 | fn test_sigmoid_05() { 34 | let x = Variable::::from(0.5); 35 | let y = sigmoid(x); 36 | assert!(y.get_data().index_item([]) - 0.62245935 < 1e-6); 37 | } 38 | run_test!(test_sigmoid_05, test_sigmoid_05_cpu, test_sigmoid_05_nvidia); 39 | 40 | fn test_sigmoid_01() { 41 | let x = Variable::::from(0.1); 42 | let y = sigmoid(x); 43 | assert!(y.get_data().index_item([]) - 0.52497919 < 1e-6); 44 | } 45 | run_test!(test_sigmoid_01, test_sigmoid_01_cpu, test_sigmoid_01_nvidia); 46 | 47 | fn sigmoid_1d() { 48 | let x: Variable = from_vec(vec![0.0, 0.5, 0.1], [3]); 49 | let y = sigmoid(x); 50 | assert!(y.get_data().index_item([0]) - 0.5 < 1e-6); 51 | assert!(y.get_data().index_item([1]) - 0.62245935 < 1e-6); 52 | assert!(y.get_data().index_item([2]) - 0.52497919 < 1e-6); 53 | } 54 | run_test!(sigmoid_1d, sigmoid_1d_cpu, sigmoid_1d_nvidia); 55 | 56 | fn sigmoid_2d() { 57 | let x: Variable = from_vec(vec![0.0, 0.5, 0.1, 0.2, 0.3, 0.4], [2, 3]); 58 | let y = sigmoid(x); 59 | assert!(y.get_data().index_item([0, 0]) - 0.5 < 1e-6,); 60 | assert!(y.get_data().index_item([0, 1]) - 0.62245935 < 1e-6,); 61 | assert!(y.get_data().index_item([0, 2]) - 0.52497919 < 1e-6,); 62 | assert!(y.get_data().index_item([1, 0]) - 0.54983399 < 1e-6,); 63 | assert!(y.get_data().index_item([1, 1]) - 0.57444252 < 1e-6,); 64 | assert!(y.get_data().index_item([1, 2]) - 0.59868766 < 1e-6,); 65 | } 66 | run_test!(sigmoid_2d, sigmoid_2d_cpu, sigmoid_2d_nvidia); 67 | } 68 | -------------------------------------------------------------------------------- /zenu-autograd/src/creator/alloc.rs: -------------------------------------------------------------------------------- 1 | use zenu_matrix::{ 2 | device::Device, 3 | dim::DimDyn, 4 | matrix::{Matrix, Owned}, 5 | num::Num, 6 | }; 7 | 8 | use crate::Variable; 9 | 10 | pub fn alloc, D: Device>(shape: I) -> Variable { 11 | let matrix = Matrix::, DimDyn, D>::alloc(shape); 12 | Variable::new(matrix) 13 | } 14 | 15 | #[must_use] 16 | #[expect(clippy::module_name_repetitions)] 17 | pub fn alloc_like(a: &Variable) -> Variable { 18 | alloc(a.get_data().shape()) 19 | } 20 | -------------------------------------------------------------------------------- /zenu-autograd/src/creator/from_vec.rs: -------------------------------------------------------------------------------- 1 | use zenu_matrix::{ 2 | device::Device, 3 | dim::DimDyn, 4 | matrix::{Matrix, Owned}, 5 | num::Num, 6 | }; 7 | 8 | use crate::Variable; 9 | 10 | pub fn from_vec, D: Device>(vec: Vec, dim: I) -> Variable { 11 | let matrix = Matrix::, DimDyn, D>::from_vec(vec, dim.into()); 12 | Variable::new(matrix) 13 | } 14 | -------------------------------------------------------------------------------- /zenu-autograd/src/creator/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod alloc; 2 | pub mod from_vec; 3 | pub mod ones; 4 | pub mod rand; 5 | pub mod zeros; 6 | -------------------------------------------------------------------------------- /zenu-autograd/src/creator/ones.rs: -------------------------------------------------------------------------------- 1 | use zenu_matrix::{ 2 | device::Device, 3 | dim::DimDyn, 4 | matrix::{Matrix, Owned}, 5 | num::Num, 6 | }; 7 | 8 | use crate::Variable; 9 | 10 | pub fn ones, D: Device>(dim: I) -> Variable { 11 | let matrix = Matrix::, DimDyn, D>::ones(dim.into()); 12 | Variable::new(matrix) 13 | } 14 | 15 | #[expect(clippy::module_name_repetitions)] 16 | #[must_use] 17 | pub fn ones_like(a: &Variable) -> Variable { 18 | ones(a.get_data().shape()) 19 | } 20 | -------------------------------------------------------------------------------- /zenu-autograd/src/creator/rand.rs: -------------------------------------------------------------------------------- 1 | use rand_distr::{Distribution, StandardNormal}; 2 | use zenu_matrix::{ 3 | constructor::rand::{NormalBuilder, UniformBuilder}, 4 | device::Device, 5 | dim::{DimDyn, DimTrait}, 6 | num::Num, 7 | }; 8 | 9 | use crate::Variable; 10 | 11 | pub fn uniform( 12 | low: T, 13 | high: T, 14 | seed: Option, 15 | shape: S, 16 | ) -> Variable { 17 | let mut builder = UniformBuilder::new().low(low).high(high).shape(shape); 18 | if let Some(seed) = seed { 19 | builder = builder.seed(seed); 20 | } 21 | let matrix = builder.build(); 22 | let matrix = matrix.into_dyn_dim(); 23 | Variable::from(matrix) 24 | } 25 | 26 | pub fn uniform_like( 27 | a: &Variable, 28 | low: T, 29 | high: T, 30 | seed: Option, 31 | ) -> Variable { 32 | uniform(low, high, seed, a.get_data().shape()) 33 | } 34 | 35 | pub fn normal(mean: T, std_dev: T, seed: Option, shape: I) -> Variable 36 | where 37 | T: Num, 38 | I: Into, 39 | StandardNormal: Distribution, 40 | D: Device, 41 | { 42 | let mut builder = NormalBuilder::new() 43 | .std_dev(std_dev) 44 | .mean(mean) 45 | .shape(shape.into()); 46 | if let Some(seed) = seed { 47 | builder = builder.seed(seed); 48 | } 49 | let matrix = builder.build(); 50 | let matrix = matrix.into_dyn_dim(); 51 | Variable::from(matrix) 52 | } 53 | 54 | pub fn normal_like( 55 | a: &Variable, 56 | mean: T, 57 | std_dev: T, 58 | seed: Option, 59 | ) -> Variable 60 | where 61 | StandardNormal: Distribution, 62 | { 63 | normal(mean, std_dev, seed, a.get_data().shape()) 64 | } 65 | -------------------------------------------------------------------------------- /zenu-autograd/src/creator/zeros.rs: -------------------------------------------------------------------------------- 1 | use zenu_matrix::{ 2 | device::Device, 3 | dim::DimDyn, 4 | matrix::{Matrix, Owned}, 5 | num::Num, 6 | }; 7 | 8 | use crate::Variable; 9 | 10 | pub fn zeros, D: Device>(dim: I) -> Variable { 11 | let matrix = Matrix::, DimDyn, D>::zeros(dim.into()); 12 | Variable::new(matrix) 13 | } 14 | 15 | #[expect(clippy::module_name_repetitions)] 16 | #[must_use] 17 | pub fn zeros_like(a: &Variable) -> Variable { 18 | zeros(a.get_data().shape()) 19 | } 20 | -------------------------------------------------------------------------------- /zenu-autograd/src/functions/broadcast.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, rc::Rc}; 2 | 3 | use zenu_matrix::{device::Device, dim::DimDyn, num::Num}; 4 | 5 | use crate::{creator::alloc::alloc, Function, Variable, VariableWeak}; 6 | 7 | use super::sum_to::sum_to; 8 | 9 | struct Broadcast { 10 | x: Variable, 11 | output: VariableWeak, 12 | } 13 | 14 | impl Broadcast { 15 | pub fn new(x: Variable, output: Variable) -> Self { 16 | let output = output.downgrade(); 17 | Self { x, output } 18 | } 19 | } 20 | 21 | impl Function for Broadcast { 22 | fn forward(&self) { 23 | let output = self.output.upgrade().unwrap(); 24 | let mut output = output.get_data_mut(); 25 | let x = self.x.get_data(); 26 | let x = x.to_ref(); 27 | let output = output.to_ref_mut(); 28 | output.broadcast(&x); 29 | } 30 | 31 | fn backward(&self) { 32 | let x_shape = self.x.get_data().shape(); 33 | let output = self.output.upgrade().unwrap(); 34 | let output_grad = output.get_grad().unwrap(); 35 | let x_grad = sum_to(output_grad, x_shape); 36 | self.x.set_grad(x_grad); 37 | } 38 | 39 | fn get_inputs(&self) -> Vec> { 40 | vec![self.x.clone()] 41 | } 42 | } 43 | 44 | #[must_use] 45 | pub fn broadcast(x: Variable, shape: DimDyn) -> Variable { 46 | let output = alloc(shape); 47 | let broadcast = Broadcast::new(x, output.clone()); 48 | broadcast.forward(); 49 | output.set_creator(Rc::new(RefCell::new(Box::new(broadcast)))); 50 | output 51 | } 52 | 53 | #[cfg(test)] 54 | mod broadcast { 55 | use zenu_matrix::{ 56 | device::Device, 57 | dim::DimDyn, 58 | matrix::{Matrix, Owned}, 59 | }; 60 | use zenu_test::{assert_val_eq, assert_val_eq_grad, run_test}; 61 | 62 | use crate::Variable; 63 | 64 | use super::broadcast; 65 | 66 | fn broadcast_2d_1d() { 67 | let x: Matrix, DimDyn, D> = Matrix::from_vec(vec![1.0, 2.0, 3.0], [3]); 68 | let x = Variable::from(x); 69 | let y = broadcast(x.clone(), DimDyn::new(&[3, 3])); 70 | let forward_ans: Matrix, DimDyn, D> = 71 | Matrix::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], [3, 3]); 72 | y.backward(); 73 | assert_val_eq!(y, forward_ans, 1e-6); 74 | assert_val_eq_grad!( 75 | x, 76 | Matrix::<_, DimDyn, _>::from_vec(vec![3.0, 3.0, 3.0], [3]), 77 | 1e-6 78 | ); 79 | } 80 | run_test!(broadcast_2d_1d, broadcast_2d_1d_cpu, broadcast_2d_1d_nvidia); 81 | 82 | fn broadcast_4d_2d() { 83 | let x: Matrix, DimDyn, D> = Matrix::from_vec(vec![1.0, 2.0], [1, 2]); 84 | let x = Variable::from(x); 85 | let y = broadcast(x.clone(), DimDyn::new(&[2, 3, 1, 2])); 86 | let forward_ans: Matrix, DimDyn, D> = Matrix::from_vec( 87 | vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0], 88 | [2, 3, 1, 2], 89 | ); 90 | 91 | y.backward(); 92 | let backward_ans: Matrix, DimDyn, D> = Matrix::from_vec(vec![6.0, 6.0], [1, 2]); 93 | assert_val_eq!(y, forward_ans, 1e-6); 94 | assert_val_eq_grad!(x, backward_ans, 1e-6); 95 | } 96 | run_test!(broadcast_4d_2d, broadcast_4d_2d_cpu, broadcast_4d_2d_nvidia); 97 | } 98 | -------------------------------------------------------------------------------- /zenu-autograd/src/functions/clip.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, rc::Rc}; 2 | 3 | use zenu_matrix::{device::Device, num::Num}; 4 | 5 | use crate::{creator::alloc::alloc, Function, Variable, VariableWeak}; 6 | 7 | struct Clip { 8 | min: T, 9 | max: T, 10 | input: Variable, 11 | output: VariableWeak, 12 | } 13 | 14 | impl Clip { 15 | pub fn new(min: T, max: T, input: Variable, output: Variable) -> Self { 16 | assert_eq!( 17 | input.get_data().shape(), 18 | output.get_data().shape(), 19 | "input.shape() != output.shape()" 20 | ); 21 | let output = output.downgrade(); 22 | Self { 23 | min, 24 | max, 25 | input, 26 | output, 27 | } 28 | } 29 | } 30 | 31 | impl Function for Clip { 32 | fn backward(&self) { 33 | let output_grad = self.output.upgrade().unwrap().get_grad().clone().unwrap(); 34 | let clip_filter = self 35 | .input 36 | .get_data() 37 | .to_ref() 38 | .clip_backward_mask(self.min, self.max); 39 | let clip_filter = Variable::from(clip_filter); 40 | let input_grad = output_grad * clip_filter; 41 | self.input.set_grad(input_grad); 42 | } 43 | 44 | fn forward(&self) { 45 | let input = self.input.get_data(); 46 | let output = input.clip(self.min, self.max); 47 | self.output 48 | .upgrade() 49 | .unwrap() 50 | .get_data_mut() 51 | .to_ref_mut() 52 | .copy_from(&output.to_ref()); 53 | } 54 | 55 | fn get_inputs(&self) -> Vec> { 56 | vec![self.input.clone()] 57 | } 58 | } 59 | 60 | pub fn clip(input: Variable, min: T, max: T) -> Variable { 61 | let output = alloc(input.get_shape()); 62 | let clip = Clip::new(min, max, input, output.clone()); 63 | clip.forward(); 64 | output.set_creator(Rc::new(RefCell::new(Box::new(clip)))); 65 | output 66 | } 67 | 68 | #[cfg(test)] 69 | mod clip { 70 | use zenu_matrix::{ 71 | device::Device, 72 | dim::DimDyn, 73 | matrix::{Matrix, Owned}, 74 | }; 75 | use zenu_test::{assert_val_eq, assert_val_eq_grad, run_test}; 76 | 77 | use crate::creator::from_vec::from_vec; 78 | 79 | fn clip_1d() { 80 | let input = from_vec(vec![1., 2., 3., 4., 5., 6.], [6]); 81 | let output = super::clip(input.clone(), 2.0, 4.0); 82 | output.backward(); 83 | let ans: Matrix, DimDyn, D> = 84 | Matrix::from_vec(vec![2., 2., 3., 4., 4., 4.], [6]); 85 | let grad_ans: Matrix, DimDyn, D> = 86 | Matrix::from_vec(vec![0., 1., 1., 1., 0., 0.], [6]); 87 | assert_val_eq!(output, ans, 1.0e-6); 88 | assert_val_eq_grad!(input, grad_ans, 1.0e-6); 89 | } 90 | run_test!(clip_1d, clip_1d_cpu, clip_1d_nvidia); 91 | } 92 | -------------------------------------------------------------------------------- /zenu-autograd/src/functions/div.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, ops::Div, rc::Rc}; 2 | 3 | use zenu_matrix::{device::Device, num::Num}; 4 | 5 | use crate::{creator::alloc::alloc, Function, Variable, VariableWeak}; 6 | 7 | struct DivFunc { 8 | x: Variable, 9 | y: Variable, 10 | output: VariableWeak, 11 | } 12 | 13 | impl DivFunc { 14 | pub fn new(x: Variable, y: Variable, output: Variable) -> Self { 15 | let output = output.downgrade(); 16 | Self { x, y, output } 17 | } 18 | } 19 | 20 | impl Function for DivFunc { 21 | fn forward(&self) { 22 | let x = self.x.get_data(); 23 | let y = self.y.get_data(); 24 | let output = x.to_ref() / y.to_ref(); 25 | self.output 26 | .upgrade() 27 | .unwrap() 28 | .get_data_mut() 29 | .to_ref_mut() 30 | .copy_from(&output.to_ref()); 31 | } 32 | 33 | fn backward(&self) { 34 | let output_grad = self.output.upgrade().unwrap().get_grad().clone().unwrap(); 35 | let x_grad = output_grad.clone() / self.y.clone(); 36 | let y_grad = output_grad * self.x.clone() / (self.y.clone() * self.y.clone()) 37 | * Variable::from(T::minus_one()); 38 | self.x.set_grad(x_grad); 39 | self.y.set_grad(y_grad); 40 | } 41 | 42 | fn get_inputs(&self) -> Vec> { 43 | vec![self.x.clone(), self.y.clone()] 44 | } 45 | } 46 | 47 | pub fn div(x: Variable, y: Variable) -> Variable { 48 | let output_shape = super::output_shape(&x, &y); 49 | let output = alloc(output_shape); 50 | let div = DivFunc::new(x, y, output.clone()); 51 | div.forward(); 52 | output.set_creator(Rc::new(RefCell::new(Box::new(div)))); 53 | output 54 | } 55 | 56 | impl Div> for Variable { 57 | type Output = Variable; 58 | 59 | fn div(self, rhs: Variable) -> Self::Output { 60 | div(self, rhs) 61 | } 62 | } 63 | 64 | #[cfg(test)] 65 | mod div { 66 | 67 | use zenu_matrix::{ 68 | device::Device, 69 | dim::DimDyn, 70 | matrix::{Matrix, Owned}, 71 | }; 72 | use zenu_test::{assert_val_eq, assert_val_eq_grad, run_test}; 73 | 74 | use crate::{creator::from_vec::from_vec, Variable}; 75 | 76 | fn div_2d() { 77 | let a: Variable = from_vec(vec![1f64, 2., 3., 4., 5., 6.], [2, 3]); 78 | let b: Variable = from_vec(vec![6., 7., 8., 9., 10., 11.], [2, 3]); 79 | let c = a.clone() / b.clone(); 80 | c.backward(); 81 | let ans = Matrix::, DimDyn, D>::from_vec( 82 | vec![1. / 6., 2. / 7., 3. / 8., 4. / 9., 5. / 10., 6. / 11.], 83 | [2, 3], 84 | ); 85 | assert_val_eq!(c, ans, 1e-6); 86 | let a_grad: Matrix, DimDyn, D> = 87 | Matrix::from_vec(vec![0.1667, 0.1429, 0.1250, 0.1111, 0.1000, 0.0909], [2, 3]); 88 | let b_grad: Matrix, DimDyn, D> = Matrix::from_vec( 89 | vec![-0.0278, -0.0408, -0.0469, -0.0494, -0.0500, -0.0496], 90 | [2, 3], 91 | ); 92 | assert_val_eq_grad!(a, a_grad, 1e-4); 93 | assert_val_eq_grad!(b, b_grad, 1e-4); 94 | } 95 | run_test!(div_2d, div_2d_cpu, div_2d_nvidia); 96 | } 97 | -------------------------------------------------------------------------------- /zenu-autograd/src/functions/exp.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, rc::Rc}; 2 | 3 | use zenu_matrix::{device::Device, num::Num}; 4 | 5 | use crate::{creator::alloc::alloc_like, Function, Variable, VariableWeak}; 6 | 7 | struct Exp { 8 | input: Variable, 9 | output: VariableWeak, 10 | } 11 | 12 | impl Exp { 13 | pub fn new(input: Variable, output: Variable) -> Self { 14 | let output = output.downgrade(); 15 | Self { input, output } 16 | } 17 | } 18 | 19 | impl Function for Exp { 20 | fn forward(&self) { 21 | let input = self.input.get_data(); 22 | let output = self.output.upgrade().unwrap(); 23 | output.get_data_mut().to_ref_mut().exp_array(&input); 24 | } 25 | 26 | fn backward(&self) { 27 | let output = self.output.upgrade().unwrap(); 28 | let output_grad = output.get_grad().unwrap(); 29 | self.input.set_grad(output * output_grad); 30 | } 31 | 32 | fn get_inputs(&self) -> Vec> { 33 | vec![self.input.clone()] 34 | } 35 | } 36 | 37 | #[must_use] 38 | pub fn exp(input: Variable) -> Variable { 39 | let output = alloc_like(&input); 40 | let exp = Exp::new(input, output.clone()); 41 | exp.forward(); 42 | output.set_creator(Rc::new(RefCell::new(Box::new(exp)))); 43 | output 44 | } 45 | 46 | #[cfg(test)] 47 | mod exp { 48 | use zenu_matrix::{ 49 | device::Device, 50 | dim::DimDyn, 51 | matrix::{Matrix, Owned}, 52 | }; 53 | use zenu_test::{assert_val_eq, assert_val_eq_grad, run_test}; 54 | 55 | use crate::creator::from_vec::from_vec; 56 | 57 | use super::exp; 58 | 59 | #[expect(clippy::unreadable_literal)] 60 | fn exp_1d() { 61 | let x = from_vec(vec![1., 2., 3.], [3]); 62 | let exp = exp(x.clone()); 63 | exp.backward(); 64 | let exp_ans: Matrix, DimDyn, D> = 65 | Matrix::from_vec(vec![2.7182817, 7.389056, 20.085537], [3]); 66 | let x_grad: Matrix, DimDyn, D> = 67 | Matrix::from_vec(vec![1_f64.exp(), 2_f64.exp(), 3_f64.exp()], [3]); 68 | assert_val_eq!(exp, exp_ans, 1e-6); 69 | assert_val_eq_grad!(x, x_grad, 1e-6); 70 | } 71 | run_test!(exp_1d, exp_1d_cpu, exp_1d_nvidia); 72 | } 73 | -------------------------------------------------------------------------------- /zenu-autograd/src/functions/flatten.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, rc::Rc}; 2 | 3 | use zenu_matrix::{device::Device, dim::DimTrait, num::Num}; 4 | 5 | use crate::{creator::alloc::alloc, Function, Variable, VariableWeak}; 6 | 7 | use super::reshape::reshape; 8 | 9 | struct Flatten { 10 | input: Variable, 11 | output: VariableWeak, 12 | } 13 | 14 | impl Flatten { 15 | fn new(input: Variable, output: Variable) -> Self { 16 | let output = output.downgrade(); 17 | Self { input, output } 18 | } 19 | } 20 | 21 | impl Function for Flatten { 22 | fn forward(&self) { 23 | let output_shape = self.output.upgrade().unwrap().get_data().shape(); 24 | let input_mat = self.input.get_data(); 25 | self.output 26 | .upgrade() 27 | .unwrap() 28 | .get_data_mut() 29 | .to_ref_mut() 30 | .copy_from(&input_mat.reshape(output_shape.slice())); 31 | } 32 | 33 | fn backward(&self) { 34 | let output_grad = self.output.upgrade().unwrap().get_grad().unwrap(); 35 | let input_shape = self.input.get_shape(); 36 | let input_grad = reshape(output_grad, input_shape.slice()); 37 | self.input.set_grad(input_grad); 38 | } 39 | 40 | fn get_inputs(&self) -> Vec> { 41 | vec![self.input.clone()] 42 | } 43 | } 44 | 45 | #[must_use] 46 | pub fn flatten(input: Variable) -> Variable { 47 | let input_shape = input.get_data().shape(); 48 | let batch_size = input_shape[0]; 49 | let num_elm = input_shape.num_elm(); 50 | let output_shape = [batch_size, num_elm / batch_size]; 51 | let output = alloc(output_shape); 52 | let flatten = Flatten::new(input, output.clone()); 53 | flatten.forward(); 54 | output.set_creator(Rc::new(RefCell::new(Box::new(flatten)))); 55 | output 56 | } 57 | -------------------------------------------------------------------------------- /zenu-autograd/src/functions/log.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, rc::Rc}; 2 | 3 | use zenu_matrix::{device::Device, num::Num}; 4 | 5 | use crate::{creator::alloc::alloc_like, Function, Variable, VariableWeak}; 6 | 7 | struct Log { 8 | input: Variable, 9 | output: VariableWeak, 10 | } 11 | 12 | impl Log { 13 | pub fn new(input: Variable, output: Variable) -> Self { 14 | assert_eq!( 15 | input.get_data().shape(), 16 | output.get_data().shape(), 17 | "input and output shape must be same" 18 | ); 19 | let output = output.downgrade(); 20 | Self { input, output } 21 | } 22 | } 23 | 24 | impl Function for Log { 25 | fn forward(&self) { 26 | let input = self.input.get_data(); 27 | self.output 28 | .upgrade() 29 | .unwrap() 30 | .get_data_mut() 31 | .to_ref_mut() 32 | .log_array(&input.to_ref()); 33 | } 34 | 35 | fn backward(&self) { 36 | let output = self.output.upgrade().unwrap().get_grad().unwrap(); 37 | self.input.set_grad(output / self.input.clone()); 38 | } 39 | 40 | fn get_inputs(&self) -> Vec> { 41 | vec![self.input.clone()] 42 | } 43 | } 44 | 45 | #[must_use] 46 | pub fn log(x: Variable) -> Variable { 47 | let output = alloc_like(&x); 48 | let log = Log::new(x, output.clone()); 49 | log.forward(); 50 | output.set_creator(Rc::new(RefCell::new(Box::new(log)))); 51 | output 52 | } 53 | 54 | #[cfg(test)] 55 | mod log { 56 | 57 | use zenu_matrix::{ 58 | device::Device, 59 | dim::DimDyn, 60 | matrix::{Matrix, Owned}, 61 | }; 62 | use zenu_test::{assert_val_eq, assert_val_eq_grad, run_test}; 63 | 64 | use crate::creator::from_vec::from_vec; 65 | 66 | use super::log; 67 | 68 | #[expect(clippy::unreadable_literal, clippy::approx_constant)] 69 | fn log_1d() { 70 | let x = from_vec(vec![1., 2., 3., 4.], [4]); 71 | let y = log(x.clone()); 72 | y.backward(); 73 | let forward_ans: Matrix, DimDyn, D> = Matrix::from_vec( 74 | vec![ 75 | 0., 76 | 0.6931471805599453, 77 | 1.0986122886681098, 78 | 1.3862943611198906, 79 | ], 80 | [4], 81 | ); 82 | let x_grad: Matrix, DimDyn, D> = 83 | Matrix::from_vec(vec![1., 0.5, 1. / 3., 0.25], [4]); 84 | assert_val_eq!(y, forward_ans, 1e-7); 85 | assert_val_eq_grad!(x, x_grad, 1e-7); 86 | } 87 | run_test!(log_1d, log_1d_cpu, log_1d_gpu); 88 | } 89 | -------------------------------------------------------------------------------- /zenu-autograd/src/functions/mod.rs: -------------------------------------------------------------------------------- 1 | use zenu_matrix::{ 2 | device::Device, 3 | dim::{larger_shape, DimDyn}, 4 | num::Num, 5 | }; 6 | 7 | use crate::Variable; 8 | 9 | mod add; 10 | mod div; 11 | mod mul; 12 | mod sub; 13 | 14 | pub mod broadcast; 15 | pub mod clip; 16 | pub mod cosh; 17 | pub mod exp; 18 | pub mod flatten; 19 | pub mod index_axis; 20 | pub mod log; 21 | pub mod matmul; 22 | pub mod powf; 23 | pub mod reshape; 24 | pub mod sinh; 25 | pub mod slice; 26 | pub mod stack; 27 | pub mod sum; 28 | pub mod sum_to; 29 | pub mod tanh; 30 | pub mod transpose; 31 | 32 | pub(crate) fn output_shape(x: &Variable, y: &Variable) -> DimDyn { 33 | let x_shape = x.get_data().shape(); 34 | let y_shape = y.get_data().shape(); 35 | larger_shape(x_shape, y_shape) 36 | } 37 | -------------------------------------------------------------------------------- /zenu-autograd/src/functions/mul.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, ops::Mul, rc::Rc}; 2 | 3 | use zenu_matrix::{device::Device, num::Num}; 4 | 5 | use crate::{creator::alloc::alloc, Function, Variable, VariableWeak}; 6 | 7 | use super::{output_shape, sum_to::sum_to}; 8 | 9 | struct Multiply { 10 | x: Variable, 11 | y: Variable, 12 | output: VariableWeak, 13 | } 14 | 15 | impl Multiply { 16 | pub fn new(x: Variable, y: Variable, output: Variable) -> Self { 17 | let output = output.downgrade(); 18 | Self { x, y, output } 19 | } 20 | } 21 | 22 | impl Function for Multiply { 23 | fn forward(&self) { 24 | let x = self.x.get_data(); 25 | let y = self.y.get_data(); 26 | let x = x.to_ref(); 27 | let y = y.to_ref(); 28 | let output = self.output.upgrade().unwrap(); 29 | let mut output = output.get_data_mut(); 30 | output.to_ref_mut().mul_array(&x, &y); 31 | } 32 | 33 | fn backward(&self) { 34 | let x_shape = self.x.get_data().shape(); 35 | let y_shape = self.y.get_data().shape(); 36 | let output = self.output.upgrade().unwrap(); 37 | let grad = output.get_grad().clone().unwrap(); 38 | let x_grad = grad.clone() * self.y.clone(); 39 | let y_grad = self.x.clone() * grad; 40 | self.x.set_grad(sum_to(x_grad, x_shape)); 41 | self.y.set_grad(sum_to(y_grad, y_shape)); 42 | } 43 | 44 | fn get_inputs(&self) -> Vec> { 45 | vec![self.x.clone(), self.y.clone()] 46 | } 47 | } 48 | 49 | fn mul(x: Variable, y: Variable) -> Variable { 50 | let output_shape = output_shape(&x, &y); 51 | let output = alloc(output_shape); 52 | let mul = Multiply::new(x, y, output.clone()); 53 | mul.forward(); 54 | output.set_creator(Rc::new(RefCell::new(Box::new(mul)))); 55 | output 56 | } 57 | 58 | impl Mul> for Variable { 59 | type Output = Variable; 60 | 61 | fn mul(self, rhs: Variable) -> Self::Output { 62 | mul(self, rhs) 63 | } 64 | } 65 | 66 | #[cfg(test)] 67 | mod mul { 68 | use zenu_matrix::{ 69 | device::Device, 70 | dim::DimDyn, 71 | matrix::{Matrix, Owned}, 72 | }; 73 | use zenu_test::{assert_val_eq, assert_val_eq_grad, run_test}; 74 | 75 | use crate::Variable; 76 | 77 | fn mul_2d_1d() { 78 | let a_mat: Matrix, DimDyn, D> = 79 | Matrix::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]); 80 | let b_mat: Matrix, DimDyn, D> = Matrix::from_vec(vec![1., 2., 3.], [3]); 81 | let a = Variable::new(a_mat); 82 | let b = Variable::new(b_mat); 83 | let c = a.clone() * b.clone(); 84 | c.backward(); 85 | let c_ans: Matrix, DimDyn, D> = 86 | Matrix::from_vec(vec![1., 4., 9., 4., 10., 18.], [2, 3]); 87 | let a_grad_ans: Matrix, DimDyn, D> = 88 | Matrix::from_vec(vec![1., 2., 3., 1., 2., 3.], [2, 3]); 89 | let b_grad_ans: Matrix, DimDyn, D> = Matrix::from_vec(vec![5., 7., 9.], [3]); 90 | assert_val_eq!(c, c_ans, 1e-6); 91 | assert_val_eq_grad!(a, a_grad_ans, 1e-6); 92 | assert_val_eq_grad!(b, b_grad_ans, 1e-6); 93 | } 94 | run_test!(mul_2d_1d, mul_2d_1d_cpu, mul_2d_1d_gpu); 95 | } 96 | -------------------------------------------------------------------------------- /zenu-autograd/src/functions/powf.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, rc::Rc}; 2 | 3 | use zenu_matrix::{device::Device, num::Num}; 4 | 5 | use crate::{creator::alloc::alloc_like, Function, Variable, VariableWeak}; 6 | 7 | struct Powf { 8 | input: Variable, 9 | factor: T, 10 | output: VariableWeak, 11 | } 12 | 13 | impl Function for Powf { 14 | fn get_inputs(&self) -> Vec> { 15 | vec![self.input.clone()] 16 | } 17 | 18 | fn forward(&self) { 19 | let x = self.input.get_data(); 20 | let output = self.output.upgrade().unwrap(); 21 | let mut y = output.get_data_mut(); 22 | y.to_ref_mut().powf(&x, self.factor); 23 | } 24 | 25 | fn backward(&self) { 26 | let dx = powf(self.input.clone(), self.factor - T::one()) 27 | * Variable::from(self.factor) 28 | * self.output.upgrade().unwrap().get_grad().unwrap(); 29 | self.input.set_grad(dx); 30 | } 31 | } 32 | 33 | pub fn powf(x: Variable, factor: T) -> Variable { 34 | let output = alloc_like(&x); 35 | let output_weak = output.clone().downgrade(); 36 | let powf = Powf { 37 | input: x, 38 | factor, 39 | output: output_weak, 40 | }; 41 | powf.forward(); 42 | output.set_creator(Rc::new(RefCell::new(Box::new(powf)))); 43 | output 44 | } 45 | 46 | #[cfg(test)] 47 | mod powf { 48 | use zenu_matrix::{ 49 | device::Device, 50 | dim::DimDyn, 51 | matrix::{Matrix, Owned}, 52 | }; 53 | use zenu_test::{assert_val_eq, assert_val_eq_grad, run_test}; 54 | 55 | use crate::{creator::from_vec::from_vec, Variable}; 56 | 57 | use super::powf; 58 | 59 | fn powf_() { 60 | let input: Variable = from_vec(vec![1.0, 2.0, 3.0], [3]); 61 | let output = powf(input.clone(), 2.0); 62 | output.backward(); 63 | let expected = Matrix::, DimDyn, D>::from_vec(vec![1.0, 4.0, 9.0], [3]); 64 | assert_val_eq!(output, expected, 1e-6); 65 | let expected = Matrix::, DimDyn, D>::from_vec(vec![2.0, 4.0, 6.0], [3]); 66 | assert_val_eq_grad!(input, expected, 1e-6); 67 | } 68 | run_test!(powf_, powf_cpu, pow_nvidia); 69 | } 70 | -------------------------------------------------------------------------------- /zenu-autograd/src/functions/reshape.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, rc::Rc}; 2 | 3 | use zenu_matrix::{device::Device, dim::DimTrait, num::Num}; 4 | 5 | use crate::{creator::alloc::alloc, Function, Variable, VariableWeak}; 6 | 7 | struct Reshape { 8 | input: Variable, 9 | output: VariableWeak, 10 | } 11 | 12 | impl Reshape { 13 | fn new(input: Variable, output: Variable) -> Self { 14 | let output = output.downgrade(); 15 | Self { input, output } 16 | } 17 | } 18 | 19 | impl Function for Reshape { 20 | fn forward(&self) { 21 | let output_shape = self.output.upgrade().unwrap().get_data().shape(); 22 | let input_mat = self.input.get_data(); 23 | self.output 24 | .upgrade() 25 | .unwrap() 26 | .get_data_mut() 27 | .to_ref_mut() 28 | .copy_from(&input_mat.reshape(output_shape.slice())); 29 | } 30 | 31 | fn backward(&self) { 32 | let output_grad = self.output.upgrade().unwrap().get_grad().unwrap(); 33 | self.input.set_grad(reshape( 34 | output_grad.clone(), 35 | self.input.get_data().shape().slice(), 36 | )); 37 | } 38 | 39 | fn get_inputs(&self) -> Vec> { 40 | vec![self.input.clone()] 41 | } 42 | } 43 | 44 | #[must_use] 45 | pub fn reshape(input: Variable, output_shape: &[usize]) -> Variable { 46 | let output = alloc(output_shape); 47 | let reshape = Reshape::new(input, output.clone()); 48 | reshape.forward(); 49 | output.set_creator(Rc::new(RefCell::new(Box::new(reshape)))); 50 | output 51 | } 52 | -------------------------------------------------------------------------------- /zenu-autograd/src/functions/sub.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, ops::Sub, rc::Rc}; 2 | 3 | use zenu_matrix::{device::Device, num::Num}; 4 | 5 | use crate::{creator::alloc::alloc, Function, Variable, VariableWeak}; 6 | 7 | use super::output_shape; 8 | 9 | pub struct SubFunc { 10 | x: Variable, 11 | y: Variable, 12 | output: VariableWeak, 13 | } 14 | 15 | impl Function for SubFunc { 16 | fn forward(&self) { 17 | let x = self.x.get_data(); 18 | let y = self.y.get_data(); 19 | let output = x.to_ref() - y.to_ref(); 20 | self.output 21 | .upgrade() 22 | .unwrap() 23 | .get_data_mut() 24 | .to_ref_mut() 25 | .copy_from(&output.to_ref()); 26 | } 27 | 28 | fn backward(&self) { 29 | // let output_grad = self.output.get_grad().clone().unwrap(); 30 | let output_grad = self.output.upgrade().unwrap().get_grad().clone().unwrap(); 31 | let x_grad = output_grad.clone(); 32 | let y_grad = output_grad.clone() * Variable::from(T::minus_one()); 33 | self.x.set_grad(x_grad); 34 | self.y.set_grad(y_grad); 35 | } 36 | 37 | fn get_inputs(&self) -> Vec> { 38 | vec![self.x.clone(), self.y.clone()] 39 | } 40 | } 41 | 42 | pub fn sub(x: Variable, y: Variable) -> Variable { 43 | let output_shape = output_shape(&x, &y); 44 | let output = alloc(output_shape); 45 | let sub = SubFunc { 46 | x, 47 | y, 48 | output: output.clone().downgrade(), 49 | }; 50 | sub.forward(); 51 | output.set_creator(Rc::new(RefCell::new(Box::new(sub)))); 52 | output 53 | } 54 | 55 | impl Sub> for Variable { 56 | type Output = Variable; 57 | 58 | fn sub(self, rhs: Variable) -> Self::Output { 59 | sub(self, rhs) 60 | } 61 | } 62 | 63 | #[cfg(test)] 64 | mod tests { 65 | use zenu_matrix::{ 66 | device::cpu::Cpu, 67 | dim::DimDyn, 68 | matrix::{Matrix, Owned}, 69 | }; 70 | use zenu_test::{assert_val_eq, assert_val_eq_grad}; 71 | 72 | use super::*; 73 | 74 | #[test] 75 | fn sub() { 76 | let x = Variable::::new(Matrix::from_vec(vec![1., 2., 3.], [3])); 77 | let y = Variable::new(Matrix::from_vec(vec![1., 2., 3.], [3])); 78 | let z = x.clone() - y.clone(); 79 | let ans = Matrix::, DimDyn, _>::zeros([3]); 80 | let ones = Matrix::<_, DimDyn, _>::ones([3]); 81 | let minus_ones = Matrix::<_, DimDyn, _>::from_vec(vec![-1., -1., -1.], [3]); 82 | assert_val_eq!(z.clone(), ans, 1e-4); 83 | z.backward(); 84 | assert_val_eq_grad!(x, ones, 1e-4); 85 | assert_val_eq_grad!(y, minus_ones, 1e-4); 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /zenu-autograd/src/functions/sum_to.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, rc::Rc}; 2 | 3 | use zenu_matrix::{device::Device, dim::DimDyn, num::Num, operation::sum::sum_to as sum_to_func}; 4 | 5 | use crate::{creator::zeros::zeros, Function, Variable, VariableWeak}; 6 | 7 | use super::broadcast::broadcast; 8 | 9 | struct SumTo { 10 | x: Variable, 11 | output: VariableWeak, 12 | } 13 | 14 | impl SumTo { 15 | pub fn new(x: Variable, output: Variable) -> Self { 16 | let output = output.downgrade(); 17 | Self { x, output } 18 | } 19 | } 20 | 21 | impl Function for SumTo { 22 | fn forward(&self) { 23 | sum_to_func( 24 | self.x.get_data().to_ref(), 25 | self.output.upgrade().unwrap().get_data_mut().to_ref_mut(), 26 | ); 27 | } 28 | 29 | fn backward(&self) { 30 | let output = self.output.upgrade().unwrap(); 31 | let output_grad = output.get_grad().clone().unwrap(); 32 | let x_grad = broadcast(output_grad.clone(), self.x.get_data().shape()); 33 | self.x.set_grad(x_grad); 34 | } 35 | 36 | fn get_inputs(&self) -> Vec> { 37 | vec![self.x.clone()] 38 | } 39 | } 40 | 41 | pub fn sum_to, D: Device>(x: Variable, shape: I) -> Variable { 42 | let shape = shape.into(); 43 | let output = zeros(shape); 44 | let sum_to = SumTo::new(x, output.clone()); 45 | sum_to.forward(); 46 | output.set_creator(Rc::new(RefCell::new(Box::new(sum_to)))); 47 | output 48 | } 49 | 50 | #[cfg(test)] 51 | mod sum_to_test { 52 | use zenu_matrix::{ 53 | device::Device, 54 | dim::DimDyn, 55 | matrix::{Matrix, Owned}, 56 | }; 57 | use zenu_test::{assert_val_eq, assert_val_eq_grad, run_test}; 58 | 59 | use crate::Variable; 60 | 61 | use super::sum_to; 62 | 63 | fn sum_to_2d_1d() { 64 | let x: Matrix, DimDyn, D> = 65 | Matrix::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], [2, 3]); 66 | let x = Variable::from(x); 67 | let y = sum_to(x.clone(), DimDyn::new(&[3])); 68 | let forward_ans: Matrix, DimDyn, D> = Matrix::from_vec(vec![2.0, 4.0, 6.0], [3]); 69 | y.backward(); 70 | let x_grad: Matrix, DimDyn, D> = Matrix::ones([2, 3]); 71 | assert_val_eq!(y, forward_ans, 1e-6); 72 | assert_val_eq_grad!(x, x_grad, 1e-6); 73 | } 74 | run_test!(sum_to_2d_1d, sum_to_2d_1d_cpu, sum_to_2d_1d_nvidia); 75 | } 76 | -------------------------------------------------------------------------------- /zenu-autograd/src/functions/tanh.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, rc::Rc}; 2 | 3 | use zenu_matrix::{device::Device, num::Num}; 4 | 5 | use crate::{creator::alloc::alloc_like, Function, Variable, VariableWeak}; 6 | 7 | use super::cosh::cosh; 8 | 9 | struct Tanh { 10 | input: Variable, 11 | output: VariableWeak, 12 | } 13 | 14 | impl Tanh { 15 | pub fn new(input: Variable, output: Variable) -> Self { 16 | let output = output.downgrade(); 17 | Self { input, output } 18 | } 19 | } 20 | 21 | impl Function for Tanh { 22 | fn forward(&self) { 23 | let output = self.output.upgrade().unwrap(); 24 | output 25 | .get_data_mut() 26 | .to_ref_mut() 27 | .tanh_array(&self.input.get_data().to_ref()); 28 | } 29 | 30 | fn backward(&self) { 31 | let output = self.output.upgrade().unwrap(); 32 | let output_grad = output.get_grad().unwrap(); 33 | let input_cosh = cosh(self.input.clone()); 34 | let input_cosh_2 = input_cosh.clone() * input_cosh.clone(); 35 | let grad = output_grad / input_cosh_2; 36 | self.input.set_grad(grad); 37 | } 38 | 39 | fn get_inputs(&self) -> Vec> { 40 | vec![self.input.clone()] 41 | } 42 | } 43 | 44 | #[must_use] 45 | pub fn tanh(input: Variable) -> Variable { 46 | let output = alloc_like(&input); 47 | output.set_name(&format!("tanh({})", input.get_name().unwrap_or_default())); 48 | let tanh = Tanh::new(input, output.clone()); 49 | tanh.forward(); 50 | output.set_creator(Rc::new(RefCell::new(Box::new(tanh)))); 51 | output 52 | } 53 | 54 | #[cfg(test)] 55 | mod tanh { 56 | 57 | use zenu_matrix::{ 58 | device::Device, 59 | dim::DimDyn, 60 | matrix::{Matrix, Owned}, 61 | }; 62 | use zenu_test::{assert_val_eq, assert_val_eq_grad, run_test}; 63 | 64 | use crate::creator::from_vec::from_vec; 65 | 66 | use super::tanh; 67 | 68 | fn tanh_1d() { 69 | let x = from_vec(vec![1f32, 2., 3., 4., 5., 6.], [6]); 70 | let y = tanh(x.clone()); 71 | y.backward(); 72 | let y_ans: Matrix, DimDyn, D> = Matrix::from_vec( 73 | vec![ 74 | 1_f32.tanh(), 75 | 2_f32.tanh(), 76 | 3_f32.tanh(), 77 | 4_f32.tanh(), 78 | 5_f32.tanh(), 79 | 6_f32.tanh(), 80 | ], 81 | [6], 82 | ); 83 | let x_grad_ans: Matrix, DimDyn, D> = Matrix::from_vec( 84 | vec![ 85 | 1. / (1_f32.cosh() * 1_f32.cosh()), 86 | 1. / (2_f32.cosh() * 2_f32.cosh()), 87 | 1. / (3_f32.cosh() * 3_f32.cosh()), 88 | 1. / (4_f32.cosh() * 4_f32.cosh()), 89 | 1. / (5_f32.cosh() * 5_f32.cosh()), 90 | 1. / (6_f32.cosh() * 6_f32.cosh()), 91 | ], 92 | [6], 93 | ); 94 | assert_val_eq!(y, y_ans, 1e-6); 95 | assert_val_eq_grad!(x, x_grad_ans, 1e-6); 96 | } 97 | run_test!(tanh_1d, tanh_1d_cpu, tanh_1d_nvidia); 98 | } 99 | -------------------------------------------------------------------------------- /zenu-autograd/src/loss/cross_entropy.rs: -------------------------------------------------------------------------------- 1 | use zenu_matrix::{device::Device, num::Num}; 2 | 3 | use crate::{ 4 | functions::{log::log, sum_to::sum_to}, 5 | Variable, 6 | }; 7 | 8 | use super::softmax::softmax; 9 | 10 | #[expect(clippy::needless_pass_by_value)] 11 | #[must_use] 12 | pub fn cross_entropy( 13 | pred: Variable, 14 | ans: Variable, 15 | ) -> Variable { 16 | let pred = softmax(pred, 1); 17 | let log = log(pred.clone()); 18 | let y_log_pred = ans.clone() * log; 19 | let sum = sum_to(y_log_pred, &[] as &[usize]); 20 | let n = T::from_usize(pred.get_data().shape()[0]); 21 | let n = Variable::from(-n); 22 | sum / n 23 | } 24 | 25 | #[cfg(test)] 26 | mod cross_entropy_test { 27 | use zenu_matrix::{ 28 | device::Device, 29 | dim::DimDyn, 30 | matrix::{Matrix, Owned}, 31 | }; 32 | use zenu_test::{assert_val_eq, assert_val_eq_grad, run_test}; 33 | 34 | use crate::{creator::from_vec::from_vec, Variable}; 35 | 36 | use super::cross_entropy; 37 | 38 | fn cross_entropy_batch_size_1() { 39 | let pred = from_vec(vec![0.1, 0.9, 0.1, 0.1], [1, 4]); 40 | let ans = from_vec(vec![0.0, 1.0, 0.0, 0.0], [1, 4]); 41 | let loss = super::cross_entropy(pred.clone(), ans); 42 | loss.backward(); 43 | let ans = Matrix::, DimDyn, D>::from_vec(vec![0.8536], []); 44 | assert_val_eq!(loss, ans, 1e-4); 45 | let pred_ans = Matrix::, DimDyn, D>::from_vec( 46 | vec![0.1914, -0.5741, 0.1914, 0.1914], 47 | [1, 4], 48 | ); 49 | assert_val_eq_grad!(pred, pred_ans, 1e-4); 50 | } 51 | run_test!( 52 | cross_entropy_batch_size_1, 53 | cross_entropy_batch_size_1_cpu, 54 | cross_entropy_batch_size_1_nvidia 55 | ); 56 | 57 | fn cross_entropy_batch_size_2() { 58 | let pred: Variable = 59 | from_vec(vec![0.1, 0.9, 0.1, 0.1, 0.01, 0.9, 0.2, 0.05], [2, 4]); 60 | let ans = from_vec(vec![0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], [2, 4]); 61 | let loss = cross_entropy(pred.clone(), ans.clone()); 62 | loss.backward(); 63 | let ans_ = Matrix::, DimDyn, D>::from_vec(vec![1.2957], []); 64 | assert_val_eq!(loss, ans_, 2e-5); 65 | let pred_grad_ans = Matrix::, DimDyn, D>::from_vec( 66 | vec![ 67 | 0.0957, -0.2871, 0.0957, 0.0957, -0.4121, 0.2142, 0.1064, 0.0915, 68 | ], 69 | [2, 4], 70 | ); 71 | assert_val_eq_grad!(pred, pred_grad_ans, 3e-4); 72 | let ans_grad_ans = Matrix::, DimDyn, D>::from_vec( 73 | vec![ 74 | 0.8268, 0.4268, 0.8268, 0.8268, 0.8689, 0.4239, 0.7739, 0.8489, 75 | ], 76 | [2, 4], 77 | ); 78 | assert_val_eq_grad!(ans, ans_grad_ans, 3e-4); 79 | } 80 | run_test!( 81 | cross_entropy_batch_size_2, 82 | cross_entropy_batch_size_2_cpu, 83 | cross_entropy_batch_size_2_nvidia 84 | ); 85 | } 86 | -------------------------------------------------------------------------------- /zenu-autograd/src/loss/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod cross_entropy; 2 | pub mod mse; 3 | pub mod softmax; 4 | -------------------------------------------------------------------------------- /zenu-autograd/src/loss/mse.rs: -------------------------------------------------------------------------------- 1 | use zenu_matrix::{device::Device, num::Num}; 2 | 3 | use crate::{functions::sum_to::sum_to, Variable}; 4 | 5 | #[must_use] 6 | pub fn mean_squared_error( 7 | y_true: Variable, 8 | y_pred: Variable, 9 | ) -> Variable { 10 | let batch_size = y_true.get_data().shape()[0]; 11 | let diff = y_true - y_pred; 12 | let diff_squared = diff.clone() * diff; 13 | sum_to(diff_squared, []) / Variable::from(T::from_usize(batch_size)) 14 | } 15 | 16 | #[cfg(test)] 17 | mod mse_test { 18 | 19 | use zenu_matrix::{ 20 | device::Device, 21 | dim::DimDyn, 22 | matrix::{Matrix, Owned}, 23 | }; 24 | use zenu_test::{assert_val_eq, assert_val_eq_grad, run_test}; 25 | 26 | use super::mean_squared_error; 27 | 28 | fn batch_1() { 29 | let y_true = crate::creator::from_vec::from_vec(vec![1., 2., 3.], [1, 3]); 30 | let y_pred = crate::creator::from_vec::from_vec(vec![2., 3., 4.], [1, 3]); 31 | let mse = mean_squared_error(y_true.clone(), y_pred.clone()); 32 | mse.backward(); 33 | let mse_ans = Matrix::, DimDyn, D>::from_vec(vec![3.], []); 34 | assert_val_eq!(mse, mse_ans, 1e-6); 35 | 36 | let y_true_grad_ans = 37 | Matrix::, DimDyn, D>::from_vec(vec![-2., -2., -2.], [1, 3]); 38 | let y_pred_grad_ans = Matrix::, DimDyn, D>::from_vec(vec![2., 2., 2.], [1, 3]); 39 | assert_val_eq_grad!(y_true, y_true_grad_ans, 1e-6); 40 | assert_val_eq_grad!(y_pred, y_pred_grad_ans, 1e-6); 41 | } 42 | run_test!(batch_1, batch_1_cpu, batch_1_nvidia); 43 | } 44 | -------------------------------------------------------------------------------- /zenu-autograd/src/nn/conv/mod.rs: -------------------------------------------------------------------------------- 1 | mod conv_with_bias; 2 | mod conv_without_bias; 3 | 4 | pub use conv_with_bias::conv; 5 | pub use conv_without_bias::ConvConfigs; 6 | -------------------------------------------------------------------------------- /zenu-autograd/src/nn/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod batch_norm; 2 | pub mod conv; 3 | pub mod dropout; 4 | pub mod pool2d; 5 | pub mod rnns; 6 | -------------------------------------------------------------------------------- /zenu-autograd/src/nn/rnns/gru/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::Variable; 2 | 3 | use zenu_matrix::{device::Device, num::Num}; 4 | 5 | #[cfg(feature = "nvidia")] 6 | pub mod cudnn; 7 | 8 | pub mod naive; 9 | 10 | pub struct GRUOutput { 11 | pub y: Variable, 12 | pub hy: Variable, 13 | } 14 | -------------------------------------------------------------------------------- /zenu-autograd/src/nn/rnns/lstm/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "nvidia")] 2 | pub mod cudnn; 3 | 4 | pub mod naive; 5 | -------------------------------------------------------------------------------- /zenu-autograd/src/nn/rnns/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod gru; 2 | pub mod lstm; 3 | pub mod rnn; 4 | pub mod weights; 5 | -------------------------------------------------------------------------------- /zenu-autograd/src/nn/rnns/rnn/mod.rs: -------------------------------------------------------------------------------- 1 | use zenu_matrix::{device::Device, num::Num}; 2 | 3 | use crate::Variable; 4 | 5 | #[cfg(feature = "nvidia")] 6 | pub mod cudnn; 7 | 8 | pub mod naive; 9 | 10 | pub struct RNNOutput { 11 | pub y: Variable, 12 | pub hy: Variable, 13 | } 14 | -------------------------------------------------------------------------------- /zenu-autograd/tests/combined.rs: -------------------------------------------------------------------------------- 1 | // use zenu_autograd::Variable; 2 | // use zenu_matrix::matrix::{AsPtr, ToViewMatrix}; 3 | // 4 | // fn formula(x: Variable, y: Variable, z: Variable) -> Variable { 5 | // let a = x * y; 6 | // let b = a + z; 7 | // b 8 | // } 9 | // 10 | // #[test] 11 | // fn test_formula() { 12 | // let x = Variable::from(3.0); 13 | // let y = Variable::from(4.0); 14 | // let z = Variable::from(5.0); 15 | // let result = formula(x.clone(), y.clone(), z.clone()); 16 | // assert_eq!(unsafe { *result.get_data().to_view().as_ptr() }, 17.0); 17 | // result.backward(); 18 | // x.with_grad_data(|grad| { 19 | // assert_eq!(unsafe { *grad.to_view().as_ptr() }, 4.0); 20 | // }); 21 | // y.with_grad_data(|grad| { 22 | // assert_eq!(unsafe { *grad.to_view().as_ptr() }, 3.0); 23 | // }); 24 | // z.with_grad_data(|grad| { 25 | // assert_eq!(unsafe { *grad.to_view().as_ptr() }, 1.0); 26 | // }); 27 | // } 28 | // 29 | // fn use_twice(x: Variable) -> Variable { 30 | // let a = x.clone() * x; 31 | // let b = a.clone() + Variable::from(3.0); 32 | // b 33 | // } 34 | // 35 | // #[test] 36 | // fn test_use_twice() { 37 | // let x = Variable::from(3.0); 38 | // let result = use_twice(x.clone()); 39 | // assert_eq!(unsafe { *result.get_data().to_view().as_ptr() }, 12.0); 40 | // result.backward(); 41 | // x.with_grad_data(|grad| { 42 | // assert_eq!(unsafe { *grad.to_view().as_ptr() }, 6.0); 43 | // }); 44 | // } 45 | -------------------------------------------------------------------------------- /zenu-autograd/tests/get_all_test_variable.rs: -------------------------------------------------------------------------------- 1 | // use zenu_autograd::Variable; 2 | // 3 | // fn test_function() -> Variable { 4 | // let a = Variable::from(1.0); 5 | // let b = Variable::from(2.0); 6 | // 7 | // b.set_is_train(true); 8 | // b.set_name("b"); 9 | // 10 | // let c = Variable::from(3.0); 11 | // c.set_name("c"); 12 | // c.set_is_train(true); 13 | // 14 | // a * b + c 15 | // } 16 | // 17 | // #[test] 18 | // fn all_trainable() { 19 | // let a = test_function(); 20 | // let variables = a.get_all_trainable_variables(); 21 | // assert_eq!(variables.len(), 2); 22 | // let mut names = variables 23 | // .iter() 24 | // .map(|v| v.get_name().unwrap()) 25 | // .collect::>(); 26 | // names.sort(); 27 | // assert_eq!(names, vec!["b", "c"]); 28 | // } 29 | // 30 | // fn more_complicated(a: Variable, b: Variable, c: Variable) -> Variable { 31 | // let d = a * b; 32 | // let e = d.clone() + c; 33 | // let f = e * d; 34 | // f 35 | // } 36 | // 37 | // #[test] 38 | // fn copilicated() { 39 | // let a = Variable::from(1.0); 40 | // a.set_name("a"); 41 | // a.set_is_train(true); 42 | // 43 | // let c = more_complicated(a.clone(), a.clone(), a.clone()); 44 | // let d = more_complicated(c.clone(), a.clone(), a.clone()); 45 | // let variables = d.get_all_trainable_variables(); 46 | // assert_eq!(variables.len(), 1); 47 | // assert_eq!(variables[0].get_name().unwrap(), "a"); 48 | // } 49 | // 50 | // #[test] 51 | // fn ultra_large_copilicated() { 52 | // let a = Variable::from(1.0); 53 | // a.set_name("a"); 54 | // a.set_is_train(true); 55 | // 56 | // let mut c = more_complicated(a.clone(), a.clone(), a.clone()); 57 | // for _ in 0..500 { 58 | // c = more_complicated(c.clone(), a.clone(), a.clone()); 59 | // } 60 | // let variables = c.get_all_trainable_variables(); 61 | // assert_eq!(variables.len(), 1); 62 | // } 63 | -------------------------------------------------------------------------------- /zenu-cublas-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu-cublas-sys" 3 | version = "0.1.0" 4 | edition = "2021" 5 | description = "Rust bindings for cuBLAS" 6 | license = "MIT" 7 | repository = "https://github.com/bokutotu/zenu" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | libc="0.2.153" 13 | 14 | [build-dependencies] 15 | bindgen = "0.69.4" 16 | zenu-cuda-config = { path = "../zenu-cuda-config", version = "0.1.0"} 17 | 18 | [lints] 19 | workspace = true 20 | -------------------------------------------------------------------------------- /zenu-cublas-sys/build.rs: -------------------------------------------------------------------------------- 1 | extern crate bindgen; 2 | 3 | use zenu_cuda_config::find_cuda; 4 | 5 | fn main() { 6 | for path in find_cuda() { 7 | println!("cargo:rustc-link-search=native={}", path.display()); 8 | } 9 | 10 | println!("cargo:rustc-link-lib=dylib=cudart"); 11 | println!("cargo:rerun-if-changed=build.rs"); 12 | println!("cargo:rustc-link-lib=dylib=cublas"); 13 | 14 | let bindings = bindgen::Builder::default() 15 | .ctypes_prefix("::libc") 16 | .size_t_is_usize(true) 17 | .clang_arg("-I") 18 | .clang_arg("/usr/local/cuda/include".to_string()) 19 | .header("wrapper.h") 20 | .rustified_non_exhaustive_enum("cublas[A-Za-z]+_t") 21 | .rustified_non_exhaustive_enum("cuda.*") 22 | .default_alias_style(bindgen::AliasVariation::TypeAlias) 23 | .allowlist_type("^cublas.*") 24 | .allowlist_type("cublas.*") 25 | .allowlist_function("^cublas.*") 26 | .allowlist_type("[Cc][Uu].*") 27 | .allowlist_var("CUBLAS.*") 28 | .derive_default(true) 29 | .derive_eq(true) 30 | .derive_hash(true) 31 | .derive_ord(true) 32 | .generate() 33 | .expect("Unable to generate bindings"); 34 | 35 | bindings 36 | .write_to_file("./src/bindings.rs") 37 | .expect("Unable to write"); 38 | } 39 | -------------------------------------------------------------------------------- /zenu-cublas-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![expect(non_upper_case_globals)] 2 | #![expect(non_camel_case_types)] 3 | #![expect(non_snake_case)] 4 | #![expect(clippy::unreadable_literal)] 5 | #![expect(clippy::pub_underscore_fields)] 6 | 7 | include!("bindings.rs"); 8 | -------------------------------------------------------------------------------- /zenu-cublas-sys/wrapper.h: -------------------------------------------------------------------------------- 1 | #include "cublas_v2.h" 2 | -------------------------------------------------------------------------------- /zenu-cuda-config/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu-cuda-config" 3 | version = "0.1.0" 4 | edition = "2021" 5 | repository = "https://github.com/bokutotu/zenu" 6 | license = "MIT" 7 | description = "CUDA configuration for Zenu" 8 | 9 | [dependencies] 10 | glob = "0.3" 11 | 12 | [lints] 13 | workspace = true 14 | -------------------------------------------------------------------------------- /zenu-cuda-driver-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu-cuda-driver-sys" 3 | version = "0.1.0" 4 | edition = "2021" 5 | repository = "https://github.com/bokutotu/zenu" 6 | license = "MIT" 7 | description = "Rust bindings for CUDA Driver API" 8 | 9 | [dependencies] 10 | libc = "0.2.153" 11 | 12 | [build-dependencies] 13 | bindgen="0.69.4" 14 | zenu-cuda-config = { path = "../zenu-cuda-config", version = "0.1.0"} 15 | 16 | [lints] 17 | workspace = true 18 | -------------------------------------------------------------------------------- /zenu-cuda-driver-sys/build.rs: -------------------------------------------------------------------------------- 1 | extern crate bindgen; 2 | 3 | use zenu_cuda_config::find_cuda; 4 | 5 | fn main() { 6 | for path in find_cuda() { 7 | println!("cargo:rustc-link-search=native={}", path.display()); 8 | } 9 | 10 | println!("cargo:rustc-link-lib=dylib=cudart"); 11 | println!("cargo:rerun-if-changed=build.rs"); 12 | 13 | let bindings = bindgen::Builder::default() 14 | .ctypes_prefix("::libc") 15 | .size_t_is_usize(true) 16 | .clang_arg("-I") 17 | .clang_arg("/usr/local/cuda/include".to_string()) 18 | .header("wrapper.h") 19 | .rustified_non_exhaustive_enum("cuda.*") 20 | .allowlist_type("^cuda.*") 21 | .allowlist_type("^surfaceReference") 22 | .allowlist_type("^textureReference") 23 | .allowlist_var("^cuda.*") 24 | .allowlist_function("^cuda.*") 25 | .default_alias_style(bindgen::AliasVariation::TypeAlias) 26 | .derive_default(true) 27 | .derive_eq(true) 28 | .derive_hash(true) 29 | .derive_ord(true) 30 | .generate() 31 | .expect("Unable to generate bindings"); 32 | 33 | bindings 34 | .write_to_file("./src/bindings.rs") 35 | .expect("Unable to write"); 36 | } 37 | -------------------------------------------------------------------------------- /zenu-cuda-driver-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![expect(non_camel_case_types)] 2 | 3 | include!("bindings.rs"); 4 | -------------------------------------------------------------------------------- /zenu-cuda-driver-sys/wrapper.h: -------------------------------------------------------------------------------- 1 | #include "cuda.h" 2 | -------------------------------------------------------------------------------- /zenu-cuda-kernel-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu-cuda-kernel-sys" 3 | version = "0.1.0" 4 | edition = "2021" 5 | repository = "https://github.com/bokutotu/zenu" 6 | license = "MIT" 7 | description = "CUDA kernel bindings for Rust" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | libc = "0.2.153" 13 | 14 | [build-dependencies] 15 | bindgen = "0.69.4" 16 | cc = "1.2.1" 17 | 18 | [lints] 19 | workspace = true 20 | 21 | [dev-dependencies] 22 | zenu-cuda-runtime-sys = { path = "../zenu-cuda-runtime-sys" } 23 | -------------------------------------------------------------------------------- /zenu-cuda-kernel-sys/build.rs: -------------------------------------------------------------------------------- 1 | extern crate bindgen; 2 | extern crate cc; 3 | 4 | fn main() { 5 | let cuda_files = vec![ 6 | "kernel/array_scalar.cu", 7 | "kernel/element_wise.cu", 8 | "kernel/memory_access.cu", 9 | "kernel/array_array.cu", 10 | "kernel/activations.cu", 11 | "kernel/conv2d_bkwd_data.cu", 12 | ]; 13 | 14 | for cuda_file in &cuda_files { 15 | println!("cargo:rerun-if-changed={cuda_file}"); 16 | } 17 | println!("cargo:rerun-if-changed=kernel/kernel.h"); 18 | println!("cargo:rerun-if-changed=build.rs"); 19 | 20 | cc::Build::new() 21 | .cuda(true) 22 | .cpp(true) 23 | .flag("-std=c++11") 24 | .flag("-cudart=shared") 25 | .flag("--expt-relaxed-constexpr") 26 | .files(cuda_files) 27 | .include("kernel/") 28 | .compile("libkernel.a"); 29 | 30 | println!("cargo:rustc-link-lib=kernel"); 31 | println!("cargo:rustc-link-lib=static=kernel"); 32 | println!( 33 | "cargo:rustc-link-search=native={}", 34 | std::env::var("OUT_DIR").unwrap() 35 | ); 36 | 37 | let bindings = bindgen::Builder::default() 38 | .ctypes_prefix("::libc") 39 | .size_t_is_usize(true) 40 | .clang_arg("-I") 41 | .clang_arg("/usr/local/cuda/include".to_string()) 42 | .header("kernel/kernel.h") 43 | .default_alias_style(bindgen::AliasVariation::TypeAlias) 44 | .derive_default(true) 45 | .derive_eq(true) 46 | .derive_hash(true) 47 | .derive_ord(true) 48 | .generate() 49 | .expect("Unable to generate bindings"); 50 | 51 | bindings 52 | .write_to_file("src/bindings.rs") 53 | .expect("Couldn't write bindings!"); 54 | } 55 | -------------------------------------------------------------------------------- /zenu-cuda-kernel-sys/kernel/activations.cu: -------------------------------------------------------------------------------- 1 | #include "activations.h" 2 | 3 | template 4 | __global__ void relu_kernel(T *input , T* output, T alpha, int size, int input_stride, int output_stride) { 5 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 6 | if (idx < size) { 7 | output[idx * output_stride] = input[idx * input_stride] > 0 ? input[idx * input_stride] : alpha * input[idx * input_stride]; 8 | } 9 | } 10 | 11 | template 12 | __global__ void relu_background_mask(T *input, T *mask, T alpha, int size, int input_stride, int mask_stride) { 13 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 14 | if (idx < size) { 15 | mask[idx * mask_stride] = input[idx * input_stride] > 0 ? 1 : alpha * -1; 16 | } 17 | } 18 | 19 | void relu_float(float *input , float* output, float alpha, int size, int input_stride, int output_stride) { 20 | int block_size = 128; 21 | int grid_size = (size + block_size - 1) / block_size; 22 | relu_kernel<<>>(input, output, alpha, size, input_stride, output_stride); 23 | } 24 | 25 | void relu_double(double *input , double* output, double alpha, int size, int input_stride, int output_stride) { 26 | int block_size = 128; 27 | int grid_size = (size + block_size - 1) / block_size; 28 | relu_kernel<<>>(input, output, alpha, size, input_stride, output_stride); 29 | } 30 | 31 | void relu_backward_mask_float(float *input, float *mask, float alpha, int size, int input_stride, int mask_stride) { 32 | int block_size = 128; 33 | int grid_size = (size + block_size - 1) / block_size; 34 | relu_background_mask<<>>(input, mask, alpha, size, input_stride, mask_stride); 35 | } 36 | 37 | void relu_backward_mask_double(double *input, double *mask, double alpha, int size, int input_stride, int mask_stride) { 38 | int block_size = 128; 39 | int grid_size = (size + block_size - 1) / block_size; 40 | relu_background_mask<<>>(input, mask, alpha, size, input_stride, mask_stride); 41 | } 42 | -------------------------------------------------------------------------------- /zenu-cuda-kernel-sys/kernel/activations.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | void relu_float(float *input, float *output, float alpha, int size, int input_stride, int output_stride); 8 | void relu_double(double *input, double *output, double alpha, int size, int input_stride, int output_stride); 9 | 10 | void relu_backward_mask_float(float *input, float *mask, float alpha, int size, int input_stride, int mask_stride); 11 | void relu_backward_mask_double(double *input, double *mask, double alpha, int size, int input_stride, int mask_stride); 12 | 13 | #ifdef __cplusplus 14 | } 15 | #endif 16 | -------------------------------------------------------------------------------- /zenu-cuda-kernel-sys/kernel/array_array.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | void array_array_add_float(float *a, int stride_a, float *b, int stride_b, float *c, int stride_c, int n); 8 | void array_array_sub_float(float *a, int stride_a, float *b, int stride_b, float *c, int stride_c, int n); 9 | void array_array_mul_float(float *a, int stride_a, float *b, int stride_b, float *c, int stride_c, int n); 10 | void array_array_div_float(float *a, int stride_a, float *b, int stride_b, float *c, int stride_c, int n); 11 | 12 | void array_array_add_double(double *a, int stride_a, double *b, int stride_b, double *c, int stride_c, int n); 13 | void array_array_sub_double(double *a, int stride_a, double *b, int stride_b, double *c, int stride_c, int n); 14 | void array_array_mul_double(double *a, int stride_a, double *b, int stride_b, double *c, int stride_c, int n); 15 | void array_array_div_double(double *a, int stride_a, double *b, int stride_b, double *c, int stride_c, int n); 16 | 17 | void array_array_add_assign_float(float *a, int stride_a, float *b, int stride_b, int n); 18 | void array_array_sub_assign_float(float *a, int stride_a, float *b, int stride_b, int n); 19 | void array_array_mul_assign_float(float *a, int stride_a, float *b, int stride_b, int n); 20 | void array_array_div_assign_float(float *a, int stride_a, float *b, int stride_b, int n); 21 | 22 | void array_array_add_assign_double(double *a, int stride_a, double *b, int stride_b, int n); 23 | void array_array_sub_assign_double(double *a, int stride_a, double *b, int stride_b, int n); 24 | void array_array_mul_assign_double(double *a, int stride_a, double *b, int stride_b, int n); 25 | void array_array_div_assign_double(double *a, int stride_a, double *b, int stride_b, int n); 26 | 27 | void conv_bias_add_float(const float *input, float *output, int channel_stride, const float *bias, int bias_size, int total_elements); 28 | void conv_bias_add_double(const double *input, double *output, int channel_stride, const double *bias, int bias_size, int total_elements); 29 | 30 | #ifdef __cplusplus 31 | } 32 | #endif 33 | -------------------------------------------------------------------------------- /zenu-cuda-kernel-sys/kernel/conv2d_bkwd_data.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | void conv2d_bias_bkwd_float( 8 | const float* dOut, // device ptr 9 | float* dbias, // device ptr 10 | int N, int C, int H, int W 11 | ); 12 | 13 | void conv2d_bias_bkwd_double( 14 | const double* dOut, // device ptr 15 | double* dbias, // device ptr 16 | int N, int C, int H, int W 17 | ); 18 | 19 | #ifdef __cplusplus 20 | } 21 | #endif 22 | -------------------------------------------------------------------------------- /zenu-cuda-kernel-sys/kernel/element_wise.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | void array_max_idx_float(float *a, int size, int stride, int *out); 8 | 9 | void array_max_idx_double(double *a, int size, int stride, int *out); 10 | 11 | #ifdef __cplusplus 12 | } 13 | #endif 14 | -------------------------------------------------------------------------------- /zenu-cuda-kernel-sys/kernel/kernel.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | #include 8 | 9 | #include "array_array.h" 10 | #include "array_scalar.h" 11 | #include "element_wise.h" 12 | #include "memory_access.h" 13 | #include "activations.h" 14 | #include "conv2d_bkwd_data.h" 15 | 16 | #ifdef __cplusplus 17 | } 18 | #endif 19 | -------------------------------------------------------------------------------- /zenu-cuda-kernel-sys/kernel/memory_access.cu: -------------------------------------------------------------------------------- 1 | #include "memory_access.h" 2 | #include 3 | 4 | void memory_access_float(float *array, int offset, float *result) { 5 | cudaMemcpy(result, array + offset, sizeof(float), cudaMemcpyDeviceToHost); 6 | } 7 | 8 | void memory_access_double(double *array, int offset, double *result) { 9 | cudaMemcpy(result, array + offset, sizeof(double), cudaMemcpyDeviceToHost); 10 | } 11 | 12 | void memory_set_float(float *array, int offset, float value) { 13 | cudaMemcpy(array + offset, &value, sizeof(float), cudaMemcpyHostToDevice); 14 | } 15 | 16 | void memory_set_double(double *array, int offset, double value) { 17 | cudaMemcpy(array + offset, &value, sizeof(double), cudaMemcpyHostToDevice); 18 | } 19 | -------------------------------------------------------------------------------- /zenu-cuda-kernel-sys/kernel/memory_access.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | void memory_access_float(float *array, int offset, float *result); 8 | void memory_access_double(double *array, int offset, double *result); 9 | 10 | void memory_set_float(float *array, int offset, float value); 11 | void memory_set_double(double *array, int offset, double value); 12 | 13 | #ifdef __cplusplus 14 | } 15 | #endif 16 | -------------------------------------------------------------------------------- /zenu-cuda-kernel-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![expect(non_upper_case_globals)] 2 | #![expect(non_camel_case_types)] 3 | #![expect(non_snake_case)] 4 | #![expect(clippy::unreadable_literal)] 5 | #![expect(clippy::pub_underscore_fields)] 6 | 7 | include!("./bindings.rs"); 8 | 9 | #[cfg(test)] 10 | mod kernel { 11 | 12 | use zenu_cuda_runtime_sys::{cudaMalloc, cudaMemcpy, cudaMemcpyKind}; 13 | 14 | use crate::array_scalar_add_float; 15 | 16 | #[test] 17 | fn test_add() { 18 | let a: Vec = vec![1.0, 2.0, 3.0]; 19 | // let b: Vec = vec![0., 0., 0.]; 20 | 21 | let mut a_gpu: *mut f32 = std::ptr::null_mut(); 22 | // let a_gpu_gpu = &a_gpu as *const *mut f32 as *mut *mut libc::c_void; 23 | let a_gpu_gpu = std::ptr::from_mut(&mut a_gpu); 24 | unsafe { cudaMalloc(a_gpu_gpu.cast(), 3 * std::mem::size_of::()) }; 25 | let a_gpu = unsafe { *a_gpu_gpu }; 26 | 27 | let mut b_gpu: *mut f32 = std::ptr::null_mut(); 28 | // let b_gpu_gpu = &b_gpu as *const *mut f32 as *mut *mut libc::c_void; 29 | let b_gpu_gpu = std::ptr::from_mut(&mut b_gpu); 30 | unsafe { cudaMalloc(b_gpu_gpu.cast(), 3 * std::mem::size_of::()) }; 31 | let b_gpu = unsafe { *b_gpu_gpu }; 32 | 33 | unsafe { 34 | cudaMemcpy( 35 | a_gpu.cast(), 36 | a.as_ptr().cast(), 37 | 3 * std::mem::size_of::(), 38 | cudaMemcpyKind::cudaMemcpyHostToDevice, 39 | ) 40 | }; 41 | 42 | unsafe { array_scalar_add_float(a_gpu.cast(), 3, 1, 1., b_gpu.cast(), 1) }; 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /zenu-cuda-runtime-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu-cuda-runtime-sys" 3 | version = "0.1.0" 4 | edition = "2021" 5 | repository = "https://github.com/bokutotu/zenu" 6 | license = "MIT" 7 | description = "CUDA runtime bindings for Rust" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | libc="0.2.151" 13 | 14 | [build-dependencies] 15 | bindgen = "0.69.4" 16 | zenu-cuda-config = { path = "../zenu-cuda-config", version = "0.1.0"} 17 | 18 | [lints] 19 | workspace = true 20 | -------------------------------------------------------------------------------- /zenu-cuda-runtime-sys/build.rs: -------------------------------------------------------------------------------- 1 | extern crate bindgen; 2 | use zenu_cuda_config::find_cuda; 3 | 4 | fn main() { 5 | for path in find_cuda() { 6 | println!("cargo:rustc-link-search=native={}", path.display()); 7 | } 8 | 9 | println!("cargo:rustc-link-lib=dylib=cudart"); 10 | println!("cargo:rerun-if-changed=build.rs"); 11 | 12 | let bindings = bindgen::Builder::default() 13 | .ctypes_prefix("::libc") 14 | .size_t_is_usize(true) 15 | .clang_arg("-I") 16 | .clang_arg("/usr/local/cuda/include".to_string()) 17 | .header("wrapper.h") 18 | .rustified_non_exhaustive_enum("cuda.*") 19 | .allowlist_type("^cuda.*") 20 | .allowlist_type("^surfaceReference") 21 | .allowlist_type("^textureReference") 22 | .allowlist_var("^cuda.*") 23 | .allowlist_function("^cuda.*") 24 | .default_alias_style(bindgen::AliasVariation::TypeAlias) 25 | .derive_default(true) 26 | .derive_eq(true) 27 | .derive_hash(true) 28 | .derive_ord(true) 29 | // .parse_callbacks(Box::new(bindgen::CargoCallbacks)) 30 | // .rustfmt_bindings(true) 31 | .generate() 32 | .expect("Unable to generate bindings"); 33 | 34 | bindings 35 | .write_to_file("./src/bindings.rs") 36 | .expect("Unable to write"); 37 | } 38 | -------------------------------------------------------------------------------- /zenu-cuda-runtime-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![expect(non_upper_case_globals)] 2 | #![expect(non_camel_case_types)] 3 | #![expect(non_snake_case)] 4 | #![expect(clippy::unreadable_literal)] 5 | #![expect(clippy::pub_underscore_fields)] 6 | 7 | include!("bindings.rs"); 8 | -------------------------------------------------------------------------------- /zenu-cuda-runtime-sys/wrapper.h: -------------------------------------------------------------------------------- 1 | #include "cuda_runtime.h" 2 | -------------------------------------------------------------------------------- /zenu-cuda/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu-cuda" 3 | version = "0.1.0" 4 | edition = "2021" 5 | repository = "https://github.com/bokutotu/zenu" 6 | license = "MIT" 7 | description = "CUDA bindings for Rust" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | once_cell = "1.19.0" 13 | libc = "0.2.153" 14 | 15 | zenu-cuda-driver-sys = { path = "../zenu-cuda-driver-sys", version="0.1.0" } 16 | zenu-cuda-runtime-sys = { path = "../zenu-cuda-runtime-sys", version="0.1.0" } 17 | zenu-cudnn-sys = { path = "../zenu-cudnn-sys", version="0.1.0" } 18 | zenu-cublas-sys = { path = "../zenu-cublas-sys", version="0.1.0" } 19 | zenu-cuda-kernel-sys = { path = "../zenu-cuda-kernel-sys", version="0.1.0" } 20 | zenu-cudnn-frontend-wrapper-sys = { path = "../zenu-cudnn-frontend-wrapper-sys", version="0.1.0" } 21 | -------------------------------------------------------------------------------- /zenu-cuda/src/cublas/cublas_error.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 2 | pub enum ZenuCublasError { 3 | CublasStatusSuccess, 4 | CublasStatusNotInitialized, 5 | CublasStatusAllocFailed, 6 | CublasStatusInvalidValue, 7 | CublasStatusArchMismatch, 8 | CublasStatusMappingError, 9 | CublasStatusExecutionFailed, 10 | CublasStatusInternalError, 11 | CublasStatusNotSupported, 12 | CublasStatusLicenseError, 13 | } 14 | 15 | impl From for ZenuCublasError { 16 | fn from(value: u32) -> Self { 17 | match value { 18 | 0 => ZenuCublasError::CublasStatusSuccess, 19 | 1 => ZenuCublasError::CublasStatusNotInitialized, 20 | 3 => ZenuCublasError::CublasStatusAllocFailed, 21 | 7 => ZenuCublasError::CublasStatusInvalidValue, 22 | 8 => ZenuCublasError::CublasStatusArchMismatch, 23 | 11 => ZenuCublasError::CublasStatusMappingError, 24 | 13 => ZenuCublasError::CublasStatusExecutionFailed, 25 | 14 => ZenuCublasError::CublasStatusInternalError, 26 | 15 => ZenuCublasError::CublasStatusNotSupported, 27 | 16 => ZenuCublasError::CublasStatusLicenseError, 28 | _ => panic!("Invalid cublas error code: {value}"), 29 | } 30 | } 31 | } 32 | 33 | impl From for u32 { 34 | fn from(value: ZenuCublasError) -> Self { 35 | match value { 36 | ZenuCublasError::CublasStatusSuccess => 0, 37 | ZenuCublasError::CublasStatusNotInitialized => 1, 38 | ZenuCublasError::CublasStatusAllocFailed => 3, 39 | ZenuCublasError::CublasStatusInvalidValue => 7, 40 | ZenuCublasError::CublasStatusArchMismatch => 8, 41 | ZenuCublasError::CublasStatusMappingError => 11, 42 | ZenuCublasError::CublasStatusExecutionFailed => 13, 43 | ZenuCublasError::CublasStatusInternalError => 14, 44 | ZenuCublasError::CublasStatusNotSupported => 15, 45 | ZenuCublasError::CublasStatusLicenseError => 16, 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /zenu-cuda/src/cudnn/error.rs: -------------------------------------------------------------------------------- 1 | use zenu_cudnn_sys::{cudnnGetErrorString, cudnnStatus_t}; 2 | 3 | #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] 4 | pub enum ZenuCudnnError { 5 | NotInitialized = 1001, 6 | LicenseError = 1005, 7 | RuntimeInProgress = 1006, 8 | RuntimeFpOverflow = 1007, 9 | BadParam = 2000, 10 | NotSupported = 3000, 11 | InternalError = 4000, 12 | ExecutionFailed = 5000, 13 | InvalidValue = 2001, 14 | Other = 9999, 15 | } 16 | 17 | impl From for ZenuCudnnError { 18 | fn from(status: cudnnStatus_t) -> Self { 19 | match status { 20 | cudnnStatus_t::CUDNN_STATUS_NOT_INITIALIZED => ZenuCudnnError::NotInitialized, 21 | cudnnStatus_t::CUDNN_STATUS_LICENSE_ERROR => ZenuCudnnError::LicenseError, 22 | cudnnStatus_t::CUDNN_STATUS_RUNTIME_IN_PROGRESS => ZenuCudnnError::RuntimeInProgress, 23 | cudnnStatus_t::CUDNN_STATUS_RUNTIME_FP_OVERFLOW => ZenuCudnnError::RuntimeFpOverflow, 24 | cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => ZenuCudnnError::BadParam, 25 | cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => ZenuCudnnError::NotSupported, 26 | cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR => ZenuCudnnError::InternalError, 27 | cudnnStatus_t::CUDNN_STATUS_EXECUTION_FAILED => ZenuCudnnError::ExecutionFailed, 28 | cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE => ZenuCudnnError::InvalidValue, 29 | _ => unreachable!(), 30 | } 31 | } 32 | } 33 | 34 | impl From for cudnnStatus_t { 35 | fn from(error: ZenuCudnnError) -> Self { 36 | match error { 37 | ZenuCudnnError::NotInitialized => cudnnStatus_t::CUDNN_STATUS_NOT_INITIALIZED, 38 | ZenuCudnnError::LicenseError => cudnnStatus_t::CUDNN_STATUS_LICENSE_ERROR, 39 | ZenuCudnnError::RuntimeInProgress => cudnnStatus_t::CUDNN_STATUS_RUNTIME_IN_PROGRESS, 40 | ZenuCudnnError::RuntimeFpOverflow => cudnnStatus_t::CUDNN_STATUS_RUNTIME_FP_OVERFLOW, 41 | ZenuCudnnError::BadParam => cudnnStatus_t::CUDNN_STATUS_BAD_PARAM, 42 | ZenuCudnnError::NotSupported => cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED, 43 | ZenuCudnnError::InternalError => cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR, 44 | ZenuCudnnError::ExecutionFailed => cudnnStatus_t::CUDNN_STATUS_EXECUTION_FAILED, 45 | ZenuCudnnError::InvalidValue => cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE, 46 | ZenuCudnnError::Other => unimplemented!(), 47 | } 48 | } 49 | } 50 | 51 | impl std::fmt::Debug for ZenuCudnnError { 52 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 53 | let error: cudnnStatus_t = (*self).into(); 54 | let error_char_ptr = unsafe { cudnnGetErrorString(error) }; 55 | let error_str = unsafe { std::ffi::CStr::from_ptr(error_char_ptr) }; 56 | write!(f, "{}", error_str.to_str().unwrap()) 57 | } 58 | } 59 | 60 | impl std::fmt::Display for ZenuCudnnError { 61 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 62 | let error: cudnnStatus_t = (*self).into(); 63 | let error_char_ptr = unsafe { cudnnGetErrorString(error) }; 64 | let error_str = unsafe { std::ffi::CStr::from_ptr(error_char_ptr) }; 65 | write!(f, "{}", error_str.to_str().unwrap()) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /zenu-cuda/src/cudnn/graph_utils.rs: -------------------------------------------------------------------------------- 1 | use zenu_cudnn_frontend_wrapper_sys::{ 2 | CudnnFrontendDataType_t, CudnnFrontendError_t, CudnnTensorShapeStride, 3 | }; 4 | 5 | #[expect(clippy::cast_possible_wrap)] 6 | pub fn shape_stride_to_cudnn(shape: &[usize], stride: &[usize]) -> CudnnTensorShapeStride { 7 | let num_dims = shape.len(); 8 | let mut dims = [0_i64; 8]; 9 | let mut strides = [0_i64; 8]; 10 | for i in 0..num_dims { 11 | dims[i] = shape[i] as i64; 12 | strides[i] = stride[i] as i64; 13 | } 14 | CudnnTensorShapeStride { 15 | num_dims, 16 | dims, 17 | strides, 18 | } 19 | } 20 | 21 | pub fn get_cudnn_frontend_type() -> CudnnFrontendDataType_t { 22 | if std::any::type_name::() == "f32" { 23 | CudnnFrontendDataType_t::DATA_TYPE_FLOAT 24 | } else if std::any::type_name::() == "f64" { 25 | CudnnFrontendDataType_t::DATA_TYPE_DOUBLE 26 | } else { 27 | panic!("Unsupported data type"); 28 | } 29 | } 30 | 31 | pub fn success_or_panic(status: CudnnFrontendError_t) { 32 | assert!( 33 | status == CudnnFrontendError_t::SUCCESS, 34 | "Cudnn frontend error: {status:?}" 35 | ); 36 | } 37 | -------------------------------------------------------------------------------- /zenu-cuda/src/cudnn/rnn/mod.rs: -------------------------------------------------------------------------------- 1 | mod descriptor; 2 | mod helper; 3 | mod test; 4 | 5 | pub use helper::{RNNAlgo, RNNBias, RNNCell, RNNDataLayout, RNNMathType}; 6 | 7 | pub use descriptor::{RNNContext, RNNDescriptor, RNNParams}; 8 | -------------------------------------------------------------------------------- /zenu-cudnn-frontend-wrapper-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu-cudnn-frontend-wrapper-sys" 3 | version = "0.1.0" 4 | edition = "2021" 5 | license = "MIT" 6 | description = "A wrapper for the CUDNN frontend API" 7 | repository = "https://github.com/bokutotu/zenu" 8 | 9 | [dependencies] 10 | 11 | [build-dependencies] 12 | bindgen = "0.71.0" 13 | cmake = "0.1.52" 14 | 15 | [lints] 16 | workspace = true 17 | -------------------------------------------------------------------------------- /zenu-cudnn-frontend-wrapper-sys/build.rs: -------------------------------------------------------------------------------- 1 | extern crate bindgen; 2 | extern crate cmake; 3 | 4 | use std::env; 5 | use std::path::PathBuf; 6 | 7 | fn main() { 8 | let out_dir = env::var("OUT_DIR").unwrap(); 9 | 10 | let dst = cmake::Config::new("cudnn_frontend_wrapper") 11 | .define("CMAKE_INSTALL_PREFIX", &out_dir) 12 | .build(); 13 | 14 | println!("cargo:rustc-link-search=native={}/lib", dst.display()); 15 | 16 | println!("cargo:rustc-link-lib=static=cudnn_frontend_wrapper"); 17 | 18 | let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); 19 | let header_path = PathBuf::from(&manifest_dir) 20 | .join("cudnn_frontend_wrapper") 21 | .join("include") 22 | .join("cudnn_frontend_wrapper.h"); 23 | 24 | let bindings = bindgen::Builder::default() 25 | .header(header_path.to_string_lossy()) 26 | .clang_arg(format!( 27 | "-I{}", 28 | PathBuf::from(&manifest_dir) 29 | .join("cudnn_frontend_wrapper/include") 30 | .display() 31 | )) 32 | .clang_arg("-I/usr/local/cuda/include") 33 | .rustified_enum(".*") 34 | .generate() 35 | .expect("Unable to generate bindings"); 36 | 37 | let out_path = PathBuf::from("src").join("bindings.rs"); 38 | bindings 39 | .write_to_file(&out_path) 40 | .expect("Couldn't write bindings!"); 41 | } 42 | -------------------------------------------------------------------------------- /zenu-cudnn-frontend-wrapper-sys/cudnn_frontend_wrapper/.gitignore: -------------------------------------------------------------------------------- 1 | CMakeLists.txt.user 2 | CMakeCache.txt 3 | CMakeFiles 4 | CMakeScripts 5 | Testing 6 | Makefile 7 | cmake_install.cmake 8 | install_manifest.txt 9 | compile_commands.json 10 | CTestTestfile.cmake 11 | _deps 12 | CMakeUserPresets.json 13 | 14 | # Prerequisites 15 | *.d 16 | 17 | # Compiled Object files 18 | *.slo 19 | *.lo 20 | *.o 21 | *.obj 22 | 23 | # Precompiled Headers 24 | *.gch 25 | *.pch 26 | 27 | # Compiled Dynamic libraries 28 | *.so 29 | *.dylib 30 | *.dll 31 | 32 | # Fortran module files 33 | *.mod 34 | *.smod 35 | 36 | # Compiled Static libraries 37 | *.lai 38 | *.la 39 | *.a 40 | *.lib 41 | 42 | # Executables 43 | *.exe 44 | *.out 45 | *.app 46 | 47 | # Prerequisites 48 | *.d 49 | 50 | # Object files 51 | *.o 52 | *.ko 53 | *.obj 54 | *.elf 55 | 56 | # Linker output 57 | *.ilk 58 | *.map 59 | *.exp 60 | 61 | # Precompiled Headers 62 | *.gch 63 | *.pch 64 | 65 | # Libraries 66 | *.lib 67 | *.a 68 | *.la 69 | *.lo 70 | 71 | # Shared objects (inc. Windows DLLs) 72 | *.dll 73 | *.so 74 | *.so.* 75 | *.dylib 76 | 77 | # Executables 78 | *.exe 79 | *.out 80 | *.app 81 | *.i*86 82 | *.x86_64 83 | *.hex 84 | 85 | # Debug files 86 | *.dSYM/ 87 | *.su 88 | *.idb 89 | *.pdb 90 | 91 | # Kernel Module Compile Results 92 | *.mod* 93 | *.cmd 94 | .tmp_versions/ 95 | modules.order 96 | Module.symvers 97 | Mkfile.old 98 | dkms.conf 99 | 100 | .ccls-cache/* 101 | build 102 | -------------------------------------------------------------------------------- /zenu-cudnn-frontend-wrapper-sys/cudnn_frontend_wrapper/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "cudnn-frontend"] 2 | path = cudnn-frontend 3 | url = git@github.com:NVIDIA/cudnn-frontend.git 4 | -------------------------------------------------------------------------------- /zenu-cudnn-frontend-wrapper-sys/cudnn_frontend_wrapper/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18) 2 | project(cudnn_frontend_wrapper VERSION 1.8.0) 3 | 4 | option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF) 5 | option(CUDNN_FRONTEND_BUILD_TESTS "Build tests for cudnn_frontend_wrapper" OFF) 6 | 7 | find_package(Threads REQUIRED) 8 | find_package(CUDAToolkit REQUIRED) 9 | 10 | # C++17 対応 11 | set(CMAKE_CXX_STANDARD 17) 12 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 13 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 14 | 15 | # cuDNNをCUDNN::cudnnターゲットとして読み込み 16 | include(${PROJECT_SOURCE_DIR}/cudnn-frontend/cmake/cuDNN.cmake) 17 | 18 | set(SRC_FILES 19 | ${PROJECT_SOURCE_DIR}/src/batchnorm.cpp 20 | ${PROJECT_SOURCE_DIR}/src/conv.cpp 21 | ${PROJECT_SOURCE_DIR}/src/cudnn_frontend_wrapper.cpp 22 | ${PROJECT_SOURCE_DIR}/src/utils.cpp 23 | ) 24 | 25 | # 静的ライブラリ作成 26 | add_library(cudnn_frontend_wrapper STATIC ${SRC_FILES}) 27 | 28 | # コンパイル定義フラグをターゲットへ設定 29 | target_compile_definitions(cudnn_frontend_wrapper PUBLIC 30 | $<$:CUDNN_FRONTEND_SKIP_JSON_LIB> 31 | ) 32 | 33 | # インクルードディレクトリ設定 34 | target_include_directories(cudnn_frontend_wrapper PUBLIC 35 | ${PROJECT_SOURCE_DIR}/include 36 | ${PROJECT_SOURCE_DIR}/src 37 | ${PROJECT_SOURCE_DIR}/cudnn-frontend/include 38 | ${CUDAToolkit_INCLUDE_DIRS} 39 | ) 40 | 41 | # デバッグビルドでは-gオプションを付与 42 | if(CMAKE_BUILD_TYPE STREQUAL "Debug") 43 | target_compile_options(cudnn_frontend_wrapper PUBLIC -g) 44 | endif() 45 | 46 | # リンクライブラリ設定 47 | target_link_libraries(cudnn_frontend_wrapper 48 | Threads::Threads 49 | CUDA::cudart 50 | CUDA::cuda_driver 51 | CUDNN::cudnn 52 | ) 53 | 54 | if(CUDNN_FRONTEND_BUILD_TESTS) 55 | include(FetchContent) 56 | FetchContent_Declare( 57 | Catch2 58 | GIT_REPOSITORY https://github.com/catchorg/Catch2.git 59 | GIT_TAG v2.13.7 60 | ) 61 | FetchContent_MakeAvailable(Catch2) 62 | 63 | set(TEST_SRC_FILES 64 | ${PROJECT_SOURCE_DIR}/tests/batchnorm.cpp 65 | ${PROJECT_SOURCE_DIR}/tests/conv.cpp 66 | ) 67 | 68 | add_executable(run_tests ${TEST_SRC_FILES}) 69 | 70 | target_link_libraries(run_tests 71 | cudnn_frontend_wrapper 72 | Catch2::Catch2 73 | ) 74 | 75 | # テスト登録 76 | include(CTest) 77 | add_test(NAME run_tests COMMAND run_tests) 78 | endif() 79 | 80 | # ccls用の設定ファイル出力(任意) 81 | file(WRITE ${CMAKE_CURRENT_SOURCE_DIR}/.ccls " 82 | %compile_commands.json=build/compile_commands.json 83 | %clang 84 | -std=c++17 85 | -I${CMAKE_CURRENT_SOURCE_DIR}/cudnn-frontend/include 86 | ") 87 | 88 | install(TARGETS cudnn_frontend_wrapper 89 | ARCHIVE DESTINATION lib 90 | LIBRARY DESTINATION lib 91 | RUNTIME DESTINATION bin 92 | ) 93 | -------------------------------------------------------------------------------- /zenu-cudnn-frontend-wrapper-sys/cudnn_frontend_wrapper/README.md: -------------------------------------------------------------------------------- 1 | cudnn 9以上が必要っぽい 2 | 3 | # build 4 | ```bash 5 | mkdir build 6 | cd build 7 | cmake .. 8 | cmake --build . -j32 9 | ``` 10 | 11 | # debug 12 | ```bash 13 | cmake -DCMAKE_BUILD_TYPE=Debug .. 14 | ``` 15 | -------------------------------------------------------------------------------- /zenu-cudnn-frontend-wrapper-sys/cudnn_frontend_wrapper/src/i_graph_desc.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "cudnn_frontend.h" 6 | #include "../include/cudnn_frontend_wrapper.h" 7 | 8 | namespace fe = cudnn_frontend; 9 | 10 | class IGraphDescriptor { 11 | protected: 12 | fe::graph::Graph graph; 13 | std::vector heur_mode = {fe::HeurMode_t::A}; 14 | 15 | public: 16 | virtual ~IGraphDescriptor() = default; 17 | 18 | virtual CudnnFrontendError_t build_and_check_graph(cudnnHandle_t* handle, bool build_all_plans) { 19 | auto err = graph.validate(); 20 | if (!err.is_good()) { return handle_error("Graph validation", err); } 21 | 22 | err = graph.build_operation_graph(*handle); 23 | if (!err.is_good()) { return handle_error("Graph build operation graph", err); } 24 | 25 | err = graph.create_execution_plans(heur_mode); 26 | if (!err.is_good()) { return handle_error("Graph create execution plans", err); } 27 | 28 | err = graph.check_support(*handle); 29 | if (!err.is_good()) { return handle_error("Graph check support", err); } 30 | 31 | if (build_all_plans) { 32 | err = graph.build_plans(*handle, fe::BuildPlanPolicy_t::ALL); 33 | } else { 34 | err = graph.build_plans(*handle); 35 | } 36 | if (!err.is_good()) { return handle_error("Graph build plans", err); } 37 | 38 | return CudnnFrontendError_t::SUCCESS; 39 | } 40 | 41 | virtual CudnnFrontendError_t get_workspace_size(int64_t* workspace_size) { 42 | auto err = graph.get_workspace_size(*workspace_size); 43 | if (!err.is_good()) { 44 | return handle_error("Graph get workspace size", err); 45 | } 46 | return CudnnFrontendError_t::SUCCESS; 47 | } 48 | 49 | virtual CudnnFrontendError_t execute_graph(cudnnHandle_t* handle, 50 | std::unordered_map, void*> &variant_pack, 51 | void* workspace) { 52 | auto err = graph.execute(*handle, variant_pack, workspace); 53 | if (!err.is_good()) { 54 | return handle_error("Graph execute failed", err); 55 | } 56 | return CudnnFrontendError_t::SUCCESS; 57 | } 58 | 59 | protected: 60 | CudnnFrontendError_t handle_error(const std::string &msg, const fe::error_t &err) { 61 | std::cout << msg << std::endl; 62 | std::cout << const_cast(err).get_message() << std::endl; 63 | return CudnnFrontendError_t::FAILURE; 64 | } 65 | 66 | }; 67 | 68 | -------------------------------------------------------------------------------- /zenu-cudnn-frontend-wrapper-sys/cudnn_frontend_wrapper/src/macros.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bokutotu/zenu/3ab3e9580789fb0cd1793316c04f9c05dcc55127/zenu-cudnn-frontend-wrapper-sys/cudnn_frontend_wrapper/src/macros.h -------------------------------------------------------------------------------- /zenu-cudnn-frontend-wrapper-sys/cudnn_frontend_wrapper/src/utils.cpp: -------------------------------------------------------------------------------- 1 | #include "cudnn_frontend.h" 2 | #include "cudnn_frontend_wrapper.h" 3 | #include "utils.h" 4 | 5 | #include 6 | #include 7 | 8 | std::vector from_shape(size_t n, int64_t dims[8]) { 9 | std::vector shape; 10 | for (size_t i = 0; i < n; i++) { 11 | shape.push_back(dims[i]); 12 | } 13 | return shape; 14 | } 15 | 16 | cudnn_frontend::DataType_t get_data_type(CudnnFrontendDataType_t data_type) { 17 | switch (data_type) { 18 | case DATA_TYPE_HALF: 19 | return cudnn_frontend::DataType_t::HALF; 20 | case DATA_TYPE_FLOAT: 21 | return cudnn_frontend::DataType_t::FLOAT; 22 | case DATA_TYPE_DOUBLE: 23 | return cudnn_frontend::DataType_t::DOUBLE; 24 | default: 25 | std::stringstream err_msg; 26 | err_msg << "Invalid data type: " << data_type; 27 | throw std::runtime_error(err_msg.str()); 28 | } 29 | } 30 | 31 | cudnn_frontend::graph::Tensor_attributes get_tensor_attributes(std::vector shape, 32 | std::vector strides, 33 | CudnnFrontendDataType_t data_type) { 34 | auto type = get_data_type(data_type); 35 | return cudnn_frontend::graph::Tensor_attributes() 36 | .set_dim(shape) 37 | .set_stride(strides) 38 | .set_data_type(type); 39 | } 40 | 41 | cudnn_frontend::graph::Tensor_attributes get_tensor_attributes_without_type(std::vector shape, 42 | std::vector strides) { 43 | return cudnn_frontend::graph::Tensor_attributes() 44 | .set_dim(shape) 45 | .set_stride(strides); 46 | } 47 | -------------------------------------------------------------------------------- /zenu-cudnn-frontend-wrapper-sys/cudnn_frontend_wrapper/src/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "cudnn_frontend.h" 7 | #include "cudnn_frontend_wrapper.h" 8 | 9 | std::vector from_shape(size_t n, int64_t dims[8]); 10 | 11 | cudnn_frontend::DataType_t get_data_type(CudnnFrontendDataType_t data_type); 12 | 13 | cudnn_frontend::graph::Tensor_attributes get_tensor_attributes(std::vector shape, 14 | std::vector strides, 15 | CudnnFrontendDataType_t data_type); 16 | 17 | cudnn_frontend::graph::Tensor_attributes get_tensor_attributes_without_type(std::vector shape, 18 | std::vector strides); 19 | -------------------------------------------------------------------------------- /zenu-cudnn-frontend-wrapper-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![expect(non_upper_case_globals)] 2 | #![expect(non_camel_case_types)] 3 | #![expect(non_snake_case)] 4 | #![expect(clippy::unreadable_literal)] 5 | #![expect(clippy::pub_underscore_fields)] 6 | 7 | include!("bindings.rs"); 8 | -------------------------------------------------------------------------------- /zenu-cudnn-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu-cudnn-sys" 3 | version = "0.1.0" 4 | edition = "2021" 5 | repository = "https://github.com/bokutotu/zenu" 6 | license = "MIT" 7 | description = "Rust bindings for cuDNN" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | libc = "0.2.151" 13 | 14 | [build-dependencies] 15 | bindgen = "0.69.4" 16 | zenu-cuda-config = { path = "../zenu-cuda-config", version = "0.1.0"} 17 | 18 | [lints] 19 | workspace = true 20 | -------------------------------------------------------------------------------- /zenu-cudnn-sys/build.rs: -------------------------------------------------------------------------------- 1 | extern crate bindgen; 2 | 3 | use zenu_cuda_config::find_cuda; 4 | 5 | fn main() { 6 | for path in find_cuda() { 7 | println!("cargo:rustc-link-search=native={}", path.display()); 8 | } 9 | 10 | println!("cargo:rustc-link-lib=dylib=cudart"); 11 | println!("cargo:rerun-if-changed=build.rs"); 12 | println!("cargo:rustc-link-lib=dylib=cudnn"); 13 | 14 | let bindings = bindgen::Builder::default() 15 | .ctypes_prefix("::libc") 16 | .allowlist_function("cu.*") 17 | .allowlist_var("CUDNN.*") 18 | .allowlist_type("[Cc][Uu].*") 19 | .default_alias_style(bindgen::AliasVariation::TypeAlias) 20 | .rustified_non_exhaustive_enum("cudnn[A-Za-z]+_t") 21 | .rustified_non_exhaustive_enum("cuda.*") 22 | .derive_default(true) 23 | .derive_eq(true) 24 | .derive_hash(true) 25 | .derive_ord(true) 26 | // .parse_callbacks(Box::new(bindgen::CargoCallbacks)) 27 | // .rustfmt_bindings(true) 28 | .clang_arg("-I") 29 | .clang_arg("/usr/local/cuda/include".to_string()) 30 | .header("wrapper.h") 31 | .generate() 32 | .expect("Unable to generate bindings"); 33 | 34 | bindings 35 | .write_to_file("./src/bindings.rs") 36 | .expect("Unable to write"); 37 | } 38 | -------------------------------------------------------------------------------- /zenu-cudnn-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![expect(non_upper_case_globals)] 2 | #![expect(non_camel_case_types)] 3 | #![expect(non_snake_case)] 4 | #![expect(clippy::unreadable_literal)] 5 | #![expect(clippy::pub_underscore_fields)] 6 | 7 | // include!(concat!(env!("OUT_DIR"), "/bindings.rs")); 8 | include!("bindings.rs"); 9 | -------------------------------------------------------------------------------- /zenu-cudnn-sys/wrapper.h: -------------------------------------------------------------------------------- 1 | #include "cudnn.h" 2 | -------------------------------------------------------------------------------- /zenu-layer/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu-layer" 3 | version = "0.1.1" 4 | edition = "2021" 5 | description = "A simple neural network layer library." 6 | license = "MIT" 7 | repository = "https://github.com/bokutotu/zenu" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | zenu-matrix = { path = "../zenu-matrix", version = "0.1.2" } 13 | zenu-autograd = { path = "../zenu-autograd", version = "0.1.2" } 14 | 15 | rand = "0.8.5" 16 | rand_distr = "0.4.3" 17 | 18 | [dev-dependencies] 19 | zenu-test = { path = "../zenu-test" } 20 | 21 | serde_json = "1" 22 | 23 | [lints] 24 | workspace = true 25 | 26 | [profile.bench] 27 | debug = true 28 | 29 | [features] 30 | nvidia = ["zenu-matrix/nvidia", "zenu-autograd/nvidia"] 31 | -------------------------------------------------------------------------------- /zenu-layer/README.md: -------------------------------------------------------------------------------- 1 | # ZeNu Layer 2 | 3 | ZeNu Layer is a collection of neural network layers implemented in Rust. It provides building blocks for constructing neural networks and integrates with the ZeNu deep learning library. 4 | 5 | ## Features 6 | 7 | - Various layer types, including fully connected (linear) layers 8 | - Layer parameter initialization 9 | - Forward pass computation 10 | - Integration with ZeNu Autograd for automatic differentiation 11 | 12 | ## Getting Started 13 | 14 | To use ZeNu Layer in your Rust project, add the following to your `Cargo.toml` file: 15 | 16 | ```toml 17 | [dependencies] 18 | zenu-layer = "0.1.0" 19 | ``` 20 | 21 | Here's a simple example of using a linear layer from ZeNu Layer: 22 | 23 | ```rust 24 | use zenu_autograd::creator::from_vec::from_vec; 25 | use zenu_layer::layers::linear::Linear; 26 | use zenu_layer::Layer; 27 | 28 | fn main() { 29 | // Create a new linear layer with input dimension 3 and output dimension 2 30 | let mut linear_layer = Linear::new(3, 2); 31 | 32 | // Initialize the layer parameters with a random seed 33 | linear_layer.init_parameters(Some(42)); 34 | 35 | // Create input data as a Variable 36 | let input = from_vec(vec![1., 2., 3.], [1, 3]); 37 | 38 | // Perform a forward pass through the layer 39 | let output = linear_layer.call(input); 40 | 41 | // Access the layer parameters 42 | let parameters = linear_layer.parameters(); 43 | } 44 | ``` 45 | 46 | For more details and examples, please refer to the [documentation](https://docs.rs/zenu-layer). 47 | 48 | ## License 49 | 50 | ZeNu Layer is licensed under the [MIT License](LICENSE). 51 | -------------------------------------------------------------------------------- /zenu-layer/src/layers/dropout.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, collections::HashMap, rc::Rc}; 2 | 3 | use zenu_autograd::{ 4 | nn::dropout::{dropout, DropoutConfig}, 5 | Variable, 6 | }; 7 | use zenu_matrix::{ 8 | device::Device, 9 | dim::{DimDyn, DimTrait}, 10 | num::Num, 11 | }; 12 | 13 | use crate::{Module, Parameters}; 14 | 15 | pub struct Dropout { 16 | config: DropoutConfig, 17 | input_shape: Option>>, 18 | raio: f32, 19 | } 20 | 21 | impl Module for Dropout { 22 | type Input = Variable; 23 | type Output = Variable; 24 | fn call(&self, input: Variable) -> Variable { 25 | if self.input_shape.as_ref().unwrap().borrow().slice() != input.get_shape().slice() { 26 | todo!(); 27 | } 28 | dropout(input, self.raio, Some(self.config.clone())) 29 | } 30 | } 31 | 32 | impl Dropout { 33 | #[must_use] 34 | pub fn new(rate: f32) -> Self { 35 | let config = DropoutConfig::new(rate); 36 | Self { 37 | config, 38 | input_shape: None, 39 | raio: rate, 40 | } 41 | } 42 | 43 | pub fn gpu_init(&self, shape: DimDyn) { 44 | self.config.gpu_init(shape); 45 | } 46 | } 47 | 48 | impl Parameters for Dropout { 49 | fn weights(&self) -> HashMap> { 50 | HashMap::new() 51 | } 52 | 53 | fn biases(&self) -> HashMap> { 54 | HashMap::new() 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /zenu-layer/src/layers/max_pool_2d.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use zenu_autograd::{ 4 | nn::pool2d::{max_pool_2d, MaxPool2dConfig}, 5 | Variable, 6 | }; 7 | use zenu_matrix::{device::Device, num::Num}; 8 | 9 | use crate::{Module, Parameters}; 10 | 11 | pub struct MaxPool2d { 12 | stride: (usize, usize), 13 | kernel_size: (usize, usize), 14 | pad: (usize, usize), 15 | config: MaxPool2dConfig, 16 | } 17 | 18 | impl Parameters for MaxPool2d { 19 | fn weights(&self) -> HashMap> { 20 | HashMap::new() 21 | } 22 | 23 | fn biases(&self) -> HashMap> { 24 | HashMap::new() 25 | } 26 | } 27 | 28 | impl MaxPool2d { 29 | #[must_use] 30 | pub fn new(kernel_size: (usize, usize), stride: (usize, usize), pad: (usize, usize)) -> Self { 31 | Self { 32 | stride, 33 | kernel_size, 34 | pad, 35 | config: MaxPool2dConfig::default(), 36 | } 37 | } 38 | } 39 | 40 | impl Module for MaxPool2d { 41 | type Input = Variable; 42 | type Output = Variable; 43 | fn call(&self, input: Variable) -> Variable { 44 | max_pool_2d( 45 | input, 46 | self.kernel_size, 47 | self.stride, 48 | self.pad, 49 | self.config.clone(), 50 | ) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /zenu-layer/src/layers/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod batch_norm_2d; 2 | pub mod conv2d; 3 | pub mod dropout; 4 | pub mod linear; 5 | pub mod max_pool_2d; 6 | pub mod rnn; 7 | -------------------------------------------------------------------------------- /zenu-layer/src/layers/rnn/gru.rs: -------------------------------------------------------------------------------- 1 | use rand_distr::{Distribution, StandardNormal}; 2 | use zenu_autograd::{ 3 | nn::rnns::{gru::naive::gru_naive, weights::GRUCell}, 4 | Variable, 5 | }; 6 | 7 | #[cfg(feature = "nvidia")] 8 | use zenu_autograd::nn::rnns::gru::cudnn::gru_cudnn; 9 | 10 | use zenu_matrix::{device::Device, num::Num}; 11 | 12 | use crate::{Module, ModuleParameters, Parameters}; 13 | 14 | use super::{builder::RNNSLayerBuilder, inner::RNNInner}; 15 | 16 | pub struct GRUInput { 17 | pub x: Variable, 18 | pub hx: Variable, 19 | } 20 | 21 | impl ModuleParameters for GRUInput {} 22 | 23 | impl RNNInner { 24 | fn forward(&self, input: GRUInput) -> Variable { 25 | #[cfg(feature = "nvidia")] 26 | if self.is_cudnn { 27 | let desc = self.desc.as_ref().unwrap(); 28 | let weights = self.cudnn_weights.as_ref().unwrap(); 29 | 30 | let out = gru_cudnn( 31 | desc.clone(), 32 | input.x.to(), 33 | Some(input.hx.to()), 34 | weights.to(), 35 | self.is_training, 36 | ); 37 | 38 | return out.y.to(); 39 | } 40 | 41 | gru_naive( 42 | input.x, 43 | input.hx, 44 | self.weights.as_ref().unwrap(), 45 | self.is_bidirectional, 46 | ) 47 | } 48 | } 49 | 50 | pub struct GRU(RNNInner); 51 | 52 | impl Parameters for GRU { 53 | fn weights(&self) -> std::collections::HashMap> { 54 | self.0.weights() 55 | } 56 | 57 | fn biases(&self) -> std::collections::HashMap> { 58 | self.0.biases() 59 | } 60 | 61 | fn load_parameters(&mut self, parameters: std::collections::HashMap>) { 62 | self.0.load_parameters(parameters); 63 | } 64 | } 65 | 66 | impl Module for GRU { 67 | type Input = GRUInput; 68 | type Output = Variable; 69 | 70 | fn call(&self, input: Self::Input) -> Self::Output { 71 | self.0.forward(input) 72 | } 73 | } 74 | 75 | pub type GRUBuilder = RNNSLayerBuilder; 76 | 77 | impl RNNSLayerBuilder 78 | where 79 | StandardNormal: Distribution, 80 | { 81 | pub fn build_gru(self) -> GRU { 82 | GRU(self.build_inner()) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /zenu-layer/src/layers/rnn/lstm.rs: -------------------------------------------------------------------------------- 1 | use rand_distr::{Distribution, StandardNormal}; 2 | use zenu_autograd::{ 3 | nn::rnns::{lstm::naive::lstm_naive, weights::LSTMCell}, 4 | Variable, 5 | }; 6 | 7 | #[cfg(feature = "nvidia")] 8 | use zenu_autograd::nn::rnns::lstm::cudnn::lstm_cudnn; 9 | 10 | use zenu_matrix::{device::Device, num::Num}; 11 | 12 | use crate::{Module, ModuleParameters, Parameters}; 13 | 14 | use super::{builder::RNNSLayerBuilder, inner::RNNInner}; 15 | 16 | pub struct LSTMInput { 17 | pub x: Variable, 18 | pub hx: Variable, 19 | pub cx: Variable, 20 | } 21 | 22 | impl ModuleParameters for LSTMInput {} 23 | 24 | impl RNNInner { 25 | fn forward(&self, input: LSTMInput) -> Variable { 26 | #[cfg(feature = "nvidia")] 27 | if self.is_cudnn { 28 | let desc = self.desc.as_ref().unwrap(); 29 | let weights = self.cudnn_weights.as_ref().unwrap(); 30 | 31 | let out = lstm_cudnn( 32 | desc.clone(), 33 | input.x.to(), 34 | Some(input.hx.to()), 35 | Some(input.cx.to()), 36 | weights.to(), 37 | self.is_training, 38 | ); 39 | 40 | return out.to(); 41 | } 42 | 43 | lstm_naive( 44 | input.x, 45 | input.hx, 46 | input.cx, 47 | self.weights.as_ref().unwrap(), 48 | self.is_bidirectional, 49 | ) 50 | } 51 | } 52 | 53 | pub struct LSTM(RNNInner); 54 | 55 | impl Parameters for LSTM { 56 | fn weights(&self) -> std::collections::HashMap> { 57 | self.0.weights() 58 | } 59 | 60 | fn biases(&self) -> std::collections::HashMap> { 61 | self.0.biases() 62 | } 63 | 64 | fn load_parameters(&mut self, parameters: std::collections::HashMap>) { 65 | self.0.load_parameters(parameters); 66 | } 67 | } 68 | 69 | impl Module for LSTM { 70 | type Input = LSTMInput; 71 | type Output = Variable; 72 | 73 | fn call(&self, input: Self::Input) -> Self::Output { 74 | self.0.forward(input) 75 | } 76 | } 77 | 78 | pub type LSTMBuilder = RNNSLayerBuilder; 79 | 80 | impl RNNSLayerBuilder 81 | where 82 | StandardNormal: Distribution, 83 | { 84 | pub fn build_lstm(self) -> LSTM { 85 | LSTM(self.build_inner()) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /zenu-layer/src/layers/rnn/mod.rs: -------------------------------------------------------------------------------- 1 | mod builder; 2 | mod gru; 3 | mod inner; 4 | mod lstm; 5 | #[expect(clippy::module_inception)] 6 | mod rnn; 7 | 8 | pub use gru::*; 9 | pub use inner::Activation; 10 | pub use lstm::*; 11 | pub use rnn::*; 12 | -------------------------------------------------------------------------------- /zenu-macros/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu-macros" 3 | version = "0.1.0" 4 | edition = "2021" 5 | repository = "https://github.com/bokutotu/zenu" 6 | license = "MIT" 7 | description = "Procedural macros for Zenu" 8 | 9 | [dependencies] 10 | proc-macro2 = "1" 11 | quote = "1" 12 | syn = { version = "1", features = ["full"] } 13 | 14 | [dev-dependencies] 15 | zenu-test = { path = "../zenu-test" } 16 | zenu = { path = "../zenu"} 17 | serde = { version = "1.0.197", features = ["derive"] } 18 | serde_json = "1.0" 19 | rand_distr = "0.4.3" 20 | 21 | 22 | [lib] 23 | proc-macro = true 24 | 25 | [lints] 26 | workspace = true 27 | -------------------------------------------------------------------------------- /zenu-macros/tests/multi_parameter_struct.rs: -------------------------------------------------------------------------------- 1 | use rand_distr::{Distribution, StandardNormal}; 2 | use zenu::layer::layers::{conv2d::Conv2d, linear::Linear, max_pool_2d::MaxPool2d}; 3 | use zenu::macros::Parameters; 4 | use zenu::matrix::{ 5 | device::{cpu::Cpu, Device}, 6 | num::Num, 7 | }; 8 | use zenu_test::assert_val_eq; 9 | 10 | #[derive(Parameters)] 11 | #[parameters(num = F, device = De)] 12 | struct ConvBlock { 13 | pub conv2d: Conv2d, 14 | pub max_pool: MaxPool2d, 15 | } 16 | 17 | #[derive(Parameters)] 18 | #[parameters(num = T, device = D)] 19 | struct LinearBlock { 20 | pub linear: Linear, 21 | } 22 | 23 | #[derive(Parameters)] 24 | #[parameters(num = T, device = D)] 25 | struct ConvNet { 26 | pub conv_block: ConvBlock, 27 | pub linear_block: LinearBlock, 28 | } 29 | 30 | impl ConvNet { 31 | fn new() -> Self 32 | where 33 | StandardNormal: Distribution, 34 | { 35 | Self { 36 | conv_block: ConvBlock { 37 | conv2d: Conv2d::new(3, 3, (1, 1), (1, 1), (1, 1), (1, 1), true), 38 | max_pool: MaxPool2d::new((2, 2), (2, 2), (0, 0)), 39 | }, 40 | linear_block: LinearBlock { 41 | linear: Linear::new(2, 2, true), 42 | }, 43 | } 44 | } 45 | } 46 | 47 | #[test] 48 | fn multi_params() { 49 | use zenu::layer::Parameters; 50 | let model = ConvNet::::new(); 51 | let conv_fileter = model.conv_block.conv2d.filter.clone(); 52 | let conv_bias = model.conv_block.conv2d.bias.clone(); 53 | let conv_bias = conv_bias.unwrap(); 54 | let linear_weight = model.linear_block.linear.weight.clone(); 55 | let linear_bias = model.linear_block.linear.bias.clone(); 56 | let linear_bias = linear_bias.unwrap(); 57 | 58 | let weights = model.weights(); 59 | let biases = model.biases(); 60 | let parameteers = model.parameters(); 61 | assert_eq!(weights.len(), 2); 62 | assert_eq!(biases.len(), 2); 63 | assert_eq!(parameteers.len(), 4); 64 | 65 | assert_val_eq!( 66 | weights["conv_block.conv2d.conv2d.filter"].clone(), 67 | conv_fileter.get_data(), 68 | 1e-6 69 | ); 70 | assert_val_eq!( 71 | biases["conv_block.conv2d.conv2d.bias"].clone(), 72 | conv_bias.get_data(), 73 | 1e-6 74 | ); 75 | assert_val_eq!( 76 | weights["linear_block.linear.linear.weight"].clone(), 77 | linear_weight.get_data(), 78 | 1e-6 79 | ); 80 | assert_val_eq!( 81 | biases["linear_block.linear.linear.bias"].clone(), 82 | linear_bias.get_data(), 83 | 1e-6 84 | ); 85 | } 86 | 87 | #[test] 88 | fn test_load_parameters_convnet() { 89 | use zenu::layer::Parameters; 90 | let model = ConvNet::::new(); 91 | let parameters = model.parameters(); 92 | 93 | let mut new_model = ConvNet::::new(); 94 | 95 | new_model.load_parameters(parameters.clone()); 96 | 97 | let new_parameters = new_model.parameters(); 98 | 99 | for (key, value) in ¶meters { 100 | assert_val_eq!( 101 | value.clone(), 102 | new_parameters[key].clone().get_as_ref(), 103 | 1e-6 104 | ); 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /zenu-macros/tests/small_case.rs: -------------------------------------------------------------------------------- 1 | use zenu::layer::layers::linear::Linear; 2 | use zenu::macros::Parameters as ParametersDerive; 3 | use zenu::matrix::{ 4 | device::{cpu::Cpu, Device}, 5 | matrix::Matrix, 6 | }; 7 | use zenu_test::assert_val_eq; 8 | 9 | #[derive(ParametersDerive)] 10 | #[parameters(num = f32, device = D)] 11 | pub struct Hoge 12 | where 13 | D: Device, 14 | { 15 | pub linear: Linear, 16 | } 17 | 18 | #[test] 19 | fn small_net() { 20 | use zenu::layer::Parameters; 21 | let hoge = Hoge:: { 22 | linear: Linear::new(2, 2, true), 23 | }; 24 | 25 | let weights = hoge.weights(); 26 | let biases = hoge.biases(); 27 | let parameters = hoge.parameters(); 28 | 29 | assert_eq!(weights.len(), 1); 30 | assert_eq!(biases.len(), 1); 31 | assert_eq!(parameters.len(), 2); 32 | 33 | assert_val_eq!( 34 | weights["linear.linear.weight"].clone(), 35 | hoge.linear.weight.get_data(), 36 | 1e-4 37 | ); 38 | 39 | let linear_bias = hoge.linear.bias.clone().unwrap(); 40 | let linear_bias = linear_bias.get_data(); 41 | 42 | assert_val_eq!( 43 | biases["linear.linear.bias"].clone(), 44 | linear_bias.clone(), 45 | 1e-4 46 | ); 47 | 48 | assert_val_eq!( 49 | parameters["linear.linear.weight"].clone(), 50 | hoge.linear.weight.get_data(), 51 | 1e-4 52 | ); 53 | assert_val_eq!(parameters["linear.linear.bias"].clone(), linear_bias, 1e-4); 54 | } 55 | 56 | #[test] 57 | fn test_load_parameters() { 58 | use zenu::layer::Parameters; 59 | let base_model = Hoge:: { 60 | linear: Linear::new(2, 2, true), 61 | }; 62 | 63 | let base_model_parameters = base_model.parameters(); 64 | 65 | let mut new_model = Hoge:: { 66 | linear: Linear::new(2, 2, true), 67 | }; 68 | 69 | let new_model_weight = new_model.linear.weight.get_as_ref(); 70 | 71 | println!("{:?}", base_model_parameters.keys()); 72 | 73 | println!( 74 | "new_model.parameters().keys(): {:?}", 75 | new_model.parameters().keys() 76 | ); 77 | 78 | new_model 79 | .linear 80 | .weight 81 | .get_as_mut() 82 | .copy_from(&Matrix::zeros_like(&new_model_weight)); 83 | 84 | let new_model_bias = new_model.linear.bias.clone().unwrap().get_as_ref(); 85 | new_model 86 | .linear 87 | .bias 88 | .clone() 89 | .unwrap() 90 | .get_as_mut() 91 | .copy_from(&Matrix::zeros_like(&new_model_bias)); 92 | 93 | new_model.load_parameters(base_model_parameters); 94 | } 95 | -------------------------------------------------------------------------------- /zenu-matrix/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu-matrix" 3 | version = "0.1.2" 4 | edition = "2021" 5 | description = "Matrix library for ZeNu" 6 | license = "MIT" 7 | repository = "https://github.com/bokutotu/zenu" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | cblas = "0.4.0" 13 | openblas-src = { version = "0.10.8", features = ["system", "cblas"] } 14 | rand = "0.8.5" 15 | rand_distr = "0.4.3" 16 | serde = { version = "1.0.197", features = ["derive"] } 17 | libc = "0.2" 18 | once_cell = "1.19.0" 19 | num-traits = "0.2.19" 20 | 21 | zenu-cuda = { path = "../zenu-cuda", optional = true, version = "0.1.0" } 22 | 23 | [features] 24 | nvidia = ["zenu-cuda"] 25 | 26 | [dev-dependencies] 27 | itertools = { version = "0.10.0", default-features = false, features = ["use_std"] } 28 | criterion = "0.5.1" 29 | serde_json = "1.0.114" 30 | zenu-test = { path = "../zenu-test" } 31 | 32 | [[bench]] 33 | name = "copy_from_im2col_way" 34 | harness = false 35 | 36 | [[bench]] 37 | name = "copy_from_all_matrix" 38 | harness = false 39 | 40 | [[bench]] 41 | name = "transpose_reshape_im2col" 42 | harness = false 43 | 44 | [[bench]] 45 | name = "im2col_function" 46 | harness = false 47 | 48 | [profile.bench] 49 | debug = true 50 | 51 | # TODO: integrate to workspace level 52 | [lints.clippy] 53 | print_stdout = "deny" 54 | dbg_macro = "deny" 55 | -------------------------------------------------------------------------------- /zenu-matrix/README.md: -------------------------------------------------------------------------------- 1 | # ZeNu Matrix 2 | 3 | ZeNu Matrix is a high-performance linear algebra library for Rust, designed to provide efficient matrix operations and various utilities for working with matrices. Whether you are building complex machine learning models or performing scientific computations, ZeNu Matrix offers the tools you need. 4 | 5 | ## Features 6 | 7 | - **Comprehensive Matrix Operations**: Create, index, and slice matrices with ease. 8 | - **Element-wise Operations**: Perform operations on individual elements or entire matrices. 9 | - **Efficient Matrix Multiplication (GEMM)**: Utilize optimized routines for matrix multiplication. 10 | - **Matrix Transposition**: Quickly transpose matrices as needed. 11 | - **Broadcasting**: Seamlessly broadcast operations over matrices of different shapes. 12 | - **Random Matrix Generation**: Generate matrices with random values for testing and initialization. 13 | - **BLAS Integration**: Leverage BLAS for optimized performance on supported hardware. 14 | - **CUDA Support**: Accelerate computations using NVIDIA GPUs with CUDA integration. 15 | 16 | ## Getting Started 17 | 18 | To start using ZeNu Matrix, add it to your `Cargo.toml`: 19 | 20 | ```toml 21 | [dependencies] 22 | zenu-matrix = "0.1.1" 23 | ``` 24 | 25 | ### Example 26 | 27 | Here's a simple example of using ZeNu Matrix: 28 | 29 | ```rust 30 | use zenu_matrix::{ 31 | matrix::{IndexItem, OwnedMatrix}, 32 | matrix_impl::OwnedMatrixDyn, 33 | operation::asum::Asum, 34 | }; 35 | 36 | fn main() { 37 | let a = OwnedMatrixDyn::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]); 38 | let b = OwnedMatrixDyn::from_vec(vec![7., 8., 9., 10., 11., 12.], [3, 2]); 39 | let c = a.clone() * b.clone(); 40 | 41 | assert_eq!(c.index_item([0, 0]), 58.); 42 | assert_eq!(c.index_item([0, 1]), 64.); 43 | assert_eq!(c.index_item([1, 0]), 139.); 44 | assert_eq!(c.index_item([1, 1]), 154.); 45 | } 46 | ``` 47 | 48 | For more details and examples, please refer to the [documentation](https://docs.rs/zenu-matrix). 49 | 50 | ## License 51 | 52 | ZeNu Matrix is licensed under the [MIT License](LICENSE). 53 | -------------------------------------------------------------------------------- /zenu-matrix/benches/copy_from_all_matrix.rs: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, Criterion}; 2 | use zenu_matrix::{ 3 | constructor::zeros::Zeros, 4 | dim::DimDyn, 5 | matrix::{ToViewMatrix, ToViewMutMatrix}, 6 | matrix_impl::{Matrix, OwnedMatrixDyn}, 7 | memory_impl::{ViewMem, ViewMutMem}, 8 | operation::copy_from::CopyFrom, 9 | }; 10 | 11 | fn copy_from_all_matrix(mut a: Matrix, DimDyn>, b: Matrix, DimDyn>) { 12 | a.copy_from(&b); 13 | } 14 | 15 | fn copy_from_all_matrix_(c: &mut Criterion) { 16 | let b = black_box(OwnedMatrixDyn::zeros([32, 16, 128, 128])); 17 | let mut a = black_box(OwnedMatrixDyn::zeros([32, 16, 128, 128])); 18 | 19 | c.bench_function("copy_from_all_matrix", |b_| { 20 | b_.iter(|| copy_from_all_matrix(a.to_view_mut(), b.to_view())) 21 | }); 22 | } 23 | 24 | criterion_group!(benches, copy_from_all_matrix_); 25 | criterion_main!(benches); 26 | -------------------------------------------------------------------------------- /zenu-matrix/benches/copy_from_im2col_way.rs: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, Criterion}; 2 | use zenu_matrix::{ 3 | constructor::zeros::Zeros, 4 | dim::DimDyn, 5 | matrix::{AsPtr, MatrixSliceDyn, MatrixSliceMutDyn, ToViewMatrix, ToViewMutMatrix}, 6 | matrix_impl::{Matrix, OwnedMatrixDyn}, 7 | memory_impl::{ViewMem, ViewMutMem}, 8 | operation::copy_from::CopyFrom, 9 | slice_dynamic, 10 | }; 11 | 12 | fn copy_from_( 13 | mut a: Matrix, DimDyn>, 14 | b: Matrix, DimDyn>, 15 | kh: usize, 16 | kw: usize, 17 | oh: usize, 18 | ow: usize, 19 | sh: usize, 20 | sw: usize, 21 | ) -> *const f32 { 22 | for j in 0..kh { 23 | let j_lim = j + sh * oh; 24 | for i in 0..kw { 25 | let i_lim = i + sw * ow; 26 | let mut a = a.slice_mut_dyn(slice_dynamic!(.., .., j, i, .., ..)); 27 | let b = b.slice_dyn(slice_dynamic!(.., .., j..j_lim;sh, i..i_lim;sw)); 28 | a.copy_from(&b); 29 | } 30 | } 31 | 32 | a.as_ptr() 33 | } 34 | 35 | fn copy_from_im2col_way(c: &mut Criterion) { 36 | let b = black_box(OwnedMatrixDyn::zeros([32, 16, 128, 128])); 37 | let mut a = black_box(OwnedMatrixDyn::zeros([32, 16, 3, 3, 126, 126])); 38 | 39 | let kh = 3; 40 | let kw = 3; 41 | let sh = 1; 42 | let sw = 1; 43 | let oh = 126; 44 | let ow = 126; 45 | 46 | c.bench_function("copy_from_im2col_way", |b_| { 47 | b_.iter(|| copy_from_(a.to_view_mut(), b.to_view(), kh, kw, oh, ow, sh, sw)) 48 | }); 49 | } 50 | 51 | criterion_group!(benches, copy_from_im2col_way); 52 | criterion_main!(benches); 53 | -------------------------------------------------------------------------------- /zenu-matrix/benches/transpose_reshape_im2col.rs: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, Criterion}; 2 | use zenu_matrix::{ 3 | constructor::zeros::Zeros, 4 | dim::DimDyn, 5 | matrix::ToViewMatrix, 6 | matrix_impl::{Matrix, OwnedMatrixDyn}, 7 | memory_impl::ViewMem, 8 | operation::{reshape::Reshape, transpose::TransposeInplace}, 9 | }; 10 | 11 | fn transpose_reshape(a: Matrix, DimDyn>) { 12 | let a = a.reshape([32, 3 * 3 * 16, 126 * 126]); 13 | let a = a.transpose_by_index_inplace(&[1, 0, 2]); 14 | let _ = a.reshape([16 * 3 * 3, 32 * 126 * 126]); 15 | } 16 | 17 | fn bench(c: &mut Criterion) { 18 | let a = black_box(OwnedMatrixDyn::zeros([32, 16, 3, 3, 126, 126])); 19 | // let a = OwnedMatrixDyn::zeros([32, 16 * 3 * 3, 126 * 126]); 20 | c.bench_function("transpose_reshape_im2col", |b| { 21 | b.iter(|| transpose_reshape(a.to_view())) 22 | }); 23 | } 24 | 25 | criterion_group!(benches, bench); 26 | criterion_main!(benches); 27 | -------------------------------------------------------------------------------- /zenu-matrix/src/concat.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::Device, 3 | dim::{DimDyn, DimTrait}, 4 | index::Index0D, 5 | matrix::{Matrix, Owned, Repr}, 6 | num::Num, 7 | }; 8 | 9 | /// Matrix concatenation 10 | /// # Arguments 11 | /// * `matrix` - A slice of matrices to concatenate 12 | /// # Panics 13 | /// * If the matrices do not have the same shape 14 | /// * If the matrices are 4D 15 | pub fn concat, S: DimTrait, D: Device>( 16 | matrix: &[Matrix], 17 | ) -> Matrix, DimDyn, D> { 18 | let first_shape = matrix[0].shape(); 19 | for m in matrix.iter().skip(1) { 20 | assert!( 21 | m.shape() == first_shape, 22 | "All matrices must have the same shape" 23 | ); 24 | } 25 | assert!( 26 | first_shape.len() != 4, 27 | "Concatenation of 4D matrices is not supported" 28 | ); 29 | 30 | let mut shape = DimDyn::default(); 31 | shape.push_dim(matrix.len()); 32 | for d in first_shape { 33 | shape.push_dim(d); 34 | } 35 | 36 | let mut result = Matrix::alloc(shape); 37 | 38 | for (i, m) in matrix.iter().enumerate() { 39 | let view = m.to_ref().into_dyn_dim(); 40 | result 41 | .to_ref_mut() 42 | .index_axis_mut_dyn(Index0D::new(i)) 43 | .copy_from(&view); 44 | } 45 | 46 | result 47 | } 48 | 49 | #[expect(clippy::float_cmp)] 50 | #[cfg(test)] 51 | mod concat_test { 52 | use crate::{ 53 | device::Device, 54 | dim::DimDyn, 55 | matrix::{Matrix, Owned}, 56 | }; 57 | 58 | fn cat_1d() { 59 | let a = Matrix::, DimDyn, D>::from_vec(vec![1., 2., 3.], [3]); 60 | let b = Matrix::, DimDyn, D>::from_vec(vec![4., 5., 6.], [3]); 61 | let c = Matrix::, DimDyn, D>::from_vec(vec![7., 8., 9.], [3]); 62 | 63 | let result = super::concat(&[a, b, c]); 64 | 65 | let ans = Matrix::, DimDyn, D>::from_vec( 66 | vec![1., 2., 3., 4., 5., 6., 7., 8., 9.], 67 | [3, 3], 68 | ); 69 | 70 | let diff = result - ans; 71 | assert_eq!(diff.asum(), 0.); 72 | } 73 | #[test] 74 | fn cal_1d_cpu() { 75 | cat_1d::(); 76 | } 77 | #[cfg(feature = "nvidia")] 78 | #[test] 79 | fn cal_1d_gpu() { 80 | cat_1d::(); 81 | } 82 | 83 | fn cal_2d() { 84 | let a = Matrix::, DimDyn, D>::from_vec(vec![1., 2., 3., 4.], [2, 2]); 85 | let b = Matrix::, DimDyn, D>::from_vec(vec![5., 6., 7., 8.], [2, 2]); 86 | let c = Matrix::, DimDyn, D>::from_vec(vec![9., 10., 11., 12.], [2, 2]); 87 | let result = super::concat(&[a, b, c]); 88 | 89 | let ans = Matrix::, DimDyn, D>::from_vec( 90 | vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.], 91 | [3, 2, 2], 92 | ); 93 | 94 | let diff = result - ans; 95 | assert_eq!(diff.asum(), 0.); 96 | } 97 | #[test] 98 | fn cal_2d_cpu() { 99 | cal_2d::(); 100 | } 101 | #[cfg(feature = "nvidia")] 102 | #[test] 103 | fn cal_2d_gpu() { 104 | cal_2d::(); 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /zenu-matrix/src/constructor/alloc.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::DeviceBase, 3 | dim::{default_stride, DimTrait}, 4 | matrix::{Matrix, Owned, Ptr, Repr}, 5 | num::Num, 6 | }; 7 | 8 | impl Matrix, S, D> 9 | where 10 | T: Num, 11 | D: DeviceBase, 12 | S: DimTrait, 13 | { 14 | #[expect(clippy::missing_panics_doc)] 15 | pub fn alloc>(shape: I) -> Self { 16 | let shape = shape.into(); 17 | let num_elm = shape.num_elm(); 18 | let bytes = num_elm * std::mem::size_of::(); 19 | 20 | let ptr = Ptr::new(D::alloc(bytes).unwrap().cast(), num_elm, 0); 21 | 22 | let stride = default_stride(shape); 23 | Matrix::new(ptr, shape, stride) 24 | } 25 | 26 | pub fn alloc_like>(mat: &Matrix) -> Self { 27 | let shape = mat.shape(); 28 | Self::alloc(shape) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /zenu-matrix/src/constructor/from_vec.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::DeviceBase, 3 | dim::{default_stride, DimTrait}, 4 | matrix::{Matrix, Owned, Ptr}, 5 | num::Num, 6 | }; 7 | 8 | impl Matrix, S, D> 9 | where 10 | T: Num, 11 | D: DeviceBase, 12 | S: DimTrait, 13 | { 14 | #[expect(clippy::missing_panics_doc)] 15 | pub fn from_vec>(vec: Vec, shape: I) -> Self { 16 | let shape = shape.into(); 17 | assert!( 18 | vec.len() == shape.num_elm(), 19 | "Invalid Shape, vec.len() = {}, shape.num_elm() = {}", 20 | vec.len(), 21 | shape.num_elm() 22 | ); 23 | 24 | let len = vec.len(); 25 | 26 | let ptr = Ptr::new(D::from_vec(vec), len, 0); 27 | 28 | let stride = default_stride(shape); 29 | Matrix::new(ptr, shape, stride) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /zenu-matrix/src/constructor/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod alloc; 2 | pub mod from_vec; 3 | pub mod ones; 4 | pub mod rand; 5 | pub mod zeros; 6 | -------------------------------------------------------------------------------- /zenu-matrix/src/constructor/ones.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::DeviceBase, 3 | dim::DimTrait, 4 | matrix::{Matrix, Owned, Repr}, 5 | num::Num, 6 | }; 7 | 8 | impl Matrix, S, D> 9 | where 10 | T: Num, 11 | S: DimTrait, 12 | D: DeviceBase, 13 | { 14 | pub fn ones>(dim: I) -> Self { 15 | let dim = dim.into(); 16 | let data = vec![T::one(); dim.num_elm()]; 17 | Self::from_vec(data, dim) 18 | } 19 | 20 | pub fn ones_like>(m: &Matrix) -> Self { 21 | Self::ones(m.shape()) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /zenu-matrix/src/constructor/zeros.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::DeviceBase, 3 | dim::{default_stride, DimTrait}, 4 | matrix::{Matrix, Owned, Ptr, Repr}, 5 | num::Num, 6 | }; 7 | 8 | impl Matrix, S, D> 9 | where 10 | T: Num, 11 | S: DimTrait, 12 | D: DeviceBase, 13 | { 14 | pub fn zeros>(dim: I) -> Self { 15 | let dim = dim.into(); 16 | let num_elm = dim.num_elm(); 17 | let ptr = D::zeros(num_elm); 18 | let ptr = Ptr::new(ptr, num_elm, 0); 19 | Matrix::new(ptr, dim, default_stride(dim)) 20 | } 21 | 22 | pub fn zeros_like>(m: &Matrix) -> Self { 23 | Self::zeros(m.shape()) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /zenu-matrix/src/device/cpu/mod.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::{memory_pool::MemPoolError, num::Num, ZENU_MATRIX_STATE}; 4 | 5 | use super::{Device, DeviceBase}; 6 | 7 | #[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)] 8 | pub struct Cpu; 9 | 10 | impl DeviceBase for Cpu { 11 | fn raw_drop_ptr(ptr: *mut T) { 12 | unsafe { libc::free(ptr.cast::<::libc::c_void>()) } 13 | } 14 | 15 | fn mem_pool_drop_ptr(ptr: *mut u8) -> Result<(), MemPoolError> { 16 | let state = &ZENU_MATRIX_STATE; 17 | state.cpu.try_free(ptr) 18 | } 19 | 20 | #[expect(clippy::not_unsafe_ptr_arg_deref)] 21 | fn clone_ptr(ptr: *const T, len: usize) -> *mut T { 22 | let mut vec = Vec::with_capacity(len); 23 | for i in 0..len { 24 | vec.push(unsafe { ptr.add(i).read() }); 25 | } 26 | let ptr = vec.as_mut_ptr(); 27 | std::mem::forget(vec); 28 | ptr 29 | } 30 | 31 | #[expect(clippy::not_unsafe_ptr_arg_deref)] 32 | fn assign_item(ptr: *mut T, offset: usize, value: T) { 33 | unsafe { 34 | ptr.add(offset).write(value); 35 | } 36 | } 37 | 38 | #[expect(clippy::not_unsafe_ptr_arg_deref)] 39 | fn get_item(ptr: *const T, offset: usize) -> T { 40 | unsafe { ptr.add(offset).read() } 41 | } 42 | 43 | fn from_vec(mut vec: Vec) -> *mut T { 44 | let ptr = vec.as_mut_ptr().cast::(); 45 | std::mem::forget(vec); 46 | ptr 47 | } 48 | 49 | fn zeros(len: usize) -> *mut T { 50 | use cblas::{dscal, sscal}; 51 | let ptr = Self::alloc(len * std::mem::size_of::()) 52 | .unwrap() 53 | .cast::(); 54 | if T::is_f32() { 55 | let slice = unsafe { std::slice::from_raw_parts_mut(ptr.cast(), 1) }; 56 | unsafe { sscal(i32::try_from(len).unwrap(), 0.0, slice, 1) }; 57 | } else { 58 | let slice = unsafe { std::slice::from_raw_parts_mut(ptr.cast(), 1) }; 59 | unsafe { dscal(i32::try_from(len).unwrap(), 0.0, slice, 1) }; 60 | } 61 | ptr 62 | } 63 | 64 | fn raw_alloc(num_bytes: usize) -> Result<*mut u8, String> { 65 | let ptr = unsafe { libc::malloc(num_bytes) }; 66 | if ptr.is_null() { 67 | Err("null pointer".to_string()) 68 | } else { 69 | Ok(ptr.cast()) 70 | } 71 | } 72 | 73 | fn mem_pool_alloc(num_bytes: usize) -> Result<*mut u8, MemPoolError> { 74 | let state = &ZENU_MATRIX_STATE; 75 | state.cpu.try_alloc(num_bytes) 76 | } 77 | } 78 | 79 | impl Device for Cpu {} 80 | -------------------------------------------------------------------------------- /zenu-matrix/src/device/mod.rs: -------------------------------------------------------------------------------- 1 | use serde::Serialize; 2 | 3 | use crate::{ 4 | memory_pool::MemPoolError, 5 | nn::{ 6 | batch_norm::BatchNormalization, 7 | conv::interface::{ConvBias, ConvBkwdData, ConvBkwdFilter, ConvFwd}, 8 | dropout::Dropout, 9 | pool2d::Pool2dImpl, 10 | }, 11 | num::Num, 12 | operation::{ 13 | asum::Asum, 14 | basic_operations::{ 15 | AbsOps, AcosOps, AddOps, AsinOps, AtanOps, CosOps, CoshOps, DivOps, ExpOps, LogOps, 16 | MulOps, PowOws, SinOps, SinhOps, SqrtOps, SubOps, TanOps, TanhOps, 17 | }, 18 | clip::ClipOps, 19 | copy_from::CopyBlas, 20 | max::MaxIdx, 21 | mul::Gemm, 22 | relu::ReluOps, 23 | }, 24 | ZENU_MATRIX_STATE, 25 | }; 26 | 27 | pub mod cpu; 28 | 29 | #[cfg(feature = "nvidia")] 30 | pub mod nvidia; 31 | 32 | #[expect(clippy::module_name_repetitions)] 33 | pub trait DeviceBase: Copy + Default + Serialize + 'static { 34 | fn drop_ptr(ptr: *mut T) { 35 | let state = &ZENU_MATRIX_STATE; 36 | if state.is_mem_pool_used { 37 | let result = Self::mem_pool_drop_ptr(ptr.cast()); 38 | if result.is_err() { 39 | Self::raw_drop_ptr(ptr); 40 | } 41 | } else { 42 | Self::raw_drop_ptr(ptr); 43 | } 44 | } 45 | #[expect(clippy::missing_errors_doc)] 46 | fn mem_pool_drop_ptr(ptr: *mut u8) -> Result<(), MemPoolError>; 47 | fn raw_drop_ptr(ptr: *mut T); 48 | fn clone_ptr(ptr: *const T, len: usize) -> *mut T; 49 | fn assign_item(ptr: *mut T, offset: usize, value: T); 50 | fn get_item(ptr: *const T, offset: usize) -> T; 51 | fn from_vec(vec: Vec) -> *mut T; 52 | fn zeros(len: usize) -> *mut T; 53 | #[expect(clippy::missing_errors_doc)] 54 | fn alloc(num_bytes: usize) -> Result<*mut u8, MemPoolError> { 55 | if num_bytes == 0 { 56 | return Ok(std::ptr::null_mut()); 57 | } 58 | let state = &ZENU_MATRIX_STATE; 59 | if state.is_mem_pool_used { 60 | Self::mem_pool_alloc(num_bytes) 61 | } else { 62 | Self::raw_alloc(num_bytes).map_err(|_| MemPoolError::DeviceMallocError) 63 | } 64 | } 65 | #[expect(clippy::missing_errors_doc)] 66 | fn mem_pool_alloc(num_bytes: usize) -> Result<*mut u8, MemPoolError>; 67 | #[expect(clippy::missing_errors_doc)] 68 | fn raw_alloc(num_bytes: usize) -> Result<*mut u8, String>; 69 | } 70 | 71 | pub trait Device: 72 | DeviceBase 73 | + CopyBlas 74 | + AddOps 75 | + SubOps 76 | + MulOps 77 | + DivOps 78 | + Asum 79 | + ClipOps 80 | + SinOps 81 | + CosOps 82 | + TanOps 83 | + AsinOps 84 | + AcosOps 85 | + AtanOps 86 | + SinhOps 87 | + CoshOps 88 | + TanhOps 89 | + AbsOps 90 | + SqrtOps 91 | + ExpOps 92 | + LogOps 93 | + MaxIdx 94 | + ReluOps 95 | + Gemm 96 | + PowOws 97 | + BatchNormalization 98 | + ConvFwd 99 | + ConvBkwdData 100 | + ConvBkwdFilter 101 | + ConvBias 102 | + Sized 103 | + Pool2dImpl 104 | + Dropout 105 | + Send 106 | + Sync 107 | + 'static 108 | { 109 | } 110 | -------------------------------------------------------------------------------- /zenu-matrix/src/device/nvidia/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{Device, DeviceBase}; 2 | use crate::{memory_pool::MemPoolError, num::Num, ZENU_MATRIX_STATE}; 3 | use serde::{Deserialize, Serialize}; 4 | 5 | #[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)] 6 | pub struct Nvidia; 7 | 8 | impl DeviceBase for Nvidia { 9 | fn raw_drop_ptr(ptr: *mut T) { 10 | zenu_cuda::runtime::cuda_free(ptr.cast::<::libc::c_void>()).unwrap(); 11 | } 12 | 13 | fn mem_pool_drop_ptr(ptr: *mut u8) -> Result<(), MemPoolError> { 14 | let state = &ZENU_MATRIX_STATE; 15 | state.nvidia.try_free(ptr) 16 | } 17 | 18 | fn clone_ptr(src: *const T, len: usize) -> *mut T { 19 | let bytes = len * std::mem::size_of::(); 20 | let dst = Self::alloc(bytes).unwrap().cast::(); 21 | zenu_cuda::runtime::cuda_copy( 22 | dst, 23 | src, 24 | len, 25 | zenu_cuda::runtime::ZenuCudaMemCopyKind::HostToHost, 26 | ) 27 | .unwrap(); 28 | dst 29 | } 30 | 31 | fn assign_item(ptr: *mut T, offset: usize, value: T) { 32 | zenu_cuda::kernel::set_memory(ptr, offset, value); 33 | } 34 | 35 | fn get_item(ptr: *const T, offset: usize) -> T { 36 | zenu_cuda::kernel::get_memory(ptr, offset) 37 | } 38 | 39 | fn from_vec(mut vec: Vec) -> *mut T { 40 | let ptr = Self::alloc(vec.len() * std::mem::size_of::()) 41 | .unwrap() 42 | .cast::(); 43 | zenu_cuda::runtime::cuda_copy( 44 | ptr, 45 | vec.as_mut_ptr(), 46 | vec.len(), 47 | zenu_cuda::runtime::ZenuCudaMemCopyKind::HostToDevice, 48 | ) 49 | .unwrap(); 50 | ptr 51 | } 52 | 53 | fn zeros(len: usize) -> *mut T { 54 | let bytes = len * std::mem::size_of::(); 55 | let ptr = Self::alloc(bytes).unwrap().cast::(); 56 | zenu_cuda::cublas::cublas_scal(len, T::zero(), ptr, 1).unwrap(); 57 | ptr 58 | } 59 | 60 | fn raw_alloc(num_bytes: usize) -> Result<*mut u8, String> { 61 | zenu_cuda::runtime::cuda_malloc_bytes(num_bytes) 62 | .map_err(|_| "cudaMalloc failed".to_string()) 63 | } 64 | 65 | fn mem_pool_alloc(num_bytes: usize) -> Result<*mut u8, MemPoolError> { 66 | let state = &ZENU_MATRIX_STATE; 67 | state.nvidia.try_alloc(num_bytes) 68 | } 69 | } 70 | 71 | impl Device for Nvidia {} 72 | -------------------------------------------------------------------------------- /zenu-matrix/src/dim/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod dim_dyn; 2 | pub mod dim_static; 3 | 4 | pub use dim_dyn::larger_shape; 5 | pub use dim_dyn::DimDyn; 6 | pub(crate) use dim_dyn::{into_dyn, smaller_shape}; 7 | pub use dim_static::{Dim0, Dim1, Dim2, Dim3, Dim4}; 8 | 9 | use std::{ 10 | fmt::Debug, 11 | ops::{Index, IndexMut}, 12 | }; 13 | 14 | pub trait DimTrait: 15 | Index 16 | + IndexMut 17 | + IntoIterator 18 | + Clone 19 | + Copy 20 | + Default 21 | + PartialEq 22 | + Debug 23 | + for<'a> From<&'a [usize]> 24 | + for<'a> From<&'a Self> 25 | + 'static 26 | { 27 | fn len(&self) -> usize; 28 | fn is_empty(&self) -> bool; 29 | fn is_overflow(&self, index: D) -> bool { 30 | assert!(self.len() >= index.len(), "Dimension mismatch"); 31 | 32 | index.into_iter().zip(*self).any(|(x, y)| x >= y) 33 | } 34 | fn num_elm(&self) -> usize { 35 | self.into_iter().product() 36 | } 37 | 38 | fn slice(&self) -> &[usize]; 39 | 40 | fn is_scalar(&self) -> bool { 41 | self.len() == 0 || self.num_elm() == 1 42 | } 43 | } 44 | 45 | pub trait LessDimTrait: DimTrait { 46 | type LessDim: DimTrait; 47 | 48 | fn remove_axis(&self, axis: usize) -> Self::LessDim { 49 | let mut default = DimDyn::default(); 50 | for i in 0..self.len() { 51 | if i == axis { 52 | continue; 53 | } 54 | default.push_dim(self[i]); 55 | } 56 | Self::LessDim::from(default.slice()) 57 | } 58 | } 59 | 60 | pub trait GreaterDimTrait: DimTrait { 61 | type GreaterDim: DimTrait; 62 | } 63 | 64 | #[expect(clippy::missing_panics_doc)] 65 | pub fn cal_offset(shape: D1, stride: D2) -> usize { 66 | assert!(shape.len() == stride.len(), "Dimension mismatch"); 67 | shape.into_iter().zip(stride).map(|(x, y)| x * y).sum() 68 | } 69 | 70 | pub fn default_stride(shape: D) -> D { 71 | let mut stride = shape; 72 | let n = shape.len(); 73 | 74 | if n == 0 { 75 | return stride; 76 | } 77 | 78 | if n == 1 { 79 | stride[0] = 1; 80 | return stride; 81 | } 82 | 83 | // 最後の次元のストライドは常に1 84 | stride[n - 1] = 1; 85 | 86 | // 残りの次元に対して、後ろから前へ計算 87 | for i in (0..n - 1).rev() { 88 | stride[i] = stride[i + 1] * shape[i + 1]; 89 | } 90 | 91 | stride 92 | } 93 | -------------------------------------------------------------------------------- /zenu-matrix/src/index/index_dyn_impl.rs: -------------------------------------------------------------------------------- 1 | use crate::{dim::DimTrait, shape_stride::ShapeStride}; 2 | 3 | use super::IndexAxisTrait; 4 | 5 | #[derive(Clone, Copy, Debug)] 6 | pub struct Index { 7 | axis: usize, 8 | index: usize, 9 | } 10 | 11 | impl Index { 12 | #[must_use] 13 | pub fn new(axis: usize, index: usize) -> Self { 14 | Self { axis, index } 15 | } 16 | 17 | #[must_use] 18 | pub fn axis(&self) -> usize { 19 | self.axis 20 | } 21 | 22 | #[must_use] 23 | pub fn index(&self) -> usize { 24 | self.index 25 | } 26 | } 27 | 28 | impl IndexAxisTrait for Index { 29 | fn get_shape_stride( 30 | &self, 31 | shape: Din, 32 | stride: Din, 33 | ) -> ShapeStride { 34 | let mut shape_v = Vec::new(); 35 | let mut stride_v = Vec::new(); 36 | for i in 0..shape.len() { 37 | if i == self.axis { 38 | continue; 39 | } 40 | shape_v.push(shape[i]); 41 | stride_v.push(stride[i]); 42 | } 43 | 44 | let new_shape = Dout::from(&shape_v as &[usize]); 45 | let new_stride = Dout::from(&stride_v as &[usize]); 46 | ShapeStride::new(new_shape, new_stride) 47 | } 48 | fn offset(&self, stride: Din) -> usize { 49 | stride[self.axis] * self.index 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /zenu-matrix/src/index/mod.rs: -------------------------------------------------------------------------------- 1 | #![expect(clippy::module_name_repetitions)] 2 | 3 | pub mod index_dyn_impl; 4 | pub mod index_impl; 5 | 6 | pub use index_impl::{Index0D, Index1D, Index2D, Index3D}; 7 | 8 | use crate::{dim::DimTrait, shape_stride::ShapeStride}; 9 | 10 | pub trait SliceTrait: Copy { 11 | type Dim: DimTrait; 12 | fn sliced_shape_stride(&self, shape: Self::Dim, stride: Self::Dim) -> ShapeStride; 13 | fn sliced_offset(&self, stride: Self::Dim) -> usize; 14 | } 15 | 16 | pub trait IndexAxisTrait: Copy { 17 | fn get_shape_stride( 18 | &self, 19 | shape: Din, 20 | stride: Din, 21 | ) -> ShapeStride; 22 | fn offset(&self, stride: Din) -> usize; 23 | } 24 | -------------------------------------------------------------------------------- /zenu-matrix/src/lib.rs: -------------------------------------------------------------------------------- 1 | use device::cpu::Cpu; 2 | use memory_pool::MemPool; 3 | 4 | pub mod concat; 5 | pub mod constructor; 6 | pub mod device; 7 | pub mod dim; 8 | pub mod index; 9 | pub mod matrix; 10 | pub mod matrix_blas; 11 | pub mod matrix_iter; 12 | pub mod nn; 13 | pub mod num; 14 | pub mod operation; 15 | pub mod shape_stride; 16 | pub mod slice; 17 | 18 | mod impl_ops; 19 | mod impl_serde; 20 | mod matrix_format; 21 | mod memory_pool; 22 | mod with_clousers; 23 | 24 | #[cfg(feature = "nvidia")] 25 | use device::nvidia::Nvidia; 26 | 27 | pub(crate) struct ZenuMatrixState { 28 | pub(crate) is_mem_pool_used: bool, 29 | pub(crate) cpu: MemPool, 30 | #[cfg(feature = "nvidia")] 31 | pub(crate) nvidia: MemPool, 32 | } 33 | 34 | impl Default for ZenuMatrixState { 35 | fn default() -> Self { 36 | let use_mem_pool = std::env::var("ZENU_USE_MEMPOOL").unwrap_or("1".to_string()) == "1"; 37 | ZenuMatrixState { 38 | is_mem_pool_used: use_mem_pool, 39 | cpu: MemPool::default(), 40 | #[cfg(feature = "nvidia")] 41 | nvidia: MemPool::default(), 42 | } 43 | } 44 | } 45 | 46 | pub(crate) static ZENU_MATRIX_STATE: once_cell::sync::Lazy = 47 | once_cell::sync::Lazy::new(ZenuMatrixState::default); 48 | -------------------------------------------------------------------------------- /zenu-matrix/src/matrix_blas/mod.rs: -------------------------------------------------------------------------------- 1 | #[derive(Copy, Clone, Debug, PartialEq, Eq)] 2 | pub enum BlasTrans { 3 | None, 4 | Ordinary, 5 | Conjugate, 6 | } 7 | 8 | #[derive(Copy, Clone, Debug, PartialEq, Eq)] 9 | pub enum BlasLayout { 10 | RowMajor, 11 | ColMajor, 12 | } 13 | -------------------------------------------------------------------------------- /zenu-matrix/src/memory_pool/data_ptr.rs: -------------------------------------------------------------------------------- 1 | use crate::device::DeviceBase; 2 | 3 | use super::MemPoolError; 4 | 5 | pub(super) struct DataPtr { 6 | pub ptr: *mut u8, 7 | pub bytes: usize, 8 | _marker: std::marker::PhantomData, 9 | } 10 | 11 | impl DataPtr { 12 | pub(super) fn new(bytes: usize) -> Result { 13 | let ptr = D::raw_alloc(bytes).map_err(|_| MemPoolError::DeviceMallocError)?; 14 | Ok(DataPtr { 15 | ptr, 16 | bytes, 17 | _marker: std::marker::PhantomData, 18 | }) 19 | } 20 | } 21 | 22 | impl Drop for DataPtr { 23 | fn drop(&mut self) { 24 | D::raw_drop_ptr(self.ptr); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /zenu-matrix/src/memory_pool/dynamic_buffer.rs: -------------------------------------------------------------------------------- 1 | use crate::device::DeviceBase; 2 | 3 | use super::{data_ptr::DataPtr, MemPoolError}; 4 | 5 | pub struct DynBuffer { 6 | data_ptr: DataPtr, 7 | } 8 | 9 | impl DynBuffer { 10 | pub fn new(bytes: usize) -> Result { 11 | Ok(DynBuffer { 12 | data_ptr: DataPtr::new(bytes)?, 13 | }) 14 | } 15 | 16 | pub fn start_ptr(&self) -> *mut u8 { 17 | self.data_ptr.ptr 18 | } 19 | 20 | pub fn bytes(&self) -> usize { 21 | self.data_ptr.bytes 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /zenu-matrix/src/memory_pool/dynamic_pool.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::{BTreeMap, HashMap}, 3 | ops::Bound::{Included, Unbounded}, 4 | sync::{Arc, Mutex}, 5 | }; 6 | 7 | use crate::device::DeviceBase; 8 | 9 | use super::{dynamic_buffer::DynBuffer, MemPoolError}; 10 | 11 | #[derive(Default)] 12 | pub struct DynMemPool { 13 | used_buffers: HashMap<*mut u8, Arc>>>, 14 | unused_buffers: BTreeMap>>>>, 15 | } 16 | 17 | impl DynMemPool { 18 | pub fn try_alloc(&mut self, bytes: usize) -> Result<*mut u8, MemPoolError> { 19 | if let Some(smallest_unused_bytes_over_request) = 20 | self.smallest_unused_bytes_over_request(bytes) 21 | { 22 | let buffers = self 23 | .unused_buffers 24 | .get_mut(&smallest_unused_bytes_over_request) 25 | .unwrap(); 26 | let buffer = buffers.pop().unwrap(); 27 | let ptr = buffer.lock().unwrap().start_ptr(); 28 | self.used_buffers.insert(ptr, buffer); 29 | if buffers.is_empty() { 30 | self.unused_buffers 31 | .remove(&smallest_unused_bytes_over_request); 32 | } 33 | Ok(ptr) 34 | } else { 35 | let buffer = Arc::new(Mutex::new(DynBuffer::new(bytes)?)); 36 | let ptr = buffer.lock().unwrap().start_ptr(); 37 | self.used_buffers.insert(ptr, buffer); 38 | Ok(ptr) 39 | } 40 | } 41 | 42 | pub fn try_free(&mut self, ptr: *mut u8) -> Result<(), MemPoolError> { 43 | let buffer = self 44 | .used_buffers 45 | .remove(&ptr) 46 | .ok_or(MemPoolError::DynMemPoolFreeError)?; 47 | let bytes = buffer.lock().unwrap().bytes(); 48 | self.unused_buffers.entry(bytes).or_default().push(buffer); 49 | Ok(()) 50 | } 51 | 52 | pub fn smallest_unused_bytes_over_request(&self, bytes: usize) -> Option { 53 | self.unused_buffers 54 | .range((Included(&bytes), Unbounded)) 55 | .next() 56 | .map(|(unused_bytes, _)| *unused_bytes) 57 | } 58 | 59 | pub fn contains(&self, ptr: *mut u8) -> bool { 60 | self.used_buffers.contains_key(&ptr) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /zenu-matrix/src/nn/col2im.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::Device, 3 | dim::DimDyn, 4 | matrix::{Matrix, Owned, Ref}, 5 | num::Num, 6 | slice_dynamic, 7 | }; 8 | 9 | #[expect(clippy::needless_pass_by_value)] 10 | pub(super) fn col2im( 11 | col: Matrix, DimDyn, D>, 12 | img_shape: [usize; 4], 13 | kernel_size: (usize, usize), 14 | stride: (usize, usize), 15 | pad: (usize, usize), 16 | ) -> Matrix, DimDyn, D> { 17 | let (batch_size, c, h, w) = (img_shape[0], img_shape[1], img_shape[2], img_shape[3]); 18 | let (kh, kw) = kernel_size; 19 | let (sh, sw) = stride; 20 | let (ph, pw) = pad; 21 | let (oh, ow) = ((h + 2 * ph - kh) / sh + 1, (w + 2 * pw - kw) / sw + 1); 22 | 23 | let mut img = 24 | Matrix::<_, DimDyn, _>::zeros([batch_size, c, h + 2 * ph + sh - 1, w + 2 * pw + sw - 1]); 25 | 26 | for j in 0..kh { 27 | let j_lim = j + sh * oh; 28 | for i in 0..kw { 29 | let i_lim = i + sw * ow; 30 | let col_ref = col.to_ref(); 31 | let col_ref = col_ref.slice_dyn(slice_dynamic!(.., .., j, i, .., ..)); 32 | 33 | let mut img_slice = img.to_ref_mut().slice_mut_dyn(slice_dynamic!( 34 | .., 35 | .., 36 | j..j_lim;sh, 37 | i..i_lim;sw 38 | )); 39 | img_slice += col_ref; 40 | } 41 | } 42 | 43 | let img = img.slice_dyn(slice_dynamic!(.., .., ph..ph + h, pw..pw + w)); 44 | img.new_matrix() 45 | } 46 | 47 | #[cfg(test)] 48 | mod col2im { 49 | use crate::{ 50 | device::cpu::Cpu, 51 | dim::DimDyn, 52 | matrix::{Matrix, Owned}, 53 | }; 54 | 55 | use super::col2im; 56 | 57 | #[expect(clippy::cast_precision_loss)] 58 | #[test] 59 | fn col2im_small() { 60 | let col = (1..=1350).map(|x| x as f32).collect::>(); 61 | let col = Matrix::, DimDyn, Cpu>::from_vec(col, [2, 3, 3, 3, 5, 5]); 62 | let img_shape = [2, 3, 5, 5]; 63 | let kernel_shape = (3, 3); 64 | let stride = (1, 1); 65 | let pad = (1, 1); 66 | let img = col2im(col.to_ref(), img_shape, kernel_shape, stride, pad); 67 | let ans = vec![ 68 | 216, 402, 408, 414, 328, 564, 963, 972, 981, 732, 594, 1008, 1017, 1026, 762, 624, 69 | 1053, 1062, 1071, 792, 576, 942, 948, 954, 688, 1116, 1752, 1758, 1764, 1228, 1914, 70 | 2988, 2997, 3006, 2082, 1944, 3033, 3042, 3051, 2112, 1974, 3078, 3087, 3096, 2142, 71 | 1476, 2292, 2298, 2304, 1588, 2016, 3102, 3108, 3114, 2128, 3264, 5013, 5022, 5031, 72 | 3432, 3294, 5058, 5067, 5076, 3462, 3324, 5103, 5112, 5121, 3492, 2376, 3642, 3648, 73 | 3654, 2488, 2916, 4452, 4458, 4464, 3028, 4614, 7038, 7047, 7056, 4782, 4644, 7083, 74 | 7092, 7101, 4812, 4674, 7128, 7137, 7146, 4842, 3276, 4992, 4998, 5004, 3388, 3816, 75 | 5802, 5808, 5814, 3928, 5964, 9063, 9072, 9081, 6132, 5994, 9108, 9117, 9126, 6162, 76 | 6024, 9153, 9162, 9171, 6192, 4176, 6342, 6348, 6354, 4288, 4716, 7152, 7158, 7164, 77 | 4828, 7314, 11088, 11097, 11106, 7482, 7344, 11133, 11142, 11151, 7512, 7374, 11178, 78 | 11187, 11196, 7542, 5076, 7692, 7698, 7704, 5188, 79 | ] 80 | .iter() 81 | .map(|&x| x as f32) 82 | .collect::>(); 83 | let ans = Matrix::, DimDyn, Cpu>::from_vec(ans, [2, 3, 5, 5]); 84 | assert!((img - ans).asum() < 1e-6); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /zenu-matrix/src/nn/conv/cpu/bias.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::cpu::Cpu, 3 | dim::DimDyn, 4 | matrix::{Matrix, Ref}, 5 | nn::conv::interface::ConvBias, 6 | num::Num, 7 | }; 8 | 9 | impl ConvBias for Cpu { 10 | fn conv2d_bias( 11 | input: Matrix, DimDyn, Self>, 12 | bias: Matrix, DimDyn, Self>, 13 | mut output: Matrix, DimDyn, Self>, 14 | ) { 15 | output.add_array(&input, &bias); 16 | } 17 | 18 | fn conv2d_bias_bkwd( 19 | d_output: Matrix, DimDyn, Self>, 20 | d_bias: Matrix, DimDyn, Self>, 21 | ) { 22 | let dy_0 = d_output.sum(0, true); 23 | let dy_0_2 = dy_0.to_ref().sum(2, true); 24 | d_bias.copy_from(&dy_0_2.to_ref().sum(3, true)); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /zenu-matrix/src/nn/conv/cpu/col2im.rs: -------------------------------------------------------------------------------- 1 | use crate::num::Num; 2 | 3 | use super::super::utils::conv_dim_out_size; 4 | 5 | #[expect( 6 | clippy::too_many_arguments, 7 | clippy::similar_names, 8 | clippy::cast_possible_wrap, 9 | clippy::cast_sign_loss 10 | )] 11 | pub fn col2im( 12 | col: &[T], 13 | n: usize, 14 | c_in: usize, 15 | h_in: usize, 16 | w_in: usize, 17 | kh: usize, 18 | kw: usize, 19 | pad_h: usize, 20 | pad_w: usize, 21 | stride_h: usize, 22 | stride_w: usize, 23 | dilation_h: usize, 24 | dilation_w: usize, 25 | dx: &mut [T], 26 | ) { 27 | let out_h = conv_dim_out_size(h_in, kh, pad_h, stride_h, dilation_h); 28 | let out_w = conv_dim_out_size(w_in, kw, pad_w, stride_w, dilation_w); 29 | 30 | let rows = c_in * kh * kw; 31 | let cols = n * out_h * out_w; 32 | assert_eq!(col.len(), rows * cols); 33 | 34 | for v in dx.iter_mut() { 35 | *v = T::zero(); 36 | } 37 | 38 | for c in 0..c_in { 39 | for k_h in 0..kh { 40 | let ih_base = k_h * dilation_h; 41 | for k_w in 0..kw { 42 | let iw_base = k_w * dilation_w; 43 | let row = c * kh * kw + k_h * kw + k_w; 44 | for ni in 0..n { 45 | for oh in 0..out_h { 46 | let ih_ = (oh * stride_h) as isize + ih_base as isize - pad_h as isize; 47 | if ih_ < 0 || ih_ >= h_in as isize { 48 | continue; 49 | } 50 | let ih = ih_ as usize; 51 | 52 | for ow in 0..out_w { 53 | let iw_ = (ow * stride_w) as isize + iw_base as isize - pad_w as isize; 54 | if iw_ < 0 || iw_ >= w_in as isize { 55 | continue; 56 | } 57 | let iw = iw_ as usize; 58 | 59 | let col_idx = ni * (out_h * out_w) + oh * out_w + ow; 60 | dx[ni * (c_in * h_in * w_in) + c * (h_in * w_in) + ih * w_in + iw] += 61 | col[row * cols + col_idx]; 62 | } 63 | } 64 | } 65 | } 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /zenu-matrix/src/nn/conv/cpu/conv_bkwd_data.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::cpu::Cpu, matrix_blas::BlasTrans, nn::conv::utils::conv_dim_out_size, num::Num, 3 | operation::mul::Gemm, 4 | }; 5 | 6 | use super::col2im::col2im; 7 | 8 | #[expect(clippy::too_many_arguments)] 9 | pub fn conv_bkwd_data( 10 | dy: &[T], 11 | filter: &[T], 12 | n: usize, 13 | c_in: usize, 14 | c_out: usize, 15 | h_in: usize, 16 | w_in: usize, 17 | kh: usize, 18 | kw: usize, 19 | pad_h: usize, 20 | pad_w: usize, 21 | stride_h: usize, 22 | stride_w: usize, 23 | dilation_h: usize, 24 | dilation_w: usize, 25 | dx: &mut [T], 26 | ) { 27 | let out_h = conv_dim_out_size(h_in, kh, pad_h, stride_h, dilation_h); 28 | let out_w = conv_dim_out_size(w_in, kw, pad_w, stride_w, dilation_w); 29 | 30 | let nn = n * out_h * out_w; 31 | let mut dy_mat = vec![T::zero(); c_out * nn]; 32 | for ni in 0..n { 33 | for co in 0..c_out { 34 | for oh in 0..out_h { 35 | for ow in 0..out_w { 36 | let dy_idx = ni * c_out * out_h * out_w + co * out_h * out_w + oh * out_w + ow; 37 | let mat_idx = co * nn + ni * (out_h * out_w) + oh * out_w + ow; 38 | dy_mat[mat_idx] = dy[dy_idx]; 39 | } 40 | } 41 | } 42 | } 43 | 44 | let m = c_in * kh * kw; 45 | let k = c_out; 46 | let ncol = n * out_h * out_w; 47 | let mut col = vec![T::zero(); m * ncol]; 48 | 49 | Cpu::gemm_unchecked( 50 | BlasTrans::Ordinary, 51 | BlasTrans::None, 52 | m, 53 | ncol, 54 | k, 55 | T::one(), 56 | filter.as_ptr(), 57 | c_in * kh * kw, 58 | dy_mat.as_ptr(), 59 | ncol, 60 | T::zero(), 61 | col.as_mut_ptr(), 62 | ncol, 63 | ); 64 | 65 | col2im( 66 | &col, n, c_in, h_in, w_in, kh, kw, pad_h, pad_w, stride_h, stride_w, dilation_h, 67 | dilation_w, dx, 68 | ); 69 | } 70 | -------------------------------------------------------------------------------- /zenu-matrix/src/nn/conv/cpu/conv_bkwd_filter.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::cpu::Cpu, matrix_blas::BlasTrans, nn::conv::utils::conv_dim_out_size, num::Num, 3 | operation::mul::Gemm, 4 | }; 5 | 6 | use super::im2col::im2col; 7 | 8 | #[expect(clippy::many_single_char_names, clippy::too_many_arguments)] 9 | pub fn conv_bkwd_filter( 10 | dy: &[T], 11 | x: &[T], 12 | n: usize, 13 | c_in: usize, 14 | c_out: usize, 15 | h_in: usize, 16 | w_in: usize, 17 | kh: usize, 18 | kw: usize, 19 | pad_h: usize, 20 | pad_w: usize, 21 | stride_h: usize, 22 | stride_w: usize, 23 | dilation_h: usize, 24 | dilation_w: usize, 25 | dw: &mut [T], 26 | ) { 27 | let out_h = conv_dim_out_size(h_in, kh, pad_h, stride_h, dilation_h); 28 | let out_w = conv_dim_out_size(w_in, kw, pad_w, stride_w, dilation_w); 29 | 30 | let col_size = c_in * kh * kw * n * out_h * out_w; 31 | let mut col = vec![T::zero(); col_size]; 32 | im2col( 33 | x, n, c_in, h_in, w_in, kh, kw, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, 34 | &mut col, 35 | ); 36 | 37 | let mut dy_mat = vec![T::zero(); c_out * n * out_h * out_w]; 38 | for ni in 0..n { 39 | for co in 0..c_out { 40 | for oh in 0..out_h { 41 | for ow in 0..out_w { 42 | let dy_idx = ni * c_out * out_h * out_w + co * out_h * out_w + oh * out_w + ow; 43 | let mat_idx = co * (n * out_h * out_w) + ni * (out_h * out_w) + oh * out_w + ow; 44 | dy_mat[mat_idx] = dy[dy_idx]; 45 | } 46 | } 47 | } 48 | } 49 | 50 | // dw = dy_mat * col^T 51 | // ここでtransb = BlasTrans::Ordinaryを使用して転置を内部的に指示し、forループでの転置を省略 52 | let m = c_out; 53 | let k = n * out_h * out_w; 54 | let p = c_in * kh * kw; 55 | 56 | let mut dw_temp = vec![T::zero(); c_out * c_in * kh * kw]; 57 | 58 | Cpu::gemm_unchecked( 59 | BlasTrans::None, // dy_matはそのまま 60 | BlasTrans::Ordinary, // colを転置扱い 61 | m, 62 | p, 63 | k, 64 | T::one(), 65 | dy_mat.as_ptr(), 66 | k, 67 | col.as_ptr(), 68 | k, // colは( C_in*KH*KW, N*out_h*out_w ) => 転置すれば(N*out_h*out_w, C_in*KH*KW) 69 | T::zero(), 70 | dw_temp.as_mut_ptr(), 71 | p, 72 | ); 73 | 74 | dw.copy_from_slice(&dw_temp); 75 | } 76 | -------------------------------------------------------------------------------- /zenu-matrix/src/nn/conv/cpu/conv_fwd.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::cpu::Cpu, matrix_blas::BlasTrans, nn::conv::utils::conv_dim_out_size, num::Num, 3 | operation::mul::Gemm, 4 | }; 5 | 6 | use super::im2col::im2col; 7 | 8 | #[expect(clippy::too_many_arguments)] 9 | pub fn conv_fwd( 10 | input: &[T], 11 | filter: &[T], 12 | n: usize, 13 | c_in: usize, 14 | h_in: usize, 15 | w_in: usize, 16 | c_out: usize, 17 | kh: usize, 18 | kw: usize, 19 | pad_h: usize, 20 | pad_w: usize, 21 | stride_h: usize, 22 | stride_w: usize, 23 | dilation_h: usize, 24 | dilation_w: usize, 25 | output: &mut [T], 26 | ) { 27 | let out_h = conv_dim_out_size(h_in, kh, pad_h, stride_h, dilation_h); 28 | let out_w = conv_dim_out_size(w_in, kw, pad_w, stride_w, dilation_w); 29 | 30 | let col_size = c_in * kh * kw * n * out_h * out_w; 31 | let mut col = vec![T::zero(); col_size]; 32 | im2col( 33 | input, n, c_in, h_in, w_in, kh, kw, pad_h, pad_w, stride_h, stride_w, dilation_h, 34 | dilation_w, &mut col, 35 | ); 36 | 37 | // out_mat = filter * col 38 | let mut out_mat = vec![T::zero(); c_out * n * out_h * out_w]; 39 | 40 | Cpu::gemm_unchecked( 41 | BlasTrans::None, 42 | BlasTrans::None, 43 | c_out, 44 | n * out_h * out_w, 45 | c_in * kh * kw, 46 | T::one(), 47 | filter.as_ptr(), 48 | c_in * kh * kw, 49 | col.as_ptr(), 50 | n * out_h * out_w, 51 | T::zero(), 52 | out_mat.as_mut_ptr(), 53 | n * out_h * out_w, 54 | ); 55 | 56 | // reshape out_mat to NCHW 57 | for ni in 0..n { 58 | for co in 0..c_out { 59 | for oh in 0..out_h { 60 | for ow in 0..out_w { 61 | let out_idx = ni * c_out * out_h * out_w + co * out_h * out_w + oh * out_w + ow; 62 | let mat_idx = co * (n * out_h * out_w) + ni * (out_h * out_w) + oh * out_w + ow; 63 | output[out_idx] = out_mat[mat_idx]; 64 | } 65 | } 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /zenu-matrix/src/nn/conv/cpu/im2col.rs: -------------------------------------------------------------------------------- 1 | use crate::num::Num; 2 | 3 | use super::super::utils::conv_dim_out_size; 4 | 5 | #[expect( 6 | clippy::too_many_arguments, 7 | clippy::similar_names, 8 | clippy::cast_possible_wrap, 9 | clippy::cast_sign_loss 10 | )] 11 | pub fn im2col( 12 | input: &[T], 13 | n: usize, 14 | c_in: usize, 15 | h_in: usize, 16 | w_in: usize, 17 | kh: usize, 18 | kw: usize, 19 | pad_h: usize, 20 | pad_w: usize, 21 | stride_h: usize, 22 | stride_w: usize, 23 | dilation_h: usize, 24 | dilation_w: usize, 25 | col: &mut [T], 26 | ) { 27 | let out_h = conv_dim_out_size(h_in, kh, pad_h, stride_h, dilation_h); 28 | let out_w = conv_dim_out_size(w_in, kw, pad_w, stride_w, dilation_w); 29 | 30 | let n_out = n * out_h * out_w; 31 | let c_kk = c_in * kh * kw; 32 | assert_eq!(col.len(), c_kk * n_out); 33 | 34 | for v in col.iter_mut() { 35 | *v = T::zero(); 36 | } 37 | 38 | for ni in 0..n { 39 | for c in 0..c_in { 40 | for k_h in 0..kh { 41 | let ih_base = k_h * dilation_h; 42 | for k_w in 0..kw { 43 | let iw_base = k_w * dilation_w; 44 | let row = c * kh * kw + k_h * kw + k_w; 45 | for oh in 0..out_h { 46 | let ih_ = (oh * stride_h) as isize + (ih_base as isize) - (pad_h as isize); 47 | if ih_ < 0 || ih_ >= h_in as isize { 48 | // ihが画像範囲外なのでスキップ 49 | continue; 50 | } 51 | let ih = ih_ as usize; 52 | 53 | for ow in 0..out_w { 54 | let iw_ = 55 | (ow * stride_w) as isize + (iw_base as isize) - (pad_w as isize); 56 | if iw_ < 0 || iw_ >= w_in as isize { 57 | // iwが画像範囲外なのでスキップ 58 | continue; 59 | } 60 | let iw = iw_ as usize; 61 | 62 | let col_idx = ni * (out_h * out_w) + oh * out_w + ow; 63 | col[row * n_out + col_idx] = input 64 | [ni * (c_in * h_in * w_in) + c * (h_in * w_in) + ih * w_in + iw]; 65 | } 66 | } 67 | } 68 | } 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /zenu-matrix/src/nn/conv/shape_check.rs: -------------------------------------------------------------------------------- 1 | use crate::dim::{DimDyn, DimTrait}; 2 | 3 | #[allow(clippy::module_name_repetitions)] 4 | pub fn shape_check_2d( 5 | input_shape: DimDyn, 6 | filter_shape: DimDyn, 7 | output_shape: DimDyn, 8 | stride: &[usize], 9 | padding: &[usize], 10 | dilation: &[usize], 11 | ) { 12 | assert_eq!(input_shape.len(), 4, "Input shape must have 4 dimensions."); 13 | assert_eq!( 14 | filter_shape.len(), 15 | 4, 16 | "Filter shape must have 4 dimensions." 17 | ); 18 | assert_eq!(stride.len(), 2, "Stride must have 2 dimensions."); 19 | assert_eq!(padding.len(), 2, "Padding must have 2 dimensions."); 20 | assert_eq!(dilation.len(), 2, "Dilation must have 2 dimensions."); 21 | assert_eq!( 22 | dilation.len(), 23 | stride.len(), 24 | "Dilation length must match the number of spatial dimensions." 25 | ); 26 | 27 | let h_out_expected = 28 | (input_shape[2] + 2 * padding[0] - dilation[0] * (filter_shape[2] - 1) - 1) / stride[0] + 1; 29 | let w_out_expected = 30 | (input_shape[3] + 2 * padding[1] - dilation[1] * (filter_shape[3] - 1) - 1) / stride[1] + 1; 31 | 32 | assert_eq!( 33 | h_out_expected, output_shape[2], 34 | "Output height mismatch: expected {}, got {}", 35 | h_out_expected, output_shape[2] 36 | ); 37 | assert_eq!( 38 | w_out_expected, output_shape[3], 39 | "Output width mismatch: expected {}, got {}", 40 | w_out_expected, output_shape[3] 41 | ); 42 | 43 | // channel数が一致しているか 44 | assert_eq!( 45 | input_shape[1], filter_shape[1], 46 | "Input and filter channel count must match." 47 | ); 48 | // outputのchannel数が一致しているか 49 | assert_eq!( 50 | filter_shape[0], output_shape[1], 51 | "Filter and output channel count must match." 52 | ); 53 | } 54 | -------------------------------------------------------------------------------- /zenu-matrix/src/nn/conv/utils.rs: -------------------------------------------------------------------------------- 1 | use crate::dim::{DimDyn, DimTrait}; 2 | 3 | pub(super) fn conv_output_shape( 4 | input: DimDyn, 5 | filter: DimDyn, 6 | stride: &[usize], 7 | padding: &[usize], 8 | dilation: &[usize], 9 | ) -> DimDyn { 10 | // 入力次元数のチェック 11 | // conv1d: input = [N, C_in, L], filter = [C_out, C_in, K] 12 | // conv2d: input = [N, C_in, H, W], filter = [C_out, C_in, K_h, K_w] 13 | // それ以外も想定可能だが、ここでは1D/2Dを想定とのこと 14 | assert!( 15 | input.len() == filter.len(), 16 | "Input and filter must have the same number of dimensions." 17 | ); 18 | assert!( 19 | input.len() >= 3 && input.len() <= 4, 20 | "This function currently supports conv1d or conv2d only." 21 | ); 22 | assert!( 23 | stride.len() == input.len() - 2, 24 | "Stride length must match the number of spatial dimensions." 25 | ); 26 | assert!( 27 | padding.len() == input.len() - 2, 28 | "Padding length must match the number of spatial dimensions." 29 | ); 30 | assert!( 31 | dilation.len() == input.len() - 2, 32 | "Dilation length must match the number of spatial dimensions." 33 | ); 34 | 35 | let mut output = DimDyn::default(); 36 | output.push_dim(input[0]); 37 | output.push_dim(filter[0]); 38 | 39 | for i in 2..input.len() { 40 | let in_size = input[i]; 41 | let kernel_size = filter[i]; 42 | let (strd, pad, dil) = (stride[i - 2], padding[i - 2], dilation[i - 2]); 43 | let out_size = conv_dim_out_size(in_size, kernel_size, pad, strd, dil); 44 | output.push_dim(out_size); 45 | } 46 | 47 | output 48 | } 49 | 50 | /// 1次元方向の出力サイズを計算するためのヘルパー関数 51 | /// `out_size` = ((`in_size` + 2*`pad` - `dil`*(`kernel_size`-1) - 1) / `stride`) + 1 52 | pub(super) fn conv_dim_out_size( 53 | in_size: usize, 54 | kernel_size: usize, 55 | pad: usize, 56 | stride: usize, 57 | dilation: usize, 58 | ) -> usize { 59 | ((in_size + 2 * pad - dilation * (kernel_size - 1) - 1) / stride) + 1 60 | } 61 | -------------------------------------------------------------------------------- /zenu-matrix/src/nn/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::device::DeviceBase; 2 | 3 | pub mod batch_norm; 4 | pub mod col2im; 5 | pub mod conv; 6 | pub mod dropout; 7 | pub mod im2col; 8 | pub mod pool2d; 9 | 10 | #[cfg(feature = "nvidia")] 11 | pub mod rnn; 12 | 13 | /// matrixでメモリ管理されないblobをrustのメモリ管理に任せるための構造体 14 | /// `cudnn`などで、計算する際にworkspaceを確保することが求められる。 15 | #[expect(unused)] 16 | pub(crate) struct NNCache { 17 | pub(crate) bytes: usize, 18 | pub(crate) ptr: *mut u8, 19 | _device: std::marker::PhantomData, 20 | } 21 | 22 | impl NNCache { 23 | #[allow(unused)] 24 | pub(crate) fn new(bytes: usize) -> Self { 25 | let ptr = D::alloc(bytes).unwrap(); 26 | Self { 27 | bytes, 28 | ptr, 29 | _device: std::marker::PhantomData, 30 | } 31 | } 32 | } 33 | 34 | impl Drop for NNCache { 35 | fn drop(&mut self) { 36 | D::drop_ptr(self.ptr); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /zenu-matrix/src/nn/rnn/gru_params.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::nvidia::Nvidia, 3 | dim::DimDyn, 4 | matrix::{Matrix, Owned}, 5 | num::Num, 6 | }; 7 | 8 | #[derive(Debug, Clone)] 9 | pub struct GRUOutput { 10 | pub y: Matrix, DimDyn, Nvidia>, 11 | pub hy: Matrix, DimDyn, Nvidia>, 12 | } 13 | 14 | #[derive(Debug, Clone)] 15 | pub struct GRUGrad { 16 | pub dx: Matrix, DimDyn, Nvidia>, 17 | pub dhx: Matrix, DimDyn, Nvidia>, 18 | } 19 | -------------------------------------------------------------------------------- /zenu-matrix/src/nn/rnn/lstm_params.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::nvidia::Nvidia, 3 | dim::DimDyn, 4 | matrix::{Matrix, Owned}, 5 | num::Num, 6 | }; 7 | 8 | #[derive(Debug, Clone)] 9 | pub struct LSTMOutput { 10 | pub y: Matrix, DimDyn, Nvidia>, 11 | pub hy: Matrix, DimDyn, Nvidia>, 12 | pub cy: Matrix, DimDyn, Nvidia>, 13 | } 14 | 15 | #[derive(Debug, Clone)] 16 | pub struct LSTMGrad { 17 | pub dx: Matrix, DimDyn, Nvidia>, 18 | pub dhx: Matrix, DimDyn, Nvidia>, 19 | pub dcx: Matrix, DimDyn, Nvidia>, 20 | } 21 | -------------------------------------------------------------------------------- /zenu-matrix/src/nn/rnn/mod.rs: -------------------------------------------------------------------------------- 1 | pub(super) mod descriptor; 2 | mod gru; 3 | mod gru_params; 4 | mod lstm; 5 | mod lstm_params; 6 | mod rnn; 7 | mod rnn_params; 8 | 9 | pub use descriptor::RNNDescriptor; 10 | pub use rnn_params::{RNNBkwdDataOutput, RNNOutput, RNNWeights}; 11 | -------------------------------------------------------------------------------- /zenu-matrix/src/num.rs: -------------------------------------------------------------------------------- 1 | #![expect(clippy::cast_precision_loss, clippy::cast_possible_truncation)] 2 | 3 | use std::{ 4 | fmt::{Debug, Display}, 5 | ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}, 6 | }; 7 | 8 | use num_traits::Float; 9 | 10 | use rand_distr::uniform::SampleUniform; 11 | use serde::Serialize; 12 | 13 | pub trait Num: 14 | Default 15 | + Clone 16 | + Copy 17 | + Debug 18 | + Display 19 | + Add 20 | + PartialOrd 21 | + Mul 22 | + Div 23 | + Sub 24 | + SubAssign 25 | + DivAssign 26 | + AddAssign 27 | + MulAssign 28 | + Float 29 | + SampleUniform 30 | + Serialize 31 | + 'static 32 | { 33 | fn is_f32() -> bool; 34 | fn minus_one() -> Self; 35 | fn from_usize(n: usize) -> Self; 36 | #[must_use] 37 | fn size() -> usize { 38 | std::mem::size_of::() 39 | } 40 | fn from_f32(f: f32) -> Self; 41 | fn from_f64(f: f64) -> Self; 42 | } 43 | 44 | impl Num for f32 { 45 | fn is_f32() -> bool { 46 | true 47 | } 48 | 49 | fn minus_one() -> f32 { 50 | -1.0 51 | } 52 | 53 | fn from_usize(n: usize) -> f32 { 54 | n as f32 55 | } 56 | 57 | fn from_f32(f: f32) -> Self { 58 | f 59 | } 60 | 61 | fn from_f64(f: f64) -> Self { 62 | f as f32 63 | } 64 | } 65 | 66 | impl Num for f64 { 67 | fn is_f32() -> bool { 68 | false 69 | } 70 | 71 | fn minus_one() -> f64 { 72 | -1.0 73 | } 74 | 75 | fn from_usize(n: usize) -> Self { 76 | n as f64 77 | } 78 | 79 | fn from_f32(f: f32) -> Self { 80 | f64::from(f) 81 | } 82 | 83 | fn from_f64(f: f64) -> Self { 84 | f 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /zenu-matrix/src/operation/add_axis.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::Device, 3 | dim::DimDyn, 4 | matrix::{Matrix, Repr}, 5 | }; 6 | 7 | impl Matrix { 8 | pub fn add_axis(&mut self, axis: usize) { 9 | let shape_stride = self.shape_stride(); 10 | let shape_stride = shape_stride.add_axis(axis); 11 | self.update_shape(shape_stride.shape()); 12 | self.update_stride(shape_stride.stride()); 13 | } 14 | } 15 | 16 | #[cfg(test)] 17 | mod add_axis_test { 18 | #![expect(clippy::float_cmp)] 19 | use crate::{ 20 | device::Device, 21 | dim::{DimDyn, DimTrait}, 22 | matrix::{Matrix, Owned}, 23 | }; 24 | 25 | fn test() { 26 | let mut a: Matrix, DimDyn, D> = Matrix::from_vec(vec![1., 2., 3., 4.], [2, 2]); 27 | a.add_axis(0); 28 | assert_eq!(a.shape().slice(), [1, 2, 2]); 29 | let ans: Matrix, DimDyn, D> = Matrix::from_vec(vec![1., 2., 3., 4.], [1, 2, 2]); 30 | let diff = a.to_ref() - ans.to_ref(); 31 | let diff = diff.asum(); 32 | assert_eq!(diff, 0.); 33 | } 34 | #[test] 35 | fn cpu() { 36 | test::(); 37 | } 38 | #[cfg(feature = "nvidia")] 39 | #[test] 40 | fn nvidia() { 41 | test::(); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /zenu-matrix/src/operation/mean.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::Device, 3 | dim::{DimDyn, DimTrait}, 4 | matrix::{Matrix, Owned, Repr}, 5 | num::Num, 6 | }; 7 | 8 | impl, S: DimTrait, D: Device> Matrix { 9 | pub fn mean(&self, axis: Option, keep_dim: bool) -> Matrix, DimDyn, D> { 10 | if let Some(axis) = axis { 11 | let sum_axis_num_elm = self.shape()[axis]; 12 | let sum = self.to_ref().into_dyn_dim().sum(axis, keep_dim); 13 | sum / T::from_usize(sum_axis_num_elm) 14 | } else { 15 | let asum = self.to_ref().asum(); 16 | let num_elm = self.shape().num_elm(); 17 | let mean = asum / T::from_usize(num_elm); 18 | Matrix::from_vec(vec![mean], []) 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /zenu-matrix/src/operation/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod add_axis; 2 | pub mod asum; 3 | pub mod basic_operations; 4 | pub mod broadcast; 5 | pub mod clip; 6 | pub mod copy_from; 7 | pub mod max; 8 | pub mod mean; 9 | pub mod mul; 10 | pub mod norm2; 11 | pub mod relu; 12 | pub mod reshape; 13 | pub mod softmax; 14 | pub mod split; 15 | pub mod stack; 16 | pub mod sum; 17 | pub mod to_default_stride; 18 | pub mod transpose; 19 | pub mod var; 20 | -------------------------------------------------------------------------------- /zenu-matrix/src/operation/norm2.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /zenu-matrix/src/operation/softmax.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::Device, 3 | dim::{DimDyn, DimTrait}, 4 | matrix::{Matrix, Ref, Repr}, 5 | num::Num, 6 | }; 7 | 8 | impl Matrix, DimDyn, D> { 9 | #[expect(clippy::missing_panics_doc)] 10 | pub fn softmax_assign>(&self, source: &Matrix, axis: usize) { 11 | assert!( 12 | axis < self.shape().len(), 13 | "axis must be less than the number of dimensions" 14 | ); 15 | assert!( 16 | self.shape().slice() == source.shape().slice(), 17 | "softmax shape mismatch" 18 | ); 19 | 20 | let max_diff = source.to_ref() - source.max_axis(axis, true); 21 | let mut output = max_diff.exp(); 22 | let sum = output.to_ref().sum(axis, true); 23 | output /= sum; 24 | self.copy_from(&output.to_ref()); 25 | } 26 | } 27 | 28 | #[cfg(test)] 29 | mod softmax { 30 | #![expect(clippy::unreadable_literal)] 31 | use crate::{ 32 | device::Device, 33 | dim::DimDyn, 34 | matrix::{Matrix, Owned}, 35 | }; 36 | 37 | fn softmax_1d() { 38 | let a = Matrix::, DimDyn, D>::from_vec(vec![1., 2., 3., 4.], [4]); 39 | let mut b = Matrix::, DimDyn, D>::zeros([4]); 40 | b.to_ref_mut().softmax_assign(&a, 0); 41 | let ans = Matrix::, DimDyn, D>::from_vec( 42 | vec![0.0320586, 0.08714432, 0.23688284, 0.64391428], 43 | [4], 44 | ); 45 | let diff = b - ans; 46 | assert!(diff.asum() < 1e-6); 47 | } 48 | #[test] 49 | fn softmax_1d_cpu() { 50 | softmax_1d::(); 51 | } 52 | #[cfg(feature = "nvidia")] 53 | #[test] 54 | fn softmax_1d_cuda() { 55 | softmax_1d::(); 56 | } 57 | 58 | fn softmax_2d() { 59 | let a = Matrix::, DimDyn, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]); 60 | let mut b = Matrix::, DimDyn, D>::zeros([2, 3]); 61 | b.to_ref_mut().softmax_assign(&a, 1); 62 | let ans = Matrix::, DimDyn, D>::from_vec( 63 | vec![ 64 | 0.09003057, 0.24472847, 0.66524096, 0.09003057, 0.24472847, 0.66524096, 65 | ], 66 | [2, 3], 67 | ); 68 | let diff = b - ans; 69 | assert!(diff.asum() < 1e-6); 70 | 71 | let a = Matrix::, DimDyn, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]); 72 | let mut b = Matrix::, DimDyn, D>::zeros([2, 3]); 73 | b.to_ref_mut().softmax_assign(&a, 0); 74 | let ans_2 = Matrix::, DimDyn, D>::from_vec( 75 | vec![ 76 | 0.04742587, 0.04742587, 0.04742587, 0.95257413, 0.95257413, 0.95257413, 77 | ], 78 | [2, 3], 79 | ); 80 | let diff = b - ans_2; 81 | assert!(diff.asum() < 1e-6); 82 | } 83 | #[test] 84 | fn softmax_2d_cpu() { 85 | softmax_2d::(); 86 | } 87 | #[cfg(feature = "nvidia")] 88 | #[test] 89 | fn softmax_2d_cuda() { 90 | softmax_2d::(); 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /zenu-matrix/src/operation/to_default_stride.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | device::Device, 3 | dim::DimTrait, 4 | matrix::{Matrix, Owned, Repr}, 5 | }; 6 | 7 | impl Matrix { 8 | pub fn to_default_stride(&self) -> Matrix, S, D> { 9 | let mut output: Matrix, S, D> = Matrix::alloc_like(self); 10 | { 11 | let output_view_mut = output.to_ref_mut(); 12 | output_view_mut.copy_from(self); 13 | } 14 | output 15 | } 16 | } 17 | 18 | #[expect(clippy::float_cmp)] 19 | #[cfg(test)] 20 | mod to_default_stride { 21 | use crate::{ 22 | dim::{default_stride, DimDyn}, 23 | slice_dynamic, 24 | }; 25 | 26 | use super::*; 27 | 28 | fn test_1d() { 29 | // 0 t0 16 f32 vec 30 | let v = vec![ 31 | 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 32 | ]; 33 | 34 | let m: Matrix, DimDyn, D> = Matrix::from_vec(v.clone(), [16]); 35 | let sliced = m.slice(slice_dynamic!(..;2)); 36 | // let default_strided: OwnedMatrixDyn = ToDefaultStride::to_default_stride(&sliced); 37 | let default_strided = sliced.to_default_stride(); 38 | 39 | assert_eq!( 40 | default_strided.shape_stride().stride(), 41 | (default_strided.shape_stride().stride()) 42 | ); 43 | 44 | assert_eq!(default_strided.index_item([0]), 0.); 45 | assert_eq!(default_strided.index_item([1]), 2.); 46 | assert_eq!(default_strided.index_item([2]), 4.); 47 | assert_eq!(default_strided.index_item([3]), 6.); 48 | assert_eq!(default_strided.index_item([4]), 8.); 49 | } 50 | #[test] 51 | fn test_1d_cpu() { 52 | test_1d::(); 53 | } 54 | #[cfg(feature = "nvidia")] 55 | #[test] 56 | fn test_1d_gpu() { 57 | test_1d::(); 58 | } 59 | 60 | fn test_2d() { 61 | // 0 t0 16 f32 vec 62 | let v = vec![ 63 | 0., 1., 2., 3., 4., 5., 6., 7., // 64 | 8., 9., 10., 11., 12., 13., 14., 15., 65 | ]; 66 | 67 | let m: Matrix, DimDyn, D> = Matrix::from_vec(v.clone(), [4, 4]); 68 | let sliced = m.slice(slice_dynamic!(..;2, ..;2)); 69 | let default_strided = sliced.to_default_stride(); 70 | 71 | assert_eq!( 72 | default_strided.shape_stride().stride(), 73 | default_stride(default_strided.shape_stride().shape()) 74 | ); 75 | 76 | assert_eq!(default_strided.index_item([0, 0]), 0.); 77 | assert_eq!(default_strided.index_item([0, 1]), 2.); 78 | assert_eq!(default_strided.index_item([1, 0]), 8.); 79 | assert_eq!(default_strided.index_item([1, 1]), 10.); 80 | } 81 | #[test] 82 | fn test_2d_cpu() { 83 | test_2d::(); 84 | } 85 | #[cfg(feature = "nvidia")] 86 | #[test] 87 | fn test_2d_gpu() { 88 | test_2d::(); 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /zenu-matrix/src/slice/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod dynamic; 2 | pub mod r#macro; 3 | pub mod slice_dim; 4 | pub mod static_dim_slice; 5 | 6 | pub use dynamic::*; 7 | pub use slice_dim::*; 8 | pub use static_dim_slice::*; 9 | -------------------------------------------------------------------------------- /zenu-optimizer/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu-optimizer" 3 | version = "0.1.1" 4 | edition = "2021" 5 | description = "A simple optimizer for neural networks" 6 | license = "MIT" 7 | repository = "https://github.com/bokutotu/zenu" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | zenu-matrix = { path = "../zenu-matrix", version = "0.1.2" } 13 | zenu-autograd = { path = "../zenu-autograd", version = "0.1.2" } 14 | zenu-layer = { path = "../zenu-layer", version = "0.1.1" } 15 | 16 | [dev-dependencies] 17 | zenu-test = { path="../zenu-test/" } 18 | zenu = { path="../zenu/"} 19 | rand = { version = "0.8.5", features = ["small_rng"] } 20 | rand_distr = "0.4.2" 21 | 22 | [lints] 23 | workspace = true 24 | 25 | [profile.bench] 26 | debug = true 27 | 28 | [features] 29 | nvidia = ["zenu-matrix/nvidia", "zenu-autograd/nvidia", "zenu-layer/nvidia"] 30 | -------------------------------------------------------------------------------- /zenu-optimizer/README.md: -------------------------------------------------------------------------------- 1 | # ZeNu Optimizer 2 | 3 | ZeNu Optimizer is a collection of optimization algorithms for training neural networks. It provides various optimizers that can be used with the ZeNu deep learning library. 4 | 5 | ## Features 6 | 7 | - Stochastic Gradient Descent (SGD) optimizer 8 | - Integration with ZeNu Autograd for gradient computation 9 | - Easy integration with ZeNu models and layers 10 | 11 | ## Getting Started 12 | 13 | To use ZeNu Optimizer in your Rust project, add the following to your `Cargo.toml` file: 14 | 15 | ```toml 16 | [dependencies] 17 | zenu-optimizer = "0.1.0" 18 | ``` 19 | 20 | Here's a simple example of using the SGD optimizer from ZeNu Optimizer: 21 | 22 | ```rust 23 | use zenu_autograd::{creator::from_vec::from_vec, Variable}; 24 | use zenu_optimizer::sgd::SGD; 25 | 26 | fn main() { 27 | let variable = from_vec(vec![1., 2., 3., 4., 5., 6.], [3, 2]); 28 | variable.set_grad(from_vec(vec![1., 2., 3., 4., 5., 6.], [3, 2])); 29 | 30 | let sgd = SGD::new(0.01); 31 | sgd.update(&[variable.clone()]); 32 | 33 | // The variable has been updated by the optimizer 34 | // Perform further computations with the updated variable 35 | } 36 | ``` 37 | 38 | For more details and examples, please refer to the [documentation](https://docs.rs/zenu-optimizer). 39 | 40 | ## License 41 | 42 | ZeNu Optimizer is licensed under the [MIT License](LICENSE). 43 | -------------------------------------------------------------------------------- /zenu-optimizer/src/adam.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, collections::HashMap, rc::Rc}; 2 | 3 | use zenu_autograd::{creator::zeros::zeros_like, Variable}; 4 | use zenu_layer::Parameters; 5 | use zenu_matrix::{device::Device, num::Num}; 6 | 7 | use crate::Optimizer; 8 | 9 | pub struct Adam { 10 | learning_rate: T, 11 | beta1: T, 12 | beta2: T, 13 | epsilon: T, 14 | step: Rc>, 15 | pub m: HashMap>, 16 | pub v: HashMap>, 17 | } 18 | 19 | impl> Optimizer for Adam { 20 | fn update(&self, parameters: &P) { 21 | *self.step.borrow_mut() += 1; 22 | let step = T::from_usize(*self.step.borrow()); 23 | 24 | let beta1_t = self.beta1.powf(step); 25 | let beta2_t = self.beta2.powf(step); 26 | 27 | let parameters = parameters 28 | .parameters() 29 | .iter() 30 | .filter_map(|(key, value)| { 31 | value 32 | .get_grad() 33 | .map(|grad| (key.clone(), (value.clone(), grad.clone()))) 34 | }) 35 | .collect::>(); 36 | 37 | for (k, (data, grad)) in ¶meters { 38 | let v = self.v.get(k).unwrap(); 39 | let m = self.m.get(k).unwrap(); 40 | let mut v = v.get_as_mut(); 41 | let mut m = m.get_as_mut(); 42 | let grad = grad.get_as_ref(); 43 | 44 | m *= self.beta1; 45 | m += grad.to_ref() * (T::one() - self.beta1); 46 | 47 | v *= self.beta2; 48 | v += grad.to_ref() * grad.to_ref() * (T::one() - self.beta2); 49 | 50 | let m_hat = m.clone() / (T::one() - beta1_t); 51 | let v_hat = v.clone() / (T::one() - beta2_t); 52 | 53 | let m_v_hat = m_hat / (v_hat.sqrt() + self.epsilon); 54 | let lr_mv_hat = m_v_hat * self.learning_rate; 55 | 56 | data.get_as_mut().sub_assign(&lr_mv_hat.to_ref()); 57 | } 58 | } 59 | } 60 | 61 | impl Adam { 62 | pub fn new( 63 | learning_rate: T, 64 | beta1: T, 65 | beta2: T, 66 | epsilon: T, 67 | model: &impl Parameters, 68 | ) -> Self { 69 | let m = model 70 | .parameters() 71 | .iter() 72 | .map(|(key, value)| (key.clone(), zeros_like(value))) 73 | .collect(); 74 | let v = model 75 | .parameters() 76 | .iter() 77 | .map(|(key, value)| (key.clone(), zeros_like(value))) 78 | .collect(); 79 | Self { 80 | learning_rate, 81 | beta1, 82 | beta2, 83 | epsilon, 84 | step: Rc::new(RefCell::new(0)), 85 | m, 86 | v, 87 | } 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /zenu-optimizer/src/adamw.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, collections::HashMap, rc::Rc}; 2 | 3 | use zenu_autograd::{creator::zeros::zeros_like, Variable}; 4 | use zenu_layer::Parameters; 5 | use zenu_matrix::{device::Device, num::Num}; 6 | 7 | use crate::Optimizer; 8 | 9 | pub struct AdamW { 10 | learning_rate: T, 11 | beta1: T, 12 | beta2: T, 13 | epsilon: T, 14 | weight_decay: T, 15 | step: Rc>, 16 | m: HashMap>, 17 | v: HashMap>, 18 | } 19 | 20 | impl> Optimizer for AdamW { 21 | fn update(&self, parameters: &P) { 22 | let step = *self.step.borrow() + T::one(); 23 | *self.step.borrow_mut() = step; 24 | 25 | let beta1_t = self.beta1.powf(step); 26 | let beta2_t = self.beta2.powf(step); 27 | 28 | let weight_keys: Vec<_> = parameters.weights().keys().cloned().collect(); 29 | 30 | let params = parameters 31 | .parameters() 32 | .iter() 33 | .filter_map(|(key, value)| { 34 | value 35 | .get_grad() 36 | .map(|grad| (key.clone(), (value.clone(), grad.clone()))) 37 | }) 38 | .collect::>(); 39 | 40 | for (k, (data, grad)) in params { 41 | let m = self.m.get(&k).unwrap(); 42 | let v = self.v.get(&k).unwrap(); 43 | let mut m = m.get_as_mut(); 44 | let mut v = v.get_as_mut(); 45 | let grad = grad.get_as_ref(); 46 | 47 | // Update m and v 48 | m *= self.beta1; 49 | m += grad.to_ref() * (T::one() - self.beta1); 50 | 51 | v *= self.beta2; 52 | v += grad.to_ref() * grad.to_ref() * (T::one() - self.beta2); 53 | 54 | let m_hat = m.clone() / (T::one() - beta1_t); 55 | let v_hat = v.clone() / (T::one() - beta2_t); 56 | 57 | let denom = v_hat.sqrt() + self.epsilon; 58 | let step_size = self.learning_rate; 59 | let update = m_hat / denom; 60 | 61 | if weight_keys.contains(&k) { 62 | data.get_as_mut().sub_assign( 63 | &(data.get_as_ref() * self.learning_rate * self.weight_decay).to_ref(), 64 | ); 65 | } 66 | 67 | data.get_as_mut().sub_assign(&(update * step_size).to_ref()); 68 | } 69 | } 70 | } 71 | impl AdamW { 72 | pub fn new( 73 | learning_rate: T, 74 | beta1: T, 75 | beta2: T, 76 | epsilon: T, 77 | weight_decay: T, 78 | model: &impl Parameters, 79 | ) -> Self { 80 | let m = model 81 | .parameters() 82 | .iter() 83 | .map(|(key, value)| (key.clone(), zeros_like(value))) 84 | .collect(); 85 | let v = model 86 | .parameters() 87 | .iter() 88 | .map(|(key, value)| (key.clone(), zeros_like(value))) 89 | .collect(); 90 | Self { 91 | learning_rate, 92 | beta1, 93 | beta2, 94 | epsilon, 95 | weight_decay, 96 | step: Rc::new(RefCell::new(T::zero())), 97 | m, 98 | v, 99 | } 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /zenu-optimizer/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod adam; 2 | pub mod adamw; 3 | pub mod sgd; 4 | 5 | use zenu_layer::Parameters; 6 | use zenu_matrix::{device::Device, num::Num}; 7 | 8 | pub trait Optimizer> { 9 | fn update(&self, parameters: &P); 10 | } 11 | -------------------------------------------------------------------------------- /zenu-optimizer/src/sgd.rs: -------------------------------------------------------------------------------- 1 | use zenu_layer::Parameters; 2 | use zenu_matrix::{device::Device, num::Num}; 3 | 4 | use crate::Optimizer; 5 | 6 | pub struct SGD { 7 | pub learning_rate: T, 8 | _device: std::marker::PhantomData, 9 | } 10 | 11 | impl SGD { 12 | pub fn new(learning_rate: T) -> Self { 13 | Self { 14 | learning_rate, 15 | _device: std::marker::PhantomData, 16 | } 17 | } 18 | } 19 | 20 | impl> Optimizer for SGD { 21 | fn update(&self, parameters: &P) { 22 | for data in parameters.parameters().values() { 23 | if let Some(grad) = data.get_grad() { 24 | let update_data = grad.get_data().to_ref() * self.learning_rate; 25 | let mut data = data.get_data_mut(); 26 | let mut data = data.to_ref_mut(); 27 | data -= update_data; 28 | } 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /zenu-test/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu-test" 3 | version = "0.1.0" 4 | edition = "2021" 5 | repository = "https://github.com/bokutotu/zenu" 6 | license = "MIT" 7 | description = "Testing framework for Zenu" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | 13 | zenu-matrix = { version= "0.1.2", path = "../zenu-matrix" } 14 | zenu-autograd = { version="0.1.1", path = "../zenu-autograd" } 15 | 16 | serde = { version = "1.0", features = ["derive"] } 17 | serde_json = "1.0" 18 | 19 | [lints] 20 | workspace = true 21 | -------------------------------------------------------------------------------- /zenu/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zenu" 3 | version = "0.1.2" 4 | edition = "2021" 5 | description = "A simple Deep Learning library for Rust" 6 | license = "MIT" 7 | repository = "https://github.com/bokutotu/zenu" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | zenu-matrix = { path = "../zenu-matrix", version = "0.1.2" } 13 | zenu-autograd = { path = "../zenu-autograd", version = "0.1.2" } 14 | zenu-layer = { path = "../zenu-layer", version = "0.1.1" } 15 | zenu-optimizer = { path = "../zenu-optimizer", version = "0.1.1" } 16 | zenu-macros = { path = "../zenu-macros", version = "0.1.0" } 17 | 18 | reqwest = { version = "0.12", features = ["json", "blocking"] } 19 | flate2 = { version = "1.0", features = ["zlib"] } 20 | rand = { version = "0.8.5", features = ["small_rng"] } 21 | serde = { version = "1.0.114", features = ["derive"] } 22 | bincode = "1.3.3" 23 | tar = "0.4.40" 24 | 25 | [dev-dependencies] 26 | zenu-test = { path = "../zenu-test" } 27 | 28 | [profile.bench] 29 | debug = true 30 | 31 | [features] 32 | nvidia = ["zenu-matrix/nvidia", "zenu-autograd/nvidia", "zenu-layer/nvidia", "zenu-optimizer/nvidia"] 33 | -------------------------------------------------------------------------------- /zenu/examples/install_mnist.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import pandas as pd 3 | 4 | mnist = tf.keras.datasets.mnist 5 | (train_images, train_labels), (test_images, test_labels) = mnist.load_data() 6 | 7 | train_images_flattened = train_images.reshape(train_images.shape[0], -1) 8 | test_images_flattened = test_images.reshape(test_images.shape[0], -1) 9 | 10 | train_df = pd.DataFrame(train_images_flattened) 11 | train_df.to_csv('mnist_train_flattened.txt', index=False) 12 | 13 | test_df = pd.DataFrame(test_images_flattened) 14 | test_df.to_csv('mnist_test_flattened.txt', index=False) 15 | 16 | train_labels_df = pd.DataFrame(train_labels) 17 | train_labels_df.to_csv('mnist_train_labels.txt', index=False) 18 | 19 | test_labels_df = pd.DataFrame(test_labels) 20 | test_labels_df.to_csv('mnist_test_labels.txt', index=False) 21 | -------------------------------------------------------------------------------- /zenu/examples/resnet.rs: -------------------------------------------------------------------------------- 1 | // use zenu_autograd::{ 2 | // functions::{activation::relu::relu, flatten::flatten, pool2d::max_pool_2d}, 3 | // Variable, 4 | // }; 5 | // use zenu_layer::{ 6 | // layers::{batch_norm_2d::BatchNorm2d, conv2d::Conv2d, linear::Linear}, 7 | // Module, 8 | // }; 9 | // use zenu_matrix::device::Device; 10 | // 11 | // struct ResBlock { 12 | // conv1: Conv2d, 13 | // batch_norm1: BatchNorm2d, 14 | // conv2: Conv2d, 15 | // batch_norm2: BatchNorm2d, 16 | // } 17 | // 18 | // impl Module for ResBlock { 19 | // fn call(&self, inputs: Variable) -> Variable { 20 | // let x = inputs.clone(); 21 | // let y = self.conv1.call(x.clone()); 22 | // let y = self.batch_norm1.call(y); 23 | // let y = relu(y); 24 | // let y = self.conv2.call(y); 25 | // let y = self.batch_norm2.call(y); 26 | // let y = y + x; 27 | // relu(y) 28 | // } 29 | // } 30 | // 31 | // impl ResBlock { 32 | // fn new( 33 | // in_channels: usize, 34 | // out_channels: usize, 35 | // kernel_size: (usize, usize), 36 | // padding: (usize, usize), 37 | // stride: (usize, usize), 38 | // ) -> Self { 39 | // let conv1 = Conv2d::new( 40 | // in_channels, 41 | // out_channels, 42 | // kernel_size, 43 | // padding, 44 | // stride, 45 | // true, 46 | // ); 47 | // let batch_norm1 = BatchNorm2d::new(out_channels, 0.9); 48 | // let conv2 = Conv2d::new( 49 | // out_channels, 50 | // out_channels, 51 | // kernel_size, 52 | // padding, 53 | // stride, 54 | // true, 55 | // ); 56 | // let batch_norm2 = BatchNorm2d::new(out_channels, 0.9); 57 | // Self { 58 | // conv1, 59 | // batch_norm1, 60 | // conv2, 61 | // batch_norm2, 62 | // } 63 | // } 64 | // } 65 | // 66 | // struct ResNet { 67 | // conv1: Conv2d, 68 | // res_block1: ResBlock, 69 | // res_block2: ResBlock, 70 | // linear: Linear, 71 | // } 72 | // 73 | // impl Module for ResNet { 74 | // fn call(&self, inputs: Variable) -> Variable { 75 | // let x = self.conv1.call(inputs.clone()); 76 | // // let x = max_pool_2d(x, (3, 3), (2, 2), (1, 1), config) 77 | // let x = relu(x); 78 | // let x = self.res_block1.call(x); 79 | // let x = self.res_block2.call(x); 80 | // let x = flatten(x); 81 | // self.linear.call(x) 82 | // } 83 | // } 84 | // 85 | fn main() { 86 | println!("Hello, world!"); 87 | } 88 | --------------------------------------------------------------------------------