├── .gitignore ├── doc ├── add.pdf ├── max.pdf ├── biasadd.pdf ├── concat.pdf ├── conv2d.pdf └── operation.pdf ├── sample ├── helper.h ├── CMakeLists.txt ├── sample_test.sh ├── pytorch.py ├── tf2.py ├── copy.h ├── debug_model.cpp ├── count_mac.cpp ├── sample.cpp ├── dumper.h └── naive_quantization.cpp ├── LICENSE ├── sadl ├── layer_identity.h ├── layer_shape.h ├── layer_relu.h ├── options.h ├── layer_copy.h ├── layer_leakyrelu.h ├── dimensions.h ├── layers.h ├── layer_flatten.h ├── layer_placeholder.h ├── layer_reshape.h ├── layer_const.h ├── layer_expand.h ├── layer_maximum.h ├── layer_concat.h ├── layer_transpose.h ├── layer.h ├── layer_maxpool.h ├── layer_biasadd.h ├── layer_add.h ├── layer_mul.h ├── layer_matmul.h ├── tensor.h └── layer_conv2dtranspose.h ├── .clang-format └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | */CMakeLists.txt.user 3 | -------------------------------------------------------------------------------- /doc/add.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/SADL/HEAD/doc/add.pdf -------------------------------------------------------------------------------- /doc/max.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/SADL/HEAD/doc/max.pdf -------------------------------------------------------------------------------- /doc/biasadd.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/SADL/HEAD/doc/biasadd.pdf -------------------------------------------------------------------------------- /doc/concat.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/SADL/HEAD/doc/concat.pdf -------------------------------------------------------------------------------- /doc/conv2d.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/SADL/HEAD/doc/conv2d.pdf -------------------------------------------------------------------------------- /doc/operation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/SADL/HEAD/doc/operation.pdf -------------------------------------------------------------------------------- /sample/helper.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | inline sadl::layers::TensorInternalType::Type getModelType(const std::string &filename) 7 | { 8 | const std::string MAGICNUMBER = "SADL0002"; 9 | std::ifstream file(filename, std::ios::binary); 10 | if (!file) { 11 | std::cerr << "[ERROR] No file " << filename << std::endl; 12 | exit(-1); 13 | } 14 | char magic[9]; 15 | file.read(magic, 8); 16 | magic[8] = '\0'; 17 | std::string magic_s = magic; 18 | if (magic_s != MAGICNUMBER) 19 | { 20 | std::cerr << "[ERROR] Pb reading model: wrong magic " << magic_s << std::endl; 21 | exit(-1); 22 | } 23 | 24 | int8_t x = 0; 25 | file.read((char *) &x, sizeof(int8_t)); 26 | return (sadl::layers::TensorInternalType::Type) x; 27 | } 28 | 29 | -------------------------------------------------------------------------------- /sample/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.5) 2 | 3 | project(sample LANGUAGES CXX) 4 | 5 | set(CMAKE_CXX_STANDARD 14) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | set(CMAKE_CXX_FLAGS "-std=c++14 -ffast-math -Wall -fstrict-aliasing") 8 | set(CMAKE_CXX_FLAGS_RELEASE "-O3") 9 | 10 | include_directories(..) 11 | file(GLOB HEADER_FILES helper.h ../sadl/*.h ) 12 | 13 | add_executable(sample_generic sample.cpp ${HEADER_FILES}) 14 | add_executable(sample_simd256 sample.cpp ${HEADER_FILES}) 15 | add_executable(sample_simd512 sample.cpp ${HEADER_FILES}) 16 | set_target_properties(sample_simd256 PROPERTIES COMPILE_FLAGS "-mavx2 -DNDEBUG=1 " ) 17 | set_target_properties(sample_simd512 PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -DNDEBUG=1" ) 18 | 19 | add_executable(count_mac count_mac.cpp ${HEADER_FILES}) 20 | set_target_properties(count_mac PROPERTIES COMPILE_FLAGS "-DNDEBUG=1 " ) # must build in scalar mode to count MAC 21 | 22 | add_executable(debug_model debug_model.cpp ${HEADER_FILES}) 23 | set_target_properties(debug_model PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw" ) # must build in SIMD mode to debug SIMD issue 24 | 25 | add_executable(naive_quantization naive_quantization.cpp ${HEADER_FILES} dumper.h) 26 | 27 | 28 | -------------------------------------------------------------------------------- /sample/sample_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | echo "[INFO] BUILD SADL SAMPLE" 5 | # build sample 6 | mkdir -p sample_test; 7 | cd sample_test; 8 | cmake -DCMAKE_BUILD_TYPE=Release ../sample 9 | make 10 | echo "" 11 | 12 | echo "[INFO] TF2 -> ONNX -> SADL" 13 | # TF2 14 | python3 ../sample/tf2.py 2>/dev/null # output a tf2.onnx 15 | python3 ../converter/main.py --input_onnx tf2.onnx --output tf2.sadl 16 | ./sample_simd512 tf2.sadl 17 | echo "" 18 | 19 | echo "[INFO] PYTORCH -> ONNX -> SADL" 20 | # torch 21 | python3 ../sample/pytorch.py 2>/dev/null # output a pytorch.onnx 22 | python3 ../converter/main.py --input_onnx pytorch.onnx --output pytorch.sadl 23 | ./sample_simd512 pytorch.sadl 24 | echo "" 25 | 26 | echo "[INFO] DEBUG MODEL" 27 | ./debug_model pytorch.sadl > debug_model.log 28 | echo "see debug_model.log" 29 | echo "" 30 | 31 | echo "[INFO] COUNT MAC" 32 | ./count_mac pytorch.sadl 33 | 34 | echo "[INFO] WRITE INT16 MODEL" 35 | echo "0 15 1 8 2 0 3 8 6 8 7 0 8 8 10 9 11 0 12 8 14 8 15 1 16 8 20 8 21 0 22 9 24 8 25 0 26 8 28 8 29 0 30 8 34 8 35 0 36 8 38 0 39 0 42 8 43 0 44 8" | ./naive_quantization pytorch.sadl pytorch_int16.sadl; 36 | 37 | 38 | if [ -f tf2.sadl -a -f pytorch.sadl \ 39 | -a -f sample_generic -a -f sample_simd256 -a -f sample_simd512 \ 40 | -a -f count_mac \ 41 | -a -f debug_model \ 42 | -a -f naive_quantization ]; then 43 | exit 0; 44 | else 45 | exit 1; 46 | fi; 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | The copyright in this software is being made available under the BSD 4 | License, included below. This software may be subject to other third party 5 | and contributor rights, including patent rights, and no such rights are 6 | granted under this license. 7 | 8 | Copyright (c) 2010-2022, ITU/ISO/IEC 9 | All rights reserved. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | * Redistributions of source code must retain the above copyright notice, 15 | this list of conditions and the following disclaimer. 16 | * Redistributions in binary form must reproduce the above copyright notice, 17 | this list of conditions and the following disclaimer in the documentation 18 | and/or other materials provided with the distribution. 19 | * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 20 | be used to endorse or promote products derived from this software without 21 | specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 26 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 27 | BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 28 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 29 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 30 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 31 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 32 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 33 | THE POSSIBILITY OF SUCH DAMAGE. 34 | -------------------------------------------------------------------------------- /sample/pytorch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def weights_init(m): 7 | if isinstance(m, nn.Conv2d): 8 | torch.nn.init.xavier_uniform_(m.weight) 9 | torch.nn.init.zeros_(m.bias) 10 | 11 | 12 | h = 16 13 | w =16 14 | s = (3, h, w) 15 | 16 | 17 | class Model(nn.Module): 18 | def __init__(self): 19 | super(Model, self).__init__() 20 | nbf = 8 21 | self.nbf=nbf 22 | self.conv01 = nn.Conv2d(3, nbf, kernel_size=(3, 3), padding=1, bias=True) 23 | self.pool0 = nn.MaxPool2d(2,2) 24 | self.conv02 = nn.Conv2d(nbf, nbf, kernel_size=(3, 3), padding=1, bias=True) 25 | 26 | self.conv11 = nn.Conv2d(nbf, nbf, kernel_size=(3, 3), padding=1, bias=True) 27 | self.conv12 = nn.Conv2d(nbf, nbf, kernel_size=(3, 3), padding=1, bias=True) 28 | self.pool1 = nn.MaxPool2d(2,2) 29 | self.conv13 = nn.Conv2d(nbf, 2*nbf, kernel_size=(3, 3), padding=1, bias=True) 30 | 31 | self.conv21 = nn.Conv2d( 32 | 2 * nbf, 2 * nbf, kernel_size=(3, 3), padding=1, bias=True 33 | ) 34 | self.conv22 = nn.Conv2d( 35 | 2 * nbf, 2 * nbf, kernel_size=(3, 3), padding=1, bias=True 36 | ) 37 | self.pool2 = nn.MaxPool2d(2,2) 38 | self.conv23 = nn.Conv2d( 39 | 2 * nbf, 4 * nbf, kernel_size=(3, 3), padding=1, bias=True 40 | ) 41 | 42 | self.linear = nn.Linear(4*nbf*h//8*w//8, 2) 43 | self.apply(weights_init) 44 | 45 | def forward(self, inputs): 46 | input, = inputs 47 | x = self.conv01(input) 48 | x = self.pool0(x) 49 | x = self.conv02(x) 50 | 51 | x0 = self.conv11(x) 52 | x0 = self.conv12(x0) 53 | x0 = x0+x 54 | x0 = self.pool1(x0) 55 | x0 = self.conv13(x0) 56 | 57 | x1 = self.conv21(x0) 58 | x1 = self.conv22(x1) 59 | x1 = x1+x0 60 | x1 = self.pool2(x1) 61 | x1 = self.conv23(x1) 62 | 63 | x1= x1.reshape((1,4*self.nbf*h//8*w//8)) 64 | y=self.linear(x1) 65 | return y 66 | 67 | 68 | model = Model() 69 | input0 = ( 70 | np.linspace(-1.0, 1, np.prod(s)).reshape((1, h, w, 3)).astype(np.float32) 71 | ) # in sadl, tensor are nhwc... 72 | input0 = np.transpose(input0,(0,3,1,2)) # transpose for pytorch 73 | inputs_torch = [ torch.from_numpy(input0)] 74 | inputs_torch[0].requires_grad=True 75 | output = model(inputs_torch) 76 | print("Output",output) 77 | torch.onnx.export(model, inputs_torch, "./pytorch.onnx") 78 | -------------------------------------------------------------------------------- /sample/tf2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import onnx 5 | import tensorflow as tf 6 | import tf2onnx 7 | 8 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 10 | tensor_fmt = "channels_last" 11 | 12 | s = (16, 16,3) 13 | inputs = tf.keras.Input(shape=s, name="input0", dtype=tf.float32) 14 | 15 | nbf = 8 16 | x = tf.keras.layers.Conv2D( 17 | nbf, 18 | (3, 3), 19 | activation="linear", 20 | data_format=tensor_fmt, 21 | use_bias=True, 22 | bias_initializer="glorot_uniform", 23 | padding="same", 24 | )(inputs) 25 | x = tf.keras.layers.MaxPool2D(2,data_format=tensor_fmt)(x) 26 | x = tf.keras.layers.Conv2D( 27 | nbf, 28 | (3, 3), 29 | activation="linear", 30 | data_format=tensor_fmt, 31 | use_bias=True, 32 | bias_initializer="glorot_uniform", 33 | padding="same", 34 | )(x) 35 | 36 | x0 = tf.keras.layers.Conv2D( 37 | nbf, 38 | kernel_size=(3, 3), 39 | activation="relu", 40 | use_bias=True, 41 | data_format=tensor_fmt, 42 | padding="same", 43 | )(x) 44 | x0 = tf.keras.layers.Conv2D( 45 | nbf, 46 | kernel_size=(3, 3), 47 | activation="relu", 48 | use_bias=True, 49 | data_format=tensor_fmt, 50 | padding="same", 51 | )(x0) 52 | x0 = x0 + x 53 | x0 = tf.keras.layers.MaxPool2D(2,data_format=tensor_fmt)(x0) 54 | x0 = tf.keras.layers.Conv2D( 55 | 2 * nbf, 56 | kernel_size=(3, 3), 57 | activation="relu", 58 | use_bias=True, 59 | data_format=tensor_fmt, 60 | padding="same", 61 | )(x0) 62 | 63 | x1 = tf.keras.layers.Conv2D( 64 | 2 * nbf, 65 | kernel_size=(3, 3), 66 | activation="relu", 67 | use_bias=True, 68 | data_format=tensor_fmt, 69 | padding="same", 70 | )(x0) 71 | x1 = tf.keras.layers.Conv2D( 72 | 2 * nbf, 73 | kernel_size=(3, 3), 74 | activation="relu", 75 | use_bias=True, 76 | data_format=tensor_fmt, 77 | padding="same", 78 | )(x1) 79 | x1 = x1 + x0 80 | x1 = tf.keras.layers.MaxPool2D(2,data_format=tensor_fmt)(x1) 81 | x1 = tf.keras.layers.Conv2D( 82 | 4 * nbf, 83 | kernel_size=(3, 3), 84 | activation="relu", 85 | use_bias=True, 86 | data_format=tensor_fmt, 87 | padding="same", 88 | )(x1) 89 | 90 | x2 = tf.keras.layers.Reshape((1,4*nbf*16//8*16//8))(x1) 91 | y = tf.keras.layers.Dense(2)(x2) 92 | model = tf.keras.Model(inputs=[inputs],outputs=y,name="cat_classifier") 93 | 94 | 95 | X = np.linspace(-1.0, 1, np.prod(s)).reshape((1,) + s) 96 | Y = model(X) 97 | 98 | model_onnx, _ = tf2onnx.convert.from_keras( 99 | model, [tf.TensorSpec(shape=(1,) + s, name="input0")], opset=13 100 | ) 101 | onnx.save(model_onnx, "./tf2.onnx") 102 | # print("Input\n",X) 103 | print("Output\n",Y) 104 | 105 | print("Model in tf2.onnx") 106 | -------------------------------------------------------------------------------- /sadl/layer_identity.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | template class Identity : public Layer 41 | { 42 | public: 43 | using Layer::Layer; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | virtual bool apply(std::vector *> &in) override; 48 | virtual bool init(const std::vector *> &in) override; 49 | virtual bool mutateInput() const override { return true; } 50 | 51 | protected: 52 | virtual bool loadInternal(std::istream &file, Version v) override; 53 | }; 54 | 55 | template bool Identity::apply(std::vector *> &in) 56 | { 57 | assert(in.size() == 1); 58 | assert(in[0]->dims() == out_.dims()); 59 | swap(*in[0], out_); 60 | return true; 61 | } 62 | 63 | template bool Identity::init(const std::vector *> &in) 64 | { 65 | if (in.size() != 1) 66 | return false; 67 | out_.resize(in[0]->dims()); 68 | initDone_ = true; 69 | return true; 70 | } 71 | 72 | template bool Identity::loadInternal(std::istream &, Version) 73 | { 74 | return true; 75 | } 76 | 77 | } // namespace layers 78 | } // namespace sadl 79 | -------------------------------------------------------------------------------- /sadl/layer_shape.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | template class Shape : public Layer 41 | { 42 | public: 43 | using Layer::Layer; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | virtual bool apply(std::vector *> &in) override; 48 | virtual bool init(const std::vector *> &in) override; 49 | 50 | protected: 51 | virtual bool loadInternal(std::istream &file, Version v) override; 52 | }; 53 | 54 | template bool Shape::apply(std::vector *> &in) 55 | { 56 | assert(in.size() == 1); 57 | (void) in; 58 | // done at init 59 | return true; 60 | } 61 | 62 | template bool Shape::init(const std::vector *> &in) 63 | { 64 | if (in.size() != 1) 65 | return false; 66 | Dimensions d; 67 | d.resize(1); 68 | d[0] = in[0]->dims().size(); 69 | out_.resize(d); 70 | copy(in[0]->dims().begin(), in[0]->dims().end(), out_.begin()); 71 | initDone_ = true; 72 | return true; 73 | } 74 | 75 | template bool Shape::loadInternal(std::istream &, Version) 76 | { 77 | return true; 78 | } 79 | 80 | } // namespace layers 81 | } // namespace sadl 82 | -------------------------------------------------------------------------------- /sadl/layer_relu.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | template class Relu : public Layer 41 | { 42 | public: 43 | using Layer::Layer; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | virtual bool apply(std::vector *> &in) override; 48 | virtual bool init(const std::vector *> &in) override; 49 | virtual bool mutateInput() const override { return true; } 50 | 51 | protected: 52 | virtual bool loadInternal(std::istream &file, Version v) override; 53 | }; 54 | 55 | template bool Relu::apply(std::vector *> &in) 56 | { 57 | assert(in.size() == 1); 58 | assert(in[0]->dims() == out_.dims()); 59 | swap(*in[0], out_); 60 | for (auto &x: out_) 61 | x = (x < 0) ? 0 : x; 62 | return true; 63 | } 64 | 65 | template bool Relu::init(const std::vector *> &in) 66 | { 67 | if (in.size() != 1) 68 | return false; 69 | out_.resize(in[0]->dims()); 70 | initDone_ = true; 71 | return true; 72 | } 73 | 74 | template bool Relu::loadInternal(std::istream &, Version) 75 | { 76 | return true; 77 | } 78 | 79 | } // namespace layers 80 | } // namespace sadl 81 | -------------------------------------------------------------------------------- /sadl/options.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | 35 | // build options 36 | // behavior 37 | #ifndef SATURATE_RESULT 38 | #define SATURATE_RESULT 1 // avoid overflow in int NN 39 | #endif 40 | // optimization 41 | // nothing/-msse42: no simd 42 | // -mavx2: avx2 43 | // -mavx2 -mfma: avx2 + fuse multiply/add 44 | // -mavx512bw -mavx512f: avx512 45 | // #define NDEBUG 1 // remove sanity tests 46 | 47 | // debug 48 | // #define DEBUG_VALUES 1 // show values 49 | // #define DEBUG_MODEL 1 // show pb with model 50 | // #define DEBUG_COUNTERS 1 // print overflow etc. 51 | // #define DEBUG_PRINT 1 // print model info 52 | // #define DEBUG_SIMD 1 // tell about non simd version 53 | // #define DEBUG_KEEP_OUTPUT 1 // keep a copy of the output tensor 54 | #if SATURATE_RESULT 55 | #define SATURATE(X) if (!std::is_same::value) X = (X>ComputationType::max)?ComputationType::max:(X<-ComputationType::max?-ComputationType::max:X) 56 | #else 57 | #define SATURATE(X) 58 | #endif 59 | 60 | #if DEBUG_COUNTERS 61 | template T my_abs(T x) { return xcpt_op; if (my_abs(X) > ComputationType::max) ++this->cpt_overflow 63 | #define COUNTERS_MAC(X) ++this->cpt_mac; if (X!=0) ++this->cpt_mac_nz 64 | #else 65 | #define COUNTERS(X) (void)X 66 | #define COUNTERS_MAC(X) (void)X 67 | #endif 68 | 69 | 70 | #ifndef DUMP_MODEL_EXT 71 | #define DUMP_MODEL_EXT 72 | #endif 73 | namespace sadl 74 | { 75 | enum class Version 76 | { 77 | unknown = -1, 78 | sadl01 = 1, 79 | sadl02 = 2 80 | }; 81 | } 82 | -------------------------------------------------------------------------------- /sadl/layer_copy.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | template class Copy : public Layer 41 | { 42 | public: 43 | using Layer::Layer; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | virtual bool apply(std::vector *> &in) override; 48 | virtual bool init(const std::vector *> &in) override; 49 | 50 | void setInputLayer(typename Layer::Id id); 51 | 52 | protected: 53 | virtual bool loadInternal(std::istream &file, Version v) override; 54 | }; 55 | 56 | template bool Copy::apply(std::vector *> &in) 57 | { 58 | assert(in.size() == 1); 59 | assert(in[0]->dims() == out_.dims()); 60 | std::copy(in[0]->begin(), in[0]->end(), out_.begin()); 61 | out_.quantizer = in[0]->quantizer; // adapt output width to bias 62 | out_.border_skip = in[0]->border_skip; // adapt output width to bias 63 | 64 | return true; 65 | } 66 | 67 | template bool Copy::init(const std::vector *> &in) 68 | { 69 | if (in.size() != 1) 70 | return false; 71 | out_.resize(in[0]->dims()); 72 | initDone_ = true; 73 | return true; 74 | } 75 | 76 | template void Copy::setInputLayer(typename Layer::Id iid) 77 | { 78 | this->inputs_id_.push_back(iid); 79 | this->name_ = "copy"; 80 | } 81 | 82 | template bool Copy::loadInternal(std::istream &, Version) 83 | { 84 | return false; 85 | } 86 | 87 | } // namespace layers 88 | } // namespace sadl 89 | -------------------------------------------------------------------------------- /sadl/layer_leakyrelu.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | template class LeakyRelu : public Layer 41 | { 42 | public: 43 | using Layer::Layer; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | virtual bool apply(std::vector *> &in) override; 48 | virtual bool init(const std::vector *> &in) override; 49 | virtual bool mutateInput() const override { return true; } 50 | 51 | protected: 52 | virtual bool loadInternal(std::istream &file, Version) override; 53 | }; 54 | 55 | template bool LeakyRelu::apply(std::vector *> &in) 56 | { 57 | assert(in.size() == 2); 58 | assert(in[0]->dims() == out_.dims()); 59 | const Tensor &A = *in[1]; 60 | swap(*in[0], out_); 61 | // keep same qunatiz as input 62 | const typename ComputationType::type alpha = A[0]; 63 | 64 | const int alpha_q = A.quantizer; 65 | for (auto &x: out_) 66 | { 67 | if (x < 0) 68 | { 69 | typename ComputationType::type z = x * alpha; 70 | ComputationType::quantize(z, alpha_q); 71 | COUNTERS(z); 72 | // do not saturate because alpha<0 73 | x = z; 74 | } 75 | } 76 | 77 | return true; 78 | } 79 | 80 | template bool LeakyRelu::init(const std::vector *> &in) 81 | { 82 | if (in.size() != 2) 83 | return false; 84 | out_.resize(in[0]->dims()); 85 | initDone_ = true; 86 | return true; 87 | } 88 | 89 | template bool LeakyRelu::loadInternal(std::istream &, Version) 90 | { 91 | return true; 92 | } 93 | 94 | } // namespace layers 95 | } // namespace sadl 96 | -------------------------------------------------------------------------------- /sadl/dimensions.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include 35 | #include 36 | #include 37 | #include 38 | 39 | namespace sadl 40 | { 41 | struct Dimensions 42 | { 43 | static constexpr int MaxDim = 6; 44 | using iterator = int *; 45 | using const_iterator = const int *; 46 | 47 | Dimensions() = default; 48 | Dimensions(std::initializer_list L) 49 | { 50 | assert((int) L.size() <= MaxDim); 51 | s_ = (int) L.size(); 52 | std::copy(L.begin(), L.end(), v_); 53 | } 54 | 55 | void resize(int s) 56 | { 57 | assert(s <= MaxDim); 58 | s_ = s; 59 | } 60 | int size() const { return s_; } 61 | int64_t nbElements() const { return std::accumulate(v_, v_ + s_, (int64_t)1, [](int64_t a,int64_t b) { return a*b; }); } 62 | int operator[](int k) const { return v_[k]; } 63 | int & operator[](int k) { return v_[k]; } 64 | iterator begin() { return v_; } 65 | iterator end() { return v_ + s_; } 66 | const_iterator begin() const { return v_; } 67 | const_iterator end() const { return v_ + s_; } 68 | bool operator==(const Dimensions &d) const { return d.s_ == s_ && std::equal(v_, v_ + s_, d.v_); } 69 | int back() const { return v_[s_ - 1]; } 70 | 71 | private: 72 | int v_[MaxDim] = {}; 73 | int s_ = 0; 74 | }; 75 | 76 | } // namespace sadl 77 | 78 | //#if !NDEBUG 79 | #include 80 | namespace sadl 81 | { 82 | inline std::ostream &operator<<(std::ostream &out, const Dimensions &d) 83 | { 84 | out << "( "; 85 | for (int k = 0; k < (int) d.size(); ++k) 86 | out << d[k] << ' '; 87 | out << ')'; 88 | return out; 89 | } 90 | } // namespace sadl 91 | -------------------------------------------------------------------------------- /sadl/layers.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | 35 | #include "layer_placeholder.h" 36 | #include "layer_reshape.h" 37 | #include "layer_const.h" 38 | #include "layer_identity.h" 39 | #include "layer_conv2d.h" // before matmul to get def of sum8_float 40 | #include "layer_conv2dtranspose.h" 41 | #include "layer_matmul.h" 42 | #include "layer_biasadd.h" 43 | #include "layer_add.h" 44 | #include "layer_relu.h" 45 | #include "layer_copy.h" 46 | #include "layer_maxpool.h" 47 | #include "layer_mul.h" 48 | #include "layer_concat.h" 49 | #include "layer_maximum.h" 50 | #include "layer_leakyrelu.h" 51 | #include "layer_transpose.h" 52 | #include "layer_flatten.h" 53 | #include "layer_shape.h" 54 | #include "layer_expand.h" 55 | 56 | namespace sadl 57 | { 58 | namespace layers 59 | { 60 | 61 | inline std::string opName(const OperationType::Type op) 62 | { 63 | #define DIRTYCASEPRINT(X) \ 64 | case OperationType::X: oss << #X; break 65 | std::ostringstream oss; 66 | switch (op) 67 | { 68 | DIRTYCASEPRINT(Copy); 69 | DIRTYCASEPRINT(Const); 70 | DIRTYCASEPRINT(Placeholder); 71 | DIRTYCASEPRINT(Identity); 72 | DIRTYCASEPRINT(BiasAdd); 73 | DIRTYCASEPRINT(MaxPool); 74 | DIRTYCASEPRINT(MatMul); 75 | DIRTYCASEPRINT(Reshape); 76 | DIRTYCASEPRINT(Relu); 77 | DIRTYCASEPRINT(Conv2D); 78 | DIRTYCASEPRINT(Conv2DTranspose); 79 | DIRTYCASEPRINT(Add); 80 | DIRTYCASEPRINT(Mul); 81 | DIRTYCASEPRINT(Concat); 82 | DIRTYCASEPRINT(Maximum); 83 | DIRTYCASEPRINT(LeakyRelu); 84 | DIRTYCASEPRINT(Transpose); 85 | DIRTYCASEPRINT(Flatten); 86 | DIRTYCASEPRINT(Shape); 87 | DIRTYCASEPRINT(Expand); 88 | default: oss << "??"; break; 89 | } 90 | return oss.str(); 91 | #undef DIRTYCASEPRINT 92 | } 93 | 94 | 95 | } // namespace layers 96 | 97 | } // namespace sadl 98 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: LLVM 4 | AccessModifierOffset: -2 5 | AlignAfterOpenBracket: Align 6 | AlignConsecutiveAssignments: true 7 | AlignConsecutiveDeclarations: true 8 | AlignEscapedNewlines: Right 9 | AlignOperands: true 10 | AlignTrailingComments: true 11 | AllowAllParametersOfDeclarationOnNextLine: false 12 | AllowShortBlocksOnASingleLine: true 13 | AllowShortCaseLabelsOnASingleLine: true 14 | AllowShortFunctionsOnASingleLine: Inline 15 | AllowShortIfStatementsOnASingleLine: false 16 | AllowShortLoopsOnASingleLine: false 17 | AlwaysBreakAfterDefinitionReturnType: None 18 | AlwaysBreakAfterReturnType: None 19 | AlwaysBreakBeforeMultilineStrings: false 20 | AlwaysBreakTemplateDeclarations: false 21 | BinPackArguments: true 22 | BinPackParameters: true 23 | BraceWrapping: 24 | AfterClass: true 25 | AfterControlStatement: true 26 | AfterEnum: true 27 | AfterFunction: true 28 | AfterNamespace: true 29 | AfterObjCDeclaration: false 30 | AfterStruct: true 31 | AfterUnion: true 32 | AfterExternBlock: true 33 | BeforeCatch: true 34 | BeforeElse: true 35 | IndentBraces: false 36 | SplitEmptyFunction: false 37 | SplitEmptyRecord: false 38 | SplitEmptyNamespace: false 39 | BeforeLambdaBody : true 40 | BreakBeforeBinaryOperators: NonAssignment 41 | BreakBeforeBraces: Allman 42 | BreakBeforeInheritanceComma: false 43 | BreakBeforeTernaryOperators: true 44 | BreakConstructorInitializersBeforeComma: true 45 | BreakConstructorInitializers: BeforeComma 46 | BreakAfterJavaFieldAnnotations: false 47 | BreakStringLiterals: true 48 | ColumnLimit: 160 49 | CommentPragmas: '^ IWYU pragma:' 50 | CompactNamespaces: false 51 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 52 | ConstructorInitializerIndentWidth: 2 53 | ContinuationIndentWidth: 2 54 | Cpp11BracedListStyle: false 55 | DerivePointerAlignment: false 56 | DisableFormat: false 57 | ExperimentalAutoDetectBinPacking: false 58 | FixNamespaceComments: true 59 | ForEachMacros: 60 | - foreach 61 | - Q_FOREACH 62 | - BOOST_FOREACH 63 | IncludeBlocks: Preserve 64 | IncludeCategories: 65 | - Regex: '^"(llvm|llvm-c|clang|clang-c)/' 66 | Priority: 2 67 | - Regex: '^(<|"(gtest|gmock|isl|json)/)' 68 | Priority: 3 69 | - Regex: '.*' 70 | Priority: 1 71 | IncludeIsMainRegex: '(Test)?$' 72 | IndentCaseLabels: false 73 | IndentPPDirectives: None 74 | IndentWidth: 2 75 | IndentWrappedFunctionNames: true 76 | JavaScriptQuotes: Leave 77 | JavaScriptWrapImports: true 78 | KeepEmptyLinesAtTheStartOfBlocks: false 79 | MacroBlockBegin: '' 80 | MacroBlockEnd: '' 81 | MaxEmptyLinesToKeep: 1 82 | NamespaceIndentation: None 83 | ObjCBinPackProtocolList: Auto 84 | ObjCBlockIndentWidth: 2 85 | ObjCSpaceAfterProperty: false 86 | ObjCSpaceBeforeProtocolList: true 87 | PenaltyBreakAssignment: 2 88 | PenaltyBreakBeforeFirstCallParameter: 19 89 | PenaltyBreakComment: 300 90 | PenaltyBreakFirstLessLess: 120 91 | PenaltyBreakString: 1000 92 | PenaltyExcessCharacter: 1000000 93 | PenaltyReturnTypeOnItsOwnLine: 60 94 | PointerAlignment: Right 95 | ReflowComments: true 96 | SortIncludes: false 97 | SortUsingDeclarations: true 98 | SpaceAfterCStyleCast: true 99 | SpaceAfterTemplateKeyword: false 100 | SpaceBeforeAssignmentOperators: true 101 | SpaceBeforeCtorInitializerColon: true 102 | SpaceBeforeInheritanceColon: true 103 | SpaceBeforeParens: ControlStatements 104 | SpaceBeforeRangeBasedForLoopColon: false 105 | SpaceInEmptyParentheses: false 106 | SpacesBeforeTrailingComments: 3 107 | SpacesInAngles: false 108 | SpacesInContainerLiterals: true 109 | SpacesInCStyleCastParentheses: false 110 | SpacesInParentheses: false 111 | SpacesInSquareBrackets: false 112 | Standard: Cpp11 113 | TabWidth: 8 114 | UseTab: Never 115 | ... 116 | 117 | -------------------------------------------------------------------------------- /sample/copy.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | using namespace std; 6 | 7 | template 8 | bool copy(const sadl::layers::Layer &layer, sadl::layers::Layer &layerQ) { 9 | // from loadPrefix 10 | layerQ.name_=layer.name_; 11 | layerQ.inputs_id_=layer.inputs_id_; 12 | // WARNING: SHOULD BE SYNC BY HAND WITH NEW LAYERS 13 | // IF LOADINTERNAL IMPLEMENTED FOR A LAYER 14 | switch(layerQ.op()) { 15 | case sadl::layers::OperationType::Add: break; 16 | case sadl::layers::OperationType::BiasAdd: break; 17 | case sadl::layers::OperationType::Concat: break; 18 | case sadl::layers::OperationType::Const: layerQ.out_.resize(layer.out_.dims()); for(int k=0;k &>(layerQ).strides_=dynamic_cast &>(layer).strides_; 21 | dynamic_cast &>(layerQ).pads_=dynamic_cast &>(layer).pads_; 22 | break; 23 | case sadl::layers::OperationType::Conv2DTranspose: 24 | dynamic_cast &>(layerQ).strides_=dynamic_cast &>(layer).strides_; 25 | dynamic_cast &>(layerQ).pads_=dynamic_cast &>(layer).pads_; 26 | dynamic_cast &>(layerQ).out_pads_=dynamic_cast &>(layer).out_pads_; 27 | break; 28 | case sadl::layers::OperationType::Copy: break; 29 | case sadl::layers::OperationType::Identity: break; 30 | case sadl::layers::OperationType::LeakyRelu: break; 31 | case sadl::layers::OperationType::MatMul: break; 32 | case sadl::layers::OperationType::MaxPool: 33 | dynamic_cast &>(layerQ).kernel_=dynamic_cast &>(layer).kernel_; 34 | dynamic_cast &>(layerQ).strides_=dynamic_cast &>(layer).strides_; 35 | dynamic_cast &>(layerQ).pads_=dynamic_cast &>(layer).pads_; 36 | break; 37 | case sadl::layers::OperationType::Maximum: break; 38 | case sadl::layers::OperationType::Mul: break; 39 | case sadl::layers::OperationType::Placeholder: /* do not copy q */; break; 40 | case sadl::layers::OperationType::Relu: break; 41 | case sadl::layers::OperationType::Reshape: break; 42 | case sadl::layers::OperationType::OperationTypeCount: break; 43 | case sadl::layers::OperationType::Transpose: 44 | dynamic_cast &>(layerQ).perm_=dynamic_cast &>(layer).perm_; 45 | break; 46 | case sadl::layers::OperationType::Flatten: 47 | dynamic_cast &>(layerQ).axis_=dynamic_cast &>(layer).axis_; 48 | dynamic_cast &>(layerQ).dim_=dynamic_cast &>(layer).dim_; 49 | break; 50 | case sadl::layers::OperationType::Shape: break; 51 | case sadl::layers::OperationType::Expand: break; 52 | // no default to get warning 53 | } 54 | 55 | return true; 56 | } 57 | 58 | template 59 | bool copy(const sadl::Model &model, sadl::Model &modelQ) { 60 | modelQ.version_ = model.version_; 61 | modelQ.data_.clear(); 62 | modelQ.data_.resize(model.data_.size()); 63 | modelQ.ids_input = model.ids_input; 64 | modelQ.ids_output = model.ids_output; 65 | int nb_layers = modelQ.data_.size(); 66 | for (int k = 0; k < nb_layers; ++k) { 67 | modelQ.data_[k].layer = sadl::createLayer(model.data_[k].layer->id(), model.data_[k].layer->op()); 68 | modelQ.data_[k].inputs.clear(); 69 | copy(*model.data_[k].layer, *modelQ.data_[k].layer); 70 | } 71 | return true; 72 | } 73 | 74 | -------------------------------------------------------------------------------- /sadl/layer_flatten.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | template class Flatten : public Layer 41 | { 42 | public: 43 | using Layer::Layer; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | virtual bool apply(std::vector *> &in) override; 48 | virtual bool init(const std::vector *> &in) override; 49 | virtual bool mutateInput() const override { return true; } 50 | 51 | protected: 52 | virtual bool loadInternal(std::istream &file, Version v) override; 53 | int32_t axis_; 54 | Dimensions dim_; // dims after flatten 55 | DUMP_MODEL_EXT; 56 | }; 57 | 58 | template bool Flatten::apply(std::vector *> &in) 59 | { 60 | assert(in.size() == 1); 61 | assert(in[0]->size() == out_.size()); 62 | // resize done at init 63 | swapData(*in[0], out_); 64 | 65 | return true; 66 | } 67 | 68 | template bool Flatten::init(const std::vector *> &in) 69 | { 70 | if (in.size() != 1) 71 | return false; 72 | SADL_DBG(std::cout << " - " << in[0]->dims() << std::endl); 73 | int nb_dim = axis_ + 1; 74 | dim_.resize(nb_dim); 75 | for (int k = 0; k < axis_; ++k) 76 | dim_[k] = in[0]->dims()[k]; 77 | int s = 1; 78 | for (int k = axis_; k < in[0]->dims().size(); ++k) 79 | s *= in[0]->dims()[k]; 80 | dim_[axis_] = s; 81 | SADL_DBG(std::cout << " - new shape: " << dim_ << std::endl); 82 | out_.resize(dim_); 83 | initDone_ = true; 84 | return true; 85 | } 86 | 87 | template bool Flatten::loadInternal(std::istream &file, Version) 88 | { 89 | // load values 90 | int32_t x = 0; 91 | file.read((char *) &x, sizeof(x)); 92 | if (x <= 0 || x > Dimensions::MaxDim) 93 | { 94 | std::cerr << "[ERROR] invalid axis: " << x << std::endl; 95 | return false; 96 | } 97 | axis_ = x; 98 | SADL_DBG(std::cout << " - start axis: " << axis_ << std::endl); 99 | return true; 100 | } 101 | 102 | } // namespace layers 103 | } // namespace sadl 104 | -------------------------------------------------------------------------------- /sample/debug_model.cpp: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | 34 | #define DEBUG_VALUES 1 // show values 35 | #define DEBUG_MODEL 1 // show pb with model 36 | #define DEBUG_PRINT 1 // print model info 37 | #define DEBUG_SIMD 1 // tell about non simd version 38 | 39 | #include 40 | #include 41 | #include 42 | #include 43 | #include "helper.h" 44 | 45 | using namespace std; 46 | 47 | namespace 48 | { 49 | 50 | template void infer(const string &filename) 51 | { 52 | sadl::Model model; 53 | ifstream file(filename, ios::binary); 54 | cout << "[INFO] Model loading" << endl; 55 | if (!model.load(file)) 56 | { 57 | cerr << "[ERROR] Unable to read model " << filename << endl; 58 | exit(-1); 59 | } 60 | 61 | // sadl::Tensor::skip_border = true; 62 | vector> inputs = model.getInputsTemplate(); 63 | 64 | cout << "[INFO] Model initilization" << endl; 65 | 66 | if (!model.init(inputs)) 67 | { 68 | cerr << "[ERROR] issue during initialization" << endl; 69 | exit(-1); 70 | } 71 | 72 | if (!model.apply(inputs)) 73 | { 74 | cerr << "[ERROR] issue during inference" << endl; 75 | exit(-1); 76 | } 77 | 78 | if (sadl::Tensor::skip_border) 79 | cout << "[INFO] discard border size=" << model.result().border_skip << endl; 80 | 81 | const int N = model.getIdsOutput().size(); 82 | for (int i = 0; i < N; ++i) 83 | cout << "[INFO] output " << i << '\n' << model.result(i) << endl; 84 | } 85 | 86 | } // namespace 87 | 88 | int main(int argc, char **argv) 89 | { 90 | if (argc != 2) 91 | { 92 | cout << "[ERROR] sample filename_model" << endl; 93 | return 1; 94 | } 95 | 96 | const string filename_model = argv[1]; 97 | 98 | sadl::layers::TensorInternalType::Type type_model = getModelType(filename_model); 99 | switch (type_model) 100 | { 101 | case sadl::layers::TensorInternalType::Float: infer(filename_model); break; 102 | case sadl::layers::TensorInternalType::Int32: infer(filename_model); break; 103 | case sadl::layers::TensorInternalType::Int16: infer(filename_model); break; 104 | default: cerr << "[ERROR] unsupported type" << endl; exit(-1); 105 | } 106 | 107 | return 0; 108 | } 109 | -------------------------------------------------------------------------------- /sadl/layer_placeholder.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | // like an identity 41 | template class Placeholder : public Layer 42 | { 43 | public: 44 | using Layer::Layer; 45 | using Layer::out_; 46 | using Layer::initDone_; 47 | 48 | virtual bool apply(std::vector *> &in) override; 49 | virtual bool init(const std::vector *> &in) override; 50 | virtual bool mutateInput() const override { return true; } 51 | int quantizer() const { return q_; } 52 | Dimensions dims() const { return dims_; } 53 | 54 | protected: 55 | virtual bool loadInternal(std::istream &file, Version v) override; 56 | int q_ = -1000; // will override user input 57 | Dimensions dims_; // can be use as a hint by user 58 | DUMP_MODEL_EXT; 59 | }; 60 | 61 | template bool Placeholder::apply(std::vector *> &in) 62 | { 63 | assert(in.size() == 1); 64 | swap(*in[0], out_); 65 | if (q_ >= 0) 66 | { // v2 67 | out_.quantizer = q_; 68 | } 69 | out_.border_skip = 0; 70 | return true; 71 | } 72 | 73 | template bool Placeholder::init(const std::vector *> &in) 74 | { 75 | if (in.size() != 1) 76 | return false; 77 | out_.resize(in[0]->dims()); 78 | dims_ = in[0]->dims(); 79 | initDone_ = true; 80 | return true; 81 | } 82 | 83 | template bool Placeholder::loadInternal(std::istream &file, Version v) 84 | { 85 | int32_t x = 0; 86 | file.read((char *) &x, sizeof(x)); 87 | if (x <= 0 || x > Dimensions::MaxDim) 88 | { 89 | std::cerr << "[ERROR] invalid nb of dimensions: " << x << std::endl; 90 | return false; 91 | } 92 | dims_.resize(x); 93 | file.read((char *) dims_.begin(), sizeof(int) * x); 94 | // HACK 95 | if (dims_.size() == 1) 96 | { 97 | x = dims_[0]; 98 | dims_.resize(2); 99 | dims_[0] = 1; 100 | dims_[1] = x; 101 | } 102 | // END HACK 103 | file.read((char *) &q_, sizeof(q_)); 104 | SADL_DBG(std::cout << " - dim: " << dims_ << std::endl); 105 | SADL_DBG(std::cout << " - q: " << q_ << std::endl); 106 | return true; 107 | } 108 | 109 | } // namespace layers 110 | } // namespace sadl 111 | -------------------------------------------------------------------------------- /sample/count_mac.cpp: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | 34 | #define DEBUG_COUNTERS 1 // print overflow etc. 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include "helper.h" 40 | 41 | using namespace std; 42 | 43 | namespace 44 | { 45 | template void infer(const string &filename) 46 | { 47 | sadl::Model model; 48 | ifstream file(filename, ios::binary); 49 | cout << "[INFO] Model loading" << endl; 50 | if (!model.load(file)) 51 | { 52 | cerr << "[ERROR] Unable to read model " << filename << endl; 53 | exit(-1); 54 | } 55 | 56 | // sadl::Tensor::skip_border = true; 57 | vector> inputs = model.getInputsTemplate(); 58 | // fill with 1 59 | for(auto &t: inputs) { 60 | T v=1<<(t.quantizer); 61 | for(auto &x: t) x=v; 62 | } 63 | cout << "[INFO] Model initilization" << endl; 64 | 65 | if (!model.init(inputs)) 66 | { 67 | cerr << "[ERROR] issue during initialization" << endl; 68 | exit(-1); 69 | } 70 | 71 | model.resetCounters(); 72 | 73 | if (!model.apply(inputs)) 74 | { 75 | cerr << "[ERROR] issue during inference" << endl; 76 | exit(-1); 77 | } 78 | 79 | cout << "\n[INFO] Complexity assessment" << endl; 80 | auto stat = model.printOverflow(true); 81 | cout << "[INFO] ---------------------------------" << endl; 82 | cout << "[INFO] " << stat.op << " OPs" << endl; 83 | cout << "[INFO] " << stat.mac << " MACs" << endl; 84 | cout << "[INFO] " << stat.overflow << " overflow" << endl; 85 | //cout << "[INFO] " << stat.mac_nz << " MACs non 0" << endl; 86 | cout << "[INFO] ---------------------------------" << endl; 87 | } 88 | 89 | } // namespace 90 | 91 | int main(int argc, char **argv) 92 | { 93 | if (argc != 2) 94 | { 95 | cout << "[ERROR] count_mac filename_model" << endl; 96 | return 1; 97 | } 98 | 99 | const string filename_model = argv[1]; 100 | 101 | sadl::layers::TensorInternalType::Type type_model = getModelType(filename_model); 102 | switch (type_model) 103 | { 104 | case sadl::layers::TensorInternalType::Float: infer(filename_model); break; 105 | case sadl::layers::TensorInternalType::Int32: infer(filename_model); break; 106 | case sadl::layers::TensorInternalType::Int16: infer(filename_model); break; 107 | default: cerr << "[ERROR] unsupported type" << endl; exit(-1); 108 | } 109 | 110 | return 0; 111 | } 112 | -------------------------------------------------------------------------------- /sample/sample.cpp: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include "helper.h" 39 | 40 | using namespace std; 41 | 42 | namespace 43 | { 44 | 45 | template void infer(const string &filename) 46 | { 47 | 48 | sadl::Model model; 49 | ifstream file(filename, ios::binary); 50 | cout << "[INFO] Model loading" << endl; 51 | if (!model.load(file)) 52 | { 53 | cerr << "[ERROR] Unable to read model " << filename << endl; 54 | exit(-1); 55 | } 56 | 57 | 58 | // sadl::Tensor::skip_border = true; 59 | vector> inputs = model.getInputsTemplate(); 60 | cout << "[INFO] Model initilization" << endl; 61 | 62 | if (!model.init(inputs)) 63 | { 64 | cerr << "[ERROR] issue during initialization" << endl; 65 | exit(-1); 66 | } 67 | 68 | // fill input with values from -1 to 1 69 | double step = (1. + 1.) / (inputs[0].size() - 1); 70 | double x0 = -1.; 71 | for (auto &t: inputs) 72 | for (auto &x: t) 73 | { 74 | x = x0; 75 | x0 += step; 76 | } 77 | chrono::steady_clock::time_point t1 = chrono::steady_clock::now(); 78 | if (!model.apply(inputs)) 79 | { 80 | cerr << "[ERROR] issue during inference" << endl; 81 | exit(-1); 82 | } 83 | chrono::steady_clock::time_point t2 = chrono::steady_clock::now(); 84 | chrono::duration dt = chrono::duration_cast>(t2 - t1); 85 | cout << "[INFO] " << dt.count() * 1000. << " ms" << endl; 86 | 87 | const int N = model.getIdsOutput().size(); 88 | for (int i = 0; i < N; ++i) 89 | cout << "[INFO] output " << i << '\n' << model.result(i) << endl; 90 | } 91 | 92 | } // namespace 93 | 94 | int main(int argc, char **argv) 95 | { 96 | if (argc != 2) 97 | { 98 | cout << "[ERROR] sample filename_model" << endl; 99 | return 1; 100 | } 101 | 102 | const string filename_model = argv[1]; 103 | 104 | sadl::layers::TensorInternalType::Type type_model = getModelType(filename_model); 105 | switch (type_model) 106 | { 107 | case sadl::layers::TensorInternalType::Float: infer(filename_model); break; 108 | case sadl::layers::TensorInternalType::Int32: infer(filename_model); break; 109 | case sadl::layers::TensorInternalType::Int16: infer(filename_model); break; 110 | default: cerr << "[ERROR] unsupported type" << endl; exit(-1); 111 | } 112 | 113 | return 0; 114 | } 115 | -------------------------------------------------------------------------------- /sadl/layer_reshape.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | template class Reshape : public Layer 41 | { 42 | public: 43 | using Layer::Layer; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | virtual bool apply(std::vector *> &in) override; 48 | virtual bool init(const std::vector *> &in) override; 49 | virtual bool mutateInput() const override { return true; } 50 | 51 | protected: 52 | virtual bool loadInternal(std::istream &file, Version v) override; 53 | }; 54 | 55 | // assume data in in[0] and shape in in[1] 56 | template bool Reshape::apply(std::vector *> &in) 57 | { 58 | assert(in.size() == 2); 59 | // second layer is reshape prms 60 | // assert(in[1]->dims().size()==2); 61 | assert(in[0]->size() == out_.size()); 62 | // resize done at init 63 | swapData(*in[0], out_); 64 | 65 | return true; 66 | } 67 | 68 | template bool Reshape::init(const std::vector *> &in) 69 | { 70 | if (in.size() != 2) 71 | return false; 72 | SADL_DBG(std::cout << " - " << in[0]->dims() << ' ' << in[1]->dims() << std::endl); 73 | // second layer is always reshape prms: value as int inside the tensor 74 | if (in[1]->dims().size() != 1) 75 | return false; 76 | Dimensions dim; 77 | dim.resize((int)in[1]->size()); 78 | if (!std::is_same::value&&in[1]->quantizer!=0) { 79 | std::cerr << "[ERROR] quantizer on reshape dimensions data layer" << std::endl; 80 | return false; 81 | } 82 | for (int k = 0; k < in[1]->size(); ++k) 83 | { 84 | if ((*in[1]) (k) == -1) 85 | { // keep dim of org 86 | dim[k] = in[0]->dims()[k]; 87 | } 88 | else 89 | { 90 | dim[k] = (int) ((*in[1]) (k)); 91 | } 92 | } 93 | if (dim.nbElements() != in[0]->dims().nbElements()) 94 | { 95 | std::cerr << "[ERROR] reshape incompatible sizes " << dim << ' ' << in[0]->dims() << std::endl; 96 | std::cerr << "[ERROR] "; 97 | for (int k = 0; k < in[1]->dims()[0]; ++k) 98 | std::cerr << (*in[1]) (k) << ' '; 99 | std::cerr << std::endl; 100 | 101 | return false; 102 | } 103 | SADL_DBG(std::cout << " - new shape: " << dim << std::endl); 104 | out_.resize(dim); 105 | initDone_ = true; 106 | return true; 107 | } 108 | 109 | template bool Reshape::loadInternal(std::istream &, Version) 110 | { 111 | return true; 112 | } 113 | 114 | } // namespace layers 115 | } // namespace sadl 116 | -------------------------------------------------------------------------------- /sadl/layer_const.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include 35 | 36 | #include "layer.h" 37 | 38 | namespace sadl 39 | { 40 | namespace layers 41 | { 42 | template class Const : public Layer 43 | { 44 | public: 45 | using Layer::Layer; 46 | using Layer::out_; // to avoid this-> 47 | using Layer::initDone_; 48 | 49 | virtual bool apply(std::vector *> &in) override; 50 | virtual bool init(const std::vector *> &in) override; 51 | 52 | protected: 53 | virtual bool loadInternal(std::istream &file, Version v) override; 54 | template void readTensor(std::istream &file, Tensor &out); 55 | DUMP_MODEL_EXT; 56 | }; 57 | 58 | template bool Const::apply(std::vector *> &in) 59 | { 60 | assert(in.size() == 0); 61 | (void) in; 62 | // assert(ptr==ptr) 63 | return true; 64 | } 65 | 66 | template bool Const::init(const std::vector *> &in) 67 | { 68 | if (in.size() != 0) 69 | return false; 70 | initDone_ = true; 71 | return true; 72 | } 73 | 74 | template template void Const::readTensor(std::istream &file, Tensor &out) 75 | { 76 | if (std::is_same::value) 77 | file.read((char *) out.data(), sizeof(T) * out.size()); 78 | else 79 | { 80 | std::vector data(out.size()); 81 | file.read((char *) data.data(), sizeof(U) * data.size()); 82 | for (int k = 0; k < (int) data.size(); ++k) 83 | out[k] = static_cast(data[k]); 84 | } 85 | } 86 | 87 | template bool Const::loadInternal(std::istream &file, Version v) 88 | { 89 | // load values 90 | int32_t x = 0; 91 | file.read((char *) &x, sizeof(x)); 92 | if (x <= 0 || x > Dimensions::MaxDim) 93 | { 94 | std::cerr << "[ERROR] invalid nb of dimensions: " << x << std::endl; 95 | return false; 96 | } 97 | Dimensions d; 98 | d.resize(x); 99 | for (int k = 0; k < d.size(); ++k) 100 | { 101 | file.read((char *) &x, sizeof(x)); 102 | d[k] = x; 103 | } 104 | 105 | if (d.nbElements() >= Tensor::kMaxSize) 106 | { 107 | std::cerr << "[ERROR] tensor too large? " << d.nbElements() << std::endl; 108 | return false; 109 | } 110 | out_.resize(d); 111 | SADL_DBG(std::cout << " - tensor: " << out_.dims() << std::endl); 112 | 113 | file.read((char *) &x, sizeof(x)); 114 | 115 | // cannot check internal type because tensor also used by reshape etc. 116 | switch (x) 117 | { 118 | case TensorInternalType::Int32: 119 | // assert((std::is_same::value)); 120 | file.read((char *) &out_.quantizer, sizeof(out_.quantizer)); 121 | readTensor(file, out_); 122 | break; 123 | case TensorInternalType::Float: 124 | // assert((std::is_same::value)); 125 | readTensor(file, out_); 126 | break; 127 | case TensorInternalType::Int16: 128 | // assert((std::is_same::value)); 129 | file.read((char *) &out_.quantizer, sizeof(out_.quantizer)); 130 | readTensor(file, out_); 131 | break; 132 | default: std::cerr << "[ERROR] unknown internal type " << x << std::endl; return false; 133 | } 134 | 135 | SADL_DBG(std::cout << " - data: "; for (int k = 0; k < 4 && k < out_.size(); ++k) std::cout << out_[k] << ' '; std::cout << " ...\n"); 136 | SADL_DBG(std::cout << " - quantizer: " << out_.quantizer << std::endl); 137 | // SADL_DBG(std::cout< 2 | 3 | template 4 | bool sadl::layers::Conv2D::dump(std::ostream &file) { 5 | int32_t x = strides_.size(); 6 | file.write((const char *)&x, sizeof(int32_t)); 7 | file.write((const char *)strides_.begin(), strides_.size() * sizeof(int32_t)); 8 | x=pads_.size(); 9 | file.write((const char *)&x, sizeof(int32_t)); 10 | file.write((const char *)pads_.begin(), pads_.size() * sizeof(int32_t)); 11 | file.write((const char *)&q_, sizeof(q_)); 12 | return true; 13 | } 14 | 15 | template bool sadl::layers::Conv2DTranspose::dump(std::ostream &file) 16 | { 17 | int32_t x = strides_.size(); 18 | file.write((const char *) &x, sizeof(int32_t)); 19 | file.write((const char *) strides_.begin(), strides_.size() * sizeof(int32_t)); 20 | x = pads_.size(); 21 | file.write((const char *) &x, sizeof(int32_t)); 22 | file.write((const char *) pads_.begin(), pads_.size() * sizeof(int32_t)); 23 | x = out_pads_.size(); 24 | file.write((const char *) &x, sizeof(int32_t)); 25 | file.write((const char *) out_pads_.begin(), out_pads_.size() * sizeof(int32_t)); 26 | file.write((const char *) &q_, sizeof(q_)); 27 | return true; 28 | } 29 | 30 | template 31 | bool sadl::layers::MatMul::dump(std::ostream &file) { 32 | file.write((const char *)&q_, sizeof(q_)); 33 | return true; 34 | } 35 | 36 | template 37 | bool sadl::layers::Mul::dump(std::ostream &file) { 38 | file.write((const char *)&q_, sizeof(q_)); 39 | return true; 40 | } 41 | 42 | template 43 | bool sadl::layers::Placeholder::dump(std::ostream &file) { 44 | int32_t x = dims_.size(); 45 | file.write((const char*)&x, sizeof(x)); 46 | file.write((const char*)dims_.begin(), sizeof(int)*x); 47 | file.write((const char *)&q_, sizeof(q_)); 48 | return true; 49 | } 50 | 51 | template 52 | bool sadl::layers::MaxPool::dump(std::ostream &file) { 53 | int32_t x = strides_.size(); 54 | file.write((const char *)&x, sizeof(int32_t)); 55 | file.write((const char *)strides_.begin(), strides_.size() * sizeof(int32_t)); 56 | x = kernel_.size(); 57 | file.write((const char *)&x, sizeof(int32_t)); 58 | file.write((const char *)kernel_.begin(), kernel_.size() * sizeof(int32_t)); 59 | x=pads_.size(); 60 | file.write((const char *)&x, sizeof(int32_t)); 61 | file.write((const char *)pads_.begin(), pads_.size() * sizeof(int32_t)); 62 | return true; 63 | } 64 | 65 | template 66 | bool sadl::layers::Flatten::dump(std::ostream &file) { 67 | int32_t x = axis_; 68 | file.write((const char *)&x, sizeof(int32_t)); 69 | return true; 70 | } 71 | 72 | 73 | 74 | template 75 | bool sadl::layers::Const::dump(std::ostream &file) { 76 | // load values 77 | int32_t x = out_.dims().size(); 78 | file.write((const char *)&x, sizeof(x)); 79 | file.write((const char *)out_.dims().begin(), x * sizeof(int)); 80 | if (std::is_same::value) { 81 | x = TensorInternalType::Int16; 82 | } else if (std::is_same::value) { 83 | x = TensorInternalType::Int32; 84 | } else if (std::is_same::value) { 85 | x = TensorInternalType::Float; 86 | } else { 87 | std::cerr << "[ERROR] to do" << std::endl; 88 | exit(-1); 89 | } 90 | file.write((const char *)&x, sizeof(x)); 91 | 92 | if (!std::is_same::value) file.write((const char *)&out_.quantizer, sizeof(out_.quantizer)); 93 | file.write((const char *)out_.data(), out_.size() * sizeof(T)); 94 | return true; 95 | } 96 | 97 | template 98 | bool sadl::layers::Layer::dump(std::ostream &file) { 99 | // std::cout<<"todo? "< 104 | bool sadl::Model::dump(std::ostream &file) { 105 | char magic[9] = "SADL0002"; 106 | file.write(magic, 8); 107 | int32_t x = 0; 108 | if (std::is_same::value) 109 | x = layers::TensorInternalType::Float; 110 | else if (std::is_same::value) 111 | x = layers::TensorInternalType::Int32; 112 | else if (std::is_same::value) 113 | x = layers::TensorInternalType::Int16; 114 | else { 115 | std::cerr << "[ERROR] to do Model::dump" << std::endl; 116 | exit(-1); 117 | } 118 | file.write((const char *)&x, sizeof(int32_t)); 119 | 120 | int32_t nb_layers = data_.size(); 121 | file.write((const char *)&nb_layers, sizeof(int32_t)); 122 | int32_t nb = ids_input.size(); 123 | file.write((const char *)&nb, sizeof(int32_t)); 124 | file.write((const char *)ids_input.data(), sizeof(int32_t) * nb); 125 | nb = ids_output.size(); 126 | file.write((const char *)&nb, sizeof(int32_t)); 127 | file.write((const char *)ids_output.data(), sizeof(int32_t) * nb); 128 | 129 | 130 | for (int k = 0; k < nb_layers; ++k) { 131 | // save header 132 | int32_t x = data_[k].layer->id(); 133 | file.write((const char *)&x, sizeof(int32_t)); 134 | x = data_[k].layer->op(); 135 | file.write((const char *)&x, sizeof(int32_t)); 136 | // savePrefix 137 | int32_t L = data_[k].layer->name_.size(); 138 | file.write((const char *)&L, sizeof(int32_t)); 139 | file.write((const char *)data_[k].layer->name_.c_str(), data_[k].layer->name_.size()); 140 | L = data_[k].layer->inputs_id_.size(); 141 | file.write((const char *)&L, sizeof(int32_t)); 142 | file.write((const char *)data_[k].layer->inputs_id_.data(), data_[k].layer->inputs_id_.size() * sizeof(int32_t)); 143 | data_[k].layer->dump(file); 144 | } 145 | return true; 146 | } 147 | 148 | -------------------------------------------------------------------------------- /sadl/layer_expand.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | template class Expand : public Layer 41 | { 42 | public: 43 | using Layer::Layer; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | virtual bool apply(std::vector *> &in) override; 48 | virtual bool init(const std::vector *> &in) override; 49 | 50 | protected: 51 | virtual bool loadInternal(std::istream &file, Version v) override; 52 | }; 53 | 54 | // assume data in in[0] and shape in in[1] 55 | template bool Expand::apply(std::vector *> &in) 56 | { 57 | assert(in.size() == 2); 58 | // second layer is reshape prms, already process in init 59 | out_.border_skip = in[0]->border_skip; // adapt output width to bias 60 | out_.quantizer=in[0]->quantizer; 61 | 62 | if (in[0]->size() == 1) 63 | { // broadcast 64 | const auto v = (*in[0])[0]; 65 | fill(out_.begin(), out_.end(), v); 66 | } 67 | else 68 | { 69 | // quick hack: to improve 70 | if (out_.dims().size() == 4) 71 | { 72 | const Dimensions d = out_.dims(); 73 | assert(d[0] == 1); 74 | assert(in[0]->dims()[3] == 1); 75 | for (int i = 0; i < 1 /*d[0]*/; ++i) 76 | { 77 | for (int j = 0; j < d[1]; ++j) 78 | { 79 | for (int k = 0; k < d[2]; ++k) 80 | { 81 | const auto offset_in0 = (d[2] * (d[1] * i + j) + k); 82 | const auto offset_in1 = d[3] * offset_in0; 83 | const auto v = in[0]->data()[offset_in0]; 84 | for (int l = 0; l < out_.dims()[3]; ++l) 85 | { 86 | out_.data()[offset_in1 + l] = v; 87 | } 88 | } 89 | } 90 | } 91 | } 92 | else 93 | { 94 | SADL_DBG(std::cout << "TODO" << std::endl); 95 | exit(-1); 96 | } 97 | } 98 | return true; 99 | } 100 | 101 | template bool Expand::init(const std::vector *> &in) 102 | { 103 | if (in.size() != 2) 104 | return false; 105 | SADL_DBG(std::cout << " - " << in[0]->dims() << ' ' << in[1]->dims() << std::endl); 106 | // second layer is always reshape prms: value as int inside the tensor 107 | if (in[1]->dims().size() != 1) 108 | return false; 109 | Dimensions dim; 110 | dim.resize((int)in[1]->size()); 111 | if (!std::is_same::value&&in[1]->quantizer!=0) { 112 | std::cerr << "[ERROR] quantizer on reshape dimensions data layer" << std::endl; 113 | return false; 114 | } 115 | copy(in[1]->begin(), in[1]->end(), dim.begin()); 116 | // current restriction: broadcast only scalar to shape or expand last channel =1 of a tensor of dim 4 117 | bool ok = false; 118 | if (in[0]->size() == 1) 119 | { 120 | ok = true; 121 | } 122 | else 123 | { 124 | if (in[0]->dims().size() != dim.size() || dim.size() != 4) 125 | { 126 | ok = false; 127 | } 128 | else 129 | { 130 | ok = (in[0]->dims().back() == 1); 131 | for (int k = 0; k < dim.size() - 1; ++k) 132 | if (in[0]->dims()[k] != dim[k]) 133 | ok = false; 134 | } 135 | } 136 | if (!ok) 137 | { 138 | std::cerr << "[ERROR] value to expand not supported " << in[0]->dims() << " expand to " << dim << std::endl; 139 | return false; 140 | } 141 | out_.resize(dim); 142 | SADL_DBG(std::cout << " - new shape: " << dim << std::endl); 143 | initDone_ = true; 144 | return true; 145 | } 146 | 147 | template bool Expand::loadInternal(std::istream &, Version) 148 | { 149 | return true; 150 | } 151 | 152 | } // namespace layers 153 | } // namespace sadl 154 | -------------------------------------------------------------------------------- /sadl/layer_maximum.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | template class Maximum : public Layer 41 | { 42 | public: 43 | using Layer::Layer; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | virtual bool apply(std::vector *> &in) override; 48 | virtual bool init(const std::vector *> &in) override; 49 | virtual bool mutateInput() const override { return true; } 50 | 51 | protected: 52 | virtual bool loadInternal(std::istream &file, Version) override; 53 | }; 54 | 55 | template bool Maximum::apply(std::vector *> &in) 56 | { 57 | assert(in.size() == 2); 58 | if (in[0] == in[1]) 59 | { 60 | std::cerr << " input aliasing" << std::endl; 61 | return false; 62 | } 63 | const int shift = -(in[1]->quantizer - in[0]->quantizer); 64 | swap(*in[0], out_); 65 | 66 | /* 67 | Looking at the initialization, if the condition 68 | below is false, necessarily, `in[1]->dims().size()` 69 | is equal to 1. 70 | */ 71 | if (in[0]->dims() == in[1]->dims()) 72 | { 73 | for (auto it0 = out_.begin(), it1 = in[1]->begin(); it0 != out_.end(); ++it0, ++it1) 74 | { 75 | T z = *it1; 76 | ComputationType::shift_left(z, shift); 77 | *it0 = std::max(*it0, z); 78 | } 79 | } 80 | else 81 | { 82 | const Tensor &B{ *in[1] }; 83 | if (B.size() == 1) 84 | { 85 | T value{ B[0] }; 86 | ComputationType::shift_left(value, shift); 87 | for (auto it0 = out_.begin(); it0 != out_.end(); ++it0) 88 | { 89 | *it0 = std::max(*it0, value); 90 | } 91 | } 92 | else if (in[0]->dims().size() == 2) 93 | { 94 | const int N{ in[0]->dims()[0] }; 95 | const int H{ in[0]->dims()[1] }; 96 | for (int n = 0; n < N; ++n) 97 | for (int i = 0; i < H; ++i) 98 | { 99 | T z = B[i]; 100 | ComputationType::shift_left(z, shift); 101 | out_(n, i) = std::max(out_(n, i), z); 102 | } 103 | } 104 | else if (in[0]->dims().size() == 3) 105 | { 106 | const int N{ in[0]->dims()[0] }; 107 | const int H{ in[0]->dims()[1] }; 108 | const int W{ in[0]->dims()[2] }; 109 | for (int n = 0; n < N; ++n) 110 | for (int i = 0; i < H; ++i) 111 | for (int j = 0; j < W; ++j) 112 | { 113 | T z = B[j]; 114 | ComputationType::shift_left(z, shift); 115 | out_(n, i, j) = std::max(out_(n, i, j), z); 116 | } 117 | } 118 | else if (in[0]->dims().size() == 4) 119 | { 120 | const int N{ in[0]->dims()[0] }; 121 | const int H{ in[0]->dims()[1] }; 122 | const int W{ in[0]->dims()[2] }; 123 | const int K{ in[0]->dims()[3] }; 124 | for (int n = 0; n < N; ++n) 125 | for (int i = 0; i < H; ++i) 126 | for (int j = 0; j < W; ++j) 127 | for (int k = 0; k < K; ++k) 128 | { 129 | T z = B[k]; 130 | ComputationType::shift_left(z, shift); 131 | out_(n, i, j, k) = std::max(out_(n, i, j, k), z); 132 | } 133 | } 134 | } 135 | return true; 136 | } 137 | 138 | template bool Maximum::init(const std::vector *> &in) 139 | { 140 | SADL_DBG(std::cout << " - " << in[0]->dims() << ' ' << in[1]->dims() << std::endl); 141 | if (in.size() != 2) 142 | { 143 | return false; 144 | } 145 | 146 | /* 147 | Broadcasting is supported. This means that either 148 | the two input Tensors have the same shape or the 149 | second input Tensor is a singleton or the second 150 | input Tensor is a vector and the last dimension 151 | of the first input Tensor is equal to the size 152 | of the second input Tensor. 153 | */ 154 | if (in[1]->size() == 1) 155 | { // singleton 156 | // ok 157 | } 158 | else if (in[1]->dims().size() == 1 || (in[1]->dims().size() == 2 && in[1]->dims()[0] == 1)) 159 | { 160 | if (in[1]->size() != in[0]->dims().back()) 161 | { // broadcast last tdim 162 | return false; 163 | } 164 | } 165 | else 166 | { 167 | if (!(in[0]->dims() == in[1]->dims())) 168 | { // same sim 169 | return false; 170 | } 171 | } 172 | out_.resize(in[0]->dims()); 173 | initDone_ = true; 174 | return true; 175 | } 176 | 177 | template bool Maximum::loadInternal(std::istream &, Version) 178 | { 179 | return true; 180 | } 181 | 182 | } // namespace layers 183 | } // namespace sadl 184 | -------------------------------------------------------------------------------- /sadl/layer_concat.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | template class Concat : public Layer 41 | { 42 | public: 43 | using Layer::Layer; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | virtual bool apply(std::vector *> &in) override; 48 | virtual bool init(const std::vector *> &in) override; 49 | 50 | protected: 51 | virtual bool loadInternal(std::istream &file, Version v) override; 52 | }; 53 | 54 | template bool Concat::apply(std::vector *> &in) 55 | { 56 | assert(in.size() >= 3); 57 | const int nb_in = (int) in.size() - 1; // without axis inputs 58 | assert((*in[nb_in]).size() == 1 && ((*in[nb_in])[0] == in[0]->dims().size() - 1 || (*in[nb_in])[0] == -1)); // currently concat on last axis 59 | int shift[16] = {}; 60 | int qmin = in.front()->quantizer; 61 | if (!std::is_same::value) 62 | { 63 | assert(nb_in < 16); 64 | for (int i = 0; i < nb_in; ++i) 65 | { 66 | if (in[i]->quantizer < qmin) 67 | qmin = in[i]->quantizer; 68 | } 69 | for (int i = 0; i < nb_in; ++i) 70 | { 71 | shift[i] = in[i]->quantizer - qmin; 72 | } 73 | } 74 | out_.quantizer = qmin; // adapt output width to last input 75 | out_.border_skip = in[0]->border_skip; 76 | for (int i = 1; i < nb_in; ++i) 77 | out_.border_skip = std::max(out_.border_skip, in[i]->border_skip); 78 | 79 | const Dimensions dim = in[0]->dims(); 80 | if (dim.size() == 2) 81 | { 82 | for (int i = 0; i < dim[0]; ++i) 83 | { 84 | int offset = 0; 85 | for (int n = 0; n < nb_in; ++n) 86 | { 87 | const Tensor &A = *(in[n]); 88 | for (int j = 0; j < A.dims()[1]; ++j, ++offset) 89 | { 90 | T z = A(i, j); 91 | ComputationType::quantize(z, shift[n]); 92 | out_(i, offset) = z; 93 | } 94 | } 95 | } 96 | } 97 | else if (dim.size() == 3) 98 | { 99 | for (int i = 0; i < dim[0]; ++i) 100 | { 101 | for (int j = 0; j < dim[1]; ++j) 102 | { 103 | int offset = 0; 104 | for (int n = 0; n < nb_in; ++n) 105 | { 106 | const Tensor &A = *(in[n]); 107 | for (int k = 0; k < A.dims()[2]; ++k, ++offset) 108 | { 109 | T z = A(i, j, k); 110 | ComputationType::quantize(z, shift[n]); 111 | out_(i, j, offset) = z; 112 | } 113 | } 114 | } 115 | } 116 | } 117 | else if (dim.size() == 4) 118 | { 119 | for (int i = 0; i < dim[0]; ++i) 120 | { 121 | for (int j = 0; j < dim[1]; ++j) 122 | { 123 | for (int k = 0; k < dim[2]; ++k) 124 | { 125 | int offset = 0; 126 | for (int n = 0; n < nb_in; ++n) 127 | { 128 | const Tensor &A = *(in[n]); 129 | for (int l = 0; l < A.dims()[3]; ++l, ++offset) 130 | { 131 | T z = A(i, j, k, l); 132 | ComputationType::quantize(z, shift[n]); 133 | out_(i, j, k, offset) = z; 134 | } 135 | } 136 | } 137 | } 138 | } 139 | } 140 | else 141 | { 142 | // TO DO 143 | return false; 144 | } 145 | return true; 146 | } 147 | 148 | template bool Concat::init(const std::vector *> &in) 149 | { 150 | /* 151 | The axis of the concatenation is the third tensor 152 | in `in`. 153 | */ 154 | if (in.size() < 3) 155 | return false; 156 | if (in[0]->dims().size() < 1) 157 | return false; 158 | const int last_axis = in[0]->dims().size() - 1; 159 | 160 | // Currently, the concatenation is along the last axis. 161 | int axis_idx = (int) in.size() - 1; 162 | if (!((*in[axis_idx]).size() == 1 && ((*in[axis_idx])[0] == last_axis || (*in[axis_idx])[0] == -1))) 163 | return false; 164 | 165 | // should have same shape 166 | int sum_dim = 0; 167 | for (int i = 1; i < axis_idx; i++) 168 | { 169 | if (in[0]->dims().size() != in[i]->dims().size()) 170 | return false; 171 | sum_dim += in[i]->dims()[last_axis]; 172 | for (int k = 0; k < last_axis; ++k) 173 | if (in[0]->dims()[k] != in[i]->dims()[k]) 174 | return false; 175 | } 176 | Dimensions dim = in[0]->dims(); 177 | dim[last_axis] += sum_dim; 178 | out_.resize(dim); 179 | initDone_ = true; 180 | return true; 181 | } 182 | 183 | template bool Concat::loadInternal(std::istream &, Version) 184 | { 185 | return true; 186 | } 187 | 188 | } // namespace layers 189 | } // namespace sadl 190 | -------------------------------------------------------------------------------- /sadl/layer_transpose.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | template class Transpose : public Layer 41 | { 42 | public: 43 | using Layer::Layer; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | virtual bool apply(std::vector *> &in) override; 48 | virtual bool init(const std::vector *> &in) override; 49 | virtual bool mutateInput() const override { return true; } 50 | 51 | protected: 52 | virtual bool loadInternal(std::istream &file, Version v) override; 53 | std::array perm_; 54 | }; 55 | 56 | // assume data in in[0] and shape in in[1] 57 | template bool Transpose::apply(std::vector *> &in) 58 | { 59 | Dimensions d = out_.dims(); // {in[0]->dims()[0], in[0]->dims()[3], 60 | // in[0]->dims()[1], in[0]->dims()[2]}; 61 | const auto &A = *in[0]; 62 | Dimensions Ad = A.dims(); 63 | if (d.size() == 1) 64 | { 65 | swapData(*in[0], out_); 66 | } 67 | else if (d.size() == 4) 68 | { 69 | out_.quantizer = in[0]->quantizer; 70 | // border to skip?? 71 | std::array index; 72 | std::array index_mapped; 73 | for (int k = 0; k < 4; ++k) 74 | index_mapped[k] = &index[perm_[k]]; 75 | 76 | for (index[0] = 0; index[0] < Ad[0]; ++index[0]) 77 | for (index[1] = 0; index[1] < Ad[1]; ++index[1]) 78 | for (index[2] = 0; index[2] < Ad[2]; ++index[2]) 79 | for (index[3] = 0; index[3] < Ad[3]; ++index[3]) 80 | { 81 | auto offsetA = (Ad[3] * (Ad[2] * (Ad[1] * index[0] + index[1]) + index[2]) + index[3]); 82 | auto offsetOut = (d[3] * (d[2] * (d[1] * *index_mapped[0] + *index_mapped[1]) + *index_mapped[2]) + *index_mapped[3]); 83 | out_[offsetOut] = A[offsetA]; 84 | } 85 | } 86 | else if (d.size() == 6) 87 | { // very naive version 88 | out_.quantizer = in[0]->quantizer; 89 | // border to skip?? 90 | std::array index; 91 | std::array index_mapped; 92 | for (int k = 0; k < 6; ++k) 93 | index_mapped[k] = &index[perm_[k]]; 94 | 95 | for (index[0] = 0; index[0] < Ad[0]; ++index[0]) 96 | for (index[1] = 0; index[1] < Ad[1]; ++index[1]) 97 | for (index[2] = 0; index[2] < Ad[2]; ++index[2]) 98 | for (index[3] = 0; index[3] < Ad[3]; ++index[3]) 99 | for (index[4] = 0; index[4] < Ad[4]; ++index[4]) 100 | for (index[5] = 0; index[5] < Ad[5]; ++index[5]) 101 | { 102 | auto offsetA = Ad[5] * (Ad[4] * (Ad[3] * (Ad[2] * (Ad[1] * index[0] + index[1]) + index[2]) + index[3]) + index[4]) + index[5]; 103 | auto offsetOut = 104 | d[5] * (d[4] * (d[3] * (d[2] * (d[1] * *index_mapped[0] + *index_mapped[1]) + *index_mapped[2]) + *index_mapped[3]) + *index_mapped[4]) 105 | + *index_mapped[5]; 106 | out_[offsetOut] = A[offsetA]; 107 | } 108 | } 109 | else 110 | { 111 | std::cerr << "\nTODO Transpose case: " << in[0]->dims() << " => " << out_.dims() << std::endl; 112 | exit(-1); 113 | } 114 | // } 115 | return true; 116 | } 117 | 118 | template bool Transpose::init(const std::vector *> &in) 119 | { 120 | if (in.size() != 2) 121 | return false; 122 | SADL_DBG(std::cout << " - " << in[0]->dims() << ' ' << in[1]->dims() << std::endl); 123 | // second layer is always reshape prms: value as int inside the tensor 124 | if (in[1]->dims().size() != 1) 125 | return false; 126 | if (!std::is_same::value && in[1]->quantizer != 0) 127 | { 128 | std::cerr << "[ERROR] quantizer on reshape dimensions data layer" << std::endl; 129 | return false; 130 | } 131 | Dimensions dim; 132 | dim.resize((int) in[1]->size()); 133 | for (int k = 0; k < in[1]->size(); ++k) 134 | { 135 | if ((*in[1])(k) == -1) 136 | { // keep dim of org 137 | dim[k] = in[0]->dims()[k]; 138 | perm_[k] = k; 139 | } 140 | else 141 | { 142 | dim[k] = in[0]->dims()[(int) ((*in[1])(k))]; 143 | perm_[k] = (int) ((*in[1])(k)); 144 | } 145 | } 146 | if (dim.nbElements() != in[0]->dims().nbElements()) 147 | { 148 | std::cerr << "[ERROR] transpose incompatible sizes shuffle=[" << dim << "] input shape: " << in[0]->dims() << std::endl; 149 | return false; 150 | } 151 | SADL_DBG(std::cout << " - new shape: " << dim << std::endl); 152 | out_.resize(dim); 153 | initDone_ = true; 154 | return true; 155 | } 156 | 157 | template bool Transpose::loadInternal(std::istream &, Version) 158 | { 159 | return true; 160 | } 161 | 162 | } // namespace layers 163 | } // namespace sadl 164 | -------------------------------------------------------------------------------- /sadl/layer.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include "tensor.h" 39 | 40 | namespace sadl 41 | { 42 | template class Model; // fwd 43 | namespace layers 44 | { 45 | // should be similar to python dumper: 46 | struct OperationType 47 | { 48 | enum Type 49 | { 50 | Copy = -1, // internal layer 51 | Const = 1, // important to have const first 52 | Placeholder = 2, 53 | Identity = 3, 54 | BiasAdd = 4, 55 | MaxPool = 5, 56 | MatMul = 6, 57 | Reshape = 7, 58 | Relu = 8, 59 | Conv2D = 9, 60 | Add = 10, 61 | Concat = 11, 62 | Mul = 12, 63 | Maximum = 13, 64 | LeakyRelu = 14, 65 | Transpose = 15, 66 | Flatten = 16, 67 | Shape = 17, 68 | Expand = 18, 69 | Conv2DTranspose = 19, 70 | OperationTypeCount = 20 71 | }; 72 | }; 73 | 74 | struct TensorInternalType 75 | { 76 | enum Type 77 | { 78 | Int32 = 0, 79 | Float = 1, 80 | Int16 = 2, 81 | }; 82 | }; 83 | 84 | template class Layer 85 | { 86 | public: 87 | using Id = int32_t; 88 | using value_type = T; 89 | 90 | Layer(Id iid, OperationType::Type iop) : id_(iid), op_(iop) {} 91 | virtual ~Layer() = default; 92 | 93 | virtual bool apply(std::vector *> &in) = 0; // note: we ca modify inputs for optiz purpose 94 | virtual bool init(const std::vector *> &in) = 0; // run it once 95 | bool load(std::istream &file, Version v); 96 | 97 | bool initDone() const; 98 | virtual bool mutateInput() const { return false; } 99 | Tensor & output(); 100 | const std::string & name() const; 101 | Id id() const; 102 | const std::vector &inputsId() const; 103 | OperationType::Type op() const; 104 | void replaceInputId(Id old, Id newid); 105 | #if DEBUG_MODEL 106 | bool computed_ = false; 107 | #endif 108 | #if DEBUG_KEEP_OUTPUT 109 | Tensor outcopy_; 110 | #endif 111 | #if DEBUG_COUNTERS 112 | int64_t cpt_op = 0; 113 | int64_t cpt_mac_nz = 0; 114 | int64_t cpt_mac = 0; 115 | int64_t cpt_overflow = 0; 116 | #endif 117 | protected: 118 | bool loadPrefix(std::istream &file, Version v); 119 | virtual bool loadInternal(std::istream &file, Version v) = 0; 120 | Tensor out_; 121 | const Id id_; 122 | const OperationType::Type op_; 123 | std::string name_; 124 | std::vector inputs_id_; 125 | bool initDone_ = false; 126 | template friend class sadl::Model; 127 | DUMP_MODEL_EXT; 128 | }; 129 | 130 | template bool Layer::load(std::istream &file, Version v) 131 | { 132 | return loadPrefix(file, v) && loadInternal(file, v); 133 | } 134 | 135 | template bool Layer::initDone() const 136 | { 137 | return initDone_; 138 | } 139 | 140 | template sadl::Tensor &Layer::output() 141 | { 142 | return out_; 143 | } 144 | 145 | template const std::string &Layer::name() const 146 | { 147 | return name_; 148 | } 149 | 150 | template typename Layer::Id Layer::id() const 151 | { 152 | return id_; 153 | } 154 | 155 | template const std::vector::Id> &Layer::inputsId() const 156 | { 157 | return inputs_id_; 158 | } 159 | 160 | template OperationType::Type Layer::op() const 161 | { 162 | return op_; 163 | } 164 | 165 | template void Layer::replaceInputId(Layer::Id old, Layer::Id newid) 166 | { 167 | std::replace(inputs_id_.begin(), inputs_id_.end(), old, newid); 168 | } 169 | 170 | template bool Layer::loadPrefix(std::istream &file, Version v) 171 | { 172 | initDone_ = false; 173 | int32_t L = 0; 174 | file.read((char *) &L, sizeof(int32_t)); 175 | constexpr int maxLength = 2048; 176 | assert(L > 0 && L + 1 < maxLength); // max name size 177 | char s[maxLength]; 178 | file.read(s, L); 179 | s[L] = '\0'; 180 | name_ = s; 181 | SADL_DBG(std::cout << " - name: " << name_ << '\n'); 182 | 183 | file.read((char *) &L, sizeof(int32_t)); 184 | assert(L >= 0 && L < 8); 185 | inputs_id_.resize(L); 186 | SADL_DBG(std::cout << " - inputs: "); 187 | for (auto &x: inputs_id_) 188 | { 189 | file.read((char *) &x, sizeof(int32_t)); 190 | SADL_DBG(std::cout << x << ' '); 191 | } 192 | SADL_DBG(std::cout << '\n'); 193 | return static_cast(file); 194 | } 195 | 196 | } // namespace layers 197 | 198 | } // namespace sadl 199 | -------------------------------------------------------------------------------- /sample/naive_quantization.cpp: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | #include 41 | #include 42 | #include 43 | #include 44 | #include 45 | 46 | #define DUMP_MODEL_EXT virtual bool dump(std::ostream &file) 47 | #define DEBUG_KEEP_OUTPUT 1 48 | // trick to access inner data 49 | #define private public 50 | #define protected public 51 | #include 52 | #undef private 53 | #undef protected 54 | #define private private 55 | #define protected public 56 | 57 | #include "helper.h" 58 | #include "dumper.h" 59 | #include "copy.h" 60 | 61 | using namespace std; 62 | 63 | namespace 64 | { 65 | constexpr int kNoQValue=-1000; 66 | 67 | bool toQuantize(sadl::layers::OperationType::Type type) { 68 | // negative logic 69 | return type != sadl::layers::OperationType::Add && 70 | type != sadl::layers::OperationType::BiasAdd && 71 | type != sadl::layers::OperationType::Concat && 72 | type != sadl::layers::OperationType::Copy && 73 | type != sadl::layers::OperationType::Expand && 74 | type != sadl::layers::OperationType::Flatten && 75 | type != sadl::layers::OperationType::Identity && 76 | type != sadl::layers::OperationType::LeakyRelu && 77 | type != sadl::layers::OperationType::MaxPool && 78 | type != sadl::layers::OperationType::Relu && 79 | type != sadl::layers::OperationType::Reshape && 80 | type != sadl::layers::OperationType::Shape && 81 | type != sadl::layers::OperationType::Transpose; 82 | } 83 | 84 | 85 | template 86 | void quantizeTensor(const sadl::Tensor &B, sadl::Tensor &Bq) { 87 | double Q = (1 << Bq.quantizer); 88 | for (int k = 0; k < B.size(); ++k) { 89 | double z = round(B[k] * Q); 90 | if (z <= -numeric_limits::max()) { 91 | z = -numeric_limits::max() + 1; 92 | } 93 | if (z >= numeric_limits::max()) { 94 | z = numeric_limits::max() - 1; 95 | } 96 | Bq[k] = (T)z; 97 | } 98 | } 99 | 100 | template 101 | void quantize(sadl::layers::Layer &layerQ, const sadl::layers::Layer &layer_float, int quantizer) { 102 | 103 | // layers with internal quantizer 104 | if (layerQ.op() == sadl::layers::OperationType::Conv2D) dynamic_cast &>(layerQ).q_ = quantizer; 105 | else if (layerQ.op() == sadl::layers::OperationType::MatMul) dynamic_cast &>(layerQ).q_ = quantizer; 106 | else if (layerQ.op() == sadl::layers::OperationType::Conv2DTranspose) dynamic_cast &>(layerQ).q_ = quantizer; 107 | else if (layerQ.op() == sadl::layers::OperationType::Mul) dynamic_cast &>(layerQ).q_ = quantizer; 108 | else if (layerQ.op() == sadl::layers::OperationType::Placeholder) dynamic_cast &>(layerQ).q_ = quantizer; 109 | else if (layerQ.op() == sadl::layers::OperationType::Const) { 110 | layerQ.out_.quantizer = quantizer; 111 | quantizeTensor(layer_float.out_, layerQ.out_); 112 | } else { 113 | cerr << "[ERROR] unsupported layer " << sadl::layers::opName(layerQ.op()) << endl; 114 | exit(-1); 115 | } 116 | } 117 | 118 | 119 | template void quantize(const string &filename,const string &filename_out,const std::vector &quantizers) 120 | { 121 | // load float model 122 | sadl::layers::TensorInternalType::Type type_model = getModelType(filename); 123 | if (type_model!=sadl::layers::TensorInternalType::Float) { 124 | std::cerr<<"[ERROR] please input a float model"< model; 129 | ifstream file(filename, ios::binary); 130 | cout << "[INFO] Model loading" << endl; 131 | if (!model.load(file)) 132 | { 133 | cerr << "[ERROR] Unable to read model " << filename << endl; 134 | exit(-1); 135 | } 136 | 137 | // init quantize model 138 | sadl::Model modelQ; 139 | if (!copy(model, modelQ)) { 140 | cerr << "[ERROR] Unable to copy model " << endl; 141 | exit(-1); 142 | } 143 | 144 | // we need to set the placeholders layers (input layers) size because init is not done 145 | auto inputs=model.getInputsTemplate(); 146 | std::vector> inputsQ{inputs.size()}; 147 | for (int s = 0; s < (int)inputsQ.size(); ++s) { 148 | inputsQ[s].resize(inputs[s].dims()); 149 | } 150 | int cpt = 0; 151 | for (auto &id_input: modelQ.ids_input) { 152 | auto &L = modelQ.getLayer(id_input); 153 | if (L.layer->op() == sadl::layers::OperationType::Placeholder) { 154 | assert(cpt<(int)inputs.size()); 155 | std::vector *> v = {&inputsQ[cpt]}; 156 | ++cpt; 157 | L.layer->init(v); 158 | } 159 | } 160 | // quantize each layer + set quantizer 161 | for (int k=0;k<(int)modelQ.data_.size();++k) { // 162 | auto &layer=*modelQ.data_[k].layer; 163 | if (toQuantize(layer.op())) { 164 | if (layer.id()>=(int)quantizers.size()||quantizers[layer.id()]==kNoQValue) { 165 | std::cerr << "[ERROR] need a quantizer for layer " << layer.id() <<" op="<> id_q; 198 | while(std::cin>>id>>N) { 199 | id_q.push_back({id,N}); 200 | max_id=max(max_id,id); 201 | } 202 | std::vector quantizers; 203 | quantizers.resize(max_id+1); 204 | fill(quantizers.begin(),quantizers.end(),kNoQValue); 205 | for(auto x: id_q) { 206 | quantizers[x.first]=x.second; 207 | } 208 | quantize(filename_model,filename_model_out,quantizers); 209 | 210 | return 0; 211 | } 212 | -------------------------------------------------------------------------------- /sadl/layer_maxpool.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | template class MaxPool : public Layer 41 | { 42 | public: 43 | using Layer::Layer; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | virtual bool apply(std::vector *> &in) override; 48 | virtual bool init(const std::vector *> &in) override; 49 | 50 | protected: 51 | virtual bool loadInternal(std::istream &file, Version v) override; 52 | Dimensions kernel_; 53 | Dimensions strides_; 54 | Dimensions pads_; 55 | DUMP_MODEL_EXT; 56 | }; 57 | 58 | // assume data in in[0] 59 | // data [batch, in_height, in_width, in_channels] 60 | // kernel [1, kernel_height, kernel_width, 1] 61 | // stride [1, stride_height, stride_width, 1] 62 | template bool MaxPool::apply(std::vector *> &in) 63 | { 64 | assert(in.size() == 1); 65 | assert(in[0]->dims().size() == 4); 66 | 67 | const Tensor &A = *in[0]; 68 | const int N = out_.dims()[0]; 69 | const int H = out_.dims()[1]; 70 | const int W = out_.dims()[2]; 71 | const int D = out_.dims()[3]; 72 | const int offset_end = kernel_[1] / 2; 73 | const int offset_start = kernel_[1] - 1 - offset_end; 74 | const int step = strides_[1]; 75 | const int in_H = in[0]->dims()[1]; 76 | 77 | // currently adhoc start 78 | int start = 0; 79 | if (step == 1) 80 | { 81 | start = 0; 82 | } 83 | else if (step == 2) 84 | { 85 | // if (in_H % 2 == 0) 86 | // start = 1; 87 | // else 88 | start = 0; 89 | } 90 | else if (step == 3) 91 | { 92 | if (in_H % 2 == 0) 93 | start = 0; 94 | else 95 | start = 1; 96 | } 97 | else 98 | { 99 | std::cerr << "[ERROR] to do" << std::endl; 100 | assert(false); 101 | exit(-1); 102 | } 103 | 104 | out_.quantizer = in[0]->quantizer; // adapt output width to bias 105 | out_.border_skip = in[0]->border_skip; // to check 106 | 107 | for (int im_nb = 0; im_nb < N; ++im_nb) 108 | { 109 | // loop on out 110 | for (int im_i = 0; im_i < H; ++im_i) 111 | { 112 | for (int im_j = 0; im_j < W; ++im_j) 113 | { 114 | for (int im_d = 0; im_d < D; ++im_d) 115 | { 116 | T xx = -std::numeric_limits::max(); 117 | for (int filter_i = -offset_start; filter_i <= offset_end; ++filter_i) 118 | { 119 | for (int filter_j = -offset_start; filter_j <= offset_end; ++filter_j) 120 | { 121 | int ii = im_i * step + filter_i + start; 122 | int jj = im_j * step + filter_j + start; 123 | if (A.in(im_nb, ii, jj, im_d)) 124 | { 125 | T x = A(im_nb, ii, jj, im_d); 126 | if (xx < x) 127 | xx = x; 128 | } 129 | } 130 | } 131 | out_(im_nb, im_i, im_j, im_d) = xx; 132 | } 133 | } 134 | } 135 | } 136 | 137 | return true; 138 | } 139 | 140 | // data [batch, in_height, in_width, in_channels] 141 | // kernel [filter_height, filter_width, in_channels, out_channels] 142 | template bool MaxPool::init(const std::vector *> &in) 143 | { 144 | if (in.size() != 1) 145 | return false; 146 | SADL_DBG(std::cout << " - input maxpool: " << in[0]->dims() << std::endl); 147 | SADL_DBG(std::cout << " - stride: " << strides_ << std::endl); 148 | SADL_DBG(std::cout << " - kernel: " << kernel_ << std::endl); 149 | if (in[0]->dims().size() != 4) 150 | return false; 151 | 152 | // convervative check 153 | if (kernel_.size() != 4) 154 | return false; 155 | // no pooling on batch and depth 156 | if (kernel_[0] != 1 || kernel_[3] != 1) 157 | return false; 158 | 159 | // no stride on batch and depth 160 | if (strides_.size() != 4) 161 | return false; 162 | if (strides_[0] != 1 || strides_[3] != 1) 163 | return false; 164 | 165 | // square filter 166 | if (kernel_[1] != kernel_[2]) 167 | return false; 168 | // square stride 169 | if (strides_[1] != strides_[2]) 170 | return false; 171 | 172 | Dimensions dim; 173 | 174 | dim.resize(4); 175 | dim[0] = in[0]->dims()[0]; 176 | constexpr int dilatation = 1; 177 | dim[1] = (int) floor((in[0]->dims()[1] + pads_[0] + pads_[2] - ((kernel_[1] - 1) * dilatation + 1)) / (float) strides_[1] + 1); 178 | dim[2] = (int) floor((in[0]->dims()[2] + pads_[1] + pads_[3] - ((kernel_[2] - 1) * dilatation + 1)) / (float) strides_[2] + 1); 179 | dim[3] = in[0]->dims()[3]; 180 | 181 | out_.resize(dim); 182 | SADL_DBG(std::cout << " - output: " << out_.dims() << std::endl); 183 | 184 | initDone_ = true; 185 | return true; 186 | } 187 | 188 | template bool MaxPool::loadInternal(std::istream &file, Version v) 189 | { 190 | // load values 191 | int32_t x = 0; 192 | file.read((char *) &x, sizeof(x)); 193 | if (x <= 0 || x > Dimensions::MaxDim) 194 | { 195 | std::cerr << "[ERROR] invalid nb of dimensions strides: " << x << std::endl; 196 | return false; 197 | } 198 | strides_.resize(x); 199 | for (int k = 0; k < strides_.size(); ++k) 200 | { 201 | file.read((char *) &x, sizeof(x)); 202 | strides_[k] = x; 203 | } 204 | SADL_DBG(std::cout << " - strides: " << strides_ << std::endl); 205 | if (strides_.size() != 4) 206 | { 207 | std::cerr << "[ERROR] invalid strides: " << strides_.size() << std::endl; 208 | return false; 209 | } 210 | if (strides_[0] != 1) 211 | { 212 | std::cerr << "[ERROR] invalid strides[0]: " << strides_[0] << std::endl; 213 | return false; 214 | } 215 | if (strides_[3] != 1) 216 | { 217 | std::cerr << "[ERROR] invalid strides[3]: " << strides_[3] << std::endl; 218 | return false; 219 | } 220 | if (strides_[1] != strides_[2]) 221 | { 222 | std::cerr << "[ERROR] invalid stride H Vs: " << strides_ << std::endl; 223 | return false; 224 | } 225 | 226 | x = 0; 227 | file.read((char *) &x, sizeof(x)); 228 | if (x <= 0 || x > Dimensions::MaxDim) 229 | { 230 | std::cerr << "[ERROR] invalid nb of dimensions kernel: " << x << std::endl; 231 | return false; 232 | } 233 | kernel_.resize(x); 234 | for (int k = 0; k < kernel_.size(); ++k) 235 | { 236 | file.read((char *) &x, sizeof(x)); 237 | kernel_[k] = x; 238 | } 239 | SADL_DBG(std::cout << " - kernel: " << kernel_ << std::endl); 240 | if (kernel_.size() != 4) 241 | { 242 | std::cerr << "[ERROR] invalid kernel: " << kernel_.size() << std::endl; 243 | return false; 244 | } 245 | if (kernel_[0] != 1) 246 | { 247 | std::cerr << "[ERROR] invalid kernel[0]: " << kernel_[0] << std::endl; 248 | return false; 249 | } 250 | if (kernel_[3] != 1) 251 | { 252 | std::cerr << "[ERROR] invalid kernel[3]: " << kernel_[3] << std::endl; 253 | return false; 254 | } 255 | if (kernel_[1] != kernel_[2]) 256 | { 257 | std::cerr << "[ERROR] invalid kernel H V: " << kernel_ << std::endl; 258 | return false; 259 | } 260 | file.read((char *) &x, sizeof(x)); 261 | if (x <= 0 || x > Dimensions::MaxDim) 262 | { 263 | std::cerr << "[ERROR] invalid nb of dimensions: " << x << std::endl; 264 | return false; 265 | } 266 | pads_.resize(x); 267 | for (int k = 0; k < pads_.size(); ++k) 268 | { 269 | file.read((char *) &x, sizeof(x)); 270 | pads_[k] = x; 271 | } 272 | SADL_DBG(std::cout << " - pads: " << pads_ << std::endl); 273 | return true; 274 | } 275 | 276 | } // namespace layers 277 | } // namespace sadl 278 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SADL: Small Adhoc Deep-Learning Library 2 | 3 | A small library to perform inference in pure C++. 4 | Models in ONNX format can be converted to a simple format compatible with the library. 5 | ONNX export feature is supported by all majors framework (TF1.x, TF2.x, PyTorch etc.). 6 | Inference can be done completely in C++ without any external dependencies. 7 | 8 | 9 | ## Conversion instruction 10 | Conversion is performed from an ONNX file. 11 | In the sample directory, 2 examples are given. 12 | ```python 13 | import numpy as np 14 | import os 15 | import tf2onnx 16 | import onnx 17 | 18 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 19 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 20 | tensor_fmt = 'channels_last' 21 | import tensorflow as tf 22 | 23 | s = (16, 16,3) 24 | inputs = tf.keras.Input(shape=s, name='input0', dtype=tf.float32) 25 | 26 | nbf = 8 27 | x = tf.keras.layers.Conv2D(nbf, (3,3) , activation='linear',data_format=tensor_fmt, use_bias=True,bias_initializer="glorot_uniform",padding='same')(inputs) 28 | x = tf.keras.layers.MaxPool2D(2,data_format=tensor_fmt)(x) 29 | x = tf.keras.layers.Conv2D(nbf, (3,3) , activation='linear',data_format=tensor_fmt, use_bias=True,bias_initializer="glorot_uniform",padding='same')(x) 30 | 31 | x0 = tf.keras.layers.Conv2D(nbf, kernel_size=(3, 3) , activation='relu', use_bias=True,data_format=tensor_fmt, padding='same')(x) 32 | x0 = tf.keras.layers.Conv2D(nbf, kernel_size=(3, 3) , activation='relu', use_bias=True,data_format=tensor_fmt, padding='same')(x0) 33 | x0 = x0 + x 34 | x0 = tf.keras.layers.MaxPool2D(2,data_format=tensor_fmt)(x0) 35 | x0 = tf.keras.layers.Conv2D(2*nbf, kernel_size=(3, 3) , activation='relu', use_bias=True,data_format=tensor_fmt, padding='same')(x0) 36 | 37 | x1 = tf.keras.layers.Conv2D(2*nbf, kernel_size=(3, 3) , activation='relu', use_bias=True,data_format=tensor_fmt, padding='same')(x0) 38 | x1 = tf.keras.layers.Conv2D(2*nbf, kernel_size=(3, 3) , activation='relu', use_bias=True,data_format=tensor_fmt, padding='same')(x1) 39 | x1 = x1 + x0 40 | x1 = tf.keras.layers.MaxPool2D(2,data_format=tensor_fmt)(x1) 41 | x1 = tf.keras.layers.Conv2D(4*nbf, kernel_size=(3, 3) , activation='relu', use_bias=True,data_format=tensor_fmt, padding='same')(x1) 42 | 43 | x2 = tf.keras.layers.Reshape((1,4*nbf*16//8*16//8))(x1) 44 | y = tf.keras.layers.Dense(2)(x2) 45 | model = tf.keras.Model(inputs=[inputs],outputs=y,name="cat_classifier") 46 | 47 | 48 | X = np.linspace(-1.,1,np.prod(s)).reshape((1,)+s) 49 | Y = model(X) 50 | 51 | model_onnx , _ = tf2onnx.convert.from_keras(model,[tf.TensorSpec(shape=(1,)+s,name="input0")],opset=13) 52 | onnx.save(model_onnx, "./tf2.onnx") 53 | print("Output\n",Y) 54 | print("Model in tf2.onnx") 55 | ``` 56 | 57 | Example of conversion 58 | ```shell 59 | # output a tf2.onnx 60 | python3 sample/tf2.py 61 | 62 | # convert the onnx to sadl 63 | python3 converter/main.py --input_onnx tf2.onnx --output tf2.sadl 64 | ``` 65 | 66 | In PyTorch the same example is given: 67 | ```shell 68 | # output a pytorch.onnx 69 | python3 sample/pytorch.py 70 | 71 | # convert the onnx to sadl 72 | python3 converter/main.py --input_onnx pytorch.onnx --output pytorch.sadl 73 | ``` 74 | Please note taht frameworks using the NCHW data layout are changed to a NHWC data layout graph (one just need to adapt the inputs and/or inputs). 75 | 76 | ## Build instruction 77 | The library is header only and does not require to be build. 78 | However, when integrated into a software, build options will drive several aspects: 79 | - several macros will control the debug level and information level. The macros can be find in sadl/optiosn.h 80 | - the simd options pass to the compiler will control the level of SIMD activated in the code (the level is not, in the library, control dynamically) 81 | 82 | Several examples are given in the sample directory CMakeLists.txt: 83 | - count_mac: just assess the model complexity 84 | - debug_model: print information on potential issues in the model (wrong values, no SIMD layers etc.) 85 | - sample: inference of the model with 3 levels of optimization (generic, simd256, simd512) 86 | 87 | Example of program: 88 | ```c++ 89 | #include 90 | 91 | int main() { 92 | sadl::Model model; 93 | 94 | ifstream file("model.sadl", ios::binary); 95 | if (!model.load(file)) { 96 | cerr << "[ERROR] Unable to read model " << endl; 97 | exit(-1); 98 | } 99 | 100 | vector> inputs=model.getInputsTemplate(); 101 | if (!model.init(inputs)) { 102 | cerr << "[ERROR] issue during initialization" << endl; 103 | exit(-1); 104 | } 105 | 106 | if (!model.apply(inputs)) { 107 | cerr << "[ERROR] issue during inference" << endl; 108 | exit(-1); 109 | } 110 | const int N=model.nbOutput(); 111 | for(int i=0;i class BiasAdd : public Add 41 | { 42 | public: 43 | using Add::Add; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | virtual bool apply(std::vector *> &in) override; 48 | virtual bool init(const std::vector *> &in) override; 49 | }; 50 | 51 | template bool BiasAdd::apply(std::vector *> &in) 52 | { 53 | assert(in.size() == 2); 54 | if (in[0] == in[1]) 55 | { 56 | std::cerr << " input aliasing" << std::endl; 57 | return false; 58 | } 59 | 60 | const int shift = in[0]->quantizer - in[1]->quantizer; 61 | swap(*in[0], out_); 62 | // adapt output width to second input (which are the bias) in order to be able to rescale as desired the input 63 | out_.quantizer = in[1]->quantizer; 64 | 65 | if (shift < 0) 66 | { 67 | if (in[0]->dims() == in[1]->dims()) 68 | { 69 | for (auto it0 = out_.begin(), it1 = in[1]->begin(); it0 != out_.end(); ++it0, ++it1) 70 | { 71 | typename ComputationType::type z = *it0; 72 | ComputationType::shift_left(z, -shift); 73 | // ComputationType::quantize(z, shift); 74 | z += *it1; 75 | COUNTERS(z); 76 | SATURATE(z); 77 | *it0 = z; 78 | } 79 | } 80 | else 81 | { 82 | if (in[1]->size() == 1) 83 | { // ie in[0]->dims().size() == 1? happen if in[1] is a Const 84 | const Tensor &B = *in[1]; 85 | const T value = B[0]; 86 | for (auto &x: out_) 87 | { 88 | typename ComputationType::type z = x; 89 | ComputationType::shift_left(z, -shift); 90 | z += value; 91 | COUNTERS(z); 92 | SATURATE(z); 93 | x = z; 94 | } 95 | } 96 | else if (in[0]->dims().size() == 2) 97 | { 98 | const Tensor &B = *in[1]; 99 | assert(B.dims().size() == 1 || (B.dims().size() == 2 && B.dims()[0] == 1)); 100 | const int N = in[0]->dims()[0]; 101 | const int H = in[0]->dims()[1]; 102 | for (int n = 0; n < N; ++n) 103 | for (int i = 0; i < H; ++i) 104 | { 105 | typename ComputationType::type z = out_(n, i); 106 | ComputationType::shift_left(z, -shift); 107 | z += B[i]; 108 | COUNTERS(z); 109 | SATURATE(z); 110 | out_(n, i) = z; 111 | } 112 | } 113 | else if (in[0]->dims().size() == 3) 114 | { 115 | const Tensor &B = *in[1]; 116 | const int N = in[0]->dims()[0]; 117 | const int H = in[0]->dims()[1]; 118 | const int W = in[0]->dims()[2]; 119 | assert(B.dims().size() == 1 || (B.dims().size() == 2 && B.dims()[0] == 1)); 120 | for (int n = 0; n < N; ++n) 121 | for (int i = 0; i < H; ++i) 122 | for (int j = 0; j < W; ++j) 123 | { 124 | typename ComputationType::type z = out_(n, i, j); 125 | ComputationType::shift_left(z, -shift); 126 | z += B[j]; 127 | COUNTERS(z); 128 | SATURATE(z); 129 | out_(n, i, j) = z; 130 | } 131 | } 132 | else if (in[0]->dims().size() == 4) 133 | { 134 | const Tensor &B = *in[1]; 135 | const int N = in[0]->dims()[0]; 136 | const int H = in[0]->dims()[1]; 137 | const int W = in[0]->dims()[2]; 138 | const int K = in[0]->dims()[3]; 139 | assert(B.dims().size() == 1 || (B.dims().size() == 2 && B.dims()[0] == 1)); 140 | for (int n = 0; n < N; ++n) 141 | for (int i = 0; i < H; ++i) 142 | for (int j = 0; j < W; ++j) 143 | for (int k = 0; k < K; ++k) 144 | { 145 | typename ComputationType::type z = out_(n, i, j, k); 146 | ComputationType::shift_left(z, -shift); 147 | z += B[k]; 148 | COUNTERS(z); 149 | SATURATE(z); 150 | out_(n, i, j, k) = z; 151 | } 152 | } 153 | } 154 | } 155 | else 156 | { 157 | if (in[0]->dims() == in[1]->dims()) 158 | { 159 | for (auto it0 = out_.begin(), it1 = in[1]->begin(); it0 != out_.end(); ++it0, ++it1) 160 | { 161 | typename ComputationType::type z = *it0; 162 | ComputationType::quantize(z, shift); 163 | z += *it1; 164 | COUNTERS(z); 165 | SATURATE(z); 166 | *it0 = z; 167 | } 168 | } 169 | else 170 | { 171 | if (in[1]->size() == 1) 172 | { // for constant 173 | const Tensor &B = *in[1]; 174 | const T value = B[0]; 175 | for (auto &x: out_) 176 | { 177 | typename ComputationType::type z = x; 178 | ComputationType::quantize(z, shift); 179 | z += value; 180 | COUNTERS(z); 181 | SATURATE(z); 182 | x = z; 183 | } 184 | } 185 | else if (in[0]->dims().size() == 2) 186 | { 187 | const Tensor &B = *in[1]; 188 | assert(B.dims().size() == 1 || (B.dims().size() == 2 && B.dims()[0] == 1)); 189 | const int N = in[0]->dims()[0]; 190 | const int H = in[0]->dims()[1]; 191 | for (int n = 0; n < N; ++n) 192 | for (int i = 0; i < H; ++i) 193 | { 194 | typename ComputationType::type z = out_(n, i); 195 | ComputationType::quantize(z, shift); 196 | z += B[i]; 197 | COUNTERS(z); 198 | SATURATE(z); 199 | out_(n, i) = z; 200 | } 201 | } 202 | else if (in[0]->dims().size() == 3) 203 | { 204 | const Tensor &B = *in[1]; 205 | assert(B.dims().size() == 1 || (B.dims().size() == 2 && B.dims()[0] == 1)); 206 | const int N = in[0]->dims()[0]; 207 | const int H = in[0]->dims()[1]; 208 | const int W = in[0]->dims()[2]; 209 | 210 | for (int n = 0; n < N; ++n) 211 | for (int i = 0; i < H; ++i) 212 | for (int j = 0; j < W; ++j) 213 | { 214 | typename ComputationType::type z = out_(n, i, j); 215 | ComputationType::quantize(z, shift); 216 | z += B[j]; 217 | COUNTERS(z); 218 | SATURATE(z); 219 | out_(n, i, j) = z; 220 | } 221 | } 222 | else if (in[0]->dims().size() == 4) 223 | { 224 | const Tensor &B = *in[1]; 225 | assert(B.dims().size() == 1 || (B.dims().size() == 2 && B.dims()[0] == 1)); 226 | const int N = in[0]->dims()[0]; 227 | const int H = in[0]->dims()[1]; 228 | const int W = in[0]->dims()[2]; 229 | const int K = in[0]->dims()[3]; 230 | 231 | for (int n = 0; n < N; ++n) 232 | for (int i = 0; i < H; ++i) 233 | for (int j = 0; j < W; ++j) 234 | for (int k = 0; k < K; ++k) 235 | { 236 | typename ComputationType::type z = out_(n, i, j, k); 237 | ComputationType::quantize(z, shift); 238 | z += B[k]; 239 | COUNTERS(z); 240 | SATURATE(z); 241 | out_(n, i, j, k) = z; 242 | } 243 | } 244 | } 245 | } 246 | return true; 247 | } 248 | 249 | template bool BiasAdd::init(const std::vector *> &in) 250 | { 251 | // convervative check 252 | if (in.size() != 2) 253 | return false; 254 | if (in[1]->dims().size() != 1) 255 | return false; 256 | if (in[0]->dims()[in[0]->dims().size() - 1] != in[1]->dims()[0]) 257 | return false; 258 | 259 | out_.resize(in[0]->dims()); 260 | initDone_ = true; 261 | return true; 262 | } 263 | 264 | } // namespace layers 265 | } // namespace sadl 266 | -------------------------------------------------------------------------------- /sadl/layer_add.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | template class Add : public Layer 41 | { 42 | public: 43 | using Layer::Layer; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | 48 | virtual bool apply(std::vector *> &in) override; 49 | virtual bool init(const std::vector *> &in) override; 50 | virtual bool mutateInput() const override { return true; } 51 | 52 | protected: 53 | virtual bool loadInternal(std::istream &file, Version v) override; 54 | }; 55 | 56 | 57 | // TODO : check all dims, check loop for optiz 58 | template bool Add::apply(std::vector *> &in) 59 | { 60 | assert(in.size() == 2); 61 | if (in[0] == in[1]) 62 | { 63 | std::cerr << " input aliasing" << std::endl; 64 | return false; 65 | } 66 | 67 | 68 | int shift0,shift1,qfinal; 69 | if (in[0]->quantizer < in[1]->quantizer) { 70 | shift0=0; 71 | shift1=in[1]->quantizer-in[0]->quantizer; 72 | qfinal=in[0]->quantizer; 73 | } else { 74 | shift0=in[0]->quantizer-in[1]->quantizer; 75 | shift1=0; 76 | qfinal=in[1]->quantizer; 77 | } 78 | swap(*in[0], out_); 79 | // adapt output width to second input (which are the bias) in order to be able to rescale as desired the input 80 | out_.quantizer =qfinal; 81 | 82 | if (shift0>0) 83 | { 84 | const int shift=shift0; 85 | if (in[0]->dims() == in[1]->dims()) 86 | { 87 | for (auto it0 = out_.begin(), it1 = in[1]->begin(); it0 != out_.end(); ++it0, ++it1) 88 | { 89 | typename ComputationType::type z = *it0; 90 | ComputationType::quantize(z, shift); 91 | z += *it1; 92 | COUNTERS(z); 93 | SATURATE(z); 94 | *it0 = z; 95 | } 96 | } 97 | else 98 | { 99 | if (in[1]->size() == 1) 100 | { // ie in[0]->dims().size() == 1? happen if in[1] is a Const 101 | const Tensor &B = *in[1]; 102 | const T value = B[0]; 103 | for (auto &x: out_) 104 | { 105 | typename ComputationType::type z = x; 106 | ComputationType::quantize(z, shift); 107 | z += value; 108 | COUNTERS(z); 109 | SATURATE(z); 110 | x = z; 111 | } 112 | } 113 | else if (in[0]->dims().size() == 2) 114 | { 115 | const Tensor &B = *in[1]; 116 | assert(B.dims().size() == 1 || (B.dims().size() == 2 && B.dims()[0] == 1)); 117 | const int N = in[0]->dims()[0]; 118 | const int H = in[0]->dims()[1]; 119 | for (int n = 0; n < N; ++n) 120 | for (int i = 0; i < H; ++i) 121 | { 122 | typename ComputationType::type z = out_(n, i); 123 | ComputationType::quantize(z, shift); 124 | z += B[i]; 125 | COUNTERS(z); 126 | SATURATE(z); 127 | out_(n, i) = z; 128 | } 129 | } 130 | else if (in[0]->dims().size() == 3) 131 | { 132 | const Tensor &B = *in[1]; 133 | const int N = in[0]->dims()[0]; 134 | const int H = in[0]->dims()[1]; 135 | const int W = in[0]->dims()[2]; 136 | assert(B.dims().size() == 1 || (B.dims().size() == 2 && B.dims()[0] == 1)); 137 | for (int n = 0; n < N; ++n) 138 | for (int i = 0; i < H; ++i) 139 | for (int j = 0; j < W; ++j) 140 | { 141 | typename ComputationType::type z = out_(n, i, j); 142 | ComputationType::quantize(z, shift); 143 | z += B[j]; 144 | COUNTERS(z); 145 | SATURATE(z); 146 | out_(n, i, j) = z; 147 | } 148 | } 149 | else if (in[0]->dims().size() == 4) 150 | { 151 | const Tensor &B = *in[1]; 152 | const int N = in[0]->dims()[0]; 153 | const int H = in[0]->dims()[1]; 154 | const int W = in[0]->dims()[2]; 155 | const int K = in[0]->dims()[3]; 156 | assert(B.dims().size() == 1 || (B.dims().size() == 2 && B.dims()[0] == 1)); 157 | for (int n = 0; n < N; ++n) 158 | for (int i = 0; i < H; ++i) 159 | for (int j = 0; j < W; ++j) 160 | for (int k = 0; k < K; ++k) 161 | { 162 | typename ComputationType::type z = out_(n, i, j, k); 163 | ComputationType::quantize(z, shift); 164 | z += B[k]; 165 | COUNTERS(z); 166 | SATURATE(z); 167 | out_(n, i, j, k) = z; 168 | } 169 | } 170 | } 171 | } 172 | else // shift1 173 | { 174 | const int shift=shift1; 175 | if (in[0]->dims() == in[1]->dims()) 176 | { 177 | for (auto it0 = out_.begin(), it1 = in[1]->begin(); it0 != out_.end(); ++it0, ++it1) 178 | { 179 | typename ComputationType::type z = *it1; 180 | ComputationType::quantize(z, shift); 181 | z += *it0; 182 | COUNTERS(z); 183 | SATURATE(z); 184 | *it0 = z; 185 | } 186 | } 187 | else 188 | { 189 | if (in[1]->size() == 1) 190 | { // for constant 191 | const Tensor &B = *in[1]; 192 | T valt=B[0]; 193 | ComputationType::quantize(valt,shift); 194 | const T value = valt; 195 | for (auto &x: out_) 196 | { 197 | typename ComputationType::type z = x; 198 | z += value; 199 | COUNTERS(z); 200 | SATURATE(z); 201 | x = z; 202 | } 203 | } 204 | else if (in[0]->dims().size() == 2) 205 | { 206 | const Tensor &B = *in[1]; 207 | assert(B.dims().size() == 1 || (B.dims().size() == 2 && B.dims()[0] == 1)); 208 | const int N = in[0]->dims()[0]; 209 | const int H = in[0]->dims()[1]; 210 | for (int n = 0; n < N; ++n) 211 | for (int i = 0; i < H; ++i) 212 | { 213 | typename ComputationType::type z = B[i]; 214 | ComputationType::quantize(z,shift); 215 | z +=out_(n, i); 216 | COUNTERS(z); 217 | SATURATE(z); 218 | out_(n, i) = z; 219 | } 220 | } 221 | else if (in[0]->dims().size() == 3) 222 | { 223 | const Tensor &B = *in[1]; 224 | assert(B.dims().size() == 1 || (B.dims().size() == 2 && B.dims()[0] == 1)); 225 | const int N = in[0]->dims()[0]; 226 | const int H = in[0]->dims()[1]; 227 | const int W = in[0]->dims()[2]; 228 | 229 | for (int n = 0; n < N; ++n) 230 | for (int i = 0; i < H; ++i) 231 | for (int j = 0; j < W; ++j) 232 | { 233 | typename ComputationType::type z = B[j]; 234 | ComputationType::quantize(z,shift); 235 | z +=out_(n, i,j); 236 | COUNTERS(z); 237 | SATURATE(z); 238 | out_(n, i, j) = z; 239 | } 240 | } 241 | else if (in[0]->dims().size() == 4) 242 | { 243 | const Tensor &B = *in[1]; 244 | assert(B.dims().size() == 1 || (B.dims().size() == 2 && B.dims()[0] == 1)); 245 | const int N = in[0]->dims()[0]; 246 | const int H = in[0]->dims()[1]; 247 | const int W = in[0]->dims()[2]; 248 | const int K = in[0]->dims()[3]; 249 | 250 | for (int n = 0; n < N; ++n) 251 | for (int i = 0; i < H; ++i) 252 | for (int j = 0; j < W; ++j) 253 | for (int k = 0; k < K; ++k) 254 | { 255 | typename ComputationType::type z = B[k]; 256 | ComputationType::quantize(z,shift); 257 | z +=out_(n, i,j,k); 258 | COUNTERS(z); 259 | SATURATE(z); 260 | out_(n, i, j, k) = z; 261 | } 262 | } 263 | } 264 | } 265 | return true; 266 | } 267 | 268 | 269 | // data in in[0] 270 | // bias in in[1] 271 | // assume data shape [N,W,H,D] 272 | // assume bias shape [D] 273 | template bool Add::init(const std::vector *> &in) 274 | { 275 | if (in.size() != 2) 276 | return false; 277 | SADL_DBG(std::cout << " - " << in[0]->dims() << ' ' << in[1]->dims() << std::endl); 278 | 279 | // either broadcast from a tensor of size [n] (use when input is Const) or [1,n] 280 | // of add if same dimensions 281 | if (in[1]->dims().size() == 1) 282 | { 283 | if (in[1]->dims()[0] != 1 && in[1]->dims()[0] != in[0]->dims().back()) 284 | return false; 285 | } 286 | else if (in[1]->dims().size() == 2 && in[1]->dims()[0] == 1) 287 | { 288 | if (in[1]->dims()[1] != 1 && in[1]->dims()[1] != in[0]->dims().back()) 289 | return false; 290 | } 291 | else 292 | { 293 | if (!(in[0]->dims() == in[1]->dims())) 294 | return false; 295 | } 296 | out_.resize(in[0]->dims()); 297 | initDone_ = true; 298 | return true; 299 | } 300 | 301 | template bool Add::loadInternal(std::istream &, Version) 302 | { 303 | return true; 304 | } 305 | 306 | } // namespace layers 307 | } // namespace sadl 308 | -------------------------------------------------------------------------------- /sadl/layer_mul.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | 36 | namespace sadl 37 | { 38 | namespace layers 39 | { 40 | template class Mul : public Layer 41 | { 42 | public: 43 | using Layer::Layer; 44 | using Layer::out_; // to avoid this-> 45 | using Layer::initDone_; 46 | 47 | virtual bool apply(std::vector *> &in) override; 48 | virtual bool init(const std::vector *> &in) override; 49 | virtual bool mutateInput() const override { return true; } 50 | 51 | protected: 52 | virtual bool loadInternal(std::istream &file, Version v) override; 53 | int q_ = 0; 54 | template bool apply_same_dim(std::vector *> &in); 55 | template bool apply_singleton(std::vector *> &in); 56 | template bool apply_dim2(std::vector *> &in); 57 | template bool apply_dim3(std::vector *> &in); 58 | template bool apply_dim4(std::vector *> &in); 59 | #if __AVX2__ 60 | bool apply_singleton_simd8(std::vector *> &in); 61 | bool apply_singleton_simd16(std::vector *> &in); 62 | #endif 63 | DUMP_MODEL_EXT; 64 | }; 65 | 66 | template bool Mul::apply(std::vector *> &in) 67 | { 68 | assert(in.size() == 2); 69 | if (in[0] == in[1]) 70 | { 71 | std::cerr << " input aliasing" << std::endl; 72 | return false; 73 | } 74 | swap(*in[0], out_); 75 | out_.border_skip = std::max(out_.border_skip, in[1]->border_skip); 76 | 77 | out_.quantizer -= q_; // q0-q 78 | assert(out_.quantizer >= 0); 79 | assert(in[1]->quantizer + q_ >= 0); 80 | 81 | const int last = in[0]->dims().back(); 82 | if (last % 16 == 0) 83 | { 84 | constexpr int NN = 16; 85 | if (in[0]->dims() == in[1]->dims()) 86 | { // product wise 87 | return apply_same_dim(in); 88 | } 89 | else if (in[1]->size() == 1) 90 | { // broadcast single element 91 | #if __AVX2__ 92 | return apply_singleton_simd16(in); 93 | #endif 94 | return apply_singleton(in); 95 | } 96 | else if (in[0]->dims().size() == 2) 97 | { 98 | return apply_dim2(in); 99 | } 100 | else if (in[0]->dims().size() == 3) 101 | { 102 | return apply_dim3(in); 103 | } 104 | else if (in[0]->dims().size() == 4) 105 | { 106 | return apply_dim4(in); 107 | } 108 | } 109 | else if (last % 8 == 0) 110 | { 111 | constexpr int NN = 8; 112 | if (in[0]->dims() == in[1]->dims()) 113 | { // product wise 114 | return apply_same_dim(in); 115 | } 116 | else if (in[1]->size() == 1) 117 | { // broadcast single element 118 | #if __AVX2__ 119 | return apply_singleton_simd8(in); 120 | #endif 121 | return apply_singleton(in); 122 | } 123 | else if (in[0]->dims().size() == 2) 124 | { 125 | return apply_dim2(in); 126 | } 127 | else if (in[0]->dims().size() == 3) 128 | { 129 | return apply_dim3(in); 130 | } 131 | else if (in[0]->dims().size() == 4) 132 | { 133 | return apply_dim4(in); 134 | } 135 | } 136 | else 137 | { 138 | constexpr int NN = 1; 139 | if (in[0]->dims() == in[1]->dims()) 140 | { // product wise 141 | return apply_same_dim(in); 142 | } 143 | else if (in[1]->size() == 1) 144 | { // broadcast single element 145 | return apply_singleton(in); 146 | } 147 | else if (in[0]->dims().size() == 2) 148 | { 149 | return apply_dim2(in); 150 | } 151 | else if (in[0]->dims().size() == 3) 152 | { 153 | return apply_dim3(in); 154 | } 155 | else if (in[0]->dims().size() == 4) 156 | { 157 | return apply_dim4(in); 158 | } 159 | } 160 | 161 | return false; 162 | } 163 | 164 | template template bool Mul::apply_same_dim(std::vector *> &in) 165 | { 166 | const int shift = in[1]->quantizer + q_; 167 | #if __AVX2__ && DEBUG_SIMD 168 | std::cout << "\n[WARN] generic version mul sameDim (but likely vectorized) " << in[0]->dims() << ' ' << in[1]->dims() << " " 169 | << in[0]->dims().nbElements() / 1000 << " kMAC" << std::endl; 170 | #endif // SIMD 171 | // for (auto it0 = out_.begin(), it1 = in[1]->begin(); it0 != out_.end(); ++it0, ++it1) { 172 | const auto &B = *in[1]; 173 | const auto N = (out_.size() / NN) * NN; 174 | for (int k = 0; k < N; ++k) 175 | { 176 | typename ComputationType::type x = out_[k]; 177 | x *= B[k]; 178 | COUNTERS_MAC(B[k]); 179 | ComputationType::quantize(x, shift); 180 | COUNTERS(x); 181 | SATURATE(x); 182 | out_[k] = (T) x; 183 | } 184 | return true; 185 | } 186 | 187 | template template bool Mul::apply_singleton(std::vector *> &in) 188 | { 189 | const int shift = in[1]->quantizer + q_; 190 | const Tensor &B = *in[1]; 191 | #if __AVX2__ && DEBUG_SIMD 192 | std::cout << "[WARN] generic version mul singleton (but likely vectorized) " << in[0]->dims() << ' ' << in[1]->dims() << std::endl; 193 | #endif // SIMD 194 | const T value{ B[0] }; 195 | const auto N = (out_.size() / NN) * NN; 196 | // for (auto it0 = out_.begin(); it0 != out_.end(); ++it0) { 197 | for (int k = 0; k < N; ++k) 198 | { 199 | typename ComputationType::type x = out_[k]; 200 | x *= value; 201 | COUNTERS_MAC(value); 202 | ComputationType::quantize(x, shift); 203 | COUNTERS(x); 204 | SATURATE(x); 205 | out_[k] = (T) x; 206 | } 207 | return true; 208 | } 209 | 210 | template template bool Mul::apply_dim2(std::vector *> &in) 211 | { 212 | const int shift = in[1]->quantizer + q_; 213 | 214 | #if __AVX2__ && DEBUG_SIMD 215 | std::cout << "[WARN] generic version mul singleton (but likely vectorized) " << in[0]->dims() << ' ' << in[1]->dims() << std::endl; 216 | #endif // SIMD 217 | 218 | const Tensor &B = *in[1]; 219 | const int N = in[0]->dims()[0]; 220 | const int H = (in[0]->dims()[1] / NN) * NN; 221 | for (int n = 0; n < N; ++n) 222 | for (int i = 0; i < H; ++i) 223 | { 224 | typename ComputationType::type x = out_(n, i); 225 | x *= B[i]; 226 | COUNTERS_MAC(B[i]); 227 | ComputationType::quantize(x, shift); 228 | COUNTERS(x); 229 | SATURATE(x); 230 | out_(n, i) = (T) x; 231 | } 232 | return true; 233 | } 234 | 235 | template template bool Mul::apply_dim3(std::vector *> &in) 236 | { 237 | const int shift = in[1]->quantizer + q_; 238 | 239 | #if __AVX2__ && DEBUG_SIMD 240 | std::cout << "[WARN] generic version mul singleton " << in[0]->dims() << ' ' << in[1]->dims() << std::endl; 241 | #endif // SIMD 242 | 243 | const Tensor &B = *in[1]; 244 | const int N = in[0]->dims()[0]; 245 | const int H = in[0]->dims()[1]; 246 | const int W = (in[0]->dims()[2] / NN) * NN; 247 | for (int n = 0; n < N; ++n) 248 | for (int i = 0; i < H; ++i) 249 | for (int j = 0; j < W; ++j) 250 | { 251 | typename ComputationType::type x = out_(n, i, j); 252 | x *= B[j]; 253 | COUNTERS_MAC(B[j]); 254 | ComputationType::quantize(x, shift); 255 | COUNTERS(x); 256 | SATURATE(x); 257 | out_(n, i, j) = (T) x; 258 | } 259 | return true; 260 | } 261 | 262 | template template bool Mul::apply_dim4(std::vector *> &in) 263 | { 264 | const int shift = in[1]->quantizer + q_; 265 | 266 | #if __AVX2__ && DEBUG_SIMD 267 | std::cout << "[WARN] generic version mul singleton" << in[0]->dims() << ' ' << in[1]->dims() << std::endl; 268 | #endif // SIMD 269 | assert(in[0]->dims()[0] == 1); 270 | 271 | const Tensor &B = *in[1]; 272 | const int N = in[0]->dims()[0]; 273 | const int H = in[0]->dims()[1]; 274 | const int W = in[0]->dims()[2]; 275 | const int K = (in[0]->dims()[3] / NN) * NN; 276 | for (int n = 0; n < N; ++n) 277 | for (int i = 0; i < H; ++i) 278 | for (int j = 0; j < W; ++j) 279 | for (int k = 0; k < K; ++k) 280 | { 281 | typename ComputationType::type x = out_(n, i, j, k); 282 | x *= B[k]; 283 | COUNTERS_MAC(B[k]); 284 | ComputationType::quantize(x, shift); 285 | COUNTERS(x); 286 | SATURATE(x); 287 | out_(n, i, j, k) = (T) x; 288 | } 289 | 290 | return true; 291 | } 292 | 293 | #if __AVX2__ 294 | template<> inline bool Mul::apply_singleton_simd8(std::vector *> &in) 295 | { 296 | using T = float; 297 | const Tensor &B = *in[1]; 298 | const __m256 value = _mm256_set1_ps(B[0]); 299 | for (int k = 0; k < out_.size(); k += 8) 300 | { 301 | float *aptr = out_.data() + k; 302 | __m256 a = _mm256_load_ps(aptr); 303 | __m256 v = _mm256_mul_ps(a, value); 304 | _mm256_store_ps(aptr, v); 305 | } 306 | return true; 307 | } 308 | template<> inline bool Mul::apply_singleton_simd16(std::vector *> &in) 309 | { 310 | return apply_singleton_simd8(in); 311 | } 312 | 313 | 314 | 315 | template bool Mul::apply_singleton_simd8(std::vector *> &in) 316 | { 317 | return apply_singleton<8>(in); 318 | } 319 | 320 | template bool Mul::apply_singleton_simd16(std::vector *> &in) 321 | { 322 | return apply_singleton<16>(in); 323 | } 324 | #endif 325 | // data in in[0] 326 | // bias in in[1] 327 | // assume data shape [N,W,H,D] 328 | // assume bias shape [D] 329 | template bool Mul::init(const std::vector *> &in) 330 | { 331 | SADL_DBG(std::cout << " - " << in[0]->dims() << ' ' << in[1]->dims() << std::endl); 332 | if (in.size() != 2) 333 | return false; 334 | 335 | // cases: 336 | // same dim: element wise 337 | // if B as only one element-> bradcast to all A element 338 | // B has dim [n] or [1,n] and A[...,n] 339 | /* 340 | If the bias a single dimension dimension and it 341 | is not a singleton, the last dimension of the input 342 | tensor has to be equal to the bias dimension. 343 | */ 344 | 345 | if (in[1]->size() == 1) 346 | { 347 | // ok 348 | } 349 | else if (in[1]->dims().size() == 1 || (in[1]->dims().size() == 2 && in[1]->dims()[0] == 1)) 350 | { 351 | if (in[0]->dims().back() != in[1]->dims().back()) 352 | return false; 353 | } 354 | else if (in[0]->dims().size()>=2 && in[1]->size() == in[0]->dims().back()) { 355 | 356 | } 357 | else 358 | { 359 | if (!(in[0]->dims() == in[1]->dims())) 360 | return false; 361 | } 362 | out_.resize(in[0]->dims()); 363 | initDone_ = true; 364 | return true; 365 | } 366 | 367 | template bool Mul::loadInternal(std::istream &file, Version v) 368 | { 369 | file.read((char *) &q_, sizeof(q_)); 370 | SADL_DBG(std::cout << " - q: " << q_ << std::endl); 371 | 372 | return true; 373 | } 374 | 375 | } // namespace layers 376 | } // namespace sadl 377 | -------------------------------------------------------------------------------- /sadl/layer_matmul.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include "layer.h" 35 | #if __AVX2__ || __SSE4_2__ 36 | #include 37 | #endif 38 | 39 | namespace sadl 40 | { 41 | namespace layers 42 | { 43 | template class MatMul : public Layer 44 | { 45 | public: 46 | using Layer::Layer; 47 | using Layer::out_; // to avoid this-> 48 | using Layer::initDone_; 49 | 50 | virtual bool apply(std::vector *> &in) override; 51 | virtual bool init(const std::vector *> &in) override; 52 | 53 | protected: 54 | virtual bool loadInternal(std::istream &file, Version v) override; 55 | template bool apply_dim2(std::vector *> &in); 56 | template bool apply_dim3(std::vector *> &in); 57 | #if __AVX2__ 58 | bool apply_dim2_simd8(std::vector *> &in) { return apply_dim2<8>(in); } 59 | bool apply_dim2_simd16(std::vector *> &in) { return apply_dim2_simd8(in); } 60 | bool apply_sparse_matmul_simd16(std::vector *> &in); 61 | #endif 62 | int q_ = 0; 63 | DUMP_MODEL_EXT; 64 | }; 65 | 66 | template bool MatMul::apply(std::vector *> &in) 67 | { 68 | assert(in.size() == 2); 69 | #if __AVX2__ 70 | #define MULT8_DIM2 apply_dim2_simd8 71 | #define MULT16_DIM2 apply_dim2_simd16 72 | #else 73 | #define MULT8_DIM2 apply_dim2<8> 74 | #define MULT16_DIM2 apply_dim2<16> 75 | #endif 76 | const Tensor &A{ *in[0] }; 77 | const Tensor &B{ *in[1] }; 78 | out_.quantizer = A.quantizer - q_; 79 | assert(out_.quantizer >= 0); 80 | assert(in[1]->quantizer + q_ >= 0); 81 | int dum = A.dims().size(); 82 | // cases: 83 | // A: always a tensor 84 | // B: tensor or const 85 | // 1- A [x] B[x] || A [x,y] B[y,z] || A [x,y,z] B[x,z,t] 86 | // 2- A [1,x] B[x] || A [1,x,y] B[y,z] || A [1,x,y,z] B[x,z,t] 87 | if (A.dims().size() - 1 == B.dims().size()) 88 | dum--; 89 | const int H{ A.dims().back() }; // to be chnaged if SIMD for more than dim1 and dim2 90 | 91 | switch (dum) 92 | { 93 | case 2: 94 | if (H % 16 == 0) 95 | return MULT16_DIM2(in); 96 | else if (H % 8 == 0) 97 | return MULT8_DIM2(in); 98 | else 99 | return apply_dim2<1>(in); 100 | break; 101 | case 3: return apply_dim3<1>(in); break; 102 | default: std::cerr << "Logical error MatMul::apply(std::vector *> &in)" << A.dims() << ' ' << B.dims() << std::endl; return false; 103 | } 104 | } 105 | 106 | #if __AVX2__ 107 | template<> inline bool MatMul::apply_dim2_simd8(std::vector *> &in) 108 | { 109 | using T = float; 110 | const Tensor &A{ *in[0] }; 111 | const Tensor &B{ *in[1] }; 112 | const int last = A.dims().size() - 1; 113 | const int N{ A.dims()[last - 1] }; 114 | const int H{ A.dims()[last] }; 115 | const int R{ B.dims()[1] }; 116 | #if DEBUG_SIMD 117 | if (H >= 16) 118 | { 119 | std::cout << "\n[WARN] suboptimal SIMD version matmul dim2 " << A.dims() << ' ' << B.dims() << "(H=" << H << ") " << (N * R * H) / 1000 << " kMAC" 120 | << std::endl; 121 | } 122 | #endif 123 | assert(H % 8 == 0); 124 | for (int b = 0; b < N; ++b) 125 | { 126 | float *optr = out_.data() + R * b; 127 | for (int t = 0; t < R; ++t) 128 | { 129 | __m256 s = _mm256_setzero_ps(); 130 | const float *aptr = A.data() + b * H; 131 | const float *bptr = B.data() + t * H; // T * i + t (i, t); => B[t*H+i] if transposed 132 | for (int i = 0; i < H; i += 8, aptr += 8, bptr += 8) 133 | { 134 | __m256 a = _mm256_load_ps(aptr); 135 | __m256 b = _mm256_load_ps(bptr); 136 | #if __FMA__ 137 | s = _mm256_fmadd_ps(a, b, s); 138 | #else 139 | s = _mm256_add_ps(s, _mm256_mul_ps(a, b)); // s+= _mm256_mul_ps(a, b); // 140 | #endif 141 | } 142 | optr[t] = sum8_float(s); // out_(b, t) = sum8_float(s); 143 | } 144 | } 145 | return true; 146 | } 147 | 148 | #if __AVX512F__ 149 | template<> inline bool MatMul::apply_dim2_simd16(std::vector *> &in) 150 | { 151 | const Tensor &A{ *in[0] }; 152 | const Tensor &B{ *in[1] }; 153 | const int last = A.dims().size() - 1; 154 | const int N{ A.dims()[last - 1] }; 155 | const int H{ A.dims()[last] }; 156 | const int R{ B.dims()[1] }; 157 | assert(H % 16 == 0); 158 | for (int b = 0; b < N; ++b) 159 | { 160 | float *optr = out_.data() + R * b; 161 | for (int t = 0; t < R; ++t) 162 | { 163 | __m512 s = _mm512_setzero_ps(); 164 | const float *aptr = A.data() + b * H; 165 | 166 | // The matrix of weights is transposed. 167 | const float *bptr = B.data() + t * H; 168 | for (int i = 0; i < H; i += 16, aptr += 16, bptr += 16) 169 | { 170 | __m512 a = _mm512_load_ps(aptr); 171 | __m512 b = _mm512_load_ps(bptr); 172 | #if __FMA__ 173 | s = _mm512_fmadd_ps(a, b, s); 174 | #else 175 | s = _mm512_add_ps(s, _mm512_mul_ps(a, b)); 176 | #endif 177 | } 178 | optr[t] = sum16_float(s); 179 | } 180 | } 181 | return true; 182 | } 183 | #endif 184 | #endif 185 | 186 | template template bool MatMul::apply_dim2(std::vector *> &in) 187 | { 188 | const Tensor &A{ *in[0] }; 189 | const Tensor &B{ *in[1] }; 190 | const int shift{ in[1]->quantizer + q_ }; 191 | const int last = A.dims().size() - 1; 192 | const int N{ A.dims()[last - 1] }; 193 | const int H{ (A.dims()[last] / NN) * NN }; 194 | const int R{ B.dims().back() }; 195 | #if __AVX2__ && DEBUG_SIMD 196 | std::cout << "\n[WARN] generic version matmul dim2 " << A.dims() << ' ' << B.dims() << "(H=" << H << ") " << (N * R * H) / 1000 << " kMAC" << std::endl; 197 | #endif // SIMD 198 | if (A.dims().size() == 2) 199 | { 200 | for (int b = 0; b < N; ++b) 201 | { 202 | const T *aptr = A.data() + H * b; // A(b,i) => A[H*b] 203 | for (int t = 0; t < R; ++t) 204 | { 205 | typename ComputationType::type x = 0; 206 | const T * bptr = B.data() + t * H; // T * i + t (i, t); => B[t*H+i] if transposed 207 | for (int i = 0; i < H; ++i) 208 | { 209 | x += (typename ComputationType::type) aptr[i] * bptr[i]; // A(b,i)*B(i, t); 210 | COUNTERS_MAC(bptr[i]); 211 | } 212 | ComputationType::quantize(x, shift); 213 | COUNTERS(x); 214 | SATURATE(x); 215 | out_(b, t) = (T) x; 216 | } 217 | } 218 | } 219 | else 220 | { 221 | for (int b = 0; b < N; ++b) 222 | { 223 | const T *aptr = A.data() + H * b; // A(0,b,i) => A[H*b] 224 | for (int t = 0; t < R; ++t) 225 | { 226 | typename ComputationType::type x = 0; 227 | const T * bptr = B.data() + t * H; // T * i + t (i, t); => B[t*H+i] if transposed 228 | for (int i = 0; i < H; ++i) 229 | { 230 | x += (typename ComputationType::type) aptr[i] * bptr[i]; // A(0,b,i)*B(i, t); 231 | COUNTERS_MAC(bptr[i]); 232 | } 233 | ComputationType::quantize(x, shift); 234 | COUNTERS(x); 235 | SATURATE(x); 236 | out_(0, b, t) = (T) x; 237 | } 238 | } 239 | } 240 | return true; 241 | } 242 | 243 | template template bool MatMul::apply_dim3(std::vector *> &in) 244 | { 245 | const Tensor &A{ *in[0] }; 246 | const Tensor &B{ *in[1] }; 247 | const int shift{ in[1]->quantizer + q_ }; 248 | const int last = A.dims().size() - 1; 249 | const int N{ A.dims()[last - 2] }; 250 | const int H{ A.dims()[last - 1] }; 251 | const int W{ (A.dims()[last] / NN) * NN }; 252 | const int R{ B.dims().back() }; 253 | #if __AVX2__ && DEBUG_SIMD 254 | std::cout << "\n[WARN] generic version matmul dim3 " << A.dims() << ' ' << B.dims() << "(H=" << H << ") " << (N * R * H * W) / 1000 << " kMAC" << std::endl; 255 | #endif // SIMD 256 | if (A.dims().size() == 3) 257 | { 258 | for (int b = 0; b < N; ++b) 259 | { 260 | for (int i = 0; i < H; ++i) 261 | { 262 | for (int t = 0; t < R; ++t) 263 | { 264 | typename ComputationType::type x = 0; 265 | for (int j = 0; j < W; ++j) 266 | { 267 | x += (typename ComputationType::type) A(b, i, j) * B(b, j, t); 268 | COUNTERS_MAC(B(b, j, t)); 269 | } 270 | ComputationType::quantize(x, shift); 271 | COUNTERS(x); 272 | SATURATE(x); 273 | out_(b, i, t) = (T) x; 274 | } 275 | } 276 | } 277 | } 278 | else 279 | { // size==4 280 | for (int b = 0; b < N; ++b) 281 | { 282 | for (int i = 0; i < H; ++i) 283 | { 284 | for (int t = 0; t < R; ++t) 285 | { 286 | typename ComputationType::type x = 0; 287 | for (int j = 0; j < W; ++j) 288 | { 289 | x += (typename ComputationType::type) A(0, b, i, j) * B(b, j, t); 290 | COUNTERS_MAC(B(b, j, t)); 291 | } 292 | ComputationType::quantize(x, shift); 293 | COUNTERS(x); 294 | SATURATE(x); 295 | out_(0, b, i, t) = (T) x; 296 | } 297 | } 298 | } 299 | } 300 | return true; 301 | } 302 | 303 | template bool MatMul::init(const std::vector *> &in) 304 | { 305 | // old: 306 | // old: 307 | // multiply matrix of inner dim [a b ] or [x a b] or [ x a b y] (the [a b] 308 | // matrix) x and y should be same new: output[..., i, j] = sum_k (a[..., i, k] 309 | // * b[..., k, j]), for all indices i, j. 310 | 311 | SADL_DBG(std::cout << " - input matmul: " << in[0]->dims() << ' ' << in[1]->dims() << std::endl); 312 | 313 | if (in.size() != 2) 314 | { 315 | return false; 316 | } 317 | // cases: 318 | // A: always a tensor 319 | // B: const (because assumed transposed) 320 | // 1- A [x,y] B[y,z] || A [x,y,z] B[x,z,t] || A [1,x,y,z] B[1,x,z,t] 321 | // 2- A [1,x,y] B[y,z] || A [1,x,y,z] B[x,z,t] 322 | if (in[1]->dims().size() < 2 || in[1]->dims().size() > 3) 323 | { 324 | return false; 325 | } 326 | 327 | if (in[0]->dims().size() != in[1]->dims().size() && !(in[0]->dims().size() - 1 == in[1]->dims().size() && in[0]->dims()[0] == 1)) 328 | { 329 | return false; 330 | } 331 | 332 | if (in[0]->dims().size() != in[1]->dims().size() && !(in[0]->dims().size() - 1 == in[1]->dims().size() && in[0]->dims()[0] == 1)) 333 | { 334 | return false; 335 | } 336 | Dimensions dim = in[0]->dims(); 337 | const int last = in[0]->dims().size() - 1; 338 | 339 | if (in[0]->dims().size() - 1 == in[1]->dims().size()) 340 | { 341 | for (int k = 1; k < last - 1; ++k) 342 | { 343 | if (in[0]->dims()[k] != in[1]->dims()[k - 1]) 344 | { 345 | return false; 346 | } 347 | } 348 | if (in[0]->dims()[last] != in[1]->dims()[last - 2]) 349 | { 350 | return false; 351 | } 352 | } 353 | else 354 | { 355 | #if DEBUG_MODEL 356 | if (in[0]->dims()[0] != 1) 357 | std::cout << "[WARN] suspicious operation (likely second input not a Const)" << std::endl; 358 | #endif 359 | // Excluding the last two dimensions, the dimension 360 | // of index i in the first input Tensor must be equal 361 | // to the dimension of index i in the second input 362 | // Tensor. 363 | for (int k = 0; k < last - 1; ++k) 364 | { 365 | if (in[0]->dims()[k] != in[1]->dims()[k]) 366 | { 367 | return false; 368 | } 369 | } 370 | if (in[0]->dims()[last] != in[1]->dims()[last - 1]) 371 | { 372 | return false; 373 | } 374 | } 375 | dim[last] = in[1]->dims().back(); 376 | out_.resize(dim); 377 | SADL_DBG(std::cout << " - output matmul: " << out_.dims() << std::endl); 378 | initDone_ = true; 379 | return true; 380 | } 381 | 382 | template bool MatMul::loadInternal(std::istream &file, Version v) 383 | { 384 | file.read((char *) &q_, sizeof(q_)); 385 | SADL_DBG(std::cout << " - q: " << q_ << std::endl); 386 | return true; 387 | } 388 | 389 | } // namespace layers 390 | } // namespace sadl 391 | -------------------------------------------------------------------------------- /sadl/tensor.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include 35 | #include 36 | #if _WIN32 || __USE_ISOC11 37 | #include 38 | #else 39 | #include 40 | #endif 41 | #include 42 | #include 43 | #include "options.h" 44 | 45 | #include "dimensions.h" 46 | 47 | namespace sadl 48 | { 49 | // tensor between layers: depth height width (or width height?) 50 | template struct aligned_allocator 51 | { 52 | using pointer = T *; 53 | using const_pointer = const T *; 54 | using reference = T &; 55 | using const_reference = const T &; 56 | using value_type = T; 57 | using size_type = std::size_t; 58 | using difference_type = std::ptrdiff_t; 59 | 60 | pointer address(reference r) const { return &r; } 61 | const_pointer address(const_reference s) const { return &s; } 62 | size_type max_size() const { return (static_cast(0) - static_cast(1)) / sizeof(T); } 63 | template struct rebind 64 | { 65 | typedef aligned_allocator other; 66 | }; 67 | 68 | bool operator!=(const aligned_allocator &other) const { return !(*this == other); } 69 | void construct(pointer p, const_reference t) const 70 | { 71 | void *const pv = static_cast(p); 72 | new (pv) T(t); 73 | } 74 | void destroy(T *const p) const { p->~T(); } 75 | bool operator==(const aligned_allocator & /*other*/) const { return true; } 76 | 77 | aligned_allocator() = default; 78 | aligned_allocator(const aligned_allocator &) = default; 79 | ~aligned_allocator() = default; 80 | aligned_allocator &operator=(const aligned_allocator &) = delete; 81 | 82 | template aligned_allocator(const aligned_allocator &) {} 83 | 84 | pointer allocate(const std::size_t n) const 85 | { 86 | if (n == 0) 87 | return nullptr; 88 | size_t s = ((n * sizeof(T) + Alignment - 1) / Alignment) * Alignment; 89 | 90 | #if _WIN32 91 | #if __MINGW32__ 92 | void *const pv = __mingw_aligned_malloc(s, Alignment); 93 | #else 94 | void *const pv = _aligned_malloc(s, Alignment); 95 | #endif 96 | #else 97 | #if __USE_ISOC11 98 | void *const pv = aligned_alloc(Alignment, s); 99 | #else 100 | void *pv = nullptr; 101 | if (posix_memalign(&pv, Alignment, s)) 102 | { 103 | throw std::bad_alloc(); 104 | } 105 | #endif 106 | #endif 107 | 108 | if (!pv) 109 | throw std::bad_alloc(); 110 | return static_cast(pv); 111 | } 112 | 113 | #ifdef _WIN32 114 | void deallocate(T *const p, const std::size_t n) const { _aligned_free(p); } 115 | #else 116 | void deallocate(T *const p, const std::size_t /*n*/) const { free(p); } 117 | #endif 118 | 119 | template pointer allocate(const std::size_t n, const U * /* const hint */) const { return allocate(n); } 120 | }; 121 | 122 | template struct ComputationType 123 | { 124 | }; 125 | 126 | // predecl for friendness 127 | template class Tensor; 128 | template void swap(Tensor &t0, Tensor &t1); 129 | template void swapData(Tensor &t0, Tensor &t1); 130 | 131 | template class Tensor 132 | { 133 | public: 134 | using value_type = T; 135 | using Data = std::vector>; 136 | using iterator = typename Data::iterator; 137 | using const_iterator = typename Data::const_iterator; 138 | static bool skip_border; // to replace by inline global C++17 139 | 140 | Tensor() = default; 141 | explicit Tensor(Dimensions d); 142 | 143 | void resize(Dimensions d); 144 | 145 | // lineqar access 146 | value_type &operator[](int i); 147 | value_type operator[](int i) const; 148 | 149 | // tensor access 150 | value_type &operator()(int i); 151 | value_type operator()(int i) const; 152 | 153 | value_type &operator()(int i, int j); 154 | value_type operator()(int i, int j) const; 155 | 156 | value_type &operator()(int i, int j, int k); 157 | value_type operator()(int i, int j, int k) const; 158 | 159 | value_type & operator()(int i, int j, int k, int l); 160 | value_type operator()(int i, int j, int k, int l) const; 161 | const value_type *addr(int i, int j, int k, int l) const; 162 | 163 | bool in(int i) const; 164 | bool in(int i, int j) const; 165 | bool in(int i, int j, int k) const; 166 | bool in(int i, int j, int k, int l) const; 167 | void fill(value_type value); 168 | 169 | const Dimensions &dims() const; 170 | int64_t size() const; 171 | 172 | const value_type *data() const { return data_.data(); } 173 | value_type * data() { return data_.data(); } 174 | 175 | iterator begin() { return data_.begin(); } 176 | const_iterator begin() const { return data_.begin(); } 177 | iterator end() { return data_.end(); } 178 | const_iterator end() const { return data_.end(); } 179 | 180 | int quantizer = 0; // for int 181 | int border_skip = 0; 182 | static constexpr int64_t kMaxSize = 32LL*1024*1024*1024; 183 | 184 | Data &getData() { return data_; } 185 | 186 | private: 187 | Dimensions dims_; 188 | Data data_; 189 | friend void swap<>(Tensor &t0, Tensor &t1); 190 | friend void swapData<>(Tensor &t0, Tensor &t1); 191 | #if DEBUG_PRINT 192 | public: 193 | static bool verbose_; 194 | #endif 195 | }; 196 | 197 | // spe 198 | template<> struct ComputationType 199 | { 200 | using type = float; 201 | static constexpr type max = std::numeric_limits::max(); 202 | static void quantize(type, int) {} // nothing to do 203 | static void shift_left(type, int) {} // nothing to do 204 | }; 205 | 206 | template<> struct ComputationType 207 | { 208 | using type = int64_t; 209 | static constexpr type max = std::numeric_limits::max(); 210 | static void quantize(type &z, int q) { z >>= q; } 211 | static void shift_left(type &z, int q) { z <<= q; } 212 | static void quantize(int32_t &z, int q) { z >>= q; } 213 | static void shift_left(int32_t &z, int q) { z <<= q; } 214 | }; 215 | 216 | template<> struct ComputationType 217 | { 218 | using type = int32_t; 219 | static constexpr type max = std::numeric_limits::max(); 220 | static void quantize(type &z, int q) { z >>= q; } 221 | static void shift_left(type &z, int q) { z <<= q; } 222 | static void quantize(int16_t &z, int q) { z >>= q; } 223 | static void shift_left(int16_t &z, int q) { z <<= q; } 224 | }; 225 | 226 | // impl 227 | template bool Tensor::skip_border = false; 228 | 229 | template void swap(Tensor &t0, Tensor &t1) 230 | { 231 | std::swap(t0.dims_, t1.dims_); 232 | std::swap(t0.data_, t1.data_); 233 | std::swap(t0.quantizer, t1.quantizer); 234 | std::swap(t0.border_skip, t1.border_skip); 235 | } 236 | 237 | template void swapData(Tensor &t0, Tensor &t1) 238 | { 239 | assert(t0.size() == t1.size()); 240 | std::swap(t0.data_, t1.data_); 241 | std::swap(t0.quantizer, t1.quantizer); 242 | std::swap(t0.border_skip, t1.border_skip); 243 | } 244 | 245 | template Tensor::Tensor(Dimensions d) 246 | { 247 | resize(d); 248 | } 249 | 250 | template const Dimensions &Tensor::dims() const 251 | { 252 | return dims_; 253 | } 254 | 255 | template int64_t Tensor::size() const 256 | { 257 | return data_.size(); 258 | } 259 | 260 | template void Tensor::resize(Dimensions d) 261 | { 262 | dims_ = d; 263 | int64_t m = dims_.nbElements(); 264 | // for(auto x: dims_) m*=x; 265 | assert(m < kMaxSize); 266 | data_.resize(m); 267 | } 268 | 269 | // TODO: variadic template to define all accesors 270 | template T &Tensor::operator[](int i) 271 | { 272 | return data_[i]; 273 | } 274 | 275 | template T &Tensor::operator()(int i) 276 | { 277 | assert(dims_.size() == 1); 278 | assert(i < dims_[0] && i >= 0); 279 | 280 | return data_[i]; 281 | } 282 | 283 | template bool Tensor::in(int i) const 284 | { 285 | return dims_.size() == 1 && i < dims_[0] && i >= 0; 286 | } 287 | 288 | template T Tensor::operator[](int i) const 289 | { 290 | return data_[i]; 291 | } 292 | 293 | template T Tensor::operator()(int i) const 294 | { 295 | assert(dims_.size() == 1); 296 | assert(i < dims_[0] && i >= 0); 297 | 298 | return data_[i]; 299 | } 300 | 301 | template T &Tensor::operator()(int i, int j) 302 | { 303 | assert(dims_.size() == 2); 304 | assert(i < dims_[0] && i >= 0); 305 | assert(j < dims_[1] && j >= 0); 306 | 307 | return data_[(int64_t)dims_[1] * i + j]; 308 | } 309 | 310 | template T Tensor::operator()(int i, int j) const 311 | { 312 | assert(dims_.size() == 2); 313 | assert(i < dims_[0] && i >= 0); 314 | assert(j < dims_[1] && j >= 0); 315 | 316 | return data_[(int64_t)dims_[1] * i + j]; 317 | } 318 | 319 | template bool Tensor::in(int i, int j) const 320 | { 321 | return dims_.size() == 2 && i < dims_[0] && i >= 0 && j < dims_[1] && j >= 0; 322 | } 323 | 324 | template T &Tensor::operator()(int i, int j, int k) 325 | { 326 | assert(dims_.size() == 3); 327 | assert(i < dims_[0] && i >= 0); 328 | assert(j < dims_[1] && j >= 0); 329 | assert(k < dims_[2] && k >= 0); 330 | 331 | return data_[(int64_t)dims_[2] * (dims_[1] * i + j) + k]; 332 | } 333 | 334 | template T Tensor::operator()(int i, int j, int k) const 335 | { 336 | assert(dims_.size() == 3); 337 | assert(i < dims_[0] && i >= 0); 338 | assert(j < dims_[1] && j >= 0); 339 | assert(k < dims_[2] && k >= 0); 340 | 341 | return data_[(int64_t)dims_[2] * (dims_[1] * i + j) + k]; 342 | } 343 | 344 | template bool Tensor::in(int i, int j, int k) const 345 | { 346 | return dims_.size() == 3 && i < dims_[0] && i >= 0 && j < dims_[1] && j >= 0 && k < dims_[2] && k >= 0; 347 | } 348 | 349 | template T &Tensor::operator()(int i, int j, int k, int l) 350 | { 351 | assert(dims_.size() == 4); 352 | assert(i < dims_[0] && i >= 0); 353 | assert(j < dims_[1] && j >= 0); 354 | assert(k < dims_[2] && k >= 0); 355 | assert(l < dims_[3] && l >= 0); 356 | 357 | return data_[(int64_t)dims_[3] * (dims_[2] * (dims_[1] * i + j) + k) + l]; 358 | } 359 | 360 | template bool Tensor::in(int i, int j, int k, int l) const 361 | { 362 | return dims_.size() == 4 && i < dims_[0] && i >= 0 && j < dims_[1] && j >= 0 && k < dims_[2] && k >= 0 && l < dims_[3] && l >= 0; 363 | } 364 | 365 | template const T *Tensor::addr(int i, int j, int k, int l) const 366 | { 367 | assert(dims_.size() == 4); 368 | assert(i < dims_[0] && i >= 0); 369 | assert(j < dims_[1] && j >= 0); 370 | assert(k < dims_[2] && k >= 0); 371 | assert(l < dims_[3] && l >= 0); 372 | return &data_[(int64_t)dims_[3] * (dims_[2] * (dims_[1] * i + j) + k) + l]; 373 | } 374 | 375 | template T Tensor::operator()(int i, int j, int k, int l) const 376 | { 377 | assert(dims_.size() == 4); 378 | assert(i < dims_[0] && i >= 0); 379 | assert(j < dims_[1] && j >= 0); 380 | assert(k < dims_[2] && k >= 0); 381 | assert(l < dims_[3] && l >= 0); 382 | return data_[(int64_t)dims_[3] * (dims_[2] * (dims_[1] * i + j) + k) + l]; 383 | } 384 | 385 | template void Tensor::fill(value_type value) 386 | { 387 | std::fill(data_.begin(), data_.end(), value); 388 | } 389 | 390 | } // namespace sadl 391 | 392 | #include 393 | #include 394 | 395 | #if DEBUG_PRINT 396 | template bool sadl::Tensor::verbose_ = true; 397 | 398 | #define SADL_DBG(X) \ 399 | if (sadl::Tensor::verbose_) \ 400 | { \ 401 | X; \ 402 | } 403 | #else 404 | #define SADL_DBG(X) 405 | #endif 406 | 407 | namespace sadl 408 | { 409 | template std::ostream &operator<<(std::ostream &out, const Tensor &t) 410 | { 411 | // adhoc 412 | if (t.dims().size() == 4u) 413 | { 414 | out << "["; 415 | if (t.dims()[0] > 1) 416 | out << '\n'; 417 | for (int k = 0; k < t.dims()[0]; ++k) 418 | { 419 | out << " ["; 420 | if (t.dims()[1] > 1) 421 | out << '\n'; 422 | for (int d = 0; d < t.dims()[1]; ++d) 423 | { 424 | out << " ["; 425 | if (t.dims()[2] > 1) 426 | out << '\n'; 427 | for (int i = 0; i < t.dims()[2]; ++i) 428 | { 429 | out << " ["; 430 | for (int j = 0; j < t.dims()[3]; ++j) 431 | out << t(k, d, i, j) << ' '; 432 | out << " ]"; 433 | if (t.dims()[2] > 1) 434 | out << '\n'; 435 | } 436 | out << " ]"; 437 | if (t.dims()[1] > 1) 438 | out << '\n'; 439 | } 440 | out << " ]"; 441 | if (t.dims()[0] > 1) 442 | out << '\n'; 443 | } 444 | out << "]"; 445 | } 446 | else if (t.dims().size() == 3u) 447 | { 448 | out << "["; 449 | for (int d = 0; d < t.dims()[0]; ++d) 450 | { 451 | out << " ["; 452 | if (t.dims()[0] > 1) 453 | out << '\n'; 454 | for (int i = 0; i < t.dims()[1]; ++i) 455 | { 456 | out << "["; 457 | if (t.dims()[1] > 1) 458 | out << '\n'; 459 | for (int j = 0; j < t.dims()[2]; ++j) 460 | out << t(d, i, j) << '\t'; 461 | out << " ]"; 462 | if (t.dims()[1] > 1) 463 | out << '\n'; 464 | } 465 | out << " ]"; 466 | if (t.dims()[0] > 1) 467 | out << '\n'; 468 | } 469 | out << "]"; 470 | } 471 | else if (t.dims().size() == 2u) 472 | { 473 | out << "["; 474 | for (int i = 0; i < t.dims()[0]; ++i) 475 | { 476 | out << "["; 477 | if (t.dims()[0] > 1) 478 | out << '\n'; 479 | for (int j = 0; j < t.dims()[1]; ++j) 480 | out << t(i, j) << ' '; 481 | out << " ]"; 482 | if (t.dims()[0] > 1) 483 | out << '\n'; 484 | } 485 | out << "]\n"; 486 | } 487 | else if (t.dims().size() == 1u) 488 | { 489 | out << "["; 490 | for (int j = 0; j < t.dims()[0]; ++j) 491 | out << t(j) << ' '; 492 | out << "]"; 493 | } 494 | else 495 | { 496 | out << "TODO\n"; 497 | } 498 | return out; 499 | } 500 | 501 | } // namespace sadl 502 | -------------------------------------------------------------------------------- /sadl/layer_conv2dtranspose.h: -------------------------------------------------------------------------------- 1 | /* The copyright in this software is being made available under the BSD 2 | * License, included below. This software may be subject to other third party 3 | * and contributor rights, including patent rights, and no such rights are 4 | * granted under this license. 5 | * 6 | * Copyright (c) 2010-2022, ITU/ISO/IEC 7 | * All rights reserved. 8 | * 9 | * Redistribution and use in source and binary forms, with or without 10 | * modification, are permitted provided that the following conditions are met: 11 | * 12 | * * Redistributions of source code must retain the above copyright notice, 13 | * this list of conditions and the following disclaimer. 14 | * * Redistributions in binary form must reproduce the above copyright notice, 15 | * this list of conditions and the following disclaimer in the documentation 16 | * and/or other materials provided with the distribution. 17 | * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may 18 | * be used to endorse or promote products derived from this software without 19 | * specific prior written permission. 20 | * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 25 | * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 31 | * THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | #pragma once 34 | #include 35 | #if __AVX2__ 36 | #include 37 | #endif 38 | 39 | #include "layer.h" 40 | 41 | namespace sadl 42 | { 43 | namespace layers 44 | { 45 | template class Conv2DTranspose : public Layer 46 | { 47 | public: 48 | using Layer::Layer; 49 | using Layer::out_; // to avoid this-> 50 | using Layer::initDone_; 51 | 52 | virtual bool apply(std::vector *> &in) override; 53 | virtual bool init(const std::vector *> &in) override; 54 | 55 | protected: 56 | virtual bool loadInternal(std::istream &file, Version /*v*/) override; 57 | Dimensions strides_; 58 | Dimensions pads_; 59 | Dimensions out_pads_; 60 | int q_ = 0; 61 | // should never be used 62 | void conv2dtranspose(int nb_filters, int in_D,Tensor &out_, const Tensor &A, const Tensor &kernel); 63 | #if __AVX512F__ 64 | template void conv2dtranspose_simd512(int nb_filters, Tensor &out_, const Tensor &A, const Tensor &kernel); 65 | #endif 66 | using T2=typename ComputationType::type; 67 | Tensor tempo_; 68 | DUMP_MODEL_EXT; 69 | }; 70 | 71 | // assume data in in[0] and kernel in in[1] 72 | // data [batch, in_height, in_width, in_channels] 73 | // kernel [filter_height, filter_width, in_channels, out_channels] 74 | template bool Conv2DTranspose::apply(std::vector *> &in) 75 | { 76 | assert(in.size() == 2); 77 | assert(in[0]->dims().size() == 4); 78 | assert(in[1]->dims().size() == 4); 79 | const Tensor &A = *in[0]; 80 | const Tensor &kernel = *in[1]; 81 | out_.quantizer = A.quantizer - q_; 82 | out_.border_skip = A.border_skip; 83 | 84 | assert(out_.quantizer >= 0); 85 | assert(kernel.quantizer + q_ >= 0); 86 | 87 | const int nb_filters{ out_.dims()[3] }; // kernel.dims()[2] }; 88 | // int in_H{ A.dims()[1] }; 89 | // int in_W{ A.dims()[2] }; 90 | const int in_D{ A.dims()[3] }; 91 | //const int half_size{ kernel.dims()[0] / 2 }; 92 | //const int top{ pads_[0] }; 93 | //const int left{ pads_[1] }; 94 | // int start_h{ half_size - top }; 95 | // int start_w{ half_size - left }; 96 | 97 | #if __AVX512F__ 98 | switch(in_D) { 99 | case 32: conv2dtranspose_simd512<32>(nb_filters,out_,A,kernel); break; 100 | case 64: conv2dtranspose_simd512<64>(nb_filters,out_,A,kernel); break; 101 | case 128: conv2dtranspose_simd512<128>(nb_filters,out_,A,kernel); break; 102 | case 192: conv2dtranspose_simd512<192>(nb_filters,out_,A,kernel); break; 103 | case 256: conv2dtranspose_simd512<256>(nb_filters,out_,A,kernel); break; 104 | default: 105 | conv2dtranspose(nb_filters,in_D,out_,A,kernel); 106 | } 107 | #else 108 | conv2dtranspose(nb_filters,in_D,out_,A,kernel); 109 | #endif 110 | return true; 111 | } 112 | 113 | 114 | 115 | // data [batch, in_height, in_width, in_channels] 116 | // kernel [filter_height, filter_width, in_channels, out_channels] 117 | template bool Conv2DTranspose::init(const std::vector *> &in) 118 | { 119 | if (in.size() != 2) 120 | return false; 121 | SADL_DBG(std::cout << " - input conv2dtranspose: " << in[0]->dims() << ' ' << in[1]->dims() << std::endl); 122 | if (in[0]->dims().size() != 4) 123 | return false; 124 | if (in[1]->dims().size() != 4) 125 | return false; 126 | if (in[1]->dims()[0] != in[1]->dims()[1]) 127 | return false; 128 | if ((in[1]->dims()[0]) % 2 == 0) 129 | return false; 130 | 131 | // The spatial dimensions of a convolutional kernel must be 3 or 5 132 | if ( 133 | (in[1]->dims()[0] !=3 ) 134 | && (in[1]->dims()[0] !=5 )) 135 | return false; 136 | 137 | if (! ((in[1]->dims()[0] ==3 && pads_[1] == 1) || (in[1]->dims()[0] == 5 && pads_[1] == 2)) ) { 138 | return false; 139 | } 140 | Dimensions dim; 141 | dim.resize(4); 142 | dim[0] = in[0]->dims()[0]; 143 | dim[1] = (int) ceil(in[0]->dims()[1] * (float) strides_[1]); 144 | dim[2] = (int) ceil(in[0]->dims()[2] * (float) strides_[2]); 145 | dim[3] = in[1]->dims()[2]; 146 | out_.resize(dim); 147 | SADL_DBG(std::cout << " - output Conv2DTranspose: " << out_.dims() << std::endl); 148 | // init tempo 149 | int half_size=in[1]->dims()[0]/2; 150 | dim[1]+=half_size*2; 151 | dim[2]+=half_size*2; 152 | tempo_.resize(dim); 153 | initDone_ = true; 154 | return true; 155 | } 156 | 157 | template bool Conv2DTranspose::loadInternal(std::istream &file, Version /*v*/) 158 | { 159 | int32_t x = 0; 160 | file.read((char *) &x, sizeof(x)); 161 | if (x <= 0 || x > Dimensions::MaxDim) 162 | { 163 | std::cerr << "[ERROR] invalid nb of dimensions: " << x << std::endl; 164 | return false; 165 | } 166 | strides_.resize(x); 167 | for (int k = 0; k < strides_.size(); ++k) 168 | { 169 | file.read((char *) &x, sizeof(x)); 170 | strides_[k] = x; 171 | } 172 | if (strides_.size() == 2) 173 | { 174 | strides_ = Dimensions({ 1, strides_[0], strides_[1], 1 }); 175 | } 176 | if (strides_.size() != 4) 177 | { 178 | std::cerr << "[ERROR] invalid strides: " << strides_.size() << std::endl; 179 | return false; 180 | } 181 | if (strides_[0] != 1) 182 | { 183 | std::cerr << "[ERROR] invalid strides[0]: " << strides_[0] << std::endl; 184 | return false; 185 | } 186 | if (strides_[3] != 1) 187 | { 188 | std::cerr << "[ERROR] invalid strides[3]: " << strides_[3] << std::endl; 189 | return false; 190 | } 191 | if (strides_[1] != 2 || strides_[1] != 2) 192 | { 193 | std::cerr << "[ERROR] stride not 2: to check " << strides_ << std::endl; 194 | return false; 195 | } 196 | SADL_DBG(std::cout << " - strides: " << strides_ << std::endl); 197 | 198 | file.read((char *) &x, sizeof(x)); 199 | if (x <= 0 || x > Dimensions::MaxDim) 200 | { 201 | std::cerr << "[ERROR] invalid nb of dimensions: " << x << std::endl; 202 | return false; 203 | } 204 | pads_.resize(x); 205 | for (int k = 0; k < pads_.size(); ++k) 206 | { 207 | file.read((char *) &x, sizeof(x)); 208 | pads_[k] = x; 209 | if (x != 1 && x!=2) { 210 | std::cerr << "[ERROR] pads values not supported: " << x << std::endl; 211 | return false; 212 | } 213 | } 214 | SADL_DBG(std::cout << " - pads: " << pads_ << std::endl); 215 | 216 | file.read((char *) &x, sizeof(x)); 217 | if (x <= 0 || x > Dimensions::MaxDim) 218 | { 219 | std::cerr << "[ERROR] invalid nb of dimensions: " << x << std::endl; 220 | return false; 221 | } 222 | out_pads_.resize(x); 223 | for (int k = 0; k < out_pads_.size(); ++k) 224 | { 225 | file.read((char *) &x, sizeof(x)); 226 | out_pads_[k] = x; 227 | if (x != 1) { 228 | std::cerr << "[ERROR] output pads !=1 " << x << std::endl; 229 | return false; 230 | } 231 | } 232 | SADL_DBG(std::cout << " - out_pads: " << out_pads_ << std::endl); 233 | 234 | { 235 | file.read((char *) &q_, sizeof(q_)); 236 | SADL_DBG(std::cout << " - q: " << q_ << std::endl); 237 | } 238 | 239 | return true; 240 | } 241 | 242 | // should never be used for perf reasons 243 | template 244 | void Conv2DTranspose::conv2dtranspose(int nb_filters, int in_D,Tensor &out_, const Tensor &A, 245 | const Tensor &kernel) 246 | { 247 | #if DEBUG_SIMD && __AVX2__ 248 | const int in_H{ A.dims()[1] }; 249 | const int in_W{ A.dims()[2] }; 250 | std::cout << "\n[WARN] debug generic version convtranspose inD=" << in_D << " outD=" << nb_filters << 251 | //" s=[" << s_w << ' ' << s_h << "] " 252 | in_H << 'x' << in_W << " " << 253 | // << in_D * kernel.dims()[0] * kernel.dims()[1] * nb_filters * (in_H /) * (in_W / s_w) / 1000 << " kMAC" 254 | std::endl; 255 | #endif 256 | constexpr int im_nb = 0; 257 | assert(strides_[1]==2); 258 | assert(strides_[2]==2); 259 | int half_size=kernel.dims()[0]/2; 260 | constexpr int sw=2; 261 | constexpr int sh=2; 262 | 263 | const int shift = kernel.quantizer + q_; 264 | const int out_h=out_.dims()[1]; 265 | const int out_w=out_.dims()[2]; 266 | tempo_.fill(T2{}); 267 | 268 | for (int filter = 0; filter < nb_filters; ++filter) 269 | { 270 | for (int im_i = 0; im_i < out_h; im_i += sh) 271 | { 272 | for (int im_j = 0; im_j < out_w; im_j += sw) 273 | { 274 | const int i1=im_i/sh; 275 | const int j1=im_j/sw; 276 | assert(A.in(im_nb,i1,j1,0)); 277 | for (int filter_i = -half_size; filter_i <= half_size; ++filter_i) 278 | { 279 | // fixed 280 | for (int filter_j = -half_size; filter_j <= half_size; ++filter_j) 281 | { 282 | // fixed 283 | const int ki = half_size + filter_i; 284 | const int kj = half_size + filter_j; 285 | const int ii = im_i + ki; 286 | const int jj = im_j + kj; 287 | T2 s{}; 288 | for (int filter_d = 0; filter_d < in_D; ++filter_d) 289 | { 290 | s+= A(im_nb, i1, j1, filter_d) * kernel(ki, kj, filter, filter_d); 291 | COUNTERS_MAC(kernel(ki, kj, filter, filter_d)); 292 | } 293 | tempo_(im_nb,ii,jj,filter) += s; 294 | } 295 | } 296 | } 297 | } 298 | for (int im_i = 0; im_i < out_h; ++im_i) 299 | { 300 | for (int im_j = 0; im_j < out_w; ++im_j) 301 | { 302 | auto x=tempo_(im_nb,im_i+half_size,im_j+half_size,filter); 303 | ComputationType::quantize(x, shift); 304 | COUNTERS(x); 305 | SATURATE(x); 306 | out_(im_nb,im_i,im_j,filter)=static_cast(x); 307 | } 308 | } 309 | } 310 | } 311 | 312 | #if __AVX512F__ 313 | template<> 314 | template 315 | void Conv2DTranspose::conv2dtranspose_simd512(int nb_filters, Tensor &out_, const Tensor &A, const Tensor &kernel) 316 | { 317 | constexpr int im_nb = 0; 318 | assert(strides_[1]==2); 319 | assert(strides_[2]==2); 320 | assert(kernel.dims()[0]==kernel.dims()[1]); 321 | int half_size=kernel.dims()[0]/2; 322 | static_assert(in_D % 16 == 0, "Should be used with mod16 filters."); 323 | constexpr int sw=2; 324 | constexpr int sh=2; 325 | 326 | const int out_h=out_.dims()[1]; 327 | const int out_w=out_.dims()[2]; 328 | tempo_.fill(T2{}); 329 | 330 | for (int filter = 0; filter < nb_filters; ++filter) 331 | { 332 | for (int im_i = 0; im_i < out_h; im_i += sh) 333 | { 334 | for (int im_j = 0; im_j < out_w; im_j += sw) 335 | { 336 | const int i1=im_i/sh; 337 | const int j1=im_j/sw; 338 | assert(A.in(im_nb,i1,j1,0)); 339 | for (int filter_i = -half_size; filter_i <= half_size; ++filter_i) 340 | { 341 | // fixed 342 | for (int filter_j = -half_size; filter_j <= half_size; ++filter_j) 343 | { 344 | __m512 s = _mm512_setzero_ps(); 345 | const int ki = half_size + filter_i; 346 | const int kj = half_size + filter_j; 347 | const int ii = im_i + ki; 348 | const int jj = im_j + kj; 349 | const float *kptr = kernel.addr(ki, kj, filter, 0); 350 | const float *aptr = A.addr(im_nb, i1, j1, 0); 351 | // fixed 352 | for (int filter_d = 0; filter_d < in_D; filter_d+=16) 353 | { 354 | const __m512 k0 = _mm512_loadu_ps(kptr+filter_d); // not always aligned 355 | #if __FMA__ 356 | s = _mm512_fmadd_ps(k0, _mm512_load_ps(aptr+filter_d), s); 357 | #else 358 | const __m512 m0 = _mm512_mul_ps(k0, _mm512_load_ps(aptr+filter_d)); 359 | s = _mm512_add_ps(s, m0); 360 | #endif 361 | } 362 | tempo_(im_nb,ii,jj,filter) += sum16_float(s); 363 | } 364 | } 365 | } 366 | } 367 | for (int im_i = 0; im_i < out_h; ++im_i) 368 | { 369 | for (int im_j = 0; im_j < out_w; ++im_j) 370 | { 371 | auto x=tempo_(im_nb,im_i+half_size,im_j+half_size,filter); 372 | out_(im_nb,im_i,im_j,filter)=x; 373 | } 374 | } 375 | } 376 | } 377 | #endif 378 | 379 | #if __AVX512BW__ 380 | template<> 381 | template 382 | void Conv2DTranspose::conv2dtranspose_simd512(int nb_filters, Tensor &out_, const Tensor &A, const Tensor &kernel) 383 | { 384 | constexpr int im_nb = 0; 385 | assert(strides_[1]==2); 386 | assert(strides_[2]==2); 387 | static_assert(in_D % 32 == 0, "Should be used with mod32 filters."); 388 | #if DEBUG_COUNTERS || SATURATE_RESULT 389 | using T = int16_t; 390 | #endif 391 | assert(kernel.dims()[0]==kernel.dims()[1]); 392 | const int half_size=kernel.dims()[0]/2; 393 | static_assert(in_D % 32 == 0, "Should be used with mod32 filters."); 394 | constexpr int sw=2; 395 | constexpr int sh=2; 396 | const int out_h=out_.dims()[1]; 397 | const int out_w=out_.dims()[2]; 398 | tempo_.fill(T2{}); 399 | const int shift = kernel.quantizer + q_; 400 | 401 | for (int filter = 0; filter < nb_filters; ++filter) 402 | { 403 | for (int im_i = 0; im_i < out_h; im_i += sh) 404 | { 405 | for (int im_j = 0; im_j < out_w; im_j += sw) 406 | { 407 | const int i1=im_i/sh; 408 | const int j1=im_j/sw; 409 | assert(A.in(im_nb,i1,j1,0)); 410 | const T *aptr = A.addr(im_nb, i1, j1, 0); 411 | for (int filter_i = -half_size; filter_i <= half_size; ++filter_i) 412 | { 413 | // fixed 414 | for (int filter_j = -half_size; filter_j <= half_size; ++filter_j) 415 | { 416 | __m512i s = _mm512_setzero_si512(); 417 | const int ki = half_size + filter_i; 418 | const int kj = half_size + filter_j; 419 | const int ii = im_i + ki; 420 | const int jj = im_j + kj; 421 | const T *kptr = kernel.addr(ki, kj, filter, 0); 422 | // fixed 423 | for (int filter_d = 0; filter_d < in_D; filter_d+=32) 424 | { 425 | const __m512i k0 = _mm512_loadu_si512(kptr+filter_d); // not always aligned 426 | const __m512i v0 = _mm512_load_si512(aptr+filter_d); 427 | const __m512i mad0 = _mm512_madd_epi16(k0, v0); // res in si32 428 | s = _mm512_add_epi32(s, mad0); 429 | } 430 | tempo_(im_nb,ii,jj,filter) += _mm512_reduce_add_epi32(s); 431 | } 432 | } 433 | } 434 | } 435 | for (int im_i = 0; im_i < out_h; ++im_i) 436 | { 437 | for (int im_j = 0; im_j < out_w; ++im_j) 438 | { 439 | auto z=tempo_(im_nb,im_i+half_size,im_j+half_size,filter)>>shift; 440 | SATURATE(z); 441 | out_(im_nb,im_i,im_j,filter)=z; 442 | } 443 | } 444 | } 445 | } 446 | #endif 447 | 448 | #if __AVX512BW__ || __AVX512F__ 449 | template 450 | template void Conv2DTranspose::conv2dtranspose_simd512(int nb_filters, Tensor &out_, const Tensor &A, const Tensor &kernel) { 451 | std::cerr<<"TODO "<