├── .DS_Store
├── LICENSE
├── README.md
├── adaconv.py
└── assets
├── archi.png
├── teaser.png
└── usage.png
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/AdaConv-Tensorflow/a33b912077fe374424bc61f6e78adf655158bfbe/.DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Junho Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## AdaConv — Simple TensorFlow Implementation [[Paper]](https://openaccess.thecvf.com/content/CVPR2021/papers/Chandran_Adaptive_Convolutions_for_Structure-Aware_Style_Transfer_CVPR_2021_paper.pdf)
2 | ### : Adaptive Convolutions for Structure-Aware Style Transfer (CVPR 2021)
3 |
4 | ## Note
5 | This repository does not implement all codes, but only implements the core modules of the paper.
6 |
7 |
8 |

9 |

10 |
11 |
12 | ## Requirements
13 | * `Tensorflow == 2.5.0`
14 |
15 | ## Usage
16 | ```python
17 | feats = tf.random.normal(shape=[5, 64, 64, 256])
18 | style_w = tf.random.normal(shape=[5, 512])
19 |
20 | kp = KernelPredict(in_channels=feats.shape[-1], group_div=1)
21 | adac = AdaConv(channels=1024, group_div=1)
22 |
23 | w_spatial, w_pointwise, bias = kp(style_w)
24 | x = adac([feats, w_spatial, w_pointwise, bias]) # [5, 64, 64, 1024]
25 | ```
26 |
27 | ## Reference
28 | * https://github.com/RElbers/ada-conv-pytorch
29 |
30 | ## Author
31 | [Junho Kim](http://bit.ly/jhkim_ai)
32 |
--------------------------------------------------------------------------------
/adaconv.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import Sequential
3 |
4 | class KernelPredict(tf.keras.layers.Layer):
5 | def __init__(self, in_channels, kernel_size=3, conv=False, group_div=1, name='KernelPredict'):
6 | super(KernelPredict, self).__init__(name=name)
7 | self.channels = in_channels # content feature map channels
8 | self.kernel_size = kernel_size
9 | self.group_div = group_div
10 | self.conv = conv
11 | self.n_groups = self.channels // self.group_div
12 |
13 | if self.conv:
14 | self.w_spatial_layer = tf.keras.layers.Conv2D(filters=self.channels * self.channels // self.n_groups, kernel_size=self.kernel_size,
15 | strides=1, use_bias=True, padding='SAME', name='w_spatial_conv')
16 | self.w_point_layer = Sequential([tf.keras.layers.GlobalAvgPool2D(name='gap_point_conv'),
17 | tf.keras.layers.Dense(units=self.channels * self.channels // self.n_groups,
18 | use_bias=True, name='w_point_fc')])
19 | self.bias = Sequential([tf.keras.layers.GlobalAvgPool2D(name='gap_point_bias'),
20 | tf.keras.layers.Dense(units=self.channels,
21 | use_bias=True, name='bias_fc')])
22 | else: # fully-connected
23 | self.w_spatial_layer = tf.keras.layers.Dense(units=self.channels * self.channels // self.n_groups,
24 | use_bias=True, name='w_spatial_fc')
25 | self.w_point_layer = tf.keras.layers.Dense(units=self.channels * self.channels // self.n_groups,
26 | use_bias=True, name='w_point_fc')
27 | self.bias = tf.keras.layers.Dense(units=self.channels,
28 | use_bias=True, name='bias_fc')
29 |
30 | def call(self, style_w, training=None, mask=None):
31 | batch_size = style_w.shape[0]
32 | style_w_size = style_w.shape[1]
33 |
34 | w_spatial = self.w_spatial_layer(style_w)
35 |
36 | if self.conv:
37 | w_spatial = tf.reshape(w_spatial, shape=[batch_size, style_w_size, style_w_size, self.channels // self.n_groups, self.channels])
38 | else:
39 | w_spatial = tf.reshape(w_spatial, shape=[batch_size, 1, 1, self.channels // self.n_groups, self.channels]) # in, out
40 |
41 | w_pointwise = self.w_point_layer(style_w)
42 | w_pointwise = tf.reshape(w_pointwise, shape=[batch_size, 1, 1, self.channels // self.n_groups, self.channels])
43 |
44 | bias = self.bias(style_w)
45 | bias = tf.reshape(bias, shape=[batch_size, self.channels])
46 |
47 | return w_spatial, w_pointwise, bias
48 |
49 | class AdaConv(tf.keras.layers.Layer):
50 | def __init__(self, channels, kernel_size=3, group_div=1, name='AdaConv'):
51 | super(AdaConv, self).__init__(name=name)
52 |
53 | self.channels = channels
54 | self.kernel_size = kernel_size
55 | self.group_div = group_div
56 |
57 | self.conv = tf.keras.layers.Conv2D(filters=self.channels, kernel_size=self.kernel_size,
58 | strides=1, use_bias=True, padding='SAME', name='conv1')
59 |
60 | def build(self, input_shape):
61 | self.n_groups = input_shape[0][-1] // self.group_div
62 |
63 | def call(self, inputs, training=None, mask=None):
64 | """
65 | x = [batch, height, width, channels]
66 | w_spatial = [batch, ws_height, ws_width, in_channels, out_channels]
67 | w_pointwise = [batch, wp_height, wp_width, in_channels, out_channels]
68 | bias = [batch, out_channels]
69 | """
70 |
71 | x, w_spatial, w_pointwise, bias = inputs
72 | batch_size = x.shape[0]
73 | xs = []
74 |
75 | x = self._normalize(x)
76 |
77 | for i in range(batch_size):
78 | _x = self._apply_weights(x[i:i+1], w_spatial[i:i+1], w_pointwise[i:i+1], bias[i:i+1])
79 | xs.append(_x)
80 |
81 | x = tf.concat(xs, axis=0)
82 | x = self.conv(x)
83 |
84 | return x
85 |
86 | def _normalize(self, x, eps=1e-5):
87 | mean = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
88 | std = tf.math.reduce_std(x, axis=[1, 2], keepdims=True)
89 | x_norm = (x - mean) / (std + eps)
90 |
91 | return x_norm
92 |
93 | def _apply_weights(self, x, w_spatial, w_pointwise, bias):
94 | """
95 | x = [1, height, width, channels]
96 | w_spatial = [1, ws_height, ws_width, in_channels, out_channels]
97 | w_pointwise = [1, wp_height, wp_width, in_channels, out_channels]
98 | bias = [1, out_channels]
99 | """
100 |
101 | # spatial conv
102 | spatial_out_channels = w_spatial.shape[-1]
103 | spatial_kernel_size = w_spatial.shape[1]
104 | spatial_conv = tf.keras.layers.Conv2D(filters=spatial_out_channels, kernel_size=spatial_kernel_size,
105 | strides=1, use_bias=False, padding='SAME', groups=self.n_groups, name='spatial_conv')
106 |
107 | spatial_conv.build(x.shape)
108 | spatial_conv.set_weights(w_spatial)
109 | x = spatial_conv(x)
110 |
111 | # pointwise conv
112 | point_out_channels = w_pointwise.shape[-1]
113 | point_kernel_size = w_pointwise.shape[1]
114 | w_pointwise = tf.squeeze(w_pointwise, axis=0)
115 | bias = tf.squeeze(bias, axis=0)
116 |
117 | point_conv = tf.keras.layers.Conv2D(filters=point_out_channels, kernel_size=point_kernel_size,
118 | strides=1, use_bias=True, padding='VALID', groups=self.n_groups, name='point_conv')
119 | point_conv.build(x.shape)
120 | point_conv.set_weights([w_pointwise, bias])
121 | x = point_conv(x)
122 |
123 | return x
124 |
125 |
126 | # test code
127 | feats = tf.random.normal(shape=[5, 64, 64, 256])
128 | style_w = tf.random.normal(shape=[5, 512])
129 |
130 | kp = KernelPredict(in_channels=feats.shape[-1], group_div=1)
131 | adac = AdaConv(channels=1024, group_div=1)
132 |
133 | w_spatial, w_pointwise, bias = kp(style_w)
134 | x = adac([feats, w_spatial, w_pointwise, bias])
135 | print(x.shape)
136 |
137 |
--------------------------------------------------------------------------------
/assets/archi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/AdaConv-Tensorflow/a33b912077fe374424bc61f6e78adf655158bfbe/assets/archi.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/AdaConv-Tensorflow/a33b912077fe374424bc61f6e78adf655158bfbe/assets/teaser.png
--------------------------------------------------------------------------------
/assets/usage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/AdaConv-Tensorflow/a33b912077fe374424bc61f6e78adf655158bfbe/assets/usage.png
--------------------------------------------------------------------------------