├── .ipynb_checkpoints └── README-checkpoint.md ├── LICENSE ├── README.md ├── paper-code ├── C-discover-symmetry │ ├── .ipynb_checkpoints │ │ └── L-conv-discover-large-angle2021-06-04-checkpoint.ipynb │ ├── L-conv-discover-2021-06-04.ipynb │ └── L-conv-discover-large-angle2021-06-04.ipynb └── D-image-experiments │ ├── L-conv-extra-exp-2020-11-21.ipynb │ ├── L-conv-extra-exp-2020-11-24.ipynb │ ├── L-conv-usage-vis-2020-11-23.ipynb │ ├── lconv.py │ ├── run_LieConv_cifar100.py │ ├── run_test-v2.py │ └── run_test.py └── src ├── .ipynb_checkpoints ├── examples-checkpoint.ipynb └── lconv-checkpoint.py ├── __pycache__ └── lconv.cpython-38.pyc ├── examples.ipynb └── lconv.py /.ipynb_checkpoints/README-checkpoint.md: -------------------------------------------------------------------------------- 1 | # Lie Algebra Convolutional Network (L-conv) implementation 2 | __Paper:__ [Automatic Symmetry Discovery with Lie Algebra Convolutional Network](https://papers.nips.cc/paper/2021/file/148148d62be67e0916a833931bd32b26-Paper.pdf) _Nima Dehmamy, Robin Walters, Yanchen Liu, Dashun Wang, Rose Yu_ NeurIPS 2021 3 | (find updated versions on [arxiv](https://arxiv.org/abs/2109.07103)) 4 | 5 | 6 | ## Contents 7 | A simple implementation of the L-conv layer in PyTorch (>=1.8) can be found in `src/lconv.py`. 8 | The L-conv layer acts similar to a graph convlutional layer (GCN), so prepare your input in a similar fashion (e.g. flatten the spatial dimensions). 9 | The input should have shape `(batch, channels, #nodes)` (e.g. on an image, # nodes = # pixels) 10 | This repository also contains code and notebooks for the experiemnts in the paper (appendix C and D) under `paper-code`. 11 | Most experiments in appendix D use an older (but identical) implementation in Tensoflow (>=2.1). 12 | Comparison with LieConv in appendix D requires the [LieConv](https://github.com/mfinzi/LieConv) packages. 13 | 14 | ### TBA soon: 15 | Exmaples of uses will be added soon. 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Nima Dehmamy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lie Algebra Convolutional Network (L-conv) implementation 2 | __Paper:__ [Automatic Symmetry Discovery with Lie Algebra Convolutional Network](https://papers.nips.cc/paper/2021/file/148148d62be67e0916a833931bd32b26-Paper.pdf) _Nima Dehmamy, Robin Walters, Yanchen Liu, Dashun Wang, Rose Yu_ NeurIPS 2021 3 | (find updated versions on [arxiv](https://arxiv.org/abs/2109.07103)) 4 | 5 | 6 | ## Contents 7 | A simple implementation of the L-conv layer in PyTorch (>=1.8) can be found in `src/lconv.py`. 8 | The L-conv layer acts similar to a graph convlutional layer (GCN), so prepare your input in a similar fashion (e.g. flatten the spatial dimensions). 9 | The input should have shape `(batch, channels, #nodes)` (e.g. on an image, # nodes = # pixels) 10 | This repository also contains code and notebooks for the experiemnts in the paper (appendix C and D) under `paper-code`. 11 | Most experiments in appendix D use an older (but identical) implementation in Tensoflow (>=2.1). 12 | Comparison with LieConv in appendix D requires the [LieConv](https://github.com/mfinzi/LieConv) packages. 13 | 14 | ### TBA soon: 15 | Exmaples of uses will be added soon. 16 | -------------------------------------------------------------------------------- /paper-code/D-image-experiments/L-conv-extra-exp-2020-11-21.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/", 9 | "height": 34 10 | }, 11 | "colab_type": "code", 12 | "id": "L6Q9A-AvgeCn", 13 | "outputId": "c2755a01-3585-4ab1-e378-a3eb70478363" 14 | }, 15 | "outputs": [ 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "Populating the interactive namespace from numpy and matplotlib\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "%pylab inline\n", 26 | "\n", 27 | "%config InlineBackend.figure_format = 'retina'" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import lconv\n", 37 | "import tensorflow as tf\n", 38 | "\n", 39 | "import os\n", 40 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\" \n", 41 | "\n", 42 | "sess = tf.compat.v1.InteractiveSession()\n", 43 | "# K = tf.keras.backend\n", 44 | "\n", 45 | "from tensorflow.keras import Model, Sequential\n", 46 | "from tensorflow.keras.layers import Layer, Input, Flatten, Reshape, Dense, Conv2D, MaxPool2D\n", 47 | "\n", 48 | "import pickle as pk\n", 49 | "import json\n", 50 | "# import numpy as np\n", 51 | "\n", 52 | "\n", 53 | "from scipy import ndimage\n" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "\n", 63 | "def rotated_ims_rand(x):\n", 64 | " return np.float32([ndimage.rotate(i, (np.random.rand()-.5)*180, reshape=False, mode='nearest') for i in x])\n", 65 | " \n", 66 | "class Scramble_x:\n", 67 | " def __init__(self,x):\n", 68 | " s = x.shape[1:-1]\n", 69 | " self.idx = np.argsort(np.random.rand(np.prod(s)))\n", 70 | " r,c = np.int0(self.idx/s[0]), (self.idx % s[1]) \n", 71 | " self.x = np.float32([i[r,c].reshape(s+(x.shape[-1],)) for i in x])\n", 72 | "\n" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 4, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "\n", 82 | "# Defaults \n", 83 | "configs= {\n", 84 | " 'dataset': dict(name= 'cifar100', #'mnist' ,#'mnist',cifar100, \n", 85 | " rotate=True, \n", 86 | " scramble=True,\n", 87 | " ),\n", 88 | " 'net': dict(architecture= 'lconv', # 'cnn', 'fc', \n", 89 | " num_filters=32, \n", 90 | " kernel_size=9, \n", 91 | " L_hid= [8], #[16], \n", 92 | " activation = 'relu',\n", 93 | " L_trainable = True,\n", 94 | " num_layers = 1,\n", 95 | " ),\n", 96 | "}\n" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 5, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "name": "stdout", 106 | "output_type": "stream", 107 | "text": [ 108 | "Rotating images\n", 109 | "Scrambling images\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "\n", 115 | "dataset_name = configs['dataset']['name']\n", 116 | "\n", 117 | "dataset = eval(\"tf.keras.datasets.%s.load_data()\" %dataset_name) \n", 118 | "(x_train, y_train), (x_test,y_test) = dataset\n", 119 | "if len(x_train.shape) == 3:\n", 120 | " # mnist channel is missing\n", 121 | " x_train = x_train[...,np.newaxis]\n", 122 | " \n", 123 | "# normalize\n", 124 | "x_train = x_train/x_train[:100].max() -.5 \n", 125 | "# make categorical\n", 126 | "y_train = tf.keras.utils.to_categorical(y_train)\n", 127 | "\n", 128 | "\n", 129 | "results = {'configs':configs,}\n", 130 | "\n", 131 | "if configs['dataset']['rotate']:\n", 132 | " print('Rotating images')\n", 133 | " x_train = rotated_ims_rand(x_train)\n", 134 | " \n", 135 | "if configs['dataset']['scramble']:\n", 136 | " print('Scrambling images')\n", 137 | " scr = Scramble_x(x_train)\n", 138 | " x_train = scr.x\n", 139 | " results['scramble_idx']=scr.idx.tolist()\n", 140 | "\n", 141 | "##### Make model #####\n", 142 | "\n", 143 | "net = configs['net']\n", 144 | "# arch = net['architecture']\n" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 6, 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "name": "stdout", 154 | "output_type": "stream", 155 | "text": [ 156 | "Model: \"functional_1\"\n", 157 | "_________________________________________________________________\n", 158 | "Layer (type) Output Shape Param # \n", 159 | "=================================================================\n", 160 | "input_1 (InputLayer) [(None, 32, 32, 3)] 0 \n", 161 | "_________________________________________________________________\n", 162 | "tf_op_layer_Reshape (TensorF [(None, 1024, 3)] 0 \n", 163 | "_________________________________________________________________\n", 164 | "l__conv (L_Conv) (None, 1024, 32) 132224 \n", 165 | "_________________________________________________________________\n", 166 | "flatten (Flatten) (None, 32768) 0 \n", 167 | "_________________________________________________________________\n", 168 | "dense (Dense) (None, 100) 3276900 \n", 169 | "=================================================================\n", 170 | "Total params: 3,409,124\n", 171 | "Trainable params: 3,409,124\n", 172 | "Non-trainable params: 0\n", 173 | "_________________________________________________________________\n" 174 | ] 175 | } 176 | ], 177 | "source": [ 178 | "net = configs['net']\n", 179 | "\n", 180 | "inp = Input(x_train[0].shape)\n", 181 | "\n", 182 | "x = inp\n", 183 | "for _ in range(net['num_layers']):\n", 184 | " x = tf.reshape(x, shape=(-1,np.prod(x.shape[1:-1]), x.shape[-1]))\n", 185 | " lay = lconv.L_Conv(num_filters= net['num_filters'], \n", 186 | " kernel_size= net['kernel_size'], \n", 187 | " L_hid = net['L_hid'], \n", 188 | " activation = net['activation'],)\n", 189 | "\n", 190 | " x = lay(x)\n", 191 | " lay.L.trainable = net['L_trainable']\n", 192 | "\n", 193 | "\n", 194 | "x = Flatten()(x)\n", 195 | "\n", 196 | "# x = Dense(100, activation = 'relu')(x)\n", 197 | "\n", 198 | "out = Dense(y_train.shape[-1], activation='softmax')(x)\n", 199 | "\n", 200 | "model = Model(inputs = [inp], outputs = [out])\n", 201 | "model.compile(loss = tf.keras.losses.categorical_crossentropy, metrics = ['accuracy'])\n", 202 | "\n", 203 | "model.summary()\n", 204 | "\n" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 7, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "h = model.fit(x_train, y_train, validation_split=0.2, epochs=10)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "# Shallow FC with parameters matching the L-conv model " 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "The parameters \n", 228 | "xs * u + u* 10" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 8, 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "data": { 238 | "text/plain": [ 239 | "1075" 240 | ] 241 | }, 242 | "execution_count": 8, 243 | "metadata": {}, 244 | "output_type": "execute_result" 245 | } 246 | ], 247 | "source": [ 248 | "xs = np.prod(inp.shape[1:])\n", 249 | "# u = int(0.5+ lay.count_params()/ xs)\n", 250 | "u = int(.5+ model.count_params() / (model.output_shape[-1] + xs ))\n", 251 | "u" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 9, 257 | "metadata": {}, 258 | "outputs": [ 259 | { 260 | "name": "stdout", 261 | "output_type": "stream", 262 | "text": [ 263 | "Model: \"functional_3\"\n", 264 | "_________________________________________________________________\n", 265 | "Layer (type) Output Shape Param # \n", 266 | "=================================================================\n", 267 | "input_2 (InputLayer) [(None, 32, 32, 3)] 0 \n", 268 | "_________________________________________________________________\n", 269 | "flatten_1 (Flatten) (None, 3072) 0 \n", 270 | "_________________________________________________________________\n", 271 | "dense_1 (Dense) (None, 1075) 3303475 \n", 272 | "_________________________________________________________________\n", 273 | "dense_2 (Dense) (None, 100) 107600 \n", 274 | "=================================================================\n", 275 | "Total params: 3,411,075\n", 276 | "Trainable params: 3,411,075\n", 277 | "Non-trainable params: 0\n", 278 | "_________________________________________________________________\n" 279 | ] 280 | } 281 | ], 282 | "source": [ 283 | "net = configs['net']\n", 284 | "\n", 285 | "inp = Input(x_train[0].shape)\n", 286 | "\n", 287 | "x = inp\n", 288 | "\n", 289 | "x = Flatten()(inp)\n", 290 | "# FC comparable to L-conv, but no shared weights \n", 291 | "x = Dense(u, activation = net['activation'])(x)\n", 292 | "\n", 293 | "out = Dense(y_train.shape[-1], activation='softmax')(x)\n", 294 | "\n", 295 | "model = Model(inputs = [inp], outputs = [out])\n", 296 | "model.compile(loss = tf.keras.losses.categorical_crossentropy, metrics = ['accuracy'])\n", 297 | "\n", 298 | "model.summary()\n", 299 | "\n" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": 10, 305 | "metadata": {}, 306 | "outputs": [ 307 | { 308 | "name": "stdout", 309 | "output_type": "stream", 310 | "text": [ 311 | "Epoch 1/30\n", 312 | "1250/1250 [==============================] - 5s 4ms/step - loss: 4.3505 - accuracy: 0.1206 - val_loss: 4.1713 - val_accuracy: 0.1420\n", 313 | "Epoch 2/30\n", 314 | "1250/1250 [==============================] - 5s 4ms/step - loss: 3.9070 - accuracy: 0.1848 - val_loss: 4.2039 - val_accuracy: 0.1650\n", 315 | "Epoch 3/30\n", 316 | "1250/1250 [==============================] - 5s 4ms/step - loss: 3.7058 - accuracy: 0.2255 - val_loss: 4.1117 - val_accuracy: 0.1878\n", 317 | "Epoch 4/30\n", 318 | "1250/1250 [==============================] - 5s 4ms/step - loss: 3.5451 - accuracy: 0.2544 - val_loss: 4.8140 - val_accuracy: 0.1781\n", 319 | "Epoch 5/30\n", 320 | "1250/1250 [==============================] - 5s 4ms/step - loss: 3.3958 - accuracy: 0.2869 - val_loss: 4.3992 - val_accuracy: 0.1821\n", 321 | "Epoch 6/30\n", 322 | "1250/1250 [==============================] - 5s 4ms/step - loss: 3.2576 - accuracy: 0.3169 - val_loss: 4.5950 - val_accuracy: 0.1988\n", 323 | "Epoch 7/30\n", 324 | "1250/1250 [==============================] - 5s 4ms/step - loss: 3.1092 - accuracy: 0.3451 - val_loss: 4.9388 - val_accuracy: 0.1893\n", 325 | "Epoch 8/30\n", 326 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.9819 - accuracy: 0.3800 - val_loss: 5.1262 - val_accuracy: 0.1933\n", 327 | "Epoch 9/30\n", 328 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.8600 - accuracy: 0.4033 - val_loss: 5.4323 - val_accuracy: 0.1927\n", 329 | "Epoch 10/30\n", 330 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.7234 - accuracy: 0.4313 - val_loss: 5.6304 - val_accuracy: 0.1902\n", 331 | "Epoch 11/30\n", 332 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.6202 - accuracy: 0.4598 - val_loss: 5.6795 - val_accuracy: 0.2021\n", 333 | "Epoch 12/30\n", 334 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.4936 - accuracy: 0.4812 - val_loss: 6.2062 - val_accuracy: 0.1952\n", 335 | "Epoch 13/30\n", 336 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.3621 - accuracy: 0.5110 - val_loss: 6.3498 - val_accuracy: 0.2006\n", 337 | "Epoch 14/30\n", 338 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.2828 - accuracy: 0.5314 - val_loss: 6.8185 - val_accuracy: 0.2076\n", 339 | "Epoch 15/30\n", 340 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.1961 - accuracy: 0.5562 - val_loss: 6.9912 - val_accuracy: 0.1992\n", 341 | "Epoch 16/30\n", 342 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.0824 - accuracy: 0.5715 - val_loss: 7.3662 - val_accuracy: 0.2000\n", 343 | "Epoch 17/30\n", 344 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.0096 - accuracy: 0.5952 - val_loss: 7.7374 - val_accuracy: 0.1947\n", 345 | "Epoch 18/30\n", 346 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.9498 - accuracy: 0.6051 - val_loss: 8.0833 - val_accuracy: 0.2032\n", 347 | "Epoch 19/30\n", 348 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.8755 - accuracy: 0.6250 - val_loss: 8.6338 - val_accuracy: 0.1993\n", 349 | "Epoch 20/30\n", 350 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.7955 - accuracy: 0.6419 - val_loss: 8.6264 - val_accuracy: 0.2086\n", 351 | "Epoch 21/30\n", 352 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.7687 - accuracy: 0.6562 - val_loss: 8.9263 - val_accuracy: 0.2076\n", 353 | "Epoch 22/30\n", 354 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.6695 - accuracy: 0.6693 - val_loss: 9.6866 - val_accuracy: 0.2030\n", 355 | "Epoch 23/30\n", 356 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.6299 - accuracy: 0.6830 - val_loss: 9.7473 - val_accuracy: 0.2070\n", 357 | "Epoch 24/30\n", 358 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.5655 - accuracy: 0.6931 - val_loss: 10.7333 - val_accuracy: 0.2005\n", 359 | "Epoch 25/30\n", 360 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.5255 - accuracy: 0.7073 - val_loss: 10.8067 - val_accuracy: 0.2012\n", 361 | "Epoch 26/30\n", 362 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.4941 - accuracy: 0.7142 - val_loss: 10.9942 - val_accuracy: 0.2002\n", 363 | "Epoch 27/30\n", 364 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.4608 - accuracy: 0.7237 - val_loss: 11.6645 - val_accuracy: 0.2034\n", 365 | "Epoch 28/30\n", 366 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.3985 - accuracy: 0.7359 - val_loss: 11.6220 - val_accuracy: 0.2058\n", 367 | "Epoch 29/30\n", 368 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.3767 - accuracy: 0.7433 - val_loss: 12.3072 - val_accuracy: 0.2017\n", 369 | "Epoch 30/30\n", 370 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.3341 - accuracy: 0.7549 - val_loss: 12.7778 - val_accuracy: 0.1996\n" 371 | ] 372 | } 373 | ], 374 | "source": [ 375 | "h = model.fit(x_train, y_train, validation_split=0.2, epochs=30)" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 11, 381 | "metadata": {}, 382 | "outputs": [ 383 | { 384 | "data": { 385 | "image/png": "\n", 386 | "text/plain": [ 387 | "
" 388 | ] 389 | }, 390 | "metadata": { 391 | "image/png": { 392 | "height": 248, 393 | "width": 373 394 | }, 395 | "needs_background": "light" 396 | }, 397 | "output_type": "display_data" 398 | } 399 | ], 400 | "source": [ 401 | "for k in ['accuracy', 'val_accuracy']:\n", 402 | " plot(h.history[k], )" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": 12, 408 | "metadata": {}, 409 | "outputs": [ 410 | { 411 | "name": "stdout", 412 | "output_type": "stream", 413 | "text": [ 414 | "./results-v2/cifar100/FC_shallow-u1075-n_lay-1-act-relu-rotate=False-scramble=False.json\n" 415 | ] 416 | } 417 | ], 418 | "source": [ 419 | "model_name = f\"FC_shallow-u{u}\" \n", 420 | "\n", 421 | "\n", 422 | "model_name += f\"-n_lay-{net['num_layers']}\"\n", 423 | "model_name += f\"-act-{net['activation']}\"\n", 424 | "num_params = model.count_params() \n", 425 | "out_file_name = f\"./results-v2/{dataset_name}/{model_name}-rotate={bool(configs['dataset']['rotate'])}-scramble={bool(configs['dataset']['scramble'])}.json\"\n", 426 | "\n", 427 | "configs['net']['architecture'] = 'fc_shallow'\n", 428 | "\n", 429 | "results = {}\n", 430 | "results.update({\n", 431 | " 'num_params':num_params,\n", 432 | " 'result':h.history,\n", 433 | " 'configs':configs,\n", 434 | "# 'result': {k: np.float32(v).tolist() for k,v in h.history.items()}, # bug in json or TF2\n", 435 | " })\n", 436 | "\n", 437 | "\n", 438 | "\n", 439 | "# for k,v in results['result'].items():\n", 440 | "# print(k,type(v))\n", 441 | "\n", 442 | "import os\n", 443 | "\n", 444 | "# print(h.history)\n", 445 | "\n", 446 | "dirs = os.path.split(out_file_name)[0]\n", 447 | "os.makedirs(dirs,exist_ok=True)\n", 448 | "\n", 449 | "print(out_file_name)\n", 450 | "\n", 451 | "json.dump(results, open(out_file_name, 'w'))" 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "metadata": {}, 457 | "source": [ 458 | "# CNN + Max pool" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 7, 464 | "metadata": {}, 465 | "outputs": [ 466 | { 467 | "data": { 468 | "text/plain": [ 469 | "{'dataset': {'name': 'cifar100', 'rotate': True, 'scramble': True},\n", 470 | " 'net': {'architecture': 'lconv',\n", 471 | " 'num_filters': 32,\n", 472 | " 'kernel_size': 9,\n", 473 | " 'L_hid': [8],\n", 474 | " 'activation': 'relu',\n", 475 | " 'L_trainable': True,\n", 476 | " 'num_layers': 1}}" 477 | ] 478 | }, 479 | "execution_count": 7, 480 | "metadata": {}, 481 | "output_type": "execute_result" 482 | } 483 | ], 484 | "source": [ 485 | "net['num_layers'] = 1\n", 486 | "configs\n" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": 8, 492 | "metadata": {}, 493 | "outputs": [ 494 | { 495 | "name": "stdout", 496 | "output_type": "stream", 497 | "text": [ 498 | "Model: \"functional_3\"\n", 499 | "_________________________________________________________________\n", 500 | "Layer (type) Output Shape Param # \n", 501 | "=================================================================\n", 502 | "input_2 (InputLayer) [(None, 32, 32, 3)] 0 \n", 503 | "_________________________________________________________________\n", 504 | "conv2d (Conv2D) (None, 30, 30, 32) 896 \n", 505 | "_________________________________________________________________\n", 506 | "max_pooling2d (MaxPooling2D) (None, 29, 29, 32) 0 \n", 507 | "_________________________________________________________________\n", 508 | "flatten_1 (Flatten) (None, 26912) 0 \n", 509 | "_________________________________________________________________\n", 510 | "dense_1 (Dense) (None, 100) 2691300 \n", 511 | "=================================================================\n", 512 | "Total params: 2,692,196\n", 513 | "Trainable params: 2,692,196\n", 514 | "Non-trainable params: 0\n", 515 | "_________________________________________________________________\n" 516 | ] 517 | } 518 | ], 519 | "source": [ 520 | "net = configs['net']\n", 521 | "\n", 522 | "kx = int(0.5+np.sqrt(net['kernel_size']))\n", 523 | "ky = int(0.5+net['kernel_size']/kx)\n", 524 | "kernel_size = (kx,ky)\n", 525 | "\n", 526 | "\n", 527 | "inp = Input(x_train[0].shape)\n", 528 | "\n", 529 | "x = inp\n", 530 | "for _ in range(net['num_layers']):\n", 531 | " x = Conv2D(filters=net['num_filters'], \n", 532 | " kernel_size=kernel_size, \n", 533 | " activation = net['activation'])(x)\n", 534 | "\n", 535 | " x = MaxPool2D(pool_size=(2, 2),strides=(1, 1), padding='valid')(x)\n", 536 | "\n", 537 | "\n", 538 | "x = Flatten()(x)\n", 539 | "\n", 540 | "# x = Dense(100, activation = 'relu')(x)\n", 541 | "\n", 542 | "out = Dense(y_train.shape[-1], activation='softmax')(x)\n", 543 | "\n", 544 | "model = Model(inputs = [inp], outputs = [out])\n", 545 | "model.compile(loss = tf.keras.losses.categorical_crossentropy, metrics = ['accuracy'])\n", 546 | "\n", 547 | "model.summary()\n", 548 | "\n" 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "execution_count": 9, 554 | "metadata": {}, 555 | "outputs": [ 556 | { 557 | "name": "stdout", 558 | "output_type": "stream", 559 | "text": [ 560 | "Epoch 1/30\n", 561 | "1250/1250 [==============================] - 5s 4ms/step - loss: 3.9167 - accuracy: 0.1236 - val_loss: 3.7271 - val_accuracy: 0.1473\n", 562 | "Epoch 2/30\n", 563 | "1250/1250 [==============================] - 5s 4ms/step - loss: 3.5135 - accuracy: 0.1927 - val_loss: 3.6933 - val_accuracy: 0.1582\n", 564 | "Epoch 3/30\n", 565 | "1250/1250 [==============================] - 5s 4ms/step - loss: 3.3421 - accuracy: 0.2239 - val_loss: 3.6467 - val_accuracy: 0.1758\n", 566 | "Epoch 4/30\n", 567 | "1250/1250 [==============================] - 5s 4ms/step - loss: 3.2010 - accuracy: 0.2546 - val_loss: 3.6889 - val_accuracy: 0.1663\n", 568 | "Epoch 5/30\n", 569 | "1250/1250 [==============================] - 5s 4ms/step - loss: 3.0754 - accuracy: 0.2845 - val_loss: 3.8055 - val_accuracy: 0.1799\n", 570 | "Epoch 6/30\n", 571 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.9545 - accuracy: 0.3071 - val_loss: 3.7576 - val_accuracy: 0.1786\n", 572 | "Epoch 7/30\n", 573 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.8452 - accuracy: 0.3301 - val_loss: 3.8578 - val_accuracy: 0.1732\n", 574 | "Epoch 8/30\n", 575 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.7313 - accuracy: 0.3489 - val_loss: 3.8588 - val_accuracy: 0.1688\n", 576 | "Epoch 9/30\n", 577 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.6140 - accuracy: 0.3762 - val_loss: 3.9497 - val_accuracy: 0.1739\n", 578 | "Epoch 10/30\n", 579 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.5057 - accuracy: 0.4000 - val_loss: 3.9963 - val_accuracy: 0.1631\n", 580 | "Epoch 11/30\n", 581 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.3924 - accuracy: 0.4230 - val_loss: 4.0768 - val_accuracy: 0.1632\n", 582 | "Epoch 12/30\n", 583 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.2901 - accuracy: 0.4454 - val_loss: 4.1800 - val_accuracy: 0.1581\n", 584 | "Epoch 13/30\n", 585 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.1808 - accuracy: 0.4697 - val_loss: 4.2969 - val_accuracy: 0.1600\n", 586 | "Epoch 14/30\n", 587 | "1250/1250 [==============================] - 5s 4ms/step - loss: 2.0804 - accuracy: 0.4894 - val_loss: 4.4049 - val_accuracy: 0.1651\n", 588 | "Epoch 15/30\n", 589 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.9885 - accuracy: 0.5112 - val_loss: 4.4964 - val_accuracy: 0.1609\n", 590 | "Epoch 16/30\n", 591 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.8966 - accuracy: 0.5336 - val_loss: 4.6298 - val_accuracy: 0.1590\n", 592 | "Epoch 17/30\n", 593 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.8040 - accuracy: 0.5528 - val_loss: 4.6630 - val_accuracy: 0.1627\n", 594 | "Epoch 18/30\n", 595 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.7130 - accuracy: 0.5712 - val_loss: 4.8223 - val_accuracy: 0.1548\n", 596 | "Epoch 19/30\n", 597 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.6312 - accuracy: 0.5905 - val_loss: 4.9647 - val_accuracy: 0.1609\n", 598 | "Epoch 20/30\n", 599 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.5499 - accuracy: 0.6095 - val_loss: 5.1244 - val_accuracy: 0.1548\n", 600 | "Epoch 21/30\n", 601 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.4663 - accuracy: 0.6308 - val_loss: 5.3002 - val_accuracy: 0.1550\n", 602 | "Epoch 22/30\n", 603 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.3956 - accuracy: 0.6469 - val_loss: 5.4341 - val_accuracy: 0.1556\n", 604 | "Epoch 23/30\n", 605 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.3252 - accuracy: 0.6654 - val_loss: 5.7170 - val_accuracy: 0.1451\n", 606 | "Epoch 24/30\n", 607 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.2559 - accuracy: 0.6827 - val_loss: 5.7569 - val_accuracy: 0.1495\n", 608 | "Epoch 25/30\n", 609 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.1830 - accuracy: 0.6993 - val_loss: 5.9255 - val_accuracy: 0.1458\n", 610 | "Epoch 26/30\n", 611 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.1174 - accuracy: 0.7147 - val_loss: 6.1386 - val_accuracy: 0.1498\n", 612 | "Epoch 27/30\n", 613 | "1250/1250 [==============================] - 5s 4ms/step - loss: 1.0562 - accuracy: 0.7307 - val_loss: 6.3020 - val_accuracy: 0.1465\n", 614 | "Epoch 28/30\n", 615 | "1250/1250 [==============================] - 5s 4ms/step - loss: 0.9958 - accuracy: 0.7465 - val_loss: 6.4740 - val_accuracy: 0.1424\n", 616 | "Epoch 29/30\n", 617 | "1250/1250 [==============================] - 5s 4ms/step - loss: 0.9368 - accuracy: 0.7599 - val_loss: 6.6097 - val_accuracy: 0.1445\n", 618 | "Epoch 30/30\n", 619 | "1250/1250 [==============================] - 5s 4ms/step - loss: 0.8849 - accuracy: 0.7746 - val_loss: 6.8493 - val_accuracy: 0.1477\n" 620 | ] 621 | } 622 | ], 623 | "source": [ 624 | "h = model.fit(x_train, y_train, validation_split=0.2, epochs=30)" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": 10, 630 | "metadata": {}, 631 | "outputs": [ 632 | { 633 | "data": { 634 | "image/png": "\n", 635 | "text/plain": [ 636 | "
" 637 | ] 638 | }, 639 | "metadata": { 640 | "image/png": { 641 | "height": 249, 642 | "width": 373 643 | }, 644 | "needs_background": "light" 645 | }, 646 | "output_type": "display_data" 647 | } 648 | ], 649 | "source": [ 650 | "for k in ['accuracy', 'val_accuracy']:\n", 651 | " plot(h.history[k], )" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": 11, 657 | "metadata": {}, 658 | "outputs": [ 659 | { 660 | "name": "stdout", 661 | "output_type": "stream", 662 | "text": [ 663 | "Model: \"functional_3\"\n", 664 | "_________________________________________________________________\n", 665 | "Layer (type) Output Shape Param # \n", 666 | "=================================================================\n", 667 | "input_2 (InputLayer) [(None, 32, 32, 3)] 0 \n", 668 | "_________________________________________________________________\n", 669 | "conv2d (Conv2D) (None, 30, 30, 32) 896 \n", 670 | "_________________________________________________________________\n", 671 | "max_pooling2d (MaxPooling2D) (None, 29, 29, 32) 0 \n", 672 | "_________________________________________________________________\n", 673 | "flatten_1 (Flatten) (None, 26912) 0 \n", 674 | "_________________________________________________________________\n", 675 | "dense_1 (Dense) (None, 100) 2691300 \n", 676 | "=================================================================\n", 677 | "Total params: 2,692,196\n", 678 | "Trainable params: 2,692,196\n", 679 | "Non-trainable params: 0\n", 680 | "_________________________________________________________________\n" 681 | ] 682 | } 683 | ], 684 | "source": [ 685 | "model.summary()" 686 | ] 687 | }, 688 | { 689 | "cell_type": "code", 690 | "execution_count": null, 691 | "metadata": {}, 692 | "outputs": [], 693 | "source": [] 694 | }, 695 | { 696 | "cell_type": "code", 697 | "execution_count": 12, 698 | "metadata": {}, 699 | "outputs": [ 700 | { 701 | "name": "stdout", 702 | "output_type": "stream", 703 | "text": [ 704 | "./results-v2/cifar100/CNN_maxpool-nf32-ker(3, 3)-n_lay-1-act-relu-rotate=True-scramble=True.json\n" 705 | ] 706 | } 707 | ], 708 | "source": [ 709 | "configs['net']['architecture'] = 'cnn_maxpool'\n", 710 | "\n", 711 | "model_name = f\"CNN_maxpool-nf{net['num_filters']}-ker{kernel_size}\"\n", 712 | "\n", 713 | "model_name += f\"-n_lay-{net['num_layers']}\"\n", 714 | "model_name += f\"-act-{net['activation']}\"\n", 715 | "num_params = model.count_params() \n", 716 | "\n", 717 | "out_file_name = f\"./results-v2/{dataset_name}/{model_name}-rotate={bool(configs['dataset']['rotate'])}-scramble={bool(configs['dataset']['scramble'])}.json\"\n", 718 | "\n", 719 | "configs['net']['architecture'] = 'cnn_maxpool'\n", 720 | "\n", 721 | "results = {}\n", 722 | "results.update({\n", 723 | " 'num_params':num_params,\n", 724 | " 'result':h.history,\n", 725 | " 'configs':configs,\n", 726 | "# 'result': {k: np.float32(v).tolist() for k,v in h.history.items()}, # bug in json or TF2\n", 727 | " })\n", 728 | "\n", 729 | "\n", 730 | "\n", 731 | "# for k,v in results['result'].items():\n", 732 | "# print(k,type(v))\n", 733 | "\n", 734 | "import os\n", 735 | "\n", 736 | "# print(h.history)\n", 737 | "\n", 738 | "dirs = os.path.split(out_file_name)[0]\n", 739 | "os.makedirs(dirs,exist_ok=True)\n", 740 | "\n", 741 | "print(out_file_name)\n", 742 | "\n", 743 | "\n", 744 | "json.dump(results, open(out_file_name, 'w'))" 745 | ] 746 | } 747 | ], 748 | "metadata": { 749 | "kernelspec": { 750 | "display_name": "TF2.0 Py3", 751 | "language": "python", 752 | "name": "tf2" 753 | }, 754 | "language_info": { 755 | "codemirror_mode": { 756 | "name": "ipython", 757 | "version": 3 758 | }, 759 | "file_extension": ".py", 760 | "mimetype": "text/x-python", 761 | "name": "python", 762 | "nbconvert_exporter": "python", 763 | "pygments_lexer": "ipython3", 764 | "version": "3.7.5" 765 | } 766 | }, 767 | "nbformat": 4, 768 | "nbformat_minor": 4 769 | } 770 | -------------------------------------------------------------------------------- /paper-code/D-image-experiments/lconv.py: -------------------------------------------------------------------------------- 1 | from tensorflow import reduce_sum, concat, reduce_max 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras.activations import deserialize 5 | 6 | from numpy import newaxis,prod 7 | 8 | 9 | class L_module(Layer): 10 | def __init__(self, n_L, out_dim = None, hidden_units = [], activation = 'linear', **kws): 11 | super(L_module, self).__init__(**kws) 12 | self.params = dict(n_L = n_L, activation = activation, hidden_units = hidden_units) 13 | self.out_dim = out_dim 14 | 15 | 16 | def build(self, input_shape): 17 | # print(input_shape) 18 | d = prod(input_shape[1:]) 19 | 20 | space_dim, feature_dim = input_shape[-2:] 21 | out_dim = self.out_dim or space_dim #int(space_dim/self.stride) 22 | 23 | n_L = self.params['n_L'] 24 | hidden_units = self.params['hidden_units'] 25 | 26 | self.hidden_layers = [] 27 | in_dim = space_dim 28 | for i,u in enumerate(hidden_units): 29 | self.hidden_layers += [self.add_weight(shape=(n_L, u, in_dim), 30 | initializer='glorot_normal', trainable=True, name='Lh_%d' %i)] 31 | in_dim = u 32 | 33 | self.L = self.add_weight(shape=(n_L, out_dim, in_dim), initializer='glorot_normal', trainable=True, name='L') 34 | 35 | def call(self, inputs): 36 | n_L = self.params['n_L'] 37 | 38 | x = inputs 39 | 40 | for l in self.hidden_layers: 41 | x = l @ x 42 | x = self.L @ x 43 | act = deserialize(self.params['activation']) 44 | return act(x) 45 | 46 | class L_Conv(Model): 47 | def __init__(self, 48 | num_filters: int, 49 | kernel_size: int, 50 | #stride: int = 1, 51 | activation = 'relu', 52 | L_hid =[], L_act = 'linear', 53 | ): 54 | """Assumes channel last input x: (batch, space, features). 55 | Uses stride to to scale space dimension: out_dim = int(space/stride). 56 | call: L @ (x @ W + b) 57 | """ 58 | super(L_Conv, self).__init__() 59 | self.num_filters = num_filters 60 | self.kernel_size = kernel_size 61 | self.stride = 1 # stride 62 | self.activation = activation 63 | self.L_params = dict(n_L = kernel_size-1, hidden_units = L_hid, activation = L_act) 64 | 65 | def get_L(self, input_shape): 66 | # assume channel last 67 | space_dim, feature_dim = input_shape[-2:] 68 | out_dim = int(space_dim/self.stride) 69 | 70 | # num_L = kernel-1 b/c original input will be concat 71 | # L = self.add_weight(shape=(self.kernel_size - 1, out_dim, space_dim), 72 | # initializer='glorot_normal', 73 | # trainable=True, name='L') 74 | L = L_module(out_dim = out_dim, **self.L_params) 75 | return L 76 | 77 | def build(self, input_shape): 78 | self.L = self.get_L(input_shape) 79 | self.w = self.add_weight(shape=(self.kernel_size, input_shape[-1], self.num_filters), 80 | initializer='glorot_normal', 81 | trainable=True, name = 'w') 82 | self.b = self.add_weight(shape=(self.kernel_size,1, self.num_filters), 83 | initializer='zeros', 84 | trainable=True, name = 'b') 85 | self.activation_layer = deserialize(self.activation) 86 | 87 | def call(self, inputs): 88 | x0 = inputs[:,newaxis] 89 | 90 | # (batch, space, features) --> (batch, 1, space, features) 91 | #x = self.L @ x0 92 | x = self.L(x0) 93 | # (batch, 1, space, features) --> (batch, kernel_size, space/stride, features) 94 | 95 | x = concat([x, x0], axis = 1) # add back the original 96 | 97 | x = x @ self.w + self.b 98 | # (batch, kernel_size, space, features) --> (batch, kernel_size, space, num_filters) 99 | 100 | x = reduce_sum(x, axis = 1) 101 | 102 | return self.activation_layer(x) 103 | 104 | class L_Conv_max(L_Conv): 105 | def __init__(self, stride: int = 1, **kws): 106 | """Does max_i(L_i L_j x) 107 | """ 108 | super(L_Conv_max, self).__init__(**kws) 109 | 110 | def call(self, inputs): 111 | x0 = inputs[:,newaxis] 112 | 113 | #print(x0.shape) 114 | # (batch, space, features) --> (batch, 1, space, features) 115 | #x = self.L @ x0 116 | x = self.L(x0) 117 | # (batch, 1, space, features) --> (batch, kernel_size, space, features) 118 | 119 | #print(x.shape) 120 | x = concat([x, x0], axis = 1) # add back the original 121 | 122 | #print(x.shape) 123 | # apply L again 124 | x1 = self.L(x[:,:,newaxis]) 125 | # (batch, kernel_size, space, features) --> (batch, kernel_size, kernel_size, space, features) 126 | 127 | #print(x1.shape, x.shape) 128 | x = concat([x1, x[:,:,newaxis]], axis = 2) # add back the original 129 | 130 | x = x @ self.w + self.b 131 | # (batch, kernel_size, kernel_size, space, features) --> (batch, kernel_size, kernel_size, space, num_filters) 132 | 133 | x = reduce_sum(x, axis = 1) 134 | # (batch, kernel_size, space, features) 135 | 136 | # max pooling 137 | x = reduce_max(x, axis = 1) 138 | # (batch, space, features) 139 | 140 | return self.activation_layer(x) 141 | 142 | 143 | class L_Conv_strided(L_Conv): 144 | def __init__(self, stride: int = 1, **kws): 145 | """Assumes channel last input x: (batch, space, features). 146 | Uses stride to to scale space dimension: out_dim = int(space/stride). 147 | call: L @ (x @ W + b) 148 | """ 149 | super(L_Conv_strided, self).__init__(**kws) 150 | self.stride = stride 151 | self.L_params['n_L'] += 1 # to make up for lack of residual conn. 152 | 153 | def call(self, inputs): 154 | x0 = inputs[:,newaxis] 155 | 156 | # (batch, space, features) --> (batch, 1, space, features) 157 | #x = self.L @ x0 158 | x = self.L(x0) 159 | # (batch, 1, space, features) --> (batch, kernel_size, space/stride, features) 160 | 161 | # x = tf.concat([x, x0], axis = 1) # add back the original 162 | 163 | x = x @ self.w + self.b 164 | # (batch, kernel_size, space, features) --> (batch, kernel_size, space, num_filters) 165 | 166 | x = reduce_sum(x, axis = 1) 167 | 168 | return self.activation_layer(x) -------------------------------------------------------------------------------- /paper-code/D-image-experiments/run_LieConv_cifar100.py: -------------------------------------------------------------------------------- 1 | """LieConv Baseline experiments. 2 | requires: 3 | https://github.com/mfinzi/LieConv 4 | 5 | Usage: 6 | $ python3 run_LieConv_cifar100.py --epochs 100 --nlay 2 --ker 256 --lr 3e-3 --bn 0 --rot 1 --scr 1 7 | 8 | --rot: rotate images 9 | --scr: scramble images (fixed shuffling of pixels in all images) 10 | --nlay: number of layers 11 | --bn: batchnorm (0: no batchnorm; used in baseline experiments) 12 | 13 | """ 14 | 15 | 16 | import torch 17 | import torchvision.transforms as transforms 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | from torch.utils.data import DataLoader 22 | from oil.utils.utils import LoaderTo, cosLr, islice 23 | from oil.tuning.study import train_trial 24 | from oil.datasetup.datasets import split_dataset 25 | from oil.utils.parallel import try_multigpu_parallelize 26 | from oil.model_trainers.classifier import Classifier 27 | from functools import partial 28 | from torch.optim import Adam 29 | from oil.tuning.args import argupdated_config 30 | import copy 31 | import lie_conv.lieGroups as lieGroups 32 | import lie_conv.lieConv as lieConv 33 | from lie_conv.lieConv import ImgLieResnet 34 | # from lie_conv.datasets import MnistRotDataset, RotMNIST 35 | from oil.datasetup.datasets import EasyIMGDataset 36 | 37 | from lie_conv.utils import Named, export, Expression, FixedNumpySeed, RandomZrotation, GaussianNoise 38 | from lie_conv.utils import Named 39 | import numpy as np 40 | from PIL import Image 41 | from torchvision.datasets.utils import download_url, download_and_extract_archive, extract_archive, \ 42 | verify_str_arg 43 | from torchvision.datasets.vision import VisionDataset 44 | import torchvision 45 | 46 | class RotCIFAR10(EasyIMGDataset,torchvision.datasets.CIFAR10): 47 | # """ Unofficial RotMNIST dataset created on the fly by rotating MNIST""" 48 | means = (0.5,) 49 | stds = (0.25,) 50 | num_targets = 10 51 | def __init__(self,*args,dataseed=0,transform=None,**kwargs): 52 | super().__init__(*args,download=True,**kwargs) 53 | N = len(self) 54 | with FixedNumpySeed(dataseed): 55 | angles = torch.rand(N)*2*np.pi 56 | with torch.no_grad(): 57 | # R = torch.zeros(N,2,2) 58 | # R[:,0,0] = R[:,1,1] = angles.cos() 59 | # R[:,0,1] = R[:,1,0] = angles.sin() 60 | # R[:,1,0] *=-1 61 | # Build affine matrices for random translation of each image 62 | affineMatrices = torch.zeros(N,2,3) 63 | affineMatrices[:,0,0] = angles.cos() 64 | affineMatrices[:,1,1] = angles.cos() 65 | affineMatrices[:,0,1] = angles.sin() 66 | affineMatrices[:,1,0] = -angles.sin() 67 | # affineMatrices[:,0,2] = -2*np.random.randint(-self.max_trans, self.max_trans+1, bs)/w 68 | # affineMatrices[:,1,2] = 2*np.random.randint(-self.max_trans, self.max_trans+1, bs)/h 69 | # self.data = self.data.unsqueeze(1).float() 70 | self.data = torch.as_tensor(self.data.transpose((0,3,1,2))).float() 71 | flowgrid = F.affine_grid(affineMatrices, size = self.data.size()) 72 | self.data = F.grid_sample(self.data, flowgrid) 73 | normalize = transforms.Normalize((127.5,) ,(255,)) 74 | self.data = normalize(self.data) 75 | def __getitem__(self,idx): 76 | return self.data[idx], int(self.targets[idx]) 77 | def default_aug_layers(self): 78 | return RandomRotateTranslate(0)# no translation 79 | 80 | 81 | class CIFAR100(EasyIMGDataset,torchvision.datasets.CIFAR100): 82 | # """ Unofficial RotMNIST dataset created on the fly by rotating MNIST""" 83 | means = (0.5,) 84 | stds = (0.25,) 85 | num_targets = 100 86 | def __init__(self,*args,dataseed=0,transform=None,**kwargs): 87 | super().__init__(*args,download=True,**kwargs) 88 | N = len(self) 89 | with torch.no_grad(): 90 | self.data = torch.as_tensor(self.data.transpose((0,3,1,2))).float() 91 | normalize = transforms.Normalize((127.5,) ,(255,)) 92 | self.data = normalize(self.data) 93 | def __getitem__(self,idx): 94 | return self.data[idx], int(self.targets[idx]) 95 | def default_aug_layers(self): 96 | return RandomRotateTranslate(0)# no translation 97 | 98 | 99 | class RotCIFAR100(EasyIMGDataset,torchvision.datasets.CIFAR100): 100 | # """ Unofficial RotMNIST dataset created on the fly by rotating MNIST""" 101 | means = (0.5,) 102 | stds = (0.25,) 103 | num_targets = 100 104 | def __init__(self,*args,dataseed=0,transform=None,**kwargs): 105 | super().__init__(*args,download=True,**kwargs) 106 | N = len(self) 107 | with FixedNumpySeed(dataseed): 108 | angles = torch.rand(N)*2*np.pi 109 | with torch.no_grad(): 110 | # R = torch.zeros(N,2,2) 111 | # R[:,0,0] = R[:,1,1] = angles.cos() 112 | # R[:,0,1] = R[:,1,0] = angles.sin() 113 | # R[:,1,0] *=-1 114 | # Build affine matrices for random translation of each image 115 | affineMatrices = torch.zeros(N,2,3) 116 | affineMatrices[:,0,0] = angles.cos() 117 | affineMatrices[:,1,1] = angles.cos() 118 | affineMatrices[:,0,1] = angles.sin() 119 | affineMatrices[:,1,0] = -angles.sin() 120 | # affineMatrices[:,0,2] = -2*np.random.randint(-self.max_trans, self.max_trans+1, bs)/w 121 | # affineMatrices[:,1,2] = 2*np.random.randint(-self.max_trans, self.max_trans+1, bs)/h 122 | # self.data = self.data.unsqueeze(1).float() 123 | self.data = torch.as_tensor(self.data.transpose((0,3,1,2))).float() 124 | flowgrid = F.affine_grid(affineMatrices, size = self.data.size()) 125 | self.data = F.grid_sample(self.data, flowgrid) 126 | normalize = transforms.Normalize((127.5,) ,(255,)) 127 | self.data = normalize(self.data) 128 | def __getitem__(self,idx): 129 | return self.data[idx], int(self.targets[idx]) 130 | def default_aug_layers(self): 131 | return RandomRotateTranslate(0)# no translation 132 | 133 | class RotScramCIFAR100(RotCIFAR100): 134 | """ Scrambled""" 135 | def __init__(self,*args,**kwargs): 136 | super().__init__(*args,**kwargs) 137 | with torch.no_grad(): 138 | idx = torch.randperm(self.data[0,0].nelement()) 139 | self.data = self.data.view(*self.data.shape[:2], -1)[:,:,idx].view(self.data.size()) 140 | 141 | def makeTrainer(*, dataset=RotCIFAR100, network=ImgLieResnet, num_epochs=100, 142 | bs=50, lr=3e-3, aug=False,#True, 143 | optim=Adam, device='cuda', trainer=Classifier, 144 | split={'train':40000}, small_test=False, net_config={}, opt_config={}, 145 | trainer_config={'log_dir':None}): 146 | 147 | # Prep the datasets splits, model, and dataloaders 148 | datasets = split_dataset(dataset(f'~/datasets/{dataset}/'),splits=split) 149 | datasets['test'] = dataset(f'~/datasets/{dataset}/', train=False) 150 | device = torch.device(device) 151 | model = network(num_targets=datasets['train'].num_targets,**net_config).to(device) 152 | if aug: model = torch.nn.Sequential(datasets['train'].default_aug_layers(),model) 153 | model,bs = try_multigpu_parallelize(model,bs) 154 | 155 | dataloaders = {k:LoaderTo(DataLoader(v,batch_size=bs,shuffle=(k=='train'), 156 | num_workers=0,pin_memory=False),device) for k,v in datasets.items()} 157 | dataloaders['Train'] = islice(dataloaders['train'],1+len(dataloaders['train'])//10) 158 | if small_test: dataloaders['test'] = islice(dataloaders['test'],1+len(dataloaders['train'])//10) 159 | # Add some extra defaults if SGD is chosen 160 | opt_constr = partial(optim, lr=lr, **opt_config) 161 | lr_sched = cosLr(num_epochs) 162 | return trainer(model,dataloaders,opt_constr,lr_sched,**trainer_config) 163 | 164 | 165 | 166 | import argparse 167 | parser = argparse.ArgumentParser(description='LieConv Tests') 168 | parser.add_argument('--rot', type=int, default=1, metavar='N', 169 | help='rotated CIFAR100 (default: False)') 170 | parser.add_argument('--scr', type=int, default=0, metavar='N', 171 | help='scramble (default: False)') 172 | parser.add_argument('--ker', type=int, default=128, metavar='N', 173 | help='k in LieConv layer (default: 128)') 174 | parser.add_argument('--nlay', type=int, default=2, metavar='N', 175 | help='number of layers (default: 2)') 176 | parser.add_argument('--epochs', type=int, default=40, metavar='N', 177 | help='number of epochs to train (default: 40)') 178 | parser.add_argument('--lr', type=float, default=3e-3, metavar='N', 179 | help='learning rate (default: 3e-3)') 180 | parser.add_argument('--bn', type=int, default=1, metavar='N', 181 | help='batch normalization (default: True)') 182 | 183 | 184 | 185 | args = parser.parse_args() 186 | 187 | SCRAMBLE = args.scr 188 | ROTATE = args.rot 189 | 190 | ker = args.ker 191 | nlay = args.nlay 192 | batchnorm = bool(args.bn) 193 | EPOCHS = args.epochs 194 | 195 | 196 | 197 | 198 | Trial = train_trial(makeTrainer) 199 | defaults = copy.deepcopy(makeTrainer.__kwdefaults__) 200 | 201 | if ROTATE: 202 | defaults['dataset'] = RotCIFAR100 #MnistRotScrambleDataset 203 | elif SCRAMBLE: 204 | defaults['dataset'] = RotScramCIFAR100 205 | else: 206 | print("=============\n\n Using default CIFAR100\n\n=============") 207 | defaults['dataset'] = CIFAR100 208 | 209 | 210 | defaults['net_config'] = dict(chin=3, 211 | num_layers=nlay, 212 | k=ker, 213 | bn= batchnorm 214 | ) 215 | defaults['num_epochs'] = EPOCHS 216 | defaults['lr'] = args.lr 217 | 218 | print(defaults) 219 | fnam = f'./results/lie_conv-cifar100{"-rot" if ROTATE else ""}{"-scr" if SCRAMBLE else ""}-lay{nlay}-k{ker}.pkl' 220 | 221 | print('\n', fnam,'\n') 222 | 223 | results = Trial(defaults) 224 | 225 | net = ImgLieResnet(**defaults['net_config']) 226 | param_size_list = [(name,tuple(param.shape)) for name, param in net.net.named_parameters() if param.requires_grad] 227 | 228 | out = dict( net_configs = results[0], 229 | results = results[1].to_dict(), 230 | params = param_size_list, 231 | total_params = sum([np.prod(i[1]) for i in param_size_list]) 232 | ) 233 | 234 | print('# params: ', out['total_params']) 235 | import pickle 236 | pickle.dump(out, open(fnam, 'wb')) 237 | 238 | 239 | 240 | # if __name__=="__main__": 241 | # Trial = train_trial(makeTrainer) 242 | # defaults = copy.deepcopy(makeTrainer.__kwdefaults__) 243 | # defaults['save'] = False 244 | # Trial(argupdated_config(defaults,namespace=(lieConv,lieGroups))) 245 | -------------------------------------------------------------------------------- /paper-code/D-image-experiments/run_test-v2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import lconv 4 | import tensorflow as tf 5 | 6 | sess = tf.compat.v1.InteractiveSession() 7 | # K = tf.keras.backend 8 | 9 | from tensorflow.keras import Model, Sequential 10 | from tensorflow.keras.layers import Layer, Input, Flatten, Reshape, Dense, Conv2D, MaxPool2D 11 | 12 | import pickle as pk 13 | import json 14 | import numpy as np 15 | 16 | 17 | from scipy import ndimage 18 | 19 | def rotated_ims_rand(x): 20 | return np.float32([ndimage.rotate(i, (np.random.rand()-.5)*180, reshape=False, mode='nearest') for i in x]) 21 | 22 | class Scramble_x: 23 | def __init__(self,x): 24 | s = x.shape[1:-1] 25 | self.idx = np.argsort(np.random.rand(np.prod(s))) 26 | r,c = np.int0(self.idx/s[0]), (self.idx % s[1]) 27 | self.x = np.float32([i[r,c].reshape(s+(x.shape[-1],)) for i in x]) 28 | 29 | 30 | # Defaults 31 | configs= { 32 | 'dataset': dict(name='mnist' ,#'mnist',cifar100, 33 | rotate=False, 34 | scramble=False,), 35 | 'net': dict(architecture= 'lconv', # 'cnn', 'fc', 36 | num_filters=32, 37 | kernel_size=9, 38 | L_hid= [8], #[16], 39 | activation = 'relu', 40 | L_trainable = True, 41 | num_layers = 1, 42 | ), 43 | } 44 | 45 | 46 | import argparse 47 | 48 | parser = argparse.ArgumentParser(description='Run experiments using L-conv or baselines.') 49 | # Add the arguments 50 | # parser.add_argument('dataset_name',#action='store', 51 | # # metavar='--data', 52 | # type=str, #required=True, 53 | # help='mnist, cifar10, cifar100') 54 | 55 | # Add the arguments 56 | # parser.add_argument('config_json_file', #action='store', 57 | # # metavar='--config', 58 | # type=str, #required=True, 59 | # help='json file of model configuration') 60 | parser.add_argument('--architecture', action='store', type=str, help="lconv (default), cnn, fc") 61 | parser.add_argument('--dataset', action='store', type=str, help='mnist (default), cifar10, cifar100, fashion_mnist') 62 | parser.add_argument('--rotate', action='store_const', const=True) 63 | parser.add_argument('--scramble', action='store_const', const=True) 64 | parser.add_argument('--lrand', action='store_true', help='Whether L are random or trainable') 65 | parser.add_argument('--epochs', action='store', type=int) 66 | parser.add_argument('--num_layers', action='store', type=int, help='Number of layers before classification layer') 67 | parser.add_argument('--hid', action='store', type=int, help='# hidden units in L for low-rank encoding') 68 | 69 | # parser.add_argument('--test', action='store', type=bool, required=False) 70 | 71 | args = parser.parse_args() 72 | # print(args.architecture) 73 | # exit() 74 | 75 | # print(args.dataset_name) 76 | # print(args.config_json_file) 77 | # configs = json.load(open(args.config_json_file,'r')) 78 | 79 | if args.architecture: 80 | configs['net']['architecture'] = args.architecture 81 | 82 | if args.dataset: 83 | configs['dataset']['name'] = args.dataset 84 | 85 | if args.rotate!=None: 86 | configs['dataset']['rotate'] = args.rotate 87 | 88 | if args.scramble!=None: 89 | configs['dataset']['scramble'] = args.scramble 90 | 91 | 92 | print(args.lrand) 93 | 94 | if args.lrand: 95 | configs['net']['L_trainable'] = False 96 | 97 | if args.hid: 98 | configs['net']['L_hid'] = [args.hid] 99 | 100 | if args.num_layers: 101 | configs['net']['num_layers'] = args.num_layers 102 | 103 | EPOCHS = args.epochs or 30 104 | 105 | print(configs) 106 | 107 | # exit() 108 | 109 | # cf: conv filters, ker: kernel size, d: dense (FC) units 110 | # make_model_name = lambda cf,ker, d: ''.join([('c%s(k%d)' %(cf,ker) if len(cf) else ''),('d%s'%d if len(d) else '')]) or 'base' 111 | # cf: conv filters, ker: kernel size, d: dense (FC) units 112 | 113 | 114 | 115 | dataset_name = configs['dataset']['name'] 116 | 117 | dataset = eval("tf.keras.datasets.%s.load_data()" %dataset_name) 118 | (x_train, y_train), (x_test,y_test) = dataset 119 | if len(x_train.shape) == 3: 120 | # mnist channel is missing 121 | x_train = x_train[...,np.newaxis] 122 | 123 | # normalize 124 | x_train = x_train/x_train[:100].max() -.5 125 | # make categorical 126 | y_train = tf.keras.utils.to_categorical(y_train) 127 | 128 | 129 | results = {'configs':configs,} 130 | 131 | if configs['dataset']['rotate']: 132 | print('Rotating images') 133 | x_train = rotated_ims_rand(x_train) 134 | 135 | if configs['dataset']['scramble']: 136 | print('Scrambling images') 137 | scr = Scramble_x(x_train) 138 | x_train = scr.x 139 | results['scramble_idx']=scr.idx.tolist() 140 | 141 | ##### Make model ##### 142 | 143 | net = configs['net'] 144 | # arch = net['architecture'] 145 | 146 | kernel_size = net['kernel_size'] 147 | # L_hid = net['L_hid'] 148 | # activation = net['activation'] 149 | # L_trainable = net['L_trainable'] 150 | 151 | 152 | inp = Input(x_train[0].shape) 153 | 154 | x = inp 155 | for _ in range(net['num_layers']): 156 | if net['architecture']=='lconv': 157 | x = tf.reshape(x, shape=(-1,np.prod(x.shape[1:-1]), x.shape[-1])) 158 | lay = lconv.L_Conv(num_filters= net['num_filters'], 159 | kernel_size= net['kernel_size'], 160 | L_hid = net['L_hid'], 161 | activation = net['activation'],) 162 | 163 | x = lay(x) 164 | lay.L.trainable = net['L_trainable'] 165 | 166 | elif net['architecture']=='cnn': 167 | kx = int(0.5+np.sqrt(net['kernel_size'])) 168 | ky = int(0.5+net['kernel_size']/kx) 169 | kernel_size = (kx,ky) 170 | cnn = Conv2D(filters=net['num_filters'], 171 | kernel_size=kernel_size, 172 | activation = net['activation']) 173 | x = cnn(inp) 174 | 175 | elif net['architecture']=='fc': 176 | k = net['kernel_size'] 177 | nf = net['num_filters'] 178 | hid = net['L_hid'][0] 179 | act = net['activation'] 180 | xs = np.prod(inp.shape[1:-1]) 181 | 182 | x = Flatten()(inp) 183 | # FC comparable to L-conv, but no shared weights 184 | x = Dense(k*hid, activation = act)(x) 185 | x = Dense(xs*nf, activation = act)(x) 186 | 187 | 188 | if net['architecture']!='fc': 189 | x = Flatten()(x) 190 | 191 | # x = Dense(100, activation = 'relu')(x) 192 | 193 | out = Dense(y_train.shape[-1], activation='softmax')(x) 194 | 195 | model = Model(inputs = [inp], outputs = [out]) 196 | model.compile(loss = tf.keras.losses.categorical_crossentropy, metrics = ['accuracy']) 197 | 198 | model.summary() 199 | 200 | 201 | # exit() 202 | 203 | ##### Train model 204 | 205 | h = model.fit(x_train, y_train, validation_split=0.2, epochs=EPOCHS) 206 | 207 | ##### model name and results 208 | non_trainable = 0 209 | if net['architecture'] == 'lconv': 210 | model_name = f"L-conv-nf{net['num_filters']}-hid{net['L_hid']}-L_trainable{net['L_trainable']}-ker{net['kernel_size']}" 211 | non_trainable = (0 if net['L_trainable'] else lay.L.count_params()) 212 | elif net['architecture'] == 'cnn': 213 | model_name = f"CNN-nf{net['num_filters']}-ker{kernel_size}" 214 | elif net['architecture'] =='fc': 215 | model_name = f"FC-nf{net['num_filters']}-hid{net['L_hid']}-ker{kernel_size}" 216 | 217 | 218 | model_name += f"-n_lay-{net['num_layers']}" 219 | model_name += f"-act-{net['activation']}" 220 | num_params = model.count_params() - non_trainable 221 | 222 | out_file_name = f"./results-v2/{dataset_name}/{model_name}-rotate={configs['dataset']['rotate']}-scramble={configs['dataset']['scramble']}.json" 223 | 224 | results = {} 225 | results.update({ 226 | 'num_params':num_params, 227 | 'result':h.history, 228 | 'configs':configs, 229 | # 'result': {k: np.float32(v).tolist() for k,v in h.history.items()}, # bug in json or TF2 230 | }) 231 | 232 | 233 | 234 | # for k,v in results['result'].items(): 235 | # print(k,type(v)) 236 | 237 | import os 238 | 239 | # print(h.history) 240 | 241 | dirs = os.path.split(out_file_name)[0] 242 | os.makedirs(dirs,exist_ok=True) 243 | 244 | print(out_file_name) 245 | 246 | 247 | json.dump(results, open(out_file_name, 'w')) -------------------------------------------------------------------------------- /paper-code/D-image-experiments/run_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import lconv 4 | import tensorflow as tf 5 | 6 | sess = tf.compat.v1.InteractiveSession() 7 | # K = tf.keras.backend 8 | 9 | from tensorflow.keras import Model, Sequential 10 | from tensorflow.keras.layers import Layer, Input, Flatten, Reshape, Dense, Conv2D, MaxPool2D 11 | 12 | import pickle as pk 13 | import json 14 | import numpy as np 15 | 16 | 17 | from scipy import ndimage 18 | 19 | def rotated_ims_rand(x): 20 | return np.float32([ndimage.rotate(i, (np.random.rand()-.5)*180, reshape=False, mode='nearest') for i in x]) 21 | 22 | class Scramble_x: 23 | def __init__(self,x): 24 | s = x.shape[1:-1] 25 | self.idx = np.argsort(np.random.rand(np.prod(s))) 26 | r,c = np.int0(self.idx/s[0]), (self.idx % s[1]) 27 | self.x = np.float32([i[r,c].reshape(s+(x.shape[-1],)) for i in x]) 28 | 29 | 30 | # Defaults 31 | configs= { 32 | 'dataset': dict(name='mnist' ,#'mnist',cifar100, 33 | rotate=False, 34 | scramble=False,), 35 | 'net': dict(architecture= 'lconv', # 'cnn', 'fc', 36 | num_filters=32, 37 | kernel_size=9, 38 | L_hid= [8], #[16], 39 | activation = 'relu', 40 | L_trainable = True), 41 | } 42 | 43 | 44 | import argparse 45 | 46 | parser = argparse.ArgumentParser(description='Run experiments using L-conv or baselines.') 47 | # Add the arguments 48 | # parser.add_argument('dataset_name',#action='store', 49 | # # metavar='--data', 50 | # type=str, #required=True, 51 | # help='mnist, cifar10, cifar100') 52 | 53 | # Add the arguments 54 | # parser.add_argument('config_json_file', #action='store', 55 | # # metavar='--config', 56 | # type=str, #required=True, 57 | # help='json file of model configuration') 58 | parser.add_argument('--architecture', action='store', type=str, help="lconv (default), cnn, fc") 59 | parser.add_argument('--dataset', action='store', type=str, help='mnist (default), cifar10, cifar100, fashion_mnist') 60 | parser.add_argument('--rotate', action='store_const', const=True) 61 | parser.add_argument('--scramble', action='store_const', const=True) 62 | parser.add_argument('--lrand', action='store_true', help='Whether L are random or trainable') 63 | parser.add_argument('--epochs', action='store', type=int) 64 | parser.add_argument('--hid', action='store', type=int, help='# hidden units in L for low-rank encoding') 65 | 66 | # parser.add_argument('--test', action='store', type=bool, required=False) 67 | 68 | args = parser.parse_args() 69 | # print(args.architecture) 70 | # exit() 71 | 72 | # print(args.dataset_name) 73 | # print(args.config_json_file) 74 | # configs = json.load(open(args.config_json_file,'r')) 75 | 76 | if args.architecture: 77 | configs['net']['architecture'] = args.architecture 78 | 79 | if args.dataset: 80 | configs['dataset']['name'] = args.dataset 81 | 82 | if args.rotate!=None: 83 | configs['dataset']['rotate'] = args.rotate 84 | 85 | if args.scramble!=None: 86 | configs['dataset']['scramble'] = args.scramble 87 | 88 | 89 | print(args.lrand) 90 | 91 | if args.lrand: 92 | configs['net']['L_trainable'] = False 93 | 94 | if args.hid: 95 | configs['net']['L_hid'] = [args.hid] 96 | 97 | 98 | EPOCHS = args.epochs or 30 99 | 100 | print(configs) 101 | 102 | # exit() 103 | 104 | # cf: conv filters, ker: kernel size, d: dense (FC) units 105 | # make_model_name = lambda cf,ker, d: ''.join([('c%s(k%d)' %(cf,ker) if len(cf) else ''),('d%s'%d if len(d) else '')]) or 'base' 106 | # cf: conv filters, ker: kernel size, d: dense (FC) units 107 | 108 | 109 | 110 | dataset_name = configs['dataset']['name'] 111 | 112 | dataset = eval("tf.keras.datasets.%s.load_data()" %dataset_name) 113 | (x_train, y_train), (x_test,y_test) = dataset 114 | if len(x_train.shape) == 3: 115 | # mnist channel is missing 116 | x_train = x_train[...,np.newaxis] 117 | 118 | # normalize 119 | x_train = x_train/x_train[:100].max() -.5 120 | # make categorical 121 | y_train = tf.keras.utils.to_categorical(y_train) 122 | 123 | 124 | results = {'configs':configs,} 125 | 126 | if configs['dataset']['rotate']: 127 | print('Rotating images') 128 | x_train = rotated_ims_rand(x_train) 129 | 130 | if configs['dataset']['scramble']: 131 | print('Scrambling images') 132 | scr = Scramble_x(x_train) 133 | x_train = scr.x 134 | results['scramble_idx']=scr.idx.tolist() 135 | 136 | ##### Make model ##### 137 | 138 | net = configs['net'] 139 | # arch = net['architecture'] 140 | 141 | kernel_size = net['kernel_size'] 142 | # L_hid = net['L_hid'] 143 | # activation = net['activation'] 144 | # L_trainable = net['L_trainable'] 145 | 146 | 147 | inp = Input(x_train[0].shape) 148 | 149 | if net['architecture']=='lconv': 150 | x = tf.reshape(inp, shape=(-1,np.prod(inp.shape[1:-1]), inp.shape[-1])) 151 | lay = lconv.L_Conv(num_filters= net['num_filters'], 152 | kernel_size= kernel_size, 153 | L_hid = net['L_hid'], 154 | activation = net['activation'],) 155 | 156 | x = lay(x) 157 | lay.L.trainable = net['L_trainable'] 158 | 159 | elif net['architecture']=='cnn': 160 | kx = int(round(np.sqrt(kernel_size))) 161 | ky = int(round(kernel_size/kx)) 162 | kernel_size = (kx,ky) 163 | x = Conv2D(filters=net['num_filters'], kernel_size=kernel_size, activation = net['activation'])(inp) 164 | # x = cnn(inp) 165 | elif net['architecture']=='fc': 166 | k = net['kernel_size'] 167 | nf = net['num_filters'] 168 | hid = net['L_hid'][0] 169 | act = net['activation'] 170 | xs = np.prod(inp.shape[1:-1]) 171 | 172 | x = Flatten()(inp) 173 | # FC comparable to L-conv, but no shared weights 174 | x = Dense(k*hid, activation = act)(x) 175 | x = Dense(xs*nf, activation = act)(x) 176 | 177 | if net['architecture']!='fc': 178 | x = Flatten()(x) 179 | 180 | # x = Dense(100, activation = 'relu')(x) 181 | 182 | out = Dense(y_train.shape[-1], activation='softmax')(x) 183 | 184 | model = Model(inputs = [inp], outputs = [out]) 185 | model.compile(loss = tf.keras.losses.categorical_crossentropy, metrics = ['accuracy']) 186 | 187 | model.summary() 188 | 189 | 190 | 191 | ##### Train model 192 | 193 | h = model.fit(x_train, y_train, validation_split=0.2, epochs=EPOCHS) 194 | 195 | ##### model name and results 196 | non_trainable = 0 197 | if net['architecture'] == 'lconv': 198 | model_name = f"L-conv-nf{net['num_filters']}-hid{net['L_hid']}-L_trainable{net['L_trainable']}-ker{kernel_size}" 199 | non_trainable = (0 if net['L_trainable'] else lay.L.count_params()) 200 | elif net['architecture'] == 'cnn': 201 | model_name = f"CNN-nf{net['num_filters']}-ker{kernel_size}" 202 | elif net['architecture'] =='fc': 203 | model_name = f"FC-nf{net['num_filters']}-hid{net['L_hid']}-ker{kernel_size}" 204 | 205 | 206 | model_name += f"-act-{net['activation']}" 207 | num_params = model.count_params() - non_trainable 208 | 209 | out_file_name = f"./results/{dataset_name}/{model_name}-rotate={configs['dataset']['rotate']}-scrambled={configs['dataset']['scramble']}.json" 210 | 211 | results = {} 212 | results.update({ 213 | 'num_params':num_params, 214 | 'result':h.history, 215 | # 'result': {k: np.float32(v).tolist() for k,v in h.history.items()}, # bug in json or TF2 216 | }) 217 | 218 | # for k,v in results['result'].items(): 219 | # print(k,type(v)) 220 | 221 | import os 222 | 223 | # print(h.history) 224 | 225 | dirs = os.path.split(out_file_name)[0] 226 | os.makedirs(dirs,exist_ok=True) 227 | 228 | print(out_file_name) 229 | 230 | 231 | json.dump(results, open(out_file_name, 'w')) -------------------------------------------------------------------------------- /src/.ipynb_checkpoints/examples-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "c71613f7-6f96-4c85-aa8f-e56eb35b30cf", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from lconv import Lconv" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "id": "65d9c034-6cfe-4640-b86b-c0f14d0cf971", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import torch" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 3, 26 | "id": "4124cfe2-6c17-47ea-aaa9-b59c10afeba0", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "n = 10\n", 31 | "d = 7\n", 32 | "c = 3\n", 33 | "# channel first\n", 34 | "x = torch.rand(n,c,d)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "id": "67a78d5c-758d-49bd-96bc-c4abc7dd3160", 40 | "metadata": {}, 41 | "source": [ 42 | "## Basic usage" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 5, 48 | "id": "70f5dfcb-e608-4eed-8dab-cf6353463b23", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "num_L = 2\n", 53 | "co = 5\n", 54 | "l = Lconv(d,num_L, cin=c, cout= co)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 6, 60 | "id": "b9f51920-6495-425f-94b9-2e9a51dbb5ec", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "y = l(x)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 7, 70 | "id": "2e380c7d-5f1e-456e-b3a5-21042bb6ac4e", 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "torch.Size([10, 5, 7])" 77 | ] 78 | }, 79 | "execution_count": 7, 80 | "metadata": {}, 81 | "output_type": "execute_result" 82 | } 83 | ], 84 | "source": [ 85 | "y.shape" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "5c6c2860-59e8-4a6f-a2f8-5673d885feb6", 91 | "metadata": {}, 92 | "source": [ 93 | "# Updated soon" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "aad5bb07-2ac4-43bb-a678-4913219b9840", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [] 103 | } 104 | ], 105 | "metadata": { 106 | "kernelspec": { 107 | "display_name": "Python 3", 108 | "language": "python", 109 | "name": "python3" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": { 113 | "name": "ipython", 114 | "version": 3 115 | }, 116 | "file_extension": ".py", 117 | "mimetype": "text/x-python", 118 | "name": "python", 119 | "nbconvert_exporter": "python", 120 | "pygments_lexer": "ipython3", 121 | "version": "3.8.1" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 5 126 | } 127 | -------------------------------------------------------------------------------- /src/.ipynb_checkpoints/lconv-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Lconv(nn.Module): 6 | """ L-conv layer with full L """ 7 | def __init__(self,d,num_L=1,cin=1,cout=1,rank=8): 8 | """ 9 | L:(num_L, d, d) 10 | Wi: (num_L, cout, cin) 11 | """ 12 | super().__init__() 13 | self.L = nn.Parameter(torch.Tensor(num_L, d, d)) 14 | self.Wi = nn.Parameter(torch.Tensor(num_L+1, cout, cin)) # W^0 = Wi[0], W^0\epsion^i = Wi[1:] 15 | 16 | # initialize weights and biases 17 | nn.init.kaiming_normal_(self.L) 18 | nn.init.kaiming_normal_(self.Wi) 19 | 20 | def forward(self, x): 21 | # x:(batch, channel, flat_d) 22 | # res = x W0 23 | residual = torch.einsum('bcd,oc->bod', x, self.Wi[0] ) 24 | # y = Li x Wi 25 | y = torch.einsum('kdf,bcf,koc->bod', self.L, x, self.Wi[1:]) 26 | return y + residual 27 | 28 | class Reshape(nn.Module): 29 | def __init__(self,shape=None): 30 | self.shape = shape 31 | super().__init__() 32 | def forward(self,x): 33 | return x.view(-1,*self.shape) 34 | 35 | -------------------------------------------------------------------------------- /src/__pycache__/lconv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nimadehmamy/L-conv-code/5a8abfbff3f6564771234df3e177d1d4aafe371d/src/__pycache__/lconv.cpython-38.pyc -------------------------------------------------------------------------------- /src/examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "c71613f7-6f96-4c85-aa8f-e56eb35b30cf", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from lconv import Lconv" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "id": "65d9c034-6cfe-4640-b86b-c0f14d0cf971", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import torch" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 3, 26 | "id": "4124cfe2-6c17-47ea-aaa9-b59c10afeba0", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "n = 10\n", 31 | "d = 7\n", 32 | "c = 3\n", 33 | "# channel first\n", 34 | "x = torch.rand(n,c,d)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "id": "67a78d5c-758d-49bd-96bc-c4abc7dd3160", 40 | "metadata": {}, 41 | "source": [ 42 | "## Basic usage" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 5, 48 | "id": "70f5dfcb-e608-4eed-8dab-cf6353463b23", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "num_L = 2\n", 53 | "co = 5\n", 54 | "l = Lconv(d,num_L, cin=c, cout= co)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 6, 60 | "id": "b9f51920-6495-425f-94b9-2e9a51dbb5ec", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "y = l(x)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 7, 70 | "id": "2e380c7d-5f1e-456e-b3a5-21042bb6ac4e", 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "torch.Size([10, 5, 7])" 77 | ] 78 | }, 79 | "execution_count": 7, 80 | "metadata": {}, 81 | "output_type": "execute_result" 82 | } 83 | ], 84 | "source": [ 85 | "y.shape" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "5c6c2860-59e8-4a6f-a2f8-5673d885feb6", 91 | "metadata": {}, 92 | "source": [ 93 | "# Updated soon" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "aad5bb07-2ac4-43bb-a678-4913219b9840", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [] 103 | } 104 | ], 105 | "metadata": { 106 | "kernelspec": { 107 | "display_name": "Python 3", 108 | "language": "python", 109 | "name": "python3" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": { 113 | "name": "ipython", 114 | "version": 3 115 | }, 116 | "file_extension": ".py", 117 | "mimetype": "text/x-python", 118 | "name": "python", 119 | "nbconvert_exporter": "python", 120 | "pygments_lexer": "ipython3", 121 | "version": "3.8.1" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 5 126 | } 127 | -------------------------------------------------------------------------------- /src/lconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Lconv(nn.Module): 6 | """ L-conv layer with full L """ 7 | def __init__(self,d,num_L=1,cin=1,cout=1,rank=8): 8 | """ 9 | L:(num_L, d, d) 10 | Wi: (num_L, cout, cin) 11 | """ 12 | super().__init__() 13 | self.L = nn.Parameter(torch.Tensor(num_L, d, d)) 14 | self.Wi = nn.Parameter(torch.Tensor(num_L+1, cout, cin)) # W^0 = Wi[0], W^0\epsion^i = Wi[1:] 15 | 16 | # initialize weights and biases 17 | nn.init.kaiming_normal_(self.L) 18 | nn.init.kaiming_normal_(self.Wi) 19 | 20 | def forward(self, x): 21 | # x:(batch, channel, flat_d) 22 | # res = x W0 23 | residual = torch.einsum('bcd,oc->bod', x, self.Wi[0] ) 24 | # y = Li x Wi 25 | y = torch.einsum('kdf,bcf,koc->bod', self.L, x, self.Wi[1:]) 26 | return y + residual 27 | 28 | class Reshape(nn.Module): 29 | def __init__(self,shape=None): 30 | self.shape = shape 31 | super().__init__() 32 | def forward(self,x): 33 | return x.view(-1,*self.shape) 34 | 35 | --------------------------------------------------------------------------------