├── README.md ├── machine_learning ├── data_mungers.py ├── neural_networks │ ├── __init__.py │ ├── basic_components.py │ ├── sequence_networks.py │ ├── tf_helpers.py │ └── torch_sequence_networks.py └── torch_helpers.py └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | # machine_learning 2 | a collection of packages for ML projects, written in Tensorflow's Python API 3 | 4 | To implement the neural speech decoders described in ["Machine translation of cortical activity to text with an encoder-decoder framework" (_Nature Neuroscience_, 2020)](https://www.nature.com/articles/s41593-020-0608-8), you probably also want to install the higher-level package, [`ecog2txt`](https://github.com/jgmakin/ecog2txt). 5 | 6 | ## Installation 7 | ``` 8 | pip install tensorflow-gpu 9 | git clone https://github.com/jgmakin/machine_learning.git 10 | pip install -e machine_learning 11 | ``` 12 | -------------------------------------------------------------------------------- /machine_learning/data_mungers.py: -------------------------------------------------------------------------------- 1 | # standard libraries 2 | import os 3 | import pdb 4 | import random 5 | from functools import reduce 6 | 7 | # third-party 8 | import torch 9 | from torch.utils.data import Dataset 10 | import torchvision 11 | import numpy as np 12 | from PIL import Image 13 | import tensorflow as tf 14 | import tensorflow_datasets as tfds 15 | 16 | # local 17 | from utils_jgm.machine_compatibility_utils import MachineCompatibilityUtils 18 | MCUs = MachineCompatibilityUtils() 19 | 20 | 21 | # define a new class to work around path issues 22 | class DualMNIST(torchvision.datasets.MNIST): 23 | def __init__(self, root=os.path.join(MCUs.get_path('data')), *args, **kwargs): 24 | super().__init__(root=root, *args, **kwargs) 25 | 26 | # ... 27 | N_classes = len(np.unique(self.targets)) 28 | 29 | # the indices sorted by label 30 | self.index_vector = torch.argsort(self.targets) 31 | self.nums_instances = np.sum( 32 | self.targets.numpy() == np.arange(N_classes).reshape([-1, 1]), 1 33 | ) 34 | 35 | # cumulative number of instances *not* including current class 36 | cums_instances = np.cumsum(self.nums_instances) 37 | self.cums_instances = np.append(0, cums_instances[:-1]) 38 | 39 | @property 40 | def raw_folder(self) -> str: 41 | # return os.path.join(self.root, self.__class__.__name__, 'raw') 42 | return os.path.join(self.root, 'MNIST', 'raw') 43 | 44 | def __getitem__(self, index: int) -> tuple[any, any]: 45 | """ 46 | Args: 47 | index (int): Index 48 | 49 | Returns: 50 | tuple: (image_a, image_b, target) where target is index of the 51 | target class. 52 | """ 53 | 54 | # Interpret index argument as index into the *sorted* examples 55 | # NB: By default, this will produce sorted examples. You need to use 56 | # shuffle=True during training 57 | iExample = index 58 | 59 | # convert to index of *unsorted* data and get the class id 60 | index_i = self.index_vector[iExample] 61 | class_id = self.targets[index_i] 62 | 63 | # useful integers 64 | N_instances = self.nums_instances[class_id] 65 | N_prev_class_examples = self.cums_instances[class_id] 66 | iInstance = iExample - N_prev_class_examples 67 | 68 | # jInstance = index wrt the first index of this class 69 | # jExample = index wrt index 0 (of the sorted data) 70 | jInstance = (random.randrange(1, N_instances) + iInstance) % N_instances 71 | jExample = jInstance + N_prev_class_examples 72 | 73 | # convert to index of *unsorted* data 74 | index_j = self.index_vector[jExample] 75 | 76 | # apply any 77 | if self.target_transform is not None: 78 | class_id = self.target_transform(class_id) 79 | 80 | image_i = self._image_proc(index_i) 81 | image_j = self._image_proc(index_j) 82 | 83 | return torch.cat((image_i, image_j), axis=0), class_id 84 | 85 | def _image_proc(self, index): 86 | image = self.data[index] 87 | 88 | # doing this so that it is consistent with all other datasets 89 | # to return a PIL Image 90 | image = Image.fromarray(image.numpy(), mode="L") 91 | 92 | if self.transform is not None: 93 | image = self.transform(image) 94 | 95 | return image 96 | 97 | 98 | # define a new class to work around path issues 99 | class SplitMNIST(torchvision.datasets.MNIST): 100 | def __init__(self, root=os.path.join(MCUs.get_path('data')), *args, **kwargs): 101 | super().__init__(root=root, *args, **kwargs) 102 | 103 | @property 104 | def raw_folder(self) -> str: 105 | # return os.path.join(self.root, self.__class__.__name__, 'raw') 106 | return os.path.join(self.root, 'MNIST', 'raw') 107 | 108 | def __getitem__(self, index: int) -> tuple[any, any]: 109 | """ 110 | Args: 111 | index (int): Index 112 | 113 | Returns: 114 | tuple: (image, target) where target is index of the target class 115 | and image is (2 x image_height/2 x image_width) 116 | """ 117 | 118 | # apply any 119 | class_id = self.targets[index] 120 | if self.target_transform is not None: 121 | class_id = self.target_transform(class_id) 122 | 123 | image = self._image_proc(index) 124 | 125 | # break into top and bottom 126 | # image = image.reshape(2, -1, image.shape[-1]) 127 | 128 | # break into left and right 129 | image = image.swapaxes(1, 2) 130 | image = image.reshape(2, -1, image.shape[-1]) 131 | image = image.swapaxes(1, 2) 132 | 133 | return image, class_id 134 | 135 | def _image_proc(self, index): 136 | image = self.data[index] 137 | 138 | # doing this so that it is consistent with all other datasets 139 | # to return a PIL Image 140 | image = Image.fromarray(image.numpy(), mode="L") 141 | 142 | if self.transform is not None: 143 | image = self.transform(image) 144 | 145 | return image 146 | 147 | 148 | class TFRecordDataLoader: 149 | def __init__( 150 | self, subnets_params, data_partition, N_cases, OOV_token, 151 | TARGETS_ARE_SEQUENCES=True, 152 | ): 153 | 154 | # don't let TF allocate the GPU to itself 155 | tf.config.set_visible_devices([], 'GPU') 156 | ds = self._tf_records_to_dataset( 157 | subnets_params, data_partition, N_cases, OOV_token, 158 | TARGETS_ARE_SEQUENCES, 159 | # num_shards_to_discard=0, DROP_REMAINDER=False 160 | ) 161 | N_batches = 0 162 | for batch in ds: 163 | N_batches += 1 164 | self.N_batches = N_batches 165 | self.ds = tfds.as_numpy(ds) 166 | # self.N_cases = N_cases 167 | self._iterator = None 168 | 169 | def __iter__(self): 170 | if self._iterator is None: 171 | self._iterator = iter(self.ds) 172 | else: 173 | self._reset() 174 | return self._iterator 175 | 176 | def _reset(self): 177 | self._iterator = iter(self.ds) 178 | 179 | def __next__(self): 180 | batch = next(self._iterator) 181 | return batch 182 | 183 | def __len__(self): 184 | return self. N_batches 185 | 186 | def _tf_records_to_dataset( 187 | self, subnets_params, data_partition, num_cases, OOV_token, 188 | TARGETS_ARE_SEQUENCES, num_shards_to_discard=0, DROP_REMAINDER=False, 189 | ): 190 | ''' 191 | Load, shuffle, batch and pad, and concatentate across subnets (for 192 | parallel transfer learning) all the data. 193 | ''' 194 | 195 | # accumulate datasets, one for each subnetwork 196 | dataset_list = [] 197 | for subnet_params in subnets_params: 198 | dataset = tf.data.TFRecordDataset([ 199 | subnet_params.tf_record_partial_path.format(block_id) 200 | for block_id in subnet_params.block_ids[data_partition]] 201 | ) 202 | dataset = dataset.map( 203 | lambda example_proto: _parse_protobuf_seq2seq_example( 204 | example_proto, subnet_params.data_manifests 205 | ), 206 | num_parallel_calls=tf.data.AUTOTUNE 207 | ) 208 | 209 | # filter data to include or exclude only specified decoder targets? 210 | decoder_targets_list = subnet_params.data_manifests[ 211 | 'decoder_targets'].get_feature_list() 212 | target_filter = TargetFilter( 213 | decoder_targets_list, subnet_params.target_specs, 214 | data_partition 215 | ) 216 | dataset = target_filter.filter_dataset(dataset) 217 | 218 | # filter out words not in the decoder_targets_list 219 | if not TARGETS_ARE_SEQUENCES: 220 | # ...then get rid of OOV examples 221 | OOV_id = ( 222 | decoder_targets_list.index(OOV_token) 223 | if OOV_token in decoder_targets_list else -1 224 | ) 225 | # NB that x['decoder_targets'].shape = [None, 1] 226 | dataset = dataset.filter( 227 | lambda x: tf.not_equal(x['decoder_targets'][0, 0], OOV_id) 228 | ) 229 | 230 | # discard some of the data?; shuffle; batch (evening out w/padding) 231 | if num_shards_to_discard > 0: 232 | dataset = dataset.shard(num_shards_to_discard+1, 0) 233 | dataset = dataset.shuffle(buffer_size=35000) # > greatest 234 | dataset = dataset.padded_batch( 235 | num_cases, 236 | padded_shapes=tf.compat.v1.data.get_output_shapes(dataset), 237 | padding_values={ 238 | key: data_manifest.padding_value 239 | for key, data_manifest in subnet_params.data_manifests.items() 240 | }, 241 | drop_remainder=DROP_REMAINDER 242 | ) 243 | 244 | # add id for "proprietary" parts of network under transfer learning 245 | dataset = dataset.map( 246 | lambda batch_of_protos_dict: { 247 | **batch_of_protos_dict, 'subnet_id': tf.constant( 248 | str(subnet_params.subnet_id), dtype=tf.string 249 | ) 250 | } 251 | ) 252 | dataset_list.append(dataset) 253 | 254 | # (randomly) interleave (sub-)batches w/o throwing anything away 255 | dataset = reduce( 256 | lambda set_a, set_b: set_a.concatenate(set_b), dataset_list 257 | ) 258 | dataset = dataset.shuffle(buffer_size=3000) 259 | ###### 260 | # Since your parse_protobuf_seq2seq_example isn't doing much, the 261 | # overhead associated with just scheduling the dataset.map will 262 | # dominate the cost of applying it. Therefore, tensorflow 263 | # recommends batching first, and applying a vectorized version of 264 | # parse_protobuf_seq2seq_example. But you shuffle first..... 265 | ###### 266 | dataset = dataset.prefetch(tf.data.AUTOTUNE) #num_cases) 267 | 268 | return dataset 269 | 270 | 271 | def _parse_protobuf_seq2seq_example(example_proto, data_manifests): 272 | 273 | # parse the features using the data_descriptions and prepare the outputs 274 | feature_dict = { 275 | data_manifest.sequence_type: data_manifest.feature_value 276 | for data_manifest in data_manifests.values() 277 | } 278 | parsed_features = tf.io.parse_single_example( 279 | serialized=example_proto, features=feature_dict) 280 | example_dict = dict.fromkeys(data_manifests.keys()) 281 | 282 | # for each data_manifest (the number is indeterminate)... 283 | for key, data_manifest in data_manifests.items(): 284 | # ..."unflatten" the sequence of (possibly length-1) vectors and xform 285 | sequence_matrix = tf.reshape( 286 | parsed_features[data_manifest.sequence_type].values, 287 | (-1, data_manifest.num_features_raw) 288 | ) 289 | example_dict[key] = data_manifest.transform(sequence_matrix) 290 | 291 | return example_dict 292 | 293 | 294 | class TargetFilter: 295 | def __init__(self, unique_targets, target_specs, this_data_type): 296 | 297 | ''' 298 | # Example: 299 | target_specs = { 300 | 'validation': [ 301 | ['this', 'was', 'easy', 'for', 'us'], 302 | ['they', 'often', 'go', 'out', 'in', 'the', 'evening'], 303 | ['i', 'honour', 'my', 'mum'], 304 | ['a', 'doctor', 'was', 'in', 'the', 'ambulance', 'with', 'the', 'patient'], 305 | ['we', 'are', 'open', 'every', 'monday', 'evening'], 306 | ['withdraw', 'only', 'as', 'much', 'money', 'as', 'you', 'need'], 307 | ['allow', 'each', 'child', 'to', 'have', 'an', 'ice', 'pop'], 308 | ['is', 'she', 'going', 'with', 'you'] 309 | ] 310 | } 311 | ''' 312 | 313 | # fixed 314 | data_types = {'training', 'validation'} 315 | 316 | # convert target_specs dictionary entries from word- to index-based 317 | # NB: PROBABLY NOT GENERAL ENOUGH to work w/non-word_sequence data 318 | self.target_specs = {key: [ 319 | [unique_targets.index(w + '_') for w in target] + [1] 320 | for target in target_spec] for key, target_spec in target_specs.items() 321 | } 322 | 323 | # store for later use 324 | self.this_data_type = this_data_type 325 | self.other_data_type = (data_types - {this_data_type}).pop() 326 | 327 | def _test_special(self, fetch_target_indices, data_type): 328 | # Test if this tf_record target is among this dataset's target_specs. 329 | # NB that this function returns a (boolean) tf.tensor. 330 | TEST_SPECIAL = tf.constant(False) 331 | for target_indices in self.target_specs[data_type]: 332 | TEST_MATCH = tf.reduce_all( 333 | tf.linalg.diag_part(tf.equal( 334 | fetch_target_indices, 335 | np.array(target_indices, ndmin=2)) 336 | )) 337 | TEST_SPECIAL = tf.logical_or(TEST_SPECIAL, TEST_MATCH) 338 | return TEST_SPECIAL 339 | 340 | def filter_dataset(self, dataset): 341 | if self.this_data_type in self.target_specs: 342 | return dataset.filter( 343 | lambda example_dict: self._test_special( 344 | example_dict['decoder_targets'], self.this_data_type 345 | )) 346 | elif self.other_data_type in self.target_specs: 347 | return dataset.filter( 348 | lambda example_dict: self._test_special( 349 | example_dict['decoder_targets'], self.other_data_type 350 | )) 351 | else: 352 | return dataset 353 | 354 | 355 | ##### 356 | # DEPRECATED: too slow 357 | # Importing tfrecords into pytorch 358 | ##### 359 | # class TFRecordPipe: 360 | # from machine_learning.torch_helpers import parse_protobuf_seq2seq_example 361 | # from torchdata.datapipes.iter import FileLister, FileOpener 362 | # from tfrecord.torch.dataset import MultiTFRecordDataset 363 | 364 | # def __init__( 365 | # self, 366 | # ##### 367 | # # for now, just one 368 | # subject 369 | # ##### 370 | # ): 371 | # self.data_manifests = subject.data_manifests 372 | # self.partial_path = subject.data_generator.tf_record_partial_path 373 | # self.block_ids = subject.block_ids 374 | 375 | # def parse_protobuf(self, example_proto): 376 | # ''' 377 | # Resizes data and converts words to indices. NB that all matrices in 378 | # the example_dict have size [T x N_features] 379 | # ''' 380 | 381 | # example_dict = TFRecordPipe.parse_protobuf_seq2seq_example( 382 | # example_proto, self.data_manifests, 383 | # ) 384 | 385 | # return example_dict 386 | 387 | # def pad_collate(self, batch): 388 | # ''' 389 | # Transforms 390 | # list of dictionaries of variable-length sequences 391 | # into 392 | # dictionary of tensors 393 | # (padded to account for variable lengths) 394 | 395 | # Each example in the batch is a dict. Each value in the dict has size 396 | # (T_i x N_features). 397 | # where T_i is the length of that particular example. The elements of 398 | # the output, batch_dict, have size 399 | # (N_cases x T x N_features), 400 | # where T is the length of the longest sequence in the batch. 401 | # ''' 402 | 403 | # batch_dict = { 404 | # key: torch.nn.utils.rnn.pad_sequence( 405 | # # [torch.tensor(example[key]) for example in batch], 406 | # [example[key] for example in batch], 407 | # batch_first=True, 408 | # padding_value=self.data_manifests[key].padding_value 409 | # ) for key in self.data_manifests.keys() 410 | # } 411 | 412 | # return batch_dict 413 | 414 | # def construct_pipe(self): 415 | # # vahidk or pytorch version? Both very slow 416 | # return self.construct_pipe_v() 417 | # return self.construct_pipe_t() 418 | 419 | # def construct_pipe_v(self): 420 | # # ... 421 | # # index_pattern = self.partial_path.replace('.tfrecord', '.tfindex') 422 | # # description = { 423 | # # 'ecog_sequence': 'float', 424 | # # 'phoneme_sequence': 'byte', 425 | # # 'text_sequence': 'byte', 426 | # # "audio_sequence": 'float', 427 | # # } 428 | 429 | # # unnormalized probabilities 430 | # splits = {block: 1.0 for block in self.block_ids['training']} 431 | 432 | # # ... 433 | # dataset = MultiTFRecordDataset( 434 | # self.partial_path, 435 | # index_pattern=None, 436 | # splits=splits, 437 | # description=None, 438 | # infinite=False, 439 | # transform=self.parse_protobuf, 440 | # shuffle_queue_size=512, 441 | # ) 442 | # return dataset 443 | 444 | # def construct_pipe_t(self): 445 | # tf_record_dir, tf_record_name = os.path.split(self.partial_path) 446 | # datapipe = FileLister(tf_record_dir, tf_record_name.format('*')) 447 | # datapipe = FileOpener(datapipe, mode="b") 448 | # datapipe = datapipe.load_from_tfrecord() 449 | # datapipe = datapipe.shuffle() 450 | # datapipe = datapipe.sharding_filter() 451 | # datapipe = datapipe.map(self.parse_protobuf) 452 | 453 | # # # train, valid = tfrecord_datapipe.random_split( 454 | # # # # total_length=10, 455 | # # # weights={"train": 0.8, "valid": 0.2}, seed=0 456 | # # # ) 457 | 458 | # # return datapipe 459 | -------------------------------------------------------------------------------- /machine_learning/neural_networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Created: 07/05/17 2 | # by JGM 3 | -------------------------------------------------------------------------------- /machine_learning/neural_networks/basic_components.py: -------------------------------------------------------------------------------- 1 | # standard libraries 2 | import pdb 3 | 4 | # third-party packages 5 | import tensorflow as tf 6 | try: 7 | from tensor2tensor.layers import common_layers 8 | except ModuleNotFoundError: 9 | print('WARNING: tensor2tensor missing; skipping') 10 | 11 | # local 12 | from . import tf_helpers as tfh 13 | 14 | ''' 15 | A collection of methods for neural networks built in tensorflow. 16 | 17 | :Author: J.G. Makin (except where otherwise noted) 18 | 19 | Created: June 2017 20 | ''' 21 | 22 | 23 | def bias_decorator(preactivation_fxn): 24 | def bias_wrapper(inputs, Nin, Nout, **kwargs): 25 | USE_BIASES = kwargs.get('USE_BIASES', True) 26 | stiffness = kwargs.get('stiffness', 0) 27 | preactivation = preactivation_fxn(inputs, Nin, Nout, **kwargs) 28 | if USE_BIASES: 29 | biases = create_biases([Nout], stiffness=stiffness) 30 | preactivation += biases 31 | return preactivation 32 | 33 | return bias_wrapper 34 | 35 | 36 | @bias_decorator 37 | def tf_matmul_wrapper( 38 | inputs, Nin, Nout, stiffness=0, transpose_b=False, num_shards=None, 39 | USE_BIASES=True): 40 | wts_shape = (Nout, Nin) if transpose_b else (Nin, Nout) 41 | wts = create_weights(wts_shape, stiffness=stiffness, num_shards=num_shards) 42 | 43 | # This is something of a hack: When the inputs are integers, interpret them 44 | # as indices--in lieu of a one-hot representation--and therefore tf.gather 45 | # to extract the indexed column, rather than tf.matmul. 46 | # You would prefer to test sizes, but this is impossible: You need to hide 47 | # the final dimension (num_features) from Tensorflow to keep it from 48 | # noticing that different subjects have different num_features (which is 49 | # fine, because they're separated by tf.case, but TF doesn't get this). 50 | #### if (common_layers.shape_list(inputs)[-1] == 1) and (Nin != 1): 51 | if inputs.dtype is tf.int32: 52 | return tf.gather(wts, tf.reshape(inputs, [-1])) 53 | else: 54 | return tf.matmul(inputs, wts, transpose_b=transpose_b) 55 | 56 | 57 | @bias_decorator 58 | def tf_conv2d_wrapper( 59 | inputs, Nin, Nout, name, stiffness=0, filter_height=1, filter_width=1, 60 | strides=[1, 1, 1, 1], num_shards=None, USE_BIASES=True): 61 | wts_shape = [filter_height, filter_width, Nin, Nout] 62 | wts = create_weights(wts_shape, stiffness=stiffness, num_shards=num_shards) 63 | preactivations = tf.nn.conv2d( 64 | input=inputs, filters=wts, strides=strides, padding='VALID' 65 | ) 66 | ### provide the name to conv2d?? otherwise eliminate.... 67 | return preactivations 68 | 69 | 70 | def tf_max_pool_wrapper(inputs, name, ksize, strides): 71 | return tf.nn.max_pool2d( 72 | input=inputs, name=name, ksize=ksize, strides=strides, padding='VALID') 73 | 74 | 75 | def tf_avg_pool_wrapper(inputs, name, ksize, strides): 76 | return tf.nn.avg_pool2d( 77 | value=inputs, name=name, ksize=ksize, strides=strides, padding='VALID') 78 | 79 | 80 | def feed_forward_multi_layer( 81 | get_activations, Ninputs, layer_sizes, dropout_rate, net_name, 82 | preactivation_fxns=None, activation_fxns=None): 83 | 84 | Nlayers = len(layer_sizes) 85 | 86 | # fill in in default lists 87 | if not preactivation_fxns: 88 | preactivation_fxns = [tf_matmul_wrapper]*Nlayers 89 | if not activation_fxns: 90 | activation_fxns = [tf.nn.relu]*Nlayers 91 | 92 | for iLayer, (Noutputs, preactivation_fxn, activation_fxn) in enumerate(zip( 93 | layer_sizes, preactivation_fxns, activation_fxns)): 94 | 95 | # for consistency w/t2t translation model, swap for output projections 96 | wts_shape = (Ninputs, Noutputs) 97 | if hasattr(preactivation_fxn, 'TRANSPOSED'): 98 | if preactivation_fxn.TRANSPOSED: 99 | wts_shape = (Noutputs, Ninputs) 100 | layer_name = '%s_%i_%i_%i' % (net_name, *wts_shape, iLayer) 101 | 102 | # now run through a single layer 103 | get_activations = feed_forward_one_layer( 104 | get_activations, layer_name, Nin=Ninputs, Nout=Noutputs, 105 | preactivation_fxn=preactivation_fxn, activation_fxn=activation_fxn) 106 | get_activations = tf.nn.dropout(get_activations, rate=dropout_rate) 107 | Ninputs = Noutputs 108 | return get_activations, Ninputs 109 | 110 | 111 | def feed_forward_one_layer( 112 | input_tensor, layer_name, Nin=1, Nout=1, 113 | preactivation_fxn=tf_matmul_wrapper, activation_fxn=tf.nn.relu 114 | ): 115 | """NB that for convnets, Nout is number of *channels*""" 116 | with tf.compat.v1.variable_scope(layer_name, reuse=tf.compat.v1.AUTO_REUSE): 117 | preactivations = preactivation_fxn(input_tensor, Nin, Nout) 118 | activations = activation_fxn(preactivations, name='activation') 119 | 120 | # for visualization in tensorboard 121 | # variable_summaries(wts, 'weights') 122 | # variable_summaries(biases, 'biases') 123 | # variable_summaries(preactivations, 'preactivations_summary') 124 | # variable_summaries(activations, 'activations_summary') 125 | return activations 126 | 127 | 128 | def create_weights(weight_shape, stiffness=0, num_shards=None): 129 | # NB that lambda initialization is required for use under 130 | # tensorflow control loops 131 | if num_shards: 132 | # ...then create a tensor for each shard and merge 133 | # Borrowed from tensor2tensor.layers.modalities.py 134 | # to facilitate restoration of tensor2tensor models: 135 | shards = [] 136 | for iShard in range(num_shards): 137 | shard_size = (weight_shape[0] // num_shards) + ( 138 | 1 if iShard < weight_shape[0] % num_shards else 0) 139 | var_name = "weights_%d" % iShard 140 | shards.append(tf.compat.v1.get_variable( 141 | var_name, shape=[shard_size] + weight_shape[1:], 142 | initializer=lambda shape, dtype, partition_info: 143 | tf.compat.v1.truncated_normal(shape, stddev=0.1))) 144 | weights = tf.concat(shards, 0) 145 | # ret = eu.convert_gradient_to_tensor(ret) 146 | else: 147 | weights = tf.compat.v1.get_variable( 148 | 'weights', shape=weight_shape, 149 | initializer=lambda shape, dtype, partition_info: 150 | ###tf.compat.v1.truncated_normal(shape, stddev=0.005)) 151 | tf.compat.v1.truncated_normal(shape, stddev=0.1)) 152 | #cost = tf.multiply(tf.nn.l2_loss(weights), stiffness, name='weight_loss') 153 | #tf.add_to_collection(tf.GraphKeys.LOSSES, cost) 154 | return weights 155 | 156 | 157 | def create_biases(bias_shape, stiffness=0): 158 | ''' 159 | initial_values = tf.constant(0.1, shape=bias_shape) 160 | biases = tf.get_variable('biases', initializer=initial_values) 161 | ''' 162 | biases = tf.compat.v1.get_variable( 163 | 'biases', shape=bias_shape, 164 | initializer=lambda shape, dtype, partition_info: 165 | ###tf.constant(0.0, shape=shape)) 166 | tf.constant(0.1, shape=shape)) 167 | #cost = tf.multiply(tf.nn.l2_loss(biases), stiffness, name='weight_loss') 168 | #tf.add_to_collection(tf.GraphKeys.LOSSES, cost) 169 | return biases 170 | 171 | 172 | def LSTM_rnn( 173 | batch_sequences, sequence_lengths, hidden_layer_sizes, dropout, name, 174 | initial_state=None, BIDIRECTIONAL=False 175 | ): 176 | # Borrowed from the tensor2tensor library, and modified 177 | ''' 178 | Run LSTM cell on inputs, assuming they have size 179 | [N_cases x max_sequence_length x Ninputs]. 180 | 181 | Input arguments: 182 | ------- 183 | batch_sequences: 184 | sequence_lengths: 185 | hidden_size: 186 | num_hidden_layers: 187 | dropout: 188 | name: 189 | states: 190 | 191 | Outputs: 192 | ------- 193 | lstm_outputs: 194 | lstm_final_states: 195 | 196 | 197 | Un-/poorly documented behavior: tf.nn.dynamic_rnn returns two arguments, 198 | an output and a (final) state. Providing the sequence_lengths as input 199 | will "copy-through state [the second output] and zero-out outputs [the 200 | first output] when past a batch element's sequence length" (per the tf 201 | documentation). Therefore, since we're passing the sequence_lengths, we 202 | can safely use the last element of the final_state. So far so good. 203 | 204 | However, the final_state is itself a tuple, consisting of a cell state (c) 205 | and a hidden state (h). Per some stackexchange posts, e.g. this one, 206 | 207 | https://stackoverflow.com/questions/36817596/ 208 | 209 | you use the hidden state. On the difference b/n output and state, see 210 | 211 | https://stats.stackexchange.com/questions/330176 212 | 213 | Incidentally, the hidden state, final_state.h, is itself an array, with 214 | as many elements as layers in the RNN. For decoding purposes, you will 215 | typically want to use the last, final_state.h[-1]. 216 | 217 | 218 | ''' 219 | 220 | ''' 221 | with tf.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 222 | layers = [ 223 | tf.keras.layers.LSTM( 224 | hidden_size, dropout=dropout, return_sequences=True) 225 | for _ in range(num_hidden_layers) 226 | ### tf.keras.layers.CuDNNLSTM doesn't support masking 227 | ] 228 | if BIDIRECTIONAL: 229 | ### This is messed up b/c the initial state could be for multiple 230 | ### layers 231 | layers = [tf.keras.layers.Bidirectional(layer, states) 232 | for layer in layers] 233 | stacked_layers = tf.keras.layers.StackedRNNCells(layers) 234 | masked_sequences = tf.keras.layers.Masking()(batch_sequences) 235 | outputs, states_tuple = stacked_layers(masked_sequences) 236 | 237 | 238 | ''' 239 | # for brevity 240 | def variational_dropout_lstm_cell(input_size): 241 | LSTMcell = tf.compat.v1.nn.rnn_cell.LSTMCell( 242 | input_size, name='basic_lstm_cell' 243 | ) 244 | return tf.compat.v1.nn.rnn_cell.DropoutWrapper( 245 | LSTMcell, input_keep_prob=1-dropout, 246 | ) 247 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 248 | forward_layers = [ 249 | variational_dropout_lstm_cell(layer_size) 250 | for layer_size in hidden_layer_sizes 251 | ] 252 | if BIDIRECTIONAL: 253 | backward_layers = [ 254 | variational_dropout_lstm_cell(layer_size) 255 | for layer_size in hidden_layer_sizes 256 | ] 257 | 258 | # see https://stackoverflow.com/questions/49242266/ 259 | (outputs, final_state_fw, final_state_bw 260 | ) = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( 261 | forward_layers, 262 | backward_layers, 263 | batch_sequences, 264 | sequence_length=sequence_lengths, 265 | initial_states_fw=initial_state, 266 | initial_states_bw=None, 267 | dtype=tf.float32) 268 | 269 | states_tuple = tuple(tf.compat.v1.nn.rnn_cell.LSTMStateTuple( 270 | tf.concat((state_fw.c, state_bw.c), 1), 271 | tf.concat((state_fw.h, state_bw.h), 1)) 272 | for state_fw, state_bw in zip(final_state_fw, final_state_bw)) 273 | 274 | # ########### 275 | # TF 2.x STYLE 276 | # 277 | # # But see this tensorflow stupidity: 278 | # # https://github.com/tensorflow/tensorflow/issues/35654 279 | # ########### 280 | 281 | # # transform initial state into the new format 282 | # outputs = tf.keras.layers.Masking(mask_value=0.0)(batch_sequences) 283 | # if initial_state is None: 284 | # states = None 285 | # else: 286 | # states = [ 287 | # initial_state[-1].h[:, :hidden_layer_sizes[0]/2], 288 | # initial_state[-1].c[:, :hidden_layer_sizes[0]/2], 289 | # initial_state[-1].h[:, -hidden_layer_sizes[0]/2:], 290 | # initial_state[-1].c[:, -hidden_layer_sizes[0]/2:], 291 | # ] 292 | 293 | # all_states = [] 294 | # for ii, layer_size in enumerate(hidden_layer_sizes): 295 | # outputs, *states = tf.keras.layers.Bidirectional( 296 | # tf.keras.layers.RNN( 297 | # tf.keras.layers.LSTMCell( 298 | # layer_size, recurrent_dropout=dropout 299 | # ), return_state=True, return_sequences=True 300 | # ) 301 | # )(outputs, states) 302 | 303 | # # accumulate the states, converting over to the old format 304 | # all_states.append( 305 | # tf.compat.v1.nn.rnn_cell.LSTMStateTuple( 306 | # tf.concat((states[1], states[3]), 1), 307 | # tf.concat((states[0], states[2]), 1) 308 | # ) 309 | # ) 310 | 311 | # # probably unnecessary 312 | # states_tuple = tuple(all_states) 313 | 314 | else: 315 | outputs, states_tuple = tf.compat.v1.nn.dynamic_rnn( 316 | tf.compat.v1.nn.rnn_cell.MultiRNNCell(forward_layers), 317 | batch_sequences, 318 | sequence_length=sequence_lengths, 319 | initial_state=initial_state, 320 | dtype=tf.float32) 321 | return outputs, states_tuple 322 | 323 | 324 | def variable_summaries(var, name): 325 | ''' 326 | Attach a lot of summaries to a Tensor (for TensorBoard visualization). 327 | ''' 328 | with tf.compat.v1.name_scope('summaries'): 329 | mean = tf.reduce_mean(input_tensor=var, name=name) 330 | tf.compat.v1.summary.scalar('mean', mean) 331 | # with tf.name_scope('stddev'): 332 | # stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 333 | # tf.summary.scalar('stddev', stddev) 334 | # tf.summary.scalar('max', tf.reduce_max(var)) 335 | # tf.summary.scalar('min', tf.reduce_min(var)) 336 | tf.compat.v1.summary.histogram(name, var) 337 | 338 | 339 | def sequences_tools(sequences): 340 | ''' 341 | Input arguments: 342 | -------- 343 | sequences: 344 | tensor of size (N_cases x max_sequence_length x Ndims) 345 | 346 | Returns: 347 | -------- 348 | index_sequences_elements: 349 | (sum_i^Nsequences seq_len(i) x 2) tensor listing all the non-zero 350 | indices in the tensor of sequences 351 | get_sequences_lengths: 352 | int32 tensor of size (N_cases) 353 | ''' 354 | 355 | # mask_binariwise is a (N_cases x max_sequence_length) matrix with 0s 356 | # wherever all elements of an input token are simultaneously zero, 357 | # and 1s elsewhere. Since all elements of an input token are 358 | # simultaneously zero only in the zero-padding, the 1s will be 359 | # contiguous, and the number of them in each row will be the 360 | # corresponding true sequence length 361 | mask_binariwise = tf.sign(tf.reduce_max(tf.abs(sequences), axis=2)) 362 | get_sequences_lengths = tf.reduce_sum(mask_binariwise, axis=1) 363 | get_sequences_lengths = tf.cast(get_sequences_lengths, tf.int32) 364 | # shouldn't it already be an int? 365 | index_sequences_elements = tf.cast( 366 | tf.compat.v1.where(tf.equal(mask_binariwise, 1)), tf.int32) 367 | 368 | return index_sequences_elements, get_sequences_lengths 369 | 370 | 371 | def occlude_sequence_features(get_sequences, occluded_features): 372 | ''' 373 | For all sequences in the zero-padded tensor get_sequences (with shape 374 | (N_cases x T_max x Nfeatures), replace the features labeled by index in 375 | occluded_features with their average values (exlcuding the zero-padding, 376 | of course). 377 | ''' 378 | 379 | index_sequences, _ = sequences_tools(get_sequences) 380 | desequence_sequences = tf.gather_nd(get_sequences, index_sequences) 381 | desequenced_shape = common_layers.shape_list(desequence_sequences) 382 | 383 | average_feature_activities = tf.reduce_mean( 384 | input_tensor=desequence_sequences, axis=0) 385 | occlude_desequenced_sequences = tf.stack( 386 | [ 387 | tf.fill((desequenced_shape[0], ), average_feature_activities[i]) 388 | if i in occluded_features else desequence_sequences[:, i] 389 | for i in range(desequenced_shape[1]) 390 | ], axis=1 391 | ) 392 | 393 | return tf.scatter_nd( 394 | index_sequences, occlude_desequenced_sequences, 395 | tf.shape(input=get_sequences)) 396 | 397 | 398 | def tf_expected_word_error_rates( 399 | references, hypotheses, get_sequence_log_probs, 400 | USE_BUILTIN=True, EXCLUDE_EOS=False, eos_id=1 401 | ): 402 | ''' 403 | Compute word error rate on the results of a beam search. In particular, 404 | tile the references to the beam size, and reshape into (N_cases*beam_width 405 | x max_sequence_length); then compute the word error rate in a vectorized 406 | way. 407 | 408 | Input arguments: 409 | -------- 410 | references: (N_cases x 1 x max_ref_length) 411 | hypotheses: (N_cases x beam_width x max_hyp_length) 412 | get_sequence_log_probs': (N_cases x beam_width) 413 | 414 | For a categorical distribution, the natural params are (possibly 415 | unnormalized) log probabilities. (I believe that the tensor2tensor 416 | beam_search, on whose outputs this is generally run, actually normalizes 417 | probabilities within the beam, but to be sure we treat these as 418 | unnormalized log probabilities.) 419 | 420 | Returns: 421 | -------- 422 | average_word_error_rate 423 | ''' 424 | 425 | # Ns 426 | N_cases = common_layers.shape_list(references)[0] 427 | beam_width = common_layers.shape_list(hypotheses)[1] 428 | N_sentences = N_cases*beam_width 429 | 430 | # tile references to have the same shape as hypotheses 431 | references = tf.reshape( 432 | tf.tile(references, [1, beam_width, 1]), [N_sentences, -1] 433 | ) 434 | hypotheses = tf.reshape(hypotheses, [N_sentences, -1]) 435 | 436 | # get word error rates 437 | get_word_error_rate = ( 438 | tf_word_error_rates_built_in if USE_BUILTIN else tf_word_error_rates 439 | ) 440 | word_error_rate_matrix = tf.reshape( 441 | get_word_error_rate(references, hypotheses, EXCLUDE_EOS, eos_id), 442 | [-1, beam_width] 443 | ) 444 | 445 | # take average under the hypotheses' probabilities 446 | logZ = tf.reduce_logsumexp(get_sequence_log_probs, axis=1) 447 | logXpctWplus1 = tf.reduce_logsumexp( 448 | get_sequence_log_probs + tf.math.log(word_error_rate_matrix + 1.0), 449 | axis=1) - logZ 450 | 451 | # we added 1 to avoid log(0); now subtract that 1 452 | return tf.exp(logXpctWplus1) - 1 453 | 454 | 455 | def tf_word_error_rates(references, hypotheses, EXCLUDE_EOS=False, eos_id=1): 456 | """ 457 | Tensorflow implementation of a vectorized version of word error rate, 458 | based on the Levenstein distance. The underlying algorithm is a variant 459 | on Wagner-Fisher/Needleman-Wunsch. 460 | 461 | Input arguments: 462 | -------- 463 | references: a tensor with shape (N_sentences x max_ref_length) 464 | hypotheses: a tensor with shape (N_sentences x max_hyp_length) 465 | 466 | Returns: 467 | -------- 468 | tensor with shape (N_sentences) 469 | 470 | 471 | Example: 472 | -------- 473 | etc..... 474 | 475 | """ 476 | # Created: 02/14/18 477 | # by JGM 478 | 479 | ###### 480 | # TO DO: 481 | # (1) Make use of the args EXCLUDE_EOS and eos_id 482 | ###### 483 | 484 | # Ns 485 | N_sentences, max_ref_length = common_layers.shape_list(references) 486 | max_hyp_length = common_layers.shape_list(hypotheses)[1] 487 | 488 | # upper bound on WER 489 | d_maxes = tf.fill( 490 | (N_sentences,), tf.maximum(max_ref_length, max_hyp_length) + 1 491 | ) 492 | 493 | # get all the sequence lengths 494 | _, get_ref_lengths = sequences_tools(tf.expand_dims(references, axis=2)) 495 | _, get_hyp_lengths = sequences_tools(tf.expand_dims(hypotheses, axis=2)) 496 | 497 | # the conditions and bodies for a pair of nested while loops 498 | def inner_cond(i_ref, i_hyp, distances): return i_hyp < max_hyp_length 499 | 500 | def outer_cond(i_ref, i_hyp, distances): return i_ref < max_ref_length 501 | 502 | def inner_body(i_ref, i_hyp, distances): 503 | return i_ref, i_hyp+1, tf_fisher_wagner_body(i_ref, i_hyp, distances) 504 | 505 | def outer_body(i_ref, i_hyp, distances): 506 | i_ref, _, distances = tf.while_loop( 507 | cond=inner_cond, body=inner_body, loop_vars=(i_ref, i_hyp, distances), 508 | parallel_iterations=1, back_prop=False) 509 | return i_ref+1, tf.constant(0), distances 510 | 511 | def tf_fisher_wagner_body(i_ref, i_hyp, distances): 512 | match = tf.compat.v1.where( 513 | tf.equal(references[:, i_ref], hypotheses[:, i_hyp]), 514 | distances[:, i_ref, i_hyp], 515 | d_maxes 516 | ) 517 | substitution = distances[:, i_ref, i_hyp] + 1 518 | insertion = distances[:, i_ref+1, i_hyp] + 1 519 | deletion = distances[:, i_ref, i_hyp+1] + 1 520 | updates = tf.reduce_min(input_tensor=tf.stack( 521 | [match, substitution, insertion, deletion], axis=1), axis=1) 522 | indices = tf.stack((tf.range(N_sentences), 523 | tf.fill([N_sentences], i_ref+1), 524 | tf.fill([N_sentences], i_hyp+1)), axis=1) 525 | return distances + tf.scatter_nd( 526 | indices, updates, 527 | shape=[N_sentences, max_ref_length+1, max_hyp_length+1]) 528 | 529 | # initialize 530 | i_ref0 = tf.constant(0) 531 | i_hyp0 = tf.constant(0) 532 | row_indices = tf.range(max_ref_length+1) 533 | first_row_indices = tf.stack( 534 | (row_indices, tf.fill([max_ref_length+1], 0)), axis=1) 535 | col_indices = tf.range(max_hyp_length+1) 536 | first_col_indices = tf.stack( 537 | (tf.fill([max_hyp_length+1], 0), col_indices), axis=1) 538 | indices = tf.concat((first_col_indices, first_row_indices), axis=0) 539 | updates = tf.concat((col_indices, row_indices), axis=0) 540 | distances0 = tf.scatter_nd(indices, updates, 541 | shape=[max_ref_length+1, max_hyp_length+1]) 542 | distances0 = tf.transpose( 543 | a=tf.tile(tf.expand_dims(distances0, axis=2), [1, 1, N_sentences]), 544 | perm=[2, 0, 1]) 545 | 546 | # run the nested while loops (Fisher-Wagner algorithm) 547 | _, _, distance_tensor = tf.while_loop( 548 | cond=outer_cond, body=outer_body, loop_vars=(i_ref0, i_hyp0, distances0), 549 | parallel_iterations=1, back_prop=False) 550 | 551 | # return just the distances at the end of each sentence 552 | distance_vector = tf.cast(tf.divide(tf.gather_nd(distance_tensor, tf.stack( 553 | [tf.range(N_sentences), get_ref_lengths, get_hyp_lengths], axis=1)), 554 | get_ref_lengths), tf.float32) 555 | return distance_vector 556 | 557 | 558 | def tf_word_error_rates_built_in( 559 | references, hypotheses, EXCLUDE_EOS=False, eos_id=1 560 | ): 561 | # Use tensorflow's word_error_rate calculator 562 | 563 | # HARD-CODED PAD_ID 564 | pad_id = 0 565 | ignore_ids = [pad_id, eos_id] if EXCLUDE_EOS else [pad_id] 566 | 567 | ##### 568 | # You should make this into a general function ("tf_broadcast_equal")... 569 | def extract_non_ignore_indices(sequences): 570 | return tf.compat.v1.where(tf.reduce_all( 571 | tf.not_equal(tf.expand_dims(sequences, 2), [[ignore_ids]]), 572 | axis=2 573 | )) 574 | ##### 575 | 576 | # ... 577 | index_references = extract_non_ignore_indices(references) 578 | sparse_references = tf.SparseTensor( 579 | index_references, 580 | tf.gather_nd(references, index_references), 581 | tf.cast(tf.shape(input=references), tf.int64) 582 | ) 583 | index_hypotheses = extract_non_ignore_indices(hypotheses) 584 | sparse_hypotheses = tf.SparseTensor( 585 | index_hypotheses, 586 | tf.gather_nd(hypotheses, index_hypotheses), 587 | tf.cast(tf.shape(input=hypotheses), tf.int64) 588 | ) 589 | return tf.edit_distance(sparse_hypotheses, sparse_references) 590 | 591 | 592 | def seq_log_probs_to_word_log_probs( 593 | get_beam_outputs, get_sequence_log_probs, Nclasses, 594 | index_sequences_elements, max_hyp_length, padding_value=0 595 | ): 596 | ''' 597 | :param get_outputs: (Nsequences x beam_width x max_prediction_length) 598 | :param get_sequence_log_probs: (Nsequences x beam_width) 599 | :param Nclasses: scalar 600 | :param index_sequence_elements: (sum_i^Nsequences seq_len(i) x 2), a list 601 | of all the (putative) non-zero indices in the tensor of sequences 602 | :param max_hyp_length: scalar tensor 603 | :return: score_as_unnorm_log_probs: (sum_i^Nsequences seq_len(i) x Nclasses), 604 | a tensor of log probabilities for each id, de-sequenced 605 | 606 | A sensible set of variables for a beam search to return is the set of the K 607 | most probable sequences and their probabilities, where K=beam_width. (These 608 | sequence_log_probs are not assumed to be normalized.) 609 | 610 | We want to expand the log probabilities to cover *all* tokens, not just the 611 | K most likely. Conceptually, this is straightforward: For each element of 612 | each sequence, exponentiate the log probabilities; compute the "leftover" 613 | probability for all ids outside the beam, and divide it up equally among 614 | them; compute the logarithm elementwise. Computationally, however, it is 615 | more complicated, b/c an effort must be made to avoid over- and underflows. 616 | 617 | Furthermore, to avoid doing any serious calculations, we have to make some 618 | simplifying choice for how to compute the "leftover" probabilities. Here, 619 | we basically assign each non-selected id probability 1/S, S=total number of 620 | possible sequences. That is, we pretend that each non-selected *sequence* 621 | has equal probability, 1/S, and then assume (what is certainly false) that 622 | each non-selected token at each time step *in each beam* can be assigned to 623 | exactly one of these non-selected sequences. Hence e.g., even if token 324 624 | appears in at least one beam at time step t, it will still be assigned 625 | probability 1/S at t in all beams where it did *not* appear. This 626 | facilitates summing log probabilities across the beams. 627 | 628 | Total number of sequences: For simplicity, ignore the end-of-sequence 629 | tokens. For a vocabulary of size N and a maximum sequence length of M, 630 | there are N possible sequences that end at the first step; N^2 that end 631 | at the second step; and so forth up to N^M. Thus altogether there are 632 | N^1 + N^2 + N^3 + ... + N^M 633 | = N^0 + N^1 + N^2 + N^3 + ... + N^M - 1 634 | = (N^(M+1) - 1)/(N - 1) - 1 635 | ~= N^M 636 | sequences, where the approximation follows from the fact that, for N or M 637 | of any reasonable size, the -1s don't matter. Likewise, subtracting out 638 | the K in-beam sequences has no appreciable effect for any reasonable K. 639 | Hence the probability of each out-of-beam sequence is approximately N^-M, 640 | or again: 641 | log(out_beam_prob) = -M*log(N) 642 | 643 | Given the approximations, and more importantly since no attempt is made to 644 | decrease the in-beam probabilities by the probability assigned to out-of- 645 | beam ids, the result of logsumexp will be *unnormalized* log probabilities. 646 | These values are furthermore desequenced into shape 647 | (sum_i^N_cases targ_seq_len(i) x Nclasses) 648 | before returning. 649 | ''' 650 | 651 | # one-hotify and scale by log probabilities 652 | # -> (N_cases x beam_width x max_pred_length x Nclasses) 653 | # NB that the resulting tensor does *not* represent log probs, b/c it has 654 | # *zeros* in the out-of-beam locations 655 | in_beam_log_probs = tf.multiply( 656 | tf.one_hot(get_beam_outputs, Nclasses, axis=-1), 657 | tf.expand_dims(tf.expand_dims(get_sequence_log_probs, axis=-1), axis=-1) 658 | ) 659 | 660 | # pad out to max_hyp_length 661 | # -> (N_cases x beam_width x max_hyp_length x Nclasses) 662 | max_pred_length = common_layers.shape_list(get_beam_outputs)[2] 663 | in_beam_log_probs = tf.pad( 664 | tensor=in_beam_log_probs, 665 | paddings=[ 666 | [0, 0], 667 | [0, 0], 668 | [0, tf.maximum(max_hyp_length - max_pred_length, 0) + 1], 669 | [0, 0] 670 | ], 671 | constant_values=padding_value 672 | ) 673 | ### 674 | # This assumes the pad token=0. Ideally, you'd pass this in explicitly, 675 | # and then set constant_values= in tf.pad. 676 | ### 677 | 678 | # fill in zeros with (approximate) out-of-beam log probs (see above) 679 | out_beam_log_prob = tf.multiply( 680 | tf.cast(-max_hyp_length, tf.float32), 681 | tf.math.log(tf.cast(Nclasses, tf.float32)) 682 | ) 683 | out_beam_log_probs = tf.fill( 684 | common_layers.shape_list(in_beam_log_probs), out_beam_log_prob 685 | ) 686 | IS_OUT_OF_BEAM = tf.equal(in_beam_log_probs, 0) 687 | beam_log_probs = tf.compat.v1.where( 688 | IS_OUT_OF_BEAM, out_beam_log_probs, in_beam_log_probs 689 | ) 690 | 691 | # collapse across beam -> (N_cases x max_hyp_length x Nclasses) 692 | score_as_unnorm_log_probs = tf.reduce_logsumexp(beam_log_probs, axis=1) 693 | 694 | # de-sequence -> (sum_i^N_cases targ_seq_len(i) x Nclasses) 695 | score_as_unnorm_log_probs = tf.gather_nd( 696 | score_as_unnorm_log_probs, index_sequences_elements 697 | ) 698 | 699 | return score_as_unnorm_log_probs 700 | 701 | 702 | def fake_beam_for_sequence_targets( 703 | get_targets, get_natural_params, unique_targets_list, beam_width, pad_token 704 | ): 705 | ''' 706 | This function breaks each target and prediction at any spaces they contain, 707 | and treats the resulting lists as sentences between which to compute word 708 | error rates. (For targets that don't contain spaces, nothing interesting 709 | happens.) 710 | 711 | Returns: 712 | -------- 713 | references: 714 | (N_cases x 1 x max_ref_length) tensor of "reference" sentences 715 | hypotheses: 716 | (N_cases x beam_width x max_hyp_length) 717 | int32 tensor of size (N_cases) 718 | fake_beam_natural_params: 719 | (N_cases x beam_width) 720 | 721 | 722 | #### 723 | Why convert to strings? Why not just compute with the word indices?? 724 | #### 725 | ''' 726 | 727 | # make tensors for the list of unique targets (single words or sentences) 728 | unique_targets_tensor = tf.constant( 729 | unique_targets_list, shape=[len(unique_targets_list), 1]) 730 | 731 | # ...and the list of unique *tokens* that the targets comprise 732 | unique_tokens_list = targets_to_tokens(unique_targets_list, pad_token) 733 | unique_tokens_tensor = tf.constant( 734 | unique_tokens_list, shape=[1, 1, len(unique_tokens_list)]) 735 | 736 | # pretend targets and predictions are themselves sequences 737 | _, fake_beam_ids = tf.nn.top_k(get_natural_params, k=beam_width) 738 | make_target_matrix = tf_sentence_to_word_ids( 739 | get_targets, unique_targets_tensor, 740 | unique_tokens_tensor, pad_token 741 | ) 742 | make_prediction_matrix = tf_sentence_to_word_ids( 743 | fake_beam_ids, unique_targets_tensor, unique_tokens_tensor, pad_token) 744 | references = tf.expand_dims(make_target_matrix, axis=1) 745 | hypotheses = tf.reshape( 746 | make_prediction_matrix, [tf.shape(fake_beam_ids)[0], beam_width, -1] 747 | ) 748 | 749 | # from *all* natural params, extract just some for a fake beam 750 | row_inds = tf.cast(tf.tile( 751 | tf.expand_dims(tf.range(tf.shape(fake_beam_ids)[0]), 1), 752 | (1, beam_width)), tf.int32) 753 | fake_beam_inds = tf.stack((tf.reshape(row_inds, [-1]), 754 | tf.reshape(fake_beam_ids, [-1])), 1) 755 | fake_beam_natural_params = tf.reshape( 756 | tf.gather_nd(get_natural_params, fake_beam_inds), [-1, beam_width] 757 | ) 758 | 759 | return references, hypotheses, fake_beam_natural_params 760 | 761 | 762 | def targets_to_tokens(unique_targets_list, pad_token): 763 | ''' 764 | This only does something interesting for 'trial' data, i.e. if the unique 765 | targets are single strings of *sentences* (containing spaces). In this 766 | case, the unique_tokens_list contains all words in those sentences, plus 767 | the pad_token. If the unique targets are single words, then the unique 768 | tokens will just be identical, although possibly re-ordered, and adding the 769 | pad_token if it wasn't originally present. 770 | 771 | ''' 772 | # get unique tokens by splitting into pieces (at spaces); exclude pad_token 773 | unique_tokens_list = list(set([ 774 | item for target in unique_targets_list for item in target.split(' ') 775 | if item != pad_token])) 776 | unique_tokens_list.sort() # enforce deterministic behavior 777 | 778 | # make sure the pad_token is at the beginning of the list [why?] 779 | unique_tokens_list = [pad_token] + unique_tokens_list 780 | 781 | return unique_tokens_list 782 | 783 | 784 | def tf_sentence_to_word_ids( 785 | sentence_ids, unique_targets_tensor, unique_tokens_tensor, pad_token): 786 | ''' 787 | (1) Given sentence *target* ids, extract the corresponding sentences 788 | (2) Then break these into words (N_cases, max_len, 1) 789 | (3) Convert from words to *token* ids, by broadcasting tf.equals. The 790 | resulting matrix will have the word ID in its third column, and its 791 | location (sentence number, word number) in the first two columns 792 | (4) Scatter back into a matrix, with zero padding 793 | ''' 794 | 795 | extract_sentences = tf.gather(unique_targets_tensor, sentence_ids) 796 | N_cases = tf.size(extract_sentences) 797 | extract_words = tf.expand_dims(tf.sparse.to_dense(tf.compat.v1.string_split( 798 | tf.reshape(extract_sentences, [-1])), default_value=pad_token), -1) 799 | id_tokens = tf.compat.v1.where(tf.equal(unique_tokens_tensor, extract_words)) 800 | return tf.scatter_nd( 801 | id_tokens[:, 0:2], id_tokens[:, 2], 802 | [tf.cast(N_cases, tf.int64), 1+tf.reduce_max(input_tensor=id_tokens[:, 1])]) 803 | 804 | 805 | def average_gradients(tower_grads): 806 | # Cribbed from the tensorflow tutorials: 807 | # tutorials/image/cifar10/cifar10_multi_gpu_train.py 808 | """Calculate the average gradient for each shared variable across 809 | all towers. 810 | 811 | Note that this function provides a synchronization point across 812 | all towers. 813 | 814 | Args: 815 | tower_grads: List of lists of (gradient, variable) tuples. The 816 | outer list is over individual gradients. The inner list is 817 | over the gradient calculation for each tower. 818 | Returns: 819 | List of pairs of (gradient, variable) where the gradient has 820 | been averaged across all towers. 821 | """ 822 | average_grads = [] 823 | for grad_and_vars in zip(*tower_grads): 824 | # Note that each grad_and_vars looks like the following: 825 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 826 | grads = [] 827 | for g, _ in grad_and_vars: 828 | # Add 0 dimension to the gradients to represent the tower. 829 | expanded_g = tf.expand_dims(g, 0) 830 | 831 | # Append on a 'tower' dimension which we will average over 832 | # below. 833 | grads.append(expanded_g) 834 | 835 | # Average over the 'tower' dimension. 836 | grad = tf.concat(axis=0, values=grads) 837 | grad = tf.reduce_mean(input_tensor=grad, axis=0) 838 | 839 | # Keep in mind that the Variables are redundant because they 840 | # are shared across towers. So .. we will just return the 841 | # first tower's pointer to the Variable. 842 | v = grad_and_vars[0][1] 843 | grad_and_var = (grad, v) 844 | average_grads.append(grad_and_var) 845 | return average_grads 846 | 847 | 848 | def tf_linear_interpolation(X, stretch_factor, axis=0): 849 | ''' 850 | Linearly interpolate sequences in `X ` along axis `axis`, according to 851 | the `stretch_factor`. 852 | 853 | Each point in a linear interpolation is a weighted sum of the closest 854 | values, above and below. 855 | ''' 856 | 857 | # get the original and new (resampled) max sample 858 | T_orig = tf.shape(input=X)[1] 859 | T_new = tf.round(stretch_factor*tf.cast(T_orig-1, tf.float32)) 860 | 861 | # interpolate "indices"--but NB that these are actually floats 862 | interpolate_inds = tf.range(T_new)/stretch_factor 863 | 864 | # the closest integer-valued indices *below* each interpolated value 865 | get_lower_inds = tf.cast(tf.floor(interpolate_inds), tf.int32) 866 | 867 | # extract points at these lower inds and at the next higher inds 868 | extract_lower_vals = tfh.fancy_indexing(X, get_lower_inds, axis=axis) 869 | extract_upper_vals = tfh.fancy_indexing(X, get_lower_inds+1, axis=axis) 870 | 871 | # linear interpolant is a weighted sum of these values 872 | get_w_lower = tf.cast(get_lower_inds + 1, tf.float32) - interpolate_inds 873 | get_w_upper = interpolate_inds - tf.cast(get_lower_inds, tf.float32) 874 | 875 | new_shape = [tf.constant(1) if i != axis else tf.shape(input=get_lower_inds)[0] 876 | for i in range(len(common_layers.shape_list(X)))] 877 | get_w_lower = tf.reshape(get_w_lower, new_shape) 878 | get_w_upper = tf.reshape(get_w_upper, new_shape) 879 | 880 | return get_w_lower*extract_lower_vals + get_w_upper*extract_upper_vals 881 | 882 | 883 | def swap(key, string): 884 | # In SequenceNetworks, keys are often constructed from the data_manifest 885 | # key by swapping out the word 'targets' for some other string. This is 886 | # just a shortcut for that process. 887 | return key.replace('targets', string) 888 | 889 | 890 | def cross_entropy(key, data_manifest, sequenced_op_dict): 891 | ''' 892 | ... 893 | 894 | In fact, this function *averages*, rather than sums, across all features 895 | given by a particular key. Although the result is not technically the 896 | cross entropy of the output, it is more easily comparable across keys and 897 | therefore facilitates the design of penalty_scales. 898 | ''' 899 | 900 | # desequence the targets and natural_params 901 | # NB that this enforces that the lengths of the predicted and actual 902 | # sequences match. This is of course *not* enforced when calculating the 903 | # word error rate, which is anyway computed from 'decoder_outputs', not 904 | # 'decoder_natural_params'. 905 | index_targets, get_lengths = sequences_tools(sequenced_op_dict[key]) 906 | targets = tf.gather_nd(sequenced_op_dict[key], index_targets) 907 | np_key = swap(key, 'natural_params') 908 | natural_params = tf.gather_nd(sequenced_op_dict[np_key], index_targets) 909 | 910 | # the form of the cross-entropy depends on the distribution 911 | if data_manifest.distribution == 'Gaussian': 912 | # average across features (axis=1) 913 | compute_cross_entropy = tf.reduce_mean( 914 | tf.square(natural_params - targets), 1)/2 915 | elif data_manifest.distribution == 'categorical': 916 | compute_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 917 | labels=tf.reshape(targets, [-1]), logits=natural_params 918 | ) 919 | elif data_manifest.distribution == 'CTC': 920 | ######### 921 | # Why do we need to pad the symbol dimensions with a zero?? 922 | sequenced_natural_params = tf.pad( 923 | sequenced_op_dict[np_key], tf.constant([[0, 0], [0, 0], [0, 1]]) 924 | ) 925 | ######### 926 | 927 | # the labels need to be a SparseTensor 928 | sequenced_encoder_targets = tf.SparseTensor( 929 | tf.cast(index_targets, tf.int64), 930 | tf.reshape(targets, [-1]), 931 | tf.cast( 932 | [tf.shape(get_lengths)[0], tf.reduce_max(get_lengths)], 933 | tf.int64 934 | ) 935 | ) 936 | 937 | #### 938 | # ugh: not actually a cross entropy... 939 | #### 940 | compute_cross_entropy = tf.compat.v1.nn.ctc_loss( 941 | sequenced_encoder_targets, 942 | inputs=sequenced_natural_params, 943 | sequence_length=get_lengths, 944 | preprocess_collapse_repeated=True, 945 | ctc_merge_repeated=False, 946 | time_major=False 947 | ) 948 | else: 949 | # raise NotImplementedError( 950 | # "Only Gaussian, categorical cross entropies have been impl.") 951 | print('WARNING: unrecognized data_manifest.', end='') 952 | print('distribution; not computing a cross entropy') 953 | return 954 | 955 | # average across elements of the batch 956 | return tf.reduce_mean(compute_cross_entropy, 0) 957 | ###return tf.reduce_sum(compute_cross_entropy, 0) 958 | -------------------------------------------------------------------------------- /machine_learning/neural_networks/tf_helpers.py: -------------------------------------------------------------------------------- 1 | # standard libraries 2 | import pdb 3 | import sys 4 | import os 5 | 6 | # third-party packages 7 | import numpy as np 8 | import tensorflow as tf 9 | from tensorflow.python.client import device_lib 10 | try: 11 | from tensor2tensor.layers import common_layers 12 | except ModuleNotFoundError: 13 | print('WARNING: tensor2tensor not found; skipping') 14 | 15 | # local 16 | from utils_jgm.toolbox import auto_attribute 17 | 18 | ''' 19 | A collection of helper functions for use with tensorflow 20 | 21 | :Author: J.G. Makin (except where otherwise noted) 22 | 23 | Cribbed from other JGM code: June 2018 24 | ''' 25 | 26 | PS_OPS = ("Variable", "VariableV2", "AutoReloadVariable", 27 | "MutableHashTable", "MutableHashTableV2", 28 | "MutableHashTableOfTensors", "MutableHashTableOfTensorsV2", 29 | "MutableDenseHashTable", "MutableDenseHashTableV2", 30 | "VarHandleOp", "BoostedTreesEnsembleResourceHandleOp" 31 | "Assert", "StringFormat", "PrintV2" # added by JGM 32 | ) 33 | 34 | 35 | class GraphBuilder: 36 | @auto_attribute 37 | def __init__( 38 | self, 39 | # functions: 40 | training_data_fxn, 41 | assessment_data_fxn, 42 | training_net_builder, 43 | assessment_net_builder, 44 | optimizer, 45 | assessor, 46 | # other arguments: 47 | checkpoints_path, 48 | final_epoch, 49 | # arguments with default values: 50 | initial_epoch=0, 51 | EMA_decay=0.0, 52 | reuse_vars_scope=None, 53 | training_GPUs=None, 54 | assessment_GPU=0, 55 | # private; don't assign these to self: 56 | _restore_epoch=None, 57 | _restore_model=None, 58 | ): 59 | 60 | # construct and store other useful parameters 61 | if _restore_epoch is None: 62 | _restore_epoch = final_epoch 63 | self.last_checkpoint = self.checkpoints_path + '-%i' % _restore_epoch 64 | if reuse_vars_scope: 65 | self.final_epoch += _restore_epoch 66 | self.initial_epoch += _restore_epoch 67 | if _restore_model: 68 | token_type = os.path.split(os.path.split(os.path.split( 69 | self.last_checkpoint)[0])[0])[1] 70 | self.last_checkpoint = self.last_checkpoint.replace( 71 | token_type, _restore_model + '_' + token_type) 72 | 73 | def train_and_assess(self, assessment_epoch_interval=1): 74 | ''' 75 | Train and assess a neural netword built in tensorflow. 76 | ''' 77 | 78 | # construct, initialize training and assessment graphs 79 | (update_params, initialize_training_data, training_saver, training_sess, 80 | ) = self._build_training_graph() 81 | (assessment_sess, assessment_saver, assessments 82 | ) = self._build_assessment_graph() 83 | 84 | # start training 85 | assessment_step = 0 86 | try: 87 | for training_epoch in range(self.initial_epoch, self.final_epoch): 88 | print('training...') 89 | training_sess.run(initialize_training_data) 90 | while True: 91 | try: 92 | training_sess.run(update_params) 93 | except tf.errors.OutOfRangeError: 94 | break 95 | if training_epoch % assessment_epoch_interval == 0: 96 | print('assessing...') 97 | assessments = self._save_and_assess( 98 | training_sess, training_saver, training_epoch, 99 | assessment_sess, assessment_saver, assessment_step, 100 | assessments) 101 | assessment_step += 1 102 | else: 103 | # be sure to save the model (and assess) after the last epoch 104 | assessments = self._save_and_assess( 105 | training_sess, training_saver, training_epoch+1, 106 | assessment_sess, assessment_saver, assessment_step, 107 | assessments) 108 | 109 | self.close_all( 110 | training_sess, assessment_sess, 111 | *[struct.writer for (key, struct) in assessments.items()], 112 | # writer 113 | ) 114 | return assessments 115 | 116 | except: 117 | exc_type, exc_value, exc_traceback = sys.exc_info() 118 | print(str(exc_type)) 119 | print(str(exc_value)) 120 | print('AT LINE ' + str(exc_traceback.tb_lineno)) 121 | print('cleaning up...') 122 | self.close_all( 123 | training_sess, assessment_sess, 124 | *[struct.writer for (key, struct) in assessments.items()], 125 | # writer 126 | ) 127 | print('...cleaned up!') 128 | 129 | return assessments 130 | 131 | def assess(self): 132 | (assessment_sess, assessment_saver, assessments 133 | ) = self._build_assessment_graph() 134 | assessment_saver.restore(assessment_sess, self.last_checkpoint) 135 | for data_partition in assessments.keys(): 136 | assessments[data_partition] = self.assessor( 137 | assessment_sess, assessments[data_partition], 138 | self.final_epoch, 0, data_partition) 139 | return assessments 140 | 141 | def get_saliencies(self): 142 | # get gradients (of inputs) 143 | controller = '/cpu:0' 144 | initialize_data, tower_grads = self._parallel_differentiator(controller) 145 | with tf.compat.v1.name_scope("apply_gradients"), tf.device(controller): 146 | get_input_saliencies = self._average_tower_gradients(tower_grads) 147 | 148 | # create the session and restore the graph 149 | EMA = tf.train.ExponentialMovingAverage( 150 | decay=self.EMA_decay) if self.EMA_decay else None 151 | sess, saver = get_session_and_saver( 152 | EMA=EMA, allow_soft_placement=True, allow_growth=True) 153 | saver.restore(sess, self.last_checkpoint) 154 | 155 | # an abuse of the assessor... 156 | return self.assessor(sess, initialize_data, get_input_saliencies[0][0]) 157 | 158 | def _build_training_graph(self): 159 | 160 | training_graph = tf.Graph() 161 | with training_graph.as_default(): 162 | # set up trainer; open session; restore saved weights 163 | update_params, initialize_data, EMA = self._parallel_trainer() 164 | sess, saver = get_session_and_saver( 165 | EMA=EMA, allow_soft_placement=True, allow_growth=True 166 | ) 167 | self._restore_weights(sess, training_graph, EMA) 168 | training_graph.finalize() 169 | 170 | print('Training graph built...') 171 | return update_params, initialize_data, saver, sess 172 | 173 | def _restore_weights(self, sess, training_graph, EMA): 174 | if self.reuse_vars_scope: 175 | reuse_vars = tf.compat.v1.trainable_variables(self.reuse_vars_scope) 176 | if EMA is not None: 177 | # The keys are EMA names; the values are EMA variables where 178 | # they exist and otherwise the "regular" variables. NB that it 179 | # includes the AdaM variables, although these are filtered out 180 | # below by requiring that the vars be in reuse_vars. 181 | EMA_reuse_var_dict = EMA.variables_to_restore() 182 | reuse_vars_dict = {EMA.average_name(var): EMA_reuse_var_dict[ 183 | EMA.average_name(var)] for var in reuse_vars} 184 | else: 185 | reuse_vars_dict = { 186 | reuse_var.name.split(':')[0]: 187 | training_graph.get_tensor_by_name(reuse_var.name) 188 | for reuse_var in reuse_vars 189 | } 190 | # pdb.set_trace() 191 | training_restore_saver = tf.compat.v1.train.Saver(reuse_vars_dict) 192 | training_restore_saver.restore(sess, self.last_checkpoint) 193 | 194 | def _build_assessment_graph(self): 195 | 196 | # for malloc'ing some data storage 197 | num_epochs = self.final_epoch - self.initial_epoch 198 | 199 | # under this graph (and device)... 200 | assessment_graph = tf.Graph() 201 | #with assessment_graph.as_default(), tf.device( 202 | # '/gpu:%i' % self.assessment_GPU): 203 | with assessment_graph.as_default(): 204 | # ...create data; assess; prepare plots, init storage; open session 205 | GPU_op_dict, CPU_op_dict, assessments = self.assessment_data_fxn( 206 | num_epochs) 207 | self.assessment_net_builder(GPU_op_dict, CPU_op_dict) 208 | EMA = (tf.train.ExponentialMovingAverage(decay=self.EMA_decay) 209 | if self.EMA_decay else None) 210 | sess, saver = get_session_and_saver(EMA=EMA, allow_growth=True) 211 | assessment_graph.finalize() 212 | 213 | print('Assessment graph built...') 214 | return sess, saver, assessments 215 | 216 | def _save_and_assess( 217 | self, training_sess, training_saver, training_epoch, 218 | assessment_sess, assessment_saver, assessment_step, assessments 219 | ): 220 | this_checkpoint = training_saver.save( 221 | training_sess, self.checkpoints_path, global_step=training_epoch) 222 | assessment_saver.restore(assessment_sess, this_checkpoint) 223 | for data_partition in assessments.keys(): 224 | assessments[data_partition] = self.assessor( 225 | assessment_sess, assessments[data_partition], 226 | training_epoch, assessment_step, data_partition) 227 | return assessments 228 | 229 | def _parallel_trainer(self): 230 | # http://blog.s-schoener.com/2017-12-15-parallel-tensorflow-intro/ 231 | initialize_data, tower_grads = self._parallel_differentiator() 232 | update_weights = self._parallel_weight_updater(tower_grads) 233 | if self.EMA_decay: 234 | EMA = tf.train.ExponentialMovingAverage(decay=self.EMA_decay) 235 | update_moving_averages = EMA.apply(tf.compat.v1.trainable_variables()) 236 | update_params = tf.group((update_weights, update_moving_averages)) 237 | else: 238 | EMA = None 239 | update_params = update_weights 240 | 241 | return update_params, initialize_data, EMA 242 | 243 | def _parallel_differentiator(self, controller='/cpu:0'): 244 | # http://blog.s-schoener.com/2017-12-15-parallel-tensorflow-intro/ 245 | 246 | # return a list of device ids like`['/gpu:0', '/gpu:1']` 247 | devices = get_available_gpus() 248 | if not devices: 249 | devices = ['/cpu:0'] 250 | print('No GPUs available or requested! using the CPU -- jgm') 251 | elif self.training_GPUs is not None: 252 | devices = [devices[i] for i in self.training_GPUs] 253 | else: 254 | print('Using *all* %i GPUs...' % len(devices)) 255 | 256 | # the ops make data, to be placed either on the GPU or the CPU 257 | (GPU_op_dict, CPU_op_dict, initialize_data 258 | ) = self.training_data_fxn(len(devices)) 259 | 260 | # Get the current variable scope so we can reuse all variables we need 261 | # once we get to the nth iteration of the loop below 262 | tower_grads = [] 263 | for iDevice, device_id in enumerate(devices): 264 | print('Setting up tower on %s' % device_id) 265 | tower_name = 'tower_{}'.format(iDevice) 266 | 267 | # force onto controller device 268 | ########## 269 | # with tf.device(self._assign_to_device(device_id, controller)): 270 | ########## 271 | with tf.device('/gpu:{}'.format(iDevice)): 272 | # but see: https://stackoverflow.com/questions/45156542/ 273 | with tf.compat.v1.name_scope(tower_name) as scope: 274 | # compute gradients 275 | model_outputs = self.training_net_builder( 276 | {key: op[iDevice] for key, op in GPU_op_dict.items()}, 277 | CPU_op_dict, 278 | tower_name=tower_name 279 | ) 280 | with tf.compat.v1.name_scope("compute_gradients"): 281 | # get list of (gradient, variable) pairs 282 | grads_and_vars = self.optimizer.compute_gradients( 283 | *model_outputs) 284 | tower_grads.append(grads_and_vars) 285 | 286 | return initialize_data, tower_grads 287 | 288 | def _parallel_weight_updater(self, tower_grads, controller='/cpu:0'): 289 | 290 | # Apply the gradients on the controlling device 291 | with tf.compat.v1.name_scope("apply_gradients"), tf.device(controller): 292 | # back on the CPU, average the gradients from each tower 293 | 294 | # (gradient, variable) lists -> (gradient, variables) list 295 | gradients = self._average_tower_gradients(tower_grads) 296 | update_weights = self.optimizer.apply_gradients(gradients) 297 | 298 | return update_weights 299 | 300 | @staticmethod 301 | def _average_tower_gradients(tower_grads): 302 | ''' 303 | Calculate average gradient for each shared variable across all towers. 304 | 305 | See https://github.com/tensorflow/models/blob/master/tutorials/image/ 306 | cifar10/cifar10_multi_gpu_train.py#L101 307 | 308 | Note: this function provides a synchronization point across all towers. 309 | Args: 310 | tower_grads: List of lists of (gradient, variable) tuples. The outer 311 | list ranges over the devices. The inner list ranges over the 312 | different variables. 313 | Returns: 314 | List of pairs of (gradient, variable) where the gradient has 315 | been averaged across all towers. 316 | ''' 317 | averaged_grads_and_vars = [] 318 | for grad_and_var_all_towers in zip(*tower_grads): 319 | # Each grad_and_var_all_towers is 320 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 321 | 322 | #### 323 | # List comprehension fails for certain variables.... 324 | #grads_list = [g for g, _ in grad_and_var_all_towers] 325 | grads_list = [] 326 | for g, _ in grad_and_var_all_towers: 327 | expanded_g = tf.expand_dims(g, 0) 328 | grads_list.append(expanded_g) 329 | grads = tf.concat(axis=0, values=grads_list) 330 | #### 331 | grad = tf.reduce_mean(input_tensor=grads, axis=0) 332 | 333 | # Variables (grad_and_var_all_towers[iTower][1] for all iTower) are 334 | # redundant because they are shared across towers, so we can 335 | # return just the pointer from tower 0. 336 | averaged_grads_and_vars.append((grad, grad_and_var_all_towers[0][1])) 337 | return averaged_grads_and_vars 338 | 339 | @staticmethod 340 | def _assign_to_device(device, ps_device): 341 | ''' 342 | Returns a function to place variables on the ps_device. 343 | 344 | See https://github.com/tensorflow/tensorflow/issues/9517 345 | 346 | Args: 347 | device: Device for everything but variables 348 | ps_device: Device to put the variables on. Example values are 349 | /GPU:0 and /CPU:0. 350 | 351 | If ps_device is not set then variables will be placed on the default 352 | device. The best device for shared varibles depends on the platform as 353 | well as the model. Start with CPU:0 and then test GPU:0 to see if there 354 | is an improvement. 355 | ''' 356 | def _assign(op): 357 | node_def = op if isinstance(op, tf.compat.v1.NodeDef) else op.node_def 358 | if node_def.op in PS_OPS or 'read' in node_def.name: 359 | # If you don't do this, the 'read' ops for the kernels and 360 | # biases in tf.nn.rnn_cell.LSTMCell end up on GPU:0. For 361 | # details on the read op, see: 362 | # https://stackoverflow.com/questions/42783909/ 363 | # 364 | # or '/Assert' in op.name or '/summaries' in op.name: 365 | return ps_device 366 | else: 367 | return device 368 | return _assign 369 | 370 | @staticmethod 371 | def close_all(*args): 372 | for arg in args: 373 | arg.close() 374 | 375 | 376 | def get_session_and_saver( 377 | initialize_graph=None, EMA=None, allow_soft_placement=True, 378 | allow_growth=False, 379 | ): 380 | # if there isn't an initializer op, create it 381 | if initialize_graph is None: 382 | initialize_graph = tf.group( 383 | tf.compat.v1.global_variables_initializer(), 384 | tf.compat.v1.local_variables_initializer() 385 | ) 386 | 387 | # create a session 388 | sess_config = tf.compat.v1.ConfigProto( 389 | log_device_placement=True, allow_soft_placement=allow_soft_placement, 390 | ) 391 | sess_config.gpu_options.allow_growth = allow_growth 392 | sess = tf.compat.v1.Session(config=sess_config) 393 | 394 | # exponential moving average 395 | if EMA: 396 | saver = tf.compat.v1.train.Saver(EMA.variables_to_restore()) 397 | else: 398 | saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) 399 | sess.run(initialize_graph) 400 | 401 | return sess, saver 402 | 403 | 404 | def get_session_with_saved_model( 405 | restore_dir, allow_soft_placement=True, allow_growth=False): 406 | 407 | # well, this is just what you usually do 408 | sess_config = tf.compat.v1.ConfigProto( 409 | log_device_placement=True, 410 | allow_soft_placement=allow_soft_placement 411 | ) 412 | sess_config.gpu_options.allow_growth = allow_growth 413 | sess = tf.compat.v1.Session(config=sess_config) 414 | 415 | # load the model into (?) this session, and return it 416 | tf.compat.v1.saved_model.loader.load(sess, ["serve"], restore_dir) 417 | 418 | return sess 419 | 420 | 421 | def get_available_gpus(): 422 | ''' 423 | Returns a list of the identifiers of all visible GPUs. 424 | See https://stackoverflow.com/questions/38559755/ 425 | ''' 426 | local_device_protos = device_lib.list_local_devices() 427 | return [x.name for x in local_device_protos if x.device_type == 'GPU'] 428 | 429 | 430 | def hide_shape(x): 431 | ''' 432 | This is a rather subtle little function. There are times when you would 433 | like to hide a tensor x's shape from a tensorflow op--e.g., when that op 434 | is created in a branch of a conditional statement (tf.cond, tf.case) that 435 | will never be executed, but with which the shape of x is incompatible. 436 | Since tensorflow sets up the ops in *all* branches of the conditional, 437 | its shape inference will choke on x in this case. So first we obscure the 438 | shape with this function. 439 | 440 | I (JGM) copied it directly from here: 441 | https://github.com/tensorflow/tensorflow/issues/6906 442 | ''' 443 | return tf.cond( 444 | pred=tf.constant(True), 445 | true_fn=lambda: x, 446 | false_fn=lambda: tf.compat.v1.placeholder(x.dtype) 447 | ) 448 | 449 | 450 | def fancy_indexing(X, extract_inds, axis=0): 451 | ''' 452 | Select indices along axis `axis` from tensor `X` with indices in rank-1 453 | tensor `extract_inds`. 454 | ''' 455 | # Expand me to deal with all the interesting numpy cases 456 | 457 | # get the indices for gathering/scattering 458 | X_shape = tf.shape(input=X) 459 | make_grid_coords = [tf.range(X_shape[i]) for i in range(axis)] 460 | make_grid = tf.meshgrid(*make_grid_coords, extract_inds, indexing='ij') 461 | vectorize_grid = [tf.reshape(grid_coords, [-1]) for grid_coords in make_grid] 462 | matricize_grid = tf.stack(vectorize_grid, axis=1) 463 | 464 | # gather them up and reshape 465 | new_shape = [ 466 | X_shape[i] if i != axis else tf.shape(input=extract_inds)[0] 467 | for i in range(len(common_layers.shape_list(X))) 468 | ] 469 | return tf.reshape(tf.gather_nd(X, matricize_grid), new_shape) 470 | 471 | 472 | def rescale(get_X, xmin, xmax, zmin, zmax): 473 | scaling = (zmax - zmin)/(xmax - xmin) 474 | return scaling*(get_X - xmin) + zmin 475 | 476 | 477 | def tf_print(tensor, message="JGM TENSOR: ", **kwargs): 478 | print_op = tf.print(message, tensor, **kwargs) 479 | with tf.control_dependencies([print_op]): 480 | tensor = tf.identity(tensor) 481 | return tensor 482 | 483 | 484 | def make_feature_example(example_dict): 485 | ''' 486 | For this "example," construct a dictionary of "Features" with the same keys 487 | as the example_dict passed in. 488 | ''' 489 | 490 | feature_dict = {} 491 | for key, value in example_dict.items(): 492 | if type(value) is list: 493 | feature_dict[key] = _featurize_bytes_list(value) 494 | elif type(value) is np.ndarray: 495 | # *assume* it's a float, convert to single precision, and flatten 496 | feature_dict[key] = _featurize_float_list( 497 | np.float32(value).reshape(-1)) 498 | else: 499 | raise NotImplementedError( 500 | "Only list and ndarray features have been implemented") 501 | 502 | # transform the dictionary into Features 503 | features = tf.train.Features(feature=feature_dict) 504 | 505 | # transform Features into an Example, and return 506 | return tf.train.Example(features=features) 507 | 508 | 509 | def _featurize_bytes_list(value): 510 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 511 | 512 | 513 | def _featurize_int64_list(value): 514 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 515 | 516 | 517 | def _featurize_float_list(value): 518 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 519 | 520 | 521 | def parse_protobuf_seq2seq_example(example_proto, data_manifests): 522 | 523 | # parse the features using the data_descriptions and prepare the outputs 524 | feature_dict = { 525 | data_manifest.sequence_type: data_manifest.feature_value 526 | for data_manifest in data_manifests.values() 527 | } 528 | parsed_features = tf.io.parse_single_example( 529 | serialized=example_proto, features=feature_dict) 530 | example_dict = dict.fromkeys(data_manifests.keys()) 531 | 532 | # for each data_manifest (the number is indeterminate)... 533 | for key, data_manifest in data_manifests.items(): 534 | # ..."unflatten" the sequence of (possibly length-1) vectors and xform 535 | sequence_matrix = tf.reshape( 536 | parsed_features[data_manifest.sequence_type].values, 537 | (-1, data_manifest.num_features_raw) 538 | ) 539 | example_dict[key] = data_manifest.transform(sequence_matrix) 540 | 541 | return example_dict 542 | 543 | 544 | def replace_with_gaussian_noise(data_op): 545 | return tf.random.normal(tf.shape(input=data_op)) 546 | 547 | 548 | def randomly_rotate_sequence(data_op): 549 | T = tf.shape(data_op)[0] 550 | return tf.roll(data_op, shift=T//2, axis=0) 551 | 552 | 553 | def string_seq_to_index_seq( 554 | sequence_matrix, unique_targets_list, eos_id_list, OOV_id 555 | ): 556 | ''' 557 | Convert a sequence of strings (sequence_matrix) into a sequence of indices 558 | into the unique_targets_list, where 559 | indices[:,1] = target_ids 560 | indices[:,0] = sequence positions. 561 | NB that the sequence_matrix must have size (N x 1) as opposed to (N, ). 562 | Strings not found in the unique_targets_list are converted to the OOV_id. 563 | 564 | As a final step, the sequence is appended with eos_id_list, which typically 565 | will hold either a single id or none at all (e.g. for single-word data). 566 | Note that the returned sequence is, like the input, (N x 1). 567 | ''' 568 | 569 | # naively get the indices for all elements of this sequence 570 | unique_bytes_list = [t.encode('utf-8') for t in unique_targets_list] 571 | indices = tf.cast(tf.compat.v1.where(tf.equal( 572 | tf.constant(unique_bytes_list, shape=[1, len(unique_bytes_list)]), 573 | sequence_matrix 574 | ### can be [:, None]? 575 | )), tf.int32) 576 | 577 | # If a sequence element is missing (because that target wasn't in the 578 | # unique_targets_list), replace it with the OOV_id. 579 | target_shape = tf.shape(sequence_matrix)[0:1] 580 | all_OOV_vector = tf.fill(target_shape, OOV_id) 581 | updates = indices[:, 1] 582 | cull_non_OOV_target_ids = tf.scatter_nd( 583 | indices[:, 0, None], updates, target_shape) 584 | updates = tf.ones_like(indices[:, 1], dtype=tf.bool) 585 | mask_non_OOV_target_ids = tf.scatter_nd( 586 | indices[:, 0, None], updates, target_shape) 587 | replace_missing_with_OOV = tf.compat.v1.where( 588 | mask_non_OOV_target_ids, cull_non_OOV_target_ids, all_OOV_vector) 589 | 590 | # append the EOS_id before returning 591 | return tf.concat((replace_missing_with_OOV, eos_id_list), axis=0)[:, None] 592 | -------------------------------------------------------------------------------- /machine_learning/neural_networks/torch_sequence_networks.py: -------------------------------------------------------------------------------- 1 | # standard libraries 2 | import pdb 3 | from termcolor import cprint 4 | from IPython.display import clear_output 5 | import os 6 | import math 7 | from functools import partial 8 | 9 | # third-party libraries 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torchvision 15 | from torch.utils.tensorboard import SummaryWriter 16 | from torchtune.modules import RotaryPositionalEmbeddings 17 | ####### 18 | # from torch.profiler import profile, record_function, ProfilerActivity 19 | ####### 20 | 21 | # local 22 | from machine_learning.data_mungers import TFRecordDataLoader 23 | from machine_learning.torch_helpers import ( 24 | get_word_error_rate, sequences_tools, reverse_sequences 25 | ) 26 | from utils_jgm.toolbox import ( 27 | auto_attribute, wer_vector, close_factors, MutableNamedTuple 28 | ) 29 | from utils_jgm.machine_compatibility_utils import MachineCompatibilityUtils 30 | MCUs = MachineCompatibilityUtils() 31 | 32 | 33 | ''' 34 | Neural networks for sequence-to-label and sequence-to-sequence problems. 35 | 36 | Some portions inspired by the tutorial here: 37 | https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html 38 | 39 | :Author: J.G. Makin (except where otherwise noted) 40 | 41 | Created: November 2023 42 | by JGM 43 | ''' 44 | 45 | 46 | ############### 47 | # (1) transformer encoder: 48 | # (a) for word_*sequence* 49 | # (b) add positional encoding 50 | # 51 | # 52 | # 53 | # (16) NB: the penalty_scale probably means something different for MFCCs 54 | # b/c you are summing rather than averaging across the 13 dimensions... 55 | # (17) serial transfer learning 56 | # (18) print PER 57 | ############### 58 | 59 | 60 | ''' 61 | Data orderings: 62 | canonical: (N_cases x T x N_features) 63 | Conv1d: (N_cases x N_features x T) 64 | Embedding input: (*) 65 | Embedding output: (*, N_features) 66 | RNN inputs: (T x N_cases x N_features) 67 | RNN outputs: (T x N_cases x N_features) 68 | RNN states: (N_layers*N_directions x N_cases x N_features) 69 | reshaped states: (N_layers x N_cases x N_directions*N_features) 70 | ''' 71 | 72 | 73 | class Sequence2Sequence(nn.Module): 74 | @auto_attribute(CHECK_MANIFEST=True) 75 | def __init__( 76 | self, 77 | manifest, 78 | subnets_params, 79 | ##### 80 | # kwargs set in the manifest 81 | layer_sizes=None, 82 | FF_dropout=None, 83 | RNN_dropout=None, # misnomer 84 | TEMPORALLY_CONVOLVE=None, # !! currently does nothing 85 | ##### 86 | ENCODER_IS_BIDIRECTIONAL=True, 87 | training_GPUs=None, 88 | EOS_token='', 89 | pad_token='', 90 | TARGETS_ARE_SEQUENCES=True, 91 | max_hyp_length=20, 92 | coupling='attention', # 'final state' 93 | encoder_type='GRU', 94 | decoder_type='GRU', 95 | VERBOSE=True, 96 | ): 97 | super().__init__() 98 | 99 | if decoder_type == 'classifier': 100 | assert not TARGETS_ARE_SEQUENCES, ( 101 | "Non-sequence targets require the decoder to be a classifier!" 102 | ) 103 | 104 | # useful subnet_params 105 | #-------# 106 | # for now, assume there is only one decoder_targets_list 107 | self.decoder_targets_list = subnets_params[-1].data_manifests[ 108 | 'decoder_targets'].get_feature_list() 109 | #-------# 110 | self.EOS_token = EOS_token 111 | self.pad_token = pad_token 112 | 113 | self.EOS_id = self.decoder_targets_list.index(self.EOS_token) 114 | self.pad_id = self.decoder_targets_list.index(self.pad_token) 115 | 116 | # self.vprint('COUPLING %s ENCODER TO %s DECODER WITH %s' % ( 117 | # encoder_type, decoder_type, coupling 118 | # )) 119 | 120 | # ENCODER 121 | match encoder_type: 122 | case 'LSTM' | 'GRU': 123 | self.encoder = EncoderRNN( 124 | self.layer_sizes, self.FF_dropout, self.RNN_dropout, 125 | subnets_params, self.ENCODER_IS_BIDIRECTIONAL, 126 | encoder_type, coupling=coupling, VERBOSE=VERBOSE 127 | ) 128 | case 'transformer': 129 | self.encoder = EncoderTransformer( 130 | self.layer_sizes, self.FF_dropout, self.RNN_dropout, 131 | subnets_params, coupling=coupling, VERBOSE=VERBOSE 132 | ) 133 | self.ENCODER_IS_BIDIRECTIONAL = False 134 | case _: 135 | raise ValueError('Unrecognized decoder_type') 136 | 137 | # reshape the context to pass to the decoder 138 | self.reshape_context = partial( 139 | context_reshape, N_layers=len(self.layer_sizes['decoder_rnn']), 140 | N_directions=self.ENCODER_IS_BIDIRECTIONAL+1 141 | ) 142 | 143 | # DECODER 144 | self.vprint('COUPLING %s ENCODER ' % encoder_type, end='') 145 | if decoder_type == 'classifier': 146 | self.decoder = FinalStateClassifier( 147 | self.layer_sizes, self.FF_dropout, subnets_params, 148 | self.ENCODER_IS_BIDIRECTIONAL 149 | ) 150 | self.vprint('TO A FINAL-STATE CLASSIFIER') 151 | else: 152 | match coupling: 153 | case 'final state': 154 | self.decoder = DecoderRNN( 155 | self.layer_sizes, self.FF_dropout, self.RNN_dropout, 156 | subnets_params, self.EOS_id, max_hyp_length, decoder_type 157 | # use EOS as SOS, as in the TF1 version 158 | ) 159 | self.vprint('VIA FINAL HIDDEN STATE TO DECODER') 160 | case 'attention': 161 | key_length = layer_sizes['encoder_rnn'][-1]*( 162 | 1 + self.ENCODER_IS_BIDIRECTIONAL 163 | ) 164 | self.decoder = DecoderAttentionRNN( 165 | key_length, self.layer_sizes, self.FF_dropout, 166 | self.RNN_dropout, subnets_params, self.EOS_id, 167 | max_hyp_length, decoder_type 168 | # use EOS as SOS, as in the TF1 version 169 | ) 170 | self.vprint('VIA ATTENTION TO DECODER') 171 | case _: 172 | raise ValueError('Unrecognized decoder_type') 173 | 174 | # accumulate the loss functions over subjects and their losses 175 | self.loss_fxn_dicts = {} 176 | for subnet_params in subnets_params: 177 | subnet_id = str(subnet_params.subnet_id) 178 | self.loss_fxn_dicts[subnet_id] = {} 179 | for key in subnet_params.data_mapping: 180 | if key.endswith('targets'): 181 | data_manifest = subnet_params.data_manifests[key] 182 | self.loss_fxn_dicts[subnet_id][key] = ( 183 | get_cross_entropy_fxn(data_manifest.distribution), 184 | data_manifest.penalty_scale 185 | ) 186 | 187 | def forward(self, inputs, subnet_id, targets=None): 188 | ''' 189 | The inputs, targets, and natural_params have JGM "canonical ordering," 190 | 191 | (N_cases x T x N_features) and (N_cases x T) 192 | 193 | The RNN outputs have RNN ordering, 194 | 195 | (T x N_cases x N_features) 196 | 197 | and the RNN states ("similarly") have shape 198 | 199 | (N_layers*N_directions x N_cases x N_features) 200 | ''' 201 | 202 | # for storing useful outputs 203 | natural_params_dict = {} 204 | image_dict = {} 205 | 206 | # encode; project outputs; init decoder state; decode; project outputs 207 | encoder_rnn_outputs, encoder_final_states, inputs_lengths = self.encoder( 208 | inputs, subnet_id, natural_params_dict, image_dict 209 | ) 210 | if isinstance(encoder_final_states, tuple): 211 | encoder_final_states = tuple( 212 | self.reshape_context(states) for states in encoder_final_states 213 | ) 214 | else: 215 | encoder_final_states = self.reshape_context(encoder_final_states) 216 | self.decoder( 217 | encoder_rnn_outputs, encoder_final_states, targets, inputs_lengths, 218 | natural_params_dict, image_dict, 219 | ) 220 | 221 | return natural_params_dict, image_dict 222 | 223 | def print_sentences( 224 | self, most_probable_classes, decoder_targets, on_clr, N_sentences=10, 225 | PRINT_CRUDE_WER=False 226 | ): 227 | 228 | if PRINT_CRUDE_WER: 229 | accumulated_targets = [] 230 | accumulated_predictions = [] 231 | 232 | print() 233 | for iExample, (predicted_classes, target_classes) in enumerate(zip( 234 | most_probable_classes, decoder_targets 235 | )): 236 | predicted_words = class_indices_to_sequence( 237 | predicted_classes, self.decoder_targets_list, 238 | self.EOS_token, self.pad_token 239 | ) 240 | target_words = class_indices_to_sequence( 241 | target_classes, self.decoder_targets_list, 242 | self.EOS_token, self.pad_token 243 | ) 244 | 245 | # reduce cluter; don't print sentences for training data 246 | if True: # data_partition == 'validation': 247 | cprint( 248 | '{0:60} {1}'.format(target_words, predicted_words), 249 | on_color=on_clr 250 | ) 251 | 252 | if PRINT_CRUDE_WER: 253 | accumulated_targets.append(target_words.split()) 254 | accumulated_predictions.append(predicted_words.split()) 255 | 256 | # only print N_sentences sentences 257 | if iExample > N_sentences: 258 | break 259 | 260 | if PRINT_CRUDE_WER: 261 | print(' WERb: %1.3f' % np.mean( 262 | wer_vector(accumulated_targets, accumulated_predictions) 263 | )) 264 | 265 | def vprint(self, *args, **kwargs): 266 | if self.VERBOSE: 267 | print(*args, **kwargs) 268 | 269 | 270 | class EncoderRNN(nn.Module): 271 | def __init__( 272 | self, 273 | layer_sizes, 274 | FF_dropout, 275 | RNN_dropout, 276 | subnets_params, 277 | BIDIRECTIONAL, 278 | RNN_type, 279 | coupling=None, 280 | MAX_POOL=False, 281 | VERBOSE=True, 282 | ): 283 | super().__init__() 284 | 285 | # ... 286 | if len(np.unique(layer_sizes['encoder_rnn'])) > 1: 287 | raise NotImplementedError('Expected the same layer size for all layers') 288 | else: 289 | N_hidden = layer_sizes['encoder_rnn'][0] 290 | 291 | self.FF_dropout = FF_dropout 292 | self.RNN_dropout = RNN_dropout 293 | self.coupling = coupling 294 | 295 | # accumulate proprietary components (embeddings, projections) 296 | self.embeddings = nn.ModuleDict() 297 | self.projections = nn.ModuleDict() 298 | self.decimation_factors = {} 299 | for subnet_params in subnets_params: 300 | subnet_id = str(subnet_params.subnet_id) 301 | 302 | # embedding_layers 303 | self.embeddings[subnet_id] = MLConvEmbedding( 304 | subnet_params.data_manifests['encoder_inputs'].num_features, 305 | layer_sizes['encoder_embedding'], subnet_params, 306 | self.FF_dropout, MAX_POOL, VERBOSE=VERBOSE 307 | ) 308 | 309 | # decimation_factors 310 | self.decimation_factors[subnet_id] = subnet_params.decimation_factor 311 | 312 | # projections 313 | self.projections[subnet_id] = nn.ModuleDict() 314 | for key in subnet_params.data_mapping: 315 | if key.startswith('encoder') and key.endswith('targets'): 316 | 317 | # useful quantities 318 | data_manifest = subnet_params.data_manifests[key] 319 | RNN_layer = int(key.split('_')[1]) 320 | N_outputs = data_manifest.num_features 321 | Ns_hidden = layer_sizes['encoder_%i_projection' % RNN_layer] 322 | N_inputs = layer_sizes['encoder_rnn'][RNN_layer]*( 323 | BIDIRECTIONAL + 1 324 | ) 325 | 326 | # accumulate this encoder projection 327 | self.projections[subnet_id][key] = MultiLayerProjection( 328 | N_inputs, Ns_hidden, N_outputs, self.FF_dropout, 329 | ) 330 | 331 | # the RNN; you may need outputs from intermediate layers, so you have 332 | # to construct this one layer at a time 333 | self.RNNs = nn.ModuleList() 334 | RNN = getattr(nn, RNN_type) 335 | N_in = layer_sizes['encoder_embedding'][-1] 336 | for N_hidden in layer_sizes['encoder_rnn']: 337 | self.RNNs.append( 338 | RNN(N_in, N_hidden, num_layers=1, bidirectional=BIDIRECTIONAL) 339 | ) 340 | N_in = N_hidden*(1 + BIDIRECTIONAL) 341 | 342 | ############### 343 | # flatten_parameters() 344 | ############### 345 | 346 | def forward(self, inputs, subnet_id, natural_params_dict, image_dict): 347 | 348 | # get lengths of *downsampled* input sequences 349 | inputs_indices, inputs_lengths = sequences_tools( 350 | inputs[:, ::self.decimation_factors[subnet_id], :] 351 | ) 352 | 353 | # embed with temporal convolutions 354 | X = self.embeddings[subnet_id](inputs) 355 | 356 | # reverse? (a la Sutskever 2014); canonical ordering -> RNN ordering, 357 | # (T x N_cases x N_features) 358 | if self.coupling == 'final state': 359 | X = reverse_sequences(X, inputs_indices, inputs_lengths) 360 | X = X.permute(1, 0, 2) 361 | 362 | # the original TF version used dropout on the RNN *inputs* 363 | X = F.dropout( 364 | X, self.RNN_dropout, training=self.training, 365 | inplace=True 366 | ) 367 | 368 | # "pack" and pad 369 | X = nn.utils.rnn.pack_padded_sequence( 370 | X, inputs_lengths.to('cpu'), enforce_sorted=False 371 | ) 372 | 373 | # Run through RNN---layer by layer, to accumulate outputs for encoder 374 | # targeting. NB: the final_states may not be equal to the last nonpad 375 | # outputs because dropout is applied to the latter but not the former. 376 | all_outputs = [] 377 | all_final_states = [] 378 | for RNN_layer in self.RNNs: 379 | X, final_states = RNN_layer(X) 380 | X, _ = nn.utils.rnn.pad_packed_sequence(X) 381 | all_outputs.append(X) 382 | all_final_states.append(final_states) 383 | X = F.dropout( 384 | X, self.RNN_dropout, training=self.training, inplace=True 385 | ) 386 | X = nn.utils.rnn.pack_padded_sequence( 387 | X, inputs_lengths.to('cpu'), enforce_sorted=False 388 | ) 389 | 390 | # pack the final states together the way PyTorch does 391 | # NB: this assumes that all RNN layers have the same dimensions 392 | if isinstance(all_final_states[0], tuple): 393 | # LSTM 394 | all_final_states = tuple( 395 | torch.stack(states).flatten(end_dim=1) 396 | for states in zip(*all_final_states) 397 | ) 398 | else: 399 | # GRU 400 | all_final_states = torch.stack(all_final_states).flatten(end_dim=1) 401 | 402 | # "project" into natural params and convert back to canonical ordering 403 | for key, projection in self.projections[subnet_id].items(): 404 | RNN_layer = int(key.split('_')[1]) 405 | natural_params_dict[key] = projection( 406 | all_outputs[RNN_layer] 407 | ).permute(1, 0, 2) 408 | 409 | return all_outputs, all_final_states, inputs_lengths 410 | 411 | 412 | class EncoderTransformer(nn.Module): 413 | def __init__( 414 | self, 415 | layer_sizes, 416 | FF_dropout, 417 | T_dropout, 418 | subnets_params, 419 | N_head=1, 420 | coupling=None, 421 | MAX_POOL=False, 422 | VERBOSE=True, 423 | ): 424 | super().__init__() 425 | 426 | # ... 427 | if len(np.unique(layer_sizes['encoder_rnn'])) > 1: 428 | raise NotImplementedError('Expected the same layer size for all layers') 429 | else: 430 | N_hidden = layer_sizes['encoder_rnn'][0] 431 | 432 | self.FF_dropout = FF_dropout 433 | self.T_dropout = T_dropout 434 | self.coupling = coupling 435 | self.N_head = N_head 436 | 437 | # accumulate proprietary components (embeddings, projections) 438 | self.embeddings = nn.ModuleDict() 439 | self.projections = nn.ModuleDict() 440 | self.decimation_factors = {} 441 | for subnet_params in subnets_params: 442 | subnet_id = str(subnet_params.subnet_id) 443 | 444 | # embedding_layers 445 | self.embeddings[subnet_id] = MLConvEmbedding( 446 | subnet_params.data_manifests['encoder_inputs'].num_features, 447 | layer_sizes['encoder_embedding'], subnet_params, 448 | self.FF_dropout, MAX_POOL, VERBOSE=VERBOSE 449 | ) 450 | 451 | # decimation_factors 452 | self.decimation_factors[subnet_id] = subnet_params.decimation_factor 453 | 454 | # projections 455 | self.projections[subnet_id] = nn.ModuleDict() 456 | for key in subnet_params.data_mapping: 457 | if key.startswith('encoder') and key.endswith('targets'): 458 | 459 | # useful quantities 460 | data_manifest = subnet_params.data_manifests[key] 461 | output_layer = int(key.split('_')[1]) 462 | N_outputs = data_manifest.num_features 463 | Ns_hidden = layer_sizes['encoder_%i_projection' % output_layer] 464 | N_inputs = layer_sizes['encoder_rnn'][output_layer] 465 | 466 | # accumulate this encoder projection 467 | self.projections[subnet_id][key] = MultiLayerProjection( 468 | N_inputs, Ns_hidden, N_outputs, self.FF_dropout, 469 | ) 470 | 471 | # from conv output dim to transformer input dim 472 | self.pretransformer = nn.Linear( 473 | layer_sizes['encoder_embedding'][-1], N_hidden 474 | ) 475 | # position encodings 476 | # max sentence length = (~20 samples/sec)(6.25 sec) \approx 128 477 | # self.positional_encoding = RotaryPositionalEmbeddings( 478 | # # N_hidden//self.N_head, 128 479 | # N_hidden, 128 480 | # ) 481 | self.positional_encoding = PositionalEncoding( 482 | d_model=N_hidden, dropout=0, max_len=256 483 | ) 484 | encoder_layer = nn.TransformerEncoderLayer( 485 | d_model=N_hidden, nhead=self.N_head, dim_feedforward=N_hidden, 486 | dropout=self.T_dropout 487 | ) 488 | self.transformer_encoder = nn.ModuleList([ 489 | encoder_layer for _ in layer_sizes['encoder_rnn'] 490 | ]) 491 | 492 | ############### 493 | # flatten_parameters() 494 | ############### 495 | 496 | def forward(self, inputs, subnet_id, natural_params_dict, image_dict): 497 | 498 | # get lengths of *downsampled* input sequences 499 | inputs_indices, inputs_lengths = sequences_tools( 500 | inputs[:, ::self.decimation_factors[subnet_id], :] 501 | ) 502 | 503 | # embed with temporal convolutions 504 | X = self.embeddings[subnet_id](inputs) 505 | 506 | # the original TF version used dropout on the RNN *inputs* 507 | X = self.pretransformer(X) 508 | X = F.dropout( 509 | X, self.FF_dropout, training=self.training, inplace=True 510 | ) 511 | 512 | # switch out of batch-first 513 | X = X.permute(1, 0, 2) 514 | 515 | # add positional encodings 516 | # N_cases, T = X.shape[:2] 517 | # X = self.positional_encoding(X.view([N_cases, T, self.N_head, -1])) 518 | # X = self.positional_encoding(X.unsqueeze(2)).squeeze(2) 519 | X = self.positional_encoding(X) 520 | 521 | # transform 522 | # X = self.transformer_encoder(X) 523 | all_outputs = [] 524 | for encoder_layer in self.transformer_encoder: 525 | X = encoder_layer(X) 526 | all_outputs.append(X) 527 | 528 | ####### 529 | # Shouldn't this really be at the final time *of each sequence*? 530 | # return only the mean across time? 531 | # states = torch.mean(X, 0, keepdim=True) # average state 532 | states = X[-1:] # final state 533 | # states = X[inputs_lengths, torch.arange(len(inputs_lengths)), :].unsqueeze(0) 534 | ####### 535 | 536 | # "project" into natural params and convert back to canonical ordering 537 | for key, projection in self.projections[subnet_id].items(): 538 | encoder_layer = int(key.split('_')[1]) 539 | natural_params_dict[key] = projection( 540 | all_outputs[encoder_layer] 541 | ).permute(1, 0, 2) 542 | 543 | return all_outputs, states, inputs_lengths 544 | 545 | 546 | class FinalStateClassifier(nn.Module): 547 | def __init__( 548 | self, 549 | layer_sizes, 550 | FF_dropout, 551 | subnets_params, 552 | ENCODER_IS_BIDIRECTIONAL 553 | ): 554 | super().__init__() 555 | 556 | ######### 557 | # for now, assume that the decoder has no proprietary layers 558 | N_outputs = subnets_params[-1].data_manifests['decoder_targets'].num_features 559 | ######### 560 | 561 | # NB! Even though there is no decoder_rnn, len(layer_sizes['decoder_rnn']) 562 | # sets how deep into the encoder RNN the classifier looks. 563 | N_decoded_layers = len(layer_sizes['decoder_rnn']) 564 | output_layer_sizes = sum(layer_sizes['encoder_rnn'][-N_decoded_layers:]) 565 | N_in = output_layer_sizes*(1 + ENCODER_IS_BIDIRECTIONAL) 566 | 567 | # add the "projection" 568 | self.classifier_head = MultiLayerProjection( 569 | N_in, layer_sizes['decoder_projection'], N_outputs, FF_dropout, 570 | ) 571 | 572 | def forward( 573 | self, encoder_rnn_outputs, encoder_final_states, targets, inputs_lengths, 574 | natural_params_dict, image_dict 575 | ): 576 | ''' 577 | encoder_rnn_outputs, targets, inputs_lengths, and 578 | image_dict aren't necessary but are included here for consistency with 579 | DecoderAttentionRNN. 580 | ''' 581 | 582 | # (N_layers x N_cases x N_hidden) 583 | if isinstance(encoder_final_states, tuple): 584 | # LSTM: use only hidden, not cell, states; 585 | hidden_states = encoder_final_states[0] 586 | else: 587 | hidden_states = encoder_final_states 588 | 589 | # treat activities at all layers as features on all fours w/each other 590 | # -> (N_cases x N_layers*N_hidden) 591 | hidden_states = hidden_states.permute(1, 0, 2).flatten(start_dim=1) 592 | 593 | # -> (N_cases x 1 x N_classes) 594 | natural_params_dict['decoder_targets'] = self.classifier_head( 595 | hidden_states).unsqueeze(1) 596 | 597 | ############## 598 | # targets, image_dict 599 | ############## 600 | 601 | 602 | class DecoderRNN(nn.Module): 603 | def __init__( 604 | self, 605 | layer_sizes, 606 | FF_dropout, 607 | RNN_dropout, 608 | subnets_params, 609 | SOS_id, 610 | max_hyp_length, 611 | RNN_type 612 | ): 613 | super().__init__() 614 | 615 | # these are required at run time 616 | self.FF_dropout = FF_dropout 617 | self.RNN_dropout = RNN_dropout 618 | self.SOS_id = SOS_id 619 | self.max_hyp_length = max_hyp_length 620 | 621 | # generalizing beyond this would be a lot of work 622 | if len(np.unique(layer_sizes['decoder_rnn'])) > 1: 623 | raise NotImplementedError('Expected the same layer size for all layers') 624 | else: 625 | N_hidden = layer_sizes['decoder_rnn'][0] 626 | 627 | #-------# 628 | # for now, assume that the decoder has no proprietary layers 629 | N_outputs = subnets_params[-1].data_manifests['decoder_targets'].num_features 630 | #-------# 631 | 632 | # (possibly) multi-layer embedding 633 | self.embedding = MLLinearEmbedding( 634 | N_outputs, layer_sizes['decoder_embedding'], self.FF_dropout 635 | ) 636 | 637 | ############ 638 | # You could in theory just make the decoder like the encoder: *loop* 639 | # across layers of the RNN and save all outputs. E.g., you could 640 | # imagine targeting different decoder layers.... 641 | # If len()==1, this dropout will have no effect 642 | N_in = layer_sizes['decoder_embedding'][-1] 643 | N_layers_RNN = len(layer_sizes['decoder_rnn']) 644 | self.RNN = getattr(nn, RNN_type)( 645 | N_in, N_hidden, num_layers=N_layers_RNN, dropout=RNN_dropout 646 | ) 647 | 648 | # flatten_parameters() 649 | ############ 650 | 651 | # add a "projection" 652 | self.decoder_projection = MultiLayerProjection( 653 | layer_sizes['decoder_rnn'][-1], layer_sizes['decoder_projection'], 654 | N_outputs, self.FF_dropout, 655 | ) 656 | 657 | def forward( 658 | self, encoder_rnn_outputs, encoder_final_states, targets, inputs_lengths, 659 | natural_params_dict, image_dict 660 | ): 661 | ''' 662 | encoder_rnn_outputs aren't really necessary but are included here for 663 | consistency with DecoderAttentionRNN 664 | ''' 665 | 666 | # ... 667 | states = encoder_final_states 668 | 669 | # Are we testing or training? 670 | if targets is None: 671 | # testing: use most probable prev. word as input; go one step at a time 672 | 673 | # encoder_rnn_outputs are in RNN ordering 674 | N_cases = encoder_rnn_outputs[-1].shape[1] 675 | 676 | # and the input to the embedding must have size (N_cases x 1) 677 | inputs = torch.full([N_cases, 1], self.SOS_id).to( 678 | encoder_rnn_outputs[-1].device 679 | ) 680 | 681 | # loop 682 | natural_params = [] 683 | for i in range(self.max_hyp_length): 684 | one_step_natural_params, states = self.forward_core(inputs, states) 685 | natural_params.append(one_step_natural_params) 686 | _, most_probable_classes = one_step_natural_params.topk(1) 687 | inputs = most_probable_classes[:, :, 0].detach() 688 | 689 | # terminate if all most_probable_classes are EOS? 690 | 691 | final_states = states 692 | natural_params = torch.cat(natural_params, dim=1) 693 | else: 694 | # ...we're training; use *shifted* targets as inputs... 695 | X = targets.roll(1, 1) 696 | X[:, 0] = self.SOS_id 697 | natural_params, final_states = self.forward_core(X, states) 698 | 699 | # update the natural_params_dict 700 | natural_params_dict['decoder_targets'] = natural_params 701 | 702 | def forward_core(self, inputs, initial_state): 703 | 704 | # embed inputs 705 | X = self.embedding(inputs) 706 | 707 | # the original TF implementation used dropout at the RNN *inputs* 708 | X = F.dropout( 709 | X, self.RNN_dropout, training=self.training, 710 | inplace=True 711 | ) 712 | 713 | # run through RNN, converting canonical to RNN ordering 714 | outputs, final_states = self.RNN(X.permute(1, 0, 2), initial_state) 715 | 716 | # "project" and convert back to canonical ordering 717 | natural_params = self.decoder_projection(outputs).permute(1, 0, 2) 718 | 719 | return natural_params, final_states 720 | 721 | 722 | class DecoderAttentionRNN(DecoderRNN): 723 | def __init__( 724 | self, 725 | key_length, 726 | layer_sizes, 727 | FF_dropout, 728 | RNN_dropout, 729 | subnets_params, 730 | SOS_id, 731 | max_hyp_length, 732 | RNN_type, 733 | N_hidden_attention=200, 734 | ): 735 | 736 | # It doesn't really make sense to have layerwise cross-attention with 737 | # variable numbers of units per encoder layer, because the context 738 | # is a weighted sum of tokens across time *and layers*. Plus PyTorch 739 | # wants multilayer RNNs to have a fixed number of hidden units across 740 | # layers anyway. So we assume this throughout. 741 | 742 | # create a DecoderRNN 743 | super().__init__( 744 | layer_sizes, FF_dropout, RNN_dropout, subnets_params, SOS_id, 745 | max_hyp_length, RNN_type, 746 | ) 747 | 748 | # add attention; Ns are necessarily the case 749 | query_length = layer_sizes['decoder_rnn'][-1] 750 | N_layers = len(layer_sizes['decoder_rnn']) 751 | self.attention = BahdanauAttention( 752 | query_length, key_length, N_hidden_attention, N_layers 753 | ) 754 | 755 | # RNN's input is [embedded_input, "context"], so we have to overwrite it 756 | self.RNN = getattr(nn, RNN_type)( 757 | input_size=self.RNN.input_size + key_length, 758 | hidden_size=query_length, 759 | num_layers=N_layers, 760 | dropout=RNN_dropout 761 | ) 762 | 763 | def forward( 764 | self, encoder_outputs, encoder_final_states, targets, inputs_lengths, 765 | natural_params_dict, image_dict 766 | ): 767 | 768 | # encoder_outputs are in RNN ordering 769 | N_cases = encoder_outputs[-1].shape[1] 770 | iMax = targets.shape[1] if targets is not None else self.max_hyp_length 771 | 772 | # use as many keys as there are decoder layers 773 | encoder_outputs = torch.stack(encoder_outputs[-self.RNN.num_layers:]) 774 | 775 | # and the input to the embedding must have size (N_cases x 1) 776 | inputs = torch.full([N_cases, 1], self.SOS_id).to(encoder_outputs.device) 777 | states = encoder_final_states 778 | natural_params = [] 779 | attn_weights = [] 780 | 781 | for i in range(iMax): 782 | one_step_natural_params, states, one_step_attn_weights = self.forward_step( 783 | inputs, states, encoder_outputs, inputs_lengths 784 | ) 785 | natural_params.append(one_step_natural_params) 786 | attn_weights.append(one_step_attn_weights) 787 | 788 | if targets is None: 789 | # testing: use most probable prev. word as input 790 | _, most_probable_classes = one_step_natural_params.topk(1) 791 | inputs = most_probable_classes[:, :, 0].detach() 792 | else: 793 | # training: use *actual* previous word as input 794 | inputs = targets[:, i, None] 795 | 796 | # (N_cases x T_out x N_out) 797 | natural_params_dict['decoder_targets'] = torch.cat(natural_params, dim=1) 798 | 799 | # just for plotting 800 | if not self.training: 801 | # (N_layers x Te x N_cases x Td) -> (N_cases x N_layers x Td x Te) 802 | attn_weights = torch.cat(attn_weights, dim=3).permute([2, 0, 3, 1]) 803 | # attn_weights /= attn_weights.amax(dim=(2, 3), keepdim=True) 804 | attn_weights /= attn_weights.amax(dim=(2), keepdim=True) 805 | 806 | # final-layer attention only 807 | attention_images = torchvision.utils.make_grid( 808 | attn_weights[:, -1:, :, :] 809 | ) 810 | 811 | # "average" over the duplicated single channel 812 | image_dict['attention'] = attention_images.mean(dim=0) 813 | 814 | def forward_step(self, inputs, states, encoder_outputs, inputs_lengths): 815 | ''' 816 | inputs: (N_cases x 1) 817 | states: (N_layers x N_cases x N_hidden_decoder) 818 | encoder_outputs: (N_layers x T x N_cases x N_hidden_encoder) 819 | ''' 820 | 821 | # get attention weights 822 | if isinstance(states, tuple): 823 | # use only hidden state, not cell state 824 | queries = states[0][:, None, :, :] 825 | else: 826 | queries = states[:, None, :, :] 827 | context, attn_weights = self.attention( 828 | queries, encoder_outputs, inputs_lengths 829 | ) 830 | 831 | # embed inputs---into size (N_cases x 1 x M) 832 | X = self.embedding(inputs) 833 | 834 | # concatenate "context" onto embedded input 835 | X = torch.cat((X, context[:, None, :]), dim=2) 836 | 837 | # the original TF implementation used dropout at the RNN *inputs* 838 | X = F.dropout( 839 | X, self.RNN_dropout, training=self.training, inplace=True 840 | ) 841 | 842 | # run through RNN, converting canonical to RNN ordering 843 | outputs, new_states = self.RNN(X.permute(1, 0, 2), states) 844 | 845 | # "project" and convert back to canonical ordering 846 | natural_params = self.decoder_projection(outputs).permute(1, 0, 2) 847 | 848 | # we return attn_weights only for plotting purposes 849 | return natural_params, new_states, attn_weights 850 | 851 | 852 | class BahdanauAttention(nn.Module): 853 | def __init__(self, query_length, key_length, N_hidden, N_layers): 854 | super().__init__() 855 | 856 | self.Q_list = nn.ModuleList( 857 | nn.Linear(query_length, N_hidden) for _ in range(N_layers) 858 | ) 859 | self.K_list = nn.ModuleList( 860 | nn.Linear(key_length, N_hidden) for _ in range(N_layers) 861 | ) 862 | self.V_list = nn.ModuleList( 863 | nn.Linear(N_hidden, 1) for _ in range(N_layers) 864 | ) 865 | self.N_layers = N_layers 866 | 867 | def forward(self, queries, keys, inputs_lengths): 868 | ''' 869 | queries: (N_layers x 1 x N_cases x query_length) 870 | keys: (N_layers x T x N_cases x key_length) 871 | 872 | context: (N_cases x key_length) 873 | weights: (N_layers x T x N_cases x 1) 874 | ''' 875 | 876 | # a tensor of size (N_layers x T x N_cases x 1) 877 | scores = torch.stack([ 878 | V(torch.tanh(Q(query) + K(key))) for V, Q, K, query, key in zip( 879 | self.V_list, self.Q_list, self.K_list, queries, keys 880 | ) 881 | ]) 882 | 883 | ###### 884 | # Masking doesn't actually seem to help 885 | # mask out (send to -inf) scores beyond the end of the sequences 886 | T, N_cases = scores.shape[1:3] 887 | mask = torch.arange(T, device=scores.device)[:, None] < inputs_lengths[None, :] 888 | scores = scores.masked_fill(mask[None, :, :, None] == 0, float('-inf')) 889 | ###### 890 | 891 | # -> (N_layers*T x N_cases x 1) to normalize over *both* layers and time 892 | scores = scores.view([-1, N_cases]) 893 | 894 | # normalize and compute convex combination of keys across layers/time 895 | weights = F.softmax(scores, dim=0) 896 | weights = weights.view([-1, T, N_cases, 1]) 897 | context = torch.sum(weights*keys, dim=(0, 1)) 898 | 899 | # (N_cases x key_length), plus attention weights for plotting 900 | return context, weights 901 | 902 | 903 | class MLLinearEmbedding(nn.Module): 904 | def __init__( 905 | self, 906 | N_in, 907 | layer_sizes, 908 | dropout, 909 | ): 910 | ''' 911 | Really just a MLP, but with the first layer expecting one-hot inputs 912 | 913 | ''' 914 | super().__init__() 915 | 916 | self.dropout = dropout 917 | 918 | self.layers = nn.ModuleList() 919 | for iLayer in range(len(layer_sizes)): 920 | N_out = layer_sizes[iLayer] 921 | Layer = nn.Embedding if iLayer == 0 else nn.Linear 922 | self.layers.append(Layer(N_in, N_out)) 923 | N_in = N_out 924 | 925 | def forward(self, inputs): 926 | 927 | # embed inputs 928 | X = inputs 929 | for layer in self.layers: 930 | X = F.dropout( 931 | F.relu(layer(X)), self.dropout, training=self.training, 932 | inplace=False 933 | ) 934 | 935 | return X 936 | 937 | 938 | class MLConvEmbedding(nn.Module): 939 | def __init__( 940 | self, 941 | N_in, 942 | layer_sizes, 943 | subnet_params, 944 | dropout, 945 | MAX_POOL=False, 946 | VERBOSE=True, 947 | ): 948 | super().__init__() 949 | 950 | # useful things for `forward` 951 | self.decimation_factor = subnet_params.decimation_factor 952 | self.dropout = dropout 953 | 954 | # there may be multiple layers 955 | self.layers = nn.ModuleList() 956 | 957 | # distribute decimation over multiple layers 958 | layer_strides = close_factors(self.decimation_factor, len(layer_sizes)) 959 | if VERBOSE: 960 | print('Temporally convolving with strides ' + repr(layer_strides)) 961 | 962 | # construct embedding network 963 | for N_out, layer_stride in zip(layer_sizes, layer_strides): 964 | self.layers.append(nn.Conv1d( 965 | N_in, N_out, layer_stride, layer_stride, 966 | bias=MAX_POOL, 967 | padding='valid', 968 | )) 969 | N_in = N_out 970 | 971 | ######### 972 | # max pool... 973 | ######### 974 | 975 | def forward(self, inputs): 976 | 977 | # canonical ordering -> conv ordering, (N_cases x N_features x T) 978 | X = inputs.permute(0, 2, 1) 979 | 980 | # In 'VALID'-style convolution, the data are not padded to accommodate 981 | # the filter, and the final (right-most) elements that don't fit a 982 | # filter are simply dropped. Here we pad by a sufficient amount to 983 | # ensure that no data are dropped. There's no danger in padding too 984 | # much because we will subsequently extract out only sequences of the 985 | # right inputs_lengths 986 | X = F.pad(X, [0, 4*self.decimation_factor]) 987 | 988 | # "embed" 989 | for layer in self.layers: 990 | X = F.dropout( 991 | layer(X), self.dropout, training=self.training, inplace=True 992 | ) 993 | 994 | # return in canonical ordering 995 | return X.permute(0, 2, 1) 996 | 997 | 998 | class MultiLayerProjection(nn.Module): 999 | def __init__( 1000 | self, 1001 | N_inputs, 1002 | Ns_hidden, 1003 | N_outputs, 1004 | FF_dropout, 1005 | ): 1006 | super().__init__() 1007 | 1008 | self.FF_dropout = FF_dropout 1009 | self.projections = nn.ModuleList() 1010 | Ns_out = Ns_hidden + [N_outputs] 1011 | N_in = N_inputs 1012 | for N_out in Ns_out: 1013 | self.projections.append(nn.Linear(N_in, N_out)) 1014 | N_in = N_out 1015 | 1016 | def forward(self, X): 1017 | ''' 1018 | arguments: 1019 | X: input tensor of size (* x N_inputs) 1020 | returns: 1021 | Y: output tensor of size (* x N_outputs) 1022 | ''' 1023 | 1024 | for projection in self.projections[:-1]: 1025 | X = F.dropout( 1026 | F.relu(projection(X)), self.FF_dropout, training=self.training, 1027 | inplace=False 1028 | ) 1029 | 1030 | ############# 1031 | # no nonlinearity on the final layer--but there is dropout (!) 1032 | # return F.dropout(self.projections[-1](X), self.FF_dropout, training=self.training) 1033 | return self.projections[-1](X) 1034 | ############# 1035 | 1036 | 1037 | def context_reshape(states, N_layers, N_directions=2): 1038 | ''' 1039 | Pytorch RNNs return the states with size 1040 | 1041 | (N_layers*N_directions x N_cases x other). 1042 | 1043 | Thus, the two different directions of a bidirectional RNN are concatenated 1044 | together along the *layers* dimension. Now, to initialize a unidirectional 1045 | RNN with 2*N_features from the final states of a bidirectional RNN with 1046 | N_features, we need to unpack the first dimension, permute, and then 1047 | flatten again: 1048 | 1049 | (N_layers x N_cases x N_directions*N_features) 1050 | 1051 | This function will also only select the last N_layers' worth of states, and 1052 | so can be used to hook up encoder and decoder RNNs of different depths. 1053 | Notice that if N_directions==1, then this is the *only* effect of this fxn. 1054 | 1055 | The concatention ordering of the RNN was taken from here: 1056 | https://discuss.pytorch.org/t/ 1057 | how-can-i-know-which-part-of-h-n-of-bidirectional-rnn-is-for-backward-process/3883 1058 | ''' 1059 | 1060 | # useful sizes 1061 | _, N_cases, N_features = states.shape 1062 | 1063 | # break layers and directions into separate dimensions 1064 | states = states.reshape([-1, N_directions, N_cases, N_features]) 1065 | 1066 | # only grab the last N_layers' worth of states 1067 | states = states[-N_layers:] 1068 | 1069 | # put directions next to features and flatten 1070 | states = states.permute([0, 2, 1, 3]) 1071 | return torch.flatten(states, start_dim=2) 1072 | 1073 | 1074 | class SequenceTrainer(): 1075 | @auto_attribute(CHECK_MANIFEST=True) 1076 | def __init__( 1077 | self, 1078 | manifest, 1079 | subnets_params, 1080 | ##### 1081 | # kwargs set in the manifest 1082 | temperature=None, 1083 | EMA_decay=None, 1084 | beam_width=None, 1085 | assessment_epoch_interval=None, 1086 | tf_summaries_dir=None, 1087 | ##### 1088 | OOV_token='', 1089 | N_cases=128, 1090 | assessment_op_set={ 1091 | 'decoder_word_error_rate', 1092 | 'decoder_accuracy', 1093 | 'loss' 1094 | }, 1095 | REPORT_TRAINING_LOSS=True, 1096 | TARGETS_ARE_SEQUENCES=True, 1097 | ): 1098 | 1099 | class AssessmentTuple(MutableNamedTuple): 1100 | __slots__ = (['decoder_word_error_rates'] + list(self.assessment_op_set)) 1101 | 1102 | # create data loaders, tensorboard writers, assessments objects 1103 | data_partitions = ['training', 'validation'] 1104 | self.loaders = dict.fromkeys(data_partitions) 1105 | self.writers = dict.fromkeys(data_partitions) 1106 | self.assessments = dict.fromkeys(data_partitions) 1107 | for data_partition in data_partitions: 1108 | 1109 | # only assess on the *last* subject 1110 | params = ( 1111 | subnets_params[-1:] if data_partition == 'validation' 1112 | else subnets_params 1113 | ) 1114 | self.loaders[data_partition] = TFRecordDataLoader( 1115 | params, data_partition, N_cases, OOV_token, TARGETS_ARE_SEQUENCES, 1116 | ) 1117 | self.writers[data_partition] = SummaryWriter( 1118 | log_dir=os.path.join(self.tf_summaries_dir, data_partition) 1119 | ) 1120 | self.assessments[data_partition] = AssessmentTuple( 1121 | decoder_word_error_rates=None, 1122 | **dict.fromkeys(self.assessment_op_set) 1123 | ) 1124 | 1125 | def train_and_assess(self, N_epochs, sequence_net, device): 1126 | 1127 | ######## 1128 | # temporary hack 1129 | # self.assessment_epoch_interval = N_epochs - 1 1130 | ######## 1131 | 1132 | # init 1133 | optimizer = torch.optim.Adam(sequence_net.parameters(), lr=3e-4) 1134 | N_assessments = math.ceil(N_epochs/self.assessment_epoch_interval)+1 1135 | for assessment in self.assessments.values(): 1136 | assessment.decoder_word_error_rates = np.zeros((N_assessments)) 1137 | sequence_net.to(device) 1138 | 1139 | def batch_op_core( 1140 | device_batch, natural_params_dict, loss_fxn_dict, epoch_loss_dict 1141 | ): 1142 | 1143 | # overkill but organized this way for "elegance" 1144 | metadata_dicts = dict.fromkeys(['encoder', 'decoder']) 1145 | for coder in metadata_dicts.keys(): 1146 | 1147 | metadata_dicts[coder] = {} 1148 | if coder == 'encoder': 1149 | d = sequence_net.encoder.decimation_factors[device_batch['subnet_id']] 1150 | inds, lens = sequences_tools(device_batch['encoder_inputs'][:, ::d, :]) 1151 | metadata_dicts[coder]['decimation_factor'] = d 1152 | else: 1153 | inds, lens = sequences_tools(device_batch['decoder_targets']) 1154 | metadata_dicts[coder]['indices'] = inds 1155 | metadata_dicts[coder]['lengths'] = lens 1156 | 1157 | # compute losses 1158 | complete_loss = 0 1159 | for key, natural_params in natural_params_dict.items(): 1160 | 1161 | # assemble the targets, their indices, and lengths 1162 | coder = key.split('_')[0] 1163 | targets = device_batch[key] 1164 | indices = metadata_dicts[coder]['indices'] 1165 | lengths = metadata_dicts[coder]['lengths'] 1166 | 1167 | # *encoder* targets are decimated and possibly reversed 1168 | if coder == 'encoder': 1169 | d = metadata_dicts[coder]['decimation_factor'] 1170 | targets = targets[:, ::d, :] 1171 | if sequence_net.coupling == 'final state': 1172 | targets = reverse_sequences(targets, indices, lengths) 1173 | 1174 | # accumulate loss 1175 | complete_loss += penalize_RNN( 1176 | natural_params, targets, indices, 1177 | *loss_fxn_dict[key], epoch_loss_dict, key 1178 | ) 1179 | 1180 | return complete_loss 1181 | 1182 | for epoch in range(N_epochs): 1183 | 1184 | # with profile( 1185 | # activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 1186 | # record_shapes=True 1187 | # ) as prof: 1188 | # with record_function("model_inference"): 1189 | self.batch_train(batch_op_core, sequence_net, optimizer, epoch, device) 1190 | # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) 1191 | 1192 | # validate 1193 | if (epoch % self.assessment_epoch_interval) == 0: 1194 | # clear output only when ready to print new validation results 1195 | clear_output(wait=True) 1196 | for data_partition in ['validation', 'training']: 1197 | self.batch_assess( 1198 | batch_op_core, sequence_net, data_partition, epoch, device 1199 | ) 1200 | 1201 | # store 1202 | self.assessments[data_partition].decoder_word_error_rates[ 1203 | epoch//self.assessment_epoch_interval 1204 | ] = self.assessments[data_partition].decoder_word_error_rate 1205 | 1206 | # for backward compatibility with unitrain 1207 | return self.assessments 1208 | 1209 | def batch_train(self, batch_op_core, net, optimizer, epoch, device): 1210 | ''' 1211 | Train for one epoch (all batches) 1212 | ''' 1213 | 1214 | net.train() 1215 | N_examples = 0 1216 | epoch_loss_dict = {} 1217 | for batch in self.loaders['training']: 1218 | 1219 | # put the data on (presumably) the GPU and pass thru network 1220 | device_batch = { 1221 | key: val.decode() if key == 'subnet_id' 1222 | else torch.tensor(val).to(device) 1223 | for key, val in batch.items() 1224 | } 1225 | natural_params_dict, image_dict = net( 1226 | device_batch['encoder_inputs'], device_batch['subnet_id'], 1227 | device_batch['decoder_targets'][:, :, 0] 1228 | ) 1229 | 1230 | # accumulate total number of examples 1231 | N_examples += device_batch['encoder_inputs'].shape[0] 1232 | 1233 | # ... 1234 | optimizer.zero_grad() 1235 | loss = batch_op_core( 1236 | device_batch, natural_params_dict, 1237 | net.loss_fxn_dicts[device_batch['subnet_id']], epoch_loss_dict 1238 | ) 1239 | 1240 | # backprop and take a step downhill 1241 | loss.backward() 1242 | optimizer.step() 1243 | 1244 | # print the per-example loss 1245 | if self.REPORT_TRAINING_LOSS: 1246 | loss_string = ' '.join([ 1247 | '%s: %.2e' % (loss_name, loss_value/N_examples) 1248 | for loss_name, loss_value in epoch_loss_dict.items() 1249 | ]) 1250 | print('\n[ training ] epoch: %3i %s' % (epoch, loss_string), end='\t') 1251 | 1252 | def batch_assess(self, batch_op_core, net, data_partition, epoch, device): 1253 | ''' 1254 | Assess on this data_partition 1255 | ''' 1256 | 1257 | # init 1258 | net.eval() 1259 | N_examples = 0 1260 | epoch_WER = 0 1261 | epoch_loss_dict = {} 1262 | on_clr = 'on_yellow' if data_partition == 'training' else 'on_cyan' 1263 | with torch.no_grad(): 1264 | for batch in self.loaders[data_partition]: 1265 | 1266 | # put the data on (presumably) the GPU and pass thru network 1267 | device_batch = { 1268 | key: val.decode() if key == 'subnet_id' 1269 | else torch.tensor(val).to(device) 1270 | for key, val in batch.items() 1271 | } 1272 | 1273 | # put the data on (presumably) the GPU and pass thru network; 1274 | # DO NOT PASS TARGETS 1275 | natural_params_dict, image_dict = net( 1276 | device_batch['encoder_inputs'], device_batch['subnet_id'] 1277 | ) 1278 | 1279 | # accumulate total number of examples 1280 | N_examples += device_batch['encoder_inputs'].shape[0] 1281 | 1282 | # update losses in epoch_loss_dict 1283 | batch_op_core( 1284 | device_batch, natural_params_dict, 1285 | net.loss_fxn_dicts[device_batch['subnet_id']], epoch_loss_dict 1286 | ) 1287 | 1288 | # only consider the single sequence of most probable classes 1289 | _, most_probable_classes = natural_params_dict['decoder_targets'].topk(1) 1290 | if net.TARGETS_ARE_SEQUENCES: 1291 | most_probable_classes = terminate_sequences( 1292 | most_probable_classes, net.EOS_id, net.pad_id 1293 | ) 1294 | 1295 | # accumulate word error rates 1296 | WERs = get_word_error_rate( 1297 | device_batch['decoder_targets'], most_probable_classes 1298 | ) 1299 | epoch_WER += sum(WERs).item() 1300 | 1301 | # ... 1302 | net.print_sentences( 1303 | most_probable_classes, device_batch['decoder_targets'], on_clr 1304 | ) 1305 | 1306 | # just evaluate on a single batch 1307 | break 1308 | else: 1309 | raise ValueError('No %s data!' % data_partition) 1310 | 1311 | # report cross entropy(s) 1312 | print('[assessment] epoch: %3i ' % epoch, end='') 1313 | for loss_name, loss_value in epoch_loss_dict.items(): 1314 | 1315 | # divide cumulative errors by number of examples 1316 | per_example_loss = loss_value/N_examples 1317 | 1318 | # print to screen and write to tensorboard 1319 | print(' %s: %.2e' % (loss_name, per_example_loss), end='') 1320 | self.writers[data_partition].add_scalar( 1321 | 'summarize_%s' % loss_name, per_example_loss, epoch 1322 | ) 1323 | for image_name, image in image_dict.items(): 1324 | self.writers[data_partition].add_image( 1325 | 'image_name', image, dataformats='HW', 1326 | # max_outputs=16 1327 | ) 1328 | 1329 | # report WER(s) 1330 | per_example_WER = epoch_WER/N_examples 1331 | print( 1332 | ' WER: %1.3f' % per_example_WER, 1333 | ' (%s data)' % data_partition, 1334 | end='' 1335 | ) 1336 | self.writers[data_partition].add_scalar( 1337 | 'summarize_decoder_word_error_rate', per_example_WER, epoch 1338 | ) 1339 | self.writers[data_partition].flush() 1340 | 1341 | ############ 1342 | # hard-coded; generally, these might not be in the assessments 1343 | # store the assessments 1344 | self.assessments[data_partition].loss = per_example_loss 1345 | self.assessments[data_partition].decoder_word_error_rate = per_example_WER 1346 | # fake it 1347 | self.assessments[data_partition].decoder_accuracy = 1 - per_example_WER 1348 | ############ 1349 | 1350 | 1351 | class PositionalEncoding(nn.Module): 1352 | # see https://stackoverflow.com/questions/77444485/ 1353 | 1354 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 1355 | super().__init__() 1356 | self.dropout = nn.Dropout(p=dropout) 1357 | 1358 | position = torch.arange(max_len).unsqueeze(1) 1359 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 1360 | pe = torch.zeros(max_len, 1, d_model) 1361 | pe[:, 0, 0::2] = torch.sin(position * div_term) 1362 | pe[:, 0, 1::2] = torch.cos(position * div_term) 1363 | self.register_buffer('pe', pe) 1364 | 1365 | def forward(self, x): 1366 | """ 1367 | Arguments: 1368 | x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` 1369 | """ 1370 | x = x + self.pe[:x.size(0)] 1371 | return self.dropout(x) 1372 | 1373 | 1374 | def penalize_RNN( 1375 | natural_params, targets, targets_indices, loss_fxn, penalty_scale, 1376 | epoch_loss_dict, key 1377 | ): 1378 | 1379 | # compute loss 1380 | loss = loss_fxn(natural_params[targets_indices], targets[targets_indices]) 1381 | 1382 | # *accumulate* loss 1383 | CE_key = swap(key, 'cross_entropy') 1384 | if CE_key not in epoch_loss_dict: 1385 | epoch_loss_dict[CE_key] = 0 1386 | epoch_loss_dict[CE_key] += loss.item() 1387 | 1388 | return penalty_scale*loss 1389 | 1390 | 1391 | def class_indices_to_sequence(classes, targets_list, EOS_token, pad_token): 1392 | word_sequence = ''.join([targets_list[c] for c in classes]).replace( 1393 | '_', ' ').replace(pad_token, '').replace(EOS_token, '').rstrip() 1394 | 1395 | return word_sequence 1396 | 1397 | 1398 | def terminate_sequences(sequence_tensor, EOS_id, pad_id): 1399 | 1400 | # Create matrix like [[0, 1, 2, 3], [0, 1, 2, 3]] 1401 | all_inds = torch.arange(sequence_tensor.shape[1]).tile( 1402 | (sequence_tensor.shape[0], 1) 1403 | ).to(sequence_tensor.device) 1404 | 1405 | # In the worst case, no EOS_id occurs in a hypothesis; mark this explicitly 1406 | # (which allows the argmax to work). (Technically, this could fix a 1407 | # sentence that should have had an EOS_id at max_hyp_length but didn't.) 1408 | # NB that argmax returns the *first* instance where there are multiple. 1409 | sequence_tensor[:, -1, :] = EOS_id 1410 | final_inds = torch.argmax((sequence_tensor == 1).to(dtype=torch.int), dim=1) 1411 | 1412 | # now write pad_id into all entries of sequence_tensor beyond first EOS_id 1413 | sequence_tensor = torch.where( 1414 | all_inds[:, :, None] > final_inds[:, None, :], pad_id, sequence_tensor 1415 | ) 1416 | 1417 | return sequence_tensor 1418 | 1419 | 1420 | def get_cross_entropy_fxn(distribution): 1421 | ''' 1422 | Get the cross entropy function appropriate to this distribution. NB that 1423 | these functions *sum* across all examples. 1424 | 1425 | Cross entropies are computed in bits. 1426 | 1427 | However, the Gaussian cross entropy is off by an additive constant (the log 1428 | normalizer) 1429 | 1430 | Also NB: no lambda functions for compatibility with pickle 1431 | ''' 1432 | 1433 | if distribution == 'Gaussian': 1434 | return Gaussian_cross_entropy 1435 | elif distribution == 'categorical': 1436 | return categorical_cross_entropy 1437 | else: 1438 | raise NotImplementedError('%s cross entropy not yet implemented!' % distribution) 1439 | 1440 | 1441 | def Gaussian_cross_entropy(natural_params, targets): 1442 | # in TF1, you averaged across features 1443 | return np.log2(np.e)*nn.MSELoss(reduction='sum')(natural_params, targets) 1444 | 1445 | 1446 | def categorical_cross_entropy(natural_params, targets): 1447 | # Generic ints don't work so convert to int64. Also, your code expects 1448 | # targets to have a final dimension of 1; CrossEntropyLoss does not. 1449 | return nn.CrossEntropyLoss(reduction='sum')( 1450 | natural_params, targets[:, 0].to(torch.int64) 1451 | ) 1452 | 1453 | 1454 | def swap(key, string): 1455 | # In SequenceNetworks, keys are often constructed from the data_manifest 1456 | # key by swapping out the word 'targets' for some other string. This is 1457 | # just a shortcut for that process. 1458 | return key.replace('targets', string) 1459 | -------------------------------------------------------------------------------- /machine_learning/torch_helpers.py: -------------------------------------------------------------------------------- 1 | # standard libraries 2 | import pdb 3 | 4 | # third-party libraries 5 | import numpy as np 6 | import torch 7 | 8 | 9 | # for compatibility with tensorflow 10 | def parse_protobuf_seq2seq_example(example_proto, data_manifests): 11 | ''' 12 | NB that all sequence_matrices are [T x N_features] 13 | ''' 14 | 15 | example_dict = dict.fromkeys(data_manifests.keys()) 16 | 17 | for key, data_manifest in data_manifests.items(): 18 | # ..."unflatten" the sequence of (possibly length-1) vectors and xform 19 | 20 | sequence_vector = np.array(example_proto[data_manifest.sequence_type]) 21 | sequence_matrix = sequence_vector.reshape( 22 | [-1, data_manifest.num_features_raw] 23 | ) 24 | sequence_matrix = data_manifest.transform(sequence_matrix) 25 | if not (type(sequence_matrix) is torch.Tensor): 26 | sequence_matrix = torch.tensor(sequence_matrix) 27 | example_dict[key] = sequence_matrix 28 | 29 | return example_dict 30 | 31 | 32 | def fancy_indexing(X, extract_inds, axis=0): 33 | ############## 34 | # FIX ME 35 | ############## 36 | ''' 37 | Select indices along axis `axis` from tensor `X` with indices in rank-1 38 | tensor `extract_inds`. 39 | ''' 40 | # Expand me to deal with all the interesting numpy cases 41 | 42 | # get the indices for gathering/scattering 43 | X_shape = tf.shape(input=X) 44 | make_grid_coords = [tf.range(X_shape[i]) for i in range(axis)] 45 | make_grid = tf.meshgrid(*make_grid_coords, extract_inds, indexing='ij') 46 | vectorize_grid = [tf.reshape(grid_coords, [-1]) for grid_coords in make_grid] 47 | matricize_grid = tf.stack(vectorize_grid, axis=1) 48 | 49 | # gather them up and reshape 50 | new_shape = [ 51 | X_shape[i] if i != axis else tf.shape(input=extract_inds)[0] 52 | for i in range(len(common_layers.shape_list(X))) 53 | ] 54 | return tf.reshape(tf.gather_nd(X, matricize_grid), new_shape) 55 | 56 | 57 | def string_seq_to_index_seq( 58 | sequence_matrix, unique_targets_list, eos_id_list, OOV_index 59 | ): 60 | ''' 61 | Convert a sequence of strings (sequence_matrix) into a sequence of indices 62 | into the unique_targets_list, where 63 | indices[:, 1] = target_ids 64 | indices[:, 0] = sequence positions. 65 | NB that the sequence_matrix must have size (N x 1) as opposed to (N, ). 66 | Strings not found in the unique_targets_list are converted to the OOV_id. 67 | 68 | As a final step, the sequence is appended with eos_id_list, which typically 69 | will hold either a single id or none at all (e.g. for single-word data). 70 | Note that the returned sequence is, like the input, (N x 1). 71 | ''' 72 | 73 | # naively get the indices for all elements of this sequence 74 | unique_bytes_list = [t.encode('utf-8') for t in unique_targets_list] 75 | sequence_positions, target_ids = torch.where(torch.tensor( 76 | np.array(unique_bytes_list)[None, :] == sequence_matrix 77 | )) 78 | 79 | # If a sequence element is missing (because that target wasn't in the 80 | # unique_targets_list), replace it with the OOV_index. 81 | target_shape = sequence_matrix.shape[0:1] 82 | all_OOV_vector = torch.full(target_shape, OOV_index) 83 | index_sequence = all_OOV_vector.scatter_(0, sequence_positions, target_ids) 84 | 85 | # append the EOS_id 86 | index_sequence = torch.cat((index_sequence, torch.tensor(eos_id_list))) 87 | 88 | # return a "matrix" 89 | return index_sequence[:, None] 90 | 91 | 92 | def get_word_error_rate( 93 | references, hypotheses, m_cost=0, s_cost=1, i_cost=1, d_cost=1, cost_fxn=None 94 | ): 95 | """ 96 | Vectorized calculation of word error rate with Levenshtein distance. 97 | 98 | Works only for iterables up to 254 elements (uint8). 99 | O(nm) time and space complexity. 100 | 101 | Input arguments: 102 | -------- 103 | references : tensor of right-zero-padded rows of integers 104 | (N_sentences x max_length x 1) 105 | hypotheses : tensor of right-zero-padded rows of integers 106 | (N_sentences x max_length x 1) 107 | 108 | (Tensors are expected because of where this function gets used....) 109 | 110 | Returns: 111 | -------- 112 | numpy array (vector) of len(references) (== len(hypotheses)) 113 | 114 | Revised: 11/20/23 115 | re-wrote for PyTorch 116 | Created: 02/12/18 117 | by JGM 118 | Inspired by scalar version found here: 119 | https://martin-thoma.com/word-error-rate-calculation/ 120 | """ 121 | 122 | # ... 123 | device = references.device 124 | N_sentences = references.shape[0] 125 | if hypotheses.shape[0] != N_sentences: 126 | raise ValueError('no. of hypotheses must equal no. of references') 127 | 128 | # ... 129 | _, references_lengths = sequences_tools(references) 130 | _, hypotheses_lengths = sequences_tools(hypotheses) 131 | 132 | N_ref_max = max(references_lengths) 133 | N_hyp_max = max(hypotheses_lengths) 134 | d_max = max(N_ref_max, N_hyp_max) 135 | 136 | # initialize 137 | if cost_fxn is None: 138 | def cost_fxn(ref, hyp): 139 | return m_cost, s_cost, i_cost, d_cost 140 | 141 | distance_tensor = torch.zeros( 142 | (N_sentences, N_ref_max + 1, N_hyp_max + 1), dtype=torch.uint8 143 | ).to(device) 144 | distance_tensor[:, 0] = torch.arange(N_hyp_max + 1)[None, :] 145 | distance_tensor[:, :, 0] = torch.arange(N_ref_max + 1)[None, :] 146 | else: 147 | distance_tensor = torch.full( 148 | (N_sentences, N_ref_max + 1, N_hyp_max + 1), torch.inf, 149 | ).to(device) 150 | distance_tensor[:, 0, 0] = 0 151 | 152 | # compute minimum edit distance 153 | for i_ref in range(N_ref_max): 154 | for i_hyp in range(N_hyp_max): 155 | m_cost, s_cost, i_cost, d_cost = cost_fxn( 156 | references[:, i_ref], hypotheses[:, i_hyp]) 157 | match = m_cost + distance_tensor[:, i_ref, i_hyp] + d_max*( 158 | references[:, i_ref, 0] != hypotheses[:, i_hyp, 0]) 159 | substitution = s_cost + distance_tensor[:, i_ref, i_hyp] 160 | insertion = i_cost + distance_tensor[:, i_ref + 1, i_hyp] 161 | deletion = d_cost + distance_tensor[:, i_ref, i_hyp + 1] 162 | distance_tensor[:, i_ref+1, i_hyp+1], _ = torch.min(torch.stack( 163 | [match, substitution, insertion, deletion] 164 | ), dim=0) 165 | 166 | distances = distance_tensor[ 167 | (torch.arange(N_sentences), references_lengths, hypotheses_lengths) 168 | ] 169 | 170 | return distances/references_lengths 171 | 172 | 173 | def sequences_tools(sequences, as_tuple=True): 174 | ''' 175 | Input arguments: 176 | -------- 177 | sequences: 178 | tensor of size (N_cases x max_sequence_length x Ndims) 179 | 180 | Returns: 181 | -------- 182 | sequences_indices: 183 | (sum_i^Nsequences seq_len(i) x 2) tensor listing all the non-zero 184 | indices in the tensor of sequences 185 | sequences_lengths: 186 | int32 tensor of size (N_cases) 187 | ''' 188 | 189 | # binary_mask is a (N_cases x max_sequence_length) matrix with 0s 190 | # wherever all elements of an input token are simultaneously zero, 191 | # and 1s elsewhere. Since all elements of an input token are 192 | # simultaneously zero only in the zero-padding, the 1s will be 193 | # contiguous, and the number of them in each row will be the 194 | # corresponding true sequence length 195 | max_vals, _ = torch.max(torch.abs(sequences), axis=2) 196 | binary_mask = torch.sign(max_vals) 197 | sequences_lengths = torch.sum(binary_mask, axis=1, dtype=torch.int32) 198 | sequences_indices = torch.nonzero(binary_mask, as_tuple=as_tuple) 199 | 200 | return sequences_indices, sequences_lengths 201 | 202 | 203 | def reverse_sequences(sequences, indices, lengths): 204 | # You can get indices and lengths from sequences_tools, but NB that in 205 | # practice you often won't run this on sequences but on some precursor 206 | # thereof. 207 | 208 | reversed_indices = ( 209 | indices[0], torch.repeat_interleave(lengths, lengths) - indices[1] - 1, 210 | ) 211 | new_sequences = sequences.clone() 212 | new_sequences[indices] = sequences[reversed_indices] 213 | 214 | return new_sequences 215 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import sys 3 | 4 | with open("README.md", "r") as fh: 5 | long_description = fh.read() 6 | 7 | install_requires=['numpy', 'scipy'] 8 | python_subversion = int(sys.version.split('.')[1]) 9 | if python_subversion > 8: 10 | install_requires = [ 11 | 'numpy', 'scipy', 'tensorflow', 'tensorflow-probability', 12 | 'tensorflow-datasets', 'torch', 'torchvision', 'seaborn', 13 | 'matplotlib', 'tensorflow-datasets' 14 | ] 15 | else: 16 | install_requires = [ 17 | 'numpy==1.22.4', 'scipy', 18 | ##### 19 | # these ought to be enforced by tensor2tensor, but they're not 20 | 'kfac==0.2.0', 21 | 'dopamine_rl==2.0.5', 22 | 'gym==0.12.4', 23 | 'absl-py==0.10.0', 24 | ##### 25 | 'tensorflow-probability==0.7', 26 | 'tensor2tensor==1.15.7', 27 | 'tfmpl', 28 | 'protobuf==3.20.3', 29 | # 'tensorflow-gpu==1.15.3' the cpu version will also work 30 | ] 31 | 32 | 33 | setuptools.setup( 34 | name="machine_learning", 35 | version="0.7.0", 36 | author="J.G. Makin", 37 | author_email="jgmakin@gmail.com", 38 | description="a collection of packages for ML projects, written in the Python APIs for Tensorflow and Pytorch", 39 | long_description=long_description, 40 | long_description_content_type="text/markdown", 41 | url="https://github.com/jgmakin/machine_learning", 42 | packages=setuptools.find_packages(), 43 | install_requires=install_requires, 44 | classifiers=[ 45 | "Development Status :: 3 - Alpha", 46 | "Intended Audience :: Science/Research", 47 | "Topic :: Scientific/Engineering", 48 | "Programming Language :: Python :: 3", 49 | # "License :: OSI Approved :: MIT License", 50 | "Operating System :: OS Independent", 51 | ], 52 | ) 53 | --------------------------------------------------------------------------------