├── .gitignore ├── LICENSE ├── README.md ├── models ├── simplenet_pt.py ├── simplenet_tf.py ├── simplenet_tf_mixed.py └── vgg.py ├── pt_deconv.py ├── tests ├── __init__.py ├── test_pt_deconv.py ├── test_tf_deconv.py └── test_tf_deconv_mixed_prec.py ├── tf_deconv.py ├── tf_deconv_mixed_prec.py ├── train_cifar10.py └── utils ├── optim.py └── schedule.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | log_cifar10/* 132 | checkpoints/* 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Somshubra Majumdar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Deconvolutions (Tensorflow) 2 | 3 | Tensorflow implementation of the `FastDeconv2D` and `ChannelDeconv` layers from the paper [Network Deconvolution](https://openreview.net/forum?id=rkeu30EtvS) by Ye et al. Code ported from the repository - https://github.com/yechengxi/deconvolution/. 4 | 5 | Tensorflow implementation also support mixed precision training, allowing larger training sizes with no reduction in accuracy (found in `tf_deconv_mixed_prec.py`). 6 | 7 | # Usage 8 | 9 | Simply download the `tf_deconv.py` script and import `ChannelDeconv2D` and `FastDeconv2D` layers. Mixed precision support can be found in equivalent classes inside `tf_deconv_mixed_prec.py`. 10 | 11 | A baseline model has been provided in `models/vgg.py` to try out the architecture. `FastDeconv2D` can replace most Conv2D layer operations. 12 | 13 | ## Important Note 14 | ----------------- 15 | 16 | It is crucial to initialize your models properly to obtain correct performance. 17 | 18 | 1) All `FastDeconv2D` kernels are initialized by default using `he_uniform`, and their bias by `BiasHeUniform`. 19 | 20 | 2) Final `Dense` layer `kernel_initializer` should be `he_uniform` and `bias_initializer` should be `BiasHeUniform`. 21 | 22 | -------- 23 | 24 | ```python 25 | import tensorflow as tf 26 | from tf_deconv import FastDeconv2D, ChannelDeconv2D, BiasHeUniform 27 | 28 | kernel_size = 3 29 | 30 | cfg = { 31 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 32 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 33 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 34 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 35 | } 36 | 37 | 38 | class VGG(tf.keras.Model): 39 | def __init__(self, vgg_name, num_classes=10): 40 | super(VGG, self).__init__() 41 | assert vgg_name in cfg.keys(), "Choose VGG model from {}".format(cfg.keys()) 42 | 43 | self.features = self._make_layers(cfg[vgg_name]) 44 | self.channel_deconv = ChannelDeconv2D(block=512) 45 | self.classifier = tf.keras.layers.Dense(num_classes, activation='softmax', 46 | kernel_initializer='he_uniform', 47 | bias_initializer=BiasHeUniform(), 48 | ) 49 | 50 | def call(self, x, training=None, mask=None): 51 | out = self.features(x, training=training) 52 | out = self.channel_deconv(out, training=training) 53 | out = self.classifier(out) 54 | return out 55 | 56 | def _make_layers(self, cfg): 57 | layers = [] 58 | in_channels = 3 59 | 60 | for x in cfg: 61 | if x == 'M': 62 | layers.append(tf.keras.layers.MaxPool2D()) 63 | else: 64 | if in_channels == 3: 65 | deconv = FastDeconv2D(in_channels, x, kernel_size=(kernel_size, kernel_size), padding='same', 66 | freeze=True, n_iter=15, block=64, activation='relu') 67 | else: 68 | deconv = FastDeconv2D(in_channels, x, kernel_size=(kernel_size, kernel_size), padding='same', 69 | block=64, activation='relu') 70 | 71 | layers.append(deconv) 72 | in_channels = x 73 | 74 | layers.append(tf.keras.layers.GlobalAveragePooling2D()) 75 | return tf.keras.Sequential(layers) 76 | ``` 77 | 78 | # Dependencies 79 | - Tensorflow 2.1+ 80 | -------------------------------------------------------------------------------- /models/simplenet_pt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from pt_deconv import FastDeconv1D, FastDeconv2D, ChannelDeconv1D, ChannelDeconv2D 6 | 7 | 8 | class SimpleNet1D(nn.Module): 9 | 10 | def __init__(self, num_classes, num_channels=64, groups=1, channel_deconv_loc="pre", blocks=64): 11 | super().__init__() 12 | 13 | self.num_channels = num_channels 14 | self.num_classes = num_channels 15 | self.channel_deconv_loc = channel_deconv_loc 16 | 17 | self.conv1 = FastDeconv1D(3, num_channels, kernel_size=3, stride=2, 18 | groups=1, padding=1, 19 | n_iter=5, momentum=0.1, block=blocks) 20 | 21 | self.conv2 = FastDeconv1D(num_channels, num_channels, kernel_size=3, stride=2, 22 | groups=groups, padding=1, 23 | n_iter=5, momentum=0.1, block=blocks) 24 | 25 | self.final_conv = ChannelDeconv1D(block=num_channels, momentum=0.1) 26 | 27 | self.gap = nn.AdaptiveAvgPool1d(1) 28 | self.clf = nn.Linear(num_channels, num_classes) 29 | 30 | def forward(self, inputs): 31 | x = self.conv1(inputs) 32 | x = F.relu(x, inplace=True) 33 | x = self.conv2(x) 34 | x = F.relu(x, inplace=True) 35 | 36 | if self.channel_deconv_loc == 'pre': 37 | x = self.final_conv(x) 38 | x = self.gap(x) 39 | else: 40 | x = self.gap(x) 41 | x = self.final_conv(x) 42 | 43 | x = x.view(-1, self.num_channels) 44 | x = self.clf(x) 45 | return x 46 | 47 | 48 | class SimpleNet2D(nn.Module): 49 | 50 | def __init__(self, num_classes, num_channels=64, groups=1, channel_deconv_loc="pre", blocks=64): 51 | super().__init__() 52 | 53 | self.num_channels = num_channels 54 | self.num_classes = num_channels 55 | self.channel_deconv_loc = channel_deconv_loc 56 | 57 | self.conv1 = FastDeconv2D(3, num_channels, kernel_size=(3, 3), stride=(2, 2), 58 | groups=1, padding=1, 59 | n_iter=5, momentum=0.9, block=blocks) 60 | 61 | self.conv2 = FastDeconv2D(num_channels, num_channels, kernel_size=(3, 3), stride=(2, 2), 62 | groups=groups, padding=1, 63 | n_iter=5, momentum=0.9, block=blocks) 64 | 65 | self.final_conv = ChannelDeconv2D(block=num_channels, momentum=0.1) 66 | 67 | self.gap = nn.AdaptiveAvgPool2d(1) 68 | self.clf = nn.Linear(num_channels, num_classes) 69 | 70 | def forward(self, inputs): 71 | x = self.conv1(inputs) 72 | x = F.relu(x, inplace=True) 73 | x = self.conv2(x) 74 | x = F.relu(x, inplace=True) 75 | 76 | if self.channel_deconv_loc == 'pre': 77 | x = self.final_conv(x) 78 | x = self.gap(x) 79 | else: 80 | x = self.gap(x) 81 | x = self.final_conv(x) 82 | 83 | x = x.view(-1, self.num_channels) 84 | x = self.clf(x) 85 | return x 86 | -------------------------------------------------------------------------------- /models/simplenet_tf.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tf_deconv import FastDeconv1D, FastDeconv2D, ChannelDeconv1D, ChannelDeconv2D 3 | 4 | 5 | class SimpleNet1D(tf.keras.Model): 6 | 7 | def __init__(self, num_classes, num_channels=64, groups=1, channel_deconv_loc="pre", blocks=64): 8 | super().__init__() 9 | 10 | self.num_channels = num_channels 11 | self.num_classes = num_channels 12 | self.channel_deconv_loc = channel_deconv_loc 13 | 14 | self.conv1 = FastDeconv1D(3, num_channels, kernel_size=3, stride=2, 15 | padding='same', activation='relu', groups=1, 16 | n_iter=5, momentum=0.1, block=blocks) 17 | 18 | self.conv2 = FastDeconv1D(num_channels, num_channels, kernel_size=3, stride=2, 19 | padding='same', activation='relu', groups=groups, 20 | n_iter=5, momentum=0.1, block=blocks) 21 | 22 | self.final_conv = ChannelDeconv1D(block=num_channels, momentum=0.1) 23 | 24 | self.gap = tf.keras.layers.GlobalAveragePooling1D() 25 | self.clf = tf.keras.layers.Dense(num_classes, activation='softmax') 26 | 27 | def call(self, inputs, training=None, mask=None): 28 | x = self.conv1(inputs, training=training) 29 | x = self.conv2(x, training=training) 30 | 31 | if self.channel_deconv_loc == 'pre': 32 | x = self.final_conv(x, training=training) 33 | x = self.gap(x) 34 | else: 35 | x = self.gap(x) 36 | x = self.final_conv(x, training=training) 37 | 38 | x = self.clf(x) 39 | 40 | return x 41 | 42 | 43 | class SimpleNet2D(tf.keras.Model): 44 | 45 | def __init__(self, num_classes, num_channels=64, groups=1, channel_deconv_loc="pre", blocks=64): 46 | super().__init__() 47 | 48 | self.num_channels = num_channels 49 | self.num_classes = num_channels 50 | self.channel_deconv_loc = channel_deconv_loc 51 | 52 | self.conv1 = FastDeconv2D(3, num_channels, kernel_size=(3, 3), stride=(2, 2), 53 | padding='same', activation='relu', groups=1, 54 | n_iter=5, momentum=0.9, block=blocks) 55 | 56 | self.conv2 = FastDeconv2D(num_channels, num_channels, kernel_size=(3, 3), stride=(2, 2), 57 | padding='same', activation='relu', groups=groups, 58 | n_iter=5, momentum=0.9, block=blocks) 59 | 60 | self.final_conv = ChannelDeconv2D(block=num_channels, momentum=0.1) 61 | 62 | self.gap = tf.keras.layers.GlobalAveragePooling2D() 63 | self.clf = tf.keras.layers.Dense(num_classes, activation='softmax') 64 | 65 | def call(self, inputs, training=None, mask=None): 66 | x = self.conv1(inputs, training=training) 67 | x = self.conv2(x, training=training) 68 | 69 | if self.channel_deconv_loc == 'pre': 70 | x = self.final_conv(x, training=training) 71 | x = self.gap(x) 72 | else: 73 | x = self.gap(x) 74 | x = self.final_conv(x, training=training) 75 | 76 | x = self.clf(x) 77 | 78 | return x 79 | -------------------------------------------------------------------------------- /models/simplenet_tf_mixed.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tf_deconv_mixed_prec import FastDeconv1D, FastDeconv2D, ChannelDeconv1D, ChannelDeconv2D 3 | 4 | 5 | class SimpleNet1D(tf.keras.Model): 6 | 7 | def __init__(self, num_classes, num_channels=64, groups=1, channel_deconv_loc="pre", blocks=64, **kwargs): 8 | super().__init__(**kwargs) 9 | 10 | self.num_channels = num_channels 11 | self.num_classes = num_channels 12 | self.channel_deconv_loc = channel_deconv_loc 13 | 14 | self.conv1 = FastDeconv1D(3, num_channels, kernel_size=3, stride=2, 15 | padding='same', activation='relu', groups=1, 16 | n_iter=5, momentum=0.1, block=blocks, **kwargs) 17 | 18 | self.conv2 = FastDeconv1D(num_channels, num_channels, kernel_size=3, stride=2, 19 | padding='same', activation='relu', groups=groups, 20 | n_iter=5, momentum=0.1, block=blocks, **kwargs) 21 | 22 | self.final_conv = ChannelDeconv1D(block=num_channels, momentum=0.1, **kwargs) 23 | 24 | self.gap = tf.keras.layers.GlobalAveragePooling1D(**kwargs) 25 | self.clf = tf.keras.layers.Dense(num_classes, activation=None, **kwargs) 26 | self.softmax = tf.keras.layers.Activation('softmax', dtype='float32') 27 | 28 | def call(self, inputs, training=None, mask=None): 29 | x = self.conv1(inputs, training=training) 30 | x = self.conv2(x, training=training) 31 | 32 | if self.channel_deconv_loc == 'pre': 33 | x = self.final_conv(x, training=training) 34 | x = self.gap(x) 35 | else: 36 | x = self.gap(x) 37 | x = self.final_conv(x, training=training) 38 | 39 | x = self.clf(x) 40 | x = self.softmax(x) 41 | 42 | return x 43 | 44 | 45 | class SimpleNet2D(tf.keras.Model): 46 | 47 | def __init__(self, num_classes, num_channels=64, groups=1, channel_deconv_loc="pre", blocks=64, **kwargs): 48 | super().__init__(**kwargs) 49 | 50 | self.num_channels = num_channels 51 | self.num_classes = num_channels 52 | self.channel_deconv_loc = channel_deconv_loc 53 | 54 | self.conv1 = FastDeconv2D(3, num_channels, kernel_size=(3, 3), stride=(2, 2), 55 | padding='same', activation='relu', groups=1, 56 | n_iter=5, momentum=0.9, block=blocks, **kwargs) 57 | 58 | self.conv2 = FastDeconv2D(num_channels, num_channels, kernel_size=(3, 3), stride=(2, 2), 59 | padding='same', activation='relu', groups=groups, 60 | n_iter=5, momentum=0.9, block=blocks, **kwargs) 61 | 62 | self.final_conv = ChannelDeconv2D(block=num_channels, momentum=0.1, **kwargs) 63 | 64 | self.gap = tf.keras.layers.GlobalAveragePooling2D(**kwargs) 65 | self.clf = tf.keras.layers.Dense(num_classes, activation=None, **kwargs) 66 | self.softmax = tf.keras.layers.Activation('softmax', dtype='float32') 67 | 68 | def call(self, inputs, training=None, mask=None): 69 | x = self.conv1(inputs, training=training) 70 | x = self.conv2(x, training=training) 71 | 72 | if self.channel_deconv_loc == 'pre': 73 | x = self.final_conv(x, training=training) 74 | x = self.gap(x) 75 | else: 76 | x = self.gap(x) 77 | x = self.final_conv(x, training=training) 78 | 79 | x = self.clf(x) 80 | x = self.softmax(x) 81 | 82 | return x 83 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | """VGG11/13/16/19 in Pytorch.""" 2 | import tensorflow as tf 3 | from tf_deconv import FastDeconv2D, ChannelDeconv2D, BiasHeUniform 4 | 5 | kernel_size = 3 6 | 7 | cfg = { 8 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 10 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 11 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 12 | } 13 | 14 | 15 | class VGG(tf.keras.Model): 16 | def __init__(self, vgg_name, num_classes=10): 17 | super(VGG, self).__init__() 18 | assert vgg_name in cfg.keys(), "Choose VGG model from {}".format(cfg.keys()) 19 | 20 | self.features = self._make_layers(cfg[vgg_name]) 21 | self.channel_deconv = ChannelDeconv2D(block=512) 22 | self.classifier = tf.keras.layers.Dense(num_classes, activation='softmax', 23 | kernel_initializer='he_uniform', 24 | bias_initializer=BiasHeUniform(), 25 | ) 26 | 27 | def call(self, x, training=None, mask=None): 28 | out = self.features(x, training=training) 29 | out = self.channel_deconv(out, training=training) 30 | out = self.classifier(out) 31 | return out 32 | 33 | def _make_layers(self, cfg): 34 | layers = [] 35 | in_channels = 3 36 | 37 | for x in cfg: 38 | if x == 'M': 39 | layers.append(tf.keras.layers.MaxPool2D()) 40 | else: 41 | if in_channels == 3: 42 | deconv = FastDeconv2D(in_channels, x, kernel_size=(kernel_size, kernel_size), padding='same', 43 | freeze=True, n_iter=15, block=64, activation='relu') 44 | else: 45 | deconv = FastDeconv2D(in_channels, x, kernel_size=(kernel_size, kernel_size), padding='same', 46 | block=64, activation='relu') 47 | 48 | layers.append(deconv) 49 | in_channels = x 50 | 51 | layers.append(tf.keras.layers.GlobalAveragePooling2D()) 52 | return tf.keras.Sequential(layers) 53 | 54 | 55 | if __name__ == '__main__': 56 | x = tf.zeros([16, 32, 32, 3]) 57 | model = VGG(vgg_name='VGG16', num_classes=10) 58 | 59 | # trace the model 60 | model_traced = tf.function(model) 61 | 62 | out = model_traced(x, training=True) 63 | print(out.shape) 64 | 65 | # 14.71 M trainable params, 18.97 total params; matches paper 66 | model.summary() 67 | -------------------------------------------------------------------------------- /pt_deconv.py: -------------------------------------------------------------------------------- 1 | """ Modified from https://github.com/yechengxi/deconvolution/blob/master/models/deconv.py """ 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch.nn.modules import conv 8 | from torch.nn.modules.utils import _pair 9 | 10 | 11 | # iteratively solve for inverse sqrt of a matrix 12 | def isqrt_newton_schulz_autograd(A, numIters): 13 | dim = A.shape[0] 14 | normA = A.norm() 15 | Y = A.div(normA) 16 | I = torch.eye(dim, dtype=A.dtype, device=A.device) 17 | Z = torch.eye(dim, dtype=A.dtype, device=A.device) 18 | 19 | for i in range(numIters): 20 | T = 0.5 * (3.0 * I - Z @ Y) 21 | Y = Y @ T 22 | Z = T @ Z 23 | # A_sqrt = Y*torch.sqrt(normA) 24 | A_isqrt = Z / torch.sqrt(normA) 25 | return A_isqrt 26 | 27 | 28 | def isqrt_newton_schulz_autograd_batch(A, numIters): 29 | batchSize, dim, _ = A.shape 30 | normA = A.view(batchSize, -1).norm(2, 1).view(batchSize, 1, 1) 31 | Y = A.div(normA) 32 | I = torch.eye(dim, dtype=A.dtype, device=A.device).unsqueeze(0).expand_as(A) 33 | Z = torch.eye(dim, dtype=A.dtype, device=A.device).unsqueeze(0).expand_as(A) 34 | 35 | for i in range(numIters): 36 | T = 0.5 * (3.0 * I - Z.bmm(Y)) 37 | Y = Y.bmm(T) 38 | Z = T.bmm(Z) 39 | # A_sqrt = Y*torch.sqrt(normA) 40 | A_isqrt = Z / torch.sqrt(normA) 41 | 42 | return A_isqrt 43 | 44 | 45 | # deconvolve channels 46 | class ChannelDeconv2D(nn.Module): 47 | def __init__(self, block, eps=1e-2, n_iter=5, momentum=0.1, sampling_stride=3): 48 | super(ChannelDeconv2D, self).__init__() 49 | 50 | self.eps = eps 51 | self.n_iter = n_iter 52 | self.momentum = momentum 53 | self.block = block 54 | 55 | self.register_buffer('running_mean1', torch.zeros(block, 1)) 56 | # self.register_buffer('running_cov', torch.eye(block)) 57 | self.register_buffer('running_deconv', torch.eye(block)) 58 | self.register_buffer('running_mean2', torch.zeros(1, 1)) 59 | self.register_buffer('running_var', torch.ones(1, 1)) 60 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 61 | self.sampling_stride = sampling_stride 62 | 63 | def forward(self, x): 64 | x_shape = x.shape 65 | if len(x.shape) == 2: 66 | x = x.view(x.shape[0], x.shape[1], 1, 1) 67 | if len(x.shape) == 3: 68 | print('Error! Unsupprted tensor shape.') 69 | 70 | N, C, H, W = x.size() 71 | B = self.block 72 | 73 | # take the first c channels out for deconv 74 | c = int(C / B) * B 75 | if c == 0: 76 | print('Error! block should be set smaller.') 77 | 78 | # step 1. remove mean 79 | if c != C: 80 | x1 = x[:, :c].permute(1, 0, 2, 3).contiguous().view(B, -1) 81 | else: 82 | x1 = x.permute(1, 0, 2, 3).contiguous().view(B, -1) 83 | 84 | if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride: 85 | x1_s = x1[:, ::self.sampling_stride ** 2] 86 | else: 87 | x1_s = x1 88 | 89 | mean1 = x1_s.mean(-1, keepdim=True) 90 | 91 | if self.num_batches_tracked == 0: 92 | self.running_mean1.copy_(mean1.detach()) 93 | if self.training: 94 | self.running_mean1.mul_(1 - self.momentum) 95 | self.running_mean1.add_(mean1.detach() * self.momentum) 96 | else: 97 | mean1 = self.running_mean1 98 | 99 | x1 = x1 - mean1 100 | 101 | # step 2. calculate deconv@x1 = cov^(-0.5)@x1 102 | if self.training: 103 | cov = x1_s @ x1_s.t() / x1_s.shape[1] + self.eps * torch.eye(B, dtype=x.dtype, device=x.device) 104 | deconv = isqrt_newton_schulz_autograd(cov, self.n_iter) 105 | 106 | if self.num_batches_tracked == 0: 107 | # self.running_cov.copy_(cov.detach()) 108 | self.running_deconv.copy_(deconv.detach()) 109 | 110 | if self.training: 111 | # self.running_cov.mul_(1-self.momentum) 112 | # self.running_cov.add_(cov.detach()*self.momentum) 113 | self.running_deconv.mul_(1 - self.momentum) 114 | self.running_deconv.add_(deconv.detach() * self.momentum) 115 | else: 116 | # cov = self.running_cov 117 | deconv = self.running_deconv 118 | 119 | x1 = deconv @ x1 120 | 121 | # reshape to N,c,J,W 122 | x1 = x1.view(c, N, H, W).contiguous().permute(1, 0, 2, 3) 123 | 124 | # normalize the remaining channels 125 | if c != C: 126 | x_tmp = x[:, c:].view(N, -1) 127 | if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride: 128 | x_s = x_tmp[:, ::self.sampling_stride ** 2] 129 | else: 130 | x_s = x_tmp 131 | 132 | mean2 = x_s.mean() 133 | var = x_s.var() 134 | 135 | if self.num_batches_tracked == 0: 136 | self.running_mean2.copy_(mean2.detach()) 137 | self.running_var.copy_(var.detach()) 138 | 139 | if self.training: 140 | self.running_mean2.mul_(1 - self.momentum) 141 | self.running_mean2.add_(mean2.detach() * self.momentum) 142 | self.running_var.mul_(1 - self.momentum) 143 | self.running_var.add_(var.detach() * self.momentum) 144 | else: 145 | mean2 = self.running_mean2 146 | var = self.running_var 147 | 148 | x_tmp = (x[:, c:] - mean2) / (var + self.eps).sqrt() 149 | x1 = torch.cat([x1, x_tmp], dim=1) 150 | 151 | if self.training: 152 | self.num_batches_tracked.add_(1) 153 | 154 | if len(x_shape) == 2: 155 | x1 = x1.view(x_shape) 156 | return x1 157 | 158 | 159 | class FastDeconv2D(conv._ConvNd): 160 | 161 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 162 | eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3, freeze=False, freeze_iter=100): 163 | self.momentum = momentum 164 | self.n_iter = n_iter 165 | self.eps = eps 166 | self.counter = 0 167 | self.track_running_stats = True 168 | super(FastDeconv2D, self).__init__( 169 | in_channels, out_channels, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation), 170 | False, _pair(0), groups, bias, padding_mode='zeros') 171 | 172 | if block > in_channels: 173 | block = in_channels 174 | else: 175 | if in_channels % block != 0: 176 | block = math.gcd(block, in_channels) 177 | 178 | if groups > 1: 179 | # grouped conv 180 | block = in_channels // groups 181 | 182 | self.block = block 183 | 184 | self.num_features = kernel_size[0] * kernel_size[1] * block 185 | if groups == 1: 186 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 187 | self.register_buffer('running_deconv', torch.eye(self.num_features)) 188 | else: 189 | self.register_buffer('running_mean', torch.zeros(kernel_size[0] * kernel_size[1] * in_channels)) 190 | self.register_buffer('running_deconv', torch.eye(self.num_features).repeat(in_channels // block, 1, 1)) 191 | 192 | stride_int = stride[0] if type(stride) in (list, tuple) else stride 193 | self.sampling_stride = sampling_stride * stride_int 194 | self.counter = 0 195 | self.freeze_iter = freeze_iter 196 | self.freeze = freeze 197 | 198 | def forward(self, x): 199 | N, C, H, W = x.shape 200 | B = self.block 201 | frozen = self.freeze and (self.counter > self.freeze_iter) 202 | if self.training and self.track_running_stats: 203 | self.counter += 1 204 | self.counter %= (self.freeze_iter * 10) 205 | 206 | if self.training and (not frozen): 207 | 208 | # 1. im2col: N x cols x pixels -> N*pixles x cols 209 | if self.kernel_size[0] > 1: 210 | X = torch.nn.functional.unfold(x, self.kernel_size, self.dilation, self.padding, 211 | self.sampling_stride).transpose(1, 2).contiguous() 212 | else: 213 | # channel wise 214 | X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride ** 2, :] 215 | 216 | if self.groups == 1: 217 | # (C//B*N*pixels,k*k*B) 218 | X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features) 219 | else: 220 | X = X.view(-1, X.shape[-1]) 221 | 222 | # 2. subtract mean 223 | X_mean = X.mean(0) 224 | X = X - X_mean.unsqueeze(0) 225 | 226 | # 3. calculate COV, COV^(-0.5), then deconv 227 | if self.groups == 1: 228 | # Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device) 229 | Id = torch.eye(X.shape[1], dtype=X.dtype, device=X.device) 230 | Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X) 231 | deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter) 232 | else: 233 | X = X.view(-1, self.groups, self.num_features).transpose(0, 1) 234 | Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, 235 | self.num_features) 236 | Cov = torch.baddbmm(self.eps, Id, 1. / X.shape[1], X.transpose(1, 2), X) 237 | 238 | deconv = isqrt_newton_schulz_autograd_batch(Cov, self.n_iter) 239 | 240 | if self.track_running_stats: 241 | self.running_mean.mul_(1 - self.momentum) 242 | self.running_mean.add_(X_mean.detach() * self.momentum) 243 | # track stats for evaluation 244 | self.running_deconv.mul_(1 - self.momentum) 245 | self.running_deconv.add_(deconv.detach() * self.momentum) 246 | 247 | else: 248 | X_mean = self.running_mean 249 | deconv = self.running_deconv 250 | 251 | # 4. X * deconv * conv = X * (deconv * conv) 252 | if self.groups == 1: 253 | w = self.weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, 254 | self.num_features) @ deconv 255 | b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1) 256 | w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous() 257 | else: 258 | w = self.weight.view(C // B, -1, self.num_features) @ deconv 259 | b = self.bias - (w @ (X_mean.view(-1, self.num_features, 1))).view(self.bias.shape) 260 | 261 | w = w.view(self.weight.shape) 262 | x = F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups) 263 | 264 | return x 265 | 266 | 267 | """ 1D Conv Wrapper """ 268 | 269 | 270 | class ChannelDeconv1D(ChannelDeconv2D): 271 | 272 | def __init__(self, block, eps=1e-5, n_iter=5, momentum=0.1, sampling_stride=3): 273 | super(ChannelDeconv1D, self).__init__(block=block, eps=eps, n_iter=n_iter, 274 | momentum=momentum, sampling_stride=sampling_stride) 275 | 276 | def forward(self, x: torch.Tensor): 277 | # insert dummy dimension in time channel 278 | shape = x.size() 279 | 280 | if len(shape) == 3: 281 | x_expanded = x.unsqueeze(-1) # [N, C, T, 1] 282 | else: 283 | x_expanded = x 284 | 285 | out = super(ChannelDeconv1D, self).forward(x_expanded) 286 | 287 | if len(shape) == 3: 288 | # remove dummy dimension 289 | x = out.squeeze(-1) # [N, C', T / stride] 290 | else: 291 | x = out 292 | 293 | return x 294 | 295 | 296 | class FastDeconv1D(FastDeconv2D): 297 | 298 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 299 | eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3, freeze=False, freeze_iter=100): 300 | kernel_size = (kernel_size, 1) 301 | stride = (stride, 1) 302 | padding = (padding, 0) 303 | dilation = (dilation, 1) 304 | super(FastDeconv1D, self).__init__(in_channels=in_channels, out_channels=out_channels, 305 | kernel_size=kernel_size, stride=stride, padding=padding, 306 | dilation=dilation, bias=bias, groups=groups, eps=eps, 307 | n_iter=n_iter, momentum=momentum, block=block, 308 | sampling_stride=sampling_stride, freeze=freeze, freeze_iter=freeze_iter) 309 | 310 | def forward(self, x: torch.Tensor): 311 | # insert dummy dimension in time channel 312 | x_expanded = x.unsqueeze(-1) # [N, C, T, 1] 313 | 314 | out = super(FastDeconv1D, self).forward(x_expanded) 315 | 316 | # remove dummy dimension 317 | x = out.squeeze(-1) # [N, C', T / stride] 318 | return x 319 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/tf_neural_deconvolution/692403f8ed255ceb5f1c10136aa17babbaf7f7b4/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_pt_deconv.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | from models import simplenet_pt 5 | 6 | 7 | @pytest.mark.parametrize("groups", [1, 32, 64]) 8 | @pytest.mark.parametrize("channel_deconv_loc", ['pre', 'post']) 9 | @pytest.mark.parametrize("blocks", [1, 32, 64]) 10 | def test_fastdeconv_1d(groups, channel_deconv_loc, blocks): 11 | """ Test 1D variant """ 12 | x = torch.zeros([16, 3, 24]) 13 | model2 = simplenet_pt.SimpleNet1D(num_classes=10, num_channels=64, groups=groups, 14 | channel_deconv_loc=channel_deconv_loc, blocks=blocks) 15 | model2.train() 16 | out = model2(x) 17 | assert list(out.shape) == [16, 10] 18 | 19 | model2.eval() 20 | out = model2(x) 21 | assert list(out.shape) == [16, 10] 22 | 23 | 24 | @pytest.mark.parametrize("groups", [1, 32, 64]) 25 | @pytest.mark.parametrize("channe_deconv_loc", ['pre', 'post']) 26 | @pytest.mark.parametrize("blocks", [1, 32, 64]) 27 | def test_fastdeconv_2d_no_groups(groups, channe_deconv_loc, blocks): 28 | """ Test 1D variant """ 29 | x = torch.zeros([16, 3, 32, 32]) 30 | model2 = simplenet_pt.SimpleNet2D(num_classes=10, num_channels=64, groups=groups, 31 | channel_deconv_loc=channe_deconv_loc, blocks=blocks) 32 | model2.train() 33 | out = model2(x) 34 | assert list(out.shape) == [16, 10] 35 | 36 | model2.eval() 37 | out = model2(x) 38 | assert list(out.shape) == [16, 10] 39 | 40 | 41 | if __name__ == '__main__': 42 | pytest.main([__file__]) 43 | -------------------------------------------------------------------------------- /tests/test_tf_deconv.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import six 3 | import tensorflow as tf 4 | from tensorflow.keras.mixed_precision import experimental as mixed_precision 5 | 6 | from models import simplenet_tf 7 | 8 | 9 | def tf_context(func): 10 | @six.wraps(func) 11 | def wrapped(*args, **kwargs): 12 | # Run tests only on the gpu as grouped convs are not supported on cpu 13 | with tf.device('gpu:0'): 14 | out = func(*args, **kwargs) 15 | return out 16 | return wrapped 17 | 18 | 19 | @pytest.mark.parametrize("groups", [1, 32, 64]) 20 | @pytest.mark.parametrize("channel_deconv_loc", ['pre', 'post']) 21 | @pytest.mark.parametrize("blocks", [1, 32, 64]) 22 | @tf_context 23 | def test_fastdeconv_1d(groups, channel_deconv_loc, blocks): 24 | """ Test 1D variant """ 25 | x = tf.zeros([16, 24, 3]) 26 | model2 = simplenet_tf.SimpleNet1D(num_classes=10, num_channels=64, groups=groups, 27 | channel_deconv_loc=channel_deconv_loc, blocks=blocks) 28 | 29 | # trace the model 30 | model_traced2 = tf.function(model2) 31 | 32 | out = model_traced2(x, training=True) 33 | assert out.shape == [16, 10] 34 | 35 | out = model_traced2(x, training=False) 36 | assert out.shape == [16, 10] 37 | 38 | 39 | @pytest.mark.parametrize("groups", [1, 32, 64]) 40 | @pytest.mark.parametrize("channe_deconv_loc", ['pre', 'post']) 41 | @pytest.mark.parametrize("blocks", [1, 32, 64]) 42 | @tf_context 43 | def test_fastdeconv_2d(groups, channe_deconv_loc, blocks): 44 | """ Test 1D variant """ 45 | x = tf.zeros([16, 32, 32, 3]) 46 | model2 = simplenet_tf.SimpleNet2D(num_classes=10, num_channels=64, groups=groups, 47 | channel_deconv_loc=channe_deconv_loc, blocks=blocks) 48 | 49 | # trace the model 50 | model_traced2 = tf.function(model2) 51 | 52 | out = model_traced2(x, training=True) 53 | assert out.shape == [16, 10] 54 | 55 | out = model_traced2(x, training=False) 56 | assert out.shape == [16, 10] 57 | 58 | 59 | if __name__ == '__main__': 60 | pytest.main([__file__]) 61 | -------------------------------------------------------------------------------- /tests/test_tf_deconv_mixed_prec.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import six 3 | import tensorflow as tf 4 | from tensorflow.keras.mixed_precision import experimental as mixed_precision 5 | 6 | from models import simplenet_tf_mixed 7 | 8 | 9 | def tf_context(func): 10 | @six.wraps(func) 11 | def wrapped(*args, **kwargs): 12 | # Run tests only on the gpu as grouped convs are not supported on cpu 13 | with tf.device('gpu:0'): 14 | out = func(*args, **kwargs) 15 | return out 16 | return wrapped 17 | 18 | 19 | def tf_mixed_precision(func): 20 | @six.wraps(func) 21 | def wrapped(*args, **kwargs): 22 | # Run in mixed precision mode, then return to float32 mode 23 | policy = mixed_precision.Policy('mixed_float16') 24 | mixed_precision.set_policy(policy) 25 | 26 | # call method in mixed precision mode 27 | try: 28 | out = func(*args, **kwargs) 29 | finally: 30 | # Return to float32 precision mode 31 | policy = mixed_precision.Policy('float32') 32 | mixed_precision.set_policy(policy) 33 | 34 | return out 35 | return wrapped 36 | 37 | 38 | @pytest.mark.parametrize("groups", [1, 32, 64]) 39 | @pytest.mark.parametrize("channel_deconv_loc", ['pre', 'post']) 40 | @pytest.mark.parametrize("blocks", [1, 32, 64]) 41 | @tf_context 42 | @tf_mixed_precision 43 | def test_fastdeconv_1d_mixed_precision(groups, channel_deconv_loc, blocks): 44 | """ Test 1D variant """ 45 | policy = mixed_precision.global_policy() 46 | 47 | x = tf.zeros([16, 24, 3], dtype=tf.float16) 48 | model2 = simplenet_tf_mixed.SimpleNet1D(num_classes=10, num_channels=64, groups=groups, 49 | channel_deconv_loc=channel_deconv_loc, blocks=blocks, 50 | dtype=policy) 51 | 52 | # trace the model 53 | model_traced2 = tf.function(model2) 54 | 55 | out = model_traced2(x, training=True) 56 | assert out.shape == [16, 10] 57 | 58 | out = model_traced2(x, training=False) 59 | assert out.shape == [16, 10] 60 | 61 | 62 | @pytest.mark.parametrize("groups", [1, 32, 64]) 63 | @pytest.mark.parametrize("channe_deconv_loc", ['pre', 'post']) 64 | @pytest.mark.parametrize("blocks", [1, 32, 64]) 65 | @tf_context 66 | @tf_mixed_precision 67 | def test_fastdeconv_2d_mixed_precision(groups, channe_deconv_loc, blocks): 68 | """ Test 1D variant """ 69 | policy = mixed_precision.global_policy() 70 | 71 | x = tf.zeros([16, 32, 32, 3], dtype=tf.float16) 72 | model2 = simplenet_tf_mixed.SimpleNet2D(num_classes=10, num_channels=64, groups=groups, 73 | channel_deconv_loc=channe_deconv_loc, blocks=blocks, 74 | dtype=policy) 75 | 76 | # trace the model 77 | model_traced2 = tf.function(model2) 78 | 79 | out = model_traced2(x, training=True) 80 | assert out.shape == [16, 10] 81 | 82 | out = model_traced2(x, training=False) 83 | assert out.shape == [16, 10] 84 | 85 | 86 | if __name__ == '__main__': 87 | pytest.main([__file__]) 88 | -------------------------------------------------------------------------------- /tf_deconv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import tensorflow as tf 3 | 4 | from tensorflow.python.keras.layers.convolutional import Conv 5 | from tensorflow.python.keras.utils import conv_utils 6 | 7 | 8 | class BiasHeUniform(tf.keras.initializers.VarianceScaling): 9 | def __init__(self, seed=None): 10 | super(BiasHeUniform, self).__init__(scale=1. / 3., mode='fan_in', distribution='uniform', seed=seed) 11 | 12 | 13 | # iteratively solve for inverse sqrt of a matrix 14 | def isqrt_newton_schulz_autograd(A: tf.Tensor, numIters: int): 15 | dim = tf.shape(A)[0] 16 | normA = tf.norm(A, ord='fro', axis=[0, 1]) 17 | Y = A / normA 18 | 19 | with tf.device(A.device): 20 | I = tf.eye(dim, dtype=A.dtype) 21 | Z = tf.eye(dim, dtype=A.dtype) 22 | 23 | for i in range(numIters): 24 | T = 0.5 * (3.0 * I - tf.matmul(Z, Y)) 25 | Y = tf.matmul(Y, T) 26 | Z = tf.matmul(T, Z) 27 | 28 | A_isqrt = Z / tf.sqrt(normA) 29 | return A_isqrt 30 | 31 | 32 | def isqrt_newton_schulz_autograd_batch(A: tf.Tensor, numIters: int): 33 | Ashape = tf.shape(A) # [batch, _, C] 34 | batchSize, dim = Ashape[0], Ashape[-1] 35 | 36 | normA = tf.reshape(A, (batchSize, -1)) 37 | normA = tf.norm(normA, ord=2, axis=1) 38 | normA = tf.reshape(normA, [batchSize, 1, 1]) 39 | 40 | Y = A / normA 41 | 42 | with tf.device(A.device): 43 | I = tf.expand_dims(tf.eye(dim, dtype=A.dtype), 0) 44 | Z = tf.expand_dims(tf.eye(dim, dtype=A.dtype), 0) 45 | 46 | I = tf.broadcast_to(I, Ashape) 47 | Z = tf.broadcast_to(Z, Ashape) 48 | 49 | for i in range(numIters): 50 | T = 0.5 * (3.0 * I - tf.matmul(Z, Y)) 51 | Y = tf.matmul(Y, T) 52 | Z = tf.matmul(T, Z) 53 | 54 | A_isqrt = Z / tf.sqrt(normA) 55 | 56 | return A_isqrt 57 | 58 | 59 | class ChannelDeconv2D(tf.keras.layers.Layer): 60 | def __init__(self, block, eps=1e-5, n_iter=5, momentum=0.1, sampling_stride=3, **kwargs): 61 | super(ChannelDeconv2D, self).__init__(**kwargs) 62 | 63 | self.eps = eps 64 | self.n_iter = n_iter 65 | self.momentum = momentum 66 | self.block = block 67 | self.sampling_stride = sampling_stride 68 | 69 | self.running_mean1 = tf.Variable(tf.zeros([block, 1]), trainable=False, dtype=self.dtype) 70 | self.running_mean2 = tf.Variable(tf.zeros([]), trainable=False, dtype=self.dtype) 71 | self.running_var = tf.Variable(tf.ones([]), trainable=False, dtype=self.dtype) 72 | self.running_deconv = tf.Variable(tf.eye(block), trainable=False, dtype=self.dtype) 73 | self.num_batches_tracked = tf.Variable(tf.convert_to_tensor(0, dtype=tf.int64), trainable=False) 74 | 75 | self.block_eye = tf.eye(block) 76 | 77 | def build(self, input_shape): 78 | in_channels = input_shape[-1] 79 | self.in_channels = in_channels 80 | 81 | if int(in_channels / self.block) * self.block == 0: 82 | raise ValueError("`block` must be smaller than in_channels.") 83 | 84 | # change rank based on 3d or 4d tensor input 85 | channel_axis = -1 86 | self.input_spec = tf.keras.layers.InputSpec(min_ndim=2, 87 | max_ndim=4, 88 | axes={channel_axis: in_channels}) 89 | 90 | self.built = True 91 | 92 | @tf.function 93 | def call(self, x, training=None): 94 | x_shape = tf.shape(x) 95 | x_original_shape = x_shape 96 | 97 | if len(x.shape) == 2: 98 | x = tf.reshape(x, [x_shape[0], 1, 1, x_shape[1]]) 99 | 100 | x_shape = tf.shape(x) 101 | 102 | N, H, W, C = x_shape[0], x_shape[1], x_shape[2], x_shape[3] 103 | block = self.block 104 | 105 | # take the first c channels out for deconv 106 | c = tf.cast(C / block, tf.int32) * block 107 | 108 | # step 1. remove mean 109 | if c != C: 110 | x1 = tf.reshape(tf.transpose(x[:, :, :, :c], [3, 0, 1, 2]), [block, -1]) 111 | else: 112 | x1 = tf.reshape(tf.transpose(x, [3, 0, 1, 2]), [block, -1]) 113 | 114 | if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride: 115 | x1_s = x1[:, ::self.sampling_stride ** 2] 116 | else: 117 | x1_s = x1 118 | 119 | mean1 = tf.reduce_mean(x1_s, axis=-1, keepdims=True) # [blocks, 1] 120 | 121 | if self.num_batches_tracked == 0: 122 | self.running_mean1.assign(mean1) 123 | 124 | if training: 125 | running_mean1 = self.momentum * mean1 + (1. - self.momentum) * self.running_mean1 126 | self.running_mean1.assign(running_mean1) 127 | else: 128 | mean1 = self.running_mean1 129 | 130 | x1 = x1 - mean1 131 | 132 | # step 2. calculate deconv@x1 = cov^(-0.5)@x1 133 | if training: 134 | scale_ = tf.cast(tf.shape(x1_s)[1], x1_s.dtype) 135 | cov = (tf.matmul(x1_s, tf.transpose(x1_s)) / scale_) + self.eps * self.block_eye 136 | deconv = isqrt_newton_schulz_autograd(cov, self.n_iter) 137 | else: 138 | deconv = self.running_deconv 139 | 140 | if self.num_batches_tracked == 0: 141 | self.running_deconv.assign(deconv) 142 | 143 | if training: 144 | running_deconv = self.momentum * deconv + (1. - self.momentum) * self.running_deconv 145 | self.running_deconv.assign(running_deconv) 146 | else: 147 | deconv = self.running_deconv 148 | 149 | x1 = tf.matmul(deconv, x1) 150 | 151 | # reshape to N,c,J,W 152 | x1 = tf.reshape(x1, [c, N, H, W]) 153 | x1 = tf.transpose(x1, [1, 2, 3, 0]) # [N, H, W, C] 154 | 155 | # normalize the remaining channels 156 | if c != C: 157 | x_tmp = tf.reshape(x[:, :, :, c:], [N, -1]) 158 | 159 | if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride: 160 | x_s = x_tmp[:, ::self.sampling_stride ** 2] 161 | else: 162 | x_s = x_tmp 163 | 164 | mean2, var = tf.nn.moments(x_s, axes=[0, 1]) 165 | 166 | if self.num_batches_tracked == 0: 167 | self.running_mean2.assign(mean2) 168 | self.running_var.assign(var) 169 | 170 | if training: 171 | running_mean2 = self.momentum * mean2 + (1. - self.momentum) * self.running_mean2 172 | running_var = self.momentum * var + (1. - self.momentum) * self.running_var 173 | self.running_mean2.assign(running_mean2) 174 | self.running_var.assign(running_var) 175 | else: 176 | mean2 = self.running_mean2 177 | var = self.running_var 178 | 179 | x_tmp = tf.sqrt((x[:, :, :, c:] - mean2) / (var + self.eps)) 180 | x1 = tf.concat([x1, x_tmp], axis=-1) 181 | 182 | if training: 183 | self.num_batches_tracked.assign_add(1) 184 | 185 | if len(x_original_shape) == 2: 186 | x1 = tf.reshape(x1, x_original_shape) 187 | else: 188 | x_intshape = x.shape 189 | x1 = tf.ensure_shape(x1, x_intshape) 190 | 191 | return x1 192 | 193 | 194 | class FastDeconv2D(Conv): 195 | 196 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding='valid', dilation_rate=1, 197 | activation=None, use_bias=True, groups=1, eps=1e-5, n_iter=5, momentum=0.1, block=64, 198 | sampling_stride=3, freeze=False, freeze_iter=100, kernel_initializer='he_uniform', 199 | bias_initializer=BiasHeUniform(), **kwargs): 200 | self.in_channels = in_channels 201 | self.groups = groups 202 | self.momentum = momentum 203 | self.n_iter = n_iter 204 | self.eps = eps 205 | self.counter = 0 206 | self.track_running_stats = True 207 | 208 | if in_channels % self.groups != 0: 209 | raise ValueError( 210 | 'The number of input channels must be evenly divisible by the number ' 211 | 'of groups. Received groups={}, but the input has {} channels '.format(self.groups, 212 | in_channels)) 213 | if out_channels is not None and out_channels % self.groups != 0: 214 | raise ValueError( 215 | 'The number of filters must be evenly divisible by the number of ' 216 | 'groups. Received: groups={}, filters={}'.format(groups, out_channels)) 217 | 218 | super(FastDeconv2D, self).__init__( 219 | 2, out_channels, kernel_size, stride, padding, dilation_rate=dilation_rate, 220 | activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, 221 | bias_initializer=bias_initializer, **kwargs 222 | ) 223 | 224 | if block > in_channels: 225 | block = in_channels 226 | else: 227 | if in_channels % block != 0: 228 | block = math.gcd(block, in_channels) 229 | print("`in_channels` not divisible by `block`, computing new `block` value as %d" % (block)) 230 | 231 | if groups > 1: 232 | block = in_channels // groups 233 | 234 | self.block = block 235 | 236 | kernel_size_int_0 = kernel_size[0] if type(kernel_size) in (list, tuple) else kernel_size 237 | kernel_size_int_1 = kernel_size[1] if type(kernel_size) in (list, tuple) else kernel_size 238 | self.num_features = kernel_size_int_0 * kernel_size_int_1 * block 239 | 240 | if self.groups == 1: 241 | self.running_mean = tf.Variable(tf.zeros(self.num_features), trainable=False, dtype=self.dtype) 242 | self.running_deconv = tf.Variable(tf.eye(self.num_features), trainable=False, dtype=self.dtype) 243 | else: 244 | self.running_mean = tf.Variable(tf.zeros(kernel_size_int_0 * kernel_size_int_1 * in_channels), 245 | trainable=False, dtype=self.dtype) 246 | 247 | deconv_buff = tf.eye(self.num_features) 248 | deconv_buff = tf.expand_dims(deconv_buff, axis=0) 249 | deconv_buff = tf.tile(deconv_buff, [in_channels // block, 1, 1]) 250 | self.running_deconv = tf.Variable(deconv_buff, trainable=False, dtype=self.dtype) 251 | 252 | stride_int = stride[0] if type(stride) in (list, tuple) else stride 253 | self.sampling_stride = sampling_stride * stride_int 254 | self.counter = tf.Variable(tf.convert_to_tensor(0, dtype=tf.int64), trainable=False) 255 | self.freeze_iter = freeze_iter 256 | self.freeze = freeze 257 | 258 | def build(self, input_shape): 259 | input_shape = tf.TensorShape(input_shape) 260 | input_channel = self._get_input_channel(input_shape) 261 | kernel_shape = self.kernel_size + (input_channel // self.groups, self.filters) 262 | 263 | self.kernel = self.add_weight( 264 | name='kernel', 265 | shape=kernel_shape, 266 | initializer=self.kernel_initializer, 267 | regularizer=self.kernel_regularizer, 268 | constraint=self.kernel_constraint, 269 | trainable=True, 270 | dtype=self.dtype) 271 | if self.use_bias: 272 | self.bias = self.add_weight( 273 | name='bias', 274 | shape=(self.filters,), 275 | initializer=self.bias_initializer, 276 | regularizer=self.bias_regularizer, 277 | constraint=self.bias_constraint, 278 | trainable=True, 279 | dtype=self.dtype) 280 | else: 281 | self.bias = None 282 | channel_axis = self._get_channel_axis() 283 | 284 | # change rank based on 3d or 4d tensor input 285 | ndim = len(input_shape) 286 | 287 | self.input_spec = tf.keras.layers.InputSpec(min_ndim=3, 288 | max_ndim=4, 289 | axes={channel_axis: input_channel}) 290 | 291 | self._build_conv_op_input_shape = input_shape 292 | self._build_input_channel = input_channel 293 | self._padding_op = self._get_padding_op() 294 | self._conv_op_data_format = conv_utils.convert_data_format( 295 | self.data_format, self.rank + 2) 296 | self.built = True 297 | 298 | @tf.function(experimental_compile=False) 299 | def call(self, x, training=None): 300 | x_shape = tf.shape(x) 301 | N, H, W, C = x_shape[0], x_shape[1], x_shape[2], x_shape[3] 302 | 303 | block = self.block 304 | frozen = self.freeze and (self.counter > self.freeze_iter) 305 | 306 | if training and self.track_running_stats: 307 | counter = self.counter + 1 308 | counter = counter % (self.freeze_iter * 10) 309 | self.counter.assign(counter) 310 | 311 | if training and (not frozen): 312 | 313 | # 1. im2col: N x cols x pixels -> N*pixles x cols 314 | if self.kernel_size[0] > 1: 315 | # [N, L, L, C * K^2] 316 | X = tf.image.extract_patches(x, 317 | sizes=[1] + list(self.kernel_size) + [1], 318 | strides=[1, self.sampling_stride, self.sampling_stride, 1], 319 | rates=[1, self.dilation_rate[0], self.dilation_rate[1], 1], 320 | padding=str(self.padding).upper()) 321 | 322 | X = tf.reshape(X, [N, -1, C * self.kernel_size[0] * self.kernel_size[1]]) # [N, L^2, C * K^2] 323 | 324 | else: 325 | # channel wise ([N, H, W, C] -> [N * H * W, C] -> [N * H / S * W / S, C] 326 | X = tf.reshape(x, [-1, C])[::self.sampling_stride ** 2, :] 327 | 328 | if self.groups == 1: 329 | # (C//B*N*pixels,k*k*B) 330 | X = tf.reshape(X, [-1, self.num_features, C // block]) 331 | X = tf.transpose(X, [0, 2, 1]) 332 | X = tf.reshape(X, [-1, self.num_features]) 333 | else: 334 | X_shape_ = tf.shape(X) 335 | X = tf.reshape(X, [-1, X_shape_[-1]]) # [N, L^2, C * K^2] -> [N * L^2, C * K^2] 336 | 337 | # 2. subtract mean 338 | X_mean = tf.reduce_mean(X, axis=0) 339 | X = X - tf.expand_dims(X_mean, axis=0) 340 | 341 | # 3. calculate COV, COV^(-0.5), then deconv 342 | if self.groups == 1: 343 | scale = tf.cast(tf.shape(X)[0], X.dtype) 344 | Id = tf.eye(X.shape[1], dtype=X.dtype) 345 | # addmm op 346 | Cov = self.eps * Id + (1. / scale) * tf.matmul(tf.transpose(X), X) 347 | deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter) 348 | else: 349 | X = tf.reshape(X, [-1, self.groups, self.num_features]) 350 | X = tf.transpose(X, [1, 0, 2]) # [groups, -1, num_features] 351 | 352 | Id = tf.eye(self.num_features, dtype=X.dtype) 353 | Id = tf.broadcast_to(Id, [self.groups, self.num_features, self.num_features]) 354 | 355 | scale = tf.cast(tf.shape(X)[1], X.dtype) 356 | Cov = self.eps * Id + (1. / scale) * tf.matmul(tf.transpose(X, [0, 2, 1]), X) 357 | 358 | deconv = isqrt_newton_schulz_autograd_batch(Cov, self.n_iter) 359 | 360 | if self.track_running_stats: 361 | running_mean = self.momentum * X_mean + (1. - self.momentum) * self.running_mean 362 | running_deconv = self.momentum * deconv + (1. - self.momentum) * self.running_deconv 363 | 364 | # track stats for evaluation 365 | self.running_mean.assign(running_mean) 366 | self.running_deconv.assign(running_deconv) 367 | 368 | else: 369 | X_mean = self.running_mean 370 | deconv = self.running_deconv 371 | 372 | # 4. X * deconv * conv = X * (deconv * conv) 373 | if self.groups == 1: 374 | w = tf.reshape(self.kernel, [C // block, self.num_features, -1]) 375 | w = tf.transpose(w, [0, 2, 1]) 376 | w = tf.reshape(w, [-1, self.num_features]) 377 | w = tf.matmul(w, deconv) 378 | 379 | if self.use_bias: 380 | b_dash = tf.matmul(w, (tf.expand_dims(X_mean, axis=-1))) 381 | b_dash = tf.reshape(b_dash, [self.filters, -1]) 382 | b_dash = tf.reduce_sum(b_dash, axis=1) 383 | b = self.bias - b_dash 384 | else: 385 | b = 0. 386 | 387 | w = tf.reshape(w, [C // block, -1, self.num_features]) 388 | w = tf.transpose(w, [0, 2, 1]) 389 | 390 | else: 391 | w = tf.reshape(self.kernel, [C // block, -1, self.num_features]) 392 | w = tf.matmul(w, deconv) 393 | 394 | if self.use_bias: 395 | b_dash = tf.matmul(w, tf.reshape(X_mean, [-1, self.num_features, 1])) 396 | b_dash = tf.reshape(b_dash, self.bias.shape) 397 | b = self.bias - b_dash 398 | else: 399 | b = 0. 400 | 401 | w = tf.reshape(w, self.kernel.shape) 402 | 403 | x = tf.nn.conv2d(x, w, self.strides, str(self.padding).upper(), dilations=self.dilation_rate) 404 | if self.use_bias: 405 | x = tf.nn.bias_add(x, b, data_format="NHWC") 406 | 407 | if self.activation is not None: 408 | return self.activation(x) 409 | else: 410 | return x 411 | 412 | 413 | """ 1D Compat layers """ 414 | 415 | 416 | class ChannelDeconv1D(ChannelDeconv2D): 417 | 418 | def __init__(self, block, eps=1e-5, n_iter=5, momentum=0.1, sampling_stride=3, **kwargs): 419 | super(ChannelDeconv1D, self).__init__(block=block, eps=eps, n_iter=n_iter, 420 | momentum=momentum, sampling_stride=sampling_stride, **kwargs) 421 | 422 | self.input_spec = tf.keras.layers.InputSpec(min_ndim=2, max_ndim=3) 423 | 424 | @tf.function 425 | def call(self, x, training=None): 426 | # insert dummy dimension in time channel 427 | shape = x.shape 428 | 429 | if len(shape) == 3: 430 | x_expanded = tf.expand_dims(x, axis=2) # [N, T, 1, C] 431 | else: 432 | x_expanded = x 433 | 434 | out = super(ChannelDeconv1D, self).call(x_expanded, training=training) 435 | 436 | if len(shape) == 3: 437 | # remove dummy dimension 438 | x = tf.squeeze(out, axis=2) # [N, T / stride, C'] 439 | else: 440 | x = out 441 | 442 | return x 443 | 444 | 445 | class FastDeconv1D(FastDeconv2D): 446 | 447 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding='valid', dilation_rate=1, 448 | activation=None, use_bias=True, groups=1, eps=1e-5, n_iter=5, momentum=0.1, block=64, 449 | sampling_stride=3, freeze=False, freeze_iter=100, kernel_initializer='he_uniform', 450 | bias_initializer=BiasHeUniform(), **kwargs): 451 | kernel_size = (kernel_size, 1) 452 | stride = (stride, 1) 453 | super(FastDeconv1D, self).__init__(in_channels=in_channels, out_channels=out_channels, 454 | kernel_size=kernel_size, stride=stride, padding=padding, 455 | dilation_rate=dilation_rate, activation=activation, 456 | use_bias=use_bias, groups=groups, eps=eps, 457 | n_iter=n_iter, momentum=momentum, block=block, 458 | sampling_stride=sampling_stride, freeze=freeze, freeze_iter=freeze_iter, 459 | kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, 460 | **kwargs) 461 | 462 | self.input_spec = tf.keras.layers.InputSpec(ndim=3) 463 | 464 | @tf.function(experimental_compile=False) 465 | def call(self, x, training=None): 466 | # insert dummy dimension in time channel 467 | x_expanded = tf.expand_dims(x, axis=2) # [N, T, 1, C] 468 | 469 | out = super(FastDeconv1D, self).call(x_expanded, training=training) 470 | 471 | # remove dummy dimension 472 | x = tf.squeeze(out, axis=2) # [N, T / stride, C'] 473 | return x 474 | -------------------------------------------------------------------------------- /tf_deconv_mixed_prec.py: -------------------------------------------------------------------------------- 1 | import math 2 | import tensorflow as tf 3 | 4 | from tensorflow.python.keras.layers.convolutional import Conv 5 | from tensorflow.python.keras.utils import conv_utils 6 | from tensorflow.keras.mixed_precision import experimental as mixed_precision 7 | 8 | 9 | class BiasHeUniform(tf.keras.initializers.VarianceScaling): 10 | def __init__(self, seed=None): 11 | super(BiasHeUniform, self).__init__(scale=1. / 3., mode='fan_in', distribution='uniform', seed=seed) 12 | 13 | 14 | # iteratively solve for inverse sqrt of a matrix 15 | def isqrt_newton_schulz_autograd(A: tf.Tensor, numIters: int): 16 | dim = tf.shape(A)[0] 17 | A_dtype = A.dtype 18 | A = tf.cast(A, tf.float32) 19 | 20 | normA = tf.norm(A, ord='fro', axis=[0, 1]) 21 | Y = A / normA 22 | 23 | with tf.device(A.device): 24 | I = tf.eye(dim, dtype=A.dtype) 25 | Z = tf.eye(dim, dtype=A.dtype) 26 | 27 | for i in range(numIters): 28 | T = 0.5 * (3.0 * I - tf.matmul(Z, Y)) 29 | Y = tf.matmul(Y, T) 30 | Z = tf.matmul(T, Z) 31 | 32 | A_isqrt = Z / tf.sqrt(normA) 33 | A_isqrt = tf.cast(A_isqrt, A_dtype) 34 | return A_isqrt 35 | 36 | 37 | def isqrt_newton_schulz_autograd_batch(A: tf.Tensor, numIters: int): 38 | Ashape = tf.shape(A) # [batch, _, C] 39 | batchSize, dim = Ashape[0], Ashape[-1] 40 | 41 | A_dtype = A.dtype 42 | A = tf.cast(A, tf.float32) 43 | 44 | normA = tf.reshape(A, (batchSize, -1)) 45 | normA = tf.norm(normA, ord=2, axis=1) 46 | normA = tf.reshape(normA, [batchSize, 1, 1]) 47 | 48 | Y = A / normA 49 | 50 | with tf.device(A.device): 51 | I = tf.expand_dims(tf.eye(dim, dtype=A.dtype), 0) 52 | Z = tf.expand_dims(tf.eye(dim, dtype=A.dtype), 0) 53 | 54 | I = tf.broadcast_to(I, Ashape) 55 | Z = tf.broadcast_to(Z, Ashape) 56 | 57 | for i in range(numIters): 58 | T = 0.5 * (3.0 * I - tf.matmul(Z, Y)) 59 | Y = tf.matmul(Y, T) 60 | Z = tf.matmul(T, Z) 61 | 62 | A_isqrt = Z / tf.sqrt(normA) 63 | A_isqrt = tf.cast(A_isqrt, A_dtype) 64 | 65 | return A_isqrt 66 | 67 | 68 | class ChannelDeconv2D(tf.keras.layers.Layer): 69 | def __init__(self, block, eps=1e-5, n_iter=5, momentum=0.1, sampling_stride=3, **kwargs): 70 | super(ChannelDeconv2D, self).__init__(**kwargs) 71 | 72 | self.eps = eps 73 | self.n_iter = n_iter 74 | self.momentum = momentum 75 | self.block = block 76 | self.sampling_stride = sampling_stride 77 | 78 | self.running_mean1 = tf.Variable(tf.zeros([block, 1]), trainable=False, dtype='float32') 79 | self.running_mean2 = tf.Variable(tf.zeros([]), trainable=False, dtype='float32') 80 | self.running_var = tf.Variable(tf.ones([]), trainable=False, dtype='float32') 81 | self.running_deconv = tf.Variable(tf.eye(block), trainable=False, dtype='float32') 82 | self.num_batches_tracked = tf.Variable(tf.convert_to_tensor(0, dtype=tf.int64), trainable=False) 83 | 84 | self.block_eye = tf.eye(block) 85 | 86 | def build(self, input_shape): 87 | in_channels = input_shape[-1] 88 | self.in_channels = in_channels 89 | 90 | if int(in_channels / self.block) * self.block == 0: 91 | raise ValueError("`block` must be smaller than in_channels.") 92 | 93 | # change rank based on 3d or 4d tensor input 94 | channel_axis = -1 95 | self.input_spec = tf.keras.layers.InputSpec(min_ndim=2, 96 | max_ndim=4, 97 | axes={channel_axis: in_channels}) 98 | 99 | self.built = True 100 | 101 | @tf.function 102 | def call(self, x, training=None): 103 | x_shape = tf.shape(x) 104 | x_original_shape = x_shape 105 | 106 | if len(x.shape) == 2: 107 | x = tf.reshape(x, [x_shape[0], 1, 1, x_shape[1]]) 108 | 109 | x_shape = tf.shape(x) 110 | 111 | N, H, W, C = x_shape[0], x_shape[1], x_shape[2], x_shape[3] 112 | block = self.block 113 | 114 | # take the first c channels out for deconv 115 | c = tf.cast(C / block, tf.int32) * block 116 | 117 | # step 1. remove mean 118 | if c != C: 119 | x1 = tf.reshape(tf.transpose(x[:, :, :, :c], [3, 0, 1, 2]), [block, -1]) 120 | else: 121 | x1 = tf.reshape(tf.transpose(x, [3, 0, 1, 2]), [block, -1]) 122 | 123 | if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride: 124 | x1_s = x1[:, ::self.sampling_stride ** 2] 125 | else: 126 | x1_s = x1 127 | 128 | mean1 = tf.reduce_mean(x1_s, axis=-1, keepdims=True) # [blocks, 1] 129 | 130 | if self.num_batches_tracked == 0: 131 | self.running_mean1.assign(tf.cast(mean1, self.running_mean1.dtype)) 132 | 133 | if training: 134 | running_mean1 = tf.cast(self.momentum * mean1, self.running_mean1.dtype) + (1. - self.momentum) * self.running_mean1 135 | self.running_mean1.assign(running_mean1) 136 | else: 137 | mean1 = tf.cast(self.running_mean1, mean1.dtype) 138 | 139 | x1 = x1 - mean1 140 | 141 | # step 2. calculate deconv@x1 = cov^(-0.5)@x1 142 | if training: 143 | scale_ = tf.cast(tf.shape(x1_s)[1], x1_s.dtype) 144 | cov = (tf.matmul(x1_s, tf.transpose(x1_s)) / scale_) + tf.cast(self.eps * self.block_eye, x1_s.dtype) 145 | deconv = isqrt_newton_schulz_autograd(cov, self.n_iter) 146 | else: 147 | deconv = tf.cast(self.running_deconv, x1_s.dtype) 148 | 149 | if self.num_batches_tracked == 0: 150 | self.running_deconv.assign(tf.cast(deconv, self.running_deconv.dtype)) 151 | 152 | if training: 153 | running_deconv = tf.cast(self.momentum * deconv, self.running_deconv.dtype) + (1. - self.momentum) * self.running_deconv 154 | self.running_deconv.assign(running_deconv) 155 | else: 156 | deconv = tf.cast(self.running_deconv, x.dtype) 157 | 158 | deconv = tf.cast(deconv, x1.dtype) 159 | x1 = tf.matmul(deconv, x1) 160 | 161 | # reshape to N,c,J,W 162 | x1 = tf.reshape(x1, [c, N, H, W]) 163 | x1 = tf.transpose(x1, [1, 2, 3, 0]) # [N, H, W, C] 164 | 165 | # normalize the remaining channels 166 | if c != C: 167 | x_tmp = tf.reshape(x[:, :, :, c:], [N, -1]) 168 | 169 | if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride: 170 | x_s = x_tmp[:, ::self.sampling_stride ** 2] 171 | else: 172 | x_s = x_tmp 173 | 174 | x_s = tf.cast(x_s, self.running_mean2.dtype) 175 | mean2, var = tf.nn.moments(x_s, axes=[0, 1]) 176 | 177 | if self.num_batches_tracked == 0: 178 | self.running_mean2.assign(mean2) 179 | self.running_var.assign(var) 180 | 181 | if training: 182 | running_mean2 = self.momentum * mean2 + (1. - self.momentum) * self.running_mean2 183 | running_var = self.momentum * var + (1. - self.momentum) * self.running_var 184 | self.running_mean2.assign(running_mean2) 185 | self.running_var.assign(running_var) 186 | else: 187 | mean2 = self.running_mean2 188 | var = self.running_var 189 | 190 | x_tmp = tf.sqrt((tf.cast(x[:, :, :, c:], mean2.dtype) - mean2) / (var + self.eps)) 191 | x_tmp = tf.cast(x_tmp, x1.dtype) 192 | x1 = tf.concat([x1, x_tmp], axis=-1) 193 | 194 | if training: 195 | self.num_batches_tracked.assign_add(1) 196 | 197 | if len(x_original_shape) == 2: 198 | x1 = tf.reshape(x1, x_original_shape) 199 | else: 200 | x_intshape = x.shape 201 | x1 = tf.ensure_shape(x1, x_intshape) 202 | 203 | return x1 204 | 205 | 206 | class FastDeconv2D(Conv): 207 | 208 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding='valid', dilation_rate=1, 209 | activation=None, use_bias=True, groups=1, eps=1e-5, n_iter=5, momentum=0.1, block=64, 210 | sampling_stride=3, freeze=False, freeze_iter=100, kernel_initializer='he_uniform', 211 | bias_initializer=BiasHeUniform(), **kwargs): 212 | self.in_channels = in_channels 213 | self.groups = groups 214 | self.momentum = momentum 215 | self.n_iter = n_iter 216 | self.eps = eps 217 | self.counter = 0 218 | self.track_running_stats = True 219 | 220 | if in_channels % self.groups != 0: 221 | raise ValueError( 222 | 'The number of input channels must be evenly divisible by the number ' 223 | 'of groups. Received groups={}, but the input has {} channels '.format(self.groups, 224 | in_channels)) 225 | if out_channels is not None and out_channels % self.groups != 0: 226 | raise ValueError( 227 | 'The number of filters must be evenly divisible by the number of ' 228 | 'groups. Received: groups={}, filters={}'.format(groups, out_channels)) 229 | 230 | super(FastDeconv2D, self).__init__( 231 | 2, out_channels, kernel_size, stride, padding, dilation_rate=dilation_rate, 232 | activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, 233 | bias_initializer=bias_initializer, **kwargs 234 | ) 235 | 236 | if block > in_channels: 237 | block = in_channels 238 | else: 239 | if in_channels % block != 0: 240 | block = math.gcd(block, in_channels) 241 | print("`in_channels` not divisible by `block`, computing new `block` value as %d" % (block)) 242 | 243 | if groups > 1: 244 | block = in_channels // groups 245 | 246 | self.block = block 247 | 248 | kernel_size_int_0 = kernel_size[0] if type(kernel_size) in (list, tuple) else kernel_size 249 | kernel_size_int_1 = kernel_size[1] if type(kernel_size) in (list, tuple) else kernel_size 250 | self.num_features = kernel_size_int_0 * kernel_size_int_1 * block 251 | 252 | if self.groups == 1: 253 | self.running_mean = tf.Variable(tf.zeros(self.num_features), trainable=False, dtype='float32') 254 | self.running_deconv = tf.Variable(tf.eye(self.num_features), trainable=False, dtype='float32') 255 | else: 256 | self.running_mean = tf.Variable(tf.zeros(kernel_size_int_0 * kernel_size_int_1 * in_channels), 257 | trainable=False, dtype='float32') 258 | 259 | deconv_buff = tf.eye(self.num_features) 260 | deconv_buff = tf.expand_dims(deconv_buff, axis=0) 261 | deconv_buff = tf.tile(deconv_buff, [in_channels // block, 1, 1]) 262 | self.running_deconv = tf.Variable(deconv_buff, trainable=False, dtype='float32') 263 | 264 | stride_int = stride[0] if type(stride) in (list, tuple) else stride 265 | self.sampling_stride = sampling_stride * stride_int 266 | self.counter = tf.Variable(tf.convert_to_tensor(0, dtype=tf.int64), trainable=False) 267 | self.freeze_iter = freeze_iter 268 | self.freeze = freeze 269 | 270 | def build(self, input_shape): 271 | input_shape = tf.TensorShape(input_shape) 272 | input_channel = self._get_input_channel(input_shape) 273 | kernel_shape = self.kernel_size + (input_channel // self.groups, self.filters) 274 | 275 | self.kernel = self.add_weight( 276 | name='kernel', 277 | shape=kernel_shape, 278 | initializer=self.kernel_initializer, 279 | regularizer=self.kernel_regularizer, 280 | constraint=self.kernel_constraint, 281 | trainable=True, 282 | dtype=self.dtype) 283 | if self.use_bias: 284 | self.bias = self.add_weight( 285 | name='bias', 286 | shape=(self.filters,), 287 | initializer=self.bias_initializer, 288 | regularizer=self.bias_regularizer, 289 | constraint=self.bias_constraint, 290 | trainable=True, 291 | dtype=self.dtype) 292 | else: 293 | self.bias = None 294 | channel_axis = self._get_channel_axis() 295 | 296 | # change rank based on 3d or 4d tensor input 297 | ndim = len(input_shape) 298 | 299 | self.input_spec = tf.keras.layers.InputSpec(min_ndim=3, 300 | max_ndim=4, 301 | axes={channel_axis: input_channel}) 302 | 303 | self._build_conv_op_input_shape = input_shape 304 | self._build_input_channel = input_channel 305 | self._padding_op = self._get_padding_op() 306 | self._conv_op_data_format = conv_utils.convert_data_format( 307 | self.data_format, self.rank + 2) 308 | self.built = True 309 | 310 | @tf.function(experimental_compile=False) 311 | def call(self, x, training=None): 312 | x_shape = tf.shape(x) 313 | N, H, W, C = x_shape[0], x_shape[1], x_shape[2], x_shape[3] 314 | 315 | block = self.block 316 | frozen = self.freeze and (self.counter > self.freeze_iter) 317 | 318 | if training and self.track_running_stats: 319 | counter = self.counter + 1 320 | counter = counter % (self.freeze_iter * 10) 321 | self.counter.assign(counter) 322 | 323 | if training and (not frozen): 324 | 325 | # 1. im2col: N x cols x pixels -> N*pixles x cols 326 | if self.kernel_size[0] > 1: 327 | # [N, L, L, C * K^2] 328 | X = tf.image.extract_patches(tf.cast(x, tf.float32), 329 | sizes=[1] + list(self.kernel_size) + [1], 330 | strides=[1, self.sampling_stride, self.sampling_stride, 1], 331 | rates=[1, self.dilation_rate[0], self.dilation_rate[1], 1], 332 | padding=str(self.padding).upper()) 333 | 334 | X = tf.cast(X, x.dtype) 335 | X = tf.reshape(X, [N, -1, C * self.kernel_size[0] * self.kernel_size[1]]) # [N, L^2, C * K^2] 336 | 337 | else: 338 | # channel wise ([N, H, W, C] -> [N * H * W, C] -> [N * H / S * W / S, C] 339 | X = tf.reshape(x, [-1, C])[::self.sampling_stride ** 2, :] 340 | 341 | if self.groups == 1: 342 | # (C//B*N*pixels,k*k*B) 343 | X = tf.reshape(X, [-1, self.num_features, C // block]) 344 | X = tf.transpose(X, [0, 2, 1]) 345 | X = tf.reshape(X, [-1, self.num_features]) 346 | else: 347 | X_shape_ = tf.shape(X) 348 | X = tf.reshape(X, [-1, X_shape_[-1]]) # [N, L^2, C * K^2] -> [N * L^2, C * K^2] 349 | 350 | # 2. subtract mean 351 | # X = tf.cast(X, tf.float32) 352 | X_mean = tf.reduce_mean(X, axis=0) 353 | X = X - tf.expand_dims(X_mean, axis=0) 354 | 355 | # 3. calculate COV, COV^(-0.5), then deconv 356 | if self.groups == 1: 357 | scale = tf.cast(tf.shape(X)[0], X.dtype) 358 | Id = tf.eye(X.shape[1], dtype=X.dtype) 359 | # addmm op 360 | Cov = self.eps * Id + (1. / scale) * tf.matmul(tf.transpose(X), X) 361 | deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter) 362 | else: 363 | X = tf.cast(X, tf.float32) 364 | X = tf.reshape(X, [-1, self.groups, self.num_features]) 365 | X = tf.transpose(X, [1, 0, 2]) # [groups, -1, num_features] 366 | 367 | Id = tf.eye(self.num_features, dtype=X.dtype) 368 | Id = tf.broadcast_to(Id, [self.groups, self.num_features, self.num_features]) 369 | 370 | scale = tf.cast(tf.shape(X)[1], X.dtype) 371 | Cov = self.eps * Id + (1. / scale) * tf.matmul(tf.transpose(X, [0, 2, 1]), X) 372 | 373 | deconv = isqrt_newton_schulz_autograd_batch(Cov, self.n_iter) 374 | deconv = tf.cast(deconv, x.dtype) 375 | 376 | if self.track_running_stats: 377 | running_mean = tf.cast(self.momentum * X_mean, self.running_mean.dtype) + (1. - self.momentum) * self.running_mean 378 | running_deconv = tf.cast(self.momentum * deconv, self.running_deconv.dtype) + (1. - self.momentum) * self.running_deconv 379 | 380 | # track stats for evaluation 381 | self.running_mean.assign(running_mean) 382 | self.running_deconv.assign(running_deconv) 383 | 384 | else: 385 | X_mean = tf.cast(self.running_mean, x.dtype) 386 | deconv = tf.cast(self.running_deconv, x.dtype) 387 | 388 | # 4. X * deconv * conv = X * (deconv * conv) 389 | if self.groups == 1: 390 | w = tf.reshape(self.kernel, [C // block, self.num_features, -1]) 391 | w = tf.transpose(w, [0, 2, 1]) 392 | w = tf.reshape(w, [-1, self.num_features]) 393 | w = tf.matmul(w, tf.cast(deconv, w.dtype)) 394 | 395 | if self.use_bias: 396 | b_dash = tf.matmul(w, tf.cast(tf.expand_dims(X_mean, axis=-1), dtype=w.dtype)) 397 | b_dash = tf.reshape(b_dash, [self.filters, -1]) 398 | b_dash = tf.reduce_sum(b_dash, axis=1) 399 | b = self.bias - b_dash 400 | else: 401 | b = 0. 402 | 403 | w = tf.reshape(w, [C // block, -1, self.num_features]) 404 | w = tf.transpose(w, [0, 2, 1]) 405 | 406 | else: 407 | w = tf.reshape(self.kernel, [C // block, -1, self.num_features]) 408 | w = tf.matmul(w, tf.cast(deconv, w.dtype)) 409 | 410 | if self.use_bias: 411 | b_dash = tf.matmul(w, tf.cast(tf.reshape(X_mean, [-1, self.num_features, 1]), dtype=w.dtype)) 412 | b_dash = tf.reshape(b_dash, self.bias.shape) 413 | b = self.bias - b_dash 414 | else: 415 | b = 0. 416 | 417 | w = tf.reshape(w, self.kernel.shape) 418 | 419 | x = tf.nn.conv2d(x, w, self.strides, str(self.padding).upper(), dilations=self.dilation_rate) 420 | if self.use_bias: 421 | x = tf.nn.bias_add(x, b, data_format="NHWC") 422 | 423 | if self.activation is not None: 424 | return self.activation(x) 425 | else: 426 | return x 427 | 428 | 429 | """ 1D Compat layers """ 430 | 431 | 432 | class ChannelDeconv1D(ChannelDeconv2D): 433 | 434 | def __init__(self, block, eps=1e-5, n_iter=5, momentum=0.1, sampling_stride=3, **kwargs): 435 | super(ChannelDeconv1D, self).__init__(block=block, eps=eps, n_iter=n_iter, 436 | momentum=momentum, sampling_stride=sampling_stride, **kwargs) 437 | 438 | self.input_spec = tf.keras.layers.InputSpec(min_ndim=2, max_ndim=3) 439 | 440 | @tf.function 441 | def call(self, x, training=None): 442 | # insert dummy dimension in time channel 443 | shape = x.shape 444 | 445 | if len(shape) == 3: 446 | x_expanded = tf.expand_dims(x, axis=2) # [N, T, 1, C] 447 | else: 448 | x_expanded = x 449 | 450 | out = super(ChannelDeconv1D, self).call(x_expanded, training=training) 451 | 452 | if len(shape) == 3: 453 | # remove dummy dimension 454 | x = tf.squeeze(out, axis=2) # [N, T / stride, C'] 455 | else: 456 | x = out 457 | 458 | return x 459 | 460 | 461 | class FastDeconv1D(FastDeconv2D): 462 | 463 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding='valid', dilation_rate=1, 464 | activation=None, use_bias=True, groups=1, eps=1e-5, n_iter=5, momentum=0.1, block=64, 465 | sampling_stride=3, freeze=False, freeze_iter=100, kernel_initializer='he_uniform', 466 | bias_initializer=BiasHeUniform(), **kwargs): 467 | kernel_size = (kernel_size, 1) 468 | stride = (stride, 1) 469 | super(FastDeconv1D, self).__init__(in_channels=in_channels, out_channels=out_channels, 470 | kernel_size=kernel_size, stride=stride, padding=padding, 471 | dilation_rate=dilation_rate, activation=activation, 472 | use_bias=use_bias, groups=groups, eps=eps, 473 | n_iter=n_iter, momentum=momentum, block=block, 474 | sampling_stride=sampling_stride, freeze=freeze, freeze_iter=freeze_iter, 475 | kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, 476 | **kwargs) 477 | 478 | self.input_spec = tf.keras.layers.InputSpec(ndim=3) 479 | 480 | @tf.function(experimental_compile=False) 481 | def call(self, x, training=None): 482 | # insert dummy dimension in time channel 483 | x_expanded = tf.expand_dims(x, axis=2) # [N, T, 1, C] 484 | 485 | out = super(FastDeconv1D, self).call(x_expanded, training=training) 486 | 487 | # remove dummy dimension 488 | x = tf.squeeze(out, axis=2) # [N, T / stride, C'] 489 | return x 490 | -------------------------------------------------------------------------------- /train_cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import datetime as dt 4 | import tensorflow as tf 5 | 6 | from utils.optim import AdamW, SGDW 7 | from utils.schedule import CosineDecay 8 | 9 | from models.simplenet2d import SimpleNet2D 10 | from models.vgg import VGG 11 | 12 | 13 | (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() 14 | 15 | CIFAR_10_MEAN = tf.convert_to_tensor([0.4914, 0.4822, 0.4465]) 16 | CIFAR_10_STD = tf.convert_to_tensor([0.2023, 0.1994, 0.2010]) 17 | CIFAR_MEAN = tf.reshape(CIFAR_10_MEAN, [1, 1, 3]) 18 | CIFAR_STD = tf.reshape(CIFAR_10_STD, [1, 1, 3]) 19 | 20 | 21 | def augment(x, y): 22 | x = tf.image.resize_with_crop_or_pad(x, 36, 36) 23 | x = tf.image.random_crop(x, [32, 32, 3]) 24 | x = tf.image.random_flip_left_right(x) 25 | return x, y 26 | 27 | 28 | # Dataset pipelines 29 | train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000) 30 | train_dataset = train_dataset.map(lambda x, y: (tf.cast(x, tf.float32) / 255., y)) 31 | train_dataset = train_dataset.map(augment, num_parallel_calls=os.cpu_count()) 32 | train_dataset = train_dataset.map(lambda x, y: ((x - CIFAR_MEAN) / CIFAR_STD, y)) 33 | train_dataset = train_dataset.batch(128) 34 | train_dataset = train_dataset.prefetch(4) 35 | 36 | test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) 37 | test_dataset = test_dataset.map(lambda x, y: (tf.cast(x, tf.float32) / 255., y)) 38 | test_dataset = test_dataset.map(lambda x, y: ((x - CIFAR_MEAN) / CIFAR_STD, y)) 39 | test_dataset = test_dataset.batch(100) 40 | 41 | logdir = './log_cifar10/{}'.format(dt.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")) 42 | if not os.path.exists(logdir): 43 | os.makedirs(logdir) 44 | 45 | if not os.path.exists('checkpoints/cifar10/'): 46 | os.makedirs('checkpoints/cifar10/') 47 | 48 | callbacks = [ 49 | # Write TensorBoard logs to `./logs` directory 50 | tf.keras.callbacks.TensorBoard(log_dir=logdir, update_freq='batch', profile_batch=0), 51 | tf.keras.callbacks.ModelCheckpoint('checkpoints/cifar10/', monitor='loss', 52 | verbose=2, save_best_only=True, 53 | save_weights_only=True, mode='min') 54 | ] 55 | 56 | # model = SimpleNet(num_classes=10, num_channels=64) 57 | model = VGG(vgg_name='VGG16', num_classes=10) 58 | 59 | epochs = 1 # should be 1, 20 or 100 60 | 61 | # SGDW Optimizer 62 | total_steps = math.ceil(len(x_train) / float(128)) * max(1, epochs) 63 | lr = CosineDecay(0.1, decay_steps=total_steps, alpha=1e-6) 64 | optimizer = SGDW(lr, momentum=0.9, nesterov=True, weight_decay=0.001) 65 | 66 | model.compile(optimizer=optimizer, 67 | loss='sparse_categorical_crossentropy', 68 | metrics=['acc']) 69 | 70 | model.fit(train_dataset, epochs=epochs, 71 | validation_data=test_dataset, 72 | callbacks=callbacks) 73 | -------------------------------------------------------------------------------- /utils/optim.py: -------------------------------------------------------------------------------- 1 | # Inported from TF Addons library - https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/SGDW 2 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Base class to make optimizers weight decay ready.""" 17 | 18 | import tensorflow as tf 19 | from typing import Union, Callable, Type 20 | 21 | 22 | def _ref(var): 23 | return var.ref() if hasattr(var, "ref") else var.experimental_ref() 24 | 25 | 26 | class DecoupledWeightDecayExtension: 27 | """This class allows to extend optimizers with decoupled weight decay. 28 | It implements the decoupled weight decay described by Loshchilov & Hutter 29 | (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is 30 | decoupled from the optimization steps w.r.t. to the loss function. 31 | For SGD variants, this simplifies hyperparameter search since it decouples 32 | the settings of weight decay and learning rate. 33 | For adaptive gradient algorithms, it regularizes variables with large 34 | gradients more than L2 regularization would, which was shown to yield 35 | better training loss and generalization error in the paper above. 36 | This class alone is not an optimizer but rather extends existing 37 | optimizers with decoupled weight decay. We explicitly define the two 38 | examples used in the above paper (SGDW and AdamW), but in general this 39 | can extend any OptimizerX by using 40 | `extend_with_decoupled_weight_decay( 41 | OptimizerX, weight_decay=weight_decay)`. 42 | In order for it to work, it must be the first class the Optimizer with 43 | weight decay inherits from, e.g. 44 | ```python 45 | class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): 46 | def __init__(self, weight_decay, *args, **kwargs): 47 | super(AdamW, self).__init__(weight_decay, *args, **kwargs). 48 | ``` 49 | Note: this extension decays weights BEFORE applying the update based 50 | on the gradient, i.e. this extension only has the desired behaviour for 51 | optimizers which do not depend on the value of'var' in the update step! 52 | Note: when applying a decay to the learning rate, be sure to manually apply 53 | the decay to the `weight_decay` as well. For example: 54 | ```python 55 | step = tf.Variable(0, trainable=False) 56 | schedule = tf.optimizers.schedules.PiecewiseConstantDecay( 57 | [10000, 15000], [1e-0, 1e-1, 1e-2]) 58 | # lr and wd can be a function or a tensor 59 | lr = 1e-1 * schedule(step) 60 | wd = lambda: 1e-4 * schedule(step) 61 | # ... 62 | optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) 63 | ``` 64 | """ 65 | 66 | def __init__(self, weight_decay, **kwargs): 67 | """Extension class that adds weight decay to an optimizer. 68 | Args: 69 | weight_decay: A `Tensor` or a floating point value, the factor by 70 | which a variable is decayed in the update step. 71 | **kwargs: Optional list or tuple or set of `Variable` objects to 72 | decay. 73 | """ 74 | wd = kwargs.pop("weight_decay", weight_decay) 75 | super().__init__(**kwargs) 76 | self._decay_var_list = None # is set in minimize or apply_gradients 77 | self._set_hyper("weight_decay", wd) 78 | 79 | def get_config(self): 80 | config = super().get_config() 81 | config.update( 82 | {"weight_decay": self._serialize_hyperparameter("weight_decay"),} 83 | ) 84 | return config 85 | 86 | def minimize(self, loss, var_list, grad_loss=None, name=None, decay_var_list=None): 87 | """Minimize `loss` by updating `var_list`. 88 | This method simply computes gradient using `tf.GradientTape` and calls 89 | `apply_gradients()`. If you want to process the gradient before 90 | applying then call `tf.GradientTape` and `apply_gradients()` explicitly 91 | instead of using this function. 92 | Args: 93 | loss: A callable taking no arguments which returns the value to 94 | minimize. 95 | var_list: list or tuple of `Variable` objects to update to 96 | minimize `loss`, or a callable returning the list or tuple of 97 | `Variable` objects. Use callable when the variable list would 98 | otherwise be incomplete before `minimize` since the variables 99 | are created at the first time `loss` is called. 100 | grad_loss: Optional. A `Tensor` holding the gradient computed for 101 | `loss`. 102 | decay_var_list: Optional list of variables to be decayed. Defaults 103 | to all variables in var_list. 104 | name: Optional name for the returned operation. 105 | Returns: 106 | An Operation that updates the variables in `var_list`. 107 | Raises: 108 | ValueError: If some of the variables are not `Variable` objects. 109 | """ 110 | self._decay_var_list = ( 111 | set([_ref(v) for v in decay_var_list]) if decay_var_list else False 112 | ) 113 | return super().minimize(loss, var_list=var_list, grad_loss=grad_loss, name=name) 114 | 115 | def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwargs): 116 | """Apply gradients to variables. 117 | This is the second part of `minimize()`. It returns an `Operation` that 118 | applies gradients. 119 | Args: 120 | grads_and_vars: List of (gradient, variable) pairs. 121 | name: Optional name for the returned operation. Default to the 122 | name passed to the `Optimizer` constructor. 123 | decay_var_list: Optional list of variables to be decayed. Defaults 124 | to all variables in var_list. 125 | **kwargs: Additional arguments to pass to the base optimizer's 126 | apply_gradient method, e.g., TF2.2 added an argument 127 | `all_reduce_sum_gradients`. 128 | Returns: 129 | An `Operation` that applies the specified gradients. 130 | Raises: 131 | TypeError: If `grads_and_vars` is malformed. 132 | ValueError: If none of the variables have gradients. 133 | """ 134 | self._decay_var_list = ( 135 | set([_ref(v) for v in decay_var_list]) if decay_var_list else False 136 | ) 137 | return super().apply_gradients(grads_and_vars, name=name, **kwargs) 138 | 139 | def _decay_weights_op(self, var): 140 | if not self._decay_var_list or _ref(var) in self._decay_var_list: 141 | return var.assign_sub( 142 | self._get_hyper("weight_decay", var.dtype) * var, self._use_locking 143 | ) 144 | return tf.no_op() 145 | 146 | def _decay_weights_sparse_op(self, var, indices): 147 | if not self._decay_var_list or _ref(var) in self._decay_var_list: 148 | update = -self._get_hyper("weight_decay", var.dtype) * tf.gather( 149 | var, indices 150 | ) 151 | return self._resource_scatter_add(var, indices, update) 152 | return tf.no_op() 153 | 154 | # Here, we overwrite the apply functions that the base optimizer calls. 155 | # super().apply_x resolves to the apply_x function of the BaseOptimizer. 156 | 157 | def _resource_apply_dense(self, grad, var): 158 | with tf.control_dependencies([self._decay_weights_op(var)]): 159 | return super()._resource_apply_dense(grad, var) 160 | 161 | def _resource_apply_sparse(self, grad, var, indices): 162 | decay_op = self._decay_weights_sparse_op(var, indices) 163 | with tf.control_dependencies([decay_op]): 164 | return super()._resource_apply_sparse(grad, var, indices) 165 | 166 | 167 | def extend_with_decoupled_weight_decay( 168 | base_optimizer: Type[tf.keras.optimizers.Optimizer], 169 | ) -> Type[tf.keras.optimizers.Optimizer]: 170 | """Factory function returning an optimizer class with decoupled weight 171 | decay. 172 | Returns an optimizer class. An instance of the returned class computes the 173 | update step of `base_optimizer` and additionally decays the weights. 174 | E.g., the class returned by 175 | `extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is 176 | equivalent to `tfa.optimizers.AdamW`. 177 | The API of the new optimizer class slightly differs from the API of the 178 | base optimizer: 179 | - The first argument to the constructor is the weight decay rate. 180 | - `minimize` and `apply_gradients` accept the optional keyword argument 181 | `decay_var_list`, which specifies the variables that should be decayed. 182 | If `None`, all variables that are optimized are decayed. 183 | Usage example: 184 | ```python 185 | # MyAdamW is a new class 186 | MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam) 187 | # Create a MyAdamW object 188 | optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) 189 | # update var1, var2 but only decay var1 190 | optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1]) 191 | Note: this extension decays weights BEFORE applying the update based 192 | on the gradient, i.e. this extension only has the desired behaviour for 193 | optimizers which do not depend on the value of 'var' in the update step! 194 | Note: when applying a decay to the learning rate, be sure to manually apply 195 | the decay to the `weight_decay` as well. For example: 196 | ```python 197 | step = tf.Variable(0, trainable=False) 198 | schedule = tf.optimizers.schedules.PiecewiseConstantDecay( 199 | [10000, 15000], [1e-0, 1e-1, 1e-2]) 200 | # lr and wd can be a function or a tensor 201 | lr = 1e-1 * schedule(step) 202 | wd = lambda: 1e-4 * schedule(step) 203 | # ... 204 | optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) 205 | ``` 206 | Note: you might want to register your own custom optimizer using 207 | `tf.keras.utils.get_custom_objects()`. 208 | Args: 209 | base_optimizer: An optimizer class that inherits from 210 | tf.optimizers.Optimizer. 211 | Returns: 212 | A new optimizer class that inherits from DecoupledWeightDecayExtension 213 | and base_optimizer. 214 | """ 215 | 216 | class OptimizerWithDecoupledWeightDecay( 217 | DecoupledWeightDecayExtension, base_optimizer 218 | ): 219 | """Base_optimizer with decoupled weight decay. 220 | This class computes the update step of `base_optimizer` and 221 | additionally decays the variable with the weight decay being 222 | decoupled from the optimization steps w.r.t. to the loss 223 | function, as described by Loshchilov & Hutter 224 | (https://arxiv.org/pdf/1711.05101.pdf). For SGD variants, this 225 | simplifies hyperparameter search since it decouples the settings 226 | of weight decay and learning rate. For adaptive gradient 227 | algorithms, it regularizes variables with large gradients more 228 | than L2 regularization would, which was shown to yield better 229 | training loss and generalization error in the paper above. 230 | """ 231 | 232 | def __init__( 233 | self, weight_decay, *args, **kwargs 234 | ): 235 | # super delegation is necessary here 236 | super().__init__(weight_decay, *args, **kwargs) 237 | 238 | return OptimizerWithDecoupledWeightDecay 239 | 240 | 241 | class SGDW(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD): 242 | """Optimizer that implements the Momentum algorithm with weight_decay. 243 | This is an implementation of the SGDW optimizer described in "Decoupled 244 | Weight Decay Regularization" by Loshchilov & Hutter 245 | (https://arxiv.org/abs/1711.05101) 246 | ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). 247 | It computes the update step of `tf.keras.optimizers.SGD` and additionally 248 | decays the variable. Note that this is different from adding 249 | L2 regularization on the variables to the loss. Decoupling the weight decay 250 | from other hyperparameters (in particular the learning rate) simplifies 251 | hyperparameter search. 252 | For further information see the documentation of the SGD Optimizer. 253 | This optimizer can also be instantiated as 254 | ```python 255 | extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD, 256 | weight_decay=weight_decay) 257 | ``` 258 | Note: when applying a decay to the learning rate, be sure to manually apply 259 | the decay to the `weight_decay` as well. For example: 260 | ```python 261 | step = tf.Variable(0, trainable=False) 262 | schedule = tf.optimizers.schedules.PiecewiseConstantDecay( 263 | [10000, 15000], [1e-0, 1e-1, 1e-2]) 264 | # lr and wd can be a function or a tensor 265 | lr = 1e-1 * schedule(step) 266 | wd = lambda: 1e-4 * schedule(step) 267 | # ... 268 | optimizer = tfa.optimizers.SGDW( 269 | learning_rate=lr, weight_decay=wd, momentum=0.9) 270 | ``` 271 | """ 272 | 273 | def __init__( 274 | self, 275 | learning_rate=0.001, 276 | momentum=0.0, 277 | nesterov=False, 278 | weight_decay=0.0, 279 | name: str = "SGDW", 280 | **kwargs 281 | ): 282 | """Construct a new SGDW optimizer. 283 | For further information see the documentation of the SGD Optimizer. 284 | Args: 285 | learning_rate: float hyperparameter >= 0. Learning rate. 286 | momentum: float hyperparameter >= 0 that accelerates SGD in the 287 | relevant direction and dampens oscillations. 288 | nesterov: boolean. Whether to apply Nesterov momentum. 289 | name: Optional name prefix for the operations created when applying 290 | gradients. Defaults to 'SGD'. 291 | **kwargs: keyword arguments. Allowed to be {`clipnorm`, 292 | `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by 293 | norm; `clipvalue` is clip gradients by value, `decay` is 294 | included for backward compatibility to allow time inverse decay 295 | of learning rate. `lr` is included for backward compatibility, 296 | recommended to use `learning_rate` instead. 297 | """ 298 | super().__init__( 299 | weight_decay, 300 | learning_rate=learning_rate, 301 | momentum=momentum, 302 | nesterov=nesterov, 303 | name=name, 304 | **kwargs, 305 | ) 306 | 307 | 308 | class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): 309 | """Optimizer that implements the Adam algorithm with weight decay. 310 | This is an implementation of the AdamW optimizer described in "Decoupled 311 | Weight Decay Regularization" by Loshch ilov & Hutter 312 | (https://arxiv.org/abs/1711.05101) 313 | ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). 314 | It computes the update step of `tf.keras.optimizers.Adam` and additionally 315 | decays the variable. Note that this is different from adding L2 316 | regularization on the variables to the loss: it regularizes variables with 317 | large gradients more than L2 regularization would, which was shown to yield 318 | better training loss and generalization error in the paper above. 319 | For further information see the documentation of the Adam Optimizer. 320 | This optimizer can also be instantiated as 321 | ```python 322 | extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam, 323 | weight_decay=weight_decay) 324 | ``` 325 | Note: when applying a decay to the learning rate, be sure to manually apply 326 | the decay to the `weight_decay` as well. For example: 327 | ```python 328 | step = tf.Variable(0, trainable=False) 329 | schedule = tf.optimizers.schedules.PiecewiseConstantDecay( 330 | [10000, 15000], [1e-0, 1e-1, 1e-2]) 331 | # lr and wd can be a function or a tensor 332 | lr = 1e-1 * schedule(step) 333 | wd = lambda: 1e-4 * schedule(step) 334 | # ... 335 | optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) 336 | ``` 337 | """ 338 | 339 | def __init__( 340 | self, 341 | learning_rate=0.001, 342 | beta_1=0.9, 343 | beta_2=0.999, 344 | epsilon=1e-07, 345 | amsgrad=False, 346 | weight_decay=0.0, 347 | name: str = "AdamW", 348 | **kwargs 349 | ): 350 | """Construct a new AdamW optimizer. 351 | For further information see the documentation of the Adam Optimizer. 352 | Args: 353 | weight_decay: A Tensor or a floating point value. The weight decay. 354 | learning_rate: A Tensor or a floating point value. The learning 355 | rate. 356 | beta_1: A float value or a constant float tensor. The exponential 357 | decay rate for the 1st moment estimates. 358 | beta_2: A float value or a constant float tensor. The exponential 359 | decay rate for the 2nd moment estimates. 360 | epsilon: A small constant for numerical stability. This epsilon is 361 | "epsilon hat" in the Kingma and Ba paper (in the formula just 362 | before Section 2.1), not the epsilon in Algorithm 1 of the 363 | paper. 364 | amsgrad: boolean. Whether to apply AMSGrad variant of this 365 | algorithm from the paper "On the Convergence of Adam and 366 | beyond". 367 | name: Optional name for the operations created when applying 368 | gradients. Defaults to "AdamW". 369 | **kwargs: keyword arguments. Allowed to be {`clipnorm`, 370 | `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by 371 | norm; `clipvalue` is clip gradients by value, `decay` is 372 | included for backward compatibility to allow time inverse decay 373 | of learning rate. `lr` is included for backward compatibility, 374 | recommended to use `learning_rate` instead. 375 | """ 376 | super().__init__( 377 | weight_decay, 378 | learning_rate=learning_rate, 379 | beta_1=beta_1, 380 | beta_2=beta_2, 381 | epsilon=epsilon, 382 | amsgrad=amsgrad, 383 | name=name, 384 | **kwargs, 385 | ) 386 | -------------------------------------------------------------------------------- /utils/schedule.py: -------------------------------------------------------------------------------- 1 | import math 2 | import tensorflow as tf 3 | 4 | 5 | class CosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule): 6 | """A LearningRateSchedule that uses a cosine decay schedule.""" 7 | 8 | def __init__( 9 | self, 10 | initial_learning_rate, 11 | decay_steps, 12 | alpha=0.0, 13 | name=None): 14 | """Applies cosine decay to the learning rate. 15 | 16 | See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent 17 | with Warm Restarts. https://arxiv.org/abs/1608.03983 18 | 19 | When training a model, it is often recommended to lower the learning rate as 20 | the training progresses. This schedule applies a cosine decay function 21 | to an optimizer step, given a provided initial learning rate. 22 | It requires a `step` value to compute the decayed learning rate. You can 23 | just pass a TensorFlow variable that you increment at each training step. 24 | 25 | The schedule a 1-arg callable that produces a decayed learning 26 | rate when passed the current optimizer step. This can be useful for changing 27 | the learning rate value across different invocations of optimizer functions. 28 | It is computed as: 29 | 30 | ```python 31 | def decayed_learning_rate(step): 32 | step = min(step, decay_steps) 33 | cosine_decay = 0.5 * (1 + cos(pi * step / decay_steps)) 34 | decayed = (1 - alpha) * cosine_decay + alpha 35 | return initial_learning_rate * decayed 36 | ``` 37 | 38 | Example usage: 39 | ```python 40 | decay_steps = 1000 41 | lr_decayed_fn = tf.keras.experimental.CosineDecay( 42 | initial_learning_rate, decay_steps) 43 | ``` 44 | 45 | You can pass this schedule directly into a `tf.keras.optimizers.Optimizer` 46 | as the learning rate. The learning rate schedule is also serializable and 47 | deserializable using `tf.keras.optimizers.schedules.serialize` and 48 | `tf.keras.optimizers.schedules.deserialize`. 49 | 50 | Args: 51 | initial_learning_rate: A scalar `float32` or `float64` Tensor or a 52 | Python number. The initial learning rate. 53 | decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. 54 | Number of steps to decay over. 55 | alpha: A scalar `float32` or `float64` Tensor or a Python number. 56 | Minimum learning rate value as a fraction of initial_learning_rate. 57 | name: String. Optional name of the operation. Defaults to 'CosineDecay'. 58 | Returns: 59 | A 1-arg callable learning rate schedule that takes the current optimizer 60 | step and outputs the decayed learning rate, a scalar `Tensor` of the same 61 | type as `initial_learning_rate`. 62 | """ 63 | super(CosineDecay, self).__init__() 64 | 65 | self.initial_learning_rate = initial_learning_rate 66 | self.decay_steps = decay_steps 67 | self.alpha = alpha 68 | self.name = name 69 | 70 | def __call__(self, step): 71 | with tf.name_scope(self.name or "CosineDecay"): 72 | initial_learning_rate = tf.convert_to_tensor( 73 | self.initial_learning_rate, name="initial_learning_rate") 74 | dtype = initial_learning_rate.dtype 75 | decay_steps = tf.cast(self.decay_steps, dtype) 76 | 77 | global_step_recomp = tf.cast(step, dtype) 78 | global_step_recomp = tf.minimum(global_step_recomp, decay_steps) 79 | completed_fraction = global_step_recomp / decay_steps 80 | cosine_decayed = 0.5 * (1.0 + tf.cos( 81 | tf.constant(math.pi) * completed_fraction)) 82 | 83 | decayed = (self.initial_learning_rate - self.alpha) * cosine_decayed + self.alpha 84 | return decayed 85 | 86 | def get_config(self): 87 | return { 88 | "initial_learning_rate": self.initial_learning_rate, 89 | "decay_steps": self.decay_steps, 90 | "alpha": self.alpha, 91 | "name": self.name 92 | } 93 | --------------------------------------------------------------------------------