├── ImageNet └── ILSVRC2012_devkit_t12 │ └── data │ ├── ILSVRC2012_validation_ground_truth.txt │ └── meta.mat ├── README.md ├── read_ImageNetData.py ├── resnext.py └── train.py /ImageNet/ILSVRC2012_devkit_t12/data/meta.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/miraclewkf/ResNeXt-PyTorch/0b1ad0252aba44b070a3d191d209443fb08f9f9c/ImageNet/ILSVRC2012_devkit_t12/data/meta.mat -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## This is the PyTorch implement of ResNeXt (train on ImageNet dataset) 2 | 3 | Paper: [Aggregated Residual Transformations for Deep Neural Networks](https://arxiv.org/abs/1611.05431) 4 | 5 | 6 | # Usage 7 | 8 | ### Prepare data 9 | 10 | This code takes ImageNet dataset as example. You can download ImageNet dataset and put them as follows. I only provide `ILSVRC2012_dev_kit_t12` due to the restriction of memory, in other words, you need download `ILSVRC2012_img_train` and `ILSVRC2012_img_val`. 11 | 12 | ``` 13 | ├── train.py # train script 14 | ├── resnext.py # network of resnext 15 | ├── read_ImageNetData.py # ImageNet dataset read script 16 | ├── ImageData # train and validation data 17 | ├── ILSVRC2012_img_train 18 | ├── n01440764 19 | ├── ... 20 | ├── n15075141 21 | ├── ILSVRC2012_img_val 22 | ├── ILSVRC2012_dev_kit_t12 23 | ├── data 24 | ├── ILSVRC2012_validation_ground_truth.txt 25 | ├── meta.mat # the map between train file name and label 26 | ``` 27 | 28 | ### Train 29 | 30 | * If you want to train from scratch, you can run as follows: 31 | 32 | ``` 33 | python train.py --batch-size 256 --gpus 0,1,2,3 34 | ``` 35 | 36 | * If you want to train from one checkpoint, you can run as follows(for example train from `epoch_4.pth.tar`, the `--start-epoch` parameter is corresponding to the epoch of the checkpoint): 37 | 38 | ``` 39 | python train.py --batch-size 256 --gpus 0,1,2,3 --resume output/epoch_4.pth.tar --start-epoch 4 40 | ``` -------------------------------------------------------------------------------- /read_ImageNetData.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms, datasets 2 | import os 3 | import torch 4 | from PIL import Image 5 | import scipy.io as scio 6 | 7 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 8 | 9 | def ImageNetData(args): 10 | # data_transform, pay attention that the input of Normalize() is Tensor and the input of RandomResizedCrop() or RandomHorizontalFlip() is PIL Image 11 | data_transforms = { 12 | 'train': transforms.Compose([ 13 | transforms.Resize(256), 14 | transforms.RandomCrop(224), 15 | transforms.RandomHorizontalFlip(), 16 | transforms.ToTensor(), 17 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 18 | ]), 19 | 'val': transforms.Compose([ 20 | transforms.Resize(256), 21 | transforms.CenterCrop(224), 22 | transforms.ToTensor(), 23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 24 | ]), 25 | } 26 | image_datasets = {} 27 | #image_datasets['train'] = datasets.ImageFolder(os.path.join(args.data_dir, 'ILSVRC2012_img_train'), data_transforms['train']) 28 | 29 | image_datasets['train'] = ImageNetTrainDataSet(os.path.join(args.data_dir, 'ILSVRC2012_img_train'), 30 | os.path.join(args.data_dir, 'ILSVRC2012_devkit_t12', 'data', 'meta.mat'), 31 | data_transforms['train']) 32 | image_datasets['val'] = ImageNetValDataSet(os.path.join(args.data_dir, 'ILSVRC2012_img_val'), 33 | os.path.join(args.data_dir, 'ILSVRC2012_devkit_t12', 'data','ILSVRC2012_validation_ground_truth.txt'), 34 | data_transforms['val']) 35 | 36 | # wrap your data and label into Tensor 37 | dataloders = {x: torch.utils.data.DataLoader(image_datasets[x], 38 | batch_size=args.batch_size, 39 | shuffle=True, 40 | num_workers=args.num_workers) for x in ['train', 'val']} 41 | 42 | 43 | dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} 44 | return dataloders, dataset_sizes 45 | 46 | class ImageNetTrainDataSet(torch.utils.data.Dataset): 47 | def __init__(self, root_dir, img_label, data_transforms): 48 | label_array = scio.loadmat(img_label)['synsets'] 49 | label_dic = {} 50 | for i in range(1000): 51 | label_dic[label_array[i][0][1][0]] = i 52 | self.img_path = os.listdir(root_dir) 53 | self.data_transforms = data_transforms 54 | self.label_dic = label_dic 55 | self.root_dir = root_dir 56 | self.imgs = self._make_dataset() 57 | 58 | def __len__(self): 59 | return len(self.imgs) 60 | 61 | def __getitem__(self, item): 62 | data, label = self.imgs[item] 63 | img = Image.open(data).convert('RGB') 64 | if self.data_transforms is not None: 65 | try: 66 | img = self.data_transforms(img) 67 | except: 68 | print("Cannot transform image: {}".format(self.img_path[item])) 69 | return img, label 70 | 71 | def _make_dataset(self): 72 | class_to_idx = self.label_dic 73 | images = [] 74 | dir = os.path.expanduser(self.root_dir) 75 | for target in sorted(os.listdir(dir)): 76 | d = os.path.join(dir, target) 77 | if not os.path.isdir(d): 78 | continue 79 | 80 | for root, _, fnames in sorted(os.walk(d)): 81 | for fname in sorted(fnames): 82 | if self._is_image_file(fname): 83 | path = os.path.join(root, fname) 84 | item = (path, class_to_idx[target]) 85 | images.append(item) 86 | 87 | return images 88 | 89 | def _is_image_file(self, filename): 90 | """Checks if a file is an image. 91 | 92 | Args: 93 | filename (string): path to a file 94 | 95 | Returns: 96 | bool: True if the filename ends with a known image extension 97 | """ 98 | filename_lower = filename.lower() 99 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 100 | 101 | class ImageNetValDataSet(torch.utils.data.Dataset): 102 | def __init__(self, img_path, img_label, data_transforms): 103 | self.data_transforms = data_transforms 104 | img_names = os.listdir(img_path) 105 | img_names.sort() 106 | self.img_path = [os.path.join(img_path, img_name) for img_name in img_names] 107 | with open(img_label,"r") as input_file: 108 | lines = input_file.readlines() 109 | self.img_label = [(int(line)-1) for line in lines] 110 | 111 | def __len__(self): 112 | return len(self.img_path) 113 | 114 | def __getitem__(self, item): 115 | img = Image.open(self.img_path[item]).convert('RGB') 116 | label = self.img_label[item] 117 | if self.data_transforms is not None: 118 | try: 119 | img = self.data_transforms(img) 120 | except: 121 | print("Cannot transform image: {}".format(self.img_path[item])) 122 | return img, label 123 | -------------------------------------------------------------------------------- /resnext.py: -------------------------------------------------------------------------------- 1 | ''' 2 | New for ResNeXt: 3 | 1. Wider bottleneck 4 | 2. Add group for conv2 5 | ''' 6 | 7 | import torch.nn as nn 8 | import math 9 | 10 | __all__ = ['ResNeXt', 'resnext18', 'resnext34', 'resnext50', 'resnext101', 11 | 'resnext152'] 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None, num_group=32): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes*2, stride) 25 | self.bn1 = nn.BatchNorm2d(planes*2) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes*2, planes*2, groups=num_group) 28 | self.bn2 = nn.BatchNorm2d(planes*2) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None, num_group=32): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes*2, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes*2) 58 | self.conv2 = nn.Conv2d(planes*2, planes*2, kernel_size=3, stride=stride, 59 | padding=1, bias=False, groups=num_group) 60 | self.bn2 = nn.BatchNorm2d(planes*2) 61 | self.conv3 = nn.Conv2d(planes*2, planes * 4, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(planes * 4) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv3(out) 79 | out = self.bn3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class ResNeXt(nn.Module): 91 | 92 | def __init__(self, block, layers, num_classes=1000, num_group=32): 93 | self.inplanes = 64 94 | super(ResNeXt, self).__init__() 95 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 96 | bias=False) 97 | self.bn1 = nn.BatchNorm2d(64) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 100 | self.layer1 = self._make_layer(block, 64, layers[0], num_group) 101 | self.layer2 = self._make_layer(block, 128, layers[1], num_group, stride=2) 102 | self.layer3 = self._make_layer(block, 256, layers[2], num_group, stride=2) 103 | self.layer4 = self._make_layer(block, 512, layers[3], num_group, stride=2) 104 | self.avgpool = nn.AvgPool2d(7, stride=1) 105 | self.fc = nn.Linear(512 * block.expansion, num_classes) 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | elif isinstance(m, nn.BatchNorm2d): 112 | m.weight.data.fill_(1) 113 | m.bias.data.zero_() 114 | 115 | def _make_layer(self, block, planes, blocks, num_group, stride=1): 116 | downsample = None 117 | if stride != 1 or self.inplanes != planes * block.expansion: 118 | downsample = nn.Sequential( 119 | nn.Conv2d(self.inplanes, planes * block.expansion, 120 | kernel_size=1, stride=stride, bias=False), 121 | nn.BatchNorm2d(planes * block.expansion), 122 | ) 123 | 124 | layers = [] 125 | layers.append(block(self.inplanes, planes, stride, downsample, num_group=num_group)) 126 | self.inplanes = planes * block.expansion 127 | for i in range(1, blocks): 128 | layers.append(block(self.inplanes, planes, num_group=num_group)) 129 | 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x): 133 | x = self.conv1(x) 134 | x = self.bn1(x) 135 | x = self.relu(x) 136 | x = self.maxpool(x) 137 | 138 | x = self.layer1(x) 139 | x = self.layer2(x) 140 | x = self.layer3(x) 141 | x = self.layer4(x) 142 | 143 | x = self.avgpool(x) 144 | x = x.view(x.size(0), -1) 145 | x = self.fc(x) 146 | 147 | return x 148 | 149 | 150 | def resnext18( **kwargs): 151 | """Constructs a ResNeXt-18 model. 152 | """ 153 | model = ResNeXt(BasicBlock, [2, 2, 2, 2], **kwargs) 154 | return model 155 | 156 | 157 | def resnext34(**kwargs): 158 | """Constructs a ResNeXt-34 model. 159 | """ 160 | model = ResNeXt(BasicBlock, [3, 4, 6, 3], **kwargs) 161 | return model 162 | 163 | 164 | def resnext50(**kwargs): 165 | """Constructs a ResNeXt-50 model. 166 | """ 167 | model = ResNeXt(Bottleneck, [3, 4, 6, 3], **kwargs) 168 | return model 169 | 170 | 171 | def resnext101(**kwargs): 172 | """Constructs a ResNeXt-101 model. 173 | """ 174 | model = ResNeXt(Bottleneck, [3, 4, 23, 3], **kwargs) 175 | return model 176 | 177 | 178 | def resnext152(**kwargs): 179 | """Constructs a ResNeXt-152 model. 180 | """ 181 | model = ResNeXt(Bottleneck, [3, 8, 36, 3], **kwargs) 182 | return model -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.optim import lr_scheduler 7 | from torch.autograd import Variable 8 | import time 9 | import os 10 | from resnext import * 11 | import argparse 12 | from read_ImageNetData import ImageNetData 13 | 14 | def train_model(args, model, criterion, optimizer, scheduler, num_epochs, dataset_sizes): 15 | since = time.time() 16 | resumed = False 17 | 18 | best_model_wts = model.state_dict() 19 | 20 | for epoch in range(args.start_epoch+1,num_epochs): 21 | 22 | # Each epoch has a training and validation phase 23 | for phase in ['train','val']: 24 | if phase == 'train': 25 | if args.start_epoch > 0 and (not resumed): 26 | scheduler.step(args.start_epoch+1) 27 | resumed = True 28 | else: 29 | scheduler.step(epoch) 30 | model.train(True) # Set model to training mode 31 | else: 32 | model.train(False) # Set model to evaluate mode 33 | 34 | running_loss = 0.0 35 | running_corrects = 0 36 | 37 | tic_batch = time.time() 38 | # Iterate over data. 39 | for i, (inputs, labels) in enumerate(dataloders[phase]): 40 | # wrap them in Variable 41 | if use_gpu: 42 | inputs = Variable(inputs.cuda()) 43 | labels = Variable(labels.cuda()) 44 | else: 45 | inputs, labels = Variable(inputs), Variable(labels) 46 | 47 | # zero the parameter gradients 48 | optimizer.zero_grad() 49 | 50 | # forward 51 | outputs = model(inputs) 52 | _, preds = torch.max(outputs.data, 1) 53 | loss = criterion(outputs, labels) 54 | 55 | # backward + optimize only if in training phase 56 | if phase == 'train': 57 | loss.backward() 58 | optimizer.step() 59 | 60 | # statistics 61 | running_loss += loss.data[0] 62 | running_corrects += torch.sum(preds == labels.data) 63 | 64 | batch_loss = running_loss / ((i+1)*args.batch_size) 65 | batch_acc = running_corrects / ((i+1)*args.batch_size) 66 | 67 | if phase == 'train' and i%args.print_freq == 0: 68 | print('[Epoch {}/{}]-[batch:{}/{}] lr:{:.4f} {} Loss: {:.6f} Acc: {:.4f} Time: {:.4f}batch/sec'.format( 69 | epoch, num_epochs - 1, i, round(dataset_sizes[phase]/args.batch_size)-1, scheduler.get_lr()[0], phase, batch_loss, batch_acc, \ 70 | args.print_freq/(time.time()-tic_batch))) 71 | tic_batch = time.time() 72 | 73 | epoch_loss = running_loss / dataset_sizes[phase] 74 | epoch_acc = running_corrects / dataset_sizes[phase] 75 | 76 | print('{} Loss: {:.4f} Acc: {:.4f}'.format( 77 | phase, epoch_loss, epoch_acc)) 78 | 79 | if (epoch+1) % args.save_epoch_freq == 0: 80 | if not os.path.exists(args.save_path): 81 | os.makedirs(args.save_path) 82 | torch.save(model, os.path.join(args.save_path, "epoch_" + str(epoch) + ".pth.tar")) 83 | 84 | time_elapsed = time.time() - since 85 | print('Training complete in {:.0f}m {:.0f}s'.format( 86 | time_elapsed // 60, time_elapsed % 60)) 87 | 88 | # load best model weights 89 | model.load_state_dict(best_model_wts) 90 | return model 91 | 92 | if __name__ == '__main__': 93 | 94 | parser = argparse.ArgumentParser(description="PyTorch implementation of SENet") 95 | parser.add_argument('--data-dir', type=str, default="/ImageNet") 96 | parser.add_argument('--batch-size', type=int, default=16) 97 | parser.add_argument('--num-class', type=int, default=1000) 98 | parser.add_argument('--num-epochs', type=int, default=100) 99 | parser.add_argument('--lr', type=float, default=0.1) 100 | parser.add_argument('--num-workers', type=int, default=0) 101 | parser.add_argument('--gpus', type=str, default=0) 102 | parser.add_argument('--print-freq', type=int, default=10) 103 | parser.add_argument('--save-epoch-freq', type=int, default=1) 104 | parser.add_argument('--save-path', type=str, default="output") 105 | parser.add_argument('--resume', type=str, default="", help="For training from one checkpoint") 106 | parser.add_argument('--start-epoch', type=int, default=0, help="Corresponding to the epoch of resume ") 107 | args = parser.parse_args() 108 | 109 | # read data 110 | dataloders, dataset_sizes = ImageNetData(args) 111 | 112 | # use gpu or not 113 | use_gpu = torch.cuda.is_available() 114 | print("use_gpu:{}".format(use_gpu)) 115 | 116 | # get model 117 | model = resnext50(num_classes = args.num_class) 118 | 119 | if args.resume: 120 | if os.path.isfile(args.resume): 121 | print(("=> loading checkpoint '{}'".format(args.resume))) 122 | checkpoint = torch.load(args.resume) 123 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.state_dict().items())} 124 | model.load_state_dict(base_dict) 125 | else: 126 | print(("=> no checkpoint found at '{}'".format(args.resume))) 127 | 128 | if use_gpu: 129 | model = model.cuda() 130 | model = torch.nn.DataParallel(model, device_ids=[int(i) for i in args.gpus.strip().split(',')]) 131 | 132 | # define loss function 133 | criterion = nn.CrossEntropyLoss() 134 | 135 | # Observe that all parameters are being optimized 136 | optimizer_ft = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001) 137 | 138 | # Decay LR by a factor of 0.1 every 7 epochs 139 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.1) 140 | 141 | model = train_model(args=args, 142 | model=model, 143 | criterion=criterion, 144 | optimizer=optimizer_ft, 145 | scheduler=exp_lr_scheduler, 146 | num_epochs=args.num_epochs, 147 | dataset_sizes=dataset_sizes) 148 | --------------------------------------------------------------------------------