├── MetaNN ├── policy │ ├── policy_macro_end.hpp │ ├── policy_container.hpp │ ├── policy_macro_begin.hpp │ ├── policy_operations.hpp │ └── policy_selector.hpp ├── data │ ├── tags.hpp │ ├── lower_access.hpp │ ├── matrix │ │ ├── zero_matrix.hpp │ │ ├── one_hot_vector.hpp │ │ ├── trivial_matrix.hpp │ │ └── matrix.hpp │ ├── scalar.hpp │ ├── allocator.hpp │ ├── batch │ │ ├── duplicate.hpp │ │ ├── batch.hpp │ │ └── array.hpp │ └── traits.hpp ├── operator │ ├── abs.hpp │ ├── sign.hpp │ ├── tanh.hpp │ ├── softmax.hpp │ ├── sigmoid.hpp │ ├── tags.hpp │ ├── collapse.hpp │ ├── tanh_derivation.hpp │ ├── sigmoid_derivation.hpp │ ├── softmax_derivation.hpp │ ├── traits.hpp │ ├── negative_log_likelihood.hpp │ ├── interpolation.hpp │ ├── transpose.hpp │ ├── negative_log_likelihood_derivation.hpp │ ├── organizer.hpp │ ├── operators.hpp │ ├── dot.hpp │ ├── divide.hpp │ ├── subtract.hpp │ ├── element_mul.hpp │ └── add.hpp ├── param_initializer │ ├── fill_with_distribution.hpp │ ├── constant_filler.hpp │ ├── gaussian_filler.hpp │ ├── uniform_filler.hpp │ ├── init_policy.hpp │ ├── var_scale_filler.hpp │ └── param_initializer.hpp └── facility │ ├── data_copy.hpp │ └── var_type_dict.hpp ├── .editorconfig ├── test ├── test_layer.cpp ├── test_evaluation.cpp ├── test_policy.cpp ├── test.cpp ├── test_param_initializer.cpp ├── test_operator.cpp ├── Makefile ├── test_facility.cpp ├── test.hpp ├── test_data.cpp └── TestUtil.hpp ├── .gitignore ├── LICENSE ├── .vscode ├── tasks.json └── launch.json └── README.md /MetaNN/policy/policy_macro_end.hpp: -------------------------------------------------------------------------------- 1 | #undef TypePolicyObj 2 | #undef ValuePolicyObj 3 | #undef TypePolicyTemplate 4 | #undef ValuePolicyTemplate -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | root = true 3 | 4 | [*] 5 | indent_style = space 6 | indent_size = 4 7 | insert_final_newline = false 8 | trim_trailing_whitespace = false 9 | 10 | [Makefile] 11 | indent_style = tab 12 | -------------------------------------------------------------------------------- /test/test_layer.cpp: -------------------------------------------------------------------------------- 1 | #include "test.hpp" 2 | 3 | // using namespace MetaNN; 4 | 5 | 6 | void test_layer(TestUtil& util) 7 | { 8 | util.setTestGroup("layer"); 9 | { 10 | 11 | } 12 | util.showGroupResult(); 13 | } 14 | -------------------------------------------------------------------------------- /test/test_evaluation.cpp: -------------------------------------------------------------------------------- 1 | #include "test.hpp" 2 | 3 | // using namespace MetaNN; 4 | 5 | 6 | void test_evaluation(TestUtil& util) 7 | { 8 | util.setTestGroup("evaluation"); 9 | { 10 | 11 | } 12 | util.showGroupResult(); 13 | } -------------------------------------------------------------------------------- /test/test_policy.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "test.hpp" 6 | 7 | using namespace MetaNN; 8 | 9 | 10 | void test_policy(TestUtil& util) 11 | { 12 | util.setTestGroup("policy"); 13 | { 14 | 15 | } 16 | util.showGroupResult(); 17 | } 18 | -------------------------------------------------------------------------------- /test/test.cpp: -------------------------------------------------------------------------------- 1 | #include "test.hpp" 2 | 3 | int main(int argc, char const *argv[]) 4 | { 5 | bool showDetails = parseDetailFlag(argc, argv); 6 | TestUtil& util = getMetaNNTestUtil(showDetails); 7 | test_facility(); 8 | test_data(); 9 | // test_operator(); 10 | // test_policy(); 11 | // test_param_initializer(); 12 | // test_layer(); 13 | // test_evaluation(); 14 | util.showFinalResult(); 15 | return 0; 16 | } 17 | -------------------------------------------------------------------------------- /MetaNN/data/tags.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace MetaNN 4 | { 5 | 6 | // 数据类型分类 7 | // 注意:MetaNN并不区分向量和矩阵,向量被视为行或者列数为1的矩阵,涉及的运算也使用矩阵运算来表示,并且标签之间不存在层次包含关系,他们是互斥的 8 | struct CategoryTags 9 | { 10 | struct Scalar; // 标量 11 | struct Matrix; // 矩阵 12 | struct BatchScalar; // 标量列表 13 | struct BatchMatrix; // 矩阵列表 14 | }; 15 | 16 | // 硬件设备标签:当前仅支持使用CPU计算,但支持自行扩展 17 | struct DeviceTags 18 | { 19 | struct CPU; 20 | }; 21 | 22 | } // namespace MetaNN -------------------------------------------------------------------------------- /test/test_param_initializer.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "test.hpp" 9 | 10 | using namespace MetaNN; 11 | 12 | 13 | void test_param_initializer(TestUtil& util) 14 | { 15 | util.setTestGroup("parameter initializer"); 16 | { 17 | 18 | } 19 | util.showGroupResult(); 20 | } 21 | -------------------------------------------------------------------------------- /MetaNN/operator/abs.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // 绝对值:仅针对矩阵和矩阵列表 9 | template 10 | class OpAbs 11 | { 12 | using RawT = std::remove_cvref_t; 13 | public: 14 | static auto eval(T&& data) 15 | { 16 | using ResType = UnaryOp; 17 | return ResType(std::forward(data)); 18 | } 19 | }; 20 | 21 | template requires MatrixC || BatchMatrixC 22 | auto abs(T&& data) 23 | { 24 | return OpAbs::eval(std::forward(data)); 25 | } 26 | 27 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/sign.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // sign函数:用于矩阵和矩阵列表 9 | template 10 | class OpSign 11 | { 12 | private: 13 | using RawT = std::remove_cvref_t; 14 | public: 15 | static auto eval(T&& data) 16 | { 17 | using ResType = UnaryOp; 18 | return ResType(std::forward(data)); 19 | } 20 | }; 21 | 22 | template requires MatrixC || BatchMatrixC 23 | auto sign(T&& data) 24 | { 25 | return OpSign::eval(std::forward(data)); 26 | } 27 | 28 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/tanh.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // tanh:双曲正切,仅支持矩阵或者矩阵列表 9 | template 10 | class OpTanh 11 | { 12 | private: 13 | using RawT = std::remove_cvref_t; 14 | public: 15 | static auto eval(T&& data) 16 | { 17 | using ResType = UnaryOp; 18 | return ResType(std::forward(data)); 19 | } 20 | }; 21 | 22 | template requires MatrixC || BatchMatrixC 23 | auto tanh(T&& data) 24 | { 25 | return OpTanh::eval(std::forward(data)); 26 | } 27 | 28 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/policy/policy_container.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace MetaNN 4 | { 5 | 6 | template 7 | struct PolicyContainer; 8 | 9 | template 10 | constexpr bool IsPolicyContainer = false; 11 | 12 | template 13 | constexpr bool IsPolicyContainer> = true; 14 | 15 | template 16 | struct SubPolicyContainer; 17 | 18 | template 19 | constexpr bool IsSubPolicyContainer = false; 20 | 21 | template 22 | constexpr bool IsSubPolicyContainer> = true; 23 | 24 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/softmax.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // VecSoftMax:将输入矩阵归一化,用于矩阵和矩阵列表 9 | template 10 | class OpVecSoftmax 11 | { 12 | private: 13 | using RawT = std::remove_cvref_t; 14 | public: 15 | static auto eval(T&& data) 16 | { 17 | using ResType = UnaryOp; 18 | return ResType(std::forward(data)); 19 | } 20 | }; 21 | 22 | template requires MatrixC || BatchMatrixC 23 | auto vecSoftmax(T&& data) 24 | { 25 | return OpVecSoftmax::eval(std::forward(data)); 26 | } 27 | 28 | 29 | } // namespace MetaNN -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # https://github.com/tch0/MyConfigurations/blob/master/GitIgnore/Cpp.gitignore 2 | 3 | # all files ignored, for all generated binary files 4 | * 5 | 6 | # track all directories 7 | !*/ 8 | 9 | # .vscode configuration files 10 | !.vscode/launch.json 11 | !.vscode/tasks.json 12 | 13 | # track .gitignore .editorconfig 14 | !.gitignore 15 | !.editorconfig 16 | 17 | # track *.c *.h *.cpp source files 18 | !*.c 19 | !*.cpp 20 | !*.h 21 | !*.hpp 22 | !*.cc 23 | 24 | # track makefiles 25 | !Makefile 26 | !makefile 27 | 28 | # track README.md 29 | !README.md 30 | 31 | # track LICENSE 32 | !LICENSE 33 | 34 | # track test .txt files 35 | !*.txt 36 | 37 | # add all specific files that need to be tracked here 38 | -------------------------------------------------------------------------------- /MetaNN/data/lower_access.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // 通过一个中间层,提供更底层的访问,比如矩阵级的访问 9 | // 不直接在矩阵的数据访问接口中提供,接口中不提供以保证用户普通使用场景的安全性,这个中间层仅暴露给库实现者 10 | 11 | // 为需要暴露的任何类提供LowerAccessImpl特化 12 | template 13 | struct LowerAccessImpl; 14 | 15 | template 16 | auto lowerAccess(TData&& p) 17 | { 18 | using RawType = std::remove_cvref_t; 19 | return LowerAccessImpl(std::forward(p)); 20 | } 21 | 22 | // 是否具有底层访问:满足能够通过该对象构造LowerAccessImpl的要求 23 | template 24 | concept LowerAccessC = requires 25 | { 26 | LowerAccessImpl(std::declval()); 27 | }; 28 | 29 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/sigmoid.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // sigmoid function : S(x) = 1/(1+e^(-x)) 6 | // map (-infinity, +inifinity) to (0, 1) 7 | 8 | namespace MetaNN 9 | { 10 | 11 | // Sigmoid:仅支持矩阵或者矩阵列表 12 | template 13 | class OpSigmoid 14 | { 15 | private: 16 | using RawT = std::remove_cvref_t; 17 | public: 18 | static auto eval(T&& data) 19 | { 20 | using ResType = UnaryOp; 21 | return ResType(std::forward(data)); 22 | } 23 | }; 24 | 25 | template requires MatrixC || BatchMatrixC 26 | auto sigmoid(T&& data) 27 | { 28 | return OpSigmoid::eval(std::forward(data)); 29 | } 30 | 31 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/tags.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace MetaNN 4 | { 5 | 6 | // 一元运算 7 | struct UnaryOpTags 8 | { 9 | struct Abs; 10 | struct Sigmoid; 11 | struct Sign; 12 | struct Tanh; 13 | struct Transpose; 14 | struct Collapse; 15 | struct VecSoftmax; 16 | }; 17 | 18 | // 二元运算 19 | struct BinaryOpTags 20 | { 21 | struct Add; 22 | struct Subtract; 23 | struct ElementMul; 24 | struct Divide; 25 | struct Dot; 26 | struct NegativeLogLikelihood; 27 | struct SigmoidDerivation; 28 | struct TanhDerivation; 29 | struct VecSoftmaxDerivation; 30 | }; 31 | 32 | // 三元运算 33 | struct TernaryOpTags 34 | { 35 | struct Interpolation; 36 | struct NegativeLogLikelihoodDerivation; 37 | }; 38 | 39 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/collapse.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // 折叠运算:对一个矩阵列表求和,生成一个矩阵,输入输出类型不一样,需要特化OpCategory 9 | template<> 10 | struct OpCategory_ 11 | { 12 | using type = CategoryTags::Matrix; 13 | }; 14 | 15 | template 16 | class OpCollapse 17 | { 18 | using RawT = std::remove_cvref_t; 19 | public: 20 | static auto eval(T&& data) 21 | { 22 | using ResType = UnaryOp; 23 | return ResType(std::forward(data)); 24 | } 25 | }; 26 | 27 | template 28 | auto collapse(T&& data) requires BatchMatrixC 29 | { 30 | return OpCollapse::eval(std::forward(data)); 31 | } 32 | 33 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/data/matrix/zero_matrix.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace MetaNN 7 | { 8 | 9 | // 全零矩阵:即矩阵中元素全为0的平凡矩阵 10 | template 11 | class ZeroMatrix; 12 | 13 | template 14 | class ZeroMatrix 15 | { 16 | static_assert(std::is_same_v, TElem>, "TElem is not an available type"); 17 | public: 18 | using Category = CategoryTags::Matrix; 19 | using ElementType = TElem; 20 | using DeviceType = DeviceTags::CPU; 21 | public: 22 | ZeroMatrix(std::size_t row, std::size_t col) 23 | : m_rowNum(row) 24 | , m_colNum(col) 25 | { 26 | } 27 | 28 | // 访问接口 29 | std::size_t rowNum() const 30 | { 31 | return m_rowNum; 32 | } 33 | std::size_t colNum() const 34 | { 35 | return m_colNum; 36 | } 37 | 38 | // 求值接口: todo 39 | private: 40 | std::size_t m_rowNum; 41 | std::size_t m_colNum; 42 | // 求值结果缓存: todo 43 | }; 44 | 45 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/param_initializer/fill_with_distribution.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | // 使用特定分布初始化参数矩阵 8 | 9 | namespace MetaNN 10 | { 11 | 12 | namespace NsInitializer 13 | { 14 | 15 | template 16 | void fillWithDistribution(Matrix& data, TDist& dist, TEngine& engine) 17 | { 18 | if (!data.availableForWrite()) 19 | { 20 | throw std::runtime_error("Matrix is sharing, can not fill-in."); 21 | } 22 | 23 | auto acc = lowerAccess(data); 24 | std::size_t row = data.rowNum(); 25 | std::size_t col = data.colNum(); 26 | std::size_t rowLen = acc.rowLen(); 27 | auto p = acc.rawMutableMemory(); 28 | 29 | for (std::size_t i = 0; i < row; i++) 30 | { 31 | for (std::size_t j = 0; j < col; j++) 32 | { 33 | p[i] = static_cast(dist(engine)); 34 | } 35 | p += rowLen; 36 | } 37 | } 38 | 39 | } // namespace NsInitializer 40 | 41 | } // namespace MetaNN 42 | -------------------------------------------------------------------------------- /MetaNN/data/scalar.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace MetaNN 7 | { 8 | 9 | template 10 | class Scalar; 11 | 12 | template 13 | class Scalar 14 | { 15 | static_assert(std::is_same_v, TElem>, "TElem is not an available type"); 16 | public: 17 | using Category = CategoryTags::Scalar; 18 | using ElementType = TElem; 19 | using DeviceType = DeviceTags::CPU; 20 | public: 21 | Scalar(ElementType elem = {}) 22 | : m_elem(elem) {} 23 | // 拷贝构造、拷贝赋值、移动构造、移动赋值由编译器合成 24 | 25 | auto& value() { return m_elem; } 26 | 27 | auto value() const { return m_elem; } 28 | 29 | // 求值相关接口: todo 30 | bool operator==(const Scalar& rhs) const; 31 | 32 | template 33 | bool operator==(const TOtherType& rhs) const; 34 | 35 | template 36 | bool operator!=(const TData& rhs) const; 37 | private: 38 | ElementType m_elem; 39 | }; 40 | 41 | } // namespace MetaNN -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 tch0 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MetaNN/data/matrix/one_hot_vector.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | template 9 | class OneHotVector; 10 | 11 | template 12 | class OneHotVector 13 | { 14 | static_assert(std::is_same_v, TElem>, "TElem is not an available type"); 15 | public: 16 | using Category = CategoryTags::Matrix; 17 | using ElementType = TElem; 18 | using DeviceType = DeviceTags::CPU; 19 | public: 20 | // 行向量 21 | OneHotVector(std::size_t col, std::size_t hotPos) 22 | : m_colNum(col) 23 | , m_hotPos(hotPos) 24 | { 25 | } 26 | // 访问接口 27 | std::size_t rowNum() const 28 | { 29 | return 1; 30 | } 31 | std::size_t colNum() const 32 | { 33 | return m_colNum; 34 | } 35 | auto hotPos() const 36 | { 37 | return m_hotPos; 38 | } 39 | 40 | // 求值接口: todo 41 | private: 42 | std::size_t m_colNum; 43 | std::size_t m_hotPos; 44 | // 求值结果缓存: todo 45 | }; 46 | 47 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/policy/policy_macro_begin.hpp: -------------------------------------------------------------------------------- 1 | #define TypePolicyObj(PolicyName, Major, Minor, Value)\ 2 | struct PolicyName : virtual public Major\ 3 | {\ 4 | using MinorClass = Major::Minor##TypeCategory;\ 5 | using Minor = Major::Minor##TypeCategory::Value;\ 6 | } 7 | 8 | #define ValuePolicyObj(PolicyName, Major, Minor, Value)\ 9 | struct PolicyName : virtual public Major\ 10 | {\ 11 | using MinorClass = Major::Minor##ValueCategory;\ 12 | private:\ 13 | using type = std::remove_cvref_t;\ 14 | public:\ 15 | static constexpr type Minor = static_cast(Value);\ 16 | } 17 | 18 | #define TypePolicyTemplate(PolicyName, Major, Minor)\ 19 | template\ 20 | struct PolicyName : virtual public Major\ 21 | {\ 22 | using MinorClass = Major::Minor##TypeCategory;\ 23 | using Minor = T;\ 24 | } 25 | 26 | #define ValuePolicyTemplate(PolicyName, Major, Minor)\ 27 | template T>\ 28 | struct PolicyName : virtual public Major\ 29 | {\ 30 | using MinorClass = Major::Minor##ValueCategory;\ 31 | private:\ 32 | using type = std::remove_cvref_t;\ 33 | public:\ 34 | static constexpr type Minor = T;\ 35 | } 36 | -------------------------------------------------------------------------------- /MetaNN/operator/tanh_derivation.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // TanhDerivation operation 9 | // 支持类型: 10 | // 矩阵与矩阵 11 | // 矩阵列表与矩阵列表 12 | 13 | template 14 | class OpTanhDerivation 15 | { 16 | using RawT1 = std::remove_cvref_t; 17 | using RawT2 = std::remove_cvref_t; 18 | public: 19 | static auto eval(T1&& data1, T2&& data2) 20 | { 21 | static_assert(std::is_same_v, "Matrices with different element types can not do TanhDerivation directly"); 22 | static_assert(std::is_same_v, "Matrices with different device types can not do TanhDerivation directly"); 23 | 24 | using ResType = BinaryOp; 25 | return ResType(std::forward(data1), std::forward(data2)); 26 | } 27 | }; 28 | 29 | template 30 | requires (MatrixC && MatrixC) || (BatchMatrixC && BatchMatrixC) 31 | auto tanhDerivation(T1&& data1, T2&& data2) 32 | { 33 | return OpTanhDerivation::eval(std::forward(data1), std::forward(data2)); 34 | } 35 | 36 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/facility/data_copy.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace MetaNN 9 | { 10 | 11 | template 12 | void dataCopy(const Matrix& src, Matrix& dest) 13 | { 14 | std::size_t rowNum = src.rowNum(); 15 | std::size_t colNum = src.colNum(); 16 | if (rowNum != dest.rowNum() || colNum != dest.colNum()) 17 | { 18 | throw std::runtime_error("Error in dataCopy: matrix dimension mismatch!"); 19 | } 20 | const auto memSrc = lowerAccess(src); 21 | auto memDest = lowerAccess(dest); 22 | 23 | std::size_t srcRowLen = memSrc.rowLen(); 24 | std::size_t destRowLen = memDest.rowLen(); 25 | 26 | const TElem* pSrc = memSrc.rawMemory(); 27 | TElem* pDest = memDest.mutableRawMemory(); 28 | 29 | if (srcRowLen == colNum && destRowLen == colNum) 30 | { 31 | std::copy(pSrc, pSrc + rowNum * colNum, pDest); 32 | } 33 | else 34 | { 35 | for (std::size_t i = 0; i < rowNum; ++i) 36 | { 37 | std::copy(pSrc, pSrc + colNum, pDest); 38 | pDest += destRowLen; 39 | pSrc += srcRowLen; 40 | } 41 | } 42 | } 43 | 44 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/param_initializer/constant_filler.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace MetaNN 8 | { 9 | 10 | namespace NsConstantFiller 11 | { 12 | 13 | template 14 | void fill(Matrix& mat, const double& val) 15 | { 16 | if (!mat.avilableForWrite()) 17 | { 18 | throw std::runtime_error("Matrix is string weight, can not fill in."); 19 | } 20 | 21 | auto acc = lowerAccess(mat); 22 | std::size_t row = mat.rowNum(); 23 | std::size_t col = mat.colNum(); 24 | std::size_t rowLen = acc.rowLen(); 25 | auto p = acc.mutableMemory(); 26 | for (std::size_t i = 0; i < row; i++) 27 | { 28 | for (std::size_t j = 0; j < col; j ++) 29 | { 30 | p[j] = static_cast(val); 31 | } 32 | p += rowLen; 33 | } 34 | } 35 | 36 | } // NsConstantFiller 37 | 38 | class ConstantFiller 39 | { 40 | public: 41 | ConstantFiller(double val = 0) : m_val(val) {} 42 | template 43 | void fill(TData& data, std::size_t /*fan-in*/, std::size_t/*fan-out*/) 44 | { 45 | NsConstantFiller::fill(data, m_val); 46 | } 47 | private: 48 | double m_val; 49 | }; 50 | 51 | 52 | } // namespace MetaNN 53 | -------------------------------------------------------------------------------- /MetaNN/operator/sigmoid_derivation.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // SigmoidDerivation operation 9 | // 支持类型: 10 | // 矩阵与矩阵 11 | // 矩阵列表与矩阵列表 12 | 13 | template 14 | class OpSigmoidDerivation 15 | { 16 | using RawT1 = std::remove_cvref_t; 17 | using RawT2 = std::remove_cvref_t; 18 | public: 19 | static auto eval(T1&& data1, T2&& data2) 20 | { 21 | static_assert(std::is_same_v, "Matrices with different element types can not do SigmoidDerivation directly"); 22 | static_assert(std::is_same_v, "Matrices with different device types can not do SigmoidDerivation directly"); 23 | 24 | using ResType = BinaryOp; 25 | return ResType(std::forward(data1), std::forward(data2)); 26 | } 27 | }; 28 | 29 | template 30 | requires (MatrixC && MatrixC) || (BatchMatrixC && BatchMatrixC) 31 | auto sigmoidDerivation(T1&& data1, T2&& data2) 32 | { 33 | return OpSigmoidDerivation::eval(std::forward(data1), std::forward(data2)); 34 | } 35 | 36 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/softmax_derivation.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // VecSoftmaxDerivation operation 9 | // 支持类型: 10 | // 矩阵与矩阵 11 | // 矩阵列表与矩阵列表 12 | 13 | template 14 | class OpVecSoftmaxDerivation 15 | { 16 | using RawT1 = std::remove_cvref_t; 17 | using RawT2 = std::remove_cvref_t; 18 | public: 19 | static auto eval(T1&& data1, T2&& data2) 20 | { 21 | static_assert(std::is_same_v, "Matrices with different element types can not do VecSoftmaxDerivation directly"); 22 | static_assert(std::is_same_v, "Matrices with different device types can not do VecSoftmaxDerivation directly"); 23 | 24 | using ResType = BinaryOp; 25 | return ResType(std::forward(data1), std::forward(data2)); 26 | } 27 | }; 28 | 29 | template 30 | requires (MatrixC && MatrixC) || (BatchMatrixC && BatchMatrixC) 31 | auto vecSoftmaxDerivation(T1&& data1, T2&& data2) 32 | { 33 | return OpVecSoftmaxDerivation::eval(std::forward(data1), std::forward(data2)); 34 | } 35 | 36 | } // namespace MetaNN -------------------------------------------------------------------------------- /test/test_operator.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | #include "test.hpp" 28 | 29 | using namespace MetaNN; 30 | 31 | 32 | void test_OpCategory() 33 | { 34 | static_assert(std::same_as, ZeroMatrix, OneHotVector>, CategoryTags::Matrix>); 35 | } 36 | 37 | void test_operator(TestUtil& util) 38 | { 39 | util.setTestGroup("operator"); 40 | { 41 | 42 | } 43 | util.showGroupResult(); 44 | } 45 | -------------------------------------------------------------------------------- /test/Makefile: -------------------------------------------------------------------------------- 1 | # https://github.com/tch0/MyConfigurations/blob/master/MakefileTemplate/CppTemplate2.mk 2 | 3 | # Makefile template 2: 4 | # For multiple C++ files in one directory, compile into one executable. 5 | 6 | # make debug=yes to compile with -g 7 | # make system=windows for windows system 8 | 9 | .PHONY : all run rund 10 | .PHONY .IGNORE : clean 11 | 12 | # add your own include path/library path/link library to CXXFLAGS 13 | CXX = g++ 14 | CXXFLAGS += -std=c++20 15 | CXXFLAGS += -Wall -Wextra -pedantic-errors -Wshadow -Wno-sign-compare 16 | # CXXFLAGS += -Wfatal-errors 17 | CXXFLAGS += -I../MetaNN 18 | RM = rm 19 | 20 | # final target: add your target here 21 | target = test 22 | 23 | # debug 24 | ifeq ($(debug), yes) 25 | CXXFLAGS += -g 26 | else 27 | CXXFLAGS += -O3 28 | CXXFLAGS += -DNDEBUG 29 | endif 30 | 31 | # filenames and targets 32 | all_source_files := $(wildcard *.cpp) 33 | all_targets := $(target) 34 | 35 | # all targetss 36 | all : $(all_targets) 37 | 38 | # compile 39 | $(all_targets) : $(all_source_files) 40 | $(CXX) $^ -o $@ $(CXXFLAGS) 41 | 42 | # run 43 | run : $(all_targets) 44 | ./$(all_targets) 45 | rund : $(all_targets) 46 | ./$(all_targets) -d 47 | 48 | # system: affect how to clean and executable file name 49 | # value: windows/unix 50 | system = unix 51 | ifeq ($(system), windows) 52 | all_targets := $(addsuffix .exe, $(all_targets)) 53 | RM := del 54 | endif 55 | 56 | # clean 57 | clean : 58 | -$(RM) $(all_targets) 59 | -------------------------------------------------------------------------------- /MetaNN/param_initializer/gaussian_filler.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace MetaNN 13 | { 14 | 15 | // 使用正态分布初始化参数矩阵:提供平均值与标准差与一个随机数种子 16 | template> 17 | class GaussianFiller 18 | { 19 | using TRandomEngine = typename PolicySelect::RandomEngine; 20 | public: 21 | GaussianFiller(double meanVal, double standardDeviation, unsigned seed = std::random_device{}()) 22 | : m_engine(seed) 23 | , m_meanVal(meanVal) 24 | , m_stdDeviation(standardDeviation) 25 | { 26 | if (standardDeviation <= 0) 27 | { 28 | throw std::runtime_error("Invalid standard derivation for gaussian ditribution."); 29 | } 30 | } 31 | template 32 | void fill(TData& data, std::size_t /*fan-in*/, std::size_t/*fan-out*/) 33 | { 34 | using ElementType = typename TData::ElementType; 35 | std::normal_distribution dist(m_meanVal, m_stdDeviation); 36 | NsInitializer::fillWithDistribution(data, dist, m_engine); 37 | } 38 | private: 39 | TRandomEngine m_engine; 40 | double m_meanVal; 41 | double m_stdDeviation; 42 | }; 43 | 44 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/traits.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace MetaNN 8 | { 9 | 10 | // 获取运算结果元素类型,如果一个新运算具有不同行为,那么需要对其进行特化 11 | template 12 | struct OpElementType_ 13 | { 14 | using type = typename TOp1::ElementType; 15 | }; 16 | 17 | template 18 | using OpElementType = typename OpElementType_::type; 19 | 20 | // 获取运算结果设备类型 21 | template 22 | struct OpDeviceType_ 23 | { 24 | using type = typename TOp1::DeviceType; 25 | }; 26 | 27 | template 28 | using OpDeviceType = typename OpDeviceType_::type; 29 | 30 | 31 | // 当参与运算的所有类型都是一个类别时,直接定义结果类别为其共同类别(比如矩阵和矩阵那么结果就是矩阵) 32 | // 对不同类别的运算则需要特化(比如标量和矩阵,那么结果就由具体运算的特化决定) 33 | template 34 | struct OpCategory_ 35 | { 36 | static_assert((true && ... && std::is_same_v), "Data category mismatch."); 37 | using type = THeadCategory; 38 | }; 39 | 40 | template 41 | using OpCateCal = typename OpCategory_, DataCategory...>::type; 42 | 43 | // 求值逻辑:需要对具体类型特化 44 | template 45 | struct OpSeqContainer; 46 | 47 | template 48 | struct OpSeq_; 49 | 50 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/policy/policy_operations.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace MetaNN 7 | { 8 | 9 | // =============================== poliy existence check =============================== 10 | template 11 | struct PolicyExist_; 12 | 13 | template 14 | struct PolicyExist_, TPolicy> 15 | { 16 | static constexpr bool value = (std::is_same_v && 17 | std::is_same_v) || 18 | PolicyExist_, TPolicy>::value; 19 | }; 20 | 21 | // 跳过SubPolicyContainer 22 | template 23 | struct PolicyExist_, Ts2...>, TPolicy> 24 | { 25 | static constexpr bool value = PolicyExist_, TPolicy>::value; 26 | }; 27 | 28 | template 29 | struct PolicyExist_, TPolicy> 30 | { 31 | static constexpr bool value = false; 32 | }; 33 | 34 | template 35 | constexpr bool PolicyExist = PolicyExist_::value; 36 | 37 | 38 | // =============================== poliy derivation =============================== 39 | namespace NsPolicyDerive 40 | { 41 | 42 | 43 | 44 | } // namespace NsPolicyDerive 45 | 46 | 47 | } // namespace MetaNN -------------------------------------------------------------------------------- /test/test_facility.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "test.hpp" 7 | 8 | using namespace MetaNN; 9 | using namespace std::literals; 10 | 11 | void test_facility_var_type_dict(TestUtil& util); 12 | void test_facility_data_copy(TestUtil& util); 13 | 14 | void test_facility(TestUtil& util) 15 | { 16 | test_facility_var_type_dict(util); 17 | test_facility_data_copy(util); 18 | } 19 | 20 | using Params = VarTypeDict; 21 | template 22 | auto foo(const T& t) 23 | { 24 | auto a = t.template get(); 25 | const auto& b = t.template get(); 26 | auto& c = t.template get(); 27 | return std::tuple{a, b, c}; 28 | } 29 | 30 | void test_facility_var_type_dict(TestUtil& util) 31 | { 32 | util.setTestGroup("facilit.var_type_dict"); 33 | { 34 | auto res = foo(Params::create().set(1u).set(2.1).set("hello"s)); 35 | util.assertEqual(std::get<0>(res), 1u); 36 | util.assertEqual(std::get<1>(res), 2.1); 37 | util.assertEqual(std::get<2>(res), "hello"s); 38 | } 39 | util.showGroupResult(); 40 | } 41 | 42 | void test_facility_data_copy(TestUtil& util) 43 | { 44 | util.setTestGroup("facility.data_copy"); 45 | { 46 | Matrix mat1; 47 | iota(mat1); 48 | util.assertEqual(mat1.availableForWrite(), true); 49 | Matrix mat2; 50 | dataCopy(mat1, mat2); 51 | util.assertEqual(mat1, mat2); 52 | util.assertEqual(mat1.availableForWrite(), true); 53 | } 54 | util.showGroupResult(); 55 | } 56 | -------------------------------------------------------------------------------- /MetaNN/param_initializer/uniform_filler.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace MetaNN 13 | { 14 | 15 | // 使用均匀分布初始化参数矩阵:提供最大最小值和一个随机数种子 16 | template> 17 | class UniformFiller 18 | { 19 | using TRandomEngine = typename PolicySelect::RandomEngine; 20 | public: 21 | UniformFiller(double min, double max, unsigned seed = std::random_device{}()) 22 | : m_engine(seed) 23 | , m_min(min) 24 | , m_max(max) 25 | { 26 | if (min >= max) 27 | { 28 | throw std::runtime_error("Min if larger or equal than max for uniform ditribution."); 29 | } 30 | } 31 | template 32 | void fill(TData& data, std::size_t /*fan-in*/, std::size_t/*fan-out*/) 33 | { 34 | using ElementType = typename TData::ElementType; 35 | using DistType = std::conditional_t, 36 | std::uniform_int_distribution, 37 | std::uniform_real_distribution>; 38 | DistType dist(m_min, m_max); 39 | NsInitializer::fillWithDistribution(data, dist, m_engine); 40 | } 41 | private: 42 | TRandomEngine m_engine; 43 | double m_min; 44 | double m_max; 45 | }; 46 | 47 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/negative_log_likelihood.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // NegativeLogLikelihood operation 9 | // 支持类型: 10 | // 矩阵与矩阵:输出为标量 11 | // 矩阵列表与矩阵列表:输出为标量列表 12 | 13 | template<> 14 | struct OpCategory_ 15 | { 16 | using type = CategoryTags::Scalar; 17 | }; 18 | 19 | template<> 20 | struct OpCategory_ 21 | { 22 | using type = CategoryTags::BatchScalar; 23 | }; 24 | 25 | template 26 | class OpNegativeLogLikelihood 27 | { 28 | using RawT1 = std::remove_cvref_t; 29 | using RawT2 = std::remove_cvref_t; 30 | public: 31 | static auto eval(T1&& data1, T2&& data2) 32 | { 33 | static_assert(std::is_same_v, "Matrices with different element types can not do NegativeLogLikelihood directly"); 34 | static_assert(std::is_same_v, "Matrices with different device types can not do NegativeLogLikelihood directly"); 35 | 36 | using ResType = BinaryOp; 37 | return ResType(std::forward(data1), std::forward(data2)); 38 | } 39 | }; 40 | 41 | template 42 | requires (MatrixC && MatrixC) || (BatchMatrixC && BatchMatrixC) 43 | auto negativeLogLikelihood(T1&& data1, T2&& data2) 44 | { 45 | return OpNegativeLogLikelihood::eval(std::forward(data1), std::forward(data2)); 46 | } 47 | 48 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/interpolation.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // Interpolation operation:三元运算符 9 | // 支持类型: 10 | // 矩阵、矩阵、矩阵 11 | // 矩阵列表、矩阵列表、矩阵列表 12 | 13 | template 14 | class OpInterpolation 15 | { 16 | using RawT1 = std::remove_cvref_t; 17 | using RawT2 = std::remove_cvref_t; 18 | using RawT3 = std::remove_cvref_t; 19 | public: 20 | static auto eval(T1&& data1, T2&& data2, T3&& data3) 21 | { 22 | static_assert(std::is_same_v, "Matrices with different element types can not do Interpolation directly"); 23 | static_assert(std::is_same_v, "Matrices with different element types can not do Interpolation directly"); 24 | static_assert(std::is_same_v, "Matrices with different device types can not do Interpolation directly"); 25 | static_assert(std::is_same_v, "Matrices with different device types can not do Interpolation directly"); 26 | 27 | using ResType = TernaryOp; 28 | return ResType(std::forward(data1), std::forward(data2), std::forward(data3)); 29 | } 30 | }; 31 | 32 | template 33 | requires (MatrixC && MatrixC && MatrixC) || 34 | (BatchMatrixC && BatchMatrixC && BatchMatrixC) 35 | auto interpolation(T1&& data1, T2&& data2, T3&& data3) 36 | { 37 | return OpInterpolation::eval(std::forward(data1), std::forward(data2), std::forward(data3)); 38 | } 39 | 40 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/transpose.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // 转置操作修改了OpOrganizer的默认行为,需要对OpOrganizer进行特化 9 | 10 | // 矩阵转置 11 | template<> 12 | class OpOrganizer 13 | { 14 | public: 15 | template 16 | OpOrganizer(const TData& data) 17 | : m_rowNum(data.colNum()) 18 | , m_colNum(data.rowNum()) 19 | { 20 | } 21 | 22 | std::size_t rowNum() const 23 | { 24 | return m_rowNum; 25 | } 26 | std::size_t colNum() const 27 | { 28 | return m_colNum; 29 | } 30 | 31 | private: 32 | std::size_t m_rowNum; 33 | std::size_t m_colNum; 34 | }; 35 | 36 | // 矩阵列表转置:转置其中每一个矩阵 37 | template<> 38 | class OpOrganizer 39 | : public OpOrganizer 40 | { 41 | using BaseType = OpOrganizer; 42 | public: 43 | template 44 | OpOrganizer(const TData& data) 45 | : BaseType(data) 46 | , m_batchNum(data.batchNum()) 47 | { 48 | } 49 | 50 | std::size_t batchNum() const 51 | { 52 | return m_batchNum; 53 | } 54 | 55 | private: 56 | std::size_t m_batchNum; 57 | }; 58 | 59 | // 转置运算 60 | template 61 | class OpTranspose 62 | { 63 | using RawT = std::remove_cvref_t; 64 | public: 65 | static auto eval(T&& data) 66 | { 67 | using ResType = UnaryOp; 68 | return ResType(std::forward(data)); 69 | } 70 | }; 71 | 72 | template requires MatrixC || BatchMatrixC 73 | auto transpose(T&& data) 74 | { 75 | return OpTranspose::eval(std::forward(data)); 76 | } 77 | 78 | } // namespace MetaNN 79 | -------------------------------------------------------------------------------- /MetaNN/param_initializer/init_policy.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace MetaNN 7 | { 8 | 9 | struct InitPolicy 10 | { 11 | using MajorClass = InitPolicy; 12 | 13 | struct OverallTypeCategory; 14 | struct WeightTypeCategory; 15 | struct BiasTypeCategory; 16 | 17 | using Overall = void; 18 | using Weight = void; 19 | using Bias = void; 20 | 21 | struct RandomEngineTypeCategory; 22 | using RandomEngine = std::mt19937; 23 | }; 24 | 25 | TypePolicyTemplate(PInitializerIs, InitPolicy, Overall); // 设置默认初始化器 26 | TypePolicyTemplate(PWeightInitializerIs, InitPolicy, Weight); // 设置权重初始化器 27 | TypePolicyTemplate(PBiasInitializerIs, InitPolicy, Bias); // 设置偏置初始化器 28 | TypePolicyTemplate(PRandomGeneratorIs, InitPolicy, RandomEngine); // 设置随机数引擎 29 | 30 | // VarScaleFiller的策略 31 | struct VarScaleFillerPolicy 32 | { 33 | using MajorClass = VarScaleFillerPolicy; 34 | 35 | struct DistributionTypeCategory 36 | { 37 | struct Uniform; 38 | struct Normal; 39 | }; 40 | using Distribution = DistributionTypeCategory::Uniform; 41 | 42 | struct ScaleModeTypeCategory 43 | { 44 | struct FanIn; 45 | struct FanOut; 46 | struct FanAvg; 47 | }; 48 | using ScaleMode = ScaleModeTypeCategory::FanAvg; 49 | }; 50 | 51 | TypePolicyObj(PNormalVarScale, VarScaleFillerPolicy, Distribution, Normal); 52 | TypePolicyObj(PUniformVarScale, VarScaleFillerPolicy, Distribution, Uniform); 53 | TypePolicyObj(PVarScaleFanIn, VarScaleFillerPolicy, ScaleMode, FanIn); 54 | TypePolicyObj(PVarScaleFanOut, VarScaleFillerPolicy, ScaleMode, FanOut); 55 | TypePolicyObj(PVarScaleFanAvg, VarScaleFillerPolicy, ScaleMode, FanAvg); 56 | 57 | } // namespace MetaNN 58 | 59 | #include -------------------------------------------------------------------------------- /MetaNN/data/allocator.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace MetaNN 8 | { 9 | 10 | template 11 | struct Allocator; 12 | 13 | template<> 14 | struct Allocator 15 | { 16 | template 17 | static std::shared_ptr allocate(size_t elementSize) 18 | { 19 | return std::shared_ptr(new TElem[elementSize], [](TElem* ptr) { delete [] ptr; }); 20 | } 21 | }; 22 | 23 | // 维护Allocator分配的内存 24 | // 传递内存时同时传递智能指针,确保引用计数的正确性 25 | // 使用时只使用底层的内存,通常指向智能指针维护内存的开始,但也可能指向中间(比如涉及子矩阵的情况) 26 | // 该对象拷贝是浅拷贝,以避免大量数据的深拷贝 27 | // 读操作可以任意时候进行,但是写操作只能在引用计数为1,也就是无其他地方引用时进行,防止修改了共享了底层内存的其他数据造成错误。 28 | template 29 | class ContinuousMemory 30 | { 31 | static_assert(std::is_same_v, TElem>, "TElem is not an available type"); // 内存中保存的类型不应该有CVRef限定 32 | using ElementType = TElem; 33 | public: 34 | explicit ContinuousMemory(size_t size) 35 | : m_sp(Allocator::template allocate(size)) 36 | , m_pMemStart(m_sp.get()) 37 | { 38 | } 39 | ContinuousMemory(std::shared_ptr spMem, ElementType* pMemStart) 40 | : m_sp(std::move(spMem)) 41 | , m_pMemStart(pMemStart) 42 | { 43 | } 44 | auto rawMemory() const 45 | { 46 | return m_pMemStart; 47 | } 48 | const std::shared_ptr sharedPtr() const 49 | { 50 | return m_sp; 51 | } 52 | bool operator==(const ContinuousMemory& rhs) const 53 | { 54 | return m_sp == rhs.m_sp && m_pMemStart == rhs.m_pMemStart; 55 | } 56 | bool operator!=(const ContinuousMemory& rhs) const 57 | { 58 | return !(operator==(rhs)); 59 | } 60 | size_t useCount() const 61 | { 62 | return m_sp.use_count(); 63 | } 64 | private: 65 | std::shared_ptr m_sp; 66 | ElementType* m_pMemStart; 67 | }; 68 | 69 | } // namespace MetaNN -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | // https://github.com/tch0/MyConfigurations/blob/master/VsCodeCppConfig/tasks.json 2 | { 3 | "tasks": [ 4 | // nromal configuration: g++ compile single C++ files. 5 | { 6 | "type": "cppbuild", 7 | "label": "C/C++: g++ compile single file", 8 | "command": "g++", 9 | "args": [ 10 | "-fdiagnostics-color=always", 11 | "-g", 12 | "${file}", 13 | "-o", 14 | "${fileDirname}/${fileBasenameNoExtension}", 15 | "-Wall", 16 | "-std=c++20" 17 | ], 18 | "options": { 19 | "cwd": "${fileDirname}" 20 | }, 21 | "problemMatcher": [ 22 | "$gcc" 23 | ], 24 | "group": "build", 25 | "detail": "defualt task" 26 | }, 27 | // make: compile single files. 28 | { 29 | "type": "cppbuild", 30 | "label": "C/C++: make compile single file", 31 | "command": "make", 32 | "args": [ 33 | "debug=yes", 34 | "${fileDirname}/${fileBasenameNoExtension}" 35 | ], 36 | "options": { 37 | "cwd": "${fileDirname}" 38 | }, 39 | "problemMatcher": [ 40 | "$gcc" 41 | ], 42 | "group": "build", 43 | "detail": "" 44 | }, 45 | // make: compile all files into one target. 46 | { 47 | "type": "cppbuild", 48 | "label": "C/C++: make compile all files into one target", 49 | "command": "make", 50 | "args": [ 51 | "debug=yes" 52 | // add your target name here if necessary 53 | ], 54 | "options": { 55 | "cwd": "${fileDirname}" 56 | }, 57 | "problemMatcher": [ 58 | "$gcc" 59 | ], 60 | "group": "build", 61 | "detail": "" 62 | } 63 | ], 64 | "version": "2.0.0" 65 | } 66 | -------------------------------------------------------------------------------- /MetaNN/data/matrix/trivial_matrix.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace MetaNN 9 | { 10 | 11 | // 平凡矩阵:所有元素值都一样的矩阵 12 | template> 13 | class TrivialMatrix; 14 | 15 | template 16 | class TrivialMatrix 17 | { 18 | static_assert(std::is_same_v, TElem>, "TElem is not an available type"); 19 | public: 20 | using Category = CategoryTags::Matrix; 21 | using ElementType = TElem; 22 | using DeviceType = DeviceTags::CPU; 23 | public: 24 | TrivialMatrix(std::size_t row, std::size_t col, TScalar val) 25 | : m_rowNum(row) 26 | , m_colNum(col) 27 | , m_val(val) 28 | { 29 | } 30 | 31 | // 访问接口 32 | std::size_t rowNum() const 33 | { 34 | return m_rowNum; 35 | } 36 | std::size_t colNum() const 37 | { 38 | return m_colNum; 39 | } 40 | // 读访问接口 41 | auto elementValue() const 42 | { 43 | return m_val; 44 | } 45 | 46 | // 求值接口: todo 47 | 48 | private: 49 | std::size_t m_rowNum; 50 | std::size_t m_colNum; 51 | TScalar m_val; 52 | // 求值结果缓存: todo 53 | }; 54 | 55 | // 创建平凡矩阵,简化构造过程 56 | template 57 | auto makeTrivialMatrix(std::size_t row, std::size_t col, TVal&& val) 58 | { 59 | using RawVal = std::remove_cvref_t; 60 | if constexpr (IsScalarC) 61 | { 62 | static_assert(std::is_same_v || 63 | std::is_same_v); 64 | return TrivialMatrix(row, col, val); 65 | } 66 | else 67 | { 68 | TElem tmpElem = static_cast(val); 69 | Scalar scalar(std::move(tmpElem)); 70 | return TrivialMatrix>(row, col, std::move(scalar)); 71 | } 72 | } 73 | 74 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/negative_log_likelihood_derivation.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // NegativeLogLikelihoodDerivation operation:三元运算符 9 | // 支持类型: 10 | // 标量、矩阵、矩阵:输出为矩阵 11 | // 标量列表、矩阵列表、矩阵列表:输出为矩阵列表 12 | 13 | template<> 14 | struct OpCategory_ 16 | { 17 | using type = CategoryTags::Matrix; 18 | }; 19 | 20 | template<> 21 | struct OpCategory_ 23 | { 24 | using type = CategoryTags::BatchMatrix; 25 | }; 26 | 27 | 28 | template 29 | class OpNegativeLogLikelihoodDerivation 30 | { 31 | using RawT1 = std::remove_cvref_t; 32 | using RawT2 = std::remove_cvref_t; 33 | using RawT3 = std::remove_cvref_t; 34 | public: 35 | static auto eval(T1&& data1, T2&& data2, T3&& data3) 36 | { 37 | static_assert(std::is_same_v, "Matrices with different element types can not do NegativeLogLikelihoodDerivation directly"); 38 | static_assert(std::is_same_v, "Matrices with different element types can not do NegativeLogLikelihoodDerivation directly"); 39 | static_assert(std::is_same_v, "Matrices with different device types can not do NegativeLogLikelihoodDerivation directly"); 40 | static_assert(std::is_same_v, "Matrices with different device types can not do NegativeLogLikelihoodDerivation directly"); 41 | 42 | using ResType = TernaryOp; 43 | return ResType(std::forward(data1), std::forward(data2), std::forward(data3)); 44 | } 45 | }; 46 | 47 | template 48 | requires (MatrixC && MatrixC && MatrixC) || 49 | (BatchMatrixC && BatchMatrixC && BatchMatrixC) 50 | auto negativeLogLikelihoodDerivation(T1&& data1, T2&& data2, T3&& data3) 51 | { 52 | return OpNegativeLogLikelihoodDerivation::eval(std::forward(data1), std::forward(data2), std::forward(data3)); 53 | } 54 | 55 | } // namespace MetaNN -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | // https://github.com/tch0/MyConfigurations/blob/master/VsCodeCppConfig/launch.json 2 | { 3 | "configurations": [ 4 | // normal configuration: g++ compile single C++ files. 5 | { 6 | "name": "gdb debug single file", 7 | "type": "cppdbg", 8 | "request": "launch", 9 | "program": "${fileDirname}/${fileBasenameNoExtension}", 10 | "args": [], 11 | "stopAtEntry": false, 12 | "cwd": "${fileDirname}", 13 | "environment": [], 14 | "externalConsole": false, 15 | "MIMode": "gdb", 16 | "miDebuggerPath": "gdb", 17 | "setupCommands": [ 18 | { 19 | "description": "pretty printing for gdb", 20 | "text": "-enable-pretty-printing", 21 | "ignoreFailures": true 22 | } 23 | ], 24 | "preLaunchTask": "C/C++: g++ compile single file" 25 | }, 26 | // make: g++ compile single files. 27 | { 28 | "name": "make and gdb debug single file", 29 | "type": "cppdbg", 30 | "request": "launch", 31 | "program": "${fileDirname}/${fileBasenameNoExtension}", 32 | "args": [], 33 | "stopAtEntry": false, 34 | "cwd": "${fileDirname}", 35 | "environment": [], 36 | "externalConsole": false, 37 | "MIMode": "gdb", 38 | "miDebuggerPath": "gdb", 39 | "setupCommands": [ 40 | { 41 | "description": "pretty printing for gdb", 42 | "text": "-enable-pretty-printing", 43 | "ignoreFailures": true 44 | } 45 | ], 46 | "preLaunchTask": "C/C++: make compile single file" 47 | }, 48 | // make: g++ compile all source files into one executable. 49 | { 50 | "name": "make and gdb debug one target", 51 | "type": "cppdbg", 52 | "request": "launch", 53 | "program": "${fileDirname}/test", // replace with your target filename 54 | "args": [], // add your start arguments 55 | "stopAtEntry": false, 56 | "cwd": "${fileDirname}", 57 | "environment": [], 58 | "externalConsole": false, 59 | "MIMode": "gdb", 60 | "miDebuggerPath": "gdb", 61 | "setupCommands": [ 62 | { 63 | "description": "pretty printing for gdb", 64 | "text": "-enable-pretty-printing", 65 | "ignoreFailures": true 66 | } 67 | ], 68 | "preLaunchTask": "C/C++: make compile all files into one target" 69 | } 70 | ], 71 | "version": "2.0.0" 72 | } 73 | -------------------------------------------------------------------------------- /MetaNN/policy/policy_selector.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace MetaNN 7 | { 8 | 9 | namespace NsPolicySelect 10 | { 11 | 12 | // 组合多个策略的结果 13 | template 14 | struct PolicySelectResult; 15 | 16 | // 直接多继承即可组合,按道理来说经过了检查之后不会存在成员有冲突 17 | template 18 | struct PolicySelectResult> : public TPolicies... {}; 19 | 20 | // 过滤出所有特定MajorClass的策略对象 21 | template 22 | struct MajorFilter_ 23 | { 24 | using type = TResult; 25 | }; 26 | 27 | template 28 | struct MajorFilter_, TMajorClass, TCurPolicy, TRestPolicies...> 29 | { 30 | using type = typename MajorFilter_, 31 | PolicyContainer, 32 | PolicyContainer>, 33 | TMajorClass, 34 | TRestPolicies...>::type; 35 | }; 36 | 37 | // 检查策略容器中是否有互斥(相同的MinorClass)的策略对象,有的话返回false 38 | template 39 | struct MinorCheck_ 40 | { 41 | static constexpr bool value = true; 42 | }; 43 | 44 | template 45 | struct MinorCheck_> 46 | { 47 | static constexpr bool current = (true && ... && (!std::is_same_v)); 48 | static constexpr bool value = current && MinorCheck_::value; 49 | }; 50 | 51 | // 从策略容器中选择出所有相同MajorClass的策略对象 52 | template 53 | struct Selector_; 54 | 55 | template 56 | struct Selector_> 57 | { 58 | using TMF = typename MajorFilter_, TMajorClass, TPolicies...>::type; 59 | static_assert(MinorCheck_::value, "Minor class set conflict!"); 60 | 61 | using type = std::conditional_t>, // 筛选结果是否为空 62 | TMajorClass, // 为空则使用默认策略,也就是TMajorClass 63 | PolicySelectResult>; // 不为空则将这些策略与默认策略组合起来得到结果 64 | }; 65 | 66 | } // namespace NsPolicySelect 67 | 68 | template 69 | using PolicySelect = typename NsPolicySelect::Selector_::type; 70 | 71 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/param_initializer/var_scale_filler.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace MetaNN 15 | { 16 | 17 | // TensorFlow中的variance_scaling_filler,并由此构造出XavierFiller和MSRAFiller(todo yet!) 18 | template> 19 | class VarScaleFiller 20 | { 21 | using TRandomEngine = typename PolicySelect::RandomEngine; 22 | public: 23 | VarScaleFiller(double factor, unsigned seed = std::random_device{}()) 24 | : m_engine(seed) 25 | , m_factor(factor) 26 | { 27 | } 28 | template 29 | void fill(TData& data, std::size_t fanIn, std::size_t fanOut) 30 | { 31 | using ScaleMode = typename PolicySelect::ScaleMode; 32 | double fan_factor = 0; 33 | if constexpr (std::is_same_v) 34 | { 35 | fan_factor = fanIn; 36 | } 37 | else if constexpr (std::is_same_v) 38 | { 39 | fan_factor = fanOut; 40 | } 41 | else if constexpr (std::is_same_v) 42 | { 43 | fan_factor = (fanIn + fanOut) / 2; 44 | } 45 | else 46 | { 47 | static_assert(DependencyFalse); 48 | } 49 | 50 | using DistType = typename PolicySelect::Distribution; 51 | using ElementType = typename TData::ElementType; 52 | if constexpr (std::is_same_v) 53 | { 54 | double limit = std::sqrt(3.0 * m_factor / fan_factor); 55 | std::uniform_int_distribution dist(-limit, limit); 56 | NsInitializer::fillWithDistribution(data, dist, m_engine); 57 | } 58 | else if constexpr (std::is_same_v) 59 | { 60 | double stdDeviation = std::sqrt(m_factor / fan_factor); 61 | std::normal_distribution dist(0, stdDeviation); 62 | NsInitializer::fillWithDistribution(data, dist, m_engine); 63 | } 64 | else 65 | { 66 | static_assert(DependencyFalse); 67 | } 68 | } 69 | private: 70 | TRandomEngine m_engine; 71 | double m_factor; 72 | }; 73 | 74 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/organizer.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace MetaNN 8 | { 9 | 10 | // TOpTag是操作标签,TCategory是结果类别,针对不同类别进行偏特化 11 | // 如果有运算不满足这里的默认行为,需要针对特定运算进行偏特化或者全特化 12 | template 13 | class OpOrganizer; 14 | 15 | template 16 | class OpOrganizer 17 | { 18 | public: 19 | template 20 | OpOrganizer(const THead& head, const TRemain&... remain) 21 | { 22 | } 23 | }; 24 | 25 | template 26 | class OpOrganizer 27 | { 28 | public: 29 | template 30 | OpOrganizer(const THead& head, const TRemain&... remain) 31 | : m_rowNum(head.rowNum()) 32 | , m_colNum(head.colNum()) 33 | { 34 | assert((true && ... && (head.rowNum() == remain.rowNum()))); 35 | assert((true && ... && (head.colNum() == remain.colNum()))); 36 | } 37 | std::size_t rowNum() const 38 | { 39 | return m_rowNum; 40 | } 41 | std::size_t colNum() const 42 | { 43 | return m_colNum; 44 | } 45 | private: 46 | std::size_t m_rowNum; 47 | std::size_t m_colNum; 48 | }; 49 | 50 | template 51 | class OpOrganizer 52 | { 53 | public: 54 | template 55 | OpOrganizer(const THead& head, const TRemain&... remain) 56 | : m_batchNum(head.batchNum) 57 | { 58 | assert((true && ... && (head.batchNum() == remain.batchNum()))); 59 | } 60 | std::size_t batchNum() const 61 | { 62 | return m_batchNum; 63 | } 64 | private: 65 | std::size_t m_batchNum; 66 | }; 67 | 68 | template 69 | class OpOrganizer 70 | { 71 | public: 72 | template 73 | OpOrganizer(const THead& head, const TRemain&... remain) 74 | : m_rowNum(head.rowNum()) 75 | , m_colNum(head.colNum()) 76 | , m_batchNum(head.batchNum) 77 | { 78 | assert((true && ... && (head.rowNum() == remain.rowNum()))); 79 | assert((true && ... && (head.colNum() == remain.colNum()))); 80 | assert((true && ... && (head.batchNum() == remain.batchNum()))); 81 | } 82 | std::size_t rowNum() const 83 | { 84 | return m_rowNum; 85 | } 86 | std::size_t colNum() const 87 | { 88 | return m_colNum; 89 | } 90 | std::size_t batchNum() const 91 | { 92 | return m_batchNum; 93 | } 94 | private: 95 | std::size_t m_rowNum; 96 | std::size_t m_colNum; 97 | std::size_t m_batchNum; 98 | }; 99 | 100 | } // namespace MetaNN 101 | -------------------------------------------------------------------------------- /MetaNN/data/batch/duplicate.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace MetaNN 8 | { 9 | 10 | // 将矩阵或者标量转换为包含相同值的矩阵列表或者标量列表 11 | // 比如用于矩阵与矩阵列表(或类似场景)的操作,首先将矩阵转换为矩阵的重复列表。 12 | // 最终的操作形式都是列表与列表或者非列表与非列表,可大幅减小冗余代码 13 | template 14 | class Duplicate; 15 | 16 | // 标量重复列表 17 | template 18 | class Duplicate 19 | { 20 | static_assert(std::is_same_v, TData>, "TData is not an available type"); 21 | public: 22 | using Category = CategoryTags::BatchScalar; 23 | using ElementType = typename TData::ElementType; 24 | using DeviceType = typename TData::DeviceType; 25 | public: 26 | Duplicate(TData data, std::size_t batchNum) 27 | : m_data(std::move(data)) 28 | , m_batchNum(batchNum) 29 | { 30 | assert(m_batchNum != 0); 31 | } 32 | 33 | // 查询接口 34 | std::size_t batchNum() const 35 | { 36 | return m_batchNum; 37 | } 38 | const TData& element() const 39 | { 40 | return m_data; 41 | } 42 | 43 | // 求值接口: todo 44 | private: 45 | TData m_data; 46 | std::size_t m_batchNum; 47 | }; 48 | 49 | // 矩阵重复列表 50 | template 51 | class Duplicate 52 | { 53 | static_assert(std::is_same_v, TData>, "TData is not an available type"); 54 | public: 55 | using Category = CategoryTags::BatchMatrix; 56 | using ElementType = typename TData::ElementType; 57 | using DeviceType = typename TData::DeviceType; 58 | public: 59 | Duplicate(TData data, std::size_t batchNum) 60 | : m_data(std::move(data)) 61 | , m_batchNum(batchNum) 62 | { 63 | assert(m_batchNum != 0); 64 | } 65 | 66 | // 查询接口 67 | std::size_t rowNum() const 68 | { 69 | return m_data.rowNum(); 70 | } 71 | std::size_t colNum() const 72 | { 73 | return m_data.colNum(); 74 | } 75 | std::size_t batchNum() const 76 | { 77 | return m_batchNum; 78 | } 79 | const TData& element() const 80 | { 81 | return m_data; 82 | } 83 | 84 | // 求值接口: todo 85 | 86 | private: 87 | TData m_data; 88 | std::size_t m_batchNum; 89 | // 求值缓存: todo 90 | }; 91 | 92 | // 快捷构造Duplicate 93 | template requires ScalarC || MatrixC 94 | auto makeDuplicate(std::size_t batchNum, TData&& data) 95 | { 96 | using RawDataType = std::remove_cvref_t; 97 | return Duplicate(std::forward(data), batchNum); 98 | } 99 | 100 | template requires ScalarC || MatrixC 101 | auto makeDuplicate(std::size_t batchNum, Args&&... args) 102 | { 103 | using RawDataType = std::remove_cvref_t; 104 | RawDataType tmp(std::forward(args)...); 105 | return Dupliate(std::move(tmp), batchNum); 106 | } 107 | 108 | } // namespace MetaNN -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # C++模板元编程实战 2 | 3 | 阅读《[C++模板元编程实战:一个深度学习框架的初步实现](https://book.douban.com/subject/30394402/)》的同步代码实现。 4 | 5 | 直接参考来源: 6 | - [bluealert/MetaNN-book](https://github.com/bluealert/MetaNN-book) 7 | - [liwei-cpp/MetaNN](https://github.com/liwei-cpp/MetaNN) 8 | - 在其基础上使用C++20进行了一定程度重构和重新组织。 9 | 10 | 编码风格与特点: 11 | - 使用C++20标准,仅头文件(Header-Only),添加[`./MetaNN`](./MetaNN/)到包含目录即可使用。 12 | - 所有代码包含在命名空间 `MetaNN` 中。 13 | - 4空格缩进,大括号换行,类名大驼峰,函数与变量名小驼峰,文件采用全小写下划线连接。 14 | - 命名: 15 | - 使用内嵌类型或者常量作为输出的元函数使用下划线`_`结尾。 16 | - 而对应直接作为输出结果的元函数,比如变量模板、别名模板则在其基础上去掉末尾`_`。 17 | - 概念都以`C`后缀结尾。 18 | - 策略以`P`作为前缀,其中的策略模板以`Is`作为后缀。 19 | 20 | 运行测试: 21 | ```shell 22 | cd ./test 23 | make run 24 | ``` 25 | 26 | 目前完成状态: 27 | - 仅完成了前五章的代码,第四章以前的代码都经过了测试。 28 | - 第六章基本层的实现,第七章复合层和循环层的实现才是深度学习框架的重点,更不用说还有第八章的求值。 29 | - 这本书里描述的深度学习框架的基本实现方法已经了解了,但很多地方都并不清楚为什么要这样做。 30 | - 后三章代码细节很多,但在不是很了解为什么的情况下实现和测试每一个细节会很折磨,所以目前处于搁置状态,仅了解了解原理,不确定以后是否会来实现。 31 | 32 | 这本书重点与思想总结: 33 | - 异类词典和策略模板作为实用的技巧,在基本层、复合层的实现中大量使用。异类词典用来在一个参数中保存任意数量任意类型的对象。策略模板通过编译期运算用来配置各种各样的选项(主要用在层中),不需要运行时代价。 34 | - 类型体系的核心思想是数据共享,核心数据结构是矩阵,底层都使用`std::shared_ptr`共享。除了数据共享,另一个核心思想是富类型,编译期多态提供的能力,可避免动态多态的运行时开销,通过标签体系而不是继承定义类型的类别,同一个类别都可以进行相关操作。所以可以为其中特殊的类型提供特殊实现,并且最大限度提供类型信息,为最后的求值优化提供可能性。同一个类别中最基本的类型成为主体类型,比如保存所有元素的矩阵是主题类型,而所有元素都相同只保存一个值的平凡矩阵就不是。使用C++20引入的概念替代SFINAE元编程可以很方便的编写类型体系的代码。 35 | - 表达式模板是模板元编程的一大精华,通过将运算组织称编译期表达式树,运算被声明时不会立即进行而仅仅保存将会进行这样一个计算这种状态,等到显式调用求值接口时才进行计算。即是惰性求值,表达式模板将运算视为对数据的变换,求值时一步到位,避免大量中间状态与计算,并且可以根据特定类型信息进行优化(比如乘以0矩阵直接不做返回0矩阵即可),可以大幅提升计算性能。 36 | - 通过表达式模板来组织深度学习的模型太原始了,还需要更加高级的实体来组织深度学习的模型,这就是层。基本层设计时需要考虑各种各样的因素: 37 | - 参数矩阵的初始化:某些层中会有参数矩阵,在每一轮训练中得到更新,需要支持这些参数矩阵的初始化:可能是从文件中加载的上次训练到一半的结果、或者第一次训练前用来填充的随机值或者满足某个特定分布的序列。还要考虑一个同样的层在模型多处被复用,共享同一个参数矩阵的情况。更复杂的框架还需要考虑并行训练时的如何共享和更新的问题(本书中没有)。 38 | - 正向传播过程:数据从输入流动到输出的过程,有参数矩阵的层需要用到参数矩阵参与运算。 39 | - 存储中间结果:某些层中为了能够在反向传播时计算梯度,需要能够将计算的中间结果存储下来以便反向传播时使用。这种数据本书中是放在一个栈中,每次正向传播时计算并填充,反向传播时消耗。还需要提供检测,这个数据是否处于不正常状态(中性检测:比如产生了但是没有被使用)。 40 | - 反向传播:在正向传播完成后,数据从输入端流动到了输出端,一般会有一个层将输出与预先标注的结果(监督学习)进行比较,进行量化(通过损失函数)并向输入端传播。 41 | - 参数矩阵更新:反向传播过程中,数据(梯度信息)从输出端流动到输入端,这时会进行消耗正向传播过程中存储的中间结果(如果有的话),并进行参数矩阵的更新(如果有并且要)。并且需要考虑反向传播时某些层可能并不想更新参数仅传递梯度信息、或者根本就不传递梯度信息,这些设置都应该纳入设计考虑内。 42 | - 参数矩阵保存与加载:每次正向传播与反向传播完成一轮,模型中参数就得到了一轮更新。无论是多轮训练完成后还是训练到一半想要中止都需要能够将参数矩阵保存起来,以便下一次能够无损的加载进来,完美复原模型的状态。 43 | - 训练与预测:一个深度学习模型的典型最终目的就是预测(或者生成,总之就是根据输入得到输出),这时模型已经训练好了,其中的参数已经固定下来,只进行正向传播,反向传播过程就不需要了。需要考虑这种状态下一些不需要保存或者计算的东西(比如中间结果)就可以省略了,以提升运行性能(这个其实是在求值阶段来优化)。 44 | - 复合层和循环层: 45 | - 复杂的层是通过简单的层通过组合连接起来的,最终组成的复合层需要能够像基本层那样使用。 46 | - 复合层可能组合基本层也可能组合复合层。 47 | - 复合层声明中需要提供的信息:复合层包含的子层,复合层的输入流动到了哪些子层的哪些端口,哪些子层的哪些接口流动到了复合层的输出,子层之间的数据流动。 48 | - 复合层面临的一个问题:某些子层可能原本并不传递梯度信息,但是组成复合层之后它的前驱却需要反向传播的梯度信息,这时就需要更新它的后继更新自己的状态为传播梯度信息。这需要编译期拓扑排序然后根据排序结果来处理。 49 | - 为符合层设置策略时可能还需要考虑其子层的独特策略,书中引入为子层引入了嵌套在策略容器中的子层策略(还有什么子层的独特策略覆盖复合层的互斥子层策略之类的路基),相关的策略元函数都需要修改以支持,还要支持深层次的嵌套。 50 | - 最后,一个基本层有的所有东西都需要适当调整并在复合层上实现。 51 | - 循环层没了解清楚,略。 52 | - 求值与优化: 53 | - 正向传播和反向传播中所说的运算都是构造表达式模板的过程。而求值则是将这些表达式模板对象转换为对应的主体类型的过程。 54 | - 对于每个表达式模板来说,其与其内部的子表达式模板和最底层的数据构成一个表达式树的结构。但是一个模型中某些中间计算结果会被共享,所以严格来说是一个有向无环图。 55 | - 为了能够正确求值,就需要对这个图中的表达式模板的求值顺序进行进行组织。 56 | - MetaNN还通过几种手段对求值过程尽可能优化:避免重复计算、同类合并计算、多运算协同优化。 57 | - 更多细节略。 -------------------------------------------------------------------------------- /MetaNN/data/matrix/matrix.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace MetaNN 10 | { 11 | 12 | template 13 | class Matrix; 14 | 15 | // 提供底层访问接口 16 | template 17 | struct LowerAccessImpl>; 18 | 19 | // 为Batch提供前向声明 20 | template 21 | class Batch; 22 | 23 | template 24 | class Matrix 25 | { 26 | static_assert(std::is_same_v, TElem>, "TElem is not an available type"); 27 | friend struct LowerAccessImpl>; 28 | friend struct Batch; 29 | public: 30 | using Category = CategoryTags::Matrix; 31 | using ElementType = TElem; 32 | using DeviceType = DeviceTags::CPU; 33 | public: 34 | Matrix(std::size_t row = 0, std::size_t col = 0) 35 | : m_mem(row * col) 36 | , m_rowNum(row) 37 | , m_colNum(col) 38 | , m_rowLen(col) 39 | { 40 | } 41 | 42 | // 访问接口 43 | std::size_t rowNum() const 44 | { 45 | return m_rowNum; 46 | } 47 | std::size_t colNum() const 48 | { 49 | return m_colNum; 50 | } 51 | // 写操作,需要可写才能调用 52 | void setValue(std::size_t row, std::size_t col, ElementType val) 53 | { 54 | assert(availableForWrite()); 55 | assert(row < m_rowNum && col < m_colNum); 56 | m_mem.rawMemory()[row * m_rowLen + col] = val; 57 | } 58 | // 读操作,返回副本而非引用 59 | const auto operator()(std::size_t row, std::size_t col) const 60 | { 61 | assert(row < m_rowNum && col < m_colNum); 62 | return m_mem.rawMemory()[row * m_rowLen + col]; 63 | } 64 | bool availableForWrite() const 65 | { 66 | return m_mem.useCount() == 1; 67 | } 68 | 69 | // 子矩阵接口,浅拷贝,共享存储空间,区间前闭后开 70 | Matrix subMatrix(std::size_t rowBegin, std::size_t rowEnd, std::size_t colBegin, std::size_t colEnd) 71 | { 72 | assert(rowBegin < m_rowNum && colBegin < m_colNum); 73 | assert(rowEnd <= m_rowNum && colEnd <= m_colNum); 74 | TElem* pos = m_mem.rawMemory() + rowBegin * m_rowLen + colBegin; 75 | return Matrix(m_mem.sharedPtr(), pos, rowEnd - rowBegin, colEnd - colBegin, m_rowLen); 76 | } 77 | 78 | // 求值接口: todo 79 | 80 | private: 81 | // 为构造子矩阵准备 82 | Matrix(std::shared_ptr spMem, ElementType* pMemStart, 83 | std::size_t row, std::size_t col, std::size_t rowLen) 84 | : m_mem(spMem, pMemStart) 85 | , m_rowNum(row) 86 | , m_colNum(col) 87 | , m_rowLen(rowLen) 88 | { 89 | } 90 | private: 91 | ContinuousMemory m_mem; 92 | std::size_t m_rowNum; 93 | std::size_t m_colNum; 94 | std::size_t m_rowLen; 95 | }; 96 | 97 | // 底层访问 98 | template 99 | struct LowerAccessImpl> 100 | { 101 | LowerAccessImpl(Matrix p) 102 | : m_matrix(p) {} 103 | 104 | // 使用这个接口提供的指针进行写操作具有一定安全性隐患,因为不会检查共享数量 105 | // 所以这个只应该提供给库作者使用,提供性能更高的操作,相当于一个后门,使用时应当非常注意 106 | auto mutableRawMemory() 107 | { 108 | return m_matrix.m_mem.rawMemory(); 109 | } 110 | 111 | const auto rawMemory() const 112 | { 113 | return m_matrix.m_mem.rawMemory(); 114 | } 115 | 116 | std::size_t rowLen() const 117 | { 118 | return m_matrix.m_rowLen; 119 | } 120 | 121 | private: 122 | Matrix m_matrix; 123 | }; 124 | 125 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/operators.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace MetaNN 10 | { 11 | 12 | // 表达式模板:不提供写接口 13 | 14 | // 一元运算 15 | template 16 | class UnaryOp : public OpOrganizer> 17 | { 18 | static_assert(std::is_same_v, TData>, "TData is not an available type"); 19 | public: 20 | using Category = OpCateCal; 21 | using ElementType = OpElementType; 22 | using DeviceType = OpDeviceType; 23 | public: 24 | UnaryOp(TData data) 25 | : OpOrganizer(data) 26 | , m_data(std::move(data)) 27 | { 28 | } 29 | 30 | const TData& operand() const 31 | { 32 | return m_data; 33 | } 34 | 35 | // 求值接口: todo 36 | 37 | private: 38 | TData m_data; 39 | using TPrincipal = PrincipalDataType; 40 | }; 41 | 42 | // 二元运算 43 | template 44 | class BinaryOp : public OpOrganizer> 45 | { 46 | static_assert(std::is_same_v, TData1>, "TData1 is not an available type"); 47 | static_assert(std::is_same_v, TData2>, "TData2 is not an available type"); 48 | public: 49 | using Category = OpCateCal; 50 | using ElementType = OpElementType; 51 | using DeviceType = OpDeviceType; 52 | public: 53 | BinaryOp(TData1 data1, TData2 data2) 54 | : OpOrganizer(data1, data2) 55 | , m_data1(std::move(data1)) 56 | , m_data2(std::move(data2)) 57 | { 58 | } 59 | 60 | const TData1& operand1() const 61 | { 62 | return m_data1; 63 | } 64 | const TData2& operand2() const 65 | { 66 | return m_data2; 67 | } 68 | 69 | // 求值接口: todo 70 | 71 | private: 72 | TData1 m_data1; 73 | TData2 m_data2; 74 | using TPrincipal = PrincipalDataType; 75 | }; 76 | 77 | // 三元运算 78 | template 79 | class TernaryOp : public OpOrganizer> 80 | { 81 | static_assert(std::is_same_v, TData1>, "TData1 is not an available type"); 82 | static_assert(std::is_same_v, TData2>, "TData2 is not an available type"); 83 | static_assert(std::is_same_v, TData3>, "TData3 is not an available type"); 84 | public: 85 | using Category = OpCateCal; 86 | using ElementType = OpElementType; 87 | using DeviceType = OpDeviceType; 88 | public: 89 | TernaryOp(TData1 data1, TData2 data2, TData3 data3) 90 | : OpOrganizer(data1, data2, data3) 91 | , m_data1(std::move(data1)) 92 | , m_data2(std::move(data2)) 93 | , m_data3(std::move(data3)) 94 | { 95 | } 96 | 97 | const TData1& operand1() const 98 | { 99 | return m_data1; 100 | } 101 | const TData2& operand2() const 102 | { 103 | return m_data2; 104 | } 105 | const TData3& operand3() const 106 | { 107 | return m_data3; 108 | } 109 | 110 | // 求值接口: todo 111 | 112 | private: 113 | TData1 m_data1; 114 | TData2 m_data2; 115 | TData3 m_data3; 116 | using TPrincipal = PrincipalDataType; 117 | }; 118 | 119 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/data/traits.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace MetaNN 8 | { 9 | 10 | // 前向声明 11 | template class Scalar; 12 | template class Matrix; 13 | template class Batch; 14 | 15 | // 主体类型 16 | template 17 | struct PrincipalDataType_; 18 | 19 | template 20 | struct PrincipalDataType_ 21 | { 22 | using type = Scalar; 23 | }; 24 | 25 | template 26 | struct PrincipalDataType_ 27 | { 28 | using type = Matrix; 29 | }; 30 | 31 | template 32 | struct PrincipalDataType_ 33 | { 34 | using type = Batch; 35 | }; 36 | 37 | template 38 | struct PrincipalDataType_ 39 | { 40 | using type = Batch; 41 | }; 42 | 43 | template 44 | using PrincipalDataType = typename PrincipalDataType_::type; 45 | 46 | 47 | // 获取一个类型的类别,除了可以定义category嵌套类型以声明类别,也可以通过特化DataCategory_非侵入式声明 48 | template requires requires { typename T::Category; } 49 | struct DataCategory_ 50 | { 51 | using type = typename T::Category; 52 | }; 53 | // 对const和引用偏特化,使其更加通用 54 | template 55 | struct DataCategory_ : DataCategory_ {}; 56 | template 57 | struct DataCategory_ : DataCategory_ {}; 58 | template 59 | struct DataCategory_ : DataCategory_ {}; 60 | 61 | template 62 | using DataCategory = typename DataCategory_::type; 63 | 64 | // 合法数据类型的约束 65 | template 66 | concept ValidDataTypeC = requires 67 | { 68 | typename TDataType::ElementType; 69 | typename TDataType::DeviceType; 70 | }; 71 | 72 | template 73 | concept ValidMatrixTypeC = requires(const TDataType& data) 74 | { 75 | { data.rowNum() } -> std::same_as; 76 | { data.colNum() } -> std::same_as; 77 | }; 78 | 79 | template 80 | concept ValidBatchTypeC = requires(const TDataType& data) 81 | { 82 | { data.batchNum() } -> std::same_as; 83 | }; 84 | 85 | template 86 | concept ValidEvaluationTypeC = requires(TDataType data) 87 | { 88 | data; 89 | // todo yet! 90 | }; 91 | 92 | // 类别判断概念 93 | // Scalar 94 | template 95 | concept IsScalarC = std::is_same_v, CategoryTags::Scalar> && ValidDataTypeC; 96 | 97 | // Matrix 98 | template 99 | concept IsMatrixC = std::is_same_v, CategoryTags::Matrix> && ValidDataTypeC && ValidMatrixTypeC; 100 | 101 | // BatchScalar 102 | template 103 | concept IsBatchScalarC = std::is_same_v, CategoryTags::BatchScalar> && ValidDataTypeC && ValidBatchTypeC; 104 | 105 | // BatchMatrix 106 | template 107 | concept IsBatchMatrixC = std::is_same_v, CategoryTags::BatchMatrix> && ValidDataTypeC && ValidBatchTypeC && ValidMatrixTypeC; 108 | 109 | // 范围更广的类别判断概念,对引用和const修饰的复合类型则根据其底层类型判断 110 | template 111 | concept ScalarC = IsScalarC>; 112 | 113 | template 114 | concept MatrixC = IsMatrixC>; 115 | 116 | template 117 | concept BatchScalarC = IsBatchScalarC>; 118 | 119 | template 120 | concept BatchMatrixC = IsBatchMatrixC>; 121 | 122 | // 用于static_assert 123 | template 124 | constexpr bool DependencyFalse = false; 125 | 126 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/dot.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace MetaNN 7 | { 8 | 9 | // 矩阵乘法 10 | // 支持类型: 11 | // 矩阵与矩阵 12 | // 矩阵与矩阵列表 13 | // 矩阵列表与矩阵 14 | // 矩阵列表与矩阵列表 15 | 16 | // 重载OpOrganizer定义结果矩阵的行数和列数 17 | template<> 18 | class OpOrganizer 19 | { 20 | public: 21 | template 22 | OpOrganizer(const T1& data1, const T2& data2) 23 | : m_rowNum(data1.rowNum()) 24 | , m_colNum(data2.colNum()) 25 | { 26 | assert(data1.colNum() == data2.rowNum()); 27 | } 28 | 29 | std::size_t rowNum() const 30 | { 31 | return m_rowNum; 32 | } 33 | std::size_t colNum() const 34 | { 35 | return m_colNum; 36 | } 37 | 38 | private: 39 | std::size_t m_rowNum; 40 | std::size_t m_colNum; 41 | }; 42 | 43 | template<> 44 | class OpOrganizer 45 | { 46 | public: 47 | template 48 | OpOrganizer(const T1& data1, const T2& data2) 49 | : m_rowNum(data1.rowNum()) 50 | , m_colNum(data2.colNum()) 51 | , m_batchNum(data1.batchNum()) 52 | { 53 | assert(data1.colNum() == data2.rowNum()); 54 | assert(data1.batchNum() == data2.batchNum()); 55 | } 56 | 57 | std::size_t rowNum() const 58 | { 59 | return m_rowNum; 60 | } 61 | std::size_t colNum() const 62 | { 63 | return m_colNum; 64 | } 65 | 66 | private: 67 | std::size_t m_rowNum; 68 | std::size_t m_colNum; 69 | std::size_t m_batchNum; 70 | }; 71 | 72 | // 矩阵乘法运算 73 | template 74 | class OpDot 75 | { 76 | using RawT1 = std::remove_cvref_t; 77 | using RawT2 = std::remove_cvref_t; 78 | public: 79 | // 类别相同:矩阵与矩阵、矩阵列表与矩阵列表 80 | static auto eval(T1&& data1, T2&& data2) requires (MatrixC && MatrixC) || (BatchMatrixC && BatchMatrixC) 81 | { 82 | static_assert(std::is_same_v, "Matrices with different element types can not dot directly"); 83 | static_assert(std::is_same_v, "Matrices with different device types can not dot directly"); 84 | 85 | using ResType = BinaryOp; 86 | return ResType(std::forward(data1), std::forward(data2)); 87 | } 88 | // 矩阵与矩阵列表 89 | static auto eval(T1&& data1, T2&& data2) requires MatrixC && BatchMatrixC 90 | { 91 | static_assert(std::is_same_v, "Matrices with different element types can not dot directly"); 92 | static_assert(std::is_same_v, "Matrices with different device types can not dot directly"); 93 | 94 | Duplicate tmpDuplicateMatrix(std::forward(data1), data2.batchNum()); 95 | using ResType = BinaryOp, RawT2>; 96 | return ResType(std::move(tmpDuplicateMatrix), std::forward(data2)); 97 | } 98 | // 矩阵列表与矩阵 99 | static auto eval(T1&& data1, T2&& data2) requires BatchMatrixC && MatrixC 100 | { 101 | static_assert(std::is_same_v, "Matrices with different element types can not dot directly"); 102 | static_assert(std::is_same_v, "Matrices with different device types can not dot directly"); 103 | 104 | Duplicate tmpDuplicateMatrix(std::forward(data2), data1.batchNum()); 105 | using ResType = BinaryOp>; 106 | return ResType(std::forward(data1), std::move(tmpDuplicateMatrix)); 107 | 108 | } 109 | }; 110 | 111 | template 112 | requires (MatrixC && MatrixC) || 113 | (MatrixC && BatchMatrixC) || 114 | (BatchMatrixC && MatrixC) || 115 | (BatchMatrixC && BatchMatrixC) 116 | auto dot(T1&& data1, T2&& data2) 117 | { 118 | return OpDot::eval(std::forward(data1), std::forward(data2)); 119 | } 120 | 121 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/param_initializer/param_initializer.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace MetaNN 14 | { 15 | 16 | namespace NsParamInitializer 17 | { 18 | 19 | // 将传入的初始化策略中的模板参数作为VarTypeDict的模板实参,并且去重之后得到输出类型 20 | // 仅针对: PInitializerIs, PWeightInitializerIs, PBiasInitializerIs 21 | // MetaNN中有去重逻辑,这里没有加,暂时没看到去重的必要性。 22 | template 23 | struct FillerTagsFromPolicy_ 24 | { 25 | using type = TRes; 26 | }; 27 | 28 | template 29 | struct FillerTagsFromPolicy_, PInitializerIs, TRest...> 30 | : FillerTagsFromPolicy_, TRest...> {}; 31 | 32 | template 33 | struct FillerTagsFromPolicy_, PWeightInitializerIs, TRest...> 34 | : FillerTagsFromPolicy_, TRest...> {}; 35 | 36 | template 37 | struct FillerTagsFromPolicy_, PBiasInitializerIs, TRest...> 38 | : FillerTagsFromPolicy_, TRest...> {}; 39 | 40 | template 41 | struct FillerTagsFromPolicy_, SubPolicyContainer, TRest...> 42 | : FillerTagsFromPolicy_, TSub..., TRest...> {}; 43 | 44 | } // namespace NsParamInitializer 45 | 46 | template 47 | using FillerTags2NamedParams = typename NsParamInitializer::FillerTagsFromPolicy_, TPolicies...>::type; 48 | 49 | // TFiller是一个VarTypeDict类型 50 | template 51 | class ParamInitializer 52 | { 53 | public: 54 | using PolicyCont = TPolicyContainer; 55 | 56 | ParamInitializer(TFillers&& filler) 57 | : m_filler(std::move(filler)) 58 | { 59 | } 60 | 61 | // 初始化器的设置与获取 62 | template 63 | auto setFiller(TVal&& val) && 64 | { 65 | auto newFiller = std::move(m_filler).template set(std::forward(val)); 66 | using newFillerType = std::remove_cvref_t; 67 | return ParamInitializer(std::move(newFiller)); 68 | } 69 | template 70 | auto getFiller() 71 | { 72 | return m_filler.template get(); 73 | } 74 | 75 | // 参数矩阵的设置与获取 76 | template 77 | void setMatrix(const std::string& name, const Matrix& param) 78 | { 79 | if (m_params.find(name) != m_params.end()) 80 | { 81 | throw std::runtime_error("Duplicate parameter matrix: " + name); 82 | } 83 | m_params.insert({name, param}); 84 | } 85 | // 通过深拷贝方式获取参数矩阵,不会共享内存,从ParamInitializer获取的参数矩阵不会共享数据 86 | template 87 | void getMatrix(const std::string& name, Matrix& res) const 88 | { 89 | auto it = m_params.find(name); 90 | if (it == m_params.end()) 91 | { 92 | throw std::runtime_error("Parameter no exist: " + name); 93 | } 94 | const auto& mat = it->second; 95 | if (mat.rowNum() != res.rowNum() || mat.colNum() != res.colNum()) 96 | { 97 | throw std::runtime_error("Matrices dimension mismatch!"); 98 | } 99 | dataCopy(mat, res); 100 | } 101 | bool IsMatrixExist(const std::string& name) const 102 | { 103 | return m_params.find(name) != m_params.end(); 104 | } 105 | 106 | private: 107 | TFillers m_filler; 108 | std::map> m_params; 109 | }; 110 | 111 | template 112 | auto makeInitializer() 113 | { 114 | using DictType = FillerTags2NamedParams; // 获取传入的策略对象标签并构造出对应参数的VarTypeDict类型 115 | using FillerDictType = std::remove_cvref_t; // 构造异类词典对象类型VarTypeDict::Values 116 | return PramInitializer, FillerDictType>(DictType::create()); // 构造出参数初始化器对象 117 | } 118 | 119 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/divide.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // 除法操作 9 | // 支持类型: 10 | // 标量与矩阵 11 | // 标量与矩阵列表 12 | // 矩阵与矩阵 13 | // 矩阵与矩阵列表 14 | // 矩阵列表与矩阵列表 15 | 16 | template 17 | class OpDivide 18 | { 19 | using RawT1 = std::remove_cvref_t; 20 | using RawT2 = std::remove_cvref_t; 21 | public: 22 | // 类别相同:矩阵与矩阵、矩阵列表与矩阵列表,平凡实现 23 | static auto eval(T1&& data1, T2&& data2) requires (MatrixC && MatrixC) || (BatchMatrixC && BatchMatrixC) 24 | { 25 | static_assert(std::is_same_v, "Matrices with different element types can not divide directly"); 26 | static_assert(std::is_same_v, "Matrices with different device types can not divide directly"); 27 | using ResType = BinaryOp; 28 | return ResType(std::forward(data1), std::forward(data2)); 29 | } 30 | // 标量与矩阵:将标量构造为平凡矩阵,转换为矩阵与矩阵操作 31 | static auto eval(T1&& data1, T2&& data2) requires (ScalarC && MatrixC) || (MatrixC && ScalarC) 32 | { 33 | if constexpr (ScalarC && MatrixC) 34 | { 35 | using ElementType = typename T2::ElementType; 36 | using DeviceType = typename T2::DeviceType; 37 | auto tmpTrivialMatix = makeTrivialMatrix(data2.rowNum(), data2.colNum(), data1); 38 | using ResType = BinaryOp, RawT2>; 39 | return ResType(std::move(tmpTrivialMatix), std::forward(data2)); 40 | } 41 | else // MatrixC && ScalarC 42 | { 43 | return eval(std::forward(data2), std::forward(data1)); 44 | } 45 | } 46 | // 标量与矩阵列表:将标量构造为平凡矩阵的重复列表,转换为矩阵列表与矩阵列表操作 47 | static auto eval(T1&& data1, T2&& data2) requires (ScalarC && BatchMatrixC) || (BatchMatrixC && ScalarC) 48 | { 49 | if constexpr (ScalarC && BatchMatrixC) 50 | { 51 | using ElementType = typename T2::ElementType; 52 | using DeviceType = typename T2::DeviceType; 53 | auto tmpTrivialMatrix = makeTrivialMatrix(data2.rowNum(), data2.colNum(), data1); 54 | auto tmpDuplicateTrivialMatrix = makeDuplicate(data2.batchNum(), std::move(tmpTrivialMatrix)); 55 | using ResType = BinaryOp, RawT2>; 56 | return ResType(std::move(tmpDuplicateTrivialMatrix), std::forward(data2)); 57 | } 58 | else // BatchMatrixC && ScalarC 59 | { 60 | return eval(std::forward(data2), std::forward(data1)); 61 | } 62 | } 63 | // 矩阵与矩阵列表:将矩阵构造为重复矩阵列表,转换为矩阵列表与矩阵列表操作 64 | static auto eval(T1&& data1, T2&& data2) requires (MatrixC && BatchMatrixC) || (BatchMatrixC && MatrixC) 65 | { 66 | static_assert(std::is_same_v, "Matrices with different element types can not divide directly"); 67 | static_assert(std::is_same_v, "Matrices with different device types can not divide directly"); 68 | if constexpr (MatrixC && BatchMatrixC) 69 | { 70 | auto tmpDuplicateMatrix = makeDuplicate(data2.batchNum(), std::move(data1)); 71 | using ResType = BinaryOp, RawT2>; 72 | return ResType(std::move(tmpDuplicateMatrix), std::forward(data2)); 73 | } 74 | else // BatchMatrixC && MatrixC 75 | { 76 | return eval(std::forward(data2), std::forward(data1)); 77 | } 78 | } 79 | }; 80 | 81 | template 82 | requires (ScalarC && MatrixC) || 83 | (MatrixC && ScalarC) || 84 | (ScalarC && BatchMatrixC) || 85 | (BatchMatrixC && ScalarC) || 86 | (MatrixC && MatrixC) || 87 | (MatrixC && BatchMatrixC) || 88 | (BatchMatrixC && MatrixC) || 89 | (BatchMatrixC && BatchMatrixC) 90 | auto operator/(T1&& data1, T2&& data2) 91 | { 92 | return OpDivide::eval(std::forward(data1), std::forward(data2)); 93 | } 94 | 95 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/subtract.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // 减法操作:类似于加法 9 | // 支持类型: 10 | // 标量与矩阵 11 | // 标量与矩阵列表 12 | // 矩阵与矩阵 13 | // 矩阵与矩阵列表 14 | // 矩阵列表与矩阵列表 15 | 16 | template 17 | class OpSubtract 18 | { 19 | using RawT1 = std::remove_cvref_t; 20 | using RawT2 = std::remove_cvref_t; 21 | public: 22 | // 类别相同:矩阵与矩阵、矩阵列表与矩阵列表,平凡实现 23 | static auto eval(T1&& data1, T2&& data2) requires (MatrixC && MatrixC) || (BatchMatrixC && BatchMatrixC) 24 | { 25 | static_assert(std::is_same_v, "Matrices with different element types can not subtract directly"); 26 | static_assert(std::is_same_v, "Matrices with different device types can not subtract directly"); 27 | using ResType = BinaryOp; 28 | return ResType(std::forward(data1), std::forward(data2)); 29 | } 30 | // 标量与矩阵:将标量构造为平凡矩阵,转换为矩阵与矩阵操作 31 | static auto eval(T1&& data1, T2&& data2) requires (ScalarC && MatrixC) || (MatrixC && ScalarC) 32 | { 33 | if constexpr (ScalarC && MatrixC) 34 | { 35 | using ElementType = typename T2::ElementType; 36 | using DeviceType = typename T2::DeviceType; 37 | auto tmpTrivialMatix = makeTrivialMatrix(data2.rowNum(), data2.colNum(), data1); 38 | using ResType = BinaryOp, RawT2>; 39 | return ResType(std::move(tmpTrivialMatix), std::forward(data2)); 40 | } 41 | else // MatrixC && ScalarC 42 | { 43 | return eval(std::forward(data2), std::forward(data1)); 44 | } 45 | } 46 | // 标量与矩阵列表:将标量构造为平凡矩阵的重复列表,转换为矩阵列表与矩阵列表操作 47 | static auto eval(T1&& data1, T2&& data2) requires (ScalarC && BatchMatrixC) || (BatchMatrixC && ScalarC) 48 | { 49 | if constexpr (ScalarC && BatchMatrixC) 50 | { 51 | using ElementType = typename T2::ElementType; 52 | using DeviceType = typename T2::DeviceType; 53 | auto tmpTrivialMatrix = makeTrivialMatrix(data2.rowNum(), data2.colNum(), data1); 54 | auto tmpDuplicateTrivialMatrix = makeDuplicate(data2.batchNum(), std::move(tmpTrivialMatrix)); 55 | using ResType = BinaryOp, RawT2>; 56 | return ResType(std::move(tmpDuplicateTrivialMatrix), std::forward(data2)); 57 | } 58 | else // BatchMatrixC && ScalarC 59 | { 60 | return eval(std::forward(data2), std::forward(data1)); 61 | } 62 | } 63 | // 矩阵与矩阵列表:将矩阵构造为重复矩阵列表,转换为矩阵列表与矩阵列表操作 64 | static auto eval(T1&& data1, T2&& data2) requires (MatrixC && BatchMatrixC) || (BatchMatrixC && MatrixC) 65 | { 66 | static_assert(std::is_same_v, "Matrices with different element types can not subtract directly"); 67 | static_assert(std::is_same_v, "Matrices with different device types can not subtract directly"); 68 | if constexpr (MatrixC && BatchMatrixC) 69 | { 70 | auto tmpDuplicateMatrix = makeDuplicate(data2.batchNum(), std::move(data1)); 71 | using ResType = BinaryOp, RawT2>; 72 | return ResType(std::move(tmpDuplicateMatrix), std::forward(data2)); 73 | } 74 | else // BatchMatrixC && MatrixC 75 | { 76 | return eval(std::forward(data2), std::forward(data1)); 77 | } 78 | } 79 | }; 80 | 81 | template 82 | requires (ScalarC && MatrixC) || 83 | (MatrixC && ScalarC) || 84 | (ScalarC && BatchMatrixC) || 85 | (BatchMatrixC && ScalarC) || 86 | (MatrixC && MatrixC) || 87 | (MatrixC && BatchMatrixC) || 88 | (BatchMatrixC && MatrixC) || 89 | (BatchMatrixC && BatchMatrixC) 90 | auto operator-(T1&& data1, T2&& data2) 91 | { 92 | return OpSubtract::eval(std::forward(data1), std::forward(data2)); 93 | } 94 | 95 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/element_mul.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace MetaNN 6 | { 7 | 8 | // 元素乘法操作:注意这不是矩阵相乘,而是元素对应相乘 9 | // 支持类型: 10 | // 标量与矩阵 11 | // 标量与矩阵列表 12 | // 矩阵与矩阵 13 | // 矩阵与矩阵列表 14 | // 矩阵列表与矩阵列表 15 | 16 | template 17 | class OpElementMul 18 | { 19 | using RawT1 = std::remove_cvref_t; 20 | using RawT2 = std::remove_cvref_t; 21 | public: 22 | // 类别相同:矩阵与矩阵、矩阵列表与矩阵列表,平凡实现 23 | static auto eval(T1&& data1, T2&& data2) requires (MatrixC && MatrixC) || (BatchMatrixC && BatchMatrixC) 24 | { 25 | static_assert(std::is_same_v, "Matrices with different element types can not multiply directly"); 26 | static_assert(std::is_same_v, "Matrices with different device types can not multiply directly"); 27 | using ResType = BinaryOp; 28 | return ResType(std::forward(data1), std::forward(data2)); 29 | } 30 | // 标量与矩阵:将标量构造为平凡矩阵,转换为矩阵与矩阵操作 31 | static auto eval(T1&& data1, T2&& data2) requires (ScalarC && MatrixC) || (MatrixC && ScalarC) 32 | { 33 | if constexpr (ScalarC && MatrixC) 34 | { 35 | using ElementType = typename T2::ElementType; 36 | using DeviceType = typename T2::DeviceType; 37 | auto tmpTrivialMatix = makeTrivialMatrix(data2.rowNum(), data2.colNum(), data1); 38 | using ResType = BinaryOp, RawT2>; 39 | return ResType(std::move(tmpTrivialMatix), std::forward(data2)); 40 | } 41 | else // MatrixC && ScalarC 42 | { 43 | return eval(std::forward(data2), std::forward(data1)); 44 | } 45 | } 46 | // 标量与矩阵列表:将标量构造为平凡矩阵的重复列表,转换为矩阵列表与矩阵列表操作 47 | static auto eval(T1&& data1, T2&& data2) requires (ScalarC && BatchMatrixC) || (BatchMatrixC && ScalarC) 48 | { 49 | if constexpr (ScalarC && BatchMatrixC) 50 | { 51 | using ElementType = typename T2::ElementType; 52 | using DeviceType = typename T2::DeviceType; 53 | auto tmpTrivialMatrix = makeTrivialMatrix(data2.rowNum(), data2.colNum(), data1); 54 | auto tmpDuplicateTrivialMatrix = makeDuplicate(data2.batchNum(), std::move(tmpTrivialMatrix)); 55 | using ResType = BinaryOp, RawT2>; 56 | return ResType(std::move(tmpDuplicateTrivialMatrix), std::forward(data2)); 57 | } 58 | else // BatchMatrixC && ScalarC 59 | { 60 | return eval(std::forward(data2), std::forward(data1)); 61 | } 62 | } 63 | // 矩阵与矩阵列表:将矩阵构造为重复矩阵列表,转换为矩阵列表与矩阵列表操作 64 | static auto eval(T1&& data1, T2&& data2) requires (MatrixC && BatchMatrixC) || (BatchMatrixC && MatrixC) 65 | { 66 | static_assert(std::is_same_v, "Matrices with different element types can not multiply directly"); 67 | static_assert(std::is_same_v, "Matrices with different device types can not multiply directly"); 68 | if constexpr (MatrixC && BatchMatrixC) 69 | { 70 | auto tmpDuplicateMatrix = makeDuplicate(data2.batchNum(), std::move(data1)); 71 | using ResType = BinaryOp, RawT2>; 72 | return ResType(std::move(tmpDuplicateMatrix), std::forward(data2)); 73 | } 74 | else // BatchMatrixC && MatrixC 75 | { 76 | return eval(std::forward(data2), std::forward(data1)); 77 | } 78 | } 79 | }; 80 | 81 | template 82 | requires (ScalarC && MatrixC) || 83 | (MatrixC && ScalarC) || 84 | (ScalarC && BatchMatrixC) || 85 | (BatchMatrixC && ScalarC) || 86 | (MatrixC && MatrixC) || 87 | (MatrixC && BatchMatrixC) || 88 | (BatchMatrixC && MatrixC) || 89 | (BatchMatrixC && BatchMatrixC) 90 | auto operator*(T1&& data1, T2&& data2) 91 | { 92 | return OpElementMul::eval(std::forward(data1), std::forward(data2)); 93 | } 94 | 95 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/operator/add.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace MetaNN 8 | { 9 | 10 | // 加法操作 11 | // 支持类型: 12 | // 标量与矩阵 13 | // 标量与矩阵列表 14 | // 矩阵与矩阵 15 | // 矩阵与矩阵列表 16 | // 矩阵列表与矩阵列表 17 | 18 | template 19 | class OpAdd 20 | { 21 | using RawT1 = std::remove_cvref_t; 22 | using RawT2 = std::remove_cvref_t; 23 | public: 24 | // 类别相同:矩阵与矩阵、矩阵列表与矩阵列表,平凡实现 25 | static auto eval(T1&& data1, T2&& data2) requires (MatrixC && MatrixC) || (BatchMatrixC && BatchMatrixC) 26 | { 27 | static_assert(std::is_same_v, "Matrices with different element types can not add directly"); 28 | static_assert(std::is_same_v, "Matrices with different device types can not add directly"); 29 | 30 | using ResType = BinaryOp; 31 | return ResType(std::forward(data1), std::forward(data2)); 32 | } 33 | // 标量与矩阵:将标量构造为平凡矩阵,转换为矩阵与矩阵操作 34 | static auto eval(T1&& data1, T2&& data2) requires (ScalarC && MatrixC) || (MatrixC && ScalarC) 35 | { 36 | if constexpr (ScalarC && MatrixC) 37 | { 38 | using ElementType = typename T2::ElementType; 39 | using DeviceType = typename T2::DeviceType; 40 | auto tmpTrivialMatix = makeTrivialMatrix(data2.rowNum(), data2.colNum(), data1); 41 | using ResType = BinaryOp, RawT2>; 42 | return ResType(std::move(tmpTrivialMatix), std::forward(data2)); 43 | } 44 | else // MatrixC && ScalarC 45 | { 46 | return eval(std::forward(data2), std::forward(data1)); 47 | } 48 | } 49 | // 标量与矩阵列表:将标量构造为平凡矩阵的重复列表,转换为矩阵列表与矩阵列表操作 50 | static auto eval(T1&& data1, T2&& data2) requires (ScalarC && BatchMatrixC) || (BatchMatrixC && ScalarC) 51 | { 52 | if constexpr (ScalarC && BatchMatrixC) 53 | { 54 | using ElementType = typename T2::ElementType; 55 | using DeviceType = typename T2::DeviceType; 56 | auto tmpTrivialMatrix = makeTrivialMatrix(data2.rowNum(), data2.colNum(), data1); 57 | auto tmpDuplicateTrivialMatrix = makeDuplicate(data2.batchNum(), std::move(tmpTrivialMatrix)); 58 | using ResType = BinaryOp, RawT2>; 59 | return ResType(std::move(tmpDuplicateTrivialMatrix), std::forward(data2)); 60 | } 61 | else // BatchMatrixC && ScalarC 62 | { 63 | return eval(std::forward(data2), std::forward(data1)); 64 | } 65 | } 66 | // 矩阵与矩阵列表:将矩阵构造为重复矩阵列表,转换为矩阵列表与矩阵列表操作 67 | static auto eval(T1&& data1, T2&& data2) requires (MatrixC && BatchMatrixC) || (BatchMatrixC && MatrixC) 68 | { 69 | static_assert(std::is_same_v, "Matrices with different element types can not add directly"); 70 | static_assert(std::is_same_v, "Matrices with different device types can not add directly"); 71 | 72 | if constexpr (MatrixC && BatchMatrixC) 73 | { 74 | auto tmpDuplicateMatrix = makeDuplicate(data2.batchNum(), std::move(data1)); 75 | using ResType = BinaryOp, RawT2>; 76 | return ResType(std::move(tmpDuplicateMatrix), std::forward(data2)); 77 | } 78 | else // BatchMatrixC && MatrixC 79 | { 80 | return eval(std::forward(data2), std::forward(data1)); 81 | } 82 | } 83 | }; 84 | 85 | template 86 | requires (ScalarC && MatrixC) || 87 | (MatrixC && ScalarC) || 88 | (ScalarC && BatchMatrixC) || 89 | (BatchMatrixC && ScalarC) || 90 | (MatrixC && MatrixC) || 91 | (MatrixC && BatchMatrixC) || 92 | (BatchMatrixC && MatrixC) || 93 | (BatchMatrixC && BatchMatrixC) 94 | auto operator+(T1&& data1, T2&& data2) 95 | { 96 | return OpAdd::eval(std::forward(data1), std::forward(data2)); 97 | } 98 | 99 | } // namespace MetaNN 100 | -------------------------------------------------------------------------------- /MetaNN/facility/var_type_dict.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | // 异类类型词典,为了实现命名参数 8 | // 原理:通过下标建立从可变参数外层类模板到其中的嵌套可变参数类模板的参数之间的关系 9 | // 键是编译期常量,值是运行时对象。 10 | 11 | namespace MetaNN 12 | { 13 | 14 | struct NullParameter {}; 15 | 16 | namespace NsVarTypeDict 17 | { 18 | 19 | // 将N个NullParameter占位类型添加到类型容器左端,主要用于创建保存N个NullParameter的类型容器 20 | template class TCont, typename... Ts> 21 | struct Create_ 22 | { 23 | using type = typename Create_::type; 24 | }; 25 | 26 | template class TCont, typename... Ts> 27 | struct Create_<0, TCont, Ts...> 28 | { 29 | using type = TCont; 30 | }; 31 | 32 | // 替换typelist中的指定位置的类型为TVal,M是辅助变量,表示已经扫描过的类型数量 33 | template 34 | struct NewTupleType_; 35 | // N!=M的情况,继续扫描 36 | template class TCont, typename... TModifiedTypes, typename TCurType, typename... TRemainTypes> 37 | struct NewTupleType_, TCurType, TRemainTypes...> 38 | { 39 | using type = typename NewTupleType_, TRemainTypes...>::type; 40 | }; 41 | // N==M的情况,替换后直接返回 42 | template class TCont, typename... TModifiedTypes, typename TCurType, typename... TRemainTypes> 43 | struct NewTupleType_, TCurType, TRemainTypes...> 44 | { 45 | using type = TCont; 46 | }; 47 | 48 | template 49 | using NewTupleType = typename NewTupleType_::type; 50 | 51 | // 在多个类型中查找指定类型的下标,从左到右,从0开始 52 | template 53 | struct Tag2Id_; 54 | template 55 | struct Tag2Id_ 56 | { 57 | static constexpr size_t value = std::conditional_t, 58 | std::integral_constant, 59 | Tag2Id_>::value; // 不能直接用?:,需要阻止不满足条件的情况下的继续实例化 60 | }; 61 | 62 | template 63 | constexpr size_t Tag2Id = Tag2Id_::value; 64 | 65 | // 在类型容器中查找指定下标的类型 66 | template 67 | struct ContPosType_ 68 | { 69 | using type = void; // 这个必须要有,因为在递归终点std::conditional_t中使用了这个类型,所以也会被实例化 70 | }; 71 | template class TCont, typename TCurType, typename... TRemainTypes, size_t TagPos> 72 | struct ContPosType_, TagPos> 73 | { 74 | using type = std::conditional_t, TagPos - 1>::type>; 75 | }; 76 | 77 | template 78 | using ContPosType = typename ContPosType_::type; 79 | 80 | } // namespace NsVarTypeDict 81 | 82 | 83 | template 84 | struct VarTypeDict 85 | { 86 | template 87 | struct Values 88 | { 89 | private: 90 | std::shared_ptr m_tuple[sizeof...(TTypes)]; 91 | public: 92 | Values() = default; 93 | Values(std::shared_ptr(&&input)[sizeof...(TTypes)]) 94 | { 95 | for (size_t i = 0; i < sizeof...(TTypes); i++) 96 | { 97 | m_tuple[i] = std::move(input[i]); 98 | } 99 | } 100 | // TTag作为类型键值数组的键,每一次Set调用会使用TVal去替代TTag位置的类型,返回结果类型的值 101 | // 当前已保存的值会被移动到返回结果中,所有TTag都被替代完成才能通过编译 102 | template 103 | auto set(TVal&& val) && 104 | { 105 | using namespace NsVarTypeDict; 106 | constexpr static size_t TagPos = Tag2Id; 107 | using RawVal = std::decay_t; 108 | RawVal* tmp = new RawVal(std::forward(val)); 109 | m_tuple[TagPos] = std::shared_ptr(tmp, 110 | [](void* ptr) { 111 | RawVal* nptr = static_cast(ptr); 112 | delete nptr; 113 | } 114 | ); 115 | using new_type = NewTupleType, TTypes...>; 116 | return new_type(std::move(m_tuple)); 117 | } 118 | template 119 | const auto& get() const 120 | { 121 | using namespace NsVarTypeDict; 122 | constexpr static size_t TagPos = Tag2Id; 123 | using TupleType = std::decay_t; 124 | return *static_cast*>(m_tuple[TagPos].get()); 125 | } 126 | }; 127 | // 返回类型是Values<>,实参列表是sizeof...(Ts)个NullParameter 128 | static auto create() 129 | { 130 | using namespace NsVarTypeDict; 131 | using type = typename Create_::type; 132 | return type(); 133 | } 134 | }; 135 | 136 | } // MetaNN -------------------------------------------------------------------------------- /test/test.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "TestUtil.hpp" 13 | 14 | inline TestUtil& getMetaNNTestUtil(bool showDetails = true) 15 | { 16 | static TestUtil util(showDetails, "MetaNN"); 17 | return util; 18 | } 19 | 20 | // 针对MetaNN类型特化判等与输出 21 | // scalar 22 | template 23 | struct PrintObj 24 | { 25 | PrintObj(const T& val) : m_val(val) {} 26 | const T& m_val; 27 | void print(std::ostream& os) const 28 | { 29 | os << "Scalar: " << m_val.value(); 30 | } 31 | }; 32 | 33 | // matrix 34 | template 35 | struct PrintObj 36 | { 37 | PrintObj(const T& mat) : m_mat(mat) {} 38 | const T& m_mat; 39 | void print(std::ostream& os) const 40 | { 41 | os << "matrix: " << m_mat.rowNum() << "*" << m_mat.colNum() << "\n"; 42 | for (std::size_t i = 0; i < m_mat.rowNum(); ++i) 43 | { 44 | os << "\t["; 45 | for (std::size_t j = 0; j < m_mat.colNum(); ++j) 46 | { 47 | os << std::setw(3) << m_mat(i, j) << ","; 48 | } 49 | os << "]\n"; 50 | } 51 | } 52 | }; 53 | 54 | // batch scalar 55 | template 56 | struct PrintObj 57 | { 58 | PrintObj(const T& batch) : m_batch(batch) {} 59 | const T& m_batch; 60 | void print(std::ostream& os) const 61 | { 62 | os << "batch scalar: ["; 63 | for (std::size_t i = 0; i < m_batch.batchNum(); i++) 64 | { 65 | os << PrintObj>(m_batch[i]) << ", "; 66 | } 67 | } 68 | }; 69 | 70 | // batch matrix 71 | template 72 | struct PrintObj 73 | { 74 | PrintObj(const T& batch) : m_batch(batch) {} 75 | const T& m_batch; 76 | void print(std::ostream& os) const 77 | { 78 | os << "batch matrix: \n"; 79 | for (std::size_t i = 0; i < m_batch.batchNum(); i++) 80 | { 81 | os << "[" << i << "]: "; 82 | os << PrintObj>(m_batch[i]); 83 | } 84 | } 85 | }; 86 | 87 | // 判等 88 | template 89 | struct ObjEqual 90 | { 91 | bool operator()(const T1& val1, const T2& val2) const 92 | { 93 | return val1.value() == val2.value(); 94 | } 95 | }; 96 | 97 | template 98 | struct ObjEqual 99 | { 100 | bool operator()(const T1& mat1, const T2& mat2) const 101 | { 102 | if (mat1.rowNum() != mat2.rowNum() || mat1.colNum() != mat2.colNum()) 103 | { 104 | return false; 105 | } 106 | for (std::size_t i = 0; i < mat1.rowNum(); ++i) 107 | { 108 | for (std::size_t j = 0; j < mat1.colNum(); ++j) 109 | { 110 | if (mat1(i, j) != mat2(i, j)) 111 | { 112 | return false; 113 | } 114 | } 115 | } 116 | return true; 117 | } 118 | }; 119 | 120 | template 121 | struct ObjEqual 122 | { 123 | bool operator()(const T1& batch1, const T2& batch2) const 124 | { 125 | if (batch1.batchNum() != batch2.batchNum()) 126 | { 127 | return false; 128 | } 129 | for (std::size_t i = 0; i < batch1.batchNum(); ++i) 130 | { 131 | if (!ObjEqual, std::decay_t>()(batch1[i], batch2[i])) 132 | { 133 | return false; 134 | } 135 | } 136 | return true; 137 | } 138 | }; 139 | 140 | template 141 | struct ObjEqual 142 | { 143 | bool operator()(const T1& batch1, const T2& batch2) const 144 | { 145 | if (batch1.batchNum() != batch2.batchNum()) 146 | { 147 | return false; 148 | } 149 | for (std::size_t i = 0; i < batch1.batchNum(); ++i) 150 | { 151 | if (!ObjEqual, std::decay_t>()(batch1[i], batch2[i])) 152 | { 153 | return false; 154 | } 155 | } 156 | return true; 157 | } 158 | }; 159 | 160 | // 初始化一个矩阵 161 | template 162 | inline void iota(MetaNN::Matrix& mat) 163 | { 164 | int count = 0; 165 | for (std::size_t i = 0; i < mat.rowNum(); ++i) 166 | { 167 | for (std::size_t j = 0; j < mat.colNum(); ++j) 168 | { 169 | mat.setValue(i, j, count++); 170 | } 171 | } 172 | } 173 | 174 | // 初始化矩阵列表 175 | template 176 | inline void iota(MetaNN::Batch& batch) 177 | { 178 | int count = 0; 179 | for (std::size_t i = 0; i < batch.batchNum(); i++) 180 | { 181 | for (std::size_t j = 0; j < batch.rowNum(); j++) 182 | { 183 | for (std::size_t k = 0; k < batch.colNum(); k++) 184 | { 185 | batch.setValue(i, j, k, count++); 186 | } 187 | } 188 | } 189 | } 190 | 191 | // 测试函数声明 192 | void test_facility(TestUtil& util = getMetaNNTestUtil()); 193 | 194 | void test_data(TestUtil& util = getMetaNNTestUtil()); 195 | 196 | void test_operator(TestUtil& util = getMetaNNTestUtil()); 197 | 198 | void test_policy(TestUtil& util = getMetaNNTestUtil()); 199 | 200 | void test_param_initializer(TestUtil& util = getMetaNNTestUtil()); 201 | 202 | void test_layer(TestUtil& util = getMetaNNTestUtil()); 203 | 204 | void test_evaluation(TestUtil& util = getMetaNNTestUtil()); 205 | -------------------------------------------------------------------------------- /MetaNN/data/batch/batch.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace MetaNN 10 | { 11 | 12 | template 13 | class Batch; 14 | 15 | // 别名 16 | template 17 | using CpuBatchScalar = Batch; 18 | template 19 | using CpuBatchMatix = Batch; 20 | 21 | // 底层访问 22 | template 23 | struct LowerAccessImpl>; 24 | template 25 | struct LowerAccessImpl>; 26 | 27 | // 标量列表 28 | template 29 | class Batch 30 | { 31 | static_assert(std::is_same_v, TElem>, "TElem is not an available type"); 32 | friend struct LowerAccessImpl>; 33 | public: 34 | using Category = CategoryTags::BatchScalar; 35 | using ElementType = TElem; 36 | using DeviceType = DeviceTags::CPU; 37 | public: 38 | Batch(std::size_t length = 0) 39 | : m_mem(length) 40 | , m_len(length) 41 | { 42 | } 43 | 44 | // 查询接口 45 | size_t batchNum() const 46 | { 47 | return m_len; 48 | } 49 | bool availableForWrite() const 50 | { 51 | return m_mem.useCount() == 1; 52 | } 53 | // 写入接口 54 | void setValue(std::size_t index, ElementType val) 55 | { 56 | assert(availableForWrite()); 57 | assert(index < m_len); 58 | m_mem.rawMemory()[index] = val; 59 | } 60 | // 读取接口 61 | const auto operator[](std::size_t index) const 62 | { 63 | assert(index < m_len); 64 | return m_mem.rawMemory()[index]; 65 | } 66 | 67 | // 求值接口: todo 68 | private: 69 | ContinuousMemory m_mem; 70 | std::size_t m_len; 71 | }; 72 | 73 | // 标量列表底层访问 74 | template 75 | struct LowerAccessImpl> 76 | { 77 | LowerAccessImpl(Batch p) 78 | : m_data(std::move(p)) 79 | { 80 | } 81 | auto mutableRawMemory() 82 | { 83 | return m_data.m_mem.rawMemory(); 84 | } 85 | const auto rawMemory() const 86 | { 87 | return m_data.m_mem.rawMemory(); 88 | } 89 | private: 90 | Batch m_data; 91 | }; 92 | 93 | 94 | // 矩阵列表 95 | template 96 | class Batch 97 | { 98 | static_assert(std::is_same_v, TElem>, "TElem is not an available type"); 99 | friend struct LowerAccessImpl>; 100 | public: 101 | using Category = CategoryTags::BatchMatrix; 102 | using ElementType = TElem; 103 | using DeviceType = DeviceTags::CPU; 104 | public: 105 | Batch(std::size_t batchNum = 0, std::size_t row = 0, std::size_t col = 0) 106 | : m_mem(row * col * batchNum) 107 | , m_rowNum(row) 108 | , m_colNum(col) 109 | , m_batchNum(batchNum) 110 | , m_rowLen(col) 111 | , m_rawMatrixSize(row * col) 112 | { 113 | } 114 | // 查询接口 115 | std::size_t rowNum() const 116 | { 117 | return m_rowNum; 118 | } 119 | std::size_t colNum() const 120 | { 121 | return m_colNum; 122 | } 123 | std::size_t batchNum() const 124 | { 125 | return m_batchNum; 126 | } 127 | bool availableForWrite() const 128 | { 129 | return m_mem.useCount() == 1; 130 | } 131 | // 写入接口:写入具体的某个矩阵的某个值 132 | void setValue(std::size_t batchId, std::size_t row, std::size_t col, ElementType val) 133 | { 134 | assert(availableForWrite()); 135 | assert(row < m_rowNum && col < m_colNum && batchId < m_batchNum); 136 | m_mem.rawMemory()[batchId * m_rawMatrixSize + row * m_rowLen + col] = val; 137 | } 138 | // 读取接口:返回一个临时矩阵,共享存储,仅用于访问 139 | const auto operator[](std::size_t batchId) const 140 | { 141 | assert(batchId < m_batchNum); 142 | auto pos = m_mem.rawMemory() + batchId * m_rawMatrixSize; 143 | return Matrix(m_mem.sharedPtr(), pos, m_rowNum, m_colNum, m_rowLen); 144 | } 145 | 146 | // 子矩阵列表接口,浅拷贝,共享存储,区间前闭后开 147 | auto subBatchMatrix(std::size_t rowBegin, std::size_t rowEnd, std::size_t colBegin, std::size_t colEnd) 148 | { 149 | assert(rowBegin < m_rowNum && colBegin < m_colNum); 150 | assert(rowend <= m_rowNum && colEnd <= m_colNum); 151 | auto pos = m_mem.rawMemory() + rowBegin * m_rowLen + colBegin; 152 | return Batch(m_mem.sharedPtr(), pos, rowEnd - rowBegin, colEnd - colBegin, m_batchNum, m_rowLen, m_rawMatrixSize); 153 | } 154 | 155 | // 求值接口: todo 156 | 157 | private: 158 | Batch(std::shared_ptr sp, ElementType* pMemStart, 159 | std::size_t row, std::size_t col, std::size_t batchNum, std::size_t rowLen, std::size_t matrixSize) 160 | : m_mem(sp, pMemStart) 161 | , m_rowNum(row) 162 | , m_colNum(col) 163 | , m_batchNum(batchNum) 164 | , m_rowLen(rowLen) 165 | , m_rawMatrixSize(matrixSize) 166 | { 167 | } 168 | 169 | ContinuousMemory m_mem; // 内部数据存储于一维数组,并使用ContinuousMemory维护 170 | std::size_t m_rowNum; 171 | std::size_t m_colNum; 172 | std::size_t m_batchNum; 173 | std::size_t m_rowLen; 174 | std::size_t m_rawMatrixSize; // 原始的矩阵大小,也就是原始的矩阵行列之积 175 | }; 176 | 177 | // 矩阵列表底层访问 178 | template 179 | struct LowerAccessImpl> 180 | { 181 | LowerAccessImpl(Batch p) 182 | : m_data(std::move(p)) 183 | { 184 | } 185 | auto mutableRawMemory() 186 | { 187 | return m_data.m_mem.rawMemory(); 188 | } 189 | const auto rawMemory() const 190 | { 191 | return m_data.m_mem.rawMemory(); 192 | } 193 | std::size_t rowLen() const 194 | { 195 | return m_data.m_rowLen; 196 | } 197 | std::size_t rawMatrixSize() const 198 | { 199 | return m_data.m_rawMatrixSize; 200 | } 201 | private: 202 | Batch m_data; 203 | }; 204 | 205 | } // namespace MetaNN -------------------------------------------------------------------------------- /MetaNN/data/batch/array.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace MetaNN 12 | { 13 | 14 | // Batch是不可变的列表,而Array是可变列表(话说用Array命名可变列表会不会有点怪?) 15 | template 16 | class Array; 17 | 18 | // 标量数组 19 | template requires std::same_as, CategoryTags::Scalar> 20 | class Array 21 | { 22 | static_assert(std::is_same_v, TData>, "TData is not an available type"); 23 | public: 24 | using Category = CategoryTags::BatchScalar; 25 | using ElementType = typename TData::ElementType; 26 | using DeviceType = typename TData::DeviceType; 27 | public: 28 | Array(std::size_t = 0, std::size_t = 0) 29 | : m_buffer(new std::vector()) 30 | { 31 | } 32 | template 33 | Array(TIterator begin, TIterator end) 34 | : m_buffer(new std::vector(begin, end)) 35 | { 36 | } 37 | 38 | // 访问接口 39 | std::size_t batchNum() const 40 | { 41 | return m_buffer->size(); 42 | } 43 | std::size_t size() const 44 | { 45 | return m_buffer->size(); 46 | } 47 | bool availableForWrite() const 48 | { 49 | return m_buffer.use_count() == 1; // todo 50 | } 51 | 52 | // STL兼容接口 53 | void push_back(TData val) 54 | { 55 | assert(availableForWrite()); 56 | m_buffer->emplace_back(std::move(val)); 57 | } 58 | template 59 | void emplace_back(Args&&... args) 60 | { 61 | assert(availableForWrite()); 62 | TData tmp(std::forward(args)...); 63 | m_buffer->emplace_back(std::move(tmp)); 64 | } 65 | void reserve(std::size_t num) 66 | { 67 | assert(availableForWrite()); 68 | m_buffer->reserve(num); 69 | } 70 | void clear() 71 | { 72 | assert(availableForWrite()); 73 | m_buffer->clear(); 74 | } 75 | bool empty() const 76 | { 77 | return m_buffer->empty(); 78 | } 79 | const auto& operator[](std::size_t idx) const 80 | { 81 | return (*m_buffer)[idx]; 82 | } 83 | auto& operator[](std::size_t idx) 84 | { 85 | return (*m_buffer)[idx]; 86 | } 87 | auto begin() 88 | { 89 | return m_buffer->begin(); 90 | } 91 | auto begin() const 92 | { 93 | return m_buffer->begin(); 94 | } 95 | auto end() 96 | { 97 | return m_buffer->end(); 98 | } 99 | auto end() const 100 | { 101 | return m_buffer->end(); 102 | } 103 | 104 | // 求值接口: todo 105 | private: 106 | std::shared_ptr> m_buffer; 107 | // 求值缓存: todo 108 | }; 109 | 110 | // 矩阵数组 111 | template requires std::same_as, CategoryTags::Matrix> 112 | class Array 113 | { 114 | static_assert(std::is_same_v, TData>, "TData is not an available type"); 115 | public: 116 | using Category = CategoryTags::BatchMatrix; 117 | using ElementType = typename TData::ElementType; 118 | using DeviceType = typename TData::DeviceType; 119 | public: 120 | Array(std::size_t row = 0, std::size_t col = 0) 121 | : m_rowNum(row) 122 | , m_colNum(col) 123 | , m_buffer(new std::vector()) 124 | { 125 | } 126 | template 127 | Array(TIterator begin, TIterator end) 128 | : m_rowNum(0) 129 | , m_colNum(0) 130 | , m_buffer(new std::vector(begin, end)) 131 | { 132 | const auto& buffer = *m_buffer; 133 | if (!buffer.empty()) 134 | { 135 | m_rowNum = buffer[0].rowNum(); 136 | m_colNum = buffer[1].colNum(); 137 | for (std::size_t i = 1; i < buffer.size(); ++i) 138 | { 139 | if (buffer[i].rowNum() != m_rowNum || buffer[i].colNum() != m_colNum) 140 | { 141 | throw std::runtime_error("Dimension mismatch"); 142 | } 143 | } 144 | } 145 | } 146 | // 访问接口 147 | std::size_t rowNum() const 148 | { 149 | return m_rowNum; 150 | } 151 | std::size_t colNum() const 152 | { 153 | return m_colNum; 154 | } 155 | std::size_t batchNum() const 156 | { 157 | return m_buffer->size(); 158 | } 159 | std::size_t size() const 160 | { 161 | return m_buffer->size(); 162 | } 163 | bool availableForWrite() const 164 | { 165 | return m_buffer.use_count() == 1; // todo 166 | } 167 | // STL兼容接口 168 | void push_back(TData mat) 169 | { 170 | assert(availableForWrite()); 171 | if (mat.rowNum() != m_rowNum || mat.colNum() != m_colNum) 172 | { 173 | throw std::runtime_error("Dimension mismatch"); 174 | } 175 | m_buffer->push_back(std::move(mat)); 176 | } 177 | template 178 | void emplace_back(Args&&... args) 179 | { 180 | assert(availableForWrite()); 181 | TData tmp(std::forward(args)...); 182 | if (tmp.rowNum() != m_rowNum || tmp.colNum() != m_colNum) 183 | { 184 | throw std::runtime_error("Dimension mismatch"); 185 | } 186 | m_buffer->emplace_back(std::move(tmp)); 187 | } 188 | void reserve(std::size_t num) 189 | { 190 | assert(availableForWrite()); 191 | m_buffer->reserve(num); 192 | } 193 | void clear() 194 | { 195 | assert(availableForWrite()); 196 | m_buffer->clear(); 197 | } 198 | bool empty() const 199 | { 200 | return m_buffer->empty(); 201 | } 202 | const auto& operator[](std::size_t idx) const 203 | { 204 | return (*m_buffer)[idx]; 205 | } 206 | auto& operator[](std::size_t idx) 207 | { 208 | return (*m_buffer)[idx]; 209 | } 210 | auto begin() 211 | { 212 | return m_buffer->begin(); 213 | } 214 | auto begin() const 215 | { 216 | return m_buffer->begin(); 217 | } 218 | auto end() 219 | { 220 | return m_buffer->end(); 221 | } 222 | auto end() const 223 | { 224 | return m_buffer->end(); 225 | } 226 | 227 | // 求值接口: todo 228 | 229 | private: 230 | std::size_t m_rowNum; 231 | std::size_t m_colNum; 232 | std::shared_ptr> m_buffer; 233 | // 求值缓存: todo 234 | }; 235 | 236 | // 快捷构造Array 237 | template 238 | auto makeArray(TIterator beg, TIterator end) 239 | { 240 | using TData = typename std::iterator_traits::value_type; 241 | using RawData = std::remove_cvref_t; 242 | return Array(beg, end); 243 | } 244 | 245 | } // namespace MetaNN -------------------------------------------------------------------------------- /test/test_data.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "test.hpp" 15 | 16 | using namespace MetaNN; 17 | 18 | // scalar 19 | static_assert(ScalarC>); 20 | static_assert(ScalarC>); 21 | // matrix 22 | static_assert(MatrixC>); 23 | static_assert(MatrixC>); 24 | static_assert(MatrixC>>); 25 | static_assert(MatrixC>); 26 | static_assert(MatrixC>); 27 | static_assert(MatrixC>); 28 | static_assert(MatrixC>); 29 | static_assert(MatrixC>); 30 | static_assert(LowerAccessC>); 31 | static_assert(!LowerAccessC>); 32 | // batch scalar, batch matrix 33 | static_assert(BatchScalarC>); 34 | static_assert(BatchMatrixC>); 35 | static_assert(BatchScalarC>); 36 | static_assert(BatchMatrixC>); 37 | static_assert(LowerAccessC>); 38 | static_assert(LowerAccessC>); 39 | static_assert(BatchScalarC>>); 40 | static_assert(BatchMatrixC>>); 41 | static_assert(BatchScalarC>>); 42 | static_assert(BatchMatrixC>>); 43 | // PrincipalDataType 44 | static_assert(std::same_as, PrincipalDataType>); 45 | static_assert(std::same_as, PrincipalDataType>); 46 | static_assert(std::same_as, PrincipalDataType>); 47 | static_assert(std::same_as, PrincipalDataType>); 48 | 49 | void test_scalar(TestUtil& util); 50 | void test_matrix(TestUtil& util); 51 | void test_batch_scalar(TestUtil& util); 52 | void test_batch_matrix(TestUtil& util); 53 | void test_trivial_matrix(TestUtil& util); 54 | void test_zero_matrix(TestUtil& util); 55 | void test_one_hot_vector(TestUtil& util); 56 | void test_array(TestUtil& util); 57 | void test_duplicate(TestUtil& util); 58 | 59 | void test_data(TestUtil& util) 60 | { 61 | test_scalar(util); 62 | test_matrix(util); 63 | test_batch_scalar(util); 64 | test_batch_matrix(util); 65 | test_trivial_matrix(util); 66 | test_zero_matrix(util); 67 | test_one_hot_vector(util); 68 | test_array(util); 69 | test_duplicate(util); 70 | } 71 | 72 | void test_scalar(TestUtil& util) 73 | { 74 | util.setTestGroup("data.scalar"); 75 | { 76 | Scalar s(1.1); 77 | util.assertEqual(s.value(), 1.1); 78 | s.value() = 10; 79 | util.assertEqual(s.value(), 10); 80 | } 81 | { 82 | const Scalar s(1.2); 83 | util.assertEqual(s.value(), 1.2); 84 | } 85 | util.showGroupResult(); 86 | } 87 | 88 | void test_matrix(TestUtil& util) 89 | { 90 | util.setTestGroup("data.matrix"); 91 | { 92 | // rowNum, colNum 93 | Matrix mat(2, 3); 94 | iota(mat); 95 | util.assertEqual(mat.rowNum(), 2); 96 | util.assertEqual(mat.colNum(), 3); 97 | mat.setValue(1, 1, 10.5); 98 | util.assertEqual(mat(1, 1), 10.5); 99 | util.assertEqual(mat.availableForWrite(), true); 100 | // availableForWrite 101 | Matrix mat2(mat); 102 | util.assertEqual(mat2(1, 1), 10.5); 103 | util.assertEqual(mat.availableForWrite(), false); 104 | util.assertEqual(mat2.availableForWrite(), false); 105 | // lower access 106 | auto acc = lowerAccess(mat); 107 | util.assertEqual(acc.rowLen(), 3); 108 | util.assertEqual(*(acc.rawMemory() + 1 * 3 + 1), 10.5); 109 | } 110 | // subMatrix 111 | { 112 | Matrix mat(10, 10); 113 | iota(mat); 114 | Matrix mat2(10, 10); 115 | iota(mat2); 116 | auto sub1 = mat.subMatrix(3, 8, 4, 10); 117 | auto sub2 = mat2.subMatrix(3, 8, 4, 10); 118 | util.assertEqual(sub1, sub2); 119 | sub1.setValue(1, 1, -1); 120 | sub2.setValue(1, 1, -1); 121 | util.assertEqual(sub1, sub2); 122 | } 123 | util.showGroupResult(); 124 | } 125 | 126 | void test_batch_scalar(TestUtil& util) 127 | { 128 | util.setTestGroup("data.batch_scalar"); 129 | { 130 | using BatchScalarDouble = Batch; 131 | 132 | BatchScalarDouble s(10); 133 | util.assertEqual(s.batchNum(), 10); 134 | util.assertEqual(s.availableForWrite(), true); 135 | for (std::size_t i = 0; i < s.batchNum(); i++) 136 | { 137 | s.setValue(i, i); 138 | } 139 | util.assertEqual(s[1], 1); 140 | util.assertEqual(s[8], 8); 141 | // lower access 142 | { 143 | auto acc = lowerAccess(s); 144 | util.assertEqual(*(acc.rawMemory() + 3), 3); 145 | *(acc.mutableRawMemory() + 3) = 10; 146 | util.assertEqual(*(acc.rawMemory() + 3), 10); 147 | } 148 | } 149 | 150 | util.showGroupResult(); 151 | } 152 | 153 | void test_batch_matrix(TestUtil& util) 154 | { 155 | util.setTestGroup("data.batch_matrix"); 156 | { 157 | using BatchMatrixDouble = Batch; 158 | BatchMatrixDouble batch1(10, 2, 3); 159 | BatchMatrixDouble batch2(10, 2, 3); 160 | util.assertEqual(batch1.rowNum(), batch2.rowNum()); 161 | util.assertEqual(batch1.colNum(), batch2.colNum()); 162 | util.assertEqual(batch1.batchNum(), batch2.batchNum()); 163 | util.assertEqual(batch1.rowNum(), 2); 164 | util.assertEqual(batch1.colNum(), 3); 165 | util.assertEqual(batch1.batchNum(), 10); 166 | iota(batch1); 167 | iota(batch2); 168 | util.assertEqual(batch1, batch2); 169 | util.assertEqual(batch1.availableForWrite(), true); 170 | batch1.setValue(0, 0, 0, 10.2); 171 | util.assertEqual(batch1[0](0, 0), 10.2); 172 | Matrix mat(2, 3); 173 | iota(mat); 174 | util.assertEqual(mat, batch2[0]); 175 | // subBatchMatrix 176 | auto sub1 = batch1.subBatchMatrix(0, 2, 1, 3); 177 | auto sub2 = batch2.subBatchMatrix(0, 2, 1, 3); 178 | util.assertEqual(sub1, sub2); 179 | // lower access 180 | auto acc = lowerAccess(sub1); 181 | *(acc.mutableRawMemory()) = -1; 182 | util.assertEqual(batch1[0](0,1), -1); 183 | util.assertEqual(*acc.rawMemory(), -1); 184 | util.assertEqual(acc.rowLen(), 3); 185 | util.assertEqual(acc.rawMatrixSize(), 6); 186 | } 187 | util.showGroupResult(); 188 | } 189 | 190 | void test_trivial_matrix(TestUtil& util) 191 | { 192 | util.setTestGroup("data.trivial_matrix"); 193 | { 194 | TrivialMatrix mat(10, 10, 9.9); 195 | util.assertEqual(mat.rowNum(), 10); 196 | util.assertEqual(mat.colNum(), 10); 197 | util.assertEqual(mat.elementValue().value(), 9.9); 198 | } 199 | // makeTrivialMatrix 200 | { 201 | auto mat = makeTrivialMatrix(10, 10, 9.9); 202 | util.assertEqual(mat.rowNum(), 10); 203 | util.assertEqual(mat.colNum(), 10); 204 | util.assertEqual(mat.elementValue().value(), 9.9); 205 | } 206 | { 207 | Scalar s(3.3); 208 | auto mat = makeTrivialMatrix(10, 10, s); 209 | util.assertEqual(mat.rowNum(), 10); 210 | util.assertEqual(mat.colNum(), 10); 211 | util.assertEqual(mat.elementValue().value(), 3.3); 212 | } 213 | util.showGroupResult(); 214 | } 215 | 216 | void test_zero_matrix(TestUtil& util) 217 | { 218 | util.setTestGroup("data.zero_matrix"); 219 | { 220 | ZeroMatrix mat(10, 10); 221 | util.assertEqual(mat.rowNum(), 10); 222 | util.assertEqual(mat.colNum(), 10); 223 | } 224 | util.showGroupResult(); 225 | } 226 | 227 | void test_one_hot_vector(TestUtil& util) 228 | { 229 | util.setTestGroup("data.one_hot_vector"); 230 | { 231 | OneHotVector mat(10, 3); 232 | util.assertEqual(mat.rowNum(), 1); 233 | util.assertEqual(mat.colNum(), 10); 234 | util.assertEqual(mat.hotPos(), 3); 235 | } 236 | util.showGroupResult(); 237 | } 238 | 239 | void test_array(TestUtil& util) 240 | { 241 | util.setTestGroup("data.array"); 242 | // as batch scalar 243 | { 244 | Array> arr1; 245 | arr1.push_back(1.0); 246 | arr1.push_back(2.0); 247 | arr1.emplace_back(3.0); 248 | util.assertEqual(arr1.batchNum(), 3); 249 | util.assertEqual(arr1.size(), 3); 250 | util.assertEqual(arr1.availableForWrite(), true); 251 | Array> arr2(arr1.begin(), arr1.end()); 252 | util.assertEqual(arr1, arr2); 253 | util.assertEqual(arr1[0].value(), 1.0); 254 | util.assertEqual(arr1[2].value(), 3.0); 255 | arr1[0] = Scalar(9.9); 256 | util.assertEqual(arr1[0].value(), 9.9); 257 | util.assertEqual((*arr1.begin()).value(), 9.9); 258 | arr1.clear(); 259 | util.assertEqual(arr1.empty(), true); 260 | } 261 | // as batch matrix 262 | { 263 | Array> arr1(4, 5); 264 | Matrix mat(4, 5); 265 | iota(mat); 266 | arr1.push_back(mat); 267 | arr1.emplace_back(mat); 268 | util.assertEqual(arr1.size(), 2); 269 | util.assertEqual(arr1.rowNum(), 4); 270 | util.assertEqual(arr1.colNum(), 5); 271 | util.assertEqual(arr1.batchNum(), 2); 272 | util.assertEqual(arr1.availableForWrite(), true); 273 | Array> arr2(arr1.begin(), arr1.end()); 274 | util.assertEqual(arr1, arr2); 275 | util.assertEqual(arr1[0], mat); 276 | mat.setValue(1, 1, -1); 277 | arr1[0] = mat; 278 | util.assertEqual(arr1[0], mat); 279 | arr1.clear(); 280 | util.assertEqual(arr1.empty(), true); 281 | } 282 | // makeArray 283 | { 284 | Array> arr1(4, 5); 285 | Matrix mat(4, 5); 286 | iota(mat); 287 | arr1.push_back(mat); 288 | arr1.emplace_back(mat); 289 | 290 | auto arr2 = makeArray(arr1.begin(), arr1.end()); 291 | util.assertEqual(arr2.size(), 2); 292 | util.assertEqual(arr2[0], mat); 293 | } 294 | util.showGroupResult(); 295 | } 296 | 297 | void test_duplicate(TestUtil& util) 298 | { 299 | util.setTestGroup("data.duplicate"); 300 | // dupliate of scalar 301 | { 302 | Duplicate> dup(Scalar(9.9), 10); 303 | util.assertEqual(dup.batchNum(), 10); 304 | util.assertEqual(dup.element().value(), 9.9); 305 | } 306 | // dupliate of matrix 307 | { 308 | Matrix mat(4, 5); 309 | iota(mat); 310 | Duplicate> dup(mat, 10); 311 | util.assertEqual(dup.rowNum(), 4); 312 | util.assertEqual(dup.colNum(), 5); 313 | util.assertEqual(dup.batchNum(), 10); 314 | util.assertEqual(dup.element(), mat); 315 | } 316 | // makeDupliate 317 | { 318 | Matrix mat(4, 5); 319 | iota(mat); 320 | auto dup = makeDuplicate(10, mat); 321 | util.assertEqual(dup.rowNum(), 4); 322 | util.assertEqual(dup.colNum(), 5); 323 | util.assertEqual(dup.batchNum(), 10); 324 | util.assertEqual(dup.element(), mat); 325 | } 326 | util.showGroupResult(); 327 | } 328 | -------------------------------------------------------------------------------- /test/TestUtil.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include // since C++20 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | // parsing first argument: -d to show details 13 | inline bool parseDetailFlag(int argc, char const *argv[]) 14 | { 15 | return argc >= 2 && std::string(argv[1]) == "-d"; 16 | } 17 | 18 | // 通用输出工具,对于需要输出但未定义operator<<的自定义类型则需要进行特化,或者定义operator<<,默认行为是调用operator<< 19 | template 20 | struct PrintObj 21 | { 22 | PrintObj(const T& val) : m_val(val) {} 23 | const T& m_val; 24 | void print(std::ostream& os) const 25 | { 26 | os << m_val; 27 | } 28 | }; 29 | 30 | template 31 | std::ostream& operator<<(std::ostream& os, const PrintObj& obj) 32 | { 33 | obj.print(os); 34 | return os; 35 | } 36 | 37 | template 38 | struct PrintObj> 39 | { 40 | PrintObj(const std::pair& val) : m_pair(val) {} 41 | const std::pair& m_pair; 42 | void operator()(std::ostream& os) const 43 | { 44 | os << "(" << m_pair.first << ", " << m_pair.second << ")"; 45 | } 46 | }; 47 | 48 | // 用于判等的函数对象,方便自定义类型特化以区分于operator== 49 | template 50 | struct ObjEqual 51 | { 52 | bool operator()(const T1& val1, const T2& val2) const 53 | { 54 | return val1 == val2; 55 | } 56 | }; 57 | 58 | // manipulator for printing first N elements of a sequence 59 | template 60 | class PrintSequenceElements 61 | { 62 | friend std::ostream& operator<<(std::ostream& os, const PrintSequenceElements& p) 63 | { 64 | int count = 0; 65 | auto iter = p.begin; 66 | for (; iter != p.end && count < p.num; ++count, ++iter) 67 | { 68 | os << *iter << " "; 69 | } 70 | if (iter != p.end) 71 | { 72 | os << "..."; 73 | } 74 | return os; 75 | } 76 | public: 77 | PrintSequenceElements(const Iterator& _begin, const Iterator& _end, std::size_t _num) : begin(_begin), end(_end), num(_num) 78 | { 79 | } 80 | private: 81 | std::size_t num; 82 | const Iterator begin; 83 | const Iterator end; 84 | }; 85 | 86 | template 87 | PrintSequenceElements printContainerElememts(const Container& c, std::size_t num) 88 | { 89 | return PrintSequenceElements(c.begin(), c.end(), num); 90 | } 91 | 92 | template 93 | PrintSequenceElements printArrayElements(T* arr, std::size_t size, std::size_t num) 94 | { 95 | return PrintSequenceElements(arr, arr+size, num); 96 | } 97 | 98 | // test utilities 99 | class TestUtil 100 | { 101 | public: 102 | TestUtil(bool _show, const std::string& _target, int _lineNumberWidth = 4, int _maxSequenceLength = 20, std::ostream& _os = std::cout) 103 | : groupPassedCount(0) 104 | , groupTotalCount(0) 105 | , passedCount(0) 106 | , totalCount(0) 107 | , lineNumberWidth(_lineNumberWidth) 108 | , maxSequenceLength(_maxSequenceLength) 109 | , showDetails(_show) 110 | , target(_target) 111 | , os(_os) 112 | { 113 | os.clear(); 114 | } 115 | 116 | // 必须用于测试一个组之前 117 | void setTestGroup(const std::string& group) 118 | { 119 | passedCount += groupPassedCount; 120 | totalCount += groupTotalCount; 121 | groupPassedCount = 0; 122 | groupTotalCount = 0; 123 | curGroup = group; 124 | if (showDetails) 125 | { 126 | os << "Test of " << curGroup << ":\n"; 127 | } 128 | } 129 | 130 | // 用于测试一个组的最后 131 | void showGroupResult() 132 | { 133 | os << std::boolalpha << std::dec; 134 | os << "Test result of " << std::setfill('_') << std::left << std::setw(30) << curGroup << ": "; 135 | os << std::right << std::setfill(' '); 136 | os << std::setw(3) << groupPassedCount << "/" << std::setw(3) << std::left << groupTotalCount << " passed"; 137 | os << (groupPassedCount == groupTotalCount ? "\n" : " --------------------------> failed\n"); 138 | if (showDetails) 139 | { 140 | os << "\n"; 141 | } 142 | passedCount += groupPassedCount; 143 | totalCount += groupTotalCount; 144 | groupPassedCount = 0; 145 | groupTotalCount = 0; 146 | } 147 | 148 | void showFinalResult() 149 | { 150 | os << std::boolalpha << std::dec; 151 | os << "Test results of " << target << ": "; 152 | os << std::setw(4) << std::right << passedCount << "/" << std::setw(4) << std::left << totalCount << " passed"; 153 | os << (passedCount == totalCount ? " ========================= success\n" : " ========================= failed\n"); 154 | } 155 | 156 | template> 157 | void assertEqual(const T1& t1, const T2& t2, const Equal& eq = ObjEqual(), const std::source_location& loc = std::source_location::current()) 158 | { 159 | bool res = eq(t1, t2); 160 | groupPassedCount += (res ? 1 : 0); 161 | groupTotalCount++; 162 | if (!res && showDetails) 163 | { 164 | os << std::boolalpha << std::dec << std::right << std::setfill(' '); 165 | os << loc.file_name() << ":" << std::setw(lineNumberWidth) << loc.line() << ": " 166 | << "assertEqual: " << "left value( " << PrintObj(t1) << " ), right value( " << PrintObj(t2) << " )\n"; 167 | } 168 | } 169 | 170 | template> 171 | void assertNotEqual(const T1& t1, const T2& t2, const Equal& eq = ObjEqual(), const std::source_location& loc = std::source_location::current()) 172 | { 173 | bool res = !eq(t1, t2); 174 | groupPassedCount += (res ? 1 : 0); 175 | groupTotalCount++; 176 | if (!res && showDetails) 177 | { 178 | os << std::boolalpha << std::dec << std::right << std::setfill(' '); 179 | os << loc.file_name() << ":" << std::setw(lineNumberWidth) << loc.line() << ": " 180 | << "assertEqual: " << "left value( " << PrintObj(t1) << " ), right value( " << PrintObj(t2) << " )\n"; 181 | } 182 | } 183 | 184 | template 185 | void assertSequenceEqual(const Container1& c1, const Container2& c2, const std::source_location& loc = std::source_location::current()) 186 | { 187 | bool res = std::equal(c1.begin(), c1.end(), c2.begin()); 188 | groupPassedCount += (res ? 1 : 0); 189 | groupTotalCount++; 190 | if (!res && showDetails) 191 | { 192 | os << std::boolalpha << std::dec << std::right << std::setfill(' '); 193 | os << loc.file_name() << ":" << std::setw(lineNumberWidth) << loc.line() << ": " 194 | << "assertSequenceEqual: " 195 | << "\n\tleft sequence: " << printContainerElememts(c1, maxSequenceLength) 196 | << "\n\tright sequence: " << printContainerElememts(c2, maxSequenceLength) << "\n"; 197 | } 198 | } 199 | 200 | template 201 | void assertArrayEqual(const T1* arr1, const T2* arr2, std::size_t size, const std::source_location& loc = std::source_location::current()) 202 | { 203 | bool res = std::equal(arr1, arr1 + size, arr2); 204 | groupPassedCount += (res ? 1 : 0); 205 | groupTotalCount++; 206 | if (!res && showDetails) 207 | { 208 | os << std::boolalpha << std::dec << std::right << std::setfill(' '); 209 | os << loc.file_name() << ":" << std::setw(lineNumberWidth) << loc.line() << ": " 210 | << "assertArrayEqual: " 211 | << "\n\tleft array: " << printArrayElements(arr1, size, maxSequenceLength) 212 | << "\n\tright array: " << printArrayElements(arr2, size, maxSequenceLength) << "\n"; 213 | } 214 | } 215 | 216 | // more generic version of assert sequence/array equal 217 | template 218 | void assertRangeEqual(ForwardIterator1 b1, ForwardIterator1 e1, ForwardIterator2 b2, const std::source_location& loc = std::source_location::current()) 219 | { 220 | bool res = std::equal(b1, e1, b2); 221 | groupPassedCount += (res ? 1 : 0); 222 | groupTotalCount++; 223 | if (!res && showDetails) 224 | { 225 | os << std::boolalpha << std::dec << std::right << std::setfill(' '); 226 | os << loc.file_name() << ":" << std::setw(lineNumberWidth) << loc.line() << ": " 227 | << "assertRangeEqual: " 228 | << "\n\tleft range: " << PrintSequenceElements(b1, e1, maxSequenceLength) 229 | << "\n\tright range: " << PrintSequenceElements(b2, std::next(b2, std::distance(b1, e1)), maxSequenceLength) << "\n"; 230 | } 231 | } 232 | template 233 | void assertRangeEqual(ForwardIterator1 b1, ForwardIterator1 e1, ForwardIterator2 b2, ForwardIterator2 e2, const std::source_location& loc = std::source_location::current()) 234 | { 235 | bool res = std::distance(b1, e1) == std::distance(b2, e2) && std::equal(b1, e1, b2); 236 | groupPassedCount += (res ? 1 : 0); 237 | groupTotalCount++; 238 | if (!res && showDetails) 239 | { 240 | os << std::boolalpha << std::dec << std::right << std::setfill(' '); 241 | os << loc.file_name() << ":" << std::setw(lineNumberWidth) << loc.line() << ": " 242 | << "assertRangeEqual: " 243 | << "\n\tleft range: " << PrintSequenceElements(b1, e1, maxSequenceLength) 244 | << "\n\tright range: " << PrintSequenceElements(b2, e2, maxSequenceLength) << "\n"; 245 | } 246 | } 247 | // assert a sequence is sorted 248 | template::value_type>> 249 | void assertSorted(InputIterator b, InputIterator e, const Compare& cmp = Compare(), const std::source_location& loc = std::source_location::current()) 250 | { 251 | bool res = std::is_sorted(b, e, cmp); 252 | groupPassedCount += (res ? 1 : 0); 253 | groupTotalCount++; 254 | if (!res && showDetails) 255 | { 256 | os << std::boolalpha << std::dec << std::right << std::setfill(' '); 257 | os << loc.file_name() << ":" << std::setw(lineNumberWidth) << loc.line() << ": " 258 | << "assertSorted: " 259 | << "\n\tsequence: " << PrintSequenceElements(b, e, maxSequenceLength) << "\n"; 260 | } 261 | } 262 | // assert two set is equal, do not consider order of elements. 263 | template 264 | void assertSetEqual(const Container1& c1, const Container2& c2, const std::source_location& loc = std::source_location::current()) 265 | { 266 | bool res = (std::size(c1) == std::size(c2) && std::is_permutation(c1.begin(), c1.end(), c2.begin())); 267 | groupPassedCount += (res ? 1 : 0); 268 | groupTotalCount++; 269 | if (!res && showDetails) 270 | { 271 | os << std::boolalpha << std::dec << std::right << std::setfill(' '); 272 | os << loc.file_name() << ":" << std::setw(lineNumberWidth) << loc.line() << ": " 273 | << "assertSetEqual: " 274 | << "\n\tleft set: " << printContainerElememts(c1, maxSequenceLength) 275 | << "\n\tright set: " << printContainerElememts(c2, maxSequenceLength) << "\n"; 276 | } 277 | } 278 | template 279 | void assertSetEqual(ForwardIterator1 b1, ForwardIterator1 e1, ForwardIterator2 b2, ForwardIterator2 e2, const std::source_location& loc = std::source_location::current()) 280 | { 281 | bool res = (std::distance(b1, e1) == std::distance(b2, e2) && std::is_permutation(b1, e1, b2)); 282 | groupPassedCount += (res ? 1 : 0); 283 | groupTotalCount++; 284 | if (!res && showDetails) 285 | { 286 | os << std::boolalpha << std::dec << std::right << std::setfill(' '); 287 | os << loc.file_name() << ":" << std::setw(lineNumberWidth) << loc.line() << ": " 288 | << "assertSetEqual: " 289 | << "\n\tleft set: " << PrintSequenceElements(b1, e1, maxSequenceLength) 290 | << "\n\tright set: " << PrintSequenceElements(b2, e2, maxSequenceLength) << "\n"; 291 | } 292 | } 293 | private: 294 | int groupPassedCount; 295 | int groupTotalCount; 296 | int passedCount; 297 | int totalCount; 298 | int lineNumberWidth; // output width of line number 299 | int maxSequenceLength; // max output length of a sequence 300 | bool showDetails; 301 | std::string target; 302 | std::string curGroup; 303 | std::ostream& os; 304 | }; 305 | --------------------------------------------------------------------------------