├── README.md ├── layers.py └── model.py /README.md: -------------------------------------------------------------------------------- 1 | # Simple Baselines For Image Restoration [[Paper]](https://arxiv.org/abs/2204.04676) 2 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Union 2 | import tensorflow as tf 3 | import tensorflow.keras.backend as K 4 | 5 | 6 | ''' 7 | class PlainBLock(tf.keras.layers.Layer): 8 | def __init__(self, 9 | n_filters: int, 10 | dw_expansion: int = 2, 11 | ffn_expansion: int = 2 12 | ): 13 | super(PlainBLock, self).__init__() 14 | 15 | self.n_filters = n_filters 16 | self.dw_filters = n_filters * dw_expansion 17 | self.ffn_filters = n_filters * ffn_expansion 18 | 19 | self.spatial = tf.keras.Sequential([ 20 | tf.keras.layers.Conv2D(self.dw_filters, 21 | activation=None, 22 | kernel_size=1, 23 | strides=1, 24 | padding='VALID' 25 | ), 26 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, 27 | strides=1, 28 | padding='SAME', 29 | activation=None 30 | ), 31 | tf.keras.layers.ReLU(), 32 | tf.keras.layers.Conv2D(self.n_filters, 33 | kernel_size=1, 34 | activation=None, 35 | strides=1, 36 | padding='VALID' 37 | ) 38 | ]) 39 | self.channel = tf.keras.Sequential([ 40 | tf.keras.layers.Dense(self.ffn_filters, 41 | activation=None 42 | ), 43 | tf.keras.layers.ReLU(), 44 | tf.keras.layers.Dense(self.n_filters, 45 | activation=None 46 | ) 47 | ]) 48 | 49 | def call(self, inputs, *args, **kwargs): 50 | inputs = self.spatial(inputs) + inputs 51 | inputs = self.channel(inputs) + inputs 52 | return inputs 53 | 54 | 55 | class CAModule(tf.keras.layers.Layer): 56 | def __init__(self, 57 | n_filters: int, 58 | reduction_rate: int = 4 59 | ): 60 | super(CAModule, self).__init__() 61 | 62 | self.n_filters = n_filters 63 | self.reduction_filters = int(self.n_filters // self.reduction_rate) 64 | 65 | self.pool = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=[1, 2], keepdims=True)) 66 | self.forward = tf.keras.Sequential([ 67 | tf.keras.layers.Dense(self.reduction_filters, 68 | activation='relu' 69 | ), 70 | tf.keras.layers.Dense(self.n_filters, 71 | activation=None 72 | ) 73 | ]) 74 | 75 | def call(self, inputs, *args, **kwargs): 76 | attention = self.pool(inputs) 77 | attention = self.forward(attention) 78 | inputs = inputs * attention 79 | 80 | 81 | class BaselineBlock(tf.keras.layers.Layer): 82 | def __init__(self, 83 | n_filters: int, 84 | dw_expansion: int = 2, 85 | ffn_expansion: int = 2 86 | ): 87 | super(BaselineBlock, self).__init__() 88 | 89 | self.n_filters = n_filters 90 | self.dw_filters = n_filters * dw_expansion 91 | self.ffn_filters = n_filters * ffn_expansion 92 | 93 | self.spatial = tf.keras.Sequential([ 94 | tf.keras.layers.LayerNormalization(), 95 | tf.keras.layers.Conv2D(self.dw_filters, 96 | kernel_size=1, 97 | strides=1, 98 | activation=None, 99 | padding='VALID' 100 | ), 101 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, 102 | strides=1, 103 | activation=None, 104 | padding='SAME' 105 | ), 106 | tf.keras.layers.Activation('gelu'), 107 | tf.keras.layers.Conv2D(self.n_filters, 108 | kernel_size=1, 109 | strides=1, 110 | activation=None, 111 | padding='VALID' 112 | ) 113 | ]) 114 | self.channel = tf.keras.Sequential([ 115 | tf.keras.layers.LayerNormalization(), 116 | tf.keras.layers.Dense(self.ffn_filters, 117 | activation=None 118 | ), 119 | tf.keras.layers.Activation('gelu'), 120 | tf.keras.layers.Dense(self.n_filters, 121 | activation=None 122 | ) 123 | ]) 124 | 125 | def call(self, inputs, *args, **kwargs): 126 | inputs = self.spatial(inputs) + inputs 127 | inputs = self.channel(inputs) + inputs 128 | return inputs 129 | ''' 130 | 131 | 132 | def edge_padding2d(x, h_pad, w_pad): 133 | if h_pad[0] != 0: 134 | x_up = tf.gather(x, indices=[0], axis=1) 135 | x_up = tf.concat([x_up for _ in range(h_pad[0])], axis=1) 136 | x = tf.concat([x_up, x], axis=1) 137 | if h_pad[1] != 0: 138 | x_down = tf.gather(tf.reverse(x, axis=[1]), indices=[0], axis=1) 139 | x_down = tf.concat([x_down for _ in range(h_pad[1])], axis=1) 140 | x = tf.concat([x, x_down], axis=1) 141 | if w_pad[0] != 0: 142 | x_left = tf.gather(x, indices=[0], axis=2) 143 | x_left = tf.concat([x_left for _ in range(w_pad[0])], axis=2) 144 | x = tf.concat([x_left, x], axis=2) 145 | if w_pad[1] != 0: 146 | x_right= tf.gather(tf.reverse(x, axis=[2]), indices=[0], axis=2) 147 | x_right = tf.concat([x_right for _ in range(w_pad[1])], axis=2) 148 | x = tf.concat([x, x_right], axis=2) 149 | return x 150 | 151 | 152 | class LocalAvgPool2D(tf.keras.layers.Layer): 153 | def __init__( 154 | self, local_size: Sequence[int] 155 | ): 156 | super(LocalAvgPool2D, self).__init__() 157 | self.local_size = local_size 158 | 159 | def call(self, inputs, training): 160 | if training: 161 | return tf.reduce_mean(inputs, axis=[1, 2], keepdims=True) 162 | 163 | _, h, w, _ = inputs.get_shape().as_list() 164 | kh = min(h, self.local_size[0]) 165 | kw = min(w, self.local_size[1]) 166 | inputs = tf.pad(inputs, 167 | [[0, 0], 168 | [1, 0], 169 | [1, 0], 170 | [0, 0]] 171 | ) 172 | inputs = tf.cumsum(tf.cumsum(inputs, axis=2), axis=1) 173 | s1 = tf.slice(inputs, 174 | [0, 0, 0, 0], 175 | [-1, kh, kw, -1] 176 | ) 177 | s2 = tf.slice(inputs, 178 | [0, 0, (w - kw)+1, 0], 179 | [-1, kw, -1, -1] 180 | ) 181 | s3 = tf.slice(inputs, 182 | [0, (h - kh)+1, 0, 0], 183 | [-1, -1, kw, -1] 184 | ) 185 | s4 = tf.slice(inputs, 186 | [0, (h - kh)+1, (w - kw)+1, 0], 187 | [-1, -1, -1, -1] 188 | ) 189 | local_ap = (s4 + s1 - s2 - s3) / (kh * kw) 190 | 191 | _, h_, w_, _ = local_ap.get_shape().as_list() 192 | h_pad, w_pad = [(h - h_) // 2, (h - h_ + 1) // 2], [(w - w_) // 2, (w - w_ + 1) // 2] 193 | local_ap = edge_padding2d(local_ap, h_pad, w_pad) 194 | return local_ap 195 | 196 | 197 | class PixelShuffle(tf.keras.layers.Layer): 198 | def __init__(self, upsample_rate): 199 | super(PixelShuffle, self).__init__() 200 | self.upsample_rate = upsample_rate 201 | 202 | def call(self, inputs, *args, **kwargs): 203 | return tf.nn.depth_to_space( 204 | inputs, block_size=self.upsample_rate 205 | ) 206 | 207 | 208 | class SimpleGate(tf.keras.layers.Layer): 209 | def __init__(self): 210 | super(SimpleGate, self).__init__() 211 | 212 | def call(self, inputs, *args, **kwargs): 213 | x1, x2 = tf.split( 214 | inputs, num_or_size_splits=2, axis=-1 215 | ) 216 | return x1 * x2 217 | 218 | 219 | class SimpleChannelAttention(tf.keras.layers.Layer): 220 | def __init__( 221 | self, n_filters: int, kh: int, kw: int 222 | ): 223 | super(SimpleChannelAttention, self).__init__() 224 | self.n_filters = n_filters 225 | self.kh = kh 226 | self.kw = kw 227 | 228 | self.pool = LocalAvgPool2D((kh, kw)) 229 | self.w = tf.keras.layers.Dense( 230 | self.n_filters, activation=None 231 | ) 232 | 233 | def call(self, inputs, *args, **kwargs): 234 | attention = self.pool(inputs) 235 | attention = self.w(attention) 236 | return attention * inputs 237 | 238 | 239 | class NAFBlock(tf.keras.layers.Layer): 240 | def __init__( 241 | self, n_filters: int, dropout_rate: float, kh: int, 242 | kw: int, dw_expansion: int = 2, ffn_expansion: int = 2 243 | ): 244 | super(NAFBlock, self).__init__() 245 | self.n_filters = n_filters 246 | self.dropout_rate = dropout_rate 247 | self.kh = kh 248 | self.kw = kw 249 | self.dw_filters = n_filters * dw_expansion 250 | self.ffn_filters = n_filters * ffn_expansion 251 | 252 | self.spatial = tf.keras.Sequential([ 253 | tf.keras.layers.LayerNormalization(), 254 | tf.keras.layers.Conv2D( 255 | self.dw_filters, kernel_size=1, strides=1, padding='VALID', 256 | activation=None 257 | ), 258 | tf.keras.layers.DepthwiseConv2D( 259 | kernel_size=3, strides=1, padding='SAME', activation=None 260 | ), 261 | SimpleGate(), 262 | SimpleChannelAttention( 263 | self.n_filters, self.kh, self.kw 264 | ), 265 | tf.keras.layers.Conv2D( 266 | self.n_filters, kernel_size=1, strides=1, padding='VALID', 267 | activation=None 268 | ) 269 | ]) 270 | self.spatial_drop = tf.keras.layers.Dropout(self.dropout_rate) 271 | 272 | self.channel = tf.keras.Sequential([ 273 | tf.keras.layers.LayerNormalization(), 274 | tf.keras.layers.Dense(self.ffn_filters, 275 | activation=None 276 | ), 277 | SimpleGate(), 278 | tf.keras.layers.Dense(self.n_filters, 279 | activation=None 280 | ) 281 | ]) 282 | self.channel_drop = tf.keras.layers.Dropout(self.dropout_rate) 283 | 284 | self.beta = tf.Variable( 285 | tf.zeros((1, 1, 1, self.n_filters)), 286 | trainable=True, 287 | dtype=tf.float32 288 | ) 289 | self.gamma = tf.Variable( 290 | tf.zeros((1, 1, 1, self.n_filters)), 291 | trainable=True, 292 | dtype=tf.float32 293 | ) 294 | 295 | def call(self, inputs, *args, **kwargs): 296 | inputs = self.spatial_drop(self.spatial(inputs)) * self.beta + inputs 297 | inputs = self.channel_drop(self.channel(inputs)) * self.gamma + inputs 298 | return inputs 299 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from typing import Sequence, Optional 3 | from layers import * 4 | 5 | 6 | class NAFNet(tf.keras.models.Model): 7 | def __init__( 8 | self, width: int = 16, n_middle_blocks: int = 1, n_enc_blocks: Sequence[int] = (1, 1, 1, 28), 9 | n_dec_blocks: Sequence[int] = (1, 1, 1, 1), dropout_rate: float = 0., 10 | train_size: Sequence[Optional[int]] = (None, 256, 256, 3), tlsc_rate: float = 1.5 11 | ): 12 | super(NAFNet, self).__init__() 13 | self.width = width 14 | self.n_middle_blocks = n_middle_blocks 15 | self.n_enc_blocks = n_enc_blocks 16 | self.n_dec_blocks = n_dec_blocks 17 | self.dropout_rate = dropout_rate 18 | self.train_size = train_size 19 | self.tlsc_rate = tlsc_rate 20 | n_stages = len(n_enc_blocks) 21 | kh, kw = int(train_size[0] * tlsc_rate), int(train_size[1] * tlsc_rate) 22 | 23 | self.to_features = tf.keras.layers.Conv2D( 24 | width, kernel_size=3, padding='SAME', activation=None, 25 | strides=1 26 | ) 27 | self.to_rgb = tf.keras.layers.Conv2D( 28 | 3, kernel_size=3, padding='SAME', activation=None, 29 | strides=1 30 | ) 31 | self.encoders = [] 32 | self.downs = [] 33 | for i, n in enumerate(n_enc_blocks): 34 | self.encoders.append( 35 | tf.keras.Sequential([ 36 | NAFBlock( 37 | width * (2 ** i), dropout_rate, kh // (2 ** i), kw // (2 ** i) 38 | ) for _ in range(n) 39 | ]) 40 | ) 41 | self.downs.append( 42 | tf.keras.layers.Conv2D( 43 | width * (2 ** (i + 1)), kernel_size=2, padding='valid', strides=2, 44 | activation=None 45 | ) 46 | ) 47 | self.middles = tf.keras.Sequential([ 48 | NAFBlock( 49 | width * (2 ** n_stages), dropout_rate, kh // (2 ** n_stages), kw // (2 ** n_stages) 50 | ) for _ in range(n_middle_blocks) 51 | ]) 52 | self.decoders = [] 53 | self.ups = [] 54 | for i, n in enumerate(n_dec_blocks): 55 | self.ups.append( 56 | tf.keras.Sequential([ 57 | tf.keras.layers.Conv2D( 58 | width * (2 ** (n_stages - i)) * 2, kernel_size=1, padding='VALID', activation=None, 59 | strides=1 60 | ), 61 | PixelShuffle(2) 62 | ]) 63 | ) 64 | self.decoders.append( 65 | tf.keras.Sequential([ 66 | NAFBlock( 67 | width * (2 ** (n_stages - (i + 1))), dropout_rate, 68 | kh // (2 ** (n_stages - (i + 1))), kw // (2 ** (n_stages - (i + 1))) 69 | ) for _ in range(n) 70 | ]) 71 | ) 72 | 73 | @tf.function 74 | def forward(self, x, training=False): 75 | features = self.to_features(x, training=training) 76 | 77 | encs = [] 78 | for encoder, down in zip(self.encoders, self.downs): 79 | features = encoder(features, training=training) 80 | encs.append(features) 81 | features = down(features, training=training) 82 | 83 | features = self.middles(features, training=training) 84 | 85 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): 86 | features = up(features, training=training) 87 | features = features + enc_skip 88 | features = decoder(features, training=training) 89 | 90 | x_res = self.to_rgb(features, training=training) 91 | x = x + x_res 92 | return x 93 | 94 | def call(self, inputs, training=None, mask=None): 95 | if training is None: 96 | training = False 97 | return self.forward(inputs, training=training) 98 | --------------------------------------------------------------------------------