├── .clang-format ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── tests ├── CMakeLists.txt ├── cmake │ └── sanitizers │ │ ├── FindASan.cmake │ │ ├── FindMSan.cmake │ │ ├── FindSanitizers.cmake │ │ ├── FindTSan.cmake │ │ ├── FindUBSan.cmake │ │ ├── asan-wrapper │ │ └── sanitize-helpers.cmake ├── test_common.h ├── test_dot_major.cpp ├── test_main.cpp ├── test_profiling.cpp ├── test_separated_1.cpp ├── test_separated_2.cpp ├── test_time_consumption.cpp ├── test_total.cpp └── timer.h └── tinyndarray.h /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | IndentWidth: 4 3 | TabWidth: 4 4 | ContinuationIndentWidth: 8 5 | UseTab: Never 6 | BreakBeforeBraces: Attach 7 | AccessModifierOffset: -4 8 | Standard: Cpp11 9 | AlignOperands: true 10 | BreakBeforeTernaryOperators: true 11 | AllowShortIfStatementsOnASingleLine: true 12 | AllowShortCaseLabelsOnASingleLine: true 13 | AllowShortLoopsOnASingleLine: true 14 | PointerAlignment: Left 15 | AlignAfterOpenBracket: Align 16 | ColumnLimit: 80 17 | AllowShortFunctionsOnASingleLine: Empty 18 | BasedOnStyle: 'Google' 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.swo 3 | *.swn 4 | *.swm 5 | *.plist 6 | 7 | build 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tests/Catch2"] 2 | path = tests/Catch2 3 | url = https://github.com/catchorg/Catch2 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 David Pilger 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TinyNdArray 2 | 3 | Single Header C++ Implementation of NumPy NdArray. 4 | 5 | I look forward to your pull-request. 6 | 7 | ## Requirement 8 | 9 | * C++14 compiler 10 | 11 | ## Sample Code 12 |
 13 | #define TINYNDARRAY_IMPLEMENTATION
 14 | #include "tinyndarray.h"
 15 | 
 16 | using tinyndarray::NdArray;
 17 | 
 18 | int main(int argc, char const* argv[]) {
 19 |     auto m1 = NdArray::Arange(12).reshape(2, 2, 3) - 2.f;
 20 |     auto m2 = NdArray::Ones(3, 1) * 10.f;
 21 |     auto m12 = m1.dot(m2);
 22 |     std::cout << m12 << std::endl;
 23 | 
 24 |     NdArray m3 = {{{-0.4f, 0.3f}, {-0.2f, 0.1f}},
 25 |                   {{-0.1f, 0.2f}, {-0.3f, 0.4f}}};
 26 |     m3 = Sin(std::move(m3));
 27 |     std::cout << m3 << std::endl;
 28 | 
 29 |     auto sum_abs = Abs(m3).sum();
 30 |     std::cout << sum_abs << std::endl;
 31 | 
 32 |     auto m4 = Where(0.f < m3, -100.f, 100.f);
 33 |     bool all_m4 = All(m4);
 34 |     bool any_m4 = Any(m4);
 35 |     std::cout << m4 << std::endl;
 36 |     std::cout << all_m4 << " " << any_m4 << std::endl;
 37 | 
 38 |     return 0;
 39 | }
 40 | 
41 | 42 | ## Quick Guide 43 | 44 | TinyNdArray supports only float array. 45 | 46 | In the following Python code `dtype=float32` is omitted, and in C++ code assuming `using namespace tinyndarray;` is declared. 47 | 48 | For more detail, please see declarations in top of the header file. 49 | 50 | ### Copy behavior 51 | 52 | Copy behavior of NdArray is shallow copy which is same as NumPy. 53 | 54 | 55 | 56 | 57 | 58 | 59 | 65 | 71 | 72 |
Numpy (Python) TinyNdArray (C++)
 60 |     a = np.ones((2, 3))
 61 |     b = a
 62 |     b[0, 0] = -1
 63 |     print(a[0, 0])  # -1
 64 | 
 66 |     auto a = NdArray::Ones(2, 3);
 67 |     auto b = a;
 68 |     b[{0, 0}] = -1;
 69 |     std::cout << a[{0, 0}] << std::endl;  // -1
 70 | 
