├── ANN.h ├── BatchProducer.h ├── BatchWiseDropoutLayer.h ├── Batches.h ├── CudaUtils.h ├── Data ├── Artificial │ ├── .gitignore │ └── artificial.dataset.py ├── CIFAR10 │ └── .gitignore └── MNIST │ └── .gitignore ├── Dataset.h ├── Layer.h ├── Makefile ├── NetworkArchitectures.h ├── README.md ├── Rng.h ├── SampleWiseDropoutLayer.h ├── SigmoidLayer.h ├── SimpleLayer.h ├── SoftmaxClassifier.h ├── dropout.h ├── readArtificialDataset.h ├── readCIFAR10.h ├── readMNIST.h ├── runArtificial.cu ├── runCifar10.cu └── runMnist.cu /ANN.h: -------------------------------------------------------------------------------- 1 | class ANN { 2 | public: 3 | vector ann; 4 | int nInputFeatures; 5 | int nOutputFeatures; 6 | int inputSpatialSize; 7 | int nClasses; 8 | int nTop; 9 | float inputDropout; 10 | ANN (int nInputFeatures, 11 | int nClasses, 12 | int cudaDevice=0, 13 | float inputDropout=0.0f, 14 | int nTop=1) : 15 | nInputFeatures(nInputFeatures), 16 | nClasses(nClasses), 17 | nTop(nTop), 18 | inputDropout(inputDropout) { 19 | nOutputFeatures=nInputFeatures; 20 | initializeGPU(cudaDevice); 21 | if (inputDropout>0) 22 | cout << "Input Layer Dropout " << inputDropout << endl; 23 | } 24 | void addSimpleLayer(int nFeatures, ActivationFunction activationFn=RELU) { 25 | cout << nFeatures <<"N " << sigmoidNames[activationFn] << endl; 26 | ann.push_back(new SimpleLayer(nOutputFeatures, nFeatures,activationFn)); 27 | nOutputFeatures=nFeatures; 28 | } 29 | void addBatchWiseDropoutLayer(int nFeatures, float dropout, ActivationFunction activationFn=RELU) { 30 | cout << nFeatures <<"N " << sigmoidNames[activationFn] << endl; 31 | cout << "Dropout " << dropout << endl; 32 | ann.push_back(new BatchWiseDropoutLayer(nOutputFeatures, nFeatures,dropout,activationFn)); 33 | nOutputFeatures=nFeatures; 34 | } 35 | void addSampleWiseDropoutLayer(int nFeatures, float dropout, ActivationFunction activationFn=RELU) { 36 | cout << "Dropout " << dropout << endl; 37 | cout << nFeatures <<"N " << sigmoidNames[activationFn] << endl; 38 | ann.push_back(new SampleWiseDropoutLayer(nOutputFeatures, nFeatures, dropout, activationFn)); 39 | nOutputFeatures=nFeatures; 40 | } 41 | void processBatch(Batch& batch, float learningRate=0.1) { 42 | vector interfaces(1); 43 | interfaces[0]=&batch.i; 44 | for (int i=0;iforwards(*interfaces[i],*interfaces[i+1]); 48 | } 49 | if (batch.i.type==TRAINBATCH) 50 | { 51 | SoftmaxClassifier(*interfaces.back(),nTop,batch.labels,batch.mistakes); 52 | for (int i=ann.size()-1;i>=0;i--) { 53 | ann[i]->backwards(*interfaces[i],*interfaces[i+1],learningRate); 54 | } 55 | } 56 | else if (batch.i.type==TESTBATCH) 57 | { 58 | SoftmaxClassifier(*interfaces.back(),nTop,batch.labels,batch.mistakes); 59 | } 60 | else if (batch.i.type==UNLABELLEDBATCH) 61 | { 62 | ofstream f; 63 | f.open("predictions.labels", ios::app); 64 | vector > predictions=SoftmaxClassifier(*interfaces.back(),nTop); 65 | for (int j=0;j(epoch)+string(".ann"); 77 | ifstream f; 78 | f.open(filename.c_str(),ios::out | ios::binary); 79 | if (f) { 80 | cout << "Loading network parameters from " << filename << endl; 81 | } else { 82 | cout <<"Cannot find " << filename << endl; 83 | exit(EXIT_FAILURE); 84 | } 85 | for (int i=0;iloadWeightsFromStream(f); 87 | f.close(); 88 | } 89 | void saveWeights(string baseName, int epoch) { 90 | string filename=string(baseName)+string("_epoch-")+boost::lexical_cast(epoch)+string(".ann"); 91 | ofstream f; 92 | f.open(filename.c_str(),ios::binary); 93 | if (f) { 94 | for (int i=0;iputWeightsToStream(f); 96 | f.close(); 97 | } else { 98 | cout <<"Cannot write " << filename << endl; 99 | exit(EXIT_FAILURE); 100 | } 101 | } 102 | }; 103 | 104 | 105 | 106 | float iterate(ANN& ann, Dataset &dataset, int batchSize=128, float learningRate=0.1, bool verbose=false) { 107 | float errorRate=0; 108 | BatchProducer bp(dataset,batchSize,ann.inputDropout); 109 | if (dataset.type==UNLABELLEDBATCH) { 110 | ofstream f; 111 | f.open("predictions.labels"); 112 | } 113 | int ctr=0; 114 | while(Batch* batch=bp.nextBatch()) { 115 | ann.processBatch(*batch,learningRate); 116 | ctr++; 117 | if (verbose and !(ctr & (ctr-1))) 118 | cout << ctr << " " << batch->mistakes << endl; 119 | errorRate+=batch->mistakes*1.0/dataset.samples.size(); 120 | delete batch; 121 | } 122 | cout << dataset.name 123 | << " Mistakes: " 124 | << 100*errorRate 125 | << "%" 126 | << endl; 127 | return errorRate; 128 | } 129 | -------------------------------------------------------------------------------- /BatchProducer.h: -------------------------------------------------------------------------------- 1 | class BatchProducer { 2 | public: 3 | int batchCounter; 4 | boost::thread_group workers; 5 | Dataset& dataset; 6 | vector v; 7 | int batchSize; 8 | int nThreads; 9 | float inputDropout; 10 | 11 | Batch* nextBatch() { 12 | if (batchCounterbatchCounter+5*nThreads) 25 | boost::this_thread::sleep(boost::posix_time::milliseconds(10)); 26 | Batch* batch = 27 | new Batch(dataset.type, dataset.nFeatures, inputDropout); 28 | for (int i=c*batchSize;idistort(rng); 31 | pic->codifyInputData(*batch); 32 | delete pic; 33 | } else { 34 | dataset.samples[i]->codifyInputData(*batch); 35 | } 36 | } 37 | v[c]=batch; 38 | } 39 | } 40 | BatchProducer (Dataset &dataset, int batchSize, float inputDropout=0.0f, int nThreads=4) : 41 | batchCounter(0), dataset(dataset), batchSize(batchSize), inputDropout(inputDropout), nThreads(nThreads) { 42 | v.resize((dataset.samples.size()+batchSize-1)/batchSize,NULL); 43 | for (int nThread=0; nThread W; //Weights 5 | vectorCUDA mW; //momentum 6 | vectorCUDA w; //shrunk versions 7 | vectorCUDA dw; //For backprop 8 | vectorCUDA B; //Weights 9 | vectorCUDA mB; //momentum 10 | vectorCUDA b; //shrunk versions 11 | vectorCUDA db; //For backprop 12 | ActivationFunction fn; 13 | public: 14 | int nFeaturesIn; 15 | int nFeaturesOut; 16 | float dropout; 17 | BatchWiseDropoutLayer(int nFeaturesIn, int nFeaturesOut, 18 | float dropout=0,ActivationFunction fn=NOSIGMOID) : 19 | nFeaturesIn(nFeaturesIn), nFeaturesOut(nFeaturesOut), 20 | dropout(dropout), fn(fn) { 21 | float scale=0; 22 | if (fn!=SOFTMAX) 23 | scale=powf(nFeaturesIn,-0.5); 24 | W.resize (nFeaturesIn*nFeaturesOut); W.setUniform(-scale,scale); 25 | mW.resize (nFeaturesIn*nFeaturesOut); mW.setZero(); 26 | B.resize (nFeaturesOut); B.setZero(); 27 | mB.resize (nFeaturesOut); mB.setZero(); 28 | 29 | } 30 | void forwards 31 | (BatchInterface &input, 32 | BatchInterface &output) { 33 | output.type=input.type; 34 | output.batchSize=input.batchSize; 35 | output.nFeatures=nFeaturesOut; 36 | int o=nFeaturesOut*(input.type==TRAINBATCH?(1.0f-dropout):1.0f); 37 | output.featuresPresent.hVector()=rng.NchooseM(nFeaturesOut,o); 38 | assert(input.nFeatures==nFeaturesIn); 39 | output.features.resize(output.batchSize*output.featuresPresent.size()); 40 | 41 | if (input.type==TRAINBATCH) 42 | output.dfeatures.resize(output.batchSize*output.featuresPresent.size()); 43 | 44 | if (input.type==TRAINBATCH and nFeaturesIn+nFeaturesOut>input.featuresPresent.size()+output.featuresPresent.size()) { 45 | w.resize(input.featuresPresent.size()*output.featuresPresent.size()); 46 | b.resize(output.featuresPresent.size()); 47 | dShrinkMatrixForDropout 48 | <<>> 49 | (W.dPtr(), w.dPtr(), 50 | input.featuresPresent.dPtr(), 51 | output.featuresPresent.dPtr(), 52 | output.nFeatures, 53 | output.featuresPresent.size()); 54 | dShrinkVectorForDropout<<<1,NTHREADS>>>(B.dPtr(), b.dPtr(), 55 | output.featuresPresent.dPtr(), 56 | output.nFeatures, 57 | output.featuresPresent.size()); 58 | cudaCheckError(); 59 | replicateArray(b.dPtr(), output.features.dPtr(), output.batchSize, output.featuresPresent.size()); 60 | d_rowMajorSGEMM_alphaAB_betaC(cublasHandle, 61 | input.features.dPtr(), w.dPtr(), output.features.dPtr(), 62 | output.batchSize, input.featuresPresent.size(), output.featuresPresent.size(), 63 | 1.0f, 1.0f,__FILE__,__LINE__); 64 | cudaCheckError(); 65 | 66 | } else { 67 | replicateArray(B.dPtr(), output.features.dPtr(), output.batchSize, output.featuresPresent.size()); 68 | d_rowMajorSGEMM_alphaAB_betaC(cublasHandle, 69 | input.features.dPtr(), W.dPtr(), output.features.dPtr(), 70 | output.batchSize, input.nFeatures, output.nFeatures, 71 | 1.0f-dropout, 1.0f-dropout,__FILE__,__LINE__); 72 | cudaCheckError(); 73 | } 74 | applySigmoid(output, output, fn); 75 | } 76 | void backwards(BatchInterface &input, 77 | BatchInterface &output, 78 | float learningRate=0.1) { 79 | applySigmoidBackProp(output, output, fn); 80 | 81 | dw.resize(input.featuresPresent.size()*output.featuresPresent.size()); 82 | db.resize(output.featuresPresent.size()); 83 | 84 | d_rowMajorSGEMM_alphaAtB_betaC(cublasHandle, 85 | input.features.dPtr(), output.dfeatures.dPtr(), dw.dPtr(), 86 | input.featuresPresent.size(), output.batchSize, output.featuresPresent.size(), 87 | 1.0, 0.0); 88 | db.setZero(); 89 | columnSum(output.dfeatures.dPtr(), db.dPtr(), output.batchSize, output.featuresPresent.size()); 90 | cudaCheckError(); 91 | 92 | if (nFeaturesIn+nFeaturesOut>input.featuresPresent.size()+output.featuresPresent.size()) { 93 | if (input.dfeatures.size()>0) { 94 | d_rowMajorSGEMM_alphaABt_betaC(cublasHandle, 95 | output.dfeatures.dPtr(), w.dPtr(), input.dfeatures.dPtr(), 96 | output.batchSize,output.featuresPresent.size(),input.featuresPresent.size(), 97 | 1.0, 0.0); 98 | cudaCheckError(); 99 | } 100 | 101 | dGradientDescentMatrixNAGlite<<>> 102 | (dw.dPtr(), mW.dPtr(), W.dPtr(), 103 | output.nFeatures, output.featuresPresent.size(), 104 | input.featuresPresent.dPtr(), output.featuresPresent.dPtr(), 105 | learningRate); 106 | 107 | dGradientDescentVectorNAGlite<<<1,NTHREADS>>> 108 | (db.dPtr(), mB.dPtr(), B.dPtr(), 109 | output.nFeatures, output.featuresPresent.size(), 110 | output.featuresPresent.dPtr(), 111 | learningRate); 112 | } else { 113 | if (input.dfeatures.size()>0) { 114 | d_rowMajorSGEMM_alphaABt_betaC(cublasHandle, 115 | output.dfeatures.dPtr(), W.dPtr(), input.dfeatures.dPtr(), 116 | output.batchSize,nFeaturesOut,nFeaturesIn, 117 | 1.0, 0.0); 118 | cudaCheckError(); 119 | } 120 | dGradientDescentNAG<<>> 121 | (dw.dPtr(), mW.dPtr(), W.dPtr(), nFeaturesOut, learningRate); 122 | dGradientDescentNAG<<<1,KERNELBLOCKSIZE>>> 123 | (db.dPtr(), mB.dPtr(), B.dPtr(), nFeaturesOut, learningRate); 124 | cudaCheckError(); 125 | } 126 | } 127 | void loadWeightsFromStream(ifstream &f) { 128 | f.read((char*)&W.hVector()[0],sizeof(float)*W.size()); 129 | f.read((char*)&B.hVector()[0],sizeof(float)*B.size()); 130 | }; 131 | void putWeightsToStream(ofstream &f) { 132 | f.write((char*)&W.hVector()[0],sizeof(float)*W.size()); 133 | f.write((char*)&B.hVector()[0],sizeof(float)*B.size()); 134 | }; 135 | }; 136 | -------------------------------------------------------------------------------- /Batches.h: -------------------------------------------------------------------------------- 1 | enum batchType {TRAINBATCH, TESTBATCH, UNLABELLEDBATCH}; 2 | 3 | 4 | class BatchInterface { 5 | public: 6 | batchType type; 7 | int batchSize; // Number of training/test samples 8 | int nFeatures; // Features per sample 9 | vectorCUDA features; // For the forwards pass 10 | vectorCUDA dfeatures; // For the backwards/backpropagation pass 11 | vectorCUDA featuresPresent; // For batchwise dropout - rng.NchooseM(nFeatures,featuresPresent.size()); 12 | // Not dropped out features 13 | vectorCUDA dropoutMask; // For SampleWiseDropoutLayers 14 | void summary() { 15 | cout << "---------------------------------------------------\n"; 16 | cout << "type" << type << endl; 17 | cout << "batchSize" << batchSize << endl; 18 | cout << "nFeatures" << nFeatures << endl; 19 | cout << "featuresPresent.size()" << featuresPresent.size() < labels; 30 | int mistakes; 31 | float inputDropout; //Use for batch-wise dropout during testing 32 | Batch(batchType type, int nFeatures, float inputDrop=0.0f) { 33 | i.type=type; 34 | i.batchSize=0; 35 | i.nFeatures=nFeatures; 36 | if (type==TRAINBATCH) { 37 | RNG rng; 38 | i.featuresPresent.hVector()=rng.NchooseM(nFeatures,nFeatures*(1-inputDrop)); 39 | inputDropout=0; 40 | } else { 41 | i.featuresPresent.hVector()=range(nFeatures); 42 | inputDropout=inputDrop; 43 | } 44 | mistakes=0; 45 | } 46 | }; 47 | -------------------------------------------------------------------------------- /CudaUtils.h: -------------------------------------------------------------------------------- 1 | // https://gist.github.com/ashwin/2652488#file-cudaerrorcheck-cu 2 | // Define this to turn on error checking 3 | //#define CUDA_ERROR_CHECK 4 | 5 | #define cudaSafeCall( err ) __cudaSafeCall( err, __FILE__, __LINE__ ) 6 | #define cudaCheckError() { __cudaCheckError( __FILE__, __LINE__ ); } 7 | 8 | inline void __cudaSafeCall( cudaError err, const char *file, const int line ) 9 | { 10 | #ifdef CUDA_ERROR_CHECK 11 | if ( cudaSuccess != err ) 12 | { 13 | fprintf( stderr, "cudaSafeCall() failed at %s:%i : %s\n", 14 | file, line, cudaGetErrorString( err ) ); 15 | exit( -1 ); 16 | } 17 | #endif 18 | return; 19 | } 20 | 21 | inline void __cudaCheckError( const char *file, const int line ) 22 | { 23 | #ifdef CUDA_ERROR_CHECK 24 | cudaError err = cudaGetLastError(); 25 | if ( cudaSuccess != err ) 26 | { 27 | fprintf( stderr, "cudaCheckError() failed at %s:%i : %s\n", 28 | file, line, cudaGetErrorString( err ) ); 29 | exit( -1 ); 30 | } 31 | 32 | // More careful checking. However, this will affect performance. 33 | // Comment away if needed. 34 | err = cudaDeviceSynchronize(); 35 | if( cudaSuccess != err ) 36 | { 37 | fprintf( stderr, "cudaCheckError() with sync failed at %s:%i : %s\n", 38 | file, line, cudaGetErrorString( err ) ); 39 | exit( -1 ); 40 | } 41 | #endif 42 | 43 | return; 44 | } 45 | 46 | cublasHandle_t cublasHandle; 47 | #define NTHREADS 512 48 | #define KERNELBLOCKSIZE 32 49 | 50 | static void cublasError(cublasStatus_t error,const char* file = 0, int linenumber = 0) 51 | { 52 | switch (error) 53 | { 54 | case CUBLAS_STATUS_SUCCESS: 55 | break; 56 | 57 | case CUBLAS_STATUS_NOT_INITIALIZED: 58 | cout << file << " " << linenumber< class vectorCUDA { 165 | private: 166 | t* d_vec; 167 | int dsize; //When on GPU 168 | std::vector vec; 169 | public: 170 | bool onGPU; 171 | void copyToCPU() { 172 | if (onGPU) { 173 | onGPU=false; 174 | if (dsize>0) { 175 | vec.resize(dsize); 176 | cudaSafeCall(cudaMemcpy(&vec[0],d_vec,sizeof(t)*dsize,cudaMemcpyDeviceToHost)); 177 | cudaSafeCall(cudaFree(d_vec)); 178 | } 179 | } 180 | } 181 | void copyToGPU() { 182 | if (!onGPU) { 183 | onGPU=true; 184 | if (vec.size()>0) { 185 | dsize=vec.size(); 186 | cudaSafeCall(cudaMalloc((void**) &d_vec, sizeof(t)*dsize)); 187 | cudaSafeCall(cudaMemcpy(d_vec,&vec[0],sizeof(t)*dsize,cudaMemcpyHostToDevice)); 188 | vec.clear(); 189 | } else { 190 | dsize=0; 191 | } 192 | } 193 | } 194 | void copyToGPU(cudaStream_t stream) { 195 | if (!onGPU) { 196 | onGPU=true; 197 | if (vec.size()>0) { 198 | dsize=vec.size(); 199 | cudaSafeCall(cudaMalloc((void**) &d_vec, sizeof(t)*dsize)); 200 | cudaSafeCall(cudaMemcpyAsync(d_vec,&vec[0],sizeof(t)*dsize,cudaMemcpyHostToDevice,stream)); 201 | vec.clear(); 202 | } 203 | } 204 | } 205 | t*& dPtr() { 206 | copyToGPU(); 207 | return d_vec; 208 | } 209 | vector& hVector() { 210 | copyToCPU(); 211 | return vec; 212 | } 213 | int size() { 214 | if (onGPU) return dsize; 215 | return vec.size(); 216 | } 217 | float meanAbs() { 218 | float total=0; 219 | for (int i=0;i0) 258 | cudaSafeCall(cudaFree(d_vec)); 259 | if (n>0) 260 | cudaSafeCall(cudaMalloc((void**) &d_vec, sizeof(t)*n)); 261 | dsize=n; 262 | } 263 | } else { 264 | vec.resize(n); 265 | } 266 | } 267 | vectorCUDA(bool onGPU=true, int dsize=0) : onGPU(onGPU), dsize(dsize) { 268 | if (onGPU && dsize>0) { 269 | cudaSafeCall(cudaMalloc((void**) &d_vec, sizeof(t)*dsize)); 270 | } else { 271 | vec.resize(dsize); 272 | } 273 | } 274 | ~vectorCUDA() { 275 | if (onGPU && dsize>0) 276 | cudaSafeCall(cudaFree(d_vec)); 277 | } 278 | void printSubset(const char *name, int nCol,int maxPrint=10) { 279 | RNG rng; 280 | copyToCPU(); 281 | int nRow=vec.size()/nCol; 282 | cout << name << " " << nRow << " " << nCol << endl; 283 | vector rr=rng.NchooseM(nRow,min(maxPrint,nRow)); 284 | vector rc=rng.NchooseM(nCol,min(maxPrint,nCol)); 285 | for (int i=0;i range(int n) { 298 | vector ret(n); 299 | for (int i=0; i1) { 330 | biases[blockIdx.x]/=acc; 331 | for(int i=blockIdx.x; i0) 400 | dColumnSum<<>>(matrix, target, nRows, nColumns); 401 | if (nColumns%KERNELBLOCKSIZE>0) { 402 | int o=nColumns/KERNELBLOCKSIZE*KERNELBLOCKSIZE; 403 | dColumnSum<<>>(matrix+o, target+o, nRows, nColumns); 404 | } 405 | cudaCheckError(); 406 | } 407 | 408 | __global__ void dReplicateArray 409 | (float* src, float* dst, int nColumns) { 410 | int i=blockIdx.x*nColumns; 411 | for (int j=threadIdx.x;j>> 419 | (src, dst+processed*nColumns, nColumns); 420 | processed+=batch; 421 | } 422 | cudaCheckError(); 423 | } 424 | -------------------------------------------------------------------------------- /Data/Artificial/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | !artificial.dataset.py 6 | -------------------------------------------------------------------------------- /Data/Artificial/artificial.dataset.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | numpy.random.seed(31415927) 3 | def randomWalk(n,l): 4 | """Random walk in the hypercube {0,1}^n with nearest neighbor edges. Walk of length l-1, so l points in total.""" 5 | a=numpy.zeros((l,n),dtype="float32") 6 | x=numpy.random.binomial(1,0.5,n) 7 | a[0]=x 8 | for j in xrange(1,l): 9 | i=numpy.random.randint(0,n) 10 | x[i]=1-x[i] 11 | a[j]=x 12 | return a 13 | def randomWalkSampleNoisy(rw,N): 14 | a=rw[numpy.random.randint(0,rw.shape[0],N)] 15 | b=numpy.random.binomial(1,0.4,(N,rw.shape[1])) 16 | return (a+b)%2 17 | 18 | nTrain=100000 19 | nTest=10000 20 | dims=1000 21 | walkLengths=1000 22 | nWalksPerClass=1 23 | nClasses=100 24 | fTrain=open("artificial.train.data","wb") 25 | fTest=open("artificial.test.data","wb") 26 | for cl in range(nClasses): 27 | for w in range(nWalksPerClass): 28 | rw=randomWalk(dims,walkLengths) 29 | s=randomWalkSampleNoisy(rw,nTrain/nWalksPerClass/nClasses) 30 | for i in range(s.shape[0]): 31 | numpy.array([cl],dtype="uint8").tofile(fTrain) 32 | s[i].astype("uint8").tofile(fTrain) 33 | s=randomWalkSampleNoisy(rw,nTest/nWalksPerClass/nClasses) 34 | for i in range(s.shape[0]): 35 | numpy.array([cl],dtype="uint8").tofile(fTest) 36 | s[i].astype("uint8").tofile(fTest) 37 | fTrain.close() 38 | fTest.close() 39 | -------------------------------------------------------------------------------- /Data/CIFAR10/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /Data/MNIST/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /Dataset.h: -------------------------------------------------------------------------------- 1 | class Datum { 2 | public: 3 | virtual void codifyInputData (Batch &batch)=0; 4 | virtual Datum* distort (RNG& rng) =0; 5 | int label; //-1 for unknown 6 | virtual ~Datum() {} 7 | }; 8 | 9 | 10 | class Dataset { 11 | public: 12 | string name; 13 | vector samples; 14 | int nFeatures; 15 | int nClasses; 16 | batchType type; 17 | void shuffle() { 18 | random_shuffle ( samples.begin(), samples.end()); 19 | } 20 | void summary() { 21 | cout << "Name: " << name << endl; 22 | cout << "nSamples: " << samples.size() << endl; 23 | cout << "nClasses: " << nClasses << endl; 24 | cout << "nFeatures: " << nFeatures << endl; 25 | } 26 | Dataset extractValidationSet(float p=0.1) { 27 | Dataset val; 28 | val.name=name+string(" Validation set"); 29 | name=name+string(" minus Validation set"); 30 | val.nFeatures=nFeatures; 31 | val.nClasses=nClasses; 32 | val.type=TESTBATCH; 33 | shuffle(); 34 | int size=samples.size()*p; 35 | for (;size>0;size--) { 36 | val.samples.push_back(samples.back()); 37 | samples.pop_back(); 38 | } 39 | return val; 40 | } 41 | Dataset subset(int n) { 42 | Dataset subset(*this); 43 | subset.shuffle(); 44 | subset.samples.resize(n); 45 | return subset; 46 | } 47 | }; 48 | 49 | class vectorDatum : public Datum { 50 | public: 51 | vector features; 52 | void codifyInputData (Batch &batch) { 53 | for (int i=0;i features; 73 | void codifyInputData (Batch &batch) { 74 | for (int i=0;ifeatures[c*1024+y*32+x]= 101 | (xx>=0 and xx<32 and yy>=0 and yy<32)?features[c*1024+yy*32+xx]:0; 102 | } 103 | return a; 104 | } 105 | ~vectorDatum32_24() {} 106 | }; 107 | -------------------------------------------------------------------------------- /Layer.h: -------------------------------------------------------------------------------- 1 | class Layer { 2 | public: 3 | virtual void forwards 4 | (BatchInterface &input, BatchInterface &output) = 0; 5 | virtual void backwards 6 | (BatchInterface &input, BatchInterface &output, float learningRate=0.1) = 0; 7 | virtual void loadWeightsFromStream(ifstream &f) {}; 8 | virtual void putWeightsToStream(ofstream &f) {}; 9 | }; 10 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | cifar10: 2 | echo "Please put .bin files from http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz in Data/CIFAR10/" 3 | nvcc -o DropOutNet runCifar10.cu -lrt -lcublas -lboost_thread -lboost_system -arch sm_20 -O2 4 | DropOutNet 5 | mnist: 6 | echo "Please put http://yann.lecun.com/exdb/mnist/ ubyte files in Data/MNIST" 7 | nvcc -o DropOutNet runMnist.cu -lrt -lcublas -lboost_thread -lboost_system -arch sm_20 -O2 8 | DropOutNet 9 | artifical: 10 | cd Data/Artificial/; python artificial.dataset.py 11 | nvcc -o DropOutNet runArtificial.cu -lrt -lcublas -lboost_thread -lboost_system -arch sm_20 -O2 12 | DropOutNet 13 | -------------------------------------------------------------------------------- /NetworkArchitectures.h: -------------------------------------------------------------------------------- 1 | class SimpleNet : public ANN { 2 | public: 3 | SimpleNet(int nInputFeatures, int nFeatures, int nHiddenLayers, int nClasses, ActivationFunction fn, int cudaDevice=0, int nTop=1) : ANN(nInputFeatures, nClasses, cudaDevice, 0.0f, nTop) { 4 | for (int i=0;i nd(mean, sd); 25 | boost::variate_generator > var_nor(gen, nd); 27 | return mean+sd*var_nor(); 28 | } 29 | int bernoulli(float p) { 30 | if (uniform() 36 | int index(std::vector &v) { 37 | if (v.size()==0) std::cout << "RNG::index called for empty std::vector!\n"; 38 | return gen()%v.size(); 39 | } 40 | std::vector zNchooseM(int n, int m) { 41 | std::vector ret; 42 | for(int i=0;i NchooseM(int n, int m) { 47 | std::vector ret(m); 48 | int ctr=m; 49 | for(int i=0;i permutation(int n) { 54 | std::vector ret; 55 | for (int i=0;i 2 | #define CURANDBLOCKS 32 3 | __global__ void dCurandInit(curandState *d_state) { 4 | curand_init(blockIdx.x*KERNELBLOCKSIZE+threadIdx.x,0,0,&d_state[blockIdx.x*KERNELBLOCKSIZE+threadIdx.x]); 5 | } 6 | __global__ void dCurandBernoulliDrop 7 | (curandState *d_state, bool update, float* d, float p, int N, int n) { 8 | curandState state=d_state[blockIdx.x*KERNELBLOCKSIZE+threadIdx.x]; 9 | for (int i=blockIdx.x;i>>(d_state); 22 | } 23 | void drop(float* d, float p, int N, int n) { 24 | dCurandBernoulliDrop<<>>(d_state,false,d,p,N,n); 25 | } 26 | void dropD(float* d, float p, int N, int n) { 27 | dCurandBernoulliDrop<<>>(d_state,true,d,p,N,n); 28 | } 29 | }; 30 | /////////////////////////////////////////////////////////////////////////////////////////////////// 31 | 32 | class SampleWiseDropoutLayer : public Layer { 33 | private: 34 | RNG rng; 35 | vectorCUDA W; //Weights 36 | vectorCUDA mW; //momentum 37 | vectorCUDA dW; //For backprop 38 | vectorCUDA B; //Weights 39 | vectorCUDA mB; //momentum 40 | vectorCUDA dB; //For backprop 41 | ActivationFunction fn; 42 | public: 43 | int nFeaturesIn; 44 | int nFeaturesOut; 45 | float dropout; 46 | curandDropout cd; 47 | SampleWiseDropoutLayer(int nFeaturesIn, int nFeaturesOut, 48 | float dropout=0,ActivationFunction fn=NOSIGMOID) : 49 | nFeaturesIn(nFeaturesIn), nFeaturesOut(nFeaturesOut), 50 | dropout(dropout), fn(fn) { 51 | float scale=0; 52 | if (fn!=SOFTMAX) 53 | scale=powf(nFeaturesIn,-0.5); 54 | W.resize (nFeaturesIn*nFeaturesOut); W.setUniform(-scale,scale); 55 | mW.resize (nFeaturesIn*nFeaturesOut); mW.setZero(); 56 | B.resize (nFeaturesOut); B.setZero(); 57 | mB.resize (nFeaturesOut); mB.setZero(); 58 | 59 | } 60 | void forwards 61 | (BatchInterface &input, 62 | BatchInterface &output) { 63 | assert(input.nFeatures==nFeaturesIn); 64 | assert(input.featuresPresent.size()==nFeaturesIn); 65 | output.type=input.type; 66 | output.batchSize=input.batchSize; 67 | output.nFeatures=nFeaturesOut; 68 | output.featuresPresent.hVector()=range(nFeaturesOut); 69 | output.features.resize(output.batchSize*nFeaturesOut); 70 | replicateArray(B.dPtr(), output.features.dPtr(), output.batchSize, nFeaturesOut); 71 | 72 | if (input.type==TRAINBATCH) { 73 | output.dfeatures.resize(output.batchSize*nFeaturesOut); 74 | cd.drop(input.features.dPtr(),dropout,input.batchSize,nFeaturesIn); 75 | d_rowMajorSGEMM_alphaAB_betaC(cublasHandle, 76 | input.features.dPtr(), W.dPtr(), output.features.dPtr(), 77 | output.batchSize, nFeaturesIn, nFeaturesOut, 78 | 1.0f, 1.0f,__FILE__,__LINE__); 79 | 80 | } else { 81 | d_rowMajorSGEMM_alphaAB_betaC(cublasHandle, 82 | input.features.dPtr(), W.dPtr(), output.features.dPtr(), 83 | output.batchSize, nFeaturesIn, nFeaturesOut, 84 | 1.0f-dropout, 1.0f,__FILE__,__LINE__); 85 | } 86 | cudaCheckError(); 87 | applySigmoid(output, output, fn); 88 | } 89 | void backwards(BatchInterface &input, 90 | BatchInterface &output, 91 | float learningRate=0.1) { 92 | applySigmoidBackProp(output, output, fn); 93 | 94 | dW.resize(nFeaturesIn*nFeaturesOut); 95 | dB.resize(nFeaturesOut); 96 | 97 | d_rowMajorSGEMM_alphaAtB_betaC(cublasHandle, 98 | input.features.dPtr(), output.dfeatures.dPtr(), dW.dPtr(), 99 | nFeaturesIn, output.batchSize, nFeaturesOut, 100 | 1.0, 0.0); 101 | dB.setZero(); 102 | columnSum(output.dfeatures.dPtr(), dB.dPtr(), output.batchSize, nFeaturesOut); 103 | cudaCheckError(); 104 | 105 | if (input.dfeatures.size()>0) 106 | d_rowMajorSGEMM_alphaABt_betaC(cublasHandle, 107 | output.dfeatures.dPtr(), W.dPtr(), input.dfeatures.dPtr(), 108 | output.batchSize,nFeaturesOut,nFeaturesIn, 109 | 1.0, 0.0); 110 | cudaCheckError(); 111 | 112 | dGradientDescentNAG<<>> 113 | (dW.dPtr(), mW.dPtr(), W.dPtr(), nFeaturesOut, learningRate); 114 | dGradientDescentNAG<<<1,KERNELBLOCKSIZE>>> 115 | (dB.dPtr(), mB.dPtr(), B.dPtr(), nFeaturesOut, learningRate); 116 | cudaCheckError(); 117 | if (input.dfeatures.size()>0) 118 | cd.dropD(input.dfeatures.dPtr(),dropout,input.batchSize,nFeaturesIn); 119 | } 120 | void loadWeightsFromStream(ifstream &f) { 121 | f.read((char*)&W.hVector()[0],sizeof(float)*W.size()); 122 | f.read((char*)&B.hVector()[0],sizeof(float)*B.size()); 123 | }; 124 | void putWeightsToStream(ofstream &f) { 125 | f.write((char*)&W.hVector()[0],sizeof(float)*W.size()); 126 | f.write((char*)&B.hVector()[0],sizeof(float)*B.size()); 127 | }; 128 | }; 129 | -------------------------------------------------------------------------------- /SigmoidLayer.h: -------------------------------------------------------------------------------- 1 | // _____ _ _ _ 2 | // / ____(_) (_) | | 3 | // | (___ _ __ _ _ __ ___ ___ _ __| |___ 4 | // \___ \| |/ _` | '_ ` _ \ / _ \| |/ _` / __| 5 | // ____) | | (_| | | | | | | (_) | | (_| \__ \ 6 | // |_____/|_|\__, |_| |_| |_|\___/|_|\__,_|___/ 7 | // __/ | 8 | // |___/ 9 | 10 | enum ActivationFunction {NOSIGMOID, RELU, VLEAKYRELU, LEAKYRELU, LOGISTIC, TANH, SOFTMAX}; 11 | const char *sigmoidNames[] ={ "*" , "ReLU", "VeryLeakyReLU", "LeakyReLU", "LOGISTIC", "Tanh", "Softmax Classification"}; 12 | ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 13 | __global__ void dSigmoidLogistic 14 | (float* a, float* b, int nOut) { 15 | int i=blockIdx.x*nOut; 16 | for (int j=i+threadIdx.x; j>> 25 | (a+processed*nOut, b+processed*nOut, nOut); 26 | processed+=batch; 27 | } 28 | cudaCheckError(); 29 | } 30 | 31 | __global__ void dSigmoidBackpropLogistic 32 | (float* a, float* b, float* da, float* db, int nOut) { 33 | int i=blockIdx.x*nOut; 34 | for (int j=i+threadIdx.x; j>> 43 | (a+processed*nOut, b+processed*nOut, da+processed*nOut, db+processed*nOut, nOut); 44 | processed+=batch; 45 | } 46 | cudaCheckError(); 47 | } 48 | ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 49 | __global__ void dSigmoidTanh 50 | (float* a, float* b, int nOut) { 51 | int i=blockIdx.x*nOut; 52 | for (int j=i+threadIdx.x; j>> 61 | (a+processed*nOut, b+processed*nOut, nOut); 62 | processed+=batch; 63 | } 64 | cudaCheckError(); 65 | } 66 | 67 | __global__ void dSigmoidBackpropTanh 68 | (float* a, float* b, float* da, float* db, int nOut) { 69 | int i=blockIdx.x*nOut; 70 | for (int j=i+threadIdx.x; j>> 79 | (a+processed*nOut, b+processed*nOut, da+processed*nOut, db+processed*nOut, nOut); 80 | processed+=batch; 81 | } 82 | cudaCheckError(); 83 | } 84 | ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 85 | __global__ void dSigmoidReLu 86 | (float* a, float* b, int nOut) { 87 | int i=blockIdx.x*nOut; 88 | for (int j=i+threadIdx.x; j0)?a[j]:0; 90 | } 91 | } 92 | void sigmoidReLU(float* a, float* b, int count, int nOut) { 93 | int processed=0; 94 | while (processed>> 97 | (a+processed*nOut, b+processed*nOut, nOut); 98 | processed+=batch; 99 | } 100 | cudaCheckError(); 101 | } 102 | 103 | __global__ void dSigmoidBackpropReLu 104 | (float* a, float* b, float* da, float* db, int nOut) { 105 | int i=blockIdx.x*nOut; 106 | for (int j=i+threadIdx.x; j0)?db[j]:0; 108 | } 109 | } 110 | void sigmoidBackpropReLU(float* a, float* b, float* da, float* db, int count, int nOut) { 111 | int processed=0; 112 | while (processed>> 115 | (a+processed*nOut, b+processed*nOut, da+processed*nOut, db+processed*nOut, nOut); 116 | processed+=batch; 117 | } 118 | cudaCheckError(); 119 | } 120 | ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 121 | __global__ void dSigmoidLeakyReLu 122 | (float* a, float* b, int nOut, float alpha) { 123 | int i=blockIdx.x*nOut; 124 | for (int j=i+threadIdx.x; j0)?a[j]:(a[j]*alpha); 126 | } 127 | } 128 | void sigmoidLeakyReLU(float* a, float* b, int count, int nOut, float alpha=0.01) { 129 | int processed=0; 130 | while (processed>> 133 | (a+processed*nOut, b+processed*nOut, nOut,alpha); 134 | processed+=batch; 135 | } 136 | cudaCheckError(); 137 | } 138 | 139 | __global__ void dSigmoidBackpropLeakyReLu 140 | (float* a, float* b, float* da, float* db, int nOut, float alpha) { 141 | int i=blockIdx.x*nOut; 142 | for (int j=i+threadIdx.x; j0)?db[j]:(db[j]*alpha); 144 | __syncthreads(); 145 | } 146 | } 147 | void sigmoidBackpropLeakyReLU(float* a, float* b, float* da, float* db, int count, int nOut, float alpha=0.01) { 148 | int processed=0; 149 | while (processed>> 152 | (a+processed*nOut, b+processed*nOut, da+processed*nOut, db+processed*nOut, nOut,alpha); 153 | processed+=batch; 154 | } 155 | cudaCheckError(); 156 | } 157 | 158 | 159 | 160 | 161 | //SOFTMAX only occurs at the top layer; 162 | //derivative contained in calculation of initial d_delta. 163 | __global__ void dSigmoidSoftmax(float* a, float* b, int count, int nOut) { 164 | for(int i=threadIdx.x; imx) mx=a[i*nOut+k]; 169 | for (int k=0;k>> (input.features.dPtr(),output.features.dPtr(),output.batchSize,output.nFeatures); break; 224 | } 225 | } 226 | 227 | void applySigmoidBackProp(BatchInterface& input, BatchInterface& output, ActivationFunction fn) { 228 | switch(fn) { 229 | case TANH: 230 | sigmoidBackpropTanh 231 | (input.features.dPtr(),output.features.dPtr(), 232 | input.dfeatures.dPtr(), 233 | output.dfeatures.dPtr(), 234 | output.batchSize, 235 | output.featuresPresent.size()); 236 | break; 237 | case LOGISTIC: 238 | sigmoidBackpropLogistic 239 | (input.features.dPtr(),output.features.dPtr(), 240 | input.dfeatures.dPtr(), 241 | output.dfeatures.dPtr(), 242 | output.batchSize, 243 | output.featuresPresent.size()); 244 | break; 245 | case RELU: 246 | sigmoidBackpropReLU 247 | (input.features.dPtr(),output.features.dPtr(), 248 | input.dfeatures.dPtr(), 249 | output.dfeatures.dPtr(), 250 | output.batchSize, 251 | output.featuresPresent.size()); 252 | break; 253 | case LEAKYRELU: 254 | sigmoidBackpropLeakyReLU 255 | (input.features.dPtr(), 256 | output.features.dPtr(), 257 | input.dfeatures.dPtr(), 258 | output.dfeatures.dPtr(), 259 | output.batchSize, 260 | output.featuresPresent.size(), 261 | 0.01); 262 | break; 263 | case VLEAKYRELU: 264 | sigmoidBackpropLeakyReLU 265 | (input.features.dPtr(), 266 | output.features.dPtr(), 267 | input.dfeatures.dPtr(), 268 | output.dfeatures.dPtr(), 269 | output.batchSize, 270 | output.featuresPresent.size(), 271 | 0.333); 272 | break; 273 | case SOFTMAX: 274 | dSigmoidBackpropSoftmax <<<1,NTHREADS>>> 275 | (input.features.dPtr(),output.features.dPtr(), input.dfeatures.dPtr(),output.dfeatures.dPtr(), output.batchSize, output.nFeatures); break; 276 | } 277 | } 278 | 279 | class SigmoidLayer : public Layer { 280 | public: 281 | ActivationFunction fn; 282 | SigmoidLayer(ActivationFunction fn) : fn(fn) { 283 | cout << sigmoidNames[fn]< W; //Weights 5 | vectorCUDA mW; //momentum 6 | vectorCUDA dW; //For backprop 7 | vectorCUDA B; //Weights 8 | vectorCUDA mB; //momentum 9 | vectorCUDA dB; //For backprop 10 | ActivationFunction fn; 11 | public: 12 | int nFeaturesIn; 13 | int nFeaturesOut; 14 | SimpleLayer(int nFeaturesIn, int nFeaturesOut, 15 | ActivationFunction fn=NOSIGMOID) : 16 | nFeaturesIn(nFeaturesIn), nFeaturesOut(nFeaturesOut), 17 | fn(fn) { 18 | float scale=pow(6.0f/(nFeaturesIn+nFeaturesOut),0.5f); 19 | W.resize (nFeaturesIn*nFeaturesOut); W.setUniform(-scale,scale); 20 | mW.resize (nFeaturesIn*nFeaturesOut); mW.setZero(); 21 | B.resize (nFeaturesOut); B.setZero(); 22 | mB.resize (nFeaturesOut); mB.setZero(); 23 | 24 | } 25 | void forwards 26 | (BatchInterface &input, 27 | BatchInterface &output) { 28 | assert(input.nFeatures==nFeaturesIn); 29 | assert(input.featuresPresent.size()==nFeaturesIn); 30 | output.type=input.type; 31 | output.batchSize=input.batchSize; 32 | output.nFeatures=nFeaturesOut; 33 | output.featuresPresent.hVector()=range(nFeaturesOut); 34 | output.features.resize(output.batchSize*nFeaturesOut); 35 | replicateArray(B.dPtr(), output.features.dPtr(), output.batchSize, nFeaturesOut); 36 | d_rowMajorSGEMM_alphaAB_betaC(cublasHandle, 37 | input.features.dPtr(), W.dPtr(), output.features.dPtr(), 38 | output.batchSize, nFeaturesIn, nFeaturesOut, 39 | 1.0f, 1.0f,__FILE__,__LINE__); 40 | cudaCheckError(); 41 | applySigmoid(output, output, fn); 42 | 43 | if (input.type==TRAINBATCH) 44 | output.dfeatures.resize(output.batchSize*output.featuresPresent.size()); 45 | } 46 | void backwards(BatchInterface &input, 47 | BatchInterface &output, 48 | float learningRate=0.1) { 49 | applySigmoidBackProp(output, output, fn); 50 | 51 | dW.resize(nFeaturesIn*nFeaturesOut); 52 | dB.resize(nFeaturesOut); 53 | 54 | d_rowMajorSGEMM_alphaAtB_betaC(cublasHandle, 55 | input.features.dPtr(), output.dfeatures.dPtr(), dW.dPtr(), 56 | nFeaturesIn, output.batchSize, nFeaturesOut, 57 | 1.0, 0.0); 58 | dB.setZero(); 59 | columnSum(output.dfeatures.dPtr(), dB.dPtr(), output.batchSize, nFeaturesOut); 60 | cudaCheckError(); 61 | 62 | if (input.dfeatures.size()>0) 63 | d_rowMajorSGEMM_alphaABt_betaC(cublasHandle, 64 | output.dfeatures.dPtr(), W.dPtr(), input.dfeatures.dPtr(), 65 | output.batchSize,nFeaturesOut,nFeaturesIn, 66 | 1.0, 0.0); 67 | cudaCheckError(); 68 | 69 | dGradientDescentNAG<<>> 70 | (dW.dPtr(), mW.dPtr(), W.dPtr(), nFeaturesOut, learningRate); 71 | cudaCheckError(); 72 | dGradientDescentNAG<<<1,KERNELBLOCKSIZE>>> 73 | (dB.dPtr(), mB.dPtr(), B.dPtr(), nFeaturesOut, learningRate); 74 | cudaCheckError(); 75 | } 76 | void loadWeightsFromStream(ifstream &f) { 77 | f.read((char*)&W.hVector()[0],sizeof(float)*W.size()); 78 | f.read((char*)&B.hVector()[0],sizeof(float)*B.size()); 79 | }; 80 | void putWeightsToStream(ofstream &f) { 81 | f.write((char*)&W.hVector()[0],sizeof(float)*W.size()); 82 | f.write((char*)&B.hVector()[0],sizeof(float)*B.size()); 83 | }; 84 | }; 85 | -------------------------------------------------------------------------------- /SoftmaxClassifier.h: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////////////////////////// 2 | //Calculate softmaxProbability(i) - indicator(i=label) 3 | // for i=0,1,...N-1 with N the number of character classes. 4 | __global__ void dDerivativeOfCostWRTpreSoftmaxTopLevelWeights 5 | (int batchSize, float* topDelta, float* topGrid, 6 | int* labels, int N) { 7 | for (int k=0;kmaxP) { 21 | prediction=k; 22 | maxP=d_probs[i*nOut+k]; 23 | } 24 | } 25 | d_probs[i*nOut+prediction]-=1; 26 | d_predictions[i]=prediction; 27 | } 28 | } 29 | 30 | 31 | //Assume no dropout in the output layer! nClasses:=input.nFeatures. 32 | vector > SoftmaxClassifier(BatchInterface& input, int nTop) { 33 | assert(input.nFeatures==input.featuresPresent.size()); //Could bypass this requirement for training batches so long as all training labels present in the batch are a subset of elements in input.featuresPresent; useful if the number of classes is very large ?!? 34 | vectorCUDA probs(true, input.features.size()); 35 | cudaSafeCall(cudaMemcpy(probs.dPtr(),input.features.dPtr(), input.features.size()*sizeof(float), cudaMemcpyDeviceToDevice)); 36 | vectorCUDA pred(true, nTop*input.batchSize); 37 | for (int j=0;j>> 39 | (probs.dPtr(), pred.dPtr()+j*input.batchSize, 40 | input.batchSize, input.nFeatures); 41 | cudaCheckError(); 42 | 43 | vector > predictions(input.batchSize); 44 | vector &p=pred.hVector(); 45 | for (int i=0;i > SoftmaxClassifier(BatchInterface& input, int nTop, vectorCUDA &labels, int& mistakes) { 53 | vector > predictions=SoftmaxClassifier(input, nTop); 54 | 55 | mistakes+=input.batchSize; 56 | for (int i=0;i>> 66 | (input.batchSize, input.dfeatures.dPtr(), input.features.dPtr(), 67 | labels.dPtr(), input.nFeatures); 68 | cudaCheckError(); 69 | } 70 | return predictions; 71 | } 72 | -------------------------------------------------------------------------------- /dropout.h: -------------------------------------------------------------------------------- 1 | //Convolutional case: half the forward pass without dropout. finish the forward pass 4 times in parallel with lots of dropout, first half of backprop in parallel. Sum derivatives to finish the backward pass 2 | //Implement dropout pretraining for ReLU units 3 | 4 | 5 | 6 | // Ben Graham, University of Warwick, 2014 7 | // Batch-wise and sample-wise dropout 8 | 9 | // N.B. BatchWiseDropoutLayer applies dropout to the output layer, so output.featuresPresent.size() <= output.nFeatures 10 | // It therefore has to accept input hidden layers with input.featuresPresent.size() <= input.nFeatures 11 | // Other layer-types assume input.featuresPresent.size() == input.nFeatures 12 | // SimpleLayer does no dropout at all. 13 | // SampleWiseDropoutLayer applies dropout to the input layer (by multiplying by zero). The output layer has full size. 14 | 15 | 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include "cuda.h" 37 | #include 38 | using namespace std; 39 | 40 | #include "Rng.h" 41 | #include "CudaUtils.h" 42 | #include "Batches.h" 43 | #include "Layer.h" 44 | #include "SigmoidLayer.h" 45 | #include "BatchWiseDropoutLayer.h" 46 | #include "SampleWiseDropoutLayer.h" 47 | #include "SimpleLayer.h" 48 | #include "SoftmaxClassifier.h" 49 | #include "Dataset.h" 50 | #include "BatchProducer.h" 51 | #include "ANN.h" 52 | #include "NetworkArchitectures.h" 53 | -------------------------------------------------------------------------------- /readArtificialDataset.h: -------------------------------------------------------------------------------- 1 | static void loadData(string filename, vector &characters, int n) { 2 | ifstream f(filename.c_str()); 3 | if (!f) { 4 | cout <<"Cannot find " << filename << endl; 5 | exit(EXIT_FAILURE);} 6 | unsigned char data[1001]; 7 | for (int i=0;ifeatures[j]=data[j+1]*2-1; 12 | characters.push_back(character); 13 | } 14 | } 15 | 16 | Dataset ArtificialTrainSet() { 17 | Dataset dataset; 18 | dataset.name="Artificial train set"; 19 | dataset.type=TRAINBATCH; 20 | dataset.nFeatures=1000; 21 | dataset.nClasses=100; 22 | string train("Data/Artificial/artificial.train.data"); 23 | loadData(train, dataset.samples,100000); 24 | return dataset; 25 | } 26 | Dataset ArtificialTestSet() { 27 | Dataset dataset; 28 | dataset.name="Artificial test set"; 29 | dataset.type=TESTBATCH; 30 | dataset.nFeatures=1000; 31 | dataset.nClasses=100; 32 | string train("Data/Artificial/artificial.test.data"); 33 | loadData(train, dataset.samples,10000); 34 | return dataset; 35 | } 36 | -------------------------------------------------------------------------------- /readCIFAR10.h: -------------------------------------------------------------------------------- 1 | void readCIFAR10File(vector &characters, const char* filename) { 2 | ifstream file(filename,ios::in|ios::binary); 3 | if (!file) { 4 | cout <<"Cannot find " << filename << endl; 5 | exit(EXIT_FAILURE); 6 | } 7 | unsigned char label; 8 | while (file.read((char*)&label,1)) { 9 | vectorDatum* character = new vectorDatum(3072,label); 10 | unsigned char bitmap[3072]; 11 | file.read((char*)bitmap,3072); 12 | for (int x=0;x<3072;x++) { 13 | character->features[x]=bitmap[x]/127.5-1; 14 | } 15 | characters.push_back(character); 16 | } 17 | file.close(); 18 | } 19 | Dataset Cifar10TrainSet() { 20 | Dataset dataset; 21 | dataset.name="CIFAR-10 train set"; 22 | dataset.type=TRAINBATCH; 23 | dataset.nFeatures=3072; 24 | dataset.nClasses=10; 25 | char filenameFormat[]="Data/CIFAR10/data_batch_%d.bin"; 26 | char filename[100]; 27 | for(int fileNumber=1;fileNumber<=5;fileNumber++) { 28 | sprintf(filename,filenameFormat,fileNumber); 29 | readCIFAR10File(dataset.samples,filename); 30 | } 31 | return dataset; 32 | } 33 | Dataset Cifar10TestSet() { 34 | Dataset dataset; 35 | dataset.name="CIFAR-10 test set"; 36 | dataset.type=TESTBATCH; 37 | dataset.nFeatures=3072; 38 | dataset.nClasses=10; 39 | char filenameTest[]="Data/CIFAR10/test_batch.bin"; 40 | readCIFAR10File(dataset.samples,filenameTest); 41 | return dataset; 42 | } 43 | -------------------------------------------------------------------------------- /readMNIST.h: -------------------------------------------------------------------------------- 1 | static int intToggleEndianness(int a) { 2 | int b=0; 3 | b+=a%256*(1<<24);a>>=8; 4 | b+=a%256*(1<<16);a>>=8; 5 | b+=a%256*(1<< 8);a>>=8; 6 | b+=a%256*(1<< 0); 7 | return b; 8 | } 9 | 10 | static void loadMnistC(string filename, vector &characters) { 11 | ifstream f(filename.c_str()); 12 | if (!f) { 13 | cout <<"Cannot find " << filename << endl; 14 | exit(EXIT_FAILURE);} 15 | int a,n1,n2,n3; 16 | f.read((char*)&a,4); 17 | f.read((char*)&a,4); 18 | n1=intToggleEndianness(a); 19 | f.read((char*)&a,4); 20 | n2=intToggleEndianness(a); 21 | f.read((char*)&a,4); 22 | n3=intToggleEndianness(a); 23 | unsigned char *bitmap=new unsigned char[n2*n3]; 24 | for (int i1=0;i1features[j]=bitmap[j]/255.0; 29 | characters.push_back(character); 30 | } 31 | delete[] bitmap; 32 | } 33 | 34 | static void loadMnistL(string filename, vector &characters) { 35 | ifstream f(filename.c_str()); 36 | if (!f) { 37 | cout <<"Cannot find " << filename << endl; 38 | exit(EXIT_FAILURE);} 39 | int a,n; 40 | char l; 41 | f.read((char*)&a,4); 42 | f.read((char*)&a,4); 43 | n=intToggleEndianness(a); 44 | for (int i=0;ilabel=l; 47 | } 48 | } 49 | 50 | Dataset MnistTrainSet() { 51 | Dataset dataset; 52 | dataset.name="MNIST train set"; 53 | dataset.type=TRAINBATCH; 54 | dataset.nFeatures=784; 55 | dataset.nClasses=10; 56 | string trainC("Data/MNIST/train-images-idx3-ubyte"); 57 | string trainL("Data/MNIST/train-labels-idx1-ubyte"); 58 | loadMnistC(trainC, dataset.samples); 59 | loadMnistL(trainL, dataset.samples); 60 | return dataset; 61 | } 62 | Dataset MnistTestSet() { 63 | Dataset dataset; 64 | dataset.type=TESTBATCH; 65 | dataset.name="MNIST test set"; 66 | dataset.nFeatures=784; 67 | dataset.nClasses=10; 68 | string testC("Data/MNIST/t10k-images-idx3-ubyte"); 69 | string testL("Data/MNIST/t10k-labels-idx1-ubyte"); 70 | loadMnistC(testC, dataset.samples); 71 | loadMnistL(testL, dataset.samples); 72 | return dataset; 73 | } 74 | -------------------------------------------------------------------------------- /runArtificial.cu: -------------------------------------------------------------------------------- 1 | #include "dropout.h" 2 | #include "readArtificialDataset.h" 3 | 4 | int epoch=0; 5 | int cudaDevice=-1; 6 | 7 | int main() { 8 | Dataset trainSet=ArtificialTrainSet(); 9 | Dataset testSet=ArtificialTestSet(); 10 | 11 | int batchSize=100; 12 | trainSet.summary(); 13 | testSet.summary(); 14 | 15 | //SimpleNet ann(trainSet.nFeatures, 1000, 3, trainSet.nClasses, RELU, cudaDevice); 16 | BatchWiseDropoutNet ann(trainSet.nFeatures, 1000, 3, trainSet.nClasses, RELU, 0.5, 0.5, cudaDevice); 17 | //SampleWiseDropoutNet ann(trainSet.nFeatures, 1000, 3, trainSet.nClasses, RELU, 0.5, 0.5, cudaDevice); 18 | 19 | for (epoch++;epoch<=300;epoch++) { 20 | cout <<"epoch: " << epoch << " " << flush; 21 | trainSet.shuffle(); 22 | iterate(ann, trainSet, batchSize,0.001); 23 | iterate(ann, testSet, batchSize/4); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /runCifar10.cu: -------------------------------------------------------------------------------- 1 | #include "dropout.h" 2 | #include "readCIFAR10.h" 3 | 4 | int cudaDevice=-1; 5 | 6 | int main() { 7 | string baseName="Cifar10"; 8 | 9 | Dataset trainSet=Cifar10TrainSet(); 10 | Dataset testSet=Cifar10TestSet(); 11 | 12 | int batchSize=100; 13 | trainSet.summary(); 14 | testSet.summary(); 15 | 16 | //SimpleNet ann(trainSet.nFeatures, 1000, 3, trainSet.nClasses, RELU, cudaDevice); 17 | BatchWiseDropoutNet ann(trainSet.nFeatures, 1000, 3, trainSet.nClasses, RELU, 0.2, 0.5, cudaDevice); 18 | //SampleWiseDropoutNet ann(trainSet.nFeatures,4000, 3, trainSet.nClasses, RELU, 0.2, 0.5, cudaDevice); 19 | for (int epoch=1;epoch<=1000;epoch++) { 20 | cout <<"epoch: " << epoch <<" " << flush; 21 | trainSet.shuffle(); 22 | iterate(ann, trainSet, batchSize,0.001); 23 | if (epoch%10==0) 24 | iterate(ann, testSet, batchSize/4); 25 | 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /runMnist.cu: -------------------------------------------------------------------------------- 1 | #include "dropout.h" 2 | #include "readMNIST.h" 3 | 4 | int cudaDevice=3; 5 | 6 | int main() { 7 | string baseName="/tmp/mnist"; 8 | 9 | Dataset trainSet=MnistTrainSet(); 10 | Dataset testSet=MnistTestSet(); 11 | 12 | int batchSize=100; 13 | trainSet.summary(); 14 | testSet.summary(); 15 | 16 | //SimpleNet ann(trainSet.nFeatures, 800, 2, trainSet.nClasses, RELU, cudaDevice); 17 | BatchWiseDropoutNet ann(trainSet.nFeatures, 800, 2, trainSet.nClasses, RELU, 0.2, 0.5, cudaDevice); 18 | //SampleWiseDropoutNet ann(trainSet.nFeatures, 800, 2, trainSet.nClasses, RELU, 0.2, 0.5, cudaDevice); 19 | 20 | for (int epoch=1;epoch<=200;epoch++) { 21 | cout <<"epoch: " << epoch << " " << flush; 22 | trainSet.shuffle(); 23 | iterate(ann, trainSet, batchSize,0.01*exp(-epoch*0.01)); 24 | if (epoch%10==0) 25 | iterate(ann, testSet, batchSize/4); 26 | } 27 | } 28 | --------------------------------------------------------------------------------