├── .gitignore ├── images ├── goodness.png ├── supervisedFF.png └── unsupervisedFF.png ├── old_model_based_implementation ├── images │ ├── summary_table.png │ ├── goodness_function.png │ └── dense_architecture.png ├── README.md └── ffobjects.py ├── examples ├── all_digits_as_positive │ └── data.py └── five_digits_as_positive │ └── data.py ├── ffobjects ├── utils.py ├── trainmgr.py └── ffobjects.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/goodness.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmwkwok/forward_forward_algorithm/HEAD/images/goodness.png -------------------------------------------------------------------------------- /images/supervisedFF.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmwkwok/forward_forward_algorithm/HEAD/images/supervisedFF.png -------------------------------------------------------------------------------- /images/unsupervisedFF.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmwkwok/forward_forward_algorithm/HEAD/images/unsupervisedFF.png -------------------------------------------------------------------------------- /old_model_based_implementation/images/summary_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmwkwok/forward_forward_algorithm/HEAD/old_model_based_implementation/images/summary_table.png -------------------------------------------------------------------------------- /old_model_based_implementation/images/goodness_function.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmwkwok/forward_forward_algorithm/HEAD/old_model_based_implementation/images/goodness_function.png -------------------------------------------------------------------------------- /old_model_based_implementation/images/dense_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmwkwok/forward_forward_algorithm/HEAD/old_model_based_implementation/images/dense_architecture.png -------------------------------------------------------------------------------- /examples/all_digits_as_positive/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from ffobjects import BaseFFLayer 4 | 5 | NUM_CLASS = 10 6 | 7 | def preprocess(Xy): 8 | X, y = Xy 9 | X = X.astype(np.float32) / 255. 10 | y = y.astype(np.int32) 11 | return X, y 12 | 13 | def _create_dataset(X, y, y_ff, seed, batch_size): 14 | ds = tf.data.Dataset.from_tensor_slices((X, y, y_ff)) 15 | if seed is not None: 16 | ds = ds.shuffle(4096, seed, True) 17 | return ds.batch(batch_size, True, tf.data.AUTOTUNE, True)\ 18 | .prefetch(tf.data.AUTOTUNE) 19 | 20 | def create_mnist_datasets(seed=10, batch_size=128, is_supervised_ff=False): 21 | (train_X, train_y), (valid_X, valid_y) = \ 22 | map(preprocess, tf.keras.datasets.mnist.load_data()) 23 | 24 | train_y_neg = train_y if not is_supervised_ff else gen_fake_y(train_y) 25 | valid_y_neg = valid_y if not is_supervised_ff else gen_fake_y(valid_y) 26 | 27 | train_ff_pos = _create_dataset( 28 | train_X, train_y, 29 | np.ones_like(train_y), seed, batch_size) 30 | train_ff_neg = _create_dataset( 31 | train_X, train_y_neg, 32 | np.zeros_like(train_y), seed, batch_size) 33 | valid_ff_pos = _create_dataset( 34 | valid_X, valid_y, 35 | np.ones_like(valid_y), seed, batch_size) 36 | valid_ff_neg = _create_dataset( 37 | valid_X, valid_y_neg, 38 | np.zeros_like(valid_y), seed, batch_size) 39 | 40 | datasets = [ 41 | (BaseFFLayer.TASK_TRAIN_POS, train_ff_pos), 42 | (BaseFFLayer.TASK_TRAIN_NEG, train_ff_neg), 43 | (BaseFFLayer.TASK_EVAL_POS, valid_ff_pos), 44 | ] 45 | 46 | if is_supervised_ff: 47 | datasets.extend([ 48 | (BaseFFLayer.TASK_EVAL_DUPED_POS, valid_ff_pos), 49 | ]) 50 | 51 | return datasets 52 | 53 | def gen_fake_y(y_true, num_class=NUM_CLASS): 54 | n = len(y_true) 55 | y_zero = np.zeros(n, dtype=np.int32) 56 | all_classes = np.expand_dims(np.arange(num_class), 0) 57 | 58 | a = np.expand_dims(y_true, 1) != all_classes 59 | b = np.expand_dims(y_zero, 1) + all_classes 60 | c = b[a].reshape((-1, num_class-1)) 61 | d = np.random.randint(0, num_class-1, size=n) 62 | y_fake = c[np.arange(n), d] 63 | 64 | return y_fake -------------------------------------------------------------------------------- /examples/five_digits_as_positive/data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../ffobjects/') 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | from ffobjects import BaseFFLayer 7 | 8 | def preprocess(Xy): 9 | ''' 10 | Preprocessing. This example takes digits smaller than 5 as positive 11 | data, and the rest as negative data. 12 | Args: 13 | Xy: a tuple of `numpy.array` of data and targets. 14 | 15 | Returns: 16 | A positive `tuple` and a negative `tuple`. Each `tuple` has 17 | two `numpy.array` for data and targets respectively. 18 | ''' 19 | X, y = Xy 20 | X = X.astype(np.float32) / 255. 21 | y = y.astype(np.int32) 22 | return (X[y< 5], y[y< 5]), (X[y>=5], y[y>=5]) 23 | 24 | def _create_dataset(X, y, y_ff, seed, batch_size): 25 | ''' 26 | Args: 27 | X, y: `numpy.array` for data and targets. 28 | y_ff: `numpy.array`. It is `np.ones_like(...)` for positive data 29 | and `np.zeros_like(...)` for negative data. 30 | seed: When `None`, no shuffling is done. When `int`, reshuffling 31 | is done. 32 | batch_size: an `int` for the batch size. 33 | 34 | Returns: 35 | a `tf.data.Dataset`. 36 | ''' 37 | ds = tf.data.Dataset.from_tensor_slices((X, y, y_ff)) 38 | if seed is not None: 39 | ds = ds.shuffle(4096, seed, True) 40 | return ds.batch(batch_size, True, tf.data.AUTOTUNE, True)\ 41 | .prefetch(tf.data.AUTOTUNE) 42 | 43 | def create_mnist_datasets(seed=10, batch_size=128, include_duped=False): 44 | ''' 45 | Args: 46 | batch_size: an `int` for batch size. 47 | include_duped: `bool`. Set to `True` for supervised-wise FF 48 | training; `False` for unsupervised-wise FF training. 49 | 50 | Returns: 51 | When `include_duped` is `False`, returns a list of three 52 | datasets for the tasks of `TASK_TRAIN_POS`, `TASK_TRAIN_NEG` 53 | and `TASK_EVAL_POS` respectively. When it is `True`, returns 54 | a list of four datasets by adding an extra one for 55 | `TASK_EVAL_DUPED_POS` which is needed to evaluate 56 | supervised-wise FF layers. 57 | ''' 58 | 59 | ((train_pos_X, train_pos_y), (train_neg_X, train_neg_y)), \ 60 | ((valid_pos_X, valid_pos_y), (valid_neg_X, valid_neg_y)) =\ 61 | map(preprocess, tf.keras.datasets.mnist.load_data()) 62 | 63 | rng = np.random.default_rng(seed=seed) 64 | train_neg_y = rng.integers(0, 5, size=len(train_neg_y)) 65 | 66 | train_ff_pos = _create_dataset( 67 | train_pos_X, train_pos_y, np.ones_like(train_pos_y), seed, batch_size) 68 | train_ff_neg = _create_dataset( 69 | train_neg_X, train_neg_y, np.zeros_like(train_neg_y), seed, batch_size) 70 | eval_ff_train = _create_dataset( 71 | train_pos_X, train_pos_y, np.ones_like(train_pos_y), seed, batch_size) 72 | eval_ff_valid = _create_dataset( 73 | valid_pos_X, valid_pos_y, np.ones_like(valid_pos_y), seed, batch_size) 74 | 75 | datasets = [ 76 | # task, dataset name, dataset 77 | (BaseFFLayer.TASK_TRAIN_POS, train_ff_pos), 78 | (BaseFFLayer.TASK_TRAIN_NEG, train_ff_neg), 79 | # (BaseFFLayer.TASK_EVAL, eval_ff_train), 80 | (BaseFFLayer.TASK_EVAL_POS, eval_ff_valid), 81 | ] 82 | 83 | if include_duped: 84 | datasets.extend([ 85 | # (BaseFFLayer.TASK_EVAL_DUPED_POS, eval_ff_train), 86 | (BaseFFLayer.TASK_EVAL_DUPED_POS, eval_ff_valid), 87 | ]) 88 | 89 | return datasets -------------------------------------------------------------------------------- /ffobjects/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import tensorflow as tf 4 | from matplotlib import pyplot as plt 5 | 6 | def set_seed(seed=100): 7 | np.random.seed(seed) 8 | random.seed(seed) 9 | tf.random.set_seed(seed) 10 | 11 | class Plotter: 12 | def __init__(self, plots_per_row, figsize): 13 | self.axes = [] 14 | self.figsize = figsize 15 | self.plots_per_row = plots_per_row 16 | self.gen_axes() 17 | 18 | def get_ax(self): 19 | if len(self.axes) == 0: 20 | self.close() 21 | self.gen_axes() 22 | return self.axes.pop(0) 23 | 24 | def gen_axes(self): 25 | fig, axes = plt.subplots(1, self.plots_per_row, figsize=self.figsize) 26 | self.fig = fig 27 | self.axes.extend(axes) 28 | 29 | def close(self): 30 | plt.tight_layout() 31 | plt.show() 32 | 33 | def plot_training_curves(train_mgr): 34 | for _hist in train_mgr.history: 35 | metriced_layers = [l for l in _hist if l != 'trainable_layers'] 36 | 37 | plotter = Plotter(5, (20, 3)) 38 | plotter.fig.suptitle('trainable_layers:' + str(_hist['trainable_layers'])) 39 | for layer in metriced_layers: 40 | ax = plotter.get_ax() 41 | ax.set_title(layer) 42 | ax.set_xlabel('epoch') 43 | ax.set_ylabel('metric') 44 | ax.plot(_hist[layer]) 45 | plotter.close() 46 | 47 | # Routes related for FFRoutedDense 48 | def _ones_partial(indices, length): 49 | return tf.reduce_sum(tf.one_hot(indices, length), axis=0, keepdims=True) 50 | 51 | def _route_classes_to_units(classes, num_classes, units, num_units): 52 | a = _ones_partial(classes, num_classes) 53 | b = _ones_partial(units, num_units) 54 | return tf.transpose(a) @ b 55 | 56 | def _get_route(num_classes, num_units, classes_to_units_list): 57 | ''' 58 | Args: 59 | num_classes: `int`. total number of classes 60 | num_units: `int`.total number of FFDense units 61 | classes_to_units_list: `list` of `tuple` `(classes, units)`. 62 | `classes` is a `Tensor` of classes' numbers that are to be 63 | routed to `units` which is a `Tensors` of units' numbers. 64 | e.g. (tf.range(0, 5), tf.range( 0, 50)) maps classes 0-4 65 | to units 0-49. 66 | Returns: 67 | `Tensor` of shape (`num_classes`, `num_units`) representing 68 | which class is routed to which unit. 69 | ''' 70 | return tf.math.add_n([ 71 | _route_classes_to_units(classes, num_classes, units, num_units) 72 | for classes, units in classes_to_units_list]) 73 | 74 | def get_routes(num_classes, num_units, num_routes, seed=1, 75 | split_mode=None): 76 | ''' 77 | Given `num_classes` and `num_units`, generate `num_routes` routes 78 | based on the criteria set by `classes_to_units_list`. 79 | `num_routes` usually equals to the number of hidden layer that need 80 | a route. 81 | `mode` is one of `['NoSplitting', 'RandomSplitting', 82 | 'SameSplitting']` 83 | ''' 84 | assert split_mode in ['NoSplitting', 'RandomSplitting', 'SameSplitting'] 85 | 86 | classes = tf.range(num_classes) 87 | classes = tf.random.shuffle(classes, seed=seed) 88 | 89 | for i in range(num_routes): 90 | if split_mode == 'NoSplitting': 91 | classes_to_units_list = [ 92 | (classes, tf.range(0, num_units)), 93 | ] 94 | else: 95 | classes_to_units_list = [ 96 | (classes[:5], tf.range(0, num_units//2)), 97 | (classes[5:], tf.range(num_units//2, num_units)), 98 | ] 99 | yield _get_route(num_classes, num_units, classes_to_units_list) 100 | 101 | if split_mode == 'RandomSplitting': 102 | classes = tf.random.shuffle(classes, seed=seed) 103 | -------------------------------------------------------------------------------- /old_model_based_implementation/README.md: -------------------------------------------------------------------------------- 1 | # Forward Forward algorithm in Tensorflow (development paused) 2 | 3 | Paper: [Geoffrey Hinton. The Forward-Forward Algorithm: Some Preliminary Investigations](https://www.cs.toronto.edu/~hinton/FFA13.pdf) 4 | 5 | ### 0. Background 6 | 1. Implemented examples of supervised-wise forward forward (FF) algorithm (paper section 3.2) and unsupervised FF (section 3.3) based on my understanding. 7 | 8 | 2. In backprop (BP) algorithm, we do a forward pass through all layers during which it remembers many intermediate computed results that will be used in the backward pass to update the layers' weights. In FF, we do two forward passes on all **hidden layers**, one with positive data and one with negative data. In the postiive pass, after the data passes through a hidden layer, the layer will perform a gradient descent with the objective of minimizing the Binary crossentropy loss of a goodnessfunction. Each samples are assumed positive (`y_true=1`), and the activities of the layer is aggregated to produce a single goodness value as the `y_pred`. After the positve samples passes through all the layers, comes the negative pass. Like in the positive pass, each layer will perform a separate gradient descent on the negative data except that all samples assume (`y_pred=0`). A softmax Dense layer can be appened to the model at its building, only that it does not perform FF training, instead it performs the regular gradient descent on only the positive samples with the samples' original labels. 9 | 10 | ![goodness function](./images/goodness_function.png) 11 | 12 | 13 | 3. The objective that each layer is optimzied for is that the goodness for the positive samples is to be close to 1, and that for the negative samples to 0. A goodness function suggested in the paper is the sum of squared activity values minus threshold. The same goodness function is implemented 14 | 15 | 4. Used MNIST data (60000 training samples + 10000 test samples). Instead of predicting for 10 different digits, this repo predicts only 5 (number zero to four). The numbers five to nine are reserved as negative samples. This is not how the paper did it, instead the paper used all digits as positive, and created negatives by image augmentation. 16 | 17 | 5. In my implementation, each trainable layer has its own metrics function, loss function and optimizer. 18 | 19 | 6. [**IMPORTANT**] In principle, with FF algorithm, we can save a lot of memory because we don't need to remember anything outside of a layer to update the weights in that layer. However, memory saving is NOT my interest about this algorithm, but how it works and what it can do. Therefore, my implementation does NOT realize that memory saving capability. I am using gradient tape over the whole model, so it saves as many things as backprop algorithm will do. However, I do NOT use anything from other layers to update a layer. In short, gradients of other layers are there, but I do not use them. This is FF, but this is not the ultimate, memory-saving version of FF. 20 | 21 | 22 | ### 1. Unsupervised-wise VS. supervised-wise FF 23 | 24 | In a supervised-wise FF training, the label of a digit is one-hot-encoded (e.g. `[0., 0., 1., 0., 0.]` stands for label `2`) and overlayed in the first 5 pixels of the image. At prediction, 2 approaches are possible and implemented: (1) overlay a "default" (`[0.2, 0.2, 0.2, 0.2, 0.2]`), and see which label is predicted in the softmax layer, or (2) copy an image 5 times and in each overlay a different one-hot-encoded label, then pass all 5 of them to the model and look at which one has the highest accumulated goodness value. 25 | 26 | In an unsupervised-wise training, the image is unchanged, and we rely on the softmax layer for class prediction. 27 | 28 | ### 2. Model description 29 | 30 | ![dense_architecture](./images/dense_architecture.png) 31 | 32 | Each hidden layer is trained with the FF algorithm. The normalized activities are concatenated and fed to a trainable Softmax Dense layer. The unnormalized activities are concatenated and on which the untrainable goodness function is applied. 33 | 34 | ### 3. Results 35 | 36 | ![summary_table](./images/summary_table.png) 37 | Source: examples.ipynb 38 | 39 | - It's consistent with the paper that BP (Backprop) does better than FF (Forward-forward) even in less epochs. 40 | - It's reasonable that the "accuracy by goodness softmax" is very poor with "unsupervised" data, however, it's interesting that it can reach 41.5% as well 41 | - "Accuracy by goodness softmax" is pretty sensitive to initialization 42 | - No hyperparameters tuned for best performing models. They are just for demo. 43 | - Known major difference between this and the paper's implementation is that my layers are smaller and my training sets are smaller. 44 | - Refer to examples.ipynb for the performance curves on the validation set. 45 | 46 | 47 | -------------------------------------------------------------------------------- /ffobjects/trainmgr.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ffobjects import BaseFFLayer 3 | 4 | class TrainMgr: 5 | def __init__(self, layers, train_seq): 6 | self.layers = layers 7 | self.train_seq = train_seq 8 | self.history = [] 9 | self.tf_train_fns = dict() 10 | self.metric_monitor_buffer = dict() 11 | self._get_all_metrics() 12 | ''' 13 | A training manager for the training process and metrics monitor. 14 | When calls the `ff_train()` method, it runs 3 layers of loop. 15 | The outermost loops over `trainable_layers_list` in which each 16 | element contains the a list of trainable layers and `epochs`. 17 | The middle loop iterates through the epochs. The innermost takes 18 | one dataset out at a time and carry the task associated to the 19 | dataset. 20 | Since different tasks requires different operation from the 21 | same layer, to exploit tensorflow graph while avoid retracing, 22 | the training manager keeps a dictionary of these graphed 23 | functions so that one graph serves one configuration of tasks, 24 | and the stored graphs may be reuse for the same config of tasks. 25 | 26 | Args: 27 | layers: a `dict` of FFLayers and / or tensorflow layers. 28 | train_seq: a python function that connects the `layers` up. 29 | ''' 30 | 31 | def ff_train(self, datasets, trainable_layers_list, show_metrics_max=[]): 32 | ''' 33 | Training the layers following the `train_seq`. 34 | 35 | Args: 36 | datasets: a `list` of `(task, tf.data.Dataset)`. `task` can 37 | take any one of `TASK_TRANSFORM`, `TASK_TRAIN_POS` 38 | `TASK_TRAIN_NEG`, `TASK_EVAL_POS`, `TASK_EVAL_DUPED_POS` 39 | defined in class `BaseFFLayer`. 40 | trainable_layers_list: a `list` of 41 | `[[layer_1, layer_2, ...], epochs]`. Each listing 42 | element describes which layers are trainable for 43 | `epochs` round of training. The layer name `layer_1` 44 | should reference back to the name in the `layers` 45 | dictionary that has been passed in when instaniating 46 | this object. 47 | show_metrics_max: a `list` of `str`. Listed layer's maximum 48 | metric value among epochs will be shown in the training 49 | progress bar. The layer name should reference back to 50 | the name in the `layers` dictionary that has been passed 51 | in when instaniating this object. 52 | 53 | Returns: 54 | self. 55 | ''' 56 | ff_layers = {n: l for n, l in self.layers.items() 57 | if isinstance(l, BaseFFLayer)} 58 | 59 | for trainable_layers, epochs in trainable_layers_list: 60 | self._init_metric_monitor( 61 | show_metrics_max, trainable_layers, epochs) 62 | for epoch in range(epochs): 63 | self._reset_metric_objects() 64 | for task, dataset in datasets: 65 | for name, layer in ff_layers.items(): 66 | layer.ff_set_task( 67 | task if name in trainable_layers else\ 68 | BaseFFLayer.TASK_TRANSFORM) 69 | key = (task, *trainable_layers, str(dataset.element_spec)) 70 | signature = tf.data.DatasetSpec(dataset.element_spec) 71 | self.tf_train_fns\ 72 | .setdefault(key, self._get_new_fn(signature))(dataset) 73 | self._update_metrics(epoch) 74 | self._save_history_buffer() 75 | return self 76 | 77 | def _get_new_fn(self, dataset_signature): 78 | @tf.function(input_signature=[dataset_signature]) 79 | def _train_fn(dataset): 80 | for X, y, y_ff in dataset: 81 | self.train_seq(X, y, y_ff) 82 | return _train_fn 83 | 84 | # Utilities: Metric related 85 | def _get_all_metrics(self): 86 | self._metrics = {} 87 | for name, layer in self.layers.items(): 88 | if isinstance(layer, BaseFFLayer): 89 | if isinstance(layer.ff_metric, tf.keras.metrics.Metric): 90 | self._metrics[name] = layer.ff_metric 91 | if isinstance(layer.ff_metric_duped, tf.keras.metrics.Metric): 92 | self._metrics[f'{name}_duped'] = layer.ff_metric_duped 93 | 94 | def _init_metric_monitor(self, show_metrics_max, trainable_layers, epochs): 95 | self._monitoring_metrics = { 96 | name: metric for name, metric in self._metrics.items() \ 97 | if name.rstrip('_duped') in trainable_layers} 98 | 99 | self._best_metric_buffer = { 100 | f'best_{name}': -9999999 101 | for name, v in self._monitoring_metrics.items() \ 102 | if name.rstrip('_duped') in show_metrics_max} 103 | 104 | self._hist_buffer = dict( 105 | trainable_layers=trainable_layers, 106 | **{n: [] for n in self._monitoring_metrics}) 107 | 108 | pbar_names = list(self._monitoring_metrics) +\ 109 | list(self._best_metric_buffer) 110 | self._pbar = tf.keras.utils.Progbar(epochs, 111 | stateful_metrics=pbar_names) 112 | 113 | def _reset_metric_objects(self): 114 | [metric.reset_state() for metric in self._monitoring_metrics.values()] 115 | 116 | def _update_metrics(self, epoch): 117 | pbar_metric = [] 118 | for name, metric in self._monitoring_metrics.items(): 119 | v = metric.result() 120 | self._hist_buffer[name].append(v) 121 | pbar_metric.append((name, v)) 122 | for name, v in self._best_metric_buffer.items(): 123 | new_v = max(v, self._hist_buffer[name[5:]][-1]) 124 | self._best_metric_buffer[name] = new_v 125 | pbar_metric.append((name, self._best_metric_buffer[name])) 126 | self._pbar.update(epoch+1, pbar_metric) 127 | 128 | def _save_history_buffer(self): 129 | temp = {k: v if k not in self._monitoring_metrics or len(v) == 0 else\ 130 | tf.concat(v, 0).numpy() 131 | for k, v in self._hist_buffer.items() } 132 | self.history.append(temp) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Forward Forward algorithm in Tensorflow (Developing) 2 | 3 | Paper: [Geoffrey Hinton. The Forward-Forward Algorithm: Some Preliminary Investigations](https://www.cs.toronto.edu/~hinton/FFA13.pdf) 4 | 5 | ## Give up thinking it as a Model 6 | I have rethought and concluded that building FF-based layers into a Model is not flexible enough for researching new ideas. Therefore, I gave up the concept of Model that I had adopted in my [last implementation](https://github.com/rmwkwok/forward_forward_algorithm/tree/main/old_model_based_implementation), and re-implemented this work which only has FFLayers (but no FFModel). A FFLayer (e.g. `FFDense`) class inherits the `BaseFFLayer` that contains a Tensorflow Layer (e.g. `Dense`). 7 | 8 | ## The BaseFFLayer class 9 | The `BaseFFLayer` enables a created FFLayer to perform 5 tasks: 10 | 11 | 1. `TASK_TRANSFORM`. Perform the tensorflow layer's transformation; 12 | 2. `TASK_TRAIN_POS`. Positive pass training; 13 | 3. `TASK_TRAIN_NEG`. Negative pass training; 14 | 4. `TASK_EVAL`. Evaluation; 15 | 5. `TASK_EVAL_DUPED`. Evaluation specialized for supervised-wise FF's goodness layer. 16 | 17 | ## The TrainMgr class, and an example for unsuperwised-FF training 18 | Since there is no Model, a `TrainMgr` is built to template the training loops and carry out metrics monitoring. It accepts a `dict` of layers and a training sequence function that connects the layers. For example, the following describes an unsupervised-wise FF based sequence: 19 | 20 | ``` python 21 | layers = dict( 22 | # Utilities 23 | input = tf.keras.layers.InputLayer(name='input'), 24 | concat = tf.keras.layers.Concatenate(name='concat'), 25 | flatten = tf.keras.layers.Flatten(name='flatten'), 26 | preNorm = FFPreNorm(name='preNorm'), 27 | 28 | # FF layers 29 | b1 = FFDense(units=units, optimizer=Adam(0.0001), th_pos=.1, th_neg=.1, name=f'b1'), 30 | b2 = FFDense(units=units, optimizer=Adam(0.0001), th_pos=.1, th_neg=.1, name=f'b2'), 31 | b3 = FFDense(units=units, optimizer=Adam(0.0001), th_pos=.1, th_neg=.1, name=f'b3'), 32 | b4 = FFDense(units=units, optimizer=Adam(0.0001), th_pos=.1, th_neg=.1, name=f'b4'), 33 | 34 | # Classifiers 35 | softmax = FFSoftmax(units=NUM_CLASS, optimizer=Adam(0.001), name=f'softmax'), 36 | ) 37 | 38 | def train_seq(X, y, y_ff): 39 | x = layers['input'](X) 40 | x = layers['flatten'](x) 41 | b1a = layers['b1'].ff_do_task(x, y_ff) 42 | b1n = layers['preNorm'](b1a) 43 | b2a = layers['b2'].ff_do_task(b1n, y_ff) 44 | b2n = layers['preNorm'](b2a) 45 | b3a = layers['b3'].ff_do_task(b2n, y_ff) 46 | b3n = layers['preNorm'](b3a) 47 | b4a = layers['b4'].ff_do_task(b3n, y_ff) 48 | b4n = layers['preNorm'](b4a) 49 | softmax_x = layers['concat']([b2n, b3n, b4n]) 50 | y_pred = layers['softmax'].ff_do_task(softmax_x, y) 51 | ``` 52 | 53 | Pictorically: 54 | 55 | ![unsupervisedFF](./images/unsupervisedFF.png) 56 | 57 | ## How FF training works 58 | 59 | In each epoch, for unsupervised-FF training, 3 datasets pass through the sequence. 60 | 1. First comes the positive data. After an positive mini-batch passes through the first `FFDense`, a gradient descent is done to update that `FFDense`. The transformed mini-batch then passes through the next `FFDense`, and another gradient descent is done to update the second `FFDense`, and the same sequence of operations are done for the other `FFDense`'s and the `FFSoftmax`. 61 | 2. Then comes the negative pass which goes through only all the `FFDense` layers and performs gradient descent on them. 62 | 3. Last comes an evaluation dataset of only positive data. 63 | 64 | The 3 datasets each requires the layers to perform a different set of tasks (among the 5 listed above). Therefore, 3 different tensorflow graphs can be built based on a `train_seq` python function as demonstrated above. To avoid graph retracing, the `TrainMgr` is responsible for storing built Graphs so that they can be reused from epoch to epoch. 65 | 66 | ## Unsupervised-wise FF vs. Supervised-wise FF 67 | 68 | My understanding is that an unsupervised-wise-FF-trained layer does not use any label information, whereas an supervised-wise-FF-trained layer uses it. Although the softmax that is appended as the last step of the unsupervised-wise FF algorithm uses the labels, it is not trained FF-wise, so the algorithm is still counted as unsupervised-wise-FF-trained. 69 | 70 | ## Supervised-wise FF flow chart 71 | 72 | ![supervisedFF](./images/supervisedFF.png) 73 | An extra `FFOverlay` layer is added to the top of the sequence which will overlays on the sample its one-hot encoded label. At evaluation (`TASK_EVAL_DUPED`), the `FFOverlay` will replicate a sample to 5 copies which will each have one of the 5 one-hot encoded labels. 74 | 75 | An `FFGoodness` replaces the `FFSoftmax` to receive some `FFDense`'s outputs for computing goodness score. 76 | 77 | ## The goodness formula 78 | 79 | ![goodness](./images/goodness.png) 80 | 81 | Although this formula was used by the paper as an example for positive data, I used the same also for negative data. 82 | 83 | ## Result 2 (Updated on 26 Feb 2023) 84 | ### The data 85 | 86 | This time, all digits were used as positive data, instead of just five of them in Result 1. However, unlike the paper (again), I didn't use hybrid digits as negative data for unsupervised-wise FF. Instead, I split the digits into two sets (A & B), and divide a layer into two halves. The first half took set A as positive data and set B sa negative, whereas the second half did the opposite. Three digit splitting were explored, including (for both supervised and unsupervised-wise FF) "random splitting", "same splitting", and (only for supervised-wise FF) "no splitting". 87 | 88 | ### Benchmark and summary 89 | 90 | 1. unsupervised-wise FF: 91 | - 97.1% @ `units = 2000` and Random splitting 92 | 93 | 2. supervised-wise FF: 94 | - 94.7% @ `units = 2000` and Random splitting 95 | - 95.1% @ `units = 2000` and No splitting 96 | 97 | (No hyperparameter tuning were done) 98 | 99 | For more explanations and discussions, refer to [this article](https://medium.com/@rmwkwok/some-forward-forward-algorithm-experiments-3a9d6f9503b6?source=friends_link&sk=34556e26da24aaa7a7aa5499eff5a993). 100 | 101 | ## Result 1 (Updated on 14 Feb 2023) 102 | ### The data 103 | 104 | Unlike the paper, I used digits 0, 1, 2, 3, and 4 from the MNIST dataset as the positive-pass's data, and the rest of the digits as the negative-pass's. At training, both positive and negative pass data are used. At evaluation, only a separate set of positive-pass data is done. 105 | 106 | ### Benchmark and summary 107 | 108 | The basic supervised and unsupervised results use the following settings: 109 | 110 | 1. `units = 100` (number of units in each hidden layer) 111 | 2. `batch_size = 128` 112 | 3. data shuffling enabled 113 | 4. train all layers in every epoch 114 | 5. epochs = 200 115 | 116 | Their validation accuracies are: 117 | 118 | 1. unsupervised-wise FF: 119 | - 96.64% @ basic setting 120 | - 98.44% @ changing to `units = 500` and `batch_size = 512` 121 | - 99.06% @ changing to `units = 2000` and `batch_size = 512` 122 | 2. supervised-wise FF: 123 | - 76.41% @ basic setting 124 | - 95.76% @ changing to `units = 500` and `batch_size = 512` 125 | - 97.42% @ changing to `units = 2000` and `batch_size = 512` 126 | 127 | A similar setting with backprop can reach over 99% easily within 20 epochs. 128 | 129 | For other results, observations, validation curves, and the layers' details, refer to examples/five_digits_as_positive/main.ipynb. 130 | 131 | 132 | -------------------------------------------------------------------------------- /old_model_based_implementation/ffobjects.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | ''' 4 | This work attempts to reproduce the forward-forward algorithm in this 5 | paper: 6 | Author: Geoffrey Hinton. 7 | Title: The Forward-Forward Algorithm: Some Preliminary Investigations 8 | Link: https://www.cs.toronto.edu/~hinton/FFA13.pdf 9 | ''' 10 | 11 | class FFConstants: 12 | POS = 'pos' 13 | NEG = 'neg' 14 | 15 | def FFLayer(layer_class, **kwargs): 16 | ''' 17 | Creating a layer inheriting the `layer_class`, and carries the 18 | layer's own metric, loss and optimizer objects. If the loss or the 19 | optimizer is not defined, the layer won't be trained. `do_ff` is 20 | needed to specify whether the layer performs a forward-forward 21 | training or a single forward training. 22 | 23 | Example: 24 | This is a FF Dense layer. `do_ff`, `optimizer`, and `loss_fn` 25 | are required for it to be a trainable FF layer. 26 | 27 | ``` 28 | x = FFLayer(tf.keras.layers.Dense, units=32, activation='relu', 29 | do_ff=True, optimizer=Adam(0.0001), 30 | loss_fn=FFLoss(threshold=0.), 31 | ) 32 | ``` 33 | 34 | For a trainable non-FF layer, only the `optimizer` and `loss_fn` 35 | are required. 36 | 37 | Args: 38 | layer_class: A class inheriting `tf.keras.layers.Layer` 39 | **kwargs: arguments passed to the `layer_class` 40 | 41 | do_ff: A `bool` indicating whether this layer performs 42 | forward-forward training (True) or forward training (False). 43 | A softmax layer usually does forward training. 44 | 45 | optimizer: A `tf.keras.optimizers.Optimizer` object 46 | 47 | loss_fn: A `tf.keras.losses.Loss` object 48 | 49 | metric: A `tf.keras.metrics.Meric` object 50 | 51 | is_goodness_softmax: Set it to `True` if the layer serves to 52 | predict a class based on goodness, `False` otherwise. 53 | 54 | report_metric_pos: A `bool` indicating whether metric evaluation 55 | result for the positive pass should be reported during 56 | training 57 | 58 | report_metric_neg: A `bool` indicating whether metric evaluation 59 | result for the negative pass should be reported during 60 | training 61 | 62 | Returns: 63 | A `Layer` object inheriting the `layer_class` which can be used 64 | to build tensorflow model 65 | ''' 66 | 67 | class Layer(layer_class): 68 | def __init__(self, do_ff=False, optimizer=None, loss_fn=None, 69 | metric=None, is_goodness_softmax=False, 70 | report_metric_pos=False, report_metric_neg=False, 71 | **kwargs): 72 | super().__init__(**kwargs) 73 | self.ff_do_ff = do_ff 74 | self.ff_metric = metric 75 | self.ff_loss_fn = loss_fn 76 | self.ff_optimizer = optimizer 77 | self.ff_is_goodness_softmax = is_goodness_softmax 78 | self.ff_report_metric = { 79 | FFConstants.POS: report_metric_pos, 80 | FFConstants.NEG: report_metric_neg, 81 | } 82 | 83 | def ff_reset_metric(self): 84 | if self.ff_metric: 85 | self.ff_metric.reset_state() 86 | 87 | def ff_get_metric_results(self, results_dict): 88 | if self.ff_metric: 89 | results_dict[self.name] = self.ff_metric.result().numpy() 90 | 91 | Layer.__name__ = layer_class.__name__ 92 | 93 | layer_object = Layer(**kwargs) 94 | 95 | if isinstance(layer_object, tf.keras.layers.InputLayer): 96 | # Copied from https://github.com/keras-team/keras/blob/e6784e4302c7b8cd116b74a784f4b78d60e83c26/keras/engine/input_layer.py#L442-L446 97 | outputs = layer_object._inbound_nodes[0].outputs 98 | if isinstance(outputs, list) and len(outputs) == 1: 99 | return outputs[0] 100 | else: 101 | return outputs 102 | else: 103 | return layer_object 104 | 105 | 106 | class FFModel(tf.keras.Model): 107 | ''' 108 | Create a FFModel inheriting `tf.keras.Model`. 109 | 110 | Example: 111 | 112 | This example creates a 3-layer NN where the hidden layers are FF 113 | trainable, whereas the output layer is trainable but not FF. All 114 | trainable layers, namely `y1, y2, y4`, and all evaluated layers, 115 | namely `y3, y4` are all added to the output when calling `FFModel`. 116 | 117 | ``` 118 | x0 = FFLayer(tf.keras.layers.InputLayer, input_shape=(100, )) 119 | 120 | y1 = FFLayer(tf.keras.layers.Dense, units=32, activation='relu', 121 | do_ff=True, optimizer=tf.keras.optimizers.Adam(0.0001), 122 | loss_fn=FFLoss(threshold=0.), 123 | name='dense 1')(x0) 124 | 125 | y2 = FFLayer(tf.keras.layers.Dense, units=16, activation='relu', 126 | do_ff=True, optimizer=tf.keras.optimizers.Adam(0.0001), 127 | loss_fn=FFLoss(threshold=0.), 128 | name='dense 2')(y1) 129 | 130 | y3 = FFLayer(tf.keras.layers.Concatenate, 131 | metric=tf.keras.metrics.SparseCategoricalAccuracy(), 132 | is_goodness_softmax=True, 133 | name='goodness softmax')([y1, y2]) 134 | 135 | y4 = FFLayer(tf.keras.layers.Dense, units=10, activation='linear', 136 | do_ff=False, 137 | optimizer=tf.keras.optimizers.Adam(0.0001), 138 | loss_fn=tf.keras.losses.SparseCategoricalCrossentropy( 139 | from_logits=True), 140 | metric=tf.keras.metrics.SparseCategoricalAccuracy(), 141 | name='dense softmax' 142 | )(y2) 143 | 144 | model = FFModel(x0, [y1, y2, y3, y4]) 145 | ``` 146 | 147 | To train the model, call `.ff_train(...)` instead of `.fit(...)`. 148 | 149 | ``` 150 | model.ff_train(ds_train, ds_valid_for_eval, epochs=10, eval_every=5) 151 | ``` 152 | ''' 153 | def __init__(self, *args, **kwargs): 154 | ''' 155 | Building a `FFModel` inheriting `tf.keras.Model`. 156 | 157 | Args: 158 | *args, **kwargs: Passed into `tf.keras.Model`. 159 | 160 | Returns: 161 | A `FFModel`. 162 | ''' 163 | super().__init__(*args, **kwargs) 164 | self.ff_layers = [self.get_layer(n) for n in self.output_names] 165 | 166 | def ff_reset_all_metrics(self): 167 | for layer in self.ff_layers: 168 | layer.ff_reset_metric() 169 | 170 | def ff_get_all_metric_results(self, results_dict): 171 | for layer in self.ff_layers: 172 | layer.ff_get_metric_results(results_dict) 173 | 174 | def ff_print_record(self, record): 175 | ''' 176 | At the end of an epoch, print metric results. 177 | ''' 178 | epoch = record['epoch'] 179 | 180 | string = '' 181 | for layer in self.ff_layers: 182 | for pn in [FFConstants.POS, FFConstants.NEG]: 183 | 184 | temp = [] 185 | for tv in ['train', 'valid']: 186 | if layer.name in record[tv][pn] and\ 187 | layer.ff_report_metric[pn]: 188 | x = record[tv][pn][layer.name] 189 | temp.append(f'{x:.6f}') 190 | 191 | if len(temp): 192 | temp = ' '.join(temp) 193 | string = f'{string} | {layer.name}/{pn} {temp}' 194 | 195 | print(f'epoch {epoch: 5d}{string}') 196 | 197 | def ff_train(self, ds_train, ds_valid_for_eval, epochs, eval_every=5): 198 | ''' 199 | train the model. Use this method instead of `.fit(...)` for 200 | forward-forward algorithm. Use `.fit(...)` for backprop 201 | algorithm. 202 | 203 | Args: 204 | ds_train: a tuple of of two training datasets. The first is 205 | for the positive pass, and the second the negative pass 206 | ds_valid_for_eval: a list of tuples. Each tuple has two 207 | evaluation datasets. The first is for the positive pass, 208 | and the second the negative pass 209 | epochs: int. Number of training epochs 210 | eval_every: int. Evaluate once every N epochs. Which layer's 211 | evaluation will be printed is controlled by the 212 | the layer's ff_report_metric_pos and 213 | ff_report_metric_neg parameters 214 | 215 | Returns: 216 | history: a history of all evaluation results, reported or 217 | not. 218 | ''' 219 | history = [] 220 | _passes = [FFConstants.POS, FFConstants.NEG] 221 | 222 | do_evaluate = lambda e: e % eval_every == 0 or e == epochs-1 223 | 224 | for epoch in range(epochs): 225 | 226 | # Gradient descent 227 | for _pass, dataset in zip(_passes, ds_train): 228 | for X, y_true in dataset: 229 | self._ff_gradient_descent(X, y_true, _pass) 230 | 231 | # Evaluation 232 | if not do_evaluate(epoch): 233 | continue 234 | 235 | record = {'epoch': epoch, 236 | 'train': {FFConstants.POS: {}, FFConstants.NEG: {}, }, 237 | 'valid': {FFConstants.POS: {}, FFConstants.NEG: {}, },} 238 | 239 | for tv, ds in ds_valid_for_eval: 240 | for _pass, dataset in zip(_passes, ds): 241 | self.ff_reset_all_metrics() 242 | for X, y_true in dataset: 243 | self._ff_evaluate(X, y_true, _pass) 244 | self.ff_get_all_metric_results(record[tv][_pass]) 245 | 246 | history.append(record) 247 | self.ff_print_record(record) 248 | 249 | return history 250 | 251 | @tf.function 252 | def _ff_convert_label(self, y_true, layer, _pass): 253 | ''' 254 | In a forward-forward algorithm layer, the label is always 1 in a 255 | positive pass, and 0 in a negative pass. 256 | ''' 257 | if layer.ff_do_ff: 258 | if _pass == FFConstants.POS: 259 | return y_true * 0 + 1 # always 1 260 | elif _pass == FFConstants.NEG: 261 | return y_true * 0 # always 0 262 | else: 263 | return y_true 264 | 265 | @tf.function 266 | def _ff_gradient_descent(self, X, y_true, _pass): 267 | with tf.GradientTape(persistent=True) as tape: 268 | losses = [] 269 | for layer, y_pred in zip(self.ff_layers, self(X, training=True)): 270 | if layer.ff_loss_fn and\ 271 | (layer.ff_do_ff or _pass == FFConstants.POS): 272 | yt = self._ff_convert_label(y_true, layer, _pass) 273 | losses.append((layer, layer.ff_loss_fn(yt, y_pred))) 274 | 275 | for layer, loss in losses: 276 | if layer.ff_optimizer: 277 | grads = tape.gradient(loss, layer.trainable_weights) 278 | layer.ff_optimizer.apply_gradients( 279 | zip(grads, layer.trainable_weights)) 280 | 281 | del tape 282 | 283 | @tf.function 284 | def _ff_evaluate(self, X, y_true, _pass): 285 | for name, y_pred in zip(self.output_names, self(X)): 286 | layer = self.get_layer(name) 287 | if layer.ff_do_ff or _pass == FFConstants.POS: 288 | if layer.ff_metric: 289 | if layer.ff_is_goodness_softmax: 290 | yt = y_true 291 | else: 292 | yt = self._ff_convert_label(y_true, layer, _pass) 293 | 294 | layer.ff_metric.update_state(yt, y_pred) 295 | -------------------------------------------------------------------------------- /ffobjects/ffobjects.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | # Utils 4 | def goodness(x, threshold=0.): 5 | return tf.reduce_sum(x**2, axis=-1) - tf.cast(threshold, tf.float32) 6 | 7 | def preNorm(X): 8 | axis = tf.range(tf.rank(X))[1:] 9 | norm = tf.math.sqrt(tf.reduce_sum(X**2, axis=axis, keepdims=True)) 10 | return X/(norm + 1e-7) 11 | 12 | # Loss and Metric For FF-trained layers (e.g. FFDense). 13 | # A softmax layer is not trained FF-wise, so it does not use them. 14 | class FFLoss(tf.keras.losses.BinaryCrossentropy): 15 | def __init__(self, threshold, **kwargs): 16 | super().__init__(from_logits=True, **kwargs) 17 | self.threshold = threshold 18 | 19 | def __call__(self, y_true, y_pred): 20 | y_pred = goodness(y_pred, self.threshold) 21 | return super().__call__(y_true, y_pred) 22 | 23 | class FFMetric(tf.keras.metrics.BinaryCrossentropy): 24 | def __init__(self, **kwargs): 25 | super().__init__(from_logits=True, **kwargs) 26 | 27 | def update_state(self, y_true, y_pred): 28 | y_pred = goodness(y_pred) 29 | return super().update_state(y_true, y_pred) 30 | 31 | # FFLayers 32 | # The base class comes first, followed by FFLayers inheriting the 33 | # base class and a tf.keras.layers.Layer. 34 | class BaseFFLayer: 35 | TASK_TRANSFORM = 'TASK_TRANSFORM' 36 | TASK_TRAIN_POS = 'TASK_TRAIN_POS' 37 | TASK_TRAIN_NEG = 'TASK_TRAIN_NEG' 38 | TASK_EVAL_POS = 'TASK_EVAL_POS' 39 | TASK_EVAL_NEG = 'TASK_EVAL_NEG' 40 | TASK_EVAL_DUPED_POS = 'TASK_EVAL_DUPED_POS' 41 | 42 | def __init__( 43 | self, 44 | tfl, 45 | optimizer=None, 46 | metric=None, 47 | metric_duped=None, 48 | loss_pos=None, 49 | loss_neg=None, 50 | **kwargs 51 | ): 52 | ''' 53 | 6 tasks are predefined. They accept `X` and `y_true` as input, 54 | and produces `y_pred` as output: 55 | - transform: e.g. tf.keras.layers.Dense(X) 56 | - train_pos: optimize on positive pass data 57 | - train_neg: optimize on negative pass data 58 | - eval_pos: evaluate layer 59 | - eval_neg: evaluate layer 60 | - eval_duped_pos: evaluate layer while expecting 61 | "duplicated" data. At evaluation, data is duplicated 62 | before gets overlayed with different labels is passed 63 | through. This is for supervised-wise FF training. 64 | 65 | The `ff_do_task()` calls the default task which is settable by 66 | calling the `ff_set_task()` method. This arrangement is for the 67 | sake of building different tensorflow graphs with even the same 68 | python function that calls the `ff_do_task()` 69 | 70 | Args: 71 | tfl: a `tf.keras.layers.Layer` object. 72 | optimizer: `tf.keras.optimizers.Optimizer` object. Used in 73 | `TASK_TRAIN_POS` and `TASK_TRAIN_NEG`. 74 | metric: `tf.keras.metrics.Metric` object. Used in 75 | `TASK_EVAL_POS`. 76 | metric_duped: `tf.keras.metrics.Metric` object. Used in 77 | `TASK_EVAL_DUPED_POS`. 78 | loss_pos: `tf.keras.losses.Loss` object. Used in 79 | `TASK_TRAIN_POS`. 80 | loss_neg: `tf.keras.losses.Loss` object. Used in 81 | `TASK_TRAIN_NEG`. 82 | **kwargs: passed to the inherited `tf.keras.layers.Layer` 83 | object by the FFLayer that inherits this class. 84 | ''' 85 | self.tfl = tfl 86 | self.ff_opt = optimizer 87 | self.ff_metric = metric 88 | self.ff_metric_duped = metric_duped 89 | self.ff_loss_pos = loss_pos 90 | self.ff_loss_neg = loss_neg 91 | self.ff_task_fn = self.ff_task_transform 92 | self.tasks = { 93 | self.TASK_TRANSFORM: self.ff_task_transform, 94 | self.TASK_TRAIN_POS: self.ff_task_train_pos, 95 | self.TASK_TRAIN_NEG: self.ff_task_train_neg, 96 | self.TASK_EVAL_POS: self.ff_task_eval_pos, 97 | self.TASK_EVAL_NEG: self.ff_task_eval_neg, 98 | self.TASK_EVAL_DUPED_POS: self.ff_task_eval_duped_pos, 99 | } 100 | 101 | def ff_set_task(self, task): 102 | ''' 103 | Set the default task. 104 | ''' 105 | self.ff_task_fn = self.tasks[task] 106 | 107 | def ff_do_task(self, X, y_true=None): 108 | ''' 109 | Calls the default task. 110 | ''' 111 | return self.ff_task_fn(X, y_true) 112 | 113 | def ff_task_transform(self, X, y_true=None): 114 | return self(X) 115 | 116 | def ff_task_train_pos(self, X, y_true=None): 117 | return self(X) 118 | 119 | def ff_task_train_neg(self, X, y_true=None): 120 | return self(X) 121 | 122 | def ff_task_eval_pos(self, X, y_true=None): 123 | return self(X) 124 | 125 | def ff_task_eval_neg(self, X, y_true=None): 126 | return self(X) 127 | 128 | def ff_task_eval_duped_pos(self, X, y_true=None): 129 | return self(X) 130 | 131 | def __call__(self, *args, **kwargs): 132 | return self.tfl(*args, **kwargs) 133 | 134 | @property 135 | def trainable_weights(self): 136 | return self.tfl.trainable_weights 137 | 138 | # FFLayers inheriting the base class and a tf.keras.layers.Layer 139 | class FFDense(BaseFFLayer): 140 | def __init__(self, th_pos, th_neg, optimizer, **tfl_kwargs): 141 | super().__init__( 142 | tfl=tf.keras.layers.Dense(activation='relu', **tfl_kwargs), 143 | optimizer=optimizer, 144 | metric=FFMetric(), 145 | loss_pos=FFLoss(th_pos), 146 | loss_neg=FFLoss(th_neg), 147 | ) 148 | 149 | def ff_task_train_pos(self, X, y_true): 150 | with tf.GradientTape() as tape: 151 | y_pred = self(X) 152 | loss = self.ff_loss_pos(y_true, y_pred) 153 | grads = tape.gradient(loss, self.trainable_weights) 154 | self.ff_opt.apply_gradients(zip(grads, self.trainable_weights)) 155 | return y_pred 156 | 157 | def ff_task_train_neg(self, X, y_true): 158 | with tf.GradientTape() as tape: 159 | y_pred = self(X) 160 | loss = self.ff_loss_neg(y_true, y_pred) 161 | grads = tape.gradient(loss, self.trainable_weights) 162 | self.ff_opt.apply_gradients(zip(grads, self.trainable_weights)) 163 | return y_pred 164 | 165 | def ff_task_eval_pos(self, X, y_true): 166 | y_pred = self(X) 167 | self.ff_metric.update_state(y_true, y_pred) 168 | return y_pred 169 | 170 | def ff_task_eval_neg(self, X, y_true): 171 | y_pred = self(X) 172 | self.ff_metric.update_state(y_true, y_pred) 173 | return y_pred 174 | 175 | class FFSoftmax(BaseFFLayer): 176 | def __init__(self, optimizer, **tfl_kwargs): 177 | super().__init__( 178 | tfl=tf.keras.layers.Dense(activation='linear', **tfl_kwargs), 179 | optimizer=optimizer, 180 | metric=tf.keras.metrics.SparseCategoricalAccuracy(), 181 | loss_pos=tf.keras.losses.SparseCategoricalCrossentropy(True), 182 | ) 183 | 184 | def ff_task_train_pos(self, X, y_true): 185 | with tf.GradientTape() as tape: 186 | y_pred = self(X) 187 | loss = self.ff_loss_pos(y_true, y_pred) 188 | grads = tape.gradient(loss, self.trainable_weights) 189 | self.ff_opt.apply_gradients(zip(grads, self.trainable_weights)) 190 | return y_pred 191 | 192 | def ff_task_eval_pos(self, X, y_true): 193 | y_pred = self(X) 194 | self.ff_metric.update_state(y_true, y_pred) 195 | return y_pred 196 | 197 | class FFGoodness(BaseFFLayer): 198 | def __init__(self, **tfl_kwargs): 199 | super().__init__( 200 | tfl=tf.keras.layers.Lambda(goodness, **tfl_kwargs), 201 | metric_duped=tf.keras.metrics.SparseCategoricalAccuracy(), 202 | ) 203 | 204 | def ff_task_eval_duped_pos(self, X, y_true): 205 | m = tf.shape(y_true)[0] 206 | y_pred = self(X) 207 | y_pred = tf.reshape(y_pred, (m, -1)) 208 | self.ff_metric_duped.update_state(y_true, y_pred) 209 | return y_pred 210 | 211 | class FFOverlay(BaseFFLayer): 212 | def __init__(self, embedding, **tfl_kwargs): 213 | ''' 214 | When `ff_task_eval_pos` or `ff_task_transform` is called, it 215 | overlays an embedding onto a sample based on the sample's 216 | `y_true`. When `ff_task_eval_duped_pos` is called, it overlays 217 | all embeddings onto each sample, transforming an `X` from 218 | shape `(samples, features)` to 219 | `(samples * embeddings, features)`. 220 | 221 | Args: 222 | embedding. a `Tensor` of shape 223 | `(number of embeddings, features)`. 224 | ''' 225 | self.ff_embedding = tf.cast(embedding, tf.float32) 226 | self.ff_emb_shape = tf.shape(self.ff_embedding)[1:] 227 | 228 | def function(X): 229 | X, y_true = X 230 | return X + tf.gather(self.ff_embedding, y_true) 231 | 232 | super().__init__( 233 | tfl=tf.keras.layers.Lambda(function, **tfl_kwargs), 234 | ) 235 | 236 | def ff_task_eval_duped_pos(self, X, y_true=None): 237 | X, y_true = X 238 | y_pred = tf.expand_dims(X, 1) + self.ff_embedding 239 | y_pred = tf.reshape(y_pred, (-1, *tf.unstack(self.ff_emb_shape))) 240 | return y_pred 241 | 242 | class FFPreNorm(BaseFFLayer): 243 | def __init__(self, **tfl_kwargs): 244 | super().__init__( 245 | tfl=tf.keras.layers.Lambda(preNorm, **tfl_kwargs), 246 | ) 247 | 248 | class FFRoutedDense(BaseFFLayer): 249 | def __init__(self, th_pos, th_neg, optimizer, **tfl_kwargs): 250 | super().__init__( 251 | tfl=tf.keras.layers.Dense(activation='relu', **tfl_kwargs), 252 | optimizer=optimizer, 253 | metric=FFMetric(), 254 | loss_pos=FFLoss(th_pos), 255 | loss_neg=FFLoss(th_neg), 256 | ) 257 | 258 | def ff_set_ctu_map(self, ctu_map_pos, ctu_map_neg): 259 | self._ff_ctu_map_pos = ctu_map_pos 260 | self._ff_ctu_map_neg = ctu_map_neg 261 | return self 262 | 263 | def ff_set_classes(self, classes): 264 | self._ff_classes = classes 265 | return self 266 | 267 | def _ff_route_y_pred(self, y_pred, ctu_map): 268 | route = tf.nn.embedding_lookup(ctu_map, self._ff_classes) 269 | y_pred_routed = y_pred * route 270 | return y_pred_routed 271 | 272 | def ff_task_train_pos(self, X, y_true): 273 | with tf.GradientTape() as tape: 274 | y_pred = self(X) 275 | y_pred_routed = self._ff_route_y_pred(y_pred, self._ff_ctu_map_pos) 276 | loss = self.ff_loss_pos(y_true, y_pred_routed) 277 | grads = tape.gradient(loss, self.trainable_weights) 278 | self.ff_opt.apply_gradients(zip(grads, self.trainable_weights)) 279 | return y_pred 280 | 281 | def ff_task_train_neg(self, X, y_true): 282 | with tf.GradientTape() as tape: 283 | y_pred = self(X) 284 | y_pred_routed = self._ff_route_y_pred(y_pred, self._ff_ctu_map_neg) 285 | loss = self.ff_loss_neg(y_true, y_pred_routed) 286 | grads = tape.gradient(loss, self.trainable_weights) 287 | self.ff_opt.apply_gradients(zip(grads, self.trainable_weights)) 288 | return y_pred 289 | 290 | def ff_task_eval_pos(self, X, y_true): 291 | y_pred = self(X) 292 | y_pred_routed = self._ff_route_y_pred(y_pred, self._ff_ctu_map_pos) 293 | self.ff_metric.update_state(y_true, y_pred_routed) 294 | return y_pred 295 | 296 | def ff_task_eval_neg(self, X, y_true): 297 | y_pred = self(X) 298 | y_pred_routed = self._ff_route_y_pred(y_pred, self._ff_ctu_map_neg) 299 | self.ff_metric.update_state(y_true, y_pred_routed) 300 | return y_pred 301 | 302 | # class FFClassFilter(BaseFFLayer): 303 | # def __init__(self, keep_classes, num_classes, **tfl_kwargs): 304 | # self.ff_keep_classes = tf.transpose( 305 | # tf.reduce_sum( 306 | # tf.one_hot(keep_classes, num_classes), 307 | # axis=0, 308 | # keepdims=True)) 309 | 310 | # def function(X): 311 | # X, y_true = X 312 | # arg = tf.equal( 313 | # tf.squeeze( 314 | # tf.one_hot(y_true, num_classes) @\ 315 | # self.ff_keep_classes), 1.) 316 | # X = tf.boolean_mask(X, arg) 317 | # y_true = tf.boolean_mask(y_true, arg) 318 | # return (X, y_true) 319 | 320 | # super().__init__( 321 | # tfl=tf.keras.layers.Lambda(function, **tfl_kwargs), 322 | # ) 323 | 324 | # class FFGoodness2(BaseFFLayer): 325 | # def __init__(self, **tfl_kwargs): 326 | # super().__init__( 327 | # tfl=tf.keras.layers.Layer, 328 | # metric=tf.keras.metrics.SparseCategoricalAccuracy(), 329 | # **tfl_kwargs) 330 | 331 | # def ff_task_eval_pos(self, X, y_true): 332 | # m = tf.shape(y_true)[0] 333 | # y_pred = X 334 | # y_pred = tf.reshape(y_pred, (m, 10, -1)) 335 | # y_pred = goodness(y_pred) 336 | # self.ff_metric.update_state(y_true, y_pred) 337 | # return y_pred --------------------------------------------------------------------------------