├── .clang-format
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── tests
├── CMakeLists.txt
├── cmake
│ └── sanitizers
│ │ ├── FindASan.cmake
│ │ ├── FindMSan.cmake
│ │ ├── FindSanitizers.cmake
│ │ ├── FindTSan.cmake
│ │ ├── FindUBSan.cmake
│ │ ├── asan-wrapper
│ │ └── sanitize-helpers.cmake
├── test_common.h
├── test_dot_major.cpp
├── test_main.cpp
├── test_profiling.cpp
├── test_separated_1.cpp
├── test_separated_2.cpp
├── test_time_consumption.cpp
├── test_total.cpp
└── timer.h
└── tinyndarray.h
/.clang-format:
--------------------------------------------------------------------------------
1 | ---
2 | IndentWidth: 4
3 | TabWidth: 4
4 | ContinuationIndentWidth: 8
5 | UseTab: Never
6 | BreakBeforeBraces: Attach
7 | AccessModifierOffset: -4
8 | Standard: Cpp11
9 | AlignOperands: true
10 | BreakBeforeTernaryOperators: true
11 | AllowShortIfStatementsOnASingleLine: true
12 | AllowShortCaseLabelsOnASingleLine: true
13 | AllowShortLoopsOnASingleLine: true
14 | PointerAlignment: Left
15 | AlignAfterOpenBracket: Align
16 | ColumnLimit: 80
17 | AllowShortFunctionsOnASingleLine: Empty
18 | BasedOnStyle: 'Google'
19 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.swo
3 | *.swn
4 | *.swm
5 | *.plist
6 |
7 | build
8 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "tests/Catch2"]
2 | path = tests/Catch2
3 | url = https://github.com/catchorg/Catch2
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 David Pilger
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TinyNdArray
2 |
3 | Single Header C++ Implementation of NumPy NdArray.
4 |
5 | I look forward to your pull-request.
6 |
7 | ## Requirement
8 |
9 | * C++14 compiler
10 |
11 | ## Sample Code
12 |
13 | #define TINYNDARRAY_IMPLEMENTATION
14 | #include "tinyndarray.h"
15 |
16 | using tinyndarray::NdArray;
17 |
18 | int main(int argc, char const* argv[]) {
19 | auto m1 = NdArray::Arange(12).reshape(2, 2, 3) - 2.f;
20 | auto m2 = NdArray::Ones(3, 1) * 10.f;
21 | auto m12 = m1.dot(m2);
22 | std::cout << m12 << std::endl;
23 |
24 | NdArray m3 = {{{-0.4f, 0.3f}, {-0.2f, 0.1f}},
25 | {{-0.1f, 0.2f}, {-0.3f, 0.4f}}};
26 | m3 = Sin(std::move(m3));
27 | std::cout << m3 << std::endl;
28 |
29 | auto sum_abs = Abs(m3).sum();
30 | std::cout << sum_abs << std::endl;
31 |
32 | auto m4 = Where(0.f < m3, -100.f, 100.f);
33 | bool all_m4 = All(m4);
34 | bool any_m4 = Any(m4);
35 | std::cout << m4 << std::endl;
36 | std::cout << all_m4 << " " << any_m4 << std::endl;
37 |
38 | return 0;
39 | }
40 |
41 |
42 | ## Quick Guide
43 |
44 | TinyNdArray supports only float array.
45 |
46 | In the following Python code `dtype=float32` is omitted, and in C++ code assuming `using namespace tinyndarray;` is declared.
47 |
48 | For more detail, please see declarations in top of the header file.
49 |
50 | ### Copy behavior
51 |
52 | Copy behavior of NdArray is shallow copy which is same as NumPy.
53 |
54 |
55 |
56 | Numpy (Python) |
57 | TinyNdArray (C++) |
58 |
59 |
60 | a = np.ones((2, 3))
61 | b = a
62 | b[0, 0] = -1
63 | print(a[0, 0]) # -1
64 | |
65 |
66 | auto a = NdArray::Ones(2, 3);
67 | auto b = a;
68 | b[{0, 0}] = -1;
69 | std::cout << a[{0, 0}] << std::endl; // -1
70 | |
71 |
72 |
73 |
74 | ### Basic Constructing
75 |
76 | | **Numpy (Python)** | **TinyNdArray (C++)** |
77 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
78 | | ```a = np.array([2, 3])``` | ```NdArray a({2, 3});``` or |
79 | | | ```NdArray a(Shape{2, 3});``` |
80 |
81 | ### Float Initializer
82 |
83 | Supports up to 10 dimensions.
84 |
85 | | **Numpy (Python)** | **TinyNdArray (C++)** |
86 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
87 | | ```a = np.array([1.0, 2.0])``` | ```NdArray a = {1.f, 2.f};``` |
88 | | ```a = np.array([[1.0, 2.0]])``` | ```NdArray a = {{1.f, 2.f}};``` |
89 | | ```a = np.array([[1.0, 2.0], [3.0, 4.0]])``` | ```NdArray a = {{1.f, 2.f}, {3.f, 4.f}};``` |
90 | | ```a = np.array([[[[[[[[[[1.0, 2.0]]]]]]]]]])``` | ```NdArray a = {{{{{{{{{{1.f, 2.f}}}}}}}}}};``` |
91 |
92 | ### Static Initializer
93 |
94 | | **Numpy (Python)** | **TinyNdArray (C++)** |
95 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
96 | | ```a = np.empty((2, 3))``` | ```auto a = NdArray::Empty(2, 3);``` or |
97 | | | ```auto a = NdArray::Empty({2, 3});``` or |
98 | | | ```auto a = NdArray::Empty(Shape{2, 3});``` |
99 | | ```a = np.zeros((2, 3))``` | ```auto a = NdArray::Zeros(2, 3);``` or |
100 | | | ```auto a = NdArray::Zeros({2, 3});``` or |
101 | | | ```auto a = NdArray::Zeros(Shape{2, 3});``` |
102 | | ```a = np.ones((2, 3))``` | ```auto a = NdArray::Ones(2, 3);``` or |
103 | | | ```auto a = NdArray::Ones({2, 3});``` or |
104 | | | ```auto a = NdArray::Ones(Shape{2, 3});``` |
105 | | ```a = np.arange(10)``` | ```auto a = NdArray::Arange(10);``` |
106 | | ```a = np.arange(0, 100, 10)``` | ```auto a = NdArray::Arange(0, 100, 10);``` |
107 |
108 | ### Random Initializer
109 |
110 | | **Numpy (Python)** | **TinyNdArray (C++)** |
111 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
112 | | ```a = np.random.uniform(10)``` | ```auto a = NdArray::Uniform(10);``` |
113 | | ```a = np.random.uniform(size=(2, 10))``` | ```auto a = NdArray::Uniform({2, 10});``` or |
114 | | | ```auto a = NdArray::Uniform(Shape{2, 10});``` |
115 | | ```a = np.random.uniform(low=0.0, high=1.0, size=10)``` | ```auto a = NdArray::Uniform(0.0, 1.0, {10});``` or |
116 | | | ```auto a = NdArray::Uniform(0.0, 1.0, Shape{10});``` |
117 | | ```a = np.random.normal(10)``` | ```auto a = NdArray::Normal(10);``` |
118 | | ```a = np.random.normal(size=(2, 10))``` | ```auto a = NdArray::Normal({2, 10});``` or |
119 | | | ```auto a = NdArray::Normal(Shape{2, 10});``` |
120 | | ```a = np.random.normal(loc=0.0, scale=1.0, size=10)``` | ```auto a = NdArray::Normal(0.0, 1.0, {10}); ``` or |
121 | | | ```auto a = NdArray::Normal(0.0, 1.0, Shape{10});``` |
122 |
123 | ### Random Seed
124 |
125 | | **Numpy (Python)** | **TinyNdArray (C++)** |
126 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
127 | | ```a = np.random.seed()``` | ```auto a = NdArray::Seed();``` |
128 | | ```a = np.random.seed(0)``` | ```auto a = NdArray::Seed(0);``` |
129 |
130 | ### Basic Embeded Method
131 |
132 | | **Numpy (Python)** | **TinyNdArray (C++)** |
133 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
134 | | ```id(a)``` | ```a.id()``` |
135 | | ```a.size``` | ```a.size()``` |
136 | | ```a.shape``` | ```a.shape()``` |
137 | | ```a.ndim``` | ```a.ndim()``` |
138 | | ```a.fill(2.0)``` | ```a.fill(2.f)``` |
139 | | ```a.copy()``` | ```a.copy()``` |
140 |
141 | ### Original Embeded Method
142 |
143 | | **Numpy (Python)** | **TinyNdArray (C++)** |
144 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
145 | | ```---``` | ```a.empty()``` |
146 | | ```---``` | ```a.data()``` |
147 | | ```---``` | ```a.begin()``` |
148 | | ```---``` | ```a.end()``` |
149 |
150 | ### Single Element Casting
151 |
152 | | **Numpy (Python)** | **TinyNdArray (C++)** |
153 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
154 | | ```float(a)``` | ```static_cast(a)``` |
155 |
156 | ### Index Access
157 |
158 | | **Numpy (Python)** | **TinyNdArray (C++)** |
159 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
160 | | ```a[2, -3]``` | ```a[{2, -3}]``` or |
161 | | | ```a[Index{2, -3}]``` or |
162 | | | ```a(2, -3)``` |
163 |
164 | ### Reshape methods
165 |
166 | | **Numpy (Python)** | **TinyNdArray (C++)** |
167 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
168 | | ```a.reshape(-1, 2, 1)``` | ```a.reshape({-1, 2, 1})``` or |
169 | | | ```a.reshape(Shape{-1, 2, 1})``` or |
170 | | | ```a.reshape(-1, 2, 1)``` |
171 | | ```a.flatten()``` | ```a.flatten()``` |
172 | | ```a.ravel()``` | ```a.ravel()``` |
173 |
174 | ### Reshape functions
175 |
176 | | **Numpy (Python)** | **TinyNdArray (C++)** |
177 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
178 | | ```np.reshape(a, (-1, 2, 1))``` | ```Reshape(a, {-1, 2, 1})``` |
179 | | ```np.squeeze(a)``` | ```Squeeze(a)``` |
180 | | ```np.squeeze(a, [0, -2])``` | ```Squeeze(a, {0, -2})``` |
181 | | ```np.expand_dims(a, 1)``` | ```ExpandDims(a, 1)``` |
182 |
183 | ### Slice
184 |
185 | Slice methods create copy of the array, not reference.
186 |
187 | | **Numpy (Python)** | **TinyNdArray (C++)** |
188 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
189 | | ```a[1:5, -4:-1]``` | ```a.slice({{1, 5}, {-4, -1}})``` or |
190 | | | ```a.slice(SliceIndex{{1, 5}, {-4, -1}})``` or |
191 | | | ```a.slice({1, 5}, {-4, -1})``` |
192 |
193 | ### Print
194 |
195 | | **Numpy (Python)** | **TinyNdArray (C++)** |
196 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
197 | | ```print(a)``` | ```std::cout << a << std::endl;``` |
198 |
199 | ### Single Operators
200 | | **Numpy (Python)** | **TinyNdArray (C++)** |
201 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
202 | | ```+np.ones((2, 3))``` | ```+NdArray::Ones(2, 3)``` |
203 | | ```-np.ones((2, 3))``` | ```-NdArray::Ones(2, 3)``` |
204 |
205 | ### Arithmetic Operators
206 |
207 | All operaters supports broadcast.
208 |
209 | | **Numpy (Python)** | **TinyNdArray (C++)** |
210 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
211 | | ```np.ones((2, 1, 3)) + np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) + NdArray::Ones(4, 1)``` |
212 | | ```np.ones((2, 1, 3)) - np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) - NdArray::Ones(4, 1)``` |
213 | | ```np.ones((2, 1, 3)) * np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) * NdArray::Ones(4, 1)``` |
214 | | ```np.ones((2, 1, 3)) / np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) / NdArray::Ones(4, 1)``` |
215 | | ```np.ones((2, 1, 3)) + 2.0``` | ```NdArray::Ones(2, 1, 3) + 2.f``` |
216 | | ```np.ones((2, 1, 3)) - 2.0``` | ```NdArray::Ones(2, 1, 3) - 2.f``` |
217 | | ```np.ones((2, 1, 3)) * 2.0``` | ```NdArray::Ones(2, 1, 3) * 2.f``` |
218 | | ```np.ones((2, 1, 3)) / 2.0``` | ```NdArray::Ones(2, 1, 3) / 2.f``` |
219 | | ```2.0 + np.ones((2, 1, 3))``` | ```2.f + NdArray::Ones(2, 1, 3)``` |
220 | | ```2.0 - np.ones((2, 1, 3))``` | ```2.f - NdArray::Ones(2, 1, 3)``` |
221 | | ```2.0 * np.ones((2, 1, 3))``` | ```2.f * NdArray::Ones(2, 1, 3)``` |
222 | | ```2.0 / np.ones((2, 1, 3))``` | ```2.f / NdArray::Ones(2, 1, 3)``` |
223 |
224 | ### Comparison Operators
225 |
226 | All operaters supports broadcast.
227 |
228 | | **Numpy (Python)** | **TinyNdArray (C++)** |
229 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
230 | | ```np.ones((2, 1, 3)) == np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) == NdArray::Ones(4, 1)``` |
231 | | ```np.ones((2, 1, 3)) != np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) != NdArray::Ones(4, 1)``` |
232 | | ```np.ones((2, 1, 3)) > np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) > NdArray::Ones(4, 1)``` |
233 | | ```np.ones((2, 1, 3)) >= np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) >= NdArray::Ones(4, 1)``` |
234 | | ```np.ones((2, 1, 3)) < np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) < NdArray::Ones(4, 1)``` |
235 | | ```np.ones((2, 1, 3)) <= np.ones((4, 1))``` | ```NdArray::Ones(2, 1, 3) <= NdArray::Ones(4, 1)``` |
236 | | ```np.ones((2, 1, 3)) == 1.f``` | ```NdArray::Ones(2, 1, 3) == 1.f``` |
237 | | ```np.ones((2, 1, 3)) != 1.f``` | ```NdArray::Ones(2, 1, 3) != 1.f``` |
238 | | ```np.ones((2, 1, 3)) > 1.f``` | ```NdArray::Ones(2, 1, 3) > 1.f``` |
239 | | ```np.ones((2, 1, 3)) >= 1.f``` | ```NdArray::Ones(2, 1, 3) >= 1.f``` |
240 | | ```np.ones((2, 1, 3)) < 1.f``` | ```NdArray::Ones(2, 1, 3) < 1.f``` |
241 | | ```np.ones((2, 1, 3)) <= 1.f``` | ```NdArray::Ones(2, 1, 3) <= 1.f``` |
242 | | ```1.f == np.ones((4, 1))``` | ```1.f == NdArray::Ones(4, 1)``` |
243 | | ```1.f != np.ones((4, 1))``` | ```1.f != NdArray::Ones(4, 1)``` |
244 | | ```1.f > np.ones((4, 1))``` | ```1.f > NdArray::Ones(4, 1)``` |
245 | | ```1.f >= np.ones(4, 1))``` | ```1.f >= NdArray::Ones(4, 1)``` |
246 | | ```1.f < np.ones((4, 1))``` | ```1.f < NdArray::Ones(4, 1)``` |
247 | | ```1.f <= np.ones((4, 1))``` | ```1.f <= NdArray::Ones(4, 1)``` |
248 |
249 | ### Compound Assignment Operators
250 |
251 | All operaters supports broadcast. However, left-side variable keep its size.
252 |
253 | | **Numpy (Python)** | **TinyNdArray (C++)** |
254 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
255 | | ```np.ones((2, 1, 3)) += np.ones(3)``` | ```NdArray::Ones(2, 1, 3) += NdArray::Ones(3)``` |
256 | | ```np.ones((2, 1, 3)) -= np.ones(3)``` | ```NdArray::Ones(2, 1, 3) -= NdArray::Ones(3)``` |
257 | | ```np.ones((2, 1, 3)) *= np.ones(3)``` | ```NdArray::Ones(2, 1, 3) *= NdArray::Ones(3)``` |
258 | | ```np.ones((2, 1, 3)) /= np.ones(3)``` | ```NdArray::Ones(2, 1, 3) /= NdArray::Ones(3)``` |
259 | | ```np.ones((2, 1, 3)) += 2.f``` | ```NdArray::Ones(2, 1, 3) += 2.f``` |
260 | | ```np.ones((2, 1, 3)) -= 2.f``` | ```NdArray::Ones(2, 1, 3) -= 2.f``` |
261 | | ```np.ones((2, 1, 3)) *= 2.f``` | ```NdArray::Ones(2, 1, 3) *= 2.f``` |
262 | | ```np.ones((2, 1, 3)) /= 2.f``` | ```NdArray::Ones(2, 1, 3) /= 2.f``` |
263 |
264 | ### Math Functions
265 |
266 | Functions which takes two arguments support broadcast.
267 |
268 | | **Numpy (Python)** | **TinyNdArray (C++)** |
269 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
270 | | ```np.abs(a)``` | ```Abs(a)``` |
271 | | ```np.sign(a)``` | ```Sign(a)``` |
272 | | ```np.ceil(a)``` | ```Ceil(a)``` |
273 | | ```np.floor(a)``` | ```Floor(a)``` |
274 | | ```np.clip(a, x_min, x_max)``` | ```Clip(a, x_min, x_max)``` |
275 | | ```np.sqrt(a)``` | ```Sqrt(a)``` |
276 | | ```np.exp(a)``` | ```Exp(a)``` |
277 | | ```np.log(a)``` | ```Log(a)``` |
278 | | ```np.square(a)``` | ```Square(a)``` |
279 | | ```np.power(a, b)``` | ```Power(a, b)``` |
280 | | ```np.power(a, 2.0)``` | ```Power(a, 2.f)``` |
281 | | ```np.power(2.0, a)``` | ```Power(2.f, a)``` |
282 | | ```np.sin(a)``` | ```Sin(a)``` |
283 | | ```np.cos(a)``` | ```Cos(a)``` |
284 | | ```np.tan(a)``` | ```Tan(a)``` |
285 | | ```np.arcsin(a)``` | ```ArcSin(a)``` |
286 | | ```np.arccos(a)``` | ```ArcCos(a)``` |
287 | | ```np.arctan(a)``` | ```ArcTan(a)``` |
288 | | ```np.arctan2(a, b)``` | ```ArcTan2(a, b)``` |
289 | | ```np.arctan2(a, 10.0)``` | ```ArcTan2(a, 10.f)``` |
290 | | ```np.arctan2(10.0, a)``` | ```ArcTan2(10.f, a)``` |
291 |
292 | ### Axis Functions
293 |
294 | | **Numpy (Python)** | **TinyNdArray (C++)** |
295 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
296 | | ```np.sum(a)``` | ```Sum(a)``` |
297 | | ```np.sum(a, axis=0)``` | ```Sum(a, {0})``` or |
298 | | | ```Sum(a, Axis{0})``` |
299 | | ```np.sum(a, axis=(0, 2))``` | ```Sum(a, {0, 2})``` or |
300 | | | ```Sum(a, Axis{0, 2})``` |
301 | | ```np.mean(a, axis=0)``` | ```Mean(a, {0})``` |
302 | | ```np.min(a, axis=0)``` | ```Min(a, {0})``` |
303 | | ```np.max(a, axis=0)``` | ```Max(a, {0})``` |
304 |
305 | ### Logistic Functions
306 |
307 | | **Numpy (Python)** | **TinyNdArray (C++)** |
308 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
309 | | ```np.all(a)``` | ```All(a, {0})``` |
310 | | ```np.all(a, axis=0)``` | ```All(a, {0})``` |
311 | | ```np.any(a)``` | ```Any(a, {0})``` |
312 | | ```np.any(a, axis=0)``` | ```Any(a, {0})``` |
313 | | ```np.where(condition, x, y)``` | ```Where(condition, x, y)``` |
314 |
315 | ### Axis Method
316 |
317 | | **Numpy (Python)** | **TinyNdArray (C++)** |
318 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
319 | | ```a.sum()``` | ```a.sum()``` |
320 | | ```a.sum(axis=0)``` | ```a.sum({0})``` or |
321 | | | ```a.sum(Axis{0})``` |
322 | | ```a.sum(axis=(0, 2))``` | ```a.sum({0, 2})``` or |
323 | | | ```a.sum(Axis{0, 2})``` |
324 | | ```a.mean(axis=0)``` | ```a.mean({0})``` |
325 | | ```a.min(axis=0)``` | ```a.min({0})``` |
326 | | ```a.max(axis=0)``` | ```a.max({0})``` |
327 |
328 | ### Grouping Functions
329 |
330 | | **Numpy (Python)** | **TinyNdArray (C++)** |
331 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
332 | | ```np.stack((a, b, ...), axis=0)``` | ```Stack({a, b, ...}, 0)``` |
333 | | ```np.concatenate((a, b, ...), axis=0)``` | ```Concatenate({a, b, ...}, 0)``` |
334 | | ```np.split(a, 2, axis=0)``` | ```Split(a, 2, 0)``` |
335 | | ```np.split(a, [1, 3], axis=0)``` | ```Split(a, {1, 3}, 0)``` |
336 | | | ```Separate(a, 0)``` An inverse of Stack(a, 0) |
337 |
338 | ### View Changing Functions
339 |
340 | View chaining methods create copy of the array, not reference.
341 |
342 | | **Numpy (Python)** | **TinyNdArray (C++)** |
343 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
344 | | ```np.transpose(x)``` | ```Transpose(x)``` |
345 | | ```np.swapaxes(x, 0, 2)``` | ```Swapaxes(x, 0, 2)``` |
346 | | ```np.broadcast_to(x, (3, 2))``` | ```BroadcastTo(x, {3, 2})``` |
347 | | | ```SumTo(x, {3, 2})``` An inverse of BroadcastTo(x, {3, 2}) |
348 |
349 | ### Matrix Products
350 |
351 | All dimension rules of numpy are implemented.
352 |
353 | | **Numpy (Python)** | **TinyNdArray (C++)** |
354 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
355 | | ```np.dot(a, b)``` | ```Dot(a, b)``` |
356 | | ```a.dot(b)``` | ```a.dot(b)``` |
357 | | ```np.matmul(a, b)``` | ```Matmul(a, b)``` |
358 | | ```np.cross(a, b)``` | ```Cross(a, b)``` |
359 | | ```a.cross(b)``` | ```a.cross(b)``` |
360 |
361 | ### Inverse
362 |
363 | All dimension rules of numpy are implemented.
364 |
365 | | **Numpy (Python)** | **TinyNdArray (C++)** |
366 | |:--------------------------------------------------------:|:--------------------------------------------------------:|
367 | | ```np.linalg.inv(a)``` | ```Inv(a, b)``` |
368 |
369 |
370 | ## In-place Operation
371 |
372 | In NumPy, in-place and not inplace operations are written by following.
373 |
374 |
375 |
376 | Numpy (Python) In-place |
377 | Numpy (Python) Not in-place |
378 |
379 |
380 | a = np.ones((2, 3))
381 | a_id = id(a)
382 | np.exp(a, out=a) # in-place
383 | print(id(a) == a_id) # True
384 | |
385 |
386 | a = np.ones((2, 3))
387 | a_id = id(a)
388 | a = np.exp(a) # not in-place
389 | print(id(a) == a_id) # False
390 | |
391 |
392 |
393 |
394 | In TinyNdArray, when right-reference values are passed, no new arrays are created and operated in-place.
395 |
396 |
397 |
398 | TinyNdArray (C++) In-place |
399 | TinyNdArray (C++) Not in-place |
400 |
401 |
402 | auto a = NdArray::Ones(2, 3);
403 | auto a_id = a.id();
404 | a = np.exp(std::move(a)); // in-place
405 | std::cout << (a.id() == a_id)
406 | << std::endl; // true
407 | |
408 |
409 | auto a = NdArray::Ones(2, 3);
410 | auto a_id = a.id();
411 | a = np.exp(a); // not in-place
412 | std::cout << (a.id() == a_id)
413 | << std::endl; // false
414 | |
415 |
416 |
417 |
418 | However, even right-reference values are passed, when the size is changed by broadcasting, a new array will be created.
419 |
420 |
421 |
422 | TinyNdArray (C++) In-place |
423 | TinyNdArray (C++) Not in-place |
424 |
425 |
426 | auto a = NdArray::Ones(2, 1, 3);
427 | auto b = NdArray::Ones(3);
428 | auto a_id = a.id();
429 | a = np.exp(std::move(a)); // in-place
430 | std::cout << a.shape()
431 | << std::endl; // [2, 1, 3]
432 | std::cout << (a.id() == a_id)
433 | << std::endl; // true
434 | |
435 |
436 | auto a = NdArray::Ones(2, 1, 3);
437 | auto a = NdArray::Ones(3, 1);
438 | auto a_id = a.id();
439 | a = np.exp(srd::move(a)); // looks like in-place
440 | std::cout << a.shape()
441 | << std::endl; // [2, 3, 3]
442 | std::cout << (a.id() == a_id)
443 | << std::endl; // false
444 | |
445 |
446 |
447 |
448 | ## Parallel Execusion
449 |
450 | In default, most of all operations run in parallel by threads.
451 |
452 | When changing the number of workers, please set via `NdArray::SetNumWorkers()`.
453 |
454 |
455 | // Default setting. Use all of cores.
456 | NdArray::SetNumWorkers(-1);
457 | // Set no parallel.
458 | NdArray::SetNumWorkers(1);
459 | // Use 4 cores
460 | NdArray::SetNumWorkers(4);
461 |
462 |
463 | ## Memory profiling
464 |
465 | When a macro `TINYNDARRAY_PROFILE_MEMORY` is defined, memory profiler is activated.
466 |
467 | The following methods can be used to get the number of instances and the size of allocated memories..
468 |
469 | ```NdArray::GetNumInstance()```
470 | ```NdArray::GetTotalMemory()```
471 |
472 | ## Macros
473 |
474 | * TINYNDARRAY_H_ONCE
475 | * TINYNDARRAY_NO_INCLUDE
476 | * TINYNDARRAY_NO_NAMESPACE
477 | * TINYNDARRAY_NO_DECLARATION
478 | * TINYNDARRAY_IMPLEMENTATION
479 | * TINYNDARRAY_PROFILE_MEMORY
480 |
481 | ## TODO
482 |
483 | * [x] Replace axis reduction function with more effective algorithm.
484 | * [ ] Implement more effective algorithm.
485 | * [x] Replace slice method's recursive call with loop for speed up.
486 | * [x] Make parallel
487 | * [ ] Improve inverse function with LU decomposition.
488 | * [ ] Implement reference slice which dose not effect the current performance.
489 | * [ ] Introduce SIMD instructions.
490 |
491 | Everything in the upper list are difficult challenges. If you have any ideas, please let me know.
492 |
--------------------------------------------------------------------------------
/tests/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.8.2)
2 |
3 | message(STATUS "Build tinyndarray tests")
4 |
5 | # ------------------------------------------------------------------------------
6 | # ----------------------------------- Common -----------------------------------
7 | # ------------------------------------------------------------------------------
8 | project(tinyndarray_tests CXX C)
9 |
10 | set(CMAKE_CXX_STANDARD 14) # C++ 14
11 |
12 | if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
13 | set(LINK_TYPE STATIC)
14 | else()
15 | set(LINK_TYPE SHARED)
16 | set(CMAKE_POSITION_INDEPENDENT_CODE ON)
17 | endif()
18 |
19 | # Print make commands for debug
20 | # set(CMAKE_VERBOSE_MAKEFILE 1)
21 |
22 | # Set default build type
23 | if (NOT CMAKE_BUILD_TYPE)
24 | set(CMAKE_BUILD_TYPE Release)
25 | endif()
26 | message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
27 |
28 | # Output `compile_commands.json`
29 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
30 |
31 | # cmake modules
32 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
33 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/sanitizers)
34 | find_package(Sanitizers) # Address sanitizer (-DSANITIZE_ADDRESS=ON)
35 |
36 | # Set output directories
37 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
38 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
39 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
40 |
41 | # Warning options
42 | if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
43 | set(warning_options "-Wall -Wextra -Wconversion")
44 | elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang")
45 | set(warning_options "-Wall -Wextra -Wcast-align -Wcast-qual \
46 | -Wctor-dtor-privacy -Wdisabled-optimization \
47 | -Wformat=2 -Winit-self \
48 | -Wmissing-declarations -Wmissing-include-dirs \
49 | -Wold-style-cast -Woverloaded-virtual \
50 | -Wredundant-decls -Wshadow -Wsign-conversion \
51 | -Wsign-promo -Wno-old-style-cast\
52 | -Wstrict-overflow=5 -Wundef -Wno-unknown-pragmas \
53 | -Wreturn-std-move")
54 | elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
55 | set(warning_options "--pedantic -Wall -Wextra -Wcast-align -Wcast-qual \
56 | -Wctor-dtor-privacy -Wdisabled-optimization \
57 | -Wformat=2 -Winit-self -Wlogical-op \
58 | -Wmissing-declarations -Wmissing-include-dirs \
59 | -Wnoexcept -Wold-style-cast -Woverloaded-virtual \
60 | -Wredundant-decls -Wshadow -Wsign-conversion \
61 | -Wsign-promo -Wstrict-null-sentinel \
62 | -Wstrict-overflow=5 -Wundef -Wno-unknown-pragmas")
63 | elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
64 | set(warning_options "/W4")
65 | else()
66 | message(WARNING "Unsupported compiler for warning options")
67 | message("CMAKE_CXX_COMPILER_ID is ${CMAKE_CXX_COMPILER_ID}")
68 | endif()
69 |
70 | # Utility function to setup a target (include, link, warning, sanitizer)
71 | function(setup_target target includes libs)
72 | target_include_directories(${target} PUBLIC ${includes})
73 | target_link_libraries(${target} ${libs})
74 | set_target_properties(${target} PROPERTIES COMPILE_FLAGS ${warning_options})
75 | add_sanitizers(${target})
76 | endfunction(setup_target)
77 |
78 | # Utility function to setup a target simply (include, link)
79 | function(setup_target_simple target includes libs)
80 | target_include_directories(${target} PUBLIC ${includes})
81 | target_link_libraries(${target} ${libs})
82 | endfunction(setup_target_simple)
83 |
84 |
85 | # ------------------------------------------------------------------------------
86 | # ----------------------------- Internal Libraries -----------------------------
87 | # ------------------------------------------------------------------------------
88 |
89 | find_package(Threads)
90 | set(tinyndarray_header ${CMAKE_CURRENT_SOURCE_DIR}/../tinyndarray.h)
91 |
92 | # Executable file for tests (one c++ file)
93 | add_executable(tinyndarray_tests
94 | ${tinyndarray_header}
95 | ${CMAKE_CURRENT_SOURCE_DIR}/test_main.cpp
96 | ${CMAKE_CURRENT_SOURCE_DIR}/test_total.cpp
97 | )
98 | setup_target(tinyndarray_tests "" "${CMAKE_THREAD_LIBS_INIT}")
99 |
100 | # Executable file for tests (separated c++ files)
101 | add_executable(tinyndarray_tests_separated
102 | ${tinyndarray_header}
103 | ${CMAKE_CURRENT_SOURCE_DIR}/test_main.cpp
104 | ${CMAKE_CURRENT_SOURCE_DIR}/test_separated_1.cpp
105 | ${CMAKE_CURRENT_SOURCE_DIR}/test_separated_2.cpp
106 | )
107 | setup_target(tinyndarray_tests_separated "" "${CMAKE_THREAD_LIBS_INIT}")
108 |
109 | # Executable file for tests (time consumption)
110 | add_executable(tinyndarray_tests_time_consumption
111 | ${tinyndarray_header}
112 | ${CMAKE_CURRENT_SOURCE_DIR}/test_main.cpp
113 | ${CMAKE_CURRENT_SOURCE_DIR}/test_time_consumption.cpp
114 | )
115 | setup_target(tinyndarray_tests_time_consumption "" "${CMAKE_THREAD_LIBS_INIT}")
116 |
117 | # Executable file for tests (profiling test)
118 | add_executable(tinyndarray_tests_profiling
119 | ${tinyndarray_header}
120 | ${CMAKE_CURRENT_SOURCE_DIR}/test_main.cpp
121 | ${CMAKE_CURRENT_SOURCE_DIR}/test_profiling.cpp
122 | )
123 | setup_target(tinyndarray_tests_profiling "" "${CMAKE_THREAD_LIBS_INIT}")
124 |
125 | # Executable file for tests (dot majoring test)
126 | add_executable(tinyndarray_tests_dot_major
127 | ${tinyndarray_header}
128 | ${CMAKE_CURRENT_SOURCE_DIR}/test_dot_major.cpp
129 | )
130 | setup_target(tinyndarray_tests_dot_major "" "${CMAKE_THREAD_LIBS_INIT}")
131 |
--------------------------------------------------------------------------------
/tests/cmake/sanitizers/FindASan.cmake:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 | #
3 | # Copyright (c)
4 | # 2013 Matthew Arsenault
5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | option(SANITIZE_ADDRESS "Enable AddressSanitizer for sanitized targets." Off)
26 |
27 | set(FLAG_CANDIDATES
28 | # Clang 3.2+ use this version. The no-omit-frame-pointer option is optional.
29 | "-g -fsanitize=address -fno-omit-frame-pointer"
30 | "-g -fsanitize=address"
31 |
32 | # Older deprecated flag for ASan
33 | "-g -faddress-sanitizer"
34 | )
35 |
36 |
37 | if (SANITIZE_ADDRESS AND (SANITIZE_THREAD OR SANITIZE_MEMORY))
38 | message(FATAL_ERROR "AddressSanitizer is not compatible with "
39 | "ThreadSanitizer or MemorySanitizer.")
40 | endif ()
41 |
42 |
43 | include(sanitize-helpers)
44 |
45 | if (SANITIZE_ADDRESS)
46 | sanitizer_check_compiler_flags("${FLAG_CANDIDATES}" "AddressSanitizer"
47 | "ASan")
48 |
49 | find_program(ASan_WRAPPER "asan-wrapper" PATHS ${CMAKE_MODULE_PATH})
50 | mark_as_advanced(ASan_WRAPPER)
51 | endif ()
52 |
53 | function (add_sanitize_address TARGET)
54 | if (NOT SANITIZE_ADDRESS)
55 | return()
56 | endif ()
57 |
58 | saitizer_add_flags(${TARGET} "AddressSanitizer" "ASan")
59 | endfunction ()
60 |
--------------------------------------------------------------------------------
/tests/cmake/sanitizers/FindMSan.cmake:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 | #
3 | # Copyright (c)
4 | # 2013 Matthew Arsenault
5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | option(SANITIZE_MEMORY "Enable MemorySanitizer for sanitized targets." Off)
26 |
27 | set(FLAG_CANDIDATES
28 | "-g -fsanitize=memory"
29 | )
30 |
31 |
32 | include(sanitize-helpers)
33 |
34 | if (SANITIZE_MEMORY)
35 | if (NOT ${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
36 | message(WARNING "MemorySanitizer disabled for target ${TARGET} because "
37 | "MemorySanitizer is supported for Linux systems only.")
38 | set(SANITIZE_MEMORY Off CACHE BOOL
39 | "Enable MemorySanitizer for sanitized targets." FORCE)
40 | elseif (NOT ${CMAKE_SIZEOF_VOID_P} EQUAL 8)
41 | message(WARNING "MemorySanitizer disabled for target ${TARGET} because "
42 | "MemorySanitizer is supported for 64bit systems only.")
43 | set(SANITIZE_MEMORY Off CACHE BOOL
44 | "Enable MemorySanitizer for sanitized targets." FORCE)
45 | else ()
46 | sanitizer_check_compiler_flags("${FLAG_CANDIDATES}" "MemorySanitizer"
47 | "MSan")
48 | endif ()
49 | endif ()
50 |
51 | function (add_sanitize_memory TARGET)
52 | if (NOT SANITIZE_MEMORY)
53 | return()
54 | endif ()
55 |
56 | saitizer_add_flags(${TARGET} "MemorySanitizer" "MSan")
57 | endfunction ()
58 |
--------------------------------------------------------------------------------
/tests/cmake/sanitizers/FindSanitizers.cmake:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 | #
3 | # Copyright (c)
4 | # 2013 Matthew Arsenault
5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | # If any of the used compiler is a GNU compiler, add a second option to static
26 | # link against the sanitizers.
27 | option(SANITIZE_LINK_STATIC "Try to link static against sanitizers." Off)
28 |
29 |
30 |
31 |
32 | set(FIND_QUIETLY_FLAG "")
33 | if (DEFINED Sanitizers_FIND_QUIETLY)
34 | set(FIND_QUIETLY_FLAG "QUIET")
35 | endif ()
36 |
37 | find_package(ASan ${FIND_QUIETLY_FLAG})
38 | find_package(TSan ${FIND_QUIETLY_FLAG})
39 | find_package(MSan ${FIND_QUIETLY_FLAG})
40 | find_package(UBSan ${FIND_QUIETLY_FLAG})
41 |
42 |
43 |
44 |
45 | function(sanitizer_add_blacklist_file FILE)
46 | if(NOT IS_ABSOLUTE ${FILE})
47 | set(FILE "${CMAKE_CURRENT_SOURCE_DIR}/${FILE}")
48 | endif()
49 | get_filename_component(FILE "${FILE}" REALPATH)
50 |
51 | sanitizer_check_compiler_flags("-fsanitize-blacklist=${FILE}"
52 | "SanitizerBlacklist" "SanBlist")
53 | endfunction()
54 |
55 | function(add_sanitizers ...)
56 | # If no sanitizer is enabled, return immediately.
57 | if (NOT (SANITIZE_ADDRESS OR SANITIZE_MEMORY OR SANITIZE_THREAD OR
58 | SANITIZE_UNDEFINED))
59 | return()
60 | endif ()
61 |
62 | foreach (TARGET ${ARGV})
63 | # Check if this target will be compiled by exactly one compiler. Other-
64 | # wise sanitizers can't be used and a warning should be printed once.
65 | sanitizer_target_compilers(${TARGET} TARGET_COMPILER)
66 | list(LENGTH TARGET_COMPILER NUM_COMPILERS)
67 | if (NUM_COMPILERS GREATER 1)
68 | message(WARNING "Can't use any sanitizers for target ${TARGET}, "
69 | "because it will be compiled by incompatible compilers. "
70 | "Target will be compiled without sanitzers.")
71 | return()
72 |
73 | # If the target is compiled by no known compiler, ignore it.
74 | elseif (NUM_COMPILERS EQUAL 0)
75 | message(WARNING "Can't use any sanitizers for target ${TARGET}, "
76 | "because it uses an unknown compiler. Target will be "
77 | "compiled without sanitzers.")
78 | return()
79 | endif ()
80 |
81 | # Add sanitizers for target.
82 | add_sanitize_address(${TARGET})
83 | add_sanitize_thread(${TARGET})
84 | add_sanitize_memory(${TARGET})
85 | add_sanitize_undefined(${TARGET})
86 | endforeach ()
87 | endfunction(add_sanitizers)
88 |
--------------------------------------------------------------------------------
/tests/cmake/sanitizers/FindTSan.cmake:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 | #
3 | # Copyright (c)
4 | # 2013 Matthew Arsenault
5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | option(SANITIZE_THREAD "Enable ThreadSanitizer for sanitized targets." Off)
26 |
27 | set(FLAG_CANDIDATES
28 | "-g -fsanitize=thread"
29 | )
30 |
31 |
32 | # ThreadSanitizer is not compatible with MemorySanitizer.
33 | if (SANITIZE_THREAD AND SANITIZE_MEMORY)
34 | message(FATAL_ERROR "ThreadSanitizer is not compatible with "
35 | "MemorySanitizer.")
36 | endif ()
37 |
38 |
39 | include(sanitize-helpers)
40 |
41 | if (SANITIZE_THREAD)
42 | if (NOT ${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
43 | message(WARNING "ThreadSanitizer disabled for target ${TARGET} because "
44 | "ThreadSanitizer is supported for Linux systems only.")
45 | set(SANITIZE_THREAD Off CACHE BOOL
46 | "Enable ThreadSanitizer for sanitized targets." FORCE)
47 | elseif (NOT ${CMAKE_SIZEOF_VOID_P} EQUAL 8)
48 | message(WARNING "ThreadSanitizer disabled for target ${TARGET} because "
49 | "ThreadSanitizer is supported for 64bit systems only.")
50 | set(SANITIZE_THREAD Off CACHE BOOL
51 | "Enable ThreadSanitizer for sanitized targets." FORCE)
52 | else ()
53 | sanitizer_check_compiler_flags("${FLAG_CANDIDATES}" "ThreadSanitizer"
54 | "TSan")
55 | endif ()
56 | endif ()
57 |
58 | function (add_sanitize_thread TARGET)
59 | if (NOT SANITIZE_THREAD)
60 | return()
61 | endif ()
62 |
63 | saitizer_add_flags(${TARGET} "ThreadSanitizer" "TSan")
64 | endfunction ()
65 |
--------------------------------------------------------------------------------
/tests/cmake/sanitizers/FindUBSan.cmake:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 | #
3 | # Copyright (c)
4 | # 2013 Matthew Arsenault
5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | option(SANITIZE_UNDEFINED
26 | "Enable UndefinedBehaviorSanitizer for sanitized targets." Off)
27 |
28 | set(FLAG_CANDIDATES
29 | "-g -fsanitize=undefined"
30 | )
31 |
32 |
33 | include(sanitize-helpers)
34 |
35 | if (SANITIZE_UNDEFINED)
36 | sanitizer_check_compiler_flags("${FLAG_CANDIDATES}"
37 | "UndefinedBehaviorSanitizer" "UBSan")
38 | endif ()
39 |
40 | function (add_sanitize_undefined TARGET)
41 | if (NOT SANITIZE_UNDEFINED)
42 | return()
43 | endif ()
44 |
45 | saitizer_add_flags(${TARGET} "UndefinedBehaviorSanitizer" "UBSan")
46 | endfunction ()
47 |
--------------------------------------------------------------------------------
/tests/cmake/sanitizers/asan-wrapper:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # The MIT License (MIT)
4 | #
5 | # Copyright (c)
6 | # 2013 Matthew Arsenault
7 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany
8 | #
9 | # Permission is hereby granted, free of charge, to any person obtaining a copy
10 | # of this software and associated documentation files (the "Software"), to deal
11 | # in the Software without restriction, including without limitation the rights
12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | # copies of the Software, and to permit persons to whom the Software is
14 | # furnished to do so, subject to the following conditions:
15 | #
16 | # The above copyright notice and this permission notice shall be included in all
17 | # copies or substantial portions of the Software.
18 | #
19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | # SOFTWARE.
26 |
27 | # This script is a wrapper for AddressSanitizer. In some special cases you need
28 | # to preload AddressSanitizer to avoid error messages - e.g. if you're
29 | # preloading another library to your application. At the moment this script will
30 | # only do something, if we're running on a Linux platform. OSX might not be
31 | # affected.
32 |
33 |
34 | # Exit immediately, if platform is not Linux.
35 | if [ "$(uname)" != "Linux" ]
36 | then
37 | exec $@
38 | fi
39 |
40 |
41 | # Get the used libasan of the application ($1). If a libasan was found, it will
42 | # be prepended to LD_PRELOAD.
43 | libasan=$(ldd $1 | grep libasan | sed "s/^[[:space:]]//" | cut -d' ' -f1)
44 | if [ -n "$libasan" ]
45 | then
46 | if [ -n "$LD_PRELOAD" ]
47 | then
48 | export LD_PRELOAD="$libasan:$LD_PRELOAD"
49 | else
50 | export LD_PRELOAD="$libasan"
51 | fi
52 | fi
53 |
54 | # Execute the application.
55 | exec $@
56 |
--------------------------------------------------------------------------------
/tests/cmake/sanitizers/sanitize-helpers.cmake:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 | #
3 | # Copyright (c)
4 | # 2013 Matthew Arsenault
5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | # Helper function to get the language of a source file.
26 | function (sanitizer_lang_of_source FILE RETURN_VAR)
27 | get_filename_component(FILE_EXT "${FILE}" EXT)
28 | string(TOLOWER "${FILE_EXT}" FILE_EXT)
29 | string(SUBSTRING "${FILE_EXT}" 1 -1 FILE_EXT)
30 |
31 | get_property(ENABLED_LANGUAGES GLOBAL PROPERTY ENABLED_LANGUAGES)
32 | foreach (LANG ${ENABLED_LANGUAGES})
33 | list(FIND CMAKE_${LANG}_SOURCE_FILE_EXTENSIONS "${FILE_EXT}" TEMP)
34 | if (NOT ${TEMP} EQUAL -1)
35 | set(${RETURN_VAR} "${LANG}" PARENT_SCOPE)
36 | return()
37 | endif ()
38 | endforeach()
39 |
40 | set(${RETURN_VAR} "" PARENT_SCOPE)
41 | endfunction ()
42 |
43 |
44 | # Helper function to get compilers used by a target.
45 | function (sanitizer_target_compilers TARGET RETURN_VAR)
46 | # Check if all sources for target use the same compiler. If a target uses
47 | # e.g. C and Fortran mixed and uses different compilers (e.g. clang and
48 | # gfortran) this can trigger huge problems, because different compilers may
49 | # use different implementations for sanitizers.
50 | set(BUFFER "")
51 | get_target_property(TSOURCES ${TARGET} SOURCES)
52 | foreach (FILE ${TSOURCES})
53 | # If expression was found, FILE is a generator-expression for an object
54 | # library. Object libraries will be ignored.
55 | string(REGEX MATCH "TARGET_OBJECTS:([^ >]+)" _file ${FILE})
56 | if ("${_file}" STREQUAL "")
57 | sanitizer_lang_of_source(${FILE} LANG)
58 | if (LANG)
59 | list(APPEND BUFFER ${CMAKE_${LANG}_COMPILER_ID})
60 | endif ()
61 | endif ()
62 | endforeach ()
63 |
64 | list(REMOVE_DUPLICATES BUFFER)
65 | set(${RETURN_VAR} "${BUFFER}" PARENT_SCOPE)
66 | endfunction ()
67 |
68 |
69 | # Helper function to check compiler flags for language compiler.
70 | function (sanitizer_check_compiler_flag FLAG LANG VARIABLE)
71 | if (${LANG} STREQUAL "C")
72 | include(CheckCCompilerFlag)
73 | check_c_compiler_flag("${FLAG}" ${VARIABLE})
74 |
75 | elseif (${LANG} STREQUAL "CXX")
76 | include(CheckCXXCompilerFlag)
77 | check_cxx_compiler_flag("${FLAG}" ${VARIABLE})
78 |
79 | elseif (${LANG} STREQUAL "Fortran")
80 | # CheckFortranCompilerFlag was introduced in CMake 3.x. To be compatible
81 | # with older Cmake versions, we will check if this module is present
82 | # before we use it. Otherwise we will define Fortran coverage support as
83 | # not available.
84 | include(CheckFortranCompilerFlag OPTIONAL RESULT_VARIABLE INCLUDED)
85 | if (INCLUDED)
86 | check_fortran_compiler_flag("${FLAG}" ${VARIABLE})
87 | elseif (NOT CMAKE_REQUIRED_QUIET)
88 | message(STATUS "Performing Test ${VARIABLE}")
89 | message(STATUS "Performing Test ${VARIABLE}"
90 | " - Failed (Check not supported)")
91 | endif ()
92 | endif()
93 | endfunction ()
94 |
95 |
96 | # Helper function to test compiler flags.
97 | function (sanitizer_check_compiler_flags FLAG_CANDIDATES NAME PREFIX)
98 | set(CMAKE_REQUIRED_QUIET ${${PREFIX}_FIND_QUIETLY})
99 |
100 | get_property(ENABLED_LANGUAGES GLOBAL PROPERTY ENABLED_LANGUAGES)
101 | foreach (LANG ${ENABLED_LANGUAGES})
102 | # Sanitizer flags are not dependend on language, but the used compiler.
103 | # So instead of searching flags foreach language, search flags foreach
104 | # compiler used.
105 | set(COMPILER ${CMAKE_${LANG}_COMPILER_ID})
106 | if (NOT DEFINED ${PREFIX}_${COMPILER}_FLAGS)
107 | foreach (FLAG ${FLAG_CANDIDATES})
108 | if(NOT CMAKE_REQUIRED_QUIET)
109 | message(STATUS "Try ${COMPILER} ${NAME} flag = [${FLAG}]")
110 | endif()
111 |
112 | set(CMAKE_REQUIRED_FLAGS "${FLAG}")
113 | unset(${PREFIX}_FLAG_DETECTED CACHE)
114 | sanitizer_check_compiler_flag("${FLAG}" ${LANG}
115 | ${PREFIX}_FLAG_DETECTED)
116 |
117 | if (${PREFIX}_FLAG_DETECTED)
118 | # If compiler is a GNU compiler, search for static flag, if
119 | # SANITIZE_LINK_STATIC is enabled.
120 | if (SANITIZE_LINK_STATIC AND (${COMPILER} STREQUAL "GNU"))
121 | string(TOLOWER ${PREFIX} PREFIX_lower)
122 | sanitizer_check_compiler_flag(
123 | "-static-lib${PREFIX_lower}" ${LANG}
124 | ${PREFIX}_STATIC_FLAG_DETECTED)
125 |
126 | if (${PREFIX}_STATIC_FLAG_DETECTED)
127 | set(FLAG "-static-lib${PREFIX_lower} ${FLAG}")
128 | endif ()
129 | endif ()
130 |
131 | set(${PREFIX}_${COMPILER}_FLAGS "${FLAG}" CACHE STRING
132 | "${NAME} flags for ${COMPILER} compiler.")
133 | mark_as_advanced(${PREFIX}_${COMPILER}_FLAGS)
134 | break()
135 | endif ()
136 | endforeach ()
137 |
138 | if (NOT ${PREFIX}_FLAG_DETECTED)
139 | set(${PREFIX}_${COMPILER}_FLAGS "" CACHE STRING
140 | "${NAME} flags for ${COMPILER} compiler.")
141 | mark_as_advanced(${PREFIX}_${COMPILER}_FLAGS)
142 |
143 | message(WARNING "${NAME} is not available for ${COMPILER} "
144 | "compiler. Targets using this compiler will be "
145 | "compiled without ${NAME}.")
146 | endif ()
147 | endif ()
148 | endforeach ()
149 | endfunction ()
150 |
151 |
152 | # Helper to assign sanitizer flags for TARGET.
153 | function (saitizer_add_flags TARGET NAME PREFIX)
154 | # Get list of compilers used by target and check, if sanitizer is available
155 | # for this target. Other compiler checks like check for conflicting
156 | # compilers will be done in add_sanitizers function.
157 | sanitizer_target_compilers(${TARGET} TARGET_COMPILER)
158 | list(LENGTH TARGET_COMPILER NUM_COMPILERS)
159 | if ("${${PREFIX}_${TARGET_COMPILER}_FLAGS}" STREQUAL "")
160 | return()
161 | endif()
162 |
163 | # Set compile- and link-flags for target.
164 | set_property(TARGET ${TARGET} APPEND_STRING
165 | PROPERTY COMPILE_FLAGS " ${${PREFIX}_${TARGET_COMPILER}_FLAGS}")
166 | set_property(TARGET ${TARGET} APPEND_STRING
167 | PROPERTY COMPILE_FLAGS " ${SanBlist_${TARGET_COMPILER}_FLAGS}")
168 | set_property(TARGET ${TARGET} APPEND_STRING
169 | PROPERTY LINK_FLAGS " ${${PREFIX}_${TARGET_COMPILER}_FLAGS}")
170 | endfunction ()
171 |
--------------------------------------------------------------------------------
/tests/test_common.h:
--------------------------------------------------------------------------------
1 | #include "Catch2/single_include/catch2/catch.hpp"
2 |
3 | #include "../tinyndarray.h"
4 |
5 | #include
6 |
7 | using namespace tinyndarray;
8 |
9 | static void CheckNdArray(const NdArray& m, const std::string& str,
10 | int precision = -1) {
11 | std::stringstream ss;
12 | if (0 < precision) {
13 | ss << std::setprecision(4);
14 | }
15 | ss << m;
16 | CHECK(ss.str() == str);
17 | }
18 |
19 | static void CheckNdArrayInplace(NdArray&& x, const std::string& str,
20 | std::function f,
21 | int precision = -1) {
22 | uintptr_t x_id = x.id();
23 | const NdArray& y = f(std::move(x));
24 | CHECK(y.id() == x_id);
25 | CheckNdArray(y, str, precision);
26 | }
27 |
28 | static void CheckNdArrayInplace(NdArray&& lhs, NdArray&& rhs,
29 | const std::string& str,
30 | std::function f,
31 | int precision = -1) {
32 | uintptr_t l_id = lhs.id();
33 | uintptr_t r_id = rhs.id();
34 | const NdArray& ret = f(std::move(lhs), std::move(rhs));
35 | CHECK((ret.id() == l_id || ret.id() == r_id));
36 | CheckNdArray(ret, str, precision);
37 | }
38 |
39 | static void CheckNdArrayInplace(
40 | const NdArray& lhs, NdArray&& rhs, const std::string& str,
41 | std::function f,
42 | int precision = -1) {
43 | uintptr_t l_id = lhs.id();
44 | uintptr_t r_id = rhs.id();
45 | const NdArray& ret = f(lhs, std::move(rhs));
46 | CHECK((ret.id() == l_id || ret.id() == r_id));
47 | CheckNdArray(ret, str, precision);
48 | }
49 |
50 | static void CheckNdArrayInplace(
51 | NdArray&& lhs, const NdArray& rhs, const std::string& str,
52 | std::function f,
53 | int precision = -1) {
54 | uintptr_t l_id = lhs.id();
55 | uintptr_t r_id = rhs.id();
56 | const NdArray& ret = f(std::move(lhs), rhs);
57 | CHECK((ret.id() == l_id || ret.id() == r_id));
58 | CheckNdArray(ret, str, precision);
59 | }
60 |
61 | static void CheckNdArrayInplace(NdArray&& lhs, float rhs,
62 | const std::string& str,
63 | std::function f,
64 | int precision = -1) {
65 | uintptr_t l_id = lhs.id();
66 | const NdArray& ret = f(std::move(lhs), rhs);
67 | CHECK((ret.id() == l_id));
68 | CheckNdArray(ret, str, precision);
69 | }
70 |
71 | static void CheckNdArrayInplace(float lhs, NdArray&& rhs,
72 | const std::string& str,
73 | std::function f,
74 | int precision = -1) {
75 | uintptr_t r_id = rhs.id();
76 | const NdArray& ret = f(lhs, std::move(rhs));
77 | CHECK((ret.id() == r_id));
78 | CheckNdArray(ret, str, precision);
79 | }
80 |
81 | static void CheckNdArrayNotInplace(NdArray&& x, const std::string& str,
82 | std::function f,
83 | int precision = -1) {
84 | uintptr_t x_id = x.id();
85 | const NdArray& y = f(x);
86 | CHECK(y.id() != x_id);
87 | CheckNdArray(y, str, precision);
88 | }
89 |
90 | static bool IsSameNdArray(const NdArray& m1, const NdArray& m2) {
91 | if (m1.shape() != m2.shape()) {
92 | return false;
93 | }
94 | auto&& data1 = m1.data();
95 | auto&& data2 = m2.data();
96 | for (int i = 0; i < static_cast(m1.size()); i++) {
97 | if (data1[i] != data2[i]) {
98 | return false;
99 | }
100 | }
101 | return true;
102 | }
103 |
104 | static void ResolveAmbiguous(NdArray& x) {
105 | for (auto&& v : x) {
106 | if (std::isnan(v)) {
107 | v = std::abs(v);
108 | }
109 | if (v == -0.f) {
110 | v = 0.f;
111 | }
112 | }
113 | }
114 |
115 | TEST_CASE("NdArray") {
116 | // -------------------------- Basic construction ---------------------------
117 | SECTION("Empty") {
118 | const NdArray m1;
119 | CHECK(m1.empty());
120 | CHECK(m1.size() == 0);
121 | CHECK(m1.shape() == Shape{0});
122 | CHECK(m1.ndim() == 1);
123 | }
124 |
125 | // --------------------------- Float initializer ---------------------------
126 | SECTION("Float initializer") {
127 | const NdArray m1 = {1.f, 2.f, 3.f};
128 | const NdArray m2 = {{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}};
129 | const NdArray m3 = {{{1.f, 2.f}}, {{3.f, 4.f}}, {{2.f, 3.f}}};
130 | CHECK(m1.shape() == Shape{3});
131 | CHECK(m2.shape() == Shape{2, 3});
132 | CHECK(m3.shape() == Shape{3, 1, 2});
133 | CheckNdArray(m1, "[1, 2, 3]");
134 | CheckNdArray(m2,
135 | "[[1, 2, 3],\n"
136 | " [4, 5, 6]]");
137 | CheckNdArray(m3,
138 | "[[[1, 2]],\n"
139 | " [[3, 4]],\n"
140 | " [[2, 3]]]");
141 | }
142 |
143 | SECTION("Float initializer invalid") {
144 | CHECK_NOTHROW(NdArray{{{1.f, 2.f}}, {{3.f, 4.f}}, {{1.f, 2.f}}});
145 | CHECK_THROWS(NdArray{{{1, 2}}, {}});
146 | CHECK_THROWS(NdArray{{{1.f, 2.f}}, {{3.f, 4.f}}, {{1.f, 2.f, 3.f}}});
147 | }
148 |
149 | SECTION("Confusable initializers") {
150 | const NdArray m1 = {1.f, 2.f, 3.f}; // Float initializer
151 | const NdArray m2 = {1, 2, 3}; // Shape (int) initalizer
152 | const NdArray m3 = {{1, 2, 3}}; // Float initializer due to nest
153 | CHECK(m1.shape() == Shape{3});
154 | CHECK(m2.shape() == Shape{1, 2, 3});
155 | CHECK(m3.shape() == Shape{1, 3});
156 | }
157 |
158 | // --------------------------- Static initializer --------------------------
159 | SECTION("Empty/Ones/Zeros") {
160 | const NdArray m1({2, 5}); // Same as Empty
161 | const auto m2 = NdArray::Empty({2, 5});
162 | const auto m3 = NdArray::Zeros({2, 5});
163 | const auto m4 = NdArray::Ones({2, 5});
164 | CHECK(m1.shape() == Shape{2, 5});
165 | CHECK(m2.shape() == Shape{2, 5});
166 | CHECK(m3.shape() == Shape{2, 5});
167 | CHECK(m4.shape() == Shape{2, 5});
168 | CheckNdArray(m3,
169 | "[[0, 0, 0, 0, 0],\n"
170 | " [0, 0, 0, 0, 0]]");
171 | CheckNdArray(m4,
172 | "[[1, 1, 1, 1, 1],\n"
173 | " [1, 1, 1, 1, 1]]");
174 | }
175 |
176 | SECTION("Empty/Ones/Zeros by template") {
177 | const NdArray m1({2, 5}); // No template support
178 | const auto m2 = NdArray::Empty(2, 5);
179 | const auto m3 = NdArray::Zeros(2, 5);
180 | const auto m4 = NdArray::Ones(2, 5);
181 | CHECK(m1.shape() == Shape{2, 5});
182 | CHECK(m2.shape() == Shape{2, 5});
183 | CHECK(m3.shape() == Shape{2, 5});
184 | CHECK(m4.shape() == Shape{2, 5});
185 | CheckNdArray(m3,
186 | "[[0, 0, 0, 0, 0],\n"
187 | " [0, 0, 0, 0, 0]]");
188 | CheckNdArray(m4,
189 | "[[1, 1, 1, 1, 1],\n"
190 | " [1, 1, 1, 1, 1]]");
191 | }
192 |
193 | SECTION("Arange") {
194 | const auto m1 = NdArray::Arange(10.f);
195 | const auto m2 = NdArray::Arange(0.f, 10.f, 1.f);
196 | const auto m3 = NdArray::Arange(5.f, 5.5f, 0.1f);
197 | CHECK(m1.shape() == Shape{10});
198 | CHECK(m2.shape() == Shape{10});
199 | CHECK(m3.shape() == Shape{5});
200 | CheckNdArray(m1, "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]");
201 | CheckNdArray(m2, "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]");
202 | CheckNdArray(m3, "[5, 5.1, 5.2, 5.3, 5.4]");
203 | }
204 |
205 | // --------------------------------- Random --------------------------------
206 | SECTION("Random uniform") {
207 | NdArray::Seed(0);
208 | const auto m1 = NdArray::Uniform({2, 3});
209 | NdArray::Seed(0);
210 | const auto m2 = NdArray::Uniform({2, 3});
211 | NdArray::Seed(1);
212 | const auto m3 = NdArray::Uniform({2, 3});
213 | CHECK(IsSameNdArray(m1, m2));
214 | CHECK(!IsSameNdArray(m1, m3));
215 | }
216 |
217 | SECTION("Random normal") {
218 | NdArray::Seed(0);
219 | const auto m1 = NdArray::Normal({2, 3});
220 | NdArray::Seed(0);
221 | const auto m2 = NdArray::Normal({2, 3});
222 | NdArray::Seed(1);
223 | const auto m3 = NdArray::Normal({2, 3});
224 | CHECK(IsSameNdArray(m1, m2));
225 | CHECK(!IsSameNdArray(m1, m3));
226 | }
227 |
228 | // ------------------------------ Basic method -----------------------------
229 | SECTION("Basic method") {
230 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
231 | auto m2 = m1.copy();
232 | CheckNdArray(m2,
233 | "[[0, 1, 2],\n"
234 | " [3, 4, 5]]");
235 | m2.fill(-1);
236 | CheckNdArray(m1,
237 | "[[0, 1, 2],\n"
238 | " [3, 4, 5]]");
239 | CheckNdArray(m2,
240 | "[[-1, -1, -1],\n"
241 | " [-1, -1, -1]]");
242 | CHECK(m1.ndim() == 2);
243 | CHECK(m1.flatten().ndim() == 1);
244 | CHECK(m1.flatten().id() != m1.id()); // Copy
245 | CHECK(m1.ravel().ndim() == 1);
246 | CHECK(m1.ravel().id() == m1.id()); // Same instance
247 |
248 | m1.resize({2, 2});
249 | CheckNdArray(m1,
250 | "[[0, 1],\n"
251 | " [2, 3]]");
252 | m1.resize({2, 4});
253 | CheckNdArray(m1,
254 | "[[0, 1, 2, 3],\n"
255 | " [0, 0, 0, 0]]");
256 | }
257 |
258 | // ------------------------------- Begin/End -------------------------------
259 | SECTION("Begin/End") {
260 | auto m1 = NdArray::Arange(1.f, 10.01f);
261 | // C++11 for-loop
262 | float sum1 = 0.f;
263 | for (auto&& v : m1) {
264 | sum1 += v;
265 | }
266 | CHECK(sum1 == Approx(55.f));
267 | // std library
268 | float sum2 = std::accumulate(m1.begin(), m1.end(), 0.f);
269 | CHECK(sum2 == Approx(55.f));
270 | }
271 |
272 | // ------------------------------- Float cast ------------------------------
273 | SECTION("Float cast") {
274 | auto m1 = NdArray::Ones({1, 1});
275 | auto m2 = NdArray::Ones({1, 2});
276 | CHECK(static_cast(m1) == 1);
277 | CHECK_THROWS(static_cast(m2));
278 | }
279 |
280 | // ------------------------------ Index access -----------------------------
281 | SECTION("Index access by []") {
282 | auto m1 = NdArray::Arange(12.f);
283 | auto m2 = NdArray::Arange(12.f).reshape({3, 4});
284 | auto m3 = NdArray::Arange(12.f).reshape({2, 2, -1});
285 | m1[3] = -1.f;
286 | m1[-2] = -2.f;
287 | m2[{1, 1}] = -1.f;
288 | m2[{-1, 3}] = -2.f;
289 | m3[{1, 1, 2}] = -1.f;
290 | m3[{0, 1, -2}] = -2.f;
291 | CheckNdArray(m1, "[0, 1, 2, -1, 4, 5, 6, 7, 8, 9, -2, 11]");
292 | CheckNdArray(m2,
293 | "[[0, 1, 2, 3],\n"
294 | " [4, -1, 6, 7],\n"
295 | " [8, 9, 10, -2]]");
296 | CheckNdArray(m3,
297 | "[[[0, 1, 2],\n"
298 | " [3, -2, 5]],\n"
299 | " [[6, 7, 8],\n"
300 | " [9, 10, -1]]]");
301 | }
302 |
303 | SECTION("Index access by ()") {
304 | auto m1 = NdArray::Arange(12.f);
305 | auto m2 = NdArray::Arange(12.f).reshape({3, 4});
306 | auto m3 = NdArray::Arange(12.f).reshape({2, 2, -1});
307 | m1(3) = -1.f;
308 | m1(-2) = -2.f;
309 | m2(1, 1) = -1.f;
310 | m2(-1, 3) = -2.f;
311 | m3(1, 1, 2) = -1.f;
312 | m3(0, 1, -2) = -2.f;
313 | CheckNdArray(m1, "[0, 1, 2, -1, 4, 5, 6, 7, 8, 9, -2, 11]");
314 | CheckNdArray(m2,
315 | "[[0, 1, 2, 3],\n"
316 | " [4, -1, 6, 7],\n"
317 | " [8, 9, 10, -2]]");
318 | CheckNdArray(m3,
319 | "[[[0, 1, 2],\n"
320 | " [3, -2, 5]],\n"
321 | " [[6, 7, 8],\n"
322 | " [9, 10, -1]]]");
323 | }
324 |
325 | SECTION("Index access by [] (const)") {
326 | const auto m1 = NdArray::Arange(12.f).reshape({1, 4, 3});
327 | CHECK(m1[{0, 2, 1}] == 7);
328 | CHECK(m1[{0, -1, 1}] == 10);
329 | }
330 |
331 | SECTION("Index access by () (const)") {
332 | const auto m1 = NdArray::Arange(12.f).reshape({1, 4, 3});
333 | CHECK(m1(0, 2, 1) == 7);
334 | CHECK(m1(0, -1, 1) == 10);
335 | }
336 |
337 | // -------------------------------- Reshape --------------------------------
338 | SECTION("Reshape") {
339 | auto m1 = NdArray::Arange(12.f);
340 | auto m2 = m1.reshape({3, 4});
341 | auto m3 = m2.reshape({2, -1});
342 | auto m4 = m3.reshape({2, 2, -1});
343 | CheckNdArray(m1, "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]");
344 | CheckNdArray(m2,
345 | "[[0, 1, 2, 3],\n"
346 | " [4, 5, 6, 7],\n"
347 | " [8, 9, 10, 11]]");
348 | CheckNdArray(m3,
349 | "[[0, 1, 2, 3, 4, 5],\n"
350 | " [6, 7, 8, 9, 10, 11]]");
351 | CheckNdArray(m4,
352 | "[[[0, 1, 2],\n"
353 | " [3, 4, 5]],\n"
354 | " [[6, 7, 8],\n"
355 | " [9, 10, 11]]]");
356 | }
357 |
358 | SECTION("Reshape by template") {
359 | auto m1 = NdArray::Arange(12.f);
360 | auto m2 = m1.reshape(3, 4);
361 | auto m3 = m2.reshape(2, -1);
362 | auto m4 = m3.reshape(2, 2, -1);
363 | CheckNdArray(m1, "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]");
364 | CheckNdArray(m2,
365 | "[[0, 1, 2, 3],\n"
366 | " [4, 5, 6, 7],\n"
367 | " [8, 9, 10, 11]]");
368 | CheckNdArray(m3,
369 | "[[0, 1, 2, 3, 4, 5],\n"
370 | " [6, 7, 8, 9, 10, 11]]");
371 | CheckNdArray(m4,
372 | "[[[0, 1, 2],\n"
373 | " [3, 4, 5]],\n"
374 | " [[6, 7, 8],\n"
375 | " [9, 10, 11]]]");
376 | }
377 |
378 | SECTION("Reshape with value change") {
379 | auto m1 = NdArray::Arange(12.f);
380 | auto m2 = m1.reshape({3, 4});
381 | auto m3 = m2.reshape({2, -1});
382 | auto m4 = m3.reshape({2, 2, -1});
383 | m1.data()[0] = -1.f;
384 | CheckNdArray(m1, "[-1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]");
385 | CheckNdArray(m2,
386 | "[[-1, 1, 2, 3],\n"
387 | " [4, 5, 6, 7],\n"
388 | " [8, 9, 10, 11]]");
389 | CheckNdArray(m3,
390 | "[[-1, 1, 2, 3, 4, 5],\n"
391 | " [6, 7, 8, 9, 10, 11]]");
392 | CheckNdArray(m4,
393 | "[[[-1, 1, 2],\n"
394 | " [3, 4, 5]],\n"
395 | " [[6, 7, 8],\n"
396 | " [9, 10, 11]]]");
397 | }
398 |
399 | SECTION("Reshape invalid") {
400 | auto m1 = NdArray::Arange(12.f);
401 | CHECK_THROWS(m1.reshape({5, 2}));
402 | CHECK_THROWS(m1.reshape({-1, -1}));
403 | }
404 |
405 | SECTION("Flatten") {
406 | auto m1 = NdArray::Arange(12.f).reshape(2, 2, 3);
407 | auto m2 = m1.flatten();
408 | CheckNdArray(m2, "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]");
409 | }
410 |
411 | // --------------------------------- Slice ---------------------------------
412 | SECTION("Slice 2-dim") {
413 | auto m1 = NdArray::Arange(16.f).reshape(4, 4);
414 | auto m2 = m1.slice({{1, 3}, {1, 3}});
415 | auto m3 = m1.slice({1, 3}, {0, 4});
416 | auto m4 = m1.slice({0, 4}, {1, 3});
417 | auto m5 = m1.slice({1, -1}, {0, 100000});
418 | CHECK(m1.shape() == Shape{4, 4});
419 | CHECK(m2.shape() == Shape{2, 2});
420 | CHECK(m3.shape() == Shape{2, 4});
421 | CHECK(m4.shape() == Shape{4, 2});
422 | CHECK(m5.shape() == Shape{2, 4});
423 | CheckNdArray(m1,
424 | "[[0, 1, 2, 3],\n"
425 | " [4, 5, 6, 7],\n"
426 | " [8, 9, 10, 11],\n"
427 | " [12, 13, 14, 15]]");
428 | CheckNdArray(m2,
429 | "[[5, 6],\n"
430 | " [9, 10]]");
431 | CheckNdArray(m3,
432 | "[[4, 5, 6, 7],\n"
433 | " [8, 9, 10, 11]]");
434 | CheckNdArray(m4,
435 | "[[1, 2],\n"
436 | " [5, 6],\n"
437 | " [9, 10],\n"
438 | " [13, 14]]");
439 | CheckNdArray(m5,
440 | "[[4, 5, 6, 7],\n"
441 | " [8, 9, 10, 11]]");
442 | }
443 |
444 | SECTION("Slice high-dim") {
445 | auto m1 = NdArray::Arange(256.f).reshape(4, 4, 4, 4);
446 | auto m2 = m1.slice({{1, 3}, {1, 3}, {1, 3}, {1, 3}});
447 | auto m3 = m1.slice({1, 3}, {1, 3}, {1, 3}, {1, 3});
448 | CHECK(m1.shape() == Shape{4, 4, 4, 4});
449 | CHECK(m2.shape() == Shape{2, 2, 2, 2});
450 | CHECK(m3.shape() == Shape{2, 2, 2, 2});
451 | CheckNdArray(m2,
452 | "[[[[85, 86],\n"
453 | " [89, 90]],\n"
454 | " [[101, 102],\n"
455 | " [105, 106]]],\n"
456 | " [[[149, 150],\n"
457 | " [153, 154]],\n"
458 | " [[165, 166],\n"
459 | " [169, 170]]]]");
460 | CheckNdArray(m3,
461 | "[[[[85, 86],\n"
462 | " [89, 90]],\n"
463 | " [[101, 102],\n"
464 | " [105, 106]]],\n"
465 | " [[[149, 150],\n"
466 | " [153, 154]],\n"
467 | " [[165, 166],\n"
468 | " [169, 170]]]]");
469 | }
470 |
471 | // ------------------------------ Dot product ------------------------------
472 | SECTION("Dot (empty)") {
473 | // Empty array
474 | auto m1 = NdArray::Arange(0.f);
475 | CHECK_THROWS(m1.dot(m1));
476 | }
477 |
478 | SECTION("Dot (scalar)") {
479 | // Scalar multiply
480 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
481 | auto m1_a = m1.dot(NdArray{2.f});
482 | auto m1_b = NdArray({2.f}).dot(m1);
483 | CheckNdArray(m1_a,
484 | "[[0, 2, 4],\n"
485 | " [6, 8, 10]]");
486 | CheckNdArray(m1_b,
487 | "[[0, 2, 4],\n"
488 | " [6, 8, 10]]");
489 | }
490 |
491 | SECTION("Dot (1D, 1D)") {
492 | // Inner product of vectors
493 | auto m1 = NdArray::Arange(3.f);
494 | auto m2 = NdArray::Ones(3);
495 | float m11 = m1.dot(m1);
496 | float m12 = m1.dot(m2);
497 | CHECK(m11 == Approx(5.f));
498 | CHECK(m12 == Approx(3.f));
499 | // Shape mismatch
500 | auto m3 = NdArray::Arange(4.f);
501 | CHECK_THROWS(m1.dot(m3));
502 | }
503 |
504 | SECTION("Dot (2D, 2D)") {
505 | // Inner product of 2D matrix
506 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
507 | auto m2 = NdArray::Arange(6.f).reshape(3, 2);
508 | auto m12 = m1.dot(m2);
509 | CheckNdArray(m12,
510 | "[[10, 13],\n"
511 | " [28, 40]]");
512 | // Shape mismatch
513 | CHECK_THROWS(m1.dot(m1));
514 | }
515 |
516 | SECTION("Dot (2D, 1D)") {
517 | // Inner product of 2D matrix and vector (2D, 1D)
518 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
519 | auto m2 = NdArray::Arange(3.f);
520 | auto m12 = m1.dot(m2);
521 | CheckNdArray(m12, "[5, 14]");
522 | // Shape mismatch
523 | auto m3 = NdArray::Arange(2.f);
524 | CHECK_THROWS(m1.dot(m3));
525 | }
526 |
527 | SECTION("Dot (ND, 1D)") {
528 | // Inner product of ND matrix and vector (ND, 1D)
529 | auto m1 = NdArray::Arange(12.f).reshape(2, 2, 3);
530 | auto m2 = NdArray::Arange(3.f);
531 | auto m12 = m1.dot(m2);
532 | CheckNdArray(m12,
533 | "[[5, 14],\n"
534 | " [23, 32]]");
535 | }
536 |
537 | SECTION("Dot (ND, MD)") {
538 | // Inner product of ND matrix and MD matrix
539 | auto m1 = NdArray::Arange(12.f).reshape(2, 3, 2);
540 | auto m2 = NdArray::Arange(6.f).reshape(2, 3);
541 | auto m3 = NdArray::Arange(12.f).reshape(3, 2, 2);
542 | auto m12 = m1.dot(m2);
543 | auto m13 = m1.dot(m3);
544 | CHECK(m12.shape() == Shape{2, 3, 3});
545 | CHECK(m13.shape() == Shape{2, 3, 3, 2});
546 | CheckNdArray(m12,
547 | "[[[3, 4, 5],\n"
548 | " [9, 14, 19],\n"
549 | " [15, 24, 33]],\n"
550 | " [[21, 34, 47],\n"
551 | " [27, 44, 61],\n"
552 | " [33, 54, 75]]]");
553 | CheckNdArray(m13,
554 | "[[[[2, 3],\n"
555 | " [6, 7],\n"
556 | " [10, 11]],\n"
557 | " [[6, 11],\n"
558 | " [26, 31],\n"
559 | " [46, 51]],\n"
560 | " [[10, 19],\n"
561 | " [46, 55],\n"
562 | " [82, 91]]],\n"
563 | " [[[14, 27],\n"
564 | " [66, 79],\n"
565 | " [118, 131]],\n"
566 | " [[18, 35],\n"
567 | " [86, 103],\n"
568 | " [154, 171]],\n"
569 | " [[22, 43],\n"
570 | " [106, 127],\n"
571 | " [190, 211]]]]");
572 | }
573 |
574 | // ----------------------------- Matmul product ----------------------------
575 | SECTION("Matmul (2D, 2D)") {
576 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
577 | auto m2 = NdArray::Arange(6.f).reshape(3, 2);
578 | auto m12 = Matmul(m1, m2);
579 | CheckNdArray(m12,
580 | "[[10, 13],\n"
581 | " [28, 40]]");
582 | }
583 |
584 | SECTION("Matmul (2D, 3D)") {
585 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
586 | auto m2 = NdArray::Arange(6.f).reshape(1, 3, 2);
587 | auto m12 = Matmul(m1, m2);
588 | CheckNdArray(m12,
589 | "[[[10, 13],\n"
590 | " [28, 40]]]");
591 | }
592 |
593 | SECTION("Matmul (3D, 2D)") {
594 | auto m1 = NdArray::Arange(6.f).reshape(1, 2, 3);
595 | auto m2 = NdArray::Arange(6.f).reshape(3, 2);
596 | auto m12 = Matmul(m1, m2);
597 | CheckNdArray(m12,
598 | "[[[10, 13],\n"
599 | " [28, 40]]]");
600 | }
601 |
602 | SECTION("Matmul (1D, 2D)") {
603 | auto m1 = NdArray::Arange(3.f);
604 | auto m2 = NdArray::Arange(6.f).reshape(3, 2);
605 | auto m12 = Matmul(m1, m2);
606 | CheckNdArray(m12, "[10, 13]");
607 | }
608 |
609 | SECTION("Matmul (2D, 1D)") {
610 | auto m1 = NdArray::Arange(6.f).reshape(3, 2);
611 | auto m2 = NdArray::Arange(2.f);
612 | auto m12 = Matmul(m1, m2);
613 | CheckNdArray(m12, "[1, 3, 5]");
614 | }
615 |
616 | SECTION("Matmul (1D, ND)") {
617 | auto m1 = NdArray::Arange(3.f);
618 | auto m2 = NdArray::Arange(12.f).reshape(2, 1, 3, 2);
619 | auto m12 = Matmul(m1, m2);
620 | CheckNdArray(m12,
621 | "[[[10, 13]],\n"
622 | " [[28, 31]]]");
623 | }
624 |
625 | SECTION("Matmul (2D, ND)") {
626 | auto m1 = NdArray::Arange(12.f).reshape(2, 1, 3, 2);
627 | auto m2 = NdArray::Arange(2.f);
628 | auto m12 = Matmul(m1, m2);
629 | CheckNdArray(m12,
630 | "[[[1, 3, 5]],\n"
631 | " [[7, 9, 11]]]");
632 | }
633 |
634 | SECTION("Matmul (ND, MD)") {
635 | auto m1 = NdArray::Arange(36.f).reshape(2, 3, 1, 2, 3);
636 | auto m2 = NdArray::Arange(36.f).reshape(3, 3, 4);
637 | auto m12 = Matmul(m1, m2);
638 | CHECK(m12.shape() == Shape{2, 3, 3, 2, 4});
639 | CheckNdArray(m12,
640 | "[[[[[20, 23, 26, 29],\n"
641 | " [56, 68, 80, 92]],\n"
642 | " [[56, 59, 62, 65],\n"
643 | " [200, 212, 224, 236]],\n"
644 | " [[92, 95, 98, 101],\n"
645 | " [344, 356, 368, 380]]],\n"
646 | " [[[92, 113, 134, 155],\n"
647 | " [128, 158, 188, 218]],\n"
648 | " [[344, 365, 386, 407],\n"
649 | " [488, 518, 548, 578]],\n"
650 | " [[596, 617, 638, 659],\n"
651 | " [848, 878, 908, 938]]],\n"
652 | " [[[164, 203, 242, 281],\n"
653 | " [200, 248, 296, 344]],\n"
654 | " [[632, 671, 710, 749],\n"
655 | " [776, 824, 872, 920]],\n"
656 | " [[1100, 1139, 1178, 1217],\n"
657 | " [1352, 1400, 1448, 1496]]]],\n"
658 | " [[[[236, 293, 350, 407],\n"
659 | " [272, 338, 404, 470]],\n"
660 | " [[920, 977, 1034, 1091],\n"
661 | " [1064, 1130, 1196, 1262]],\n"
662 | " [[1604, 1661, 1718, 1775],\n"
663 | " [1856, 1922, 1988, 2054]]],\n"
664 | " [[[308, 383, 458, 533],\n"
665 | " [344, 428, 512, 596]],\n"
666 | " [[1208, 1283, 1358, 1433],\n"
667 | " [1352, 1436, 1520, 1604]],\n"
668 | " [[2108, 2183, 2258, 2333],\n"
669 | " [2360, 2444, 2528, 2612]]],\n"
670 | " [[[380, 473, 566, 659],\n"
671 | " [416, 518, 620, 722]],\n"
672 | " [[1496, 1589, 1682, 1775],\n"
673 | " [1640, 1742, 1844, 1946]],\n"
674 | " [[2612, 2705, 2798, 2891],\n"
675 | " [2864, 2966, 3068, 3170]]]]]");
676 | }
677 |
678 | // ----------------------------- Cross product -----------------------------
679 | SECTION("Cross (1D, 1D), (3, 3 elem)") {
680 | NdArray m1 = {1.f, 2.f, 3.f};
681 | NdArray m2 = {4.f, 5.f, 6.f};
682 | auto m12 = m1.cross(m2);
683 | CheckNdArray(m12, "[-3, 6, -3]");
684 | }
685 |
686 | SECTION("Cross (1D, 1D), (3, 2 elem)") {
687 | NdArray m1 = {1.f, 2.f, 3.f};
688 | NdArray m2 = {4.f, 5.f};
689 | auto m12 = m1.cross(m2);
690 | CheckNdArray(m12, "[-15, 12, -3]");
691 | }
692 |
693 | SECTION("Cross (1D, 1D), (2, 3 elem)") {
694 | NdArray m1 = {1.f, 2.f};
695 | NdArray m2 = {4.f, 5.f, 6.f};
696 | auto m12 = m1.cross(m2);
697 | CheckNdArray(m12, "[12, -6, -3]");
698 | }
699 |
700 | SECTION("Cross (1D, 1D), (2, 2 elem)") {
701 | NdArray m1 = {1.f, 2.f};
702 | NdArray m2 = {4.f, 5.f};
703 | float m12 = m1.cross(m2);
704 | CHECK(m12 == Approx(-3.f));
705 | }
706 |
707 | SECTION("Cross (1D, 1D), (mismatch)") {
708 | NdArray m1 = {1.f};
709 | NdArray m2 = {4.f, 5.f};
710 | NdArray m3 = {4.f, 5.f, 6.f, 7.f};
711 | CHECK_THROWS(m1.cross(m2));
712 | CHECK_THROWS(m2.cross(m3));
713 | }
714 |
715 | SECTION("Cross (ND, MD), (3, 3 elem)") {
716 | auto m1 = NdArray::Arange(18.f).reshape(3, 2, 3);
717 | auto m2 = NdArray::Arange(6.f).reshape(2, 3) + 1.f;
718 | auto m12 = m1.cross(m2);
719 | CheckNdArray(m12,
720 | "[[[-1, 2, -1],\n"
721 | " [-1, 2, -1]],\n"
722 | " [[5, -10, 5],\n"
723 | " [5, -10, 5]],\n"
724 | " [[11, -22, 11],\n"
725 | " [11, -22, 11]]]");
726 | }
727 |
728 | SECTION("Cross (ND, MD), (3, 2 elem)") {
729 | auto m1 = NdArray::Arange(18.f).reshape(2, 3, 3);
730 | auto m2 = NdArray::Arange(6.f).reshape(3, 2) + 1.f;
731 | auto m12 = m1.cross(m2);
732 | CheckNdArray(m12,
733 | "[[[-4, 2, -1],\n"
734 | " [-20, 15, 0],\n"
735 | " [-48, 40, 1]],\n"
736 | " [[-22, 11, 8],\n"
737 | " [-56, 42, 9],\n"
738 | " [-102, 85, 10]]]");
739 | }
740 |
741 | SECTION("Cross (ND, MD), (2, 3 elem)") {
742 | auto m1 = NdArray::Arange(12.f).reshape(3, 2, 2);
743 | auto m2 = NdArray::Arange(6.f).reshape(2, 3) + 1.f;
744 | auto m12 = m1.cross(m2);
745 | CheckNdArray(m12,
746 | "[[[3, -0, -1],\n"
747 | " [18, -12, -2]],\n"
748 | " [[15, -12, 3],\n"
749 | " [42, -36, 2]],\n"
750 | " [[27, -24, 7],\n"
751 | " [66, -60, 6]]]");
752 | }
753 |
754 | SECTION("Cross (ND, MD), (2, 2 elem)") {
755 | auto m1 = NdArray::Arange(12.f).reshape(3, 2, 2);
756 | auto m2 = NdArray::Arange(4.f).reshape(2, 2) + 1.f;
757 | auto m12 = m1.cross(m2);
758 | CheckNdArray(m12,
759 | "[[-1, -1],\n"
760 | " [3, 3],\n"
761 | " [7, 7]]");
762 | }
763 |
764 | // ----------------------------- Axis operation ----------------------------
765 | SECTION("Sum") {
766 | NdArray m0;
767 | auto m1 = NdArray::Arange(6.f);
768 | auto m2 = NdArray::Arange(36.f).reshape(2, 3, 2, 3);
769 | CheckNdArray(m0.sum(), "[0]");
770 | CheckNdArray(m1.sum(), "[15]");
771 | CheckNdArray(m2.sum({0}),
772 | "[[[18, 20, 22],\n"
773 | " [24, 26, 28]],\n"
774 | " [[30, 32, 34],\n"
775 | " [36, 38, 40]],\n"
776 | " [[42, 44, 46],\n"
777 | " [48, 50, 52]]]");
778 | CheckNdArray(m2.sum({2}),
779 | "[[[3, 5, 7],\n"
780 | " [15, 17, 19],\n"
781 | " [27, 29, 31]],\n"
782 | " [[39, 41, 43],\n"
783 | " [51, 53, 55],\n"
784 | " [63, 65, 67]]]");
785 | CheckNdArray(m2.sum({3}),
786 | "[[[3, 12],\n"
787 | " [21, 30],\n"
788 | " [39, 48]],\n"
789 | " [[57, 66],\n"
790 | " [75, 84],\n"
791 | " [93, 102]]]");
792 | CheckNdArray(m2.sum({1, 2}),
793 | "[[45, 51, 57],\n"
794 | " [153, 159, 165]]");
795 | CheckNdArray(m2.sum({1, 3}),
796 | "[[63, 90],\n"
797 | " [225, 252]]");
798 | CheckNdArray(m2.sum({0, 1, 2, 3}), "[630]");
799 | CheckNdArray(m1.reshape(1, 2, 3).sum({0, 1, 2}), "[15]");
800 | }
801 |
802 | SECTION("Min") {
803 | NdArray m0;
804 | auto m1 = NdArray::Arange(6.f) - 3.f;
805 | auto m2 = NdArray::Arange(12.f).reshape(2, 3, 2) - 6.f;
806 | CHECK_THROWS(m0.min());
807 | CheckNdArray(m1.min(), "[-3]");
808 | CheckNdArray(m2.min({0}),
809 | "[[-6, -5],\n"
810 | " [-4, -3],\n"
811 | " [-2, -1]]");
812 | CheckNdArray(m2.min({2}),
813 | "[[-6, -4, -2],\n"
814 | " [0, 2, 4]]");
815 | CheckNdArray(m2.min({2, 1}), "[-6, 0]");
816 | }
817 |
818 | SECTION("Max") {
819 | NdArray m0;
820 | auto m1 = NdArray::Arange(6.f) - 3.f;
821 | auto m2 = NdArray::Arange(12.f).reshape(2, 3, 2) - 6.f;
822 | CHECK_THROWS(m0.max());
823 | CheckNdArray(m1.max(), "[2]");
824 | CheckNdArray(m2.max({0}),
825 | "[[0, 1],\n"
826 | " [2, 3],\n"
827 | " [4, 5]]");
828 | CheckNdArray(m2.max({2}),
829 | "[[-5, -3, -1],\n"
830 | " [1, 3, 5]]");
831 | CheckNdArray(m2.max({2, 1}), "[-1, 5]");
832 | }
833 |
834 | SECTION("Mean") {
835 | NdArray m0;
836 | auto m1 = NdArray::Arange(6.f) - 3.f;
837 | auto m2 = NdArray::Arange(12.f).reshape(2, 3, 2) - 6.f;
838 | CheckNdArray(m0.mean(), "[nan]");
839 | CheckNdArray(m1.mean(), "[-0.5]");
840 | CheckNdArray(m2.mean({0}),
841 | "[[-3, -2],\n"
842 | " [-1, 0],\n"
843 | " [1, 2]]");
844 | CheckNdArray(m2.mean({2}),
845 | "[[-5.5, -3.5, -1.5],\n"
846 | " [0.5, 2.5, 4.5]]");
847 | CheckNdArray(m2.mean({2, 1}), "[-3.5, 2.5]");
848 | }
849 |
850 | SECTION("Sum keepdims") {
851 | NdArray m0;
852 | auto m1 = NdArray::Arange(36.f).reshape(2, 3, 2, 3);
853 | CheckNdArray(m1.sum({1, 2}, true),
854 | "[[[[45, 51, 57]]],\n"
855 | " [[[153, 159, 165]]]]");
856 | CheckNdArray(m1.sum({1, 3}, true),
857 | "[[[[63],\n"
858 | " [90]]],\n"
859 | " [[[225],\n"
860 | " [252]]]]");
861 | CheckNdArray(m1.sum({0, 1, 2, 3}, true), "[[[[630]]]]");
862 | CheckNdArray(m1.sum({}, true), "[[[[630]]]]");
863 | }
864 |
865 | // --------------------------- Logistic operation --------------------------
866 | SECTION("All (no axis)") {
867 | NdArray m0;
868 | CHECK(All(m0));
869 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
870 | auto m2 = m1.copy();
871 | CHECK(All(m1 == m2));
872 | m2(1, 2) = -1.f;
873 | CHECK(!All(m1 == m2));
874 | }
875 |
876 | SECTION("Any (no axis)") {
877 | NdArray m0;
878 | CHECK(!Any(m0));
879 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
880 | auto m2 = m1.copy() + 1.f;
881 | CHECK(!Any(m1 == m2));
882 | m2(0, 0) = 0.f;
883 | CHECK(Any(m1 == m2));
884 | }
885 |
886 | SECTION("All (with axis)") {
887 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
888 | CheckNdArray(All(m1, {0}), "[0, 1, 1]");
889 | CheckNdArray(All(m1, {1}), "[0, 1]");
890 | }
891 |
892 | SECTION("Any (with axis)") {
893 | auto m1 = NdArray::Zeros(2, 3);
894 | m1(0, 0) = -1.f;
895 | CheckNdArray(Any(m1, {0}), "[1, 0, 0]");
896 | CheckNdArray(Any(m1, {1}), "[1, 0]");
897 | }
898 |
899 | SECTION("Where") {
900 | auto m1 = 2.f < NdArray::Arange(6.f).reshape(2, 3);
901 | CheckNdArray(Where(m1, NdArray::Ones(2, 1), NdArray::Arange(3)),
902 | "[[0, 1, 2],\n"
903 | " [1, 1, 1]]");
904 | CheckNdArray(Where(m1, 1.f, NdArray::Arange(3)),
905 | "[[0, 1, 2],\n"
906 | " [1, 1, 1]]");
907 | CheckNdArray(Where(m1, NdArray::Arange(3), 0.f),
908 | "[[0, 0, 0],\n"
909 | " [0, 1, 2]]");
910 | CheckNdArray(Where(m1, 1.f, 0.f),
911 | "[[0, 0, 0],\n"
912 | " [1, 1, 1]]");
913 | CheckNdArray(Where(m1, 0.f, 1.f),
914 | "[[1, 1, 1],\n"
915 | " [0, 0, 0]]");
916 | }
917 |
918 | // ------------------------------- Operator --------------------------------
919 | SECTION("Single +- operators") {
920 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
921 | auto m_p = +m1;
922 | auto m_n = -m1;
923 | CHECK(m_p.id() != m1.id());
924 | CHECK(m_n.id() != m1.id());
925 | CheckNdArray(m_p,
926 | "[[0, 1, 2],\n"
927 | " [3, 4, 5]]");
928 | CheckNdArray(m_n,
929 | "[[-0, -1, -2],\n"
930 | " [-3, -4, -5]]");
931 | }
932 |
933 | SECTION("Add same shape") {
934 | auto m1 = NdArray::Arange(12.f).reshape(2, 3, 2);
935 | auto m2 = NdArray::Ones({2, 3, 2});
936 | auto m3 = m1 + m2;
937 | CHECK(m1.shape() == m2.shape());
938 | CHECK(m1.shape() == m3.shape());
939 | CheckNdArray(m3,
940 | "[[[1, 2],\n"
941 | " [3, 4],\n"
942 | " [5, 6]],\n"
943 | " [[7, 8],\n"
944 | " [9, 10],\n"
945 | " [11, 12]]]");
946 | }
947 |
948 | SECTION("Add broadcast 2-dim") {
949 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
950 | auto m2 = NdArray::Arange(2.f).reshape(2, 1);
951 | auto m3 = NdArray::Arange(3.f).reshape(1, 3);
952 | auto m12 = m1 + m2;
953 | auto m13 = m1 + m3;
954 | auto m23 = m2 + m3;
955 | CHECK(m12.shape() == Shape{2, 3});
956 | CHECK(m13.shape() == Shape{2, 3});
957 | CHECK(m23.shape() == Shape{2, 3});
958 | CheckNdArray(m12,
959 | "[[0, 1, 2],\n"
960 | " [4, 5, 6]]");
961 | CheckNdArray(m13,
962 | "[[0, 2, 4],\n"
963 | " [3, 5, 7]]");
964 | CheckNdArray(m23,
965 | "[[0, 1, 2],\n"
966 | " [1, 2, 3]]");
967 | }
968 |
969 | SECTION("Add broadcast high-dim") {
970 | auto m1 = NdArray::Arange(6.f).reshape(1, 2, 1, 1, 3);
971 | auto m2 = NdArray::Arange(2.f).reshape(2, 1);
972 | auto m3 = NdArray::Arange(3.f).reshape(1, 3);
973 | auto m12 = m1 + m2;
974 | auto m13 = m1 + m3;
975 | CHECK(m12.shape() == Shape{1, 2, 1, 2, 3});
976 | CHECK(m13.shape() == Shape{1, 2, 1, 1, 3});
977 | CheckNdArray(m12,
978 | "[[[[[0, 1, 2],\n"
979 | " [1, 2, 3]]],\n"
980 | " [[[3, 4, 5],\n"
981 | " [4, 5, 6]]]]]");
982 | CheckNdArray(m13,
983 | "[[[[[0, 2, 4]]],\n"
984 | " [[[3, 5, 7]]]]]");
985 | }
986 |
987 | SECTION("Add empty") {
988 | NdArray m1;
989 | auto m2 = m1 + 1.f;
990 | CHECK(m1.shape() == Shape{0});
991 | CHECK(m2.shape() == Shape{0});
992 | CHECK(m2.size() == 0);
993 | }
994 |
995 | SECTION("Sub/Mul/Div") {
996 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
997 | auto m2 = NdArray::Arange(3.f).reshape(3);
998 | auto m_sub = m1 - m2;
999 | auto m_mul = m1 * m2;
1000 | auto m_div = m1 / m2;
1001 | CHECK(m_sub.shape() == Shape{2, 3});
1002 | CHECK(m_mul.shape() == Shape{2, 3});
1003 | CHECK(m_div.shape() == Shape{2, 3});
1004 | CheckNdArray(m_sub,
1005 | "[[0, 0, 0],\n"
1006 | " [3, 3, 3]]");
1007 | CheckNdArray(m_mul,
1008 | "[[0, 1, 4],\n"
1009 | " [0, 4, 10]]");
1010 | ResolveAmbiguous(m_div); // -nan -> nan
1011 | CheckNdArray(m_div,
1012 | "[[nan, 1, 1],\n"
1013 | " [inf, 4, 2.5]]");
1014 | }
1015 |
1016 | SECTION("Arithmetic operators (NdArray, float)") {
1017 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
1018 | auto m_add = m1 + 10.f;
1019 | auto m_sub = m1 - 10.f;
1020 | auto m_mul = m1 * 10.f;
1021 | auto m_div = m1 / 10.f;
1022 | CheckNdArray(m_add,
1023 | "[[10, 11, 12],\n"
1024 | " [13, 14, 15]]");
1025 | CheckNdArray(m_sub,
1026 | "[[-10, -9, -8],\n"
1027 | " [-7, -6, -5]]");
1028 | CheckNdArray(m_mul,
1029 | "[[0, 10, 20],\n"
1030 | " [30, 40, 50]]");
1031 | CheckNdArray(m_div,
1032 | "[[0, 0.1, 0.2],\n"
1033 | " [0.3, 0.4, 0.5]]");
1034 | }
1035 |
1036 | SECTION("Arithmetic operators (float, NdArray)") {
1037 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
1038 | auto m_add = 10.f + m1;
1039 | auto m_sub = 10.f - m1;
1040 | auto m_mul = 10.f * m1;
1041 | auto m_div = 10.f / m1;
1042 | CheckNdArray(m_add,
1043 | "[[10, 11, 12],\n"
1044 | " [13, 14, 15]]");
1045 | CheckNdArray(m_sub,
1046 | "[[10, 9, 8],\n"
1047 | " [7, 6, 5]]");
1048 | CheckNdArray(m_mul,
1049 | "[[0, 10, 20],\n"
1050 | " [30, 40, 50]]");
1051 | CheckNdArray(m_div,
1052 | "[[inf, 10, 5],\n"
1053 | " [3.33333, 2.5, 2]]");
1054 | }
1055 |
1056 | SECTION("Comparison operators (NdArray, NdArray)") {
1057 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
1058 | auto m2 = NdArray::Arange(2.f).reshape(2, 1) + 3.f;
1059 | CheckNdArray(m1 == m2,
1060 | "[[0, 0, 0],\n"
1061 | " [0, 1, 0]]");
1062 | CheckNdArray(m1 != m2,
1063 | "[[1, 1, 1],\n"
1064 | " [1, 0, 1]]");
1065 | CheckNdArray(m1 > m2,
1066 | "[[0, 0, 0],\n"
1067 | " [0, 0, 1]]");
1068 | CheckNdArray(m1 >= m2,
1069 | "[[0, 0, 0],\n"
1070 | " [0, 1, 1]]");
1071 | CheckNdArray(m1 < m2,
1072 | "[[1, 1, 1],\n"
1073 | " [1, 0, 0]]");
1074 | CheckNdArray(m1 <= m2,
1075 | "[[1, 1, 1],\n"
1076 | " [1, 1, 0]]");
1077 | }
1078 |
1079 | SECTION("Comparison operators (NdArray, float)") {
1080 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
1081 | CheckNdArray(m1 == 1.f,
1082 | "[[0, 1, 0],\n"
1083 | " [0, 0, 0]]");
1084 | CheckNdArray(m1 != 1.f,
1085 | "[[1, 0, 1],\n"
1086 | " [1, 1, 1]]");
1087 | CheckNdArray(m1 > 1.f,
1088 | "[[0, 0, 1],\n"
1089 | " [1, 1, 1]]");
1090 | CheckNdArray(m1 >= 1.f,
1091 | "[[0, 1, 1],\n"
1092 | " [1, 1, 1]]");
1093 | CheckNdArray(m1 < 1.f,
1094 | "[[1, 0, 0],\n"
1095 | " [0, 0, 0]]");
1096 | CheckNdArray(m1 <= 1.f,
1097 | "[[1, 1, 0],\n"
1098 | " [0, 0, 0]]");
1099 | }
1100 |
1101 | SECTION("Comparison operators (float, NdArray)") {
1102 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
1103 | CheckNdArray(1.f == m1,
1104 | "[[0, 1, 0],\n"
1105 | " [0, 0, 0]]");
1106 | CheckNdArray(1.f != m1,
1107 | "[[1, 0, 1],\n"
1108 | " [1, 1, 1]]");
1109 | CheckNdArray(1.f > m1,
1110 | "[[1, 0, 0],\n"
1111 | " [0, 0, 0]]");
1112 | CheckNdArray(1.f >= m1,
1113 | "[[1, 1, 0],\n"
1114 | " [0, 0, 0]]");
1115 | CheckNdArray(1.f < m1,
1116 | "[[0, 0, 1],\n"
1117 | " [1, 1, 1]]");
1118 | CheckNdArray(1.f <= m1,
1119 | "[[0, 1, 1],\n"
1120 | " [1, 1, 1]]");
1121 | }
1122 |
1123 | // --------------------------- In-place Operator ---------------------------
1124 | SECTION("Arithmetic operators (NdArray, NdArray) (in-place both)") {
1125 | CheckNdArrayInplace(
1126 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f),
1127 | "[[0, 2, 4],\n"
1128 | " [3, 5, 7]]",
1129 | static_cast(operator+));
1130 | CheckNdArrayInplace(
1131 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f),
1132 | "[[0, 0, 0],\n"
1133 | " [3, 3, 3]]",
1134 | static_cast(operator-));
1135 | CheckNdArrayInplace(
1136 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f),
1137 | "[[0, 1, 4],\n"
1138 | " [0, 4, 10]]",
1139 | static_cast(operator*));
1140 | CheckNdArrayInplace(
1141 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f) + 1.f,
1142 | "[[0, 0.5, 0.666667],\n"
1143 | " [3, 2, 1.66667]]",
1144 | static_cast(operator/));
1145 | }
1146 |
1147 | SECTION("Arithmetic operators (NdArray, NdArray) (in-place right)") {
1148 | auto m1 = NdArray::Arange(3.f);
1149 | auto m2 = NdArray::Arange(3.f) + 1.f;
1150 | CheckNdArrayInplace(
1151 | m1, NdArray::Arange(6.f).reshape(2, 3),
1152 | "[[0, 2, 4],\n"
1153 | " [3, 5, 7]]",
1154 | static_cast(operator+));
1155 | CheckNdArrayInplace(
1156 | m1, NdArray::Arange(6.f).reshape(2, 3),
1157 | "[[0, 0, 0],\n"
1158 | " [-3, -3, -3]]",
1159 | static_cast(operator-));
1160 | CheckNdArrayInplace(
1161 | m1, NdArray::Arange(6.f).reshape(2, 3),
1162 | "[[0, 1, 4],\n"
1163 | " [0, 4, 10]]",
1164 | static_cast(operator*));
1165 | CheckNdArrayInplace(
1166 | m2, NdArray::Arange(6.f).reshape(2, 3),
1167 | "[[inf, 2, 1.5],\n"
1168 | " [0.333333, 0.5, 0.6]]",
1169 | static_cast(operator/));
1170 | }
1171 |
1172 | SECTION("Arithmetic operators (NdArray, NdArray) (in-place left)") {
1173 | auto m2 = NdArray::Arange(3.f);
1174 | auto m3 = NdArray::Arange(3.f) + 1.f;
1175 | CheckNdArrayInplace(
1176 | NdArray::Arange(6.f).reshape(2, 3), m2,
1177 | "[[0, 2, 4],\n"
1178 | " [3, 5, 7]]",
1179 | static_cast(operator+));
1180 | CheckNdArrayInplace(
1181 | NdArray::Arange(6.f).reshape(2, 3), m2,
1182 | "[[0, 0, 0],\n"
1183 | " [3, 3, 3]]",
1184 | static_cast(operator-));
1185 | CheckNdArrayInplace(
1186 | NdArray::Arange(6.f).reshape(2, 3), m2,
1187 | "[[0, 1, 4],\n"
1188 | " [0, 4, 10]]",
1189 | static_cast(operator*));
1190 | CheckNdArrayInplace(
1191 | NdArray::Arange(6.f).reshape(2, 3), m3,
1192 | "[[0, 0.5, 0.666667],\n"
1193 | " [3, 2, 1.66667]]",
1194 | static_cast(operator/));
1195 | }
1196 |
1197 | SECTION("Arithmetic operators (NdArray, float) (inplace)") {
1198 | CheckNdArrayInplace(
1199 | NdArray::Arange(3.f), 2.f, "[2, 3, 4]",
1200 | static_cast(operator+));
1201 | CheckNdArrayInplace(
1202 | NdArray::Arange(3.f), 2.f, "[-2, -1, 0]",
1203 | static_cast(operator-));
1204 | CheckNdArrayInplace(
1205 | NdArray::Arange(3.f), 2.f, "[0, 2, 4]",
1206 | static_cast(operator*));
1207 | CheckNdArrayInplace(
1208 | NdArray::Arange(3.f), 2.f, "[0, 0.5, 1]",
1209 | static_cast(operator/));
1210 | }
1211 |
1212 | SECTION("Arithmetic operators (float, NdArray) (inplace)") {
1213 | CheckNdArrayInplace(
1214 | 2.f, NdArray::Arange(3.f), "[2, 3, 4]",
1215 | static_cast(operator+));
1216 | CheckNdArrayInplace(
1217 | 2.f, NdArray::Arange(3.f), "[2, 1, 0]",
1218 | static_cast(operator-));
1219 | CheckNdArrayInplace(
1220 | 2.f, NdArray::Arange(3.f), "[0, 2, 4]",
1221 | static_cast(operator*));
1222 | CheckNdArrayInplace(
1223 | 2.f, NdArray::Arange(3.f), "[inf, 2, 1]",
1224 | static_cast(operator/));
1225 | }
1226 |
1227 | SECTION("Comparison operators (NdArray, NdArray) (inplace both)") {
1228 | CheckNdArrayInplace(
1229 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[0, 1, 0]",
1230 | static_cast(operator==));
1231 | CheckNdArrayInplace(
1232 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[1, 0, 1]",
1233 | static_cast(operator!=));
1234 | CheckNdArrayInplace(
1235 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[0, 0, 1]",
1236 | static_cast(operator>));
1237 | CheckNdArrayInplace(
1238 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[0, 1, 1]",
1239 | static_cast(operator>=));
1240 | CheckNdArrayInplace(
1241 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[1, 0, 0]",
1242 | static_cast(operator<));
1243 | CheckNdArrayInplace(
1244 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[1, 1, 0]",
1245 | static_cast(operator<=));
1246 | }
1247 |
1248 | SECTION("Comparison operators (NdArray, NdArray) (inplace right)") {
1249 | auto m2 = NdArray::Zeros(1) + 1.f;
1250 | CheckNdArrayInplace(NdArray::Arange(3.f), m2, "[0, 1, 0]",
1251 | static_cast(
1252 | operator==));
1253 | CheckNdArrayInplace(NdArray::Arange(3.f), m2, "[1, 0, 1]",
1254 | static_cast(
1255 | operator!=));
1256 | CheckNdArrayInplace(
1257 | NdArray::Arange(3.f), m2, "[0, 0, 1]",
1258 | static_cast(operator>));
1259 | CheckNdArrayInplace(NdArray::Arange(3.f), m2, "[0, 1, 1]",
1260 | static_cast(
1261 | operator>=));
1262 | CheckNdArrayInplace(
1263 | NdArray::Arange(3.f), m2, "[1, 0, 0]",
1264 | static_cast(operator<));
1265 | CheckNdArrayInplace(NdArray::Arange(3.f), m2, "[1, 1, 0]",
1266 | static_cast(
1267 | operator<=));
1268 | }
1269 |
1270 | SECTION("Comparison operators (NdArray, NdArray) (inplace left)") {
1271 | auto m1 = NdArray::Zeros(1) + 1.f;
1272 | CheckNdArrayInplace(m1, NdArray::Arange(3.f), "[0, 1, 0]",
1273 | static_cast(
1274 | operator==));
1275 | CheckNdArrayInplace(m1, NdArray::Arange(3.f), "[1, 0, 1]",
1276 | static_cast(
1277 | operator!=));
1278 | CheckNdArrayInplace(
1279 | m1, NdArray::Arange(3.f), "[1, 0, 0]",
1280 | static_cast(operator>));
1281 | CheckNdArrayInplace(m1, NdArray::Arange(3.f), "[1, 1, 0]",
1282 | static_cast(
1283 | operator>=));
1284 | CheckNdArrayInplace(
1285 | m1, NdArray::Arange(3.f), "[0, 0, 1]",
1286 | static_cast(operator<));
1287 | CheckNdArrayInplace(m1, NdArray::Arange(3.f), "[0, 1, 1]",
1288 | static_cast(
1289 | operator<=));
1290 | }
1291 |
1292 | SECTION("Comparison operators (NdArray, float) (inplace)") {
1293 | CheckNdArrayInplace(
1294 | NdArray::Arange(3.f), 1.f, "[0, 1, 0]",
1295 | static_cast(operator==));
1296 | CheckNdArrayInplace(
1297 | NdArray::Arange(3.f), 1.f, "[1, 0, 1]",
1298 | static_cast(operator!=));
1299 | CheckNdArrayInplace(
1300 | NdArray::Arange(3.f), 1.f, "[0, 0, 1]",
1301 | static_cast(operator>));
1302 | CheckNdArrayInplace(
1303 | NdArray::Arange(3.f), 1.f, "[0, 1, 1]",
1304 | static_cast(operator>=));
1305 | CheckNdArrayInplace(
1306 | NdArray::Arange(3.f), 1.f, "[1, 0, 0]",
1307 | static_cast(operator<));
1308 | CheckNdArrayInplace(
1309 | NdArray::Arange(3.f), 1.f, "[1, 1, 0]",
1310 | static_cast(operator<=));
1311 | }
1312 |
1313 | SECTION("Comparison operators (NdArray, float) (inplace)") {
1314 | CheckNdArrayInplace(
1315 | 1.f, NdArray::Arange(3.f), "[0, 1, 0]",
1316 | static_cast(operator==));
1317 | CheckNdArrayInplace(
1318 | 1.f, NdArray::Arange(3.f), "[1, 0, 1]",
1319 | static_cast(operator!=));
1320 | CheckNdArrayInplace(
1321 | 1.f, NdArray::Arange(3.f), "[1, 0, 0]",
1322 | static_cast(operator>));
1323 | CheckNdArrayInplace(
1324 | 1.f, NdArray::Arange(3.f), "[1, 1, 0]",
1325 | static_cast(operator>=));
1326 | CheckNdArrayInplace(
1327 | 1.f, NdArray::Arange(3.f), "[0, 0, 1]",
1328 | static_cast(operator<));
1329 | CheckNdArrayInplace(
1330 | 1.f, NdArray::Arange(3.f), "[0, 1, 1]",
1331 | static_cast(operator<=));
1332 | }
1333 |
1334 | SECTION("Compound assignment operators (NdArray, NdArray) (in-plafce, &)") {
1335 | auto m0 = NdArray::Arange(6.f).reshape(2, 3);
1336 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
1337 | auto m2 = NdArray::Arange(6.f).reshape(2, 3);
1338 | auto m3 = NdArray::Arange(6.f).reshape(2, 3);
1339 | auto m4 = NdArray::Arange(6.f).reshape(2, 3);
1340 | auto m5 = NdArray::Arange(3.f);
1341 | auto m0_id = m0.id();
1342 | auto m1_id = m1.id();
1343 | auto m2_id = m2.id();
1344 | auto m3_id = m3.id();
1345 | auto m4_id = m4.id();
1346 | m0 += m0;
1347 | m1 += NdArray::Arange(3.f);
1348 | m2 -= NdArray::Arange(3.f);
1349 | m3 *= NdArray::Arange(3.f);
1350 | m4 /= NdArray::Arange(3.f);
1351 | CheckNdArray(m0,
1352 | "[[0, 2, 4],\n"
1353 | " [6, 8, 10]]");
1354 | CheckNdArray(m1,
1355 | "[[0, 2, 4],\n"
1356 | " [3, 5, 7]]");
1357 | CheckNdArray(m2,
1358 | "[[0, 0, 0],\n"
1359 | " [3, 3, 3]]");
1360 | CheckNdArray(m3,
1361 | "[[0, 1, 4],\n"
1362 | " [0, 4, 10]]");
1363 | ResolveAmbiguous(m4); // -nan -> nan
1364 | CheckNdArray(m4,
1365 | "[[nan, 1, 1],\n"
1366 | " [inf, 4, 2.5]]");
1367 | CHECK(m0.id() == m0_id); // in-place
1368 | CHECK(m1.id() == m1_id);
1369 | CHECK(m2.id() == m2_id);
1370 | CHECK(m3.id() == m3_id);
1371 | CHECK(m4.id() == m4_id);
1372 | // size change is not allowed
1373 | CHECK_THROWS(m5 += m0);
1374 | auto m6 = m0.reshape(2, 1, 3);
1375 | CHECK_THROWS(m6 *= m5.reshape(3, 1));
1376 | }
1377 |
1378 | SECTION("Compound assignment operators (NdArray, float) (in-plafce, &)") {
1379 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
1380 | auto m2 = NdArray::Arange(6.f).reshape(2, 3);
1381 | auto m3 = NdArray::Arange(6.f).reshape(2, 3);
1382 | auto m4 = NdArray::Arange(6.f).reshape(2, 3);
1383 | auto m1_id = m1.id();
1384 | auto m2_id = m2.id();
1385 | auto m3_id = m3.id();
1386 | auto m4_id = m4.id();
1387 | m1 += 10.f;
1388 | m2 -= 10.f;
1389 | m3 *= 10.f;
1390 | m4 /= 10.f;
1391 | CheckNdArray(m1,
1392 | "[[10, 11, 12],\n"
1393 | " [13, 14, 15]]");
1394 | CheckNdArray(m2,
1395 | "[[-10, -9, -8],\n"
1396 | " [-7, -6, -5]]");
1397 | CheckNdArray(m3,
1398 | "[[0, 10, 20],\n"
1399 | " [30, 40, 50]]");
1400 | CheckNdArray(m4,
1401 | "[[0, 0.1, 0.2],\n"
1402 | " [0.3, 0.4, 0.5]]");
1403 | CHECK(m1.id() == m1_id);
1404 | CHECK(m2.id() == m2_id);
1405 | CHECK(m3.id() == m3_id);
1406 | CHECK(m4.id() == m4_id);
1407 | }
1408 |
1409 | SECTION("Compound assignment operators (in-plafce, &&)") {
1410 | // (NdArray, NdArray)
1411 | auto m1 = NdArray::Arange(6.f).reshape(2, 3) += NdArray::Arange(3.f);
1412 | auto m2 = NdArray::Arange(6.f).reshape(2, 3) -= NdArray::Arange(3.f);
1413 | auto m3 = NdArray::Arange(6.f).reshape(2, 3) *= NdArray::Arange(3.f);
1414 | auto m4 = NdArray::Arange(6.f).reshape(2, 3) /= NdArray::Arange(3.f);
1415 | CheckNdArray(m1,
1416 | "[[0, 2, 4],\n"
1417 | " [3, 5, 7]]");
1418 | CheckNdArray(m2,
1419 | "[[0, 0, 0],\n"
1420 | " [3, 3, 3]]");
1421 | CheckNdArray(m3,
1422 | "[[0, 1, 4],\n"
1423 | " [0, 4, 10]]");
1424 | ResolveAmbiguous(m4); // -nan -> nan
1425 | CheckNdArray(m4,
1426 | "[[nan, 1, 1],\n"
1427 | " [inf, 4, 2.5]]");
1428 | // (NdArray, float)
1429 | auto m5 = NdArray::Arange(6.f).reshape(2, 3) += 10.f;
1430 | auto m6 = NdArray::Arange(6.f).reshape(2, 3) -= 10.f;
1431 | auto m7 = NdArray::Arange(6.f).reshape(2, 3) *= 10.f;
1432 | auto m8 = NdArray::Arange(6.f).reshape(2, 3) /= 10.f;
1433 | CheckNdArray(m5,
1434 | "[[10, 11, 12],\n"
1435 | " [13, 14, 15]]");
1436 | CheckNdArray(m6,
1437 | "[[-10, -9, -8],\n"
1438 | " [-7, -6, -5]]");
1439 | CheckNdArray(m7,
1440 | "[[0, 10, 20],\n"
1441 | " [30, 40, 50]]");
1442 | CheckNdArray(m8,
1443 | "[[0, 0.1, 0.2],\n"
1444 | " [0.3, 0.4, 0.5]]");
1445 | }
1446 |
1447 | // --------------------------- Operator function ---------------------------
1448 | SECTION("Single") {
1449 | auto m1 = NdArray::Arange(3.f);
1450 | auto m2 = Positive(m1);
1451 | auto m3 = Negative(m1);
1452 | m1[0] = -1.f;
1453 | CheckNdArray(m1, "[-1, 1, 2]");
1454 | CheckNdArray(m2, "[0, 1, 2]");
1455 | CheckNdArray(m3, "[-0, -1, -2]");
1456 | }
1457 |
1458 | SECTION("Function Arithmetic (NdArray, NdArray)") {
1459 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
1460 | auto m2 = NdArray::Arange(3.f);
1461 | auto m_add = Add(m1, m2);
1462 | auto m_sub = Subtract(m1, m2);
1463 | auto m_mul = Multiply(m1, m2);
1464 | auto m_div = Divide(m1, m2);
1465 | CheckNdArray(m_add,
1466 | "[[0, 2, 4],\n"
1467 | " [3, 5, 7]]");
1468 | CheckNdArray(m_sub,
1469 | "[[0, 0, 0],\n"
1470 | " [3, 3, 3]]");
1471 | CheckNdArray(m_mul,
1472 | "[[0, 1, 4],\n"
1473 | " [0, 4, 10]]");
1474 | ResolveAmbiguous(m_div); // -nan -> nan
1475 | CheckNdArray(m_div,
1476 | "[[nan, 1, 1],\n"
1477 | " [inf, 4, 2.5]]");
1478 | }
1479 |
1480 | SECTION("Function Arithmetic (NdArray, float)") {
1481 | auto m1 = NdArray::Arange(3.f);
1482 | auto m_add = Add(m1, 2.f);
1483 | auto m_sub = Subtract(m1, 2.f);
1484 | auto m_mul = Multiply(m1, 2.f);
1485 | auto m_div = Divide(m1, 2.f);
1486 | CheckNdArray(m_add, "[2, 3, 4]");
1487 | CheckNdArray(m_sub, "[-2, -1, 0]");
1488 | CheckNdArray(m_mul, "[0, 2, 4]");
1489 | CheckNdArray(m_div, "[0, 0.5, 1]");
1490 | }
1491 |
1492 | SECTION("Function Arithmetic (float, NdArray)") {
1493 | auto m1 = NdArray::Arange(3.f);
1494 | auto m_add = Add(2.f, m1);
1495 | auto m_sub = Subtract(2.f, m1);
1496 | auto m_mul = Multiply(2.f, m1);
1497 | auto m_div = Divide(2.f, m1);
1498 | CheckNdArray(m_add, "[2, 3, 4]");
1499 | CheckNdArray(m_sub, "[2, 1, 0]");
1500 | CheckNdArray(m_mul, "[0, 2, 4]");
1501 | CheckNdArray(m_div, "[inf, 2, 1]");
1502 | }
1503 |
1504 | SECTION("Function Comparison") {
1505 | auto m1 = NdArray::Arange(3.f);
1506 | auto m2 = NdArray::Zeros(1) + 1.f;
1507 | CheckNdArray(Equal(m1, m2), "[0, 1, 0]");
1508 | CheckNdArray(NotEqual(m1, m2), "[1, 0, 1]");
1509 | CheckNdArray(Greater(m1, m2), "[0, 0, 1]");
1510 | CheckNdArray(GreaterEqual(m1, m2), "[0, 1, 1]");
1511 | CheckNdArray(Less(m1, m2), "[1, 0, 0]");
1512 | CheckNdArray(LessEqual(m1, m2), "[1, 1, 0]");
1513 | CheckNdArray(Equal(m1, 1.f), "[0, 1, 0]");
1514 | CheckNdArray(NotEqual(m1, 1.f), "[1, 0, 1]");
1515 | CheckNdArray(Greater(m1, 1.f), "[0, 0, 1]");
1516 | CheckNdArray(GreaterEqual(m1, 1.f), "[0, 1, 1]");
1517 | CheckNdArray(Less(m1, 1.f), "[1, 0, 0]");
1518 | CheckNdArray(LessEqual(m1, 1.f), "[1, 1, 0]");
1519 | CheckNdArray(Equal(1.f, m1), "[0, 1, 0]");
1520 | CheckNdArray(NotEqual(1.f, m1), "[1, 0, 1]");
1521 | CheckNdArray(Greater(1.f, m1), "[1, 0, 0]");
1522 | CheckNdArray(GreaterEqual(1.f, m1), "[1, 1, 0]");
1523 | CheckNdArray(Less(1.f, m1), "[0, 0, 1]");
1524 | CheckNdArray(LessEqual(1.f, m1), "[0, 1, 1]");
1525 | }
1526 |
1527 | SECTION("Function Dot") {
1528 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
1529 | auto m2 = NdArray::Arange(3.f);
1530 | auto m12 = Dot(m1, m2);
1531 | auto m_a = Dot(m2, {2.f});
1532 | auto m_b = Dot({2.f}, m2);
1533 | CheckNdArray(m12, "[5, 14]");
1534 | CheckNdArray(m_a, "[0, 2, 4]");
1535 | CheckNdArray(m_b, "[0, 2, 4]");
1536 | }
1537 |
1538 | SECTION("Function Cross") {
1539 | NdArray m1 = {1.f, 2.f, 3.f};
1540 | NdArray m2 = {4.f, 5.f, 6.f};
1541 | auto m12 = Cross(m1, m2);
1542 | CheckNdArray(m12, "[-3, 6, -3]");
1543 | }
1544 |
1545 | SECTION("Function Basic Math") {
1546 | auto m1 = NdArray::Arange(3.f);
1547 | auto m2 = NdArray::Arange(7.f) / 3.f - 1.f;
1548 | CheckNdArray(-m1, "[-0, -1, -2]");
1549 | CheckNdArray(Abs(-m1), "[0, 1, 2]");
1550 | CheckNdArray(Sign(m2), "[-1, -1, -1, 0, 1, 1, 1]");
1551 | CheckNdArray(Ceil(m2), "[-1, -0, -0, 0, 1, 1, 1]");
1552 | CheckNdArray(Floor(m2), "[-1, -1, -1, 0, 0, 0, 1]");
1553 | CheckNdArray(Clip(m2, -0.5f, 0.4f),
1554 | "[-0.5, -0.5, -0.333333, 0, 0.333333, 0.4, 0.4]");
1555 | CheckNdArray(Sqrt(m1), "[0, 1, 1.41421]");
1556 | CheckNdArray(Exp(m1), "[1, 2.71828, 7.38906]");
1557 | CheckNdArray(Log(m1), "[-inf, 0, 0.693147]");
1558 | CheckNdArray(Square(m1), "[0, 1, 4]");
1559 | CheckNdArray(Power(m1, m1 + 1.f), "[0, 1, 8]");
1560 | CheckNdArray(Power(m1, 4.f), "[0, 1, 16]");
1561 | CheckNdArray(Power(4.f, m1), "[1, 4, 16]");
1562 | }
1563 |
1564 | SECTION("Function Trigonometric") {
1565 | auto m1 = NdArray::Arange(3.f);
1566 | CheckNdArray(Sin(m1), "[0, 0.841471, 0.909297]");
1567 | CheckNdArray(Cos(m1), "[1, 0.540302, -0.416147]");
1568 | CheckNdArray(Tan(m1), "[0, 1.55741, -2.18504]");
1569 | }
1570 |
1571 | SECTION("Function Inverse-Trigonometric") {
1572 | auto m1 = NdArray::Arange(3) - 1.f;
1573 | auto m2 = NdArray::Arange(3) * 100.f;
1574 | CheckNdArray(ArcSin(m1), "[-1.5708, 0, 1.5708]");
1575 | CheckNdArray(ArcCos(m1), "[3.14159, 1.5708, 0]");
1576 | CheckNdArray(ArcTan(m2), "[0, 1.5608, 1.5658]");
1577 | CheckNdArray(ArcTan2(m1, m2), "[-1.5708, 0, 0.00499996]");
1578 | CheckNdArray(ArcTan2(m1, 2.f), "[-0.463648, 0, 0.463648]");
1579 | CheckNdArray(ArcTan2(2.f, m1), "[2.03444, 1.5708, 1.10715]");
1580 | }
1581 |
1582 | SECTION("Function Axis") {
1583 | auto m1 = NdArray::Arange(36.f).reshape(2, 3, 2, 3);
1584 | CheckNdArray(Sum(m1, {1, 3}),
1585 | "[[63, 90],\n"
1586 | " [225, 252]]");
1587 | CheckNdArray(Sum(m1, {-1, -3}), // negative axis
1588 | "[[63, 90],\n"
1589 | " [225, 252]]");
1590 | auto m2 = NdArray::Arange(12.f).reshape(2, 3, 2) - 6.f;
1591 | CheckNdArray(Min(m2, {2, 1}), "[-6, 0]");
1592 | CheckNdArray(Min(m2, {-2, -1}), "[-6, 0]"); // negative axis
1593 | CheckNdArray(Max(m2, {2, 1}), "[-1, 5]");
1594 | CheckNdArray(Max(m2, {-2, -1}), "[-1, 5]"); // negative axis
1595 | CheckNdArray(Mean(m2, {2, 1}), "[-3.5, 2.5]");
1596 | CheckNdArray(Mean(m2, {-2, -1}), "[-3.5, 2.5]"); // negative axis
1597 | }
1598 |
1599 | SECTION("Function Shape") {
1600 | auto m1 = NdArray::Arange(6.f).reshape(1, 2, 1, 3, 1);
1601 | CHECK(Reshape(m1, {2, 1, 3}).shape() == Shape{2, 1, 3});
1602 | CHECK(Squeeze(m1).shape() == Shape{2, 3});
1603 | CHECK(Squeeze(m1, {0, 2}).shape() == Shape{2, 3, 1});
1604 | CHECK(Squeeze(m1, {-1, -3}).shape() == Shape{1, 2, 3});
1605 | CHECK(ExpandDims(m1, 0).shape() == Shape{1, 1, 2, 1, 3, 1});
1606 | CHECK(ExpandDims(m1, 2).shape() == Shape{1, 2, 1, 1, 3, 1});
1607 | CHECK(ExpandDims(m1, 5).shape() == Shape{1, 2, 1, 3, 1, 1});
1608 | CHECK(ExpandDims(m1, -1).shape() == Shape{1, 2, 1, 3, 1, 1});
1609 | CHECK(ExpandDims(m1.ravel(), -1).shape() == Shape{6, 1});
1610 | CHECK_THROWS(Squeeze(m1, {-2}));
1611 | CHECK_THROWS(Squeeze(m1, {5}));
1612 | CHECK_THROWS(ExpandDims(m1, 6));
1613 | }
1614 |
1615 | SECTION("Function Stack") {
1616 | auto m1 = NdArray::Arange(12.f).reshape(4, 3);
1617 | auto m2 = NdArray::Arange(12.f).reshape(4, 3) + 1.f;
1618 | CheckNdArray(Stack({m1, m2}, 0),
1619 | "[[[0, 1, 2],\n"
1620 | " [3, 4, 5],\n"
1621 | " [6, 7, 8],\n"
1622 | " [9, 10, 11]],\n"
1623 | " [[1, 2, 3],\n"
1624 | " [4, 5, 6],\n"
1625 | " [7, 8, 9],\n"
1626 | " [10, 11, 12]]]");
1627 | CheckNdArray(Stack({m1, m2}, 1),
1628 | "[[[0, 1, 2],\n"
1629 | " [1, 2, 3]],\n"
1630 | " [[3, 4, 5],\n"
1631 | " [4, 5, 6]],\n"
1632 | " [[6, 7, 8],\n"
1633 | " [7, 8, 9]],\n"
1634 | " [[9, 10, 11],\n"
1635 | " [10, 11, 12]]]");
1636 | CheckNdArray(Stack({m1, m2}, 2),
1637 | "[[[0, 1],\n"
1638 | " [1, 2],\n"
1639 | " [2, 3]],\n"
1640 | " [[3, 4],\n"
1641 | " [4, 5],\n"
1642 | " [5, 6]],\n"
1643 | " [[6, 7],\n"
1644 | " [7, 8],\n"
1645 | " [8, 9]],\n"
1646 | " [[9, 10],\n"
1647 | " [10, 11],\n"
1648 | " [11, 12]]]");
1649 | CheckNdArray(Stack({m1, m2}, -3), // negative axis
1650 | "[[[0, 1, 2],\n"
1651 | " [3, 4, 5],\n"
1652 | " [6, 7, 8],\n"
1653 | " [9, 10, 11]],\n"
1654 | " [[1, 2, 3],\n"
1655 | " [4, 5, 6],\n"
1656 | " [7, 8, 9],\n"
1657 | " [10, 11, 12]]]");
1658 | CHECK_THROWS(Stack({m1, m2}, 3));
1659 | CHECK_THROWS(Stack({m1, m2}, -4)); // negative axis
1660 | }
1661 |
1662 | SECTION("Function Concatenate") {
1663 | auto m1 = NdArray::Arange(12.f).reshape(4, 3);
1664 | auto m2 = NdArray::Arange(6.f).reshape(2, 3) + 1.f;
1665 | CheckNdArray(Concatenate({m1, m2}, 0),
1666 | "[[0, 1, 2],\n"
1667 | " [3, 4, 5],\n"
1668 | " [6, 7, 8],\n"
1669 | " [9, 10, 11],\n"
1670 | " [1, 2, 3],\n"
1671 | " [4, 5, 6]]");
1672 | CHECK_THROWS(Concatenate({m1, m2}, 1));
1673 | CHECK_THROWS(Concatenate({m1, m2}, 2));
1674 | CHECK_THROWS(Concatenate({m1, m2}, -1));
1675 | CHECK_THROWS(Concatenate({m1, m2}, -3));
1676 | auto m3 = NdArray::Arange(8.f).reshape(4, 2) + 1.f;
1677 | CheckNdArray(Concatenate({m1, m3}, 1),
1678 | "[[0, 1, 2, 1, 2],\n"
1679 | " [3, 4, 5, 3, 4],\n"
1680 | " [6, 7, 8, 5, 6],\n"
1681 | " [9, 10, 11, 7, 8]]");
1682 | CHECK_THROWS(Concatenate({m1, m3}, 0));
1683 | CHECK_THROWS(Concatenate({m1, m3}, 2));
1684 | CHECK_THROWS(Concatenate({m1, m3}, -2));
1685 | CHECK_THROWS(Concatenate({m1, m3}, -3));
1686 |
1687 | auto m4 = NdArray::Arange(12.f).reshape(4, 1, 3);
1688 | auto m5 = NdArray::Arange(6.f).reshape(2, 1, 3) + 1.f;
1689 | CheckNdArray(Concatenate({m4, m5}, 0),
1690 | "[[[0, 1, 2]],\n"
1691 | " [[3, 4, 5]],\n"
1692 | " [[6, 7, 8]],\n"
1693 | " [[9, 10, 11]],\n"
1694 | " [[1, 2, 3]],\n"
1695 | " [[4, 5, 6]]]");
1696 | CHECK_THROWS(Concatenate({m4, m5}, 1));
1697 | CHECK_THROWS(Concatenate({m4, m5}, 2));
1698 | CHECK_THROWS(Concatenate({m4, m5}, -1));
1699 | CHECK_THROWS(Concatenate({m4, m5}, -2));
1700 | CHECK_THROWS(Concatenate({m4, m5}, -4));
1701 | auto m6 = NdArray::Arange(8.f).reshape(4, 1, 2) + 1.f;
1702 | CheckNdArray(Concatenate({m4, m6}, 2),
1703 | "[[[0, 1, 2, 1, 2]],\n"
1704 | " [[3, 4, 5, 3, 4]],\n"
1705 | " [[6, 7, 8, 5, 6]],\n"
1706 | " [[9, 10, 11, 7, 8]]]");
1707 | CHECK_THROWS(Concatenate({m4, m6}, 0));
1708 | CHECK_THROWS(Concatenate({m4, m6}, 1));
1709 | CHECK_THROWS(Concatenate({m4, m6}, -2));
1710 | CHECK_THROWS(Concatenate({m4, m6}, -3));
1711 | CHECK_THROWS(Concatenate({m4, m6}, -4));
1712 | }
1713 |
1714 | SECTION("Function Split by indices") {
1715 | auto m1 = NdArray::Arange(16.f).reshape(2, 4, 2);
1716 |
1717 | auto r0 = Split(m1, {1, 1}, 1);
1718 | CHECK(r0.size() == 3);
1719 | CheckNdArray(r0[0],
1720 | "[[[0, 1]],\n"
1721 | " [[8, 9]]]");
1722 | CheckNdArray(r0[1], "[]");
1723 | CheckNdArray(r0[2],
1724 | "[[[2, 3],\n"
1725 | " [4, 5],\n"
1726 | " [6, 7]],\n"
1727 | " [[10, 11],\n"
1728 | " [12, 13],\n"
1729 | " [14, 15]]]");
1730 |
1731 | auto r1 = Split(m1, {2, 0}, 1);
1732 | CHECK(r1.size() == 3);
1733 | CheckNdArray(r1[0],
1734 | "[[[0, 1],\n"
1735 | " [2, 3]],\n"
1736 | " [[8, 9],\n"
1737 | " [10, 11]]]");
1738 | CheckNdArray(r1[1], "[]");
1739 | CheckNdArray(r1[2],
1740 | "[[[0, 1],\n"
1741 | " [2, 3],\n"
1742 | " [4, 5],\n"
1743 | " [6, 7]],\n"
1744 | " [[8, 9],\n"
1745 | " [10, 11],\n"
1746 | " [12, 13],\n"
1747 | " [14, 15]]]");
1748 |
1749 | auto r2 = Split(m1, {0, 2, 3}, 1);
1750 | CHECK(r2.size() == 4);
1751 | CheckNdArray(r2[0], "[]");
1752 | CheckNdArray(r2[1],
1753 | "[[[0, 1],\n"
1754 | " [2, 3]],\n"
1755 | " [[8, 9],\n"
1756 | " [10, 11]]]");
1757 | CheckNdArray(r2[2],
1758 | "[[[4, 5]],\n"
1759 | " [[12, 13]]]");
1760 | CheckNdArray(r2[3],
1761 | "[[[6, 7]],\n"
1762 | " [[14, 15]]]");
1763 |
1764 | auto r3 = Split(m1, {2, 4}, 1);
1765 | CHECK(r3.size() == 3);
1766 | CheckNdArray(r3[0],
1767 | "[[[0, 1],\n"
1768 | " [2, 3]],\n"
1769 | " [[8, 9],\n"
1770 | " [10, 11]]]");
1771 | CheckNdArray(r3[1],
1772 | "[[[4, 5],\n"
1773 | " [6, 7]],\n"
1774 | " [[12, 13],\n"
1775 | " [14, 15]]]");
1776 | CheckNdArray(r3[2], "[]");
1777 |
1778 | auto r4 = Split(m1, {2, 4}, -2); // negative axis
1779 | CHECK(r4.size() == 3);
1780 | CheckNdArray(r4[0],
1781 | "[[[0, 1],\n"
1782 | " [2, 3]],\n"
1783 | " [[8, 9],\n"
1784 | " [10, 11]]]");
1785 | CheckNdArray(r4[1],
1786 | "[[[4, 5],\n"
1787 | " [6, 7]],\n"
1788 | " [[12, 13],\n"
1789 | " [14, 15]]]");
1790 | CheckNdArray(r4[2], "[]");
1791 | }
1792 |
1793 | SECTION("Function Split by n_section") {
1794 | auto m1 = NdArray::Arange(16.f).reshape(2, 4, 2);
1795 | auto r0 = Split(m1, 2, 1);
1796 | CHECK(r0.size() == 2);
1797 | CheckNdArray(r0[0],
1798 | "[[[0, 1],\n"
1799 | " [2, 3]],\n"
1800 | " [[8, 9],\n"
1801 | " [10, 11]]]");
1802 | CheckNdArray(r0[1],
1803 | "[[[4, 5],\n"
1804 | " [6, 7]],\n"
1805 | " [[12, 13],\n"
1806 | " [14, 15]]]");
1807 |
1808 | auto r1 = Split(m1, 4, 1);
1809 | CHECK(r1.size() == 4);
1810 | CheckNdArray(r1[0],
1811 | "[[[0, 1]],\n"
1812 | " [[8, 9]]]");
1813 | CheckNdArray(r1[1],
1814 | "[[[2, 3]],\n"
1815 | " [[10, 11]]]");
1816 | CheckNdArray(r1[2],
1817 | "[[[4, 5]],\n"
1818 | " [[12, 13]]]");
1819 | CheckNdArray(r1[3],
1820 | "[[[6, 7]],\n"
1821 | " [[14, 15]]]");
1822 |
1823 | auto r2 = Split(m1, 2, 2);
1824 | CHECK(r2.size() == 2);
1825 | CheckNdArray(r2[0],
1826 | "[[[0],\n"
1827 | " [2],\n"
1828 | " [4],\n"
1829 | " [6]],\n"
1830 | " [[8],\n"
1831 | " [10],\n"
1832 | " [12],\n"
1833 | " [14]]]");
1834 | CheckNdArray(r2[1],
1835 | "[[[1],\n"
1836 | " [3],\n"
1837 | " [5],\n"
1838 | " [7]],\n"
1839 | " [[9],\n"
1840 | " [11],\n"
1841 | " [13],\n"
1842 | " [15]]]");
1843 |
1844 | auto r3 = Split(m1, 4, -2);
1845 | CHECK(r3.size() == 4);
1846 | CheckNdArray(r3[0],
1847 | "[[[0, 1]],\n"
1848 | " [[8, 9]]]");
1849 | CheckNdArray(r3[1],
1850 | "[[[2, 3]],\n"
1851 | " [[10, 11]]]");
1852 | CheckNdArray(r3[2],
1853 | "[[[4, 5]],\n"
1854 | " [[12, 13]]]");
1855 | CheckNdArray(r3[3],
1856 | "[[[6, 7]],\n"
1857 | " [[14, 15]]]");
1858 | }
1859 |
1860 | SECTION("Function Separate") {
1861 | auto m = NdArray::Arange(16.f).reshape(2, 4, 2);
1862 | // Axis 0
1863 | auto r0 = Separate(m, 0);
1864 | CHECK(r0.size() == 2);
1865 | auto m0 = Stack(r0, 0);
1866 | CHECK(All(m == m0));
1867 | // Axis 1
1868 | auto r1 = Separate(m, 1);
1869 | CHECK(r1.size() == 4);
1870 | auto m1 = Stack(r1, 1);
1871 | CHECK(All(m == m1));
1872 | // Axis 2
1873 | auto r2 = Separate(m, 2);
1874 | CHECK(r2.size() == 2);
1875 | auto m2 = Stack(r2, 2);
1876 | CHECK(All(m == m2));
1877 | // Negative axis
1878 | auto r3 = Separate(m, -1);
1879 | CHECK(r3.size() == 2);
1880 | auto m3 = Stack(r3, -1);
1881 | CHECK(All(m == m3));
1882 | }
1883 |
1884 | SECTION("Function Transpose") {
1885 | auto m1 = NdArray::Arange(6.f).reshape(3, 2);
1886 | CheckNdArray(Transpose(m1),
1887 | "[[0, 2, 4],\n"
1888 | " [1, 3, 5]]");
1889 | auto m2 = NdArray::Arange(8).reshape(2, 2, 2);
1890 | CheckNdArray(Transpose(m2),
1891 | "[[[0, 4],\n"
1892 | " [2, 6]],\n"
1893 | " [[1, 5],\n"
1894 | " [3, 7]]]");
1895 | }
1896 |
1897 | SECTION("Function Swapaxes") {
1898 | auto m1 = NdArray::Arange(6.f).reshape(3, 2);
1899 | CheckNdArray(Swapaxes(m1, -1, -2),
1900 | "[[0, 2, 4],\n"
1901 | " [1, 3, 5]]");
1902 | auto m2 = NdArray::Arange(8).reshape(2, 2, 2);
1903 | CheckNdArray(Swapaxes(m2, 0, 2),
1904 | "[[[0, 4],\n"
1905 | " [2, 6]],\n"
1906 | " [[1, 5],\n"
1907 | " [3, 7]]]");
1908 | }
1909 |
1910 | SECTION("Function BroadcastTo") {
1911 | CheckNdArray(BroadcastTo(NdArray::Arange(3.f), {2, 3}),
1912 | "[[0, 1, 2],\n"
1913 | " [0, 1, 2]]");
1914 | CheckNdArray(BroadcastTo(NdArray::Arange(3.f), {2, 2, 3}),
1915 | "[[[0, 1, 2],\n"
1916 | " [0, 1, 2]],\n"
1917 | " [[0, 1, 2],\n"
1918 | " [0, 1, 2]]]");
1919 | }
1920 |
1921 | SECTION("Function Inverse (2d)") {
1922 | auto m1 = NdArray::Arange(4).reshape(2, 2) + 1.f;
1923 | auto m2 = Inv(m1);
1924 | CheckNdArray(m2,
1925 | "[[-2, 1],\n"
1926 | " [1.5, -0.5]]",
1927 | 4); // Low precision
1928 | CheckNdArray(m1.dot(m2),
1929 | "[[1, 0],\n"
1930 | " [0, 1]]",
1931 | 4); // Low precision
1932 | }
1933 |
1934 | SECTION("Function Inverse (high-dim)") {
1935 | auto m1 = NdArray::Arange(24).reshape(2, 3, 2, 2) + 1.f;
1936 | auto m2 = Inv(m1);
1937 | CheckNdArray(m2,
1938 | "[[[[-2, 1],\n"
1939 | " [1.5, -0.5]],\n"
1940 | " [[-4, 3],\n"
1941 | " [3.5, -2.5]],\n"
1942 | " [[-6, 5],\n"
1943 | " [5.5, -4.5]]],\n"
1944 | " [[[-8, 7],\n"
1945 | " [7.5, -6.5]],\n"
1946 | " [[-10, 9],\n"
1947 | " [9.5, -8.5]],\n"
1948 | " [[-12, 11],\n"
1949 | " [11.5, -10.5]]]]",
1950 | 4); // Low precision
1951 | }
1952 |
1953 | // ----------------------- In-place Operator function ----------------------
1954 | SECTION("Function in-place basic") {
1955 | // In-place
1956 | NdArray m1 = NdArray::Arange(3);
1957 | NdArray m2 = NdArray::Arange(3);
1958 | uintptr_t m1_id = m1.id();
1959 | uintptr_t m2_id = m2.id();
1960 | NdArray m3 = Power(std::move(m1), std::move(m2));
1961 | CHECK((m3.id() == m1_id || m3.id() == m2_id)); // m3 is not new array
1962 | // Not in-place
1963 | CheckNdArrayNotInplace(-NdArray::Arange(3.f), "[0, 1, 2]",
1964 | static_cast(Abs));
1965 | // No matching broadcast
1966 | NdArray m4 = NdArray::Arange(6).reshape(2, 1, 3);
1967 | NdArray m5 = NdArray::Arange(2).reshape(2, 1);
1968 | uintptr_t m4_id = m4.id();
1969 | uintptr_t m5_id = m5.id();
1970 | NdArray m6 = Power(std::move(m4), std::move(m5));
1971 | CHECK((m6.id() != m4_id && m6.id() != m5_id)); // m6 is new array
1972 | }
1973 |
1974 | SECTION("Single (inplace)") {
1975 | CheckNdArrayInplace(NdArray::Arange(3.f), "[0, 1, 2]",
1976 | static_cast(Positive));
1977 | CheckNdArrayInplace(NdArray::Arange(3.f), "[-0, -1, -2]",
1978 | static_cast(Negative));
1979 | }
1980 |
1981 | SECTION("Function Arithmetic (NdArray, NdArray) (in-place both)") {
1982 | CheckNdArrayInplace(
1983 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f),
1984 | "[[0, 2, 4],\n"
1985 | " [3, 5, 7]]",
1986 | static_cast(Add));
1987 | CheckNdArrayInplace(
1988 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f),
1989 | "[[0, 0, 0],\n"
1990 | " [3, 3, 3]]",
1991 | static_cast(Subtract));
1992 | CheckNdArrayInplace(
1993 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f),
1994 | "[[0, 1, 4],\n"
1995 | " [0, 4, 10]]",
1996 | static_cast(Multiply));
1997 | CheckNdArrayInplace(
1998 | NdArray::Arange(6.f).reshape(2, 3), NdArray::Arange(3.f) + 1.f,
1999 | "[[0, 0.5, 0.666667],\n"
2000 | " [3, 2, 1.66667]]",
2001 | static_cast(Divide));
2002 | }
2003 |
2004 | SECTION("Function Arithmetic (NdArray, NdArray) (in-place right)") {
2005 | auto m1 = NdArray::Arange(3.f);
2006 | auto m2 = NdArray::Arange(3.f) + 1.f;
2007 | CheckNdArrayInplace(
2008 | m1, NdArray::Arange(6.f).reshape(2, 3),
2009 | "[[0, 2, 4],\n"
2010 | " [3, 5, 7]]",
2011 | static_cast(Add));
2012 | CheckNdArrayInplace(
2013 | m1, NdArray::Arange(6.f).reshape(2, 3),
2014 | "[[0, 0, 0],\n"
2015 | " [-3, -3, -3]]",
2016 | static_cast(Subtract));
2017 | CheckNdArrayInplace(
2018 | m1, NdArray::Arange(6.f).reshape(2, 3),
2019 | "[[0, 1, 4],\n"
2020 | " [0, 4, 10]]",
2021 | static_cast(Multiply));
2022 | CheckNdArrayInplace(
2023 | m2, NdArray::Arange(6.f).reshape(2, 3),
2024 | "[[inf, 2, 1.5],\n"
2025 | " [0.333333, 0.5, 0.6]]",
2026 | static_cast(Divide));
2027 | }
2028 |
2029 | SECTION("Function Arithmetic (NdArray, NdArray) (in-place left)") {
2030 | auto m2 = NdArray::Arange(3.f);
2031 | auto m3 = NdArray::Arange(3.f) + 1.f;
2032 | CheckNdArrayInplace(
2033 | NdArray::Arange(6.f).reshape(2, 3), m2,
2034 | "[[0, 2, 4],\n"
2035 | " [3, 5, 7]]",
2036 | static_cast(Add));
2037 | CheckNdArrayInplace(
2038 | NdArray::Arange(6.f).reshape(2, 3), m2,
2039 | "[[0, 0, 0],\n"
2040 | " [3, 3, 3]]",
2041 | static_cast(Subtract));
2042 | CheckNdArrayInplace(
2043 | NdArray::Arange(6.f).reshape(2, 3), m2,
2044 | "[[0, 1, 4],\n"
2045 | " [0, 4, 10]]",
2046 | static_cast(Multiply));
2047 | CheckNdArrayInplace(
2048 | NdArray::Arange(6.f).reshape(2, 3), m3,
2049 | "[[0, 0.5, 0.666667],\n"
2050 | " [3, 2, 1.66667]]",
2051 | static_cast(Divide));
2052 | }
2053 |
2054 | SECTION("Function Arithmetic (NdArray, float) (inplace)") {
2055 | CheckNdArrayInplace(NdArray::Arange(3.f), 2.f, "[2, 3, 4]",
2056 | static_cast(Add));
2057 | CheckNdArrayInplace(
2058 | NdArray::Arange(3.f), 2.f, "[-2, -1, 0]",
2059 | static_cast(Subtract));
2060 | CheckNdArrayInplace(
2061 | NdArray::Arange(3.f), 2.f, "[0, 2, 4]",
2062 | static_cast(Multiply));
2063 | CheckNdArrayInplace(NdArray::Arange(3.f), 2.f, "[0, 0.5, 1]",
2064 | static_cast(Divide));
2065 | }
2066 |
2067 | SECTION("Function Arithmetic (float, NdArray) (inplace)") {
2068 | CheckNdArrayInplace(2.f, NdArray::Arange(3.f), "[2, 3, 4]",
2069 | static_cast(Add));
2070 | CheckNdArrayInplace(
2071 | 2.f, NdArray::Arange(3.f), "[2, 1, 0]",
2072 | static_cast(Subtract));
2073 | CheckNdArrayInplace(
2074 | 2.f, NdArray::Arange(3.f), "[0, 2, 4]",
2075 | static_cast(Multiply));
2076 | CheckNdArrayInplace(2.f, NdArray::Arange(3.f), "[inf, 2, 1]",
2077 | static_cast(Divide));
2078 | }
2079 |
2080 | SECTION("Function Comparison (NdArray, NdArray) (inplace both)") {
2081 | CheckNdArrayInplace(
2082 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[0, 1, 0]",
2083 | static_cast(Equal));
2084 | CheckNdArrayInplace(
2085 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[1, 0, 1]",
2086 | static_cast(NotEqual));
2087 | CheckNdArrayInplace(
2088 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[0, 0, 1]",
2089 | static_cast(Greater));
2090 | CheckNdArrayInplace(
2091 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[0, 1, 1]",
2092 | static_cast(GreaterEqual));
2093 | CheckNdArrayInplace(
2094 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[1, 0, 0]",
2095 | static_cast(Less));
2096 | CheckNdArrayInplace(
2097 | NdArray::Arange(3.f), NdArray::Zeros(1) + 1.f, "[1, 1, 0]",
2098 | static_cast(LessEqual));
2099 | }
2100 |
2101 | SECTION("Function Comparison (NdArray, NdArray) (inplace right)") {
2102 | auto m2 = NdArray::Zeros(1) + 1.f;
2103 | CheckNdArrayInplace(
2104 | NdArray::Arange(3.f), m2, "[0, 1, 0]",
2105 | static_cast(Equal));
2106 | CheckNdArrayInplace(
2107 | NdArray::Arange(3.f), m2, "[1, 0, 1]",
2108 | static_cast(NotEqual));
2109 | CheckNdArrayInplace(
2110 | NdArray::Arange(3.f), m2, "[0, 0, 1]",
2111 | static_cast(Greater));
2112 | CheckNdArrayInplace(NdArray::Arange(3.f), m2, "[0, 1, 1]",
2113 | static_cast(
2114 | GreaterEqual));
2115 | CheckNdArrayInplace(
2116 | NdArray::Arange(3.f), m2, "[1, 0, 0]",
2117 | static_cast(Less));
2118 | CheckNdArrayInplace(
2119 | NdArray::Arange(3.f), m2, "[1, 1, 0]",
2120 | static_cast(LessEqual));
2121 | }
2122 |
2123 | SECTION("Function Comparison (NdArray, NdArray) (inplace left)") {
2124 | auto m1 = NdArray::Zeros(1) + 1.f;
2125 | CheckNdArrayInplace(
2126 | m1, NdArray::Arange(3.f), "[0, 1, 0]",
2127 | static_cast(Equal));
2128 | CheckNdArrayInplace(
2129 | m1, NdArray::Arange(3.f), "[1, 0, 1]",
2130 | static_cast(NotEqual));
2131 | CheckNdArrayInplace(
2132 | m1, NdArray::Arange(3.f), "[1, 0, 0]",
2133 | static_cast(Greater));
2134 | CheckNdArrayInplace(m1, NdArray::Arange(3.f), "[1, 1, 0]",
2135 | static_cast(
2136 | GreaterEqual));
2137 | CheckNdArrayInplace(
2138 | m1, NdArray::Arange(3.f), "[0, 0, 1]",
2139 | static_cast(Less));
2140 | CheckNdArrayInplace(
2141 | m1, NdArray::Arange(3.f), "[0, 1, 1]",
2142 | static_cast(LessEqual));
2143 | }
2144 |
2145 | SECTION("Function Comparison (NdArray, float) (inplace)") {
2146 | CheckNdArrayInplace(NdArray::Arange(3.f), 1.f, "[0, 1, 0]",
2147 | static_cast(Equal));
2148 | CheckNdArrayInplace(
2149 | NdArray::Arange(3.f), 1.f, "[1, 0, 1]",
2150 | static_cast(NotEqual));
2151 | CheckNdArrayInplace(
2152 | NdArray::Arange(3.f), 1.f, "[0, 0, 1]",
2153 | static_cast(Greater));
2154 | CheckNdArrayInplace(
2155 | NdArray::Arange(3.f), 1.f, "[0, 1, 1]",
2156 | static_cast(GreaterEqual));
2157 | CheckNdArrayInplace(NdArray::Arange(3.f), 1.f, "[1, 0, 0]",
2158 | static_cast(Less));
2159 | CheckNdArrayInplace(
2160 | NdArray::Arange(3.f), 1.f, "[1, 1, 0]",
2161 | static_cast(LessEqual));
2162 | }
2163 |
2164 | SECTION("Function Comparison (NdArray, float) (inplace)") {
2165 | CheckNdArrayInplace(1.f, NdArray::Arange(3.f), "[0, 1, 0]",
2166 | static_cast(Equal));
2167 | CheckNdArrayInplace(
2168 | 1.f, NdArray::Arange(3.f), "[1, 0, 1]",
2169 | static_cast(NotEqual));
2170 | CheckNdArrayInplace(
2171 | 1.f, NdArray::Arange(3.f), "[1, 0, 0]",
2172 | static_cast(Greater));
2173 | CheckNdArrayInplace(
2174 | 1.f, NdArray::Arange(3.f), "[1, 1, 0]",
2175 | static_cast(GreaterEqual));
2176 | CheckNdArrayInplace(1.f, NdArray::Arange(3.f), "[0, 0, 1]",
2177 | static_cast(Less));
2178 | CheckNdArrayInplace(
2179 | 1.f, NdArray::Arange(3.f), "[0, 1, 1]",
2180 | static_cast(LessEqual));
2181 | }
2182 |
2183 | SECTION("Function Basic Math (in-place)") {
2184 | CheckNdArrayInplace(-NdArray::Arange(3.f), "[0, 1, 2]",
2185 | static_cast(Abs));
2186 | CheckNdArrayInplace(NdArray::Arange(7.f) / 3.f - 1.f,
2187 | "[-1, -1, -1, 0, 1, 1, 1]",
2188 | static_cast(Sign));
2189 | CheckNdArrayInplace(NdArray::Arange(7.f) / 3.f - 1.f,
2190 | "[-1, -0, -0, 0, 1, 1, 1]",
2191 | static_cast(Ceil));
2192 | CheckNdArrayInplace(NdArray::Arange(7.f) / 3.f - 1.f,
2193 | "[-1, -1, -1, 0, 0, 0, 1]",
2194 | static_cast(Floor));
2195 | auto clip_bind = std::bind(
2196 | static_cast(Clip),
2197 | std::placeholders::_1, -0.5f, 0.4f);
2198 | CheckNdArrayInplace(NdArray::Arange(7.f) / 3.f - 1.f,
2199 | "[-0.5, -0.5, -0.333333, 0, 0.333333, 0.4, 0.4]",
2200 | clip_bind);
2201 | CheckNdArrayInplace(NdArray::Arange(3.f), "[0, 1, 1.41421]",
2202 | static_cast(Sqrt));
2203 | CheckNdArrayInplace(NdArray::Arange(3.f), "[1, 2.71828, 7.38906]",
2204 | static_cast(Exp));
2205 | CheckNdArrayInplace(NdArray::Arange(3.f), "[-inf, 0, 0.693147]",
2206 | static_cast(Log));
2207 | CheckNdArrayInplace(NdArray::Arange(3.f), "[0, 1, 4]",
2208 | static_cast(Square));
2209 | CheckNdArrayInplace(
2210 | NdArray::Arange(3.f), NdArray::Arange(3.f) + 1.f, "[0, 1, 8]",
2211 | static_cast(Power));
2212 | const auto m1 = NdArray::Arange(3.f) + 1.f;
2213 | const auto m2 = NdArray::Arange(3.f);
2214 | CheckNdArrayInplace(
2215 | m2, NdArray::Arange(3.f) + 1.f, "[0, 1, 8]",
2216 | static_cast(Power));
2217 | CheckNdArrayInplace(
2218 | NdArray::Arange(3.f), m1, "[0, 1, 8]",
2219 | static_cast(Power));
2220 | CheckNdArrayInplace(NdArray::Arange(3.f), 4.f, "[0, 1, 16]",
2221 | static_cast(Power));
2222 | CheckNdArrayInplace(4.f, NdArray::Arange(3.f), "[1, 4, 16]",
2223 | static_cast(Power));
2224 | }
2225 |
2226 | SECTION("Function Trigonometric (in-place)") {
2227 | CheckNdArrayInplace(NdArray::Arange(3.f), "[0, 0.841471, 0.909297]",
2228 | static_cast(Sin));
2229 | CheckNdArrayInplace(NdArray::Arange(3.f), "[1, 0.540302, -0.416147]",
2230 | static_cast(Cos));
2231 | CheckNdArrayInplace(NdArray::Arange(3.f), "[0, 1.55741, -2.18504]",
2232 | static_cast(Tan));
2233 | }
2234 |
2235 | SECTION("Function Inverse-Trigonometric (in-place)") {
2236 | auto m1 = NdArray::Arange(3) - 1.f;
2237 | auto m2 = NdArray::Arange(3) * 100.f;
2238 | CheckNdArrayInplace(NdArray::Arange(3.f) - 1.f, "[-1.5708, 0, 1.5708]",
2239 | static_cast(ArcSin));
2240 | CheckNdArrayInplace(NdArray::Arange(3.f) - 1.f, "[3.14159, 1.5708, 0]",
2241 | static_cast(ArcCos));
2242 | CheckNdArrayInplace(NdArray::Arange(3.f) * 100.f, "[0, 1.5608, 1.5658]",
2243 | static_cast(ArcTan));
2244 | CheckNdArrayInplace(
2245 | NdArray::Arange(3.f) - 1.f, NdArray::Arange(3.f) * 100.f,
2246 | "[-1.5708, 0, 0.00499996]",
2247 | static_cast(ArcTan2));
2248 | CheckNdArrayInplace(
2249 | m1, NdArray::Arange(3.f) * 100.f, "[-1.5708, 0, 0.00499996]",
2250 | static_cast(ArcTan2));
2251 | CheckNdArrayInplace(
2252 | NdArray::Arange(3.f) - 1.f, m2, "[-1.5708, 0, 0.00499996]",
2253 | static_cast(ArcTan2));
2254 | CheckNdArrayInplace(
2255 | NdArray::Arange(3.f) - 1.f, 2.f, "[-0.463648, 0, 0.463648]",
2256 | static_cast(ArcTan2));
2257 | CheckNdArrayInplace(
2258 | 2.f, NdArray::Arange(3.f) - 1.f, "[2.03444, 1.5708, 1.10715]",
2259 | static_cast(ArcTan2));
2260 | }
2261 |
2262 | SECTION("Where (in-place)") {
2263 | auto m1 = NdArray::Arange(6.f).reshape(2, 3);
2264 | CheckNdArray(Where(2.f < m1, NdArray::Ones(2, 1), NdArray::Arange(3)),
2265 | "[[0, 1, 2],\n"
2266 | " [1, 1, 1]]");
2267 | CheckNdArray(Where(2.f < m1, 1.f, NdArray::Arange(3)),
2268 | "[[0, 1, 2],\n"
2269 | " [1, 1, 1]]");
2270 | CheckNdArray(Where(2.f < m1, NdArray::Arange(3), 0.f),
2271 | "[[0, 0, 0],\n"
2272 | " [0, 1, 2]]");
2273 | CheckNdArray(Where(2.f < m1, 1.f, 0.f),
2274 | "[[0, 0, 0],\n"
2275 | " [1, 1, 1]]");
2276 | }
2277 |
2278 | SECTION("Function Inverse (in-place)") {
2279 | auto m1 = NdArray::Arange(4).reshape(2, 2) + 1.f;
2280 | auto m2 = Inv(m1);
2281 | CheckNdArrayInplace(NdArray::Arange(4).reshape(2, 2) + 1.f,
2282 | "[[-2, 1],\n"
2283 | " [1.5, -0.5]]",
2284 | static_cast(Inv),
2285 | 4); // Low precision
2286 | }
2287 | }
2288 |
--------------------------------------------------------------------------------
/tests/test_dot_major.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include