├── src ├── Makefile ├── memoryBroker.cpp ├── utils.cpp ├── tensor.cpp └── contract.cpp ├── include ├── Makefile ├── tcl.h ├── memoryBroker.h ├── tcl_types.h ├── contract.h ├── tensor.h └── utils.h ├── .gitattributes ├── .gitignore ├── misc ├── tcl_1thread.png ├── tcl_24thread.png └── tcl_speedup.png ├── pythonAPI ├── tcl │ ├── __init__.py │ └── tcl.py └── setup.py ├── examples ├── Makefile ├── test_python.py └── contraction.cpp ├── Makefile ├── benchmark └── python │ └── benchmark.py ├── COPYING.LESSER ├── README.md └── LICENSE.txt /src/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | ${MAKE} -C ../ 3 | -------------------------------------------------------------------------------- /include/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | ${MAKE} -C ../ 3 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | benchmark/python/benchmark.sh linguist-vendored 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.dSYM 2 | *.pyc 3 | *tags 4 | *.o 5 | *.swp 6 | *.swo 7 | *.exe 8 | lib/ 9 | -------------------------------------------------------------------------------- /misc/tcl_1thread.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/springer13/tcl/HEAD/misc/tcl_1thread.png -------------------------------------------------------------------------------- /misc/tcl_24thread.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/springer13/tcl/HEAD/misc/tcl_24thread.png -------------------------------------------------------------------------------- /misc/tcl_speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/springer13/tcl/HEAD/misc/tcl_speedup.png -------------------------------------------------------------------------------- /pythonAPI/tcl/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """TCL - Tensor Contraction Module based on the C++ tensor contraction library (TCL)""" 3 | from tcl import * 4 | -------------------------------------------------------------------------------- /examples/Makefile: -------------------------------------------------------------------------------- 1 | CXX_FLAGS=-O0 -g -std=c++11 2 | CXX_LINK= -L../lib/ -ltcl 3 | 4 | ifeq ($(CXX),icpc) 5 | CXX_FLAGS += -qopenmp -xhost 6 | else 7 | ifeq ($(CXX),g++) 8 | CXX_FLAGS += -fopenmp -mcpu=native 9 | else 10 | ifeq ($(CXX),clang++) 11 | CXX_FLAGS += -fopenmp -march=native 12 | endif 13 | endif 14 | endif 15 | 16 | scalar: 17 | ${MAKE} clean 18 | ${MAKE} all 19 | 20 | all: contraction.o 21 | ${CXX} contraction.o ${CXX_FLAGS} -o contraction.exe ${CXX_LINK} 22 | 23 | %.o: %.cpp 24 | ${CXX} ${CXX_FLAGS} -I../include/ -c $< -o $@ 25 | 26 | clean: 27 | rm -rf ./*.o ./*.exe 28 | -------------------------------------------------------------------------------- /pythonAPI/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from distutils.core import setup 3 | import os 4 | 5 | OKGREEN = '\033[92m' 6 | FAIL = '\033[91m' 7 | WARNING = '\033[93m' 8 | ENDC = '\033[0m' 9 | 10 | setup(name="tcl", 11 | version="0.1.0", 12 | description="Tensor Contraction Library", 13 | author="Paul Springer", 14 | author_email="springer@aices.rwth-aachen.de", 15 | packages=["tcl"] 16 | ) 17 | 18 | print("") 19 | output = "# "+ FAIL + "IMPORTANT"+ENDC+": execute 'export TCL_ROOT=%s/../' #"%(os.path.dirname(os.path.realpath(__file__))) 20 | print('#'*(len(output)-2*len(FAIL)+1)) 21 | print(output) 22 | print('#'*(len(output)-2*len(FAIL)+1)) 23 | print("") 24 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | INCLUDE_FLAGS=-I./include/ -I${HPTT_ROOT}/include/ 2 | 3 | BLAS_LIB=-L${BLIS_ROOT}/lib -lblis 4 | 5 | # MKL 6 | #BLAS_LIB=-L${MKLROOT}/lib/intel64 -lmkl_intel_lp64 -lmkl_intel_thread -lmkl_core -liomp5 -lpthread -lm -ldl 7 | #INCLUDE_FLAGS +=-I${MKLROOT}/include 8 | 9 | CXX_LINK=-L${HPTT_ROOT}/lib -lhptt ${BLAS_LIB} 10 | CXX_FLAGS=-O3 -std=c++11 -fPIC ${INCLUDE_FLAGS} -fopenmp -march=native 11 | 12 | scalar: 13 | ${MAKE} clean 14 | ${MAKE} scalar2 15 | 16 | scalar2: all 17 | 18 | SRC=$(wildcard ./src/*.cpp) 19 | OBJ=$(SRC:.cpp=.o) 20 | 21 | all: ${OBJ} 22 | mkdir -p lib 23 | ${CXX} ${OBJ} ${CXX_FLAGS} -o lib/libtcl.so -shared ${CXX_LINK} 24 | ar rvs lib/libtcl.a ${OBJ} 25 | 26 | %.o: %.cpp 27 | ${CXX} ${CXX_FLAGS} ${INCLUDE_PATH} -c $< -o $@ 28 | 29 | clean: 30 | rm -rf src/*.o lib/libtcl.so lib/libtcl.a 31 | -------------------------------------------------------------------------------- /include/tcl.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2017 Paul Springer (springer@aices.rwth-aachen.de) 3 | * 4 | * This program is free software: you can redistribute it and/or modify 5 | * it under the terms of the GNU Lesser General Public License as published by 6 | * the Free Software Foundation, either version 3 of the License, or 7 | * (at your option) any later version. 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program. If not, see . 16 | */ 17 | 18 | #pragma once 19 | 20 | #include "tcl_types.h" 21 | #include "tensor.h" 22 | #include "contract.h" 23 | #include "memoryBroker.h" 24 | 25 | namespace tcl{ 26 | extern MemoryBroker memBroker; 27 | } 28 | -------------------------------------------------------------------------------- /include/memoryBroker.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2017 Paul Springer (springer@aices.rwth-aachen.de) 3 | * 4 | * This program is free software: you can redistribute it and/or modify 5 | * it under the terms of the GNU Lesser General Public License as published by 6 | * the Free Software Foundation, either version 3 of the License, or 7 | * (at your option) any later version. 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program. If not, see . 16 | */ 17 | 18 | 19 | #pragma once 20 | #include 21 | #include 22 | 23 | namespace tcl 24 | { 25 | 26 | class MemoryBroker { 27 | public: 28 | MemoryBroker(); 29 | 30 | void alloc( size_t size ); 31 | char* requestMemory( size_t size ); 32 | void reset(); 33 | void release(); 34 | bool isInit() const; 35 | uint64_t size() const; 36 | 37 | private: 38 | 39 | char *ptr; 40 | size_t totalSize; 41 | size_t currentOffset; 42 | }; 43 | 44 | } 45 | -------------------------------------------------------------------------------- /examples/test_python.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | if sys.version_info[0] < 3: 4 | import tcl 5 | else: 6 | import tcl.tcl as tcl 7 | 8 | print("Testing float64") 9 | for i in range(100): 10 | A = np.random.rand(4,3) 11 | B = np.random.rand(3,3) 12 | C = np.zeros((4,3), dtype=np.float) 13 | np_result = np.einsum('ij,jk->ik', A, B) 14 | tcl_result = tcl.einsum('ij,jk->ik', A, B) 15 | tcl.tensorMult(1., A, 'i,j', B, 'j,k', 0., C, 'i,k') 16 | assert(np.linalg.norm(np_result - tcl_result ) < 1e-15) 17 | assert(np.linalg.norm(np_result - C) < 1e-15) 18 | 19 | print("Test success") 20 | 21 | print("Testing complex128") 22 | for i in range(100): 23 | A = np.random.rand(4,3) + 1j*np.random.rand(4,3) 24 | B = np.random.rand(3,3) + 1j*np.random.rand(3,3) 25 | C = np.zeros((4,3), dtype=A.dtype) 26 | np_result = np.einsum('ij,jk->ik', A, B) 27 | tcl_result = tcl.einsum('ij,jk->ik', A, B) 28 | tcl.tensorMult(1., A, 'i,j', B, 'j,k', 0., C, 'i,k') 29 | assert(np.linalg.norm(np_result - tcl_result ) < 1e-14) 30 | assert(np.linalg.norm(np_result - C) < 1e-14) 31 | 32 | print("Test success") 33 | 34 | print("Testing multiple Indices") 35 | for i in range(100): 36 | U = np.random.rand(2, 2, 4) 37 | T = np.random.rand(2, 2, 2, 2, 2) 38 | np_result = np.einsum('pqa,qldru->pldrua', U, T) 39 | tcl_result = tcl.einsum('pqa,qldru->pldrua', U, T) 40 | assert(np.linalg.norm(np_result - tcl_result) < 1e-15) 41 | 42 | print("Test success") 43 | 44 | -------------------------------------------------------------------------------- /include/tcl_types.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2017 Paul Springer (springer@aices.rwth-aachen.de) 3 | * 4 | * This program is free software: you can redistribute it and/or modify 5 | * it under the terms of the GNU Lesser General Public License as published by 6 | * the Free Software Foundation, either version 3 of the License, or 7 | * (at your option) any later version. 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program. If not, see . 16 | */ 17 | 18 | 19 | #ifndef TCL_TYPES_H 20 | #define TCL_TYPES_H 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | namespace tcl 30 | { 31 | 32 | typedef int sizeType; 33 | typedef std::list indicesType; 34 | 35 | using FloatComplex = std::complex; 36 | using DoubleComplex = std::complex; 37 | 38 | enum error { 39 | SUCCESS = 0, 40 | INVALID_PARAMETER_0 = 1, 41 | INVALID_PARAMETER_1 = 2, 42 | INVALID_PARAMETER_2 = 3, 43 | INVALID_PARAMETER_3 = 4, 44 | INVALID_PARAMETER_4 = 5, 45 | INVALID_PARAMETER_5 = 6, 46 | INVALID_PARAMETER_6 = 7, 47 | INVALID_PARAMETER_7 = 8, 48 | INVALID_TENSOR_SIZE = 9, 49 | LOOP_INDEX_DETECTED = 10, 50 | TENSOR_CONTRACTION_UNSUPPORTED = 11, 51 | INTERNAL_ERROR_DIM_MISMATCH, 52 | INTERNAL_ERROR, 53 | }; 54 | 55 | } 56 | 57 | #endif 58 | -------------------------------------------------------------------------------- /src/memoryBroker.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2017 Paul Springer (springer@aices.rwth-aachen.de) 3 | * 4 | * This program is free software: you can redistribute it and/or modify 5 | * it under the terms of the GNU Lesser General Public License as published by 6 | * the Free Software Foundation, either version 3 of the License, or 7 | * (at your option) any later version. 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program. If not, see . 16 | */ 17 | 18 | 19 | #include 20 | 21 | #include "memoryBroker.h" 22 | 23 | namespace tcl 24 | { 25 | MemoryBroker memBroker; 26 | 27 | MemoryBroker::MemoryBroker() : ptr(nullptr), totalSize(0), currentOffset(0) {} 28 | 29 | void MemoryBroker::alloc( size_t size ) 30 | { 31 | int dummy = posix_memalign((void**)&(this->ptr), 4096, size); 32 | this->totalSize = size; 33 | this->currentOffset = 0; 34 | } 35 | 36 | char* MemoryBroker::requestMemory( size_t size ) 37 | { 38 | assert( this->currentOffset + size <= this->totalSize ); 39 | 40 | char* ret = &ptr[currentOffset]; 41 | currentOffset += size; 42 | return ret; 43 | } 44 | void MemoryBroker::reset() { this->currentOffset = 0; } 45 | 46 | void MemoryBroker::release() { 47 | free(this->ptr); 48 | this->totalSize = 0; 49 | this->currentOffset = 0; 50 | } 51 | 52 | bool MemoryBroker::isInit() const { return totalSize != 0; } 53 | uint64_t MemoryBroker::size() const { return totalSize; } 54 | 55 | } 56 | -------------------------------------------------------------------------------- /examples/contraction.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2017 Paul Springer (springer@aices.rwth-aachen.de) 3 | * 4 | * This program is free software: you can redistribute it and/or modify 5 | * it under the terms of the GNU Lesser General Public License as published by 6 | * the Free Software Foundation, either version 3 of the License, or 7 | * (at your option) any later version. 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program. If not, see . 16 | */ 17 | 18 | 19 | #include 20 | #include 21 | 22 | #include 23 | 24 | int main(int argc, char** argv) 25 | { 26 | tcl::sizeType m = 5; 27 | tcl::sizeType n = 4; 28 | tcl::sizeType k1 = 2; 29 | tcl::sizeType k2 = 3; 30 | tcl::sizeType l1 = 6; 31 | 32 | float *dataA, *dataB, *dataC; 33 | posix_memalign((void**) &dataA, 64, sizeof(float) * ((size_t)k2)*m*k1*l1); 34 | posix_memalign((void**) &dataB, 64, sizeof(float) * ((size_t)n)*k2*k1*l1); 35 | posix_memalign((void**) &dataC, 64, sizeof(float) * ((size_t)m)*n*l1); 36 | 37 | // Initialize tensors (data is not owned by the tensors) 38 | tcl::Tensor A({k1,m,k2,l1}, dataA); 39 | tcl::Tensor B({n,k2,k1,l1}, dataB); 40 | tcl::Tensor C({m,n,l1}, dataC); 41 | 42 | // Data initialization 43 | #pragma omp parallel for 44 | for(int i=0; i < A.getTotalSize(); ++i) 45 | dataA[i] = (i+1)*7% 100; 46 | #pragma omp parallel for 47 | for(int i=0; i < B.getTotalSize(); ++i) 48 | dataB[i] = (i+1)*13% 100; 49 | #pragma omp parallel for 50 | for(int i=0; i < C.getTotalSize(); ++i) 51 | dataC[i] = (i+1)*5% 100; 52 | 53 | float alpha = 2; 54 | float beta = 4; 55 | 56 | // tensor contarction: C_{m,n} = alpha * A_{k2,m,k1} * B_{n,k2,k1} + beta * C_{m,n} 57 | auto err = tcl::tensorMult( alpha, A["k1,m,k2,l1"], B["n,k2,k1,l1"], beta, C["m,n,l1"] ); 58 | if( err != tcl::SUCCESS ){ 59 | printf("ERROR: %s\n", tcl::getErrorString(err)); 60 | exit(-1); 61 | } 62 | 63 | for(int i=0; i < m; ++i){ 64 | for(int j=0; j < n; ++j) 65 | std::cout<< dataC[j * m + i] << " "; 66 | std::cout<< "\n"; 67 | } 68 | 69 | return 0; 70 | } 71 | -------------------------------------------------------------------------------- /include/contract.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2017 Paul Springer (springer@aices.rwth-aachen.de) 3 | * 4 | * This program is free software: you can redistribute it and/or modify 5 | * it under the terms of the GNU Lesser General Public License as published by 6 | * the Free Software Foundation, either version 3 of the License, or 7 | * (at your option) any later version. 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program. If not, see . 16 | */ 17 | 18 | 19 | #pragma once 20 | 21 | /// \page pg_contract Tensor Contraction 22 | /// 23 | /// \section sec_contract Tensor Contraction 24 | /// A tensor contraction is the generalization of a matrix-matrix multiplication to multiple dimensions 25 | 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | 35 | #include "tcl_types.h" 36 | #include "tensor.h" 37 | 38 | namespace tcl{ 39 | 40 | template 41 | error tensorMult(const floatType alpha, const Tensor *A, 42 | const Tensor *B, 43 | const floatType beta, Tensor *C); 44 | } 45 | 46 | extern "C"{ 47 | void sTensorMult(const float alpha, const float *A, const long *sizeA, const long *outerSizeA, const char* indA, 48 | const float *B, const long *sizeB, const long *outerSizeB, const char* indB, 49 | const float beta , float *C, const long *sizeC, const long *outerSizeC, const char* indC, const int useRowMajor = 0); 50 | 51 | void dTensorMult(const double alpha, const double *A, const long *sizeA, const long *outerSizeA, const char* indA, 52 | const double *B, const long *sizeB, const long *outerSizeB, const char* indB, 53 | const double beta , double *C, const long *sizeC, const long *outerSizeC, const char* indC, const int useRowMajor = 0); 54 | 55 | void cTensorMult(const float _Complex alpha, const float _Complex *A, const long *sizeA, const long *outerSizeA, const char* indA, 56 | const float _Complex *B, const long *sizeB, const long *outerSizeB, const char* indB, 57 | const float _Complex beta , float _Complex *C, const long *sizeC, const long *outerSizeC, const char* indC, const int useRowMajor = 0); 58 | 59 | void zTensorMult(const double _Complex alpha, const double _Complex *A, const long *sizeA, const long *outerSizeA, const char* indA, 60 | const double _Complex *B, const long *sizeB, const long *outerSizeB, const char* indB, 61 | const double _Complex beta , double _Complex *C, const long *sizeC, const long *outerSizeC, const char* indC, const int useRowMajor = 0); 62 | 63 | void randomNumaAwareInit(float *data, const long *size, int dim); 64 | } 65 | 66 | -------------------------------------------------------------------------------- /benchmark/python/benchmark.py: -------------------------------------------------------------------------------- 1 | import sys 2 | if sys.version_info[0] < 3: 3 | import tcl 4 | else: 5 | import tcl.tcl as tcl 6 | 7 | import numpy as np 8 | import time 9 | import argparse 10 | OKGREEN = '\033[92m' 11 | FAIL = '\033[91m' 12 | ENDC = '\033[0m' 13 | 14 | sizeA = [] 15 | sizeB = [] 16 | sizeC = [] 17 | order = 'F' #'F': Fortran: Column-major; 'C' : C : row-major 18 | alpha = 1.0 19 | beta = 0.0 20 | floatType = np.float32 21 | 22 | parser = argparse.ArgumentParser(description='Benchmark for tensor contractions using TCL and NumPY.') 23 | 24 | parser.add_argument('indA', metavar='indicesA', type=str, help="comma separated list of characters, which denote the indices of A (e.g., 'u,m,n')") 25 | parser.add_argument('indB', metavar='indicesB', type=str, help="comma separated list of characters, which denote the indices of B (e.g., 'a,u')") 26 | parser.add_argument('indC', metavar='indicesC', type=str, help="comma separated list of characters, which denote the indices of C (e.g., 'n,m,a')") 27 | parser.add_argument('sizes', metavar='sizes', type=str, help=",,... (e.g., a:100,b:200)") 28 | parser.add_argument('--useRowMajor', action='store_true', help="Uses a row-major data layout; by default we use column-major.") 29 | parser.add_argument('--floatType', metavar='floatType', type=str, help="float type can be either 'double' or 'float' (default)") 30 | parser.add_argument('--alpha', type=float, help='alpha scalar (default: 1.0)') 31 | parser.add_argument('--beta', type=float, help='beta scalar (default: 0.0)') 32 | 33 | args = parser.parse_args() 34 | if( args.useRowMajor): 35 | order = 'C' 36 | if( args.alpha): 37 | alpha = float(args.alpha) 38 | if( args.beta): 39 | beta = float(args.beta) 40 | if( args.floatType ): 41 | if( args.floatType == 'double' ): 42 | floatType = np.float64 43 | else: 44 | floatType = np.float32 45 | 46 | sizes = {} 47 | gflops = 2./1e9 48 | for pair in args.sizes.split(","): 49 | idx = pair.split(':')[0] 50 | size = int(pair.split(':')[1]) 51 | gflops *= size 52 | sizes[idx] = size 53 | 54 | indA = "" 55 | indB = "" 56 | indC = "" 57 | for idx in args.indA.split(","): 58 | indA+= idx + "," 59 | sizeA.append(sizes[idx]) 60 | indA= indA[:-1] 61 | for idx in args.indB.split(","): 62 | indB+= idx + "," 63 | sizeB.append(sizes[idx]) 64 | indB = indB[:-1] 65 | for idx in args.indC.split(","): 66 | indC+= idx + "," 67 | sizeC.append(sizes[idx]) 68 | indC= indC[:-1] 69 | 70 | Ma = np.random.rand(2500**2).astype('f') 71 | Mb = np.random.rand(2500**2).astype('f') 72 | A = np.empty(sizeA, order=order, dtype=floatType) 73 | B = np.empty(sizeB, order=order, dtype=floatType) 74 | C = np.empty(sizeC, order=order, dtype=floatType) 75 | tcl.randomNumaAwareInit(A) 76 | tcl.randomNumaAwareInit(B) 77 | tcl.randomNumaAwareInit(C) 78 | 79 | 80 | indC_np = "" 81 | axesA = [] 82 | axesB = [] 83 | for idx in args.indA.split(","): 84 | found = 0 85 | posB = indB.replace(',','').find(idx) 86 | posA = indA.replace(',','').find(idx) 87 | if( posB != -1 ): #contracted index 88 | axesA.append(posA) 89 | axesB.append(posB) 90 | for idxB in args.indB.split(","): 91 | if idxB == idx: 92 | found = 1 93 | if( not found ): 94 | indC_np += idx 95 | for idx in args.indB.split(","): 96 | found = 0 97 | for idxB in args.indA.split(","): 98 | if idxB == idx: 99 | found = 1 100 | if( not found ): 101 | indC_np += idx 102 | perm = [] 103 | for idx in indC.replace(',',''): 104 | perm.append(indC_np.find(idx)) 105 | 106 | # print(indA, indB, indC_np, indC.replace(',','')) 107 | # print(perm, axesA, axesB) 108 | 109 | timeTCL = 1e100 110 | for i in range(3): 111 | Mb = Ma *1.1 + Mb #trash cache 112 | s = time.time() 113 | tcl.tensorMult( alpha, A, indA, B, indB, beta, C, indC) 114 | timeTCL = min(timeTCL, time.time() - s) 115 | timeNP = 1e100 116 | 117 | for i in range(3): 118 | Mb = Ma *1.1 + Mb #trash cache 119 | s = time.time() 120 | # C_ = np.einsum("%s,%s->%s"%(indA.replace(',',''),indB.replace(',',''),indC.replace(',','')), A, B) 121 | if( indC.replace(',','') != indC_np ): #transpose required 122 | C_ = np.transpose(np.tensordot(A, B, axes=(axesA,axesB)),perm).copy(order=order) 123 | else: 124 | C_ = np.tensordot(A, B, axes=(axesA,axesB)) 125 | timeNP = min(time.time() - s, timeNP) 126 | 127 | if( not tcl.equal(C, C_, 1000) ): 128 | print("ERROR: validation" + FAIL + " failed!!!" + ENDC) 129 | print(indC.replace(',','') != indC_np) #transpose required 130 | else: 131 | print("%.2f GFLOPS %.2f GFLOPS speedup: %.2fx"%( gflops/timeTCL, gflops/timeNP, timeNP/ timeTCL)) 132 | print(indC.replace(',','') != indC_np) #transpose required 133 | -------------------------------------------------------------------------------- /src/utils.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2017 Paul Springer (springer@aices.rwth-aachen.de) 3 | * 4 | * This program is free software: you can redistribute it and/or modify 5 | * it under the terms of the GNU Lesser General Public License as published by 6 | * the Free Software Foundation, either version 3 of the License, or 7 | * (at your option) any later version. 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program. If not, see . 16 | */ 17 | 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include 26 | 27 | namespace tcl 28 | { 29 | 30 | const char* getErrorString( error err ){ 31 | switch (err){ 32 | case SUCCESS: 33 | return "SUCCESS"; 34 | case INVALID_PARAMETER_0: 35 | return "Parameter 0 is invalid."; 36 | case INVALID_PARAMETER_1: 37 | return "Parameter 1 is invalid."; 38 | case INVALID_PARAMETER_2: 39 | return "Parameter 2 is invalid."; 40 | case INVALID_PARAMETER_3: 41 | return "Parameter 3 is invalid."; 42 | case INVALID_PARAMETER_4: 43 | return "Parameter 4 is invalid."; 44 | case INVALID_PARAMETER_5: 45 | return "Parameter 5 is invalid."; 46 | case INVALID_PARAMETER_6: 47 | return "Parameter 6 is invalid."; 48 | case INVALID_PARAMETER_7: 49 | return "Parameter 7 is invalid."; 50 | case INVALID_TENSOR_SIZE: 51 | return "Tensor size invalid. Mismatch between the sizes of two indices."; 52 | case INTERNAL_ERROR: 53 | return "Internal error."; 54 | case TENSOR_CONTRACTION_UNSUPPORTED: 55 | return "The specified tensor contraction is not yet supported. Please open a ticket saying that you need this type of tensor contraction."; 56 | default: 57 | return "Unkown error."; 58 | } 59 | } 60 | int getNumThreads(){ 61 | auto tmp = std::getenv("OMP_NUM_THREADS"); 62 | if( tmp ) 63 | return std::max(1, atoi(tmp)); 64 | else 65 | return 1; 66 | } 67 | 68 | bool isIdentity(const std::vector &perm){ 69 | for(int i=0; i < perm.size(); ++i) 70 | if( i != perm[i] ) 71 | return false; 72 | return true; 73 | } 74 | 75 | template<> 76 | void gemm(const char *transa, const char *transb, 77 | const sizeType *m, const sizeType *n, const sizeType *k, 78 | const float *alpha, const float *a, 79 | const sizeType *lda, const float *b, const sizeType *ldb, 80 | const float *beta, float *c, const sizeType *ldc) 81 | { 82 | #ifdef DEBUG 83 | std::cout<< "GEMM: " << transa+'\0' << " "<< transb+'\0' << " "<< *m << " " << *n << " " << *k << std::endl; 84 | #endif 85 | sgemm_(transa, transb, m, n, k, 86 | alpha, a, lda, b, ldb, 87 | beta, c, ldc); 88 | } 89 | 90 | template<> 91 | void gemm(const char *transa, const char *transb, 92 | const sizeType *m, const sizeType *n, const sizeType *k, 93 | const double *alpha, const double *a, 94 | const sizeType *lda, const double *b, const sizeType *ldb, 95 | const double *beta, double *c, const sizeType *ldc) 96 | { 97 | #ifdef DEBUG 98 | std::cout<< "GEMM: " << transa+'\0' << " "<< transb+'\0' << " "<< *m << " " << *n << " " << *k << std::endl; 99 | #endif 100 | dgemm_(transa, transb, m, n, k, 101 | alpha, a, lda, b, ldb, 102 | beta, c, ldc); 103 | } 104 | 105 | template<> 106 | void gemm(const char *transa, const char *transb, 107 | const sizeType *m, const sizeType *n, const sizeType *k, 108 | const FloatComplex *alpha, const FloatComplex *a, 109 | const sizeType *lda, const FloatComplex *b, const sizeType *ldb, 110 | const FloatComplex *beta, FloatComplex *c, const sizeType *ldc) 111 | { 112 | #ifdef DEBUG 113 | std::cout<< "GEMM: " << transa+'\0' << " "<< transb+'\0' << " "<< *m << " " << *n << " " << *k << std::endl; 114 | #endif 115 | cgemm_(transa, transb, m, n, k, 116 | (const float _Complex*) alpha, (const float _Complex*)a, lda, (const float _Complex*)b, ldb, 117 | (const float _Complex*) beta, (float _Complex*)c, ldc); 118 | } 119 | 120 | template<> 121 | void gemm(const char *transa, const char *transb, 122 | const sizeType *m, const sizeType *n, const sizeType *k, 123 | const DoubleComplex *alpha, const DoubleComplex *a, 124 | const sizeType *lda, const DoubleComplex *b, const sizeType *ldb, 125 | const DoubleComplex *beta, DoubleComplex *c, const sizeType *ldc) 126 | { 127 | #ifdef DEBUG 128 | std::cout<< "GEMM: " << transa+'\0' << " "<< transb+'\0' << " "<< *m << " " << *n << " " << *k << std::endl; 129 | #endif 130 | zgemm_(transa, transb, m, n, k, 131 | (const double _Complex*) alpha, (const double _Complex*)a, lda, (const double _Complex*)b, ldb, 132 | (const double _Complex*) beta, (double _Complex*)c, ldc); 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /src/tensor.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2017 Paul Springer (springer@aices.rwth-aachen.de) 3 | * 4 | * This program is free software: you can redistribute it and/or modify 5 | * it under the terms of the GNU Lesser General Public License as published by 6 | * the Free Software Foundation, either version 3 of the License, or 7 | * (at your option) any later version. 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program. If not, see . 16 | */ 17 | 18 | 19 | #include 20 | 21 | namespace tcl 22 | { 23 | template 24 | Tensor Tensor::getSubTensor( 25 | const indicesType &indices, 26 | const std::vector &size) const 27 | { 28 | if( size.size() > 0 && indices.size() != size.size() ) 29 | throw std::invalid_argument( 30 | "The size of the first and second paramter does not match." ); 31 | if( indices.size() <= 0 || indices.size() > _indices.size() ) 32 | throw std::invalid_argument( 33 | "The first parameter does not match the size of the parent tensor." ); 34 | 35 | std::vector subSize(_size); 36 | int j = 0; 37 | auto itSub = indices.begin(); 38 | int i = 0; 39 | for( auto idx : _indices ) 40 | { 41 | if(itSub != indices.end() && idx == *itSub) 42 | { 43 | if( size.size() > 0 && (size[j] > _size[i] || size[j] <= 0) ) 44 | throw std::invalid_argument( 45 | "Specified size of the subtensor is invalid. It must be smaller than the size of the parent tensor." ); 46 | 47 | if( size.size() > 0 ) 48 | subSize[i] = size[j]; 49 | ++j; 50 | itSub++; 51 | } else 52 | subSize[i] = 1; 53 | ++i; 54 | } 55 | if( j != indices.size() ) 56 | throw std::invalid_argument( "Some indices could not be found." ); 57 | 58 | return Tensor(subSize, _data, _outerSize, _indices, _offsets); 59 | } 60 | 61 | template 62 | size_t Tensor::getTotalSize( const indicesType &indices ) const 63 | { 64 | size_t product = 1; 65 | if( indices.size() == 0 ) 66 | { 67 | for(int i=0; i < _size.size(); ++i) 68 | product *= _size[i]; 69 | } else { 70 | assert( _size.size() == _indices.size() ); 71 | 72 | int i = 0; 73 | for(auto idx : _indices) 74 | { 75 | if( find(idx, indices) ) 76 | product *= _size[i]; //only multiply if the dimsion was found in indices 77 | i++; 78 | } 79 | } 80 | return product; 81 | } 82 | 83 | 84 | template 85 | sizeType Tensor::getStride( const std::string &idx ) const 86 | { 87 | sizeType stride = 1; 88 | int i =0; 89 | for(auto a : _indices) 90 | { 91 | if( a == idx ) 92 | return stride; 93 | else 94 | stride *= _outerSize[i]; 95 | ++i; 96 | } 97 | 98 | return 0; // not found, this tensor does not depend on the given index 99 | } 100 | 101 | template 102 | sizeType Tensor::getSize( const std::string &idx ) const 103 | { 104 | int i =0; 105 | for(auto a : _indices) 106 | { 107 | if( a == idx ) 108 | return _size[i]; 109 | ++i; 110 | } 111 | 112 | return 0; // not found 113 | } 114 | 115 | template 116 | Tensor::Tensor( const std::vector &size, 117 | floatType *data, 118 | const std::vector &outerSize, 119 | const indicesType &indices, 120 | const std::vector &offsets 121 | ) : 122 | _data(data), 123 | _size(size), 124 | _outerSize(outerSize), 125 | _indices(indices), 126 | _offsets(offsets) 127 | { 128 | if( _outerSize.size() == 0 ) 129 | _outerSize = _size; 130 | if( _offsets.size() == 0 ) 131 | _offsets = std::vector(_size.size(),0); 132 | 133 | if( _data == nullptr ){ 134 | size_t totalSize = 1; 135 | for(auto s : _outerSize) 136 | totalSize *= s; 137 | int dummy = posix_memalign((void**) &_data, 64, sizeof(floatType) * totalSize); 138 | } 139 | } 140 | 141 | 142 | template 143 | void Tensor::print() const 144 | { 145 | std::cout<< "---------------------\n"; 146 | printVector(_indices, "Indices"); 147 | printVector(_size, "Size"); 148 | printVector(_outerSize, "Outer size"); 149 | printVector(_offsets, "Offsets"); 150 | std::cout<< "---------------------\n"; 151 | } 152 | 153 | template class Tensor; 154 | template class Tensor; 155 | template class Tensor; 156 | template class Tensor; 157 | } 158 | 159 | -------------------------------------------------------------------------------- /include/tensor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2017 Paul Springer (springer@aices.rwth-aachen.de) 3 | * 4 | * This program is free software: you can redistribute it and/or modify 5 | * it under the terms of the GNU Lesser General Public License as published by 6 | * the Free Software Foundation, either version 3 of the License, or 7 | * (at your option) any later version. 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program. If not, see . 16 | */ 17 | 18 | 19 | #ifndef TCL_TENSOR_H 20 | #define TCL_TENSOR_H 21 | 22 | /// \class 23 | /// Tensor 24 | /// 25 | /// \page pg_tensor Tensor 26 | /// 27 | /// \section sec_tensor Tensor 28 | /// A tensor is a symbolic representation of a tensor; it does not own the 29 | /// corresponding data. 30 | 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | 41 | #include 42 | 43 | #include "tcl_types.h" 44 | #include "utils.h" 45 | 46 | namespace tcl{ 47 | 48 | /*! 49 | * \brief This class represents a tensor; a tensor does _not_ own the data, it 50 | * merely interprets the data as a multidimensional-array. 51 | * 52 | * A tensor stores information about its dimensionality, size and a pointer 53 | * to its data. 54 | */ 55 | 56 | template 57 | class Tensor 58 | { 59 | public: 60 | 61 | Tensor() : _data(nullptr) {}; 62 | 63 | Tensor( const std::vector &size, 64 | floatType *data, 65 | const std::vector &outerSize = {}, 66 | const indicesType &indices = {}, 67 | const std::vector &offsets = {} 68 | ); 69 | 70 | Tensor( const Tensor& other ) : _data(other._data), 71 | _size(other._size), 72 | _outerSize(other._outerSize), 73 | _indices(other._indices), 74 | _offsets(other._offsets) 75 | {} 76 | ~Tensor() { } 77 | 78 | /** 79 | * Return a subtensor that is that is spanned by the indices. For 80 | * example, if this tensor is A_{m,n,k} \in \mathbb{R}^{M \times N \times K}} 81 | * and indices = {m,k}, then this function would return the subtensor 82 | * A_{m,k} \in \mathbb{R}^{M \times K} with a change outer size for the 83 | * m-dimension = M * N as oppossed to just M. 84 | * 85 | * \param[in] indices Indices of the desired subtensor. The indices 86 | * have to be in the same order as they appear in the parent tensor. 87 | * Moreover: 1 <= indices.size() <= this.getDim(). 88 | * \param[in] size Size of each dimension of the subtensor. size.size() == indices.size(). 89 | */ 90 | Tensor getSubTensor( 91 | const indicesType &indices, 92 | const std::vector &size = {}) const; 93 | 94 | /** 95 | * Calculate the product of the sizes of the indices specified by 'indices' 96 | * \return Product of the sizes of the specified indices 97 | */ 98 | size_t getTotalSize( const indicesType &indices = {} ) const; 99 | 100 | /** 101 | * \return The stride of the specified index 102 | */ 103 | sizeType getStride( const std::string &idx ) const; 104 | 105 | /** 106 | * \return The size of the dimension corresponding to the specified index 107 | */ 108 | sizeType getSize( const std::string &idx ) const; 109 | 110 | const std::vector& getSize() const { return _size; } 111 | const std::vector& getOuterSize() const { return _outerSize; } 112 | const indicesType& getIndices() const { return _indices; } 113 | int getDim() const { return _size.size(); } 114 | floatType* getData() const { 115 | sizeType offset = 0; 116 | sizeType stride = 1; 117 | for(int i=0; i < _outerSize.size(); ++i){ 118 | offset += stride * _offsets[i]; 119 | stride *= _outerSize[i]; 120 | } 121 | 122 | return &_data[offset]; 123 | } 124 | int getIndexPos(std::string index) const { 125 | int pos = 0; 126 | for( auto idx : _indices ) 127 | if( idx == index ) 128 | return pos; 129 | else 130 | pos++; 131 | 132 | return -1; 133 | } 134 | 135 | void setOffset(int indexPos, sizeType offset) { 136 | _offsets[indexPos] = offset; 137 | } 138 | void setData(floatType *data) { _data = data; } 139 | 140 | Tensor* operator[] (const std::string &&indices){ 141 | split(indices, ',', _indices); 142 | return this; 143 | } 144 | Tensor* operator[] (const std::string &indices){ 145 | split(indices, ',', _indices); 146 | return this; 147 | } 148 | 149 | void print() const; 150 | 151 | private: 152 | 153 | /*************************************** 154 | * private member functions 155 | ***************************************/ 156 | 157 | floatType * _data; 158 | std::vector _size; 159 | std::vector _outerSize; 160 | indicesType _indices; 161 | std::vector _offsets; 162 | }; 163 | } 164 | 165 | #endif 166 | -------------------------------------------------------------------------------- /pythonAPI/tcl/tcl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ctypes 3 | from ctypes import cdll 4 | import os 5 | import random 6 | import re 7 | 8 | TCL_ROOT = "" 9 | try: 10 | TCL_ROOT = os.environ['TCL_ROOT'] 11 | except: 12 | print("[TCL] ERROR: TCL_ROOT environment variable is not set. Point TCL_ROOT to the folder which includes TCL_ROOT/lib/libtcl.so") 13 | exit(-1) 14 | 15 | # load TCL library 16 | lib = cdll.LoadLibrary(TCL_ROOT+"/lib/libtcl.so") 17 | 18 | def randomNumaAwareInit( A ): 19 | """ 20 | initializes the passed numpy.ndarray (which have to be created with 21 | numpy.empty) and initializes it with random data in paralle such that the 22 | pages are equally distributed among the numa nodes 23 | """ 24 | lib.randomNumaAwareInit( ctypes.c_void_p(A.ctypes.data), 25 | ctypes.cast(A.ctypes.shape, ctypes.POINTER(ctypes.c_voidp)), 26 | ctypes.c_int32(A.ndim) ) 27 | 28 | 29 | def tensorMult( alpha, A, indicesA, B, indicesB, beta, C, indicesC): 30 | """ 31 | This function computes the tensor contraction of A and B, yielding C. 32 | The tensor contraction is of the form: 33 | C[indicesC] = alpha * A[indicesA] * B[indicesB] + beta * C[indicesC] 34 | 35 | where alpha and beta are scalors and A, B, and C correspond to arbitrary 36 | dimensional arrays (i.e., tensors). The dimensionality of A, B, and C 37 | depends on their indices (which need to be separated by commas). 38 | 39 | For instance, the tensor contraction C[m1,n1,m2] = 1.3 * A[k1,m1,k2,m2] * B[k1,k2,n1] 40 | would be represented as: tensorMult(1.3, A, "k1,m1,k2,m2", B, 41 | "k1,k2,n1", 0.0, C, "m1,n1,m2"). 42 | """ 43 | 44 | dataA = ctypes.c_void_p(A.ctypes.data) 45 | sizeA = ctypes.cast(A.ctypes.shape, ctypes.POINTER(ctypes.c_voidp)) 46 | outerSizeA = sizeA 47 | dataB = ctypes.c_void_p(B.ctypes.data) 48 | sizeB = ctypes.cast(B.ctypes.shape, ctypes.POINTER(ctypes.c_voidp)) 49 | outerSizeB = sizeB 50 | dataC = ctypes.c_void_p(C.ctypes.data) 51 | sizeC = ctypes.cast(C.ctypes.shape, ctypes.POINTER(ctypes.c_voidp)) 52 | outerSizeC = sizeC 53 | indicesA = indicesA.encode('utf-8') 54 | indicesB = indicesB.encode('utf-8') 55 | indicesC = indicesC.encode('utf-8') 56 | indicesA = ctypes.c_char_p(indicesA) 57 | indicesB = ctypes.c_char_p(indicesB) 58 | indicesC = ctypes.c_char_p(indicesC) 59 | useRowMajor = 0 60 | if( A.flags['C_CONTIGUOUS'] ): 61 | useRowMajor = 1 62 | 63 | if( A.dtype == np.float32 ): 64 | lib.sTensorMult(ctypes.c_float(alpha), dataA, sizeA, outerSizeA, indicesA, 65 | dataB, sizeB, outerSizeB, indicesB, 66 | ctypes.c_float(beta) , dataC, sizeC, outerSizeC, indicesC, useRowMajor) 67 | elif( A.dtype == np.float64 ): 68 | lib.dTensorMult(ctypes.c_double(alpha), dataA, sizeA, outerSizeA, indicesA, 69 | dataB, sizeB, outerSizeB, indicesB, 70 | ctypes.c_double(beta) , dataC, sizeC, outerSizeC, indicesC, useRowMajor) 71 | elif( A.dtype == np.complex64 ): 72 | lib.cTensorMult(ctypes.c_float(alpha), dataA, sizeA, outerSizeA, indicesA, 73 | dataB, sizeB, outerSizeB, indicesB, 74 | ctypes.c_float(beta), dataC, sizeC, outerSizeC, indicesC, useRowMajor) 75 | elif( A.dtype == np.complex128 ): 76 | lib.zTensorMult(ctypes.c_double(alpha), dataA, sizeA, outerSizeA, indicesA, 77 | dataB, sizeB, outerSizeB, indicesB, 78 | ctypes.c_double(beta) , dataC, sizeC, outerSizeC, indicesC, useRowMajor) 79 | else: 80 | raise NotImplementedError 81 | 82 | def equal(A, B, numSamples=-1): 83 | """ Ensures that alle elements of A and B are pretty much equal (due to limited machine precision) 84 | 85 | Parameter: 86 | numSamples: number of random samples to compare (-1: all). This values is used to approximate this function and speed the result up." 87 | """ 88 | threshold = 1e-4 89 | A = np.reshape(A, A.size) 90 | B = np.reshape(B, B.size) 91 | error = 0 92 | samples = list(range(A.size)) 93 | if( numSamples != -1 ): 94 | samples = random.sample(samples, min(A.size,numSamples)) 95 | 96 | for i in samples: 97 | Aabs = abs(A[i]); 98 | Babs = abs(B[i]); 99 | absmax = max(Aabs, Babs); 100 | diff = Aabs - Babs; 101 | if( diff < 0 ): 102 | diff *= -1 103 | if(diff > 0): 104 | relError = diff / absmax; 105 | if(relError > 4e-5 and min(Aabs,Babs) > threshold ): 106 | error += 1 107 | return error == 0 108 | 109 | def einsum(string, *arg_list): 110 | """ 111 | A wrapper around np.einsum. We call TensorMult in TCL library if a tensor multiplication is called, i.e. two operands in einsum. 112 | For all other cases, we return the result of np.einsum. 113 | For simplicity, we always cast the nd.array to be contiguousarray. One could also choose to cast all nd.array to Fortran array. 114 | 115 | Input Parameters: 116 | We take exactly the same format as np.einsum function, with a string and the operands. 117 | """ 118 | if len(list(arg_list)) != 2: 119 | return np.einsum(string, *arg_list) 120 | else: 121 | T_a, T_b = arg_list 122 | 123 | if np.isfortran(T_a) and np.isfortran(T_b): 124 | order = 'F' 125 | # or do np.asfortranarray(T_a) (T_b) order ='F' 126 | # see also https://stackoverflow.com/questions/27567876/copying-array-changes-data-from-c-contiguous-to-f-contiguous 127 | else: 128 | # There is some situation when T is not Contiguous nor Fortan 129 | T_a = np.ascontiguousarray(T_a) 130 | T_b = np.ascontiguousarray(T_b) 131 | order = 'C' 132 | 133 | # [TODO] benchmark whether 'C', 'F' affect the result 134 | 135 | # [TODO] casting to higher precesion data type? 136 | floatType = T_a.dtype 137 | if floatType != T_b.dtype: 138 | raise NotImplementedError 139 | 140 | np_indA, np_indB, np_indC = re.split(' , | ,|, |,|->', string) 141 | # [TODO] implement this ? 142 | if len(np_indC) == 0: 143 | return np.einsum(string, *arg_list) 144 | 145 | sizes = {} 146 | shapeA = T_a.shape 147 | for idx, ind in enumerate(np_indA): 148 | sizes[ind] = shapeA[idx] 149 | 150 | shapeB = T_b.shape 151 | for idx, ind in enumerate(np_indB): 152 | sizes[ind] = shapeB[idx] 153 | 154 | indA = ','.join(list(np_indA)) 155 | indB = ','.join(list(np_indB)) 156 | indC = ','.join(list(np_indC)) 157 | sizeC = [sizes[idx] for idx in np_indC] 158 | 159 | # For most case this is enough. 160 | # T_c = np.empty(sizeC, order=order, dtype=floatType) 161 | # However, it seems like for complex datatype, we need to zeros out T_c 162 | # such that it would not return wierd value. 163 | # This should be because of the interface in tensorMult is not written 164 | # for complex datatype 165 | T_c = np.zeros(sizeC, order=order, dtype=floatType) 166 | 167 | tensorMult(1, T_a, indA, T_b, indB, 0., T_c, indC) 168 | return T_c 169 | -------------------------------------------------------------------------------- /COPYING.LESSER: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | 167 | -------------------------------------------------------------------------------- /include/utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2017 Paul Springer (springer@aices.rwth-aachen.de) 3 | * 4 | * This program is free software: you can redistribute it and/or modify 5 | * it under the terms of the GNU Lesser General Public License as published by 6 | * the Free Software Foundation, either version 3 of the License, or 7 | * (at your option) any later version. 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program. If not, see . 16 | */ 17 | 18 | 19 | #ifndef TCL_UTILS_H 20 | #define TCL_UTILS_H 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | #include "tcl_types.h" 29 | 30 | extern "C" 31 | { 32 | void sgemm_(const char *transa, const char *transb, 33 | const int *m, const int *n, const int *k, 34 | const float *alpha, const float *a, 35 | const int *lda, const float *b, const int *ldb, 36 | const float *beta, float *c, const int *ldc); 37 | void dgemm_(const char *transa, const char *transb, 38 | const int *m, const int *n, const int *k, 39 | const double *alpha, const double *a, 40 | const int *lda, const double *b, const int *ldb, 41 | const double *beta, double *c, const int *ldc); 42 | void cgemm_(const char *transa, const char *transb, 43 | const int *m, const int *n, const int *k, 44 | const float _Complex *alpha, const float _Complex *a, 45 | const int *lda, const float _Complex *b, const int *ldb, 46 | const float _Complex *beta, float _Complex *c, const int *ldc); 47 | void zgemm_(const char *transa, const char *transb, 48 | const int *m, const int *n, const int *k, 49 | const double _Complex *alpha, const double _Complex *a, 50 | const int *lda, const double _Complex *b, const int *ldb, 51 | const double _Complex *beta, double _Complex *c, const int *ldc); 52 | } 53 | 54 | namespace tcl 55 | { 56 | 57 | template 58 | static double getZeroThreshold(); 59 | template<> 60 | double getZeroThreshold() { return 1e-16;} 61 | template<> 62 | double getZeroThreshold() { return 1e-16;} 63 | template<> 64 | double getZeroThreshold() { return 1e-6;} 65 | template<> 66 | double getZeroThreshold() { return 1e-6;} 67 | 68 | const char* getErrorString( error err ); 69 | 70 | /** 71 | * concatinates the two vectors a and b and stores the result into c 72 | */ 73 | template 74 | void concatinate(const T &a, 75 | const T &b, 76 | T &c) 77 | { 78 | auto endA = a.end(); 79 | for(auto it = a.begin(); it != endA; it++) 80 | c.emplace_back(*it); 81 | auto endB = b.end(); 82 | for(auto it = b.begin(); it != endB; it++) 83 | c.emplace_back(*it); 84 | } 85 | /** 86 | * intersects the two vectors a and b and stores the result into c 87 | */ 88 | template 89 | T intersect(const T &a, const T &b) 90 | { 91 | T c; 92 | for(auto x1 : a) 93 | for(auto x2 : b) 94 | if( x1 == x2 ) 95 | c.emplace_back(x1); 96 | return c; 97 | } 98 | 99 | /** 100 | * find x in l 101 | */ 102 | template 103 | bool find(const T &x, const L &l) //TODO templetize L? 104 | { 105 | for( auto a : l ) 106 | if( a == x ) 107 | return true; 108 | return false; 109 | } 110 | 111 | /** 112 | * find x in l 113 | */ 114 | template 115 | bool findPos(const T &x, const L &l) //TODO templetize L? 116 | { 117 | int count = 0; 118 | for( auto a : l ) 119 | if( a == x ) 120 | return count; 121 | else 122 | count++; 123 | return -1; 124 | } 125 | 126 | /** 127 | * remove all elements of b from a 128 | */ 129 | template 130 | T setMinus(const T &a, const T &b) 131 | { 132 | T c; 133 | for(auto item : a) 134 | if( !find(item, b) ) 135 | c.emplace_back(item); 136 | return c; 137 | } 138 | 139 | /** 140 | * \return elements of 'toBeSorted' sorted according to their order in 'order' 141 | */ 142 | template 143 | void sortAccordingTo(const T &toBeSorted, const T &order, T &sorted ) 144 | { 145 | sorted.clear(); 146 | for( auto elem : order ) 147 | if( find(elem, toBeSorted) ) 148 | sorted.emplace_back(elem); 149 | assert( sorted.size() == toBeSorted.size() ); 150 | } 151 | 152 | /** 153 | * \return true iff all indices of 'subset' appear contiguously in 'indices' 154 | */ 155 | template 156 | bool indicesContiguous(const T &subset, const T &indices) 157 | { 158 | if( subset.size() <= 0 ) 159 | return true; 160 | int contiguousCount = 0; 161 | for( auto a : indices ){ 162 | if( find(a, subset) ){ 163 | contiguousCount++; 164 | }else if( contiguousCount == subset.size() ) 165 | return true; 166 | else if( contiguousCount > 0 ) 167 | return false; 168 | } 169 | return contiguousCount == subset.size(); 170 | } 171 | 172 | //! return the number of threads used within TCL 173 | int getNumThreads(); 174 | 175 | /** 176 | * checks if the provided permutation is the identity 177 | */ 178 | bool isIdentity(const std::vector &perm); 179 | 180 | /** 181 | * permute the input according to the permutation perm. 182 | */ 183 | template 184 | std::vector permute(const std::vector &perm, const std::vector &input) 185 | { 186 | assert( input.size() == perm.size() ); 187 | 188 | std::vector output; 189 | for(int i=0; i < perm.size(); ++i){ 190 | assert( perm[i] >= 0 && perm[i] < input.size() ); 191 | output.emplace_back(input[perm[i]]); 192 | } 193 | return output; 194 | } 195 | 196 | /** 197 | * Split the input string by the delimiter 198 | * 199 | * \param[in] str input string 200 | * \param[in] delim delimiter 201 | * \return output vector cotainign all tokens 202 | */ 203 | template 204 | void split( const std::string &str, char delim, T &output ) 205 | { 206 | output.clear(); 207 | std::string token; 208 | std::stringstream stream; 209 | stream.str(str); 210 | while ( std::getline(stream, token, delim) ) 211 | output.emplace_back(token); 212 | } 213 | /** 214 | * \param[in] input Indices of the input tensor 215 | * \param[in] output Indices of the output tensor 216 | * \return the permutation that is required to rearange the indices of input to match the order of the indices specified by 'output'. 217 | */ 218 | template 219 | std::vector getPermutation( const T &input, const T &output) 220 | { 221 | assert( input.size() == output.size() ); 222 | 223 | std::vector perm; 224 | for(auto b : output) 225 | { 226 | int pos = 0; 227 | for(auto a : input) 228 | { 229 | if( b == a ){ 230 | perm.emplace_back(pos); 231 | break; 232 | } 233 | pos++; 234 | } 235 | } 236 | assert( perm.size() == output.size() ); 237 | 238 | return perm; 239 | } 240 | 241 | template 242 | void printVector(const T &vec, std::string label) 243 | { 244 | std::cout<< label << ": "; 245 | for( auto a : vec) 246 | std::cout<< a << " "; 247 | std::cout<< "\n"; 248 | } 249 | 250 | template 251 | void gemm(const char *transa, const char *transb, 252 | const sizeType *m, const sizeType *n, const sizeType *k, 253 | const floatType *alpha, const floatType *a, 254 | const sizeType *lda, const floatType *b, const sizeType *ldb, 255 | const floatType *beta, floatType *c, const sizeType *ldc); 256 | } 257 | 258 | #endif 259 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This project is no longer maintained, please consider to use its GPU counterpart: https://developer.nvidia.com/cutensor 2 | 3 | # Tensor Contraction Library (TCL) for C++ (and Python) 4 | 5 | TCL is a C++ library for high-performance tensor contractions; TCL also includes 6 | a wrapper for python and can be easily integrated into native python code. 7 | 8 | From a computational perspective, tensors 9 | can be interpreted as a generalization of matrices to multiple dimensions or simply as 10 | multidimensional arrays; likewise, tensor contractions 11 | are a generalization of the matrix-matrix multiplication to higher 12 | dimensions. For instance, A[i,k], B[k,j] and C[i,j] denote two-dimensional 13 | tensors (i.e., matrices) and C[i,j] = A[i,k] * B[k,j] represents a tensor 14 | contraction where the sum over 'k' as well as the loops over 'i' and 'j' are 15 | implicit. Further examples of tensor contractions are: 16 | 17 | * C[i0,j0,j1] = A[i0,k0] * B[j1,k0,j0] 18 | * C[i0,j0,j1,i1] = A[i0,k0,i1] * B[j1,k0,j0] 19 | * C[i0,j0,j1,i1] = A[k0,i0,k1,i1] * B[k1,j1,k0,j0] 20 | * C[i1,j1,j0,i0] = A[k0,i0,k1,k2,i1] * B[k1,j1,k0,j0,k2] 21 | * ... 22 | 23 | You can find additional information on tensor contractions in the [paper](https://arxiv.org/abs/1607.00145) listed 24 | below. 25 | 26 | # Requirements 27 | 28 | * A C++ compiler with c++11 support (I've tested it with: g++ 5.1.0, icpc 17.0.2). 29 | * Some BLAS library (e.g., [BLIS](https://github.com/flame/blis), ATLAS, MKL, 30 | OpenBlas) 31 | * [HPTT](https://github.com/springer13/hptt) for high-performance tensor transpositions 32 | 33 | # Install 34 | 35 | ## C/C++ library 36 | 37 | Install TCL's dependencies (see above). Then clone the repository into a desired directory and change to that location: 38 | 39 | git clone https://github.com/springer13/tcl.git 40 | cd tcl 41 | 42 | You might have to update the Makefile and specify the location of your BLAS and 43 | HPTT library, then continue with: 44 | 45 | make 46 | 47 | This should be it and you should see a libtcl.so in the ./lib/ directory. 48 | 49 | ## Python API 50 | 51 | To install the python API you have to: 52 | 53 | cd pythonAPI 54 | python setup.py install 55 | 56 | At this point you can import the tcl module in your python scripts and call the 57 | tcl.tensorMult() function (see ./benchmark/python/benchmark.py for further examples). 58 | 59 | Keep in mind that TCL is a multi-threaded and performance critical library. 60 | Thus, it is of great importance that you follow the following steps before you 61 | run your python script: 62 | 63 | * Specify the thread affinity (e.g., via environment variable KMP_AFFINITY, via taskset, ...) 64 | * Specify the amount of threads to be used via the OMP_NUM_THREADS environment 65 | variable. 66 | * Ensure that your python environment links against a multi-threaded BLAS (see 67 | numpy.\_\_config\_\_.show()) 68 | 69 | # Getting started 70 | 71 | First off, TCL supports any kind of tensor contractions (i.e., it is not limited 72 | to tensor contractions that can be mapped to GEMM). The idea behind TCL is that you only 73 | have to call a single function for any contraction: tcl::tensorMult(). Once you 74 | have specified the tensor contraction, TCL will _automatically_ map this tensor 75 | contraction to the most efficient kernel. 76 | 77 | TCL supports both column-major (default) and row-major data layouts. Column-major: indices are stored 78 | from left to right with the leftmost and rightmost index respectively being 79 | the fastest-varying (stride-1) index and the slowest-varying index; row-major: indices are stored 80 | from right to left. 81 | 82 | You can find an self-explanatory example under ./examples/contraction.cpp 83 | 84 | #include 85 | ... 86 | 87 | tcl::sizeType m = 5; 88 | tcl::sizeType n = 4; 89 | tcl::sizeType k1 = 2; 90 | tcl::sizeType k2 = 3; 91 | 92 | // Allocation of the Tensors (data is owned by the tensors) 93 | tcl::Tensor A({k2,m,k1}); 94 | tcl::Tensor B({n,k2,k1}); 95 | tcl::Tensor C({m,n}); 96 | 97 | // Data initialization (omitted) ... 98 | 99 | // tensor contarction: C_{m,n} = alpha * A_{k2,m,k1} * B_{n,k2,k1} + beta * C_{m,n} 100 | auto ret = tcl::tensorMult( alpha, A["k2,m,k1"], B["n,k2,k1"], beta, C["m,n"], 0 ); 101 | 102 | 103 | You just have to include the header (which can be found in ./include/) and link 104 | against tcl; an exemplary Makefile can be found in ./examples/Makefile. 105 | 106 | ## C-Interface 107 | 108 | TCL also provides a C interface: 109 | 110 | void sTensorMult(const float alpha, const float *A, const long *sizeA, const long *outerSizeA, const char* indA, 111 | const float *B, const long *sizeB, const long *outerSizeB, const char* indB, 112 | const float beta , float *C, const long *sizeC, const long *outerSizeC, const char* indC, int useRowMajor = 0); 113 | 114 | void dTensorMult(const double alpha, const double *A, const long *sizeA, const long *outerSizeA, const char* indA, 115 | const double *B, const long *sizeB, const long *outerSizeB, const char* indB, 116 | const double beta , double *C, const long *sizeC, const long *outerSizeC, const char* indC, int useRowMajor = 0); 117 | 118 | The outerSizes enable the user to operate on subtensors; the outerSize may be NULL, in that 119 | case a dense tensor with size=outerSize is assumed. 120 | 121 | ## Python-Interface 122 | 123 | TCL now also offers a python-interface. The functionality offered by TCL is comparable to that of [numpy.einsum](https://docs.scipy.org/doc/numpy/reference/generated/numpy.einsum.html): 124 | 125 | tensorMult( alpha, A, indicesA, B, indicesB, beta, C, indicesC) 126 | 127 | See docstring for additional information. 128 | 129 | Several examples can be found under ./benchmark/python/ 130 | 131 | # Key Features 132 | 133 | * Multi-threading support 134 | * TCL's current implementation is based on the 135 | Transpose-Transpose-GEMM-Transpose (TTGT) approach (see paper). 136 | * Support for single- and double-precision as well as complex data types. 137 | 138 | 139 | # Performance Results 140 | 141 | ![hptt](https://github.com/springer13/tcl/blob/master/misc/tcl_speedup.png) 142 | 143 | The above Figure presents the speedup of TCL over the best 144 | reference among multiple state-of-the-art implementations (i.e., [Eigen](http://eigen.tuxfamily.org), 145 | [Tensor Toolbox](http://www.sandia.gov/~tgkolda/TensorToolbox), [NumPy](http://www.numpy.org/), [ITensor](http://itensor.org/)) for 1000 random tensor contractions running on a two 146 | socket Intel Haswell-EP E5-2680 v3 utilizing 24 threads. The cases are sorted with respect to the arithmetic 147 | intensity of an equally-sized matrix-matrix multiplication. 148 | 149 | We make the following observations: 150 | 151 | * All speedups are well above 1.0x; phrased differently, TCL exhibits positive speedups across all 1000 random tensor contractions. 152 | * The speedups are especially high for tensor contractions with a low arithmetic intensity (left side of the plot), reaching up to 18x. 153 | * The speedups decrease with increasing arithmetic intensity, this is due to the fact that the runtime of those contractions is dominated by a large GEMM, thus attaining close to the theoretical peak floating-point performance of the CPU. 154 | 155 | You can run your own benchmarks via: 156 | 157 | ./benchmark/python/benchmark.sh 158 | 159 | Notice that the full benchmark may take hours to complete. 160 | 161 | # Current limitations 162 | 163 | TCL currently requires additional auxiliary memory for the transposed tensors. 164 | This library should eventually also support the GEMM-like Tensor-Tensor (GETT) 165 | contraction approach (see paper), which yields better performance and does not 166 | require any auxiliary memory. 167 | 168 | 169 | # License 170 | 171 | This project is under LGPLv3 for now. If this license is too restrictive for you, 172 | please feel free to contact me via email (springer@aices.rwth-aachen.de). 173 | 174 | 175 | # Citation 176 | 177 | In case you want to refer to TCL as part of a research paper, please cite the following 178 | article [(pdf)](https://arxiv.org/abs/1607.00145): 179 | ``` 180 | @article{tccg2016a, 181 | author = {Paul Springer and Paolo Bientinesi}, 182 | title = {{Design of a {H}igh-{P}erformance {GEMM}-like {T}ensor-{T}ensor {M}ultiplication}}, 183 | archivePrefix = "arXiv", 184 | eprint = {1607.00145}, 185 | primaryClass = {cs.MS, cs.PF}, 186 | journal = {CoRR}, 187 | year = {2016}, 188 | issue_date = {July 2016}, 189 | url = {http://arxiv.org/abs/1607.00145} 190 | } 191 | ``` 192 | 193 | 194 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | {one line to give the program's name and a brief idea of what it does.} 635 | Copyright (C) {year} {name of author} 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | {project} Copyright (C) {year} {fullname} 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /src/contract.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2017 Paul Springer (springer@aices.rwth-aachen.de) 3 | * 4 | * This program is free software: you can redistribute it and/or modify 5 | * it under the terms of the GNU Lesser General Public License as published by 6 | * the Free Software Foundation, either version 3 of the License, or 7 | * (at your option) any later version. 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program. If not, see . 16 | */ 17 | 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | #include 24 | 25 | #include 26 | #include 27 | 28 | //this macro is merely used to assist the compiler in its efforts to generate good code 29 | #define TCL_DUPLICATE(condition, ...) \ 30 | if (condition) { __VA_ARGS__ } \ 31 | else { __VA_ARGS__ } 32 | 33 | //#define TIMERS 34 | 35 | namespace tcl 36 | { 37 | struct TTGTCandidate{ 38 | bool transA; //! if set, then A is in form A_{k,m}; otherwise: A_{m,k} 39 | bool transB; //! if set, then B is in form B_{n,k}; otherwise: B_{k,n} 40 | bool interchangeAB; //! if set, the A and B are interchanged in the GEMM to yield a transposed C 41 | indicesType indicesA; 42 | indicesType indicesB; 43 | indicesType indicesC; 44 | 45 | bool isValid(){ 46 | auto loopIndices = intersect(indicesA, intersect(indicesB, indicesC)); 47 | //! also known as the free indices of A 48 | auto mIndices = setMinus(intersect(indicesA, indicesC), loopIndices); 49 | //! also known as the free indices of B 50 | auto nIndices = setMinus(intersect(indicesB, indicesC), loopIndices); 51 | //! also known as the contracted indices 52 | auto kIndices = setMinus(intersect(indicesA, indicesB), loopIndices); 53 | 54 | indicesType mIndicesA; sortAccordingTo(mIndices, indicesA, mIndicesA); 55 | indicesType kIndicesA; sortAccordingTo(kIndices, indicesA, kIndicesA); 56 | indicesType kIndicesB; sortAccordingTo(kIndices, indicesB, kIndicesB); 57 | indicesType nIndicesB; sortAccordingTo(nIndices, indicesB, nIndicesB); 58 | indicesType mIndicesC; sortAccordingTo(mIndices, indicesC, mIndicesC); 59 | indicesType nIndicesC; sortAccordingTo(nIndices, indicesC, nIndicesC); 60 | 61 | bool mIndicesContiguousC = indicesContiguous(mIndices, indicesC); 62 | bool nIndicesContiguousC = indicesContiguous(nIndices, indicesC); 63 | bool mIndicesContiguousA = indicesContiguous(mIndices, indicesA); 64 | bool kIndicesContiguousA = indicesContiguous(kIndices, indicesA); 65 | bool nIndicesContiguousB = indicesContiguous(nIndices, indicesB); 66 | bool kIndicesContiguousB = indicesContiguous(kIndices, indicesB); 67 | 68 | return mIndicesC == mIndicesA && 69 | nIndicesC == nIndicesB && 70 | kIndicesA == kIndicesB && 71 | mIndicesContiguousC && nIndicesContiguousC && 72 | mIndicesContiguousA && kIndicesContiguousA && 73 | kIndicesContiguousB && nIndicesContiguousB; 74 | } 75 | 76 | void print(){ 77 | printf("TransA: %d TransB: %d TransC: %d\n", transA, transB, interchangeAB); 78 | printVector(indicesA, "Indices A"); 79 | printVector(indicesB, "Indices B"); 80 | printVector(indicesC, "Indices C"); 81 | } 82 | }; 83 | 84 | double getTTGTCandidateCost(const indicesType &indicesA, const sizeType totalSizeA, 85 | const indicesType &indicesB, const sizeType totalSizeB, 86 | const indicesType &indicesC, const sizeType totalSizeC, const TTGTCandidate &candidate) 87 | { 88 | double cost = 0; 89 | auto permA = getPermutation(indicesA, candidate.indicesA); 90 | auto permB = getPermutation(indicesB, candidate.indicesB); 91 | auto permC = getPermutation(indicesC, candidate.indicesC); 92 | 93 | if( !isIdentity(permA) ) 94 | { 95 | double penalty = (permA[0] == 0 ) ? 1 : 1.1; //favor transpositions for which the first index is unchanged 96 | cost += penalty * totalSizeA; 97 | } 98 | if( !isIdentity(permB) ) 99 | { 100 | double penalty = (permB[0] == 0 ) ? 1 : 1.1; //favor transpositions for which the first index is unchanged 101 | cost += penalty * totalSizeB; 102 | } 103 | if( !isIdentity(permC) ) 104 | { 105 | double penalty = (permC[0] == 0 ) ? 1 : 1.1; //favor transpositions for which the first index is unchanged 106 | cost += penalty * totalSizeC; 107 | } 108 | return cost; 109 | } 110 | 111 | /// target indes will be aIndices(sorted w.r.t. orderA) + bIndices(sorted w.r.t. orderB) + loopIndices 112 | void concatinate_helper(const indicesType &aIndices, const indicesType &orderA, const indicesType &bIndices, const indicesType &orderB, const indicesType &loopIndices, indicesType &target) 113 | { 114 | for(auto idx : orderA ) 115 | if( find(idx, aIndices) ) 116 | target.emplace_back(idx); 117 | for(auto idx : orderB ) 118 | if( find(idx, bIndices) ) 119 | target.emplace_back(idx); 120 | for(auto idx : loopIndices) 121 | target.emplace_back(idx); 122 | } 123 | 124 | /// determine the transpositions required in order to support indices that appear in all tensors 125 | void findPerm_helper(const indicesType &indicesA, const indicesType &indicesB, const indicesType &indicesC, 126 | const indicesType &mIndices, const indicesType &nIndices, const indicesType &kIndices, const indicesType &loopIndices, 127 | indicesType &newIndicesA, indicesType &newIndicesB, indicesType &newIndicesC) 128 | { 129 | bool transAreq = false; 130 | bool transBreq = false; 131 | bool transCreq = false; 132 | for(auto idx : loopIndices) 133 | { 134 | for(auto it = indicesA.rbegin(); it != indicesA.rend() && *it != idx; it++) 135 | if( not find(*it, loopIndices) ) { 136 | transAreq = true; 137 | break; 138 | } 139 | for(auto it = indicesB.rbegin(); it != indicesB.rend() && *it != idx; it++) 140 | if( not find(*it, loopIndices) ) { 141 | transBreq = true; 142 | break; 143 | } 144 | for(auto it = indicesC.rbegin(); it != indicesC.rend() && *it != idx; it++) 145 | if( not find(*it, loopIndices) ) { 146 | transCreq = true; 147 | break; 148 | } 149 | } 150 | 151 | if( transAreq ){ 152 | if( find( indicesA.front(), mIndices ) ) //choose permutation that can presurve the stride-1 index 153 | concatinate_helper(mIndices, indicesC, kIndices, indicesB, loopIndices, newIndicesA); 154 | else 155 | concatinate_helper(kIndices, indicesB, mIndices, indicesC, loopIndices, newIndicesA); 156 | if( transBreq ){ 157 | if( find( indicesB.front(), nIndices ) ) //choose permutation that can presurve the stride-1 index 158 | concatinate_helper(nIndices, indicesC, kIndices, indicesB, loopIndices, newIndicesB); 159 | else 160 | concatinate_helper(kIndices, indicesB, nIndices, indicesC, loopIndices, newIndicesB); 161 | if( transCreq ) 162 | if( find( indicesB.front(), mIndices ) ) //choose permutation that can presurve the stride-1 index 163 | concatinate_helper(mIndices, indicesC, nIndices, indicesC, loopIndices, newIndicesC); 164 | else 165 | concatinate_helper(nIndices, indicesC, mIndices, indicesC, loopIndices, newIndicesC); 166 | }else{ 167 | if( transCreq ) 168 | if( find( indicesC.front(), mIndices ) ) //choose permutation that can presurve the stride-1 index 169 | concatinate_helper(mIndices, indicesC, nIndices, indicesB, loopIndices, newIndicesC); 170 | else 171 | concatinate_helper(nIndices, indicesB, mIndices, indicesC, loopIndices, newIndicesC); 172 | } 173 | }else{ 174 | if( transBreq ){ 175 | if( find( indicesB.front(), nIndices ) ) //choose permutation that can presurve the stride-1 index 176 | concatinate_helper(nIndices, indicesC, kIndices, indicesA, loopIndices, newIndicesB); 177 | else 178 | concatinate_helper(kIndices, indicesA, nIndices, indicesC, loopIndices, newIndicesB); 179 | if( transCreq ) 180 | if( find( indicesC.front(), mIndices ) ) //choose permutation that can presurve the stride-1 index 181 | concatinate_helper(mIndices, indicesA, nIndices, indicesC, loopIndices, newIndicesC); 182 | else 183 | concatinate_helper(nIndices, indicesC, mIndices, indicesA, loopIndices, newIndicesC); 184 | }else{ 185 | if( transCreq ) 186 | if( find( indicesC.front(), mIndices ) ) //choose permutation that can presurve the stride-1 index 187 | concatinate_helper(mIndices, indicesA, nIndices, indicesB, loopIndices, newIndicesC); 188 | else 189 | concatinate_helper(nIndices, indicesB, mIndices, indicesA, loopIndices, newIndicesC); 190 | } 191 | } 192 | if( not transAreq ) 193 | newIndicesA = indicesA; 194 | if( not transBreq ) 195 | newIndicesB = indicesB; 196 | if( not transCreq ) 197 | newIndicesC = indicesC; 198 | } 199 | 200 | static void helperTranspose(const indicesType& loopIndices, const indicesType& aIndices, const indicesType& bIndices, const indicesType& indicesT, 201 | bool &trans, indicesType& indices ) 202 | { 203 | if( aIndices.front() == indicesT.front() ) { 204 | trans = false; 205 | concatinate(aIndices, bIndices, indices); 206 | }else{ 207 | trans = true; 208 | concatinate(bIndices, aIndices, indices); 209 | } 210 | for(auto x : loopIndices) 211 | indices.emplace_back(x); 212 | } 213 | 214 | void getBestTTGTCandidate(const indicesType &indicesA, const size_t totalSizeA, 215 | const indicesType &indicesB, const size_t totalSizeB, 216 | const indicesType &indicesC, const size_t totalSizeC, 217 | const indicesType &loopIndices, 218 | const indicesType &mIndices, 219 | const indicesType &nIndices, 220 | const indicesType &kIndices, 221 | TTGTCandidate &candidate) 222 | { 223 | bool mIndicesContiguousC = indicesContiguous(mIndices, indicesC); 224 | bool nIndicesContiguousC = indicesContiguous(nIndices, indicesC); 225 | bool mIndicesContiguousA = indicesContiguous(mIndices, indicesA); 226 | bool kIndicesContiguousA = indicesContiguous(kIndices, indicesA); 227 | bool nIndicesContiguousB = indicesContiguous(nIndices, indicesB); 228 | bool kIndicesContiguousB = indicesContiguous(kIndices, indicesB); 229 | 230 | bool transposeRequiredA = !kIndicesContiguousA || !mIndicesContiguousA; 231 | bool transposeRequiredB = !kIndicesContiguousB || !nIndicesContiguousB; 232 | bool transposeRequiredC = !mIndicesContiguousC || !nIndicesContiguousC; 233 | 234 | indicesType mIndicesA; sortAccordingTo(mIndices, indicesA, mIndicesA); 235 | indicesType kIndicesA; sortAccordingTo(kIndices, indicesA, kIndicesA); 236 | indicesType kIndicesB; sortAccordingTo(kIndices, indicesB, kIndicesB); 237 | indicesType nIndicesB; sortAccordingTo(nIndices, indicesB, nIndicesB); 238 | indicesType mIndicesC; sortAccordingTo(mIndices, indicesC, mIndicesC); 239 | indicesType nIndicesC; sortAccordingTo(nIndices, indicesC, nIndicesC); 240 | 241 | // pick the best candidate that avoids as many transpositions as possible 242 | if( transposeRequiredB ){ 243 | if( transposeRequiredA ){ 244 | if( transposeRequiredC ){ 245 | //all tensors need to be transposed 246 | if( totalSizeA >= totalSizeB && totalSizeA >= totalSizeC ){ 247 | // transpose A 248 | helperTranspose(loopIndices,mIndicesA, kIndicesA, indicesA, candidate.transA, candidate.indicesA); 249 | // transpose C 250 | helperTranspose(loopIndices,mIndicesA, nIndicesC, indicesC, candidate.interchangeAB, candidate.indicesC); 251 | // transpose B 252 | helperTranspose(loopIndices,kIndicesA, nIndicesC, indicesB, candidate.transB, candidate.indicesB); 253 | } else if( totalSizeC >= totalSizeB && totalSizeC >= totalSizeA ){ 254 | // transpose C 255 | helperTranspose(loopIndices,mIndicesC, nIndicesC, indicesC, candidate.interchangeAB, candidate.indicesC); 256 | // transpose A 257 | helperTranspose(loopIndices,mIndicesC, kIndicesA, indicesA, candidate.transA, candidate.indicesA); 258 | // transpose B 259 | helperTranspose(loopIndices,kIndicesA, nIndicesC, indicesB, candidate.transB, candidate.indicesB); 260 | } else if( totalSizeB >= totalSizeA && totalSizeB >= totalSizeC ){ 261 | // transpose B 262 | helperTranspose(loopIndices,kIndicesB, nIndicesB, indicesB, candidate.transB, candidate.indicesB); 263 | // transpose C 264 | helperTranspose(loopIndices,mIndicesC, nIndicesB, indicesC, candidate.interchangeAB, candidate.indicesC); 265 | // transpose A 266 | helperTranspose(loopIndices,mIndicesC, kIndicesB, indicesA, candidate.transA, candidate.indicesA); 267 | } 268 | }else{ 269 | if( totalSizeB > totalSizeA ){ 270 | // transpose B 271 | helperTranspose(loopIndices,kIndicesB, nIndicesC, indicesB, candidate.transB, candidate.indicesB); 272 | // transpose A 273 | helperTranspose(loopIndices,mIndicesC, kIndicesB, indicesA, candidate.transA, candidate.indicesA); 274 | }else{ 275 | // transpose A 276 | helperTranspose(loopIndices,mIndicesC, kIndicesA, indicesA, candidate.transA, candidate.indicesA); 277 | // transpose B 278 | helperTranspose(loopIndices,kIndicesA, nIndicesC, indicesB, candidate.transB, candidate.indicesB); 279 | } 280 | //dont transpose C 281 | candidate.indicesC = indicesC; 282 | candidate.interchangeAB = nIndicesC.front() == indicesC.front(); 283 | } 284 | }else{ 285 | if( transposeRequiredC ){ 286 | if( totalSizeC > totalSizeB ){ 287 | // transpose C 288 | helperTranspose(loopIndices,mIndicesA, nIndicesC, indicesC, candidate.interchangeAB, candidate.indicesC); 289 | // transpose B 290 | helperTranspose(loopIndices,kIndicesA, nIndicesC, indicesB, candidate.transB, candidate.indicesB); 291 | }else{ 292 | // transpose B 293 | helperTranspose(loopIndices,kIndicesA, nIndicesB, indicesB, candidate.transB, candidate.indicesB); 294 | // transpose C 295 | helperTranspose(loopIndices,mIndicesA, nIndicesB, indicesC, candidate.interchangeAB, candidate.indicesC); 296 | } 297 | //dont transpose A 298 | candidate.indicesA = indicesA; 299 | candidate.transA = kIndicesA.front() == indicesA.front(); 300 | }else{ 301 | if( mIndicesC != mIndicesA ){ 302 | // either A or C also needs to be permuted 303 | if( totalSizeA > totalSizeC ){ 304 | //dont transpose A 305 | candidate.indicesA = indicesA; 306 | candidate.transA = kIndicesA.front() == indicesA.front(); 307 | // transpose C 308 | helperTranspose(loopIndices,mIndicesA, nIndicesC, indicesC, candidate.interchangeAB, candidate.indicesC); 309 | }else{ 310 | //dont transpose C 311 | candidate.indicesC = indicesC; 312 | candidate.interchangeAB = nIndicesC.front() == indicesC.front(); 313 | // transpose A 314 | helperTranspose(loopIndices,mIndicesC, kIndicesA, indicesA, candidate.transA, candidate.indicesA); 315 | } 316 | }else{ 317 | //dont transpose A 318 | candidate.indicesA = indicesA; 319 | candidate.transA = kIndicesA.front() == indicesA.front(); 320 | //dont transpose C 321 | candidate.indicesC = indicesC; 322 | candidate.interchangeAB = nIndicesC.front() == indicesC.front(); 323 | } 324 | // transpose B 325 | helperTranspose(loopIndices,kIndicesA, nIndicesC, indicesB, candidate.transB, candidate.indicesB); 326 | } 327 | } 328 | }else{ 329 | if( transposeRequiredA ){ 330 | if( transposeRequiredC ){ 331 | if( totalSizeC > totalSizeA ){ 332 | // transpose C 333 | helperTranspose(loopIndices,mIndicesC, nIndicesB, indicesC, candidate.interchangeAB, candidate.indicesC); 334 | // transpose A 335 | helperTranspose(loopIndices,mIndicesC, kIndicesB, indicesA, candidate.transA, candidate.indicesA); 336 | }else{ 337 | // transpose A 338 | helperTranspose(loopIndices,mIndicesA, kIndicesB, indicesA, candidate.transA, candidate.indicesA); 339 | // transpose C 340 | helperTranspose(loopIndices,mIndicesA, nIndicesB, indicesC, candidate.interchangeAB, candidate.indicesC); 341 | } 342 | //dont transpose B 343 | candidate.indicesB = indicesB; 344 | candidate.transB = nIndicesB.front() == indicesB.front(); 345 | }else{ 346 | if( nIndicesC != nIndicesB ){ 347 | // either B or C also needs to be permuted 348 | if( totalSizeB > totalSizeC ){ 349 | //dont transpose B 350 | candidate.indicesB = indicesB; 351 | candidate.transB = nIndicesB.front() == indicesB.front(); 352 | // transpose C 353 | helperTranspose(loopIndices,mIndicesC, nIndicesB, indicesC, candidate.interchangeAB, candidate.indicesC); 354 | }else{ 355 | //dont transpose C 356 | candidate.indicesC = indicesC; 357 | candidate.interchangeAB = nIndicesC.front() == indicesC.front(); 358 | // transpose B 359 | helperTranspose(loopIndices,kIndicesB, nIndicesC, indicesB, candidate.transB, candidate.indicesB); 360 | } 361 | }else{ 362 | //dont transpose B 363 | candidate.indicesB = indicesB; 364 | candidate.transB = nIndicesB.front() == indicesB.front(); 365 | //dont transpose C 366 | candidate.indicesC = indicesC; 367 | candidate.interchangeAB = nIndicesC.front() == indicesC.front(); 368 | } 369 | // transpose A 370 | helperTranspose(loopIndices,mIndicesC, kIndicesB, indicesA, candidate.transA, candidate.indicesA); 371 | } 372 | }else{ 373 | if( transposeRequiredC ){ //not a, not b, c 374 | if( kIndicesA != kIndicesB ){ 375 | // either A or B also needs to be permuted 376 | if( totalSizeA > totalSizeB ){ 377 | //dont transpose A 378 | candidate.indicesA = indicesA; 379 | candidate.transA = kIndicesA.front() == indicesA.front(); 380 | // transpose B 381 | helperTranspose(loopIndices,kIndicesA, nIndicesB, indicesB, candidate.transB, candidate.indicesB); 382 | }else{ 383 | //dont transpose B 384 | candidate.indicesB = indicesB; 385 | candidate.transB = nIndicesB.front() == indicesB.front(); 386 | // transpose A 387 | helperTranspose(loopIndices,mIndicesA, kIndicesB, indicesA, candidate.transA, candidate.indicesA); 388 | } 389 | }else{ 390 | //dont transpose A 391 | candidate.indicesA = indicesA; 392 | candidate.transA = kIndicesA.front() == indicesA.front(); 393 | //dont transpose B 394 | candidate.indicesB = indicesB; 395 | candidate.transB = nIndicesB.front() == indicesB.front(); 396 | } 397 | // transpose C 398 | helperTranspose(loopIndices,mIndicesA, nIndicesB, indicesC, candidate.interchangeAB, candidate.indicesC); 399 | }else{ 400 | if( kIndicesA != kIndicesB ){ 401 | // transpose A or B 402 | if( mIndicesA != mIndicesC ){ 403 | // transpose A or C 404 | if( nIndicesC != nIndicesB ){ 405 | // transpose C or B 406 | // transpose atleast two tensors 407 | if( totalSizeA >= totalSizeB && totalSizeA >= totalSizeC ){ 408 | //dont transpose A 409 | candidate.indicesA = indicesA; 410 | candidate.transA = kIndicesA.front() == indicesA.front(); 411 | // transpose B 412 | helperTranspose(loopIndices,kIndicesA, nIndicesB, indicesB, candidate.transB, candidate.indicesB); 413 | // transpose C 414 | helperTranspose(loopIndices,mIndicesA, nIndicesB, indicesC, candidate.interchangeAB, candidate.indicesC); 415 | }else if( totalSizeC >= totalSizeB && totalSizeC >= totalSizeA ){ 416 | //dont transpose C 417 | candidate.indicesC = indicesC; 418 | candidate.interchangeAB = nIndicesC.front() == indicesC.front(); 419 | // transpose B 420 | helperTranspose(loopIndices,kIndicesB, nIndicesC, indicesB, candidate.transB, candidate.indicesB); 421 | // transpose A 422 | helperTranspose(loopIndices,mIndicesC, kIndicesB, indicesA, candidate.transA, candidate.indicesA); 423 | }else{ 424 | //dont transpose B 425 | candidate.indicesB = indicesB; 426 | candidate.transB = nIndicesB.front() == indicesB.front(); 427 | // transpose A 428 | helperTranspose(loopIndices,mIndicesA, kIndicesB, indicesA, candidate.transA, candidate.indicesA); 429 | // transpose C 430 | helperTranspose(loopIndices,mIndicesA, nIndicesB, indicesC, candidate.interchangeAB, candidate.indicesC); 431 | } 432 | }else{ // A only is possible 433 | if( totalSizeA <= totalSizeB + totalSizeC ){ 434 | // transpose A 435 | helperTranspose(loopIndices,mIndicesC, kIndicesB, indicesA, candidate.transA, candidate.indicesA); 436 | //dont transpose B 437 | candidate.indicesB = indicesB; 438 | candidate.transB = nIndicesB.front() == indicesB.front(); 439 | //dont transpose C 440 | candidate.indicesC = indicesC; 441 | candidate.interchangeAB = nIndicesC.front() == indicesC.front(); 442 | }else{ 443 | //dont transpose A 444 | candidate.indicesA = indicesA; 445 | candidate.transA = kIndicesA.front() == indicesA.front(); 446 | // transpose B 447 | helperTranspose(loopIndices,kIndicesA, nIndicesB, indicesB, candidate.transB, candidate.indicesB); 448 | // transpose C 449 | helperTranspose(loopIndices,mIndicesA, nIndicesB, indicesC, candidate.interchangeAB, candidate.indicesC); 450 | } 451 | } 452 | }else{ 453 | if( nIndicesC != nIndicesB ){ 454 | // transpose C or B 455 | // either only B or A and C 456 | if( totalSizeB <= totalSizeA + totalSizeC ){ 457 | //dont transpose A 458 | candidate.indicesA = indicesA; 459 | candidate.transA = kIndicesA.front() == indicesA.front(); 460 | //dont transpose C 461 | candidate.indicesC = indicesC; 462 | candidate.interchangeAB = nIndicesC.front() == indicesC.front(); 463 | // transpose B 464 | helperTranspose(loopIndices,kIndicesA, nIndicesC, indicesB, candidate.transB, candidate.indicesB); 465 | }else{ 466 | //dont transpose B 467 | candidate.indicesB = indicesB; 468 | candidate.transB = nIndicesB.front() == indicesB.front(); 469 | // transpose C 470 | helperTranspose(loopIndices,mIndicesC, nIndicesB, indicesC, candidate.interchangeAB, candidate.indicesC); 471 | // transpose A 472 | helperTranspose(loopIndices,mIndicesC, kIndicesB, indicesA, candidate.transA, candidate.indicesA); 473 | } 474 | }else{ 475 | // transpose A or B 476 | if( totalSizeA > totalSizeB ){ 477 | //dont transpose A 478 | candidate.indicesA = indicesA; 479 | candidate.transA = kIndicesA.front() == indicesA.front(); 480 | //dont transpose C 481 | candidate.indicesC = indicesC; 482 | candidate.interchangeAB = nIndicesC.front() == indicesC.front(); 483 | // transpose B 484 | helperTranspose(loopIndices,kIndicesA, nIndicesC, indicesB, candidate.transB, candidate.indicesB); 485 | }else{ 486 | //dont transpose B 487 | candidate.indicesB = indicesB; 488 | candidate.transB = nIndicesB.front() == indicesB.front(); 489 | //dont transpose C 490 | candidate.indicesC = indicesC; 491 | candidate.interchangeAB = nIndicesC.front() == indicesC.front(); 492 | // transpose A 493 | helperTranspose(loopIndices,mIndicesC, kIndicesB, indicesA, candidate.transA, candidate.indicesA); 494 | } 495 | } 496 | } 497 | }else{ 498 | if( mIndicesA != mIndicesC ){ 499 | // transpose A or C 500 | if( nIndicesC != nIndicesB ){ 501 | // transpose C or B 502 | if( totalSizeC <= totalSizeA + totalSizeB ){ 503 | //dont transpose A 504 | candidate.indicesA = indicesA; 505 | candidate.transA = kIndicesA.front() == indicesA.front(); 506 | //dont transpose B 507 | candidate.indicesB = indicesB; 508 | candidate.transB = nIndicesB.front() == indicesB.front(); 509 | // transpose C 510 | helperTranspose(loopIndices,mIndicesA, nIndicesB, indicesC, candidate.interchangeAB, candidate.indicesC); 511 | }else{ 512 | //dont transpose C 513 | candidate.indicesC = indicesC; 514 | candidate.interchangeAB = nIndicesC.front() == indicesC.front(); 515 | // transpose A 516 | helperTranspose(loopIndices,mIndicesC, kIndicesA, indicesA, candidate.transA, candidate.indicesA); 517 | // transpose B 518 | helperTranspose(loopIndices,kIndicesA, nIndicesC, indicesB, candidate.transB, candidate.indicesB); 519 | } 520 | }else{ 521 | if( totalSizeA > totalSizeC ){ 522 | //dont transpose A 523 | candidate.indicesA = indicesA; 524 | candidate.transA = kIndicesA.front() == indicesA.front(); 525 | //dont transpose B 526 | candidate.indicesB = indicesB; 527 | candidate.transB = nIndicesB.front() == indicesB.front(); 528 | // transpose C 529 | helperTranspose(loopIndices,mIndicesA, nIndicesB, indicesC, candidate.interchangeAB, candidate.indicesC); 530 | }else{ 531 | //dont transpose B 532 | candidate.indicesB = indicesB; 533 | candidate.transB = nIndicesB.front() == indicesB.front(); 534 | //dont transpose C 535 | candidate.indicesC = indicesC; 536 | candidate.interchangeAB = nIndicesC.front() == indicesC.front(); 537 | // transpose A 538 | helperTranspose(loopIndices,mIndicesC, kIndicesB, indicesA, candidate.transA, candidate.indicesA); 539 | } 540 | } 541 | }else{ 542 | //dont transpose A 543 | candidate.indicesA = indicesA; 544 | candidate.transA = kIndicesA.front() == indicesA.front(); 545 | if( nIndicesC != nIndicesB ){ 546 | // transpose C or B 547 | if( totalSizeC > totalSizeB ){ 548 | //dont transpose C 549 | candidate.indicesC = indicesC; 550 | candidate.interchangeAB = nIndicesC.front() == indicesC.front(); 551 | // transpose B 552 | helperTranspose(loopIndices,kIndicesA, nIndicesC, indicesB, candidate.transB, candidate.indicesB); 553 | }else{ 554 | //dont transpose B 555 | candidate.indicesB = indicesB; 556 | candidate.transB = nIndicesB.front() == indicesB.front(); 557 | // transpose C 558 | helperTranspose(loopIndices,mIndicesA, nIndicesB, indicesC, candidate.interchangeAB, candidate.indicesC); 559 | } 560 | }else{ 561 | //dont transpose B 562 | candidate.indicesB = indicesB; 563 | candidate.transB = nIndicesB.front() == indicesB.front(); 564 | //dont transpose C 565 | candidate.indicesC = indicesC; 566 | candidate.interchangeAB = nIndicesC.front() == indicesC.front(); 567 | } 568 | } 569 | } 570 | } 571 | } 572 | } 573 | #ifdef DEBUG 574 | if( !candidate.isValid() ){ 575 | std::cerr<< "ERROR: TTGT candidate is invalid.\n"; 576 | std::cout<< transposeRequiredA << transposeRequiredB << transposeRequiredC < 589 | void batchedGEMM( const floatType alpha, const floatType* A, const floatType *B, const floatType beta, floatType *C, const TTGTCandidate &candidate, 590 | const sizeType m, const sizeType n, const sizeType k, const std::vector &sizes, 591 | const std::vector &stridesA, const std::vector &stridesB, const std::vector &stridesC, const int level) 592 | { 593 | if( level >= stridesA.size() ) 594 | { 595 | if( candidate.interchangeAB ) 596 | if( candidate.transA ) 597 | if( candidate.transB ) 598 | // n,m <- n,k x k,m 599 | gemm("N", "N", &n, &m, &k, 600 | &alpha, B, &n, A, &k, &beta, C, &n); 601 | else 602 | // n,m <- k,n x k,m 603 | gemm("T", "N", &n, &m, &k, 604 | &alpha, B, &k, A, &k, &beta, C, &n); 605 | else 606 | if( candidate.transB ) 607 | // n,m <- n,k x m,k 608 | gemm("N", "T", &n, &m, &k, 609 | &alpha, B, &n, A, &m, &beta, C, &n); 610 | else 611 | // n,m <- k,n x m,k 612 | gemm("T", "T", &n, &m, &k, 613 | &alpha, B, &k, A, &m, &beta, C, &n); 614 | else 615 | if( candidate.transA ) 616 | if( candidate.transB ) 617 | // m,n <- k,m x n,k 618 | gemm("T", "T", &m, &n, &k, 619 | &alpha, A, &k, B, &n, &beta, C, &m); 620 | else 621 | // m,n <- k,m x k,n 622 | gemm("T", "N", &m, &n, &k, 623 | &alpha, A, &k, B, &k, &beta, C, &m); 624 | else 625 | if( candidate.transB ) 626 | // m,n <- m,k x n,k 627 | gemm("N", "T", &m, &n, &k, 628 | &alpha, A, &m, B, &n, &beta, C, &m); 629 | else 630 | // m,n <- m,k x k,n 631 | gemm("N", "N", &m, &n, &k, 632 | &alpha, A, &m, B, &k, &beta, C, &m); 633 | } else { 634 | auto end = sizes[level]; 635 | const auto strideA = stridesA[level]; 636 | const auto strideB = stridesB[level]; 637 | const auto strideC = stridesC[level]; 638 | //TODO call batched GEMM 639 | #pragma omp parallel for if(level==0) 640 | for(sizeType i=0; i < end; ++i) 641 | batchedGEMM( alpha, A + i * strideA, B + i * strideB, beta, C + i * strideC, 642 | candidate, m, n, k, sizes, stridesA, stridesB, stridesC, level+1 ); 643 | } 644 | } 645 | 646 | template 647 | error contractTTGT(const floatType alpha, const Tensor *A, const Tensor *B, const floatType beta, Tensor *C) 648 | { 649 | #ifdef TIMERS 650 | auto start = omp_get_wtime(); 651 | #endif 652 | 653 | auto indicesA = A->getIndices(); 654 | auto indicesB = B->getIndices(); 655 | auto indicesC = C->getIndices(); 656 | auto loopIndices = intersect(indicesA, intersect(indicesB, indicesC)); 657 | //! also known as the free indices of A 658 | auto mIndices = setMinus(intersect(indicesA, indicesC), loopIndices); 659 | //! also known as the free indices of B 660 | auto nIndices = setMinus(intersect(indicesB, indicesC), loopIndices); 661 | //! also known as the contracted indices 662 | auto kIndices = setMinus(intersect(indicesA, indicesB), loopIndices); 663 | 664 | if( mIndices.size() <= 0 || nIndices.size() <= 0 || kIndices.size() <= 0 ) // TTGT is not applicable; use fallback 665 | return contract(alpha, A, B, beta, C); 666 | 667 | auto totalSizeA = A->getTotalSize() * sizeof(floatType); 668 | auto totalSizeB = B->getTotalSize() * sizeof(floatType); 669 | auto totalSizeC = C->getTotalSize() * sizeof(floatType); 670 | TTGTCandidate candidate; 671 | 672 | if( loopIndices.size() > 0 ) 673 | { 674 | // ensure that loop indices are the outermost indices 675 | indicesType newIndicesA, newIndicesB, newIndicesC; 676 | findPerm_helper(indicesA, indicesB, indicesC, mIndices, nIndices, kIndices, loopIndices, 677 | newIndicesA, newIndicesB, newIndicesC); 678 | 679 | getBestTTGTCandidate(newIndicesA, totalSizeA, 680 | newIndicesB, totalSizeB, 681 | newIndicesC, totalSizeC, 682 | loopIndices, mIndices, nIndices, kIndices, candidate); 683 | 684 | }else{ 685 | getBestTTGTCandidate(indicesA, totalSizeA, 686 | indicesB, totalSizeB, 687 | indicesC, totalSizeC, 688 | loopIndices, mIndices, nIndices, kIndices, candidate); 689 | } 690 | 691 | #ifdef TIMERS 692 | start = omp_get_wtime(); 693 | #endif 694 | /********************* 695 | * Request auxiliary memory form the memory broker 696 | *********************/ 697 | auto permA = getPermutation(indicesA, candidate.indicesA); 698 | auto permB = getPermutation(indicesB, candidate.indicesB); 699 | auto permC = getPermutation(candidate.indicesC, indicesC); 700 | 701 | size_t requestedSize = 0; 702 | if( !isIdentity(permA) ) 703 | requestedSize += totalSizeA; 704 | if( !isIdentity(permB) ) 705 | requestedSize += totalSizeB; 706 | if( !isIdentity(permC) ) 707 | requestedSize += totalSizeC; 708 | 709 | if( requestedSize > memBroker.size() ) 710 | { 711 | if( memBroker.isInit() ) 712 | memBroker.release(); 713 | memBroker.alloc( requestedSize ); 714 | } 715 | 716 | int numThreads = getNumThreads(); 717 | #ifdef TIMERS 718 | auto memTime = omp_get_wtime() - start; 719 | printf("TCL membroker: %f\n",memTime); 720 | #endif 721 | 722 | floatType *bufferA(nullptr), *bufferB(nullptr), *bufferC(nullptr); 723 | floatType betaGEMM = beta; 724 | 725 | /********************* 726 | * Transpose A 727 | *********************/ 728 | #ifdef TIMERS 729 | start = omp_get_wtime(); 730 | #endif 731 | if( !isIdentity(permA) ) { 732 | // packing of A is required 733 | auto sizeA = permute(permA, A->getSize()); 734 | bufferA = (floatType*) memBroker.requestMemory(totalSizeA); 735 | #ifdef DEBUG 736 | printVector(A->getSize(), "sizeA in"); 737 | printVector(A->getOuterSize(), "outerSizeA in"); 738 | printVector(permA, "permA"); 739 | printVector(sizeA, "sizeA out"); 740 | #endif 741 | // create plan for Transposition 742 | auto plan = hptt::create_plan( permA, A->getDim(), 743 | 1, A->getData(), A->getSize(), A->getOuterSize(), 744 | 0, bufferA, sizeA, hptt::ESTIMATE, numThreads); 745 | plan->execute(); 746 | }else 747 | bufferA = A->getData(); 748 | #ifdef TIMERS 749 | auto packTimeA = omp_get_wtime() - start; 750 | printf("TCL A: %f\n",packTimeA); 751 | #endif 752 | 753 | /********************* 754 | * Transpose B 755 | *********************/ 756 | #ifdef TIMERS 757 | start = omp_get_wtime(); 758 | #endif 759 | if( !isIdentity(permB) ) 760 | { 761 | auto sizeB = permute(permB, B->getSize()); 762 | #ifdef DEBUG 763 | printVector(B->getSize(), "sizeB in"); 764 | printVector(B->getOuterSize(), "outerSizeB in"); 765 | printVector(permB, "permB"); 766 | printVector(sizeB, "sizeB out"); 767 | #endif 768 | // packing of B is required 769 | bufferB = (floatType*) memBroker.requestMemory(totalSizeB); 770 | // create plan for Transposition 771 | auto plan = hptt::create_plan( permB, B->getDim(), 772 | 1, B->getData(), B->getSize(), B->getOuterSize(), 773 | 0, bufferB, sizeB, hptt::ESTIMATE, numThreads); 774 | plan->execute(); 775 | }else 776 | bufferB = B->getData(); 777 | #ifdef TIMERS 778 | auto packTimeB = omp_get_wtime() - start; 779 | printf("TCL B: %f\n",packTimeB); 780 | #endif 781 | 782 | if( !isIdentity(permC) ) { 783 | betaGEMM = 0; 784 | bufferC = (floatType*) memBroker.requestMemory(totalSizeC); 785 | }else 786 | bufferC = C->getData(); 787 | 788 | /********************* 789 | * GEMM 790 | *********************/ 791 | std::vector stridesA, stridesB, stridesC, sizes; 792 | if( loopIndices.size() > 0 ){ 793 | auto m = A->getTotalSize(mIndices); 794 | auto n = B->getTotalSize(nIndices); 795 | auto k = A->getTotalSize(kIndices); 796 | auto strideA = m * k; 797 | auto strideB = n * k; 798 | auto strideC = m * n; 799 | for( auto idx : loopIndices ) 800 | sizes.emplace_back(A->getSize(idx)); 801 | for( auto idx : candidate.indicesA ){ 802 | if( find(idx, loopIndices ) ) { 803 | stridesA.emplace_back(strideA); 804 | strideA *= A->getSize(idx); 805 | } 806 | } 807 | for( auto idx : candidate.indicesB ){ 808 | if( find(idx, loopIndices ) ) { 809 | stridesB.emplace_back(strideB); 810 | strideB *= B->getSize(idx); 811 | } 812 | } 813 | for( auto idx : candidate.indicesC ){ 814 | if( find(idx, loopIndices ) ) { 815 | stridesC.emplace_back(strideC); 816 | strideC *= C->getSize(idx); 817 | } 818 | } 819 | } 820 | 821 | #ifdef TIMERS 822 | start = omp_get_wtime(); 823 | #endif 824 | 825 | batchedGEMM( alpha, bufferA, bufferB, betaGEMM, bufferC, 826 | candidate, A->getTotalSize(mIndices), B->getTotalSize(nIndices), A->getTotalSize(kIndices), 827 | sizes, stridesA, stridesB, stridesC, 0 ); 828 | 829 | #ifdef TIMERS 830 | auto gemmTime = omp_get_wtime() - start; 831 | printf("TCL GEMM: %f\n", gemmTime); 832 | #endif 833 | 834 | /********************* 835 | * untranspose C 836 | *********************/ 837 | #ifdef TIMERS 838 | start = omp_get_wtime(); 839 | #endif 840 | if( !isIdentity(permC) ) { 841 | auto invPermC = getPermutation(indicesC, candidate.indicesC); 842 | auto sizeC = permute(invPermC, C->getSize()); 843 | #ifdef DEBUG 844 | printVector(C->getSize(), "sizeC out"); 845 | printVector(C->getOuterSize(), "outerSizeC out"); 846 | printVector(permC, "permC"); 847 | printVector(sizeC, "sizeC in"); 848 | #endif 849 | // create plan for Transposition 850 | auto plan = hptt::create_plan( permC, C->getDim(), 851 | 1, bufferC, sizeC, sizeC, 852 | beta, C->getData(), C->getOuterSize(), hptt::ESTIMATE, numThreads); 853 | plan->execute(); 854 | } 855 | #ifdef TIMERS 856 | auto packTimeC = omp_get_wtime() - start; 857 | printf("TCL: %f GFLOPS %f %f\n", totalSizeM*totalSizeN*totalSizeK*2. / 1e9 / (gemmTime+packTimeA+packTimeB+packTimeC), getBestTime / (gemmTime+packTimeA+packTimeB+packTimeC+getBestTime), (gemmTime+packTimeA+packTimeB+packTimeC+getBestTime)); 858 | #endif 859 | 860 | //free memory 861 | memBroker.reset(); 862 | return SUCCESS; 863 | } 864 | 865 | template 866 | floatType contractedLoops(const std::vector &stridesA, const floatType *dataA, 867 | const std::vector &stridesB, const floatType *dataB, 868 | const std::vector &sizes, const int loop) 869 | { 870 | const auto size = sizes[loop]; 871 | const auto strideA = stridesA[loop]; 872 | const auto strideB = stridesB[loop]; 873 | floatType tmp = 0; 874 | if( loop < sizes.size() - 1) 875 | for(int m=0; m < size; ++m) 876 | tmp += contractedLoops(stridesA, &dataA[m*strideA], stridesB, &dataB[m*strideB], sizes, loop+1); 877 | else 878 | TCL_DUPLICATE(strideA==1, // help the compiler to optimize the code 879 | TCL_DUPLICATE(strideB==1 , 880 | for(int m=0; m < size; ++m) 881 | tmp += dataA[m * strideA] * dataB[m*strideB]; 882 | ) 883 | ) 884 | 885 | return tmp; 886 | } 887 | 888 | template 889 | void freeLoops(floatType alpha, const std::vector &stridesA, const floatType *dataA, 890 | const std::vector &stridesB, const floatType *dataB, 891 | floatType beta, const std::vector &stridesC, floatType *dataC, 892 | const std::vector &sizes, const int numFreeLoops, const int numContractedLoops, const int loop) 893 | { 894 | if( numContractedLoops == 0 ) 895 | { 896 | if( numParallel == 2) 897 | { 898 | const auto s0 = sizes[0]; 899 | const auto s1 = sizes[1]; 900 | const auto stridesC1 = stridesC[1]; 901 | const auto stridesC0 = stridesC[0]; 902 | const auto stridesA1 = stridesA[1]; 903 | const auto stridesA0 = stridesA[0]; 904 | const auto stridesB1 = stridesB[1]; 905 | const auto stridesB0 = stridesB[0]; 906 | if( numParallel == numFreeLoops ){ 907 | #pragma omp parallel for collapse(2) 908 | for(int i=0; i < s0; ++i) 909 | for(int j=0; j < s1; ++j) 910 | TCL_DUPLICATE(stridesC1==0, 911 | TCL_DUPLICATE(stridesC1==1, 912 | TCL_DUPLICATE(stridesA1==0, 913 | TCL_DUPLICATE(stridesA1==1, 914 | TCL_DUPLICATE(stridesB1==0, 915 | TCL_DUPLICATE(stridesB1==1, 916 | if( betaIsZero ) 917 | dataC[i*stridesC0 + j*stridesC1] = alpha * dataA[i*stridesA0 + j*stridesA1] * dataB[i*stridesB0 + j*stridesB1]; 918 | else 919 | dataC[i*stridesC0 + j*stridesC1] = alpha * dataA[i*stridesA0 + j*stridesA1] * dataB[i*stridesB0 + j*stridesB1] + beta * dataC[i*stridesC0 + j*stridesC1]; 920 | )))))) 921 | }else{ 922 | #pragma omp parallel for collapse(2) 923 | for(int i=0; i < s0; ++i) 924 | for(int j=0; j < s1; ++j) 925 | freeLoops(alpha, stridesA, &dataA[i*stridesA0 + j*stridesA1], stridesB, &dataB[i*stridesB0 + j*stridesB1], beta, stridesC, &dataC[i*stridesC0 + j*stridesC1], 926 | sizes, numFreeLoops, numContractedLoops, loop+2); 927 | } 928 | }else if (numParallel == 1){ 929 | const auto s0 = sizes[0]; 930 | const auto stridesC0 = stridesC[0]; 931 | const auto stridesA0 = stridesA[0]; 932 | const auto stridesB0 = stridesB[0]; 933 | #pragma omp parallel for 934 | for(int i=0; i < s0; ++i) 935 | TCL_DUPLICATE(stridesC0==0, 936 | TCL_DUPLICATE(stridesC0==1, 937 | TCL_DUPLICATE(stridesA0==0, 938 | TCL_DUPLICATE(stridesA0==1, 939 | TCL_DUPLICATE(stridesB0==0, 940 | TCL_DUPLICATE(stridesB0==1, 941 | if( betaIsZero ) 942 | dataC[i * stridesC0] = alpha * dataA[i * stridesA0] * dataB[i * stridesB0]; 943 | else 944 | dataC[i * stridesC0] = alpha * dataA[i * stridesA0] * dataB[i * stridesB0] + beta * dataC[i * stridesC0]; 945 | )))))) 946 | }else{ 947 | // recurse without spawning threads 948 | for(int i=0; i < sizes[loop]; ++i) 949 | freeLoops(alpha, stridesA, &dataA[i*stridesA[loop]], stridesB, &dataB[i*stridesB[loop]], beta, stridesC, &dataC[i*stridesC[loop]], 950 | sizes, numFreeLoops, numContractedLoops, loop+1); 951 | } 952 | }else{ 953 | if( numParallel == 2) 954 | { 955 | const auto s0 = sizes[0]; 956 | const auto s1 = sizes[1]; 957 | const auto stridesC1 = stridesC[1]; 958 | const auto stridesC0 = stridesC[0]; 959 | const auto stridesA1 = stridesA[1]; 960 | const auto stridesA0 = stridesA[0]; 961 | const auto stridesB1 = stridesB[1]; 962 | const auto stridesB0 = stridesB[0]; 963 | if( numParallel == numFreeLoops ){ 964 | #pragma omp parallel for collapse(2) 965 | for(int i=0; i < s0; ++i) 966 | for(int j=0; j < s1; ++j) 967 | TCL_DUPLICATE(stridesC1==0, 968 | TCL_DUPLICATE(stridesC1==1, 969 | TCL_DUPLICATE(stridesA1==0, 970 | TCL_DUPLICATE(stridesA1==1, 971 | TCL_DUPLICATE(stridesB1==0, 972 | TCL_DUPLICATE(stridesB1==1, 973 | if( betaIsZero ) 974 | //C[i*stridesC1 + j * stridesC0] = alpha * dataA[i*stridesA1 + j * stridesA0] * dataB[i*stridesB1 + j * stridesB0]; 975 | dataC[i*stridesC0 + j*stridesC1] = alpha * contractedLoops(stridesA, &dataA[i*stridesA0 + j*stridesA1], 976 | stridesB, &dataB[i*stridesB0 + j*stridesB1], sizes, loop+2); 977 | else 978 | dataC[i*stridesC0 + j*stridesC1] = alpha * contractedLoops(stridesA, &dataA[i*stridesA0 + j*stridesA1], 979 | stridesB, &dataB[i*stridesB0 + j*stridesB1], sizes, loop+2) + beta * dataC[i*stridesC0 + j * stridesC1]; 980 | )))))) 981 | }else{ 982 | #pragma omp parallel for collapse(2) 983 | for(int i=0; i < s0; ++i) 984 | for(int j=0; j < s1; ++j) 985 | freeLoops(alpha, stridesA, &dataA[i*stridesA0 + j*stridesA1], stridesB, &dataB[i*stridesB0 + j*stridesB1], beta, stridesC, &dataC[i*stridesC0 + j*stridesC1], 986 | sizes, numFreeLoops, numContractedLoops, loop+2); 987 | } 988 | }else if (numParallel == 1){ 989 | const auto s0 = sizes[0]; 990 | const auto stridesC0 = stridesC[0]; 991 | const auto stridesA0 = stridesA[0]; 992 | const auto stridesB0 = stridesB[0]; 993 | #pragma omp parallel for 994 | for(int i=0; i < s0; ++i) 995 | TCL_DUPLICATE(stridesC0==0, 996 | TCL_DUPLICATE(stridesC0==1, 997 | TCL_DUPLICATE(stridesA0==0, 998 | TCL_DUPLICATE(stridesA0==1, 999 | TCL_DUPLICATE(stridesB0==0, 1000 | TCL_DUPLICATE(stridesB0==1, 1001 | if( betaIsZero ) 1002 | dataC[i*stridesC0] = alpha * contractedLoops(stridesA, &dataA[i*stridesA0], 1003 | stridesB, &dataB[i*stridesB0], sizes, loop+1); 1004 | else 1005 | dataC[i*stridesC0] = alpha * contractedLoops(stridesA, &dataA[i*stridesA0], 1006 | stridesB, &dataB[i*stridesB0], sizes, loop+1) + beta * dataC[i*stridesC0]; 1007 | )))))) 1008 | }else if(loop < numFreeLoops){ 1009 | // recurse without spawning threads 1010 | 1011 | const auto stridesC0 = stridesC[loop]; 1012 | const auto stridesA0 = stridesA[loop]; 1013 | const auto stridesB0 = stridesB[loop]; 1014 | for(int i=0; i < sizes[loop]; ++i) 1015 | freeLoops(alpha, stridesA, &dataA[i*stridesA0], stridesB, &dataB[i*stridesB0], beta, stridesC, &dataC[i*stridesC0], 1016 | sizes, numFreeLoops, numContractedLoops, loop+1); 1017 | }else{ 1018 | // recurse without spawning threads 1019 | if( betaIsZero ) 1020 | (*dataC) = alpha * contractedLoops(stridesA, dataA, stridesB, dataB, sizes, loop); 1021 | else 1022 | (*dataC) = alpha * contractedLoops(stridesA, dataA, stridesB, dataB, sizes, loop) + beta * (*dataC); 1023 | } 1024 | } 1025 | } 1026 | 1027 | //! generic muliply method 1028 | template 1029 | error contract(const floatType alpha, const Tensor *A, const Tensor *B, const floatType beta, Tensor *C) 1030 | { 1031 | /******************* 1032 | * Determine Loop Order 1033 | ********************/ 1034 | auto loopIndices = intersect(A->getIndices(), intersect(B->getIndices(), C->getIndices())); 1035 | //! also known as the free indices of A 1036 | auto mIndices = setMinus(intersect(A->getIndices(), C->getIndices()), loopIndices); // avoid loop indices 1037 | //! also known as the free indices of B 1038 | auto nIndices = setMinus(intersect(B->getIndices(), C->getIndices()), loopIndices); 1039 | //! also known as the contracted indices 1040 | auto kIndices = setMinus(intersect(A->getIndices(), B->getIndices()), loopIndices); 1041 | 1042 | if( loopIndices.size() > 0 ) 1043 | return TENSOR_CONTRACTION_UNSUPPORTED; 1044 | 1045 | //ensure that a stride-1 indx is the innermost loop if it belongs to the kIndices 1046 | auto posB = findPos(B->getIndices().front(), kIndices); 1047 | auto posA = findPos(A->getIndices().front(), kIndices); 1048 | if( kIndices.size() > 1 && posA == -1 && posB != 0 ) 1049 | { 1050 | int count = 0; 1051 | while(count != posB){ 1052 | auto idx = kIndices.front(); kIndices.pop_front(); 1053 | kIndices.push_back(idx); 1054 | count++; 1055 | } 1056 | } 1057 | 1058 | const int numIndices = loopIndices.size() + mIndices.size() + nIndices.size() + kIndices.size(); 1059 | std::vector sizes(numIndices); 1060 | std::vector stridesA(numIndices); 1061 | std::vector stridesB(numIndices); 1062 | std::vector stridesC(numIndices); 1063 | std::list loopOrder; 1064 | 1065 | auto itA = A->getIndices().cbegin(); 1066 | auto itB = B->getIndices().cbegin(); 1067 | auto itC = C->getIndices().cbegin(); 1068 | int countIndices = 0; 1069 | // keep kIndices as innermost loops 1070 | for( auto idx : kIndices ){ 1071 | loopOrder.push_front(idx); 1072 | sizes[numIndices - countIndices - 1] = A->getSize(idx); 1073 | stridesA[numIndices - countIndices - 1] = A->getStride(idx); 1074 | stridesB[numIndices - countIndices - 1] = B->getStride(idx); 1075 | stridesC[numIndices - countIndices - 1] = 0; 1076 | countIndices++; 1077 | } 1078 | while(itA != A->getIndices().cend() || itB != B->getIndices().cend() || itC != C->getIndices().cend() ) 1079 | { 1080 | // add one index of C 1081 | for(; itC != C->getIndices().cend(); itC++) 1082 | { 1083 | if( !find(*itC, loopOrder) ){ 1084 | loopOrder.push_front(*itC); 1085 | sizes[numIndices-1 - countIndices] = C->getSize(*itC); 1086 | stridesA[numIndices - countIndices - 1] = A->getStride(*itC); 1087 | stridesB[numIndices - countIndices - 1] = B->getStride(*itC); 1088 | stridesC[numIndices - countIndices - 1] = C->getStride(*itC); 1089 | countIndices++; 1090 | break; 1091 | } 1092 | } 1093 | // add one index of A 1094 | for(; itA != A->getIndices().cend(); itA++) 1095 | { 1096 | if( !find(*itA, loopOrder) ){ 1097 | loopOrder.push_front(*itA); 1098 | sizes[numIndices - countIndices - 1] = A->getSize(*itA); 1099 | stridesA[numIndices - countIndices - 1] = A->getStride(*itA); 1100 | stridesB[numIndices - countIndices - 1] = B->getStride(*itA); 1101 | stridesC[numIndices - countIndices - 1] = C->getStride(*itA); 1102 | countIndices++; 1103 | break; 1104 | } 1105 | } 1106 | // add one index of B 1107 | for(; itB != B->getIndices().cend(); itB++) 1108 | { 1109 | if( !find(*itB, loopOrder) ){ 1110 | loopOrder.push_front(*itB); 1111 | sizes[numIndices - countIndices - 1] = B->getSize(*itB); 1112 | stridesA[numIndices - countIndices - 1] = A->getStride(*itB); 1113 | stridesB[numIndices - countIndices - 1] = B->getStride(*itB); 1114 | stridesC[numIndices - countIndices - 1] = C->getStride(*itB); 1115 | countIndices++; 1116 | break; 1117 | } 1118 | } 1119 | } 1120 | assert( countIndices == mIndices.size() + nIndices.size() + kIndices.size() + loopIndices.size() ); 1121 | // printVector(loopOrder, "Loop Order"); 1122 | /***************************************/ 1123 | 1124 | int numParallelizableLoops = mIndices.size() + nIndices.size() + loopIndices.size(); 1125 | if( numParallelizableLoops >= 2 ) 1126 | if( std::abs(beta) < getZeroThreshold()) 1127 | freeLoops(alpha, stridesA, A->getData(), 1128 | stridesB, B->getData(), 1129 | beta, stridesC, C->getData(), 1130 | sizes, numParallelizableLoops, kIndices.size(), 0); 1131 | else 1132 | freeLoops(alpha, stridesA, A->getData(), 1133 | stridesB, B->getData(), 1134 | beta, stridesC, C->getData(), 1135 | sizes, numParallelizableLoops, kIndices.size(), 0); 1136 | else if( numParallelizableLoops == 1 ) 1137 | if( std::abs(beta) < getZeroThreshold()) 1138 | freeLoops(alpha, stridesA, A->getData(), 1139 | stridesB, B->getData(), 1140 | beta, stridesC, C->getData(), 1141 | sizes, numParallelizableLoops, kIndices.size(), 0); 1142 | else 1143 | freeLoops(alpha, stridesA, A->getData(), 1144 | stridesB, B->getData(), 1145 | beta, stridesC, C->getData(), 1146 | sizes, numParallelizableLoops, kIndices.size(), 0); 1147 | else 1148 | if( std::abs(beta) < getZeroThreshold()) 1149 | freeLoops(alpha, stridesA, A->getData(), 1150 | stridesB, B->getData(), 1151 | beta, stridesC, C->getData(), 1152 | sizes, numParallelizableLoops, kIndices.size(), 0); 1153 | else 1154 | freeLoops(alpha, stridesA, A->getData(), 1155 | stridesB, B->getData(), 1156 | beta, stridesC, C->getData(), 1157 | sizes, numParallelizableLoops, kIndices.size(), 0); 1158 | return SUCCESS; 1159 | } 1160 | 1161 | 1162 | template 1163 | error processLoopedIndices(const floatType alpha, const Tensor *A, 1164 | const Tensor *B, 1165 | const floatType beta, Tensor *C, 1166 | indicesType::iterator it, const indicesType::iterator &end) 1167 | { 1168 | if( it != end ){ 1169 | auto loopIdx = *it; 1170 | auto loopSize = A->getSize(loopIdx); 1171 | auto loopStrideA = A->getStride(loopIdx); 1172 | auto loopStrideB = B->getStride(loopIdx); 1173 | auto loopStrideC = C->getStride(loopIdx); 1174 | 1175 | const floatType *dataA = A->getData(); 1176 | const floatType *dataB = B->getData(); 1177 | floatType *dataC = C->getData(); 1178 | 1179 | for(int l=0; l < loopSize; ++l){ 1180 | A->setData(dataA + loopStrideA * l); 1181 | B->setData(dataB + loopStrideB * l); 1182 | C->setData(dataC + loopStrideC * l); 1183 | 1184 | processLoopedIndices(alpha, A, B, beta, C, it++, end); 1185 | } 1186 | }else{ 1187 | contractTTGT(alpha, A, B, beta, C); 1188 | } 1189 | return SUCCESS; 1190 | } 1191 | 1192 | template 1193 | error tensorMult(const floatType alpha, const Tensor *A, 1194 | const Tensor *B, 1195 | const floatType beta, Tensor *C) 1196 | { 1197 | #ifdef TIMERS 1198 | auto start = omp_get_wtime(); 1199 | #endif 1200 | auto indicesA = A->getIndices(); 1201 | auto indicesB = B->getIndices(); 1202 | auto indicesC = C->getIndices(); 1203 | 1204 | // error checking 1205 | if( A->getDim() <= 0 || A->getDim() != indicesA.size() ) 1206 | return INVALID_PARAMETER_2; 1207 | if( B->getDim() <= 0 || B->getDim() != indicesB.size() ) 1208 | return INVALID_PARAMETER_4; 1209 | if( C->getDim() <= 0 || C->getDim() != indicesC.size() ) 1210 | return INVALID_PARAMETER_7; 1211 | 1212 | // TODO merge indices 1213 | 1214 | // check for duplicates 1215 | for(auto it = indicesA.begin(); it != indicesA.end(); it++) 1216 | for(auto itt = std::next(it); itt != indicesA.end(); itt++) 1217 | if( *it == *itt ) 1218 | return INVALID_PARAMETER_2; 1219 | for(auto it = indicesB.begin(); it != indicesB.end(); it++) 1220 | for(auto itt = std::next(it); itt != indicesB.end(); itt++) 1221 | if( *it == *itt ) 1222 | return INVALID_PARAMETER_4; 1223 | for(auto it = indicesC.begin(); it != indicesC.end(); it++) 1224 | for(auto itt = std::next(it); itt != indicesC.end(); itt++) 1225 | if( *it == *itt ) 1226 | return INVALID_PARAMETER_7; 1227 | 1228 | // check for correct sizes 1229 | { 1230 | int i = 0; 1231 | for(auto it = indicesC.begin(); it != indicesC.end(); it++, i++){ 1232 | int j = 0; 1233 | for(auto itt = indicesA.begin(); itt != indicesA.end(); itt++, j++) 1234 | if( *it == *itt && A->getSize()[j] != C->getSize()[i] ) 1235 | return INVALID_TENSOR_SIZE; 1236 | } 1237 | i = 0; 1238 | for(auto it = indicesC.begin(); it != indicesC.end(); it++, i++){ 1239 | int j = 0; 1240 | for(auto itt = indicesB.begin(); itt != indicesB.end(); itt++, j++) 1241 | if( *it == *itt && B->getSize()[j] != C->getSize()[i] ) 1242 | return INVALID_TENSOR_SIZE; 1243 | } 1244 | i = 0; 1245 | for(auto it = indicesB.begin(); it != indicesB.end(); it++, i++){ 1246 | int j = 0; 1247 | for(auto itt = indicesA.begin(); itt != indicesA.end(); itt++, j++) 1248 | if( *it == *itt && A->getSize()[j] != B->getSize()[i] ) 1249 | return INVALID_TENSOR_SIZE; 1250 | } 1251 | } 1252 | 1253 | // swap A and B if the stride-1 index of C appears in B 1254 | if( find(indicesC.front(), indicesB) && !find(indicesC.front(), indicesA) ) 1255 | std::swap(A, B); 1256 | 1257 | return contractTTGT(alpha, A, B, beta, C); 1258 | } 1259 | 1260 | template error contract(const float alpha, const Tensor *A, const Tensor *B, const float beta, Tensor *C); 1261 | template error contractTTGT(const float alpha, const Tensor *A, const Tensor *B, const float beta, Tensor *C); 1262 | template error tensorMult(const float alpha, const Tensor *A, const Tensor *B, 1263 | const float beta, Tensor *C); 1264 | 1265 | template error contract(const double alpha, const Tensor *A, const Tensor *B, const double beta, Tensor *C); 1266 | template error contractTTGT(const double alpha, const Tensor *A, const Tensor *B, const double beta, Tensor *C); 1267 | template error tensorMult(const double alpha, const Tensor *A, const Tensor *B, 1268 | const double beta, Tensor *C); 1269 | 1270 | template error contract(const FloatComplex alpha, const Tensor *A, const Tensor *B, const FloatComplex beta, Tensor *C); 1271 | template error contractTTGT(const FloatComplex alpha, const Tensor *A, const Tensor *B, const FloatComplex beta, Tensor *C); 1272 | template error tensorMult(const FloatComplex alpha, const Tensor *A, const Tensor *B, 1273 | const FloatComplex beta, Tensor *C); 1274 | 1275 | template error contract(const DoubleComplex alpha, const Tensor *A, const Tensor *B, const DoubleComplex beta, Tensor *C); 1276 | template error contractTTGT(const DoubleComplex alpha, const Tensor *A, const Tensor *B, const DoubleComplex beta, Tensor *C); 1277 | template error tensorMult(const DoubleComplex alpha, const Tensor *A, const Tensor *B, 1278 | const DoubleComplex beta, Tensor *C); 1279 | } 1280 | 1281 | extern "C" 1282 | { 1283 | void sTensorMult(const float alpha, const float *dataA, const long *sizeA, const long *outerSizeA, const char* indA, 1284 | const float *dataB, const long *sizeB, const long *outerSizeB, const char* indB, 1285 | const float beta , float *dataC, const long *sizeC, const long *outerSizeC, const char* indC, const int useRowMajor ) 1286 | { 1287 | tcl::indicesType indicesA, indicesB, indicesC; 1288 | tcl::split(std::string(indA), ',', indicesA); 1289 | tcl::split(std::string(indB), ',', indicesB); 1290 | tcl::split(std::string(indC), ',', indicesC); 1291 | if( useRowMajor ){ 1292 | indicesA.reverse(); 1293 | indicesB.reverse(); 1294 | indicesC.reverse(); 1295 | } 1296 | int dimA = indicesA.size(); 1297 | int dimB = indicesB.size(); 1298 | int dimC = indicesC.size(); 1299 | std::vector sizeA_, outerSizeA_; 1300 | std::vector sizeB_, outerSizeB_; 1301 | std::vector sizeC_, outerSizeC_; 1302 | for(int i=0; i < dimA; ++i){ 1303 | int idx = i; 1304 | if( useRowMajor ) 1305 | idx = dimA - i - 1; 1306 | sizeA_.emplace_back(sizeA[idx]); 1307 | if( outerSizeA == nullptr ) 1308 | outerSizeA_.emplace_back(sizeA[idx]); 1309 | else 1310 | outerSizeA_.emplace_back(outerSizeA[idx]); 1311 | } 1312 | for(int i=0; i < dimB; ++i){ 1313 | int idx = i; 1314 | if( useRowMajor ) 1315 | idx = dimB - i - 1; 1316 | sizeB_.emplace_back(sizeB[idx]); 1317 | if( outerSizeB == nullptr ) 1318 | outerSizeB_.emplace_back(sizeB[idx]); 1319 | else 1320 | outerSizeB_.emplace_back(outerSizeB[idx]); 1321 | } 1322 | for(int i=0; i < dimC; ++i){ 1323 | int idx = i; 1324 | if( useRowMajor ) 1325 | idx = dimC - i - 1; 1326 | sizeC_.emplace_back(sizeC[idx]); 1327 | if( outerSizeC == nullptr ) 1328 | outerSizeC_.emplace_back(sizeC[idx]); 1329 | else 1330 | outerSizeC_.emplace_back(outerSizeC[idx]); 1331 | } 1332 | std::vector offsets; 1333 | 1334 | tcl::Tensor A( sizeA_, const_cast(dataA), outerSizeA_, indicesA, offsets); 1335 | tcl::Tensor B( sizeB_, const_cast(dataB), outerSizeB_, indicesB, offsets); 1336 | tcl::Tensor C( sizeC_, dataC , outerSizeC_, indicesC, offsets); 1337 | 1338 | if( tcl::tensorMult(alpha, &A, &B, beta, &C) != tcl::SUCCESS ) 1339 | printf("[TCL] ERROR: some error occured in tensorMult()\n"); 1340 | } 1341 | 1342 | void dTensorMult(const double alpha, const double *dataA, const long *sizeA, const long *outerSizeA, const char* indA, 1343 | const double *dataB, const long *sizeB, const long *outerSizeB, const char* indB, 1344 | const double beta , double *dataC, const long *sizeC, const long *outerSizeC, const char* indC, const int useRowMajor ) 1345 | { 1346 | tcl::indicesType indicesA, indicesB, indicesC; 1347 | tcl::split(std::string(indA), ',', indicesA); 1348 | tcl::split(std::string(indB), ',', indicesB); 1349 | tcl::split(std::string(indC), ',', indicesC); 1350 | if( useRowMajor ){ 1351 | indicesA.reverse(); 1352 | indicesB.reverse(); 1353 | indicesC.reverse(); 1354 | } 1355 | int dimA = indicesA.size(); 1356 | int dimB = indicesB.size(); 1357 | int dimC = indicesC.size(); 1358 | std::vector sizeA_, outerSizeA_; 1359 | std::vector sizeB_, outerSizeB_; 1360 | std::vector sizeC_, outerSizeC_; 1361 | for(int i=0; i < dimA; ++i){ 1362 | int idx = i; 1363 | if( useRowMajor ) 1364 | idx = dimA - i - 1; 1365 | sizeA_.emplace_back(sizeA[idx]); 1366 | if( outerSizeA == nullptr ) 1367 | outerSizeA_.emplace_back(sizeA[idx]); 1368 | else 1369 | outerSizeA_.emplace_back(outerSizeA[idx]); 1370 | } 1371 | for(int i=0; i < dimB; ++i){ 1372 | int idx = i; 1373 | if( useRowMajor ) 1374 | idx = dimB - i - 1; 1375 | sizeB_.emplace_back(sizeB[idx]); 1376 | if( outerSizeB == nullptr ) 1377 | outerSizeB_.emplace_back(sizeB[idx]); 1378 | else 1379 | outerSizeB_.emplace_back(outerSizeB[idx]); 1380 | } 1381 | for(int i=0; i < dimC; ++i){ 1382 | int idx = i; 1383 | if( useRowMajor ) 1384 | idx = dimC - i - 1; 1385 | sizeC_.emplace_back(sizeC[idx]); 1386 | if( outerSizeB == nullptr ) 1387 | outerSizeC_.emplace_back(sizeC[idx]); 1388 | else 1389 | outerSizeC_.emplace_back(outerSizeC[idx]); 1390 | } 1391 | std::vector offsets; 1392 | 1393 | tcl::Tensor A( sizeA_, const_cast(dataA), outerSizeA_, indicesA, offsets); 1394 | tcl::Tensor B( sizeB_, const_cast(dataB), outerSizeB_, indicesB, offsets); 1395 | tcl::Tensor C( sizeC_, dataC , outerSizeC_, indicesC, offsets); 1396 | 1397 | if( tcl::tensorMult(alpha, &A, &B, beta, &C) != tcl::SUCCESS ) 1398 | printf("[TCL] ERROR: some error occured in tensorMult()\n"); 1399 | } 1400 | 1401 | void cTensorMult(const float _Complex alpha, const float _Complex *dataA, const long *sizeA, const long *outerSizeA, const char* indA, 1402 | const float _Complex *dataB, const long *sizeB, const long *outerSizeB, const char* indB, 1403 | const float _Complex beta , float _Complex *dataC, const long *sizeC, const long *outerSizeC, const char* indC, const int useRowMajor ) 1404 | { 1405 | tcl::indicesType indicesA, indicesB, indicesC; 1406 | tcl::split(std::string(indA), ',', indicesA); 1407 | tcl::split(std::string(indB), ',', indicesB); 1408 | tcl::split(std::string(indC), ',', indicesC); 1409 | if( useRowMajor ){ 1410 | indicesA.reverse(); 1411 | indicesB.reverse(); 1412 | indicesC.reverse(); 1413 | } 1414 | int dimA = indicesA.size(); 1415 | int dimB = indicesB.size(); 1416 | int dimC = indicesC.size(); 1417 | std::vector sizeA_, outerSizeA_; 1418 | std::vector sizeB_, outerSizeB_; 1419 | std::vector sizeC_, outerSizeC_; 1420 | for(int i=0; i < dimA; ++i){ 1421 | int idx = i; 1422 | if( useRowMajor ) 1423 | idx = dimA - i - 1; 1424 | sizeA_.emplace_back(sizeA[idx]); 1425 | if( outerSizeA == nullptr ) 1426 | outerSizeA_.emplace_back(sizeA[idx]); 1427 | else 1428 | outerSizeA_.emplace_back(outerSizeA[idx]); 1429 | } 1430 | for(int i=0; i < dimB; ++i){ 1431 | int idx = i; 1432 | if( useRowMajor ) 1433 | idx = dimB - i - 1; 1434 | sizeB_.emplace_back(sizeB[idx]); 1435 | if( outerSizeB == nullptr ) 1436 | outerSizeB_.emplace_back(sizeB[idx]); 1437 | else 1438 | outerSizeB_.emplace_back(outerSizeB[idx]); 1439 | } 1440 | for(int i=0; i < dimC; ++i){ 1441 | int idx = i; 1442 | if( useRowMajor ) 1443 | idx = dimC - i - 1; 1444 | sizeC_.emplace_back(sizeC[idx]); 1445 | if( outerSizeB == nullptr ) 1446 | outerSizeC_.emplace_back(sizeC[idx]); 1447 | else 1448 | outerSizeC_.emplace_back(outerSizeC[idx]); 1449 | } 1450 | std::vector offsets; 1451 | 1452 | tcl::Tensor A( sizeA_, const_cast((const tcl::FloatComplex*)dataA), outerSizeA_, indicesA, offsets); 1453 | tcl::Tensor B( sizeB_, const_cast((const tcl::FloatComplex*)dataB), outerSizeB_, indicesB, offsets); 1454 | tcl::Tensor C( sizeC_, (tcl::FloatComplex*) dataC , outerSizeC_, indicesC, offsets); 1455 | 1456 | if( tcl::tensorMult((const tcl::FloatComplex) alpha, &A, &B, (const tcl::FloatComplex)beta, &C) != tcl::SUCCESS ) 1457 | printf("[TCL] ERROR: some error occured in tensorMult()\n"); 1458 | } 1459 | 1460 | void zTensorMult(const double _Complex alpha, const double _Complex *dataA, const long *sizeA, const long *outerSizeA, const char* indA, 1461 | const double _Complex *dataB, const long *sizeB, const long *outerSizeB, const char* indB, 1462 | const double _Complex beta , double _Complex *dataC, const long *sizeC, const long *outerSizeC, const char* indC, const int useRowMajor ) 1463 | { 1464 | tcl::indicesType indicesA, indicesB, indicesC; 1465 | tcl::split(std::string(indA), ',', indicesA); 1466 | tcl::split(std::string(indB), ',', indicesB); 1467 | tcl::split(std::string(indC), ',', indicesC); 1468 | if( useRowMajor ){ 1469 | indicesA.reverse(); 1470 | indicesB.reverse(); 1471 | indicesC.reverse(); 1472 | } 1473 | int dimA = indicesA.size(); 1474 | int dimB = indicesB.size(); 1475 | int dimC = indicesC.size(); 1476 | std::vector sizeA_, outerSizeA_; 1477 | std::vector sizeB_, outerSizeB_; 1478 | std::vector sizeC_, outerSizeC_; 1479 | for(int i=0; i < dimA; ++i){ 1480 | int idx = i; 1481 | if( useRowMajor ) 1482 | idx = dimA - i - 1; 1483 | sizeA_.emplace_back(sizeA[idx]); 1484 | if( outerSizeA == nullptr ) 1485 | outerSizeA_.emplace_back(sizeA[idx]); 1486 | else 1487 | outerSizeA_.emplace_back(outerSizeA[idx]); 1488 | } 1489 | for(int i=0; i < dimB; ++i){ 1490 | int idx = i; 1491 | if( useRowMajor ) 1492 | idx = dimB - i - 1; 1493 | sizeB_.emplace_back(sizeB[idx]); 1494 | if( outerSizeB == nullptr ) 1495 | outerSizeB_.emplace_back(sizeB[idx]); 1496 | else 1497 | outerSizeB_.emplace_back(outerSizeB[idx]); 1498 | } 1499 | for(int i=0; i < dimC; ++i){ 1500 | int idx = i; 1501 | if( useRowMajor ) 1502 | idx = dimC - i - 1; 1503 | sizeC_.emplace_back(sizeC[idx]); 1504 | if( outerSizeB == nullptr ) 1505 | outerSizeC_.emplace_back(sizeC[idx]); 1506 | else 1507 | outerSizeC_.emplace_back(outerSizeC[idx]); 1508 | } 1509 | std::vector offsets; 1510 | 1511 | tcl::Tensor A( sizeA_, const_cast((const tcl::DoubleComplex*)dataA), outerSizeA_, indicesA, offsets); 1512 | tcl::Tensor B( sizeB_, const_cast((const tcl::DoubleComplex*)dataB), outerSizeB_, indicesB, offsets); 1513 | tcl::Tensor C( sizeC_, (tcl::DoubleComplex*) dataC , outerSizeC_, indicesC, offsets); 1514 | 1515 | if( tcl::tensorMult((const tcl::DoubleComplex) alpha, &A, &B, (const tcl::DoubleComplex)beta, &C) != tcl::SUCCESS ) 1516 | printf("[TCL] ERROR: some error occured in tensorMult()\n"); 1517 | } 1518 | 1519 | void randomNumaAwareInit(float *data, const long *size, int dim) 1520 | { 1521 | long totalSize = 1; 1522 | for(int i = 0; i < dim; i++) 1523 | totalSize *= size[i]; 1524 | #pragma omp parallel for 1525 | for(int i=0; i < totalSize; ++i) 1526 | data[i] = (i+1)%1000 - 500; 1527 | 1528 | 1529 | } 1530 | } 1531 | 1532 | --------------------------------------------------------------------------------