├── LICENSE ├── README.md ├── complexity.jpg └── dysample.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Wenze Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DySample: Learning to Upsample by Learning to Sample 2 | 3 |

4 | 5 | Code for the ICCV 2023 paper [Learning to Upsample by Learning to Sample](https://arxiv.org/abs/2308.15085). 6 | 7 | We present DySample, an ultra-lightweight and effective dynamic upsampler. While impressive performance gains have been witnessed from recent kernel-based dynamic upsamplers such as CARAFE, FADE, and SAPA, they introduce much workload, mostly due to the time-consuming dynamic convolution and the additional sub-network used to generate dynamic kernels. Further, the need for high-res feature guidance of FADE and SAPA somehow limits their application scenarios. To address these concerns, we bypass dynamic convolution and formulate upsampling from the perspective of point sampling, which is more resource-efficient and can be easily implemented with the standard built-in function in PyTorch. We first showcase a naive design, and then demonstrate how to strengthen its upsampling behavior step by step towards our new upsampler, DySample. Compared with former kernel-based dynamic upsamplers, DySample requires no customized CUDA package and has much fewer parameters, FLOPs, GPU memory, and latency. Besides the light-weight characteristics, DySample outperforms other upsamplers across five dense prediction tasks, including semantic segmentation, object detection, instance segmentation, panoptic segmentation, and monocular depth estimation. 8 | 9 | ## Highlights 10 | 11 | - **Fast:** DySample adopts very simple implementation for fast speed; 12 | - **Easy to use:** DySample does not rely on any extra CUDA packages installed. 13 | 14 | ## Results 15 | 16 | Object detection with Faster R-CNN on COCO 17 | | Faster R-CNN | Backbone | Params | $AP$ | $AP_{50}$ | $AP_{75}$ | $AP_S$ | $AP_M$ | $AP_{L}$ | log | ckpt | 18 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 19 | | Nearest | R50 | 46.8M | 37.5 | 58.2 | 40.8 | 21.3 | 41.1 | 48.9 | | | 20 | | DySample | R50 | +32.7K | 38.6 | 59.9 | 42.0 | 22.9 | 42.1 | 50.2 |[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/faster_rcnn_r50_fpn_dysample-lpg4_1x_coco.log)|[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/faster_rcnn_r50_fpn_dysample-lpg4_1x_coco.pth)| 21 | | DySample+ | R50 | +65.5K | 38.7 | 60.0 | 42.2 | 22.5 | 42.4 | 50.2 |[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/faster_rcnn_r50_fpn_dysample-lpg4ds_1x_coco.log)|[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/faster_rcnn_r50_fpn_dysample-lpg4ds_1x_coco.pth)| 22 | | Nearest | R101 | 65.8M | 39.4 | 60.1 | 43.1 | 22.4 | 43.7 | 51.1 | | | 23 | | DySample+ | R101 | +65.5K | 40.5 | 61.6 | 43.8 | 24.2 | 44.5 | 52.3 |[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/faster_rcnn_r101_fpn_dysample-lpg4ds_1x_coco.log)|[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/faster_rcnn_r101_fpn_dysample-lpg4ds_1x_coco.pth)| 24 | 25 | Instance segmentation with Mask R-CNN on COCO 26 | | Bbox results | Backbone | Params | $AP$ | $AP_{50}$ | $AP_{75}$ | $AP_S$ | $AP_M$ | $AP_{L}$ | 27 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 28 | | Nearest | R50 | - | 38.3 | 58.7 | 42.0 | 21.9 | 41.8 | 50.2 | 29 | | DySample | R50 | +32.7K | 39.2 | 60.3 | 43.0 | 23.5 | 42.5 | 51.0 | 30 | | DySample+ | R50 | +65.5K | 39.6 | 60.6 | 43.5 | 23.5 | 43.1 | 50.8 | 31 | | Nearest | R101 | - | 40.0 | 60.4 | 43.7 | 22.8 | 43.7 | 52.0 | 32 | | DySample+ | R101 | +65.5K | 41.0 | 61.9 | 44.9 | 24.3 | 45.0 | 53.5 | 33 | 34 | | Segm results | Backbone | Params | $AP$ | $AP_{50}$ | $AP_{75}$ | $AP_S$ | $AP_M$ | $AP_{L}$ | log | ckpt | 35 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 36 | | Nearest | R50 | - | 34.7 | 55.8 | 37.2 | 16.1 | 37.3 | 50.8 | | | 37 | | DySample | R50 | +32.7K | 35.4 | 56.9 | 37.8 | 17.1 | 37.7 | 51.1 |[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/mask_rcnn_r50_fpn_dysample-lpg4_1x_coco.log)|[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/mask_rcnn_r50_fpn_dysample-lpg4_1x_coco.pth)| 38 | | DySample+ | R50 | +65.5K | 35.7 | 57.4 | 38.1 | 17.6 | 38.5 | 51.5 |[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/mask_rcnn_r50_fpn_dysample-lpg4ds_1x_coco.log)|[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/mask_rcnn_r50_fpn_dysample-lpg4ds_1x_coco.pth)| 39 | | Nearest | R101 | - | 36.0 | 57.6 | 38.5 | 16.5 | 39.3 | 52.2 | | | 40 | | DySample+ | R101 | +65.5K | 36.8 | 58.7 | 39.5 | 17.5 | 40.0 | 53.8 |[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/mask_rcnn_r101_fpn_dysample-lpg4ds_1x_coco.log)|[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/mask_rcnn_r101_fpn_dysample-lpg4ds_1x_coco.pth)| 41 | 42 | Panoptic segmentation with Panoptic FPN on COCO 43 | | Panoptic FPN | Backbone | Params | $PQ$ | $PQ^{th}$ | $PQ^{st}$ | $SQ$ | $RQ$ | log | ckpt | 44 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 45 | | Nearest | R50 | 46.0M | 40.2 | 47.8 | 28.9 | 77.8 | 49.3 | | | 46 | | DySample | R50 | +24.6K | 41.4 | 48.5 | 30.7 | 78.6 | 50.7 |[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/panoptic_fpn_r50_fpn_dysample-lpg4_1x_coco.log)|[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/panoptic_fpn_r50_fpn_dysample-lpg4_1x_coco.pth)| 47 | | DySample+ | R50 | +49.2K | 41.5 | 48.5 | 30.8 | 78.3 | 50.7 |[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/panoptic_fpn_r50_fpn_dysample-lpg4ds_1x_coco.log)|[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/panoptic_fpn_r50_fpn_dysample-lpg4ds_1x_coco.pth)| 48 | | Nearest | R101 | 65.0M | 42.2 | 50.1 | 30.3 | 78.3 | 51.4 | | | 49 | | DySample+ | R101 | +49.2K | 43.0 | 50.2 | 32.1 | 78.6 | 52.4 |[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/panoptic_fpn_r101_fpn_dysample-lpg4ds_1x_coco.log)|[Link](https://github.com/tiny-smart/detection-with-upsamplers/releases/download/checkpoint/panoptic_fpn_r101_fpn_dysample-lpg4ds_1x_coco.pth)| 50 | 51 | ## Usage 52 | 53 | For application instances, one can refer to [detection-with-upsamplers](https://github.com/tiny-smart/detection-with-upsamplers) and [segmentation-with-upsamplers](https://github.com/tiny-smart/segmentation-with-upsamplers) to try upsamplers with mmcv. 54 | 55 | ## Citation 56 | If you find DySample useful for your research, please cite: 57 | ``` 58 | @inproceedings{liu2023learning, 59 | title={Learning to Upsample by Learning to Sample}, 60 | author={Liu, Wenze and Lu, Hao and Fu, Hongtao and Cao, Zhiguo}, 61 | booktitle={Proc. IEEE/CVF International Conference on Computer Vision (ICCV)}, 62 | year={2023} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /complexity.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiny-smart/dysample/81a1de5caa95d55a0f5488425fa53ec7ef47f8f0/complexity.jpg -------------------------------------------------------------------------------- /dysample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def normal_init(module, mean=0, std=1, bias=0): 7 | if hasattr(module, 'weight') and module.weight is not None: 8 | nn.init.normal_(module.weight, mean, std) 9 | if hasattr(module, 'bias') and module.bias is not None: 10 | nn.init.constant_(module.bias, bias) 11 | 12 | 13 | def constant_init(module, val, bias=0): 14 | if hasattr(module, 'weight') and module.weight is not None: 15 | nn.init.constant_(module.weight, val) 16 | if hasattr(module, 'bias') and module.bias is not None: 17 | nn.init.constant_(module.bias, bias) 18 | 19 | 20 | class DySample(nn.Module): 21 | def __init__(self, in_channels, scale=2, style='lp', groups=4, dyscope=False): 22 | super().__init__() 23 | self.scale = scale 24 | self.style = style 25 | self.groups = groups 26 | assert style in ['lp', 'pl'] 27 | if style == 'pl': 28 | assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0 29 | assert in_channels >= groups and in_channels % groups == 0 30 | 31 | if style == 'pl': 32 | in_channels = in_channels // scale ** 2 33 | out_channels = 2 * groups 34 | else: 35 | out_channels = 2 * groups * scale ** 2 36 | 37 | self.offset = nn.Conv2d(in_channels, out_channels, 1) 38 | normal_init(self.offset, std=0.001) 39 | if dyscope: 40 | self.scope = nn.Conv2d(in_channels, out_channels, 1, bias=False) 41 | constant_init(self.scope, val=0.) 42 | 43 | self.register_buffer('init_pos', self._init_pos()) 44 | 45 | def _init_pos(self): 46 | h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale 47 | return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1) 48 | 49 | def sample(self, x, offset): 50 | B, _, H, W = offset.shape 51 | offset = offset.view(B, 2, -1, H, W) 52 | coords_h = torch.arange(H) + 0.5 53 | coords_w = torch.arange(W) + 0.5 54 | coords = torch.stack(torch.meshgrid([coords_w, coords_h]) 55 | ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device) 56 | normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1) 57 | coords = 2 * (coords + offset) / normalizer - 1 58 | coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view( 59 | B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1) 60 | return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear', 61 | align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W) 62 | 63 | def forward_lp(self, x): 64 | if hasattr(self, 'scope'): 65 | offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos 66 | else: 67 | offset = self.offset(x) * 0.25 + self.init_pos 68 | return self.sample(x, offset) 69 | 70 | def forward_pl(self, x): 71 | x_ = F.pixel_shuffle(x, self.scale) 72 | if hasattr(self, 'scope'): 73 | offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_pos 74 | else: 75 | offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_pos 76 | return self.sample(x, offset) 77 | 78 | def forward(self, x): 79 | if self.style == 'pl': 80 | return self.forward_pl(x) 81 | return self.forward_lp(x) 82 | 83 | 84 | if __name__ == '__main__': 85 | x = torch.rand(2, 64, 4, 7) 86 | dys = DySample(64) 87 | print(dys(x).shape) 88 | --------------------------------------------------------------------------------