├── Agricultural land.png ├── README.md ├── __pycache__ ├── config.cpython-35.pyc ├── config.cpython-37.pyc ├── main.cpython-35.pyc └── predict.cpython-35.pyc ├── cat_to_name.json ├── combine.py ├── config.py ├── cut.py ├── main.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-37.pyc │ ├── alexnet.cpython-35.pyc │ ├── alexnet.cpython-37.pyc │ ├── densenet.cpython-35.pyc │ ├── densenet.cpython-37.pyc │ ├── googlenet.cpython-35.pyc │ ├── googlenet.cpython-37.pyc │ ├── inception.cpython-35.pyc │ ├── inception.cpython-37.pyc │ ├── mnasnet.cpython-35.pyc │ ├── mnasnet.cpython-37.pyc │ ├── mobilenet.cpython-35.pyc │ ├── mobilenet.cpython-37.pyc │ ├── resnet.cpython-35.pyc │ ├── resnet.cpython-37.pyc │ ├── shufflenetv2.cpython-35.pyc │ ├── shufflenetv2.cpython-37.pyc │ ├── squeezenet.cpython-35.pyc │ ├── squeezenet.cpython-37.pyc │ ├── utils.cpython-35.pyc │ ├── utils.cpython-37.pyc │ ├── vgg.cpython-35.pyc │ └── vgg.cpython-37.pyc ├── _utils.py ├── alexnet.py ├── densenet.py ├── googlenet.py ├── inception.py ├── mnasnet.py ├── mobilenet.py ├── resnet.py ├── shufflenetv2.py ├── squeezenet.py ├── utils.py └── vgg.py ├── predict.py └── trained_models └── data_record.pth /Agricultural land.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/Agricultural land.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch_Remote-Sensing-Image-Classification 2 | 3 | Based on Google Earth remote sensing images of Wuhan in 2016, using pre-trained models Alexnet, VGG11, Resnet-50 and Resnet-101 for transfer learning, the classification model can predict the category of unknown remote sensing images with 90.48% accuracy (ResNet-101) 4 | 5 | 基于2016年武汉市GoogleEarth遥感影像,利用预训练模型Alexnet、VGG11、Resnet-50和Resnet-101进行迁移学习,分类模型可预测未知遥感影像的类别 6 | -------------------------------------------------------------------------------- /__pycache__/config.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/__pycache__/config.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/main.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/__pycache__/main.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/predict.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/__pycache__/predict.cpython-35.pyc -------------------------------------------------------------------------------- /cat_to_name.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/cat_to_name.json -------------------------------------------------------------------------------- /combine.py: -------------------------------------------------------------------------------- 1 | import os 2 | from osgeo import gdal 3 | 4 | #读图像文件 5 | def read_img(self, filename): 6 | 7 | dataset = gdal.Open(filename) #打开文件 8 | 9 | im_width = dataset.RasterXSize #栅格矩阵的列数 10 | im_height = dataset.RasterYSize #栅格矩阵的行数 11 | 12 | im_geotrans = dataset.GetGeoTransform() #仿射矩阵 13 | im_proj = dataset.GetProjection() #地图投影信息 14 | im_data = dataset.ReadAsArray(0, 0, im_width, im_height) #将数据写成数组,对应栅格矩阵 15 | 16 | del dataset 17 | return im_proj, im_geotrans, im_data 18 | 19 | def write_img(self, filename, im_proj, im_geotrans, im_data): 20 | 21 | # 判断栅格数据的数据类型 22 | if 'int8' in im_data.dtype.name: 23 | datatype = gdal.GDT_Byte 24 | elif 'int16' in im_data.dtype.name: 25 | datatype = gdal.GDT_UInt16 26 | else: 27 | datatype = gdal.GDT_Float32 28 | 29 | # 判读数组维数 30 | if len(im_data.shape) == 3: 31 | im_bands, im_height, im_width = im_data.shape 32 | else: 33 | im_bands, (im_height, im_width) = 1, im_data.shape 34 | 35 | # 创建文件 36 | driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间 37 | dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) 38 | 39 | dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 40 | dataset.SetProjection(im_proj) # 写入投影 41 | 42 | if im_bands == 1: 43 | dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据 44 | else: 45 | for i in range(im_bands): 46 | dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) 47 | 48 | del dataset 49 | 50 | 51 | 52 | if __name__ == "__main__": 53 | os.chdir(r'C:/Users/SchaferHolz/Desktop/image') 54 | proj, geotrans, data = read_img('whu.tif') # 读数据 55 | print(proj) 56 | print(geotrans) 57 | #print(data) 58 | print(data.shape) 59 | channel, width, height = data.shape 60 | for i in range(width // 200): # 切割成200*200小图 61 | for j in range(height // 200): 62 | cur_image = data[:, i * 200:(i + 1) * 200, j * 200:(j + 1) * 200] 63 | write_img('images/raw1/{}_{}.tif'.format(i, j), proj, geotrans, cur_image) #写数据 64 | os.chdir(r'C:/Users/SchaferHolz/Desktop/image/images/raw1') 65 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | #数据集的类别 4 | NUM_CLASSES = 11 5 | 6 | #训练时batch的大小 7 | BATCH_SIZE = 128 8 | 9 | #训练轮数 10 | NUM_EPOCHS= 25 11 | 12 | #训练完成,精度和损失文件的保存路径,默认保存在trained_models下 13 | TRAINED_MODEL = 'trained_models/data_record.pth' 14 | 15 | #数据集的存放位置 16 | TRAIN_DATASET_DIR = 'data/train' 17 | VALID_DATASET_DIR = 'data/val' 18 | -------------------------------------------------------------------------------- /cut.py: -------------------------------------------------------------------------------- 1 | import os 2 | import predict 3 | from osgeo import gdal 4 | 5 | 6 | class GRID: 7 | #读图像文件 8 | def read_img(self, filename): 9 | 10 | dataset = gdal.Open(filename) #打开文件 11 | 12 | im_width = dataset.RasterXSize #栅格矩阵的列数 13 | im_height = dataset.RasterYSize #栅格矩阵的行数 14 | 15 | im_geotrans = dataset.GetGeoTransform() #仿射矩阵 16 | im_proj = dataset.GetProjection() #地图投影信息 17 | im_data = dataset.ReadAsArray(0, 0, im_width, im_height) #将数据写成数组,对应栅格矩阵 18 | 19 | del dataset 20 | return im_proj, im_geotrans, im_data 21 | 22 | #写文件,以写成tif为例 23 | def write_img(self, filename, im_proj, im_geotrans, im_data): 24 | #gdal数据类型包括 25 | #gdal.GDT_Byte, 26 | #gdal.GDT_UInt16, gdal.GDT_Int16, gdal.GDT_UInt32, gdal.GDT_Int32, 27 | #gdal.GDT_Float32, gdal.GDT_Float64 28 | 29 | #判断栅格数据的数据类型 30 | if 'int8' in im_data.dtype.name: 31 | datatype = gdal.GDT_Byte 32 | elif 'int16' in im_data.dtype.name: 33 | datatype = gdal.GDT_UInt16 34 | else: 35 | datatype = gdal.GDT_Float32 36 | 37 | #判读数组维数 38 | if len(im_data.shape) == 3: 39 | im_bands, im_height, im_width = im_data.shape 40 | else: 41 | im_bands, (im_height, im_width) = 1, im_data.shape 42 | 43 | #创建文件 44 | driver = gdal.GetDriverByName("GTiff") #数据类型必须有,因为要计算需要多大内存空间 45 | dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) 46 | 47 | dataset.SetGeoTransform(im_geotrans) #写入仿射变换参数 48 | dataset.SetProjection(im_proj) #写入投影 49 | 50 | if im_bands == 1: 51 | dataset.GetRasterBand(1).WriteArray(im_data) #写入数组数据 52 | else: 53 | for i in range(im_bands): 54 | dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) 55 | 56 | del dataset 57 | 58 | 59 | 60 | 61 | if __name__ == "__main__": 62 | os.chdir(r'C:\Users\SchaferHolz\Desktop\image') 63 | proj, geotrans, data = GRID().read_img('whu.tif') # 读数据 64 | print(proj) 65 | print(geotrans) 66 | #print(data) 67 | print(data.shape) 68 | channel, width, height = data.shape 69 | for i in range(width // 200): # 切割成200*200小图 70 | for j in range(height // 200): 71 | cur_image = data[:, i * 200:(i + 1) * 200, j * 200:(j + 1) * 200] 72 | #GRID().write_img('images/raw1/{}_{}.tif'.format(i, j), proj, geotrans, cur_image) #写数据 73 | 74 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, models, transforms 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | import time 7 | 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import models 11 | import config 12 | 13 | import json 14 | from PIL import Image 15 | 16 | #数据增强 17 | train_transforms = transforms.Compose( 18 | [transforms.RandomResizedCrop(size=156, scale=(0.8, 1.0)), #随机裁剪到156*156 19 | transforms.RandomRotation(degrees=15), #随机旋转 20 | transforms.RandomHorizontalFlip(), #随机水平翻转 21 | transforms.CenterCrop(size=124), #中心裁剪到124*124 22 | transforms.ToTensor(), #转化成张量 23 | transforms.Normalize([0.485, 0.456, 0.406], #归一化 24 | [0.229, 0.224, 0.225]) 25 | ]) 26 | 27 | test_valid_transforms = transforms.Compose( 28 | [transforms.Resize(156), 29 | transforms.CenterCrop(124), 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.485, 0.456, 0.406], 32 | [0.229, 0.224, 0.225]) 33 | ]) 34 | 35 | #利用Dataloader加载数据 36 | train_directory = config.TRAIN_DATASET_DIR 37 | valid_directory = config.VALID_DATASET_DIR 38 | 39 | batch_size = config.BATCH_SIZE 40 | num_classes = config.NUM_CLASSES 41 | 42 | train_datasets = datasets.ImageFolder(train_directory, transform=train_transforms) 43 | train_data_size = len(train_datasets) 44 | train_data = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True) 45 | 46 | ''' 47 | #获取索引到类名的映射,以便查看测试影像的输出类 48 | idx_to_class = {v: k for k, v in train_datasets.class_to_idx.items()} 49 | print(idx_to_class) 50 | 51 | with open('cat_to_name.json', 'r') as f: 52 | cat_to_name = json.load(f) 53 | print(cat_to_name) 54 | ''' 55 | 56 | valid_datasets = datasets.ImageFolder(valid_directory,transform=test_valid_transforms) 57 | valid_data_size = len(valid_datasets) 58 | valid_data = torch.utils.data.DataLoader(valid_datasets, batch_size=batch_size, shuffle=True) 59 | 60 | #print(train_data_size, valid_data_size) 61 | 62 | #使用Resnet-50的预训练模型进行迁移学习 63 | resnet50 = models.resnet50(pretrained=True) 64 | 65 | #查看更改前的模型参数 66 | #print('before:{%s}\n'%resnet50) 67 | 68 | for param in resnet50.parameters(): 69 | param.requires_grad = False #冻结预训练网络中的参数 70 | 71 | fc_inputs = resnet50.fc.in_features 72 | resnet50.fc = nn.Sequential( 73 | nn.Linear(fc_inputs, 256), #将resnet50最后的全连接层输入给256输出单元的线性层 74 | nn.ReLU(), 75 | nn.Dropout(0.4), 76 | nn.Linear(256, 11), 77 | nn.LogSoftmax(dim=1) #输出11通道softmax层 78 | ) 79 | #查看更改后的模型参数 80 | #print('after:{%s}\n'%resnet50) 81 | 82 | #定义损失函数和优化器 83 | loss_func = nn.NLLLoss() 84 | optimizer = optim.SGD(resnet50.parameters(), lr = 0.01) 85 | 86 | #训练过程 87 | def train_and_valid(model, loss_function, optimizer, epochs=25): 88 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #若有gpu可用则用gpu 89 | record = [] 90 | best_acc = 0.0 91 | best_epoch = 0 92 | 93 | for epoch in range(epochs): #训练epochs轮 94 | epoch_start = time.time() 95 | print("Epoch: {}/{}".format(epoch + 1, epochs)) 96 | 97 | model.train() #训练 98 | 99 | train_loss = 0.0 100 | train_acc = 0.0 101 | valid_loss = 0.0 102 | valid_acc = 0.0 103 | 104 | for i, (inputs, labels) in enumerate(train_data): 105 | inputs = inputs.to(device) 106 | labels = labels.to(device) 107 | #print(labels) 108 | 109 | optimizer.zero_grad() #梯度清零 110 | 111 | outputs = model(inputs) #数据前馈,正向传播 112 | 113 | loss = loss_function(outputs, labels) #输出误差 114 | 115 | loss.backward() #反向传播 116 | 117 | optimizer.step() #优化器更新参数 118 | 119 | train_loss += loss.item() * inputs.size(0) 120 | 121 | ret, predictions = torch.max(outputs.data, 1) 122 | correct_counts = predictions.eq(labels.data.view_as(predictions)) 123 | 124 | acc = torch.mean(correct_counts.type(torch.FloatTensor)) 125 | 126 | train_acc += acc.item() * inputs.size(0) 127 | 128 | with torch.no_grad(): 129 | model.eval() #验证 130 | 131 | for j, (inputs, labels) in enumerate(valid_data): 132 | inputs = inputs.to(device) 133 | labels = labels.to(device) 134 | 135 | outputs = model(inputs) 136 | 137 | loss = loss_function(outputs, labels) 138 | 139 | valid_loss += loss.item() * inputs.size(0) 140 | 141 | ret, predictions = torch.max(outputs.data, 1) 142 | correct_counts = predictions.eq(labels.data.view_as(predictions)) 143 | 144 | acc = torch.mean(correct_counts.type(torch.FloatTensor)) 145 | 146 | valid_acc += acc.item() * inputs.size(0) 147 | 148 | avg_train_loss = train_loss / train_data_size 149 | avg_train_acc = train_acc / train_data_size 150 | 151 | avg_valid_loss = valid_loss / valid_data_size 152 | avg_valid_acc = valid_acc / valid_data_size 153 | 154 | record.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc]) 155 | 156 | if avg_valid_acc > best_acc : #记录最高准确性的模型 157 | best_acc = avg_valid_acc 158 | best_epoch = epoch + 1 159 | 160 | epoch_end = time.time() 161 | 162 | print("Epoch: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation: Loss: {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format( 163 | epoch + 1, avg_valid_loss, avg_train_acc * 100, avg_valid_loss, avg_valid_acc * 100, 164 | epoch_end - epoch_start)) 165 | print("Best Accuracy for validation : {:.4f} at epoch {:03d}".format(best_acc, best_epoch)) 166 | 167 | torch.save(model, 'trained_models/resnet50_model_' + str(epoch + 1) + '.pth') 168 | return model, record 169 | 170 | ''' 171 | def predict(model, test_image_name): 172 | 173 | transform = test_valid_transforms 174 | 175 | test_image = Image.open(test_image_name).convert('RGB') 176 | plt.imshow(test_image) 177 | 178 | test_image_tensor = transform(test_image) 179 | 180 | if torch.cuda.is_available(): 181 | test_image_tensor = test_image_tensor.view(1, 3, 124, 124).cuda() 182 | else: 183 | test_image_tensor = test_image_tensor.view(1, 3, 124, 124) 184 | 185 | with torch.no_grad(): 186 | model.eval() 187 | # Model outputs log probabilities 188 | start = time.time() 189 | out = model(test_image_tensor) 190 | stop = time.time() 191 | print('cost time', stop - start) 192 | ps = torch.exp(out) 193 | topk, topclass = ps.topk(3, dim=1) 194 | names = [] 195 | for i in range(3): 196 | names.append(cat_to_name[idx_to_class[topclass.cpu().numpy()[0][i]]]) 197 | print("Predcition", i + 1, ":", names[i], ", Score: ", 198 | topk.cpu().numpy()[0][i]) 199 | plt.barh([2, 1, 0], topk.cpu().numpy(), tick_label=names) 200 | ''' 201 | 202 | #结果 203 | if __name__=='__main__': 204 | num_epochs = config.NUM_EPOCHS 205 | trained_model, record = train_and_valid(resnet50, loss_func, optimizer, num_epochs) 206 | torch.save(record, config.TRAINED_MODEL) 207 | 208 | record = np.array(record) 209 | plt.plot(record[:, 0:2]) 210 | plt.legend(['Train Loss', 'Valid Loss']) 211 | plt.xlabel('Epoch Number') 212 | plt.ylabel('Loss') 213 | plt.ylim(0, 1) 214 | plt.savefig('loss.png') 215 | plt.show() 216 | 217 | plt.plot(record[:, 2:4]) 218 | plt.legend(['Train Accuracy', 'Valid Accuracy']) 219 | plt.xlabel('Epoch Number') 220 | plt.ylabel('Accuracy') 221 | plt.ylim(0, 1) 222 | plt.savefig('accuracy.png') 223 | plt.show() 224 | 225 | ''' 226 | model = torch.load('trained_models/resnet50_model_23.pth') 227 | predict(model, '61.png') 228 | ''' -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .alexnet import * 2 | from .resnet import * 3 | from .vgg import * 4 | from .squeezenet import * 5 | from .inception import * 6 | from .densenet import * 7 | from .googlenet import * 8 | from .mobilenet import * 9 | from .mnasnet import * 10 | from .shufflenetv2 import * 11 | 12 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/alexnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/alexnet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/alexnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/alexnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/densenet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/densenet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/densenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/densenet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/googlenet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/googlenet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/googlenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/googlenet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/inception.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/inception.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/inception.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/inception.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/mnasnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/mnasnet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/mnasnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/mnasnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/mobilenet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/mobilenet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/mobilenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/mobilenet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/resnet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/shufflenetv2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/shufflenetv2.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/shufflenetv2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/shufflenetv2.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/squeezenet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/squeezenet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/squeezenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/squeezenet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/vgg.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/vgg.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/vgg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/models/__pycache__/vgg.cpython-37.pyc -------------------------------------------------------------------------------- /models/_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class IntermediateLayerGetter(nn.ModuleDict): 8 | """ 9 | Module wrapper that returns intermediate layers from a model 10 | 11 | It has a strong assumption that the modules have been registered 12 | into the model in the same order as they are used. 13 | This means that one should **not** reuse the same nn.Module 14 | twice in the forward if you want this to work. 15 | 16 | Additionally, it is only able to query submodules that are directly 17 | assigned to the model. So if `model` is passed, `model.feature1` can 18 | be returned, but not `model.feature1.layer2`. 19 | 20 | Arguments: 21 | model (nn.Module): model on which we will extract the features 22 | return_layers (Dict[name, new_name]): a dict containing the names 23 | of the modules for which the activations will be returned as 24 | the key of the dict, and the value of the dict is the name 25 | of the returned activation (which the user can specify). 26 | 27 | Examples:: 28 | 29 | >>> m = torchvision.models.resnet18(pretrained=True) 30 | >>> # extract layer1 and layer3, giving as names `feat1` and feat2` 31 | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, 32 | >>> {'layer1': 'feat1', 'layer3': 'feat2'}) 33 | >>> out = new_m(torch.rand(1, 3, 224, 224)) 34 | >>> print([(k, v.shape) for k, v in out.items()]) 35 | >>> [('feat1', torch.Size([1, 64, 56, 56])), 36 | >>> ('feat2', torch.Size([1, 256, 14, 14]))] 37 | """ 38 | def __init__(self, model, return_layers): 39 | if not set(return_layers).issubset([name for name, _ in model.named_children()]): 40 | raise ValueError("return_layers are not present in model") 41 | 42 | orig_return_layers = return_layers 43 | return_layers = {k: v for k, v in return_layers.items()} 44 | layers = OrderedDict() 45 | for name, module in model.named_children(): 46 | layers[name] = module 47 | if name in return_layers: 48 | del return_layers[name] 49 | if not return_layers: 50 | break 51 | 52 | super(IntermediateLayerGetter, self).__init__(layers) 53 | self.return_layers = orig_return_layers 54 | 55 | def forward(self, x): 56 | out = OrderedDict() 57 | for name, module in self.named_children(): 58 | x = module(x) 59 | if name in self.return_layers: 60 | out_name = self.return_layers[name] 61 | out[out_name] = x 62 | return out 63 | -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .utils import load_state_dict_from_url 4 | 5 | 6 | __all__ = ['AlexNet', 'alexnet'] 7 | 8 | 9 | model_urls = { 10 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 11 | } 12 | 13 | 14 | class AlexNet(nn.Module): 15 | 16 | def __init__(self, num_classes=1000): 17 | super(AlexNet, self).__init__() 18 | self.features = nn.Sequential( 19 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 20 | nn.ReLU(inplace=True), 21 | nn.MaxPool2d(kernel_size=3, stride=2), 22 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 23 | nn.ReLU(inplace=True), 24 | nn.MaxPool2d(kernel_size=3, stride=2), 25 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 26 | nn.ReLU(inplace=True), 27 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 30 | nn.ReLU(inplace=True), 31 | nn.MaxPool2d(kernel_size=3, stride=2), 32 | ) 33 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 34 | self.classifier = nn.Sequential( 35 | nn.Dropout(), 36 | nn.Linear(256 * 6 * 6, 4096), 37 | nn.ReLU(inplace=True), 38 | nn.Dropout(), 39 | nn.Linear(4096, 4096), 40 | nn.ReLU(inplace=True), 41 | nn.Linear(4096, num_classes), 42 | ) 43 | 44 | def forward(self, x): 45 | x = self.features(x) 46 | x = self.avgpool(x) 47 | x = torch.flatten(x, 1) 48 | x = self.classifier(x) 49 | return x 50 | 51 | 52 | def alexnet(pretrained=False, progress=True, **kwargs): 53 | r"""AlexNet model architecture from the 54 | `"One weird trick..." `_ paper. 55 | 56 | Args: 57 | pretrained (bool): If True, returns a model pre-trained on ImageNet 58 | progress (bool): If True, displays a progress bar of the download to stderr 59 | """ 60 | model = AlexNet(**kwargs) 61 | if pretrained: 62 | state_dict = load_state_dict_from_url(model_urls['alexnet'], 63 | progress=progress) 64 | model.load_state_dict(state_dict) 65 | return model 66 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint as cp 6 | from collections import OrderedDict 7 | from .utils import load_state_dict_from_url 8 | 9 | 10 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 11 | 12 | model_urls = { 13 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 14 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 15 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 16 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 17 | } 18 | 19 | 20 | def _bn_function_factory(norm, relu, conv): 21 | def bn_function(*inputs): 22 | concated_features = torch.cat(inputs, 1) 23 | bottleneck_output = conv(relu(norm(concated_features))) 24 | return bottleneck_output 25 | 26 | return bn_function 27 | 28 | 29 | class _DenseLayer(nn.Sequential): 30 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False): 31 | super(_DenseLayer, self).__init__() 32 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 33 | self.add_module('relu1', nn.ReLU(inplace=True)), 34 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 35 | growth_rate, kernel_size=1, stride=1, 36 | bias=False)), 37 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 38 | self.add_module('relu2', nn.ReLU(inplace=True)), 39 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 40 | kernel_size=3, stride=1, padding=1, 41 | bias=False)), 42 | self.drop_rate = drop_rate 43 | self.memory_efficient = memory_efficient 44 | 45 | def forward(self, *prev_features): 46 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 47 | if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 48 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 49 | else: 50 | bottleneck_output = bn_function(*prev_features) 51 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 52 | if self.drop_rate > 0: 53 | new_features = F.dropout(new_features, p=self.drop_rate, 54 | training=self.training) 55 | return new_features 56 | 57 | 58 | class _DenseBlock(nn.Module): 59 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False): 60 | super(_DenseBlock, self).__init__() 61 | for i in range(num_layers): 62 | layer = _DenseLayer( 63 | num_input_features + i * growth_rate, 64 | growth_rate=growth_rate, 65 | bn_size=bn_size, 66 | drop_rate=drop_rate, 67 | memory_efficient=memory_efficient, 68 | ) 69 | self.add_module('denselayer%d' % (i + 1), layer) 70 | 71 | def forward(self, init_features): 72 | features = [init_features] 73 | for name, layer in self.named_children(): 74 | new_features = layer(*features) 75 | features.append(new_features) 76 | return torch.cat(features, 1) 77 | 78 | 79 | class _Transition(nn.Sequential): 80 | def __init__(self, num_input_features, num_output_features): 81 | super(_Transition, self).__init__() 82 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 83 | self.add_module('relu', nn.ReLU(inplace=True)) 84 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 85 | kernel_size=1, stride=1, bias=False)) 86 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 87 | 88 | 89 | class DenseNet(nn.Module): 90 | r"""Densenet-BC model class, based on 91 | `"Densely Connected Convolutional Networks" `_ 92 | 93 | Args: 94 | growth_rate (int) - how many filters to add each layer (`k` in paper) 95 | block_config (list of 4 ints) - how many layers in each pooling block 96 | num_init_features (int) - the number of filters to learn in the first convolution layer 97 | bn_size (int) - multiplicative factor for number of bottle neck layers 98 | (i.e. bn_size * k features in the bottleneck layer) 99 | drop_rate (float) - dropout rate after each dense layer 100 | num_classes (int) - number of classification classes 101 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 102 | but slower. Default: *False*. See `"paper" `_ 103 | """ 104 | 105 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 106 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False): 107 | 108 | super(DenseNet, self).__init__() 109 | 110 | # First convolution 111 | self.features = nn.Sequential(OrderedDict([ 112 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, 113 | padding=3, bias=False)), 114 | ('norm0', nn.BatchNorm2d(num_init_features)), 115 | ('relu0', nn.ReLU(inplace=True)), 116 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 117 | ])) 118 | 119 | # Each denseblock 120 | num_features = num_init_features 121 | for i, num_layers in enumerate(block_config): 122 | block = _DenseBlock( 123 | num_layers=num_layers, 124 | num_input_features=num_features, 125 | bn_size=bn_size, 126 | growth_rate=growth_rate, 127 | drop_rate=drop_rate, 128 | memory_efficient=memory_efficient 129 | ) 130 | self.features.add_module('denseblock%d' % (i + 1), block) 131 | num_features = num_features + num_layers * growth_rate 132 | if i != len(block_config) - 1: 133 | trans = _Transition(num_input_features=num_features, 134 | num_output_features=num_features // 2) 135 | self.features.add_module('transition%d' % (i + 1), trans) 136 | num_features = num_features // 2 137 | 138 | # Final batch norm 139 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 140 | 141 | # Linear layer 142 | self.classifier = nn.Linear(num_features, num_classes) 143 | 144 | # Official init from torch repo. 145 | for m in self.modules(): 146 | if isinstance(m, nn.Conv2d): 147 | nn.init.kaiming_normal_(m.weight) 148 | elif isinstance(m, nn.BatchNorm2d): 149 | nn.init.constant_(m.weight, 1) 150 | nn.init.constant_(m.bias, 0) 151 | elif isinstance(m, nn.Linear): 152 | nn.init.constant_(m.bias, 0) 153 | 154 | def forward(self, x): 155 | features = self.features(x) 156 | out = F.relu(features, inplace=True) 157 | out = F.adaptive_avg_pool2d(out, (1, 1)) 158 | out = torch.flatten(out, 1) 159 | out = self.classifier(out) 160 | return out 161 | 162 | 163 | def _load_state_dict(model, model_url, progress): 164 | # '.'s are no longer allowed in module names, but previous _DenseLayer 165 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 166 | # They are also in the checkpoints in model_urls. This pattern is used 167 | # to find such keys. 168 | pattern = re.compile( 169 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 170 | 171 | state_dict = load_state_dict_from_url(model_url, progress=progress) 172 | for key in list(state_dict.keys()): 173 | res = pattern.match(key) 174 | if res: 175 | new_key = res.group(1) + res.group(2) 176 | state_dict[new_key] = state_dict[key] 177 | del state_dict[key] 178 | model.load_state_dict(state_dict) 179 | 180 | 181 | def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, 182 | **kwargs): 183 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) 184 | if pretrained: 185 | _load_state_dict(model, model_urls[arch], progress) 186 | return model 187 | 188 | 189 | def densenet121(pretrained=False, progress=True, **kwargs): 190 | r"""Densenet-121 model from 191 | `"Densely Connected Convolutional Networks" `_ 192 | 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | progress (bool): If True, displays a progress bar of the download to stderr 196 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 197 | but slower. Default: *False*. See `"paper" `_ 198 | """ 199 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, 200 | **kwargs) 201 | 202 | 203 | def densenet161(pretrained=False, progress=True, **kwargs): 204 | r"""Densenet-161 model from 205 | `"Densely Connected Convolutional Networks" `_ 206 | 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | progress (bool): If True, displays a progress bar of the download to stderr 210 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 211 | but slower. Default: *False*. See `"paper" `_ 212 | """ 213 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, 214 | **kwargs) 215 | 216 | 217 | def densenet169(pretrained=False, progress=True, **kwargs): 218 | r"""Densenet-169 model from 219 | `"Densely Connected Convolutional Networks" `_ 220 | 221 | Args: 222 | pretrained (bool): If True, returns a model pre-trained on ImageNet 223 | progress (bool): If True, displays a progress bar of the download to stderr 224 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 225 | but slower. Default: *False*. See `"paper" `_ 226 | """ 227 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, 228 | **kwargs) 229 | 230 | 231 | def densenet201(pretrained=False, progress=True, **kwargs): 232 | r"""Densenet-201 model from 233 | `"Densely Connected Convolutional Networks" `_ 234 | 235 | Args: 236 | pretrained (bool): If True, returns a model pre-trained on ImageNet 237 | progress (bool): If True, displays a progress bar of the download to stderr 238 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 239 | but slower. Default: *False*. See `"paper" `_ 240 | """ 241 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, 242 | **kwargs) 243 | 244 | 245 | if __name__ == "__main__": 246 | dense121 = densenet121(pretrained=False,progress=True,num_classes=40) 247 | print(dense121) -------------------------------------------------------------------------------- /models/googlenet.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import namedtuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .utils import load_state_dict_from_url 7 | 8 | __all__ = ['GoogLeNet', 'googlenet'] 9 | 10 | model_urls = { 11 | # GoogLeNet ported from TensorFlow 12 | 'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth', 13 | } 14 | 15 | _GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) 16 | 17 | 18 | def googlenet(pretrained=False, progress=True, **kwargs): 19 | r"""GoogLeNet (Inception v1) model architecture from 20 | `"Going Deeper with Convolutions" `_. 21 | 22 | Args: 23 | pretrained (bool): If True, returns a model pre-trained on ImageNet 24 | progress (bool): If True, displays a progress bar of the download to stderr 25 | aux_logits (bool): If True, adds two auxiliary branches that can improve training. 26 | Default: *False* when pretrained is True otherwise *True* 27 | transform_input (bool): If True, preprocesses the input according to the method with which it 28 | was trained on ImageNet. Default: *False* 29 | """ 30 | if pretrained: 31 | if 'transform_input' not in kwargs: 32 | kwargs['transform_input'] = True 33 | if 'aux_logits' not in kwargs: 34 | kwargs['aux_logits'] = False 35 | if kwargs['aux_logits']: 36 | warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, ' 37 | 'so make sure to train them') 38 | original_aux_logits = kwargs['aux_logits'] 39 | kwargs['aux_logits'] = True 40 | kwargs['init_weights'] = False 41 | model = GoogLeNet(**kwargs) 42 | state_dict = load_state_dict_from_url(model_urls['googlenet'], 43 | progress=progress) 44 | model.load_state_dict(state_dict) 45 | if not original_aux_logits: 46 | model.aux_logits = False 47 | del model.aux1, model.aux2 48 | return model 49 | 50 | return GoogLeNet(**kwargs) 51 | 52 | 53 | class GoogLeNet(nn.Module): 54 | 55 | def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True): 56 | super(GoogLeNet, self).__init__() 57 | self.aux_logits = aux_logits 58 | self.transform_input = transform_input 59 | 60 | self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) 61 | self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 62 | self.conv2 = BasicConv2d(64, 64, kernel_size=1) 63 | self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) 64 | self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 65 | 66 | self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) 67 | self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) 68 | self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 69 | 70 | self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) 71 | self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) 72 | self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) 73 | self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) 74 | self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) 75 | self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 76 | 77 | self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) 78 | self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) 79 | 80 | if aux_logits: 81 | self.aux1 = InceptionAux(512, num_classes) 82 | self.aux2 = InceptionAux(528, num_classes) 83 | 84 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 85 | self.dropout = nn.Dropout(0.2) 86 | self.fc = nn.Linear(1024, num_classes) 87 | 88 | if init_weights: 89 | self._initialize_weights() 90 | 91 | def _initialize_weights(self): 92 | for m in self.modules(): 93 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 94 | import scipy.stats as stats 95 | X = stats.truncnorm(-2, 2, scale=0.01) 96 | values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 97 | values = values.view(m.weight.size()) 98 | with torch.no_grad(): 99 | m.weight.copy_(values) 100 | elif isinstance(m, nn.BatchNorm2d): 101 | nn.init.constant_(m.weight, 1) 102 | nn.init.constant_(m.bias, 0) 103 | 104 | def forward(self, x): 105 | if self.transform_input: 106 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 107 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 108 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 109 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 110 | 111 | # N x 3 x 224 x 224 112 | x = self.conv1(x) 113 | # N x 64 x 112 x 112 114 | x = self.maxpool1(x) 115 | # N x 64 x 56 x 56 116 | x = self.conv2(x) 117 | # N x 64 x 56 x 56 118 | x = self.conv3(x) 119 | # N x 192 x 56 x 56 120 | x = self.maxpool2(x) 121 | 122 | # N x 192 x 28 x 28 123 | x = self.inception3a(x) 124 | # N x 256 x 28 x 28 125 | x = self.inception3b(x) 126 | # N x 480 x 28 x 28 127 | x = self.maxpool3(x) 128 | # N x 480 x 14 x 14 129 | x = self.inception4a(x) 130 | # N x 512 x 14 x 14 131 | if self.training and self.aux_logits: 132 | aux1 = self.aux1(x) 133 | 134 | x = self.inception4b(x) 135 | # N x 512 x 14 x 14 136 | x = self.inception4c(x) 137 | # N x 512 x 14 x 14 138 | x = self.inception4d(x) 139 | # N x 528 x 14 x 14 140 | if self.training and self.aux_logits: 141 | aux2 = self.aux2(x) 142 | 143 | x = self.inception4e(x) 144 | # N x 832 x 14 x 14 145 | x = self.maxpool4(x) 146 | # N x 832 x 7 x 7 147 | x = self.inception5a(x) 148 | # N x 832 x 7 x 7 149 | x = self.inception5b(x) 150 | # N x 1024 x 7 x 7 151 | 152 | x = self.avgpool(x) 153 | # N x 1024 x 1 x 1 154 | x = torch.flatten(x, 1) 155 | # N x 1024 156 | x = self.dropout(x) 157 | x = self.fc(x) 158 | # N x 1000 (num_classes) 159 | if self.training and self.aux_logits: 160 | return _GoogLeNetOutputs(x, aux2, aux1) 161 | return x 162 | 163 | 164 | class Inception(nn.Module): 165 | 166 | def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): 167 | super(Inception, self).__init__() 168 | 169 | self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) 170 | 171 | self.branch2 = nn.Sequential( 172 | BasicConv2d(in_channels, ch3x3red, kernel_size=1), 173 | BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) 174 | ) 175 | 176 | self.branch3 = nn.Sequential( 177 | BasicConv2d(in_channels, ch5x5red, kernel_size=1), 178 | BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1) 179 | ) 180 | 181 | self.branch4 = nn.Sequential( 182 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 183 | BasicConv2d(in_channels, pool_proj, kernel_size=1) 184 | ) 185 | 186 | def forward(self, x): 187 | branch1 = self.branch1(x) 188 | branch2 = self.branch2(x) 189 | branch3 = self.branch3(x) 190 | branch4 = self.branch4(x) 191 | 192 | outputs = [branch1, branch2, branch3, branch4] 193 | return torch.cat(outputs, 1) 194 | 195 | 196 | class InceptionAux(nn.Module): 197 | 198 | def __init__(self, in_channels, num_classes): 199 | super(InceptionAux, self).__init__() 200 | self.conv = BasicConv2d(in_channels, 128, kernel_size=1) 201 | 202 | self.fc1 = nn.Linear(2048, 1024) 203 | self.fc2 = nn.Linear(1024, num_classes) 204 | 205 | def forward(self, x): 206 | # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 207 | x = F.adaptive_avg_pool2d(x, (4, 4)) 208 | # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 209 | x = self.conv(x) 210 | # N x 128 x 4 x 4 211 | x = torch.flatten(x, 1) 212 | # N x 2048 213 | x = F.relu(self.fc1(x), inplace=True) 214 | # N x 1024 215 | x = F.dropout(x, 0.7, training=self.training) 216 | # N x 1024 217 | x = self.fc2(x) 218 | # N x 1000 (num_classes) 219 | 220 | return x 221 | 222 | 223 | class BasicConv2d(nn.Module): 224 | 225 | def __init__(self, in_channels, out_channels, **kwargs): 226 | super(BasicConv2d, self).__init__() 227 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 228 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 229 | 230 | def forward(self, x): 231 | x = self.conv(x) 232 | x = self.bn(x) 233 | return F.relu(x, inplace=True) 234 | -------------------------------------------------------------------------------- /models/inception.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .utils import load_state_dict_from_url 6 | 7 | 8 | __all__ = ['Inception3', 'inception_v3'] 9 | 10 | 11 | model_urls = { 12 | # Inception v3 ported from TensorFlow 13 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 14 | } 15 | 16 | _InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits']) 17 | 18 | 19 | def inception_v3(pretrained=False, progress=True, **kwargs): 20 | r"""Inception v3 model architecture from 21 | `"Rethinking the Inception Architecture for Computer Vision" `_. 22 | 23 | .. note:: 24 | **Important**: In contrast to the other models the inception_v3 expects tensors with a size of 25 | N x 3 x 299 x 299, so ensure your images are sized accordingly. 26 | 27 | Args: 28 | pretrained (bool): If True, returns a model pre-trained on ImageNet 29 | progress (bool): If True, displays a progress bar of the download to stderr 30 | aux_logits (bool): If True, add an auxiliary branch that can improve training. 31 | Default: *True* 32 | transform_input (bool): If True, preprocesses the input according to the method with which it 33 | was trained on ImageNet. Default: *False* 34 | """ 35 | if pretrained: 36 | if 'transform_input' not in kwargs: 37 | kwargs['transform_input'] = True 38 | if 'aux_logits' in kwargs: 39 | original_aux_logits = kwargs['aux_logits'] 40 | kwargs['aux_logits'] = True 41 | else: 42 | original_aux_logits = True 43 | model = Inception3(**kwargs) 44 | state_dict = load_state_dict_from_url(model_urls['inception_v3_google'], 45 | progress=progress) 46 | model.load_state_dict(state_dict) 47 | if not original_aux_logits: 48 | model.aux_logits = False 49 | del model.AuxLogits 50 | return model 51 | 52 | return Inception3(**kwargs) 53 | 54 | 55 | class Inception3(nn.Module): 56 | 57 | def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): 58 | super(Inception3, self).__init__() 59 | self.aux_logits = aux_logits 60 | self.transform_input = transform_input 61 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) 62 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 63 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 64 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 65 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 66 | self.Mixed_5b = InceptionA(192, pool_features=32) 67 | self.Mixed_5c = InceptionA(256, pool_features=64) 68 | self.Mixed_5d = InceptionA(288, pool_features=64) 69 | self.Mixed_6a = InceptionB(288) 70 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 71 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 72 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 73 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 74 | if aux_logits: 75 | self.AuxLogits = InceptionAux(768, num_classes) 76 | self.Mixed_7a = InceptionD(768) 77 | self.Mixed_7b = InceptionE(1280) 78 | self.Mixed_7c = InceptionE(2048) 79 | self.fc = nn.Linear(2048, num_classes) 80 | 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 83 | import scipy.stats as stats 84 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1 85 | X = stats.truncnorm(-2, 2, scale=stddev) 86 | values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 87 | values = values.view(m.weight.size()) 88 | with torch.no_grad(): 89 | m.weight.copy_(values) 90 | elif isinstance(m, nn.BatchNorm2d): 91 | nn.init.constant_(m.weight, 1) 92 | nn.init.constant_(m.bias, 0) 93 | 94 | def forward(self, x): 95 | if self.transform_input: 96 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 97 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 98 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 99 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 100 | # N x 3 x 299 x 299 101 | x = self.Conv2d_1a_3x3(x) 102 | # N x 32 x 149 x 149 103 | x = self.Conv2d_2a_3x3(x) 104 | # N x 32 x 147 x 147 105 | x = self.Conv2d_2b_3x3(x) 106 | # N x 64 x 147 x 147 107 | x = F.max_pool2d(x, kernel_size=3, stride=2) 108 | # N x 64 x 73 x 73 109 | x = self.Conv2d_3b_1x1(x) 110 | # N x 80 x 73 x 73 111 | x = self.Conv2d_4a_3x3(x) 112 | # N x 192 x 71 x 71 113 | x = F.max_pool2d(x, kernel_size=3, stride=2) 114 | # N x 192 x 35 x 35 115 | x = self.Mixed_5b(x) 116 | # N x 256 x 35 x 35 117 | x = self.Mixed_5c(x) 118 | # N x 288 x 35 x 35 119 | x = self.Mixed_5d(x) 120 | # N x 288 x 35 x 35 121 | x = self.Mixed_6a(x) 122 | # N x 768 x 17 x 17 123 | x = self.Mixed_6b(x) 124 | # N x 768 x 17 x 17 125 | x = self.Mixed_6c(x) 126 | # N x 768 x 17 x 17 127 | x = self.Mixed_6d(x) 128 | # N x 768 x 17 x 17 129 | x = self.Mixed_6e(x) 130 | # N x 768 x 17 x 17 131 | if self.training and self.aux_logits: 132 | aux = self.AuxLogits(x) 133 | # N x 768 x 17 x 17 134 | x = self.Mixed_7a(x) 135 | # N x 1280 x 8 x 8 136 | x = self.Mixed_7b(x) 137 | # N x 2048 x 8 x 8 138 | x = self.Mixed_7c(x) 139 | # N x 2048 x 8 x 8 140 | # Adaptive average pooling 141 | x = F.adaptive_avg_pool2d(x, (1, 1)) 142 | # N x 2048 x 1 x 1 143 | x = F.dropout(x, training=self.training) 144 | # N x 2048 x 1 x 1 145 | x = torch.flatten(x, 1) 146 | # N x 2048 147 | x = self.fc(x) 148 | # N x 1000 (num_classes) 149 | if self.training and self.aux_logits: 150 | return _InceptionOutputs(x, aux) 151 | return x 152 | 153 | 154 | class InceptionA(nn.Module): 155 | 156 | def __init__(self, in_channels, pool_features): 157 | super(InceptionA, self).__init__() 158 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) 159 | 160 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) 161 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 162 | 163 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 164 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 165 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) 166 | 167 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) 168 | 169 | def forward(self, x): 170 | branch1x1 = self.branch1x1(x) 171 | 172 | branch5x5 = self.branch5x5_1(x) 173 | branch5x5 = self.branch5x5_2(branch5x5) 174 | 175 | branch3x3dbl = self.branch3x3dbl_1(x) 176 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 177 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 178 | 179 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 180 | branch_pool = self.branch_pool(branch_pool) 181 | 182 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 183 | return torch.cat(outputs, 1) 184 | 185 | 186 | class InceptionB(nn.Module): 187 | 188 | def __init__(self, in_channels): 189 | super(InceptionB, self).__init__() 190 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) 191 | 192 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 193 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 194 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) 195 | 196 | def forward(self, x): 197 | branch3x3 = self.branch3x3(x) 198 | 199 | branch3x3dbl = self.branch3x3dbl_1(x) 200 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 201 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 202 | 203 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 204 | 205 | outputs = [branch3x3, branch3x3dbl, branch_pool] 206 | return torch.cat(outputs, 1) 207 | 208 | 209 | class InceptionC(nn.Module): 210 | 211 | def __init__(self, in_channels, channels_7x7): 212 | super(InceptionC, self).__init__() 213 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) 214 | 215 | c7 = channels_7x7 216 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) 217 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 218 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 219 | 220 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) 221 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 222 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 223 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 224 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 225 | 226 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 227 | 228 | def forward(self, x): 229 | branch1x1 = self.branch1x1(x) 230 | 231 | branch7x7 = self.branch7x7_1(x) 232 | branch7x7 = self.branch7x7_2(branch7x7) 233 | branch7x7 = self.branch7x7_3(branch7x7) 234 | 235 | branch7x7dbl = self.branch7x7dbl_1(x) 236 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 237 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 238 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 239 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 240 | 241 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 242 | branch_pool = self.branch_pool(branch_pool) 243 | 244 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 245 | return torch.cat(outputs, 1) 246 | 247 | 248 | class InceptionD(nn.Module): 249 | 250 | def __init__(self, in_channels): 251 | super(InceptionD, self).__init__() 252 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 253 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) 254 | 255 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 256 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) 257 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) 258 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) 259 | 260 | def forward(self, x): 261 | branch3x3 = self.branch3x3_1(x) 262 | branch3x3 = self.branch3x3_2(branch3x3) 263 | 264 | branch7x7x3 = self.branch7x7x3_1(x) 265 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 266 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 267 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 268 | 269 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 270 | outputs = [branch3x3, branch7x7x3, branch_pool] 271 | return torch.cat(outputs, 1) 272 | 273 | 274 | class InceptionE(nn.Module): 275 | 276 | def __init__(self, in_channels): 277 | super(InceptionE, self).__init__() 278 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 279 | 280 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 281 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 282 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 283 | 284 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 285 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 286 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 287 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 288 | 289 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 290 | 291 | def forward(self, x): 292 | branch1x1 = self.branch1x1(x) 293 | 294 | branch3x3 = self.branch3x3_1(x) 295 | branch3x3 = [ 296 | self.branch3x3_2a(branch3x3), 297 | self.branch3x3_2b(branch3x3), 298 | ] 299 | branch3x3 = torch.cat(branch3x3, 1) 300 | 301 | branch3x3dbl = self.branch3x3dbl_1(x) 302 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 303 | branch3x3dbl = [ 304 | self.branch3x3dbl_3a(branch3x3dbl), 305 | self.branch3x3dbl_3b(branch3x3dbl), 306 | ] 307 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 308 | 309 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 310 | branch_pool = self.branch_pool(branch_pool) 311 | 312 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 313 | return torch.cat(outputs, 1) 314 | 315 | 316 | class InceptionAux(nn.Module): 317 | 318 | def __init__(self, in_channels, num_classes): 319 | super(InceptionAux, self).__init__() 320 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) 321 | self.conv1 = BasicConv2d(128, 768, kernel_size=5) 322 | self.conv1.stddev = 0.01 323 | self.fc = nn.Linear(768, num_classes) 324 | self.fc.stddev = 0.001 325 | 326 | def forward(self, x): 327 | # N x 768 x 17 x 17 328 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 329 | # N x 768 x 5 x 5 330 | x = self.conv0(x) 331 | # N x 128 x 5 x 5 332 | x = self.conv1(x) 333 | # N x 768 x 1 x 1 334 | # Adaptive average pooling 335 | x = F.adaptive_avg_pool2d(x, (1, 1)) 336 | # N x 768 x 1 x 1 337 | x = torch.flatten(x, 1) 338 | # N x 768 339 | x = self.fc(x) 340 | # N x 1000 341 | return x 342 | 343 | 344 | class BasicConv2d(nn.Module): 345 | 346 | def __init__(self, in_channels, out_channels, **kwargs): 347 | super(BasicConv2d, self).__init__() 348 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 349 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 350 | 351 | def forward(self, x): 352 | x = self.conv(x) 353 | x = self.bn(x) 354 | return F.relu(x, inplace=True) 355 | -------------------------------------------------------------------------------- /models/mnasnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from .utils import load_state_dict_from_url 6 | 7 | __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] 8 | 9 | _MODEL_URLS = { 10 | "mnasnet0_5": 11 | "https://download.pytorch.org/models/mnasnet0.5_top1_67.592-7c6cb539b9.pth", 12 | "mnasnet0_75": None, 13 | "mnasnet1_0": 14 | "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", 15 | "mnasnet1_3": None 16 | } 17 | 18 | # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is 19 | # 1.0 - tensorflow. 20 | _BN_MOMENTUM = 1 - 0.9997 21 | 22 | 23 | class _InvertedResidual(nn.Module): 24 | 25 | def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, 26 | bn_momentum=0.1): 27 | super(_InvertedResidual, self).__init__() 28 | assert stride in [1, 2] 29 | assert kernel_size in [3, 5] 30 | mid_ch = in_ch * expansion_factor 31 | self.apply_residual = (in_ch == out_ch and stride == 1) 32 | self.layers = nn.Sequential( 33 | # Pointwise 34 | nn.Conv2d(in_ch, mid_ch, 1, bias=False), 35 | nn.BatchNorm2d(mid_ch, momentum=bn_momentum), 36 | nn.ReLU(inplace=True), 37 | # Depthwise 38 | nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, 39 | stride=stride, groups=mid_ch, bias=False), 40 | nn.BatchNorm2d(mid_ch, momentum=bn_momentum), 41 | nn.ReLU(inplace=True), 42 | # Linear pointwise. Note that there's no activation. 43 | nn.Conv2d(mid_ch, out_ch, 1, bias=False), 44 | nn.BatchNorm2d(out_ch, momentum=bn_momentum)) 45 | 46 | def forward(self, input): 47 | if self.apply_residual: 48 | return self.layers(input) + input 49 | else: 50 | return self.layers(input) 51 | 52 | 53 | def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, 54 | bn_momentum): 55 | """ Creates a stack of inverted residuals. """ 56 | assert repeats >= 1 57 | # First one has no skip, because feature map size changes. 58 | first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, 59 | bn_momentum=bn_momentum) 60 | remaining = [] 61 | for _ in range(1, repeats): 62 | remaining.append( 63 | _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, 64 | bn_momentum=bn_momentum)) 65 | return nn.Sequential(first, *remaining) 66 | 67 | 68 | def _round_to_multiple_of(val, divisor, round_up_bias=0.9): 69 | """ Asymmetric rounding to make `val` divisible by `divisor`. With default 70 | bias, will round up, unless the number is no more than 10% greater than the 71 | smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ 72 | assert 0.0 < round_up_bias < 1.0 73 | new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) 74 | return new_val if new_val >= round_up_bias * val else new_val + divisor 75 | 76 | 77 | def _scale_depths(depths, alpha): 78 | """ Scales tensor depths as in reference MobileNet code, prefers rouding up 79 | rather than down. """ 80 | return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] 81 | 82 | 83 | class MNASNet(torch.nn.Module): 84 | """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. 85 | >>> model = MNASNet(1000, 1.0) 86 | >>> x = torch.rand(1, 3, 224, 224) 87 | >>> y = model(x) 88 | >>> y.dim() 89 | 1 90 | >>> y.nelement() 91 | 1000 92 | """ 93 | 94 | def __init__(self, alpha, num_classes=1000, dropout=0.2): 95 | super(MNASNet, self).__init__() 96 | depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha) 97 | layers = [ 98 | # First layer: regular conv. 99 | nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), 100 | nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), 101 | nn.ReLU(inplace=True), 102 | # Depthwise separable, no skip. 103 | nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False), 104 | nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), 105 | nn.ReLU(inplace=True), 106 | nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), 107 | nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), 108 | # MNASNet blocks: stacks of inverted residuals. 109 | _stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM), 110 | _stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM), 111 | _stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM), 112 | _stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM), 113 | _stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM), 114 | _stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM), 115 | # Final mapping to classifier input. 116 | nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False), 117 | nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), 118 | nn.ReLU(inplace=True), 119 | ] 120 | self.layers = nn.Sequential(*layers) 121 | self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), 122 | nn.Linear(1280, num_classes)) 123 | self._initialize_weights() 124 | 125 | def forward(self, x): 126 | x = self.layers(x) 127 | # Equivalent to global avgpool and removing H and W dimensions. 128 | x = x.mean([2, 3]) 129 | return self.classifier(x) 130 | 131 | def _initialize_weights(self): 132 | for m in self.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | nn.init.kaiming_normal_(m.weight, mode="fan_out", 135 | nonlinearity="relu") 136 | if m.bias is not None: 137 | nn.init.zeros_(m.bias) 138 | elif isinstance(m, nn.BatchNorm2d): 139 | nn.init.ones_(m.weight) 140 | nn.init.zeros_(m.bias) 141 | elif isinstance(m, nn.Linear): 142 | nn.init.normal_(m.weight, 0.01) 143 | nn.init.zeros_(m.bias) 144 | 145 | 146 | def _load_pretrained(model_name, model, progress): 147 | if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: 148 | raise ValueError( 149 | "No checkpoint is available for model type {}".format(model_name)) 150 | checkpoint_url = _MODEL_URLS[model_name] 151 | model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress)) 152 | 153 | 154 | def mnasnet0_5(pretrained=False, progress=True, **kwargs): 155 | """MNASNet with depth multiplier of 0.5 from 156 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 157 | `_. 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | progress (bool): If True, displays a progress bar of the download to stderr 161 | """ 162 | model = MNASNet(0.5, **kwargs) 163 | if pretrained: 164 | _load_pretrained("mnasnet0_5", model, progress) 165 | return model 166 | 167 | 168 | def mnasnet0_75(pretrained=False, progress=True, **kwargs): 169 | """MNASNet with depth multiplier of 0.75 from 170 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 171 | `_. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | progress (bool): If True, displays a progress bar of the download to stderr 175 | """ 176 | model = MNASNet(0.75, **kwargs) 177 | if pretrained: 178 | _load_pretrained("mnasnet0_75", model, progress) 179 | return model 180 | 181 | 182 | def mnasnet1_0(pretrained=False, progress=True, **kwargs): 183 | """MNASNet with depth multiplier of 1.0 from 184 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 185 | `_. 186 | Args: 187 | pretrained (bool): If True, returns a model pre-trained on ImageNet 188 | progress (bool): If True, displays a progress bar of the download to stderr 189 | """ 190 | model = MNASNet(1.0, **kwargs) 191 | if pretrained: 192 | _load_pretrained("mnasnet1_0", model, progress) 193 | return model 194 | 195 | 196 | def mnasnet1_3(pretrained=False, progress=True, **kwargs): 197 | """MNASNet with depth multiplier of 1.3 from 198 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 199 | `_. 200 | Args: 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | progress (bool): If True, displays a progress bar of the download to stderr 203 | """ 204 | model = MNASNet(1.3, **kwargs) 205 | if pretrained: 206 | _load_pretrained("mnasnet1_3", model, progress) 207 | return model 208 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .utils import load_state_dict_from_url 3 | 4 | 5 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 6 | 7 | 8 | model_urls = { 9 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 10 | } 11 | 12 | 13 | def _make_divisible(v, divisor, min_value=None): 14 | """ 15 | This function is taken from the original tf repo. 16 | It ensures that all layers have a channel number that is divisible by 8 17 | It can be seen here: 18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 19 | :param v: 20 | :param divisor: 21 | :param min_value: 22 | :return: 23 | """ 24 | if min_value is None: 25 | min_value = divisor 26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than 10%. 28 | if new_v < 0.9 * v: 29 | new_v += divisor 30 | return new_v 31 | 32 | 33 | class ConvBNReLU(nn.Sequential): 34 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 35 | padding = (kernel_size - 1) // 2 36 | super(ConvBNReLU, self).__init__( 37 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 38 | nn.BatchNorm2d(out_planes), 39 | nn.ReLU6(inplace=True) 40 | ) 41 | 42 | 43 | class InvertedResidual(nn.Module): 44 | def __init__(self, inp, oup, stride, expand_ratio): 45 | super(InvertedResidual, self).__init__() 46 | self.stride = stride 47 | assert stride in [1, 2] 48 | 49 | hidden_dim = int(round(inp * expand_ratio)) 50 | self.use_res_connect = self.stride == 1 and inp == oup 51 | 52 | layers = [] 53 | if expand_ratio != 1: 54 | # pw 55 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 56 | layers.extend([ 57 | # dw 58 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 59 | # pw-linear 60 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 61 | nn.BatchNorm2d(oup), 62 | ]) 63 | self.conv = nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | if self.use_res_connect: 67 | return x + self.conv(x) 68 | else: 69 | return self.conv(x) 70 | 71 | 72 | class MobileNetV2(nn.Module): 73 | def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8): 74 | """ 75 | MobileNet V2 main class 76 | 77 | Args: 78 | num_classes (int): Number of classes 79 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 80 | inverted_residual_setting: Network structure 81 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 82 | Set to 1 to turn off rounding 83 | """ 84 | super(MobileNetV2, self).__init__() 85 | block = InvertedResidual 86 | input_channel = 32 87 | last_channel = 1280 88 | 89 | if inverted_residual_setting is None: 90 | inverted_residual_setting = [ 91 | # t, c, n, s 92 | [1, 16, 1, 1], 93 | [6, 24, 2, 2], 94 | [6, 32, 3, 2], 95 | [6, 64, 4, 2], 96 | [6, 96, 3, 1], 97 | [6, 160, 3, 2], 98 | [6, 320, 1, 1], 99 | ] 100 | 101 | # only check the first element, assuming user knows t,c,n,s are required 102 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 103 | raise ValueError("inverted_residual_setting should be non-empty " 104 | "or a 4-element list, got {}".format(inverted_residual_setting)) 105 | 106 | # building first layer 107 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 108 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 109 | features = [ConvBNReLU(3, input_channel, stride=2)] 110 | # building inverted residual blocks 111 | for t, c, n, s in inverted_residual_setting: 112 | output_channel = _make_divisible(c * width_mult, round_nearest) 113 | for i in range(n): 114 | stride = s if i == 0 else 1 115 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 116 | input_channel = output_channel 117 | # building last several layers 118 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 119 | # make it nn.Sequential 120 | self.features = nn.Sequential(*features) 121 | 122 | # building classifier 123 | self.classifier = nn.Sequential( 124 | nn.Dropout(0.2), 125 | nn.Linear(self.last_channel, num_classes), 126 | ) 127 | 128 | # weight initialization 129 | for m in self.modules(): 130 | if isinstance(m, nn.Conv2d): 131 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 132 | if m.bias is not None: 133 | nn.init.zeros_(m.bias) 134 | elif isinstance(m, nn.BatchNorm2d): 135 | nn.init.ones_(m.weight) 136 | nn.init.zeros_(m.bias) 137 | elif isinstance(m, nn.Linear): 138 | nn.init.normal_(m.weight, 0, 0.01) 139 | nn.init.zeros_(m.bias) 140 | 141 | def forward(self, x): 142 | x = self.features(x) 143 | x = x.mean([2, 3]) 144 | x = self.classifier(x) 145 | return x 146 | 147 | 148 | def mobilenet_v2(pretrained=False, progress=True, **kwargs): 149 | """ 150 | Constructs a MobileNetV2 architecture from 151 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 152 | 153 | Args: 154 | pretrained (bool): If True, returns a model pre-trained on ImageNet 155 | progress (bool): If True, displays a progress bar of the download to stderr 156 | """ 157 | model = MobileNetV2(**kwargs) 158 | if pretrained: 159 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 160 | progress=progress) 161 | model.load_state_dict(state_dict) 162 | return model 163 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .utils import load_state_dict_from_url 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=dilation, groups=groups, bias=False, dilation=dilation) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | __constants__ = ['downsample'] 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 40 | base_width=64, dilation=1, norm_layer=None): 41 | super(BasicBlock, self).__init__() 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | if groups != 1 or base_width != 64: 45 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 46 | if dilation > 1: 47 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 49 | self.conv1 = conv3x3(inplanes, planes, stride) 50 | self.bn1 = norm_layer(planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn2 = norm_layer(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | identity = self.downsample(x) 69 | 70 | out += identity 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | 76 | class Bottleneck(nn.Module): 77 | expansion = 4 78 | 79 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 80 | base_width=64, dilation=1, norm_layer=None): 81 | super(Bottleneck, self).__init__() 82 | if norm_layer is None: 83 | norm_layer = nn.BatchNorm2d 84 | width = int(planes * (base_width / 64.)) * groups 85 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 86 | self.conv1 = conv1x1(inplanes, width) 87 | self.bn1 = norm_layer(width) 88 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 89 | self.bn2 = norm_layer(width) 90 | self.conv3 = conv1x1(width, planes * self.expansion) 91 | self.bn3 = norm_layer(planes * self.expansion) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.downsample = downsample 94 | self.stride = stride 95 | 96 | def forward(self, x): 97 | identity = x 98 | 99 | out = self.conv1(x) 100 | out = self.bn1(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv2(out) 104 | out = self.bn2(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv3(out) 108 | out = self.bn3(out) 109 | 110 | if self.downsample is not None: 111 | identity = self.downsample(x) 112 | 113 | out += identity 114 | out = self.relu(out) 115 | 116 | return out 117 | 118 | 119 | class ResNet(nn.Module): 120 | 121 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 122 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 123 | norm_layer=None): 124 | super(ResNet, self).__init__() 125 | if norm_layer is None: 126 | norm_layer = nn.BatchNorm2d 127 | self._norm_layer = norm_layer 128 | 129 | self.inplanes = 64 130 | self.dilation = 1 131 | if replace_stride_with_dilation is None: 132 | # each element in the tuple indicates if we should replace 133 | # the 2x2 stride with a dilated convolution instead 134 | replace_stride_with_dilation = [False, False, False] 135 | if len(replace_stride_with_dilation) != 3: 136 | raise ValueError("replace_stride_with_dilation should be None " 137 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 138 | self.groups = groups 139 | self.base_width = width_per_group 140 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 141 | bias=False) 142 | self.bn1 = norm_layer(self.inplanes) 143 | self.relu = nn.ReLU(inplace=True) 144 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 145 | self.layer1 = self._make_layer(block, 64, layers[0]) 146 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 147 | dilate=replace_stride_with_dilation[0]) 148 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 149 | dilate=replace_stride_with_dilation[1]) 150 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 151 | dilate=replace_stride_with_dilation[2]) 152 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 153 | self.fc = nn.Linear(512 * block.expansion, num_classes) 154 | 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 158 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 159 | nn.init.constant_(m.weight, 1) 160 | nn.init.constant_(m.bias, 0) 161 | 162 | # Zero-initialize the last BN in each residual branch, 163 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 164 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 165 | if zero_init_residual: 166 | for m in self.modules(): 167 | if isinstance(m, Bottleneck): 168 | nn.init.constant_(m.bn3.weight, 0) 169 | elif isinstance(m, BasicBlock): 170 | nn.init.constant_(m.bn2.weight, 0) 171 | 172 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 173 | norm_layer = self._norm_layer 174 | downsample = None 175 | previous_dilation = self.dilation 176 | if dilate: 177 | self.dilation *= stride 178 | stride = 1 179 | if stride != 1 or self.inplanes != planes * block.expansion: 180 | downsample = nn.Sequential( 181 | conv1x1(self.inplanes, planes * block.expansion, stride), 182 | norm_layer(planes * block.expansion), 183 | ) 184 | 185 | layers = [] 186 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 187 | self.base_width, previous_dilation, norm_layer)) 188 | self.inplanes = planes * block.expansion 189 | for _ in range(1, blocks): 190 | layers.append(block(self.inplanes, planes, groups=self.groups, 191 | base_width=self.base_width, dilation=self.dilation, 192 | norm_layer=norm_layer)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def forward(self, x): 197 | x = self.conv1(x) 198 | x = self.bn1(x) 199 | x = self.relu(x) 200 | x = self.maxpool(x) 201 | 202 | x = self.layer1(x) 203 | x = self.layer2(x) 204 | x = self.layer3(x) 205 | x = self.layer4(x) 206 | 207 | x = self.avgpool(x) 208 | x = torch.flatten(x, 1) 209 | x = self.fc(x) 210 | 211 | return x 212 | 213 | 214 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 215 | model = ResNet(block, layers, **kwargs) 216 | if pretrained: 217 | state_dict = load_state_dict_from_url(model_urls[arch], 218 | progress=progress) 219 | model.load_state_dict(state_dict) 220 | return model 221 | 222 | 223 | def resnet18(pretrained=False, progress=True, **kwargs): 224 | r"""ResNet-18 model from 225 | `"Deep Residual Learning for Image Recognition" `_ 226 | 227 | Args: 228 | pretrained (bool): If True, returns a model pre-trained on ImageNet 229 | progress (bool): If True, displays a progress bar of the download to stderr 230 | """ 231 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 232 | **kwargs) 233 | 234 | 235 | def resnet34(pretrained=False, progress=True, **kwargs): 236 | r"""ResNet-34 model from 237 | `"Deep Residual Learning for Image Recognition" `_ 238 | 239 | Args: 240 | pretrained (bool): If True, returns a model pre-trained on ImageNet 241 | progress (bool): If True, displays a progress bar of the download to stderr 242 | """ 243 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 244 | **kwargs) 245 | 246 | 247 | def resnet50(pretrained=False, progress=True, **kwargs): 248 | r"""ResNet-50 model from 249 | `"Deep Residual Learning for Image Recognition" `_ 250 | 251 | Args: 252 | pretrained (bool): If True, returns a model pre-trained on ImageNet 253 | progress (bool): If True, displays a progress bar of the download to stderr 254 | """ 255 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 256 | **kwargs) 257 | 258 | 259 | def resnet101(pretrained=False, progress=True, **kwargs): 260 | r"""ResNet-101 model from 261 | `"Deep Residual Learning for Image Recognition" `_ 262 | 263 | Args: 264 | pretrained (bool): If True, returns a model pre-trained on ImageNet 265 | progress (bool): If True, displays a progress bar of the download to stderr 266 | """ 267 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 268 | **kwargs) 269 | 270 | 271 | def resnet152(pretrained=False, progress=True, **kwargs): 272 | r"""ResNet-152 model from 273 | `"Deep Residual Learning for Image Recognition" `_ 274 | 275 | Args: 276 | pretrained (bool): If True, returns a model pre-trained on ImageNet 277 | progress (bool): If True, displays a progress bar of the download to stderr 278 | """ 279 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 280 | **kwargs) 281 | 282 | 283 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 284 | r"""ResNeXt-50 32x4d model from 285 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 286 | 287 | Args: 288 | pretrained (bool): If True, returns a model pre-trained on ImageNet 289 | progress (bool): If True, displays a progress bar of the download to stderr 290 | """ 291 | kwargs['groups'] = 32 292 | kwargs['width_per_group'] = 4 293 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 294 | pretrained, progress, **kwargs) 295 | 296 | 297 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 298 | r"""ResNeXt-101 32x8d model from 299 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 300 | 301 | Args: 302 | pretrained (bool): If True, returns a model pre-trained on ImageNet 303 | progress (bool): If True, displays a progress bar of the download to stderr 304 | """ 305 | kwargs['groups'] = 32 306 | kwargs['width_per_group'] = 8 307 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 308 | pretrained, progress, **kwargs) 309 | 310 | 311 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 312 | r"""Wide ResNet-50-2 model from 313 | `"Wide Residual Networks" `_ 314 | 315 | The model is the same as ResNet except for the bottleneck number of channels 316 | which is twice larger in every block. The number of channels in outer 1x1 317 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 318 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 319 | 320 | Args: 321 | pretrained (bool): If True, returns a model pre-trained on ImageNet 322 | progress (bool): If True, displays a progress bar of the download to stderr 323 | """ 324 | kwargs['width_per_group'] = 64 * 2 325 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 326 | pretrained, progress, **kwargs) 327 | 328 | 329 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 330 | r"""Wide ResNet-101-2 model from 331 | `"Wide Residual Networks" `_ 332 | 333 | The model is the same as ResNet except for the bottleneck number of channels 334 | which is twice larger in every block. The number of channels in outer 1x1 335 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 336 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 337 | 338 | Args: 339 | pretrained (bool): If True, returns a model pre-trained on ImageNet 340 | progress (bool): If True, displays a progress bar of the download to stderr 341 | """ 342 | kwargs['width_per_group'] = 64 * 2 343 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 344 | pretrained, progress, **kwargs) 345 | -------------------------------------------------------------------------------- /models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .utils import load_state_dict_from_url 4 | 5 | 6 | __all__ = [ 7 | 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 8 | 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' 9 | ] 10 | 11 | model_urls = { 12 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', 13 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', 14 | 'shufflenetv2_x1.5': None, 15 | 'shufflenetv2_x2.0': None, 16 | } 17 | 18 | 19 | def channel_shuffle(x, groups): 20 | # type: (torch.Tensor, int) -> torch.Tensor 21 | batchsize, num_channels, height, width = x.data.size() 22 | channels_per_group = num_channels // groups 23 | 24 | # reshape 25 | x = x.view(batchsize, groups, 26 | channels_per_group, height, width) 27 | 28 | x = torch.transpose(x, 1, 2).contiguous() 29 | 30 | # flatten 31 | x = x.view(batchsize, -1, height, width) 32 | 33 | return x 34 | 35 | 36 | class InvertedResidual(nn.Module): 37 | def __init__(self, inp, oup, stride): 38 | super(InvertedResidual, self).__init__() 39 | 40 | if not (1 <= stride <= 3): 41 | raise ValueError('illegal stride value') 42 | self.stride = stride 43 | 44 | branch_features = oup // 2 45 | assert (self.stride != 1) or (inp == branch_features << 1) 46 | 47 | if self.stride > 1: 48 | self.branch1 = nn.Sequential( 49 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 50 | nn.BatchNorm2d(inp), 51 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 52 | nn.BatchNorm2d(branch_features), 53 | nn.ReLU(inplace=True), 54 | ) 55 | else: 56 | self.branch1 = nn.Sequential() 57 | 58 | self.branch2 = nn.Sequential( 59 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 60 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 61 | nn.BatchNorm2d(branch_features), 62 | nn.ReLU(inplace=True), 63 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 64 | nn.BatchNorm2d(branch_features), 65 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 66 | nn.BatchNorm2d(branch_features), 67 | nn.ReLU(inplace=True), 68 | ) 69 | 70 | @staticmethod 71 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): 72 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 73 | 74 | def forward(self, x): 75 | if self.stride == 1: 76 | x1, x2 = x.chunk(2, dim=1) 77 | out = torch.cat((x1, self.branch2(x2)), dim=1) 78 | else: 79 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 80 | 81 | out = channel_shuffle(out, 2) 82 | 83 | return out 84 | 85 | 86 | class ShuffleNetV2(nn.Module): 87 | def __init__(self, stages_repeats, stages_out_channels, num_classes=1000): 88 | super(ShuffleNetV2, self).__init__() 89 | 90 | if len(stages_repeats) != 3: 91 | raise ValueError('expected stages_repeats as list of 3 positive ints') 92 | if len(stages_out_channels) != 5: 93 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 94 | self._stage_out_channels = stages_out_channels 95 | 96 | input_channels = 3 97 | output_channels = self._stage_out_channels[0] 98 | self.conv1 = nn.Sequential( 99 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), 100 | nn.BatchNorm2d(output_channels), 101 | nn.ReLU(inplace=True), 102 | ) 103 | input_channels = output_channels 104 | 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | 107 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 108 | for name, repeats, output_channels in zip( 109 | stage_names, stages_repeats, self._stage_out_channels[1:]): 110 | seq = [InvertedResidual(input_channels, output_channels, 2)] 111 | for i in range(repeats - 1): 112 | seq.append(InvertedResidual(output_channels, output_channels, 1)) 113 | setattr(self, name, nn.Sequential(*seq)) 114 | input_channels = output_channels 115 | 116 | output_channels = self._stage_out_channels[-1] 117 | self.conv5 = nn.Sequential( 118 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 119 | nn.BatchNorm2d(output_channels), 120 | nn.ReLU(inplace=True), 121 | ) 122 | 123 | self.fc = nn.Linear(output_channels, num_classes) 124 | 125 | def forward(self, x): 126 | x = self.conv1(x) 127 | x = self.maxpool(x) 128 | x = self.stage2(x) 129 | x = self.stage3(x) 130 | x = self.stage4(x) 131 | x = self.conv5(x) 132 | x = x.mean([2, 3]) # globalpool 133 | x = self.fc(x) 134 | return x 135 | 136 | 137 | def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): 138 | model = ShuffleNetV2(*args, **kwargs) 139 | 140 | if pretrained: 141 | model_url = model_urls[arch] 142 | if model_url is None: 143 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 144 | else: 145 | state_dict = load_state_dict_from_url(model_url, progress=progress) 146 | model.load_state_dict(state_dict) 147 | 148 | return model 149 | 150 | 151 | def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs): 152 | """ 153 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in 154 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 155 | `_. 156 | 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | progress (bool): If True, displays a progress bar of the download to stderr 160 | """ 161 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, 162 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) 163 | 164 | 165 | def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs): 166 | """ 167 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in 168 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 169 | `_. 170 | 171 | Args: 172 | pretrained (bool): If True, returns a model pre-trained on ImageNet 173 | progress (bool): If True, displays a progress bar of the download to stderr 174 | """ 175 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, 176 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) 177 | 178 | 179 | def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs): 180 | """ 181 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in 182 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 183 | `_. 184 | 185 | Args: 186 | pretrained (bool): If True, returns a model pre-trained on ImageNet 187 | progress (bool): If True, displays a progress bar of the download to stderr 188 | """ 189 | return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, 190 | [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) 191 | 192 | 193 | def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs): 194 | """ 195 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in 196 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 197 | `_. 198 | 199 | Args: 200 | pretrained (bool): If True, returns a model pre-trained on ImageNet 201 | progress (bool): If True, displays a progress bar of the download to stderr 202 | """ 203 | return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, 204 | [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) 205 | -------------------------------------------------------------------------------- /models/squeezenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | from .utils import load_state_dict_from_url 5 | 6 | __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] 7 | 8 | model_urls = { 9 | 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', 10 | 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', 11 | } 12 | 13 | 14 | class Fire(nn.Module): 15 | 16 | def __init__(self, inplanes, squeeze_planes, 17 | expand1x1_planes, expand3x3_planes): 18 | super(Fire, self).__init__() 19 | self.inplanes = inplanes 20 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 21 | self.squeeze_activation = nn.ReLU(inplace=True) 22 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 23 | kernel_size=1) 24 | self.expand1x1_activation = nn.ReLU(inplace=True) 25 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 26 | kernel_size=3, padding=1) 27 | self.expand3x3_activation = nn.ReLU(inplace=True) 28 | 29 | def forward(self, x): 30 | x = self.squeeze_activation(self.squeeze(x)) 31 | return torch.cat([ 32 | self.expand1x1_activation(self.expand1x1(x)), 33 | self.expand3x3_activation(self.expand3x3(x)) 34 | ], 1) 35 | 36 | 37 | class SqueezeNet(nn.Module): 38 | 39 | def __init__(self, version='1_0', num_classes=1000): 40 | super(SqueezeNet, self).__init__() 41 | self.num_classes = num_classes 42 | if version == '1_0': 43 | self.features = nn.Sequential( 44 | nn.Conv2d(3, 96, kernel_size=7, stride=2), 45 | nn.ReLU(inplace=True), 46 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 47 | Fire(96, 16, 64, 64), 48 | Fire(128, 16, 64, 64), 49 | Fire(128, 32, 128, 128), 50 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 51 | Fire(256, 32, 128, 128), 52 | Fire(256, 48, 192, 192), 53 | Fire(384, 48, 192, 192), 54 | Fire(384, 64, 256, 256), 55 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 56 | Fire(512, 64, 256, 256), 57 | ) 58 | elif version == '1_1': 59 | self.features = nn.Sequential( 60 | nn.Conv2d(3, 64, kernel_size=3, stride=2), 61 | nn.ReLU(inplace=True), 62 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 63 | Fire(64, 16, 64, 64), 64 | Fire(128, 16, 64, 64), 65 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 66 | Fire(128, 32, 128, 128), 67 | Fire(256, 32, 128, 128), 68 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 69 | Fire(256, 48, 192, 192), 70 | Fire(384, 48, 192, 192), 71 | Fire(384, 64, 256, 256), 72 | Fire(512, 64, 256, 256), 73 | ) 74 | else: 75 | # FIXME: Is this needed? SqueezeNet should only be called from the 76 | # FIXME: squeezenet1_x() functions 77 | # FIXME: This checking is not done for the other models 78 | raise ValueError("Unsupported SqueezeNet version {version}:" 79 | "1_0 or 1_1 expected".format(version=version)) 80 | 81 | # Final convolution is initialized differently from the rest 82 | final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) 83 | self.classifier = nn.Sequential( 84 | nn.Dropout(p=0.5), 85 | final_conv, 86 | nn.ReLU(inplace=True), 87 | nn.AdaptiveAvgPool2d((1, 1)) 88 | ) 89 | 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv2d): 92 | if m is final_conv: 93 | init.normal_(m.weight, mean=0.0, std=0.01) 94 | else: 95 | init.kaiming_uniform_(m.weight) 96 | if m.bias is not None: 97 | init.constant_(m.bias, 0) 98 | 99 | def forward(self, x): 100 | x = self.features(x) 101 | x = self.classifier(x) 102 | return torch.flatten(x, 1) 103 | 104 | 105 | def _squeezenet(version, pretrained, progress, **kwargs): 106 | model = SqueezeNet(version, **kwargs) 107 | if pretrained: 108 | arch = 'squeezenet' + version 109 | state_dict = load_state_dict_from_url(model_urls[arch], 110 | progress=progress) 111 | model.load_state_dict(state_dict) 112 | return model 113 | 114 | 115 | def squeezenet1_0(pretrained=False, progress=True, **kwargs): 116 | r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level 117 | accuracy with 50x fewer parameters and <0.5MB model size" 118 | `_ paper. 119 | 120 | Args: 121 | pretrained (bool): If True, returns a model pre-trained on ImageNet 122 | progress (bool): If True, displays a progress bar of the download to stderr 123 | """ 124 | return _squeezenet('1_0', pretrained, progress, **kwargs) 125 | 126 | 127 | def squeezenet1_1(pretrained=False, progress=True, **kwargs): 128 | r"""SqueezeNet 1.1 model from the `official SqueezeNet repo 129 | `_. 130 | SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters 131 | than SqueezeNet 1.0, without sacrificing accuracy. 132 | 133 | Args: 134 | pretrained (bool): If True, returns a model pre-trained on ImageNet 135 | progress (bool): If True, displays a progress bar of the download to stderr 136 | """ 137 | return _squeezenet('1_1', pretrained, progress, **kwargs) 138 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch.hub import load_state_dict_from_url 3 | except ImportError: 4 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 5 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .utils import load_state_dict_from_url 4 | 5 | 6 | __all__ = [ 7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 8 | 'vgg19_bn', 'vgg19', 9 | ] 10 | 11 | 12 | model_urls = { 13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 14 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 15 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 16 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 17 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 18 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 19 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 20 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 21 | } 22 | 23 | 24 | class VGG(nn.Module): 25 | 26 | def __init__(self, features, num_classes=1000, init_weights=True): 27 | super(VGG, self).__init__() 28 | self.features = features 29 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 30 | self.classifier = nn.Sequential( 31 | nn.Linear(512 * 7 * 7, 4096), 32 | nn.ReLU(True), 33 | nn.Dropout(), 34 | nn.Linear(4096, 4096), 35 | nn.ReLU(True), 36 | nn.Dropout(), 37 | nn.Linear(4096, num_classes), 38 | ) 39 | if init_weights: 40 | self._initialize_weights() 41 | 42 | def forward(self, x): 43 | x = self.features(x) 44 | x = self.avgpool(x) 45 | x = torch.flatten(x, 1) 46 | x = self.classifier(x) 47 | return x 48 | 49 | def _initialize_weights(self): 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 53 | if m.bias is not None: 54 | nn.init.constant_(m.bias, 0) 55 | elif isinstance(m, nn.BatchNorm2d): 56 | nn.init.constant_(m.weight, 1) 57 | nn.init.constant_(m.bias, 0) 58 | elif isinstance(m, nn.Linear): 59 | nn.init.normal_(m.weight, 0, 0.01) 60 | nn.init.constant_(m.bias, 0) 61 | 62 | 63 | def make_layers(cfg, batch_norm=False): 64 | layers = [] 65 | in_channels = 3 66 | for v in cfg: 67 | if v == 'M': 68 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 69 | else: 70 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 71 | if batch_norm: 72 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 73 | else: 74 | layers += [conv2d, nn.ReLU(inplace=True)] 75 | in_channels = v 76 | return nn.Sequential(*layers) 77 | 78 | 79 | cfgs = { 80 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 81 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 82 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 83 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 84 | } 85 | 86 | 87 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): 88 | if pretrained: 89 | kwargs['init_weights'] = False 90 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 91 | if pretrained: 92 | state_dict = load_state_dict_from_url(model_urls[arch], 93 | progress=progress) 94 | model.load_state_dict(state_dict) 95 | return model 96 | 97 | 98 | def vgg11(pretrained=False, progress=True, **kwargs): 99 | r"""VGG 11-layer model (configuration "A") from 100 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 101 | 102 | Args: 103 | pretrained (bool): If True, returns a model pre-trained on ImageNet 104 | progress (bool): If True, displays a progress bar of the download to stderr 105 | """ 106 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 107 | 108 | 109 | def vgg11_bn(pretrained=False, progress=True, **kwargs): 110 | r"""VGG 11-layer model (configuration "A") with batch normalization 111 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 112 | 113 | Args: 114 | pretrained (bool): If True, returns a model pre-trained on ImageNet 115 | progress (bool): If True, displays a progress bar of the download to stderr 116 | """ 117 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 118 | 119 | 120 | def vgg13(pretrained=False, progress=True, **kwargs): 121 | r"""VGG 13-layer model (configuration "B") 122 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 123 | 124 | Args: 125 | pretrained (bool): If True, returns a model pre-trained on ImageNet 126 | progress (bool): If True, displays a progress bar of the download to stderr 127 | """ 128 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 129 | 130 | 131 | def vgg13_bn(pretrained=False, progress=True, **kwargs): 132 | r"""VGG 13-layer model (configuration "B") with batch normalization 133 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 134 | 135 | Args: 136 | pretrained (bool): If True, returns a model pre-trained on ImageNet 137 | progress (bool): If True, displays a progress bar of the download to stderr 138 | """ 139 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 140 | 141 | 142 | def vgg16(pretrained=False, progress=True, **kwargs): 143 | r"""VGG 16-layer model (configuration "D") 144 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 145 | 146 | Args: 147 | pretrained (bool): If True, returns a model pre-trained on ImageNet 148 | progress (bool): If True, displays a progress bar of the download to stderr 149 | """ 150 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 151 | 152 | 153 | def vgg16_bn(pretrained=False, progress=True, **kwargs): 154 | r"""VGG 16-layer model (configuration "D") with batch normalization 155 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 156 | 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | progress (bool): If True, displays a progress bar of the download to stderr 160 | """ 161 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 162 | 163 | 164 | def vgg19(pretrained=False, progress=True, **kwargs): 165 | r"""VGG 19-layer model (configuration "E") 166 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 167 | 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | progress (bool): If True, displays a progress bar of the download to stderr 171 | """ 172 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 173 | 174 | 175 | def vgg19_bn(pretrained=False, progress=True, **kwargs): 176 | r"""VGG 19-layer model (configuration 'E') with batch normalization 177 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 178 | 179 | Args: 180 | pretrained (bool): If True, returns a model pre-trained on ImageNet 181 | progress (bool): If True, displays a progress bar of the download to stderr 182 | """ 183 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) 184 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from PIL import Image 4 | 5 | import json 6 | import matplotlib.pyplot as plt 7 | import main 8 | 9 | #获取索引到类名的映射,以便查看测试影像的输出类 10 | idx_to_class = {v: k for k, v in main.train_datasets.class_to_idx.items()} 11 | print(idx_to_class) 12 | 13 | with open('cat_to_name.json', 'r', encoding='gbk') as f: 14 | cat_to_name = json.load(f) 15 | print(cat_to_name) 16 | 17 | def predict(model, test_image_name): 18 | 19 | transform = main.test_valid_transforms 20 | 21 | test_image = Image.open(test_image_name).convert('RGB') 22 | plt.imshow(test_image) 23 | 24 | test_image_tensor = transform(test_image) 25 | 26 | if torch.cuda.is_available(): 27 | test_image_tensor = test_image_tensor.view(1, 3, 124, 124).cuda() 28 | else: 29 | test_image_tensor = test_image_tensor.view(1, 3, 124, 124) 30 | 31 | with torch.no_grad(): 32 | model.eval() 33 | # Model outputs log probabilities 34 | start = time.time() 35 | out = model(test_image_tensor) 36 | stop = time.time() 37 | print('cost time', stop - start) 38 | ps = torch.exp(out) 39 | topk, topclass = ps.topk(3, dim=1) 40 | names = [] 41 | for i in range(3): 42 | names.append(cat_to_name[idx_to_class[topclass.cpu().numpy()[0][i]]]) 43 | print("Predcition", i + 1, ":", names[i], ", Score: ", 44 | topk.cpu().numpy()[0][i]) 45 | 46 | if __name__ == '__main__': 47 | model = torch.load('trained_models/resnet50_model_23.pth') 48 | predict(model, '61.png') -------------------------------------------------------------------------------- /trained_models/data_record.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MUYang99/Pytorch_Remote-Sensing-Image-Classification/a9e4c94f02f2883f89cc53102e1372d2d1ece763/trained_models/data_record.pth --------------------------------------------------------------------------------