├── .gitignore ├── LICENSE ├── README.md ├── include ├── array.h ├── metal.h └── mtlcpp.h ├── metal-cpp_macOS14.2_iOS17.2 └── Metal │ └── Metal.hpp └── test ├── Makefile ├── bench.cpp ├── doctest.h ├── mnist.cpp ├── nanobench.h ├── sample_weight.csv ├── t10k-images-idx3-ubyte ├── t10k-labels-idx1-ubyte ├── test.cpp ├── test_2lnn.cpp ├── test_array.cpp ├── test_examples.cpp └── test_perceptron.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | .DS_Store 3 | metal-cpp/* 4 | test/test 5 | test/bench 6 | test/mnist 7 | misc/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 yhirose 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | mtlcpp 2 | ====== 3 | 4 | A header-only C++20 linear algebra library for Metal on MacOS 5 | 6 | * This project is still in development and is far from reaching the first alpha version :) 7 | * Data types supported in this library are `int` and `float` only, since Metal doesn't support `double` 8 | * This library uses GPU cores in Apple M1 chip with [Metal-cpp](https://developer.apple.com/metal/cpp/) 9 | 10 | Build and run unit tests and benchmark 11 | -------------------------------------- 12 | 13 | * Install Xcode Command Line Tools 14 | * Run the following commands in Terminal 15 | 16 | ```bash 17 | cd test 18 | make 19 | ``` 20 | 21 | Benchmark as of 3/2/2024 on M1 MacBook Pro 14 22 | --------------------------------------------- 23 | 24 | | ns/op | op/s | benchmark 25 | |--------------------:|--------------------:|:---------- 26 | | 150,856,709.00 | 6.63 | CPU: `a + b` 27 | | 2,262,442.07 | 442.00 | GPU: `a + b` 28 | | 1,351,401.59 | 739.97 | Eigen: `a + b` 29 | | 964,220,500.00 | 1.04 | CPU: `a.dot(b)` 30 | | 1,094,602.35 | 913.57 | GPU: `a.dot(b)` 31 | | 3,002,299.36 | 333.08 | Eigen: `a * b` 32 | 33 | ```cpp 34 | // test/bench.cpp 35 | 36 | // `add` benchmark 37 | const size_t n = 10'000'000; 38 | 39 | auto a = mtl::ones({n}); 40 | auto b = mtl::ones({n}); 41 | auto c = mtl::array(); 42 | 43 | mtl::use_cpu(); 44 | Bench().run("CPU: a + b", [&] { 45 | c = a + b; 46 | }); 47 | 48 | mtl::use_gpu(); 49 | Bench().run("GPU: a + b", [&] { 50 | c = a + b; 51 | }); 52 | 53 | auto aa = Eigen::Vector::Ones(n); 54 | auto bb = Eigen::Vector::Ones(n); 55 | auto cc = Eigen::Vector(n); 56 | 57 | Bench().run("Eigen: a + b", [&] { 58 | cc = aa + bb; 59 | }); 60 | 61 | // `dot` benchmark 62 | auto a = mtl::ones({1000, 1000}); 63 | auto b = mtl::ones({1000, 100}); 64 | auto c = mtl::array(); 65 | 66 | mtl::use_cpu(); 67 | Bench().run("CPU: a.dot(b)", [&] { 68 | c = a.dot(b); 69 | }); 70 | 71 | mtl::use_gpu(); 72 | Bench().run("GPU: a.dot(b)", [&] { 73 | c = a.dot(b); 74 | }); 75 | 76 | auto aa = Eigen::Matrix::Ones(1000, 1000); 77 | auto bb = Eigen::Matrix::Ones(1000, 100); 78 | auto cc = Eigen::Matrix(); 79 | 80 | Bench().run("Eigen: a * b", [&] { 81 | cc = aa * bb; 82 | }); 83 | ``` 84 | 85 | Operations 86 | ---------- 87 | 88 | ### GPU and CPU 89 | 90 | * `+` (add) 91 | * `-` (sub) 92 | * `*` (mul) 93 | * `/` (div) 94 | * `dot` (dot product) 95 | 96 | ### CPU only 97 | 98 | * `==` 99 | * `clone` 100 | * `constants` 101 | * `empty` 102 | * `zeros` 103 | * `ones` 104 | * `random` 105 | * `transpose` 106 | * `sigmoid` 107 | * `sum` 108 | * `mean` 109 | * `min` 110 | * `max` 111 | * `count` 112 | * `all` 113 | * `softmax` 114 | * `argmax` 115 | * `array_equal` 116 | * `allclose` 117 | 118 | License 119 | ------- 120 | 121 | MIT license (© 2024 Yuji Hirose) 122 | -------------------------------------------------------------------------------- /include/array.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace mtl { 12 | 13 | using shape_type = std::vector; 14 | using strides_type = shape_type; 15 | 16 | //------------------------------------------------------------------------------ 17 | 18 | template 19 | class array { 20 | public: 21 | array() = default; 22 | array(array &&rhs) = default; 23 | array(const array &rhs) = default; 24 | array &operator=(const array &rhs) = default; 25 | 26 | array(const shape_type &shape, T val); 27 | array(const shape_type &shape, std::input_iterator auto it); 28 | array(const shape_type &shape, std::ranges::input_range auto &&r); 29 | array(std::ranges::input_range auto &&r); 30 | array(T val); 31 | 32 | array(nested_initializer_list l); 33 | array(nested_initializer_list l); 34 | array(nested_initializer_list l); 35 | array(nested_initializer_list l); 36 | 37 | //---------------------------------------------------------------------------- 38 | 39 | template 40 | array clone() const; 41 | 42 | //---------------------------------------------------------------------------- 43 | 44 | array operator==(const array &rhs) const; 45 | array operator!=(const array &rhs) const; 46 | array operator>(const array &rhs) const; 47 | array operator<(const array &rhs) const; 48 | array operator>=(const array &rhs) const; 49 | array operator<=(const array &rhs) const; 50 | 51 | //---------------------------------------------------------------------------- 52 | 53 | size_t buffer_element_count() const; 54 | size_t buffer_bytes() const; 55 | 56 | T *buffer_data(); 57 | const T *buffer_data() const; 58 | 59 | //---------------------------------------------------------------------------- 60 | 61 | size_t element_count() const; 62 | size_t length() const; 63 | 64 | size_t dimension() const; 65 | const shape_type &shape() const; 66 | const strides_type &strides() const; 67 | 68 | void reshape(const shape_type &shape); 69 | 70 | // TODO: can return a reference for performance? 71 | const auto broadcast(const shape_type &target_shape) const; 72 | 73 | array transpose() const; 74 | 75 | //---------------------------------------------------------------------------- 76 | 77 | T at() const; 78 | T &at(); 79 | 80 | T at(size_t i) const; 81 | T &at(size_t i); 82 | 83 | T at(size_t x, size_t y) const; 84 | T &at(size_t x, size_t y); 85 | 86 | T at(size_t x, size_t y, size_t z) const; 87 | T &at(size_t x, size_t y, size_t z); 88 | 89 | T at(const std::vector &position) const; 90 | T &at(const std::vector &position); 91 | 92 | template 93 | auto take() const; 94 | 95 | //---------------------------------------------------------------------------- 96 | 97 | array operator[](size_t row) const; 98 | 99 | //---------------------------------------------------------------------------- 100 | 101 | auto element_begin(); 102 | auto element_end(); 103 | 104 | auto element_cbegin() const; 105 | auto element_cend() const; 106 | 107 | auto elements(); 108 | auto elements() const; 109 | 110 | //---------------------------------------------------------------------------- 111 | 112 | auto begin(); 113 | auto end(); 114 | 115 | auto begin() const; 116 | auto end() const; 117 | 118 | auto cbegin() const; 119 | auto cend() const; 120 | 121 | template 122 | auto rows(); 123 | 124 | template 125 | auto rows() const; 126 | 127 | //---------------------------------------------------------------------------- 128 | 129 | void set(std::input_iterator auto it); 130 | void set(std::initializer_list l); 131 | 132 | //---------------------------------------------------------------------------- 133 | 134 | void constants(T val); 135 | void zeros(); 136 | void ones(); 137 | void random(); 138 | 139 | //---------------------------------------------------------------------------- 140 | 141 | array operator+(const array &rhs) const; 142 | array operator-(const array &rhs) const; 143 | array operator*(const array &rhs) const; 144 | array operator/(const array &rhs) const; 145 | 146 | array pow(const array &rhs) const; 147 | 148 | void operator+=(const array &rhs); 149 | void operator-=(const array &rhs); 150 | void operator*=(const array &rhs); 151 | void operator/=(const array &rhs); 152 | 153 | //---------------------------------------------------------------------------- 154 | 155 | array dot(const array &rhs) const; 156 | 157 | array linear(const array& W, const array& b) const; 158 | 159 | 160 | //---------------------------------------------------------------------------- 161 | 162 | array sigmoid() const; 163 | 164 | //---------------------------------------------------------------------------- 165 | 166 | T sum() const; 167 | array sum(size_t axis) const; 168 | 169 | float mean() const; 170 | array mean(size_t axis) const; 171 | 172 | T min() const; 173 | T max() const; 174 | 175 | size_t count() const; 176 | 177 | bool all(arithmetic auto val) const; 178 | template 179 | bool all(U fn) const; 180 | 181 | array softmax() const; 182 | auto argmax() const; 183 | 184 | float mean_square_error(const array &rhs) const; 185 | 186 | //---------------------------------------------------------------------------- 187 | 188 | std::string print_shape_type(const shape_type &shape) const; 189 | std::string print_shape() const; 190 | std::string print_strides() const; 191 | std::string print_data_type() const; 192 | std::string print_info() const; 193 | std::string print_array() const; 194 | 195 | private: 196 | shape_type shape_; 197 | strides_type strides_; 198 | metal::storage storage_; 199 | 200 | //---------------------------------------------------------------------------- 201 | 202 | void allocate_buffer_(); 203 | 204 | //---------------------------------------------------------------------------- 205 | 206 | void copy_initializer_list_(const auto &l); 207 | 208 | //---------------------------------------------------------------------------- 209 | 210 | void bounds_check_(size_t i) const; 211 | void bounds_check_(size_t x, size_t y) const; 212 | void bounds_check_(size_t x, size_t y, size_t z) const; 213 | 214 | //---------------------------------------------------------------------------- 215 | 216 | template 217 | void enumerate_position_(size_t shape_index, std::vector &position, 218 | U fn) const; 219 | template 220 | void enumerate_position_(U fn) const; 221 | 222 | //---------------------------------------------------------------------------- 223 | 224 | static auto broadcast_(const array &lhs, const array &rhs, auto cb); 225 | 226 | template 227 | array apply_binary_operation_(const array &rhs, auto ope) const; 228 | 229 | //---------------------------------------------------------------------------- 230 | 231 | enum class ArithmeticOperation { 232 | Add = 0, 233 | Sub, 234 | Mul, 235 | Div, 236 | Pow, 237 | }; 238 | 239 | static auto gpu_arithmetic_operation_(const array &lhs, const array &rhs, 240 | ArithmeticOperation ope); 241 | static auto cpu_arithmetic_operation_(const array &lhs, const array &rhs, 242 | ArithmeticOperation ope); 243 | 244 | static auto arithmetic_operation_(const array &lhs, const array &rhs, 245 | ArithmeticOperation ope); 246 | 247 | //---------------------------------------------------------------------------- 248 | 249 | static array cpu_dot_operation_(const array &lhs, const array &rhs); 250 | static array gpu_dot_operation_(const array &lhs, const array &rhs); 251 | template 252 | array dot_operation_(const array &rhs, U fn) const; 253 | }; 254 | 255 | //---------------------------------------------------------------------------- 256 | 257 | template 258 | std::ostream &operator<<(std::ostream &os, const array &arr); 259 | 260 | //---------------------------------------------------------------------------- 261 | 262 | template 263 | array operator+(auto lhs, const array &rhs); 264 | 265 | template 266 | array operator-(auto lhs, const array &rhs); 267 | 268 | template 269 | array operator*(auto lhs, const array &rhs); 270 | 271 | template 272 | array operator/(auto lhs, const array &rhs); 273 | 274 | //---------------------------------------------------------------------------- 275 | 276 | template 277 | array where(const array &cond, T x, T y); 278 | 279 | //---------------------------------------------------------------------------- 280 | 281 | template 282 | bool array_equal(const array &a, const array &b); 283 | 284 | template 285 | bool is_close(T a, U b, float tolerance = 1e-3); 286 | 287 | template 288 | bool is_close(T a, U b); 289 | 290 | template 291 | bool allclose(const array &a, const array &b, float tolerance = 1e-3); 292 | 293 | //---------------------------------------------------------------------------- 294 | 295 | template 296 | auto empty(const shape_type &shape); 297 | 298 | template 299 | auto zeros(const shape_type &shape); 300 | 301 | template 302 | auto ones(const shape_type &shape); 303 | 304 | auto random(const shape_type &shape); 305 | 306 | //----------------------------------------------------------------------------- 307 | // Implementation 308 | //----------------------------------------------------------------------------- 309 | 310 | template 311 | inline array::array(const shape_type &shape, T val) { 312 | reshape(shape); 313 | allocate_buffer_(); 314 | constants(val); 315 | } 316 | 317 | template 318 | inline array::array(const shape_type &shape, std::input_iterator auto it) { 319 | reshape(shape); 320 | allocate_buffer_(); 321 | set(it); 322 | } 323 | 324 | template 325 | inline array::array(const shape_type &shape, 326 | std::ranges::input_range auto &&r) { 327 | reshape(shape); 328 | allocate_buffer_(); 329 | set(r.begin()); 330 | } 331 | 332 | template 333 | inline array::array(std::ranges::input_range auto &&r) { 334 | size_t element_count = std::ranges::distance(r); 335 | reshape({element_count}); 336 | allocate_buffer_(); 337 | set(r.begin()); 338 | } 339 | 340 | template 341 | inline array::array(T val) : array(shape_type({}), T{}) { 342 | *buffer_data() = val; 343 | } 344 | 345 | template 346 | struct depth_ { 347 | static constexpr size_t value = 0; 348 | }; 349 | 350 | template 351 | struct depth_> { 352 | static constexpr size_t value = 1 + depth_::value; 353 | }; 354 | 355 | template 356 | struct shape_value_ { 357 | template 358 | static constexpr size_t value(T l) { 359 | return l.size() == 0 ? 0 : shape_value_::value(*l.begin()); 360 | } 361 | }; 362 | 363 | template <> 364 | struct shape_value_<0> { 365 | template 366 | static constexpr size_t value(T l) { 367 | return l.size(); 368 | } 369 | }; 370 | 371 | template 372 | constexpr shape_type shape_(T l, std::index_sequence) { 373 | return {shape_type::value_type(shape_value_::value(l))...}; 374 | } 375 | 376 | template 377 | constexpr size_t nested_initializer_list_dimension_() { 378 | return depth_::value; 379 | }; 380 | 381 | template 382 | constexpr shape_type nested_initializer_list_shape_(T l) { 383 | return shape_( 384 | l, std::make_index_sequence()>()); 385 | } 386 | 387 | template 388 | inline array::array(nested_initializer_list l) 389 | : array(nested_initializer_list_shape_(l), T{}) { 390 | copy_initializer_list_(l); 391 | } 392 | 393 | template 394 | inline array::array(nested_initializer_list l) 395 | : array(nested_initializer_list_shape_(l), T{}) { 396 | copy_initializer_list_(l); 397 | } 398 | 399 | template 400 | inline array::array(nested_initializer_list l) 401 | : array(nested_initializer_list_shape_(l), T{}) { 402 | copy_initializer_list_(l); 403 | } 404 | 405 | template 406 | inline array::array(nested_initializer_list l) 407 | : array(nested_initializer_list_shape_(l), T{}) { 408 | copy_initializer_list_(l); 409 | } 410 | 411 | //---------------------------------------------------------------------------- 412 | 413 | template 414 | template 415 | inline array array::clone() const { 416 | auto tmp = array(shape_, U{}); 417 | for (size_t i = 0; i < element_count(); i++) { 418 | tmp.at(i) = static_cast(at(i)); 419 | } 420 | return tmp; 421 | } 422 | 423 | //---------------------------------------------------------------------------- 424 | 425 | template 426 | inline array array::operator==(const array &rhs) const { 427 | return apply_binary_operation_( 428 | rhs, [](auto lhs, auto rhs) { return lhs == rhs; }); 429 | } 430 | 431 | template 432 | inline array array::operator!=(const array &rhs) const { 433 | return apply_binary_operation_( 434 | rhs, [](auto lhs, auto rhs) { return lhs != rhs; }); 435 | } 436 | 437 | template 438 | inline array array::operator>(const array &rhs) const { 439 | return apply_binary_operation_( 440 | rhs, [](auto lhs, auto rhs) { return lhs > rhs; }); 441 | } 442 | 443 | template 444 | inline array array::operator>=(const array &rhs) const { 445 | return apply_binary_operation_( 446 | rhs, [](auto lhs, auto rhs) { return lhs >= rhs; }); 447 | } 448 | 449 | template 450 | inline array array::operator<(const array &rhs) const { 451 | return apply_binary_operation_( 452 | rhs, [](auto lhs, auto rhs) { return lhs < rhs; }); 453 | } 454 | 455 | template 456 | inline array array::operator<=(const array &rhs) const { 457 | return apply_binary_operation_( 458 | rhs, [](auto lhs, auto rhs) { return lhs <= rhs; }); 459 | } 460 | 461 | //---------------------------------------------------------------------------- 462 | 463 | template 464 | inline size_t array::buffer_element_count() const { 465 | return storage_.len; 466 | } 467 | 468 | template 469 | inline size_t array::buffer_bytes() const { 470 | return storage_.len * sizeof(T); 471 | } 472 | 473 | template 474 | inline T *array::buffer_data() { 475 | return static_cast(storage_.buf->contents()) + storage_.off; 476 | } 477 | 478 | template 479 | inline const T *array::buffer_data() const { 480 | return static_cast(storage_.buf->contents()) + storage_.off; 481 | } 482 | 483 | //---------------------------------------------------------------------------- 484 | 485 | template 486 | inline size_t array::element_count() const { 487 | // TODO: cache 488 | size_t count = 1; 489 | for (auto n : shape_) { 490 | count *= n; 491 | } 492 | return count; 493 | } 494 | 495 | template 496 | inline size_t array::length() const { 497 | if (shape_.empty()) { 498 | throw std::runtime_error("array: cannot call with a scalar value."); 499 | } 500 | return shape_[0]; 501 | } 502 | 503 | template 504 | inline size_t array::dimension() const { 505 | return shape_.size(); 506 | } 507 | 508 | template 509 | inline const shape_type &array::shape() const { 510 | return shape_; 511 | } 512 | 513 | template 514 | inline const strides_type &array::strides() const { 515 | return strides_; 516 | } 517 | 518 | template 519 | inline void array::reshape(const shape_type &shape) { 520 | // TODO: check the shape 521 | shape_ = shape; 522 | 523 | // strides 524 | strides_.clear(); 525 | strides_.push_back(1); 526 | if (!strides_.empty()) { 527 | for (int i = shape.size() - 1; i > 0; i--) { 528 | auto n = strides_.front() * shape[i]; 529 | strides_.insert(strides_.begin(), n); 530 | } 531 | } 532 | } 533 | 534 | template 535 | inline const auto array::broadcast(const shape_type &target_shape) const { 536 | if (target_shape.size() < dimension()) { 537 | throw std::runtime_error("array: invalid shape for broadcast."); 538 | } else if (target_shape.size() == dimension()) { 539 | return *this; 540 | } 541 | 542 | auto diff = target_shape.size() - dimension(); 543 | for (size_t i = 0; i < dimension(); i++) { 544 | if (shape_[i] != target_shape[i + diff]) { 545 | throw std::runtime_error("array: invalid shape for broadcast."); 546 | } 547 | } 548 | 549 | array tmp = *this; 550 | tmp.shape_ = target_shape; 551 | 552 | // strides 553 | tmp.strides_.clear(); 554 | tmp.strides_.push_back(1); 555 | if (!strides_.empty()) { 556 | for (int i = target_shape.size() - 1; i > 0; i--) { 557 | auto n = i <= diff ? 0 : tmp.strides_.front() * target_shape[i]; 558 | tmp.strides_.insert(tmp.strides_.begin(), n); 559 | } 560 | } 561 | return tmp; 562 | } 563 | 564 | template 565 | inline array array::transpose() const { 566 | if (dimension() == 1) { 567 | auto tmp = clone(); 568 | tmp.reshape({1, element_count()}); 569 | 570 | auto it = element_cbegin(); 571 | for (size_t col = 0; col < element_count(); col++) { 572 | tmp.at(0, col) = *it; 573 | ++it; 574 | } 575 | return tmp; 576 | } 577 | 578 | if (dimension() == 2) { 579 | if (shape_[0] == 1) { 580 | auto tmp = clone(); 581 | tmp.reshape({element_count()}); 582 | 583 | auto it = element_cbegin(); 584 | for (size_t row = 0; row < element_count(); row++) { 585 | tmp.at(row) = *it; 586 | ++it; 587 | } 588 | return tmp; 589 | } else { 590 | auto shape = shape_; 591 | std::ranges::reverse(shape); 592 | 593 | auto tmp = clone(); 594 | tmp.reshape(shape); 595 | 596 | auto it = element_cbegin(); 597 | for (size_t col = 0; col < shape[1]; col++) { 598 | for (size_t row = 0; row < shape[0]; row++) { 599 | tmp.at(row, col) = *it; 600 | ++it; 601 | } 602 | } 603 | return tmp; 604 | } 605 | } 606 | 607 | if (dimension() == 3) { 608 | auto shape = shape_; 609 | std::ranges::reverse(shape); 610 | 611 | auto tmp = clone(); 612 | tmp.reshape(shape); 613 | 614 | auto it = element_cbegin(); 615 | for (size_t z = 0; z < shape[2]; z++) { 616 | for (size_t y = 0; y < shape[1]; y++) { 617 | for (size_t x = 0; x < shape[0]; x++) { 618 | tmp.at(x, y, z) = *it; 619 | ++it; 620 | } 621 | } 622 | } 623 | return tmp; 624 | } 625 | 626 | throw std::runtime_error("array: can't do `transpose` operation."); 627 | } 628 | 629 | //---------------------------------------------------------------------------- 630 | 631 | template 632 | inline T array::at() const { 633 | return *buffer_data(); 634 | } 635 | 636 | template 637 | inline T &array::at() { 638 | return *buffer_data(); 639 | } 640 | 641 | template 642 | inline T array::at(size_t i) const { 643 | bounds_check_(i); 644 | return buffer_data()[i % buffer_element_count()]; 645 | } 646 | 647 | template 648 | inline T &array::at(size_t i) { 649 | bounds_check_(i); 650 | return buffer_data()[i % buffer_element_count()]; 651 | } 652 | 653 | template 654 | inline T array::at(size_t x, size_t y) const { 655 | bounds_check_(x, y); 656 | return buffer_data()[strides_[0] * x + y]; 657 | } 658 | 659 | template 660 | inline T &array::at(size_t x, size_t y) { 661 | bounds_check_(x, y); 662 | return buffer_data()[strides_[0] * x + y]; 663 | } 664 | 665 | template 666 | inline T array::at(size_t x, size_t y, size_t z) const { 667 | bounds_check_(x, y, z); 668 | return buffer_data()[(strides_[0] * x) + (strides_[1] * y) + z]; 669 | } 670 | 671 | template 672 | inline T &array::at(size_t x, size_t y, size_t z) { 673 | bounds_check_(x, y, z); 674 | return buffer_data()[(strides_[0] * x) + (strides_[1] * y) + z]; 675 | } 676 | 677 | template 678 | inline T array::at(const std::vector &position) const { 679 | // TODO: bounds_check_(position); 680 | size_t buffer_index = 0; 681 | for (size_t i = 0; i < position.size(); i++) { 682 | buffer_index += strides_[i] * position[i]; 683 | } 684 | return buffer_data()[buffer_index]; 685 | } 686 | 687 | template 688 | inline T &array::at(const std::vector &position) { 689 | // TODO: bounds_check_(position); 690 | size_t buffer_index = 0; 691 | for (size_t i = 0; i < position.size(); i++) { 692 | buffer_index += strides_[i] * position[i]; 693 | } 694 | return buffer_data()[buffer_index]; 695 | } 696 | 697 | template 698 | template 699 | inline auto array::take() const { 700 | if constexpr (I == 0) { 701 | return std::tuple<>(); 702 | } else { 703 | auto t = take(); 704 | return std::tuple_cat(t, std::tuple(at(I - 1))); 705 | } 706 | } 707 | 708 | //---------------------------------------------------------------------------- 709 | 710 | template 711 | inline array array::operator[](size_t row) const { 712 | if (dimension() == 0 || row >= shape_[0]) { 713 | throw std::runtime_error("array: row is out of bounds."); 714 | } 715 | 716 | array tmp(*this); 717 | 718 | auto s = shape(); 719 | s.erase(s.begin()); 720 | tmp.reshape(s); 721 | 722 | auto stride = strides_[0]; 723 | tmp.storage_.off = storage_.off + stride * row; 724 | tmp.storage_.len = stride; 725 | return tmp; 726 | } 727 | 728 | //---------------------------------------------------------------------------- 729 | 730 | template 731 | class element_iterator { 732 | public: 733 | using difference_type = std::ptrdiff_t; 734 | using reference = T &; 735 | using iterator_concept = std::forward_iterator_tag; 736 | 737 | element_iterator(array *arr, size_t i) : arr_(arr), i_(i) {} 738 | 739 | element_iterator &operator++() { 740 | ++i_; 741 | return *this; 742 | } 743 | 744 | element_iterator operator++(int) { 745 | auto tmp = *this; 746 | ++(*this); 747 | return tmp; 748 | } 749 | 750 | reference &operator*() { return arr_->at(i_); } 751 | 752 | friend bool operator==(const element_iterator &a, const element_iterator &b) { 753 | return a.i_ == b.i_; 754 | }; 755 | 756 | private: 757 | array *arr_ = nullptr; 758 | size_t i_ = 0; 759 | }; 760 | 761 | template 762 | class const_element_iterator { 763 | public: 764 | using difference_type = std::ptrdiff_t; 765 | using value_type = T; 766 | using iterator_concept = std::forward_iterator_tag; 767 | 768 | const_element_iterator(const array *arr, size_t i) : arr_(arr), i_(i) {} 769 | 770 | const_element_iterator &operator++() { 771 | ++i_; 772 | return *this; 773 | } 774 | 775 | const_element_iterator operator++(int) { 776 | auto tmp = *this; 777 | ++(*this); 778 | return tmp; 779 | } 780 | 781 | value_type operator*() const { return arr_->at(i_); } 782 | 783 | friend bool operator==(const const_element_iterator &a, 784 | const const_element_iterator &b) { 785 | return a.i_ == b.i_; 786 | }; 787 | 788 | private: 789 | const array *arr_ = nullptr; 790 | size_t i_ = 0; 791 | }; 792 | 793 | template 794 | struct element_range { 795 | element_range(array *arr) : arr_(arr) {} 796 | auto begin() { return element_iterator(arr_, 0); } 797 | auto end() { return element_iterator(arr_, arr_->element_count()); } 798 | array *arr_ = nullptr; 799 | }; 800 | 801 | template 802 | struct const_element_range { 803 | const_element_range(const array *arr) : arr_(arr) {} 804 | auto begin() { return const_element_iterator(arr_, 0); } 805 | auto end() { return const_element_iterator(arr_, arr_->element_count()); } 806 | auto cbegin() const { return const_element_iterator(arr_, 0); } 807 | auto cend() const { 808 | return const_element_iterator(arr_, arr_->element_count()); 809 | } 810 | const array *arr_ = nullptr; 811 | }; 812 | 813 | template 814 | inline auto array::element_begin() { 815 | return element_iterator(this, 0); 816 | } 817 | 818 | template 819 | inline auto array::element_end() { 820 | return element_iterator(this, element_count()); 821 | } 822 | 823 | template 824 | inline auto array::element_cbegin() const { 825 | return const_element_iterator(this, 0); 826 | } 827 | 828 | template 829 | inline auto array::element_cend() const { 830 | return const_element_iterator(this, element_count()); 831 | } 832 | 833 | template 834 | inline auto array::elements() { 835 | return element_range(this); 836 | } 837 | 838 | template 839 | inline auto array::elements() const { 840 | return const_element_range(this); 841 | } 842 | 843 | //---------------------------------------------------------------------------- 844 | 845 | template 846 | class row_iterator { 847 | public: 848 | using difference_type = std::ptrdiff_t; 849 | using value_type = array; 850 | using iterator_concept = std::forward_iterator_tag; 851 | 852 | row_iterator(array *arr, size_t i) : arr_(arr), i_(i) {} 853 | 854 | row_iterator &operator++() { 855 | ++i_; 856 | return *this; 857 | } 858 | 859 | row_iterator operator++(int) { 860 | auto tmp = *this; 861 | ++(*this); 862 | return tmp; 863 | } 864 | 865 | value_type operator*() { return (*arr_)[i_]; } 866 | 867 | friend bool operator==(const row_iterator &a, const row_iterator &b) { 868 | return a.i_ == b.i_; 869 | }; 870 | 871 | private: 872 | array *arr_ = nullptr; 873 | size_t i_ = 0; 874 | }; 875 | 876 | template 877 | class const_row_iterator { 878 | public: 879 | using difference_type = std::ptrdiff_t; 880 | using value_type = array; 881 | using iterator_concept = std::forward_iterator_tag; 882 | 883 | const_row_iterator(const array *arr, size_t i) : arr_(arr), i_(i) {} 884 | 885 | const_row_iterator &operator++() { 886 | ++i_; 887 | return *this; 888 | } 889 | 890 | const_row_iterator operator++(int) { 891 | auto tmp = *this; 892 | ++(*this); 893 | return tmp; 894 | } 895 | 896 | value_type operator*() const { return (*arr_)[i_]; } 897 | 898 | friend bool operator==(const const_row_iterator &a, 899 | const const_row_iterator &b) { 900 | return a.i_ == b.i_; 901 | }; 902 | 903 | private: 904 | const array *arr_ = nullptr; 905 | size_t i_ = 0; 906 | }; 907 | 908 | template 909 | class row_tuple_iterator { 910 | public: 911 | using difference_type = std::ptrdiff_t; 912 | using reference = array &; 913 | using iterator_concept = std::forward_iterator_tag; 914 | 915 | row_tuple_iterator(array *arr, size_t i) : arr_(arr), i_(i) {} 916 | 917 | row_tuple_iterator &operator++() { 918 | ++i_; 919 | return *this; 920 | } 921 | 922 | row_tuple_iterator operator++(int) { 923 | auto tmp = *this; 924 | ++(*this); 925 | return tmp; 926 | } 927 | 928 | auto operator*() const { return (*arr_)[i_].template take(); } 929 | 930 | friend bool operator==(const row_tuple_iterator &a, 931 | const row_tuple_iterator &b) { 932 | return a.i_ == b.i_; 933 | }; 934 | 935 | private: 936 | array *arr_ = nullptr; 937 | size_t i_ = 0; 938 | }; 939 | 940 | template 941 | class const_row_tuple_iterator { 942 | public: 943 | using difference_type = std::ptrdiff_t; 944 | using iterator_concept = std::forward_iterator_tag; 945 | 946 | const_row_tuple_iterator(const array *arr, size_t i) : arr_(arr), i_(i) {} 947 | 948 | const_row_tuple_iterator &operator++() { 949 | ++i_; 950 | return *this; 951 | } 952 | 953 | const_row_tuple_iterator operator++(int) { 954 | auto tmp = *this; 955 | ++(*this); 956 | return tmp; 957 | } 958 | 959 | auto operator*() const { return (*arr_)(i_).template take(); } 960 | 961 | friend bool operator==(const const_row_tuple_iterator &a, 962 | const const_row_tuple_iterator &b) { 963 | return a.i_ == b.i_; 964 | }; 965 | 966 | private: 967 | const array *arr_ = nullptr; 968 | size_t i_ = 0; 969 | }; 970 | 971 | template 972 | struct row_range { 973 | row_range(array *arr) : arr_(arr) {} 974 | auto begin() { return row_iterator(arr_, 0); } 975 | auto end() { return row_iterator(arr_, arr_->shape()[0]); } 976 | array *arr_ = nullptr; 977 | }; 978 | 979 | template 980 | struct const_row_range { 981 | const_row_range(const array *arr) : arr_(arr) {} 982 | auto begin() const { return const_row_iterator(arr_, 0); } 983 | auto end() const { return const_row_iterator(arr_, arr_->shape()[0]); } 984 | auto cbegin() const { return const_row_iterator(arr_, 0); } 985 | auto cend() const { return const_row_iterator(arr_, arr_->shape()[0]); } 986 | const array *arr_ = nullptr; 987 | }; 988 | 989 | template 990 | struct row_tuple_range { 991 | row_tuple_range(array *arr) : arr_(arr) {} 992 | auto begin() { return row_tuple_iterator(arr_, 0); } 993 | auto end() { return row_tuple_iterator(arr_, arr_->shape()[0]); } 994 | array *arr_ = nullptr; 995 | }; 996 | 997 | template 998 | struct const_row_tuple_range { 999 | const_row_tuple_range(array *arr) : arr_(arr) {} 1000 | auto cbegin() const { return const_row_tuple_iterator(arr_, 0); } 1001 | auto cend() const { 1002 | return const_row_tuple_iterator(arr_, arr_->shape()[0]); 1003 | } 1004 | const array *arr_ = nullptr; 1005 | }; 1006 | 1007 | template 1008 | inline auto array::begin() { 1009 | return row_iterator(this, 0); 1010 | } 1011 | 1012 | template 1013 | inline auto array::end() { 1014 | return row_iterator(this, shape_[0]); 1015 | } 1016 | 1017 | template 1018 | inline auto array::begin() const { 1019 | return const_row_iterator(this, 0); 1020 | } 1021 | 1022 | template 1023 | inline auto array::end() const { 1024 | return const_row_iterator(this, shape_[0]); 1025 | } 1026 | 1027 | template 1028 | inline auto array::cbegin() const { 1029 | return const_row_iterator(this, 0); 1030 | } 1031 | template 1032 | inline auto array::cend() const { 1033 | return const_row_iterator(this, shape_[0]); 1034 | } 1035 | 1036 | template 1037 | template 1038 | inline auto array::rows() { 1039 | if constexpr (N == 0) { 1040 | return row_range(this); 1041 | } else { 1042 | return row_tuple_range(this); 1043 | } 1044 | } 1045 | 1046 | template 1047 | template 1048 | inline auto array::rows() const { 1049 | if constexpr (N == 0) { 1050 | return const_row_range(this); 1051 | } else { 1052 | return const_row_tuple_range(this); 1053 | } 1054 | } 1055 | 1056 | //---------------------------------------------------------------------------- 1057 | 1058 | template 1059 | inline void array::set(std::input_iterator auto it) { 1060 | // TODO: parallel operation on GPU 1061 | for (size_t i = 0; i < element_count(); i++) { 1062 | at(i) = *it++; 1063 | } 1064 | } 1065 | 1066 | template 1067 | inline void array::set(std::initializer_list l) { 1068 | std::ranges::copy(l, element_begin()); 1069 | } 1070 | 1071 | //---------------------------------------------------------------------------- 1072 | 1073 | template 1074 | inline void array::constants(T val) { 1075 | std::fill(buffer_data(), buffer_data() + buffer_element_count(), val); 1076 | } 1077 | 1078 | template 1079 | inline void array::zeros() { 1080 | constants(0); 1081 | }; 1082 | 1083 | template 1084 | inline void array::ones() { 1085 | constants(1); 1086 | }; 1087 | 1088 | template 1089 | inline void array::random() { 1090 | std::generate(buffer_data(), buffer_data() + buffer_element_count(), []() { 1091 | return static_cast(static_cast(rand()) / RAND_MAX); 1092 | }); 1093 | } 1094 | 1095 | //---------------------------------------------------------------------------- 1096 | 1097 | template 1098 | inline array array::operator+(const array &rhs) const { 1099 | return arithmetic_operation_(*this, rhs, ArithmeticOperation::Add); 1100 | } 1101 | 1102 | template 1103 | inline array array::operator-(const array &rhs) const { 1104 | return arithmetic_operation_(*this, rhs, ArithmeticOperation::Sub); 1105 | } 1106 | 1107 | template 1108 | inline array array::operator*(const array &rhs) const { 1109 | return arithmetic_operation_(*this, rhs, ArithmeticOperation::Mul); 1110 | } 1111 | 1112 | template 1113 | inline array array::operator/(const array &rhs) const { 1114 | return arithmetic_operation_(*this, rhs, ArithmeticOperation::Div); 1115 | } 1116 | 1117 | template 1118 | inline array array::pow(const array &rhs) const { 1119 | return arithmetic_operation_(*this, rhs, ArithmeticOperation::Pow); 1120 | } 1121 | 1122 | template 1123 | inline void array::operator+=(const array &rhs) { 1124 | // TODO: in-place 1125 | *this = arithmetic_operation_(*this, rhs, ArithmeticOperation::Add); 1126 | } 1127 | 1128 | template 1129 | inline void array::operator-=(const array &rhs) { 1130 | // TODO: in-place 1131 | *this = arithmetic_operation_(*this, rhs, ArithmeticOperation::Sub); 1132 | } 1133 | 1134 | template 1135 | inline void array::operator*=(const array &rhs) { 1136 | // TODO: in-place 1137 | *this = arithmetic_operation_(*this, rhs, ArithmeticOperation::Mul); 1138 | } 1139 | 1140 | template 1141 | inline void array::operator/=(const array &rhs) { 1142 | // TODO: in-place 1143 | *this = arithmetic_operation_(*this, rhs, ArithmeticOperation::Div); 1144 | } 1145 | 1146 | //---------------------------------------------------------------------------- 1147 | 1148 | template 1149 | inline array array::dot(const array &rhs) const { 1150 | switch (device_) { 1151 | case Device::GPU: 1152 | return dot_operation_(rhs, gpu_dot_operation_); 1153 | case Device::CPU: 1154 | return dot_operation_(rhs, cpu_dot_operation_); 1155 | } 1156 | } 1157 | 1158 | template 1159 | inline array array::linear(const array& W, const array& b) const { 1160 | return dot(W) + b; 1161 | } 1162 | 1163 | //---------------------------------------------------------------------------- 1164 | 1165 | template 1166 | inline array array::sigmoid() const { 1167 | // TODO: parallel operation on GPU 1168 | auto tmp = array(shape_, 0.0); 1169 | for (size_t i = 0; i < element_count(); i++) { 1170 | tmp.at(i) = 1.0 / (1.0 + std::exp(-static_cast(at(i)))); 1171 | } 1172 | return tmp; 1173 | } 1174 | 1175 | //---------------------------------------------------------------------------- 1176 | 1177 | template 1178 | inline T array::sum() const { 1179 | return std::accumulate(element_cbegin(), element_cend(), T{}); 1180 | } 1181 | 1182 | template 1183 | inline array array::sum(size_t axis) const { 1184 | auto s = shape_; 1185 | s.erase(s.begin() + axis); 1186 | 1187 | auto tmp = array(s, T{}); 1188 | 1189 | enumerate_position_([&](const auto &pos) { 1190 | auto p = pos; 1191 | p.erase(p.begin() + axis); 1192 | 1193 | tmp.at(p) += at(pos); 1194 | }); 1195 | 1196 | return tmp; 1197 | } 1198 | 1199 | template 1200 | inline float array::mean() const { 1201 | return sum() / static_cast(element_count()); 1202 | } 1203 | 1204 | template 1205 | inline array array::mean(size_t axis) const { 1206 | auto t = sum(axis); 1207 | if constexpr (std::is_same_v) { 1208 | return t / shape_[axis]; 1209 | } else { 1210 | return t.template clone() / shape_[axis]; 1211 | } 1212 | } 1213 | 1214 | template 1215 | inline T array::min() const { 1216 | T min_val = std::numeric_limits::max(); 1217 | for (size_t i = 0; i < buffer_element_count(); i++) { 1218 | auto val = buffer_data()[i]; 1219 | if (val < min_val) { 1220 | min_val = val; 1221 | } 1222 | } 1223 | return min_val; 1224 | } 1225 | 1226 | template 1227 | inline T array::max() const { 1228 | T max_val = std::numeric_limits::min(); 1229 | for (size_t i = 0; i < buffer_element_count(); i++) { 1230 | auto val = buffer_data()[i]; 1231 | if (val > max_val) { 1232 | max_val = val; 1233 | } 1234 | } 1235 | return max_val; 1236 | } 1237 | 1238 | template 1239 | inline size_t array::count() const { 1240 | size_t cnt = 0; 1241 | for (size_t i = 0; i < element_count(); i++) { 1242 | if (at(i)) { 1243 | cnt++; 1244 | } 1245 | } 1246 | return cnt; 1247 | } 1248 | 1249 | template 1250 | inline bool array::all(arithmetic auto val) const { 1251 | for (size_t i = 0; i < buffer_element_count(); i++) { 1252 | if (buffer_data()[i] != val) { 1253 | return false; 1254 | } 1255 | } 1256 | return true; 1257 | } 1258 | 1259 | template 1260 | template 1261 | inline bool array::all(U fn) const { 1262 | for (size_t i = 0; i < buffer_element_count(); i++) { 1263 | if (!fn(buffer_data()[i])) { 1264 | return false; 1265 | } 1266 | } 1267 | return true; 1268 | } 1269 | 1270 | template 1271 | inline array array::softmax() const { 1272 | if (dimension() == 1) { 1273 | auto c = min(); 1274 | auto tmp = array(shape_, 0.0); 1275 | 1276 | for (size_t i = 0; i < element_count(); i++) { 1277 | tmp.at(i) = std::exp(at(i) - c); 1278 | } 1279 | return tmp / tmp.sum(); 1280 | } else if (dimension() == 2) { 1281 | auto tmp = array(shape_, 0.0); 1282 | 1283 | for (size_t i = 0; i < shape_[0]; i++) { 1284 | const auto row = (*this)[i]; 1285 | auto c = row.min(); 1286 | for (size_t j = 0; j < row.element_count(); j++) { 1287 | tmp[i].at(j) = std::exp(row.at(j) - c); 1288 | } 1289 | auto smax = tmp[i] / tmp[i].sum(); 1290 | 1291 | for (size_t j = 0; j < row.element_count(); j++) { 1292 | tmp[i].at(j) = smax.at(j); 1293 | } 1294 | } 1295 | return tmp; 1296 | } 1297 | 1298 | throw std::runtime_error( 1299 | "array: softmax is available only for 1 or 2 dimension array."); 1300 | } 1301 | 1302 | template 1303 | inline auto array::argmax() const { 1304 | if (dimension() == 2) { 1305 | auto row_count = shape_[0]; 1306 | auto tmp = array({row_count}, 0); 1307 | 1308 | for (size_t i = 0; i < row_count; i++) { 1309 | const auto row = (*this)[i]; 1310 | 1311 | size_t max_index = 0; 1312 | { 1313 | T max_val = std::numeric_limits::min(); 1314 | for (size_t j = 0; j < row.buffer_element_count(); j++) { 1315 | auto val = row.buffer_data()[j]; 1316 | if (val > max_val) { 1317 | max_val = val; 1318 | max_index = j; 1319 | } 1320 | } 1321 | } 1322 | 1323 | tmp.at(i) = max_index; 1324 | } 1325 | return tmp; 1326 | } 1327 | 1328 | throw std::runtime_error("array: argmax is available for 2 dimension array."); 1329 | } 1330 | 1331 | template 1332 | inline float array::mean_square_error(const array &rhs) const { 1333 | return (*this - rhs).pow(2).mean(); 1334 | } 1335 | 1336 | //---------------------------------------------------------------------------- 1337 | 1338 | template 1339 | inline std::string array::print_shape_type(const shape_type &shape) const { 1340 | std::stringstream ss; 1341 | ss << "{"; 1342 | for (size_t i = 0; i < shape.size(); i++) { 1343 | if (i != 0) { 1344 | ss << ", "; 1345 | } 1346 | ss << shape[i]; 1347 | } 1348 | ss << "}"; 1349 | return ss.str(); 1350 | } 1351 | 1352 | template 1353 | inline std::string array::print_shape() const { 1354 | return print_shape_type(shape_); 1355 | } 1356 | 1357 | template 1358 | inline std::string array::print_strides() const { 1359 | return print_shape_type(strides_); 1360 | } 1361 | 1362 | template 1363 | inline std::string array::print_data_type() const { 1364 | if constexpr (std::is_same_v) { 1365 | return "float"; 1366 | } else { 1367 | return "int"; 1368 | } 1369 | } 1370 | 1371 | template 1372 | inline std::string array::print_info() const { 1373 | std::stringstream ss; 1374 | ss << "dtype: " << print_data_type() << ", dim: " << dimension() 1375 | << ", shape: " << print_shape() << ", strides: " << print_strides(); 1376 | return ss.str(); 1377 | } 1378 | 1379 | template 1380 | inline std::string array::print_array() const { 1381 | auto loop = [&](auto self, auto &os, auto dim, auto arr_index) { 1382 | auto n = shape_[dim]; 1383 | if (dim + 1 == dimension()) { 1384 | for (size_t i = 0; i < n; i++, arr_index++) { 1385 | if (i > 0) { 1386 | os << ", "; 1387 | } 1388 | if constexpr (std::is_same_v) { 1389 | os << (at(arr_index) ? "true" : "false"); 1390 | } else { 1391 | os << at(arr_index); 1392 | } 1393 | } 1394 | return arr_index; 1395 | } 1396 | 1397 | for (size_t i = 0; i < n; i++) { 1398 | if (dim < dimension() && i > 0) { 1399 | os << ",\n"; 1400 | if (dimension() >= 3 && dim == 0 && i > 0) { 1401 | os << "\n"; 1402 | } 1403 | for (size_t j = 0; j <= dim; j++) { 1404 | os << " "; 1405 | } 1406 | } 1407 | os << '{'; 1408 | arr_index = self(self, os, dim + 1, arr_index); 1409 | os << '}'; 1410 | } 1411 | return arr_index; 1412 | }; 1413 | 1414 | std::stringstream ss; 1415 | if (dimension() == 0) { 1416 | ss << at(); 1417 | } else { 1418 | ss << '{'; 1419 | loop(loop, ss, 0, 0); 1420 | ss << '}'; 1421 | } 1422 | return ss.str(); 1423 | } 1424 | 1425 | //---------------------------------------------------------------------------- 1426 | 1427 | template 1428 | inline void array::allocate_buffer_() { 1429 | storage_.off = 0; 1430 | storage_.len = element_count(); 1431 | storage_.buf = metal::default_device().make_buffer(storage_.len * sizeof(T)); 1432 | } 1433 | 1434 | //---------------------------------------------------------------------------- 1435 | 1436 | template 1437 | constexpr size_t nested_initializer_item_count_(const T &l) { 1438 | return 1; 1439 | } 1440 | 1441 | template 1442 | constexpr size_t nested_initializer_item_count_(std::initializer_list l) { 1443 | size_t count = 0; 1444 | for (auto it = l.begin(); it != l.end(); ++it) { 1445 | count += nested_initializer_item_count_(*it); 1446 | } 1447 | return count; 1448 | } 1449 | 1450 | template 1451 | constexpr void nested_initializer_copy_(T &&dst, const auto &src) { 1452 | *dst++ = src; 1453 | } 1454 | 1455 | template 1456 | constexpr void nested_initializer_copy_(T &&dst, std::initializer_list src) { 1457 | for (auto it = src.begin(); it != src.end(); ++it) { 1458 | nested_initializer_copy_(std::forward(dst), *it); 1459 | } 1460 | } 1461 | 1462 | template 1463 | inline void array::copy_initializer_list_(const auto &l) { 1464 | if (nested_initializer_item_count_(l) != element_count()) { 1465 | throw std::runtime_error("array: invalid initializer list."); 1466 | } 1467 | nested_initializer_copy_(buffer_data(), l); 1468 | } 1469 | 1470 | //---------------------------------------------------------------------------- 1471 | 1472 | template 1473 | inline void array::bounds_check_(size_t i) const { 1474 | if (strides_.empty() || i >= element_count()) { 1475 | throw std::runtime_error("array: index is out of bounds."); 1476 | } 1477 | } 1478 | 1479 | template 1480 | inline void array::bounds_check_(size_t x, size_t y) const { 1481 | if (dimension() != 2 || x >= shape_[0] || y >= shape_[1]) { 1482 | throw std::runtime_error("array: (x, y) is out of bounds."); 1483 | } 1484 | } 1485 | 1486 | template 1487 | inline void array::bounds_check_(size_t x, size_t y, size_t z) const { 1488 | if (dimension() != 3 || x >= shape_[0] || y >= shape_[1] || z >= shape_[2]) { 1489 | throw std::runtime_error("array: (x, y, z) is out of bounds."); 1490 | } 1491 | } 1492 | 1493 | //---------------------------------------------------------------------------- 1494 | 1495 | template 1496 | template 1497 | inline void array::enumerate_position_(size_t shape_index, 1498 | std::vector &position, 1499 | U fn) const { 1500 | if (shape_index == shape_.size()) { 1501 | fn(position); 1502 | return; 1503 | } 1504 | 1505 | for (size_t i = 0; i < shape_[shape_index]; i++) { 1506 | position[shape_index] = i; 1507 | enumerate_position_(shape_index + 1, position, fn); 1508 | } 1509 | } 1510 | 1511 | template 1512 | template 1513 | inline void array::enumerate_position_(U fn) const { 1514 | std::vector position(shape_.size()); 1515 | for (size_t i = 0; i < shape_[0]; i++) { 1516 | position[0] = i; 1517 | enumerate_position_(1, position, fn); 1518 | } 1519 | } 1520 | 1521 | //---------------------------------------------------------------------------- 1522 | 1523 | template 1524 | inline auto array::broadcast_(const array &lhs, const array &rhs, auto cb) { 1525 | if (lhs.shape() == rhs.shape()) { 1526 | return cb(lhs, rhs); 1527 | } else if (lhs.dimension() < rhs.dimension()) { 1528 | return cb(lhs.broadcast(rhs.shape()), rhs); 1529 | } else if (lhs.dimension() > rhs.dimension()) { 1530 | return cb(lhs, rhs.broadcast(lhs.shape())); 1531 | } 1532 | throw std::runtime_error("array: invalid operation."); 1533 | } 1534 | 1535 | template 1536 | template 1537 | inline array array::apply_binary_operation_(const array &rhs, 1538 | auto ope) const { 1539 | return broadcast_(*this, rhs, [ope](const auto &lhs, const auto &rhs) { 1540 | // TODO: parallel operation on GPU 1541 | auto tmp = array(lhs.shape(), U{}); 1542 | for (size_t i = 0; i < lhs.element_count(); i++) { 1543 | tmp.at(i) = ope(lhs.at(i), rhs.at(i)); 1544 | } 1545 | return tmp; 1546 | }); 1547 | } 1548 | 1549 | //---------------------------------------------------------------------------- 1550 | 1551 | template 1552 | inline auto array::gpu_arithmetic_operation_(const array &lhs, 1553 | const array &rhs, 1554 | ArithmeticOperation ope) { 1555 | return broadcast_(lhs, rhs, [ope](const auto &lhs, const auto &rhs) { 1556 | auto tmp = array(lhs.shape(), T{}); 1557 | switch (ope) { 1558 | case ArithmeticOperation::Add: 1559 | metal::default_device().add(lhs.storage_, rhs.storage_, 1560 | tmp.storage_); 1561 | break; 1562 | case ArithmeticOperation::Sub: 1563 | metal::default_device().sub(lhs.storage_, rhs.storage_, 1564 | tmp.storage_); 1565 | break; 1566 | case ArithmeticOperation::Mul: 1567 | metal::default_device().mul(lhs.storage_, rhs.storage_, 1568 | tmp.storage_); 1569 | break; 1570 | case ArithmeticOperation::Div: 1571 | metal::default_device().div(lhs.storage_, rhs.storage_, 1572 | tmp.storage_); 1573 | break; 1574 | case ArithmeticOperation::Pow: 1575 | metal::default_device().pow(lhs.storage_, rhs.storage_, 1576 | tmp.storage_); 1577 | break; 1578 | default: 1579 | assert(false); 1580 | break; 1581 | } 1582 | return tmp; 1583 | }); 1584 | } 1585 | 1586 | template 1587 | inline auto array::cpu_arithmetic_operation_(const array &lhs, 1588 | const array &rhs, 1589 | ArithmeticOperation ope) { 1590 | switch (ope) { 1591 | case ArithmeticOperation::Add: 1592 | return lhs.apply_binary_operation_( 1593 | rhs, [](auto lhs, auto rhs) { return lhs + rhs; }); 1594 | break; 1595 | case ArithmeticOperation::Sub: 1596 | return lhs.apply_binary_operation_( 1597 | rhs, [](auto lhs, auto rhs) { return lhs - rhs; }); 1598 | break; 1599 | case ArithmeticOperation::Mul: 1600 | return lhs.apply_binary_operation_( 1601 | rhs, [](auto lhs, auto rhs) { return lhs * rhs; }); 1602 | break; 1603 | case ArithmeticOperation::Div: 1604 | return lhs.apply_binary_operation_( 1605 | rhs, [](auto lhs, auto rhs) { return lhs / rhs; }); 1606 | break; 1607 | case ArithmeticOperation::Pow: 1608 | return lhs.apply_binary_operation_( 1609 | rhs, [](auto lhs, auto rhs) { return std::pow(lhs, rhs); }); 1610 | break; 1611 | default: 1612 | assert(false); 1613 | break; 1614 | } 1615 | } 1616 | 1617 | template 1618 | inline auto array::arithmetic_operation_(const array &lhs, const array &rhs, 1619 | ArithmeticOperation ope) { 1620 | switch (device_) { 1621 | case Device::GPU: 1622 | return gpu_arithmetic_operation_(lhs, rhs, ope); 1623 | case Device::CPU: 1624 | return cpu_arithmetic_operation_(lhs, rhs, ope); 1625 | } 1626 | } 1627 | 1628 | //---------------------------------------------------------------------------- 1629 | 1630 | template 1631 | inline array array::cpu_dot_operation_(const array &lhs, 1632 | const array &rhs) { 1633 | auto rows = lhs.shape_[0]; 1634 | auto cols = rhs.shape_[1]; 1635 | auto m = lhs.shape_[1]; 1636 | auto tmp = array({rows, cols}, T{}); 1637 | 1638 | for (size_t row = 0; row < rows; row++) { 1639 | for (size_t col = 0; col < cols; col++) { 1640 | T val{}; 1641 | for (size_t i = 0; i < m; i++) { 1642 | val += lhs.at(row, i) * rhs.at(i, col); 1643 | } 1644 | tmp.at(row, col) = val; 1645 | } 1646 | } 1647 | return tmp; 1648 | } 1649 | 1650 | template 1651 | inline array array::gpu_dot_operation_(const array &lhs, 1652 | const array &rhs) { 1653 | auto tmp = array({lhs.shape_[0], rhs.shape_[1]}, T{}); 1654 | 1655 | metal::default_device().dot(lhs.storage_, rhs.storage_, tmp.storage_, 1656 | lhs.shape_[1], lhs.shape_[0], rhs.shape_[1]); 1657 | 1658 | return tmp; 1659 | } 1660 | 1661 | template 1662 | template 1663 | inline array array::dot_operation_(const array &rhs, U fn) const { 1664 | if (dimension() == 2 && rhs.dimension() == 2 && shape_[1] == rhs.shape_[0]) { 1665 | return fn(*this, rhs); 1666 | } 1667 | 1668 | if (dimension() == 1 && rhs.dimension() == 2 && shape_[0] == rhs.shape_[0]) { 1669 | auto lhs2 = *this; 1670 | lhs2.reshape({1, shape_[0]}); 1671 | 1672 | auto tmp = fn(lhs2, rhs); 1673 | tmp.reshape({rhs.shape_[1]}); 1674 | return tmp; 1675 | } 1676 | 1677 | if (dimension() == 2 && rhs.dimension() == 1 && shape_[1] == rhs.shape_[0]) { 1678 | auto rhs2 = rhs; 1679 | rhs2.reshape({rhs.shape_[0], 1}); 1680 | 1681 | auto tmp = fn(*this, rhs2); 1682 | tmp.reshape({shape_[0]}); 1683 | return tmp; 1684 | } 1685 | 1686 | if (dimension() == 1 && rhs.dimension() == 1 && shape_[0] == rhs.shape_[0]) { 1687 | auto lhs2 = *this; 1688 | lhs2.reshape({1, shape_[0]}); 1689 | 1690 | auto rhs2 = rhs; 1691 | rhs2.reshape({rhs.shape_[0], 1}); 1692 | 1693 | auto tmp = fn(lhs2, rhs2); 1694 | tmp.reshape({}); 1695 | return tmp; 1696 | } 1697 | 1698 | throw std::runtime_error("array: can't do `dot` operation."); 1699 | } 1700 | 1701 | //---------------------------------------------------------------------------- 1702 | 1703 | template 1704 | inline std::ostream &operator<<(std::ostream &os, const array &arr) { 1705 | os << arr.print_array(); 1706 | return os; 1707 | } 1708 | 1709 | //---------------------------------------------------------------------------- 1710 | 1711 | template 1712 | inline array operator+(auto lhs, const array &rhs) { 1713 | return array(static_cast(lhs)) + rhs; 1714 | } 1715 | 1716 | template 1717 | inline array operator-(auto lhs, const array &rhs) { 1718 | return array(static_cast(lhs)) - rhs; 1719 | } 1720 | 1721 | template 1722 | inline array operator*(auto lhs, const array &rhs) { 1723 | return array(static_cast(lhs)) * rhs; 1724 | } 1725 | 1726 | template 1727 | inline array operator/(auto lhs, const array &rhs) { 1728 | return array(static_cast(lhs)) / rhs; 1729 | } 1730 | 1731 | //---------------------------------------------------------------------------- 1732 | 1733 | template 1734 | inline array where(const array &cond, T x, T y) { 1735 | // TODO: parallel operation on GPU 1736 | auto tmp = array(cond.shape(), T{}); 1737 | for (size_t i = 0; i < cond.element_count(); i++) { 1738 | tmp.at(i) = cond.at(i) ? x : y; 1739 | } 1740 | return tmp; 1741 | } 1742 | 1743 | //---------------------------------------------------------------------------- 1744 | 1745 | template 1746 | inline bool array_equal(const array &a, const array &b) { 1747 | if (&a != &b) { 1748 | if (a.shape() != b.shape()) { 1749 | return false; 1750 | } 1751 | 1752 | for (size_t i = 0; i < a.element_count(); i++) { 1753 | if (a.at(i) != b.at(i)) { 1754 | return false; 1755 | } 1756 | } 1757 | } 1758 | return true; 1759 | } 1760 | 1761 | template 1762 | inline bool is_close(T a, U b, float tolerance) { 1763 | return std::abs(static_cast(a) - static_cast(b)) <= tolerance; 1764 | } 1765 | 1766 | template 1767 | inline bool is_close(T a, U b) { 1768 | return a == b; 1769 | } 1770 | 1771 | template 1772 | inline bool allclose(const array &a, const array &b, float tolerance) { 1773 | if (&a != &b) { 1774 | if (a.shape() != b.shape()) { 1775 | return false; 1776 | } 1777 | 1778 | for (size_t i = 0; i < a.element_count(); i++) { 1779 | if constexpr (std::is_same_v) { 1780 | if (std::abs(a.at(i) - b.at(i)) > tolerance) { 1781 | return false; 1782 | } 1783 | } else { 1784 | if (a.at(i) != b.at(i)) { 1785 | return false; 1786 | } 1787 | } 1788 | } 1789 | } 1790 | return true; 1791 | } 1792 | 1793 | //---------------------------------------------------------------------------- 1794 | 1795 | template 1796 | inline auto empty(const shape_type &shape) { 1797 | return array(shape, T{}); 1798 | } 1799 | 1800 | template 1801 | inline auto zeros(const shape_type &shape) { 1802 | return array(shape, 0); 1803 | } 1804 | 1805 | template 1806 | inline auto ones(const shape_type &shape) { 1807 | return array(shape, 1); 1808 | } 1809 | 1810 | inline auto random(const shape_type &shape) { 1811 | auto tmp = array(shape, 0.0); 1812 | tmp.random(); 1813 | return tmp; 1814 | } 1815 | 1816 | }; // namespace mtl 1817 | -------------------------------------------------------------------------------- /include/metal.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace mtl { 8 | 9 | template 10 | concept value_type = 11 | std::same_as || std::same_as || std::same_as; 12 | 13 | template 14 | concept arithmetic = std::is_arithmetic_v; 15 | 16 | //----------------------------------------------------------------------------- 17 | 18 | template 19 | struct releaser { 20 | void operator()(T* p) { 21 | if (p != nullptr) { 22 | p->release(); 23 | } else { 24 | throw std::runtime_error( 25 | "metal: This managed resource object has already been released..."); 26 | } 27 | } 28 | }; 29 | 30 | template 31 | inline auto managed(T* p) { 32 | return std::shared_ptr(p, releaser()); 33 | } 34 | 35 | template 36 | using managed_ptr = std::shared_ptr; 37 | 38 | //----------------------------------------------------------------------------- 39 | 40 | template 41 | struct nested_initializer_list_ { 42 | using nested_type = nested_initializer_list_::type; 43 | using type = std::initializer_list; 44 | }; 45 | 46 | template 47 | struct nested_initializer_list_ { 48 | using type = T; 49 | }; 50 | 51 | template 52 | using nested_initializer_list = nested_initializer_list_::type; 53 | 54 | //----------------------------------------------------------------------------- 55 | 56 | enum class Device { 57 | GPU = 0, 58 | CPU, 59 | }; 60 | 61 | static Device device_ = Device::GPU; 62 | 63 | inline void use_cpu() { device_ = Device::CPU; } 64 | 65 | inline void use_gpu() { device_ = Device::GPU; } 66 | 67 | //----------------------------------------------------------------------------- 68 | 69 | class metal { 70 | public: 71 | struct storage { 72 | managed_ptr buf; 73 | size_t off = 0; 74 | size_t len = 0; 75 | }; 76 | 77 | metal(managed_ptr device); 78 | 79 | managed_ptr make_buffer(NS::UInteger length); 80 | 81 | template 82 | void add(const storage& A, const storage& B, storage& OUT); 83 | 84 | template 85 | void sub(const storage& A, const storage& B, storage& OUT); 86 | 87 | template 88 | void mul(const storage& A, const storage& B, storage& OUT); 89 | 90 | template 91 | void div(const storage& A, const storage& B, storage& OUT); 92 | 93 | template 94 | void pow(const storage& A, const storage& B, storage& OUT); 95 | 96 | template 97 | void dot(const storage& A, const storage& B, storage& OUT, uint32_t A_cols, 98 | uint32_t OUT_rows, uint32_t OUT_cols); 99 | 100 | static metal& default_device() { 101 | static auto metal_ = metal(managed(MTL::CreateSystemDefaultDevice())); 102 | return metal_; 103 | } 104 | 105 | private: 106 | enum class DataType { 107 | Float = 0, 108 | Integer, 109 | }; 110 | 111 | managed_ptr device_; 112 | 113 | managed_ptr pso_add_; 114 | managed_ptr pso_sub_; 115 | managed_ptr pso_mul_; 116 | managed_ptr pso_div_; 117 | managed_ptr pso_pow_; 118 | managed_ptr pso_dot_; 119 | 120 | managed_ptr queue_; 121 | 122 | template 123 | void arithmetic_operation_(const storage& A, const storage& B, storage& OUT, 124 | managed_ptr pso); 125 | 126 | managed_ptr create_compute_pipeline_state_object_( 127 | managed_ptr device, managed_ptr library, 128 | const char* name); 129 | }; 130 | 131 | //----------------------------------------------------------------------------- 132 | // Implementation 133 | //----------------------------------------------------------------------------- 134 | 135 | static const char* msl_source_ = R"( 136 | 137 | #include 138 | 139 | using namespace metal; 140 | 141 | template 142 | void arithmetic_operation_( 143 | device const void* A, 144 | device const void* B, 145 | device void* OUT, 146 | constant uint32_t& A_length, 147 | constant uint32_t& B_length, 148 | uint gid) 149 | { 150 | auto A_arr = static_cast(A); 151 | auto B_arr = static_cast(B); 152 | auto OUT_arr = reinterpret_cast(OUT); 153 | 154 | // broadcast offset 155 | auto A_index = gid % A_length; 156 | auto B_index = gid % B_length; 157 | 158 | OUT_arr[gid] = Ope()(A_arr[A_index], B_arr[B_index]); 159 | } 160 | 161 | template struct add_ { T operator()(T a, T b) { return a + b; } }; 162 | template struct sub_ { T operator()(T a, T b) { return a - b; } }; 163 | template struct mul_ { T operator()(T a, T b) { return a * b; } }; 164 | template struct div_ { T operator()(T a, T b) { return a / b; } }; 165 | 166 | struct powf_ { float operator()(float a, float b) { return pow(a, b); } }; 167 | struct powi_ { int operator()(int a, int b) { 168 | return round(pow(static_cast(a), static_cast(b))); 169 | } }; 170 | 171 | template 172 | void dot_operatoin( 173 | device const void* A, 174 | device const void* B, 175 | device void* OUT, 176 | constant uint32_t& A_cols, 177 | constant uint32_t& OUT_raws, 178 | constant uint32_t& OUT_cols, 179 | uint2 gid) 180 | { 181 | auto A_arr = static_cast(A); 182 | auto B_arr = static_cast(B); 183 | auto OUT_arr = reinterpret_cast(OUT); 184 | 185 | auto irow = gid.y; 186 | auto icol = gid.x; 187 | 188 | T val{}; 189 | for (uint32_t i = 0; i < A_cols; i++) { 190 | auto aval = A_arr[A_cols * irow + i]; 191 | auto bval = B_arr[OUT_cols * i + icol]; 192 | val += aval * bval; 193 | } 194 | 195 | OUT_arr[OUT_cols * irow + icol] = val; 196 | } 197 | 198 | constant uint32_t Float = 0; 199 | 200 | kernel void add( 201 | device const void* A, 202 | device const void* B, 203 | device void* OUT, 204 | constant uint32_t& A_length, 205 | constant uint32_t& B_length, 206 | constant uint32_t& dtype, 207 | uint gid [[thread_position_in_grid]]) 208 | { 209 | if (dtype == Float) { 210 | arithmetic_operation_, float>(A, B, OUT, A_length, B_length, gid); 211 | } else { 212 | arithmetic_operation_, int>(A, B, OUT, A_length, B_length, gid); 213 | } 214 | } 215 | 216 | kernel void sub( 217 | device const void* A, 218 | device const void* B, 219 | device void* OUT, 220 | constant uint32_t& A_length, 221 | constant uint32_t& B_length, 222 | constant uint32_t& dtype, 223 | uint gid [[thread_position_in_grid]]) 224 | { 225 | if (dtype == Float) { 226 | arithmetic_operation_, float>(A, B, OUT, A_length, B_length, gid); 227 | } else { 228 | arithmetic_operation_, int>(A, B, OUT, A_length, B_length, gid); 229 | } 230 | } 231 | 232 | kernel void mul( 233 | device const void* A, 234 | device const void* B, 235 | device void* OUT, 236 | constant uint32_t& A_length, 237 | constant uint32_t& B_length, 238 | constant uint32_t& dtype, 239 | uint gid [[thread_position_in_grid]]) 240 | { 241 | if (dtype == Float) { 242 | arithmetic_operation_, float>(A, B, OUT, A_length, B_length, gid); 243 | } else { 244 | arithmetic_operation_, int>(A, B, OUT, A_length, B_length, gid); 245 | } 246 | } 247 | 248 | kernel void div( 249 | device const void* A, 250 | device const void* B, 251 | device void* OUT, 252 | constant uint32_t& A_length, 253 | constant uint32_t& B_length, 254 | constant uint32_t& dtype, 255 | uint gid [[thread_position_in_grid]]) 256 | { 257 | if (dtype == Float) { 258 | arithmetic_operation_, float>(A, B, OUT, A_length, B_length, gid); 259 | } else { 260 | arithmetic_operation_, int>(A, B, OUT, A_length, B_length, gid); 261 | } 262 | } 263 | 264 | kernel void pow( 265 | device const void* A, 266 | device const void* B, 267 | device void* OUT, 268 | constant uint32_t& A_length, 269 | constant uint32_t& B_length, 270 | constant uint32_t& dtype, 271 | uint gid [[thread_position_in_grid]]) 272 | { 273 | if (dtype == Float) { 274 | arithmetic_operation_(A, B, OUT, A_length, B_length, gid); 275 | } else { 276 | arithmetic_operation_(A, B, OUT, A_length, B_length, gid); 277 | } 278 | } 279 | 280 | kernel void dot( 281 | device const void* A, 282 | device const void* B, 283 | device void* OUT, 284 | constant uint32_t& A_cols, 285 | constant uint32_t& OUT_raws, 286 | constant uint32_t& OUT_cols, 287 | constant uint32_t& dtype, 288 | uint2 gid [[thread_position_in_grid]]) 289 | { 290 | if (dtype == Float) { 291 | dot_operatoin(A, B, OUT, A_cols, OUT_raws, OUT_cols, gid); 292 | } else { 293 | dot_operatoin(A, B, OUT, A_cols, OUT_raws, OUT_cols, gid); 294 | } 295 | } 296 | 297 | )"; 298 | 299 | //----------------------------------------------------------------------------- 300 | 301 | inline metal::metal(managed_ptr device) : device_(device) { 302 | if (device == nullptr) { 303 | throw std::runtime_error("metal: Failed to create the default library."); 304 | } 305 | 306 | // Compile a Metal library 307 | auto src = NS::String::string(msl_source_, NS::ASCIIStringEncoding); 308 | NS::Error* error = nullptr; 309 | 310 | auto lib = managed(device->newLibrary(src, nullptr, &error)); 311 | if (lib == nullptr) { 312 | std::stringstream ss; 313 | ss << "metal: Failed to compile the Metal library, error " << error << "."; 314 | throw std::runtime_error(ss.str()); 315 | } 316 | 317 | // Create pipeline state objects 318 | pso_add_ = create_compute_pipeline_state_object_(device, lib, "add"); 319 | pso_sub_ = create_compute_pipeline_state_object_(device, lib, "sub"); 320 | pso_mul_ = create_compute_pipeline_state_object_(device, lib, "mul"); 321 | pso_div_ = create_compute_pipeline_state_object_(device, lib, "div"); 322 | pso_pow_ = create_compute_pipeline_state_object_(device, lib, "pow"); 323 | pso_dot_ = create_compute_pipeline_state_object_(device, lib, "dot"); 324 | 325 | // Create a command queue 326 | queue_ = managed(device->newCommandQueue()); 327 | 328 | if (queue_ == nullptr) { 329 | throw std::runtime_error("metal: Failed to find the command queue."); 330 | } 331 | } 332 | 333 | inline managed_ptr metal::make_buffer(NS::UInteger length) { 334 | return managed(device_->newBuffer(length, MTL::ResourceStorageModeShared)); 335 | } 336 | 337 | template 338 | inline void metal::add(const storage& A, const storage& B, storage& OUT) { 339 | arithmetic_operation_(A, B, OUT, pso_add_); 340 | } 341 | 342 | template 343 | inline void metal::sub(const storage& A, const storage& B, storage& OUT) { 344 | arithmetic_operation_(A, B, OUT, pso_sub_); 345 | } 346 | 347 | template 348 | inline void metal::mul(const storage& A, const storage& B, storage& OUT) { 349 | arithmetic_operation_(A, B, OUT, pso_mul_); 350 | } 351 | 352 | template 353 | inline void metal::div(const storage& A, const storage& B, storage& OUT) { 354 | arithmetic_operation_(A, B, OUT, pso_div_); 355 | } 356 | 357 | template 358 | inline void metal::pow(const storage& A, const storage& B, storage& OUT) { 359 | arithmetic_operation_(A, B, OUT, pso_pow_); 360 | } 361 | 362 | template 363 | inline void metal::dot(const storage& A, const storage& B, storage& OUT, 364 | uint32_t A_cols, uint32_t OUT_rows, uint32_t OUT_cols) { 365 | auto pso = pso_dot_; 366 | auto dtype = std::is_same_v 367 | ? static_cast(DataType::Float) 368 | : static_cast(DataType::Integer); 369 | auto commandBuffer = queue_->commandBuffer(); 370 | auto computeEncoder = commandBuffer->computeCommandEncoder(); 371 | 372 | computeEncoder->setComputePipelineState(pso.get()); 373 | computeEncoder->setBuffer(A.buf.get(), A.off * sizeof(T), 0); 374 | computeEncoder->setBuffer(B.buf.get(), B.off * sizeof(T), 1); 375 | computeEncoder->setBuffer(OUT.buf.get(), OUT.off * sizeof(T), 2); 376 | computeEncoder->setBytes(&A_cols, sizeof(uint32_t), 3); 377 | computeEncoder->setBytes(&OUT_rows, sizeof(uint32_t), 4); 378 | computeEncoder->setBytes(&OUT_cols, sizeof(uint32_t), 5); 379 | computeEncoder->setBytes(&dtype, sizeof(uint32_t), 6); 380 | 381 | auto grid_size = MTL::Size::Make(OUT_cols, OUT_rows, 1); 382 | auto w = pso->threadExecutionWidth(); 383 | auto h = pso->maxTotalThreadsPerThreadgroup() / w; 384 | auto threads_size = MTL::Size::Make(w, h, 1); 385 | 386 | computeEncoder->dispatchThreads(grid_size, threads_size); 387 | computeEncoder->endEncoding(); 388 | 389 | commandBuffer->commit(); 390 | commandBuffer->waitUntilCompleted(); 391 | } 392 | 393 | template 394 | inline void metal::arithmetic_operation_( 395 | const storage& A, const storage& B, storage& OUT, 396 | managed_ptr pso) { 397 | auto dtype = std::is_same_v 398 | ? static_cast(DataType::Float) 399 | : static_cast(DataType::Integer); 400 | 401 | auto commandBuffer = queue_->commandBuffer(); 402 | auto computeEncoder = commandBuffer->computeCommandEncoder(); 403 | 404 | computeEncoder->setComputePipelineState(pso.get()); 405 | computeEncoder->setBuffer(A.buf.get(), A.off * sizeof(T), 0); 406 | computeEncoder->setBuffer(B.buf.get(), B.off * sizeof(T), 1); 407 | computeEncoder->setBuffer(OUT.buf.get(), OUT.off * sizeof(T), 2); 408 | computeEncoder->setBytes(&A.len, sizeof(uint32_t), 3); 409 | computeEncoder->setBytes(&B.len, sizeof(uint32_t), 4); 410 | computeEncoder->setBytes(&dtype, sizeof(uint32_t), 5); 411 | 412 | auto grid_size = MTL::Size::Make(OUT.len, 1, 1); 413 | auto w = pso->threadExecutionWidth(); 414 | auto h = pso->maxTotalThreadsPerThreadgroup() / w; 415 | auto threads_size = MTL::Size::Make(w, h, 1); 416 | 417 | computeEncoder->dispatchThreads(grid_size, threads_size); 418 | computeEncoder->endEncoding(); 419 | 420 | commandBuffer->commit(); 421 | commandBuffer->waitUntilCompleted(); 422 | } 423 | 424 | inline managed_ptr 425 | metal::create_compute_pipeline_state_object_(managed_ptr device, 426 | managed_ptr library, 427 | const char* name) { 428 | auto str = NS::String::string(name, NS::ASCIIStringEncoding); 429 | auto fn = managed(library->newFunction(str)); 430 | 431 | if (fn == nullptr) { 432 | std::stringstream ss; 433 | ss << "metal: Failed to find the " << name << " function."; 434 | throw std::runtime_error(ss.str()); 435 | } 436 | 437 | NS::Error* error = nullptr; 438 | auto pso = managed(device->newComputePipelineState(fn.get(), &error)); 439 | 440 | if (pso == nullptr) { 441 | std::stringstream ss; 442 | ss << "metal: Failed to created pipeline state object, error " << error 443 | << "."; 444 | throw std::runtime_error(ss.str()); 445 | } 446 | 447 | return pso; 448 | } 449 | 450 | }; // namespace mtl 451 | -------------------------------------------------------------------------------- /include/mtlcpp.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "./metal.h" 4 | #include "./array.h" 5 | -------------------------------------------------------------------------------- /test/Makefile: -------------------------------------------------------------------------------- 1 | METAL_CPP_DIR = ../metal-cpp_macOS14.2_iOS17.2 2 | 3 | CXX = clang++ 4 | CXXFLAGS = -std=c++20 -O2 -I$(METAL_CPP_DIR) -I../include -framework Foundation -framework Metal -framework MetalKit 5 | 6 | INC = ../include/metal.h ../include/array.h ../include/mtlcpp.h 7 | 8 | TEST_SRC = test_array.cpp test_examples.cpp test_perceptron.cpp test_2lnn.cpp test.cpp 9 | BENCH_SRC = bench.cpp 10 | MNIST_SRC = mnist.cpp 11 | 12 | all : test bench mnist 13 | ./test 14 | ./bench 15 | ./mnist 16 | 17 | test : $(INC) $(TEST_SRC) Makefile 18 | $(CXX) -o $@ $(CXXFLAGS) $(TEST_SRC) 19 | 20 | bench : $(INC) $(BENCH_SRC) Makefile 21 | $(CXX) -o $@ $(CXXFLAGS) $(BENCH_SRC) 22 | 23 | mnist : $(INC) $(MNIST_SRC) Makefile 24 | $(CXX) -o $@ $(CXXFLAGS) $(MNIST_SRC) 25 | 26 | clean: 27 | rm -rf test bench mnist 28 | 29 | -------------------------------------------------------------------------------- /test/bench.cpp: -------------------------------------------------------------------------------- 1 | #define NS_PRIVATE_IMPLEMENTATION 2 | #define MTL_PRIVATE_IMPLEMENTATION 3 | #define ANKERL_NANOBENCH_IMPLEMENT 4 | 5 | #include 6 | 7 | #include 8 | 9 | #include "nanobench.h" 10 | 11 | using namespace ankerl::nanobench; 12 | 13 | void add() { 14 | const size_t n = 10'000'000; 15 | 16 | auto a = mtl::ones({n}); 17 | auto b = mtl::ones({n}); 18 | auto c = mtl::array(); 19 | 20 | mtl::use_cpu(); 21 | Bench().run("CPU: a + b", [&] { c = a + b; }); 22 | 23 | mtl::use_gpu(); 24 | Bench().minEpochIterations(100).run("GPU: a + b", [&] { c = a + b; }); 25 | 26 | auto aa = Eigen::Vector::Ones(n); 27 | auto bb = Eigen::Vector::Ones(n); 28 | auto cc = Eigen::Vector(n); 29 | 30 | Bench().minEpochIterations(100).run("Eigen: a + b", [&] { cc = aa + bb; }); 31 | } 32 | 33 | void dot() { 34 | auto a = mtl::ones({1000, 1000}); 35 | auto b = mtl::ones({1000, 100}); 36 | auto c = mtl::array(); 37 | 38 | mtl::use_cpu(); 39 | Bench().run("CPU: a.dot(b)", [&] { c = a.dot(b); }); 40 | 41 | mtl::use_gpu(); 42 | Bench().minEpochIterations(100).run("GPU: a.dot(b)", [&] { c = a.dot(b); }); 43 | 44 | auto aa = 45 | Eigen::Matrix::Ones(1000, 1000); 46 | auto bb = 47 | Eigen::Matrix::Ones(1000, 100); 48 | auto cc = Eigen::Matrix(); 49 | 50 | Bench().minEpochIterations(100).run("Eigen: a * b", [&] { cc = aa * bb; }); 51 | } 52 | 53 | int main(void) { 54 | add(); 55 | dot(); 56 | } 57 | -------------------------------------------------------------------------------- /test/mnist.cpp: -------------------------------------------------------------------------------- 1 | #define NS_PRIVATE_IMPLEMENTATION 2 | #define MTL_PRIVATE_IMPLEMENTATION 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | class memory_mapped_file { 12 | public: 13 | memory_mapped_file(const char* path) { 14 | fd_ = open(path, O_RDONLY); 15 | assert(fd_ != -1); 16 | 17 | struct stat sb; 18 | fstat(fd_, &sb); 19 | size_ = sb.st_size; 20 | 21 | addr_ = ::mmap(NULL, size_, PROT_READ, MAP_PRIVATE, fd_, 0); 22 | assert(addr_ != MAP_FAILED); 23 | } 24 | 25 | ~memory_mapped_file() { 26 | munmap(addr_, size_); 27 | close(fd_); 28 | } 29 | 30 | const char* data() const { return (const char*)addr_; } 31 | 32 | private: 33 | int fd_ = -1; 34 | size_t size_ = 0; 35 | void* addr_ = MAP_FAILED; 36 | }; 37 | 38 | class mnist_data { 39 | public: 40 | mnist_data(const char* label_data_path, const char* image_data_path) 41 | : label_data_(label_data_path), image_data_(image_data_path) { 42 | if (read32(label_data_, 0) != 2049) { 43 | throw std::runtime_error("invalid label data format."); 44 | } 45 | 46 | if (read32(image_data_, 0) != 2051) { 47 | throw std::runtime_error("invalid image data format."); 48 | } 49 | 50 | number_of_items_ = read32(label_data_, 4); 51 | if (number_of_items_ != read32(image_data_, 4)) { 52 | throw std::runtime_error("image data doesn't match label data."); 53 | } 54 | 55 | labels_ = reinterpret_cast(label_data_.data() + 8); 56 | 57 | number_of_rows_ = read32(image_data_, 8); 58 | number_of_columns_ = read32(image_data_, 12); 59 | pixels_ = reinterpret_cast(image_data_.data() + 16); 60 | } 61 | 62 | size_t size() const { return number_of_items_; } 63 | 64 | // Label 65 | const uint8_t* label_data() const { return labels_; } 66 | 67 | size_t label(size_t i) const { return labels_[i]; } 68 | 69 | // Image 70 | const float* normalized_image_data() const { 71 | if (normalized_pixels_.empty()) { 72 | auto total_pixels = image_pixel_size() * size(); 73 | normalized_pixels_.resize(total_pixels); 74 | for (size_t i = 0; i < total_pixels; i++) { 75 | normalized_pixels_[i] = (float)pixels_[i] / (float)255; 76 | } 77 | } 78 | return normalized_pixels_.data(); 79 | } 80 | 81 | size_t image_rows() const { return number_of_rows_; } 82 | 83 | size_t image_columns() const { return number_of_columns_; } 84 | 85 | size_t image_pixel_size() const { 86 | return number_of_rows_ * number_of_columns_; 87 | } 88 | 89 | uint8_t pixel(size_t row, size_t col) const { 90 | return pixels_[number_of_columns_ * row + col]; 91 | } 92 | 93 | private: 94 | uint32_t read32(const memory_mapped_file& mm, size_t off) { 95 | auto p = mm.data() + off; 96 | return __builtin_bswap32(*reinterpret_cast(p)); 97 | } 98 | 99 | size_t number_of_items_ = 0; 100 | size_t number_of_rows_ = 0; 101 | size_t number_of_columns_ = 0; 102 | 103 | const uint8_t* labels_ = nullptr; 104 | const uint8_t* pixels_ = nullptr; 105 | 106 | memory_mapped_file label_data_; 107 | memory_mapped_file image_data_; 108 | 109 | mutable std::vector normalized_pixels_; 110 | }; 111 | 112 | struct Network { 113 | std::map> w; 114 | std::map> b; 115 | }; 116 | 117 | Network init_network() { 118 | auto split = [](const char* b, const char* e, char d, 119 | std::function fn) { 120 | size_t i = 0; 121 | size_t beg = 0; 122 | while (e ? (b + i < e) : (b[i] != '\0')) { 123 | if (b[i] == d) { 124 | fn(&b[beg], &b[i]); 125 | beg = i + 1; 126 | } 127 | i++; 128 | } 129 | if (i) { 130 | fn(&b[beg], &b[i]); 131 | } 132 | }; 133 | 134 | Network network; 135 | std::ifstream f("sample_weight.csv"); 136 | std::string line; 137 | while (std::getline(f, line)) { 138 | std::replace(line.begin(), line.end(), ',', ' '); 139 | std::istringstream s(line); 140 | 141 | std::string label; 142 | size_t rows; 143 | size_t cols; 144 | s >> label >> rows >> cols; 145 | 146 | std::vector values; 147 | { 148 | values.reserve(rows * cols); 149 | 150 | auto count = rows; 151 | while (count > 0) { 152 | std::getline(f, line); 153 | split(&line[0], &line[line.size() - 1], ',', 154 | [&](auto b, auto /*e*/) { values.push_back(std::atof(b)); }); 155 | count--; 156 | } 157 | } 158 | 159 | if (rows > 1) { 160 | network.w[label] = mtl::array({rows, cols}, values); 161 | } else { 162 | network.b[label] = mtl::array({cols}, values); 163 | } 164 | } 165 | return network; 166 | } 167 | 168 | auto predict(const Network& network, const mtl::array& x) { 169 | auto W1 = network.w.at("W1"); 170 | auto W2 = network.w.at("W2"); 171 | auto W3 = network.w.at("W3"); 172 | auto b1 = network.b.at("b1"); 173 | auto b2 = network.b.at("b2"); 174 | auto b3 = network.b.at("b3"); 175 | 176 | auto a1 = x.dot(W1) + b1; 177 | auto z1 = a1.sigmoid(); 178 | auto a2 = z1.dot(W2) + b2; 179 | auto z2 = a2.sigmoid(); 180 | auto a3 = z2.dot(W3) + b3; 181 | auto y = a3.softmax(); 182 | return y; 183 | } 184 | 185 | int main(int argc, const char** argv) { 186 | try { 187 | if (argc > 1) { 188 | if (std::string("--cpu") == argv[1]) { 189 | mtl::use_cpu(); 190 | } 191 | } 192 | 193 | auto data = mnist_data("t10k-labels-idx1-ubyte", "t10k-images-idx3-ubyte"); 194 | auto network = init_network(); 195 | 196 | auto p = data.normalized_image_data(); 197 | 198 | size_t batch_size = 100; 199 | size_t accuracy_cnt = 0; 200 | 201 | for (auto i = 0u; i < data.size(); i += batch_size) { 202 | auto x = mtl::array({batch_size, data.image_pixel_size()}, 203 | p + data.image_pixel_size() * i); 204 | auto y = predict(network, x); 205 | auto e = mtl::array({batch_size}, data.label_data() + i); 206 | auto a = y.argmax(); 207 | 208 | auto r = e == a; 209 | accuracy_cnt += r.count(); 210 | } 211 | 212 | auto accuracy = (double)accuracy_cnt / (double)data.size(); 213 | std::cout << "MNIST Accuracy: " << accuracy << std::endl; 214 | } catch (const std::runtime_error& e) { 215 | std::cerr << e.what() << std::endl; 216 | } 217 | } 218 | -------------------------------------------------------------------------------- /test/t10k-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhirose/mtlcpp/59e5108a25a90957469d631423a0089c157feb01/test/t10k-images-idx3-ubyte -------------------------------------------------------------------------------- /test/t10k-labels-idx1-ubyte: -------------------------------------------------------------------------------- 1 | '                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             -------------------------------------------------------------------------------- /test/test.cpp: -------------------------------------------------------------------------------- 1 | #define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN 2 | #include "doctest.h" 3 | 4 | #define ANKERL_NANOBENCH_IMPLEMENT 5 | #include "nanobench.h" 6 | 7 | #define NS_PRIVATE_IMPLEMENTATION 8 | #define MTL_PRIVATE_IMPLEMENTATION 9 | #include 10 | -------------------------------------------------------------------------------- /test/test_2lnn.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "doctest.h" 4 | 5 | mtl::array mean_square_error_derivative(float dout, 6 | const mtl::array& out, 7 | const mtl::array& Y) { 8 | return dout * (2 * (out - Y)); 9 | } 10 | 11 | mtl::array sigmoid_derivative(const mtl::array& dout, 12 | const mtl::array& x) { 13 | auto y = x.sigmoid(); 14 | return dout * (y * (1 - y)); 15 | } 16 | 17 | std::tuple, mtl::array, mtl::array> 18 | linear_derivative(const mtl::array& dout, const mtl::array& x, 19 | const mtl::array& W) { 20 | auto dx = dout.dot(W.transpose()); 21 | auto dW = x.transpose().dot(dout); 22 | auto db = dout.sum(0); 23 | return {dx, dW, db}; 24 | } 25 | 26 | struct TwoLayerNeuralNetwork { 27 | mtl::array W1 = mtl::random({2, 3}) * 2.0 - 1.0; 28 | mtl::array b1 = mtl::random({3}) * 2.0 - 1.0; 29 | 30 | mtl::array W2 = mtl::random({3, 1}) * 2.0 - 1.0; 31 | mtl::array b2 = mtl::random({1}) * 2.0 - 1.0; 32 | 33 | mtl::array x; 34 | mtl::array net1; 35 | mtl::array out1; 36 | mtl::array net2; 37 | mtl::array out2; 38 | 39 | mtl::array Y; 40 | 41 | mtl::array forward(const mtl::array& x) { 42 | // Input → Hidden 43 | auto net1 = x.linear(W1, b1); 44 | auto out1 = net1.sigmoid(); 45 | 46 | // Hidden → Output 47 | auto net2 = out1.linear(W2, b2); 48 | auto out2 = net2.sigmoid(); 49 | 50 | // Save variables for backpropagation 51 | this->x = x; 52 | this->net1 = net1; 53 | this->out1 = out1; 54 | this->net2 = net2; 55 | this->out2 = out2; 56 | 57 | return out2; 58 | } 59 | 60 | float loss(const mtl::array& out, const mtl::array& Y) { 61 | // Save variables for back propagation 62 | this->Y = Y; 63 | 64 | return out.mean_square_error(Y); 65 | } 66 | 67 | std::tuple, mtl::array, mtl::array, 68 | mtl::array> 69 | backward() { 70 | auto dout = mean_square_error_derivative(1.0, this->out2, this->Y); 71 | dout = sigmoid_derivative(dout, this->net2); 72 | 73 | const auto& [dout1, dW2, db2] = 74 | linear_derivative(dout, this->out1, this->W2); 75 | 76 | dout = sigmoid_derivative(dout1, this->net1); 77 | 78 | const auto& [dx, dW1, db1] = linear_derivative(dout, this->x, this->W1); 79 | 80 | return {dW1, db1, dW2, db2}; 81 | } 82 | }; 83 | 84 | mtl::array predict(TwoLayerNeuralNetwork& model, 85 | const mtl::array& x) { 86 | auto out = model.forward(x); // 0..1 87 | return mtl::where(out > 0.5, 1, 0); 88 | } 89 | 90 | void train(TwoLayerNeuralNetwork& model, const mtl::array& X, 91 | const mtl::array& Y, size_t epochs, float learning_rate) { 92 | std::vector losses; 93 | 94 | for (size_t epoch = 0; epoch < epochs; epoch++) { 95 | // Save variables for back propagation 96 | auto out = model.forward(X); 97 | auto loss = model.loss(out, Y); 98 | 99 | // Get gradients of weight parameters 100 | const auto& [dW1, db1, dW2, db2] = model.backward(); 101 | 102 | // Update weights 103 | model.W1 -= dW1 * learning_rate; 104 | model.b1 -= db1 * learning_rate; 105 | model.W2 -= dW2 * learning_rate; 106 | model.b2 -= db2 * learning_rate; 107 | 108 | losses.push_back(loss); 109 | 110 | // Show progress message 111 | if (epoch % (epochs / 10) == 0) { 112 | printf("Epoch: %zu, Loss: %f\n", epoch, loss); 113 | } 114 | } 115 | } 116 | 117 | TEST_CASE("array: mean_square_error") { 118 | auto a = mtl::array{1, 2, 3, 4}; 119 | auto b = mtl::array{0, 2, 3, 6}; 120 | auto mean = a.mean_square_error(b); 121 | CHECK(mean == 1.25); 122 | } 123 | 124 | TEST_CASE("2 layer NN: xor") { 125 | auto X = mtl::array{ 126 | {0, 0}, 127 | {0, 1}, 128 | {1, 0}, 129 | {1, 1}, 130 | }; 131 | 132 | auto Y_XOR = mtl::array{ 133 | {0}, 134 | {1}, 135 | {1}, 136 | {0}, 137 | }; 138 | 139 | TwoLayerNeuralNetwork m; 140 | 141 | train(m, X, Y_XOR, 2000, 0.5); 142 | 143 | auto out = predict(m, X); 144 | std::cout << out << std::endl; 145 | } 146 | -------------------------------------------------------------------------------- /test/test_array.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include "doctest.h" 7 | 8 | using namespace mtl; 9 | 10 | auto itoa(size_t size, size_t init = 1) { 11 | return std::views::iota(init) | std::views::take(size); 12 | } 13 | 14 | //------------------------------------------------------------------------------ 15 | 16 | TEST_CASE("array: scalar size") { 17 | auto s = array(100); 18 | CHECK(s.element_count() == 1); 19 | CHECK_THROWS_WITH_AS(s.length(), "array: cannot call with a scalar value.", 20 | std::runtime_error); 21 | CHECK(s.dimension() == 0); 22 | CHECK(s.shape() == shape_type{}); 23 | CHECK(s.at() == 100); 24 | } 25 | 26 | //------------------------------------------------------------------------------ 27 | 28 | TEST_CASE("array: vector size") { 29 | auto v = empty({3}); 30 | CHECK(v.element_count() == 3); 31 | CHECK(v.length() == 3); 32 | CHECK(v.dimension() == 1); 33 | CHECK(v.shape() == shape_type{3}); 34 | CHECK(v.shape()[0] == 3); 35 | } 36 | 37 | TEST_CASE("array: vector initializer") { 38 | auto v = array{1, 2, 3, 4}; 39 | CHECK(v.element_count() == 4); 40 | } 41 | 42 | TEST_CASE("vector: container") { 43 | std::vector a{1, 2, 3, 4}; 44 | 45 | auto v1 = array({a.size() - 1}, a); 46 | CHECK(v1.element_count() == 3); 47 | CHECK(array_equal(v1, {1, 2, 3})); 48 | 49 | auto v2 = array({a.size() + 1}, a); 50 | CHECK(v2.element_count() == 5); 51 | CHECK(array_equal(v1, {1, 2, 3})); 52 | 53 | auto v3 = array(a); 54 | CHECK(v3.element_count() == 4); 55 | CHECK(array_equal(v3, {1, 2, 3, 4})); 56 | } 57 | 58 | TEST_CASE("array: vector ranges") { 59 | auto v = array({10}, std::views::iota(1) | std::views::take(10)); 60 | CHECK(v.element_count() == 10); 61 | CHECK(array_equal(v, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})); 62 | } 63 | 64 | TEST_CASE("array: vector `clone`") { 65 | auto a = ones({8}); 66 | auto b = a; 67 | a.zeros(); 68 | CHECK(array_equal(a, b)); 69 | 70 | b = a.clone(); 71 | a.ones(); 72 | CHECK(!array_equal(a, b)); 73 | } 74 | 75 | TEST_CASE("array: vector assignment operator") { 76 | auto v = zeros({8}); 77 | for (size_t i = 0; i < v.element_count(); i++) { 78 | v.at(i) = 1; 79 | } 80 | CHECK(array_equal(ones({8}), v)); 81 | } 82 | 83 | TEST_CASE("array: vector bounds check") { 84 | auto v = array({10}, std::views::iota(0) | std::views::take(10)); 85 | CHECK(v.at(9) == 9); 86 | CHECK_THROWS_WITH_AS(v.at(10), "array: index is out of bounds.", 87 | std::runtime_error); 88 | } 89 | 90 | TEST_CASE("array: vector range-for") { 91 | auto v = zeros({8}); 92 | std::fill(v.buffer_data(), v.buffer_data() + v.buffer_element_count(), 1); 93 | CHECK(array_equal(ones({8}), v)); 94 | } 95 | 96 | TEST_CASE("array: vector arithmatic operations") { 97 | constexpr size_t element_count = 16; 98 | 99 | auto a = array{7.82637e-06, 0.131538, 0.755605, 0.45865, 100 | 0.532767, 0.218959, 0.0470446, 0.678865, 101 | 0.679296, 0.934693, 0.383502, 0.519416, 102 | 0.830965, 0.0345721, 0.0534616, 0.5297}; 103 | 104 | auto b = array{0.671149, 0.00769819, 0.383416, 0.0668422, 105 | 0.417486, 0.686773, 0.588977, 0.930436, 106 | 0.846167, 0.526929, 0.0919649, 0.653919, 107 | 0.415999, 0.701191, 0.910321, 0.762198}; 108 | 109 | CHECK(allclose(a + b, {0.671157, 0.139236, 1.13902, 0.525492, 0.950253, 110 | 0.905732, 0.636021, 1.6093, 1.52546, 1.46162, 0.475467, 111 | 1.17334, 1.24696, 0.735763, 0.963782, 1.2919})); 112 | 113 | CHECK(allclose( 114 | a - b, {-0.671141, 0.12384, 0.372189, 0.391808, 0.115281, -0.467814, 115 | -0.541932, -0.251571, -0.166871, 0.407764, 0.291537, -0.134503, 116 | 0.414966, -0.666619, -0.856859, -0.232498})); 117 | 118 | CHECK(allclose( 119 | a * b, {5.25266e-06, 0.0010126, 0.289711, 0.0306572, 0.222423, 0.150375, 120 | 0.0277082, 0.63164, 0.574798, 0.492517, 0.0352687, 0.339656, 121 | 0.345681, 0.0242416, 0.0486672, 0.403736})); 122 | 123 | CHECK( 124 | allclose(a / b, {1.16612e-05, 17.0869, 1.97072, 6.86168, 1.27613, 125 | 0.318823, 0.0798751, 0.72962, 0.802792, 1.77385, 4.17009, 126 | 0.794312, 1.99752, 0.0493048, 0.0587283, 0.694964})); 127 | } 128 | 129 | TEST_CASE("array: vector arithmatic operation errors") { 130 | auto a = random({4}); 131 | auto b = random({8}); 132 | CHECK(!array_equal(a, b)); 133 | CHECK_THROWS_WITH_AS(a + b, "array: invalid operation.", std::runtime_error); 134 | } 135 | 136 | TEST_CASE("array: vector `pow` operation") { 137 | { 138 | auto a = array{1, 2, 3}; 139 | auto b = array{2, 2, 2}; 140 | CHECK(array_equal(a.pow(b), {1, 4, 9})); 141 | CHECK(array_equal(b.pow(a), {2, 4, 8})); 142 | } 143 | { 144 | auto a = array{1.0, 2.0, 3.0}; 145 | auto b = array{2.0, 2.0, 2.0}; 146 | CHECK(allclose(a.pow(b), {1.0, 4.0, 9.0})); 147 | CHECK(allclose(b.pow(a), {2.0, 4.0, 8.0})); 148 | } 149 | } 150 | 151 | //------------------------------------------------------------------------------ 152 | 153 | TEST_CASE("array: matrix size") { 154 | auto m = empty({3, 4}); 155 | CHECK(m.element_count() == 12); 156 | CHECK(m.shape() == shape_type{3, 4}); 157 | CHECK(m.shape()[0] == 3); 158 | CHECK(m.shape()[1] == 4); 159 | CHECK(m.dimension() == 2); 160 | } 161 | 162 | TEST_CASE("array: matrix container") { 163 | auto m1 = array{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; 164 | CHECK(m1.element_count() == 12); 165 | CHECK(m1.dimension() == 1); 166 | CHECK(m1.shape() == shape_type{12}); 167 | CHECK(m1.strides() == strides_type{1}); 168 | 169 | auto m2 = array{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; 170 | CHECK(m2.element_count() == 12); 171 | CHECK(m2.dimension() == 2); 172 | CHECK(m2.shape() == shape_type{3, 4}); 173 | CHECK(m2.strides() == strides_type{4, 1}); 174 | 175 | auto m3 = array{{{1, 2, 3}, {4, 5, 6}}, {{7, 8, 9}, {10, 11, 12}}}; 176 | CHECK(m3.element_count() == 12); 177 | CHECK(m3.dimension() == 3); 178 | CHECK(m3.shape() == shape_type{2, 2, 3}); 179 | CHECK(m3.strides() == strides_type{6, 3, 1}); 180 | 181 | CHECK_THROWS_WITH_AS( 182 | (array{{{1, 2, 3}, {4, 5}}, {{7, 8, 9}, {10, 11, 12}}}), 183 | "array: invalid initializer list.", std::runtime_error); 184 | } 185 | 186 | TEST_CASE("array: matrix ranges") { 187 | auto m = array({3, 4}, std::views::iota(1) | std::views::take(12)); 188 | 189 | size_t i = 0; 190 | for (size_t row = 0; row < m.shape()[0]; row++) { 191 | for (size_t col = 0; col < m.shape()[1]; col++) { 192 | CHECK(m.at(row, col) == m.at(i)); 193 | i++; 194 | } 195 | } 196 | 197 | CHECK(array_equal(m, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}})); 198 | } 199 | 200 | TEST_CASE("array: matrix arithmatic operations") { 201 | auto r = itoa(12); 202 | auto a = array({3, 4}, r); 203 | auto b = array({3, 4}, r); 204 | CHECK(array_equal(a + b, {{2, 4, 6, 8}, {10, 12, 14, 16}, {18, 20, 22, 24}})); 205 | CHECK(array_equal(a - b, {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}})); 206 | CHECK(array_equal(a * b, 207 | {{1, 4, 9, 16}, {25, 36, 49, 64}, {81, 100, 121, 144}})); 208 | CHECK(array_equal(a / b, {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}})); 209 | } 210 | 211 | TEST_CASE("array: matrix arithmatic operations with scalar") { 212 | auto a = array{{1, 2}, {3, 4}}; 213 | CHECK(array_equal(a + 1, {{2, 3}, {4, 5}})); 214 | CHECK(array_equal(a - 1, {{0, 1}, {2, 3}})); 215 | CHECK(array_equal(a * 2, {{2, 4}, {6, 8}})); 216 | CHECK(array_equal(a / 2, {{0.5, 1}, {1.5, 2}})); 217 | CHECK(array_equal(1 + a, {{2, 3}, {4, 5}})); 218 | CHECK(array_equal(1 - a, {{0, -1}, {-2, -3}})); 219 | CHECK(array_equal(2 * a, {{2, 4}, {6, 8}})); 220 | CHECK(array_equal(2 / a, {{2, 1}, {2.0 / 3.0, 0.5}})); 221 | } 222 | 223 | TEST_CASE("array: matrix v*v `dot` operation") { 224 | auto a = array({4}, itoa(4)); 225 | auto b = array({4}, itoa(4)); 226 | auto out = a.dot(b); 227 | CHECK(out.shape() == shape_type{}); 228 | CHECK(array_equal(out, array(30))); 229 | } 230 | 231 | TEST_CASE("array: matrix m*m `dot` operation") { 232 | auto a = array({3, 4}, itoa(12)); 233 | auto b = array({4, 2}, itoa(8)); 234 | auto out = a.dot(b); 235 | CHECK(out.shape() == shape_type{3, 2}); 236 | CHECK(array_equal(out, {{50, 60}, {114, 140}, {178, 220}})); 237 | } 238 | 239 | TEST_CASE("array: matrix v*m `dot` operation") { 240 | auto a = array({4}, itoa(4)); 241 | auto b = array({4, 2}, itoa(8)); 242 | auto out = a.dot(b); 243 | CHECK(out.shape() == shape_type{2}); 244 | CHECK(array_equal(out, {50, 60})); 245 | } 246 | 247 | TEST_CASE("array: matrix m*v `dot` operation") { 248 | auto a = array({2, 4}, itoa(8)); 249 | auto b = array({4}, itoa(4)); 250 | auto out = a.dot(b); 251 | CHECK(out.shape() == shape_type{2}); 252 | CHECK(array_equal(out, {30, 70})); 253 | } 254 | 255 | TEST_CASE("array: matrix transpose") { 256 | auto v = array{1, 2, 3, 4}; 257 | auto vT = v.transpose(); 258 | CHECK(vT.element_count() == 4); 259 | CHECK(vT.dimension() == 2); 260 | CHECK(vT.shape() == shape_type{1, 4}); 261 | CHECK(array_equal(vT, {{1, 2, 3, 4}})); 262 | 263 | auto vT2 = vT.transpose(); 264 | CHECK(vT2.element_count() == 4); 265 | CHECK(vT2.dimension() == 1); 266 | CHECK(vT2.shape() == shape_type{4}); 267 | CHECK(array_equal(vT2, {1, 2, 3, 4})); 268 | 269 | auto m2 = array{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; 270 | auto m2T = m2.transpose(); 271 | CHECK(m2T.element_count() == 12); 272 | CHECK(m2T.dimension() == 2); 273 | CHECK(m2T.shape() == shape_type{4, 3}); 274 | CHECK(array_equal(m2T, {{1, 5, 9}, {2, 6, 10}, {3, 7, 11}, {4, 8, 12}})); 275 | 276 | auto m2T2 = m2T.transpose(); 277 | CHECK(m2T2.element_count() == m2.element_count()); 278 | CHECK(m2T2.dimension() == m2.dimension()); 279 | CHECK(m2T2.shape() == m2.shape()); 280 | CHECK(array_equal(m2T2, m2)); 281 | 282 | auto m3 = array{{{1, 2, 3, 4}, {5, 6, 7, 8}}, 283 | {{9, 10, 11, 12}, {13, 14, 15, 16}}}; 284 | CHECK(m3.element_count() == 16); 285 | CHECK(m3.dimension() == 3); 286 | CHECK(m3.shape() == shape_type{2, 2, 4}); 287 | 288 | auto m3T = m3.transpose(); 289 | CHECK(m3T.element_count() == 16); 290 | CHECK(m3T.dimension() == 3); 291 | CHECK(m3T.shape() == shape_type{4, 2, 2}); 292 | 293 | auto m3T2 = m3T.transpose(); 294 | CHECK(m3T2.element_count() == m3.element_count()); 295 | CHECK(m3T2.dimension() == m3.dimension()); 296 | CHECK(m3T2.shape() == m3.shape()); 297 | CHECK(array_equal(m3T2, m3)); 298 | } 299 | 300 | TEST_CASE("array: matrix broadcast") { 301 | auto a = array{{1, 2, 3}, {4, 5, 6}}; 302 | auto b = a.broadcast({3, 2, 3}); 303 | 304 | CHECK(array_equal(b, {{{1, 2, 3}, {4, 5, 6}}, 305 | {{1, 2, 3}, {4, 5, 6}}, 306 | {{1, 2, 3}, {4, 5, 6}}})); 307 | 308 | CHECK(b.element_count() == 18); 309 | CHECK(b.buffer_element_count() == 6); 310 | CHECK(b.buffer_bytes() == 6 * sizeof(int)); 311 | 312 | CHECK(b.at(0) == 1); 313 | CHECK(b.at(b.element_count() - 1) == 6); 314 | 315 | CHECK(b.at(0, 0, 0) == 1); 316 | CHECK(b.at(1, 1, 0) == 4); 317 | CHECK(b.at(2, 1, 2) == 6); 318 | 319 | CHECK(b.strides().size() == 3); 320 | CHECK(b.strides()[0] == 0); 321 | CHECK(b.strides()[1] == 3); 322 | CHECK(b.strides()[2] == 1); 323 | } 324 | 325 | TEST_CASE("array: matrix arithmatic operations with broadcast") { 326 | auto a_2_3 = array{{1, 2, 3}, {4, 5, 6}}; 327 | auto a_2_2_3 = array{{{1, 2, 3}, {4, 5, 6}}, {{7, 8, 9}, {10, 11, 12}}}; 328 | 329 | auto b = array(1); 330 | auto b_3 = array{1, 2, 3}; 331 | auto b_2_3 = array{{1, 2, 3}, {4, 5, 6}}; 332 | 333 | CHECK(array_equal(a_2_3 + b, {{2, 3, 4}, {5, 6, 7}})); 334 | CHECK(array_equal(a_2_2_3 + b, 335 | {{{2, 3, 4}, {5, 6, 7}}, {{8, 9, 10}, {11, 12, 13}}})); 336 | CHECK(array_equal(a_2_3 + b_3, {{2, 4, 6}, {5, 7, 9}})); 337 | CHECK(array_equal(a_2_2_3 + b_3, 338 | {{{2, 4, 6}, {5, 7, 9}}, {{8, 10, 12}, {11, 13, 15}}})); 339 | CHECK(array_equal(a_2_2_3 + b_2_3, 340 | {{{2, 4, 6}, {8, 10, 12}}, {{8, 10, 12}, {14, 16, 18}}})); 341 | 342 | CHECK(array_equal(b + a_2_3, {{2, 3, 4}, {5, 6, 7}})); 343 | CHECK(array_equal(b + a_2_2_3, 344 | {{{2, 3, 4}, {5, 6, 7}}, {{8, 9, 10}, {11, 12, 13}}})); 345 | CHECK(array_equal(b_3 + a_2_3, {{2, 4, 6}, {5, 7, 9}})); 346 | CHECK(array_equal(b_3 + a_2_2_3, 347 | {{{2, 4, 6}, {5, 7, 9}}, {{8, 10, 12}, {11, 13, 15}}})); 348 | CHECK(array_equal(b_2_3 + a_2_2_3, 349 | {{{2, 4, 6}, {8, 10, 12}}, {{8, 10, 12}, {14, 16, 18}}})); 350 | } 351 | 352 | TEST_CASE("array: matrix slice") { 353 | auto t = array{ 354 | {{1, 2, 3}, {4, 5, 6}}, 355 | {{7, 8, 9}, {10, 11, 12}}, 356 | {{13, 14, 15}, {16, 17, 18}}, 357 | }; 358 | 359 | CHECK_THROWS_WITH_AS(t[3], "array: row is out of bounds.", 360 | std::runtime_error); 361 | 362 | auto m = t[1]; 363 | auto v = m[1]; 364 | auto s = v[1]; 365 | 366 | CHECK(array_equal(m, {{7, 8, 9}, {10, 11, 12}})); 367 | CHECK(array_equal(v, {10, 11, 12})); 368 | CHECK(array_equal(s, array(11))); 369 | 370 | s.at() += 100; 371 | 372 | CHECK(array_equal(t, {{{1, 2, 3}, {4, 5, 6}}, 373 | {{7, 8, 9}, {10, 111, 12}}, 374 | {{13, 14, 15}, {16, 17, 18}}})); 375 | CHECK(array_equal(m, {{7, 8, 9}, {10, 111, 12}})); 376 | CHECK(array_equal(v, {10, 111, 12})); 377 | CHECK(array_equal(s, array(111))); 378 | 379 | m.zeros(); 380 | 381 | CHECK(array_equal(t, {{{1, 2, 3}, {4, 5, 6}}, 382 | {{0, 0, 0}, {0, 0, 0}}, 383 | {{13, 14, 15}, {16, 17, 18}}})); 384 | CHECK(array_equal(m, {{0, 0, 0}, {0, 0, 0}})); 385 | CHECK(array_equal(v, {0, 0, 0})); 386 | CHECK(array_equal(s, array(0))); 387 | } 388 | 389 | //------------------------------------------------------------------------------ 390 | 391 | TEST_CASE("array: aggregate functions") { 392 | auto v = array{1, 2, 3, 4, 5, 6}; 393 | 394 | auto t = array{ 395 | {{1, 2, 3}, {4, 5, 6}}, 396 | {{7, 8, 9}, {10, 11, 12}}, 397 | {{13, 14, 15}, {16, 17, 18}}, 398 | }; 399 | 400 | CHECK(v.min() == 1); 401 | CHECK(v.max() == 6); 402 | CHECK(t.min() == 1); 403 | CHECK(t.max() == 18); 404 | 405 | CHECK(v.sum() == 21); 406 | CHECK(t.sum() == 171); 407 | CHECK(array_equal(t.sum(0), {{21, 24, 27}, {30, 33, 36}})); 408 | CHECK(array_equal(t.sum(1), {{5, 7, 9}, {17, 19, 21}, {29, 31, 33}})); 409 | CHECK(array_equal(t.sum(2), {{6, 15}, {24, 33}, {42, 51}})); 410 | CHECK(is_close(array{1.1, 2.2}.sum(), 3.3)); 411 | CHECK(is_close(array{1, 2}.sum(), 3l)); 412 | 413 | CHECK(v.mean() == 3.5); 414 | CHECK(t.mean() == 9.5); 415 | 416 | CHECK(array_equal(t.mean(0), array{{7, 8, 9}, {10, 11, 12}})); 417 | CHECK(array_equal( 418 | t.mean(1), 419 | array{{2.5, 3.5, 4.5}, {8.5, 9.5, 10.5}, {14.5, 15.5, 16.5}})); 420 | CHECK(array_equal(t.mean(2), array{{2, 5}, {8, 11}, {14, 17}})); 421 | } 422 | 423 | TEST_CASE("array: softmax") { 424 | auto v = array{1, 2, 3, 4, 5, 6}; 425 | auto m = array{{7, 8, 9}, {10, 11, 12}}; 426 | 427 | auto vsm = v.softmax(); 428 | auto msm = m.softmax(); 429 | 430 | CHECK(vsm.sum() == 1); 431 | CHECK(vsm.all([](auto x) { return x <= 1; })); 432 | 433 | CHECK(array_equal(msm.sum(1), array{1, 1})); 434 | CHECK(msm.all([](auto x) { return x <= 1; })); 435 | } 436 | 437 | TEST_CASE("array: iterators") { 438 | auto t = array{ 439 | {{1, 2, 3}, {4, 5, 6}}, 440 | {{7, 8, 9}, {10, 11, 12}}, 441 | {{13, 14, 15}, {16, 17, 18}}, 442 | }; 443 | 444 | for (auto row : t) { 445 | for (auto &x : row.elements()) { 446 | x += 100; 447 | } 448 | } 449 | 450 | const auto ct = t; 451 | 452 | int cur = 101; 453 | for (auto row : ct.rows()) { 454 | for (const auto &x : row.elements()) { 455 | CHECK(x == cur++); 456 | } 457 | } 458 | 459 | cur = 101; 460 | for (auto row : ct.rows()) { 461 | for (auto [a, b, c] : row.rows<3>()) { 462 | CHECK(a == cur++); 463 | CHECK(b == cur++); 464 | CHECK(c == cur++); 465 | } 466 | } 467 | } 468 | -------------------------------------------------------------------------------- /test/test_examples.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "doctest.h" 4 | 5 | using namespace mtl; 6 | 7 | TEST_CASE("example: create empty array") { 8 | auto i = empty({2, 3, 2}); 9 | auto f = empty({2, 3, 2}); 10 | // auto d = empty({2, 3}); // cannot compile... 11 | } 12 | 13 | TEST_CASE("example: create array with constants") { 14 | auto s = array(1); 15 | auto v = array{1, 2, 3, 4, 5, 6}; 16 | auto m = array{{1, 2}, {3, 4}, {5, 6}}; 17 | auto t = array{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}; 18 | 19 | // std::cout << s.print_info() << std::endl << s << std::endl << std::endl; 20 | // std::cout << v.print_info() << std::endl << v << std::endl << std::endl; 21 | // std::cout << m.print_info() << std::endl << m << std::endl << std::endl; 22 | // std::cout << t.print_info() << std::endl << t << std::endl << std::endl; 23 | 24 | // dtype: float, dim: 0, shape: {}, strides: {1} 25 | // 1 26 | // 27 | // dtype: float, dim: 1, shape: {6}, strides: {1} 28 | // {1, 2, 3, 4, 5, 6} 29 | // 30 | // dtype: float, dim: 2, shape: {3, 2}, strides: {2, 1} 31 | // {{1, 2}, 32 | // {3, 4}, 33 | // {5, 6}} 34 | // 35 | // dtype: float, dim: 3, shape: {2, 3, 2}, strides: {6, 2, 1} 36 | // {{{1, 2}, 37 | // {3, 4}, 38 | // {5, 6}}, 39 | // 40 | // {{7, 8}, 41 | // {9, 10}, 42 | // {11, 12}}} 43 | } 44 | 45 | TEST_CASE("example: create array with shape") { 46 | auto zeros1 = array({2, 3, 2}, 0); 47 | auto zeros2 = zeros({2, 3, 2}); 48 | CHECK(array_equal(zeros1, zeros2)); 49 | 50 | auto ones1 = array({2, 3, 2}, 1); 51 | auto ones2 = ones({2, 3, 2}); 52 | CHECK(array_equal(ones1, ones2)); 53 | 54 | auto rand = random({2, 3, 2}); 55 | CHECK(rand.all([](auto val) { return 0 <= val && val < 1.0; })); 56 | 57 | auto v = std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; 58 | auto from_iterator = array({2, 3, 2}, v.begin()); 59 | auto from_range1 = array({2, 3, 2}, v); 60 | auto from_range2 = array({2, 3, 2}, std::views::iota(1)); 61 | auto expected = array({ 62 | {{1, 2}, {3, 4}, {5, 6}}, 63 | {{7, 8}, {9, 10}, {11, 12}}, 64 | }); 65 | CHECK(array_equal(from_iterator, expected)); 66 | CHECK(array_equal(from_range1, expected)); 67 | CHECK(array_equal(from_range2, expected)); 68 | } 69 | 70 | TEST_CASE("example: clone array") { 71 | auto a = ones({4}); 72 | 73 | auto cloned = a.clone(); 74 | cloned.zeros(); 75 | CHECK(array_equal(a, {1, 1, 1, 1})); 76 | 77 | auto assigned = a; 78 | assigned.zeros(); 79 | CHECK(array_equal(a, {0, 0, 0, 0})); 80 | } 81 | 82 | TEST_CASE("example: arithmatic operations") { 83 | auto a = array{{1, 2}, {3, 4}}; 84 | auto b = array{{1, 2}, {3, 4}}; 85 | 86 | auto add = a + b; 87 | CHECK(array_equal(add, {{2, 4}, {6, 8}})); 88 | 89 | auto sub = a - b; 90 | CHECK(array_equal(sub, {{0, 0}, {0, 0}})); 91 | 92 | auto mul = a * b; 93 | CHECK(array_equal(mul, {{1, 4}, {9, 16}})); 94 | 95 | auto div = a / b; 96 | CHECK(array_equal(div, {{1, 1}, {1, 1}})); 97 | } 98 | 99 | TEST_CASE("example: dot operation") { 100 | auto x = array{1, 2, 3}; 101 | auto W = array{{1, 2}, {3, 4}, {5, 6}}; 102 | 103 | auto y = x.dot(W); 104 | CHECK(array_equal(y, {22, 28})); 105 | } 106 | 107 | -------------------------------------------------------------------------------- /test/test_perceptron.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "doctest.h" 4 | 5 | class LogicGate { 6 | private: 7 | float w0 = 0.1; 8 | float w1 = 0.1; 9 | float b = 0.1; 10 | 11 | auto range(size_t count) { 12 | return std::views::iota(1) | std::views::take(count); 13 | } 14 | 15 | int predict(int x0, int x1) { 16 | auto y = (x0 * w0) + (x1 * w1) + b; 17 | return y > 0 ? 1 : 0; 18 | } 19 | 20 | public: 21 | LogicGate(mtl::array&& dataset) { 22 | auto max_iteration = 10; 23 | auto learning_rate = 1.0; 24 | 25 | for (auto n : range(max_iteration)) { 26 | for (auto [x0, x1, t] : dataset.rows<3>()) { 27 | auto y = predict(x0, x1); 28 | auto diff = t - y; 29 | auto update = diff * learning_rate; 30 | 31 | w0 += update * x0; 32 | w1 += update * x1; 33 | b += update; 34 | } 35 | } 36 | } 37 | 38 | int operator()(int x0, int x1) { return predict(x0, x1); } 39 | }; 40 | 41 | TEST_CASE("perceptron: nand") { 42 | auto AND = LogicGate({ 43 | {0, 0, 0}, 44 | {0, 1, 0}, 45 | {1, 0, 0}, 46 | {1, 1, 1}, 47 | }); 48 | 49 | auto OR = LogicGate({ 50 | {0, 0, 0}, 51 | {0, 1, 1}, 52 | {1, 0, 1}, 53 | {1, 1, 1}, 54 | }); 55 | 56 | auto NAND = LogicGate({ 57 | {0, 0, 1}, 58 | {0, 1, 1}, 59 | {1, 0, 1}, 60 | {1, 1, 0}, 61 | }); 62 | 63 | auto XOR = [&](int x0, int x1) { return AND(NAND(x0, x1), OR(x0, x1)); }; 64 | 65 | CHECK(AND(0, 0) == 0); 66 | CHECK(AND(0, 1) == 0); 67 | CHECK(AND(1, 0) == 0); 68 | CHECK(AND(1, 1) == 1); 69 | 70 | CHECK(OR(0, 0) == 0); 71 | CHECK(OR(0, 1) == 1); 72 | CHECK(OR(1, 0) == 1); 73 | CHECK(OR(1, 1) == 1); 74 | 75 | CHECK(NAND(0, 0) == 1); 76 | CHECK(NAND(0, 1) == 1); 77 | CHECK(NAND(1, 0) == 1); 78 | CHECK(NAND(1, 1) == 0); 79 | 80 | CHECK(XOR(0, 0) == 0); 81 | CHECK(XOR(0, 1) == 1); 82 | CHECK(XOR(1, 0) == 1); 83 | CHECK(XOR(1, 1) == 0); 84 | } 85 | --------------------------------------------------------------------------------