├── .gitattributes ├── fsam.png ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── densenet.cpython-310.pyc │ ├── pyramidnet.cpython-310.pyc │ ├── resnet.cpython-310.pyc │ ├── vgg.cpython-310.pyc │ ├── vit.cpython-310.pyc │ └── wide_resnet.cpython-310.pyc ├── pyramidnet.py ├── resnet.py ├── vgg.py └── wide_resnet.py ├── readme.md ├── requirement.txt ├── results ├── FriendlySAM │ └── CIFAR100 │ │ └── resnet18 │ │ ├── FriendlySAM_cutout_0.2_1_0.6_200_resnet18_bz128_wd0.001_CIFAR100_cosine_seed1 │ │ └── log │ │ └── FriendlySAM_cutout_0.2_1_0.95_200_resnet18_bz128_wd0.001_CIFAR100_cosine_seed1 │ │ └── log └── SAM │ └── CIFAR100 │ └── resnet18 │ └── SAM_cutout_0.2_0_0_200_resnet18_bz128_wd0.001_CIFAR100_cosine_seed1 │ └── log ├── run.sh ├── train.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /fsam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/F-SAM/2e2f9899ab473880430640fa5def4ad8f5aa2401/fsam.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .vgg import * 3 | from .wide_resnet import * -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/F-SAM/2e2f9899ab473880430640fa5def4ad8f5aa2401/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/densenet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/F-SAM/2e2f9899ab473880430640fa5def4ad8f5aa2401/models/__pycache__/densenet.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/pyramidnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/F-SAM/2e2f9899ab473880430640fa5def4ad8f5aa2401/models/__pycache__/pyramidnet.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/F-SAM/2e2f9899ab473880430640fa5def4ad8f5aa2401/models/__pycache__/resnet.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/vgg.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/F-SAM/2e2f9899ab473880430640fa5def4ad8f5aa2401/models/__pycache__/vgg.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/vit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/F-SAM/2e2f9899ab473880430640fa5def4ad8f5aa2401/models/__pycache__/vit.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/wide_resnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/F-SAM/2e2f9899ab473880430640fa5def4ad8f5aa2401/models/__pycache__/wide_resnet.cpython-310.pyc -------------------------------------------------------------------------------- /models/pyramidnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | #from math import round 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | outchannel_ratio = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None): 18 | super(BasicBlock, self).__init__() 19 | self.bn1 = nn.BatchNorm2d(inplanes) 20 | self.conv1 = conv3x3(inplanes, planes, stride) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv2 = conv3x3(planes, planes) 23 | self.bn3 = nn.BatchNorm2d(planes) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | 30 | out = self.bn1(x) 31 | out = self.conv1(out) 32 | out = self.bn2(out) 33 | out = self.relu(out) 34 | out = self.conv2(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | shortcut = self.downsample(x) 39 | featuremap_size = shortcut.size()[2:4] 40 | else: 41 | shortcut = x 42 | featuremap_size = out.size()[2:4] 43 | 44 | batch_size = out.size()[0] 45 | residual_channel = out.size()[1] 46 | shortcut_channel = shortcut.size()[1] 47 | 48 | if residual_channel != shortcut_channel: 49 | padding = torch.autograd.Variable(torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]).fill_(0)) 50 | out += torch.cat((shortcut, padding), 1) 51 | else: 52 | out += shortcut 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | outchannel_ratio = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.bn1 = nn.BatchNorm2d(inplanes) 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, (planes*1), kernel_size=3, stride=stride, 66 | padding=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d((planes*1)) 68 | self.conv3 = nn.Conv2d((planes*1), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False) 69 | self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | 76 | out = self.bn1(x) 77 | out = self.conv1(out) 78 | 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | out = self.conv2(out) 82 | 83 | out = self.bn3(out) 84 | out = self.relu(out) 85 | out = self.conv3(out) 86 | 87 | out = self.bn4(out) 88 | 89 | if self.downsample is not None: 90 | shortcut = self.downsample(x) 91 | featuremap_size = shortcut.size()[2:4] 92 | else: 93 | shortcut = x 94 | featuremap_size = out.size()[2:4] 95 | 96 | batch_size = out.size()[0] 97 | residual_channel = out.size()[1] 98 | shortcut_channel = shortcut.size()[1] 99 | 100 | if residual_channel != shortcut_channel: 101 | padding = torch.autograd.Variable(torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]).fill_(0)) 102 | out += torch.cat((shortcut, padding), 1) 103 | else: 104 | out += shortcut 105 | 106 | return out 107 | 108 | 109 | class PyramidNet(nn.Module): 110 | def __init__(self, depth, alpha, num_classes, bottleneck=True, dataset='cifar'): 111 | super(PyramidNet, self).__init__() 112 | self.dataset = dataset 113 | if self.dataset.startswith('cifar'): 114 | self.inplanes = 16 115 | if bottleneck == True: 116 | n = int((depth - 2) / 9) 117 | block = Bottleneck 118 | else: 119 | n = int((depth - 2) / 6) 120 | block = BasicBlock 121 | 122 | self.addrate = alpha / (3*n*1.0) 123 | 124 | self.input_featuremap_dim = self.inplanes 125 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False) 126 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 127 | 128 | self.featuremap_dim = self.input_featuremap_dim 129 | self.layer1 = self.pyramidal_make_layer(block, n) 130 | self.layer2 = self.pyramidal_make_layer(block, n, stride=2) 131 | self.layer3 = self.pyramidal_make_layer(block, n, stride=2) 132 | 133 | self.final_featuremap_dim = self.input_featuremap_dim 134 | self.bn_final= nn.BatchNorm2d(self.final_featuremap_dim) 135 | self.relu_final = nn.ReLU(inplace=True) 136 | self.avgpool = nn.AvgPool2d(8) 137 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 138 | 139 | elif dataset == 'imagenet': 140 | blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 141 | layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]} 142 | 143 | if layers.get(depth) is None: 144 | if bottleneck == True: 145 | blocks[depth] = Bottleneck 146 | temp_cfg = int((depth-2)/12) 147 | else: 148 | blocks[depth] = BasicBlock 149 | temp_cfg = int((depth-2)/8) 150 | 151 | layers[depth]= [temp_cfg, temp_cfg, temp_cfg, temp_cfg] 152 | print('=> the layer configuration for each stage is set to', layers[depth]) 153 | 154 | self.inplanes = 64 155 | self.addrate = alpha / (sum(layers[depth])*1.0) 156 | 157 | self.input_featuremap_dim = self.inplanes 158 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False) 159 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 160 | self.relu = nn.ReLU(inplace=True) 161 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 162 | 163 | self.featuremap_dim = self.input_featuremap_dim 164 | self.layer1 = self.pyramidal_make_layer(blocks[depth], layers[depth][0]) 165 | self.layer2 = self.pyramidal_make_layer(blocks[depth], layers[depth][1], stride=2) 166 | self.layer3 = self.pyramidal_make_layer(blocks[depth], layers[depth][2], stride=2) 167 | self.layer4 = self.pyramidal_make_layer(blocks[depth], layers[depth][3], stride=2) 168 | 169 | self.final_featuremap_dim = self.input_featuremap_dim 170 | self.bn_final= nn.BatchNorm2d(self.final_featuremap_dim) 171 | self.relu_final = nn.ReLU(inplace=True) 172 | self.avgpool = nn.AvgPool2d(7) 173 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 174 | 175 | for m in self.modules(): 176 | if isinstance(m, nn.Conv2d): 177 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 178 | m.weight.data.normal_(0, math.sqrt(2. / n)) 179 | elif isinstance(m, nn.BatchNorm2d): 180 | m.weight.data.fill_(1) 181 | m.bias.data.zero_() 182 | 183 | def pyramidal_make_layer(self, block, block_depth, stride=1): 184 | downsample = None 185 | if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio: 186 | downsample = nn.AvgPool2d((2,2), stride = (2, 2), ceil_mode=True) 187 | 188 | layers = [] 189 | self.featuremap_dim = self.featuremap_dim + self.addrate 190 | layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample)) 191 | for i in range(1, block_depth): 192 | temp_featuremap_dim = self.featuremap_dim + self.addrate 193 | layers.append(block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1)) 194 | self.featuremap_dim = temp_featuremap_dim 195 | self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio 196 | 197 | return nn.Sequential(*layers) 198 | 199 | def forward(self, x): 200 | if self.dataset.startswith('cifar'): 201 | x = self.conv1(x) 202 | x = self.bn1(x) 203 | 204 | x = self.layer1(x) 205 | x = self.layer2(x) 206 | x = self.layer3(x) 207 | 208 | x = self.bn_final(x) 209 | x = self.relu_final(x) 210 | x = self.avgpool(x) 211 | x = x.view(x.size(0), -1) 212 | x = self.fc(x) 213 | 214 | elif self.dataset == 'imagenet': 215 | x = self.conv1(x) 216 | x = self.bn1(x) 217 | x = self.relu(x) 218 | x = self.maxpool(x) 219 | 220 | x = self.layer1(x) 221 | x = self.layer2(x) 222 | x = self.layer3(x) 223 | x = self.layer4(x) 224 | 225 | x = self.bn_final(x) 226 | x = self.relu_final(x) 227 | x = self.avgpool(x) 228 | x = x.view(x.size(0), -1) 229 | x = self.fc(x) 230 | 231 | return x 232 | 233 | if __name__ == "__main__": 234 | model = PyramidNet(110, 270, 10) 235 | print(model) 236 | for param in model.parameters(): 237 | print(param.shape) 238 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """resnet in pytorch 2 | 3 | 4 | 5 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 6 | 7 | Deep Residual Learning for Image Recognition 8 | https://arxiv.org/abs/1512.03385v1 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | class BasicBlock(nn.Module): 15 | """Basic Block for resnet 18 and resnet 34 16 | 17 | """ 18 | 19 | #BasicBlock and BottleNeck block 20 | #have different output size 21 | #we use class attribute expansion 22 | #to distinct 23 | expansion = 1 24 | 25 | def __init__(self, in_channels, out_channels, stride=1): 26 | super().__init__() 27 | 28 | #residual function 29 | self.residual_function = nn.Sequential( 30 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 31 | nn.BatchNorm2d(out_channels), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 34 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 35 | ) 36 | 37 | #shortcut 38 | self.shortcut = nn.Sequential() 39 | 40 | #the shortcut output dimension is not the same with residual function 41 | #use 1*1 convolution to match the dimension 42 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 43 | self.shortcut = nn.Sequential( 44 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 45 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 46 | ) 47 | 48 | def forward(self, x): 49 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 50 | 51 | class BottleNeck(nn.Module): 52 | """Residual block for resnet over 50 layers 53 | 54 | """ 55 | expansion = 4 56 | def __init__(self, in_channels, out_channels, stride=1): 57 | super().__init__() 58 | self.residual_function = nn.Sequential( 59 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 60 | nn.BatchNorm2d(out_channels), 61 | nn.ReLU(inplace=True), 62 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 63 | nn.BatchNorm2d(out_channels), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), 66 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 67 | ) 68 | 69 | self.shortcut = nn.Sequential() 70 | 71 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 72 | self.shortcut = nn.Sequential( 73 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 74 | nn.BatchNorm2d(out_channels * BottleNeck.expansion) 75 | ) 76 | 77 | def forward(self, x): 78 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 79 | 80 | class ResNet(nn.Module): 81 | 82 | def __init__(self, block, num_block, num_classes=100): 83 | super().__init__() 84 | 85 | self.in_channels = 64 86 | 87 | self.conv1 = nn.Sequential( 88 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 89 | nn.BatchNorm2d(64), 90 | nn.ReLU(inplace=True)) 91 | #we use a different inputsize than the original paper 92 | #so conv2_x's stride is 1 93 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 94 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 95 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 96 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 97 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 98 | self.fc = nn.Linear(512 * block.expansion, num_classes) 99 | 100 | def _make_layer(self, block, out_channels, num_blocks, stride): 101 | """make resnet layers(by layer i didnt mean this 'layer' was the 102 | same as a neuron netowork layer, ex. conv layer), one layer may 103 | contain more than one residual block 104 | 105 | Args: 106 | block: block type, basic block or bottle neck block 107 | out_channels: output depth channel number of this layer 108 | num_blocks: how many blocks per layer 109 | stride: the stride of the first block of this layer 110 | 111 | Return: 112 | return a resnet layer 113 | """ 114 | 115 | # we have num_block blocks per layer, the first block 116 | # could be 1 or 2, other blocks would always be 1 117 | strides = [stride] + [1] * (num_blocks - 1) 118 | layers = [] 119 | for stride in strides: 120 | layers.append(block(self.in_channels, out_channels, stride)) 121 | self.in_channels = out_channels * block.expansion 122 | 123 | return nn.Sequential(*layers) 124 | 125 | def forward(self, x): 126 | output = self.conv1(x) 127 | output = self.conv2_x(output) 128 | output = self.conv3_x(output) 129 | output = self.conv4_x(output) 130 | output = self.conv5_x(output) 131 | output = self.avg_pool(output) 132 | output = output.view(output.size(0), -1) 133 | output = self.fc(output) 134 | 135 | return output 136 | 137 | class resnet18: 138 | base = ResNet 139 | args = list() 140 | kwargs = {'block': BasicBlock, 'num_block': [2, 2, 2, 2]} 141 | 142 | 143 | class resnet50: 144 | base = ResNet 145 | args = list() 146 | kwargs = {'block': BasicBlock, 'num_block': [3, 4, 6, 3]} 147 | 148 | # def resnet18(): 149 | # """ return a ResNet 18 object 150 | # """ 151 | # kwargs = {} 152 | # return ResNet(BasicBlock, [2, 2, 2, 2]) 153 | 154 | def resnet34(): 155 | """ return a ResNet 34 object 156 | """ 157 | return ResNet(BasicBlock, [3, 4, 6, 3]) 158 | 159 | # def resnet50(): 160 | # """ return a ResNet 50 object 161 | # """ 162 | # return ResNet(BottleNeck, [3, 4, 6, 3]) 163 | 164 | def resnet101(): 165 | """ return a ResNet 101 object 166 | """ 167 | return ResNet(BottleNeck, [3, 4, 23, 3]) 168 | 169 | def resnet152(): 170 | """ return a ResNet 152 object 171 | """ 172 | return ResNet(BottleNeck, [3, 8, 36, 3]) 173 | 174 | 175 | 176 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | VGG model definition 3 | ported from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 4 | """ 5 | 6 | import math 7 | import torch.nn as nn 8 | import torchvision.transforms as transforms 9 | 10 | __all__ = ['VGG16', 'VGG16BN', 'VGG19', 'VGG19BN'] 11 | 12 | 13 | def make_layers(cfg, batch_norm=False): 14 | layers = list() 15 | in_channels = 3 16 | for v in cfg: 17 | if v == 'M': 18 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 19 | else: 20 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 21 | if batch_norm: 22 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 23 | else: 24 | layers += [conv2d, nn.ReLU(inplace=True)] 25 | in_channels = v 26 | return nn.Sequential(*layers) 27 | 28 | 29 | cfg = { 30 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 31 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 32 | 512, 512, 512, 512, 'M'], 33 | } 34 | 35 | 36 | class VGG(nn.Module): 37 | def __init__(self, num_classes=10, depth=16, batch_norm=False): 38 | super(VGG, self).__init__() 39 | self.features = make_layers(cfg[depth], batch_norm) 40 | self.classifier = nn.Sequential( 41 | nn.Dropout(), 42 | nn.Linear(512, 512), 43 | nn.ReLU(True), 44 | nn.Dropout(), 45 | nn.Linear(512, 512), 46 | nn.ReLU(True), 47 | nn.Linear(512, num_classes), 48 | ) 49 | 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 53 | m.weight.data.normal_(0, math.sqrt(2. / n)) 54 | m.bias.data.zero_() 55 | 56 | def forward(self, x): 57 | x = self.features(x) 58 | x = x.view(x.size(0), -1) 59 | x = self.classifier(x) 60 | return x 61 | 62 | 63 | class Base: 64 | base = VGG 65 | args = list() 66 | kwargs = dict() 67 | transform_train = transforms.Compose([ 68 | transforms.RandomHorizontalFlip(), 69 | transforms.RandomCrop(32, padding=4), 70 | transforms.ToTensor(), 71 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 72 | ]) 73 | 74 | transform_test = transforms.Compose([ 75 | transforms.ToTensor(), 76 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 77 | ]) 78 | 79 | 80 | class VGG16(Base): 81 | pass 82 | 83 | 84 | class VGG16BN(Base): 85 | kwargs = {'batch_norm': True} 86 | 87 | 88 | class VGG19(Base): 89 | kwargs = {'depth': 19} 90 | 91 | 92 | class VGG19BN(Base): 93 | kwargs = {'depth': 19, 'batch_norm': True} -------------------------------------------------------------------------------- /models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | WideResNet model definition 3 | ported from https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py 4 | """ 5 | 6 | import torchvision.transforms as transforms 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | import math 11 | 12 | __all__ = ['WideResNet28x10', 'WideResNet16x8'] 13 | 14 | from collections import OrderedDict 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | 21 | class BasicUnit(nn.Module): 22 | def __init__(self, channels: int, dropout: float): 23 | super(BasicUnit, self).__init__() 24 | self.block = nn.Sequential(OrderedDict([ 25 | ("0_normalization", nn.BatchNorm2d(channels)), 26 | ("1_activation", nn.ReLU(inplace=True)), 27 | ("2_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)), 28 | ("3_normalization", nn.BatchNorm2d(channels)), 29 | ("4_activation", nn.ReLU(inplace=True)), 30 | ("5_dropout", nn.Dropout(dropout, inplace=True)), 31 | ("6_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)), 32 | ])) 33 | 34 | def forward(self, x): 35 | return x + self.block(x) 36 | 37 | 38 | class DownsampleUnit(nn.Module): 39 | def __init__(self, in_channels: int, out_channels: int, stride: int, dropout: float): 40 | super(DownsampleUnit, self).__init__() 41 | self.norm_act = nn.Sequential(OrderedDict([ 42 | ("0_normalization", nn.BatchNorm2d(in_channels)), 43 | ("1_activation", nn.ReLU(inplace=True)), 44 | ])) 45 | self.block = nn.Sequential(OrderedDict([ 46 | ("0_convolution", nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=1, bias=False)), 47 | ("1_normalization", nn.BatchNorm2d(out_channels)), 48 | ("2_activation", nn.ReLU(inplace=True)), 49 | ("3_dropout", nn.Dropout(dropout, inplace=True)), 50 | ("4_convolution", nn.Conv2d(out_channels, out_channels, (3, 3), stride=1, padding=1, bias=False)), 51 | ])) 52 | self.downsample = nn.Conv2d(in_channels, out_channels, (1, 1), stride=stride, padding=0, bias=False) 53 | 54 | def forward(self, x): 55 | x = self.norm_act(x) 56 | return self.block(x) + self.downsample(x) 57 | 58 | 59 | class Block(nn.Module): 60 | def __init__(self, in_channels: int, out_channels: int, stride: int, depth: int, dropout: float): 61 | super(Block, self).__init__() 62 | self.block = nn.Sequential( 63 | DownsampleUnit(in_channels, out_channels, stride, dropout), 64 | *(BasicUnit(out_channels, dropout) for _ in range(depth)) 65 | ) 66 | 67 | def forward(self, x): 68 | return self.block(x) 69 | 70 | 71 | class WideResNet(nn.Module): 72 | def __init__(self, depth: int, width_factor: int, dropout: float, in_channels: int, num_classes: int): 73 | super(WideResNet, self).__init__() 74 | 75 | self.filters = [16, 1 * 16 * width_factor, 2 * 16 * width_factor, 4 * 16 * width_factor] 76 | self.block_depth = (depth - 4) // (3 * 2) 77 | 78 | self.f = nn.Sequential(OrderedDict([ 79 | ("0_convolution", nn.Conv2d(in_channels, self.filters[0], (3, 3), stride=1, padding=1, bias=False)), 80 | ("1_block", Block(self.filters[0], self.filters[1], 1, self.block_depth, dropout)), 81 | ("2_block", Block(self.filters[1], self.filters[2], 2, self.block_depth, dropout)), 82 | ("3_block", Block(self.filters[2], self.filters[3], 2, self.block_depth, dropout)), 83 | ("4_normalization", nn.BatchNorm2d(self.filters[3])), 84 | ("5_activation", nn.ReLU(inplace=True)), 85 | ("6_pooling", nn.AvgPool2d(kernel_size=8)), 86 | ("7_flattening", nn.Flatten()), 87 | ("8_classification", nn.Linear(in_features=self.filters[3], out_features=num_classes)), 88 | ])) 89 | 90 | self._initialize() 91 | 92 | def _initialize(self): 93 | for m in self.modules(): 94 | if isinstance(m, nn.Conv2d): 95 | nn.init.kaiming_normal_(m.weight.data, mode="fan_in", nonlinearity="relu") 96 | if m.bias is not None: 97 | m.bias.data.zero_() 98 | elif isinstance(m, nn.BatchNorm2d): 99 | m.weight.data.fill_(1) 100 | m.bias.data.zero_() 101 | elif isinstance(m, nn.Linear): 102 | m.weight.data.zero_() 103 | m.bias.data.zero_() 104 | 105 | def forward(self, x): 106 | return self.f(x) 107 | 108 | class WideResNet28x10: 109 | base = WideResNet 110 | args = list() 111 | kwargs = {'depth': 28, 'width_factor': 10, 'dropout': 0, 'in_channels': 3} 112 | transform_train = transforms.Compose([ 113 | transforms.RandomCrop(32, padding=4), 114 | transforms.RandomHorizontalFlip(), 115 | transforms.ToTensor(), 116 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 117 | ]) 118 | transform_test = transforms.Compose([ 119 | transforms.ToTensor(), 120 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 121 | ]) 122 | 123 | class WideResNet16x8: 124 | base = WideResNet 125 | args = list() 126 | kwargs = {'depth': 16, 'width_factor': 8, 'dropout': 0, 'in_channels': 3} 127 | transform_train = transforms.Compose([ 128 | transforms.RandomCrop(32, padding=4), 129 | transforms.RandomHorizontalFlip(), 130 | transforms.ToTensor(), 131 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 132 | ]) 133 | transform_test = transforms.Compose([ 134 | transforms.ToTensor(), 135 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 136 | ]) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Friendly Sharpness-Aware Minimization 2 | 3 | The code is the official implementation of our CVPR 2024 paper 4 | [Friendly Sharpness-Aware Minimization](https://arxiv.org/html/2403.12350v1). 5 | 6 | ## Introduction 7 | In this work, we reveal that the full gradient component in SAM’s adversarial weight perturbation does not contribute to generalization and, in fact, has undesirable effects. We then propose an efficient variant to mitigate these effects and solely utilize batch-wise stochastic gradient noise for weight perturbation. It further enhances the generalization performance of SAM and provides a fresh understanding on SAM's practical success. 8 | 9 | ![Illustration of F-SAM](fsam.png) 10 | 11 | ## Dependencies 12 | 13 | Install required dependencies: 14 | 15 | ``` 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## How to run 20 | 21 | We show sample usages in `run.sh`: 22 | 23 | ``` 24 | bash run.sh 25 | ``` 26 | 27 | 28 | ## Citation 29 | If you find this work helpful, please cite: 30 | ``` 31 | @inproceedings{li2024friendly, 32 | title={Friendly Sharpness-Aware Minimization}, 33 | author={Li, Tao and Zhou, Pan and He, Zhengbao and Cheng, Xinwen and Huang, Xiaolin}, 34 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 35 | year={2024} 36 | } 37 | ``` 38 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1+cu118 2 | torchaudio==2.0.2+cu118 3 | torchvision==0.15.2+cu118 4 | timm==0.9.8 -------------------------------------------------------------------------------- /results/FriendlySAM/CIFAR100/resnet18/FriendlySAM_cutout_0.2_1_0.95_200_resnet18_bz128_wd0.001_CIFAR100_cosine_seed1/log: -------------------------------------------------------------------------------- 1 | save dir: results_final_final/FriendlySAM/CIFAR100/resnet18/FriendlySAM_cutout_0.2_1_0.95_200_resnet18_bz128_wd0.001_CIFAR100_cosine_seed1/checkpoints 2 | log dir: results_final_final/FriendlySAM/CIFAR100/resnet18/FriendlySAM_cutout_0.2_1_0.95_200_resnet18_bz128_wd0.001_CIFAR100_cosine_seed1 3 | Model: resnet18 4 | lambda: 0.95 5 | cutout: True 6 | cutout! 7 | cifar100 dataset! 8 | Files already downloaded and verified 9 | 391 10 | 50000 11 | optimizer: FriendlySAM 12 | FriendlySAM sigma: 1.0 lambda: 0.95 13 | FriendlySAM ( 14 | Parameter Group 0 15 | adaptive: 0 16 | dampening: 0 17 | differentiable: False 18 | foreach: None 19 | lr: 0.05 20 | maximize: False 21 | momentum: 0.9 22 | nesterov: False 23 | rho: 0.2 24 | weight_decay: 0.001 25 | ) 26 | Start training: 0 -> 200 27 | current lr 5.00000e-02 28 | Epoch: [0][0/391] Time 1.080 (1.080) Data 0.203 (0.203) Loss 4.7223 (4.7223) Prec@1 0.000 (0.000) 29 | Epoch: [0][200/391] Time 0.031 (0.041) Data 0.000 (0.001) Loss 4.0549 (4.2394) Prec@1 6.250 (5.204) 30 | Epoch: [0][390/391] Time 0.440 (0.039) Data 0.000 (0.001) Loss 4.0071 (4.0738) Prec@1 3.750 (7.120) 31 | Total time for epoch [0] : 15.123 32 | Test: [0/79] Time 0.107 (0.107) Loss 3.7475 (3.7475) Prec@1 12.500 (12.500) 33 | * Prec@1 11.150 34 | current lr 4.99969e-02 35 | Epoch: [1][0/391] Time 0.174 (0.174) Data 0.114 (0.114) Loss 3.8406 (3.8406) Prec@1 8.594 (8.594) 36 | Epoch: [1][200/391] Time 0.033 (0.037) Data 0.000 (0.001) Loss 3.5742 (3.7431) Prec@1 14.844 (11.221) 37 | Epoch: [1][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 3.2747 (3.6810) Prec@1 20.000 (12.394) 38 | Total time for epoch [1] : 14.053 39 | Test: [0/79] Time 0.118 (0.118) Loss 3.4410 (3.4410) Prec@1 19.531 (19.531) 40 | * Prec@1 17.140 41 | current lr 4.99877e-02 42 | Epoch: [2][0/391] Time 0.185 (0.185) Data 0.149 (0.149) Loss 3.5966 (3.5966) Prec@1 11.719 (11.719) 43 | Epoch: [2][200/391] Time 0.031 (0.037) Data 0.000 (0.001) Loss 3.4463 (3.4444) Prec@1 20.312 (16.628) 44 | Epoch: [2][390/391] Time 0.030 (0.036) Data 0.000 (0.001) Loss 3.2757 (3.3781) Prec@1 17.500 (17.978) 45 | Total time for epoch [2] : 14.048 46 | Test: [0/79] Time 0.105 (0.105) Loss 3.1228 (3.1228) Prec@1 22.656 (22.656) 47 | * Prec@1 21.770 48 | current lr 4.99722e-02 49 | Epoch: [3][0/391] Time 0.177 (0.177) Data 0.128 (0.128) Loss 3.0748 (3.0748) Prec@1 22.656 (22.656) 50 | Epoch: [3][200/391] Time 0.039 (0.036) Data 0.000 (0.001) Loss 3.3805 (3.1712) Prec@1 16.406 (21.980) 51 | Epoch: [3][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 3.1096 (3.1234) Prec@1 22.500 (22.682) 52 | Total time for epoch [3] : 13.720 53 | Test: [0/79] Time 0.117 (0.117) Loss 2.9183 (2.9183) Prec@1 28.125 (28.125) 54 | * Prec@1 25.720 55 | current lr 4.99507e-02 56 | Epoch: [4][0/391] Time 0.191 (0.191) Data 0.141 (0.141) Loss 2.9887 (2.9887) Prec@1 23.438 (23.438) 57 | Epoch: [4][200/391] Time 0.032 (0.037) Data 0.000 (0.001) Loss 2.9297 (2.9533) Prec@1 24.219 (26.042) 58 | Epoch: [4][390/391] Time 0.030 (0.036) Data 0.000 (0.001) Loss 3.0481 (2.8990) Prec@1 23.750 (26.902) 59 | Total time for epoch [4] : 13.926 60 | Test: [0/79] Time 0.118 (0.118) Loss 2.6665 (2.6665) Prec@1 35.156 (35.156) 61 | * Prec@1 30.490 62 | current lr 4.99229e-02 63 | Epoch: [5][0/391] Time 0.186 (0.186) Data 0.125 (0.125) Loss 2.6242 (2.6242) Prec@1 32.812 (32.812) 64 | Epoch: [5][200/391] Time 0.031 (0.036) Data 0.000 (0.001) Loss 2.7153 (2.6982) Prec@1 31.250 (30.978) 65 | Epoch: [5][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 2.5287 (2.6549) Prec@1 30.000 (32.082) 66 | Total time for epoch [5] : 13.910 67 | Test: [0/79] Time 0.110 (0.110) Loss 2.3932 (2.3932) Prec@1 34.375 (34.375) 68 | * Prec@1 34.610 69 | current lr 4.98890e-02 70 | Epoch: [6][0/391] Time 0.165 (0.165) Data 0.124 (0.124) Loss 2.6678 (2.6678) Prec@1 32.812 (32.812) 71 | Epoch: [6][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 2.3475 (2.4616) Prec@1 36.719 (35.856) 72 | Epoch: [6][390/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 2.4082 (2.4285) Prec@1 33.750 (36.632) 73 | Total time for epoch [6] : 13.897 74 | Test: [0/79] Time 0.097 (0.097) Loss 2.2021 (2.2021) Prec@1 39.062 (39.062) 75 | * Prec@1 40.140 76 | current lr 4.98490e-02 77 | Epoch: [7][0/391] Time 0.178 (0.178) Data 0.118 (0.118) Loss 2.0725 (2.0725) Prec@1 46.875 (46.875) 78 | Epoch: [7][200/391] Time 0.036 (0.037) Data 0.000 (0.001) Loss 2.4552 (2.2442) Prec@1 34.375 (40.501) 79 | Epoch: [7][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 2.0271 (2.2284) Prec@1 50.000 (40.820) 80 | Total time for epoch [7] : 14.030 81 | Test: [0/79] Time 0.100 (0.100) Loss 1.9642 (1.9642) Prec@1 47.656 (47.656) 82 | * Prec@1 44.320 83 | current lr 4.98029e-02 84 | Epoch: [8][0/391] Time 0.159 (0.159) Data 0.116 (0.116) Loss 2.2025 (2.2025) Prec@1 46.875 (46.875) 85 | Epoch: [8][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 2.1863 (2.0955) Prec@1 40.625 (44.349) 86 | Epoch: [8][390/391] Time 0.030 (0.036) Data 0.000 (0.001) Loss 2.0547 (2.0731) Prec@1 50.000 (44.804) 87 | Total time for epoch [8] : 13.998 88 | Test: [0/79] Time 0.145 (0.145) Loss 1.9854 (1.9854) Prec@1 46.094 (46.094) 89 | * Prec@1 46.440 90 | current lr 4.97506e-02 91 | Epoch: [9][0/391] Time 0.163 (0.163) Data 0.121 (0.121) Loss 1.8238 (1.8238) Prec@1 53.125 (53.125) 92 | Epoch: [9][200/391] Time 0.047 (0.036) Data 0.000 (0.001) Loss 1.9943 (1.9628) Prec@1 45.312 (47.361) 93 | Epoch: [9][390/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 2.1354 (1.9484) Prec@1 37.500 (47.674) 94 | Total time for epoch [9] : 14.003 95 | Test: [0/79] Time 0.115 (0.115) Loss 1.8873 (1.8873) Prec@1 50.000 (50.000) 96 | * Prec@1 46.840 97 | current lr 4.96922e-02 98 | Epoch: [10][0/391] Time 0.176 (0.176) Data 0.133 (0.133) Loss 1.8271 (1.8271) Prec@1 49.219 (49.219) 99 | Epoch: [10][200/391] Time 0.035 (0.036) Data 0.000 (0.001) Loss 1.9511 (1.8538) Prec@1 47.656 (49.957) 100 | Epoch: [10][390/391] Time 0.034 (0.036) Data 0.001 (0.001) Loss 1.6726 (1.8487) Prec@1 57.500 (50.168) 101 | Total time for epoch [10] : 13.887 102 | Test: [0/79] Time 0.119 (0.119) Loss 1.5925 (1.5925) Prec@1 58.594 (58.594) 103 | * Prec@1 50.860 104 | current lr 4.96277e-02 105 | Epoch: [11][0/391] Time 0.227 (0.227) Data 0.170 (0.170) Loss 1.9841 (1.9841) Prec@1 49.219 (49.219) 106 | Epoch: [11][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 1.6576 (1.7566) Prec@1 50.781 (52.301) 107 | Epoch: [11][390/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 1.7557 (1.7669) Prec@1 52.500 (51.988) 108 | Total time for epoch [11] : 13.918 109 | Test: [0/79] Time 0.115 (0.115) Loss 1.8533 (1.8533) Prec@1 53.125 (53.125) 110 | * Prec@1 50.710 111 | current lr 4.95572e-02 112 | Epoch: [12][0/391] Time 0.159 (0.159) Data 0.115 (0.115) Loss 1.8231 (1.8231) Prec@1 52.344 (52.344) 113 | Epoch: [12][200/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 1.7106 (1.7073) Prec@1 56.250 (53.821) 114 | Epoch: [12][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 1.9403 (1.7074) Prec@1 55.000 (53.538) 115 | Total time for epoch [12] : 14.064 116 | Test: [0/79] Time 0.169 (0.169) Loss 1.6839 (1.6839) Prec@1 52.344 (52.344) 117 | * Prec@1 53.240 118 | current lr 4.94806e-02 119 | Epoch: [13][0/391] Time 0.172 (0.172) Data 0.131 (0.131) Loss 1.6159 (1.6159) Prec@1 53.906 (53.906) 120 | Epoch: [13][200/391] Time 0.035 (0.036) Data 0.000 (0.001) Loss 1.5824 (1.6489) Prec@1 60.156 (55.057) 121 | Epoch: [13][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 1.6527 (1.6426) Prec@1 52.500 (55.286) 122 | Total time for epoch [13] : 13.966 123 | Test: [0/79] Time 0.131 (0.131) Loss 1.4623 (1.4623) Prec@1 57.812 (57.812) 124 | * Prec@1 55.500 125 | current lr 4.93979e-02 126 | Epoch: [14][0/391] Time 0.164 (0.164) Data 0.115 (0.115) Loss 1.4891 (1.4891) Prec@1 59.375 (59.375) 127 | Epoch: [14][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 1.8948 (1.5902) Prec@1 45.312 (56.681) 128 | Epoch: [14][390/391] Time 0.031 (0.035) Data 0.000 (0.001) Loss 1.5889 (1.5910) Prec@1 56.250 (56.574) 129 | Total time for epoch [14] : 13.843 130 | Test: [0/79] Time 0.103 (0.103) Loss 1.5179 (1.5179) Prec@1 55.469 (55.469) 131 | * Prec@1 55.350 132 | current lr 4.93092e-02 133 | Epoch: [15][0/391] Time 0.169 (0.169) Data 0.126 (0.126) Loss 1.5416 (1.5416) Prec@1 57.812 (57.812) 134 | Epoch: [15][200/391] Time 0.031 (0.035) Data 0.000 (0.001) Loss 1.5458 (1.5392) Prec@1 52.344 (57.929) 135 | Epoch: [15][390/391] Time 0.032 (0.035) Data 0.000 (0.001) Loss 1.5886 (1.5500) Prec@1 57.500 (57.760) 136 | Total time for epoch [15] : 13.603 137 | Test: [0/79] Time 0.148 (0.148) Loss 1.4142 (1.4142) Prec@1 60.938 (60.938) 138 | * Prec@1 55.360 139 | current lr 4.92146e-02 140 | Epoch: [16][0/391] Time 0.251 (0.251) Data 0.190 (0.190) Loss 1.5656 (1.5656) Prec@1 59.375 (59.375) 141 | Epoch: [16][200/391] Time 0.039 (0.036) Data 0.000 (0.001) Loss 1.4889 (1.5145) Prec@1 63.281 (58.493) 142 | Epoch: [16][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 1.5526 (1.5148) Prec@1 55.000 (58.674) 143 | Total time for epoch [16] : 13.855 144 | Test: [0/79] Time 0.133 (0.133) Loss 1.4223 (1.4223) Prec@1 64.844 (64.844) 145 | * Prec@1 57.790 146 | current lr 4.91139e-02 147 | Epoch: [17][0/391] Time 0.220 (0.220) Data 0.176 (0.176) Loss 1.5259 (1.5259) Prec@1 58.594 (58.594) 148 | Epoch: [17][200/391] Time 0.033 (0.037) Data 0.000 (0.001) Loss 1.5923 (1.4658) Prec@1 58.594 (60.036) 149 | Epoch: [17][390/391] Time 0.027 (0.036) Data 0.000 (0.001) Loss 1.3029 (1.4871) Prec@1 62.500 (59.282) 150 | Total time for epoch [17] : 14.174 151 | Test: [0/79] Time 0.133 (0.133) Loss 1.3500 (1.3500) Prec@1 60.156 (60.156) 152 | * Prec@1 57.770 153 | current lr 4.90073e-02 154 | Epoch: [18][0/391] Time 0.175 (0.175) Data 0.124 (0.124) Loss 1.3702 (1.3702) Prec@1 60.938 (60.938) 155 | Epoch: [18][200/391] Time 0.040 (0.036) Data 0.000 (0.001) Loss 1.5570 (1.4531) Prec@1 53.906 (60.160) 156 | Epoch: [18][390/391] Time 0.031 (0.035) Data 0.000 (0.001) Loss 1.7360 (1.4567) Prec@1 51.250 (60.060) 157 | Total time for epoch [18] : 13.734 158 | Test: [0/79] Time 0.099 (0.099) Loss 1.2751 (1.2751) Prec@1 62.500 (62.500) 159 | * Prec@1 57.770 160 | current lr 4.88948e-02 161 | Epoch: [19][0/391] Time 0.177 (0.177) Data 0.121 (0.121) Loss 1.2061 (1.2061) Prec@1 67.969 (67.969) 162 | Epoch: [19][200/391] Time 0.038 (0.036) Data 0.000 (0.001) Loss 1.3683 (1.3912) Prec@1 64.844 (61.975) 163 | Epoch: [19][390/391] Time 0.030 (0.035) Data 0.000 (0.001) Loss 1.3694 (1.4159) Prec@1 58.750 (61.342) 164 | Total time for epoch [19] : 13.729 165 | Test: [0/79] Time 0.109 (0.109) Loss 1.2923 (1.2923) Prec@1 64.062 (64.062) 166 | * Prec@1 59.710 167 | current lr 4.87764e-02 168 | Epoch: [20][0/391] Time 0.176 (0.176) Data 0.126 (0.126) Loss 1.6124 (1.6124) Prec@1 60.156 (60.156) 169 | Epoch: [20][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 1.4108 (1.3909) Prec@1 57.812 (61.913) 170 | Epoch: [20][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 1.4965 (1.4006) Prec@1 58.750 (61.570) 171 | Total time for epoch [20] : 13.931 172 | Test: [0/79] Time 0.111 (0.111) Loss 1.4168 (1.4168) Prec@1 60.938 (60.938) 173 | * Prec@1 56.620 174 | current lr 4.86521e-02 175 | Epoch: [21][0/391] Time 0.181 (0.181) Data 0.122 (0.122) Loss 1.4468 (1.4468) Prec@1 61.719 (61.719) 176 | Epoch: [21][200/391] Time 0.035 (0.036) Data 0.000 (0.001) Loss 1.3725 (1.3573) Prec@1 62.500 (63.021) 177 | Epoch: [21][390/391] Time 0.027 (0.036) Data 0.000 (0.001) Loss 1.7540 (1.3790) Prec@1 51.250 (62.270) 178 | Total time for epoch [21] : 13.910 179 | Test: [0/79] Time 0.118 (0.118) Loss 1.3211 (1.3211) Prec@1 65.625 (65.625) 180 | * Prec@1 60.810 181 | current lr 4.85220e-02 182 | Epoch: [22][0/391] Time 0.177 (0.177) Data 0.126 (0.126) Loss 1.1393 (1.1393) Prec@1 71.875 (71.875) 183 | Epoch: [22][200/391] Time 0.044 (0.036) Data 0.000 (0.001) Loss 1.3178 (1.3453) Prec@1 64.844 (63.013) 184 | Epoch: [22][390/391] Time 0.027 (0.036) Data 0.000 (0.001) Loss 1.4000 (1.3575) Prec@1 58.750 (62.654) 185 | Total time for epoch [22] : 13.892 186 | Test: [0/79] Time 0.099 (0.099) Loss 1.2709 (1.2709) Prec@1 65.625 (65.625) 187 | * Prec@1 59.270 188 | current lr 4.83861e-02 189 | Epoch: [23][0/391] Time 0.208 (0.208) Data 0.162 (0.162) Loss 1.2530 (1.2530) Prec@1 63.281 (63.281) 190 | Epoch: [23][200/391] Time 0.030 (0.036) Data 0.000 (0.001) Loss 1.3245 (1.3191) Prec@1 60.938 (63.981) 191 | Epoch: [23][390/391] Time 0.029 (0.035) Data 0.000 (0.001) Loss 1.4516 (1.3304) Prec@1 66.250 (63.674) 192 | Total time for epoch [23] : 13.846 193 | Test: [0/79] Time 0.101 (0.101) Loss 1.2957 (1.2957) Prec@1 64.062 (64.062) 194 | * Prec@1 60.760 195 | current lr 4.82444e-02 196 | Epoch: [24][0/391] Time 0.168 (0.168) Data 0.120 (0.120) Loss 1.3835 (1.3835) Prec@1 60.938 (60.938) 197 | Epoch: [24][200/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 1.1301 (1.2942) Prec@1 74.219 (64.852) 198 | Epoch: [24][390/391] Time 0.029 (0.035) Data 0.000 (0.001) Loss 1.3156 (1.3127) Prec@1 65.000 (64.208) 199 | Total time for epoch [24] : 13.825 200 | Test: [0/79] Time 0.111 (0.111) Loss 1.1199 (1.1199) Prec@1 71.875 (71.875) 201 | * Prec@1 62.240 202 | current lr 4.80970e-02 203 | Epoch: [25][0/391] Time 0.165 (0.165) Data 0.124 (0.124) Loss 1.2802 (1.2802) Prec@1 64.062 (64.062) 204 | Epoch: [25][200/391] Time 0.041 (0.037) Data 0.000 (0.001) Loss 1.2222 (1.2735) Prec@1 67.188 (65.264) 205 | Epoch: [25][390/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 1.3526 (1.2938) Prec@1 60.000 (64.726) 206 | Total time for epoch [25] : 14.051 207 | Test: [0/79] Time 0.114 (0.114) Loss 1.3873 (1.3873) Prec@1 61.719 (61.719) 208 | * Prec@1 59.460 209 | current lr 4.79439e-02 210 | Epoch: [26][0/391] Time 0.170 (0.170) Data 0.129 (0.129) Loss 1.3970 (1.3970) Prec@1 64.062 (64.062) 211 | Epoch: [26][200/391] Time 0.032 (0.035) Data 0.000 (0.001) Loss 1.2081 (1.2733) Prec@1 71.094 (65.470) 212 | Epoch: [26][390/391] Time 0.031 (0.036) Data 0.000 (0.001) Loss 1.4791 (1.2924) Prec@1 58.750 (64.772) 213 | Total time for epoch [26] : 13.882 214 | Test: [0/79] Time 0.114 (0.114) Loss 1.2558 (1.2558) Prec@1 64.844 (64.844) 215 | * Prec@1 61.000 216 | current lr 4.77851e-02 217 | Epoch: [27][0/391] Time 0.193 (0.193) Data 0.139 (0.139) Loss 1.2837 (1.2837) Prec@1 66.406 (66.406) 218 | Epoch: [27][200/391] Time 0.031 (0.036) Data 0.000 (0.001) Loss 1.2417 (1.2458) Prec@1 70.312 (65.668) 219 | Epoch: [27][390/391] Time 0.029 (0.035) Data 0.000 (0.001) Loss 1.2700 (1.2648) Prec@1 63.750 (65.288) 220 | Total time for epoch [27] : 13.830 221 | Test: [0/79] Time 0.113 (0.113) Loss 1.2147 (1.2147) Prec@1 69.531 (69.531) 222 | * Prec@1 62.350 223 | current lr 4.76207e-02 224 | Epoch: [28][0/391] Time 0.186 (0.186) Data 0.139 (0.139) Loss 1.2099 (1.2099) Prec@1 64.844 (64.844) 225 | Epoch: [28][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 1.2039 (1.2328) Prec@1 65.625 (66.262) 226 | Epoch: [28][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 1.3396 (1.2607) Prec@1 62.500 (65.548) 227 | Total time for epoch [28] : 13.741 228 | Test: [0/79] Time 0.118 (0.118) Loss 1.2044 (1.2044) Prec@1 66.406 (66.406) 229 | * Prec@1 61.220 230 | current lr 4.74507e-02 231 | Epoch: [29][0/391] Time 0.171 (0.171) Data 0.129 (0.129) Loss 1.2851 (1.2851) Prec@1 65.625 (65.625) 232 | Epoch: [29][200/391] Time 0.042 (0.036) Data 0.001 (0.001) Loss 1.2465 (1.2220) Prec@1 64.062 (66.457) 233 | Epoch: [29][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 1.3808 (1.2431) Prec@1 60.000 (66.052) 234 | Total time for epoch [29] : 13.977 235 | Test: [0/79] Time 0.096 (0.096) Loss 1.2305 (1.2305) Prec@1 68.750 (68.750) 236 | * Prec@1 60.400 237 | current lr 4.72752e-02 238 | Epoch: [30][0/391] Time 0.165 (0.165) Data 0.123 (0.123) Loss 0.9275 (0.9275) Prec@1 76.562 (76.562) 239 | Epoch: [30][200/391] Time 0.035 (0.037) Data 0.000 (0.001) Loss 1.1931 (1.2162) Prec@1 68.750 (66.694) 240 | Epoch: [30][390/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 1.2399 (1.2288) Prec@1 63.750 (66.300) 241 | Total time for epoch [30] : 14.081 242 | Test: [0/79] Time 0.109 (0.109) Loss 1.2226 (1.2226) Prec@1 66.406 (66.406) 243 | * Prec@1 63.610 244 | current lr 4.70941e-02 245 | Epoch: [31][0/391] Time 0.178 (0.178) Data 0.136 (0.136) Loss 1.0921 (1.0921) Prec@1 73.438 (73.438) 246 | Epoch: [31][200/391] Time 0.040 (0.036) Data 0.000 (0.001) Loss 1.3295 (1.1868) Prec@1 60.156 (67.522) 247 | Epoch: [31][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 1.2154 (1.2141) Prec@1 68.750 (66.722) 248 | Total time for epoch [31] : 14.016 249 | Test: [0/79] Time 0.099 (0.099) Loss 1.2594 (1.2594) Prec@1 64.062 (64.062) 250 | * Prec@1 62.480 251 | current lr 4.69077e-02 252 | Epoch: [32][0/391] Time 0.165 (0.165) Data 0.114 (0.114) Loss 1.1879 (1.1879) Prec@1 70.312 (70.312) 253 | Epoch: [32][200/391] Time 0.033 (0.035) Data 0.000 (0.001) Loss 1.0928 (1.1873) Prec@1 74.219 (67.724) 254 | Epoch: [32][390/391] Time 0.029 (0.035) Data 0.000 (0.001) Loss 1.0679 (1.2132) Prec@1 62.500 (66.802) 255 | Total time for epoch [32] : 13.756 256 | Test: [0/79] Time 0.106 (0.106) Loss 1.1971 (1.1971) Prec@1 66.406 (66.406) 257 | * Prec@1 61.340 258 | current lr 4.67158e-02 259 | Epoch: [33][0/391] Time 0.183 (0.183) Data 0.125 (0.125) Loss 1.3155 (1.3155) Prec@1 64.844 (64.844) 260 | Epoch: [33][200/391] Time 0.031 (0.036) Data 0.000 (0.001) Loss 1.1415 (1.1787) Prec@1 66.406 (67.728) 261 | Epoch: [33][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 1.4957 (1.1929) Prec@1 58.750 (67.360) 262 | Total time for epoch [33] : 13.744 263 | Test: [0/79] Time 0.114 (0.114) Loss 1.2990 (1.2990) Prec@1 63.281 (63.281) 264 | * Prec@1 59.330 265 | current lr 4.65186e-02 266 | Epoch: [34][0/391] Time 0.204 (0.204) Data 0.144 (0.144) Loss 1.1421 (1.1421) Prec@1 71.094 (71.094) 267 | Epoch: [34][200/391] Time 0.031 (0.037) Data 0.000 (0.001) Loss 1.2101 (1.1659) Prec@1 62.500 (68.171) 268 | Epoch: [34][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 1.2300 (1.1873) Prec@1 58.750 (67.498) 269 | Total time for epoch [34] : 14.039 270 | Test: [0/79] Time 0.117 (0.117) Loss 1.1284 (1.1284) Prec@1 69.531 (69.531) 271 | * Prec@1 64.070 272 | current lr 4.63160e-02 273 | Epoch: [35][0/391] Time 0.176 (0.176) Data 0.124 (0.124) Loss 0.8910 (0.8910) Prec@1 77.344 (77.344) 274 | Epoch: [35][200/391] Time 0.041 (0.036) Data 0.000 (0.001) Loss 0.9192 (1.1475) Prec@1 71.875 (69.053) 275 | Epoch: [35][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 1.2841 (1.1764) Prec@1 58.750 (68.006) 276 | Total time for epoch [35] : 13.708 277 | Test: [0/79] Time 0.123 (0.123) Loss 1.1128 (1.1128) Prec@1 68.750 (68.750) 278 | * Prec@1 63.350 279 | current lr 4.61082e-02 280 | Epoch: [36][0/391] Time 0.177 (0.177) Data 0.128 (0.128) Loss 0.8853 (0.8853) Prec@1 75.781 (75.781) 281 | Epoch: [36][200/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 1.2811 (1.1427) Prec@1 67.188 (68.890) 282 | Epoch: [36][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 1.2902 (1.1707) Prec@1 58.750 (68.094) 283 | Total time for epoch [36] : 13.858 284 | Test: [0/79] Time 0.118 (0.118) Loss 1.3326 (1.3326) Prec@1 64.062 (64.062) 285 | * Prec@1 62.140 286 | current lr 4.58952e-02 287 | Epoch: [37][0/391] Time 0.168 (0.168) Data 0.122 (0.122) Loss 1.1050 (1.1050) Prec@1 70.312 (70.312) 288 | Epoch: [37][200/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 1.0733 (1.1209) Prec@1 71.875 (69.508) 289 | Epoch: [37][390/391] Time 0.050 (0.036) Data 0.000 (0.001) Loss 1.2255 (1.1582) Prec@1 68.750 (68.546) 290 | Total time for epoch [37] : 13.944 291 | Test: [0/79] Time 0.103 (0.103) Loss 1.0864 (1.0864) Prec@1 70.312 (70.312) 292 | * Prec@1 62.730 293 | current lr 4.56770e-02 294 | Epoch: [38][0/391] Time 0.185 (0.185) Data 0.134 (0.134) Loss 1.0666 (1.0666) Prec@1 73.438 (73.438) 295 | Epoch: [38][200/391] Time 0.032 (0.037) Data 0.000 (0.001) Loss 1.2919 (1.1232) Prec@1 64.844 (69.314) 296 | Epoch: [38][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 1.2672 (1.1483) Prec@1 66.250 (68.544) 297 | Total time for epoch [38] : 14.015 298 | Test: [0/79] Time 0.105 (0.105) Loss 1.0734 (1.0734) Prec@1 71.094 (71.094) 299 | * Prec@1 64.180 300 | current lr 4.54537e-02 301 | Epoch: [39][0/391] Time 0.166 (0.166) Data 0.118 (0.118) Loss 1.0989 (1.0989) Prec@1 68.750 (68.750) 302 | Epoch: [39][200/391] Time 0.035 (0.036) Data 0.000 (0.001) Loss 1.2583 (1.1215) Prec@1 61.719 (69.415) 303 | Epoch: [39][390/391] Time 0.030 (0.036) Data 0.000 (0.001) Loss 1.1707 (1.1425) Prec@1 71.250 (68.816) 304 | Total time for epoch [39] : 13.976 305 | Test: [0/79] Time 0.100 (0.100) Loss 1.0058 (1.0058) Prec@1 69.531 (69.531) 306 | * Prec@1 65.150 307 | current lr 4.52254e-02 308 | Epoch: [40][0/391] Time 0.169 (0.169) Data 0.116 (0.116) Loss 1.1316 (1.1316) Prec@1 68.750 (68.750) 309 | Epoch: [40][200/391] Time 0.035 (0.036) Data 0.000 (0.001) Loss 1.1409 (1.0986) Prec@1 69.531 (69.935) 310 | Epoch: [40][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 1.2202 (1.1273) Prec@1 66.250 (69.218) 311 | Total time for epoch [40] : 13.896 312 | Test: [0/79] Time 0.108 (0.108) Loss 1.1284 (1.1284) Prec@1 69.531 (69.531) 313 | * Prec@1 63.690 314 | current lr 4.49921e-02 315 | Epoch: [41][0/391] Time 0.157 (0.157) Data 0.111 (0.111) Loss 0.9477 (0.9477) Prec@1 71.875 (71.875) 316 | Epoch: [41][200/391] Time 0.035 (0.036) Data 0.000 (0.001) Loss 1.1201 (1.1037) Prec@1 71.094 (69.625) 317 | Epoch: [41][390/391] Time 0.030 (0.035) Data 0.000 (0.001) Loss 1.1093 (1.1232) Prec@1 72.500 (69.280) 318 | Total time for epoch [41] : 13.764 319 | Test: [0/79] Time 0.122 (0.122) Loss 1.1084 (1.1084) Prec@1 70.312 (70.312) 320 | * Prec@1 65.360 321 | current lr 4.47539e-02 322 | Epoch: [42][0/391] Time 0.161 (0.161) Data 0.116 (0.116) Loss 1.0310 (1.0310) Prec@1 72.656 (72.656) 323 | Epoch: [42][200/391] Time 0.036 (0.036) Data 0.000 (0.001) Loss 1.2016 (1.0800) Prec@1 67.188 (70.608) 324 | Epoch: [42][390/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 1.2067 (1.1146) Prec@1 68.750 (69.554) 325 | Total time for epoch [42] : 14.011 326 | Test: [0/79] Time 0.103 (0.103) Loss 1.1241 (1.1241) Prec@1 67.188 (67.188) 327 | * Prec@1 64.990 328 | current lr 4.45108e-02 329 | Epoch: [43][0/391] Time 0.176 (0.176) Data 0.136 (0.136) Loss 0.9557 (0.9557) Prec@1 76.562 (76.562) 330 | Epoch: [43][200/391] Time 0.035 (0.036) Data 0.000 (0.001) Loss 1.1484 (1.0717) Prec@1 70.312 (71.199) 331 | Epoch: [43][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 1.3396 (1.0946) Prec@1 66.250 (70.312) 332 | Total time for epoch [43] : 14.134 333 | Test: [0/79] Time 0.108 (0.108) Loss 1.0538 (1.0538) Prec@1 70.312 (70.312) 334 | * Prec@1 65.230 335 | current lr 4.42628e-02 336 | Epoch: [44][0/391] Time 0.191 (0.191) Data 0.133 (0.133) Loss 0.9990 (0.9990) Prec@1 74.219 (74.219) 337 | Epoch: [44][200/391] Time 0.040 (0.036) Data 0.000 (0.001) Loss 1.1760 (1.0879) Prec@1 65.625 (70.215) 338 | Epoch: [44][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 1.2145 (1.1015) Prec@1 66.250 (69.982) 339 | Total time for epoch [44] : 13.929 340 | Test: [0/79] Time 0.094 (0.094) Loss 1.1627 (1.1627) Prec@1 65.625 (65.625) 341 | * Prec@1 64.490 342 | current lr 4.40101e-02 343 | Epoch: [45][0/391] Time 0.170 (0.170) Data 0.116 (0.116) Loss 1.0231 (1.0231) Prec@1 70.312 (70.312) 344 | Epoch: [45][200/391] Time 0.035 (0.036) Data 0.000 (0.001) Loss 1.3319 (1.0622) Prec@1 67.188 (71.035) 345 | Epoch: [45][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 1.1686 (1.0831) Prec@1 68.750 (70.404) 346 | Total time for epoch [45] : 13.983 347 | Test: [0/79] Time 0.110 (0.110) Loss 1.1447 (1.1447) Prec@1 65.625 (65.625) 348 | * Prec@1 66.750 349 | current lr 4.37528e-02 350 | Epoch: [46][0/391] Time 0.180 (0.180) Data 0.126 (0.126) Loss 1.2325 (1.2325) Prec@1 64.062 (64.062) 351 | Epoch: [46][200/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 1.0885 (1.0730) Prec@1 71.875 (70.896) 352 | Epoch: [46][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 1.2452 (1.0833) Prec@1 63.750 (70.640) 353 | Total time for epoch [46] : 13.978 354 | Test: [0/79] Time 0.096 (0.096) Loss 1.0211 (1.0211) Prec@1 71.094 (71.094) 355 | * Prec@1 66.790 356 | current lr 4.34908e-02 357 | Epoch: [47][0/391] Time 0.181 (0.181) Data 0.125 (0.125) Loss 1.1235 (1.1235) Prec@1 70.312 (70.312) 358 | Epoch: [47][200/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 0.9524 (1.0562) Prec@1 75.781 (71.416) 359 | Epoch: [47][390/391] Time 0.032 (0.035) Data 0.000 (0.001) Loss 1.0360 (1.0724) Prec@1 68.750 (70.946) 360 | Total time for epoch [47] : 13.714 361 | Test: [0/79] Time 0.118 (0.118) Loss 1.1440 (1.1440) Prec@1 67.969 (67.969) 362 | * Prec@1 63.400 363 | current lr 4.32242e-02 364 | Epoch: [48][0/391] Time 0.219 (0.219) Data 0.174 (0.174) Loss 1.0776 (1.0776) Prec@1 72.656 (72.656) 365 | Epoch: [48][200/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 1.3021 (1.0374) Prec@1 60.156 (71.618) 366 | Epoch: [48][390/391] Time 0.047 (0.035) Data 0.000 (0.001) Loss 1.1863 (1.0695) Prec@1 68.750 (70.730) 367 | Total time for epoch [48] : 13.877 368 | Test: [0/79] Time 0.130 (0.130) Loss 0.8845 (0.8845) Prec@1 75.000 (75.000) 369 | * Prec@1 66.590 370 | current lr 4.29532e-02 371 | Epoch: [49][0/391] Time 0.197 (0.197) Data 0.141 (0.141) Loss 0.9009 (0.9009) Prec@1 78.125 (78.125) 372 | Epoch: [49][200/391] Time 0.037 (0.036) Data 0.001 (0.001) Loss 1.0031 (1.0357) Prec@1 74.219 (71.933) 373 | Epoch: [49][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 1.2646 (1.0572) Prec@1 65.000 (71.224) 374 | Total time for epoch [49] : 13.975 375 | Test: [0/79] Time 0.096 (0.096) Loss 1.1316 (1.1316) Prec@1 67.188 (67.188) 376 | * Prec@1 66.270 377 | current lr 4.26777e-02 378 | Epoch: [50][0/391] Time 0.159 (0.159) Data 0.120 (0.120) Loss 1.2429 (1.2429) Prec@1 66.406 (66.406) 379 | Epoch: [50][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 0.9329 (1.0402) Prec@1 75.000 (71.681) 380 | Epoch: [50][390/391] Time 0.030 (0.036) Data 0.000 (0.001) Loss 1.1209 (1.0538) Prec@1 63.750 (71.302) 381 | Total time for epoch [50] : 14.138 382 | Test: [0/79] Time 0.100 (0.100) Loss 1.2156 (1.2156) Prec@1 66.406 (66.406) 383 | * Prec@1 62.300 384 | current lr 4.23978e-02 385 | Epoch: [51][0/391] Time 0.166 (0.166) Data 0.116 (0.116) Loss 0.8190 (0.8190) Prec@1 75.781 (75.781) 386 | Epoch: [51][200/391] Time 0.034 (0.035) Data 0.000 (0.001) Loss 1.1333 (1.0192) Prec@1 69.531 (72.260) 387 | Epoch: [51][390/391] Time 0.030 (0.035) Data 0.000 (0.001) Loss 0.9274 (1.0468) Prec@1 71.250 (71.594) 388 | Total time for epoch [51] : 13.816 389 | Test: [0/79] Time 0.105 (0.105) Loss 0.9675 (0.9675) Prec@1 74.219 (74.219) 390 | * Prec@1 67.710 391 | current lr 4.21137e-02 392 | Epoch: [52][0/391] Time 0.197 (0.197) Data 0.136 (0.136) Loss 1.1376 (1.1376) Prec@1 69.531 (69.531) 393 | Epoch: [52][200/391] Time 0.041 (0.036) Data 0.000 (0.001) Loss 1.0149 (0.9980) Prec@1 67.969 (73.130) 394 | Epoch: [52][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 0.9875 (1.0335) Prec@1 72.500 (71.908) 395 | Total time for epoch [52] : 13.797 396 | Test: [0/79] Time 0.117 (0.117) Loss 0.9920 (0.9920) Prec@1 70.312 (70.312) 397 | * Prec@1 67.650 398 | current lr 4.18253e-02 399 | Epoch: [53][0/391] Time 0.179 (0.179) Data 0.125 (0.125) Loss 0.9081 (0.9081) Prec@1 75.781 (75.781) 400 | Epoch: [53][200/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 1.1957 (1.0134) Prec@1 64.844 (72.641) 401 | Epoch: [53][390/391] Time 0.031 (0.036) Data 0.000 (0.001) Loss 0.9722 (1.0309) Prec@1 71.250 (72.078) 402 | Total time for epoch [53] : 13.887 403 | Test: [0/79] Time 0.122 (0.122) Loss 1.0226 (1.0226) Prec@1 68.750 (68.750) 404 | * Prec@1 68.600 405 | current lr 4.15328e-02 406 | Epoch: [54][0/391] Time 0.164 (0.164) Data 0.124 (0.124) Loss 1.0617 (1.0617) Prec@1 72.656 (72.656) 407 | Epoch: [54][200/391] Time 0.030 (0.036) Data 0.000 (0.001) Loss 0.8784 (1.0013) Prec@1 78.125 (73.185) 408 | Epoch: [54][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 0.9817 (1.0165) Prec@1 73.750 (72.510) 409 | Total time for epoch [54] : 14.064 410 | Test: [0/79] Time 0.109 (0.109) Loss 1.0987 (1.0987) Prec@1 67.969 (67.969) 411 | * Prec@1 66.480 412 | current lr 4.12362e-02 413 | Epoch: [55][0/391] Time 0.175 (0.175) Data 0.127 (0.127) Loss 1.0582 (1.0582) Prec@1 67.188 (67.188) 414 | Epoch: [55][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 0.9239 (1.0050) Prec@1 73.438 (72.761) 415 | Epoch: [55][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 0.9558 (1.0173) Prec@1 75.000 (72.380) 416 | Total time for epoch [55] : 13.881 417 | Test: [0/79] Time 0.107 (0.107) Loss 1.1378 (1.1378) Prec@1 67.969 (67.969) 418 | * Prec@1 67.090 419 | current lr 4.09356e-02 420 | Epoch: [56][0/391] Time 0.196 (0.196) Data 0.138 (0.138) Loss 0.9094 (0.9094) Prec@1 79.688 (79.688) 421 | Epoch: [56][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 0.9961 (0.9844) Prec@1 71.875 (73.550) 422 | Epoch: [56][390/391] Time 0.029 (0.035) Data 0.000 (0.001) Loss 1.1121 (1.0083) Prec@1 66.250 (72.780) 423 | Total time for epoch [56] : 13.810 424 | Test: [0/79] Time 0.122 (0.122) Loss 1.0745 (1.0745) Prec@1 68.750 (68.750) 425 | * Prec@1 65.370 426 | current lr 4.06311e-02 427 | Epoch: [57][0/391] Time 0.179 (0.179) Data 0.132 (0.132) Loss 0.9158 (0.9158) Prec@1 73.438 (73.438) 428 | Epoch: [57][200/391] Time 0.035 (0.036) Data 0.000 (0.001) Loss 1.0290 (0.9767) Prec@1 73.438 (73.457) 429 | Epoch: [57][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 0.8608 (1.0015) Prec@1 85.000 (72.920) 430 | Total time for epoch [57] : 13.832 431 | Test: [0/79] Time 0.109 (0.109) Loss 1.0964 (1.0964) Prec@1 73.438 (73.438) 432 | * Prec@1 67.370 433 | current lr 4.03227e-02 434 | Epoch: [58][0/391] Time 0.187 (0.187) Data 0.137 (0.137) Loss 0.8527 (0.8527) Prec@1 76.562 (76.562) 435 | Epoch: [58][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 1.1241 (0.9668) Prec@1 69.531 (73.982) 436 | Epoch: [58][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 1.1489 (0.9884) Prec@1 70.000 (73.310) 437 | Total time for epoch [58] : 13.947 438 | Test: [0/79] Time 0.104 (0.104) Loss 1.0696 (1.0696) Prec@1 68.750 (68.750) 439 | * Prec@1 67.510 440 | current lr 4.00105e-02 441 | Epoch: [59][0/391] Time 0.173 (0.173) Data 0.128 (0.128) Loss 0.9455 (0.9455) Prec@1 75.000 (75.000) 442 | Epoch: [59][200/391] Time 0.039 (0.036) Data 0.000 (0.001) Loss 1.1182 (0.9668) Prec@1 66.406 (73.838) 443 | Epoch: [59][390/391] Time 0.038 (0.035) Data 0.000 (0.001) Loss 1.1531 (0.9889) Prec@1 68.750 (73.248) 444 | Total time for epoch [59] : 13.799 445 | Test: [0/79] Time 0.118 (0.118) Loss 0.9864 (0.9864) Prec@1 69.531 (69.531) 446 | * Prec@1 66.810 447 | current lr 3.96946e-02 448 | Epoch: [60][0/391] Time 0.173 (0.173) Data 0.127 (0.127) Loss 0.8649 (0.8649) Prec@1 75.781 (75.781) 449 | Epoch: [60][200/391] Time 0.040 (0.036) Data 0.000 (0.001) Loss 1.0392 (0.9796) Prec@1 75.000 (73.418) 450 | Epoch: [60][390/391] Time 0.037 (0.036) Data 0.000 (0.001) Loss 0.9569 (0.9850) Prec@1 78.750 (73.312) 451 | Total time for epoch [60] : 13.883 452 | Test: [0/79] Time 0.115 (0.115) Loss 1.0701 (1.0701) Prec@1 73.438 (73.438) 453 | * Prec@1 66.580 454 | current lr 3.93751e-02 455 | Epoch: [61][0/391] Time 0.171 (0.171) Data 0.130 (0.130) Loss 0.9020 (0.9020) Prec@1 74.219 (74.219) 456 | Epoch: [61][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 0.9789 (0.9517) Prec@1 71.875 (74.106) 457 | Epoch: [61][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 0.9323 (0.9727) Prec@1 70.000 (73.576) 458 | Total time for epoch [61] : 13.760 459 | Test: [0/79] Time 0.118 (0.118) Loss 0.9486 (0.9486) Prec@1 75.781 (75.781) 460 | * Prec@1 67.460 461 | current lr 3.90521e-02 462 | Epoch: [62][0/391] Time 0.217 (0.217) Data 0.166 (0.166) Loss 0.8728 (0.8728) Prec@1 75.000 (75.000) 463 | Epoch: [62][200/391] Time 0.035 (0.037) Data 0.000 (0.001) Loss 1.1394 (0.9504) Prec@1 65.625 (74.246) 464 | Epoch: [62][390/391] Time 0.027 (0.036) Data 0.000 (0.001) Loss 0.7575 (0.9696) Prec@1 83.750 (73.742) 465 | Total time for epoch [62] : 14.025 466 | Test: [0/79] Time 0.109 (0.109) Loss 0.9414 (0.9414) Prec@1 75.000 (75.000) 467 | * Prec@1 67.660 468 | current lr 3.87256e-02 469 | Epoch: [63][0/391] Time 0.194 (0.194) Data 0.153 (0.153) Loss 0.8440 (0.8440) Prec@1 80.469 (80.469) 470 | Epoch: [63][200/391] Time 0.036 (0.037) Data 0.000 (0.001) Loss 1.0363 (0.9413) Prec@1 67.969 (74.685) 471 | Epoch: [63][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 1.3105 (0.9633) Prec@1 63.750 (74.062) 472 | Total time for epoch [63] : 14.101 473 | Test: [0/79] Time 0.108 (0.108) Loss 1.0948 (1.0948) Prec@1 71.094 (71.094) 474 | * Prec@1 67.540 475 | current lr 3.83957e-02 476 | Epoch: [64][0/391] Time 0.159 (0.159) Data 0.113 (0.113) Loss 0.8467 (0.8467) Prec@1 75.781 (75.781) 477 | Epoch: [64][200/391] Time 0.037 (0.035) Data 0.001 (0.001) Loss 0.8654 (0.9327) Prec@1 78.906 (75.078) 478 | Epoch: [64][390/391] Time 0.029 (0.035) Data 0.000 (0.001) Loss 0.9656 (0.9546) Prec@1 73.750 (74.434) 479 | Total time for epoch [64] : 13.712 480 | Test: [0/79] Time 0.128 (0.128) Loss 0.9886 (0.9886) Prec@1 68.750 (68.750) 481 | * Prec@1 66.040 482 | current lr 3.80625e-02 483 | Epoch: [65][0/391] Time 0.158 (0.158) Data 0.116 (0.116) Loss 0.9209 (0.9209) Prec@1 75.000 (75.000) 484 | Epoch: [65][200/391] Time 0.039 (0.036) Data 0.000 (0.001) Loss 1.0225 (0.9233) Prec@1 71.875 (75.117) 485 | Epoch: [65][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 0.9934 (0.9496) Prec@1 75.000 (74.404) 486 | Total time for epoch [65] : 13.833 487 | Test: [0/79] Time 0.107 (0.107) Loss 0.8893 (0.8893) Prec@1 71.875 (71.875) 488 | * Prec@1 68.280 489 | current lr 3.77260e-02 490 | Epoch: [66][0/391] Time 0.182 (0.182) Data 0.132 (0.132) Loss 0.8678 (0.8678) Prec@1 78.906 (78.906) 491 | Epoch: [66][200/391] Time 0.036 (0.035) Data 0.000 (0.001) Loss 0.7546 (0.9279) Prec@1 78.906 (75.218) 492 | Epoch: [66][390/391] Time 0.029 (0.035) Data 0.000 (0.001) Loss 0.9398 (0.9424) Prec@1 73.750 (74.738) 493 | Total time for epoch [66] : 13.765 494 | Test: [0/79] Time 0.102 (0.102) Loss 1.1219 (1.1219) Prec@1 69.531 (69.531) 495 | * Prec@1 67.040 496 | current lr 3.73865e-02 497 | Epoch: [67][0/391] Time 0.165 (0.165) Data 0.121 (0.121) Loss 0.7447 (0.7447) Prec@1 84.375 (84.375) 498 | Epoch: [67][200/391] Time 0.036 (0.037) Data 0.000 (0.001) Loss 0.9157 (0.9179) Prec@1 75.781 (75.299) 499 | Epoch: [67][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 0.8274 (0.9398) Prec@1 75.000 (74.742) 500 | Total time for epoch [67] : 14.126 501 | Test: [0/79] Time 0.101 (0.101) Loss 1.0014 (1.0014) Prec@1 70.312 (70.312) 502 | * Prec@1 69.170 503 | current lr 3.70438e-02 504 | Epoch: [68][0/391] Time 0.186 (0.186) Data 0.136 (0.136) Loss 0.9484 (0.9484) Prec@1 72.656 (72.656) 505 | Epoch: [68][200/391] Time 0.031 (0.036) Data 0.000 (0.001) Loss 0.8614 (0.9174) Prec@1 78.125 (75.447) 506 | Epoch: [68][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 0.7986 (0.9317) Prec@1 77.500 (75.006) 507 | Total time for epoch [68] : 14.006 508 | Test: [0/79] Time 0.108 (0.108) Loss 1.0430 (1.0430) Prec@1 73.438 (73.438) 509 | * Prec@1 68.070 510 | current lr 3.66982e-02 511 | Epoch: [69][0/391] Time 0.167 (0.167) Data 0.127 (0.127) Loss 0.7949 (0.7949) Prec@1 78.125 (78.125) 512 | Epoch: [69][200/391] Time 0.039 (0.036) Data 0.000 (0.001) Loss 0.9330 (0.8884) Prec@1 75.781 (76.368) 513 | Epoch: [69][390/391] Time 0.035 (0.036) Data 0.000 (0.001) Loss 0.8001 (0.9133) Prec@1 80.000 (75.612) 514 | Total time for epoch [69] : 13.953 515 | Test: [0/79] Time 0.115 (0.115) Loss 0.9905 (0.9905) Prec@1 68.750 (68.750) 516 | * Prec@1 67.890 517 | current lr 3.63498e-02 518 | Epoch: [70][0/391] Time 0.183 (0.183) Data 0.123 (0.123) Loss 0.9090 (0.9090) Prec@1 78.125 (78.125) 519 | Epoch: [70][200/391] Time 0.038 (0.036) Data 0.001 (0.001) Loss 0.9782 (0.8894) Prec@1 71.094 (75.960) 520 | Epoch: [70][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 1.2943 (0.9121) Prec@1 65.000 (75.426) 521 | Total time for epoch [70] : 13.871 522 | Test: [0/79] Time 0.115 (0.115) Loss 0.9382 (0.9382) Prec@1 74.219 (74.219) 523 | * Prec@1 68.350 524 | current lr 3.59985e-02 525 | Epoch: [71][0/391] Time 0.180 (0.180) Data 0.127 (0.127) Loss 0.8484 (0.8484) Prec@1 76.562 (76.562) 526 | Epoch: [71][200/391] Time 0.031 (0.036) Data 0.000 (0.001) Loss 0.9576 (0.8887) Prec@1 72.656 (76.271) 527 | Epoch: [71][390/391] Time 0.029 (0.035) Data 0.000 (0.001) Loss 0.7815 (0.9044) Prec@1 83.750 (75.878) 528 | Total time for epoch [71] : 13.865 529 | Test: [0/79] Time 0.117 (0.117) Loss 1.0875 (1.0875) Prec@1 69.531 (69.531) 530 | * Prec@1 68.740 531 | current lr 3.56445e-02 532 | Epoch: [72][0/391] Time 0.192 (0.192) Data 0.150 (0.150) Loss 0.9236 (0.9236) Prec@1 76.562 (76.562) 533 | Epoch: [72][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.7782 (0.8855) Prec@1 75.000 (76.384) 534 | Epoch: [72][390/391] Time 0.029 (0.035) Data 0.000 (0.001) Loss 0.9637 (0.9012) Prec@1 70.000 (75.852) 535 | Total time for epoch [72] : 13.690 536 | Test: [0/79] Time 0.118 (0.118) Loss 0.9425 (0.9425) Prec@1 67.188 (67.188) 537 | * Prec@1 68.910 538 | current lr 3.52879e-02 539 | Epoch: [73][0/391] Time 0.174 (0.174) Data 0.121 (0.121) Loss 0.7488 (0.7488) Prec@1 82.031 (82.031) 540 | Epoch: [73][200/391] Time 0.033 (0.035) Data 0.000 (0.001) Loss 0.8334 (0.8574) Prec@1 80.469 (76.866) 541 | Epoch: [73][390/391] Time 0.029 (0.035) Data 0.000 (0.001) Loss 0.6466 (0.8956) Prec@1 82.500 (75.884) 542 | Total time for epoch [73] : 13.753 543 | Test: [0/79] Time 0.122 (0.122) Loss 1.0443 (1.0443) Prec@1 71.875 (71.875) 544 | * Prec@1 66.290 545 | current lr 3.49287e-02 546 | Epoch: [74][0/391] Time 0.170 (0.170) Data 0.121 (0.121) Loss 0.8102 (0.8102) Prec@1 80.469 (80.469) 547 | Epoch: [74][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.9991 (0.8807) Prec@1 71.094 (76.586) 548 | Epoch: [74][390/391] Time 0.031 (0.036) Data 0.000 (0.001) Loss 0.7677 (0.8920) Prec@1 81.250 (76.222) 549 | Total time for epoch [74] : 14.029 550 | Test: [0/79] Time 0.140 (0.140) Loss 0.8876 (0.8876) Prec@1 71.094 (71.094) 551 | * Prec@1 69.870 552 | current lr 3.45671e-02 553 | Epoch: [75][0/391] Time 0.179 (0.179) Data 0.135 (0.135) Loss 0.8982 (0.8982) Prec@1 75.781 (75.781) 554 | Epoch: [75][200/391] Time 0.035 (0.037) Data 0.000 (0.001) Loss 0.8809 (0.8629) Prec@1 77.344 (77.013) 555 | Epoch: [75][390/391] Time 0.027 (0.036) Data 0.000 (0.001) Loss 0.6056 (0.8840) Prec@1 82.500 (76.486) 556 | Total time for epoch [75] : 14.023 557 | Test: [0/79] Time 0.130 (0.130) Loss 0.9308 (0.9308) Prec@1 73.438 (73.438) 558 | * Prec@1 68.600 559 | current lr 3.42031e-02 560 | Epoch: [76][0/391] Time 0.207 (0.207) Data 0.143 (0.143) Loss 0.9331 (0.9331) Prec@1 76.562 (76.562) 561 | Epoch: [76][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 0.9449 (0.8520) Prec@1 71.875 (77.616) 562 | Epoch: [76][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 1.0897 (0.8678) Prec@1 70.000 (77.010) 563 | Total time for epoch [76] : 14.150 564 | Test: [0/79] Time 0.102 (0.102) Loss 0.9549 (0.9549) Prec@1 73.438 (73.438) 565 | * Prec@1 70.200 566 | current lr 3.38369e-02 567 | Epoch: [77][0/391] Time 0.170 (0.170) Data 0.117 (0.117) Loss 0.7339 (0.7339) Prec@1 82.031 (82.031) 568 | Epoch: [77][200/391] Time 0.035 (0.036) Data 0.000 (0.001) Loss 0.7788 (0.8416) Prec@1 82.031 (77.523) 569 | Epoch: [77][390/391] Time 0.032 (0.035) Data 0.000 (0.001) Loss 0.9584 (0.8679) Prec@1 71.250 (76.866) 570 | Total time for epoch [77] : 13.787 571 | Test: [0/79] Time 0.100 (0.100) Loss 0.8438 (0.8438) Prec@1 75.000 (75.000) 572 | * Prec@1 71.220 573 | current lr 3.34684e-02 574 | Epoch: [78][0/391] Time 0.167 (0.167) Data 0.122 (0.122) Loss 0.8196 (0.8196) Prec@1 82.812 (82.812) 575 | Epoch: [78][200/391] Time 0.037 (0.037) Data 0.000 (0.001) Loss 0.7989 (0.8394) Prec@1 80.469 (77.833) 576 | Epoch: [78][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 0.9884 (0.8561) Prec@1 73.750 (77.248) 577 | Total time for epoch [78] : 14.091 578 | Test: [0/79] Time 0.129 (0.129) Loss 0.9554 (0.9554) Prec@1 75.000 (75.000) 579 | * Prec@1 69.460 580 | current lr 3.30979e-02 581 | Epoch: [79][0/391] Time 0.160 (0.160) Data 0.114 (0.114) Loss 0.7739 (0.7739) Prec@1 80.469 (80.469) 582 | Epoch: [79][200/391] Time 0.038 (0.037) Data 0.001 (0.001) Loss 0.7681 (0.8131) Prec@1 80.469 (78.315) 583 | Epoch: [79][390/391] Time 0.030 (0.036) Data 0.000 (0.001) Loss 0.8549 (0.8460) Prec@1 76.250 (77.316) 584 | Total time for epoch [79] : 14.120 585 | Test: [0/79] Time 0.103 (0.103) Loss 1.0503 (1.0503) Prec@1 70.312 (70.312) 586 | * Prec@1 70.300 587 | current lr 3.27254e-02 588 | Epoch: [80][0/391] Time 0.181 (0.181) Data 0.130 (0.130) Loss 0.7421 (0.7421) Prec@1 78.906 (78.906) 589 | Epoch: [80][200/391] Time 0.035 (0.036) Data 0.000 (0.001) Loss 0.9683 (0.8307) Prec@1 67.969 (78.098) 590 | Epoch: [80][390/391] Time 0.029 (0.035) Data 0.000 (0.001) Loss 0.9069 (0.8406) Prec@1 77.500 (77.768) 591 | Total time for epoch [80] : 13.869 592 | Test: [0/79] Time 0.106 (0.106) Loss 0.9640 (0.9640) Prec@1 70.312 (70.312) 593 | * Prec@1 69.900 594 | current lr 3.23510e-02 595 | Epoch: [81][0/391] Time 0.175 (0.175) Data 0.133 (0.133) Loss 0.8376 (0.8376) Prec@1 78.906 (78.906) 596 | Epoch: [81][200/391] Time 0.040 (0.035) Data 0.001 (0.001) Loss 0.7986 (0.8079) Prec@1 79.688 (78.599) 597 | Epoch: [81][390/391] Time 0.027 (0.035) Data 0.000 (0.001) Loss 0.7770 (0.8310) Prec@1 78.750 (77.858) 598 | Total time for epoch [81] : 13.766 599 | Test: [0/79] Time 0.101 (0.101) Loss 0.9862 (0.9862) Prec@1 71.875 (71.875) 600 | * Prec@1 69.960 601 | current lr 3.19748e-02 602 | Epoch: [82][0/391] Time 0.166 (0.166) Data 0.127 (0.127) Loss 0.6646 (0.6646) Prec@1 83.594 (83.594) 603 | Epoch: [82][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 0.9715 (0.8097) Prec@1 77.344 (78.630) 604 | Epoch: [82][390/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 0.6958 (0.8252) Prec@1 80.000 (78.258) 605 | Total time for epoch [82] : 13.951 606 | Test: [0/79] Time 0.099 (0.099) Loss 0.9612 (0.9612) Prec@1 71.094 (71.094) 607 | * Prec@1 71.040 608 | current lr 3.15968e-02 609 | Epoch: [83][0/391] Time 0.169 (0.169) Data 0.124 (0.124) Loss 0.6717 (0.6717) Prec@1 81.250 (81.250) 610 | Epoch: [83][200/391] Time 0.047 (0.036) Data 0.001 (0.001) Loss 0.7760 (0.7953) Prec@1 79.688 (79.069) 611 | Epoch: [83][390/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 0.8772 (0.8228) Prec@1 78.750 (78.344) 612 | Total time for epoch [83] : 14.021 613 | Test: [0/79] Time 0.098 (0.098) Loss 0.9135 (0.9135) Prec@1 73.438 (73.438) 614 | * Prec@1 70.550 615 | current lr 3.12172e-02 616 | Epoch: [84][0/391] Time 0.165 (0.165) Data 0.120 (0.120) Loss 0.8657 (0.8657) Prec@1 76.562 (76.562) 617 | Epoch: [84][200/391] Time 0.044 (0.036) Data 0.000 (0.001) Loss 0.8507 (0.7888) Prec@1 77.344 (79.283) 618 | Epoch: [84][390/391] Time 0.037 (0.035) Data 0.000 (0.001) Loss 0.8147 (0.8108) Prec@1 76.250 (78.722) 619 | Total time for epoch [84] : 13.820 620 | Test: [0/79] Time 0.101 (0.101) Loss 0.8485 (0.8485) Prec@1 74.219 (74.219) 621 | * Prec@1 70.630 622 | current lr 3.08361e-02 623 | Epoch: [85][0/391] Time 0.176 (0.176) Data 0.129 (0.129) Loss 0.6308 (0.6308) Prec@1 86.719 (86.719) 624 | Epoch: [85][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 0.8583 (0.7956) Prec@1 73.438 (79.380) 625 | Epoch: [85][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 0.9173 (0.8064) Prec@1 72.500 (78.886) 626 | Total time for epoch [85] : 13.990 627 | Test: [0/79] Time 0.100 (0.100) Loss 0.8902 (0.8902) Prec@1 72.656 (72.656) 628 | * Prec@1 70.990 629 | current lr 3.04536e-02 630 | Epoch: [86][0/391] Time 0.183 (0.183) Data 0.140 (0.140) Loss 0.6616 (0.6616) Prec@1 84.375 (84.375) 631 | Epoch: [86][200/391] Time 0.040 (0.037) Data 0.000 (0.001) Loss 0.7536 (0.7698) Prec@1 78.125 (80.049) 632 | Epoch: [86][390/391] Time 0.027 (0.037) Data 0.000 (0.001) Loss 1.0012 (0.7966) Prec@1 68.750 (79.182) 633 | Total time for epoch [86] : 14.292 634 | Test: [0/79] Time 0.103 (0.103) Loss 0.9377 (0.9377) Prec@1 69.531 (69.531) 635 | * Prec@1 69.210 636 | current lr 3.00697e-02 637 | Epoch: [87][0/391] Time 0.191 (0.191) Data 0.127 (0.127) Loss 0.7443 (0.7443) Prec@1 83.594 (83.594) 638 | Epoch: [87][200/391] Time 0.030 (0.037) Data 0.000 (0.001) Loss 0.9309 (0.7726) Prec@1 78.906 (79.742) 639 | Epoch: [87][390/391] Time 0.036 (0.035) Data 0.000 (0.001) Loss 0.7558 (0.7888) Prec@1 83.750 (79.288) 640 | Total time for epoch [87] : 13.856 641 | Test: [0/79] Time 0.135 (0.135) Loss 0.9196 (0.9196) Prec@1 71.875 (71.875) 642 | * Prec@1 70.910 643 | current lr 2.96845e-02 644 | Epoch: [88][0/391] Time 0.176 (0.176) Data 0.136 (0.136) Loss 0.6375 (0.6375) Prec@1 88.281 (88.281) 645 | Epoch: [88][200/391] Time 0.038 (0.036) Data 0.000 (0.001) Loss 0.8718 (0.7497) Prec@1 74.219 (80.605) 646 | Epoch: [88][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 0.7714 (0.7826) Prec@1 82.500 (79.644) 647 | Total time for epoch [88] : 13.763 648 | Test: [0/79] Time 0.116 (0.116) Loss 0.9830 (0.9830) Prec@1 68.750 (68.750) 649 | * Prec@1 70.710 650 | current lr 2.92982e-02 651 | Epoch: [89][0/391] Time 0.175 (0.175) Data 0.122 (0.122) Loss 0.6211 (0.6211) Prec@1 84.375 (84.375) 652 | Epoch: [89][200/391] Time 0.036 (0.036) Data 0.001 (0.001) Loss 0.9291 (0.7518) Prec@1 75.781 (80.418) 653 | Epoch: [89][390/391] Time 0.031 (0.036) Data 0.000 (0.001) Loss 0.9534 (0.7740) Prec@1 77.500 (79.808) 654 | Total time for epoch [89] : 13.956 655 | Test: [0/79] Time 0.100 (0.100) Loss 0.8735 (0.8735) Prec@1 72.656 (72.656) 656 | * Prec@1 72.090 657 | current lr 2.89109e-02 658 | Epoch: [90][0/391] Time 0.185 (0.185) Data 0.124 (0.124) Loss 0.6601 (0.6601) Prec@1 82.812 (82.812) 659 | Epoch: [90][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 0.7780 (0.7573) Prec@1 77.344 (80.407) 660 | Epoch: [90][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 0.8191 (0.7699) Prec@1 82.500 (79.870) 661 | Total time for epoch [90] : 13.905 662 | Test: [0/79] Time 0.115 (0.115) Loss 0.8789 (0.8789) Prec@1 71.094 (71.094) 663 | * Prec@1 70.850 664 | current lr 2.85225e-02 665 | Epoch: [91][0/391] Time 0.169 (0.169) Data 0.116 (0.116) Loss 0.6282 (0.6282) Prec@1 86.719 (86.719) 666 | Epoch: [91][200/391] Time 0.043 (0.037) Data 0.001 (0.001) Loss 0.7531 (0.7345) Prec@1 81.250 (80.974) 667 | Epoch: [91][390/391] Time 0.027 (0.036) Data 0.000 (0.001) Loss 0.7719 (0.7590) Prec@1 83.750 (80.350) 668 | Total time for epoch [91] : 13.977 669 | Test: [0/79] Time 0.123 (0.123) Loss 1.0070 (1.0070) Prec@1 72.656 (72.656) 670 | * Prec@1 71.700 671 | current lr 2.81333e-02 672 | Epoch: [92][0/391] Time 0.191 (0.191) Data 0.143 (0.143) Loss 0.7697 (0.7697) Prec@1 78.906 (78.906) 673 | Epoch: [92][200/391] Time 0.032 (0.037) Data 0.000 (0.001) Loss 0.7462 (0.7280) Prec@1 78.125 (81.277) 674 | Epoch: [92][390/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.8710 (0.7499) Prec@1 75.000 (80.476) 675 | Total time for epoch [92] : 14.041 676 | Test: [0/79] Time 0.095 (0.095) Loss 0.8443 (0.8443) Prec@1 76.562 (76.562) 677 | * Prec@1 72.880 678 | current lr 2.77434e-02 679 | Epoch: [93][0/391] Time 0.196 (0.196) Data 0.127 (0.127) Loss 0.6725 (0.6725) Prec@1 86.719 (86.719) 680 | Epoch: [93][200/391] Time 0.033 (0.035) Data 0.000 (0.001) Loss 0.6393 (0.7195) Prec@1 84.375 (81.269) 681 | Epoch: [93][390/391] Time 0.031 (0.035) Data 0.000 (0.001) Loss 0.8859 (0.7401) Prec@1 75.000 (80.618) 682 | Total time for epoch [93] : 13.832 683 | Test: [0/79] Time 0.110 (0.110) Loss 0.9023 (0.9023) Prec@1 73.438 (73.438) 684 | * Prec@1 72.650 685 | current lr 2.73527e-02 686 | Epoch: [94][0/391] Time 0.189 (0.189) Data 0.134 (0.134) Loss 0.5984 (0.5984) Prec@1 85.938 (85.938) 687 | Epoch: [94][200/391] Time 0.039 (0.036) Data 0.000 (0.001) Loss 0.7725 (0.7211) Prec@1 76.562 (81.635) 688 | Epoch: [94][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 0.8354 (0.7386) Prec@1 78.750 (81.016) 689 | Total time for epoch [94] : 13.901 690 | Test: [0/79] Time 0.114 (0.114) Loss 0.8264 (0.8264) Prec@1 71.875 (71.875) 691 | * Prec@1 71.730 692 | current lr 2.69615e-02 693 | Epoch: [95][0/391] Time 0.170 (0.170) Data 0.120 (0.120) Loss 0.6213 (0.6213) Prec@1 81.250 (81.250) 694 | Epoch: [95][200/391] Time 0.036 (0.035) Data 0.000 (0.001) Loss 0.7331 (0.7235) Prec@1 78.906 (81.308) 695 | Epoch: [95][390/391] Time 0.031 (0.035) Data 0.000 (0.001) Loss 0.8372 (0.7405) Prec@1 77.500 (80.768) 696 | Total time for epoch [95] : 13.787 697 | Test: [0/79] Time 0.112 (0.112) Loss 0.8925 (0.8925) Prec@1 74.219 (74.219) 698 | * Prec@1 71.800 699 | current lr 2.65698e-02 700 | Epoch: [96][0/391] Time 0.186 (0.186) Data 0.135 (0.135) Loss 0.6204 (0.6204) Prec@1 86.719 (86.719) 701 | Epoch: [96][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.7124 (0.7068) Prec@1 80.469 (81.860) 702 | Epoch: [96][390/391] Time 0.031 (0.035) Data 0.000 (0.001) Loss 0.7460 (0.7264) Prec@1 77.500 (81.222) 703 | Total time for epoch [96] : 13.838 704 | Test: [0/79] Time 0.140 (0.140) Loss 0.8794 (0.8794) Prec@1 74.219 (74.219) 705 | * Prec@1 73.040 706 | current lr 2.61777e-02 707 | Epoch: [97][0/391] Time 0.163 (0.163) Data 0.116 (0.116) Loss 0.5690 (0.5690) Prec@1 88.281 (88.281) 708 | Epoch: [97][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.6204 (0.6850) Prec@1 81.250 (82.587) 709 | Epoch: [97][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 0.7916 (0.7109) Prec@1 80.000 (81.752) 710 | Total time for epoch [97] : 13.937 711 | Test: [0/79] Time 0.104 (0.104) Loss 0.8159 (0.8159) Prec@1 75.000 (75.000) 712 | * Prec@1 70.750 713 | current lr 2.57853e-02 714 | Epoch: [98][0/391] Time 0.218 (0.218) Data 0.169 (0.169) Loss 0.5778 (0.5778) Prec@1 88.281 (88.281) 715 | Epoch: [98][200/391] Time 0.041 (0.037) Data 0.000 (0.001) Loss 0.8184 (0.6903) Prec@1 72.656 (82.276) 716 | Epoch: [98][390/391] Time 0.030 (0.036) Data 0.000 (0.001) Loss 0.8012 (0.7096) Prec@1 78.750 (81.692) 717 | Total time for epoch [98] : 14.157 718 | Test: [0/79] Time 0.116 (0.116) Loss 0.9089 (0.9089) Prec@1 78.125 (78.125) 719 | * Prec@1 71.080 720 | current lr 2.53927e-02 721 | Epoch: [99][0/391] Time 0.180 (0.180) Data 0.127 (0.127) Loss 0.5689 (0.5689) Prec@1 85.156 (85.156) 722 | Epoch: [99][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.7537 (0.6764) Prec@1 77.344 (82.925) 723 | Epoch: [99][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 0.7280 (0.6962) Prec@1 81.250 (82.396) 724 | Total time for epoch [99] : 13.966 725 | Test: [0/79] Time 0.108 (0.108) Loss 0.8860 (0.8860) Prec@1 73.438 (73.438) 726 | * Prec@1 72.900 727 | current lr 2.50000e-02 728 | Epoch: [100][0/391] Time 0.172 (0.172) Data 0.117 (0.117) Loss 0.6923 (0.6923) Prec@1 81.250 (81.250) 729 | Epoch: [100][200/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 0.5902 (0.6654) Prec@1 85.938 (83.057) 730 | Epoch: [100][390/391] Time 0.027 (0.035) Data 0.000 (0.001) Loss 0.7726 (0.6919) Prec@1 78.750 (82.324) 731 | Total time for epoch [100] : 13.764 732 | Test: [0/79] Time 0.116 (0.116) Loss 0.8372 (0.8372) Prec@1 75.000 (75.000) 733 | * Prec@1 73.440 734 | current lr 2.46073e-02 735 | Epoch: [101][0/391] Time 0.165 (0.165) Data 0.125 (0.125) Loss 0.7940 (0.7940) Prec@1 80.469 (80.469) 736 | Epoch: [101][200/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 0.5811 (0.6587) Prec@1 87.500 (83.473) 737 | Epoch: [101][390/391] Time 0.035 (0.036) Data 0.000 (0.001) Loss 0.6703 (0.6853) Prec@1 85.000 (82.600) 738 | Total time for epoch [101] : 13.899 739 | Test: [0/79] Time 0.117 (0.117) Loss 0.9080 (0.9080) Prec@1 71.875 (71.875) 740 | * Prec@1 73.270 741 | current lr 2.42147e-02 742 | Epoch: [102][0/391] Time 0.178 (0.178) Data 0.114 (0.114) Loss 0.5726 (0.5726) Prec@1 83.594 (83.594) 743 | Epoch: [102][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.5606 (0.6583) Prec@1 88.281 (83.648) 744 | Epoch: [102][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 0.6608 (0.6750) Prec@1 78.750 (83.026) 745 | Total time for epoch [102] : 13.826 746 | Test: [0/79] Time 0.113 (0.113) Loss 0.7786 (0.7786) Prec@1 75.781 (75.781) 747 | * Prec@1 73.780 748 | current lr 2.38223e-02 749 | Epoch: [103][0/391] Time 0.171 (0.171) Data 0.127 (0.127) Loss 0.5665 (0.5665) Prec@1 85.156 (85.156) 750 | Epoch: [103][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.6064 (0.6464) Prec@1 79.688 (83.866) 751 | Epoch: [103][390/391] Time 0.032 (0.035) Data 0.000 (0.001) Loss 1.0066 (0.6661) Prec@1 68.750 (83.142) 752 | Total time for epoch [103] : 13.773 753 | Test: [0/79] Time 0.121 (0.121) Loss 0.9174 (0.9174) Prec@1 71.094 (71.094) 754 | * Prec@1 73.800 755 | current lr 2.34302e-02 756 | Epoch: [104][0/391] Time 0.213 (0.213) Data 0.168 (0.168) Loss 0.5763 (0.5763) Prec@1 88.281 (88.281) 757 | Epoch: [104][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.6916 (0.6416) Prec@1 82.812 (84.126) 758 | Epoch: [104][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 0.5834 (0.6580) Prec@1 87.500 (83.648) 759 | Total time for epoch [104] : 13.914 760 | Test: [0/79] Time 0.101 (0.101) Loss 0.8170 (0.8170) Prec@1 75.000 (75.000) 761 | * Prec@1 73.300 762 | current lr 2.30385e-02 763 | Epoch: [105][0/391] Time 0.180 (0.180) Data 0.126 (0.126) Loss 0.6032 (0.6032) Prec@1 88.281 (88.281) 764 | Epoch: [105][200/391] Time 0.033 (0.035) Data 0.000 (0.001) Loss 0.4375 (0.6280) Prec@1 90.625 (84.340) 765 | Epoch: [105][390/391] Time 0.032 (0.035) Data 0.000 (0.001) Loss 0.8477 (0.6505) Prec@1 75.000 (83.574) 766 | Total time for epoch [105] : 13.803 767 | Test: [0/79] Time 0.099 (0.099) Loss 0.8099 (0.8099) Prec@1 76.562 (76.562) 768 | * Prec@1 72.680 769 | current lr 2.26473e-02 770 | Epoch: [106][0/391] Time 0.208 (0.208) Data 0.159 (0.159) Loss 0.8366 (0.8366) Prec@1 78.906 (78.906) 771 | Epoch: [106][200/391] Time 0.036 (0.036) Data 0.001 (0.001) Loss 0.6099 (0.6254) Prec@1 88.281 (84.674) 772 | Epoch: [106][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 0.6245 (0.6437) Prec@1 78.750 (84.054) 773 | Total time for epoch [106] : 13.883 774 | Test: [0/79] Time 0.127 (0.127) Loss 0.9314 (0.9314) Prec@1 72.656 (72.656) 775 | * Prec@1 72.730 776 | current lr 2.22566e-02 777 | Epoch: [107][0/391] Time 0.166 (0.166) Data 0.117 (0.117) Loss 0.6623 (0.6623) Prec@1 82.031 (82.031) 778 | Epoch: [107][200/391] Time 0.033 (0.037) Data 0.000 (0.001) Loss 0.6160 (0.6085) Prec@1 87.500 (85.218) 779 | Epoch: [107][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 0.7502 (0.6318) Prec@1 81.250 (84.306) 780 | Total time for epoch [107] : 13.980 781 | Test: [0/79] Time 0.119 (0.119) Loss 0.8392 (0.8392) Prec@1 75.000 (75.000) 782 | * Prec@1 72.420 783 | current lr 2.18667e-02 784 | Epoch: [108][0/391] Time 0.201 (0.201) Data 0.152 (0.152) Loss 0.6562 (0.6562) Prec@1 83.594 (83.594) 785 | Epoch: [108][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 0.6057 (0.5981) Prec@1 87.500 (85.452) 786 | Epoch: [108][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 0.5820 (0.6199) Prec@1 92.500 (84.716) 787 | Total time for epoch [108] : 13.821 788 | Test: [0/79] Time 0.107 (0.107) Loss 0.8533 (0.8533) Prec@1 73.438 (73.438) 789 | * Prec@1 73.130 790 | current lr 2.14775e-02 791 | Epoch: [109][0/391] Time 0.198 (0.198) Data 0.144 (0.144) Loss 0.6192 (0.6192) Prec@1 86.719 (86.719) 792 | Epoch: [109][200/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 0.6749 (0.6099) Prec@1 78.125 (85.001) 793 | Epoch: [109][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 0.6335 (0.6214) Prec@1 81.250 (84.622) 794 | Total time for epoch [109] : 13.968 795 | Test: [0/79] Time 0.114 (0.114) Loss 0.8124 (0.8124) Prec@1 75.000 (75.000) 796 | * Prec@1 74.260 797 | current lr 2.10891e-02 798 | Epoch: [110][0/391] Time 0.158 (0.158) Data 0.116 (0.116) Loss 0.6615 (0.6615) Prec@1 78.906 (78.906) 799 | Epoch: [110][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 0.5929 (0.5938) Prec@1 85.156 (85.230) 800 | Epoch: [110][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 0.5012 (0.6059) Prec@1 86.250 (84.934) 801 | Total time for epoch [110] : 13.871 802 | Test: [0/79] Time 0.102 (0.102) Loss 0.8050 (0.8050) Prec@1 75.000 (75.000) 803 | * Prec@1 74.570 804 | current lr 2.07018e-02 805 | Epoch: [111][0/391] Time 0.261 (0.261) Data 0.209 (0.209) Loss 0.5467 (0.5467) Prec@1 86.719 (86.719) 806 | Epoch: [111][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.5222 (0.5797) Prec@1 86.719 (85.871) 807 | Epoch: [111][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 0.7065 (0.5964) Prec@1 80.000 (85.424) 808 | Total time for epoch [111] : 14.011 809 | Test: [0/79] Time 0.112 (0.112) Loss 0.8533 (0.8533) Prec@1 76.562 (76.562) 810 | * Prec@1 74.220 811 | current lr 2.03155e-02 812 | Epoch: [112][0/391] Time 0.210 (0.210) Data 0.155 (0.155) Loss 0.5164 (0.5164) Prec@1 86.719 (86.719) 813 | Epoch: [112][200/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 0.5604 (0.5737) Prec@1 81.250 (86.256) 814 | Epoch: [112][390/391] Time 0.041 (0.036) Data 0.000 (0.001) Loss 0.6140 (0.5855) Prec@1 86.250 (85.796) 815 | Total time for epoch [112] : 13.901 816 | Test: [0/79] Time 0.116 (0.116) Loss 0.7671 (0.7671) Prec@1 75.781 (75.781) 817 | * Prec@1 74.400 818 | current lr 1.99303e-02 819 | Epoch: [113][0/391] Time 0.167 (0.167) Data 0.123 (0.123) Loss 0.5145 (0.5145) Prec@1 88.281 (88.281) 820 | Epoch: [113][200/391] Time 0.031 (0.036) Data 0.000 (0.001) Loss 0.6708 (0.5578) Prec@1 82.031 (86.528) 821 | Epoch: [113][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 0.6002 (0.5777) Prec@1 80.000 (85.866) 822 | Total time for epoch [113] : 13.894 823 | Test: [0/79] Time 0.099 (0.099) Loss 0.7233 (0.7233) Prec@1 78.906 (78.906) 824 | * Prec@1 73.330 825 | current lr 1.95464e-02 826 | Epoch: [114][0/391] Time 0.171 (0.171) Data 0.114 (0.114) Loss 0.5785 (0.5785) Prec@1 85.938 (85.938) 827 | Epoch: [114][200/391] Time 0.033 (0.035) Data 0.000 (0.001) Loss 0.6099 (0.5565) Prec@1 85.156 (86.598) 828 | Epoch: [114][390/391] Time 0.039 (0.035) Data 0.000 (0.001) Loss 0.5968 (0.5750) Prec@1 82.500 (85.998) 829 | Total time for epoch [114] : 13.797 830 | Test: [0/79] Time 0.111 (0.111) Loss 0.7554 (0.7554) Prec@1 76.562 (76.562) 831 | * Prec@1 74.520 832 | current lr 1.91639e-02 833 | Epoch: [115][0/391] Time 0.166 (0.166) Data 0.123 (0.123) Loss 0.3873 (0.3873) Prec@1 92.188 (92.188) 834 | Epoch: [115][200/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 0.5219 (0.5457) Prec@1 89.062 (87.294) 835 | Epoch: [115][390/391] Time 0.030 (0.035) Data 0.000 (0.001) Loss 0.3986 (0.5609) Prec@1 88.750 (86.652) 836 | Total time for epoch [115] : 13.744 837 | Test: [0/79] Time 0.107 (0.107) Loss 0.7871 (0.7871) Prec@1 75.000 (75.000) 838 | * Prec@1 75.090 839 | current lr 1.87828e-02 840 | Epoch: [116][0/391] Time 0.193 (0.193) Data 0.146 (0.146) Loss 0.4090 (0.4090) Prec@1 93.750 (93.750) 841 | Epoch: [116][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.4943 (0.5305) Prec@1 84.375 (87.438) 842 | Epoch: [116][390/391] Time 0.034 (0.035) Data 0.000 (0.001) Loss 0.4456 (0.5523) Prec@1 88.750 (86.752) 843 | Total time for epoch [116] : 13.813 844 | Test: [0/79] Time 0.118 (0.118) Loss 0.7362 (0.7362) Prec@1 78.125 (78.125) 845 | * Prec@1 75.180 846 | current lr 1.84032e-02 847 | Epoch: [117][0/391] Time 0.166 (0.166) Data 0.121 (0.121) Loss 0.6262 (0.6262) Prec@1 81.250 (81.250) 848 | Epoch: [117][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.5995 (0.5345) Prec@1 85.938 (87.484) 849 | Epoch: [117][390/391] Time 0.031 (0.036) Data 0.000 (0.001) Loss 0.6375 (0.5500) Prec@1 82.500 (86.984) 850 | Total time for epoch [117] : 13.882 851 | Test: [0/79] Time 0.097 (0.097) Loss 0.6708 (0.6708) Prec@1 79.688 (79.688) 852 | * Prec@1 75.440 853 | current lr 1.80252e-02 854 | Epoch: [118][0/391] Time 0.211 (0.211) Data 0.166 (0.166) Loss 0.5004 (0.5004) Prec@1 89.844 (89.844) 855 | Epoch: [118][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.4283 (0.5170) Prec@1 89.844 (88.056) 856 | Epoch: [118][390/391] Time 0.030 (0.036) Data 0.000 (0.001) Loss 0.5444 (0.5369) Prec@1 87.500 (87.432) 857 | Total time for epoch [118] : 14.080 858 | Test: [0/79] Time 0.150 (0.150) Loss 0.6987 (0.6987) Prec@1 78.125 (78.125) 859 | * Prec@1 74.700 860 | current lr 1.76490e-02 861 | Epoch: [119][0/391] Time 0.167 (0.167) Data 0.113 (0.113) Loss 0.4888 (0.4888) Prec@1 89.062 (89.062) 862 | Epoch: [119][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.5452 (0.5160) Prec@1 92.188 (88.169) 863 | Epoch: [119][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 0.7207 (0.5291) Prec@1 78.750 (87.624) 864 | Total time for epoch [119] : 13.836 865 | Test: [0/79] Time 0.105 (0.105) Loss 0.7411 (0.7411) Prec@1 76.562 (76.562) 866 | * Prec@1 75.920 867 | current lr 1.72746e-02 868 | Epoch: [120][0/391] Time 0.162 (0.162) Data 0.115 (0.115) Loss 0.4602 (0.4602) Prec@1 90.625 (90.625) 869 | Epoch: [120][200/391] Time 0.044 (0.037) Data 0.000 (0.001) Loss 0.5595 (0.4998) Prec@1 82.031 (88.806) 870 | Epoch: [120][390/391] Time 0.028 (0.037) Data 0.000 (0.001) Loss 0.4613 (0.5159) Prec@1 90.000 (88.140) 871 | Total time for epoch [120] : 14.273 872 | Test: [0/79] Time 0.100 (0.100) Loss 0.8041 (0.8041) Prec@1 72.656 (72.656) 873 | * Prec@1 75.570 874 | current lr 1.69021e-02 875 | Epoch: [121][0/391] Time 0.163 (0.163) Data 0.121 (0.121) Loss 0.4345 (0.4345) Prec@1 91.406 (91.406) 876 | Epoch: [121][200/391] Time 0.032 (0.036) Data 0.001 (0.001) Loss 0.5606 (0.4929) Prec@1 87.500 (88.923) 877 | Epoch: [121][390/391] Time 0.027 (0.036) Data 0.000 (0.001) Loss 0.5510 (0.5155) Prec@1 87.500 (88.040) 878 | Total time for epoch [121] : 14.072 879 | Test: [0/79] Time 0.115 (0.115) Loss 0.7023 (0.7023) Prec@1 75.781 (75.781) 880 | * Prec@1 75.590 881 | current lr 1.65316e-02 882 | Epoch: [122][0/391] Time 0.164 (0.164) Data 0.117 (0.117) Loss 0.4842 (0.4842) Prec@1 89.062 (89.062) 883 | Epoch: [122][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.5054 (0.4882) Prec@1 86.719 (89.051) 884 | Epoch: [122][390/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 0.4732 (0.5064) Prec@1 88.750 (88.526) 885 | Total time for epoch [122] : 13.965 886 | Test: [0/79] Time 0.116 (0.116) Loss 0.7853 (0.7853) Prec@1 75.781 (75.781) 887 | * Prec@1 75.620 888 | current lr 1.61631e-02 889 | Epoch: [123][0/391] Time 0.168 (0.168) Data 0.123 (0.123) Loss 0.4515 (0.4515) Prec@1 89.844 (89.844) 890 | Epoch: [123][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.4822 (0.4831) Prec@1 91.406 (89.331) 891 | Epoch: [123][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 0.7319 (0.4989) Prec@1 83.750 (88.772) 892 | Total time for epoch [123] : 13.774 893 | Test: [0/79] Time 0.099 (0.099) Loss 0.7097 (0.7097) Prec@1 78.125 (78.125) 894 | * Prec@1 76.250 895 | current lr 1.57969e-02 896 | Epoch: [124][0/391] Time 0.155 (0.155) Data 0.115 (0.115) Loss 0.3834 (0.3834) Prec@1 92.969 (92.969) 897 | Epoch: [124][200/391] Time 0.034 (0.035) Data 0.000 (0.001) Loss 0.4485 (0.4754) Prec@1 89.844 (89.513) 898 | Epoch: [124][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 0.6910 (0.4861) Prec@1 80.000 (89.098) 899 | Total time for epoch [124] : 13.740 900 | Test: [0/79] Time 0.123 (0.123) Loss 0.7374 (0.7374) Prec@1 78.125 (78.125) 901 | * Prec@1 75.680 902 | current lr 1.54329e-02 903 | Epoch: [125][0/391] Time 0.169 (0.169) Data 0.122 (0.122) Loss 0.4205 (0.4205) Prec@1 92.969 (92.969) 904 | Epoch: [125][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.4904 (0.4628) Prec@1 90.625 (89.984) 905 | Epoch: [125][390/391] Time 0.027 (0.035) Data 0.000 (0.001) Loss 0.4820 (0.4797) Prec@1 87.500 (89.298) 906 | Total time for epoch [125] : 13.853 907 | Test: [0/79] Time 0.122 (0.122) Loss 0.7572 (0.7572) Prec@1 74.219 (74.219) 908 | * Prec@1 76.170 909 | current lr 1.50713e-02 910 | Epoch: [126][0/391] Time 0.174 (0.174) Data 0.130 (0.130) Loss 0.4195 (0.4195) Prec@1 92.188 (92.188) 911 | Epoch: [126][200/391] Time 0.038 (0.036) Data 0.000 (0.001) Loss 0.4912 (0.4539) Prec@1 89.062 (90.085) 912 | Epoch: [126][390/391] Time 0.029 (0.035) Data 0.000 (0.001) Loss 0.6683 (0.4679) Prec@1 80.000 (89.700) 913 | Total time for epoch [126] : 13.814 914 | Test: [0/79] Time 0.103 (0.103) Loss 0.7793 (0.7793) Prec@1 72.656 (72.656) 915 | * Prec@1 75.850 916 | current lr 1.47121e-02 917 | Epoch: [127][0/391] Time 0.184 (0.184) Data 0.133 (0.133) Loss 0.4487 (0.4487) Prec@1 89.844 (89.844) 918 | Epoch: [127][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.4015 (0.4387) Prec@1 89.844 (90.571) 919 | Epoch: [127][390/391] Time 0.027 (0.036) Data 0.000 (0.001) Loss 0.6407 (0.4591) Prec@1 87.500 (89.918) 920 | Total time for epoch [127] : 13.950 921 | Test: [0/79] Time 0.130 (0.130) Loss 0.7350 (0.7350) Prec@1 77.344 (77.344) 922 | * Prec@1 76.540 923 | current lr 1.43555e-02 924 | Epoch: [128][0/391] Time 0.175 (0.175) Data 0.118 (0.118) Loss 0.5255 (0.5255) Prec@1 86.719 (86.719) 925 | Epoch: [128][200/391] Time 0.038 (0.036) Data 0.000 (0.001) Loss 0.4965 (0.4458) Prec@1 87.500 (90.501) 926 | Epoch: [128][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 0.4674 (0.4554) Prec@1 91.250 (90.094) 927 | Total time for epoch [128] : 13.703 928 | Test: [0/79] Time 0.116 (0.116) Loss 0.7790 (0.7790) Prec@1 77.344 (77.344) 929 | * Prec@1 75.760 930 | current lr 1.40015e-02 931 | Epoch: [129][0/391] Time 0.209 (0.209) Data 0.154 (0.154) Loss 0.3588 (0.3588) Prec@1 89.844 (89.844) 932 | Epoch: [129][200/391] Time 0.041 (0.036) Data 0.000 (0.001) Loss 0.4358 (0.4321) Prec@1 89.844 (90.882) 933 | Epoch: [129][390/391] Time 0.030 (0.036) Data 0.000 (0.001) Loss 0.4050 (0.4491) Prec@1 91.250 (90.274) 934 | Total time for epoch [129] : 13.969 935 | Test: [0/79] Time 0.100 (0.100) Loss 0.7513 (0.7513) Prec@1 78.906 (78.906) 936 | * Prec@1 77.490 937 | current lr 1.36502e-02 938 | Epoch: [130][0/391] Time 0.161 (0.161) Data 0.116 (0.116) Loss 0.3765 (0.3765) Prec@1 91.406 (91.406) 939 | Epoch: [130][200/391] Time 0.035 (0.035) Data 0.000 (0.001) Loss 0.3894 (0.4215) Prec@1 93.750 (91.239) 940 | Epoch: [130][390/391] Time 0.027 (0.036) Data 0.000 (0.001) Loss 0.4448 (0.4381) Prec@1 90.000 (90.598) 941 | Total time for epoch [130] : 13.942 942 | Test: [0/79] Time 0.126 (0.126) Loss 0.7868 (0.7868) Prec@1 75.781 (75.781) 943 | * Prec@1 76.710 944 | current lr 1.33018e-02 945 | Epoch: [131][0/391] Time 0.174 (0.174) Data 0.114 (0.114) Loss 0.3400 (0.3400) Prec@1 92.969 (92.969) 946 | Epoch: [131][200/391] Time 0.032 (0.036) Data 0.000 (0.001) Loss 0.4247 (0.4058) Prec@1 92.969 (91.768) 947 | Epoch: [131][390/391] Time 0.028 (0.036) Data 0.000 (0.001) Loss 0.4633 (0.4213) Prec@1 88.750 (91.170) 948 | Total time for epoch [131] : 13.911 949 | Test: [0/79] Time 0.110 (0.110) Loss 0.7466 (0.7466) Prec@1 75.000 (75.000) 950 | * Prec@1 76.990 951 | current lr 1.29562e-02 952 | Epoch: [132][0/391] Time 0.191 (0.191) Data 0.141 (0.141) Loss 0.3828 (0.3828) Prec@1 92.188 (92.188) 953 | Epoch: [132][200/391] Time 0.035 (0.036) Data 0.000 (0.001) Loss 0.4541 (0.4096) Prec@1 89.844 (91.476) 954 | Epoch: [132][390/391] Time 0.030 (0.035) Data 0.000 (0.001) Loss 0.4677 (0.4195) Prec@1 90.000 (91.084) 955 | Total time for epoch [132] : 13.876 956 | Test: [0/79] Time 0.094 (0.094) Loss 0.7756 (0.7756) Prec@1 77.344 (77.344) 957 | * Prec@1 76.850 958 | current lr 1.26135e-02 959 | Epoch: [133][0/391] Time 0.159 (0.159) Data 0.112 (0.112) Loss 0.3659 (0.3659) Prec@1 92.188 (92.188) 960 | Epoch: [133][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.3823 (0.4038) Prec@1 92.188 (91.706) 961 | Epoch: [133][390/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 0.4149 (0.4127) Prec@1 91.250 (91.350) 962 | Total time for epoch [133] : 13.925 963 | Test: [0/79] Time 0.104 (0.104) Loss 0.6793 (0.6793) Prec@1 76.562 (76.562) 964 | * Prec@1 76.830 965 | current lr 1.22740e-02 966 | Epoch: [134][0/391] Time 0.176 (0.176) Data 0.125 (0.125) Loss 0.3454 (0.3454) Prec@1 94.531 (94.531) 967 | Epoch: [134][200/391] Time 0.029 (0.036) Data 0.000 (0.001) Loss 0.3750 (0.3938) Prec@1 90.625 (91.935) 968 | Epoch: [134][390/391] Time 0.037 (0.035) Data 0.000 (0.001) Loss 0.6710 (0.4026) Prec@1 81.250 (91.754) 969 | Total time for epoch [134] : 13.797 970 | Test: [0/79] Time 0.112 (0.112) Loss 0.6066 (0.6066) Prec@1 81.250 (81.250) 971 | * Prec@1 77.300 972 | current lr 1.19375e-02 973 | Epoch: [135][0/391] Time 0.174 (0.174) Data 0.126 (0.126) Loss 0.3970 (0.3970) Prec@1 89.844 (89.844) 974 | Epoch: [135][200/391] Time 0.034 (0.036) Data 0.000 (0.001) Loss 0.3667 (0.3803) Prec@1 92.969 (92.498) 975 | Epoch: [135][390/391] Time 0.030 (0.036) Data 0.000 (0.001) Loss 0.6130 (0.3951) Prec@1 82.500 (91.946) 976 | Total time for epoch [135] : 13.984 977 | Test: [0/79] Time 0.111 (0.111) Loss 0.6877 (0.6877) Prec@1 82.031 (82.031) 978 | * Prec@1 77.000 979 | current lr 1.16043e-02 980 | Epoch: [136][0/391] Time 0.188 (0.188) Data 0.136 (0.136) Loss 0.3753 (0.3753) Prec@1 93.750 (93.750) 981 | Epoch: [136][200/391] Time 0.033 (0.035) Data 0.000 (0.001) Loss 0.3427 (0.3762) Prec@1 94.531 (92.522) 982 | Epoch: [136][390/391] Time 0.028 (0.035) Data 0.000 (0.001) Loss 0.5199 (0.3859) Prec@1 88.750 (92.142) 983 | Total time for epoch [136] : 13.722 984 | Test: [0/79] Time 0.110 (0.110) Loss 0.6875 (0.6875) Prec@1 76.562 (76.562) 985 | * Prec@1 77.380 986 | current lr 1.12744e-02 987 | Epoch: [137][0/391] Time 0.161 (0.161) Data 0.112 (0.112) Loss 0.3158 (0.3158) Prec@1 93.750 (93.750) 988 | Epoch: [137][200/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.4080 (0.3684) Prec@1 88.281 (92.736) 989 | Epoch: [137][390/391] Time 0.033 (0.036) Data 0.000 (0.001) Loss 0.5028 (0.3825) Prec@1 90.000 (92.322) 990 | Total time for epoch [137] : 14.158 991 | Test: [0/79] Time 0.128 (0.128) Loss 0.6859 (0.6859) Prec@1 80.469 (80.469) 992 | * Prec@1 77.810 993 | current lr 1.09479e-02 994 | Epoch: [138][0/391] Time 0.175 (0.175) Data 0.124 (0.124) Loss 0.4306 (0.4306) Prec@1 89.062 (89.062) 995 | Epoch: [138][200/391] Time 0.046 (0.038) Data 0.000 (0.001) Loss 0.4877 (0.3623) Prec@1 86.719 (92.778) 996 | Epoch: [138][390/391] Time 0.032 (0.038) Data 0.000 (0.001) Loss 0.3970 (0.3723) Prec@1 90.000 (92.542) 997 | Total time for epoch [138] : 14.876 998 | Test: [0/79] Time 0.118 (0.118) Loss 0.6829 (0.6829) Prec@1 79.688 (79.688) 999 | * Prec@1 77.780 1000 | current lr 1.06249e-02 1001 | Epoch: [139][0/391] Time 0.171 (0.171) Data 0.118 (0.118) Loss 0.4018 (0.4018) Prec@1 90.625 (90.625) 1002 | Epoch: [139][200/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.3682 (0.3477) Prec@1 93.750 (93.307) 1003 | Epoch: [139][390/391] Time 0.042 (0.038) Data 0.000 (0.001) Loss 0.4364 (0.3599) Prec@1 91.250 (93.000) 1004 | Total time for epoch [139] : 14.775 1005 | Test: [0/79] Time 0.124 (0.124) Loss 0.6918 (0.6918) Prec@1 75.781 (75.781) 1006 | * Prec@1 77.920 1007 | current lr 1.03054e-02 1008 | Epoch: [140][0/391] Time 0.170 (0.170) Data 0.114 (0.114) Loss 0.2377 (0.2377) Prec@1 99.219 (99.219) 1009 | Epoch: [140][200/391] Time 0.035 (0.038) Data 0.000 (0.001) Loss 0.2217 (0.3408) Prec@1 96.875 (93.525) 1010 | Epoch: [140][390/391] Time 0.038 (0.038) Data 0.000 (0.001) Loss 0.5046 (0.3580) Prec@1 91.250 (93.080) 1011 | Total time for epoch [140] : 14.970 1012 | Test: [0/79] Time 0.104 (0.104) Loss 0.7912 (0.7912) Prec@1 75.781 (75.781) 1013 | * Prec@1 78.260 1014 | current lr 9.98949e-03 1015 | Epoch: [141][0/391] Time 0.174 (0.174) Data 0.122 (0.122) Loss 0.3114 (0.3114) Prec@1 93.750 (93.750) 1016 | Epoch: [141][200/391] Time 0.040 (0.038) Data 0.000 (0.001) Loss 0.2975 (0.3385) Prec@1 94.531 (93.731) 1017 | Epoch: [141][390/391] Time 0.034 (0.038) Data 0.000 (0.001) Loss 0.2831 (0.3471) Prec@1 98.750 (93.468) 1018 | Total time for epoch [141] : 14.796 1019 | Test: [0/79] Time 0.111 (0.111) Loss 0.7100 (0.7100) Prec@1 77.344 (77.344) 1020 | * Prec@1 78.030 1021 | current lr 9.67732e-03 1022 | Epoch: [142][0/391] Time 0.244 (0.244) Data 0.195 (0.195) Loss 0.3631 (0.3631) Prec@1 92.188 (92.188) 1023 | Epoch: [142][200/391] Time 0.038 (0.038) Data 0.000 (0.001) Loss 0.2984 (0.3376) Prec@1 90.625 (93.731) 1024 | Epoch: [142][390/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.3489 (0.3447) Prec@1 92.500 (93.514) 1025 | Total time for epoch [142] : 14.964 1026 | Test: [0/79] Time 0.127 (0.127) Loss 0.6467 (0.6467) Prec@1 82.031 (82.031) 1027 | * Prec@1 78.700 1028 | current lr 9.36893e-03 1029 | Epoch: [143][0/391] Time 0.171 (0.171) Data 0.118 (0.118) Loss 0.3163 (0.3163) Prec@1 96.094 (96.094) 1030 | Epoch: [143][200/391] Time 0.039 (0.038) Data 0.000 (0.001) Loss 0.3558 (0.3245) Prec@1 95.312 (94.201) 1031 | Epoch: [143][390/391] Time 0.039 (0.039) Data 0.000 (0.001) Loss 0.3210 (0.3363) Prec@1 95.000 (93.782) 1032 | Total time for epoch [143] : 15.077 1033 | Test: [0/79] Time 0.106 (0.106) Loss 0.6173 (0.6173) Prec@1 82.031 (82.031) 1034 | * Prec@1 78.210 1035 | current lr 9.06440e-03 1036 | Epoch: [144][0/391] Time 0.172 (0.172) Data 0.128 (0.128) Loss 0.2283 (0.2283) Prec@1 95.312 (95.312) 1037 | Epoch: [144][200/391] Time 0.036 (0.039) Data 0.000 (0.001) Loss 0.3585 (0.3188) Prec@1 92.969 (94.154) 1038 | Epoch: [144][390/391] Time 0.035 (0.038) Data 0.000 (0.001) Loss 0.2924 (0.3274) Prec@1 95.000 (93.922) 1039 | Total time for epoch [144] : 14.987 1040 | Test: [0/79] Time 0.104 (0.104) Loss 0.6432 (0.6432) Prec@1 76.562 (76.562) 1041 | * Prec@1 78.080 1042 | current lr 8.76380e-03 1043 | Epoch: [145][0/391] Time 0.173 (0.173) Data 0.127 (0.127) Loss 0.3675 (0.3675) Prec@1 92.969 (92.969) 1044 | Epoch: [145][200/391] Time 0.035 (0.038) Data 0.000 (0.001) Loss 0.2542 (0.3084) Prec@1 95.312 (94.430) 1045 | Epoch: [145][390/391] Time 0.030 (0.038) Data 0.000 (0.001) Loss 0.3887 (0.3177) Prec@1 91.250 (94.226) 1046 | Total time for epoch [145] : 14.717 1047 | Test: [0/79] Time 0.108 (0.108) Loss 0.6350 (0.6350) Prec@1 80.469 (80.469) 1048 | * Prec@1 78.720 1049 | current lr 8.46720e-03 1050 | Epoch: [146][0/391] Time 0.182 (0.182) Data 0.125 (0.125) Loss 0.2960 (0.2960) Prec@1 94.531 (94.531) 1051 | Epoch: [146][200/391] Time 0.049 (0.038) Data 0.000 (0.001) Loss 0.2354 (0.3044) Prec@1 96.094 (94.753) 1052 | Epoch: [146][390/391] Time 0.029 (0.037) Data 0.000 (0.001) Loss 0.3975 (0.3143) Prec@1 91.250 (94.456) 1053 | Total time for epoch [146] : 14.655 1054 | Test: [0/79] Time 0.105 (0.105) Loss 0.6335 (0.6335) Prec@1 78.906 (78.906) 1055 | * Prec@1 78.310 1056 | current lr 8.17469e-03 1057 | Epoch: [147][0/391] Time 0.163 (0.163) Data 0.116 (0.116) Loss 0.2659 (0.2659) Prec@1 96.094 (96.094) 1058 | Epoch: [147][200/391] Time 0.035 (0.038) Data 0.000 (0.001) Loss 0.3241 (0.3034) Prec@1 93.750 (94.729) 1059 | Epoch: [147][390/391] Time 0.032 (0.038) Data 0.000 (0.001) Loss 0.4819 (0.3089) Prec@1 88.750 (94.572) 1060 | Total time for epoch [147] : 14.809 1061 | Test: [0/79] Time 0.143 (0.143) Loss 0.6520 (0.6520) Prec@1 78.906 (78.906) 1062 | * Prec@1 78.840 1063 | current lr 7.88632e-03 1064 | Epoch: [148][0/391] Time 0.229 (0.229) Data 0.182 (0.182) Loss 0.2571 (0.2571) Prec@1 96.094 (96.094) 1065 | Epoch: [148][200/391] Time 0.042 (0.038) Data 0.000 (0.001) Loss 0.2602 (0.2969) Prec@1 96.094 (94.986) 1066 | Epoch: [148][390/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.3217 (0.3024) Prec@1 93.750 (94.736) 1067 | Total time for epoch [148] : 14.897 1068 | Test: [0/79] Time 0.125 (0.125) Loss 0.6474 (0.6474) Prec@1 78.125 (78.125) 1069 | * Prec@1 79.190 1070 | current lr 7.60218e-03 1071 | Epoch: [149][0/391] Time 0.183 (0.183) Data 0.132 (0.132) Loss 0.2378 (0.2378) Prec@1 96.875 (96.875) 1072 | Epoch: [149][200/391] Time 0.034 (0.038) Data 0.000 (0.001) Loss 0.3483 (0.2863) Prec@1 92.969 (95.200) 1073 | Epoch: [149][390/391] Time 0.030 (0.038) Data 0.000 (0.001) Loss 0.3668 (0.2960) Prec@1 93.750 (94.936) 1074 | Total time for epoch [149] : 14.765 1075 | Test: [0/79] Time 0.117 (0.117) Loss 0.6116 (0.6116) Prec@1 82.031 (82.031) 1076 | * Prec@1 78.880 1077 | current lr 7.32233e-03 1078 | Epoch: [150][0/391] Time 0.169 (0.169) Data 0.122 (0.122) Loss 0.2416 (0.2416) Prec@1 96.094 (96.094) 1079 | Epoch: [150][200/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.2649 (0.2851) Prec@1 96.094 (95.211) 1080 | Epoch: [150][390/391] Time 0.030 (0.038) Data 0.000 (0.001) Loss 0.2899 (0.2877) Prec@1 95.000 (95.160) 1081 | Total time for epoch [150] : 14.779 1082 | Test: [0/79] Time 0.110 (0.110) Loss 0.6698 (0.6698) Prec@1 78.906 (78.906) 1083 | * Prec@1 78.930 1084 | current lr 7.04684e-03 1085 | Epoch: [151][0/391] Time 0.168 (0.168) Data 0.123 (0.123) Loss 0.3019 (0.3019) Prec@1 93.750 (93.750) 1086 | Epoch: [151][200/391] Time 0.038 (0.037) Data 0.000 (0.001) Loss 0.2680 (0.2785) Prec@1 96.875 (95.495) 1087 | Epoch: [151][390/391] Time 0.031 (0.037) Data 0.000 (0.001) Loss 0.3541 (0.2820) Prec@1 91.250 (95.380) 1088 | Total time for epoch [151] : 14.593 1089 | Test: [0/79] Time 0.111 (0.111) Loss 0.6308 (0.6308) Prec@1 83.594 (83.594) 1090 | * Prec@1 79.460 1091 | current lr 6.77578e-03 1092 | Epoch: [152][0/391] Time 0.163 (0.163) Data 0.116 (0.116) Loss 0.3064 (0.3064) Prec@1 94.531 (94.531) 1093 | Epoch: [152][200/391] Time 0.038 (0.038) Data 0.000 (0.001) Loss 0.3032 (0.2695) Prec@1 93.750 (95.693) 1094 | Epoch: [152][390/391] Time 0.029 (0.038) Data 0.000 (0.001) Loss 0.2827 (0.2767) Prec@1 96.250 (95.396) 1095 | Total time for epoch [152] : 14.719 1096 | Test: [0/79] Time 0.109 (0.109) Loss 0.6270 (0.6270) Prec@1 78.906 (78.906) 1097 | * Prec@1 78.930 1098 | current lr 6.50922e-03 1099 | Epoch: [153][0/391] Time 0.205 (0.205) Data 0.160 (0.160) Loss 0.2095 (0.2095) Prec@1 96.094 (96.094) 1100 | Epoch: [153][200/391] Time 0.034 (0.039) Data 0.000 (0.001) Loss 0.3017 (0.2610) Prec@1 96.094 (95.993) 1101 | Epoch: [153][390/391] Time 0.031 (0.038) Data 0.000 (0.001) Loss 0.3696 (0.2705) Prec@1 91.250 (95.674) 1102 | Total time for epoch [153] : 14.837 1103 | Test: [0/79] Time 0.121 (0.121) Loss 0.5886 (0.5886) Prec@1 79.688 (79.688) 1104 | * Prec@1 79.540 1105 | current lr 6.24722e-03 1106 | Epoch: [154][0/391] Time 0.190 (0.190) Data 0.134 (0.134) Loss 0.2285 (0.2285) Prec@1 96.094 (96.094) 1107 | Epoch: [154][200/391] Time 0.035 (0.039) Data 0.000 (0.001) Loss 0.2221 (0.2610) Prec@1 96.875 (95.915) 1108 | Epoch: [154][390/391] Time 0.031 (0.039) Data 0.000 (0.001) Loss 0.2267 (0.2645) Prec@1 97.500 (95.770) 1109 | Total time for epoch [154] : 15.078 1110 | Test: [0/79] Time 0.105 (0.105) Loss 0.6393 (0.6393) Prec@1 81.250 (81.250) 1111 | * Prec@1 79.180 1112 | current lr 5.98985e-03 1113 | Epoch: [155][0/391] Time 0.177 (0.177) Data 0.127 (0.127) Loss 0.3184 (0.3184) Prec@1 92.188 (92.188) 1114 | Epoch: [155][200/391] Time 0.036 (0.039) Data 0.000 (0.001) Loss 0.2916 (0.2567) Prec@1 96.875 (95.907) 1115 | Epoch: [155][390/391] Time 0.032 (0.038) Data 0.000 (0.001) Loss 0.2980 (0.2600) Prec@1 97.500 (95.918) 1116 | Total time for epoch [155] : 14.836 1117 | Test: [0/79] Time 0.112 (0.112) Loss 0.6666 (0.6666) Prec@1 79.688 (79.688) 1118 | * Prec@1 79.550 1119 | current lr 5.73717e-03 1120 | Epoch: [156][0/391] Time 0.187 (0.187) Data 0.130 (0.130) Loss 0.3066 (0.3066) Prec@1 93.750 (93.750) 1121 | Epoch: [156][200/391] Time 0.037 (0.039) Data 0.000 (0.001) Loss 0.2605 (0.2527) Prec@1 96.094 (96.125) 1122 | Epoch: [156][390/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.3066 (0.2564) Prec@1 91.250 (96.108) 1123 | Total time for epoch [156] : 14.901 1124 | Test: [0/79] Time 0.111 (0.111) Loss 0.5792 (0.5792) Prec@1 82.031 (82.031) 1125 | * Prec@1 79.710 1126 | current lr 5.48924e-03 1127 | Epoch: [157][0/391] Time 0.170 (0.170) Data 0.128 (0.128) Loss 0.2048 (0.2048) Prec@1 97.656 (97.656) 1128 | Epoch: [157][200/391] Time 0.039 (0.038) Data 0.000 (0.001) Loss 0.2679 (0.2434) Prec@1 92.969 (96.245) 1129 | Epoch: [157][390/391] Time 0.031 (0.037) Data 0.000 (0.001) Loss 0.2313 (0.2481) Prec@1 96.250 (96.142) 1130 | Total time for epoch [157] : 14.624 1131 | Test: [0/79] Time 0.117 (0.117) Loss 0.6472 (0.6472) Prec@1 77.344 (77.344) 1132 | * Prec@1 79.170 1133 | current lr 5.24612e-03 1134 | Epoch: [158][0/391] Time 0.236 (0.236) Data 0.191 (0.191) Loss 0.3249 (0.3249) Prec@1 92.969 (92.969) 1135 | Epoch: [158][200/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.2262 (0.2391) Prec@1 95.312 (96.459) 1136 | Epoch: [158][390/391] Time 0.046 (0.038) Data 0.000 (0.001) Loss 0.3351 (0.2486) Prec@1 95.000 (96.178) 1137 | Total time for epoch [158] : 14.733 1138 | Test: [0/79] Time 0.129 (0.129) Loss 0.6582 (0.6582) Prec@1 79.688 (79.688) 1139 | * Prec@1 79.410 1140 | current lr 5.00788e-03 1141 | Epoch: [159][0/391] Time 0.190 (0.190) Data 0.138 (0.138) Loss 0.2337 (0.2337) Prec@1 96.094 (96.094) 1142 | Epoch: [159][200/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.2667 (0.2383) Prec@1 96.875 (96.525) 1143 | Epoch: [159][390/391] Time 0.032 (0.038) Data 0.000 (0.001) Loss 0.3199 (0.2418) Prec@1 95.000 (96.390) 1144 | Total time for epoch [159] : 14.910 1145 | Test: [0/79] Time 0.182 (0.182) Loss 0.6246 (0.6246) Prec@1 80.469 (80.469) 1146 | * Prec@1 79.750 1147 | current lr 4.77458e-03 1148 | Epoch: [160][0/391] Time 0.186 (0.186) Data 0.132 (0.132) Loss 0.1674 (0.1674) Prec@1 98.438 (98.438) 1149 | Epoch: [160][200/391] Time 0.040 (0.038) Data 0.000 (0.001) Loss 0.2016 (0.2294) Prec@1 97.656 (96.712) 1150 | Epoch: [160][390/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.2109 (0.2336) Prec@1 97.500 (96.580) 1151 | Total time for epoch [160] : 14.717 1152 | Test: [0/79] Time 0.105 (0.105) Loss 0.6409 (0.6409) Prec@1 79.688 (79.688) 1153 | * Prec@1 79.690 1154 | current lr 4.54626e-03 1155 | Epoch: [161][0/391] Time 0.168 (0.168) Data 0.117 (0.117) Loss 0.2221 (0.2221) Prec@1 96.094 (96.094) 1156 | Epoch: [161][200/391] Time 0.035 (0.039) Data 0.000 (0.001) Loss 0.2271 (0.2286) Prec@1 97.656 (96.634) 1157 | Epoch: [161][390/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.2436 (0.2297) Prec@1 97.500 (96.662) 1158 | Total time for epoch [161] : 14.926 1159 | Test: [0/79] Time 0.114 (0.114) Loss 0.6096 (0.6096) Prec@1 79.688 (79.688) 1160 | * Prec@1 79.700 1161 | current lr 4.32299e-03 1162 | Epoch: [162][0/391] Time 0.190 (0.190) Data 0.141 (0.141) Loss 0.2544 (0.2544) Prec@1 96.094 (96.094) 1163 | Epoch: [162][200/391] Time 0.036 (0.038) Data 0.000 (0.001) Loss 0.2306 (0.2246) Prec@1 96.875 (96.774) 1164 | Epoch: [162][390/391] Time 0.030 (0.038) Data 0.000 (0.001) Loss 0.2056 (0.2268) Prec@1 98.750 (96.752) 1165 | Total time for epoch [162] : 14.681 1166 | Test: [0/79] Time 0.103 (0.103) Loss 0.6010 (0.6010) Prec@1 81.250 (81.250) 1167 | * Prec@1 80.190 1168 | current lr 4.10482e-03 1169 | Epoch: [163][0/391] Time 0.172 (0.172) Data 0.126 (0.126) Loss 0.2087 (0.2087) Prec@1 96.875 (96.875) 1170 | Epoch: [163][200/391] Time 0.036 (0.038) Data 0.000 (0.001) Loss 0.2416 (0.2147) Prec@1 96.875 (97.170) 1171 | Epoch: [163][390/391] Time 0.032 (0.038) Data 0.000 (0.001) Loss 0.3018 (0.2214) Prec@1 96.250 (96.944) 1172 | Total time for epoch [163] : 14.807 1173 | Test: [0/79] Time 0.115 (0.115) Loss 0.5941 (0.5941) Prec@1 79.688 (79.688) 1174 | * Prec@1 79.960 1175 | current lr 3.89180e-03 1176 | Epoch: [164][0/391] Time 0.170 (0.170) Data 0.120 (0.120) Loss 0.1422 (0.1422) Prec@1 98.438 (98.438) 1177 | Epoch: [164][200/391] Time 0.033 (0.039) Data 0.000 (0.001) Loss 0.2693 (0.2177) Prec@1 96.094 (96.891) 1178 | Epoch: [164][390/391] Time 0.030 (0.038) Data 0.000 (0.001) Loss 0.2444 (0.2192) Prec@1 96.250 (96.942) 1179 | Total time for epoch [164] : 15.008 1180 | Test: [0/79] Time 0.120 (0.120) Loss 0.6044 (0.6044) Prec@1 81.250 (81.250) 1181 | * Prec@1 80.310 1182 | current lr 3.68400e-03 1183 | Epoch: [165][0/391] Time 0.173 (0.173) Data 0.124 (0.124) Loss 0.2501 (0.2501) Prec@1 94.531 (94.531) 1184 | Epoch: [165][200/391] Time 0.037 (0.038) Data 0.000 (0.001) Loss 0.2061 (0.2083) Prec@1 96.875 (97.132) 1185 | Epoch: [165][390/391] Time 0.034 (0.038) Data 0.000 (0.001) Loss 0.1554 (0.2128) Prec@1 98.750 (97.092) 1186 | Total time for epoch [165] : 14.750 1187 | Test: [0/79] Time 0.114 (0.114) Loss 0.6369 (0.6369) Prec@1 79.688 (79.688) 1188 | * Prec@1 79.900 1189 | current lr 3.48145e-03 1190 | Epoch: [166][0/391] Time 0.196 (0.196) Data 0.142 (0.142) Loss 0.1635 (0.1635) Prec@1 99.219 (99.219) 1191 | Epoch: [166][200/391] Time 0.038 (0.039) Data 0.000 (0.001) Loss 0.2015 (0.2072) Prec@1 97.656 (97.112) 1192 | Epoch: [166][390/391] Time 0.042 (0.038) Data 0.000 (0.001) Loss 0.2638 (0.2096) Prec@1 95.000 (97.072) 1193 | Total time for epoch [166] : 14.873 1194 | Test: [0/79] Time 0.118 (0.118) Loss 0.6109 (0.6109) Prec@1 81.250 (81.250) 1195 | * Prec@1 80.310 1196 | current lr 3.28421e-03 1197 | Epoch: [167][0/391] Time 0.174 (0.174) Data 0.125 (0.125) Loss 0.2204 (0.2204) Prec@1 95.312 (95.312) 1198 | Epoch: [167][200/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.2381 (0.2114) Prec@1 97.656 (97.038) 1199 | Epoch: [167][390/391] Time 0.029 (0.038) Data 0.000 (0.001) Loss 0.2490 (0.2109) Prec@1 95.000 (97.066) 1200 | Total time for epoch [167] : 14.742 1201 | Test: [0/79] Time 0.109 (0.109) Loss 0.6112 (0.6112) Prec@1 79.688 (79.688) 1202 | * Prec@1 80.570 1203 | current lr 3.09233e-03 1204 | Epoch: [168][0/391] Time 0.181 (0.181) Data 0.127 (0.127) Loss 0.1759 (0.1759) Prec@1 100.000 (100.000) 1205 | Epoch: [168][200/391] Time 0.041 (0.039) Data 0.000 (0.001) Loss 0.1585 (0.2002) Prec@1 99.219 (97.400) 1206 | Epoch: [168][390/391] Time 0.034 (0.038) Data 0.000 (0.001) Loss 0.2106 (0.2022) Prec@1 96.250 (97.252) 1207 | Total time for epoch [168] : 15.051 1208 | Test: [0/79] Time 0.117 (0.117) Loss 0.6145 (0.6145) Prec@1 78.906 (78.906) 1209 | * Prec@1 80.620 1210 | current lr 2.90586e-03 1211 | Epoch: [169][0/391] Time 0.187 (0.187) Data 0.136 (0.136) Loss 0.2001 (0.2001) Prec@1 98.438 (98.438) 1212 | Epoch: [169][200/391] Time 0.036 (0.038) Data 0.000 (0.001) Loss 0.1851 (0.1937) Prec@1 96.875 (97.532) 1213 | Epoch: [169][390/391] Time 0.062 (0.038) Data 0.000 (0.001) Loss 0.1984 (0.1967) Prec@1 97.500 (97.410) 1214 | Total time for epoch [169] : 14.836 1215 | Test: [0/79] Time 0.124 (0.124) Loss 0.5869 (0.5869) Prec@1 81.250 (81.250) 1216 | * Prec@1 80.670 1217 | current lr 2.72484e-03 1218 | Epoch: [170][0/391] Time 0.186 (0.186) Data 0.131 (0.131) Loss 0.1544 (0.1544) Prec@1 97.656 (97.656) 1219 | Epoch: [170][200/391] Time 0.032 (0.038) Data 0.000 (0.001) Loss 0.1972 (0.1962) Prec@1 97.656 (97.423) 1220 | Epoch: [170][390/391] Time 0.029 (0.038) Data 0.000 (0.001) Loss 0.2643 (0.1964) Prec@1 97.500 (97.354) 1221 | Total time for epoch [170] : 14.755 1222 | Test: [0/79] Time 0.119 (0.119) Loss 0.5817 (0.5817) Prec@1 82.812 (82.812) 1223 | * Prec@1 80.730 1224 | current lr 2.54931e-03 1225 | Epoch: [171][0/391] Time 0.172 (0.172) Data 0.126 (0.126) Loss 0.1705 (0.1705) Prec@1 96.875 (96.875) 1226 | Epoch: [171][200/391] Time 0.034 (0.039) Data 0.000 (0.001) Loss 0.1969 (0.1923) Prec@1 96.875 (97.509) 1227 | Epoch: [171][390/391] Time 0.030 (0.038) Data 0.000 (0.001) Loss 0.1923 (0.1958) Prec@1 97.500 (97.392) 1228 | Total time for epoch [171] : 15.028 1229 | Test: [0/79] Time 0.113 (0.113) Loss 0.5893 (0.5893) Prec@1 80.469 (80.469) 1230 | * Prec@1 81.100 1231 | current lr 2.37932e-03 1232 | Epoch: [172][0/391] Time 0.185 (0.185) Data 0.129 (0.129) Loss 0.2078 (0.2078) Prec@1 96.875 (96.875) 1233 | Epoch: [172][200/391] Time 0.036 (0.039) Data 0.000 (0.001) Loss 0.1959 (0.1889) Prec@1 98.438 (97.470) 1234 | Epoch: [172][390/391] Time 0.030 (0.039) Data 0.000 (0.001) Loss 0.1983 (0.1930) Prec@1 97.500 (97.428) 1235 | Total time for epoch [172] : 15.091 1236 | Test: [0/79] Time 0.102 (0.102) Loss 0.5962 (0.5962) Prec@1 81.250 (81.250) 1237 | * Prec@1 80.820 1238 | current lr 2.21492e-03 1239 | Epoch: [173][0/391] Time 0.199 (0.199) Data 0.151 (0.151) Loss 0.2046 (0.2046) Prec@1 97.656 (97.656) 1240 | Epoch: [173][200/391] Time 0.036 (0.038) Data 0.000 (0.001) Loss 0.2025 (0.1926) Prec@1 96.875 (97.477) 1241 | Epoch: [173][390/391] Time 0.029 (0.038) Data 0.000 (0.001) Loss 0.2432 (0.1940) Prec@1 95.000 (97.478) 1242 | Total time for epoch [173] : 14.890 1243 | Test: [0/79] Time 0.105 (0.105) Loss 0.5830 (0.5830) Prec@1 82.031 (82.031) 1244 | * Prec@1 80.730 1245 | current lr 2.05613e-03 1246 | Epoch: [174][0/391] Time 0.170 (0.170) Data 0.122 (0.122) Loss 0.1223 (0.1223) Prec@1 100.000 (100.000) 1247 | Epoch: [174][200/391] Time 0.035 (0.038) Data 0.000 (0.001) Loss 0.2026 (0.1837) Prec@1 97.656 (97.672) 1248 | Epoch: [174][390/391] Time 0.029 (0.038) Data 0.000 (0.001) Loss 0.2120 (0.1859) Prec@1 96.250 (97.608) 1249 | Total time for epoch [174] : 14.810 1250 | Test: [0/79] Time 0.116 (0.116) Loss 0.5876 (0.5876) Prec@1 78.906 (78.906) 1251 | * Prec@1 80.790 1252 | current lr 1.90301e-03 1253 | Epoch: [175][0/391] Time 0.221 (0.221) Data 0.170 (0.170) Loss 0.1131 (0.1131) Prec@1 100.000 (100.000) 1254 | Epoch: [175][200/391] Time 0.035 (0.039) Data 0.000 (0.001) Loss 0.1661 (0.1850) Prec@1 97.656 (97.711) 1255 | Epoch: [175][390/391] Time 0.029 (0.038) Data 0.000 (0.001) Loss 0.2525 (0.1840) Prec@1 97.500 (97.710) 1256 | Total time for epoch [175] : 14.989 1257 | Test: [0/79] Time 0.108 (0.108) Loss 0.5740 (0.5740) Prec@1 82.031 (82.031) 1258 | * Prec@1 80.960 1259 | current lr 1.75559e-03 1260 | Epoch: [176][0/391] Time 0.168 (0.168) Data 0.123 (0.123) Loss 0.1598 (0.1598) Prec@1 99.219 (99.219) 1261 | Epoch: [176][200/391] Time 0.035 (0.039) Data 0.000 (0.001) Loss 0.1993 (0.1812) Prec@1 95.312 (97.730) 1262 | Epoch: [176][390/391] Time 0.030 (0.038) Data 0.000 (0.001) Loss 0.2460 (0.1851) Prec@1 96.250 (97.568) 1263 | Total time for epoch [176] : 14.790 1264 | Test: [0/79] Time 0.174 (0.174) Loss 0.5707 (0.5707) Prec@1 82.031 (82.031) 1265 | * Prec@1 80.910 1266 | current lr 1.61390e-03 1267 | Epoch: [177][0/391] Time 0.193 (0.193) Data 0.144 (0.144) Loss 0.2060 (0.2060) Prec@1 96.875 (96.875) 1268 | Epoch: [177][200/391] Time 0.034 (0.038) Data 0.000 (0.001) Loss 0.2272 (0.1795) Prec@1 95.312 (97.796) 1269 | Epoch: [177][390/391] Time 0.032 (0.037) Data 0.000 (0.001) Loss 0.2492 (0.1828) Prec@1 96.250 (97.632) 1270 | Total time for epoch [177] : 14.626 1271 | Test: [0/79] Time 0.124 (0.124) Loss 0.5910 (0.5910) Prec@1 80.469 (80.469) 1272 | * Prec@1 81.040 1273 | current lr 1.47798e-03 1274 | Epoch: [178][0/391] Time 0.201 (0.201) Data 0.149 (0.149) Loss 0.1514 (0.1514) Prec@1 99.219 (99.219) 1275 | Epoch: [178][200/391] Time 0.039 (0.039) Data 0.001 (0.001) Loss 0.1349 (0.1763) Prec@1 98.438 (97.800) 1276 | Epoch: [178][390/391] Time 0.045 (0.038) Data 0.000 (0.001) Loss 0.1482 (0.1769) Prec@1 97.500 (97.800) 1277 | Total time for epoch [178] : 14.793 1278 | Test: [0/79] Time 0.111 (0.111) Loss 0.6010 (0.6010) Prec@1 81.250 (81.250) 1279 | * Prec@1 80.880 1280 | current lr 1.34787e-03 1281 | Epoch: [179][0/391] Time 0.175 (0.175) Data 0.129 (0.129) Loss 0.1469 (0.1469) Prec@1 100.000 (100.000) 1282 | Epoch: [179][200/391] Time 0.035 (0.039) Data 0.000 (0.001) Loss 0.1678 (0.1753) Prec@1 98.438 (97.901) 1283 | Epoch: [179][390/391] Time 0.035 (0.038) Data 0.000 (0.001) Loss 0.3344 (0.1748) Prec@1 92.500 (97.842) 1284 | Total time for epoch [179] : 14.826 1285 | Test: [0/79] Time 0.112 (0.112) Loss 0.5980 (0.5980) Prec@1 80.469 (80.469) 1286 | * Prec@1 81.090 1287 | current lr 1.22359e-03 1288 | Epoch: [180][0/391] Time 0.163 (0.163) Data 0.119 (0.119) Loss 0.1785 (0.1785) Prec@1 98.438 (98.438) 1289 | Epoch: [180][200/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.1005 (0.1766) Prec@1 98.438 (97.816) 1290 | Epoch: [180][390/391] Time 0.042 (0.038) Data 0.000 (0.001) Loss 0.1598 (0.1760) Prec@1 100.000 (97.804) 1291 | Total time for epoch [180] : 14.746 1292 | Test: [0/79] Time 0.137 (0.137) Loss 0.6271 (0.6271) Prec@1 79.688 (79.688) 1293 | * Prec@1 81.200 1294 | current lr 1.10517e-03 1295 | Epoch: [181][0/391] Time 0.168 (0.168) Data 0.125 (0.125) Loss 0.1414 (0.1414) Prec@1 98.438 (98.438) 1296 | Epoch: [181][200/391] Time 0.039 (0.038) Data 0.000 (0.001) Loss 0.1825 (0.1694) Prec@1 97.656 (97.948) 1297 | Epoch: [181][390/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.1902 (0.1718) Prec@1 98.750 (97.906) 1298 | Total time for epoch [181] : 14.806 1299 | Test: [0/79] Time 0.114 (0.114) Loss 0.6055 (0.6055) Prec@1 81.250 (81.250) 1300 | * Prec@1 81.140 1301 | current lr 9.92658e-04 1302 | Epoch: [182][0/391] Time 0.155 (0.155) Data 0.111 (0.111) Loss 0.0858 (0.0858) Prec@1 99.219 (99.219) 1303 | Epoch: [182][200/391] Time 0.038 (0.038) Data 0.000 (0.001) Loss 0.1196 (0.1699) Prec@1 100.000 (97.909) 1304 | Epoch: [182][390/391] Time 0.030 (0.038) Data 0.000 (0.001) Loss 0.2107 (0.1715) Prec@1 97.500 (97.846) 1305 | Total time for epoch [182] : 14.793 1306 | Test: [0/79] Time 0.109 (0.109) Loss 0.5832 (0.5832) Prec@1 81.250 (81.250) 1307 | * Prec@1 81.220 1308 | current lr 8.86065e-04 1309 | Epoch: [183][0/391] Time 0.205 (0.205) Data 0.154 (0.154) Loss 0.1308 (0.1308) Prec@1 97.656 (97.656) 1310 | Epoch: [183][200/391] Time 0.034 (0.038) Data 0.000 (0.001) Loss 0.1734 (0.1704) Prec@1 96.094 (97.983) 1311 | Epoch: [183][390/391] Time 0.031 (0.038) Data 0.000 (0.001) Loss 0.1905 (0.1714) Prec@1 96.250 (97.884) 1312 | Total time for epoch [183] : 14.817 1313 | Test: [0/79] Time 0.105 (0.105) Loss 0.6069 (0.6069) Prec@1 80.469 (80.469) 1314 | * Prec@1 81.110 1315 | current lr 7.85421e-04 1316 | Epoch: [184][0/391] Time 0.171 (0.171) Data 0.124 (0.124) Loss 0.1986 (0.1986) Prec@1 96.875 (96.875) 1317 | Epoch: [184][200/391] Time 0.037 (0.038) Data 0.000 (0.001) Loss 0.1681 (0.1676) Prec@1 98.438 (97.967) 1318 | Epoch: [184][390/391] Time 0.031 (0.039) Data 0.000 (0.001) Loss 0.1602 (0.1688) Prec@1 98.750 (97.956) 1319 | Total time for epoch [184] : 15.074 1320 | Test: [0/79] Time 0.103 (0.103) Loss 0.6065 (0.6065) Prec@1 79.688 (79.688) 1321 | * Prec@1 81.110 1322 | current lr 6.90752e-04 1323 | Epoch: [185][0/391] Time 0.171 (0.171) Data 0.118 (0.118) Loss 0.1575 (0.1575) Prec@1 99.219 (99.219) 1324 | Epoch: [185][200/391] Time 0.035 (0.040) Data 0.000 (0.001) Loss 0.1281 (0.1677) Prec@1 98.438 (97.971) 1325 | Epoch: [185][390/391] Time 0.047 (0.039) Data 0.000 (0.001) Loss 0.2148 (0.1673) Prec@1 96.250 (97.976) 1326 | Total time for epoch [185] : 15.262 1327 | Test: [0/79] Time 0.122 (0.122) Loss 0.5877 (0.5877) Prec@1 81.250 (81.250) 1328 | * Prec@1 81.280 1329 | current lr 6.02081e-04 1330 | Epoch: [186][0/391] Time 0.172 (0.172) Data 0.113 (0.113) Loss 0.2004 (0.2004) Prec@1 98.438 (98.438) 1331 | Epoch: [186][200/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.1698 (0.1685) Prec@1 99.219 (97.878) 1332 | Epoch: [186][390/391] Time 0.031 (0.039) Data 0.000 (0.001) Loss 0.2079 (0.1676) Prec@1 98.750 (97.904) 1333 | Total time for epoch [186] : 15.067 1334 | Test: [0/79] Time 0.105 (0.105) Loss 0.5811 (0.5811) Prec@1 80.469 (80.469) 1335 | * Prec@1 81.190 1336 | current lr 5.19430e-04 1337 | Epoch: [187][0/391] Time 0.179 (0.179) Data 0.124 (0.124) Loss 0.2247 (0.2247) Prec@1 95.312 (95.312) 1338 | Epoch: [187][200/391] Time 0.037 (0.039) Data 0.000 (0.001) Loss 0.1785 (0.1686) Prec@1 96.875 (97.893) 1339 | Epoch: [187][390/391] Time 0.030 (0.038) Data 0.000 (0.001) Loss 0.1480 (0.1671) Prec@1 100.000 (97.908) 1340 | Total time for epoch [187] : 14.883 1341 | Test: [0/79] Time 0.110 (0.110) Loss 0.5905 (0.5905) Prec@1 80.469 (80.469) 1342 | * Prec@1 81.250 1343 | current lr 4.42819e-04 1344 | Epoch: [188][0/391] Time 0.181 (0.181) Data 0.126 (0.126) Loss 0.1589 (0.1589) Prec@1 98.438 (98.438) 1345 | Epoch: [188][200/391] Time 0.037 (0.039) Data 0.000 (0.001) Loss 0.2299 (0.1683) Prec@1 96.875 (97.948) 1346 | Epoch: [188][390/391] Time 0.030 (0.038) Data 0.000 (0.001) Loss 0.1896 (0.1654) Prec@1 98.750 (97.978) 1347 | Total time for epoch [188] : 14.911 1348 | Test: [0/79] Time 0.103 (0.103) Loss 0.5922 (0.5922) Prec@1 80.469 (80.469) 1349 | * Prec@1 81.100 1350 | current lr 3.72267e-04 1351 | Epoch: [189][0/391] Time 0.195 (0.195) Data 0.139 (0.139) Loss 0.1541 (0.1541) Prec@1 96.875 (96.875) 1352 | Epoch: [189][200/391] Time 0.045 (0.038) Data 0.000 (0.001) Loss 0.1406 (0.1635) Prec@1 97.656 (98.041) 1353 | Epoch: [189][390/391] Time 0.032 (0.037) Data 0.000 (0.001) Loss 0.1374 (0.1639) Prec@1 98.750 (98.054) 1354 | Total time for epoch [189] : 14.634 1355 | Test: [0/79] Time 0.114 (0.114) Loss 0.5819 (0.5819) Prec@1 81.250 (81.250) 1356 | * Prec@1 81.260 1357 | current lr 3.07791e-04 1358 | Epoch: [190][0/391] Time 0.181 (0.181) Data 0.133 (0.133) Loss 0.1283 (0.1283) Prec@1 99.219 (99.219) 1359 | Epoch: [190][200/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.1682 (0.1672) Prec@1 97.656 (98.029) 1360 | Epoch: [190][390/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.2071 (0.1668) Prec@1 98.750 (97.938) 1361 | Total time for epoch [190] : 14.764 1362 | Test: [0/79] Time 0.142 (0.142) Loss 0.5903 (0.5903) Prec@1 79.688 (79.688) 1363 | * Prec@1 81.350 1364 | current lr 2.49409e-04 1365 | Epoch: [191][0/391] Time 0.167 (0.167) Data 0.119 (0.119) Loss 0.1709 (0.1709) Prec@1 96.875 (96.875) 1366 | Epoch: [191][200/391] Time 0.033 (0.037) Data 0.000 (0.001) Loss 0.1411 (0.1622) Prec@1 99.219 (98.150) 1367 | Epoch: [191][390/391] Time 0.038 (0.037) Data 0.000 (0.001) Loss 0.1868 (0.1632) Prec@1 96.250 (98.138) 1368 | Total time for epoch [191] : 14.565 1369 | Test: [0/79] Time 0.117 (0.117) Loss 0.5936 (0.5936) Prec@1 78.906 (78.906) 1370 | * Prec@1 81.360 1371 | current lr 1.97132e-04 1372 | Epoch: [192][0/391] Time 0.176 (0.176) Data 0.127 (0.127) Loss 0.1609 (0.1609) Prec@1 98.438 (98.438) 1373 | Epoch: [192][200/391] Time 0.036 (0.038) Data 0.000 (0.001) Loss 0.1722 (0.1635) Prec@1 98.438 (97.928) 1374 | Epoch: [192][390/391] Time 0.031 (0.038) Data 0.000 (0.001) Loss 0.1735 (0.1639) Prec@1 97.500 (97.930) 1375 | Total time for epoch [192] : 14.871 1376 | Test: [0/79] Time 0.106 (0.106) Loss 0.5900 (0.5900) Prec@1 82.031 (82.031) 1377 | * Prec@1 81.410 1378 | current lr 1.50976e-04 1379 | Epoch: [193][0/391] Time 0.172 (0.172) Data 0.121 (0.121) Loss 0.1688 (0.1688) Prec@1 98.438 (98.438) 1380 | Epoch: [193][200/391] Time 0.034 (0.038) Data 0.000 (0.001) Loss 0.1947 (0.1656) Prec@1 96.094 (98.049) 1381 | Epoch: [193][390/391] Time 0.033 (0.038) Data 0.000 (0.001) Loss 0.1897 (0.1619) Prec@1 98.750 (98.118) 1382 | Total time for epoch [193] : 14.683 1383 | Test: [0/79] Time 0.123 (0.123) Loss 0.6005 (0.6005) Prec@1 81.250 (81.250) 1384 | * Prec@1 81.360 1385 | current lr 1.10951e-04 1386 | Epoch: [194][0/391] Time 0.199 (0.199) Data 0.155 (0.155) Loss 0.2398 (0.2398) Prec@1 96.094 (96.094) 1387 | Epoch: [194][200/391] Time 0.037 (0.038) Data 0.000 (0.001) Loss 0.1535 (0.1610) Prec@1 99.219 (98.057) 1388 | Epoch: [194][390/391] Time 0.030 (0.039) Data 0.000 (0.001) Loss 0.2247 (0.1615) Prec@1 96.250 (98.082) 1389 | Total time for epoch [194] : 15.058 1390 | Test: [0/79] Time 0.112 (0.112) Loss 0.5988 (0.5988) Prec@1 79.688 (79.688) 1391 | * Prec@1 81.170 1392 | current lr 7.70667e-05 1393 | Epoch: [195][0/391] Time 0.192 (0.192) Data 0.132 (0.132) Loss 0.1477 (0.1477) Prec@1 98.438 (98.438) 1394 | Epoch: [195][200/391] Time 0.040 (0.038) Data 0.000 (0.001) Loss 0.1193 (0.1558) Prec@1 98.438 (98.228) 1395 | Epoch: [195][390/391] Time 0.034 (0.038) Data 0.000 (0.001) Loss 0.1766 (0.1576) Prec@1 98.750 (98.216) 1396 | Total time for epoch [195] : 14.736 1397 | Test: [0/79] Time 0.117 (0.117) Loss 0.5917 (0.5917) Prec@1 81.250 (81.250) 1398 | * Prec@1 81.360 1399 | current lr 4.93318e-05 1400 | Epoch: [196][0/391] Time 0.196 (0.196) Data 0.143 (0.143) Loss 0.1420 (0.1420) Prec@1 99.219 (99.219) 1401 | Epoch: [196][200/391] Time 0.046 (0.039) Data 0.000 (0.001) Loss 0.1641 (0.1639) Prec@1 98.438 (98.018) 1402 | Epoch: [196][390/391] Time 0.037 (0.038) Data 0.000 (0.001) Loss 0.1968 (0.1613) Prec@1 98.750 (98.134) 1403 | Total time for epoch [196] : 14.892 1404 | Test: [0/79] Time 0.110 (0.110) Loss 0.5879 (0.5879) Prec@1 80.469 (80.469) 1405 | * Prec@1 81.430 1406 | current lr 2.77531e-05 1407 | Epoch: [197][0/391] Time 0.241 (0.241) Data 0.192 (0.192) Loss 0.2146 (0.2146) Prec@1 98.438 (98.438) 1408 | Epoch: [197][200/391] Time 0.037 (0.038) Data 0.000 (0.001) Loss 0.1188 (0.1590) Prec@1 100.000 (98.259) 1409 | Epoch: [197][390/391] Time 0.029 (0.038) Data 0.000 (0.001) Loss 0.1097 (0.1607) Prec@1 97.500 (98.156) 1410 | Total time for epoch [197] : 14.822 1411 | Test: [0/79] Time 0.116 (0.116) Loss 0.6077 (0.6077) Prec@1 79.688 (79.688) 1412 | * Prec@1 81.280 1413 | current lr 1.23360e-05 1414 | Epoch: [198][0/391] Time 0.164 (0.164) Data 0.115 (0.115) Loss 0.1401 (0.1401) Prec@1 97.656 (97.656) 1415 | Epoch: [198][200/391] Time 0.040 (0.038) Data 0.000 (0.001) Loss 0.1876 (0.1594) Prec@1 98.438 (98.259) 1416 | Epoch: [198][390/391] Time 0.030 (0.037) Data 0.000 (0.001) Loss 0.1757 (0.1609) Prec@1 98.750 (98.172) 1417 | Total time for epoch [198] : 14.650 1418 | Test: [0/79] Time 0.116 (0.116) Loss 0.5906 (0.5906) Prec@1 80.469 (80.469) 1419 | * Prec@1 81.410 1420 | current lr 3.08419e-06 1421 | Epoch: [199][0/391] Time 0.170 (0.170) Data 0.122 (0.122) Loss 0.1614 (0.1614) Prec@1 98.438 (98.438) 1422 | Epoch: [199][200/391] Time 0.035 (0.039) Data 0.000 (0.001) Loss 0.2097 (0.1620) Prec@1 97.656 (98.127) 1423 | Epoch: [199][390/391] Time 0.031 (0.038) Data 0.000 (0.001) Loss 0.1808 (0.1620) Prec@1 98.750 (98.052) 1424 | Total time for epoch [199] : 14.753 1425 | Test: [0/79] Time 0.105 (0.105) Loss 0.5899 (0.5899) Prec@1 79.688 (79.688) 1426 | * Prec@1 81.320 1427 | final 1428 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | device=0 2 | seed=1 3 | datasets=CIFAR100 4 | model=resnet18 # resnet18 VGG16BN WideResNet28x10 5 | schedule=cosine 6 | wd=0.001 7 | epoch=200 8 | bz=128 9 | rho=0.2 10 | sigma=1 11 | lmbda=0.6 12 | opt=FriendlySAM # FriendlySAM SAM 13 | DST=results/$opt/$datasets/$model/$opt\_cutout\_$rho\_$sigma\_$lmbda\_$epoch\_$model\_bz$bz\_wd$wd\_$datasets\_$schedule\_seed$seed 14 | 15 | CUDA_VISIBLE_DEVICES=$device python -u train.py --datasets $datasets \ 16 | --arch=$model --epochs=$epoch --wd=$wd --randomseed $seed --lr 0.05 --rho $rho --optimizer $opt \ 17 | --save-dir=$DST/checkpoints --log-dir=$DST -p 200 --schedule $schedule -b $bz \ 18 | --cutout --sigma $sigma --lmbda $lmbda -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import numpy as np 5 | import random 6 | import sys 7 | 8 | from torch.nn.modules.batchnorm import _BatchNorm 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim 14 | import torch.utils.data 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | 18 | import torch.nn.functional as F 19 | from torch.optim.optimizer import Optimizer, required 20 | import math 21 | 22 | from PIL import Image, ImageFile 23 | ImageFile.LOAD_TRUNCATED_IMAGES = True 24 | 25 | from utils import * 26 | 27 | # Parse arguments 28 | parser = argparse.ArgumentParser(description='Regular SGD training') 29 | parser.add_argument('--EXP', metavar='EXP', help='experiment name', default='SGD') 30 | parser.add_argument('--arch', '-a', metavar='ARCH', 31 | help='The architecture of the model') 32 | parser.add_argument('--datasets', metavar='DATASETS', default='CIFAR10', type=str, 33 | help='The training datasets') 34 | parser.add_argument('--optimizer', metavar='OPTIMIZER', default='sgd', type=str, 35 | help='The optimizer for training') 36 | parser.add_argument('--schedule', metavar='SCHEDULE', default='step', type=str, 37 | help='The schedule for training') 38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 39 | help='number of data loading workers (default: 4)') 40 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 41 | help='number of total epochs to run') 42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 43 | help='manual epoch number (useful on restarts)') 44 | parser.add_argument('-b', '--batch-size', default=128, type=int, 45 | metavar='N', help='mini-batch size (default: 128)') 46 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 47 | metavar='LR', help='initial learning rate') 48 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 49 | help='momentum') 50 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 51 | metavar='W', help='weight decay (default: 1e-4)') 52 | parser.add_argument('--print-freq', '-p', default=100, type=int, 53 | metavar='N', help='print frequency (default: 50 iterations)') 54 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 55 | help='path to latest checkpoint (default: none)') 56 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 57 | help='evaluate model on validation set') 58 | parser.add_argument('--wandb', dest='wandb', action='store_true', 59 | help='use wandb to monitor statisitcs') 60 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 61 | help='use pre-trained model') 62 | parser.add_argument('--half', dest='half', action='store_true', 63 | help='use half-precision(16-bit) ') 64 | parser.add_argument('--save-dir', dest='save_dir', 65 | help='The directory used to save the trained models', 66 | default='save_temp', type=str) 67 | parser.add_argument('--log-dir', dest='log_dir', 68 | help='The directory used to save the log', 69 | default='save_temp', type=str) 70 | parser.add_argument('--log-name', dest='log_name', 71 | help='The log file name', 72 | default='log', type=str) 73 | parser.add_argument('--randomseed', 74 | help='Randomseed for training and initialization', 75 | type=int, default=1) 76 | parser.add_argument('--rho', default=0.1, type=float, 77 | metavar='RHO', help='rho for sam') 78 | parser.add_argument('--cutout', dest='cutout', action='store_true', 79 | help='use cutout data augmentation') 80 | parser.add_argument('--sigma', default=1, type=float, 81 | metavar='S', help='sigma for FriendlySAM') 82 | parser.add_argument('--lmbda', default=0.95, type=float, 83 | metavar='L', help='lambda for FriendlySAM') 84 | 85 | 86 | 87 | parser.add_argument('--noise_ratio', default=0.5, type=float, 88 | metavar='N', help='noise ratio for dataset') 89 | 90 | 91 | 92 | parser.add_argument('--img_size', type=int, default=224, help="Resolution size") 93 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 94 | help='Dropout rate (default: 0.)') 95 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 96 | help='Drop path rate (default: 0.1)') 97 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 98 | help='Drop block rate (default: None)') 99 | #vit parameters 100 | parser.add_argument('--patch', dest='patch', type=int, default=4) 101 | parser.add_argument('--dimhead', dest='dimhead', type=int, default=512) 102 | parser.add_argument('--convkernel', dest='convkernel', type=int, default=8) 103 | 104 | best_prec1 = 0 105 | 106 | import torch 107 | 108 | def disable_running_stats(model): 109 | def _disable(module): 110 | if isinstance(module, _BatchNorm): 111 | module.backup_momentum = module.momentum 112 | module.momentum = 0 113 | 114 | model.apply(_disable) 115 | 116 | def enable_running_stats(model): 117 | def _enable(module): 118 | if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"): 119 | module.momentum = module.backup_momentum 120 | 121 | model.apply(_enable) 122 | 123 | 124 | 125 | # Record training statistics 126 | train_loss = [] 127 | train_err = [] 128 | post_train_loss = [] 129 | post_train_err = [] 130 | ori_train_loss = [] 131 | ori_train_err = [] 132 | test_loss = [] 133 | test_err = [] 134 | arr_time = [] 135 | 136 | p0 = None 137 | 138 | args = parser.parse_args() 139 | 140 | if args.wandb: 141 | import wandb 142 | wandb.init(project="TWA", entity="nblt") 143 | date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 144 | wandb.run.name = args.EXP + date 145 | 146 | def get_model_param_vec(model): 147 | # Return the model parameters as a vector 148 | 149 | vec = [] 150 | for name,param in model.named_parameters(): 151 | vec.append(param.data.detach().reshape(-1)) 152 | return torch.cat(vec, 0) 153 | 154 | 155 | def get_model_grad_vec(model): 156 | # Return the model gradient as a vector 157 | 158 | vec = [] 159 | for name,param in model.named_parameters(): 160 | vec.append(param.grad.detach().reshape(-1)) 161 | return torch.cat(vec, 0) 162 | 163 | def update_grad(model, grad_vec): 164 | idx = 0 165 | for name,param in model.named_parameters(): 166 | arr_shape = param.grad.shape 167 | size = 1 168 | for i in range(len(list(arr_shape))): 169 | size *= arr_shape[i] 170 | param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape) 171 | idx += size 172 | 173 | def update_param(model, param_vec): 174 | idx = 0 175 | for name,param in model.named_parameters(): 176 | arr_shape = param.data.shape 177 | size = 1 178 | for i in range(len(list(arr_shape))): 179 | size *= arr_shape[i] 180 | param.data = param_vec[idx:idx+size].reshape(arr_shape) 181 | idx += size 182 | 183 | sample_idx = 0 184 | 185 | def _cosine_annealing(step, total_steps, lr_max, lr_min): 186 | return lr_min + (lr_max - 187 | lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi)) 188 | 189 | def get_cosine_annealing_scheduler(optimizer, epochs, steps_per_epoch, base_lr): 190 | lr_min = 0.0 191 | total_steps = epochs * steps_per_epoch 192 | 193 | scheduler = torch.optim.lr_scheduler.LambdaLR( 194 | optimizer, 195 | lr_lambda=lambda step: _cosine_annealing( 196 | step, 197 | total_steps, 198 | 1, # since lr_lambda computes multiplicative factor 199 | lr_min / base_lr)) 200 | 201 | return scheduler 202 | 203 | def main(): 204 | 205 | global args, best_prec1, p0 206 | global param_avg, train_loss, train_err, post_train_loss, post_train_err, test_loss, test_err, arr_time 207 | 208 | set_seed(args.randomseed) 209 | 210 | # Check the save_dir exists or not 211 | if not os.path.exists(args.save_dir): 212 | os.makedirs(args.save_dir) 213 | 214 | # Check the log_dir exists or not 215 | if not os.path.exists(args.log_dir): 216 | os.makedirs(args.log_dir) 217 | 218 | 219 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name)) 220 | print ('save dir:', args.save_dir) 221 | print ('log dir:', args.log_dir) 222 | # Define model 223 | # model = torch.nn.DataParallel(get_model(args)) 224 | model = get_model(args) 225 | model.cuda() 226 | 227 | # for n, p in model.named_parameters(): 228 | # print (n, p.shape) 229 | # while (True): pass 230 | 231 | # Optionally resume from a checkpoint 232 | if args.resume: 233 | # if os.path.isfile(args.resume): 234 | if os.path.isfile(os.path.join(args.save_dir, args.resume)): 235 | 236 | # model.load_state_dict(torch.load(os.path.join(args.save_dir, args.resume))) 237 | 238 | print ("=> loading checkpoint '{}'".format(args.resume)) 239 | checkpoint = torch.load(args.resume) 240 | args.start_epoch = checkpoint['epoch'] 241 | print ('from ', args.start_epoch) 242 | best_prec1 = checkpoint['best_prec1'] 243 | model.load_state_dict(checkpoint['state_dict']) 244 | print ("=> loaded checkpoint '{}' (epoch {})" 245 | .format(args.evaluate, checkpoint['epoch'])) 246 | else: 247 | print ("=> no checkpoint found at '{}'".format(args.resume)) 248 | 249 | cudnn.benchmark = True 250 | 251 | # Prepare Dataloader 252 | print ('lambda:', args.lmbda) 253 | print ('cutout:', args.cutout) 254 | if args.cutout: 255 | train_loader, val_loader = get_datasets_cutout(args) 256 | 257 | print (len(train_loader)) 258 | print (len(train_loader.dataset)) 259 | 260 | # define loss function (criterion) and optimizer 261 | criterion = nn.CrossEntropyLoss().cuda() 262 | 263 | if args.half: 264 | model.half() 265 | criterion.half() 266 | 267 | print ('optimizer:', args.optimizer) 268 | 269 | if args.optimizer == 'SAM': 270 | base_optimizer = torch.optim.SGD 271 | optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=0, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, \ 272 | nesterov=False) 273 | elif args.optimizer == 'SAM_adamw': 274 | base_optimizer = torch.optim.AdamW 275 | optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=0, lr=args.lr, weight_decay=args.weight_decay) 276 | elif args.optimizer == 'FriendlySAM': 277 | base_optimizer = torch.optim.SGD 278 | optimizer = FriendlySAM(model.parameters(), base_optimizer, rho=args.rho, sigma=args.sigma, lmbda=args.lmbda, adaptive=0, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,\ 279 | nesterov=False) 280 | elif args.optimizer == 'FriendlySAM_adamw': 281 | base_optimizer = torch.optim.AdamW 282 | optimizer = FriendlySAM(model.parameters(), base_optimizer, rho=args.rho, sigma=args.sigma, lmbda=args.lmbda, adaptive=0, lr=args.lr, weight_decay=args.weight_decay) 283 | 284 | print (optimizer) 285 | if args.schedule == 'step': 286 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer.base_optimizer, milestones=[60, 120,160], gamma=0.2, last_epoch=args.start_epoch - 1) 287 | elif args.schedule == 'cosine': 288 | # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer.base_optimizer, T_max=args.epochs) 289 | lr_scheduler = get_cosine_annealing_scheduler(optimizer, args.epochs, len(train_loader), args.lr) 290 | 291 | if args.evaluate: 292 | validate(val_loader, model, criterion) 293 | return 294 | 295 | is_best = 0 296 | print ('Start training: ', args.start_epoch, '->', args.epochs) 297 | 298 | # DLDR sampling 299 | torch.save(model.state_dict(), os.path.join(args.save_dir, str(0) + '.pt')) 300 | 301 | p0 = get_model_param_vec(model) 302 | 303 | for epoch in range(args.start_epoch, args.epochs): 304 | 305 | # train for one epoch 306 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 307 | train(train_loader, model, criterion, optimizer, lr_scheduler, epoch) 308 | 309 | 310 | # evaluate on validation set 311 | prec1 = validate(val_loader, model, criterion) 312 | 313 | # remember best prec@1 and save checkpoint 314 | is_best = prec1 > best_prec1 315 | best_prec1 = max(prec1, best_prec1) 316 | 317 | save_checkpoint({ 318 | 'state_dict': model.state_dict(), 319 | 'best_prec1': best_prec1, 320 | }, is_best, filename=os.path.join(args.save_dir, 'model.th')) 321 | 322 | print ('train loss: ', train_loss) 323 | print ('train err: ', train_err) 324 | print ('test loss: ', test_loss) 325 | print ('test err: ', test_err) 326 | print ('ori train loss: ', ori_train_loss) 327 | print ('ori train err: ', ori_train_err) 328 | print ('time: ', arr_time) 329 | # print ('best ema:', ema_max) 330 | 331 | def train(train_loader, model, criterion, optimizer, lr_scheduler, epoch): 332 | """ 333 | Run one train epoch 334 | """ 335 | global train_loss, train_err, post_train_loss, post_train_err, ori_train_loss, ori_train_err, arr_time 336 | 337 | batch_time = AverageMeter() 338 | data_time = AverageMeter() 339 | losses = AverageMeter() 340 | top1 = AverageMeter() 341 | 342 | # switch to train mode 343 | model.train() 344 | 345 | total_loss, total_err = 0, 0 346 | post_total_loss, post_total_err = 0, 0 347 | ori_total_loss, ori_total_err = 0, 0 348 | 349 | end = time.time() 350 | for i, (input, target) in enumerate(train_loader): 351 | 352 | # measure data loading time 353 | data_time.update(time.time() - end) 354 | 355 | target = target.cuda() 356 | input_var = input.cuda() 357 | target_var = target 358 | if args.half: 359 | input_var = input_var.half() 360 | 361 | enable_running_stats(model) 362 | 363 | # first forward-backward step 364 | predictions = model(input_var) 365 | loss = criterion(predictions, target_var) 366 | loss.mean().backward() 367 | optimizer.first_step(zero_grad=True) 368 | 369 | # second forward-backward step 370 | disable_running_stats(model) 371 | output_adv = model(input_var) 372 | loss_adv = criterion(output_adv, target_var) 373 | loss_adv.mean().backward() 374 | optimizer.second_step(zero_grad=True) 375 | 376 | 377 | lr_scheduler.step() 378 | 379 | output = predictions.float() 380 | loss = loss.float() 381 | 382 | total_loss += loss.item() * input_var.shape[0] 383 | total_err += (output.max(dim=1)[1] != target_var).sum().item() 384 | 385 | # measure accuracy and record loss 386 | prec1 = accuracy(output.data, target)[0] 387 | losses.update(loss.item(), input.size(0)) 388 | top1.update(prec1.item(), input.size(0)) 389 | 390 | # measure elapsed time 391 | batch_time.update(time.time() - end) 392 | end = time.time() 393 | 394 | ori_total_loss += loss_adv.item() * input_var.shape[0] 395 | ori_total_err += (output_adv.max(dim=1)[1] != target_var).sum().item() 396 | 397 | if i % args.print_freq == 0 or i == len(train_loader) - 1: 398 | print('Epoch: [{0}][{1}/{2}]\t' 399 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 400 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 401 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 402 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 403 | epoch, i, len(train_loader), batch_time=batch_time, 404 | data_time=data_time, loss=losses, top1=top1)) 405 | 406 | print ('Total time for epoch [{0}] : {1:.3f}'.format(epoch, batch_time.sum)) 407 | 408 | train_loss.append(total_loss / len(train_loader.dataset)) 409 | train_err.append(total_err / len(train_loader.dataset)) 410 | ori_train_loss.append(ori_total_loss / len(train_loader.dataset)) 411 | ori_train_err.append(ori_total_err / len(train_loader.dataset)) 412 | 413 | if args.wandb: 414 | wandb.log({"train loss": total_loss / len(train_loader.dataset)}) 415 | wandb.log({"train acc": 1 - total_err / len(train_loader.dataset)}) 416 | 417 | arr_time.append(batch_time.sum) 418 | 419 | def validate(val_loader, model, criterion): 420 | """ 421 | Run evaluation 422 | """ 423 | global test_err, test_loss 424 | 425 | total_loss = 0 426 | total_err = 0 427 | 428 | batch_time = AverageMeter() 429 | losses = AverageMeter() 430 | top1 = AverageMeter() 431 | 432 | # switch to evaluate mode 433 | model.eval() 434 | 435 | end = time.time() 436 | with torch.no_grad(): 437 | for i, (input, target) in enumerate(val_loader): 438 | target = target.cuda() 439 | input_var = input.cuda() 440 | target_var = target.cuda() 441 | 442 | if args.half: 443 | input_var = input_var.half() 444 | 445 | # compute output 446 | output = model(input_var) 447 | loss = criterion(output, target_var) 448 | 449 | output = output.float() 450 | loss = loss.float() 451 | 452 | total_loss += loss.item() * input_var.shape[0] 453 | total_err += (output.max(dim=1)[1] != target_var).sum().item() 454 | 455 | # measure accuracy and record loss 456 | prec1 = accuracy(output.data, target)[0] 457 | losses.update(loss.item(), input.size(0)) 458 | top1.update(prec1.item(), input.size(0)) 459 | 460 | # measure elapsed time 461 | batch_time.update(time.time() - end) 462 | end = time.time() 463 | 464 | if i % args.print_freq == 0: 465 | print('Test: [{0}/{1}]\t' 466 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 467 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 468 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 469 | i, len(val_loader), batch_time=batch_time, loss=losses, 470 | top1=top1)) 471 | 472 | print(' * Prec@1 {top1.avg:.3f}' 473 | .format(top1=top1)) 474 | 475 | test_loss.append(total_loss / len(val_loader.dataset)) 476 | test_err.append(total_err / len(val_loader.dataset)) 477 | 478 | if args.wandb: 479 | wandb.log({"test loss": total_loss / len(val_loader.dataset)}) 480 | wandb.log({"test acc": 1 - total_err / len(val_loader.dataset)}) 481 | 482 | return top1.avg 483 | 484 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 485 | """ 486 | Save the training model 487 | """ 488 | torch.save(state, filename) 489 | 490 | class AverageMeter(object): 491 | """Computes and stores the average and current value""" 492 | def __init__(self): 493 | self.reset() 494 | 495 | def reset(self): 496 | self.val = 0 497 | self.avg = 0 498 | self.sum = 0 499 | self.count = 0 500 | 501 | def update(self, val, n=1): 502 | self.val = val 503 | self.sum += val * n 504 | self.count += n 505 | self.avg = self.sum / self.count 506 | 507 | 508 | def accuracy(output, target, topk=(1,)): 509 | """Computes the precision@k for the specified values of k""" 510 | maxk = max(topk) 511 | batch_size = target.size(0) 512 | 513 | _, pred = output.topk(maxk, 1, True, True) 514 | pred = pred.t() 515 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 516 | 517 | res = [] 518 | for k in topk: 519 | correct_k = correct[:k].view(-1).float().sum(0) 520 | res.append(correct_k.mul_(100.0 / batch_size)) 521 | return res 522 | 523 | 524 | if __name__ == '__main__': 525 | main() 526 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.backends.cudnn as cudnn 5 | import torch.optim as optim 6 | import torch.utils.data 7 | import torch.nn.functional as F 8 | import torchvision.transforms as transforms 9 | import torchvision.datasets as datasets 10 | import torchvision.models as models_imagenet 11 | from models.pyramidnet import PyramidNet 12 | from timm.models.vision_transformer import VisionTransformer, _cfg 13 | from functools import partial 14 | from torch.utils.data import DataLoader, Dataset, Subset 15 | 16 | import numpy as np 17 | import random 18 | import os 19 | import time 20 | import models 21 | import sys 22 | import torch.utils.data as data 23 | from torchvision.datasets.utils import download_url, check_integrity 24 | import os.path 25 | import pickle 26 | from PIL import Image 27 | 28 | import torch.nn.functional as F 29 | from torch.optim.optimizer import Optimizer, required 30 | import math 31 | 32 | def set_seed(seed=1): 33 | random.seed(seed) 34 | np.random.seed(seed) 35 | torch.manual_seed(seed) 36 | torch.cuda.manual_seed(seed) 37 | torch.backends.cudnn.deterministic = True 38 | torch.backends.cudnn.benchmark = False 39 | 40 | class Logger(object): 41 | def __init__(self,fileN ="Default.log"): 42 | self.terminal = sys.stdout 43 | self.log = open(fileN,"a") 44 | 45 | def write(self,message): 46 | self.terminal.write(message) 47 | self.log.write(message) 48 | 49 | def flush(self): 50 | self.terminal.flush() 51 | self.log.flush() 52 | 53 | ################################ datasets ####################################### 54 | 55 | import torchvision.transforms as transforms 56 | import torchvision.datasets as datasets 57 | from torch.utils.data import DataLoader, Subset 58 | from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder 59 | 60 | 61 | class Cutout: 62 | def __init__(self, size=16, p=0.5): 63 | self.size = size 64 | self.half_size = size // 2 65 | self.p = p 66 | 67 | def __call__(self, image): 68 | if torch.rand([1]).item() > self.p: 69 | return image 70 | 71 | left = torch.randint(-self.half_size, image.size(1) - self.half_size, [1]).item() 72 | top = torch.randint(-self.half_size, image.size(2) - self.half_size, [1]).item() 73 | right = min(image.size(1), left + self.size) 74 | bottom = min(image.size(2), top + self.size) 75 | 76 | image[:, max(0, left): right, max(0, top): bottom] = 0 77 | return image 78 | 79 | 80 | def unpickle(file): 81 | import _pickle as cPickle 82 | with open(file, 'rb') as fo: 83 | dict = cPickle.load(fo, encoding='latin1') 84 | return dict 85 | 86 | 87 | class cifar_dataset(Dataset): 88 | def __init__(self, dataset='cifar10', r=0.4, noise_mode='sym', root_dir='./datasets/cifar-10-batches-py', 89 | transform=None, mode='all', noise_file='cifar10.json', pred=[], probability=[], log=''): 90 | 91 | self.r = r # noise ratio 92 | self.transform = transform 93 | if dataset=='cifar100': 94 | root_dir = './datasets/cifar-100-python' 95 | self.mode = mode #mode 'test', 'all', 'labeled', 'unlabeled' 96 | self.transition = {0:0,2:0,4:7,7:7,1:1,9:1,3:5,5:3,6:6,8:8} # class transition for asymmetric noise 97 | self.noise_file = os.path.join(root_dir, noise_file) 98 | if self.mode=='test': 99 | if dataset=='cifar10': 100 | test_dic = unpickle('%s/test_batch'%root_dir) 101 | self.test_data = test_dic['data'] 102 | self.test_data = self.test_data.reshape((10000, 3, 32, 32)) 103 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) #(1000,32,32,3) 104 | self.test_label = test_dic['labels'] 105 | elif dataset=='cifar100': 106 | test_dic = unpickle('%s/test'%root_dir) 107 | self.test_data = test_dic['data'] 108 | self.test_data = self.test_data.reshape((10000, 3, 32, 32)) 109 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) #(1000,32,32,3) 110 | self.test_label = test_dic['fine_labels'] 111 | else: #'train 112 | train_data=[] 113 | train_label=[] 114 | if dataset=='cifar10': 115 | for n in range(1,6): #1~5 116 | dpath = '%s/data_batch_%d'%(root_dir,n) 117 | data_dic = unpickle(dpath) 118 | train_data.append(data_dic['data']) 119 | train_label = train_label+data_dic['labels'] 120 | train_data = np.concatenate(train_data) 121 | elif dataset=='cifar100': 122 | train_dic = unpickle('%s/train'%root_dir) 123 | train_data = train_dic['data'] 124 | train_label = train_dic['fine_labels'] 125 | train_data = train_data.reshape((50000, 3, 32, 32)) 126 | train_data = train_data.transpose((0, 2, 3, 1)) #(5000,32,32,3) 127 | 128 | if os.path.exists(self.noise_file): 129 | noise_label = json.load(open(self.noise_file,"r")) 130 | else: #inject noise 131 | noise_label = [] 132 | idx = list(range(50000)) 133 | random.shuffle(idx) 134 | num_noise = int(self.r*50000) 135 | noise_idx = idx[:num_noise] 136 | for i in range(50000): 137 | if i in noise_idx: 138 | if noise_mode=='sym': 139 | if dataset=='cifar10': 140 | noiselabel = random.randint(0,9) 141 | elif dataset=='cifar100': 142 | noiselabel = random.randint(0,99) 143 | noise_label.append(noiselabel) 144 | elif noise_mode=='asym': 145 | noiselabel = self.transition[train_label[i]] 146 | noise_label.append(noiselabel) 147 | else: 148 | noise_label.append(train_label[i]) 149 | # print("save noisy labels to %s ..."%self.noise_file) 150 | # print('self.nose_file', type(self.noise_file), self.noise_file) 151 | # json.dump(noise_label,open(self.noise_file,"w")) 152 | 153 | if self.mode == 'all': 154 | self.train_data = train_data 155 | self.noise_label = noise_label 156 | else: 157 | if self.mode == "labeled": 158 | pred_idx = pred.nonzero()[0] 159 | self.probability = [probability[i] for i in pred_idx] 160 | 161 | clean = (np.array(noise_label)==np.array(train_label)) 162 | auc_meter = AUCMeter() 163 | auc_meter.reset() 164 | auc_meter.add(probability,clean) 165 | auc,_,_ = auc_meter.value() 166 | log.write('Numer of labeled samples:%d AUC:%.3f\n'%(pred.sum(),auc)) 167 | log.flush() 168 | 169 | elif self.mode == "unlabeled": 170 | pred_idx = (1-pred).nonzero()[0] 171 | 172 | self.train_data = train_data[pred_idx] 173 | self.noise_label = [noise_label[i] for i in pred_idx] 174 | print("%s data has a size of %d"%(self.mode,len(self.noise_label))) 175 | 176 | def __getitem__(self, index): 177 | if self.mode=='labeled': 178 | img, target, prob = self.train_data[index], self.noise_label[index], self.probability[index] 179 | img = Image.fromarray(img) 180 | img1 = self.transform(img) 181 | img2 = self.transform(img) 182 | return img1, img2, target, prob 183 | elif self.mode=='unlabeled': 184 | img = self.train_data[index] 185 | img = Image.fromarray(img) 186 | img1 = self.transform(img) 187 | img2 = self.transform(img) 188 | return img1, img2 189 | elif self.mode=='all': 190 | img, target = self.train_data[index], self.noise_label[index] 191 | img = Image.fromarray(img) 192 | img = self.transform(img) 193 | return (img, target) 194 | elif self.mode=='test': 195 | img, target = self.test_data[index], self.test_label[index] 196 | img = Image.fromarray(img) 197 | img = self.transform(img) 198 | return (img, target) 199 | 200 | def __len__(self): 201 | if self.mode!='test': 202 | return len(self.train_data) 203 | else: 204 | return len(self.test_data) 205 | 206 | class cifar_dataloader(): 207 | def __init__(self, dataset='cifar10', r=0.2, noise_mode='sym', batch_size=256, num_workers=4, cutout=False,root_dir='', log='', noise_file='cifar10.json'): 208 | self.dataset = dataset 209 | self.r = r 210 | self.noise_mode = noise_mode 211 | self.batch_size = batch_size 212 | self.num_workers = num_workers 213 | self.cutout = cutout 214 | self.root_dir = root_dir 215 | self.log = log 216 | self.noise_file = noise_file 217 | if self.dataset=='cifar10': 218 | if self.cutout: 219 | self.transform_train = transforms.Compose([ 220 | transforms.RandomHorizontalFlip(), 221 | transforms.RandomCrop(32, 4), 222 | transforms.ToTensor(), 223 | transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)), 224 | Cutout() 225 | ]) 226 | else: 227 | self.transform_train = transforms.Compose([ 228 | transforms.RandomHorizontalFlip(), 229 | transforms.RandomCrop(32, 4), 230 | transforms.ToTensor(), 231 | transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)), 232 | ]) 233 | elif self.dataset=='cifar100': 234 | if self.cutout: 235 | self.transform_train = transforms.Compose([ 236 | transforms.RandomHorizontalFlip(), 237 | transforms.RandomCrop(32, 4), 238 | transforms.ToTensor(), 239 | transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)), 240 | Cutout() 241 | # transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)), 242 | ]) 243 | else: 244 | self.transform_train = transforms.Compose([ 245 | transforms.RandomHorizontalFlip(), 246 | transforms.RandomCrop(32, 4), 247 | transforms.ToTensor(), 248 | transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)), 249 | # transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)), 250 | ]) 251 | self.transform_test = transforms.Compose([ 252 | transforms.ToTensor(), 253 | transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)), 254 | # transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)), 255 | ]) 256 | 257 | def get_loader(self): 258 | 259 | train_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, 260 | root_dir='./datasets/cifar-10-batches-py', transform=self.transform_train, mode="all", 261 | noise_file=self.noise_file) 262 | train_loader = DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True, 263 | num_workers=self.num_workers, pin_memory=True) 264 | 265 | val_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, 266 | root_dir='./datasets/cifar-10-batches-py', transform=self.transform_test, mode="test", 267 | noise_file=self.noise_file) 268 | val_loader = DataLoader(dataset=val_dataset, batch_size=self.batch_size, shuffle=True, 269 | num_workers=self.num_workers, pin_memory=True) 270 | return train_loader, val_loader 271 | 272 | def get_datasets_cutout(args): 273 | print ('cutout!') 274 | if args.datasets == 'CIFAR10': 275 | print ('cifar10 dataset!') 276 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 277 | 278 | train_loader = torch.utils.data.DataLoader( 279 | datasets.CIFAR10(root='./datasets/', train=True, transform=transforms.Compose([ 280 | transforms.RandomHorizontalFlip(), 281 | transforms.RandomCrop(32, 4), 282 | transforms.ToTensor(), 283 | normalize, 284 | Cutout() 285 | ]), download=True), 286 | batch_size=args.batch_size, shuffle=True, 287 | num_workers=args.workers, pin_memory=True) 288 | 289 | val_loader = torch.utils.data.DataLoader( 290 | datasets.CIFAR10(root='./datasets/', train=False, transform=transforms.Compose([ 291 | transforms.ToTensor(), 292 | normalize, 293 | ])), 294 | batch_size=128, shuffle=False, 295 | num_workers=args.workers, pin_memory=True) 296 | 297 | elif args.datasets == 'CIFAR10_noise': 298 | print('cifar10 nosie dataset!') 299 | cifar10_noise = cifar_dataloader(dataset='cifar10', r=args.noise_ratio, batch_size=args.batch_size, num_workers=args.workers, cutout=args.cutout) 300 | train_loader, val_loader = cifar10_noise.get_loader() 301 | 302 | elif args.datasets == 'CIFAR100': 303 | print ('cifar100 dataset!') 304 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 305 | 306 | train_loader = torch.utils.data.DataLoader( 307 | datasets.CIFAR100(root='./datasets/', train=True, transform=transforms.Compose([ 308 | transforms.RandomHorizontalFlip(), 309 | transforms.RandomCrop(32, 4), 310 | transforms.ToTensor(), 311 | normalize, 312 | Cutout() 313 | ]), download=True), 314 | batch_size=args.batch_size, shuffle=True, 315 | num_workers=args.workers, pin_memory=True) 316 | 317 | val_loader = torch.utils.data.DataLoader( 318 | datasets.CIFAR100(root='./datasets/', train=False, transform=transforms.Compose([ 319 | transforms.ToTensor(), 320 | normalize, 321 | ])), 322 | batch_size=128, shuffle=False, 323 | num_workers=args.workers, pin_memory=True) 324 | 325 | return train_loader, val_loader 326 | 327 | 328 | def get_model(args): 329 | print('Model: {}'.format(args.arch)) 330 | 331 | 332 | if args.datasets == 'CIFAR10' or args.datasets == 'CIFAR10_noise': 333 | num_classes = 10 334 | elif args.datasets == 'CIFAR100': 335 | num_classes = 100 336 | elif args.datasets == 'ImageNet': 337 | num_classes = 1000 338 | 339 | if 'deit' in args.arch: 340 | model = VisionTransformer( 341 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 342 | norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes=num_classes, drop_rate=args.drop, 343 | drop_path_rate=args.drop_path 344 | ) 345 | checkpoint = torch.hub.load_state_dict_from_url( 346 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 347 | map_location="cpu", check_hash=True 348 | ) 349 | checkpoint["model"].pop('head.weight') 350 | checkpoint["model"].pop('head.bias') 351 | 352 | model.load_state_dict(checkpoint["model"],strict=False) 353 | return model 354 | 355 | if args.datasets == 'ImageNet': 356 | return models_imagenet.__dict__[args.arch]() 357 | 358 | if args.arch == 'PyramidNet110': 359 | return PyramidNet(110, 270, num_classes) 360 | 361 | model_cfg = getattr(models, args.arch) 362 | 363 | return model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 364 | 365 | 366 | class SAM(torch.optim.Optimizer): 367 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): 368 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 369 | 370 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 371 | super(SAM, self).__init__(params, defaults) 372 | 373 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 374 | self.param_groups = self.base_optimizer.param_groups 375 | self.defaults.update(self.base_optimizer.defaults) 376 | 377 | @torch.no_grad() 378 | def first_step(self, zero_grad=False): 379 | grad_norm = self._grad_norm() 380 | for group in self.param_groups: 381 | scale = group["rho"] / (grad_norm + 1e-12) 382 | 383 | for p in group["params"]: 384 | if p.grad is None: continue 385 | self.state[p]["old_p"] = p.data.clone() 386 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 387 | p.add_(e_w) # climb to the local maximum "w + e(w)" 388 | 389 | if zero_grad: self.zero_grad() 390 | 391 | @torch.no_grad() 392 | def second_step(self, zero_grad=False): 393 | for group in self.param_groups: 394 | for p in group["params"]: 395 | if p.grad is None: continue 396 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 397 | 398 | self.base_optimizer.step() # do the actual "sharpness-aware" update 399 | 400 | if zero_grad: self.zero_grad() 401 | 402 | @torch.no_grad() 403 | def step(self, closure=None): 404 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 405 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 406 | 407 | self.first_step(zero_grad=True) 408 | closure() 409 | self.second_step() 410 | 411 | def _grad_norm(self): 412 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 413 | norm = torch.norm( 414 | torch.stack([ 415 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 416 | for group in self.param_groups for p in group["params"] 417 | if p.grad is not None 418 | ]), 419 | p=2 420 | ) 421 | return norm 422 | 423 | def load_state_dict(self, state_dict): 424 | super().load_state_dict(state_dict) 425 | self.base_optimizer.param_groups = self.param_groups 426 | 427 | class FriendlySAM(torch.optim.Optimizer): 428 | def __init__(self, params, base_optimizer, rho=0.05, sigma=1, lmbda=0.9, adaptive=False, **kwargs): 429 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 430 | 431 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 432 | super(FriendlySAM, self).__init__(params, defaults) 433 | 434 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 435 | self.param_groups = self.base_optimizer.param_groups 436 | self.defaults.update(self.base_optimizer.defaults) 437 | self.sigma = sigma 438 | self.lmbda = lmbda 439 | print ('FriendlySAM sigma:', self.sigma, 'lambda:', self.lmbda) 440 | 441 | @torch.no_grad() 442 | def first_step(self, zero_grad=False): 443 | 444 | for group in self.param_groups: 445 | for p in group["params"]: 446 | if p.grad is None: continue 447 | grad = p.grad.clone() 448 | if not "momentum" in self.state[p]: 449 | self.state[p]["momentum"] = grad 450 | else: 451 | p.grad -= self.state[p]["momentum"] * self.sigma 452 | self.state[p]["momentum"] = self.state[p]["momentum"] * self.lmbda + grad * (1 - self.lmbda) 453 | 454 | grad_norm = self._grad_norm() 455 | for group in self.param_groups: 456 | scale = group["rho"] / (grad_norm + 1e-12) 457 | 458 | for p in group["params"]: 459 | if p.grad is None: continue 460 | self.state[p]["old_p"] = p.data.clone() 461 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 462 | p.add_(e_w) # climb to the local maximum "w + e(w)" 463 | 464 | if zero_grad: self.zero_grad() 465 | 466 | @torch.no_grad() 467 | def second_step(self, zero_grad=False): 468 | for group in self.param_groups: 469 | for p in group["params"]: 470 | if p.grad is None: continue 471 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 472 | 473 | self.base_optimizer.step() # do the actual "sharpness-aware" update 474 | 475 | if zero_grad: self.zero_grad() 476 | 477 | @torch.no_grad() 478 | def step(self, closure=None): 479 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 480 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 481 | 482 | self.first_step(zero_grad=True) 483 | closure() 484 | self.second_step() 485 | 486 | def _grad_norm(self): 487 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 488 | norm = torch.norm( 489 | torch.stack([ 490 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 491 | for group in self.param_groups for p in group["params"] 492 | if p.grad is not None 493 | ]), 494 | p=2 495 | ) 496 | return norm 497 | 498 | def load_state_dict(self, state_dict): 499 | super().load_state_dict(state_dict) 500 | self.base_optimizer.param_groups = self.param_groups 501 | --------------------------------------------------------------------------------