├── combined_cover.png
├── PytorchExamples
├── block.py
├── network.py
└── Denoising_demo.ipynb
├── README.md
└── TensorflowExamples
├── basicModules.py
└── layer_CSVT.py
/combined_cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avisekiit/TCSVT-LightWeight-CNNs/HEAD/combined_cover.png
--------------------------------------------------------------------------------
/PytorchExamples/block.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 |
6 | class depthwise_separable_conv(nn.Module):
7 | def __init__(self, nin, nout):
8 | super(depthwise_separable_conv, self).__init__()
9 | self.depthwise = nn.Conv2d(nin, nin, kernel_size=3, padding=1, groups=nin)
10 | self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)
11 |
12 | def forward(self, x):
13 | out = self.depthwise(x)
14 | out = self.pointwise(out)
15 | return out
16 |
17 | class LIST(nn.Module):
18 | def __init__(self, input_channel, output_channel):
19 | super(LIST, self).__init__()
20 | self.squezze_layer=nn.Conv2d(input_channel, int(input_channel/4), kernel_size=1)
21 | self.fire_layer=nn.Conv2d(int(input_channel/4), int(output_channel/2), kernel_size=1)
22 | self.depthwise=depthwise_separable_conv(int(input_channel/4), int(output_channel/2))
23 |
24 | def forward(self, x):
25 | stream=self.squezze_layer(x)
26 | out2=self.fire_layer(stream)
27 | out1=self.depthwise(stream)
28 | return torch.cat((out1,out2),1)
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TCSVT-LightWeight-CNNs
2 | Code for our IEEE TCSVT Paper: Lightweight Modules for Efficient Deep Learning based Image Restoration
3 |
4 | Authors: Avisek Lahiri*, Sourav Bairagya*, Sutanu Bera, Siddhant Haldar, Prabir Kumar Biswas
5 | (* equal contribution)
6 | 1. Paper Link: https://arxiv.org/abs/2007.05835
7 | 2. IEEE Early Access Link: https://ieeexplore.ieee.org/document/9134805
8 |
9 | ### Key Points from Paper
10 | * Paper provides re-usable modules to be plugged and played to compress a given CNN
11 | * Select any favourite full-scale baseline for low-level vision applications
12 | * Replace 3X3 conv by **LIST** layer
13 | * Replace dilated conv layer **GSAT** layer
14 | * Achieve efficient up/down-sample with **Bilinear SubSampling** followed by **LIST** layer
15 |
16 | ### TensorflowExamples
17 | This contains the basic proposed modules in Tensorflow
18 | TensorflowExamples/basicModules.py contains the proposed **LIST**, **GSAT** modules
19 | It also contains the framework for **LIST** based up/down-sampling in a CNN
20 |
21 | ### PytorchExamples
22 | It contains the basic proposed modules in Pytorch
23 |
24 | * block.py has the implementatin of **LIST** module based DnCNN denoising framework
25 | * Denoising_demo.ipynb is a notebook to reflect our training/inference setup for DnCNN experiments
26 |
27 | 
28 |
--------------------------------------------------------------------------------
/PytorchExamples/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from block import LIST
4 |
5 |
6 | class DnCNN(nn.Module):
7 | """ The bare-bone skeleton to implement the fullscale baseline of DnCNN denoising network.
8 |
9 | """
10 | def __init__(self, channels=1, num_of_layers=17):
11 | super(DnCNN, self).__init__()
12 | kernel_size = 3
13 | padding = 1
14 | features = 64
15 | layers = []
16 | layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
17 | layers.append(nn.ReLU(inplace=True))
18 | for _ in range(num_of_layers-2):
19 | layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
20 | layers.append(nn.BatchNorm2d(features))
21 | layers.append(nn.ReLU(inplace=True))
22 | layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
23 | self.dncnn = nn.Sequential(*layers)
24 | def forward(self, x):
25 | out = self.dncnn(x)
26 | return out
27 |
28 |
29 |
30 | class DnCNN_cheap(nn.Module):
31 | """ The bare-bone skeleton to implement the proposed cheaper version of DnCNN network.
32 | """
33 | def __init__(self, channels=1, num_of_layers=17):
34 | super(DnCNN_cheap, self).__init__()
35 | kernel_size = 3
36 | padding = 1
37 | features = 64
38 | layers = []
39 | layers.append(nn.Conv2d(channels, features, kernel_size=kernel_size, padding=padding, bias=False))
40 | layers.append(nn.ReLU(inplace=True))
41 | for _ in range(num_of_layers-2):
42 | layers.append(LIST(features, features))
43 | layers.append(nn.BatchNorm2d(features))
44 | layers.append(nn.ReLU(inplace=True))
45 | layers.append(nn.Conv2d(features, channels, kernel_size=kernel_size, padding=padding, bias=False))
46 | self.dncnn = nn.Sequential(*layers)
47 | def forward(self, x):
48 | out = self.dncnn(x)
49 | return out
50 |
--------------------------------------------------------------------------------
/TensorflowExamples/basicModules.py:
--------------------------------------------------------------------------------
1 | from layer_CSVT import branchout, bottlenect, group_conv_dilated, group_conv_normal, batch_normalize
2 |
3 | # In this file we are releasing the barebone skeleton of our modules.
4 | # Users can takes these code snippets and plug in into their own
5 | # realizations of LIST, GSAT, up/down sampling layers.
6 |
7 | #============================================================
8 | #*** LIST Module ***
9 | # in_channel : number of channels taken as input to the the LIST layer
10 | # out_channel: number of channels that will be output from the LIST layer
11 | with tf.variable_scope('conv1_1'):
12 | x = bottleneck(x,[1, 1, in_channel, in_channel//4])
13 | x = batch_normalize(x, is_training)
14 | x = tf.nn.relu(x)
15 | with tf.variable_scope('conv1_2'):
16 | x = branchout(x,[1, 1, in_channel//4, out_channel//2],[3, 3, in_channel//4, out_channel//2])
17 | x = batch_normalize(x, is_training)
18 | x = tf.nn.relu(x)
19 | #===========================================================
20 |
21 |
22 | #===========================================================
23 |
24 | #*** GSAT Module ***
25 | # in_channel : number of channels taken as input to the the GSAT layer
26 | # num_groups = Number of groups for Group Connvolution (see GSAT section in paper)
27 | with tf.variable_scope('gsat'):
28 | skip = x;
29 | x = group_conv_dilated(x,[3, 3, in_channel, in_channel], num_groups, dilation_factor)
30 | x = channel_shuffle(x, num_groups)
31 | x = group_conv_normal(x,[1, 1, in_channel, in_channel], num_groups)
32 | x = batch_normalize(x,is_training)
33 | x = skip + x
34 | x = tf.nn.relu(x)
35 | #===========================================================
36 |
37 |
38 | #===========================================================
39 | #*** UPSAMPLING Module ***
40 | #/////////////////////////////
41 | # First deterministic upsampling then follow with a LIST layer.
42 | with tf.variable_scope('deconv1_1'):
43 | x = tf.image.resize_bilinear(x, [x.shape[1]*stride, x.shape[2]*stride])
44 | x = bottleneck(x,[1, 1, in_channel, in_channel//4])
45 | x = batch_normalize(x, is_training)
46 | x = tf.nn.relu(x)
47 | with tf.variable_scope('deconv1_2'):
48 | x = branchout(x,[1, 1, in_channel//4, out_channel//2],[3, 3, in_channel//4, out_channel//2])
49 | x = batch_normalize(x, is_training)
50 | x = tf.nn.relu(x)
51 | #===========================================================
52 |
53 |
54 | #===========================================================
55 | #*** DOWNSAMPLING Module ***
56 | #////////////////////////////////
57 | # First deterministic downsampling then follow with a LIST layer.
58 | with tf.variable_scope('conv1_1'):
59 | x = tf.image.resize_bilinear(x, [x.shape[1]//stride, x.shape[2]//stride])
60 | x = bottleneck(x,[1, 1, in_channel, in_channel//4])
61 | x = batch_normalize(x, is_training)
62 | x = tf.nn.relu(x)
63 | with tf.variable_scope('conv1_2'):
64 | x = branchout(x,[1, 1, in_channel//4, out_channel//2], [3, 3, in_channel//4, out_channel//2])
65 | x = batch_normalize(x, is_training)
66 | x = tf.nn.relu(x)
67 | #===========================================================
68 |
--------------------------------------------------------------------------------
/PytorchExamples/Denoising_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "display_name": "Python 3",
7 | "language": "python",
8 | "name": "python3"
9 | },
10 | "language_info": {
11 | "codemirror_mode": {
12 | "name": "ipython",
13 | "version": 3
14 | },
15 | "file_extension": ".py",
16 | "mimetype": "text/x-python",
17 | "name": "python",
18 | "nbconvert_exporter": "python",
19 | "pygments_lexer": "ipython3",
20 | "version": "3.7.3"
21 | },
22 | "colab": {
23 | "name": "Denoising_demo.ipynb",
24 | "provenance": [],
25 | "collapsed_sections": []
26 | }
27 | },
28 | "cells": [
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "Sv9RI8SsxYeN",
33 | "colab_type": "text"
34 | },
35 | "source": [
36 | "A Sample Notebook for realizing proposed compressed version of DnCNN framework\n",
37 | "for image denoising."
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "metadata": {
43 | "id": "fZk2K1bsxTW4",
44 | "colab_type": "code",
45 | "colab": {}
46 | },
47 | "source": [
48 | "import torch\n",
49 | "import torch.nn as nn\n",
50 | "import numpy as np\n",
51 | "import random\n",
52 | "import h5py\n",
53 | "\n",
54 | "from torch.utils.data import Dataset, DataLoader\n",
55 | "import torch.optim as optim\n",
56 | "from torch.autograd import Variable\n",
57 | "from skimage.measure.simple_metrics import compare_psnr\n",
58 | "import torch.nn.functional as F\n",
59 | "\n",
60 | "torch.backends.cudnn.benchmark = True\n",
61 | "\n",
62 | "from network import DnCNN, DnCNN_cheap"
63 | ],
64 | "execution_count": null,
65 | "outputs": []
66 | },
67 | {
68 | "cell_type": "code",
69 | "metadata": {
70 | "id": "rg7ouwIpxTW9",
71 | "colab_type": "code",
72 | "colab": {}
73 | },
74 | "source": [
75 | "def batch_PSNR(img, imclean, data_range):\n",
76 | " Img = img.data.cpu().numpy().astype(np.float32)\n",
77 | " Iclean = imclean.data.cpu().numpy().astype(np.float32)\n",
78 | " PSNR = 0\n",
79 | " for i in range(Img.shape[0]):\n",
80 | " PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)\n",
81 | " return (PSNR/Img.shape[0])"
82 | ],
83 | "execution_count": null,
84 | "outputs": []
85 | },
86 | {
87 | "cell_type": "code",
88 | "metadata": {
89 | "id": "BeqxDL1IxTXA",
90 | "colab_type": "code",
91 | "colab": {}
92 | },
93 | "source": [
94 | "class Dataset(Dataset):\n",
95 | " def __init__(self, train=True):\n",
96 | " super(Dataset, self).__init__()\n",
97 | " self.train = train\n",
98 | " # Store images in .h5 file format\n",
99 | " if self.train:\n",
100 | " h5f = h5py.File('train.h5', 'r')\n",
101 | " else:\n",
102 | " h5f = h5py.File('val.h5', 'r')\n",
103 | " self.keys = list(h5f.keys())\n",
104 | " random.shuffle(self.keys)\n",
105 | " h5f.close()\n",
106 | " def __len__(self):\n",
107 | " return len(self.keys)\n",
108 | " def __getitem__(self, index):\n",
109 | " if self.train:\n",
110 | " h5f = h5py.File('train.h5', 'r')\n",
111 | " else:\n",
112 | " h5f = h5py.File('val.h5', 'r')\n",
113 | " key = self.keys[index]\n",
114 | " data = np.array(h5f[key])\n",
115 | " h5f.close()\n",
116 | " return torch.Tensor(data)\n"
117 | ],
118 | "execution_count": null,
119 | "outputs": []
120 | },
121 | {
122 | "cell_type": "code",
123 | "metadata": {
124 | "id": "TDWTN7qYxTXE",
125 | "colab_type": "code",
126 | "colab": {}
127 | },
128 | "source": [
129 | "noiseL=25\n",
130 | "learning_rate=0.001\n",
131 | "batchSize=64\n",
132 | "print('Loading dataset ...\\n')\n",
133 | "dataset_train = Dataset(train=True)\n",
134 | "dataset_val = Dataset(train=False)\n",
135 | "loader_train = DataLoader(dataset=dataset_train, num_workers=7, \n",
136 | " batch_size=batchSize, shuffle=True)\n",
137 | "model = DnCNN_cheap() # DnCNN()\n",
138 | "model=model.cuda()\n",
139 | "criterion = nn.MSELoss(size_average=False)\n",
140 | "optimizer = optim.Adam(model.parameters(), lr=learning_rate) \n",
141 | "best_val=0 \n",
142 | "\n",
143 | "for epoch in range(50):\n",
144 | " test_loss = 0\n",
145 | " epoch_loss = 0\n",
146 | " for i, data in enumerate(loader_train, 0):\n",
147 | " model.train()\n",
148 | " model.zero_grad()\n",
149 | " optimizer.zero_grad()\n",
150 | " img_train = data\n",
151 | " noise = torch.FloatTensor(img_train.size()).normal_(mean=0, std=noiseL/255.)\n",
152 | " imgn_train = img_train + noise\n",
153 | " img_train, imgn_train = Variable(img_train.cuda()), Variable(imgn_train.cuda())\n",
154 | " noise = Variable(noise.cuda())\n",
155 | " out_train = model(imgn_train)\n",
156 | " loss = criterion(out_train, noise) / (img_train.size()[0]*2)\n",
157 | " psnr_train = batch_PSNR(imgn_train-out_train, img_train, 1.)\n",
158 | " loss.backward()\n",
159 | " optimizer.step()\n",
160 | " epoch_loss = epoch_loss+loss.item()\n",
161 | " # results\n",
162 | " if i%30 == 0:\n",
163 | " print(\"[epoch %d][%d/%d] loss: %.4f PSNR_train: %.4f\" %\n",
164 | " (epoch+1, i+1, len(loader_train), loss.item(), psnr_train))\n",
165 | " model.eval()\n",
166 | " epoch_loss=epoch_loss/len(loader_train)\n",
167 | " psnr_val = 0\n",
168 | " for k in range(len(dataset_val)):\n",
169 | " img_val = torch.unsqueeze(dataset_val[k], 0)\n",
170 | " noise = torch.FloatTensor(img_val.size()).normal_(mean=0, std=noiseL/255.)\n",
171 | " imgn_val = img_val + noise\n",
172 | " img_val, imgn_val = Variable(img_val.cuda()), Variable(imgn_val.cuda())\n",
173 | " with torch.no_grad():\n",
174 | " out_val = torch.clamp(imgn_val-model(imgn_val), 0., 1.)\n",
175 | " psnr_val += batch_PSNR(out_val, img_val, 1.)\n",
176 | " test_loss += criterion(out_val, img_val) / (imgn_train.size()[0]*2)\n",
177 | " psnr_val /= len(dataset_val)\n",
178 | " test_loss /= len(dataset_val)\n",
179 | " print(\"[epoch %d] Train Loss: %.4f Val Loss: %.4f PSNR_val: %.4f\" %\n",
180 | " (epoch+1, epoch_loss,test_loss.item(), psnr_val))\n",
181 | " \n",
182 | " if epoch%10==0:\n",
183 | " learning_rate=learning_rate*0.5\n",
184 | " for param_group in optimizer.param_groups:\n",
185 | " param_group['lr'] = learning_rate\n",
186 | " \n",
187 | " if psnr_val>=best_val:\n",
188 | " torch.save(model.state_dict(),'model.pth')\n",
189 | " best_val=psnr_val "
190 | ],
191 | "execution_count": null,
192 | "outputs": []
193 | }
194 | ]
195 | }
--------------------------------------------------------------------------------
/TensorflowExamples/layer_CSVT.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def bottleneck(x, filter_shape):
4 | filters = tf.get_variable(
5 | name='weight',
6 | shape=filter_shape,
7 | dtype=tf.float32,
8 | initializer=tf.contrib.layers.xavier_initializer(),
9 | trainable=True)
10 | return tf.nn.conv2d(x, filters, [1, 1, 1, 1], padding='SAME')
11 |
12 | def branchout(x, filter_shape1, filter_shape2):
13 | filters1 = tf.get_variable(
14 | name='weight1',
15 | shape=filter_shape1,
16 | dtype=tf.float32,
17 | initializer=tf.contrib.layers.xavier_initializer(),
18 | trainable=True)
19 | filters2_1 = tf.get_variable(
20 | name='weight2_1',
21 | shape=[3, 3, filter_shape2[2], 1],
22 | dtype=tf.float32,
23 | initializer=tf.contrib.layers.xavier_initializer(),
24 | trainable=True)
25 | filters2_2 = tf.get_variable(
26 | name='weight2_2',
27 | shape=[1, 1, filter_shape2[2], filter_shape2[3]],
28 | dtype=tf.float32,
29 | initializer=tf.contrib.layers.xavier_initializer(),
30 | trainable=True)
31 |
32 | w = tf.nn.conv2d(x, filters1, [1, 1, 1, 1], padding='SAME')
33 | y = tf.nn.separable_conv2d(x, filters2_1, filters2_2, [1, 1, 1, 1], padding='SAME')
34 | z = tf.concat([w,y],3)
35 | z = tf.reshape(z,(x.shape[0],x.shape[1],x.shape[2],2*filter_shape1[3]))
36 | return z
37 |
38 | def group_conv_normal(x, filter_shape, groups, stride=1):
39 | # Currently groups is hardcoded to be = 8 as per our paper
40 | # One can change according to their experiment's need.
41 | # That's why we have flters1_, filters2_,... filters8_.
42 | # One can write this in a loop to have a clearner code.
43 | filters1_ = tf.get_variable(
44 | name='weight1_',
45 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
46 | dtype=tf.float32,
47 | initializer=tf.contrib.layers.xavier_initializer(),
48 | trainable=True)
49 | filters2_ = tf.get_variable(
50 | name='weight2_',
51 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
52 | dtype=tf.float32,
53 | initializer=tf.contrib.layers.xavier_initializer(),
54 | trainable=True)
55 | filters3_ = tf.get_variable(
56 | name='weight3_',
57 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
58 | dtype=tf.float32,
59 | initializer=tf.contrib.layers.xavier_initializer(),
60 | trainable=True)
61 | filters4_ = tf.get_variable(
62 | name='weight4_',
63 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
64 | dtype=tf.float32,
65 | initializer=tf.contrib.layers.xavier_initializer(),
66 | trainable=True)
67 | filters5_ = tf.get_variable(
68 | name='weight5_',
69 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
70 | dtype=tf.float32,
71 | initializer=tf.contrib.layers.xavier_initializer(),
72 | trainable=True)
73 | filters6_ = tf.get_variable(
74 | name='weight6_',
75 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
76 | dtype=tf.float32,
77 | initializer=tf.contrib.layers.xavier_initializer(),
78 | trainable=True)
79 | filters7_ = tf.get_variable(
80 | name='weight7_',
81 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
82 | dtype=tf.float32,
83 | initializer=tf.contrib.layers.xavier_initializer(),
84 | trainable=True)
85 | filters8_ = tf.get_variable(
86 | name='weight8_',
87 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
88 | dtype=tf.float32,
89 | initializer=tf.contrib.layers.xavier_initializer(),
90 | trainable=True)
91 |
92 | ig = filter_shape[2] // groups
93 | offset = 0
94 | glist = []
95 | # We have g1,...g8 due to hard coding of number of groups = 8
96 | g1 = x[:, :, :, offset:offset+ig]
97 | g2 = x[:, :, :, offset+ig:offset+(2*ig)]
98 | g3 = x[:, :, :, offset+(2*ig):offset+(3*ig)]
99 | g4 = x[:, :, :, offset+(3*ig):offset+(4*ig)]
100 | g5 = x[:, :, :, offset+(4*ig):offset+(5*ig)]
101 | g6 = x[:, :, :, offset+(5*ig):offset+(6*ig)]
102 | g7 = x[:, :, :, offset+(6*ig):offset+(7*ig)]
103 | g8 = x[:, :, :, offset+(7*ig):offset+(8*ig)]
104 |
105 | # We have y1,...y8 due to hard coding of number of groups = 8
106 | y1 = tf.nn.conv2d(g1, filters1_, [1, stride, stride, 1], padding='SAME')
107 | y2 = tf.nn.conv2d(g2, filters2_, [1, stride, stride, 1], padding='SAME')
108 | y3 = tf.nn.conv2d(g3, filters3_, [1, stride, stride, 1], padding='SAME')
109 | y4 = tf.nn.conv2d(g4, filters4_, [1, stride, stride, 1], padding='SAME')
110 | y5 = tf.nn.conv2d(g5, filters5_, [1, stride, stride, 1], padding='SAME')
111 | y6 = tf.nn.conv2d(g6, filters6_, [1, stride, stride, 1], padding='SAME')
112 | y7 = tf.nn.conv2d(g7, filters7_, [1, stride, stride, 1], padding='SAME')
113 | y8 = tf.nn.conv2d(g8, filters8_, [1, stride, stride, 1], padding='SAME')
114 |
115 | z = tf.concat([y1,y2,y3,y4,y5,y6,y7,y8],3)
116 | z = tf.reshape(z,(x.shape[0],x.shape[1] ,x.shape[2], x.shape[3]))
117 |
118 | return z
119 |
120 |
121 | def group_conv_dilated(x, filter_shape, groups, dilation):
122 | # Currently groups is hardcoded to be = 8 as per our paper
123 | # One can change according to their experiment's need.
124 | # That's why we have flters1_, filters2_,... filters8_.
125 | # One can write this in a loop to have a clearner code.
126 | filters1 = tf.get_variable(
127 | name='weight1',
128 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
129 | dtype=tf.float32,
130 | initializer=tf.contrib.layers.xavier_initializer(),
131 | trainable=True)
132 | filters2 = tf.get_variable(
133 | name='weight2',
134 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
135 | dtype=tf.float32,
136 | initializer=tf.contrib.layers.xavier_initializer(),
137 | trainable=True)
138 | filters3 = tf.get_variable(
139 | name='weight3',
140 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
141 | dtype=tf.float32,
142 | initializer=tf.contrib.layers.xavier_initializer(),
143 | trainable=True)
144 | filters4 = tf.get_variable(
145 | name='weight4',
146 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
147 | dtype=tf.float32,
148 | initializer=tf.contrib.layers.xavier_initializer(),
149 | trainable=True)
150 | filters5 = tf.get_variable(
151 | name='weight5',
152 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
153 | dtype=tf.float32,
154 | initializer=tf.contrib.layers.xavier_initializer(),
155 | trainable=True)
156 | filters6 = tf.get_variable(
157 | name='weight6',
158 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
159 | dtype=tf.float32,
160 | initializer=tf.contrib.layers.xavier_initializer(),
161 | trainable=True)
162 | filters7 = tf.get_variable(
163 | name='weight7',
164 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
165 | dtype=tf.float32,
166 | initializer=tf.contrib.layers.xavier_initializer(),
167 | trainable=True)
168 | filters8 = tf.get_variable(
169 | name='weight8',
170 | shape=[filter_shape[0], filter_shape[1], filter_shape[2]//groups, filter_shape[3]//groups],
171 | dtype=tf.float32,
172 | initializer=tf.contrib.layers.xavier_initializer(),
173 | trainable=True)
174 |
175 | ig = filter_shape[2] // groups
176 | offset = 0
177 | glist = []
178 | # Due to hard coding of number of groups = 8 we need g1,g2,...g8.
179 | g1 = x[:, :, :, offset:offset+ig]
180 | g2 = x[:, :, :, offset+ig:offset+(2*ig)]
181 | g3 = x[:, :, :, offset+(2*ig):offset+(3*ig)]
182 | g4 = x[:, :, :, offset+(3*ig):offset+(4*ig)]
183 | g5 = x[:, :, :, offset+(4*ig):offset+(5*ig)]
184 | g6 = x[:, :, :, offset+(5*ig):offset+(6*ig)]
185 | g7 = x[:, :, :, offset+(6*ig):offset+(7*ig)]
186 | g8 = x[:, :, :, offset+(7*ig):offset+(8*ig)]
187 |
188 | y1 = tf.nn.atrous_conv2d(g1, filters1, dilation, padding='SAME')
189 | y2 = tf.nn.atrous_conv2d(g2, filters2, dilation, padding='SAME')
190 | y3 = tf.nn.atrous_conv2d(g3, filters3, dilation, padding='SAME')
191 | y4 = tf.nn.atrous_conv2d(g4, filters4, dilation, padding='SAME')
192 | y5 = tf.nn.atrous_conv2d(g5, filters5, dilation, padding='SAME')
193 | y6 = tf.nn.atrous_conv2d(g6, filters6, dilation, padding='SAME')
194 | y7 = tf.nn.atrous_conv2d(g7, filters7, dilation, padding='SAME')
195 | y8 = tf.nn.atrous_conv2d(g8, filters8, dilation, padding='SAME')
196 |
197 | z = tf.concat([y1,y2,y3,y4,y5,y6,y7,y8],3)
198 | z = tf.reshape(z,(x.shape[0],x.shape[1],x.shape[2],x.shape[3]))
199 |
200 | return z
201 |
202 |
203 | def channel_shuffle(x, groups):
204 | y = tf.reshape(x, (x.shape[0], x.shape[1], x.shape[2], groups, x.shape[3]//groups))
205 | y = tf.transpose(y, perm=[0, 1, 2, 4, 3])
206 | y = tf.reshape(y, (x.shape[0], x.shape[1], x.shape[2], x.shape[3]))
207 |
208 | return y
209 |
210 |
211 | def batch_normalize(x, is_training, decay=0.99, epsilon=0.001):
212 | def bn_train():
213 | batch_mean, batch_var = tf.nn.moments(x, axes=[0, 1, 2])
214 | train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
215 | train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
216 | with tf.control_dependencies([train_mean, train_var]):
217 | return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon)
218 |
219 | def bn_inference():
220 | return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon)
221 |
222 | dim = x.get_shape().as_list()[-1]
223 | beta = tf.get_variable(
224 | name='beta',
225 | shape=[dim],
226 | dtype=tf.float32,
227 | initializer=tf.truncated_normal_initializer(stddev=0.0),
228 | trainable=True)
229 | scale = tf.get_variable(
230 | name='scale',
231 | shape=[dim],
232 | dtype=tf.float32,
233 | initializer=tf.truncated_normal_initializer(stddev=0.1),
234 | trainable=True)
235 | pop_mean = tf.get_variable(
236 | name='pop_mean',
237 | shape=[dim],
238 | dtype=tf.float32,
239 | initializer=tf.constant_initializer(0.0),
240 | trainable=False)
241 | pop_var = tf.get_variable(
242 | name='pop_var',
243 | shape=[dim],
244 | dtype=tf.float32,
245 | initializer=tf.constant_initializer(1.0),
246 | trainable=False)
247 |
248 | return tf.cond(is_training, bn_train, bn_inference)
249 |
--------------------------------------------------------------------------------