├── LICENSE ├── README.md ├── deform_conv_v2.py ├── scaled_mnist ├── __pycache__ │ ├── archs.cpython-36.pyc │ └── dataset.cpython-36.pyc ├── archs.py └── dataset.py ├── scaled_mnist_train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Takato Kimura 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of Deformable ConvNets v2 2 | 3 | This repository contains code for **Deformable ConvNets v2 (Modulated Deformable Convolution)** based on [Deformable ConvNets v2: More Deformable, Better Results 4 | ](https://arxiv.org/abs/1811.11168) implemented in PyTorch. This implementation of deformable convolution based on [ChunhuanLin/deform_conv_pytorch](https://github.com/ChunhuanLin/deform_conv_pytorch), thanks to ChunhuanLin. 5 | 6 | ## TODO 7 | - [x] Initialize weight of modulated deformable convolution based on paper 8 | - [x] Learning rates of offset and modulation are set to different values from other layers 9 | - [x] Results of ScaledMNIST experiments 10 | - [x] Support different stride 11 | - [ ] Support deformable group 12 | - [ ] DeepLab + DCNv2 13 | - [ ] Results of VOC segmentation experiments 14 | 15 | ## Requirements 16 | - Python 3.6 17 | - PyTorch 1.0 18 | 19 | ## Usage 20 | Replace regular convolution (following model's conv2) with modulated deformable convolution: 21 | ```python 22 | class ConvNet(nn.Module): 23 | def __init__(self): 24 | self.relu = nn.ReLU(inplace=True) 25 | self.pool = nn.MaxPool2d((2, 2)) 26 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 27 | 28 | self.conv1 = nn.Conv2d(3, 32, 3, padding=1) 29 | self.bn1 = nn.BatchNorm2d(32) 30 | self.conv2 = nn.DeformConv2d(32, 64, 3, padding=1, modulation=True) 31 | self.bn2 = nn.BatchNorm2d(64) 32 | 33 | self.fc = nn.Linear(64, 10) 34 | 35 | def forward(self, x): 36 | x = self.relu(self.bn1(self.conv1(x))) 37 | x = self.pool(x) 38 | x = self.relu(self.bn2(self.conv2(x))) 39 | 40 | x = self.avg_pool(x) 41 | x = x.view(x.shape[0], -1) 42 | x = self.fc(x) 43 | 44 | return x 45 | ``` 46 | 47 | ## Training 48 | ### ScaledMNIST 49 | ScaledMNIST is randomly scaled MNIST. 50 | 51 | Use modulated deformable convolution at conv3~4: 52 | ``` 53 | python train.py --arch ScaledMNISTNet --deform True --modulation True --min-deform-layer 3 54 | ``` 55 | Use deformable convolution at conv3~4: 56 | ``` 57 | python train.py --arch ScaledMNISTNet --deform True --modulation False --min-deform-layer 3 58 | ``` 59 | Use only regular convolution: 60 | ``` 61 | python train.py --arch ScaledMNISTNet --deform False --modulation False 62 | ``` 63 | 64 | ## Results 65 | ### ScaledMNIST 66 | | Model | Accuracy (%) | Loss | 67 | |:------------------------|:-----------------:|:--------:| 68 | | w/o DCN | 97.22 | 0.113| 69 | | w/ DCN @conv4 | 98.60 | 0.049| 70 | | w/ DCN @conv3~4 | 98.95 | 0.035| 71 | | w/ DCNv2 @conv4 | 98.45 | 0.058| 72 | | w/ DCNv2 @conv3~4 | **99.21** | **0.027**| 73 | 74 | 75 | 76 | # 从https://github.com/4uiiurz1/pytorch-deform-conv-v2 克隆过来 77 | 78 | 79 | 80 | 添加了DCN的关键注释,记得配合【GiantPandaCV】公众号的《再思考可变形卷积》一文食用。 81 | 82 | -------------------------------------------------------------------------------- /deform_conv_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class DeformConv2d(nn.Module): 6 | # inc表示输入通道数 7 | # outc 表示输出通道数 8 | # kernel_size表示卷积核尺寸 9 | # stride 卷积核滑动步长 10 | # bias 偏置 11 | # modulation DCNV1还是DCNV2的开关 12 | def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False): 13 | """ 14 | Args: 15 | modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2). 16 | """ 17 | super(DeformConv2d, self).__init__() 18 | self.kernel_size = kernel_size 19 | self.padding = padding 20 | self.stride = stride 21 | self.zero_padding = nn.ZeroPad2d(padding) 22 | # 普通的卷积层,即获得了偏移量之后的特征图再接一个普通卷积 23 | self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) 24 | # 获得偏移量,卷积核的通道数应该为2xkernel_sizexkernel_size 25 | self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) 26 | # 偏移量初始化为0 27 | nn.init.constant_(self.p_conv.weight, 0) 28 | # 注册module反向传播的hook函数, 可以查看当前层参数的梯度 29 | self.p_conv.register_backward_hook(self._set_lr) 30 | # 将modulation赋值给当前类 31 | self.modulation = modulation 32 | if modulation: 33 | # 如果是DCN V2,还多了一个权重参数,用m_conv来表示 34 | self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) 35 | nn.init.constant_(self.m_conv.weight, 0) 36 | # 注册module反向传播的hook函数, 可以查看当前层参数的梯度 37 | self.m_conv.register_backward_hook(self._set_lr) 38 | 39 | # 静态方法 类或实例均可调用,这函数的结合hook可以输出你想要的Variable的梯度 40 | @staticmethod 41 | def _set_lr(module, grad_input, grad_output): 42 | grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input))) 43 | grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output))) 44 | 45 | # 前向传播函数 46 | def forward(self, x): 47 | # 获得输入特征图x的偏移量 48 | # 假设输入特征图shape是[1,3,32,32],然后卷积核是3x3, 49 | # 输出通道数为32,那么offset的shape是[1,2*3*3,32] 50 | offset = self.p_conv(x) 51 | # 如果是DCN V2那么还需要获得输入特征图x偏移量的权重项 52 | # 假设输入特征图shape是[1,3,32,32],然后卷积核是3x3, 53 | # 输出通道数为32,那么offset的权重shape是[1,3*3,32] 54 | if self.modulation: 55 | m = torch.sigmoid(self.m_conv(x)) 56 | # dtype = torch.float32 57 | dtype = offset.data.type() 58 | # 卷积核尺寸大小 59 | ks = self.kernel_size 60 | # N=2*3*3/2=3*3=9 61 | N = offset.size(1) // 2 62 | # 如果需要Padding就先Padding 63 | if self.padding: 64 | x = self.zero_padding(x) 65 | 66 | # p的shape为(b, 2N, h, w) 67 | # 这个函数用来获取所有的卷积核偏移之后相对于原始特征图x的坐标(现在是浮点数) 68 | p = self._get_p(offset, dtype) 69 | 70 | # 我们学习出的量是float类型的,而像素坐标都是整数类型的, 71 | # 所以我们还要用双线性插值的方法去推算相应的值 72 | # 维度转换,现在p的维度为(b, h, w, 2N) 73 | p = p.contiguous().permute(0, 2, 3, 1) 74 | # floor是向下取整 75 | q_lt = p.detach().floor() 76 | # +1相当于原始坐标向上取整 77 | q_rb = q_lt + 1 78 | # 将q_lt即左上角坐标的值限制在图像范围内 79 | q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long() 80 | # 将q_rb即右下角坐标的值限制在图像范围内 81 | q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long() 82 | # 用q_lt的前半部分坐标q_lt_x和q_rb的后半部分q_rb_y组合成q_lb 83 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1) 84 | # 同理 85 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1) 86 | 87 | # 对p的坐标也要限制在图像范围内 88 | p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1) 89 | 90 | # bilinear kernel (b, h, w, N) 91 | # 双线性插值的4个系数 92 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:])) 93 | g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:])) 94 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:])) 95 | g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:])) 96 | 97 | # (b, c, h, w, N) 98 | # 现在只获取了坐标值,我们最终木的是获取相应坐标上的值, 99 | # 这里我们通过self._get_x_q()获取相应值。 100 | x_q_lt = self._get_x_q(x, q_lt, N) 101 | x_q_rb = self._get_x_q(x, q_rb, N) 102 | x_q_lb = self._get_x_q(x, q_lb, N) 103 | x_q_rt = self._get_x_q(x, q_rt, N) 104 | 105 | # (b, c, h, w, N) 106 | # 双线性插值计算 107 | x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \ 108 | g_rb.unsqueeze(dim=1) * x_q_rb + \ 109 | g_lb.unsqueeze(dim=1) * x_q_lb + \ 110 | g_rt.unsqueeze(dim=1) * x_q_rt 111 | 112 | # modulation 113 | if self.modulation: 114 | m = m.contiguous().permute(0, 2, 3, 1) 115 | m = m.unsqueeze(dim=1) 116 | m = torch.cat([m for _ in range(x_offset.size(1))], dim=1) 117 | x_offset *= m 118 | 119 | # 在获取所有值后我们计算出x_offset,但是x_offset的size 120 | # 是(b,c,h,w,N),我们的目的是将最终的输出结果的size变 121 | # 成和x一致即(b,c,h,w),所以在最后用了一个reshape的操作。 122 | # 这里ks=3 123 | x_offset = self._reshape_x_offset(x_offset, ks) 124 | out = self.conv(x_offset) 125 | 126 | return out 127 | 128 | # 通过函数_get_p_n生成了卷积的相对坐标,其中卷积的中心点被看成原点 129 | # 然后其它点的坐标都是相对于原点来说的,例如self.kernel_size=3,通 130 | # 过torch.meshgrid生成从(-1,-1)到(1,1)9个坐标。将坐标的x和y 131 | # 分别存储,然后再将x,y以(1,2N,1,1)的形式返回,这样我们就获取了一 132 | # 个卷积核的所有相对坐标。 133 | def _get_p_n(self, N, dtype): 134 | # p_n_x = tensor([[-1, -1, -1], 135 | # [ 0, 0, 0], 136 | # [ 1, 1, 1]]) 137 | # p_n_y = tensor([[-1, 0, 1], 138 | # [-1, 0, 1], 139 | # [-1, 0, 1]]) 140 | p_n_x, p_n_y = torch.meshgrid( 141 | torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), 142 | torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) 143 | # (2N, 1) 144 | # p_n = tensor([-1, -1, -1, 0, 0, 0, 1, 1, 1, -1, 0, 1, -1, 0, 1, -1, 0, 1]) 145 | p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) 146 | # p_n.shape=[1,2*N,1,1] 147 | p_n = p_n.view(1, 2*N, 1, 1).type(dtype) 148 | 149 | return p_n 150 | 151 | # 获取卷积核在特征图上对应的中心坐标,也即论文公式中的p_0 152 | # 通过torch.mershgrid生成所有的中心坐标,然后通过kernel_size 153 | # 推断初始坐标,然后通过stride推断所有的中心坐标,这里注意一下, 154 | # 代码默认torch.arange从1开始,实际上这是kernel_size为3时的情况, 155 | # 严谨一点torch.arrange应该从kernel_size//2开始,这个实现只适合3x3的卷积。 156 | def _get_p_0(self, h, w, N, dtype): 157 | # 设w = 7, h = 5, stride = 1 158 | # 有p_0_x = tensor([[1, 1, 1, 1, 1, 1, 1], 159 | # [2, 2, 2, 2, 2, 2, 2], 160 | # [3, 3, 3, 3, 3, 3, 3], 161 | # [4, 4, 4, 4, 4, 4, 4], 162 | # [5, 5, 5, 5, 5, 5, 5]]) 163 | # p_0_x.shape = [5, 7] 164 | # p_0_y = tensor([[1, 2, 3, 4, 5, 6, 7], 165 | # [1, 2, 3, 4, 5, 6, 7], 166 | # [1, 2, 3, 4, 5, 6, 7], 167 | # [1, 2, 3, 4, 5, 6, 7], 168 | # [1, 2, 3, 4, 5, 6, 7]]) 169 | # p_0_y.shape = [5, 7] 170 | p_0_x, p_0_y = torch.meshgrid( 171 | torch.arange(1, h*self.stride+1, self.stride), 172 | torch.arange(1, w*self.stride+1, self.stride)) 173 | # p_0_x的shape为torch.Size([1, 9, 5, 7]) 174 | p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) 175 | # p_0_y的shape为torch.Size([1, 9, 5, 7]) 176 | p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) 177 | # p_0的shape为torch.Size([1, 18, 5, 7]) 178 | p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype) 179 | 180 | return p_0 181 | 182 | def _get_p(self, offset, dtype): 183 | # N = 18 / 2 = 9 184 | # h = 32 185 | # w = 32 186 | N, h, w = offset.size(1)//2, offset.size(2), offset.size(3) 187 | 188 | # (1, 2N, 1, 1) 189 | p_n = self._get_p_n(N, dtype) 190 | # (1, 2N, h, w) 191 | p_0 = self._get_p_0(h, w, N, dtype) 192 | # 卷积坐标加上之前学习出的offset后就是论文提出的公式(2)也就是加上了偏置后的卷积操作。 193 | # 比如p(在N=0时)p_0就是中心坐标,而p_n=(-1,-1),所以此时的p就是卷积核中心坐标加上 194 | # (-1,-1)(即红色块左上方的块)再加上offset。同理可得N=1,N=2...分别代表了一个卷积核 195 | # 上各个元素。 196 | p = p_0 + p_n + offset 197 | return p 198 | 199 | # 通过self._get_x_q()获取偏移坐标对应的值 200 | # 201 | def _get_x_q(self, x, q, N): 202 | # 输入x是我们最早输入的数据x,q则是我们的坐标信息。 203 | # 首先我们获取q的相关尺寸信息(b,h,w,2N),再获取x 204 | # 的w保存在padding_w中,将x(b,c,h,w)通过view变成(b,c,h*w)。 205 | # 这样子就把x的坐标信息压缩在了最后一个维度(h*w),这样做的 206 | # 目的是为了使用tensor.gather()通过坐标来获取相应值。 207 | # (这里注意下q的h,w和x的h,w不一定相同,比如stride不为1的时候) 208 | b, h, w, _ = q.size() 209 | padded_w = x.size(3) 210 | c = x.size(1) 211 | # (b, c, h*w) 212 | x = x.contiguous().view(b, c, -1) 213 | # 同样地,由于(h,w)被压缩成了(h*w)所以在这个维度上,每过w个 214 | # 元素,就代表了一行,所以我们的坐标index=offset_x*w+offset_y 215 | # (这样就能在h*w上找到(h,w)的相应坐标)同时再把偏移expand()到 216 | # 每一个通道最后返回x_offset(b,c,h,w,N)。(最后输出x_offset的 217 | # h,w指的是x的h,w而不是q的) 218 | 219 | # (b, h, w, N) 220 | index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y 221 | # (b, c, h*w*N) 222 | index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1) 223 | 224 | x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) 225 | 226 | return x_offset 227 | 228 | @staticmethod 229 | def _reshape_x_offset(x_offset, ks): 230 | # 函数首先获取了x_offset的所有size信息,然后以kernel_size为 231 | # 单位进行reshape,因为N=kernel_size*kernel_size,所以我们 232 | # 分两次进行reshape,第一次先把输入view成(b,c,h,ks*w,ks), 233 | # 第二次再view将size变成(b,c,h*ks,w*ks) 234 | 235 | b, c, h, w, N = x_offset.size() 236 | x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1) 237 | x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks) 238 | 239 | return x_offset 240 | -------------------------------------------------------------------------------- /scaled_mnist/__pycache__/archs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBuf/pytorch-deform-conv-v2-explain/9c6970583c405b138949c2a035ca4d7093098cd5/scaled_mnist/__pycache__/archs.cpython-36.pyc -------------------------------------------------------------------------------- /scaled_mnist/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBuf/pytorch-deform-conv-v2-explain/9c6970583c405b138949c2a035ca4d7093098cd5/scaled_mnist/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /scaled_mnist/archs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | from torch import nn 5 | from torch.nn import functional as F 6 | import torch 7 | from torchvision import models 8 | import torchvision 9 | 10 | from deform_conv_v2 import * 11 | 12 | 13 | class ScaledMNISTNet(nn.Module): 14 | def __init__(self, args, num_classes): 15 | super().__init__() 16 | 17 | self.relu = nn.ReLU(inplace=True) 18 | self.sigmoid = nn.Sigmoid() 19 | self.pool = nn.MaxPool2d((2, 2)) 20 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 21 | 22 | features = [] 23 | inplanes = 1 24 | outplanes = 32 25 | for i in range(4): 26 | if args.deform and args.min_deform_layer <= i+1: 27 | features.append(DeformConv2d(inplanes, outplanes, 3, padding=1, bias=False, modulation=args.modulation)) 28 | else: 29 | features.append(nn.Conv2d(inplanes, outplanes, 3, padding=1, bias=False)) 30 | features.append(nn.BatchNorm2d(outplanes)) 31 | features.append(self.relu) 32 | if i == 1: 33 | features.append(self.pool) 34 | inplanes = outplanes 35 | outplanes *= 2 36 | self.features = nn.Sequential(*features) 37 | 38 | self.fc = nn.Linear(256, 10) 39 | 40 | def forward(self, input): 41 | x = self.features(input) 42 | x = self.avg_pool(x) 43 | x = x.view(x.shape[0], -1) 44 | output = self.fc(x) 45 | 46 | return output 47 | -------------------------------------------------------------------------------- /scaled_mnist/dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | from torch.utils.data import Dataset 3 | from matplotlib import pyplot as plt 4 | import cv2 5 | import numpy as np 6 | import random 7 | import scipy.ndimage as ndi 8 | from tqdm import tqdm 9 | import os 10 | from PIL import Image 11 | 12 | 13 | class ScaledMNIST(Dataset): 14 | def __init__(self, train=True, transform=None, target_transform=None): 15 | self.transform = transform 16 | self.target_transform = target_transform 17 | self.train = train # training set or test set 18 | 19 | if not os.path.exists('input/scaled_mnist_train.npz'): 20 | train_dataset = datasets.MNIST('~/data', train=True, download=True) 21 | train_imgs, train_labels = train_dataset.train_data.numpy(), train_dataset.train_labels.numpy() 22 | 23 | scaled_train_imgs = [] 24 | for i in tqdm(range(len(train_imgs))): 25 | img = np.pad(train_imgs[i], 14, 'constant') 26 | img = random_zoom(img[:,:,np.newaxis], (0.5, 1.5)) 27 | scaled_train_imgs.append(img[:,:,0]) 28 | scaled_train_imgs = np.array(scaled_train_imgs) 29 | 30 | np.savez('input/scaled_mnist_train.npz', images=scaled_train_imgs, labels=train_labels) 31 | 32 | if not os.path.exists('input/scaled_mnist_test.npz'): 33 | test_dataset = datasets.MNIST('~/data', train=False, download=True) 34 | test_imgs, test_labels = test_dataset.test_data.numpy(), test_dataset.test_labels.numpy() 35 | 36 | scaled_test_imgs = [] 37 | for i in tqdm(range(len(test_imgs))): 38 | img = np.pad(test_imgs[i], 14, 'constant') 39 | img = random_zoom(img[:,:,np.newaxis], (0.5, 1.5)) 40 | scaled_test_imgs.append(img[:,:,0]) 41 | scaled_test_imgs = np.array(scaled_test_imgs) 42 | 43 | np.savez('input/scaled_mnist_test.npz', images=scaled_test_imgs, labels=test_labels) 44 | 45 | if self.train: 46 | scaled_mnist_train = np.load('input/scaled_mnist_train.npz') 47 | self.train_data = scaled_mnist_train['images'] 48 | self.train_labels = scaled_mnist_train['labels'] 49 | else: 50 | scaled_mnist_test = np.load('input/scaled_mnist_test.npz') 51 | self.test_data = scaled_mnist_test['images'] 52 | self.test_labels = scaled_mnist_test['labels'] 53 | 54 | def __getitem__(self, index): 55 | """ 56 | Args: 57 | index (int): Index 58 | 59 | Returns: 60 | tuple: (image, target) where target is index of the target class. 61 | """ 62 | if self.train: 63 | img, target = self.train_data[index], self.train_labels[index] 64 | else: 65 | img, target = self.test_data[index], self.test_labels[index] 66 | 67 | # doing this so that it is consistent with all other datasets 68 | # to return a PIL Image 69 | img = Image.fromarray(img, mode='L') 70 | 71 | if self.transform is not None: 72 | img = self.transform(img) 73 | 74 | if self.target_transform is not None: 75 | target = self.target_transform(target) 76 | 77 | return img, target 78 | 79 | def __len__(self): 80 | if self.train: 81 | return len(self.train_data) 82 | else: 83 | return len(self.test_data) 84 | 85 | 86 | def transform_matrix_offset_center(matrix, x, y): 87 | o_x = float(x) / 2 + 0.5 88 | o_y = float(y) / 2 + 0.5 89 | offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) 90 | reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) 91 | transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix) 92 | return transform_matrix 93 | 94 | 95 | def apply_transform(x, transform_matrix, channel_axis=0, 96 | fill_mode='nearest', cval=0.): 97 | x = np.rollaxis(x, channel_axis, 0) 98 | final_affine_matrix = transform_matrix[:2, :2] 99 | final_offset = transform_matrix[:2, 2] 100 | channel_images = [ndi.interpolation.affine_transform( 101 | x_channel, 102 | final_affine_matrix, 103 | final_offset, 104 | order=0, 105 | mode=fill_mode, 106 | cval=cval) for x_channel in x] 107 | x = np.stack(channel_images, axis=0) 108 | x = np.rollaxis(x, 0, channel_axis + 1) 109 | return x 110 | 111 | 112 | def random_zoom(X, zoom_range, row_axis=0, col_axis=1, channel_axis=2, 113 | fill_mode='nearest', cval=0.): 114 | if len(zoom_range) != 2: 115 | raise ValueError('`zoom_range` should be a tuple or list of two floats. ' 116 | 'Received arg: ', zoom_range) 117 | 118 | z = np.random.uniform(zoom_range[0], zoom_range[1]) 119 | zoom_matrix = np.array([[z, 0, 0], 120 | [0, z, 0], 121 | [0, 0, 1]]) 122 | 123 | h, w = X.shape[row_axis], X.shape[col_axis] 124 | transform_matrix = transform_matrix_offset_center(zoom_matrix, h, w) 125 | X = apply_transform(X, transform_matrix, channel_axis, fill_mode, cval) 126 | 127 | return X 128 | -------------------------------------------------------------------------------- /scaled_mnist_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from tqdm import tqdm 5 | import pandas as pd 6 | import joblib 7 | from collections import OrderedDict 8 | from datetime import datetime 9 | 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | 17 | from utils import * 18 | from scaled_mnist.dataset import ScaledMNIST 19 | import scaled_mnist.archs as archs 20 | 21 | arch_names = archs.__dict__.keys() 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument('--name', default=None, 28 | help='model name: (default: arch+timestamp)') 29 | parser.add_argument('--arch', '-a', metavar='ARCH', default='ScaledMNISTNet', 30 | choices=arch_names, 31 | help='model architecture: ' + 32 | ' | '.join(arch_names) + 33 | ' (default: ScaledMNISTNet)') 34 | parser.add_argument('--deform', default=True, type=str2bool, 35 | help='use deform conv') 36 | parser.add_argument('--modulation', default=True, type=str2bool, 37 | help='use modulated deform conv') 38 | parser.add_argument('--min-deform-layer', default=3, type=int, 39 | help='minimum number of layer using deform conv') 40 | parser.add_argument('--epochs', default=10, type=int, metavar='N', 41 | help='number of total epochs to run') 42 | parser.add_argument('--optimizer', default='SGD', 43 | choices=['Adam', 'SGD'], 44 | help='loss: ' + 45 | ' | '.join(['Adam', 'SGD']) + 46 | ' (default: Adam)') 47 | parser.add_argument('--lr', '--learning-rate', default=1e-2, type=float, 48 | metavar='LR', help='initial learning rate') 49 | parser.add_argument('--momentum', default=0.5, type=float, 50 | help='momentum') 51 | parser.add_argument('--weight-decay', default=1e-4, type=float, 52 | help='weight decay') 53 | parser.add_argument('--nesterov', default=False, type=str2bool, 54 | help='nesterov') 55 | 56 | args = parser.parse_args() 57 | 58 | return args 59 | 60 | 61 | def train(args, train_loader, model, criterion, optimizer, epoch, scheduler=None): 62 | losses = AverageMeter() 63 | scores = AverageMeter() 64 | 65 | model.train() 66 | 67 | for i, (input, target) in tqdm(enumerate(train_loader), total=len(train_loader)): 68 | input = input.cuda() 69 | target = target.cuda() 70 | 71 | output = model(input) 72 | loss = criterion(output, target) 73 | 74 | acc = accuracy(output, target)[0] 75 | 76 | losses.update(loss.item(), input.size(0)) 77 | scores.update(acc.item(), input.size(0)) 78 | 79 | # compute gradient and do optimizing step 80 | optimizer.zero_grad() 81 | loss.backward() 82 | optimizer.step() 83 | 84 | log = OrderedDict([ 85 | ('loss', losses.avg), 86 | ('acc', scores.avg), 87 | ]) 88 | 89 | return log 90 | 91 | 92 | def validate(args, val_loader, model, criterion): 93 | losses = AverageMeter() 94 | scores = AverageMeter() 95 | 96 | # switch to evaluate mode 97 | model.eval() 98 | 99 | with torch.no_grad(): 100 | for i, (input, target) in tqdm(enumerate(val_loader), total=len(val_loader)): 101 | input = input.cuda() 102 | target = target.cuda() 103 | 104 | output = model(input) 105 | loss = criterion(output, target) 106 | 107 | acc = accuracy(output, target)[0] 108 | 109 | losses.update(loss.item(), input.size(0)) 110 | scores.update(acc.item(), input.size(0)) 111 | 112 | log = OrderedDict([ 113 | ('loss', losses.avg), 114 | ('acc', scores.avg), 115 | ]) 116 | 117 | return log 118 | 119 | 120 | def main(): 121 | args = parse_args() 122 | 123 | if args.name is None: 124 | args.name = '%s' %args.arch 125 | if args.deform: 126 | args.name += '_wDCN' 127 | if args.modulation: 128 | args.name += 'v2' 129 | args.name += '_c%d-4' %args.min_deform_layer 130 | 131 | if not os.path.exists('models/%s' %args.name): 132 | os.makedirs('models/%s' %args.name) 133 | 134 | print('Config -----') 135 | for arg in vars(args): 136 | print('%s: %s' %(arg, getattr(args, arg))) 137 | print('------------') 138 | 139 | with open('models/%s/args.txt' %args.name, 'w') as f: 140 | for arg in vars(args): 141 | print('%s: %s' %(arg, getattr(args, arg)), file=f) 142 | 143 | joblib.dump(args, 'models/%s/args.pkl' %args.name) 144 | 145 | criterion = nn.CrossEntropyLoss().cuda() 146 | 147 | cudnn.benchmark = True 148 | 149 | # data loading code 150 | transform_train = transforms.Compose([ 151 | transforms.ToTensor(), 152 | transforms.Normalize((0.1307,), (0.3081,)) 153 | ]) 154 | 155 | transform_test = transforms.Compose([ 156 | transforms.ToTensor(), 157 | transforms.Normalize((0.1307,), (0.3081,)) 158 | ]) 159 | 160 | train_set = ScaledMNIST( 161 | train=True, 162 | transform=transform_train) 163 | train_loader = torch.utils.data.DataLoader( 164 | train_set, 165 | batch_size=32, 166 | shuffle=True, 167 | num_workers=8) 168 | 169 | test_set = ScaledMNIST( 170 | train=False, 171 | transform=transform_train) 172 | test_loader = torch.utils.data.DataLoader( 173 | test_set, 174 | batch_size=32, 175 | shuffle=False, 176 | num_workers=8) 177 | 178 | num_classes = 10 179 | 180 | # create model 181 | model = archs.__dict__[args.arch](args, num_classes) 182 | model = model.cuda() 183 | 184 | print(model) 185 | 186 | if args.optimizer == 'Adam': 187 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) 188 | elif args.optimizer == 'SGD': 189 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, 190 | momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) 191 | 192 | log = pd.DataFrame(index=[], columns=[ 193 | 'epoch', 'lr', 'loss', 'acc', 'val_loss', 'val_acc' 194 | ]) 195 | 196 | best_acc = 0 197 | for epoch in range(args.epochs): 198 | print('Epoch [%d/%d]' %(epoch, args.epochs)) 199 | 200 | # train for one epoch 201 | train_log = train(args, train_loader, model, criterion, optimizer, epoch) 202 | # evaluate on validation set 203 | val_log = validate(args, test_loader, model, criterion) 204 | 205 | print('loss %.4f - acc %.4f - val_loss %.4f - val_acc %.4f' 206 | %(train_log['loss'], train_log['acc'], val_log['loss'], val_log['acc'])) 207 | 208 | tmp = pd.Series([ 209 | epoch, 210 | 1e-1, 211 | train_log['loss'], 212 | train_log['acc'], 213 | val_log['loss'], 214 | val_log['acc'], 215 | ], index=['epoch', 'lr', 'loss', 'acc', 'val_loss', 'val_acc']) 216 | 217 | log = log.append(tmp, ignore_index=True) 218 | log.to_csv('models/%s/log.csv' %args.name, index=False) 219 | 220 | if val_log['acc'] > best_acc: 221 | torch.save(model.state_dict(), 'models/%s/model.pth' %args.name) 222 | best_acc = val_log['acc'] 223 | print("=> saved best model") 224 | 225 | print("best val_acc: %f" %best_acc) 226 | 227 | 228 | if __name__ == '__main__': 229 | main() 230 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | from PIL import Image 4 | import numpy as np 5 | 6 | import torch 7 | 8 | 9 | def str2bool(v): 10 | if v.lower() in ['true', 1]: 11 | return True 12 | elif v.lower() in ['false', 0]: 13 | return False 14 | else: 15 | raise argparse.ArgumentTypeError('Boolean value expected.') 16 | 17 | 18 | def count_params(model): 19 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 20 | 21 | 22 | class AverageMeter(object): 23 | """Computes and stores the average and current value""" 24 | def __init__(self): 25 | self.reset() 26 | 27 | def reset(self): 28 | self.val = 0 29 | self.avg = 0 30 | self.sum = 0 31 | self.count = 0 32 | 33 | def update(self, val, n=1): 34 | self.val = val 35 | self.sum += val * n 36 | self.count += n 37 | self.avg = self.sum / self.count 38 | 39 | 40 | def accuracy(output, target, topk=(1,)): 41 | """Computes the accuracy over the k top predictions for the specified values of k""" 42 | with torch.no_grad(): 43 | maxk = max(topk) 44 | batch_size = target.size(0) 45 | 46 | _, pred = output.topk(maxk, 1, True, True) 47 | pred = pred.t() 48 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 49 | 50 | res = [] 51 | for k in topk: 52 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 53 | res.append(correct_k.mul_(100.0 / batch_size)) 54 | return res 55 | --------------------------------------------------------------------------------