├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── cudarray ├── .gitignore ├── __init__.py ├── base.py ├── batch │ ├── __init__.py │ └── linalg.py ├── cudarray.py ├── elementwise.py ├── extra │ ├── __init__.py │ └── array.py ├── helpers.py ├── linalg.py ├── nnet │ ├── __init__.py │ ├── conv.py │ ├── image.py │ ├── math.py │ ├── pool.py │ └── special.py ├── numpy_backend │ ├── __init__.py │ └── nnet │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── conv.py │ │ ├── conv_bc01.pyx │ │ ├── lrnorm_bc01.pyx │ │ ├── pool.py │ │ ├── pool_bc01.pyx │ │ └── special.py ├── random.py ├── reduction.py └── wrap │ ├── __init__.py │ ├── array_data.pxd │ ├── array_data.pyx │ ├── array_ops.pxd │ ├── array_ops.pyx │ ├── blas.pxd │ ├── blas.pyx │ ├── cudart.pxd │ ├── cudart.pyx │ ├── cudnn.pxd │ ├── cudnn.pyx │ ├── elementwise.pxd │ ├── elementwise.pyx │ ├── image.pxd │ ├── image.pyx │ ├── nnet.pxd │ ├── nnet.pyx │ ├── random.pxd │ ├── random.pyx │ ├── reduction.pxd │ └── reduction.pyx ├── examples ├── benchmark_conv.py └── test.py ├── include └── cudarray │ ├── array_ops.hpp │ ├── blas.hpp │ ├── common.hpp │ ├── elementwise.hpp │ ├── image │ ├── img2win.hpp │ └── rescale.hpp │ ├── nnet │ ├── conv_bc01_matmul.hpp │ ├── cudnn.hpp │ ├── one_hot.hpp │ └── pool_b01.hpp │ ├── random.hpp │ └── reduction.hpp ├── requirements.txt ├── setup.py └── src ├── array_ops.cu ├── blas.cpp ├── elementwise.cu ├── image ├── img2win.cu └── rescale.cu ├── nnet ├── conv_bc01_matmul.cpp ├── cudnn.cpp ├── one_hot.cu └── pool_b01.cu ├── random.cu └── reduction.cu /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | *.o 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .cache 41 | nosetests.xml 42 | coverage.xml 43 | 44 | # Translations 45 | *.mo 46 | *.pot 47 | 48 | # Django stuff: 49 | *.log 50 | 51 | # Sphinx documentation 52 | docs/_build/ 53 | 54 | # PyBuilder 55 | target/ 56 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Anders Boesen Lindbo Larsen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ifndef CUDA_PREFIX 2 | CUDA_PREFIX = /usr/local/cuda 3 | endif 4 | ifndef INSTALL_PREFIX 5 | INSTALL_PREFIX=/usr/local 6 | endif 7 | 8 | 9 | SRC_DIR = ./src 10 | 11 | SRCS = $(SRC_DIR)/nnet/conv_bc01_matmul.cpp \ 12 | $(SRC_DIR)/nnet/pool_b01.cpp \ 13 | $(SRC_DIR)/nnet/cudnn.cpp 14 | 15 | CUDA_SRCS = $(SRC_DIR)/array_ops.cu \ 16 | $(SRC_DIR)/elementwise.cu \ 17 | $(SRC_DIR)/reduction.cu \ 18 | $(SRC_DIR)/blas.cu \ 19 | $(SRC_DIR)/random.cu \ 20 | $(SRC_DIR)/image/img2win.cu \ 21 | $(SRC_DIR)/image/rescale.cu \ 22 | $(SRC_DIR)/nnet/one_hot.cu 23 | 24 | 25 | INCLUDE_DIRS = ./include 26 | INCLUDE_DIRS += $(CUDA_PREFIX)/include 27 | 28 | ifneq ($(wildcard $(CUDA_PREFIX)/lib64),) 29 | # Use lib64 if it exists 30 | LIB_DIRS += $(CUDA_PREFIX)/lib64 31 | endif 32 | LIB_DIRS += $(CUDA_PREFIX)/lib 33 | LIBS += cudart cublas cufft curand 34 | 35 | ifeq ($(CUDNN_ENABLED), 1) 36 | C_FLAGS += -DCUDNN_ENABLED 37 | LIBS += cudnn 38 | endif 39 | 40 | ifndef CUDA_ARCH 41 | # By default, libcudarray is built for a range of different CUDA 42 | # architectures. You can speed up compilation time by selecting only the 43 | # architecture for your GPU. 44 | CUDA_ARCH = -gencode arch=compute_20,code=sm_20 \ 45 | -gencode arch=compute_20,code=compute_20 \ 46 | -gencode arch=compute_30,code=sm_30 \ 47 | -gencode arch=compute_30,code=compute_30 \ 48 | -gencode arch=compute_35,code=sm_35 \ 49 | -gencode arch=compute_35,code=compute_35 50 | endif 51 | 52 | export PATH := $(CUDA_PREFIX)/bin:$(PATH) 53 | 54 | CXX = g++ 55 | NVCC = nvcc 56 | BUILD_DIR = ./build 57 | OBJS = $(SRCS:.cpp=.o) $(CUDA_SRCS:.cu=.o) 58 | LIBCUDARRAY = libcudarray.so 59 | LIBCUDARRAY_BUILD = $(BUILD_DIR)/$(LIBCUDARRAY) 60 | LIBCUDARRAY_INSTALL = $(INSTALL_PREFIX)/lib/$(LIBCUDARRAY) 61 | 62 | INCLUDES += $(foreach include_dir,$(INCLUDE_DIRS),-I$(include_dir)) 63 | C_FLAGS += -O3 -fPIC -Wall -Wfatal-errors -D_FORCE_INLINES 64 | NVCC_FLAGS = $(CUDA_ARCH) -O3 --compiler-options '$(C_FLAGS)' \ 65 | --ftz=true --prec-div=false -prec-sqrt=false --fmad=true 66 | LDFLAGS += $(foreach lib_dir,$(LIB_DIRS),-L$(lib_dir)) \ 67 | $(foreach lib,$(LIBS),-l$(lib)) 68 | 69 | 70 | $(LIBCUDARRAY_BUILD) : $(OBJS) 71 | mkdir -p $(BUILD_DIR) 72 | $(CXX) -shared $(C_FLAGS) -o $@ $^ $(LDFLAGS) 73 | 74 | %.o : %.cpp 75 | $(CXX) $(C_FLAGS) $(INCLUDES) -c -o $@ $< 76 | 77 | %.o : %.cu 78 | $(NVCC) $(NVCC_FLAGS) $(INCLUDES) -c -o $@ $< 79 | 80 | all: $(LIBCUDARRAY_BUILD) 81 | 82 | $(LIBCUDARRAY_INSTALL) : $(LIBCUDARRAY_BUILD) 83 | cp $(LIBCUDARRAY_BUILD) $(LIBCUDARRAY_INSTALL) 84 | 85 | install: $(INSTALL_PREFIX)/lib/$(LIBCUDARRAY) 86 | 87 | uninstall: 88 | rm $(LIBCUDARRAY_INSTALL) 89 | 90 | .PHONY: clean 91 | clean: 92 | rm -f $(OBJS) $(LIBCUDARRAY_BUILD) 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CUDA-based NumPy 2 | 3 | CUDArray is a CUDA-accelerated subset of the [NumPy](http://www.numpy.org/) library. 4 | The goal of CUDArray is to combine the easy of development from the NumPy with the computational power of Nvidia GPUs in a lightweight and extensible framework. 5 | 6 | CUDArray currently imposes many limitations in order to span a manageable subset of the NumPy library. 7 | Nonetheless, it supports a neural network pipeline as demonstrated in the project [deeppy](http://github.com/andersbll/deeppy/). 8 | 9 | 10 | ### Features 11 | - Drop-in replacement for NumPy (limitations apply). 12 | - Fast array operations based on cuBLAS, cuRAND and cuDNN. 13 | - (somewhat) Simple C++/CUDA wrapper based on Cython. 14 | - Extends NumPy with specialized functions for neural networks. 15 | - CPU fall-back when CUDA is not available. 16 | 17 | 18 | ### Installation 19 | ##### With CUDA back-end 20 | First, you should consider specifying the following environment variables. 21 | - `INSTALL_PREFIX` (default: `/usr/local`). Path where to install libcudarray. For the Anaconda Python distribution this should be `/path/to/anaconda`. 22 | - `CUDA_PREFIX` (default: `/usr/local/cuda`). Path to the CUDA SDK organized in `bin/`, `lib/`, `include/` folders. 23 | - `CUDNN_ENABLED`. Set `CUDNN_ENABLED` to `1` to include cuDNN operations in `libcudarray`. 24 | 25 | Then build and install libcudarray with 26 | 27 | make 28 | make install 29 | 30 | Finally, install the cudarray Python package: 31 | 32 | python setup.py install 33 | 34 | 35 | ##### Without CUDA back-end 36 | Install the cudarray Python package: 37 | 38 | python setup.py --without-cuda install 39 | 40 | 41 | ### Documentation 42 | Please consult the [technical report][techreport] for now. 43 | Proper documentation is on the TODO list. 44 | 45 | 46 | ### Contact 47 | Feel free to report an [issue](http://github.com/andersbll/cudarray/issues) for feature requests and bug reports. 48 | 49 | For a more informal chat, visit #cudarray on the [freenode](http://freenode.net/) IRC network. 50 | 51 | 52 | ### Citation 53 | If you use CUDArray for research, please cite the [technical report][techreport]: 54 | 55 | @techreport{larsen2014cudarray, 56 | author = "Larsen, Anders Boesen Lindbo", 57 | title = "{CUDArray}: {CUDA}-based {NumPy}", 58 | institution = "Department of Applied Mathematics and Computer Science, Technical University of Denmark", 59 | year = "2014", 60 | number = "DTU Compute 2014-21", 61 | } 62 | 63 | 64 | ### TODO 65 | - Proper transpose support, 66 | - Add functionality for copying from NumPy array to existing CUDArray array. 67 | - FFT module based on cuFFT. 68 | - Unit tests! 69 | - Add documentation to wiki. 70 | - Windows/OS X support. 71 | 72 | 73 | ### Influences 74 | Thanks to the following projects for inspiration. 75 | - [cudamat](http://github.com/cudamat/cudamat) 76 | - [PyCUDA](http://mathema.tician.de/software/pycuda/) 77 | - [mshadow](http://github.com/tqchen/mshadow/) 78 | - [Caffe](http://caffe.berkeleyvision.org/) 79 | - [CUDPP](http://cudpp.github.io/) 80 | 81 | 82 | [techreport]: http://www2.compute.dtu.dk/~abll/pubs/larsen2014cudarray.pdf 83 | -------------------------------------------------------------------------------- /cudarray/.gitignore: -------------------------------------------------------------------------------- 1 | # Cython generated C files 2 | *.c 3 | *.cpp 4 | -------------------------------------------------------------------------------- /cudarray/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | _gpu_id = int(os.getenv('CUDARRAY_DEVICE', '0')) 4 | 5 | _backend = os.getenv('CUDARRAY_BACKEND') 6 | if _backend is not None: 7 | _force_backend = True 8 | _backend = _backend.lower() 9 | _valid_backends = ['numpy', 'cuda'] 10 | if _backend not in _valid_backends: 11 | raise RuntimeError('Invalid back-end "%s" specified. Valid options ' 12 | 'are: %s' % (_backend, _valid_backends)) 13 | else: 14 | # If no back-end specified, try CUDA with NumPy fall-back. 15 | _force_backend = False 16 | _backend = 'cuda' 17 | 18 | if _backend == 'cuda': 19 | try: 20 | from .cudarray import * 21 | from .base import * 22 | from .linalg import * 23 | from .elementwise import * 24 | from .reduction import * 25 | from . import random 26 | from . import nnet 27 | from . import batch 28 | from . import extra 29 | wrap.cudart.initialize(_gpu_id) 30 | except: 31 | if _force_backend: 32 | print('CUDArray: Failed to load CUDA back-end.') 33 | raise 34 | else: 35 | print('CUDArray: CUDA back-end not available, using NumPy.') 36 | # Try NumPy instead 37 | _backend = 'numpy' 38 | 39 | if _backend == 'numpy': 40 | from .numpy_backend import * 41 | 42 | 43 | __version__ = '0.1.dev' 44 | -------------------------------------------------------------------------------- /cudarray/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cudarray 3 | from .wrap import array_ops 4 | 5 | 6 | def transpose(a): 7 | if a.ndim != 2: 8 | raise ValueError('transpose is implemented for 2D arrays only') 9 | a_trans = a.view() 10 | a_trans.shape = (a.shape[1], a.shape[0]) 11 | a_trans.transposed = True 12 | return a_trans 13 | 14 | 15 | def reshape(a, newshape): 16 | a = ascontiguousarray(a) 17 | size = a.size 18 | if isinstance(newshape, int): 19 | newshape = (newshape,) 20 | newsize = np.prod(newshape) 21 | if size != newsize: 22 | if newsize < 0: 23 | # negative newsize means there is a -1 in newshape 24 | newshape = list(newshape) 25 | newshape[newshape.index(-1)] = -size // newsize 26 | newshape = tuple(newshape) 27 | else: 28 | raise ValueError('cannot reshape %s to %s' % (a.shape, newshape)) 29 | a_reshaped = a.view() 30 | a_reshaped.shape = newshape 31 | return a_reshaped 32 | 33 | 34 | def copyto(dst, src): 35 | if src.shape != dst.shape: 36 | raise ValueError('out.shape does not match result') 37 | if src.dtype != dst.dtype: 38 | raise ValueError('dtype mismatch') 39 | n = src.size 40 | if isinstance(src, np.ndarray): 41 | if isinstance(dst, np.ndarray): 42 | np.copyto(dst, src) 43 | else: 44 | src = np.ascontiguousarray(src) 45 | array_ops._to_device(src, n, dst._data) 46 | else: 47 | src = ascontiguousarray(src) 48 | if isinstance(dst, np.ndarray): 49 | array_ops._to_host(src._data, n, dst) 50 | else: 51 | array_ops._copy(src._data, n, dst._data) 52 | 53 | 54 | def ascontiguousarray(a): 55 | if not a.transposed: 56 | return a 57 | out = cudarray.empty_like(a) 58 | n, m = a.shape 59 | array_ops._transpose(a._data, m, n, out._data) 60 | return out 61 | 62 | 63 | bool_ = np.int32 64 | int_ = np.int32 65 | float_ = np.float32 66 | -------------------------------------------------------------------------------- /cudarray/batch/__init__.py: -------------------------------------------------------------------------------- 1 | from .linalg import Dot, dot 2 | -------------------------------------------------------------------------------- /cudarray/batch/linalg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cudarray as ca 3 | from ..wrap import blas 4 | from ..linalg import matmul_shape 5 | 6 | 7 | class Dot(object): 8 | def __init__(self, a, b, out=None): 9 | self.batch_size = a.shape[0] 10 | self.a = a 11 | self.b = b 12 | if a.dtype != b.dtype: 13 | raise ValueError('dtype mismatch') 14 | out_shape = (self.batch_size,) + matmul_shape(a.shape[1:], b.shape[1:]) 15 | if out is None: 16 | out = base.empty(out_shape, dtype=a.dtype) 17 | else: 18 | if out_shape != out.shape: 19 | raise ValueError('out.shape does not match result') 20 | if a.dtype != out.dtype: 21 | raise ValueError('dtype mismatch') 22 | self.out = out 23 | a_stride = np.prod(a.shape[1:]) 24 | b_stride = np.prod(b.shape[1:]) 25 | out_stride = np.prod(out.shape[1:]) 26 | self.blas_batch = blas.BLASBatch_f( 27 | a._data, b._data, out._data, self.batch_size, a_stride, b_stride, 28 | out_stride 29 | ) 30 | if a.ndim == b.ndim == 3: 31 | m, k = a.shape[1:3] 32 | n = b.shape[2] 33 | 34 | def fun(): 35 | self.blas_batch.gemm(blas.no_trans_op, blas.no_trans_op, m, n, 36 | k, 1.0, 0.0) 37 | return self.out 38 | self.perform = fun 39 | else: 40 | raise ValueError('invalid array dimensionality') 41 | 42 | 43 | def dot(a, b, out=None): 44 | return Dot(a, b, out).perform() 45 | -------------------------------------------------------------------------------- /cudarray/cudarray.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .wrap.array_data import ArrayData 3 | from .wrap import array_ops 4 | from . import elementwise 5 | from . import base 6 | from . import helpers 7 | 8 | 9 | class ndarray(object): 10 | def __init__(self, shape, dtype=None, np_data=None, array_data=None, 11 | array_owner=None): 12 | shape = helpers.require_iterable(shape) 13 | if shape == (): 14 | shape = (1,) 15 | self.shape = shape 16 | self.transposed = False 17 | self.isbool = False 18 | if dtype is None: 19 | if np_data is None: 20 | dtype = np.dtype(base.float_) 21 | else: 22 | dtype = np_data.dtype 23 | if dtype == np.dtype('float64'): 24 | dtype = np.dtype(base.float_) 25 | elif dtype == np.dtype('int64'): 26 | dtype = np.dtype(base.int_) 27 | elif dtype == np.dtype('bool'): 28 | dtype = np.dtype(base.bool_) 29 | self.isbool = True 30 | else: 31 | dtype = np.dtype(dtype) 32 | 33 | if np_data is not None: 34 | np_data = np.require(np_data, dtype=dtype, requirements='C') 35 | if array_data is None: 36 | self._data = ArrayData(self.size, dtype, np_data) 37 | else: 38 | self._data = array_data 39 | 40 | def __array__(self): 41 | np_array = np.empty(self.shape, dtype=self.dtype) 42 | self._data.to_numpy(np_array) 43 | if self.isbool: 44 | np_array = np_array.astype(np.dtype('bool')) 45 | return np_array 46 | 47 | def __str__(self): 48 | return self.__array__().__str__() 49 | 50 | def __repr__(self): 51 | return self.__array__().__repr__() 52 | 53 | def _same_array(self, other): 54 | return self.data == other.data 55 | 56 | @property 57 | def data(self): 58 | return self._data.data 59 | 60 | @property 61 | def dtype(self): 62 | return self._data.dtype 63 | 64 | @property 65 | def itemsize(self): 66 | return self._data.dtype.itemsize 67 | 68 | @property 69 | def nbytes(self): 70 | return self.size*self.itemsize 71 | 72 | @property 73 | def ndim(self): 74 | return len(self.shape) 75 | 76 | @property 77 | def size(self): 78 | return helpers.prod(self.shape) 79 | 80 | @property 81 | def T(self): 82 | return base.transpose(self) 83 | 84 | def view(self): 85 | return ndarray(self.shape, self.dtype, None, self._data) 86 | 87 | def fill(self, value): 88 | array_ops._fill(self._data, self.size, value) 89 | 90 | def __len__(self): 91 | return self.shape[0] 92 | 93 | def __add__(self, other): 94 | return elementwise.add(self, other) 95 | 96 | def __radd__(self, other): 97 | return elementwise.add(other, self) 98 | 99 | def __iadd__(self, other): 100 | return elementwise.add(self, other, self) 101 | 102 | def __sub__(self, other): 103 | return elementwise.subtract(self, other) 104 | 105 | def __rsub__(self, other): 106 | return elementwise.subtract(other, self) 107 | 108 | def __isub__(self, other): 109 | return elementwise.subtract(self, other, self) 110 | 111 | def __mul__(self, other): 112 | return elementwise.multiply(self, other) 113 | 114 | def __rmul__(self, other): 115 | return elementwise.multiply(other, self) 116 | 117 | def __imul__(self, other): 118 | return elementwise.multiply(self, other, self) 119 | 120 | def __div__(self, other): 121 | return elementwise.divide(self, other) 122 | 123 | def __rdiv__(self, other): 124 | return elementwise.divide(other, self) 125 | 126 | def __idiv__(self, other): 127 | return elementwise.divide(self, other, self) 128 | 129 | def __truediv__(self, other): 130 | return elementwise.divide(self, other) 131 | 132 | def __rtruediv__(self, other): 133 | return elementwise.divide(other, self) 134 | 135 | def __itruediv__(self, other): 136 | return elementwise.divide(self, other, self) 137 | 138 | def __pow__(self, other): 139 | return elementwise.power(self, other) 140 | 141 | def __rpow__(self, other): 142 | return elementwise.power(other, self) 143 | 144 | def __ipow__(self, other): 145 | return elementwise.power(self, other, self) 146 | 147 | def __eq__(self, other): 148 | return elementwise.equal(self, other) 149 | 150 | def __gt__(self, other): 151 | return elementwise.greater(self, other) 152 | 153 | def __ge__(self, other): 154 | return elementwise.greater_equal(self, other) 155 | 156 | def __lt__(self, other): 157 | return elementwise.less(self, other) 158 | 159 | def __le__(self, other): 160 | return elementwise.less_equal(self, other) 161 | 162 | def __ne__(self, other): 163 | return elementwise.not_equal(self, other) 164 | 165 | def __neg__(self): 166 | return elementwise.negative(self) 167 | 168 | def __ineg__(self): 169 | return elementwise.negative(self, self) 170 | 171 | def __getitem__(self, indices): 172 | if isinstance(indices, int): 173 | # Speedup case with a single index 174 | view_shape = self.shape[1:] 175 | view_size = helpers.prod(view_shape) 176 | offset = indices * view_size 177 | data_view = ArrayData(view_size, self.dtype, owner=self._data, 178 | offset=offset) 179 | return ndarray(view_shape, self.dtype, np_data=None, 180 | array_data=data_view) 181 | elif isinstance(indices, slice): 182 | indices = (indices,) 183 | # Standardize indices to a list of slices 184 | elif len(indices) > len(self.shape): 185 | raise IndexError('too many indices for array') 186 | 187 | view_shape = [] 188 | rest_must_be_contiguous = False 189 | offset = 0 190 | for i, dim in enumerate(self.shape): 191 | start = 0 192 | stop = dim 193 | append_dim = True 194 | if i < len(indices): 195 | idx = indices[i] 196 | if isinstance(idx, int): 197 | append_dim = False 198 | start = idx 199 | stop = idx+1 200 | elif isinstance(idx, slice): 201 | if idx.start is not None: 202 | start = idx.start 203 | if idx.stop is not None: 204 | stop = idx.stop 205 | if idx.step is not None: 206 | raise NotImplementedError('only contiguous indices ' 207 | + 'are supported') 208 | elif idx is Ellipsis: 209 | diff = self.ndim - len(indices) 210 | indices = indices[:i] + [slice(None)]*diff + indices[i:] 211 | return self[indices] 212 | else: 213 | raise IndexError('only integers, slices and ellipsis are ' 214 | + 'valid indices') 215 | 216 | view_dim = stop-start 217 | offset = offset * dim + start 218 | if append_dim: 219 | view_shape.append(view_dim) 220 | if rest_must_be_contiguous and view_dim < dim: 221 | raise NotImplementedError('only contiguous indices are ' 222 | + 'supported') 223 | if view_dim > 1: 224 | rest_must_be_contiguous = True 225 | 226 | view_shape = tuple(view_shape) 227 | view_size = helpers.prod(view_shape) 228 | 229 | # Construct view 230 | data_view = ArrayData(view_size, self.dtype, owner=self._data, 231 | offset=offset) 232 | return ndarray(view_shape, self.dtype, np_data=None, 233 | array_data=data_view) 234 | 235 | def __setitem__(self, indices, c): 236 | view = self.__getitem__(indices) 237 | base.copyto(view, c) 238 | 239 | 240 | def array(object, dtype=None, copy=True): 241 | np_array = np.array(object) 242 | return ndarray(np_array.shape, np_data=np_array) 243 | 244 | 245 | def empty(shape, dtype=None): 246 | return ndarray(shape, dtype=dtype) 247 | 248 | 249 | def empty_like(a, dtype=None): 250 | if not isinstance(a, (np.ndarray, ndarray)): 251 | a = np.array(a) 252 | return ndarray(a.shape, dtype=a.dtype) 253 | 254 | 255 | def ones(shape, dtype=None): 256 | return array(np.ones(shape, dtype=dtype)) 257 | 258 | 259 | def ones_like(a, dtype=None): 260 | if not isinstance(a, (np.ndarray, ndarray)): 261 | a = np.array(a) 262 | return array(np.ones_like(a, dtype=dtype)) 263 | 264 | 265 | def zeros(shape, dtype=None): 266 | a = empty(shape, dtype) 267 | a.fill(0) 268 | return a 269 | 270 | 271 | def zeros_like(a, dtype=None): 272 | if not isinstance(a, (np.ndarray, ndarray)): 273 | a = np.array(a) 274 | return array(np.zeros_like(a, dtype=dtype)) 275 | -------------------------------------------------------------------------------- /cudarray/elementwise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cudarray 3 | from .wrap import elementwise 4 | from . import helpers 5 | from . import base 6 | 7 | 8 | def broadcast_type(shape1, shape2): 9 | if shape1 == shape2: 10 | return None 11 | 12 | error = ValueError('operands could not be broadcast together with shapes ' 13 | + str(shape1) + ' ' + str(shape2)) 14 | 15 | # Bring shapes to same length by setting missing trailing dimensions to 1's 16 | len_diff = len(shape1) - len(shape2) 17 | if len_diff > 0: 18 | shape2 = (1,)*len_diff + shape2 19 | 20 | # Find out which axes to broadcast 21 | b_axes = [] 22 | for a_idx, (a1, a2) in enumerate(zip(shape1, shape2)): 23 | if a1 != a2: 24 | if a2 == 1: 25 | b_axes.append(a_idx) 26 | else: 27 | raise error 28 | 29 | ndim = len(shape1) 30 | # Detect leading broadcast 31 | if b_axes == list(range(len(b_axes))): 32 | k = 1 33 | m = helpers.prod([shape1[a] for a in b_axes]) 34 | n = helpers.prod(shape1) // m 35 | return elementwise.btype_leading, k, m, n 36 | # Detect trailing broadcast 37 | if b_axes == list(range(ndim-len(b_axes), ndim)): 38 | k = 1 39 | m = helpers.prod([shape1[a] for a in b_axes]) 40 | n = helpers.prod(shape1) // m 41 | return elementwise.btype_trailing, k, m, n 42 | # Detect inner broadcast 43 | if b_axes == list(range(b_axes[0], b_axes[0] + len(b_axes))): 44 | k = helpers.prod(shape1[:b_axes[0]]) 45 | m = helpers.prod(shape1[b_axes[0]:b_axes[-1]+1]) 46 | n = helpers.prod(shape1[b_axes[-1]+1:]) 47 | return elementwise.btype_inner, k, m, n 48 | # Detect outer broadcast 49 | for i in range(1, len(b_axes)): 50 | if b_axes[i-1] + 1 != b_axes[i]: 51 | split_idx = i 52 | break 53 | b_axes_leading = b_axes[:split_idx] 54 | b_axes_trailing = b_axes[split_idx:] 55 | if (b_axes_leading == list(range(len(b_axes_leading))) 56 | and b_axes_trailing == list(range(ndim-len(b_axes_trailing), 57 | ndim))): 58 | k = helpers.prod(shape1[:b_axes_leading[-1]+1]) 59 | m = helpers.prod(shape1[b_axes_leading[-1]+1:b_axes_trailing[0]]) 60 | n = helpers.prod(shape1[b_axes_trailing[0]:]) 61 | return elementwise.btype_outer, k, m, n 62 | 63 | raise error 64 | 65 | 66 | def binary(op, x1, x2, out=None, cmp_op=False): 67 | if np.isscalar(x1) or np.isscalar(x2): 68 | if np.isscalar(x1): 69 | flip = True 70 | scalar = x1 71 | array = x2 72 | else: 73 | flip = False 74 | array = x1 75 | scalar = x2 76 | 77 | array = base.ascontiguousarray(array) 78 | 79 | if (array.dtype == np.dtype('int32') and isinstance(scalar, (int)) 80 | or cmp_op): 81 | out_dtype = np.dtype('int32') 82 | else: 83 | out_dtype = np.dtype('float32') 84 | 85 | # Create/check output array 86 | if out is None: 87 | out = cudarray.empty(array.shape, dtype=out_dtype) 88 | else: 89 | if out.shape != array.shape: 90 | raise ValueError('out.shape does not match result') 91 | if out.dtype != out_dtype: 92 | raise ValueError('dtype mismatch') 93 | n = array.size 94 | if cmp_op: 95 | elementwise._binary_cmp_scalar(op, array._data, scalar, n, 96 | out._data, flip) 97 | else: 98 | elementwise._binary_scalar(op, array._data, scalar, n, out._data, 99 | flip) 100 | return out 101 | 102 | x1 = base.ascontiguousarray(x1) 103 | x2 = base.ascontiguousarray(x2) 104 | 105 | if x1.dtype == x2.dtype == np.dtype('int32') or cmp_op: 106 | out_dtype = np.dtype('int32') 107 | else: 108 | out_dtype = np.dtype('float32') 109 | 110 | # Create/check output array 111 | if x1.size < x2.size: 112 | flip = True 113 | x1, x2 = x2, x1 114 | else: 115 | flip = False 116 | if out is None: 117 | out = cudarray.empty(x1.shape, dtype=out_dtype) 118 | else: 119 | if out.shape != x1.shape: 120 | raise ValueError('out.shape does not match result') 121 | if out.dtype != out_dtype: 122 | raise ValueError('dtype mismatch') 123 | 124 | btype = broadcast_type(x1.shape, x2.shape) 125 | if btype is None: 126 | n = x1.size 127 | if cmp_op: 128 | elementwise._binary_cmp(op, x1._data, x2._data, n, out._data) 129 | else: 130 | elementwise._binary(op, x1._data, x2._data, n, out._data) 131 | return out 132 | else: 133 | btype, k, m, n = btype 134 | if cmp_op: 135 | if flip and op in [elementwise.lt_op, elementwise.gt_op, 136 | elementwise.lt_eq_op, elementwise.gt_eq_op]: 137 | raise NotImplementedError( 138 | 'Broadcast of non-commutative operations not supported' 139 | ) 140 | 141 | elementwise._binary_cmp_broadcast(op, btype, x1._data, x2._data, 142 | k, m, n, out._data) 143 | else: 144 | if flip and op in [elementwise.sub_op, elementwise.div_op, 145 | elementwise.pow_op]: 146 | raise NotImplementedError( 147 | 'Broadcast of non-commutative operations not supported' 148 | ) 149 | elementwise._binary_broadcast(op, btype, x1._data, x2._data, k, m, 150 | n, out._data) 151 | return out 152 | 153 | 154 | def add(x1, x2, out=None): 155 | return binary(elementwise.add_op, x1, x2, out) 156 | 157 | 158 | def subtract(x1, x2, out=None): 159 | return binary(elementwise.sub_op, x1, x2, out) 160 | 161 | 162 | def multiply(x1, x2, out=None): 163 | return binary(elementwise.mul_op, x1, x2, out) 164 | 165 | 166 | def divide(x1, x2, out=None): 167 | return binary(elementwise.div_op, x1, x2, out) 168 | 169 | 170 | def power(x1, x2, out=None): 171 | return binary(elementwise.pow_op, x1, x2, out) 172 | 173 | 174 | def maximum(x1, x2, out=None): 175 | return binary(elementwise.max_op, x1, x2, out) 176 | 177 | 178 | def minimum(x1, x2, out=None): 179 | return binary(elementwise.min_op, x1, x2, out) 180 | 181 | 182 | def equal(x1, x2, out=None): 183 | return binary(elementwise.eq_op, x1, x2, out, True) 184 | 185 | 186 | def greater(x1, x2, out=None): 187 | return binary(elementwise.gt_op, x1, x2, out, True) 188 | 189 | 190 | def greater_equal(x1, x2, out=None): 191 | return binary(elementwise.gt_eq_op, x1, x2, out, True) 192 | 193 | 194 | def less(x1, x2, out=None): 195 | return binary(elementwise.lt_op, x1, x2, out, True) 196 | 197 | 198 | def less_equal(x1, x2, out=None): 199 | return binary(elementwise.lt_eq_op, x1, x2, out, True) 200 | 201 | 202 | def not_equal(x1, x2, out=None): 203 | return binary(elementwise.neq_op, x1, x2, out, True) 204 | 205 | 206 | def unary(op, x, out=None): 207 | x = base.ascontiguousarray(x) 208 | out_shape = x.shape 209 | if out is None: 210 | out = cudarray.empty(out_shape, dtype=x.dtype) 211 | else: 212 | if not out_shape == out.shape: 213 | raise ValueError('out.shape does not match result') 214 | if not x.dtype == out.dtype: 215 | raise ValueError('dtype mismatch') 216 | n = x.size 217 | elementwise._unary(op, x._data, n, out._data) 218 | return out 219 | 220 | 221 | def absolute(x, out=None): 222 | return unary(elementwise.abs_op, x, out) 223 | fabs = absolute 224 | 225 | 226 | def cos(x, out=None): 227 | return unary(elementwise.cos_op, x, out) 228 | 229 | 230 | def exp(x, out=None): 231 | return unary(elementwise.exp_op, x, out) 232 | 233 | 234 | def fabs(x, out=None): 235 | return unary(elementwise.abs_op, x, out) 236 | 237 | 238 | def log(x, out=None): 239 | return unary(elementwise.log_op, x, out) 240 | 241 | 242 | def log1p(x, out=None): 243 | return unary(elementwise.log1p_op, x, out) 244 | 245 | 246 | def negative(x, out=None): 247 | return unary(elementwise.neg_op, x, out) 248 | 249 | 250 | def sin(x, out=None): 251 | return unary(elementwise.sin_op, x, out) 252 | 253 | 254 | def sqrt(x, out=None): 255 | return unary(elementwise.sqrt_op, x, out) 256 | 257 | 258 | def tanh(x, out=None): 259 | return unary(elementwise.tanh_op, x, out) 260 | 261 | 262 | def clip(a, a_min, a_max, out=None): 263 | a = base.ascontiguousarray(a) 264 | out_shape = a.shape 265 | if out is None: 266 | out = cudarray.empty(out_shape, dtype=a.dtype) 267 | else: 268 | if not out_shape == out.shape: 269 | raise ValueError('out.shape does not match result') 270 | if not a.dtype == out.dtype: 271 | raise ValueError('dtype mismatch') 272 | n = a.size 273 | elementwise._clip(a._data, a_min, a_max, n, out._data) 274 | return out 275 | -------------------------------------------------------------------------------- /cudarray/extra/__init__.py: -------------------------------------------------------------------------------- 1 | from .array import concatenate, split 2 | -------------------------------------------------------------------------------- /cudarray/extra/array.py: -------------------------------------------------------------------------------- 1 | import cudarray as ca 2 | from ..wrap import array_ops 3 | from ..helpers import prod 4 | 5 | 6 | def concatenate(a, b, axis=0, out=None): 7 | ndim = a.ndim 8 | a_shp = a.shape 9 | b_shp = b.shape 10 | 11 | d_concat = a_shp[axis] + b_shp[axis] 12 | out_shp = a_shp[:axis] + (d_concat,) + a_shp[axis+1:] 13 | if out is None: 14 | out = ca.empty(out_shp, dtype=a.dtype) 15 | else: 16 | if out.shape != out_shp: 17 | raise ValueError('shape mismatch') 18 | 19 | da = a_shp[axis] 20 | db = b_shp[axis] 21 | if ndim < 3: 22 | a_shp = a_shp + (1,)*(3-ndim) 23 | b_shp = b_shp + (1,)*(3-ndim) 24 | elif ndim > 3: 25 | if axis == 0: 26 | a_shp = a_shp[axis], prod(a_shp[1:]), 1 27 | b_shp = b_shp[axis], prod(b_shp[1:]), 1 28 | elif axis + 1 == ndim: 29 | a_shp = 1, prod(a_shp[:axis]), a_shp[axis] 30 | b_shp = 1, prod(b_shp[:axis]), b_shp[axis] 31 | axis = 2 32 | else: 33 | a_shp = prod(a_shp[:axis]), a_shp[axis], prod(a_shp[axis+1:]) 34 | b_shp = prod(b_shp[:axis]), b_shp[axis], prod(b_shp[axis+1:]) 35 | axis = 1 36 | d0, d1, d2 = a_shp[:axis] + (d_concat,) + a_shp[axis+1:] 37 | array_ops._concatenate(a._data, b._data, axis, d0, d1, d2, da, db, 38 | out._data) 39 | return out 40 | 41 | 42 | def split(arr, a_size, axis=0, out_a=None, out_b=None): 43 | shp = arr.shape 44 | ndim = arr.ndim 45 | da = a_size 46 | db = shp[axis]-a_size 47 | 48 | out_a_shp = shp[:axis] + (da,) + shp[axis+1:] 49 | out_b_shp = shp[:axis] + (db,) + shp[axis+1:] 50 | if out_a is None: 51 | out_a = ca.empty(out_a_shp, dtype=arr.dtype) 52 | else: 53 | if out_a.shape != out_a_shp: 54 | raise ValueError('shape mismatch') 55 | if out_b is None: 56 | out_b = ca.empty(out_b_shp, dtype=arr.dtype) 57 | else: 58 | if out_b.shape != out_b_shp: 59 | raise ValueError('shape mismatch') 60 | 61 | if ndim < 3: 62 | shp = shp + (1,)*(3-ndim) 63 | elif ndim > 3: 64 | if axis == 0: 65 | shp = shp[axis], prod(shp[1:]), 1 66 | elif axis + 1 == ndim: 67 | shp = 1, prod(shp[:axis]), shp[axis] 68 | axis = 2 69 | else: 70 | shp = prod(shp[:axis]), shp[axis], prod(shp[axis+1:]) 71 | axis = 1 72 | 73 | d0, d1, d2 = shp 74 | array_ops._split(arr._data, axis, d0, d1, d2, da, db, out_a._data, 75 | out_b._data) 76 | return out_a, out_b 77 | -------------------------------------------------------------------------------- /cudarray/helpers.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import operator 3 | 4 | 5 | def normalize_axis(axis, ndim): 6 | if axis is None: 7 | return tuple(range(ndim)) 8 | elif isinstance(axis, int): 9 | return (axis,) 10 | elif isinstance(axis, tuple): 11 | return tuple(sorted(axis)) 12 | else: 13 | raise ValueError('invalid axis type') 14 | 15 | 16 | def normalize_shape(shape): 17 | if isinstance(shape, int): 18 | return (shape,) 19 | elif isinstance(shape, tuple): 20 | return shape 21 | else: 22 | raise ValueError('invalid shape') 23 | 24 | 25 | def prod(iterable): 26 | return reduce(operator.mul, iterable, 1) 27 | 28 | 29 | def require_iterable(x): 30 | if hasattr(x, '__iter__'): 31 | return x 32 | else: 33 | return [x] 34 | -------------------------------------------------------------------------------- /cudarray/linalg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .wrap import blas 4 | from . import cudarray 5 | 6 | 7 | def matmul_shape(a_shape, b_shape): 8 | a_ndim = len(a_shape) 9 | b_ndim = len(b_shape) 10 | if a_ndim == 1 and b_ndim == 2: 11 | if a_shape[0] != b_shape[0]: 12 | raise ValueError('shape mismatch') 13 | return (b_shape[1],) 14 | elif a_ndim == 2 and b_ndim == 1: 15 | if a_shape[1] != b_shape[0]: 16 | raise ValueError('shape mismatch') 17 | return (a_shape[0],) 18 | elif a_ndim == 2 and b_ndim == 2: 19 | if a_shape[1] != b_shape[0]: 20 | raise ValueError('shape mismatch') 21 | return (a_shape[0], b_shape[1]) 22 | else: 23 | raise ValueError('only 1D and 2D arrays are supported') 24 | 25 | 26 | def inner(a, b): 27 | if a.dtype != b.dtype: 28 | raise ValueError('dtype mismatch') 29 | if not a.ndim == b.ndim == 1: 30 | raise ValueError('shape mismatch') 31 | if a.size != b.size: 32 | raise ValueError('size mismatch') 33 | return blas.dot_(a._data, b._data, a.size) 34 | 35 | 36 | def dot(a, b, out=None): 37 | if a.ndim == b.ndim == 1: 38 | return inner(a, b) 39 | 40 | if a.dtype != b.dtype: 41 | raise ValueError('dtype mismatch') 42 | 43 | out_shape = matmul_shape(a.shape, b.shape) 44 | if out is None: 45 | out = cudarray.empty(out_shape, dtype=a.dtype) 46 | else: 47 | if out_shape != out.shape: 48 | raise ValueError('out.shape does not match result') 49 | if a.dtype != out.dtype: 50 | raise ValueError('dtype mismatch') 51 | 52 | if a.ndim == b.ndim == 2: 53 | m, k = a.shape[:2] 54 | n = b.shape[1] 55 | transA = blas.trans_op if a.transposed else blas.no_trans_op 56 | transB = blas.trans_op if b.transposed else blas.no_trans_op 57 | blas.gemm_(a._data, b._data, transA, transB, m, n, k, 1.0, 0.0, 58 | out._data) 59 | elif a.ndim == 2 and b.ndim == 1: 60 | m, n = a.shape 61 | trans = blas.trans_op if a.transposed else blas.no_trans_op 62 | blas.gemv_(a._data, b._data, trans, m, n, 1.0, 0.0, out._data) 63 | elif a.ndim == 1 and b.ndim == 2: 64 | n, m = b.shape 65 | trans = blas.no_trans_op if b.transposed else blas.trans_op 66 | blas.gemv_(b._data, a._data, trans, m, n, 1.0, 0.0, out._data) 67 | else: 68 | raise ValueError('invalid array dimensionality') 69 | return out 70 | -------------------------------------------------------------------------------- /cudarray/nnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .image import * 2 | from .math import * 3 | from .special import * 4 | from .conv import * 5 | from .pool import * 6 | -------------------------------------------------------------------------------- /cudarray/nnet/conv.py: -------------------------------------------------------------------------------- 1 | import cudarray as ca 2 | from ..wrap import nnet 3 | 4 | 5 | try: 6 | from ..wrap import cudnn 7 | _default_impl = 'cudnn' 8 | except: 9 | _default_impl = 'cudarray' 10 | 11 | 12 | class ConvBC01(object): 13 | def __init__(self, padding, strides, impl=None): 14 | self.padding = padding 15 | self.strides = strides 16 | self.impl = _default_impl if impl is None else impl 17 | if self.impl == 'cudarray': 18 | self.conv_cudnn = None 19 | elif self.impl == 'cudnn': 20 | self.conv_cudnn = cudnn.conv_bc01_cudnn(padding, strides) 21 | else: 22 | raise ValueError('invalid implementation: %s' % self.impl) 23 | 24 | def fprop(self, imgs, filters, convout=None): 25 | b, c, img_h, img_w = imgs.shape 26 | f, c_filters, filter_h, filter_w = filters.shape 27 | if c != c_filters: 28 | raise ValueError('channel mismatch') 29 | if imgs.dtype != filters.dtype: 30 | raise ValueError('dtype mismatch') 31 | img_shape = (img_h, img_w) 32 | self.imgs_shape = imgs.shape 33 | filter_shape = (filter_h, filter_w) 34 | convout_shape = self.output_shape(imgs.shape, f, (filter_h, filter_w)) 35 | if convout is None: 36 | convout = ca.empty(convout_shape, dtype=imgs.dtype) 37 | else: 38 | if convout.shape != convout_shape: 39 | raise ValueError('convout.shape does not match result') 40 | if convout.dtype != imgs.dtype: 41 | raise ValueError('dtype mismatch') 42 | if self.impl == 'cudarray': 43 | nnet._conv_bc01_matmul( 44 | imgs._data, filters._data, b, c, f, img_shape, filter_shape, 45 | self.padding, self.strides, convout._data 46 | ) 47 | else: 48 | self.conv_cudnn.fprop( 49 | imgs._data, filters._data, b, c, f, img_shape, filter_shape, 50 | convout._data 51 | ) 52 | 53 | return convout 54 | 55 | def bprop(self, imgs, filters, convout_d, to_filters=True, to_imgs=True, 56 | filters_d=None, imgs_d=None): 57 | if imgs is not None: 58 | b, c, img_h, img_w = imgs.shape 59 | else: 60 | b, c, img_h, img_w = self.imgs_shape 61 | if filters is not None: 62 | f, c, filter_h, filter_w = filters.shape 63 | if imgs_d is not None: 64 | b, c, img_h, img_w = imgs_d.shape 65 | b_convout, f_convout, convout_h, convout_w = convout_d.shape 66 | imgs_shape = self.imgs_shape 67 | img_shape = imgs_shape[2:] 68 | filter_shape = (filter_h, filter_w) 69 | 70 | if to_filters: 71 | if filters_d is None: 72 | filters_d = ca.empty(filters.shape, dtype=filters.dtype) 73 | else: 74 | if filters_d.shape != filters.shape: 75 | raise ValueError('filters_d.shape does not match result') 76 | if filters_d.dtype != filters.dtype: 77 | raise ValueError('dtype mismatch') 78 | if self.impl == 'cudarray': 79 | nnet._conv_bc01_matmul_bprop_filters( 80 | imgs._data, convout_d._data, b, c, f, img_shape, 81 | filter_shape, self.padding, self.strides, filters_d._data 82 | ) 83 | 84 | if to_imgs: 85 | if imgs_d is None: 86 | imgs_d = ca.empty(imgs_shape, dtype=convout_d.dtype) 87 | elif imgs is not None: 88 | if imgs_d.shape != imgs.shape: 89 | raise ValueError('imgs_d.shape does not match result') 90 | if imgs_d.dtype != imgs.dtype: 91 | raise ValueError('dtype mismatch') 92 | if self.impl == 'cudarray': 93 | nnet._conv_bc01_matmul_bprop_imgs( 94 | filters._data, convout_d._data, b, c, f, img_shape, 95 | filter_shape, self.padding, self.strides, imgs_d._data 96 | ) 97 | 98 | if self.impl == 'cudnn': 99 | imgs_ = None if imgs is None else imgs._data 100 | imgs_d_ = imgs_d._data if to_imgs else None 101 | filters_ = filters_d._data if to_filters else None 102 | self.conv_cudnn.bprop(imgs_, filters._data, convout_d._data, 103 | imgs_d_, filters_) 104 | 105 | return filters_d, imgs_d 106 | 107 | def output_shape(self, imgs_shape, n_filters, filter_shape): 108 | b, _, img_h, img_w = imgs_shape 109 | out_shape = ((img_h + 2*self.padding[0] - filter_shape[0]) // 110 | self.strides[0] + 1, 111 | (img_w + 2*self.padding[1] - filter_shape[1]) // 112 | self.strides[1] + 1) 113 | return (b, n_filters) + out_shape 114 | -------------------------------------------------------------------------------- /cudarray/nnet/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cudarray as ca 3 | from ..wrap import image 4 | 5 | 6 | def rescale(imgs, factor, method, out=None): 7 | img_h, img_w = imgs.shape[-2:] 8 | batch_shape = imgs.shape[:-2] 9 | n_imgs = np.prod(batch_shape) 10 | if factor > 1: 11 | scaled_h = int(np.floor(img_h * factor)) 12 | scaled_w = int(np.floor(img_w * factor)) 13 | else: 14 | scaled_h = int(np.ceil(img_h * factor)) 15 | scaled_w = int(np.ceil(img_w * factor)) 16 | out_shape = batch_shape + (scaled_h, scaled_w) 17 | if out is None: 18 | out = ca.empty(out_shape, dtype=imgs.dtype) 19 | else: 20 | if out.shape != out_shape: 21 | raise ValueError('shape mismatch') 22 | method = image.sample_methods[method] 23 | image._rescale(imgs._data, factor, method, n_imgs, img_h, img_w, out._data) 24 | return out 25 | -------------------------------------------------------------------------------- /cudarray/nnet/math.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ..elementwise import unary 4 | from ..wrap import elementwise 5 | 6 | 7 | def relu(x, out=None): 8 | return unary(elementwise.relu_op, x, out) 9 | 10 | 11 | def relu_d(x, out=None): 12 | return unary(elementwise.relu_d_op, x, out) 13 | 14 | 15 | def sigmoid(x, out=None): 16 | return unary(elementwise.sigmoid_op, x, out) 17 | 18 | 19 | def sigmoid_d(x, out=None): 20 | return unary(elementwise.sigmoid_d_op, x, out) 21 | 22 | 23 | def softplus(x, out=None): 24 | return unary(elementwise.softplus_op, x, out) 25 | 26 | 27 | def softplus_d(x, out=None): 28 | return unary(elementwise.softplus_d_op, x, out) 29 | 30 | 31 | def tanh_d(x, out=None): 32 | return unary(elementwise.tanh_d_op, x, out) 33 | -------------------------------------------------------------------------------- /cudarray/nnet/pool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cudarray as ca 3 | from ..wrap import nnet 4 | 5 | 6 | try: 7 | from ..wrap import cudnn 8 | _default_impl = 'cudnn' 9 | except: 10 | _default_impl = 'cudarray' 11 | 12 | 13 | class PoolB01(object): 14 | def __init__(self, win_shape, padding, strides, method='max', impl=None): 15 | self.win_shape = win_shape 16 | self.padding = padding 17 | self.strides = strides 18 | self.impl = _default_impl if impl is None else impl 19 | if method not in ['max', 'avg']: 20 | raise ValueError('invalid pooling method: %s' % method) 21 | self.method = method 22 | if self.impl == 'cudarray': 23 | self.mask = None 24 | elif self.impl == 'cudnn': 25 | self.last_poolout = None 26 | self.pool_cudnn = cudnn.PoolBC01CuDNN_f(win_shape, padding, 27 | strides, method) 28 | else: 29 | raise ValueError('invalid implementation: %s' % self.impl) 30 | 31 | def fprop(self, imgs, poolout=None): 32 | poolout_shape = self.output_shape(imgs.shape) 33 | if poolout is None: 34 | poolout = ca.empty(poolout_shape, dtype=imgs.dtype) 35 | else: 36 | if poolout_shape != poolout.shape: 37 | raise ValueError('poolout.shape does not match result') 38 | if imgs.dtype != poolout.dtype: 39 | raise ValueError('dtype mismatch') 40 | 41 | img_shape = imgs.shape[-2:] 42 | if self.impl == 'cudarray': 43 | n_imgs = np.prod(imgs.shape[:-2]) 44 | if self.method == 'max': 45 | if self.mask is None or self.mask.shape != poolout_shape: 46 | self.mask = ca.empty(poolout_shape, 47 | dtype=np.dtype('int32')) 48 | nnet._max_pool_b01( 49 | imgs._data, n_imgs, img_shape, self.win_shape, 50 | self.padding, self.strides, poolout._data, self.mask._data 51 | ) 52 | else: 53 | nnet._avg_pool_b01( 54 | imgs._data, n_imgs, img_shape, self.win_shape, 55 | self.padding, self.strides, poolout._data 56 | ) 57 | else: 58 | n_imgs, n_channels = imgs.shape[:2] 59 | self.last_imgs = imgs 60 | self.last_poolout = poolout 61 | self.pool_cudnn.fprop( 62 | imgs._data, imgs.shape, poolout._data 63 | ) 64 | return poolout 65 | 66 | def bprop(self, img_shape, poolout_d, imgs_d=None): 67 | n_imgs_shape = poolout_d.shape[:-2] 68 | imgs_shape = n_imgs_shape + img_shape 69 | 70 | if imgs_d is None: 71 | imgs_d = ca.empty(imgs_shape, dtype=poolout_d.dtype) 72 | else: 73 | if imgs_d.shape != imgs_d.shape: 74 | raise ValueError('poolout.shape does not match result') 75 | if imgs_d.dtype != poolout_d.dtype: 76 | raise ValueError('dtype mismatch') 77 | 78 | if self.impl == 'cudarray': 79 | n_imgs = np.prod(n_imgs_shape) 80 | if self.method == 'max': 81 | nnet._max_pool_b01_bprop( 82 | poolout_d._data, self.mask._data, n_imgs, img_shape, 83 | self.win_shape, self.padding, self.strides, imgs_d._data 84 | ) 85 | else: 86 | nnet._avg_pool_b01_bprop( 87 | poolout_d._data, n_imgs, img_shape, self.win_shape, 88 | self.padding, self.strides, imgs_d._data 89 | ) 90 | else: 91 | self.pool_cudnn.bprop( 92 | self.last_imgs._data, self.last_poolout._data, poolout_d._data, 93 | imgs_d._data 94 | ) 95 | return imgs_d 96 | 97 | def output_shape(self, imgs_shape): 98 | n_imgs_shape = imgs_shape[:-2] 99 | img_h, img_w = imgs_shape[-2:] 100 | out_shape = ((img_h + 2*self.padding[0] - self.win_shape[0]) // 101 | self.strides[0] + 1, 102 | (img_w + 2*self.padding[1] - self.win_shape[1]) // 103 | self.strides[1] + 1) 104 | return n_imgs_shape + out_shape 105 | 106 | def __getstate__(self): 107 | ignore = ['mask', 'last_poolout', 'last_imgs'] 108 | return dict((k, None) if k in ignore else (k, v) 109 | for k, v in self.__dict__.items()) 110 | -------------------------------------------------------------------------------- /cudarray/nnet/special.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cudarray as ca 3 | from ..wrap import nnet 4 | 5 | 6 | def softmax(x): 7 | e = ca.exp(x - ca.amax(x, axis=1, keepdims=True)) 8 | return e/ca.sum(e, axis=1, keepdims=True) 9 | 10 | 11 | def categorical_cross_entropy(y_pred, y_true, eps=1e-15): 12 | # Assumes one-hot encoding. 13 | y_pred = ca.clip(y_pred, eps, 1 - eps) 14 | # XXX: do we need to normalize? 15 | y_pred /= ca.sum(y_pred, axis=1, keepdims=True) 16 | loss = -ca.sum(y_true * ca.log(y_pred), axis=1) 17 | return loss 18 | 19 | 20 | def one_hot_encode(labels, n_classes, out=None): 21 | out_shape = (labels.size, n_classes) 22 | if labels.dtype != np.dtype('int32'): 23 | raise ValueError('labels.dtype must be int') 24 | if out is None: 25 | out = ca.empty(out_shape) 26 | else: 27 | if out.shape != out_shape: 28 | raise ValueError('shape mismatch') 29 | nnet._one_hot_encode(labels._data, n_classes, out_shape[0], out._data) 30 | return out 31 | 32 | 33 | def one_hot_decode(one_hot, out=None): 34 | out_shape = (one_hot.shape[0],) 35 | if out is None: 36 | out = ca.empty(out_shape, dtype=np.dtype('int32')) 37 | else: 38 | if out.dtype != np.dtype('int32'): 39 | raise ValueError('out.dtype must be int') 40 | if out.shape != out_shape: 41 | raise ValueError('shape mismatch') 42 | ca.argmax(one_hot, axis=1, out=out) 43 | return out 44 | -------------------------------------------------------------------------------- /cudarray/numpy_backend/__init__.py: -------------------------------------------------------------------------------- 1 | from numpy import * 2 | from .nnet import * 3 | -------------------------------------------------------------------------------- /cudarray/numpy_backend/nnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .special import * 3 | from .conv_bc01 import * 4 | from .pool_bc01 import * 5 | from .lrnorm_bc01 import * 6 | from .conv import * 7 | from .pool import * 8 | -------------------------------------------------------------------------------- /cudarray/numpy_backend/nnet/activations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def _output(result, out): 5 | if out is None: 6 | return result 7 | else: 8 | np.copyto(out, result) 9 | return out 10 | 11 | 12 | def sigmoid(x, out=None): 13 | result = 1.0/(1.0+np.exp(-x)) 14 | return _output(result, out) 15 | 16 | 17 | def sigmoid_d(x, out=None): 18 | s = sigmoid(x) 19 | result = s*(1-s) 20 | return _output(result, out) 21 | 22 | 23 | def tanh_d(x, out=None): 24 | result = 1-np.tanh(x)**2 25 | return _output(result, out) 26 | 27 | 28 | def relu(x, out=None): 29 | result = np.maximum(0.0, x) 30 | return _output(result, out) 31 | 32 | 33 | def relu_d(x, out=None): 34 | result = np.zeros(x.shape) 35 | result[x >= 0] = 1 36 | return _output(result, out) 37 | 38 | 39 | def softplus(x, out=None): 40 | result = np.log1p(np.exp(x)) 41 | mask = x > 25.0 42 | result[mask] = x[mask] 43 | return _output(result, out) 44 | 45 | 46 | def softplus_d(x, out=None): 47 | result = 1.0 - 1.0/(1.0 + np.exp(x)) 48 | result[x > 25.0] = 1.0 49 | return _output(result, out) 50 | -------------------------------------------------------------------------------- /cudarray/numpy_backend/nnet/conv.py: -------------------------------------------------------------------------------- 1 | from .conv_bc01 import * 2 | import cudarray as ca 3 | 4 | 5 | class ConvBC01(object): 6 | def __init__(self, padding, strides): 7 | self.padding = padding 8 | self.strides = strides 9 | 10 | def fprop(self, imgs, filters, convout=None): 11 | b, c, img_h, img_w = imgs.shape 12 | f, c_filters, filter_h, filter_w = filters.shape 13 | if c != c_filters: 14 | raise ValueError('channel mismatch') 15 | if imgs.dtype != filters.dtype: 16 | raise ValueError('dtype mismatch') 17 | 18 | convout_shape = self.output_shape(imgs.shape, f, (filter_h, filter_w)) 19 | if convout is None: 20 | convout = ca.empty(convout_shape, dtype=imgs.dtype) 21 | else: 22 | if convout.shape != convout_shape: 23 | raise ValueError('convout.shape does not match result') 24 | if convout.dtype != imgs.dtype: 25 | raise ValueError('dtype mismatch') 26 | 27 | conv_bc01(imgs=imgs, 28 | filters=filters, 29 | padding=self.padding, 30 | strides=self.strides, 31 | convout=convout) 32 | 33 | self.last_imgs = imgs 34 | 35 | return convout 36 | 37 | def bprop(self, imgs, filters, convout_d, to_filters=True, to_imgs=True, 38 | filters_d=None, imgs_d=None): 39 | if imgs is None: 40 | imgs = self.last_imgs 41 | b, c, _, _ = imgs.shape 42 | f, c_filters, _, _ = filters.shape 43 | b_convout, f_convout, _, _ = convout_d.shape 44 | 45 | if b != b_convout: 46 | raise ValueError('batch mismatch') 47 | if f != f_convout: 48 | raise ValueError('filter mismatch') 49 | if c != c_filters: 50 | raise ValueError('channel mismatch') 51 | 52 | if imgs.dtype != filters.dtype != convout_d.dtype: 53 | raise ValueError('dtype mismatch') 54 | 55 | if filters_d is None: 56 | filters_d = ca.empty(filters.shape, dtype=filters.dtype) 57 | 58 | if imgs_d is None: 59 | imgs_d = ca.empty(imgs.shape, dtype=imgs.dtype) 60 | 61 | conv_bc01_bprop(imgs=imgs, 62 | convout_d=convout_d, 63 | filters=filters, 64 | padding=self.padding, 65 | strides=self.strides, 66 | imgs_grad=imgs_d, 67 | filters_grad=filters_d) 68 | 69 | return filters_d, imgs_d 70 | 71 | def output_shape(self, imgs_shape, n_filters, filter_shape): 72 | b, _, img_h, img_w = imgs_shape 73 | out_shape = ((img_h + 2*self.padding[0] - filter_shape[0]) 74 | / self.strides[0] + 1, 75 | (img_w + 2*self.padding[1] - filter_shape[1]) 76 | / self.strides[1] + 1) 77 | return (b, n_filters) + out_shape 78 | -------------------------------------------------------------------------------- /cudarray/numpy_backend/nnet/conv_bc01.pyx: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import cython 4 | #from cython.parallel import parallel, prange, threadlocal 5 | cimport numpy as np 6 | 7 | 8 | DTYPE = np.float 9 | ctypedef np.float_t DTYPE_t 10 | ctypedef Py_ssize_t uint 11 | 12 | 13 | cdef inline int int_max(int a, int b) nogil: return a if a >= b else b 14 | cdef inline int int_min(int a, int b) nogil: return a if a <= b else b 15 | 16 | 17 | @cython.boundscheck(False) 18 | @cython.wraparound(False) 19 | def conv_bc01(np.ndarray[DTYPE_t, ndim=4] imgs, 20 | np.ndarray[DTYPE_t, ndim=4] filters, 21 | tuple padding, 22 | tuple strides, 23 | np.ndarray[DTYPE_t, ndim=4] convout): 24 | """ Multi-image, multi-channel convolution 25 | imgs has shape (n_imgs, n_channels_in, img_h, img_w) 26 | filters has shape (n_channels_out, n_channels_in, filter_h, filter_w) 27 | """ 28 | # TODO: support padding and striding 29 | 30 | cdef uint n_imgs = imgs.shape[0] 31 | cdef uint img_h = imgs.shape[2] 32 | cdef uint img_w = imgs.shape[3] 33 | cdef uint n_channels_in = filters.shape[1] 34 | cdef uint n_channels_out = filters.shape[0] 35 | cdef uint fil_h = filters.shape[2] 36 | cdef uint fil_w = filters.shape[3] 37 | 38 | cdef int fil_mid_h = fil_h // 2 39 | cdef int fil_mid_w = fil_w // 2 40 | 41 | cdef uint i, c_in, c_out 42 | cdef uint img_y, img_x, fil_y, fil_x 43 | cdef DTYPE_t value 44 | 45 | cdef int y, x, y_off_min, y_off_max, y_off, x_off_min, x_off_max, x_off, mid_off_h, mid_off_w, img_x_center, img_y_center 46 | 47 | """mid_off only add one to max iff filter is of an uneaven sice 48 | This is done because filters of uneaven size have center shifte one Back-propagate 49 | [ 1, 1 , x , 1] wher x is center for a 1X4 filter""" 50 | mid_off_h = fil_h % 2 51 | mid_off_w = fil_w % 2 52 | 53 | cdef uint stride_h = strides[0] 54 | cdef uint stride_w = strides[1] 55 | 56 | cdef uint padding_h = padding[0] 57 | cdef uint padding_w = padding[1] 58 | 59 | cdef uint out_h = convout.shape[2] 60 | cdef uint out_w = convout.shape[3] 61 | 62 | for i in range(n_imgs): 63 | for c_out in range(n_channels_out): 64 | for y in range(out_h): 65 | img_y_center = y*stride_h+fil_mid_h 66 | y_off_min = int_max(-img_y_center, -padding_h-fil_mid_h) 67 | y_off_max = int_min(img_h-img_y_center, fil_mid_h+mid_off_h-padding_h) 68 | for x in range(out_w): 69 | img_x_center = x*stride_w+fil_mid_w 70 | x_off_min = int_max(-img_x_center, -padding_w-fil_mid_w) 71 | x_off_max = int_min(img_w-img_x_center, fil_mid_w+mid_off_w-padding_w) 72 | value = 0.0 73 | for y_off in range(y_off_min, y_off_max): 74 | for x_off in range(x_off_min, x_off_max): 75 | img_y = (img_y_center + y_off) 76 | img_x = (img_x_center + x_off) 77 | fil_y = (fil_mid_h + padding_h + y_off) 78 | fil_x = (fil_mid_w + padding_w + x_off) 79 | for c_in in range(n_channels_in): 80 | value += imgs[i, c_in, img_y, img_x] * filters[c_out, c_in, fil_y, fil_x] 81 | convout[i, c_out, y, x] = value 82 | 83 | return convout 84 | 85 | @cython.boundscheck(False) 86 | @cython.wraparound(False) 87 | def conv_bc01_bprop(np.ndarray[DTYPE_t, ndim=4] imgs, 88 | np.ndarray[DTYPE_t, ndim=4] convout_d, 89 | np.ndarray[DTYPE_t, ndim=4] filters, 90 | tuple padding, 91 | tuple strides, 92 | np.ndarray[DTYPE_t, ndim=4] imgs_grad, 93 | np.ndarray[DTYPE_t, ndim=4] filters_grad): 94 | """ Back-propagate gradients of multi-image, multi-channel convolution 95 | imgs has shape (b, c, img_h, img_w) 96 | filters has shape (f, c_filters, img_h, img_w) 97 | convout has shape (b_convout, f_convout, img_h, img_w) 98 | """ 99 | cdef uint img_channels = imgs.shape[1] 100 | cdef uint img_h = imgs.shape[2] 101 | cdef uint img_w = imgs.shape[3] 102 | cdef uint b_convout = convout_d.shape[0] 103 | cdef uint f_convout = convout_d.shape[1] 104 | cdef uint convout_d_h = convout_d.shape[2] 105 | cdef uint convout_d_w = convout_d.shape[3] 106 | 107 | cdef uint fil_h = filters.shape[2] 108 | cdef uint fil_w = filters.shape[3] 109 | cdef int fil_mid_h = fil_h // 2 110 | cdef int fil_mid_w = fil_w // 2 111 | 112 | cdef uint i, c_convout, c_imgs 113 | cdef uint img_y, img_x, fil_y, fil_x 114 | cdef DTYPE_t convout_d_value 115 | cdef int y, x, y_off_min, y_off_max, y_off, x_off_min, x_off_max 116 | cdef int x_off, mid_off_h, mid_off_w, img_x_center, img_y_center 117 | 118 | """mid_off only add one to max iff filter is of an uneaven sice 119 | This is done because filters of uneaven size have center shifte one Back-propagate 120 | [ 1, 1 , x , 1] wher x is center for a 1X4 filter""" 121 | mid_off_h = fil_h % 2 122 | mid_off_w = fil_w % 2 123 | 124 | cdef uint stride_h = strides[0] 125 | cdef uint stride_w = strides[1] 126 | 127 | cdef uint padding_h = padding[0] 128 | cdef uint padding_w = padding[1] 129 | 130 | imgs_grad[...] = 0 131 | filters_grad[...] = 0 132 | for i in range(b_convout): 133 | for c_convout in range(f_convout): 134 | for y in range(convout_d_h): 135 | img_y_center = y*stride_h+fil_mid_h 136 | y_off_min = int_max(-img_y_center, -padding_h-fil_mid_h) 137 | y_off_max = int_min(img_h-img_y_center, fil_mid_h+mid_off_h-padding_h) 138 | for x in range(convout_d_w): 139 | convout_d_value = convout_d[i, c_convout, y, x] 140 | img_x_center = x*stride_w+fil_mid_w 141 | x_off_min = int_max(-img_x_center, -padding_w-fil_mid_w) 142 | x_off_max = int_min(img_w-img_x_center, fil_mid_w+mid_off_w-padding_w) 143 | value = 0.0 144 | for y_off in range(y_off_min, y_off_max): 145 | for x_off in range(x_off_min, x_off_max): 146 | img_y = (img_y_center + y_off) 147 | img_x = (img_x_center + x_off) 148 | fil_y = (fil_mid_h + padding_h + y_off) 149 | fil_x = (fil_mid_w + padding_w + x_off) 150 | for c_imgs in range(img_channels): 151 | imgs_grad[i, c_imgs, img_y, img_x] += filters[c_convout, c_imgs, fil_y, fil_x] * convout_d_value 152 | filters_grad[c_convout, c_imgs, fil_y, fil_x] += imgs[i, c_imgs, img_y, img_x] * convout_d_value 153 | # filters_grad[...] /= n_imgs 154 | -------------------------------------------------------------------------------- /cudarray/numpy_backend/nnet/lrnorm_bc01.pyx: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import cython 4 | cimport numpy as np 5 | 6 | 7 | DTYPE = np.float 8 | ctypedef np.float_t DTYPE_t 9 | ctypedef Py_ssize_t uint 10 | 11 | 12 | @cython.boundscheck(False) 13 | @cython.wraparound(False) 14 | def lrnorm_bc01(np.ndarray[DTYPE_t, ndim=4] imgs, 15 | uint N, 16 | DTYPE_t alpha, 17 | DTYPE_t beta, 18 | DTYPE_t k): 19 | """ 20 | imgs has shape (n_imgs, n_channels, img_h, img_w) 21 | """ 22 | 23 | cdef DTYPE_t norm_window 24 | cdef uint n_imgs = imgs.shape[0] 25 | cdef uint n_channels = imgs.shape[1] 26 | cdef uint img_h = imgs.shape[2] 27 | cdef uint img_w = imgs.shape[3] 28 | 29 | cdef uint half = N // 2 30 | cdef uint tailLength = N - half 31 | 32 | cdef uint max_channel 33 | 34 | cdef DTYPE_t a_i 35 | cdef DTYPE_t a_half 36 | 37 | tail = tailLength*[0.0] 38 | 39 | cdef uint i, y, x, a, c 40 | 41 | for i in range(n_imgs): 42 | for y in range(img_h): 43 | for x in range(img_w): 44 | norm_window = 0.0 45 | tail = tailLength*[0.0] 46 | 47 | for a in range(N + 1): 48 | addToNormWindow(norm_window, imgs[i, a, y, x]) 49 | 50 | for c in range(n_channels): 51 | a_i = imgs[i, c, y, x] 52 | a_half = tail.pop(0) 53 | tail.append(a_i) 54 | #Normalazation 55 | imgs[i, c, y, x] = calcNormCal(a_i, norm_window, alpha, beta, k) 56 | #Move the window for next channel 57 | max_channel = half + c + 1 58 | #Move window if possible 59 | if (max_channel < n_channels and (c >= N) ): 60 | addToNormWindow(norm_window, imgs[i, max_channel, y, x]) 61 | #Remove privius channel from sum 62 | norm_window -= (a_half * a_half) 63 | 64 | return imgs 65 | 66 | @cython.profile(False) 67 | @cython.boundscheck(False) 68 | @cython.wraparound(False) 69 | cdef inline DTYPE_t calcNormCal(DTYPE_t a_i, 70 | DTYPE_t norm_window, 71 | DTYPE_t alpha, 72 | DTYPE_t beta, 73 | DTYPE_t k): 74 | return a_i / ((k + alpha * norm_window) ** beta) 75 | 76 | @cython.profile(False) 77 | @cython.boundscheck(False) 78 | @cython.wraparound(False) 79 | cdef inline addToNormWindow(DTYPE_t norm_window, 80 | DTYPE_t val): 81 | norm_window += (val * val) 82 | -------------------------------------------------------------------------------- /cudarray/numpy_backend/nnet/pool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cudarray as ca 3 | from .pool_bc01 import pool_bc01, bprop_pool_bc01 4 | 5 | 6 | class PoolB01(object): 7 | def __init__(self, win_shape, padding, strides, method='max'): 8 | self.win_shape = win_shape 9 | self.padding = padding 10 | self.strides = strides 11 | if method not in ['max', 'avg']: 12 | raise ValueError('invalid pooling method') 13 | if method == 'max': 14 | self.method = 0 15 | elif method == 'avg': 16 | self.method = 1 17 | 18 | self.mask = None 19 | 20 | def fprop(self, imgs, poolout=None): 21 | poolout_shape = self.output_shape(imgs.shape) 22 | if poolout is None: 23 | poolout = ca.empty(poolout_shape, dtype=imgs.dtype) 24 | else: 25 | if poolout_shape != poolout.shape: 26 | raise ValueError('poolout.shape does not match result') 27 | if imgs.dtype != poolout.dtype: 28 | raise ValueError('dtype mismatch') 29 | 30 | if self.mask is None or self.mask.shape[:-1] != poolout_shape: 31 | self.mask = ca.empty(poolout_shape + (2,), dtype=np.dtype('int_')) 32 | 33 | pool_bc01(imgs=imgs, 34 | win_shape=self.win_shape, 35 | strides=self.strides, 36 | padding=self.padding, 37 | poolout=poolout, 38 | type=self.method, 39 | switches=self.mask) 40 | 41 | return poolout 42 | 43 | def bprop(self, img_shape, poolout_d, imgs_d=None): 44 | n_imgs_shape = poolout_d.shape[:-2] 45 | imgs_shape = n_imgs_shape + img_shape 46 | 47 | if imgs_d is None: 48 | imgs_d = ca.empty(imgs_shape, dtype=poolout_d.dtype) 49 | else: 50 | if imgs_d.shape != imgs_d.shape: 51 | raise ValueError('poolout.shape does not match result') 52 | if imgs_d.dtype != poolout_d.dtype: 53 | raise ValueError('dtype mismatch') 54 | 55 | bprop_pool_bc01(poolout_grad=poolout_d, 56 | win_shape=self.win_shape, 57 | strides=self.strides, 58 | padding=self.padding, 59 | type=self.method, 60 | switches=self.mask, 61 | imgs_grad=imgs_d) 62 | return imgs_d 63 | 64 | def output_shape(self, imgs_shape): 65 | n_imgs_shape = imgs_shape[:-2] 66 | img_h, img_w = imgs_shape[-2:] 67 | out_shape = ((img_h + 2*self.padding[0] - self.win_shape[0]) // 68 | self.strides[0] + 1, 69 | (img_w + 2*self.padding[1] - self.win_shape[1]) // 70 | self.strides[1] + 1) 71 | return n_imgs_shape + out_shape 72 | -------------------------------------------------------------------------------- /cudarray/numpy_backend/nnet/pool_bc01.pyx: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import cython 4 | cimport numpy as np 5 | 6 | cdef int POOL_MAX = 0 7 | cdef int POOL_MEAN = 1 8 | 9 | DTYPE = np.float 10 | ctypedef np.float_t DTYPE_t 11 | ctypedef Py_ssize_t uint 12 | 13 | cdef inline DTYPE_t dtype_t_max(DTYPE_t a, DTYPE_t b): return a if a >= b else b 14 | 15 | cdef inline int int_max(int a, int b): return a if a >= b else b 16 | cdef inline int int_min(int a, int b): return a if a <= b else b 17 | 18 | 19 | @cython.boundscheck(False) 20 | @cython.wraparound(False) 21 | def pool_bc01(np.ndarray[DTYPE_t, ndim=4] imgs, 22 | tuple win_shape, 23 | tuple strides, 24 | tuple padding, 25 | np.ndarray[DTYPE_t, ndim=4] poolout, 26 | uint type, 27 | np.ndarray[np.int_t, ndim=5] switches): 28 | """ Multi-image, multi-channel pooling 29 | imgs has shape (n_imgs, n_channels, img_h, img_w) 30 | win_shape has shape (win_h, win_w) 31 | strides has shape (stride_y, stride_x) 32 | poolout has shape (n_imgs, n_channels, img_h//stride_y, img_w//stride_x) 33 | switches has shape (n_imgs, n_channels, img_h//stride_y, img_w//stride_x, 2) 34 | """ 35 | cdef uint pool_h = win_shape[0] 36 | cdef uint pool_w = win_shape[1] 37 | cdef uint pool_size = pool_h * pool_w 38 | cdef uint stride_x = strides[1] 39 | cdef uint stride_y = strides[0] 40 | cdef uint padding_x = padding[1] 41 | cdef uint padding_y = padding[0] 42 | cdef uint n_imgs = imgs.shape[0] 43 | cdef uint n_channels = imgs.shape[1] 44 | cdef uint img_h = imgs.shape[2] 45 | cdef uint img_w = imgs.shape[3] 46 | cdef uint out_h = poolout.shape[2] 47 | cdef uint out_w = poolout.shape[3] 48 | 49 | 50 | cdef uint i, c, y, x, y_out, x_out 51 | cdef int y_min, y_max, x_min, x_max 52 | cdef uint img_y, img_x 53 | cdef uint img_y_max = 0 54 | cdef uint img_x_max = 0 55 | cdef DTYPE_t value, new_value 56 | for i in range(n_imgs): 57 | for c in range(n_channels): 58 | for y_out in range(out_h): 59 | y = y_out*stride_y-padding_y 60 | y_min = int_max(y, 0) 61 | y_max = int_min(y+pool_h, img_h) 62 | for x_out in range(out_w): 63 | x = x_out*stride_x-padding_x 64 | x_min = int_max(x, 0) 65 | x_max = int_min(x+pool_w, img_w) 66 | if (type == POOL_MAX): 67 | value = -9e99 68 | else: 69 | value = 0 70 | 71 | for img_y in range(y_min, y_max): 72 | for img_x in range(x_min, x_max): 73 | if (type == POOL_MAX): 74 | new_value = imgs[i, c, img_y, img_x] 75 | if new_value > value: 76 | value = new_value 77 | img_y_max = img_y 78 | img_x_max = img_x 79 | else: 80 | value += imgs[i, c, img_y, img_x] 81 | if (type == POOL_MAX): 82 | poolout[i, c, y_out, x_out] = value 83 | switches[i, c, y_out, x_out, 0] = img_y_max 84 | switches[i, c, y_out, x_out, 1] = img_x_max 85 | else: 86 | poolout[i, c, y_out, x_out] = value / pool_size 87 | 88 | @cython.boundscheck(False) 89 | @cython.wraparound(False) 90 | def bprop_pool_bc01(np.ndarray[DTYPE_t, ndim=4] poolout_grad, 91 | tuple win_shape, 92 | tuple strides, 93 | tuple padding, 94 | uint type, 95 | np.ndarray[np.int_t, ndim=5] switches, 96 | np.ndarray[DTYPE_t, ndim=4] imgs_grad): 97 | 98 | cdef uint n_imgs = poolout_grad.shape[0] 99 | cdef uint n_channels = poolout_grad.shape[1] 100 | cdef uint poolout_h = poolout_grad.shape[2] 101 | cdef uint poolout_w = poolout_grad.shape[3] 102 | 103 | cdef uint pool_h = win_shape[0] 104 | cdef uint pool_w = win_shape[1] 105 | cdef uint pool_size = pool_h * pool_w 106 | cdef uint stride_x = strides[1] 107 | cdef uint stride_y = strides[0] 108 | cdef uint padding_x = padding[1] 109 | cdef uint padding_y = padding[0] 110 | 111 | cdef uint i, c, y, x, img_y, img_y_min, img_x_min, img_y_max, img_x_max 112 | 113 | imgs_grad[...] = 0 114 | 115 | for i in range(n_imgs): 116 | for c in range(n_channels): 117 | for y in range(poolout_h): 118 | for x in range(poolout_w): 119 | if (type == POOL_MEAN): 120 | img_y_min = y * stride_y - padding_y 121 | img_x_min = x * stride_x - padding_x 122 | img_y_max = img_y_min + pool_h 123 | img_x_max = img_x_min + pool_w 124 | # XXX should be += instead of = 125 | imgs_grad[i, c, img_y_min : img_y_max, img_x_min : img_x_max] += (poolout_grad[i, c, y, x] / pool_size) 126 | elif (type == POOL_MAX): 127 | img_y = switches[i, c, y, x, 0] 128 | img_x = switches[i, c, y, x, 1] 129 | # XXX should be += instead of = 130 | imgs_grad[i, c, img_y, img_x] += poolout_grad[i, c, y, x] 131 | return imgs_grad 132 | -------------------------------------------------------------------------------- /cudarray/numpy_backend/nnet/special.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def softmax(X): 5 | e = np.exp(X - np.amax(X, axis=1, keepdims=True)) 6 | return e/np.sum(e, axis=1, keepdims=True) 7 | 8 | 9 | def categorical_cross_entropy(y_pred, y_true, eps=1e-15): 10 | # Assumes one-hot encoding. 11 | y_pred = np.clip(y_pred, eps, 1 - eps) 12 | # XXX: do we need to normalize? 13 | y_pred /= y_pred.sum(axis=1, keepdims=True) 14 | loss = -np.sum(y_true * np.log(y_pred), axis=1) 15 | return loss 16 | 17 | 18 | def one_hot_encode(labels, n_classes, out=None): 19 | out_shape = (labels.size, n_classes) 20 | if labels.dtype != np.dtype(int): 21 | raise ValueError('labels.dtype must be int') 22 | if out is None: 23 | out = np.empty(out_shape) 24 | else: 25 | if out.shape != out_shape: 26 | raise ValueError('shape mismatch') 27 | out.fill(0) 28 | if labels.size == 1: 29 | out[0, labels] = 1 30 | else: 31 | for c in range(n_classes): 32 | out[labels == c, c] = 1 33 | return out 34 | 35 | 36 | def one_hot_decode(one_hot, out=None): 37 | out_shape = (one_hot.shape[0],) 38 | if out is None: 39 | out = np.empty(out_shape, dtype=np.dtype(int)) 40 | else: 41 | if out.dtype != np.dtype(int): 42 | raise ValueError('out.dtype must be int') 43 | if out.shape != out_shape: 44 | raise ValueError('shape mismatch') 45 | result = np.argmax(one_hot, axis=1) 46 | np.copyto(out, result) 47 | return out 48 | -------------------------------------------------------------------------------- /cudarray/random.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .wrap import random 4 | from . import cudarray 5 | from . import helpers 6 | 7 | 8 | def seed(val=None): 9 | if None: 10 | raise ValueError('not implemented') 11 | random._seed(val) 12 | 13 | 14 | def normal(loc=0.0, scale=1.0, size=None): 15 | if size is None: 16 | return np.random.normal(loc, scale, size) 17 | size = helpers.normalize_shape(size) 18 | n = helpers.prod(size) 19 | # cuRAND number generation requires an even number of elements. 20 | n = n if n % 2 == 0 else n + 1 21 | out = cudarray.empty(n) 22 | random._random_normal(out._data, loc, scale, n) 23 | out.shape = size 24 | return out 25 | 26 | 27 | def uniform(low=0.0, high=1.0, size=None): 28 | if size is None: 29 | return np.random.uniform(low, high, size) 30 | size = helpers.normalize_shape(size) 31 | n = helpers.prod(size) 32 | # cuRAND number generation requires an even number of elements. 33 | n = n if n % 2 == 0 else n + 1 34 | out = cudarray.empty(n) 35 | random._random_uniform(out._data, low, high, n) 36 | out.shape = size 37 | return out 38 | -------------------------------------------------------------------------------- /cudarray/reduction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .wrap import reduction 4 | from . import cudarray 5 | from . import helpers 6 | from . import base 7 | 8 | 9 | REDUCE_ALL = 0 10 | REDUCE_LEADING = 1 11 | REDUCE_TRAILING = 2 12 | 13 | 14 | def reduce_shape(shape, axis, keepdims): 15 | if keepdims: 16 | out_shape = list(shape) 17 | for a in axis: 18 | out_shape[a] = 1 19 | return tuple(out_shape) 20 | all_axis = tuple(range(len(shape))) 21 | if axis == all_axis: 22 | return (1,) 23 | else: 24 | return tuple(shape[a] for a in all_axis if a not in axis) 25 | 26 | 27 | def reduce_type(axis, ndim): 28 | all_axis = tuple(range(ndim)) 29 | if axis == all_axis: 30 | return REDUCE_ALL 31 | elif axis == all_axis[:len(axis)]: 32 | return REDUCE_LEADING 33 | elif axis == all_axis[-len(axis):]: 34 | return REDUCE_TRAILING 35 | raise ValueError('reduction of middle axes not implemented') 36 | 37 | 38 | def reduce(op, a, axis=None, dtype=None, out=None, keepdims=False, 39 | to_int_op=False): 40 | axis = helpers.normalize_axis(axis, a.ndim) 41 | out_shape = reduce_shape(a.shape, axis, keepdims) 42 | 43 | if to_int_op: 44 | out_dtype = np.dtype('int32') 45 | else: 46 | out_dtype = a.dtype 47 | 48 | if out is None: 49 | out = cudarray.empty(out_shape, out_dtype) 50 | else: 51 | if not out.shape == out_shape: 52 | raise ValueError('out.shape does not match result') 53 | if not out.dtype == out_dtype: 54 | raise ValueError('dtype mismatch') 55 | 56 | rtype = reduce_type(axis, a.ndim) 57 | if rtype == REDUCE_ALL: 58 | if to_int_op: 59 | reduction._reduce_to_int(op, a._data, a.size, out._data) 60 | else: 61 | reduction._reduce(op, a._data, a.size, out._data) 62 | return out 63 | 64 | a = base.ascontiguousarray(a) 65 | 66 | if rtype == REDUCE_LEADING: 67 | n = helpers.prod(out_shape) 68 | m = a.size / n 69 | if to_int_op: 70 | reduction._reduce_mat_to_int(op, a._data, m, n, True, out._data) 71 | else: 72 | reduction._reduce_mat(op, a._data, m, n, True, out._data) 73 | else: 74 | m = helpers.prod(out_shape) 75 | n = a.size / m 76 | if to_int_op: 77 | reduction._reduce_mat_to_int(op, a._data, m, n, False, out._data) 78 | else: 79 | reduction._reduce_mat(op, a._data, m, n, False, out._data) 80 | return out 81 | 82 | 83 | def amax(a, axis=None, dtype=None, out=None, keepdims=False): 84 | return reduce(reduction.max_op, a, axis, dtype, out, keepdims) 85 | 86 | 87 | def mean(a, axis=None, dtype=None, out=None, keepdims=False): 88 | return reduce(reduction.mean_op, a, axis, dtype, out, keepdims) 89 | 90 | 91 | def amin(a, axis=None, dtype=None, out=None, keepdims=False): 92 | return reduce(reduction.min_op, a, axis, dtype, out, keepdims) 93 | 94 | 95 | def sum(a, axis=None, dtype=None, out=None, keepdims=False): 96 | return reduce(reduction.sum_op, a, axis, dtype, out, keepdims) 97 | 98 | 99 | def argmax(a, axis=None, dtype=None, out=None, keepdims=False): 100 | return reduce(reduction.argmax_op, a, axis, dtype, out, keepdims, True) 101 | 102 | 103 | def argmin(a, axis=None, dtype=None, out=None, keepdims=False): 104 | return reduce(reduction.argmin_op, a, axis, dtype, out, keepdims, True) 105 | -------------------------------------------------------------------------------- /cudarray/wrap/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andersbll/cudarray/a2cffbb1434db9a7e6ed83211300d23d47630d2e/cudarray/wrap/__init__.py -------------------------------------------------------------------------------- /cudarray/wrap/array_data.pxd: -------------------------------------------------------------------------------- 1 | from libcpp cimport bool 2 | cimport numpy as np 3 | 4 | cdef extern from 'cudarray/common.hpp' namespace 'cudarray': 5 | ctypedef int bool_t; 6 | 7 | cdef class ArrayData: 8 | cdef public np.dtype dtype 9 | cdef public unsigned int nbytes 10 | cdef void *dev_ptr 11 | cdef ArrayData owner 12 | cdef size_t size 13 | cdef unsigned int offset 14 | 15 | 16 | cdef bool_t *bool_ptr(ArrayData a) 17 | cdef float *float_ptr(ArrayData a) 18 | cdef int *int_ptr(ArrayData a) 19 | cdef bool is_int(ArrayData a) 20 | cdef bool is_float(ArrayData a) 21 | -------------------------------------------------------------------------------- /cudarray/wrap/array_data.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport numpy as np 3 | from .cudart cimport * 4 | from .array_data cimport ArrayData 5 | 6 | 7 | cdef class ArrayData: 8 | def __init__(self, size_t size, np.dtype dtype, np.ndarray np_data=None, 9 | ArrayData owner=None, unsigned int offset=0): 10 | self.size = size 11 | self.dtype = dtype 12 | self.nbytes = size*dtype.itemsize 13 | self.owner = owner 14 | self.offset = offset 15 | if owner is None: 16 | cudaCheck(cudaMalloc(&self.dev_ptr, self.nbytes)) 17 | else: 18 | self.dev_ptr = ( owner.dev_ptr) + offset*dtype.itemsize 19 | if np_data is not None: 20 | cudaCheck(cudaMemcpyAsync(self.dev_ptr, np.PyArray_DATA(np_data), 21 | self.nbytes, cudaMemcpyHostToDevice)) 22 | 23 | def to_numpy(self, np_array): 24 | cudaCheck(cudaMemcpy(np.PyArray_DATA(np_array), self.dev_ptr, 25 | self.nbytes, cudaMemcpyDeviceToHost)) 26 | return np_array 27 | 28 | def __dealloc__(self): 29 | if self.owner is None: 30 | cudaFree(self.dev_ptr) 31 | 32 | def __reduce__(self): 33 | if self.owner is not None: 34 | np_array = None 35 | else: 36 | np_array = np.empty(self.size, dtype=self.dtype) 37 | self.to_numpy(np_array) 38 | args = (self.size, self.dtype, np_array, self.owner, self.offset) 39 | return (ArrayData, args) 40 | 41 | @property 42 | def data(self): 43 | return self.dev_ptr 44 | 45 | @property 46 | def itemsize(self): 47 | return self.dtype.itemsize 48 | 49 | 50 | cdef bool_t *bool_ptr(ArrayData a): 51 | return a.dev_ptr 52 | 53 | 54 | cdef float *float_ptr(ArrayData a): 55 | return a.dev_ptr 56 | 57 | 58 | cdef int *int_ptr(ArrayData a): 59 | return a.dev_ptr 60 | 61 | 62 | cdef bool is_int(ArrayData a): 63 | return a.dtype == np.dtype('int32') 64 | 65 | 66 | cdef bool is_float(ArrayData a): 67 | return a.dtype == np.dtype('float32') 68 | -------------------------------------------------------------------------------- /cudarray/wrap/array_ops.pxd: -------------------------------------------------------------------------------- 1 | cdef extern from 'cudarray/array_ops.hpp' namespace 'cudarray': 2 | 3 | void concatenate[T]( 4 | const T *a, const T *b, unsigned int axis, unsigned int d0, 5 | unsigned int d1, unsigned int d2, unsigned int da, unsigned int db, 6 | T *c 7 | ) 8 | 9 | void split[T]( 10 | const T *c, unsigned int axis, unsigned int d0, unsigned int d1, 11 | unsigned int d2, unsigned int da, unsigned int db, T *a, T *b 12 | ) 13 | 14 | void transpose[T](const T *a, unsigned int n, unsigned int m, T *b) 15 | 16 | void as[Ta, Tb](const Ta *a, unsigned int n, Tb *b) 17 | 18 | void fill[T](T *a, unsigned int n, T alpha) 19 | 20 | void copy[T](const T *a, unsigned int n, T *b) 21 | 22 | void to_device[T](const T *a, unsigned int n, T *b) 23 | 24 | void to_host[T](const T *a, unsigned int n, T *b) 25 | -------------------------------------------------------------------------------- /cudarray/wrap/array_ops.pyx: -------------------------------------------------------------------------------- 1 | cimport numpy as np 2 | cimport array_ops 3 | from .array_data cimport (ArrayData, bool_ptr, float_ptr, int_ptr, is_int, 4 | is_float) 5 | 6 | 7 | def _concatenate(ArrayData a, ArrayData b, unsigned int axis, unsigned int d0, 8 | unsigned int d1, unsigned int d2, unsigned int da, 9 | unsigned int db, ArrayData c): 10 | if is_float(a): 11 | array_ops.concatenate(float_ptr(a), float_ptr(b), axis, d0, d1, d2, da, 12 | db, float_ptr(c)) 13 | elif is_int(a): 14 | array_ops.concatenate(int_ptr(a), int_ptr(b), axis, d0, d1, d2, da, db, 15 | int_ptr(c)) 16 | else: 17 | raise ValueError('type (%s) not implemented' % str(a.dtype)) 18 | 19 | 20 | def _split(ArrayData c, unsigned int axis, unsigned int d0, unsigned int d1, 21 | unsigned int d2, unsigned int da, unsigned int db, ArrayData a, 22 | ArrayData b): 23 | if is_float(a): 24 | array_ops.split(float_ptr(c), axis, d0, d1, d2, da, db, float_ptr(a), 25 | float_ptr(b)) 26 | elif is_int(a): 27 | array_ops.split(int_ptr(c), axis, d0, d1, d2, da, db, int_ptr(a), 28 | int_ptr(b)) 29 | else: 30 | raise ValueError('type (%s) not implemented' % str(a.dtype)) 31 | 32 | 33 | def _transpose(ArrayData a, unsigned int m, unsigned int n, ArrayData out): 34 | if is_float(a): 35 | array_ops.transpose(float_ptr(a), m, n, float_ptr(out)) 36 | elif is_int(a): 37 | array_ops.transpose(int_ptr(a), m, n, int_ptr(out)) 38 | else: 39 | raise ValueError('type (%s) not implemented' % str(a.dtype)) 40 | 41 | 42 | def _asfloat(ArrayData a, unsigned int n, ArrayData out): 43 | if is_int(a): 44 | array_ops.as[int, float](int_ptr(a), n, float_ptr(out)) 45 | else: 46 | raise ValueError('type (%s) not implemented' % str(a.dtype)) 47 | 48 | 49 | def _asint(ArrayData a, unsigned int n, ArrayData out): 50 | if is_float(a): 51 | array_ops.as[float, int](float_ptr(a), n, int_ptr(out)) 52 | else: 53 | raise ValueError('type (%s) not implemented' % str(a.dtype)) 54 | 55 | 56 | def _fill(ArrayData a, unsigned int n, alpha): 57 | if is_int(a): 58 | array_ops.fill(int_ptr(a), n, alpha) 59 | elif is_float(a): 60 | array_ops.fill(float_ptr(a), n, alpha) 61 | else: 62 | raise ValueError('type (%s) not implemented' % str(a.dtype)) 63 | 64 | 65 | def _copy(ArrayData a, unsigned int n, ArrayData out): 66 | if is_int(a): 67 | array_ops.copy(int_ptr(a), n, int_ptr(out)) 68 | elif is_float(a): 69 | array_ops.copy(float_ptr(a), n, float_ptr(out)) 70 | else: 71 | raise ValueError('type (%s) not implemented' % str(a.dtype)) 72 | 73 | 74 | def _to_device(np.ndarray a, unsigned int n, ArrayData out): 75 | if is_int(out): 76 | array_ops.to_device(np.PyArray_DATA(a), n, int_ptr(out)) 77 | elif is_float(out): 78 | array_ops.to_device(np.PyArray_DATA(a), n, float_ptr(out)) 79 | else: 80 | raise ValueError('type (%s) not implemented' % str(a.dtype)) 81 | 82 | 83 | def _to_host(ArrayData a, unsigned int n, np.ndarray out): 84 | if is_int(a): 85 | array_ops.to_host(int_ptr(a), n, np.PyArray_DATA(out)) 86 | elif is_float(a): 87 | array_ops.to_host(float_ptr(a), n, np.PyArray_DATA(out)) 88 | else: 89 | raise ValueError('type (%s) not implemented' % str(a.dtype)) 90 | -------------------------------------------------------------------------------- /cudarray/wrap/blas.pxd: -------------------------------------------------------------------------------- 1 | cdef extern from "cudarray/blas.hpp" namespace 'cudarray': 2 | enum TransposeOp: 3 | OP_TRANS 4 | OP_NO_TRANS 5 | 6 | T dot[T](const T *a, const T *b, unsigned int n) 7 | 8 | void gemv[T](const T *A, const T *b, TransposeOp trans, unsigned int m, 9 | unsigned int n, T alpha, T beta, T *c) 10 | 11 | void gemm[T](const T *A, const T *B, TransposeOp transA, 12 | TransposeOp transB, unsigned int m, unsigned int n, 13 | unsigned int k, T alpha, T beta, T *C) 14 | 15 | cdef cppclass BLASBatch[T]: 16 | BLASBatch(const T *A, const T *B, T *C, unsigned int batch_size, 17 | int Astride, int Bstride, int Cstride) 18 | 19 | BLASBatch(const T **A, const T **B, T **C, unsigned int batch_size) 20 | 21 | void gemm(TransposeOp transA, TransposeOp transB, unsigned int m, 22 | unsigned int n, unsigned int k, T alpha, T beta) 23 | -------------------------------------------------------------------------------- /cudarray/wrap/blas.pyx: -------------------------------------------------------------------------------- 1 | cimport numpy as np 2 | cimport blas 3 | from .array_data cimport ArrayData, float_ptr, is_float 4 | 5 | 6 | no_trans_op = blas.OP_NO_TRANS 7 | trans_op = blas.OP_TRANS 8 | 9 | 10 | def dot_(ArrayData a, ArrayData b, unsigned int n): 11 | if is_float(a): 12 | return blas.dot(float_ptr(a), float_ptr(b), n) 13 | else: 14 | raise ValueError('type %s not implemented' % str(a.dtype)) 15 | 16 | 17 | def gemv_(ArrayData A, ArrayData x, blas.TransposeOp trans, unsigned int m, 18 | unsigned int n, alpha, beta, ArrayData y): 19 | if is_float(A): 20 | blas.gemv(float_ptr(A), float_ptr(x), trans, m, n, alpha, 21 | beta, y.dev_ptr) 22 | else: 23 | raise ValueError('type %s not implemented' % str(A.dtype)) 24 | 25 | 26 | def gemm_(ArrayData A, ArrayData B, blas.TransposeOp transA, 27 | blas.TransposeOp transB, unsigned int m, unsigned int n, 28 | unsigned int k, alpha, beta, ArrayData C): 29 | if is_float(A): 30 | blas.gemm(float_ptr(A), float_ptr(B), transA, transB, m, n, k, 31 | alpha, beta, float_ptr(C)) 32 | else: 33 | raise ValueError('type %s not implemented' % str(A.dtype)) 34 | 35 | 36 | cdef class BLASBatch_f: 37 | cdef BLASBatch[float] *ptr 38 | def __init__(self, ArrayData A, ArrayData B, ArrayData C, int batch_size, 39 | int Astride, int Bstride, int Cstride): 40 | self.ptr = new BLASBatch[float](float_ptr(A), float_ptr(B), 41 | float_ptr(C), batch_size, Astride, Bstride, Cstride) 42 | 43 | def __dealloc__(self): 44 | del self.ptr 45 | 46 | def gemm(self, blas.TransposeOp transA, blas.TransposeOp transB, 47 | unsigned int m, unsigned int n, unsigned int k, float alpha, 48 | float beta): 49 | self.ptr.gemm(transA, transB, m, n, k, alpha, beta) 50 | 51 | 52 | cpdef blas_batch(ArrayData A, ArrayData B, ArrayData C, int batch_size, 53 | int Astride, int Bstride, int Cstride): 54 | cdef BLASBatch[float] *ptr 55 | if is_float(A): 56 | return BLASBatch_f(A, B, C, batch_size, Astride, Bstride, Cstride) 57 | -------------------------------------------------------------------------------- /cudarray/wrap/cudart.pxd: -------------------------------------------------------------------------------- 1 | cdef extern from "driver_types.h": 2 | enum cudaMemcpyKind: 3 | cudaMemcpyHostToHost 4 | cudaMemcpyHostToDevice 5 | cudaMemcpyDeviceToHost 6 | cudaMemcpyDeviceToDevice 7 | cudaMemcpyDefault 8 | 9 | enum cudaError: 10 | cudaSuccess 11 | 12 | ctypedef cudaError cudaError_t 13 | 14 | ctypedef struct CUstream_st: 15 | pass 16 | ctypedef CUstream_st *cudaStream_t 17 | 18 | 19 | cdef extern from "cuda_runtime_api.h": 20 | cudaError_t cudaMemcpy(void *dst, const void *src, size_t count, 21 | cudaMemcpyKind kind) 22 | cudaError_t cudaMalloc(void **devPtr, size_t size) 23 | cudaError_t cudaFree(void *devPtr) 24 | const char* cudaGetErrorString(cudaError_t error) 25 | 26 | cudaError_t cudaDeviceSynchronize() 27 | cudaError_t cudaGetLastError() 28 | 29 | cudaError_t cudaMemcpyAsync(void *dst, const void *src, size_t count, 30 | cudaMemcpyKind kind) 31 | cudaError_t cudaSetDevice(int device) 32 | 33 | 34 | cpdef initialize(int device_id) 35 | cdef cudaCheck(cudaError_t status) 36 | cpdef cudaSyncCheck() 37 | -------------------------------------------------------------------------------- /cudarray/wrap/cudart.pyx: -------------------------------------------------------------------------------- 1 | from cudart cimport * 2 | 3 | 4 | cpdef initialize(int device_id): 5 | cudaCheck(cudaSetDevice(device_id)) 6 | # Establish context 7 | cudaCheck(cudaFree(0)) 8 | 9 | 10 | cdef cudaCheck(cudaError_t status): 11 | if status != cudaSuccess: 12 | raise ValueError(cudaGetErrorString(status)) 13 | 14 | 15 | cpdef cudaSyncCheck(): 16 | cudaCheck(cudaDeviceSynchronize()) 17 | cudaCheck(cudaGetLastError()) 18 | -------------------------------------------------------------------------------- /cudarray/wrap/cudnn.pxd: -------------------------------------------------------------------------------- 1 | cdef extern from "cudarray/nnet/cudnn.hpp" namespace 'cudarray': 2 | enum PoolMode: 3 | POOL_AVG 4 | POOL_MAX 5 | 6 | cdef cppclass PoolBC01CuDNN[T]: 7 | PoolBC01CuDNN(int n_img_dims, int *win_shape, int *padding, 8 | int *strides, PoolMode pool_mode) 9 | 10 | void fprop(const T *imgs, int *imgs_shape, T *poolout) 11 | 12 | void bprop(const T *imgs, const T* poolout, const T *poolout_d, 13 | T *imgs_d) 14 | 15 | 16 | cdef cppclass ConvBC01CuDNN[T]: 17 | ConvBC01CuDNN(int pad_y, int pad_x, int stride_y, int stride_x) 18 | 19 | void fprop(const T *imgs, const T *filters, int n_imgs, int n_channels, 20 | int n_filters, int img_h, int img_w, int filter_h, int filter_w, 21 | T *convout) 22 | 23 | void bprop(const T *imgs, const T *filters, const T *convout_d, 24 | T *imgs_d, T *filters_d) 25 | -------------------------------------------------------------------------------- /cudarray/wrap/cudnn.pyx: -------------------------------------------------------------------------------- 1 | from cpython cimport array as c_array 2 | from array import array 3 | cimport numpy as np 4 | from .array_data cimport ArrayData, float_ptr 5 | cimport cudnn 6 | 7 | 8 | 9 | cdef class PoolBC01CuDNN_f: 10 | cdef PoolBC01CuDNN[float] *ptr 11 | cdef tuple win_shape 12 | cdef tuple padding 13 | cdef tuple strides 14 | cdef str mode 15 | 16 | def __init__(self, win_shape, padding, strides, mode): 17 | cdef c_array.array win_shape_ = array('i', win_shape) 18 | cdef c_array.array padding_ = array('i', padding) 19 | cdef c_array.array strides_ = array('i', strides) 20 | self.win_shape = win_shape 21 | self.padding = padding 22 | self.strides = strides 23 | self.mode = mode 24 | if mode == 'avg': 25 | mode = POOL_AVG 26 | elif mode == 'max': 27 | mode = POOL_MAX 28 | else: 29 | raise ValueError('Invalid mode: %s' % mode) 30 | self.ptr = new PoolBC01CuDNN[float]( 31 | len(win_shape), win_shape_.data.as_ints, padding_.data.as_ints, 32 | strides_.data.as_ints, mode 33 | ) 34 | 35 | def __dealloc__(self): 36 | del self.ptr 37 | 38 | def __reduce__(self): 39 | args = (self.win_shape, self.padding, self.strides, self.mode) 40 | return (PoolBC01CuDNN_f, args) 41 | 42 | def fprop(self, ArrayData imgs, imgs_shape, ArrayData poolout): 43 | cdef c_array.array imgs_shape_ = array('i', imgs_shape) 44 | self.ptr.fprop(float_ptr(imgs), imgs_shape_.data.as_ints, 45 | float_ptr(poolout)) 46 | 47 | def bprop(self, ArrayData imgs, ArrayData poolout, ArrayData poolout_d, 48 | ArrayData imgs_d): 49 | self.ptr.bprop( 50 | imgs.dev_ptr, 51 | poolout.dev_ptr, 52 | poolout_d.dev_ptr, imgs_d.dev_ptr 53 | ) 54 | 55 | 56 | cdef class ConvBC01CuDNN_f: 57 | cdef ConvBC01CuDNN[float] *ptr 58 | cdef tuple padding 59 | cdef tuple strides 60 | def __init__(self, padding, strides): 61 | self.padding = padding 62 | self.strides = strides 63 | self.ptr = new ConvBC01CuDNN[float]( 64 | padding[0], padding[1], strides[0], strides[1] 65 | ) 66 | 67 | def __dealloc__(self): 68 | del self.ptr 69 | 70 | def __reduce__(self): 71 | args = (self.padding, self.strides) 72 | return (ConvBC01CuDNN_f, args) 73 | 74 | def fprop(self, ArrayData imgs, ArrayData filters, int n_imgs, 75 | int n_channels, int n_filters, img_shape, filter_shape, 76 | ArrayData convout): 77 | cdef int img_h = img_shape[0] 78 | cdef int img_w = img_shape[1] 79 | cdef int filter_h = filter_shape[0] 80 | cdef int filter_w = filter_shape[1] 81 | self.ptr.fprop(float_ptr(imgs), float_ptr(filters), n_imgs, n_channels, 82 | n_filters, img_h, img_w, filter_h, filter_w, float_ptr(convout)) 83 | 84 | def bprop(self, ArrayData imgs, ArrayData filters, ArrayData convout_d, 85 | ArrayData imgs_d, ArrayData filters_d): 86 | cdef float *imgs_ptr = NULL if imgs is None \ 87 | else float_ptr(imgs) 88 | cdef float *imgs_d_ptr = NULL if imgs_d is None \ 89 | else float_ptr(imgs_d) 90 | cdef float *filters_d_ptr = NULL if filters_d is None \ 91 | else float_ptr(filters_d) 92 | self.ptr.bprop(imgs_ptr, float_ptr(filters), 93 | float_ptr(convout_d), imgs_d_ptr, filters_d_ptr) 94 | 95 | 96 | def conv_bc01_cudnn(padding, strides): 97 | # TODO: only float is supported 98 | return ConvBC01CuDNN_f(padding, strides) 99 | -------------------------------------------------------------------------------- /cudarray/wrap/elementwise.pxd: -------------------------------------------------------------------------------- 1 | from libcpp cimport bool 2 | 3 | cdef extern from 'cudarray/common.hpp' namespace 'cudarray': 4 | ctypedef int bool_t; 5 | 6 | cdef extern from 'cudarray/elementwise.hpp' namespace 'cudarray': 7 | enum BroadcastType: 8 | BROADCAST_INNER 9 | BROADCAST_LEADING 10 | BROADCAST_OUTER 11 | BROADCAST_TRAILING 12 | 13 | enum BinaryOp: 14 | ADD_OP 15 | DIV_OP 16 | MAX_B_OP 17 | MIN_B_OP 18 | MUL_OP 19 | POW_OP 20 | SUB_OP 21 | 22 | void binary[Ta, Tb, Tc](BinaryOp op, const Ta *a, const Tb *b, 23 | unsigned int n, Tc *c) 24 | void binary_scalar[Ta, Talpha, Tb](BinaryOp op, const Ta *a, Talpha alpha, 25 | unsigned int n, Tb *b) 26 | void binary_scalar_[Talpha, Ta, Tb](BinaryOp op, Talpha alpha, const Ta *a, 27 | unsigned int n, Tb *b) 28 | void binary_broadcast[Ta, Tb, Tc](BinaryOp op, BroadcastType btype, 29 | const Ta *a, const Tb *b, unsigned int k, unsigned int m, 30 | unsigned int n, Tc *c) 31 | 32 | 33 | enum BinaryCmpOp: 34 | EQ_OP 35 | GT_OP 36 | GT_EQ_OP 37 | LT_OP 38 | LT_EQ_OP 39 | NEQ_OP 40 | 41 | void binary_cmp[Ta, Tb](BinaryCmpOp op, const Ta *a, const Tb *b, 42 | unsigned int n, bool_t *c) 43 | void binary_cmp_scalar[T](BinaryCmpOp op, const T *a, T alpha, 44 | unsigned int n, bool_t *b) 45 | void binary_cmp_scalar_[T](BinaryCmpOp op, T alpha, const T *a, 46 | unsigned int n, bool_t *b) 47 | void binary_cmp_broadcast[Ta, Tb](BinaryCmpOp op, BroadcastType btype, 48 | const Ta *a, const Tb *b, unsigned int k, unsigned int m, 49 | unsigned int n, bool_t *c) 50 | 51 | 52 | enum UnaryOp: 53 | ABS_OP 54 | COS_OP 55 | EXP_OP 56 | LOG_OP 57 | LOG1P_OP 58 | NEG_OP 59 | RELU_OP 60 | RELU_D_OP 61 | SIGMOID_OP 62 | SIGMOID_D_OP 63 | SOFTPLUS_OP 64 | SOFTPLUS_D_OP 65 | SIN_OP 66 | SQRT_OP 67 | TANH_OP 68 | TANH_D_OP 69 | 70 | void unary[T](UnaryOp op, const T *a, unsigned int n, T *b) 71 | 72 | void clip[T](const T *a, T a_min, T a_max, unsigned int n, T *b) 73 | -------------------------------------------------------------------------------- /cudarray/wrap/elementwise.pyx: -------------------------------------------------------------------------------- 1 | from cpython cimport bool 2 | cimport numpy as np 3 | cimport elementwise 4 | from .array_data cimport (ArrayData, bool_ptr, float_ptr, int_ptr, is_int, 5 | is_float) 6 | 7 | btype_inner = BROADCAST_INNER 8 | btype_leading = BROADCAST_LEADING 9 | btype_outer = BROADCAST_OUTER 10 | btype_trailing = BROADCAST_TRAILING 11 | 12 | add_op = ADD_OP 13 | div_op = DIV_OP 14 | max_op = MAX_B_OP 15 | min_op = MIN_B_OP 16 | mul_op = MUL_OP 17 | pow_op = POW_OP 18 | sub_op = SUB_OP 19 | 20 | abs_op = ABS_OP 21 | cos_op = COS_OP 22 | exp_op = EXP_OP 23 | log_op = LOG_OP 24 | log1p_op = LOG1P_OP 25 | neg_op = NEG_OP 26 | relu_op = RELU_OP 27 | relu_d_op = RELU_D_OP 28 | sigmoid_op = SIGMOID_OP 29 | sigmoid_d_op = SIGMOID_D_OP 30 | softplus_op = SOFTPLUS_OP 31 | softplus_d_op = SOFTPLUS_D_OP 32 | sin_op = SIN_OP 33 | sqrt_op = SQRT_OP 34 | tanh_op = TANH_OP 35 | tanh_d_op = TANH_D_OP 36 | 37 | eq_op = EQ_OP 38 | gt_op = GT_OP 39 | gt_eq_op = GT_EQ_OP 40 | lt_op = LT_OP 41 | lt_eq_op = LT_EQ_OP 42 | neq_op = NEQ_OP 43 | 44 | 45 | def _binary(BinaryOp op, ArrayData a, ArrayData b, unsigned int n, 46 | ArrayData c): 47 | if is_float(a) and is_float(b): 48 | elementwise.binary(op, float_ptr(a), float_ptr(b), n, float_ptr(c)) 49 | elif is_float(a) and is_int(b): 50 | elementwise.binary(op, float_ptr(a), int_ptr(b), n, float_ptr(c)) 51 | elif is_int(a) and is_float(b): 52 | elementwise.binary(op, int_ptr(a), float_ptr(b), n, float_ptr(c)) 53 | elif is_int(a) and is_int(b): 54 | elementwise.binary(op, int_ptr(a), int_ptr(b), n, int_ptr(c)) 55 | else: 56 | raise ValueError('types (%s, %s) not implemented' 57 | % (str(a.dtype), str(b.dtype))) 58 | 59 | 60 | def _binary_scalar(BinaryOp op, ArrayData a, alpha, unsigned int n, 61 | ArrayData b, bool flip_operands): 62 | if is_float(a): 63 | if flip_operands: 64 | elementwise.binary_scalar_(op, alpha, float_ptr(a), n, 65 | float_ptr(b)) 66 | else: 67 | elementwise.binary_scalar(op, float_ptr(a), alpha, n, 68 | float_ptr(b)) 69 | elif is_int(a) and isinstance(alpha, float): 70 | if flip_operands: 71 | elementwise.binary_scalar_(op, alpha, int_ptr(a), n, 72 | float_ptr(b)) 73 | else: 74 | elementwise.binary_scalar(op, int_ptr(a), alpha, n, 75 | float_ptr(b)) 76 | elif is_int(a) and isinstance(alpha, int): 77 | if flip_operands: 78 | elementwise.binary_scalar_(op, alpha, int_ptr(a), n, 79 | int_ptr(b)) 80 | else: 81 | elementwise.binary_scalar(op, int_ptr(a), alpha, n, 82 | int_ptr(b)) 83 | else: 84 | raise ValueError('types (%s, %s) not implemented' 85 | % (str(a.dtype), type(alpha))) 86 | 87 | 88 | def _binary_broadcast(BinaryOp op, BroadcastType btype, ArrayData a, 89 | ArrayData b, unsigned int k, unsigned int m, unsigned int n, 90 | ArrayData c): 91 | if is_float(a) and is_float(b): 92 | elementwise.binary_broadcast(op, btype, float_ptr(a), float_ptr(b), k, 93 | m, n, float_ptr(c)) 94 | elif is_float(a) and is_int(b): 95 | elementwise.binary_broadcast(op, btype, float_ptr(a), int_ptr(b), k, m, 96 | n, float_ptr(c)) 97 | elif is_int(a) and is_float(b): 98 | elementwise.binary_broadcast(op, btype, int_ptr(a), float_ptr(b), k, m, 99 | n, float_ptr(c)) 100 | elif is_int(a) and is_int(b): 101 | elementwise.binary_broadcast(op, btype, int_ptr(a), int_ptr(b), k, m, 102 | n, int_ptr(c)) 103 | else: 104 | raise ValueError('types (%s, %s) not implemented' 105 | % (str(a.dtype), str(b.dtype))) 106 | 107 | 108 | def _binary_cmp(BinaryCmpOp op, ArrayData a, ArrayData b, unsigned int n, 109 | ArrayData c): 110 | if is_float(a) and is_float(b): 111 | elementwise.binary_cmp[float, float](op, float_ptr(a), float_ptr(b), n, 112 | bool_ptr(c)) 113 | elif is_float(a) and is_int(b): 114 | elementwise.binary_cmp[float, int](op, float_ptr(a), int_ptr(b), n, 115 | bool_ptr(c)) 116 | elif is_int(a) and is_float(b): 117 | elementwise.binary_cmp[int, float](op, int_ptr(a), float_ptr(b), n, 118 | bool_ptr(c)) 119 | elif is_int(a) and is_int(b): 120 | elementwise.binary_cmp[int, int](op, int_ptr(a), int_ptr(b), n, 121 | bool_ptr(c)) 122 | else: 123 | raise ValueError('types (%s, %s) not implemented' 124 | % (str(a.dtype), str(b.dtype))) 125 | 126 | 127 | def _binary_cmp_scalar(BinaryCmpOp op, ArrayData a, alpha, unsigned int n, 128 | ArrayData b, bool flip_operands): 129 | if is_float(a): 130 | if flip_operands: 131 | elementwise.binary_cmp_scalar_[float]( 132 | op, alpha, float_ptr(a), n, bool_ptr(b) 133 | ) 134 | else: 135 | elementwise.binary_cmp_scalar[float]( 136 | op, float_ptr(a), alpha, n, bool_ptr(b) 137 | ) 138 | elif is_int(a): 139 | if flip_operands: 140 | elementwise.binary_cmp_scalar_[int]( 141 | op, alpha, int_ptr(a), n, bool_ptr(b) 142 | ) 143 | else: 144 | elementwise.binary_cmp_scalar[int]( 145 | op, int_ptr(a), alpha, n, bool_ptr(b) 146 | ) 147 | else: 148 | raise ValueError('types (%s, %s) not implemented' 149 | % (str(a.dtype), type(alpha))) 150 | 151 | 152 | def _binary_cmp_broadcast(BinaryCmpOp op, BroadcastType btype, ArrayData a, 153 | ArrayData b, unsigned int k, unsigned int m, unsigned int n, ArrayData c): 154 | if is_float(a) and is_float(b): 155 | elementwise.binary_cmp_broadcast[float, float](op, btype, float_ptr(a), 156 | float_ptr(b), k, m, n, bool_ptr(c)) 157 | elif is_float(a) and is_int(b): 158 | elementwise.binary_cmp_broadcast[float, int](op, btype, float_ptr(a), 159 | int_ptr(b), k, m, n, bool_ptr(c)) 160 | elif is_int(a) and is_float(b): 161 | elementwise.binary_cmp_broadcast[int, float](op, btype, int_ptr(a), 162 | float_ptr(b), k, m, n, bool_ptr(c)) 163 | elif is_int(a) and is_int(b): 164 | elementwise.binary_cmp_broadcast[int, int](op, btype, int_ptr(a), 165 | int_ptr(b), k, m, n, bool_ptr(c)) 166 | else: 167 | raise ValueError('types (%s, %s) not implemented' 168 | % (str(a.dtype), str(b.dtype))) 169 | 170 | 171 | def _unary(UnaryOp op, ArrayData a, unsigned int n, ArrayData b): 172 | if is_float(a): 173 | elementwise.unary(op, float_ptr(a), n, float_ptr(b)) 174 | elif is_int(a): 175 | elementwise.unary(op, int_ptr(a), n, int_ptr(b)) 176 | else: 177 | raise ValueError('type %s not implemented' % str(a.dtype)) 178 | 179 | 180 | def _clip(ArrayData a, a_min, a_max, unsigned int n, ArrayData b): 181 | if is_float(a): 182 | elementwise.clip[float](float_ptr(a), a_min, a_max, n, float_ptr(b)) 183 | elif is_int(a): 184 | elementwise.clip[int](int_ptr(a), a_min, a_max, n, int_ptr(b)) 185 | else: 186 | raise ValueError('type %s not implemented' % str(a.dtype)) 187 | -------------------------------------------------------------------------------- /cudarray/wrap/image.pxd: -------------------------------------------------------------------------------- 1 | cdef extern from 'cudarray/image/rescale.hpp' namespace 'cudarray': 2 | enum SampleMethod: 3 | BILINEAR_SAMPLING 4 | NEAREST_SAMPLING 5 | PERFORATED_SAMPLING 6 | 7 | void rescale[T](const T *imgs, float factor, SampleMethod method, 8 | int n_imgs, int img_h, int img_w, T *imgs_scaled) 9 | 10 | -------------------------------------------------------------------------------- /cudarray/wrap/image.pyx: -------------------------------------------------------------------------------- 1 | cimport numpy as np 2 | cimport image 3 | from .array_data cimport (ArrayData, float_ptr, int_ptr, is_int, is_float) 4 | 5 | 6 | sample_methods = { 7 | 'bilinear': BILINEAR_SAMPLING, 8 | 'nearest': NEAREST_SAMPLING, 9 | 'perforated': PERFORATED_SAMPLING, 10 | } 11 | 12 | 13 | def _rescale(ArrayData imgs, float factor, SampleMethod method, int n_imgs, 14 | int img_h, int img_w, ArrayData out): 15 | if is_float(imgs): 16 | image.rescale(float_ptr(imgs), factor, method, n_imgs, img_h, img_w, 17 | float_ptr(out)) 18 | else: 19 | raise ValueError('type (%s) not implemented' % str(out.dtype)) 20 | 21 | -------------------------------------------------------------------------------- /cudarray/wrap/nnet.pxd: -------------------------------------------------------------------------------- 1 | cdef extern from 'cudarray/nnet/conv_bc01_matmul.hpp' namespace 'cudarray': 2 | 3 | void conv_bc01_matmul[T](const T *imgs, const T *filters, int n_imgs, 4 | int n_channels, int n_filters, int img_h, int img_w, int filter_h, 5 | int filter_w, int pad_y, int pad_x, int stride_y, int stride_x, 6 | T *convout) 7 | 8 | void conv_bc01_matmul_bprop_filters[T](const T *imgs, const T *convout_d, 9 | int n_imgs, int n_channels, int n_filters, int img_h, int img_w, 10 | int filter_h, int filter_w, int pad_y, int pad_x, int stride_y, 11 | int stride_x, T *filters_d) 12 | 13 | void conv_bc01_matmul_bprop_imgs[T](const T *filters, const T *convout_d, 14 | int n_imgs, int n_channels, int n_filters, int img_h, int img_w, 15 | int filter_h, int filter_w, int pad_y, int pad_x, int stride_y, 16 | int stride_x, T *imgs_d) 17 | 18 | 19 | cdef extern from 'cudarray/nnet/pool_b01.hpp' namespace 'cudarray': 20 | 21 | void max_pool_b01[T](const T* imgs, int n_imgs, int img_h, int img_w, 22 | int win_h, int win_w, int pad_y, int pad_x, int stride_y, int stride_x, 23 | T* out, int* mask) 24 | 25 | void max_pool_b01_bprob[T](const T* out_d, const int* mask, int n_imgs, 26 | int img_h, int img_w, int win_h, int win_w, int pad_y, int pad_x, 27 | int stride_y, int stride_x, T* imgs_d) 28 | 29 | void avg_pool_b01[T](const T* imgs, int n_imgs, int img_h, int img_w, 30 | int win_h, int win_w, int pad_y, int pad_x, int stride_y, int stride_x, 31 | T* out) 32 | 33 | void avg_pool_b01_bprob[T](const T* out_d, int n_imgs, int img_h, 34 | int img_w, int win_h, int win_w, int pad_y, int pad_x, int stride_y, 35 | int stride_x, T* imgs_d) 36 | 37 | 38 | cdef extern from 'cudarray/nnet/one_hot.hpp' namespace 'cudarray': 39 | void one_hot_encode[T](const int *labels, int n_classes, int n, T *out) 40 | -------------------------------------------------------------------------------- /cudarray/wrap/nnet.pyx: -------------------------------------------------------------------------------- 1 | cimport numpy as np 2 | cimport nnet 3 | from .array_data cimport ArrayData, float_ptr, int_ptr, is_float 4 | 5 | 6 | def _conv_bc01_matmul(ArrayData imgs, ArrayData filters, int n_imgs, 7 | int n_channels, int n_filters, img_shape, filter_shape, padding, strides, 8 | ArrayData convout): 9 | cdef int img_h = img_shape[0] 10 | cdef int img_w = img_shape[1] 11 | cdef int filter_h = filter_shape[0] 12 | cdef int filter_w = filter_shape[1] 13 | cdef int pad_y = padding[0] 14 | cdef int pad_x = padding[1] 15 | cdef int stride_y = strides[0] 16 | cdef int stride_x = strides[1] 17 | if is_float(imgs): 18 | nnet.conv_bc01_matmul(float_ptr(imgs), float_ptr(filters), 19 | n_imgs, n_channels, n_filters, img_h, img_w, filter_h, filter_w, 20 | pad_y, pad_x, stride_y, stride_x, float_ptr(convout)) 21 | else: 22 | raise ValueError('type %s not implemented' % str(imgs.dtype)) 23 | 24 | 25 | def _conv_bc01_matmul_bprop_filters(ArrayData imgs, ArrayData convout_d, 26 | int n_imgs, int n_channels, int n_filters, img_shape, filter_shape, 27 | padding, strides, ArrayData filters_d): 28 | cdef int img_h = img_shape[0] 29 | cdef int img_w = img_shape[1] 30 | cdef int filter_h = filter_shape[0] 31 | cdef int filter_w = filter_shape[1] 32 | cdef int pad_y = padding[0] 33 | cdef int pad_x = padding[1] 34 | cdef int stride_y = strides[0] 35 | cdef int stride_x = strides[1] 36 | if is_float(imgs): 37 | nnet.conv_bc01_matmul_bprop_filters(float_ptr(imgs), 38 | float_ptr(convout_d), n_imgs, n_channels, n_filters, img_h, img_w, 39 | filter_h, filter_w, pad_y, pad_x, stride_y, stride_x, 40 | float_ptr(filters_d)) 41 | else: 42 | raise ValueError('type %s not implemented' % str(imgs.dtype)) 43 | 44 | 45 | def _conv_bc01_matmul_bprop_imgs(ArrayData filters, ArrayData convout_d, 46 | int n_imgs, int n_channels, int n_filters, img_shape, filter_shape, 47 | padding, strides, ArrayData imgs_d): 48 | cdef int img_h = img_shape[0] 49 | cdef int img_w = img_shape[1] 50 | cdef int filter_h = filter_shape[0] 51 | cdef int filter_w = filter_shape[1] 52 | cdef int pad_y = padding[0] 53 | cdef int pad_x = padding[1] 54 | cdef int stride_y = strides[0] 55 | cdef int stride_x = strides[1] 56 | if is_float(filters): 57 | nnet.conv_bc01_matmul_bprop_imgs(float_ptr(filters), 58 | float_ptr(convout_d), n_imgs, n_channels, n_filters, img_h, img_w, 59 | filter_h, filter_w, pad_y, pad_x, stride_y, stride_x, 60 | float_ptr(imgs_d)) 61 | else: 62 | raise ValueError('type %s not implemented' % str(filters.dtype)) 63 | 64 | 65 | def _max_pool_b01(ArrayData imgs, int n_imgs, img_shape, win_shape, padding, 66 | strides, ArrayData out, ArrayData mask): 67 | cdef int img_h = img_shape[0] 68 | cdef int img_w = img_shape[1] 69 | cdef int win_h = win_shape[0] 70 | cdef int win_w = win_shape[1] 71 | cdef int pad_y = padding[0] 72 | cdef int pad_x = padding[1] 73 | cdef int stride_y = strides[0] 74 | cdef int stride_x = strides[1] 75 | if is_float(imgs): 76 | nnet.max_pool_b01(float_ptr(imgs), n_imgs, img_h, img_w, 77 | win_h, win_w, pad_y, pad_x, stride_y, stride_x, 78 | float_ptr(out), int_ptr(mask)) 79 | else: 80 | raise ValueError('type %s not implemented' % str(imgs.dtype)) 81 | 82 | 83 | def _max_pool_b01_bprop(ArrayData out_d, ArrayData mask, int n_imgs, img_shape, 84 | win_shape, padding, strides, ArrayData imgs_d): 85 | cdef int img_h = img_shape[0] 86 | cdef int img_w = img_shape[1] 87 | cdef int win_h = win_shape[0] 88 | cdef int win_w = win_shape[1] 89 | cdef int pad_y = padding[0] 90 | cdef int pad_x = padding[1] 91 | cdef int stride_y = strides[0] 92 | cdef int stride_x = strides[1] 93 | if is_float(out_d): 94 | nnet.max_pool_b01_bprob(float_ptr(out_d), int_ptr(mask), 95 | n_imgs, img_h, img_w, win_h, win_w, pad_y, pad_x, stride_y, 96 | stride_x, float_ptr(imgs_d)) 97 | else: 98 | raise ValueError('type %s not implemented' % str(out_d.dtype)) 99 | 100 | 101 | def _avg_pool_b01(ArrayData imgs, int n_imgs, img_shape, win_shape, padding, 102 | strides, ArrayData out): 103 | cdef int img_h = img_shape[0] 104 | cdef int img_w = img_shape[1] 105 | cdef int win_h = win_shape[0] 106 | cdef int win_w = win_shape[1] 107 | cdef int pad_y = padding[0] 108 | cdef int pad_x = padding[1] 109 | cdef int stride_y = strides[0] 110 | cdef int stride_x = strides[1] 111 | if is_float(imgs): 112 | nnet.avg_pool_b01(float_ptr(imgs), n_imgs, img_h, img_w, 113 | win_h, win_w, pad_y, pad_x, stride_y, stride_x, 114 | float_ptr(out)) 115 | else: 116 | raise ValueError('type %s not implemented' % str(imgs.dtype)) 117 | 118 | 119 | def _avg_pool_b01_bprop(ArrayData out_d, int n_imgs, img_shape, win_shape, 120 | padding, strides, ArrayData imgs_d): 121 | cdef int img_h = img_shape[0] 122 | cdef int img_w = img_shape[1] 123 | cdef int win_h = win_shape[0] 124 | cdef int win_w = win_shape[1] 125 | cdef int pad_y = padding[0] 126 | cdef int pad_x = padding[1] 127 | cdef int stride_y = strides[0] 128 | cdef int stride_x = strides[1] 129 | if is_float(out_d): 130 | nnet.avg_pool_b01_bprob(float_ptr(out_d), n_imgs, img_h, img_w, win_h, 131 | win_w, pad_y, pad_x, stride_y, stride_x, float_ptr(imgs_d)) 132 | else: 133 | raise ValueError('type %s not implemented' % str(out_d.dtype)) 134 | 135 | 136 | def _one_hot_encode(ArrayData labels, int n_classes, int n, ArrayData out): 137 | if is_float(out): 138 | nnet.one_hot_encode(int_ptr(labels), n_classes, n, float_ptr(out)) 139 | else: 140 | raise ValueError('type %s not implemented' % str(out.dtype)) 141 | -------------------------------------------------------------------------------- /cudarray/wrap/random.pxd: -------------------------------------------------------------------------------- 1 | cdef extern from 'cudarray/random.hpp' namespace 'cudarray': 2 | void seed(unsigned long long val) 3 | 4 | void random_normal[T](T *a, T mu, T sigma, unsigned int n) 5 | 6 | void random_uniform[T](T *a, T low, T high, unsigned int n) 7 | -------------------------------------------------------------------------------- /cudarray/wrap/random.pyx: -------------------------------------------------------------------------------- 1 | cimport numpy as np 2 | cimport random 3 | from .array_data cimport ArrayData, float_ptr, is_float 4 | 5 | 6 | def _seed(val): 7 | random.seed( val) 8 | 9 | 10 | def _random_normal(ArrayData a, mu, sigma, unsigned int n): 11 | if is_float(a): 12 | random.random_normal(float_ptr(a), mu, sigma, n) 13 | else: 14 | raise ValueError('type %s not implemented' % str(a.dtype)) 15 | 16 | 17 | def _random_uniform(ArrayData a, low, high, unsigned int n): 18 | if is_float(a): 19 | random.random_uniform(float_ptr(a), low, high, n) 20 | else: 21 | raise ValueError('type %s not implemented' % str(a.dtype)) 22 | -------------------------------------------------------------------------------- /cudarray/wrap/reduction.pxd: -------------------------------------------------------------------------------- 1 | from libcpp cimport bool 2 | 3 | cdef extern from 'cudarray/reduction.hpp' namespace 'cudarray': 4 | enum ReduceOp: 5 | MAX_OP 6 | MEAN_OP 7 | MIN_OP 8 | SUM_OP 9 | 10 | enum ReduceToIntOp: 11 | ARGMAX_OP 12 | ARGMIN_OP 13 | 14 | void reduce[T](ReduceOp op, const T *a, unsigned int n, T *b) 15 | void reduce_mat[T](ReduceOp op, const T *a, unsigned int m, unsigned int n, 16 | bool reduce_leading, T *b) 17 | 18 | void reduce_to_int[T](ReduceToIntOp op, const T *a, unsigned int n, int *b) 19 | void reduce_mat_to_int[T](ReduceToIntOp op, const T *a, unsigned int m, 20 | unsigned int n, bool reduce_leading, int *b) 21 | -------------------------------------------------------------------------------- /cudarray/wrap/reduction.pyx: -------------------------------------------------------------------------------- 1 | cimport numpy as np 2 | cimport reduction 3 | from .array_data cimport ArrayData, float_ptr, int_ptr, is_int, is_float 4 | 5 | 6 | max_op = MAX_OP 7 | mean_op = MEAN_OP 8 | min_op = MIN_OP 9 | sum_op = SUM_OP 10 | 11 | argmax_op = ARGMAX_OP 12 | argmin_op = ARGMIN_OP 13 | 14 | 15 | def _reduce(ReduceOp op, ArrayData a, unsigned int n, ArrayData out): 16 | if is_float(a): 17 | reduction.reduce(op, float_ptr(a), n, float_ptr(out)) 18 | elif is_int(a): 19 | reduction.reduce(op, int_ptr(a), n, int_ptr(out)) 20 | else: 21 | raise ValueError('type %s not implemented' % str(a.dtype)) 22 | 23 | 24 | def _reduce_mat(ReduceOp op, ArrayData a, unsigned int m, unsigned int n, 25 | bool reduce_leading, ArrayData out): 26 | if is_float(a): 27 | reduction.reduce_mat(op, float_ptr(a), m, n, reduce_leading, 28 | float_ptr(out)) 29 | elif is_int(a): 30 | reduction.reduce_mat(op, int_ptr(a), m, n, reduce_leading, 31 | int_ptr(out)) 32 | else: 33 | raise ValueError('type %s not implemented' % str(a.dtype)) 34 | 35 | 36 | 37 | def _reduce_to_int(ReduceToIntOp op, ArrayData a, unsigned int n, 38 | ArrayData out): 39 | if is_float(a): 40 | reduction.reduce_to_int(op, float_ptr(a), n, int_ptr(out)) 41 | elif is_int(a): 42 | reduction.reduce_to_int(op, int_ptr(a), n, int_ptr(out)) 43 | else: 44 | raise ValueError('type %s not implemented' % str(a.dtype)) 45 | 46 | 47 | def _reduce_mat_to_int(ReduceToIntOp op, ArrayData a, unsigned int m, 48 | unsigned int n, bool reduce_leading, ArrayData out): 49 | if is_float(a): 50 | reduction.reduce_mat_to_int(op, float_ptr(a), m, n, reduce_leading, 51 | int_ptr(out)) 52 | elif is_int(a): 53 | reduction.reduce_mat_to_int(op, int_ptr(a), m, n, reduce_leading, 54 | int_ptr(out)) 55 | else: 56 | raise ValueError('type %s not implemented' % str(a.dtype)) 57 | -------------------------------------------------------------------------------- /examples/benchmark_conv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import time 5 | import numpy as np 6 | import theano 7 | import theano.tensor as T 8 | 9 | from theano.sandbox.cuda.basic_ops import gpu_from_host, host_from_gpu 10 | from pylearn2.sandbox.cuda_convnet.filter_acts import FilterActs 11 | from pylearn2.sandbox.cuda_convnet.weight_acts import WeightActs 12 | from pylearn2.sandbox.cuda_convnet.img_acts import ImageActs 13 | 14 | import cudarray as ca 15 | 16 | 17 | def avg_running_time(fun): 18 | n_iter = 20 19 | start_time = time.time() 20 | for _ in range(n_iter): 21 | fun() 22 | duration = time.time() - start_time 23 | return duration / float(n_iter) 24 | 25 | 26 | def allclose(a, b): 27 | atol = 1e-3 28 | rtol = 1e-3 29 | return np.allclose(a, b, atol=atol, rtol=rtol) 30 | 31 | 32 | def benchmark(n_imgs, n_channels, img_shape, n_filters, filter_shape, pad): 33 | print('\nn_imgs: %i, n_channels: %i, img_shape: (%i, %i), ' 34 | % ((n_imgs, n_channels) + img_shape) 35 | + 'n_filters: %i, filter_shape: (%i, %i), pad: %i' 36 | % ((n_filters,) + filter_shape + (pad,))) 37 | 38 | # Setup arrays 39 | padding = (pad, pad) 40 | strides = (1, 1) 41 | img_h, img_w = img_shape 42 | filter_h, filter_w = filter_shape 43 | convout_h = img_h + 2*pad - filter_h + 1 44 | convout_w = img_w + 2*pad - filter_w + 1 45 | 46 | imgs_bc01_shape = (n_imgs, n_channels, img_h, img_w) 47 | filters_bc01_shape = (n_filters, n_channels, filter_h, filter_w) 48 | 49 | imgs_bc01 = np.random.randn(n_imgs, n_channels, img_h, img_w) 50 | imgs_c01b = np.transpose(imgs_bc01, (1, 2, 3, 0)) 51 | filters_fc01 = np.random.randn(n_filters, n_channels, filter_h, filter_w) 52 | filters_c01f = np.transpose(filters_fc01, (1, 2, 3, 0)) 53 | convout_bc01 = np.random.randn(n_imgs, n_filters, convout_h, convout_w) 54 | convout_c01b = np.transpose(convout_bc01, (1, 2, 3, 0)) 55 | 56 | imgs_bc01_t = theano.shared(imgs_bc01.astype(theano.config.floatX)) 57 | imgs_c01b_t = theano.shared(imgs_c01b.astype(theano.config.floatX)) 58 | filters_fc01_t = theano.shared(filters_fc01.astype(theano.config.floatX)) 59 | filters_c01f_t = theano.shared(filters_c01f.astype(theano.config.floatX)) 60 | convout_bc01_t = theano.shared(convout_bc01.astype(theano.config.floatX)) 61 | convout_c01b_t = theano.shared(convout_c01b.astype(theano.config.floatX)) 62 | imgs_bc01_ca = ca.array(imgs_bc01) 63 | filters_fc01_ca = ca.array(filters_fc01) 64 | convout_bc01_ca = ca.array(convout_bc01) 65 | 66 | # Forward propagation 67 | print('fprop') 68 | convout_cc_op = FilterActs(stride=1, partial_sum=4, pad=pad) 69 | convout_cc_expr = convout_cc_op(imgs_c01b_t, filters_c01f_t) 70 | convout_cc_fun = theano.function([], convout_cc_expr) 71 | convout_cc = convout_cc_fun() 72 | convout_cc = np.transpose(convout_cc, (3, 0, 1, 2)) 73 | 74 | def convout_ca_fun(): 75 | convout = ca.nnet.conv_bc01(imgs_bc01_ca, filters_fc01_ca, padding, 76 | strides) 77 | return convout 78 | convout_ca = np.array(convout_ca_fun()) 79 | print(' correct: ' + str(allclose(convout_ca, convout_cc))) 80 | duration_cc = avg_running_time(convout_cc_fun) 81 | duration_ca = avg_running_time(convout_ca_fun) 82 | print(' avg. duration: cuda_convnet: %.4f ca: %.4f' 83 | % (duration_cc, duration_ca)) 84 | print(' speedup: %.2f' % (duration_cc/duration_ca)) 85 | del convout_cc_op 86 | del convout_cc_expr 87 | del convout_cc_fun 88 | 89 | # Back propagation, imgs 90 | print('bprop_imgs') 91 | dimgs_cc_op = ImageActs(stride=1, partial_sum=1, pad=pad) 92 | dimgs_cc_expr = dimgs_cc_op(convout_c01b_t, filters_c01f_t) 93 | dimgs_cc_fun = theano.function([], dimgs_cc_expr) 94 | dimgs_cc = dimgs_cc_fun() 95 | dimgs_cc = np.transpose(dimgs_cc, (3, 0, 1, 2)) 96 | 97 | def dimgs_ca_fun(): 98 | return ca.nnet.conv_bc01_bprop_imgs(filters_fc01_ca, convout_bc01_ca, 99 | img_shape, padding, strides) 100 | dimgs_ca = np.array(dimgs_ca_fun()) 101 | print(' correct: ' + str(allclose(dimgs_ca, dimgs_cc))) 102 | duration_cc = avg_running_time(dimgs_cc_fun) 103 | duration_ca = avg_running_time(dimgs_ca_fun) 104 | print(' avg. duration: cuda_convnet: %.4f ca: %.4f' 105 | % (duration_cc, duration_ca)) 106 | print(' speedup: %.2f' % (duration_cc/duration_ca)) 107 | del dimgs_cc_op 108 | del dimgs_cc_expr 109 | del dimgs_cc_fun 110 | 111 | # Back propagation, filters 112 | dfilters_cc_op = WeightActs(stride=1, partial_sum=1, pad=pad) 113 | dfilters_cc_expr = dfilters_cc_op(imgs_c01b_t, convout_c01b_t, 114 | T.as_tensor_variable(filter_shape)) 115 | dfilters_cc_fun = theano.function([], dfilters_cc_expr) 116 | dfilters_cc = dfilters_cc_fun()[0] 117 | dfilters_cc = np.transpose(dfilters_cc, (3, 0, 1, 2)) 118 | 119 | def dfilters_ca_fun(): 120 | return ca.nnet.conv_bc01_bprop_filters(imgs_bc01_ca, convout_bc01_ca, 121 | filter_shape, padding, strides) 122 | dfilters_ca = np.array(dfilters_ca_fun()) 123 | 124 | print('bprop_filters') 125 | print(' correct: ' + str(allclose(dfilters_ca, dfilters_cc))) 126 | duration_cc = avg_running_time(dfilters_cc_fun) 127 | duration_ca = avg_running_time(dfilters_ca_fun) 128 | print(' avg. duration: cuda_convnet: %.4f ca: %.4f' 129 | % (duration_cc, duration_ca)) 130 | print(' speedup: %.2f' % (duration_cc/duration_ca)) 131 | 132 | 133 | def run(): 134 | np.random.seed(1) 135 | # Configurations are given in the form 136 | # (n_imgs, n_channels, img_shape, n_filters, filter_shape, padding) 137 | configurations = [ 138 | # From the original paper 139 | # http://arxiv.org/abs/1312.5851 140 | (128, 3, (32, 32), 96, (11, 11), 0), 141 | (128, 96, (32, 32), 256, (7, 7), 0), 142 | (128, 256, (16, 16), 384, (5, 5), 0), 143 | (128, 384, (16, 16), 384, (5, 5), 0), 144 | (128, 384, (16, 16), 384, (3, 3), 0), 145 | # From Sander Dieleman 146 | # http://benanne.github.io/2014/05/12/fft-convolutions-in-theano.html 147 | (64, 3, (96, 96), 128, (16, 16), 0), 148 | (64, 128, (32, 32), 64, (8, 8), 0), 149 | (128, 32, (54, 54), 64, (6, 6), 0), 150 | (128, 128, (16, 16), 128, (8, 8), 0), 151 | (128, 1024, (32, 32), 128, (4, 4), 0), 152 | # Exotic shapes and padding 153 | (5, 3, (5, 5), 16, (3, 3), 1), 154 | (64, 32, (32, 32), 32, (5, 5), 2), 155 | (64, 1, (17, 19), 32, (7, 7), 4), 156 | (64, 3, (9, 16), 32, (7, 7), 4), 157 | # Typical CNN layers for CIFAR-10 158 | (128, 3, (32, 32), 64, (5, 5), 2), 159 | (128, 64, (16, 16), 64, (5, 5), 2), 160 | (128, 64, (8, 8), 64, (5, 5), 2), 161 | ] 162 | 163 | for conf in configurations: 164 | benchmark(*conf) 165 | 166 | 167 | if __name__ == '__main__': 168 | run() 169 | -------------------------------------------------------------------------------- /include/cudarray/array_ops.hpp: -------------------------------------------------------------------------------- 1 | #ifndef ARRAY_OPS_HPP_ 2 | #define ARRAY_OPS_HPP_ 3 | 4 | namespace cudarray { 5 | 6 | template 7 | void concatenate( 8 | const T *a, const T *b, unsigned int axis, unsigned int d0, 9 | unsigned int d1, unsigned int d2, unsigned int da, unsigned int db, T *c 10 | ); 11 | 12 | template 13 | void split( 14 | const T *c, unsigned int axis, unsigned int d0, unsigned int d1, 15 | unsigned int d2, unsigned int da, unsigned int db, T *a, T *b 16 | ); 17 | 18 | template 19 | void transpose(const T *a, unsigned int n, unsigned int m, T *b); 20 | 21 | template 22 | void as(const Ta *a, unsigned int n, Tb *b); 23 | 24 | template 25 | void fill(T *a, unsigned int n, T alpha); 26 | 27 | template 28 | void copy(const T *a, unsigned int n, T *b); 29 | 30 | template 31 | void to_device(const T *a, unsigned int n, T *b); 32 | 33 | template 34 | void to_host(const T *a, unsigned int n, T *b); 35 | 36 | } 37 | 38 | #endif // ARRAY_OPS_HPP_ 39 | -------------------------------------------------------------------------------- /include/cudarray/blas.hpp: -------------------------------------------------------------------------------- 1 | #ifndef BLAS_HPP_ 2 | #define BLAS_HPP_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | namespace cudarray { 10 | // TODO: implement more from 11 | // http://docs.nvidia.com/cuda/pdf/CUBLAS_Library.pdf 12 | 13 | enum TransposeOp { 14 | OP_TRANS = CUBLAS_OP_T, 15 | OP_NO_TRANS = CUBLAS_OP_N 16 | }; 17 | 18 | template 19 | T dot(const T *a, const T *b, unsigned int n); 20 | 21 | template 22 | void gemv(const T *A, const T *b, TransposeOp trans, unsigned int m, 23 | unsigned int n, T alpha, T beta, T *c); 24 | 25 | template 26 | void gemm(const T *A, const T *B, TransposeOp transA, TransposeOp transB, 27 | unsigned int m, unsigned int n, unsigned int k, T alpha, T beta, 28 | T *C); 29 | 30 | 31 | template 32 | class BLASBatch { 33 | public: 34 | BLASBatch(const T **A, const T **B, T **C, unsigned int batch_size); 35 | BLASBatch(const T *A, const T *B, T *C, unsigned int batch_size, int Astride, 36 | int Bstride, int Cstride); 37 | 38 | ~BLASBatch(); 39 | 40 | void gemm(TransposeOp transA, TransposeOp transB, unsigned int m, 41 | unsigned int n, unsigned int k, T alpha, T beta); 42 | private: 43 | unsigned int batch_size; 44 | const float **As_dev; 45 | const float **Bs_dev; 46 | float **Cs_dev; 47 | }; 48 | 49 | 50 | const char* cublas_message(cublasStatus_t status); 51 | 52 | inline void cublas_check(cublasStatus_t status, const char *file, int line) { 53 | if (status != CUBLAS_STATUS_SUCCESS) { 54 | std::ostringstream o; 55 | o << file << ":" << line << ": " << cublas_message(status); 56 | throw std::runtime_error(o.str()); 57 | } 58 | } 59 | 60 | #define CUBLAS_CHECK(status) { cublas_check((status), __FILE__, __LINE__); } 61 | 62 | 63 | /* 64 | Singleton class to handle cuBLAS resources. 65 | */ 66 | class CUBLAS { 67 | public: 68 | inline static CUBLAS &instance() { 69 | static CUBLAS instance_; 70 | return instance_; 71 | } 72 | 73 | inline static cublasHandle_t &handle() { 74 | return instance().handle_; 75 | } 76 | 77 | private: 78 | cublasHandle_t handle_; 79 | CUBLAS() { 80 | CUBLAS_CHECK(cublasCreate(&handle_)); 81 | } 82 | ~CUBLAS() { 83 | CUBLAS_CHECK(cublasDestroy(handle_)); 84 | } 85 | CUBLAS(CUBLAS const&); 86 | void operator=(CUBLAS const&); 87 | }; 88 | 89 | } 90 | 91 | #endif // BLAS_HPP_ 92 | -------------------------------------------------------------------------------- /include/cudarray/common.hpp: -------------------------------------------------------------------------------- 1 | #ifndef COMMON_HPP_ 2 | #define COMMON_HPP_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | 11 | #define CUDA_GRID_STRIDE_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < n; \ 14 | i += blockDim.x * gridDim.x) 15 | 16 | 17 | namespace cudarray { 18 | 19 | 20 | typedef int bool_t; 21 | 22 | const int kNumBuffers = 32; 23 | const int kNumBlockThreads = 512; 24 | 25 | inline int cuda_blocks(int n_threads) { 26 | return (n_threads + kNumBlockThreads - 1) / kNumBlockThreads; 27 | } 28 | 29 | inline void cuda_check(cudaError_t status, const char *file, int line) { 30 | if (status != cudaSuccess) { 31 | std::ostringstream o; 32 | o << file << ":" << line << ": " << cudaGetErrorString(status); 33 | throw std::runtime_error(o.str()); 34 | } 35 | } 36 | 37 | #define CUDA_CHECK(status) { cuda_check((status), __FILE__, __LINE__); } 38 | 39 | inline void cuda_kernel_check(const char *file, int line) { 40 | cudaError_t status = cudaPeekAtLastError(); 41 | if (status != cudaSuccess) { 42 | std::ostringstream o; 43 | o << file << ":" << line << ": " << cudaGetErrorString(status); 44 | throw std::runtime_error(o.str()); 45 | } 46 | } 47 | 48 | #define CUDA_KERNEL_CHECK { cuda_kernel_check(__FILE__, __LINE__); } 49 | 50 | inline void cuda_check_sync(const char *file, const int line) { 51 | cuda_check(cudaDeviceSynchronize(), file, line); 52 | } 53 | 54 | #define CUDA_CHECK_SYNC { cuda_check_sync(__FILE__, __LINE__); } 55 | 56 | 57 | /* 58 | Singleton class to handle CUDA resources. 59 | */ 60 | class CUDA { 61 | public: 62 | inline static CUDA &instance() { 63 | static CUDA instance_; 64 | return instance_; 65 | } 66 | 67 | /* 68 | Request a memory pointer to device memory 69 | */ 70 | inline static void *buffer(size_t size, unsigned int idx=0) { 71 | if (instance().buffer_sizes[idx] < size) { 72 | if (instance().buffers[idx]) { 73 | CUDA_CHECK(cudaFree(instance().buffers[idx])); 74 | } 75 | instance().buffer_sizes[idx] = size; 76 | CUDA_CHECK(cudaMalloc(&instance().buffers[idx], size)); 77 | } 78 | return instance().buffers[idx]; 79 | } 80 | 81 | private: 82 | void *buffers[kNumBuffers]; 83 | size_t buffer_sizes[kNumBuffers]; 84 | 85 | CUDA() { 86 | for(int i = 0; i < kNumBuffers; i++) { 87 | buffers[i] = NULL; 88 | buffer_sizes[i] = 0; 89 | } 90 | } 91 | 92 | ~CUDA() { 93 | for(int i = 0; i < kNumBuffers; i++) { 94 | if (buffers[i]) { 95 | CUDA_CHECK(cudaFree(instance().buffers[i])); 96 | } 97 | } 98 | } 99 | 100 | CUDA(CUDA const&); 101 | void operator=(CUDA const&); 102 | }; 103 | 104 | 105 | } // cudarray 106 | 107 | #endif // COMMON_HPP_ 108 | -------------------------------------------------------------------------------- /include/cudarray/elementwise.hpp: -------------------------------------------------------------------------------- 1 | #ifndef ELEMENTWISE_HPP_ 2 | #define ELEMENTWISE_HPP_ 3 | 4 | #include 5 | 6 | 7 | namespace cudarray { 8 | 9 | enum BroadcastType { 10 | BROADCAST_INNER, BROADCAST_LEADING, BROADCAST_OUTER, BROADCAST_TRAILING, 11 | }; 12 | 13 | enum BinaryOp { 14 | ADD_OP, DIV_OP, MAX_B_OP, MIN_B_OP, MUL_OP, POW_OP, SUB_OP, 15 | }; 16 | 17 | template 18 | void binary(BinaryOp op, const Ta *a, const Tb *b, unsigned int n, Tc *c); 19 | 20 | template 21 | void binary_scalar(BinaryOp op, const Ta *a, Talpha alpha, unsigned int n, 22 | Tb *b); 23 | 24 | template 25 | void binary_scalar_(BinaryOp op, Talpha alpha, const Ta *a, unsigned int n, 26 | Tb *b); 27 | 28 | template 29 | void binary_broadcast(BinaryOp op, BroadcastType btype, const Ta *a, 30 | const Tb *b, unsigned int k, unsigned int m, unsigned int n, Tc *c); 31 | 32 | 33 | enum BinaryCmpOp { 34 | EQ_OP, GT_OP, GT_EQ_OP, LT_OP, LT_EQ_OP, NEQ_OP, 35 | }; 36 | 37 | template 38 | void binary_cmp(BinaryCmpOp op, const Ta *a, const Tb *b, unsigned int n, 39 | bool_t *c); 40 | 41 | template 42 | void binary_cmp_scalar(BinaryCmpOp op, const T *a, T alpha, 43 | unsigned int n, bool_t *b); 44 | 45 | template 46 | void binary_cmp_scalar_(BinaryCmpOp op, T alpha, const T *a, 47 | unsigned int n, bool_t *b); 48 | 49 | template 50 | void binary_cmp_broadcast(BinaryCmpOp op, BroadcastType btype, const Ta *a, 51 | const Tb *b, unsigned int k, unsigned int m, unsigned int n, bool_t *c); 52 | 53 | 54 | enum UnaryOp { 55 | ABS_OP, COS_OP, EXP_OP, LOG_OP, LOG1P_OP, NEG_OP, SIN_OP, SQRT_OP, TANH_OP, 56 | RELU_OP, RELU_D_OP, SIGMOID_OP, SIGMOID_D_OP, SOFTPLUS_OP, SOFTPLUS_D_OP, 57 | TANH_D_OP, 58 | }; 59 | 60 | template 61 | void unary(UnaryOp op, const T *a, unsigned int n, T *b); 62 | 63 | template 64 | void clip(const T *a, T a_min, T a_max, unsigned int n, T *b); 65 | 66 | } 67 | 68 | #endif // ELEMENTWISE_HPP_ 69 | -------------------------------------------------------------------------------- /include/cudarray/image/img2win.hpp: -------------------------------------------------------------------------------- 1 | #ifndef IMG2WIN_HPP_ 2 | #define IMG2WIN_HPP_ 3 | 4 | namespace cudarray { 5 | 6 | template 7 | void img2win(const T *imgs, int n_imgs, int img_h, int img_w, int win_h, 8 | int win_w, int pad_y, int pad_x, int stride_y, int stride_x, T *wins); 9 | 10 | template 11 | void win2img(const T *wins, int n_imgs, int img_h, int img_w, int win_h, 12 | int win_w, int pad_y, int pad_x, int stride_y, int stride_x, T *imgs); 13 | 14 | } 15 | 16 | #endif // IMG2WIN_HPP_ 17 | -------------------------------------------------------------------------------- /include/cudarray/image/rescale.hpp: -------------------------------------------------------------------------------- 1 | #ifndef RESCALE_HPP_ 2 | #define RESCALE_HPP_ 3 | 4 | namespace cudarray { 5 | 6 | enum SampleMethod { 7 | BILINEAR_SAMPLING, NEAREST_SAMPLING, PERFORATED_SAMPLING, 8 | }; 9 | 10 | template 11 | void rescale(const T *imgs, float factor, SampleMethod method, int n_imgs, 12 | int img_h, int img_w, T *imgs_scaled); 13 | 14 | } 15 | 16 | #endif // RESCALE_HPP_ 17 | -------------------------------------------------------------------------------- /include/cudarray/nnet/conv_bc01_matmul.hpp: -------------------------------------------------------------------------------- 1 | #ifndef CONV_BC01_MATMUL_HPP_ 2 | #define CONV_BC01_MATMUL_HPP_ 3 | 4 | namespace cudarray { 5 | 6 | template 7 | void conv_bc01_matmul(const T *imgs, const T *filters, int n_imgs, 8 | int n_channels, int n_filters, int img_h, int img_w, int filter_h, 9 | int filter_w, int pad_y, int pad_x, int stride_y, int stride_x, 10 | T *convout); 11 | 12 | template 13 | void conv_bc01_matmul_bprop_imgs(const T *filters, const T *convout_d, 14 | int n_imgs, int n_channels, int n_filters, int img_h, int img_w, 15 | int filter_h, int filter_w, int pad_y, int pad_x, int stride_y, 16 | int stride_x, T *imgs_d); 17 | 18 | template 19 | void conv_bc01_matmul_bprop_filters(const T *imgs, const T *convout_d, 20 | int n_imgs, int n_channels, int n_filters, int img_h, int img_w, 21 | int filter_h, int filter_w, int pad_y, int pad_x, int stride_y, 22 | int stride_x, T *filters_d); 23 | 24 | } 25 | 26 | #endif // CONV_BC01_MATMUL_HPP_ 27 | -------------------------------------------------------------------------------- /include/cudarray/nnet/cudnn.hpp: -------------------------------------------------------------------------------- 1 | #ifdef CUDNN_ENABLED 2 | 3 | #ifndef CUDNN_HPP_ 4 | #define CUDNN_HPP_ 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | namespace cudarray { 11 | 12 | enum PoolMode {POOL_AVG, POOL_MAX}; 13 | 14 | const int MAX_IMG_DIMS = 3; 15 | const int WORKSPACE_LIMIT = 1024*1024*1024; 16 | 17 | template 18 | class PoolBC01CuDNN { 19 | public: 20 | PoolBC01CuDNN(int n_img_dims, int *win_shape, int *padding, int *strides, 21 | PoolMode pool_mode); 22 | ~PoolBC01CuDNN(); 23 | 24 | void fprop(const T *imgs, int *imgs_shape, T *poolout); 25 | 26 | void bprop(const T *imgs, const T* poolout, const T *poolout_d, T *imgs_d); 27 | 28 | private: 29 | int n_img_dims; 30 | int win_shape[MAX_IMG_DIMS]; 31 | int padding[MAX_IMG_DIMS]; 32 | int strides[MAX_IMG_DIMS]; 33 | int imgs_shape[MAX_IMG_DIMS + 2]; 34 | cudnnPoolingMode_t pool_mode; 35 | cudnnTensorDescriptor_t imgs_desc; 36 | cudnnTensorDescriptor_t poolout_desc; 37 | cudnnPoolingDescriptor_t pool_desc; 38 | }; 39 | 40 | 41 | template 42 | class ConvBC01CuDNN { 43 | public: 44 | ConvBC01CuDNN(int pad_y, int pad_x, int stride_y, int stride_x); 45 | ~ConvBC01CuDNN(); 46 | 47 | void fprop(const T *imgs, const T *filters, int n_imgs, int n_channels, 48 | int n_filters, int img_h, int img_w, int filter_h, int filter_w, 49 | T *convout); 50 | 51 | void bprop(const T* imgs, const T* filters, const T *convout_d, T *imgs_d, 52 | T *filters_d); 53 | 54 | private: 55 | int pad_y; 56 | int pad_x; 57 | int stride_y; 58 | int stride_x; 59 | int n_imgs; 60 | int n_channels; 61 | int n_filters; 62 | int img_h; 63 | int img_w; 64 | int filter_h; 65 | int filter_w; 66 | cudnnTensorDescriptor_t imgs_desc; 67 | cudnnTensorDescriptor_t convout_desc; 68 | cudnnFilterDescriptor_t filters_desc; 69 | cudnnConvolutionDescriptor_t conv_desc; 70 | cudnnConvolutionFwdAlgo_t fwd_algo; 71 | cudnnConvolutionBwdFilterAlgo_t bwd_filters_algo; 72 | cudnnConvolutionBwdDataAlgo_t bwd_imgs_algo; 73 | size_t workspace_size; 74 | }; 75 | 76 | 77 | const char* cudnn_message(cudnnStatus_t status); 78 | 79 | inline void cudnn_check(cudnnStatus_t status, const char *file, int line) { 80 | if (status != CUDNN_STATUS_SUCCESS) { 81 | std::ostringstream o; 82 | o << file << ":" << line << ": " << cudnn_message(status); 83 | throw std::runtime_error(o.str()); 84 | } 85 | } 86 | 87 | #define CUDNN_CHECK(status) { cudnn_check((status), __FILE__, __LINE__); } 88 | 89 | /* 90 | Singleton class to handle cuDNN resources. 91 | */ 92 | class CUDNN { 93 | public: 94 | static const float one; 95 | static const float zero; 96 | 97 | inline static CUDNN &instance() { 98 | static CUDNN instance_; 99 | return instance_; 100 | } 101 | 102 | inline static cudnnHandle_t &handle() { 103 | return instance().handle_; 104 | } 105 | 106 | private: 107 | cudnnHandle_t handle_; 108 | CUDNN() { 109 | CUDNN_CHECK(cudnnCreate(&handle_)); 110 | } 111 | ~CUDNN() { 112 | CUDNN_CHECK(cudnnDestroy(handle_)); 113 | } 114 | CUDNN(CUDNN const&); 115 | void operator=(CUDNN const&); 116 | }; 117 | 118 | 119 | } // cudarray 120 | 121 | #endif // CUDNN_HPP_ 122 | 123 | #endif // CUDNN_ENABLED 124 | -------------------------------------------------------------------------------- /include/cudarray/nnet/one_hot.hpp: -------------------------------------------------------------------------------- 1 | #ifndef ONE_HOT_HPP_ 2 | #define ONE_HOT_HPP_ 3 | 4 | namespace cudarray { 5 | 6 | template 7 | void one_hot_encode(const int *labels, int n_classes, int n, T *out); 8 | 9 | } 10 | 11 | #endif // ONE_HOT_HPP_ 12 | -------------------------------------------------------------------------------- /include/cudarray/nnet/pool_b01.hpp: -------------------------------------------------------------------------------- 1 | #ifndef POOL_B01_HPP_ 2 | #define POOL_B01_HPP_ 3 | 4 | namespace cudarray { 5 | 6 | template 7 | void max_pool_b01(const T* imgs, int n_imgs, int img_h, int img_w, int win_h, 8 | int win_w, int pad_y, int pad_x, int stride_y, int stride_x, T* poolout, 9 | int* mask); 10 | 11 | template 12 | void max_pool_b01_bprob(const T* poolout_d, const int* mask, int n_imgs, 13 | int img_h, int img_w, int win_h, int win_w, int pad_y, int pad_x, 14 | int stride_y, int stride_x, T* imgs_d); 15 | 16 | template 17 | void avg_pool_b01(const T* imgs, int n_imgs, int img_h, int img_w, int win_h, 18 | int win_w, int pad_y, int pad_x, int stride_y, int stride_x, T* poolout); 19 | 20 | template 21 | void avg_pool_b01_bprob(const T* poolout_d, int n_imgs, int img_h, int img_w, 22 | int win_h, int win_w, int pad_y, int pad_x, int stride_y, int stride_x, 23 | T* imgs_d); 24 | 25 | } 26 | 27 | #endif // POOL_B01_HPP_ 28 | -------------------------------------------------------------------------------- /include/cudarray/random.hpp: -------------------------------------------------------------------------------- 1 | #ifndef RANDOM_HPP_ 2 | #define RANDOM_HPP_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | 12 | namespace cudarray { 13 | 14 | void seed(unsigned long long val); 15 | 16 | template 17 | void random_normal(T *a, T mu, T sigma, unsigned int n); 18 | 19 | template 20 | void random_uniform(T *a, T low, T high, unsigned int n); 21 | 22 | 23 | const char* curand_message(curandStatus_t status); 24 | 25 | inline void curand_check(curandStatus_t status, const char *file, int line) { 26 | if (status != CURAND_STATUS_SUCCESS) { 27 | std::ostringstream o; 28 | o << file << ":" << line << ": " << curand_message(status); 29 | throw std::runtime_error(o.str()); 30 | } 31 | } 32 | 33 | #define CURAND_CHECK(status) { curand_check((status), __FILE__, __LINE__); } 34 | 35 | 36 | /* 37 | Singleton class to handle cuRAND resources. 38 | */ 39 | class CURAND { 40 | public: 41 | inline static CURAND &instance() { 42 | static CURAND instance_; 43 | return instance_; 44 | } 45 | 46 | inline static curandGenerator_t &generator() { 47 | return instance().generator_; 48 | } 49 | 50 | private: 51 | curandGenerator_t generator_; 52 | CURAND() { 53 | CURAND_CHECK(curandCreateGenerator(&generator_, 54 | CURAND_RNG_PSEUDO_DEFAULT)); 55 | std::srand(std::time(NULL)+getpid()); 56 | CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(generator_, std::rand())); 57 | } 58 | ~CURAND() { 59 | } 60 | CURAND(CURAND const&); 61 | void operator=(CURAND const&); 62 | }; 63 | 64 | } 65 | 66 | #endif // RANDOM_HPP_ 67 | -------------------------------------------------------------------------------- /include/cudarray/reduction.hpp: -------------------------------------------------------------------------------- 1 | #ifndef REDUCTION_HPP_ 2 | #define REDUCTION_HPP_ 3 | 4 | namespace cudarray { 5 | 6 | enum ReduceOp { 7 | MAX_OP, MEAN_OP, MIN_OP, SUM_OP 8 | }; 9 | 10 | enum ReduceToIntOp { 11 | ARGMAX_OP, ARGMIN_OP 12 | }; 13 | 14 | template 15 | void reduce(ReduceOp op, const T *a, unsigned int n, T *b); 16 | 17 | template 18 | void reduce_mat(ReduceOp op, const T *a, unsigned int m, unsigned int n, 19 | bool reduce_leading, T *b); 20 | 21 | template 22 | void reduce_to_int(ReduceToIntOp op, const T *a, unsigned int n, int *b); 23 | 24 | template 25 | void reduce_mat_to_int(ReduceToIntOp op, const T *a, unsigned int m, 26 | unsigned int n, bool reduce_leading, int *b); 27 | 28 | } 29 | 30 | #endif // REDUCTION_HPP_ 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cython>=0.21 2 | numpy>=1.8 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import glob 5 | import re 6 | import numpy 7 | 8 | from setuptools import setup, find_packages, Feature, Command 9 | from Cython.Build import cythonize 10 | from Cython.Distutils import build_ext 11 | from Cython.Distutils.extension import Extension 12 | 13 | 14 | def read(fname): 15 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 16 | 17 | 18 | def cuda_extensions(): 19 | cuda_dir = os.getenv('CUDA_PREFIX', '/usr/local/cuda') 20 | cuda_include_dir = os.path.join(cuda_dir, 'include') 21 | cuda_library_dir = os.path.join(cuda_dir, 'lib64') 22 | if not os.path.exists(cuda_library_dir): 23 | # Use lib if lib64 does not exist 24 | cuda_library_dir = os.path.join(cuda_dir, 'lib') 25 | 26 | library_dirs = [cuda_library_dir] 27 | prefix = os.getenv('INSTALL_PREFIX') 28 | if prefix is not None: 29 | library_dirs.append(os.path.join(prefix, 'lib')) 30 | 31 | cudarray_dir = './cudarray' 32 | cudarray_include_dir = './include' 33 | include_dirs = [cuda_include_dir, cudarray_include_dir, 34 | numpy.get_include()] 35 | cython_include_dirs = ['./cudarray/wrap'] 36 | extra_compile_args = ['-O3', '-fPIC', '-Wall', '-Wfatal-errors'] 37 | libraries = ['cudart', 'cudarray'] 38 | extra_link_args = ['-fPIC'] 39 | language = 'c++' 40 | 41 | def make_extension(name): 42 | return Extension( 43 | name='cudarray.wrap.' + name, 44 | sources=[os.path.join(cudarray_dir, 'wrap', name + '.pyx')], 45 | language=language, 46 | include_dirs=include_dirs, 47 | cython_include_dirs=cython_include_dirs, 48 | extra_compile_args=extra_compile_args, 49 | library_dirs=library_dirs, 50 | libraries=libraries, 51 | extra_link_args=extra_link_args, 52 | ) 53 | ext_names = ['cudart', 'array_data', 'array_ops', 'elementwise', 54 | 'reduction', 'blas', 'random', 'nnet', 'image'] 55 | exts = list(map(make_extension, ext_names)) 56 | 57 | if os.getenv('CUDNN_ENABLED') == '1': 58 | cudnn_ext = Extension( 59 | name='cudarray.wrap.cudnn', 60 | sources=[os.path.join(cudarray_dir, 'wrap', 'cudnn.pyx')], 61 | language=language, 62 | include_dirs=include_dirs, 63 | cython_include_dirs=cython_include_dirs, 64 | extra_compile_args=['-DCUDNN_ENABLED'] + extra_compile_args, 65 | library_dirs=library_dirs, 66 | libraries=libraries+['cudnn'], 67 | extra_link_args=extra_link_args, 68 | ) 69 | exts.append(cudnn_ext) 70 | return exts 71 | 72 | 73 | def numpy_extensions(): 74 | cython_srcs = [ 75 | 'cudarray/numpy_backend/nnet/conv_bc01.pyx', 76 | 'cudarray/numpy_backend/nnet/pool_bc01.pyx', 77 | 'cudarray/numpy_backend/nnet/lrnorm_bc01.pyx', 78 | ] 79 | return cythonize(cython_srcs, include_path=[numpy.get_include()]) 80 | 81 | 82 | class Clean(Command): 83 | description = 'Remove Cython generated files.' 84 | user_options = [] 85 | 86 | def initialize_options(self): 87 | self.report = 'clean' 88 | 89 | def finalize_options(self): 90 | pass 91 | 92 | def run(self): 93 | for ext in (cuda_extensions() + numpy_extensions()): 94 | for pyx in ext.sources: 95 | if pyx.endswith('.pyx'): 96 | prefix = pyx[:-4] 97 | for suffix in ['.c', '.cpp', '.so', '.dll']: 98 | files = glob.glob(prefix + '*' + suffix) 99 | for f in files: 100 | if os.path.exists(f): 101 | os.unlink(f) 102 | 103 | 104 | with open('requirements.txt') as f: 105 | install_requires = [l.strip() for l in f] 106 | setup_requires = [r for r in install_requires if r.startswith('cython')] 107 | 108 | 109 | version = None 110 | regex = re.compile(r'''^__version__ = ['"]([^'"]*)['"]''') 111 | with open(os.path.join('cudarray', '__init__.py')) as f: 112 | for line in f: 113 | mo = regex.search(line) 114 | if mo is not None: 115 | version = mo.group(1) 116 | break 117 | if version is None: 118 | raise RuntimeError('Could not find version number') 119 | 120 | 121 | setup( 122 | name='cudarray', 123 | version=version, 124 | author='Anders Boesen Lindbo Larsen', 125 | author_email='abll@dtu.dk', 126 | description='CUDA-based Numpy array and operations', 127 | license='MIT', 128 | url='http://compute.dtu.dk/~abll', 129 | packages=find_packages(), 130 | setup_requires=setup_requires, 131 | install_requires=install_requires, 132 | long_description=read('README.md'), 133 | classifiers=[ 134 | 'Development Status :: 4 - Beta', 135 | 'Intended Audience :: Developers', 136 | 'Intended Audience :: Science/Research', 137 | 'License :: OSI Approved :: MIT License', 138 | 'Operating System :: OS Independent', 139 | 'Programming Language :: Python', 140 | 'Topic :: Scientific/Engineering', 141 | ], 142 | include_dirs=[numpy.get_include()], 143 | features={ 144 | 'cuda': Feature( 145 | description='CUDA back-end', 146 | standard=True, 147 | remove=['cudarray.wrap'], 148 | ext_modules=cuda_extensions(), 149 | ), 150 | 'numpy': Feature( 151 | description='Numpy back-end', 152 | standard=True, 153 | remove=['cudarray.numpy_backend'], 154 | ext_modules=numpy_extensions(), 155 | ), 156 | }, 157 | cmdclass={ 158 | 'build_ext': build_ext, 159 | 'clean': Clean, 160 | }, 161 | zip_safe=False, 162 | ) 163 | -------------------------------------------------------------------------------- /src/array_ops.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "cudarray/common.hpp" 3 | #include "cudarray/array_ops.hpp" 4 | 5 | 6 | const int TILE_DIM = 32; 7 | const int BLOCK_ROWS = 8; 8 | 9 | namespace cudarray { 10 | 11 | 12 | template 13 | __global__ void kernel_concatenate( 14 | const T *a, const T *b, unsigned int d0, unsigned int d1, unsigned int d2, 15 | unsigned int da, unsigned int db, T *c 16 | ) { 17 | CUDA_GRID_STRIDE_LOOP(c_idx, d0*d1*d2) { 18 | unsigned int c2_idx = c_idx % d2; 19 | unsigned int c1_idx = (c_idx / d2) % d1; 20 | unsigned int c0_idx = c_idx / d2 / d1; 21 | unsigned int a_idx; 22 | unsigned int b_idx; 23 | bool from_a; 24 | if (axis == 0) { 25 | a_idx = (c0_idx*d1 + c1_idx)*d2 + c2_idx; 26 | b_idx = ((c0_idx-da)*d1 + c1_idx)*d2 + c2_idx; 27 | from_a = c0_idx < da; 28 | } 29 | if (axis == 1) { 30 | a_idx = (c0_idx*da + c1_idx)*d2 + c2_idx; 31 | b_idx = (c0_idx*db + (c1_idx-da))*d2 + c2_idx; 32 | from_a = c1_idx < da; 33 | } 34 | if (axis == 2) { 35 | a_idx = (c0_idx*d1 + c1_idx)*da + c2_idx; 36 | b_idx = (c0_idx*d1 + c1_idx)*db + (c2_idx-da); 37 | from_a = c2_idx < da; 38 | } 39 | c[c_idx] = from_a ? a[a_idx] : b[b_idx]; 40 | } 41 | } 42 | 43 | template 44 | void concatenate( 45 | const T *a, const T *b, unsigned int axis, unsigned int d0, 46 | unsigned int d1, unsigned int d2, unsigned int da, unsigned int db, T *c 47 | ) { 48 | unsigned int n = d0*d1*d2; 49 | if (axis == 0) { 50 | kernel_concatenate<<>>( 51 | a, b, d0, d1, d2, da, db, c 52 | ); 53 | } else if (axis == 1) { 54 | kernel_concatenate<<>>( 55 | a, b, d0, d1, d2, da, db, c 56 | ); 57 | } else if (axis == 2) { 58 | kernel_concatenate<<>>( 59 | a, b, d0, d1, d2, da, db, c 60 | ); 61 | } else { 62 | throw std::runtime_error("invalid axis"); 63 | } 64 | CUDA_KERNEL_CHECK; 65 | } 66 | 67 | template void concatenate( 68 | const float *a, const float *b, unsigned int axis, unsigned int d0, 69 | unsigned int d1, unsigned int d2, unsigned int da, unsigned int db, 70 | float *c 71 | ); 72 | template void concatenate( 73 | const int *a, const int *b, unsigned int axis, unsigned int d0, 74 | unsigned int d1, unsigned int d2, unsigned int da, unsigned int db, 75 | int *c 76 | ); 77 | 78 | 79 | template 80 | __global__ void kernel_split( 81 | const T *c, unsigned int d0, unsigned int d1, unsigned int d2, 82 | unsigned int da, unsigned int db, T *a, T *b 83 | ) { 84 | CUDA_GRID_STRIDE_LOOP(c_idx, d0*d1*d2) { 85 | unsigned int c2_idx = c_idx % d2; 86 | unsigned int c1_idx = (c_idx / d2) % d1; 87 | unsigned int c0_idx = c_idx / d2 / d1; 88 | unsigned int a_idx; 89 | unsigned int b_idx; 90 | bool from_a; 91 | if (axis == 0) { 92 | a_idx = (c0_idx*d1 + c1_idx)*d2 + c2_idx; 93 | b_idx = ((c0_idx-da)*d1 + c1_idx)*d2 + c2_idx; 94 | from_a = c0_idx < da; 95 | } 96 | if (axis == 1) { 97 | a_idx = (c0_idx*da + c1_idx)*d2 + c2_idx; 98 | b_idx = (c0_idx*db + (c1_idx-da))*d2 + c2_idx; 99 | from_a = c1_idx < da; 100 | } 101 | if (axis == 2) { 102 | a_idx = (c0_idx*d1 + c1_idx)*da + c2_idx; 103 | b_idx = (c0_idx*d1 + c1_idx)*db + (c2_idx-da); 104 | from_a = c2_idx < da; 105 | } 106 | T val = c[c_idx]; 107 | if (from_a) { 108 | a[a_idx] = val; 109 | } else { 110 | b[b_idx] = val; 111 | } 112 | } 113 | } 114 | 115 | 116 | template 117 | void split( 118 | const T *c, unsigned int axis, unsigned int d0, unsigned int d1, 119 | unsigned int d2, unsigned int da, unsigned int db, T *a, T *b 120 | ) { 121 | unsigned int n = d0*d1*d2; 122 | if (axis == 0) { 123 | kernel_split<<>>( 124 | c, d0, d1, d2, da, db, a, b 125 | ); 126 | } else if (axis == 1) { 127 | kernel_split<<>>( 128 | c, d0, d1, d2, da, db, a, b 129 | ); 130 | } else if (axis == 2) { 131 | kernel_split<<>>( 132 | c, d0, d1, d2, da, db, a, b 133 | ); 134 | } else { 135 | throw std::runtime_error("invalid axis"); 136 | } 137 | CUDA_KERNEL_CHECK; 138 | } 139 | 140 | template void split( 141 | const float *c, unsigned int axis, unsigned int d0, unsigned int d1, 142 | unsigned int d2, unsigned int da, unsigned int db, float *a, float *b 143 | ); 144 | template void split( 145 | const int *c, unsigned int axis, unsigned int d0, unsigned int d1, 146 | unsigned int d2, unsigned int da, unsigned int db, int *a, int *b 147 | ); 148 | 149 | 150 | // Adapted from 151 | // http://devblogs.nvidia.com/parallelforall/efficient-matrix-transpose-cuda-cc/ 152 | template 153 | __global__ void kernel_transpose(const T *a, unsigned int m, unsigned int n, 154 | T *b) { 155 | __shared__ T tile[TILE_DIM][TILE_DIM+1]; 156 | 157 | int x = blockIdx.x * blockDim.x + threadIdx.x; 158 | int y = blockIdx.y * TILE_DIM + threadIdx.y; 159 | for (int i = 0; i < TILE_DIM; i += blockDim.y) { 160 | int y_ = y + i; 161 | if (mTileMultiple || y_ < m) { 162 | if (nTileMultiple || x < n) { 163 | tile[threadIdx.y + i][threadIdx.x] = a[y_*n + x]; 164 | } 165 | } 166 | } 167 | __syncthreads(); 168 | 169 | x = blockIdx.y * blockDim.x + threadIdx.x; 170 | y = blockIdx.x * TILE_DIM + threadIdx.y; 171 | for (int i = 0; i < TILE_DIM; i += blockDim.y) { 172 | int y_ = y + i; 173 | if (nTileMultiple || y_ < n) { 174 | if (mTileMultiple || x < m) { 175 | b[y_*m + x] = tile[threadIdx.x][threadIdx.y + i]; 176 | } 177 | } 178 | } 179 | } 180 | 181 | 182 | #define ceildiv(a, b) (((a)+(b)-1)/(b)) 183 | 184 | template 185 | void transpose(const T *a, unsigned int m, unsigned int n, T *b) { 186 | dim3 blocks(ceildiv(n,TILE_DIM), ceildiv(m,TILE_DIM), 1); 187 | dim3 threads(TILE_DIM, BLOCK_ROWS, 1); 188 | if (m % TILE_DIM) { 189 | if (n % TILE_DIM) { 190 | kernel_transpose<<>>(a, m, n, b); 191 | } else { 192 | kernel_transpose<<>>(a, m, n, b); 193 | } 194 | } else { 195 | if (n % TILE_DIM) { 196 | kernel_transpose<<>>(a, m, n, b); 197 | } else { 198 | kernel_transpose<<>>(a, m, n, b); 199 | } 200 | } 201 | CUDA_KERNEL_CHECK; 202 | } 203 | 204 | template void transpose(const int *a, unsigned int m, unsigned int n, 205 | int *b); 206 | template void transpose(const float *a, unsigned int m, unsigned int n, 207 | float *b); 208 | 209 | 210 | template 211 | __global__ void kernel_as(const Ta *a, unsigned int n, Tb *b) { 212 | CUDA_GRID_STRIDE_LOOP(idx, n) { 213 | b[idx] = (Tb) a[idx]; 214 | } 215 | } 216 | 217 | template 218 | void as(const Ta *a, unsigned int n, Tb *b) { 219 | kernel_as<<>>(a, n, b); 220 | CUDA_KERNEL_CHECK; 221 | } 222 | 223 | template void as(const int *a, unsigned int n, float *b); 224 | template void as(const float *a, unsigned int n, int *b); 225 | 226 | 227 | template 228 | __global__ void kernel_fill(T *a, unsigned int n, T alpha) { 229 | CUDA_GRID_STRIDE_LOOP(idx, n) { 230 | a[idx] = alpha; 231 | } 232 | } 233 | 234 | template 235 | void fill(T *a, unsigned int n, T alpha) { 236 | kernel_fill<<>>(a, n, alpha); 237 | CUDA_KERNEL_CHECK; 238 | } 239 | 240 | template void fill(int *a, unsigned int n, int alpha); 241 | template void fill(float *a, unsigned int n, float alpha); 242 | 243 | 244 | template 245 | void copy(const T *a, unsigned int n, T *b) { 246 | CUDA_CHECK(cudaMemcpy(b, a, n*sizeof(T), cudaMemcpyDeviceToDevice)); 247 | } 248 | 249 | template void copy(const int *a, unsigned int n, int *b); 250 | template void copy(const float *a, unsigned int n, float *b); 251 | 252 | 253 | template 254 | void to_device(const T *a, unsigned int n, T *b) { 255 | CUDA_CHECK(cudaMemcpy(b, a, n*sizeof(T), cudaMemcpyHostToDevice)); 256 | } 257 | 258 | template void to_device(const int *a, unsigned int n, int *b); 259 | template void to_device(const float *a, unsigned int n, float *b); 260 | 261 | 262 | template 263 | void to_host(const T *a, unsigned int n, T *b) { 264 | CUDA_CHECK(cudaMemcpy(b, a, n*sizeof(T), cudaMemcpyDeviceToHost)); 265 | } 266 | 267 | template void to_host(const int *a, unsigned int n, int *b); 268 | template void to_host(const float *a, unsigned int n, float *b); 269 | 270 | } 271 | -------------------------------------------------------------------------------- /src/blas.cpp: -------------------------------------------------------------------------------- 1 | #include "cudarray/common.hpp" 2 | #include "cudarray/blas.hpp" 3 | 4 | namespace cudarray { 5 | 6 | 7 | cublasStatus_t cublas_dot(cublasHandle_t handle, int n, const float *x, 8 | int incx, const float *y, int incy, float *result) { 9 | return cublasSdot(handle, n, x, incx, y, incy, result); 10 | } 11 | 12 | template 13 | T dot(const T *a, const T *b, unsigned int n) { 14 | T result; 15 | CUBLAS_CHECK(cublas_dot(CUBLAS::handle(), n, a, 1, b, 1, &result)); 16 | return result; 17 | } 18 | 19 | template float dot(const float *x, const float *y, unsigned int n); 20 | 21 | 22 | 23 | cublasStatus_t cublas_gemv(cublasHandle_t handle, cublasOperation_t trans, 24 | int m, int n, const float *alpha, const float *A, int lda, const float *x, 25 | int incx, const float *beta, float *y, int incy) { 26 | return cublasSgemv(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, 27 | incy); 28 | } 29 | 30 | template 31 | void gemv(const T *A, const T *b, TransposeOp trans, unsigned int m, 32 | unsigned int n, T alpha, T beta, T *c) { 33 | cublasOperation_t cuTrans; 34 | if (trans == OP_TRANS) { 35 | cuTrans = CUBLAS_OP_N; 36 | unsigned int tmp = n; 37 | n = m; 38 | m = tmp; 39 | } else { 40 | cuTrans = CUBLAS_OP_T; 41 | } 42 | int lda = n; 43 | CUBLAS_CHECK(cublas_gemv(CUBLAS::handle(), cuTrans, n, m, &alpha, A, lda, b, 44 | 1, &beta, c, 1)); 45 | } 46 | 47 | template void gemv(const float *A, const float *b, TransposeOp trans, 48 | unsigned int m, unsigned int n, float alpha, float beta, float *c); 49 | 50 | 51 | 52 | cublasStatus_t cublas_gemm(cublasHandle_t handle, cublasOperation_t transA, 53 | cublasOperation_t transB, int m, int n, int k, const float *alpha, 54 | const float *A, int lda, const float *B, int ldb, const float *beta, 55 | float *C, int ldc) { 56 | return cublasSgemm(handle, transA, transB, m, n, k, alpha, A, lda, B, ldb, 57 | beta, C, ldc); 58 | } 59 | 60 | template 61 | void gemm(const T *A, const T *B, TransposeOp transA, TransposeOp transB, 62 | unsigned int m, unsigned int n, unsigned int k, T alpha, T beta, 63 | T *C) { 64 | int lda = (transA == OP_NO_TRANS) ? k : m; 65 | int ldb = (transB == OP_NO_TRANS) ? n : k; 66 | int ldc = n; 67 | cublasOperation_t cuTransA = (cublasOperation_t) transA; 68 | cublasOperation_t cuTransB = (cublasOperation_t) transB; 69 | CUBLAS_CHECK(cublas_gemm(CUBLAS::handle(), cuTransB, cuTransA, n, m, k, 70 | &alpha, B, ldb, A, lda, &beta, C, ldc)); 71 | } 72 | 73 | template void gemm(const float *A, const float *B, TransposeOp transA, 74 | TransposeOp transB, unsigned int m, unsigned int n, unsigned int k, 75 | float alpha, float beta, float *C); 76 | 77 | 78 | template 79 | T **dev_ptrs(const T *base, int num, int stride) { 80 | T **ptrs_host = new T*[num]; 81 | int idx = 0; 82 | for(int n = 0; n < num; ++n){ 83 | ptrs_host[idx] = (T *) base + n * stride; 84 | idx++; 85 | } 86 | T **ptrs_dev; 87 | CUDA_CHECK(cudaMalloc((void **) &ptrs_dev, num*sizeof(T *))); 88 | CUDA_CHECK(cudaMemcpy(ptrs_dev, ptrs_host, num*sizeof(T *), 89 | cudaMemcpyHostToDevice)); 90 | delete []ptrs_host; 91 | return ptrs_dev; 92 | } 93 | 94 | 95 | template 96 | BLASBatch::BLASBatch(const T **As, const T **Bs, T **Cs, 97 | unsigned int batch_size) : batch_size(batch_size) { 98 | size_t ptrs_size = batch_size * sizeof(T **); 99 | CUDA_CHECK(cudaMalloc((void **) &As_dev, ptrs_size)); 100 | CUDA_CHECK(cudaMemcpy(As_dev, As, batch_size*sizeof(float *), 101 | cudaMemcpyHostToDevice)); 102 | CUDA_CHECK(cudaMalloc((void **) &Bs_dev, ptrs_size)); 103 | CUDA_CHECK(cudaMemcpy(Bs_dev, Bs, batch_size*sizeof(float *), 104 | cudaMemcpyHostToDevice)); 105 | CUDA_CHECK(cudaMalloc((void **) &Cs_dev, ptrs_size)); 106 | CUDA_CHECK(cudaMemcpy(Cs_dev, Cs, batch_size*sizeof(float *), 107 | cudaMemcpyHostToDevice)); 108 | } 109 | 110 | 111 | template 112 | BLASBatch::BLASBatch(const T *A, const T *B, T *C, 113 | unsigned int batch_size, int Astride, int Bstride, int Cstride) 114 | : batch_size(batch_size) { 115 | As_dev = (const float **) dev_ptrs(A, batch_size, Astride); 116 | Bs_dev = (const float **) dev_ptrs(B, batch_size, Bstride); 117 | Cs_dev = dev_ptrs(C, batch_size, Cstride); 118 | } 119 | 120 | template 121 | BLASBatch::~BLASBatch() { 122 | CUDA_CHECK(cudaFree(As_dev)); 123 | CUDA_CHECK(cudaFree(Bs_dev)); 124 | CUDA_CHECK(cudaFree(Cs_dev)); 125 | } 126 | 127 | 128 | cublasStatus_t cublas_gemm_batched(cublasHandle_t handle, 129 | cublasOperation_t transA, cublasOperation_t transB, int m, int n, int k, 130 | const float *alpha, const float *Aarray[], int lda, const float *Barray[], 131 | int ldb, const float *beta, float *Carray[], int ldc, int batchCount) { 132 | return cublasSgemmBatched(handle, transA, transB, m, n, k, alpha, Aarray, 133 | lda, Barray, ldb, beta, Carray, ldc, batchCount); 134 | } 135 | 136 | template 137 | void BLASBatch::gemm(TransposeOp transA, TransposeOp transB, unsigned int m, 138 | unsigned int n, unsigned int k, T alpha, T beta) { 139 | int lda = (transA == OP_NO_TRANS) ? k : m; 140 | int ldb = (transB == OP_NO_TRANS) ? n : k; 141 | int ldc = n; 142 | cublasOperation_t cuTransA = (cublasOperation_t) transA; 143 | cublasOperation_t cuTransB = (cublasOperation_t) transB; 144 | CUBLAS_CHECK(cublas_gemm_batched(CUBLAS::handle(), cuTransB, cuTransA, n, m, 145 | k, &alpha, Bs_dev, ldb, As_dev, lda, &beta, Cs_dev, ldc, batch_size)); 146 | } 147 | 148 | template class BLASBatch; 149 | 150 | 151 | const char *cublas_message(cublasStatus_t status){ 152 | switch(status) { 153 | case CUBLAS_STATUS_SUCCESS: 154 | return "The operation completed successfully."; 155 | case CUBLAS_STATUS_NOT_INITIALIZED: 156 | return "The cuBLAS library was not initialized."; 157 | case CUBLAS_STATUS_ALLOC_FAILED: 158 | return "Resource allocation failed inside the cuBLAS library."; 159 | case CUBLAS_STATUS_INVALID_VALUE: 160 | return "An unsupported value or parameter was passed to the function."; 161 | case CUBLAS_STATUS_ARCH_MISMATCH: 162 | return "The function requires a feature absent from the GPU."; 163 | case CUBLAS_STATUS_MAPPING_ERROR: 164 | return "An access to GPU memory space failed."; 165 | case CUBLAS_STATUS_EXECUTION_FAILED: 166 | return "The GPU program failed to execute."; 167 | case CUBLAS_STATUS_INTERNAL_ERROR: 168 | return "An internal cuBLAS operation failed."; 169 | // case CUBLAS_STATUS_NOT_SUPPORTED: 170 | // return "The functionnality requested is not supported."; 171 | // case CUBLAS_STATUS_LICENSE_ERROR: 172 | // return "The functionality requested requires some license."; 173 | default: 174 | throw std::runtime_error("invalid cublasStatus_t"); 175 | } 176 | } 177 | 178 | } 179 | -------------------------------------------------------------------------------- /src/image/img2win.cu: -------------------------------------------------------------------------------- 1 | #include "cudarray/common.hpp" 2 | #include "cudarray/image/img2win.hpp" 3 | 4 | 5 | namespace cudarray { 6 | 7 | inline static int ceil_div(int x, int y) { 8 | return (x + y - 1) / y; 9 | } 10 | 11 | template 12 | __global__ void kernel_img2win(const T *imgs, int n_threads, int n_imgs, 13 | int img_h, int img_w, int wins_h, int wins_w, int win_h, int win_w, 14 | int pad_y, int pad_x, int stride_y, int stride_x, T *wins) { 15 | int win_size = win_h*win_w; 16 | CUDA_GRID_STRIDE_LOOP(idx, n_threads) { 17 | int wins_x = idx % wins_w; 18 | int wins_y = (idx / wins_w) % wins_h; 19 | // window offset 20 | int k = (idx / wins_w / wins_h) % win_size; 21 | // image idx 22 | int n = idx / wins_w / wins_h / win_size * group_size; 23 | 24 | int img_x = wins_x * stride_x - pad_x + (k % win_w); 25 | int img_y = wins_y * stride_y - pad_y + k / win_w; 26 | imgs += (n*img_h + img_y)*img_w + img_x; 27 | wins += ((n*win_size + k)*wins_h + wins_y)*wins_w + wins_x; 28 | bool valid = img_x >= 0 && img_x < img_w && img_y >= 0 && img_y < img_h; 29 | 30 | for (int i = 0; i < group_size; ++i) { 31 | if (i+n < n_imgs) { 32 | if (valid) { 33 | *wins = *imgs; 34 | } else { 35 | *wins = 0.0; 36 | } 37 | } 38 | wins += win_size * wins_h * wins_w; 39 | imgs += img_h * img_w; 40 | } 41 | } 42 | } 43 | 44 | template 45 | void img2win(const T *imgs, int n_imgs, int img_h, int img_w, int win_h, 46 | int win_w, int pad_y, int pad_x, int stride_y, int stride_x, T *wins) { 47 | int wins_h = (img_h + 2*pad_y - win_h) / stride_y + 1; 48 | int wins_w = (img_w + 2*pad_x - win_w) / stride_x + 1; 49 | const int group_size = 32; 50 | int n_threads = ceil_div(n_imgs, group_size)*win_h*win_w*wins_h*wins_w; 51 | kernel_img2win 52 | <<>>( 53 | imgs, n_threads, n_imgs, img_h, img_w, wins_h, wins_w, win_h, win_w, 54 | pad_y, pad_x, stride_y, stride_x, wins 55 | ); 56 | CUDA_KERNEL_CHECK; 57 | } 58 | template void img2win(const float *imgs, int n_imgs, int img_h, int img_w, 59 | int win_h, int win_w, int pad_y, int pad_x, int stride_y, int stride_x, 60 | float *wins); 61 | 62 | 63 | 64 | template 65 | __global__ void kernel_win2img(const T* wins, int n_threads, int n_imgs, 66 | int img_h, int img_w, int wins_h, int wins_w, int win_h, int win_w, 67 | int pad_y, int pad_x, int stride_y, int stride_x, T *imgs) { 68 | CUDA_GRID_STRIDE_LOOP(idx, n_threads) { 69 | int img_x = idx % img_w + pad_x; 70 | int img_y = (idx / img_w) % img_h + pad_y; 71 | int n = idx / img_w / img_h; 72 | 73 | int wins_x_start = (img_x < win_w) ? 0 : (img_x - win_w) / stride_x + 1; 74 | int wins_x_end = min(img_x / stride_x + 1, wins_w); 75 | int wins_y_start = (img_y < win_h) ? 0 : (img_y - win_h) / stride_y + 1; 76 | int wins_y_end = min(img_y / stride_y + 1, wins_h); 77 | 78 | int wins_y_offset = (1 - stride_y * win_w * wins_h) * wins_w; 79 | int wins_x_offset = (1 - stride_x * wins_h * wins_w); 80 | 81 | wins += (n * win_h * win_w + img_y * win_w + img_x) * wins_h * wins_w; 82 | T sum = 0; 83 | for (int wins_y = wins_y_start; wins_y < wins_y_end; ++wins_y) { 84 | for (int wins_x = wins_x_start; wins_x < wins_x_end; ++wins_x) { 85 | sum += wins[wins_y * wins_y_offset + wins_x * wins_x_offset]; 86 | } 87 | } 88 | imgs[idx] = sum; 89 | } 90 | } 91 | 92 | template 93 | void win2img(const T *wins, int n_imgs, int img_h, int img_w, int win_h, 94 | int win_w, int pad_y, int pad_x, int stride_y, int stride_x, T *imgs) { 95 | int wins_h = (img_h + 2*pad_y - win_h) / stride_y + 1; 96 | int wins_w = (img_w + 2*pad_x - win_w) / stride_x + 1; 97 | int n_threads = n_imgs * img_h * img_w; 98 | kernel_win2img<<>>( 99 | wins, n_threads, n_imgs, img_h, img_w, wins_h, wins_w, win_h, win_w, 100 | pad_y, pad_x, stride_y, stride_x, imgs); 101 | CUDA_KERNEL_CHECK; 102 | } 103 | 104 | template void win2img(const float *wins, int n_imgs, int img_h, int img_w, 105 | int win_h, int win_w, int pad_y, int pad_x, int stride_y, int stride_x, 106 | float *imgs); 107 | } 108 | -------------------------------------------------------------------------------- /src/image/rescale.cu: -------------------------------------------------------------------------------- 1 | #include "cudarray/common.hpp" 2 | #include "cudarray/image/rescale.hpp" 3 | #include 4 | 5 | namespace cudarray { 6 | 7 | template 8 | __global__ void kernel_rescale_bilinear( 9 | const T * __restrict__ imgs, float factor, int n_imgs, int img_h, 10 | int img_w, int scaled_h, int scaled_w, T * __restrict__ imgs_scaled) { 11 | CUDA_GRID_STRIDE_LOOP(idx, n_imgs*scaled_h*scaled_w) { 12 | int x = idx % scaled_w; 13 | int y = (idx / scaled_w) % scaled_h; 14 | int n = idx / scaled_w / scaled_h; 15 | float img_x = (x+0.5) / (img_w*factor) * (img_w - 1); 16 | float img_y = (y+0.5) / (img_h*factor) * (img_h - 1); 17 | int img_x0 = max((int) floor(img_x), 0); 18 | int img_y0 = max((int) floor(img_y), 0); 19 | int img_x1 = min((int) img_x+1, img_w-1); 20 | int img_y1 = min((int) img_y+1, img_h-1); 21 | 22 | T val_00 = imgs[(n*img_h + img_y0)*img_w + img_x0]; 23 | T val_01 = imgs[(n*img_h + img_y0)*img_w + img_x1]; 24 | T val_10 = imgs[(n*img_h + img_y1)*img_w + img_x0]; 25 | T val_11 = imgs[(n*img_h + img_y1)*img_w + img_x1]; 26 | 27 | float a = img_x - img_x0; 28 | float b = img_y - img_y0; 29 | 30 | T val_0 = a*val_01 + (1.0 - a)*val_00; 31 | T val_1 = a*val_11 + (1.0 - a)*val_10; 32 | T val = b*val_1 + (1.0 - b)*val_0; 33 | 34 | imgs_scaled[(n*scaled_h + y)*scaled_w + x] = val; 35 | } 36 | } 37 | 38 | template 39 | __global__ void kernel_upsample_perforated( 40 | const T *imgs, int factor, int n_imgs, int img_h, int img_w, 41 | int scaled_h, int scaled_w, T *imgs_scaled) { 42 | CUDA_GRID_STRIDE_LOOP(idx, n_imgs*scaled_h*scaled_w) { 43 | int x = idx % scaled_w; 44 | int y = (idx / scaled_w) % scaled_h; 45 | int n = idx / scaled_w / scaled_h; 46 | T val; 47 | if (x % factor || y % factor) { 48 | val = 0; 49 | } else { 50 | int img_x = x / factor; 51 | int img_y = y / factor; 52 | val = imgs[(n*img_h + img_y)*img_w + img_x]; 53 | } 54 | imgs_scaled[(n*scaled_h + y)*scaled_w + x] = val; 55 | } 56 | } 57 | 58 | 59 | template 60 | __global__ void kernel_rescale_nearest( 61 | const T *imgs, float factor, int n_imgs, int img_h, int img_w, 62 | int scaled_h, int scaled_w, T *imgs_scaled) { 63 | CUDA_GRID_STRIDE_LOOP(idx, n_imgs*scaled_h*scaled_w) { 64 | int x = idx % scaled_w; 65 | int y = (idx / scaled_w) % scaled_h; 66 | int n = idx / scaled_w / scaled_h; 67 | int img_x = floor(x / factor); 68 | int img_y = floor(y / factor); 69 | T val = imgs[(n*img_h + img_y)*img_w + img_x]; 70 | imgs_scaled[(n*scaled_h + y)*scaled_w + x] = val; 71 | } 72 | } 73 | 74 | 75 | template 76 | void rescale(const T *imgs, float factor, SampleMethod method, int n_imgs, 77 | int img_h, int img_w, T *imgs_scaled) { 78 | if (factor <= 0) { 79 | throw std::runtime_error("Factor must be positive."); 80 | } 81 | int scaled_h; 82 | int scaled_w; 83 | if (factor < 1) { 84 | scaled_h = ceil(img_h*factor); 85 | scaled_w = ceil(img_w*factor); 86 | } else { 87 | scaled_h = floor(img_h*factor); 88 | scaled_w = floor(img_w*factor); 89 | } 90 | int n_threads = n_imgs * scaled_h * scaled_w; 91 | switch(method) { 92 | case BILINEAR_SAMPLING: 93 | kernel_rescale_bilinear<<>>( 94 | imgs, factor, n_imgs, img_h, img_w, scaled_h, scaled_w, imgs_scaled 95 | ); 96 | break; 97 | case NEAREST_SAMPLING: 98 | kernel_rescale_nearest<<>>( 99 | imgs, factor, n_imgs, img_h, img_w, scaled_h, scaled_w, imgs_scaled 100 | ); 101 | break; 102 | case PERFORATED_SAMPLING: 103 | if (factor < 1) { 104 | kernel_rescale_nearest<<>>( 105 | imgs, factor, n_imgs, img_h, img_w, scaled_h, scaled_w, imgs_scaled 106 | ); 107 | } else { 108 | if (ceilf(factor) != factor) { 109 | throw std::runtime_error("Factor must be integer for perforated upscaling."); 110 | } 111 | kernel_upsample_perforated<<>>( 112 | imgs, factor, n_imgs, img_h, img_w, scaled_h, scaled_w, imgs_scaled 113 | ); 114 | } 115 | break; 116 | default: 117 | throw std::runtime_error("Invalid method."); 118 | } 119 | CUDA_KERNEL_CHECK; 120 | } 121 | 122 | 123 | template void rescale(const float *imgs, float factor, SampleMethod method, 124 | int n_imgs, int img_h, int img_w, float *imgs_scaled); 125 | 126 | } 127 | -------------------------------------------------------------------------------- /src/nnet/conv_bc01_matmul.cpp: -------------------------------------------------------------------------------- 1 | #include "cudarray/common.hpp" 2 | #include "cudarray/blas.hpp" 3 | #include "cudarray/nnet/conv_bc01_matmul.hpp" 4 | #include "cudarray/image/img2win.hpp" 5 | 6 | // The following convolution operations by matrix multiplication are heavily 7 | // inspired by those from Caffe, http://caffe.berkeleyvision.org/ 8 | 9 | namespace cudarray { 10 | 11 | template 12 | void conv_bc01_matmul(const T *imgs, const T *filters, int n_imgs, 13 | int n_channels, int n_filters, int img_h, int img_w, int filter_h, 14 | int filter_w, int pad_y, int pad_x, int stride_y, int stride_x, 15 | T *convout) { 16 | int convout_h = (img_h + 2 * pad_y - filter_h) / stride_y + 1; 17 | int convout_w = (img_w + 2 * pad_x - filter_w) / stride_x + 1; 18 | int win_size = filter_h * filter_w; 19 | T *buffer = (T *) CUDA::buffer(sizeof(float) * n_channels * win_size 20 | * convout_h * convout_w); 21 | int m = n_filters; 22 | int k = n_channels*win_size; 23 | int n = convout_h * convout_w; 24 | for (int i = 0; i < n_imgs; ++i) { 25 | const T *img = imgs + i * n_channels * img_h * img_w; 26 | img2win(img, n_channels, img_h, img_w, filter_h, filter_w, pad_y, pad_x, 27 | stride_y, stride_x, buffer); 28 | T *convout_img = convout + i * n_filters * convout_h * convout_w; 29 | gemm(filters, buffer, OP_NO_TRANS, OP_NO_TRANS, m, n, k, (T) 1.0, (T) 0.0, 30 | convout_img); 31 | } 32 | } 33 | template void conv_bc01_matmul(const float *imgs, const float *filters, 34 | int n_imgs, int n_channels, int n_filters, int img_h, int img_w, 35 | int filter_h, int filter_w, int pad_y, int pad_x, int stride_y, 36 | int stride_x, float *convout); 37 | 38 | 39 | 40 | template 41 | void conv_bc01_matmul_bprop_imgs(const T *filters, const T *convout_d, 42 | int n_imgs, int n_channels, int n_filters, int img_h, int img_w, 43 | int filter_h, int filter_w, int pad_y, int pad_x, int stride_y, 44 | int stride_x, T *imgs_d) { 45 | int convout_h = (img_h + 2 * pad_y - filter_h) / stride_y + 1; 46 | int convout_w = (img_w + 2 * pad_x - filter_w) / stride_x + 1; 47 | int win_size = filter_h * filter_w; 48 | T *buffer = (T *) CUDA::buffer(sizeof(float) * n_channels * win_size 49 | * convout_h * convout_w); 50 | int m = n_channels * win_size; 51 | int k = n_filters; 52 | int n = convout_h * convout_w; 53 | for (int i = 0; i < n_imgs; ++i) { 54 | const T *convout_img_d = convout_d + i * n_filters * convout_h * convout_w; 55 | gemm(filters, convout_img_d, OP_TRANS, OP_NO_TRANS, m, n, k, (T) 1.0, 56 | (T) 0.0, buffer); 57 | 58 | T *img_d = imgs_d + i * n_channels * img_h * img_w; 59 | win2img(buffer, n_channels, img_h, img_w, filter_h, filter_w, pad_y, pad_x, 60 | stride_y, stride_x, img_d); 61 | } 62 | } 63 | template void conv_bc01_matmul_bprop_imgs(const float *filters, const float *convout_d, 64 | int n_imgs, int n_channels, int n_filters, int img_h, int img_w, 65 | int filter_h, int filter_w, int pad_y, int pad_x, int stride_y, 66 | int stride_x, float *imgs_d); 67 | 68 | 69 | 70 | template 71 | void conv_bc01_matmul_bprop_filters(const T *imgs, const T *convout_d, 72 | int n_imgs, int n_channels, int n_filters, int img_h, int img_w, 73 | int filter_h, int filter_w, int pad_y, int pad_x, int stride_y, 74 | int stride_x, T *filters_d) { 75 | int convout_h = (img_h + 2 * pad_y - filter_h) / stride_y + 1; 76 | int convout_w = (img_w + 2 * pad_x - filter_w) / stride_x + 1; 77 | int win_size = filter_h * filter_w; 78 | T *buffer = (T *) CUDA::buffer(sizeof(float) * n_channels * win_size 79 | * convout_h * convout_w); 80 | int m = n_filters; 81 | int k = convout_h * convout_w; 82 | int n = n_channels*win_size; 83 | for (int i = 0; i < n_imgs; ++i) { 84 | const T *img = imgs + i * n_channels * img_h * img_w; 85 | img2win(img, n_channels, img_h, img_w, filter_h, filter_w, pad_y, pad_x, 86 | stride_y, stride_x, buffer); 87 | const T *convout_img_d = convout_d + i * n_filters * convout_h * convout_w; 88 | T beta = i > 0 ? 1.0 : 0.0; 89 | gemm(convout_img_d, buffer, OP_NO_TRANS, OP_TRANS, m, n, k, (T) 1.0, beta, 90 | filters_d); 91 | } 92 | } 93 | template void conv_bc01_matmul_bprop_filters(const float *imgs, const float *convout_d, 94 | int n_imgs, int n_channels, int n_filters, int img_h, int img_w, 95 | int filter_h, int filter_w, int pad_y, int pad_x, int stride_y, 96 | int stride_x, float *filters_d); 97 | 98 | } 99 | -------------------------------------------------------------------------------- /src/nnet/cudnn.cpp: -------------------------------------------------------------------------------- 1 | #ifdef CUDNN_ENABLED 2 | 3 | #include 4 | #include "cudarray/common.hpp" 5 | #include "cudarray/nnet/cudnn.hpp" 6 | 7 | namespace cudarray { 8 | 9 | 10 | const float CUDNN::one = 1.0f; 11 | const float CUDNN::zero = 0.0f; 12 | 13 | 14 | template 15 | PoolBC01CuDNN::PoolBC01CuDNN(int n_img_dims, int *win_shape, int *padding, 16 | int *strides, PoolMode pool_mode) : n_img_dims(n_img_dims) { 17 | if (n_img_dims > MAX_IMG_DIMS + 2) { 18 | throw std::runtime_error("More than 3 image dimensions."); 19 | } 20 | 21 | for (int i = 0; i < n_img_dims; ++i) { 22 | this->win_shape[i] = win_shape[i]; 23 | this->padding[i] = padding[i]; 24 | this->strides[i] = strides[i]; 25 | } 26 | for (int i = 0; i < n_img_dims + 2; ++i) { 27 | imgs_shape[i] = -1; 28 | } 29 | this->pool_mode = pool_mode == POOL_MAX ? CUDNN_POOLING_MAX : 30 | CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; 31 | CUDNN_CHECK(cudnnCreateTensorDescriptor(&imgs_desc)); 32 | CUDNN_CHECK(cudnnCreateTensorDescriptor(&poolout_desc)); 33 | CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc)); 34 | } 35 | 36 | 37 | template 38 | PoolBC01CuDNN::~PoolBC01CuDNN() { 39 | CUDNN_CHECK(cudnnDestroyTensorDescriptor(imgs_desc)); 40 | CUDNN_CHECK(cudnnDestroyTensorDescriptor(poolout_desc)); 41 | CUDNN_CHECK(cudnnDestroyPoolingDescriptor(pool_desc)); 42 | } 43 | 44 | 45 | void array_strides(int n_dims, const int *shape, int *strides) { 46 | int stride = 1; 47 | for (int i = n_dims-1; i >= 0; --i) { 48 | strides[i] = stride; 49 | stride *= shape[i]; 50 | } 51 | } 52 | 53 | template 54 | void PoolBC01CuDNN::fprop(const T *imgs, int *imgs_shape, T *poolout) { 55 | bool new_shape = false; 56 | int n_imgs_dims = n_img_dims + 2; 57 | for (int i = 0; i < n_imgs_dims; ++i) { 58 | if (this->imgs_shape[i] != imgs_shape[i]) { 59 | new_shape = true; 60 | break; 61 | } 62 | } 63 | 64 | if (new_shape) { 65 | for (int i = 0; i < n_imgs_dims; ++i) { 66 | this->imgs_shape[i] = imgs_shape[i]; 67 | } 68 | int imgs_strides[n_imgs_dims]; 69 | array_strides(n_imgs_dims, imgs_shape, imgs_strides); 70 | CUDNN_CHECK(cudnnSetTensorNdDescriptor( 71 | imgs_desc, CUDNN_DATA_FLOAT, n_imgs_dims, imgs_shape, imgs_strides 72 | )); 73 | 74 | CUDNN_CHECK(cudnnSetPoolingNdDescriptor( 75 | pool_desc, pool_mode, CUDNN_PROPAGATE_NAN, n_img_dims, win_shape, 76 | padding, strides 77 | )); 78 | 79 | int poolout_shape[n_imgs_dims]; 80 | poolout_shape[0] = imgs_shape[0]; 81 | poolout_shape[1] = imgs_shape[1]; 82 | for (int i = 0; i < n_img_dims; ++i) { 83 | poolout_shape[i+2] = (imgs_shape[i+2] + 2*padding[i] - win_shape[i]) 84 | / strides[i] + 1; 85 | } 86 | 87 | int poolout_strides[n_imgs_dims]; 88 | array_strides(n_imgs_dims, poolout_shape, poolout_strides); 89 | CUDNN_CHECK(cudnnSetTensorNdDescriptor( 90 | poolout_desc, CUDNN_DATA_FLOAT, n_imgs_dims, poolout_shape, 91 | poolout_strides 92 | )); 93 | } 94 | 95 | CUDNN_CHECK(cudnnPoolingForward( 96 | CUDNN::handle(), pool_desc, &CUDNN::one, imgs_desc, imgs, &CUDNN::zero, 97 | poolout_desc, poolout 98 | )); 99 | } 100 | 101 | 102 | template 103 | void PoolBC01CuDNN::bprop(const T *imgs, const T* poolout, 104 | const T *poolout_d, T *imgs_d) { 105 | CUDNN_CHECK(cudnnPoolingBackward( 106 | CUDNN::handle(), pool_desc, &CUDNN::one, poolout_desc, poolout, 107 | poolout_desc, poolout_d, imgs_desc, imgs, &CUDNN::zero, imgs_desc, imgs_d 108 | )); 109 | } 110 | 111 | 112 | template class PoolBC01CuDNN; 113 | 114 | 115 | 116 | template 117 | ConvBC01CuDNN::ConvBC01CuDNN(int pad_y, int pad_x, int stride_y, 118 | int stride_x) : pad_y(pad_y), pad_x(pad_x), stride_y(stride_y), 119 | stride_x(stride_x), n_imgs(0), n_channels(0), n_filters(0), img_h(0), 120 | img_w(0), filter_h(0), filter_w(0), workspace_size(0) { 121 | CUDNN_CHECK(cudnnCreateTensorDescriptor(&imgs_desc)); 122 | CUDNN_CHECK(cudnnCreateTensorDescriptor(&convout_desc)); 123 | CUDNN_CHECK(cudnnCreateFilterDescriptor(&filters_desc)); 124 | CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc)); 125 | } 126 | 127 | 128 | template 129 | ConvBC01CuDNN::~ConvBC01CuDNN() { 130 | CUDNN_CHECK(cudnnDestroyTensorDescriptor(imgs_desc)); 131 | CUDNN_CHECK(cudnnDestroyTensorDescriptor(convout_desc)); 132 | CUDNN_CHECK(cudnnDestroyFilterDescriptor(filters_desc)); 133 | CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc)); 134 | } 135 | 136 | 137 | template 138 | void ConvBC01CuDNN::fprop(const T *imgs, const T *filters, int n_imgs, 139 | int n_channels, int n_filters, int img_h, int img_w, int filter_h, 140 | int filter_w, T *convout) { 141 | bool set_conv_desc = false; 142 | if (n_imgs != this->n_imgs || n_channels != this->n_channels || 143 | img_h != this->img_h || img_w != this->img_w) { 144 | CUDNN_CHECK(cudnnSetTensor4dDescriptor( 145 | imgs_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n_imgs, n_channels, 146 | img_h, img_w 147 | )); 148 | this->n_imgs = n_imgs; 149 | this->n_channels = n_channels; 150 | this->img_h = img_h; 151 | this->img_w = img_w; 152 | set_conv_desc = true; 153 | } 154 | if (n_filters != this->n_filters || n_channels != this->n_channels || 155 | filter_h != this->filter_h || filter_w != this->filter_w) { 156 | CUDNN_CHECK(cudnnSetFilter4dDescriptor( 157 | filters_desc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, n_filters, 158 | n_channels, filter_h, filter_w 159 | )); 160 | this->n_filters = n_filters; 161 | this->n_channels = n_channels; 162 | this->filter_h = filter_h; 163 | this->filter_w = filter_w; 164 | set_conv_desc = true; 165 | } 166 | if (set_conv_desc) { 167 | CUDNN_CHECK(cudnnSetConvolution2dDescriptor( 168 | conv_desc, pad_y, pad_x, stride_y, stride_x, 1, 1, CUDNN_CONVOLUTION 169 | )); 170 | int n, c, h, w; 171 | CUDNN_CHECK(cudnnGetConvolution2dForwardOutputDim( 172 | conv_desc, imgs_desc, filters_desc, &n, &c, &h, &w 173 | )); 174 | CUDNN_CHECK(cudnnSetTensor4dDescriptor( 175 | convout_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w 176 | )); 177 | const int n_requestedAlgo = 16; 178 | int n_returnedAlgo; 179 | cudnnConvolutionFwdAlgoPerf_t fwd_algo_perf[n_requestedAlgo]; 180 | CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm( 181 | CUDNN::handle(), imgs_desc, filters_desc, conv_desc, convout_desc, 182 | n_requestedAlgo, &n_returnedAlgo, fwd_algo_perf 183 | )); 184 | if (n_returnedAlgo == 0) { 185 | throw std::runtime_error("No cudnnConvolutionFwdAlgoPerf_t found"); 186 | } 187 | 188 | fwd_algo = fwd_algo_perf[0].algo; 189 | cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo_perf[n_requestedAlgo]; 190 | CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithm( 191 | CUDNN::handle(), filters_desc, convout_desc, conv_desc, imgs_desc, 192 | n_requestedAlgo, &n_returnedAlgo, bwd_data_algo_perf 193 | )); 194 | if (n_returnedAlgo == 0) { 195 | throw std::runtime_error("No cudnnConvolutionBwdDataAlgoPerf_t found"); 196 | } 197 | 198 | bwd_imgs_algo = bwd_data_algo_perf[0].algo; 199 | cudnnConvolutionBwdFilterAlgoPerf_t bwd_filters_algo_perf[n_requestedAlgo]; 200 | CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithm( 201 | CUDNN::handle(), imgs_desc, convout_desc, conv_desc, filters_desc, 202 | n_requestedAlgo, &n_returnedAlgo, bwd_filters_algo_perf 203 | )); 204 | if (n_returnedAlgo == 0) { 205 | throw std::runtime_error("No cudnnConvolutionBwdFilterAlgoPerf_t found"); 206 | } 207 | bwd_filters_algo = bwd_filters_algo_perf[0].algo; 208 | size_t fwd_workspace_size; 209 | size_t bwd_imgs_workspace_size; 210 | size_t bwd_filters_workspace_size; 211 | CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize( 212 | CUDNN::handle(), imgs_desc, filters_desc, conv_desc, convout_desc, 213 | fwd_algo, &fwd_workspace_size 214 | )); 215 | CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize( 216 | CUDNN::handle(), filters_desc, convout_desc, conv_desc, imgs_desc, 217 | bwd_imgs_algo, &bwd_imgs_workspace_size 218 | )); 219 | CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize( 220 | CUDNN::handle(), imgs_desc, convout_desc, conv_desc, filters_desc, 221 | bwd_filters_algo, &bwd_filters_workspace_size 222 | )); 223 | workspace_size = std::max(fwd_workspace_size, bwd_imgs_workspace_size); 224 | workspace_size = std::max(workspace_size, bwd_filters_workspace_size); 225 | } 226 | void *workspace = NULL; 227 | if (workspace_size > 0) { 228 | workspace = CUDA::buffer(workspace_size); 229 | } 230 | CUDNN_CHECK(cudnnConvolutionForward( 231 | CUDNN::handle(), &CUDNN::one, imgs_desc, imgs, filters_desc, filters, 232 | conv_desc, fwd_algo, workspace, workspace_size, &CUDNN::zero, 233 | convout_desc, convout 234 | )); 235 | } 236 | 237 | 238 | template 239 | void ConvBC01CuDNN::bprop(const T* imgs, const T* filters, 240 | const T *convout_d, T *imgs_d, T *filters_d) { 241 | void *workspace = NULL; 242 | if (workspace_size > 0) { 243 | workspace = CUDA::buffer(workspace_size); 244 | } 245 | if (filters_d) { 246 | CUDNN_CHECK(cudnnConvolutionBackwardFilter( 247 | CUDNN::handle(), &CUDNN::one, imgs_desc, imgs, convout_desc, convout_d, 248 | conv_desc, bwd_filters_algo, workspace, workspace_size, &CUDNN::zero, 249 | filters_desc, filters_d 250 | )); 251 | } 252 | if (imgs_d) { 253 | CUDNN_CHECK(cudnnConvolutionBackwardData( 254 | CUDNN::handle(), &CUDNN::one, filters_desc, filters, convout_desc, 255 | convout_d, conv_desc, bwd_imgs_algo, workspace, workspace_size, 256 | &CUDNN::zero, imgs_desc, imgs_d 257 | )); 258 | } 259 | } 260 | 261 | template class ConvBC01CuDNN; 262 | 263 | 264 | const char *cudnn_message(cudnnStatus_t status){ 265 | switch(status) { 266 | case CUDNN_STATUS_SUCCESS: 267 | return "The operation completed successfully"; 268 | case CUDNN_STATUS_NOT_INITIALIZED: 269 | return "The cuDNN library was not initialized properly."; 270 | case CUDNN_STATUS_ALLOC_FAILED: 271 | return "Resource allocation failed inside the cuDNN library."; 272 | case CUDNN_STATUS_BAD_PARAM: 273 | return "An incorrect parameter was passed to the function."; 274 | case CUDNN_STATUS_INTERNAL_ERROR: 275 | return "An internal cuDNN operation failed."; 276 | case CUDNN_STATUS_INVALID_VALUE: 277 | return "CUDNN_STATUS_INVALID_VALUE"; 278 | case CUDNN_STATUS_ARCH_MISMATCH: 279 | return "The function requires a feature absent from the GPU"; 280 | case CUDNN_STATUS_MAPPING_ERROR: 281 | return "An access to GPU memory space failed."; 282 | case CUDNN_STATUS_EXECUTION_FAILED: 283 | return "The GPU program failed to execute."; 284 | case CUDNN_STATUS_NOT_SUPPORTED: 285 | return "The functionality not presently supported by cuDNN."; 286 | case CUDNN_STATUS_LICENSE_ERROR: 287 | return "The functionality requested requires some license."; 288 | default: 289 | throw std::runtime_error("invalid cudnnStatus_t"); 290 | } 291 | } 292 | 293 | 294 | } // cudarray 295 | 296 | #endif // CUDNN_ENABLED 297 | -------------------------------------------------------------------------------- /src/nnet/one_hot.cu: -------------------------------------------------------------------------------- 1 | #include "cudarray/common.hpp" 2 | #include "cudarray/nnet/one_hot.hpp" 3 | 4 | namespace cudarray { 5 | 6 | template 7 | __global__ void kernel_one_hot_encode(const int *labels, int n_classes, int n, 8 | T *out) { 9 | CUDA_GRID_STRIDE_LOOP(idx, n*n_classes) { 10 | int class_idx = idx % n_classes; 11 | int label_idx = idx / n_classes; 12 | out[idx] = labels[label_idx] == class_idx ? 1.0 : 0.0; 13 | } 14 | } 15 | 16 | template 17 | void one_hot_encode(const int *labels, int n_classes, int n, T *out) { 18 | kernel_one_hot_encode<<>>( 19 | labels, n_classes, n, out); 20 | CUDA_KERNEL_CHECK; 21 | } 22 | 23 | template void one_hot_encode(const int *labels, int n_classes, int n, 24 | float *out); 25 | 26 | } 27 | -------------------------------------------------------------------------------- /src/nnet/pool_b01.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "cudarray/common.hpp" 3 | #include "cudarray/nnet/pool_b01.hpp" 4 | 5 | namespace cudarray { 6 | 7 | // The implementations below are inspired by those found in the Caffe framework 8 | 9 | template 10 | __global__ void max_pool_b01(int n_threads, const T* imgs, 11 | int img_h, int img_w, int poolout_h, int poolout_w, int win_h, int win_w, 12 | int pad_y, int pad_x, int stride_y, int stride_x, T* poolout, int* mask) { 13 | CUDA_GRID_STRIDE_LOOP(idx, n_threads) { 14 | int poolout_x = idx % poolout_w; 15 | int poolout_y = (idx / poolout_w) % poolout_h; 16 | int n = idx / poolout_w / poolout_h; 17 | int img_y_start = poolout_y * stride_y - pad_y; 18 | int img_x_start = poolout_x * stride_x - pad_x; 19 | int img_y_end = min(img_y_start + win_h, img_h); 20 | int img_x_end = min(img_x_start + win_w, img_w); 21 | img_y_start = max(img_y_start, 0); 22 | img_x_start = max(img_x_start, 0); 23 | T maxval = -FLT_MAX; 24 | int maxidx = -1; 25 | imgs += n * img_h * img_w; 26 | for (int h = img_y_start; h < img_y_end; ++h) { 27 | for (int w = img_x_start; w < img_x_end; ++w) { 28 | if (imgs[h * img_w + w] > maxval) { 29 | maxidx = h * img_w + w; 30 | maxval = imgs[maxidx]; 31 | } 32 | } 33 | } 34 | poolout[idx] = maxval; 35 | mask[idx] = maxidx; 36 | } 37 | } 38 | 39 | template 40 | void max_pool_b01(const T* imgs, int n_imgs, int img_h, int img_w, int win_h, 41 | int win_w, int pad_y, int pad_x, int stride_y, int stride_x, T* poolout, 42 | int* mask) { 43 | int poolout_h = (img_h + 2*pad_y - win_h) / stride_y + 1; 44 | int poolout_w = (img_w + 2*pad_x - win_w) / stride_x + 1; 45 | int n_threads = n_imgs * poolout_h * poolout_w; 46 | max_pool_b01<<>>( 47 | n_threads, imgs, img_h, img_w, poolout_h, poolout_w, win_h, win_w, pad_y, 48 | pad_x, stride_y, stride_x, poolout, mask); 49 | CUDA_KERNEL_CHECK; 50 | } 51 | 52 | template void max_pool_b01(const float* imgs, int n_imgs, int img_h, 53 | int img_w, int win_h, int win_w, int pad_y, int pad_x, int stride_y, 54 | int stride_x, float* poolout, int* mask); 55 | 56 | 57 | 58 | template 59 | __global__ void max_pool_b01_bprob(int n_threads, const T* poolout_d, 60 | const int* mask, int img_h, int img_w, int poolout_h, int poolout_w, 61 | int win_h, int win_w, int pad_y, int pad_x, int stride_y, int stride_x, 62 | T* imgs_d) { 63 | CUDA_GRID_STRIDE_LOOP(idx, n_threads) { 64 | int img_x = idx % img_w; 65 | int img_y = (idx / img_w) % img_h; 66 | int n = idx / img_w / img_h; 67 | int poolout_y_start = (img_y + pad_y < win_h) 68 | ? 0 : (img_y + pad_y - win_h) / stride_y + 1; 69 | int poolout_y_end = min((img_y + pad_y) / stride_y + 1, poolout_h); 70 | int poolout_x_start = (img_x + pad_x < win_w) 71 | ? 0 : (img_x + pad_x - win_w) / stride_x + 1; 72 | int poolout_x_end = min((img_x + pad_x) / stride_x + 1, poolout_w); 73 | int offset = n * poolout_h * poolout_w; 74 | poolout_d += offset; 75 | mask += offset; 76 | T gradient = 0; 77 | for (int ph = poolout_y_start; ph < poolout_y_end; ++ph) { 78 | for (int pw = poolout_x_start; pw < poolout_x_end; ++pw) { 79 | if (mask[ph * poolout_w + pw] == img_y * img_w + img_x) { 80 | gradient += poolout_d[ph * poolout_w + pw]; 81 | } 82 | } 83 | } 84 | imgs_d[idx] = gradient; 85 | } 86 | } 87 | 88 | template 89 | void max_pool_b01_bprob(const T* poolout_d, const int* mask, int n_imgs, 90 | int img_h, int img_w, int win_h, int win_w, int pad_y, int pad_x, 91 | int stride_y, int stride_x, T* imgs_d) { 92 | int poolout_h = (img_h + 2*pad_y - win_h) / stride_y + 1; 93 | int poolout_w = (img_w + 2*pad_x - win_w) / stride_x + 1; 94 | int n_threads = n_imgs * img_h * img_w; 95 | max_pool_b01_bprob<<>>( 96 | n_threads, poolout_d, mask, img_h, img_w, poolout_h, poolout_w, win_h, 97 | win_w, pad_y, pad_x, stride_y, stride_x, imgs_d); 98 | CUDA_KERNEL_CHECK; 99 | } 100 | 101 | template void max_pool_b01_bprob(const float* poolout_d, const int* mask, 102 | int n_imgs, int img_h, int img_w, int win_h, int win_w, int pad_y, 103 | int pad_x, int stride_y, int stride_x, float* imgs_d); 104 | 105 | 106 | 107 | 108 | template 109 | __global__ void avg_pool_b01(int n_threads, const T* imgs, 110 | int img_h, int img_w, int poolout_h, int poolout_w, int win_h, int win_w, 111 | int pad_y, int pad_x, int stride_y, int stride_x, T* poolout) { 112 | CUDA_GRID_STRIDE_LOOP(idx, n_threads) { 113 | int poolout_x = idx % poolout_w; 114 | int poolout_y = (idx / poolout_w) % poolout_h; 115 | int n = idx / poolout_w / poolout_h; 116 | int img_y_start = poolout_y * stride_y - pad_y; 117 | int img_x_start = poolout_x * stride_x - pad_x; 118 | int img_y_end = min(img_y_start + win_h, img_h); 119 | int img_x_end = min(img_x_start + win_w, img_w); 120 | img_y_start = max(img_y_start, 0); 121 | img_x_start = max(img_x_start, 0); 122 | T sum = 0; 123 | imgs += n * img_h * img_w; 124 | for (int h = img_y_start; h < img_y_end; ++h) { 125 | for (int w = img_x_start; w < img_x_end; ++w) { 126 | sum += imgs[h * img_w + w]; 127 | } 128 | } 129 | poolout[idx] = sum / (win_h*win_w); 130 | } 131 | } 132 | 133 | template 134 | void avg_pool_b01(const T* imgs, int n_imgs, int img_h, int img_w, int win_h, 135 | int win_w, int pad_y, int pad_x, int stride_y, int stride_x, T* poolout) { 136 | int poolout_h = (img_h + 2*pad_y - win_h) / stride_y + 1; 137 | int poolout_w = (img_w + 2*pad_x - win_w) / stride_x + 1; 138 | int n_threads = n_imgs * poolout_h * poolout_w; 139 | avg_pool_b01<<>>( 140 | n_threads, imgs, img_h, img_w, poolout_h, poolout_w, win_h, win_w, pad_y, 141 | pad_x, stride_y, stride_x, poolout); 142 | CUDA_KERNEL_CHECK; 143 | } 144 | 145 | template void avg_pool_b01(const float* imgs, int n_imgs, int img_h, 146 | int img_w, int win_h, int win_w, int pad_y, int pad_x, int stride_y, 147 | int stride_x, float* poolout); 148 | 149 | 150 | 151 | template 152 | __global__ void avg_pool_b01_bprob(int n_threads, const T* poolout_d, 153 | int img_h, int img_w, int poolout_h, int poolout_w, int win_h, int win_w, 154 | int pad_y, int pad_x, int stride_y, int stride_x, T* imgs_d) { 155 | CUDA_GRID_STRIDE_LOOP(idx, n_threads) { 156 | int img_x = idx % img_w; 157 | int img_y = (idx / img_w) % img_h; 158 | int n = idx / img_w / img_h; 159 | int poolout_y_start = (img_y + pad_y < win_h) 160 | ? 0 : (img_y + pad_y - win_h) / stride_y + 1; 161 | int poolout_y_end = min((img_y + pad_y) / stride_y + 1, poolout_h); 162 | int poolout_x_start = (img_x + pad_x < win_w) 163 | ? 0 : (img_x + pad_x - win_w) / stride_x + 1; 164 | int poolout_x_end = min((img_x + pad_x) / stride_x + 1, poolout_w); 165 | int offset = n * poolout_h * poolout_w; 166 | poolout_d += offset; 167 | T gradient = 0; 168 | for (int ph = poolout_y_start; ph < poolout_y_end; ++ph) { 169 | for (int pw = poolout_x_start; pw < poolout_x_end; ++pw) { 170 | gradient += poolout_d[ph * poolout_w + pw]; 171 | } 172 | } 173 | imgs_d[idx] = gradient / (win_h * win_w); 174 | } 175 | } 176 | 177 | template 178 | void avg_pool_b01_bprob(const T* poolout_d, int n_imgs, int img_h, int img_w, 179 | int win_h, int win_w, int pad_y, int pad_x, int stride_y, int stride_x, 180 | T* imgs_d) { 181 | int poolout_h = (img_h + 2*pad_y - win_h) / stride_y + 1; 182 | int poolout_w = (img_w + 2*pad_x - win_w) / stride_x + 1; 183 | int n_threads = n_imgs * img_h * img_w; 184 | avg_pool_b01_bprob<<>>( 185 | n_threads, poolout_d, img_h, img_w, poolout_h, poolout_w, win_h, 186 | win_w, pad_y, pad_x, stride_y, stride_x, imgs_d); 187 | CUDA_KERNEL_CHECK; 188 | } 189 | 190 | template void avg_pool_b01_bprob(const float* poolout_d, int n_imgs, int img_h, 191 | int img_w, int win_h, int win_w, int pad_y, int pad_x, int stride_y, 192 | int stride_x, float* imgs_d); 193 | 194 | 195 | } 196 | -------------------------------------------------------------------------------- /src/random.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "cudarray/common.hpp" 3 | #include "cudarray/random.hpp" 4 | 5 | 6 | namespace cudarray { 7 | 8 | 9 | void seed(unsigned long long val) { 10 | CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(CURAND::generator(), 11 | val)); 12 | } 13 | 14 | template <> 15 | void random_normal(float *a, float mu, float sigma, unsigned int n) { 16 | CURAND_CHECK(curandGenerateNormal(CURAND::generator(), a, n, mu, sigma)); 17 | } 18 | 19 | 20 | template 21 | __global__ void kernel_stretch(T *a, T alpha, T beta, unsigned int n) { 22 | CUDA_GRID_STRIDE_LOOP(idx, n) { 23 | a[idx] = alpha*a[idx] + beta; 24 | } 25 | } 26 | 27 | 28 | template <> 29 | void random_uniform(float *a, float low, float high, unsigned int n) { 30 | CURAND_CHECK(curandGenerateUniform(CURAND::generator(), a, n)); 31 | if (high != 1.0 || low != 0.0) { 32 | float alpha = high - low; 33 | float beta = low; 34 | kernel_stretch<<>>(a, alpha, beta, n); 35 | } 36 | } 37 | 38 | 39 | const char* curand_message(curandStatus_t status) { 40 | switch (status) { 41 | case CURAND_STATUS_SUCCESS: 42 | return "No errors."; 43 | case CURAND_STATUS_VERSION_MISMATCH: 44 | return "Header file and linked library version do not match."; 45 | case CURAND_STATUS_NOT_INITIALIZED: 46 | return "Generator not initialized."; 47 | case CURAND_STATUS_ALLOCATION_FAILED: 48 | return "Memory allocation failed."; 49 | case CURAND_STATUS_TYPE_ERROR: 50 | return "Generator is wrong type."; 51 | case CURAND_STATUS_OUT_OF_RANGE: 52 | return "Argument out of range."; 53 | case CURAND_STATUS_LENGTH_NOT_MULTIPLE: 54 | return "Length requested is not a multple of dimension."; 55 | case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: 56 | return "GPU does not have double precision required by MRG32k3a."; 57 | case CURAND_STATUS_LAUNCH_FAILURE: 58 | return "Kernel launch failure."; 59 | case CURAND_STATUS_PREEXISTING_FAILURE: 60 | return "Preexisting failure on library entry."; 61 | case CURAND_STATUS_INITIALIZATION_FAILED: 62 | return "Initialization of CUDA failed."; 63 | case CURAND_STATUS_ARCH_MISMATCH: 64 | return "Architecture mismatch, GPU does not support requested feature."; 65 | case CURAND_STATUS_INTERNAL_ERROR: 66 | return "Internal library error."; 67 | default: 68 | throw std::runtime_error("invalid curandStatus_t"); 69 | } 70 | } 71 | 72 | 73 | } 74 | -------------------------------------------------------------------------------- /src/reduction.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "cudarray/common.hpp" 5 | #include "cudarray/reduction.hpp" 6 | #include "cudarray/elementwise.hpp" 7 | 8 | // The parallel reductions below are heavily based on 9 | // http://developer.download.nvidia.com/assets/cuda/files/reduction.pdf 10 | // and http://cudpp.github.io/ 11 | 12 | // TODO: parallelize reduce_to_int() and reduce_mat() à la reduce() 13 | 14 | namespace cudarray { 15 | 16 | template 17 | struct SharedMemory { 18 | __device__ T* pointer() const; 19 | }; 20 | template <> 21 | __device__ inline int *SharedMemory::pointer() const { 22 | extern __shared__ int s_int[]; 23 | return s_int; 24 | } 25 | template <> 26 | __device__ inline float *SharedMemory::pointer() const { 27 | extern __shared__ float s_float[]; 28 | return s_float; 29 | } 30 | 31 | 32 | template 33 | struct MaxOp { 34 | __device__ T identity() const; 35 | __device__ T operator()(const T a, const T b) { 36 | return max(a, b); 37 | } 38 | }; 39 | template <> 40 | __device__ inline int MaxOp::identity() const { 41 | return INT_MIN; 42 | } 43 | template <> 44 | __device__ inline float MaxOp::identity() const { 45 | return -FLT_MAX; 46 | } 47 | 48 | template 49 | struct MinOp { 50 | __device__ T identity() const; 51 | __device__ T operator()(const T a, const T b) { 52 | return min(a, b); 53 | } 54 | }; 55 | template <> 56 | __device__ inline int MinOp::identity() const { 57 | return INT_MAX; 58 | } 59 | template <> 60 | __device__ inline float MinOp::identity() const { 61 | return FLT_MAX; 62 | } 63 | 64 | template 65 | struct MulOp { 66 | __device__ T identity() { 67 | return (T) 1; 68 | } 69 | __device__ T operator()(const T a, const T b) { 70 | return a * b; 71 | } 72 | }; 73 | 74 | template 75 | struct AddOp { 76 | __device__ T identity() { 77 | return (T) 0; 78 | } 79 | __device__ T operator()(const T a, const T b) { 80 | return a + b; 81 | } 82 | }; 83 | 84 | template 85 | __global__ void reduce(const T *a, unsigned int n, T *b) { 86 | Op op; 87 | if (block_size == 1) { 88 | if (n == 1) { 89 | b[0] = a[0]; 90 | } else if (n == 2) { 91 | b[0] = op(a[0], a[1]); 92 | } 93 | } else { 94 | unsigned int tid = threadIdx.x; 95 | unsigned int i = blockIdx.x*(block_size*2) + threadIdx.x; 96 | unsigned int gridSize = block_size*2*gridDim.x; 97 | 98 | SharedMemory smem; 99 | volatile T* sdata = smem.pointer(); 100 | T reduced = op.identity(); 101 | 102 | // Reduce multiple elements per thread. 103 | while (i < n) { 104 | reduced = op(reduced, a[i]); 105 | // Check array bounds 106 | if (i + block_size < n) { 107 | reduced = op(reduced, a[i+block_size]); 108 | } 109 | i += gridSize; 110 | } 111 | 112 | // Reduce in shared memory 113 | sdata[tid] = reduced; 114 | __syncthreads(); 115 | 116 | #pragma unroll 117 | for (unsigned int i=512; i >= 2; i >>= 1) { 118 | if (block_size >= i) { 119 | if (tid < (i << 1)) { 120 | sdata[tid] = reduced = op(reduced, sdata[tid + i]); 121 | } 122 | // No need to sync threads in the same warp 123 | if (tid >= 32) 124 | __syncthreads(); 125 | } 126 | } 127 | 128 | // Write reduced block back to global memory 129 | if (tid == 0) { 130 | b[blockIdx.x] = sdata[0]; 131 | } 132 | } 133 | } 134 | 135 | const unsigned int max_blocks = 64; 136 | const unsigned int reduce_cta_size = 256; 137 | inline unsigned int ceil_pow2(unsigned int x) { 138 | --x; 139 | x |= x >> 1; 140 | x |= x >> 2; 141 | x |= x >> 4; 142 | x |= x >> 8; 143 | x |= x >> 16; 144 | return ++x; 145 | } 146 | 147 | unsigned int n_reduce_blocks(unsigned int n) { 148 | return min(max_blocks, (n + (2*reduce_cta_size - 1)) / (2*reduce_cta_size)); 149 | } 150 | 151 | unsigned int n_reduce_threads(unsigned int n) { 152 | return n > 2 * reduce_cta_size ? reduce_cta_size : max(1, ceil_pow2(n) / 2); 153 | } 154 | 155 | template 156 | void reduce_blocks(const T *a, unsigned int n, T *b) { 157 | unsigned int n_threads = n_reduce_threads(n); 158 | dim3 block(n_threads, 1, 1); 159 | 160 | unsigned int n_blocks = n_reduce_blocks(n); 161 | dim3 grid(n_blocks, 1, 1); 162 | int smem_size = reduce_cta_size * sizeof(T); 163 | 164 | switch (block.x) { 165 | case 512: 166 | reduce<<>>(a, n, b); 167 | break; 168 | case 256: 169 | reduce<<>>(a, n, b); 170 | break; 171 | case 128: 172 | reduce<<>>(a, n, b); 173 | break; 174 | case 64: 175 | reduce<<>>(a, n, b); 176 | break; 177 | case 32: 178 | reduce<<>>(a, n, b); 179 | break; 180 | case 16: 181 | reduce<<>>(a, n, b); 182 | break; 183 | case 8: 184 | reduce<<>>(a, n, b); 185 | break; 186 | case 4: 187 | reduce<<>>(a, n, b); 188 | break; 189 | case 2: 190 | reduce<<>>(a, n, b); 191 | break; 192 | case 1: 193 | reduce<<>>(a, n, b); 194 | break; 195 | } 196 | } 197 | 198 | template 199 | void reduce(const T *a, unsigned int n, T *b) { 200 | unsigned int n_blocks = n_reduce_blocks(n); 201 | if (n_blocks > 1) { 202 | T *buf = (T *) CUDA::buffer(n_blocks*sizeof(T)); 203 | reduce_blocks(a, n, buf); 204 | reduce_blocks(buf, n_blocks, b); 205 | } else { 206 | reduce_blocks(a, n, b); 207 | } 208 | } 209 | 210 | template 211 | void reduce(ReduceOp op, const T *a, unsigned int n, T *b) { 212 | switch (op) { 213 | case MAX_OP: 214 | reduce >(a, n, b); 215 | break; 216 | case MEAN_OP: 217 | reduce >(a, n, b); 218 | binary_scalar(DIV_OP, b, (T) n, 1, b); 219 | break; 220 | case MIN_OP: 221 | reduce >(a, n, b); 222 | break; 223 | case SUM_OP: 224 | reduce >(a, n, b); 225 | break; 226 | } 227 | } 228 | 229 | template void reduce(ReduceOp op, const float *a, unsigned int n, 230 | float *b); 231 | template void reduce(ReduceOp op, const int *a, unsigned int n, 232 | int *b); 233 | 234 | 235 | 236 | #define REDUCE_OP(name, ident_f, ident_i, reduce_op, scale_op, select_op) \ 237 | template \ 238 | struct name; \ 239 | template <> \ 240 | struct name { \ 241 | __device__ inline static float identity() { \ 242 | return ident_f; \ 243 | } \ 244 | template \ 245 | __device__ inline static void reduce(volatile Ta a, volatile int idx, \ 246 | volatile Tb &b, volatile int &b_idx) { \ 247 | reduce_op; \ 248 | } \ 249 | template \ 250 | __device__ inline static void scale(volatile Tb &b, volatile float n) { \ 251 | scale_op; \ 252 | } \ 253 | template \ 254 | __device__ inline static void select(volatile Tb &b, volatile Ta a, \ 255 | volatile int idx) { \ 256 | select_op; \ 257 | } \ 258 | }; \ 259 | template <> \ 260 | struct name { \ 261 | __device__ inline static int identity() { \ 262 | return ident_i; \ 263 | } \ 264 | template \ 265 | __device__ inline static void reduce(volatile Ta a, volatile int idx, \ 266 | volatile Tb &b, volatile int &b_idx) { \ 267 | reduce_op; \ 268 | } \ 269 | template \ 270 | __device__ inline static void scale(volatile Tb &b, volatile float n) { \ 271 | scale_op; \ 272 | } \ 273 | template \ 274 | __device__ inline static void select(volatile Tb &b, volatile Ta a, \ 275 | volatile int idx) { \ 276 | select_op; \ 277 | } \ 278 | }; 279 | 280 | REDUCE_OP(max_op, -FLT_MAX, INT_MIN, if (a > b) b = a, , b = a) 281 | REDUCE_OP(mean_op, 0.0f, 0, b += a, b /= n, b = a) 282 | REDUCE_OP(min_op, FLT_MAX, INT_MAX, if (a < b) b = a, , b = a) 283 | REDUCE_OP(sum_op, 0.0f, 0, b += a, , b = a) 284 | REDUCE_OP(argmax_op, -FLT_MAX, INT_MIN, if (a > b) {b = a; b_idx=idx;}, , b = idx) 285 | REDUCE_OP(argmin_op, FLT_MAX, INT_MAX, if (a < b) {b = a; b_idx=idx;}, , b = idx) 286 | 287 | 288 | 289 | template 290 | __global__ void kernel_reduce(const Ta *a, unsigned int n, Tb *b) { 291 | CUDA_GRID_STRIDE_LOOP(idx, 1) { 292 | Ta a_ = Op::identity(); 293 | int idx_ = 0; 294 | for (unsigned int i = 0; i < n; ++i) { 295 | Op::reduce(*a, i, a_, idx_); 296 | ++a; 297 | } 298 | Op::scale(a_, n); 299 | Op::select(*b, a_, idx_); 300 | } 301 | } 302 | 303 | 304 | template 305 | void reduce(const Ta *a, unsigned int n, Tb *b) { 306 | kernel_reduce<<>>(a, n, b); 307 | } 308 | 309 | 310 | 311 | 312 | template 313 | void reduce_to_int(ReduceToIntOp op, const T *a, unsigned int n, int *b) { 314 | switch (op) { 315 | case ARGMAX_OP: 316 | reduce >(a, n, b); 317 | break; 318 | case ARGMIN_OP: 319 | reduce >(a, n, b); 320 | break; 321 | } 322 | } 323 | 324 | template void reduce_to_int(ReduceToIntOp op, const float *a, 325 | unsigned int n, int *b); 326 | template void reduce_to_int(ReduceToIntOp op, const int *a, 327 | unsigned int n, int *b); 328 | 329 | 330 | 331 | 332 | 333 | template 334 | __global__ void kernel_reduce_mat(const Ta *a, unsigned int m, unsigned int n, 335 | Tb *b) { 336 | unsigned int n_threads; 337 | if (reduce_leading) { 338 | n_threads = n; 339 | } else { 340 | n_threads = m; 341 | } 342 | 343 | CUDA_GRID_STRIDE_LOOP(idx, n_threads) { 344 | if (reduce_leading) { 345 | a += idx; 346 | b += idx; 347 | } else { 348 | a += idx * n; 349 | b += idx; 350 | } 351 | 352 | Ta a_ = Op::identity(); 353 | int idx_ = 0; 354 | if (reduce_leading) { 355 | for (unsigned int i = 0; i < m; ++i) { 356 | Op::reduce(*a, i, a_, idx_); 357 | a += n; 358 | } 359 | } else { 360 | for (unsigned int i = 0; i < n; ++i) { 361 | Op::reduce(*a, i, a_, idx_); 362 | ++a; 363 | } 364 | } 365 | 366 | if (reduce_leading) { 367 | Op::scale(a_, m); 368 | } else { 369 | Op::scale(a_, n); 370 | } 371 | Op::select(*b, a_, idx_); 372 | } 373 | } 374 | 375 | template 376 | void reduce_mat(const Ta *a, unsigned int m, unsigned int n, 377 | bool reduce_leading, Tb *b) { 378 | if (reduce_leading) { 379 | kernel_reduce_mat<<>> 380 | (a, m, n, b); 381 | } else { 382 | kernel_reduce_mat<<>> 383 | (a, m, n, b); 384 | } 385 | } 386 | 387 | template 388 | void reduce_mat(ReduceOp op, const T *a, unsigned int m, unsigned int n, 389 | bool reduce_leading, T *b) { 390 | switch (op) { 391 | case MAX_OP: 392 | reduce_mat >(a, m, n, reduce_leading, b); 393 | break; 394 | case MEAN_OP: 395 | reduce_mat >(a, m, n, reduce_leading, b); 396 | break; 397 | case MIN_OP: 398 | reduce_mat >(a, m, n, reduce_leading, b); 399 | break; 400 | case SUM_OP: 401 | reduce_mat >(a, m, n, reduce_leading, b); 402 | break; 403 | } 404 | } 405 | 406 | template void reduce_mat(ReduceOp op, const float *a, unsigned int m, 407 | unsigned int n, bool reduce_leading, float *b); 408 | template void reduce_mat(ReduceOp op, const int *a, unsigned int m, 409 | unsigned int n, bool reduce_leading, int *b); 410 | 411 | 412 | template 413 | void reduce_mat_to_int(ReduceToIntOp op, const T *a, unsigned int m, 414 | unsigned int n, bool reduce_leading, int *b) { 415 | switch (op) { 416 | case ARGMAX_OP: 417 | reduce_mat >(a, m, n, reduce_leading, b); 418 | break; 419 | case ARGMIN_OP: 420 | reduce_mat >(a, m, n, reduce_leading, b); 421 | break; 422 | } 423 | } 424 | 425 | template void reduce_mat_to_int(ReduceToIntOp op, const float *a, 426 | unsigned int m, unsigned int n, bool reduce_leading, int *b); 427 | template void reduce_mat_to_int(ReduceToIntOp op, const int *a, 428 | unsigned int m, unsigned int n, bool reduce_leading, int *b); 429 | 430 | } 431 | --------------------------------------------------------------------------------