├── env.lua ├── TH ├── THBlas.c ├── THLapack.c ├── THStorage.c ├── THBlas.h ├── THLapack.h ├── THConfig.cmake.in ├── THMemoryFile.h ├── THLogAdd.h ├── TH.h ├── THStorage.h ├── THDiskFile.h ├── THTensor.c ├── THGenerateFloatTypes.h ├── generic │ ├── THTensorCopy.h │ ├── THTensorCopy.c │ ├── THStorageCopy.h │ ├── THTensorLapack.h │ ├── THBlas.h │ ├── THStorageCopy.c │ ├── THTensorRandom.h │ ├── THLapack.h │ ├── THVector.c │ ├── THStorage.h │ ├── THTensorConv.h │ ├── THStorage.c │ ├── THTensor.h │ ├── THLapack.c │ ├── THTensorMath.h │ └── THTensorRandom.c ├── THAllocator.h ├── THTensor.h ├── THGenerateIntTypes.h ├── THFilePrivate.h ├── THTensorMacros.h ├── THGenerateAllTypes.h ├── THLogAdd.c ├── THGeneral.h.in ├── cmake │ ├── FindSSE.cmake │ ├── FindARM.cmake │ └── FindLAPACK.cmake ├── THRandom.h ├── THGeneral.c ├── THFile.h ├── THFile.c ├── THAllocator.c ├── THRandom.c └── CMakeLists.txt ├── init.lua ├── rocks └── torch-9.scm-1.rockspec ├── CMakeLists.txt ├── README.md ├── dispatch.lua ├── register.lua ├── cmake ├── template.lua ├── TorchTemplate.cmake └── FindLAPACK.cmake ├── registernumbers.lua ├── COPYRIGHT.txt ├── dimapply.lua ├── timer.lua ├── random.lua ├── apply.lua ├── conv.lua ├── lapack.lua ├── storage.lua ├── display.lua ├── serialization.lua ├── tensorop.lua └── memoryfile.lua /env.lua: -------------------------------------------------------------------------------- 1 | local torch = {} 2 | 3 | return torch 4 | -------------------------------------------------------------------------------- /TH/THBlas.c: -------------------------------------------------------------------------------- 1 | #include "THBlas.h" 2 | 3 | #include "generic/THBlas.c" 4 | #include "THGenerateAllTypes.h" 5 | -------------------------------------------------------------------------------- /TH/THLapack.c: -------------------------------------------------------------------------------- 1 | #include "THLapack.h" 2 | 3 | #include "generic/THLapack.c" 4 | #include "THGenerateFloatTypes.h" 5 | -------------------------------------------------------------------------------- /TH/THStorage.c: -------------------------------------------------------------------------------- 1 | #include "THStorage.h" 2 | 3 | #include "generic/THStorage.c" 4 | #include "THGenerateAllTypes.h" 5 | 6 | #include "generic/THStorageCopy.c" 7 | #include "THGenerateAllTypes.h" 8 | -------------------------------------------------------------------------------- /TH/THBlas.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_BLAS_INC 2 | #define TH_BLAS_INC 3 | 4 | #include "THGeneral.h" 5 | 6 | #define THBlas_(NAME) TH_CONCAT_4(TH,Real,Blas_,NAME) 7 | 8 | #include "generic/THBlas.h" 9 | #include "THGenerateAllTypes.h" 10 | 11 | #endif 12 | -------------------------------------------------------------------------------- /TH/THLapack.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_LAPACK_INC 2 | #define TH_LAPACK_INC 3 | 4 | #include "THGeneral.h" 5 | 6 | #define THLapack_(NAME) TH_CONCAT_4(TH,Real,Lapack_,NAME) 7 | 8 | #include "generic/THLapack.h" 9 | #include "THGenerateAllTypes.h" 10 | 11 | #endif 12 | -------------------------------------------------------------------------------- /TH/THConfig.cmake.in: -------------------------------------------------------------------------------- 1 | # Find the TH includes and library 2 | # 3 | # TH_INCLUDE_DIR -- where to find the includes 4 | # TH_LIBRARIES -- list of libraries to link against 5 | # TH_FOUND -- set to 1 if found 6 | 7 | SET(TH_FOUND 1) 8 | SET(TH_INCLUDE_DIR "@TH_INCLUDE_DIR@") 9 | SET(TH_LIBRARIES "@TH_LIBRARIES@") 10 | -------------------------------------------------------------------------------- /TH/THMemoryFile.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_MEMORY_FILE_INC 2 | #define TH_MEMORY_FILE_INC 3 | 4 | #include "THFile.h" 5 | #include "THStorage.h" 6 | 7 | TH_API THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode); 8 | TH_API THFile *THMemoryFile_new(const char *mode); 9 | 10 | TH_API THCharStorage *THMemoryFile_storage(THFile *self); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /TH/THLogAdd.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_LOG_ADD_INC 2 | #define TH_LOG_ADD_INC 3 | 4 | #include "THGeneral.h" 5 | 6 | TH_API const double THLog2Pi; 7 | TH_API const double THLogZero; 8 | TH_API const double THLogOne; 9 | 10 | TH_API double THLogAdd(double log_a, double log_b); 11 | TH_API double THLogSub(double log_a, double log_b); 12 | TH_API double THExpMinusApprox(const double x); 13 | 14 | #endif 15 | -------------------------------------------------------------------------------- /TH/TH.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_INC 2 | #define TH_INC 3 | 4 | #include "THGeneral.h" 5 | 6 | #include "THBlas.h" 7 | #ifdef USE_LAPACK 8 | #include "THLapack.h" 9 | #endif 10 | 11 | #include "THVector.h" 12 | #include "THLogAdd.h" 13 | #include "THRandom.h" 14 | #include "THStorage.h" 15 | #include "THTensor.h" 16 | #include "THTensorApply.h" 17 | #include "THTensorDimApply.h" 18 | 19 | #include "THFile.h" 20 | #include "THDiskFile.h" 21 | #include "THMemoryFile.h" 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /TH/THStorage.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_STORAGE_INC 2 | #define TH_STORAGE_INC 3 | 4 | #include "THGeneral.h" 5 | #include "THAllocator.h" 6 | 7 | #define THStorage TH_CONCAT_3(TH,Real,Storage) 8 | #define THStorage_(NAME) TH_CONCAT_4(TH,Real,Storage_,NAME) 9 | 10 | /* fast access methods */ 11 | #define TH_STORAGE_GET(storage, idx) ((storage)->data[(idx)]) 12 | #define TH_STORAGE_SET(storage, idx, value) ((storage)->data[(idx)] = (value)) 13 | 14 | #include "generic/THStorage.h" 15 | #include "THGenerateAllTypes.h" 16 | 17 | #include "generic/THStorageCopy.h" 18 | #include "THGenerateAllTypes.h" 19 | 20 | #endif 21 | -------------------------------------------------------------------------------- /TH/THDiskFile.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_DISK_FILE_INC 2 | #define TH_DISK_FILE_INC 3 | 4 | #include "THFile.h" 5 | 6 | TH_API THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet); 7 | TH_API THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet); 8 | 9 | TH_API const char *THDiskFile_name(THFile *self); 10 | 11 | TH_API int THDiskFile_isLittleEndianCPU(void); 12 | TH_API int THDiskFile_isBigEndianCPU(void); 13 | TH_API void THDiskFile_nativeEndianEncoding(THFile *self); 14 | TH_API void THDiskFile_littleEndianEncoding(THFile *self); 15 | TH_API void THDiskFile_bigEndianEncoding(THFile *self); 16 | 17 | #endif 18 | -------------------------------------------------------------------------------- /TH/THTensor.c: -------------------------------------------------------------------------------- 1 | #include "THTensor.h" 2 | #include "THVector.h" 3 | #include "THBlas.h" 4 | #include "THLapack.h" 5 | #include "THRandom.h" 6 | #include "THTensorDimApply.h" 7 | 8 | #include "generic/THTensor.c" 9 | #include "THGenerateAllTypes.h" 10 | 11 | #include "generic/THTensorCopy.c" 12 | #include "THGenerateAllTypes.h" 13 | 14 | #include "generic/THTensorRandom.c" 15 | #include "THGenerateAllTypes.h" 16 | 17 | #include "generic/THTensorMath.c" 18 | #include "THGenerateAllTypes.h" 19 | 20 | #include "generic/THTensorConv.c" 21 | #include "THGenerateAllTypes.h" 22 | 23 | #include "generic/THTensorLapack.c" 24 | #include "THGenerateFloatTypes.h" 25 | -------------------------------------------------------------------------------- /TH/THGenerateFloatTypes.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #error "You must define TH_GENERIC_FILE before including THGenerateAllTypes.h" 3 | #endif 4 | 5 | #define real float 6 | #define accreal double 7 | #define Real Float 8 | #define TH_REAL_IS_FLOAT 9 | #line 1 TH_GENERIC_FILE 10 | #include TH_GENERIC_FILE 11 | #undef accreal 12 | #undef real 13 | #undef Real 14 | #undef TH_REAL_IS_FLOAT 15 | 16 | #define real double 17 | #define accreal double 18 | #define Real Double 19 | #define TH_REAL_IS_DOUBLE 20 | #line 1 TH_GENERIC_FILE 21 | #include TH_GENERIC_FILE 22 | #undef accreal 23 | #undef real 24 | #undef Real 25 | #undef TH_REAL_IS_DOUBLE 26 | 27 | #undef TH_GENERIC_FILE 28 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | if not jit then 2 | error('FATAL: torch9 is luajit *only*') 3 | end 4 | 5 | local ffi = require 'ffi' 6 | 7 | ffi.cdef[[ 8 | void free(void *ptr); 9 | void *malloc(size_t size); 10 | void *realloc(void *ptr, size_t size); 11 | typedef unsigned char byte; 12 | ]] 13 | 14 | require 'torch.timer' 15 | 16 | require 'torch.storage' 17 | require 'torch.tensor' 18 | 19 | require 'torch.apply' 20 | require 'torch.dimapply' 21 | require 'torch.maths' 22 | require 'torch.lapack' 23 | require 'torch.conv' 24 | require 'torch.tensorop' 25 | require 'torch.random' 26 | 27 | require 'torch.file' 28 | require 'torch.diskfile' 29 | require 'torch.memoryfile' 30 | 31 | require 'torch.serialization' 32 | 33 | local torch = require 'torch.env' 34 | 35 | torch.Tensor = torch.DoubleTensor 36 | 37 | return torch 38 | -------------------------------------------------------------------------------- /TH/generic/THTensorCopy.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THTensorCopy.h" 3 | #else 4 | 5 | /* Support for copy between different Tensor types */ 6 | 7 | TH_API void THTensor_(copy)(THTensor *tensor, THTensor *src); 8 | TH_API void THTensor_(copyByte)(THTensor *tensor, struct THByteTensor *src); 9 | TH_API void THTensor_(copyChar)(THTensor *tensor, struct THCharTensor *src); 10 | TH_API void THTensor_(copyShort)(THTensor *tensor, struct THShortTensor *src); 11 | TH_API void THTensor_(copyInt)(THTensor *tensor, struct THIntTensor *src); 12 | TH_API void THTensor_(copyLong)(THTensor *tensor, struct THLongTensor *src); 13 | TH_API void THTensor_(copyFloat)(THTensor *tensor, struct THFloatTensor *src); 14 | TH_API void THTensor_(copyDouble)(THTensor *tensor, struct THDoubleTensor *src); 15 | 16 | #endif 17 | -------------------------------------------------------------------------------- /TH/generic/THTensorCopy.c: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THTensorCopy.c" 3 | #else 4 | 5 | void THTensor_(copy)(THTensor *tensor, THTensor *src) 6 | { 7 | TH_TENSOR_APPLY2(real, tensor, real, src, *tensor_data = (real)(*src_data);) 8 | } 9 | 10 | #define IMPLEMENT_THTensor_COPY(TYPENAMESRC, TYPE_SRC) \ 11 | void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src) \ 12 | { \ 13 | TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = (real)(*src_data);) \ 14 | } 15 | 16 | IMPLEMENT_THTensor_COPY(Byte, unsigned char) 17 | IMPLEMENT_THTensor_COPY(Char, char) 18 | IMPLEMENT_THTensor_COPY(Short, short) 19 | IMPLEMENT_THTensor_COPY(Int, int) 20 | IMPLEMENT_THTensor_COPY(Long, long) 21 | IMPLEMENT_THTensor_COPY(Float, float) 22 | IMPLEMENT_THTensor_COPY(Double, double) 23 | 24 | #endif 25 | -------------------------------------------------------------------------------- /TH/THAllocator.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_ALLOCATOR_INC 2 | #define TH_ALLOCATOR_INC 3 | 4 | #include "THGeneral.h" 5 | 6 | /* Custom allocator 7 | */ 8 | typedef struct THAllocator { 9 | void* (*malloc)(void*, long); 10 | void* (*realloc)(void*, void*, long); 11 | void (*free)(void*, void*); 12 | } THAllocator; 13 | 14 | /* default malloc/free allocator. malloc and realloc raise an error (using 15 | * THError) on allocation failure. 16 | */ 17 | extern THAllocator THDefaultAllocator; 18 | 19 | /* file map allocator 20 | */ 21 | typedef struct THMapAllocatorContext_ THMapAllocatorContext; 22 | THMapAllocatorContext *THMapAllocatorContext_new(const char *filename, int shared); 23 | long THMapAllocatorContext_size(THMapAllocatorContext *ctx); 24 | void THMapAllocatorContext_free(THMapAllocatorContext *ctx); 25 | 26 | extern THAllocator THMapAllocator; 27 | 28 | #endif 29 | -------------------------------------------------------------------------------- /TH/THTensor.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_TENSOR_INC 2 | #define TH_TENSOR_INC 3 | 4 | #include "THStorage.h" 5 | #include "THTensorApply.h" 6 | 7 | #define THTensor TH_CONCAT_3(TH,Real,Tensor) 8 | #define THTensor_(NAME) TH_CONCAT_4(TH,Real,Tensor_,NAME) 9 | 10 | /* basics */ 11 | #include "generic/THTensor.h" 12 | #include "THGenerateAllTypes.h" 13 | 14 | #include "generic/THTensorCopy.h" 15 | #include "THGenerateAllTypes.h" 16 | 17 | #include "THTensorMacros.h" 18 | 19 | /* random numbers */ 20 | #include "THRandom.h" 21 | #include "generic/THTensorRandom.h" 22 | #include "THGenerateAllTypes.h" 23 | 24 | /* maths */ 25 | #include "generic/THTensorMath.h" 26 | #include "THGenerateAllTypes.h" 27 | 28 | /* convolutions */ 29 | #include "generic/THTensorConv.h" 30 | #include "THGenerateAllTypes.h" 31 | 32 | /* lapack support */ 33 | #include "generic/THTensorLapack.h" 34 | #include "THGenerateFloatTypes.h" 35 | 36 | #endif 37 | -------------------------------------------------------------------------------- /TH/generic/THStorageCopy.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THStorageCopy.h" 3 | #else 4 | 5 | /* Support for copy between different Storage types */ 6 | 7 | TH_API void THStorage_(rawCopy)(THStorage *storage, real *src); 8 | TH_API void THStorage_(copy)(THStorage *storage, THStorage *src); 9 | TH_API void THStorage_(copyByte)(THStorage *storage, struct THByteStorage *src); 10 | TH_API void THStorage_(copyChar)(THStorage *storage, struct THCharStorage *src); 11 | TH_API void THStorage_(copyShort)(THStorage *storage, struct THShortStorage *src); 12 | TH_API void THStorage_(copyInt)(THStorage *storage, struct THIntStorage *src); 13 | TH_API void THStorage_(copyLong)(THStorage *storage, struct THLongStorage *src); 14 | TH_API void THStorage_(copyFloat)(THStorage *storage, struct THFloatStorage *src); 15 | TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *src); 16 | 17 | #endif 18 | -------------------------------------------------------------------------------- /TH/generic/THTensorLapack.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THTensorLapack.h" 3 | #else 4 | 5 | TH_API void THTensor_(gesv)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_); 6 | TH_API void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_); 7 | TH_API void THTensor_(syev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *jobz, const char *uplo); 8 | TH_API void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *jobvr); 9 | TH_API void THTensor_(gesvd)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *a, const char *jobu); 10 | TH_API void THTensor_(gesvd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra_, THTensor *a, const char *jobu); 11 | TH_API void THTensor_(getri)(THTensor *ra_, THTensor *a); 12 | TH_API void THTensor_(potri)(THTensor *ra_, THTensor *a); 13 | TH_API void THTensor_(potrf)(THTensor *ra_, THTensor *a); 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /rocks/torch-9.scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "torch" 2 | version = "9.scm-1" 3 | 4 | source = { 5 | url = "git://github.com/andresy/torch9.git", 6 | } 7 | 8 | description = { 9 | summary = "Torch9", 10 | detailed = [[ 11 | Torch9 provides a Matlab-like environment for state-of-the-art machine 12 | learning algorithms. It provides a very efficient implementation, thanks 13 | to Luajit and few C code lines for critical inner loops. 14 | ]], 15 | homepage = "http://www.torch.ch", 16 | license = "BSD" 17 | } 18 | 19 | dependencies = { 20 | "lua >= 5.1", 21 | "argcheck >= 1", 22 | "class >= 1" 23 | } 24 | 25 | build = { 26 | type = "command", 27 | build_command = "cmake -E make_directory build && cd build && cmake -DCMAKE_INSTALL_PREFIX=$(PREFIX) -DCMAKE_BUILD_TYPE=Release -DLUA_PATH_DIR=$(LUADIR)/torch -DLUA_CPATH_DIR=$(LIBDIR) -DLUA_EXECUTABLE=$(LUA) .. && $(MAKE)", 28 | install_command = "cd build && $(MAKE) install" 29 | } 30 | -------------------------------------------------------------------------------- /TH/generic/THBlas.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THBlas.h" 3 | #else 4 | 5 | /* Level 1 */ 6 | TH_API void THBlas_(swap)(long n, real *x, long incx, real *y, long incy); 7 | TH_API void THBlas_(scal)(long n, real a, real *x, long incx); 8 | TH_API void THBlas_(copy)(long n, real *x, long incx, real *y, long incy); 9 | TH_API void THBlas_(axpy)(long n, real a, real *x, long incx, real *y, long incy); 10 | TH_API real THBlas_(dot)(long n, real *x, long incx, real *y, long incy); 11 | 12 | /* Level 2 */ 13 | TH_API void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, real *x, long incx, real beta, real *y, long incy); 14 | TH_API void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y, long incy, real *a, long lda); 15 | 16 | /* Level 3 */ 17 | TH_API void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha, real *a, long lda, real *b, long ldb, real beta, real *c, long ldc); 18 | 19 | #endif 20 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | CMAKE_MINIMUM_REQUIRED(VERSION 2.6) 2 | 3 | LIST(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") 4 | 5 | INCLUDE(TorchTemplate) 6 | 7 | SET(luasrc env.lua init.lua register.lua registernumbers.lua timer.lua apply.lua dimapply.lua 8 | display.lua random.lua file.lua diskfile.lua memoryfile.lua 9 | serialization.lua TH.lua) 10 | 11 | ADD_TORCH_TEMPLATE("storage.lua" luasrc "torch") 12 | ADD_TORCH_TEMPLATE("tensor.lua" luasrc "torch") 13 | ADD_TORCH_TEMPLATE("maths.lua" luasrc "torch") 14 | ADD_TORCH_TEMPLATE("lapack.lua" luasrc "torch") 15 | ADD_TORCH_TEMPLATE("conv.lua" luasrc "torch") 16 | ADD_TORCH_TEMPLATE("tensorop.lua" luasrc "torch") 17 | 18 | INSTALL(FILES ${luasrc} 19 | DESTINATION "${LUA_PATH_DIR}") 20 | 21 | # TH ### 22 | 23 | FILE(RELATIVE_PATH TH_INSTALL_BIN_SUBDIR "${CMAKE_INSTALL_PREFIX}" "${LUA_CPATH_DIR}") 24 | FILE(RELATIVE_PATH TH_INSTALL_LIB_SUBDIR "${CMAKE_INSTALL_PREFIX}" "${LUA_CPATH_DIR}") 25 | SET(TH_INSTALL_INCLUDE_SUBDIR "include") 26 | SET(TH_INSTALL_CMAKE_SUBDIR "cmake/TH") 27 | 28 | ADD_SUBDIRECTORY(TH) 29 | -------------------------------------------------------------------------------- /TH/generic/THStorageCopy.c: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THStorageCopy.c" 3 | #else 4 | 5 | void THStorage_(rawCopy)(THStorage *storage, real *src) 6 | { 7 | long i; 8 | for(i = 0; i < storage->size; i++) 9 | storage->data[i] = src[i]; 10 | } 11 | 12 | void THStorage_(copy)(THStorage *storage, THStorage *src) 13 | { 14 | THArgCheck(storage->size == src->size, 2, "size mismatch"); 15 | THStorage_(rawCopy)(storage, src->data); 16 | } 17 | 18 | 19 | #define IMPLEMENT_THStorage_COPY(TYPENAMESRC) \ 20 | void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage *src) \ 21 | { \ 22 | long i; \ 23 | THArgCheck(storage->size == src->size, 2, "size mismatch"); \ 24 | for(i = 0; i < storage->size; i++) \ 25 | storage->data[i] = (real)src->data[i]; \ 26 | } 27 | 28 | IMPLEMENT_THStorage_COPY(Byte) 29 | IMPLEMENT_THStorage_COPY(Char) 30 | IMPLEMENT_THStorage_COPY(Short) 31 | IMPLEMENT_THStorage_COPY(Int) 32 | IMPLEMENT_THStorage_COPY(Long) 33 | IMPLEMENT_THStorage_COPY(Float) 34 | IMPLEMENT_THStorage_COPY(Double) 35 | 36 | #endif 37 | -------------------------------------------------------------------------------- /TH/generic/THTensorRandom.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THTensorRandom.h" 3 | #else 4 | 5 | TH_API void THTensor_(random)(THTensor *self, THGenerator *_generator); 6 | TH_API void THTensor_(geometric)(THTensor *self, THGenerator *_generator, double p); 7 | TH_API void THTensor_(bernoulli)(THTensor *self, THGenerator *_generator, double p); 8 | 9 | #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) 10 | TH_API void THTensor_(uniform)(THTensor *self, THGenerator *_generator, double a, double b); 11 | TH_API void THTensor_(normal)(THTensor *self, THGenerator *_generator, double mean, double stdv); 12 | TH_API void THTensor_(exponential)(THTensor *self, THGenerator *_generator, double lambda); 13 | TH_API void THTensor_(cauchy)(THTensor *self, THGenerator *_generator, double median, double sigma); 14 | TH_API void THTensor_(logNormal)(THTensor *self, THGenerator *_generator, double mean, double stdv); 15 | TH_API void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTensor *prob_dist, int n_sample, int with_replacement); 16 | #endif 17 | 18 | #if defined(TH_REAL_IS_LONG) 19 | TH_API void THTensor_(getRNGState)(THGenerator *_generator, THTensor *self); 20 | TH_API void THTensor_(setRNGState)(THGenerator *_generator, THTensor *self); 21 | #endif 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Torch9 Core Library. 2 | =============== 3 | 4 | Torch9 provides a Matlab-like environment for state-of-the-art machine 5 | learning algorithms. It is easy to use and provides a very efficient 6 | implementation, thanks to an easy and fast scripting language (Luajit) and 7 | few C code lines for critical inner loops. 8 | 9 | This package provides the core of the Torch9 distribution. 10 | 11 | Note: Torch9 is still beta. 12 | 13 | Installation 14 | ------------ 15 | 16 | ### Requirements 17 | 18 | * C compiler 19 | * [luajit](http://www.luajit.org) 20 | * [cmake](http://www.cmake.org) 21 | * git 22 | * [luarocks](http://www.luarocks.org) 23 | 24 | ### Getting the last version from the git: 25 | 26 | ```sh 27 | # first get the argcheck dependency 28 | luarocks build https://raw.github.com/andresy/argcheck/master/rocks/argcheck-scm-1.rockspec 29 | 30 | # now get torch 31 | luarocks build https://raw.github.com/andresy/torch9/master/rocks/torch-9.scm-1.rockspec 32 | ``` 33 | 34 | Running 35 | ======= 36 | ```sh 37 | $ luajit -ltorch 38 | Torch 9.0 -- Copyright (C) 2001-2013 Idiap, NEC Labs, NYU. http://www.torch.ch/ 39 | LuaJIT 2.0.0 -- Copyright (C) 2005-2012 Mike Pall. http://luajit.org/ 40 | JIT: ON CMOV SSE2 SSE3 SSE4.1 fold cse dce fwd dse narrow loop abc sink fuse 41 | > 42 | ``` 43 | 44 | Documentation 45 | ============= 46 | 47 | Coming soon. 48 | 49 | 50 | -------------------------------------------------------------------------------- /TH/THGenerateIntTypes.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #error "You must define TH_GENERIC_FILE before including THGenerateIntTypes.h" 3 | #endif 4 | 5 | #define real unsigned char 6 | #define accreal long 7 | #define Real Byte 8 | #define TH_REAL_IS_BYTE 9 | #line 1 TH_GENERIC_FILE 10 | #include TH_GENERIC_FILE 11 | #undef real 12 | #undef accreal 13 | #undef Real 14 | #undef TH_REAL_IS_BYTE 15 | 16 | #define real char 17 | #define accreal long 18 | #define Real Char 19 | #define TH_REAL_IS_CHAR 20 | #line 1 TH_GENERIC_FILE 21 | #include TH_GENERIC_FILE 22 | #undef real 23 | #undef accreal 24 | #undef Real 25 | #undef TH_REAL_IS_CHAR 26 | 27 | #define real short 28 | #define accreal long 29 | #define Real Short 30 | #define TH_REAL_IS_SHORT 31 | #line 1 TH_GENERIC_FILE 32 | #include TH_GENERIC_FILE 33 | #undef real 34 | #undef accreal 35 | #undef Real 36 | #undef TH_REAL_IS_SHORT 37 | 38 | #define real int 39 | #define accreal long 40 | #define Real Int 41 | #define TH_REAL_IS_INT 42 | #line 1 TH_GENERIC_FILE 43 | #include TH_GENERIC_FILE 44 | #undef real 45 | #undef accreal 46 | #undef Real 47 | #undef TH_REAL_IS_INT 48 | 49 | #define real long 50 | #define accreal long 51 | #define Real Long 52 | #define TH_REAL_IS_LONG 53 | #line 1 TH_GENERIC_FILE 54 | #include TH_GENERIC_FILE 55 | #undef real 56 | #undef accreal 57 | #undef Real 58 | #undef TH_REAL_IS_LONG 59 | 60 | #undef TH_GENERIC_FILE 61 | -------------------------------------------------------------------------------- /dispatch.lua: -------------------------------------------------------------------------------- 1 | local argcheck = require 'argcheck' 2 | 3 | local dispatch = 4 | argcheck{ 5 | {{name="idx", type="number", default=1}}, 6 | function(idx) 7 | local env = {idx=idx, funcs={}, type=type, select=select, error=error, string=string} 8 | local func = 9 | function(...) 10 | local typename = type(select(idx, ...)) 11 | local func = funcs[typename] 12 | if not func then 13 | error(string.format('function not implemented for type <%s>', typename)) 14 | else 15 | return func(...) 16 | end 17 | end 18 | setfenv(func, env) 19 | return func 20 | end, 21 | 22 | {{name="idxfunc", type="function"}}, 23 | function(idxfunc) -- ARG, the idxfunc might be the same for several functions... 24 | local env = {funcs={}} 25 | setmetatable(env, {__index=getfenv(1)}) 26 | local func = 27 | function(...) 28 | setfenv(idxfunc, env) -- ... workaround 29 | return idxfunc(...) 30 | end 31 | setfenv(func, env) -- because that is how we define new dispatch functions below 32 | return func 33 | end, 34 | 35 | {{name="func", type="function"}, 36 | {name="typename", type="string"}, 37 | {name="functypename", type="function"}}, 38 | function(func, typename, functypename) 39 | getfenv(func).funcs[typename] = functypename 40 | end 41 | } 42 | 43 | return dispatch 44 | -------------------------------------------------------------------------------- /register.lua: -------------------------------------------------------------------------------- 1 | local argcheck = require 'argcheck' 2 | 3 | local function tablecopyarg(tbl, method) 4 | local newtbl = {} 5 | for k,v in pairs(tbl) do 6 | if k ~= 'method' then 7 | newtbl[k] = v 8 | end 9 | end 10 | if method and tbl.method then 11 | for k,v in pairs(tbl.method) do 12 | newtbl[k] = v 13 | end 14 | end 15 | return newtbl 16 | end 17 | 18 | local function tablecopy(tbl, method) 19 | local newtbl = {} 20 | for k,v in pairs(tbl) do 21 | if k ~= 'name' then 22 | if type(k) == 'number' and type(v) == 'table' then 23 | newtbl[k] = tablecopyarg(v, method) 24 | else 25 | newtbl[k] = v 26 | end 27 | end 28 | end 29 | return newtbl 30 | end 31 | 32 | local function register(args, namespace, metatable) 33 | local name = args.name 34 | 35 | assert(name, 'missing function name') 36 | 37 | if namespace then 38 | local args_f = tablecopy(args) 39 | if namespace[name] then 40 | args_f.chain = namespace[name] 41 | end 42 | namespace[name] = argcheck(args_f) 43 | end 44 | 45 | if metatable then 46 | local args_m = tablecopy(args, true) 47 | if args_m[1] then 48 | args_m[1].name = "self" 49 | end 50 | if metatable[name] then 51 | args_m.chain = metatable[name] 52 | end 53 | metatable[name] = argcheck(args_m) 54 | end 55 | 56 | end 57 | 58 | return register 59 | -------------------------------------------------------------------------------- /TH/THFilePrivate.h: -------------------------------------------------------------------------------- 1 | struct THFile__ 2 | { 3 | struct THFileVTable *vtable; 4 | 5 | int isQuiet; 6 | int isReadable; 7 | int isWritable; 8 | int isBinary; 9 | int isAutoSpacing; 10 | int hasError; 11 | }; 12 | 13 | /* virtual table definition */ 14 | 15 | struct THFileVTable 16 | { 17 | int (*isOpened)(THFile *self); 18 | 19 | long (*readByte)(THFile *self, unsigned char *data, long n); 20 | long (*readChar)(THFile *self, char *data, long n); 21 | long (*readShort)(THFile *self, short *data, long n); 22 | long (*readInt)(THFile *self, int *data, long n); 23 | long (*readLong)(THFile *self, long *data, long n); 24 | long (*readFloat)(THFile *self, float *data, long n); 25 | long (*readDouble)(THFile *self, double *data, long n); 26 | long (*readString)(THFile *self, const char *format, char **str_); 27 | 28 | long (*writeByte)(THFile *self, unsigned char *data, long n); 29 | long (*writeChar)(THFile *self, char *data, long n); 30 | long (*writeShort)(THFile *self, short *data, long n); 31 | long (*writeInt)(THFile *self, int *data, long n); 32 | long (*writeLong)(THFile *self, long *data, long n); 33 | long (*writeFloat)(THFile *self, float *data, long n); 34 | long (*writeDouble)(THFile *self, double *data, long n); 35 | long (*writeString)(THFile *self, const char *str, long size); 36 | 37 | void (*synchronize)(THFile *self); 38 | void (*seek)(THFile *self, long position); 39 | void (*seekEnd)(THFile *self); 40 | long (*position)(THFile *self); 41 | void (*close)(THFile *self); 42 | void (*free)(THFile *self); 43 | }; 44 | -------------------------------------------------------------------------------- /TH/generic/THLapack.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THLapack.h" 3 | #else 4 | 5 | /* AX=B */ 6 | TH_API void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info); 7 | /* ||AX-B|| */ 8 | TH_API void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real *b, int ldb, real *work, int lwork, int *info); 9 | /* Eigenvals */ 10 | TH_API void THLapack_(syev)(char jobz, char uplo, int n, real *a, int lda, real *w, real *work, int lwork, int *info); 11 | /* Non-sym eigenvals */ 12 | TH_API void THLapack_(geev)(char jobvl, char jobvr, int n, real *a, int lda, real *wr, real *wi, real* vl, int ldvl, real *vr, int ldvr, real *work, int lwork, int *info); 13 | /* svd */ 14 | TH_API void THLapack_(gesvd)(char jobu, char jobvt, int m, int n, real *a, int lda, real *s, real *u, int ldu, real *vt, int ldvt, real *work, int lwork, int *info); 15 | /* LU decomposition */ 16 | TH_API void THLapack_(getrf)(int m, int n, real *a, int lda, int *ipiv, int *info); 17 | /* Matrix Inverse */ 18 | TH_API void THLapack_(getri)(int n, real *a, int lda, int *ipiv, real *work, int lwork, int* info); 19 | 20 | /* Positive Definite matrices */ 21 | /* Cholesky factorization */ 22 | void THLapack_(potrf)(char uplo, int n, real *a, int lda, int *info); 23 | /* Matrix inverse based on Cholesky factorization */ 24 | void THLapack_(potri)(char uplo, int n, real *a, int lda, int *info); 25 | /* Solve A*X = B with a symmetric positive definite matrix A using the Cholesky factorization */ 26 | void THLapack_(potrs)(char uplo, int n, int nrhs, real *a, int lda, real *b, int ldb, int *info); 27 | 28 | #endif 29 | -------------------------------------------------------------------------------- /TH/THTensorMacros.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_TENSOR_MACROS_INC 2 | #define TH_TENSOR_MACROS_INC 3 | 4 | /* fast method to access to tensor data */ 5 | 6 | #define THTensor_fastGet1d(self, x0) \ 7 | (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]]) 8 | 9 | #define THTensor_fastGet2d(self, x0, x1) \ 10 | (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]]) 11 | 12 | #define THTensor_fastGet3d(self, x0, x1, x2) \ 13 | (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]]) 14 | 15 | #define THTensor_fastGet4d(self, x0, x1, x2, x3) \ 16 | (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]+(x3)*(self)->stride[3]]) 17 | 18 | #define THTensor_fastSet1d(self, x0, value) \ 19 | (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]] = value) 20 | 21 | #define THTensor_fastSet2d(self, x0, x1, value) \ 22 | (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]] = value) 23 | 24 | #define THTensor_fastSet3d(self, x0, x1, x2, value) \ 25 | (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]] = value) 26 | 27 | #define THTensor_fastSet4d(self, x0, x1, x2, x3, value) \ 28 | (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]+(x3)*(self)->stride[3]] = value) 29 | 30 | #endif 31 | -------------------------------------------------------------------------------- /TH/generic/THVector.c: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THVector.c" 3 | #else 4 | 5 | static TH_INLINE void THVector_(fill)(real *x, const real c, const long n) { 6 | long i = 0; 7 | 8 | for(; i < n-4; i += 4) 9 | { 10 | x[i] = c; 11 | x[i+1] = c; 12 | x[i+2] = c; 13 | x[i+3] = c; 14 | } 15 | 16 | for(; i < n; i++) 17 | x[i] = c; 18 | } 19 | 20 | static TH_INLINE void THVector_(add)(real *y, const real *x, const real c, const long n) 21 | { 22 | long i = 0; 23 | 24 | for(;i < n-4; i += 4) 25 | { 26 | y[i] += c * x[i]; 27 | y[i+1] += c * x[i+1]; 28 | y[i+2] += c * x[i+2]; 29 | y[i+3] += c * x[i+3]; 30 | } 31 | 32 | for(; i < n; i++) 33 | y[i] += c * x[i]; 34 | } 35 | 36 | static TH_INLINE void THVector_(diff)(real *z, const real *x, const real *y, const long n) 37 | { 38 | long i = 0; 39 | 40 | for(; i < n-4; i += 4) 41 | { 42 | z[i] = x[i] - y[i]; 43 | z[i+1] = x[i+1] - y[i+1]; 44 | z[i+2] = x[i+2] - y[i+2]; 45 | z[i+3] = x[i+3] - y[i+3]; 46 | } 47 | 48 | for(; i < n; i++) 49 | z[i] = x[i] - y[i]; 50 | } 51 | 52 | static TH_INLINE void THVector_(scale)(real *y, const real c, const long n) 53 | { 54 | long i = 0; 55 | 56 | for(; i < n-4; i +=4) 57 | { 58 | y[i] *= c; 59 | y[i+1] *= c; 60 | y[i+2] *= c; 61 | y[i+3] *= c; 62 | } 63 | 64 | for(; i < n; i++) 65 | y[i] *= c; 66 | } 67 | 68 | static TH_INLINE void THVector_(mul)(real *y, const real *x, const long n) 69 | { 70 | long i = 0; 71 | 72 | for(; i < n-4; i += 4) 73 | { 74 | y[i] *= x[i]; 75 | y[i+1] *= x[i+1]; 76 | y[i+2] *= x[i+2]; 77 | y[i+3] *= x[i+3]; 78 | } 79 | 80 | for(; i < n; i++) 81 | y[i] *= x[i]; 82 | } 83 | 84 | #endif 85 | -------------------------------------------------------------------------------- /cmake/template.lua: -------------------------------------------------------------------------------- 1 | local src = arg[1] 2 | local dst = arg[2] 3 | 4 | local types = {'byte', 'char', 'short', 'int', 'long', 'float', 'double'} 5 | local Types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'} 6 | local taccs = {'long', 'long', 'long', 'long', 'long', 'double', 'double'} 7 | 8 | local f = io.open(src) 9 | local txt = f:read('*all') 10 | f:close() 11 | 12 | for i=1,#types do 13 | local real, Real, accreal = types[i], Types[i], taccs[i] 14 | local txt = txt 15 | 16 | while txt:match('([%p%s])real([%p%s])') do 17 | txt = txt:gsub('([%p%s])real([%p%s])', '%1' .. real .. '%2') 18 | end 19 | 20 | while txt:match('([%p%s])accreal([%p%s])') do 21 | txt = txt:gsub('([%p%s])accreal([%p%s])', '%1' .. accreal .. '%2') 22 | end 23 | 24 | txt = txt:gsub('Real', Real) 25 | 26 | local dst = dst:gsub('(%.[^%.]+)$', '_' .. real .. '%1') 27 | assert(dst ~= src, 'source and destination are same') 28 | local f = io.open(dst, 'w') 29 | f:write(txt) 30 | f:close() 31 | end 32 | 33 | local basename, ext = dst:match('([^/\\]+)(%.[^%.]+)$') 34 | if not basename or not ext then 35 | error('could not determine destination file basename/extension') 36 | end 37 | 38 | local txt = {} 39 | if ext == '.lua' then 40 | local module = arg[3] 41 | assert(module, 'module name should be provided') 42 | for i=1,#types do 43 | table.insert(txt, string.format("require '%s.%s_%s'", module, basename, types[i])) 44 | end 45 | else 46 | for i=1,#types do 47 | table.insert(txt, string.format('#include "%s_%s%s"', basename, types[i], ext)) 48 | end 49 | end 50 | assert(dst ~= src, 'source and destination are same') 51 | local f = io.open(dst, 'w') 52 | f:write(table.concat(txt, '\n')) 53 | f:close() 54 | -------------------------------------------------------------------------------- /TH/THGenerateAllTypes.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #error "You must define TH_GENERIC_FILE before including THGenerateAllTypes.h" 3 | #endif 4 | 5 | #define real unsigned char 6 | #define accreal long 7 | #define Real Byte 8 | #define TH_REAL_IS_BYTE 9 | #line 1 TH_GENERIC_FILE 10 | /*#line 1 "THByteStorage.h"*/ 11 | #include TH_GENERIC_FILE 12 | #undef real 13 | #undef accreal 14 | #undef Real 15 | #undef TH_REAL_IS_BYTE 16 | 17 | #define real char 18 | #define accreal long 19 | #define Real Char 20 | #define TH_REAL_IS_CHAR 21 | #line 1 TH_GENERIC_FILE 22 | #include TH_GENERIC_FILE 23 | #undef real 24 | #undef accreal 25 | #undef Real 26 | #undef TH_REAL_IS_CHAR 27 | 28 | #define real short 29 | #define accreal long 30 | #define Real Short 31 | #define TH_REAL_IS_SHORT 32 | #line 1 TH_GENERIC_FILE 33 | #include TH_GENERIC_FILE 34 | #undef real 35 | #undef accreal 36 | #undef Real 37 | #undef TH_REAL_IS_SHORT 38 | 39 | #define real int 40 | #define accreal long 41 | #define Real Int 42 | #define TH_REAL_IS_INT 43 | #line 1 TH_GENERIC_FILE 44 | #include TH_GENERIC_FILE 45 | #undef real 46 | #undef accreal 47 | #undef Real 48 | #undef TH_REAL_IS_INT 49 | 50 | #define real long 51 | #define accreal long 52 | #define Real Long 53 | #define TH_REAL_IS_LONG 54 | #line 1 TH_GENERIC_FILE 55 | #include TH_GENERIC_FILE 56 | #undef real 57 | #undef accreal 58 | #undef Real 59 | #undef TH_REAL_IS_LONG 60 | 61 | #define real float 62 | #define accreal double 63 | #define Real Float 64 | #define TH_REAL_IS_FLOAT 65 | #line 1 TH_GENERIC_FILE 66 | #include TH_GENERIC_FILE 67 | #undef real 68 | #undef accreal 69 | #undef Real 70 | #undef TH_REAL_IS_FLOAT 71 | 72 | #define real double 73 | #define accreal double 74 | #define Real Double 75 | #define TH_REAL_IS_DOUBLE 76 | #line 1 TH_GENERIC_FILE 77 | #include TH_GENERIC_FILE 78 | #undef real 79 | #undef accreal 80 | #undef Real 81 | #undef TH_REAL_IS_DOUBLE 82 | 83 | #undef TH_GENERIC_FILE 84 | -------------------------------------------------------------------------------- /TH/THLogAdd.c: -------------------------------------------------------------------------------- 1 | #include "THLogAdd.h" 2 | 3 | #ifdef USE_DOUBLE 4 | #define MINUS_LOG_THRESHOLD -39.14 5 | #else 6 | #define MINUS_LOG_THRESHOLD -18.42 7 | #endif 8 | 9 | const double THLog2Pi=1.83787706640934548355; 10 | const double THLogZero=-THInf; 11 | const double THLogOne=0; 12 | 13 | double THLogAdd(double log_a, double log_b) 14 | { 15 | double minusdif; 16 | 17 | if (log_a < log_b) 18 | { 19 | double tmp = log_a; 20 | log_a = log_b; 21 | log_b = tmp; 22 | } 23 | 24 | minusdif = log_b - log_a; 25 | #ifdef DEBUG 26 | if (isnan(minusdif)) 27 | THError("THLogAdd: minusdif (%f) log_b (%f) or log_a (%f) is nan", minusdif, log_b, log_a); 28 | #endif 29 | if (minusdif < MINUS_LOG_THRESHOLD) 30 | return log_a; 31 | else 32 | return log_a + log1p(exp(minusdif)); 33 | } 34 | 35 | double THLogSub(double log_a, double log_b) 36 | { 37 | double minusdif; 38 | 39 | if (log_a < log_b) 40 | THError("LogSub: log_a (%f) should be greater than log_b (%f)", log_a, log_b); 41 | 42 | minusdif = log_b - log_a; 43 | #ifdef DEBUG 44 | if (isnan(minusdif)) 45 | THError("LogSub: minusdif (%f) log_b (%f) or log_a (%f) is nan", minusdif, log_b, log_a); 46 | #endif 47 | if (log_a == log_b) 48 | return THLogZero; 49 | else if (minusdif < MINUS_LOG_THRESHOLD) 50 | return log_a; 51 | else 52 | return log_a + log1p(-exp(minusdif)); 53 | } 54 | 55 | /* Credits to Leon Bottou */ 56 | double THExpMinusApprox(double x) 57 | { 58 | #define EXACT_EXPONENTIAL 0 59 | #if EXACT_EXPONENTIAL 60 | return exp(-x); 61 | #else 62 | /* fast approximation of exp(-x) for x positive */ 63 | # define A0 (1.0) 64 | # define A1 (0.125) 65 | # define A2 (0.0078125) 66 | # define A3 (0.00032552083) 67 | # define A4 (1.0172526e-5) 68 | if (x < 13.0) 69 | { 70 | /* assert(x>=0); */ 71 | double y; 72 | y = A0+x*(A1+x*(A2+x*(A3+x*A4))); 73 | y *= y; 74 | y *= y; 75 | y *= y; 76 | y = 1/y; 77 | return y; 78 | } 79 | return 0; 80 | # undef A0 81 | # undef A1 82 | # undef A2 83 | # undef A3 84 | # undef A4 85 | #endif 86 | } 87 | -------------------------------------------------------------------------------- /TH/THGeneral.h.in: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERAL_INC 2 | #define TH_GENERAL_INC 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #cmakedefine USE_BLAS 14 | #cmakedefine USE_LAPACK 15 | #cmakedefine BLAS_IS_ACCELERATE 16 | #cmakedefine BLAS_F2C 17 | 18 | #ifdef __cplusplus 19 | # define TH_EXTERNC extern "C" 20 | #else 21 | # define TH_EXTERNC extern 22 | #endif 23 | 24 | #ifdef WIN32 25 | # ifdef TH_EXPORTS 26 | # define TH_API TH_EXTERNC __declspec(dllexport) 27 | # else 28 | # define TH_API TH_EXTERNC __declspec(dllimport) 29 | # endif 30 | #else 31 | # define TH_API TH_EXTERNC 32 | #endif 33 | 34 | #define THInf DBL_MAX 35 | 36 | #define TH_INLINE @TH_INLINE@ 37 | 38 | #ifndef __cplusplus 39 | #define inline @TH_INLINE@ 40 | #endif 41 | 42 | #ifndef M_PI 43 | # define M_PI 3.14159265358979323846 44 | #endif 45 | 46 | TH_API double THLog1p(const double x); 47 | TH_API void THError(const char *fmt, ...); 48 | TH_API void THSetErrorHandler( void (*torchErrorHandlerFunction)(const char *msg, void *data), void *data ); 49 | TH_API void THArgCheck(int condition, int argNumber, const char *msg); 50 | TH_API void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction)(int argNumber, const char *msg, void *data), void *data ); 51 | TH_API void* THAlloc(long size); 52 | TH_API void* THRealloc(void *ptr, long size); 53 | TH_API void THFree(void *ptr); 54 | 55 | #define TH_CONCAT_STRING_2(x,y) TH_CONCAT_STRING_2_EXPAND(x,y) 56 | #define TH_CONCAT_STRING_2_EXPAND(x,y) #x #y 57 | 58 | #define TH_CONCAT_STRING_3(x,y,z) TH_CONCAT_STRING_3_EXPAND(x,y,z) 59 | #define TH_CONCAT_STRING_3_EXPAND(x,y,z) #x #y #z 60 | 61 | #define TH_CONCAT_STRING_4(x,y,z,w) TH_CONCAT_STRING_4_EXPAND(x,y,z,w) 62 | #define TH_CONCAT_STRING_4_EXPAND(x,y,z,w) #x #y #z #w 63 | 64 | #define TH_CONCAT_2(x,y) TH_CONCAT_2_EXPAND(x,y) 65 | #define TH_CONCAT_2_EXPAND(x,y) x ## y 66 | 67 | #define TH_CONCAT_3(x,y,z) TH_CONCAT_3_EXPAND(x,y,z) 68 | #define TH_CONCAT_3_EXPAND(x,y,z) x ## y ## z 69 | 70 | #define TH_CONCAT_4_EXPAND(x,y,z,w) x ## y ## z ## w 71 | #define TH_CONCAT_4(x,y,z,w) TH_CONCAT_4_EXPAND(x,y,z,w) 72 | 73 | #define THMin(X, Y) ((X) < (Y) ? (X) : (Y)) 74 | #define THMax(X, Y) ((X) > (Y) ? (X) : (Y)) 75 | 76 | #ifdef _MSC_VER 77 | # define log1p(x) THLog1p(x) 78 | #define snprintf _snprintf 79 | #define popen _popen 80 | #define pclose _pclose 81 | #endif 82 | 83 | #endif 84 | -------------------------------------------------------------------------------- /TH/generic/THStorage.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THStorage.h" 3 | #else 4 | 5 | /* on pourrait avoir un liste chainee 6 | qui initialise math, lab structures (or more). 7 | mouais -- complique. 8 | 9 | Pb: THMapStorage is kind of a class 10 | THLab_()... comment je m'en sors? 11 | 12 | en template, faudrait que je les instancie toutes!!! oh boy! 13 | Et comment je sais que c'est pour Cuda? Le type float est le meme dans les <> 14 | 15 | au bout du compte, ca serait sur des pointeurs float/double... etc... = facile. 16 | primitives?? 17 | */ 18 | 19 | #define TH_STORAGE_REFCOUNTED 1 20 | #define TH_STORAGE_RESIZABLE 2 21 | #define TH_STORAGE_FREEMEM 4 22 | 23 | typedef struct THStorage 24 | { 25 | real *data; 26 | long size; 27 | int refcount; 28 | char flag; 29 | THAllocator *allocator; 30 | void *allocatorContext; 31 | } THStorage; 32 | 33 | TH_API real* THStorage_(data)(const THStorage*); 34 | TH_API long THStorage_(size)(const THStorage*); 35 | 36 | /* slow access -- checks everything */ 37 | TH_API void THStorage_(set)(THStorage*, long, real); 38 | TH_API real THStorage_(get)(const THStorage*, long); 39 | 40 | TH_API THStorage* THStorage_(new)(void); 41 | TH_API THStorage* THStorage_(newWithSize)(long size); 42 | TH_API THStorage* THStorage_(newWithSize1)(real); 43 | TH_API THStorage* THStorage_(newWithSize2)(real, real); 44 | TH_API THStorage* THStorage_(newWithSize3)(real, real, real); 45 | TH_API THStorage* THStorage_(newWithSize4)(real, real, real, real); 46 | TH_API THStorage* THStorage_(newWithMapping)(const char *filename, long size, int shared); 47 | 48 | /* takes ownership of data */ 49 | TH_API THStorage* THStorage_(newWithData)(real *data, long size); 50 | 51 | TH_API THStorage* THStorage_(newWithAllocator)(long size, 52 | THAllocator* allocator, 53 | void *allocatorContext); 54 | TH_API THStorage* THStorage_(newWithDataAndAllocator)( 55 | real* data, long size, THAllocator* allocator, void *allocatorContext); 56 | 57 | /* should not differ with API */ 58 | TH_API void THStorage_(setFlag)(THStorage *storage, const char flag); 59 | TH_API void THStorage_(clearFlag)(THStorage *storage, const char flag); 60 | TH_API void THStorage_(retain)(THStorage *storage); 61 | 62 | /* might differ with other API (like CUDA) */ 63 | TH_API void THStorage_(free)(THStorage *storage); 64 | TH_API void THStorage_(resize)(THStorage *storage, long size); 65 | TH_API void THStorage_(fill)(THStorage *storage, real value); 66 | 67 | #endif 68 | -------------------------------------------------------------------------------- /cmake/TorchTemplate.cmake: -------------------------------------------------------------------------------- 1 | MACRO(ADD_TORCH_TEMPLATE filename) 2 | GET_FILENAME_COMPONENT(_ext ${filename} EXT) 3 | GET_FILENAME_COMPONENT(_file ${filename} NAME_WE) 4 | IF(NOT ${ARGV2} STREQUAL "") 5 | LIST(APPEND tpl_${ARGV2}_${filename}_files "${CMAKE_CURRENT_BINARY_DIR}/${filename}") 6 | ENDIF() 7 | LIST(APPEND tpl_${ARGV2}_${filename}_files "${CMAKE_CURRENT_BINARY_DIR}/${_file}_byte${_ext}") 8 | LIST(APPEND tpl_${ARGV2}_${filename}_files "${CMAKE_CURRENT_BINARY_DIR}/${_file}_char${_ext}") 9 | LIST(APPEND tpl_${ARGV2}_${filename}_files "${CMAKE_CURRENT_BINARY_DIR}/${_file}_short${_ext}") 10 | LIST(APPEND tpl_${ARGV2}_${filename}_files "${CMAKE_CURRENT_BINARY_DIR}/${_file}_int${_ext}") 11 | LIST(APPEND tpl_${ARGV2}_${filename}_files "${CMAKE_CURRENT_BINARY_DIR}/${_file}_long${_ext}") 12 | LIST(APPEND tpl_${ARGV2}_${filename}_files "${CMAKE_CURRENT_BINARY_DIR}/${_file}_float${_ext}") 13 | LIST(APPEND tpl_${ARGV2}_${filename}_files "${CMAKE_CURRENT_BINARY_DIR}/${_file}_double${_ext}") 14 | 15 | ADD_CUSTOM_COMMAND( 16 | OUTPUT ${tpl_${ARGV2}_${filename}_files} 17 | COMMAND ${LUA_EXECUTABLE} ARGS 18 | "${CMAKE_CURRENT_SOURCE_DIR}/cmake/template.lua" 19 | "${CMAKE_CURRENT_SOURCE_DIR}/${filename}" 20 | "${CMAKE_CURRENT_BINARY_DIR}/${filename}" 21 | "${ARGV2}" 22 | DEPENDS 23 | "${CMAKE_CURRENT_SOURCE_DIR}/${filename}" 24 | "${CMAKE_CURRENT_SOURCE_DIR}/cmake/template.lua") 25 | 26 | SET_PROPERTY(SOURCE "${CMAKE_CURRENT_BINARY_DIR}/${_file}_byte${_ext}" PROPERTY COMPILE_DEFINITIONS REAL_IS_BYTE byte=unsigned\ char) 27 | SET_PROPERTY(SOURCE "${CMAKE_CURRENT_BINARY_DIR}/${_file}_char${_ext}" PROPERTY COMPILE_DEFINITIONS REAL_IS_CHAR byte=unsigned\ char) 28 | SET_PROPERTY(SOURCE "${CMAKE_CURRENT_BINARY_DIR}/${_file}_short${_ext}" PROPERTY COMPILE_DEFINITIONS REAL_IS_SHORT byte=unsigned\ char) 29 | SET_PROPERTY(SOURCE "${CMAKE_CURRENT_BINARY_DIR}/${_file}_int${_ext}" PROPERTY COMPILE_DEFINITIONS REAL_IS_INT byte=unsigned\ char) 30 | SET_PROPERTY(SOURCE "${CMAKE_CURRENT_BINARY_DIR}/${_file}_long${_ext}" PROPERTY COMPILE_DEFINITIONS REAL_IS_LONG byte=unsigned\ char) 31 | SET_PROPERTY(SOURCE "${CMAKE_CURRENT_BINARY_DIR}/${_file}_float${_ext}" PROPERTY COMPILE_DEFINITIONS REAL_IS_FLOAT byte=unsigned\ char) 32 | SET_PROPERTY(SOURCE "${CMAKE_CURRENT_BINARY_DIR}/${_file}_double${_ext}" PROPERTY COMPILE_DEFINITIONS REAL_IS_DOUBLE byte=unsigned\ char) 33 | 34 | IF(${ARGC} GREATER 1) 35 | LIST(APPEND ${ARGV1} ${tpl_${ARGV2}_${filename}_files}) 36 | ENDIF() 37 | 38 | IF(NOT ${ARGV2} STREQUAL "") 39 | ADD_CUSTOM_TARGET(tpl_tpl_${ARGV2}_${_file} ALL DEPENDS ${tpl_${ARGV2}_${filename}_files}) 40 | ELSE() 41 | ADD_CUSTOM_TARGET(tpl_${_file} ALL DEPENDS ${tpl_${ARGV2}_${filename}_files}) 42 | ENDIF() 43 | 44 | ENDMACRO() 45 | -------------------------------------------------------------------------------- /TH/cmake/FindSSE.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(CheckCSourceRuns) 2 | INCLUDE(CheckCXXSourceRuns) 3 | 4 | SET(SSE1_CODE " 5 | #include 6 | 7 | int main() 8 | { 9 | __m128 a; 10 | float vals[4] = {0,0,0,0}; 11 | a = _mm_loadu_ps(vals); 12 | return 0; 13 | }") 14 | 15 | SET(SSE2_CODE " 16 | #include 17 | 18 | int main() 19 | { 20 | __m128d a; 21 | double vals[2] = {0,0}; 22 | a = _mm_loadu_pd(vals); 23 | return 0; 24 | }") 25 | 26 | SET(SSE3_CODE " 27 | #include 28 | 29 | int main( ) 30 | { 31 | const int vals[4] = {0,0,0,0}; 32 | __m128i a; 33 | a = _mm_lddqu_si128( (const __m128i*)vals ); 34 | return 0; 35 | }") 36 | 37 | SET(SSE4_1_CODE " 38 | #include 39 | 40 | int main () 41 | { 42 | __m128i a, b; 43 | __m128i res = _mm_max_epi8(a, b); 44 | 45 | return 0; 46 | } 47 | ") 48 | 49 | SET(SSE4_2_CODE " 50 | #include 51 | 52 | int main() 53 | { 54 | __m128i a, b, c; 55 | c = _mm_cmpgt_epi64(a, b); 56 | return 0; 57 | } 58 | ") 59 | 60 | MACRO(CHECK_SSE lang type flags) 61 | SET(__FLAG_I 1) 62 | SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) 63 | FOREACH(__FLAG ${flags}) 64 | IF(NOT ${lang}_${type}_FOUND) 65 | SET(CMAKE_REQUIRED_FLAGS ${__FLAG}) 66 | IF(lang STREQUAL "CXX") 67 | CHECK_CXX_SOURCE_RUNS("${${type}_CODE}" ${lang}_HAS_${type}_${__FLAG_I}) 68 | ELSE() 69 | CHECK_C_SOURCE_RUNS("${${type}_CODE}" ${lang}_HAS_${type}_${__FLAG_I}) 70 | ENDIF() 71 | IF(${lang}_HAS_${type}_${__FLAG_I}) 72 | SET(${lang}_${type}_FOUND TRUE CACHE BOOL "${lang} ${type} support") 73 | SET(${lang}_${type}_FLAGS "${__FLAG}" CACHE STRING "${lang} ${type} flags") 74 | ENDIF() 75 | MATH(EXPR __FLAG_I "${__FLAG_I}+1") 76 | ENDIF() 77 | ENDFOREACH() 78 | SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) 79 | 80 | IF(NOT ${lang}_${type}_FOUND) 81 | SET(${lang}_${type}_FOUND FALSE CACHE BOOL "${lang} ${type} support") 82 | SET(${lang}_${type}_FLAGS "" CACHE STRING "${lang} ${type} flags") 83 | ENDIF() 84 | 85 | MARK_AS_ADVANCED(${lang}_${type}_FOUND ${lang}_${type}_FLAGS) 86 | 87 | ENDMACRO() 88 | 89 | CHECK_SSE(C "SSE1" " ;-msse;/arch:SSE") 90 | CHECK_SSE(C "SSE2" " ;-msse2;/arch:SSE2") 91 | CHECK_SSE(C "SSE3" " ;-msse3;/arch:SSE3") 92 | CHECK_SSE(C "SSE4_1" " ;-msse4.1;-msse4;/arch:SSE4") 93 | CHECK_SSE(C "SSE4_2" " ;-msse4.2;-msse4;/arch:SSE4") 94 | 95 | CHECK_SSE(CXX "SSE1" " ;-msse;/arch:SSE") 96 | CHECK_SSE(CXX "SSE2" " ;-msse2;/arch:SSE2") 97 | CHECK_SSE(CXX "SSE3" " ;-msse3;/arch:SSE3") 98 | CHECK_SSE(CXX "SSE4_1" " ;-msse4.1;-msse4;/arch:SSE4") 99 | CHECK_SSE(CXX "SSE4_2" " ;-msse4.2;-msse4;/arch:SSE4") 100 | -------------------------------------------------------------------------------- /TH/cmake/FindARM.cmake: -------------------------------------------------------------------------------- 1 | # Check if the processor is an ARM and if Neon instruction are available on the machine where 2 | # the project is compiled. 3 | 4 | IF(CMAKE_SYSTEM_NAME MATCHES "Linux") 5 | EXEC_PROGRAM(cat ARGS "/proc/cpuinfo" OUTPUT_VARIABLE CPUINFO) 6 | 7 | #neon instruction can be found on the majority part of modern ARM processor 8 | STRING(REGEX REPLACE "^.*(neon).*$" "\\1" NEON_THERE ${CPUINFO}) 9 | STRING(COMPARE EQUAL "neon" "${NEON_THERE}" NEON_TRUE) 10 | IF (NEON_TRUE) 11 | set(NEON_FOUND true CACHE BOOL "NEON available on host") 12 | ELSE (NEON_TRUE) 13 | set(NEON_FOUND false CACHE BOOL "NEON available on host") 14 | ENDIF (NEON_TRUE) 15 | 16 | #Find the processor type (for now OMAP3 or OMAP4) 17 | STRING(REGEX REPLACE "^.*(OMAP3).*$" "\\1" OMAP3_THERE ${CPUINFO}) 18 | STRING(COMPARE EQUAL "OMAP3" "${OMAP3_THERE}" OMAP3_TRUE) 19 | IF (OMAP3_TRUE) 20 | set(CORTEXA8_FOUND true CACHE BOOL "OMAP3 available on host") 21 | ELSE (OMAP3_TRUE) 22 | set(CORTEXA8_FOUND false CACHE BOOL "OMAP3 available on host") 23 | ENDIF (OMAP3_TRUE) 24 | 25 | #Find the processor type (for now OMAP3 or OMAP4) 26 | STRING(REGEX REPLACE "^.*(OMAP4).*$" "\\1" OMAP4_THERE ${CPUINFO}) 27 | STRING(COMPARE EQUAL "OMAP4" "${OMAP4_THERE}" OMAP4_TRUE) 28 | IF (OMAP4_TRUE) 29 | set(CORTEXA9_FOUND true CACHE BOOL "OMAP4 available on host") 30 | ELSE (OMAP4_TRUE) 31 | set(CORTEXA9_FOUND false CACHE BOOL "OMAP4 available on host") 32 | ENDIF (OMAP4_TRUE) 33 | 34 | ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Darwin") 35 | EXEC_PROGRAM("/usr/sbin/sysctl -n machdep.cpu.features" OUTPUT_VARIABLE 36 | CPUINFO) 37 | 38 | #neon instruction can be found on the majority part of modern ARM processor 39 | STRING(REGEX REPLACE "^.*(neon).*$" "\\1" NEON_THERE ${CPUINFO}) 40 | STRING(COMPARE EQUAL "neon" "${NEON_THERE}" NEON_TRUE) 41 | IF (NEON_TRUE) 42 | set(NEON_FOUND true CACHE BOOL "NEON available on host") 43 | ELSE (NEON_TRUE) 44 | set(NEON_FOUND false CACHE BOOL "NEON available on host") 45 | ENDIF (NEON_TRUE) 46 | 47 | ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Windows") 48 | # TODO 49 | set(CORTEXA8_FOUND false CACHE BOOL "OMAP3 not available on host") 50 | set(CORTEXA9_FOUND false CACHE BOOL "OMAP4 not available on host") 51 | set(NEON_FOUND false CACHE BOOL "NEON not available on host") 52 | ELSE(CMAKE_SYSTEM_NAME MATCHES "Linux") 53 | set(CORTEXA8_FOUND false CACHE BOOL "OMAP3 not available on host") 54 | set(CORTEXA9_FOUND false CACHE BOOL "OMAP4 not available on host") 55 | set(NEON_FOUND false CACHE BOOL "NEON not available on host") 56 | ENDIF(CMAKE_SYSTEM_NAME MATCHES "Linux") 57 | 58 | if(NOT NEON_FOUND) 59 | MESSAGE(STATUS "Could not find hardware support for NEON on this machine.") 60 | endif(NOT NEON_FOUND) 61 | if(NOT CORTEXA8_FOUND) 62 | MESSAGE(STATUS "No OMAP3 processor on this on this machine.") 63 | endif(NOT CORTEXA8_FOUND) 64 | if(NOT CORTEXA9_FOUND) 65 | MESSAGE(STATUS "No OMAP4 processor on this on this machine.") 66 | endif(NOT CORTEXA9_FOUND) 67 | mark_as_advanced(NEON_FOUND) 68 | -------------------------------------------------------------------------------- /TH/THRandom.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_RANDOM_INC 2 | #define TH_RANDOM_INC 3 | 4 | #include "THGeneral.h" 5 | 6 | #define _MERSENNE_STATE_N 624 7 | #define _MERSENNE_STATE_M 397 8 | typedef struct THGenerator { 9 | /* The initial seed. */ 10 | unsigned long the_initial_seed; 11 | int left; /* = 1; */ 12 | int initf; /* = 0; */ 13 | unsigned long *next; 14 | unsigned long state[_MERSENNE_STATE_N]; /* the array for the state vector */ 15 | /********************************/ 16 | 17 | /* For normal distribution */ 18 | double normal_x; 19 | double normal_y; 20 | double normal_rho; 21 | int normal_is_valid; /* = 0; */ 22 | } THGenerator; 23 | 24 | #define torch_Generator "torch.Generator" 25 | 26 | /* Create a new random number generator stream */ 27 | TH_API THGenerator * THGenerator_new(); 28 | 29 | /* Free a random number generator stream */ 30 | TH_API void THGenerator_free(THGenerator *gen); 31 | 32 | /* Initializes the random number generator with the current time (granularity: seconds) and returns the seed. */ 33 | TH_API unsigned long THRandom_seed(THGenerator *_generator); 34 | 35 | /* Initializes the random number generator with the given long "the_seed_". */ 36 | TH_API void THRandom_manualSeed(THGenerator *_generator, unsigned long the_seed_); 37 | 38 | /* Returns the starting seed used. */ 39 | TH_API unsigned long THRandom_initialSeed(THGenerator *_generator); 40 | 41 | /* Generates a uniform 32 bits integer. */ 42 | TH_API unsigned long THRandom_random(THGenerator *_generator); 43 | 44 | /* Generates a uniform random number on [0,1[. */ 45 | TH_API double THRandom_uniform(THGenerator *_generator, double a, double b); 46 | 47 | /** Generates a random number from a normal distribution. 48 | (With mean #mean# and standard deviation #stdv >= 0#). 49 | */ 50 | TH_API double THRandom_normal(THGenerator *_generator, double mean, double stdv); 51 | 52 | /** Generates a random number from an exponential distribution. 53 | The density is $p(x) = lambda * exp(-lambda * x)$, where 54 | lambda is a positive number. 55 | */ 56 | TH_API double THRandom_exponential(THGenerator *_generator, double lambda); 57 | 58 | /** Returns a random number from a Cauchy distribution. 59 | The Cauchy density is $p(x) = sigma/(pi*(sigma^2 + (x-median)^2))$ 60 | */ 61 | TH_API double THRandom_cauchy(THGenerator *_generator, double median, double sigma); 62 | 63 | /** Generates a random number from a log-normal distribution. 64 | (#mean > 0# is the mean of the log-normal distribution 65 | and #stdv# is its standard deviation). 66 | */ 67 | TH_API double THRandom_logNormal(THGenerator *_generator, double mean, double stdv); 68 | 69 | /** Generates a random number from a geometric distribution. 70 | It returns an integer #i#, where $p(i) = (1-p) * p^(i-1)$. 71 | p must satisfy $0 < p < 1$. 72 | */ 73 | TH_API int THRandom_geometric(THGenerator *_generator, double p); 74 | 75 | /* Returns true with probability $p$ and false with probability $1-p$ (p > 0). */ 76 | TH_API int THRandom_bernoulli(THGenerator *_generator, double p); 77 | 78 | /* returns the random number state */ 79 | TH_API void THRandom_getState(THGenerator *_generator, unsigned long *state, long *offset, long *_left); 80 | 81 | /* sets the random number state */ 82 | TH_API void THRandom_setState(THGenerator *_generator, unsigned long *state, long offset, long _left); 83 | #endif 84 | -------------------------------------------------------------------------------- /TH/THGeneral.c: -------------------------------------------------------------------------------- 1 | #include "THGeneral.h" 2 | 3 | #ifndef TH_HAVE_THREAD 4 | #define __thread 5 | #endif 6 | /* Torch Error Handling */ 7 | static void defaultTorchErrorHandlerFunction(const char *msg, void *data) 8 | { 9 | printf("$ Error: %s\n", msg); 10 | exit(-1); 11 | } 12 | 13 | static __thread void (*torchErrorHandlerFunction)(const char *msg, void *data) = defaultTorchErrorHandlerFunction; 14 | static __thread void *torchErrorHandlerData; 15 | 16 | void THError(const char *fmt, ...) 17 | { 18 | char msg[1024]; 19 | va_list args; 20 | 21 | /* vasprintf not standard */ 22 | /* vsnprintf: how to handle if does not exists? */ 23 | va_start(args, fmt); 24 | vsnprintf(msg, 1024, fmt, args); 25 | va_end(args); 26 | 27 | (*torchErrorHandlerFunction)(msg, torchErrorHandlerData); 28 | } 29 | 30 | void THSetErrorHandler( void (*torchErrorHandlerFunction_)(const char *msg, void *data), void *data ) 31 | { 32 | if(torchErrorHandlerFunction_) 33 | torchErrorHandlerFunction = torchErrorHandlerFunction_; 34 | else 35 | torchErrorHandlerFunction = defaultTorchErrorHandlerFunction; 36 | torchErrorHandlerData = data; 37 | } 38 | 39 | /* Torch Arg Checking Handling */ 40 | static void defaultTorchArgErrorHandlerFunction(int argNumber, const char *msg, void *data) 41 | { 42 | if(msg) 43 | printf("$ Invalid argument %d: %s\n", argNumber, msg); 44 | else 45 | printf("$ Invalid argument %d\n", argNumber); 46 | exit(-1); 47 | } 48 | 49 | static __thread void (*torchArgErrorHandlerFunction)(int argNumber, const char *msg, void *data) = defaultTorchArgErrorHandlerFunction; 50 | static __thread void *torchArgErrorHandlerData; 51 | 52 | void THArgCheck(int condition, int argNumber, const char *msg) 53 | { 54 | if(!condition) 55 | (*torchArgErrorHandlerFunction)(argNumber, msg, torchArgErrorHandlerData); 56 | } 57 | 58 | void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction_)(int argNumber, const char *msg, void *data), void *data ) 59 | { 60 | if(torchArgErrorHandlerFunction_) 61 | torchArgErrorHandlerFunction = torchArgErrorHandlerFunction_; 62 | else 63 | torchArgErrorHandlerFunction = defaultTorchArgErrorHandlerFunction; 64 | torchArgErrorHandlerData = data; 65 | } 66 | 67 | void* THAlloc(long size) 68 | { 69 | void *ptr; 70 | 71 | if(size < 0) 72 | THError("$ Torch: invalid memory size -- maybe an overflow?"); 73 | 74 | if(size == 0) 75 | return NULL; 76 | 77 | ptr = malloc(size); 78 | if(!ptr) 79 | THError("$ Torch: not enough memory: you tried to allocate %dGB. Buy new RAM!", size/1073741824); 80 | 81 | return ptr; 82 | } 83 | 84 | void* THRealloc(void *ptr, long size) 85 | { 86 | if(!ptr) 87 | return(THAlloc(size)); 88 | 89 | if(size == 0) 90 | { 91 | THFree(ptr); 92 | return NULL; 93 | } 94 | 95 | if(size < 0) 96 | THError("$ Torch: invalid memory size -- maybe an overflow?"); 97 | 98 | ptr = realloc(ptr, size); 99 | if(!ptr) 100 | THError("$ Torch: not enough memory: you tried to reallocate %dGB. Buy new RAM!", size/1073741824); 101 | return ptr; 102 | } 103 | 104 | void THFree(void *ptr) 105 | { 106 | free(ptr); 107 | } 108 | 109 | double THLog1p(const double x) 110 | { 111 | #ifdef _MSC_VER 112 | volatile double y = 1 + x; 113 | return log(y) - ((y-1)-x)/y ; /* cancels errors with IEEE arithmetic */ 114 | #else 115 | return log1p(x); 116 | #endif 117 | } 118 | -------------------------------------------------------------------------------- /registernumbers.lua: -------------------------------------------------------------------------------- 1 | local register_ = require 'torch.register' 2 | local torch = require 'torch.env' 3 | 4 | local function copy_args(args) 5 | local tbl = {} 6 | for k,v in pairs(args) do 7 | tbl[k] = v 8 | end 9 | return tbl 10 | end 11 | 12 | -- handle numbers type 13 | local function register(args, namespace, metatable) 14 | local nidx 15 | for idx,arg in ipairs(args) do 16 | if arg.type == 'numbers' then 17 | if nidx then 18 | error('only one argument can be of type') 19 | end 20 | nidx = idx 21 | end 22 | end 23 | if nidx then 24 | assert(args.call, ' is supposed to be used together with ') 25 | 26 | -- with table 27 | local new_args = copy_args(args) 28 | new_args[nidx] = copy_args(new_args[nidx]) -- avoid modification with no warning 29 | new_args[nidx].type = 'table' 30 | local funcargs = {} 31 | local callargs = {} 32 | for i=1,#new_args do 33 | table.insert(funcargs, string.format('arg%d', i)) 34 | table.insert(callargs, string.format('arg%d', i)) 35 | end 36 | callargs[nidx] = 'numbers' 37 | funcargs = table.concat(funcargs, ', ') 38 | callargs = table.concat(callargs, ', ') 39 | 40 | local numbers = torch.LongStorage() 41 | local code = [[ 42 | local call 43 | local numbers 44 | return function(%s) 45 | local sz = #arg%d 46 | numbers:resize(sz) 47 | for i=1,sz do 48 | numbers.__data[i-1] = arg%d[i] 49 | end 50 | return call(%s) 51 | end 52 | ]] 53 | code = string.format(code, funcargs, nidx, nidx, callargs) 54 | code = loadstring(code)() 55 | debug.setupvalue(code, 1, numbers) 56 | debug.setupvalue(code, 2, args.call) 57 | new_args.call = code 58 | register_(new_args, namespace, metatable) 59 | 60 | -- with numbers (up to N) 61 | local N = 5 62 | local new_args = copy_args(args) 63 | table.remove(new_args, nidx) 64 | for i=1,N do 65 | local arg = copy_args(args[nidx]) 66 | arg.name = arg.name .. i 67 | arg.type = "number" 68 | if i > 1 then 69 | arg.default = 0 70 | end 71 | table.insert(new_args, nidx+i-1, arg) 72 | end 73 | local funcargs = {} 74 | local callargs = {} 75 | for i=1,#new_args do 76 | table.insert(funcargs, string.format('arg%d', i)) 77 | table.insert(callargs, string.format('arg%d', i)) 78 | end 79 | callargs[nidx] = 'numbers' 80 | for i=2,N do 81 | table.remove(callargs, nidx+1) 82 | end 83 | funcargs = table.concat(funcargs, ', ') 84 | callargs = table.concat(callargs, ', ') 85 | 86 | local numbers = torch.LongStorage(5) 87 | local code = [[ 88 | local call 89 | local numbers 90 | return function(%s) 91 | numbers.__data[0] = arg%d 92 | numbers.__data[1] = arg%d 93 | numbers.__data[2] = arg%d 94 | numbers.__data[3] = arg%d 95 | numbers.__data[4] = arg%d 96 | return call(%s) 97 | end 98 | ]] 99 | code = string.format(code, funcargs, nidx, nidx+1, nidx+2, nidx+3, nidx+4, callargs) 100 | code = loadstring(code)() 101 | debug.setupvalue(code, 1, numbers) 102 | debug.setupvalue(code, 2, args.call) 103 | new_args.call = code 104 | register_(new_args, namespace, metatable) 105 | 106 | -- with LongStorage 107 | local new_args = copy_args(args) 108 | args[nidx].type = 'torch.LongStorage' 109 | register_(new_args, namespace, metatable) 110 | else 111 | register_(args, namespace, metatable) 112 | end 113 | end 114 | 115 | return register 116 | -------------------------------------------------------------------------------- /COPYRIGHT.txt: -------------------------------------------------------------------------------- 1 | =============================================================================== 2 | 3 | Torch9 -- http://github.com/andresy/torch9 4 | 5 | Copyright (c) 2011-2013 Idiap Research Institute (Ronan Collobert) 6 | Copyright (c) 2011-2013 NEC Laboratories America (Koray Kavukcuoglu) 7 | Copyright (c) 2011-2013 NYU (Clement Farabet) 8 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 9 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 10 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 11 | 12 | All rights reserved. 13 | 14 | Redistribution and use in source and binary forms, with or without 15 | modification, are permitted provided that the following conditions are met: 16 | 17 | 1. Redistributions of source code must retain the above copyright 18 | notice, this list of conditions and the following disclaimer. 19 | 20 | 2. Redistributions in binary form must reproduce the above copyright 21 | notice, this list of conditions and the following disclaimer in the 22 | documentation and/or other materials provided with the distribution. 23 | 24 | 3. Neither the names of NEC Laboratories American and Idiap Research 25 | Institute nor the names of its contributors may be used to endorse or 26 | promote products derived from this software without specific prior 27 | written permission. 28 | 29 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 30 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 31 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 32 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 33 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 34 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 35 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 36 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 37 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 38 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | =============================================================================== 42 | [Torch9 includes Mersenne Twister code with has this license statement: ] 43 | 44 | Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, 45 | All rights reserved. 46 | 47 | Redistribution and use in source and binary forms, with or without 48 | modification, are permitted provided that the following conditions 49 | are met: 50 | 51 | 1. Redistributions of source code must retain the above copyright 52 | notice, this list of conditions and the following disclaimer. 53 | 54 | 2. Redistributions in binary form must reproduce the above copyright 55 | notice, this list of conditions and the following disclaimer in the 56 | documentation and/or other materials provided with the distribution. 57 | 58 | 3. The names of its contributors may not be used to endorse or promote 59 | products derived from this software without specific prior written 60 | permission. 61 | 62 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 63 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 64 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 65 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 66 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 67 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 68 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 69 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 70 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 71 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 72 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 73 | 74 | -------------------------------------------------------------------------------- /dimapply.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch.env' 2 | 3 | -- narg: #of input tensors 4 | -- dim: #of dim of the input tensors 5 | -- dima: dimension over which we apply 6 | local function generatedimapply_narg_dim_dima(narg, dim, dima) 7 | local func = {} 8 | local funcarg = {} 9 | for n=1,narg do 10 | table.insert(funcarg, string.format('t%d', n)) 11 | end 12 | table.insert(func, string.format('return function(%s, func)', table.concat(funcarg, ', '))) 13 | for n=1,narg do 14 | for i=0,dim-1 do 15 | table.insert(func, string.format('local t%dsz%d, t%dst%d = tonumber(t%d.__size[%d]), tonumber(t%d.__stride[%d])', n, i, n, i, n, i, n, i)) 16 | end 17 | table.insert(func, string.format('local t%ddata = t%d.__storage.__data+t%d.__storageOffset', n, n, n)) 18 | end 19 | for i=0,dim-1 do 20 | if i ~= dima then 21 | table.insert(func, string.format('for i%d=0,t%dsz%d-1 do', i, 1, i)) 22 | end 23 | end 24 | local funcarg = {} 25 | for n=1,narg do 26 | local ptr = {string.format('t%ddata', n)} 27 | for i=0,dim-1 do 28 | if i ~= dima then 29 | table.insert(ptr, string.format('+i%d*t%dst%d', i, n, i)) 30 | end 31 | end 32 | table.insert(funcarg, string.format('t%dsz%d', n, dima)) 33 | table.insert(funcarg, table.concat(ptr, '')) 34 | table.insert(funcarg, string.format('t%dst%d', n, dima)) 35 | end 36 | table.insert(func, string.format('func(%s)', table.concat(funcarg, ', '))) 37 | for i=0,dim-1 do 38 | if i ~= dima then 39 | table.insert(func, 'end') 40 | end 41 | end 42 | table.insert(func, 'end') 43 | return table.concat(func, '\n') 44 | end 45 | 46 | local function generatedimapply_n(n) 47 | local func = {} 48 | local decl = {} 49 | for i=1,n do 50 | table.insert(decl, string.format('t%d', i)) 51 | end 52 | table.insert(func, table.concat({string.format('function torch.dimapply%d(', n), 53 | table.concat(decl, ', '), 54 | ', dim, func)'}, '')) 55 | table.insert(func, 'local ndim = t1.__nDimension') 56 | table.insert(func, 'dim = dim - 1') 57 | for ndim=1,2 do 58 | table.insert(func, string.format('%sif ndim == %d then', ndim == 1 and '' or 'else', ndim)) 59 | table.insert(func, generatedimapply_dim_n(ndim, n)) 60 | end 61 | table.insert(func, 'else') 62 | table.insert(func, 'error("the provided tensor has too many dimensions")') 63 | table.insert(func, 'end') -- if/elseif 64 | table.insert(func, 'end') 65 | return table.concat(func, '\n') 66 | end 67 | 68 | local dimapply1funcs = {} 69 | function torch.rawdimapply(t1, dim, func) 70 | local dim1 = t1.__nDimension 71 | dimapply1funcs[dim1] = dimapply1funcs[dim1] or {} 72 | local dimapplyfunc = dimapply1funcs[dim1][dim-1] 73 | if not dimapplyfunc then 74 | dimapplyfunc = loadstring(generatedimapply_narg_dim_dima(1, dim1, dim-1))() 75 | dimapply1funcs[dim1][dim-1] = dimapplyfunc 76 | end 77 | dimapplyfunc(t1, func) 78 | end 79 | 80 | local dimapply2funcs = {} 81 | function torch.rawdimapply2(t1, t2, dim, func) 82 | local dim1 = t1.__nDimension 83 | dimapply2funcs[dim1] = dimapply2funcs[dim1] or {} 84 | local dimapplyfunc = dimapply2funcs[dim1][dim-1] 85 | if not dimapplyfunc then 86 | dimapplyfunc = loadstring(generatedimapply_narg_dim_dima(2, dim1, dim-1))() 87 | dimapply2funcs[dim1][dim-1] = dimapplyfunc 88 | end 89 | dimapplyfunc(t1, t2, func) 90 | end 91 | 92 | local dimapply3funcs = {} 93 | function torch.rawdimapply3(t1, t2, t3, dim, func) 94 | local dim1 = t1.__nDimension 95 | dimapply3funcs[dim1] = dimapply3funcs[dim1] or {} 96 | local dimapplyfunc = dimapply3funcs[dim1][dim-1] 97 | if not dimapplyfunc then 98 | dimapplyfunc = loadstring(generatedimapply_narg_dim_dima(3, dim1, dim-1))() 99 | dimapply3funcs[dim1][dim-1] = dimapplyfunc 100 | end 101 | dimapplyfunc(t1, t2, t3, func) 102 | end 103 | -------------------------------------------------------------------------------- /timer.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch.env' 2 | local class = require 'class' 3 | local ffi = require 'ffi' 4 | 5 | if jit.os == 'OSX' then 6 | ffi.cdef([[ 7 | typedef long time_t; 8 | 9 | typedef struct timeval { 10 | time_t tv_sec; 11 | int tv_usec; 12 | }; 13 | 14 | int gettimeofday(struct timeval* t, void* tzp); 15 | ]]) 16 | else 17 | ffi.cdef([[ 18 | typedef long time_t; 19 | 20 | typedef struct timeval { 21 | time_t tv_sec; 22 | time_t tv_usec; 23 | }; 24 | 25 | int gettimeofday(struct timeval* t, void* tzp); 26 | ]]) 27 | end 28 | 29 | ffi.cdef([[ 30 | struct rusage { 31 | struct timeval ru_utime; /* user time used */ 32 | struct timeval ru_stime; /* system time used */ 33 | long ru_maxrss; /* integral max resident set size */ 34 | long ru_ixrss; /* integral shared text memory size */ 35 | long ru_idrss; /* integral unshared data size */ 36 | long ru_isrss; /* integral unshared stack size */ 37 | long ru_minflt; /* page reclaims */ 38 | long ru_majflt; /* page faults */ 39 | long ru_nswap; /* swaps */ 40 | long ru_inblock; /* block input operations */ 41 | long ru_oublock; /* block output operations */ 42 | long ru_msgsnd; /* messages sent */ 43 | long ru_msgrcv; /* messages received */ 44 | long ru_nsignals; /* signals received */ 45 | long ru_nvcsw; /* voluntary context switches */ 46 | long ru_nivcsw; /* involuntary context switches */ 47 | }; 48 | 49 | int getrusage(int who, struct rusage *r_usage); 50 | 51 | ]]) 52 | 53 | local Timer = class('torch.Timer') 54 | torch.Timer = Timer 55 | 56 | Timer.RUSAGE_SELF = 0 57 | Timer.RUSAGE_CHILDREN = -1 58 | 59 | function Timer.real() 60 | local time = ffi.new("struct timeval") 61 | ffi.C.gettimeofday(time, nil) 62 | return (tonumber(time.tv_sec) + tonumber(time.tv_usec)/1000000.0) 63 | end 64 | 65 | function Timer.user() 66 | local time = ffi.new("struct rusage") 67 | ffi.C.getrusage(Timer.RUSAGE_SELF, time) 68 | return (tonumber(time.ru_utime.tv_sec) + tonumber(time.ru_utime.tv_usec)/1000000.0) 69 | end 70 | 71 | function Timer.sys() 72 | local time = ffi.new("struct rusage") 73 | ffi.C.getrusage(Timer.RUSAGE_SELF, time) 74 | return (tonumber(time.ru_stime.tv_sec) + tonumber(time.ru_stime.tv_usec)/1000000.0) 75 | end 76 | 77 | function Timer:__init() 78 | self:reset() 79 | end 80 | 81 | function Timer:reset() 82 | self.__isRunning = true 83 | self.__totalrealtime = 0 84 | self.__totalusertime = 0 85 | self.__totalsystime = 0 86 | self.__startrealtime = Timer.real() 87 | self.__startusertime = Timer.user() 88 | self.__startsystime = Timer.sys() 89 | return self 90 | end 91 | 92 | function Timer:stop() 93 | if self.__isRunning then 94 | self.__totalrealtime = self.__totalrealtime + Timer.real() - self.__startrealtime 95 | self.__totalusertime = self.__totalusertime + Timer.user() - self.__startusertime 96 | self.__totalsystime = self.__totalsystime + Timer.sys() - self.__startsystime 97 | self.__isRunning = false 98 | end 99 | end 100 | 101 | function Timer:resume() 102 | if not self.__isRunning then 103 | self.__startrealtime = Timer.real() 104 | self.__startusertime = Timer.user() 105 | self.__startsystime = Timer.sys() 106 | self.__isRunning = true 107 | end 108 | end 109 | 110 | function Timer:time() 111 | return { 112 | real = self.__isRunning and (self.__totalrealtime + Timer.real() - self.__startrealtime) or self.__totalrealtime, 113 | user = self.__isRunning and (self.__totalusertime + Timer.user() - self.__startusertime) or self.__totalusertime, 114 | sys = self.__isRunning and (self.__totalsystime + Timer.sys() - self.__startsystime) or self.__totalsystime 115 | } 116 | end 117 | -------------------------------------------------------------------------------- /TH/THFile.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_FILE_INC 2 | #define TH_FILE_INC 3 | 4 | #include "THStorage.h" 5 | 6 | typedef struct THFile__ THFile; 7 | 8 | TH_API int THFile_isOpened(THFile *self); 9 | TH_API int THFile_isQuiet(THFile *self); 10 | TH_API int THFile_isReadable(THFile *self); 11 | TH_API int THFile_isWritable(THFile *self); 12 | TH_API int THFile_isBinary(THFile *self); 13 | TH_API int THFile_isAutoSpacing(THFile *self); 14 | TH_API int THFile_hasError(THFile *self); 15 | 16 | TH_API void THFile_binary(THFile *self); 17 | TH_API void THFile_ascii(THFile *self); 18 | TH_API void THFile_autoSpacing(THFile *self); 19 | TH_API void THFile_noAutoSpacing(THFile *self); 20 | TH_API void THFile_quiet(THFile *self); 21 | TH_API void THFile_pedantic(THFile *self); 22 | TH_API void THFile_clearError(THFile *self); 23 | 24 | /* scalar */ 25 | TH_API unsigned char THFile_readByteScalar(THFile *self); 26 | TH_API char THFile_readCharScalar(THFile *self); 27 | TH_API short THFile_readShortScalar(THFile *self); 28 | TH_API int THFile_readIntScalar(THFile *self); 29 | TH_API long THFile_readLongScalar(THFile *self); 30 | TH_API float THFile_readFloatScalar(THFile *self); 31 | TH_API double THFile_readDoubleScalar(THFile *self); 32 | 33 | TH_API void THFile_writeByteScalar(THFile *self, unsigned char scalar); 34 | TH_API void THFile_writeCharScalar(THFile *self, char scalar); 35 | TH_API void THFile_writeShortScalar(THFile *self, short scalar); 36 | TH_API void THFile_writeIntScalar(THFile *self, int scalar); 37 | TH_API void THFile_writeLongScalar(THFile *self, long scalar); 38 | TH_API void THFile_writeFloatScalar(THFile *self, float scalar); 39 | TH_API void THFile_writeDoubleScalar(THFile *self, double scalar); 40 | 41 | /* storage */ 42 | TH_API long THFile_readByte(THFile *self, THByteStorage *storage); 43 | TH_API long THFile_readChar(THFile *self, THCharStorage *storage); 44 | TH_API long THFile_readShort(THFile *self, THShortStorage *storage); 45 | TH_API long THFile_readInt(THFile *self, THIntStorage *storage); 46 | TH_API long THFile_readLong(THFile *self, THLongStorage *storage); 47 | TH_API long THFile_readFloat(THFile *self, THFloatStorage *storage); 48 | TH_API long THFile_readDouble(THFile *self, THDoubleStorage *storage); 49 | 50 | TH_API long THFile_writeByte(THFile *self, THByteStorage *storage); 51 | TH_API long THFile_writeChar(THFile *self, THCharStorage *storage); 52 | TH_API long THFile_writeShort(THFile *self, THShortStorage *storage); 53 | TH_API long THFile_writeInt(THFile *self, THIntStorage *storage); 54 | TH_API long THFile_writeLong(THFile *self, THLongStorage *storage); 55 | TH_API long THFile_writeFloat(THFile *self, THFloatStorage *storage); 56 | TH_API long THFile_writeDouble(THFile *self, THDoubleStorage *storage); 57 | 58 | /* raw */ 59 | TH_API long THFile_readByteRaw(THFile *self, unsigned char *data, long n); 60 | TH_API long THFile_readCharRaw(THFile *self, char *data, long n); 61 | TH_API long THFile_readShortRaw(THFile *self, short *data, long n); 62 | TH_API long THFile_readIntRaw(THFile *self, int *data, long n); 63 | TH_API long THFile_readLongRaw(THFile *self, long *data, long n); 64 | TH_API long THFile_readFloatRaw(THFile *self, float *data, long n); 65 | TH_API long THFile_readDoubleRaw(THFile *self, double *data, long n); 66 | TH_API long THFile_readStringRaw(THFile *self, const char *format, char **str_); /* you must deallocate str_ */ 67 | 68 | TH_API long THFile_writeByteRaw(THFile *self, unsigned char *data, long n); 69 | TH_API long THFile_writeCharRaw(THFile *self, char *data, long n); 70 | TH_API long THFile_writeShortRaw(THFile *self, short *data, long n); 71 | TH_API long THFile_writeIntRaw(THFile *self, int *data, long n); 72 | TH_API long THFile_writeLongRaw(THFile *self, long *data, long n); 73 | TH_API long THFile_writeFloatRaw(THFile *self, float *data, long n); 74 | TH_API long THFile_writeDoubleRaw(THFile *self, double *data, long n); 75 | TH_API long THFile_writeStringRaw(THFile *self, const char *str, long size); 76 | 77 | TH_API void THFile_synchronize(THFile *self); 78 | TH_API void THFile_seek(THFile *self, long position); 79 | TH_API void THFile_seekEnd(THFile *self); 80 | TH_API long THFile_position(THFile *self); 81 | TH_API void THFile_close(THFile *self); 82 | TH_API void THFile_free(THFile *self); 83 | 84 | #endif 85 | -------------------------------------------------------------------------------- /random.lua: -------------------------------------------------------------------------------- 1 | local register_ = require 'torch.register' 2 | local argcheck = require 'argcheck' 3 | local torch = require 'torch.env' 4 | local class = require 'class' 5 | local ffi = require 'ffi' 6 | local C = require 'torch.TH' 7 | 8 | -- DEBUG: should register() be in argcheck? 9 | local function register(args) 10 | return register_(args, torch, class.metatable('torch.Generator')) 11 | end 12 | 13 | local Generator = class('torch.Generator', nil, ffi.typeof('THGenerator&')) 14 | torch.Generator = Generator 15 | 16 | Generator.new = argcheck{ 17 | call = 18 | function() 19 | local self = C.THGenerator_new()[0] 20 | ffi.gc(self, C.THGenerator_free) 21 | return self 22 | end 23 | } 24 | 25 | Generator.__factory = Generator.new 26 | 27 | torch.__generator = torch.__generator or torch.Generator() 28 | 29 | ffi.metatype('THGenerator', class.metatable('torch.Generator')) 30 | 31 | register{ 32 | name = "random", 33 | {name="generator", type="torch.Generator", opt=true, method={opt=false}}, 34 | {name="a", type="number", default=1}, 35 | {name="b", type="number"}, 36 | call = 37 | function(generator, a, b) 38 | generator = generator or torch.__generator 39 | return tonumber(C.THRandom_random(generator)) % (b+1-a)+a 40 | end 41 | } 42 | 43 | register{ 44 | name = "random", 45 | {name="generator", type="torch.Generator", opt=true, method={opt=false}}, 46 | call = 47 | function(generator, b) 48 | return tonumber(C.THRandom_random(generator)) 49 | end 50 | } 51 | 52 | register{ 53 | name = "manualSeed", 54 | {name="generator", type="torch.Generator", opt=true, method={opt=false}}, 55 | {name="seed", type="number"}, 56 | call = 57 | function(generator, seed) 58 | generator = generator or torch.__generator 59 | C.THRandom_manualSeed(generator, seed) 60 | return generator 61 | end 62 | } 63 | 64 | register{ 65 | name = "uniform", 66 | {name="generator", type="torch.Generator", opt=true, method={opt=false}}, 67 | {name="a", type="number", default=0}, 68 | {name="b", type="number", default=1}, 69 | call = 70 | function(generator, a, b) 71 | generator = generator or torch.__generator 72 | return tonumber(C.THRandom_uniform(generator, a, b)) 73 | end 74 | } 75 | 76 | register{ 77 | name = "normal", 78 | {name="generator", type="torch.Generator", opt=true, method={opt=false}}, 79 | {name="a", type="number", default=0}, 80 | {name="b", type="number", default=1}, 81 | call = 82 | function(generator, a, b) 83 | generator = generator or torch.__generator 84 | return tonumber(C.THRandom_normal(generator, a, b)) 85 | end 86 | } 87 | 88 | register{ 89 | name = "cauchy", 90 | {name="generator", type="torch.Generator", opt=true, method={opt=false}}, 91 | {name="a", type="number", default=0}, 92 | {name="b", type="number", default=1}, 93 | call = 94 | function(generator, a, b) 95 | generator = generator or torch.__generator 96 | return tonumber(C.THRandom_cauchy(generator, a, b)) 97 | end 98 | } 99 | 100 | register{ 101 | name = "logNormal", 102 | {name="generator", type="torch.Generator", opt=true, method={opt=false}}, 103 | {name="a", type="number", default=1}, 104 | {name="b", type="number", default=2}, 105 | call = 106 | function(generator, a, b) 107 | generator = generator or torch.__generator 108 | return tonumber(C.THRandom_logNormal(generator, a, b)) 109 | end 110 | } 111 | 112 | register{ 113 | name = "exponential", 114 | {name="generator", type="torch.Generator", opt=true, method={opt=false}}, 115 | {name="a", type="number", default=1}, 116 | call = 117 | function(generator, a) 118 | generator = generator or torch.__generator 119 | return tonumber(C.THRandom_exponential(generator, a)) 120 | end 121 | } 122 | 123 | register{ 124 | name = "geometric", 125 | {name="generator", type="torch.Generator", opt=true, method={opt=false}}, 126 | {name="a", type="number"}, 127 | call = 128 | function(generator, a) 129 | generator = generator or torch.__generator 130 | return tonumber(C.THRandom_geometric(generator, a)) 131 | end 132 | } 133 | 134 | register{ 135 | name = "bernoulli", 136 | {name="generator", type="torch.Generator", opt=true, method={opt=false}}, 137 | {name="a", type="number", default=0.5}, 138 | call = 139 | function(generator, a) 140 | generator = generator or torch.__generator 141 | return tonumber(C.THRandom_bernoulli(generator, a)) 142 | end 143 | } 144 | -------------------------------------------------------------------------------- /apply.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch.env' 2 | 3 | -- NOTE: 4 | -- the c1, c2... c5 trick is due to VARG not being compiled in luaJIT 5 | 6 | local function generate_apply(dim) 7 | local func = {} 8 | local funcarg = {} 9 | for n=1,#dim do 10 | table.insert(funcarg, string.format('t%d', n)) 11 | end 12 | table.insert(func, string.format('return function(%s, func, c1, c2, c3, c4, c5)', table.concat(funcarg, ', '))) 13 | for n=1,#dim do 14 | for i=0,dim[n]-1 do 15 | table.insert(func, string.format('local t%dsz%d, t%dst%d = tonumber(t%d.__size[%d]), tonumber(t%d.__stride[%d])', n, i, n, i, n, i, n, i)) 16 | end 17 | table.insert(func, string.format('local t%ddata = t%d.__storage.__data + t%d.__storageOffset', n, n, n)) 18 | end 19 | for n=1,#dim do 20 | for i=0,dim[n]-1 do 21 | table.insert(func, string.format('local t%di%d = 0', n, i)) 22 | end 23 | end 24 | local cond = {} 25 | for n=1,#dim do 26 | if dim[n] > 1 then 27 | table.insert(cond, string.format('t%di0 < t%dsz0', n, n)) 28 | end 29 | end 30 | 31 | if #cond > 0 then 32 | table.insert(func, string.format('while %s do', table.concat(cond, ' and '))) 33 | end 34 | 35 | local maxarg = {} 36 | for n=1,#dim do 37 | if dim[n] > 0 then 38 | table.insert(maxarg, string.format('t%dsz%d-t%di%d', n, dim[n]-1, n, dim[n]-1)) 39 | else 40 | table.insert(maxarg, '0') 41 | end 42 | end 43 | table.insert(func, string.format('local r = math.min(%s)', table.concat(maxarg, ', '))) 44 | 45 | -- do stuff 46 | local funcarg = {} 47 | for n=1,#dim do 48 | local data = {string.format('t%ddata', n)} 49 | for i=0,dim[n]-1 do 50 | table.insert(data, string.format(' + t%di%d*t%dst%d', n, i, n, i)) 51 | end 52 | table.insert(funcarg, table.concat(data, '')) 53 | if dim[n] > 0 then 54 | table.insert(funcarg, string.format('t%dst%d', n, dim[n]-1)) 55 | else 56 | table.insert(funcarg, '0') 57 | end 58 | end 59 | table.insert(func, string.format('func(r, %s, c1, c2, c3, c4, c5)', table.concat(funcarg, ', '))) 60 | 61 | 62 | for n=1,#dim do 63 | if dim[n] > 0 then 64 | table.insert(func, string.format('t%di%d = t%di%d + r', n, dim[n]-1, n, dim[n]-1)) 65 | end 66 | if dim[n] > 1 then 67 | table.insert(func, string.format('if t%di%d == t%dsz%d then', n, dim[n]-1, n, dim[n]-1)) 68 | table.insert(func, string.format('t%di%d = 0', n, dim[n]-1)) 69 | for i=dim[n]-2,0,-1 do 70 | table.insert(func, string.format('t%di%d = t%di%d + 1', n, i, n, i)) 71 | if i > 0 then 72 | table.insert(func, string.format('if t%di%d == t%dsz%d then', n, i, n, i)) 73 | table.insert(func, string.format('t%di%d = 0', n, i)) 74 | end 75 | end 76 | for i=dim[n]-2,1,-1 do 77 | table.insert(func, 'end') 78 | end 79 | table.insert(func, 'end') 80 | end 81 | end 82 | 83 | if #cond > 0 then 84 | table.insert(func, 'end') 85 | end 86 | table.insert(func, 'end') 87 | return table.concat(func, '\n') 88 | end 89 | 90 | local applyfuncs = {} 91 | function torch.rawapply(t1, func, c1, c2, c3, c4, c5) 92 | local dim = tonumber(t1.__nDimension) 93 | local applyfunc = applyfuncs[dim] 94 | if not applyfunc then 95 | applyfunc = loadstring(generate_apply({dim}))() 96 | applyfuncs[dim] = applyfunc 97 | end 98 | applyfunc(t1, func, c1, c2, c3, c4, c5) 99 | end 100 | 101 | local apply2funcs = {} 102 | function torch.rawapply2(t1, t2, func, c1, c2, c3, c4, c5) 103 | local dim1 = tonumber(t1.__nDimension) 104 | local dim2 = tonumber(t2.__nDimension) 105 | apply2funcs[dim1] = apply2funcs[dim1] or {} 106 | local applyfunc = apply2funcs[dim1][dim2] 107 | if not applyfunc then 108 | applyfunc = loadstring(generate_apply({dim1,dim2}))() 109 | apply2funcs[dim1][dim2] = applyfunc 110 | end 111 | applyfunc(t1, t2, func, c1, c2, c3, c4, c5) 112 | end 113 | 114 | local apply3funcs = {} 115 | function torch.rawapply3(t1, t2, t3, func, c1, c2, c3, c4, c5) 116 | local dim1 = tonumber(t1.__nDimension) 117 | local dim2 = tonumber(t2.__nDimension) 118 | local dim3 = tonumber(t3.__nDimension) 119 | apply3funcs[dim1] = apply3funcs[dim1] or {} 120 | apply3funcs[dim1][dim2] = apply3funcs[dim1][dim2] or {} 121 | local applyfunc = apply3funcs[dim1][dim2][dim3] 122 | if not applyfunc then 123 | applyfunc = loadstring(generate_apply({dim1,dim2,dim3}))() 124 | apply3funcs[dim1][dim2][dim3] = applyfunc 125 | end 126 | applyfunc(t1, t2, t3, func, c1, c2, c3, c4, c5) 127 | end 128 | -------------------------------------------------------------------------------- /TH/generic/THTensorConv.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THTensorConv.h" 3 | #else 4 | 5 | 6 | TH_API void THTensor_(validXCorr2Dptr)(real *r_, 7 | real alpha, 8 | real *t_, long ir, long ic, 9 | real *k_, long kr, long kc, 10 | long sr, long sc); 11 | 12 | TH_API void THTensor_(validConv2Dptr)(real *r_, 13 | real alpha, 14 | real *t_, long ir, long ic, 15 | real *k_, long kr, long kc, 16 | long sr, long sc); 17 | 18 | TH_API void THTensor_(fullXCorr2Dptr)(real *r_, 19 | real alpha, 20 | real *t_, long ir, long ic, 21 | real *k_, long kr, long kc, 22 | long sr, long sc); 23 | 24 | TH_API void THTensor_(fullConv2Dptr)(real *r_, 25 | real alpha, 26 | real *t_, long ir, long ic, 27 | real *k_, long kr, long kc, 28 | long sr, long sc); 29 | 30 | TH_API void THTensor_(validXCorr2DRevptr)(real *r_, 31 | real alpha, 32 | real *t_, long ir, long ic, 33 | real *k_, long kr, long kc, 34 | long sr, long sc); 35 | 36 | TH_API void THTensor_(conv2DRevger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol); 37 | TH_API void THTensor_(conv2DRevgerm)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol); 38 | TH_API void THTensor_(conv2Dger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc); 39 | TH_API void THTensor_(conv2Dmv)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc); 40 | TH_API void THTensor_(conv2Dmm)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc); 41 | TH_API void THTensor_(conv2Dmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc); 42 | TH_API void THTensor_(conv2Dcmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc); 43 | 44 | TH_API void THTensor_(validXCorr3Dptr)(real *r_, 45 | real alpha, 46 | real *t_, long it, long ir, long ic, 47 | real *k_, long kt, long kr, long kc, 48 | long st, long sr, long sc); 49 | 50 | TH_API void THTensor_(validConv3Dptr)(real *r_, 51 | real alpha, 52 | real *t_, long it, long ir, long ic, 53 | real *k_, long kt, long kr, long kc, 54 | long st, long sr, long sc); 55 | 56 | TH_API void THTensor_(fullXCorr3Dptr)(real *r_, 57 | real alpha, 58 | real *t_, long it, long ir, long ic, 59 | real *k_, long kt, long kr, long kc, 60 | long st, long sr, long sc); 61 | 62 | TH_API void THTensor_(fullConv3Dptr)(real *r_, 63 | real alpha, 64 | real *t_, long it, long ir, long ic, 65 | real *k_, long kt, long kr, long kc, 66 | long st, long sr, long sc); 67 | 68 | TH_API void THTensor_(validXCorr3DRevptr)(real *r_, 69 | real alpha, 70 | real *t_, long it, long ir, long ic, 71 | real *k_, long kt, long kr, long kc, 72 | long st, long sr, long sc); 73 | 74 | TH_API void THTensor_(conv3DRevger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol); 75 | TH_API void THTensor_(conv3Dger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc); 76 | TH_API void THTensor_(conv3Dmv)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc); 77 | TH_API void THTensor_(conv3Dmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc); 78 | TH_API void THTensor_(conv3Dcmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc); 79 | 80 | #endif 81 | -------------------------------------------------------------------------------- /TH/THFile.c: -------------------------------------------------------------------------------- 1 | #include "THFile.h" 2 | #include "THFilePrivate.h" 3 | 4 | #define IMPLEMENT_THFILE_RW(TYPEC, TYPE) \ 5 | long THFile_read##TYPEC##Raw(THFile *self, TYPE *data, long n) \ 6 | { \ 7 | return (*self->vtable->read##TYPEC)(self, data, n); \ 8 | } \ 9 | \ 10 | long THFile_write##TYPEC##Raw(THFile *self, TYPE *data, long n) \ 11 | { \ 12 | return (*self->vtable->write##TYPEC)(self, data, n); \ 13 | } 14 | 15 | IMPLEMENT_THFILE_RW(Byte, unsigned char) 16 | IMPLEMENT_THFILE_RW(Char, char) 17 | IMPLEMENT_THFILE_RW(Short, short) 18 | IMPLEMENT_THFILE_RW(Int, int) 19 | IMPLEMENT_THFILE_RW(Long, long) 20 | IMPLEMENT_THFILE_RW(Float, float) 21 | IMPLEMENT_THFILE_RW(Double, double) 22 | 23 | long THFile_readStringRaw(THFile *self, const char *format, char **str_) 24 | { 25 | return self->vtable->readString(self, format, str_); 26 | } 27 | 28 | long THFile_writeStringRaw(THFile *self, const char *str, long size) 29 | { 30 | return self->vtable->writeString(self, str, size); 31 | } 32 | 33 | void THFile_synchronize(THFile *self) 34 | { 35 | self->vtable->synchronize(self); 36 | } 37 | 38 | void THFile_seek(THFile *self, long position) 39 | { 40 | self->vtable->seek(self, position); 41 | } 42 | 43 | void THFile_seekEnd(THFile *self) 44 | { 45 | self->vtable->seekEnd(self); 46 | } 47 | 48 | long THFile_position(THFile *self) 49 | { 50 | return self->vtable->position(self); 51 | } 52 | 53 | void THFile_close(THFile *self) 54 | { 55 | self->vtable->close(self); 56 | } 57 | 58 | void THFile_free(THFile *self) 59 | { 60 | self->vtable->free(self); 61 | } 62 | 63 | int THFile_isOpened(THFile *self) 64 | { 65 | return self->vtable->isOpened(self); 66 | } 67 | 68 | #define IMPLEMENT_THFILE_FLAGS(FLAG) \ 69 | int THFile_##FLAG(THFile *self) \ 70 | { \ 71 | return self->FLAG; \ 72 | } 73 | 74 | IMPLEMENT_THFILE_FLAGS(isQuiet) 75 | IMPLEMENT_THFILE_FLAGS(isReadable) 76 | IMPLEMENT_THFILE_FLAGS(isWritable) 77 | IMPLEMENT_THFILE_FLAGS(isBinary) 78 | IMPLEMENT_THFILE_FLAGS(isAutoSpacing) 79 | IMPLEMENT_THFILE_FLAGS(hasError) 80 | 81 | void THFile_binary(THFile *self) 82 | { 83 | self->isBinary = 1; 84 | } 85 | 86 | void THFile_ascii(THFile *self) 87 | { 88 | self->isBinary = 0; 89 | } 90 | 91 | void THFile_autoSpacing(THFile *self) 92 | { 93 | self->isAutoSpacing = 1; 94 | } 95 | 96 | void THFile_noAutoSpacing(THFile *self) 97 | { 98 | self->isAutoSpacing = 0; 99 | } 100 | 101 | void THFile_quiet(THFile *self) 102 | { 103 | self->isQuiet = 1; 104 | } 105 | 106 | void THFile_pedantic(THFile *self) 107 | { 108 | self->isQuiet = 0; 109 | } 110 | 111 | void THFile_clearError(THFile *self) 112 | { 113 | self->hasError = 0; 114 | } 115 | 116 | #define IMPLEMENT_THFILE_SCALAR(TYPEC, TYPE) \ 117 | TYPE THFile_read##TYPEC##Scalar(THFile *self) \ 118 | { \ 119 | TYPE scalar; \ 120 | THFile_read##TYPEC##Raw(self, &scalar, 1); \ 121 | return scalar; \ 122 | } \ 123 | \ 124 | void THFile_write##TYPEC##Scalar(THFile *self, TYPE scalar) \ 125 | { \ 126 | THFile_write##TYPEC##Raw(self, &scalar, 1); \ 127 | } 128 | 129 | IMPLEMENT_THFILE_SCALAR(Byte, unsigned char) 130 | IMPLEMENT_THFILE_SCALAR(Char, char) 131 | IMPLEMENT_THFILE_SCALAR(Short, short) 132 | IMPLEMENT_THFILE_SCALAR(Int, int) 133 | IMPLEMENT_THFILE_SCALAR(Long, long) 134 | IMPLEMENT_THFILE_SCALAR(Float, float) 135 | IMPLEMENT_THFILE_SCALAR(Double, double) 136 | 137 | #define IMPLEMENT_THFILE_STORAGE(TYPEC, TYPE) \ 138 | long THFile_read##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \ 139 | { \ 140 | return THFile_read##TYPEC##Raw(self, storage->data, storage->size); \ 141 | } \ 142 | \ 143 | long THFile_write##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \ 144 | { \ 145 | return THFile_write##TYPEC##Raw(self, storage->data, storage->size); \ 146 | } 147 | 148 | IMPLEMENT_THFILE_STORAGE(Byte, unsigned char) 149 | IMPLEMENT_THFILE_STORAGE(Char, char) 150 | IMPLEMENT_THFILE_STORAGE(Short, short) 151 | IMPLEMENT_THFILE_STORAGE(Int, int) 152 | IMPLEMENT_THFILE_STORAGE(Long, long) 153 | IMPLEMENT_THFILE_STORAGE(Float, float) 154 | IMPLEMENT_THFILE_STORAGE(Double, double) 155 | -------------------------------------------------------------------------------- /conv.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local argcheck = require 'argcheck' 3 | local torch = require 'torch.env' 4 | local class = require 'class' 5 | local C = require 'torch.TH' 6 | local register_ = require 'torch.registernumbers' 7 | 8 | -- handle method/function 9 | local function register(args) 10 | if args.nomethod and not args.nofunction then 11 | return register_(args, torch, nil) 12 | elseif args.nofunction and not args.nomethod then 13 | return register_(args, nil, class.metatable('torch.RealTensor')) 14 | else 15 | return register_(args, torch, class.metatable('torch.RealTensor')) 16 | end 17 | end 18 | 19 | register{ 20 | name = "conv2", 21 | {name="dst", type="torch.RealTensor", opt=true, method={opt=false}}, 22 | {name="src1", type="torch.RealTensor", method={opt=true}}, 23 | {name="src2", type="torch.RealTensor"}, 24 | {name="opt", type="string", default='V'}, 25 | call = 26 | function(dst, src1, src2, opt) 27 | assert(opt == 'F' or opt == 'V', 'option must be F or V') 28 | local res = src1 and dst or torch.RealTensor() 29 | src1 = src1 or dst 30 | if src1.__nDimension == 2 and src2.__nDimension == 2 then 31 | C.THRealTensor_conv2Dmul(res, 0, 1, src1, src2, 1, 1, opt, 'C') 32 | elseif src1.__nDimension == 3 and src2.__nDimension == 3 then 33 | C.THRealTensor_conv2Dcmul(res, 0, 1, src1, src2, 1, 1, opt, 'C') 34 | elseif src1.__nDimension == 3 and src2.__nDimension == 4 then 35 | C.THRealTensor_conv2Dmv(res, 0, 1, src1, src2, 1, 1, opt, 'C') 36 | else 37 | error('invalid source dimensions (expected: 2/2 or 3/3 or 3/4') 38 | end 39 | return res 40 | end 41 | } 42 | 43 | register{ 44 | name = "xcorr2", 45 | {name="dst", type="torch.RealTensor", opt=true, method={opt=false}}, 46 | {name="src1", type="torch.RealTensor", method={opt=true}}, 47 | {name="src2", type="torch.RealTensor"}, 48 | {name="opt", type="string", default='V'}, 49 | call = 50 | function(dst, src1, src2, opt) 51 | assert(opt == 'F' or opt == 'V', 'option must be F or V') 52 | local res = src1 and dst or torch.RealTensor() 53 | src1 = src1 or dst 54 | if src1.__nDimension == 2 and src2.__nDimension == 2 then 55 | C.THRealTensor_conv2Dmul(res, 0, 1, src1, src2, 1, 1, opt, 'X') 56 | elseif src1.__nDimension == 3 and src2.__nDimension == 3 then 57 | C.THRealTensor_conv2Dcmul(res, 0, 1, src1, src2, 1, 1, opt, 'X') 58 | elseif src1.__nDimension == 3 and src2.__nDimension == 4 then 59 | C.THRealTensor_conv2Dmv(res, 0, 1, src1, src2, 1, 1, opt, 'X') 60 | else 61 | error('invalid source dimensions (expected: 2/2 or 3/3 or 3/4') 62 | end 63 | return res 64 | end 65 | } 66 | 67 | register{ 68 | name = "conv3", 69 | {name="dst", type="torch.RealTensor", opt=true, method={opt=false}}, 70 | {name="src1", type="torch.RealTensor", method={opt=true}}, 71 | {name="src2", type="torch.RealTensor"}, 72 | {name="opt", type="string", default='V'}, 73 | call = 74 | function(dst, src1, src2, opt) 75 | assert(opt == 'F' or opt == 'V', 'option must be F or V') 76 | local res = src1 and dst or torch.RealTensor() 77 | src1 = src1 or dst 78 | if src1.__nDimension == 3 and src2.__nDimension == 3 then 79 | C.THRealTensor_conv3Dmul(res, 0, 1, src1, src2, 1, 1, 1, opt, 'C') 80 | elseif src1.__nDimension == 4 and src2.__nDimension == 4 then 81 | C.THRealTensor_conv3Dcmul(res, 0, 1, src1, src2, 1, 1, 1, opt, 'C') 82 | elseif src1.__nDimension == 4 and src2.__nDimension == 5 then 83 | C.THRealTensor_conv3Dmv(res, 0, 1, src1, src2, 1, 1, 1, opt, 'C') 84 | else 85 | error('invalid source dimensions (expected: 2/2 or 3/3 or 3/4') 86 | end 87 | return res 88 | end 89 | } 90 | 91 | register{ 92 | name = "xcorr3", 93 | {name="dst", type="torch.RealTensor", opt=true, method={opt=false}}, 94 | {name="src1", type="torch.RealTensor", method={opt=true}}, 95 | {name="src2", type="torch.RealTensor"}, 96 | {name="opt", type="string", default='V'}, 97 | call = 98 | function(dst, src1, src2, opt) 99 | assert(opt == 'F' or opt == 'V', 'option must be F or V') 100 | local res = src1 and dst or torch.RealTensor() 101 | src1 = src1 or dst 102 | if src1.__nDimension == 3 and src2.__nDimension == 3 then 103 | C.THRealTensor_conv3Dmul(res, 0, 1, src1, src2, 1, 1, 1, opt, 'X') 104 | elseif src1.__nDimension == 4 and src2.__nDimension == 4 then 105 | C.THRealTensor_conv3Dcmul(res, 0, 1, src1, src2, 1, 1, 1, opt, 'X') 106 | elseif src1.__nDimension == 4 and src2.__nDimension == 5 then 107 | C.THRealTensor_conv3Dmv(res, 0, 1, src1, src2, 1, 1, 1, opt, 'X') 108 | else 109 | error('invalid source dimensions (expected: 3/3 or 4/4 or 4/5') 110 | end 111 | return res 112 | end 113 | } 114 | -------------------------------------------------------------------------------- /TH/generic/THStorage.c: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THStorage.c" 3 | #else 4 | 5 | real* THStorage_(data)(const THStorage *self) 6 | { 7 | return self->data; 8 | } 9 | 10 | long THStorage_(size)(const THStorage *self) 11 | { 12 | return self->size; 13 | } 14 | 15 | THStorage* THStorage_(new)(void) 16 | { 17 | return THStorage_(newWithSize)(0); 18 | } 19 | 20 | THStorage* THStorage_(newWithSize)(long size) 21 | { 22 | return THStorage_(newWithAllocator)(size, &THDefaultAllocator, NULL); 23 | } 24 | 25 | THStorage* THStorage_(newWithAllocator)(long size, 26 | THAllocator *allocator, 27 | void *allocatorContext) 28 | { 29 | THStorage *storage = THAlloc(sizeof(THStorage)); 30 | storage->data = allocator->malloc(allocatorContext, sizeof(real)*size); 31 | storage->size = size; 32 | storage->refcount = 1; 33 | storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; 34 | storage->allocator = allocator; 35 | storage->allocatorContext = allocatorContext; 36 | return storage; 37 | } 38 | 39 | THStorage* THStorage_(newWithMapping)(const char *filename, long size, int shared) 40 | { 41 | THMapAllocatorContext *ctx = THMapAllocatorContext_new(filename, shared); 42 | 43 | THStorage *storage = THStorage_(newWithAllocator)(size, 44 | &THMapAllocator, 45 | ctx); 46 | 47 | if(size <= 0) 48 | storage->size = THMapAllocatorContext_size(ctx)/sizeof(real); 49 | 50 | THStorage_(clearFlag)(storage, TH_STORAGE_RESIZABLE); 51 | 52 | return storage; 53 | } 54 | 55 | THStorage* THStorage_(newWithSize1)(real data0) 56 | { 57 | THStorage *self = THStorage_(newWithSize)(1); 58 | self->data[0] = data0; 59 | return self; 60 | } 61 | 62 | THStorage* THStorage_(newWithSize2)(real data0, real data1) 63 | { 64 | THStorage *self = THStorage_(newWithSize)(2); 65 | self->data[0] = data0; 66 | self->data[1] = data1; 67 | return self; 68 | } 69 | 70 | THStorage* THStorage_(newWithSize3)(real data0, real data1, real data2) 71 | { 72 | THStorage *self = THStorage_(newWithSize)(3); 73 | self->data[0] = data0; 74 | self->data[1] = data1; 75 | self->data[2] = data2; 76 | return self; 77 | } 78 | 79 | THStorage* THStorage_(newWithSize4)(real data0, real data1, real data2, real data3) 80 | { 81 | THStorage *self = THStorage_(newWithSize)(4); 82 | self->data[0] = data0; 83 | self->data[1] = data1; 84 | self->data[2] = data2; 85 | self->data[3] = data3; 86 | return self; 87 | } 88 | 89 | void THStorage_(setFlag)(THStorage *storage, const char flag) 90 | { 91 | storage->flag |= flag; 92 | } 93 | 94 | void THStorage_(clearFlag)(THStorage *storage, const char flag) 95 | { 96 | storage->flag &= ~flag; 97 | } 98 | 99 | void THStorage_(retain)(THStorage *storage) 100 | { 101 | if(storage && (storage->flag & TH_STORAGE_REFCOUNTED)) 102 | ++storage->refcount; 103 | } 104 | 105 | void THStorage_(free)(THStorage *storage) 106 | { 107 | if(!storage) 108 | return; 109 | 110 | if((storage->flag & TH_STORAGE_REFCOUNTED) && (storage->refcount > 0)) 111 | { 112 | if(--storage->refcount == 0) 113 | { 114 | if(storage->flag & TH_STORAGE_FREEMEM) 115 | storage->allocator->free(storage->allocatorContext, storage->data); 116 | THFree(storage); 117 | } 118 | } 119 | } 120 | 121 | THStorage* THStorage_(newWithData)(real *data, long size) 122 | { 123 | return THStorage_(newWithDataAndAllocator)(data, size, 124 | &THDefaultAllocator, NULL); 125 | } 126 | 127 | THStorage* THStorage_(newWithDataAndAllocator)(real* data, long size, 128 | THAllocator* allocator, 129 | void* allocatorContext) { 130 | THStorage *storage = THAlloc(sizeof(THStorage)); 131 | storage->data = data; 132 | storage->size = size; 133 | storage->refcount = 1; 134 | storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; 135 | storage->allocator = allocator; 136 | storage->allocatorContext = allocatorContext; 137 | return storage; 138 | } 139 | 140 | void THStorage_(resize)(THStorage *storage, long size) 141 | { 142 | if(storage->flag & TH_STORAGE_RESIZABLE) 143 | { 144 | storage->data = storage->allocator->realloc( 145 | storage->allocatorContext, 146 | storage->data, 147 | sizeof(real)*size); 148 | storage->size = size; 149 | } 150 | } 151 | 152 | void THStorage_(fill)(THStorage *storage, real value) 153 | { 154 | long i; 155 | for(i = 0; i < storage->size; i++) 156 | storage->data[i] = value; 157 | } 158 | 159 | void THStorage_(set)(THStorage *self, long idx, real value) 160 | { 161 | THArgCheck((idx >= 0) && (idx < self->size), 2, "out of bounds"); 162 | self->data[idx] = value; 163 | } 164 | 165 | real THStorage_(get)(const THStorage *self, long idx) 166 | { 167 | THArgCheck((idx >= 0) && (idx < self->size), 2, "out of bounds"); 168 | return self->data[idx]; 169 | } 170 | 171 | #endif 172 | -------------------------------------------------------------------------------- /lapack.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local argcheck = require 'argcheck' 3 | local torch = require 'torch.env' 4 | local class = require 'class' 5 | local C = require 'torch.TH' 6 | local register_ = require 'torch.registernumbers' 7 | 8 | -- handle method/function 9 | local function register(args) 10 | if args.nomethod and not args.nofunction then 11 | return register_(args, torch, nil) 12 | elseif args.nofunction and not args.nomethod then 13 | return register_(args, nil, class.metatable('torch.RealTensor')) 14 | else 15 | return register_(args, torch, class.metatable('torch.RealTensor')) 16 | end 17 | end 18 | 19 | if 'real' == 'float' or 'real' == 'double' then 20 | 21 | register{ 22 | name = "gesv", 23 | {name="B", type="torch.RealTensor"}, 24 | {name="A", type="torch.RealTensor"}, 25 | call = 26 | function(B, A) 27 | local X = torch.RealTensor() 28 | local LU = torch.RealTensor() 29 | C.THRealTensor_gesv(X, LU, B, A) 30 | return X, LU 31 | end 32 | } 33 | 34 | register{ 35 | nomethod = true, 36 | name = "gesv", 37 | {name="X", type="torch.RealTensor"}, 38 | {name="LU", type="torch.RealTensor"}, 39 | {name="B", type="torch.RealTensor"}, 40 | {name="A", type="torch.RealTensor"}, 41 | call = 42 | function(X, LU, B, A) 43 | C.THRealTensor_gesv(X, LU, B, A) 44 | return X, LU 45 | end 46 | } 47 | 48 | register{ 49 | name = "gels", 50 | {name="B", type="torch.RealTensor"}, 51 | {name="A", type="torch.RealTensor"}, 52 | call = 53 | function(B, A) 54 | local X = torch.RealTensor() 55 | local LU = torch.RealTensor() 56 | C.THRealTensor_gels(X, LU, B, A) 57 | return X, LU 58 | end 59 | } 60 | 61 | register{ 62 | nomethod = true, 63 | name = "gels", 64 | {name="X", type="torch.RealTensor"}, 65 | {name="LU", type="torch.RealTensor"}, 66 | {name="B", type="torch.RealTensor"}, 67 | {name="A", type="torch.RealTensor"}, 68 | call = 69 | function(X, LU, B, A) 70 | C.THRealTensor_gels(X, LU, B, A) 71 | return X, LU 72 | end 73 | } 74 | 75 | register{ 76 | name = "symeig", 77 | {name="A", type="torch.RealTensor"}, 78 | {name="opteig", type="string", default='N'}, 79 | {name="opttriang", type="string", default='U'}, 80 | call = 81 | function(A, opteig, opttriang) 82 | assert(opteig == 'N' or opteig == 'V', 'opteig: N or V expected') 83 | assert(opttriang == 'L' or opttriang == 'U', 'opttriang: L or U expected') 84 | local E = torch.RealTensor() 85 | local V = torch.RealTensor() 86 | C.THRealTensor_syev(E, V, A, opteig, opttriang) 87 | return E, V 88 | end 89 | } 90 | 91 | register{ 92 | nomethod = true, 93 | name = "symeig", 94 | {name="E", type="torch.RealTensor"}, 95 | {name="V", type="torch.RealTensor"}, 96 | {name="A", type="torch.RealTensor"}, 97 | {name="opteig", type="string", default='N'}, 98 | {name="opttriang", type="string", default='U'}, 99 | call = 100 | function(E, V, A, opteig, opttriang) 101 | assert(opteig == 'N' or opteig == 'V', 'opteig: N or V expected') 102 | assert(opttriang == 'L' or opttriang == 'U', 'opttriang: L or U expected') 103 | C.THRealTensor_syev(E, V, A, opteig, opttriang) 104 | return E, V 105 | end 106 | } 107 | 108 | register{ 109 | name = "eig", 110 | {name="A", type="torch.RealTensor"}, 111 | {name="opteig", type="string", default='N'}, 112 | call = 113 | function(A, opteig) 114 | assert(opteig == 'N' or opteig == 'V', 'opteig: N or V expected') 115 | local E = torch.RealTensor() 116 | local V = torch.RealTensor() 117 | C.THRealTensor_geev(E, V, A, opteig) 118 | return E, V 119 | end 120 | } 121 | 122 | register{ 123 | nomethod = true, 124 | name = "eig", 125 | {name="E", type="torch.RealTensor"}, 126 | {name="V", type="torch.RealTensor"}, 127 | {name="A", type="torch.RealTensor"}, 128 | {name="opteig", type="string", default='N'}, 129 | call = 130 | function(E, V, A, opteig, opttriang) 131 | assert(opteig == 'N' or opteig == 'V', 'opteig: N or V expected') 132 | C.THRealTensor_geev(E, V, A, opteig) 133 | return E, V 134 | end 135 | } 136 | 137 | register{ 138 | name = "svd", 139 | {name="A", type="torch.RealTensor"}, 140 | {name="opteig", type="string", default='S'}, 141 | call = 142 | function(A, opteig) 143 | assert(opteig == 'S' or opteig == 'A', 'opteig: S or A expected') 144 | local U = torch.RealTensor() 145 | local S = torch.RealTensor() 146 | local V = torch.RealTensor() 147 | C.THRealTensor_gesvd(U, S, V, A, opteig) 148 | return U, S, V 149 | end 150 | } 151 | 152 | register{ 153 | nomethod = true, 154 | name = "svd", 155 | {name="U", type="torch.RealTensor"}, 156 | {name="S", type="torch.RealTensor"}, 157 | {name="V", type="torch.RealTensor"}, 158 | {name="A", type="torch.RealTensor"}, 159 | {name="opteig", type="string", default='S'}, 160 | call = 161 | function(U, S, V, A, opteig) 162 | assert(opteig == 'S' or opteig == 'A', 'opteig: S or A expected') 163 | C.THRealTensor_gesvd(U, S, V, A, opteig) 164 | return U, S, V 165 | end 166 | } 167 | 168 | for _, name in ipairs{'inverse', 'potri', 'potrf'} do 169 | local cname = name == 'inverse' and 'getri' or name 170 | local func = C['THRealTensor_' .. cname] 171 | register{ 172 | name = name, 173 | {name="dst", type="torch.RealTensor", opt=true, method={opt=false}}, 174 | {name="src", type="torch.RealTensor", method={opt=true}}, 175 | call = 176 | function(dst, src) 177 | local res = src and dst or torch.RealTensor() 178 | src = src or dst 179 | func(res, src) 180 | return res 181 | end 182 | } 183 | 184 | end 185 | 186 | end 187 | -------------------------------------------------------------------------------- /storage.lua: -------------------------------------------------------------------------------- 1 | -- todo: 2 | -- RealRealStorage (plus pratique, ne serait-que pour THRealStorage, RealStorage...) 3 | -- changer le script de template accordingly 4 | -- make TH func return storage, tensors... this would avoid extra lua func redirections 5 | -- prefixer les champs des structures (.data en .__data...) dans les declarations FFI 6 | 7 | local display = require 'torch.display' 8 | local argcheck = require 'argcheck' 9 | local torch = require 'torch.env' 10 | local class = require 'class' 11 | local ffi = require 'ffi' 12 | local C = require 'torch.TH' 13 | 14 | local RealStorage = class('torch.RealStorage', nil, ffi.typeof('THRealStorage&')) 15 | torch.RealStorage = RealStorage 16 | 17 | RealStorage.__factory = 18 | function() 19 | local self = C.THRealStorage_new()[0] 20 | ffi.gc(self, C.THRealStorage_free) 21 | return self 22 | end 23 | 24 | RealStorage.new = argcheck{ 25 | {name="size", type="number", default=0}, 26 | nonamed = true, 27 | call = 28 | function(size) 29 | local self = C.THRealStorage_newWithSize(size)[0] 30 | ffi.gc(self, C.THRealStorage_free) 31 | return self 32 | end 33 | } 34 | 35 | argcheck{ 36 | {name="table", type="table"}, 37 | chain = RealStorage.new, 38 | nonamed = true, 39 | call = 40 | function(tbl) 41 | local size = #tbl 42 | self = C.THRealStorage_newWithSize(size)[0] 43 | ffi.gc(self, C.THRealStorage_free) 44 | for i=1,size do 45 | self.__data[i-1] = tbl[i] 46 | end 47 | return self 48 | end 49 | } 50 | 51 | argcheck{ 52 | {name="filename", type="string"}, 53 | {name="shared", type="boolean", default=false}, 54 | {name="size", type="number", default=0}, 55 | chain = RealStorage.new, 56 | nonamed = true, 57 | call = 58 | function(filename, shared, size) 59 | local self = C.THRealStorage_newWithMapping(filename, size, shared and 1 or 0)[0] 60 | ffi.gc(self, C.THRealStorage_free) 61 | return self 62 | end 63 | } 64 | 65 | RealStorage.fill = argcheck{ 66 | {name="self", type="torch.RealStorage"}, 67 | {name="value", type="number"}, 68 | call = C.THRealStorage_fill 69 | } 70 | 71 | RealStorage.size = argcheck{ 72 | {name="self", type="torch.RealStorage"}, 73 | call = 74 | function(self) 75 | return tonumber(self.__size) 76 | end 77 | } 78 | 79 | RealStorage.resize = argcheck{ 80 | {name="self", type="torch.RealStorage"}, 81 | {name="size", type="number"}, 82 | call = C.THRealStorage_resize 83 | } 84 | 85 | RealStorage.rawCopy = argcheck{ 86 | {name="self", type="torch.RealStorage"}, 87 | {name="data", type="cdata"}, 88 | call = 89 | function(self, data) 90 | ffi.copy(self.__data, data, ffi.sizeof('real')*self.__size) 91 | return self 92 | end 93 | } 94 | 95 | RealStorage.totable = argcheck{ 96 | {name="self", type="torch.RealStorage"}, 97 | call = 98 | function(self) 99 | local tbl = {} 100 | for i=1,self.__size do 101 | tbl[i] = self.__data[i-1] 102 | end 103 | return tbl 104 | end 105 | } 106 | 107 | if "RealStorage" == "CharRealStorage" or "RealStorage" == "ByteRealStorage" then 108 | RealStorage.string = argcheck{ 109 | {name="self", type="torch.RealStorage"}, 110 | call = 111 | function(self) 112 | return ffi.string(self.__data, self.__size) 113 | end 114 | } 115 | 116 | argcheck{ 117 | {name="self", type="torch.RealStorage"}, 118 | {name="data", type="string"}, 119 | chain = RealStorage.string, 120 | call = 121 | function(self, data) 122 | self:resize(#data) 123 | C.THRealStorage_rawCopy(self, ffi.cast('real*', data)) 124 | return self 125 | end 126 | } 127 | end 128 | 129 | RealStorage.copy = argcheck{ 130 | {name="self", type='torch.RealStorage'}, 131 | {name="src", type='torch.RealStorage'}, 132 | call = C.THRealStorage_copy 133 | } 134 | 135 | argcheck{ 136 | {name="self", type='torch.RealStorage'}, 137 | {name="src", type='torch.ByteStorage'}, 138 | chain = RealStorage.copy, 139 | call = C.THRealStorage_copyByte 140 | } 141 | 142 | argcheck{ 143 | {name="self", type='torch.RealStorage'}, 144 | {name="src", type='torch.CharStorage'}, 145 | chain = RealStorage.copy, 146 | call = C.THRealStorage_copyChar 147 | } 148 | 149 | argcheck{ 150 | {name="self", type='torch.RealStorage'}, 151 | {name="src", type='torch.ShortStorage'}, 152 | chain = RealStorage.copy, 153 | call = C.THRealStorage_copyShort 154 | } 155 | 156 | argcheck{ 157 | {name="self", type='torch.RealStorage'}, 158 | {name="src", type='torch.IntStorage'}, 159 | chain = RealStorage.copy, 160 | call = C.THRealStorage_copyInt 161 | } 162 | 163 | argcheck{ 164 | {name="self", type='torch.RealStorage'}, 165 | {name="src", type='torch.LongStorage'}, 166 | chain = RealStorage.copy, 167 | call = C.THRealStorage_copyLong 168 | } 169 | 170 | argcheck{ 171 | {name="self", type='torch.RealStorage'}, 172 | {name="src", type='torch.FloatStorage'}, 173 | chain = RealStorage.copy, 174 | call = C.THRealStorage_copyFloat 175 | } 176 | 177 | argcheck{ 178 | {name="self", type='torch.RealStorage'}, 179 | {name="src", type='torch.DoubleStorage'}, 180 | chain = RealStorage.copy, 181 | call = C.THRealStorage_copyDouble 182 | } 183 | 184 | function RealStorage:__index(k) 185 | if type(k) == 'number' then 186 | if k > 0 and k <= tonumber(self.__size) then 187 | return tonumber(self.__data[k-1]) 188 | else 189 | error('index out of bounds') 190 | end 191 | else 192 | return RealStorage[k] 193 | end 194 | end 195 | 196 | function RealStorage:__newindex(k, v) 197 | if type(k) == 'number' then 198 | if k > 0 and k <= self.__size then 199 | self.__data[k-1] = v 200 | else 201 | error('index out of bounds') 202 | end 203 | else 204 | rawset(self, k, v) 205 | end 206 | end 207 | 208 | function RealStorage:__len() 209 | return self.__size 210 | end 211 | 212 | function RealStorage:write(file) 213 | file:writeLong(self.__size) 214 | file:writeRaw('real', self.__data, self.__size) 215 | end 216 | 217 | function RealStorage:read(file) 218 | local size = file:readLong() 219 | rawInitWithSize(self, size) 220 | file:readRaw('real', self.__data, self.__size) 221 | end 222 | 223 | RealStorage.__tostring = display.storage 224 | 225 | ffi.metatype('THRealStorage', getmetatable(RealStorage)) 226 | -------------------------------------------------------------------------------- /TH/generic/THTensor.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THTensor.h" 3 | #else 4 | 5 | /* a la lua? dim, storageoffset, ... et les methodes ? */ 6 | 7 | #define TH_TENSOR_REFCOUNTED 1 8 | 9 | typedef struct THTensor 10 | { 11 | long *size; 12 | long *stride; 13 | int nDimension; 14 | 15 | THStorage *storage; 16 | long storageOffset; 17 | int refcount; 18 | 19 | char flag; 20 | 21 | } THTensor; 22 | 23 | 24 | /**** access methods ****/ 25 | TH_API THStorage* THTensor_(storage)(const THTensor *self); 26 | TH_API long THTensor_(storageOffset)(const THTensor *self); 27 | TH_API int THTensor_(nDimension)(const THTensor *self); 28 | TH_API long THTensor_(size)(const THTensor *self, int dim); 29 | TH_API long THTensor_(stride)(const THTensor *self, int dim); 30 | TH_API THLongStorage *THTensor_(newSizeOf)(THTensor *self); 31 | TH_API THLongStorage *THTensor_(newStrideOf)(THTensor *self); 32 | TH_API real *THTensor_(data)(const THTensor *self); 33 | 34 | TH_API void THTensor_(setFlag)(THTensor *self, const char flag); 35 | TH_API void THTensor_(clearFlag)(THTensor *self, const char flag); 36 | 37 | 38 | /**** creation methods ****/ 39 | TH_API THTensor *THTensor_(new)(void); 40 | TH_API THTensor *THTensor_(newWithTensor)(THTensor *tensor); 41 | /* stride might be NULL */ 42 | TH_API THTensor *THTensor_(newWithStorage)(THStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_); 43 | TH_API THTensor *THTensor_(newWithStorage1d)(THStorage *storage_, long storageOffset_, 44 | long size0_, long stride0_); 45 | TH_API THTensor *THTensor_(newWithStorage2d)(THStorage *storage_, long storageOffset_, 46 | long size0_, long stride0_, 47 | long size1_, long stride1_); 48 | TH_API THTensor *THTensor_(newWithStorage3d)(THStorage *storage_, long storageOffset_, 49 | long size0_, long stride0_, 50 | long size1_, long stride1_, 51 | long size2_, long stride2_); 52 | TH_API THTensor *THTensor_(newWithStorage4d)(THStorage *storage_, long storageOffset_, 53 | long size0_, long stride0_, 54 | long size1_, long stride1_, 55 | long size2_, long stride2_, 56 | long size3_, long stride3_); 57 | 58 | /* stride might be NULL */ 59 | TH_API THTensor *THTensor_(newWithSize)(THLongStorage *size_, THLongStorage *stride_); 60 | TH_API THTensor *THTensor_(newWithSize1d)(long size0_); 61 | TH_API THTensor *THTensor_(newWithSize2d)(long size0_, long size1_); 62 | TH_API THTensor *THTensor_(newWithSize3d)(long size0_, long size1_, long size2_); 63 | TH_API THTensor *THTensor_(newWithSize4d)(long size0_, long size1_, long size2_, long size3_); 64 | 65 | TH_API THTensor *THTensor_(newClone)(THTensor *self); 66 | TH_API THTensor *THTensor_(newContiguous)(THTensor *tensor); 67 | TH_API THTensor *THTensor_(newSelect)(THTensor *tensor, int dimension_, long sliceIndex_); 68 | TH_API THTensor *THTensor_(newNarrow)(THTensor *tensor, int dimension_, long firstIndex_, long size_); 69 | TH_API THTensor *THTensor_(newTranspose)(THTensor *tensor, int dimension1_, int dimension2_); 70 | TH_API THTensor *THTensor_(newUnfold)(THTensor *tensor, int dimension_, long size_, long step_); 71 | 72 | TH_API void THTensor_(resize)(THTensor *tensor, THLongStorage *size, THLongStorage *stride); 73 | TH_API void THTensor_(resizeAs)(THTensor *tensor, THTensor *src); 74 | TH_API void THTensor_(resize1d)(THTensor *tensor, long size0_); 75 | TH_API void THTensor_(resize2d)(THTensor *tensor, long size0_, long size1_); 76 | TH_API void THTensor_(resize3d)(THTensor *tensor, long size0_, long size1_, long size2_); 77 | TH_API void THTensor_(resize4d)(THTensor *tensor, long size0_, long size1_, long size2_, long size3_); 78 | TH_API void THTensor_(resize5d)(THTensor *tensor, long size0_, long size1_, long size2_, long size3_, long size4_); 79 | 80 | TH_API void THTensor_(set)(THTensor *self, THTensor *src); 81 | TH_API void THTensor_(setStorage)(THTensor *self, THStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_); 82 | TH_API void THTensor_(setStorage1d)(THTensor *self, THStorage *storage_, long storageOffset_, 83 | long size0_, long stride0_); 84 | TH_API void THTensor_(setStorage2d)(THTensor *self, THStorage *storage_, long storageOffset_, 85 | long size0_, long stride0_, 86 | long size1_, long stride1_); 87 | TH_API void THTensor_(setStorage3d)(THTensor *self, THStorage *storage_, long storageOffset_, 88 | long size0_, long stride0_, 89 | long size1_, long stride1_, 90 | long size2_, long stride2_); 91 | TH_API void THTensor_(setStorage4d)(THTensor *self, THStorage *storage_, long storageOffset_, 92 | long size0_, long stride0_, 93 | long size1_, long stride1_, 94 | long size2_, long stride2_, 95 | long size3_, long stride3_); 96 | 97 | TH_API void THTensor_(narrow)(THTensor *self, THTensor *src, int dimension_, long firstIndex_, long size_); 98 | TH_API void THTensor_(select)(THTensor *self, THTensor *src, int dimension_, long sliceIndex_); 99 | TH_API void THTensor_(transpose)(THTensor *self, THTensor *src, int dimension1_, int dimension2_); 100 | TH_API void THTensor_(unfold)(THTensor *self, THTensor *src, int dimension_, long size_, long step_); 101 | 102 | TH_API void THTensor_(squeeze)(THTensor *self, THTensor *src); 103 | TH_API void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension_); 104 | 105 | TH_API int THTensor_(isContiguous)(const THTensor *self); 106 | TH_API long THTensor_(nElement)(const THTensor *self); 107 | 108 | TH_API void THTensor_(retain)(THTensor *self); 109 | TH_API void THTensor_(free)(THTensor *self); 110 | TH_API void THTensor_(freeCopyTo)(THTensor *self, THTensor *dst); 111 | 112 | /* Slow access methods [check everything] */ 113 | TH_API void THTensor_(set1d)(THTensor *tensor, long x0, real value); 114 | TH_API void THTensor_(set2d)(THTensor *tensor, long x0, long x1, real value); 115 | TH_API void THTensor_(set3d)(THTensor *tensor, long x0, long x1, long x2, real value); 116 | TH_API void THTensor_(set4d)(THTensor *tensor, long x0, long x1, long x2, long x3, real value); 117 | 118 | TH_API real THTensor_(get1d)(const THTensor *tensor, long x0); 119 | TH_API real THTensor_(get2d)(const THTensor *tensor, long x0, long x1); 120 | TH_API real THTensor_(get3d)(const THTensor *tensor, long x0, long x1, long x2); 121 | TH_API real THTensor_(get4d)(const THTensor *tensor, long x0, long x1, long x2, long x3); 122 | 123 | #endif 124 | -------------------------------------------------------------------------------- /TH/generic/THLapack.c: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THLapack.c" 3 | #else 4 | 5 | 6 | TH_EXTERNC void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info); 7 | TH_EXTERNC void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info); 8 | TH_EXTERNC void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info); 9 | TH_EXTERNC void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info); 10 | TH_EXTERNC void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info); 11 | TH_EXTERNC void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info); 12 | TH_EXTERNC void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info); 13 | TH_EXTERNC void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info); 14 | TH_EXTERNC void dgesvd_(char *jobu, char *jobvt, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *info); 15 | TH_EXTERNC void sgesvd_(char *jobu, char *jobvt, int *m, int *n, float *a, int *lda, float *s, float *u, int *ldu, float *vt, int *ldvt, float *work, int *lwork, int *info); 16 | TH_EXTERNC void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info); 17 | TH_EXTERNC void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info); 18 | TH_EXTERNC void dgetri_(int *n, double *a, int *lda, int *ipiv, double *work, int *lwork, int *info); 19 | TH_EXTERNC void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info); 20 | TH_EXTERNC void dpotrf_(char *uplo, int *n, double *a, int *lda, int *info); 21 | TH_EXTERNC void spotrf_(char *uplo, int *n, float *a, int *lda, int *info); 22 | TH_EXTERNC void dpotri_(char *uplo, int *n, double *a, int *lda, int *info); 23 | TH_EXTERNC void spotri_(char *uplo, int *n, float *a, int *lda, int *info); 24 | TH_EXTERNC void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info); 25 | TH_EXTERNC void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info); 26 | 27 | 28 | void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info) 29 | { 30 | #ifdef USE_LAPACK 31 | #if defined(TH_REAL_IS_DOUBLE) 32 | dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info); 33 | #else 34 | sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info); 35 | #endif 36 | #else 37 | THError("gesv : Lapack library not found in compile time\n"); 38 | #endif 39 | return; 40 | } 41 | 42 | void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real *b, int ldb, real *work, int lwork, int *info) 43 | { 44 | #ifdef USE_LAPACK 45 | #if defined(TH_REAL_IS_DOUBLE) 46 | dgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info); 47 | #else 48 | sgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info); 49 | #endif 50 | #else 51 | THError("gels : Lapack library not found in compile time\n"); 52 | #endif 53 | } 54 | 55 | void THLapack_(syev)(char jobz, char uplo, int n, real *a, int lda, real *w, real *work, int lwork, int *info) 56 | { 57 | #ifdef USE_LAPACK 58 | #if defined(TH_REAL_IS_DOUBLE) 59 | dsyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info); 60 | #else 61 | ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info); 62 | #endif 63 | #else 64 | THError("syev : Lapack library not found in compile time\n"); 65 | #endif 66 | } 67 | 68 | void THLapack_(geev)(char jobvl, char jobvr, int n, real *a, int lda, real *wr, real *wi, real* vl, int ldvl, real *vr, int ldvr, real *work, int lwork, int *info) 69 | { 70 | #ifdef USE_LAPACK 71 | #if defined(TH_REAL_IS_DOUBLE) 72 | dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); 73 | #else 74 | sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); 75 | #endif 76 | #else 77 | THError("geev : Lapack library not found in compile time\n"); 78 | #endif 79 | } 80 | 81 | void THLapack_(gesvd)(char jobu, char jobvt, int m, int n, real *a, int lda, real *s, real *u, int ldu, real *vt, int ldvt, real *work, int lwork, int *info) 82 | { 83 | #ifdef USE_LAPACK 84 | #if defined(TH_REAL_IS_DOUBLE) 85 | dgesvd_( &jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, info); 86 | #else 87 | sgesvd_( &jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, info); 88 | #endif 89 | #else 90 | THError("gesvd : Lapack library not found in compile time\n"); 91 | #endif 92 | } 93 | 94 | /* LU decomposition */ 95 | void THLapack_(getrf)(int m, int n, real *a, int lda, int *ipiv, int *info) 96 | { 97 | #ifdef USE_LAPACK 98 | #if defined(TH_REAL_IS_DOUBLE) 99 | dgetrf_(&m, &n, a, &lda, ipiv, info); 100 | #else 101 | sgetrf_(&m, &n, a, &lda, ipiv, info); 102 | #endif 103 | #else 104 | THError("getrf : Lapack library not found in compile time\n"); 105 | #endif 106 | } 107 | /* Matrix Inverse */ 108 | void THLapack_(getri)(int n, real *a, int lda, int *ipiv, real *work, int lwork, int* info) 109 | { 110 | #ifdef USE_LAPACK 111 | #if defined(TH_REAL_IS_DOUBLE) 112 | dgetri_(&n, a, &lda, ipiv, work, &lwork, info); 113 | #else 114 | sgetri_(&n, a, &lda, ipiv, work, &lwork, info); 115 | #endif 116 | #else 117 | THError("getri : Lapack library not found in compile time\n"); 118 | #endif 119 | } 120 | 121 | /* Cholesky factorization */ 122 | void THLapack_(potrf)(char uplo, int n, real *a, int lda, int *info) 123 | { 124 | #ifdef USE_LAPACK 125 | #if defined(TH_REAL_IS_DOUBLE) 126 | dpotrf_(&uplo, &n, a, &lda, info); 127 | #else 128 | spotrf_(&uplo, &n, a, &lda, info); 129 | #endif 130 | #else 131 | THError("potrf : Lapack library not found in compile time\n"); 132 | #endif 133 | } 134 | 135 | /* Cholesky factorization based Matrix Inverse */ 136 | void THLapack_(potri)(char uplo, int n, real *a, int lda, int *info) 137 | { 138 | #ifdef USE_LAPACK 139 | #if defined(TH_REAL_IS_DOUBLE) 140 | dpotri_(&uplo, &n, a, &lda, info); 141 | #else 142 | spotri_(&uplo, &n, a, &lda, info); 143 | #endif 144 | #else 145 | THError("potri: Lapack library not found in compile time\n"); 146 | #endif 147 | } 148 | 149 | /* Solve A*X = B with a symmetric positive definite matrix A using the Cholesky factorization */ 150 | void THLapack_(potrs)(char uplo, int n, int nrhs, real *a, int lda, real *b, int ldb, int *info) 151 | { 152 | #ifdef USE_LAPACK 153 | #if defined(TH_REAL_IS_DOUBLE) 154 | dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); 155 | #else 156 | spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); 157 | #endif 158 | #else 159 | THError("potrs: Lapack library not found in compile time\n"); 160 | #endif 161 | } 162 | 163 | #endif 164 | -------------------------------------------------------------------------------- /display.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch.env' 2 | local class = require 'class' 3 | 4 | local display = {} 5 | 6 | local function storageformat(self) 7 | local intMode = true 8 | local expMin = math.huge 9 | local expMax = -math.huge 10 | local type = class.type(self) 11 | for i=1,self:size() do 12 | local z = tonumber(self[i]) 13 | if z ~= math.ceil(z) then 14 | intMode = false 15 | end 16 | expMin = math.min(expMin, math.abs(z)) 17 | expMax = math.max(expMax, math.abs(z)) 18 | end 19 | if expMin ~= 0 then 20 | expMin = math.floor(math.log10(expMin)) + 1 21 | end 22 | if expMax ~= 0 then 23 | expMax = math.floor(math.log10(expMax)) + 1 24 | end 25 | 26 | local format 27 | local scale 28 | local sz 29 | if intMode then 30 | if expMax > 9 then 31 | format = "%11.4e" 32 | sz = 11 33 | else 34 | format = "%SZd" 35 | sz = expMax + 1 36 | end 37 | else 38 | if expMax-expMin > 4 then 39 | format = "%SZ.4e" 40 | sz = 11 41 | if math.abs(expMax) > 99 or math.abs(expMin) > 99 then 42 | sz = sz + 1 43 | end 44 | else 45 | if expMax > 5 or expMax < 0 then 46 | format = "%SZ.4f" 47 | sz = 7 48 | scale = math.pow(10, expMax-1) 49 | else 50 | format = "%SZ.4f" 51 | if expMax == 0 then 52 | sz = 7 53 | else 54 | sz = expMax+6 55 | end 56 | end 57 | end 58 | end 59 | format = string.gsub(format, 'SZ', sz) 60 | if scale == 1 then 61 | scale = nil 62 | end 63 | return format, scale, sz 64 | end 65 | 66 | function display.storage(self) 67 | local strt = {'\n'} 68 | local format, scale = storageformat(self) 69 | if format:sub(2,4) == 'nan' then format = '%f' end 70 | format = format .. '\n' 71 | if scale then 72 | table.insert(strt, string.format('%g *\n', scale)) 73 | for i = 1,self:size() do 74 | table.insert(strt, string.format(format, self[i]/scale)) 75 | end 76 | else 77 | for i = 1,self:size() do 78 | table.insert(strt, string.format(format, self[i])) 79 | end 80 | end 81 | table.insert(strt, string.format('[%s of size %d]\n', class.type(self), self:size())) 82 | local str = table.concat(strt) 83 | return str 84 | end 85 | 86 | local function displaymatrix(self, indent) 87 | local format, scale, sz = storageformat(self:storage()) 88 | if format:sub(2,4) == 'nan' then format = '%f' end 89 | scale = scale or 1 90 | indent = indent or '' 91 | local strt = {indent} 92 | local nColumnPerLine = math.floor((80-#indent)/(sz+1)) 93 | local firstColumn = 1 94 | local lastColumn = -1 95 | while firstColumn <= self:size(2) do 96 | if firstColumn + nColumnPerLine - 1 <= self:size(2) then 97 | lastColumn = firstColumn + nColumnPerLine - 1 98 | else 99 | lastColumn = self:size(2) 100 | end 101 | if nColumnPerLine < self:size(2) then 102 | if firstColumn ~= 1 then 103 | table.insert(strt, '\n') 104 | end 105 | table.insert(strt, string.format('Columns %d to %d\n%s', firstColumn, lastColumn, indent)) 106 | end 107 | if scale ~= 1 then 108 | table.insert(strt, string.format('%g *\n %s', scale, indent)) 109 | end 110 | for l=1,self:size(1) do 111 | local row = self:select(1, l) 112 | for c=firstColumn,lastColumn do 113 | table.insert(strt, string.format(format, row[c]/scale)) 114 | if c == lastColumn then 115 | table.insert(strt, '\n') 116 | if l~=self:size(1) then 117 | if scale ~= 1 then 118 | table.insert(strt, indent .. ' ') 119 | else 120 | table.insert(strt, indent) 121 | end 122 | end 123 | else 124 | table.insert(strt, ' ') 125 | end 126 | end 127 | end 128 | firstColumn = lastColumn + 1 129 | end 130 | local str = table.concat(strt) 131 | return str 132 | end 133 | 134 | local function displaytensor(self) 135 | local counter = torch.LongStorage(self:nDimension()-2) 136 | local strt = {''} 137 | local finished 138 | counter:fill(1) 139 | counter[1] = 0 140 | while true do 141 | for i=1,self:nDimension()-2 do 142 | counter[i] = counter[i] + 1 143 | if counter[i] > self:size(i) then 144 | if i == self:nDimension()-2 then 145 | finished = true 146 | break 147 | end 148 | counter[i] = 1 149 | else 150 | break 151 | end 152 | end 153 | if finished then 154 | break 155 | end 156 | if #strt > 1 then 157 | table.insert(strt, '\n') 158 | end 159 | table.insert(strt, '(') 160 | local tensor = self 161 | for i=1,self:nDimension()-2 do 162 | tensor = tensor:select(1, counter[i]) 163 | table.insert(strt, counter[i] .. ',') 164 | end 165 | table.insert(strt, '.,.) = \n') 166 | table.insert(strt, displaymatrix(tensor, ' ')) 167 | end 168 | local str = table.concat(strt) 169 | return str 170 | end 171 | 172 | function display.tensor(self) 173 | local str = '\n' 174 | local strt = {''} 175 | if self:nDimension() == 0 then 176 | table.insert(strt, string.format('[%s with no dimension]\n', class.type(self))) 177 | else 178 | if self:nDimension() == 1 then 179 | local format,scale,sz = storageformat(self:storage()) 180 | if format:sub(2,4) == 'nan' then format = '%f' end 181 | format = format .. '\n' 182 | if scale then 183 | table.insert(strt, string.format('%g *\n', scale)) 184 | for i = 1,self:size(1) do 185 | table.insert(strt, string.format(format, self[i]/scale)) 186 | end 187 | else 188 | for i = 1,self:size(1) do 189 | table.insert(strt, string.format(format, self[i])) 190 | end 191 | end 192 | table.insert(strt, string.format('[%s of dimension %d]\n', class.type(self), self:size(1))) 193 | elseif self:nDimension() == 2 then 194 | table.insert(strt, displaymatrix(self)) 195 | table.insert(strt, string.format('[%s of dimension %dx%d]\n', class.type(self), self:size(1), self:size(2))) 196 | else 197 | table.insert(strt, displaytensor(self)) 198 | table.insert(strt, string.format('[%s of dimension ', class.type(self))) 199 | for i=1,self:nDimension() do 200 | table.insert(strt, self:size(i)) 201 | if i ~= self:nDimension() then 202 | table.insert(strt, 'x') 203 | end 204 | end 205 | table.insert(strt, ']\n') 206 | end 207 | end 208 | local str = table.concat(strt) 209 | return str 210 | end 211 | 212 | return display 213 | -------------------------------------------------------------------------------- /cmake/FindLAPACK.cmake: -------------------------------------------------------------------------------- 1 | # - Find LAPACK library 2 | # This module finds an installed fortran library that implements the LAPACK 3 | # linear-algebra interface (see http://www.netlib.org/lapack/). 4 | # 5 | # The approach follows that taken for the autoconf macro file, acx_lapack.m4 6 | # (distributed at http://ac-archive.sourceforge.net/ac-archive/acx_lapack.html). 7 | # 8 | # This module sets the following variables: 9 | # LAPACK_FOUND - set to true if a library implementing the LAPACK interface is found 10 | # LAPACK_LIBRARIES - list of libraries (using full path name) for LAPACK 11 | 12 | # Note: I do not think it is a good idea to mixup different BLAS/LAPACK versions 13 | # Hence, this script wants to find a Lapack library matching your Blas library 14 | 15 | # Do nothing if LAPACK was found before 16 | IF(NOT LAPACK_FOUND) 17 | 18 | SET(LAPACK_LIBRARIES) 19 | SET(LAPACK_INFO) 20 | 21 | IF(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) 22 | FIND_PACKAGE(BLAS) 23 | ELSE(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) 24 | FIND_PACKAGE(BLAS REQUIRED) 25 | ENDIF(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) 26 | 27 | # Old search lapack script 28 | include(CheckFortranFunctionExists) 29 | 30 | macro(Check_Lapack_Libraries LIBRARIES _prefix _name _flags _list _blas) 31 | # This macro checks for the existence of the combination of fortran libraries 32 | # given by _list. If the combination is found, this macro checks (using the 33 | # Check_Fortran_Function_Exists macro) whether can link against that library 34 | # combination using the name of a routine given by _name using the linker 35 | # flags given by _flags. If the combination of libraries is found and passes 36 | # the link test, LIBRARIES is set to the list of complete library paths that 37 | # have been found. Otherwise, LIBRARIES is set to FALSE. 38 | # N.B. _prefix is the prefix applied to the names of all cached variables that 39 | # are generated internally and marked advanced by this macro. 40 | set(_libraries_work TRUE) 41 | set(${LIBRARIES}) 42 | set(_combined_name) 43 | foreach(_library ${_list}) 44 | set(_combined_name ${_combined_name}_${_library}) 45 | if(_libraries_work) 46 | if (WIN32) 47 | find_library(${_prefix}_${_library}_LIBRARY 48 | NAMES ${_library} PATHS ENV LIB PATHS ENV PATH) 49 | else (WIN32) 50 | if(APPLE) 51 | find_library(${_prefix}_${_library}_LIBRARY 52 | NAMES ${_library} 53 | PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 54 | ENV DYLD_LIBRARY_PATH) 55 | else(APPLE) 56 | find_library(${_prefix}_${_library}_LIBRARY 57 | NAMES ${_library} 58 | PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 59 | ENV LD_LIBRARY_PATH) 60 | endif(APPLE) 61 | endif(WIN32) 62 | mark_as_advanced(${_prefix}_${_library}_LIBRARY) 63 | set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) 64 | set(_libraries_work ${${_prefix}_${_library}_LIBRARY}) 65 | endif(_libraries_work) 66 | endforeach(_library ${_list}) 67 | if(_libraries_work) 68 | # Test this combination of libraries. 69 | set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}} ${_blas}) 70 | if (CMAKE_Fortran_COMPILER_WORKS) 71 | check_fortran_function_exists(${_name} ${_prefix}${_combined_name}_WORKS) 72 | else (CMAKE_Fortran_COMPILER_WORKS) 73 | check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS) 74 | endif (CMAKE_Fortran_COMPILER_WORKS) 75 | set(CMAKE_REQUIRED_LIBRARIES) 76 | mark_as_advanced(${_prefix}${_combined_name}_WORKS) 77 | set(_libraries_work ${${_prefix}${_combined_name}_WORKS}) 78 | endif(_libraries_work) 79 | if(NOT _libraries_work) 80 | set(${LIBRARIES} FALSE) 81 | endif(NOT _libraries_work) 82 | endmacro(Check_Lapack_Libraries) 83 | 84 | 85 | if(BLAS_FOUND) 86 | 87 | # Intel MKL 88 | IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "mkl")) 89 | IF(MKL_LAPACK_LIBRARIES) 90 | SET(LAPACK_LIBRARIES ${MKL_LAPACK_LIBRARIES} ${MKL_LIBRARIES}) 91 | ELSE(MKL_LAPACK_LIBRARIES) 92 | SET(LAPACK_LIBRARIES ${MKL_LIBRARIES}) 93 | ENDIF(MKL_LAPACK_LIBRARIES) 94 | SET(LAPACK_INCLUDE_DIR ${MKL_INCLUDE_DIR}) 95 | SET(LAPACK_INFO "mkl") 96 | ENDIF() 97 | 98 | # OpenBlas 99 | IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "open")) 100 | SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) 101 | check_function_exists("cheev_" OPEN_LAPACK_WORKS) 102 | if(OPEN_LAPACK_WORKS) 103 | SET(LAPACK_INFO "open") 104 | else() 105 | message(STATUS "It seems OpenBlas has not been compiled with Lapack support") 106 | endif() 107 | endif() 108 | 109 | # GotoBlas 110 | IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "goto")) 111 | SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) 112 | check_function_exists("cheev_" GOTO_LAPACK_WORKS) 113 | if(GOTO_LAPACK_WORKS) 114 | SET(LAPACK_INFO "goto") 115 | else() 116 | message(STATUS "It seems GotoBlas has not been compiled with Lapack support") 117 | endif() 118 | endif() 119 | 120 | # ACML 121 | IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "acml")) 122 | SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) 123 | check_function_exists("cheev_" ACML_LAPACK_WORKS) 124 | if(ACML_LAPACK_WORKS) 125 | SET(LAPACK_INFO "acml") 126 | else() 127 | message(STATUS "Strangely, this ACML library does not support Lapack?!") 128 | endif() 129 | endif() 130 | 131 | # Accelerate 132 | IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "accelerate")) 133 | SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) 134 | check_function_exists("cheev_" ACCELERATE_LAPACK_WORKS) 135 | if(ACCELERATE_LAPACK_WORKS) 136 | SET(LAPACK_INFO "accelerate") 137 | else() 138 | message(STATUS "Strangely, this Accelerate library does not support Lapack?!") 139 | endif() 140 | endif() 141 | 142 | # vecLib 143 | IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "veclib")) 144 | SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) 145 | check_function_exists("cheev_" VECLIB_LAPACK_WORKS) 146 | if(VECLIB_LAPACK_WORKS) 147 | SET(LAPACK_INFO "veclib") 148 | else() 149 | message(STATUS "Strangely, this vecLib library does not support Lapack?!") 150 | endif() 151 | endif() 152 | 153 | # Generic LAPACK library? 154 | IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "generic")) 155 | check_lapack_libraries( 156 | LAPACK_LIBRARIES 157 | LAPACK 158 | cheev 159 | "" 160 | "lapack" 161 | "${BLAS_LIBRARIES}" 162 | ) 163 | if(LAPACK_LIBRARIES) 164 | SET(LAPACK_INFO "generic") 165 | endif(LAPACK_LIBRARIES) 166 | endif() 167 | 168 | else(BLAS_FOUND) 169 | message(STATUS "LAPACK requires BLAS") 170 | endif(BLAS_FOUND) 171 | 172 | if(LAPACK_INFO) 173 | set(LAPACK_FOUND TRUE) 174 | else(LAPACK_INFO) 175 | set(LAPACK_FOUND FALSE) 176 | endif(LAPACK_INFO) 177 | 178 | IF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED) 179 | message(FATAL_ERROR "Cannot find a library with LAPACK API. Please specify library location.") 180 | ENDIF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED) 181 | IF(NOT LAPACK_FIND_QUIETLY) 182 | IF(LAPACK_FOUND) 183 | MESSAGE(STATUS "Found a library with LAPACK API. (${LAPACK_INFO})") 184 | ELSE(LAPACK_FOUND) 185 | MESSAGE(STATUS "Cannot find a library with LAPACK API. Not using LAPACK.") 186 | ENDIF(LAPACK_FOUND) 187 | ENDIF(NOT LAPACK_FIND_QUIETLY) 188 | 189 | # Do nothing if LAPACK was found before 190 | ENDIF(NOT LAPACK_FOUND) 191 | -------------------------------------------------------------------------------- /TH/cmake/FindLAPACK.cmake: -------------------------------------------------------------------------------- 1 | # - Find LAPACK library 2 | # This module finds an installed fortran library that implements the LAPACK 3 | # linear-algebra interface (see http://www.netlib.org/lapack/). 4 | # 5 | # The approach follows that taken for the autoconf macro file, acx_lapack.m4 6 | # (distributed at http://ac-archive.sourceforge.net/ac-archive/acx_lapack.html). 7 | # 8 | # This module sets the following variables: 9 | # LAPACK_FOUND - set to true if a library implementing the LAPACK interface is found 10 | # LAPACK_LIBRARIES - list of libraries (using full path name) for LAPACK 11 | 12 | # Note: I do not think it is a good idea to mixup different BLAS/LAPACK versions 13 | # Hence, this script wants to find a Lapack library matching your Blas library 14 | 15 | # Do nothing if LAPACK was found before 16 | IF(NOT LAPACK_FOUND) 17 | 18 | SET(LAPACK_LIBRARIES) 19 | SET(LAPACK_INFO) 20 | 21 | IF(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) 22 | FIND_PACKAGE(BLAS) 23 | ELSE(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) 24 | FIND_PACKAGE(BLAS REQUIRED) 25 | ENDIF(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) 26 | 27 | # Old search lapack script 28 | include(CheckFortranFunctionExists) 29 | 30 | macro(Check_Lapack_Libraries LIBRARIES _prefix _name _flags _list _blas) 31 | # This macro checks for the existence of the combination of fortran libraries 32 | # given by _list. If the combination is found, this macro checks (using the 33 | # Check_Fortran_Function_Exists macro) whether can link against that library 34 | # combination using the name of a routine given by _name using the linker 35 | # flags given by _flags. If the combination of libraries is found and passes 36 | # the link test, LIBRARIES is set to the list of complete library paths that 37 | # have been found. Otherwise, LIBRARIES is set to FALSE. 38 | # N.B. _prefix is the prefix applied to the names of all cached variables that 39 | # are generated internally and marked advanced by this macro. 40 | set(_libraries_work TRUE) 41 | set(${LIBRARIES}) 42 | set(_combined_name) 43 | foreach(_library ${_list}) 44 | set(_combined_name ${_combined_name}_${_library}) 45 | if(_libraries_work) 46 | if (WIN32) 47 | find_library(${_prefix}_${_library}_LIBRARY 48 | NAMES ${_library} PATHS ENV LIB PATHS ENV PATH) 49 | else (WIN32) 50 | if(APPLE) 51 | find_library(${_prefix}_${_library}_LIBRARY 52 | NAMES ${_library} 53 | PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 54 | ENV DYLD_LIBRARY_PATH) 55 | else(APPLE) 56 | find_library(${_prefix}_${_library}_LIBRARY 57 | NAMES ${_library} 58 | PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 59 | ENV LD_LIBRARY_PATH) 60 | endif(APPLE) 61 | endif(WIN32) 62 | mark_as_advanced(${_prefix}_${_library}_LIBRARY) 63 | set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) 64 | set(_libraries_work ${${_prefix}_${_library}_LIBRARY}) 65 | endif(_libraries_work) 66 | endforeach(_library ${_list}) 67 | if(_libraries_work) 68 | # Test this combination of libraries. 69 | set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}} ${_blas}) 70 | if (CMAKE_Fortran_COMPILER_WORKS) 71 | check_fortran_function_exists(${_name} ${_prefix}${_combined_name}_WORKS) 72 | else (CMAKE_Fortran_COMPILER_WORKS) 73 | check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS) 74 | endif (CMAKE_Fortran_COMPILER_WORKS) 75 | set(CMAKE_REQUIRED_LIBRARIES) 76 | mark_as_advanced(${_prefix}${_combined_name}_WORKS) 77 | set(_libraries_work ${${_prefix}${_combined_name}_WORKS}) 78 | endif(_libraries_work) 79 | if(NOT _libraries_work) 80 | set(${LIBRARIES} FALSE) 81 | endif(NOT _libraries_work) 82 | endmacro(Check_Lapack_Libraries) 83 | 84 | 85 | if(BLAS_FOUND) 86 | 87 | # Intel MKL 88 | IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "mkl")) 89 | IF(MKL_LAPACK_LIBRARIES) 90 | SET(LAPACK_LIBRARIES ${MKL_LAPACK_LIBRARIES} ${MKL_LIBRARIES}) 91 | ELSE(MKL_LAPACK_LIBRARIES) 92 | SET(LAPACK_LIBRARIES ${MKL_LIBRARIES}) 93 | ENDIF(MKL_LAPACK_LIBRARIES) 94 | SET(LAPACK_INCLUDE_DIR ${MKL_INCLUDE_DIR}) 95 | SET(LAPACK_INFO "mkl") 96 | ENDIF() 97 | 98 | # OpenBlas 99 | IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "open")) 100 | SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) 101 | check_function_exists("cheev_" OPEN_LAPACK_WORKS) 102 | if(OPEN_LAPACK_WORKS) 103 | SET(LAPACK_INFO "open") 104 | else() 105 | message(STATUS "It seems OpenBlas has not been compiled with Lapack support") 106 | endif() 107 | endif() 108 | 109 | # GotoBlas 110 | IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "goto")) 111 | SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) 112 | check_function_exists("cheev_" GOTO_LAPACK_WORKS) 113 | if(GOTO_LAPACK_WORKS) 114 | SET(LAPACK_INFO "goto") 115 | else() 116 | message(STATUS "It seems GotoBlas has not been compiled with Lapack support") 117 | endif() 118 | endif() 119 | 120 | # ACML 121 | IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "acml")) 122 | SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) 123 | check_function_exists("cheev_" ACML_LAPACK_WORKS) 124 | if(ACML_LAPACK_WORKS) 125 | SET(LAPACK_INFO "acml") 126 | else() 127 | message(STATUS "Strangely, this ACML library does not support Lapack?!") 128 | endif() 129 | endif() 130 | 131 | # Accelerate 132 | IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "accelerate")) 133 | SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) 134 | check_function_exists("cheev_" ACCELERATE_LAPACK_WORKS) 135 | if(ACCELERATE_LAPACK_WORKS) 136 | SET(LAPACK_INFO "accelerate") 137 | else() 138 | message(STATUS "Strangely, this Accelerate library does not support Lapack?!") 139 | endif() 140 | endif() 141 | 142 | # vecLib 143 | IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "veclib")) 144 | SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) 145 | check_function_exists("cheev_" VECLIB_LAPACK_WORKS) 146 | if(VECLIB_LAPACK_WORKS) 147 | SET(LAPACK_INFO "veclib") 148 | else() 149 | message(STATUS "Strangely, this vecLib library does not support Lapack?!") 150 | endif() 151 | endif() 152 | 153 | # Generic LAPACK library? 154 | IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "generic")) 155 | check_lapack_libraries( 156 | LAPACK_LIBRARIES 157 | LAPACK 158 | cheev 159 | "" 160 | "lapack" 161 | "${BLAS_LIBRARIES}" 162 | ) 163 | if(LAPACK_LIBRARIES) 164 | SET(LAPACK_INFO "generic") 165 | endif(LAPACK_LIBRARIES) 166 | endif() 167 | 168 | else(BLAS_FOUND) 169 | message(STATUS "LAPACK requires BLAS") 170 | endif(BLAS_FOUND) 171 | 172 | if(LAPACK_INFO) 173 | set(LAPACK_FOUND TRUE) 174 | else(LAPACK_INFO) 175 | set(LAPACK_FOUND FALSE) 176 | endif(LAPACK_INFO) 177 | 178 | IF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED) 179 | message(FATAL_ERROR "Cannot find a library with LAPACK API. Please specify library location.") 180 | ENDIF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED) 181 | IF(NOT LAPACK_FIND_QUIETLY) 182 | IF(LAPACK_FOUND) 183 | MESSAGE(STATUS "Found a library with LAPACK API. (${LAPACK_INFO})") 184 | ELSE(LAPACK_FOUND) 185 | MESSAGE(STATUS "Cannot find a library with LAPACK API. Not using LAPACK.") 186 | ENDIF(LAPACK_FOUND) 187 | ENDIF(NOT LAPACK_FIND_QUIETLY) 188 | 189 | # Do nothing if LAPACK was found before 190 | ENDIF(NOT LAPACK_FOUND) 191 | -------------------------------------------------------------------------------- /TH/generic/THTensorMath.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THTensorMath.h" 3 | #else 4 | 5 | TH_API void THTensor_(fill)(THTensor *r_, real value); 6 | TH_API void THTensor_(zero)(THTensor *r_); 7 | 8 | TH_API void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, real value); 9 | TH_API void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src); 10 | TH_API void THTensor_(maskedSelect)(THTensor *tensor, THTensor* src, THByteTensor *mask); 11 | 12 | TH_API void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index); 13 | TH_API void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src); 14 | TH_API void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, real val); 15 | 16 | TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src); 17 | 18 | TH_API real THTensor_(minall)(THTensor *t); 19 | TH_API real THTensor_(maxall)(THTensor *t); 20 | TH_API accreal THTensor_(sumall)(THTensor *t); 21 | 22 | TH_API void THTensor_(add)(THTensor *r_, THTensor *t, real value); 23 | TH_API void THTensor_(mul)(THTensor *r_, THTensor *t, real value); 24 | TH_API void THTensor_(div)(THTensor *r_, THTensor *t, real value); 25 | 26 | TH_API void THTensor_(cadd)(THTensor *r_, THTensor *t, real value, THTensor *src); 27 | TH_API void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src); 28 | TH_API void THTensor_(cdiv)(THTensor *r_, THTensor *t, THTensor *src); 29 | 30 | TH_API void THTensor_(addcmul)(THTensor *r_, THTensor *t, real value, THTensor *src1, THTensor *src2); 31 | TH_API void THTensor_(addcdiv)(THTensor *r_, THTensor *t, real value, THTensor *src1, THTensor *src2); 32 | 33 | TH_API void THTensor_(addmv)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *mat, THTensor *vec); 34 | TH_API void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *mat1, THTensor *mat2); 35 | TH_API void THTensor_(addr)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *vec1, THTensor *vec2); 36 | 37 | TH_API void THTensor_(match)(THTensor *r_, THTensor *m1, THTensor *m2, real gain); 38 | 39 | TH_API long THTensor_(numel)(THTensor *t); 40 | TH_API void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension); 41 | TH_API void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension); 42 | TH_API void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension); 43 | TH_API void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension); 44 | TH_API void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension); 45 | TH_API void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension); 46 | TH_API void THTensor_(sign)(THTensor *r_, THTensor *t); 47 | TH_API accreal THTensor_(trace)(THTensor *t); 48 | TH_API void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension); 49 | 50 | TH_API void THTensor_(zeros)(THTensor *r_, THLongStorage *size); 51 | TH_API void THTensor_(ones)(THTensor *r_, THLongStorage *size); 52 | TH_API void THTensor_(diag)(THTensor *r_, THTensor *t, int k); 53 | TH_API void THTensor_(eye)(THTensor *r_, long n, long m); 54 | TH_API void THTensor_(range)(THTensor *r_, real xmin, real xmax, real step); 55 | TH_API void THTensor_(randperm)(THTensor *r_, THGenerator *_generator, long n); 56 | 57 | TH_API void THTensor_(reshape)(THTensor *r_, THTensor *t, THLongStorage *size); 58 | TH_API void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder); 59 | TH_API void THTensor_(tril)(THTensor *r_, THTensor *t, long k); 60 | TH_API void THTensor_(triu)(THTensor *r_, THTensor *t, long k); 61 | TH_API void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension); 62 | 63 | TH_API void THTensor_(ltValue)(THByteTensor *r_, THTensor* t, real value); 64 | TH_API void THTensor_(leValue)(THByteTensor *r_, THTensor* t, real value); 65 | TH_API void THTensor_(gtValue)(THByteTensor *r_, THTensor* t, real value); 66 | TH_API void THTensor_(geValue)(THByteTensor *r_, THTensor* t, real value); 67 | TH_API void THTensor_(neValue)(THByteTensor *r_, THTensor* t, real value); 68 | TH_API void THTensor_(eqValue)(THByteTensor *r_, THTensor* t, real value); 69 | 70 | TH_API void THTensor_(ltValueT)(THTensor *r_, THTensor* t, real value); 71 | TH_API void THTensor_(leValueT)(THTensor *r_, THTensor* t, real value); 72 | TH_API void THTensor_(gtValueT)(THTensor *r_, THTensor* t, real value); 73 | TH_API void THTensor_(geValueT)(THTensor *r_, THTensor* t, real value); 74 | TH_API void THTensor_(neValueT)(THTensor *r_, THTensor* t, real value); 75 | TH_API void THTensor_(eqValueT)(THTensor *r_, THTensor* t, real value); 76 | 77 | TH_API void THTensor_(ltTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); 78 | TH_API void THTensor_(leTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); 79 | TH_API void THTensor_(gtTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); 80 | TH_API void THTensor_(geTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); 81 | TH_API void THTensor_(neTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); 82 | TH_API void THTensor_(eqTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); 83 | 84 | TH_API void THTensor_(ltTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); 85 | TH_API void THTensor_(leTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); 86 | TH_API void THTensor_(gtTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); 87 | TH_API void THTensor_(geTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); 88 | TH_API void THTensor_(neTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); 89 | TH_API void THTensor_(eqTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); 90 | 91 | #if defined(TH_REAL_IS_INT) || defined(TH_REAL_IS_LONG) 92 | TH_API void THTensor_(abs)(THTensor *r_, THTensor *t); 93 | #endif 94 | 95 | #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) 96 | 97 | TH_API void THTensor_(log)(THTensor *r_, THTensor *t); 98 | TH_API void THTensor_(log1p)(THTensor *r_, THTensor *t); 99 | TH_API void THTensor_(exp)(THTensor *r_, THTensor *t); 100 | TH_API void THTensor_(cos)(THTensor *r_, THTensor *t); 101 | TH_API void THTensor_(acos)(THTensor *r_, THTensor *t); 102 | TH_API void THTensor_(cosh)(THTensor *r_, THTensor *t); 103 | TH_API void THTensor_(sin)(THTensor *r_, THTensor *t); 104 | TH_API void THTensor_(asin)(THTensor *r_, THTensor *t); 105 | TH_API void THTensor_(sinh)(THTensor *r_, THTensor *t); 106 | TH_API void THTensor_(tan)(THTensor *r_, THTensor *t); 107 | TH_API void THTensor_(atan)(THTensor *r_, THTensor *t); 108 | TH_API void THTensor_(atan2)(THTensor *r_, THTensor *tx, THTensor *ty); 109 | TH_API void THTensor_(tanh)(THTensor *r_, THTensor *t); 110 | TH_API void THTensor_(pow)(THTensor *r_, THTensor *t, real value); 111 | TH_API void THTensor_(sqrt)(THTensor *r_, THTensor *t); 112 | TH_API void THTensor_(ceil)(THTensor *r_, THTensor *t); 113 | TH_API void THTensor_(floor)(THTensor *r_, THTensor *t); 114 | TH_API void THTensor_(abs)(THTensor *r_, THTensor *t); 115 | 116 | TH_API void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension); 117 | TH_API void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag); 118 | TH_API void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag); 119 | TH_API void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension); 120 | TH_API accreal THTensor_(dist)(THTensor *a, THTensor *b, real value); 121 | TH_API void THTensor_(histc)(THTensor *hist, THTensor *tensor, long nbins, real minvalue, real maxvalue); 122 | 123 | TH_API accreal THTensor_(meanall)(THTensor *self); 124 | TH_API accreal THTensor_(varall)(THTensor *self); 125 | TH_API accreal THTensor_(stdall)(THTensor *self); 126 | TH_API accreal THTensor_(normall)(THTensor *t, real value); 127 | 128 | TH_API void THTensor_(linspace)(THTensor *r_, real a, real b, long n); 129 | TH_API void THTensor_(logspace)(THTensor *r_, real a, real b, long n); 130 | TH_API void THTensor_(rand)(THTensor *r_, THGenerator *_generator, THLongStorage *size); 131 | TH_API void THTensor_(randn)(THTensor *r_, THGenerator *_generator, THLongStorage *size); 132 | 133 | #endif 134 | 135 | #endif 136 | -------------------------------------------------------------------------------- /TH/THAllocator.c: -------------------------------------------------------------------------------- 1 | #include "THAllocator.h" 2 | 3 | /* stuff for mapped files */ 4 | #ifdef _WIN32 5 | #include 6 | #endif 7 | 8 | #if HAVE_MMAP 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #endif 15 | /* end of stuff for mapped files */ 16 | 17 | static void *THDefaultAllocator_alloc(void* ctx, long size) { 18 | return THAlloc(size); 19 | } 20 | 21 | static void *THDefaultAllocator_realloc(void* ctx, void* ptr, long size) { 22 | return THRealloc(ptr, size); 23 | } 24 | 25 | static void THDefaultAllocator_free(void* ctx, void* ptr) { 26 | THFree(ptr); 27 | } 28 | 29 | THAllocator THDefaultAllocator = { 30 | &THDefaultAllocator_alloc, 31 | &THDefaultAllocator_realloc, 32 | &THDefaultAllocator_free 33 | }; 34 | 35 | #if defined(_WIN32) || defined(HAVE_MMAP) 36 | 37 | struct THMapAllocatorContext_ { 38 | char *filename; /* file name */ 39 | int shared; /* is shared or not */ 40 | long size; /* mapped size */ 41 | }; 42 | 43 | THMapAllocatorContext *THMapAllocatorContext_new(const char *filename, int shared) 44 | { 45 | THMapAllocatorContext *ctx = THAlloc(sizeof(THMapAllocatorContext)); 46 | 47 | ctx->filename = THAlloc(strlen(filename)+1); 48 | strcpy(ctx->filename, filename); 49 | ctx->shared = shared; 50 | ctx->size = 0; 51 | 52 | return ctx; 53 | } 54 | 55 | long THMapAllocatorContext_size(THMapAllocatorContext *ctx) 56 | { 57 | return ctx->size; 58 | } 59 | 60 | void THMapAllocatorContext_free(THMapAllocatorContext *ctx) 61 | { 62 | THFree(ctx->filename); 63 | THFree(ctx); 64 | } 65 | 66 | static void *THMapAllocator_alloc(void* ctx_, long size) 67 | { 68 | THMapAllocatorContext *ctx = ctx_; 69 | void *data = NULL; 70 | 71 | #ifdef _WIN32 72 | { 73 | HANDLE hfile; 74 | HANDLE hmfile; 75 | DWORD size_hi, size_lo; 76 | size_t hfilesz; 77 | 78 | /* open file */ 79 | /* FILE_FLAG_RANDOM_ACCESS ? */ 80 | if(ctx->shared) 81 | { 82 | hfile = CreateFileA(ctx->filename, GENERIC_READ|GENERIC_WRITE, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, 0); 83 | if (hfile == INVALID_HANDLE_VALUE) 84 | THError("could not open file <%s> in read-write mode", ctx->filename); 85 | } 86 | else 87 | { 88 | hfile = CreateFileA(ctx->filename, GENERIC_READ, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); 89 | if (hfile == INVALID_HANDLE_VALUE) 90 | THError("could not open file <%s> in read-only mode", ctx->filename); 91 | } 92 | 93 | size_lo = GetFileSize(hfile, &size_hi); 94 | if(sizeof(size_t) > 4) 95 | { 96 | hfilesz = ((size_t)size_hi) << 32; 97 | hfilesz |= size_lo; 98 | } 99 | else 100 | hfilesz = (size_t)(size_lo); 101 | 102 | if(size > 0) 103 | { 104 | if(size > hfilesz) 105 | { 106 | if(ctx->shared) 107 | { 108 | #if SIZEOF_SIZE_T > 4 109 | size_hi = (DWORD)((size) >> 32); 110 | size_lo = (DWORD)((size) & 0xFFFFFFFF); 111 | #else 112 | size_hi = 0; 113 | size_lo = (DWORD)(size); 114 | #endif 115 | if((SetFilePointer(hfile, size_lo, &size_hi, FILE_BEGIN)) == INVALID_SET_FILE_POINTER) 116 | { 117 | CloseHandle(hfile); 118 | THError("unable to stretch file <%s> to the right size", ctx->filename); 119 | } 120 | if(SetEndOfFile(hfile) == 0) 121 | { 122 | CloseHandle(hfile); 123 | THError("unable to write to file <%s>", ctx->filename); 124 | } 125 | } 126 | else 127 | { 128 | CloseHandle(hfile); 129 | THError("file <%s> size is smaller than the required mapping size <%ld>", ctx->filename, size); 130 | } 131 | } 132 | } 133 | else 134 | size = hfilesz; 135 | 136 | ctx->size = size; /* if we are here, it must be the right size */ 137 | 138 | #if SIZEOF_SIZE_T > 4 139 | size_hi = (DWORD)((ctx->size) >> 32); 140 | size_lo = (DWORD)((ctx->size) & 0xFFFFFFFF); 141 | #else 142 | size_hi = 0; 143 | size_lo = (DWORD)(ctx->size); 144 | #endif 145 | 146 | /* get map handle */ 147 | if(ctx->shared) 148 | { 149 | if( (hmfile = CreateFileMapping(hfile, NULL, PAGE_READWRITE, size_hi, size_lo, NULL)) == NULL ) 150 | THError("could not create a map on file <%s>", ctx->filename); 151 | } 152 | else 153 | { 154 | if( (hmfile = CreateFileMapping(hfile, NULL, PAGE_WRITECOPY, size_hi, size_lo, NULL)) == NULL ) 155 | THError("could not create a map on file <%s>", ctx->filename); 156 | } 157 | 158 | /* map the stuff */ 159 | if(ctx->shared) 160 | data = MapViewOfFile(hmfile, FILE_MAP_ALL_ACCESS, 0, 0, 0); 161 | else 162 | data = MapViewOfFile(hmfile, FILE_MAP_COPY, 0, 0, 0); 163 | 164 | CloseHandle(hfile); 165 | CloseHandle(hmfile); 166 | } 167 | #else 168 | { 169 | /* open file */ 170 | int fd; 171 | int fdsz; 172 | 173 | if(ctx->shared) 174 | { 175 | if((fd = open(ctx->filename, O_RDWR | O_CREAT, (mode_t)0600)) == -1) 176 | THError("unable to open file <%s> in read-write mode", ctx->filename); 177 | } 178 | else 179 | { 180 | if((fd = open(ctx->filename, O_RDONLY)) == -1) 181 | THError("unable to open file <%s> in read-only mode", ctx->filename); 182 | } 183 | if((fdsz = lseek(fd, 0, SEEK_END)) == -1) 184 | { 185 | close(fd); 186 | THError("unable to seek at end of file <%s>", ctx->filename); 187 | } 188 | if(size > 0) 189 | { 190 | if(size > fdsz) 191 | { 192 | if(ctx->shared) 193 | { 194 | if((fdsz = lseek(fd, size-1, SEEK_SET)) == -1) 195 | { 196 | close(fd); 197 | THError("unable to stretch file <%s> to the right size", ctx->filename); 198 | } 199 | if((write(fd, "", 1)) != 1) /* note that the string "" contains the '\0' byte ... */ 200 | { 201 | close(fd); 202 | THError("unable to write to file <%s>", ctx->filename); 203 | } 204 | } 205 | else 206 | { 207 | close(fd); 208 | THError("file <%s> size is smaller than the required mapping size <%ld>", ctx->filename, size); 209 | } 210 | } 211 | } 212 | else 213 | size = fdsz; 214 | 215 | ctx->size = size; /* if we are here, it must be the right size */ 216 | 217 | /* map it */ 218 | if(ctx->shared) 219 | data = mmap(NULL, ctx->size, PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0); 220 | else 221 | data = mmap(NULL, ctx->size, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0); 222 | 223 | if(data == MAP_FAILED) 224 | data = NULL; /* let's be sure it is NULL */ 225 | 226 | close(fd); 227 | } 228 | #endif 229 | 230 | return data; 231 | } 232 | 233 | static void *THMapAllocator_realloc(void* ctx, void* ptr, long size) { 234 | THError("cannot realloc mapped data"); 235 | return NULL; 236 | } 237 | 238 | static void THMapAllocator_free(void* ctx_, void* data) { 239 | THMapAllocatorContext *ctx = ctx_; 240 | 241 | #ifdef _WIN32 242 | if(!UnmapViewOfFile((LPINT)data)) 243 | THError("could not unmap the shared memory file"); 244 | #else 245 | if (munmap(data, ctx->size)) 246 | THError("could not unmap the shared memory file"); 247 | #endif 248 | 249 | THMapAllocatorContext_free(ctx); 250 | } 251 | 252 | #else 253 | 254 | THMapAllocatorContext *THMapAllocatorContext_new(const char *filename, int shared) { 255 | THError("file mapping not supported on your system"); 256 | return NULL; 257 | } 258 | 259 | void THMapAllocatorContext_free(THMapAllocatorContext *ctx) { 260 | THError("file mapping not supported on your system"); 261 | } 262 | 263 | static void *THMapAllocator_alloc(void* ctx_, long size) { 264 | THError("file mapping not supported on your system"); 265 | return NULL; 266 | } 267 | 268 | static void *THMapAllocator_realloc(void* ctx, void* ptr, long size) { 269 | THError("file mapping not supported on your system"); 270 | return NULL; 271 | } 272 | 273 | static void THMapAllocator_free(void* ctx, void* data) { 274 | THError("file mapping not supported on your system"); 275 | } 276 | 277 | #endif 278 | 279 | THAllocator THMapAllocator = { 280 | &THMapAllocator_alloc, 281 | &THMapAllocator_realloc, 282 | &THMapAllocator_free 283 | }; 284 | -------------------------------------------------------------------------------- /serialization.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch.env' 2 | local class = require 'class' 3 | local File = class.metatable('torch.File') 4 | 5 | function File:writeBool(value) 6 | if value then 7 | self:writeInt(1) 8 | else 9 | self:writeInt(0) 10 | end 11 | end 12 | 13 | function File:readBool() 14 | return (self:readInt() == 1) 15 | end 16 | 17 | local TYPE_NIL = 0 18 | local TYPE_NUMBER = 1 19 | local TYPE_STRING = 2 20 | local TYPE_TABLE = 3 21 | local TYPE_TORCH = 4 22 | local TYPE_BOOLEAN = 5 23 | local TYPE_FUNCTION = 6 24 | 25 | function File:isWritableObject(object) 26 | local typename = class.type(object) 27 | local typeidx 28 | if type(object) ~= 'boolean' and not object then 29 | typeidx = TYPE_NIL 30 | elseif torch.metatable(typename) then 31 | typeidx = TYPE_TORCH 32 | elseif typename == 'table' then 33 | typeidx = TYPE_TABLE 34 | elseif typename == 'number' then 35 | typeidx = TYPE_NUMBER 36 | elseif typename == 'string' then 37 | typeidx = TYPE_STRING 38 | elseif typename == 'boolean' then 39 | typeidx = TYPE_BOOLEAN 40 | elseif typename == 'function' and pcall(string.dump, object) then 41 | typeidx = TYPE_FUNCTION 42 | end 43 | return typeidx 44 | end 45 | 46 | function File:writeObject(object, force) 47 | -- keep a record of written objects 48 | self.__writeObjects = self.__writeObjects or {} 49 | self.__writeObjectsRef = self.__writeObjectsRef or {} 50 | 51 | -- if nil object, only write the type and return 52 | if type(object) ~= 'boolean' and not object then 53 | self:writeInt(TYPE_NIL) 54 | return 55 | end 56 | 57 | -- check the type we are dealing with 58 | local typeidx = self:isWritableObject(object) 59 | if not typeidx then 60 | error(string.format('unwritable object <%s>', type(object))) 61 | end 62 | self:writeInt(typeidx) 63 | 64 | if typeidx == TYPE_NUMBER then 65 | self:writeDouble(object) 66 | elseif typeidx == TYPE_BOOLEAN then 67 | self:writeBool(object) 68 | elseif typeidx == TYPE_STRING then 69 | local stringStorage = torch.CharStorage():string(object) 70 | self:writeInt(stringStorage:size()) 71 | self:writeChar(stringStorage) 72 | elseif typeidx == TYPE_FUNCTION then 73 | local upvalues = {} 74 | while true do 75 | local name,value = debug.getupvalue(object, #upvalues+1) 76 | if not name then break end 77 | table.insert(upvalues, value) 78 | end 79 | local dumped = string.dump(object) 80 | local stringStorage = torch.CharStorage():string(dumped) 81 | self:writeInt(stringStorage:size()) 82 | self:writeChar(stringStorage) 83 | self:writeObject(upvalues) 84 | elseif typeidx == TYPE_TORCH or typeidx == TYPE_TABLE then 85 | -- check it exists already 86 | local objects = self.__writeObjects 87 | local objectsRef = self.__writeObjectsRef 88 | local index = objects[object] 89 | 90 | if index and (not force) then 91 | -- if already exists, write only its index 92 | self:writeInt(index) 93 | else 94 | -- else write the object itself 95 | index = objects.nWriteObject or 0 96 | index = index + 1 97 | objects[object] = index 98 | objectsRef[object] = index -- we make sure the object is not going to disappear 99 | self:writeInt(index) 100 | objects.nWriteObject = index 101 | 102 | if typeidx == TYPE_TORCH then 103 | local version = 'V ' .. object.__version 104 | self:writeInt(#version) -- backward compat 105 | self:write(version .. '\n') 106 | local className = class.type(object) 107 | self:writeInt(#className) -- backward compat 108 | self:write(className .. '\n') 109 | if object.write then 110 | object:write(self) 111 | else 112 | local var = {} 113 | for k,v in pairs(object) do 114 | if self:isWritableObject(v) then 115 | var[k] = v 116 | else 117 | print(string.format('$ Warning: cannot write object field <%s>', k)) 118 | end 119 | end 120 | self:writeObject(var) 121 | end 122 | else -- it is a table 123 | local size = 0; for k,v in pairs(object) do size = size + 1 end 124 | self:writeInt(size) 125 | for k,v in pairs(object) do 126 | self:writeObject(k) 127 | self:writeObject(v) 128 | end 129 | end 130 | end 131 | else 132 | error('unwritable object') 133 | end 134 | end 135 | 136 | function File:readObject() 137 | -- keep a record of read objects 138 | self.__readObjects = self.__readObjects or {} 139 | 140 | -- read the typeidx 141 | local typeidx = self:readInt() 142 | 143 | -- is it nil? 144 | if typeidx == TYPE_NIL then 145 | return nil 146 | end 147 | 148 | if typeidx == TYPE_NUMBER then 149 | return self:readDouble() 150 | elseif typeidx == TYPE_BOOLEAN then 151 | return self:readBool() 152 | elseif typeidx == TYPE_STRING then 153 | local size = self:readInt() 154 | return self:readChar(size):string() 155 | elseif typeidx == TYPE_FUNCTION then 156 | local size = self:readInt() 157 | local dumped = self:readChar(size):string() 158 | local func = loadstring(dumped) 159 | local upvalues = self:readObject() 160 | for index,upvalue in ipairs(upvalues) do 161 | debug.setupvalue(func, index, upvalue) 162 | end 163 | return func 164 | elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH then 165 | -- read the index 166 | local index = self:readInt() 167 | 168 | -- check it is loaded already 169 | local objects = self.__readObjects 170 | if objects[index] then 171 | return objects[index] 172 | end 173 | 174 | -- otherwise read it 175 | if typeidx == TYPE_TORCH then 176 | local version, className, versionNumber 177 | self:readInt() -- backward compat 178 | version = self:read('*l') 179 | versionNumber = tonumber(string.match(version, '^V (.*)$')) 180 | if not versionNumber then 181 | className = version 182 | versionNumber = 0 -- file created before existence of versioning system 183 | else 184 | self:readInt() -- backward compat 185 | className = self:read('*l') 186 | end 187 | if not torch.metatable(className) then 188 | error(string.format('unknown Torch class <%s>', className)) 189 | end 190 | local object = torch.factory(className) 191 | objects[index] = object 192 | if object.read then 193 | object:read(self, versionNumber) 194 | else 195 | local var = self:readObject() 196 | for k,v in pairs(var) do 197 | object[k] = v 198 | end 199 | end 200 | return object 201 | else -- it is a table 202 | local size = self:readInt() 203 | local object = {} 204 | objects[index] = object 205 | for i = 1,size do 206 | local k = self:readObject() 207 | local v = self:readObject() 208 | object[k] = v 209 | end 210 | return object 211 | end 212 | else 213 | error('unknown object') 214 | end 215 | end 216 | 217 | -- simple helpers to save/load arbitrary objects/tables 218 | function torch.save(filename, object, mode) 219 | mode = mode or 'binary' 220 | local file = torch.DiskFile(filename, 'w') 221 | file[mode](file) 222 | file:writeObject(object) 223 | file:close() 224 | end 225 | 226 | function torch.load(filename, mode) 227 | mode = mode or 'binary' 228 | local file = torch.DiskFile(filename, 'r') 229 | file[mode](file) 230 | local object = file:readObject() 231 | file:close() 232 | return object 233 | end 234 | 235 | -- simple helpers to serialize/deserialize arbitrary objects/tables 236 | function torch.serialize(object) 237 | local f = torch.MemoryFile() 238 | f:writeObject(object) 239 | local s = f:storage():string() 240 | f:close() 241 | return s 242 | end 243 | 244 | function torch.deserialize(str) 245 | local x = torch.CharStorage():string(str) 246 | local tx = torch.CharTensor(x) 247 | local xp = torch.CharStorage(x:size(1)+1) 248 | local txp = torch.CharTensor(xp) 249 | txp:narrow(1,1,tx:size(1)):copy(tx) 250 | txp[tx:size(1)+1] = 0 251 | local f = torch.MemoryFile(xp) 252 | local object = f:readObject() 253 | f:close() 254 | return object 255 | end 256 | -------------------------------------------------------------------------------- /tensorop.lua: -------------------------------------------------------------------------------- 1 | local display = require 'torch.display' 2 | local torch = require 'torch.env' 3 | local class = require 'class' 4 | local ffi = require 'ffi' 5 | local C = require 'torch.TH' 6 | 7 | local RealTensor = class.metatable('torch.RealTensor') 8 | 9 | local function index_table(self, k, v) 10 | assert(#k <= self.__nDimension, 'invalid table size') 11 | local cdim = 0 12 | local res 13 | self = C.THRealTensor_newWithTensor(self)[0] 14 | for dim=0,self.__nDimension-1 do 15 | local z = k[dim+1] 16 | if type(z) == 'number' then 17 | z = z - 1 18 | if z < 0 then 19 | z = self.__size[cdim] + z + 1 20 | end 21 | assert(z >= 0 and z < self.__size[cdim], 'out of range') 22 | if self.__nDimension == 1 then 23 | res = self.__storage.__data+self.__storageOffset+z*self.__stride[0] 24 | else 25 | C.THRealTensor_select(self, nil, cdim, z) 26 | end 27 | elseif type(z) == 'table' then 28 | local a = 0 29 | local b = self.__size[cdim]-1 30 | 31 | local zz = z[1] 32 | if type(zz) == 'number' then 33 | a = zz-1 34 | b = a 35 | end 36 | if a < 0 then 37 | a = self.__size[cdim] + a + 1 38 | end 39 | assert(a >= 0 and a < self.__size[cdim], 'out of range') 40 | 41 | local zz = z[2] 42 | if type(zz) == 'number' then 43 | b = zz-1 44 | end 45 | if b < 0 then 46 | b = self.__size[cdim] + b + 1 47 | end 48 | assert(b >= 0 and b < self.__size[cdim], 'out of range') 49 | 50 | assert(b >= a, 'end index must be greater or equal to start index') 51 | C.THRealTensor_narrow(self, nil, cdim, a, b-a+1) 52 | cdim = cdim + 1 53 | elseif type(z) ~= 'nil' then 54 | error('invalid table') 55 | end 56 | end 57 | if v then 58 | if res then 59 | res[0] = v 60 | C.THRealTensor_free(self) 61 | else 62 | self:copy(v) -- DEBUG: this could fail 63 | C.THRealTensor_free(self) 64 | end 65 | else 66 | if res then 67 | C.THRealTensor_free(self) 68 | return tonumber(res) 69 | else 70 | ffi.gc(self, C.THRealTensor_free) 71 | return self 72 | end 73 | end 74 | end 75 | 76 | function RealTensor:__index(k) 77 | local type_k = class.type(k) 78 | if type_k == 'number' then 79 | if self.__nDimension == 1 then 80 | assert(k > 0 and k <= self.__size[0], 'out of range') 81 | return tonumber( self.__storage.__data[(k-1)*self.__stride[0]+self.__storageOffset] ) 82 | elseif self.__nDimension > 1 then 83 | assert(k > 0 and k <= self.__size[0], 'out of range') 84 | return self:select(1, k) 85 | else 86 | error('empty tensor') 87 | end 88 | elseif type_k == 'torch.LongStorage' then 89 | assert(k.__size == self.__nDimension, 'invalid storage size') 90 | local idx = self.__storageOffset 91 | for dim=0,tonumber(k.__size)-1 do 92 | local z = k.__data[dim]-1 93 | assert(z >= 0 and z < self.__size[dim], 'out of range') 94 | idx = idx + z*self.__stride[dim] 95 | end 96 | return tonumber(self.__storage.__data[idx]) 97 | elseif type_k == 'torch.ByteTensor' then 98 | local vals = torch.RealTensor() 99 | C.THRealTensor_maskedSelect(vals, self, k) 100 | return vals 101 | elseif type_k == 'table' then 102 | return index_table(self, k) 103 | else 104 | return RealTensor[k] 105 | end 106 | end 107 | 108 | function RealTensor:__newindex(k, v) 109 | local type_k = class.type(k) 110 | local type_v = class.type(v) 111 | if type_k == 'number' then 112 | if type_v == 'number' then 113 | if self.__nDimension == 1 then 114 | assert(k > 0 and k <= self.__size[0], 'out of range') 115 | self.__storage.__data[self.__storageOffset+(k-1)*self.__stride[0]] = v 116 | elseif self.__nDimension > 1 then 117 | local t = C.THRealTensor_newWithTensor(t) 118 | C.THRealTensor_narrow(t, nil, 0, k-1, 1) 119 | C.THRealTensor_fill(t, v) 120 | C.THRealTensor_free(t) 121 | else 122 | error('empty tensor') 123 | end 124 | elseif 125 | type_v == 'torch.ByteTensor' 126 | or type_v == 'torch.CharTensor' 127 | or type_v == 'torch.ShortTensor' 128 | or type_v == 'torch.IntTensor' 129 | or type_v == 'torch.LongTensor' 130 | or type_v == 'torch.FloatTensor' 131 | or type_v == 'torch.DoubleTensor' then 132 | local t = self:narrow(1, k, 1) -- use gc, as this can fail 133 | t:copy(v) 134 | end 135 | elseif type_k == 'torch.LongStorage' then 136 | assert(type_v == 'number', 'number expected as value for a LongStorage as key') 137 | assert(k.__size == self.__nDimension, 'invalid storage size') 138 | local idx = self.__storageOffset 139 | for dim=0,tonumber(k.__size)-1 do 140 | local z = k.__data[dim]-1 141 | assert(z >= 0 and z < self.__size[dim], 'out of range') 142 | idx = idx + z*self.__stride[dim] 143 | end 144 | self.__storage.__data[idx] = v 145 | elseif type_k == 'torch.ByteTensor' then 146 | if type_v == 'number' then 147 | C.THRealTensor_maskedFill(self, k, v) 148 | elseif type_v == 'torch.RealTensor' then 149 | C.THRealTensor_maskedCopy(self, k, v) 150 | else 151 | error('when using a mask as a key, number or tensor are expected as value') 152 | end 153 | elseif type_k == 'table' then 154 | index_table(self, k, v) 155 | else 156 | rawset(self, k, v) 157 | end 158 | end 159 | 160 | RealTensor.__tostring = display.tensor 161 | 162 | function RealTensor.__add(t1, t2) 163 | local type_t1 = class.type(t1) 164 | local type_t2 = class.type(t2) 165 | 166 | local r = torch.RealTensor() 167 | if type_t1 == 'torch.RealTensor' and type_t2 == 'number' then 168 | r:resizeAs(t1) 169 | r:fill(t2) 170 | r:add(t1) 171 | elseif type_t1 == 'number' and type_t2 == 'torch.RealTensor' then 172 | r:resizeAs(t2) 173 | r:fill(t1) 174 | r:add(t2) 175 | elseif type_t1 == 'torch.RealTensor' and type_t2 == 'torch.RealTensor' then 176 | r:resizeAs(t1) 177 | r:copy(t1) 178 | r:add(t2) 179 | else 180 | error('two tensors or one tensor and one number expected') 181 | end 182 | 183 | return r 184 | end 185 | 186 | function RealTensor.__sub(t1, t2) 187 | local type_t1 = class.type(t1) 188 | local type_t2 = class.type(t2) 189 | 190 | local r = torch.RealTensor() 191 | if type_t1 == 'torch.RealTensor' and type_t2 == 'number' then 192 | r:resizeAs(t1) 193 | r:copy(t1) 194 | r:add(-t2) 195 | elseif type_t1 == 'number' and type_t2 == 'torch.RealTensor' then 196 | r:resizeAs(t2) 197 | r:fill(t1) 198 | r:add(-1, t1) 199 | elseif type_t1 == 'torch.RealTensor' and type_t2 == 'torch.RealTensor' then 200 | r:resizeAs(t1) 201 | r:copy(t1) 202 | r:add(-1, t2) 203 | else 204 | error('two tensors or one tensor and one number expected') 205 | end 206 | 207 | return r 208 | end 209 | 210 | function RealTensor.__unm(self) 211 | local r = torch.RealTensor() 212 | r:resizeAs(self) 213 | r:zero() 214 | r:add(-1, self) 215 | return r 216 | end 217 | 218 | function RealTensor.__mul(t1, t2) 219 | local type_t1 = class.type(t1) 220 | local type_t2 = class.type(t2) 221 | 222 | local r = torch.RealTensor() 223 | if type_t1 == 'torch.RealTensor' and type_t2 == 'number' then 224 | r:resizeAs(t1) 225 | r:zero() 226 | r:add(t2, t1) 227 | elseif type_t1 == 'number' and type_t2 == 'torch.RealTensor' then 228 | r:resizeAs(t2) 229 | r:zero() 230 | r:add(t1, t2) 231 | elseif type_t1 == 'torch.RealTensor' and type_t2 == 'torch.RealTensor' then 232 | if t1.__nDimension == 1 and t2.__nDimension == 1 then 233 | return t1:dot(t2) 234 | elseif t1.__nDimension == 2 and t2.__nDimension == 1 then 235 | return t1:mv(t2) 236 | elseif t1.__nDimension == 2 and t2.__nDimension == 2 then 237 | return t1:mm(t2) 238 | else 239 | error(string.format('multiplication between %dD and %dD tensorsnot yet supported', 240 | t1.__nDimension, t2.__nDimension)) 241 | end 242 | else 243 | error('two tensors or one tensor and one number expected') 244 | end 245 | 246 | return r 247 | end 248 | 249 | function RealTensor.__div(t1, t2) 250 | local type_t1 = class.type(t1) 251 | local type_t2 = class.type(t2) 252 | 253 | assert(type_t2 == 'number', 'number expected') 254 | 255 | local r = torch.RealTensor() 256 | r:resizeAs(t1) 257 | r:copy(t1) 258 | r:mul(1/t2) 259 | 260 | return r 261 | end 262 | -------------------------------------------------------------------------------- /memoryfile.lua: -------------------------------------------------------------------------------- 1 | local argcheck = require 'argcheck' 2 | local torch = require 'torch.env' 3 | local class = require 'class' 4 | local ffi = require 'ffi' 5 | 6 | ffi.cdef[[ 7 | int snprintf(char *restrict s, size_t n, const char *restrict format, ...); 8 | int sscanf(const char *restrict s, const char *restrict format, ...); 9 | ]] 10 | 11 | local MemoryFile = class('torch.MemoryFile', 'torch.File') 12 | torch.MemoryFile = MemoryFile 13 | 14 | local function grow(self, size) 15 | if self.__position + size + 1 >= self.__buffersize then 16 | local gsz = math.max(self.__position + size + 1, -- count trailing '\0' 17 | self.__growsize + size) 18 | 19 | if self.__buffer then 20 | ffi.gc(self.__buffer, nil) 21 | self.__buffer = ffi.cast('char*', ffi.C.realloc(self.__buffer, gsz)) 22 | ffi.gc(self.__buffer, ffi.C.free) 23 | else 24 | self.__buffer = ffi.cast('char*', ffi.C.malloc(gsz)) 25 | ffi.gc(self.__buffer, ffi.C.free) 26 | end 27 | 28 | assert(self.__buffer ~= nil, 'out of memory') 29 | self.__buffersize = gsz 30 | self.__buffer[gsz-1] = 0 31 | end 32 | end 33 | 34 | MemoryFile.isOpened = argcheck{ 35 | {name="self", type="torch.MemoryFile"}, 36 | call = 37 | function(self) 38 | assert(self.__buffer, 'attempt to use a closed file') 39 | return self.__buffer ~= nil 40 | end 41 | } 42 | 43 | MemoryFile.synchronize = argcheck{ 44 | {name="self", type="torch.MemoryFile"}, 45 | call = 46 | function(self) 47 | end 48 | } 49 | 50 | MemoryFile.seek = argcheck{ 51 | {name="self", type="torch.MemoryFile"}, 52 | {name="position", type="number"}, 53 | call = 54 | function(self, position) 55 | assert(self.__buffer, 'attempt to use a closed file') 56 | if position < 0 or position >= self.__size then 57 | self.__hasError = 1 58 | if not self.__isQuiet then 59 | error('unable to seek in file') 60 | end 61 | end 62 | self.__position = position 63 | return self 64 | end 65 | } 66 | 67 | MemoryFile.seekEnd = argcheck{ 68 | {name="self", type="torch.MemoryFile"}, 69 | call = 70 | function(self) 71 | assert(self.__buffer, 'attempt to use a closed file') 72 | self.__position = self.__size 73 | return self 74 | end 75 | } 76 | 77 | MemoryFile.position = argcheck{ 78 | {name="self", type="torch.MemoryFile"}, 79 | call = 80 | function(self) 81 | assert(self.__buffer, 'attempt to use a closed file') 82 | return self.__position 83 | end 84 | } 85 | 86 | MemoryFile.close = argcheck{ 87 | {name="self", type="torch.MemoryFile"}, 88 | call = 89 | function(self) 90 | assert(self.__buffer, 'attempt to use a closed file') 91 | ffi.gc(self.__buffer, nil) 92 | ffi.C.free(self.__buffer) 93 | self.__buffer = nil 94 | return self 95 | end 96 | } 97 | 98 | MemoryFile.__write = argcheck{ 99 | {name="self", type="torch.MemoryFile"}, 100 | {name="data", type="cdata"}, 101 | {name="elemsize", type="number"}, 102 | {name="size", type="number"}, 103 | call = 104 | function(self, data, elemsize, size) 105 | assert(self.__buffer, 'attempt to write in a closed file') 106 | assert(self.__isWritable, 'read-only file') 107 | grow(self, size*elemsize) 108 | ffi.copy(self.__buffer+self.__position, 109 | data, 110 | size*elemsize) 111 | self.__position = self.__position + size*elemsize 112 | self.__size = math.max(self.__position, self.__size) 113 | self.__buffer[self.__size] = 0 114 | return size 115 | end 116 | } 117 | 118 | MemoryFile.__read = argcheck{ 119 | {name="self", type="torch.MemoryFile"}, 120 | {name="data", type="cdata"}, 121 | {name="elemsize", type="number"}, 122 | {name="size", type="number"}, 123 | call = 124 | function(self, data, elemsize, size) 125 | assert(self.__buffer, 'attempt to write in a closed file') 126 | assert(self.__isReadable, 'write-only file') 127 | local n = math.min(math.floor((self.__size-self.__position)/elemsize), size) 128 | if n > 0 then 129 | ffi.copy(data, self.__buffer+self.__position, n*elemsize) 130 | self.__position = self.__position + n 131 | end 132 | return n 133 | end 134 | } 135 | 136 | local format2cast = { 137 | ['%hhu'] = ffi.typeof('unsigned char'), 138 | ['%hhd'] = ffi.typeof('char'), 139 | ['%hd'] = ffi.typeof('short'), 140 | ['%d'] = ffi.typeof('int'), 141 | ['%ld'] = ffi.typeof('long'), 142 | ['%g'] = ffi.typeof('float'), 143 | ['%lg'] = ffi.typeof('double'), 144 | } 145 | 146 | MemoryFile.__printf = argcheck{ 147 | {name="self", type="torch.MemoryFile"}, 148 | {name="format", type="string"}, 149 | {name="data", type="cdata"}, 150 | {name="size", type="number"}, 151 | call = 152 | function(self, format, data, size) 153 | assert(self.__buffer, 'attempt to write in a closed file') 154 | assert(self.__isWritable, 'read-only file') 155 | local cast = format2cast[format] 156 | for i=0,size-1 do 157 | repeat 158 | local szm = self.__buffersize-self.__position 159 | local szw = ffi.C.snprintf(self.__buffer+self.__position, 160 | szm, 161 | format, 162 | cast(data[i])) 163 | 164 | if szm <= szw then 165 | grow(self, szw) 166 | else 167 | self.__position = self.__position + szw 168 | self.__size = math.max(self.__position, self.__size) 169 | self.__buffer[self.__size] = 0 170 | end 171 | until szm > szw 172 | end 173 | return size 174 | end 175 | } 176 | 177 | MemoryFile.__scanf = argcheck{ 178 | {name="self", type="torch.MemoryFile"}, 179 | {name="format", type="string"}, 180 | {name="data", type="cdata"}, 181 | {name="size", type="number"}, 182 | call = 183 | function(self, format, data, size) 184 | assert(self.__buffer, 'attempt to write in a closed file') 185 | assert(self.__isReadable, 'write-only file') 186 | format = format .. "%n" 187 | local p = ffi.new('int[1]') 188 | local n = 0 189 | for i=0,size-1 do 190 | local ret = ffi.C.sscanf(self.__buffer+self.__position, format, data+i, p) 191 | if ret <= 0 then 192 | break 193 | else 194 | self.__position = self.__position + tonumber(p[0]) 195 | n = n + 1 196 | end 197 | end 198 | 199 | if self.__isAutoSpacing and size > 0 then 200 | if self.__position < self.__size and self.__buffer[self.__position] == string.byte('\n') then 201 | self.__position = self.__position + 1 202 | end 203 | end 204 | 205 | return n 206 | end 207 | } 208 | 209 | MemoryFile.__gets = argcheck{ 210 | {name="self", type="torch.MemoryFile"}, 211 | call = 212 | function(self) 213 | assert(self.__buffer, 'attempt to write in a closed file') 214 | assert(self.__isReadable, 'write-only file') 215 | local size = self.__size-self.__position 216 | local buffer = self.__buffer + self.__position 217 | local ret = string.byte('\n') 218 | local eof = (size == 0) 219 | for i=0,size-1 do 220 | if buffer[i] == ret then 221 | size = i + 1 222 | break 223 | end 224 | end 225 | 226 | self.__position = self.__position + size 227 | if buffer[size-1] == ret then 228 | size = size - 1 229 | end 230 | 231 | if not eof then 232 | local str = ffi.string(buffer, size) 233 | return str 234 | end 235 | end 236 | } 237 | 238 | MemoryFile.__init = argcheck{ 239 | {name="self", type="torch.MemoryFile"}, 240 | {name="mode", type="string", default='rw'}, 241 | {name="quiet", type="boolean", default=false}, 242 | call = 243 | function(self, mode, quiet) 244 | assert(mode == 'r' or mode == 'w' or mode == 'rw', 'invalid mode (r, w or rw expected)') 245 | 246 | self.__growsize = 1024 247 | self.__buffersize = 0 248 | self.__position = 0 249 | self.__size = 0 250 | grow(self, 0) 251 | 252 | self.__isQuiet = quiet 253 | self.__isReadable = (mode == 'r') or (mode == 'rw') 254 | self.__isWritable = (mode == 'w') or (mode == 'rw') 255 | self.__isBinary = false 256 | self.__isAutoSpacing = true 257 | self.__hasError = false 258 | 259 | return self 260 | end 261 | } 262 | -------------------------------------------------------------------------------- /TH/generic/THTensorRandom.c: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERIC_FILE 2 | #define TH_GENERIC_FILE "generic/THTensorRandom.c" 3 | #else 4 | 5 | TH_API void THTensor_(random)(THTensor *self, THGenerator *_generator) 6 | { 7 | #if defined(TH_REAL_IS_BYTE) 8 | TH_TENSOR_APPLY(real, self, *self_data = (unsigned char)(THRandom_random(_generator) % (UCHAR_MAX+1));); 9 | #elif defined(TH_REAL_IS_CHAR) 10 | TH_TENSOR_APPLY(real, self, *self_data = (char)(THRandom_random(_generator) % (CHAR_MAX+1));); 11 | #elif defined(TH_REAL_IS_SHORT) 12 | TH_TENSOR_APPLY(real, self, *self_data = (short)(THRandom_random(_generator) % (SHRT_MAX+1));); 13 | #elif defined(TH_REAL_IS_INT) 14 | TH_TENSOR_APPLY(real, self, *self_data = (int)(THRandom_random(_generator) % (INT_MAX+1UL));); 15 | #elif defined(TH_REAL_IS_LONG) 16 | TH_TENSOR_APPLY(real, self, *self_data = (long)(THRandom_random(_generator) % (LONG_MAX+1UL));); 17 | #elif defined(TH_REAL_IS_FLOAT) 18 | TH_TENSOR_APPLY(real, self, *self_data = (float)(THRandom_random(_generator) % ((1UL << FLT_MANT_DIG)+1));); 19 | #elif defined(TH_REAL_IS_DOUBLE) 20 | TH_TENSOR_APPLY(real, self, *self_data = (float)(THRandom_random(_generator) % ((1UL << DBL_MANT_DIG)+1));); 21 | #else 22 | #error "Unknown type" 23 | #endif 24 | } 25 | 26 | TH_API void THTensor_(geometric)(THTensor *self, THGenerator *_generator, double p) 27 | { 28 | TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_geometric(_generator, p);); 29 | } 30 | 31 | TH_API void THTensor_(bernoulli)(THTensor *self, THGenerator *_generator, double p) 32 | { 33 | TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_bernoulli(_generator, p);); 34 | } 35 | 36 | #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) 37 | 38 | TH_API void THTensor_(uniform)(THTensor *self, THGenerator *_generator, double a, double b) 39 | { 40 | TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_uniform(_generator, a, b);); 41 | } 42 | 43 | TH_API void THTensor_(normal)(THTensor *self, THGenerator *_generator, double mean, double stdv) 44 | { 45 | TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_normal(_generator, mean, stdv);); 46 | } 47 | 48 | TH_API void THTensor_(exponential)(THTensor *self, THGenerator *_generator, double lambda) 49 | { 50 | TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_exponential(_generator, lambda);); 51 | } 52 | 53 | TH_API void THTensor_(cauchy)(THTensor *self, THGenerator *_generator, double median, double sigma) 54 | { 55 | TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_cauchy(_generator, median, sigma);); 56 | } 57 | 58 | TH_API void THTensor_(logNormal)(THTensor *self, THGenerator *_generator, double mean, double stdv) 59 | { 60 | TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_logNormal(_generator, mean, stdv);); 61 | } 62 | 63 | TH_API void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTensor *prob_dist, int n_sample, int with_replacement) 64 | { 65 | int start_dim = THTensor_(nDimension)(prob_dist); 66 | long n_dist; 67 | long n_categories; 68 | THTensor* cum_dist; 69 | int i,j,k; 70 | 71 | if (start_dim == 1) 72 | { 73 | THTensor_(resize2d)(prob_dist, 1, THTensor_(size)(prob_dist, 0)); 74 | } 75 | 76 | n_dist = THTensor_(size)(prob_dist, 0); 77 | n_categories = THTensor_(size)(prob_dist, 1); 78 | 79 | THArgCheck(n_sample > 0, 2, "cannot sample n_sample < 0 samples"); 80 | 81 | if (!with_replacement) 82 | { 83 | THArgCheck((!with_replacement) && (n_sample <= n_categories), 2, \ 84 | "cannot sample n_sample > prob_dist:size(1) samples without replacement"); 85 | } 86 | 87 | /* cumulative probability distribution vector */ 88 | cum_dist = THTensor_(newWithSize1d)(n_categories); 89 | 90 | /* will contain multinomial samples (category indices to be returned) */ 91 | THLongTensor_resize2d(self, n_dist , n_sample); 92 | 93 | for (i=0; istorage, \ 101 | prob_dist->storageOffset+i*prob_dist->stride[0]+j*prob_dist->stride[1] \ 102 | ); 103 | THStorage_(set)( 104 | cum_dist->storage, \ 105 | cum_dist->storageOffset+j*cum_dist->stride[0], \ 106 | sum \ 107 | ); 108 | } 109 | THArgCheck((sum > 0), 2, "invalid multinomial distribution (sum of probabilities <= 0)"); 110 | /* normalize cumulative probability distribution so that last val is 1 111 | i.e. dosen't assume original prob_dist row sums to one */ 112 | if ( (sum > 0) || ( ( sum < 1.00001) && (sum > 0.99999) ) ) 113 | { 114 | for (j=0; jstride[0]] /= sum; 117 | } 118 | } 119 | 120 | for (j=0; j 0) 133 | { 134 | mid_pointer = left_pointer + (right_pointer - left_pointer) / 2; 135 | cum_prob = THStorage_(get)( \ 136 | cum_dist->storage, \ 137 | cum_dist->storageOffset+mid_pointer*cum_dist->stride[0] \ 138 | ); 139 | if (cum_prob < uniform_sample) 140 | { 141 | left_pointer = mid_pointer + 1; 142 | } 143 | else 144 | { 145 | right_pointer = mid_pointer; 146 | } 147 | } 148 | sample_idx = left_pointer; 149 | 150 | /* store in result tensor (will be incremented for lua compat by wrapper) */ 151 | THLongStorage_set( \ 152 | self->storage, \ 153 | self->storageOffset+i*self->stride[0]+j*self->stride[1], \ 154 | sample_idx \ 155 | ); 156 | 157 | /* Once a sample is drawn, it cannot be drawn again. ie sample without replacement */ 158 | if (!with_replacement) 159 | { 160 | /* update cumulative distribution so that sample cannot be drawn again */ 161 | real diff; 162 | real new_val = 0; 163 | real sum; 164 | 165 | if (sample_idx != 0) 166 | { 167 | new_val = THStorage_(get)( \ 168 | cum_dist->storage, \ 169 | cum_dist->storageOffset+(sample_idx-1)*cum_dist->stride[0] \ 170 | ); 171 | } 172 | /* marginal cumulative mass (i.e. original probability) of sample */ 173 | diff = THStorage_(get)( \ 174 | cum_dist->storage, \ 175 | cum_dist->storageOffset+sample_idx*cum_dist->stride[0] \ 176 | ) - new_val; 177 | /* new sum of marginals is not one anymore... */ 178 | sum = 1.0 - diff; 179 | for (k=0; kstorage, \ 183 | cum_dist->storageOffset+k*cum_dist->stride[0] \ 184 | ); 185 | if (k >= sample_idx) 186 | { 187 | /* remove sampled probability mass from later cumulative probabilities */ 188 | new_val -= diff; 189 | } 190 | /* make total marginals sum to one */ 191 | new_val /= sum; 192 | THStorage_(set)( \ 193 | cum_dist->storage, \ 194 | cum_dist->storageOffset+k*cum_dist->stride[0], \ 195 | new_val \ 196 | ); 197 | } 198 | } 199 | } 200 | } 201 | 202 | THTensor_(free)(cum_dist); 203 | 204 | if (start_dim == 1) 205 | { 206 | THLongTensor_resize1d(self, n_sample); 207 | THTensor_(resize1d)(prob_dist, n_categories); 208 | } 209 | } 210 | 211 | #endif 212 | 213 | #if defined(TH_REAL_IS_LONG) 214 | TH_API void THTensor_(getRNGState)(THGenerator *_generator, THTensor *self) 215 | { 216 | unsigned long *data; 217 | long *offset; 218 | long *left; 219 | 220 | THTensor_(resize1d)(self,626); 221 | data = (unsigned long *)THTensor_(data)(self); 222 | offset = (long *)data+624; 223 | left = (long *)data+625; 224 | 225 | THRandom_getState(_generator, data, offset, left); 226 | } 227 | 228 | TH_API void THTensor_(setRNGState)(THGenerator *_generator, THTensor *self) 229 | { 230 | unsigned long *data; 231 | long *offset; 232 | long *left; 233 | 234 | THArgCheck(THTensor_(nElement)(self) == 626, 1, "state should have 626 elements"); 235 | data = (unsigned long *)THTensor_(data)(self); 236 | offset = (long *)(data+624); 237 | left = (long *)(data+625); 238 | 239 | THRandom_setState(_generator, data, *offset, *left); 240 | } 241 | #endif 242 | 243 | #endif 244 | -------------------------------------------------------------------------------- /TH/THRandom.c: -------------------------------------------------------------------------------- 1 | #include "THGeneral.h" 2 | #include "THRandom.h" 3 | 4 | 5 | /* Code for the Mersenne Twister random generator.... */ 6 | #define n _MERSENNE_STATE_N 7 | #define m _MERSENNE_STATE_M 8 | THGenerator* THGenerator_new() 9 | { 10 | THGenerator *self = THAlloc(sizeof(THGenerator)); 11 | self->left = 1; 12 | self->initf = 0; 13 | self->normal_is_valid = 0; 14 | return self; 15 | } 16 | 17 | void THGenerator_free(THGenerator *self) 18 | { 19 | THFree(self); 20 | } 21 | 22 | unsigned long THRandom_seed(THGenerator *_generator) 23 | { 24 | unsigned long s = (unsigned long)time(0); 25 | THRandom_manualSeed(_generator, s); 26 | return s; 27 | } 28 | 29 | /* The next 4 methods are taken from http:www.math.keio.ac.jpmatumotoemt.html 30 | Here is the copyright: 31 | Some minor modifications have been made to adapt to "my" C... */ 32 | 33 | /* 34 | A C-program for MT19937, with initialization improved 2002/2/10. 35 | Coded by Takuji Nishimura and Makoto Matsumoto. 36 | This is a faster version by taking Shawn Cokus's optimization, 37 | Matthe Bellew's simplification, Isaku Wada's double version. 38 | 39 | Before using, initialize the state by using init_genrand(seed) 40 | or init_by_array(init_key, key_length). 41 | 42 | Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, 43 | All rights reserved. 44 | 45 | Redistribution and use in source and binary forms, with or without 46 | modification, are permitted provided that the following conditions 47 | are met: 48 | 49 | 1. Redistributions of source code must retain the above copyright 50 | notice, this list of conditions and the following disclaimer. 51 | 52 | 2. Redistributions in binary form must reproduce the above copyright 53 | notice, this list of conditions and the following disclaimer in the 54 | documentation and/or other materials provided with the distribution. 55 | 56 | 3. The names of its contributors may not be used to endorse or promote 57 | products derived from this software without specific prior written 58 | permission. 59 | 60 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 61 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 62 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 63 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 64 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 65 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 66 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 67 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 68 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 69 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 70 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 71 | 72 | 73 | Any feedback is very welcome. 74 | http://www.math.keio.ac.jp/matumoto/emt.html 75 | email: matumoto@math.keio.ac.jp 76 | */ 77 | 78 | /* Macros for the Mersenne Twister random generator... */ 79 | /* Period parameters */ 80 | /* #define n 624 */ 81 | /* #define m 397 */ 82 | #define MATRIX_A 0x9908b0dfUL /* constant vector a */ 83 | #define UMASK 0x80000000UL /* most significant w-r bits */ 84 | #define LMASK 0x7fffffffUL /* least significant r bits */ 85 | #define MIXBITS(u,v) ( ((u) & UMASK) | ((v) & LMASK) ) 86 | #define TWIST(u,v) ((MIXBITS(u,v) >> 1) ^ ((v)&1UL ? MATRIX_A : 0UL)) 87 | /*********************************************************** That's it. */ 88 | 89 | void THRandom_manualSeed(THGenerator *_generator, unsigned long the_seed_) 90 | { 91 | int j; 92 | _generator->the_initial_seed = the_seed_; 93 | _generator->state[0] = _generator->the_initial_seed & 0xffffffffUL; 94 | for(j = 1; j < n; j++) 95 | { 96 | _generator->state[j] = (1812433253UL * (_generator->state[j-1] ^ (_generator->state[j-1] >> 30)) + j); 97 | /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */ 98 | /* In the previous versions, mSBs of the seed affect */ 99 | /* only mSBs of the array state[]. */ 100 | /* 2002/01/09 modified by makoto matsumoto */ 101 | _generator->state[j] &= 0xffffffffUL; /* for >32 bit machines */ 102 | } 103 | _generator->left = 1; 104 | _generator->initf = 1; 105 | } 106 | 107 | unsigned long THRandom_initialSeed(THGenerator *_generator) 108 | { 109 | if(_generator->initf == 0) 110 | { 111 | THRandom_seed(_generator); 112 | } 113 | 114 | return _generator->the_initial_seed; 115 | } 116 | 117 | void THRandom_nextState(THGenerator *_generator) 118 | { 119 | unsigned long *p = _generator->state; 120 | int j; 121 | 122 | /* if init_genrand() has not been called, */ 123 | /* a default initial seed is used */ 124 | if(_generator->initf == 0) 125 | THRandom_seed(_generator); 126 | 127 | _generator->left = n; 128 | _generator->next = _generator->state; 129 | 130 | for(j = n-m+1; --j; p++) 131 | *p = p[m] ^ TWIST(p[0], p[1]); 132 | 133 | for(j = m; --j; p++) 134 | *p = p[m-n] ^ TWIST(p[0], p[1]); 135 | 136 | *p = p[m-n] ^ TWIST(p[0], _generator->state[0]); 137 | } 138 | 139 | unsigned long THRandom_random(THGenerator *_generator) 140 | { 141 | unsigned long y; 142 | 143 | if (--(_generator->left) == 0) 144 | THRandom_nextState(_generator); 145 | y = *((_generator->next)++); 146 | 147 | /* Tempering */ 148 | y ^= (y >> 11); 149 | y ^= (y << 7) & 0x9d2c5680UL; 150 | y ^= (y << 15) & 0xefc60000UL; 151 | y ^= (y >> 18); 152 | 153 | return y; 154 | } 155 | 156 | /* generates a random number on [0,1)-double-interval */ 157 | static double __uniform__(THGenerator *_generator) 158 | { 159 | unsigned long y; 160 | 161 | if (--(_generator->left) == 0) 162 | THRandom_nextState(_generator); 163 | y = *((_generator->next)++); 164 | 165 | /* Tempering */ 166 | y ^= (y >> 11); 167 | y ^= (y << 7) & 0x9d2c5680UL; 168 | y ^= (y << 15) & 0xefc60000UL; 169 | y ^= (y >> 18); 170 | 171 | return (double)y * (1.0/4294967296.0); 172 | /* divided by 2^32 */ 173 | } 174 | 175 | /********************************************************* 176 | 177 | Thanks *a lot* Takuji Nishimura and Makoto Matsumoto! 178 | 179 | Now my own code... 180 | 181 | *********************************************************/ 182 | 183 | double THRandom_uniform(THGenerator *_generator, double a, double b) 184 | { 185 | return(__uniform__(_generator) * (b - a) + a); 186 | } 187 | 188 | double THRandom_normal(THGenerator *_generator, double mean, double stdv) 189 | { 190 | THArgCheck(stdv > 0, 2, "standard deviation must be strictly positive"); 191 | 192 | /* This is known as the Box-Muller method */ 193 | if(!_generator->normal_is_valid) 194 | { 195 | _generator->normal_x = __uniform__(_generator); 196 | _generator->normal_y = __uniform__(_generator); 197 | _generator->normal_rho = sqrt(-2. * log(1.0-_generator->normal_y)); 198 | _generator->normal_is_valid = 1; 199 | } 200 | else 201 | _generator->normal_is_valid = 0; 202 | 203 | if(_generator->normal_is_valid) 204 | return _generator->normal_rho*cos(2.*M_PI*_generator->normal_x)*stdv+mean; 205 | else 206 | return _generator->normal_rho*sin(2.*M_PI*_generator->normal_x)*stdv+mean; 207 | } 208 | 209 | double THRandom_exponential(THGenerator *_generator, double lambda) 210 | { 211 | return(-1. / lambda * log(1-__uniform__(_generator))); 212 | } 213 | 214 | double THRandom_cauchy(THGenerator *_generator, double median, double sigma) 215 | { 216 | return(median + sigma * tan(M_PI*(__uniform__(_generator)-0.5))); 217 | } 218 | 219 | /* Faut etre malade pour utiliser ca. 220 | M'enfin. */ 221 | double THRandom_logNormal(THGenerator *_generator, double mean, double stdv) 222 | { 223 | double zm = mean*mean; 224 | double zs = stdv*stdv; 225 | THArgCheck(stdv > 0, 2, "standard deviation must be strictly positive"); 226 | return(exp(THRandom_normal(_generator, log(zm/sqrt(zs + zm)), sqrt(log(zs/zm+1)) ))); 227 | } 228 | 229 | int THRandom_geometric(THGenerator *_generator, double p) 230 | { 231 | THArgCheck(p > 0 && p < 1, 1, "must be > 0 and < 1"); 232 | return((int)(log(1-__uniform__(_generator)) / log(p)) + 1); 233 | } 234 | 235 | int THRandom_bernoulli(THGenerator *_generator, double p) 236 | { 237 | THArgCheck(p >= 0 && p <= 1, 1, "must be >= 0 and <= 1"); 238 | return(__uniform__(_generator) <= p); 239 | } 240 | 241 | /* returns the random number state */ 242 | void THRandom_getState(THGenerator *_generator, unsigned long *_state, long *offset, long *_left) 243 | { 244 | if(_generator->initf == 0) 245 | THRandom_seed(_generator); 246 | memmove(_state, _generator->state, n*sizeof(long)); 247 | *offset = (long)(_generator->next - _generator->state); 248 | *_left = _generator->left; 249 | } 250 | 251 | /* sets the random number state */ 252 | void THRandom_setState(THGenerator *_generator, unsigned long *_state, long offset, long _left) 253 | { 254 | memmove(_generator->state, _state, n*sizeof(long)); 255 | _generator->next = _generator->state + offset; 256 | _generator->left = _left; 257 | _generator->initf = 1; 258 | } 259 | -------------------------------------------------------------------------------- /TH/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.6) 2 | 3 | SET(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH}) 4 | 5 | # Can be compiled standalone 6 | IF(NOT TH_INSTALL_BIN_SUBDIR 7 | OR NOT TH_INSTALL_LIB_SUBDIR 8 | OR NOT TH_INSTALL_INCLUDE_SUBDIR 9 | OR NOT TH_INSTALL_CMAKE_SUBDIR) 10 | 11 | SET(TH_INSTALL_BIN_SUBDIR "bin" CACHE PATH "TH install binary subdirectory") 12 | SET(TH_INSTALL_LIB_SUBDIR "lib" CACHE PATH "TH install library subdirectory") 13 | SET(TH_INSTALL_INCLUDE_SUBDIR "include" CACHE PATH "TH install include subdirectory") 14 | SET(TH_INSTALL_CMAKE_SUBDIR "share/cmake/TH" CACHE PATH "TH install cmake subdirectory") 15 | ENDIF() 16 | 17 | # flags 18 | 19 | IF(MSVC) 20 | # respect the standard 21 | ADD_DEFINITIONS(-D_CRT_SECURE_NO_DEPRECATE=1) 22 | ENDIF(MSVC) 23 | 24 | # OpenMP support? 25 | SET(WITH_OPENMP ON CACHE BOOL "OpenMP support if available?") 26 | IF (APPLE AND CMAKE_COMPILER_IS_GNUCC) 27 | EXEC_PROGRAM (uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION) 28 | STRING (REGEX MATCH "[0-9]+" DARWIN_VERSION ${DARWIN_VERSION}) 29 | MESSAGE (STATUS "MAC OS Darwin Version: ${DARWIN_VERSION}") 30 | IF (DARWIN_VERSION GREATER 9) 31 | SET(APPLE_OPENMP_SUCKS 1) 32 | ENDIF (DARWIN_VERSION GREATER 9) 33 | EXECUTE_PROCESS (COMMAND ${CMAKE_C_COMPILER} -dumpversion 34 | OUTPUT_VARIABLE GCC_VERSION) 35 | IF (APPLE_OPENMP_SUCKS AND GCC_VERSION VERSION_LESS 4.6.2) 36 | MESSAGE(STATUS "Warning: Disabling OpenMP (unstable with this version of GCC)") 37 | MESSAGE(STATUS " Install GCC >= 4.6.2 or change your OS to enable OpenMP") 38 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unknown-pragmas") 39 | SET(WITH_OPENMP OFF CACHE BOOL "OpenMP support if available?" FORCE) 40 | ENDIF () 41 | ENDIF () 42 | 43 | IF (WITH_OPENMP) 44 | FIND_PACKAGE(OpenMP) 45 | IF(OPENMP_FOUND) 46 | MESSAGE(STATUS "Compiling with OpenMP support") 47 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 48 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 49 | SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") 50 | ENDIF(OPENMP_FOUND) 51 | ENDIF (WITH_OPENMP) 52 | 53 | # ARM specific flags 54 | FIND_PACKAGE(ARM) 55 | IF (NEON_FOUND) 56 | MESSAGE(STATUS "Neon found with compiler flag : -mfpu=neon -D__NEON__") 57 | SET(CMAKE_C_FLAGS "-mfpu=neon -D__NEON__ ${CMAKE_C_FLAGS}") 58 | ENDIF (NEON_FOUND) 59 | IF (CORTEXA8_FOUND) 60 | MESSAGE(STATUS "Cortex-A8 Found with compiler flag : -mcpu=cortex-a8") 61 | SET(CMAKE_C_FLAGS "-mcpu=cortex-a8 -fprefetch-loop-arrays ${CMAKE_C_FLAGS}") 62 | ENDIF (CORTEXA8_FOUND) 63 | IF (CORTEXA9_FOUND) 64 | MESSAGE(STATUS "Cortex-A9 Found with compiler flag : -mcpu=cortex-a9") 65 | SET(CMAKE_C_FLAGS "-mcpu=cortex-a9 ${CMAKE_C_FLAGS}") 66 | ENDIF (CORTEXA9_FOUND) 67 | 68 | IF(UNIX) 69 | INCLUDE(CheckFunctionExists) 70 | SET(CMAKE_EXTRA_INCLUDE_FILES "sys/mman.h") 71 | CHECK_FUNCTION_EXISTS(mmap HAVE_MMAP) 72 | IF(HAVE_MMAP) 73 | ADD_DEFINITIONS(-DHAVE_MMAP=1) 74 | ENDIF(HAVE_MMAP) 75 | ENDIF(UNIX) 76 | 77 | FIND_PACKAGE(SSE) 78 | IF(C_SSE2_FOUND) 79 | SET(CMAKE_C_FLAGS "${C_SSE2_FLAGS} -DUSE_SSE2 ${CMAKE_C_FLAGS}") 80 | ENDIF(C_SSE2_FOUND) 81 | IF(C_SSE3_FOUND) 82 | SET(CMAKE_C_FLAGS "${C_SSE3_FLAGS} -DUSE_SSE3 ${CMAKE_C_FLAGS}") 83 | ENDIF(C_SSE3_FOUND) 84 | IF(C_SSE4_1_FOUND) 85 | SET(CMAKE_C_FLAGS "${C_SSE4_1_FLAGS} -DUSE_SSE4_1 ${CMAKE_C_FLAGS}") 86 | ENDIF(C_SSE4_1_FOUND) 87 | IF(C_SSE4_2_FOUND) 88 | SET(CMAKE_C_FLAGS "${C_SSE4_2_FLAGS} -DUSE_SSE4_2 ${CMAKE_C_FLAGS}") 89 | ENDIF(C_SSE4_2_FOUND) 90 | 91 | SET(hdr 92 | THGeneral.h THAllocator.h THStorage.h THTensor.h THTensorApply.h THBlas.h 93 | THLapack.h THLogAdd.h THRandom.h THVector.h) 94 | 95 | SET(src 96 | THGeneral.c THAllocator.c THStorage.c THTensor.c THBlas.c THLapack.c 97 | THLogAdd.c THRandom.c THFile.c THDiskFile.c THMemoryFile.c) 98 | 99 | SET(src ${src} ${hdr}) 100 | ADD_LIBRARY(TH SHARED ${src}) 101 | 102 | FIND_PACKAGE(BLAS) 103 | IF(BLAS_FOUND) 104 | SET(USE_BLAS 1) 105 | TARGET_LINK_LIBRARIES(TH ${BLAS_LIBRARIES}) 106 | ENDIF(BLAS_FOUND) 107 | 108 | FIND_PACKAGE(LAPACK) 109 | IF(LAPACK_FOUND) 110 | SET(USE_LAPACK 1) 111 | TARGET_LINK_LIBRARIES(TH ${LAPACK_LIBRARIES}) 112 | ENDIF(LAPACK_FOUND) 113 | 114 | IF(BLAS_IS_ACCELERATE) 115 | MESSAGE(STATUS "BLAS FOUND IS ACCELERATE: Fix for sdot") 116 | ENDIF() 117 | 118 | INCLUDE(CheckCSourceRuns) 119 | SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) 120 | FOREACH(KEYWORD "inline" "__inline__" "__inline") 121 | IF(NOT DEFINED C_INLINE) 122 | 123 | SET(CMAKE_REQUIRED_FLAGS "-Dinline=${KEYWORD} ${CMAKE_C_FLAGS}") 124 | CHECK_C_SOURCE_RUNS(" 125 | static inline int static_foo() 126 | { 127 | return 0; 128 | } 129 | 130 | int main(int argc, char *argv[]) 131 | { 132 | static_foo(); 133 | return 0; 134 | }" C_HAS_${KEYWORD}) 135 | 136 | IF(C_HAS_${KEYWORD}) 137 | SET(C_INLINE TRUE) 138 | # Right now i put it in THGeneral.h -- debatable 139 | # ADD_DEFINITIONS("-Dinline=${KEYWORD}") 140 | SET(TH_INLINE ${KEYWORD}) 141 | MESSAGE(STATUS "C inline is supported (${KEYWORD})") 142 | ENDIF(C_HAS_${KEYWORD}) 143 | ENDIF(NOT DEFINED C_INLINE) 144 | ENDFOREACH(KEYWORD) 145 | SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) 146 | 147 | IF(NOT DEFINED C_INLINE) 148 | MESSAGE(STATUS "C inline seems not supported") 149 | # Right now i put it in THGeneral.h -- debatable 150 | # ADD_DEFINITIONS("-Dinline=") 151 | SET(TH_INLINE "") 152 | ENDIF(NOT DEFINED C_INLINE) 153 | 154 | # Is __thread supported? 155 | INCLUDE(CheckCSourceCompiles) 156 | CHECK_C_SOURCE_COMPILES("static __thread int x = 1; int main() { return x; }" C_HAS_THREAD) 157 | IF(NOT DEFINED C_HAS_THREAD) 158 | MESSAGE(STATUS "Warning: __thread is not supported, generating thread-unsafe code") 159 | ENDIF(NOT DEFINED C_HAS_THREAD) 160 | IF(C_HAS_THREAD) 161 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DTH_HAVE_THREAD") 162 | ENDIF(C_HAS_THREAD) 163 | 164 | INCLUDE_DIRECTORIES("${CMAKE_CURRENT_BINARY_DIR}") 165 | CONFIGURE_FILE(THGeneral.h.in "${CMAKE_CURRENT_BINARY_DIR}/THGeneral.h") 166 | 167 | INSTALL(TARGETS TH 168 | EXPORT TH-exports 169 | RUNTIME DESTINATION "${TH_INSTALL_BIN_SUBDIR}" 170 | LIBRARY DESTINATION "${TH_INSTALL_LIB_SUBDIR}" 171 | ARCHIVE DESTINATION "${TH_INSTALL_LIB_SUBDIR}") 172 | 173 | INSTALL(FILES 174 | TH.h 175 | THAllocator.h 176 | THBlas.h 177 | THDiskFile.h 178 | THFile.h 179 | THFilePrivate.h 180 | ${CMAKE_CURRENT_BINARY_DIR}/THGeneral.h 181 | THGenerateAllTypes.h 182 | THGenerateFloatTypes.h 183 | THGenerateIntTypes.h 184 | THLapack.h 185 | THLogAdd.h 186 | THMemoryFile.h 187 | THRandom.h 188 | THStorage.h 189 | THTensor.h 190 | THTensorApply.h 191 | THTensorDimApply.h 192 | THTensorMacros.h 193 | THVector.h 194 | DESTINATION "${TH_INSTALL_INCLUDE_SUBDIR}/TH") 195 | 196 | INSTALL(FILES 197 | generic/THBlas.c 198 | generic/THBlas.h 199 | generic/THLapack.c 200 | generic/THLapack.h 201 | generic/THStorage.c 202 | generic/THStorage.h 203 | generic/THStorageCopy.c 204 | generic/THStorageCopy.h 205 | generic/THTensor.c 206 | generic/THTensor.h 207 | generic/THTensorConv.c 208 | generic/THTensorConv.h 209 | generic/THTensorCopy.c 210 | generic/THTensorCopy.h 211 | generic/THTensorLapack.c 212 | generic/THTensorLapack.h 213 | generic/THTensorMath.c 214 | generic/THTensorMath.h 215 | generic/THTensorRandom.c 216 | generic/THTensorRandom.h 217 | generic/THVector.c 218 | DESTINATION "${TH_INSTALL_INCLUDE_SUBDIR}/TH/generic") 219 | 220 | 221 | IF (WIN32 AND NOT CYGWIN) 222 | SET(BLAS_INSTALL_LIBRARIES "OFF" 223 | CACHE BOOL "Copy the required BLAS DLLs into the TH install dirs") 224 | ENDIF (WIN32 AND NOT CYGWIN) 225 | 226 | MACRO(Install_Required_Library ln) 227 | get_filename_component(libpath ${ln} PATH) 228 | get_filename_component(libname ${ln} NAME_WE) 229 | file(GLOB libdlls "${libpath}/${libname}*.dll") 230 | install(PROGRAMS ${libdlls} 231 | DESTINATION "${TH_INSTALL_BIN_SUBDIR}") 232 | ENDMACRO(Install_Required_Library libname) 233 | 234 | IF (BLAS_FOUND AND BLAS_INSTALL_LIBRARIES) 235 | IF (BLAS_goto2_LIBRARY) 236 | Install_Required_Library(${BLAS_goto2_LIBRARY}) 237 | Install_Required_Library("${libpath}/libgfortran") 238 | Install_Required_Library("${libpath}/libquadmath") 239 | Install_Required_Library("${libpath}/libgcc") 240 | ENDIF() 241 | IF (BLAS_openblas_LIBRARY) 242 | Install_Required_Library(${BLAS_openblas_LIBRARY}) 243 | Install_Required_Library("${libpath}/libquadmath") 244 | Install_Required_Library("${libpath}/libgfortran") 245 | Install_Required_Library("${libpath}/libquadmath") 246 | Install_Required_Library("${libpath}/libgcc") 247 | ENDIF() 248 | ENDIF() 249 | 250 | # Create THConfig.cmake 251 | GET_TARGET_PROPERTY(TH_OUTPUT_NAME TH LOCATION) 252 | GET_FILENAME_COMPONENT(TH_OUTPUT_NAME ${TH_OUTPUT_NAME} NAME) 253 | SET(TH_LIBRARIES "${CMAKE_INSTALL_PREFIX}/${TH_INSTALL_LIB_SUBDIR}/${TH_OUTPUT_NAME}") 254 | SET(TH_INCLUDE_DIR "${CMAKE_INSTALL_PREFIX}/${TH_INSTALL_INCLUDE_SUBDIR}/TH") 255 | CONFIGURE_FILE(THConfig.cmake.in "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/THConfig.cmake") 256 | INSTALL(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/THConfig.cmake" 257 | DESTINATION "${TH_INSTALL_CMAKE_SUBDIR}") 258 | --------------------------------------------------------------------------------