├── README.md └── drconv.py /README.md: -------------------------------------------------------------------------------- 1 | # DRConv-pytorch 2 | 3 | An unofficial non-cuda toy implementation of [DRConv](https://arxiv.org/abs/2003.12243). 4 | -------------------------------------------------------------------------------- /drconv.py: -------------------------------------------------------------------------------- 1 | # Written by Jinghao Zhou 2 | 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import torch 6 | 7 | from torch.autograd import Variable, Function 8 | 9 | class asign_index(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, kernel, guide_feature): 12 | ctx.save_for_backward(kernel, guide_feature) 13 | guide_mask = torch.zeros_like(guide_feature).scatter_(1, guide_feature.argmax(dim=1, keepdim=True), 1).unsqueeze(2) # B x 3 x 1 x 25 x 25 14 | return torch.sum(kernel * guide_mask, dim=1) 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | kernel, guide_feature = ctx.saved_tensors 19 | guide_mask = torch.zeros_like(guide_feature).scatter_(1, guide_feature.argmax(dim=1, keepdim=True), 1).unsqueeze(2) # B x 3 x 1 x 25 x 25 20 | grad_kernel = grad_output.clone().unsqueeze(1) * guide_mask # B x 3 x 256 x 25 x 25 21 | grad_guide = grad_output.clone().unsqueeze(1) * kernel # B x 3 x 256 x 25 x 25 22 | grad_guide = grad_guide.sum(dim=2) # B x 3 x 25 x 25 23 | softmax = F.softmax(guide_feature, 1) # B x 3 x 25 x 25 24 | grad_guide = softmax * (grad_guide - (softmax * grad_guide).sum(dim=1, keepdim=True)) # B x 3 x 25 x 25 25 | return grad_kernel, grad_guide 26 | 27 | 28 | def xcorr_slow(x, kernel, kwargs): 29 | """for loop to calculate cross correlation 30 | """ 31 | batch = x.size()[0] 32 | out = [] 33 | for i in range(batch): 34 | px = x[i] 35 | pk = kernel[i] 36 | px = px.view(1, px.size()[0], px.size()[1], px.size()[2]) 37 | pk = pk.view(-1, px.size()[1], pk.size()[1], pk.size()[2]) 38 | po = F.conv2d(px, pk, **kwargs) 39 | out.append(po) 40 | out = torch.cat(out, 0) 41 | return out 42 | 43 | 44 | def xcorr_fast(x, kernel, kwargs): 45 | """group conv2d to calculate cross correlation 46 | """ 47 | batch = kernel.size()[0] 48 | pk = kernel.view(-1, x.size()[1], kernel.size()[2], kernel.size()[3]) 49 | px = x.view(1, -1, x.size()[2], x.size()[3]) 50 | po = F.conv2d(px, pk, **kwargs, groups=batch) 51 | po = po.view(batch, -1, po.size()[2], po.size()[3]) 52 | return po 53 | 54 | class Corr(Function): 55 | @staticmethod 56 | def symbolic(g, x, kernel, groups): 57 | return g.op("Corr", x, kernel, groups_i=groups) 58 | 59 | @staticmethod 60 | def forward(self, x, kernel, groups, kwargs): 61 | """group conv2d to calculate cross correlation 62 | """ 63 | batch = x.size(0) 64 | channel = x.size(1) 65 | x = x.view(1, -1, x.size(2), x.size(3)) 66 | kernel = kernel.view(-1, channel // groups, kernel.size(2), kernel.size(3)) 67 | out = F.conv2d(x, kernel, **kwargs, groups=groups * batch) 68 | out = out.view(batch, -1, out.size(2), out.size(3)) 69 | return out 70 | 71 | class Correlation(nn.Module): 72 | use_slow = True 73 | 74 | def __init__(self, use_slow=None): 75 | super(Correlation, self).__init__() 76 | if use_slow is not None: 77 | self.use_slow = use_slow 78 | else: 79 | self.use_slow = Correlation.use_slow 80 | 81 | def extra_repr(self): 82 | if self.use_slow: return "xcorr_slow" 83 | return "xcorr_fast" 84 | 85 | def forward(self, x, kernel, **kwargs): 86 | if self.training: 87 | if self.use_slow: 88 | return xcorr_slow(x, kernel, kwargs) 89 | else: 90 | return xcorr_fast(x, kernel, kwargs) 91 | else: 92 | return Corr.apply(x, kernel, 1, kwargs) 93 | 94 | 95 | class DRConv2d(nn.Module): 96 | def __init__(self, in_channels, out_channels, kernel_size, region_num=8, **kwargs): 97 | super(DRConv2d, self).__init__() 98 | self.region_num = region_num 99 | 100 | self.conv_kernel = nn.Sequential( 101 | nn.AdaptiveAvgPool2d((kernel_size, kernel_size)), 102 | nn.Conv2d(in_channels, region_num * region_num, kernel_size=1), 103 | nn.Sigmoid(), 104 | nn.Conv2d(region_num * region_num, region_num * in_channels * out_channels, kernel_size=1, groups=region_num) 105 | ) 106 | self.conv_guide = nn.Conv2d(in_channels, region_num, kernel_size=kernel_size, **kwargs) 107 | 108 | self.corr = Correlation(use_slow=False) 109 | self.kwargs = kwargs 110 | self.asign_index = asign_index.apply 111 | 112 | def forward(self, input): 113 | kernel = self.conv_kernel(input) 114 | kernel = kernel.view(kernel.size(0), -1, kernel.size(2), kernel.size(3)) # B x (r*in*out) x W X H 115 | output = self.corr(input, kernel, **self.kwargs) # B x (r*out) x W x H 116 | output = output.view(output.size(0), self.region_num, -1, output.size(2), output.size(3)) # B x r x out x W x H 117 | guide_feature = self.conv_guide(input) 118 | output = self.asign_index(output, guide_feature) 119 | return output 120 | 121 | if __name__ == '__main__': 122 | B = 16 123 | in_channels = 256 124 | out_channels = 512 125 | size = 89 126 | conv = DRConv2d(in_channels, out_channels, kernel_size=3, region_num=8).cuda() 127 | conv.train() 128 | input = torch.ones(B, in_channels, size, size).cuda() 129 | output = conv(input) 130 | print(input.shape, output.shape) 131 | 132 | # flops, params 133 | from thop import profile 134 | from thop import clever_format 135 | 136 | class Conv2d(nn.Module): 137 | def __init__(self): 138 | super(Conv2d, self).__init__() 139 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3) 140 | def forward(self, input): 141 | return self.conv(input) 142 | conv2 = Conv2d().cuda() 143 | conv2.train() 144 | macs2, params2 = profile(conv2, inputs=(input, )) 145 | macs, params = profile(conv, inputs=(input, )) 146 | print(macs2, params2) 147 | print(macs, params) --------------------------------------------------------------------------------