├── .DS_Store ├── LICENSE ├── README.md ├── adaconv.py └── assets ├── archi.png ├── teaser.png └── usage.png /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/AdaConv-Tensorflow/a33b912077fe374424bc61f6e78adf655158bfbe/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Junho Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## AdaConv — Simple TensorFlow Implementation [[Paper]](https://openaccess.thecvf.com/content/CVPR2021/papers/Chandran_Adaptive_Convolutions_for_Structure-Aware_Style_Transfer_CVPR_2021_paper.pdf) 2 | ### : Adaptive Convolutions for Structure-Aware Style Transfer (CVPR 2021) 3 | 4 | ## Note 5 | This repository does not implement all codes, but only implements the core modules of the paper. 6 | 7 |
8 | 9 | 10 |
11 | 12 | ## Requirements 13 | * `Tensorflow == 2.5.0` 14 | 15 | ## Usage 16 | ```python 17 | feats = tf.random.normal(shape=[5, 64, 64, 256]) 18 | style_w = tf.random.normal(shape=[5, 512]) 19 | 20 | kp = KernelPredict(in_channels=feats.shape[-1], group_div=1) 21 | adac = AdaConv(channels=1024, group_div=1) 22 | 23 | w_spatial, w_pointwise, bias = kp(style_w) 24 | x = adac([feats, w_spatial, w_pointwise, bias]) # [5, 64, 64, 1024] 25 | ``` 26 | 27 | ## Reference 28 | * https://github.com/RElbers/ada-conv-pytorch 29 | 30 | ## Author 31 | [Junho Kim](http://bit.ly/jhkim_ai) 32 | -------------------------------------------------------------------------------- /adaconv.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Sequential 3 | 4 | class KernelPredict(tf.keras.layers.Layer): 5 | def __init__(self, in_channels, kernel_size=3, conv=False, group_div=1, name='KernelPredict'): 6 | super(KernelPredict, self).__init__(name=name) 7 | self.channels = in_channels # content feature map channels 8 | self.kernel_size = kernel_size 9 | self.group_div = group_div 10 | self.conv = conv 11 | self.n_groups = self.channels // self.group_div 12 | 13 | if self.conv: 14 | self.w_spatial_layer = tf.keras.layers.Conv2D(filters=self.channels * self.channels // self.n_groups, kernel_size=self.kernel_size, 15 | strides=1, use_bias=True, padding='SAME', name='w_spatial_conv') 16 | self.w_point_layer = Sequential([tf.keras.layers.GlobalAvgPool2D(name='gap_point_conv'), 17 | tf.keras.layers.Dense(units=self.channels * self.channels // self.n_groups, 18 | use_bias=True, name='w_point_fc')]) 19 | self.bias = Sequential([tf.keras.layers.GlobalAvgPool2D(name='gap_point_bias'), 20 | tf.keras.layers.Dense(units=self.channels, 21 | use_bias=True, name='bias_fc')]) 22 | else: # fully-connected 23 | self.w_spatial_layer = tf.keras.layers.Dense(units=self.channels * self.channels // self.n_groups, 24 | use_bias=True, name='w_spatial_fc') 25 | self.w_point_layer = tf.keras.layers.Dense(units=self.channels * self.channels // self.n_groups, 26 | use_bias=True, name='w_point_fc') 27 | self.bias = tf.keras.layers.Dense(units=self.channels, 28 | use_bias=True, name='bias_fc') 29 | 30 | def call(self, style_w, training=None, mask=None): 31 | batch_size = style_w.shape[0] 32 | style_w_size = style_w.shape[1] 33 | 34 | w_spatial = self.w_spatial_layer(style_w) 35 | 36 | if self.conv: 37 | w_spatial = tf.reshape(w_spatial, shape=[batch_size, style_w_size, style_w_size, self.channels // self.n_groups, self.channels]) 38 | else: 39 | w_spatial = tf.reshape(w_spatial, shape=[batch_size, 1, 1, self.channels // self.n_groups, self.channels]) # in, out 40 | 41 | w_pointwise = self.w_point_layer(style_w) 42 | w_pointwise = tf.reshape(w_pointwise, shape=[batch_size, 1, 1, self.channels // self.n_groups, self.channels]) 43 | 44 | bias = self.bias(style_w) 45 | bias = tf.reshape(bias, shape=[batch_size, self.channels]) 46 | 47 | return w_spatial, w_pointwise, bias 48 | 49 | class AdaConv(tf.keras.layers.Layer): 50 | def __init__(self, channels, kernel_size=3, group_div=1, name='AdaConv'): 51 | super(AdaConv, self).__init__(name=name) 52 | 53 | self.channels = channels 54 | self.kernel_size = kernel_size 55 | self.group_div = group_div 56 | 57 | self.conv = tf.keras.layers.Conv2D(filters=self.channels, kernel_size=self.kernel_size, 58 | strides=1, use_bias=True, padding='SAME', name='conv1') 59 | 60 | def build(self, input_shape): 61 | self.n_groups = input_shape[0][-1] // self.group_div 62 | 63 | def call(self, inputs, training=None, mask=None): 64 | """ 65 | x = [batch, height, width, channels] 66 | w_spatial = [batch, ws_height, ws_width, in_channels, out_channels] 67 | w_pointwise = [batch, wp_height, wp_width, in_channels, out_channels] 68 | bias = [batch, out_channels] 69 | """ 70 | 71 | x, w_spatial, w_pointwise, bias = inputs 72 | batch_size = x.shape[0] 73 | xs = [] 74 | 75 | x = self._normalize(x) 76 | 77 | for i in range(batch_size): 78 | _x = self._apply_weights(x[i:i+1], w_spatial[i:i+1], w_pointwise[i:i+1], bias[i:i+1]) 79 | xs.append(_x) 80 | 81 | x = tf.concat(xs, axis=0) 82 | x = self.conv(x) 83 | 84 | return x 85 | 86 | def _normalize(self, x, eps=1e-5): 87 | mean = tf.reduce_mean(x, axis=[1, 2], keepdims=True) 88 | std = tf.math.reduce_std(x, axis=[1, 2], keepdims=True) 89 | x_norm = (x - mean) / (std + eps) 90 | 91 | return x_norm 92 | 93 | def _apply_weights(self, x, w_spatial, w_pointwise, bias): 94 | """ 95 | x = [1, height, width, channels] 96 | w_spatial = [1, ws_height, ws_width, in_channels, out_channels] 97 | w_pointwise = [1, wp_height, wp_width, in_channels, out_channels] 98 | bias = [1, out_channels] 99 | """ 100 | 101 | # spatial conv 102 | spatial_out_channels = w_spatial.shape[-1] 103 | spatial_kernel_size = w_spatial.shape[1] 104 | spatial_conv = tf.keras.layers.Conv2D(filters=spatial_out_channels, kernel_size=spatial_kernel_size, 105 | strides=1, use_bias=False, padding='SAME', groups=self.n_groups, name='spatial_conv') 106 | 107 | spatial_conv.build(x.shape) 108 | spatial_conv.set_weights(w_spatial) 109 | x = spatial_conv(x) 110 | 111 | # pointwise conv 112 | point_out_channels = w_pointwise.shape[-1] 113 | point_kernel_size = w_pointwise.shape[1] 114 | w_pointwise = tf.squeeze(w_pointwise, axis=0) 115 | bias = tf.squeeze(bias, axis=0) 116 | 117 | point_conv = tf.keras.layers.Conv2D(filters=point_out_channels, kernel_size=point_kernel_size, 118 | strides=1, use_bias=True, padding='VALID', groups=self.n_groups, name='point_conv') 119 | point_conv.build(x.shape) 120 | point_conv.set_weights([w_pointwise, bias]) 121 | x = point_conv(x) 122 | 123 | return x 124 | 125 | 126 | # test code 127 | feats = tf.random.normal(shape=[5, 64, 64, 256]) 128 | style_w = tf.random.normal(shape=[5, 512]) 129 | 130 | kp = KernelPredict(in_channels=feats.shape[-1], group_div=1) 131 | adac = AdaConv(channels=1024, group_div=1) 132 | 133 | w_spatial, w_pointwise, bias = kp(style_w) 134 | x = adac([feats, w_spatial, w_pointwise, bias]) 135 | print(x.shape) 136 | 137 | -------------------------------------------------------------------------------- /assets/archi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/AdaConv-Tensorflow/a33b912077fe374424bc61f6e78adf655158bfbe/assets/archi.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/AdaConv-Tensorflow/a33b912077fe374424bc61f6e78adf655158bfbe/assets/teaser.png -------------------------------------------------------------------------------- /assets/usage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/AdaConv-Tensorflow/a33b912077fe374424bc61f6e78adf655158bfbe/assets/usage.png --------------------------------------------------------------------------------