├── logo.png ├── logo_small.png ├── scripts ├── bench_config.sh ├── bench_config_gpu.sh ├── bench_parallel.sh ├── debug_compile.sh ├── release_debug_compile.sh ├── test_runner.sh └── bench_runner.sh ├── .color_coded ├── .codedocs ├── test ├── include │ ├── test.hpp │ ├── test_light.hpp │ ├── compat.hpp │ └── template_test.hpp └── src │ ├── test.cpp │ ├── cross.cpp │ ├── tmp.cpp │ ├── special_cases.cpp │ ├── optimize_2.cpp │ ├── memory_slice.cpp │ ├── optimize_1.cpp │ ├── elt_logical.cpp │ ├── max_pool_upsample_deep.cpp │ ├── decomposition.cpp │ ├── avg_pool_upsample_deep.cpp │ ├── noise.cpp │ ├── conv_4d_valid_mixed.cpp │ ├── conv_4d_full_mixed.cpp │ ├── timed.cpp │ ├── assert.cpp │ ├── iterators.cpp │ ├── gemv_types.cpp │ ├── gevm_types.cpp │ ├── bias_add.cpp │ ├── transpose_front.cpp │ ├── dot.cpp │ ├── compare.cpp │ ├── serial.cpp │ ├── parallel.cpp │ └── alias.cpp ├── .gitignore ├── benchmark ├── src │ ├── benchmark_base.cpp │ ├── benchmark_trigo.cpp │ └── benchmark_batch_hint.cpp └── include │ └── benchmark_gemm.hpp ├── .gitmodules ├── sonar-project.properties ├── include └── etl │ ├── namespaces.hpp │ ├── restrict.hpp │ ├── random.hpp │ ├── sparse_storage.hpp │ ├── bce_impl.hpp │ ├── cce_impl.hpp │ ├── concepts_base.hpp │ ├── mse_impl.hpp │ ├── batch_softmax_impl.hpp │ ├── pool_impl.hpp │ ├── fft_impl.hpp │ ├── impl │ ├── blas │ │ ├── blas.hpp │ │ ├── dot.hpp │ │ └── outer.hpp │ ├── std │ │ ├── norm.hpp │ │ ├── dot.hpp │ │ ├── outer.hpp │ │ ├── bias_add.hpp │ │ ├── mse.hpp │ │ ├── cce.hpp │ │ ├── bce.hpp │ │ ├── det.hpp │ │ ├── sum.hpp │ │ └── convmtx2.hpp │ ├── inv.hpp │ ├── det.hpp │ ├── norm.hpp │ ├── conv.hpp │ ├── decomposition.hpp │ ├── egblas │ │ ├── or.hpp │ │ ├── and.hpp │ │ ├── xor.hpp │ │ ├── scalar_set.hpp │ │ ├── sigmoid.hpp │ │ ├── transpose_front.hpp │ │ ├── relu_der_out.hpp │ │ └── one_if_max_sub.hpp │ ├── conv_select.hpp │ └── cublas │ │ ├── dot.hpp │ │ ├── cuda.hpp │ │ └── outer.hpp │ ├── eval_visitors.hpp │ ├── sum_impl.hpp │ ├── dot_impl.hpp │ ├── transpose_impl.hpp │ ├── outer_impl.hpp │ ├── exit.hpp │ ├── bias_add_impl.hpp │ ├── op │ ├── generators.hpp │ ├── generators │ │ └── sequence.hpp │ ├── unary_op.hpp │ ├── binary │ │ ├── mod.hpp │ │ └── one_if.hpp │ └── binary_op.hpp │ ├── impl_enums.hpp │ ├── order.hpp │ ├── gemm_impl.hpp │ ├── adapters │ ├── diagonal_exception.hpp │ ├── lower_exception.hpp │ ├── upper_exception.hpp │ ├── uni_lower_exception.hpp │ ├── uni_upper_exception.hpp │ ├── strictly_lower_exception.hpp │ └── strictly_upper_exception.hpp │ ├── util │ ├── complex_cast.hpp │ └── variadic.hpp │ ├── builder │ ├── mul_expression_builder.hpp │ └── conv_expression_builder.hpp │ ├── traits_base.hpp │ ├── inline.hpp │ ├── duration.hpp │ ├── serializer.hpp │ ├── deserializer.hpp │ ├── stop.hpp │ ├── math.hpp │ ├── direct_fill.hpp │ ├── parallel_session.hpp │ ├── memory.hpp │ ├── conv_impl.hpp │ ├── expr_fwd.hpp │ ├── std.hpp │ ├── print.hpp │ └── crtp │ └── value_testable.hpp ├── workbench └── src │ ├── test_dim.cpp │ ├── test.cpp │ └── verify_cpm.cpp ├── .github └── workflows │ └── make.yml ├── LICENSE ├── Implementation.rst ├── .clang-format └── Jenkinsfile /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wichtounet/etl/HEAD/logo.png -------------------------------------------------------------------------------- /logo_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wichtounet/etl/HEAD/logo_small.png -------------------------------------------------------------------------------- /scripts/bench_config.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export ETL_EXTENDED=true 3 | export ETL_MKL=true 4 | export ETL_PARALLEL=true 5 | -------------------------------------------------------------------------------- /.color_coded: -------------------------------------------------------------------------------- 1 | -std=c++1y 2 | -Iinclude 3 | -Ilib/include 4 | -ICatch/include 5 | -Itest/include 6 | -include 7 | etl/etl.hpp 8 | -------------------------------------------------------------------------------- /.codedocs: -------------------------------------------------------------------------------- 1 | PROJECT_NAME = "Expression Templates Library (ETL)" 2 | EXCLUDE = Catch, cpm, lib, make-utils 3 | -------------------------------------------------------------------------------- /scripts/bench_config_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export ETL_EXTENDED=true 3 | export ETL_MKL=true 4 | export ETL_CUFFT=true 5 | export ETL_CUBLAS=true 6 | export ETL_CUDNN=true 7 | export ETL_PARALLEL=true 8 | -------------------------------------------------------------------------------- /scripts/bench_parallel.sh: -------------------------------------------------------------------------------- 1 | make clean 2 | rm -rf results/*.cpm 3 | 4 | unset ETL_PARALLEL 5 | make -j9 release_debug/bin/benchmark && ./release_debug/bin/benchmark -c serial -t now 6 | 7 | rm -rf release_debug 8 | export ETL_PARALLEL=true 9 | make -j9 release_debug/bin/benchmark && ./release_debug/bin/benchmark -c parallel -t now 10 | 11 | mkdir reports 12 | ../cpm/release_debug/bin/cpm results 13 | -------------------------------------------------------------------------------- /test/include/test.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "etl/etl.hpp" 9 | #include "compat.hpp" 10 | -------------------------------------------------------------------------------- /test/include/test_light.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "etl/etl_light.hpp" 9 | #include "compat.hpp" 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Configurations outputs 2 | debug 3 | release 4 | release_debug 5 | 6 | # Generated for clang tools 7 | etl_file_list 8 | tidy_report_light 9 | tidy_report_all 10 | .cache 11 | *.plist 12 | 13 | # Profiling files 14 | perf.data 15 | perf.data.old 16 | 17 | # Generated by the tests 18 | *.tmp.etl 19 | coverage_report*.xml 20 | 21 | # Generated by the benchmark 22 | reports 23 | results 24 | 25 | # Doxygen 26 | html 27 | latex 28 | -------------------------------------------------------------------------------- /benchmark/src/benchmark_base.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #define CPM_BENCHMARK "Basic Benchmarks" 9 | #include "benchmark.hpp" 10 | -------------------------------------------------------------------------------- /test/src/test.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #define DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE 9 | #define DOCTEST_CONFIG_SUPER_FAST_ASSERTS 10 | #define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN 11 | #include "doctest/doctest.h" 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "make-utils"] 2 | path = make-utils 3 | url = https://github.com/wichtounet/make-utils.git 4 | branch = master 5 | [submodule "lib/include/cpp_utils"] 6 | path = lib/include/cpp_utils 7 | url = https://github.com/wichtounet/cpp_utils.git 8 | [submodule "wiki"] 9 | path = wiki 10 | url = https://github.com/wichtounet/etl.wiki.git 11 | [submodule "cpm"] 12 | path = cpm 13 | url = https://github.com/wichtounet/cpm.git 14 | [submodule "doctest"] 15 | path = doctest 16 | url = https://github.com/onqtam/doctest.git 17 | branch = dev 18 | -------------------------------------------------------------------------------- /sonar-project.properties: -------------------------------------------------------------------------------- 1 | sonar.projectKey=wichtounet_etl 2 | sonar.organization=wichtounet-github 3 | 4 | # This is the name and version displayed in the SonarCloud UI. 5 | sonar.projectName=etl 6 | sonar.projectVersion=1.3.0 7 | 8 | # C++23 support in SonarCloud is experimental, so enable it 9 | sonar.cfamily.cpp23.enabled=true 10 | 11 | # Get coverage from gcov files 12 | sonar.cfamily.gcov.reportsPath=gcov-reports 13 | 14 | # Path is relative to the sonar-project.properties file. Replace "\" by "/" on Windows. 15 | #sonar.sources=. 16 | 17 | # Encoding of the source code. Default is default system encoding 18 | #sonar.sourceEncoding=UTF-8 19 | -------------------------------------------------------------------------------- /include/etl/namespaces.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #error Documentation only 9 | 10 | /*! 11 | * \namespace etl 12 | * \brief Root namespace for the ETL library. 13 | * 14 | * This namespace exposes all the public interface of the library, it is the 15 | * only one that should be used directly in client code. 16 | */ 17 | namespace etl {} 18 | -------------------------------------------------------------------------------- /include/etl/restrict.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Restrict macros. 11 | */ 12 | 13 | #pragma once 14 | 15 | #ifdef __GNUC__ 16 | 17 | #define ETL_RESTRICT __restrict 18 | 19 | #elif defined(__clang__) 20 | 21 | #define ETL_RESTRICT __restrict__ 22 | 23 | #else 24 | 25 | #define ETL_RESTRICT 26 | 27 | #endif 28 | -------------------------------------------------------------------------------- /include/etl/random.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Contains utilities for random generation 11 | */ 12 | 13 | #pragma once 14 | 15 | #include 16 | 17 | namespace etl { 18 | 19 | /*! 20 | * \brief The random engine used by the library 21 | */ 22 | using random_engine = std::mt19937_64; 23 | 24 | } //end of namespace etl 25 | -------------------------------------------------------------------------------- /include/etl/sparse_storage.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Containers the sparse_storage enum 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Enumeration for sparse storage formats 19 | */ 20 | enum class sparse_storage { 21 | COO ///< Coordinate Format (COO) 22 | }; 23 | 24 | } //end of namespace etl 25 | -------------------------------------------------------------------------------- /include/etl/bce_impl.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Enumeration for BCE implementations 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Enumeration describing the different implementations of BCE 19 | */ 20 | enum class bce_impl { 21 | STD, ///< Standard implementation 22 | EGBLAS ///< GPU implementation 23 | }; 24 | 25 | } //end of namespace etl 26 | -------------------------------------------------------------------------------- /include/etl/cce_impl.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Enumeration for CCE implementations 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Enumeration describing the different implementations of CCE 19 | */ 20 | enum class cce_impl { 21 | STD, ///< Standard implementation 22 | EGBLAS ///< GPU implementation 23 | }; 24 | 25 | } //end of namespace etl 26 | -------------------------------------------------------------------------------- /include/etl/concepts_base.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | 13 | namespace etl { 14 | 15 | template 16 | concept etl_expr = decay_traits::is_etl; 17 | 18 | template 19 | struct scalar; 20 | 21 | template 22 | concept expr_or_scalar = etl_expr || std::same_as>; 23 | 24 | } // namespace etl 25 | -------------------------------------------------------------------------------- /include/etl/mse_impl.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Enumeration for MSE implementations 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Enumeration describing the different implementations of MSE 19 | */ 20 | enum class mse_impl { 21 | STD, ///< Standard implementation 22 | EGBLAS ///< GPU implementation 23 | }; 24 | 25 | } //end of namespace etl 26 | -------------------------------------------------------------------------------- /include/etl/batch_softmax_impl.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Enumeration for batch_softmax implementations 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Enumeration describing the different implementations of CCE 19 | */ 20 | enum class batch_softmax_impl { 21 | STD, ///< Standard implementation 22 | CUDNN ///< GPU implementation 23 | }; 24 | 25 | } //end of namespace etl 26 | -------------------------------------------------------------------------------- /include/etl/pool_impl.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Enumeration for the pooling implementations 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Enumeration describing the different implementations of 19 | * pooling 20 | */ 21 | enum class pool_impl { 22 | STD, ///< Standard implementation 23 | CUDNN ///< CUDNN (GPU) implementation 24 | }; 25 | 26 | } //end of namespace etl 27 | -------------------------------------------------------------------------------- /include/etl/fft_impl.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Enumeration for the fft implementations 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief The different FFT implementations 19 | */ 20 | enum class fft_impl { 21 | STD, ///< The standard implementation 22 | MKL, ///< The Intel MKL implementation 23 | CUFFT ///< The NVidia CuFFT implementation 24 | }; 25 | 26 | } //end of namespace etl 27 | -------------------------------------------------------------------------------- /include/etl/impl/blas/blas.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | #ifdef ETL_MKL_MODE 11 | 12 | #define disable_blas_threads() \ 13 | auto etl_mkl_threads = mkl_get_max_threads(); \ 14 | mkl_set_num_threads(1); 15 | 16 | #define restore_blas_threads() mkl_set_num_threads(etl_mkl_threads); 17 | 18 | #else 19 | 20 | #define disable_blas_threads() \ 21 | 22 | #define restore_blas_threads() 23 | 24 | #endif 25 | -------------------------------------------------------------------------------- /workbench/src/test_dim.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "etl/etl.hpp" 9 | 10 | etl::fast_matrix a = {-1.0, 2.0, 5.0, 2.0, 5.0, 1.2}; 11 | etl::fast_matrix b = {-1.0, 2.0, 5.0, 1.2, 2.5, 3.0}; 12 | 13 | /* 14 | * Simple source file to verify how the code is compiled. 15 | */ 16 | 17 | int main(){ 18 | etl::fast_vector d(row(a,1) + 1.5 * (row(a,0) / 2.0 + row(b,1))); 19 | 20 | return static_cast(sum(d)); 21 | } 22 | -------------------------------------------------------------------------------- /.github/workflows/make.yml: -------------------------------------------------------------------------------- 1 | name: Linux Build 2 | 3 | on: 4 | push: 5 | branches: [ "master" ] 6 | pull_request: 7 | branches: [ "master" ] 8 | 9 | jobs: 10 | build: 11 | name: Compile and test on Linux 12 | runs-on: ubuntu-latest 13 | container: 14 | image: wichtounet/cpp:latest 15 | strategy: 16 | matrix: 17 | compiler: [gcc, clang] 18 | 19 | steps: 20 | - name: Checkout code 21 | uses: actions/checkout@v3 22 | with: 23 | submodules: recursive 24 | 25 | - name: Build binaries 26 | run: make -j5 debug compiler=${{ matrix.compiler }} 27 | 28 | - name: Build tests 29 | run: make -j5 debug_etl_test compiler=${{ matrix.compiler }} 30 | 31 | - name: Run tests 32 | run: make debug_test compiler=${{ matrix.compiler }} 33 | -------------------------------------------------------------------------------- /include/etl/impl/std/norm.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Standard implementation of the "norm" reduction 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl::impl::standard { 16 | 17 | /*! 18 | * \brief Compute the euclidean norm of a 19 | * \param a The expression 20 | * \return the euclidean norm 21 | */ 22 | template 23 | value_t norm(const A& a) { 24 | return std::sqrt(sum(scale(a, a))); 25 | } 26 | 27 | } //end of namespace etl::impl::standard 28 | -------------------------------------------------------------------------------- /include/etl/eval_visitors.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file eval_visitors.hpp 10 | * \brief Contains the visitors used by the evaluator to process the 11 | * expression trees. 12 | */ 13 | 14 | #pragma once 15 | 16 | namespace etl { 17 | 18 | namespace detail { 19 | 20 | /*! 21 | * \brief Visitor to perform local evaluation when necessary 22 | */ 23 | struct evaluator_visitor { 24 | // Nothing to configure 25 | }; 26 | 27 | } //end of namespace detail 28 | 29 | } //end of namespace etl 30 | -------------------------------------------------------------------------------- /include/etl/sum_impl.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Enumeration for sum implementations 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Enumeration describing the different implementations of sum 19 | */ 20 | enum class sum_impl { 21 | STD, ///< Standard implementation 22 | VEC, ///< Vectorized implementation 23 | BLAS, ///< BLAS implementation 24 | CUBLAS ///< BLAS implementation 25 | }; 26 | 27 | } //end of namespace etl 28 | -------------------------------------------------------------------------------- /include/etl/dot_impl.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Enumeration of the dot implementations 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Enumeration describing the different implementations of dot 19 | */ 20 | enum class dot_impl { 21 | STD, ///< Standard implementation 22 | VEC, ///< Uniform Vectorized implementation 23 | BLAS, ///< BLAS implementation 24 | CUBLAS ///< BLAS implementation 25 | }; 26 | 27 | } //end of namespace etl 28 | -------------------------------------------------------------------------------- /include/etl/transpose_impl.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Enumeration of the transpose implementations 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Enumeration describing the different implementations of transpose 19 | */ 20 | enum class transpose_impl { 21 | STD, ///< Standard implementation 22 | VEC, ///< Vectorized implementation 23 | MKL, ///< MKL implementation 24 | CUBLAS, ///< CUBLAS implementation 25 | }; 26 | 27 | } //end of namespace etl 28 | -------------------------------------------------------------------------------- /include/etl/impl/std/dot.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Standard implementation of the "dot" reduction 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl::impl::standard { 16 | 17 | /*! 18 | * \brief Compute the dot product of a and b 19 | * \param a The lhs expression 20 | * \param b The rhs expression 21 | * \return the sum 22 | */ 23 | template 24 | value_t dot(const A& a, const B& b) { 25 | return sum(scale(a, b)); 26 | } 27 | 28 | } //end of namespace etl::impl::standard 29 | -------------------------------------------------------------------------------- /include/etl/outer_impl.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Enumeration for the outer product implementations 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Enumeration describing the different implementations of 19 | * outer product 20 | */ 21 | enum class outer_impl { 22 | STD, ///< Standard implementation 23 | BLAS, ///< BLAS implementation 24 | CUBLAS, ///< CUBLAS implementation 25 | VEC ///< VEC implementation 26 | }; 27 | 28 | } //end of namespace etl 29 | -------------------------------------------------------------------------------- /include/etl/exit.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Produces exit utility when necessary 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Exit from ETL, releasing any possible resource. 19 | * 20 | * This function must be called if ETL_GPU_POOL is used 21 | */ 22 | inline void exit() { 23 | #ifdef ETL_CUDA 24 | #ifdef ETL_GPU_POOL 25 | etl::gpu_memory_allocator::clear(); 26 | #endif 27 | #endif 28 | } 29 | 30 | } // end of namespace etl 31 | 32 | #define ETL_PROLOGUE etl::exit(); 33 | -------------------------------------------------------------------------------- /include/etl/bias_add_impl.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Enumeration for the bias_add implementations 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Enumeration describing the different implementations of 19 | * bias_add 20 | */ 21 | enum class bias_add_impl { 22 | STD, ///< Standard implementation 23 | VEC, ///< VEC implementation 24 | EGBLAS, ///< ETL-GPU-BLAS (GPU) implementation 25 | CUDNN ///< CUDNN (GPU) implementation 26 | }; 27 | 28 | } //end of namespace etl 29 | -------------------------------------------------------------------------------- /include/etl/impl/inv.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | #include "etl/impl/std/inv.hpp" 11 | 12 | namespace etl::detail { 13 | 14 | /*! 15 | * \brief Functor for Inverse 16 | */ 17 | struct inv_impl { 18 | /*! 19 | * \brief Apply the functor 20 | * \param a The input sub expression 21 | * \param c The output sub expression 22 | */ 23 | template 24 | static void apply(A&& a, C&& c) { 25 | etl::impl::standard::inv(a, c); 26 | } 27 | }; 28 | 29 | } //end of namespace etl::detail 30 | -------------------------------------------------------------------------------- /include/etl/op/generators.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Contains generators 11 | */ 12 | 13 | #pragma once 14 | 15 | #include "etl/op/generators/normal.hpp" 16 | #include "etl/op/generators/truncated_normal.hpp" 17 | #include "etl/op/generators/uniform.hpp" 18 | #include "etl/op/generators/sequence.hpp" 19 | #include "etl/op/generators/dropout_mask.hpp" 20 | #include "etl/op/generators/inverted_dropout_mask.hpp" 21 | #include "etl/op/generators/state_dropout_mask.hpp" 22 | #include "etl/op/generators/state_inverted_dropout_mask.hpp" 23 | -------------------------------------------------------------------------------- /include/etl/impl_enums.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Utility header including all enumerations headers 11 | */ 12 | 13 | #pragma once 14 | 15 | #include "sum_impl.hpp" 16 | #include "cce_impl.hpp" 17 | #include "bce_impl.hpp" 18 | #include "mse_impl.hpp" 19 | #include "transpose_impl.hpp" 20 | #include "dot_impl.hpp" 21 | #include "conv_impl.hpp" 22 | #include "gemm_impl.hpp" 23 | #include "outer_impl.hpp" 24 | #include "bias_add_impl.hpp" 25 | #include "fft_impl.hpp" 26 | #include "pool_impl.hpp" 27 | #include "batch_softmax_impl.hpp" 28 | -------------------------------------------------------------------------------- /workbench/src/test.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "etl/etl.hpp" 9 | 10 | etl::fast_vector a = {-1.0, 2.0, 5.0, 2.0, 5.0, 1.2, 2.5, 1.2, -3.0, 3.5, 1.0}; 11 | etl::fast_vector b = {-1.0, 2.0, 5.0, 1.2, 2.5, 3.0, 4.0, 1.2, -3.0, 3.5, 1.0}; 12 | etl::fast_vector c = {1.2, -3.0, 3.5, 1.2, -3.0, 3.5, 1.0}; 13 | 14 | /* 15 | * Simple source file to verify how the code is compiled. 16 | */ 17 | 18 | int main(){ 19 | etl::fast_vector d((1.5 * (a * b + c)) / a); 20 | 21 | return static_cast(sum(d)); 22 | } 23 | -------------------------------------------------------------------------------- /include/etl/order.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | namespace etl { 11 | 12 | /*! 13 | * \brief Storage order of a matrix 14 | */ 15 | enum class order { 16 | RowMajor, ///< Row-Major storage 17 | ColumnMajor ///< Column-Major storage 18 | }; 19 | 20 | /*! 21 | * \brief Reverse the given storage order. 22 | * \param o The order to reverse 23 | * \return the reversed equivalent storage order 24 | */ 25 | constexpr order reverse(order o) { 26 | return o == order::RowMajor ? order::ColumnMajor : order::RowMajor; 27 | } 28 | 29 | } //end of namespace etl 30 | -------------------------------------------------------------------------------- /include/etl/gemm_impl.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Enumeration of the different matrix-matrix muliplication implementations 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Enumeration describing the different matrix-matrix 19 | * multiplication implementations 20 | */ 21 | enum class gemm_impl { 22 | STD, ///< Standard implmentation 23 | VEC, ///< Vectorized BLAS implementation 24 | BLAS, ///< BLAS implementation 25 | CUBLAS ///< CUBLAS (GPU) implementation 26 | }; 27 | 28 | } //end of namespace etl 29 | -------------------------------------------------------------------------------- /include/etl/adapters/diagonal_exception.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Contains diagonal matrix exception implementation 11 | */ 12 | 13 | #pragma once 14 | 15 | #include 16 | 17 | namespace etl { 18 | 19 | /*! 20 | * \brief Exception that is thrown when an operation is made to 21 | * a diagonal matrix that would render it non-diagonal. 22 | */ 23 | struct diagonal_exception : std::exception { 24 | /*! 25 | * \brief Returns a description of the exception 26 | */ 27 | const char* what() const noexcept override { 28 | return "Invalid assignment to a diagonal matrix"; 29 | } 30 | }; 31 | 32 | } //end of namespace etl 33 | -------------------------------------------------------------------------------- /include/etl/impl/det.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Selector for the determinant implementation 11 | */ 12 | 13 | #pragma once 14 | 15 | //Include the implementations 16 | #include "etl/impl/std/det.hpp" 17 | 18 | namespace etl::detail { 19 | 20 | /*! 21 | * \brief Functor for determinant 22 | */ 23 | struct det_impl { 24 | /*! 25 | * \brief Apply the functor to A 26 | * \param A The input matrix 27 | * \return the determinant of the matrix 28 | */ 29 | template 30 | static value_t apply(const AT& A) { 31 | return etl::impl::standard::det(A); 32 | } 33 | }; 34 | 35 | } //end of namespace etl::detail 36 | -------------------------------------------------------------------------------- /include/etl/impl/norm.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file norm.hpp 10 | * \brief Selector for the euclidean norm operation 11 | */ 12 | 13 | #pragma once 14 | 15 | //Include the implementations 16 | #include "etl/impl/std/norm.hpp" 17 | 18 | namespace etl::detail { 19 | 20 | /*! 21 | * \brief Functor for euclidean norm 22 | */ 23 | struct norm_impl { 24 | /*! 25 | * \brief Apply the functor to a 26 | * \param a the expression 27 | * \return the euclidean norm of a 28 | */ 29 | template 30 | static value_t apply(const A& a) { 31 | return etl::impl::standard::norm(a); 32 | } 33 | }; 34 | 35 | } //end of namespace etl::detail 36 | -------------------------------------------------------------------------------- /include/etl/adapters/lower_exception.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Contains lower triangular matrix exception implementation 11 | */ 12 | 13 | #pragma once 14 | 15 | #include 16 | 17 | namespace etl { 18 | 19 | /*! 20 | * \brief Exception that is thrown when an operation is made to 21 | * a lower triangular matrix that would render it non-lower 22 | * triangular. 23 | */ 24 | struct lower_exception : std::exception { 25 | /*! 26 | * \brief Returns a description of the exception 27 | */ 28 | const char* what() const noexcept override { 29 | return "Invalid assignment to a lower triangular matrix"; 30 | } 31 | }; 32 | 33 | } //end of namespace etl 34 | -------------------------------------------------------------------------------- /include/etl/adapters/upper_exception.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Contains upper triangular matrix exception implementation 11 | */ 12 | 13 | #pragma once 14 | 15 | #include 16 | 17 | namespace etl { 18 | 19 | /*! 20 | * \brief Exception that is thrown when an operation is made to 21 | * a upper triangular matrix that would render it non-upper 22 | * triangular. 23 | */ 24 | struct upper_exception : std::exception { 25 | /*! 26 | * \brief Returns a description of the exception 27 | */ 28 | const char* what() const noexcept override { 29 | return "Invalid assignment to an upper triangular matrix"; 30 | } 31 | }; 32 | 33 | } //end of namespace etl 34 | -------------------------------------------------------------------------------- /include/etl/util/complex_cast.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | namespace etl { 11 | 12 | inline cuComplex complex_cast(const std::complex& alpha) { 13 | return *reinterpret_cast(&alpha); 14 | } 15 | 16 | inline cuComplex complex_cast(const etl::complex& alpha) { 17 | return *reinterpret_cast(&alpha); 18 | } 19 | 20 | inline cuDoubleComplex complex_cast(const std::complex& alpha) { 21 | return *reinterpret_cast(&alpha); 22 | } 23 | 24 | inline cuDoubleComplex complex_cast(const etl::complex& alpha) { 25 | return *reinterpret_cast(&alpha); 26 | } 27 | 28 | } //end of namespace etl 29 | -------------------------------------------------------------------------------- /scripts/debug_compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | make clean 4 | 5 | max_memory="0.0" 6 | total="0.0" 7 | 8 | start=$(date "+%s.%N") 9 | 10 | for f in test/src/*.cpp 11 | do 12 | results=$(/usr/bin/time -v make debug/$f.o 2>&1) 13 | rss=$(echo "$results" | grep "Maximum resident set size" | rev | cut -d" " -f1 | rev) 14 | user_time=$(echo "$results" | grep "User time" | rev | cut -d" " -f1 | rev) 15 | elapsed=$(echo "$results" | grep "Elapsed" | rev | cut -d" " -f1 | rev) 16 | lines=$(cat $f | wc -l) 17 | memory=$(echo "scale=2; $rss/1024/4" | bc -l) 18 | relative=$(echo "1000.0 * ($user_time/$lines)" | bc -l) 19 | relative_fixed=$(echo "scale=2; $relative/1.0" | bc -l) 20 | echo "$f => $elapsed => ${memory}MB => ${relative_fixed} ms/l" 21 | 22 | if [ $(echo "$max_memory < $memory" | bc) -eq 1 ] 23 | then 24 | max_memory=$memory 25 | fi 26 | done 27 | 28 | end=$(date "+%s.%N") 29 | 30 | runtime=$(echo "scale=3; ($end - $start) / 1.0" | bc -l) 31 | 32 | echo "Max memory: $max_memory" 33 | echo "Total time: $runtime" 34 | -------------------------------------------------------------------------------- /include/etl/adapters/uni_lower_exception.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Contains uni lower triangular matrix exception implementation 11 | */ 12 | 13 | #pragma once 14 | 15 | #include 16 | 17 | namespace etl { 18 | 19 | /*! 20 | * \brief Exception that is thrown when an operation is made to 21 | * a uni lower triangular matrix that would render it non-uni lower 22 | * triangular. 23 | */ 24 | struct uni_lower_exception : std::exception { 25 | /*! 26 | * \brief Returns a description of the exception 27 | */ 28 | const char* what() const noexcept override { 29 | return "Invalid assignment to a uni lower triangular matrix"; 30 | } 31 | }; 32 | 33 | } //end of namespace etl 34 | -------------------------------------------------------------------------------- /include/etl/adapters/uni_upper_exception.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Contains uni upper triangular matrix exception implementation 11 | */ 12 | 13 | #pragma once 14 | 15 | #include 16 | 17 | namespace etl { 18 | 19 | /*! 20 | * \brief Exception that is thrown when an operation is made to 21 | * a uni upper triangular matrix that would render it non-uni upper 22 | * triangular. 23 | */ 24 | struct uni_upper_exception : std::exception { 25 | /*! 26 | * \brief Returns a description of the exception 27 | */ 28 | const char* what() const noexcept override { 29 | return "Invalid assignment to a uni upper triangular matrix"; 30 | } 31 | }; 32 | 33 | } //end of namespace etl 34 | -------------------------------------------------------------------------------- /test/src/cross.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test_light.hpp" 9 | #include "dot_test.hpp" 10 | 11 | TEMPLATE_TEST_CASE_2("cross/1", "[cross]", T, float, double) { 12 | etl::fast_vector a = {1.0, 2.0, 3.0}; 13 | etl::fast_vector b = {4.0, 5.0, 6.0}; 14 | 15 | auto c = etl::cross(a, b); 16 | 17 | REQUIRE_EQUALS(c[0], -3.0); 18 | REQUIRE_EQUALS(c[1], 6.0); 19 | REQUIRE_EQUALS(c[2], -3.0); 20 | } 21 | 22 | TEMPLATE_TEST_CASE_2("cross/2", "[cross]", T, float, double) { 23 | etl::dyn_vector a{1.0, 2.0, 3.0}; 24 | etl::dyn_vector b{4.0, 5.0, 6.0}; 25 | 26 | auto c = etl::cross(a, b); 27 | 28 | REQUIRE_EQUALS(c[0], -3.0); 29 | REQUIRE_EQUALS(c[1], 6.0); 30 | REQUIRE_EQUALS(c[2], -3.0); 31 | } 32 | -------------------------------------------------------------------------------- /scripts/release_debug_compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | make clean 4 | 5 | max_memory="0.0" 6 | total="0.0" 7 | 8 | start=$(date "+%s.%N") 9 | 10 | for f in test/src/*.cpp 11 | do 12 | results=$(/usr/bin/time -v make release_debug/$f.o 2>&1) 13 | rss=$(echo "$results" | grep "Maximum resident set size" | rev | cut -d" " -f1 | rev) 14 | user_time=$(echo "$results" | grep "User time" | rev | cut -d" " -f1 | rev) 15 | elapsed=$(echo "$results" | grep "Elapsed" | rev | cut -d" " -f1 | rev) 16 | lines=$(cat $f | wc -l) 17 | memory=$(echo "scale=2; $rss/1024/4" | bc -l) 18 | relative=$(echo "1000.0 * ($user_time/$lines)" | bc -l) 19 | relative_fixed=$(echo "scale=2; $relative/1.0" | bc -l) 20 | echo "$f => $elapsed => ${memory}MB => ${relative_fixed} ms/l" 21 | 22 | if [ $(echo "$max_memory < $memory" | bc) -eq 1 ] 23 | then 24 | max_memory=$memory 25 | fi 26 | done 27 | 28 | end=$(date "+%s.%N") 29 | 30 | runtime=$(echo "scale=3; ($end - $start) / 1.0" | bc -l) 31 | 32 | echo "Max memory: $max_memory" 33 | echo "Total time: $runtime" 34 | -------------------------------------------------------------------------------- /test/src/tmp.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "compat.hpp" 9 | 10 | #include "cpp_utils/tmp.hpp" 11 | #include "cpp_utils/assert.hpp" 12 | #include "etl/tmp.hpp" 13 | 14 | ETL_TEST_CASE("tmp/sequence_equal/1", "[tmp]") { 15 | REQUIRE_DIRECT((etl::sequence_equal, std::index_sequence<2>>::value)); 16 | REQUIRE_DIRECT((etl::sequence_equal, std::index_sequence<>>::value)); 17 | REQUIRE_DIRECT((etl::sequence_equal, std::index_sequence<1, 2>>::value)); 18 | REQUIRE_DIRECT(!(etl::sequence_equal, std::index_sequence<1, 2, 3>>::value)); 19 | REQUIRE_DIRECT(!(etl::sequence_equal, std::index_sequence<1, 2, 3>>::value)); 20 | } 21 | -------------------------------------------------------------------------------- /include/etl/adapters/strictly_lower_exception.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Contains strictly lower triangular matrix exception implementation 11 | */ 12 | 13 | #pragma once 14 | 15 | #include 16 | 17 | namespace etl { 18 | 19 | /*! 20 | * \brief Exception that is thrown when an operation is made to 21 | * a strictly lower triangular matrix that would render it non-strictly lower 22 | * triangular. 23 | */ 24 | struct strictly_lower_exception : std::exception { 25 | /*! 26 | * \brief Returns a description of the exception 27 | */ 28 | const char* what() const noexcept override { 29 | return "Invalid assignment to a strictly lower triangular matrix"; 30 | } 31 | }; 32 | 33 | } //end of namespace etl 34 | -------------------------------------------------------------------------------- /include/etl/adapters/strictly_upper_exception.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Contains strictly upper triangular matrix exception implementation 11 | */ 12 | 13 | #pragma once 14 | 15 | #include 16 | 17 | namespace etl { 18 | 19 | /*! 20 | * \brief Exception that is thrown when an operation is made to 21 | * a strictly upper triangular matrix that would render it non-strictly-upper 22 | * triangular. 23 | */ 24 | struct strictly_upper_exception : std::exception { 25 | /*! 26 | * \brief Returns a description of the exception 27 | */ 28 | const char* what() const noexcept override { 29 | return "Invalid assignment to a strictly upper triangular matrix"; 30 | } 31 | }; 32 | 33 | } //end of namespace etl 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014-2023 Baptiste Wicht 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /include/etl/builder/mul_expression_builder.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file mul_expression_builder.hpp 10 | * \brief Contains all the operators and functions to build multiplication expressions. 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Multiply two matrices together lazily (expression templates) 19 | * \param a The left hand side matrix 20 | * \param b The right hand side matrix 21 | * \return An expression representing the matrix-matrix multiplication of a and b 22 | */ 23 | template 24 | auto lazy_mul(A&& a, B&& b) -> detail::stable_transform_binary_helper { 25 | return detail::stable_transform_binary_helper{mm_mul_transformer, detail::build_type>(a, b)}; 26 | } 27 | 28 | } //end of namespace etl 29 | -------------------------------------------------------------------------------- /test/src/special_cases.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test_light.hpp" 9 | 10 | TEMPLATE_TEST_CASE_2("deep_assign/vec", "deep_assign", Z, float, double) { 11 | etl::fast_vector, 2> a; 12 | 13 | a = 0.0; 14 | 15 | for (auto& v : a) { 16 | for (auto& v2 : v) { 17 | REQUIRE_EQUALS(v2, 0.0); 18 | } 19 | } 20 | } 21 | 22 | TEMPLATE_TEST_CASE_2("deep_assign/mat", "deep_assign", Z, float, double) { 23 | etl::fast_matrix, 2, 3> a; 24 | 25 | a = 0.0; 26 | 27 | for (auto& v : a) { 28 | for (auto& v2 : v) { 29 | REQUIRE_EQUALS(v2, 0.0); 30 | } 31 | } 32 | } 33 | 34 | TEMPLATE_TEST_CASE_2("deep_assign/mat", "deep_assign", Z, float, double) { 35 | etl::fast_matrix, 2, 3> a; 36 | 37 | a = 0.0; 38 | 39 | for (auto& v : a) { 40 | for (auto& v2 : v) { 41 | REQUIRE_EQUALS(v2, 0.0); 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /include/etl/util/variadic.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brie Contains the safe_cast function overloads. 11 | * 12 | * This function helps generic code in the BLAS/CUBLAS wrappers to 13 | * be able to convert types from etl::complex to std::complex 14 | * easily. 15 | */ 16 | 17 | #pragma once 18 | 19 | namespace etl::util { 20 | 21 | /*! 22 | * \brief Returns the size of a matrix given its dimensions 23 | */ 24 | inline size_t size(size_t first) { 25 | return first; 26 | } 27 | 28 | /*! 29 | * \brief Returns the size of a matrix given its dimensions 30 | */ 31 | template 32 | inline size_t size(size_t first, T... args) { 33 | return first * size(args...); 34 | } 35 | 36 | /*! 37 | * \brief Returns the size of a matrix given its dimensions 38 | */ 39 | template 40 | inline size_t size(const std::index_sequence& /*i*/, const T&... args) { 41 | return size((cpp::nth_value(args...))...); 42 | } 43 | 44 | } //end of namespace etl::util 45 | -------------------------------------------------------------------------------- /test/include/compat.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #define UNIQUE_NAME_LINE2( name, line ) name##line 9 | #define UNIQUE_NAME_LINE( name, line ) UNIQUE_NAME_LINE2( name, line ) 10 | #define UNIQUE_NAME( name ) UNIQUE_NAME_LINE( name, __LINE__ ) 11 | 12 | #define DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE 13 | #define DOCTEST_CONFIG_SUPER_FAST_ASSERTS 14 | #include 15 | #include "doctest/doctest.h" 16 | 17 | constexpr auto base_eps = std::numeric_limits::epsilon() * 100; 18 | constexpr auto base_eps_etl = 0.0000001f; 19 | constexpr auto base_eps_etl_large = 0.01f; 20 | 21 | #define ETL_TEST_CASE(name, description) TEST_CASE(name) 22 | #define ETL_SECTION(name) SUBCASE(name) 23 | 24 | #include "template_test.hpp" 25 | 26 | #define REQUIRE_DIRECT(value) FAST_CHECK_UNARY(value) 27 | #define REQUIRE_EQUALS(lhs, rhs) FAST_CHECK_EQ(lhs, rhs) 28 | #define REQUIRE_EQUALS_APPROX(lhs, rhs) FAST_CHECK_EQ(lhs, doctest::Approx(rhs)) 29 | #define REQUIRE_EQUALS_APPROX_E(lhs, rhs, eps) FAST_CHECK_EQ(lhs, doctest::Approx(rhs).epsilon(eps)) 30 | -------------------------------------------------------------------------------- /include/etl/impl/conv.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Selector for the convolution implementations. 11 | * 12 | * The functions are responsible for selecting the most efficient 13 | * implementation for each case, based on what is available. The selection of 14 | * parallel versus serial is also done at this level. The implementation 15 | * functions should never be used directly, only functions of this header can 16 | * be used directly. 17 | * 18 | * Ideas for improvements: 19 | * * Parallel dispatching for SSE/AVX implementation is not perfect, it should be done inside the micro kernel main loop 20 | */ 21 | 22 | #pragma once 23 | 24 | //Include the implementations 25 | #include "etl/impl/std/conv.hpp" 26 | #include "etl/impl/vec/conv.hpp" 27 | #include "etl/impl/cudnn/conv.hpp" 28 | #include "etl/impl/egblas/conv_1d.hpp" 29 | 30 | #include "etl/impl/conv_select.hpp" // The selection functions 31 | 32 | // All the descriptors 33 | #include "etl/impl/conv_2d.hpp" 34 | #include "etl/impl/conv_4d.hpp" 35 | #include "etl/impl/conv_multi.hpp" 36 | -------------------------------------------------------------------------------- /test/src/optimize_2.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test_light.hpp" 9 | 10 | #ifndef ETL_CUDA 11 | 12 | ETL_TEST_CASE("optimize/6", "[fast][optimizer]") { 13 | etl::fast_vector a({1.0, -2.0, 3.0}); 14 | etl::fast_vector b; 15 | 16 | b = opt(0.0 * a + 0.0 * a); 17 | 18 | REQUIRE_EQUALS(b[0], 0.0); 19 | } 20 | 21 | ETL_TEST_CASE("optimize/7", "[fast][optimizer]") { 22 | etl::fast_vector a({1.0, -2.0, 3.0}); 23 | etl::fast_vector b; 24 | 25 | b = opt(0.0 * a + 0.0 * a + 1.0 * a); 26 | 27 | REQUIRE_EQUALS(b[0], 1.0); 28 | } 29 | 30 | ETL_TEST_CASE("optimize/8", "[fast][optimizer]") { 31 | etl::fast_vector a({1.0, -2.0, 3.0}); 32 | etl::fast_vector b; 33 | 34 | b = opt(0.0 * a + 1.0 * a + 1.0 * (a - 0)); 35 | 36 | REQUIRE_EQUALS(b[0], 2.0); 37 | } 38 | 39 | ETL_TEST_CASE("optimize/10", "[fast][optimizer]") { 40 | etl::fast_vector a({1.0, -2.0, 3.0}); 41 | etl::fast_vector b; 42 | 43 | b = opt(+((-(a * 1.0)) * 1.0)); 44 | 45 | REQUIRE_EQUALS(b[0], -1.0); 46 | } 47 | 48 | #endif 49 | -------------------------------------------------------------------------------- /include/etl/traits_base.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | namespace etl { 11 | 12 | /*! 13 | * \brief Traits to get information about ETL types 14 | * 15 | * For non-ETL types, is_etl is false and in that case, no other fields should be used on the traits. 16 | * 17 | * \tparam T the type to introspect 18 | */ 19 | template 20 | struct etl_traits { 21 | static constexpr bool is_etl = false; ///< Indicates if T is an ETL type 22 | static constexpr bool is_transformer = false; ///< Indicates if T is a transformer 23 | static constexpr bool is_view = false; ///< Indicates if T is a view 24 | static constexpr bool is_magic_view = false; ///< Indicates if T is a magic view 25 | static constexpr bool is_fast = false; ///< Indicates if T is a fast structure 26 | static constexpr bool is_generator = false; ///< Indicates if T is a generator expression 27 | 28 | /*! 29 | * \brief Return the number of dimensions of the expression 30 | */ 31 | static constexpr size_t dimensions() { 32 | return 0; 33 | } 34 | }; 35 | 36 | } //end of namespace etl 37 | -------------------------------------------------------------------------------- /include/etl/inline.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Inlining macros. 11 | */ 12 | 13 | #pragma once 14 | 15 | #ifdef __clang__ 16 | #define ETL_INLINE_ATTR_VEC __attribute__((__always_inline__, __nodebug__)) 17 | 18 | #define ETL_INLINE(RRRR) inline RRRR __attribute__((__always_inline__, __nodebug__)) 19 | #define ETL_STRONG_INLINE(RRRR) inline RRRR __attribute__((__always_inline__)) 20 | #define ETL_STATIC_INLINE(RRRR) static ETL_INLINE(RRRR) 21 | #define ETL_TMP_INLINE(RRRR) static inline RRRR __attribute__((__always_inline__, __nodebug__)) 22 | #define ETL_OUT_INLINE(RRRR) inline RRRR __attribute__((__always_inline__, __nodebug__)) 23 | #else 24 | #define ETL_INLINE_ATTR_VEC __attribute__((__always_inline__, __artificial__)) 25 | 26 | #define ETL_INLINE(RRRR) inline RRRR __attribute__((__always_inline__, __gnu_inline__, __artificial__)) 27 | #define ETL_STRONG_INLINE(RRRR) inline RRRR __attribute__((__always_inline__, __gnu_inline__)) 28 | #define ETL_STATIC_INLINE(RRRR) static ETL_INLINE(RRRR) 29 | #define ETL_TMP_INLINE(RRRR) static inline RRRR __attribute__((__always_inline__, __artificial__)) 30 | #define ETL_OUT_INLINE(RRRR) inline RRRR __attribute__((__always_inline__, __artificial__)) 31 | #endif 32 | -------------------------------------------------------------------------------- /include/etl/duration.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | namespace etl { 11 | 12 | using seconds = std::chrono::seconds; ///< The seconds resolution 13 | using milliseconds = std::chrono::milliseconds; ///< The milliseconds resolution 14 | using microseconds = std::chrono::microseconds; ///< The microseconds resolution 15 | using nanoseconds = std::chrono::nanoseconds; ///< The nanoseconds resolution 16 | 17 | using timer_clock = std::chrono::steady_clock; ///< The chrono clock used by ETL 18 | using clock_resolution = nanoseconds; ///< The clock resolution used by ETL 19 | 20 | /*! 21 | * \brief return the string representation of the given resolution 22 | * \return the tring representation of the given resolution 23 | */ 24 | template 25 | std::string resolution_to_string() { 26 | if constexpr (std::same_as) { 27 | return "s"; 28 | } else if constexpr (std::same_as) { 29 | return "ms"; 30 | } else if constexpr (std::same_as) { 31 | return "us"; 32 | } else if constexpr (std::same_as) { 33 | return "ns"; 34 | } else { 35 | return "?"; 36 | } 37 | } 38 | 39 | } //end of namespace etl 40 | -------------------------------------------------------------------------------- /include/etl/serializer.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | namespace etl { 11 | 12 | /*! 13 | * \brief A serializer for ETL expressions 14 | */ 15 | template 16 | struct serializer { 17 | using stream_t = Stream; ///< The type of stream to use 18 | using char_t = typename stream_t::char_type; ///< The char type of the stream 19 | 20 | stream_t stream; ///< The stream 21 | 22 | /*! 23 | * \brief Construct the serializer by forwarding the arguments 24 | * to the stream 25 | * \param args The arguments to forward to the stream constructor 26 | */ 27 | template 28 | explicit serializer(Args&&... args) : stream(std::forward(args)...) {} 29 | 30 | /*! 31 | * \brief Outputs the given value to the stream 32 | * \param value The value to write to the stream 33 | * \return the serializer 34 | */ 35 | template 36 | serializer& operator<<(const T& value) { 37 | if constexpr (std::is_arithmetic_v) { 38 | stream.write(reinterpret_cast(&value), sizeof(T)); 39 | } else { 40 | serialize(*this, value); 41 | } 42 | 43 | return *this; 44 | } 45 | }; 46 | 47 | } //end of namespace etl 48 | -------------------------------------------------------------------------------- /test/src/memory_slice.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test_light.hpp" 9 | 10 | TEMPLATE_TEST_CASE_2("memory_slice/1", "[slice]", Z, float, double) { 11 | etl::fast_matrix a = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; 12 | 13 | auto s1 = etl::memory_slice(a, 0, 3); 14 | 15 | REQUIRE_EQUALS(etl::size(s1), 3UL); 16 | 17 | REQUIRE_EQUALS(etl::dim<0>(s1), 3UL); 18 | REQUIRE_EQUALS(etl::dimensions(s1), 1UL); 19 | 20 | REQUIRE_EQUALS(s1[0], 1); 21 | REQUIRE_EQUALS(s1[1], 2); 22 | REQUIRE_EQUALS(s1[2], 3); 23 | 24 | auto s2 = etl::memory_slice(a, 1, 3); 25 | 26 | REQUIRE_EQUALS(etl::size(s2), 2UL); 27 | 28 | REQUIRE_EQUALS(etl::dim<0>(s2), 2UL); 29 | REQUIRE_EQUALS(etl::dimensions(s2), 1UL); 30 | 31 | REQUIRE_EQUALS(s2[0], 2); 32 | REQUIRE_EQUALS(s2[1], 3); 33 | REQUIRE_EQUALS(s2[2], 4); 34 | 35 | auto s3 = etl::memory_slice(a, 0, 6); 36 | 37 | REQUIRE_EQUALS(etl::size(s3), 6UL); 38 | 39 | REQUIRE_EQUALS(etl::dim<0>(s3), 6UL); 40 | REQUIRE_EQUALS(etl::dimensions(s3), 1UL); 41 | 42 | REQUIRE_EQUALS(s3[0], 1); 43 | REQUIRE_EQUALS(s3[1], 2); 44 | REQUIRE_EQUALS(s3[2], 3); 45 | REQUIRE_EQUALS(s3[3], 4); 46 | REQUIRE_EQUALS(s3[4], 5); 47 | REQUIRE_EQUALS(s3[5], 6); 48 | } 49 | -------------------------------------------------------------------------------- /test/src/optimize_1.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test_light.hpp" 9 | 10 | #ifndef ETL_CUDA 11 | 12 | ETL_TEST_CASE("optimize/1", "[fast][optimizer]") { 13 | etl::fast_vector a({1.0, -2.0, 3.0}); 14 | etl::fast_vector b; 15 | 16 | b = opt(a + a); 17 | 18 | REQUIRE_EQUALS(b[0], 2.0); 19 | } 20 | 21 | ETL_TEST_CASE("optimize/2", "[fast][optimizer]") { 22 | etl::fast_vector a({1.0, -2.0, 3.0}); 23 | etl::fast_vector b; 24 | 25 | b = opt(a * 1.0); 26 | 27 | REQUIRE_EQUALS(b[0], 1.0); 28 | } 29 | 30 | ETL_TEST_CASE("optimize/3", "[fast][optimizer]") { 31 | etl::fast_vector a({1.0, -2.0, 3.0}); 32 | etl::fast_vector b; 33 | 34 | b = opt(a + a * 1.0); 35 | 36 | REQUIRE_EQUALS(b[0], 2.0); 37 | } 38 | 39 | ETL_TEST_CASE("optimize/4", "[fast][optimizer]") { 40 | etl::fast_vector a({1.0, -2.0, 3.0}); 41 | etl::fast_vector b; 42 | 43 | b = opt(a + 1.0 * a); 44 | 45 | REQUIRE_EQUALS(b[0], 2.0); 46 | } 47 | 48 | ETL_TEST_CASE("optimize/5", "[fast][optimizer]") { 49 | etl::fast_vector a({1.0, -2.0, 3.0}); 50 | etl::fast_vector b; 51 | 52 | b = opt(a * 1.0 + 1.0 * a); 53 | 54 | REQUIRE_EQUALS(b[0], 2.0); 55 | } 56 | 57 | #endif 58 | -------------------------------------------------------------------------------- /include/etl/deserializer.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | namespace etl { 11 | 12 | /*! 13 | * \brief A deserializer for ETL expressions 14 | */ 15 | template 16 | struct deserializer { 17 | using stream_t = Stream; ///< The type of stream to use 18 | using char_t = typename stream_t::char_type; ///< The char type of the stream 19 | 20 | stream_t stream; ///< The stream 21 | 22 | /*! 23 | * \brief Construct the deserializer by forwarding the arguments 24 | * to the stream 25 | * \param args The arguments to forward to the stream constructor 26 | */ 27 | template 28 | explicit deserializer(Args&&... args) : stream(std::forward(args)...) {} 29 | 30 | /*! 31 | * \brief Reads a value of the given type from the stream 32 | * \param value Reference to the value where to write 33 | * \return the deserializer 34 | */ 35 | template 36 | deserializer& operator>>(T& value) { 37 | if constexpr (std::is_arithmetic_v) { 38 | stream.read(reinterpret_cast(&value), sizeof(T)); 39 | } else { 40 | deserialize(*this, value); 41 | } 42 | 43 | return *this; 44 | } 45 | }; 46 | 47 | } //end of namespace etl 48 | -------------------------------------------------------------------------------- /include/etl/impl/decomposition.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Selector for the decompositions implementation 11 | */ 12 | 13 | #pragma once 14 | 15 | //Include the implementations 16 | #include "etl/impl/std/decomposition.hpp" 17 | 18 | namespace etl::detail { 19 | 20 | /*! 21 | * \brief Functor for euclidean norm 22 | */ 23 | struct lu_impl { 24 | /*! 25 | * \brief Apply the functor to A, L, U, P 26 | * \param A The input matrix 27 | * \param L The L decomposition (output) 28 | * \param U The U decomposition (output) 29 | * \param P The P permutation matrix (output) 30 | */ 31 | template 32 | static void apply(const AT& A, LT& L, UT& U, PT& P) { 33 | etl::impl::standard::lu(A, L, U, P); 34 | } 35 | }; 36 | 37 | /*! 38 | * \brief Functor for QR decomposition 39 | */ 40 | struct qr_impl { 41 | /*! 42 | * \brief Apply the functor to A,Q,R 43 | * \param A The input matrix 44 | * \param Q The Q decomposition (output) 45 | * \param R The R decomposition (output) 46 | */ 47 | template 48 | static void apply(AT& A, QT& Q, RT& R) { 49 | etl::impl::standard::qr(A, Q, R); 50 | } 51 | }; 52 | 53 | } //end of namespace etl::detail 54 | -------------------------------------------------------------------------------- /include/etl/impl/std/outer.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Standard implementation of the outer product 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl::impl::standard { 16 | 17 | /*! 18 | * \brief Compute the outer product of a and b and store the result in c 19 | * \param a The a expression 20 | * \param b The b expression 21 | * \param c The c expression 22 | */ 23 | template 24 | void outer(const A& a, const B& b, C&& c) { 25 | for (size_t i = 0; i < etl::dim<0>(c); ++i) { 26 | for (size_t j = 0; j < etl::dim<1>(c); ++j) { 27 | c(i, j) = a(i) * b(j); 28 | } 29 | } 30 | } 31 | 32 | /*! 33 | * \brief Compute the batch outer product of a and b and store the result in c 34 | * \param lhs The a expression 35 | * \param rhs The b expression 36 | * \param c The c expression 37 | */ 38 | template 39 | void batch_outer(const A& lhs, const B& rhs, C&& c) { 40 | c = 0; 41 | 42 | for (size_t b = 0; b < etl::dim<0>(lhs); ++b) { 43 | for (size_t i = 0; i < etl::dim<0>(c); ++i) { 44 | for (size_t j = 0; j < etl::dim<1>(c); ++j) { 45 | c(i, j) += lhs(b, i) * rhs(b, j); 46 | } 47 | } 48 | } 49 | } 50 | 51 | } //end of namespace etl::impl::standard 52 | -------------------------------------------------------------------------------- /include/etl/stop.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | namespace etl { 11 | 12 | /*! 13 | * \brief Force the evaluation of the given expression 14 | * \param value The ETL expression 15 | * \return A value class with the values of the given expression 16 | */ 17 | template 18 | auto s(T&& value) { 19 | // Sizes will be directly propagated 20 | dyn_matrix, etl_traits::dimensions()> mat; 21 | mat = value; 22 | return mat; 23 | } 24 | 25 | /*! 26 | * \brief TMP struct to build fast matrix type from a fast expression type 27 | */ 28 | template 29 | struct build_matrix_type; 30 | 31 | /*! 32 | * \copydoc build_matrix_type 33 | */ 34 | template 35 | struct build_matrix_type> { 36 | using type = fast_dyn_matrix, etl_traits::template dim()...>; ///< The fast matrix type 37 | }; 38 | 39 | /*! 40 | * \brief Force the evaluation of the given expression 41 | * \param value The ETL expression 42 | * \return A value class with the values of the given expression 43 | */ 44 | template 45 | auto s(T&& value) { 46 | typename build_matrix_type::dimensions()>>::type mat; 47 | mat = value; 48 | return mat; 49 | } 50 | 51 | } // end of namespace etl 52 | -------------------------------------------------------------------------------- /test/src/elt_logical.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test.hpp" 9 | 10 | #include 11 | 12 | ETL_TEST_CASE("elt_logical/and/1", "[compare]") { 13 | etl::fast_matrix a{false, false, true, true}; 14 | etl::fast_matrix b{false, true, true, false}; 15 | etl::fast_matrix c; 16 | 17 | c = logical_and(a, b); 18 | 19 | REQUIRE_EQUALS(c(0, 0), false); 20 | REQUIRE_EQUALS(c(0, 1), false); 21 | REQUIRE_EQUALS(c(1, 0), true); 22 | REQUIRE_EQUALS(c(1, 1), false); 23 | } 24 | 25 | ETL_TEST_CASE("elt_logical/or/1", "[compare]") { 26 | etl::fast_matrix a{false, false, true, true}; 27 | etl::fast_matrix b{false, true, true, false}; 28 | etl::fast_matrix c; 29 | 30 | c = logical_or(a, b); 31 | 32 | REQUIRE_EQUALS(c(0, 0), false); 33 | REQUIRE_EQUALS(c(0, 1), true); 34 | REQUIRE_EQUALS(c(1, 0), true); 35 | REQUIRE_EQUALS(c(1, 1), true); 36 | } 37 | 38 | ETL_TEST_CASE("elt_logical/xor/1", "[compare]") { 39 | etl::fast_dyn_matrix a{false, false, true, true}; 40 | etl::fast_dyn_matrix b{false, true, true, false}; 41 | etl::fast_dyn_matrix c; 42 | 43 | c = logical_xor(a, b); 44 | 45 | REQUIRE_EQUALS(c(0, 0), false); 46 | REQUIRE_EQUALS(c(0, 1), true); 47 | REQUIRE_EQUALS(c(1, 0), false); 48 | REQUIRE_EQUALS(c(1, 1), true); 49 | } 50 | -------------------------------------------------------------------------------- /include/etl/builder/conv_expression_builder.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file conv_expression_builder.hpp 10 | * \brief Contains all the operators and functions to build convolution expressions. 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Construct a matrix to compute a convolution by matrix-matrix multiplication 19 | * \param a The vector to transform (the input of the convolution) 20 | * \param h The size of kernel 21 | * \return a matrix expression for convolution 22 | */ 23 | template 24 | auto convmtx(A&& a, size_t h) -> detail::stable_transform_helper { 25 | return detail::stable_transform_helper{dyn_convmtx_transformer>(a, h)}; 26 | } 27 | 28 | /*! 29 | * \brief Construct a matrix to compute a 2D convolution by matrix-matrix multiplication 30 | * \param a The 2D matrix to transform (the input of the convolution) 31 | * \param k1 The first dimension of the kernel 32 | * \param k2 The second dimension of the kernel 33 | * \return a matrix expression for convolution 34 | */ 35 | template 36 | auto convmtx2(A&& a, size_t k1, size_t k2) -> detail::stable_transform_helper { 37 | return detail::stable_transform_helper{dyn_convmtx2_transformer>(a, k1, k2)}; 38 | } 39 | 40 | } //end of namespace etl 41 | -------------------------------------------------------------------------------- /include/etl/impl/blas/dot.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief BLAS implementation of the "dot" reduction 11 | */ 12 | 13 | #pragma once 14 | 15 | #ifdef ETL_BLAS_MODE 16 | #include "cblas.h" //For ddot/sdot 17 | #endif 18 | 19 | namespace etl::impl::blas { 20 | 21 | #ifdef ETL_BLAS_MODE 22 | 23 | /*! 24 | * \brief Compute the dot product of a and b 25 | * \param a The lhs expression 26 | * \param b The rhs expression 27 | * \return the sum 28 | */ 29 | template 30 | value_t dot(const A& a, const B& b) { 31 | if constexpr (all_dma) { 32 | a.ensure_cpu_up_to_date(); 33 | b.ensure_cpu_up_to_date(); 34 | 35 | if constexpr (all_single_precision) { 36 | return cblas_sdot(etl::size(a), a.memory_start(), 1, b.memory_start(), 1); 37 | } else { 38 | return cblas_ddot(etl::size(a), a.memory_start(), 1, b.memory_start(), 1); 39 | } 40 | } else { 41 | cpp_unreachable("BLAS not enabled/available"); 42 | return 0.0; 43 | } 44 | } 45 | 46 | #else 47 | 48 | //COVERAGE_EXCLUDE_BEGIN 49 | 50 | /*! 51 | * \copydoc dot 52 | */ 53 | template 54 | value_t dot(const A& /*a*/, const B& /*b*/) { 55 | cpp_unreachable("BLAS not enabled/available"); 56 | return 0.0; 57 | } 58 | 59 | //COVERAGE_EXCLUDE_END 60 | 61 | #endif 62 | 63 | } //end of namespace etl::impl::blas 64 | -------------------------------------------------------------------------------- /include/etl/op/generators/sequence.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Contains sequence generators 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Generator from a sequence 19 | */ 20 | template 21 | struct sequence_generator_op { 22 | using value_type = T; ///< The value type 23 | 24 | const value_type start; ///< The beginning of the sequence 25 | value_type current; ///< The current sequence element 26 | 27 | static constexpr bool gpu_computable = false; ///< Indicates if the operator is computable on GPU 28 | 29 | /*! 30 | * \brief Construct a new generator with the given sequence start 31 | * \param start The beginning of the sequence 32 | */ 33 | explicit sequence_generator_op(value_type start = 0) : start(start), current(start) {} 34 | 35 | /*! 36 | * \brief Generate a new value 37 | * \return the newly generated value 38 | */ 39 | value_type operator()() { 40 | return current++; 41 | } 42 | 43 | /*! 44 | * \brief Outputs the given generator to the given stream 45 | * \param os The output stream 46 | * \param s The generator 47 | * \return the output stream 48 | */ 49 | friend std::ostream& operator<<(std::ostream& os, const sequence_generator_op& s) { 50 | return os << "[" << s.start << ",...]"; 51 | } 52 | }; 53 | 54 | } //end of namespace etl 55 | -------------------------------------------------------------------------------- /include/etl/impl/std/bias_add.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Standard implementation of the bias_add computation 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl::impl::standard { 16 | 17 | /*! 18 | * \brief Compute the bias addition of a and b and store the result in c 19 | * \param lhs The a expression 20 | * \param rhs The b expression 21 | * \param c The c expression 22 | */ 23 | template 24 | void bias_add_4d(const A& lhs, const B& rhs, C&& c) { 25 | for (size_t i = 0; i < etl::dim<0>(lhs); ++i) { 26 | for (size_t j = 0; j < etl::dim<1>(lhs); ++j) { 27 | for (size_t k = 0; k < etl::dim<2>(lhs); ++k) { 28 | for (size_t l = 0; l < etl::dim<3>(lhs); ++l) { 29 | c(i, j, k, l) = lhs(i, j, k, l) + rhs(j); 30 | } 31 | } 32 | } 33 | } 34 | } 35 | 36 | /*! 37 | * \brief Compute the bias addition of a and b and store the result in c 38 | * \param lhs The a expression 39 | * \param rhs The b expression 40 | * \param c The c expression 41 | */ 42 | template 43 | void bias_add_2d(const A& lhs, const B& rhs, C&& c) { 44 | for (size_t i = 0; i < etl::dim<0>(lhs); ++i) { 45 | for (size_t j = 0; j < etl::dim<1>(lhs); ++j) { 46 | c(i, j) = lhs(i, j) + rhs(j); 47 | } 48 | } 49 | } 50 | 51 | } //end of namespace etl::impl::standard 52 | -------------------------------------------------------------------------------- /include/etl/impl/std/mse.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Standard implementation of the Mean Squared Error reduction 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl::impl::standard { 16 | 17 | /*! 18 | * \brief Returns the Mean Squared Error Loss 19 | * \param output The outputs 20 | * \param labels The labels 21 | * \return The MSE Loss of the output and labels 22 | */ 23 | template 24 | value_t mse_loss(const O& output, const L& labels, value_t scale) { 25 | return scale * sum((output - labels) >> (output - labels)); 26 | } 27 | 28 | /*! 29 | * \brief Returns the Mean Squared Error Error 30 | * \param output The outputs 31 | * \param labels The labels 32 | * \return The MSE Error of the output and labels 33 | */ 34 | template 35 | value_t mse_error(const O& output, const L& labels, value_t scale) { 36 | return scale * asum(labels - output); 37 | } 38 | 39 | /*! 40 | * \brief Returns the Mean Squared Error Loss and Error 41 | * \param output The outputs 42 | * \param labels The labels 43 | * \return The MSE Error of the output and labels 44 | */ 45 | template 46 | std::pair, value_t> mse(const O& output, const L& labels, value_t alpha, value_t beta) { 47 | return std::make_pair(mse_loss(output, labels, alpha), mse_error(output, labels, beta)); 48 | } 49 | 50 | } //end of namespace etl::impl::standard 51 | -------------------------------------------------------------------------------- /include/etl/impl/std/cce.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Standard implementation of the Categorical Cross Entropy reduction 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl::impl::standard { 16 | 17 | /*! 18 | * \brief Compute the Categorical Cross Entropy loss of the input in the given expression 19 | * \param input The input expression 20 | * \return the sum 21 | */ 22 | template 23 | value_t cce_loss(const O& output, const L& labels, value_t scale) { 24 | return scale * etl::sum(log(output) >> labels); 25 | } 26 | 27 | /*! 28 | * \brief Compute the Categorical Cross Entropy error of the input in the given expression 29 | * \param input The input expression 30 | * \return the sum 31 | */ 32 | template 33 | value_t cce_error(const O& output, const L& labels, value_t scale) { 34 | return scale * sum(min(abs(argmax(labels) - argmax(output)), 1.0)); 35 | } 36 | 37 | /*! 38 | * \brief Returns the Binary Cross Entropy Loss and Error 39 | * \param output The outputs 40 | * \param labels The labels 41 | * \return The CCE Error of the output and labels 42 | */ 43 | template 44 | std::pair, value_t> cce(const O& output, const L& labels, value_t alpha, value_t beta) { 45 | return std::make_pair(cce_loss(output, labels, alpha), cce_error(output, labels, beta)); 46 | } 47 | 48 | } //end of namespace etl::impl::standard 49 | -------------------------------------------------------------------------------- /benchmark/src/benchmark_trigo.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #define CPM_LIB 9 | #include "benchmark.hpp" 10 | 11 | //Bench trigonometric function 12 | CPM_BENCH() { 13 | CPM_TWO_PASS_NS( 14 | "r = cos(a) (s) [std][cos][s]", 15 | [](size_t d){ return std::make_tuple(svec(d), svec(d)); }, 16 | [](svec& a, svec& r){ r = cos(a); } 17 | ); 18 | 19 | CPM_TWO_PASS_NS( 20 | "r = sin(a) (s) [std][sin][s]", 21 | [](size_t d){ return std::make_tuple(svec(d), svec(d)); }, 22 | [](svec& a, svec& r){ r = sin(a); } 23 | ); 24 | 25 | CPM_TWO_PASS_NS( 26 | "r = tan(a) (s) [std][tan][s]", 27 | [](size_t d){ return std::make_tuple(svec(d), svec(d)); }, 28 | [](svec& a, svec& r){ r = tan(a); } 29 | ); 30 | } 31 | 32 | //Bench hyperbolic function 33 | CPM_BENCH() { 34 | CPM_TWO_PASS_NS( 35 | "r = cosh(a) (s) [std][cosh][s]", 36 | [](size_t d){ return std::make_tuple(svec(d), svec(d)); }, 37 | [](svec& a, svec& r){ r = cosh(a); } 38 | ); 39 | 40 | CPM_TWO_PASS_NS( 41 | "r = sinh(a) (s) [std][sinh][s]", 42 | [](size_t d){ return std::make_tuple(svec(d), svec(d)); }, 43 | [](svec& a, svec& r){ r = sinh(a); } 44 | ); 45 | 46 | CPM_TWO_PASS_NS( 47 | "r = tanh(a) (s) [std][tanh][s]", 48 | [](size_t d){ return std::make_tuple(svec(d), svec(d)); }, 49 | [](svec& a, svec& r){ r = tanh(a); } 50 | ); 51 | } 52 | -------------------------------------------------------------------------------- /include/etl/impl/std/bce.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Standard implementation of the Binary Cross Entropy reduction 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl::impl::standard { 16 | 17 | /*! 18 | * \brief Returns the Binary Cross Entropy Loss 19 | * \param output The outputs 20 | * \param labels The labels 21 | * \return The BCE Loss of the output and labels 22 | */ 23 | template 24 | value_t bce_loss(const O& output, const L& labels, value_t scale) { 25 | return scale * sum((labels >> log(output)) + ((1.0 - labels) >> log(1.0 - output))); 26 | } 27 | 28 | /*! 29 | * \brief Returns the Binary Cross Entropy Error 30 | * \param output The outputs 31 | * \param labels The labels 32 | * \return The BCE Error of the output and labels 33 | */ 34 | template 35 | value_t bce_error(const O& output, const L& labels, value_t scale) { 36 | return scale * asum(labels - output); 37 | } 38 | 39 | /*! 40 | * \brief Returns the Binary Cross Entropy Loss and Error 41 | * \param output The outputs 42 | * \param labels The labels 43 | * \return The BCE Error of the output and labels 44 | */ 45 | template 46 | std::pair, value_t> bce(const O& output, const L& labels, value_t alpha, value_t beta) { 47 | return std::make_pair(bce_loss(output, labels, alpha), bce_error(output, labels, beta)); 48 | } 49 | 50 | } //end of namespace etl::impl::standard 51 | -------------------------------------------------------------------------------- /scripts/test_runner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -e 3 | 4 | function etl_run { 5 | make clean 6 | time make $ETL_THREADS debug/bin/etl_test 7 | time ./debug/bin/etl_test 8 | } 9 | 10 | # Disable default options 11 | export ETL_NO_DEFAULT=true 12 | unset ETL_DEFAULTS 13 | unset ETL_MKL 14 | unset ETL_BLAS 15 | unset ETL_CUBLAS 16 | unset ETL_CUFFT 17 | unset ETL_CUDNN 18 | unset ETL_GPU 19 | unset ETL_EGBLAS 20 | 21 | # Use gcc 22 | export CXX=$ETL_GPP 23 | export LD=$ETL_GPP 24 | 25 | echo "Tests are compiled using $CXX compiler:" 26 | $CXX --version 27 | 28 | echo "Test 1. GCC (debug default)" 29 | 30 | export ETL_DEFAULTS="-DETL_DEBUG_THRESHOLDS -DCPP_UTILS_ASSERT_EXCEPTION" 31 | 32 | etl_run 1 33 | 34 | echo "Test 2. GCC (debug vectorize avx)" 35 | 36 | export ETL_DEFAULTS="-DETL_DEBUG_THRESHOLDS -DETL_VECTORIZE_FULL -mavx" 37 | 38 | etl_run 2 39 | 40 | echo "Test 3. GCC (debug vectorize sse)" 41 | 42 | export ETL_DEFAULTS="-DETL_DEBUG_THRESHOLDS -DETL_VECTORIZE_FULL -msse3 -msse4" 43 | 44 | etl_run 3 45 | 46 | echo "Test 4. GCC (debug mkl)" 47 | 48 | export ETL_DEFAULTS="-DETL_DEBUG_THRESHOLDS" 49 | export ETL_MKL=true 50 | 51 | etl_run 4 52 | 53 | echo "Test 5. GCC (debug parallel)" 54 | 55 | unset ETL_MKL 56 | export ETL_DEFAULTS="-DETL_DEBUG_THRESHOLDS -DETL_PARALLEL" 57 | 58 | etl_run 5 59 | 60 | echo "Test 6. GCC (debug vectorize sse avx parallel)" 61 | 62 | unset ETL_MKL 63 | export ETL_DEFAULTS="-DETL_DEBUG_THRESHOLDS -DETL_PARALLEL -DETL_VECTORIZE_FULL -msse3 -msse4 -mavx" 64 | 65 | etl_run 6 66 | 67 | if [ "$ETL_NO_GPU" == "" ] 68 | then 69 | echo "Test 7. GCC (debug cublas cufft)" 70 | 71 | export ETL_DEFAULTS="-DETL_DEBUG_THRESHOLDS" 72 | unset ETL_MKL 73 | export ETL_CUBLAS=true 74 | export ETL_CUFFT=true 75 | export ETL_CUDNN=true 76 | 77 | etl_run 7 78 | fi 79 | -------------------------------------------------------------------------------- /benchmark/include/benchmark_gemm.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | using outer_policy = NARY_POLICY(VALUES_POLICY(10, 50, 100, 500, 1000, 2000, 3000), 11 | VALUES_POLICY(10, 50, 100, 500, 1000, 2000, 3000)); 12 | 13 | using bias_add_policy = NARY_POLICY(VALUES_POLICY(10, 20, 30, 40, 50, 60, 70, 80, 90, 100), 14 | VALUES_POLICY(128, 128, 128, 128, 128, 128, 128, 128, 128, 128), 15 | VALUES_POLICY(100, 100, 100, 100, 100, 100, 100, 100, 100, 100), 16 | VALUES_POLICY(100, 100, 100, 100, 100, 100, 100, 100, 100, 100)); 17 | 18 | using bias_add_2d_policy = NARY_POLICY(VALUES_POLICY(128, 128, 128, 128, 128, 128, 128, 128, 128, 128), 19 | VALUES_POLICY(100, 200, 300, 400, 500, 600, 700, 800, 900, 1000)); 20 | 21 | using square_policy = NARY_POLICY(VALUES_POLICY(100, 200, 300, 400, 500, 600, 700, 800, 900, 1000), 22 | VALUES_POLICY(100, 200, 300, 400, 500, 600, 700, 800, 900, 1000)); 23 | 24 | using small_square_policy = NARY_POLICY(VALUES_POLICY(100, 150, 200, 250, 300, 350, 400, 450, 500), 25 | VALUES_POLICY(100, 150, 200, 250, 300, 350, 400, 450, 500)); 26 | 27 | using gemv_policy = NARY_POLICY(VALUES_POLICY(50, 100, 250, 500, 750, 1000, 2000, 3000, 4000, 5000, 6000), 28 | VALUES_POLICY(50, 100, 250, 500, 750, 1000, 2000, 3000, 4000, 5000, 6000)); 29 | -------------------------------------------------------------------------------- /include/etl/impl/egblas/or.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief EGBLAS wrappers for the or operation. 11 | */ 12 | 13 | #pragma once 14 | 15 | #ifdef ETL_EGBLAS_MODE 16 | 17 | #include "etl/impl/cublas/cuda.hpp" 18 | 19 | #include 20 | 21 | #endif 22 | 23 | namespace etl::impl::egblas { 24 | 25 | /*! 26 | * \brief Indicates if EGBLAS has single-precision or. 27 | */ 28 | #ifdef EGBLAS_HAS_BOR 29 | static constexpr bool has_bor = true; 30 | #else 31 | static constexpr bool has_bor = false; 32 | #endif 33 | 34 | /*! 35 | * \brief Wrappers for or operation 36 | * \param n The size of the vector 37 | * \param A The memory of the vector a 38 | * \param lda The leading dimension of a 39 | * \param B The memory of the vector b 40 | * \param ldb The leading dimension of b 41 | * \param C The memory of the vector c 42 | * \param ldc The leading dimension of c 43 | */ 44 | inline void logical_or([[maybe_unused]] size_t n, 45 | [[maybe_unused]] const bool* A, 46 | [[maybe_unused]] size_t lda, 47 | [[maybe_unused]] const bool* B, 48 | [[maybe_unused]] size_t ldb, 49 | [[maybe_unused]] bool* C, 50 | [[maybe_unused]] size_t ldc) { 51 | #ifdef EGBLAS_HAS_BOR 52 | inc_counter("egblas"); 53 | egblas_bor(n, A, lda, B, ldb, C, ldc); 54 | #else 55 | cpp_unreachable("Invalid call to egblas::or"); 56 | #endif 57 | } 58 | 59 | } //end of namespace etl::impl::egblas 60 | -------------------------------------------------------------------------------- /include/etl/impl/conv_select.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Contains selectors for convolution implementations. 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Enumeration describing the different types of convolution 19 | */ 20 | enum class conv_type { 21 | VALID, ///< Valid convolution 22 | VALID_MULTI, ///< Valid convolution, with multiple kernels 23 | SAME, ///< Same convolution 24 | SAME_MULTI, ///< Same convolution, with multiple kernels 25 | FULL, ///< Full convolution 26 | FULL_MULTI ///< Full convolution, with multiple kernels 27 | }; 28 | 29 | namespace detail { 30 | 31 | /*! 32 | * \brief Test if ETL should run in parallel for the conv of I and K in C 33 | * \tparam I The input type 34 | * \tparam K The kernel type 35 | * \tparam C The conv type 36 | * \return true to run in paralle, false otherwise 37 | */ 38 | template 39 | inline bool select_parallel(const I& /*input*/, const K& kernel, C&& conv) { 40 | if ((is_parallel && !local_context().serial) || (parallel_support && local_context().parallel)) { 41 | return etl::size(conv) >= conv1_parallel_threshold_conv && etl::size(kernel) >= conv1_parallel_threshold_kernel; 42 | } else { 43 | return false; 44 | } 45 | } 46 | 47 | } //end of namespace detail 48 | 49 | } //end of namespace etl 50 | 51 | #include "conv_normal_select.hpp" 52 | #include "conv_multi_select.hpp" 53 | #include "conv_4d_select.hpp" 54 | -------------------------------------------------------------------------------- /include/etl/impl/egblas/and.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief EGBLAS wrappers for the and operation. 11 | */ 12 | 13 | #pragma once 14 | 15 | #ifdef ETL_EGBLAS_MODE 16 | 17 | #include "etl/impl/cublas/cuda.hpp" 18 | 19 | #include 20 | 21 | #endif 22 | 23 | namespace etl::impl::egblas { 24 | 25 | /*! 26 | * \brief Indicates if EGBLAS has single-precision and. 27 | */ 28 | #ifdef EGBLAS_HAS_BAND 29 | static constexpr bool has_band = true; 30 | #else 31 | static constexpr bool has_band = false; 32 | #endif 33 | 34 | /*! 35 | * \brief Wrappers for or operation 36 | * \param n The size of the vector 37 | * \param A The memory of the vector a 38 | * \param lda The leading dimension of a 39 | * \param B The memory of the vector b 40 | * \param ldb The leading dimension of b 41 | * \param C The memory of the vector c 42 | * \param ldc The leading dimension of c 43 | */ 44 | inline void logical_and([[maybe_unused]] size_t n, 45 | [[maybe_unused]] const bool* A, 46 | [[maybe_unused]] size_t lda, 47 | [[maybe_unused]] const bool* B, 48 | [[maybe_unused]] size_t ldb, 49 | [[maybe_unused]] bool* C, 50 | [[maybe_unused]] size_t ldc) { 51 | #ifdef EGBLAS_HAS_BAND 52 | inc_counter("egblas"); 53 | egblas_band(n, A, lda, B, ldb, C, ldc); 54 | #else 55 | cpp_unreachable("Invalid call to egblas::logical_and"); 56 | #endif 57 | } 58 | 59 | } //end of namespace etl::impl::egblas 60 | -------------------------------------------------------------------------------- /include/etl/impl/egblas/xor.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief EGBLAS wrappers for the xor operation. 11 | */ 12 | 13 | #pragma once 14 | 15 | #ifdef ETL_EGBLAS_MODE 16 | 17 | #include "etl/impl/cublas/cuda.hpp" 18 | 19 | #include 20 | 21 | #endif 22 | 23 | namespace etl::impl::egblas { 24 | 25 | /*! 26 | * \brief Indicates if EGBLAS has single-precision xor. 27 | */ 28 | #ifdef EGBLAS_HAS_BXOR 29 | static constexpr bool has_bxor = true; 30 | #else 31 | static constexpr bool has_bxor = false; 32 | #endif 33 | 34 | /*! 35 | * \brief Wrappers for or operation 36 | * \param n The size of the vector 37 | * \param A The memory of the vector a 38 | * \param lda The leading dimension of a 39 | * \param B The memory of the vector b 40 | * \param ldb The leading dimension of b 41 | * \param C The memory of the vector c 42 | * \param ldc The leading dimension of c 43 | */ 44 | inline void logical_xor([[maybe_unused]] size_t n, 45 | [[maybe_unused]] const bool* A, 46 | [[maybe_unused]] size_t lda, 47 | [[maybe_unused]] const bool* B, 48 | [[maybe_unused]] size_t ldb, 49 | [[maybe_unused]] bool* C, 50 | [[maybe_unused]] size_t ldc) { 51 | #ifdef EGBLAS_HAS_BXOR 52 | inc_counter("egblas"); 53 | egblas_bxor(n, A, lda, B, ldb, C, ldc); 54 | #else 55 | cpp_unreachable("Invalid call to egblas::logical_xor"); 56 | #endif 57 | } 58 | 59 | } //end of namespace etl::impl::egblas 60 | -------------------------------------------------------------------------------- /include/etl/math.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | #include 11 | 12 | namespace etl::math { 13 | 14 | /*! 15 | * \brief Return the logistic sigmoid of x 16 | * \param x The value 17 | * \return The logistic sigmoid of x 18 | */ 19 | inline float logistic_sigmoid(float x) { 20 | return 1.0f / (1.0f + std::exp(-x)); 21 | } 22 | 23 | /*! 24 | * \brief Return the logistic sigmoid of x 25 | * \param x The value 26 | * \return The logistic sigmoid of x 27 | */ 28 | inline double logistic_sigmoid(double x) { 29 | return 1.0 / (1.0 + std::exp(-x)); 30 | } 31 | 32 | /*! 33 | * \brief Return the softplus of x 34 | * \param x The value 35 | * \return The softplus of x 36 | */ 37 | inline float softplus(float x) { 38 | return std::log1p(std::exp(x)); 39 | } 40 | 41 | /*! 42 | * \brief Return the softplus of x 43 | * \param x The value 44 | * \return The softplus of x 45 | */ 46 | inline double softplus(double x) { 47 | return std::log1p(std::exp(x)); 48 | } 49 | 50 | /*! 51 | * \brief Return the sign of x 52 | * \param v The value 53 | * \return The sign of x 54 | */ 55 | template 56 | constexpr double sign(W v) noexcept { 57 | return v == W(0) ? W(0) : (v > W(0) ? W(1) : W(-1)); 58 | } 59 | 60 | /*! 61 | * \brief Test if the given number is a power of two 62 | * \param n The number to test 63 | * \return true if the number is a power of two, false otherwise 64 | */ 65 | constexpr bool is_power_of_two(int64_t n) { 66 | return (n & (n - 1)) == 0; 67 | } 68 | 69 | } //end of namespace etl::math 70 | -------------------------------------------------------------------------------- /include/etl/impl/std/det.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Standard implementation of the determinant 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \copydoc etl::lu 19 | */ 20 | template 21 | bool lu(const AT& A, LT& L, UT& U, PT& P); 22 | 23 | namespace impl { 24 | 25 | namespace standard { 26 | 27 | /*! 28 | * \brief Compute the determinant of the given matrix 29 | * \return The determinant of the given matrix 30 | */ 31 | template 32 | value_t det(const AT& A) { 33 | using T = value_t; 34 | 35 | const auto n = etl::dim<0>(A); 36 | 37 | if (is_permutation_matrix(A)) { 38 | size_t t = 0; 39 | 40 | for (size_t i = 0; i < n; ++i) { 41 | for (size_t j = 0; j < n; ++j) { 42 | if (A(i, j) != 0.0 && i != j) { 43 | ++t; 44 | } 45 | } 46 | } 47 | 48 | return std::pow(T(-1.0), t - 1); 49 | } 50 | 51 | if (is_triangular(A)) { 52 | T det(1.0); 53 | 54 | for (size_t i = 0; i < n; ++i) { 55 | det *= A(i, i); 56 | } 57 | 58 | return det; 59 | } 60 | 61 | auto L = force_temporary_dim_only(A); 62 | auto U = force_temporary_dim_only(A); 63 | auto P = force_temporary_dim_only(A); 64 | 65 | etl::lu(A, L, U, P); 66 | 67 | return det(L) * det(U) * det(P); 68 | } 69 | 70 | } //end of namespace standard 71 | } //end of namespace impl 72 | } //end of namespace etl 73 | -------------------------------------------------------------------------------- /include/etl/direct_fill.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Standard memory utilities 11 | */ 12 | 13 | #pragma once 14 | 15 | #include "etl/impl/egblas/scalar_set.hpp" 16 | 17 | namespace etl { 18 | 19 | /*! 20 | * \brief Fill the given ETL value class with the given value 21 | * \param mat The ETL value class 22 | * \param value The value to set to each element of the matrix 23 | */ 24 | template 25 | void direct_fill(E&& mat, V value) { 26 | if constexpr (is_single_precision && egblas_enabled && impl::egblas::has_scalar_sset) { 27 | value_t value_conv = value; 28 | 29 | if (mat.gpu_memory()) { 30 | impl::egblas::scalar_set(mat.gpu_memory(), etl::size(mat), 1, value_conv); 31 | 32 | mat.validate_gpu(); 33 | } 34 | 35 | std::fill(mat.memory_start(), mat.memory_end(), value_conv); 36 | 37 | mat.validate_cpu(); 38 | } else if constexpr (is_double_precision && egblas_enabled && impl::egblas::has_scalar_dset) { 39 | value_t value_conv = value; 40 | 41 | if (mat.gpu_memory()) { 42 | impl::egblas::scalar_set(mat.gpu_memory(), etl::size(mat), 1, value_conv); 43 | 44 | mat.validate_gpu(); 45 | } 46 | 47 | std::fill(mat.memory_start(), mat.memory_end(), value_conv); 48 | 49 | mat.validate_cpu(); 50 | } else { 51 | std::fill(mat.memory_start(), mat.memory_end(), value); 52 | 53 | mat.validate_cpu(); 54 | mat.invalidate_gpu(); 55 | } 56 | } 57 | 58 | } //end of namespace etl 59 | -------------------------------------------------------------------------------- /test/src/max_pool_upsample_deep.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test.hpp" 9 | 10 | #include 11 | 12 | TEMPLATE_TEST_CASE_2("pool_upsample/dyn/max2/deep/1", "[pooling]", Z, float, double) { 13 | etl::dyn_matrix input(5, 9, 9); 14 | input = etl::sequence_generator(1.0); 15 | 16 | etl::dyn_matrix errors(5, 3, 3); 17 | errors = 100.0 * etl::sequence_generator(1.0); 18 | 19 | etl::dyn_matrix output(5, 3, 3); 20 | output = etl::max_pool_2d(input, 3, 3); 21 | 22 | etl::dyn_matrix c1(5, 9, 9); 23 | etl::dyn_matrix result(5, 9, 9); 24 | 25 | c1 = etl::max_pool_derivative_2d(input, output, 3, 3) >> etl::upsample_2d(errors, 3, 3); 26 | result = etl::max_pool_upsample_2d(input, output, errors, 3, 3); 27 | 28 | REQUIRE_DIRECT(approx_equals(c1, result, base_eps_etl)); 29 | } 30 | 31 | TEMPLATE_TEST_CASE_2("pool_upsample/max2/deep/1", "[pooling]", Z, float, double) { 32 | etl::fast_matrix input; 33 | input = etl::sequence_generator(1.0); 34 | 35 | etl::fast_matrix errors; 36 | errors = 100.0 * etl::sequence_generator(1.0); 37 | 38 | etl::fast_matrix output; 39 | output = etl::max_pool_2d<2, 1>(input); 40 | 41 | etl::fast_matrix c1; 42 | etl::fast_matrix result; 43 | 44 | c1 = etl::max_pool_derivative_2d<2, 1>(input, output) >> etl::upsample_2d<2, 1>(errors); 45 | result = etl::max_pool_upsample_2d<2, 1>(input, output, errors); 46 | 47 | REQUIRE_DIRECT(approx_equals(c1, result, base_eps_etl)); 48 | } 49 | 50 | // TODO This should have more tests 51 | -------------------------------------------------------------------------------- /test/src/decomposition.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test.hpp" 9 | 10 | /* LU */ 11 | 12 | TEMPLATE_TEST_CASE_2("globals/lu/1", "[globals][LU]", Z, float, double) { 13 | etl::fast_matrix A{1, 3, 5, 2, 4, 7, 1, 1, 0}; 14 | etl::fast_matrix L; 15 | etl::fast_matrix U; 16 | etl::fast_matrix P; 17 | 18 | etl::lu(A, L, U, P); 19 | 20 | etl::fast_matrix PA; 21 | etl::fast_matrix LU; 22 | PA = P * A; 23 | LU = L * U; 24 | 25 | REQUIRE_DIRECT(approx_equals(PA, LU, base_eps_etl)); 26 | } 27 | 28 | TEMPLATE_TEST_CASE_2("globals/lu/2", "[globals][LU]", Z, float, double) { 29 | etl::fast_matrix A{11, 9, 24, 2, 1, 5, 2, 6, 3, 17, 18, 1, 2, 5, 7, 1}; 30 | etl::fast_matrix L; 31 | etl::fast_matrix U; 32 | etl::fast_matrix P; 33 | 34 | etl::lu(A, L, U, P); 35 | 36 | etl::fast_matrix PA; 37 | etl::fast_matrix LU; 38 | PA = P * A; 39 | LU = L * U; 40 | 41 | REQUIRE_DIRECT(approx_equals(PA, LU, base_eps_etl)); 42 | } 43 | 44 | /* QR */ 45 | 46 | TEMPLATE_TEST_CASE_2("globals/qr/1", "[globals][QR]", Z, float, double) { 47 | etl::fast_matrix A{12, -51, 4, 6, 167, -68, -4, 24, -41, -1, 1, 44, 2, 11, 3}; 48 | etl::fast_matrix Q; 49 | etl::fast_matrix R; 50 | etl::fast_matrix QR; 51 | 52 | etl::qr(A, Q, R); 53 | 54 | QR = Q * R; 55 | 56 | // The epsilon need to be big because of the zero in the result 57 | // and the large difference in computation around zero 58 | REQUIRE_DIRECT(approx_equals(QR, A, 100 * base_eps_etl)); 59 | } 60 | -------------------------------------------------------------------------------- /include/etl/parallel_session.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | namespace etl { 11 | 12 | namespace detail { 13 | 14 | /*! 15 | * \brief RAII helper for run and validating parallel session 16 | */ 17 | template 18 | struct parallel_session { 19 | /*! 20 | * \brief Default construct a parallel session 21 | * 22 | * This sets the parallel session as active and makes sure that no previous 23 | * parallel session was running. 24 | */ 25 | parallel_session() { 26 | cpp_assert(!active, "Parallel session cannot be nested"); 27 | 28 | active = true; 29 | } 30 | 31 | /*! 32 | * \brief Destruct a parallel session 33 | * 34 | * This disable the parallel session 35 | */ 36 | ~parallel_session() { 37 | active = false; 38 | } 39 | 40 | /*! 41 | * \brief Does nothing, simple trick for macro to be nice 42 | */ 43 | operator bool() { 44 | return true; 45 | } 46 | 47 | static bool active; ///< Indicates if the parallel session is active 48 | }; 49 | 50 | template 51 | bool parallel_session::active = false; 52 | 53 | } //end of namespace detail 54 | 55 | /*! 56 | * \brief Indicates if a parallel session is currently active 57 | * \return true if a parallel section is active, false otherwise 58 | */ 59 | inline bool is_parallel_session() { 60 | return detail::parallel_session::active; 61 | } 62 | 63 | /*! 64 | * \brief Define the start of an ETL parallel session 65 | */ 66 | #define ETL_PARALLEL_SESSION if (auto etl_parallel_session__ = etl::detail::parallel_session()) 67 | 68 | } //end of namespace etl 69 | -------------------------------------------------------------------------------- /include/etl/op/unary_op.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Contains the unary operators for the unary expression 11 | * 12 | * A unary operator is a simple class with a static function apply that 13 | * computes its result. If the operator is vectorizable, it also contains a 14 | * static function load that computes the result for several operands at a 15 | * time. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | #include 22 | 23 | #include "etl/math.hpp" 24 | #include "etl/temporary.hpp" 25 | 26 | #include "etl/op/unary/minus.hpp" 27 | #include "etl/op/unary/plus.hpp" 28 | #include "etl/op/unary/abs.hpp" 29 | #include "etl/op/unary/floor.hpp" 30 | #include "etl/op/unary/ceil.hpp" 31 | #include "etl/op/unary/log.hpp" 32 | #include "etl/op/unary/log2.hpp" 33 | #include "etl/op/unary/log10.hpp" 34 | #include "etl/op/unary/sqrt.hpp" 35 | #include "etl/op/unary/invsqrt.hpp" 36 | #include "etl/op/unary/cbrt.hpp" 37 | #include "etl/op/unary/invcbrt.hpp" 38 | #include "etl/op/unary/tan.hpp" 39 | #include "etl/op/unary/sin.hpp" 40 | #include "etl/op/unary/cos.hpp" 41 | #include "etl/op/unary/tanh.hpp" 42 | #include "etl/op/unary/sinh.hpp" 43 | #include "etl/op/unary/cosh.hpp" 44 | #include "etl/op/unary/exp.hpp" 45 | #include "etl/op/unary/sigmoid.hpp" 46 | #include "etl/op/unary/sign.hpp" 47 | #include "etl/op/unary/softplus.hpp" 48 | #include "etl/op/unary/real.hpp" 49 | #include "etl/op/unary/imag.hpp" 50 | #include "etl/op/unary/conj.hpp" 51 | #include "etl/op/unary/relu.hpp" 52 | #include "etl/op/unary/relu_derivative.hpp" 53 | #include "etl/op/unary/bernoulli.hpp" 54 | #include "etl/op/unary/noise.hpp" 55 | #include "etl/op/unary/clip.hpp" 56 | -------------------------------------------------------------------------------- /include/etl/impl/std/sum.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Standard implementation of the "sum" reduction 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl::impl::standard { 16 | 17 | /*! 18 | * \brief Compute the sum of the input in the given expression 19 | * \param input The input expression 20 | * \return the sum 21 | */ 22 | template 23 | value_t sum(const E& input) { 24 | using T = value_t; 25 | 26 | T acc(0); 27 | 28 | auto acc_functor = [&acc](T value) { acc += value; }; 29 | 30 | auto batch_fun = [](auto& sub) { 31 | T acc(0); 32 | 33 | for (size_t i = 0; i < etl::size(sub); ++i) { 34 | acc += sub[i]; 35 | } 36 | 37 | return acc; 38 | }; 39 | 40 | engine_dispatch_1d_acc_slice(input, batch_fun, acc_functor, sum_parallel_threshold); 41 | 42 | return acc; 43 | } 44 | 45 | /*! 46 | * \brief Compute the sum of the absolute values in the given expression 47 | * \param input The input expression 48 | * \return the absolute sum 49 | */ 50 | template 51 | value_t asum(const E& input) { 52 | using T = value_t; 53 | 54 | T acc(0); 55 | 56 | auto acc_functor = [&acc](T value) { acc += value; }; 57 | 58 | auto batch_fun = [](auto& sub) { 59 | T acc(0); 60 | 61 | for (size_t i = 0; i < etl::size(sub); ++i) { 62 | using std::abs; 63 | acc += abs(sub[i]); 64 | } 65 | 66 | return acc; 67 | }; 68 | 69 | engine_dispatch_1d_acc_slice(input, batch_fun, acc_functor, sum_parallel_threshold); 70 | 71 | return acc; 72 | } 73 | 74 | } //end of namespace etl::impl::standard 75 | -------------------------------------------------------------------------------- /include/etl/memory.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Standard memory utilities 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Performs a direct memory copy 19 | * \param first pointer to the first element to copy 20 | * \param last pointer to the next-to-last element to copy 21 | * \param target pointer to the first element of the result 22 | */ 23 | template 24 | void direct_copy(const S* first, const S* last, T* target) { 25 | std::copy(first, last, target); 26 | } 27 | 28 | /*! 29 | * \brief Performs a direct memory copy 30 | * \param source pointer to the first source element 31 | * \param target pointer to the first element of the result 32 | * \param n The number of elements to copy 33 | */ 34 | template 35 | void direct_copy_n(const S* source, T* target, size_t n) { 36 | std::copy_n(source, n, target); 37 | } 38 | 39 | /*! 40 | * \brief Fills the given memory with the given value 41 | * \param first pointer to the first element to copy 42 | * \param last pointer to the next-to-last element to copy 43 | * \param value The value to fill the memory with 44 | */ 45 | template 46 | void direct_fill(S* first, S* last, T value) { 47 | std::fill(first, last, value); 48 | } 49 | 50 | /*! 51 | * \brief Fills the given memory with the given value 52 | * \param first pointer to the first element to copy 53 | * \param n The number of elements to fill 54 | * \param value The value to fill the memory with 55 | */ 56 | template 57 | void direct_fill_n(S* first, size_t n, T value) { 58 | std::fill_n(first, n, value); 59 | } 60 | 61 | } //end of namespace etl 62 | -------------------------------------------------------------------------------- /include/etl/impl/cublas/dot.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief CUBLAS implementation of the dot product 11 | */ 12 | 13 | #pragma once 14 | 15 | #ifdef ETL_CUBLAS_MODE 16 | 17 | #include "etl/impl/cublas/cuda.hpp" 18 | #include "etl/impl/cublas/cublas.hpp" 19 | 20 | #endif 21 | 22 | namespace etl::impl::cublas { 23 | 24 | #ifdef ETL_CUBLAS_MODE 25 | 26 | /*! 27 | * \brief Compute the batch_outer product of a and b and store the result in c 28 | * \param a The lhs expression 29 | * \param b The rhs expression 30 | * \param c The output expression 31 | */ 32 | template 33 | float dot(const A& a, const B& b) { 34 | decltype(auto) handle = start_cublas(); 35 | 36 | a.ensure_gpu_up_to_date(); 37 | b.ensure_gpu_up_to_date(); 38 | 39 | float prod = 0.0; 40 | cublas_check(cublasSdot(handle.get(), etl::size(a), a.gpu_memory(), 1, b.gpu_memory(), 1, &prod)); 41 | return prod; 42 | } 43 | 44 | /*! 45 | * \copydoc batch_outer 46 | */ 47 | template 48 | double dot(const A& a, const B& b) { 49 | decltype(auto) handle = start_cublas(); 50 | 51 | a.ensure_gpu_up_to_date(); 52 | b.ensure_gpu_up_to_date(); 53 | 54 | double prod = 0.0; 55 | cublas_check(cublasDdot(handle.get(), etl::size(a), a.gpu_memory(), 1, b.gpu_memory(), 1, &prod)); 56 | return prod; 57 | } 58 | 59 | #else 60 | 61 | /*! 62 | * \copydoc batch_outer 63 | */ 64 | template 65 | value_t dot(const A& /*a*/, const B& /*b*/) { 66 | cpp_unreachable("CUBLAS not enabled/available"); 67 | return 0.0; 68 | } 69 | 70 | #endif 71 | 72 | } //end of namespace etl::impl::cublas 73 | -------------------------------------------------------------------------------- /test/src/avg_pool_upsample_deep.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test.hpp" 9 | 10 | #include 11 | 12 | TEMPLATE_TEST_CASE_2("pool_upsample/dyn/avg2/deep/1", "[pooling]", Z, float, double) { 13 | std::random_device rd; 14 | etl::random_engine g(rd()); 15 | 16 | etl::dyn_matrix input(5, 9, 9); 17 | input = etl::uniform_generator(g, -1000.0, 1000.0); 18 | 19 | etl::dyn_matrix errors(5, 3, 3); 20 | errors = etl::uniform_generator(g, -1000.0, 1000.0); 21 | 22 | etl::dyn_matrix output(5, 3, 3); 23 | output = etl::avg_pool_2d(input, 3, 3); 24 | 25 | etl::dyn_matrix c1(5, 9, 9); 26 | etl::dyn_matrix c2(5, 9, 9); 27 | 28 | c1 = etl::avg_pool_derivative_2d(input, output, 3, 3) >> etl::upsample_2d(errors, 3, 3); 29 | c2 = etl::avg_pool_upsample_2d(input, output, errors, 3, 3); 30 | 31 | REQUIRE_DIRECT(approx_equals(c1, c2, base_eps_etl)); 32 | } 33 | 34 | TEMPLATE_TEST_CASE_2("pool_upsample/avg2/deep/1", "[pooling]", Z, float, double) { 35 | std::random_device rd; 36 | etl::random_engine g(rd()); 37 | 38 | etl::fast_matrix input; 39 | input = etl::uniform_generator(g, -1000.0, 1000.0); 40 | 41 | etl::fast_matrix errors; 42 | errors = etl::uniform_generator(g, -1000.0, 1000.0); 43 | 44 | etl::fast_matrix output; 45 | output = etl::avg_pool_2d<2, 1>(input); 46 | 47 | etl::fast_matrix c1; 48 | etl::fast_matrix c2; 49 | 50 | c1 = etl::avg_pool_derivative_2d<2, 1>(input, output) >> etl::upsample_2d<2, 1>(errors); 51 | c2 = etl::avg_pool_upsample_2d<2, 1>(input, output, errors); 52 | 53 | REQUIRE_DIRECT(approx_equals(c1, c2, base_eps_etl)); 54 | } 55 | -------------------------------------------------------------------------------- /include/etl/impl/std/convmtx2.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | namespace etl::impl::standard { 11 | 12 | /*! 13 | * \brief Direct evaluation of conmtx2 14 | */ 15 | struct convmtx2_direct { 16 | /*! 17 | * \brief Apply the convmtx2 to sub into m 18 | * \param sub The sub expression 19 | * \param m The output matrix 20 | */ 21 | template 22 | static void apply(A&& sub, M& m) { 23 | const size_t i1 = etl::dim<0>(sub); 24 | const size_t i2 = etl::dim<1>(sub); 25 | 26 | const size_t c_height = etl::dim<0>(m); 27 | constexpr size_t c_width = K1 * K2; 28 | 29 | cpp_assert(c_height == ((i1 + K1 - 1) * (i2 + K2 - 1)), "Invalid input height"); 30 | cpp_assert(c_width == etl::dim<1>(m), "Invalid input width"); 31 | 32 | const auto max_fill = c_height - ((i1 + K1 - 1) * ((c_width - 1) / K1) + (c_width - 1) % K1); 33 | const auto inner_paddings = max_fill - (i1 * i2); 34 | const auto inner_padding = inner_paddings / (i2 - 1); 35 | 36 | m = 0; 37 | 38 | for (size_t j = 0; j < c_width; ++j) { 39 | auto top_padding = (i1 + K1 - 1) * (j / K1) + j % K1; 40 | auto bottom_padding = top_padding + (i1 * i2) + inner_paddings; 41 | 42 | for (size_t i = top_padding; i < bottom_padding; ++i) { 43 | auto inner = i - top_padding; 44 | auto block = inner / (i1 + inner_padding); 45 | auto col = inner % (i1 + inner_padding); 46 | 47 | if (col < i1) { 48 | m(i, j) = sub(col, block); 49 | } 50 | } 51 | } 52 | } 53 | }; 54 | 55 | } //end of namespace etl::impl::standard 56 | -------------------------------------------------------------------------------- /test/src/noise.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test_light.hpp" 9 | 10 | // These are simply compilation tests to avoid regression in noise functions 11 | 12 | TEMPLATE_TEST_CASE_2("logistic_noise/0", "[logistic_noise]", Z, float, double) { 13 | etl::fast_matrix a = {-1.0, 2.0, 5.0, 1.0}; 14 | etl::fast_matrix d; 15 | 16 | d = logistic_noise(a); 17 | } 18 | 19 | TEMPLATE_TEST_CASE_2("logistic_noise/1", "[logistic_noise]", Z, float, double) { 20 | etl::fast_matrix a = {-1.0, 2.0, 5.0, 1.0}; 21 | etl::fast_matrix d; 22 | 23 | etl::random_engine g(666); 24 | d = logistic_noise(g, a); 25 | } 26 | 27 | TEMPLATE_TEST_CASE_2("logistic_noise/2", "[logistic_noise]", Z, float, double) { 28 | etl::fast_matrix a = {-1.0, 2.0, 5.0, 1.0}; 29 | etl::fast_matrix d; 30 | 31 | d = state_logistic_noise(a); 32 | } 33 | 34 | TEMPLATE_TEST_CASE_2("logistic_noise/3", "[logistic_noise]", Z, float, double) { 35 | etl::fast_matrix a = {-1.0, 2.0, 5.0, 1.0}; 36 | etl::fast_matrix d; 37 | 38 | etl::random_engine g(666); 39 | d = state_logistic_noise(g, a); 40 | } 41 | 42 | TEMPLATE_TEST_CASE_2("logistic_noise/4", "[logistic_noise]", Z, float, double) { 43 | etl::fast_matrix a = {-1.0, 2.0, 5.0, 1.0}; 44 | etl::fast_matrix d; 45 | 46 | auto states = std::make_shared(); 47 | d = state_logistic_noise(a, states); 48 | } 49 | 50 | TEMPLATE_TEST_CASE_2("logistic_noise/5", "[logistic_noise]", Z, float, double) { 51 | etl::fast_matrix a = {-1.0, 2.0, 5.0, 1.0}; 52 | etl::fast_matrix d; 53 | 54 | auto states = std::make_shared(); 55 | etl::random_engine g(666); 56 | d = state_logistic_noise(g, a, states); 57 | } 58 | -------------------------------------------------------------------------------- /test/include/template_test.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_DECL(name, description, T) \ 9 | template \ 10 | static void UNIQUE_NAME(____T_E_M_P_L_A_TE____T_E_S_T____)(); \ 11 | ETL_TEST_CASE(name, description) 12 | 13 | #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_SECTION(Tn) \ 14 | ETL_SECTION(#Tn) { \ 15 | UNIQUE_NAME(____T_E_M_P_L_A_TE____T_E_S_T____)(); \ 16 | } 17 | 18 | #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_DEFN(T) \ 19 | template \ 20 | static void UNIQUE_NAME(____T_E_M_P_L_A_TE____T_E_S_T____)() 21 | 22 | #define TEMPLATE_TEST_CASE_2(name, description, T, T1, T2) \ 23 | INTERNAL_CATCH_TEMPLATE_TEST_CASE_DECL(name, description, T) { \ 24 | INTERNAL_CATCH_TEMPLATE_TEST_CASE_SECTION(T1) \ 25 | INTERNAL_CATCH_TEMPLATE_TEST_CASE_SECTION(T2) \ 26 | } \ 27 | INTERNAL_CATCH_TEMPLATE_TEST_CASE_DEFN(T) 28 | 29 | #define TEMPLATE_TEST_CASE_4(name, description, T, T1, T2, T3, T4) \ 30 | INTERNAL_CATCH_TEMPLATE_TEST_CASE_DECL(name, description, T) { \ 31 | INTERNAL_CATCH_TEMPLATE_TEST_CASE_SECTION(T1) \ 32 | INTERNAL_CATCH_TEMPLATE_TEST_CASE_SECTION(T2) \ 33 | INTERNAL_CATCH_TEMPLATE_TEST_CASE_SECTION(T3) \ 34 | INTERNAL_CATCH_TEMPLATE_TEST_CASE_SECTION(T4) \ 35 | } \ 36 | INTERNAL_CATCH_TEMPLATE_TEST_CASE_DEFN(T) 37 | -------------------------------------------------------------------------------- /scripts/bench_runner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Disable default options 5 | export ETL_NO_DEFAULT=true 6 | unset ETL_DEFAULTS 7 | unset ETL_MKL 8 | unset ETL_BLAS 9 | 10 | # Start with clang 11 | export CXX=clang++ 12 | export LD=clang++ 13 | 14 | echo "Configuration 1. Clang (default)" 15 | 16 | make clean 17 | make -j6 release/bin/benchmark 18 | time ./release/bin/benchmark $BENCH_ARGS --tag=`git rev-list HEAD --count`-`git rev-parse HEAD` --configuration=default 19 | 20 | echo "Configuration 2. Clang (vectorize_full)" 21 | 22 | export ETL_DEFAULTS="-DETL_VECTORIZE_FULL" 23 | 24 | make clean 25 | make -j6 release/bin/benchmark 26 | time ./release/bin/benchmark $BENCH_ARGS --tag=`git rev-list HEAD --count`-`git rev-parse HEAD` --configuration=vectorize_full 27 | 28 | echo "Configuration 3. Clang (vectorize_full mkl_mode)" 29 | 30 | unset ETL_BLAS 31 | export ETL_DEFAULTS="-DETL_VECTORIZE_FULL" 32 | export ETL_MKL=true 33 | 34 | make clean 35 | make -j6 release/bin/benchmark 36 | time ./release/bin/benchmark $BENCH_ARGS --tag=`git rev-list HEAD --count`-`git rev-parse HEAD` --configuration="mkl_mode+vectorize_full" 37 | 38 | unset ETL_DEFAULTS 39 | unset ETL_MKL 40 | 41 | # Continue with gcc 42 | export CXX=$ETL_GPP 43 | export LD=$ETL_GPP 44 | 45 | echo "Configuration 1. GCC (default)" 46 | 47 | make clean 48 | make -j6 release/bin/benchmark 49 | time ./release/bin/benchmark $BENCH_ARGS --tag=`git rev-list HEAD --count`-`git rev-parse HEAD` --configuration=default 50 | 51 | echo "Configuration 2. GCC (vectorize_full)" 52 | 53 | export ETL_DEFAULTS="-DETL_VECTORIZE_FULL" 54 | 55 | make clean 56 | make -j6 release/bin/benchmark 57 | time ./release/bin/benchmark $BENCH_ARGS --tag=`git rev-list HEAD --count`-`git rev-parse HEAD` --configuration=vectorize_full 58 | 59 | echo "Configuration 3. GCC (vectorize_full mkl_mode)" 60 | 61 | unset ETL_BLAS 62 | export ETL_DEFAULTS="-DETL_VECTORIZE_FULL" 63 | export ETL_MKL=true 64 | 65 | make clean 66 | make -j6 release/bin/benchmark 67 | time ./release/bin/benchmark $BENCH_ARGS --tag=`git rev-list HEAD --count`-`git rev-parse HEAD` --configuration="mkl_mode+vectorize_full" 68 | -------------------------------------------------------------------------------- /include/etl/conv_impl.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Enumeration of the different convolution implementations 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | /*! 18 | * \brief Enumeration describing the different convolution implementations 19 | */ 20 | enum class conv_impl { 21 | STD, ///< Standard implementation 22 | VEC, ///< Uniform Vectorized Implementation with locality 23 | CUDNN, ///< CUDNN implementation 24 | FFT_STD, ///< FFT reduction (with STD impl) 25 | FFT_MKL, ///< FFT reduction (with MKL impl) 26 | FFT_CUFFT, ///< FFT reduction (with CUFFT impl) 27 | EGBLAS ///< EGBLAS implementation 28 | }; 29 | 30 | /*! 31 | * \brief Enumeration describing the different convolution implementations 32 | */ 33 | enum class conv4_impl { 34 | STD, ///< Standard implementation 35 | VEC, ///< VEC implementation 36 | CUDNN, ///< CUDNN implementation 37 | FFT_STD, ///< FFT reduction (with STD impl) 38 | FFT_MKL, ///< FFT reduction (with MKL impl) 39 | FFT_CUFFT, ///< FFT reduction (with CUFFT impl) 40 | BLAS_VEC, ///< BLAS reduction 41 | BLAS_MKL ///< BLAS reduction 42 | }; 43 | 44 | /*! 45 | * \brief Enumeration describing the different multiple convolution implementations 46 | */ 47 | enum class conv_multi_impl { 48 | STD, ///< Standard implementation 49 | VEC, ///< VEC implementation 50 | VALID_FFT_MKL, ///< Reductiont to FFT (valid) 51 | FFT_STD, ///< FFT reduction (with STD impl) 52 | FFT_MKL, ///< FFT reduction (with MKL impl) 53 | FFT_CUFFT, ///< FFT reduction (with CUFFT impl) 54 | BLAS_VEC, ///< Reduction to BLAS (GEMM) 55 | BLAS_MKL, ///< Reduction to BLAS (GEMM) 56 | CUDNN ///< GPU with CUDNN 57 | }; 58 | 59 | } //end of namespace etl 60 | -------------------------------------------------------------------------------- /Implementation.rst: -------------------------------------------------------------------------------- 1 | Implementation notes 2 | ==================== 3 | 4 | Code use C++17 extensively. 5 | 6 | For now code is made to be compiled with: 7 | 8 | * >=g++-13 9 | * >=g++-16 10 | 11 | Due to modern features that are being used, it is unlikely that everything works on 12 | Windows and icc. 13 | 14 | Compile-Time 15 | ------------ 16 | 17 | The time to compile expressions it currently not great. It is 18 | starting to take quite long. 19 | 20 | There are several possible solutions 21 | 22 | * Type erasure of some sort. For instance, some algorithms only 23 | need some sizes and data pointer, this would save some 24 | instantiations, but could mean major changes 25 | * Simplify the interface template by removing some enable_if and 26 | making some things not template. 27 | 28 | However, some things are harder than expected. For instance, it is 29 | not possible to remove SFINAE on the operator+ funtion because 30 | otherwise it would match etl::complex or iterators from dyn matrix. 31 | This is because forwarding is "too perfect" and because ADL is used. 32 | 33 | Coverage 34 | -------- 35 | 36 | Coverage is quite low (less than 50%). The biggest problem now is that it is not 37 | possible to merge the coverage statistics of several runs correctly. I do 38 | believe that the real coverage is in fact higher than this. The merge process 39 | only takes the maximum coverage from each profile, but does not merge individual 40 | functions meaning that a lot of information is lost. One other problems is that 41 | some files are polluting the results since they are not meant to be covered, for 42 | instance SFINAE selection is never meant to be "executed" and no_vectorization 43 | as well. 44 | 45 | How to improve coverage: 46 | * Improve the merge of multiple coverage profile 47 | * Remove some files from the coverage analysis 48 | * Remove some lines from the coverage analysis (asserts) (not 49 | possible in sonar unfortunately) 50 | * Test more things, obviously 51 | 52 | Notes 53 | ----- 54 | 55 | There are too many corner cases in evaluation of expressions: 56 | * direct_evaluate 57 | * compound expressions 58 | * forced expressions 59 | * reductions 60 | 61 | This should be improved 62 | -------------------------------------------------------------------------------- /include/etl/op/binary/mod.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | namespace etl { 11 | 12 | /*! 13 | * \brief Binary operator for scalar modulo 14 | */ 15 | template 16 | struct mod_binary_op { 17 | static constexpr bool linear = true; ///< Indicates if the operator is linear or not 18 | static constexpr bool thread_safe = true; ///< Indicates if the operator is thread safe or not 19 | static constexpr bool desc_func = false; ///< Indicates if the description must be printed as function 20 | 21 | /*! 22 | * \brief Indicates if the expression is vectorizable using the 23 | * given vector mode 24 | * \tparam V The vector mode 25 | */ 26 | template 27 | static constexpr bool vectorizable = false; 28 | 29 | /*! 30 | * \brief Indicates if the operator can be computed on GPU 31 | */ 32 | template 33 | static constexpr bool gpu_computable = false; 34 | 35 | /*! 36 | * \brief Estimate the complexity of operator 37 | * \return An estimation of the complexity of the operator 38 | */ 39 | static constexpr int complexity() { 40 | return 2; 41 | } 42 | 43 | /*! 44 | * \brief Apply the unary operator on lhs and rhs 45 | * \param lhs The left hand side value on which to apply the operator 46 | * \param rhs The right hand side value on which to apply the operator 47 | * \return The result of applying the binary operator on lhs and rhs 48 | */ 49 | static constexpr T apply(const T& lhs, const T& rhs) noexcept { 50 | return lhs % rhs; 51 | } 52 | 53 | /*! 54 | * \brief Returns a textual representation of the operator 55 | * \return a string representing the operator 56 | */ 57 | static std::string desc() noexcept { 58 | return "%"; 59 | } 60 | }; 61 | 62 | } //end of namespace etl 63 | -------------------------------------------------------------------------------- /include/etl/impl/cublas/cuda.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | #ifdef ETL_CUDA 11 | 12 | #include "cuda.h" 13 | #include "cuda_runtime.h" 14 | #include "cuda_runtime_api.h" 15 | #include "cuComplex.h" 16 | 17 | #include "etl/util/complex_cast.hpp" 18 | 19 | #define cuda_check(call) \ 20 | { \ 21 | auto status = call; \ 22 | if (status != cudaSuccess) { \ 23 | std::cerr << "CUDA error: " << cudaGetErrorString(status) << " from " << #call << std::endl \ 24 | << "from " << __FILE__ << ":" << __LINE__ << std::endl; \ 25 | } \ 26 | } 27 | 28 | #define cuda_check_assert(call) \ 29 | { \ 30 | auto status = call; \ 31 | if (status != cudaSuccess) { \ 32 | std::cerr << "CUDA error: " << cudaGetErrorString(status) << " from " << #call << std::endl \ 33 | << "from " << __FILE__ << ":" << __LINE__ << std::endl; \ 34 | std::abort(); \ 35 | } \ 36 | } 37 | 38 | #endif 39 | -------------------------------------------------------------------------------- /include/etl/op/binary/one_if.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | namespace etl { 11 | 12 | /*! 13 | * \brief Binary operator to get 1.0 if x equals to rhs value, 0 otherwise 14 | */ 15 | template 16 | struct one_if_binary_op { 17 | static constexpr bool linear = true; ///< Indicates if the operator is linear or not 18 | static constexpr bool thread_safe = true; ///< Indicates if the operator is thread safe or not 19 | static constexpr bool desc_func = true; ///< Indicates if the description must be printed as function 20 | 21 | /*! 22 | * \brief Indicates if the expression is vectorizable using the 23 | * given vector mode 24 | * \tparam V The vector mode 25 | */ 26 | template 27 | static constexpr bool vectorizable = false; 28 | 29 | /*! 30 | * \brief Indicates if the operator can be computed on GPU 31 | */ 32 | template 33 | static constexpr bool gpu_computable = false; 34 | 35 | /*! 36 | * \brief Estimate the complexity of operator 37 | * \return An estimation of the complexity of the operator 38 | */ 39 | static constexpr int complexity() { 40 | return 1; 41 | } 42 | 43 | /*! 44 | * \brief Apply the unary operator on lhs and rhs 45 | * \param x The left hand side value on which to apply the operator 46 | * \param value The right hand side value on which to apply the operator 47 | * \return The result of applying the binary operator on lhs and rhs 48 | */ 49 | static constexpr T apply(const T& x, E value) noexcept { 50 | return x == value ? 1.0 : 0.0; 51 | } 52 | 53 | /*! 54 | * \brief Returns a textual representation of the operator 55 | * \return a string representing the operator 56 | */ 57 | static std::string desc() noexcept { 58 | return "one_if"; 59 | } 60 | }; 61 | 62 | } //end of namespace etl 63 | -------------------------------------------------------------------------------- /test/src/conv_4d_valid_mixed.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test.hpp" 9 | #include "conv_test.hpp" 10 | 11 | ETL_TEST_CASE("conv4/valid/mixed/0", "[conv][conv4][valid]") { 12 | etl::fast_matrix I; 13 | etl::fast_matrix K; 14 | 15 | I = etl::sequence_generator(10.0) * 4.0; 16 | K = etl::sequence_generator(2.0) * 0.3; 17 | 18 | etl::fast_matrix ref; 19 | etl::fast_matrix c; 20 | 21 | SELECTED_SECTION(etl::conv_impl::STD) { 22 | ref = 0.0; 23 | for (size_t i = 0; i < etl::dim<0>(I); ++i) { 24 | for (size_t c = 0; c < etl::dim<1>(K); ++c) { 25 | for (size_t k = 0; k < etl::dim<0>(K); ++k) { 26 | ref(i)(k) += conv_2d_valid(I(i)(c), K(k)(c)); 27 | } 28 | } 29 | } 30 | } 31 | 32 | c = etl::conv_4d_valid(I, K); 33 | 34 | for (size_t i = 0; i < ref.size(); ++i) { 35 | REQUIRE_EQUALS_APPROX(c[i], ref[i]); 36 | } 37 | } 38 | 39 | ETL_TEST_CASE("conv4/valid/mixed/1", "[conv][conv4][valid]") { 40 | etl::fast_matrix I; 41 | etl::fast_matrix_cm K; 42 | 43 | I = etl::sequence_generator(10.0) * 4.0; 44 | K = etl::sequence_generator(2.0) * 0.3; 45 | 46 | etl::fast_matrix ref; 47 | etl::fast_matrix c; 48 | 49 | SELECTED_SECTION(etl::conv_impl::STD) { 50 | ref = 0.0; 51 | for (size_t i = 0; i < etl::dim<0>(I); ++i) { 52 | for (size_t c = 0; c < etl::dim<1>(K); ++c) { 53 | for (size_t k = 0; k < etl::dim<0>(K); ++k) { 54 | ref(i)(k) += conv_2d_valid(I(i)(c), K(k)(c)); 55 | } 56 | } 57 | } 58 | } 59 | 60 | c = etl::conv_4d_valid(I, K); 61 | 62 | for (size_t i = 0; i < ref.size(); ++i) { 63 | REQUIRE_EQUALS_APPROX(c[i], ref[i]); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /test/src/conv_4d_full_mixed.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test.hpp" 9 | #include "conv_test.hpp" 10 | 11 | ETL_TEST_CASE("conv/4d/full/mixed/0", "[conv][conv4][full]") { 12 | etl::fast_matrix I; 13 | etl::fast_matrix K; 14 | 15 | I = etl::sequence_generator(3.0) * 0.4; 16 | K = etl::sequence_generator(2.0) * 0.3; 17 | 18 | etl::fast_matrix ref; 19 | etl::fast_matrix c; 20 | 21 | SELECTED_SECTION(etl::conv_impl::STD) { 22 | ref = 0.0; 23 | for (size_t i = 0; i < etl::dim<0>(I); ++i) { 24 | for (size_t c = 0; c < etl::dim<1>(K); ++c) { 25 | for (size_t k = 0; k < etl::dim<0>(K); ++k) { 26 | ref(i)(c) += conv_2d_full(I(i)(k), K(k)(c)); 27 | } 28 | } 29 | } 30 | } 31 | 32 | c = conv_4d_full(I, K); 33 | 34 | for (size_t i = 0; i < ref.size(); ++i) { 35 | REQUIRE_EQUALS_APPROX_E(c[i], ref[i], base_eps * 100000); 36 | } 37 | } 38 | 39 | ETL_TEST_CASE("conv/4d/full/mixed/1", "[conv][conv4][full]") { 40 | etl::fast_matrix I; 41 | etl::fast_matrix_cm K; 42 | 43 | I = etl::sequence_generator(3.0) * 0.4; 44 | K = etl::sequence_generator(2.0) * 0.3; 45 | 46 | etl::fast_matrix ref; 47 | etl::fast_matrix c; 48 | 49 | SELECTED_SECTION(etl::conv_impl::STD) { 50 | ref = 0.0; 51 | for (size_t i = 0; i < etl::dim<0>(I); ++i) { 52 | for (size_t c = 0; c < etl::dim<1>(K); ++c) { 53 | for (size_t k = 0; k < etl::dim<0>(K); ++k) { 54 | ref(i)(c) += conv_2d_full(I(i)(k), K(k)(c)); 55 | } 56 | } 57 | } 58 | } 59 | 60 | c = conv_4d_full(I, K); 61 | 62 | for (size_t i = 0; i < ref.size(); ++i) { 63 | REQUIRE_EQUALS_APPROX_E(c[i], ref[i], base_eps * 100000); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /include/etl/op/binary_op.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief Contains binary operators 11 | */ 12 | 13 | #pragma once 14 | 15 | #include 16 | 17 | #include "etl/math.hpp" 18 | #include "etl/temporary.hpp" 19 | 20 | #ifdef ETL_CUBLAS_MODE 21 | #include "etl/impl/cublas/cuda.hpp" 22 | #include "etl/impl/cublas/cublas.hpp" 23 | #include "etl/impl/cublas/axpy.hpp" 24 | #include "etl/impl/cublas/scal.hpp" 25 | #endif 26 | 27 | #include "etl/impl/egblas/apxdbpy.hpp" 28 | #include "etl/impl/egblas/apxdbpy_3.hpp" 29 | #include "etl/impl/egblas/apxdby.hpp" 30 | #include "etl/impl/egblas/apxdby_3.hpp" 31 | #include "etl/impl/egblas/axdbpy.hpp" 32 | #include "etl/impl/egblas/axdbpy_3.hpp" 33 | #include "etl/impl/egblas/axdy.hpp" 34 | #include "etl/impl/egblas/axdy_3.hpp" 35 | #include "etl/impl/egblas/axmy.hpp" 36 | #include "etl/impl/egblas/axmy_3.hpp" 37 | #include "etl/impl/egblas/axpby.hpp" 38 | #include "etl/impl/egblas/axpby_3.hpp" 39 | #include "etl/impl/egblas/axpy.hpp" 40 | #include "etl/impl/egblas/axpy_3.hpp" 41 | 42 | #include "etl/impl/egblas/scalar_add.hpp" 43 | #include "etl/impl/egblas/scalar_div.hpp" 44 | #include "etl/impl/egblas/scalar_mul.hpp" 45 | 46 | #include "etl/op/binary/plus.hpp" 47 | #include "etl/op/binary/minus.hpp" 48 | #include "etl/op/binary/mul.hpp" 49 | #include "etl/op/binary/div.hpp" 50 | #include "etl/op/binary/mod.hpp" 51 | #include "etl/op/binary/equal.hpp" 52 | #include "etl/op/binary/not_equal.hpp" 53 | #include "etl/op/binary/less.hpp" 54 | #include "etl/op/binary/less_equal.hpp" 55 | #include "etl/op/binary/greater.hpp" 56 | #include "etl/op/binary/greater_equal.hpp" 57 | #include "etl/op/binary/logical_and.hpp" 58 | #include "etl/op/binary/logical_or.hpp" 59 | #include "etl/op/binary/logical_xor.hpp" 60 | #include "etl/op/binary/min.hpp" 61 | #include "etl/op/binary/max.hpp" 62 | #include "etl/op/binary/one_if.hpp" 63 | #include "etl/op/binary/ranged_noise.hpp" 64 | #include "etl/op/binary/sigmoid_derivative.hpp" 65 | #include "etl/op/binary/relu_derivative.hpp" 66 | #include "etl/op/binary/pow.hpp" 67 | -------------------------------------------------------------------------------- /include/etl/expr_fwd.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | namespace etl { 11 | 12 | template 13 | struct optimizable; 14 | 15 | template 16 | struct optimizer; 17 | 18 | template 19 | struct transformer; 20 | 21 | struct identity_op; 22 | 23 | struct transform_op; 24 | 25 | template 26 | struct stateful_op; 27 | 28 | template Expr, typename UnaryOp> 29 | struct unary_expr; 30 | 31 | template LeftExpr, typename BinaryOp, expr_or_scalar RightExpr> 32 | struct binary_expr; 33 | 34 | template 35 | class generator_expr; 36 | 37 | template 38 | struct optimized_expr; 39 | 40 | template 41 | struct serial_expr; 42 | 43 | template 44 | struct selected_expr; 45 | 46 | template 47 | struct parallel_expr; 48 | 49 | template 50 | struct timed_expr; 51 | 52 | template 53 | struct temporary_expr_bin; 54 | 55 | template 56 | struct dim_view; 57 | 58 | template 59 | struct sub_view; 60 | 61 | template 62 | struct sub_matrix_2d; 63 | 64 | template 65 | struct sub_matrix_3d; 66 | 67 | template 68 | struct sub_matrix_4d; 69 | 70 | template 71 | struct slice_view; 72 | 73 | template 74 | struct memory_slice_view; 75 | 76 | template 77 | struct fast_matrix_view; 78 | 79 | template 80 | struct dyn_matrix_view; 81 | 82 | template 83 | struct transpose_expr; 84 | 85 | template 86 | struct transpose_front_expr; 87 | 88 | template 89 | struct base_temporary_expr; 90 | 91 | template 92 | struct scalar; 93 | 94 | } //end of namespace etl 95 | -------------------------------------------------------------------------------- /test/src/timed.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test_light.hpp" 9 | 10 | namespace { 11 | 12 | bool starts_with(const std::string& str, const std::string& search) { 13 | return std::mismatch(search.begin(), search.end(), str.begin()).first == search.end(); 14 | } 15 | 16 | } //end of anonymous namespace 17 | 18 | TEMPLATE_TEST_CASE_2("timed/1", "[fast][serial]", Z, float, double) { 19 | std::stringstream buffer; 20 | auto* old = std::cout.rdbuf(buffer.rdbuf()); 21 | 22 | etl::fast_vector a({1.0, -2.0, 3.0}); 23 | etl::fast_vector b; 24 | 25 | b = timed(a + a); 26 | 27 | auto text = buffer.str(); 28 | std::cout.rdbuf(old); 29 | 30 | REQUIRE_DIRECT(starts_with(text, "timed(=): (V[3] + V[3]) took ")); 31 | REQUIRE_EQUALS(std::string(text.end() - 3, text.end() - 1), "ns"); 32 | 33 | REQUIRE_EQUALS(b[0], 2.0); 34 | } 35 | 36 | TEMPLATE_TEST_CASE_2("timed/2", "[dyn][serial]", Z, float, double) { 37 | std::stringstream buffer; 38 | auto* old = std::cout.rdbuf(buffer.rdbuf()); 39 | 40 | etl::dyn_vector a(10000); 41 | etl::dyn_vector b(10000); 42 | 43 | a = 1.0; 44 | b = 2.0; 45 | 46 | b = timed(a + b); 47 | 48 | auto text = buffer.str(); 49 | std::cout.rdbuf(old); 50 | 51 | REQUIRE_DIRECT(starts_with(text, "timed(=): (V[10000] + V[10000]) took ")); 52 | REQUIRE_EQUALS(std::string(text.end() - 3, text.end() - 1), "ns"); 53 | 54 | REQUIRE_EQUALS(b[0], 3.0); 55 | } 56 | 57 | TEMPLATE_TEST_CASE_2("timed/3", "[dyn][serial]", Z, float, double) { 58 | std::stringstream buffer; 59 | auto* old = std::cout.rdbuf(buffer.rdbuf()); 60 | 61 | etl::dyn_vector a(10000); 62 | etl::dyn_vector b(10000); 63 | 64 | a = 1.0; 65 | 66 | b = etl::timed_res(a + a); 67 | 68 | auto text = buffer.str(); 69 | std::cout.rdbuf(old); 70 | 71 | REQUIRE_DIRECT(starts_with(text, "timed(=): (V[10000] + V[10000]) took ")); 72 | REQUIRE_EQUALS(std::string(text.end() - 3, text.end() - 1), "ms"); 73 | 74 | REQUIRE_EQUALS(b[0], 2.0); 75 | } 76 | -------------------------------------------------------------------------------- /test/src/assert.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #ifndef NDEBUG 9 | 10 | #include "test_light.hpp" 11 | 12 | ETL_TEST_CASE("assert/nothrow/1", "[assert]") { 13 | #ifdef CPP_UTILS_ASSERT_EXCEPTION 14 | REQUIRE_DIRECT(!etl::assert_nothrow); 15 | #else 16 | REQUIRE_DIRECT(etl::assert_nothrow); 17 | #endif 18 | } 19 | 20 | #ifdef CPP_UTILS_ASSERT_EXCEPTION 21 | ETL_TEST_CASE("assert/sizes/1", "[assert]") { 22 | etl::dyn_vector a = {-1.0, 2.0, 5.0}; 23 | etl::dyn_vector b = {2.5, 3.0, 4.0, 1.0}; 24 | 25 | REQUIRE_THROWS(a + b); 26 | } 27 | 28 | ETL_TEST_CASE("assert/dim/1", "[assert]") { 29 | etl::fast_matrix matrix; 30 | 31 | REQUIRE_NOTHROW(matrix(1, 1, 1)); 32 | REQUIRE_THROWS(matrix(3, 2, 1)); 33 | REQUIRE_THROWS(matrix(2, 2, 1)); 34 | REQUIRE_THROWS(matrix(1, 5, 1)); 35 | REQUIRE_THROWS(matrix(1, 1, 5)); 36 | REQUIRE_THROWS(matrix(1, 1, 4)); 37 | REQUIRE_THROWS(matrix(3, 3, 4)); 38 | } 39 | 40 | ETL_TEST_CASE("assert/dim/2", "[assert]") { 41 | etl::fast_dyn_matrix matrix; 42 | 43 | REQUIRE_NOTHROW(matrix(1, 1, 1)); 44 | REQUIRE_THROWS(matrix(3, 2, 1)); 45 | REQUIRE_THROWS(matrix(2, 2, 1)); 46 | REQUIRE_THROWS(matrix(1, 5, 1)); 47 | REQUIRE_THROWS(matrix(1, 1, 5)); 48 | REQUIRE_THROWS(matrix(1, 1, 4)); 49 | REQUIRE_THROWS(matrix(3, 3, 4)); 50 | } 51 | 52 | ETL_TEST_CASE("assert/dim/3", "[assert]") { 53 | etl::dyn_matrix matrix(2, 3, 4); 54 | 55 | REQUIRE_NOTHROW(matrix(1, 1, 1)); 56 | REQUIRE_THROWS(matrix(3, 2, 1)); 57 | REQUIRE_THROWS(matrix(2, 2, 1)); 58 | REQUIRE_THROWS(matrix(1, 5, 1)); 59 | REQUIRE_THROWS(matrix(1, 1, 5)); 60 | REQUIRE_THROWS(matrix(1, 1, 4)); 61 | REQUIRE_THROWS(matrix(3, 3, 4)); 62 | } 63 | 64 | ETL_TEST_CASE("assert/dim/4", "[assert]") { 65 | etl::sparse_matrix matrix(2, 3); 66 | 67 | REQUIRE_NOTHROW(matrix(1, 1)); 68 | REQUIRE_THROWS(matrix(3, 2)); 69 | REQUIRE_THROWS(matrix(2, 2)); 70 | REQUIRE_THROWS(matrix(1, 5)); 71 | REQUIRE_THROWS(matrix(3, 5)); 72 | } 73 | #endif 74 | 75 | #endif 76 | -------------------------------------------------------------------------------- /test/src/iterators.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test.hpp" 9 | 10 | TEMPLATE_TEST_CASE_2("iterator/1", "[iterator]", Z, float, double) { 11 | etl::fast_matrix M(5.5); 12 | 13 | REQUIRE((std::is_same_v)); 14 | REQUIRE((std::is_same_v)); 15 | 16 | REQUIRE((std::is_same_v)); 17 | REQUIRE((std::is_same_v)); 18 | } 19 | 20 | TEMPLATE_TEST_CASE_2("iterator/2", "[iterator]", Z, float, double) { 21 | etl::dyn_matrix M(2,3,4); 22 | 23 | REQUIRE((std::is_same_v)); 24 | REQUIRE((std::is_same_v)); 25 | 26 | REQUIRE((std::is_same_v)); 27 | REQUIRE((std::is_same_v)); 28 | } 29 | 30 | TEMPLATE_TEST_CASE_2("iterator/3", "[iterator]", Z, float, double) { 31 | etl::dyn_matrix M(2,3,4); 32 | 33 | auto A = M(0); 34 | 35 | REQUIRE((std::is_same_v)); 36 | REQUIRE((std::is_same_v)); 37 | 38 | REQUIRE((std::is_same_v)); 39 | REQUIRE((std::is_same_v)); 40 | } 41 | 42 | TEMPLATE_TEST_CASE_2("iterator/4", "[iterator]", Z, float, double) { 43 | etl::dyn_matrix M(3,3); 44 | 45 | auto A = M * M; 46 | 47 | REQUIRE((std::is_same_v)); 48 | REQUIRE((std::is_same_v)); 49 | 50 | REQUIRE((std::is_same_v)); 51 | REQUIRE((std::is_same_v)); 52 | } 53 | 54 | TEMPLATE_TEST_CASE_2("iterator/5", "[iterator]", Z, float, double) { 55 | etl::dyn_matrix M(3, 3,3); 56 | 57 | auto A = M(0) * M(1); 58 | 59 | REQUIRE((std::is_same_v)); 60 | REQUIRE((std::is_same_v)); 61 | 62 | REQUIRE((std::is_same_v)); 63 | REQUIRE((std::is_same_v)); 64 | } 65 | -------------------------------------------------------------------------------- /include/etl/std.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | // STL 11 | #include 12 | #include 13 | #include 14 | #include //For value_testable 15 | #include //For stream support 16 | #include //For static assertions tests 17 | #include //For TMP stuff 18 | #include 19 | 20 | // cpp_utils 21 | #include "cpp_utils/compat.hpp" 22 | #include "cpp_utils/tmp.hpp" 23 | #include "cpp_utils/likely.hpp" 24 | #include "cpp_utils/assert.hpp" 25 | #include "cpp_utils/parallel.hpp" 26 | 27 | // Macro to handle noexcept and cpp_assert 28 | 29 | namespace etl { 30 | 31 | #ifdef NDEBUG 32 | constexpr bool assert_nothrow = true; 33 | #else 34 | #ifdef CPP_UTILS_ASSERT_EXCEPTION 35 | constexpr bool assert_nothrow = false; 36 | #else 37 | constexpr bool assert_nothrow = true; 38 | #endif 39 | #endif 40 | 41 | /*! 42 | * \brief Alignment flag to aligned expressions 43 | * 44 | * This can be used to make expressions more clear. 45 | */ 46 | constexpr bool aligned = false; 47 | 48 | /*! 49 | * \brief Alignment flag to unaligned expressions. 50 | * 51 | * This can be used to make expressions more clear. 52 | */ 53 | constexpr bool unaligned = false; 54 | 55 | /*! 56 | * \brief The current major version number of the library 57 | */ 58 | constexpr size_t version_major = 1; 59 | 60 | /*! 61 | * \brief The current minor version number of the library 62 | */ 63 | constexpr size_t version_minor = 3; 64 | 65 | /*! 66 | * \brief The current revision version number of the library 67 | */ 68 | constexpr size_t version_revision = 0; 69 | 70 | } //end of namespace etl 71 | 72 | /*! 73 | * \brief String representation of the current version of the library. 74 | */ 75 | #define ETL_VERSION_STR "1.3.0" 76 | 77 | /*! 78 | * \brief The current major version number of the library 79 | */ 80 | #define ETL_VERSION_MAJOR 1 81 | 82 | /*! 83 | * \brief The current minor version number of the library 84 | */ 85 | #define ETL_VERSION_MINOR 3 86 | 87 | /*! 88 | * \brief The current revision version number of the library 89 | */ 90 | #define ETL_VERSION_REVISION 0 91 | -------------------------------------------------------------------------------- /include/etl/print.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | 13 | namespace etl { 14 | 15 | /*! 16 | * \brief Construct a textual representation of the matrix contents 17 | * \param m The expression to transform 18 | * \return a string representing the contents of the expression 19 | */ 20 | template 21 | std::string to_string(T&& m) { 22 | if constexpr (decay_traits::dimensions() > 1) { 23 | etl::force(m); 24 | 25 | std::string v = "["; 26 | for (size_t i = 0; i < etl::dim<0>(m); ++i) { 27 | v += to_string(sub(m, i)); 28 | 29 | if (i < etl::dim<0>(m) - 1) { 30 | v += "\n"; 31 | } 32 | } 33 | v += "]"; 34 | return v; 35 | } else { 36 | return to_octave(m); 37 | } 38 | } 39 | 40 | /*! 41 | * \brief Construct a textual representation of the matrix contents, following the octave format 42 | * \param m The expression to transform 43 | * \return a string representing the contents of the expression 44 | */ 45 | template 46 | std::string to_octave(T&& m) { 47 | etl::force(m); 48 | 49 | std::string v; 50 | 51 | if (!Sub) { 52 | v = "["; 53 | } 54 | 55 | if constexpr (decay_traits::dimensions() > 1) { 56 | for (size_t i = 0; i < etl::dim<0>(m); ++i) { 57 | v += to_octave(sub(m, i)); 58 | 59 | if (i < etl::dim<0>(m) - 1) { 60 | v += ";"; 61 | } 62 | } 63 | 64 | if (!Sub) { 65 | v += "]"; 66 | } 67 | } else { 68 | std::string comma; 69 | for (size_t j = 0; j < etl::dim<0>(m); ++j) { 70 | if constexpr (is_floating) { 71 | v += std::format("{}{:.6f}", comma, m(j)); 72 | } else { 73 | v += std::to_string(m(j)); 74 | } 75 | comma = ","; 76 | } 77 | 78 | if (!Sub) { 79 | v += "]"; 80 | } 81 | } 82 | 83 | return v; 84 | } 85 | 86 | } //end of namespace etl 87 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | 2 | # This format file is made especially for clang-format-9 3 | 4 | --- 5 | BasedOnStyle: Google 6 | IndentWidth: 4 7 | ColumnLimit: 160 8 | --- 9 | Language: Cpp 10 | Standard: Cpp11 11 | 12 | # Don't fuck with my includes 13 | SortIncludes: false 14 | 15 | # Don't fuck with my using declarations either 16 | SortUsingDeclarations: false 17 | 18 | # Tune some indentations 19 | AccessModifierOffset: -4 20 | ConstructorInitializerIndentWidth: 8 21 | ContinuationIndentWidth: 8 22 | 23 | # General Space Configuration 24 | SpaceBeforeParens: ControlStatements 25 | SpaceBeforeAssignmentOperators: true 26 | SpaceAfterTemplateKeyword: true 27 | SpaceAfterCStyleCast: true 28 | SpaceInEmptyParentheses: false 29 | SpacesInAngles: false 30 | SpacesInCStyleCastParentheses: false 31 | SpacesInContainerLiterals: false 32 | SpacesInParentheses: false 33 | SpacesInSquareBrackets: false 34 | 35 | # No block and its body should EVER be on a single line 36 | AllowShortFunctionsOnASingleLine: Empty 37 | AllowShortBlocksOnASingleLine: false 38 | AllowShortCaseLabelsOnASingleLine: false 39 | AllowShortIfStatementsOnASingleLine: false 40 | AllowShortLoopsOnASingleLine: false 41 | 42 | # Functions 43 | IndentWrappedFunctionNames: false 44 | AlwaysBreakAfterReturnType: None 45 | BreakConstructorInitializers: AfterColon 46 | 47 | # Switch 48 | IndentCaseLabels: false 49 | 50 | # Better C++11 support 51 | Cpp11BracedListStyle: true 52 | 53 | # Avoid too many empty lines 54 | MaxEmptyLinesToKeep: 1 55 | KeepEmptyLinesAtTheStartOfBlocks: false 56 | 57 | # Templates should always be on a separate line 58 | AlwaysBreakTemplateDeclarations: true 59 | 60 | # Nice alignement 61 | AlignConsecutiveAssignments: true 62 | AlignConsecutiveDeclarations: true 63 | AlignOperands: true 64 | AlignAfterOpenBracket: Align 65 | BinPackParameters: false 66 | BinPackArguments: false 67 | 68 | # Improve ternary operators alignement 69 | BreakBeforeTernaryOperators: true 70 | BreakBeforeBinaryOperators: NonAssignment 71 | 72 | # Tabs are bad news 73 | UseTab: Never 74 | 75 | # Configure comments 76 | AlignTrailingComments: true 77 | SpacesBeforeTrailingComments: 1 78 | 79 | # Avoid empty lines 80 | KeepEmptyLinesAtTheStartOfBlocks: false 81 | 82 | # Force pointers to the type 83 | DerivePointerAlignment: false 84 | PointerAlignment: Left 85 | 86 | # Strings 87 | BreakStringLiterals: false 88 | 89 | # Namespaces 90 | NamespaceIndentation: None 91 | CompactNamespaces: false 92 | FixNamespaceComments: true 93 | -------------------------------------------------------------------------------- /test/src/gemv_types.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test.hpp" 9 | #include "etl/stop.hpp" 10 | 11 | #include "mmul_test.hpp" 12 | 13 | // Matrix-Vector Multiplication with mixed types 14 | 15 | // GEMV 16 | 17 | ETL_TEST_CASE("gemv/types/0", "[gemv]") { 18 | etl::fast_matrix a = {1, 2, 3, 4, 5, 6}; 19 | etl::fast_vector b = {7, 8, 9}; 20 | etl::fast_matrix c; 21 | 22 | c = a * b; 23 | 24 | REQUIRE_EQUALS(c(0), float(50)); 25 | REQUIRE_EQUALS(c(1), float(122)); 26 | } 27 | 28 | ETL_TEST_CASE("gemv/types/1", "[gemv]") { 29 | etl::fast_matrix a = {1, 2, 3, 4, 5, 6}; 30 | etl::fast_vector b = {7, 8, 9}; 31 | etl::fast_matrix c; 32 | 33 | c = a * b; 34 | 35 | REQUIRE_EQUALS(c(0), double(50)); 36 | REQUIRE_EQUALS(c(1), double(122)); 37 | } 38 | 39 | ETL_TEST_CASE("gemv/types/2", "[gemv]") { 40 | etl::fast_matrix a = {1, 2, 3, 4, 5, 6}; 41 | etl::fast_vector b = {7, 8, 9}; 42 | etl::fast_matrix c; 43 | 44 | c = a * b; 45 | 46 | REQUIRE_EQUALS(c(0), float(50)); 47 | REQUIRE_EQUALS(c(1), float(122)); 48 | } 49 | 50 | // GEMV_T 51 | 52 | ETL_TEST_CASE("gemv_t/types/0", "[gemv]") { 53 | etl::fast_matrix a = {1, 4, 2, 5, 3, 6}; 54 | etl::fast_vector b = {7, 8, 9}; 55 | etl::fast_matrix c; 56 | 57 | c = trans(a) * b; 58 | 59 | REQUIRE_EQUALS(c(0), float(50)); 60 | REQUIRE_EQUALS(c(1), float(122)); 61 | } 62 | 63 | ETL_TEST_CASE("gemv_t/types/1", "[gemv]") { 64 | etl::fast_matrix a = {1, 4, 2, 5, 3, 6}; 65 | etl::fast_vector b = {7, 8, 9}; 66 | etl::fast_matrix c; 67 | 68 | c = trans(a) * b; 69 | 70 | REQUIRE_EQUALS(c(0), double(50)); 71 | REQUIRE_EQUALS(c(1), double(122)); 72 | } 73 | 74 | ETL_TEST_CASE("gemv_t/types/2", "[gemv]") { 75 | etl::fast_matrix a = {1, 4, 2, 5, 3, 6}; 76 | etl::fast_vector b = {7, 8, 9}; 77 | etl::fast_matrix c; 78 | 79 | c = trans(a) * b; 80 | 81 | REQUIRE_EQUALS(c(0), float(50)); 82 | REQUIRE_EQUALS(c(1), float(122)); 83 | } 84 | -------------------------------------------------------------------------------- /include/etl/impl/egblas/scalar_set.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief EGBLAS wrappers for the scalar_set operation. 11 | */ 12 | 13 | #pragma once 14 | 15 | #ifdef ETL_EGBLAS_MODE 16 | 17 | #include "etl/impl/cublas/cuda.hpp" 18 | 19 | #include 20 | 21 | #endif 22 | 23 | namespace etl::impl::egblas { 24 | 25 | #ifdef EGBLAS_HAS_SCALAR_SSET 26 | 27 | static constexpr bool has_scalar_sset = true; 28 | 29 | /*! 30 | * \brief sets the scalar beta to each element of the single-precision vector x 31 | * \param x The vector to set the scalar to (GPU pointer) 32 | * \param n The size of the vector 33 | * \param s The stride of the vector 34 | * \param beta The scalar to set 35 | */ 36 | inline void scalar_set(float* x, size_t n, size_t s, const float beta) { 37 | inc_counter("egblas"); 38 | egblas_scalar_sset(x, n, s, beta); 39 | } 40 | 41 | #else 42 | 43 | static constexpr bool has_scalar_sset = false; 44 | 45 | #endif 46 | 47 | #ifdef EGBLAS_HAS_SCALAR_DSET 48 | 49 | static constexpr bool has_scalar_dset = true; 50 | 51 | /*! 52 | * \brief sets the scalar beta to each element of the double-precision vector x 53 | * \param x The vector to set the scalar to (GPU pointer) 54 | * \param n The size of the vector 55 | * \param s The stride of the vector 56 | * \param beta The scalar to set 57 | */ 58 | inline void scalar_set(double* x, size_t n, size_t s, const double beta) { 59 | inc_counter("egblas"); 60 | egblas_scalar_dset(x, n, s, beta); 61 | } 62 | 63 | #else 64 | 65 | static constexpr bool has_scalar_dset = false; 66 | 67 | #endif 68 | 69 | #ifndef ETL_EGBLAS_MODE 70 | 71 | /*! 72 | * \brief sets the scalar beta to each element of the single-precision vector x 73 | * \param x The vector to set the scalar to (GPU pointer) 74 | * \param n The size of the vector 75 | * \param s The stride of the vector 76 | * \param beta The scalar to set 77 | */ 78 | template 79 | inline void scalar_set([[maybe_unused]] T* x, [[maybe_unused]] size_t n, [[maybe_unused]] size_t s, [[maybe_unused]] const T beta) { 80 | cpp_unreachable("Invalid call to egblas::scalar_set"); 81 | } 82 | 83 | #endif 84 | 85 | } //end of namespace etl::impl::egblas 86 | -------------------------------------------------------------------------------- /test/src/gevm_types.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test.hpp" 9 | #include "etl/stop.hpp" 10 | 11 | #include "mmul_test.hpp" 12 | 13 | // Vector-Matrix Multiplication with mixed types 14 | 15 | // GEVM 16 | 17 | ETL_TEST_CASE("gevm/types/0", "[gevm]") { 18 | etl::fast_matrix a = {1, 2, 3, 4, 5, 6}; 19 | etl::fast_vector b = {7, 8, 9}; 20 | etl::fast_matrix c; 21 | 22 | c = b * a; 23 | 24 | REQUIRE_EQUALS(c(0), float(76)); 25 | REQUIRE_EQUALS(c(1), float(100)); 26 | } 27 | 28 | ETL_TEST_CASE("gevm/types/1", "[gevm]") { 29 | etl::fast_matrix a = {1, 2, 3, 4, 5, 6}; 30 | etl::fast_vector b = {7, 8, 9}; 31 | etl::fast_matrix c; 32 | 33 | c = b * a; 34 | 35 | REQUIRE_EQUALS(c(0), float(76)); 36 | REQUIRE_EQUALS(c(1), float(100)); 37 | } 38 | 39 | ETL_TEST_CASE("gevm/types/2", "[gevm]") { 40 | etl::fast_matrix a = {1, 2, 3, 4, 5, 6}; 41 | etl::fast_vector b = {7, 8, 9}; 42 | etl::fast_matrix c; 43 | 44 | c = b * a; 45 | 46 | REQUIRE_EQUALS(c(0), double(76)); 47 | REQUIRE_EQUALS(c(1), double(100)); 48 | } 49 | 50 | // GEVM_T 51 | 52 | ETL_TEST_CASE("gevm_t/types/0", "[gevm]") { 53 | etl::fast_matrix a = {1, 3, 5, 2, 4, 6}; 54 | etl::fast_vector b = {7, 8, 9}; 55 | etl::fast_matrix c; 56 | 57 | c = b * trans(a); 58 | 59 | REQUIRE_EQUALS(c(0), float(76)); 60 | REQUIRE_EQUALS(c(1), float(100)); 61 | } 62 | 63 | ETL_TEST_CASE("gevm_t/types/1", "[gevm]") { 64 | etl::fast_matrix a = {1, 3, 5, 2, 4, 6}; 65 | etl::fast_vector b = {7, 8, 9}; 66 | etl::fast_matrix c; 67 | 68 | c = b * trans(a); 69 | 70 | REQUIRE_EQUALS(c(0), float(76)); 71 | REQUIRE_EQUALS(c(1), float(100)); 72 | } 73 | 74 | ETL_TEST_CASE("gevm_t/types/2", "[gevm]") { 75 | etl::fast_matrix a = {1, 3, 5, 2, 4, 6}; 76 | etl::fast_vector b = {7, 8, 9}; 77 | etl::fast_matrix c; 78 | 79 | c = b * trans(a); 80 | 81 | REQUIRE_EQUALS(c(0), double(76)); 82 | REQUIRE_EQUALS(c(1), double(100)); 83 | } 84 | -------------------------------------------------------------------------------- /benchmark/src/benchmark_batch_hint.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #define CPM_LIB 9 | #include "benchmark.hpp" 10 | 11 | // 2D batch_hint 12 | CPM_BENCH() { 13 | CPM_TWO_PASS_NS_P( 14 | batch_hint_2d_policy, 15 | "R = batch_hint_2d(gamma >> hint) (s) [batch_hint][s]", 16 | [](auto B, auto I){ return std::make_tuple(smat(B, I), svec(I), smat(B, I)); }, 17 | [](smat& lhs, svec& gamma, smat& input){ lhs = batch_hint(gamma >> input); } 18 | ); 19 | 20 | CPM_TWO_PASS_NS_P( 21 | batch_hint_2d_policy, 22 | "R = batch_hint_2d((gamma >> hint) + beta) (s) [batch_hint][s]", 23 | [](auto B, auto I){ return std::make_tuple(smat(B, I), svec(I), svec(I), smat(B, I)); }, 24 | [](smat& lhs, svec& gamma, svec& beta, smat& input){ lhs = batch_hint((gamma >> input) + beta); } 25 | ); 26 | 27 | CPM_TWO_PASS_NS_P( 28 | batch_hint_2d_policy, 29 | "R = batch_hint_2d(gamma >> (hint - beta) (s) [batch_hint][s]", 30 | [](auto B, auto I){ return std::make_tuple(smat(B, I), svec(I), svec(I), smat(B, I)); }, 31 | [](smat& lhs, svec& gamma, svec& beta, smat& input){ lhs = batch_hint(gamma >> (input - beta)); } 32 | ); 33 | } 34 | 35 | // 4D batch_hint 36 | CPM_BENCH() { 37 | CPM_TWO_PASS_NS_P( 38 | batch_hint_4d_policy, 39 | "R = batch_hint_4d(gamma >> hint) (s) [batch_hint][s]", 40 | [](auto B, auto K, auto M, auto N){ return std::make_tuple(smat4(B, K, M, N), svec(K), smat4(B, K, M, N)); }, 41 | [](smat4& lhs, svec& gamma, smat4& input){ lhs = batch_hint(gamma >> input); } 42 | ); 43 | 44 | CPM_TWO_PASS_NS_P( 45 | batch_hint_4d_policy, 46 | "R = batch_hint_4d((gamma >> hint) + beta) (s) [batch_hint][s]", 47 | [](auto B, auto K, auto M, auto N){ return std::make_tuple(smat4(B, K, M, N), svec(K), svec(K), smat4(B, K, M, N)); }, 48 | [](smat4& lhs, svec& gamma, svec& beta, smat4& input){ lhs = batch_hint((gamma >> input) + beta); } 49 | ); 50 | 51 | CPM_TWO_PASS_NS_P( 52 | batch_hint_4d_policy, 53 | "R = batch_hint_4d(gamma >> (hint - beta) (s) [batch_hint][s]", 54 | [](auto B, auto K, auto M, auto N){ return std::make_tuple(smat4(B, K, M, N), svec(K), svec(K), smat4(B, K, M, N)); }, 55 | [](smat4& lhs, svec& gamma, svec& beta, smat4& input){ lhs = batch_hint(gamma >> (input - beta)); } 56 | ); 57 | } 58 | -------------------------------------------------------------------------------- /include/etl/impl/egblas/sigmoid.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief EGBLAS wrappers for the sigmoid operation. 11 | */ 12 | 13 | #pragma once 14 | 15 | #ifdef ETL_EGBLAS_MODE 16 | 17 | #include "etl/impl/cublas/cuda.hpp" 18 | 19 | #include 20 | 21 | #endif 22 | 23 | namespace etl::impl::egblas { 24 | 25 | #ifdef EGBLAS_HAS_SSIGMOID 26 | static constexpr bool has_ssigmoid = true; 27 | #else 28 | static constexpr bool has_ssigmoid = false; 29 | #endif 30 | 31 | /*! 32 | * \brief Wrappers for single-precision egblas sigmoid operation 33 | * \param n The size of the vector 34 | * \param alpha The scaling factor alpha 35 | * \param A The memory of the vector a 36 | * \param lda The leading dimension of a 37 | * \param B The memory of the vector b 38 | * \param ldb The leading dimension of b 39 | */ 40 | inline void sigmoid([[maybe_unused]] size_t n, 41 | [[maybe_unused]] float alpha, 42 | [[maybe_unused]] float* A, 43 | [[maybe_unused]] size_t lda, 44 | [[maybe_unused]] float* B, 45 | [[maybe_unused]] size_t ldb) { 46 | #ifdef EGBLAS_HAS_SSIGMOID 47 | inc_counter("egblas"); 48 | egblas_ssigmoid(n, alpha, A, lda, B, ldb); 49 | #else 50 | cpp_unreachable("Invalid call to egblas::sigmoid"); 51 | #endif 52 | } 53 | 54 | #ifdef EGBLAS_HAS_DSIGMOID 55 | static constexpr bool has_dsigmoid = true; 56 | #else 57 | static constexpr bool has_dsigmoid = false; 58 | #endif 59 | 60 | /*! 61 | * \brief Wrappers for double-precision egblas sigmoid operation 62 | * \param n The size of the vector 63 | * \param alpha The scaling factor alpha 64 | * \param A The memory of the vector a 65 | * \param lda The leading dimension of a 66 | * \param B The memory of the vector b 67 | * \param ldb The leading dimension of b 68 | */ 69 | inline void sigmoid([[maybe_unused]] size_t n, 70 | [[maybe_unused]] double alpha, 71 | [[maybe_unused]] double* A, 72 | [[maybe_unused]] size_t lda, 73 | [[maybe_unused]] double* B, 74 | [[maybe_unused]] size_t ldb) { 75 | #ifdef EGBLAS_HAS_DSIGMOID 76 | inc_counter("egblas"); 77 | egblas_dsigmoid(n, alpha, A, lda, B, ldb); 78 | #else 79 | cpp_unreachable("Invalid call to egblas::sigmoid"); 80 | #endif 81 | } 82 | 83 | } //end of namespace etl::impl::egblas 84 | -------------------------------------------------------------------------------- /test/src/bias_add.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test.hpp" 9 | #include "bias_test.hpp" 10 | 11 | // Tests for bias_add 12 | 13 | BIAS_ADD_4D_TEST_CASE("bias_add/0", "[bias_add]") { 14 | etl::fast_matrix a({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); 15 | etl::fast_matrix b{1, 2, 3}; 16 | etl::fast_matrix c; 17 | 18 | Impl::apply(a, b, c); 19 | 20 | REQUIRE_EQUALS(c(0, 0, 0, 0), T(a(0, 0, 0, 0) + 1)); 21 | REQUIRE_EQUALS(c(0, 0, 0, 1), T(a(0, 0, 0, 1) + 1)); 22 | REQUIRE_EQUALS(c(0, 0, 1, 0), T(a(0, 0, 1, 0) + 1)); 23 | REQUIRE_EQUALS(c(0, 0, 1, 1), T(a(0, 0, 1, 1) + 1)); 24 | 25 | REQUIRE_EQUALS(c(0, 1, 0, 0), T(a(0, 1, 0, 0) + 2)); 26 | REQUIRE_EQUALS(c(0, 1, 0, 1), T(a(0, 1, 0, 1) + 2)); 27 | REQUIRE_EQUALS(c(0, 1, 1, 0), T(a(0, 1, 1, 0) + 2)); 28 | REQUIRE_EQUALS(c(0, 1, 1, 1), T(a(0, 1, 1, 1) + 2)); 29 | 30 | REQUIRE_EQUALS(c(0, 2, 0, 0), T(a(0, 2, 0, 0) + 3)); 31 | REQUIRE_EQUALS(c(0, 2, 0, 1), T(a(0, 2, 0, 1) + 3)); 32 | REQUIRE_EQUALS(c(0, 2, 1, 0), T(a(0, 2, 1, 0) + 3)); 33 | REQUIRE_EQUALS(c(0, 2, 1, 1), T(a(0, 2, 1, 1) + 3)); 34 | 35 | REQUIRE_EQUALS(c(1, 0, 0, 0), T(a(1, 0, 0, 0) + 1)); 36 | REQUIRE_EQUALS(c(1, 0, 0, 1), T(a(1, 0, 0, 1) + 1)); 37 | REQUIRE_EQUALS(c(1, 0, 1, 0), T(a(1, 0, 1, 0) + 1)); 38 | REQUIRE_EQUALS(c(1, 0, 1, 1), T(a(1, 0, 1, 1) + 1)); 39 | 40 | REQUIRE_EQUALS(c(1, 1, 0, 0), T(a(1, 1, 0, 0) + 2)); 41 | REQUIRE_EQUALS(c(1, 1, 0, 1), T(a(1, 1, 0, 1) + 2)); 42 | REQUIRE_EQUALS(c(1, 1, 1, 0), T(a(1, 1, 1, 0) + 2)); 43 | REQUIRE_EQUALS(c(1, 1, 1, 1), T(a(1, 1, 1, 1) + 2)); 44 | 45 | REQUIRE_EQUALS(c(1, 2, 0, 0), T(a(1, 2, 0, 0) + 3)); 46 | REQUIRE_EQUALS(c(1, 2, 0, 1), T(a(1, 2, 0, 1) + 3)); 47 | REQUIRE_EQUALS(c(1, 2, 1, 0), T(a(1, 2, 1, 0) + 3)); 48 | REQUIRE_EQUALS(c(1, 2, 1, 1), T(a(1, 2, 1, 1) + 3)); 49 | } 50 | 51 | BIAS_ADD_2D_TEST_CASE("bias_add/1", "[bias_add]") { 52 | etl::fast_matrix a({1, 2, 3, 4, 5, 6}); 53 | etl::fast_matrix b{1, 2, 3}; 54 | etl::fast_matrix c; 55 | 56 | Impl::apply(a, b, c); 57 | 58 | REQUIRE_EQUALS(c(0, 0), T(a(0, 0) + 1)); 59 | REQUIRE_EQUALS(c(0, 1), T(a(0, 1) + 2)); 60 | REQUIRE_EQUALS(c(0, 2), T(a(0, 2) + 3)); 61 | 62 | REQUIRE_EQUALS(c(1, 0), T(a(1, 0) + 1)); 63 | REQUIRE_EQUALS(c(1, 1), T(a(1, 1) + 2)); 64 | REQUIRE_EQUALS(c(1, 2), T(a(1, 2) + 3)); 65 | } 66 | -------------------------------------------------------------------------------- /test/src/transpose_front.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test.hpp" 9 | 10 | TEMPLATE_TEST_CASE_2("transpose_front/0", "[transpose_front]", Z, float, double) { 11 | etl::fast_matrix a; 12 | etl::fast_matrix b; 13 | etl::fast_matrix ref; 14 | 15 | a = etl::sequence_generator(5) * 0.2; 16 | 17 | b = transpose_front(a); 18 | 19 | for (size_t i = 0; i < etl::dim<0>(a); ++i) { 20 | for (size_t j = 0; j < etl::dim<1>(a); ++j) { 21 | ref(j)(i) = a(i)(j); 22 | } 23 | } 24 | 25 | for (size_t i = 0; i < etl::size(a); ++i) { 26 | REQUIRE(ref[i] == b[i]); 27 | } 28 | } 29 | 30 | TEMPLATE_TEST_CASE_2("transpose_front/1", "[transpose_front]", Z, float, double) { 31 | etl::dyn_matrix a(3, 2, 5); 32 | etl::dyn_matrix b(2, 3, 5); 33 | etl::dyn_matrix ref(2, 3, 5); 34 | 35 | a = etl::sequence_generator(3) * 0.1; 36 | 37 | b = transpose_front(a); 38 | 39 | for (size_t i = 0; i < etl::dim<0>(a); ++i) { 40 | for (size_t j = 0; j < etl::dim<1>(a); ++j) { 41 | ref(j)(i) = a(i)(j); 42 | } 43 | } 44 | 45 | for (size_t i = 0; i < etl::size(a); ++i) { 46 | REQUIRE(ref[i] == b[i]); 47 | } 48 | } 49 | 50 | TEMPLATE_TEST_CASE_2("transpose_front/2", "[transpose_front]", Z, float, double) { 51 | etl::fast_matrix a; 52 | etl::fast_matrix b; 53 | etl::fast_matrix ref; 54 | 55 | a = etl::sequence_generator(5) * 0.2; 56 | 57 | b = transpose_front(a); 58 | 59 | for (size_t i = 0; i < etl::dim<0>(a); ++i) { 60 | for (size_t j = 0; j < etl::dim<1>(a); ++j) { 61 | ref(j)(i) = a(i)(j); 62 | } 63 | } 64 | 65 | for (size_t i = 0; i < etl::size(a); ++i) { 66 | REQUIRE(ref[i] == b[i]); 67 | } 68 | } 69 | 70 | TEMPLATE_TEST_CASE_2("transpose_front/3", "[transpose_front]", Z, float, double) { 71 | etl::dyn_matrix a(7, 5, 15); 72 | etl::dyn_matrix b(5, 7, 15); 73 | etl::dyn_matrix ref(5, 7, 15); 74 | 75 | a = etl::sequence_generator(3) * 0.1; 76 | 77 | b = transpose_front(a); 78 | 79 | for (size_t i = 0; i < etl::dim<0>(a); ++i) { 80 | for (size_t j = 0; j < etl::dim<1>(a); ++j) { 81 | ref(j)(i) = a(i)(j); 82 | } 83 | } 84 | 85 | for (size_t i = 0; i < etl::size(a); ++i) { 86 | REQUIRE(ref[i] == b[i]); 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /include/etl/impl/egblas/transpose_front.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief EGBLAS wrappers for the transpose_front operation. 11 | */ 12 | 13 | #pragma once 14 | 15 | #ifdef ETL_EGBLAS_MODE 16 | 17 | #include "etl/impl/cublas/cuda.hpp" 18 | 19 | #include 20 | 21 | #endif 22 | 23 | namespace etl::impl::egblas { 24 | 25 | #ifdef EGBLAS_HAS_STRANSPOSE_FRONT 26 | static constexpr bool has_stranspose_front = true; 27 | #else 28 | static constexpr bool has_stranspose_front = false; 29 | #endif 30 | 31 | /*! 32 | * \brief Wrappers for single-precision egblas transpose_front operation 33 | * \param b The batch dimension of the matrix 34 | * \param n The size of the output vector 35 | * \param A The memory of the vector a 36 | * \param lda The leading dimension of a 37 | * \param B The memory of the vector b 38 | * \param ldb The leading dimension of b 39 | */ 40 | inline void transpose_front([[maybe_unused]] size_t m, 41 | [[maybe_unused]] size_t n, 42 | [[maybe_unused]] size_t k, 43 | [[maybe_unused]] float* A, 44 | [[maybe_unused]] float* B) { 45 | #ifdef EGBLAS_HAS_STRANSPOSE_FRONT 46 | inc_counter("egblas"); 47 | egblas_stranspose_front(m, n, k, A, B); 48 | #else 49 | cpp_unreachable("Invalid call to egblas::transpose_front"); 50 | #endif 51 | } 52 | 53 | #ifdef EGBLAS_HAS_DTRANSPOSE_FRONT 54 | static constexpr bool has_dtranspose_front = true; 55 | #else 56 | static constexpr bool has_dtranspose_front = false; 57 | #endif 58 | 59 | /*! 60 | * \brief Wrappers for double-precision egblas transpose_front operation 61 | * \param b The batch dimension of the matrix 62 | * \param n The size of the output vector 63 | * \param A The memory of the vector a 64 | * \param lda The leading dimension of a 65 | * \param B The memory of the vector b 66 | * \param ldb The leading dimension of b 67 | */ 68 | inline void transpose_front([[maybe_unused]] size_t m, 69 | [[maybe_unused]] size_t n, 70 | [[maybe_unused]] size_t k, 71 | [[maybe_unused]] double* A, 72 | [[maybe_unused]] double* B) { 73 | #ifdef EGBLAS_HAS_DTRANSPOSE_FRONT 74 | inc_counter("egblas"); 75 | egblas_dtranspose_front(m, n, k, A, B); 76 | #else 77 | cpp_unreachable("Invalid call to egblas::transpose_front"); 78 | #endif 79 | } 80 | 81 | } //end of namespace etl::impl::egblas 82 | -------------------------------------------------------------------------------- /test/src/dot.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test_light.hpp" 9 | #include "dot_test.hpp" 10 | 11 | DOT_TEST_CASE("dot/1", "[dot]") { 12 | etl::fast_vector a = {-1.0, 2.0, 8.5}; 13 | etl::fast_vector b = {2.0, 3.0, 2.0}; 14 | 15 | T value = 0; 16 | Impl::apply(a, b, value); 17 | 18 | REQUIRE_EQUALS(value, 21.0); 19 | } 20 | 21 | DOT_TEST_CASE("dot/2", "[dot]") { 22 | etl::fast_vector a = {-1.0, 2.0, 8.5}; 23 | etl::fast_vector b = {-2.0, -3.0, -2.0}; 24 | 25 | T value = 0; 26 | Impl::apply(a, b, value); 27 | 28 | REQUIRE_EQUALS(value, -21.0); 29 | } 30 | 31 | DOT_TEST_CASE("dot/3", "[dot]") { 32 | etl::dyn_vector a({-1.0, 2.0, 8.5}); 33 | etl::dyn_vector b({2.0, 3.0, 2.0}); 34 | 35 | T value = 0; 36 | Impl::apply(a, b, value); 37 | 38 | REQUIRE_EQUALS(value, 21.0); 39 | } 40 | 41 | DOT_TEST_CASE("dot/4", "[dot]") { 42 | etl::dyn_vector a({-1.0, 2.0, 8.5}); 43 | etl::dyn_vector b({-2.0, -3.0, -2.0}); 44 | 45 | T value = 0; 46 | Impl::apply(a, b, value); 47 | 48 | REQUIRE_EQUALS(value, -21.0); 49 | } 50 | 51 | DOT_TEST_CASE("dot/5", "[dot]") { 52 | etl::dyn_vector a(15); 53 | etl::dyn_vector b(15); 54 | 55 | a = etl::sequence_generator(1); 56 | b = etl::sequence_generator(2); 57 | 58 | T value = 0; 59 | Impl::apply(a, b, value); 60 | 61 | REQUIRE_EQUALS(value, 1360.0); 62 | } 63 | 64 | DOT_TEST_CASE("dot/6", "[dot]") { 65 | etl::dyn_vector a(33); 66 | etl::dyn_vector b(33); 67 | 68 | a = etl::sequence_generator(1); 69 | b = etl::sequence_generator(2); 70 | 71 | T value = 0; 72 | Impl::apply(a, b, value); 73 | 74 | REQUIRE_EQUALS(value, 13090.0); 75 | } 76 | 77 | DOT_TEST_CASE("dot/7", "[dot]") { 78 | etl::dyn_vector a(57); 79 | etl::dyn_vector b(57); 80 | 81 | a = etl::sequence_generator(1); 82 | b = etl::sequence_generator(2); 83 | 84 | T value = 0; 85 | Impl::apply(a, b, value); 86 | 87 | REQUIRE_EQUALS(value, 65018.0); 88 | } 89 | 90 | DOT_TEST_CASE("dot/8", "[dot]") { 91 | etl::dyn_vector a(1024 - 7); 92 | etl::dyn_vector b(1024 - 7); 93 | 94 | a = T(0.01) * etl::sequence_generator(1); 95 | b = T(0.02) * etl::sequence_generator(2); 96 | 97 | T value = 0; 98 | Impl::apply(a, b, value); 99 | 100 | REQUIRE_EQUALS_APPROX(value, 70331.7876); 101 | } 102 | -------------------------------------------------------------------------------- /include/etl/crtp/value_testable.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file value_testable.hpp 10 | * \brief Use CRTP technique to inject functions that test the values of the expressions or the value classes. 11 | */ 12 | 13 | #pragma once 14 | 15 | namespace etl { 16 | 17 | template 18 | bool is_diagonal(E&& expr); 19 | 20 | /*! 21 | * \brief CRTP class to inject functions testing values of the expressions. 22 | * 23 | * This CRTP class injects test for is_finite and is_zero. 24 | */ 25 | template 26 | struct value_testable { 27 | using derived_t = D; ///< The derived type 28 | 29 | /*! 30 | * \brief Returns a reference to the derived object, i.e. the object using the CRTP injector. 31 | * \return a reference to the derived object. 32 | */ 33 | derived_t& as_derived() noexcept { 34 | return *static_cast(this); 35 | } 36 | 37 | /*! 38 | * \brief Returns a reference to the derived object, i.e. the object using the CRTP injector. 39 | * \return a reference to the derived object. 40 | */ 41 | const derived_t& as_derived() const noexcept { 42 | return *static_cast(this); 43 | } 44 | 45 | /*! 46 | * \brief Indicates if the expression contains only finite values. 47 | * \return true if the sequence only contains finite values, false otherwise. 48 | */ 49 | bool is_finite() const noexcept { 50 | return std::all_of(as_derived().begin(), as_derived().end(), static_cast)>(std::isfinite)); 51 | } 52 | 53 | /*! 54 | * \brief Indicates if the expression contains only zero values. 55 | * \return true if the sequence only contains zero values, false otherwise. 56 | */ 57 | bool is_zero() const noexcept { 58 | return std::all_of(as_derived().begin(), as_derived().end(), [](value_t v) { return v == value_t(0); }); 59 | } 60 | 61 | /*! 62 | * \brief Indicates if the expression is diagonal. 63 | * \return true if the expression is diagonal, false otherwise. 64 | */ 65 | bool is_diagonal() const noexcept { 66 | return etl::is_diagonal(as_derived()); 67 | } 68 | 69 | /*! 70 | * \brief Indicates if the expression is uniform, i.e. all elements are of the same value 71 | * \return true if the expression is uniform, false otherwise. 72 | */ 73 | bool is_uniform() const noexcept { 74 | return etl::is_uniform(as_derived()); 75 | } 76 | }; 77 | 78 | } //end of namespace etl 79 | -------------------------------------------------------------------------------- /include/etl/impl/cublas/outer.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief BLAS implementation of the outer product 11 | */ 12 | 13 | #pragma once 14 | 15 | #ifdef ETL_CUBLAS_MODE 16 | 17 | #include "etl/impl/cublas/cuda.hpp" 18 | #include "etl/impl/cublas/cublas.hpp" 19 | 20 | #endif 21 | 22 | namespace etl::impl::cublas { 23 | 24 | #ifdef ETL_CUBLAS_MODE 25 | 26 | /*! 27 | * \brief Compute the batch_outer product of a and b and store the result in c 28 | * \param a The lhs expression 29 | * \param b The rhs expression 30 | * \param c The output expression 31 | */ 32 | template 33 | void batch_outer(const A& a, const B& b, C&& c) { 34 | decltype(auto) handle = start_cublas(); 35 | 36 | float alpha = 1.0; 37 | float beta = 0.0; 38 | 39 | // This is brain-killing :s 40 | // CUBLAS need matrices in column-major order. By switching both 41 | // matrices, this is achieved. However, since one of the matrix 42 | // needs to be transposed, it must be changed again 43 | 44 | a.ensure_gpu_up_to_date(); 45 | b.ensure_gpu_up_to_date(); 46 | c.ensure_gpu_allocated(); 47 | 48 | cublas_check(cublasSgemm(handle.get(), CUBLAS_OP_N, CUBLAS_OP_T, etl::columns(c), etl::rows(c), etl::rows(b), &alpha, b.gpu_memory(), etl::columns(b), 49 | a.gpu_memory(), etl::columns(a), &beta, c.gpu_memory(), etl::columns(b))); 50 | 51 | c.validate_gpu(); 52 | c.invalidate_cpu(); 53 | } 54 | 55 | /*! 56 | * \copydoc batch_outer 57 | */ 58 | template 59 | void batch_outer(const A& a, const B& b, C&& c) { 60 | decltype(auto) handle = start_cublas(); 61 | 62 | double alpha = 1.0; 63 | double beta = 0.0; 64 | 65 | a.ensure_gpu_up_to_date(); 66 | b.ensure_gpu_up_to_date(); 67 | c.ensure_gpu_allocated(); 68 | 69 | cublas_check(cublasDgemm(handle.get(), CUBLAS_OP_N, CUBLAS_OP_T, etl::columns(c), etl::rows(c), etl::rows(b), &alpha, b.gpu_memory(), etl::columns(b), 70 | a.gpu_memory(), etl::columns(a), &beta, c.gpu_memory(), etl::columns(b))); 71 | 72 | c.validate_gpu(); 73 | c.invalidate_cpu(); 74 | } 75 | 76 | #else 77 | 78 | /*! 79 | * \copydoc batch_outer 80 | */ 81 | template 82 | void batch_outer(const A& /*a*/, const B& /*b*/, C&& /*c*/) { 83 | cpp_unreachable("CUBLAS not enabled/available"); 84 | } 85 | 86 | #endif 87 | 88 | } //end of namespace etl::impl::cublas 89 | -------------------------------------------------------------------------------- /test/src/compare.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test.hpp" 9 | 10 | #include 11 | 12 | TEMPLATE_TEST_CASE_2("compare/1", "[compare]", Z, float, double) { 13 | etl::fast_matrix a(3.3); 14 | etl::fast_matrix b(3.3); 15 | etl::fast_matrix c(33.3); 16 | 17 | REQUIRE_EQUALS(a, a); 18 | REQUIRE_EQUALS(a, b); 19 | REQUIRE_EQUALS(b, a); 20 | REQUIRE_EQUALS(b, b); 21 | REQUIRE_DIRECT(!(a == c)); 22 | REQUIRE_DIRECT(!(b == c)); 23 | 24 | REQUIRE_DIRECT(!(a != a)); 25 | REQUIRE_DIRECT(!(a != b)); 26 | REQUIRE_DIRECT(!(b != a)); 27 | REQUIRE_DIRECT(!(b != b)); 28 | REQUIRE_DIRECT(a != c); 29 | REQUIRE_DIRECT(b != c); 30 | } 31 | 32 | TEMPLATE_TEST_CASE_2("compare/2", "[compare]", Z, float, double) { 33 | etl::dyn_matrix a(2, 2, 3.3); 34 | etl::dyn_matrix b(2, 2, 3.3); 35 | etl::dyn_matrix c(2, 2, 33.3); 36 | 37 | REQUIRE_EQUALS(a, a); 38 | REQUIRE_EQUALS(a, b); 39 | REQUIRE_EQUALS(b, a); 40 | REQUIRE_EQUALS(b, b); 41 | REQUIRE_DIRECT(!(a == c)); 42 | REQUIRE_DIRECT(!(b == c)); 43 | 44 | REQUIRE_DIRECT(!(a != a)); 45 | REQUIRE_DIRECT(!(a != b)); 46 | REQUIRE_DIRECT(!(b != a)); 47 | REQUIRE_DIRECT(!(b != b)); 48 | REQUIRE_DIRECT(a != c); 49 | REQUIRE_DIRECT(b != c); 50 | } 51 | 52 | TEMPLATE_TEST_CASE_2("compare/3", "[compare]", Z, float, double) { 53 | etl::fast_matrix fa(3.3); 54 | etl::fast_matrix fc(33.3); 55 | etl::dyn_matrix da(2, 2, 3.3); 56 | etl::dyn_matrix dc(2, 2, 33.3); 57 | 58 | REQUIRE_EQUALS(da, fa); 59 | REQUIRE_EQUALS(fa, da); 60 | 61 | REQUIRE_DIRECT(da != fc); 62 | REQUIRE_DIRECT(fc != da); 63 | } 64 | 65 | TEMPLATE_TEST_CASE_2("compare/4", "[compare]", Z, float, double) { 66 | etl::fast_matrix a(3.3); 67 | etl::dyn_matrix b(2, 2, 3.3); 68 | 69 | REQUIRE_EQUALS((a + b), (b + a)); 70 | REQUIRE_EQUALS((2 * a), (a * 2)); 71 | REQUIRE_EQUALS(*(a * a), *(a * b)); 72 | 73 | REQUIRE_EQUALS(log(a + b), log(b + a)); 74 | REQUIRE_DIRECT(log(a + b) != exp(b + a)); 75 | } 76 | 77 | TEMPLATE_TEST_CASE_2("compare/5", "[compare]", Z, float, double) { 78 | etl::fast_matrix a(3.3); 79 | etl::dyn_matrix b(2, 2, 3.3); 80 | 81 | etl::fast_matrix c(3.3); 82 | etl::dyn_matrix d(3, 2, 3.3); 83 | 84 | REQUIRE_DIRECT(a != c); 85 | REQUIRE_DIRECT(b != d); 86 | 87 | REQUIRE_DIRECT((a + b) != (c + d)); 88 | REQUIRE_DIRECT((2 * a) != (c * 2)); 89 | 90 | REQUIRE_DIRECT(log(a + b) != log(c + d)); 91 | } 92 | -------------------------------------------------------------------------------- /include/etl/impl/blas/outer.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief BLAS implementation of the outer product 11 | */ 12 | 13 | #pragma once 14 | 15 | #ifdef ETL_BLAS_MODE 16 | #include "cblas.h" 17 | #endif 18 | 19 | namespace etl::impl::blas { 20 | 21 | #ifdef ETL_BLAS_MODE 22 | 23 | /*! 24 | * \brief Compute the outer product of a and b and store the result in c 25 | * \param a The lhs expression 26 | * \param b The rhs expression 27 | * \param c The output expression 28 | */ 29 | template 30 | void outer(const A& a, const B& b, C&& c) { 31 | c = 0; 32 | 33 | a.ensure_cpu_up_to_date(); 34 | b.ensure_cpu_up_to_date(); 35 | c.ensure_cpu_up_to_date(); 36 | 37 | if constexpr (all_single_precision) { 38 | cblas_sger(CblasRowMajor, etl::dim<0>(a), etl::dim<0>(b), 1.0, a.memory_start(), 1, b.memory_start(), 1, c.memory_start(), etl::dim<0>(b)); 39 | } else { 40 | cblas_dger(CblasRowMajor, etl::dim<0>(a), etl::dim<0>(b), 1.0, a.memory_start(), 1, b.memory_start(), 1, c.memory_start(), etl::dim<0>(b)); 41 | } 42 | 43 | c.invalidate_gpu(); 44 | } 45 | 46 | /*! 47 | * \brief Compute the batch_outer product of a and b and store the result in c 48 | * \param a The lhs expression 49 | * \param b The rhs expression 50 | * \param c The output expression 51 | */ 52 | template 53 | void batch_outer(const A& a, const B& b, C&& c) { 54 | const size_t m = etl::rows(c); 55 | const size_t n = etl::columns(c); 56 | const size_t k = etl::rows(a); 57 | 58 | a.ensure_cpu_up_to_date(); 59 | b.ensure_cpu_up_to_date(); 60 | 61 | if constexpr (all_single_precision) { 62 | cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m, n, k, 1.0f, a.memory_start(), m, b.memory_start(), n, 0.0f, c.memory_start(), n); 63 | } else { 64 | cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m, n, k, 1.0, a.memory_start(), m, b.memory_start(), n, 0.0, c.memory_start(), n); 65 | } 66 | 67 | c.invalidate_gpu(); 68 | } 69 | 70 | #else 71 | 72 | /*! 73 | * \copydoc outer 74 | */ 75 | template 76 | void outer(const A& /*a*/, const B& /*b*/, C&& /*c*/) { 77 | cpp_unreachable("BLAS not enabled/available"); 78 | } 79 | 80 | /*! 81 | * \copydoc batch_outer 82 | */ 83 | template 84 | void batch_outer(const A& /*a*/, const B& /*b*/, C&& /*c*/) { 85 | cpp_unreachable("BLAS not enabled/available"); 86 | } 87 | 88 | #endif 89 | 90 | } //end of namespace etl::impl::blas 91 | -------------------------------------------------------------------------------- /test/src/serial.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test_light.hpp" 9 | 10 | TEMPLATE_TEST_CASE_2("serial/1", "[fast][serial]", Z, float, double) { 11 | etl::fast_vector a({1.0, -2.0, 3.0}); 12 | etl::fast_vector b; 13 | 14 | b = serial(a + a); 15 | 16 | REQUIRE_EQUALS(b[0], 2.0); 17 | } 18 | 19 | TEMPLATE_TEST_CASE_2("serial/2", "[dyn][serial][sum]", Z, float, double) { 20 | etl::dyn_matrix a(500, 500); 21 | 22 | a = 12.0; 23 | 24 | Z sum = 0.0; 25 | 26 | SERIAL_SECTION { 27 | sum = etl::sum(a); 28 | } 29 | 30 | REQUIRE_EQUALS(sum, 12.0 * etl::size(a)); 31 | } 32 | 33 | TEMPLATE_TEST_CASE_2("serial/3", "[fast][serial]", Z, float, double) { 34 | etl::fast_vector a({1.0, -2.0, 3.0}); 35 | etl::fast_vector b; 36 | 37 | b = serial(a + a); 38 | b += serial(a + a); 39 | 40 | REQUIRE_EQUALS(b[0], 4.0); 41 | } 42 | 43 | TEMPLATE_TEST_CASE_2("serial_section/1", "[fast][serial]", Z, float, double) { 44 | etl::fast_vector a({1.0, -2.0, 3.0}); 45 | etl::fast_vector b; 46 | 47 | REQUIRE_DIRECT(!etl::local_context().serial); 48 | 49 | SERIAL_SECTION { 50 | REQUIRE_DIRECT(etl::local_context().serial); 51 | b = a + a; 52 | b += a + a; 53 | } 54 | 55 | REQUIRE_DIRECT(!etl::local_context().serial); 56 | 57 | REQUIRE_EQUALS(b[0], 4.0); 58 | } 59 | 60 | TEMPLATE_TEST_CASE_2("serial_section/2", "[fast][serial]", Z, float, double) { 61 | etl::fast_vector a({1.0, -2.0, 3.0}); 62 | etl::fast_vector b; 63 | 64 | REQUIRE_DIRECT(!etl::local_context().serial); 65 | 66 | etl::local_context().serial = true; 67 | 68 | SERIAL_SECTION { 69 | REQUIRE_DIRECT(etl::local_context().serial); 70 | b = a + a; 71 | b += a + a; 72 | 73 | etl::local_context().serial = false; 74 | } 75 | 76 | REQUIRE_DIRECT(etl::local_context().serial); 77 | 78 | etl::local_context().serial = false; 79 | 80 | REQUIRE_EQUALS(b[0], 4.0); 81 | } 82 | 83 | TEMPLATE_TEST_CASE_2("serial_section/3", "[fast][serial]", Z, float, double) { 84 | REQUIRE_DIRECT(!etl::local_context().serial); 85 | 86 | SERIAL_SECTION { 87 | REQUIRE_DIRECT(etl::local_context().serial); 88 | 89 | etl::local_context().serial = false; 90 | 91 | SERIAL_SECTION { 92 | REQUIRE_DIRECT(etl::local_context().serial); 93 | } 94 | 95 | REQUIRE_DIRECT(!etl::local_context().serial); 96 | } 97 | 98 | REQUIRE_DIRECT(!etl::local_context().serial); 99 | } 100 | -------------------------------------------------------------------------------- /Jenkinsfile: -------------------------------------------------------------------------------- 1 | #!groovy 2 | 3 | pipeline { 4 | agent any 5 | 6 | stages { 7 | stage ('git'){ 8 | steps { 9 | // TODO This horrible Jenkins mess should be cleaned 10 | checkout([ 11 | $class: 'GitSCM', 12 | branches: scm.branches, 13 | doGenerateSubmoduleConfigurations: false, 14 | extensions: scm.extensions + [[$class: 'SubmoduleOption', disableSubmodules: false, recursiveSubmodules: true, reference: '', trackingSubmodules: false]], 15 | submoduleCfg: [], 16 | userRemoteConfigs: scm.userRemoteConfigs]) 17 | } 18 | } 19 | 20 | stage ('pre-analysis') { 21 | steps { 22 | sh 'cppcheck --xml-version=2 -j3 --enable=all --std=c++11 `git ls-files "*.hpp" "*.cpp"` 2> cppcheck_report.xml' 23 | sh 'sloccount --duplicates --wide --details include/etl test workbench > sloccount.sc' 24 | sh 'cccc include/etl/*.hpp test/*.cpp workbench/*.cpp || true' 25 | } 26 | } 27 | 28 | stage ('build'){ 29 | environment { 30 | CXX = "g++-6.4.0" 31 | LD = "g++-6.4.0" 32 | ETL_MKL = 'true' 33 | ETL_COVERAGE = 'true' 34 | } 35 | 36 | steps { 37 | sh 'make clean' 38 | sh 'make -j6 release' 39 | } 40 | } 41 | 42 | stage ('test'){ 43 | environment { 44 | ETL_THREADS = "-j6" 45 | ETL_GPP = "g++-6.4.0" 46 | } 47 | 48 | steps { 49 | sh './scripts/test_runner_ci.sh' 50 | archive 'catch_report.xml' 51 | junit 'catch_report.xml' 52 | } 53 | } 54 | 55 | stage ('sonar-master'){ 56 | when { 57 | branch 'master' 58 | } 59 | steps { 60 | sh "/opt/sonar-runner/bin/sonar-runner" 61 | } 62 | } 63 | 64 | stage ('sonar-branch'){ 65 | when { 66 | not { 67 | branch 'master' 68 | } 69 | } 70 | steps { 71 | sh "/opt/sonar-runner/bin/sonar-runner -Dsonar.branch=${env.BRANCH_NAME}" 72 | } 73 | } 74 | 75 | stage ('bench'){ 76 | steps { 77 | build job: 'etl - benchmark', wait: false 78 | } 79 | } 80 | } 81 | 82 | post { 83 | always { 84 | script { 85 | if (currentBuild.result == null) { 86 | currentBuild.result = 'SUCCESS' 87 | } 88 | } 89 | 90 | step([$class: 'Mailer', 91 | notifyEveryUnstableBuild: true, 92 | recipients: "baptiste.wicht@gmail.com", 93 | sendToIndividuals: true]) 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /test/src/parallel.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test_light.hpp" 9 | 10 | TEMPLATE_TEST_CASE_2("parallel/1", "[fast][parallel]", Z, float, double) { 11 | etl::fast_vector a({1.0, -2.0, 3.0}); 12 | etl::fast_vector b; 13 | 14 | b = parallel(a + a); 15 | 16 | REQUIRE_EQUALS(b[0], 2.0); 17 | } 18 | 19 | TEMPLATE_TEST_CASE_2("parallel/2", "[dyn][parallel][sum]", Z, float, double) { 20 | etl::dyn_matrix a(500, 500); 21 | 22 | a = 12.0; 23 | 24 | Z sum = 0.0; 25 | 26 | PARALLEL_SECTION { 27 | sum = etl::sum(a); 28 | } 29 | 30 | REQUIRE_EQUALS(sum, 12.0 * etl::size(a)); 31 | } 32 | 33 | TEMPLATE_TEST_CASE_2("parallel/3", "[fast][parallel]", Z, float, double) { 34 | etl::fast_vector a({1.0, -2.0, 3.0}); 35 | etl::fast_vector b; 36 | 37 | b = parallel(a + a); 38 | b += parallel(a + a); 39 | 40 | REQUIRE_EQUALS(b[0], 4.0); 41 | } 42 | 43 | TEMPLATE_TEST_CASE_2("parallel_section/1", "[fast][parallel]", Z, float, double) { 44 | etl::fast_vector a({1.0, -2.0, 3.0}); 45 | etl::fast_vector b; 46 | 47 | REQUIRE_DIRECT(!etl::local_context().parallel); 48 | 49 | PARALLEL_SECTION { 50 | REQUIRE_DIRECT(etl::local_context().parallel); 51 | b = a + a; 52 | b += a + a; 53 | } 54 | 55 | REQUIRE_DIRECT(!etl::local_context().parallel); 56 | 57 | REQUIRE_EQUALS(b[0], 4.0); 58 | } 59 | 60 | TEMPLATE_TEST_CASE_2("parallel_section/2", "[fast][parallel]", Z, float, double) { 61 | etl::fast_vector a({1.0, -2.0, 3.0}); 62 | etl::fast_vector b; 63 | 64 | REQUIRE_DIRECT(!etl::local_context().parallel); 65 | 66 | etl::local_context().parallel = true; 67 | 68 | PARALLEL_SECTION { 69 | REQUIRE_DIRECT(etl::local_context().parallel); 70 | b = a + a; 71 | b += a + a; 72 | 73 | etl::local_context().parallel = false; 74 | } 75 | 76 | REQUIRE_DIRECT(etl::local_context().parallel); 77 | 78 | etl::local_context().parallel = false; 79 | 80 | REQUIRE_EQUALS(b[0], 4.0); 81 | } 82 | 83 | TEMPLATE_TEST_CASE_2("parallel_section/3", "[fast][parallel]", Z, float, double) { 84 | REQUIRE_DIRECT(!etl::local_context().parallel); 85 | 86 | PARALLEL_SECTION { 87 | REQUIRE_DIRECT(etl::local_context().parallel); 88 | 89 | etl::local_context().parallel = false; 90 | 91 | PARALLEL_SECTION { 92 | REQUIRE_DIRECT(etl::local_context().parallel); 93 | } 94 | 95 | REQUIRE_DIRECT(!etl::local_context().parallel); 96 | } 97 | 98 | REQUIRE_DIRECT(!etl::local_context().parallel); 99 | } 100 | -------------------------------------------------------------------------------- /include/etl/impl/egblas/relu_der_out.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief EGBLAS wrappers for the relu_der_out operation. 11 | */ 12 | 13 | #pragma once 14 | 15 | #ifdef ETL_EGBLAS_MODE 16 | 17 | #include "etl/impl/cublas/cuda.hpp" 18 | 19 | #include 20 | 21 | #endif 22 | 23 | namespace etl::impl::egblas { 24 | 25 | /*! 26 | * \brief Indicates if EGBLAS has single-precision relu_der_out. 27 | */ 28 | #ifdef EGBLAS_HAS_SRELU_DER_OUT 29 | static constexpr bool has_srelu_der_out = true; 30 | #else 31 | static constexpr bool has_srelu_der_out = false; 32 | #endif 33 | 34 | /*! 35 | * \brief Wrappers for single-precision egblas relu_der_out operation 36 | * \param n The size of the vector 37 | * \param alpha The scaling factor alpha 38 | * \param A The memory of the vector a 39 | * \param lda The leading dimension of a 40 | * \param B The memory of the vector b 41 | * \param ldb The leading dimension of b 42 | */ 43 | inline void relu_der_out([[maybe_unused]] size_t n, 44 | [[maybe_unused]] float alpha, 45 | [[maybe_unused]] float* A, 46 | [[maybe_unused]] size_t lda, 47 | [[maybe_unused]] float* B, 48 | [[maybe_unused]] size_t ldb) { 49 | #ifdef EGBLAS_HAS_SRELU_DER_OUT 50 | inc_counter("egblas"); 51 | egblas_srelu_der_out(n, alpha, A, lda, B, ldb); 52 | #else 53 | cpp_unreachable("Invalid call to egblas::relu_der_out"); 54 | #endif 55 | } 56 | 57 | /*! 58 | * \brief Indicates if EGBLAS has double-precision relu_der_out. 59 | */ 60 | #ifdef EGBLAS_HAS_DRELU_DER_OUT 61 | static constexpr bool has_drelu_der_out = true; 62 | #else 63 | static constexpr bool has_drelu_der_out = false; 64 | #endif 65 | 66 | /*! 67 | * \brief Wrappers for double-precision egblas relu_der_out operation 68 | * \param n The size of the vector 69 | * \param alpha The scaling factor alpha 70 | * \param A The memory of the vector a 71 | * \param lda The leading dimension of a 72 | * \param B The memory of the vector b 73 | * \param ldb The leading dimension of b 74 | */ 75 | inline void relu_der_out([[maybe_unused]] size_t n, 76 | [[maybe_unused]] double alpha, 77 | [[maybe_unused]] double* A, 78 | [[maybe_unused]] size_t lda, 79 | [[maybe_unused]] double* B, 80 | [[maybe_unused]] size_t ldb) { 81 | #ifdef EGBLAS_HAS_DRELU_DER_OUT 82 | inc_counter("egblas"); 83 | egblas_drelu_der_out(n, alpha, A, lda, B, ldb); 84 | #else 85 | cpp_unreachable("Invalid call to egblas::relu_der_out"); 86 | #endif 87 | } 88 | 89 | } //end of namespace etl::impl::egblas 90 | -------------------------------------------------------------------------------- /workbench/src/verify_cpm.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include "etl/etl.hpp" 13 | 14 | typedef std::chrono::high_resolution_clock timer_clock; 15 | typedef std::chrono::milliseconds milliseconds; 16 | 17 | void default_case(){ 18 | etl::dyn_matrix A(128, 64, 100, 100); 19 | etl::dyn_matrix B(64); 20 | etl::dyn_matrix C(128, 64, 100, 100); 21 | 22 | for(size_t i = 0; i < 5; ++i){ 23 | C = etl::bias_add_4d(A, B); 24 | } 25 | 26 | auto start_time = timer_clock::now(); 27 | 28 | for(size_t i = 0; i < 10; ++i){ 29 | C = etl::bias_add_4d(A, B); 30 | } 31 | 32 | auto end_time = timer_clock::now(); 33 | auto duration = std::chrono::duration_cast(end_time - start_time); 34 | 35 | std::cout << "default: " << duration.count() << "ms" << std::endl; 36 | std::cout << " mean: " << (duration.count() / 10.0) << "ms" << std::endl; 37 | } 38 | 39 | void std_case(){ 40 | etl::dyn_matrix A(128, 64, 100, 100); 41 | etl::dyn_matrix B(64); 42 | etl::dyn_matrix C(128, 64, 100, 100); 43 | 44 | for(size_t i = 0; i < 5; ++i){ 45 | C = selected_helper(etl::bias_add_impl::STD, etl::bias_add_4d(A, B)); 46 | } 47 | 48 | auto start_time = timer_clock::now(); 49 | 50 | for(size_t i = 0; i < 10; ++i){ 51 | C = selected_helper(etl::bias_add_impl::STD, etl::bias_add_4d(A, B)); 52 | } 53 | 54 | auto end_time = timer_clock::now(); 55 | auto duration = std::chrono::duration_cast(end_time - start_time); 56 | 57 | std::cout << " std: " << duration.count() << "ms" << std::endl; 58 | std::cout << " mean: " << (duration.count() / 10.0) << "ms" << std::endl; 59 | } 60 | 61 | void vec_case(){ 62 | etl::dyn_matrix A(128, 64, 100, 100); 63 | etl::dyn_matrix B(64); 64 | etl::dyn_matrix C(128, 64, 100, 100); 65 | 66 | for(size_t i = 0; i < 5; ++i){ 67 | C = selected_helper(etl::bias_add_impl::VEC, etl::bias_add_4d(A, B)); 68 | } 69 | 70 | auto start_time = timer_clock::now(); 71 | 72 | for(size_t i = 0; i < 10; ++i){ 73 | C = selected_helper(etl::bias_add_impl::VEC, etl::bias_add_4d(A, B)); 74 | } 75 | 76 | auto end_time = timer_clock::now(); 77 | auto duration = std::chrono::duration_cast(end_time - start_time); 78 | 79 | std::cout << " vec: " << duration.count() << "ms" << std::endl; 80 | std::cout << " mean: " << (duration.count() / 10.0) << "ms" << std::endl; 81 | } 82 | 83 | int main(){ 84 | default_case(); 85 | std_case(); 86 | vec_case(); 87 | 88 | return 0; 89 | } 90 | -------------------------------------------------------------------------------- /include/etl/impl/egblas/one_if_max_sub.hpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | /*! 9 | * \file 10 | * \brief EGBLAS wrappers for the one_if_max_sub operation. 11 | */ 12 | 13 | #pragma once 14 | 15 | #ifdef ETL_EGBLAS_MODE 16 | 17 | #include "etl/impl/cublas/cuda.hpp" 18 | 19 | #include 20 | 21 | #endif 22 | 23 | namespace etl::impl::egblas { 24 | 25 | #ifdef EGBLAS_HAS_SONE_IF_MAX_SUB 26 | static constexpr bool has_sone_if_max_sub = true; 27 | #else 28 | static constexpr bool has_sone_if_max_sub = false; 29 | #endif 30 | 31 | /*! 32 | * \brief Wrappers for single-precision egblas one_if_max_sub operation 33 | * \param b The batch dimension of the matrix 34 | * \param n The size of the output vector 35 | * \param A The memory of the vector a 36 | * \param lda The leading dimension of a 37 | * \param B The memory of the vector b 38 | * \param ldb The leading dimension of b 39 | */ 40 | inline void one_if_max_sub([[maybe_unused]] size_t b, 41 | [[maybe_unused]] size_t n, 42 | [[maybe_unused]] float alpha, 43 | [[maybe_unused]] float* A, 44 | [[maybe_unused]] size_t lda, 45 | [[maybe_unused]] float* B, 46 | [[maybe_unused]] size_t ldb) { 47 | #ifdef EGBLAS_HAS_SONE_IF_MAX_SUB 48 | inc_counter("egblas"); 49 | egblas_sone_if_max_sub(b, n, alpha, A, lda, B, ldb); 50 | #else 51 | cpp_unreachable("Invalid call to egblas::one_if_max_sub"); 52 | #endif 53 | } 54 | 55 | #ifdef EGBLAS_HAS_DONE_IF_MAX_SUB 56 | static constexpr bool has_done_if_max_sub = true; 57 | #else 58 | static constexpr bool has_done_if_max_sub = false; 59 | #endif 60 | 61 | /*! 62 | * \brief Wrappers for double-precision egblas one_if_max_sub operation 63 | * \param b The batch dimension of the matrix 64 | * \param n The size of the output vector 65 | * \param A The memory of the vector a 66 | * \param lda The leading dimension of a 67 | * \param B The memory of the vector b 68 | * \param ldb The leading dimension of b 69 | */ 70 | inline void one_if_max_sub([[maybe_unused]] size_t b, 71 | [[maybe_unused]] size_t n, 72 | [[maybe_unused]] double alpha, 73 | [[maybe_unused]] double* A, 74 | [[maybe_unused]] size_t lda, 75 | [[maybe_unused]] double* B, 76 | [[maybe_unused]] size_t ldb) { 77 | #ifdef EGBLAS_HAS_DONE_IF_MAX_SUB 78 | inc_counter("egblas"); 79 | egblas_done_if_max_sub(b, n, alpha, A, lda, B, ldb); 80 | #else 81 | cpp_unreachable("Invalid call to egblas::one_if_max_sub"); 82 | #endif 83 | } 84 | 85 | } //end of namespace etl::impl::egblas 86 | -------------------------------------------------------------------------------- /test/src/alias.cpp: -------------------------------------------------------------------------------- 1 | //======================================================================= 2 | // Copyright (c) 2014-2023 Baptiste Wicht 3 | // Distributed under the terms of the MIT License. 4 | // (See accompanying file LICENSE or copy at 5 | // http://opensource.org/licenses/MIT) 6 | //======================================================================= 7 | 8 | #include "test.hpp" 9 | 10 | TEMPLATE_TEST_CASE_2("alias/1", "[alias]", Z, float, double) { 11 | etl::fast_matrix a; 12 | etl::dyn_matrix b(3, 3); 13 | 14 | REQUIRE_DIRECT(a.alias(a)); 15 | REQUIRE_DIRECT(b.alias(b)); 16 | REQUIRE_DIRECT(!b.alias(a)); 17 | REQUIRE_DIRECT(!a.alias(b)); 18 | 19 | REQUIRE_DIRECT(a.alias(a + a)); 20 | REQUIRE_DIRECT(a.alias(a + 1)); 21 | REQUIRE_DIRECT(!a.alias(b + 1)); 22 | REQUIRE_DIRECT((a + a).alias(a + a)); 23 | REQUIRE_DIRECT(!a.alias(b >> b)); 24 | 25 | REQUIRE_DIRECT(a.alias(a(0))); 26 | REQUIRE_DIRECT(a(0).alias(a(0))); 27 | REQUIRE_DIRECT(!a(0).alias(a(1))); 28 | REQUIRE_DIRECT(a.alias(a(0) + a(1))); 29 | } 30 | 31 | TEMPLATE_TEST_CASE_2("alias/traits/1", "[alias][traits]", Z, float, double) { 32 | etl::fast_matrix a({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}); 33 | 34 | //Test linear operations 35 | REQUIRE_DIRECT(etl::decay_traits::is_linear); 36 | REQUIRE_DIRECT(!etl::decay_traits::is_linear); 37 | REQUIRE_DIRECT(etl::decay_traits> a) + a - a / a)>::is_linear); 38 | REQUIRE_DIRECT(etl::decay_traits::is_linear); 39 | REQUIRE_DIRECT(etl::decay_traits::is_linear); 40 | REQUIRE_DIRECT(etl::decay_traits::is_linear); 41 | 42 | //Test non linear operations 43 | REQUIRE_DIRECT(!etl::decay_traits::is_linear); 44 | REQUIRE_DIRECT(!etl::decay_traits::is_linear); 45 | } 46 | 47 | TEMPLATE_TEST_CASE_2("alias/transpose/1", "[alias][transpose]", Z, float, double) { 48 | etl::fast_matrix a({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}); 49 | 50 | a = transpose(a); 51 | 52 | REQUIRE_EQUALS(a(0, 0), 1.0); 53 | REQUIRE_EQUALS(a(0, 1), 4.0); 54 | REQUIRE_EQUALS(a(0, 2), 7.0); 55 | REQUIRE_EQUALS(a(1, 0), 2.0); 56 | REQUIRE_EQUALS(a(1, 1), 5.0); 57 | REQUIRE_EQUALS(a(1, 2), 8.0); 58 | REQUIRE_EQUALS(a(2, 0), 3.0); 59 | REQUIRE_EQUALS(a(2, 1), 6.0); 60 | REQUIRE_EQUALS(a(2, 2), 9.0); 61 | } 62 | 63 | TEMPLATE_TEST_CASE_2("alias/transpose/2", "[alias][transpose]", Z, float, double) { 64 | etl::fast_matrix a({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}); 65 | 66 | a = (transpose(a) >> 2.0) + (transpose(a) >> 3.0); 67 | 68 | REQUIRE_EQUALS(a(0, 0), 5.0); 69 | REQUIRE_EQUALS(a(0, 1), 20.0); 70 | REQUIRE_EQUALS(a(0, 2), 35.0); 71 | REQUIRE_EQUALS(a(1, 0), 10.0); 72 | REQUIRE_EQUALS(a(1, 1), 25.0); 73 | REQUIRE_EQUALS(a(1, 2), 40.0); 74 | REQUIRE_EQUALS(a(2, 0), 15.0); 75 | REQUIRE_EQUALS(a(2, 1), 30.0); 76 | REQUIRE_EQUALS(a(2, 2), 45.0); 77 | } 78 | --------------------------------------------------------------------------------