├── .gitattributes ├── tinnet_examples ├── mnist │ ├── src │ │ └── main.cpp │ └── BUILD └── basic │ ├── src │ └── main.cpp │ └── BUILD ├── .bazelrc ├── tinnet ├── includes │ ├── node │ │ ├── Type.h │ │ ├── kernel │ │ │ ├── MathFunction.h │ │ │ ├── NNFunction.h │ │ │ └── BasicArithmetic.h │ │ ├── Node.h │ │ ├── Shape.h │ │ └── Builder.h │ ├── platform │ │ ├── Platform.h │ │ └── CallingConvention.h │ ├── compute │ │ ├── Denormal.h │ │ └── GEMM.h │ └── memory │ │ └── ScopedStorage.h ├── tests │ ├── memory │ │ └── scopedstorage.cpp │ ├── node │ │ ├── kernel │ │ │ ├── neg.cpp │ │ │ ├── log.cpp │ │ │ ├── add.cpp │ │ │ ├── sub.cpp │ │ │ ├── mul.cpp │ │ │ ├── relu.cpp │ │ │ └── div.cpp │ │ └── shape.cpp │ └── helper │ │ └── Random.h ├── src │ ├── compute │ │ ├── Denormal.cpp │ │ └── GEMM.cpp │ ├── node │ │ ├── kernel │ │ │ ├── MathFunction.cpp │ │ │ ├── NNFunction.cpp │ │ │ └── BasicArithmetic.cpp │ │ ├── Node.cpp │ │ ├── Shape.cpp │ │ └── Builder.cpp │ └── memory │ │ └── ScopedStorage.cpp └── BUILD ├── README.md ├── .gitignore ├── WORKSPACE └── .clang-format /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto -------------------------------------------------------------------------------- /tinnet_examples/mnist/src/main.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | 4 | int main() 5 | { 6 | return 0; 7 | } -------------------------------------------------------------------------------- /.bazelrc: -------------------------------------------------------------------------------- 1 | build --crosstool_top=@llvm_toolchain//:toolchain 2 | test --crosstool_top=@llvm_toolchain//:toolchain -------------------------------------------------------------------------------- /tinnet_examples/basic/src/main.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | 6 | int main() 7 | { 8 | return 0; 9 | } -------------------------------------------------------------------------------- /tinnet/includes/node/Type.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _TINNET_NODE_TYPE_H 3 | 4 | #define _TINNET_NODE_TYPE_H 5 | 6 | namespace tinnet::node { 7 | enum class Type { F32 }; 8 | } 9 | 10 | #endif -------------------------------------------------------------------------------- /tinnet/includes/platform/Platform.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _TINNET_PLATFORM_PLATFORM_H 3 | 4 | #define _TINNET_PLATFORM_PLATFORM_H 5 | 6 | #ifdef _MSC_VER 7 | 8 | # define _TINNET_PLATFORM_WINDOWS 9 | 10 | #else 11 | 12 | # define _TINNET_PLATFORM_UNIX 13 | 14 | #endif 15 | 16 | #endif -------------------------------------------------------------------------------- /tinnet_examples/mnist/BUILD: -------------------------------------------------------------------------------- 1 | # https://docs.bazel.build/versions/master/be/c-cpp.html#cc_binary 2 | cc_binary( 3 | name = "mnist", 4 | srcs = ["src/main.cpp"], 5 | copts = [ 6 | "-std=c++17", 7 | "-O3 -mllvm -polly", 8 | ], 9 | deps = ["//tinnet:main"], 10 | ) 11 | -------------------------------------------------------------------------------- /tinnet_examples/basic/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_binary") 2 | 3 | # https://docs.bazel.build/versions/master/be/c-cpp.html#cc_binary 4 | cc_binary( 5 | name = "basic", 6 | srcs = ["src/main.cpp"], 7 | copts = [ 8 | "-std=c++17", 9 | "-O3 -mllvm -polly", 10 | ], 11 | deps = ["//tinnet:main"], 12 | ) 13 | -------------------------------------------------------------------------------- /tinnet/includes/compute/Denormal.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _TINNET_COMPUTE_DENORMAL_H 3 | 4 | #define _TINNET_COMPUTE_DENORMAL_H 5 | 6 | namespace tinnet::compute { 7 | class Denormal final { 8 | public: 9 | const bool bFTZ; 10 | const bool bDAZ; 11 | 12 | public: 13 | Denormal(); 14 | ~Denormal(); 15 | }; 16 | } // namespace tinnet::compute 17 | 18 | #endif -------------------------------------------------------------------------------- /tinnet/includes/node/kernel/MathFunction.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _TINNET_NODE_KERNEL_MATHFUNCTION_H 3 | 4 | #define _TINNET_NODE_KERNEL_MATHFUNCTION_H 5 | 6 | #include "tinnet/includes/memory/ScopedStorage.h" 7 | #include "tinnet/includes/node/Node.h" 8 | 9 | namespace tinnet::node::kernel { 10 | memory::ScopedStorage __kernel__log(Node *pNode); 11 | void __kernel__logGradient(Node *pNode, Node *pDeps); 12 | } // namespace tinnet::node::kernel 13 | 14 | #endif -------------------------------------------------------------------------------- /tinnet/includes/node/kernel/NNFunction.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _TINNET_NODE_KERNEL_NNFUNCTION_H 3 | 4 | #define _TINNET_NODE_KERNEL_NNFUNCTION_H 5 | 6 | #include "tinnet/includes/memory/ScopedStorage.h" 7 | #include "tinnet/includes/node/Node.h" 8 | 9 | namespace tinnet::node::kernel { 10 | memory::ScopedStorage __kernel__relu(Node *pNode, float nA); 11 | void __kernel__reluGradient(Node *pNode, Node *pDeps, float nA); 12 | } // namespace tinnet::node::kernel 13 | 14 | #endif -------------------------------------------------------------------------------- /tinnet/tests/memory/scopedstorage.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "tinnet/includes/memory/ScopedStorage.h" 3 | 4 | #include "catch2.hpp" 5 | 6 | #include 7 | #include 8 | 9 | TEST_CASE("tinnet::memory::ScopedStorage") 10 | { 11 | SECTION("Alignment") 12 | { 13 | for (std::uint64_t nIndex{0}; nIndex < 1024; ++nIndex) { 14 | tinnet::memory::ScopedStorage sStorage{sizeof(float)}; 15 | 16 | CHECK((reinterpret_cast(sStorage.aligned()) % 32 == 0) == true); 17 | } 18 | } 19 | } -------------------------------------------------------------------------------- /tinnet/includes/platform/CallingConvention.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _TINNET_PLATFORM_CALLINGCONVENTION_H 3 | 4 | #define _TINNET_PLATFORM_CALLINGCONVENTION_H 5 | 6 | #ifdef _TINNET_PLATFORM_WINDOWS 7 | 8 | # define _TINNET_CDECL __cdecl 9 | # define _TINNET_STDCALL __stdcall 10 | # define _TINNET_REGCALL __regcall 11 | 12 | #else 13 | 14 | # define _TINNET_CDECL __attribute__((cdecl)) 15 | # define _TINNET_STDCALL __attribute__((stdcall)) 16 | # define _TINNET_REGCALL __attribute__((regcall)) 17 | 18 | #endif 19 | 20 | #endif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TinNet@2 2 | 3 | A compact DNN library. 4 | 5 | ## Build 6 | 7 | This project uses [Bazel](https://bazel.build/) as a build system(1.0 or above required) and compiles with [Clang](https://clang.llvm.org/)(**NOT** required, automatically fulfilled). 8 | 9 | To build, issue below. 10 | 11 | ``` 12 | bazel build //tinnet:main 13 | ``` 14 | 15 | Windows are not supported currently. 16 | 17 | ## Test 18 | 19 | To run all tests, issue below. 20 | 21 | ``` 22 | bazel test //tinnet:test 23 | ``` 24 | 25 | ## Example Usages 26 | 27 | Please refer the example projects. 28 | 29 | ## Benchmarks 30 | 31 | TODO 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore backup files. 2 | *~ 3 | # Ignore Vim swap files. 4 | .*.swp 5 | # Ignore files generated by IDEs. 6 | /.classpath 7 | /.factorypath 8 | /.idea/ 9 | /.ijwb/ 10 | /.project 11 | /.settings 12 | /.vs/ 13 | /.vscode/ 14 | /bazel.iml 15 | # Ignore all bazel-* symlinks. There is no full list since this can change 16 | # based on the name of the directory bazel is cloned into. 17 | /bazel-* 18 | # Ignore outputs generated during Bazel bootstrapping. 19 | /output/ 20 | # Ignore jekyll build output. 21 | /production 22 | /.sass-cache 23 | # Bazelisk version file 24 | .bazelversion 25 | # User-specific .bazelrc 26 | user.bazelrc 27 | 28 | .DS_Store -------------------------------------------------------------------------------- /tinnet/src/compute/Denormal.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "tinnet/includes/compute/Denormal.h" 3 | 4 | #include 5 | 6 | namespace tinnet::compute { 7 | Denormal::Denormal() : 8 | bFTZ{_MM_GET_FLUSH_ZERO_MODE() == _MM_FLUSH_ZERO_ON}, 9 | bDAZ{_MM_GET_DENORMALS_ZERO_MODE() == _MM_DENORMALS_ZERO_ON} { 10 | _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); 11 | _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); 12 | } 13 | 14 | Denormal::~Denormal() { 15 | _MM_SET_FLUSH_ZERO_MODE(this->bFTZ ? _MM_FLUSH_ZERO_ON : _MM_FLUSH_ZERO_OFF); 16 | _MM_SET_DENORMALS_ZERO_MODE(this->bDAZ ? _MM_DENORMALS_ZERO_ON : _MM_DENORMALS_ZERO_OFF); 17 | } 18 | } // namespace tinnet::compute -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | 3 | http_archive( 4 | name = "catch2", 5 | strip_prefix = "catch2-bazel-2.11.1", 6 | url = "https://github.com/AcrylicShrimp/catch2-bazel/archive/v2.11.1.tar.gz", 7 | ) 8 | 9 | http_archive( 10 | name = "com_grail_bazel_toolchain", 11 | strip_prefix = "bazel-toolchain-master", 12 | urls = ["https://github.com/grailbio/bazel-toolchain/archive/master.tar.gz"], 13 | ) 14 | 15 | load("@com_grail_bazel_toolchain//toolchain:deps.bzl", "bazel_toolchain_dependencies") 16 | 17 | bazel_toolchain_dependencies() 18 | 19 | load("@com_grail_bazel_toolchain//toolchain:rules.bzl", "llvm_toolchain") 20 | 21 | llvm_toolchain( 22 | name = "llvm_toolchain", 23 | llvm_version = "9.0.0", 24 | ) 25 | 26 | load("@llvm_toolchain//:toolchains.bzl", "llvm_register_toolchains") 27 | 28 | llvm_register_toolchains() 29 | -------------------------------------------------------------------------------- /tinnet/tests/node/kernel/neg.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "catch2.hpp" 3 | #include "tinnet/includes/node/Builder.h" 4 | #include "tinnet/includes/node/Shape.h" 5 | #include "tinnet/tests/helper/Random.h" 6 | 7 | #include 8 | 9 | TEST_CASE("tinnet::node::kernel::BasicArithmetic neg") 10 | { 11 | auto nLength{tinnet::test::helper::Random::genIndex()}; 12 | 13 | auto sNode{tinnet::test::helper::Random::genData(nLength)}; 14 | auto pNode{tinnet::node::Builder::memory(tinnet::node::Shape{{nLength}}, sNode.data(), true)}; 15 | auto pResult{-pNode}; 16 | 17 | SECTION("Check forward") 18 | { 19 | for (std::size_t nIndex{0}; nIndex < nLength; ++nIndex) 20 | CHECK(pResult->output().aligned()[nIndex] == Approx(-pNode->output().aligned()[nIndex])); 21 | } 22 | 23 | pResult->computeGradient(); 24 | 25 | SECTION("Check backward") 26 | { 27 | for (std::size_t nIndex{0}; nIndex < nLength; ++nIndex) 28 | CHECK(pNode->gradient().aligned()[nIndex] == Approx(-1.f)); 29 | } 30 | } -------------------------------------------------------------------------------- /tinnet/includes/memory/ScopedStorage.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _TINNET_MEMORY_SCOPEDSTORAGE_H 3 | 4 | #define _TINNET_MEMORY_SCOPEDSTORAGE_H 5 | 6 | #include 7 | 8 | namespace tinnet::memory { 9 | class ScopedStorage final { 10 | private: 11 | void *pOrigin; 12 | void *pAligned; 13 | 14 | public: 15 | ScopedStorage(); 16 | ScopedStorage(std::size_t nSize); 17 | ScopedStorage(ScopedStorage &&sRhs) noexcept; 18 | ~ScopedStorage() noexcept; 19 | 20 | public: 21 | ScopedStorage &operator=(ScopedStorage &&sRhs); 22 | 23 | public: 24 | inline void *origin() const noexcept 25 | { 26 | return this->pOrigin; 27 | } 28 | template 29 | inline T *origin() const noexcept 30 | { 31 | return reinterpret_cast(this->pOrigin); 32 | } 33 | inline void *aligned() const noexcept 34 | { 35 | return this->pAligned; 36 | } 37 | template 38 | inline T *aligned() const noexcept 39 | { 40 | return reinterpret_cast(this->pAligned); 41 | } 42 | 43 | public: 44 | // TODO: Provide swap support here. 45 | }; 46 | } // namespace tinnet::memory 47 | 48 | #endif -------------------------------------------------------------------------------- /tinnet/tests/node/kernel/log.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "catch2.hpp" 3 | #include "tinnet/includes/node/Builder.h" 4 | #include "tinnet/includes/node/Shape.h" 5 | #include "tinnet/tests/helper/Random.h" 6 | 7 | #include 8 | #include 9 | 10 | TEST_CASE("tinnet::node::kernel::MathFunction log") 11 | { 12 | auto nLength{tinnet::test::helper::Random::genIndex()}; 13 | 14 | auto sNode{tinnet::test::helper::Random::genPositiveData(nLength)}; 15 | auto pNode{tinnet::node::Builder::memory(tinnet::node::Shape{{nLength}}, sNode.data(), true)}; 16 | auto pResult{tinnet::node::Builder::log(pNode)}; 17 | 18 | SECTION("Check forward") 19 | { 20 | for (std::size_t nIndex{0}; nIndex < nLength; ++nIndex) 21 | CHECK( 22 | pResult->output().aligned()[nIndex] 23 | == Approx(std::log(pNode->output().aligned()[nIndex] + 1e-5f))); 24 | } 25 | 26 | pResult->computeGradient(); 27 | 28 | SECTION("Check backward") 29 | { 30 | for (std::size_t nIndex{0}; nIndex < nLength; ++nIndex) 31 | CHECK( 32 | pNode->gradient().aligned()[nIndex] 33 | == Approx(1.f / (pNode->output().aligned()[nIndex] + 1e-5f))); 34 | } 35 | } -------------------------------------------------------------------------------- /tinnet/src/node/kernel/MathFunction.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "tinnet/includes/node/kernel/MathFunction.h" 3 | 4 | #include 5 | #include 6 | 7 | namespace tinnet::node::kernel { 8 | memory::ScopedStorage __kernel__log(Node *pNode) 9 | { 10 | if (!pNode) throw std::runtime_error{"invalid node"}; 11 | 12 | auto nSize{pNode->sShape.size()}; 13 | memory::ScopedStorage sResult{sizeof(float) * nSize}; 14 | 15 | auto *__restrict pD{sResult.aligned()}; 16 | const auto *__restrict pL{pNode->output().aligned()}; 17 | 18 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] = std::log(pL[nIndex] + 1e-5f); 19 | 20 | return sResult; 21 | } 22 | 23 | void __kernel__logGradient(Node *pNode, Node *pDeps) 24 | { 25 | auto nSize{pNode->sShape.size()}; 26 | auto *__restrict pD{pDeps->gradient().aligned()}; 27 | const auto *__restrict pG{pNode->gradient().aligned()}; 28 | const auto *__restrict pL{pNode->deps()[0]->output().aligned()}; 29 | 30 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] += pG[nIndex] / (pL[nIndex] + 1e-5f); 31 | } 32 | } // namespace tinnet::node::kernel -------------------------------------------------------------------------------- /tinnet/includes/node/kernel/BasicArithmetic.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _TINNET_NODE_KERNEL_BASICARITHMETIC_H 3 | 4 | #define _TINNET_NODE_KERNEL_BASICARITHMETIC_H 5 | 6 | #include "tinnet/includes/memory/ScopedStorage.h" 7 | #include "tinnet/includes/node/Node.h" 8 | 9 | namespace tinnet::node::kernel { 10 | memory::ScopedStorage __kernel__neg(Node *pNode); 11 | memory::ScopedStorage __kernel__add(Node *pLeft, Node *pRight); 12 | memory::ScopedStorage __kernel__sub(Node *pLeft, Node *pRight); 13 | memory::ScopedStorage __kernel__mul(Node *pLeft, Node *pRight); 14 | memory::ScopedStorage __kernel__div(Node *pLeft, Node *pRight); 15 | void __kernel__negGradient(Node *pNode, Node *pDeps); 16 | void __kernel__addGradient(Node *pNode, Node *pDeps); 17 | void __kernel__subLGradient(Node *pNode, Node *pDeps); 18 | void __kernel__subRGradient(Node *pNode, Node *pDeps); 19 | void __kernel__mulLGradient(Node *pNode, Node *pDeps); 20 | void __kernel__mulRGradient(Node *pNode, Node *pDeps); 21 | void __kernel__divLGradient(Node *pNode, Node *pDeps); 22 | void __kernel__divRGradient(Node *pNode, Node *pDeps); 23 | } // namespace tinnet::node::kernel 24 | 25 | #endif -------------------------------------------------------------------------------- /tinnet/src/memory/ScopedStorage.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "tinnet/includes/memory/ScopedStorage.h" 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace tinnet::memory { 9 | ScopedStorage::ScopedStorage() : pOrigin{nullptr}, pAligned{nullptr} {} 10 | 11 | ScopedStorage::ScopedStorage(std::size_t nSize) 12 | { 13 | this->pOrigin = std::malloc(nSize + 32); 14 | this->pAligned = reinterpret_cast( 15 | (reinterpret_cast(this->pOrigin) + 31) & ~31); // 32byte alignment. 16 | } 17 | 18 | ScopedStorage::ScopedStorage(ScopedStorage &&sRhs) noexcept : ScopedStorage() 19 | { 20 | using std::swap; 21 | 22 | swap(this->pOrigin, sRhs.pOrigin); 23 | swap(this->pAligned, sRhs.pAligned); 24 | } 25 | 26 | ScopedStorage::~ScopedStorage() noexcept 27 | { 28 | if (this->pOrigin) std::free(this->pOrigin); 29 | } 30 | 31 | ScopedStorage &ScopedStorage::operator=(ScopedStorage &&sRhs) 32 | { 33 | if (&sRhs == this) return *this; 34 | 35 | this->~ScopedStorage(); 36 | 37 | this->pOrigin = sRhs.pOrigin; 38 | this->pAligned = sRhs.pAligned; 39 | sRhs.pOrigin = nullptr; 40 | sRhs.pAligned = nullptr; 41 | 42 | return *this; 43 | } 44 | } // namespace tinnet::memory -------------------------------------------------------------------------------- /tinnet/tests/node/kernel/add.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "catch2.hpp" 3 | #include "tinnet/includes/node/Builder.h" 4 | #include "tinnet/includes/node/Shape.h" 5 | #include "tinnet/tests/helper/Random.h" 6 | 7 | #include 8 | 9 | TEST_CASE("tinnet::node::kernel::BasicArithmetic add") 10 | { 11 | auto nLength{tinnet::test::helper::Random::genIndex()}; 12 | 13 | auto sLeft{tinnet::test::helper::Random::genData(nLength)}; 14 | auto sRight{tinnet::test::helper::Random::genData(nLength)}; 15 | 16 | auto pLeft{tinnet::node::Builder::memory(tinnet::node::Shape{{nLength}}, sLeft.data(), true)}; 17 | auto pRight{tinnet::node::Builder::memory(tinnet::node::Shape{{nLength}}, sRight.data(), true)}; 18 | auto pResult{pLeft + pRight}; 19 | 20 | SECTION("Check forward") 21 | { 22 | for (std::size_t nIndex{0}; nIndex < nLength; ++nIndex) 23 | CHECK(pResult->output().aligned()[nIndex] == Approx(sLeft[nIndex] + sRight[nIndex])); 24 | } 25 | 26 | pResult->computeGradient(); 27 | 28 | SECTION("Check backward") 29 | { 30 | for (std::size_t nIndex{0}; nIndex < nLength; ++nIndex) { 31 | CHECK(pLeft->gradient().aligned()[nIndex] == Approx(1.f)); 32 | CHECK(pRight->gradient().aligned()[nIndex] == Approx(1.f)); 33 | } 34 | } 35 | } -------------------------------------------------------------------------------- /tinnet/tests/node/kernel/sub.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "catch2.hpp" 3 | #include "tinnet/includes/node/Builder.h" 4 | #include "tinnet/includes/node/Shape.h" 5 | #include "tinnet/tests/helper/Random.h" 6 | 7 | #include 8 | 9 | TEST_CASE("tinnet::node::kernel::BasicArithmetic sub") 10 | { 11 | auto nLength{tinnet::test::helper::Random::genIndex()}; 12 | 13 | auto sLeft{tinnet::test::helper::Random::genData(nLength)}; 14 | auto sRight{tinnet::test::helper::Random::genData(nLength)}; 15 | 16 | auto pLeft{tinnet::node::Builder::memory(tinnet::node::Shape{{nLength}}, sLeft.data(), true)}; 17 | auto pRight{tinnet::node::Builder::memory(tinnet::node::Shape{{nLength}}, sRight.data(), true)}; 18 | auto pResult{pLeft - pRight}; 19 | 20 | SECTION("Check forward") 21 | { 22 | for (std::size_t nIndex{0}; nIndex < nLength; ++nIndex) 23 | CHECK(pResult->output().aligned()[nIndex] == Approx(sLeft[nIndex] - sRight[nIndex])); 24 | } 25 | 26 | pResult->computeGradient(); 27 | 28 | SECTION("Check backward") 29 | { 30 | for (std::size_t nIndex{0}; nIndex < nLength; ++nIndex) { 31 | CHECK(pLeft->gradient().aligned()[nIndex] == Approx(1.f)); 32 | CHECK(pRight->gradient().aligned()[nIndex] == Approx(-1.f)); 33 | } 34 | } 35 | } -------------------------------------------------------------------------------- /tinnet/tests/node/kernel/mul.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "catch2.hpp" 3 | #include "tinnet/includes/node/Builder.h" 4 | #include "tinnet/includes/node/Shape.h" 5 | #include "tinnet/tests/helper/Random.h" 6 | 7 | #include 8 | 9 | TEST_CASE("tinnet::node::kernel::BasicArithmetic mul") 10 | { 11 | auto nLength{tinnet::test::helper::Random::genIndex()}; 12 | 13 | auto sLeft{tinnet::test::helper::Random::genData(nLength)}; 14 | auto sRight{tinnet::test::helper::Random::genData(nLength)}; 15 | 16 | auto pLeft{tinnet::node::Builder::memory(tinnet::node::Shape{{nLength}}, sLeft.data(), true)}; 17 | auto pRight{tinnet::node::Builder::memory(tinnet::node::Shape{{nLength}}, sRight.data(), true)}; 18 | auto pResult{pLeft * pRight}; 19 | 20 | SECTION("Check forward") 21 | { 22 | for (std::size_t nIndex{0}; nIndex < nLength; ++nIndex) 23 | CHECK(pResult->output().aligned()[nIndex] == Approx(sLeft[nIndex] * sRight[nIndex])); 24 | } 25 | 26 | pResult->computeGradient(); 27 | 28 | SECTION("Check backward") 29 | { 30 | for (std::size_t nIndex{0}; nIndex < nLength; ++nIndex) { 31 | CHECK(pLeft->gradient().aligned()[nIndex] == Approx(pRight->output().aligned()[nIndex])); 32 | CHECK(pRight->gradient().aligned()[nIndex] == Approx(pLeft->output().aligned()[nIndex])); 33 | } 34 | } 35 | } -------------------------------------------------------------------------------- /tinnet/tests/node/kernel/relu.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "catch2.hpp" 3 | #include "tinnet/includes/node/Builder.h" 4 | #include "tinnet/includes/node/Shape.h" 5 | #include "tinnet/tests/helper/Random.h" 6 | 7 | #include 8 | 9 | TEST_CASE("tinnet::node::kernel::NNFunction relu") 10 | { 11 | auto nLength{tinnet::test::helper::Random::genIndex()}; 12 | 13 | auto sNode{tinnet::test::helper::Random::genData(nLength)}; 14 | auto pNode{tinnet::node::Builder::memory(tinnet::node::Shape{{nLength}}, sNode.data(), true)}; 15 | auto pResult{tinnet::node::Builder::relu(pNode, .1f)}; 16 | 17 | SECTION("Check forward") 18 | { 19 | auto fRectify{[](float nV, float nA) { 20 | return nV < .0f ? nA * nV : nV; 21 | }}; 22 | 23 | for (std::size_t nIndex{0}; nIndex < nLength; ++nIndex) 24 | CHECK( 25 | pResult->output().aligned()[nIndex] 26 | == Approx(fRectify(pNode->output().aligned()[nIndex], .1f))); 27 | } 28 | 29 | pResult->computeGradient(); 30 | 31 | SECTION("Check backward") 32 | { 33 | auto fRectify{[](float nV, float nA) { 34 | return nV < .0f ? nA : 1.f; 35 | }}; 36 | 37 | for (std::size_t nIndex{0}; nIndex < nLength; ++nIndex) 38 | CHECK( 39 | pNode->gradient().aligned()[nIndex] 40 | == Approx(fRectify(pNode->output().aligned()[nIndex], .1f))); 41 | } 42 | } -------------------------------------------------------------------------------- /tinnet/src/node/kernel/NNFunction.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "tinnet/includes/node/kernel/NNFunction.h" 3 | 4 | #include 5 | 6 | namespace tinnet::node::kernel { 7 | memory::ScopedStorage __kernel__relu(Node *pNode, float nA) 8 | { 9 | if (!pNode) throw std::runtime_error{"invalid node"}; 10 | 11 | auto nSize{pNode->sShape.size()}; 12 | memory::ScopedStorage sResult{sizeof(float) * nSize}; 13 | 14 | auto *__restrict pD{sResult.aligned()}; 15 | const auto *__restrict pL{pNode->output().aligned()}; 16 | 17 | auto fRectify{[](float nV, float nA) { 18 | return nV < .0f ? nA * nV : nV; 19 | }}; 20 | 21 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] = fRectify(pL[nIndex], nA); 22 | 23 | return sResult; 24 | } 25 | 26 | void __kernel__reluGradient(Node *pNode, Node *pDeps, float nA) 27 | { 28 | auto nSize{pNode->sShape.size()}; 29 | auto *__restrict pD{pDeps->gradient().aligned()}; 30 | const auto *__restrict pG{pNode->gradient().aligned()}; 31 | const auto *__restrict pL{pNode->deps()[0]->output().aligned()}; 32 | 33 | auto fRectify{[](float nV, float nA, float nG) { 34 | return nV < .0f ? nA * nG : nG; 35 | }}; 36 | 37 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] += fRectify(pL[nIndex], nA, pG[nIndex]); 38 | } 39 | } // namespace tinnet::node::kernel -------------------------------------------------------------------------------- /tinnet/tests/node/kernel/div.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "catch2.hpp" 3 | #include "tinnet/includes/node/Builder.h" 4 | #include "tinnet/includes/node/Shape.h" 5 | #include "tinnet/tests/helper/Random.h" 6 | 7 | #include 8 | 9 | TEST_CASE("tinnet::node::kernel::BasicArithmetic div") 10 | { 11 | auto nLength{tinnet::test::helper::Random::genIndex()}; 12 | 13 | auto sLeft{tinnet::test::helper::Random::genData(nLength)}; 14 | auto sRight{tinnet::test::helper::Random::genData(nLength)}; 15 | 16 | auto pLeft{tinnet::node::Builder::memory(tinnet::node::Shape{{nLength}}, sLeft.data(), true)}; 17 | auto pRight{tinnet::node::Builder::memory(tinnet::node::Shape{{nLength}}, sRight.data(), true)}; 18 | auto pResult{pLeft / pRight}; 19 | 20 | SECTION("Check forward") 21 | { 22 | for (std::size_t nIndex{0}; nIndex < nLength; ++nIndex) 23 | CHECK(pResult->output().aligned()[nIndex] == Approx(sLeft[nIndex] / (sRight[nIndex] + 1e-5f))); 24 | } 25 | 26 | pResult->computeGradient(); 27 | 28 | SECTION("Check backward") 29 | { 30 | for (std::size_t nIndex{0}; nIndex < nLength; ++nIndex) { 31 | CHECK( 32 | pLeft->gradient().aligned()[nIndex] 33 | == Approx(1.f / (pRight->output().aligned()[nIndex] + 1e-5f))); 34 | CHECK( 35 | pRight->gradient().aligned()[nIndex] 36 | == Approx( 37 | pLeft->output().aligned()[nIndex] 38 | / (pRight->output().aligned()[nIndex] * pRight->output().aligned()[nIndex] + 1e-5f))); 39 | } 40 | } 41 | } -------------------------------------------------------------------------------- /tinnet/tests/helper/Random.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _TINNET_TEST_HELPER_RANDOM_H 3 | 4 | #define _TINNET_TEST_HELPER_RANDOM_H 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace tinnet::test::helper { 12 | class Random final { 13 | public: 14 | Random() = delete; 15 | ~Random() noexcept = delete; 16 | 17 | public: 18 | static std::vector genPositiveData(std::size_t nLength) 19 | { 20 | std::mt19937 sMT{std::random_device{}()}; 21 | std::uniform_real_distribution sDist{.0f, 1024.f}; 22 | 23 | std::vector sResult(nLength, .0f); 24 | 25 | for (auto &nElement: sResult) nElement = sDist(sMT); 26 | 27 | return sResult; 28 | } 29 | 30 | static std::vector genData(std::size_t nLength) 31 | { 32 | std::mt19937 sMT{std::random_device{}()}; 33 | std::uniform_real_distribution sDist{-1024.f, 1024.f}; 34 | 35 | std::vector sResult(nLength, .0f); 36 | 37 | for (auto &nElement: sResult) nElement = sDist(sMT); 38 | 39 | return sResult; 40 | } 41 | 42 | static std::size_t genIndex() 43 | { 44 | std::mt19937 sMT{std::random_device{}()}; 45 | std::uniform_int_distribution sDist{1, 128}; 46 | 47 | return sDist(sMT); 48 | } 49 | 50 | static void loopIndex(std::size_t nMaxIndex, const std::function &fFunc) 51 | { 52 | for (std::size_t nIndex{0}; nIndex < nMaxIndex; ++nIndex) fFunc(nIndex); 53 | } 54 | }; 55 | } // namespace tinnet::test::helper 56 | 57 | #endif -------------------------------------------------------------------------------- /tinnet/includes/node/Node.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _TINNET_NODE_NODE_H 3 | 4 | #define _TINNET_NODE_NODE_H 5 | 6 | #include "tinnet/includes/memory/ScopedStorage.h" 7 | #include "tinnet/includes/node/Shape.h" 8 | #include "tinnet/includes/node/Type.h" 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | namespace tinnet::node { 15 | class Node { 16 | public: 17 | using GFunc = std::function; 18 | 19 | public: 20 | const Type eType; 21 | const Shape sShape; 22 | const std::size_t nElement; 23 | 24 | private: 25 | bool bGradientEnabled; 26 | memory::ScopedStorage sOutput; 27 | memory::ScopedStorage sGradient; 28 | std::vector sDeps; // Nodes that this instance depends on. 29 | std::vector sGFunction; 30 | 31 | public: 32 | Node( 33 | Type eType, 34 | Shape && sShape, 35 | bool bGradientEnabled, 36 | memory::ScopedStorage &&sOutput, 37 | std::vector && sDeps, 38 | std::vector && sGFunction); 39 | virtual ~Node() noexcept = default; 40 | 41 | public: 42 | void computeGradient(); 43 | 44 | public: 45 | bool gradientEnabled() const noexcept 46 | { 47 | return this->bGradientEnabled; 48 | } 49 | const memory::ScopedStorage &output() const noexcept 50 | { 51 | return this->sOutput; 52 | } 53 | const memory::ScopedStorage &gradient() const noexcept 54 | { 55 | return this->sGradient; 56 | } 57 | const std::vector &deps() const noexcept 58 | { 59 | return this->sDeps; 60 | } 61 | }; 62 | } // namespace tinnet::node 63 | 64 | #endif -------------------------------------------------------------------------------- /tinnet/includes/compute/GEMM.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _TINNET_COMPUTE_GEMM_H 3 | 4 | #define _TINNET_COMPUTE_GEMM_H 5 | 6 | #include "tinnet/includes/platform/CallingConvention.h" 7 | 8 | #include 9 | 10 | namespace tinnet::compute { 11 | class GEMM final { 12 | public: 13 | GEMM() = delete; 14 | ~GEMM() = delete; 15 | 16 | public: 17 | static void _TINNET_REGCALL multiply( 18 | std::size_t nMaxIndex, 19 | std::size_t nRow, 20 | std::size_t nColumn, 21 | const float *__restrict pL, 22 | const float *__restrict pR, 23 | float *__restrict pD) noexcept; 24 | static void _TINNET_REGCALL multiplyAdd( 25 | std::size_t nMaxIndex, 26 | std::size_t nRow, 27 | std::size_t nColumn, 28 | const float *__restrict pL, 29 | const float *__restrict pR, 30 | float *__restrict pD) noexcept; 31 | 32 | static void _TINNET_REGCALL dMultiplyLeft( 33 | std::size_t nMaxIndex, 34 | std::size_t nRow, 35 | std::size_t nColumn, 36 | const float *__restrict pG, 37 | const float *__restrict pR, 38 | float *__restrict pD) noexcept; 39 | static void _TINNET_REGCALL dMultiplyAddLeft( 40 | std::size_t nMaxIndex, 41 | std::size_t nRow, 42 | std::size_t nColumn, 43 | const float *__restrict pG, 44 | const float *__restrict pR, 45 | float *__restrict pD) noexcept; 46 | 47 | static void _TINNET_REGCALL dMultiplyRight( 48 | std::size_t nMaxIndex, 49 | std::size_t nRow, 50 | std::size_t nColumn, 51 | const float *__restrict pG, 52 | const float *__restrict pL, 53 | float *__restrict pD) noexcept; 54 | static void _TINNET_REGCALL dMultiplyAddRight( 55 | std::size_t nMaxIndex, 56 | std::size_t nRow, 57 | std::size_t nColumn, 58 | const float *__restrict pG, 59 | const float *__restrict pL, 60 | float *__restrict pD) noexcept; 61 | }; 62 | } // namespace tinnet::compute 63 | 64 | #endif -------------------------------------------------------------------------------- /tinnet/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") 2 | 3 | # https://docs.bazel.build/versions/master/be/c-cpp.html#cc_library 4 | cc_library( 5 | name = "main", 6 | srcs = [ 7 | "src/compute/Denormal.cpp", 8 | "src/compute/GEMM.cpp", 9 | "src/memory/ScopedStorage.cpp", 10 | "src/node/Builder.cpp", 11 | "src/node/Node.cpp", 12 | "src/node/Shape.cpp", 13 | "src/node/kernel/BasicArithmetic.cpp", 14 | "src/node/kernel/MathFunction.cpp", 15 | "src/node/kernel/NNFunction.cpp", 16 | ], 17 | hdrs = [ 18 | "includes/compute/Denormal.h", 19 | "includes/compute/GEMM.h", 20 | "includes/memory/ScopedStorage.h", 21 | "includes/node/Builder.h", 22 | "includes/node/Node.h", 23 | "includes/node/Shape.h", 24 | "includes/node/Type.h", 25 | "includes/node/kernel/BasicArithmetic.h", 26 | "includes/node/kernel/MathFunction.h", 27 | "includes/node/kernel/NNFunction.h", 28 | "includes/platform/CallingConvention.h", 29 | "includes/platform/Platform.h", 30 | ], 31 | copts = [ 32 | "-std=c++17", 33 | "-ffast-math", 34 | "-fopenmp", 35 | "-mavx2", 36 | "-mfma", 37 | "-O3 -mllvm -polly", 38 | ], 39 | visibility = ["//visibility:public"], 40 | ) 41 | 42 | cc_test( 43 | name = "test", 44 | timeout = "short", 45 | srcs = [ 46 | "tests/helper/Random.h", 47 | "tests/memory/scopedstorage.cpp", 48 | "tests/node/kernel/add.cpp", 49 | "tests/node/kernel/div.cpp", 50 | "tests/node/kernel/log.cpp", 51 | "tests/node/kernel/mul.cpp", 52 | "tests/node/kernel/neg.cpp", 53 | "tests/node/kernel/relu.cpp", 54 | "tests/node/kernel/sub.cpp", 55 | "tests/node/shape.cpp", 56 | ], 57 | copts = [ 58 | "-std=c++17", 59 | "-ffast-math", 60 | "-O3 -mllvm -polly", 61 | ], 62 | deps = [ 63 | "//tinnet:main", 64 | "@catch2//:main", 65 | ], 66 | ) 67 | -------------------------------------------------------------------------------- /tinnet/src/node/Node.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "tinnet/includes/node/Node.h" 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace tinnet::node { 9 | Node::Node( 10 | Type eType, 11 | Shape && sShape, 12 | bool bGradientEnabled, 13 | memory::ScopedStorage &&sOutput, 14 | std::vector && sDeps, 15 | std::vector && sGFunction) : 16 | eType{eType}, 17 | sShape{std::move(sShape)}, 18 | nElement{this->sShape.size()}, 19 | bGradientEnabled{bGradientEnabled}, 20 | sOutput{std::move(sOutput)}, 21 | sDeps{std::move(sDeps)}, 22 | sGFunction{std::move(sGFunction)} 23 | { 24 | if (!this->bGradientEnabled) 25 | for (auto pDepsNode: this->sDeps) 26 | if ((this->bGradientEnabled = this->bGradientEnabled || pDepsNode->bGradientEnabled)) break; 27 | 28 | if (this->bGradientEnabled) this->sGradient = memory::ScopedStorage{sizeof(float) * this->sShape.size()}; 29 | } 30 | 31 | void Node::computeGradient() 32 | { 33 | if (!this->bGradientEnabled) return; 34 | 35 | // Fills the target gradient buffer with ones. 36 | std::fill(this->sGradient.aligned(), this->sGradient.aligned() + this->nElement, 1.f); 37 | 38 | // Builds a dependency list for only nodes that enabled gradient. 39 | std::vector> sDepsChain; 40 | sDepsChain.reserve(this->sDeps.size()); 41 | 42 | for (std::size_t nD{0}, nMaxD{this->sDeps.size()}; nD < nMaxD; ++nD) 43 | if (this->sDeps[nD]->bGradientEnabled) sDepsChain.emplace_back(this, this->sDeps[nD], nD); 44 | 45 | for (std::size_t nIndex{0}; nIndex != sDepsChain.size();) 46 | for (std::size_t nMaxIndex{sDepsChain.size()}; nIndex < nMaxIndex; ++nIndex) { 47 | auto *pNode{std::get<1>(sDepsChain[nIndex])}; 48 | 49 | sDepsChain.reserve(sDepsChain.size() + pNode->sDeps.size()); 50 | 51 | for (std::size_t nD{0}, nMaxD{pNode->sDeps.size()}; nD < nMaxD; ++nD) 52 | if (pNode->sDeps[nD]->bGradientEnabled) sDepsChain.emplace_back(pNode, pNode->sDeps[nD], nD); 53 | } 54 | 55 | // TODO: Support multiple types here. 56 | // Fills with zeros for all nodes 57 | for (auto &sChainDeps: sDepsChain) { 58 | auto *pNode{std::get<1>(sChainDeps)}; 59 | 60 | std::fill(pNode->sGradient.aligned(), pNode->sGradient.aligned() + pNode->nElement, .0f); 61 | } 62 | 63 | // Calls all gradient computation kernels. 64 | for (auto &sChainDeps: sDepsChain) { 65 | auto nD{std::get<2>(sChainDeps)}; 66 | auto *pNode{std::get<0>(sChainDeps)}; 67 | auto *pDepsNode{std::get<1>(sChainDeps)}; 68 | 69 | pNode->sGFunction[nD](pNode, pDepsNode); 70 | } 71 | } 72 | } // namespace tinnet::node -------------------------------------------------------------------------------- /tinnet/includes/node/Shape.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _TINNET_NODE_SHAPE_H 3 | 4 | #define _TINNET_NODE_SHAPE_H 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace tinnet::node { 14 | class Shape final { 15 | private: 16 | std::vector sDimension; 17 | 18 | public: 19 | Shape(); 20 | Shape(const std::vector &sRhs); 21 | Shape(std::vector &&sRhs) noexcept; 22 | Shape(const Shape &sRhs) = default; 23 | Shape(Shape &&sRhs) noexcept = default; 24 | ~Shape() noexcept = default; 25 | 26 | public: 27 | Shape & operator=(const std::vector &sRhs); 28 | Shape & operator=(std::vector &&sRhs) noexcept; 29 | Shape & operator=(const Shape &sRhs) = default; 30 | Shape & operator=(Shape &&sRhs) noexcept = default; 31 | bool operator==(const Shape &sRhs) const; 32 | friend bool operator==(const Shape &sLhs, const std::vector &sRhs); 33 | friend bool operator==(const std::vector &sLhs, const Shape &sRhs); 34 | bool operator!=(const Shape &sRhs) const; 35 | friend bool operator!=(const Shape &sLhs, const std::vector &sRhs); 36 | friend bool operator!=(const std::vector &sLhs, const Shape &sRhs); 37 | friend std::ostream &operator<<(std::ostream &sLhs, const Shape &sRhs); 38 | Shape extend() const; 39 | Shape extend(std::size_t nRank) const; 40 | Shape shrink() const; 41 | Shape squeeze() const; 42 | static Shape broadcast(const Shape &sLhs, const Shape &sRhs); 43 | 44 | public: 45 | std::size_t &operator[](std::size_t nIndex) 46 | { 47 | return this->sDimension[nIndex]; 48 | } 49 | std::size_t operator[](std::size_t nIndex) const 50 | { 51 | return this->sDimension[nIndex]; 52 | } 53 | std::size_t rank() const noexcept 54 | { 55 | return this->sDimension.size(); 56 | } 57 | std::size_t size() const noexcept 58 | { 59 | return std::accumulate( 60 | this->sDimension.cbegin(), 61 | this->sDimension.cend(), 62 | 1, 63 | std::multiplies{}); 64 | } 65 | bool scalar() const noexcept 66 | { 67 | return this->rank() == 0; 68 | } 69 | bool vector() const noexcept 70 | { 71 | return this->rank() == 1; 72 | } 73 | bool matrix() const noexcept 74 | { 75 | return this->rank() == 2; 76 | } 77 | bool tensor() const noexcept 78 | { 79 | return this->rank() >= 3; 80 | } 81 | 82 | public: 83 | friend void swap(Shape &sLhs, Shape &sRhs) noexcept 84 | { 85 | using std::swap; 86 | swap(sLhs.sDimension, sRhs.sDimension); 87 | } 88 | }; 89 | 90 | } // namespace tinnet::node 91 | 92 | #endif -------------------------------------------------------------------------------- /tinnet/includes/node/Builder.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _TINNET_NODE_BUILDER_H 3 | 4 | #define _TINNET_NODE_BUILDER_H 5 | 6 | #include "tinnet/includes/memory/ScopedStorage.h" 7 | #include "tinnet/includes/node/Node.h" 8 | #include "tinnet/includes/node/Shape.h" 9 | #include "tinnet/includes/node/Type.h" 10 | 11 | #include 12 | 13 | namespace tinnet::node { 14 | class Builder final { 15 | public: 16 | Builder() = delete; 17 | ~Builder() = delete; 18 | 19 | public: 20 | // TODO: Support multiple types here. 21 | // static std::unique_ptr memory(Type eType, Shape &&sShape, bool bGradientEnabled, std::uint8_t 22 | // *pOutput); 23 | static std::unique_ptr memory(Shape &&sShape, const float *pSource, bool bGradientEnabled = false); 24 | static std::unique_ptr neg(const std::unique_ptr &sLeft, bool bGradientEnabled = false); 25 | static std::unique_ptr 26 | add(const std::unique_ptr &sLeft, const std::unique_ptr &sRight, bool bGradientEnabled = false); 27 | static std::unique_ptr 28 | sub(const std::unique_ptr &sLeft, const std::unique_ptr &sRight, bool bGradientEnabled = false); 29 | static std::unique_ptr 30 | mul(const std::unique_ptr &sLeft, const std::unique_ptr &sRight, bool bGradientEnabled = false); 31 | static std::unique_ptr 32 | div(const std::unique_ptr &sLeft, const std::unique_ptr &sRight, bool bGradientEnabled = false); 33 | 34 | static std::unique_ptr log(const std::unique_ptr &sLeft, bool bGradientEnabled = false); 35 | 36 | static std::unique_ptr 37 | relu(const std::unique_ptr &sLeft, float nA = .1f, bool bGradientEnabled = false); 38 | }; 39 | 40 | } // namespace tinnet::node 41 | 42 | inline std::unique_ptr operator-(std::unique_ptr &sNode) 43 | { 44 | return tinnet::node::Builder::neg(sNode, false); 45 | } 46 | 47 | inline std::unique_ptr 48 | operator+(std::unique_ptr &sLeft, std::unique_ptr &sRight) 49 | { 50 | return tinnet::node::Builder::add(sLeft, sRight, false); 51 | } 52 | 53 | inline std::unique_ptr 54 | operator-(std::unique_ptr &sLeft, std::unique_ptr &sRight) 55 | { 56 | return tinnet::node::Builder::sub(sLeft, sRight, false); 57 | } 58 | 59 | inline std::unique_ptr 60 | operator*(std::unique_ptr &sLeft, std::unique_ptr &sRight) 61 | { 62 | return tinnet::node::Builder::mul(sLeft, sRight, false); 63 | } 64 | 65 | inline std::unique_ptr 66 | operator/(std::unique_ptr &sLeft, std::unique_ptr &sRight) 67 | { 68 | return tinnet::node::Builder::div(sLeft, sRight, false); 69 | } 70 | 71 | #endif -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | 2 | BasedOnStyle: LLVM 3 | 4 | ColumnLimit: 120 5 | 6 | TabWidth: 4 7 | IndentWidth: 4 8 | ContinuationIndentWidth: 4 9 | AccessModifierOffset: -4 10 | 11 | UseTab: Always 12 | 13 | Language: Cpp 14 | Standard: Cpp11 15 | 16 | KeepEmptyLinesAtTheStartOfBlocks: false 17 | MaxEmptyLinesToKeep: 1 18 | 19 | AlignAfterOpenBracket: AlwaysBreak 20 | AlignConsecutiveAssignments: true 21 | AlignConsecutiveDeclarations: true 22 | # AlignConsecutiveMacros: true 23 | AlignEscapedNewlines: Right 24 | AlignOperands: true 25 | AlignTrailingComments: true 26 | 27 | AllowAllArgumentsOnNextLine: false 28 | AllowAllConstructorInitializersOnNextLine: false 29 | AllowAllParametersOfDeclarationOnNextLine: false 30 | AllowShortBlocksOnASingleLine: true # Always 31 | AllowShortCaseLabelsOnASingleLine: true 32 | AllowShortFunctionsOnASingleLine: Empty 33 | AllowShortIfStatementsOnASingleLine: true # Always 34 | AllowShortLambdasOnASingleLine: Empty 35 | AllowShortLoopsOnASingleLine: true 36 | AlwaysBreakAfterReturnType: None 37 | AlwaysBreakBeforeMultilineStrings: false 38 | AlwaysBreakTemplateDeclarations: Yes 39 | 40 | BinPackArguments: false 41 | BinPackParameters: false 42 | 43 | BreakBeforeBinaryOperators: All 44 | BreakBeforeTernaryOperators: true 45 | BreakConstructorInitializers: AfterColon 46 | BreakInheritanceList: AfterColon 47 | BreakStringLiterals: true 48 | BreakBeforeBraces: Custom 49 | BraceWrapping: 50 | AfterCaseLabel: false 51 | AfterClass: false 52 | AfterEnum: false 53 | AfterFunction: true 54 | AfterNamespace: false 55 | AfterStruct: false 56 | AfterUnion: false 57 | AfterExternBlock: false 58 | BeforeCatch: true 59 | BeforeElse: true 60 | IndentBraces: false 61 | # BraceWrappingAfterControlStatementStyle: Never 62 | 63 | SpaceAfterCStyleCast: false 64 | SpaceAfterLogicalNot: false 65 | SpaceAfterTemplateKeyword: false 66 | SpaceBeforeAssignmentOperators: true 67 | SpaceBeforeCpp11BracedList: false 68 | SpaceBeforeCtorInitializerColon: true 69 | SpaceBeforeInheritanceColon: true 70 | SpaceBeforeParens: ControlStatements 71 | SpaceBeforeRangeBasedForLoopColon: false 72 | # SpaceInEmptyBlock: false 73 | SpaceInEmptyParentheses: false 74 | SpacesBeforeTrailingComments: 4 75 | SpacesInAngles: false 76 | SpacesInCStyleCastParentheses: false 77 | SpacesInContainerLiterals: false 78 | SpacesInParentheses: false 79 | SpacesInSquareBrackets: false 80 | 81 | IndentCaseLabels: false 82 | # IndentGotoLabels: false 83 | IndentPPDirectives: AfterHash 84 | IndentWrappedFunctionNames: true 85 | NamespaceIndentation: All 86 | 87 | CompactNamespaces: false 88 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 89 | Cpp11BracedListStyle: true 90 | 91 | ReflowComments: true 92 | FixNamespaceComments: true 93 | 94 | IncludeBlocks: Regroup 95 | SortIncludes: true 96 | SortUsingDeclarations: true 97 | 98 | DerivePointerAlignment: false 99 | PointerAlignment: Right 100 | 101 | ForEachMacros: ['BOOST_FOREACH'] -------------------------------------------------------------------------------- /tinnet/src/node/Shape.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "tinnet/includes/node/Shape.h" 3 | 4 | #include 5 | #include 6 | 7 | namespace tinnet::node { 8 | Shape::Shape() 9 | { 10 | this->sDimension.reserve(4); 11 | } 12 | 13 | Shape::Shape(const std::vector &sRhs) : sDimension{sRhs.cbegin(), sRhs.cend()} {} 14 | 15 | Shape::Shape(std::vector &&sRhs) noexcept : sDimension{std::move(sRhs)} {} 16 | 17 | Shape &Shape::operator=(const std::vector &sRhs) 18 | { 19 | this->sDimension = sRhs; 20 | return *this; 21 | } 22 | 23 | Shape &Shape::operator=(std::vector &&sRhs) noexcept 24 | { 25 | this->sDimension = std::move(sRhs); 26 | return *this; 27 | } 28 | 29 | bool Shape::operator==(const Shape &sRhs) const 30 | { 31 | return this->sDimension == sRhs.sDimension; 32 | } 33 | 34 | bool operator==(const Shape &sLhs, const std::vector &sRhs) 35 | { 36 | return sLhs.sDimension == sRhs; 37 | } 38 | 39 | bool operator==(const std::vector &sLhs, const Shape &sRhs) 40 | { 41 | return sLhs == sRhs.sDimension; 42 | } 43 | 44 | bool Shape::operator!=(const Shape &sRhs) const 45 | { 46 | return this->sDimension != sRhs.sDimension; 47 | } 48 | 49 | bool operator!=(const Shape &sLhs, const std::vector &sRhs) 50 | { 51 | return sLhs.sDimension != sRhs; 52 | } 53 | 54 | bool operator!=(const std::vector &sLhs, const Shape &sRhs) 55 | { 56 | return sLhs != sRhs.sDimension; 57 | } 58 | 59 | std::ostream &operator<<(std::ostream &sLhs, const Shape &sRhs) 60 | { 61 | sLhs << "["; 62 | 63 | for (std::size_t nIndex{0}, nMaxIndex{sRhs.sDimension.size()}; nIndex < nMaxIndex; ++nIndex) { 64 | sLhs << sRhs.sDimension[nIndex]; 65 | 66 | if (nIndex + 1 != nMaxIndex) sLhs << ", "; 67 | } 68 | 69 | return sLhs << "]"; 70 | } 71 | 72 | Shape Shape::extend() const 73 | { 74 | auto sResult{*this}; 75 | sResult.sDimension.insert(sResult.sDimension.cbegin(), 1); 76 | 77 | return sResult; 78 | } 79 | 80 | Shape Shape::extend(std::size_t nRank) const 81 | { 82 | auto sResult{*this}; 83 | 84 | while (sResult.rank() < nRank) sResult.sDimension.insert(sResult.sDimension.cbegin(), 1); 85 | 86 | return sResult; 87 | } 88 | 89 | Shape Shape::shrink() const 90 | { 91 | auto sResult{*this}; 92 | 93 | while (!sResult.sDimension.empty() && sResult.sDimension.front() == 1) 94 | sResult.sDimension.erase(sResult.sDimension.cbegin()); 95 | 96 | return sResult; 97 | } 98 | 99 | Shape Shape::squeeze() const 100 | { 101 | Shape sResult; 102 | 103 | for (auto nSize: this->sDimension) 104 | if (nSize != 1) sResult.sDimension.emplace_back(nSize); 105 | 106 | return sResult; 107 | } 108 | 109 | Shape Shape::broadcast(const Shape &sLhs, const Shape &sRhs) 110 | { 111 | Shape sResult; 112 | 113 | auto iL{sLhs.sDimension.crbegin()}, iLEnd{sLhs.sDimension.crend()}, iR{sRhs.sDimension.crbegin()}, 114 | iREnd{sRhs.sDimension.crend()}; 115 | 116 | // TODO: Add special case handling - dimension contains 0 - here. 117 | 118 | for (; iL != iLEnd && iR != iREnd; ++iL, ++iR) { 119 | if (*iL != *iR && *iL != 1 && *iR != 1) throw std::runtime_error{"unable to broadcast"}; 120 | 121 | sResult.sDimension.insert(sResult.sDimension.cbegin(), std::max(*iL, *iR)); 122 | } 123 | 124 | if (iL != iLEnd) 125 | for (; iL != iLEnd; ++iL) sResult.sDimension.insert(sResult.sDimension.cbegin(), *iL); 126 | else if (iR != iREnd) 127 | for (; iR != iREnd; ++iR) sResult.sDimension.insert(sResult.sDimension.cbegin(), *iR); 128 | 129 | return sResult; 130 | } 131 | } // namespace tinnet::node -------------------------------------------------------------------------------- /tinnet/src/node/Builder.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "tinnet/includes/node/Builder.h" 3 | 4 | #include "tinnet/includes/node/kernel/BasicArithmetic.h" 5 | #include "tinnet/includes/node/kernel/MathFunction.h" 6 | #include "tinnet/includes/node/kernel/NNFunction.h" 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace tinnet::node { 14 | // std::unique_ptr Builder::wrap(Type eType, Shape &&sShape, bool bGradientEnabled, std::uint8_t *pOutput) 15 | // { 16 | // return std::make_unique( 17 | // eType, 18 | // std::move(sShape), 19 | // bGradientEnabled, 20 | // pOutput, 21 | // std::vector{}, 22 | // std::vector{}); 23 | // } 24 | 25 | std::unique_ptr Builder::memory(Shape &&sShape, const float *pSource, bool bGradientEnabled) 26 | { 27 | memory::ScopedStorage sOutput{sizeof(float) * sShape.size()}; 28 | std::copy(pSource, pSource + sShape.size(), sOutput.aligned()); 29 | 30 | return std::make_unique( 31 | Type::F32, 32 | std::move(sShape), 33 | bGradientEnabled, 34 | std::move(sOutput), 35 | std::vector{}, 36 | std::vector{}); 37 | } 38 | 39 | std::unique_ptr Builder::neg(const std::unique_ptr &sLeft, bool bGradientEnabled) 40 | { 41 | return std::make_unique( 42 | Type::F32, 43 | Shape{sLeft->sShape}, 44 | bGradientEnabled, 45 | kernel::__kernel__neg(sLeft.get()), 46 | std::vector{sLeft.get()}, 47 | std::vector{&kernel::__kernel__negGradient}); 48 | } 49 | 50 | std::unique_ptr 51 | Builder::add(const std::unique_ptr &sLeft, const std::unique_ptr &sRight, bool bGradientEnabled) 52 | { 53 | return std::make_unique( 54 | Type::F32, 55 | Shape{sLeft->sShape}, 56 | bGradientEnabled, 57 | kernel::__kernel__add(sLeft.get(), sRight.get()), 58 | std::vector{sLeft.get(), sRight.get()}, 59 | std::vector{&kernel::__kernel__addGradient, &kernel::__kernel__addGradient}); 60 | } 61 | 62 | std::unique_ptr 63 | Builder::sub(const std::unique_ptr &sLeft, const std::unique_ptr &sRight, bool bGradientEnabled) 64 | { 65 | return std::make_unique( 66 | Type::F32, 67 | Shape{sLeft->sShape}, 68 | bGradientEnabled, 69 | kernel::__kernel__sub(sLeft.get(), sRight.get()), 70 | std::vector{sLeft.get(), sRight.get()}, 71 | std::vector{&kernel::__kernel__subLGradient, &kernel::__kernel__subRGradient}); 72 | } 73 | 74 | std::unique_ptr 75 | Builder::mul(const std::unique_ptr &sLeft, const std::unique_ptr &sRight, bool bGradientEnabled) 76 | { 77 | return std::make_unique( 78 | Type::F32, 79 | Shape{sLeft->sShape}, 80 | bGradientEnabled, 81 | kernel::__kernel__mul(sLeft.get(), sRight.get()), 82 | std::vector{sLeft.get(), sRight.get()}, 83 | std::vector{&kernel::__kernel__mulLGradient, &kernel::__kernel__mulRGradient}); 84 | } 85 | 86 | std::unique_ptr 87 | Builder::div(const std::unique_ptr &sLeft, const std::unique_ptr &sRight, bool bGradientEnabled) 88 | { 89 | return std::make_unique( 90 | Type::F32, 91 | Shape{sLeft->sShape}, 92 | bGradientEnabled, 93 | kernel::__kernel__div(sLeft.get(), sRight.get()), 94 | std::vector{sLeft.get(), sRight.get()}, 95 | std::vector{&kernel::__kernel__divLGradient, &kernel::__kernel__divRGradient}); 96 | } 97 | 98 | std::unique_ptr Builder::log(const std::unique_ptr &sLeft, bool bGradientEnabled) 99 | { 100 | return std::make_unique( 101 | Type::F32, 102 | Shape{sLeft->sShape}, 103 | bGradientEnabled, 104 | kernel::__kernel__log(sLeft.get()), 105 | std::vector{sLeft.get()}, 106 | std::vector{&kernel::__kernel__logGradient}); 107 | } 108 | 109 | std::unique_ptr Builder::relu(const std::unique_ptr &sLeft, float nA, bool bGradientEnabled) 110 | { 111 | return std::make_unique( 112 | Type::F32, 113 | Shape{sLeft->sShape}, 114 | bGradientEnabled, 115 | kernel::__kernel__relu(sLeft.get(), nA), 116 | std::vector{sLeft.get()}, 117 | std::vector{[nA](auto *pNode, auto *pGradient) { 118 | kernel::__kernel__reluGradient(pNode, pGradient, nA); 119 | }}); 120 | } 121 | } // namespace tinnet::node -------------------------------------------------------------------------------- /tinnet/src/compute/GEMM.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "tinnet/includes/compute/GEMM.h" 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace tinnet::compute { 10 | void _TINNET_REGCALL GEMM::multiply( 11 | std::size_t nMaxIndex, 12 | std::size_t nRow, 13 | std::size_t nColumn, 14 | const float *__restrict pL, 15 | const float *__restrict pR, 16 | float *__restrict pD) noexcept 17 | { 18 | std::fill(pD, pD + nRow * nColumn, .0f); 19 | GEMM::multiplyAdd(nMaxIndex, nRow, nColumn, pL, pR, pD); 20 | } 21 | 22 | void _TINNET_REGCALL GEMM::multiplyAdd( 23 | std::size_t nMaxIndex, 24 | std::size_t nRow, 25 | std::size_t nColumn, 26 | const float *__restrict pL, 27 | const float *__restrict pR, 28 | float *__restrict pD) noexcept 29 | { 30 | #pragma omp parallel for schedule(guided) default(shared) num_threads(static_cast (std::max ( \ 31 | 1u, \ 32 | std::min (nMaxIndex * nRow * nColumn / 1600000u, std::thread::hardware_concurrency())))) 33 | for (std::int64_t nR = 0; nR < static_cast(nRow); ++nR) 34 | for (std::size_t nC{0}; nC < nColumn; ++nC) { 35 | std::size_t nIndex{0}; 36 | auto sSum = _mm256_setzero_ps(); 37 | 38 | for (; nIndex + 8 <= nMaxIndex; nIndex += 8) 39 | sSum = _mm256_fmadd_ps( 40 | _mm256_loadu_ps(pL + nR * nMaxIndex + nIndex), 41 | _mm256_loadu_ps(pR + nC * nMaxIndex + nIndex), 42 | sSum); 43 | 44 | const auto sSum128 = _mm_add_ps(_mm256_extractf128_ps(sSum, 1), _mm256_castps256_ps128(sSum)); 45 | const auto sSum64 = _mm_add_ps(sSum128, _mm_movehl_ps(sSum128, sSum128)); 46 | const auto sSum32 = _mm_add_ss(sSum64, _mm_shuffle_ps(sSum64, sSum64, 0x55)); 47 | auto nSum = _mm_cvtss_f32(sSum32); 48 | 49 | for (; nIndex < nMaxIndex; ++nIndex) nSum += pL[nR * nMaxIndex + nIndex] * pR[nC * nMaxIndex + nIndex]; 50 | 51 | pD[nR * nColumn + nC] += nSum; 52 | } 53 | } 54 | 55 | void _TINNET_REGCALL GEMM::dMultiplyLeft( 56 | std::size_t nMaxIndex, 57 | std::size_t nRow, 58 | std::size_t nColumn, 59 | const float *__restrict pG, 60 | const float *__restrict pR, 61 | float *__restrict pD) noexcept 62 | { 63 | std::fill(pD, pD + nRow * nMaxIndex, .0f); 64 | GEMM::dMultiplyAddLeft(nMaxIndex, nRow, nColumn, pG, pR, pD); 65 | } 66 | 67 | void _TINNET_REGCALL GEMM::dMultiplyAddLeft( 68 | std::size_t nMaxIndex, 69 | std::size_t nRow, 70 | std::size_t nColumn, 71 | const float *__restrict pG, 72 | const float *__restrict pR, 73 | float *__restrict pD) noexcept 74 | { 75 | std::vector sRightTransposed(nMaxIndex * nColumn); 76 | auto *__restrict pRT{sRightTransposed.data()}; 77 | 78 | #pragma omp parallel default(shared) num_threads(static_cast (std::max ( \ 79 | 1u, \ 80 | std::min (nMaxIndex * nColumn / 1600000u, std::thread::hardware_concurrency())))) 81 | { 82 | #pragma omp for schedule(guided) 83 | for (std::int64_t nR = 0; nR < static_cast(nColumn); ++nR) 84 | for (std::size_t nC{0}; nC < nMaxIndex; ++nC) pRT[nC * nColumn + nR] = pR[nR * nMaxIndex + nC]; 85 | 86 | #pragma omp for schedule(guided) 87 | for (std::int64_t nR = 0; nR < static_cast(nRow); ++nR) 88 | for (std::size_t nC{0}; nC < nMaxIndex; ++nC) { 89 | std::size_t nIndex{0}; 90 | auto sSum = _mm256_setzero_ps(); 91 | 92 | for (; nIndex + 8 <= nColumn; nIndex += 8) 93 | sSum = _mm256_fmadd_ps( 94 | _mm256_loadu_ps(pG + nR * nColumn + nIndex), 95 | _mm256_loadu_ps(pRT + nC * nColumn + nIndex), 96 | sSum); 97 | 98 | const auto sSum128 = _mm_add_ps(_mm256_extractf128_ps(sSum, 1), _mm256_castps256_ps128(sSum)); 99 | const auto sSum64 = _mm_add_ps(sSum128, _mm_movehl_ps(sSum128, sSum128)); 100 | const auto sSum32 = _mm_add_ss(sSum64, _mm_shuffle_ps(sSum64, sSum64, 0x55)); 101 | auto nSum = _mm_cvtss_f32(sSum32); 102 | 103 | for (; nIndex < nColumn; ++nIndex) nSum += pG[nR * nColumn + nIndex] * pRT[nC * nColumn + nIndex]; 104 | 105 | pD[nR * nMaxIndex + nC] += nSum; 106 | } 107 | } 108 | } 109 | 110 | void _TINNET_REGCALL GEMM::dMultiplyRight( 111 | std::size_t nMaxIndex, 112 | std::size_t nRow, 113 | std::size_t nColumn, 114 | const float *__restrict pG, 115 | const float *__restrict pL, 116 | float *__restrict pD) noexcept 117 | { 118 | std::fill(pD, pD + nColumn * nMaxIndex, .0f); 119 | GEMM::dMultiplyAddRight(nMaxIndex, nRow, nColumn, pG, pL, pD); 120 | } 121 | 122 | void _TINNET_REGCALL GEMM::dMultiplyAddRight( 123 | std::size_t nMaxIndex, 124 | std::size_t nRow, 125 | std::size_t nColumn, 126 | const float *__restrict pG, 127 | const float *__restrict pL, 128 | float *__restrict pD) noexcept 129 | { 130 | std::vector sGradientTransposed(nRow * nColumn); 131 | std::vector sLeftTransposed(nRow * nMaxIndex); 132 | auto *__restrict pGT{sGradientTransposed.data()}; 133 | auto *__restrict pLT{sLeftTransposed.data()}; 134 | 135 | #pragma omp parallel default(shared) num_threads(static_cast (std::max ( \ 136 | 1u, \ 137 | std::min (nMaxIndex * nRow * nColumn / 1600000u, std::thread::hardware_concurrency())))) 138 | { 139 | #pragma omp for schedule(guided) nowait 140 | for (std::int64_t nR = 0; nR < static_cast(nRow); ++nR) 141 | for (std::size_t nC{0}; nC < nColumn; ++nC) pGT[nC * nRow + nR] = pG[nR * nColumn + nC]; 142 | 143 | #pragma omp for schedule(guided) 144 | for (std::int64_t nR = 0; nR < static_cast(nRow); ++nR) 145 | for (std::size_t nC{0}; nC < nMaxIndex; ++nC) pLT[nC * nRow + nR] = pL[nR * nMaxIndex + nC]; 146 | 147 | #pragma omp for schedule(guided) 148 | for (std::int64_t nR = 0; nR < static_cast(nColumn); ++nR) 149 | for (std::size_t nC{0}; nC < nMaxIndex; ++nC) { 150 | std::size_t nIndex{0}; 151 | auto sSum = _mm256_setzero_ps(); 152 | 153 | for (; nIndex + 8 <= nRow; nIndex += 8) 154 | sSum = _mm256_fmadd_ps( 155 | _mm256_loadu_ps(pGT + nR * nRow + nIndex), 156 | _mm256_loadu_ps(pLT + nC * nRow + nIndex), 157 | sSum); 158 | 159 | const auto sSum128 = _mm_add_ps(_mm256_extractf128_ps(sSum, 1), _mm256_castps256_ps128(sSum)); 160 | const auto sSum64 = _mm_add_ps(sSum128, _mm_movehl_ps(sSum128, sSum128)); 161 | const auto sSum32 = _mm_add_ss(sSum64, _mm_shuffle_ps(sSum64, sSum64, 0x55)); 162 | auto nSum = _mm_cvtss_f32(sSum32); 163 | 164 | for (; nIndex < nRow; ++nIndex) nSum += pGT[nR * nRow + nIndex] * pLT[nC * nRow + nIndex]; 165 | 166 | pD[nR * nMaxIndex + nC] += nSum; 167 | } 168 | } 169 | } 170 | } // namespace tinnet::compute -------------------------------------------------------------------------------- /tinnet/tests/node/shape.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "tinnet/includes/node/Shape.h" 3 | 4 | #include "catch2.hpp" 5 | 6 | #include 7 | #include 8 | 9 | TEST_CASE("tinnet::node::Shape") 10 | { 11 | SECTION("Scalar shape") 12 | { 13 | tinnet::node::Shape sShape; 14 | 15 | CHECK(sShape.rank() == 0); 16 | CHECK(sShape.size() == 1); 17 | CHECK(sShape.scalar() == true); 18 | CHECK(sShape.vector() == false); 19 | CHECK(sShape.matrix() == false); 20 | CHECK(sShape.tensor() == false); 21 | } 22 | SECTION("Vector shape") 23 | { 24 | tinnet::node::Shape sShape{{1}}; 25 | 26 | CHECK(sShape.rank() == 1); 27 | CHECK(sShape.size() == 1); 28 | CHECK(sShape.scalar() == false); 29 | CHECK(sShape.vector() == true); 30 | CHECK(sShape.matrix() == false); 31 | CHECK(sShape.tensor() == false); 32 | } 33 | SECTION("Matrix shape") 34 | { 35 | tinnet::node::Shape sShape{{1, 1}}; 36 | 37 | CHECK(sShape.rank() == 2); 38 | CHECK(sShape.size() == 1); 39 | CHECK(sShape.scalar() == false); 40 | CHECK(sShape.vector() == false); 41 | CHECK(sShape.matrix() == true); 42 | CHECK(sShape.tensor() == false); 43 | } 44 | SECTION("Tensor shape") 45 | { 46 | tinnet::node::Shape sShape{{1, 1, 1}}; 47 | 48 | CHECK(sShape.rank() == 3); 49 | CHECK(sShape.size() == 1); 50 | CHECK(sShape.scalar() == false); 51 | CHECK(sShape.vector() == false); 52 | CHECK(sShape.matrix() == false); 53 | CHECK(sShape.tensor() == true); 54 | } 55 | SECTION("Ctor, Rank, Size #1") 56 | { 57 | tinnet::node::Shape sShape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}; 58 | 59 | CHECK(sShape.rank() == 9); 60 | CHECK(sShape.size() == 1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9); 61 | CHECK(sShape.scalar() == false); 62 | CHECK(sShape.vector() == false); 63 | CHECK(sShape.matrix() == false); 64 | CHECK(sShape.tensor() == true); 65 | } 66 | SECTION("Ctor, Rank, Size #2") 67 | { 68 | tinnet::node::Shape sShape{tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}}; 69 | 70 | CHECK(sShape.rank() == 9); 71 | CHECK(sShape.size() == 1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9); 72 | CHECK(sShape.scalar() == false); 73 | CHECK(sShape.vector() == false); 74 | CHECK(sShape.matrix() == false); 75 | CHECK(sShape.tensor() == true); 76 | } 77 | SECTION("Comparison operators") 78 | { 79 | CHECK( 80 | (tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}} == std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9}) 81 | == true); 82 | CHECK( 83 | (tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}} == std::vector{1, 2, 3, 4, 5, 6, 7, 8, 8}) 84 | == false); 85 | CHECK( 86 | (tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}} != std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9}) 87 | == false); 88 | CHECK( 89 | (tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}} != std::vector{1, 2, 3, 4, 5, 6, 7, 8, 8}) 90 | == true); 91 | CHECK( 92 | (tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}} == tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}) 93 | == true); 94 | CHECK( 95 | (tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}} == tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 8}}) 96 | == false); 97 | CHECK( 98 | (tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}} != tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}) 99 | == false); 100 | CHECK( 101 | (tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}} != tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 8}}) 102 | == true); 103 | 104 | CHECK( 105 | (std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9} == tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}) 106 | == true); 107 | CHECK( 108 | (std::vector{1, 2, 3, 4, 5, 6, 7, 8, 8} == tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}) 109 | == false); 110 | CHECK( 111 | (std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9} != tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}) 112 | == false); 113 | CHECK( 114 | (std::vector{1, 2, 3, 4, 5, 6, 7, 8, 8} != tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}) 115 | == true); 116 | CHECK( 117 | (tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}} == tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}) 118 | == true); 119 | CHECK( 120 | (tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 8}} == tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}) 121 | == false); 122 | CHECK( 123 | (tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}} != tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}) 124 | == false); 125 | CHECK( 126 | (tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 8}} != tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}) 127 | == true); 128 | } 129 | SECTION("Assignment operators") 130 | { 131 | CHECK( 132 | ((tinnet::node::Shape{} = std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9}) 133 | == tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}) 134 | == true); 135 | CHECK( 136 | ((tinnet::node::Shape{} = tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}) 137 | == tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}) 138 | == true); 139 | } 140 | SECTION("Extend, Shrink, Squeeze") 141 | { 142 | CHECK(tinnet::node::Shape{}.extend() == tinnet::node::Shape{{1}}); 143 | CHECK(tinnet::node::Shape{{1}}.extend() == tinnet::node::Shape{{1, 1}}); 144 | CHECK(tinnet::node::Shape{{1, 10}}.extend() == tinnet::node::Shape{{1, 1, 10}}); 145 | CHECK(tinnet::node::Shape{{1, 10, 20}}.extend() == tinnet::node::Shape{{1, 1, 10, 20}}); 146 | 147 | CHECK(tinnet::node::Shape{{1, 10, 2, 20}}.extend(0) == tinnet::node::Shape{{1, 10, 2, 20}}); 148 | CHECK(tinnet::node::Shape{{1, 10, 2, 20}}.extend(1) == tinnet::node::Shape{{1, 10, 2, 20}}); 149 | CHECK(tinnet::node::Shape{{1, 10, 2, 20}}.extend(10) == tinnet::node::Shape{{1, 1, 1, 1, 1, 1, 1, 10, 2, 20}}); 150 | 151 | CHECK(tinnet::node::Shape{{1, 1, 1, 1}}.shrink() == tinnet::node::Shape{}); 152 | CHECK(tinnet::node::Shape{{1, 10, 2, 20}}.shrink() == tinnet::node::Shape{{10, 2, 20}}); 153 | CHECK(tinnet::node::Shape{{10, 20, 30, 40}}.shrink() == tinnet::node::Shape{{10, 20, 30, 40}}); 154 | 155 | CHECK(tinnet::node::Shape{{1, 1, 1, 1}}.squeeze() == tinnet::node::Shape{}); 156 | CHECK(tinnet::node::Shape{{1, 10, 1, 20}}.squeeze() == tinnet::node::Shape{{10, 20}}); 157 | CHECK(tinnet::node::Shape{{1, 10, 2, 20}}.squeeze() == tinnet::node::Shape{{10, 2, 20}}); 158 | } 159 | SECTION("Broadcast") 160 | { 161 | CHECK( 162 | tinnet::node::Shape::broadcast(tinnet::node::Shape{{}}, tinnet::node::Shape{{}}) 163 | == tinnet::node::Shape{{}}); 164 | CHECK( 165 | tinnet::node::Shape::broadcast(tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}, tinnet::node::Shape{{}}) 166 | == tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}); 167 | CHECK( 168 | tinnet::node::Shape::broadcast(tinnet::node::Shape{{}}, tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}) 169 | == tinnet::node::Shape{{1, 2, 3, 4, 5, 6, 7, 8, 9}}); 170 | CHECK( 171 | tinnet::node::Shape::broadcast(tinnet::node::Shape{{1, 20, 1, 40}}, tinnet::node::Shape{{10, 1, 30, 1}}) 172 | == tinnet::node::Shape{{10, 20, 30, 40}}); 173 | CHECK( 174 | tinnet::node::Shape::broadcast(tinnet::node::Shape{{10, 20, 30, 40}}, tinnet::node::Shape{{10, 20, 30, 40}}) 175 | == tinnet::node::Shape{{10, 20, 30, 40}}); 176 | 177 | CHECK_THROWS(tinnet::node::Shape::broadcast( 178 | tinnet::node::Shape{{10, 20, 30, 40}}, 179 | tinnet::node::Shape{{40, 30, 20, 10}})); 180 | } 181 | } -------------------------------------------------------------------------------- /tinnet/src/node/kernel/BasicArithmetic.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "tinnet/includes/node/kernel/BasicArithmetic.h" 3 | 4 | #include 5 | 6 | namespace tinnet::node::kernel { 7 | memory::ScopedStorage __kernel__neg(Node *pNode) 8 | { 9 | if (!pNode) throw std::runtime_error{"invalid node"}; 10 | 11 | auto nSize{pNode->sShape.size()}; 12 | memory::ScopedStorage sResult{sizeof(float) * nSize}; 13 | 14 | auto *__restrict pD{sResult.aligned()}; 15 | const auto *__restrict pL{pNode->output().aligned()}; 16 | 17 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] = -pL[nIndex]; 18 | 19 | return sResult; 20 | } 21 | 22 | memory::ScopedStorage __kernel__add(Node *pLeft, Node *pRight) 23 | { 24 | if (!pLeft || !pRight) throw std::runtime_error{"invalid node"}; 25 | if (pLeft->eType != pRight->eType) throw std::runtime_error{"type mismatch"}; 26 | 27 | const auto sLeftShape = pLeft->sShape.squeeze(); 28 | const auto sRightShape = pRight->sShape.squeeze(); 29 | 30 | if (sLeftShape != sRightShape) throw std::runtime_error{"shape mismatch"}; 31 | 32 | auto nSize{sLeftShape.size()}; // == sRightShape.size() 33 | memory::ScopedStorage sResult{sizeof(float) * nSize}; 34 | 35 | auto *__restrict pD{sResult.aligned()}; 36 | const auto *__restrict pL{pLeft->output().aligned()}; 37 | const auto *__restrict pR{pRight->output().aligned()}; 38 | 39 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] = pL[nIndex] + pR[nIndex]; 40 | 41 | return sResult; 42 | } 43 | 44 | memory::ScopedStorage __kernel__sub(Node *pLeft, Node *pRight) 45 | { 46 | if (!pLeft || !pRight) throw std::runtime_error{"invalid node"}; 47 | if (pLeft->eType != pRight->eType) throw std::runtime_error{"type mismatch"}; 48 | 49 | const auto sLeftShape = pLeft->sShape.squeeze(); 50 | const auto sRightShape = pRight->sShape.squeeze(); 51 | 52 | if (sLeftShape != sRightShape) throw std::runtime_error{"shape mismatch"}; 53 | 54 | auto nSize{sLeftShape.size()}; // == sRightShape.size() 55 | memory::ScopedStorage sResult{sizeof(float) * nSize}; 56 | 57 | auto *__restrict pD{sResult.aligned()}; 58 | const auto *__restrict pL{pLeft->output().aligned()}; 59 | const auto *__restrict pR{pRight->output().aligned()}; 60 | 61 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] = pL[nIndex] - pR[nIndex]; 62 | 63 | return sResult; 64 | } 65 | 66 | memory::ScopedStorage __kernel__mul(Node *pLeft, Node *pRight) 67 | { 68 | if (!pLeft || !pRight) throw std::runtime_error{"invalid node"}; 69 | if (pLeft->eType != pRight->eType) throw std::runtime_error{"type mismatch"}; 70 | 71 | const auto sLeftShape = pLeft->sShape.squeeze(); 72 | const auto sRightShape = pRight->sShape.squeeze(); 73 | 74 | if (sLeftShape != sRightShape) throw std::runtime_error{"shape mismatch"}; 75 | 76 | auto nSize{sLeftShape.size()}; // == sRightShape.size() 77 | memory::ScopedStorage sResult{sizeof(float) * nSize}; 78 | 79 | auto *__restrict pD{sResult.aligned()}; 80 | const auto *__restrict pL{pLeft->output().aligned()}; 81 | const auto *__restrict pR{pRight->output().aligned()}; 82 | 83 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] = pL[nIndex] * pR[nIndex]; 84 | 85 | return sResult; 86 | } 87 | 88 | memory::ScopedStorage __kernel__div(Node *pLeft, Node *pRight) 89 | { 90 | if (!pLeft || !pRight) throw std::runtime_error{"invalid node"}; 91 | if (pLeft->eType != pRight->eType) throw std::runtime_error{"type mismatch"}; 92 | 93 | const auto sLeftShape = pLeft->sShape.squeeze(); 94 | const auto sRightShape = pRight->sShape.squeeze(); 95 | 96 | if (sLeftShape != sRightShape) throw std::runtime_error{"shape mismatch"}; 97 | 98 | auto nSize{sLeftShape.size()}; // == sRightShape.size() 99 | memory::ScopedStorage sResult{sizeof(float) * nSize}; 100 | 101 | auto *__restrict pD{sResult.aligned()}; 102 | const auto *__restrict pL{pLeft->output().aligned()}; 103 | const auto *__restrict pR{pRight->output().aligned()}; 104 | 105 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] = pL[nIndex] / (pR[nIndex] + 1e-5f); 106 | 107 | return sResult; 108 | } 109 | 110 | void __kernel__negGradient(Node *pNode, Node *pDeps) 111 | { 112 | auto nSize{pNode->sShape.size()}; 113 | auto *__restrict pD{pDeps->gradient().aligned()}; 114 | const auto *__restrict pG{pNode->gradient().aligned()}; 115 | 116 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] -= pG[nIndex]; 117 | } 118 | 119 | void __kernel__addGradient(Node *pNode, Node *pDeps) 120 | { 121 | auto nSize{pNode->sShape.size()}; 122 | auto *__restrict pD{pDeps->gradient().aligned()}; 123 | const auto *__restrict pG{pNode->gradient().aligned()}; 124 | 125 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] += pG[nIndex]; 126 | } 127 | 128 | void __kernel__subLGradient(Node *pNode, Node *pDeps) 129 | { 130 | auto nSize{pNode->sShape.size()}; 131 | auto *__restrict pD{pDeps->gradient().aligned()}; 132 | const auto *__restrict pG{pNode->gradient().aligned()}; 133 | 134 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] += pG[nIndex]; 135 | } 136 | 137 | void __kernel__subRGradient(Node *pNode, Node *pDeps) 138 | { 139 | auto nSize{pNode->sShape.size()}; 140 | auto *__restrict pD{pDeps->gradient().aligned()}; 141 | const auto *__restrict pG{pNode->gradient().aligned()}; 142 | 143 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] -= pG[nIndex]; 144 | } 145 | 146 | void __kernel__mulLGradient(Node *pNode, Node *pDeps) 147 | { 148 | auto nSize{pNode->sShape.size()}; 149 | auto *__restrict pD{pDeps->gradient().aligned()}; 150 | const auto *__restrict pG{pNode->gradient().aligned()}; 151 | const auto *__restrict pR{pNode->deps()[1]->output().aligned()}; 152 | 153 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] += pG[nIndex] * pR[nIndex]; 154 | } 155 | 156 | void __kernel__mulRGradient(Node *pNode, Node *pDeps) 157 | { 158 | auto nSize{pNode->sShape.size()}; 159 | auto *__restrict pD{pDeps->gradient().aligned()}; 160 | const auto *__restrict pG{pNode->gradient().aligned()}; 161 | const auto *__restrict pL{pNode->deps()[0]->output().aligned()}; 162 | 163 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] += pG[nIndex] * pL[nIndex]; 164 | } 165 | 166 | void __kernel__divLGradient(Node *pNode, Node *pDeps) 167 | { 168 | auto nSize{pNode->sShape.size()}; 169 | auto *__restrict pD{pDeps->gradient().aligned()}; 170 | const auto *__restrict pG{pNode->gradient().aligned()}; 171 | const auto *__restrict pR{pNode->deps()[1]->output().aligned()}; 172 | 173 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) pD[nIndex] += pG[nIndex] / (pR[nIndex] + 1e-5f); 174 | } 175 | 176 | void __kernel__divRGradient(Node *pNode, Node *pDeps) 177 | { 178 | auto nSize{pNode->sShape.size()}; 179 | auto *__restrict pD{pDeps->gradient().aligned()}; 180 | const auto *__restrict pG{pNode->gradient().aligned()}; 181 | const auto *__restrict pL{pNode->deps()[0]->output().aligned()}; 182 | const auto *__restrict pR{pNode->deps()[1]->output().aligned()}; 183 | 184 | for (std::size_t nIndex{0}; nIndex < nSize; ++nIndex) 185 | pD[nIndex] += pG[nIndex] * pL[nIndex] / (pR[nIndex] * pR[nIndex] + 1e-5f); 186 | } 187 | } // namespace tinnet::node::kernel --------------------------------------------------------------------------------