├── combined_cover.png ├── PytorchExamples ├── block.py ├── network.py └── Denoising_demo.ipynb ├── README.md └── TensorflowExamples ├── basicModules.py └── layer_CSVT.py /combined_cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avisekiit/TCSVT-LightWeight-CNNs/HEAD/combined_cover.png -------------------------------------------------------------------------------- /PytorchExamples/block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | 6 | class depthwise_separable_conv(nn.Module): 7 | def __init__(self, nin, nout): 8 | super(depthwise_separable_conv, self).__init__() 9 | self.depthwise = nn.Conv2d(nin, nin, kernel_size=3, padding=1, groups=nin) 10 | self.pointwise = nn.Conv2d(nin, nout, kernel_size=1) 11 | 12 | def forward(self, x): 13 | out = self.depthwise(x) 14 | out = self.pointwise(out) 15 | return out 16 | 17 | class LIST(nn.Module): 18 | def __init__(self, input_channel, output_channel): 19 | super(LIST, self).__init__() 20 | self.squezze_layer=nn.Conv2d(input_channel, int(input_channel/4), kernel_size=1) 21 | self.fire_layer=nn.Conv2d(int(input_channel/4), int(output_channel/2), kernel_size=1) 22 | self.depthwise=depthwise_separable_conv(int(input_channel/4), int(output_channel/2)) 23 | 24 | def forward(self, x): 25 | stream=self.squezze_layer(x) 26 | out2=self.fire_layer(stream) 27 | out1=self.depthwise(stream) 28 | return torch.cat((out1,out2),1) 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TCSVT-LightWeight-CNNs 2 | Code for our IEEE TCSVT Paper: Lightweight Modules for Efficient Deep Learning based Image Restoration 3 | 4 | Authors: Avisek Lahiri*, Sourav Bairagya*, Sutanu Bera, Siddhant Haldar, Prabir Kumar Biswas
5 | (* equal contribution)
6 | 1. Paper Link: https://arxiv.org/abs/2007.05835
7 | 2. IEEE Early Access Link: https://ieeexplore.ieee.org/document/9134805
8 | 9 | ### Key Points from Paper 10 | * Paper provides re-usable modules to be plugged and played to compress a given CNN 11 | * Select any favourite full-scale baseline for low-level vision applications 12 | * Replace 3X3 conv by **LIST** layer 13 | * Replace dilated conv layer **GSAT** layer 14 | * Achieve efficient up/down-sample with **Bilinear SubSampling** followed by **LIST** layer 15 | 16 | ### TensorflowExamples 17 | This contains the basic proposed modules in Tensorflow
18 | TensorflowExamples/basicModules.py contains the proposed **LIST**, **GSAT** modules
19 | It also contains the framework for **LIST** based up/down-sampling in a CNN
20 | 21 | ### PytorchExamples 22 | It contains the basic proposed modules in Pytorch
23 | 24 | * block.py has the implementatin of **LIST** module based DnCNN denoising framework 25 | * Denoising_demo.ipynb is a notebook to reflect our training/inference setup for DnCNN experiments 26 | 27 | ![Cover Picture](/combined_cover.png) 28 | -------------------------------------------------------------------------------- /PytorchExamples/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from block import LIST 4 | 5 | 6 | class DnCNN(nn.Module): 7 | """ The bare-bone skeleton to implement the fullscale baseline of DnCNN denoising network. 8 | 9 | """ 10 | def __init__(self, channels=1, num_of_layers=17): 11 | super(DnCNN, self).__init__() 12 | kernel_size = 3 13 | padding = 1 14 | features = 64 15 | layers = [] 16 | layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False)) 17 | layers.append(nn.ReLU(inplace=True)) 18 | for _ in range(num_of_layers-2): 19 | layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False)) 20 | layers.append(nn.BatchNorm2d(features)) 21 | layers.append(nn.ReLU(inplace=True)) 22 | layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False)) 23 | self.dncnn = nn.Sequential(*layers) 24 | def forward(self, x): 25 | out = self.dncnn(x) 26 | return out 27 | 28 | 29 | 30 | class DnCNN_cheap(nn.Module): 31 | """ The bare-bone skeleton to implement the proposed cheaper version of DnCNN network. 32 | """ 33 | def __init__(self, channels=1, num_of_layers=17): 34 | super(DnCNN_cheap, self).__init__() 35 | kernel_size = 3 36 | padding = 1 37 | features = 64 38 | layers = [] 39 | layers.append(nn.Conv2d(channels, features, kernel_size=kernel_size, padding=padding, bias=False)) 40 | layers.append(nn.ReLU(inplace=True)) 41 | for _ in range(num_of_layers-2): 42 | layers.append(LIST(features, features)) 43 | layers.append(nn.BatchNorm2d(features)) 44 | layers.append(nn.ReLU(inplace=True)) 45 | layers.append(nn.Conv2d(features, channels, kernel_size=kernel_size, padding=padding, bias=False)) 46 | self.dncnn = nn.Sequential(*layers) 47 | def forward(self, x): 48 | out = self.dncnn(x) 49 | return out 50 | -------------------------------------------------------------------------------- /TensorflowExamples/basicModules.py: -------------------------------------------------------------------------------- 1 | from layer_CSVT import branchout, bottlenect, group_conv_dilated, group_conv_normal, batch_normalize 2 | 3 | # In this file we are releasing the barebone skeleton of our modules. 4 | # Users can takes these code snippets and plug in into their own 5 | # realizations of LIST, GSAT, up/down sampling layers. 6 | 7 | #============================================================ 8 | #*** LIST Module *** 9 | # in_channel : number of channels taken as input to the the LIST layer 10 | # out_channel: number of channels that will be output from the LIST layer 11 | with tf.variable_scope('conv1_1'): 12 | x = bottleneck(x,[1, 1, in_channel, in_channel//4]) 13 | x = batch_normalize(x, is_training) 14 | x = tf.nn.relu(x) 15 | with tf.variable_scope('conv1_2'): 16 | x = branchout(x,[1, 1, in_channel//4, out_channel//2],[3, 3, in_channel//4, out_channel//2]) 17 | x = batch_normalize(x, is_training) 18 | x = tf.nn.relu(x) 19 | #=========================================================== 20 | 21 | 22 | #=========================================================== 23 | 24 | #*** GSAT Module *** 25 | # in_channel : number of channels taken as input to the the GSAT layer 26 | # num_groups = Number of groups for Group Connvolution (see GSAT section in paper) 27 | with tf.variable_scope('gsat'): 28 | skip = x; 29 | x = group_conv_dilated(x,[3, 3, in_channel, in_channel], num_groups, dilation_factor) 30 | x = channel_shuffle(x, num_groups) 31 | x = group_conv_normal(x,[1, 1, in_channel, in_channel], num_groups) 32 | x = batch_normalize(x,is_training) 33 | x = skip + x 34 | x = tf.nn.relu(x) 35 | #=========================================================== 36 | 37 | 38 | #=========================================================== 39 | #*** UPSAMPLING Module *** 40 | #///////////////////////////// 41 | # First deterministic upsampling then follow with a LIST layer. 42 | with tf.variable_scope('deconv1_1'): 43 | x = tf.image.resize_bilinear(x, [x.shape[1]*stride, x.shape[2]*stride]) 44 | x = bottleneck(x,[1, 1, in_channel, in_channel//4]) 45 | x = batch_normalize(x, is_training) 46 | x = tf.nn.relu(x) 47 | with tf.variable_scope('deconv1_2'): 48 | x = branchout(x,[1, 1, in_channel//4, out_channel//2],[3, 3, in_channel//4, out_channel//2]) 49 | x = batch_normalize(x, is_training) 50 | x = tf.nn.relu(x) 51 | #=========================================================== 52 | 53 | 54 | #=========================================================== 55 | #*** DOWNSAMPLING Module *** 56 | #//////////////////////////////// 57 | # First deterministic downsampling then follow with a LIST layer. 58 | with tf.variable_scope('conv1_1'): 59 | x = tf.image.resize_bilinear(x, [x.shape[1]//stride, x.shape[2]//stride]) 60 | x = bottleneck(x,[1, 1, in_channel, in_channel//4]) 61 | x = batch_normalize(x, is_training) 62 | x = tf.nn.relu(x) 63 | with tf.variable_scope('conv1_2'): 64 | x = branchout(x,[1, 1, in_channel//4, out_channel//2], [3, 3, in_channel//4, out_channel//2]) 65 | x = batch_normalize(x, is_training) 66 | x = tf.nn.relu(x) 67 | #=========================================================== 68 | -------------------------------------------------------------------------------- /PytorchExamples/Denoising_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.7.3" 21 | }, 22 | "colab": { 23 | "name": "Denoising_demo.ipynb", 24 | "provenance": [], 25 | "collapsed_sections": [] 26 | } 27 | }, 28 | "cells": [ 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "Sv9RI8SsxYeN", 33 | "colab_type": "text" 34 | }, 35 | "source": [ 36 | "A Sample Notebook for realizing proposed compressed version of DnCNN framework\n", 37 | "for image denoising." 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "fZk2K1bsxTW4", 44 | "colab_type": "code", 45 | "colab": {} 46 | }, 47 | "source": [ 48 | "import torch\n", 49 | "import torch.nn as nn\n", 50 | "import numpy as np\n", 51 | "import random\n", 52 | "import h5py\n", 53 | "\n", 54 | "from torch.utils.data import Dataset, DataLoader\n", 55 | "import torch.optim as optim\n", 56 | "from torch.autograd import Variable\n", 57 | "from skimage.measure.simple_metrics import compare_psnr\n", 58 | "import torch.nn.functional as F\n", 59 | "\n", 60 | "torch.backends.cudnn.benchmark = True\n", 61 | "\n", 62 | "from network import DnCNN, DnCNN_cheap" 63 | ], 64 | "execution_count": null, 65 | "outputs": [] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "metadata": { 70 | "id": "rg7ouwIpxTW9", 71 | "colab_type": "code", 72 | "colab": {} 73 | }, 74 | "source": [ 75 | "def batch_PSNR(img, imclean, data_range):\n", 76 | " Img = img.data.cpu().numpy().astype(np.float32)\n", 77 | " Iclean = imclean.data.cpu().numpy().astype(np.float32)\n", 78 | " PSNR = 0\n", 79 | " for i in range(Img.shape[0]):\n", 80 | " PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)\n", 81 | " return (PSNR/Img.shape[0])" 82 | ], 83 | "execution_count": null, 84 | "outputs": [] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "metadata": { 89 | "id": "BeqxDL1IxTXA", 90 | "colab_type": "code", 91 | "colab": {} 92 | }, 93 | "source": [ 94 | "class Dataset(Dataset):\n", 95 | " def __init__(self, train=True):\n", 96 | " super(Dataset, self).__init__()\n", 97 | " self.train = train\n", 98 | " # Store images in .h5 file format\n", 99 | " if self.train:\n", 100 | " h5f = h5py.File('train.h5', 'r')\n", 101 | " else:\n", 102 | " h5f = h5py.File('val.h5', 'r')\n", 103 | " self.keys = list(h5f.keys())\n", 104 | " random.shuffle(self.keys)\n", 105 | " h5f.close()\n", 106 | " def __len__(self):\n", 107 | " return len(self.keys)\n", 108 | " def __getitem__(self, index):\n", 109 | " if self.train:\n", 110 | " h5f = h5py.File('train.h5', 'r')\n", 111 | " else:\n", 112 | " h5f = h5py.File('val.h5', 'r')\n", 113 | " key = self.keys[index]\n", 114 | " data = np.array(h5f[key])\n", 115 | " h5f.close()\n", 116 | " return torch.Tensor(data)\n" 117 | ], 118 | "execution_count": null, 119 | "outputs": [] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "metadata": { 124 | "id": "TDWTN7qYxTXE", 125 | "colab_type": "code", 126 | "colab": {} 127 | }, 128 | "source": [ 129 | "noiseL=25\n", 130 | "learning_rate=0.001\n", 131 | "batchSize=64\n", 132 | "print('Loading dataset ...\\n')\n", 133 | "dataset_train = Dataset(train=True)\n", 134 | "dataset_val = Dataset(train=False)\n", 135 | "loader_train = DataLoader(dataset=dataset_train, num_workers=7, \n", 136 | " batch_size=batchSize, shuffle=True)\n", 137 | "model = DnCNN_cheap() # DnCNN()\n", 138 | "model=model.cuda()\n", 139 | "criterion = nn.MSELoss(size_average=False)\n", 140 | "optimizer = optim.Adam(model.parameters(), lr=learning_rate) \n", 141 | "best_val=0 \n", 142 | "\n", 143 | "for epoch in range(50):\n", 144 | " test_loss = 0\n", 145 | " epoch_loss = 0\n", 146 | " for i, data in enumerate(loader_train, 0):\n", 147 | " model.train()\n", 148 | " model.zero_grad()\n", 149 | " optimizer.zero_grad()\n", 150 | " img_train = data\n", 151 | " noise = torch.FloatTensor(img_train.size()).normal_(mean=0, std=noiseL/255.)\n", 152 | " imgn_train = img_train + noise\n", 153 | " img_train, imgn_train = Variable(img_train.cuda()), Variable(imgn_train.cuda())\n", 154 | " noise = Variable(noise.cuda())\n", 155 | " out_train = model(imgn_train)\n", 156 | " loss = criterion(out_train, noise) / (img_train.size()[0]*2)\n", 157 | " psnr_train = batch_PSNR(imgn_train-out_train, img_train, 1.)\n", 158 | " loss.backward()\n", 159 | " optimizer.step()\n", 160 | " epoch_loss = epoch_loss+loss.item()\n", 161 | " # results\n", 162 | " if i%30 == 0:\n", 163 | " print(\"[epoch %d][%d/%d] loss: %.4f PSNR_train: %.4f\" %\n", 164 | " (epoch+1, i+1, len(loader_train), loss.item(), psnr_train))\n", 165 | " model.eval()\n", 166 | " epoch_loss=epoch_loss/len(loader_train)\n", 167 | " psnr_val = 0\n", 168 | " for k in range(len(dataset_val)):\n", 169 | " img_val = torch.unsqueeze(dataset_val[k], 0)\n", 170 | " noise = torch.FloatTensor(img_val.size()).normal_(mean=0, std=noiseL/255.)\n", 171 | " imgn_val = img_val + noise\n", 172 | " img_val, imgn_val = Variable(img_val.cuda()), Variable(imgn_val.cuda())\n", 173 | " with torch.no_grad():\n", 174 | " out_val = torch.clamp(imgn_val-model(imgn_val), 0., 1.)\n", 175 | " psnr_val += batch_PSNR(out_val, img_val, 1.)\n", 176 | " test_loss += criterion(out_val, img_val) / (imgn_train.size()[0]*2)\n", 177 | " psnr_val /= len(dataset_val)\n", 178 | " test_loss /= len(dataset_val)\n", 179 | " print(\"[epoch %d] Train Loss: %.4f Val Loss: %.4f PSNR_val: %.4f\" %\n", 180 | " (epoch+1, epoch_loss,test_loss.item(), psnr_val))\n", 181 | " \n", 182 | " if epoch%10==0:\n", 183 | " learning_rate=learning_rate*0.5\n", 184 | " for param_group in optimizer.param_groups:\n", 185 | " param_group['lr'] = learning_rate\n", 186 | " \n", 187 | " if psnr_val>=best_val:\n", 188 | " torch.save(model.state_dict(),'model.pth')\n", 189 | " best_val=psnr_val " 190 | ], 191 | "execution_count": null, 192 | "outputs": [] 193 | } 194 | ] 195 | } -------------------------------------------------------------------------------- /TensorflowExamples/layer_CSVT.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def bottleneck(x, filter_shape): 4 | filters = tf.get_variable( 5 | name='weight', 6 | shape=filter_shape, 7 | dtype=tf.float32, 8 | initializer=tf.contrib.layers.xavier_initializer(), 9 | trainable=True) 10 | return tf.nn.conv2d(x, filters, [1, 1, 1, 1], padding='SAME') 11 | 12 | def branchout(x, filter_shape1, filter_shape2): 13 | filters1 = tf.get_variable( 14 | name='weight1', 15 | shape=filter_shape1, 16 | dtype=tf.float32, 17 | initializer=tf.contrib.layers.xavier_initializer(), 18 | trainable=True) 19 | filters2_1 = tf.get_variable( 20 | name='weight2_1', 21 | shape=[3, 3, filter_shape2[2], 1], 22 | dtype=tf.float32, 23 | initializer=tf.contrib.layers.xavier_initializer(), 24 | trainable=True) 25 | filters2_2 = tf.get_variable( 26 | name='weight2_2', 27 | shape=[1, 1, filter_shape2[2], filter_shape2[3]], 28 | dtype=tf.float32, 29 | initializer=tf.contrib.layers.xavier_initializer(), 30 | trainable=True) 31 | 32 | w = tf.nn.conv2d(x, filters1, [1, 1, 1, 1], padding='SAME') 33 | y = tf.nn.separable_conv2d(x, filters2_1, filters2_2, [1, 1, 1, 1], padding='SAME') 34 | z = tf.concat([w,y],3) 35 | z = tf.reshape(z,(x.shape[0],x.shape[1],x.shape[2],2*filter_shape1[3])) 36 | return z 37 | 38 | def group_conv_normal(x, filter_shape, groups, stride=1): 39 | # Currently groups is hardcoded to be = 8 as per our paper 40 | # One can change according to their experiment's need. 41 | # That's why we have flters1_, filters2_,... filters8_. 42 | # One can write this in a loop to have a clearner code. 43 | filters1_ = tf.get_variable( 44 | name='weight1_', 45 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 46 | dtype=tf.float32, 47 | initializer=tf.contrib.layers.xavier_initializer(), 48 | trainable=True) 49 | filters2_ = tf.get_variable( 50 | name='weight2_', 51 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 52 | dtype=tf.float32, 53 | initializer=tf.contrib.layers.xavier_initializer(), 54 | trainable=True) 55 | filters3_ = tf.get_variable( 56 | name='weight3_', 57 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 58 | dtype=tf.float32, 59 | initializer=tf.contrib.layers.xavier_initializer(), 60 | trainable=True) 61 | filters4_ = tf.get_variable( 62 | name='weight4_', 63 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 64 | dtype=tf.float32, 65 | initializer=tf.contrib.layers.xavier_initializer(), 66 | trainable=True) 67 | filters5_ = tf.get_variable( 68 | name='weight5_', 69 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 70 | dtype=tf.float32, 71 | initializer=tf.contrib.layers.xavier_initializer(), 72 | trainable=True) 73 | filters6_ = tf.get_variable( 74 | name='weight6_', 75 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 76 | dtype=tf.float32, 77 | initializer=tf.contrib.layers.xavier_initializer(), 78 | trainable=True) 79 | filters7_ = tf.get_variable( 80 | name='weight7_', 81 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 82 | dtype=tf.float32, 83 | initializer=tf.contrib.layers.xavier_initializer(), 84 | trainable=True) 85 | filters8_ = tf.get_variable( 86 | name='weight8_', 87 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 88 | dtype=tf.float32, 89 | initializer=tf.contrib.layers.xavier_initializer(), 90 | trainable=True) 91 | 92 | ig = filter_shape[2] // groups 93 | offset = 0 94 | glist = [] 95 | # We have g1,...g8 due to hard coding of number of groups = 8 96 | g1 = x[:, :, :, offset:offset+ig] 97 | g2 = x[:, :, :, offset+ig:offset+(2*ig)] 98 | g3 = x[:, :, :, offset+(2*ig):offset+(3*ig)] 99 | g4 = x[:, :, :, offset+(3*ig):offset+(4*ig)] 100 | g5 = x[:, :, :, offset+(4*ig):offset+(5*ig)] 101 | g6 = x[:, :, :, offset+(5*ig):offset+(6*ig)] 102 | g7 = x[:, :, :, offset+(6*ig):offset+(7*ig)] 103 | g8 = x[:, :, :, offset+(7*ig):offset+(8*ig)] 104 | 105 | # We have y1,...y8 due to hard coding of number of groups = 8 106 | y1 = tf.nn.conv2d(g1, filters1_, [1, stride, stride, 1], padding='SAME') 107 | y2 = tf.nn.conv2d(g2, filters2_, [1, stride, stride, 1], padding='SAME') 108 | y3 = tf.nn.conv2d(g3, filters3_, [1, stride, stride, 1], padding='SAME') 109 | y4 = tf.nn.conv2d(g4, filters4_, [1, stride, stride, 1], padding='SAME') 110 | y5 = tf.nn.conv2d(g5, filters5_, [1, stride, stride, 1], padding='SAME') 111 | y6 = tf.nn.conv2d(g6, filters6_, [1, stride, stride, 1], padding='SAME') 112 | y7 = tf.nn.conv2d(g7, filters7_, [1, stride, stride, 1], padding='SAME') 113 | y8 = tf.nn.conv2d(g8, filters8_, [1, stride, stride, 1], padding='SAME') 114 | 115 | z = tf.concat([y1,y2,y3,y4,y5,y6,y7,y8],3) 116 | z = tf.reshape(z,(x.shape[0],x.shape[1] ,x.shape[2], x.shape[3])) 117 | 118 | return z 119 | 120 | 121 | def group_conv_dilated(x, filter_shape, groups, dilation): 122 | # Currently groups is hardcoded to be = 8 as per our paper 123 | # One can change according to their experiment's need. 124 | # That's why we have flters1_, filters2_,... filters8_. 125 | # One can write this in a loop to have a clearner code. 126 | filters1 = tf.get_variable( 127 | name='weight1', 128 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 129 | dtype=tf.float32, 130 | initializer=tf.contrib.layers.xavier_initializer(), 131 | trainable=True) 132 | filters2 = tf.get_variable( 133 | name='weight2', 134 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 135 | dtype=tf.float32, 136 | initializer=tf.contrib.layers.xavier_initializer(), 137 | trainable=True) 138 | filters3 = tf.get_variable( 139 | name='weight3', 140 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 141 | dtype=tf.float32, 142 | initializer=tf.contrib.layers.xavier_initializer(), 143 | trainable=True) 144 | filters4 = tf.get_variable( 145 | name='weight4', 146 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 147 | dtype=tf.float32, 148 | initializer=tf.contrib.layers.xavier_initializer(), 149 | trainable=True) 150 | filters5 = tf.get_variable( 151 | name='weight5', 152 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 153 | dtype=tf.float32, 154 | initializer=tf.contrib.layers.xavier_initializer(), 155 | trainable=True) 156 | filters6 = tf.get_variable( 157 | name='weight6', 158 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 159 | dtype=tf.float32, 160 | initializer=tf.contrib.layers.xavier_initializer(), 161 | trainable=True) 162 | filters7 = tf.get_variable( 163 | name='weight7', 164 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 165 | dtype=tf.float32, 166 | initializer=tf.contrib.layers.xavier_initializer(), 167 | trainable=True) 168 | filters8 = tf.get_variable( 169 | name='weight8', 170 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups], 171 | dtype=tf.float32, 172 | initializer=tf.contrib.layers.xavier_initializer(), 173 | trainable=True) 174 | 175 | ig = filter_shape[2] // groups 176 | offset = 0 177 | glist = [] 178 | # Due to hard coding of number of groups = 8 we need g1,g2,...g8. 179 | g1 = x[:, :, :, offset:offset+ig] 180 | g2 = x[:, :, :, offset+ig:offset+(2*ig)] 181 | g3 = x[:, :, :, offset+(2*ig):offset+(3*ig)] 182 | g4 = x[:, :, :, offset+(3*ig):offset+(4*ig)] 183 | g5 = x[:, :, :, offset+(4*ig):offset+(5*ig)] 184 | g6 = x[:, :, :, offset+(5*ig):offset+(6*ig)] 185 | g7 = x[:, :, :, offset+(6*ig):offset+(7*ig)] 186 | g8 = x[:, :, :, offset+(7*ig):offset+(8*ig)] 187 | 188 | y1 = tf.nn.atrous_conv2d(g1, filters1, dilation, padding='SAME') 189 | y2 = tf.nn.atrous_conv2d(g2, filters2, dilation, padding='SAME') 190 | y3 = tf.nn.atrous_conv2d(g3, filters3, dilation, padding='SAME') 191 | y4 = tf.nn.atrous_conv2d(g4, filters4, dilation, padding='SAME') 192 | y5 = tf.nn.atrous_conv2d(g5, filters5, dilation, padding='SAME') 193 | y6 = tf.nn.atrous_conv2d(g6, filters6, dilation, padding='SAME') 194 | y7 = tf.nn.atrous_conv2d(g7, filters7, dilation, padding='SAME') 195 | y8 = tf.nn.atrous_conv2d(g8, filters8, dilation, padding='SAME') 196 | 197 | z = tf.concat([y1,y2,y3,y4,y5,y6,y7,y8],3) 198 | z = tf.reshape(z,(x.shape[0],x.shape[1],x.shape[2],x.shape[3])) 199 | 200 | return z 201 | 202 | 203 | def channel_shuffle(x, groups): 204 | y = tf.reshape(x, (x.shape[0], x.shape[1], x.shape[2], groups, x.shape[3]//groups)) 205 | y = tf.transpose(y, perm=[0, 1, 2, 4, 3]) 206 | y = tf.reshape(y, (x.shape[0], x.shape[1], x.shape[2], x.shape[3])) 207 | 208 | return y 209 | 210 | 211 | def batch_normalize(x, is_training, decay=0.99, epsilon=0.001): 212 | def bn_train(): 213 | batch_mean, batch_var = tf.nn.moments(x, axes=[0, 1, 2]) 214 | train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay)) 215 | train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) 216 | with tf.control_dependencies([train_mean, train_var]): 217 | return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon) 218 | 219 | def bn_inference(): 220 | return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon) 221 | 222 | dim = x.get_shape().as_list()[-1] 223 | beta = tf.get_variable( 224 | name='beta', 225 | shape=[dim], 226 | dtype=tf.float32, 227 | initializer=tf.truncated_normal_initializer(stddev=0.0), 228 | trainable=True) 229 | scale = tf.get_variable( 230 | name='scale', 231 | shape=[dim], 232 | dtype=tf.float32, 233 | initializer=tf.truncated_normal_initializer(stddev=0.1), 234 | trainable=True) 235 | pop_mean = tf.get_variable( 236 | name='pop_mean', 237 | shape=[dim], 238 | dtype=tf.float32, 239 | initializer=tf.constant_initializer(0.0), 240 | trainable=False) 241 | pop_var = tf.get_variable( 242 | name='pop_var', 243 | shape=[dim], 244 | dtype=tf.float32, 245 | initializer=tf.constant_initializer(1.0), 246 | trainable=False) 247 | 248 | return tf.cond(is_training, bn_train, bn_inference) 249 | --------------------------------------------------------------------------------