├── model ├── __init__.py ├── decoder.py ├── aspp.py ├── model.py ├── mobilenet.py ├── resnet.py ├── utils.py └── refiner.py ├── LICENSE └── README.md /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import MattingBase, MattingRefine 2 | from .utils import load_torch_weights -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 University of Washington 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU 4 | 5 | 6 | class Decoder(Model): 7 | def __init__(self, channels): 8 | super().__init__() 9 | self.conv1 = Conv2D(channels[0], 3, padding='SAME', use_bias=False) 10 | self.bn1 = BatchNormalization(momentum=0.1, epsilon=1e-5) 11 | self.conv2 = Conv2D(channels[1], 3, padding='SAME', use_bias=False) 12 | self.bn2 = BatchNormalization(momentum=0.1, epsilon=1e-5) 13 | self.conv3 = Conv2D(channels[2], 3, padding='SAME', use_bias=False) 14 | self.bn3 = BatchNormalization(momentum=0.1, epsilon=1e-5) 15 | self.conv4 = Conv2D(channels[3], 3, padding='SAME', use_bias=True) 16 | self.relu = ReLU() 17 | 18 | def call(self, x, training=None): 19 | x4, x3, x2, x1, x0 = x 20 | x = tf.image.resize(x4, tf.shape(x3)[1:3]) 21 | x = tf.concat([x, x3], axis=-1) 22 | x = self.conv1(x, training=training) 23 | x = self.bn1(x, training=training) 24 | x = self.relu(x, training=training) 25 | x = tf.image.resize(x, tf.shape(x2)[1:3]) 26 | x = tf.concat([x, x2], axis=-1) 27 | x = self.conv2(x, training=training) 28 | x = self.bn2(x, training=training) 29 | x = self.relu(x, training=training) 30 | x = tf.image.resize(x, tf.shape(x1)[1:3]) 31 | x = tf.concat([x, x1], axis=-1) 32 | x = self.conv3(x, training=training) 33 | x = self.bn3(x, training=training) 34 | x = self.relu(x, training=training) 35 | x = tf.image.resize(x, tf.shape(x0)[1:3]) 36 | x = tf.concat([x, x0], axis=-1) 37 | x = self.conv4(x, training=training) 38 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Real-Time High-Resolution Background Matting (TensorFlow) 2 | 3 | This repo contains TensorFlow 2 implementation of Real-Time High-Resolution Background Matting. For more information and downloading the weights, please visit our [official repo](https://github.com/PeterL1n/BackgroundMattingV2). 4 | 5 | The TensorFlow implementation is experimental. We find PyTorch to have faster inference speed and suggest you to use the official PyTorch version whenever possible. 6 | 7 | ## Use our model 8 | 9 | We reimplement our model natively in TensorFlow 2 and provide a script to load PyTorch weights directly into the TensorFlow model. 10 | 11 | ```python 12 | import tensorflow as tf 13 | import torch # For loading PyTorch weights. 14 | 15 | from model import MattingRefine, load_torch_weights 16 | 17 | # Enable mixed precision, it reduces memory and may make model inference faster. 18 | tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True}) 19 | 20 | # Create TensorFlow model 21 | model = MattingRefine(backbone='resnet50', 22 | backbone_scale=0.25, 23 | refine_mode='sampling', 24 | refine_sample_pixels=80000) 25 | 26 | # Load PyTorch weights into TensorFlow model. 27 | load_torch_weights(model, torch.load('PATH_TO_PYTORCH_WEIGHTS.pth')) 28 | 29 | src = tf.random.normal((1, 1080, 1920, 3)) 30 | bgr = tf.random.normal((1, 1080, 1920, 3)) 31 | 32 | # Faster inference with tf.function 33 | # Note that at the first time the model run with 34 | # tf.function will be slow. 35 | model = tf.function(model, experimental_relax_shapes=True) 36 | 37 | pha, fgr = model([src, bgr], training=False)[:2] 38 | ``` 39 | 40 | ## Download weights 41 | 42 | Please visit the [official repo](https://github.com/PeterL1n/BackgroundMattingV2) for detail. 43 | -------------------------------------------------------------------------------- /model/aspp.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.models import Model, Sequential 3 | from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU, GlobalAveragePooling2D, Dropout 4 | 5 | 6 | class ASPP(Model): 7 | def __init__(self, filters, dilation_rates=[3, 6, 9]): 8 | super().__init__() 9 | self.aspp1 = ASPPConv(filters, 1, 1) 10 | self.aspp2 = ASPPConv(filters, 3, dilation_rates[0]) 11 | self.aspp3 = ASPPConv(filters, 3, dilation_rates[1]) 12 | self.aspp4 = ASPPConv(filters, 3, dilation_rates[2]) 13 | self.pool = ASPPPooling(filters) 14 | self.project = Sequential([ 15 | Conv2D(filters, 1, use_bias=False), 16 | BatchNormalization(momentum=0.1, epsilon=1e-5), 17 | ReLU(), 18 | Dropout(0.1) 19 | ]) 20 | 21 | def call(self, x, training=None): 22 | x = tf.concat([ 23 | self.aspp1(x, training=training), 24 | self.aspp2(x, training=training), 25 | self.aspp3(x, training=training), 26 | self.aspp4(x, training=training), 27 | self.pool(x, training=training) 28 | ], axis=-1) 29 | x = self.project(x, training=training) 30 | return x 31 | 32 | 33 | class ASPPConv(Model): 34 | def __init__(self, filters, kernel_size, dilation_rate): 35 | super().__init__() 36 | self.conv = Conv2D(filters, kernel_size, padding='SAME', dilation_rate=dilation_rate, use_bias=False) 37 | self.bn = BatchNormalization(momentum=0.1, epsilon=1e-5) 38 | self.relu = ReLU() 39 | 40 | def call(self, x, training=None): 41 | x = self.conv(x, training=training) 42 | x = self.bn(x, training=training) 43 | x = self.relu(x, training=training) 44 | return x 45 | 46 | 47 | class ASPPPooling(Model): 48 | def __init__(self, filters): 49 | super().__init__() 50 | self.pool = GlobalAveragePooling2D() 51 | self.conv = Conv2D(filters, 1, use_bias=False) 52 | self.bn = BatchNormalization(momentum=0.1, epsilon=1e-5) 53 | self.relu = ReLU() 54 | 55 | def call(self, x, training=None): 56 | h, w = tf.shape(x)[1], tf.shape(x)[2] 57 | x = self.pool(x, training=training) 58 | x = x[:, None, None, :] 59 | x = self.conv(x, training=training) 60 | x = self.bn(x, training=training) 61 | x = self.relu(x, training=training) 62 | x = tf.image.resize(x, (h, w), 'nearest') 63 | return x -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.models import Model 3 | 4 | from .resnet import ResNetEncoder 5 | from .mobilenet import MobileNetV2Encoder 6 | from .aspp import ASPP 7 | from .decoder import Decoder 8 | from .refiner import Refiner 9 | 10 | 11 | class MattingBase(Model): 12 | def __init__(self, backbone): 13 | super().__init__() 14 | assert backbone in ["resnet50", "resnet101", "mobilenetv2"] 15 | if backbone in ['resnet50', 'resnet101']: 16 | self.backbone = ResNetEncoder(backbone) 17 | else: 18 | self.backbone = MobileNetV2Encoder() 19 | self.aspp = ASPP(256, [3, 6, 9]) 20 | self.decoder = Decoder([128, 64, 48, (1 + 3 + 1 + 32)]) 21 | 22 | def call(self, x, training=None, _output_fgr_as_residual=False): 23 | src, bgr = x 24 | x = tf.concat([src, bgr], axis=-1) 25 | x, *shortcuts = self.backbone(x, training=training) 26 | x = self.aspp(x, training=training) 27 | x = self.decoder([x, *shortcuts], training=training) 28 | 29 | pha = tf.clip_by_value(x[:, :, :, 0:1], 0, 1) 30 | fgr = x[:, :, :, 1:4] 31 | err = tf.clip_by_value(x[:, :, :, 4:5], 0, 1) 32 | hid = tf.nn.relu(x[:, :, :, 5:]) 33 | 34 | if not _output_fgr_as_residual: 35 | fgr = tf.clip_by_value(fgr + src, 0, 1) 36 | 37 | return pha, fgr, err, hid 38 | 39 | 40 | class MattingRefine(MattingBase): 41 | def __init__(self, 42 | backbone: str, 43 | backbone_scale: float = 1/4, 44 | refine_mode: str = 'sampling', 45 | refine_sample_pixels: int = 80_000, 46 | refine_threshold: float = 0.7): 47 | assert backbone_scale <= 1/2, 'backbone_scale should not be greater than 1/2' 48 | super().__init__(backbone) 49 | self.backbone_scale = backbone_scale 50 | self.refiner = Refiner(refine_mode, refine_sample_pixels, refine_threshold) 51 | 52 | def call(self, x, training=None): 53 | src, bgr = x 54 | tf.debugging.assert_equal(tf.shape(src), tf.shape(bgr), 55 | 'src and bgr must have equal size.') 56 | tf.debugging.assert_equal(tf.shape(src)[1:3] // 4 * 4, tf.shape(src)[1:3], 57 | 'src and bgr must have width and height that are divisible by 4') 58 | 59 | size_sm = tf.cast(tf.cast(tf.shape(src)[1:3], tf.float32) * self.backbone_scale, tf.int32) 60 | src_sm = tf.image.resize(src, size_sm) 61 | bgr_sm = tf.image.resize(bgr, size_sm) 62 | 63 | # Base 64 | pha_sm, fgr_sm, err_sm, hid_sm = super().call([src_sm, bgr_sm], training=training, _output_fgr_as_residual=True) 65 | 66 | # Refiner 67 | pha, fgr, ref_sm = self.refiner([src, bgr, pha_sm, fgr_sm, err_sm, hid_sm], training=training) 68 | 69 | # Clamp outputs 70 | pha = tf.clip_by_value(pha, 0, 1) 71 | fgr = tf.clip_by_value(fgr + src, 0, 1) 72 | fgr_sm = tf.clip_by_value(fgr_sm + src_sm, 0, 1) 73 | 74 | return pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm 75 | -------------------------------------------------------------------------------- /model/mobilenet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.models import Model, Sequential 3 | from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU, ZeroPadding2D, DepthwiseConv2D 4 | 5 | 6 | class MobileNetV2Encoder(Model): 7 | def __init__(self): 8 | super().__init__() 9 | self.features = [ 10 | ConvBNReLU(32, 3, 2), 11 | InvertedResidual(32, 16, 1, 1), 12 | InvertedResidual(16, 24, 2, 6), 13 | InvertedResidual(24, 24, 1, 6), 14 | InvertedResidual(24, 32, 2, 6), 15 | InvertedResidual(32, 32, 1, 6), 16 | InvertedResidual(32, 32, 1, 6), 17 | InvertedResidual(32, 64, 2, 6), 18 | InvertedResidual(64, 64, 1, 6), 19 | InvertedResidual(64, 64, 1, 6), 20 | InvertedResidual(64, 64, 1, 6), 21 | InvertedResidual(64, 96, 1, 6), 22 | InvertedResidual(96, 96, 1, 6), 23 | InvertedResidual(96, 96, 1, 6), 24 | InvertedResidual(96, 160, 1, 6), 25 | InvertedResidual(160, 160, 1, 6, 2), 26 | InvertedResidual(160, 160, 1, 6, 2), 27 | InvertedResidual(160, 320, 1, 6, 2), 28 | ] 29 | 30 | def call(self, x, training=None): 31 | x0 = x 32 | for i in range(0, 2): 33 | x = self.features[i](x, training=training) 34 | x1 = x 35 | for i in range(2, 4): 36 | x = self.features[i](x, training=training) 37 | x2 = x 38 | for i in range(4, 7): 39 | x = self.features[i](x, training=training) 40 | x3 = x 41 | for i in range(7, 18): 42 | x = self.features[i](x, training=training) 43 | x4 = x 44 | return x4, x3, x2, x1, x0 45 | 46 | 47 | class InvertedResidual(Model): 48 | def __init__(self, inp, oup, strides, expand_ratio, dilation_rate=1): 49 | super().__init__() 50 | self.use_res_connect = strides == 1 and inp == oup 51 | hidden_filters = int(round(inp * expand_ratio)) 52 | 53 | if expand_ratio != 1: 54 | self.pw = ConvBNReLU(hidden_filters, 1) 55 | self.dw = ConvBNReLU(hidden_filters, 3, strides, True, dilation_rate) 56 | self.pw_linear = ConvBNReLU(oup, 1, activation=False) 57 | 58 | def call(self, x, training=None): 59 | identity = x 60 | if hasattr(self, 'pw'): 61 | x = self.pw(x, training=training) 62 | x = self.dw(x, training=training) 63 | x = self.pw_linear(x, training=training) 64 | if self.use_res_connect: 65 | x += identity 66 | return x 67 | 68 | 69 | class ConvBNReLU(Sequential): 70 | def __init__(self, filters, kernel_size, strides=1, depthwise=False, dilation_rate=1, activation=True): 71 | super().__init__() 72 | padding = (kernel_size * dilation_rate - 1) // 2 73 | if padding != 0: 74 | self.pad = ZeroPadding2D((padding, padding)) 75 | if depthwise: 76 | self.conv = DepthwiseConv2D(kernel_size, strides, dilation_rate=dilation_rate, use_bias=False) 77 | else: 78 | self.conv = Conv2D(filters, kernel_size, strides, dilation_rate=dilation_rate, use_bias=False) 79 | self.bn = BatchNormalization(momentum=0.1, epsilon=1e-5) 80 | if activation: 81 | self.relu = ReLU(6) 82 | 83 | def call(self, x, training=None): 84 | if hasattr(self, 'pad'): 85 | x = self.pad(x, training=training) 86 | x = self.conv(x, training=training) 87 | x = self.bn(x, training=training) 88 | if hasattr(self, 'relu'): 89 | x = self.relu(x, training=training) 90 | return x -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.models import Model, Sequential 3 | from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU, MaxPool2D, ZeroPadding2D 4 | 5 | 6 | class ResNetEncoder(Model): 7 | def __init__(self, variant): 8 | super().__init__() 9 | 10 | blocks = {'resnet50': [3, 4, 6, 3], 'resnet101': [3, 4, 23, 3]}[variant] 11 | filters = [64, 256, 512, 1024, 2048] 12 | 13 | self.pad1 = ZeroPadding2D(3) 14 | self.conv1 = Conv2D(filters[0], 7, 2, use_bias=False) 15 | self.bn1 = BatchNormalization(momentum=0.1, epsilon=1e-5) 16 | self.relu = ReLU() 17 | self.pad2 = ZeroPadding2D(1) 18 | self.maxpool = MaxPool2D(3, 2) 19 | 20 | self.layer1 = self._make_layer(filters[1], blocks[0], strides=1, dilation_rate=1) 21 | self.layer2 = self._make_layer(filters[2], blocks[1], strides=2, dilation_rate=1) 22 | self.layer3 = self._make_layer(filters[3], blocks[2], strides=2, dilation_rate=1) 23 | self.layer4 = self._make_layer(filters[4], blocks[3], strides=1, dilation_rate=2) 24 | 25 | def _make_layer(self, filters, blocks, strides, dilation_rate): 26 | layers = [ResNetBlock(filters, 3, strides, 1, True)] 27 | for _ in range(1, blocks): 28 | layers.append(ResNetBlock(filters, 3, 1, dilation_rate, False)) 29 | return Sequential(layers) 30 | 31 | def call(self, x, training=None): 32 | x0 = x # 1/1 33 | x = self.pad1(x) 34 | x = self.conv1(x, training=training) 35 | x = self.bn1(x, training=training) 36 | x = self.relu(x, training=training) 37 | x1 = x # 1/2 38 | x = self.pad2(x) 39 | x = self.maxpool(x, training=training) 40 | x = self.layer1(x, training=training) 41 | x2 = x # 1/4 42 | x = self.layer2(x, training=training) 43 | x3 = x # 1/8 44 | x = self.layer3(x, training=training) 45 | x = self.layer4(x, training=training) 46 | x4 = x # 1/16 47 | return x4, x3, x2, x1, x0 48 | 49 | 50 | class ResNetBlock(Model): 51 | def __init__(self, filters, kernel_size=3, strides=1, dilation_rate=1, conv_shortcut=True): 52 | super().__init__() 53 | self.conv1 = Conv2D(filters // 4, 1, use_bias=False) 54 | self.bn1 = BatchNormalization(momentum=0.1, epsilon=1e-5) 55 | self.pad2 = ZeroPadding2D(dilation_rate) 56 | self.conv2 = Conv2D(filters // 4, kernel_size, strides, dilation_rate=dilation_rate, use_bias=False) 57 | self.bn2 = BatchNormalization(momentum=0.1, epsilon=1e-5) 58 | self.conv3 = Conv2D(filters, 1, use_bias=False) 59 | self.bn3 = BatchNormalization(momentum=0.1, epsilon=1e-5) 60 | self.relu = ReLU() 61 | if conv_shortcut: 62 | self.convd = Conv2D(filters, 1, strides, use_bias=False) 63 | self.bnd = BatchNormalization(momentum=0.1, epsilon=1e-5) 64 | 65 | def call(self, x, training=None): 66 | if hasattr(self, 'convd'): 67 | shortcut = self.convd(x, training=training) 68 | shortcut = self.bnd(shortcut, training=training) 69 | else: 70 | shortcut = x 71 | 72 | x = self.conv1(x, training=training) 73 | x = self.bn1(x, training=training) 74 | x = self.relu(x, training=training) 75 | x = self.pad2(x, training=training) 76 | x = self.conv2(x, training=training) 77 | x = self.bn2(x, training=training) 78 | x = self.relu(x, training=training) 79 | x = self.conv3(x, training=training) 80 | x = self.bn3(x, training=training) 81 | x += shortcut 82 | x = self.relu(x, training=training) 83 | return x -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .resnet import ResNetEncoder 4 | from .mobilenet import MobileNetV2Encoder 5 | 6 | 7 | def load_torch_weights(model, state_dict, default_size=(1080, 1920)): 8 | # Build model 9 | model([tf.random.normal((1, *default_size, 3)), tf.random.normal((1, *default_size, 3))], training=False) 10 | 11 | # ResNet backbone 12 | if isinstance(model.backbone, ResNetEncoder): 13 | load_conv_weights(model.backbone.conv1, state_dict, 'backbone.conv1') 14 | load_bn_weights(model.backbone.bn1, state_dict, 'backbone.bn1') 15 | for l in range(1, 5): 16 | for b, resblock in enumerate(getattr(model.backbone, f'layer{l}').layers): 17 | if hasattr(resblock, 'convd'): 18 | load_conv_weights(resblock.convd, state_dict, f'backbone.layer{l}.{b}.downsample.0') 19 | load_bn_weights(resblock.bnd, state_dict, f'backbone.layer{l}.{b}.downsample.1') 20 | load_conv_weights(resblock.conv1, state_dict, f'backbone.layer{l}.{b}.conv1') 21 | load_conv_weights(resblock.conv2, state_dict, f'backbone.layer{l}.{b}.conv2') 22 | load_conv_weights(resblock.conv3, state_dict, f'backbone.layer{l}.{b}.conv3') 23 | load_bn_weights(resblock.bn1, state_dict, f'backbone.layer{l}.{b}.bn1') 24 | load_bn_weights(resblock.bn2, state_dict, f'backbone.layer{l}.{b}.bn2') 25 | load_bn_weights(resblock.bn3, state_dict, f'backbone.layer{l}.{b}.bn3') 26 | 27 | # MobileNet backbone 28 | if isinstance(model.backbone, MobileNetV2Encoder): 29 | load_conv_weights(model.backbone.features[0].conv, state_dict, 'backbone.features.0.0') 30 | load_bn_weights(model.backbone.features[0].bn, state_dict, 'backbone.features.0.1') 31 | load_conv_weights(model.backbone.features[1].dw.conv, state_dict, 'backbone.features.1.conv.0.0', True) 32 | load_bn_weights(model.backbone.features[1].dw.bn, state_dict, 'backbone.features.1.conv.0.1') 33 | load_conv_weights(model.backbone.features[1].pw_linear.conv, state_dict, 'backbone.features.1.conv.1') 34 | load_bn_weights(model.backbone.features[1].pw_linear.bn, state_dict, 'backbone.features.1.conv.2') 35 | for i in range(2, 18): 36 | load_conv_weights(model.backbone.features[i].pw.conv, state_dict, f'backbone.features.{i}.conv.0.0') 37 | load_bn_weights(model.backbone.features[i].pw.bn, state_dict, f'backbone.features.{i}.conv.0.1') 38 | load_conv_weights(model.backbone.features[i].dw.conv, state_dict, f'backbone.features.{i}.conv.1.0', True) 39 | load_bn_weights(model.backbone.features[i].dw.bn, state_dict, f'backbone.features.{i}.conv.1.1') 40 | load_conv_weights(model.backbone.features[i].pw_linear.conv, state_dict, f'backbone.features.{i}.conv.2') 41 | load_bn_weights(model.backbone.features[i].pw_linear.bn, state_dict, f'backbone.features.{i}.conv.3') 42 | 43 | # ASPP 44 | for i in range(4): 45 | load_conv_weights(getattr(model.aspp, f'aspp{i+1}').conv, state_dict, f'aspp.convs.{i}.0') 46 | load_bn_weights(getattr(model.aspp, f'aspp{i+1}').bn, state_dict, f'aspp.convs.{i}.1') 47 | load_conv_weights(model.aspp.pool.conv, state_dict, f'aspp.convs.4.1') 48 | load_bn_weights(model.aspp.pool.bn, state_dict, f'aspp.convs.4.2') 49 | load_conv_weights(model.aspp.project.layers[0], state_dict, f'aspp.project.0') 50 | load_bn_weights(model.aspp.project.layers[1], state_dict, f'aspp.project.1') 51 | 52 | # Decoder 53 | load_conv_weights(model.decoder.conv1, state_dict, 'decoder.conv1') 54 | load_bn_weights(model.decoder.bn1, state_dict, 'decoder.bn1') 55 | load_conv_weights(model.decoder.conv2, state_dict, 'decoder.conv2') 56 | load_bn_weights(model.decoder.bn2, state_dict, 'decoder.bn2') 57 | load_conv_weights(model.decoder.conv3, state_dict, 'decoder.conv3') 58 | load_bn_weights(model.decoder.bn3, state_dict, 'decoder.bn3') 59 | load_conv_weights(model.decoder.conv4, state_dict, 'decoder.conv4') 60 | 61 | # Refiner 62 | if hasattr(model, 'refiner'): 63 | load_conv_weights(model.refiner.conv1, state_dict, 'refiner.conv1') 64 | load_bn_weights(model.refiner.bn1, state_dict, 'refiner.bn1') 65 | load_conv_weights(model.refiner.conv2, state_dict, 'refiner.conv2') 66 | load_bn_weights(model.refiner.bn2, state_dict, 'refiner.bn2') 67 | load_conv_weights(model.refiner.conv3, state_dict, 'refiner.conv3') 68 | load_bn_weights(model.refiner.bn3, state_dict, 'refiner.bn3') 69 | load_conv_weights(model.refiner.conv4, state_dict, 'refiner.conv4') 70 | 71 | def load_conv_weights(conv, state_dict, name, depthwise_conv=False): 72 | weight = state_dict[name + '.weight'] 73 | if depthwise_conv: 74 | weight = weight.permute(2, 3, 0, 1).numpy() 75 | else: 76 | weight = weight.permute(2, 3, 1, 0).numpy() 77 | if name + '.bias' in state_dict: 78 | bias = state_dict[name + '.bias'].numpy() 79 | conv.set_weights([weight, bias]) 80 | else: 81 | conv.set_weights([weight]) 82 | 83 | def load_bn_weights(bn, state_dict, name): 84 | weight = state_dict[name + '.weight'] 85 | bias = state_dict[name + '.bias'] 86 | running_mean = state_dict[name + '.running_mean'] 87 | running_var = state_dict[name + '.running_var'] 88 | bn.set_weights([weight, bias, running_mean, running_var]) -------------------------------------------------------------------------------- /model/refiner.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU 4 | 5 | 6 | class Refiner(Model): 7 | 8 | prevent_oversampling = True 9 | 10 | def __init__(self, mode, sample_pixels, threshold): 11 | super().__init__() 12 | assert mode in ['full', 'sampling', 'thresholding'] 13 | 14 | self.mode = mode 15 | self.sample_pixels = sample_pixels 16 | self.threshold = threshold 17 | 18 | channels = [24, 16, 12, 4] 19 | self.conv1 = Conv2D(channels[0], 3, use_bias=False) 20 | self.bn1 = BatchNormalization(momentum=0.1, epsilon=1e-5) 21 | self.conv2 = Conv2D(channels[1], 3, use_bias=False) 22 | self.bn2 = BatchNormalization(momentum=0.1, epsilon=1e-5) 23 | self.conv3 = Conv2D(channels[2], 3, use_bias=False) 24 | self.bn3 = BatchNormalization(momentum=0.1, epsilon=1e-5) 25 | self.conv4 = Conv2D(channels[3], 3, use_bias=True) 26 | self.relu = ReLU() 27 | 28 | def call(self, x, training=None): 29 | src, bgr, pha, fgr, err, hid = x 30 | 31 | H_full, W_full = tf.shape(src)[1], tf.shape(src)[2] 32 | H_half, W_half = H_full // 2, W_full // 2 33 | H_quat, W_quat = H_full // 4, W_full // 4 34 | 35 | src_bgr = tf.concat([src, bgr], axis=-1) 36 | 37 | if self.mode != 'full': 38 | err = tf.image.resize(err, (H_quat, W_quat)) 39 | ref = self.select_refinement_regions(err) 40 | idx = tf.where(ref[:, :, :, 0]) 41 | 42 | x = tf.concat([hid, pha, fgr], axis=-1) 43 | x = tf.image.resize(x, (H_half, W_half)) 44 | x = self.crop_patch(x, idx, 2, 3) 45 | 46 | y = tf.image.resize(src_bgr, (H_half, W_half)) 47 | y = self.crop_patch(y, idx, 2, 3) 48 | 49 | x = self.conv1(tf.concat([x, y], axis=-1), training=training) 50 | x = self.bn1(x, training=training) 51 | x = self.relu(x, training=training) 52 | x = self.conv2(x, training=training) 53 | x = self.bn2(x, training=training) 54 | x = self.relu(x, training=training) 55 | 56 | x = tf.image.resize(x, (8, 8), 'nearest') 57 | y = self.crop_patch(src_bgr, idx, 4, 2) 58 | 59 | x = self.conv3(tf.concat([x, y], axis=-1), training=training) 60 | x = self.bn3(x, training=training) 61 | x = self.relu(x, training=training) 62 | x = self.conv4(x, training=training) 63 | 64 | pha = tf.image.resize(pha, (H_full, W_full)) 65 | pha = self.replace_patch(pha, x[:, :, :, :1], idx) 66 | 67 | fgr = tf.image.resize(fgr, (H_full, W_full)) 68 | fgr = self.replace_patch(fgr, x[:, :, :, 1:], idx) 69 | else: 70 | x = tf.concat([hid, pha, fgr], axis=-1) 71 | x = tf.image.resize(x, (H_half, W_half)) 72 | y = tf.image.resize(src_bgr, (H_half, W_half)) 73 | 74 | x = tf.concat([x, y], axis=-1) 75 | x = tf.pad(x, tf.constant([[0, 0], [3, 3], [3, 3], [0, 0]])) 76 | 77 | x = self.conv1(x, training=training) 78 | x = self.bn1(x, training=training) 79 | x = self.relu(x, training=training) 80 | x = self.conv2(x, training=training) 81 | x = self.bn2(x, training=training) 82 | x = self.relu(x, training=training) 83 | 84 | x = tf.image.resize(x, (H_full + 4, W_full + 4), 'nearest') 85 | y = tf.pad(src_bgr, tf.constant([[0, 0], [2, 2], [2, 2], [0, 0]])) 86 | x = tf.concat([x, y], axis=-1) 87 | 88 | x = self.conv3(x, training=training) 89 | x = self.bn3(x, training=training) 90 | x = self.relu(x, training=training) 91 | x = self.conv4(x, training=training) 92 | 93 | pha = x[:, :, :, :1] 94 | fgr = x[:, :, :, 1:] 95 | ref = tf.ones((tf.shape(src)[0], 1, H_quat, W_quat), dtype=src.dtype) 96 | 97 | return pha, fgr, ref 98 | 99 | def crop_patch(self, x, idx, size: int, padding: int): 100 | box_indices = tf.cast(idx[:, 0], tf.dtypes.int32) 101 | 102 | y1 = idx[:, 1] * size - padding 103 | x1 = idx[:, 2] * size - padding 104 | y2 = idx[:, 1] * size + (size - 1) + padding 105 | x2 = idx[:, 2] * size + (size - 1) + padding 106 | 107 | shape = tf.shape(x) 108 | h = tf.cast(shape[1] - 1, tf.dtypes.float32) 109 | w = tf.cast(shape[2] - 1, tf.dtypes.float32) 110 | y1 = tf.cast(y1, tf.dtypes.float32) / h 111 | x1 = tf.cast(x1, tf.dtypes.float32) / w 112 | y2 = tf.cast(y2, tf.dtypes.float32) / h 113 | x2 = tf.cast(x2, tf.dtypes.float32) / w 114 | 115 | boxes = tf.stack([y1, x1, y2, x2], axis=1) 116 | return tf.image.crop_and_resize(x, boxes, box_indices, (size + 2 * padding, size + 2 * padding), 'nearest') 117 | 118 | def replace_patch(self, x, y, idx): 119 | shape = tf.shape(x) 120 | x = tf.reshape(x, (shape[0], shape[1] // 4, 4, shape[2] // 4, 4, shape[3])) 121 | x = tf.transpose(x, (0, 1, 3, 2, 4, 5)) 122 | x = tf.tensor_scatter_nd_update(x, idx, y) 123 | x = tf.transpose(x, (0, 1, 3, 2, 4, 5)) 124 | x = tf.reshape(x, shape) 125 | return x 126 | 127 | def select_refinement_regions(self, err): 128 | if self.mode == 'sampling': 129 | k = self.sample_pixels // 16 130 | B = tf.shape(err)[0] 131 | H = tf.shape(err)[1] 132 | W = tf.shape(err)[2] 133 | idx = tf.reshape(err, (B, -1)) 134 | idx = tf.math.top_k(idx, k=k, sorted=False)[1] 135 | idx = tf.stack([tf.tile(tf.range(B)[:, None], (1, k)), idx], axis=2) 136 | idx = tf.reshape(idx, (-1, 2)) 137 | ref = tf.scatter_nd(idx, tf.ones(B * k), (B, H * W)) 138 | ref = tf.reshape(ref, (B, H, W, 1)) 139 | if self.prevent_oversampling: 140 | ref *= tf.cast(err > 0, err.dtype) 141 | else: 142 | ref = err > self.threshold 143 | return ref 144 | --------------------------------------------------------------------------------