├── .gitignore
├── Makefile
├── README.md
├── source
├── 8bit_floats.hdf5
├── DeepNeuralNetwork.cpp
├── DeepNeuralNetwork.h
├── Layer.cpp
├── Layer.h
├── WikiMaxoutNet.cpp
├── WikiMaxoutNet.h
├── WikiMaxoutNet_PCIe.cpp
├── WikiMaxoutNet_PCIe.h
├── WikiMaxoutNet_PCIe2.cpp
├── WikiMaxoutNet_PCIe2.h
├── WikiNetDist.cpp
├── WikiNetDist.h
├── basicOps.cu
├── basicOps.cuh
├── batchAllocator.cpp
├── batchAllocator.h
├── clusterKernels.cu
├── clusterKernels.cuh
├── clusterNet.cpp
├── clusterNet.h
├── test.cu
├── util.cu
└── util.cuh
└── tests
├── basicOps_test.cu
├── basicOps_test.cuh
├── batchAllocator_test.cu
├── batchAllocator_test.cuh
├── clusterNet_test.cu
├── clusterNet_test.cuh
├── crowdflower_X_test.hdf5
├── crowdflower_y_test.hdf5
├── miniMNIST_test.cu
├── miniMNIST_test.cuh
├── mnist_mini_X.hdf5
├── mnist_mini_y.hdf5
├── numpy_arange_as_h5py.hdf5
├── scipy_sparse_arange_as_h5py.hdf5
├── testSuite.cu
├── testSuite.cuh
├── util_test.cu
└── util_test.cuh
/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 | .metadata
3 | .*
4 | !.gitignore
5 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | CC = nvcc
2 | MPI_DIR=/usr/local/openmpi
3 | #MPI_DIR=/usr/mpi/openmpi-1.8.1
4 | HDF5_DIR = /home/tim/apps/hdf5-1.8.14/hdf5
5 | SZIP_DIR = /home/tim/apps/szip-2.1/szip/
6 | TOP := $(dir $(CURDIR)/$(word $(words $(MAKEFILE_LIST)),$(MAKEFILE_LIST)))
7 | TESTS := tests/testSuite.cu $(wildcard tests/*_test.cu)
8 | NODES=tim@10.0.0.2
9 | HOSTFILE=/home/tim/cluster_one_node
10 | #HOSTFILE=/home/tim/cluster_other_node
11 | #HOSTFILE=/home/tim/cluster
12 | SCR := $(wildcard source/*.cu) $(wildcard source/*.cpp)
13 | INCLUDE = -I $(MPI_DIR)/include -I $(TOP)source -I $(TOP)tests -I /usr/local/cuda/include -I $(HDF5_DIR)include -I $(SZIP_DIR)include
14 | LIB = -L $(MPI_DIR)/lib -L /usr/local/cuda/lib64 -L $(HDF5_DIR)lib -L $(SZIP_DIR)lib
15 | CFLAGS = -gencode arch=compute_35,code=sm_35 -lcusparse -lcublas -lcurand -lmpi_cxx -lmpi -lhdf5 -lhdf5_hl -lz $(LIB) $(INCLUDE)
16 | LINK = source/util.cu source/clusterKernels.cu source/basicOps.cu $(wildcard source/*.cpp)
17 |
18 | EXECSRC = build/clusterNet.out
19 | EXECTEST = build/testSuite.out
20 |
21 | all : $(EXECSRC) #$(EXECTEST)
22 |
23 | $(EXECSRC) : $(SCR)
24 | $(CC) $^ -o $@ $(CFLAGS)
25 |
26 | $(EXECTEST): $(SCR) $(TESTS)
27 | $(CC) $(TESTS) $(LINK) -o $@ $(CFLAGS)
28 |
29 | test:
30 | #scp $(TOP)$(EXECTEST) $(NODES):$(TOP)build/;
31 | $(MPI_DIR)/bin/mpirun -x LD_LIBRARY_PATH -np 2 $(TOP)$(EXECTEST)
32 |
33 | run:
34 | #scp $(TOP)$(EXECSRC) $(NODES):$(TOP)build/;
35 | $(MPI_DIR)/bin/mpirun -x LD_LIBRARY_PATH -np 2 $(TOP)$(EXECSRC)
36 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | clusterNet
2 | ==============
3 |
4 | Deep neural network framework for GPU clusters:
5 |
6 | - supports NVIDIA GPUDirect RDMA
7 | - easy distributed computation:
8 |
9 | Matrix C = dot(A,B); //uses one GPU
10 | Matrix C = dotMPI(A,B); //uses all available GPUs on the board or in the network
11 | - no delay between batches due to asynchronous memory copies to the GPU:
12 | gpu.init_batch_allocator(X, y, 128);
13 | for(int i = 0; i < gpu.m_total_batches; i++)
14 | {
15 | gpu.allocate_next_batch_async(); //loads the next batch while you do computations
16 | result = gpu.dot(gpu.m_current_batch_X,w1); //do your computations here
17 | gpu.replace_current_batch_with_next(); //get the next batch which is already loaded
18 | }
19 |
20 | - distributed weights which are larger than a single GPU memory:
21 |
22 | ClusterNet gpus = ClusterNet(argc,argv,12346);
23 | Matrix *batch = gpus.rand(128,100000);//34 MB
24 | Matrix *out1 = empty(128,40000);//19 MB
25 | Matrix *out2 = empty(128,20000);//9 MB
26 | Matrix *W1 = gpus.distributed_uniformSqrtWeight(100000,40000);//15258 MB
27 | Matrix *W2 = gpus.distributed_uniformSqrtWeight(40000,20000);//3051 MB
28 | gpus.tick("Time taken");
29 | gpus.dotMPI(batch,W1,out1);
30 | gpus.dotMPI(out1,W2,out2);
31 | gpus.tock("Time taken");
32 | >>>Time taken: 117.704285 ms.
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/source/8bit_floats.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimDettmers/clusterNet/cb0bec556c480d26a8be4cd7ff0317ff661fab64/source/8bit_floats.hdf5
--------------------------------------------------------------------------------
/source/DeepNeuralNetwork.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 |
11 |
12 | using std::cout;
13 | using std::endl;
14 | using std::string;
15 |
16 | DeepNeuralNetwork::DeepNeuralNetwork(std::vector lLayerSizes, Networktype_t net_type, ClusterNet gpus, BatchAllocator allocator, int categories)
17 | {
18 | m_gpus = gpus;
19 | m_BA = allocator;
20 | int device = 0;
21 | cudaGetDevice(&device);
22 | cout << "Active device: GPU" << device << endl;
23 |
24 | EPOCHS = 250;
25 | TRANSITION_EPOCH = 250;
26 | LEARNING_RATE = 0.01;
27 | LEARNING_RATE_DECAY = 0.99;
28 | MOMENTUM = 0.7;
29 | OUTPUT_IS_PROBABILITY = false;
30 | PRINT_MISSCLASSIFICATION = net_type == Classification ? true : false;
31 | MAIN_UNIT = Logistic;
32 | m_output_dim = categories;
33 | m_net_type = net_type;
34 | UPDATE_TYPE = NesterovRMSProp;
35 |
36 | RMSPROP_MOMENTUM = 0.9;
37 |
38 | init_network_layout(lLayerSizes);
39 | init_weights();
40 |
41 | }
42 |
43 | void DeepNeuralNetwork::init_network_layout(std::vector lLayerSizes)
44 | {
45 | m_lLayers = lLayerSizes;
46 | if(m_net_type == Classification){ m_costFunction = Cross_Entropy;}
47 | if(m_net_type == Regression){ m_costFunction = Root_Squared_Error; }
48 |
49 | DROPOUT.push_back(0.2f);
50 | for(int i = 0;i < m_lLayers.size(); i++)
51 | {
52 | if(m_net_type == Classification){ lUnits.push_back(MAIN_UNIT); }
53 | if(m_net_type == Regression){ lUnits.push_back(Rectified_Linear); }
54 | DROPOUT.push_back(0.5f);
55 | }
56 | if(m_net_type == Classification){ lUnits.push_back(Softmax); }
57 | if(m_net_type == Regression){ lUnits.push_back(Linear); }
58 | }
59 |
60 | void DeepNeuralNetwork::init_weights()
61 | {
62 | int output_size = m_output_dim;
63 | if(m_net_type == Regression)
64 | output_size = m_BA.CURRENT_BATCH_Y->cols;
65 |
66 | if(m_BA.BATCH_METHOD == Distributed_weights || m_BA.BATCH_METHOD == Distributed_weights_sparse)
67 | {
68 | max_values.push_back(0.1f);
69 | W.push_back(m_gpus.distributed_uniformSqrtWeight(m_BA.CURRENT_BATCH->cols,m_lLayers[0]));
70 | B.push_back(zeros(1,m_lLayers[0]));
71 | M.push_back(m_gpus.distributed_zeros(m_BA.CURRENT_BATCH->cols,m_lLayers[0]));
72 | B_M.push_back(zeros(1,m_lLayers[0]));
73 | MS.push_back(m_gpus.distributed_zeros(m_BA.CURRENT_BATCH->cols,m_lLayers[0]));
74 | B_MS.push_back(zeros(1,m_lLayers[0]));
75 | GRAD.push_back(m_gpus.distributed_zeros(m_BA.CURRENT_BATCH->cols,m_lLayers[0]));
76 | GRAD_approx.push_back(m_gpus.distributed_zeros(m_BA.CURRENT_BATCH->cols,m_lLayers[0]));
77 | GRAD8bit.push_back(m_gpus.distributed_zeros(m_BA.CURRENT_BATCH->cols,m_lLayers[0]));
78 | B_GRAD.push_back(zeros(1,m_lLayers[0]));
79 | for(int i = 0;i < (m_lLayers.size()-1); i++)
80 | {
81 | max_values.push_back(0.1f);
82 | W.push_back(m_gpus.distributed_uniformSqrtWeight(m_lLayers[i],m_lLayers[i+1]));
83 | B.push_back(zeros(1,m_lLayers[i+1]));
84 | M.push_back(m_gpus.distributed_zeros(m_lLayers[i],m_lLayers[i+1]));
85 | B_M.push_back(zeros(1,m_lLayers[i+1]));
86 | MS.push_back(m_gpus.distributed_zeros(m_lLayers[i],m_lLayers[i+1]));
87 | B_MS.push_back(zeros(1,m_lLayers[i+1]));
88 | GRAD.push_back(m_gpus.distributed_zeros(m_lLayers[i],m_lLayers[i+1]));
89 | GRAD_approx.push_back(m_gpus.distributed_zeros(m_lLayers[i],m_lLayers[i+1]));
90 | GRAD8bit.push_back(m_gpus.distributed_zeros(m_lLayers[i],m_lLayers[i+1]));
91 | B_GRAD.push_back(zeros(1,m_lLayers[i+1]));
92 | }
93 | max_values.push_back(0.1f);
94 | W.push_back(m_gpus.distributed_uniformSqrtWeight(m_lLayers.back(), output_size));
95 | B.push_back(zeros(1, output_size));
96 | M.push_back(m_gpus.distributed_zeros(m_lLayers.back(),output_size));
97 | B_M.push_back(zeros(1, output_size));
98 | MS.push_back(m_gpus.distributed_zeros(m_lLayers.back(),output_size));
99 | B_MS.push_back(zeros(1, output_size));
100 | GRAD.push_back(m_gpus.distributed_zeros(m_lLayers.back(),output_size));
101 | GRAD_approx.push_back(m_gpus.distributed_zeros(m_lLayers.back(),output_size));
102 | GRAD8bit.push_back(m_gpus.distributed_zeros(m_lLayers.back(),output_size));
103 | B_GRAD.push_back(zeros(1, output_size));
104 | }
105 | else
106 | {
107 | max_values.push_back(0.1f);
108 | W.push_back(m_gpus.uniformSqrtWeight(m_BA.CURRENT_BATCH->cols,m_lLayers[0]));
109 | B.push_back(zeros(1,m_lLayers[0]));
110 | M.push_back(zeros(m_BA.CURRENT_BATCH->cols,m_lLayers[0]));
111 | B_M.push_back(zeros(1,m_lLayers[0]));
112 | MS.push_back(zeros(m_BA.CURRENT_BATCH->cols,m_lLayers[0]));
113 | B_MS.push_back(zeros(1,m_lLayers[0]));
114 | GRAD.push_back(zeros(m_BA.CURRENT_BATCH->cols,m_lLayers[0]));
115 | GRAD8bit.push_back(empty_char(m_BA.CURRENT_BATCH->cols,m_lLayers[0]));
116 | GRAD_approx.push_back(zeros(m_BA.CURRENT_BATCH->cols,m_lLayers[0]));
117 | B_GRAD.push_back(zeros(1,m_lLayers[0]));
118 | for(int i = 0;i < (m_lLayers.size()-1); i++)
119 | {
120 | max_values.push_back(0.1f);
121 | W.push_back(m_gpus.uniformSqrtWeight(m_lLayers[i],m_lLayers[i+1]));
122 | B.push_back(zeros(1,m_lLayers[i+1]));
123 | M.push_back(zeros(m_lLayers[i],m_lLayers[i+1]));
124 | B_M.push_back(zeros(1,m_lLayers[i+1]));
125 | MS.push_back(zeros(m_lLayers[i],m_lLayers[i+1]));
126 | B_MS.push_back(zeros(1,m_lLayers[i+1]));
127 | GRAD.push_back(zeros(m_lLayers[i],m_lLayers[i+1]));
128 | GRAD8bit.push_back(empty_char(m_lLayers[i],m_lLayers[i+1]));
129 | GRAD_approx.push_back(zeros(m_lLayers[i],m_lLayers[i+1]));
130 | B_GRAD.push_back(zeros(1,m_lLayers[i+1]));
131 | }
132 | max_values.push_back(0.1f);
133 | W.push_back(m_gpus.uniformSqrtWeight(m_lLayers.back(),output_size));
134 | B.push_back(zeros(1,output_size));
135 | M.push_back(zeros(m_lLayers.back(),output_size));
136 | B_M.push_back(zeros(1,output_size));
137 | MS.push_back(zeros(m_lLayers.back(),output_size));
138 | B_MS.push_back(zeros(1,output_size));
139 | GRAD.push_back(zeros(m_lLayers.back(),output_size));
140 | GRAD8bit.push_back(empty_char(m_lLayers.back(),output_size));
141 | GRAD_approx.push_back(zeros(m_lLayers.back(),output_size));
142 | B_GRAD.push_back(zeros(1,output_size));
143 | }
144 |
145 | for(int i = 0; i < W.size(); i++)
146 | cout << W[i]->rows << 'x' << W[i]->cols << endl;
147 |
148 | }
149 |
150 |
151 | void DeepNeuralNetwork::save_history()
152 | {
153 |
154 | Matrix *history = empty_cpu(train_history.size(),2);
155 |
156 | for(int i = 0; i < train_history.size(); i++)
157 | {
158 | history->data[i*2] = train_history[i];
159 | history->data[(i*2)+1] = cv_history[i];
160 | }
161 |
162 | ::write_hdf5("/home/tim/data/mnist/history.hdf5",history);
163 |
164 | free(history->data);
165 | free(history);
166 | }
167 |
168 | void DeepNeuralNetwork::train()
169 | {
170 | //if(OUTPUT_IS_PROBABILITY)
171 | //lUnits.back() = Double_Rectified_Linear;
172 |
173 |
174 | float original_learning_rate = LEARNING_RATE;
175 | for(int EPOCH = 0; EPOCH < EPOCHS; EPOCH++)
176 | {
177 | if(m_BA.BATCH_METHOD == Single_GPU || (m_BA.BATCH_METHOD != Single_GPU && m_gpus.MYRANK == 0))
178 | std::cout << "EPOCH: " << EPOCH + 1 << std::endl;
179 | MOMENTUM += 0.01;
180 | if(MOMENTUM > 0.95) MOMENTUM = 0.95;
181 |
182 |
183 | if(EPOCH > 0 && EPOCH % (TRANSITION_EPOCH-1) == 0)
184 | {
185 | TRANSITION_EPOCH = TRANSITION_EPOCH + (TRANSITION_EPOCH/4);
186 | cout << "Transition point reached: Halving dropout!" << endl;
187 | //m_update_type = NoMomentum;
188 | for(int i = 0; i < DROPOUT.size(); i++)
189 | DROPOUT[i] = DROPOUT[i] / 2.0;
190 |
191 |
192 | LEARNING_RATE_DECAY = 0.85f;
193 | }
194 |
195 |
196 |
197 | for(int i = 0; i < m_BA.TOTAL_BATCHES; i++)
198 | {
199 | nesterov_updates();
200 | m_BA.broadcast_batch_to_processes();
201 | feedforward(Dropout);
202 |
203 | if(m_BA.CURRENT_BATCH->isSparse == 0)
204 | m_BA.allocate_next_batch_async();
205 | backprop();
206 |
207 | weight_updates();
208 | free_variables();
209 |
210 | if(m_BA.CURRENT_BATCH->isSparse == 1)
211 | m_BA.allocate_next_batch_async();
212 | m_BA.replace_current_batch_with_next();
213 | }
214 | train_error();
215 | cross_validation_error();
216 |
217 |
218 | for(int i = 0; i < W.size(); i++)
219 | {
220 | abs(GRAD[i],GRAD[i]);
221 | max_values[i] = max(GRAD[i]);
222 | }
223 |
224 |
225 | LEARNING_RATE*=LEARNING_RATE_DECAY;
226 | save_history();
227 | }
228 |
229 | m_BA.finish_batch_allocator();
230 |
231 | }
232 |
233 | void DeepNeuralNetwork::backprop()
234 | {
235 | //backprop
236 | if(m_net_type == Classification)
237 | {
238 | Matrix *t = create_t_matrix(m_BA.CURRENT_BATCH_Y,m_output_dim);
239 | E.push_back(sub(Z.back(), t));
240 | cudaFree(t->data);
241 | }
242 | else
243 | {
244 | E.push_back(sub(Z.back(), m_BA.CURRENT_BATCH_Y));
245 | }
246 | for(int i = W.size()-1; i > 0; i--)
247 | {
248 | Matrix *bias_activation = ones(1,E.back()->rows);
249 | m_gpus.Tdot(Z[i],E.back(),GRAD[i]);
250 | m_gpus.dot(bias_activation,E.back(),B_GRAD[i]);
251 | cudaFree(bias_activation->data);
252 | derivative_function(i, Z[i]);
253 | E.push_back(m_gpus.dotT(E.back(), W[i]));
254 | mul(E.back(),Z[i],E.back());
255 | }
256 | Matrix *bias_activation = ones(1,E.back()->rows);
257 | m_gpus.Tdot(m_BA.CURRENT_BATCH,E.back(),GRAD[0]);
258 | m_gpus.dot(bias_activation,E.back(),B_GRAD[0]);
259 | cudaFree(bias_activation->data);
260 |
261 | }
262 |
263 | void DeepNeuralNetwork::free_variables()
264 | {
265 | for(int i = 0; i < D.size(); i++)
266 | {
267 | if(D[i]->isSparse == 0)
268 | cudaFree(D[i]->data);
269 | else
270 | {
271 | cudaFree(D[i]->data);
272 | cudaFree(D[i]->idx_cols);
273 | cudaFree(D[i]->ptr_rows);
274 | }
275 | }
276 | D.clear();
277 |
278 | for(int i = 1; i < Z.size(); i++)
279 | {
280 | if(Z[i]->isSparse == 0)
281 | cudaFree(Z[i]->data);
282 | else
283 | {
284 | cudaFree(Z[i]->data);
285 | cudaFree(Z[i]->idx_cols);
286 | cudaFree(Z[i]->ptr_rows);
287 | }
288 | }
289 | Z.clear();
290 |
291 | for(int i = 0; i < E.size(); i++)
292 | cudaFree(E[i]->data);
293 | E.clear();
294 |
295 | }
296 |
297 |
298 |
299 | void DeepNeuralNetwork::weight_updates()
300 | {
301 | for(int i = 0;i < GRAD.size(); i++)
302 | {
303 | if(UPDATE_TYPE == NesterovRMSProp)
304 | {
305 |
306 |
307 |
308 |
309 | RMSprop_with_nesterov_weight_update(MS[i],GRAD[i],W[i],M[i],RMSPROP_MOMENTUM,LEARNING_RATE,m_BA.CURRENT_BATCH->rows, MOMENTUM);
310 | RMSprop_with_nesterov_weight_update(B_MS[i],B_GRAD[i],B[i],B_M[i],RMSPROP_MOMENTUM,LEARNING_RATE,m_BA.CURRENT_BATCH->rows, MOMENTUM);
311 |
312 |
313 |
314 | }
315 | else if(UPDATE_TYPE == NesterovMomentum)
316 | {
317 | Nesterov_weight_update(MS[i],GRAD[i],W[i],M[i],RMSPROP_MOMENTUM,LEARNING_RATE,m_BA.CURRENT_BATCH->rows, MOMENTUM);
318 | Nesterov_weight_update(B_MS[i],B_GRAD[i],B[i],B_M[i],RMSPROP_MOMENTUM,LEARNING_RATE,m_BA.CURRENT_BATCH->rows, MOMENTUM);
319 | }
320 | else if(UPDATE_TYPE == RMSProp)
321 | {
322 |
323 |
324 |
325 |
326 |
327 | RMSprop_with_weight_update(MS[i],GRAD[i],W[i],M[i],RMSPROP_MOMENTUM,LEARNING_RATE,m_BA.CURRENT_BATCH->rows, MOMENTUM);
328 | RMSprop_with_weight_update(B_MS[i],B_GRAD[i],B[i],B_M[i],RMSPROP_MOMENTUM,LEARNING_RATE,m_BA.CURRENT_BATCH->rows, MOMENTUM);
329 |
330 |
331 |
332 |
333 |
334 | /*
335 | RMSprop_with_weight_update_8bit(MS[i],GRAD[i],W[i],M[i],RMSPROP_MOMENTUM,LEARNING_RATE,m_BA.CURRENT_BATCH->rows, MOMENTUM);
336 | RMSprop_with_weight_update(B_MS[i],B_GRAD[i],B[i],B_M[i],RMSPROP_MOMENTUM,LEARNING_RATE,m_BA.CURRENT_BATCH->rows, MOMENTUM);
337 | m_gpus.compression_8bit(GRAD[i],max_values[i],GRAD8bit[i]);
338 | m_gpus.decompression_8bit(GRAD8bit[i],max_values[i],GRAD_approx[i]);
339 | sub(W[i],GRAD_approx[i],W[i]);
340 | */
341 |
342 |
343 |
344 | //squared_error(GRAD[i],GRAD_approx[i],GRAD_approx[i]);
345 | //cout << "ERROR: " << sum(GRAD_approx[i]) << endl;
346 |
347 |
348 | }
349 | else if(UPDATE_TYPE == NoMomentum)
350 | {
351 | scalarMul(GRAD[i],LEARNING_RATE/(float)m_BA.CURRENT_BATCH->rows,GRAD[i]);
352 | sub(W[i],GRAD[i],W[i]);
353 | scalarMul(B_GRAD[i],LEARNING_RATE/(float)m_BA.CURRENT_BATCH->rows,GRAD[i]);
354 | sub(B[i],B_GRAD[i],B[i]);
355 | }
356 | }
357 | }
358 |
359 |
360 | void DeepNeuralNetwork::feedforward(FeedForward_t ff)
361 | {
362 | //scale up the weights
363 | if(ff == Dropout)
364 | {
365 | Z.push_back(m_BA.CURRENT_BATCH);
366 | for(int i = 0; i < W.size(); i++)
367 | {
368 | //D.push_back(Z.back());
369 | D.push_back(m_gpus.dropout(Z.back(),DROPOUT[i]));
370 | Z.push_back(m_gpus.dot(D.back(), W[i]));
371 | addMatrixVector(Z.back(),B[i],Z.back());
372 | activation_function(i, Z.back());
373 | }
374 | }
375 | else
376 | {
377 | //TODO: Correct input dropout rescaling
378 | if(ff == Train_error){ Z.push_back(m_BA.CURRENT_BATCH);}
379 | else{ Z.push_back(m_BA.CURRENT_BATCH_CV);}
380 |
381 |
382 | scalarMul(Z.back(), 1.0f-DROPOUT[0], Z.back());
383 |
384 | for(int i = 0; i < W.size(); i++)
385 | {
386 | Z.push_back(m_gpus.dot(Z.back(), W[i]));
387 | addMatrixVector(Z.back(),B[i],Z.back());
388 | activation_function(i, Z.back());
389 | if(i < W.size() -1)
390 | scalarMul(Z.back(), 1.0f-DROPOUT[i+1], Z.back());
391 | }
392 | }
393 |
394 | if(OUTPUT_IS_PROBABILITY)
395 | doubleRectifiedLinear(Z.back(),Z.back());
396 |
397 | }
398 |
399 | float DeepNeuralNetwork::get_errors(Batchtype_t batch_t)
400 | {
401 | float errors = 0;
402 | if(m_net_type == Classification || PRINT_MISSCLASSIFICATION)
403 | {
404 |
405 | Matrix *result = argmax(Z.back());
406 | Matrix *eq;
407 | if(m_net_type == Classification)
408 | {
409 | if(batch_t == Train){ eq = equal(result,m_BA.CURRENT_BATCH_Y);}
410 | else{ eq = equal(result,m_BA.CURRENT_BATCH_CV_Y);}
411 | }
412 | else
413 | {
414 | Matrix *argmax_regression_batch;
415 | if(batch_t == Train){argmax_regression_batch = argmax(m_BA.CURRENT_BATCH_Y); eq = equal(result,argmax_regression_batch);}
416 | else{argmax_regression_batch = argmax(m_BA.CURRENT_BATCH_CV_Y); eq = equal(result,argmax_regression_batch);}
417 | }
418 |
419 | float sum_value = sum(eq);
420 | missclassification_error += (Z.back()->rows - sum_value);
421 | cudaFree(result->data);
422 | cudaFree(eq->data);
423 | }
424 |
425 | if(m_net_type == Regression)
426 | {
427 | //Matrix *sqrErr = squared_error(Z.back(),batch_t == Train ? m_BA.CURRENT_BATCH_Y : m_BA.CURRENT_BATCH_CV_Y);
428 | Matrix *sqrErr = sub(Z.back(), batch_t == Train ? m_BA.CURRENT_BATCH_Y : m_BA.CURRENT_BATCH_CV_Y);
429 | square(sqrErr,sqrErr);
430 |
431 |
432 | errors = sum(sqrErr);
433 | errors /= m_BA.CURRENT_BATCH_Y->cols;
434 | errors *= batch_t == Train ? m_BA.CURRENT_BATCH_Y->rows : m_BA.CURRENT_BATCH_CV_Y->rows;
435 | errors = sqrt(errors);
436 | cudaFree(sqrErr->data);
437 | }
438 |
439 |
440 | return errors;
441 | }
442 |
443 | void DeepNeuralNetwork::activation_function(int layer, Matrix * A)
444 | {
445 | switch(lUnits[layer])
446 | {
447 | case Logistic:
448 | logistic(A,A);
449 | break;
450 | case Rectified_Linear:
451 | rectified_linear(A,A);
452 | break;
453 | case Softmax:
454 | softmax(A,A);
455 | break;
456 | case Double_Rectified_Linear:
457 | doubleRectifiedLinear(A,A);
458 | break;
459 | case Linear:
460 | break;
461 | }
462 | }
463 |
464 | void DeepNeuralNetwork::derivative_function(int layer, Matrix * A)
465 | {
466 | switch(lUnits[layer-1])
467 | {
468 | case Logistic:
469 | logisticGrad(A,A);
470 | break;
471 | case Rectified_Linear:
472 | rectified_linear_derivative(A,A);
473 | break;
474 | case Double_Rectified_Linear:
475 | double_rectified_linear_derivative(A,A);
476 | break;
477 | default:
478 | throw "Unknown unit";
479 | break;
480 | }
481 | }
482 |
483 | void DeepNeuralNetwork::nesterov_updates()
484 | {
485 | //nesterov updates
486 | for(int i = 0;i < M.size(); i++)
487 | {
488 | scalarMul(M[i],MOMENTUM,M[i]);
489 | add(W[i],M[i],W[i]);
490 | scalarMul(B_M[i],MOMENTUM,B_M[i]);
491 | add(B[i],B_M[i],B[i]);
492 | }
493 | }
494 |
495 | void DeepNeuralNetwork::train_error()
496 | {
497 | float errors = 0;
498 | missclassification_error = 0.0f;
499 | for(int i = 0; i < m_BA.TOTAL_BATCHES; i++)
500 | {
501 | m_BA.broadcast_batch_to_processes();
502 | feedforward(Train_error);
503 | m_BA.allocate_next_batch_async();
504 |
505 | errors += get_errors(Train);
506 |
507 | free_variables();
508 |
509 | m_BA.replace_current_batch_with_next();
510 | }
511 |
512 |
513 |
514 | if((m_BA.BATCH_METHOD == Single_GPU || (m_BA.BATCH_METHOD != Single_GPU && m_gpus.MYRANK == 0)) && m_net_type != Classification)
515 | std::cout << "Train error: " << errors/m_BA.TRAIN_SET_ROWS << std::endl;
516 | if((m_BA.BATCH_METHOD == Single_GPU || (m_BA.BATCH_METHOD != Single_GPU && m_gpus.MYRANK == 0)) &&
517 | PRINT_MISSCLASSIFICATION)
518 | std::cout << "Train classification error: " << missclassification_error/m_BA.TRAIN_SET_ROWS << std::endl;
519 |
520 | train_history.push_back(missclassification_error/m_BA.TRAIN_SET_ROWS);
521 | }
522 |
523 |
524 | void DeepNeuralNetwork::cross_validation_error()
525 | {
526 | float errors = 0;
527 | missclassification_error = 0.0f;
528 | for(int i = 0; i < m_BA.TOTAL_BATCHES_CV; i++)
529 | {
530 | m_BA.broadcast_batch_cv_to_processes();
531 | feedforward(CV_error);
532 | m_BA.allocate_next_cv_batch_async();
533 | errors += get_errors(CV);
534 | free_variables();
535 |
536 | m_BA.replace_current_cv_batch_with_next();
537 | }
538 |
539 |
540 | // cout << "Number of errors: " << missclassification_error << endl;
541 |
542 | if((m_BA.BATCH_METHOD == Single_GPU || (m_BA.BATCH_METHOD != Single_GPU && m_gpus.MYRANK == 0)) && m_net_type != Classification)
543 | std::cout << "Cross validation error: " << errors/m_BA.CV_SET_ROWS << std::endl;
544 | if((m_BA.BATCH_METHOD == Single_GPU || (m_BA.BATCH_METHOD != Single_GPU && m_gpus.MYRANK == 0)) &&
545 | PRINT_MISSCLASSIFICATION)
546 | std::cout << "Cross validation classification error: " << missclassification_error/m_BA.CV_SET_ROWS << std::endl;
547 |
548 |
549 | cv_history.push_back(missclassification_error/m_BA.CV_SET_ROWS );
550 | }
551 |
552 | Matrix* DeepNeuralNetwork::predict(Matrix *X)
553 | {
554 | int batch_size = 128;
555 | int rows = X->rows;
556 | int cols = X->cols;
557 |
558 | if(m_gpus.MPI_SIZE > 1)
559 | if(m_gpus.MYGPUID == 0)
560 | for(int i = 1; i < m_gpus.PCIe_RANKS.size();i++)
561 | {
562 | MPI_Send(&rows,1,MPI_INT,m_gpus.PCIe_RANKS[i],999,MPI_COMM_WORLD);
563 | MPI_Send(&cols,1,MPI_INT,m_gpus.PCIe_RANKS[i],999,MPI_COMM_WORLD);
564 | }
565 |
566 | else
567 | {
568 | MPI_Recv(&rows,1,MPI_INT,m_gpus.PCIe_RANKS[0],999,MPI_COMM_WORLD, MPI_STATUS_IGNORE);
569 | MPI_Recv(&cols,1,MPI_INT,m_gpus.PCIe_RANKS[0],999,MPI_COMM_WORLD, MPI_STATUS_IGNORE);
570 | }
571 |
572 | Matrix *batch = empty(batch_size,cols);
573 | Matrix *off_batch = empty(rows % batch_size,cols);
574 | Matrix *buffer = empty_cpu(batch_size,cols);
575 | Matrix *off_buffer = empty_cpu(rows % batch_size,cols);
576 |
577 | Matrix *out;
578 |
579 | int full_batches = (rows / batch_size)-1;
580 | for(int i = 0; i < (rows/batch_size) + 1; i++)
581 | {
582 | if(m_gpus.MYGPUID == 0)
583 | {
584 | if(X->isSparse == 0)
585 | {
586 | if(i < full_batches)
587 | cudaMemcpy(&batch->data[0],&X->data[(i*X->cols)*batch_size],batch->bytes,cudaMemcpyDefault);
588 | else
589 | cudaMemcpy(&off_batch->data[0],&X->data[(i*X->cols)*batch_size],off_batch->bytes,cudaMemcpyDefault);
590 | }
591 | else
592 | {
593 | if(i < full_batches)
594 | {
595 | slice_sparse_to_dense(X,buffer,i*batch_size,batch_size);
596 | cudaMemcpy(&batch->data[0],&buffer->data[0],buffer->bytes,cudaMemcpyDefault);
597 |
598 | }
599 | else
600 | {
601 | slice_sparse_to_dense(X,off_buffer,i*batch_size,X->rows % batch_size);
602 | cudaMemcpy(&off_batch->data[0],&off_buffer->data[0],off_buffer->bytes,cudaMemcpyDefault);
603 | }
604 | }
605 |
606 |
607 | if(m_gpus.MPI_SIZE > 1)
608 | if(i < full_batches)
609 | for(int i = 1; i < m_gpus.PCIe_RANKS.size();i++)
610 | {
611 | MPI_Send(batch->data,batch->size,MPI_FLOAT,m_gpus.PCIe_RANKS[i],999,MPI_COMM_WORLD);
612 | }
613 | else
614 | for(int i = 1; i < m_gpus.PCIe_RANKS.size();i++)
615 | MPI_Send(off_batch->data,off_batch->size,MPI_FLOAT,m_gpus.PCIe_RANKS[i],999,MPI_COMM_WORLD);
616 |
617 |
618 | }
619 | else
620 | {
621 |
622 | if(i < full_batches)
623 | {
624 | MPI_Recv(batch->data,batch->size,MPI_FLOAT,m_gpus.PCIe_RANKS[0],999,MPI_COMM_WORLD, MPI_STATUS_IGNORE);
625 | }
626 | else
627 | MPI_Recv(off_batch->data,off_batch->size,MPI_FLOAT,m_gpus.PCIe_RANKS[0],999,MPI_COMM_WORLD, MPI_STATUS_IGNORE);
628 |
629 | }
630 |
631 |
632 |
633 |
634 | if(i < full_batches)
635 | {
636 | to_col_major(batch,batch);
637 | Z.push_back(batch);
638 | }
639 | else
640 | {
641 | to_col_major(off_batch,off_batch);
642 | Z.push_back(off_batch);
643 | }
644 |
645 | //feed forward
646 | for(int j = 0; j < W.size(); j++)
647 | {
648 | Z.push_back(m_gpus.dot(Z.back(), W[j]));
649 | addMatrixVector(Z.back(),B[j],Z.back());
650 | activation_function(j, Z.back());
651 | if(j < W.size() -1)
652 | scalarMul(Z.back(),1.0f-DROPOUT[i+1],Z.back());
653 | }
654 |
655 | if(OUTPUT_IS_PROBABILITY)
656 | doubleRectifiedLinear(Z.back(),Z.back());
657 |
658 | if(m_gpus.MYGPUID == 0)
659 | {
660 | if(i == 0)
661 | out = empty_cpu(X->rows,Z.back()->cols);
662 |
663 | Matrix *host = to_host(Z.back());
664 | for(int k = 0; k < host->size; k++)
665 | out->data[(i*batch_size*host->cols) + k] = host->data[k];
666 |
667 | free(host->data);
668 |
669 | }
670 |
671 |
672 |
673 | free_variables();
674 |
675 |
676 |
677 | }
678 |
679 | cudaFree(batch->data);
680 | cudaFree(off_batch->data);
681 |
682 | return out;
683 |
684 |
685 | }
686 |
--------------------------------------------------------------------------------
/source/DeepNeuralNetwork.h:
--------------------------------------------------------------------------------
1 | #ifndef DeepNeuralNetwork_H
2 | #define DeepNeuralNetwork_H
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | typedef enum FeedForward_t
11 | {
12 | Dropout = 0,
13 | Train_error = 1,
14 | CV_error = 2
15 | } FeedForward_t;
16 |
17 | typedef enum Batchtype_t
18 | {
19 | Train = 0,
20 | CV = 1
21 | } Batchtype_t;
22 |
23 | typedef enum Networktype_t
24 | {
25 | Classification = 0,
26 | Regression = 1
27 | } Networktype_t;
28 |
29 |
30 |
31 |
32 |
33 | class DeepNeuralNetwork
34 | {
35 | public:
36 | DeepNeuralNetwork(std::vector lLayerSizes, Networktype_t net_type, ClusterNet gpus, BatchAllocator allocator, int categories);
37 | void train();
38 | Matrix* predict(Matrix *X);
39 |
40 | float LEARNING_RATE;
41 | float LEARNING_RATE_DECAY;
42 | float MOMENTUM;
43 | float RMSPROP_MOMENTUM;
44 | int EPOCHS;
45 | bool OUTPUT_IS_PROBABILITY;
46 | int TRANSITION_EPOCH;
47 | bool PRINT_MISSCLASSIFICATION;
48 | Unittype_t MAIN_UNIT;
49 | std::vector DROPOUT;
50 |
51 | WeightUpdateType_t UPDATE_TYPE;
52 |
53 | private:
54 | Costfunction_t m_costFunction;
55 | std::vector D;
56 | std::vector D_B;
57 | std::vector Z;
58 | std::vector Z_B;
59 | std::vector E;
60 | std::vector lUnits;
61 | BatchAllocator m_BA;
62 | std::vector W;
63 | std::vector max_values;
64 | std::vector B;
65 | std::vector B_Activations;
66 | std::vector B_Activations_CV;
67 | std::vector M;
68 | std::vector B_M;
69 | std::vector GRAD;
70 | std::vector GRAD8bit;
71 | std::vector GRAD_approx;
72 | std::vector B_GRAD;
73 | std::vector MS;
74 | std::vector B_MS;
75 | std::vector train_history;
76 | std::vector cv_history;
77 | std::vector m_lLayers;
78 |
79 | ClusterNet m_gpus;
80 | int m_output_dim;
81 | Networktype_t m_net_type;
82 |
83 | float missclassification_error;
84 |
85 | void init_network_layout(std::vector lLayerSizes);
86 | void init_weights();
87 | void nesterov_updates();
88 | void feedforward(FeedForward_t ff);
89 | void backprop();
90 | void weight_updates();
91 | void free_variables();
92 | float get_errors(Batchtype_t batch_t);
93 | void cross_validation_error();
94 | void train_error();
95 |
96 | void save_history();
97 |
98 | void activation_function(int layer, Matrix *A);
99 | void derivative_function(int layer, Matrix *A);
100 |
101 | };
102 |
103 | #endif
104 |
--------------------------------------------------------------------------------
/source/Layer.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 |
5 |
6 | using std::cout;
7 | using std::endl;
8 | using std::string;
9 | using std::vector;
10 |
11 | Layer::Layer(int unitcount, int start_batch_size, Unittype_t unit, ClusterNet *gpu){ init(unitcount, start_batch_size,unit,gpu); }
12 | Layer::Layer(int unitcount, Unittype_t unit){ init(unitcount, 0,unit, NULL); }
13 | Layer::Layer(int unitcount){ init(unitcount, 0,Rectified_Linear, NULL); }
14 |
15 | Layer::Layer(int unitcount, int start_batch_size, Unittype_t unit, Layer *prev, ClusterNet *gpu)
16 | { init(unitcount, start_batch_size,unit,gpu); prev->link_with_next_layer(this); }
17 | Layer::Layer(int unitcount, Unittype_t unit, Layer *prev){ init(unitcount, 0,unit, prev->GPU); prev->link_with_next_layer(this); }
18 | Layer::Layer(int unitcount, Layer *prev){ init(unitcount, 0,Rectified_Linear, NULL); prev->link_with_next_layer(this); }
19 |
20 | void Layer::init(int unitcount, int start_batch_size, Unittype_t unit, ClusterNet *gpu)
21 | {
22 |
23 | next = NULL;
24 | prev = NULL;
25 | w_next = NULL;
26 | b_next = NULL;
27 | b_next_sync = NULL;
28 | w_rms_next = NULL;
29 | b_rms_next = NULL;
30 | b_grad_next = NULL;
31 |
32 | w_next_sync_send = NULL;
33 | b_next_sync_send = NULL;
34 | w_next_sync_recv = NULL;
35 | b_next_sync_recv = NULL;
36 |
37 | isSynchronizing = false;
38 |
39 | compression = bits_8;
40 |
41 | target = NULL;
42 | target_matrix = NULL;
43 | error = NULL;
44 |
45 | LEARNING_RATE = 0.006;
46 | RMSPROP_MOMENTUM = 0.9f;
47 | UNIT_TYPE = unit;
48 | DROPOUT = 0.5f;
49 | UNITCOUNT = unitcount;
50 | BATCH_SIZE = start_batch_size;
51 | RUNNING_ERROR = 0.0f;
52 | RUNNING_SAMPLE_SIZE = 0.0f;
53 | L2 = 15.0f;
54 |
55 | MAX_GRAD_VALUE = 1.0f;
56 |
57 | UPDATE_TYPE = RMSProp;
58 | COST = Misclassification;
59 | PARALLELISM = None;
60 |
61 | GPU = gpu;
62 | count = 0;
63 |
64 | for(int i = 0; i < GPU->MPI_SIZE; i++)
65 | {
66 |
67 | send_request.push_back(new MPI_Request);
68 | recv_request.push_back(new MPI_Request);
69 | }
70 |
71 | max_grad_value_sync = (float*)malloc(GPU->MPI_SIZE*sizeof(float));
72 |
73 | if(BATCH_SIZE > 0)
74 | {
75 | out = zeros(BATCH_SIZE, UNITCOUNT);
76 | bias_activations = ones(1, BATCH_SIZE);
77 | activation = zeros(BATCH_SIZE, UNITCOUNT);
78 | }
79 | else
80 | {
81 | out = NULL;
82 | bias_activations = NULL;
83 | activation = NULL;
84 | }
85 |
86 |
87 | mpi_buffer = (float*)malloc(GPU->MPI_SIZE*sizeof(float));
88 | }
89 |
90 | void Layer::link_with_next_layer(Layer *next_layer)
91 | {
92 | next = next_layer;
93 | if(next->BATCH_SIZE == 0){ next->BATCH_SIZE = BATCH_SIZE; }
94 | if(!next->GPU){next->GPU = GPU;}
95 |
96 | if(PARALLELISM == ModelParallelism)
97 | {
98 | for(int i = 0; i < GPU->MPI_SIZE; i++)
99 | {
100 | vec_w_grad_next.push_back(GPU->distributed_zeros(UNITCOUNT,next_layer->UNITCOUNT));
101 | }
102 |
103 | Matrix *w = GPU->distributed_uniformSqrtWeight(UNITCOUNT,next_layer->UNITCOUNT);
104 | w_next = w;
105 | b_grad_next = GPU->distributed_zeros(1,next_layer->UNITCOUNT);
106 | b_rms_next = GPU->distributed_zeros(1,next_layer->UNITCOUNT);
107 | w_rms_next = GPU->distributed_zeros(UNITCOUNT,next_layer->UNITCOUNT);
108 |
109 | }
110 | else
111 | {
112 | Matrix *w = GPU->uniformSqrtWeight(UNITCOUNT,next_layer->UNITCOUNT);
113 | w_next = w;
114 | b_grad_next = zeros(1,next_layer->UNITCOUNT);
115 | b_rms_next = zeros(1,next_layer->UNITCOUNT);
116 | w_rms_next = zeros(UNITCOUNT,next_layer->UNITCOUNT);
117 | w_next_abs_max_buffer = zeros(UNITCOUNT,next_layer->UNITCOUNT);
118 |
119 | for(int i = 0; i < GPU->MPI_SIZE; i++)
120 | {
121 | vec_w_grad_next.push_back(zeros(UNITCOUNT,next_layer->UNITCOUNT));
122 | vec_w_grad_next_8bit.push_back(empty_char(UNITCOUNT,next_layer->UNITCOUNT));
123 | }
124 |
125 | w_next_sync_send = empty_char(UNITCOUNT,next_layer->UNITCOUNT);
126 | w_next_sync_recv = empty_char(UNITCOUNT,next_layer->UNITCOUNT);
127 | b_next_sync = zeros(1,next_layer->UNITCOUNT);
128 | b_next_sync_send = empty_char(1,next_layer->UNITCOUNT);
129 | b_next_sync_recv = empty_char(1,next_layer->UNITCOUNT);
130 |
131 | }
132 |
133 |
134 | Matrix *b = zeros(1,next_layer->UNITCOUNT);
135 | b_next = b;
136 | next->out = zeros(BATCH_SIZE, next->UNITCOUNT);
137 | next->activation = zeros(BATCH_SIZE, next->UNITCOUNT);
138 | next->error = zeros(BATCH_SIZE, next->UNITCOUNT);
139 | next->bias_activations = ones(1, BATCH_SIZE);
140 | next->prev = this;
141 |
142 |
143 | }
144 |
145 |
146 | void Layer::unit_activation(){ unit_activation(true); }
147 | void Layer::unit_activation(bool useDropout)
148 | {
149 | switch(UNIT_TYPE)
150 | {
151 | case Logistic:
152 | logistic(out,activation);
153 | break;
154 | case Rectified_Linear:
155 | rectified_linear(out,activation);
156 | break;
157 | case Softmax:
158 | softmax(out,out);
159 | break;
160 | case Double_Rectified_Linear:
161 | doubleRectifiedLinear(out,activation);
162 | break;
163 | case Linear:
164 | LinearUnit(out, activation);
165 | break;
166 | case Input:
167 | break;
168 | }
169 |
170 | if(UNIT_TYPE != Softmax)
171 | {
172 | if(useDropout)
173 | GPU->dropout(activation,out,DROPOUT);
174 | else
175 | scalarMul(activation,1.0f-DROPOUT, out);
176 | }
177 |
178 | }
179 |
180 | void Layer::activation_gradient()
181 | {
182 |
183 | switch(UNIT_TYPE)
184 | {
185 | case Logistic:
186 | logisticGrad(activation,out);
187 | break;
188 | case Rectified_Linear:
189 | rectified_linear_derivative(activation,out);
190 | break;
191 | case Double_Rectified_Linear:
192 | double_rectified_linear_derivative(activation,out);
193 | break;
194 | case Softmax:
195 | break;
196 | default:
197 | throw "Unknown unit";
198 | break;
199 | }
200 |
201 | }
202 |
203 | void Layer::handle_offsize()
204 | {
205 | if(!prev)
206 | {
207 | if(!out){ out = empty(activation->rows, activation->cols); }
208 | else if(out->rows != activation->rows)
209 | {
210 | cudaFree(out->data);
211 | free(out);
212 | out = empty(activation->rows, activation->cols);
213 | }
214 | }
215 | else
216 | {
217 | if(prev->out->rows != out->rows && (!out_offsize || out_offsize->rows != prev->out->rows))
218 | {
219 | if(out_offsize)
220 | {
221 | cudaFree(out_offsize->data);
222 | cudaFree(activation_offsize->data);
223 | cudaFree(error_offsize->data);
224 | cudaFree(bias_activations_offsize->data);
225 | cudaFree(target_matrix_offsize->data);
226 | }
227 |
228 | out_offsize = empty(prev->out->rows, UNITCOUNT);
229 | activation_offsize = empty(prev->out->rows, UNITCOUNT);
230 | error_offsize = empty(prev->out->rows, UNITCOUNT);
231 | bias_activations_offsize = empty(1,prev->out->rows);
232 | target_matrix_offsize = zeros(prev->out->rows, UNITCOUNT);
233 | }
234 |
235 |
236 | if(prev->out->rows != out->rows)
237 | {
238 | Matrix *swap;
239 | swap = out; out = out_offsize; out_offsize = swap;
240 | swap = activation; activation = activation_offsize; activation_offsize = swap;
241 | swap = error; error = error_offsize; error_offsize = swap;
242 | swap = bias_activations; bias_activations = bias_activations_offsize; bias_activations_offsize = swap;
243 | swap = target_matrix; target_matrix = target_matrix_offsize; target_matrix_offsize = swap;
244 | }
245 | }
246 |
247 | }
248 |
249 | void Layer::forward(){ forward(true); }
250 | void Layer::forward(bool useDropout)
251 | {
252 | handle_offsize();
253 | if(!prev){ unit_activation(useDropout); next->forward(useDropout); return; }
254 | if(PARALLELISM == DataParallelism && useDropout){ prev->wait_for_synchronization(); prev->weight_update(); }
255 |
256 | GPU->dot(prev->out,prev->w_next,out);
257 | addMatrixVector(out,prev->b_next,out);
258 | unit_activation(useDropout);
259 |
260 | if(next){ next->forward(useDropout); }
261 | }
262 |
263 |
264 | void Layer::running_error(bool isCV, int epoch)
265 | {
266 | if(!target){ next->running_error(isCV, epoch); return;}
267 |
268 | string text = "";
269 |
270 | Matrix *result;
271 | Matrix *eq;
272 | float sum_value = 0.0f;
273 | float size = 0.0f;
274 |
275 | if (!Train_errors.count(epoch))
276 | {
277 | Train_errors[epoch] = std::vector();
278 | CV_errors[epoch] = std::vector();
279 | }
280 |
281 |
282 |
283 |
284 |
285 | switch(COST)
286 | {
287 | case Misclassification:
288 | result = argmax(out);
289 | eq = equal(result,target);
290 | sum_value = sum(eq);
291 | sum_value = reduce_to_sum_root(sum_value);
292 | size = reduce_to_sum_root(out->rows);
293 | if(GPU->MYRANK == 0)
294 | {
295 | if(isCV)
296 | CV_errors[epoch].push_back(sum_value/size);
297 | else
298 | Train_errors[epoch].push_back(sum_value/size);
299 | }
300 | RUNNING_ERROR += (out->rows - sum_value);
301 | RUNNING_SAMPLE_SIZE += out->rows;
302 | break;
303 | default:
304 | throw "Unknown cost function!";
305 | break;
306 | }
307 |
308 | cudaFree(result->data);
309 | cudaFree(eq->data);
310 | }
311 |
312 |
313 |
314 | void Layer::backward_errors()
315 | {
316 | if(!target){ next->backward_errors(); }
317 | if(target)
318 | {
319 | if(out->cols != target->cols && !target_matrix){ target_matrix = zeros(BATCH_SIZE,out->cols); }
320 | if(out->cols != target->cols){ create_t_matrix(target,target_matrix); sub(out,target_matrix,error); return; }
321 | else{ sub(activation,target,error); return;}
322 | }
323 |
324 | if(UNIT_TYPE == Input){ backward_grads(); return; }
325 |
326 | activation_gradient();
327 | GPU->dotT(next->error, w_next,error);
328 | mul(error, out, error);
329 |
330 | }
331 |
332 | void Layer::backward_grads()
333 | {
334 | GPU->Tdot(activation, next->error, vec_w_grad_next[GPU->MYRANK]);
335 | MPI_synchronization_async();
336 | if(!next->target){ next->backward_grads(); }
337 | //GPU->dot(next->bias_activations, next->error,b_grad_next);
338 |
339 | }
340 |
341 | void Layer::MPI_synchronization_async()
342 | {
343 | if(PARALLELISM != DataParallelism){ return; }
344 |
345 | int target = GPU->MYRANK +1 == GPU->MPI_SIZE ? 0 : GPU->MYRANK+1;
346 | int source = GPU->MYRANK-1 == -1 ? GPU->MPI_SIZE-1 : GPU->MYRANK-1;
347 |
348 |
349 | if(compression == bits_8)
350 | {
351 |
352 | //cout << 1.0f/((float)out->rows) << endl;
353 | /*
354 | scalarMul(vec_w_grad_next[GPU->MYRANK],1.0f/((float)out->rows),vec_w_grad_next[GPU->MYRANK]);
355 |
356 | abs(vec_w_grad_next[GPU->MYRANK],w_next_abs_max_buffer);
357 | MAX_GRAD_VALUE = max(w_next_abs_max_buffer);
358 | MPI_Allgather(&MAX_GRAD_VALUE, 1, MPI_FLOAT, max_grad_value_sync, 1, MPI_FLOAT, MPI_COMM_WORLD);
359 | */
360 |
361 | //cout << max_grad_value_sync[GPU->MYRANK] << " vs. " << MAX_GRAD_VALUE << " and " << max_grad_value_sync[source] << endl;
362 |
363 |
364 |
365 | GPU->compression_8bit(vec_w_grad_next[GPU->MYRANK],MAX_GRAD_VALUE,vec_w_grad_next_8bit[GPU->MYRANK]);
366 | for (int i = 0; i < GPU->MPI_SIZE - 1; i++)
367 | {
368 | MPI_Isend(vec_w_grad_next_8bit[GPU->MYRANK]->char_data,vec_w_grad_next_8bit[GPU->MYRANK]->size,MPI_CHAR,target,999,MPI_COMM_WORLD, send_request[target]);
369 | MPI_Irecv(vec_w_grad_next_8bit[source]->char_data,vec_w_grad_next_8bit[source]->size,MPI_CHAR,source,999,MPI_COMM_WORLD,recv_request[source]);
370 | target = target +1 == GPU->MPI_SIZE ? 0 : target+1;
371 | source = source-1 == -1 ? GPU->MPI_SIZE-1 : source-1;
372 | }
373 | }
374 | else
375 | {
376 | for (int i = 0; i < GPU->MPI_SIZE - 1; i++)
377 | {
378 | MPI_Isend(vec_w_grad_next[GPU->MYRANK]->data,vec_w_grad_next[GPU->MYRANK]->size,MPI_FLOAT,target,999,MPI_COMM_WORLD, send_request[target]);
379 | MPI_Irecv(vec_w_grad_next[source]->data,vec_w_grad_next[source]->size,MPI_FLOAT,source,999,MPI_COMM_WORLD,recv_request[source]);
380 | target = target +1 == GPU->MPI_SIZE ? 0 : target+1;
381 | source = source-1 == -1 ? GPU->MPI_SIZE-1 : source-1;
382 |
383 | }
384 | }
385 | isSynchronizing = true;
386 |
387 |
388 |
389 |
390 | }
391 |
392 | void Layer::wait_for_synchronization()
393 | {
394 | if(target){ return; }
395 | if(!isSynchronizing){ return; }
396 | if(PARALLELISM != DataParallelism){ return; }
397 | //GPU->tick();
398 | //MPI_Wait(next->send_request,MPI_STATUS_IGNORE);_w_next_sync
399 |
400 | for(int i = 0; i < GPU->MPI_SIZE; i++)
401 | {
402 | if(i== GPU->MYRANK){ continue; }
403 | MPI_Wait(send_request[i],MPI_STATUS_IGNORE);
404 | MPI_Wait(recv_request[i],MPI_STATUS_IGNORE);
405 | }
406 |
407 | //float secs = GPU->tock()/1000.0f;
408 | //cout << w_next_sync->bytes/1024./1024./1024./secs << " GB/s" << endl;
409 | //printdim(w_next_sync);
410 | //cout << "pre decomrpess" << endl;
411 | //GPU->decompression_8bit(w_next_sync_recv,0.001,w_next_sync);
412 | //cout << "post decompress" << endl;
413 |
414 |
415 |
416 |
417 | /*
418 | MPI_Barrier(MPI_COMM_WORLD);
419 | cout << GPU->MYRANK << " " << sum(vec_w_grad_next[0]) << " 0" << endl;
420 | MPI_Barrier(MPI_COMM_WORLD);
421 | cout << GPU->MYRANK << " " << sum(vec_w_grad_next[1]) << " 1" << endl;
422 | MPI_Barrier(MPI_COMM_WORLD);
423 | cout << GPU->MYRANK << " " << sum(vec_w_grad_next[2]) << " 2" << endl;
424 | MPI_Barrier(MPI_COMM_WORLD);
425 | cout << GPU->MYRANK << " " << sum(vec_w_grad_next[3]) << " 3" << endl;
426 | MPI_Barrier(MPI_COMM_WORLD);
427 | */
428 |
429 |
430 | for(int i = 0; i < GPU->MPI_SIZE; i++)
431 | {
432 | if(i == GPU->MYRANK){ continue; }
433 | if(compression == bits_8){ GPU->decompression_8bit(vec_w_grad_next_8bit[i],max_grad_value_sync[i],vec_w_grad_next[i]); }
434 | add(vec_w_grad_next[GPU->MYRANK],vec_w_grad_next[i],vec_w_grad_next[GPU->MYRANK]);
435 | }
436 | isSynchronizing = false;
437 | }
438 |
439 | void Layer::weight_update()
440 | {
441 | if(target){ return; }
442 |
443 | if(PARALLELISM != DataParallelism)
444 | next->weight_update();
445 | //float *data = (float*)malloc(sizeof(float)*100);
446 |
447 | switch(UPDATE_TYPE)
448 | {
449 | case RMSProp:
450 |
451 | //CUDA_CHECK_RETURN(cudaMemcpy(data,vec_w_grad_next[GPU->MYRANK]->data,10*sizeof(float),cudaMemcpyDefault));
452 | //cout << "pre print" << endl;
453 |
454 | //for(int i; i < 100; i++){ cout << data[i] << endl;}
455 | //RMSprop_with_weight_update(w_rms_next,vec_w_grad_next[GPU->MYRANK],w_next,w_next,RMSPROP_MOMENTUM,LEARNING_RATE,out->rows*GPU->MPI_SIZE,MOMENTUM);
456 | RMSprop_with_weight_update(w_rms_next,vec_w_grad_next[GPU->MYRANK],w_next,w_next,RMSPROP_MOMENTUM,LEARNING_RATE,GPU->MPI_SIZE,MOMENTUM);
457 | //cout << "post print" << endl;
458 | //RMSprop_with_weight_update(b_rms_next,b_grad_next,b_next,b_next,RMSPROP_MOMENTUM,LEARNING_RATE/100.0f,out->rows,MOMENTUM);
459 | //scalarMul(b_grad_next, LEARNING_RATE/float(out->rows*GPU->MPI_SIZE) ,b_grad_next);
460 | //sub(b_next,b_grad_next,b_next);
461 |
462 | break;
463 | default:
464 | throw "Unknown update type!";
465 | break;
466 | }
467 | //free(data);
468 |
469 | //limit_magnitude();
470 |
471 | }
472 |
473 | void Layer::limit_magnitude()
474 | {
475 |
476 | square(w_next,vec_w_grad_next[GPU->MYRANK]);
477 | Matrix *temp = ones(vec_w_grad_next[GPU->MYRANK]->cols,1);
478 | Matrix *sums = GPU->dot(vec_w_grad_next[GPU->MYRANK],temp);
479 | renormalizeWeights(w_next,sums,L2);
480 | cudaFree(temp->data);
481 | cudaFree(sums->data);
482 | free(temp);
483 | free(sums);
484 |
485 | }
486 |
487 | void Layer::print_error(string message)
488 | {
489 | if(!target){ next->print_error(message); return;}
490 |
491 | if(GPU->MPI_SIZE > 1)
492 | {
493 | RUNNING_ERROR =reduce_to_sum_root(RUNNING_ERROR);
494 | RUNNING_SAMPLE_SIZE = reduce_to_sum_root(RUNNING_SAMPLE_SIZE);
495 | }
496 |
497 | if(GPU->MYRANK == 0)
498 | cout << message << RUNNING_ERROR/RUNNING_SAMPLE_SIZE << endl;
499 |
500 | RUNNING_ERROR = 0.0f;
501 | RUNNING_SAMPLE_SIZE = 0.0f;
502 | }
503 |
504 |
505 | float Layer::reduce_to_sum_root(float value)
506 | {
507 |
508 | MPI_Gather(&value, 1, MPI_FLOAT, mpi_buffer, 1, MPI_FLOAT, 0, MPI_COMM_WORLD);
509 | for(int i = 1; i < GPU->MPI_SIZE; i++)
510 | mpi_buffer[0] += mpi_buffer[i];
511 |
512 |
513 | return mpi_buffer[0];
514 |
515 | }
516 |
517 | void Layer::set_hidden_dropout(float dropout)
518 | {
519 | if(!next){ return; }
520 | next->DROPOUT = dropout;
521 | next->set_hidden_dropout(dropout);
522 | }
523 |
524 | void Layer::learning_rate_decay(float decay_rate)
525 | {
526 | if(!next){ return; }
527 | next->LEARNING_RATE *= decay_rate;
528 | next->learning_rate_decay(decay_rate);
529 | }
530 |
531 | void Layer::dropout_decay()
532 | {
533 | if(!prev){ cout << "Decaying dropout!" << endl; }
534 | if(!next){ return;}
535 |
536 | cout << "Setting dropout from " << DROPOUT << " to " << DROPOUT/2.0f << endl;
537 | DROPOUT /= 2.0f;
538 | next->dropout_decay();
539 | }
540 |
541 | Layer::~Layer()
542 | {
543 | cout << "destruct" << endl;
544 | }
545 |
546 |
547 |
--------------------------------------------------------------------------------
/source/Layer.h:
--------------------------------------------------------------------------------
1 | #ifndef Layer_H
2 | #define Layer_H
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 | #define CUDA_CHECK_RETURN(value) { \
10 | cudaError_t _m_cudaStat = value; \
11 | if (_m_cudaStat != cudaSuccess) { \
12 | fprintf(stderr, "Error %s at line %d in file %s\n", \
13 | cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
14 | exit(1); \
15 | } }
16 |
17 | class Layer
18 | {
19 | public:
20 | Matrix *b_grad_next;
21 | Layer *next;
22 | Layer *prev;
23 | Matrix *w_next;
24 | Matrix *b_next;
25 |
26 | std::vector vec_w_grad_next;
27 | std::vector vec_w_grad_next_8bit;
28 | Matrix *b_next_sync;
29 | Matrix *w_next_sync_recv;
30 | Matrix *b_next_sync_recv;
31 | Matrix *w_next_sync_send;
32 | Matrix *w_next_abs_max_buffer;
33 | Matrix *b_next_sync_send;
34 |
35 | Matrix *w_rms_next;
36 | Matrix *b_rms_next;
37 |
38 | Matrix *bias_activations;
39 | Matrix *out;
40 | Matrix *error;
41 | Matrix *activation;
42 |
43 | float *mpi_buffer;
44 |
45 | Matrix *out_offsize;
46 | Matrix *activation_offsize;
47 | Matrix *error_offsize;
48 | Matrix *bias_activations_offsize;
49 | Matrix *target_matrix_offsize;
50 |
51 | Matrix *target;
52 | Matrix *target_matrix;
53 |
54 | std::vector send_request;
55 | std::vector recv_request;
56 |
57 | std::map > CV_errors;
58 | std::map > Train_errors;
59 |
60 | ClusterNet *GPU;
61 |
62 | int count;
63 |
64 | float LEARNING_RATE;
65 | float MOMENTUM;
66 | float RMSPROP_MOMENTUM;
67 | float RUNNING_ERROR;
68 | float RUNNING_SAMPLE_SIZE;
69 | float L2;
70 | Unittype_t UNIT_TYPE;
71 | Costfunction_t COST;
72 | float DROPOUT;
73 | int UNITCOUNT;
74 | int BATCH_SIZE;
75 |
76 | bool isSynchronizing;
77 |
78 | float MAX_GRAD_VALUE;
79 | float *max_grad_value_sync;
80 |
81 | Compression_t compression;
82 |
83 | WeightUpdateType_t UPDATE_TYPE;
84 |
85 | ParallelismType_t PARALLELISM;
86 |
87 | virtual ~Layer();
88 | Layer(int unitcount, int start_batch_size, Unittype_t unit, ClusterNet *gpu);
89 | Layer(int unitcount, Unittype_t unit);
90 | Layer(int unitcount);
91 |
92 | Layer(int unitcount, int start_batch_size, Unittype_t unit, Layer *prev, ClusterNet *gpu);
93 | Layer(int unitcount, Unittype_t unit, Layer *prev);
94 | Layer(int unitcount, Layer *prev);
95 |
96 | virtual void forward();
97 | virtual void forward(bool useDropout);
98 | virtual void running_error(bool isCV, int epoch);
99 | virtual void backward_errors();
100 | virtual void backward_grads();
101 | virtual void print_error(std::string message);
102 | virtual void weight_update();
103 |
104 | virtual void MPI_synchronization_async();
105 | virtual void wait_for_synchronization();
106 |
107 | virtual void limit_magnitude();
108 |
109 | virtual void link_with_next_layer(Layer *next_layer);
110 | virtual void init(int unitcount, int start_batch_size, Unittype_t unit, ClusterNet *gpu);
111 | virtual void set_hidden_dropout(float dropout);
112 |
113 | virtual void dropout_decay();
114 | virtual void learning_rate_decay(float decay_rate);
115 |
116 | float reduce_to_sum_root(float value);
117 |
118 |
119 |
120 | private:
121 | virtual void unit_activation();
122 | virtual void unit_activation(bool useDropout);
123 | virtual void activation_gradient();
124 | void handle_offsize();
125 |
126 |
127 | };
128 |
129 | #endif
130 |
--------------------------------------------------------------------------------
/source/WikiMaxoutNet.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | using std::cout;
4 | using std::endl;
5 |
6 | WikiMaxoutNet::WikiMaxoutNet(ClusterNet gpus)
7 | {
8 |
9 | int vocabSize = 100002;
10 | int nWordVectorDim = 120;
11 | int nWindowSize = 11;
12 | _layers.push_back(128);
13 | _learningRate = 0.1;
14 | _nCVErrorPeriodicity = 6000;
15 | _nCVErrorLength = 6000;
16 | MOMENTUM = 0.9;
17 | gpu = gpus;
18 | _nCurrentDataSet = gpu.MYRANK;
19 | _X = 0;
20 | int cv_set_number = 63;
21 | _CV_X = read_hdf5(("/home/tim/data/wiki/extracted2/AA/data100000/wiki_" + NumberToString(cv_set_number) + ".p").c_str());
22 | _nNextBatchNumber = 0;
23 | _nNextBatchNumber_CV = gpu.MYRANK*_nCVErrorLength;
24 | _nBatchSize = 512;
25 | _RMS_multiplier = 0.9f;
26 |
27 |
28 | cudaStreamCreate(&_streamNextBatch);
29 |
30 | Matrix *learning_rate_matrix_cpu = empty_cpu(nWordVectorDim,vocabSize);
31 |
32 | float learning_rate = 0.0000001;
33 | int next_level = 2000;
34 | for(int col = 0; col < vocabSize; col++)
35 | for(int row = 0; row < nWordVectorDim; row++)
36 | {
37 | if(col > next_level)
38 | {
39 | learning_rate = learning_rate * 10.00f;
40 | next_level = next_level == 50000 ? vocabSize : next_level;
41 | next_level = next_level == 25000 ? 50000 : next_level;
42 | next_level = next_level == 10000 ? 25000 : next_level;
43 | next_level = next_level == 2000 ? 10000 : next_level;
44 | }
45 |
46 | if((col == vocabSize-2) || (col == vocabSize-1))
47 | {
48 | learning_rate_matrix_cpu->data[col + (row*vocabSize)] = 0.0000001;
49 | }
50 | else
51 | {
52 | learning_rate_matrix_cpu->data[col + (row*vocabSize)] = learning_rate;
53 | }
54 | }
55 |
56 |
57 | learning_rate_matrix = to_gpu(learning_rate_matrix_cpu);
58 | free(learning_rate_matrix_cpu->data);
59 |
60 |
61 | useRMSProp = true;
62 |
63 | cout << "_layers: " << _layers[0] << endl;
64 | cout << "nWordVectorDim: " << nWordVectorDim << endl;
65 | cout << "_nBatchSize: " << _nBatchSize << endl;
66 | cout << "_learningRate: " << _learningRate << endl;
67 | cout << "Use RMSProp: " << useRMSProp << endl;
68 |
69 | W.push_back(gpu.uniformSqrtWeight(nWordVectorDim*nWindowSize,_layers[0]));
70 | W.push_back(gpu.uniformSqrtWeight(_layers[0], 1));
71 | B.push_back(zeros(1,_layers[0]));
72 | B.push_back(zeros(1,1));
73 | M.push_back(zeros(nWordVectorDim*nWindowSize,_layers[0]));
74 | M.push_back(zeros(_layers[0], 1));
75 | M_B.push_back(zeros(1,_layers[0]));
76 | M_B.push_back(zeros(1,1));
77 |
78 |
79 |
80 | for(int i = 0; i < W.size(); i++)
81 | {
82 | cout << sum(W[i]) << endl;
83 | MPI_Barrier(MPI_COMM_WORLD);
84 | }
85 |
86 | CV_container = empty_cpu(10000,1);
87 | for(int i = 0; i < CV_container->size; i++)
88 | CV_container->data[i] = 0.0f;
89 |
90 |
91 | if(gpu.MPI_SIZE == 0)
92 | gpu.MPI_SIZE = 1;
93 |
94 |
95 | cout << gpu.MPI_SIZE << " MPI SIZE" << endl;
96 | cout << gpu.MYRANK << " MYRANK " << endl;
97 | for(int i = W.size()-1; i >= 0; i--)
98 | {
99 | Matrix **gradX = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
100 | arrGRAD.push_back(gradX);
101 | Matrix **gradY = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
102 | arrGRAD.push_back(gradY);
103 | Matrix **gradX_B = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
104 | arrGRAD_B.push_back(gradX_B);
105 | Matrix **gradY_B = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
106 | arrGRAD_B.push_back(gradY_B);
107 | }
108 | Matrix **gradX = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
109 | arrGRAD.push_back(gradX);
110 | Matrix **gradY = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
111 | arrGRAD.push_back(gradY);
112 |
113 | cout << arrGRAD.size() << " size" << endl;
114 |
115 |
116 | for(int i = W.size()-1; i >= 0; i--)
117 | {
118 | MSGRAD.push_back(zeros(W[i]->rows, W[i]->cols));
119 | MSGRAD.push_back(zeros(W[i]->rows, W[i]->cols));
120 | MSBGRAD.push_back(zeros(B[i]->rows, B[i]->cols));
121 | MSBGRAD.push_back(zeros(B[i]->rows, B[i]->cols));
122 | }
123 |
124 | for(int j =0; j < gpu.MPI_SIZE; j++)
125 | {
126 | int idx = 0;
127 | for(int i = W.size()-1; i >= 0; i--)
128 | {
129 | arrGRAD[idx][j] = zeros(W[i]->rows, W[i]->cols);
130 | arrGRAD_B[idx][j] = zeros(B[i]->rows, B[i]->cols);
131 | idx++;
132 | arrGRAD[idx][j] = (zeros(W[i]->rows, W[i]->cols));
133 | arrGRAD_B[idx][j] = zeros(B[i]->rows, B[i]->cols);
134 | idx++;
135 | }
136 |
137 | arrGRAD[4][j] = zeros(_nBatchSize,nWordVectorDim*nWindowSize);
138 | arrGRAD[5][j] = zeros(_nBatchSize,nWordVectorDim*nWindowSize);
139 | }
140 |
141 |
142 |
143 | stackedVocabGrad_X = zeros(_nBatchSize*gpu.MPI_SIZE,nWordVectorDim*nWindowSize);
144 | stackedVocabGrad_Y = zeros(_nBatchSize*gpu.MPI_SIZE,nWordVectorDim*nWindowSize);
145 | stackedBatchIdx_X = zeros(_nBatchSize*gpu.MPI_SIZE,nWindowSize);
146 | stackedBatchIdx_Y = zeros(_nBatchSize*gpu.MPI_SIZE,nWindowSize);
147 | _Vocab = gpu.uniformSqrtWeight(nWordVectorDim,vocabSize);
148 | //_Vocab = gpu.sparseInitWeight(nWordVectorDim,vocabSize);
149 | //_Vocab = gpu.rand(nWordVectorDim,vocabSize);
150 | //scalarMul(_Vocab,0.01f,_Vocab);
151 | //scalarAdd(_Vocab,-0.5f,_Vocab);
152 | cout << sum(_Vocab) << endl;
153 | _Vocab_grad = zeros(nWordVectorDim,vocabSize);
154 | _MSVocab_grad = zeros(nWordVectorDim,vocabSize);
155 | _MSVocab_grad_Y = zeros(nWordVectorDim,vocabSize);
156 | M_VocabX = zeros(nWordVectorDim,vocabSize);
157 | M_VocabY = zeros(nWordVectorDim,vocabSize);
158 | _Vocab_grad_idx = zeros(1,vocabSize);
159 |
160 | d0 = zeros(_nBatchSize,nWordVectorDim*nWindowSize);
161 | z1 = zeros(_nBatchSize, _layers[0]);
162 | a1_Y = zeros(_nBatchSize, _layers[0]);
163 | a1_idx_Y = zeros(_nBatchSize, _layers[0]);
164 | a1_X = zeros(_nBatchSize, _layers[0]);
165 | a1_idx_X = zeros(_nBatchSize, _layers[0]);
166 | d1 = zeros(_nBatchSize, _layers[0]);
167 | z2_X = zeros(_nBatchSize, 1);
168 | z2_Y = zeros(_nBatchSize, 1);
169 |
170 | out = zeros(_nBatchSize,1);
171 | pairwise_grad = zeros(_nBatchSize,1);
172 | e1 = empty(_nBatchSize,1);
173 | aB = ones(1,_nBatchSize);
174 | e2_partial = zeros(_nBatchSize,W[1]->rows);
175 | e2 = empty(_nBatchSize,e2_partial->cols);
176 |
177 |
178 | _batchX = zeros(_nBatchSize, nWordVectorDim*nWindowSize);
179 | _batchY = zeros(_nBatchSize, nWordVectorDim*nWindowSize);
180 |
181 | _currentBatchIdx_X = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
182 | _currentBatchIdx_Y = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
183 | for(int i = 0; i < gpu.MPI_SIZE; i++)
184 | {
185 | _currentBatchIdx_X[i] = zeros(_nBatchSize, nWindowSize);
186 | _currentBatchIdx_Y[i] = zeros(_nBatchSize, nWindowSize);
187 | }
188 | _nextBatchIdx = zeros(_nBatchSize,nWindowSize);
189 |
190 | _dSumError = 0.0;
191 |
192 | loadNextDataSet();
193 |
194 |
195 | }
196 | void WikiMaxoutNet::run()
197 | {
198 | allocateNextBatch(false);
199 |
200 | size_t freemem, total;
201 | cudaMemGetInfo(&freemem,&total);
202 | cout << freemem << endl;
203 |
204 |
205 | srand( time(NULL) );
206 | start = clock();
207 |
208 | int i = 0;
209 | while(true)
210 | {
211 | if(i > 0 && i % _nCVErrorPeriodicity == 0)
212 | {
213 | if( i > 0 && i % 12000 == 0)
214 | {
215 | cout << "Saving vocabulary matrix to disk..." << endl;
216 | Matrix *host = to_host(_Vocab);
217 | write_hdf5("/home/tim/data/wiki/vocab.hdf5",host);
218 | free(host->data);
219 | free(host);
220 | write_hdf5("/home/tim/data/wiki/CV_values.hdf5",CV_container);
221 | }
222 |
223 | double error = calculateError();
224 | CV_container->data[i/_nCVErrorPeriodicity] = (float)error;
225 | cout << "BatchNo: " << i << endl;
226 | cout << "Cross validation error: " << error << endl;
227 | i+=gpu.MPI_SIZE;
228 |
229 | //MOMENTUM+= 0.01;
230 | //if( MOMENTUM > 0.95)
231 | //MOMENTUM = 0.95;
232 |
233 | _RMS_multiplier-= 0.01;
234 | if( _RMS_multiplier < 0.25)
235 | _RMS_multiplier = 0.25;
236 |
237 |
238 | MOMENTUM-= 0.01;
239 | if( MOMENTUM < 0.25)
240 | MOMENTUM = 0.25;
241 |
242 | stop = clock();
243 |
244 | double time_interval_seconds = (double((stop - start)) / CLOCKS_PER_SEC) ;
245 | cout << "Approximate time left in hours: " << ((1.0f/(((i*_nBatchSize)/(float)_X->rows)/63.0))*time_interval_seconds/(float)3600.0) -
246 | (time_interval_seconds/(float)3600.0)<< endl;
247 |
248 | }
249 | else
250 | {
251 | nesterov();
252 | feedforward();
253 | backprop();
254 | weightUpdates();
255 | }
256 | //cout << i << endl;
257 | allocateNextBatch(false);
258 | i+=gpu.MPI_SIZE;
259 | }
260 | }
261 |
262 | void WikiMaxoutNet::loadNextDataSet()
263 | {
264 |
265 |
266 | std::string path = "/home/tim/data/wiki/extracted2/AA/data100000/wiki_";
267 | //std::string path = "/home/tim/data/wiki/extracted2/AA/data/wiki_";
268 | std::string number = "";
269 | std::string ending = ".p";
270 | //std::string ending = ".p.hdf5";
271 |
272 | if(_nCurrentDataSet < 10)
273 | number += "0";
274 |
275 | number+= NumberToString(_nCurrentDataSet);
276 |
277 | cout << "Loading next data set: " << (path + number + ending) << endl;
278 | if(_X != 0)
279 | cudaFreeHost(_X->data);
280 | _X = read_hdf5((path + number + ending).c_str());
281 | _nCurrentDataSet += gpu.MPI_SIZE;
282 | _batches = _X->rows/ _nBatchSize;
283 | _nNextBatchNumber = 0;
284 | }
285 |
286 | void WikiMaxoutNet::allocateNextBatch(bool isCV)
287 | {
288 | if(!isCV)
289 | {
290 | if(_nNextBatchNumber < 0)
291 | _nNextBatchNumber = 0;
292 |
293 | if (_nBatchSize*11*(_nNextBatchNumber+1) > _X->size)
294 | loadNextDataSet();
295 |
296 | if(_nNextBatchNumber > 0 || _nCurrentDataSet > gpu.MYRANK)
297 | {
298 | cudaStreamSynchronize(_streamNextBatch);
299 | to_col_major(_nextBatchIdx, _currentBatchIdx_X[gpu.MYRANK]);
300 | gpu.construct_vocab_matrix(_currentBatchIdx_X[gpu.MYRANK], _currentBatchIdx_Y[gpu.MYRANK], _batchX, _batchY, _Vocab);
301 |
302 | gpu.add_to_queue(_currentBatchIdx_X);
303 | gpu.add_to_queue(_currentBatchIdx_Y);
304 | }
305 |
306 |
307 |
308 | cudaMemcpyAsync(_nextBatchIdx->data,&_X->data[_nBatchSize*11*_nNextBatchNumber],
309 | _nBatchSize*11*sizeof(float),
310 | cudaMemcpyHostToDevice,_streamNextBatch);
311 |
312 |
313 | _nNextBatchNumber+=1;
314 | }
315 | else
316 | {
317 | if(_nNextBatchNumber_CV > gpu.MYRANK*_nCVErrorLength)
318 | {
319 | cudaStreamSynchronize(_streamNextBatch);
320 | to_col_major(_nextBatchIdx, _currentBatchIdx_X[gpu.MYRANK]);
321 | gpu.construct_vocab_matrix(_currentBatchIdx_X[gpu.MYRANK], _currentBatchIdx_Y[gpu.MYRANK], _batchX, _batchY, _Vocab);
322 | }
323 |
324 | cudaMemcpyAsync(_nextBatchIdx->data,&_CV_X->data[_nBatchSize*11*_nNextBatchNumber_CV],
325 | _nBatchSize*11*sizeof(float),
326 | cudaMemcpyHostToDevice,_streamNextBatch);
327 |
328 |
329 | _nNextBatchNumber_CV+=1;
330 | }
331 |
332 |
333 | }
334 |
335 | void WikiMaxoutNet::nesterov()
336 | {
337 | //nesterov
338 | for(int i = 0;i < M.size(); i++)
339 | {
340 | scalarMul(M[i],MOMENTUM,M[i]);
341 | add(W[i],M[i],W[i]);
342 | }
343 |
344 | for(int i = 0;i < M_B.size(); i++)
345 | {
346 | scalarMul(M_B[i],MOMENTUM,M_B[i]);
347 | add(B[i],M_B[i],B[i]);
348 | }
349 |
350 | scalarMul(M_VocabX, MOMENTUM, M_VocabX);
351 | add(_Vocab,M_VocabX,_Vocab);
352 |
353 | scalarMul(M_VocabY, MOMENTUM, M_VocabY);
354 | add(_Vocab,M_VocabY,_Vocab);
355 |
356 | }
357 |
358 |
359 | void WikiMaxoutNet::feedforward()
360 | {
361 | gpu.dot(_batchX,W[0],z1);
362 | addMatrixVector(z1,B[0],z1);
363 | logistic(z1,a1_X);
364 | gpu.dot(a1_X,W[1],z2_X);
365 | addMatrixVector(z2_X,B[1],z2_X);
366 |
367 | gpu.dot(_batchY,W[0],z1);
368 | addMatrixVector(z1,B[0],z1);
369 | logistic(z1,a1_Y);
370 | gpu.dot(a1_Y,W[1],z2_Y);
371 | addMatrixVector(z2_Y,B[1],z2_Y);
372 |
373 | }
374 |
375 | void WikiMaxoutNet::weightUpdates()
376 | {
377 | float multiplier = _learningRate/(float)_nBatchSize;
378 |
379 | if(!useRMSProp)
380 | {
381 | scalarMul(M_VocabX,MOMENTUM,M_VocabX);
382 | scalarMul(M_VocabY,MOMENTUM,M_VocabY);
383 | for(int i = 0; i < M.size(); i++)
384 | scalarMul(M[i],MOMENTUM,M[i]);
385 | for(int i = 0; i < M_B.size(); i++)
386 | scalarMul(M_B[i],MOMENTUM,M_B[i]);
387 |
388 | while(gpu.get_queue_length() > (9*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
389 | vStackN(arrGRAD[4],stackedVocabGrad_Y,gpu.MPI_SIZE);
390 | update_vocab_with_gradient(stackedVocabGrad_Y,stackedBatchIdx_Y,_Vocab,multiplier/(float)gpu.MPI_SIZE);
391 | update_vocab_with_gradient(stackedVocabGrad_Y,stackedBatchIdx_Y,M_VocabY,multiplier/(float)gpu.MPI_SIZE);
392 | while(gpu.get_queue_length() > (8*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
393 | addGradientsN(arrGRAD[2],gpu.MYRANK,gpu.MPI_SIZE,multiplier/(float)(gpu.MPI_SIZE*arrGRAD[1][gpu.MYRANK]->rows));
394 | sub(W[0],arrGRAD[2][gpu.MYRANK],W[0]);
395 | sub(M[0],arrGRAD[2][gpu.MYRANK],M[0]);
396 | while(gpu.get_queue_length() > (7*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
397 | vStackN(arrGRAD[5],stackedVocabGrad_X,gpu.MPI_SIZE);
398 | update_vocab_with_gradient(stackedVocabGrad_X,stackedBatchIdx_X,_Vocab,multiplier/(float)gpu.MPI_SIZE);
399 | update_vocab_with_gradient(stackedVocabGrad_X,stackedBatchIdx_X,M_VocabX,multiplier/(float)gpu.MPI_SIZE);
400 | while(gpu.get_queue_length() > (6*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
401 | addGradientsN(arrGRAD[3],gpu.MYRANK,gpu.MPI_SIZE,multiplier/(float)(gpu.MPI_SIZE*arrGRAD[3][gpu.MYRANK]->rows));
402 | sub(W[0],arrGRAD[3][gpu.MYRANK],W[0]);
403 | sub(M[0],arrGRAD[3][gpu.MYRANK],M[0]);
404 | while(gpu.get_queue_length() > (5*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
405 | addGradientsN(arrGRAD[0],gpu.MYRANK,gpu.MPI_SIZE,multiplier/(float)(gpu.MPI_SIZE*arrGRAD[0][gpu.MYRANK]->rows));
406 | sub(W[1],arrGRAD[0][gpu.MYRANK],W[1]);
407 | sub(M[1],arrGRAD[0][gpu.MYRANK],M[1]);
408 | while(gpu.get_queue_length() > (4*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
409 | addGradientsN(arrGRAD[1],gpu.MYRANK,gpu.MPI_SIZE,multiplier/(float)(gpu.MPI_SIZE*arrGRAD[1][gpu.MYRANK]->rows));
410 | sub(W[1],arrGRAD[1][gpu.MYRANK],W[1]);
411 | sub(M[1],arrGRAD[1][gpu.MYRANK],M[1]);
412 | while(gpu.get_queue_length() > (3*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
413 | addGradientsN(arrGRAD_B[2],gpu.MYRANK,gpu.MPI_SIZE,multiplier/(float)gpu.MPI_SIZE);
414 | sub(B[0],arrGRAD_B[2][gpu.MYRANK],B[0]);
415 | sub(M_B[0],arrGRAD_B[2][gpu.MYRANK],M_B[0]);
416 | while(gpu.get_queue_length() > (2*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
417 | addGradientsN(arrGRAD_B[0],gpu.MYRANK,gpu.MPI_SIZE,multiplier/(float)gpu.MPI_SIZE);
418 | sub(B[1],arrGRAD_B[0][gpu.MYRANK],B[1]);
419 | sub(M_B[1],arrGRAD_B[0][gpu.MYRANK],M_B[1]);
420 | while(gpu.get_queue_length() > (1*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
421 | addGradientsN(arrGRAD_B[1],gpu.MYRANK,gpu.MPI_SIZE,multiplier/(float)gpu.MPI_SIZE);
422 | sub(B[1],arrGRAD_B[1][gpu.MYRANK],B[1]);
423 | sub(M_B[1],arrGRAD_B[1][gpu.MYRANK],M_B[1]);
424 | while(gpu.get_queue_length() > 0){ gpu.pop_queue(); }
425 | addGradientsN(arrGRAD_B[3],gpu.MYRANK,gpu.MPI_SIZE,multiplier/(float)gpu.MPI_SIZE);
426 | sub(B[0],arrGRAD_B[3][gpu.MYRANK],B[0]);
427 | sub(M_B[0],arrGRAD_B[3][gpu.MYRANK],M_B[0]);
428 | }
429 | else
430 | {
431 |
432 | //10*MPI_SIZE gradients added
433 |
434 | fill_matrix(_Vocab_grad,0.0f);
435 | while(gpu.get_queue_length() > (9*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
436 | vStackN(arrGRAD[4],stackedVocabGrad_Y,gpu.MPI_SIZE);
437 | expand_vocab_gradient(stackedVocabGrad_Y,stackedBatchIdx_Y,_Vocab_grad);
438 |
439 |
440 |
441 |
442 | RMSprop_with_nesterov_weight_update(_MSVocab_grad_Y,_Vocab_grad,_Vocab,M_VocabY,_RMS_multiplier,_learningRate/(float)_nBatchSize,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
443 | while(gpu.get_queue_length() > (8*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
444 | addGradientsN(arrGRAD[2],gpu.MYRANK,gpu.MPI_SIZE,1.0f/(float)gpu.MPI_SIZE);
445 | RMSprop_with_nesterov_weight_update(MSGRAD[2],arrGRAD[2][gpu.MYRANK],W[0],M[0],0.9f,_learningRate/(float)arrGRAD[2][gpu.MYRANK]->rows,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
446 |
447 | fill_matrix(_Vocab_grad,0.0f);
448 | while(gpu.get_queue_length() > (7*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
449 | vStackN(arrGRAD[5],stackedVocabGrad_X,gpu.MPI_SIZE);
450 | expand_vocab_gradient(stackedVocabGrad_X,stackedBatchIdx_X,_Vocab_grad);
451 | RMSprop_with_nesterov_weight_update(_MSVocab_grad,_Vocab_grad,_Vocab,M_VocabX,_RMS_multiplier,_learningRate/(float)_nBatchSize,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
452 |
453 |
454 | while(gpu.get_queue_length() > (6*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
455 | addGradientsN(arrGRAD[3],gpu.MYRANK,gpu.MPI_SIZE,1.0f/(float)gpu.MPI_SIZE);
456 | RMSprop_with_nesterov_weight_update(MSGRAD[3],arrGRAD[3][gpu.MYRANK],W[0],M[0],0.9f,_learningRate/(float)arrGRAD[3][gpu.MYRANK]->rows,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
457 |
458 | while(gpu.get_queue_length() > (5*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
459 | addGradientsN(arrGRAD[0],gpu.MYRANK,gpu.MPI_SIZE,1.0f/(float)gpu.MPI_SIZE);
460 | RMSprop_with_nesterov_weight_update(MSGRAD[0],arrGRAD[0][gpu.MYRANK],W[1],M[1],0.9f,_learningRate/(float)arrGRAD[0][gpu.MYRANK]->rows,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
461 | while(gpu.get_queue_length() > (4*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
462 | addGradientsN(arrGRAD[1],gpu.MYRANK,gpu.MPI_SIZE,1.0f/(float)gpu.MPI_SIZE);
463 | RMSprop_with_nesterov_weight_update(MSGRAD[1],arrGRAD[1][gpu.MYRANK],W[1],M[1],0.9f,_learningRate/(float)arrGRAD[1][gpu.MYRANK]->rows,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
464 | while(gpu.get_queue_length() > (3*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
465 | addGradientsN(arrGRAD_B[2],gpu.MYRANK,gpu.MPI_SIZE,1.0f/(float)gpu.MPI_SIZE);
466 | RMSprop_with_nesterov_weight_update(MSBGRAD[2],arrGRAD_B[2][gpu.MYRANK],B[0],M_B[0],0.9f,_learningRate,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
467 | while(gpu.get_queue_length() > (2*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
468 | addGradientsN(arrGRAD_B[0],gpu.MYRANK,gpu.MPI_SIZE,1.0f/(float)gpu.MPI_SIZE);
469 | RMSprop_with_nesterov_weight_update(MSBGRAD[0],arrGRAD_B[0][gpu.MYRANK],B[1],M_B[1],0.9f,_learningRate,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
470 | while(gpu.get_queue_length() > (1*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
471 | addGradientsN(arrGRAD_B[1],gpu.MYRANK,gpu.MPI_SIZE,1.0f/(float)gpu.MPI_SIZE);
472 | RMSprop_with_nesterov_weight_update(MSBGRAD[1],arrGRAD_B[1][gpu.MYRANK],B[1],M_B[1],0.9f,_learningRate,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
473 | while(gpu.get_queue_length() > 0){ gpu.pop_queue(); }
474 | addGradientsN(arrGRAD_B[3],gpu.MYRANK,gpu.MPI_SIZE,1.0f/(float)gpu.MPI_SIZE);
475 | RMSprop_with_nesterov_weight_update(MSBGRAD[3],arrGRAD_B[3][gpu.MYRANK],B[0],M_B[0],0.9f,_learningRate,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
476 |
477 | /*
478 |
479 | for(int i = 0; i < arrGRAD.size(); i++)
480 | {
481 | cout << "G: " << i << " " << sum(arrGRAD[i][gpu.MYRANK]) << endl;
482 | MPI_Barrier(MPI_COMM_WORLD);
483 | }
484 |
485 | for(int i = 0; i < arrGRAD_B.size(); i++)
486 | {
487 | cout << "B: " << i << " " << sum(arrGRAD_B[i][gpu.MYRANK]) << endl;
488 | MPI_Barrier(MPI_COMM_WORLD);
489 | }
490 | */
491 |
492 | }
493 |
494 |
495 | MPI_Barrier(MPI_COMM_WORLD);
496 |
497 | }
498 |
499 | void WikiMaxoutNet::backprop()
500 | {
501 | pairwise_ranking(z2_X,z2_Y, out);
502 | pairwise_ranking_derivative(z2_X,z2_Y, pairwise_grad);
503 |
504 | mul(out, pairwise_grad, e1);
505 | gpu.dotT(e1, W[1],e2_partial);
506 |
507 | gpu.dot(aB,e1,arrGRAD_B[0][gpu.MYRANK]);
508 | gpu.tick();
509 | gpu.Tdot(a1_Y,e1,arrGRAD[0][gpu.MYRANK]);
510 |
511 | logisticGrad(a1_Y,a1_Y);
512 | mul(e2_partial,a1_Y,e2);
513 |
514 | gpu.Tdot(_batchY,e2,arrGRAD[2][gpu.MYRANK]);
515 | gpu.dot(aB,e2,arrGRAD_B[2][gpu.MYRANK]);
516 | gpu.dotT(e2,W[0],arrGRAD[4][gpu.MYRANK]);
517 |
518 |
519 |
520 | while(gpu.get_queue_length() > 0){ gpu.pop_queue(); }
521 |
522 | vStackN(_currentBatchIdx_X,stackedBatchIdx_X,gpu.MPI_SIZE);
523 | vStackN(_currentBatchIdx_Y,stackedBatchIdx_Y,gpu.MPI_SIZE);
524 |
525 | gpu.add_to_queue(arrGRAD[4]);
526 | gpu.add_to_queue(arrGRAD[2]);
527 |
528 |
529 | scalarMul(pairwise_grad,-1.0f,pairwise_grad);
530 | gpu.pop_queue();
531 | mul(out, pairwise_grad, e1);
532 | gpu.pop_queue();
533 |
534 | gpu.dot(aB,e1,arrGRAD_B[1][gpu.MYRANK]);
535 | gpu.pop_queue();
536 | gpu.Tdot(a1_X,e1,arrGRAD[1][gpu.MYRANK]);
537 | gpu.pop_queue();
538 | gpu.dotT(e1, W[1],e2_partial);
539 | gpu.pop_queue();
540 |
541 | logisticGrad(a1_X,a1_X);
542 | gpu.pop_queue();
543 | mul(e2_partial,a1_X,e2);
544 | gpu.pop_queue();
545 |
546 | gpu.Tdot(_batchX,e2,arrGRAD[3][gpu.MYRANK]);
547 | gpu.pop_queue();
548 | gpu.dot(aB,e2,arrGRAD_B[3][gpu.MYRANK]);
549 | gpu.pop_queue();
550 | gpu.dotT(e2,W[0],arrGRAD[5][gpu.MYRANK]);
551 |
552 | gpu.add_to_queue(arrGRAD[5]);
553 | gpu.add_to_queue(arrGRAD[3]);
554 |
555 |
556 | gpu.add_to_queue(arrGRAD[0]);
557 | gpu.add_to_queue(arrGRAD[1]);
558 |
559 | gpu.add_to_queue(arrGRAD_B[2]);
560 | gpu.add_to_queue(arrGRAD_B[0]);
561 | gpu.add_to_queue(arrGRAD_B[1]);
562 | gpu.add_to_queue(arrGRAD_B[3]);
563 |
564 | }
565 |
566 | double WikiMaxoutNet::calculateError()
567 | {
568 | //scalarMul(W[0],0.9,W[0]);
569 | allocateNextBatch(true);
570 | for(int i = 0; i < _nCVErrorLength; i+=gpu.MPI_SIZE)
571 | {
572 |
573 | feedforward();
574 |
575 | pairwise_ranking(z2_X,z2_Y, out);
576 | _dSumError += (double)sum(out);
577 |
578 | allocateNextBatch(true);
579 | }
580 | //scalarMul(W[0],1.1,W[0]);
581 | //size_t free, total;
582 | //cudaMemGetInfo(&free, &total);
583 | //cout << free << endl;
584 | //cout << "Free system memory: " << sysconf(_SC_PAGE_SIZE)*sysconf(_SC_PHYS_PAGES) << endl;
585 |
586 |
587 | double error = _dSumError/(double)(_nBatchSize*_nCVErrorLength/gpu.MPI_SIZE);
588 | _dSumError = 0.0;
589 |
590 |
591 |
592 | _nNextBatchNumber_CV = gpu.MYRANK*_nCVErrorLength;
593 |
594 | return error;
595 | }
596 |
597 |
598 |
599 |
600 |
--------------------------------------------------------------------------------
/source/WikiMaxoutNet.h:
--------------------------------------------------------------------------------
1 | /*
2 | * WikiMaxoutNet.h
3 | *
4 | * Created on: Jun 25, 2014
5 | * Author: tim
6 | */
7 |
8 | #ifndef WIKIMAXOUTNET_H_
9 | #define WIKIMAXOUTNET_H_
10 |
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | class WikiMaxoutNet
24 | {
25 | public:
26 | WikiMaxoutNet(ClusterNet gpus);
27 | void run();
28 |
29 | private:
30 | Matrix **_currentBatchIdx_X;
31 | Matrix **_currentBatchIdx_Y;
32 | Matrix *_nextBatchIdx;
33 | ClusterNet gpu;
34 | Matrix *_X;
35 | Matrix *_CV_X;
36 | Matrix *_Vocab;
37 | Matrix *_Vocab_grad;
38 | Matrix *_MSVocab_grad;
39 | Matrix *_MSVocab_grad_Y;
40 | Matrix *M_VocabX;
41 | Matrix *M_VocabY;
42 | Matrix *_Vocab_grad_idx;
43 | Matrix *_batchX;
44 | Matrix *_batchY;
45 | Matrix *stackedVocabGrad_X;
46 | Matrix *stackedVocabGrad_Y;
47 | Matrix *stackedBatchIdx_X;
48 | Matrix *stackedBatchIdx_Y;
49 |
50 | Matrix *out;
51 | Matrix *pairwise_grad;
52 | Matrix *e1;
53 | Matrix *aB;
54 | Matrix *e2_partial;
55 | Matrix *e2;
56 | Matrix *CV_container;
57 |
58 | Matrix *learning_rate_matrix;
59 |
60 | int _nCurrentDataSet;
61 | int _nNextBatchNumber;
62 | int _nNextBatchNumber_CV;
63 | float _RMS_multiplier;
64 | int _nBatchSize;
65 | int _batches;
66 | std::vector _layers;
67 | std::vector W;
68 | std::vector B;
69 | std::vector M;
70 | std::vector M_B;
71 | std::vector arrGRAD;
72 | std::vector MSGRAD;
73 | std::vector arrGRAD_B;
74 | std::vector MSBGRAD;
75 | clock_t start,stop;
76 |
77 | Matrix *d0;
78 | Matrix *z1;
79 | Matrix *a1_Y;
80 | Matrix *a1_idx_Y;
81 | Matrix *a1_X;
82 | Matrix *a1_idx_X;
83 | Matrix *d1;
84 | Matrix *z2_X;
85 | Matrix *z2_Y;
86 |
87 | cudaStream_t _streamNextBatch;
88 | double _dSumError;
89 | int _nCVErrorPeriodicity;
90 | int _nCVErrorLength;
91 | int _nMaxoutSize;
92 | float MOMENTUM;
93 | float _learningRate;
94 | int _totalNumberOfBatches;
95 |
96 | bool useRMSProp;
97 | bool useMaxout;
98 |
99 |
100 | void loadNextDataSet();
101 | void allocateNextBatch(bool isCV);
102 | void feedforward();
103 | void nesterov();
104 | double calculateError();
105 | void backprop();
106 | void weightUpdates();
107 | };
108 |
109 |
110 | #endif /* WIKIMAXOUTNET_H_ */
111 |
--------------------------------------------------------------------------------
/source/WikiMaxoutNet_PCIe.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | using std::cout;
4 | using std::endl;
5 |
6 |
7 | WikiMaxoutNet_PCIe::WikiMaxoutNet_PCIe(ClusterNet *gpus)
8 | {
9 |
10 | int vocabSize = 100002;
11 | int nWordVectorDim = 64;
12 | int nWindowSize = 11;
13 | _layers.push_back(512);
14 | _learningRate = 0.01;
15 | _nCVErrorPeriodicity = 6000;
16 | _nCVErrorLength = 6000;
17 | MOMENTUM = 0.5;
18 | gpu = gpus[0];
19 | _nCurrentDataSet = gpu.MYRANK;
20 | _X = 0;
21 | int cv_set_number = 63-gpu.MYRANK;
22 | cudaSetDevice(0);
23 | _CV_X = read_hdf5(("/home/tim/data/wiki/extracted2/AA/data100000/wiki_" + NumberToString(cv_set_number) + ".p").c_str());
24 | _nNextBatchNumber = 0;
25 | _nNextBatchNumber_CV = 0;
26 | _nBatchSize = 128;
27 | _RMS_multiplier = 0.9f;
28 |
29 |
30 | //TODO set weights equal
31 |
32 | gpu.GPU_COUNT = 3;
33 |
34 |
35 | for(int i = 0; i < gpu.GPU_COUNT; i++)
36 | {
37 | cudaSetDevice(i);
38 | cudaStream_t t;
39 | cudaStreamCreate(&t);
40 | _streamNextBatch.push_back(t);
41 |
42 | }
43 | cudaSetDevice(0);
44 |
45 |
46 |
47 |
48 | useRMSProp = true;
49 |
50 | cout << "_layers: " << _layers[0] << endl;
51 | cout << "nWordVectorDim: " << nWordVectorDim << endl;
52 | cout << "_nBatchSize: " << _nBatchSize << endl;
53 | cout << "_learningRate: " << _learningRate << endl;
54 | cout << "Use RMSProp: " << useRMSProp << endl;
55 |
56 | W.push_back(gpu.uniformSqrtWeight_PCIe(nWordVectorDim*nWindowSize,_layers[0]));
57 | W.push_back(gpu.uniformSqrtWeight_PCIe(_layers[0], 1));
58 | B.push_back(gpu.zeros_PCIe(1,_layers[0]));
59 | B.push_back(gpu.zeros_PCIe(1,1));
60 | M.push_back(gpu.zeros_PCIe(nWordVectorDim*nWindowSize,_layers[0]));
61 | M.push_back(gpu.zeros_PCIe(_layers[0], 1));
62 | M_B.push_back(gpu.zeros_PCIe(1,_layers[0]));
63 | M_B.push_back(gpu.zeros_PCIe(1,1));
64 |
65 |
66 |
67 | CV_container = empty_cpu(10000,1);
68 | for(int i = 0; i < CV_container->size; i++)
69 | CV_container->data[i] = 0.0f;
70 |
71 |
72 |
73 | cout << gpu.GPU_COUNT << " MPI SIZE" << endl;
74 | cout << gpu.MYRANK << " MYRANK " << endl;
75 | for(int i = W.size()-1; i >= 0; i--)
76 | {
77 | Matrix **gradX = (Matrix**)malloc(sizeof(Matrix*)*gpu.GPU_COUNT);
78 | arrGRAD.push_back(gradX);
79 | Matrix **gradY = (Matrix**)malloc(sizeof(Matrix*)*gpu.GPU_COUNT);
80 | arrGRAD.push_back(gradY);
81 | Matrix **gradX_B = (Matrix**)malloc(sizeof(Matrix*)*gpu.GPU_COUNT);
82 | arrGRAD_B.push_back(gradX_B);
83 | Matrix **gradY_B = (Matrix**)malloc(sizeof(Matrix*)*gpu.GPU_COUNT);
84 | arrGRAD_B.push_back(gradY_B);
85 | }
86 | Matrix **gradX = (Matrix**)malloc(sizeof(Matrix*)*gpu.GPU_COUNT);
87 | arrGRAD.push_back(gradX);
88 | Matrix **gradY = (Matrix**)malloc(sizeof(Matrix*)*gpu.GPU_COUNT);
89 | arrGRAD.push_back(gradY);
90 |
91 | cout << arrGRAD.size() << " size" << endl;
92 |
93 |
94 | for(int i = W.size()-1; i >= 0; i--)
95 | {
96 | MSGRAD.push_back(gpu.zeros_PCIe(W[i][0]->rows, W[i][0]->cols));
97 | MSGRAD.push_back(gpu.zeros_PCIe(W[i][0]->rows, W[i][0]->cols));
98 | MSBGRAD.push_back(gpu.zeros_PCIe(B[i][0]->rows, B[i][0]->cols));
99 | MSBGRAD.push_back(gpu.zeros_PCIe(B[i][0]->rows, B[i][0]->cols));
100 | }
101 |
102 | for(int j =0; j < gpu.GPU_COUNT; j++)
103 | {
104 | int idx = 0;
105 | for(int i = W.size()-1; i >= 0; i--)
106 | {
107 | arrGRAD[idx] = gpu.zeros_gradient_PCIe(W[i][0]->rows, W[i][0]->cols);
108 | arrGRAD_B[idx] = gpu.zeros_gradient_PCIe(B[i][0]->rows, B[i][0]->cols);
109 | idx++;
110 | arrGRAD[idx] = (gpu.zeros_gradient_PCIe(W[i][0]->rows, W[i][0]->cols));
111 | arrGRAD_B[idx] = gpu.zeros_gradient_PCIe(B[i][0]->rows, B[i][0]->cols);
112 | idx++;
113 | }
114 |
115 | arrGRAD[4] = gpu.zeros_gradient_PCIe(_nBatchSize,nWordVectorDim*nWindowSize);
116 | arrGRAD[5] = gpu.zeros_gradient_PCIe(_nBatchSize,nWordVectorDim*nWindowSize);
117 | }
118 |
119 | _Vocab = gpu.uniformSqrtWeight_PCIe(nWordVectorDim,vocabSize);
120 | //_Vocab = gpu.sparseInitWeight(nWordVectorDim,vocabSize);
121 | //_Vocab = gpu.rand(nWordVectorDim,vocabSize);
122 | //scalarMul(_Vocab,0.01f,_Vocab);
123 | //scalarAdd(_Vocab,-0.5f,_Vocab);
124 | _Vocab_grad = gpu.zeros_PCIe(nWordVectorDim,vocabSize);
125 | _MSVocab_grad = gpu.zeros_PCIe(nWordVectorDim,vocabSize);
126 | _MSVocab_grad_Y = gpu.zeros_PCIe(nWordVectorDim,vocabSize);
127 | M_VocabX = gpu.zeros_PCIe(nWordVectorDim,vocabSize);
128 | M_VocabY = gpu.zeros_PCIe(nWordVectorDim,vocabSize);
129 | _Vocab_grad_idx = gpu.zeros_PCIe(1,vocabSize);
130 |
131 | d0 = gpu.zeros_PCIe(_nBatchSize,nWordVectorDim*nWindowSize);
132 | z1 = gpu.zeros_PCIe(_nBatchSize, _layers[0]);
133 | a1_Y = gpu.zeros_PCIe(_nBatchSize, _layers[0]);
134 | a1_idx_Y = gpu.zeros_PCIe(_nBatchSize, _layers[0]);
135 | a1_X = gpu.zeros_PCIe(_nBatchSize, _layers[0]);
136 | a1_idx_X = gpu.zeros_PCIe(_nBatchSize, _layers[0]);
137 | d1 = gpu.zeros_PCIe(_nBatchSize, _layers[0]);
138 | z2_X = gpu.zeros_PCIe(_nBatchSize, 1);
139 | z2_Y = gpu.zeros_PCIe(_nBatchSize, 1);
140 |
141 | out = gpu.zeros_PCIe(_nBatchSize,1);
142 | pairwise_grad = gpu.zeros_PCIe(_nBatchSize,1);
143 | e1 = gpu.zeros_PCIe(_nBatchSize,1);
144 | aB = gpu.ones_PCIe(1,_nBatchSize);
145 | e2_partial = gpu.zeros_PCIe(_nBatchSize,W[1][0]->rows);
146 | e2 = gpu.zeros_PCIe(_nBatchSize,e2_partial[0]->cols);
147 |
148 |
149 | _batchX = gpu.zeros_PCIe(_nBatchSize, nWordVectorDim*nWindowSize);
150 | _batchY = gpu.zeros_PCIe(_nBatchSize, nWordVectorDim*nWindowSize);
151 | _currentBatchIdx_X = gpu.zeros_PCIe(_nBatchSize,nWindowSize);
152 | _currentBatchIdx_Y = gpu.zeros_PCIe(_nBatchSize,nWindowSize);
153 | _nextBatchIdx = gpu.zeros_PCIe(_nBatchSize,nWindowSize);
154 |
155 | _dSumError = 0.0;
156 |
157 |
158 |
159 | loadNextDataSet();
160 | }
161 |
162 | void WikiMaxoutNet_PCIe::loadNextDataSet()
163 | {
164 |
165 |
166 | std::string path = "/home/tim/data/wiki/extracted2/AA/data100000/wiki_";
167 | //std::string path = "/home/tim/data/wiki/extracted2/AA/data/wiki_";
168 | std::string number = "";
169 | std::string ending = ".p";
170 | //std::string ending = ".p.hdf5";
171 |
172 | if(_nCurrentDataSet < 10)
173 | number += "0";
174 |
175 | number+= NumberToString(_nCurrentDataSet);
176 |
177 | cout << "Loading next data set: " << (path + number + ending) << endl;
178 | if(_X != 0)
179 | cudaFreeHost(_X->data);
180 | _X = read_hdf5((path + number + ending).c_str());
181 | _nCurrentDataSet += 1;
182 | _batches = _X->rows/ _nBatchSize;
183 | _nNextBatchNumber = 0;
184 | }
185 |
186 |
187 |
188 | void WikiMaxoutNet_PCIe::allocateNextBatch(bool isCV)
189 | {
190 | for(int i = 0; i < gpu.GPU_COUNT; i++)
191 | {
192 | cudaSetDevice(i);
193 |
194 | if(!isCV)
195 | {
196 | if(_nNextBatchNumber < 0)
197 | _nNextBatchNumber = 0;
198 |
199 | if (_nBatchSize*11*(_nNextBatchNumber+1) > _X->size)
200 | loadNextDataSet();
201 |
202 | if(_nNextBatchNumber >= gpu.GPU_COUNT)
203 | {
204 | cout << "pre sync" << endl;
205 | cudaStreamSynchronize(_streamNextBatch[i]);
206 | cout << "post sync" << endl;
207 | to_col_major(_nextBatchIdx[i], _currentBatchIdx_X[i]);
208 | cout << "post col major" << endl;
209 | gpu.construct_vocab_matrix(_currentBatchIdx_X[i], _currentBatchIdx_Y[i], _batchX[i], _batchY[i], _Vocab[i]);
210 | cout << "post constructed" << endl;
211 | }
212 |
213 |
214 |
215 | cudaMemcpyAsync(_nextBatchIdx[i]->data,&_X->data[_nBatchSize*11*_nNextBatchNumber],
216 | _nBatchSize*11*sizeof(float),
217 | cudaMemcpyHostToDevice,_streamNextBatch[i]);
218 |
219 |
220 | _nNextBatchNumber+=1;
221 | }
222 | else
223 | {
224 | if(_nNextBatchNumber_CV >= gpu.GPU_COUNT)
225 | {
226 | cudaStreamSynchronize(_streamNextBatch[i]);
227 | to_col_major(_nextBatchIdx[i], _currentBatchIdx_X[i]);
228 | gpu.construct_vocab_matrix(_currentBatchIdx_X[i], _currentBatchIdx_Y[i], _batchX[i], _batchY[i], _Vocab[i]);
229 | }
230 |
231 | cudaMemcpyAsync(_nextBatchIdx[i]->data,&_CV_X->data[_nBatchSize*11*_nNextBatchNumber_CV],
232 | _nBatchSize*11*sizeof(float),
233 | cudaMemcpyHostToDevice,_streamNextBatch[i]);
234 |
235 |
236 | _nNextBatchNumber_CV+=1;
237 | }
238 |
239 | }
240 |
241 |
242 | }
243 |
244 |
245 |
246 |
247 | void WikiMaxoutNet_PCIe::run()
248 | {
249 | allocateNextBatch(false);
250 |
251 | size_t freemem, total;
252 | cudaMemGetInfo(&freemem,&total);
253 | cout << freemem << endl;
254 |
255 |
256 | srand( time(NULL) );
257 | start = clock();
258 |
259 | int i = 0;
260 | while(true)
261 | {
262 | if(i > 0 && i % _nCVErrorPeriodicity == 0)
263 | {
264 | if( i > 0 && i % 100000 == 0)
265 | {
266 | cout << "Saving vocabulary matrix to disk..." << endl;
267 | cudaSetDevice(0);
268 | Matrix *host = to_host(_Vocab[0]);
269 | write_hdf5("/home/tim/data/wiki/vocab.hdf5",host);
270 | free(host->data);
271 | free(host);
272 | write_hdf5("/home/tim/data/wiki/CV_values.hdf5",CV_container);
273 | }
274 |
275 | double error = calculateError();
276 | CV_container->data[i/_nCVErrorPeriodicity] = (float)error;
277 | cout << "BatchNo: " << i << endl;
278 | cout << "Cross validation error: " << error << endl;
279 | i+=gpu.GPU_COUNT;
280 |
281 | //MOMENTUM+= 0.01;
282 | //if( MOMENTUM > 0.95)
283 | //MOMENTUM = 0.95;
284 |
285 | _RMS_multiplier-= 0.01;
286 | if( _RMS_multiplier < 0.25)
287 | _RMS_multiplier = 0.25;
288 |
289 | stop = clock();
290 |
291 | double time_interval_seconds = (double((stop - start)) / CLOCKS_PER_SEC) ;
292 | cout << "Approximate time left in hours: " << ((1.0f/(((i*_nBatchSize)/(float)_X->rows)/63.0))*time_interval_seconds/(float)3600.0) -
293 | (time_interval_seconds/(float)3600.0)<< endl;
294 |
295 | }
296 | else
297 | {
298 | cout << "nesterov" << endl;
299 | nesterov();
300 | feedforward();
301 | backprop();
302 | weightUpdates();
303 | }
304 |
305 | cout << i << endl;
306 | allocateNextBatch(false);
307 | i+=gpu.GPU_COUNT;
308 |
309 | cout << i << endl;
310 | }
311 | }
312 |
313 |
314 |
315 |
316 |
317 | void WikiMaxoutNet_PCIe::nesterov()
318 | {
319 | //nesterov
320 | for(int i = 0;i < M.size(); i++)
321 | {
322 | gpu.scalarMul_PCIe(M[i],MOMENTUM,M[i]);
323 | gpu.add_PCIe(W[i],M[i],W[i]);
324 | }
325 |
326 | for(int i = 0;i < B.size(); i++)
327 | {
328 | gpu.scalarMul_PCIe(M_B[i],MOMENTUM,M_B[i]);
329 | gpu.add_PCIe(B[i],M_B[i],B[i]);
330 | }
331 |
332 | gpu.scalarMul_PCIe(M_VocabX, MOMENTUM, M_VocabX);
333 | gpu.add_PCIe(_Vocab,M_VocabX,_Vocab);
334 |
335 | gpu.scalarMul_PCIe(M_VocabY, MOMENTUM, M_VocabY);
336 | gpu.add_PCIe(_Vocab,M_VocabY,_Vocab);
337 |
338 | }
339 |
340 |
341 | void WikiMaxoutNet_PCIe::feedforward()
342 | {
343 | gpu.dotPCIe(_batchX,W[0],z1);
344 | gpu.addMatrixVector_PCIe(z1,B[0],z1);
345 | gpu.logistic_PCIe(z1,a1_X);
346 | gpu.dotPCIe(a1_X,W[1],z2_X);
347 | gpu.addMatrixVector_PCIe(z2_X,B[1],z2_X);
348 |
349 | gpu.dotPCIe(_batchY,W[0],z1);
350 | gpu.addMatrixVector_PCIe(z1,B[0],z1);
351 | gpu.logistic_PCIe(z1,a1_Y);
352 | gpu.dotPCIe(a1_Y,W[1],z2_Y);
353 | gpu.addMatrixVector_PCIe(z2_Y,B[1],z2_Y);
354 |
355 |
356 | }
357 |
358 |
359 | void WikiMaxoutNet_PCIe::weightUpdates()
360 | {
361 | float multiplier = _learningRate/(float)_nBatchSize;
362 |
363 | if(!useRMSProp)
364 | {
365 | /*
366 | scalarMul(arrGRAD[0][gpu.MYRANK],multiplier/(float)arrGRAD[0][gpu.MYRANK]->rows,arrGRAD[0][gpu.MYRANK]);
367 | scalarMul(arrGRAD[1][gpu.MYRANK],multiplier/(float)arrGRAD[1][gpu.MYRANK]->rows,arrGRAD[1][gpu.MYRANK]);
368 | scalarMul(arrGRAD[2][gpu.MYRANK],multiplier/(float)arrGRAD[1][gpu.MYRANK]->rows,arrGRAD[2][gpu.MYRANK]);
369 | scalarMul(arrGRAD[3][gpu.MYRANK],multiplier/(float)arrGRAD[1][gpu.MYRANK]->rows,arrGRAD[3][gpu.MYRANK]);
370 | scalarMul(arrGRAD_B[0][gpu.MYRANK],multiplier,arrGRAD_B[0][gpu.MYRANK]);
371 | scalarMul(arrGRAD_B[1][gpu.MYRANK],multiplier,arrGRAD_B[1][gpu.MYRANK]);
372 | scalarMul(arrGRAD_B[2][gpu.MYRANK],multiplier,arrGRAD_B[2][gpu.MYRANK]);
373 | scalarMul(arrGRAD_B[3][gpu.MYRANK],multiplier,arrGRAD_B[3][gpu.MYRANK]);
374 |
375 | sub(W[1],arrGRAD[0][gpu.MYRANK],W[1]);
376 | sub(W[1],arrGRAD[1][gpu.MYRANK],W[1]);
377 | sub(W[0],arrGRAD[2][gpu.MYRANK],W[0]);
378 | sub(W[0],arrGRAD[3][gpu.MYRANK],W[0]);
379 | sub(B[1],arrGRAD_B[0][gpu.MYRANK],B[1]);
380 | sub(B[1],arrGRAD_B[1][gpu.MYRANK],B[1]);
381 | sub(B[0],arrGRAD_B[2][gpu.MYRANK],B[0]);
382 | sub(B[0],arrGRAD_B[3][gpu.MYRANK],B[0]);
383 |
384 | update_vocab_with_gradient(arrGRAD[4][gpu.MYRANK],_currentBatchIdx_Y,_Vocab,multiplier);
385 | update_vocab_with_gradient(arrGRAD[5][gpu.MYRANK],_currentBatchIdx_X,_Vocab,multiplier);
386 | */
387 | }
388 | else
389 | {
390 |
391 | //10*GPU_COUNT gradients added
392 |
393 | cout << gpu.get_queue_length() << endl;
394 | while(gpu.get_queue_length() > 0)
395 | {
396 | gpu.pop_queue_PCIe();
397 | usleep(100);
398 | }
399 |
400 | cout << "past lol" << endl;
401 |
402 | for(int i = 0; i < gpu.GPU_COUNT; i++)
403 | {
404 | cudaSetDevice(i);
405 | fill_matrix(_Vocab_grad[i],0.0f);
406 | expand_vocab_gradient(arrGRAD[4][i],_currentBatchIdx_Y[i],_Vocab_grad[i]);
407 | }
408 | //cout << gpu.get_queue_length() << endl;
409 | gpu.RMSprop_with_nesterov_weight_update_PCIe(_MSVocab_grad_Y,_Vocab_grad,_Vocab,M_VocabY,_RMS_multiplier,_learningRate/(float)_nBatchSize,1, MOMENTUM);
410 | gpu.RMSprop_with_nesterov_weight_update_PCIe(MSGRAD[2],arrGRAD[2],W[0],M[0],0.9f,_learningRate/(float)arrGRAD[2][gpu.MYRANK]->rows,_nBatchSize, MOMENTUM);
411 | for(int i = 0; i < gpu.GPU_COUNT; i++)
412 | {
413 | cudaSetDevice(i);
414 | fill_matrix(_Vocab_grad[i],0.0f);
415 | expand_vocab_gradient(arrGRAD[5][i],_currentBatchIdx_X[i],_Vocab_grad[i]);
416 | }
417 | gpu.RMSprop_with_nesterov_weight_update_PCIe(_MSVocab_grad,_Vocab_grad,_Vocab,M_VocabX,_RMS_multiplier,_learningRate/(float)_nBatchSize,1, MOMENTUM);
418 | gpu.RMSprop_with_nesterov_weight_update_PCIe(MSGRAD[3],arrGRAD[3],W[0],M[0],0.9f,_learningRate/(float)arrGRAD[3][gpu.MYRANK]->rows,_nBatchSize, MOMENTUM);
419 | gpu.RMSprop_with_nesterov_weight_update_PCIe(MSGRAD[0],arrGRAD[0],W[1],M[1],0.9f,_learningRate/(float)arrGRAD[0][gpu.MYRANK]->rows,_nBatchSize, MOMENTUM);
420 | gpu.RMSprop_with_nesterov_weight_update_PCIe(MSGRAD[1],arrGRAD[1],W[1],M[1],0.9f,_learningRate/(float)arrGRAD[1][gpu.MYRANK]->rows,_nBatchSize, MOMENTUM);
421 | gpu.RMSprop_with_nesterov_weight_update_PCIe(MSBGRAD[2],arrGRAD_B[2],B[0],M_B[0],0.9f,_learningRate,_nBatchSize, MOMENTUM);
422 | gpu.RMSprop_with_nesterov_weight_update_PCIe(MSBGRAD[0],arrGRAD_B[0],B[1],M_B[1],0.9f,_learningRate,_nBatchSize, MOMENTUM);
423 | gpu.RMSprop_with_nesterov_weight_update_PCIe(MSBGRAD[1],arrGRAD_B[1],B[1],M_B[1],0.9f,_learningRate,_nBatchSize, MOMENTUM);
424 | gpu.RMSprop_with_nesterov_weight_update_PCIe(MSBGRAD[3],arrGRAD_B[3],B[0],M_B[0],0.9f,_learningRate,_nBatchSize, MOMENTUM);
425 |
426 | }
427 |
428 | }
429 |
430 |
431 | void WikiMaxoutNet_PCIe::backprop()
432 | {
433 | for(int i = 0; i < gpu.GPU_COUNT; i++)
434 | {
435 | cudaSetDevice(i);
436 | pairwise_ranking(z2_X[i],z2_Y[i], out[i]);
437 | pairwise_ranking_derivative(z2_X[i],z2_Y[i], pairwise_grad[i]);
438 | }
439 |
440 |
441 |
442 | gpu.mul_PCIe(out, pairwise_grad, e1);
443 | gpu.dotTPCIe(e1, W[1],e2_partial);
444 |
445 | gpu.dotPCIe(aB,e1,arrGRAD_B[0]);
446 |
447 | //gpu.add_to_queue_PCIe(arrGRAD_B[0]);
448 |
449 | gpu.TdotPCIe(a1_Y,e1,arrGRAD[0]);
450 | gpu.add_to_queue_PCIe(arrGRAD[0]);
451 |
452 | for(int i = 0; i < gpu.GPU_COUNT; i++)
453 | {
454 | cudaSetDevice(i);
455 | logisticGrad(a1_Y[i],a1_Y[i]);
456 | }
457 | gpu.mul_PCIe(e2_partial,a1_Y,e2);
458 |
459 |
460 | gpu.TdotPCIe(_batchY,e2,arrGRAD[2]);
461 | //gpu.add_to_queue_PCIe(arrGRAD[2]);
462 | gpu.dotPCIe(aB,e2,arrGRAD_B[2]);
463 | //gpu.add_to_queue_PCIe(arrGRAD_B[2]);
464 | gpu.dotTPCIe(e2,W[0],arrGRAD[4]);
465 | //gpu.add_to_queue_PCIe(arrGRAD[4]);
466 |
467 | gpu.scalarMul_PCIe(pairwise_grad,-1.0f,pairwise_grad);
468 | gpu.mul_PCIe(out, pairwise_grad, e1);
469 |
470 | gpu.dotPCIe(aB,e1,arrGRAD_B[1]);
471 | //gpu.add_to_queue_PCIe(arrGRAD_B[1]);
472 | gpu.TdotPCIe(a1_X,e1,arrGRAD[1]);
473 | //gpu.add_to_queue_PCIe(arrGRAD[1]);
474 | gpu.dotTPCIe(e1, W[1],e2_partial);
475 |
476 |
477 | for(int i = 0; i < gpu.GPU_COUNT; i++)
478 | {
479 | cudaSetDevice(i);
480 | logisticGrad(a1_X[i],a1_X[i]);
481 | }
482 | gpu.mul_PCIe(e2_partial,a1_X,e2);
483 |
484 |
485 | gpu.TdotPCIe(_batchX,e2,arrGRAD[3]);
486 | //gpu.add_to_queue_PCIe(arrGRAD[3]);
487 | gpu.dotPCIe(aB,e2,arrGRAD_B[3]);
488 | //gpu.add_to_queue_PCIe(arrGRAD_B[3]);
489 | gpu.dotTPCIe(e2,W[0],arrGRAD[5]);
490 | //gpu.add_to_queue_PCIe(arrGRAD[5]);
491 | }
492 |
493 |
494 | double WikiMaxoutNet_PCIe::calculateError()
495 | {
496 | //scalarMul(W[0],0.9,W[0]);
497 | allocateNextBatch(true);
498 | for(int i = 0; i < _nCVErrorLength; i+=gpu.GPU_COUNT)
499 | {
500 |
501 | feedforward();
502 |
503 | for(int j = 0; j < gpu.GPU_COUNT; j++)
504 | {
505 | cudaSetDevice(j);
506 | pairwise_ranking(z2_X[j],z2_Y[j], out[j]);
507 | }
508 | cudaSetDevice(0);
509 | _dSumError += (double)sum(out[0]);
510 |
511 | allocateNextBatch(true);
512 | }
513 | //scalarMul(W[0],1.1,W[0]);
514 | //size_t free, total;
515 | //cudaMemGetInfo(&free, &total);
516 | //cout << free << endl;
517 | //cout << "Free system memory: " << sysconf(_SC_PAGE_SIZE)*sysconf(_SC_PHYS_PAGES) << endl;
518 |
519 |
520 | double error = _dSumError/(double)(_nBatchSize*_nCVErrorLength);
521 | _dSumError = 0.0;
522 |
523 |
524 |
525 | _nNextBatchNumber_CV = 0;
526 |
527 | return error;
528 | }
529 |
530 |
531 |
532 |
533 |
534 |
--------------------------------------------------------------------------------
/source/WikiMaxoutNet_PCIe.h:
--------------------------------------------------------------------------------
1 |
2 |
3 | #ifndef WIKIMAXOUTNET_PCIE_H_
4 | #define WIKIMAXOUTNET_PCIE_H_
5 |
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 |
18 | class WikiMaxoutNet_PCIe
19 | {
20 | public:
21 | WikiMaxoutNet_PCIe(ClusterNet *gpus);
22 | void run();
23 |
24 | private:
25 | Matrix **_currentBatchIdx_X;
26 | Matrix **_currentBatchIdx_Y;
27 | Matrix **_nextBatchIdx;
28 | ClusterNet gpu;
29 | Matrix *_X;
30 | Matrix *_CV_X;
31 | Matrix **_Vocab;
32 | Matrix **_Vocab_grad;
33 | Matrix **_MSVocab_grad;
34 | Matrix **_MSVocab_grad_Y;
35 | Matrix **M_VocabX;
36 | Matrix **M_VocabY;
37 | Matrix **_Vocab_grad_idx;
38 | Matrix **_batchX;
39 | Matrix **_batchY;
40 | Matrix *stackedVocabGrad_X;
41 | Matrix *stackedVocabGrad_Y;
42 | Matrix *stackedBatchIdx_X;
43 | Matrix *stackedBatchIdx_Y;
44 |
45 | Matrix **out;
46 | Matrix **pairwise_grad;
47 | Matrix **e1;
48 | Matrix **aB;
49 | Matrix **e2_partial;
50 | Matrix **e2;
51 | Matrix *CV_container;
52 |
53 | Matrix *learning_rate_matrix;
54 |
55 |
56 |
57 | int _nCurrentDataSet;
58 | int _nNextBatchNumber;
59 | int _nNextBatchNumber_CV;
60 | float _RMS_multiplier;
61 | int _nBatchSize;
62 | int _batches;
63 | std::vector _layers;
64 | std::vector W;
65 | std::vector B;
66 | std::vector M;
67 | std::vector M_B;
68 | std::vector arrGRAD;
69 | std::vector MSGRAD;
70 | std::vector arrGRAD_B;
71 | std::vector MSBGRAD;
72 | clock_t start,stop;
73 |
74 | Matrix **d0;
75 | Matrix **z1;
76 | Matrix **a1_Y;
77 | Matrix **a1_idx_Y;
78 | Matrix **a1_X;
79 | Matrix **a1_idx_X;
80 | Matrix **d1;
81 | Matrix **z2_X;
82 | Matrix **z2_Y;
83 |
84 | std::vector _streamNextBatch;
85 | double _dSumError;
86 | int _nCVErrorPeriodicity;
87 | int _nCVErrorLength;
88 | float MOMENTUM;
89 | float _learningRate;
90 | int _totalNumberOfBatches;
91 |
92 | bool useRMSProp;
93 | bool useMaxout;
94 |
95 |
96 | void loadNextDataSet();
97 | void allocateNextBatch(bool isCV);
98 | void feedforward();
99 | void nesterov();
100 | double calculateError();
101 | void backprop();
102 | void weightUpdates();
103 | };
104 |
105 |
106 | #endif /* WIKIMAXOUTNET_H_ */
107 |
--------------------------------------------------------------------------------
/source/WikiMaxoutNet_PCIe2.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | using std::cout;
4 | using std::endl;
5 |
6 | WikiMaxoutNet_PCIe2::WikiMaxoutNet_PCIe2()
7 | {
8 | current_device = 1;
9 | pthread_t t2;
10 | pthread_create(&t2,NULL,&WikiMaxoutNet_PCIe2::hello_helper,&this[0]);
11 | /*
12 | current_device = 1;
13 | pthread_t t2;
14 | pthread_create(&t2,NULL,WikiMaxoutNet_PCIe2::hello_helper,this);
15 | current_device = 1;
16 | pthread_t t3;
17 | pthread_create(&t3,NULL,WikiMaxoutNet_PCIe2::hello_helper,this);
18 | */
19 | }
20 |
21 |
22 | void WikiMaxoutNet_PCIe2::init()
23 | {
24 |
25 | cout << current_device << endl;
26 | cudaSetDevice(current_device);
27 | cout << "Mydevice: " << current_device << endl;
28 | gpu = ClusterNet(12354);
29 |
30 | cout << "post cluster init" << endl;
31 | /*
32 | int vocabSize = 100002;
33 | int nWordVectorDim = 64;
34 | int nWindowSize = 11;
35 | _layers.push_back(512);
36 | _nMaxoutSize = 1;
37 | _learningRate = 0.01;
38 | _nCVErrorPeriodicity = 5000;
39 | _nCVErrorLength = 5000;
40 | MOMENTUM = 0.5;
41 | _nCurrentDataSet = gpu.MYRANK;
42 | _X = 0;
43 | int cv_set_number = 63-gpu.MYRANK;
44 | _CV_X = read_hdf5(("/home/tim/data/wiki/extracted2/AA/data100000/wiki_" + NumberToString(cv_set_number) + ".p").c_str());
45 | _nNextBatchNumber = 0;
46 | _nNextBatchNumber_CV = 0;
47 | _nBatchSize = 128;
48 | _RMS_multiplier = 0.9f;
49 | cudaStreamCreate(&_streamNextBatch);
50 |
51 | Matrix *learning_rate_matrix_cpu = empty_cpu(nWordVectorDim,vocabSize);
52 |
53 | float learning_rate = 0.0000001;
54 | int next_level = 2000;
55 | for(int col = 0; col < vocabSize; col++)
56 | for(int row = 0; row < nWordVectorDim; row++)
57 | {
58 | if(col > next_level)
59 | {
60 | learning_rate = learning_rate * 10.00f;
61 | next_level = next_level == 50000 ? vocabSize : next_level;
62 | next_level = next_level == 25000 ? 50000 : next_level;
63 | next_level = next_level == 10000 ? 25000 : next_level;
64 | next_level = next_level == 2000 ? 10000 : next_level;
65 | }
66 |
67 | if((col == vocabSize-2) || (col == vocabSize-1))
68 | {
69 | learning_rate_matrix_cpu->data[col + (row*vocabSize)] = 0.0000001;
70 | }
71 | else
72 | {
73 | learning_rate_matrix_cpu->data[col + (row*vocabSize)] = learning_rate;
74 | }
75 | }
76 |
77 |
78 | learning_rate_matrix = to_gpu(learning_rate_matrix_cpu);
79 | free(learning_rate_matrix_cpu->data);
80 |
81 |
82 | useRMSProp = true;
83 | useMaxout = false;
84 |
85 | cout << "_nMaxoutSize: " << _nMaxoutSize << endl;
86 | cout << "_layers: " << _layers[0] << endl;
87 | cout << "nWordVectorDim: " << nWordVectorDim << endl;
88 | cout << "_nBatchSize: " << _nBatchSize << endl;
89 | cout << "_learningRate: " << _learningRate << endl;
90 | cout << "Use RMSProp: " << useRMSProp << endl;
91 |
92 | W.push_back(gpu.uniformSqrtWeight(nWordVectorDim*nWindowSize,_layers[0]));
93 | W.push_back(gpu.uniformSqrtWeight(_layers[0]/_nMaxoutSize, 1));
94 | B.push_back(zeros(1,_layers[0]));
95 | B.push_back(zeros(1,1));
96 | M.push_back(zeros(nWordVectorDim*nWindowSize,_layers[0]));
97 | M.push_back(zeros(_layers[0]/_nMaxoutSize, 1));
98 | M_B.push_back(zeros(1,_layers[0]));
99 | M_B.push_back(zeros(1,1));
100 |
101 | CV_container = empty_cpu(10000,1);
102 | for(int i = 0; i < CV_container->size; i++)
103 | CV_container->data[i] = 0.0f;
104 |
105 |
106 | if(gpu.MPI_SIZE == 0)
107 | gpu.MPI_SIZE = 1;
108 |
109 |
110 | cout << gpu.MPI_SIZE << " MPI SIZE" << endl;
111 | cout << gpu.MYRANK << " MYRANK " << endl;
112 | for(int i = W.size()-1; i >= 0; i--)
113 | {
114 | Matrix **gradX = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
115 | arrGRAD.push_back(gradX);
116 | Matrix **gradY = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
117 | arrGRAD.push_back(gradY);
118 | Matrix **gradX_B = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
119 | arrGRAD_B.push_back(gradX_B);
120 | Matrix **gradY_B = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
121 | arrGRAD_B.push_back(gradY_B);
122 | }
123 | Matrix **gradX = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
124 | arrGRAD.push_back(gradX);
125 | Matrix **gradY = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
126 | arrGRAD.push_back(gradY);
127 |
128 | cout << arrGRAD.size() << " size" << endl;
129 |
130 |
131 | for(int i = W.size()-1; i >= 0; i--)
132 | {
133 | MSGRAD.push_back(zeros(W[i]->rows, W[i]->cols));
134 | MSGRAD.push_back(zeros(W[i]->rows, W[i]->cols));
135 | MSBGRAD.push_back(zeros(B[i]->rows, B[i]->cols));
136 | MSBGRAD.push_back(zeros(B[i]->rows, B[i]->cols));
137 | }
138 |
139 | for(int j =0; j < gpu.MPI_SIZE; j++)
140 | {
141 | int idx = 0;
142 | for(int i = W.size()-1; i >= 0; i--)
143 | {
144 | arrGRAD[idx][j] = zeros(W[i]->rows, W[i]->cols);
145 | arrGRAD_B[idx][j] = zeros(B[i]->rows, B[i]->cols);
146 | idx++;
147 | arrGRAD[idx][j] = (zeros(W[i]->rows, W[i]->cols));
148 | arrGRAD_B[idx][j] = zeros(B[i]->rows, B[i]->cols);
149 | idx++;
150 | }
151 |
152 | arrGRAD[4][j] = zeros(_nBatchSize,nWordVectorDim*nWindowSize);
153 | arrGRAD[5][j] = zeros(_nBatchSize,nWordVectorDim*nWindowSize);
154 | }
155 |
156 |
157 | MSGRAD.push_back(zeros(_nBatchSize*gpu.MPI_SIZE,nWordVectorDim*nWindowSize));
158 | MSGRAD.push_back(zeros(_nBatchSize*gpu.MPI_SIZE,nWordVectorDim*nWindowSize));
159 |
160 | _Vocab = gpu.uniformSqrtWeight(nWordVectorDim,vocabSize);
161 | //_Vocab = gpu.sparseInitWeight(nWordVectorDim,vocabSize);
162 | //_Vocab = gpu.rand(nWordVectorDim,vocabSize);
163 | //scalarMul(_Vocab,0.01f,_Vocab);
164 | //scalarAdd(_Vocab,-0.5f,_Vocab);
165 | cout << sum(_Vocab) << endl;
166 | _Vocab_grad = zeros(nWordVectorDim,vocabSize);
167 | _MSVocab_grad = zeros(nWordVectorDim,vocabSize);
168 | _MSVocab_grad_Y = zeros(nWordVectorDim,vocabSize);
169 | M_VocabX = zeros(nWordVectorDim,vocabSize);
170 | M_VocabY = zeros(nWordVectorDim,vocabSize);
171 | _Vocab_grad_idx = zeros(1,vocabSize);
172 |
173 | d0 = zeros(_nBatchSize,nWordVectorDim*nWindowSize);
174 | z1 = zeros(_nBatchSize, _layers[0]);
175 | a1_Y = zeros(_nBatchSize, _layers[0]/_nMaxoutSize);
176 | a1_idx_Y = zeros(_nBatchSize, _layers[0]/_nMaxoutSize);
177 | a1_X = zeros(_nBatchSize, _layers[0]/_nMaxoutSize);
178 | a1_idx_X = zeros(_nBatchSize, _layers[0]/_nMaxoutSize);
179 | d1 = zeros(_nBatchSize, _layers[0]/_nMaxoutSize);
180 | z2_X = zeros(_nBatchSize, 1);
181 | z2_Y = zeros(_nBatchSize, 1);
182 |
183 | out = zeros(_nBatchSize,1);
184 | pairwise_grad = zeros(_nBatchSize,1);
185 | e1 = empty(_nBatchSize,1);
186 | aB = ones(1,_nBatchSize);
187 | e2_partial = zeros(_nBatchSize,W[1]->rows);
188 | e2 = empty(_nBatchSize,e2_partial->cols*_nMaxoutSize);
189 |
190 |
191 | _batchX = zeros(_nBatchSize, nWordVectorDim*nWindowSize);
192 | _batchY = zeros(_nBatchSize, nWordVectorDim*nWindowSize);
193 | _currentBatchIdx_X = zeros(_nBatchSize,nWindowSize);
194 | _currentBatchIdx_Y = zeros(_nBatchSize,nWindowSize);
195 | _nextBatchIdx = zeros(_nBatchSize,nWindowSize);
196 |
197 | _dSumError = 0.0;
198 | */
199 |
200 | loadNextDataSet();
201 | }
202 |
203 | void WikiMaxoutNet_PCIe2::run()
204 | {
205 |
206 | cout << current_device << endl;
207 | cudaSetDevice(current_device);
208 | init();
209 | allocateNextBatch(false);
210 |
211 |
212 | size_t freemem, total;
213 | cudaMemGetInfo(&freemem,&total);
214 | cout << freemem << endl;
215 |
216 |
217 | srand( time(NULL) );
218 | start = clock();
219 |
220 | int i = 0;
221 | while(true)
222 | {
223 | if(i > 0 && i % _nCVErrorPeriodicity == 0)
224 | {
225 | if( i > 0 && i % 100000 == 0)
226 | {
227 | cout << "Saving vocabulary matrix to disk..." << endl;
228 | Matrix *host = to_host(_Vocab);
229 | write_hdf5("/home/tim/data/wiki/vocab.hdf5",host);
230 | free(host->data);
231 | free(host);
232 | write_hdf5("/home/tim/data/wiki/CV_values.hdf5",CV_container);
233 | }
234 |
235 | double error = calculateError();
236 | CV_container->data[i/_nCVErrorPeriodicity] = (float)error;
237 | cout << "BatchNo: " << i << endl;
238 | cout << "Cross validation error: " << error << endl;
239 | i++;
240 |
241 | //MOMENTUM+= 0.01;
242 | //if( MOMENTUM > 0.95)
243 | //MOMENTUM = 0.95;
244 |
245 | _RMS_multiplier-= 0.01;
246 | if( _RMS_multiplier < 0.25)
247 | _RMS_multiplier = 0.25;
248 |
249 | stop = clock();
250 |
251 | double time_interval_seconds = (double((stop - start)) / CLOCKS_PER_SEC) ;
252 | cout << "Approximate time left in hours: " << ((1.0f/(((i*_nBatchSize)/(float)_X->rows)/63.0))*time_interval_seconds/(float)3600.0) -
253 | (time_interval_seconds/(float)3600.0)<< endl;
254 |
255 | }
256 | else
257 | {
258 | nesterov();
259 | cout << "nesterov" << endl;
260 | feedforward();
261 | backprop();
262 | weightUpdates();
263 | }
264 |
265 | //cout << i << endl;
266 | allocateNextBatch(false);
267 | i++;
268 | }
269 | }
270 |
271 | void WikiMaxoutNet_PCIe2::loadNextDataSet()
272 | {
273 |
274 | cout << "lololoading" << endl;
275 | std::string path = "/home/tim/data/wiki/extracted2/AA/data100000/wiki_";
276 | //std::string path = "/home/tim/data/wiki/extracted2/AA/data/wiki_";
277 | std::string number = "";
278 | std::string ending = ".p";
279 | //std::string ending = ".p.hdf5";
280 |
281 | if(_nCurrentDataSet < 10)
282 | number += "0";
283 |
284 | number+= NumberToString(_nCurrentDataSet);
285 |
286 | cout << "Loading next data set: " << (path + number + ending) << endl;
287 | if(_X != 0)
288 | cudaFreeHost(_X->data);
289 | _X = read_hdf5((path + number + ending).c_str());
290 | _nCurrentDataSet += gpu.MPI_SIZE;
291 | _batches = _X->rows/ _nBatchSize;
292 | _nNextBatchNumber = 0;
293 | }
294 |
295 | void WikiMaxoutNet_PCIe2::allocateNextBatch(bool isCV)
296 | {
297 | if(!isCV)
298 | {
299 | if(_nNextBatchNumber < 0)
300 | _nNextBatchNumber = 0;
301 |
302 | if (_nBatchSize*11*(_nNextBatchNumber+1) > _X->size)
303 | loadNextDataSet();
304 |
305 | if(_nNextBatchNumber > 0 || _nCurrentDataSet > gpu.MYRANK)
306 | {
307 | cudaStreamSynchronize(_streamNextBatch);
308 | to_col_major(_nextBatchIdx, _currentBatchIdx_X);
309 | gpu.construct_vocab_matrix(_currentBatchIdx_X, _currentBatchIdx_Y, _batchX, _batchY, _Vocab);
310 | }
311 |
312 |
313 |
314 | cudaMemcpyAsync(_nextBatchIdx->data,&_X->data[_nBatchSize*11*_nNextBatchNumber],
315 | _nBatchSize*11*sizeof(float),
316 | cudaMemcpyHostToDevice,_streamNextBatch);
317 |
318 |
319 | _nNextBatchNumber+=1;
320 | }
321 | else
322 | {
323 | if(_nNextBatchNumber_CV > 0)
324 | {
325 | cudaStreamSynchronize(_streamNextBatch);
326 | to_col_major(_nextBatchIdx, _currentBatchIdx_X);
327 | gpu.construct_vocab_matrix(_currentBatchIdx_X, _currentBatchIdx_Y, _batchX, _batchY, _Vocab);
328 | }
329 |
330 | cudaMemcpyAsync(_nextBatchIdx->data,&_CV_X->data[_nBatchSize*11*_nNextBatchNumber_CV],
331 | _nBatchSize*11*sizeof(float),
332 | cudaMemcpyHostToDevice,_streamNextBatch);
333 |
334 |
335 | _nNextBatchNumber_CV+=1;
336 | }
337 |
338 |
339 | }
340 |
341 | void WikiMaxoutNet_PCIe2::nesterov()
342 | {
343 | //nesterov
344 | for(int i = 0;i < M.size(); i++)
345 | {
346 | scalarMul(M[i],MOMENTUM,M[i]);
347 | add(W[i],M[i],W[i]);
348 | }
349 |
350 | for(int i = 0;i < B.size(); i++)
351 | {
352 | scalarMul(M_B[i],MOMENTUM,M_B[i]);
353 | add(B[i],M_B[i],B[i]);
354 | }
355 |
356 | scalarMul(M_VocabX, MOMENTUM, M_VocabX);
357 | add(_Vocab,M_VocabX,_Vocab);
358 |
359 | scalarMul(M_VocabY, MOMENTUM, M_VocabY);
360 | add(_Vocab,M_VocabY,_Vocab);
361 |
362 | }
363 |
364 |
365 | void WikiMaxoutNet_PCIe2::feedforward()
366 | {
367 |
368 | if(useMaxout)
369 | {
370 | //gpu.dropout(_batchX,d0,0.1);
371 | gpu.dot(_batchX,W[0],z1);
372 | addMatrixVector(z1,B[0],z1);
373 | maxout(z1, a1_X, a1_idx_X, _nMaxoutSize);
374 | gpu.dot(a1_X,W[1],z2_X);
375 | addMatrixVector(z2_X,B[1],z2_X);
376 |
377 | //gpu.dropout(_batchY,d0,0.1);
378 | gpu.dot(_batchY,W[0],z1);
379 | addMatrixVector(z1,B[0],z1);
380 | maxout(z1, a1_Y, a1_idx_Y, _nMaxoutSize);
381 | gpu.dot(a1_Y,W[1],z2_Y);
382 | addMatrixVector(z2_Y,B[1],z2_Y);
383 | }
384 | else
385 | {
386 | gpu.dot(_batchX,W[0],z1);
387 | addMatrixVector(z1,B[0],z1);
388 | logistic(z1,a1_X);
389 | gpu.dot(a1_X,W[1],z2_X);
390 | addMatrixVector(z2_X,B[1],z2_X);
391 |
392 | gpu.dot(_batchY,W[0],z1);
393 | addMatrixVector(z1,B[0],z1);
394 | logistic(z1,a1_Y);
395 | gpu.dot(a1_Y,W[1],z2_Y);
396 | addMatrixVector(z2_Y,B[1],z2_Y);
397 |
398 | }
399 | }
400 |
401 | void WikiMaxoutNet_PCIe2::weightUpdates()
402 | {
403 | float multiplier = _learningRate/(float)_nBatchSize;
404 |
405 | if(!useRMSProp)
406 | {
407 | scalarMul(arrGRAD[0][gpu.MYRANK],multiplier/(float)arrGRAD[0][gpu.MYRANK]->rows,arrGRAD[0][gpu.MYRANK]);
408 | scalarMul(arrGRAD[1][gpu.MYRANK],multiplier/(float)arrGRAD[1][gpu.MYRANK]->rows,arrGRAD[1][gpu.MYRANK]);
409 | scalarMul(arrGRAD[2][gpu.MYRANK],multiplier/(float)arrGRAD[1][gpu.MYRANK]->rows,arrGRAD[2][gpu.MYRANK]);
410 | scalarMul(arrGRAD[3][gpu.MYRANK],multiplier/(float)arrGRAD[1][gpu.MYRANK]->rows,arrGRAD[3][gpu.MYRANK]);
411 | scalarMul(arrGRAD_B[0][gpu.MYRANK],multiplier,arrGRAD_B[0][gpu.MYRANK]);
412 | scalarMul(arrGRAD_B[1][gpu.MYRANK],multiplier,arrGRAD_B[1][gpu.MYRANK]);
413 | scalarMul(arrGRAD_B[2][gpu.MYRANK],multiplier,arrGRAD_B[2][gpu.MYRANK]);
414 | scalarMul(arrGRAD_B[3][gpu.MYRANK],multiplier,arrGRAD_B[3][gpu.MYRANK]);
415 |
416 | sub(W[1],arrGRAD[0][gpu.MYRANK],W[1]);
417 | sub(W[1],arrGRAD[1][gpu.MYRANK],W[1]);
418 | sub(W[0],arrGRAD[2][gpu.MYRANK],W[0]);
419 | sub(W[0],arrGRAD[3][gpu.MYRANK],W[0]);
420 | sub(B[1],arrGRAD_B[0][gpu.MYRANK],B[1]);
421 | sub(B[1],arrGRAD_B[1][gpu.MYRANK],B[1]);
422 | sub(B[0],arrGRAD_B[2][gpu.MYRANK],B[0]);
423 | sub(B[0],arrGRAD_B[3][gpu.MYRANK],B[0]);
424 |
425 | update_vocab_with_gradient(arrGRAD[4][gpu.MYRANK],_currentBatchIdx_Y,_Vocab,multiplier);
426 | update_vocab_with_gradient(arrGRAD[5][gpu.MYRANK],_currentBatchIdx_X,_Vocab,multiplier);
427 | }
428 | else
429 | {
430 |
431 | //10*MPI_SIZE gradients added
432 |
433 |
434 |
435 |
436 | while(gpu.get_queue_length() > (9*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
437 | fill_matrix(_Vocab_grad,0.0f);
438 | gpu.pop_queue();
439 | expand_vocab_gradient(arrGRAD[4][gpu.MYRANK],_currentBatchIdx_Y,_Vocab_grad);
440 | gpu.pop_queue();
441 | RMSprop_with_nesterov_weight_update(_MSVocab_grad_Y,_Vocab_grad,_Vocab,M_VocabY,_RMS_multiplier,_learningRate/(float)_nBatchSize,1, MOMENTUM);
442 | while(gpu.get_queue_length() > (8*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
443 | RMSprop_with_nesterov_weight_update(MSGRAD[2],arrGRAD[2][gpu.MYRANK],W[0],M[0],0.9f,_learningRate/(float)arrGRAD[2][gpu.MYRANK]->rows,_nBatchSize, MOMENTUM);
444 |
445 | while(gpu.get_queue_length() > (7*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
446 | fill_matrix(_Vocab_grad,0.0f);
447 | gpu.pop_queue();
448 | expand_vocab_gradient(arrGRAD[5][gpu.MYRANK],_currentBatchIdx_X,_Vocab_grad);
449 | gpu.pop_queue();
450 | RMSprop_with_nesterov_weight_update(_MSVocab_grad,_Vocab_grad,_Vocab,M_VocabX,_RMS_multiplier,_learningRate/(float)_nBatchSize,1, MOMENTUM);
451 |
452 | while(gpu.get_queue_length() > (6*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
453 | RMSprop_with_nesterov_weight_update(MSGRAD[3],arrGRAD[3][gpu.MYRANK],W[0],M[0],0.9f,_learningRate/(float)arrGRAD[3][gpu.MYRANK]->rows,_nBatchSize, MOMENTUM);
454 | gpu.pop_queue();
455 | while(gpu.get_queue_length() > (5*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
456 | RMSprop_with_nesterov_weight_update(MSGRAD[0],arrGRAD[0][gpu.MYRANK],W[1],M[1],0.9f,_learningRate/(float)arrGRAD[0][gpu.MYRANK]->rows,_nBatchSize, MOMENTUM);
457 | while(gpu.get_queue_length() > (4*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
458 | RMSprop_with_nesterov_weight_update(MSGRAD[1],arrGRAD[1][gpu.MYRANK],W[1],M[1],0.9f,_learningRate/(float)arrGRAD[1][gpu.MYRANK]->rows,_nBatchSize, MOMENTUM);
459 | while(gpu.get_queue_length() > (3*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
460 | RMSprop_with_nesterov_weight_update(MSBGRAD[2],arrGRAD_B[2][gpu.MYRANK],B[0],M_B[0],0.9f,_learningRate,_nBatchSize, MOMENTUM);
461 | while(gpu.get_queue_length() > (2*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
462 | RMSprop_with_nesterov_weight_update(MSBGRAD[0],arrGRAD_B[0][gpu.MYRANK],B[1],M_B[1],0.9f,_learningRate,_nBatchSize, MOMENTUM);
463 | while(gpu.get_queue_length() > (1*(gpu.MPI_SIZE-1))){ gpu.pop_queue(); }
464 | RMSprop_with_nesterov_weight_update(MSBGRAD[1],arrGRAD_B[1][gpu.MYRANK],B[1],M_B[1],0.9f,_learningRate,_nBatchSize, MOMENTUM);
465 | while(gpu.get_queue_length() > 0){ gpu.pop_queue(); }
466 | RMSprop_with_nesterov_weight_update(MSBGRAD[3],arrGRAD_B[3][gpu.MYRANK],B[0],M_B[0],0.9f,_learningRate,_nBatchSize, MOMENTUM);
467 |
468 |
469 | }
470 |
471 | }
472 |
473 | void WikiMaxoutNet_PCIe2::backprop()
474 | {
475 | pairwise_ranking(z2_X,z2_Y, out);
476 | pairwise_ranking_derivative(z2_X,z2_Y, pairwise_grad);
477 |
478 | mul(out, pairwise_grad, e1);
479 | gpu.dotT(e1, W[1],e2_partial);
480 |
481 | gpu.dot(aB,e1,arrGRAD_B[0][gpu.MYRANK]);
482 | gpu.tick();
483 | gpu.Tdot(a1_Y,e1,arrGRAD[0][gpu.MYRANK]);
484 |
485 | if(!useMaxout)
486 | {
487 | logisticGrad(a1_Y,a1_Y);
488 | mul(e2_partial,a1_Y,e2);
489 | }
490 | else
491 | {
492 | expand_to_maxout_grad(e2_partial, a1_idx_Y,e2);
493 | }
494 | gpu.Tdot(_batchY,e2,arrGRAD[2][gpu.MYRANK]);
495 | gpu.dot(aB,e2,arrGRAD_B[2][gpu.MYRANK]);
496 | gpu.dotT(e2,W[0],arrGRAD[4][gpu.MYRANK]);
497 |
498 | gpu.add_to_queue(arrGRAD[4]);
499 | gpu.add_to_queue(arrGRAD[2]);
500 |
501 |
502 | scalarMul(pairwise_grad,-1.0f,pairwise_grad);
503 | gpu.pop_queue();
504 | mul(out, pairwise_grad, e1);
505 | gpu.pop_queue();
506 |
507 | gpu.dot(aB,e1,arrGRAD_B[1][gpu.MYRANK]);
508 | gpu.pop_queue();
509 | gpu.Tdot(a1_X,e1,arrGRAD[1][gpu.MYRANK]);
510 | gpu.pop_queue();
511 | gpu.dotT(e1, W[1],e2_partial);
512 | gpu.pop_queue();
513 |
514 |
515 | if(!useMaxout)
516 | {
517 | logisticGrad(a1_X,a1_X);
518 | gpu.pop_queue();
519 | mul(e2_partial,a1_X,e2);
520 | gpu.pop_queue();
521 | }
522 | else
523 | {
524 | expand_to_maxout_grad(e2_partial, a1_idx_X,e2);
525 | }
526 | gpu.Tdot(_batchX,e2,arrGRAD[3][gpu.MYRANK]);
527 | gpu.pop_queue();
528 | gpu.dot(aB,e2,arrGRAD_B[3][gpu.MYRANK]);
529 | gpu.pop_queue();
530 | gpu.dotT(e2,W[0],arrGRAD[5][gpu.MYRANK]);
531 |
532 | gpu.add_to_queue(arrGRAD[5]);
533 | gpu.add_to_queue(arrGRAD[3]);
534 |
535 |
536 | gpu.add_to_queue(arrGRAD[0]);
537 | gpu.add_to_queue(arrGRAD[1]);
538 |
539 | gpu.add_to_queue(arrGRAD_B[2]);
540 | gpu.add_to_queue(arrGRAD_B[0]);
541 | gpu.add_to_queue(arrGRAD_B[1]);
542 | gpu.add_to_queue(arrGRAD_B[3]);
543 |
544 | }
545 |
546 | double WikiMaxoutNet_PCIe2::calculateError()
547 | {
548 | //scalarMul(W[0],0.9,W[0]);
549 | allocateNextBatch(true);
550 | for(int i = 0; i < _nCVErrorLength; i++)
551 | {
552 |
553 | feedforward();
554 |
555 | pairwise_ranking(z2_X,z2_Y, out);
556 | _dSumError += (double)sum(out);
557 |
558 | allocateNextBatch(true);
559 | }
560 | //scalarMul(W[0],1.1,W[0]);
561 | //size_t free, total;
562 | //cudaMemGetInfo(&free, &total);
563 | //cout << free << endl;
564 | //cout << "Free system memory: " << sysconf(_SC_PAGE_SIZE)*sysconf(_SC_PHYS_PAGES) << endl;
565 |
566 |
567 | double error = _dSumError/(double)(_nBatchSize*_nCVErrorLength);
568 | _dSumError = 0.0;
569 |
570 |
571 |
572 | _nNextBatchNumber_CV = 0;
573 |
574 | return error;
575 | }
576 |
577 |
578 |
579 |
580 |
--------------------------------------------------------------------------------
/source/WikiMaxoutNet_PCIe2.h:
--------------------------------------------------------------------------------
1 | /*
2 | * WikiMaxoutNet.h
3 | *
4 | * Created on: Jun 25, 2014
5 | * Author: tim
6 | */
7 |
8 | #ifndef WIKIMAXOUTNET_PCIE2_H_
9 | #define WIKIMAXOUTNET_PCIE2_H_
10 |
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | class WikiMaxoutNet_PCIe2
24 | {
25 | public:
26 | WikiMaxoutNet_PCIe2();
27 | void run();
28 |
29 | void *hello(void)
30 | {
31 | run();
32 |
33 | return 0;
34 | }
35 |
36 | static void *hello_helper(void *context)
37 | {
38 | return ((WikiMaxoutNet_PCIe2 *)context)->hello();
39 | }
40 |
41 | void init();
42 |
43 | private:
44 | Matrix *_currentBatchIdx_X;
45 | Matrix *_currentBatchIdx_Y;
46 | Matrix *_nextBatchIdx;
47 | ClusterNet gpu;
48 | Matrix *_X;
49 | Matrix *_CV_X;
50 | Matrix *_Vocab;
51 | Matrix *_Vocab_grad;
52 | Matrix *_MSVocab_grad;
53 | Matrix *_MSVocab_grad_Y;
54 | Matrix *M_VocabX;
55 | Matrix *M_VocabY;
56 | Matrix *_Vocab_grad_idx;
57 | Matrix *_batchX;
58 | Matrix *_batchY;
59 | Matrix *stackedVocabGrad_X;
60 | Matrix *stackedVocabGrad_Y;
61 | Matrix *stackedBatchIdx_X;
62 | Matrix *stackedBatchIdx_Y;
63 |
64 | Matrix *out;
65 | Matrix *pairwise_grad;
66 | Matrix *e1;
67 | Matrix *aB;
68 | Matrix *e2_partial;
69 | Matrix *e2;
70 | Matrix *CV_container;
71 |
72 | Matrix *learning_rate_matrix;
73 |
74 |
75 |
76 | int _nCurrentDataSet;
77 | int _nNextBatchNumber;
78 | int _nNextBatchNumber_CV;
79 | float _RMS_multiplier;
80 | int _nBatchSize;
81 | int _batches;
82 | std::vector _layers;
83 | std::vector W;
84 | std::vector B;
85 | std::vector M;
86 | std::vector M_B;
87 | std::vector arrGRAD;
88 | std::vector MSGRAD;
89 | std::vector arrGRAD_B;
90 | std::vector MSBGRAD;
91 | clock_t start,stop;
92 |
93 | int current_device;
94 |
95 | Matrix *d0;
96 | Matrix *z1;
97 | Matrix *a1_Y;
98 | Matrix *a1_idx_Y;
99 | Matrix *a1_X;
100 | Matrix *a1_idx_X;
101 | Matrix *d1;
102 | Matrix *z2_X;
103 | Matrix *z2_Y;
104 |
105 | cudaStream_t _streamNextBatch;
106 | double _dSumError;
107 | int _nCVErrorPeriodicity;
108 | int _nCVErrorLength;
109 | int _nMaxoutSize;
110 | float MOMENTUM;
111 | float _learningRate;
112 | int _totalNumberOfBatches;
113 |
114 | bool useRMSProp;
115 | bool useMaxout;
116 |
117 |
118 | void loadNextDataSet();
119 | void allocateNextBatch(bool isCV);
120 | void feedforward();
121 | void nesterov();
122 | double calculateError();
123 | void backprop();
124 | void weightUpdates();
125 | };
126 |
127 |
128 | #endif /* WIKIMAXOUTNET_H_ */
129 |
--------------------------------------------------------------------------------
/source/WikiNetDist.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | using std::cout;
4 | using std::endl;
5 |
6 | WikiNetDist::WikiNetDist(ClusterNet gpus)
7 | {
8 |
9 | int vocabSize = 100002;
10 | int nWordVectorDim = 120;
11 | int nWindowSize = 11;
12 | _layers.push_back(128);
13 | _learningRate = 0.001;
14 | _nCVErrorPeriodicity = 6000;
15 | _nCVErrorLength = 6000;
16 | MOMENTUM = 0.9;
17 | gpu = gpus;
18 | _nCurrentDataSet = 0;
19 | _X = 0;
20 | int cv_set_number = 63;
21 | _CV_X = read_hdf5(("/home/tim/data/wiki/extracted2/AA/data100000/wiki_" + NumberToString(cv_set_number) + ".p").c_str());
22 | _nNextBatchNumber = 0;
23 | _nNextBatchNumber_CV = 0;
24 | _nBatchSize = 256;
25 | _RMS_multiplier = 0.9f;
26 |
27 |
28 | cudaStreamCreate(&_streamNextBatch);
29 |
30 | useRMSProp = true;
31 |
32 |
33 | cout << "_layers: " << _layers[0] << endl;
34 | cout << "nWordVectorDim: " << nWordVectorDim << endl;
35 | cout << "_nBatchSize: " << _nBatchSize << endl;
36 | cout << "_learningRate: " << _learningRate << endl;
37 | cout << "Use RMSProp: " << useRMSProp << endl;
38 |
39 | cout << "layer size: " << _layers[0] << endl;
40 | W.push_back(gpu.distributed_uniformSqrtWeight(nWordVectorDim*nWindowSize,_layers[0]));
41 | W.push_back(gpu.distributed_uniformSqrtWeight(_layers[0], 1));
42 | B.push_back(zeros(1,_layers[0]));
43 | B.push_back(zeros(1,1));
44 | M.push_back(gpu.distributed_zeros(nWordVectorDim*nWindowSize,_layers[0]));
45 | M.push_back(gpu.distributed_zeros(_layers[0], 1));
46 | M_B.push_back(zeros(1,_layers[0]));
47 | M_B.push_back(zeros(1,1));
48 |
49 |
50 | CV_container = empty_cpu(10000,1);
51 | for(int i = 0; i < CV_container->size; i++)
52 | CV_container->data[i] = 0.0f;
53 |
54 |
55 | if(gpu.MPI_SIZE == 0)
56 | gpu.MPI_SIZE = 1;
57 |
58 |
59 | cout << gpu.MPI_SIZE << " MPI SIZE" << endl;
60 | cout << gpu.MYRANK << " MYRANK " << endl;
61 | for(int i = W.size()-1; i >= 0; i--)
62 | {
63 | Matrix **gradX = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
64 | arrGRAD.push_back(gradX);
65 | Matrix **gradY = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
66 | arrGRAD.push_back(gradY);
67 | Matrix **gradX_B = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
68 | arrGRAD_B.push_back(gradX_B);
69 | Matrix **gradY_B = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
70 | arrGRAD_B.push_back(gradY_B);
71 | }
72 | Matrix **gradX = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
73 | arrGRAD.push_back(gradX);
74 | Matrix **gradY = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
75 | arrGRAD.push_back(gradY);
76 |
77 | cout << arrGRAD.size() << " size" << endl;
78 |
79 |
80 | for(int i = W.size()-1; i >= 0; i--)
81 | {
82 | MSGRAD.push_back(gpu.distributed_zeros(W[i]->rows, W[i]->isDistributed ? W[i]->cols_distributed : W[i]->cols));
83 | MSGRAD.push_back(gpu.distributed_zeros(W[i]->rows, W[i]->isDistributed ? W[i]->cols_distributed : W[i]->cols));
84 | MSBGRAD.push_back(zeros(B[i]->rows, B[i]->cols));
85 | MSBGRAD.push_back(zeros(B[i]->rows, B[i]->cols));
86 | }
87 |
88 | for(int j =0; j < gpu.MPI_SIZE; j++)
89 | {
90 | int idx = 0;
91 | for(int i = W.size()-1; i >= 0; i--)
92 | {
93 | arrGRAD[idx][j] = gpu.distributed_zeros(W[i]->rows, W[i]->isDistributed ? W[i]->cols_distributed : W[i]->cols);
94 | arrGRAD_B[idx][j] = zeros(B[i]->rows, B[i]->cols);
95 | idx++;
96 | arrGRAD[idx][j] = (gpu.distributed_zeros(W[i]->rows, W[i]->isDistributed ? W[i]->cols_distributed : W[i]->cols));
97 | arrGRAD_B[idx][j] = zeros(B[i]->rows, B[i]->cols);
98 | idx++;
99 | }
100 |
101 | arrGRAD[4][j] = zeros(_nBatchSize,nWordVectorDim*nWindowSize);
102 | arrGRAD[5][j] = zeros(_nBatchSize,nWordVectorDim*nWindowSize);
103 | }
104 |
105 |
106 | /*
107 | if(gpu.MYRANK == 0)
108 | {
109 | Matrix *init_vocab = gpu.uniformSqrtWeight(nWordVectorDim,vocabSize);
110 | _Vocab = slice_rows(init_vocab,0,(nWordVectorDim/gpu.MPI_SIZE)-1);
111 | cout << _Vocab->size << endl;
112 |
113 | for(int i = 1; i < gpu.MPI_SIZE; i++)
114 | {
115 | Matrix *slice = slice_rows(init_vocab,(nWordVectorDim/gpu.MPI_SIZE)*i,((nWordVectorDim/gpu.MPI_SIZE)*(i+1))-1);
116 | MPI_Send(slice->data,slice->size,MPI_FLOAT,i,0,MPI_COMM_WORLD);
117 | cudaFree(slice->data);
118 | }
119 | }
120 | else
121 | {
122 | _Vocab = zeros(nWordVectorDim/gpu.MPI_SIZE,vocabSize);
123 | cout << _Vocab->size << endl;
124 | MPI_Recv(_Vocab->data,_Vocab->size,MPI_FLOAT,0,0,MPI_COMM_WORLD,MPI_STATUS_IGNORE);
125 | }
126 | */
127 |
128 |
129 | _Vocab = gpu.uniformSqrtWeight(nWordVectorDim/gpu.MPI_SIZE,vocabSize,nWordVectorDim,vocabSize);
130 | //_Vocab = gpu.sparseInitWeight(nWordVectorDim,vocabSize);
131 | //_Vocab = gpu.rand(nWordVectorDim/gpu.MPI_SIZE,vocabSize);
132 | //scalarMul(_Vocab,0.01f,_Vocab);
133 | //scalarAdd(_Vocab,-0.5f,_Vocab);
134 | cout << sum(_Vocab) << endl;
135 | _Vocab_grad = zeros(nWordVectorDim/gpu.MPI_SIZE,vocabSize);
136 | _Vocab_grad_full = zeros(nWordVectorDim,vocabSize);
137 | _MSVocab_grad = zeros(nWordVectorDim/gpu.MPI_SIZE,vocabSize);
138 | _MSVocab_grad_Y = zeros(nWordVectorDim/gpu.MPI_SIZE,vocabSize);
139 | M_VocabX = zeros(nWordVectorDim/gpu.MPI_SIZE,vocabSize);
140 | M_VocabY = zeros(nWordVectorDim/gpu.MPI_SIZE,vocabSize);
141 | _Vocab_grad_idx = zeros(1,vocabSize);
142 |
143 | d0 = zeros(_nBatchSize,nWordVectorDim*nWindowSize);
144 | z1 = zeros(_nBatchSize, _layers[0]);
145 | a1_Y = zeros(_nBatchSize, _layers[0]);
146 | a1_idx_Y = zeros(_nBatchSize, _layers[0]);
147 | a1_X = zeros(_nBatchSize, _layers[0]);
148 | a1_idx_X = zeros(_nBatchSize, _layers[0]);
149 | d1 = zeros(_nBatchSize, _layers[0]);
150 | z2_X = zeros(_nBatchSize, 1);
151 | z2_Y = zeros(_nBatchSize, 1);
152 |
153 | out = zeros(_nBatchSize,1);
154 | pairwise_grad = zeros(_nBatchSize,1);
155 | e1 = empty(_nBatchSize,1);
156 | aB = ones(1,_nBatchSize);
157 | e2_partial = zeros(_nBatchSize,W[1]->rows);
158 | e2 = empty(_nBatchSize,e2_partial->cols);
159 |
160 |
161 | _batchX = zeros(_nBatchSize, nWordVectorDim*nWindowSize);
162 | _batchY = zeros(_nBatchSize, nWordVectorDim*nWindowSize);
163 |
164 | stackedBatch_X = gpu.zeros_stacked(_nBatchSize, nWordVectorDim*nWindowSize/gpu.MPI_SIZE);
165 | stackedBatch_Y = gpu.zeros_stacked(_nBatchSize, nWordVectorDim*nWindowSize/gpu.MPI_SIZE);
166 |
167 | _currentBatchIdx_X = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
168 | _currentBatchIdx_Y = (Matrix**)malloc(sizeof(Matrix*)*gpu.MPI_SIZE);
169 | for(int i = 0; i < gpu.MPI_SIZE; i++)
170 | {
171 | _currentBatchIdx_X[i] = zeros(_nBatchSize, nWindowSize);
172 | _currentBatchIdx_Y[i] = zeros(_nBatchSize, nWindowSize);
173 | }
174 | _nextBatchIdx = zeros(_nBatchSize,nWindowSize);
175 |
176 | _dSumError = 0.0;
177 |
178 | loadNextDataSet();
179 |
180 |
181 | }
182 | void WikiNetDist::run()
183 | {
184 | allocateNextBatch(false);
185 |
186 | size_t freemem, total;
187 | cudaMemGetInfo(&freemem,&total);
188 | cout << freemem << endl;
189 |
190 |
191 | srand( time(NULL) );
192 | start = clock();
193 |
194 | int i = 0;
195 | while(true)
196 | {
197 | if(i > 0 && i % _nCVErrorPeriodicity == 0)
198 | {
199 | if( i > 0 && i % 12000 == 0)
200 | {
201 | cout << "Saving vocabulary matrix to disk..." << endl;
202 | Matrix *host = to_host(_Vocab);
203 | write_hdf5(("/home/tim/data/wiki/vocab" + NumberToString(gpu.MYRANK) + ".hdf5").c_str(),host);
204 | free(host->data);
205 | free(host);
206 | write_hdf5(("/home/tim/data/wiki/CV_values" + NumberToString(gpu.MYRANK) + ".hdf5").c_str(),CV_container);
207 | }
208 |
209 | double error = calculateError();
210 | CV_container->data[i/_nCVErrorPeriodicity] = (float)error;
211 | cout << "BatchNo: " << i << endl;
212 | cout << "Cross validation error: " << error << endl;
213 | i++;
214 |
215 | //MOMENTUM+= 0.01;
216 | //if( MOMENTUM > 0.95)
217 | //MOMENTUM = 0.95;
218 |
219 | _RMS_multiplier-= 0.01;
220 | if( _RMS_multiplier < 0.25)
221 | _RMS_multiplier = 0.25;
222 |
223 |
224 | MOMENTUM-= 0.01;
225 | if( MOMENTUM < 0.25)
226 | MOMENTUM = 0.25;
227 |
228 | stop = clock();
229 |
230 | double time_interval_seconds = (double((stop - start)) / CLOCKS_PER_SEC) ;
231 | cout << "Approximate time left in hours: " << ((1.0f/(((i*_nBatchSize)/(float)_X->rows)/63.0))*time_interval_seconds/(float)3600.0) -
232 | (time_interval_seconds/(float)3600.0)<< endl;
233 |
234 | }
235 | else
236 | {
237 | nesterov();
238 | feedforward();
239 | backprop();
240 | weightUpdates();
241 | }
242 |
243 | //cout << i << endl;
244 | allocateNextBatch(false);
245 | i++;
246 | }
247 | }
248 |
249 | void WikiNetDist::loadNextDataSet()
250 | {
251 |
252 |
253 | std::string path = "/home/tim/data/wiki/extracted2/AA/data100000/wiki_";
254 | //std::string path = "/home/tim/data/wiki/extracted2/AA/data/wiki_";
255 | std::string number = "";
256 | std::string ending = ".p";
257 | //std::string ending = ".p.hdf5";
258 |
259 | if(_nCurrentDataSet < 10)
260 | number += "0";
261 |
262 | number+= NumberToString(_nCurrentDataSet);
263 |
264 | cout << "Loading next data set: " << (path + number + ending) << endl;
265 | if(_X != 0)
266 | cudaFreeHost(_X->data);
267 | _X = read_hdf5((path + number + ending).c_str());
268 | _nCurrentDataSet++;
269 | _batches = _X->rows/ _nBatchSize;
270 | _nNextBatchNumber = 0;
271 | }
272 |
273 | void WikiNetDist::allocateNextBatch(bool isCV)
274 | {
275 | if(!isCV)
276 | {
277 | if(_nNextBatchNumber < 0)
278 | _nNextBatchNumber = 0;
279 |
280 | if (_nBatchSize*11*(_nNextBatchNumber+1) > _X->size)
281 | loadNextDataSet();
282 |
283 | if(_nNextBatchNumber > 0 || _nCurrentDataSet > 0)
284 | {
285 | cudaStreamSynchronize(_streamNextBatch);
286 | to_col_major(_nextBatchIdx, _currentBatchIdx_X[gpu.MYRANK]);
287 | gpu.construct_vocab_matrix(_currentBatchIdx_X[gpu.MYRANK], _currentBatchIdx_Y[gpu.MYRANK], stackedBatch_X[gpu.MYRANK], stackedBatch_Y[gpu.MYRANK], _Vocab);
288 | gpu.add_to_queue(stackedBatch_X);
289 | gpu.add_to_queue(stackedBatch_Y);
290 | while(gpu.get_queue_length() > 0){ gpu.pop_queue(); }
291 | concatVocabBatchesN(stackedBatch_X, stackedBatch_Y, _batchX, _batchY,11,gpu.MPI_SIZE);
292 | }
293 |
294 |
295 | cudaMemcpyAsync(_nextBatchIdx->data,&_X->data[_nBatchSize*11*_nNextBatchNumber],
296 | _nBatchSize*11*sizeof(float),
297 | cudaMemcpyHostToDevice,_streamNextBatch);
298 |
299 |
300 | _nNextBatchNumber+=1;
301 | }
302 | else
303 | {
304 | if(_nNextBatchNumber_CV > 0)
305 | {
306 | cudaStreamSynchronize(_streamNextBatch);
307 | to_col_major(_nextBatchIdx, _currentBatchIdx_X[gpu.MYRANK]);
308 | gpu.construct_vocab_matrix(_currentBatchIdx_X[gpu.MYRANK], _currentBatchIdx_Y[gpu.MYRANK], stackedBatch_X[gpu.MYRANK], stackedBatch_Y[gpu.MYRANK], _Vocab);
309 |
310 | gpu.add_to_queue(stackedBatch_X);
311 | gpu.add_to_queue(stackedBatch_Y);
312 | while(gpu.get_queue_length() > 0){ gpu.pop_queue(); }
313 | concatVocabBatchesN(stackedBatch_X, stackedBatch_Y, _batchX, _batchY,11,gpu.MPI_SIZE);
314 | }
315 |
316 | cudaMemcpyAsync(_nextBatchIdx->data,&_CV_X->data[_nBatchSize*11*_nNextBatchNumber_CV],
317 | _nBatchSize*11*sizeof(float),
318 | cudaMemcpyHostToDevice,_streamNextBatch);
319 |
320 |
321 | _nNextBatchNumber_CV+=1;
322 | }
323 |
324 |
325 | }
326 |
327 | void WikiNetDist::nesterov()
328 | {
329 | //nesterov
330 | for(int i = 0;i < M.size(); i++)
331 | {
332 | scalarMul(M[i],MOMENTUM,M[i]);
333 | add(W[i],M[i],W[i]);
334 | }
335 |
336 | for(int i = 0;i < M_B.size(); i++)
337 | {
338 | scalarMul(M_B[i],MOMENTUM,M_B[i]);
339 | add(B[i],M_B[i],B[i]);
340 | }
341 |
342 | scalarMul(M_VocabX, MOMENTUM, M_VocabX);
343 | add(_Vocab,M_VocabX,_Vocab);
344 |
345 | scalarMul(M_VocabY, MOMENTUM, M_VocabY);
346 | add(_Vocab,M_VocabY,_Vocab);
347 |
348 | }
349 |
350 |
351 | void WikiNetDist::feedforward()
352 | {
353 | gpu.dot(_batchX,W[0],z1);
354 | addMatrixVector(z1,B[0],z1);
355 | logistic(z1,a1_X);
356 | gpu.dot(a1_X,W[1],z2_X);
357 | addMatrixVector(z2_X,B[1],z2_X);
358 |
359 | gpu.dot(_batchY,W[0],z1);
360 | addMatrixVector(z1,B[0],z1);
361 | logistic(z1,a1_Y);
362 | gpu.dot(a1_Y,W[1],z2_Y);
363 | addMatrixVector(z2_Y,B[1],z2_Y);
364 |
365 | }
366 |
367 | void WikiNetDist::weightUpdates()
368 | {
369 | float multiplier = _learningRate/(float)_nBatchSize;
370 |
371 | if(!useRMSProp)
372 | {
373 | scalarMul(M_VocabX,MOMENTUM,M_VocabX);
374 | scalarMul(M_VocabY,MOMENTUM,M_VocabY);
375 | for(int i = 0; i < M.size(); i++)
376 | scalarMul(M[i],MOMENTUM,M[i]);
377 | for(int i = 0; i < M_B.size(); i++)
378 | scalarMul(M_B[i],MOMENTUM,M_B[i]);
379 |
380 | update_vocab_with_gradient(arrGRAD[4][gpu.MYRANK],_currentBatchIdx_Y[gpu.MYRANK],_Vocab,multiplier/(float)gpu.MPI_SIZE);
381 | update_vocab_with_gradient(arrGRAD[4][gpu.MYRANK],_currentBatchIdx_Y[gpu.MYRANK],M_VocabY,multiplier/(float)gpu.MPI_SIZE);
382 | sub(W[0],arrGRAD[2][gpu.MYRANK],W[0]);
383 | sub(M[0],arrGRAD[2][gpu.MYRANK],M[0]);
384 | update_vocab_with_gradient(arrGRAD[5][gpu.MYRANK],_currentBatchIdx_X[gpu.MYRANK],_Vocab,multiplier/(float)gpu.MPI_SIZE);
385 | update_vocab_with_gradient(arrGRAD[5][gpu.MYRANK],_currentBatchIdx_X[gpu.MYRANK],M_VocabX,multiplier/(float)gpu.MPI_SIZE);
386 | sub(W[0],arrGRAD[3][gpu.MYRANK],W[0]);
387 | sub(M[0],arrGRAD[3][gpu.MYRANK],M[0]);
388 | sub(W[1],arrGRAD[0][gpu.MYRANK],W[1]);
389 | sub(M[1],arrGRAD[0][gpu.MYRANK],M[1]);
390 | sub(W[1],arrGRAD[1][gpu.MYRANK],W[1]);
391 | sub(M[1],arrGRAD[1][gpu.MYRANK],M[1]);
392 | sub(B[0],arrGRAD_B[2][gpu.MYRANK],B[0]);
393 | sub(M_B[0],arrGRAD_B[2][gpu.MYRANK],M_B[0]);
394 | sub(B[1],arrGRAD_B[0][gpu.MYRANK],B[1]);
395 | sub(M_B[1],arrGRAD_B[0][gpu.MYRANK],M_B[1]);
396 | sub(B[1],arrGRAD_B[1][gpu.MYRANK],B[1]);
397 | sub(M_B[1],arrGRAD_B[1][gpu.MYRANK],M_B[1]);
398 | sub(B[0],arrGRAD_B[3][gpu.MYRANK],B[0]);
399 | sub(M_B[0],arrGRAD_B[3][gpu.MYRANK],M_B[0]);
400 | }
401 | else
402 | {
403 |
404 | //10*MPI_SIZE gradients added
405 |
406 |
407 | fill_matrix(_Vocab_grad,0.0f);
408 | expand_partial_vocab_gradient(arrGRAD[4][gpu.MYRANK],_currentBatchIdx_Y[gpu.MYRANK],_Vocab_grad,gpu.MYRANK,gpu.MPI_SIZE);
409 |
410 | RMSprop_with_nesterov_weight_update(_MSVocab_grad_Y,_Vocab_grad,_Vocab,M_VocabY,_RMS_multiplier,_learningRate/(float)_nBatchSize,_nBatchSize, MOMENTUM);
411 |
412 | fill_matrix(_Vocab_grad,0.0f);
413 | expand_partial_vocab_gradient(arrGRAD[5][gpu.MYRANK],_currentBatchIdx_X[gpu.MYRANK],_Vocab_grad, gpu.MYRANK,gpu.MPI_SIZE);
414 | RMSprop_with_nesterov_weight_update(_MSVocab_grad,_Vocab_grad,_Vocab,M_VocabX,_RMS_multiplier,_learningRate/(float)_nBatchSize,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
415 |
416 |
417 |
418 | RMSprop_with_nesterov_weight_update(MSGRAD[2],arrGRAD[2][gpu.MYRANK],W[0],M[0],0.9f,_learningRate/(float)arrGRAD[2][gpu.MYRANK]->rows,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
419 |
420 |
421 | RMSprop_with_nesterov_weight_update(MSGRAD[3],arrGRAD[3][gpu.MYRANK],W[0],M[0],0.9f,_learningRate/(float)arrGRAD[3][gpu.MYRANK]->rows,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
422 |
423 | RMSprop_with_nesterov_weight_update(MSGRAD[0],arrGRAD[0][gpu.MYRANK],W[1],M[1],0.9f,_learningRate/(float)arrGRAD[0][gpu.MYRANK]->rows,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
424 | RMSprop_with_nesterov_weight_update(MSGRAD[1],arrGRAD[1][gpu.MYRANK],W[1],M[1],0.9f,_learningRate/(float)arrGRAD[1][gpu.MYRANK]->rows,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
425 | RMSprop_with_nesterov_weight_update(MSBGRAD[2],arrGRAD_B[2][gpu.MYRANK],B[0],M_B[0],0.9f,_learningRate,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
426 | RMSprop_with_nesterov_weight_update(MSBGRAD[0],arrGRAD_B[0][gpu.MYRANK],B[1],M_B[1],0.9f,_learningRate,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
427 | RMSprop_with_nesterov_weight_update(MSBGRAD[1],arrGRAD_B[1][gpu.MYRANK],B[1],M_B[1],0.9f,_learningRate,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
428 | RMSprop_with_nesterov_weight_update(MSBGRAD[3],arrGRAD_B[3][gpu.MYRANK],B[0],M_B[0],0.9f,_learningRate,_nBatchSize*gpu.MPI_SIZE, MOMENTUM);
429 |
430 |
431 |
432 | }
433 |
434 |
435 | MPI_Barrier(MPI_COMM_WORLD);
436 |
437 | }
438 |
439 | void WikiNetDist::backprop()
440 | {
441 | pairwise_ranking(z2_X,z2_Y, out);
442 | pairwise_ranking_derivative(z2_X,z2_Y, pairwise_grad);
443 |
444 | mul(out, pairwise_grad, e1);
445 | gpu.dotT(e1, W[1],e2_partial);
446 |
447 | gpu.dot(aB,e1,arrGRAD_B[0][gpu.MYRANK]);
448 | gpu.Tdot(a1_Y,e1,arrGRAD[0][gpu.MYRANK]);
449 |
450 | logisticGrad(a1_Y,a1_Y);
451 | mul(e2_partial,a1_Y,e2);
452 |
453 |
454 | gpu.Tdot(_batchY,e2,arrGRAD[2][gpu.MYRANK]);
455 | gpu.dot(aB,e2,arrGRAD_B[2][gpu.MYRANK]);
456 | gpu.dotT(e2,W[0],arrGRAD[4][gpu.MYRANK]);
457 |
458 | scalarMul(pairwise_grad,-1.0f,pairwise_grad);
459 | mul(out, pairwise_grad, e1);
460 |
461 | gpu.dot(aB,e1,arrGRAD_B[1][gpu.MYRANK]);
462 | gpu.Tdot(a1_X,e1,arrGRAD[1][gpu.MYRANK]);
463 | gpu.dotT(e1, W[1],e2_partial);
464 |
465 | logisticGrad(a1_X,a1_X);
466 | mul(e2_partial,a1_X,e2);
467 |
468 | gpu.Tdot(_batchX,e2,arrGRAD[3][gpu.MYRANK]);
469 | gpu.dot(aB,e2,arrGRAD_B[3][gpu.MYRANK]);
470 | gpu.dotT(e2,W[0],arrGRAD[5][gpu.MYRANK]);
471 |
472 | }
473 |
474 | double WikiNetDist::calculateError()
475 | {
476 | //scalarMul(W[0],0.9,W[0]);
477 | allocateNextBatch(true);
478 | for(int i = 0; i < _nCVErrorLength; i++)
479 | {
480 | MPI_Barrier(MPI_COMM_WORLD);
481 | feedforward();
482 |
483 | pairwise_ranking(z2_X,z2_Y, out);
484 | _dSumError += (double)sum(out);
485 |
486 | allocateNextBatch(true);
487 | }
488 | //scalarMul(W[0],1.1,W[0]);
489 | //size_t free, total;
490 | //cudaMemGetInfo(&free, &total);
491 | //cout << free << endl;
492 | //cout << "Free system memory: " << sysconf(_SC_PAGE_SIZE)*sysconf(_SC_PHYS_PAGES) << endl;
493 |
494 |
495 | double error = _dSumError/(double)(_nBatchSize*_nCVErrorLength);
496 | _dSumError = 0.0;
497 |
498 |
499 |
500 | _nNextBatchNumber_CV = 0;
501 |
502 | return error;
503 | }
504 |
505 |
506 |
507 |
508 |
--------------------------------------------------------------------------------
/source/WikiNetDist.h:
--------------------------------------------------------------------------------
1 | /*
2 | * WikiMaxoutNet.h
3 | *
4 | * Created on: Jun 25, 2014
5 | * Author: tim
6 | */
7 |
8 | #ifndef WikiNetDist_H_
9 | #define WikiNetDist_H_
10 |
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | class WikiNetDist
24 | {
25 | public:
26 | WikiNetDist(ClusterNet gpus);
27 | void run();
28 |
29 | private:
30 | Matrix **_currentBatchIdx_X;
31 | Matrix **_currentBatchIdx_Y;
32 | Matrix *_nextBatchIdx;
33 | ClusterNet gpu;
34 | Matrix *_X;
35 | Matrix *_CV_X;
36 | Matrix *_Vocab;
37 | Matrix *_Vocab_grad;
38 | Matrix *_Vocab_grad_full;
39 | Matrix *_MSVocab_grad;
40 | Matrix *_MSVocab_grad_Y;
41 | Matrix *M_VocabX;
42 | Matrix *M_VocabY;
43 | Matrix *_Vocab_grad_idx;
44 | Matrix *_batchX;
45 | Matrix *_batchY;
46 | Matrix **stackedBatch_X;
47 | Matrix **stackedBatch_Y;
48 |
49 | Matrix *out;
50 | Matrix *pairwise_grad;
51 | Matrix *e1;
52 | Matrix *aB;
53 | Matrix *e2_partial;
54 | Matrix *e2;
55 | Matrix *CV_container;
56 |
57 | Matrix *learning_rate_matrix;
58 |
59 | int _nCurrentDataSet;
60 | int _nNextBatchNumber;
61 | int _nNextBatchNumber_CV;
62 | float _RMS_multiplier;
63 | int _nBatchSize;
64 | int _batches;
65 | std::vector _layers;
66 | std::vector W;
67 | std::vector B;
68 | std::vector M;
69 | std::vector M_B;
70 | std::vector arrGRAD;
71 | std::vector MSGRAD;
72 | std::vector arrGRAD_B;
73 | std::vector MSBGRAD;
74 | clock_t start,stop;
75 |
76 | Matrix *d0;
77 | Matrix *z1;
78 | Matrix *a1_Y;
79 | Matrix *a1_idx_Y;
80 | Matrix *a1_X;
81 | Matrix *a1_idx_X;
82 | Matrix *d1;
83 | Matrix *z2_X;
84 | Matrix *z2_Y;
85 |
86 | cudaStream_t _streamNextBatch;
87 | double _dSumError;
88 | int _nCVErrorPeriodicity;
89 | int _nCVErrorLength;
90 | int _nMaxoutSize;
91 | float MOMENTUM;
92 | float _learningRate;
93 | int _totalNumberOfBatches;
94 |
95 | bool useRMSProp;
96 | bool useMaxout;
97 |
98 |
99 | void loadNextDataSet();
100 | void allocateNextBatch(bool isCV);
101 | void feedforward();
102 | void nesterov();
103 | double calculateError();
104 | void backprop();
105 | void weightUpdates();
106 | };
107 |
108 |
109 | #endif /* WIKIMAXOUTNET_H_ */
110 |
--------------------------------------------------------------------------------
/source/basicOps.cuh:
--------------------------------------------------------------------------------
1 | #ifndef basicOps_H
2 | #define basicOps_H
3 |
4 | #define TILE_DIM (32)
5 | #define BLOCK_ROWS (8)
6 | #define COPY_BLOCK_SIZE 16
7 |
8 | #define RDM_NUMBERS_PER_THREAD (1024)
9 | #define THREADS_PER_BLOCKS (512)
10 | #define BLOCKS (4096)
11 |
12 | #define DOT_BLOCKS (128)
13 | #define TILE_SIZE (32)
14 | #define DOT_REPS (4)
15 |
16 | #include
17 | #include
18 | #include
19 | #include
20 |
21 | //working
22 | //128,16,8
23 | //64,16,4
24 |
25 | typedef struct Matrix
26 | {
27 | int rows;
28 | int cols;
29 | size_t bytes;
30 | int size;
31 | float *data;
32 | int isDistributed;
33 | int cols_distributed;
34 |
35 | unsigned char *char_data;
36 |
37 | int isSparse;
38 | size_t ptr_bytes;
39 | size_t idx_bytes;
40 | int *ptr_rows;
41 | int *idx_cols;
42 | } Matrix;
43 |
44 | Matrix *fill_matrix(int rows, int cols, float fill_value);
45 | void fill_matrix(Matrix *A, const float fill_value);
46 | void fill_gpuarray(float *A, const float fill_value, int size);
47 | void fill_gpuarray(int *A, const int fill_value, int size);
48 | void fill_sparse_with_zeros(Matrix *A);
49 | Matrix *ones(int rows, int cols);
50 | Matrix *zeros(int rows, int cols);
51 | Matrix *empty(int rows, int cols);
52 | Matrix *empty_char(int rows, int cols);
53 | Matrix *empty_pinned(int rows, int cols);
54 | Matrix *empty_cpu(int rows, int cols);
55 | Matrix *empty_pinned_sparse(int rows, int cols, float max_sparsity, float sparsity_buffer);
56 | Matrix *empty_pinned_sparse(int rows, int cols, int nonzeros);
57 | Matrix *empty_sparse(int rows, int cols, float max_sparsity, float sparsity_buffer);
58 | Matrix *empty_sparse(int rows, int cols, int nonzeros);
59 | Matrix *arange(int rows, int cols);
60 | Matrix *arange(int start, int rows, int cols);
61 |
62 | void uniformSqrtWeight(Matrix * uniform_rdm);
63 | void uniformSqrtWeight(Matrix * uniform_rdm, int in, int out);
64 | void sparseRdmWeight(Matrix *rdm, Matrix *idx, Matrix *out, int connections);
65 | void rand_int(Matrix *int_values,int low, int high);
66 |
67 | void add_to_z(Matrix *z, Matrix *z1, Matrix *y, int classes, Matrix *out);
68 |
69 | Matrix *add(Matrix *A, Matrix *B);
70 | void add(Matrix *A, Matrix *B, Matrix *out);
71 | Matrix *sub(Matrix *A, Matrix *B);
72 | void sub(Matrix *A, Matrix *B, Matrix *out);
73 | Matrix *mul(Matrix *A, Matrix *B);
74 | void mul(Matrix *A, Matrix *B, Matrix *out);
75 | Matrix *div(Matrix *A, Matrix *B);
76 | void div(Matrix *A, Matrix *B, Matrix *out);
77 | float sum(Matrix *A);
78 | float max(Matrix *A);
79 | int getNonZeroElements(Matrix *A);
80 | int getNonZeroColumns(Matrix *A);
81 |
82 | Matrix *to_host(Matrix *A);
83 | Matrix *to_host(Matrix *A, int is_row_major);
84 | Matrix *to_gpu(Matrix *A);
85 | Matrix *to_gpu(Matrix *A, int is_col_major);
86 | Matrix *T(Matrix *A);
87 | Matrix *to_col_major(Matrix *A);
88 | void to_col_major(Matrix *A, Matrix *out);
89 | Matrix *to_row_major(Matrix *A);
90 |
91 | Matrix *scalarMul(Matrix *A, float a);
92 | void scalarMul(Matrix *A, float a, Matrix *out);
93 | Matrix *scalarAdd(Matrix *A, float a);
94 | void scalarAdd(Matrix *A, float a, Matrix *out);
95 | void dropout(Matrix *A, Matrix *out, float dropout_rate);
96 | void dropout_cached(Matrix *A, Matrix *dropout, Matrix *out, int idx);
97 | void RMSprop(Matrix *RMS, Matrix *grad, float RMS_multiplier, float learning_rate, int batch_size);
98 | void RMSprop_with_momentum_update(Matrix *RMS, Matrix *grad, Matrix *w, Matrix *m, float RMS_multiplier, float learning_rate, int batch_size, float momentum);
99 | void RMSprop_with_momentum_weight_update(Matrix *RMS, Matrix *grad, Matrix *w, Matrix *m, float RMS_multiplier, float learning_rate, int batch_size, float momentum);
100 | void RMSprop_with_nesterov_weight_update(Matrix *RMS, Matrix *grad, Matrix *w, Matrix *m, float RMS_multiplier, float learning_rate, int batch_size, float momentum);
101 | void RMSprop_with_weight_update(Matrix *RMS, Matrix *grad, Matrix *w, Matrix *m, float RMS_multiplier, float learning_rate, int batch_size, float momentum);
102 | void RMSprop_with_weight_update_8bit(Matrix *RMS, Matrix *grad, Matrix *w, Matrix *m, float RMS_multiplier, float learning_rate, int batch_size, float momentum);
103 | void Nesterov_weight_update(Matrix *RMS, Matrix *grad, Matrix *w, Matrix *m, float RMS_multiplier, float learning_rate, int batch_size, float momentum);
104 |
105 | void LocalGrad(Matrix *z, Matrix *w, Matrix *y, float learning_rate, int batch_size, float momentum);
106 | void compression_8bit(Matrix *tbl_flt, Matrix *A, float precision, Matrix *out);
107 | void compression_8bit_test(Matrix *tbl, Matrix *A, float precision, Matrix *out);
108 |
109 | void decompression_8bit(Matrix *tbl_flt, Matrix *A, float precision, Matrix *out);
110 |
111 | void renormalizeWeights(Matrix *w, Matrix *unit_sums, float limit);
112 |
113 | Matrix *square(Matrix *A);
114 | void square(Matrix *A, Matrix *out);
115 | Matrix *abs(Matrix *A);
116 | void abs(Matrix *A, Matrix *out);
117 | Matrix *gpuExp(Matrix *A);
118 | void gpuExp(Matrix *A, Matrix *out);
119 | Matrix *logistic(Matrix *A);
120 | void logistic(Matrix *A, Matrix *out);
121 | Matrix *logisticGrad(Matrix *A);
122 | void logisticGrad(Matrix *A, Matrix *out);
123 | Matrix *gpuLog(Matrix *A);
124 | void gpuLog(Matrix *A, Matrix *out);
125 | Matrix *gpuSqrt(Matrix *A);
126 | void gpuSqrt(Matrix *A, Matrix *out);
127 | Matrix *doubleRectifiedLinear(Matrix *A);
128 | void doubleRectifiedLinear(Matrix *A, Matrix *out);
129 | Matrix *LinearUnit(Matrix *A);
130 | void LinearUnit(Matrix *A, Matrix *out);
131 | Matrix *hardTanH(Matrix *A);
132 | void hardTanH(Matrix *A, Matrix *out);
133 | Matrix *pairwise_ranking(Matrix *A, Matrix *B);
134 | void pairwise_ranking(Matrix *A, Matrix *B, Matrix *out);
135 | Matrix *pairwise_ranking_derivative(Matrix *A, Matrix *B);
136 | void pairwise_ranking_derivative(Matrix *A, Matrix *B, Matrix *out);
137 |
138 | Matrix *softmax(Matrix *A);
139 | void softmax(Matrix *A, Matrix *out);
140 | Matrix *subMatrixVector(Matrix *A, Matrix *v);
141 | void subMatrixVector(Matrix *A, Matrix *v, Matrix *out);
142 | Matrix *addMatrixVector(Matrix *A, Matrix *v);
143 | void addMatrixVector(Matrix *A, Matrix *v, Matrix *out);
144 | Matrix *addScaledMatrixVector(Matrix *A, Matrix *v, float weight);
145 | void addScaledMatrixVector(Matrix *A, Matrix *v, float weight, Matrix *out);
146 | Matrix *mulMatrixVector(Matrix *A, Matrix *v);
147 | void mulMatrixVector(Matrix *A, Matrix *v, Matrix *out);
148 | Matrix *argmax(Matrix *A);
149 | void argmax(Matrix* A, Matrix* out);
150 | Matrix *create_t_matrix(Matrix *labels, int max_label);
151 | void create_t_matrix(Matrix *labels, Matrix *out);
152 | Matrix *equal(Matrix *A, Matrix *B);
153 | void equal(Matrix *A, Matrix *B, Matrix *out);
154 | Matrix *maxColumnwise(Matrix *A);
155 | void maxColumnwise(Matrix *A, Matrix *out);
156 | Matrix **maxout(Matrix *A, int maxout_level);
157 | void maxout(Matrix *A, Matrix *out, Matrix *outargmax, int maxout_level);
158 |
159 | Matrix *rectified_linear(Matrix *A);
160 | void rectified_linear(Matrix *A, Matrix *out);
161 | Matrix *rectified_linear_derivative(Matrix *A);
162 | void rectified_linear_derivative(Matrix *A, Matrix *out);
163 | Matrix *double_rectified_linear_derivative(Matrix *A);
164 | void double_rectified_linear_derivative(Matrix *A, Matrix *out);
165 | Matrix *hardTanH_derivative(Matrix *A);
166 | void hardTanH_derivative(Matrix *A, Matrix *out);
167 | Matrix *squared_error(Matrix *A, Matrix *targets);
168 | void squared_error(Matrix *A, Matrix *targets, Matrix *out);
169 |
170 | void expand_to_maxout_grad(Matrix *error, Matrix *idx, Matrix *grad);
171 |
172 | Matrix *rand_numbers(int rows, int cols, Matrix *seeds);
173 | void rand_numbers(Matrix *out, Matrix *seeds);
174 |
175 | int checkMatrixOperation(Matrix *A, Matrix *B, Matrix *C, cublasOperation_t T1, cublasOperation_t T2, int blnMatrixProduct);
176 | int blnFaultySizes(Matrix *A, Matrix *B, Matrix *C);
177 | int blnFaultyMatrixProductSizes(Matrix *A, Matrix *B, Matrix *C, cublasOperation_t T1, cublasOperation_t T2);
178 | void printFaultySizeError(Matrix *A, Matrix *B, Matrix *C);
179 | void printFaultyMatrixProductSizeError(Matrix *A, Matrix *B, Matrix *C, cublasOperation_t T1, cublasOperation_t T2);
180 | void printData(Matrix *A);
181 |
182 | Matrix *slice_rows(Matrix *A, int start, int end);
183 | Matrix *slice_cols(Matrix *A, int start, int end);
184 | void vStack(Matrix *A, Matrix *B, Matrix *out);
185 | Matrix *vStack(Matrix *A, Matrix *B);
186 | void hStack(Matrix *A, Matrix *B, Matrix *out);
187 | Matrix *hStack(Matrix *A, Matrix *B);
188 | void hStackN(float** arrA, int general_size, Matrix *out, int matrices_count);
189 | void hStackN(Matrix** arrA, int general_size, Matrix *out, int matrices_count);
190 | void vStackN(Matrix** arrA, Matrix *out, int matrices_count);
191 |
192 | void addGradientsN(Matrix** arrA, int myrank, int matrices_count, float multiplier);
193 |
194 | void sparse_dot(Matrix *A, Matrix *B, Matrix *out);
195 | void construct_vocab_matrix(Matrix *vocab_idx, Matrix *vocab_idx_y, Matrix *batch_X, Matrix *batch_y, Matrix *vocab, Matrix *rdm_idx);
196 | //void update_vocab_with_gradient(Matrix *grad, Matrix *vocab_idx, Matrix *vocab, float learning_rate);
197 | void expand_double_vocab_gradient(Matrix *gradX, Matrix *gradY, Matrix *vocab_idx_X, Matrix *vocab_idx_Y, Matrix *vocab, Matrix *vocab_grad, Matrix *vocab_grad_idx, float learning_rate);
198 | void expand_vocab_gradient(Matrix *grad, Matrix *vocab_idx, Matrix *vocab_grad);
199 | void expand_partial_vocab_gradient(Matrix *grad, Matrix *vocab_idx, Matrix *vocab_grad, int matrix_idx, int matrix_count);
200 | void update_vocab_with_gradient(Matrix *grad, Matrix *vocab_idx, Matrix *vocab, float learning_rate);
201 | void concatVocabBatchesN(Matrix** arrBatch_X, Matrix **arrBatch_Y, Matrix *out_X, Matrix *out_Y, int window_size, int matrices_count);
202 | void expand_vocab_gradient_middle_word(Matrix *grad, Matrix *vocab_idx, Matrix *vocab_grad);
203 | void matmul(Matrix *A, Matrix *B, Matrix *out, int T1, int T2);
204 | void dot8bit(Matrix *charA, Matrix *charB, Matrix* out, Matrix *flt_tbl, float precisionA, float precisionB);
205 | void dot8bit_shared(Matrix *charA, Matrix *charB, Matrix* out, Matrix *flt_tbl, float precisionA, float precisionB);
206 | #endif
207 |
--------------------------------------------------------------------------------
/source/batchAllocator.h:
--------------------------------------------------------------------------------
1 | #ifndef BatchAllocator_H
2 | #define BatchAllocator_H
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 |
15 | typedef enum BatchAllocationMethod_t
16 | {
17 | Single_GPU = 0,
18 | Batch_split = 1,
19 | Distributed_weights = 2,
20 | Distributed_weights_sparse = 4
21 | } BatchAllocationMethod_t;
22 |
23 | class BatchAllocator
24 | {
25 |
26 | public:
27 | BatchAllocator();
28 | BatchAllocationMethod_t BATCH_METHOD;
29 | Matrix *CURRENT_BATCH;
30 | Matrix *CURRENT_BATCH_Y;
31 | Matrix *CURRENT_BATCH_CV;
32 | Matrix *CURRENT_BATCH_CV_Y;
33 | int TOTAL_BATCHES;
34 | int TOTAL_BATCHES_CV;
35 | int BATCH_SIZE;
36 | int BATCH_SIZE_CV;
37 | int TRAIN_SET_ROWS;
38 | int CV_SET_ROWS;
39 |
40 | bool SKIP_LAST_BATCH;
41 |
42 | void finish_batch_allocator();
43 | void broadcast_batch_to_processes();
44 | void broadcast_batch_cv_to_processes();
45 | void allocate_next_batch_async();
46 | void allocate_next_cv_batch_async();
47 | void replace_current_batch_with_next();
48 | void replace_current_cv_batch_with_next();
49 |
50 | void init(Matrix *X, Matrix *y, float cross_validation_size, int batch_size, int cv_batch_size, ClusterNet *cluster, BatchAllocationMethod_t batchmethod);
51 | void init(std::string path_X, std::string path_y, float cross_validation_size, int batch_size, int cv_batch_size, ClusterNet *cluster, BatchAllocationMethod_t batchmethod);
52 | void init(Matrix *X, Matrix *y, float cross_validation_size, int batch_size, int cv_batch_size);
53 |
54 | void propagate_through_layers(Layer *root, DataPropagationType_t type, int epoch);
55 |
56 | int m_next_batch_number_cv;
57 | private:
58 | Matrix *m_next_batch_X;
59 | Matrix *m_next_batch_y;
60 | Matrix *m_next_batch_cv_X;
61 | Matrix *m_next_batch_cv_y;
62 | Matrix *m_full_X;
63 | Matrix *m_full_y;
64 |
65 | Matrix* m_next_buffer_X;
66 | Matrix* m_next_buffer_y;
67 | Matrix* m_next_buffer_cv_X;
68 | Matrix* m_next_buffer_cv_y;
69 |
70 | int m_next_batch_number;
71 | int m_Cols_X;
72 | int m_Cols_y;
73 | int m_Rows;
74 | int m_mygpuID;
75 | int m_myrank;
76 |
77 | int m_sparse_matrix_info_X[6];
78 | int m_sparse_matrix_info_y[6];
79 | int m_sparse_matrix_info_cv_X[6];
80 | int m_sparse_matrix_info_cv_y[6];
81 |
82 | ClusterNet *m_cluster;
83 | MPI_Status m_status;
84 |
85 | cudaStream_t m_streamNext_batch_X;
86 | cudaStream_t m_streamNext_batch_y;
87 | cudaStream_t m_streamNext_batch_cv_X;
88 | cudaStream_t m_streamNext_batch_cv_y;
89 |
90 | std::vector m_requests_send_X;
91 | std::vector m_requests_send_y;
92 | std::vector m_requests_send_cv_X;
93 | std::vector m_requests_send_cv_y;
94 |
95 | std::vector m_request_X;
96 | std::vector m_request_y;
97 | std::vector m_request_cv_X;
98 | std::vector m_request_cv_y;
99 |
100 | void MPI_get_dataset_dimensions();
101 | void init(float cross_validation_size, int batch_size, int cv_batch_size);
102 | void init_batch_buffer();
103 | void init_copy_to_buffer();
104 | void update_next_batch_matrix_info();
105 | void update_next_cv_batch_matrix_info();
106 |
107 |
108 | };
109 | #endif
110 |
111 |
112 |
--------------------------------------------------------------------------------
/source/clusterKernels.cuh:
--------------------------------------------------------------------------------
1 | #ifndef clusterKernels
2 | #define clusterKernels
3 |
4 | #include "curand.h"
5 | #include "curand_kernel.h"
6 | __global__ void kRdmNumbers(float *seed, int size, float *out);
7 | __global__ void kCompression_8bit_test(float *tbl, float *A, float precision, int size, float *out);
8 | __global__ void kCompression_8bit(float *flt_tbl, float *A, float precision, int size, unsigned char *out);
9 | __global__ void kCompression_8bit_float(float *flt_tbl, float *A, float precision, int size, float *out);
10 | __global__ void kDecompression_8bit(float *flt_tbl, unsigned char *A, float precision, int size, float *out);
11 | __global__ void kRenormalizeWeights(float *w, float *unit_sums, float limit, int rows, int cols);
12 | __global__ void kGetNonZeroColumns(float *A, float *out, int rows, int cols);
13 | __global__ void kGetNonZeroElements(float *A, float *out, int size);
14 | __global__ void kFill_with(float *m, float fill_value, int size);
15 | __global__ void kFill_with(int *m, int fill_value, int size);
16 | __global__ void kAdd(float *A,float *B, float *out, int size);
17 | __global__ void kAdd_to_z(float *z, float *z1, float *y, float *y_count, int batch_size, int units, float *out);
18 | __global__ void kSub(float *A,float *B, float *out, int size);
19 | __global__ void kSub_Sparse(float *A, float *data, int *ptr_rows, int *idx_cols, float *out, int rows, int cols, int size);
20 | __global__ void kMul(float *A,float *B, float *out, int size);
21 | __global__ void kDiv(float *A,float *B, float *out, int size);
22 | __global__ void kExp(float *A, float *out, int size);
23 | __global__ void kLog(float *A, float *out, int size);
24 | __global__ void kSqrt(float *A, float *out, int size);
25 | __global__ void kSquare(float *A, float *out, int size);
26 | __global__ void kAbs(float *A, float *out, int size);
27 | __global__ void kScalarMul(float *A, float scalar, float *out, int size);
28 | __global__ void kScalarAdd(float *A, float scalar, float *out, int size);
29 | __global__ void kTranspose(float *A, float *out, int width, int height);
30 | __global__ void setup_kernel(curandState *state, int seed);
31 | __global__ void generate_uniform_kernel(curandState *state, int size, float *out);
32 | __global__ void generate_normal_kernel(curandState *state, int size, float *out);
33 | __global__ void slice_rows(float *A, float *out, int size_out, int rows_A, int start, int end);
34 | __global__ void slice_cols(float *A, float *out, int start, int rows, int size_out);
35 | __global__ void vStack(float *A, float *B, float *out, int size_out, int rows_a, int rows, int cols);
36 | __global__ void hStack(float *A, float *B, float *out, int size_out, int size_a);
37 | __global__ void hStackN(float **arrA, int general_size, float *out, int size_out, int matrices_count);
38 | __global__ void hStackN(Matrix **arrA, int general_size, float *out, int size_out, int matrices_count);
39 | __global__ void vStackN(float **arrA, float *out, int rows, int cols);
40 | __global__ void AddGradientsN(float **arrA, int size, int myrank, int matrix_count, float multiplier);
41 | __global__ void kSoftMax(float* A, float* out, unsigned int rows, unsigned int cols);
42 | __device__ void reduceToMax(float* sdata, unsigned int tid);
43 | __device__ void reduceToSumLocal(float* sdata, unsigned int tid);
44 | __global__ void kSubMatrixVector(float *A, float *v, float *out, int rows, int size);
45 | __global__ void kAddMatrixVector(float *A, float *v, float *out, int rows, int size);
46 | __global__ void kMulMatrixVector(float *A, float *v, float *out, int rows, int size);
47 | __global__ void kAddScaledMatrixVector(float *A, float *v, float weight, float *out, int rows, int size);
48 | __global__ void kDot8bit(unsigned char *A, unsigned char *B, float *out, int rowsA, int colsA, int colsB, float *flt_tbl, float precisionA, float precisionB);
49 | __global__ void kDot8bit_shared(unsigned char *A, unsigned char *B, float *out, int rowsA, int colsA, int colsB, float *flt_tbl, float precisionA, float precisionB);
50 | __global__ void kArgmax(float* A, float* out, unsigned int height, unsigned int width);
51 | __global__ void kCreate_t_matrix(float *labels, float *out, int rows, int size);
52 | __global__ void kEqual(float *A, float *B, float *out, int size);
53 | __global__ void kSum(float *v, float *out, int size);
54 | __global__ void kLogistic(float *A, float *out, int size);
55 | __global__ void kLogisticGrad(float *A, float *out, int size);
56 | __global__ void kArange(float *out, int start, int rows, int cols, int size);
57 | __global__ void kDropout(float *A, float *rdm, float dropout, int size);
58 | __global__ void kDropout_cached(float *A, float *dropout, float *out, int current_idx, int size);
59 | __global__ void kRMSprop(float *RMS, float *grad, float RMS_multiplier, float learning_rate, int batch_size, int size);
60 | __global__ void kRMSprop_with_momentum_update(float *RMS, float *grad, float *w, float *m, float RMS_multiplier, float learning_rate, int batch_size, int size, float momentum);
61 | __global__ void kRMSprop_with_momentum_weight_update(float *RMS, float *grad, float *w, float *m, float RMS_multiplier, float learning_rate, int batch_size, int size, float momentum);
62 | __global__ void kLocalGrad (float *z, float *w, float *y, float *m, float learning_rate, int batch_size, int size, float momentum);
63 | __global__ void kRMSprop_with_nesterov_weight_update(float *RMS, float *grad, float *w, float *m, float RMS_multiplier, float learning_rate, int batch_size, int size, float momentum);
64 | __global__ void kNesterov_weight_update(float *RMS, float *grad, float *w, float *m, float RMS_multiplier, float learning_rate, int batch_size, int size, float momentum);
65 | __global__ void kRMSprop_with_weight_update(float *RMS, float *grad, float *w, float *m, float RMS_multiplier, float learning_rate, int batch_size, int size, float momentum);
66 | __global__ void kRMSprop_with_weight_update_8bit(float *RMS, float *grad, float *w, float *m, float RMS_multiplier, float learning_rate, int batch_size, int size, float momentum);
67 | __global__ void kCreateRdmSqrtWeight_Logistic(float *A, int in, int out, int size);
68 | __global__ void kRandInt(float *A, int lower_limit, int upper_limit, int size);
69 | __global__ void kCreateSparseRdmWeight(float *rdm, float* indicies, float *out, int rows, int cols, int connections);
70 | __global__ void kRectifiedLinear(float *A, float *out, int size);
71 | __global__ void kRectifiedLinear_Derivative(float *A, float *out, int size);
72 | __global__ void kSquaredError(float *A, float *t, float *out, int size);
73 | __global__ void kLinear(float *A, float *out, int size);
74 | __global__ void kDoubleRectifiedLinear(float* A, float* out, int size);
75 | __global__ void kDoubleRectifiedLinear_Derivative(float *A, float *out, int size);
76 | __global__ void kSparseDot(int m, int n, int k, float *data, int* indptr, int* indices, float *dense_data, float* target, float beta, float alpha);
77 | __global__ void kPrintData(float *A, int size);
78 | __global__ void kHardTanH(float *A, float *out, int size);
79 | __global__ void kHardTanH_Derivative(float *A, float *out, int size);
80 | __global__ void kPairwise_ranking(float *A, float *B, float *out, int size);
81 | __global__ void kPairwise_ranking_derivative(float *A, float *B, float *out, int size);
82 | __global__ void kMaxColumnwise(float* mat, float* target, unsigned int width, unsigned int height);
83 | __global__ void kMaxout(float *A, float *out, float *outargmax, int maxout_level, unsigned int cols, unsigned int rows);
84 | __device__ void reduceToMaxAndArgMax(float* sdataMax, float* sdataArgMax, unsigned int tid, int threads);
85 | __global__ void kExpandToMaxoutGrad(float* error, float* indexes, float *out, int error_size, int error_rows, int maxout_level);
86 | __global__ void kConstructVocabMatrix(float *vocab_idx, float *vocab_idx_y, float* vocab, float *rdm_idx, float *batch_X, float *batch_Y);
87 | __global__ void kExpandDoubleVocabGradient(float *gradX, float *gradY, float *vocab_idx_X, float *vocab_idx_Y, float* vocab,
88 | float *vocab_grad, float *vocab_grad_idx, float learning_rate, int grad_size);
89 | __global__ void kExpandVocabGradient(float *grad, float *vocab_idx, float *vocab_grad);
90 | __global__ void kExpandPartialVocabGradient(float *grad, float *vocab_idx, float *vocab_grad, int matrix_idx, int matrix_count);
91 | __global__ void kExpandVocabGradientMiddleWord(float *grad, float *vocab_idx, float *vocab_grad);
92 | __global__ void kUpdateVocabWithGradient(float *grad, float *vocab_idx, float* vocab, float learning_rate);
93 | __global__ void MatMul(float* A, float* B, float* C, int ARows, int ACols, int BRows, int BCols, int CRows, int CCols);
94 | __global__ void sgemm_kernel_N_N_64_16_16_16_4(float* C,const float* A,const float* B, int m, int n, int k, int lda, int ldb, int ldc, float alpha, float beta );
95 | __global__ void sgemm_kernel_N_T_64_16_4_16_4(float* C, const float* A, const float* B, int m, int n, int k, int lda, int ldb, int ldc, float alpha, float beta );
96 | __global__ void sgemm_kernel_T_N_32_32_8_8_8(float* C, const float* A, const float* B, int m, int n, int k, int lda, int ldb, int ldc, float alpha, float beta );
97 | __global__ void sgemmNN( const float *A, int lda, const float *B, int ldb, float* C, int ldc, int k, float alpha, float beta );
98 | __global__ void concat_batches(float **batch_X, float **batch_Y, float *out_X, float *out_Y);
99 | #endif
100 |
--------------------------------------------------------------------------------
/source/clusterNet.h:
--------------------------------------------------------------------------------
1 | #ifndef ClusterNet_H
2 | #define ClusterNet_H
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include
25 | #include
26 |
27 | typedef enum Unittype_t
28 | {
29 | Logistic = 0,
30 | Rectified_Linear = 1,
31 | Softmax = 2,
32 | Linear = 4,
33 | Double_Rectified_Linear = 8,
34 | Input = 16
35 | } Unittype_t;
36 |
37 | typedef enum DataPropagationType_t
38 | {
39 | Training = 0,
40 | Trainerror = 1,
41 | CVerror = 2
42 | } DataPropagationType_t;
43 |
44 |
45 | typedef enum WeightUpdateType_t
46 | {
47 | NesterovRMSProp = 0,
48 | NesterovMomentum = 1,
49 | RMSProp = 2,
50 | Momentum = 4,
51 | NoMomentum = 8
52 | } WeightUpdateType_t;
53 |
54 |
55 | typedef enum ParallelismType_t
56 | {
57 | None = 0,
58 | DataParallelism = 1,
59 | ModelParallelism = 2
60 | } ParallelismType_t;
61 |
62 | typedef enum Compression_t
63 | {
64 | bits_32 = 0,
65 | bits_8 = 1
66 | } Compression_t;
67 |
68 | typedef enum Costfunction_t
69 | {
70 | Cross_Entropy = 0,
71 | Squared_Error = 1,
72 | Root_Squared_Error = 2,
73 | Misclassification = 4
74 | } Costfunction_t;
75 |
76 |
77 | class ClusterNet
78 | {
79 |
80 |
81 | public:
82 | ClusterNet();
83 | ClusterNet(int seed);
84 | ClusterNet(int argc, char* argv[]);
85 | ClusterNet(int argc, char *argv[], int seed);
86 | ClusterNet(int argc, char* argv[], int seed, bool useSameSeed);
87 |
88 | void dotPCIe(Matrix **A, Matrix **B, Matrix **out);
89 | void dotTPCIe(Matrix **A, Matrix **B, Matrix **out);
90 | void TdotPCIe(Matrix **A, Matrix **B, Matrix **out);
91 | Matrix *dot(Matrix *A, Matrix *B);
92 | Matrix *Tdot(Matrix *A, Matrix *B);
93 | Matrix *dotT(Matrix *A, Matrix *B);
94 | void dot(Matrix *A, Matrix *B, Matrix *out);
95 | void Tdot(Matrix *A, Matrix *B, Matrix *out);
96 | void dotT(Matrix *A, Matrix *B, Matrix *out);
97 | Matrix *dotTMPI(Matrix *A, Matrix *B);
98 | Matrix *TdotMPI(Matrix *A, Matrix *B);
99 | Matrix *dotMPI(Matrix *A, Matrix *B);
100 | void dotMPI(Matrix *A, Matrix *B, Matrix *out);
101 | void TdotMPI(Matrix *A, Matrix *B, Matrix *out);
102 | void dotTMPI(Matrix *A, Matrix *B, Matrix *out);
103 |
104 | void add_PCIe(Matrix **A, Matrix **B, Matrix **out);
105 | void mul_PCIe(Matrix **A, Matrix **B, Matrix **out);
106 | void scalarMul_PCIe(Matrix **A, float a, Matrix **out);
107 | void addMatrixVector_PCIe(Matrix **A, Matrix **v, Matrix **out);
108 | void logistic_PCIe(Matrix **A, Matrix **out);
109 |
110 | Matrix *dot_sparse(Matrix *A, Matrix *B);
111 | Matrix *Tdot_sparse(Matrix *A, Matrix *B);
112 | Matrix *dotT_sparse(Matrix *A, Matrix *B);
113 | void dot_sparse(Matrix *A, Matrix *B, Matrix *out);
114 | void dotT_sparse(Matrix *A, Matrix *B, Matrix *out);
115 | void Tdot_sparse(Matrix *A, Matrix *B, Matrix *out);
116 |
117 | Matrix *distribute_rows_hdf5_file(std::string path);
118 | Matrix *distribute_file(std::string path);
119 | float *distribute_float(float number);
120 |
121 | void RMSprop_with_nesterov_weight_update_PCIe(Matrix **RMS, Matrix **grad, Matrix **w, Matrix **m, float RMS_multiplier, float learning_rate, int batch_size, float momentum);
122 |
123 | Matrix *rand(int rows, int cols);
124 | Matrix *rand_same_seed_MPI(int rows, int cols);
125 | void rand(int rows, int cols, bool useSameSeedGenerator, Matrix *out);
126 |
127 |
128 | Matrix *rand_numbers(int rows, int cols);
129 | void rand_numbers(int rows, int cols, Matrix *out);
130 |
131 |
132 | Matrix *randn(int rows, int cols);
133 | Matrix *randn(int rows, int cols, float mean, float std);
134 | void randn(int rows, int cols, float mean, float std, Matrix *out);
135 | Matrix *dropout(Matrix *A, float dropout_rate);
136 | void dropout(Matrix *A, Matrix *out, float dropout_rate);
137 | Matrix *rand_int(int rows, int cols, int low, int high);
138 |
139 | Matrix *compression_8bit(Matrix *A, float precision);
140 | void compression_8bit(Matrix *A, float precision, Matrix *out);
141 | Matrix *decompression_8bit(Matrix *A, float precision);
142 | void decompression_8bit(Matrix *A, float precision, Matrix *out);
143 | Matrix *compression_8bit_test(Matrix *A, float precision);
144 | void compression_8bit_test(Matrix *A, float precision, Matrix *out);
145 |
146 | void dot8bit(Matrix *A, Matrix *B, float precisionA, float precisionB, Matrix *out);
147 | Matrix *dot8bit(Matrix *A, Matrix *B, float precisionA, float precisionB);
148 | void dot8bit_shared(Matrix *A, Matrix *B, float precisionA, float precisionB, Matrix *out);
149 | Matrix *dot8bit_shared(Matrix *A, Matrix *B, float precisionA, float precisionB);
150 |
151 | void tick(std::string name);
152 | void tick();
153 | float tock(std::string name);
154 | float tock();
155 |
156 | void benchmark_dot();
157 | void shutdown_MPI();
158 | Matrix *distributed_uniformSqrtWeight(int rows, int cols);
159 | Matrix *distributed_sparseInitWeight(int rows, int cols);
160 | Matrix *distributed_zeros(int rows, int cols);
161 | Matrix *distributed_ones(int rows, int cols);
162 | Matrix *uniformSqrtWeight(int rows, int cols);
163 | Matrix *uniformSqrtWeight(int rows, int cols, int rows_stacked, int cols_stacked);
164 | Matrix *uniformSqrtWeight_sameSeed(int rows, int cols);
165 | Matrix *sparseInitWeight(int rows, int cols);
166 | Matrix *sparseInitWeight(int rows, int cols, int connections);
167 |
168 | Matrix *dense_to_sparse(Matrix *A);
169 | Matrix *sparse_to_dense(Matrix *A);
170 |
171 | void construct_vocab_matrix(Matrix *vocab_idx, Matrix *vocab_idx_y, Matrix *batch_X, Matrix *batch_y, Matrix *vocab);
172 | void add_to_queue(Matrix **gpuArray);
173 | bool pop_queue();
174 | void add_to_queue_PCIe(Matrix **gpuArray);
175 | bool pop_queue_PCIe();
176 | int get_queue_length();
177 |
178 | void addGradients_PCIe(Matrix **grad);
179 |
180 | Matrix **zeros_PCIe(int rows, int cols);
181 | Matrix **zeros_stacked(int rows, int cols);
182 | Matrix **zeros_gradient_PCIe(int rows, int cols);
183 | Matrix **uniformSqrtWeight_stacked(int rows, int cols);
184 | Matrix **ones_PCIe(int rows, int cols);
185 | Matrix **uniformSqrtWeight_PCIe(int rows, int cols);
186 |
187 | bool QUEUE_EMPTY;
188 |
189 | bool StartBackgroundQueue;
190 | int MYRANK;
191 | int NODES;
192 | int MYGPUID;
193 | int MPI_SIZE;
194 | int GPU_COUNT;
195 | std::vector PCIe_RANKS;
196 | std::vector MASTER_GPU_RANKS;
197 |
198 | int count;
199 |
200 | void *hello(void)
201 | {
202 | bool uden = true;
203 | std::cout << "test kek" << std::endl;
204 | while(uden)
205 | {
206 | pop_queue_PCIe();
207 | usleep(100);
208 | }
209 |
210 | return 0;
211 | }
212 |
213 | static void *hello_helper(void *context)
214 | {
215 | return ((ClusterNet *)context)->hello();
216 | }
217 |
218 | private:
219 | std::vector m_handle;
220 | cusparseHandle_t m_sparse_handle;
221 | curandGenerator_t m_generator;
222 | curandGenerator_t m_generator_same_seed;
223 | std::map m_dictTickTock;
224 | std::map m_dictTickTockCumulative;
225 | MPI_Request* m_requests;
226 | MPI_Request m_sendrequest;
227 | std::vector m_sendrequests;
228 | std::map m_matrixCache;
229 | std::map m_matrixCacheChar;
230 | std::map m_matrixHStackCache;
231 | std::map m_matrixCacheUsage;
232 | int m_gpucount;
233 | pthread_t *m_threads;
234 |
235 | Matrix *flt_tbl;
236 | Matrix *seeds;
237 | float *sync_floats;
238 |
239 | bool m_hasMPI;
240 | bool m_cublasInitialized;
241 | bool m_cusparseInitialized;
242 | bool waitingForTransfer;
243 | MPI_Status m_status;
244 | MPI_Comm m_MPIWorld;
245 | MPI_Request *m_request_queue;
246 | int *m_flag_queue;
247 | std::vector m_send_queue;
248 | std::vector m_receive_queue;
249 | std::vector m_sendid_queue;
250 | std::vector m_receiveid_queue;
251 | std::vector m_streams_PCIe;
252 |
253 | int m_destination;
254 | int m_source;
255 |
256 | void dot(Matrix *A, Matrix *B, Matrix *out, cublasOperation_t T1, cublasOperation_t T2);
257 | void dot_sparse(Matrix *A, Matrix *B, Matrix *out, cublasOperation_t T1, cublasOperation_t T2);
258 | void dotMPI(Matrix *A, Matrix *B, Matrix *out, bool applyTranspose_A, bool applyTranspose_B);
259 | void init(int seed);
260 | void init_MPI(int argc, char *argv[]);
261 |
262 | void compute_PCIe_ranks();
263 | void compute_GPUID_and_Nodes();
264 |
265 |
266 | };
267 | #endif
268 |
269 |
--------------------------------------------------------------------------------
/source/util.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 |
16 | using std::string;
17 | using std::vector;
18 | using std::cout;
19 | using std::endl;
20 |
21 | Matrix *read_csv (const char* filename)
22 | {
23 | std::ifstream dStream(filename);
24 | int columns = 0;
25 | int rows = 0;
26 | vector X;
27 |
28 | string line;
29 | while(std::getline(dStream,line))
30 | {
31 | std::stringstream lineStream(line);
32 | string cell;
33 | while(std::getline(lineStream,cell,','))
34 | {
35 | X.push_back(::atof(cell.c_str()));
36 |
37 | if(rows == 0)
38 | columns++;
39 | }
40 | rows++;
41 | }
42 |
43 | float *data;
44 | size_t bytes = columns*rows*sizeof(float);
45 | cudaHostAlloc(&data, bytes, cudaHostAllocPortable);
46 | memcpy(data,&X[0], columns*rows*sizeof(float));
47 |
48 | Matrix *out = (Matrix*)malloc(sizeof(Matrix));
49 | out->rows = rows;
50 | out->cols = columns;
51 | out->bytes = bytes;
52 | out->size = columns*rows;
53 | out->data = data;
54 | out->isDistributed = 0;
55 | out->cols_distributed = 0;
56 | out->isSparse = 0;
57 |
58 | return out;
59 | }
60 |
61 | void write_csv(const char* filename, Matrix *X, const char* header, Matrix *ids)
62 | {
63 | std::ofstream myfile;
64 | myfile.open(filename,std::ios::trunc);
65 | myfile << header << "\r\n";
66 | for(int row = 0; row< X->rows; row++)
67 | {
68 | for(int col = 0; col < X->cols; col++)
69 | {
70 | if(col > 0)
71 | myfile << ",";
72 | else
73 | myfile << (int)ids->data[row] << ",";
74 |
75 | myfile << std::fixed << X->data[(row*X->cols)+col];
76 | }
77 | myfile << "\r\n";
78 | }
79 | myfile.close();
80 | }
81 |
82 | void write_csv(const char* filename, Matrix *X)
83 | {
84 | std::ofstream myfile;
85 | myfile.open(filename,std::ios::trunc);
86 | for(int row = 0; row< X->rows; row++)
87 | {
88 | for(int col = 0; col < X->cols; col++)
89 | {
90 | if(col > 0)
91 | myfile << ",";
92 |
93 | myfile << std::fixed << X->data[(row*X->cols)+col];
94 | }
95 | myfile << "\r\n";
96 | }
97 | myfile.close();
98 | }
99 |
100 | Matrix *read_hdf5(const char *filepath){ return read_hdf5(filepath,"/Default"); }
101 | Matrix *read_hdf5(const char *filepath, const char *tag)
102 | {
103 | hid_t file_id, dataset_id;
104 |
105 | file_id = H5Fopen(filepath, H5F_ACC_RDWR, H5P_DEFAULT);
106 | dataset_id = H5Dopen2(file_id, tag, H5P_DEFAULT);
107 |
108 | hid_t dspace = H5Dget_space(dataset_id);
109 | hsize_t dims[2];
110 | H5Sget_simple_extent_dims(dspace, dims, NULL);
111 | size_t bytes = sizeof(float)*dims[0]*dims[1];
112 |
113 | float *data;
114 | cudaHostAlloc(&data, bytes, cudaHostAllocPortable);
115 |
116 | H5Dread(dataset_id, H5T_NATIVE_FLOAT, H5S_ALL, H5S_ALL, H5P_DEFAULT, data);
117 | H5Dclose(dataset_id);
118 | H5Fclose(file_id);
119 |
120 | Matrix *out = (Matrix*)malloc(sizeof(Matrix));
121 | out->rows = (int)dims[0];
122 | out->cols= (int)dims[1];
123 | out->bytes = bytes;
124 | out->data = data;
125 | out->size = (int)(dims[0]*dims[1]);
126 | out->isDistributed = 0;
127 | out->cols_distributed = 0;
128 | out->isSparse = 0;
129 |
130 | return out;
131 | }
132 |
133 | Matrix *read_sparse_hdf5(const char *filepath)
134 | {
135 | hid_t file_id, dataset_id_idx, dataset_id_ptr, dataset_id_data, dataset_id_shape, dspace;
136 | hsize_t dims[2];
137 | size_t bytes;
138 | file_id = H5Fopen(filepath, H5F_ACC_RDWR, H5P_DEFAULT);
139 | Matrix *out = (Matrix*)malloc(sizeof(Matrix));
140 |
141 | dataset_id_idx = H5Dopen2(file_id, "/indices", H5P_DEFAULT);
142 | dspace = H5Dget_space(dataset_id_idx);
143 | H5Sget_simple_extent_dims(dspace, dims, NULL);
144 | bytes = sizeof(int)*dims[0];
145 | int *idx;
146 | cudaHostAlloc(&idx, bytes, cudaHostAllocPortable);
147 | H5Dread(dataset_id_idx, H5T_NATIVE_INT, H5S_ALL, H5S_ALL, H5P_DEFAULT, idx);
148 | H5Dclose(dataset_id_idx);
149 |
150 | out->idx_bytes = sizeof(int)*dims[0];
151 | out->idx_cols = idx;
152 |
153 |
154 | dataset_id_ptr = H5Dopen2(file_id, "/indptr", H5P_DEFAULT);
155 | dspace = H5Dget_space(dataset_id_ptr);
156 | H5Sget_simple_extent_dims(dspace, dims, NULL);
157 | bytes = sizeof(int)*dims[0];
158 | int *ptr;
159 | cudaHostAlloc(&ptr, bytes, cudaHostAllocPortable);
160 | H5Dread(dataset_id_ptr, H5T_NATIVE_INT, H5S_ALL, H5S_ALL, H5P_DEFAULT, ptr);
161 | H5Dclose(dataset_id_ptr);
162 |
163 | out->ptr_bytes = sizeof(int)*dims[0];
164 | out->ptr_rows = ptr;
165 |
166 |
167 | dataset_id_data = H5Dopen2(file_id, "/data", H5P_DEFAULT);
168 | dspace = H5Dget_space(dataset_id_data);
169 | H5Sget_simple_extent_dims(dspace, dims, NULL);
170 | bytes = sizeof(float)*dims[0];
171 | float *data;
172 | cudaHostAlloc(&data, bytes, cudaHostAllocPortable);
173 | H5Dread(dataset_id_data, H5T_NATIVE_FLOAT, H5S_ALL, H5S_ALL, H5P_DEFAULT, data);
174 | H5Dclose(dataset_id_data);
175 |
176 | out->bytes = sizeof(float)*dims[0];
177 | out->size = (int)dims[0];
178 |
179 | dataset_id_shape = H5Dopen2(file_id, "/shape", H5P_DEFAULT);
180 | dspace = H5Dget_space(dataset_id_shape);
181 | H5Sget_simple_extent_dims(dspace, dims, NULL);
182 | bytes = sizeof(long)*dims[0];
183 | long shape[2];
184 | H5Dread(dataset_id_shape, H5T_NATIVE_LONG, H5S_ALL, H5S_ALL, H5P_DEFAULT, shape);
185 | H5Dclose(dataset_id_shape);
186 |
187 | H5Fclose(file_id);
188 |
189 |
190 | out->rows = (int)shape[0];
191 | out->cols= (int)shape[1];
192 | out->data = data;
193 | out->isDistributed = 0;
194 | out->isSparse = 1;
195 |
196 |
197 |
198 |
199 | return out;
200 | }
201 |
202 | void write_hdf5(const char * filepath, Matrix *A)
203 | {
204 | hid_t file_id, dataset_id, dataspace_id;
205 | hsize_t dims[2];
206 |
207 | file_id = H5Fcreate(filepath, H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT);
208 | dims[0] = A->rows;
209 | dims[1] = A->cols;
210 | dataspace_id = H5Screate_simple(2, dims, NULL);
211 | dataset_id = H5Dcreate2(file_id, "/Default", H5T_NATIVE_FLOAT, dataspace_id, H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);
212 |
213 | H5Dwrite(dataset_id, H5T_NATIVE_FLOAT, H5S_ALL, H5S_ALL, H5P_DEFAULT, A->data);
214 | H5Dclose(dataset_id);
215 | H5Fclose(file_id);
216 | }
217 |
218 | cudaEvent_t* tick()
219 | {
220 | cudaEvent_t* startstop;
221 | startstop = (cudaEvent_t*)malloc(2*sizeof(cudaEvent_t));
222 | cudaEventCreate(&startstop[0]);
223 | cudaEventCreate(&startstop[1]);
224 | cudaEventRecord(startstop[0], 0);
225 |
226 | return startstop;
227 | }
228 |
229 | float tock(cudaEvent_t* startstop){ return tock(startstop, "Time for the kernel(s): "); }
230 | float tock(cudaEvent_t* startstop, std::string text)
231 | {
232 | float time;
233 | cudaEventRecord(startstop[1], 0);
234 | cudaEventSynchronize(startstop[1]);
235 | cudaEventElapsedTime(&time, startstop[0], startstop[1]);
236 | //printf((text + ": %f ms.\n").c_str(), time);
237 | return time;
238 | }
239 | float tock(std::string text, float tocks)
240 | {
241 | //printf((text + ": %f ms.\n").c_str(), tocks);
242 | return tocks;
243 | }
244 | float tock(cudaEvent_t* startstop, float tocks)
245 | {
246 | float time;
247 | cudaEventRecord(startstop[1], 0);
248 | cudaEventSynchronize(startstop[1]);
249 | cudaEventElapsedTime(&time, startstop[0], startstop[1]);
250 |
251 | return time+tocks;
252 | }
253 |
254 |
255 |
256 | int test_eq(float f1, float f2, char* message)
257 | {
258 | if(f1 == f2){ return 1;}
259 | else{ printf("%s: %f != %f\n", message, f1, f2); }
260 | return 0;
261 | }
262 |
263 | int test_eq(float f1, float f2, int idx1, int idx2, char* message)
264 | {
265 | if(f1 == f2){ return 1;}
266 | else{ printf("%s: %f != %f for index %i and %i.\n", message, f1, f2, idx1, idx2); }
267 | return 0;
268 | }
269 |
270 | int test_eq(int i1, int i2, char* message)
271 | {
272 | if(i1 == i2){ return 1;}
273 | else{ printf("%s: %i != %i\n", message, i1, i2); }
274 | return 0;
275 | }
276 |
277 | int test_eq(int i1, int i2, int idx1, int idx2, char* message)
278 | {
279 | if(i1 == i2){ return 1;}
280 | else{ printf("%s: %i != %i for index %i and %i.\n", message, i1, i2, idx1, idx2); }
281 | return 0;
282 | }
283 |
284 | int test_matrix(Matrix *A, int rows, int cols)
285 | {
286 | if((A->rows == rows) &&
287 | (A->cols == cols) &&
288 | (A->size == cols*rows) &&
289 | (A->bytes == cols*rows*sizeof(float)))
290 | {return 1;}
291 | else
292 | {
293 | test_eq(A->rows,rows,"Matrix rows");
294 | test_eq(A->cols,cols,"Matrix cols");
295 | test_eq(A->size,cols*rows,"Matrix size");
296 | test_eq((int)(A->bytes),(int)(cols*rows*sizeof(float)),"Matrix bytes");
297 | }
298 |
299 | return 0;
300 | }
301 |
302 | void print_matrix(Matrix *A, int end_rows, int end_cols)
303 | {
304 | if(A->isSparse != 1)
305 | {
306 | for(int row = 0; row< end_rows; row++)
307 | {
308 | printf("[");
309 | for(int col =0; col < end_cols; col++)
310 | {
311 | if(A->data[(row*A->cols)+col] > 0.0f)
312 | printf("% f ",A->data[(row*A->cols)+col]);
313 | else
314 | printf("%f ",A->data[(row*A->cols)+col]);
315 | }
316 | printf("]\n");
317 | }
318 | printf("\n");
319 | }
320 | else
321 | {
322 | printf("[");
323 | for(int i = end_rows; i < end_cols; i++)
324 | printf("%f ",A->data[i]);
325 |
326 | printf("]\n");
327 | }
328 | }
329 |
330 | void print_matrix(Matrix *A, int start_row, int end_row, int start_col, int end_col)
331 | {
332 | assert(A->isSparse == 0);
333 |
334 | for(int row = start_row; row< end_row; row++)
335 | {
336 | printf("[");
337 | for(int col =start_col; col < end_col; col++)
338 | {
339 | if(A->data[(row*A->cols)+col] > 0.0f)
340 | printf("% f ",A->data[(row*A->cols)+col]);
341 | else
342 | printf("%f ",A->data[(row*A->cols)+col]);
343 | }
344 | printf("]\n");
345 | }
346 | printf("\n");
347 |
348 | }
349 |
350 | void printmat(Matrix *A)
351 | {
352 | Matrix * m = to_host(A);
353 | if(A->isSparse == 0)
354 | print_matrix(m,A->rows,A->cols);
355 | else
356 | print_matrix(m,0,A->size);
357 | free(m->data);
358 | free(m);
359 |
360 | }
361 |
362 | void printdim(Matrix *A)
363 | {
364 | cout << A->rows << "x" << A->cols << endl;
365 | }
366 |
367 | void printsum(Matrix *A)
368 | {
369 | cout << sum(A) << endl;
370 | }
371 |
372 | void printhostmat(Matrix *A)
373 | {
374 | if(A->isSparse == 0)
375 | print_matrix(A,A->rows,A->cols);
376 | else
377 | print_matrix(A,0,A->size);
378 | }
379 |
380 | void printmat(Matrix *A, int end_rows, int end_cols)
381 | {
382 | Matrix * m = to_host(A);
383 | print_matrix(m, end_rows, end_cols);
384 | free(m->data);
385 | free(m);
386 |
387 | }
388 |
389 | void printmat(Matrix *A, int start_row, int end_row, int start_col, int end_col)
390 | {
391 | Matrix * m = to_host(A);
392 | print_matrix(m, start_row, end_row, start_col, end_col);
393 | free(m->data);
394 | free(m);
395 |
396 | }
397 |
398 | bool replace(std::string& str, const std::string& from, const std::string& to)
399 | {
400 | size_t start_pos = str.find(from);
401 | if(start_pos == std::string::npos)
402 | return false;
403 | str.replace(start_pos, from.length(), to);
404 | return true;
405 | }
406 |
407 | void slice_sparse_to_dense(Matrix *X, Matrix *out, int start, int length)
408 | {
409 | int idx_from = 0;
410 | int idx_to = 0;
411 | int idx = 0;
412 |
413 | for(int i = 0; i < out->size; i++)
414 | out->data[i] = 0.0f;
415 |
416 | for(int row = 0; row < length; row++)
417 | {
418 | idx_from = X->ptr_rows[start + row];
419 | idx_to = X->ptr_rows[start + row + 1];
420 |
421 | for(int i = idx_from; i < idx_to; i++)
422 | {
423 | idx = X->idx_cols[i];
424 | out->data[(row*out->cols) + idx] = X->data[i];
425 | }
426 | }
427 |
428 |
429 |
430 | }
431 |
432 | float determine_max_sparsity(Matrix *X, int batch_size)
433 | {
434 |
435 | float max_sparsity = 0.0;
436 |
437 | Matrix *dense_batch = empty_cpu(batch_size,X->cols);
438 | int batches = (X->rows / batch_size);
439 | float batch_elements = batch_size*X->cols;
440 |
441 | float nonzero_count = 0.0f;
442 | for(int i = 0; i < batches; i++)
443 | {
444 | nonzero_count = (X->ptr_rows[(i+1)*batch_size] - X->ptr_rows[i*batch_size]);
445 |
446 | if(max_sparsity < (nonzero_count / batch_elements))
447 | max_sparsity = (nonzero_count / batch_elements);
448 |
449 | nonzero_count = 0.0f;
450 | }
451 |
452 | return max_sparsity;
453 |
454 | }
455 |
456 | #define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
457 | inline void gpuAssert(cudaError_t code, char *file, int line, bool abort=true)
458 | {
459 | if (code != cudaSuccess)
460 | {
461 | fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
462 | if (abort) exit(code);
463 | }
464 | }
465 |
466 |
467 |
468 |
469 |
--------------------------------------------------------------------------------
/source/util.cuh:
--------------------------------------------------------------------------------
1 | #ifndef util_H
2 | #define util_H
3 | #include
4 | #include
5 |
6 | #define ASSERT(condition, message) \
7 | do { \
8 | if (! (condition)) { \
9 | std::cerr << "Assertion `" #condition "` failed in " << __FILE__ \
10 | << " line " << __LINE__ << ": " << message << std::endl; \
11 | std::exit(EXIT_FAILURE); \
12 | } \
13 | } while (false)
14 |
15 |
16 | Matrix *read_csv(const char* filename);
17 | void write_csv(const char* filename, Matrix *X, const char* header, Matrix *ids);
18 | void write_csv(const char* filename, Matrix *X);
19 | Matrix *read_hdf5(const char * filepath);
20 | Matrix *read_sparse_hdf5(const char * filepath);
21 | Matrix *read_hdf5(const char *filepath, const char *tag);
22 | void write_hdf5(const char * filepath, Matrix *A);
23 |
24 | cudaEvent_t* tick();
25 | float tock(cudaEvent_t* startstop);
26 | float tock(cudaEvent_t* startstop, std::string text);
27 | float tock(std::string text, float tocks);
28 | float tock(cudaEvent_t* startstop, float tocks);
29 | int test_eq(float f1, float f2, char* message);
30 | int test_eq(float f1, float f2, int idx1, int idx2, char* message);
31 | int test_eq(int i1, int i2, char* message);
32 | int test_eq(int i1, int i2, int idx1, int idx2, char* message);
33 | int test_matrix(Matrix *A, int rows, int cols);
34 | void printmat(Matrix *A);
35 | void printhostmat(Matrix *A);
36 | void printdim(Matrix *A);
37 | void printsum(Matrix *A);
38 | void printmat(Matrix *A, int end_rows, int end_cols);
39 | void printmat(Matrix *A, int start_row, int end_row, int start_col, int end_col);
40 | void print_matrix(Matrix *A, int end_rows, int end_cols);
41 | void print_matrix(Matrix *A, int start_row, int end_row, int start_col, int end_col);
42 | bool replace(std::string& str, const std::string& from, const std::string& to);
43 | void slice_sparse_to_dense(Matrix *X, Matrix *out, int start, int length);
44 | float determine_max_sparsity(Matrix *X, int batch_size);
45 | template std::string NumberToString (T Number)
46 | {
47 | std::ostringstream ss;
48 | ss << Number;
49 | return ss.str();
50 | }
51 | #endif
52 |
53 |
54 |
--------------------------------------------------------------------------------
/tests/basicOps_test.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 |
7 | using std::cout;
8 | using std::endl;
9 |
10 | int run_basicOps_test(ClusterNet gpus)
11 | {
12 |
13 | ClusterNet gpu = ClusterNet();
14 |
15 | Matrix *m1 = ones(5,6);
16 | Matrix *m2 = ones(5,6);
17 | Matrix *m3 = zeros(5,6);
18 | Matrix *out = zeros(5,6);
19 |
20 | //to_col_major test
21 | // 0 2 3
22 | // m1 = 0 0.83 59.1387
23 | //
24 | float m1_data[6] = {0,2,3,0,0.83,59.1387};
25 | size_t m1_bytes = 2*3*sizeof(float);
26 | Matrix *m1_cpu = (Matrix*)malloc(sizeof(Matrix));
27 | m1_cpu->rows = 2;
28 | m1_cpu->cols = 3;
29 | m1_cpu->bytes = m1_bytes;
30 | m1_cpu->size = 6;
31 | m1_cpu->data = m1_data;
32 |
33 | m1 = to_gpu(m1_cpu,1);
34 | //to_col_major test
35 | m1 = to_col_major(m1);
36 | float *test;
37 | test = (float*)malloc(m1->bytes);
38 | cudaMemcpy(test,m1->data,m1->bytes,cudaMemcpyDefault);
39 |
40 | assert(test_eq(test[0], 0.0f,"To col major data."));
41 | assert(test_eq(test[1], 0.0f,"To col major data."));
42 | assert(test_eq(test[2], 2.0f,"To col major data."));
43 | assert(test_eq(test[3], 0.83f,"To col major data."));
44 | assert(test_eq(test[4], 3.0f,"To col major data."));
45 | assert(test_eq(test[5], 59.1387f,"To col major data."));
46 |
47 |
48 |
49 | m1 = to_row_major(m1);
50 | cudaMemcpy(test,m1->data,m1->bytes,cudaMemcpyDefault);
51 |
52 | assert(test_eq(test[0], 0.0f,"To row major data."));
53 | assert(test_eq(test[1], 2.0f,"To row major data."));
54 | assert(test_eq(test[2], 3.0f,"To row major data."));
55 | assert(test_eq(test[3], 0.0f,"To row major data."));
56 | assert(test_eq(test[4], 0.83f,"To row major data."));
57 | assert(test_eq(test[5], 59.1387f,"To row major data."));
58 |
59 | assert(test_eq(getNonZeroElements(m1),4 ,"Get non-zero elements."));
60 |
61 |
62 | //test to_host
63 | //data is converted to column major and then back to row major
64 | Matrix *m_host = to_host(to_gpu(m1_cpu));
65 | assert(m_host->rows==m1->rows);
66 | assert(m_host->cols==m1->cols);
67 | assert(m_host->size==m1->size);
68 | assert(m_host->bytes==m1->bytes);
69 | for(int i = 0; i< 5; i++)
70 | {
71 | assert(m_host->data[i]==m1_cpu->data[i]);
72 | }
73 |
74 |
75 | //test fill_with
76 | m1 = ones(5,6);
77 | m_host = to_host(m1);
78 | for(int i = 0; i< 30; i++)
79 | {
80 | assert(m_host->data[i]==1.0f);
81 | }
82 |
83 | //test add
84 | m3 = add(m1,m2);
85 | m_host = to_host(m3);
86 | for(int i = 0; i< 30; i++)
87 | {
88 | assert(m_host->data[i]==2.0f);
89 | }
90 |
91 | //test to_gpu
92 | m_host = to_host(add(to_gpu(m_host),to_gpu(m_host)));
93 | for(int i = 0; i< 30; i++)
94 | {
95 | assert(test_eq(m_host->data[i],4.0f,"To gpu data"));
96 | }
97 |
98 | //test mul
99 | m3 = mul(m3,m3);
100 | m_host = to_host(m3);
101 | for(int i = 0; i< 30; i++)
102 | {
103 | assert(test_eq(m_host->data[i],4.0f,"Multiplication data"));
104 | }
105 |
106 | //test sub
107 | m3 = sub(m3,m1);
108 | m_host = to_host(m3);
109 | for(int i = 0; i< 30; i++)
110 | {
111 | assert(m_host->data[i]==3.0f);
112 | }
113 |
114 | //test div
115 | m2 = add(m1,m2); //2
116 | m3 = div(m3,m2);
117 | m_host = to_host(m3);
118 | for(int i = 0; i< 30; i++)
119 | {
120 | assert(m_host->data[i]==1.5f);
121 | }
122 |
123 | //test add with given output Matrix *
124 | add(m3,m2,out);
125 | m_host = to_host(out);
126 | for(int i = 0; i< 30; i++)
127 | {
128 | assert(m_host->data[i]==3.5f);
129 | }
130 |
131 | //test sub with given output Matrix *
132 | sub(m3,m2,out);
133 | m_host = to_host(out);
134 | for(int i = 0; i< 30; i++)
135 | {
136 | assert(m_host->data[i]==-0.5f);
137 | }
138 |
139 | //test mul with given output Matrix *
140 | mul(m3,m2,out);
141 | m_host = to_host(out);
142 | for(int i = 0; i< 30; i++)
143 | {
144 | assert(m_host->data[i]==3.0f);
145 | }
146 |
147 | //test div with given output Matrix *
148 | div(m3,m2,out);
149 | m_host = to_host(out);
150 | for(int i = 0; i< 30; i++)
151 | {
152 | assert(m_host->data[i]==0.75f);
153 | }
154 |
155 | //test exp
156 | m_host = to_host(gpuExp(zeros(5,6)));
157 | for(int i = 0; i< 30; i++)
158 | {
159 | assert(m_host->data[i]==1.0f);
160 | }
161 |
162 | //test scalar mul
163 | m_host = to_host(scalarMul(ones(5,6),1.83));
164 | for(int i = 0; i< 30; i++)
165 | {
166 | assert(m_host->data[i]==1.83f);
167 | }
168 |
169 | //test sqrt
170 | m_host = to_host(gpuSqrt(scalarMul(ones(5,6),4)));
171 | for(int i = 0; i< 30; i++)
172 | {
173 | assert(m_host->data[i]==2.0f);
174 | }
175 |
176 | //test log
177 | m_host = to_host(gpuLog(scalarMul(ones(5,6),2.0)));
178 | for(int i = 0; i< 30; i++)
179 | {
180 | assert(m_host->data[i]==log(2.0f));
181 | }
182 |
183 | //test square
184 | m_host = to_host(square(scalarMul(ones(5,6),2)));
185 | for(int i = 0; i< 30; i++)
186 | {
187 | assert(m_host->data[i]==4.0f);
188 | }
189 |
190 | //test blnFaultySizes
191 | assert(blnFaultySizes(ones(1,3),ones(2,3),ones(2,3))==1);
192 | assert(blnFaultySizes(ones(1,3),ones(1,3),ones(2,3))==1);
193 | assert(blnFaultySizes(ones(1,3),ones(1,3),ones(1,3))==0);
194 | assert(blnFaultySizes(ones(3,3),ones(3,3),ones(3,3))==0);
195 | //test blnFaultyMatrixSizes
196 | assert(blnFaultyMatrixProductSizes(ones(1,3),ones(1,3),ones(3,3),CUBLAS_OP_N,CUBLAS_OP_N)==1);
197 | assert(blnFaultyMatrixProductSizes(ones(3,1),ones(1,3),ones(2,2),CUBLAS_OP_N,CUBLAS_OP_N)==1);
198 | assert(blnFaultyMatrixProductSizes(ones(3,1),ones(1,3),ones(3,3),CUBLAS_OP_N,CUBLAS_OP_N)==0);
199 |
200 | //transpose test
201 | //column major order
202 | // 0 2 3
203 | // m1 = 0 0.83 59.1387
204 | //
205 | //test to_gpu with is_col_major = 1
206 | m_host = to_host(T(to_gpu(m1_cpu)));
207 | assert(test_eq(m_host->data[0],0.0f,"Transpose data."));
208 | assert(m_host->data[1]==0.0f);
209 | assert(m_host->data[2]==2.0f);
210 | assert(m_host->data[3]==0.83f);
211 | assert(m_host->data[4]==3.0f);
212 | assert(m_host->data[5]==59.1387f);
213 | assert(test_matrix(m_host,3,2));
214 |
215 | //to host and to gpu test
216 | // 0 2 3
217 | // m1 = 0 0.83 59.1387
218 | //
219 | //to gpu and to host should cancel each other out
220 | m_host = to_host(to_gpu(m1_cpu));
221 | assert(m_host->data[0]==0.0f);
222 | assert(m_host->data[1]==2.0f);
223 | assert(m_host->data[2]==3.0f);
224 | assert(m_host->data[3]==0.0f);
225 | assert(m_host->data[4]==0.83f);
226 | assert(m_host->data[5]==59.1387f);
227 | assert(test_matrix(m_host,2,3));
228 |
229 | //to_gpu for col major data test
230 | //col major data
231 | float m2_data[6] = {0,0,2.0f,0.83,3,59.1387};
232 | size_t m2_bytes = 2*3*sizeof(float);
233 | Matrix *m2_cpu = (Matrix*)malloc(sizeof(Matrix));
234 | m2_cpu->rows = 2;
235 | m2_cpu->cols = 3;
236 | m2_cpu->bytes = m2_bytes;
237 | m2_cpu->size = 6;
238 | m2_cpu->data = m2_data;
239 | m_host = to_host(to_gpu(m2_cpu,1));
240 | //should be in row major now
241 | assert(m_host->data[0]==0.0f);
242 | assert(m_host->data[1]==2.0f);
243 | assert(m_host->data[2]==3.0f);
244 | assert(m_host->data[3]==0.0f);
245 | assert(m_host->data[4]==0.83f);
246 | assert(m_host->data[5]==59.1387f);
247 | assert(test_matrix(m_host,2,3));
248 |
249 | //slice rows
250 | m1 = gpu.rand(10,10);
251 | m2 = to_host(slice_rows(m1, 2,5));
252 | m1 = to_host(m1);
253 | assert(test_matrix(m2,4,10));
254 | int idx = 0;
255 | for(int i = 20; i < 60; i++)
256 | {
257 | assert(test_eq(m1->data[i], m2->data[idx], idx, i , "Row slice data"));
258 | idx++;
259 | }
260 |
261 | //slice cols
262 | m1 = gpu.rand(10,10);
263 | m2 = to_host(slice_cols(m1, 2,5));
264 | m1 = to_host(m1);
265 | idx = 0;
266 | assert(test_matrix(m2,10,4));
267 |
268 |
269 | for(int i = 2; i < 100;i++)
270 | {
271 | if(((i % 10) < 6) &&
272 | ((i % 10) > 1))
273 | {
274 | assert(test_eq(m1->data[i], m2->data[idx], idx, i , "Col slice data"));
275 | idx++;
276 | }
277 | }
278 |
279 | //softmax test
280 | m1 = softmax(ones(2056,10));
281 | m_host = to_host(m1);
282 | assert(test_matrix(m_host,2056,10));
283 | for(int i = 0; i < m_host->size; i++)
284 | {
285 | assert(test_eq(m_host->data[i],0.1,"Softmax equal test"));
286 | }
287 |
288 | m1 = softmax(gpu.rand(2222,17));
289 | m_host = to_host(m1);
290 | assert(test_matrix(m_host,2222,17));
291 | float sum_value = 0;
292 | for(int i = 0; i < m_host->size; i++)
293 | {
294 | sum_value += m_host->data[i];
295 | if((i > 0) && (((i+1) % 17) == 0))
296 | {
297 | ASSERT((sum_value > 0.99) && (sum_value < 1.01), "Softmax row sum equal one");
298 | sum_value = 0.0f;
299 | }
300 | }
301 |
302 |
303 | m1 = zeros(10,10);
304 | m2 = ones(10,1);
305 | //sub matrix vector test: A - v
306 | m_host= to_host(subMatrixVector(m1,m2));
307 | assert(test_matrix(m_host,10,10));
308 | for(int i = 0; i < m_host->size; i++)
309 | {
310 | assert(test_eq(m_host->data[i],-1.0f, "Matrix - vector, equal data test"));
311 | }
312 | m3 = gpu.rand(13,17);
313 | Matrix *m4 = gpu.rand(1,17);
314 | m_host = to_host(addMatrixVector(m3,m4));
315 | m3 = to_host(m3);
316 | m4 = to_host(m4);
317 | assert(test_matrix(m_host,13,17));
318 | for(int row = 0; row < m_host->rows; row++)
319 | {
320 | for(int col = 0; col < m_host->cols; col++)
321 | assert(test_eq(m_host->data[(row*m_host->cols) + col], m3->data[(row*m_host->cols) + col] + m4->data[col], "Matrix + vector, equal data test"));
322 | }
323 |
324 | // 0 2 3
325 | // m1 = 0 0.83 59.1387
326 | //
327 | //argmax test
328 | //col_value = A[(i*cols) + idx];
329 | m1 = argmax(to_gpu(m1_cpu));
330 | m_host = to_host(m1);
331 | assert(test_matrix(m_host,2,1));
332 | assert(test_eq(m_host->data[0],2.0f, "Argmax test"));
333 | assert(test_eq(m_host->data[1],2.0f, "Argmax test"));
334 | m1 = gpu.rand(2056,10);
335 | m_host = to_host(argmax(m1));
336 | int counts[10] = {0,0,0,0,0,
337 | 0,0,0,0,0};
338 | assert(test_matrix(m_host,2056,1));
339 | for(int i = 0; i < m_host->size; i++)
340 | {
341 | counts[(int)m_host->data[i]]++;
342 | }
343 | for(int i = 0; i < 10; i++)
344 | {
345 | //expectation is 205.6 each;
346 | ASSERT((counts[i] > 140) && (counts[i] < 280), "Argmax value test");
347 | }
348 |
349 | //create t matrix test
350 | m1 = scalarMul(ones(10,1),4);
351 | m1 = create_t_matrix(m1,7);
352 | m_host = to_host(m1);
353 | assert(test_matrix(m_host,10,7));
354 | for(int i = 0; i < m_host->size; i++)
355 | {
356 | if((i % m1->cols) == 4)
357 | {
358 | assert(test_eq(m_host->data[i],1.0f, "Create t matrix data"));
359 | }
360 | else
361 | {
362 | assert(test_eq(m_host->data[i],0.0f, "Create t matrix data"));
363 | }
364 | }
365 |
366 | //equal test
367 | gpu = ClusterNet(12345);
368 | ClusterNet gpu2 = ClusterNet(12345);
369 | m2 = gpu.rand(10,7);
370 | m1 = gpu2.rand(10,7);
371 | m_host = to_host(equal(m1,m2));
372 | assert(test_matrix(m_host,10,7));
373 | for(int i = 0; i < m_host->size; i++)
374 | {
375 | assert(test_eq(m_host->data[i],1.0f, "Matrix matrix Equal data test"));
376 | }
377 | m1 = gpu2.rand(10,7);
378 | m_host = to_host(equal(m1,m2));
379 | assert(test_matrix(m_host,10,7));
380 | for(int i = 0; i < m_host->size; i++)
381 | {
382 | assert(test_eq(m_host->data[i],0.0f, "Matrix matrix Equal data test"));
383 | }
384 |
385 |
386 | //test sum
387 | m1 = ones(10,1);
388 | m2 = ones(1,10);
389 |
390 | ASSERT(sum(m1) == 10.0f, "Vector sum test");
391 | ASSERT(sum(m2) == 10.0f, "Vector sum test");
392 | m1 = ones(10,10);
393 | ASSERT(sum(m1) == 100.0f, "Vector sum test");
394 | ASSERT(sum(scalarMul(m2,1.73)) > 17.29f, "Vector sum test");
395 | ASSERT(sum(scalarMul(m2,1.73)) < 17.31f, "Vector sum test");
396 |
397 | //logistic test
398 | m1 = zeros(2,2);
399 | m1 = to_host(logistic(m1));
400 | assert(test_matrix(m1,2,2));
401 | for(int i = 0; i < m1->size; i++)
402 | {
403 | ASSERT(m1->data[i] == 0.5f,"Logistic data test.");
404 | }
405 | m1 = gpu.randn(100,100);
406 | m1 = to_host(logistic(m1));
407 | assert(test_matrix(m1,100,100));
408 | for(int i = 0; i < m1->size; i++)
409 | {
410 | ASSERT((m1->data[i] > 0.0f) && (m1->data[i] < 1.0f),"Logistic data test.");
411 | }
412 |
413 | //logistic grad test
414 | m1 = ones(2,2);
415 | m1 = to_host(logisticGrad(m1));
416 | assert(test_matrix(m1,2,2));
417 | for(int i = 0; i < m1->size; i++)
418 | {
419 | ASSERT(m1->data[i] == 0.0f,"Logistic data test.");
420 | }
421 | m1 = gpu.randn(100,100);
422 | m_host = to_host(m1);
423 | m1 = to_host(logisticGrad(m1));
424 | assert(test_matrix(m1,100,100));
425 | for(int i = 0; i < m1->size; i++)
426 | {
427 | ASSERT(m_host->data[i]*(1-m_host->data[i]) == m1->data[i],"Logistic data test.");
428 | }
429 |
430 | //arange test
431 | m1 = arange(10,7);
432 | m_host = to_host(m1);
433 | assert(test_matrix(m_host,10,7));
434 | for(int i = 0; i < m1->size; i++)
435 | {
436 | assert(test_eq(m_host->data[i],(float)i, "Arange data test."));
437 | }
438 |
439 | m1 = arange(101,10,7);
440 | m_host = to_host(m1);
441 | assert(test_matrix(m_host,10,7));
442 | for(int i = 0; i < m1->size; i++)
443 | {
444 | assert(test_eq(m_host->data[i],(float)(i + 101), "Arange data test."));
445 | }
446 |
447 | //cutoff to probability test
448 | m_host = to_host(doubleRectifiedLinear(gpu.randn(123,357,0,10)));
449 | assert(test_matrix(m_host,123,357));
450 | for(int i = 0; i < m_host->size; i++)
451 | ASSERT((m_host->data[i] <=1.0f) && (m_host->data[i] >=0.0f),"cutoff to probability test.");
452 |
453 |
454 | m1 = empty_sparse(17,83,0.01783,0.0);
455 | int elements = ceil(17*83*0.01783) + 1;
456 | ASSERT(m1->rows == 17, "empty sparse rows");
457 | ASSERT(m1->cols == 83, "empty sparse cols");
458 | ASSERT(m1->size == elements, "empty sparse size");
459 | ASSERT(m1->isSparse == 1, "empty sparse");
460 | ASSERT(m1->idx_bytes == sizeof(float)*elements, "empty sparse bytes");
461 | ASSERT(m1->bytes == sizeof(float)*elements, "empty sparse bytes");
462 | ASSERT(m1->ptr_bytes == sizeof(float)*(m1->rows + 1), "empty sparse bytes");
463 |
464 | m1 = empty_sparse(17,83,500);
465 | elements = 500;
466 | ASSERT(m1->rows == 17, "empty sparse rows");
467 | ASSERT(m1->cols == 83, "empty sparse cols");
468 | ASSERT(m1->size == elements, "empty sparse size");
469 | ASSERT(m1->isSparse == 1, "empty sparse");
470 | ASSERT(m1->idx_bytes == sizeof(float)*elements, "empty sparse bytes");
471 | ASSERT(m1->bytes == sizeof(float)*elements, "empty sparse bytes");
472 | ASSERT(m1->ptr_bytes == sizeof(float)*(m1->rows + 1), "empty sparse bytes");
473 |
474 | m1 = empty_pinned_sparse(171,837,0.01783,0.001110);
475 | elements = ceil(171*837*(0.01783+0.001110)) + 1;
476 | ASSERT(m1->rows == 171, "empty sparse rows");
477 | ASSERT(m1->cols == 837, "empty sparse cols");
478 | ASSERT(m1->size == elements, "empty sparse size");
479 | ASSERT(m1->isSparse == 1, "empty sparse");
480 | ASSERT(m1->idx_bytes == sizeof(float)*elements, "empty sparse bytes");
481 | ASSERT(m1->bytes == sizeof(float)*elements, "empty sparse bytes");
482 | ASSERT(m1->ptr_bytes == sizeof(float)*(m1->rows + 1), "empty sparse bytes");
483 |
484 | for(int i = 0; i < m1->size; i++)
485 | {
486 | ASSERT(m1->data[i] == 0.0f,"empty sparse data");
487 | ASSERT(m1->idx_cols[i] == 0.0f,"empty sparse data");
488 | }
489 |
490 | //fill_gpuarray test
491 | m1 = empty_sparse(10,10,10);
492 | fill_gpuarray(m1->ptr_rows,3,m1->rows+1);
493 | m1 = to_host(m1);
494 | for(int i = 0; i < m1->size; i++)
495 | assert(test_eq(m1->ptr_rows[i], 3,"fill_gpuarray test"));
496 |
497 | //sparse sub test
498 | m1 = ones(100,100);
499 | m2 = ones(100,100);
500 | m1 = gpu.dropout(m1,0.5);
501 | Matrix *s1 = gpu.dense_to_sparse(m1);
502 | m3 = sub(m2,s1);
503 | m4 = sub(m2,m1);
504 | m_host = to_host(m3);
505 | m4 = to_host(m4);
506 | int count = 0;
507 | for(int i = 0; i < m_host->size; i++)
508 | {
509 | ASSERT(m_host->data[i] == 1.0f || m_host->data[i] == 0.0f, "sub sparse test");
510 | assert(test_eq(m_host->data[i], m4->data[i], "sub sparse test"));
511 | if(m_host->data[i] == 0.0f)
512 | count++;
513 | }
514 | ASSERT(count > 4500 && count < 5500, "sub sparse test");
515 |
516 |
517 | m1 = gpu.rand(100,100);
518 | m2 = gpu.rand(100,100);
519 | m1 = gpu.dropout(m1,0.5);
520 | s1 = gpu.dense_to_sparse(m1);
521 | m3 = sub(m2,s1);
522 | m4 = sub(m2,m1);
523 | m_host = to_host(m3);
524 | m4 = to_host(m4);
525 | for(int i = 0; i < m_host->size; i++)
526 | assert(test_eq(m_host->data[i], m4->data[i], "sub sparse test"));
527 |
528 |
529 | //hard tanh test
530 | m1 = to_host(hardTanH(gpu.randn(137,457)));
531 | for(int i = 0; i < m1->size; i++)
532 | ASSERT(m1->data[i] >= -1.0 && m1->data[i] <= 1.0, "hardTanH test");
533 | assert(test_matrix(m1,137,457));
534 |
535 | //hard tanh derivative
536 | m3 = gpu.randn(137,400);
537 | m2 = to_host(m3);
538 | m1 = to_host(hardTanH_derivative(m3));
539 |
540 | assert(test_matrix(m1,137,400));
541 | for(int i = 0; i < m3->size; i++)
542 | ASSERT(((m2->data[i] < -1.0 && m1->data[i] == 0.0) ||
543 | (m2->data[i] > 1.0 && m1->data[i] == 0.0) ||
544 | (m2->data[i] >= -1.0 && m2->data[i] <= 1.0 && m1->data[i] == 1.0)), "hardTanH_derivative test");
545 |
546 | //pairwise ranking tests
547 | m3 = gpu.randn(137,450);
548 | m2 = gpu.randn(137,450);
549 | m1 = to_host(pairwise_ranking(m2,m3));
550 | m4 = to_host(pairwise_ranking_derivative(m2, m3));
551 | m3 = to_host(m3);
552 | m2 = to_host(m2);
553 | assert(test_matrix(m1,137,450));
554 | assert(test_matrix(m4,137,450));
555 | for(int i = 0; i < m1->size; i++)
556 | {
557 | ASSERT(m1->data[i] == (1.0f - m2->data[i] + m3->data[i]) < 0.0f ? 0.0f : (1.0f - m2->data[i] + m3->data[i]), "pairwise ranking test");
558 | ASSERT(m4->data[i] == (1.0f - m2->data[i] + m3->data[i]) > 0.0f ? 1.0f : 0.0f, "pairwise ranking test derivative");
559 | }
560 |
561 | //col max test
562 | m1 = gpu.rand(53,57);
563 | m2 = maxColumnwise(m1);
564 | m1 = to_host(m1);
565 | assert(test_matrix(m2,57,1));
566 | m2 = to_host(m2);
567 | float max_value = -2.0f;
568 | for(int col = 0; col < m1->cols; col++)
569 | {
570 | max_value = -2.0f;
571 | for(int row = 0; row < m1->rows; row++)
572 | if(m1->data[(row*m1->cols) + col] > max_value)
573 | max_value = m1->data[(row*m1->cols) + col];
574 |
575 | assert(test_eq(m2->data[col],max_value,"testing max col value"));
576 |
577 | }
578 |
579 |
580 | //maxout test
581 | m1 = gpu.rand(128,1736);
582 | int maxout_level = 8;
583 | m2 = maxout(m1,maxout_level)[0];
584 | m3 = maxout(m1,maxout_level)[1];
585 | m1 = to_host(m1);
586 | assert(test_matrix(m2,128,1736/8));
587 | assert(test_matrix(m3,128,1736/8));
588 | m2 = to_host(m2);
589 | m3 = to_host(m3);
590 | max_value = -2.0f;
591 | float max_col_value = 0.0f;
592 | for(int row = 0; row < m1->rows; row++)
593 | {
594 | max_value = -2.0f;
595 | for(int col = 0; col < m1->cols; col++)
596 | {
597 | if(m1->data[(row*m1->cols) + col] > max_value)
598 | {
599 | max_value = m1->data[(m1->cols*row) + col];
600 | max_col_value = (float)col;
601 | }
602 |
603 | if((col+1) % maxout_level == 0)
604 | {
605 | assert(test_eq(m2->data[(row*(m1->cols/maxout_level))+(col/maxout_level)],max_value,"testing maxout value"));
606 | assert(test_eq(m3->data[(row*(m1->cols/maxout_level))+(col/maxout_level)],max_col_value,"testing maxout index"));
607 | max_value = -2.0f;
608 | }
609 | }
610 | }
611 |
612 | //expand maxout grad test
613 | Matrix *grad = gpu.rand(2,8);
614 | Matrix *error = gpu.rand(2,4);
615 | m1 = gpu.rand(2,8);
616 | maxout_level = 2;
617 | m2 = maxout(m1,maxout_level)[0];
618 | m3 = maxout(m1,maxout_level)[1];
619 | expand_to_maxout_grad(error,m3,grad);
620 | grad = to_host(grad);
621 | m3 = to_host(m3);
622 | error = to_host(error);
623 | int maxout_block = 0;
624 | float value = 0.0f;
625 | for(int row = 0; row < grad->rows; row++)
626 | {
627 | for(int col = 0; col < grad->cols; col++)
628 | {
629 | value = grad->data[(row*grad->cols) + col];
630 | if(value != 0.0f)
631 | {
632 | assert(test_eq((int)m3->data[(row*m3->cols)+maxout_block],col,"test idx grad for expand maxout"));
633 | assert(test_eq(error->data[(row*error->cols)+maxout_block],value,"test value grad for expand maxout"));
634 | maxout_block++;
635 | }
636 | }
637 | maxout_block = 0;
638 | }
639 |
640 |
641 |
642 |
643 | //update vocab grad test
644 | int vocab_vector_size = 360;
645 | int batch_size = 127;
646 | int window_size = 21;
647 | int vocab_size = 73;
648 | grad = gpu.rand(batch_size,window_size*vocab_vector_size);
649 | Matrix *vocab_idx = gpu.rand_int(batch_size,window_size,0,vocab_size-1);
650 | Matrix *vocab = zeros(vocab_vector_size,vocab_size);
651 | m2 = to_host(vocab);
652 | expand_vocab_gradient(grad,vocab_idx,vocab);
653 |
654 | m1 = to_host(vocab_idx);
655 | m4 = to_host(vocab);
656 | m3 = to_host(grad);
657 |
658 | idx = 0;
659 | for(int row = 0; row < vocab_idx->rows; row++)
660 | for(int col = 0; col < vocab_idx->cols; col++)
661 | {
662 | idx = (int)m1->data[col + (row*m1->cols)];
663 | for(int i = 0; i < vocab_vector_size; i++)
664 | m2->data[idx + (vocab->cols*i)] += m3->data[(col*vocab_vector_size) + (row*grad->cols) + i];
665 |
666 |
667 | }
668 |
669 |
670 | test_eq(sum(to_gpu(m2)),sum(vocab),"expand gradient test");
671 |
672 |
673 | for(int row = 0; row < vocab_idx->rows; row++)
674 | for(int col = 0; col < vocab_idx->cols; col++)
675 | {
676 | idx = (int)m1->data[col + (row*m1->cols)];
677 | for(int i = 0; i < vocab_vector_size; i++)
678 | ASSERT((m2->data[idx + (vocab->cols*i)] + 0.0001 > m4->data[idx + (vocab->cols*i)]) && //0.0001 error in float arithmetic on the GPU
679 | (m2->data[idx + (vocab->cols*i)] - 0.0001 < m4->data[idx + (vocab->cols*i)]) ,"expand gradient test");
680 |
681 | }
682 |
683 |
684 |
685 |
686 | Matrix **partial_vocab = gpus.zeros_stacked(vocab_vector_size/gpus.MPI_SIZE,vocab_size);
687 | m2 = to_host(partial_vocab[gpus.MYRANK]);
688 | expand_partial_vocab_gradient(grad,vocab_idx,partial_vocab[gpus.MYRANK],gpus.MYRANK,gpus.MPI_SIZE);
689 | m1 = to_host(vocab_idx);
690 | m4 = to_host(partial_vocab[gpus.MYRANK]);
691 | m3 = to_host(grad);
692 |
693 |
694 | idx = 0;
695 | for(int row = 0; row < vocab_idx->rows; row++)
696 | for(int col = 0; col < vocab_idx->cols; col++)
697 | {
698 | idx = (int)m1->data[col + (row*m1->cols)];
699 | for(int i = 0; i < vocab_vector_size/gpus.MPI_SIZE; i++)
700 | m2->data[idx + (vocab->cols*i)] += m3->data[(col*vocab_vector_size) + (row*grad->cols) + i + (gpus.MYRANK*partial_vocab[gpus.MYRANK]->rows)];
701 |
702 |
703 | }
704 |
705 |
706 | gpus.add_to_queue(partial_vocab);
707 | while(gpus.get_queue_length() > 0) {gpus.pop_queue(); }
708 | for(int i = 1; i < gpus.MPI_SIZE; i++)
709 | {
710 | add(partial_vocab[0],partial_vocab[i],partial_vocab[0]);
711 | }
712 |
713 | test_eq(sum(partial_vocab[0]),sum(vocab),"expand partial gradient test");
714 |
715 | for(int row = 0; row < vocab_idx->rows; row++)
716 | for(int col = 0; col < vocab_idx->cols; col++)
717 | {
718 | idx = (int)m1->data[col + (row*m1->cols)];
719 | for(int i = 0; i < vocab_vector_size/gpus.MPI_SIZE; i++)
720 | ASSERT((m2->data[idx + (vocab->cols*i)] + 0.0001 > m4->data[idx + (vocab->cols*i)]) && //0.0001 error in float arithmetic on the GPU
721 | (m2->data[idx + (vocab->cols*i)] - 0.0001 < m4->data[idx + (vocab->cols*i)]) ,"expand partial gradient test");
722 |
723 | }
724 |
725 |
726 |
727 | Matrix *z1 = gpu.rand(10,5);
728 | Matrix *z2 = gpu.rand(10,5);
729 | Matrix *z = zeros(10,5);
730 | Matrix *y1 = gpu.rand_int(10,1,0,9);
731 | Matrix *y2 = gpu.rand_int(10,1,0,9);
732 |
733 | printmat(y1);
734 | printmat(y2);
735 | printmat(z1);
736 |
737 |
738 | add_to_z(z,z1,y1,10,z);
739 | add_to_z(z,z2,y2,10,z);
740 |
741 | z1 = to_host(z1);
742 | z2 = to_host(z2);
743 | Matrix *t = to_host(z);
744 | y1 = to_host(y1);
745 | y2 = to_host(y2);
746 | printhostmat(y1);
747 | printhostmat(y2);
748 |
749 | z = zeros(10,5);
750 | z = to_host(z);
751 |
752 | for(int row = 0; row < 10; row++)
753 | for(int col = 0; col < 5; col++)
754 | {
755 | int cls = (int)y1->data[row];
756 | z->data[(cls*z1->cols) + col] += z1->data[col + (row*z1->cols)];
757 | cls = (int)y2->data[row];
758 | z->data[(cls*z2->cols) + col] += z2->data[col + (row*z2->cols)];
759 |
760 | }
761 |
762 |
763 | printhostmat(z);
764 | printhostmat(t);
765 | /*
766 | for(int i = 0; i < 10; i++)
767 | {
768 | ASSERT((t->data[i]+ 0.0001 > z->data[i]) && //0.0001 error in float arithmetic on the GPU
769 | (t->data[i]- 0.0001 < z->data[i]) ,"expand partial gradient test");
770 | }
771 | */
772 |
773 |
774 |
775 |
776 |
777 | return 0;
778 | }
779 |
780 |
781 |
782 |
--------------------------------------------------------------------------------
/tests/basicOps_test.cuh:
--------------------------------------------------------------------------------
1 | #ifndef basicOps_test_H
2 | #define basicOps_test_H
3 | #include
4 | #include
5 | int run_basicOps_test(ClusterNet gpus);
6 | #endif
7 |
--------------------------------------------------------------------------------
/tests/batchAllocator_test.cuh:
--------------------------------------------------------------------------------
1 | #ifndef batchAllocator_test_H
2 | #define batchAllocator_test_H
3 | #include
4 | int run_batchAllocator_test(ClusterNet gpus);
5 | #endif
6 |
--------------------------------------------------------------------------------
/tests/clusterNet_test.cuh:
--------------------------------------------------------------------------------
1 | #ifndef clusterNet_test_H
2 | #define clusterNet_test_H
3 | #include
4 | int run_clusterNet_test(ClusterNet gpus);
5 | #endif
6 |
--------------------------------------------------------------------------------
/tests/crowdflower_X_test.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimDettmers/clusterNet/cb0bec556c480d26a8be4cd7ff0317ff661fab64/tests/crowdflower_X_test.hdf5
--------------------------------------------------------------------------------
/tests/crowdflower_y_test.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimDettmers/clusterNet/cb0bec556c480d26a8be4cd7ff0317ff661fab64/tests/crowdflower_y_test.hdf5
--------------------------------------------------------------------------------
/tests/miniMNIST_test.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | using std::cout;
11 | using std::endl;
12 |
13 | void run_miniMNIST_test(ClusterNet gpus)
14 | {
15 |
16 | // Tests RMSprop with weight updates, logistic grad.
17 | // Additionally tests the interplay between different functions.
18 |
19 | char buff[1024] = {0};
20 | ssize_t len = ::readlink("/proc/self/exe", buff, sizeof(buff)-1);
21 | std::string path = std::string(buff);
22 | replace(path,"/build/testSuite.out","/tests/");
23 |
24 | //Matrix *X = read_hdf5((path + "/mnist_mini_X.hdf5").c_str());
25 | //Matrix *y = read_hdf5((path + "/mnist_mini_y.hdf5").c_str());
26 |
27 | Matrix *X = read_hdf5("/home/tim/data/mnist/X.hdf5");
28 | Matrix *y = read_hdf5("/home/tim/data/mnist/y.hdf5");
29 |
30 |
31 | Matrix *w1 = gpus.uniformSqrtWeight(784,1000);
32 | Matrix *w2 = gpus.uniformSqrtWeight(1000,10);
33 | Matrix *m1 = zeros(784,1000);
34 | Matrix *m2 = zeros(1000,10);
35 | Matrix *ms1 = zeros(784,1000);
36 | Matrix *ms2 = zeros(1000,10);
37 | Matrix *grad_w1_ms = zeros(784,1000);
38 | Matrix *grad_w2_ms = zeros(1000,10);
39 | Matrix *grad_w2 = empty(1000,10);
40 | Matrix *grad_w1 = empty(784,1000);
41 | float cv_error = 0.0f;
42 | float train_error = 0.0f;
43 |
44 | Matrix *z = zeros(10,1000);
45 |
46 | BatchAllocator b = BatchAllocator();
47 | b.init(X, y, 0.2, 128, 64);
48 | int epochs = 20;
49 | float learning_rate = 0.003;
50 | float momentum = 0.5;
51 | for(int EPOCH = 1; EPOCH < epochs; EPOCH++)
52 | {
53 | momentum += 0.01;
54 | if(momentum > 0.95) momentum = 0.95;
55 |
56 | for(int i = 0; i < b.TOTAL_BATCHES; i++)
57 | {
58 | b.broadcast_batch_to_processes();
59 |
60 |
61 | //nesterov updates
62 | scalarMul(m1,momentum,m1);
63 | scalarMul(m2,momentum,m2);
64 | add(w1,m1,w1);
65 | add(w2,m2,w2);
66 |
67 | //feedforward
68 | Matrix *d0 = gpus.dropout(b.CURRENT_BATCH,0.2);
69 | //print_gpus_matrix(w1);
70 | Matrix *z1 = gpus.dot(d0, w1);
71 | //logistic(z1, z1);
72 | rectified_linear(z1,z1);
73 |
74 | add_to_z(z,z1,y,10,z);
75 |
76 | b.allocate_next_batch_async();
77 |
78 | b.replace_current_batch_with_next();
79 |
80 | }
81 |
82 | scalarMul(z,1.0/(b.TOTAL_BATCHES));
83 |
84 |
85 |
86 | for(int i = 0; i < b.TOTAL_BATCHES; i++)
87 | {
88 | b.broadcast_batch_to_processes();
89 |
90 |
91 | //nesterov updates
92 | scalarMul(m1,momentum,m1);
93 | scalarMul(m2,momentum,m2);
94 | add(w1,m1,w1);
95 | add(w2,m2,w2);
96 |
97 | //feedforward
98 | Matrix *d0 = gpus.dropout(b.CURRENT_BATCH,0.2);
99 | //print_gpus_matrix(w1);
100 | Matrix *z1 = gpus.dot(d0, w1);
101 | //logistic(z1, z1);
102 | rectified_linear(z1,z1);
103 | Matrix *d1 = gpus.dropout(z1,0.5);
104 | Matrix *a2 = gpus.dot(d1,w2);
105 | Matrix *out = softmax(a2);
106 | Matrix *t = create_t_matrix(b.CURRENT_BATCH_Y,10);
107 |
108 | b.allocate_next_batch_async();
109 |
110 | //backprop
111 | Matrix *e1 = sub(out, t);
112 | Matrix *e2 = gpus.dotT(e1, w2);
113 | gpus.Tdot(z1,e1,grad_w2);
114 |
115 | gpus.Tdot(b.CURRENT_BATCH,e2,grad_w1);
116 |
117 | //weight updates
118 | RMSprop_with_nesterov_weight_update(ms1,grad_w1,w1,m1,0.9f,learning_rate,b.CURRENT_BATCH->rows, momentum);
119 | RMSprop_with_nesterov_weight_update(ms2,grad_w2,w2,m2,0.9f,learning_rate,b.CURRENT_BATCH->rows, momentum);
120 |
121 | cudaFree(e1->data);
122 | cudaFree(e2->data);
123 | cudaFree(z1->data);
124 | cudaFree(a2->data);
125 | cudaFree(out->data);
126 | cudaFree(t->data);
127 | cudaFree(d0->data);
128 | cudaFree(d1->data);
129 |
130 | b.replace_current_batch_with_next();
131 |
132 | }
133 |
134 |
135 | train_error = 0;
136 | for(int i = 0; i < b.TOTAL_BATCHES; i++)
137 | {
138 |
139 | b.broadcast_batch_to_processes();
140 |
141 | Matrix *d0 = scalarMul(b.CURRENT_BATCH,0.8);
142 | Matrix *a1 = gpus.dot(d0,w1);
143 | logistic(a1, a1);
144 | Matrix *d1 = scalarMul(a1,0.5);
145 | Matrix *a2 = gpus.dot(d1,w2);
146 | Matrix *out = softmax(a2);
147 | Matrix *result = argmax(out);
148 | Matrix *eq = equal(result,b.CURRENT_BATCH_Y);
149 | b.allocate_next_batch_async();
150 | float sum_value = sum(eq);
151 |
152 | train_error += (b.CURRENT_BATCH->rows - sum_value)/ (1.0f * b.CURRENT_BATCH->rows *b.TOTAL_BATCHES) ;
153 |
154 | cudaFree(a1->data);
155 | cudaFree(a2->data);
156 | cudaFree(d0->data);
157 | cudaFree(d1->data);
158 | cudaFree(out->data);
159 | cudaFree(result->data);
160 | cudaFree(eq->data);
161 |
162 | b.replace_current_batch_with_next();
163 | }
164 |
165 | std::cout << "Train error: " << train_error << std::endl;
166 |
167 | cv_error = 0;
168 | for(int i = 0; i < b.TOTAL_BATCHES_CV; i++)
169 | {
170 | b.broadcast_batch_cv_to_processes();
171 | Matrix *d0 = scalarMul(b.CURRENT_BATCH_CV,0.8);
172 | Matrix *a1 = gpus.dot(d0,w1);
173 | logistic(a1, a1);
174 | Matrix *d1 = scalarMul(a1,0.5);
175 | Matrix *a2 = gpus.dot(d1,w2);
176 | Matrix *out = softmax(a2);
177 | Matrix *result = argmax(out);
178 | Matrix *eq = equal(result,b.CURRENT_BATCH_CV_Y);
179 | b.allocate_next_cv_batch_async();
180 | float sum_value = sum(eq);
181 |
182 | cv_error += (b.CURRENT_BATCH_CV->rows - sum_value)/ (1.0f * b.CURRENT_BATCH_CV->rows *b.TOTAL_BATCHES_CV) ;
183 |
184 | cudaFree(a1->data);
185 | cudaFree(a2->data);
186 | cudaFree(d0->data);
187 | cudaFree(d1->data);
188 | cudaFree(out->data);
189 | cudaFree(result->data);
190 | cudaFree(eq->data);
191 |
192 | b.replace_current_cv_batch_with_next();
193 | }
194 |
195 | std::cout << "Cross validation error: " << cv_error << std::endl;
196 |
197 | }
198 | /*
199 | ASSERT(train_error < 0.03f,"mini-MNIST train error 17 epochs < 0.03.");
200 | ASSERT(cv_error < 0.22f, "mini-MNIST train error 17 epochs < 0.22.");
201 |
202 | b.finish_batch_allocator();
203 |
204 |
205 | Matrix *w1_dist = gpus.distributed_uniformSqrtWeight(784,1000);
206 | Matrix *w2_dist = gpus.distributed_uniformSqrtWeight(1000,10);
207 | Matrix *m1_dist = gpus.distributed_zeros(784,1000);
208 | Matrix *m2_dist = gpus.distributed_zeros(1000,10);
209 | Matrix *ms1_dist = gpus.distributed_zeros(784,1000);
210 | Matrix *ms2_dist = gpus.distributed_zeros(1000,10);
211 | Matrix *grad_w1_ms_dist = gpus.distributed_zeros(784,1000);
212 | Matrix *grad_w2_ms_dist = gpus.distributed_zeros(1000,10);
213 | Matrix *grad_w1_dist = gpus.distributed_zeros(784,1000);
214 | Matrix *grad_w2_dist = gpus.distributed_zeros(1000,10);
215 |
216 | BatchAllocator b_dist = BatchAllocator();
217 | b_dist.init(X, y, 0.2, 32, 64, gpus, Distributed_weights);
218 | for(int EPOCH = 1; EPOCH < epochs; EPOCH++)
219 | {
220 | momentum += 0.01;
221 | if(momentum > 0.95) momentum = 0.95;
222 | for(int i = 0; i < b_dist.TOTAL_BATCHES; i++)
223 | {
224 |
225 | b_dist.broadcast_batch_to_processes();
226 |
227 | //nesterov updates
228 | scalarMul(m1_dist,momentum,m1_dist);
229 | scalarMul(m2_dist,momentum,m2_dist);
230 | add(w1_dist,m1_dist,w1_dist);
231 | add(w2_dist,m2_dist,w2_dist);
232 |
233 | Matrix *d0 = gpus.dropout(b_dist.CURRENT_BATCH,0.2);
234 | //print_gpus_matrix(w1);
235 | Matrix *z1 = gpus.dot(d0, w1_dist);
236 | logistic(z1, z1);
237 | Matrix *d1 = gpus.dropout(z1,0.5);
238 | Matrix *a2 = gpus.dot(d1,w2_dist);
239 | Matrix *out = softmax(a2);
240 | Matrix *t = create_t_matrix(b_dist.CURRENT_BATCH_Y,10);
241 |
242 | b_dist.allocate_next_batch_async();
243 |
244 | //backprop
245 | Matrix *e1 = sub(out, t);
246 | Matrix *e2 = gpus.dotT(e1, w2_dist);
247 | gpus.Tdot(z1,e1,grad_w2_dist);
248 | logisticGrad(z1,z1);
249 | mul(e2,z1,e2);
250 | gpus.Tdot(b_dist.CURRENT_BATCH,e2,grad_w1_dist);
251 |
252 | RMSprop_with_nesterov_weight_update(ms1_dist,grad_w1_dist,w1_dist,m1_dist,0.9f,learning_rate,b_dist.CURRENT_BATCH->rows, momentum);
253 | RMSprop_with_nesterov_weight_update(ms2_dist,grad_w2_dist,w2_dist,m2_dist,0.9f,learning_rate,b_dist.CURRENT_BATCH->rows, momentum);
254 |
255 | cudaFree(e1->data);
256 | cudaFree(e2->data);
257 | cudaFree(z1->data);
258 | cudaFree(a2->data);
259 | cudaFree(out->data);
260 | cudaFree(t->data);
261 | cudaFree(d0->data);
262 | cudaFree(d1->data);
263 |
264 | b_dist.replace_current_batch_with_next();
265 |
266 | }
267 |
268 | train_error = 0;
269 | for(int i = 0; i < b_dist.TOTAL_BATCHES; i++)
270 | {
271 | b_dist.broadcast_batch_to_processes ();
272 |
273 | Matrix *d0 = scalarMul(b_dist.CURRENT_BATCH,0.8);
274 | Matrix *a1 = gpus.dot(d0,w1);
275 |
276 | logistic(a1, a1);
277 | Matrix *d1 = scalarMul(a1,0.5);
278 | Matrix *a2 = gpus.dot(d1,w2);
279 | Matrix *out = softmax(a2);
280 | Matrix *result = argmax(out);
281 | Matrix *eq = equal(result,b_dist.CURRENT_BATCH_Y);
282 | float sum_value = sum(eq);
283 |
284 | b_dist.allocate_next_batch_async();
285 |
286 | train_error += (b_dist.CURRENT_BATCH->rows - sum_value)/ (1.0f * b_dist.CURRENT_BATCH->rows *b_dist.TOTAL_BATCHES) ;
287 |
288 | cudaFree(a1->data);
289 | cudaFree(a2->data);
290 | cudaFree(out->data);
291 | cudaFree(d0->data);
292 | cudaFree(d1->data);
293 | cudaFree(result->data);
294 | cudaFree(eq->data);
295 |
296 | b_dist.replace_current_batch_with_next();
297 | }
298 |
299 | //std::cout << "Train error: " << train_error << std::endl;
300 |
301 | cv_error = 0;
302 | for(int i = 0; i < b_dist.TOTAL_BATCHES_CV; i++)
303 | {
304 | b_dist.broadcast_batch_cv_to_processes();
305 |
306 | Matrix *d0 = scalarMul(b_dist.CURRENT_BATCH_CV,0.8);
307 | Matrix *a1 = gpus.dot(d0,w1);
308 | logistic(a1, a1);
309 | Matrix *d1 = scalarMul(a1,0.5);
310 | Matrix *a2 = gpus.dot(d1,w2);
311 | Matrix *out = softmax(a2);
312 | Matrix *result = argmax(out);
313 | Matrix *eq = equal(result,b_dist.CURRENT_BATCH_CV_Y);
314 | float sum_value = sum(eq);
315 |
316 | b_dist.allocate_next_cv_batch_async();
317 |
318 | cv_error += (b_dist.CURRENT_BATCH_CV->rows - sum_value)/ (1.0f * b_dist.CURRENT_BATCH_CV->rows *b_dist.TOTAL_BATCHES_CV) ;
319 |
320 | cudaFree(a1->data);
321 | cudaFree(a2->data);
322 | cudaFree(d0->data);
323 | cudaFree(d1->data);
324 | cudaFree(out->data);
325 | cudaFree(result->data);
326 | cudaFree(eq->data);
327 |
328 | b_dist.replace_current_cv_batch_with_next();
329 | }
330 |
331 | //std::cout << "Cross validation error: " << cv_error << std::endl;
332 |
333 | }
334 |
335 |
336 | ASSERT(train_error < 0.03f,"mini-MNIST train error 17 epochs < 0.03.");
337 | ASSERT(cv_error < 0.22f, "mini-MNIST train error 17 epochs < 0.22.");
338 |
339 |
340 |
341 | b_dist.finish_batch_allocator();
342 |
343 |
344 | // Maxout test
345 |
346 | // Tests RMSprop with weight updates, logistic grad.
347 | // Additionally tests the interplay between different functions.
348 |
349 | w1 = gpus.uniformSqrtWeight(784,1024);
350 | w2 = gpus.uniformSqrtWeight(128,10);
351 | m1 = zeros(784,1024);
352 | m2 = zeros(128,10);
353 | ms1 = zeros(784,1024);
354 | ms2 = zeros(128,10);
355 | grad_w1_ms = zeros(784,1024);
356 | grad_w2_ms = zeros(128,10);
357 | grad_w1 = empty(784,1024);
358 | grad_w2 = empty(128,10);
359 | cv_error = 0.0f;
360 | train_error = 0.0f;
361 |
362 | b = BatchAllocator();
363 | b.init(X, y, 0.2, 32, 64);
364 | epochs = 17;
365 | learning_rate = 0.01;
366 | momentum = 0.5;
367 | for(int EPOCH = 1; EPOCH < epochs; EPOCH++)
368 | {
369 | momentum += 0.01;
370 | if(momentum > 0.95) momentum = 0.95;
371 | for(int i = 0; i < b.TOTAL_BATCHES; i++)
372 | {
373 | b.broadcast_batch_to_processes();
374 |
375 | //nesterov updates
376 | scalarMul(m1,momentum,m1);
377 | scalarMul(m2,momentum,m2);
378 | add(w1,m1,w1);
379 | add(w2,m2,w2);
380 |
381 | //feedforward
382 | Matrix *d0 = gpus.dropout(b.CURRENT_BATCH,0.2);
383 | //print_gpus_matrix(w1);
384 | Matrix *z1 = gpus.dot(d0, w1);
385 | Matrix **a_paired = maxout(z1,8);
386 | Matrix *a1 = a_paired[0];
387 | Matrix *a1_idx = a_paired[1];
388 | Matrix *d1 = gpus.dropout(a1,0.5);
389 | Matrix *a2 = gpus.dot(d1,w2);
390 | Matrix *out = softmax(a2);
391 | Matrix *t = create_t_matrix(b.CURRENT_BATCH_Y,10);
392 |
393 | b.allocate_next_batch_async();
394 |
395 | //backprop
396 | Matrix *e1 = sub(out, t);
397 | Matrix *e2_partial = gpus.dotT(e1, w2);
398 | Matrix *e2 = empty(b.CURRENT_BATCH->rows,e2_partial->cols*8);
399 |
400 | gpus.Tdot(a1,e1,grad_w2);
401 | expand_to_maxout_grad(e2_partial, a1_idx,e2);
402 | gpus.Tdot(b.CURRENT_BATCH,e2,grad_w1);
403 |
404 | //weight updates
405 | RMSprop_with_nesterov_weight_update(ms1,grad_w1,w1,m1,0.9f,learning_rate,b.CURRENT_BATCH->rows, momentum);
406 | RMSprop_with_nesterov_weight_update(ms2,grad_w2,w2,m2,0.9f,learning_rate,b.CURRENT_BATCH->rows, momentum);
407 |
408 | cudaFree(e1->data);
409 | cudaFree(e2->data);
410 | cudaFree(e2_partial->data);
411 | cudaFree(z1->data);
412 | cudaFree(a1->data);
413 | cudaFree(a1_idx->data);
414 | cudaFree(a2->data);
415 | cudaFree(out->data);
416 | cudaFree(t->data);
417 | cudaFree(d0->data);
418 | cudaFree(d1->data);
419 | free(a_paired);
420 |
421 | b.replace_current_batch_with_next();
422 |
423 | }
424 |
425 |
426 |
427 | train_error = 0;
428 | for(int i = 0; i < b.TOTAL_BATCHES; i++)
429 | {
430 |
431 | b.broadcast_batch_to_processes();
432 |
433 | Matrix *d0 = scalarMul(b.CURRENT_BATCH,0.8);
434 | Matrix *z1 = gpus.dot(d0,w1);
435 | Matrix **a1_pair = maxout(z1,8);
436 | Matrix *a1 = a1_pair[0];
437 | Matrix *d1 = scalarMul(a1,0.5);
438 | Matrix *a2 = gpus.dot(d1,w2);
439 | Matrix *out = softmax(a2);
440 | Matrix *result = argmax(out);
441 | Matrix *eq = equal(result,b.CURRENT_BATCH_Y);
442 | b.allocate_next_batch_async();
443 | float sum_value = sum(eq);
444 |
445 | train_error += (b.CURRENT_BATCH->rows - sum_value)/ (1.0f * b.CURRENT_BATCH->rows *b.TOTAL_BATCHES) ;
446 |
447 | cudaFree(z1->data);
448 | cudaFree(a1->data);
449 | cudaFree(a1_pair[1]->data);
450 | cudaFree(a2->data);
451 | cudaFree(out->data);
452 | cudaFree(result->data);
453 | cudaFree(eq->data);
454 | cudaFree(d1->data);
455 | cudaFree(d0->data);
456 | free(a1_pair);
457 |
458 | b.replace_current_batch_with_next();
459 | }
460 |
461 | //std::cout << "MAXOUT Train error: " << train_error << std::endl;
462 |
463 |
464 |
465 | cv_error = 0;
466 | for(int i = 0; i < b.TOTAL_BATCHES_CV; i++)
467 | {
468 | b.broadcast_batch_cv_to_processes();
469 | Matrix *d0 = scalarMul(b.CURRENT_BATCH_CV,0.8);
470 | Matrix *z1 = gpus.dot(d0,w1);
471 | Matrix **a1_pair = maxout(z1,8);
472 | Matrix *a1 = a1_pair[0];
473 | Matrix *d1 = scalarMul(a1,0.5);
474 | Matrix *a2 = gpus.dot(d1,w2);
475 | Matrix *out = softmax(a2);
476 | Matrix *result = argmax(out);
477 | Matrix *eq = equal(result,b.CURRENT_BATCH_CV_Y);
478 | b.allocate_next_batch_async();
479 | float sum_value = sum(eq);
480 |
481 | cv_error += (b.CURRENT_BATCH_CV->rows - sum_value)/ (1.0f * b.CURRENT_BATCH_CV->rows *b.TOTAL_BATCHES_CV) ;
482 |
483 | cudaFree(z1->data);
484 | cudaFree(a1->data);
485 | cudaFree(a1_pair[1]->data);
486 | cudaFree(a2->data);
487 | cudaFree(out->data);
488 | cudaFree(result->data);
489 | cudaFree(eq->data);
490 | cudaFree(d0->data);
491 | cudaFree(d1->data);
492 | free(a1_pair);
493 |
494 | b.replace_current_cv_batch_with_next();
495 | }
496 |
497 | //std::cout << "MAXOUT Cross validation error: " << cv_error << std::endl;
498 |
499 | }
500 |
501 |
502 | ASSERT(train_error < 0.02f,"mini-MNIST train error 17 epochs < 0.02.");
503 | ASSERT(cv_error < 0.22f, "mini-MNIST train error 17 epochs < 0.22.");
504 | */
505 | std::vector layers;
506 | layers.push_back(768);
507 | layers.push_back(512);
508 |
509 |
510 | BatchAllocator allocator = BatchAllocator();
511 | allocator.init(X,y,0.2,128,256,&gpus, Distributed_weights);
512 | DeepNeuralNetwork net = DeepNeuralNetwork(layers,Classification, gpus, allocator, 10);
513 | net.EPOCHS = 1000;
514 | //net.LEARNING_RATE = 0.001;
515 | net.LEARNING_RATE = 0.001;
516 | net.train();
517 |
518 | if(gpus.MYRANK == 0)
519 | {
520 | cout << endl;
521 | cout << "Train error should be: 0.0025" << endl;
522 | cout << "Cross validation error should be: 0.13" << endl;
523 | }
524 |
525 | allocator = BatchAllocator();
526 | Matrix *t = to_host(create_t_matrix(to_gpu(y),10));
527 | allocator.init(X,t,0.2,128,256,&gpus, Distributed_weights);
528 | net = DeepNeuralNetwork(layers,Regression, gpus, allocator, 10);
529 | net.EPOCHS = 100;
530 | net.PRINT_MISSCLASSIFICATION = true;
531 | net.OUTPUT_IS_PROBABILITY = true;
532 | net.LEARNING_RATE = 0.0003;
533 | net.train();
534 |
535 | if(gpus.MYRANK == 0)
536 | {
537 | cout << endl;
538 | cout << "Train error should be about: 0.05" << endl;
539 | cout << "Cross validation error should be about: 0.25" << endl;
540 | }
541 |
542 |
543 | /*
544 | if(gpus.MYGPUID == 0)
545 | {
546 | X = read_sparse_hdf5((path + "crowdflower_X_test.hdf5").c_str());
547 | y = read_sparse_hdf5((path + "crowdflower_y_test.hdf5").c_str());
548 | }
549 | else
550 | {
551 | X = empty_pinned_sparse(1,1,1);
552 | y = empty_pinned_sparse(1,1,1);
553 | }
554 |
555 | b = BatchAllocator();
556 | b.init(X,y,0.2,128,512,gpus, Distributed_weights_sparse);
557 | layers.clear();
558 | layers.push_back(400);
559 | layers.push_back(400);
560 |
561 | net = DeepNeuralNetwork(layers,Regression,gpus,b,24);
562 | net.EPOCHS = 4;
563 | net.TRANSITION_EPOCH = 4;
564 | net.LEARNING_RATE = 0.0001;
565 | net.OUTPUT_IS_PROBABILITY = true;
566 | net.train();
567 | */
568 |
569 |
570 | }
571 |
--------------------------------------------------------------------------------
/tests/miniMNIST_test.cuh:
--------------------------------------------------------------------------------
1 | #ifndef miniMNIST_test_H
2 | #define miniMNIST_test_H
3 | void run_miniMNIST_test(ClusterNet gpus);
4 | #endif
5 |
--------------------------------------------------------------------------------
/tests/mnist_mini_X.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimDettmers/clusterNet/cb0bec556c480d26a8be4cd7ff0317ff661fab64/tests/mnist_mini_X.hdf5
--------------------------------------------------------------------------------
/tests/mnist_mini_y.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimDettmers/clusterNet/cb0bec556c480d26a8be4cd7ff0317ff661fab64/tests/mnist_mini_y.hdf5
--------------------------------------------------------------------------------
/tests/numpy_arange_as_h5py.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimDettmers/clusterNet/cb0bec556c480d26a8be4cd7ff0317ff661fab64/tests/numpy_arange_as_h5py.hdf5
--------------------------------------------------------------------------------
/tests/scipy_sparse_arange_as_h5py.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimDettmers/clusterNet/cb0bec556c480d26a8be4cd7ff0317ff661fab64/tests/scipy_sparse_arange_as_h5py.hdf5
--------------------------------------------------------------------------------
/tests/testSuite.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 |
10 | int main(int argc, char *argv[])
11 | {
12 | ClusterNet gpus = ClusterNet(argc,argv,132456);
13 | run_basicOps_test(gpus);
14 | run_clusterNet_test(gpus);
15 | run_batchAllocator_test(gpus);
16 | run_miniMNIST_test(gpus);
17 | run_util_test();
18 |
19 |
20 | printf("----------------------\n");
21 | printf("All tests passed successfully!\n");
22 | printf("----------------------\n");
23 |
24 | gpus.shutdown_MPI();
25 |
26 | }
27 |
--------------------------------------------------------------------------------
/tests/testSuite.cuh:
--------------------------------------------------------------------------------
1 | #ifndef testSuite
2 | #define testSuite
3 | #endif
4 |
--------------------------------------------------------------------------------
/tests/util_test.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 |
7 | using std::cout;
8 | using std::endl;
9 |
10 | void run_util_test()
11 | {
12 | /*
13 | char buff[1024];
14 | ssize_t len = ::readlink("/proc/self/exe", buff, sizeof(buff)-1);
15 | std::string path = std::string(buff);
16 | replace(path,"/build/testSuite.out","/tests/");
17 |
18 | Matrix *X = read_hdf5((path + "/numpy_arange_as_h5py.hdf5").c_str());
19 | for(int i = 0;i < X->size; i++)
20 | assert(test_eq(X->data[i],(float)i,"HDF5 read for h5py data."));
21 |
22 | X = read_sparse_hdf5((path + "/scipy_sparse_arange_as_h5py.hdf5").c_str());
23 | for(int i = 0;i < X->size; i++)
24 | assert(test_eq(X->data[i],(float)(i+1),"HDF5 read sparse for h5py data."));
25 |
26 | int col_count = 1;
27 | for(int i = 0;i < 105; i++)
28 | {
29 | assert(test_eq(X->idx_cols[i],col_count,"HDF5 read sparse for h5py data."));
30 | col_count++;
31 | if(col_count == 50)
32 | col_count = 0;
33 |
34 | }
35 |
36 | int row_ptr = 0;
37 | for(int i = 0;i < X->rows-1; i++)
38 | {
39 | assert(test_eq(X->ptr_rows[i],row_ptr,"HDF5 read sparse for h5py data."));
40 | row_ptr += i == 0 ? 49 : 50;
41 | }
42 |
43 | ASSERT(determine_max_sparsity(X,X->rows) == (float)((X->rows*X->cols)-1)/(float)(X->rows*X->cols),"max sparsity test");
44 |
45 | Matrix *out = empty_pinned(X->rows,X->cols);
46 | slice_sparse_to_dense(X,out,0,X->rows);
47 | for(int i = 0;i < out->size; i++)
48 | assert(test_eq(out->data[i],(float)i,"slice sparse to dense test."));
49 |
50 | */
51 |
52 |
53 |
54 | }
55 |
--------------------------------------------------------------------------------
/tests/util_test.cuh:
--------------------------------------------------------------------------------
1 | #ifndef util_test_H
2 | #define util_test_H
3 | void run_util_test();
4 | #endif
5 |
--------------------------------------------------------------------------------