├── LICENSE ├── .gitmodules ├── .clang-format ├── samples ├── CMakeLists.txt ├── add.cpp └── mandelbrot.cu ├── .gitignore ├── CMakeLists.txt ├── README-zh.md ├── README.md └── mtensor.hpp /LICENSE: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/stb"] 2 | path = third_party/stb 3 | url = https://github.com/nothings/stb.git 4 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | # We'll use defaults from the Google style, but with 4 columns indentation. 3 | BasedOnStyle: Google 4 | IndentWidth: 4 5 | ColumnLimit: 100 6 | --- 7 | Language: Cpp 8 | 9 | --- 10 | Language: JavaScript 11 | 12 | --- 13 | Language: Proto 14 | # Don't format .proto files. 15 | DisableFormat: true 16 | 17 | ... 18 | -------------------------------------------------------------------------------- /samples/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(image_helper INTERFACE) 2 | target_include_directories(image_helper INTERFACE ${PROJECT_SOURCE_DIR}/third_party/stb) 3 | 4 | add_executable(add add.cpp) 5 | target_link_libraries(add mtensor image_helper) 6 | 7 | if (WITH_CUDA) 8 | add_executable(mandelbrot mandelbrot.cu) 9 | target_link_libraries(mandelbrot mtensor image_helper) 10 | endif () 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files 2 | *.slo 3 | *.lo 4 | *.o 5 | *.obj 6 | 7 | # Precompiled Headers 8 | *.gch 9 | *.pch 10 | 11 | # Compiled Dynamic libraries 12 | *.so 13 | *.dylib 14 | *.dll 15 | 16 | # Fortran module files 17 | *.mod 18 | *.smod 19 | 20 | # Compiled Static libraries 21 | *.lai 22 | *.la 23 | *.a 24 | *.lib 25 | 26 | # Executables 27 | *.exe 28 | *.out 29 | *.app 30 | *.exp 31 | 32 | # Visual Studio 33 | *.db 34 | *.pdb 35 | *.opendb 36 | .vs 37 | CMakeSettings.json 38 | 39 | vs 40 | build* 41 | tmp 42 | 43 | bin 44 | log 45 | x64 46 | 47 | .vscode 48 | core* 49 | 50 | *.png 51 | *.jpg 52 | *.bmp 53 | 54 | lenna.png 55 | 56 | *.raw 57 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | 3 | project(MTensor CXX) 4 | 5 | set(CMAKE_CXX_STANDARD 14) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | set(CMAKE_CUDA_STANDARD ${CMAKE_CXX_STANDARD}) 8 | set(CMAKE_CUDA_STANDARD_REQUIRED ON) 9 | 10 | include(CheckLanguage) 11 | check_language(CUDA) 12 | 13 | if (CMAKE_CUDA_COMPILER) 14 | enable_language(CUDA) 15 | set(WITH_CUDA ON) 16 | set(CMAKE_CUDA_ARCHITECTURES OFF) 17 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") 18 | endif() 19 | 20 | add_library(mtensor INTERFACE) 21 | target_include_directories(mtensor 22 | INTERFACE $ 23 | INTERFACE $ 24 | ) 25 | 26 | add_subdirectory(samples) 27 | -------------------------------------------------------------------------------- /samples/add.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "mtensor.hpp" 4 | 5 | using namespace matazure; 6 | 7 | int main(int argc, char* argv[]) { 8 | pointi<2> shape = {2, 4}; 9 | tensor ts0(shape, runtime::host); 10 | tensor ts1(shape, runtime::host); 11 | auto set_fun = [=](pointi<2> idx) { 12 | ts0(idx) = 1; 13 | ts1(idx) = 2; 14 | }; 15 | // for_index will call set_fun(idx), for idx in grid [(0, 0), shape); 16 | for_index(shape, set_fun); 17 | // make a lambda tensor which fun(idx) = ts0(idx) + ts1(idx); 18 | auto lts_add = make_lambda_tensor(shape, [=](pointi<2> idx) { return ts0(idx) + ts1(idx); }); 19 | // mtensor has lazy evaluating binary operators +-*/ 20 | auto lts_add2 = lts_add + ts1; 21 | // return a host tensor(default runtime is host) ts_re that ts_re(idx) = lts_add2(x) for idx in 22 | // a grid of shape 23 | auto ts_re = lts_add2.persist(); 24 | for (int row = 0; row < ts_re.shape()[0]; ++row) { 25 | for (int col = 0; col < ts_re.shape()[1]; ++col) { 26 | std::cout << ts_re(row, col) << ", "; 27 | } 28 | std::cout << std::endl; 29 | } 30 | 31 | return 0; 32 | } 33 | -------------------------------------------------------------------------------- /README-zh.md: -------------------------------------------------------------------------------- 1 | # mtensor 2 | 3 | mtensor是一个c++/cuda模板库, 其支持tensor的延迟计算. 和Eigen不同的是, mtensor以张量为核心数据结构, 并且支持cuda的延迟计算. mtensor不提供线性代数和数值计算的功能. 4 | 5 | ## 如何使用 6 | 7 | 8 | ### tensor 9 | 10 | mtensor由模板类**tensor**来管理基本的数据结构, 可以这样使用它 11 | 12 | ```c++ 13 | pointi<2> shape = {10, 20}; 14 | tensor cts(shape, runtime::cuda); 15 | tensor ts(shape, runtime::host); 16 | ``` 17 | 18 | tensor可以通过模板参数来定义value_type和维度, 通过cuda/host来可以指定tensor的内存类型. 19 | 20 | ### lambda_tensor 21 | 22 | lambda_tensor是一个函数式的tensor结构, 其有着很强的表达能力. 23 | 我们可以通过shape和"index->value"的函数来定义lambda_tensor. 24 | 25 | ```c++ 26 | pointi<2> shape = {10, 10}; 27 | tensor ts(shape); 28 | auto fun = [=](pointi<2> idx) { 29 | return ts(idx) + ts(idx); 30 | }; 31 | lambda_tensor lts(shape, fun); 32 | // will evaluate fun(pointi<2>{2,2}), and return it; 33 | auto value = lts(2, 2); 34 | ``` 35 | 36 | ### for_index 37 | 38 | for_index是一个函数式的for循环, 原则上我们认为其是并行的. for_index支持host和cuda两种执行方式, 我们需要注意在cuda运行时里, 我们只能访问对应运行时的tensor. 39 | 40 | ```c++ 41 | pointi<2> shape = {10, 10}; 42 | tensor cts(shape, runtime::cuda); 43 | auto lts = make_lambda_tensor(shape, [](pointi<2> idx)->int{ 44 | return idx[0] * 10 + idx[1]; 45 | }); 46 | 47 | for_index(shape, [=](pointi<2> idx){ 48 | cts(idx) = lts(idx); 49 | }, cts.runtime()); // in cuda runtime, we can only access cuda runtime tensor 50 | ``` 51 | 52 | 我们可以拓展for_index实现不同的加速引擎, 比如我们可以拓展一个omp的for_index来使用cpu的多核. 53 | 54 | 上述三个功能是mtensor库的核心构成, 我们可以为绕它们快速实现很多功能. 详细可以参考代码用例[samples](samples) 55 | 56 | ## 用例编译 57 | 58 | ```bash 59 | git submodule update --init . 60 | cmake -B build 61 | ``` 62 | 63 | ## 如何使用 64 | 65 | 下载[mtensor.hpp](mtensor.hpp)单一头文件, 将其包含在项目中. 66 | 67 | ``` 68 | #include "mtensor.hpp" 69 | ``` 70 | 71 | 在使用mtensor中遇到问题, 可以自行修复下然后提一个pull request. 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mtensor 2 | 3 | mtensor is a C++/CUDA template library that supports lazy evaluation of Tensor. Unlike Eigen, mtensor focus the tensor structure and supports cuda lazy evalution. mtensor does not provide linear algebra and numerical features. 4 | 5 | ## How to use 6 | 7 | ### tensor 8 | 9 | We use the template class **tensor** to manage basic data structures, and we use it as the following 10 | 11 | ```c++ 12 | pointi<2> shape = {10, 20}; 13 | tensor cts(shape, runtime::cuda); 14 | tensor ts(shape, runtime::host); 15 | ``` 16 | 17 | **tensor** can change value_type and dimensions through template parameters, and specify the memory type o by runtime host or cuda. 18 | 19 | ### lambda_tensor 20 | 21 | **lambda_tensor** is a functional tensor which has strong representatility. 22 | we can define a lambda_tensor by a "index->value" function with shape(domain). 23 | 24 | ```c++ 25 | pointi<2> shape = {10, 10}; 26 | tensor ts(shape); 27 | auto fun = [=](pointi<2> idx) { 28 | return ts(idx) + ts(idx); 29 | }; 30 | lambda_tensor lts(shape, fun); 31 | // will evaluate fun(pointi<2>{2,2}), and return it; 32 | auto value = lts(2, 2); 33 | ``` 34 | 35 | ### for_index 36 | 37 | **for_index** is a functional loop, which supports host or cuda execution mode. 38 | note that in the cuda runtime, we can only access the cuda tensor. 39 | 40 | ```c++ 41 | pointi<2> shape = {10, 10}; 42 | tensor cts(shape, runtime::cuda); 43 | auto lts = make_lambda_tensor(shape, [](pointi<2> idx)->int{ 44 | return idx[0] * 10 + idx[1]; 45 | }); 46 | 47 | for_index(shape, [=](pointi<2> idx){ 48 | cts(idx) = lts(idx); 49 | }, cts.runtime()); // in cuda runtime, we can only access cuda runtime tensor 50 | ``` 51 | 52 | We can extend for_index that implement different execution policy, for example we can extend an openmp for_index to use the CPU's multi-core. 53 | 54 | ## samples build 55 | 56 | ```bash 57 | git submodule update --init . 58 | cmake -B build 59 | ``` 60 | 61 | ## How to integrate 62 | 63 | Download [mtensor.hpp](mtensor.hpp) single header file, include it in your project. 64 | 65 | ``` 66 | #include "mtensor.hpp" 67 | ``` 68 | 69 | mtensor project is lite. when you find a bug, please fix it and pull a request. 70 | -------------------------------------------------------------------------------- /samples/mandelbrot.cu: -------------------------------------------------------------------------------- 1 | #define STB_IMAGE_IMPLEMENTATION 2 | #define STB_IMAGE_WRITE_IMPLEMENTATION 3 | #include 4 | #include 5 | 6 | #include "mtensor.hpp" 7 | #include "stb_image_write.h" 8 | 9 | using namespace matazure; 10 | using rgb = point; 11 | 12 | int main(int argc, char* argv[]) { 13 | pointi<2> shape = {2048, 2048}; 14 | int_t max_iteration = 256 * 16; 15 | // make a lambda tensor to evaluate the mandelbrot set. 16 | auto lts_mandelbrot = make_lambda_tensor(shape, [=] __general__(pointi<2> idx) -> float { 17 | point idxf; 18 | idxf[0] = idx[0]; 19 | idxf[1] = idx[1]; 20 | point shapef; 21 | shapef[0] = shape[0]; 22 | shapef[1] = shape[1]; 23 | point c = 24 | idxf / shapef * point{3.25f, 2.5f} - point{2.0f, 1.25f}; 25 | auto z = point::all(0.0f); 26 | auto norm = 0.0f; 27 | int_t value = 0; 28 | while (norm <= 4.0f && value < max_iteration) { 29 | float tmp = z[0] * z[0] - z[1] * z[1] + c[0]; 30 | z[1] = 2 * z[0] * z[1] + c[1]; 31 | z[0] = tmp; 32 | ++value; 33 | norm = z[0] * z[0] + z[1] * z[1]; 34 | } 35 | 36 | return value; 37 | }); 38 | 39 | // convert mandelbrot value to rgb pixel 40 | auto lts_rgb_mandelbrot = 41 | make_lambda_tensor(lts_mandelbrot.shape(), [=] __general__(pointi<2> idx) { 42 | float t = lts_mandelbrot(idx) / max_iteration; 43 | auto r = static_cast(36 * (1 - t) * t * t * t * 255); 44 | auto g = static_cast(60 * (1 - t) * (1 - t) * t * t * 255); 45 | auto b = static_cast(38 * (1 - t) * (1 - t) * (1 - t) * t * 255); 46 | return rgb{r, g, b}; 47 | }); 48 | 49 | // select runtime 50 | runtime rt = runtime::cuda; 51 | if (argc > 1 && std::string(argv[1]) == "host") { 52 | rt = runtime::host; 53 | } 54 | 55 | auto t0 = std::chrono::high_resolution_clock::now(); 56 | // persist lambda tensor on cuda/host runtime 57 | auto ts_rgb_mandelbrot = lts_rgb_mandelbrot.persist(rt); 58 | auto t1 = std::chrono::high_resolution_clock::now(); 59 | std::cout << "render mandelbrot cost time: " 60 | << std::chrono::duration_cast(t1 - t0).count() << " ms" 61 | << std::endl; 62 | 63 | // sync ts_rgb_mandelbrot to host. if it's host alread, return itself 64 | ts_rgb_mandelbrot = ts_rgb_mandelbrot.sync(runtime::host); 65 | stbi_write_png("mandelbrot.png", ts_rgb_mandelbrot.shape()[1], ts_rgb_mandelbrot.shape()[0], 3, 66 | ts_rgb_mandelbrot.data(), ts_rgb_mandelbrot.shape()[1] * 3); 67 | 68 | return 0; 69 | } 70 | -------------------------------------------------------------------------------- /mtensor.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /***************************************************************************** 4 | MIT License 5 | 6 | Copyright (c) 2017 Zhang Zhimin 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | *****************************************************************************/ 26 | 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | 35 | // for cuda 36 | #if defined(__CUDACC__) && !defined(MATAZURE_DISABLE_CUDA) 37 | #ifdef __clang__ 38 | #if __clang_major__ < 9 39 | #error clang minimum version is 9 for cuda 40 | #endif 41 | #else 42 | #if __CUDACC_VER_MAJOR__ < 9 43 | #error CUDA minimum version is 10.0 44 | #endif 45 | #endif 46 | 47 | #define MATAZURE_CUDA 48 | #endif 49 | 50 | #ifdef MATAZURE_CUDA 51 | #define MATAZURE_GENERAL __host__ __device__ 52 | #ifndef __clang__ 53 | #define MATAZURE_NV_EXE_CHECK_DISABLE #pragma nv_exec_check_disable 54 | #else 55 | #define MATAZURE_NV_EXE_CHECK_DISABLE 56 | #endif 57 | #else 58 | #define MATAZURE_GENERAL 59 | #define MATAZURE_NV_EXE_CHECK_DISABLE 60 | #endif 61 | 62 | #define __general__ MATAZURE_GENERAL 63 | 64 | #if __cplusplus >= 201103L || (defined(_MSC_VER) && _MSC_VER >= 1900) 65 | #else 66 | #error "use c++11 at least" 67 | #endif 68 | 69 | namespace matazure { 70 | 71 | class assert_failed : public std::exception { 72 | public: 73 | assert_failed(const std::string& expr, const std::string& file, size_t line, 74 | const std::string& msg = " ") 75 | : _expr(expr), _file(file), _line(line), _msg(msg) { 76 | _what_str = _expr + ", " + _file + ", " + std::to_string(_line) + ", " + _msg; 77 | } 78 | 79 | virtual const char* what() const noexcept override { return _what_str.c_str(); } 80 | 81 | private: 82 | std::string _expr; 83 | std::string _file; 84 | size_t _line; 85 | std::string _msg; 86 | std::string _what_str; 87 | }; 88 | 89 | inline void raise_assert_failed(const std::string& expr, const std::string& file, long line, 90 | const std::string& msg = " ") { 91 | throw assert_failed(expr, file, line, msg); 92 | } 93 | 94 | inline void raise_verify_failed(const std::string& expr, const std::string& file, long line, 95 | const std::string& msg = " ") { 96 | throw assert_failed(expr, file, line, msg); 97 | } 98 | 99 | } // namespace matazure 100 | 101 | #if defined(MATAZURE_DISABLE_ASSERTS) 102 | #define MATAZURE_ASSERT(expr, msg) ((void)0) 103 | #else 104 | #define MATAZURE_ASSERT(expr, ...) \ 105 | ((!!(expr)) ? ((void)0) \ 106 | : ::matazure::raise_assert_failed(#expr, __FILE__, __LINE__, ##__VA_ARGS__)) 107 | #endif 108 | 109 | #define MATAZURE_VERIFY(expr, ...) \ 110 | ((!!(expr)) ? ((void)0) \ 111 | : ::matazure::raise_verify_failed(#expr, __FILE__, __LINE__, ##__VA_ARGS__)) 112 | 113 | #ifdef MATAZURE_CUDA 114 | #include "cuda_occupancy.h" 115 | #include "cuda_runtime.h" 116 | #endif 117 | 118 | namespace matazure { 119 | 120 | typedef int int_t; 121 | 122 | enum struct runtime { host, cuda }; 123 | 124 | template 125 | class point { 126 | public: 127 | static const int_t rank = _Rank; 128 | typedef _ValueType value_type; 129 | typedef value_type& reference; 130 | typedef const value_type& const_reference; 131 | 132 | MATAZURE_GENERAL constexpr const_reference operator[](int_t i) const { return elements_[i]; } 133 | 134 | MATAZURE_GENERAL reference operator[](int_t i) { return elements_[i]; } 135 | 136 | MATAZURE_GENERAL constexpr int_t size() const { return rank; } 137 | 138 | MATAZURE_GENERAL static point all(value_type v) { 139 | point re{}; 140 | for (int_t i = 0; i < re.size(); ++i) { 141 | re[i] = v; 142 | } 143 | return re; 144 | } 145 | 146 | public: 147 | value_type elements_[rank]; 148 | }; 149 | 150 | // binary opertor 151 | #define MATAZURE_POINT_BINARY_OPERATOR(op) \ 152 | template \ 153 | inline MATAZURE_GENERAL auto operator op(const point<_T, _Rank>& lhs, \ 154 | const point<_T, _Rank>& rhs) \ 155 | ->point { \ 156 | point re; \ 157 | for (int_t i = 0; i < _Rank; ++i) { \ 158 | re[i] = lhs[i] op rhs[i]; \ 159 | } \ 160 | return re; \ 161 | } 162 | 163 | // assignment operators 164 | #define MATAZURE_POINT_ASSIGNMENT_OPERATOR(op) \ 165 | template \ 166 | inline MATAZURE_GENERAL auto operator op(point<_T, _Rank>& lhs, const point<_T, _Rank>& rhs) \ 167 | ->point<_T, _Rank> { \ 168 | for (int_t i = 0; i < _Rank; ++i) { \ 169 | lhs[i] op rhs[i]; \ 170 | } \ 171 | return lhs; \ 172 | } 173 | 174 | // Arithmetic 175 | MATAZURE_POINT_BINARY_OPERATOR(+) 176 | MATAZURE_POINT_BINARY_OPERATOR(-) 177 | MATAZURE_POINT_BINARY_OPERATOR(*) 178 | MATAZURE_POINT_BINARY_OPERATOR(/) 179 | MATAZURE_POINT_ASSIGNMENT_OPERATOR(+=) 180 | MATAZURE_POINT_ASSIGNMENT_OPERATOR(-=) 181 | MATAZURE_POINT_ASSIGNMENT_OPERATOR(*=) 182 | MATAZURE_POINT_ASSIGNMENT_OPERATOR(/=) 183 | 184 | template 185 | inline MATAZURE_GENERAL point<_T, _Rank> operator+(const point<_T, _Rank>& p) { 186 | return p; 187 | } 188 | 189 | template 190 | inline MATAZURE_GENERAL point<_T, _Rank> operator-(const point<_T, _Rank>& p) { 191 | point<_T, _Rank> temp; 192 | for (int_t i = 0; i < _Rank; ++i) { 193 | temp[i] = -p[i]; 194 | } 195 | 196 | return temp; 197 | } 198 | 199 | template 200 | using pointi = point; 201 | 202 | template 203 | class row_major_layout { 204 | public: 205 | const static int_t rank = _Rank; 206 | 207 | MATAZURE_GENERAL row_major_layout() : row_major_layout(pointi{0}){}; 208 | 209 | MATAZURE_GENERAL row_major_layout(const pointi& shape) : shape_(shape) { 210 | stride_[rank - 1] = 1; 211 | for (int_t i = rank - 2; i >= 0; --i) { 212 | stride_[i] = shape[i + 1] * stride_[i + 1]; 213 | } 214 | size_ = stride_[0] * shape[0]; 215 | } 216 | 217 | MATAZURE_GENERAL row_major_layout(const row_major_layout& rhs) 218 | : row_major_layout(rhs.shape()) {} 219 | 220 | MATAZURE_GENERAL row_major_layout& operator=(const row_major_layout& rhs) { 221 | shape_ = rhs.shape(); 222 | stride_ = rhs.stride(); 223 | return *this; 224 | } 225 | 226 | MATAZURE_GENERAL int_t index2offset(const pointi& id) const { 227 | typename pointi::value_type offset = 0; 228 | for (int_t i = rank - 1; i >= 0; --i) { 229 | offset += id[i] * stride_[i]; 230 | } 231 | return offset; 232 | }; 233 | 234 | MATAZURE_GENERAL pointi offset2index(int_t offset) const { 235 | pointi id; 236 | for (int_t i = 0; i < rank; ++i) { 237 | id[i] = offset / stride_[i]; 238 | offset = offset % stride_[i]; 239 | } 240 | return id; 241 | } 242 | 243 | MATAZURE_GENERAL int_t size() const { return size_; } 244 | 245 | MATAZURE_GENERAL pointi shape() const { return shape_; } 246 | 247 | MATAZURE_GENERAL pointi stride() const { return stride_; } 248 | 249 | MATAZURE_GENERAL ~row_major_layout() {} 250 | 251 | private: 252 | pointi shape_; 253 | pointi stride_; 254 | int_t size_; 255 | }; 256 | 257 | } // namespace matazure 258 | 259 | #ifdef MATAZURE_CUDA 260 | namespace matazure::cuda { 261 | class runtime_error : public std::runtime_error { 262 | public: 263 | runtime_error(cudaError_t error_code) 264 | : std::runtime_error(cudaGetErrorString(error_code)), error_code_(error_code) {} 265 | 266 | private: 267 | cudaError_t error_code_; 268 | }; 269 | 270 | inline void verify_runtime_success(cudaError_t result) { 271 | if (result != cudaSuccess) { 272 | throw runtime_error(result); 273 | } 274 | } 275 | 276 | class occupancy_error : public std::runtime_error { 277 | public: 278 | occupancy_error(cudaOccError error_code) 279 | : std::runtime_error("cuda occupancy error"), error_code_(error_code) {} 280 | 281 | private: 282 | cudaOccError error_code_; 283 | }; 284 | 285 | inline void verify_occupancy_success(cudaOccError result) { 286 | if (result != CUDA_OCC_SUCCESS) { 287 | throw occupancy_error(result); 288 | } 289 | } 290 | 291 | namespace internal { 292 | 293 | class device_properties_cache { 294 | public: 295 | static cudaDeviceProp& get() { 296 | static device_properties_cache instance; 297 | 298 | int dev_id; 299 | verify_runtime_success(cudaGetDevice(&dev_id)); 300 | 301 | std::lock_guard guard(instance.mtx_); 302 | 303 | if (instance.device_prop_cache_.find(dev_id) == instance.device_prop_cache_.end()) { 304 | instance.device_prop_cache_[dev_id] = cudaDeviceProp(); 305 | verify_runtime_success( 306 | cudaGetDeviceProperties(&instance.device_prop_cache_[dev_id], dev_id)); 307 | } 308 | return instance.device_prop_cache_[dev_id]; 309 | } 310 | 311 | private: 312 | std::map device_prop_cache_; 313 | std::mutex mtx_; 314 | }; 315 | 316 | inline size_t availableSharedBytesPerBlock(size_t sharedMemPerMultiprocessor, 317 | size_t sharedSizeBytesStatic, int blocksPerSM, 318 | int smemAllocationUnit) { 319 | size_t bytes = __occRoundUp(sharedMemPerMultiprocessor / blocksPerSM, smemAllocationUnit) - 320 | smemAllocationUnit; 321 | return bytes - sharedSizeBytesStatic; 322 | } 323 | 324 | inline MATAZURE_GENERAL uint3 pointi_to_uint3(pointi<1> p) { 325 | return {static_cast(p[0]), 0, 0}; 326 | } 327 | 328 | inline MATAZURE_GENERAL uint3 pointi_to_uint3(pointi<2> p) { 329 | return {static_cast(p[0]), static_cast(p[1]), 0}; 330 | } 331 | 332 | inline MATAZURE_GENERAL uint3 pointi_to_uint3(pointi<3> p) { 333 | return {static_cast(p[0]), static_cast(p[1]), 334 | static_cast(p[2])}; 335 | } 336 | 337 | template 338 | inline MATAZURE_GENERAL pointi<_Rank> uint3_to_pointi(uint3 u); 339 | 340 | template <> 341 | inline MATAZURE_GENERAL pointi<1> uint3_to_pointi(uint3 u) { 342 | return {static_cast(u.x)}; 343 | } 344 | 345 | template <> 346 | inline MATAZURE_GENERAL pointi<2> uint3_to_pointi(uint3 u) { 347 | return {static_cast(u.x), static_cast(u.y)}; 348 | } 349 | 350 | template <> 351 | inline MATAZURE_GENERAL pointi<3> uint3_to_pointi(uint3 u) { 352 | return {static_cast(u.x), static_cast(u.y), static_cast(u.z)}; 353 | } 354 | 355 | inline MATAZURE_GENERAL dim3 pointi_to_dim3(pointi<1> p) { 356 | return {static_cast(p[0]), 1, 1}; 357 | } 358 | 359 | inline MATAZURE_GENERAL dim3 pointi_to_dim3(pointi<2> p) { 360 | return {static_cast(p[0]), static_cast(p[1]), 1}; 361 | } 362 | 363 | inline MATAZURE_GENERAL dim3 pointi_to_dim3(pointi<3> p) { 364 | return {static_cast(p[0]), static_cast(p[1]), 365 | static_cast(p[2])}; 366 | } 367 | 368 | template 369 | inline MATAZURE_GENERAL pointi<_Rank> dim3_to_pointi(dim3 u); 370 | 371 | template <> 372 | inline MATAZURE_GENERAL pointi<1> dim3_to_pointi(dim3 u) { 373 | return {static_cast(u.x)}; 374 | } 375 | 376 | template <> 377 | inline MATAZURE_GENERAL pointi<2> dim3_to_pointi(dim3 u) { 378 | return {static_cast(u.x), static_cast(u.y)}; 379 | } 380 | 381 | template <> 382 | inline MATAZURE_GENERAL pointi<3> dim3_to_pointi(dim3 u) { 383 | return {static_cast(u.x), static_cast(u.y), static_cast(u.z)}; 384 | } 385 | 386 | } // namespace internal 387 | 388 | class execution_policy { 389 | public: 390 | execution_policy( 391 | pointi<3> grid_dim = {{0, 1, 1}}, pointi<3> block_dim = {{0, 1, 1}}, 392 | size_t shared_mem_bytes = 0, 393 | std::shared_ptr sp_stream = std::make_shared(nullptr)) 394 | : grid_dim_(grid_dim), 395 | block_dim_(block_dim), 396 | shared_mem_bytes_(shared_mem_bytes), 397 | sp_stream_(sp_stream) { 398 | if (*sp_stream_ == 0) { 399 | cudaStream_t stream; 400 | verify_runtime_success(cudaStreamCreate(&stream)); 401 | sp_stream_.reset(new cudaStream_t(stream), [](cudaStream_t* p) { 402 | verify_runtime_success(cudaStreamSynchronize(*p)); 403 | verify_runtime_success(cudaStreamDestroy(*p)); 404 | delete p; 405 | }); 406 | 407 | // TODO: has bug, refactor it 408 | // verify_runtime_success(cudaStreamCreate(&stream_)); 409 | } 410 | } 411 | 412 | pointi<3> grid_dim() const { return grid_dim_; } 413 | pointi<3> block_dim() const { return block_dim_; } 414 | size_t shared_mem_bytes() const { return shared_mem_bytes_; } 415 | cudaStream_t stream() const { return *sp_stream_; } 416 | 417 | void grid_dim(pointi<3> arg) { grid_dim_ = arg; } 418 | void block_dim(pointi<3> arg) { block_dim_ = arg; } 419 | void shared_mem_bytes(size_t arg) { shared_mem_bytes_ = arg; } 420 | 421 | void synchronize() { verify_runtime_success(cudaStreamSynchronize(stream())); } 422 | 423 | protected: 424 | pointi<3> grid_dim_ = {{0, 1, 1}}; 425 | pointi<3> block_dim_ = {{0, 1, 1}}; 426 | // 0 represents not use dynamic shared memory 427 | size_t shared_mem_bytes_ = 0; 428 | std::shared_ptr sp_stream_ = nullptr; 429 | }; 430 | 431 | class default_execution_policy : public execution_policy { 432 | public: 433 | protected: 434 | pointi<3> grid_dim_ = {{0, 1, 1}}; 435 | pointi<3> block_dim_ = {{0, 1, 1}}; 436 | // 0 represents not use dynamic shared memory 437 | size_t shared_mem_bytes_ = 0; 438 | cudaStream_t stream_ = nullptr; 439 | }; 440 | 441 | template 442 | inline void configure_grid(_ExePolicy& exe_policy, _KernelFunc kernel) { 443 | /// Do none 444 | } 445 | 446 | template 447 | inline void configure_grid(default_execution_policy& exe_policy, __KernelFunc k) { 448 | cudaDeviceProp* props; 449 | props = &internal::device_properties_cache::get(); 450 | 451 | cudaFuncAttributes attribs; 452 | cudaOccDeviceProp occProp(*props); 453 | 454 | verify_runtime_success(cudaFuncGetAttributes(&attribs, k)); 455 | cudaOccFuncAttributes occAttrib(attribs); 456 | 457 | cudaFuncCache cacheConfig; 458 | verify_runtime_success(cudaDeviceGetCacheConfig(&cacheConfig)); 459 | cudaOccDeviceState occState; 460 | occState.cacheConfig = (cudaOccCacheConfig)cacheConfig; 461 | 462 | int numSMs = props->multiProcessorCount; 463 | 464 | int bsize = 0, minGridSize = 0; 465 | verify_occupancy_success(cudaOccMaxPotentialOccupancyBlockSize( 466 | &minGridSize, &bsize, &occProp, &occAttrib, &occState, exe_policy.shared_mem_bytes())); 467 | exe_policy.block_dim({bsize, 1, 1}); 468 | 469 | cudaOccResult result; 470 | verify_occupancy_success(cudaOccMaxActiveBlocksPerMultiprocessor( 471 | &result, &occProp, &occAttrib, &occState, exe_policy.block_dim()[0], 472 | exe_policy.shared_mem_bytes())); 473 | exe_policy.grid_dim({result.activeBlocksPerMultiprocessor * numSMs, 1, 1}); 474 | 475 | int smemGranularity = 0; 476 | verify_occupancy_success(cudaOccSMemAllocationGranularity(&smemGranularity, &occProp)); 477 | size_t sbytes = internal::availableSharedBytesPerBlock( 478 | props->sharedMemPerBlock, attribs.sharedSizeBytes, 479 | __occDivideRoundUp(exe_policy.grid_dim()[0], numSMs), smemGranularity); 480 | 481 | exe_policy.shared_mem_bytes(sbytes); 482 | } 483 | 484 | class for_index_execution_policy : public execution_policy { 485 | public: 486 | int_t total_size() const { return total_size_; } 487 | void total_size(int_t size) { total_size_ = size; } 488 | 489 | protected: 490 | int_t total_size_ = 0; 491 | }; 492 | 493 | template 494 | inline void configure_grid(for_index_execution_policy& exe_policy, __KernelFunc k) { 495 | cudaDeviceProp* props; 496 | props = &internal::device_properties_cache::get(); 497 | 498 | cudaFuncAttributes attribs; 499 | cudaOccDeviceProp occProp(*props); 500 | 501 | verify_runtime_success(cudaFuncGetAttributes(&attribs, k)); 502 | cudaOccFuncAttributes occAttrib(attribs); 503 | 504 | cudaFuncCache cacheConfig; 505 | verify_runtime_success(cudaDeviceGetCacheConfig(&cacheConfig)); 506 | cudaOccDeviceState occState; 507 | occState.cacheConfig = (cudaOccCacheConfig)cacheConfig; 508 | 509 | int numSMs = props->multiProcessorCount; 510 | 511 | int bsize = 0, minGridSize = 0; 512 | verify_occupancy_success(cudaOccMaxPotentialOccupancyBlockSize( 513 | &minGridSize, &bsize, &occProp, &occAttrib, &occState, exe_policy.shared_mem_bytes())); 514 | exe_policy.block_dim({bsize, 1, 1}); 515 | 516 | cudaOccResult result; 517 | verify_occupancy_success(cudaOccMaxActiveBlocksPerMultiprocessor( 518 | &result, &occProp, &occAttrib, &occState, exe_policy.block_dim()[0], 519 | exe_policy.shared_mem_bytes())); 520 | exe_policy.grid_dim({result.activeBlocksPerMultiprocessor * numSMs, 1, 1}); 521 | 522 | auto pre_block_size = exe_policy.block_dim()[0]; 523 | auto tmp_block_size = __occDivideRoundUp(exe_policy.total_size(), exe_policy.grid_dim()[0]); 524 | tmp_block_size = __occRoundUp(tmp_block_size, 128); 525 | exe_policy.block_dim({std::min(tmp_block_size, pre_block_size), 1, 1}); 526 | 527 | int smemGranularity = 0; 528 | verify_occupancy_success(cudaOccSMemAllocationGranularity(&smemGranularity, &occProp)); 529 | size_t sbytes = internal::availableSharedBytesPerBlock( 530 | props->sharedMemPerBlock, attribs.sharedSizeBytes, 531 | __occDivideRoundUp(exe_policy.grid_dim()[0], numSMs), smemGranularity); 532 | 533 | exe_policy.shared_mem_bytes(sbytes); 534 | } 535 | 536 | template 537 | __global__ void kernel(Function f, Arguments... args) { 538 | f(args...); 539 | } 540 | 541 | template 542 | inline void launch(_ExecutionPolicy exe_policy, _Fun f, _Args... args) { 543 | configure_grid(exe_policy, kernel<_Fun, _Args...>); 544 | kernel<<>>(f, args...); 547 | verify_runtime_success(cudaGetLastError()); 548 | } 549 | 550 | template 551 | inline void launch(_Fun f, _Args... args) { 552 | default_execution_policy exe_policy; 553 | launch(exe_policy, f, args...); 554 | } 555 | 556 | template 557 | struct linear_index_functor_kernel { 558 | int last; 559 | _Fun fun; 560 | 561 | __device__ void operator()() { 562 | for (int_t i = threadIdx.x + blockIdx.x * blockDim.x; i < last; 563 | i += blockDim.x * gridDim.x) { 564 | fun(i); 565 | }; 566 | } 567 | }; 568 | 569 | template 570 | inline void for_linear_index(_ExecutionPolicy policy, int_t last, _Fun fun) { 571 | linear_index_functor_kernel<_Fun> func{last, fun}; 572 | launch(policy, func); 573 | } 574 | 575 | template 576 | inline void for_index(_ExecutionPolicy policy, pointi<_Rank> end, _Fun fun) { 577 | auto extent = end; 578 | 579 | pointi<_Rank> stride; 580 | stride[0] = extent[0]; 581 | for (int_t i = 1; i < _Rank; ++i) { 582 | stride[i] = extent[i] * stride[i - 1]; 583 | } 584 | 585 | row_major_layout<_Rank> layout(extent); 586 | auto max_size = layout.index2offset(end - pointi<_Rank>::all(1)) + 1; // 要包含最后一个元素 587 | 588 | cuda::for_linear_index(policy, max_size, 589 | [=] __device__(int_t i) { fun(layout.offset2index(i)); }); 590 | } 591 | 592 | template 593 | inline void for_index(pointi<_Rank> end, _Fun fun) { 594 | default_execution_policy p; 595 | cuda::for_index(p, end, fun); 596 | } 597 | 598 | } // namespace matazure::cuda 599 | #endif 600 | 601 | namespace matazure { 602 | 603 | template 604 | struct function_traits : public function_traits {}; 605 | 606 | /// implements 607 | template 608 | struct function_traits<_ReturnType (_ClassType::*)(_Args...) const> { 609 | enum { arguments_size = sizeof...(_Args) }; 610 | 611 | typedef _ReturnType result_type; 612 | 613 | template 614 | struct arguments { 615 | typedef typename std::tuple_element<_index, std::tuple<_Args...>>::type type; 616 | }; 617 | }; 618 | 619 | template 620 | class tensor_expression { 621 | public: 622 | typedef _Tensor tensor_type; 623 | 624 | const tensor_type& operator()() const { return *static_cast(this); } 625 | 626 | tensor_type& operator()() { return *static_cast(this); } 627 | 628 | protected: 629 | MATAZURE_GENERAL tensor_expression() {} 630 | MATAZURE_GENERAL ~tensor_expression() {} 631 | }; 632 | 633 | template 634 | class tensor : public tensor_expression> { 635 | public: 636 | static const int_t rank = _Rank; 637 | typedef _Type value_type; 638 | typedef value_type& reference; 639 | typedef value_type* pointer; 640 | 641 | tensor() : tensor(pointi::all(0)) {} 642 | 643 | explicit tensor(pointi ext, runtime rt = runtime::host) 644 | : shape_(ext), layout_(ext), runtime_(rt) { 645 | if (rt == runtime::host) { 646 | auto p = new value_type[layout_.size()]; 647 | this->sp_data_ = std::shared_ptr(p, [](value_type* p) { delete[] p; }); 648 | } else { 649 | value_type* p = nullptr; 650 | #ifdef MATAZURE_CUDA 651 | cuda::verify_runtime_success(cudaMalloc(&p, layout_.size() * sizeof(value_type))); 652 | this->sp_data_ = std::shared_ptr( 653 | p, [](value_type* p) { cuda::verify_runtime_success(cudaFree(p)); }); 654 | #else 655 | MATAZURE_ASSERT(false, "not in cuda runtime"); 656 | #endif 657 | } 658 | 659 | this->data_ = sp_data_.get(); 660 | } 661 | 662 | explicit tensor(pointi ext, runtime rt, std::shared_ptr sp_data) 663 | : shape_(ext), layout_(ext), runtime_(rt), sp_data_(sp_data), data_(sp_data_.get()) {} 664 | 665 | template 666 | tensor(const tensor<_VT, _Rank>& ts) 667 | : shape_(ts.shape()), layout_(ts.layout_), sp_data_(ts.shared_data()), data_(ts.data()) {} 668 | 669 | MATAZURE_GENERAL 670 | std::shared_ptr shared_data() const { return sp_data_; } 671 | 672 | MATAZURE_GENERAL reference operator()(const pointi& index) const { 673 | return (*this)[layout_.index2offset(index)]; 674 | } 675 | 676 | template 677 | MATAZURE_GENERAL reference operator()(_Idx... idx) const { 678 | return (*this)(pointi{idx...}); 679 | } 680 | 681 | MATAZURE_GENERAL reference operator[](int_t i) const { return data_[i]; } 682 | 683 | MATAZURE_GENERAL pointi shape() const { return shape_; } 684 | 685 | MATAZURE_GENERAL int_t shape(int_t i) const { return shape_[i]; }; 686 | 687 | MATAZURE_GENERAL pointi stride() const { return layout_.stride(); } 688 | 689 | MATAZURE_GENERAL int_t size() const { return layout_.size(); } 690 | 691 | MATAZURE_GENERAL pointer data() const { return data_; } 692 | 693 | MATAZURE_GENERAL enum runtime runtime() const { return runtime_; } 694 | 695 | tensor clone() { 696 | tensor ts_re(shape(), runtime()); 697 | if (runtime_ == runtime::host) { 698 | memcpy(ts_re.data(), this->data(), sizeof(value_type) * ts_re.size()); 699 | } else { 700 | #ifdef MATAZURE_CUDA 701 | cuda::verify_runtime_success(cudaMemcpy( 702 | ts_re.data(), this->data(), sizeof(value_type) * ts_re.size(), cudaMemcpyDefault)); 703 | #else 704 | MATAZURE_ASSERT(false, "not in cuda runtime"); 705 | #endif 706 | } 707 | return ts_re; 708 | } 709 | 710 | tensor sync(enum runtime rt) { 711 | if (runtime() == rt) { 712 | return *this; 713 | } else { 714 | #ifdef MATAZURE_CUDA 715 | tensor ts_re(shape(), rt); 716 | cuda::verify_runtime_success(cudaMemcpy( 717 | ts_re.data(), this->data(), sizeof(value_type) * ts_re.size(), cudaMemcpyDefault)); 718 | return ts_re; 719 | #else 720 | MATAZURE_ASSERT(false, "not in cuda runtime"); 721 | #endif 722 | } 723 | } 724 | 725 | MATAZURE_GENERAL ~tensor() {} 726 | 727 | private: 728 | pointi shape_; 729 | enum runtime runtime_; 730 | row_major_layout layout_; 731 | std::shared_ptr sp_data_; 732 | 733 | pointer data_; 734 | }; 735 | 736 | // nvcc walkaround, sometimes you need declare the tensor_type before using 737 | using tensor1b = tensor; 738 | using tensor2b = tensor; 739 | using tensor3b = tensor; 740 | using tensor4b = tensor; 741 | using tensor1s = tensor; 742 | using tensor2s = tensor; 743 | using tensor3s = tensor; 744 | using tensor4s = tensor; 745 | using tensor1i = tensor; 746 | using tensor2i = tensor; 747 | using tensor3i = tensor; 748 | using tensor4i = tensor; 749 | using tensor1f = tensor; 750 | using tensor2f = tensor; 751 | using tensor3f = tensor; 752 | using tensor4f = tensor; 753 | using tensor1d = tensor; 754 | using tensor2d = tensor; 755 | using tensor3d = tensor; 756 | using tensor4d = tensor; 757 | 758 | struct sequence_policy {}; 759 | 760 | MATAZURE_NV_EXE_CHECK_DISABLE 761 | template 762 | MATAZURE_GENERAL inline void for_index(sequence_policy, pointi<1> end, _Fun fun) { 763 | for (int_t i = 0; i < end[0]; ++i) { 764 | fun(pointi<1>{{i}}); 765 | } 766 | } 767 | 768 | MATAZURE_NV_EXE_CHECK_DISABLE 769 | template 770 | MATAZURE_GENERAL inline void for_index(sequence_policy, pointi<2> end, _Fun fun) { 771 | for (int_t i = 0; i < end[0]; ++i) { 772 | for (int_t j = 0; j < end[1]; ++j) { 773 | fun(pointi<2>{{i, j}}); 774 | } 775 | } 776 | } 777 | 778 | MATAZURE_NV_EXE_CHECK_DISABLE 779 | template 780 | MATAZURE_GENERAL inline void for_index(sequence_policy, pointi<3> end, _Fun fun) { 781 | for (int_t i = 0; i < end[0]; ++i) { 782 | for (int_t j = 0; j < end[1]; ++j) { 783 | for (int_t k = 0; k < end[2]; ++k) { 784 | fun(pointi<3>{{i, j, k}}); 785 | } 786 | } 787 | } 788 | } 789 | 790 | MATAZURE_NV_EXE_CHECK_DISABLE 791 | template 792 | MATAZURE_GENERAL inline void for_index(sequence_policy, pointi<4> end, _Fun fun) { 793 | for (int_t i = 0; i < end[0]; ++i) { 794 | for (int_t j = 0; j < end[1]; ++j) { 795 | for (int_t k = 0; k < end[2]; ++k) { 796 | for (int_t l = 0; l < end[3]; ++l) { 797 | fun(pointi<4>{{i, j, k, l}}); 798 | } 799 | } 800 | } 801 | } 802 | } 803 | 804 | MATAZURE_NV_EXE_CHECK_DISABLE 805 | template 806 | MATAZURE_GENERAL inline void for_index(pointi<_Rank> end, _Fun fun, runtime rt = runtime::host) { 807 | if (rt == runtime::host) { 808 | sequence_policy policy{}; 809 | for_index(policy, end, fun); 810 | } else { 811 | #ifdef MATAZURE_CUDA 812 | cuda::for_index(end, fun); 813 | #else 814 | MATAZURE_ASSERT(false, "not in cuda runtime"); 815 | #endif 816 | } 817 | } 818 | 819 | template 820 | class lambda_tensor : public tensor_expression> { 821 | typedef function_traits<_Fun> functor_traits; 822 | 823 | public: 824 | static_assert(functor_traits::arguments_size == 1, "functor should be a parameter"); 825 | typedef std::decay_t::type> index_type; 826 | static const int_t rank = index_type::rank; // functor parameter should be a point 827 | typedef typename functor_traits::result_type reference; 828 | typedef std::remove_reference_t value_type; 829 | 830 | public: 831 | lambda_tensor(const pointi& ext, _Fun fun) : shape_(ext), layout_(ext), functor_(fun) {} 832 | 833 | MATAZURE_GENERAL reference operator()(const pointi& idx) const { return functor_(idx); } 834 | 835 | template 836 | MATAZURE_GENERAL reference operator()(_Idx... idx) const { 837 | return (*this)(pointi{idx...}); 838 | } 839 | 840 | tensor, rank> persist(runtime rt = runtime::host) const { 841 | tensor, rank> re(this->shape(), rt); 842 | auto functor_ = this->functor_; 843 | for_index( 844 | re.shape(), [=] __general__(pointi idx) { re(idx) = functor_(idx); }, rt); 845 | return re; 846 | } 847 | 848 | MATAZURE_GENERAL pointi shape() const { return shape_; } 849 | 850 | MATAZURE_GENERAL int_t shape(int_t i) const { return shape()[i]; } 851 | 852 | MATAZURE_GENERAL int_t size() const { return layout_.size(); } 853 | 854 | private: 855 | pointi shape_; 856 | row_major_layout layout_; 857 | _Fun functor_; 858 | }; 859 | 860 | template 861 | inline auto make_lambda_tensor(pointi<_Rank> extent, _Fun fun) -> lambda_tensor<_Fun> { 862 | static_assert(lambda_tensor<_Fun>::rank == _Rank, "_Fun rank is not matched with _Rank"); 863 | return lambda_tensor<_Fun>(extent, fun); 864 | } 865 | 866 | #define MATAZURE_STATIC_ASSERT_DIM_MATCHED(T1, T2) \ 867 | static_assert(T1::rank == T2::rank, "the rank is not matched") 868 | 869 | #define MATAZURE_STATIC_ASSERT_VALUE_TYPE_MATCHED(T1, T2) \ 870 | static_assert(std::is_same::value, \ 871 | "the value type is not matched") 872 | 873 | #define __MATAZURE_ARRAY_INDEX_TENSOR_BINARY_OPERATOR(name, op) \ 874 | template \ 875 | struct name { \ 876 | private: \ 877 | _T1 x1_; \ 878 | _T2 x2_; \ 879 | \ 880 | public: \ 881 | MATAZURE_STATIC_ASSERT_DIM_MATCHED(_T1, _T2); \ 882 | MATAZURE_STATIC_ASSERT_VALUE_TYPE_MATCHED(_T1, _T2); \ 883 | MATAZURE_GENERAL name(_T1 x1, _T2 x2) : x1_(x1), x2_(x2) {} \ 884 | \ 885 | MATAZURE_GENERAL auto operator()(const pointi<_T1::rank>& idx) const \ 886 | -> decltype(this->x1_(idx) op this->x2_(idx)) { \ 887 | return x1_(idx) op x2_(idx); \ 888 | } \ 889 | }; 890 | 891 | #define TENSOR_BINARY_OPERATOR(name, op) \ 892 | __MATAZURE_ARRAY_INDEX_TENSOR_BINARY_OPERATOR(__##name##_functor__, op) \ 893 | template \ 894 | inline lambda_tensor<__##name##_functor__<_TS1, _TS2>> operator op( \ 895 | const tensor_expression<_TS1>& e_lhs, const tensor_expression<_TS2>& e_rhs) { \ 896 | return make_lambda_tensor(e_lhs().shape(), \ 897 | __##name##_functor__<_TS1, _TS2>(e_lhs(), e_rhs())); \ 898 | } 899 | 900 | TENSOR_BINARY_OPERATOR(add, +) 901 | TENSOR_BINARY_OPERATOR(sub, -) 902 | TENSOR_BINARY_OPERATOR(mul, *) 903 | TENSOR_BINARY_OPERATOR(div, /) 904 | 905 | } // namespace matazure 906 | --------------------------------------------------------------------------------