73 | 74 | ### Basic Constructing 75 | 76 | | **Numpy (Python)** | **TinyNdArray (C++)** | 77 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 78 | | ```a = np.array([2, 3])``` | ```NdArray a({2, 3});``` or | 79 | | | ```NdArray a(Shape{2, 3});``` | 80 | 81 | ### Float Initializer 82 | 83 | Supports up to 10 dimensions. 84 | 85 | | **Numpy (Python)** | **TinyNdArray (C++)** | 86 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 87 | | ```a = np.array([1.0, 2.0])``` | ```NdArray a = {1.f, 2.f};``` | 88 | | ```a = np.array([[1.0, 2.0]])``` | ```NdArray a = {{1.f, 2.f}};``` | 89 | | ```a = np.array([[1.0, 2.0], [3.0, 4.0]])``` | ```NdArray a = {{1.f, 2.f}, {3.f, 4.f}};``` | 90 | | ```a = np.array([[[[[[[[[[1.0, 2.0]]]]]]]]]])``` | ```NdArray a = {{{{{{{{{{1.f, 2.f}}}}}}}}}};``` | 91 | 92 | ### Static Initializer 93 | 94 | | **Numpy (Python)** | **TinyNdArray (C++)** | 95 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 96 | | ```a = np.empty((2, 3))``` | ```auto a = NdArray::Empty(2, 3);``` or | 97 | | | ```auto a = NdArray::Empty({2, 3});``` or | 98 | | | ```auto a = NdArray::Empty(Shape{2, 3});``` | 99 | | ```a = np.zeros((2, 3))``` | ```auto a = NdArray::Zeros(2, 3);``` or | 100 | | | ```auto a = NdArray::Zeros({2, 3});``` or | 101 | | | ```auto a = NdArray::Zeros(Shape{2, 3});``` | 102 | | ```a = np.ones((2, 3))``` | ```auto a = NdArray::Ones(2, 3);``` or | 103 | | | ```auto a = NdArray::Ones({2, 3});``` or | 104 | | | ```auto a = NdArray::Ones(Shape{2, 3});``` | 105 | | ```a = np.arange(10)``` | ```auto a = NdArray::Arange(10);``` | 106 | | ```a = np.arange(0, 100, 10)``` | ```auto a = NdArray::Arange(0, 100, 10);``` | 107 | 108 | ### Random Initializer 109 | 110 | | **Numpy (Python)** | **TinyNdArray (C++)** | 111 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 112 | | ```a = np.random.uniform(10)``` | ```auto a = NdArray::Uniform(10);``` | 113 | | ```a = np.random.uniform(size=(2, 10))``` | ```auto a = NdArray::Uniform({2, 10});``` or | 114 | | | ```auto a = NdArray::Uniform(Shape{2, 10});``` | 115 | | ```a = np.random.uniform(low=0.0, high=1.0, size=10)``` | ```auto a = NdArray::Uniform(0.0, 1.0, {10});``` or | 116 | | | ```auto a = NdArray::Uniform(0.0, 1.0, Shape{10});``` | 117 | | ```a = np.random.normal(10)``` | ```auto a = NdArray::Normal(10);``` | 118 | | ```a = np.random.normal(size=(2, 10))``` | ```auto a = NdArray::Normal({2, 10});``` or | 119 | | | ```auto a = NdArray::Normal(Shape{2, 10});``` | 120 | | ```a = np.random.normal(loc=0.0, scale=1.0, size=10)``` | ```auto a = NdArray::Normal(0.0, 1.0, {10}); ``` or | 121 | | | ```auto a = NdArray::Normal(0.0, 1.0, Shape{10});``` | 122 | 123 | ### Random Seed 124 | 125 | | **Numpy (Python)** | **TinyNdArray (C++)** | 126 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 127 | | ```a = np.random.seed()``` | ```auto a = NdArray::Seed();``` | 128 | | ```a = np.random.seed(0)``` | ```auto a = NdArray::Seed(0);``` | 129 | 130 | ### Basic Embeded Method 131 | 132 | | **Numpy (Python)** | **TinyNdArray (C++)** | 133 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 134 | | ```id(a)``` | ```a.id()``` | 135 | | ```a.size``` | ```a.size()``` | 136 | | ```a.shape``` | ```a.shape()``` | 137 | | ```a.ndim``` | ```a.ndim()``` | 138 | | ```a.fill(2.0)``` | ```a.fill(2.f)``` | 139 | | ```a.copy()``` | ```a.copy()``` | 140 | 141 | ### Original Embeded Method 142 | 143 | | **Numpy (Python)** | **TinyNdArray (C++)** | 144 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 145 | | ```---``` | ```a.empty()``` | 146 | | ```---``` | ```a.data()``` | 147 | | ```---``` | ```a.begin()``` | 148 | | ```---``` | ```a.end()``` | 149 | 150 | ### Single Element Casting 151 | 152 | | **Numpy (Python)** | **TinyNdArray (C++)** | 153 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 154 | | ```float(a)``` | ```static_cast(a)``` | 155 | 156 | ### Index Access 157 | 158 | | **Numpy (Python)** | **TinyNdArray (C++)** | 159 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 160 | | ```a[2, -3]``` | ```a[{2, -3}]``` or | 161 | | | ```a[Index{2, -3}]``` or | 162 | | | ```a(2, -3)``` | 163 | 164 | ### Reshape methods 165 | 166 | | **Numpy (Python)** | **TinyNdArray (C++)** | 167 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 168 | | ```a.reshape(-1, 2, 1)``` | ```a.reshape({-1, 2, 1})``` or | 169 | | | ```a.reshape(Shape{-1, 2, 1})``` or | 170 | | | ```a.reshape(-1, 2, 1)``` | 171 | | ```a.flatten()``` | ```a.flatten()``` | 172 | | ```a.ravel()``` | ```a.ravel()``` | 173 | 174 | ### Reshape functions 175 | 176 | | **Numpy (Python)** | **TinyNdArray (C++)** | 177 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 178 | | ```np.reshape(a, (-1, 2, 1))``` | ```Reshape(a, {-1, 2, 1})``` | 179 | | ```np.squeeze(a)``` | ```Squeeze(a)``` | 180 | | ```np.squeeze(a, [0, -2])``` | ```Squeeze(a, {0, -2})``` | 181 | | ```np.expand_dims(a, 1)``` | ```ExpandDims(a, 1)``` | 182 | 183 | ### Slice 184 | 185 | Slice methods create copy of the array, not reference. 186 | 187 | | **Numpy (Python)** | **TinyNdArray (C++)** | 188 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 189 | | ```a[1:5, -4:-1]``` | ```a.slice({{1, 5}, {-4, -1}})``` or | 190 | | | ```a.slice(SliceIndex{{1, 5}, {-4, -1}})``` or | 191 | | | ```a.slice({1, 5}, {-4, -1})``` | 192 | 193 | ### Print 194 | 195 | | **Numpy (Python)** | **TinyNdArray (C++)** | 196 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 197 | | ```print(a)``` | ```std::cout << a << std::endl;``` | 198 | 199 | ### Single Operators 200 | | **Numpy (Python)** | **TinyNdArray (C++)** | 201 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 202 | | ```+np.ones((2, 3))``` | ```+NdArray::Ones(2, 3)``` | 203 | | ```-np.ones((2, 3))``` | ```-NdArray::Ones(2, 3)``` | 204 | 205 | ### Arithmetic Operators 206 | 207 | All operaters supports broadcast. 208 | 209 | | **Numpy (Python)** | **TinyNdArray (C++)** | 210 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 211 | | ```np.ones((2, 1, 3)) + np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) + NdArray::Ones(4, 1)``` | 212 | | ```np.ones((2, 1, 3)) - np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) - NdArray::Ones(4, 1)``` | 213 | | ```np.ones((2, 1, 3)) * np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) * NdArray::Ones(4, 1)``` | 214 | | ```np.ones((2, 1, 3)) / np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) / NdArray::Ones(4, 1)``` | 215 | | ```np.ones((2, 1, 3)) + 2.0``` | ```NdArray::Ones(2, 1, 3) + 2.f``` | 216 | | ```np.ones((2, 1, 3)) - 2.0``` | ```NdArray::Ones(2, 1, 3) - 2.f``` | 217 | | ```np.ones((2, 1, 3)) * 2.0``` | ```NdArray::Ones(2, 1, 3) * 2.f``` | 218 | | ```np.ones((2, 1, 3)) / 2.0``` | ```NdArray::Ones(2, 1, 3) / 2.f``` | 219 | | ```2.0 + np.ones((2, 1, 3))``` | ```2.f + NdArray::Ones(2, 1, 3)``` | 220 | | ```2.0 - np.ones((2, 1, 3))``` | ```2.f - NdArray::Ones(2, 1, 3)``` | 221 | | ```2.0 * np.ones((2, 1, 3))``` | ```2.f * NdArray::Ones(2, 1, 3)``` | 222 | | ```2.0 / np.ones((2, 1, 3))``` | ```2.f / NdArray::Ones(2, 1, 3)``` | 223 | 224 | ### Comparison Operators 225 | 226 | All operaters supports broadcast. 227 | 228 | | **Numpy (Python)** | **TinyNdArray (C++)** | 229 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 230 | | ```np.ones((2, 1, 3)) == np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) == NdArray::Ones(4, 1)``` | 231 | | ```np.ones((2, 1, 3)) != np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) != NdArray::Ones(4, 1)``` | 232 | | ```np.ones((2, 1, 3)) > np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) > NdArray::Ones(4, 1)``` | 233 | | ```np.ones((2, 1, 3)) >= np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) >= NdArray::Ones(4, 1)``` | 234 | | ```np.ones((2, 1, 3)) < np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) < NdArray::Ones(4, 1)``` | 235 | | ```np.ones((2, 1, 3)) <= np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) <= NdArray::Ones(4, 1)``` | 236 | | ```np.ones((2, 1, 3)) == 1.f``` | ```NdArray::Ones(2, 1, 3) == 1.f``` | 237 | | ```np.ones((2, 1, 3)) != 1.f``` | ```NdArray::Ones(2, 1, 3) != 1.f``` | 238 | | ```np.ones((2, 1, 3)) > 1.f``` | ```NdArray::Ones(2, 1, 3) > 1.f``` | 239 | | ```np.ones((2, 1, 3)) >= 1.f``` | ```NdArray::Ones(2, 1, 3) >= 1.f``` | 240 | | ```np.ones((2, 1, 3)) < 1.f``` | ```NdArray::Ones(2, 1, 3) < 1.f``` | 241 | | ```np.ones((2, 1, 3)) <= 1.f``` | ```NdArray::Ones(2, 1, 3) <= 1.f``` | 242 | | ```1.f == np.ones((4, 1))``` | ```1.f == NdArray::Ones(4, 1)``` | 243 | | ```1.f != np.ones((4, 1))``` | ```1.f != NdArray::Ones(4, 1)``` | 244 | | ```1.f > np.ones((4, 1))``` | ```1.f > NdArray::Ones(4, 1)``` | 245 | | ```1.f >= np.ones(4, 1))``` | ```1.f >= NdArray::Ones(4, 1)``` | 246 | | ```1.f < np.ones((4, 1))``` | ```1.f < NdArray::Ones(4, 1)``` | 247 | | ```1.f <= np.ones((4, 1))``` | ```1.f <= NdArray::Ones(4, 1)``` | 248 | 249 | ### Compound Assignment Operators 250 | 251 | All operaters supports broadcast. However, left-side variable keep its size. 252 | 253 | | **Numpy (Python)** | **TinyNdArray (C++)** | 254 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 255 | | ```np.ones((2, 1, 3)) += np.ones(3)``` | ```NdArray::Ones(2, 1, 3) += NdArray::Ones(3)``` | 256 | | ```np.ones((2, 1, 3)) -= np.ones(3)``` | ```NdArray::Ones(2, 1, 3) -= NdArray::Ones(3)``` | 257 | | ```np.ones((2, 1, 3)) *= np.ones(3)``` | ```NdArray::Ones(2, 1, 3) *= NdArray::Ones(3)``` | 258 | | ```np.ones((2, 1, 3)) /= np.ones(3)``` | ```NdArray::Ones(2, 1, 3) /= NdArray::Ones(3)``` | 259 | | ```np.ones((2, 1, 3)) += 2.f``` | ```NdArray::Ones(2, 1, 3) += 2.f``` | 260 | | ```np.ones((2, 1, 3)) -= 2.f``` | ```NdArray::Ones(2, 1, 3) -= 2.f``` | 261 | | ```np.ones((2, 1, 3)) *= 2.f``` | ```NdArray::Ones(2, 1, 3) *= 2.f``` | 262 | | ```np.ones((2, 1, 3)) /= 2.f``` | ```NdArray::Ones(2, 1, 3) /= 2.f``` | 263 | 264 | ### Math Functions 265 | 266 | Functions which takes two arguments support broadcast. 267 | 268 | | **Numpy (Python)** | **TinyNdArray (C++)** | 269 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 270 | | ```np.abs(a)``` | ```Abs(a)``` | 271 | | ```np.sign(a)``` | ```Sign(a)``` | 272 | | ```np.ceil(a)``` | ```Ceil(a)``` | 273 | | ```np.floor(a)``` | ```Floor(a)``` | 274 | | ```np.clip(a, x_min, x_max)``` | ```Clip(a, x_min, x_max)``` | 275 | | ```np.sqrt(a)``` | ```Sqrt(a)``` | 276 | | ```np.exp(a)``` | ```Exp(a)``` | 277 | | ```np.log(a)``` | ```Log(a)``` | 278 | | ```np.square(a)``` | ```Square(a)``` | 279 | | ```np.power(a, b)``` | ```Power(a, b)``` | 280 | | ```np.power(a, 2.0)``` | ```Power(a, 2.f)``` | 281 | | ```np.power(2.0, a)``` | ```Power(2.f, a)``` | 282 | | ```np.sin(a)``` | ```Sin(a)``` | 283 | | ```np.cos(a)``` | ```Cos(a)``` | 284 | | ```np.tan(a)``` | ```Tan(a)``` | 285 | | ```np.arcsin(a)``` | ```ArcSin(a)``` | 286 | | ```np.arccos(a)``` | ```ArcCos(a)``` | 287 | | ```np.arctan(a)``` | ```ArcTan(a)``` | 288 | | ```np.arctan2(a, b)``` | ```ArcTan2(a, b)``` | 289 | | ```np.arctan2(a, 10.0)``` | ```ArcTan2(a, 10.f)``` | 290 | | ```np.arctan2(10.0, a)``` | ```ArcTan2(10.f, a)``` | 291 | 292 | ### Axis Functions 293 | 294 | | **Numpy (Python)** | **TinyNdArray (C++)** | 295 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 296 | | ```np.sum(a)``` | ```Sum(a)``` | 297 | | ```np.sum(a, axis=0)``` | ```Sum(a, {0})``` or | 298 | | | ```Sum(a, Axis{0})``` | 299 | | ```np.sum(a, axis=(0, 2))``` | ```Sum(a, {0, 2})``` or | 300 | | | ```Sum(a, Axis{0, 2})``` | 301 | | ```np.mean(a, axis=0)``` | ```Mean(a, {0})``` | 302 | | ```np.min(a, axis=0)``` | ```Min(a, {0})``` | 303 | | ```np.max(a, axis=0)``` | ```Max(a, {0})``` | 304 | 305 | ### Logistic Functions 306 | 307 | | **Numpy (Python)** | **TinyNdArray (C++)** | 308 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 309 | | ```np.all(a)``` | ```All(a, {0})``` | 310 | | ```np.all(a, axis=0)``` | ```All(a, {0})``` | 311 | | ```np.any(a)``` | ```Any(a, {0})``` | 312 | | ```np.any(a, axis=0)``` | ```Any(a, {0})``` | 313 | | ```np.where(condition, x, y)``` | ```Where(condition, x, y)``` | 314 | 315 | ### Axis Method 316 | 317 | | **Numpy (Python)** | **TinyNdArray (C++)** | 318 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 319 | | ```a.sum()``` | ```a.sum()``` | 320 | | ```a.sum(axis=0)``` | ```a.sum({0})``` or | 321 | | | ```a.sum(Axis{0})``` | 322 | | ```a.sum(axis=(0, 2))``` | ```a.sum({0, 2})``` or | 323 | | | ```a.sum(Axis{0, 2})``` | 324 | | ```a.mean(axis=0)``` | ```a.mean({0})``` | 325 | | ```a.min(axis=0)``` | ```a.min({0})``` | 326 | | ```a.max(axis=0)``` | ```a.max({0})``` | 327 | 328 | ### Grouping Functions 329 | 330 | | **Numpy (Python)** | **TinyNdArray (C++)** | 331 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 332 | | ```np.stack((a, b, ...), axis=0)``` | ```Stack({a, b, ...}, 0)``` | 333 | | ```np.concatenate((a, b, ...), axis=0)``` | ```Concatenate({a, b, ...}, 0)``` | 334 | | ```np.split(a, 2, axis=0)``` | ```Split(a, 2, 0)``` | 335 | | ```np.split(a, [1, 3], axis=0)``` | ```Split(a, {1, 3}, 0)``` | 336 | | | ```Separate(a, 0)``` An inverse of Stack(a, 0) | 337 | 338 | ### View Changing Functions 339 | 340 | View chaining methods create copy of the array, not reference. 341 | 342 | | **Numpy (Python)** | **TinyNdArray (C++)** | 343 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 344 | | ```np.transpose(x)``` | ```Transpose(x)``` | 345 | | ```np.swapaxes(x, 0, 2)``` | ```Swapaxes(x, 0, 2)``` | 346 | | ```np.broadcast_to(x, (3, 2))``` | ```BroadcastTo(x, {3, 2})``` | 347 | | | ```SumTo(x, {3, 2})``` An inverse of BroadcastTo(x, {3, 2}) | 348 | 349 | ### Matrix Products 350 | 351 | All dimension rules of numpy are implemented. 352 | 353 | | **Numpy (Python)** | **TinyNdArray (C++)** | 354 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 355 | | ```np.dot(a, b)``` | ```Dot(a, b)``` | 356 | | ```a.dot(b)``` | ```a.dot(b)``` | 357 | | ```np.matmul(a, b)``` | ```Matmul(a, b)``` | 358 | | ```np.cross(a, b)``` | ```Cross(a, b)``` | 359 | | ```a.cross(b)``` | ```a.cross(b)``` | 360 | 361 | ### Inverse 362 | 363 | All dimension rules of numpy are implemented. 364 | 365 | | **Numpy (Python)** | **TinyNdArray (C++)** | 366 | |:--------------------------------------------------------:|:--------------------------------------------------------:| 367 | | ```np.linalg.inv(a)``` | ```Inv(a, b)``` | 368 | 369 | 370 | ## In-place Operation 371 | 372 | In NumPy, in-place and not inplace operations are written by following. 373 | 374 | 375 | 376 | 377 | 378 | 379 | 385 | 391 | 392 |
Numpy (Python) In-place Numpy (Python) Not in-place
380 |     a = np.ones((2, 3))
381 |     a_id = id(a)
382 |     np.exp(a, out=a)  # in-place
383 |     print(id(a) == a_id)  # True
384 | 
386 |     a = np.ones((2, 3))
387 |     a_id = id(a)
388 |     a = np.exp(a)  # not in-place
389 |     print(id(a) == a_id)  # False
390 | 
393 | 394 | In TinyNdArray, when right-reference values are passed, no new arrays are created and operated in-place. 395 | 396 | 397 | 398 | 399 | 400 | 401 | 408 | 415 | 416 |
TinyNdArray (C++) In-place TinyNdArray (C++) Not in-place
402 |     auto a = NdArray::Ones(2, 3);
403 |     auto a_id = a.id();
404 |     a = np.exp(std::move(a));  // in-place
405 |     std::cout << (a.id() == a_id)
406 |               << std::endl;  // true
407 | 
409 |     auto a = NdArray::Ones(2, 3);
410 |     auto a_id = a.id();
411 |     a = np.exp(a);  // not in-place
412 |     std::cout << (a.id() == a_id)
413 |               << std::endl;  // false
414 | 
417 | 418 | However, even right-reference values are passed, when the size is changed by broadcasting, a new array will be created. 419 | 420 | 421 | 422 | 423 | 424 | 425 | 435 | 445 | 446 |
TinyNdArray (C++) In-place TinyNdArray (C++) Not in-place
426 |    auto a = NdArray::Ones(2, 1, 3);
427 |    auto b = NdArray::Ones(3);
428 |    auto a_id = a.id();
429 |    a = np.exp(std::move(a));  // in-place
430 |    std::cout << a.shape()
431 |              << std::endl;  // [2, 1, 3]
432 |    std::cout << (a.id() == a_id)
433 |              << std::endl;  // true
434 | 
436 |    auto a = NdArray::Ones(2, 1, 3);
437 |    auto a = NdArray::Ones(3, 1);
438 |    auto a_id = a.id();
439 |    a = np.exp(srd::move(a));  // looks like in-place
440 |    std::cout << a.shape()
441 |              << std::endl;  // [2, 3, 3]
442 |    std::cout << (a.id() == a_id)
443 |              << std::endl;  // false
444 | 
447 | 448 | ## Parallel Execusion 449 | 450 | In default, most of all operations run in parallel by threads. 451 | 452 | When changing the number of workers, please set via `NdArray::SetNumWorkers()`. 453 | 454 |
455 | // Default setting. Use all of cores.
456 | NdArray::SetNumWorkers(-1);
457 | // Set no parallel.
458 | NdArray::SetNumWorkers(1);
459 | // Use 4 cores
460 | NdArray::SetNumWorkers(4);
461 | 
462 | 463 | ## Memory profiling 464 | 465 | When a macro `TINYNDARRAY_PROFILE_MEMORY` is defined, memory profiler is activated. 466 | 467 | The following methods can be used to get the number of instances and the size of allocated memories.. 468 | 469 | ```NdArray::GetNumInstance()``` 470 | ```NdArray::GetTotalMemory()``` 471 | 472 | ## Macros 473 | 474 | * TINYNDARRAY_H_ONCE 475 | * TINYNDARRAY_NO_INCLUDE 476 | * TINYNDARRAY_NO_NAMESPACE 477 | * TINYNDARRAY_NO_DECLARATION 478 | * TINYNDARRAY_IMPLEMENTATION 479 | * TINYNDARRAY_PROFILE_MEMORY 480 | 481 | ## TODO 482 | 483 | * [x] Replace axis reduction function with more effective algorithm. 484 | * [ ] Implement more effective algorithm. 485 | * [x] Replace slice method's recursive call with loop for speed up. 486 | * [x] Make parallel 487 | * [ ] Improve inverse function with LU decomposition. 488 | * [ ] Implement reference slice which dose not effect the current performance. 489 | * [ ] Introduce SIMD instructions. 490 | 491 | Everything in the upper list are difficult challenges. If you have any ideas, please let me know. 492 | -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.8.2) 2 | 3 | message(STATUS "Build tinyndarray tests") 4 | 5 | # ------------------------------------------------------------------------------ 6 | # ----------------------------------- Common ----------------------------------- 7 | # ------------------------------------------------------------------------------ 8 | project(tinyndarray_tests CXX C) 9 | 10 | set(CMAKE_CXX_STANDARD 14) # C++ 14 11 | 12 | if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") 13 | set(LINK_TYPE STATIC) 14 | else() 15 | set(LINK_TYPE SHARED) 16 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 17 | endif() 18 | 19 | # Print make commands for debug 20 | # set(CMAKE_VERBOSE_MAKEFILE 1) 21 | 22 | # Set default build type 23 | if (NOT CMAKE_BUILD_TYPE) 24 | set(CMAKE_BUILD_TYPE Release) 25 | endif() 26 | message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") 27 | 28 | # Output `compile_commands.json` 29 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 30 | 31 | # cmake modules 32 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) 33 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/sanitizers) 34 | find_package(Sanitizers) # Address sanitizer (-DSANITIZE_ADDRESS=ON) 35 | 36 | # Set output directories 37 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 38 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 39 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) 40 | 41 | # Warning options 42 | if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") 43 | set(warning_options "-Wall -Wextra -Wconversion") 44 | elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang") 45 | set(warning_options "-Wall -Wextra -Wcast-align -Wcast-qual \ 46 | -Wctor-dtor-privacy -Wdisabled-optimization \ 47 | -Wformat=2 -Winit-self \ 48 | -Wmissing-declarations -Wmissing-include-dirs \ 49 | -Wold-style-cast -Woverloaded-virtual \ 50 | -Wredundant-decls -Wshadow -Wsign-conversion \ 51 | -Wsign-promo -Wno-old-style-cast\ 52 | -Wstrict-overflow=5 -Wundef -Wno-unknown-pragmas \ 53 | -Wreturn-std-move") 54 | elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") 55 | set(warning_options "--pedantic -Wall -Wextra -Wcast-align -Wcast-qual \ 56 | -Wctor-dtor-privacy -Wdisabled-optimization \ 57 | -Wformat=2 -Winit-self -Wlogical-op \ 58 | -Wmissing-declarations -Wmissing-include-dirs \ 59 | -Wnoexcept -Wold-style-cast -Woverloaded-virtual \ 60 | -Wredundant-decls -Wshadow -Wsign-conversion \ 61 | -Wsign-promo -Wstrict-null-sentinel \ 62 | -Wstrict-overflow=5 -Wundef -Wno-unknown-pragmas") 63 | elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") 64 | set(warning_options "/W4") 65 | else() 66 | message(WARNING "Unsupported compiler for warning options") 67 | message("CMAKE_CXX_COMPILER_ID is ${CMAKE_CXX_COMPILER_ID}") 68 | endif() 69 | 70 | # Utility function to setup a target (include, link, warning, sanitizer) 71 | function(setup_target target includes libs) 72 | target_include_directories(${target} PUBLIC ${includes}) 73 | target_link_libraries(${target} ${libs}) 74 | set_target_properties(${target} PROPERTIES COMPILE_FLAGS ${warning_options}) 75 | add_sanitizers(${target}) 76 | endfunction(setup_target) 77 | 78 | # Utility function to setup a target simply (include, link) 79 | function(setup_target_simple target includes libs) 80 | target_include_directories(${target} PUBLIC ${includes}) 81 | target_link_libraries(${target} ${libs}) 82 | endfunction(setup_target_simple) 83 | 84 | 85 | # ------------------------------------------------------------------------------ 86 | # ----------------------------- Internal Libraries ----------------------------- 87 | # ------------------------------------------------------------------------------ 88 | 89 | find_package(Threads) 90 | set(tinyndarray_header ${CMAKE_CURRENT_SOURCE_DIR}/../tinyndarray.h) 91 | 92 | # Executable file for tests (one c++ file) 93 | add_executable(tinyndarray_tests 94 | ${tinyndarray_header} 95 | ${CMAKE_CURRENT_SOURCE_DIR}/test_main.cpp 96 | ${CMAKE_CURRENT_SOURCE_DIR}/test_total.cpp 97 | ) 98 | setup_target(tinyndarray_tests "" "${CMAKE_THREAD_LIBS_INIT}") 99 | 100 | # Executable file for tests (separated c++ files) 101 | add_executable(tinyndarray_tests_separated 102 | ${tinyndarray_header} 103 | ${CMAKE_CURRENT_SOURCE_DIR}/test_main.cpp 104 | ${CMAKE_CURRENT_SOURCE_DIR}/test_separated_1.cpp 105 | ${CMAKE_CURRENT_SOURCE_DIR}/test_separated_2.cpp 106 | ) 107 | setup_target(tinyndarray_tests_separated "" "${CMAKE_THREAD_LIBS_INIT}") 108 | 109 | # Executable file for tests (time consumption) 110 | add_executable(tinyndarray_tests_time_consumption 111 | ${tinyndarray_header} 112 | ${CMAKE_CURRENT_SOURCE_DIR}/test_main.cpp 113 | ${CMAKE_CURRENT_SOURCE_DIR}/test_time_consumption.cpp 114 | ) 115 | setup_target(tinyndarray_tests_time_consumption "" "${CMAKE_THREAD_LIBS_INIT}") 116 | 117 | # Executable file for tests (profiling test) 118 | add_executable(tinyndarray_tests_profiling 119 | ${tinyndarray_header} 120 | ${CMAKE_CURRENT_SOURCE_DIR}/test_main.cpp 121 | ${CMAKE_CURRENT_SOURCE_DIR}/test_profiling.cpp 122 | ) 123 | setup_target(tinyndarray_tests_profiling "" "${CMAKE_THREAD_LIBS_INIT}") 124 | 125 | # Executable file for tests (dot majoring test) 126 | add_executable(tinyndarray_tests_dot_major 127 | ${tinyndarray_header} 128 | ${CMAKE_CURRENT_SOURCE_DIR}/test_dot_major.cpp 129 | ) 130 | setup_target(tinyndarray_tests_dot_major "" "${CMAKE_THREAD_LIBS_INIT}") 131 | -------------------------------------------------------------------------------- /tests/cmake/sanitizers/FindASan.cmake: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 4 | # 2013 Matthew Arsenault 5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | option(SANITIZE_ADDRESS "Enable AddressSanitizer for sanitized targets." Off) 26 | 27 | set(FLAG_CANDIDATES 28 | # Clang 3.2+ use this version. The no-omit-frame-pointer option is optional. 29 | "-g -fsanitize=address -fno-omit-frame-pointer" 30 | "-g -fsanitize=address" 31 | 32 | # Older deprecated flag for ASan 33 | "-g -faddress-sanitizer" 34 | ) 35 | 36 | 37 | if (SANITIZE_ADDRESS AND (SANITIZE_THREAD OR SANITIZE_MEMORY)) 38 | message(FATAL_ERROR "AddressSanitizer is not compatible with " 39 | "ThreadSanitizer or MemorySanitizer.") 40 | endif () 41 | 42 | 43 | include(sanitize-helpers) 44 | 45 | if (SANITIZE_ADDRESS) 46 | sanitizer_check_compiler_flags("${FLAG_CANDIDATES}" "AddressSanitizer" 47 | "ASan") 48 | 49 | find_program(ASan_WRAPPER "asan-wrapper" PATHS ${CMAKE_MODULE_PATH}) 50 | mark_as_advanced(ASan_WRAPPER) 51 | endif () 52 | 53 | function (add_sanitize_address TARGET) 54 | if (NOT SANITIZE_ADDRESS) 55 | return() 56 | endif () 57 | 58 | saitizer_add_flags(${TARGET} "AddressSanitizer" "ASan") 59 | endfunction () 60 | -------------------------------------------------------------------------------- /tests/cmake/sanitizers/FindMSan.cmake: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 4 | # 2013 Matthew Arsenault 5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | option(SANITIZE_MEMORY "Enable MemorySanitizer for sanitized targets." Off) 26 | 27 | set(FLAG_CANDIDATES 28 | "-g -fsanitize=memory" 29 | ) 30 | 31 | 32 | include(sanitize-helpers) 33 | 34 | if (SANITIZE_MEMORY) 35 | if (NOT ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") 36 | message(WARNING "MemorySanitizer disabled for target ${TARGET} because " 37 | "MemorySanitizer is supported for Linux systems only.") 38 | set(SANITIZE_MEMORY Off CACHE BOOL 39 | "Enable MemorySanitizer for sanitized targets." FORCE) 40 | elseif (NOT ${CMAKE_SIZEOF_VOID_P} EQUAL 8) 41 | message(WARNING "MemorySanitizer disabled for target ${TARGET} because " 42 | "MemorySanitizer is supported for 64bit systems only.") 43 | set(SANITIZE_MEMORY Off CACHE BOOL 44 | "Enable MemorySanitizer for sanitized targets." FORCE) 45 | else () 46 | sanitizer_check_compiler_flags("${FLAG_CANDIDATES}" "MemorySanitizer" 47 | "MSan") 48 | endif () 49 | endif () 50 | 51 | function (add_sanitize_memory TARGET) 52 | if (NOT SANITIZE_MEMORY) 53 | return() 54 | endif () 55 | 56 | saitizer_add_flags(${TARGET} "MemorySanitizer" "MSan") 57 | endfunction () 58 | -------------------------------------------------------------------------------- /tests/cmake/sanitizers/FindSanitizers.cmake: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 4 | # 2013 Matthew Arsenault 5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | # If any of the used compiler is a GNU compiler, add a second option to static 26 | # link against the sanitizers. 27 | option(SANITIZE_LINK_STATIC "Try to link static against sanitizers." Off) 28 | 29 | 30 | 31 | 32 | set(FIND_QUIETLY_FLAG "") 33 | if (DEFINED Sanitizers_FIND_QUIETLY) 34 | set(FIND_QUIETLY_FLAG "QUIET") 35 | endif () 36 | 37 | find_package(ASan ${FIND_QUIETLY_FLAG}) 38 | find_package(TSan ${FIND_QUIETLY_FLAG}) 39 | find_package(MSan ${FIND_QUIETLY_FLAG}) 40 | find_package(UBSan ${FIND_QUIETLY_FLAG}) 41 | 42 | 43 | 44 | 45 | function(sanitizer_add_blacklist_file FILE) 46 | if(NOT IS_ABSOLUTE ${FILE}) 47 | set(FILE "${CMAKE_CURRENT_SOURCE_DIR}/${FILE}") 48 | endif() 49 | get_filename_component(FILE "${FILE}" REALPATH) 50 | 51 | sanitizer_check_compiler_flags("-fsanitize-blacklist=${FILE}" 52 | "SanitizerBlacklist" "SanBlist") 53 | endfunction() 54 | 55 | function(add_sanitizers ...) 56 | # If no sanitizer is enabled, return immediately. 57 | if (NOT (SANITIZE_ADDRESS OR SANITIZE_MEMORY OR SANITIZE_THREAD OR 58 | SANITIZE_UNDEFINED)) 59 | return() 60 | endif () 61 | 62 | foreach (TARGET ${ARGV}) 63 | # Check if this target will be compiled by exactly one compiler. Other- 64 | # wise sanitizers can't be used and a warning should be printed once. 65 | sanitizer_target_compilers(${TARGET} TARGET_COMPILER) 66 | list(LENGTH TARGET_COMPILER NUM_COMPILERS) 67 | if (NUM_COMPILERS GREATER 1) 68 | message(WARNING "Can't use any sanitizers for target ${TARGET}, " 69 | "because it will be compiled by incompatible compilers. " 70 | "Target will be compiled without sanitzers.") 71 | return() 72 | 73 | # If the target is compiled by no known compiler, ignore it. 74 | elseif (NUM_COMPILERS EQUAL 0) 75 | message(WARNING "Can't use any sanitizers for target ${TARGET}, " 76 | "because it uses an unknown compiler. Target will be " 77 | "compiled without sanitzers.") 78 | return() 79 | endif () 80 | 81 | # Add sanitizers for target. 82 | add_sanitize_address(${TARGET}) 83 | add_sanitize_thread(${TARGET}) 84 | add_sanitize_memory(${TARGET}) 85 | add_sanitize_undefined(${TARGET}) 86 | endforeach () 87 | endfunction(add_sanitizers) 88 | -------------------------------------------------------------------------------- /tests/cmake/sanitizers/FindTSan.cmake: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 4 | # 2013 Matthew Arsenault 5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | option(SANITIZE_THREAD "Enable ThreadSanitizer for sanitized targets." Off) 26 | 27 | set(FLAG_CANDIDATES 28 | "-g -fsanitize=thread" 29 | ) 30 | 31 | 32 | # ThreadSanitizer is not compatible with MemorySanitizer. 33 | if (SANITIZE_THREAD AND SANITIZE_MEMORY) 34 | message(FATAL_ERROR "ThreadSanitizer is not compatible with " 35 | "MemorySanitizer.") 36 | endif () 37 | 38 | 39 | include(sanitize-helpers) 40 | 41 | if (SANITIZE_THREAD) 42 | if (NOT ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") 43 | message(WARNING "ThreadSanitizer disabled for target ${TARGET} because " 44 | "ThreadSanitizer is supported for Linux systems only.") 45 | set(SANITIZE_THREAD Off CACHE BOOL 46 | "Enable ThreadSanitizer for sanitized targets." FORCE) 47 | elseif (NOT ${CMAKE_SIZEOF_VOID_P} EQUAL 8) 48 | message(WARNING "ThreadSanitizer disabled for target ${TARGET} because " 49 | "ThreadSanitizer is supported for 64bit systems only.") 50 | set(SANITIZE_THREAD Off CACHE BOOL 51 | "Enable ThreadSanitizer for sanitized targets." FORCE) 52 | else () 53 | sanitizer_check_compiler_flags("${FLAG_CANDIDATES}" "ThreadSanitizer" 54 | "TSan") 55 | endif () 56 | endif () 57 | 58 | function (add_sanitize_thread TARGET) 59 | if (NOT SANITIZE_THREAD) 60 | return() 61 | endif () 62 | 63 | saitizer_add_flags(${TARGET} "ThreadSanitizer" "TSan") 64 | endfunction () 65 | -------------------------------------------------------------------------------- /tests/cmake/sanitizers/FindUBSan.cmake: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 4 | # 2013 Matthew Arsenault 5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | option(SANITIZE_UNDEFINED 26 | "Enable UndefinedBehaviorSanitizer for sanitized targets." Off) 27 | 28 | set(FLAG_CANDIDATES 29 | "-g -fsanitize=undefined" 30 | ) 31 | 32 | 33 | include(sanitize-helpers) 34 | 35 | if (SANITIZE_UNDEFINED) 36 | sanitizer_check_compiler_flags("${FLAG_CANDIDATES}" 37 | "UndefinedBehaviorSanitizer" "UBSan") 38 | endif () 39 | 40 | function (add_sanitize_undefined TARGET) 41 | if (NOT SANITIZE_UNDEFINED) 42 | return() 43 | endif () 44 | 45 | saitizer_add_flags(${TARGET} "UndefinedBehaviorSanitizer" "UBSan") 46 | endfunction () 47 | -------------------------------------------------------------------------------- /tests/cmake/sanitizers/asan-wrapper: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 6 | # 2013 Matthew Arsenault 7 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany 8 | # 9 | # Permission is hereby granted, free of charge, to any person obtaining a copy 10 | # of this software and associated documentation files (the "Software"), to deal 11 | # in the Software without restriction, including without limitation the rights 12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | # copies of the Software, and to permit persons to whom the Software is 14 | # furnished to do so, subject to the following conditions: 15 | # 16 | # The above copyright notice and this permission notice shall be included in all 17 | # copies or substantial portions of the Software. 18 | # 19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | # SOFTWARE. 26 | 27 | # This script is a wrapper for AddressSanitizer. In some special cases you need 28 | # to preload AddressSanitizer to avoid error messages - e.g. if you're 29 | # preloading another library to your application. At the moment this script will 30 | # only do something, if we're running on a Linux platform. OSX might not be 31 | # affected. 32 | 33 | 34 | # Exit immediately, if platform is not Linux. 35 | if [ "$(uname)" != "Linux" ] 36 | then 37 | exec $@ 38 | fi 39 | 40 | 41 | # Get the used libasan of the application ($1). If a libasan was found, it will 42 | # be prepended to LD_PRELOAD. 43 | libasan=$(ldd $1 | grep libasan | sed "s/^[[:space:]]//" | cut -d' ' -f1) 44 | if [ -n "$libasan" ] 45 | then 46 | if [ -n "$LD_PRELOAD" ] 47 | then 48 | export LD_PRELOAD="$libasan:$LD_PRELOAD" 49 | else 50 | export LD_PRELOAD="$libasan" 51 | fi 52 | fi 53 | 54 | # Execute the application. 55 | exec $@ 56 | -------------------------------------------------------------------------------- /tests/cmake/sanitizers/sanitize-helpers.cmake: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 4 | # 2013 Matthew Arsenault 5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | # Helper function to get the language of a source file. 26 | function (sanitizer_lang_of_source FILE RETURN_VAR) 27 | get_filename_component(FILE_EXT "${FILE}" EXT) 28 | string(TOLOWER "${FILE_EXT}" FILE_EXT) 29 | string(SUBSTRING "${FILE_EXT}" 1 -1 FILE_EXT) 30 | 31 | get_property(ENABLED_LANGUAGES GLOBAL PROPERTY ENABLED_LANGUAGES) 32 | foreach (LANG ${ENABLED_LANGUAGES}) 33 | list(FIND CMAKE_${LANG}_SOURCE_FILE_EXTENSIONS "${FILE_EXT}" TEMP) 34 | if (NOT ${TEMP} EQUAL -1) 35 | set(${RETURN_VAR} "${LANG}" PARENT_SCOPE) 36 | return() 37 | endif () 38 | endforeach() 39 | 40 | set(${RETURN_VAR} "" PARENT_SCOPE) 41 | endfunction () 42 | 43 | 44 | # Helper function to get compilers used by a target. 45 | function (sanitizer_target_compilers TARGET RETURN_VAR) 46 | # Check if all sources for target use the same compiler. If a target uses 47 | # e.g. C and Fortran mixed and uses different compilers (e.g. clang and 48 | # gfortran) this can trigger huge problems, because different compilers may 49 | # use different implementations for sanitizers. 50 | set(BUFFER "") 51 | get_target_property(TSOURCES ${TARGET} SOURCES) 52 | foreach (FILE ${TSOURCES}) 53 | # If expression was found, FILE is a generator-expression for an object 54 | # library. Object libraries will be ignored. 55 | string(REGEX MATCH "TARGET_OBJECTS:([^ >]+)" _file ${FILE}) 56 | if ("${_file}" STREQUAL "") 57 | sanitizer_lang_of_source(${FILE} LANG) 58 | if (LANG) 59 | list(APPEND BUFFER ${CMAKE_${LANG}_COMPILER_ID}) 60 | endif () 61 | endif () 62 | endforeach () 63 | 64 | list(REMOVE_DUPLICATES BUFFER) 65 | set(${RETURN_VAR} "${BUFFER}" PARENT_SCOPE) 66 | endfunction () 67 | 68 | 69 | # Helper function to check compiler flags for language compiler. 70 | function (sanitizer_check_compiler_flag FLAG LANG VARIABLE) 71 | if (${LANG} STREQUAL "C") 72 | include(CheckCCompilerFlag) 73 | check_c_compiler_flag("${FLAG}" ${VARIABLE}) 74 | 75 | elseif (${LANG} STREQUAL "CXX") 76 | include(CheckCXXCompilerFlag) 77 | check_cxx_compiler_flag("${FLAG}" ${VARIABLE}) 78 | 79 | elseif (${LANG} STREQUAL "Fortran") 80 | # CheckFortranCompilerFlag was introduced in CMake 3.x. To be compatible 81 | # with older Cmake versions, we will check if this module is present 82 | # before we use it. Otherwise we will define Fortran coverage support as 83 | # not available. 84 | include(CheckFortranCompilerFlag OPTIONAL RESULT_VARIABLE INCLUDED) 85 | if (INCLUDED) 86 | check_fortran_compiler_flag("${FLAG}" ${VARIABLE}) 87 | elseif (NOT CMAKE_REQUIRED_QUIET) 88 | message(STATUS "Performing Test ${VARIABLE}") 89 | message(STATUS "Performing Test ${VARIABLE}" 90 | " - Failed (Check not supported)") 91 | endif () 92 | endif() 93 | endfunction () 94 | 95 | 96 | # Helper function to test compiler flags. 97 | function (sanitizer_check_compiler_flags FLAG_CANDIDATES NAME PREFIX) 98 | set(CMAKE_REQUIRED_QUIET ${${PREFIX}_FIND_QUIETLY}) 99 | 100 | get_property(ENABLED_LANGUAGES GLOBAL PROPERTY ENABLED_LANGUAGES) 101 | foreach (LANG ${ENABLED_LANGUAGES}) 102 | # Sanitizer flags are not dependend on language, but the used compiler. 103 | # So instead of searching flags foreach language, search flags foreach 104 | # compiler used. 105 | set(COMPILER ${CMAKE_${LANG}_COMPILER_ID}) 106 | if (NOT DEFINED ${PREFIX}_${COMPILER}_FLAGS) 107 | foreach (FLAG ${FLAG_CANDIDATES}) 108 | if(NOT CMAKE_REQUIRED_QUIET) 109 | message(STATUS "Try ${COMPILER} ${NAME} flag = [${FLAG}]") 110 | endif() 111 | 112 | set(CMAKE_REQUIRED_FLAGS "${FLAG}") 113 | unset(${PREFIX}_FLAG_DETECTED CACHE) 114 | sanitizer_check_compiler_flag("${FLAG}" ${LANG} 115 | ${PREFIX}_FLAG_DETECTED) 116 | 117 | if (${PREFIX}_FLAG_DETECTED) 118 | # If compiler is a GNU compiler, search for static flag, if 119 | # SANITIZE_LINK_STATIC is enabled. 120 | if (SANITIZE_LINK_STATIC AND (${COMPILER} STREQUAL "GNU")) 121 | string(TOLOWER ${PREFIX} PREFIX_lower) 122 | sanitizer_check_compiler_flag( 123 | "-static-lib${PREFIX_lower}" ${LANG} 124 | ${PREFIX}_STATIC_FLAG_DETECTED) 125 | 126 | if (${PREFIX}_STATIC_FLAG_DETECTED) 127 | set(FLAG "-static-lib${PREFIX_lower} ${FLAG}") 128 | endif () 129 | endif () 130 | 131 | set(${PREFIX}_${COMPILER}_FLAGS "${FLAG}" CACHE STRING 132 | "${NAME} flags for ${COMPILER} compiler.") 133 | mark_as_advanced(${PREFIX}_${COMPILER}_FLAGS) 134 | break() 135 | endif () 136 | endforeach () 137 | 138 | if (NOT ${PREFIX}_FLAG_DETECTED) 139 | set(${PREFIX}_${COMPILER}_FLAGS "" CACHE STRING 140 | "${NAME} flags for ${COMPILER} compiler.") 141 | mark_as_advanced(${PREFIX}_${COMPILER}_FLAGS) 142 | 143 | message(WARNING "${NAME} is not available for ${COMPILER} " 144 | "compiler. Targets using this compiler will be " 145 | "compiled without ${NAME}.") 146 | endif () 147 | endif () 148 | endforeach () 149 | endfunction () 150 | 151 | 152 | # Helper to assign sanitizer flags for TARGET. 153 | function (saitizer_add_flags TARGET NAME PREFIX) 154 | # Get list of compilers used by target and check, if sanitizer is available 155 | # for this target. Other compiler checks like check for conflicting 156 | # compilers will be done in add_sanitizers function. 157 | sanitizer_target_compilers(${TARGET} TARGET_COMPILER) 158 | list(LENGTH TARGET_COMPILER NUM_COMPILERS) 159 | if ("${${PREFIX}_${TARGET_COMPILER}_FLAGS}" STREQUAL "") 160 | return() 161 | endif() 162 | 163 | # Set compile- and link-flags for target. 164 | set_property(TARGET ${TARGET} APPEND_STRING 165 | PROPERTY COMPILE_FLAGS " ${${PREFIX}_${TARGET_COMPILER}_FLAGS}") 166 | set_property(TARGET ${TARGET} APPEND_STRING 167 | PROPERTY COMPILE_FLAGS " ${SanBlist_${TARGET_COMPILER}_FLAGS}") 168 | set_property(TARGET ${TARGET} APPEND_STRING 169 | PROPERTY LINK_FLAGS " ${${PREFIX}_${TARGET_COMPILER}_FLAGS}") 170 | endfunction () 171 | -------------------------------------------------------------------------------- /tests/test_common.h: -------------------------------------------------------------------------------- 1 | #include "Catch2/single_include/catch2/catch.hpp" 2 | 3 | #include "../tinyndarray.h" 4 | 5 | #include 6 | 7 | using namespace tinyndarray; 8 | 9 | static void CheckNdArray(const NdArray& m, const std::string& str, 10 | int precision = -1) { 11 | std::stringstream ss; 12 | if (0 < precision) { 13 | ss << std::setprecision(4); 14 | } 15 | ss << m; 16 | CHECK(ss.str() == str); 17 | } 18 | 19 | static void CheckNdArrayInplace(NdArray&& x, const std::string& str, 20 | std::function f, 21 | int precision = -1) { 22 | uintptr_t x_id = x.id(); 23 | const NdArray& y = f(std::move(x)); 24 | CHECK(y.id() == x_id); 25 | CheckNdArray(y, str, precision); 26 | } 27 | 28 | static void CheckNdArrayInplace(NdArray&& lhs, NdArray&& rhs, 29 | const std::string& str, 30 | std::function f, 31 | int precision = -1) { 32 | uintptr_t l_id = lhs.id(); 33 | uintptr_t r_id = rhs.id(); 34 | const NdArray& ret = f(std::move(lhs), std::move(rhs)); 35 | CHECK((ret.id() == l_id || ret.id() == r_id)); 36 | CheckNdArray(ret, str, precision); 37 | } 38 | 39 | static void CheckNdArrayInplace( 40 | const NdArray& lhs, NdArray&& rhs, const std::string& str, 41 | std::function f, 42 | int precision = -1) { 43 | uintptr_t l_id = lhs.id(); 44 | uintptr_t r_id = rhs.id(); 45 | const NdArray& ret = f(lhs, std::move(rhs)); 46 | CHECK((ret.id() == l_id || ret.id() == r_id)); 47 | CheckNdArray(ret, str, precision); 48 | } 49 | 50 | static void CheckNdArrayInplace( 51 | NdArray&& lhs, const NdArray& rhs, const std::string& str, 52 | std::function f, 53 | int precision = -1) { 54 | uintptr_t l_id = lhs.id(); 55 | uintptr_t r_id = rhs.id(); 56 | const NdArray& ret = f(std::move(lhs), rhs); 57 | CHECK((ret.id() == l_id || ret.id() == r_id)); 58 | CheckNdArray(ret, str, precision); 59 | } 60 | 61 | static void CheckNdArrayInplace(NdArray&& lhs, float rhs, 62 | const std::string& str, 63 | std::function f, 64 | int precision = -1) { 65 | uintptr_t l_id = lhs.id(); 66 | const NdArray& ret = f(std::move(lhs), rhs); 67 | CHECK((ret.id() == l_id)); 68 | CheckNdArray(ret, str, precision); 69 | } 70 | 71 | static void CheckNdArrayInplace(float lhs, NdArray&& rhs, 72 | const std::string& str, 73 | std::function f, 74 | int precision = -1) { 75 | uintptr_t r_id = rhs.id(); 76 | const NdArray& ret = f(lhs, std::move(rhs)); 77 | CHECK((ret.id() == r_id)); 78 | CheckNdArray(ret, str, precision); 79 | } 80 | 81 | static void CheckNdArrayNotInplace(NdArray&& x, const std::string& str, 82 | std::function f, 83 | int precision = -1) { 84 | uintptr_t x_id = x.id(); 85 | const NdArray& y = f(x); 86 | CHECK(y.id() != x_id); 87 | CheckNdArray(y, str, precision); 88 | } 89 | 90 | static bool IsSameNdArray(const NdArray& m1, const NdArray& m2) { 91 | if (m1.shape() != m2.shape()) { 92 | return false; 93 | } 94 | auto&& data1 = m1.data(); 95 | auto&& data2 = m2.data(); 96 | for (int i = 0; i < static_cast(m1.size()); i++) { 97 | if (data1[i] != data2[i]) { 98 | return false; 99 | } 100 | } 101 | return true; 102 | } 103 | 104 | static void ResolveAmbiguous(NdArray& x) { 105 | for (auto&& v : x) { 106 | if (std::isnan(v)) { 107 | v = std::abs(v); 108 | } 109 | if (v == -0.f) { 110 | v = 0.f; 111 | } 112 | } 113 | } 114 | 115 | TEST_CASE("NdArray") { 116 | // -------------------------- Basic construction --------------------------- 117 | SECTION("Empty") { 118 | const NdArray m1; 119 | CHECK(m1.empty()); 120 | CHECK(m1.size() == 0); 121 | CHECK(m1.shape() == Shape{0}); 122 | CHECK(m1.ndim() == 1); 123 | } 124 | 125 | // --------------------------- Float initializer --------------------------- 126 | SECTION("Float initializer") { 127 | const NdArray m1 = {1.f, 2.f, 3.f}; 128 | const NdArray m2 = {{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}}; 129 | const NdArray m3 = {{{1.f, 2.f}}, {{3.f, 4.f}}, {{2.f, 3.f}}}; 130 | CHECK(m1.shape() == Shape{3}); 131 | CHECK(m2.shape() == Shape{2, 3}); 132 | CHECK(m3.shape() == Shape{3, 1, 2}); 133 | CheckNdArray(m1, "[1, 2, 3]"); 134 | CheckNdArray(m2, 135 | "[[1, 2, 3],\n" 136 | " [4, 5, 6]]"); 137 | CheckNdArray(m3, 138 | "[[[1, 2]],\n" 139 | " [[3, 4]],\n" 140 | " [[2, 3]]]"); 141 | } 142 | 143 | SECTION("Float initializer invalid") { 144 | CHECK_NOTHROW(NdArray{{{1.f, 2.f}}, {{3.f, 4.f}}, {{1.f, 2.f}}}); 145 | CHECK_THROWS(NdArray{{{1, 2}}, {}}); 146 | CHECK_THROWS(NdArray{{{1.f, 2.f}}, {{3.f, 4.f}}, {{1.f, 2.f, 3.f}}}); 147 | } 148 | 149 | SECTION("Confusable initializers") { 150 | const NdArray m1 = {1.f, 2.f, 3.f}; // Float initializer 151 | const NdArray m2 = {1, 2, 3}; // Shape (int) initalizer 152 | const NdArray m3 = {{1, 2, 3}}; // Float initializer due to nest 153 | CHECK(m1.shape() == Shape{3}); 154 | CHECK(m2.shape() == Shape{1, 2, 3}); 155 | CHECK(m3.shape() == Shape{1, 3}); 156 | } 157 | 158 | // --------------------------- Static initializer -------------------------- 159 | SECTION("Empty/Ones/Zeros") { 160 | const NdArray m1({2, 5}); // Same as Empty 161 | const auto m2 = NdArray::Empty({2, 5}); 162 | const auto m3 = NdArray::Zeros({2, 5}); 163 | const auto m4 = NdArray::Ones({2, 5}); 164 | CHECK(m1.shape() == Shape{2, 5}); 165 | CHECK(m2.shape() == Shape{2, 5}); 166 | CHECK(m3.shape() == Shape{2, 5}); 167 | CHECK(m4.shape() == Shape{2, 5}); 168 | CheckNdArray(m3, 169 | "[[0, 0, 0, 0, 0],\n" 170 | " [0, 0, 0, 0, 0]]"); 171 | CheckNdArray(m4, 172 | "[[1, 1, 1, 1, 1],\n" 173 | " [1, 1, 1, 1, 1]]"); 174 | } 175 | 176 | SECTION("Empty/Ones/Zeros by template") { 177 | const NdArray m1({2, 5}); // No template support 178 | const auto m2 = NdArray::Empty(2, 5); 179 | const auto m3 = NdArray::Zeros(2, 5); 180 | const auto m4 = NdArray::Ones(2, 5); 181 | CHECK(m1.shape() == Shape{2, 5}); 182 | CHECK(m2.shape() == Shape{2, 5}); 183 | CHECK(m3.shape() == Shape{2, 5}); 184 | CHECK(m4.shape() == Shape{2, 5}); 185 | CheckNdArray(m3, 186 | "[[0, 0, 0, 0, 0],\n" 187 | " [0, 0, 0, 0, 0]]"); 188 | CheckNdArray(m4, 189 | "[[1, 1, 1, 1, 1],\n" 190 | " [1, 1, 1, 1, 1]]"); 191 | } 192 | 193 | SECTION("Arange") { 194 | const auto m1 = NdArray::Arange(10.f); 195 | const auto m2 = NdArray::Arange(0.f, 10.f, 1.f); 196 | const auto m3 = NdArray::Arange(5.f, 5.5f, 0.1f); 197 | CHECK(m1.shape() == Shape{10}); 198 | CHECK(m2.shape() == Shape{10}); 199 | CHECK(m3.shape() == Shape{5}); 200 | CheckNdArray(m1, "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]"); 201 | CheckNdArray(m2, "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]"); 202 | CheckNdArray(m3, "[5, 5.1, 5.2, 5.3, 5.4]"); 203 | } 204 | 205 | // --------------------------------- Random -------------------------------- 206 | SECTION("Random uniform") { 207 | NdArray::Seed(0); 208 | const auto m1 = NdArray::Uniform({2, 3}); 209 | NdArray::Seed(0); 210 | const auto m2 = NdArray::Uniform({2, 3}); 211 | NdArray::Seed(1); 212 | const auto m3 = NdArray::Uniform({2, 3}); 213 | CHECK(IsSameNdArray(m1, m2)); 214 | CHECK(!IsSameNdArray(m1, m3)); 215 | } 216 | 217 | SECTION("Random normal") { 218 | NdArray::Seed(0); 219 | const auto m1 = NdArray::Normal({2, 3}); 220 | NdArray::Seed(0); 221 | const auto m2 = NdArray::Normal({2, 3}); 222 | NdArray::Seed(1); 223 | const auto m3 = NdArray::Normal({2, 3}); 224 | CHECK(IsSameNdArray(m1, m2)); 225 | CHECK(!IsSameNdArray(m1, m3)); 226 | } 227 | 228 | // ------------------------------ Basic method ----------------------------- 229 | SECTION("Basic method") { 230 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 231 | auto m2 = m1.copy(); 232 | CheckNdArray(m2, 233 | "[[0, 1, 2],\n" 234 | " [3, 4, 5]]"); 235 | m2.fill(-1); 236 | CheckNdArray(m1, 237 | "[[0, 1, 2],\n" 238 | " [3, 4, 5]]"); 239 | CheckNdArray(m2, 240 | "[[-1, -1, -1],\n" 241 | " [-1, -1, -1]]"); 242 | CHECK(m1.ndim() == 2); 243 | CHECK(m1.flatten().ndim() == 1); 244 | CHECK(m1.flatten().id() != m1.id()); // Copy 245 | CHECK(m1.ravel().ndim() == 1); 246 | CHECK(m1.ravel().id() == m1.id()); // Same instance 247 | 248 | m1.resize({2, 2}); 249 | CheckNdArray(m1, 250 | "[[0, 1],\n" 251 | " [2, 3]]"); 252 | m1.resize({2, 4}); 253 | CheckNdArray(m1, 254 | "[[0, 1, 2, 3],\n" 255 | " [0, 0, 0, 0]]"); 256 | } 257 | 258 | // ------------------------------- Begin/End ------------------------------- 259 | SECTION("Begin/End") { 260 | auto m1 = NdArray::Arange(1.f, 10.01f); 261 | // C++11 for-loop 262 | float sum1 = 0.f; 263 | for (auto&& v : m1) { 264 | sum1 += v; 265 | } 266 | CHECK(sum1 == Approx(55.f)); 267 | // std library 268 | float sum2 = std::accumulate(m1.begin(), m1.end(), 0.f); 269 | CHECK(sum2 == Approx(55.f)); 270 | } 271 | 272 | // ------------------------------- Float cast ------------------------------ 273 | SECTION("Float cast") { 274 | auto m1 = NdArray::Ones({1, 1}); 275 | auto m2 = NdArray::Ones({1, 2}); 276 | CHECK(static_cast(m1) == 1); 277 | CHECK_THROWS(static_cast(m2)); 278 | } 279 | 280 | // ------------------------------ Index access ----------------------------- 281 | SECTION("Index access by []") { 282 | auto m1 = NdArray::Arange(12.f); 283 | auto m2 = NdArray::Arange(12.f).reshape({3, 4}); 284 | auto m3 = NdArray::Arange(12.f).reshape({2, 2, -1}); 285 | m1[3] = -1.f; 286 | m1[-2] = -2.f; 287 | m2[{1, 1}] = -1.f; 288 | m2[{-1, 3}] = -2.f; 289 | m3[{1, 1, 2}] = -1.f; 290 | m3[{0, 1, -2}] = -2.f; 291 | CheckNdArray(m1, "[0, 1, 2, -1, 4, 5, 6, 7, 8, 9, -2, 11]"); 292 | CheckNdArray(m2, 293 | "[[0, 1, 2, 3],\n" 294 | " [4, -1, 6, 7],\n" 295 | " [8, 9, 10, -2]]"); 296 | CheckNdArray(m3, 297 | "[[[0, 1, 2],\n" 298 | " [3, -2, 5]],\n" 299 | " [[6, 7, 8],\n" 300 | " [9, 10, -1]]]"); 301 | } 302 | 303 | SECTION("Index access by ()") { 304 | auto m1 = NdArray::Arange(12.f); 305 | auto m2 = NdArray::Arange(12.f).reshape({3, 4}); 306 | auto m3 = NdArray::Arange(12.f).reshape({2, 2, -1}); 307 | m1(3) = -1.f; 308 | m1(-2) = -2.f; 309 | m2(1, 1) = -1.f; 310 | m2(-1, 3) = -2.f; 311 | m3(1, 1, 2) = -1.f; 312 | m3(0, 1, -2) = -2.f; 313 | CheckNdArray(m1, "[0, 1, 2, -1, 4, 5, 6, 7, 8, 9, -2, 11]"); 314 | CheckNdArray(m2, 315 | "[[0, 1, 2, 3],\n" 316 | " [4, -1, 6, 7],\n" 317 | " [8, 9, 10, -2]]"); 318 | CheckNdArray(m3, 319 | "[[[0, 1, 2],\n" 320 | " [3, -2, 5]],\n" 321 | " [[6, 7, 8],\n" 322 | " [9, 10, -1]]]"); 323 | } 324 | 325 | SECTION("Index access by [] (const)") { 326 | const auto m1 = NdArray::Arange(12.f).reshape({1, 4, 3}); 327 | CHECK(m1[{0, 2, 1}] == 7); 328 | CHECK(m1[{0, -1, 1}] == 10); 329 | } 330 | 331 | SECTION("Index access by () (const)") { 332 | const auto m1 = NdArray::Arange(12.f).reshape({1, 4, 3}); 333 | CHECK(m1(0, 2, 1) == 7); 334 | CHECK(m1(0, -1, 1) == 10); 335 | } 336 | 337 | // -------------------------------- Reshape -------------------------------- 338 | SECTION("Reshape") { 339 | auto m1 = NdArray::Arange(12.f); 340 | auto m2 = m1.reshape({3, 4}); 341 | auto m3 = m2.reshape({2, -1}); 342 | auto m4 = m3.reshape({2, 2, -1}); 343 | CheckNdArray(m1, "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]"); 344 | CheckNdArray(m2, 345 | "[[0, 1, 2, 3],\n" 346 | " [4, 5, 6, 7],\n" 347 | " [8, 9, 10, 11]]"); 348 | CheckNdArray(m3, 349 | "[[0, 1, 2, 3, 4, 5],\n" 350 | " [6, 7, 8, 9, 10, 11]]"); 351 | CheckNdArray(m4, 352 | "[[[0, 1, 2],\n" 353 | " [3, 4, 5]],\n" 354 | " [[6, 7, 8],\n" 355 | " [9, 10, 11]]]"); 356 | } 357 | 358 | SECTION("Reshape by template") { 359 | auto m1 = NdArray::Arange(12.f); 360 | auto m2 = m1.reshape(3, 4); 361 | auto m3 = m2.reshape(2, -1); 362 | auto m4 = m3.reshape(2, 2, -1); 363 | CheckNdArray(m1, "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]"); 364 | CheckNdArray(m2, 365 | "[[0, 1, 2, 3],\n" 366 | " [4, 5, 6, 7],\n" 367 | " [8, 9, 10, 11]]"); 368 | CheckNdArray(m3, 369 | "[[0, 1, 2, 3, 4, 5],\n" 370 | " [6, 7, 8, 9, 10, 11]]"); 371 | CheckNdArray(m4, 372 | "[[[0, 1, 2],\n" 373 | " [3, 4, 5]],\n" 374 | " [[6, 7, 8],\n" 375 | " [9, 10, 11]]]"); 376 | } 377 | 378 | SECTION("Reshape with value change") { 379 | auto m1 = NdArray::Arange(12.f); 380 | auto m2 = m1.reshape({3, 4}); 381 | auto m3 = m2.reshape({2, -1}); 382 | auto m4 = m3.reshape({2, 2, -1}); 383 | m1.data()[0] = -1.f; 384 | CheckNdArray(m1, "[-1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]"); 385 | CheckNdArray(m2, 386 | "[[-1, 1, 2, 3],\n" 387 | " [4, 5, 6, 7],\n" 388 | " [8, 9, 10, 11]]"); 389 | CheckNdArray(m3, 390 | "[[-1, 1, 2, 3, 4, 5],\n" 391 | " [6, 7, 8, 9, 10, 11]]"); 392 | CheckNdArray(m4, 393 | "[[[-1, 1, 2],\n" 394 | " [3, 4, 5]],\n" 395 | " [[6, 7, 8],\n" 396 | " [9, 10, 11]]]"); 397 | } 398 | 399 | SECTION("Reshape invalid") { 400 | auto m1 = NdArray::Arange(12.f); 401 | CHECK_THROWS(m1.reshape({5, 2})); 402 | CHECK_THROWS(m1.reshape({-1, -1})); 403 | } 404 | 405 | SECTION("Flatten") { 406 | auto m1 = NdArray::Arange(12.f).reshape(2, 2, 3); 407 | auto m2 = m1.flatten(); 408 | CheckNdArray(m2, "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]"); 409 | } 410 | 411 | // --------------------------------- Slice --------------------------------- 412 | SECTION("Slice 2-dim") { 413 | auto m1 = NdArray::Arange(16.f).reshape(4, 4); 414 | auto m2 = m1.slice({{1, 3}, {1, 3}}); 415 | auto m3 = m1.slice({1, 3}, {0, 4}); 416 | auto m4 = m1.slice({0, 4}, {1, 3}); 417 | auto m5 = m1.slice({1, -1}, {0, 100000}); 418 | CHECK(m1.shape() == Shape{4, 4}); 419 | CHECK(m2.shape() == Shape{2, 2}); 420 | CHECK(m3.shape() == Shape{2, 4}); 421 | CHECK(m4.shape() == Shape{4, 2}); 422 | CHECK(m5.shape() == Shape{2, 4}); 423 | CheckNdArray(m1, 424 | "[[0, 1, 2, 3],\n" 425 | " [4, 5, 6, 7],\n" 426 | " [8, 9, 10, 11],\n" 427 | " [12, 13, 14, 15]]"); 428 | CheckNdArray(m2, 429 | "[[5, 6],\n" 430 | " [9, 10]]"); 431 | CheckNdArray(m3, 432 | "[[4, 5, 6, 7],\n" 433 | " [8, 9, 10, 11]]"); 434 | CheckNdArray(m4, 435 | "[[1, 2],\n" 436 | " [5, 6],\n" 437 | " [9, 10],\n" 438 | " [13, 14]]"); 439 | CheckNdArray(m5, 440 | "[[4, 5, 6, 7],\n" 441 | " [8, 9, 10, 11]]"); 442 | } 443 | 444 | SECTION("Slice high-dim") { 445 | auto m1 = NdArray::Arange(256.f).reshape(4, 4, 4, 4); 446 | auto m2 = m1.slice({{1, 3}, {1, 3}, {1, 3}, {1, 3}}); 447 | auto m3 = m1.slice({1, 3}, {1, 3}, {1, 3}, {1, 3}); 448 | CHECK(m1.shape() == Shape{4, 4, 4, 4}); 449 | CHECK(m2.shape() == Shape{2, 2, 2, 2}); 450 | CHECK(m3.shape() == Shape{2, 2, 2, 2}); 451 | CheckNdArray(m2, 452 | "[[[[85, 86],\n" 453 | " [89, 90]],\n" 454 | " [[101, 102],\n" 455 | " [105, 106]]],\n" 456 | " [[[149, 150],\n" 457 | " [153, 154]],\n" 458 | " [[165, 166],\n" 459 | " [169, 170]]]]"); 460 | CheckNdArray(m3, 461 | "[[[[85, 86],\n" 462 | " [89, 90]],\n" 463 | " [[101, 102],\n" 464 | " [105, 106]]],\n" 465 | " [[[149, 150],\n" 466 | " [153, 154]],\n" 467 | " [[165, 166],\n" 468 | " [169, 170]]]]"); 469 | } 470 | 471 | // ------------------------------ Dot product ------------------------------ 472 | SECTION("Dot (empty)") { 473 | // Empty array 474 | auto m1 = NdArray::Arange(0.f); 475 | CHECK_THROWS(m1.dot(m1)); 476 | } 477 | 478 | SECTION("Dot (scalar)") { 479 | // Scalar multiply 480 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 481 | auto m1_a = m1.dot(NdArray{2.f}); 482 | auto m1_b = NdArray({2.f}).dot(m1); 483 | CheckNdArray(m1_a, 484 | "[[0, 2, 4],\n" 485 | " [6, 8, 10]]"); 486 | CheckNdArray(m1_b, 487 | "[[0, 2, 4],\n" 488 | " [6, 8, 10]]"); 489 | } 490 | 491 | SECTION("Dot (1D, 1D)") { 492 | // Inner product of vectors 493 | auto m1 = NdArray::Arange(3.f); 494 | auto m2 = NdArray::Ones(3); 495 | float m11 = m1.dot(m1); 496 | float m12 = m1.dot(m2); 497 | CHECK(m11 == Approx(5.f)); 498 | CHECK(m12 == Approx(3.f)); 499 | // Shape mismatch 500 | auto m3 = NdArray::Arange(4.f); 501 | CHECK_THROWS(m1.dot(m3)); 502 | } 503 | 504 | SECTION("Dot (2D, 2D)") { 505 | // Inner product of 2D matrix 506 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 507 | auto m2 = NdArray::Arange(6.f).reshape(3, 2); 508 | auto m12 = m1.dot(m2); 509 | CheckNdArray(m12, 510 | "[[10, 13],\n" 511 | " [28, 40]]"); 512 | // Shape mismatch 513 | CHECK_THROWS(m1.dot(m1)); 514 | } 515 | 516 | SECTION("Dot (2D, 1D)") { 517 | // Inner product of 2D matrix and vector (2D, 1D) 518 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 519 | auto m2 = NdArray::Arange(3.f); 520 | auto m12 = m1.dot(m2); 521 | CheckNdArray(m12, "[5, 14]"); 522 | // Shape mismatch 523 | auto m3 = NdArray::Arange(2.f); 524 | CHECK_THROWS(m1.dot(m3)); 525 | } 526 | 527 | SECTION("Dot (ND, 1D)") { 528 | // Inner product of ND matrix and vector (ND, 1D) 529 | auto m1 = NdArray::Arange(12.f).reshape(2, 2, 3); 530 | auto m2 = NdArray::Arange(3.f); 531 | auto m12 = m1.dot(m2); 532 | CheckNdArray(m12, 533 | "[[5, 14],\n" 534 | " [23, 32]]"); 535 | } 536 | 537 | SECTION("Dot (ND, MD)") { 538 | // Inner product of ND matrix and MD matrix 539 | auto m1 = NdArray::Arange(12.f).reshape(2, 3, 2); 540 | auto m2 = NdArray::Arange(6.f).reshape(2, 3); 541 | auto m3 = NdArray::Arange(12.f).reshape(3, 2, 2); 542 | auto m12 = m1.dot(m2); 543 | auto m13 = m1.dot(m3); 544 | CHECK(m12.shape() == Shape{2, 3, 3}); 545 | CHECK(m13.shape() == Shape{2, 3, 3, 2}); 546 | CheckNdArray(m12, 547 | "[[[3, 4, 5],\n" 548 | " [9, 14, 19],\n" 549 | " [15, 24, 33]],\n" 550 | " [[21, 34, 47],\n" 551 | " [27, 44, 61],\n" 552 | " [33, 54, 75]]]"); 553 | CheckNdArray(m13, 554 | "[[[[2, 3],\n" 555 | " [6, 7],\n" 556 | " [10, 11]],\n" 557 | " [[6, 11],\n" 558 | " [26, 31],\n" 559 | " [46, 51]],\n" 560 | " [[10, 19],\n" 561 | " [46, 55],\n" 562 | " [82, 91]]],\n" 563 | " [[[14, 27],\n" 564 | " [66, 79],\n" 565 | " [118, 131]],\n" 566 | " [[18, 35],\n" 567 | " [86, 103],\n" 568 | " [154, 171]],\n" 569 | " [[22, 43],\n" 570 | " [106, 127],\n" 571 | " [190, 211]]]]"); 572 | } 573 | 574 | // ----------------------------- Matmul product ---------------------------- 575 | SECTION("Matmul (2D, 2D)") { 576 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 577 | auto m2 = NdArray::Arange(6.f).reshape(3, 2); 578 | auto m12 = Matmul(m1, m2); 579 | CheckNdArray(m12, 580 | "[[10, 13],\n" 581 | " [28, 40]]"); 582 | } 583 | 584 | SECTION("Matmul (2D, 3D)") { 585 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 586 | auto m2 = NdArray::Arange(6.f).reshape(1, 3, 2); 587 | auto m12 = Matmul(m1, m2); 588 | CheckNdArray(m12, 589 | "[[[10, 13],\n" 590 | " [28, 40]]]"); 591 | } 592 | 593 | SECTION("Matmul (3D, 2D)") { 594 | auto m1 = NdArray::Arange(6.f).reshape(1, 2, 3); 595 | auto m2 = NdArray::Arange(6.f).reshape(3, 2); 596 | auto m12 = Matmul(m1, m2); 597 | CheckNdArray(m12, 598 | "[[[10, 13],\n" 599 | " [28, 40]]]"); 600 | } 601 | 602 | SECTION("Matmul (1D, 2D)") { 603 | auto m1 = NdArray::Arange(3.f); 604 | auto m2 = NdArray::Arange(6.f).reshape(3, 2); 605 | auto m12 = Matmul(m1, m2); 606 | CheckNdArray(m12, "[10, 13]"); 607 | } 608 | 609 | SECTION("Matmul (2D, 1D)") { 610 | auto m1 = NdArray::Arange(6.f).reshape(3, 2); 611 | auto m2 = NdArray::Arange(2.f); 612 | auto m12 = Matmul(m1, m2); 613 | CheckNdArray(m12, "[1, 3, 5]"); 614 | } 615 | 616 | SECTION("Matmul (1D, ND)") { 617 | auto m1 = NdArray::Arange(3.f); 618 | auto m2 = NdArray::Arange(12.f).reshape(2, 1, 3, 2); 619 | auto m12 = Matmul(m1, m2); 620 | CheckNdArray(m12, 621 | "[[[10, 13]],\n" 622 | " [[28, 31]]]"); 623 | } 624 | 625 | SECTION("Matmul (2D, ND)") { 626 | auto m1 = NdArray::Arange(12.f).reshape(2, 1, 3, 2); 627 | auto m2 = NdArray::Arange(2.f); 628 | auto m12 = Matmul(m1, m2); 629 | CheckNdArray(m12, 630 | "[[[1, 3, 5]],\n" 631 | " [[7, 9, 11]]]"); 632 | } 633 | 634 | SECTION("Matmul (ND, MD)") { 635 | auto m1 = NdArray::Arange(36.f).reshape(2, 3, 1, 2, 3); 636 | auto m2 = NdArray::Arange(36.f).reshape(3, 3, 4); 637 | auto m12 = Matmul(m1, m2); 638 | CHECK(m12.shape() == Shape{2, 3, 3, 2, 4}); 639 | CheckNdArray(m12, 640 | "[[[[[20, 23, 26, 29],\n" 641 | " [56, 68, 80, 92]],\n" 642 | " [[56, 59, 62, 65],\n" 643 | " [200, 212, 224, 236]],\n" 644 | " [[92, 95, 98, 101],\n" 645 | " [344, 356, 368, 380]]],\n" 646 | " [[[92, 113, 134, 155],\n" 647 | " [128, 158, 188, 218]],\n" 648 | " [[344, 365, 386, 407],\n" 649 | " [488, 518, 548, 578]],\n" 650 | " [[596, 617, 638, 659],\n" 651 | " [848, 878, 908, 938]]],\n" 652 | " [[[164, 203, 242, 281],\n" 653 | " [200, 248, 296, 344]],\n" 654 | " [[632, 671, 710, 749],\n" 655 | " [776, 824, 872, 920]],\n" 656 | " [[1100, 1139, 1178, 1217],\n" 657 | " [1352, 1400, 1448, 1496]]]],\n" 658 | " [[[[236, 293, 350, 407],\n" 659 | " [272, 338, 404, 470]],\n" 660 | " [[920, 977, 1034, 1091],\n" 661 | " [1064, 1130, 1196, 1262]],\n" 662 | " [[1604, 1661, 1718, 1775],\n" 663 | " [1856, 1922, 1988, 2054]]],\n" 664 | " [[[308, 383, 458, 533],\n" 665 | " [344, 428, 512, 596]],\n" 666 | " [[1208, 1283, 1358, 1433],\n" 667 | " [1352, 1436, 1520, 1604]],\n" 668 | " [[2108, 2183, 2258, 2333],\n" 669 | " [2360, 2444, 2528, 2612]]],\n" 670 | " [[[380, 473, 566, 659],\n" 671 | " [416, 518, 620, 722]],\n" 672 | " [[1496, 1589, 1682, 1775],\n" 673 | " [1640, 1742, 1844, 1946]],\n" 674 | " [[2612, 2705, 2798, 2891],\n" 675 | " [2864, 2966, 3068, 3170]]]]]"); 676 | } 677 | 678 | // ----------------------------- Cross product ----------------------------- 679 | SECTION("Cross (1D, 1D), (3, 3 elem)") { 680 | NdArray m1 = {1.f, 2.f, 3.f}; 681 | NdArray m2 = {4.f, 5.f, 6.f}; 682 | auto m12 = m1.cross(m2); 683 | CheckNdArray(m12, "[-3, 6, -3]"); 684 | } 685 | 686 | SECTION("Cross (1D, 1D), (3, 2 elem)") { 687 | NdArray m1 = {1.f, 2.f, 3.f}; 688 | NdArray m2 = {4.f, 5.f}; 689 | auto m12 = m1.cross(m2); 690 | CheckNdArray(m12, "[-15, 12, -3]"); 691 | } 692 | 693 | SECTION("Cross (1D, 1D), (2, 3 elem)") { 694 | NdArray m1 = {1.f, 2.f}; 695 | NdArray m2 = {4.f, 5.f, 6.f}; 696 | auto m12 = m1.cross(m2); 697 | CheckNdArray(m12, "[12, -6, -3]"); 698 | } 699 | 700 | SECTION("Cross (1D, 1D), (2, 2 elem)") { 701 | NdArray m1 = {1.f, 2.f}; 702 | NdArray m2 = {4.f, 5.f}; 703 | float m12 = m1.cross(m2); 704 | CHECK(m12 == Approx(-3.f)); 705 | } 706 | 707 | SECTION("Cross (1D, 1D), (mismatch)") { 708 | NdArray m1 = {1.f}; 709 | NdArray m2 = {4.f, 5.f}; 710 | NdArray m3 = {4.f, 5.f, 6.f, 7.f}; 711 | CHECK_THROWS(m1.cross(m2)); 712 | CHECK_THROWS(m2.cross(m3)); 713 | } 714 | 715 | SECTION("Cross (ND, MD), (3, 3 elem)") { 716 | auto m1 = NdArray::Arange(18.f).reshape(3, 2, 3); 717 | auto m2 = NdArray::Arange(6.f).reshape(2, 3) + 1.f; 718 | auto m12 = m1.cross(m2); 719 | CheckNdArray(m12, 720 | "[[[-1, 2, -1],\n" 721 | " [-1, 2, -1]],\n" 722 | " [[5, -10, 5],\n" 723 | " [5, -10, 5]],\n" 724 | " [[11, -22, 11],\n" 725 | " [11, -22, 11]]]"); 726 | } 727 | 728 | SECTION("Cross (ND, MD), (3, 2 elem)") { 729 | auto m1 = NdArray::Arange(18.f).reshape(2, 3, 3); 730 | auto m2 = NdArray::Arange(6.f).reshape(3, 2) + 1.f; 731 | auto m12 = m1.cross(m2); 732 | CheckNdArray(m12, 733 | "[[[-4, 2, -1],\n" 734 | " [-20, 15, 0],\n" 735 | " [-48, 40, 1]],\n" 736 | " [[-22, 11, 8],\n" 737 | " [-56, 42, 9],\n" 738 | " [-102, 85, 10]]]"); 739 | } 740 | 741 | SECTION("Cross (ND, MD), (2, 3 elem)") { 742 | auto m1 = NdArray::Arange(12.f).reshape(3, 2, 2); 743 | auto m2 = NdArray::Arange(6.f).reshape(2, 3) + 1.f; 744 | auto m12 = m1.cross(m2); 745 | CheckNdArray(m12, 746 | "[[[3, -0, -1],\n" 747 | " [18, -12, -2]],\n" 748 | " [[15, -12, 3],\n" 749 | " [42, -36, 2]],\n" 750 | " [[27, -24, 7],\n" 751 | " [66, -60, 6]]]"); 752 | } 753 | 754 | SECTION("Cross (ND, MD), (2, 2 elem)") { 755 | auto m1 = NdArray::Arange(12.f).reshape(3, 2, 2); 756 | auto m2 = NdArray::Arange(4.f).reshape(2, 2) + 1.f; 757 | auto m12 = m1.cross(m2); 758 | CheckNdArray(m12, 759 | "[[-1, -1],\n" 760 | " [3, 3],\n" 761 | " [7, 7]]"); 762 | } 763 | 764 | // ----------------------------- Axis operation ---------------------------- 765 | SECTION("Sum") { 766 | NdArray m0; 767 | auto m1 = NdArray::Arange(6.f); 768 | auto m2 = NdArray::Arange(36.f).reshape(2, 3, 2, 3); 769 | CheckNdArray(m0.sum(), "[0]"); 770 | CheckNdArray(m1.sum(), "[15]"); 771 | CheckNdArray(m2.sum({0}), 772 | "[[[18, 20, 22],\n" 773 | " [24, 26, 28]],\n" 774 | " [[30, 32, 34],\n" 775 | " [36, 38, 40]],\n" 776 | " [[42, 44, 46],\n" 777 | " [48, 50, 52]]]"); 778 | CheckNdArray(m2.sum({2}), 779 | "[[[3, 5, 7],\n" 780 | " [15, 17, 19],\n" 781 | " [27, 29, 31]],\n" 782 | " [[39, 41, 43],\n" 783 | " [51, 53, 55],\n" 784 | " [63, 65, 67]]]"); 785 | CheckNdArray(m2.sum({3}), 786 | "[[[3, 12],\n" 787 | " [21, 30],\n" 788 | " [39, 48]],\n" 789 | " [[57, 66],\n" 790 | " [75, 84],\n" 791 | " [93, 102]]]"); 792 | CheckNdArray(m2.sum({1, 2}), 793 | "[[45, 51, 57],\n" 794 | " [153, 159, 165]]"); 795 | CheckNdArray(m2.sum({1, 3}), 796 | "[[63, 90],\n" 797 | " [225, 252]]"); 798 | CheckNdArray(m2.sum({0, 1, 2, 3}), "[630]"); 799 | CheckNdArray(m1.reshape(1, 2, 3).sum({0, 1, 2}), "[15]"); 800 | } 801 | 802 | SECTION("Min") { 803 | NdArray m0; 804 | auto m1 = NdArray::Arange(6.f) - 3.f; 805 | auto m2 = NdArray::Arange(12.f).reshape(2, 3, 2) - 6.f; 806 | CHECK_THROWS(m0.min()); 807 | CheckNdArray(m1.min(), "[-3]"); 808 | CheckNdArray(m2.min({0}), 809 | "[[-6, -5],\n" 810 | " [-4, -3],\n" 811 | " [-2, -1]]"); 812 | CheckNdArray(m2.min({2}), 813 | "[[-6, -4, -2],\n" 814 | " [0, 2, 4]]"); 815 | CheckNdArray(m2.min({2, 1}), "[-6, 0]"); 816 | } 817 | 818 | SECTION("Max") { 819 | NdArray m0; 820 | auto m1 = NdArray::Arange(6.f) - 3.f; 821 | auto m2 = NdArray::Arange(12.f).reshape(2, 3, 2) - 6.f; 822 | CHECK_THROWS(m0.max()); 823 | CheckNdArray(m1.max(), "[2]"); 824 | CheckNdArray(m2.max({0}), 825 | "[[0, 1],\n" 826 | " [2, 3],\n" 827 | " [4, 5]]"); 828 | CheckNdArray(m2.max({2}), 829 | "[[-5, -3, -1],\n" 830 | " [1, 3, 5]]"); 831 | CheckNdArray(m2.max({2, 1}), "[-1, 5]"); 832 | } 833 | 834 | SECTION("Mean") { 835 | NdArray m0; 836 | auto m1 = NdArray::Arange(6.f) - 3.f; 837 | auto m2 = NdArray::Arange(12.f).reshape(2, 3, 2) - 6.f; 838 | CheckNdArray(m0.mean(), "[nan]"); 839 | CheckNdArray(m1.mean(), "[-0.5]"); 840 | CheckNdArray(m2.mean({0}), 841 | "[[-3, -2],\n" 842 | " [-1, 0],\n" 843 | " [1, 2]]"); 844 | CheckNdArray(m2.mean({2}), 845 | "[[-5.5, -3.5, -1.5],\n" 846 | " [0.5, 2.5, 4.5]]"); 847 | CheckNdArray(m2.mean({2, 1}), "[-3.5, 2.5]"); 848 | } 849 | 850 | SECTION("Sum keepdims") { 851 | NdArray m0; 852 | auto m1 = NdArray::Arange(36.f).reshape(2, 3, 2, 3); 853 | CheckNdArray(m1.sum({1, 2}, true), 854 | "[[[[45, 51, 57]]],\n" 855 | " [[[153, 159, 165]]]]"); 856 | CheckNdArray(m1.sum({1, 3}, true), 857 | "[[[[63],\n" 858 | " [90]]],\n" 859 | " [[[225],\n" 860 | " [252]]]]"); 861 | CheckNdArray(m1.sum({0, 1, 2, 3}, true), "[[[[630]]]]"); 862 | CheckNdArray(m1.sum({}, true), "[[[[630]]]]"); 863 | } 864 | 865 | // --------------------------- Logistic operation -------------------------- 866 | SECTION("All (no axis)") { 867 | NdArray m0; 868 | CHECK(All(m0)); 869 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 870 | auto m2 = m1.copy(); 871 | CHECK(All(m1 == m2)); 872 | m2(1, 2) = -1.f; 873 | CHECK(!All(m1 == m2)); 874 | } 875 | 876 | SECTION("Any (no axis)") { 877 | NdArray m0; 878 | CHECK(!Any(m0)); 879 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 880 | auto m2 = m1.copy() + 1.f; 881 | CHECK(!Any(m1 == m2)); 882 | m2(0, 0) = 0.f; 883 | CHECK(Any(m1 == m2)); 884 | } 885 | 886 | SECTION("All (with axis)") { 887 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 888 | CheckNdArray(All(m1, {0}), "[0, 1, 1]"); 889 | CheckNdArray(All(m1, {1}), "[0, 1]"); 890 | } 891 | 892 | SECTION("Any (with axis)") { 893 | auto m1 = NdArray::Zeros(2, 3); 894 | m1(0, 0) = -1.f; 895 | CheckNdArray(Any(m1, {0}), "[1, 0, 0]"); 896 | CheckNdArray(Any(m1, {1}), "[1, 0]"); 897 | } 898 | 899 | SECTION("Where") { 900 | auto m1 = 2.f < NdArray::Arange(6.f).reshape(2, 3); 901 | CheckNdArray(Where(m1, NdArray::Ones(2, 1), NdArray::Arange(3)), 902 | "[[0, 1, 2],\n" 903 | " [1, 1, 1]]"); 904 | CheckNdArray(Where(m1, 1.f, NdArray::Arange(3)), 905 | "[[0, 1, 2],\n" 906 | " [1, 1, 1]]"); 907 | CheckNdArray(Where(m1, NdArray::Arange(3), 0.f), 908 | "[[0, 0, 0],\n" 909 | " [0, 1, 2]]"); 910 | CheckNdArray(Where(m1, 1.f, 0.f), 911 | "[[0, 0, 0],\n" 912 | " [1, 1, 1]]"); 913 | CheckNdArray(Where(m1, 0.f, 1.f), 914 | "[[1, 1, 1],\n" 915 | " [0, 0, 0]]"); 916 | } 917 | 918 | // ------------------------------- Operator -------------------------------- 919 | SECTION("Single +- operators") { 920 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 921 | auto m_p = +m1; 922 | auto m_n = -m1; 923 | CHECK(m_p.id() != m1.id()); 924 | CHECK(m_n.id() != m1.id()); 925 | CheckNdArray(m_p, 926 | "[[0, 1, 2],\n" 927 | " [3, 4, 5]]"); 928 | CheckNdArray(m_n, 929 | "[[-0, -1, -2],\n" 930 | " [-3, -4, -5]]"); 931 | } 932 | 933 | SECTION("Add same shape") { 934 | auto m1 = NdArray::Arange(12.f).reshape(2, 3, 2); 935 | auto m2 = NdArray::Ones({2, 3, 2}); 936 | auto m3 = m1 + m2; 937 | CHECK(m1.shape() == m2.shape()); 938 | CHECK(m1.shape() == m3.shape()); 939 | CheckNdArray(m3, 940 | "[[[1, 2],\n" 941 | " [3, 4],\n" 942 | " [5, 6]],\n" 943 | " [[7, 8],\n" 944 | " [9, 10],\n" 945 | " [11, 12]]]"); 946 | } 947 | 948 | SECTION("Add broadcast 2-dim") { 949 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 950 | auto m2 = NdArray::Arange(2.f).reshape(2, 1); 951 | auto m3 = NdArray::Arange(3.f).reshape(1, 3); 952 | auto m12 = m1 + m2; 953 | auto m13 = m1 + m3; 954 | auto m23 = m2 + m3; 955 | CHECK(m12.shape() == Shape{2, 3}); 956 | CHECK(m13.shape() == Shape{2, 3}); 957 | CHECK(m23.shape() == Shape{2, 3}); 958 | CheckNdArray(m12, 959 | "[[0, 1, 2],\n" 960 | " [4, 5, 6]]"); 961 | CheckNdArray(m13, 962 | "[[0, 2, 4],\n" 963 | " [3, 5, 7]]"); 964 | CheckNdArray(m23, 965 | "[[0, 1, 2],\n" 966 | " [1, 2, 3]]"); 967 | } 968 | 969 | SECTION("Add broadcast high-dim") { 970 | auto m1 = NdArray::Arange(6.f).reshape(1, 2, 1, 1, 3); 971 | auto m2 = NdArray::Arange(2.f).reshape(2, 1); 972 | auto m3 = NdArray::Arange(3.f).reshape(1, 3); 973 | auto m12 = m1 + m2; 974 | auto m13 = m1 + m3; 975 | CHECK(m12.shape() == Shape{1, 2, 1, 2, 3}); 976 | CHECK(m13.shape() == Shape{1, 2, 1, 1, 3}); 977 | CheckNdArray(m12, 978 | "[[[[[0, 1, 2],\n" 979 | " [1, 2, 3]]],\n" 980 | " [[[3, 4, 5],\n" 981 | " [4, 5, 6]]]]]"); 982 | CheckNdArray(m13, 983 | "[[[[[0, 2, 4]]],\n" 984 | " [[[3, 5, 7]]]]]"); 985 | } 986 | 987 | SECTION("Add empty") { 988 | NdArray m1; 989 | auto m2 = m1 + 1.f; 990 | CHECK(m1.shape() == Shape{0}); 991 | CHECK(m2.shape() == Shape{0}); 992 | CHECK(m2.size() == 0); 993 | } 994 | 995 | SECTION("Sub/Mul/Div") { 996 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 997 | auto m2 = NdArray::Arange(3.f).reshape(3); 998 | auto m_sub = m1 - m2; 999 | auto m_mul = m1 * m2; 1000 | auto m_div = m1 / m2; 1001 | CHECK(m_sub.shape() == Shape{2, 3}); 1002 | CHECK(m_mul.shape() == Shape{2, 3}); 1003 | CHECK(m_div.shape() == Shape{2, 3}); 1004 | CheckNdArray(m_sub, 1005 | "[[0, 0, 0],\n" 1006 | " [3, 3, 3]]"); 1007 | CheckNdArray(m_mul, 1008 | "[[0, 1, 4],\n" 1009 | " [0, 4, 10]]"); 1010 | ResolveAmbiguous(m_div); // -nan -> nan 1011 | CheckNdArray(m_div, 1012 | "[[nan, 1, 1],\n" 1013 | " [inf, 4, 2.5]]"); 1014 | } 1015 | 1016 | SECTION("Arithmetic operators (NdArray, float)") { 1017 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 1018 | auto m_add = m1 + 10.f; 1019 | auto m_sub = m1 - 10.f; 1020 | auto m_mul = m1 * 10.f; 1021 | auto m_div = m1 / 10.f; 1022 | CheckNdArray(m_add, 1023 | "[[10, 11, 12],\n" 1024 | " [13, 14, 15]]"); 1025 | CheckNdArray(m_sub, 1026 | "[[-10, -9, -8],\n" 1027 | " [-7, -6, -5]]"); 1028 | CheckNdArray(m_mul, 1029 | "[[0, 10, 20],\n" 1030 | " [30, 40, 50]]"); 1031 | CheckNdArray(m_div, 1032 | "[[0, 0.1, 0.2],\n" 1033 | " [0.3, 0.4, 0.5]]"); 1034 | } 1035 | 1036 | SECTION("Arithmetic operators (float, NdArray)") { 1037 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 1038 | auto m_add = 10.f + m1; 1039 | auto m_sub = 10.f - m1; 1040 | auto m_mul = 10.f * m1; 1041 | auto m_div = 10.f / m1; 1042 | CheckNdArray(m_add, 1043 | "[[10, 11, 12],\n" 1044 | " [13, 14, 15]]"); 1045 | CheckNdArray(m_sub, 1046 | "[[10, 9, 8],\n" 1047 | " [7, 6, 5]]"); 1048 | CheckNdArray(m_mul, 1049 | "[[0, 10, 20],\n" 1050 | " [30, 40, 50]]"); 1051 | CheckNdArray(m_div, 1052 | "[[inf, 10, 5],\n" 1053 | " [3.33333, 2.5, 2]]"); 1054 | } 1055 | 1056 | SECTION("Comparison operators (NdArray, NdArray)") { 1057 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 1058 | auto m2 = NdArray::Arange(2.f).reshape(2, 1) + 3.f; 1059 | CheckNdArray(m1 == m2, 1060 | "[[0, 0, 0],\n" 1061 | " [0, 1, 0]]"); 1062 | CheckNdArray(m1 != m2, 1063 | "[[1, 1, 1],\n" 1064 | " [1, 0, 1]]"); 1065 | CheckNdArray(m1 > m2, 1066 | "[[0, 0, 0],\n" 1067 | " [0, 0, 1]]"); 1068 | CheckNdArray(m1 >= m2, 1069 | "[[0, 0, 0],\n" 1070 | " [0, 1, 1]]"); 1071 | CheckNdArray(m1 < m2, 1072 | "[[1, 1, 1],\n" 1073 | " [1, 0, 0]]"); 1074 | CheckNdArray(m1 <= m2, 1075 | "[[1, 1, 1],\n" 1076 | " [1, 1, 0]]"); 1077 | } 1078 | 1079 | SECTION("Comparison operators (NdArray, float)") { 1080 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 1081 | CheckNdArray(m1 == 1.f, 1082 | "[[0, 1, 0],\n" 1083 | " [0, 0, 0]]"); 1084 | CheckNdArray(m1 != 1.f, 1085 | "[[1, 0, 1],\n" 1086 | " [1, 1, 1]]"); 1087 | CheckNdArray(m1 > 1.f, 1088 | "[[0, 0, 1],\n" 1089 | " [1, 1, 1]]"); 1090 | CheckNdArray(m1 >= 1.f, 1091 | "[[0, 1, 1],\n" 1092 | " [1, 1, 1]]"); 1093 | CheckNdArray(m1 < 1.f, 1094 | "[[1, 0, 0],\n" 1095 | " [0, 0, 0]]"); 1096 | CheckNdArray(m1 <= 1.f, 1097 | "[[1, 1, 0],\n" 1098 | " [0, 0, 0]]"); 1099 | } 1100 | 1101 | SECTION("Comparison operators (float, NdArray)") { 1102 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 1103 | CheckNdArray(1.f == m1, 1104 | "[[0, 1, 0],\n" 1105 | " [0, 0, 0]]"); 1106 | CheckNdArray(1.f != m1, 1107 | "[[1, 0, 1],\n" 1108 | " [1, 1, 1]]"); 1109 | CheckNdArray(1.f > m1, 1110 | "[[1, 0, 0],\n" 1111 | " [0, 0, 0]]"); 1112 | CheckNdArray(1.f >= m1, 1113 | "[[1, 1, 0],\n" 1114 | " [0, 0, 0]]"); 1115 | CheckNdArray(1.f < m1, 1116 | "[[0, 0, 1],\n" 1117 | " [1, 1, 1]]"); 1118 | CheckNdArray(1.f <= m1, 1119 | "[[0, 1, 1],\n" 1120 | " [1, 1, 1]]"); 1121 | } 1122 | 1123 | // --------------------------- In-place Operator --------------------------- 1124 | SECTION("Arithmetic operators (NdArray, NdArray) (in-place both)") { 1125 | CheckNdArrayInplace( 1126 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f), 1127 | "[[0, 2, 4],\n" 1128 | " [3, 5, 7]]", 1129 | static_cast(operator+)); 1130 | CheckNdArrayInplace( 1131 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f), 1132 | "[[0, 0, 0],\n" 1133 | " [3, 3, 3]]", 1134 | static_cast(operator-)); 1135 | CheckNdArrayInplace( 1136 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f), 1137 | "[[0, 1, 4],\n" 1138 | " [0, 4, 10]]", 1139 | static_cast(operator*)); 1140 | CheckNdArrayInplace( 1141 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f) + 1.f, 1142 | "[[0, 0.5, 0.666667],\n" 1143 | " [3, 2, 1.66667]]", 1144 | static_cast(operator/)); 1145 | } 1146 | 1147 | SECTION("Arithmetic operators (NdArray, NdArray) (in-place right)") { 1148 | auto m1 = NdArray::Arange(3.f); 1149 | auto m2 = NdArray::Arange(3.f) + 1.f; 1150 | CheckNdArrayInplace( 1151 | m1, NdArray::Arange(6.f).reshape(2, 3), 1152 | "[[0, 2, 4],\n" 1153 | " [3, 5, 7]]", 1154 | static_cast(operator+)); 1155 | CheckNdArrayInplace( 1156 | m1, NdArray::Arange(6.f).reshape(2, 3), 1157 | "[[0, 0, 0],\n" 1158 | " [-3, -3, -3]]", 1159 | static_cast(operator-)); 1160 | CheckNdArrayInplace( 1161 | m1, NdArray::Arange(6.f).reshape(2, 3), 1162 | "[[0, 1, 4],\n" 1163 | " [0, 4, 10]]", 1164 | static_cast(operator*)); 1165 | CheckNdArrayInplace( 1166 | m2, NdArray::Arange(6.f).reshape(2, 3), 1167 | "[[inf, 2, 1.5],\n" 1168 | " [0.333333, 0.5, 0.6]]", 1169 | static_cast(operator/)); 1170 | } 1171 | 1172 | SECTION("Arithmetic operators (NdArray, NdArray) (in-place left)") { 1173 | auto m2 = NdArray::Arange(3.f); 1174 | auto m3 = NdArray::Arange(3.f) + 1.f; 1175 | CheckNdArrayInplace( 1176 | NdArray::Arange(6.f).reshape(2, 3), m2, 1177 | "[[0, 2, 4],\n" 1178 | " [3, 5, 7]]", 1179 | static_cast(operator+)); 1180 | CheckNdArrayInplace( 1181 | NdArray::Arange(6.f).reshape(2, 3), m2, 1182 | "[[0, 0, 0],\n" 1183 | " [3, 3, 3]]", 1184 | static_cast(operator-)); 1185 | CheckNdArrayInplace( 1186 | NdArray::Arange(6.f).reshape(2, 3), m2, 1187 | "[[0, 1, 4],\n" 1188 | " [0, 4, 10]]", 1189 | static_cast(operator*)); 1190 | CheckNdArrayInplace( 1191 | NdArray::Arange(6.f).reshape(2, 3), m3, 1192 | "[[0, 0.5, 0.666667],\n" 1193 | " [3, 2, 1.66667]]", 1194 | static_cast(operator/)); 1195 | } 1196 | 1197 | SECTION("Arithmetic operators (NdArray, float) (inplace)") { 1198 | CheckNdArrayInplace( 1199 | NdArray::Arange(3.f), 2.f, "[2, 3, 4]", 1200 | static_cast(operator+)); 1201 | CheckNdArrayInplace( 1202 | NdArray::Arange(3.f), 2.f, "[-2, -1, 0]", 1203 | static_cast(operator-)); 1204 | CheckNdArrayInplace( 1205 | NdArray::Arange(3.f), 2.f, "[0, 2, 4]", 1206 | static_cast(operator*)); 1207 | CheckNdArrayInplace( 1208 | NdArray::Arange(3.f), 2.f, "[0, 0.5, 1]", 1209 | static_cast(operator/)); 1210 | } 1211 | 1212 | SECTION("Arithmetic operators (float, NdArray) (inplace)") { 1213 | CheckNdArrayInplace( 1214 | 2.f, NdArray::Arange(3.f), "[2, 3, 4]", 1215 | static_cast(operator+)); 1216 | CheckNdArrayInplace( 1217 | 2.f, NdArray::Arange(3.f), "[2, 1, 0]", 1218 | static_cast(operator-)); 1219 | CheckNdArrayInplace( 1220 | 2.f, NdArray::Arange(3.f), "[0, 2, 4]", 1221 | static_cast(operator*)); 1222 | CheckNdArrayInplace( 1223 | 2.f, NdArray::Arange(3.f), "[inf, 2, 1]", 1224 | static_cast(operator/)); 1225 | } 1226 | 1227 | SECTION("Comparison operators (NdArray, NdArray) (inplace both)") { 1228 | CheckNdArrayInplace( 1229 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[0, 1, 0]", 1230 | static_cast(operator==)); 1231 | CheckNdArrayInplace( 1232 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[1, 0, 1]", 1233 | static_cast(operator!=)); 1234 | CheckNdArrayInplace( 1235 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[0, 0, 1]", 1236 | static_cast(operator>)); 1237 | CheckNdArrayInplace( 1238 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[0, 1, 1]", 1239 | static_cast(operator>=)); 1240 | CheckNdArrayInplace( 1241 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[1, 0, 0]", 1242 | static_cast(operator<)); 1243 | CheckNdArrayInplace( 1244 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[1, 1, 0]", 1245 | static_cast(operator<=)); 1246 | } 1247 | 1248 | SECTION("Comparison operators (NdArray, NdArray) (inplace right)") { 1249 | auto m2 = NdArray::Zeros(1) + 1.f; 1250 | CheckNdArrayInplace(NdArray::Arange(3.f), m2, "[0, 1, 0]", 1251 | static_cast( 1252 | operator==)); 1253 | CheckNdArrayInplace(NdArray::Arange(3.f), m2, "[1, 0, 1]", 1254 | static_cast( 1255 | operator!=)); 1256 | CheckNdArrayInplace( 1257 | NdArray::Arange(3.f), m2, "[0, 0, 1]", 1258 | static_cast(operator>)); 1259 | CheckNdArrayInplace(NdArray::Arange(3.f), m2, "[0, 1, 1]", 1260 | static_cast( 1261 | operator>=)); 1262 | CheckNdArrayInplace( 1263 | NdArray::Arange(3.f), m2, "[1, 0, 0]", 1264 | static_cast(operator<)); 1265 | CheckNdArrayInplace(NdArray::Arange(3.f), m2, "[1, 1, 0]", 1266 | static_cast( 1267 | operator<=)); 1268 | } 1269 | 1270 | SECTION("Comparison operators (NdArray, NdArray) (inplace left)") { 1271 | auto m1 = NdArray::Zeros(1) + 1.f; 1272 | CheckNdArrayInplace(m1, NdArray::Arange(3.f), "[0, 1, 0]", 1273 | static_cast( 1274 | operator==)); 1275 | CheckNdArrayInplace(m1, NdArray::Arange(3.f), "[1, 0, 1]", 1276 | static_cast( 1277 | operator!=)); 1278 | CheckNdArrayInplace( 1279 | m1, NdArray::Arange(3.f), "[1, 0, 0]", 1280 | static_cast(operator>)); 1281 | CheckNdArrayInplace(m1, NdArray::Arange(3.f), "[1, 1, 0]", 1282 | static_cast( 1283 | operator>=)); 1284 | CheckNdArrayInplace( 1285 | m1, NdArray::Arange(3.f), "[0, 0, 1]", 1286 | static_cast(operator<)); 1287 | CheckNdArrayInplace(m1, NdArray::Arange(3.f), "[0, 1, 1]", 1288 | static_cast( 1289 | operator<=)); 1290 | } 1291 | 1292 | SECTION("Comparison operators (NdArray, float) (inplace)") { 1293 | CheckNdArrayInplace( 1294 | NdArray::Arange(3.f), 1.f, "[0, 1, 0]", 1295 | static_cast(operator==)); 1296 | CheckNdArrayInplace( 1297 | NdArray::Arange(3.f), 1.f, "[1, 0, 1]", 1298 | static_cast(operator!=)); 1299 | CheckNdArrayInplace( 1300 | NdArray::Arange(3.f), 1.f, "[0, 0, 1]", 1301 | static_cast(operator>)); 1302 | CheckNdArrayInplace( 1303 | NdArray::Arange(3.f), 1.f, "[0, 1, 1]", 1304 | static_cast(operator>=)); 1305 | CheckNdArrayInplace( 1306 | NdArray::Arange(3.f), 1.f, "[1, 0, 0]", 1307 | static_cast(operator<)); 1308 | CheckNdArrayInplace( 1309 | NdArray::Arange(3.f), 1.f, "[1, 1, 0]", 1310 | static_cast(operator<=)); 1311 | } 1312 | 1313 | SECTION("Comparison operators (NdArray, float) (inplace)") { 1314 | CheckNdArrayInplace( 1315 | 1.f, NdArray::Arange(3.f), "[0, 1, 0]", 1316 | static_cast(operator==)); 1317 | CheckNdArrayInplace( 1318 | 1.f, NdArray::Arange(3.f), "[1, 0, 1]", 1319 | static_cast(operator!=)); 1320 | CheckNdArrayInplace( 1321 | 1.f, NdArray::Arange(3.f), "[1, 0, 0]", 1322 | static_cast(operator>)); 1323 | CheckNdArrayInplace( 1324 | 1.f, NdArray::Arange(3.f), "[1, 1, 0]", 1325 | static_cast(operator>=)); 1326 | CheckNdArrayInplace( 1327 | 1.f, NdArray::Arange(3.f), "[0, 0, 1]", 1328 | static_cast(operator<)); 1329 | CheckNdArrayInplace( 1330 | 1.f, NdArray::Arange(3.f), "[0, 1, 1]", 1331 | static_cast(operator<=)); 1332 | } 1333 | 1334 | SECTION("Compound assignment operators (NdArray, NdArray) (in-plafce, &)") { 1335 | auto m0 = NdArray::Arange(6.f).reshape(2, 3); 1336 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 1337 | auto m2 = NdArray::Arange(6.f).reshape(2, 3); 1338 | auto m3 = NdArray::Arange(6.f).reshape(2, 3); 1339 | auto m4 = NdArray::Arange(6.f).reshape(2, 3); 1340 | auto m5 = NdArray::Arange(3.f); 1341 | auto m0_id = m0.id(); 1342 | auto m1_id = m1.id(); 1343 | auto m2_id = m2.id(); 1344 | auto m3_id = m3.id(); 1345 | auto m4_id = m4.id(); 1346 | m0 += m0; 1347 | m1 += NdArray::Arange(3.f); 1348 | m2 -= NdArray::Arange(3.f); 1349 | m3 *= NdArray::Arange(3.f); 1350 | m4 /= NdArray::Arange(3.f); 1351 | CheckNdArray(m0, 1352 | "[[0, 2, 4],\n" 1353 | " [6, 8, 10]]"); 1354 | CheckNdArray(m1, 1355 | "[[0, 2, 4],\n" 1356 | " [3, 5, 7]]"); 1357 | CheckNdArray(m2, 1358 | "[[0, 0, 0],\n" 1359 | " [3, 3, 3]]"); 1360 | CheckNdArray(m3, 1361 | "[[0, 1, 4],\n" 1362 | " [0, 4, 10]]"); 1363 | ResolveAmbiguous(m4); // -nan -> nan 1364 | CheckNdArray(m4, 1365 | "[[nan, 1, 1],\n" 1366 | " [inf, 4, 2.5]]"); 1367 | CHECK(m0.id() == m0_id); // in-place 1368 | CHECK(m1.id() == m1_id); 1369 | CHECK(m2.id() == m2_id); 1370 | CHECK(m3.id() == m3_id); 1371 | CHECK(m4.id() == m4_id); 1372 | // size change is not allowed 1373 | CHECK_THROWS(m5 += m0); 1374 | auto m6 = m0.reshape(2, 1, 3); 1375 | CHECK_THROWS(m6 *= m5.reshape(3, 1)); 1376 | } 1377 | 1378 | SECTION("Compound assignment operators (NdArray, float) (in-plafce, &)") { 1379 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 1380 | auto m2 = NdArray::Arange(6.f).reshape(2, 3); 1381 | auto m3 = NdArray::Arange(6.f).reshape(2, 3); 1382 | auto m4 = NdArray::Arange(6.f).reshape(2, 3); 1383 | auto m1_id = m1.id(); 1384 | auto m2_id = m2.id(); 1385 | auto m3_id = m3.id(); 1386 | auto m4_id = m4.id(); 1387 | m1 += 10.f; 1388 | m2 -= 10.f; 1389 | m3 *= 10.f; 1390 | m4 /= 10.f; 1391 | CheckNdArray(m1, 1392 | "[[10, 11, 12],\n" 1393 | " [13, 14, 15]]"); 1394 | CheckNdArray(m2, 1395 | "[[-10, -9, -8],\n" 1396 | " [-7, -6, -5]]"); 1397 | CheckNdArray(m3, 1398 | "[[0, 10, 20],\n" 1399 | " [30, 40, 50]]"); 1400 | CheckNdArray(m4, 1401 | "[[0, 0.1, 0.2],\n" 1402 | " [0.3, 0.4, 0.5]]"); 1403 | CHECK(m1.id() == m1_id); 1404 | CHECK(m2.id() == m2_id); 1405 | CHECK(m3.id() == m3_id); 1406 | CHECK(m4.id() == m4_id); 1407 | } 1408 | 1409 | SECTION("Compound assignment operators (in-plafce, &&)") { 1410 | // (NdArray, NdArray) 1411 | auto m1 = NdArray::Arange(6.f).reshape(2, 3) += NdArray::Arange(3.f); 1412 | auto m2 = NdArray::Arange(6.f).reshape(2, 3) -= NdArray::Arange(3.f); 1413 | auto m3 = NdArray::Arange(6.f).reshape(2, 3) *= NdArray::Arange(3.f); 1414 | auto m4 = NdArray::Arange(6.f).reshape(2, 3) /= NdArray::Arange(3.f); 1415 | CheckNdArray(m1, 1416 | "[[0, 2, 4],\n" 1417 | " [3, 5, 7]]"); 1418 | CheckNdArray(m2, 1419 | "[[0, 0, 0],\n" 1420 | " [3, 3, 3]]"); 1421 | CheckNdArray(m3, 1422 | "[[0, 1, 4],\n" 1423 | " [0, 4, 10]]"); 1424 | ResolveAmbiguous(m4); // -nan -> nan 1425 | CheckNdArray(m4, 1426 | "[[nan, 1, 1],\n" 1427 | " [inf, 4, 2.5]]"); 1428 | // (NdArray, float) 1429 | auto m5 = NdArray::Arange(6.f).reshape(2, 3) += 10.f; 1430 | auto m6 = NdArray::Arange(6.f).reshape(2, 3) -= 10.f; 1431 | auto m7 = NdArray::Arange(6.f).reshape(2, 3) *= 10.f; 1432 | auto m8 = NdArray::Arange(6.f).reshape(2, 3) /= 10.f; 1433 | CheckNdArray(m5, 1434 | "[[10, 11, 12],\n" 1435 | " [13, 14, 15]]"); 1436 | CheckNdArray(m6, 1437 | "[[-10, -9, -8],\n" 1438 | " [-7, -6, -5]]"); 1439 | CheckNdArray(m7, 1440 | "[[0, 10, 20],\n" 1441 | " [30, 40, 50]]"); 1442 | CheckNdArray(m8, 1443 | "[[0, 0.1, 0.2],\n" 1444 | " [0.3, 0.4, 0.5]]"); 1445 | } 1446 | 1447 | // --------------------------- Operator function --------------------------- 1448 | SECTION("Single") { 1449 | auto m1 = NdArray::Arange(3.f); 1450 | auto m2 = Positive(m1); 1451 | auto m3 = Negative(m1); 1452 | m1[0] = -1.f; 1453 | CheckNdArray(m1, "[-1, 1, 2]"); 1454 | CheckNdArray(m2, "[0, 1, 2]"); 1455 | CheckNdArray(m3, "[-0, -1, -2]"); 1456 | } 1457 | 1458 | SECTION("Function Arithmetic (NdArray, NdArray)") { 1459 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 1460 | auto m2 = NdArray::Arange(3.f); 1461 | auto m_add = Add(m1, m2); 1462 | auto m_sub = Subtract(m1, m2); 1463 | auto m_mul = Multiply(m1, m2); 1464 | auto m_div = Divide(m1, m2); 1465 | CheckNdArray(m_add, 1466 | "[[0, 2, 4],\n" 1467 | " [3, 5, 7]]"); 1468 | CheckNdArray(m_sub, 1469 | "[[0, 0, 0],\n" 1470 | " [3, 3, 3]]"); 1471 | CheckNdArray(m_mul, 1472 | "[[0, 1, 4],\n" 1473 | " [0, 4, 10]]"); 1474 | ResolveAmbiguous(m_div); // -nan -> nan 1475 | CheckNdArray(m_div, 1476 | "[[nan, 1, 1],\n" 1477 | " [inf, 4, 2.5]]"); 1478 | } 1479 | 1480 | SECTION("Function Arithmetic (NdArray, float)") { 1481 | auto m1 = NdArray::Arange(3.f); 1482 | auto m_add = Add(m1, 2.f); 1483 | auto m_sub = Subtract(m1, 2.f); 1484 | auto m_mul = Multiply(m1, 2.f); 1485 | auto m_div = Divide(m1, 2.f); 1486 | CheckNdArray(m_add, "[2, 3, 4]"); 1487 | CheckNdArray(m_sub, "[-2, -1, 0]"); 1488 | CheckNdArray(m_mul, "[0, 2, 4]"); 1489 | CheckNdArray(m_div, "[0, 0.5, 1]"); 1490 | } 1491 | 1492 | SECTION("Function Arithmetic (float, NdArray)") { 1493 | auto m1 = NdArray::Arange(3.f); 1494 | auto m_add = Add(2.f, m1); 1495 | auto m_sub = Subtract(2.f, m1); 1496 | auto m_mul = Multiply(2.f, m1); 1497 | auto m_div = Divide(2.f, m1); 1498 | CheckNdArray(m_add, "[2, 3, 4]"); 1499 | CheckNdArray(m_sub, "[2, 1, 0]"); 1500 | CheckNdArray(m_mul, "[0, 2, 4]"); 1501 | CheckNdArray(m_div, "[inf, 2, 1]"); 1502 | } 1503 | 1504 | SECTION("Function Comparison") { 1505 | auto m1 = NdArray::Arange(3.f); 1506 | auto m2 = NdArray::Zeros(1) + 1.f; 1507 | CheckNdArray(Equal(m1, m2), "[0, 1, 0]"); 1508 | CheckNdArray(NotEqual(m1, m2), "[1, 0, 1]"); 1509 | CheckNdArray(Greater(m1, m2), "[0, 0, 1]"); 1510 | CheckNdArray(GreaterEqual(m1, m2), "[0, 1, 1]"); 1511 | CheckNdArray(Less(m1, m2), "[1, 0, 0]"); 1512 | CheckNdArray(LessEqual(m1, m2), "[1, 1, 0]"); 1513 | CheckNdArray(Equal(m1, 1.f), "[0, 1, 0]"); 1514 | CheckNdArray(NotEqual(m1, 1.f), "[1, 0, 1]"); 1515 | CheckNdArray(Greater(m1, 1.f), "[0, 0, 1]"); 1516 | CheckNdArray(GreaterEqual(m1, 1.f), "[0, 1, 1]"); 1517 | CheckNdArray(Less(m1, 1.f), "[1, 0, 0]"); 1518 | CheckNdArray(LessEqual(m1, 1.f), "[1, 1, 0]"); 1519 | CheckNdArray(Equal(1.f, m1), "[0, 1, 0]"); 1520 | CheckNdArray(NotEqual(1.f, m1), "[1, 0, 1]"); 1521 | CheckNdArray(Greater(1.f, m1), "[1, 0, 0]"); 1522 | CheckNdArray(GreaterEqual(1.f, m1), "[1, 1, 0]"); 1523 | CheckNdArray(Less(1.f, m1), "[0, 0, 1]"); 1524 | CheckNdArray(LessEqual(1.f, m1), "[0, 1, 1]"); 1525 | } 1526 | 1527 | SECTION("Function Dot") { 1528 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 1529 | auto m2 = NdArray::Arange(3.f); 1530 | auto m12 = Dot(m1, m2); 1531 | auto m_a = Dot(m2, {2.f}); 1532 | auto m_b = Dot({2.f}, m2); 1533 | CheckNdArray(m12, "[5, 14]"); 1534 | CheckNdArray(m_a, "[0, 2, 4]"); 1535 | CheckNdArray(m_b, "[0, 2, 4]"); 1536 | } 1537 | 1538 | SECTION("Function Cross") { 1539 | NdArray m1 = {1.f, 2.f, 3.f}; 1540 | NdArray m2 = {4.f, 5.f, 6.f}; 1541 | auto m12 = Cross(m1, m2); 1542 | CheckNdArray(m12, "[-3, 6, -3]"); 1543 | } 1544 | 1545 | SECTION("Function Basic Math") { 1546 | auto m1 = NdArray::Arange(3.f); 1547 | auto m2 = NdArray::Arange(7.f) / 3.f - 1.f; 1548 | CheckNdArray(-m1, "[-0, -1, -2]"); 1549 | CheckNdArray(Abs(-m1), "[0, 1, 2]"); 1550 | CheckNdArray(Sign(m2), "[-1, -1, -1, 0, 1, 1, 1]"); 1551 | CheckNdArray(Ceil(m2), "[-1, -0, -0, 0, 1, 1, 1]"); 1552 | CheckNdArray(Floor(m2), "[-1, -1, -1, 0, 0, 0, 1]"); 1553 | CheckNdArray(Clip(m2, -0.5f, 0.4f), 1554 | "[-0.5, -0.5, -0.333333, 0, 0.333333, 0.4, 0.4]"); 1555 | CheckNdArray(Sqrt(m1), "[0, 1, 1.41421]"); 1556 | CheckNdArray(Exp(m1), "[1, 2.71828, 7.38906]"); 1557 | CheckNdArray(Log(m1), "[-inf, 0, 0.693147]"); 1558 | CheckNdArray(Square(m1), "[0, 1, 4]"); 1559 | CheckNdArray(Power(m1, m1 + 1.f), "[0, 1, 8]"); 1560 | CheckNdArray(Power(m1, 4.f), "[0, 1, 16]"); 1561 | CheckNdArray(Power(4.f, m1), "[1, 4, 16]"); 1562 | } 1563 | 1564 | SECTION("Function Trigonometric") { 1565 | auto m1 = NdArray::Arange(3.f); 1566 | CheckNdArray(Sin(m1), "[0, 0.841471, 0.909297]"); 1567 | CheckNdArray(Cos(m1), "[1, 0.540302, -0.416147]"); 1568 | CheckNdArray(Tan(m1), "[0, 1.55741, -2.18504]"); 1569 | } 1570 | 1571 | SECTION("Function Inverse-Trigonometric") { 1572 | auto m1 = NdArray::Arange(3) - 1.f; 1573 | auto m2 = NdArray::Arange(3) * 100.f; 1574 | CheckNdArray(ArcSin(m1), "[-1.5708, 0, 1.5708]"); 1575 | CheckNdArray(ArcCos(m1), "[3.14159, 1.5708, 0]"); 1576 | CheckNdArray(ArcTan(m2), "[0, 1.5608, 1.5658]"); 1577 | CheckNdArray(ArcTan2(m1, m2), "[-1.5708, 0, 0.00499996]"); 1578 | CheckNdArray(ArcTan2(m1, 2.f), "[-0.463648, 0, 0.463648]"); 1579 | CheckNdArray(ArcTan2(2.f, m1), "[2.03444, 1.5708, 1.10715]"); 1580 | } 1581 | 1582 | SECTION("Function Axis") { 1583 | auto m1 = NdArray::Arange(36.f).reshape(2, 3, 2, 3); 1584 | CheckNdArray(Sum(m1, {1, 3}), 1585 | "[[63, 90],\n" 1586 | " [225, 252]]"); 1587 | CheckNdArray(Sum(m1, {-1, -3}), // negative axis 1588 | "[[63, 90],\n" 1589 | " [225, 252]]"); 1590 | auto m2 = NdArray::Arange(12.f).reshape(2, 3, 2) - 6.f; 1591 | CheckNdArray(Min(m2, {2, 1}), "[-6, 0]"); 1592 | CheckNdArray(Min(m2, {-2, -1}), "[-6, 0]"); // negative axis 1593 | CheckNdArray(Max(m2, {2, 1}), "[-1, 5]"); 1594 | CheckNdArray(Max(m2, {-2, -1}), "[-1, 5]"); // negative axis 1595 | CheckNdArray(Mean(m2, {2, 1}), "[-3.5, 2.5]"); 1596 | CheckNdArray(Mean(m2, {-2, -1}), "[-3.5, 2.5]"); // negative axis 1597 | } 1598 | 1599 | SECTION("Function Shape") { 1600 | auto m1 = NdArray::Arange(6.f).reshape(1, 2, 1, 3, 1); 1601 | CHECK(Reshape(m1, {2, 1, 3}).shape() == Shape{2, 1, 3}); 1602 | CHECK(Squeeze(m1).shape() == Shape{2, 3}); 1603 | CHECK(Squeeze(m1, {0, 2}).shape() == Shape{2, 3, 1}); 1604 | CHECK(Squeeze(m1, {-1, -3}).shape() == Shape{1, 2, 3}); 1605 | CHECK(ExpandDims(m1, 0).shape() == Shape{1, 1, 2, 1, 3, 1}); 1606 | CHECK(ExpandDims(m1, 2).shape() == Shape{1, 2, 1, 1, 3, 1}); 1607 | CHECK(ExpandDims(m1, 5).shape() == Shape{1, 2, 1, 3, 1, 1}); 1608 | CHECK(ExpandDims(m1, -1).shape() == Shape{1, 2, 1, 3, 1, 1}); 1609 | CHECK(ExpandDims(m1.ravel(), -1).shape() == Shape{6, 1}); 1610 | CHECK_THROWS(Squeeze(m1, {-2})); 1611 | CHECK_THROWS(Squeeze(m1, {5})); 1612 | CHECK_THROWS(ExpandDims(m1, 6)); 1613 | } 1614 | 1615 | SECTION("Function Stack") { 1616 | auto m1 = NdArray::Arange(12.f).reshape(4, 3); 1617 | auto m2 = NdArray::Arange(12.f).reshape(4, 3) + 1.f; 1618 | CheckNdArray(Stack({m1, m2}, 0), 1619 | "[[[0, 1, 2],\n" 1620 | " [3, 4, 5],\n" 1621 | " [6, 7, 8],\n" 1622 | " [9, 10, 11]],\n" 1623 | " [[1, 2, 3],\n" 1624 | " [4, 5, 6],\n" 1625 | " [7, 8, 9],\n" 1626 | " [10, 11, 12]]]"); 1627 | CheckNdArray(Stack({m1, m2}, 1), 1628 | "[[[0, 1, 2],\n" 1629 | " [1, 2, 3]],\n" 1630 | " [[3, 4, 5],\n" 1631 | " [4, 5, 6]],\n" 1632 | " [[6, 7, 8],\n" 1633 | " [7, 8, 9]],\n" 1634 | " [[9, 10, 11],\n" 1635 | " [10, 11, 12]]]"); 1636 | CheckNdArray(Stack({m1, m2}, 2), 1637 | "[[[0, 1],\n" 1638 | " [1, 2],\n" 1639 | " [2, 3]],\n" 1640 | " [[3, 4],\n" 1641 | " [4, 5],\n" 1642 | " [5, 6]],\n" 1643 | " [[6, 7],\n" 1644 | " [7, 8],\n" 1645 | " [8, 9]],\n" 1646 | " [[9, 10],\n" 1647 | " [10, 11],\n" 1648 | " [11, 12]]]"); 1649 | CheckNdArray(Stack({m1, m2}, -3), // negative axis 1650 | "[[[0, 1, 2],\n" 1651 | " [3, 4, 5],\n" 1652 | " [6, 7, 8],\n" 1653 | " [9, 10, 11]],\n" 1654 | " [[1, 2, 3],\n" 1655 | " [4, 5, 6],\n" 1656 | " [7, 8, 9],\n" 1657 | " [10, 11, 12]]]"); 1658 | CHECK_THROWS(Stack({m1, m2}, 3)); 1659 | CHECK_THROWS(Stack({m1, m2}, -4)); // negative axis 1660 | } 1661 | 1662 | SECTION("Function Concatenate") { 1663 | auto m1 = NdArray::Arange(12.f).reshape(4, 3); 1664 | auto m2 = NdArray::Arange(6.f).reshape(2, 3) + 1.f; 1665 | CheckNdArray(Concatenate({m1, m2}, 0), 1666 | "[[0, 1, 2],\n" 1667 | " [3, 4, 5],\n" 1668 | " [6, 7, 8],\n" 1669 | " [9, 10, 11],\n" 1670 | " [1, 2, 3],\n" 1671 | " [4, 5, 6]]"); 1672 | CHECK_THROWS(Concatenate({m1, m2}, 1)); 1673 | CHECK_THROWS(Concatenate({m1, m2}, 2)); 1674 | CHECK_THROWS(Concatenate({m1, m2}, -1)); 1675 | CHECK_THROWS(Concatenate({m1, m2}, -3)); 1676 | auto m3 = NdArray::Arange(8.f).reshape(4, 2) + 1.f; 1677 | CheckNdArray(Concatenate({m1, m3}, 1), 1678 | "[[0, 1, 2, 1, 2],\n" 1679 | " [3, 4, 5, 3, 4],\n" 1680 | " [6, 7, 8, 5, 6],\n" 1681 | " [9, 10, 11, 7, 8]]"); 1682 | CHECK_THROWS(Concatenate({m1, m3}, 0)); 1683 | CHECK_THROWS(Concatenate({m1, m3}, 2)); 1684 | CHECK_THROWS(Concatenate({m1, m3}, -2)); 1685 | CHECK_THROWS(Concatenate({m1, m3}, -3)); 1686 | 1687 | auto m4 = NdArray::Arange(12.f).reshape(4, 1, 3); 1688 | auto m5 = NdArray::Arange(6.f).reshape(2, 1, 3) + 1.f; 1689 | CheckNdArray(Concatenate({m4, m5}, 0), 1690 | "[[[0, 1, 2]],\n" 1691 | " [[3, 4, 5]],\n" 1692 | " [[6, 7, 8]],\n" 1693 | " [[9, 10, 11]],\n" 1694 | " [[1, 2, 3]],\n" 1695 | " [[4, 5, 6]]]"); 1696 | CHECK_THROWS(Concatenate({m4, m5}, 1)); 1697 | CHECK_THROWS(Concatenate({m4, m5}, 2)); 1698 | CHECK_THROWS(Concatenate({m4, m5}, -1)); 1699 | CHECK_THROWS(Concatenate({m4, m5}, -2)); 1700 | CHECK_THROWS(Concatenate({m4, m5}, -4)); 1701 | auto m6 = NdArray::Arange(8.f).reshape(4, 1, 2) + 1.f; 1702 | CheckNdArray(Concatenate({m4, m6}, 2), 1703 | "[[[0, 1, 2, 1, 2]],\n" 1704 | " [[3, 4, 5, 3, 4]],\n" 1705 | " [[6, 7, 8, 5, 6]],\n" 1706 | " [[9, 10, 11, 7, 8]]]"); 1707 | CHECK_THROWS(Concatenate({m4, m6}, 0)); 1708 | CHECK_THROWS(Concatenate({m4, m6}, 1)); 1709 | CHECK_THROWS(Concatenate({m4, m6}, -2)); 1710 | CHECK_THROWS(Concatenate({m4, m6}, -3)); 1711 | CHECK_THROWS(Concatenate({m4, m6}, -4)); 1712 | } 1713 | 1714 | SECTION("Function Split by indices") { 1715 | auto m1 = NdArray::Arange(16.f).reshape(2, 4, 2); 1716 | 1717 | auto r0 = Split(m1, {1, 1}, 1); 1718 | CHECK(r0.size() == 3); 1719 | CheckNdArray(r0[0], 1720 | "[[[0, 1]],\n" 1721 | " [[8, 9]]]"); 1722 | CheckNdArray(r0[1], "[]"); 1723 | CheckNdArray(r0[2], 1724 | "[[[2, 3],\n" 1725 | " [4, 5],\n" 1726 | " [6, 7]],\n" 1727 | " [[10, 11],\n" 1728 | " [12, 13],\n" 1729 | " [14, 15]]]"); 1730 | 1731 | auto r1 = Split(m1, {2, 0}, 1); 1732 | CHECK(r1.size() == 3); 1733 | CheckNdArray(r1[0], 1734 | "[[[0, 1],\n" 1735 | " [2, 3]],\n" 1736 | " [[8, 9],\n" 1737 | " [10, 11]]]"); 1738 | CheckNdArray(r1[1], "[]"); 1739 | CheckNdArray(r1[2], 1740 | "[[[0, 1],\n" 1741 | " [2, 3],\n" 1742 | " [4, 5],\n" 1743 | " [6, 7]],\n" 1744 | " [[8, 9],\n" 1745 | " [10, 11],\n" 1746 | " [12, 13],\n" 1747 | " [14, 15]]]"); 1748 | 1749 | auto r2 = Split(m1, {0, 2, 3}, 1); 1750 | CHECK(r2.size() == 4); 1751 | CheckNdArray(r2[0], "[]"); 1752 | CheckNdArray(r2[1], 1753 | "[[[0, 1],\n" 1754 | " [2, 3]],\n" 1755 | " [[8, 9],\n" 1756 | " [10, 11]]]"); 1757 | CheckNdArray(r2[2], 1758 | "[[[4, 5]],\n" 1759 | " [[12, 13]]]"); 1760 | CheckNdArray(r2[3], 1761 | "[[[6, 7]],\n" 1762 | " [[14, 15]]]"); 1763 | 1764 | auto r3 = Split(m1, {2, 4}, 1); 1765 | CHECK(r3.size() == 3); 1766 | CheckNdArray(r3[0], 1767 | "[[[0, 1],\n" 1768 | " [2, 3]],\n" 1769 | " [[8, 9],\n" 1770 | " [10, 11]]]"); 1771 | CheckNdArray(r3[1], 1772 | "[[[4, 5],\n" 1773 | " [6, 7]],\n" 1774 | " [[12, 13],\n" 1775 | " [14, 15]]]"); 1776 | CheckNdArray(r3[2], "[]"); 1777 | 1778 | auto r4 = Split(m1, {2, 4}, -2); // negative axis 1779 | CHECK(r4.size() == 3); 1780 | CheckNdArray(r4[0], 1781 | "[[[0, 1],\n" 1782 | " [2, 3]],\n" 1783 | " [[8, 9],\n" 1784 | " [10, 11]]]"); 1785 | CheckNdArray(r4[1], 1786 | "[[[4, 5],\n" 1787 | " [6, 7]],\n" 1788 | " [[12, 13],\n" 1789 | " [14, 15]]]"); 1790 | CheckNdArray(r4[2], "[]"); 1791 | } 1792 | 1793 | SECTION("Function Split by n_section") { 1794 | auto m1 = NdArray::Arange(16.f).reshape(2, 4, 2); 1795 | auto r0 = Split(m1, 2, 1); 1796 | CHECK(r0.size() == 2); 1797 | CheckNdArray(r0[0], 1798 | "[[[0, 1],\n" 1799 | " [2, 3]],\n" 1800 | " [[8, 9],\n" 1801 | " [10, 11]]]"); 1802 | CheckNdArray(r0[1], 1803 | "[[[4, 5],\n" 1804 | " [6, 7]],\n" 1805 | " [[12, 13],\n" 1806 | " [14, 15]]]"); 1807 | 1808 | auto r1 = Split(m1, 4, 1); 1809 | CHECK(r1.size() == 4); 1810 | CheckNdArray(r1[0], 1811 | "[[[0, 1]],\n" 1812 | " [[8, 9]]]"); 1813 | CheckNdArray(r1[1], 1814 | "[[[2, 3]],\n" 1815 | " [[10, 11]]]"); 1816 | CheckNdArray(r1[2], 1817 | "[[[4, 5]],\n" 1818 | " [[12, 13]]]"); 1819 | CheckNdArray(r1[3], 1820 | "[[[6, 7]],\n" 1821 | " [[14, 15]]]"); 1822 | 1823 | auto r2 = Split(m1, 2, 2); 1824 | CHECK(r2.size() == 2); 1825 | CheckNdArray(r2[0], 1826 | "[[[0],\n" 1827 | " [2],\n" 1828 | " [4],\n" 1829 | " [6]],\n" 1830 | " [[8],\n" 1831 | " [10],\n" 1832 | " [12],\n" 1833 | " [14]]]"); 1834 | CheckNdArray(r2[1], 1835 | "[[[1],\n" 1836 | " [3],\n" 1837 | " [5],\n" 1838 | " [7]],\n" 1839 | " [[9],\n" 1840 | " [11],\n" 1841 | " [13],\n" 1842 | " [15]]]"); 1843 | 1844 | auto r3 = Split(m1, 4, -2); 1845 | CHECK(r3.size() == 4); 1846 | CheckNdArray(r3[0], 1847 | "[[[0, 1]],\n" 1848 | " [[8, 9]]]"); 1849 | CheckNdArray(r3[1], 1850 | "[[[2, 3]],\n" 1851 | " [[10, 11]]]"); 1852 | CheckNdArray(r3[2], 1853 | "[[[4, 5]],\n" 1854 | " [[12, 13]]]"); 1855 | CheckNdArray(r3[3], 1856 | "[[[6, 7]],\n" 1857 | " [[14, 15]]]"); 1858 | } 1859 | 1860 | SECTION("Function Separate") { 1861 | auto m = NdArray::Arange(16.f).reshape(2, 4, 2); 1862 | // Axis 0 1863 | auto r0 = Separate(m, 0); 1864 | CHECK(r0.size() == 2); 1865 | auto m0 = Stack(r0, 0); 1866 | CHECK(All(m == m0)); 1867 | // Axis 1 1868 | auto r1 = Separate(m, 1); 1869 | CHECK(r1.size() == 4); 1870 | auto m1 = Stack(r1, 1); 1871 | CHECK(All(m == m1)); 1872 | // Axis 2 1873 | auto r2 = Separate(m, 2); 1874 | CHECK(r2.size() == 2); 1875 | auto m2 = Stack(r2, 2); 1876 | CHECK(All(m == m2)); 1877 | // Negative axis 1878 | auto r3 = Separate(m, -1); 1879 | CHECK(r3.size() == 2); 1880 | auto m3 = Stack(r3, -1); 1881 | CHECK(All(m == m3)); 1882 | } 1883 | 1884 | SECTION("Function Transpose") { 1885 | auto m1 = NdArray::Arange(6.f).reshape(3, 2); 1886 | CheckNdArray(Transpose(m1), 1887 | "[[0, 2, 4],\n" 1888 | " [1, 3, 5]]"); 1889 | auto m2 = NdArray::Arange(8).reshape(2, 2, 2); 1890 | CheckNdArray(Transpose(m2), 1891 | "[[[0, 4],\n" 1892 | " [2, 6]],\n" 1893 | " [[1, 5],\n" 1894 | " [3, 7]]]"); 1895 | } 1896 | 1897 | SECTION("Function Swapaxes") { 1898 | auto m1 = NdArray::Arange(6.f).reshape(3, 2); 1899 | CheckNdArray(Swapaxes(m1, -1, -2), 1900 | "[[0, 2, 4],\n" 1901 | " [1, 3, 5]]"); 1902 | auto m2 = NdArray::Arange(8).reshape(2, 2, 2); 1903 | CheckNdArray(Swapaxes(m2, 0, 2), 1904 | "[[[0, 4],\n" 1905 | " [2, 6]],\n" 1906 | " [[1, 5],\n" 1907 | " [3, 7]]]"); 1908 | } 1909 | 1910 | SECTION("Function BroadcastTo") { 1911 | CheckNdArray(BroadcastTo(NdArray::Arange(3.f), {2, 3}), 1912 | "[[0, 1, 2],\n" 1913 | " [0, 1, 2]]"); 1914 | CheckNdArray(BroadcastTo(NdArray::Arange(3.f), {2, 2, 3}), 1915 | "[[[0, 1, 2],\n" 1916 | " [0, 1, 2]],\n" 1917 | " [[0, 1, 2],\n" 1918 | " [0, 1, 2]]]"); 1919 | } 1920 | 1921 | SECTION("Function Inverse (2d)") { 1922 | auto m1 = NdArray::Arange(4).reshape(2, 2) + 1.f; 1923 | auto m2 = Inv(m1); 1924 | CheckNdArray(m2, 1925 | "[[-2, 1],\n" 1926 | " [1.5, -0.5]]", 1927 | 4); // Low precision 1928 | CheckNdArray(m1.dot(m2), 1929 | "[[1, 0],\n" 1930 | " [0, 1]]", 1931 | 4); // Low precision 1932 | } 1933 | 1934 | SECTION("Function Inverse (high-dim)") { 1935 | auto m1 = NdArray::Arange(24).reshape(2, 3, 2, 2) + 1.f; 1936 | auto m2 = Inv(m1); 1937 | CheckNdArray(m2, 1938 | "[[[[-2, 1],\n" 1939 | " [1.5, -0.5]],\n" 1940 | " [[-4, 3],\n" 1941 | " [3.5, -2.5]],\n" 1942 | " [[-6, 5],\n" 1943 | " [5.5, -4.5]]],\n" 1944 | " [[[-8, 7],\n" 1945 | " [7.5, -6.5]],\n" 1946 | " [[-10, 9],\n" 1947 | " [9.5, -8.5]],\n" 1948 | " [[-12, 11],\n" 1949 | " [11.5, -10.5]]]]", 1950 | 4); // Low precision 1951 | } 1952 | 1953 | // ----------------------- In-place Operator function ---------------------- 1954 | SECTION("Function in-place basic") { 1955 | // In-place 1956 | NdArray m1 = NdArray::Arange(3); 1957 | NdArray m2 = NdArray::Arange(3); 1958 | uintptr_t m1_id = m1.id(); 1959 | uintptr_t m2_id = m2.id(); 1960 | NdArray m3 = Power(std::move(m1), std::move(m2)); 1961 | CHECK((m3.id() == m1_id || m3.id() == m2_id)); // m3 is not new array 1962 | // Not in-place 1963 | CheckNdArrayNotInplace(-NdArray::Arange(3.f), "[0, 1, 2]", 1964 | static_cast(Abs)); 1965 | // No matching broadcast 1966 | NdArray m4 = NdArray::Arange(6).reshape(2, 1, 3); 1967 | NdArray m5 = NdArray::Arange(2).reshape(2, 1); 1968 | uintptr_t m4_id = m4.id(); 1969 | uintptr_t m5_id = m5.id(); 1970 | NdArray m6 = Power(std::move(m4), std::move(m5)); 1971 | CHECK((m6.id() != m4_id && m6.id() != m5_id)); // m6 is new array 1972 | } 1973 | 1974 | SECTION("Single (inplace)") { 1975 | CheckNdArrayInplace(NdArray::Arange(3.f), "[0, 1, 2]", 1976 | static_cast(Positive)); 1977 | CheckNdArrayInplace(NdArray::Arange(3.f), "[-0, -1, -2]", 1978 | static_cast(Negative)); 1979 | } 1980 | 1981 | SECTION("Function Arithmetic (NdArray, NdArray) (in-place both)") { 1982 | CheckNdArrayInplace( 1983 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f), 1984 | "[[0, 2, 4],\n" 1985 | " [3, 5, 7]]", 1986 | static_cast(Add)); 1987 | CheckNdArrayInplace( 1988 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f), 1989 | "[[0, 0, 0],\n" 1990 | " [3, 3, 3]]", 1991 | static_cast(Subtract)); 1992 | CheckNdArrayInplace( 1993 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f), 1994 | "[[0, 1, 4],\n" 1995 | " [0, 4, 10]]", 1996 | static_cast(Multiply)); 1997 | CheckNdArrayInplace( 1998 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f) + 1.f, 1999 | "[[0, 0.5, 0.666667],\n" 2000 | " [3, 2, 1.66667]]", 2001 | static_cast(Divide)); 2002 | } 2003 | 2004 | SECTION("Function Arithmetic (NdArray, NdArray) (in-place right)") { 2005 | auto m1 = NdArray::Arange(3.f); 2006 | auto m2 = NdArray::Arange(3.f) + 1.f; 2007 | CheckNdArrayInplace( 2008 | m1, NdArray::Arange(6.f).reshape(2, 3), 2009 | "[[0, 2, 4],\n" 2010 | " [3, 5, 7]]", 2011 | static_cast(Add)); 2012 | CheckNdArrayInplace( 2013 | m1, NdArray::Arange(6.f).reshape(2, 3), 2014 | "[[0, 0, 0],\n" 2015 | " [-3, -3, -3]]", 2016 | static_cast(Subtract)); 2017 | CheckNdArrayInplace( 2018 | m1, NdArray::Arange(6.f).reshape(2, 3), 2019 | "[[0, 1, 4],\n" 2020 | " [0, 4, 10]]", 2021 | static_cast(Multiply)); 2022 | CheckNdArrayInplace( 2023 | m2, NdArray::Arange(6.f).reshape(2, 3), 2024 | "[[inf, 2, 1.5],\n" 2025 | " [0.333333, 0.5, 0.6]]", 2026 | static_cast(Divide)); 2027 | } 2028 | 2029 | SECTION("Function Arithmetic (NdArray, NdArray) (in-place left)") { 2030 | auto m2 = NdArray::Arange(3.f); 2031 | auto m3 = NdArray::Arange(3.f) + 1.f; 2032 | CheckNdArrayInplace( 2033 | NdArray::Arange(6.f).reshape(2, 3), m2, 2034 | "[[0, 2, 4],\n" 2035 | " [3, 5, 7]]", 2036 | static_cast(Add)); 2037 | CheckNdArrayInplace( 2038 | NdArray::Arange(6.f).reshape(2, 3), m2, 2039 | "[[0, 0, 0],\n" 2040 | " [3, 3, 3]]", 2041 | static_cast(Subtract)); 2042 | CheckNdArrayInplace( 2043 | NdArray::Arange(6.f).reshape(2, 3), m2, 2044 | "[[0, 1, 4],\n" 2045 | " [0, 4, 10]]", 2046 | static_cast(Multiply)); 2047 | CheckNdArrayInplace( 2048 | NdArray::Arange(6.f).reshape(2, 3), m3, 2049 | "[[0, 0.5, 0.666667],\n" 2050 | " [3, 2, 1.66667]]", 2051 | static_cast(Divide)); 2052 | } 2053 | 2054 | SECTION("Function Arithmetic (NdArray, float) (inplace)") { 2055 | CheckNdArrayInplace(NdArray::Arange(3.f), 2.f, "[2, 3, 4]", 2056 | static_cast(Add)); 2057 | CheckNdArrayInplace( 2058 | NdArray::Arange(3.f), 2.f, "[-2, -1, 0]", 2059 | static_cast(Subtract)); 2060 | CheckNdArrayInplace( 2061 | NdArray::Arange(3.f), 2.f, "[0, 2, 4]", 2062 | static_cast(Multiply)); 2063 | CheckNdArrayInplace(NdArray::Arange(3.f), 2.f, "[0, 0.5, 1]", 2064 | static_cast(Divide)); 2065 | } 2066 | 2067 | SECTION("Function Arithmetic (float, NdArray) (inplace)") { 2068 | CheckNdArrayInplace(2.f, NdArray::Arange(3.f), "[2, 3, 4]", 2069 | static_cast(Add)); 2070 | CheckNdArrayInplace( 2071 | 2.f, NdArray::Arange(3.f), "[2, 1, 0]", 2072 | static_cast(Subtract)); 2073 | CheckNdArrayInplace( 2074 | 2.f, NdArray::Arange(3.f), "[0, 2, 4]", 2075 | static_cast(Multiply)); 2076 | CheckNdArrayInplace(2.f, NdArray::Arange(3.f), "[inf, 2, 1]", 2077 | static_cast(Divide)); 2078 | } 2079 | 2080 | SECTION("Function Comparison (NdArray, NdArray) (inplace both)") { 2081 | CheckNdArrayInplace( 2082 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[0, 1, 0]", 2083 | static_cast(Equal)); 2084 | CheckNdArrayInplace( 2085 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[1, 0, 1]", 2086 | static_cast(NotEqual)); 2087 | CheckNdArrayInplace( 2088 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[0, 0, 1]", 2089 | static_cast(Greater)); 2090 | CheckNdArrayInplace( 2091 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[0, 1, 1]", 2092 | static_cast(GreaterEqual)); 2093 | CheckNdArrayInplace( 2094 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[1, 0, 0]", 2095 | static_cast(Less)); 2096 | CheckNdArrayInplace( 2097 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[1, 1, 0]", 2098 | static_cast(LessEqual)); 2099 | } 2100 | 2101 | SECTION("Function Comparison (NdArray, NdArray) (inplace right)") { 2102 | auto m2 = NdArray::Zeros(1) + 1.f; 2103 | CheckNdArrayInplace( 2104 | NdArray::Arange(3.f), m2, "[0, 1, 0]", 2105 | static_cast(Equal)); 2106 | CheckNdArrayInplace( 2107 | NdArray::Arange(3.f), m2, "[1, 0, 1]", 2108 | static_cast(NotEqual)); 2109 | CheckNdArrayInplace( 2110 | NdArray::Arange(3.f), m2, "[0, 0, 1]", 2111 | static_cast(Greater)); 2112 | CheckNdArrayInplace(NdArray::Arange(3.f), m2, "[0, 1, 1]", 2113 | static_cast( 2114 | GreaterEqual)); 2115 | CheckNdArrayInplace( 2116 | NdArray::Arange(3.f), m2, "[1, 0, 0]", 2117 | static_cast(Less)); 2118 | CheckNdArrayInplace( 2119 | NdArray::Arange(3.f), m2, "[1, 1, 0]", 2120 | static_cast(LessEqual)); 2121 | } 2122 | 2123 | SECTION("Function Comparison (NdArray, NdArray) (inplace left)") { 2124 | auto m1 = NdArray::Zeros(1) + 1.f; 2125 | CheckNdArrayInplace( 2126 | m1, NdArray::Arange(3.f), "[0, 1, 0]", 2127 | static_cast(Equal)); 2128 | CheckNdArrayInplace( 2129 | m1, NdArray::Arange(3.f), "[1, 0, 1]", 2130 | static_cast(NotEqual)); 2131 | CheckNdArrayInplace( 2132 | m1, NdArray::Arange(3.f), "[1, 0, 0]", 2133 | static_cast(Greater)); 2134 | CheckNdArrayInplace(m1, NdArray::Arange(3.f), "[1, 1, 0]", 2135 | static_cast( 2136 | GreaterEqual)); 2137 | CheckNdArrayInplace( 2138 | m1, NdArray::Arange(3.f), "[0, 0, 1]", 2139 | static_cast(Less)); 2140 | CheckNdArrayInplace( 2141 | m1, NdArray::Arange(3.f), "[0, 1, 1]", 2142 | static_cast(LessEqual)); 2143 | } 2144 | 2145 | SECTION("Function Comparison (NdArray, float) (inplace)") { 2146 | CheckNdArrayInplace(NdArray::Arange(3.f), 1.f, "[0, 1, 0]", 2147 | static_cast(Equal)); 2148 | CheckNdArrayInplace( 2149 | NdArray::Arange(3.f), 1.f, "[1, 0, 1]", 2150 | static_cast(NotEqual)); 2151 | CheckNdArrayInplace( 2152 | NdArray::Arange(3.f), 1.f, "[0, 0, 1]", 2153 | static_cast(Greater)); 2154 | CheckNdArrayInplace( 2155 | NdArray::Arange(3.f), 1.f, "[0, 1, 1]", 2156 | static_cast(GreaterEqual)); 2157 | CheckNdArrayInplace(NdArray::Arange(3.f), 1.f, "[1, 0, 0]", 2158 | static_cast(Less)); 2159 | CheckNdArrayInplace( 2160 | NdArray::Arange(3.f), 1.f, "[1, 1, 0]", 2161 | static_cast(LessEqual)); 2162 | } 2163 | 2164 | SECTION("Function Comparison (NdArray, float) (inplace)") { 2165 | CheckNdArrayInplace(1.f, NdArray::Arange(3.f), "[0, 1, 0]", 2166 | static_cast(Equal)); 2167 | CheckNdArrayInplace( 2168 | 1.f, NdArray::Arange(3.f), "[1, 0, 1]", 2169 | static_cast(NotEqual)); 2170 | CheckNdArrayInplace( 2171 | 1.f, NdArray::Arange(3.f), "[1, 0, 0]", 2172 | static_cast(Greater)); 2173 | CheckNdArrayInplace( 2174 | 1.f, NdArray::Arange(3.f), "[1, 1, 0]", 2175 | static_cast(GreaterEqual)); 2176 | CheckNdArrayInplace(1.f, NdArray::Arange(3.f), "[0, 0, 1]", 2177 | static_cast(Less)); 2178 | CheckNdArrayInplace( 2179 | 1.f, NdArray::Arange(3.f), "[0, 1, 1]", 2180 | static_cast(LessEqual)); 2181 | } 2182 | 2183 | SECTION("Function Basic Math (in-place)") { 2184 | CheckNdArrayInplace(-NdArray::Arange(3.f), "[0, 1, 2]", 2185 | static_cast(Abs)); 2186 | CheckNdArrayInplace(NdArray::Arange(7.f) / 3.f - 1.f, 2187 | "[-1, -1, -1, 0, 1, 1, 1]", 2188 | static_cast(Sign)); 2189 | CheckNdArrayInplace(NdArray::Arange(7.f) / 3.f - 1.f, 2190 | "[-1, -0, -0, 0, 1, 1, 1]", 2191 | static_cast(Ceil)); 2192 | CheckNdArrayInplace(NdArray::Arange(7.f) / 3.f - 1.f, 2193 | "[-1, -1, -1, 0, 0, 0, 1]", 2194 | static_cast(Floor)); 2195 | auto clip_bind = std::bind( 2196 | static_cast(Clip), 2197 | std::placeholders::_1, -0.5f, 0.4f); 2198 | CheckNdArrayInplace(NdArray::Arange(7.f) / 3.f - 1.f, 2199 | "[-0.5, -0.5, -0.333333, 0, 0.333333, 0.4, 0.4]", 2200 | clip_bind); 2201 | CheckNdArrayInplace(NdArray::Arange(3.f), "[0, 1, 1.41421]", 2202 | static_cast(Sqrt)); 2203 | CheckNdArrayInplace(NdArray::Arange(3.f), "[1, 2.71828, 7.38906]", 2204 | static_cast(Exp)); 2205 | CheckNdArrayInplace(NdArray::Arange(3.f), "[-inf, 0, 0.693147]", 2206 | static_cast(Log)); 2207 | CheckNdArrayInplace(NdArray::Arange(3.f), "[0, 1, 4]", 2208 | static_cast(Square)); 2209 | CheckNdArrayInplace( 2210 | NdArray::Arange(3.f), NdArray::Arange(3.f) + 1.f, "[0, 1, 8]", 2211 | static_cast(Power)); 2212 | const auto m1 = NdArray::Arange(3.f) + 1.f; 2213 | const auto m2 = NdArray::Arange(3.f); 2214 | CheckNdArrayInplace( 2215 | m2, NdArray::Arange(3.f) + 1.f, "[0, 1, 8]", 2216 | static_cast(Power)); 2217 | CheckNdArrayInplace( 2218 | NdArray::Arange(3.f), m1, "[0, 1, 8]", 2219 | static_cast(Power)); 2220 | CheckNdArrayInplace(NdArray::Arange(3.f), 4.f, "[0, 1, 16]", 2221 | static_cast(Power)); 2222 | CheckNdArrayInplace(4.f, NdArray::Arange(3.f), "[1, 4, 16]", 2223 | static_cast(Power)); 2224 | } 2225 | 2226 | SECTION("Function Trigonometric (in-place)") { 2227 | CheckNdArrayInplace(NdArray::Arange(3.f), "[0, 0.841471, 0.909297]", 2228 | static_cast(Sin)); 2229 | CheckNdArrayInplace(NdArray::Arange(3.f), "[1, 0.540302, -0.416147]", 2230 | static_cast(Cos)); 2231 | CheckNdArrayInplace(NdArray::Arange(3.f), "[0, 1.55741, -2.18504]", 2232 | static_cast(Tan)); 2233 | } 2234 | 2235 | SECTION("Function Inverse-Trigonometric (in-place)") { 2236 | auto m1 = NdArray::Arange(3) - 1.f; 2237 | auto m2 = NdArray::Arange(3) * 100.f; 2238 | CheckNdArrayInplace(NdArray::Arange(3.f) - 1.f, "[-1.5708, 0, 1.5708]", 2239 | static_cast(ArcSin)); 2240 | CheckNdArrayInplace(NdArray::Arange(3.f) - 1.f, "[3.14159, 1.5708, 0]", 2241 | static_cast(ArcCos)); 2242 | CheckNdArrayInplace(NdArray::Arange(3.f) * 100.f, "[0, 1.5608, 1.5658]", 2243 | static_cast(ArcTan)); 2244 | CheckNdArrayInplace( 2245 | NdArray::Arange(3.f) - 1.f, NdArray::Arange(3.f) * 100.f, 2246 | "[-1.5708, 0, 0.00499996]", 2247 | static_cast(ArcTan2)); 2248 | CheckNdArrayInplace( 2249 | m1, NdArray::Arange(3.f) * 100.f, "[-1.5708, 0, 0.00499996]", 2250 | static_cast(ArcTan2)); 2251 | CheckNdArrayInplace( 2252 | NdArray::Arange(3.f) - 1.f, m2, "[-1.5708, 0, 0.00499996]", 2253 | static_cast(ArcTan2)); 2254 | CheckNdArrayInplace( 2255 | NdArray::Arange(3.f) - 1.f, 2.f, "[-0.463648, 0, 0.463648]", 2256 | static_cast(ArcTan2)); 2257 | CheckNdArrayInplace( 2258 | 2.f, NdArray::Arange(3.f) - 1.f, "[2.03444, 1.5708, 1.10715]", 2259 | static_cast(ArcTan2)); 2260 | } 2261 | 2262 | SECTION("Where (in-place)") { 2263 | auto m1 = NdArray::Arange(6.f).reshape(2, 3); 2264 | CheckNdArray(Where(2.f < m1, NdArray::Ones(2, 1), NdArray::Arange(3)), 2265 | "[[0, 1, 2],\n" 2266 | " [1, 1, 1]]"); 2267 | CheckNdArray(Where(2.f < m1, 1.f, NdArray::Arange(3)), 2268 | "[[0, 1, 2],\n" 2269 | " [1, 1, 1]]"); 2270 | CheckNdArray(Where(2.f < m1, NdArray::Arange(3), 0.f), 2271 | "[[0, 0, 0],\n" 2272 | " [0, 1, 2]]"); 2273 | CheckNdArray(Where(2.f < m1, 1.f, 0.f), 2274 | "[[0, 0, 0],\n" 2275 | " [1, 1, 1]]"); 2276 | } 2277 | 2278 | SECTION("Function Inverse (in-place)") { 2279 | auto m1 = NdArray::Arange(4).reshape(2, 2) + 1.f; 2280 | auto m2 = Inv(m1); 2281 | CheckNdArrayInplace(NdArray::Arange(4).reshape(2, 2) + 1.f, 2282 | "[[-2, 1],\n" 2283 | " [1.5, -0.5]]", 2284 | static_cast(Inv), 2285 | 4); // Low precision 2286 | } 2287 | } 2288 | -------------------------------------------------------------------------------- /tests/test_dot_major.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | 19 | #include "timer.h" // g_timer 20 | 21 | #define TINYNDARRAY_IMPLEMENTATION 22 | #define TINYNDARRAY_NO_INCLUDE 23 | 24 | // ----------------------------------------------------------------------------- 25 | // Automatic 26 | namespace tnd_auto { 27 | #undef TINYNDARRAY_H_ONCE 28 | #include "../tinyndarray.h" 29 | } // namespace tnd_auto 30 | 31 | // ----------------------------------------------------------------------------- 32 | // Col-major 33 | namespace tnd_col { 34 | #undef TINYNDARRAY_H_ONCE 35 | #define TINYNDARRAY_FORCE_DOT_COLMAJOR 36 | #include "../tinyndarray.h" 37 | #undef TINYNDARRAY_FORCE_DOT_COLMAJOR 38 | } // namespace tnd_col 39 | 40 | // ----------------------------------------------------------------------------- 41 | // Row major 42 | namespace tnd_row { 43 | #undef TINYNDARRAY_H_ONCE 44 | #define TINYNDARRAY_FORCE_DOT_RAWMAJOR 45 | #include "../tinyndarray.h" 46 | #undef TINYNDARRAY_FORCE_DOT_RAWMAJOR 47 | } // namespace tnd_row 48 | // ----------------------------------------------------------------------------- 49 | 50 | using NdArrayAuto = tnd_auto::tinyndarray::NdArray; 51 | using NdArrayCol = tnd_col::tinyndarray::NdArray; 52 | using NdArrayRow = tnd_row::tinyndarray::NdArray; 53 | using namespace tnd_auto::tinyndarray; 54 | 55 | // ----------------------------------------------------------------------------- 56 | 57 | template 58 | float TestDotTimeOne(const Shape& l_shape, const Shape& r_shape) { 59 | try { 60 | float elapsed = 0.f; 61 | float count = 0.f; 62 | while (elapsed < 500.f) { 63 | // Run 64 | g_timer.start(); 65 | T::Ones(l_shape).dot(T::Ones(r_shape)); 66 | g_timer.end(); 67 | elapsed += g_timer.getElapsedMsec(); 68 | count += 1.f; 69 | } 70 | return elapsed / count; 71 | } catch (...) { 72 | return std::numeric_limits::max(); 73 | } 74 | } 75 | 76 | static std::ostream& SetFmt(std::ostream& os) { 77 | os << std::setprecision(3); 78 | return os; 79 | } 80 | 81 | static void TestDotTime(const Shape& l_shape, const Shape& r_shape) { 82 | // Print header 83 | std::stringstream ss; 84 | ss << l_shape << " @ " << r_shape << ": "; 85 | const std::string& head_str = ss.str(); 86 | std::cout << head_str; 87 | const int lack = 30 - static_cast(head_str.size()); 88 | for (int i = 0; i < lack; i++) { 89 | std::cout << " "; 90 | } 91 | 92 | // Measure 93 | const float time_auto = TestDotTimeOne(l_shape, r_shape); 94 | SetFmt(std::cout) << "a:" << time_auto << "ms, "; 95 | const float time_col = TestDotTimeOne(l_shape, r_shape); 96 | SetFmt(std::cout) << "c:" << time_col << "ms, "; 97 | const float time_row = TestDotTimeOne(l_shape, r_shape); 98 | SetFmt(std::cout) << "r:" << time_row << "ms, "; 99 | 100 | // Analyze 101 | const bool should_col = (time_col < time_row); 102 | const bool used_col = 103 | (std::abs(time_col - time_auto) < std::abs(time_row - time_auto)); 104 | if (used_col) { 105 | std::cout << "(now: col) "; 106 | } else { 107 | std::cout << "(now: row) "; 108 | } 109 | if (should_col) { 110 | std::cout << "(should be col) "; 111 | } else { 112 | std::cout << "(should be row) "; 113 | } 114 | if (used_col == should_col) { 115 | // std::cout << "-> OK"; 116 | } else { 117 | std::cout << "-> NG"; 118 | } 119 | std::cout << std::endl; 120 | } 121 | 122 | int main(int argc, char const* argv[]) { 123 | (void)argc; 124 | (void)argv; 125 | 126 | constexpr int W = 20000; 127 | constexpr int W2 = 2000; 128 | // constexpr int S = 200; 129 | constexpr int S = 1; 130 | 131 | TestDotTime({1, W}, {W, 100}); 132 | TestDotTime({10, W}, {W, 100}); 133 | TestDotTime({100, W}, {W, 100}); 134 | TestDotTime({1000, W}, {W, 100}); 135 | 136 | TestDotTime({1, W}, {W, 1}); 137 | TestDotTime({10, W}, {W, 10}); 138 | TestDotTime({100, W}, {W, 100}); 139 | TestDotTime({1000, W}, {W, 1000}); 140 | 141 | TestDotTime({1, W}, {W, 1000}); 142 | TestDotTime({10, W}, {W, 100}); 143 | TestDotTime({100, W}, {W, 10}); 144 | TestDotTime({1000, W}, {W, 1}); 145 | 146 | TestDotTime({100, W}, {W, 1}); 147 | TestDotTime({100, W}, {W, 2}); 148 | TestDotTime({100, W}, {W, 5}); 149 | TestDotTime({100, W}, {W, 7}); 150 | TestDotTime({100, W}, {W, 10}); 151 | TestDotTime({100, W}, {W, 12}); 152 | TestDotTime({100, W}, {W, 15}); 153 | TestDotTime({100, W}, {W, 17}); 154 | TestDotTime({100, W}, {W, 20}); 155 | TestDotTime({100, W}, {W, 50}); 156 | TestDotTime({100, W}, {W, 70}); 157 | TestDotTime({100, W}, {W, 20}); 158 | TestDotTime({100, W}, {W, 50}); 159 | TestDotTime({100, W}, {W, 70}); 160 | TestDotTime({100, W}, {W, 100}); 161 | TestDotTime({100, W}, {W, 1000}); 162 | 163 | TestDotTime({W, 1}, {1, W}); 164 | TestDotTime({W, 10}, {10, W}); 165 | TestDotTime({W, 100}, {100, W}); 166 | 167 | TestDotTime({10 * W2, 1}, {1, W2}); 168 | TestDotTime({10 * W2, 10}, {10, W2}); 169 | TestDotTime({10 * W2, 100}, {100, W2}); 170 | 171 | std::cout << "------------------------------------------" << std::endl; 172 | 173 | TestDotTime({10 * W2, 1}, {1, 100}); 174 | TestDotTime({10 * W2, 10}, {10, 100}); 175 | TestDotTime({10 * W2, 100}, {100, 100}); 176 | 177 | TestDotTime({10 * 100, 1}, {1, W2}); 178 | TestDotTime({10 * 100, 10}, {10, W2}); 179 | TestDotTime({10 * 100, 100}, {100, W2}); 180 | 181 | TestDotTime({100, W2}, {W2, 20000}); 182 | TestDotTime({20000, W2}, {W2, 100}); 183 | 184 | std::cout << "------------------------------------------" << std::endl; 185 | 186 | TestDotTime({1, W}, {S, W, 100}); 187 | TestDotTime({10, W}, {S, W, 100}); 188 | TestDotTime({100, W}, {S, W, 100}); 189 | TestDotTime({1000, W}, {S, W, 100}); 190 | 191 | TestDotTime({1, W}, {S, W, 1}); 192 | TestDotTime({10, W}, {S, W, 10}); 193 | TestDotTime({100, W}, {S, W, 100}); 194 | 195 | TestDotTime({1, W}, {S, W, 1000}); 196 | TestDotTime({10, W}, {S, W, 100}); 197 | TestDotTime({100, W}, {S, W, 10}); 198 | TestDotTime({1000, W}, {S, W, 1}); 199 | 200 | TestDotTime({100, W}, {S, W, 1}); 201 | TestDotTime({100, W}, {S, W, 10}); 202 | TestDotTime({100, W}, {S, W, 100}); 203 | 204 | return 0; 205 | } 206 | -------------------------------------------------------------------------------- /tests/test_main.cpp: -------------------------------------------------------------------------------- 1 | #define CATCH_CONFIG_MAIN // Define main() 2 | #include "Catch2/single_include/catch2/catch.hpp" 3 | -------------------------------------------------------------------------------- /tests/test_profiling.cpp: -------------------------------------------------------------------------------- 1 | #include "Catch2/single_include/catch2/catch.hpp" 2 | 3 | #define TINYNDARRAY_IMPLEMENTATION 4 | #define TINYNDARRAY_PROFILE_MEMORY // With profiling 5 | #include "../tinyndarray.h" 6 | 7 | using tinyndarray::NdArray; 8 | 9 | TEST_CASE("NdArray Profiling") { 10 | SECTION("Basic") { 11 | auto m1 = NdArray::Ones(10, 10); 12 | auto m2 = NdArray::Zeros(20, 20); 13 | CHECK(NdArray::GetNumInstance() == 2); 14 | CHECK(NdArray::GetTotalMemory() == (10 * 10) + (20 * 20)); 15 | } 16 | 17 | SECTION("Unregister") { 18 | auto m1 = NdArray::Ones(10, 10); 19 | auto m2 = NdArray::Zeros(20, 20); 20 | { 21 | auto m3 = NdArray::Zeros(3, 3); 22 | CHECK(NdArray::GetNumInstance() == 3); 23 | CHECK(NdArray::GetTotalMemory() == (10 * 10) + (20 * 20) + (3 * 3)); 24 | } 25 | 26 | CHECK(NdArray::GetNumInstance() == 2); 27 | CHECK(NdArray::GetTotalMemory() == (20 * 20) + (10 * 10)); 28 | } 29 | 30 | SECTION("Overwrite") { 31 | auto m1 = NdArray::Ones(10, 10); 32 | auto m2 = NdArray::Zeros(20, 20); 33 | { 34 | auto m3 = NdArray::Zeros(3, 3); 35 | CHECK(NdArray::GetNumInstance() == 3); 36 | CHECK(NdArray::GetTotalMemory() == (10 * 10) + (20 * 20) + (3 * 3)); 37 | 38 | m1 = m3; 39 | CHECK(NdArray::GetNumInstance() == 2); 40 | CHECK(NdArray::GetTotalMemory() == (20 * 20) + (3 * 3)); 41 | } 42 | 43 | CHECK(NdArray::GetNumInstance() == 2); 44 | CHECK(NdArray::GetTotalMemory() == (20 * 20) + (3 * 3)); 45 | } 46 | 47 | SECTION("Unregister of substance copy") { 48 | auto m1 = NdArray::Ones(10, 10); 49 | auto m2 = NdArray::Zeros(20, 20); 50 | { 51 | auto m3 = NdArray::Zeros(3, 3); 52 | CHECK(NdArray::GetNumInstance() == 3); 53 | CHECK(NdArray::GetTotalMemory() == (10 * 10) + (20 * 20) + (3 * 3)); 54 | 55 | m3 = m1.reshape(-1); 56 | CHECK(NdArray::GetNumInstance() == 2); 57 | CHECK(NdArray::GetTotalMemory() == (10 * 10) + (20 * 20)); 58 | } 59 | 60 | CHECK(NdArray::GetNumInstance() == 2); 61 | CHECK(NdArray::GetTotalMemory() == (10 * 10) + (20 * 20)); 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /tests/test_separated_1.cpp: -------------------------------------------------------------------------------- 1 | #include "test_common.h" 2 | -------------------------------------------------------------------------------- /tests/test_separated_2.cpp: -------------------------------------------------------------------------------- 1 | #define TINYNDARRAY_IMPLEMENTATION 2 | #include "../tinyndarray.h" 3 | -------------------------------------------------------------------------------- /tests/test_time_consumption.cpp: -------------------------------------------------------------------------------- 1 | #include "Catch2/single_include/catch2/catch.hpp" 2 | 3 | #define TINYNDARRAY_IMPLEMENTATION 4 | #include "../tinyndarray.h" 5 | 6 | #include "timer.h" // g_timer 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | using namespace tinyndarray; 13 | 14 | constexpr int W = 20000; 15 | constexpr int H = 20000; 16 | constexpr int WH = W * H; 17 | constexpr int N_WORKERS = -1; 18 | 19 | static auto MeasureOpTime(std::function op) { 20 | g_timer.start(); 21 | auto&& ret = op(); 22 | g_timer.end(); 23 | return std::tuple(g_timer.getElapsedMsec(), std::move(ret)); 24 | } 25 | 26 | template 27 | auto SplitRetItems(T... items) { 28 | return std::tuple, std::vector>( 29 | {std::get<0>(items)...}, {std::get<1>(items)...}); 30 | } 31 | 32 | static void PrintTimeResult(const std::string& tag, 33 | const std::vector& times, 34 | const int time_width) { 35 | std::cout << " * " << tag << " : "; 36 | std::cout << std::setw(time_width) << times[0] << " ms"; 37 | for (size_t i = 1; i < times.size(); i++) { 38 | std::cout << ", " << std::setw(time_width) << times[i] << " ms"; 39 | } 40 | std::cout << std::endl; 41 | } 42 | 43 | static void CheckSameNdArray(const NdArray& m1, const NdArray& m2) { 44 | CHECK(m1.shape() == m2.shape()); 45 | CHECK(All(m1 == m2)); 46 | } 47 | 48 | template 49 | void TestSingleMultiThread(const std::string& tag, F prep_func, OP... ops) { 50 | static_assert(0 < sizeof...(ops), "No operation"); 51 | 52 | // Print title 53 | std::cout << "* " << tag << std::endl; 54 | 55 | // Single thread 56 | NdArray::SetNumWorkers(1); 57 | const auto& single_rets = SplitRetItems(MeasureOpTime(ops)...); 58 | const auto& single_times = std::get<0>(single_rets); 59 | const auto& single_arrays = std::get<1>(single_rets); 60 | // Print result 61 | PrintTimeResult("Single", single_times, 5); 62 | 63 | // Prepare for next task 64 | prep_func(); 65 | 66 | // Multi thread 67 | NdArray::SetNumWorkers(N_WORKERS); 68 | const auto& multi_rets = SplitRetItems(MeasureOpTime(ops)...); 69 | const auto& multi_times = std::get<0>(multi_rets); 70 | const auto& multi_arrays = std::get<1>(multi_rets); 71 | // Print result 72 | PrintTimeResult("Multi ", multi_times, 5); 73 | 74 | // Time check 75 | for (size_t i = 0; i < sizeof...(ops); i++) { 76 | CHECK(multi_times[i] < single_times[i]); 77 | } 78 | 79 | // Check array content 80 | for (size_t i = 0; i < sizeof...(ops); i++) { 81 | CheckSameNdArray(single_arrays[i], multi_arrays[i]); 82 | } 83 | } 84 | 85 | // ---------------------------- Fill Initialization ---------------------------- 86 | TEST_CASE("NdArray Fill-initialization") { 87 | SECTION("(Zeros/Ones/Arange)") { 88 | TestSingleMultiThread( 89 | "Fill-initialization (Zeros/Ones/Arange)", [&]() {}, 90 | [&]() { return NdArray::Zeros(WH); }, 91 | [&]() { return NdArray::Ones(WH); }, 92 | [&]() { return NdArray::Arange(WH); }); 93 | } 94 | } 95 | 96 | // -------------------------- Element-wise Operation --------------------------- 97 | TEST_CASE("NdArray Element-wise") { 98 | SECTION("(NdArray, float)") { 99 | auto m1 = NdArray::Arange(WH); 100 | auto m1_move = NdArray::Arange(WH); 101 | auto m1_move_sub = NdArray::Arange(WH); 102 | auto m1_cao = NdArray::Arange(WH); 103 | auto m1_cao_sub = NdArray::Arange(WH); 104 | TestSingleMultiThread( 105 | "Element-wise (NdArray, float)", 106 | [&]() { 107 | m1_move = std::move(m1_move_sub); // Preparation for multi 108 | m1_cao = std::move(m1_cao_sub); 109 | }, 110 | [&]() { return m1 + 1.f; }, // Basic 111 | [&]() { return std::move(m1_move) + 1.f; }, // Inplace 112 | [&]() { return m1_cao += 1.f; }); // Compound Assignment 113 | } 114 | 115 | SECTION("(NdArray, NdArray) (same-size)") { 116 | auto m1 = NdArray::Arange(WH); 117 | auto m1_move = NdArray::Arange(WH); 118 | auto m1_move_sub = NdArray::Arange(WH); 119 | auto m1_cao = NdArray::Arange(WH); 120 | auto m1_cao_sub = NdArray::Arange(WH); 121 | auto m2 = NdArray::Ones(WH); 122 | 123 | TestSingleMultiThread( 124 | "Element-wise (NdArray, NdArray) (same-size)", 125 | [&]() { 126 | m1_move = std::move(m1_move_sub); 127 | m1_cao = std::move(m1_cao_sub); 128 | }, 129 | [&]() { return m1 + m2; }, // Basic 130 | [&]() { return std::move(m1_move) + m2; }, // Inplace 131 | [&]() { return m1_cao += m2; }); // Compound Assignment 132 | } 133 | 134 | SECTION("(NdArray, NdArray) (broadcast) (left-big)") { 135 | auto m1 = NdArray::Arange(WH).reshape(H, W); 136 | auto m1_move = NdArray::Arange(WH).reshape(H, W); 137 | auto m1_move_sub = NdArray::Arange(WH).reshape(H, W); 138 | auto m1_cao = NdArray::Arange(WH).reshape(H, W); 139 | auto m1_cao_sub = NdArray::Arange(WH).reshape(H, W); 140 | auto m2 = NdArray::Ones(W); 141 | TestSingleMultiThread( 142 | "Element-wise (NdArray, NdArray) (broadcast) (left-big)", 143 | [&]() { 144 | m1_move = std::move(m1_move_sub); 145 | m1_cao = std::move(m1_cao_sub); 146 | }, 147 | [&]() { return m1 + m2; }, // Basic 148 | [&]() { return std::move(m1_move) + m2; }, // Inplace 149 | [&]() { return m1_cao += m2; }); // Compound Assignment 150 | } 151 | 152 | SECTION("(NdArray, NdArray) (broadcast) (right-big)") { 153 | auto m1 = NdArray::Arange(WH).reshape(H, W); 154 | auto m1_move = NdArray::Arange(WH).reshape(H, W); 155 | auto m1_move_sub = NdArray::Arange(WH).reshape(H, W); 156 | auto m2 = NdArray::Ones(W); 157 | TestSingleMultiThread( 158 | "Element-wise (NdArray, NdArray) (broadcast) (right-big)", 159 | [&]() { m1_move = std::move(m1_move_sub); }, 160 | [&]() { return m2 + m1; }, // Basic 161 | [&]() { return m2 + std::move(m1_move); }); // Inplace 162 | } 163 | } 164 | 165 | // -------------------------------- Dot product -------------------------------- 166 | TEST_CASE("NdArray Dot") { 167 | SECTION("(1d1d)") { 168 | auto m1 = NdArray::Ones(16000000); // 16777216 is limit of float 169 | auto m2 = NdArray::Ones(16000000); 170 | TestSingleMultiThread( 171 | "Dot (1d1d)", [&]() {}, [&]() { return m1.dot(m2); }); 172 | } 173 | 174 | SECTION("(2d2d)") { 175 | auto m1 = NdArray::Arange(200 * W).reshape(200, W); 176 | auto m2 = NdArray::Ones(W, 200); 177 | TestSingleMultiThread( 178 | "Dot (2d2d)", [&]() {}, [&]() { return m1.dot(m2); }); 179 | } 180 | 181 | SECTION("(NdMd) (left-big)") { 182 | auto m1 = NdArray::Arange(WH).reshape(H, 1, W); 183 | auto m2 = NdArray::Ones(W, 1); 184 | TestSingleMultiThread( 185 | "Dot (NdMd) (left-big)", [&]() {}, 186 | [&]() { return m1.dot(m2); }); 187 | } 188 | 189 | SECTION("(NdMd) (right-big)") { 190 | auto m1 = NdArray::Ones(1, H); 191 | auto m2 = NdArray::Arange(WH).reshape(W, H, 1); 192 | TestSingleMultiThread( 193 | "Dot (NdMd) (right-big)", [&]() {}, 194 | [&]() { return m1.dot(m2); }); 195 | } 196 | 197 | /* 198 | // This type cannot be operated in parallel. 199 | SECTION("(NdMd) (right-big 2)") { 200 | auto m1 = NdArray::Ones(1, H); 201 | auto m2 = NdArray::Arange(WH).reshape(1, H, W); 202 | TestSingleMultiThread( 203 | "Dot (NdMd) (right-big 2)", [&]() {}, 204 | [&]() { return m1.dot(m2); }); 205 | } 206 | */ 207 | } 208 | 209 | // ------------------------------- Cross product ------------------------------- 210 | TEST_CASE("NdArray Cross") { 211 | SECTION("(NdMd)") { 212 | auto m1_a = NdArray::Arange(WH * 3).reshape(H, W, 3); 213 | auto m1_b = NdArray::Arange(WH * 2).reshape(H, W, 2); 214 | auto m2_a = NdArray::Ones(W, 3); 215 | auto m2_b = NdArray::Ones(W, 2); 216 | TestSingleMultiThread( 217 | "Cross (NdMd)", [&]() {}, 218 | [&]() { return m1_a.cross(m2_a); }, // 3x3 219 | [&]() { return m1_a.cross(m2_b); }, // 3x2 220 | [&]() { return m1_b.cross(m2_b); }); // 2x2 221 | } 222 | } 223 | 224 | // ------------------------------- Axis operation ------------------------------ 225 | TEST_CASE("NdArray Axis") { 226 | SECTION("Sum") { 227 | const int N = 4000; 228 | auto m1 = NdArray::Ones(N * N).reshape(N, N); // 16777216 is limit 229 | TestSingleMultiThread( 230 | "Sum", [&]() {}, [&]() { return m1.sum(); }, 231 | [&]() { return m1.sum(Axis{1}); }); 232 | } 233 | SECTION("Max") { 234 | auto m1 = NdArray::Arange(WH).reshape(W, H); 235 | TestSingleMultiThread( 236 | "Max", [&]() {}, [&]() { return m1.max(); }, 237 | [&]() { return m1.max(Axis{1}); }); 238 | } 239 | } 240 | 241 | // ----------------------------------- Slice ----------------------------------- 242 | TEST_CASE("NdArray Slice") { 243 | SECTION("Basic") { 244 | auto m1 = NdArray::Arange(WH * 2).reshape(H, W, 2); 245 | SliceIndex idx = {{W / 4, W / 4 * 3}, {0, W / 2}, {0, 1}}; 246 | TestSingleMultiThread( 247 | "Slice", [&]() {}, [&]() { return m1.slice(idx); }); 248 | } 249 | } 250 | -------------------------------------------------------------------------------- /tests/test_total.cpp: -------------------------------------------------------------------------------- 1 | #define TINYNDARRAY_IMPLEMENTATION 2 | #include "test_common.h" 3 | -------------------------------------------------------------------------------- /tests/timer.h: -------------------------------------------------------------------------------- 1 | #ifndef TINYNDARRAY_TIMER_H 2 | #define TINYNDARRAY_TIMER_H 3 | 4 | #include 5 | 6 | class Timer { 7 | public: 8 | Timer() {} 9 | 10 | void start() { 11 | m_start = std::chrono::system_clock::now(); 12 | } 13 | 14 | void end() { 15 | m_end = std::chrono::system_clock::now(); 16 | } 17 | 18 | float getElapsedMsec() const { 19 | using namespace std::chrono; 20 | return static_cast( 21 | duration_cast(m_end - m_start).count()); 22 | } 23 | 24 | private: 25 | std::chrono::system_clock::time_point m_start, m_end; 26 | }; 27 | 28 | static Timer g_timer; 29 | 30 | #endif /* end of include guard */ 31 | --------------------------------------------------------------------------------