├── loss_plot.png ├── .gitignore ├── README.md ├── test.py ├── pytorchtools.py ├── trainer.py └── CSPResNet.py /loss_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robintzeng/Pytorch-CSPNet/HEAD/loss_plot.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | runs 2 | data 3 | *.pth 4 | *.pyc 5 | *.pt 6 | __pycache__ 7 | checkpoint.pt 8 | resnet.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-CSPNet 2 | The network can perform 80% on CIFAR10 without specific training scheduling. 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | import torch.optim as optim 8 | from tqdm import tqdm 9 | 10 | from CSPResNet import csp_resnet152 11 | 12 | device = torch.device("cuda:0") 13 | print(device) 14 | 15 | #PATH = 'checkpoint76.pt' 16 | net = csp_resnet152(pretrained=False,num_classes = 10) 17 | 18 | net.load_state_dict(torch.load('checkpoint.pt')) 19 | net.to(device) 20 | #net.load_state_dict(torch.load(PATH)) 21 | 22 | 23 | 24 | transform = transforms.Compose( 25 | [transforms.ToTensor(), 26 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 27 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 28 | download=True, transform=transform) 29 | testloader = torch.utils.data.DataLoader(testset, batch_size=64, 30 | shuffle=False, num_workers=2) 31 | 32 | classes = ('plane', 'car', 'bird', 'cat', 33 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 34 | 35 | 36 | class_correct = list(0. for i in range(10)) 37 | class_total = list(0. for i in range(10)) 38 | with torch.no_grad(): 39 | for data in tqdm(testloader): 40 | images, labels = data[0].to(device), data[1].to(device) 41 | outputs = net(images) 42 | _, predicted = torch.max(outputs, 1) 43 | c = (predicted == labels).squeeze() 44 | for i in range(4): 45 | label = labels[i] 46 | class_correct[label] += c[i].item() 47 | class_total[label] += 1 48 | 49 | 50 | for i in range(10): 51 | print('Accuracy of %5s : %2d %%' % ( 52 | classes[i], 100 * class_correct[i] / class_total[i])) 53 | 54 | correct = 0 55 | total = 0 56 | with torch.no_grad(): 57 | for data in testloader: 58 | images, labels = data[0].to(device), data[1].to(device) 59 | outputs = net(images) 60 | _, predicted = torch.max(outputs.data, 1) 61 | total += labels.size(0) 62 | correct += (predicted == labels).sum().item() 63 | 64 | print('Accuracy of the network on the 10000 test images: %d %%' % ( 65 | 100 * correct / total)) 66 | -------------------------------------------------------------------------------- /pytorchtools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class EarlyStopping: 5 | """Early stops the training if validation loss doesn't improve after a given patience.""" 6 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): 7 | """ 8 | Args: 9 | patience (int): How long to wait after last time validation loss improved. 10 | Default: 7 11 | verbose (bool): If True, prints a message for each validation loss improvement. 12 | Default: False 13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 14 | Default: 0 15 | path (str): Path for the checkpoint to be saved to. 16 | Default: 'checkpoint.pt' 17 | trace_func (function): trace print function. 18 | Default: print 19 | """ 20 | self.patience = patience 21 | self.verbose = verbose 22 | self.counter = 0 23 | self.best_score = None 24 | self.early_stop = False 25 | self.val_loss_min = np.Inf 26 | self.delta = delta 27 | self.path = path 28 | self.trace_func = trace_func 29 | def __call__(self, val_loss, model): 30 | 31 | score = -val_loss 32 | 33 | if self.best_score is None: 34 | self.best_score = score 35 | self.save_checkpoint(val_loss, model) 36 | elif score < self.best_score + self.delta: 37 | self.counter += 1 38 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 39 | if self.counter >= self.patience: 40 | self.early_stop = True 41 | else: 42 | self.best_score = score 43 | self.save_checkpoint(val_loss, model) 44 | self.counter = 0 45 | 46 | def save_checkpoint(self, val_loss, model): 47 | '''Saves model when validation loss decrease.''' 48 | if self.verbose: 49 | self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 50 | torch.save(model.state_dict(), self.path) 51 | self.val_loss_min = val_loss -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | import torch.optim as optim 8 | import numpy as np 9 | from tqdm import tqdm 10 | from torch.utils.data.sampler import SubsetRandomSampler 11 | from pytorchtools import EarlyStopping 12 | from CSPResNet import csp_resnet152,csp_resnet50 13 | import matplotlib.pyplot as plt 14 | 15 | class AddGaussianNoise(object): 16 | def __init__(self, mean=0., std=1.): 17 | self.std = std 18 | self.mean = mean 19 | 20 | def __call__(self, tensor): 21 | return tensor + torch.randn(tensor.size()) * self.std + self.mean 22 | 23 | def __repr__(self): 24 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 25 | 26 | classes = ('plane', 'car', 'bird', 'cat', 27 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 28 | 29 | def train(net,n_epoches = 500, 30 | patience =20, 31 | valid_size = 0.2, 32 | batch_size = 64): 33 | 34 | if torch.cuda.is_available(): 35 | device = torch.device("cuda:0") 36 | torch.backends.cudnn.benchmark = True 37 | print(device) 38 | else: 39 | device = torch.device("cpu") 40 | print(device) 41 | 42 | net.to(device) 43 | 44 | 45 | transform_train = transforms.Compose([ 46 | #transforms.RandomCrop(32, padding=4), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 50 | AddGaussianNoise(0., 1.) 51 | ]) 52 | 53 | transform_test = transforms.Compose([ 54 | transforms.ToTensor(), 55 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 56 | ]) 57 | 58 | 59 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 60 | download=True, transform=transform_train) 61 | 62 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 63 | download=True, transform=transform_test) 64 | 65 | num_train = len(trainset) 66 | indices = list(range(num_train)) 67 | np.random.shuffle(indices) 68 | split = int(np.floor(valid_size * num_train)) 69 | train_idx, valid_idx = indices[split:], indices[:split] 70 | 71 | train_sampler = SubsetRandomSampler(train_idx) 72 | valid_sampler = SubsetRandomSampler(valid_idx) 73 | 74 | 75 | 76 | trainloader = torch.utils.data.DataLoader(trainset, 77 | batch_size=batch_size, 78 | sampler = train_sampler, 79 | num_workers=0) 80 | 81 | validloader = torch.utils.data.DataLoader(trainset, 82 | batch_size=batch_size, 83 | sampler=valid_sampler, 84 | num_workers=0) 85 | 86 | 87 | testloader = torch.utils.data.DataLoader(testset, 88 | batch_size=batch_size, 89 | num_workers=0) 90 | 91 | 92 | criterion = nn.CrossEntropyLoss() 93 | optimizer = torch.optim.Adam(net.parameters(),lr = 0.001) 94 | 95 | #optimizer = torch.optim.SGD(net.parameters(), lr=0.01, 96 | # momentum=0.9, weight_decay=5e-4) 97 | 98 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[40,70], gamma=0.1) 99 | 100 | train_losses = [] 101 | valid_losses = [] 102 | test_acces = [] 103 | avg_train_losses = [] 104 | avg_valid_losses = [] 105 | 106 | early_stopping = EarlyStopping(patience=patience, verbose=True) 107 | 108 | 109 | 110 | 111 | for epoch in range(1, n_epoches + 1): # loop over the dataset multiple times 112 | running_loss = 0.0 113 | net.train() 114 | 115 | for i, data in tqdm(enumerate(trainloader, 0)): 116 | 117 | inputs, labels = data[0].to(device), data[1].to(device) 118 | 119 | # get the inputs; data is a list of [inputs, labels] 120 | #inputs, labels = data[0].to(device), data[1].to(device) 121 | # zero the parameter gradients 122 | optimizer.zero_grad() 123 | 124 | # forward + backward + optimize 125 | outputs = net(inputs) 126 | loss = criterion(outputs, labels) 127 | loss.backward() 128 | optimizer.step() 129 | 130 | train_losses.append(loss.item()) 131 | 132 | 133 | 134 | # print statistics 135 | running_loss += loss.item() 136 | if i % 200 == 199: # print every 2000 mini-batches 137 | print('[%d, %5d] loss: %.3f' % 138 | (epoch + 1, i + 1, running_loss / 200)) 139 | running_loss = 0.0 140 | 141 | 142 | scheduler.step() 143 | 144 | net.eval() # prep model for evaluation 145 | for data, target in validloader: 146 | # forward pass: compute predicted outputs by passing inputs to the model 147 | data, target = data.to(device), target.to(device) 148 | 149 | output = net(data) 150 | # calculate the loss 151 | loss = criterion(output, target) 152 | # record validation loss 153 | valid_losses.append(loss.item()) 154 | 155 | correct = 0 156 | total = 0 157 | with torch.no_grad(): 158 | for data in testloader: 159 | images, labels = data[0].to(device), data[1].to(device) 160 | outputs = net(images) 161 | _, predicted = torch.max(outputs.data, 1) 162 | total += labels.size(0) 163 | correct += (predicted == labels).sum().item() 164 | 165 | test_acc = 100 * correct / total 166 | print('Accuracy of the network on the 10000 test images: %d %%' % ( 167 | test_acc)) 168 | 169 | 170 | train_loss = np.average(train_losses) 171 | valid_loss = np.average(valid_losses) 172 | avg_train_losses.append(train_loss) 173 | avg_valid_losses.append(valid_loss) 174 | test_acces.append(test_acc) 175 | 176 | 177 | epoch_len = len(str(n_epoches)) 178 | print_msg = (f'[{epoch:>{epoch_len}}/{n_epoches:>{epoch_len}}] ' + 179 | f'train_loss: {train_loss:.5f} ' + 180 | f'valid_loss: {valid_loss:.5f}') 181 | 182 | print(print_msg) 183 | train_losses = [] 184 | valid_losses = [] 185 | early_stopping(valid_loss, net) 186 | 187 | if early_stopping.early_stop: 188 | print("Early stopping") 189 | break 190 | net.load_state_dict(torch.load('checkpoint.pt')) 191 | 192 | return avg_train_losses, avg_valid_losses,test_acces 193 | 194 | 195 | if __name__ == "__main__": 196 | 197 | #net = csp_resnet152(pretrained=True,num_classes = 10) 198 | net = csp_resnet50(pretrained=False,model_path = "checkpoint res50_79.pt",num_classes = 10) 199 | 200 | #y = net(torch.randn(1, 3, 112, 112)) 201 | #print(y.size()) 202 | train_loss, valid_loss,test_acces = train(net,n_epoches = 500,patience =20,valid_size = 0.1,batch_size = 64) 203 | 204 | fig = plt.figure() 205 | plt.plot(range(1,len(train_loss)+1),train_loss, label='Training Loss') 206 | plt.plot(range(1,len(valid_loss)+1),valid_loss,label='Validation Loss') 207 | plt.plot(range(1,len(test_acces)+1),test_acces,label='Test accuracy') 208 | 209 | # find position of lowest validation loss 210 | minposs = valid_loss.index(min(valid_loss))+1 211 | plt.axvline(minposs, linestyle='--', color='r',label='Early Stopping Checkpoint') 212 | 213 | plt.xlabel('epochs') 214 | plt.ylabel('loss') 215 | plt.ylim(0, 2.5) # consistent scale 216 | plt.xlim(0, len(train_loss)+1) # consistent scale 217 | plt.grid(True) 218 | plt.legend() 219 | plt.tight_layout() 220 | plt.show() 221 | fig.savefig('loss_plot.png', bbox_inches='tight') 222 | -------------------------------------------------------------------------------- /CSPResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import torch.optim as optim 5 | 6 | def conv3x3(in_planes, out_planes, stride=1,dilation=1): 7 | """3x3 convolution with padding""" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 9 | padding=dilation,bias=False, dilation=dilation) 10 | 11 | 12 | def conv1x1(in_planes, out_planes): 13 | """1x1 convolution""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=1,stride=1, bias=False) 15 | 16 | class Linear(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def forward(self, x): 21 | return x 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | tran_expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 29 | base_width=64, dilation=1, norm_layer=None): 30 | super(BasicBlock, self).__init__() 31 | if norm_layer is None: 32 | norm_layer = nn.BatchNorm2d 33 | if groups != 1 or base_width != 64: 34 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 35 | if dilation > 1: 36 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 37 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 38 | self.conv1 = conv3x3(inplanes, planes, stride) 39 | self.bn1 = norm_layer(planes) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = norm_layer(planes) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | identity = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | if self.downsample is not None: 57 | identity = self.downsample(x) 58 | 59 | out += identity 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | class CSPBottleneck(nn.Module): 65 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 66 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 67 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 68 | # This variant is also known as ResNet V1.5 and improves accuracy according to 69 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 70 | 71 | expansion = 2 72 | tran_expansion = 4 73 | 74 | def __init__(self, inplanes, planes, stride=1, downsample=None,norm_layer=None): 75 | super(CSPBottleneck, self).__init__() 76 | if norm_layer is None: 77 | norm_layer = nn.BatchNorm2d 78 | 79 | width = planes 80 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 81 | self.conv1 = conv1x1(inplanes, width) 82 | self.bn1 = norm_layer(width) 83 | self.conv2 = conv3x3(width, width, stride) 84 | self.bn2 = norm_layer(width) 85 | self.conv3 = conv1x1(width, planes * self.expansion) 86 | self.bn3 = norm_layer(planes * self.expansion) 87 | self.lrelu = nn.LeakyReLU() 88 | self.downsample = downsample 89 | self.stride = stride 90 | 91 | def forward(self, x): 92 | identity = x 93 | #print("x") 94 | 95 | out = self.conv1(x) 96 | 97 | #print("out") 98 | #print(type(out)) 99 | 100 | out = self.bn1(out) 101 | out = self.lrelu(out) 102 | 103 | 104 | 105 | 106 | out = self.conv2(out) 107 | out = self.bn2(out) 108 | out = self.lrelu(out) 109 | 110 | out = self.conv3(out) 111 | out = self.bn3(out) 112 | 113 | if self.downsample is not None: 114 | identity = self.downsample(x) 115 | 116 | 117 | out += identity 118 | out = self.lrelu(out) 119 | 120 | return out 121 | 122 | class CSPBlock(nn.Module): 123 | 124 | def __init__(self, block, inplanes,blocks, stride=1, downsample=None, norm_layer=None, activation = None): 125 | super(CSPBlock, self).__init__() 126 | 127 | if norm_layer is None: 128 | norm_layer = nn.BatchNorm2d 129 | 130 | if activation is None: 131 | self.activation = nn.LeakyReLU(inplace=True) 132 | else: 133 | self.activation = activation() 134 | 135 | self.inplanes = inplanes 136 | self.norm_layer = norm_layer 137 | 138 | self.crossstage = nn.Conv2d(self.inplanes, self.inplanes*2, kernel_size=1, stride=1,bias=False) 139 | 140 | self.bn_crossstage = norm_layer(self.inplanes*2) 141 | 142 | ## first layer is different from others 143 | if(self.inplanes <= 64): 144 | self.conv1 = nn.Conv2d(self.inplanes, self.inplanes, kernel_size=1, stride=1,bias=False) 145 | self.bn1 = norm_layer(self.inplanes) 146 | self.layer_num = self.inplanes 147 | else: 148 | self.conv1 = nn.Conv2d(self.inplanes, self.inplanes*2, kernel_size=1, stride=1,bias=False) 149 | self.bn1 = norm_layer(self.inplanes*2) 150 | self.layer_num = self.inplanes*2 151 | 152 | 153 | 154 | 155 | self.layers = self._make_layer(block, self.inplanes, blocks) 156 | 157 | self.trans = nn.Conv2d(self.inplanes*2, self.inplanes*2, kernel_size=1, stride=1, bias=False) 158 | 159 | 160 | def forward(self, x): 161 | cross = self.crossstage(x) 162 | cross = self.bn_crossstage(cross) 163 | 164 | cross = self.activation(cross) 165 | 166 | origin = self.conv1(x) 167 | 168 | origin = self.bn1(origin) 169 | 170 | origin = self.activation(origin) 171 | 172 | #print("origin") 173 | #print(type(origin)) 174 | 175 | origin = self.layers(origin) 176 | #origin = self.trans(origin) 177 | 178 | #out = origin 179 | out = torch.cat((origin,cross), dim=1) 180 | 181 | return out 182 | 183 | def _make_layer(self, block, planes, blocks, stride=1): 184 | 185 | norm_layer = self.norm_layer 186 | downsample = None 187 | 188 | if stride != 1 or self.layer_num != planes * block.expansion: 189 | downsample = nn.Sequential( 190 | conv1x1(self.inplanes, planes * block.expansion), 191 | norm_layer(planes * block.expansion), 192 | ) 193 | 194 | layers = [] 195 | 196 | if(self.inplanes <=64): 197 | layers.append(block(self.inplanes, planes, stride, downsample,norm_layer)) 198 | self.inplanes = planes * block.expansion 199 | else: 200 | self.inplanes = planes * block.expansion 201 | layers.append(block(self.inplanes, planes, stride, downsample,norm_layer)) 202 | 203 | for _ in range(1, blocks): 204 | layers.append(block(self.inplanes, planes,norm_layer=norm_layer)) 205 | 206 | return nn.Sequential(*layers) 207 | 208 | class CSPResNet(nn.Module): 209 | 210 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 211 | norm_layer=None): 212 | super(CSPResNet, self).__init__() 213 | if norm_layer is None: 214 | norm_layer = nn.BatchNorm2d 215 | self._norm_layer = norm_layer 216 | 217 | self.inplanes = 64 218 | 219 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,bias=False) 220 | self.bn1 = norm_layer(self.inplanes) 221 | self.lrelu = nn.LeakyReLU() 222 | 223 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 224 | 225 | 226 | 227 | 228 | self.layer1 = CSPBlock(block, 64, layers[0], activation = nn.LeakyReLU) ## 256 out 229 | 230 | self.part_tran1 = self._make_tran(64,block.tran_expansion) 231 | 232 | self.layer2 = CSPBlock(block, 128, layers[1]-1, activation=Linear) 233 | 234 | self.part_tran2 = self._make_tran(128,block.tran_expansion) 235 | 236 | 237 | self.layer3 = CSPBlock(block, 256, layers[2]-1, activation = Linear) 238 | 239 | self.part_tran3 = self._make_tran(256,block.tran_expansion) 240 | 241 | self.layer4 = CSPBlock(block, 512, layers[3]-1, activation = nn.LeakyReLU) 242 | 243 | self.conv2 = nn.Conv2d(512*block.tran_expansion,512*2, kernel_size=1,stride=1,bias=False) 244 | 245 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 246 | 247 | #self.conv3 = nn.Conv2d(512*2,num_classes, kernel_size=1,stride=1) 248 | 249 | 250 | self.fn = nn.Linear(512*2,num_classes) 251 | 252 | for m in self.modules(): 253 | if isinstance(m, nn.Conv2d): 254 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 255 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 256 | nn.init.constant_(m.weight, 1) 257 | nn.init.constant_(m.bias, 0) 258 | 259 | # Zero-initialize the last BN in each residual branch, 260 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 261 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 262 | if zero_init_residual: 263 | for m in self.modules(): 264 | if isinstance(m, CSPBottleneck): 265 | nn.init.constant_(m.bn3.weight, 0) 266 | elif isinstance(m, BasicBlock): 267 | nn.init.constant_(m.bn2.weight, 0) 268 | 269 | def _make_tran(self, base,tran_expansion): 270 | return nn.Sequential( 271 | conv1x1(base*tran_expansion,base*2), 272 | nn.BatchNorm2d(base*2), 273 | nn.LeakyReLU(), 274 | conv3x3(base*2, base*2, stride=2), 275 | nn.BatchNorm2d(base*2), 276 | nn.LeakyReLU() 277 | ) 278 | 279 | def _forward_impl(self, x): 280 | # See note [TorchScript super()] 281 | x = self.conv1(x) 282 | x = self.bn1(x) 283 | x = self.lrelu(x) 284 | x = self.maxpool(x) 285 | 286 | x = self.layer1(x) 287 | x = self.part_tran1(x) 288 | 289 | x = self.layer2(x) 290 | x = self.part_tran2(x) 291 | 292 | x = self.layer3(x) 293 | x = self.part_tran3(x) 294 | 295 | 296 | x = self.layer4(x) 297 | 298 | x = self.conv2(x) 299 | 300 | x = self.avgpool(x) 301 | 302 | 303 | x = x.view(-1,512*2) 304 | 305 | 306 | x = self.fn(x) 307 | 308 | return x 309 | 310 | def forward(self, x): 311 | return self._forward_impl(x) 312 | 313 | 314 | def _cspresnet(arch, block, layers, pretrained,model_path, **kwargs): 315 | model = CSPResNet(block, layers, **kwargs) 316 | if pretrained: 317 | state_dict = torch.load(model_path) 318 | model.load_state_dict(state_dict) 319 | return model 320 | 321 | def csp_resnet50(pretrained=False,model_path = "checkpoint.pt",**kwargs): 322 | r"""ResNet-50 model from 323 | `"Deep Residual Learning for Image Recognition" `_ 324 | Args: 325 | pretrained (bool): If True, returns a model pre-trained on ImageNet 326 | progress (bool): If True, displays a progress bar of the download to stderr 327 | """ 328 | return _cspresnet('cspresnet50', CSPBottleneck, [3, 4, 6, 3], pretrained,model_path = model_path, 329 | **kwargs) 330 | 331 | 332 | def csp_resnet101(pretrained=False,model_path = "checkpoint.pt", **kwargs): 333 | r"""ResNet-101 model from 334 | `"Deep Residual Learning for Image Recognition" `_ 335 | Args: 336 | pretrained (bool): If True, returns a model pre-trained on ImageNet 337 | progress (bool): If True, displays a progress bar of the download to stderr 338 | """ 339 | return _cspresnet('cspresnet101', CSPBottleneck, [3, 4, 23, 3], pretrained,model_path = model_path, 340 | **kwargs) 341 | 342 | 343 | def csp_resnet152(pretrained=False,model_path = "checkpoint.pt", **kwargs): 344 | r"""ResNet-152 model from 345 | `"Deep Residual Learning for Image Recognition" `_ 346 | Args: 347 | pretrained (bool): If True, returns a model pre-trained on ImageNet 348 | progress (bool): If True, displays a progress bar of the download to stderr 349 | """ 350 | return _cspresnet('cspresnet152', CSPBottleneck, [3, 8, 36, 3], pretrained, model_path = model_path, 351 | **kwargs) 352 | 353 | 354 | 355 | if __name__ == "__main__": 356 | net = csp_resnet152(pretrained=False,num_classes = 10) 357 | y = net(torch.randn(1, 3, 112, 112)) 358 | print(y.size()) 359 | 360 | --------------------------------------------------------------------------------