├── deepheart ├── __init__.py ├── train_model.py ├── model.py └── parser.py ├── LICENSE ├── requirements.txt └── README.md /deepheart/__init__.py: -------------------------------------------------------------------------------- 1 | from . import model 2 | from . import parser -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appnope==0.1.0 2 | backports.shutil-get-terminal-size==1.0.0 3 | cycler==0.10.0 4 | decorator==4.0.10 5 | h5py==2.6.0 6 | ipdb==0.10.1 7 | ipython==5.0.0 8 | ipython-genutils==0.1.0 9 | matplotlib==1.5.1 10 | numpy==1.11.1 11 | pandas==0.18.1 12 | pathlib2==2.1.0 13 | pexpect==4.2.0 14 | pickleshare==0.7.3 15 | prompt-toolkit==1.0.3 16 | protobuf==3.0.0b2 17 | ptyprocess==0.5.1 18 | py==1.4.31 19 | Pygments==2.1.3 20 | pyparsing==2.1.4 21 | pytest==2.9.2 22 | python-dateutil==2.5.3 23 | pytz==2016.4 24 | scikit-learn==0.17.1 25 | scipy==0.17.1 26 | simplegeneric==0.8.1 27 | six==1.10.0 28 | traitlets==4.2.2 29 | wcwidth==0.1.7 30 | -------------------------------------------------------------------------------- /deepheart/train_model.py: -------------------------------------------------------------------------------- 1 | from parser import PCG 2 | from model import CNN 3 | import sys 4 | 5 | true_strs = {"True", "true", "t"} 6 | 7 | def load_and_train_model(model_path, load_pretrained): 8 | pcg = PCG(model_path) 9 | 10 | if load_pretrained: 11 | pcg.load("/tmp") 12 | else: 13 | pcg.initialize_wav_data() 14 | 15 | cnn = CNN(pcg, epochs=100, dropout=0.5) 16 | cnn.train() 17 | 18 | if __name__ == '__main__': 19 | data_path = sys.argv[1] 20 | 21 | load_pretrained = False 22 | if len(sys.argv) == 3: 23 | load_pretrained = sys.argv[2] in true_strs 24 | 25 | load_and_train_model(data_path, load_pretrained) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepHeart 2 | 3 | DeepHeart is a neural network designed for the [2016 Physionet Challenge] 4 | (http://physionet.org/physiobank/database/challenge/2016/) in predicting 5 | cardiac abnormalities from phonocardiogram (PCG) data. The challenge 6 | provides heart recordings from several patients labeled as normal 7 | or abnormal. It is difficult to predict patient health from PCG data 8 | because of noise from several sources: talking, breathing, intestinal 9 | sounds, etc. 10 | 11 | To combat the excessive amount of noise and relatively small sample size, 12 | a convolutional neural network is trained using Google's [Tensorflow] 13 | (http://github.com/tensorflow/tensorflow). Tensorflow provides an easy to use interface 14 | for compiling and efficiently running neural networks. 15 | 16 | Ideally the raw wav files would be fed into a very deep Tensorflow 17 | network and, with some careful regularization, the model would learn 18 | to accurately separate signal from noise. To reduce the cost of 19 | training, the number of hidden units is reduced in favor of 20 | some old school feature engineering: the fast fourier transform (FFT). 21 | The FFT is a signal processing technique for converting a signal into 22 | a frequency domain. The original signal is also filtered with a high 23 | pass Butterworth filter aimed at removing noise above 4Hz (or 240 beats 24 | per minute). The filtered signal is again transformed to it's approximate 25 | frequency domain. A combination of the above fourier coefficients are 26 | fed into the convolutional neural network. 27 | 28 | # Installing 29 | 30 | To run, set up a virtual environment (ensure python2.7, virtualenv, and 31 | pip are in your PATH) 32 | 33 | ``` 34 | >> cd deepheart 35 | >> virtualenv env 36 | >> source env/bin/activate 37 | >> pip install -r requirements.txt 38 | ``` 39 | 40 | Download the physionet dataset 41 | 42 | ``` 43 | >> wget http://physionet.org/physiobank/database/challenge/2016/training.zip 44 | >> unzip training.zip 45 | ``` 46 | 47 | Install tensorflow from [Tensorflow's site](https://www.tensorflow.org/versions/r0.9/get_started/os_setup.html#pip-installation) 48 | (pip install recommended) 49 | 50 | Build a feature vector from the raw data and train the CNN 51 | ``` 52 | >> python deepheart/train_model.py 53 | e.g., 54 | >> python deepheart/train_model.py training/ f 55 | ``` 56 | 57 | Note: by default this saves tensorboard statistics to /tmp which can 58 | be launched using 59 | ``` 60 | >> tensorboard --logdir=/tmp/train 61 | ``` 62 | 63 | # Performance 64 | Currently physionet data is scoring using the mean of sensitivity and 65 | specificity (Fraction of True positives and True Negatives). These summaries 66 | are calculated and logged in tensorboard as well as printed to terminal. 67 | 68 | Currently, the tensorflow CNN model converges to a mean score of 69 | 0.78. 70 | 71 | # Disclaimer 72 | This software is not intended for diagnostic purposes. It is only designed 73 | for the physionet data science competition. All statements have not been evaluated by the FDA. 74 | This product is not intended to diagnose, treat, cure, or prevent any disease. -------------------------------------------------------------------------------- /deepheart/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | from datetime import datetime 4 | 5 | 6 | class CNN: 7 | def __init__(self, pcg, nclasses=2, learning_rate=0.001, 8 | epochs=5, batch_size=100, dropout=0.75, base_dir="/tmp", 9 | model_name="cnn"): 10 | self.pcg = pcg 11 | self.nclasses = nclasses 12 | self.d_input = self.pcg.train.X.shape[1] 13 | self.learning_rate = learning_rate 14 | self.epochs = epochs 15 | self.batch_size = batch_size 16 | self.dropout = dropout 17 | self.nbatches = int(self.pcg.train.X.shape[0] / float(self.batch_size)) 18 | self.model_name = model_name 19 | self.base_dir = base_dir 20 | 21 | def train(self): 22 | """ 23 | Train a convolutional neural network over the input PCG dataset. 24 | This method is beefy: it is responsible for defining tensorflow 25 | variables, defining the training objective function, defining summary 26 | statistics creating the tensorflow session, running gradient 27 | descent and, ultimately, writing statistics 28 | 29 | In the future this will be refactored into more easily tested 30 | training segments. 31 | 32 | Parameters 33 | ---------- 34 | None 35 | 36 | Returns 37 | ------- 38 | None 39 | 40 | """ 41 | print('begin train') 42 | print(self.__get_output_name()) 43 | 44 | with tf.name_scope('input'): 45 | X = tf.placeholder(tf.float32, [None, self.d_input], name='X') 46 | y = tf.placeholder(tf.float32, [None, self.nclasses], name='y') 47 | do_drop = tf.placeholder(tf.float32, name='drop') 48 | 49 | with tf.name_scope('weights'): 50 | weights = { 51 | 'wc1': tf.Variable(tf.random_normal([5, 1, 1, 32]), name='wc1'), 52 | 'wc2': tf.Variable(tf.random_normal([5, 1, 32, 64]), name='wc2'), 53 | # 2 Max pools have taken original 10612 signal down to 54 | # 5306 --> 2653. Each max pool has a ksize=2. 55 | # 'wd1': tf.Variable(tf.random_normal([2653 * 1 * 64, 1024])), 56 | 'wd1': tf.Variable(tf.random_normal([int(self.d_input / 4) * 1 * 64, 1024]), name='wd1'), 57 | 'out': tf.Variable(tf.random_normal([1024, self.nclasses]), name='outW') 58 | } 59 | with tf.name_scope('biases'): 60 | biases = { 61 | 'bc1': tf.Variable(tf.random_normal([32]), name='bc1'), 62 | 'bc2': tf.Variable(tf.random_normal([64]), name='bc2'), 63 | 'bd1': tf.Variable(tf.random_normal([1024]), name='bd1'), 64 | 'out': tf.Variable(tf.random_normal([self.nclasses]), name='outB') 65 | } 66 | 67 | with tf.name_scope('pred'): 68 | pred = self.model1D(X, weights, biases, do_drop) 69 | 70 | with tf.name_scope('cost'): 71 | cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y, name='cost')) 72 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(cost) 73 | 74 | dim = tf.shape(y)[0] 75 | 76 | with tf.name_scope('sensitivity'): 77 | # sensitivity = correctly predicted abnormal / total number of actual abnormal 78 | abnormal_idxs = tf.cast(tf.equal(tf.argmax(pred, 1), 1), tf.float32) 79 | pred1d = tf.reshape(tf.slice(y, [0, 1], [dim, 1]), [-1]) 80 | abn = tf.mul(pred1d, abnormal_idxs) 81 | sensitivity = tf.reduce_sum(abn) / tf.reduce_sum(pred1d) 82 | tf.scalar_summary('sensitivity', sensitivity) 83 | 84 | with tf.name_scope('specificity'): 85 | # specificity = correctly predicted normal / total number of actual normal 86 | normal_idxs = tf.cast(tf.equal(tf.argmax(pred, 1), 0), tf.float32) 87 | pred1d_n = tf.reshape(tf.slice(y, [0, 0], [dim, 1]), [-1]) 88 | normal = tf.mul(pred1d_n, normal_idxs) 89 | specificity = tf.reduce_sum(normal) / tf.reduce_sum(pred1d_n) 90 | tf.scalar_summary('specificity', sensitivity) 91 | 92 | # Physionet score is the mean of sensitivity and specificity 93 | score = (sensitivity + specificity) / 2.0 94 | tf.scalar_summary('score', score) 95 | 96 | init = tf.initialize_all_variables() 97 | 98 | saver = tf.train.Saver() 99 | with tf.Session() as sess: 100 | sess.run(init) 101 | 102 | merged = tf.merge_all_summaries() 103 | train_writer = tf.train.SummaryWriter(os.path.join(self.base_dir, 'train'), sess.graph) 104 | 105 | for epoch in range(self.epochs): 106 | avg_cost = 0 107 | for batch in range(self.nbatches): 108 | batch_x, batch_y = self.pcg.get_mini_batch(self.batch_size) 109 | summary, _, c = sess.run([merged, optimizer, cost], 110 | feed_dict={X: batch_x, 111 | y: batch_y, 112 | do_drop: self.dropout}) 113 | train_writer.add_summary(summary, epoch*batch) 114 | avg_cost += c 115 | avg_cost /= float(self.nbatches) 116 | print('Epoch %s\tcost %s' % (epoch, avg_cost)) 117 | 118 | if epoch % 10 == 0: 119 | acc, sens, spec = sess.run([score, sensitivity, specificity], 120 | feed_dict={X: self.pcg.test.X, 121 | y: self.pcg.test.y, 122 | do_drop: 1.}) 123 | print('Score %s\tSensitivity %s\tSpecificity %s' % (acc, sens, spec)) 124 | 125 | saver.save(sess, self.__get_output_name()) 126 | print('Epoch written') 127 | 128 | def __get_output_name(self): 129 | now = datetime.now() 130 | time_str = "-%s" % (now.date()) # now.hour, now.minute, now.second) 131 | model_path = os.path.join(self.base_dir, self.model_name + time_str + '.tnfl') 132 | return model_path 133 | 134 | def conv2d(self, x, w, b, strides=1): 135 | """ 136 | A small helper function for calcualting a 1D convolution 137 | from tensorflow's conv2d method 138 | 139 | Parameters 140 | ---------- 141 | x: tensorflow.placeholder 142 | The feature vector 143 | w: tensorflow.Variable 144 | The unknown weights to learn 145 | b: tensorflow.Variable 146 | The unknown biases to learn 147 | strides: int 148 | The length of the stride to use for convolution 149 | 150 | Returns 151 | ------- 152 | tensorflow.Variable 153 | A convolution over the input feature vector 154 | 155 | """ 156 | 157 | 158 | x = tf.nn.conv2d(x, w, strides=[1, strides, strides, 1], padding="SAME") 159 | x = tf.nn.bias_add(x, b) 160 | return tf.nn.relu(x) 161 | 162 | def model1D(self, x, weights, biases, dropout): 163 | """ 164 | A Wrapper to chain several TensorFlow convolutional units together. This 1D model 165 | ultimately calls TensorFlow's conv2d, mapping a 1D feature vector to a collapsed 166 | 2D convolution 167 | 168 | Parameters 169 | ---------- 170 | x: tensorflow.placeholder 171 | A feature vector of size [None, no_features] 172 | 173 | weights: dict 174 | Dictionary of Unknown weights to learn 175 | 176 | biases: dict 177 | Dictionary of unknown biases to learn 178 | 179 | dropout: float 180 | the dropout fraction for convolutional units 181 | 182 | Returns 183 | ------- 184 | out: tensorflow.Variable 185 | The result of applying multiple convolutional layers and 186 | a fully connected unit to the input feature vector 187 | 188 | """ 189 | 190 | with tf.name_scope('reshape'): 191 | x = tf.reshape(x, shape=[-1, self.d_input, 1, 1]) # [n_images, width, height, n_channels] 192 | 193 | with tf.name_scope('conv1'): 194 | conv1 = self.conv2d(x, weights['wc1'], biases['bc1']) 195 | conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 1, 1], strides=[1, 2, 1, 1], padding='SAME') 196 | conv1 = tf.nn.relu(conv1) 197 | 198 | with tf.name_scope('conv2'): 199 | conv2 = self.conv2d(conv1, weights['wc2'], biases['bc2']) 200 | conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 1, 1], strides=[1, 2, 1, 1], padding="SAME") 201 | conv2 = tf.nn.relu(conv2) 202 | 203 | with tf.name_scope('fullyConnected'): 204 | d_layer1 = weights['wd1'].get_shape().as_list()[0] 205 | fc1 = tf.reshape(conv2, [-1, d_layer1]) 206 | fc1 = tf.add(tf.matmul(fc1, weights['wd1']), biases['bd1']) 207 | fc1 = tf.nn.relu(fc1) 208 | fc1 = tf.nn.dropout(fc1, dropout) 209 | 210 | out = tf.add(tf.matmul(fc1, weights['out']), biases['out']) 211 | return out 212 | -------------------------------------------------------------------------------- /deepheart/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | from scipy.io import wavfile 5 | from scipy.fftpack import fft 6 | from scipy.signal import butter, lfilter 7 | from sklearn.preprocessing import normalize 8 | from sklearn.cross_validation import train_test_split 9 | from collections import namedtuple 10 | from sklearn.cross_validation import check_random_state 11 | 12 | 13 | class PCG: 14 | """ 15 | PCG is a container for loading phonocardiogram (PCG) data for the [2016 physionet 16 | challenge](http://physionet.org/challenge/2016). Raw wav files are parsed into 17 | features, class labels are extracted from header files and data is split into 18 | training and testing groups. 19 | """ 20 | def __init__(self, basepath, random_state=42): 21 | self.basepath = basepath 22 | self.class_name_to_id = {"normal": 0, "abnormal": 1} 23 | self.nclasses = len(self.class_name_to_id.keys()) 24 | 25 | self.train = None 26 | self.test = None 27 | 28 | self.n_samples = 0 29 | 30 | self.X = None 31 | self.y = None 32 | 33 | self.random_state = random_state 34 | 35 | def initialize_wav_data(self): 36 | """ 37 | Load the original wav files and extract features. Warning, this can take a 38 | while due to slow FFTs. 39 | 40 | Parameters 41 | ---------- 42 | None 43 | 44 | Returns 45 | ------- 46 | None 47 | """ 48 | self.__load_wav_file() 49 | self.__split_train_test() 50 | # TODO: check if directory exists 51 | self.save("/tmp") 52 | 53 | def save(self, save_path): 54 | """ 55 | Persist the PCG class to disk 56 | 57 | Parameters 58 | ---------- 59 | save_path: str 60 | Location on disk to store the parsed PCG metadata 61 | 62 | Returns 63 | ------- 64 | None 65 | 66 | """ 67 | np.save(os.path.join(save_path, "X.npy"), self.X) 68 | np.save(os.path.join(save_path, "y.npy"), self.y) 69 | with open( os.path.join(save_path, "meta"), "w") as fout: 70 | pickle.dump((self.basepath, self.class_name_to_id, self.nclasses, 71 | self.n_samples, self.random_state), fout) 72 | 73 | def load(self, load_path): 74 | """ 75 | Load a previously stored PCG class. 76 | 77 | Parameters 78 | ---------- 79 | load_path: str 80 | Location on disk to load parsed PCG data 81 | 82 | Returns 83 | ------- 84 | None 85 | 86 | """ 87 | self.X = np.load(os.path.join(load_path, "X.npy")) 88 | self.y = np.load(os.path.join(load_path, "y.npy")) 89 | with open(os.path.join(load_path, "meta"), "r") as fin: 90 | (self.basepath, self.class_name_to_id, self.nclasses, 91 | self.n_samples, self.random_state) = pickle.load(fin) 92 | self.__split_train_test() 93 | 94 | def __load_wav_file(self, doFFT=True): 95 | """ 96 | Loads physio 2016 challenge dataset from self.basepath by crawling the path. 97 | For each discovered wav file: 98 | 99 | * Attempt to parse the header file for class label 100 | * Attempt to load the wav file 101 | * Calculate features from the wav file. if doFFT, features are 102 | the Fourier transform of the original signal. Else, features are 103 | the raw signal itself truncated to a fixed length 104 | 105 | Parameters 106 | ---------- 107 | doFFT: bool 108 | True if features to be calculated are the FFT of the original signal 109 | 110 | Returns 111 | ------- 112 | None 113 | """ 114 | 115 | # First pass to calculate number of samples 116 | # ensure each wav file has an associated and parsable 117 | # Header file 118 | wav_file_names = [] 119 | class_labels = [] 120 | for root, dirs, files in os.walk(self.basepath): 121 | # Ignore validation for now! 122 | if "validation" in root: 123 | continue 124 | for file in files: 125 | if file.endswith('.wav'): 126 | try: 127 | base_file_name = file.rstrip(".wav") 128 | label_file_name = os.path.join(root, base_file_name + ".hea") 129 | 130 | class_label = self.__parse_class_label(label_file_name) 131 | class_labels.append(self.class_name_to_id[class_label]) 132 | wav_file_names.append(os.path.join(root, file)) 133 | 134 | self.n_samples += 1 135 | except InvalidHeaderFileException as e: 136 | print e 137 | 138 | if doFFT: 139 | fft_embedding_size = 400 140 | highpass_embedding_size = 200 141 | X = np.zeros([self.n_samples, fft_embedding_size + highpass_embedding_size]) 142 | else: 143 | # Truncating the length of each wav file to the 144 | # min file size (10611) (Note: this is bad 145 | # And causes loss of information!) 146 | embedding_size = 10611 147 | X = np.zeros([self.n_samples, embedding_size]) 148 | 149 | for idx, wavfname in enumerate(wav_file_names): 150 | rate, wf = wavfile.read(wavfname) 151 | wf = normalize(wf.reshape(1, -1)) 152 | 153 | if doFFT: 154 | # We only care about the magnitude of each frequency 155 | wf_fft = np.abs(fft(wf)) 156 | wf_fft = wf_fft[:, :fft_embedding_size].reshape(-1) 157 | 158 | # Filter out high frequencies via Butter transform 159 | # The human heart maxes out around 150bpm = 2.5Hz 160 | # Let's filter out any frequency significantly above this 161 | nyquist = 0.5 * rate 162 | cutoff_freq = 4.0 # Hz 163 | w0, w1 = butter(5, cutoff_freq / nyquist, btype='low', analog=False) 164 | wf_low_pass = lfilter(w0, w1, wf) 165 | 166 | # FFT the filtered signal 167 | wf_low_pass_fft = np.abs(fft(wf_low_pass)) 168 | wf_low_pass_fft = wf_low_pass_fft[:, :highpass_embedding_size].reshape(-1) 169 | 170 | features = np.concatenate((wf_fft, wf_low_pass_fft)) 171 | else: 172 | features = wf[:embedding_size] 173 | 174 | X[idx, :] = features 175 | idx += 1 176 | 177 | self.X = X 178 | 179 | class_labels = np.array(class_labels) 180 | 181 | # Map from dense to one hot 182 | self.y = np.eye(self.nclasses)[class_labels] 183 | 184 | def __parse_class_label(self, label_file_name): 185 | """ 186 | Parses physio bank header files, where the class label 187 | is located in the last line of the file. An example header 188 | file could contain: 189 | 190 | f0112 1 2000 60864 191 | f0112.wav 16+44 1 16 0 0 0 0 PCG 192 | # Normal 193 | 194 | 195 | Parameters 196 | ---------- 197 | label_file_name: str 198 | Path to a specific header file 199 | 200 | Returns 201 | ------- 202 | class_label: str 203 | One of `normal` or `abnormal` 204 | """ 205 | with open(label_file_name, 'r') as fin: 206 | header = fin.readlines() 207 | 208 | comments = [line for line in header if line.startswith("#")] 209 | if not len(comments) == 1: 210 | raise InvalidHeaderFileException("Invalid label file %s" % label_file_name) 211 | 212 | class_label = str(comments[0]).lstrip("#").rstrip("\r").strip().lower() 213 | 214 | if not class_label in self.class_name_to_id.keys(): 215 | raise InvalidHeaderFileException("Invalid class label %s" % class_label) 216 | 217 | return class_label 218 | 219 | def __split_train_test(self): 220 | """ 221 | Splits internal features (self.X) and class labels (self.y) into 222 | balanced training and test sets using sklearn's helper function. 223 | 224 | Notes: 225 | * if self.random_state is None, splits will be randomly seeded 226 | otherwise, self.random_state defines the random seed to deterministicly 227 | split training and test data 228 | * For now, class balancing is done by subsampling the overrepresented class. 229 | Ideally this would be pushed down to the cost function in TensorFlow. 230 | 231 | Returns 232 | ------- 233 | None 234 | """ 235 | mlData = namedtuple('ml_data', 'X y') 236 | 237 | num_pos, num_neg = np.sum(self.y, axis=0) 238 | 239 | # Remove samples to rebalance classes 240 | # TODO: push this down into the cost function 241 | undersample_rate = num_neg / num_pos 242 | over_represented_idxs = self.y[:, 1] == 0 243 | under_represented_idxs = self.y[:, 1] == 1 244 | random_indexes_to_remove = np.random.rand(self.y.shape[0]) < undersample_rate 245 | sample_idxs = (over_represented_idxs & random_indexes_to_remove | 246 | under_represented_idxs) 247 | 248 | X_balanced = self.X[sample_idxs, :] 249 | y_balanced = self.y[sample_idxs, :] 250 | 251 | X_train, X_test, y_train, y_test = train_test_split(X_balanced, y_balanced, test_size=0.25, 252 | random_state=self.random_state) 253 | 254 | self.train = mlData(X=X_train, y=y_train) 255 | self.test = mlData(X=X_test, y=y_test) 256 | 257 | def get_mini_batch(self, batch_size): 258 | """ 259 | Helper function for sampling mini-batches from the training 260 | set. Note, random_state needs to be set to None or the same 261 | mini batch will be sampled eternally! 262 | 263 | Parameters 264 | ---------- 265 | batch_size: int 266 | Number of elements to return in the mini batch 267 | 268 | Returns 269 | ------- 270 | X: np.ndarray 271 | A feature matrix subsampled from self.train 272 | 273 | y: np.ndarray 274 | A one-hot matrix of class labels subsampled from self.train 275 | """ 276 | random_state = check_random_state(None) # self.random_state) 277 | n_training_samples = self.train.X.shape[0] 278 | minibatch_indices = random_state.randint(0, n_training_samples - 1, batch_size) 279 | 280 | return self.train.X[minibatch_indices, :], self.train.y[minibatch_indices, :] 281 | 282 | 283 | class InvalidHeaderFileException(Exception): 284 | def __init__(self, *args, **kwargs): 285 | super(args, kwargs) 286 | --------------------------------------------------------------------------------