├── README.md ├── b_matrix.png └── glore.py /README.md: -------------------------------------------------------------------------------- 1 | # GloRe 2 | PyTorch implementation of Global Reasoning unit (GloRe) from [Graph-Based Global Reasoning Networks](https://research.fb.com/wp-content/uploads/2019/05/Graph-Based-Global-Reasoning-Networks.pdf?). 3 | 4 | Visualization of B matrix for some MNIST examples, columns correspond to N different features: 5 | ![B matrix](b_matrix.png) 6 | -------------------------------------------------------------------------------- /b_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marekjg/glore/710752f5b7bad78e717aa828b1270a4f1f1e95b0/b_matrix.png -------------------------------------------------------------------------------- /glore.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class GCN(nn.Module): 5 | def __init__(self, dim_1_channels, dim_2_channels): 6 | super().__init__() 7 | 8 | self.conv1d_1 = nn.Conv1d(dim_1_channels, dim_1_channels, 1) 9 | self.conv1d_2 = nn.Conv1d(dim_2_channels, dim_2_channels, 1) 10 | 11 | def forward(self, x): 12 | h = self.conv1d_1(x).permute(0, 2, 1) 13 | return self.conv1d_2(h).permute(0, 2, 1) 14 | 15 | 16 | class GloRe(nn.Module): 17 | def __init__(self, in_channels, mid_channels, N): 18 | super().__init__() 19 | self.in_channels = in_channels 20 | self.mid_channels = mid_channels 21 | self.N = N 22 | 23 | self.phi = nn.Conv2d(in_channels, mid_channels, 1) 24 | self.theta = nn.Conv2d(in_channels, N, 1) 25 | self.gcn = GCN(N, mid_channels) 26 | self.phi_inv = nn.Conv2d(mid_channels, in_channels, 1) 27 | 28 | def forward(self, x): 29 | batch_size, in_channels, h, w = x.shape 30 | mid_channels = self.mid_channels 31 | N = self.N 32 | 33 | B = self.theta(x).view(batch_size, N, -1) 34 | x_reduced = self.phi(x).view(batch_size, mid_channels, h * w) 35 | x_reduced = x_reduced.permute(0, 2, 1) 36 | v = B.bmm(x_reduced) 37 | 38 | z = self.gcn(v) 39 | y = B.permute(0, 2, 1).bmm(z).permute(0, 2, 1) 40 | y = y.view(batch_size, mid_channels, h, w) 41 | x_res = self.phi_inv(y) 42 | 43 | return x + x_res 44 | --------------------------------------------------------------------------------