├── README.md ├── local relation layer.PNG └── local_relation_layer.py /README.md: -------------------------------------------------------------------------------- 1 | ## Local Relational Networks for Image Recognition 2 | A Pytorch implementation of the local relational layer from Local Relation Networks for Image Recogntion [[paper](https://arxiv.org/pdf/1904.11491.pdf)]. 3 | 4 | 5 | ![Local-Relational-Layer](loca_relation_layer.PNG) 6 | 7 | ## Background 8 | This is a unofficial implementation of Local Relation Layer. 9 | There has been another implementation of [local-relational-nets][https://github.com/gan3sh500/local-relational-nets] before, but it cant't run when import it. 10 | Therefore, we make modification and implement a runable version. 11 | 12 | 13 | ## To use the layer: 14 | ``` 15 | from local_relation_layer import LocalRelationalLayer 16 | 17 | layer = LocalRelationalLayer(channels=64,k=7,stride=1,m=8) 18 | ... 19 | output = layer(input) 20 | ``` 21 | 22 | ## Note: 23 | Since the implement of 2 x k x k geometric priors is not inferred in paper, we are unaware of how to constuct it and assume that it's a 2 x k x k matrix, one of which denotes offset on x axis and another for offset on y axis. See the code for details. 24 | -------------------------------------------------------------------------------- /local relation layer.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anonymous-so/Local-Relation-Networks-for-Image-Recognition/5287153ac1384c96223bac235a1ac0ada2f3ad71/local relation layer.PNG -------------------------------------------------------------------------------- /local_relation_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class GeometryPrior(torch.nn.Module): 5 | def __init__(self, k, channels, multiplier=0.5): 6 | super(GeometryPrior, self).__init__() 7 | self.channels = channels 8 | self.k = k 9 | self.l1 = torch.nn.Conv2d(2, int(multiplier * channels), 1) 10 | self.l2 = torch.nn.Conv2d(int(multiplier * channels), channels, 1) 11 | 12 | def forward(self): 13 | # as the paper does not infer how to construct a 2xkxk position matrix 14 | # we assume that it's a kxk matrix for deltax,and a kxk matric for deltay. 15 | # that is, [[[-1,0,1],[-1,0,1],[-1,0,1]],[[1,1,1],[0,0,0],[-1,-1,-1]]] for kernel = 3 16 | a_range = torch.arange(-1*(self.k//2),(self.k//2)+1).view(1,-1) 17 | x_position = a_range.expand(self.k,a_range.shape[1]) 18 | b_range = torch.arange((self.k//2),-1*(self.k//2)-1,-1).view(-1,1) 19 | y_position = b_range.expand(b_range.shape[0],self.k) 20 | position = torch.cat((x_position.unsqueeze(0),y_position.unsqueeze(0)),0).unsqueeze(0).float() 21 | if torch.cuda.is_available(): 22 | position = position.cuda() 23 | out = self.l2(torch.nn.functional.relu(self.l1(position))) 24 | return out 25 | 26 | 27 | 28 | class KeyQueryMap(torch.nn.Module): 29 | def __init__(self, channels, m): 30 | super(KeyQueryMap, self).__init__() 31 | self.l = torch.nn.Conv2d(channels, channels // m, 1) 32 | 33 | def forward(self, x): 34 | return self.l(x) 35 | 36 | 37 | class AppearanceComposability(torch.nn.Module): 38 | def __init__(self, k, padding, stride): 39 | super(AppearanceComposability, self).__init__() 40 | self.k = k 41 | self.unfold = torch.nn.Unfold(k, 1, padding, stride) 42 | 43 | def forward(self, x): 44 | key_map, query_map = x 45 | k = self.k 46 | key_map_unfold = self.unfold(key_map).transpose(2,1).contiguous() # [N batch , H_out*Wout, C channel * k*k] 47 | query_map_unfold = self.unfold(query_map).transpose(2,1).contiguous() # [N batch , H_out*Wout, C channel * k*k] 48 | key_map_unfold = key_map_unfold.view(key_map.shape[0],-1, key_map.shape[1], key_map_unfold.shape[-1]//key_map.shape[1]) 49 | query_map_unfold = query_map_unfold.view(query_map.shape[0], -1, query_map.shape[1], query_map_unfold.shape[-1]//query_map.shape[1]) 50 | key_map_unfold = key_map_unfold.transpose(2,1).contiguous() 51 | query_map_unfold = query_map_unfold.transpose(2,1).contiguous() 52 | return (key_map_unfold * query_map_unfold[:, :, :, k**2//2:k**2//2+1]).view(key_map_unfold.shape[0],key_map_unfold.shape[1],key_map_unfold.shape[2],k,k) #[N batch, C channel, (H-k+1)*(W-k+1), k*k] 53 | 54 | 55 | def combine_prior(appearance_kernel, geometry_kernel): 56 | return torch.nn.functional.softmax(appearance_kernel + geometry_kernel,dim=-1) 57 | 58 | 59 | class LocalRelationalLayer(torch.nn.Module): 60 | def __init__(self, channels, k, stride=1, padding =0,m=None): 61 | super(LocalRelationalLayer, self).__init__() 62 | self.channels = channels 63 | self.k = k 64 | self.stride = stride 65 | self.m = m or 8 66 | self.padding = padding 67 | self.kmap = KeyQueryMap(channels, self.m) 68 | self.qmap = KeyQueryMap(channels, self.m) 69 | self.ac = AppearanceComposability(k, self.padding, self.stride) 70 | self.gp = GeometryPrior(k, channels//m) 71 | self.unfold = torch.nn.Unfold(k, 1, self.padding, self.stride) 72 | self.final1x1 = torch.nn.Conv2d(channels, channels, 1) 73 | 74 | def forward(self, x): # x = [N,C,H,W] 75 | km = self.kmap(x) # [N,C/m,h,w] 76 | qm = self.qmap(x) # [N,C/m,h,w] 77 | ak = self.ac((km, qm)) # [N,C/m,H_out*W_out, k,k] 78 | gpk = self.gp() # [1, C/m,k,k] 79 | ck = combine_prior(ak, gpk.unsqueeze(2))[:, None, :, :, :] # [N,1,C/m,H_out*W_out, k,k] 80 | x_unfold = self.unfold(x).transpose(2,1).contiguous().view(x.shape[0], -1, x.shape[1], self.k*self.k).transpose(2,1).contiguous() 81 | x_unfold = x_unfold.view(x.shape[0], self.m, x.shape[1] // self.m, -1, self.k, self.k) # [N, m, C/m, H_out*W_out, k,k] 82 | pre_output = (ck * x_unfold).view(x.shape[0], x.shape[1], -1, self.k*self.k) # [N, C,HOUT*WOUT, k*k] 83 | h_out = (x.shape[2] + 2 * self.padding - 1 * self.k )// self.stride + 1 84 | w_out = (x.shape[3] + 2 * self.padding - 1 * self.k )// self.stride + 1 85 | pre_output = torch.sum(pre_output, 3).view(x.shape[0], x.shape[1], h_out, w_out) # [N, C, H_out*W_out] 86 | return self.final1x1(pre_output) 87 | 88 | if __name__ == '__main__': 89 | # example of local relation layer 90 | layer = LocalRelationalLayer(channels=64,k=7,stride=1,padding=0, m=8).cuda() 91 | input = torch.zeros(2,64,19,19).cuda() 92 | output = layer(input) 93 | --------------------------------------------------------------------------------