├── LICENSE ├── README.md ├── architectures ├── __init__.py ├── arch.py ├── convlarge.py ├── densenet.py ├── dpn.py ├── lenet.py ├── mobilenet.py ├── mobilenetv2.py ├── preact_resnet.py ├── resnet.py ├── resnext.py ├── senet.py ├── shufflenet.py └── vgg.py ├── main.py ├── results ├── epslab2013v2_cifar10-4k_20-01-07-epochs-400.txt ├── etempensv2_cifar10-4k_20-01-07-epochs-400.txt ├── ictv2_cifar10-4k_20-01-07-epochs-400.txt ├── ipslab2013v2_cifar10-4k_20-01-07-epochs-400.txt ├── itempensv1_cifar10-4k_20-01-07-epochs-400.txt ├── itempensv2_cifar10-4k_19-12-28-epochs-400.txt ├── mixmatchv2_cifar10-4k_epochs-400.txt ├── mtv2_cifar10-4k_19-12-29-epochs-400.txt ├── piv1_cifar10-4k_19-12-30-epochs-200.txt ├── piv2_cifar10-4k_20-01-07-epochs-400.txt ├── vatv1_cifar10-4k_20-04-11-08-40.txt └── vatv2_cifar10-4k_20-04-12-09-07.txt ├── run.sh ├── trainer ├── ICTv1.py ├── ICTv2.py ├── MeanTeacherv1.py ├── MeanTeacherv2.py ├── MixMatch.py ├── PIv1.py ├── PIv2.py ├── VATv1.py ├── VATv2.py ├── __init__.py ├── eFixMatch.py ├── eMixPseudoLabelv1.py ├── eMixPseudoLabelv2.py ├── ePseudoLabel2013v1.py ├── ePseudoLabel2013v2.py ├── eTempensv1.py ├── eTempensv2.py ├── iFixMatch.py ├── iPseudoLabel2013v1.py ├── iPseudoLabel2013v2.py ├── iTempensv1.py └── iTempensv2.py └── utils ├── __init__.py ├── config.py ├── context.py ├── data_utils.py ├── datasets.py ├── dist.py ├── loss.py ├── mixup.py ├── ramps.py └── randAug.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 iBelieveCJM 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tricks of Semi-supervised Deep Leanring --Pytorch 2 | 3 | The repository implements following semi-supervised deep learning methods: 4 | 5 | - **PseudoLabel 2013**: The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks (ICMLW 2013) 6 | 7 | - **PI&Tempens**: Temporal Ensembling for Semi-Supervised Learning (ICLR 2017) 8 | 9 | - **MeanTeacher**: Mean Teachers are better Role Models (NIPS 2017) 10 | 11 | - **VAT**: Virtual Adversarial Training: A Regularization Method for Supervised and Semi-supervised Learning (TPAMI 2018) 12 | 13 | - **ICT**: Interpolation Consistency Training for Semi-supervised Learning (IJCAI 2019) 14 | 15 | - **MixMatch**: A Holistic Approach to Semi-Supervised Learning (NIPS 2019) 16 | 17 | - **FixMatch**: Simplifying Semi-Supervised Learning with Consistency and Confidence (2020) 18 | 19 | This repository was created for my blog [半监督深度学习训练和实现小Tricks](https://zhuanlan.zhihu.com/p/100252944). Therefore the hyper-parameters are set for fair comparision, rather than performance. 20 | 21 | ### The environment: 22 | 23 | - Ubuntu 16.04 + CUDA 9.0 24 | 25 | - Python 3.6.5:: Anaconda 26 | 27 | - PyTorch 0.4.1 and torchvision 0.2.1 28 | 29 | ### To run the code: 30 | 31 | The script *run.sh* includes some examples. You can try it as follow: 32 | 33 | ```shell 34 | bash run.sh [gpu_id] 35 | ``` 36 | 37 | ### Some experimental results: 38 | 39 | I haven't run all models in this repository. Some results of this repo. are shown in *results* directory. And the following results came from this repository and the old codes which this repo. built on. 40 | 41 | The following table shows the error rates of the CIFAR10 experiment with 4000 labeled training samples. The parameter settings are the same with the examples in *run.sh*. 42 | 43 | | | iPseudoLabel2013 | ePseudoLabel2013 | MeanTeacher | MixMatch | iFixMatch | 44 | |------- | ---------------- | ---------------- | ----------- | -------- | --------- | 45 | |orginal | | | 12.31 | 6.24 | 4.26 | 46 | | v1 | 20.03 | 12.03 | 10.59 | 6.70 | 6.63 | 47 | | v2 | 15.82 | 10.82 | 9.46 | 6.89 | 6.44 | 48 | | | iTempens | eTempens | PI | ICT\* | VAT | 49 | |orginal | | 12.16 | 13.20 | 7.29 | 11.36 | 50 | | v1 | 10.98 | 10.74 | 14.11 | 7.12 | 13.84 | 51 | | v2 | 13.53 | 10.24 | 12.89 | 6.74 | 12.67 | 52 | 53 | 54 | | | eMixPseudoLabelv1 | eMixPseudoLabelv2 | 55 | |------- | ----------------- | ----------------- | 56 | | soft | 7.30 | 7.08 | 57 | | hard | 7.33 | 7.20 | 58 | 59 | Notes: 60 | 61 | - The MeanTeacher is the first model of the repository. So the hyper-parameters actually have been tuned for MeanTeacher. 62 | 63 | - My ICT is different from original one. The main difference is the unsupervised loss for unlabeled data. 64 | 65 | - eMixPseudoLabel is ePseudoLabel2013 with MixUp. 66 | 67 | - Instead of KL Divergence, my VAT uses MSE which bring more performance improvement. 68 | -------------------------------------------------------------------------------- /architectures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iBelieveCJM/Tricks-of-Semi-supervisedDeepLeanring-Pytorch/be90060b3017e99b8c53a596110cb5931ec9f38c/architectures/__init__.py -------------------------------------------------------------------------------- /architectures/arch.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | from functools import wraps 3 | 4 | from architectures.lenet import LeNet 5 | from architectures.vgg import VGG11, VGG13, VGG16, VGG19 6 | from architectures.resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 7 | from architectures.preact_resnet import PreActResNet18, PreActResNet34, PreActResNet50, PreActResNet101, PreActResNet152 8 | from architectures.densenet import DenseNet_cifar, DenseNet121, DenseNet169, DenseNet201, DenseNet161 9 | from architectures.resnext import ResNeXt29_2x64d,ResNeXt29_4x64d,ResNeXt29_8x64d,ResNeXt29_32x4d 10 | from architectures.senet import SENet18 11 | from architectures.dpn import DPN26, DPN92 12 | from architectures.shufflenet import ShuffleNetG2, ShuffleNetG3 13 | from architectures.mobilenet import MobileNetV1 14 | from architectures.mobilenetv2 import MobileNetV2 15 | from architectures.convlarge import convLarge 16 | 17 | arch = { 18 | 'lenet': LeNet, 19 | 'vgg11': VGG11, 20 | 'vgg13': VGG13, 21 | 'vgg16': VGG16, 22 | 'vgg19': VGG19, 23 | 'resnet18': ResNet18, 24 | 'resnet34': ResNet34, 25 | 'resnet50': ResNet50, 26 | 'resnet101': ResNet101, 27 | 'resnet152': ResNet152, 28 | 'preact_resnet18': PreActResNet18, 29 | 'preact_resnet34': PreActResNet34, 30 | 'preact_resnet50': PreActResNet50, 31 | 'preact_resnet101': PreActResNet101, 32 | 'preact_resnet152': PreActResNet152, 33 | 'densenet121': DenseNet121, 34 | 'densenet169': DenseNet169, 35 | 'densenet201': DenseNet201, 36 | 'densenet161': DenseNet161, 37 | 'densenet': DenseNet_cifar, 38 | 'resnext29_2x64d': ResNeXt29_2x64d, 39 | 'resnext29_4x64d': ResNeXt29_4x64d, 40 | 'resnext29_8x64d': ResNeXt29_8x64d, 41 | 'resnext29_32x4d': ResNeXt29_32x4d, 42 | 'senet': SENet18, 43 | 'dpn26': DPN26, 44 | 'dpn92': DPN92, 45 | 'shuffleG2': ShuffleNetG2, 46 | 'shuffleG3': ShuffleNetG3, 47 | 'mobileV1': MobileNetV1, 48 | 'mobileV2': MobileNetV2, 49 | 'cnn13': convLarge 50 | } 51 | 52 | 53 | def RegisterArch(arch_name): 54 | """Register a model 55 | you must import the file where using this decorator 56 | for register the model function 57 | """ 58 | def warpper(f): 59 | arch[arch_name] = f 60 | return f 61 | return warpper 62 | -------------------------------------------------------------------------------- /architectures/convlarge.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.utils import weight_norm 6 | 7 | class GaussianNoise(nn.Module): 8 | 9 | def __init__(self, std): 10 | super(GaussianNoise, self).__init__() 11 | self.std = std 12 | 13 | def forward(self, x): 14 | zeros = torch.zeros(x.size()).cuda() 15 | n = Variable(torch.normal(zeros_, std=self.std).cuda()) 16 | return x + n 17 | 18 | class CNN_block(nn.Module): 19 | 20 | def __init__(self, in_plane, out_plane, kernel_size, padding, activation): 21 | super(CNN_block, self).__init__() 22 | 23 | self.act = activation 24 | self.conv = nn.Conv2d(in_plane, 25 | out_plane, 26 | kernel_size, 27 | padding=padding) 28 | 29 | self.bn = nn.BatchNorm2d(out_plane) 30 | 31 | def forward(self, x): 32 | return self.act(self.bn(self.conv(x))) 33 | 34 | class CNN(nn.Module): 35 | 36 | def __init__(self, block, num_blocks, num_classes=10, drop_ratio=0.0): 37 | super(CNN, self).__init__() 38 | 39 | self.in_plane = 3 40 | self.out_plane = 128 41 | self.gn = GaussianNoise(0.15) 42 | self.act = nn.LeakyReLU(0.1) 43 | self.layer1 = self._make_layer(block, num_blocks[0], 128, 3, padding=1) 44 | self.mp1 = nn.MaxPool2d(2, stride=2, padding=0) 45 | self.drop1 = nn.Dropout(drop_ratio) 46 | self.layer2 = self._make_layer(block, num_blocks[1], 256, 3, padding=1) 47 | self.mp2 = nn.MaxPool2d(2, stride=2, padding=0) 48 | self.drop2 = nn.Dropout(drop_ratio) 49 | self.layer3 = self._make_layer(block, num_blocks[2], 50 | [512, 256, self.out_plane], 51 | [3, 1, 1], 52 | padding=0) 53 | self.ap3 = nn.AdaptiveAvgPool2d(1) 54 | self.fc1 = nn.Linear(self.out_plane, num_classes) 55 | 56 | def _make_layer(self, block, num_blocks, planes, kernel_size, padding=1): 57 | if isinstance(planes, int): 58 | planes = [planes]*num_blocks 59 | if isinstance(kernel_size, int): 60 | kernel_size = [kernel_size]*num_blocks 61 | layers = [] 62 | for plane, ks in zip(planes, kernel_size): 63 | layers.append(block(self.in_plane, plane, ks, padding, self.act)) 64 | self.in_plane = plane 65 | return nn.Sequential(*layers) 66 | 67 | def forward(self, x): 68 | out = self.layer1(x) 69 | out = self.mp1(out) 70 | out = self.drop1(out) 71 | out = self.layer2(out) 72 | out = self.mp2(out) 73 | out = self.drop1(out) 74 | out = self.layer3(out) 75 | out = self.ap3(out) 76 | 77 | out = out.view(out.size(0), -1) 78 | return self.fc1(out) 79 | 80 | 81 | def convLarge(num_classes, drop_ratio=0.0): 82 | return CNN(CNN_block, [3,3,3], num_classes, drop_ratio) 83 | 84 | def test(): 85 | print('--- run conv_large test ---') 86 | x = torch.randn(2,3,32,32) 87 | for net in [convLarge(10)]: 88 | print(net) 89 | y = net(x) 90 | print(y.size()) 91 | -------------------------------------------------------------------------------- /architectures/densenet.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class Bottleneck(nn.Module): 8 | 9 | def __init__(self, in_planes, growth_rate): 10 | super(Bottleneck, self).__init__() 11 | self.bn1 = nn.BatchNorm2d(in_planes) 12 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 13 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 14 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 15 | 16 | def forward(self, x): 17 | out = self.conv1(F.relu(self.bn1(x))) 18 | out = self.conv2(F.relu(self.bn2(out))) 19 | return torch.cat([out, x], 1) 20 | 21 | class Transition(nn.Module): 22 | 23 | def __init__(self, in_planes, out_planes): 24 | super(Transition, self).__init__() 25 | self.bn = nn.BatchNorm2d(in_planes) 26 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 27 | 28 | def forward(self, x): 29 | out = self.conv(F.relu(self.bn(x))) 30 | return F.avg_pool2d(out, 2) 31 | 32 | class DenseNet(nn.Module): 33 | 34 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 35 | super(DenseNet, self).__init__() 36 | self.growth_rate = growth_rate 37 | 38 | num_planes = 2*growth_rate 39 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 40 | 41 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 42 | num_planes += nblocks[0]*growth_rate 43 | out_planes = int(math.floor(num_planes*reduction)) 44 | self.trans1 = Transition(num_planes, out_planes) 45 | num_planes = out_planes 46 | 47 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 48 | num_planes += nblocks[1]*growth_rate 49 | out_planes = int(math.floor(num_planes*reduction)) 50 | self.trans2 = Transition(num_planes, out_planes) 51 | num_planes = out_planes 52 | 53 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 54 | num_planes += nblocks[2]*growth_rate 55 | out_planes = int(math.floor(num_planes*reduction)) 56 | self.trans3 = Transition(num_planes, out_planes) 57 | num_planes = out_planes 58 | 59 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 60 | num_planes += nblocks[3]*growth_rate 61 | 62 | self.bn = nn.BatchNorm2d(num_planes) 63 | self.fc1 = nn.Linear(num_planes, num_classes) 64 | 65 | def _make_dense_layers(self, block, in_planes, nblock): 66 | layers = [] 67 | for i in range(nblock): 68 | layers.append(block(in_planes, self.growth_rate)) 69 | in_planes += self.growth_rate 70 | return nn.Sequential(*layers) 71 | 72 | def forward(self, x): 73 | out = self.conv1(x) 74 | out = self.trans1(self.dense1(out)) 75 | out = self.trans2(self.dense2(out)) 76 | out = self.trans3(self.dense3(out)) 77 | out = self.dense4(out) 78 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 79 | out = out.view(out.size(0), -1) 80 | return self.fc1(out) 81 | 82 | def DenseNet121(num_classes): 83 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32, num_classes=num_classes) 84 | 85 | def DenseNet169(num_classes): 86 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32, num_classes=num_classes) 87 | 88 | def DenseNet201(num_classes): 89 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32, num_classes=num_classes) 90 | 91 | def DenseNet161(num_classes): 92 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48, num_classes=num_classes) 93 | 94 | def DenseNet_cifar(num_classes): 95 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12, num_classes=num_classes) 96 | 97 | def test(): 98 | print('--- run densenet test ---') 99 | x = torch.randn(2,3,32,32) 100 | for net in [DenseNet121(10), DenseNet169(10), DenseNet201(10), DenseNet161(10), DenseNet_cifar(10)]: 101 | print(net) 102 | y = net(x) 103 | print(y.size()) 104 | 105 | -------------------------------------------------------------------------------- /architectures/dpn.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Bottleneck(nn.Module): 7 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 8 | super(Bottleneck, self).__init__() 9 | self.out_planes = out_planes 10 | self.dense_depth = dense_depth 11 | 12 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(in_planes) 14 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 15 | self.bn2 = nn.BatchNorm2d(in_planes) 16 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 17 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 18 | 19 | self.shortcut = nn.Sequential() 20 | if first_layer: 21 | self.shortcut = nn.Sequential( 22 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 23 | nn.BatchNorm2d(out_planes+dense_depth) 24 | ) 25 | 26 | def forward(self, x): 27 | out = F.relu(self.bn1(self.conv1(x))) 28 | out = F.relu(self.bn2(self.conv2(out))) 29 | out = self.bn3(self.conv3(out)) 30 | x = self.shortcut(x) 31 | d = self.out_planes 32 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]],1) 33 | return F.relu(out) 34 | 35 | class DPN(nn.Module): 36 | 37 | def __init__(self, cfg, num_classes): 38 | super(DPN, self).__init__() 39 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 40 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 41 | 42 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(64) 44 | self.last_planes = 64 45 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 46 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 47 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 48 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 49 | self.fc1 = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], num_classes) 50 | 51 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 52 | strides = [stride] + [1]*(num_blocks-1) 53 | layers = [] 54 | for i, stride in enumerate(strides): 55 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 56 | self.last_planes = out_planes + (i+2)*dense_depth 57 | return nn.Sequential(*layers) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = self.layer1(out) 62 | out = self.layer2(out) 63 | out = self.layer3(out) 64 | out = self.layer4(out) 65 | out = F.avg_pool2d(out, 4) 66 | out = out.view(out.size(0), -1) 67 | return self.fc1(out) 68 | 69 | def DPN26(num_classes): 70 | cfg = { 71 | 'in_planes': (96,192,384,768), 72 | 'out_planes': (256,512,1024,2048), 73 | 'num_blocks': (2,2,2,2), 74 | 'dense_depth': (16,32,24,128), 75 | } 76 | return DPN(cfg, num_classes) 77 | 78 | def DPN92(num_classes): 79 | cfg = { 80 | 'in_planes': (96,192,384,768), 81 | 'out_planes': (256,512,1024,2048), 82 | 'num_blocks': (3,4,20,3), 83 | 'dense_depth': (16,32,24,128), 84 | } 85 | return DPN(cfg, num_classes) 86 | 87 | def test(): 88 | print('--- run dpn test ---') 89 | x = torch.randn(2,3,32,32) 90 | for net in [DPN26(10), DPN92(10)]: 91 | y = net(x) 92 | print(y.size()) 93 | -------------------------------------------------------------------------------- /architectures/lenet.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class LeNet(nn.Module): 7 | 8 | def __init__(self, num_classes): 9 | super(LeNet, self).__init__() 10 | self.conv1 = nn.Conv2d(3, 6, kernel_size=5) #default: stride=1,padding=0 11 | self.conv2 = nn.Conv2d(6, 16,kernel_size=5) 12 | self.fc1 = nn.Linear(16*5*5, 120) 13 | self.fc2 = nn.Linear(120, 84) 14 | self.fc3 = nn.Linear(84, num_classes) 15 | 16 | def forward(self, x): 17 | out = F.relu(self.conv1(x)) 18 | out = F.max_pool2d(out, 2) 19 | out = F.relu(self.conv2(out)) 20 | out = F.max_pool2d(out, 2) 21 | 22 | out = out.view(out.size(0), -1) 23 | out = F.relu(self.fc1(out)) 24 | out = F.relu(self.fc2(out)) 25 | 26 | return self.fc3(out) 27 | 28 | def test(): 29 | print('--- run lenet test ---') 30 | net = LeNet(10); 31 | print(net) 32 | x = torch.randn(2, 3, 32, 32) 33 | y = net(x) 34 | print(y.size()) 35 | -------------------------------------------------------------------------------- /architectures/mobilenet.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Block(nn.Module): 7 | ''' Depthwise conv + Pointwise conv''' 8 | def __init__(self, in_planes, out_planes, stride=1): 9 | super(Block, self).__init__() 10 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 11 | self.bn1 = nn.BatchNorm2d(in_planes) 12 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 13 | self.bn2 = nn.BatchNorm2d(out_planes) 14 | 15 | def forward(self, x): 16 | out = F.relu(self.bn1(self.conv1(x))) 17 | out = F.relu(self.bn2(self.conv2(out))) 18 | return out 19 | 20 | class MobileNet(nn.Module): 21 | 22 | def __init__(self, cfg, num_classes): 23 | super(MobileNet, self).__init__() 24 | self.cfg = cfg 25 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 26 | self.bn1 = nn.BatchNorm2d(32) 27 | self.layers = self._make_layers(in_planes=32) 28 | self.fc1 = nn.Linear(1024, num_classes) 29 | 30 | def _make_layers(self, in_planes): 31 | layers = [] 32 | for x in self.cfg: 33 | out_planes = x if isinstance(x, int) else x[0] 34 | stride = 1 if isinstance(x, int) else x[1] 35 | layers.append(Block(in_planes, out_planes, stride)) 36 | in_planes = out_planes 37 | return nn.Sequential(*layers) 38 | 39 | def forward(self, x): 40 | out = F.relu(self.bn1(self.conv1(x))) 41 | out = self.layers(out) 42 | out = F.avg_pool2d(out, 2) 43 | out = out.view(out.size(0), -1) 44 | return self.fc1(out) 45 | 46 | def MobileNetV1(num_classes): 47 | #(128,2) means planes=128,stride=2 48 | # 128 mean planes=128, by default stride=1 49 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 50 | return MobileNet(cfg, num_classes) 51 | 52 | def test(): 53 | print('--- run mobilenet test ---') 54 | x = torch.randn(2,3,32,32) 55 | net = MobileNetV1(10) 56 | y = net(x) 57 | print(y.size()) 58 | -------------------------------------------------------------------------------- /architectures/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Block(nn.Module): 7 | '''expand + depthwise + pointwise''' 8 | def __init__(self, in_planes, out_planes, expansion, stride): 9 | super(Block, self).__init__() 10 | self.stride = stride 11 | 12 | planes = expansion * in_planes 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 18 | self.bn3 = nn.BatchNorm2d(out_planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride==1 and in_planes!=out_planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 24 | nn.BatchNorm2d(out_planes) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = F.relu(self.bn2(self.conv2(out))) 30 | out = self.bn3(self.conv3(out)) 31 | out = out + self.shortcut(x) if self.stride==1 else out 32 | return out 33 | 34 | class _MobileNetV2(nn.Module): 35 | 36 | def __init__(self, cfg, num_classes): 37 | super(_MobileNetV2, self).__init__() 38 | self.cfg = cfg 39 | # note. change conv1 stride 2->1 for cifar10 40 | self.conv1 = nn.Conv2d(3,32,kernel_size=3, stride=1, padding=1, bias=False) 41 | self.bn1 = nn.BatchNorm2d(32) 42 | self.layers = self._make_layers(in_planes=32) 43 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 44 | self.bn2 = nn.BatchNorm2d(1280) 45 | self.fc1 = nn.Linear(1280, num_classes) 46 | 47 | def _make_layers(self, in_planes): 48 | layers = [] 49 | for expansion, out_planes, num_blocks, stride in self.cfg: 50 | strides = [stride] + [1]*(num_blocks-1) 51 | for stride in strides: 52 | layers.append(Block(in_planes, out_planes, expansion, stride)) 53 | in_planes = out_planes 54 | return nn.Sequential(*layers) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = self.layers(out) 59 | out = F.relu(self.bn2(self.conv2(out))) 60 | # note. change pooling kernel_size 7->4 for cifar10 61 | out = F.avg_pool2d(out, 4) 62 | out = out.view(out.size(0), -1) 63 | return self.fc1(out) 64 | 65 | def MobileNetV2(num_classes): 66 | #(expansion, out_planes, num_blocks, stride) 67 | cfg = [ 68 | (1, 16, 1, 1), 69 | (6, 24, 2, 2), 70 | (6, 32, 3, 1), 71 | (6, 64, 4, 2), 72 | (6, 96, 3, 1), 73 | (6,160, 3, 2), 74 | (6,320, 1, 1) 75 | ] 76 | return _MobileNetV2(cfg, num_classes) 77 | 78 | def test(): 79 | print('--- run mobilenet test ---') 80 | x = torch.randn(2,3,32,32) 81 | net = MobileNetV2(10) 82 | y = net(x) 83 | print(y.size()) 84 | -------------------------------------------------------------------------------- /architectures/preact_resnet.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | """Identity Mappings in Deep Residual Networks 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | class PreActBlock(nn.Module): 9 | expansion = 1 10 | 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(PreActBlock, self).__init__() 13 | self.bn1 = nn.BatchNorm2d(in_planes) 14 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(planes) 16 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 17 | 18 | if stride!=1 or in_planes!=self.expansion*planes: 19 | self.shortcut = nn.Sequential( 20 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 21 | ) 22 | 23 | def forward(self, x): 24 | out = F.relu(self.bn1(x)) 25 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 26 | out = self.conv1(out) 27 | out = self.conv2(F.relu(self.bn2(out))) 28 | out += shortcut 29 | return out 30 | 31 | class PreActBottleneck(nn.Module): 32 | expansion = 4 33 | 34 | def __init__(self, in_planes, planes, stride=1): 35 | super(PreActBottleneck, self).__init__() 36 | self.bn1 = nn.BatchNorm2d(in_planes) 37 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 40 | self.bn3 = nn.BatchNorm2d(planes) 41 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 42 | 43 | if stride!=1 or in_planes!=self.expansion*planes: 44 | self.shortcut = nn.Sequential( 45 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 46 | ) 47 | 48 | def forward(self, x): 49 | out = F.relu(self.bn1(x)) 50 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 51 | out = self.conv1(out) 52 | out = self.conv2(F.relu(self.bn2(out))) 53 | out = self.conv3(F.relu(self.bn3(out))) 54 | out += shortcut 55 | return out 56 | 57 | class PreActResNet(nn.Module): 58 | 59 | def __init__(self, block, num_blocks, num_classes): 60 | super(PreActResNet, self).__init__() 61 | self.in_planes = 64 62 | 63 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 64 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 65 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 66 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 67 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 68 | self.fc1 = nn.Linear(512*block.expansion, num_classes) 69 | 70 | def _make_layer(self, block, planes, num_blocks, stride): 71 | strides = [stride] + [1]*(num_blocks-1) 72 | layers = [] 73 | for stride in strides: 74 | layers.append(block(self.in_planes, planes, stride)) 75 | self.in_planes = planes * block.expansion 76 | return nn.Sequential(*layers) 77 | 78 | def forward(self, x): 79 | out = self.conv1(x) 80 | out = self.layer1(out) 81 | out = self.layer2(out) 82 | out = self.layer3(out) 83 | out = self.layer4(out) 84 | out = F.avg_pool2d(out, 4) 85 | out = out.view(out.size(0),-1) 86 | return self.fc1(out) 87 | 88 | def PreActResNet18(num_classes): 89 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes) 90 | 91 | def PreActResNet34(num_classes): 92 | return PreActResNet(PreActBlock, [3,4,6,3], num_classes) 93 | 94 | def PreActResNet50(num_classes): 95 | return PreActResNet(PreActBottleneck, [3,4,6,3], num_classes) 96 | 97 | def PreActResNet101(num_classes): 98 | return PreActResNet(PreActBottleneck, [3,4,23,3], num_classes) 99 | 100 | def PreActResNet152(num_classes): 101 | return PreActResNet(PreActBottleneck, [3,8,36,3], num_classes) 102 | 103 | def test(): 104 | print('--- run preact_resnet test') 105 | x = torch.randn(2,3,32,32) 106 | for net in [PreActResNet18(10), PreActResNet34(10), PreActResNet50(10), PreActResNet101(10), PreActResNet152(10)]: 107 | print(net) 108 | y = net(x) 109 | print(y.size()) 110 | -------------------------------------------------------------------------------- /architectures/resnet.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class BasicBlock(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, in_planes, planes, stride=1): 10 | super(BasicBlock, self).__init__() 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 12 | self.bn1 = nn.BatchNorm2d(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(planes) 15 | 16 | self.shortcut = nn.Sequential() 17 | if stride!=1 or in_planes!=self.expansion*planes: 18 | self.shortcut = nn.Sequential( 19 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 20 | nn.BatchNorm2d(self.expansion*planes) 21 | ) 22 | 23 | def forward(self, x): 24 | out = F.relu(self.bn1(self.conv1(x))) 25 | out = self.bn2(self.conv2(out)) 26 | out += self.shortcut(x) 27 | return F.relu(out) 28 | 29 | class Bottleneck(nn.Module): 30 | expansion = 4 31 | 32 | def __init__(self, in_planes, planes, stride=1): 33 | super(Bottleneck, self).__init__() 34 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 39 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 40 | 41 | self.shortcut = nn.Sequential() 42 | if stride!=1 or in_planes!=self.expansion*planes: 43 | self.shortcut = nn.Sequential( 44 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 45 | nn.BatchNorm2d(self.expansion*planes) 46 | ) 47 | 48 | def forward(self, x): 49 | out = F.relu(self.bn1(self.conv1(x))) 50 | out = F.relu(self.bn2(self.conv2(out))) 51 | out = self.bn3(self.conv3(out)) 52 | out += self.shortcut(x) 53 | return F.relu(out) 54 | 55 | class ResNet(nn.Module): 56 | 57 | def __init__(self, block, num_blocks, num_classes): 58 | super(ResNet, self).__init__() 59 | self.in_planes = 64 60 | 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(64) 63 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 64 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 65 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 66 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 67 | self.fc1 = nn.Linear(512*block.expansion, num_classes) 68 | 69 | def _make_layer(self, block, planes, num_blocks, stride): 70 | strides = [stride] + [1]*(num_blocks-1) 71 | layers = [] 72 | for stride in strides: 73 | layers.append(block(self.in_planes, planes, stride)) 74 | self.in_planes = planes * block.expansion 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | out = F.relu(self.bn1(self.conv1(x))) 79 | out = self.layer1(out) 80 | out = self.layer2(out) 81 | out = self.layer3(out) 82 | out = self.layer4(out) 83 | out = F.avg_pool2d(out, 4) 84 | out = out.view(out.size(0), -1) 85 | return self.fc1(out), out 86 | 87 | def ResNet18(num_classes): 88 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 89 | 90 | def ResNet34(num_classes): 91 | return ResNet(BasicBlock, [3,4,6,3], num_classes) 92 | 93 | def ResNet50(num_classes): 94 | return ResNet(Bottleneck, [3,4,6,3], num_classes) 95 | 96 | def ResNet101(num_classes): 97 | return ResNet(Bottleneck, [3,4,23,3], num_classes) 98 | 99 | def ResNet152(num_classes): 100 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 101 | 102 | def test(): 103 | print('--- run resnet test ---') 104 | x = torch.randn(2,3,32,32) 105 | for net in [ResNet18(10), ResNet34(10), ResNet50(10), ResNet101(10), ResNet152(10)]: 106 | print(net) 107 | y = net(x) 108 | print(y.size()) 109 | -------------------------------------------------------------------------------- /architectures/resnext.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Block(nn.Module): 7 | """Grouped convolution block""" 8 | expansion = 2 9 | 10 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 11 | super(Block, self).__init__() 12 | group_width = cardinality * bottleneck_width 13 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(group_width) 15 | self.conv2 = nn.Conv2d(group_width,group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 16 | self.bn2 = nn.BatchNorm2d(group_width) 17 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 18 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride!=1 or in_planes!=self.expansion*group_width: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(self.expansion*group_width) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = F.relu(self.bn2(self.conv2(out))) 30 | out = self.bn3(self.conv3(out)) 31 | out += self.shortcut(x) 32 | return F.relu(out) 33 | 34 | class ResNeXt(nn.Module): 35 | 36 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 37 | super(ResNeXt, self).__init__() 38 | self.cardinality = cardinality 39 | self.bottleneck_width = bottleneck_width 40 | self.in_planes = 64 41 | 42 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(64) 44 | self.layer1 = self._make_layer(num_blocks[0], 1) 45 | self.layer2 = self._make_layer(num_blocks[1], 2) 46 | self.layer3 = self._make_layer(num_blocks[2], 2) 47 | self.fc1 = nn.Linear(cardinality*bottleneck_width*8, num_classes) 48 | 49 | def _make_layer(self, num_blocks, stride): 50 | strides = [stride]+[1]*(num_blocks-1) 51 | layers = [] 52 | for stride in strides: 53 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 54 | self.in_planes = Block.expansion*self.cardinality*self.bottleneck_width 55 | # Increase bottleneck_width by 2 after each stage. 56 | self.bottleneck_width *= 2 57 | return nn.Sequential(*layers) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = self.layer1(out) 62 | out = self.layer2(out) 63 | out = self.layer3(out) 64 | out = F.avg_pool2d(out, 8) 65 | out = out.view(out.size(0), -1) 66 | return self.fc1(out) 67 | 68 | def ResNeXt29_2x64d(num_classes): 69 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64) 70 | 71 | def ResNeXt29_4x64d(num_classes): 72 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64) 73 | 74 | def ResNeXt29_8x64d(num_classes): 75 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64) 76 | 77 | def ResNeXt29_32x4d(num_classes): 78 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4) 79 | 80 | def test(): 81 | print('--- run resnext test ---') 82 | x = torch.randn(2,3,32,32) 83 | for net in [ResNeXt29_2x64d(10), ResNeXt29_4x64d(10), ResNeXt29_8x64d(10), ResNeXt29_32x4d(10)]: 84 | y = net(x) 85 | print(y.size()) 86 | -------------------------------------------------------------------------------- /architectures/senet.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class BasicBlock(nn.Module): 7 | 8 | def __init__(self, in_planes, planes, stride=1): 9 | super(BasicBlock, self).__init__() 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 11 | self.bn1 = nn.BatchNorm2d(planes) 12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 13 | self.bn2 = nn.BatchNorm2d(planes) 14 | 15 | self.shortcut = nn.Sequential() 16 | if stride!=1 or in_planes!=planes: 17 | self.shortcut = nn.Sequential( 18 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 19 | nn.BatchNorm2d(planes) 20 | ) 21 | 22 | # SE layers (Use nn.Conv2d instead of nn.Linear) 23 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 24 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 25 | 26 | def forward(self, x): 27 | out = F.relu(self.bn1(self.conv1(x))) 28 | out = self.bn2(self.conv2(out)) 29 | 30 | # Squeeze 31 | w = F.avg_pool2d(out, out.size(2)) 32 | w = F.relu(self.fc1(w)) 33 | #w = F.sigmoid(self.fc2(w)) # 0.4.0 34 | w = torch.sigmoid(self.fc2(w)) # 0.4.1.post2 35 | # Excitation 36 | out = out * w 37 | 38 | out += self.shortcut(x) 39 | return F.relu(out) 40 | 41 | class PreActBlock(nn.Module): 42 | 43 | def __init__(self, in_planes, planes, stride=1): 44 | super(PreActBlock, self).__init__() 45 | self.bn1 = nn.BatchNorm2d(in_planes) 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 49 | 50 | if stride!=1 or in_planes!=planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 53 | ) 54 | 55 | # SE layers 56 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 57 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(x)) 61 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 62 | out = self.conv1(out) 63 | out = self.conv2(F.relu(self.bn2(out))) 64 | 65 | # Squeeze 66 | w = F.avg_pool2d(out, (out.size(2),)) 67 | w = F.relu(self.fc1(w)) 68 | #w = F.sigmoid(self.fc2(w)) # 0.4.0 69 | w = torch.sigmoid(self.fc2(w)) # 0.4.1.post2 70 | # Excitation 71 | out = out * w 72 | 73 | out += shortcut 74 | return out 75 | 76 | class SENet(nn.Module): 77 | def __init__(self, block, num_blocks, num_classes=10): 78 | super(SENet, self).__init__() 79 | self.in_planes = 64 80 | 81 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 82 | self.bn1 = nn.BatchNorm2d(64) 83 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 84 | self.layer2 = self._make_layer(block,128, num_blocks[1], stride=2) 85 | self.layer3 = self._make_layer(block,256, num_blocks[2], stride=2) 86 | self.layer4 = self._make_layer(block,512, num_blocks[3], stride=2) 87 | self.fc1 = nn.Linear(512, num_classes) 88 | 89 | def _make_layer(self, block, planes, num_blocks, stride): 90 | strides = [stride]+[1]*(num_blocks-1) 91 | layers = [] 92 | for stride in strides: 93 | layers.append(block(self.in_planes, planes, stride)) 94 | self.in_planes = planes 95 | return nn.Sequential(*layers) 96 | 97 | def forward(self, x): 98 | out = F.relu(self.bn1(self.conv1(x))) 99 | out = self.layer1(out) 100 | out = self.layer2(out) 101 | out = self.layer3(out) 102 | out = self.layer4(out) 103 | out = F.avg_pool2d(out, 4) 104 | out = out.view(out.size(0), -1) 105 | return self.fc1(out) 106 | 107 | def SENet18(num_classes): 108 | return SENet(PreActBlock, [2,2,2,2]) 109 | 110 | def test(): 111 | print('--- run senet test ---') 112 | x = torch.randn(2,3,32,32) 113 | for net in [SENet18(10)]: 114 | print(net) 115 | y = net(x) 116 | print(y.size()) 117 | -------------------------------------------------------------------------------- /architectures/shufflenet.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class ShuffleBlock(nn.Module): 7 | 8 | def __init__(self, groups): 9 | super(ShuffleBlock, self).__init__() 10 | self.groups = groups 11 | 12 | def forward(self, x): 13 | '''Channel shuffle 14 | [N,C,H,W]->[N,g,C/g,H,W]->[N,C/g,g,H,W]->[N,C,H,W] 15 | ''' 16 | N,C,H,W = x.size() 17 | g = self.groups 18 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).contiguous().view(N,C,H,W) 19 | 20 | class Bottleneck(nn.Module): 21 | 22 | def __init__(self, in_planes, out_planes, stride, groups): 23 | super(Bottleneck, self).__init__() 24 | self.stride = stride 25 | 26 | mid_planes = out_planes//4 27 | g = 1 if in_planes==24 else groups 28 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 29 | self.bn1 = nn.BatchNorm2d(mid_planes) 30 | self.shuffle1 = ShuffleBlock(groups=g) 31 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 32 | self.bn2 = nn.BatchNorm2d(mid_planes) 33 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 34 | self.bn3 = nn.BatchNorm2d(out_planes) 35 | 36 | self.shortcut = nn.Sequential() 37 | if stride == 2: 38 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 39 | 40 | def forward(self, x): 41 | out = F.relu(self.bn1(self.conv1(x))) 42 | out = self.shuffle1(out) 43 | out = F.relu(self.bn2(self.conv2(out))) 44 | out = self.bn3(self.conv3(out)) 45 | res = self.shortcut(x) 46 | out = F.relu(torch.cat([out,res],1)) if self.stride==2 else F.relu(out+res) 47 | return out 48 | 49 | class ShuffleNet(nn.Module): 50 | 51 | def __init__(self, cfg, num_classes): 52 | super(ShuffleNet, self).__init__() 53 | out_planes = cfg['out_planes'] 54 | num_blocks = cfg['num_blocks'] 55 | groups = cfg['groups'] 56 | 57 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 58 | self.bn1 = nn.BatchNorm2d(24) 59 | self.in_planes = 24 60 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 61 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 62 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 63 | self.fc1 = nn.Linear(out_planes[2], num_classes) 64 | 65 | def _make_layer(self, out_planes, num_blocks, groups): 66 | layers = [] 67 | for i in range(num_blocks): 68 | stride = 2 if i==0 else 1 69 | cat_planes = self.in_planes if i==0 else 0 70 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups)) 71 | self.in_planes = out_planes 72 | return nn.Sequential(*layers) 73 | 74 | def forward(self, x): 75 | out = F.relu(self.bn1(self.conv1(x))) 76 | out = self.layer1(out) 77 | out = self.layer2(out) 78 | out = self.layer3(out) 79 | out = F.avg_pool2d(out, 4) 80 | out = out.view(out.size(0), -1) 81 | return self.fc1(out) 82 | 83 | def ShuffleNetG2(num_classes): 84 | cfg = { 85 | 'out_planes': [200,400,800], 86 | 'num_blocks': [4,8,4], 87 | 'groups': 2, 88 | } 89 | return ShuffleNet(cfg, num_classes) 90 | 91 | def ShuffleNetG3(num_classes): 92 | cfg = { 93 | 'out_planes': [240,480,960], 94 | 'num_blocks': [4,8,4], 95 | 'groups': 3, 96 | } 97 | return ShuffleNet(cfg, num_classes) 98 | 99 | def test(): 100 | print('--- run shufflenet test ---') 101 | x = torch.randn(2,3,32,32) 102 | for net in [ShuffleNetG2(10), ShuffleNetG3(10)]: 103 | y = net(x) 104 | print(y.size()) 105 | -------------------------------------------------------------------------------- /architectures/vgg.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | 5 | cfg = { 6 | 'vgg11':[64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 7 | 'vgg13':[64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'vgg16':[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 9 | 'vgg19':[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 10 | } 11 | 12 | def VGG11(num_classes): 13 | return VGG('vgg11', num_classes) 14 | 15 | def VGG13(num_classes): 16 | return VGG('vgg13', num_classes) 17 | 18 | def VGG16(num_classes): 19 | return VGG('vgg16', num_classes) 20 | 21 | def VGG19(num_classes): 22 | return VGG('vgg19', num_classes) 23 | 24 | class VGG(nn.Module): 25 | 26 | def __init__(self, vgg_name, num_classes): 27 | super(VGG, self).__init__() 28 | self.features = self._make_layers(cfg[vgg_name]) 29 | self.fc1 = nn.Linear(512, num_classes) 30 | 31 | def forward(self, x): 32 | out = self.features(x) 33 | out = out.view(out.size(0), -1) 34 | return self.fc1(out) 35 | 36 | def _make_layers(self, cfg): 37 | layers = [] 38 | in_channels = 3 39 | for x in cfg: 40 | if x == 'M': 41 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 42 | else: 43 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 44 | nn.BatchNorm2d(x), 45 | nn.ReLU(inplace=True)] 46 | in_channels = x 47 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 48 | return nn.Sequential(*layers) 49 | 50 | def test(): 51 | print('--- run vgg test ---') 52 | x = torch.randn(2,3,32,32) 53 | for net in [VGG11(10), VGG13(10), VGG16(10), VGG19(10)]: 54 | print(net) 55 | y = net(x) 56 | print(y.size()) 57 | 58 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import os 3 | import torch 4 | from torch import optim 5 | from torch.optim import lr_scheduler 6 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 7 | 8 | from utils import datasets 9 | from utils.ramps import exp_warmup 10 | from utils.config import parse_commandline_args 11 | from utils.data_utils import DataSetWarpper 12 | from utils.data_utils import TwoStreamBatchSampler 13 | from utils.data_utils import TransformTwice as twice 14 | from architectures.arch import arch 15 | 16 | from trainer import * 17 | build_model = { 18 | 'mtv1': MeanTeacherv1.Trainer, 19 | 'mtv2': MeanTeacherv2.Trainer, 20 | 'piv1': PIv1.Trainer, 21 | 'piv2': PIv2.Trainer, 22 | 'vatv1': VATv1.Trainer, 23 | 'vatv2': VATv2.Trainer, 24 | 'epslab2013v1': ePseudoLabel2013v1.Trainer, 25 | 'epslab2013v2': ePseudoLabel2013v2.Trainer, 26 | 'ipslab2013v1': iPseudoLabel2013v1.Trainer, 27 | 'ipslab2013v2': iPseudoLabel2013v2.Trainer, 28 | 'etempensv1': eTempensv1.Trainer, 29 | 'etempensv2': eTempensv2.Trainer, 30 | 'itempensv1': iTempensv1.Trainer, 31 | 'itempensv2': iTempensv2.Trainer, 32 | 'ictv1': ICTv1.Trainer, 33 | 'ictv2': ICTv2.Trainer, 34 | 'mixmatch': MixMatch.Trainer, 35 | 'ifixmatch': iFixMatch.Trainer, 36 | 'efixmatch': eFixMatch.Trainer, 37 | 'emixpslabv1': eMixPseudoLabelv1.Trainer, 38 | 'emixpslabv2': eMixPseudoLabelv2.Trainer, 39 | } 40 | 41 | def create_loaders_v1(trainset, evalset, label_idxs, unlab_idxs, 42 | num_classes, 43 | config): 44 | if config.data_twice: trainset.transform = twice(trainset.transform) 45 | if config.data_idxs: trainset = DataSetWarpper(trainset, num_classes) 46 | ## two-stream batch loader 47 | batch_size = config.sup_batch_size + config.usp_batch_size 48 | batch_sampler = TwoStreamBatchSampler( 49 | unlab_idxs, label_idxs, batch_size, config.sup_batch_size) 50 | train_loader = torch.utils.data.DataLoader(trainset, 51 | batch_sampler=batch_sampler, 52 | num_workers=config.workers, 53 | pin_memory=True) 54 | ## test batch loader 55 | eval_loader = torch.utils.data.DataLoader(evalset, 56 | batch_size=batch_size, 57 | shuffle=False, 58 | num_workers=2*config.workers, 59 | pin_memory=True, 60 | drop_last=False) 61 | return train_loader, eval_loader 62 | 63 | 64 | def create_loaders_v2(trainset, evalset, label_idxs, unlab_idxs, 65 | num_classes, 66 | config): 67 | if config.data_twice: trainset.transform = twice(trainset.transform) 68 | if config.data_idxs: trainset = DataSetWarpper(trainset, num_classes) 69 | ## supervised batch loader 70 | label_sampler = SubsetRandomSampler(label_idxs) 71 | label_batch_sampler = BatchSampler(label_sampler, config.sup_batch_size, 72 | drop_last=True) 73 | label_loader = torch.utils.data.DataLoader(trainset, 74 | batch_sampler=label_batch_sampler, 75 | num_workers=config.workers, 76 | pin_memory=True) 77 | ## unsupervised batch loader 78 | if not config.label_exclude: unlab_idxs += label_idxs 79 | unlab_sampler = SubsetRandomSampler(unlab_idxs) 80 | unlab_batch_sampler = BatchSampler(unlab_sampler, config.usp_batch_size, 81 | drop_last=True) 82 | unlab_loader = torch.utils.data.DataLoader(trainset, 83 | batch_sampler=unlab_batch_sampler, 84 | num_workers=config.workers, 85 | pin_memory=True) 86 | ## test batch loader 87 | eval_loader = torch.utils.data.DataLoader(evalset, 88 | batch_size=config.sup_batch_size, 89 | shuffle=False, 90 | num_workers=2*config.workers, 91 | pin_memory=True, 92 | drop_last=False) 93 | return label_loader, unlab_loader, eval_loader 94 | 95 | 96 | def create_optim(params, config): 97 | if config.optim == 'sgd': 98 | optimizer = optim.SGD(params, config.lr, 99 | momentum=config.momentum, 100 | weight_decay=config.weight_decay, 101 | nesterov=config.nesterov) 102 | elif config.optim == 'adam': 103 | optimizer = optim.Adam(params, config.lr) 104 | return optimizer 105 | 106 | 107 | def create_lr_scheduler(optimizer, config): 108 | if config.lr_scheduler == 'cos': 109 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, 110 | T_max=config.epochs, 111 | eta_min=config.min_lr) 112 | elif config.lr_scheduler == 'multistep': 113 | if config.steps is None: return None 114 | if isinstance(config.steps, int): config.steps = [config.steps] 115 | scheduler = lr_scheduler.MultiStepLR(optimizer, 116 | milestones=config.steps, 117 | gamma=config.gamma) 118 | elif config.lr_scheduler == 'exp-warmup': 119 | lr_lambda = exp_warmup(config.rampup_length, 120 | config.rampdown_length, 121 | config.epochs) 122 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 123 | elif config.lr_scheduler == 'none': 124 | scheduler = None 125 | else: 126 | raise ValueError("No such scheduler: {}".format(config.lr_scheduler)) 127 | return scheduler 128 | 129 | 130 | def run(config): 131 | print(config) 132 | print("pytorch version : {}".format(torch.__version__)) 133 | ## create save directory 134 | if config.save_freq!=0 and not os.path.exists(config.save_dir): 135 | os.makedirs(config.save_dir) 136 | ## prepare data 137 | dconfig = datasets.load[config.dataset](config.num_labels) 138 | if config.model[-1]=='1': 139 | loaders = create_loaders_v1(**dconfig, config=config) 140 | elif config.model[-1]=='2' or config.model[-5:]=='match': 141 | loaders = create_loaders_v2(**dconfig, config=config) 142 | else: 143 | raise ValueError('No such model: {}'.format(config.model)) 144 | 145 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 146 | ## prepare architecture 147 | net = arch[config.arch](dconfig['num_classes'], config.drop_ratio) 148 | net = net.to(device) 149 | optimizer = create_optim(net.parameters(), config) 150 | scheduler = create_lr_scheduler(optimizer, config) 151 | 152 | ## run the model 153 | MTbased = set(['mt', 'ict']) 154 | if config.model[:-2] in MTbased or config.model[-5:]=='match': 155 | net2 = arch[config.arch](dconfig['num_classes'], config.drop_ratio) 156 | net2 = net2.to(device) 157 | trainer = build_model[config.model](net, net2, optimizer, device, config) 158 | else: 159 | trainer = build_model[config.model](net, optimizer, device, config) 160 | trainer.loop(config.epochs, *loaders, scheduler=scheduler) 161 | 162 | 163 | if __name__ == '__main__': 164 | config = parse_commandline_args() 165 | run(config) 166 | -------------------------------------------------------------------------------- /trainer/ICTv1.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from collections import defaultdict 9 | 10 | from utils.loss import mse_with_softmax 11 | from utils.loss import softmax_loss_mean 12 | from utils.mixup import * 13 | from utils.ramps import exp_rampup 14 | from utils.datasets import decode_label 15 | from utils.data_utils import NO_LABEL 16 | 17 | class Trainer: 18 | 19 | def __init__(self, model, ema_model, optimizer, device, config): 20 | print("ICT-v1") 21 | self.model = model 22 | self.ema_model = ema_model 23 | self.optimizer = optimizer 24 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 25 | self.mixup_loss = mixup_ce_loss_with_softmax #mixup_mse_loss_with_softmax 26 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 27 | config.dataset, config.num_labels, 28 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 29 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 30 | self.global_step = 0 31 | self.epoch = 0 32 | self.alpha = config.mixup_alpha 33 | self.usp_weight = config.usp_weight 34 | self.ema_decay = config.ema_decay 35 | self.rampup = exp_rampup(config.weight_rampup) 36 | self.device = device 37 | self.save_freq = config.save_freq 38 | self.print_freq = config.print_freq 39 | 40 | def train_iteration(self, data_loader, print_freq): 41 | loop_info = defaultdict(list) 42 | label_n, unlab_n = 0, 0 43 | for batch_idx, ((x1, x2), targets) in enumerate(data_loader): 44 | self.global_step += 1 45 | x1, x2, targets = [t.to(self.device) for t in (x1,x2,targets)] 46 | ##=== decode targets === 47 | lmask, umask = self.decode_targets(targets) 48 | lbs, ubs = lmask.float().sum().item(), umask.float().sum().item() 49 | 50 | ##=== forward === 51 | outputs = self.model(x1) 52 | loss = self.ce_loss(outputs[lmask], targets[lmask]) 53 | loop_info['lloss'].append(loss.item()) 54 | 55 | ##=== Semi-supervised Training === 56 | ## update mean-teacher 57 | self.update_ema(self.model, self.ema_model, self.ema_decay, self.global_step) 58 | with torch.no_grad(): 59 | ema_outputs = self.ema_model(x2) # x2 seems better 60 | ema_outputs = ema_outputs.detach() 61 | ## mixup consistency loss 62 | # mixup-loss-v1 63 | #mixed_x, mixed_y, lam = mixup_one_target(x1, ema_outputs, self.alpha, self.device) 64 | #mixed_outputs = self.model(mixed_x) 65 | ##mix_loss = softmax_loss_mean(mixed_outputs, mixed_y) 66 | #mix_loss = mse_with_softmax(mixed_outputs, mixed_y) 67 | 68 | # mixup-loss-v2 69 | mixed_x, y_a, y_b, lam = mixup_two_targets(x1, ema_outputs, self.alpha, 70 | self.device, is_bias=False) 71 | mixed_outputs = self.model(mixed_x) 72 | mix_loss = self.mixup_loss(mixed_outputs, y_a, y_b, lam) 73 | 74 | mix_loss *= self.rampup(self.epoch)*self.usp_weight 75 | loss += mix_loss; loop_info['aMix'].append(mix_loss.item()) 76 | 77 | ## backwark 78 | self.optimizer.zero_grad() 79 | loss.backward() 80 | self.optimizer.step() 81 | 82 | ##=== log info === 83 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 84 | loop_info['lacc'].append(targets[lmask].eq(outputs[lmask].max(1)[1]).float().sum().item()) 85 | loop_info['uacc'].append(targets[umask].eq(outputs[umask].max(1)[1]).float().sum().item()) 86 | loop_info['u2acc'].append(targets[umask].eq(ema_outputs[umask].max(1)[1]).float().sum().item()) 87 | if print_freq>0 and (batch_idx%print_freq)==0: 88 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 89 | print(f">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 90 | return loop_info, label_n 91 | 92 | def test_iteration(self, data_loader, print_freq): 93 | loop_info = defaultdict(list) 94 | label_n, unlab_n = 0, 0 95 | for batch_idx, (data, targets) in enumerate(data_loader): 96 | data, targets = data.to(self.device), targets.to(self.device) 97 | lbs, ubs = data.size(0), -1 98 | 99 | ##=== forward === 100 | outputs = self.model(data) 101 | ema_outputs = self.ema_model(data) 102 | loss = self.ce_loss(outputs, targets) 103 | loop_info['lloss'].append(loss.item()) 104 | 105 | ##=== log info === 106 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 107 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 108 | loop_info['l2acc'].append(targets.eq(ema_outputs.max(1)[1]).float().sum().item()) 109 | if print_freq>0 and (batch_idx%print_freq)==0: 110 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 111 | print(f">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 112 | return loop_info, label_n 113 | 114 | def train(self, data_loader, print_freq=20): 115 | self.model.train() 116 | self.ema_model.train() 117 | with torch.enable_grad(): 118 | return self.train_iteration(data_loader, print_freq) 119 | 120 | def test(self, data_loader, print_freq=10): 121 | self.model.eval() 122 | self.ema_model.eval() 123 | with torch.no_grad(): 124 | return self.test_iteration(data_loader, print_freq) 125 | 126 | def loop(self, epochs, train_data, test_data, scheduler=None): 127 | best_info, best_acc, n = None, 0., 0 128 | for ep in range(epochs): 129 | self.epoch = ep 130 | if scheduler is not None: scheduler.step() 131 | print("------ Training epochs: {} ------".format(ep)) 132 | self.train(train_data, self.print_freq) 133 | print("------ Testing epochs: {} ------".format(ep)) 134 | info, n = self.test(test_data, self.print_freq) 135 | acc = sum(info['lacc']) / n 136 | if acc>best_acc: best_info, best_acc = info, acc 137 | ## save model 138 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 139 | self.save(ep) 140 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 141 | 142 | def update_ema(self, model, ema_model, alpha, global_step): 143 | alpha = min(1 - 1 / (global_step +1), alpha) 144 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 145 | ema_param.data.mul_(alpha).add_(1-alpha, param.data) 146 | 147 | def decode_targets(self, targets): 148 | label_mask = targets.ge(0) 149 | unlab_mask = targets.le(NO_LABEL) 150 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 151 | return label_mask, unlab_mask 152 | 153 | def gen_info(self, info, lbs, ubs, iteration=True): 154 | ret = [] 155 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 156 | for k, val in info.items(): 157 | n = nums[k[0]] 158 | v = val[-1] if iteration else sum(val) 159 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 160 | ret.append(s) 161 | return '\t'.join(ret) 162 | 163 | def save(self, epoch, **kwargs): 164 | if self.save_dir is not None: 165 | model_out_path = Path(self.save_dir) 166 | state = {"epoch": epoch, 167 | "weight": self.model.state_dict()} 168 | if not model_out_path.exists(): 169 | model_out_path.mkdir() 170 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 171 | torch.save(state, save_target) 172 | print('==> save model to {}'.format(save_target)) 173 | -------------------------------------------------------------------------------- /trainer/ICTv2.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from collections import defaultdict 9 | from itertools import cycle 10 | 11 | from utils.loss import mse_with_softmax 12 | from utils.loss import softmax_loss_mean 13 | from utils.mixup import * 14 | from utils.ramps import exp_rampup 15 | from utils.datasets import decode_label 16 | from utils.data_utils import NO_LABEL 17 | 18 | from pdb import set_trace 19 | 20 | class Trainer: 21 | 22 | def __init__(self, model, ema_model, optimizer, device, config): 23 | print("ICT-v2") 24 | self.model = model 25 | self.ema_model = ema_model 26 | self.optimizer = optimizer 27 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 28 | self.mixup_loss = mixup_ce_loss_with_softmax #mixup_mse_loss_with_softmax 29 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 30 | config.dataset, config.num_labels, 31 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 32 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 33 | self.global_step = 0 34 | self.epoch = 0 35 | self.alpha = config.mixup_alpha 36 | self.usp_weight = config.usp_weight 37 | self.ema_decay = config.ema_decay 38 | self.rampup = exp_rampup(config.weight_rampup) 39 | self.device = device 40 | self.save_freq = config.save_freq 41 | self.print_freq = config.print_freq 42 | 43 | def train_iteration(self, label_loader, unlab_loader, print_freq): 44 | loop_info = defaultdict(list) 45 | batch_idx, label_n, unlab_n = 0, 0, 0 46 | for (label_x, label_y), (unlab_x, unlab_y) in zip(cycle(label_loader), unlab_loader): 47 | self.global_step += 1; batch_idx+=1; 48 | label_x, label_y = label_x.to(self.device), label_y.to(self.device) 49 | unlab_x, unlab_y = unlab_x.to(self.device), unlab_y.to(self.device) 50 | ##=== decode targets of unlabeled data === 51 | self.decode_targets(unlab_y) 52 | lbs, ubs = label_x.size(0), unlab_x.size(0) 53 | 54 | ##=== forward === 55 | outputs = self.model(label_x) 56 | loss = self.ce_loss(outputs, label_y) 57 | loop_info['lSup'].append(loss.item()) 58 | 59 | ##=== Semi-supervised Training === 60 | ## update mean-teacher 61 | self.update_ema(self.model, self.ema_model, self.ema_decay, self.global_step) 62 | with torch.no_grad(): 63 | ema_outputs_u = self.ema_model(unlab_x) 64 | 65 | ## mixup-consistency loss 66 | # mixup-loss-v1 67 | #mixed_ux, mixed_uy, lam = mixup_one_target(unlab_x, ema_outputs_u, 68 | # self.alpha, self.device) 69 | #mixed_outputs_u = self.model(mixed_ux) 70 | ##mix_loss = softmax_loss_mean(mixed_outputs_u, mixed_uy) 71 | #mix_loss = mse_with_softmax(mixed_outputs_u, mixed_uy) 72 | 73 | # mixup-loss-v2 74 | mixed_ux, uy_a, uy_b, lam = mixup_two_targets(unlab_x, ema_outputs_u, 75 | self.alpha, self.device, is_bias=False) 76 | mixed_outputs_u = self.model(mixed_ux) 77 | mix_loss = self.mixup_loss(mixed_outputs_u, uy_a, uy_b, lam) 78 | 79 | mix_loss *= self.rampup(self.epoch)*self.usp_weight 80 | loss += mix_loss; loop_info['uMix'].append(mix_loss.item()) 81 | 82 | ## backwark 83 | self.optimizer.zero_grad() 84 | loss.backward() 85 | self.optimizer.step() 86 | 87 | ##=== log info === 88 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 89 | loop_info['lacc'].append(label_y.eq(outputs.max(1)[1]).float().sum().item()) 90 | loop_info['u2acc'].append(unlab_y.eq(ema_outputs_u.max(1)[1]).float().sum().item()) 91 | if print_freq>0 and (batch_idx%print_freq)==0: 92 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 93 | print(f">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 94 | return loop_info, label_n 95 | 96 | def test_iteration(self, data_loader, print_freq): 97 | loop_info = defaultdict(list) 98 | label_n, unlab_n = 0, 0 99 | for batch_idx, (data, targets) in enumerate(data_loader): 100 | data, targets = data.to(self.device), targets.to(self.device) 101 | lbs, ubs = data.size(0), -1 102 | 103 | ##=== forward === 104 | outputs = self.model(data) 105 | loss = self.ce_loss(outputs, targets) 106 | loop_info['lSup'].append(loss.item()) 107 | 108 | with torch.no_grad(): 109 | ema_outputs = self.ema_model(data) 110 | 111 | ##=== log info === 112 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 113 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 114 | loop_info['l2acc'].append(targets.eq(ema_outputs.max(1)[1]).float().sum().item()) 115 | if print_freq>0 and (batch_idx%print_freq)==0: 116 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 117 | print(f">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 118 | return loop_info, label_n 119 | 120 | def train(self, label_loader, unlab_loader, print_freq=20): 121 | self.model.train() 122 | self.ema_model.train() 123 | with torch.enable_grad(): 124 | return self.train_iteration(label_loader, unlab_loader, print_freq) 125 | 126 | def test(self, data_loader, print_freq=10): 127 | self.model.eval() 128 | self.ema_model.eval() 129 | with torch.no_grad(): 130 | return self.test_iteration(data_loader, print_freq) 131 | 132 | def loop(self, epochs, label_data, unlab_data, test_data, scheduler=None): 133 | best_acc, n, best_info = 0., 0., None 134 | for ep in range(epochs): 135 | self.epoch = ep 136 | if scheduler is not None: scheduler.step() 137 | print("------ Training epochs: {} ------".format(ep)) 138 | self.train(label_data, unlab_data, self.print_freq) 139 | print("------ Testing epochs: {} ------".format(ep)) 140 | info, n = self.test(test_data, self.print_freq) 141 | acc = sum(info['lacc']) / n 142 | if acc>best_acc: best_acc, best_info = acc, info 143 | ## save model 144 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 145 | self.save(ep) 146 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 147 | 148 | def update_ema(self, model, ema_model, alpha, global_step): 149 | alpha = min(1 - 1 / (global_step +1), alpha) 150 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 151 | ema_param.data.mul_(alpha).add_(1-alpha, param.data) 152 | 153 | def decode_targets(self, targets): 154 | label_mask = targets.ge(0) 155 | unlab_mask = targets.le(NO_LABEL) 156 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 157 | return label_mask, unlab_mask 158 | 159 | def gen_info(self, info, lbs, ubs, iteration=True): 160 | ret = [] 161 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 162 | for k, val in info.items(): 163 | n = nums[k[0]] 164 | v = val[-1] if iteration else sum(val) 165 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 166 | ret.append(s) 167 | return '\t'.join(ret) 168 | 169 | def save(self, epoch, **kwargs): 170 | if self.save_dir is not None: 171 | model_out_path = Path(self.save_dir) 172 | state = {"epoch": epoch, 173 | "weight": self.model.state_dict()} 174 | if not model_out_path.exists(): 175 | model_out_path.mkdir() 176 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 177 | torch.save(state, save_target) 178 | print('==> save model to {}'.format(save_target)) 179 | -------------------------------------------------------------------------------- /trainer/MeanTeacherv1.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from collections import defaultdict 9 | 10 | from utils.loss import mse_with_softmax 11 | from utils.ramps import exp_rampup 12 | from utils.datasets import decode_label 13 | from utils.data_utils import NO_LABEL 14 | 15 | from pdb import set_trace 16 | 17 | class Trainer: 18 | 19 | def __init__(self, model, ema_model, optimizer, device, config): 20 | print("MeanTeacher-v1") 21 | self.model = model 22 | self.ema_model = ema_model 23 | self.optimizer = optimizer 24 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 25 | self.cons_loss = mse_with_softmax #F.mse_loss 26 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 27 | config.dataset, config.num_labels, 28 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 29 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 30 | self.usp_weight = config.usp_weight 31 | self.ema_decay = config.ema_decay 32 | self.rampup = exp_rampup(config.weight_rampup) 33 | self.save_freq = config.save_freq 34 | self.print_freq = config.print_freq 35 | self.device = device 36 | self.global_step = 0 37 | self.epoch = 0 38 | 39 | def train_iteration(self, data_loader, print_freq): 40 | loop_info = defaultdict(list) 41 | label_n, unlab_n = 0, 0 42 | for batch_idx, ((x1, x2), targets) in enumerate(data_loader): 43 | self.global_step += 1 44 | x1, x2, targets = [t.to(self.device) for t in (x1,x2,targets)] 45 | ##=== decode targets === 46 | lmask, umask = self.decode_targets(targets) 47 | lbs, ubs = lmask.float().sum().item(), umask.float().sum().item() 48 | 49 | ##=== forward === 50 | outputs = self.model(x1) 51 | loss = self.ce_loss(outputs[lmask], targets[lmask]) 52 | loop_info['lloss'].append(loss.item()) 53 | 54 | ##=== Semi-supervised Training === 55 | self.update_ema(self.model, self.ema_model, self.ema_decay, self.global_step) 56 | ## consistency loss 57 | with torch.no_grad(): 58 | ema_outputs = self.ema_model(x2) 59 | ema_outputs = ema_outputs.detach() 60 | cons_loss = self.cons_loss(outputs, ema_outputs) 61 | cons_loss *= self.rampup(self.epoch)*self.usp_weight 62 | loss += cons_loss; loop_info['aCons'].append(cons_loss.item()) 63 | 64 | ## backwark 65 | self.optimizer.zero_grad() 66 | loss.backward() 67 | self.optimizer.step() 68 | 69 | ##=== log info === 70 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 71 | loop_info['lacc'].append(targets[lmask].eq(outputs[lmask].max(1)[1]).float().sum().item()) 72 | loop_info['uacc'].append(targets[umask].eq(outputs[umask].max(1)[1]).float().sum().item()) 73 | loop_info['u2acc'].append(targets[umask].eq(ema_outputs[umask].max(1)[1]).float().sum().item()) 74 | if print_freq>0 and (batch_idx%print_freq)==0: 75 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 76 | print(f">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 77 | return loop_info, label_n 78 | 79 | def test_iteration(self, data_loader, print_freq): 80 | loop_info = defaultdict(list) 81 | label_n, unlab_n = 0, 0 82 | for batch_idx, (data, targets) in enumerate(data_loader): 83 | data, targets = data.to(self.device), targets.to(self.device) 84 | lbs, ubs = data.size(0), -1 85 | 86 | ##=== forward === 87 | outputs = self.model(data) 88 | ema_outputs = self.ema_model(data) 89 | loss = self.ce_loss(outputs, targets) 90 | loop_info['lloss'].append(loss.item()) 91 | 92 | ##=== log info === 93 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 94 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 95 | loop_info['l2acc'].append(targets.eq(ema_outputs.max(1)[1]).float().sum().item()) 96 | if print_freq>0 and (batch_idx%print_freq)==0: 97 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 98 | print(f">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 99 | return loop_info, label_n 100 | 101 | def train(self, data_loader, print_freq=20): 102 | self.model.train() 103 | self.ema_model.train() 104 | with torch.enable_grad(): 105 | return self.train_iteration(data_loader, print_freq) 106 | 107 | def test(self, data_loader, print_freq=10): 108 | self.model.eval() 109 | self.ema_model.eval() 110 | with torch.no_grad(): 111 | return self.test_iteration(data_loader, print_freq) 112 | 113 | def loop(self, epochs, train_data, test_data, scheduler=None): 114 | best_info, best_acc, n = None, 0., 0 115 | for ep in range(epochs): 116 | self.epoch = ep 117 | if scheduler is not None: scheduler.step() 118 | print("------ Training epochs: {} ------".format(ep)) 119 | self.train(train_data, self.print_freq) 120 | print("------ Testing epochs: {} ------".format(ep)) 121 | info, n = self.test(test_data, self.print_freq) 122 | acc = sum(info['lacc']) / n 123 | if acc>best_acc: best_info, best_acc = info, acc 124 | ## save model 125 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 126 | self.save(ep) 127 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 128 | 129 | def update_ema(self, model, ema_model, alpha, global_step): 130 | alpha = min(1 - 1 / (global_step +1), alpha) 131 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 132 | ema_param.data.mul_(alpha).add_(1-alpha, param.data) 133 | 134 | def decode_targets(self, targets): 135 | label_mask = targets.ge(0) 136 | unlab_mask = targets.le(NO_LABEL) 137 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 138 | return label_mask, unlab_mask 139 | 140 | def gen_info(self, info, lbs, ubs, iteration=True): 141 | ret = [] 142 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 143 | for k, val in info.items(): 144 | n = nums[k[0]] 145 | v = val[-1] if iteration else sum(val) 146 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 147 | ret.append(s) 148 | return '\t'.join(ret) 149 | 150 | def save(self, epoch, **kwargs): 151 | if self.save_dir is not None: 152 | model_out_path = Path(self.save_dir) 153 | state = {"epoch": epoch, 154 | "weight": self.model.state_dict()} 155 | if not model_out_path.exists(): 156 | model_out_path.mkdir() 157 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 158 | torch.save(state, save_target) 159 | print('==> save model to {}'.format(save_target)) 160 | -------------------------------------------------------------------------------- /trainer/MeanTeacherv2.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from itertools import cycle 9 | from collections import defaultdict 10 | 11 | from utils.loss import mse_with_softmax 12 | from utils.ramps import exp_rampup 13 | from utils.datasets import decode_label 14 | from utils.data_utils import NO_LABEL 15 | 16 | from pdb import set_trace 17 | 18 | class Trainer: 19 | 20 | def __init__(self, model, ema_model, optimizer, device, config): 21 | print("MeanTeacher-v2") 22 | self.model = model 23 | self.ema_model = ema_model 24 | self.optimizer = optimizer 25 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 26 | self.cons_loss = mse_with_softmax #F.mse_loss 27 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 28 | config.dataset, config.num_labels, 29 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 30 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 31 | self.usp_weight = config.usp_weight 32 | self.ema_decay = config.ema_decay 33 | self.rampup = exp_rampup(config.weight_rampup) 34 | self.save_freq = config.save_freq 35 | self.print_freq = config.print_freq 36 | self.device = device 37 | self.global_step = 0 38 | self.epoch = 0 39 | 40 | def train_iteration(self, label_loader, unlab_loader, print_freq): 41 | loop_info = defaultdict(list) 42 | batch_idx, label_n, unlab_n = 0, 0, 0 43 | for ((x1,_), label_y), ((u1,u2), unlab_y) in zip(cycle(label_loader), unlab_loader): 44 | self.global_step += 1 45 | label_x, unlab_x1, unlab_x2 = x1.to(self.device), u1.to(self.device), u2.to(self.device) 46 | label_y, unlab_y = label_y.to(self.device), unlab_y.to(self.device) 47 | ##=== decode targets === 48 | self.decode_targets(unlab_y) 49 | lbs, ubs = x1.size(0), u1.size(0) 50 | 51 | ##=== forward === 52 | outputs = self.model(label_x) 53 | loss = self.ce_loss(outputs, label_y) 54 | loop_info['lloss'].append(loss.item()) 55 | 56 | 57 | ##=== Semi-supervised Training === 58 | ## update mean-teacher 59 | self.update_ema(self.model, self.ema_model, self.ema_decay, self.global_step) 60 | ## consistency loss 61 | unlab_outputs = self.model(unlab_x1) 62 | with torch.no_grad(): 63 | ema_outputs = self.ema_model(unlab_x2) 64 | ema_outputs = ema_outputs.detach() 65 | cons_loss = self.cons_loss(unlab_outputs, ema_outputs) 66 | cons_loss *= self.rampup(self.epoch)*self.usp_weight 67 | loss += cons_loss; loop_info['uloss'].append(cons_loss.item()) 68 | 69 | ## backwark 70 | self.optimizer.zero_grad() 71 | loss.backward() 72 | self.optimizer.step() 73 | 74 | ##=== log info === 75 | batch_idx, label_n, unlab_n = batch_idx+1, label_n+lbs, unlab_n+ubs 76 | loop_info['lacc'].append(label_y.eq(outputs.max(1)[1]).float().sum().item()) 77 | loop_info['uacc'].append(unlab_y.eq(unlab_outputs.max(1)[1]).float().sum().item()) 78 | loop_info['u2acc'].append(unlab_y.eq(ema_outputs.max(1)[1]).float().sum().item()) 79 | if print_freq>0 and (batch_idx%print_freq)==0: 80 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 81 | print(f">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 82 | return loop_info, label_n 83 | 84 | def test_iteration(self, data_loader, print_freq): 85 | loop_info = defaultdict(list) 86 | label_n, unlab_n = 0, 0 87 | for batch_idx, (data, targets) in enumerate(data_loader): 88 | data, targets = data.to(self.device), targets.to(self.device) 89 | lbs, ubs = data.size(0), -1 90 | 91 | ##=== forward === 92 | outputs = self.model(data) 93 | ema_outputs = self.ema_model(data) 94 | loss = self.ce_loss(outputs, targets) 95 | loop_info['lloss'].append(loss.item()) 96 | 97 | ##=== log info === 98 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 99 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 100 | loop_info['l2acc'].append(targets.eq(ema_outputs.max(1)[1]).float().sum().item()) 101 | if print_freq>0 and (batch_idx%print_freq)==0: 102 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 103 | print(f">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 104 | return loop_info, label_n 105 | 106 | 107 | def train(self, label_loader, unlab_loader, print_freq=20): 108 | self.model.train() 109 | self.ema_model.train() 110 | with torch.enable_grad(): 111 | return self.train_iteration(label_loader, unlab_loader, print_freq) 112 | 113 | def test(self, data_loader, print_freq=10): 114 | self.model.eval() 115 | self.ema_model.eval() 116 | with torch.no_grad(): 117 | return self.test_iteration(data_loader, print_freq) 118 | 119 | def loop(self, epochs, label_data, unlab_data, test_data, scheduler=None): 120 | best_acc, n, best_info = 0., 0., None 121 | for ep in range(epochs): 122 | self.epoch = ep 123 | if scheduler is not None: scheduler.step() 124 | print("------ Training epochs: {} ------".format(ep)) 125 | self.train(label_data, unlab_data, self.print_freq) 126 | print("------ Testing epochs: {} ------".format(ep)) 127 | info, n = self.test(test_data, self.print_freq) 128 | acc = sum(info['lacc'])/n 129 | if acc>best_acc: best_acc, best_info = acc, info 130 | ## save model 131 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 132 | self.save(ep) 133 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 134 | 135 | def update_ema(self, model, ema_model, alpha, global_step): 136 | alpha = min(1 - 1 / (global_step +1), alpha) 137 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 138 | ema_param.data.mul_(alpha).add_(1-alpha, param.data) 139 | 140 | def decode_targets(self, targets): 141 | label_mask = targets.ge(0) 142 | unlab_mask = targets.le(NO_LABEL) 143 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 144 | return label_mask, unlab_mask 145 | 146 | def gen_info(self, info, lbs, ubs, iteration=True): 147 | ret = [] 148 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 149 | for k, val in info.items(): 150 | n = nums[k[0]] 151 | v = val[-1] if iteration else sum(val) 152 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 153 | ret.append(s) 154 | return '\t'.join(ret) 155 | 156 | def save(self, epoch, **kwargs): 157 | if self.save_dir is not None: 158 | model_out_path = Path(self.save_dir) 159 | state = {"epoch": epoch, 160 | "weight": self.model.state_dict()} 161 | if not model_out_path.exists(): 162 | model_out_path.mkdir() 163 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 164 | torch.save(state, save_target) 165 | print('==> save model to {}'.format(save_target)) 166 | -------------------------------------------------------------------------------- /trainer/PIv1.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from collections import defaultdict 9 | 10 | from utils.loss import mse_with_softmax 11 | from utils.ramps import exp_rampup 12 | from utils.datasets import decode_label 13 | from utils.data_utils import NO_LABEL 14 | 15 | class Trainer: 16 | 17 | def __init__(self, model, optimizer, device, config): 18 | print('PI-v1') 19 | self.model = model 20 | self.optimizer = optimizer 21 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 22 | self.cons_loss = mse_with_softmax #F.mse_loss 23 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 24 | config.dataset, config.num_labels, 25 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 26 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 27 | self.usp_weight = config.usp_weight 28 | self.rampup = exp_rampup(config.weight_rampup) 29 | self.save_freq = config.save_freq 30 | self.print_freq = config.print_freq 31 | self.device = device 32 | self.epoch = 0 33 | 34 | def train_iteration(self, data_loader, print_freq): 35 | loop_info = defaultdict(list) 36 | label_n, unlab_n = 0, 0 37 | for batch_idx, ((x1, x2), targets) in enumerate(data_loader): 38 | x1, x2, targets = [t.to(self.device) for t in (x1,x2,targets)] 39 | ##=== decode targets === 40 | lmask, umask = self.decode_targets(targets) 41 | lbs, ubs = lmask.float().sum().item(), umask.float().sum().item() 42 | 43 | ##=== forward === 44 | outputs = self.model(x1) 45 | loss = self.ce_loss(outputs[lmask], targets[lmask]) 46 | loop_info['lloss'].append(loss.item()) 47 | 48 | ##=== Semi-supervised Training === 49 | ## consistency loss 50 | with torch.no_grad(): 51 | pi_outputs = self.model(x2) 52 | pi_outputs = pi_outputs.detach() 53 | cons_loss = self.cons_loss(outputs, pi_outputs) 54 | cons_loss *= self.rampup(self.epoch)*self.usp_weight 55 | loss += cons_loss; loop_info['aCons'].append(cons_loss.item()) 56 | 57 | ## backwark 58 | self.optimizer.zero_grad() 59 | loss.backward() 60 | self.optimizer.step() 61 | 62 | ##=== log info === 63 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 64 | lacc = targets[lmask].eq(outputs[lmask].max(1)[1]).float().sum().item() 65 | uacc = targets[umask].eq(outputs[umask].max(1)[1]).float().sum().item() 66 | loop_info['lacc'].append(lacc) 67 | loop_info['uacc'].append(uacc) 68 | if print_freq>0 and (batch_idx%print_freq)==0: 69 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 70 | print(f">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 71 | return loop_info, label_n 72 | 73 | def test_iteration(self, data_loader, print_freq): 74 | loop_info = defaultdict(list) 75 | label_n, unlab_n = 0, 0 76 | for batch_idx, (data, targets) in enumerate(data_loader): 77 | data, targets = data.to(self.device), targets.to(self.device) 78 | lbs, ubs = data.size(0), -1 79 | 80 | ##=== forward === 81 | outputs = self.model(data) 82 | loss = self.ce_loss(outputs, targets) 83 | loop_info['lloss'].append(loss.item()) 84 | 85 | ##=== log info === 86 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 87 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 88 | if print_freq>0 and (batch_idx%print_freq)==0: 89 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 90 | print(f">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 91 | return loop_info, label_n 92 | 93 | def train(self, data_loader, print_freq=20): 94 | self.model.train() 95 | with torch.enable_grad(): 96 | return self.train_iteration(data_loader, print_freq) 97 | 98 | def test(self, data_loader, print_freq=10): 99 | self.model.eval() 100 | with torch.no_grad(): 101 | return self.test_iteration(data_loader, print_freq) 102 | 103 | def loop(self, epochs, train_data, test_data, scheduler=None): 104 | best_info, best_acc, n = None, 0., 0 105 | for ep in range(epochs): 106 | self.epoch = ep 107 | if scheduler is not None: scheduler.step() 108 | print("------ Training epochs: {} ------".format(ep)) 109 | self.train(train_data, self.print_freq) 110 | print("------ Testing epochs: {} ------".format(ep)) 111 | info, n = self.test(test_data, self.print_freq) 112 | acc = sum(info['lacc']) / n 113 | if acc>best_acc: best_info, best_acc = info, acc 114 | ## save model 115 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 116 | self.save(ep) 117 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 118 | 119 | def decode_targets(self, targets): 120 | label_mask = targets.ge(0) 121 | unlab_mask = targets.le(NO_LABEL) 122 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 123 | return label_mask, unlab_mask 124 | 125 | def gen_info(self, info, lbs, ubs, iteration=True): 126 | ret = [] 127 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 128 | for k, val in info.items(): 129 | n = nums[k[0]] 130 | v = val[-1] if iteration else sum(val) 131 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 132 | ret.append(s) 133 | return '\t'.join(ret) 134 | 135 | def save(self, epoch, **kwargs): 136 | if self.save_dir is not None: 137 | model_out_path = Path(self.save_dir) 138 | state = {"epoch": epoch, 139 | "weight": self.model.state_dict()} 140 | if not model_out_path.exists(): 141 | model_out_path.mkdir() 142 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 143 | torch.save(state, save_target) 144 | print('==> save model to {}'.format(save_target)) 145 | -------------------------------------------------------------------------------- /trainer/PIv2.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from itertools import cycle 9 | from collections import defaultdict 10 | 11 | from utils.loss import mse_with_softmax 12 | from utils.ramps import exp_rampup 13 | from utils.datasets import decode_label 14 | from utils.data_utils import NO_LABEL 15 | 16 | class Trainer: 17 | 18 | def __init__(self, model, optimizer, device, config): 19 | print('PI-v2') 20 | self.model = model 21 | self.optimizer = optimizer 22 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 23 | self.cons_loss = mse_with_softmax #F.mse_loss 24 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 25 | config.dataset, config.num_labels, 26 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 27 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 28 | self.usp_weight = config.usp_weight 29 | self.rampup = exp_rampup(config.weight_rampup) 30 | self.save_freq = config.save_freq 31 | self.print_freq = config.print_freq 32 | self.device = device 33 | self.epoch = 0 34 | 35 | def train_iteration(self, label_loader, unlab_loader, print_freq): 36 | loop_info = defaultdict(list) 37 | batch_idx, label_n, unlab_n = 0, 0, 0 38 | for ((x1,_), label_y), ((u1,u2), unlab_y) in zip(cycle(label_loader), unlab_loader): 39 | label_x, unlab_x1, unlab_x2 = x1.to(self.device), u1.to(self.device), u2.to(self.device) 40 | label_y, unlab_y = label_y.to(self.device), unlab_y.to(self.device) 41 | ##=== decode targets === 42 | self.decode_targets(unlab_y) 43 | lbs, ubs = x1.size(0), u1.size(0) 44 | 45 | ##=== forward === 46 | outputs = self.model(label_x) 47 | loss = self.ce_loss(outputs, label_y) 48 | loop_info['lloss'].append(loss.item()) 49 | 50 | ##=== Semi-supervised Training === 51 | ## consistency loss 52 | unlab_outputs = self.model(unlab_x1) 53 | with torch.no_grad(): 54 | pi_outputs = self.model(unlab_x2) 55 | pi_outputs = pi_outputs.detach() 56 | cons_loss = self.cons_loss(unlab_outputs, pi_outputs) 57 | cons_loss *= self.rampup(self.epoch)*self.usp_weight 58 | loss += cons_loss; loop_info['uloss'].append(cons_loss.item()) 59 | 60 | ## backwark 61 | self.optimizer.zero_grad() 62 | loss.backward() 63 | self.optimizer.step() 64 | 65 | ##=== log info === 66 | batch_idx, label_n, unlab_n = batch_idx+1, label_n+lbs, unlab_n+ubs 67 | loop_info['lacc'].append(label_y.eq(outputs.max(1)[1]).float().sum().item()) 68 | loop_info['uacc'].append(unlab_y.eq(unlab_outputs.max(1)[1]).float().sum().item()) 69 | if print_freq>0 and (batch_idx%print_freq)==0: 70 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 71 | print(f">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 72 | return loop_info, label_n 73 | 74 | def test_iteration(self, data_loader, print_freq): 75 | loop_info = defaultdict(list) 76 | label_n, unlab_n = 0, 0 77 | for batch_idx, (data, targets) in enumerate(data_loader): 78 | data, targets = data.to(self.device), targets.to(self.device) 79 | lbs, ubs = data.size(0), -1 80 | 81 | ##=== forward === 82 | outputs = self.model(data) 83 | loss = self.ce_loss(outputs, targets) 84 | loop_info['lloss'].append(loss.item()) 85 | 86 | ##=== log info === 87 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 88 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 89 | if print_freq>0 and (batch_idx%print_freq)==0: 90 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 91 | print(f">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 92 | return loop_info, label_n 93 | 94 | def train(self, label_loader, unlab_loader, print_freq=20): 95 | self.model.train() 96 | with torch.enable_grad(): 97 | return self.train_iteration(label_loader, unlab_loader, print_freq) 98 | 99 | def test(self, data_loader, print_freq=10): 100 | self.model.eval() 101 | with torch.no_grad(): 102 | return self.test_iteration(data_loader, print_freq) 103 | 104 | def loop(self, epochs, label_data, unlab_data, test_data, scheduler=None): 105 | best_acc, n, best_info = 0., 0., None 106 | for ep in range(epochs): 107 | self.epoch = ep 108 | if scheduler is not None: scheduler.step() 109 | print("------ Training epochs: {} ------".format(ep)) 110 | self.train(label_data, unlab_data, self.print_freq) 111 | print("------ Testing epochs: {} ------".format(ep)) 112 | info, n = self.test(test_data, self.print_freq) 113 | acc = sum(info['lacc'])/n 114 | if acc>best_acc: best_acc, best_info = acc, info 115 | ## save model 116 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 117 | self.save(ep) 118 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 119 | 120 | def decode_targets(self, targets): 121 | label_mask = targets.ge(0) 122 | unlab_mask = targets.le(NO_LABEL) 123 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 124 | return label_mask, unlab_mask 125 | 126 | def gen_info(self, info, lbs, ubs, iteration=True): 127 | ret = [] 128 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 129 | for k, val in info.items(): 130 | n = nums[k[0]] 131 | v = val[-1] if iteration else sum(val) 132 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 133 | ret.append(s) 134 | return '\t'.join(ret) 135 | 136 | def save(self, epoch, **kwargs): 137 | if self.save_dir is not None: 138 | model_out_path = Path(self.save_dir) 139 | state = {"epoch": epoch, 140 | "weight": self.model.state_dict()} 141 | if not model_out_path.exists(): 142 | model_out_path.mkdir() 143 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 144 | torch.save(state, save_target) 145 | print('==> save model to {}'.format(save_target)) 146 | -------------------------------------------------------------------------------- /trainer/VATv1.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from collections import defaultdict 9 | 10 | from utils.loss import mse_with_softmax 11 | from utils.loss import kl_div_with_logit 12 | from utils.ramps import exp_rampup 13 | from utils.context import disable_tracking_bn_stats 14 | from utils.datasets import decode_label 15 | from utils.data_utils import NO_LABEL 16 | 17 | class Trainer: 18 | 19 | def __init__(self, model, optimizer, device, config): 20 | print('VAT-v1') 21 | self.model = model 22 | self.optimizer = optimizer 23 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 24 | self.cons_loss = mse_with_softmax #kl_div_with_logit 25 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 26 | config.dataset, config.num_labels, 27 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 28 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 29 | self.usp_weight = config.usp_weight 30 | self.rampup = exp_rampup(config.weight_rampup) 31 | self.save_freq = config.save_freq 32 | self.print_freq = config.print_freq 33 | self.device = device 34 | self.epoch = 0 35 | self.xi = config.xi 36 | self.eps = config.eps 37 | self.n_power = config.n_power 38 | 39 | def train_iteration(self, data_loader, print_freq): 40 | loop_info = defaultdict(list) 41 | label_n, unlab_n = 0, 0 42 | for batch_idx, (data, targets) in enumerate(data_loader): 43 | data, targets = data.to(self.device), targets.to(self.device) 44 | ##=== decode targets === 45 | lmask, umask = self.decode_targets(targets) 46 | lbs, ubs = lmask.float().sum().item(), umask.float().sum().item() 47 | 48 | ##=== forward === 49 | outputs = self.model(data) 50 | loss = self.ce_loss(outputs[lmask], targets[lmask]) 51 | loop_info['lloss'].append(loss.item()) 52 | 53 | ##=== Semi-supervised Training === 54 | ## local distributional smoothness (LDS) 55 | with torch.no_grad(): 56 | vlogits = outputs.clone().detach() 57 | with disable_tracking_bn_stats(self.model): 58 | r_vadv = self.gen_r_vadv(data, vlogits, self.n_power) 59 | rlogits = self.model(data + r_vadv) 60 | lds = self.cons_loss(rlogits, vlogits) 61 | lds *= self.rampup(self.epoch)*self.usp_weight 62 | loss += lds; loop_info['avat'].append(lds.item()) 63 | 64 | ## backwark 65 | self.optimizer.zero_grad() 66 | loss.backward() 67 | self.optimizer.step() 68 | 69 | ##=== log info === 70 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 71 | lacc = targets[lmask].eq(outputs[lmask].max(1)[1]).float().sum().item() 72 | uacc = targets[umask].eq(outputs[umask].max(1)[1]).float().sum().item() 73 | loop_info['lacc'].append(lacc) 74 | loop_info['uacc'].append(uacc) 75 | if print_freq>0 and (batch_idx%print_freq)==0: 76 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 77 | print(f">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 78 | return loop_info, label_n 79 | 80 | def test_iteration(self, data_loader, print_freq): 81 | loop_info = defaultdict(list) 82 | label_n, unlab_n = 0, 0 83 | for batch_idx, (data, targets) in enumerate(data_loader): 84 | data, targets = data.to(self.device), targets.to(self.device) 85 | lbs, ubs = data.size(0), -1 86 | 87 | ##=== forward === 88 | outputs = self.model(data) 89 | loss = self.ce_loss(outputs, targets) 90 | loop_info['lloss'].append(loss.item()) 91 | 92 | ##=== log info === 93 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 94 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 95 | if print_freq>0 and (batch_idx%print_freq)==0: 96 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 97 | print(f">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 98 | return loop_info, label_n 99 | 100 | def train(self, data_loader, print_freq=20): 101 | self.model.train() 102 | with torch.enable_grad(): 103 | return self.train_iteration(data_loader, print_freq) 104 | 105 | def test(self, data_loader, print_freq=10): 106 | self.model.eval() 107 | with torch.no_grad(): 108 | return self.test_iteration(data_loader, print_freq) 109 | 110 | def loop(self, epochs, train_data, test_data, scheduler=None): 111 | best_info, best_acc, n = None, 0., 0 112 | for ep in range(epochs): 113 | self.epoch = ep 114 | if scheduler is not None: scheduler.step() 115 | print("------ Training epochs: {} ------".format(ep)) 116 | self.train(train_data, self.print_freq) 117 | print("------ Testing epochs: {} ------".format(ep)) 118 | info, n = self.test(test_data, self.print_freq) 119 | acc = sum(info['lacc']) / n 120 | if acc>best_acc: best_info, best_acc = info, acc 121 | ## save model 122 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 123 | self.save(ep) 124 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 125 | 126 | def __l2_normalize(self, d): 127 | d_abs_max = torch.max( 128 | torch.abs(d.view(d.size(0),-1)), 1, keepdim=True)[0].view( 129 | d.size(0),1,1,1) 130 | d /= (1e-12 + d_abs_max) 131 | d /= torch.sqrt(1e-6 + torch.sum( 132 | torch.pow(d,2.0), tuple(range(1, len(d.size()))), keepdim=True)) 133 | return d 134 | 135 | def gen_r_vadv(self, x, vlogits, niter): 136 | # perpare random unit tensor 137 | d = torch.rand(x.shape).sub(0.5).to(self.device) 138 | d = self.__l2_normalize(d) 139 | # calc adversarial perturbation 140 | for _ in range(niter): 141 | d.requires_grad_() 142 | rlogits = self.model(x + self.xi * d) 143 | adv_dist = self.cons_loss(rlogits, vlogits) 144 | adv_dist.backward() 145 | d = self.__l2_normalize(d.grad) 146 | self.model.zero_grad() 147 | return self.eps * d 148 | 149 | def decode_targets(self, targets): 150 | label_mask = targets.ge(0) 151 | unlab_mask = targets.le(NO_LABEL) 152 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 153 | return label_mask, unlab_mask 154 | 155 | def gen_info(self, info, lbs, ubs, iteration=True): 156 | ret = [] 157 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 158 | for k, val in info.items(): 159 | n = nums[k[0]] 160 | v = val[-1] if iteration else sum(val) 161 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 162 | ret.append(s) 163 | return '\t'.join(ret) 164 | 165 | def save(self, epoch, **kwargs): 166 | if self.save_dir is not None: 167 | model_out_path = Path(self.save_dir) 168 | state = {"epoch": epoch, 169 | "weight": self.model.state_dict()} 170 | if not model_out_path.exists(): 171 | model_out_path.mkdir() 172 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 173 | torch.save(state, save_target) 174 | print('==> save model to {}'.format(save_target)) 175 | -------------------------------------------------------------------------------- /trainer/VATv2.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from itertools import cycle 9 | from collections import defaultdict 10 | 11 | from utils.loss import entropy_y_x 12 | from utils.loss import mse_with_softmax 13 | from utils.loss import kl_div_with_logit 14 | from utils.ramps import exp_rampup 15 | from utils.context import disable_tracking_bn_stats 16 | from utils.datasets import decode_label 17 | from utils.data_utils import NO_LABEL 18 | 19 | class Trainer: 20 | 21 | def __init__(self, model, optimizer, device, config): 22 | print('VAT-v2') 23 | self.model = model 24 | self.optimizer = optimizer 25 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 26 | self.cons_loss = mse_with_softmax #kl_div_with_logit 27 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 28 | config.dataset, config.num_labels, 29 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 30 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 31 | self.usp_weight = config.usp_weight 32 | self.rampup = exp_rampup(config.weight_rampup) 33 | self.save_freq = config.save_freq 34 | self.print_freq = config.print_freq 35 | self.device = device 36 | self.epoch = 0 37 | self.xi = config.xi 38 | self.eps = config.eps 39 | self.n_power = config.n_power 40 | 41 | def train_iteration(self, label_loader, unlab_loader, print_freq): 42 | loop_info = defaultdict(list) 43 | batch_idx, label_n, unlab_n = 0, 0, 0 44 | for (label_x, label_y), (unlab_x, unlab_y) in zip(cycle(label_loader), unlab_loader): 45 | label_x, label_y = label_x.to(self.device), label_y.to(self.device) 46 | unlab_x, unlab_y = unlab_x.to(self.device), unlab_y.to(self.device) 47 | ##=== decode targets of unlabeled data === 48 | self.decode_targets(unlab_y) 49 | lbs, ubs = label_x.size(0), unlab_x.size(0) 50 | 51 | ##=== forward === 52 | outputs = self.model(label_x) 53 | loss = self.ce_loss(outputs, label_y) 54 | loop_info['lloss'].append(loss.item()) 55 | 56 | ##=== Semi-supervised Training === 57 | ## local distributional smoothness (LDS) 58 | unlab_outputs = self.model(unlab_x) 59 | with torch.no_grad(): 60 | vlogits = unlab_outputs.clone().detach() 61 | with disable_tracking_bn_stats(self.model): 62 | r_vadv = self.gen_r_vadv(unlab_x, vlogits, self.n_power) 63 | rlogits = self.model(unlab_x + r_vadv) 64 | lds = self.cons_loss(rlogits, vlogits) 65 | lds *= self.rampup(self.epoch)*self.usp_weight 66 | loss += lds; loop_info['avat'].append(lds.item()) 67 | 68 | ## backwark 69 | self.optimizer.zero_grad() 70 | loss.backward() 71 | self.optimizer.step() 72 | 73 | ##=== log info === 74 | batch_idx, label_n, unlab_n = batch_idx+1, label_n+lbs, unlab_n+ubs 75 | loop_info['lacc'].append(label_y.eq(outputs.max(1)[1]).float().sum().item()) 76 | loop_info['uacc'].append(unlab_y.eq(unlab_outputs.max(1)[1]).float().sum().item()) 77 | if print_freq>0 and (batch_idx%print_freq)==0: 78 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 79 | print(">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 80 | return loop_info, label_n 81 | 82 | def test_iteration(self, data_loader, print_freq): 83 | loop_info = defaultdict(list) 84 | label_n, unlab_n = 0, 0 85 | for batch_idx, (data, targets) in enumerate(data_loader): 86 | data, targets = data.to(self.device), targets.to(self.device) 87 | lbs, ubs = data.size(0), -1 88 | 89 | ##=== forward === 90 | outputs = self.model(data) 91 | loss = self.ce_loss(outputs, targets) 92 | loop_info['lloss'].append(loss.item()) 93 | 94 | ##=== log info === 95 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 96 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 97 | if print_freq>0 and (batch_idx%print_freq)==0: 98 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 99 | print(f">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 100 | return loop_info, label_n 101 | 102 | def train(self, label_loader, unlab_loader, print_freq=20): 103 | self.model.train() 104 | with torch.enable_grad(): 105 | return self.train_iteration(label_loader, unlab_loader, print_freq) 106 | 107 | def test(self, data_loader, print_freq=10): 108 | self.model.eval() 109 | with torch.no_grad(): 110 | return self.test_iteration(data_loader, print_freq) 111 | 112 | def loop(self, epochs, label_data, unlab_data, test_data, scheduler=None): 113 | best_info, best_acc, n = None, 0., 0 114 | for ep in range(epochs): 115 | self.epoch = ep 116 | if scheduler is not None: scheduler.step() 117 | print("------ Training epochs: {} ------".format(ep)) 118 | self.train(label_data, unlab_data, self.print_freq) 119 | print("------ Testing epochs: {} ------".format(ep)) 120 | info, n = self.test(test_data, self.print_freq) 121 | acc = sum(info['lacc']) / n 122 | if acc>best_acc: best_info, best_acc = info, acc 123 | ## save model 124 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 125 | self.save(ep) 126 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 127 | 128 | def __l2_normalize(self, d): 129 | d_abs_max = torch.max( 130 | torch.abs(d.view(d.size(0),-1)), 1, keepdim=True)[0].view( 131 | d.size(0),1,1,1) 132 | d /= (1e-12 + d_abs_max) 133 | d /= torch.sqrt(1e-6 + torch.sum( 134 | torch.pow(d,2.0), tuple(range(1, len(d.size()))), keepdim=True)) 135 | return d 136 | 137 | def gen_r_vadv(self, x, vlogits, niter): 138 | # perpare random unit tensor 139 | d = torch.rand(x.shape).sub(0.5).to(self.device) 140 | d = self.__l2_normalize(d) 141 | # calc adversarial perturbation 142 | for _ in range(niter): 143 | d.requires_grad_() 144 | rlogits = self.model(x + self.xi * d) 145 | adv_dist = self.cons_loss(rlogits, vlogits) 146 | adv_dist.backward() 147 | d = self.__l2_normalize(d.grad) 148 | self.model.zero_grad() 149 | return self.eps * d 150 | 151 | def decode_targets(self, targets): 152 | label_mask = targets.ge(0) 153 | unlab_mask = targets.le(NO_LABEL) 154 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 155 | return label_mask, unlab_mask 156 | 157 | def gen_info(self, info, lbs, ubs, iteration=True): 158 | ret = [] 159 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 160 | for k, val in info.items(): 161 | n = nums[k[0]] 162 | v = val[-1] if iteration else sum(val) 163 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 164 | ret.append(s) 165 | return '\t'.join(ret) 166 | 167 | def save(self, epoch, **kwargs): 168 | if self.save_dir is not None: 169 | model_out_path = Path(self.save_dir) 170 | state = {"epoch": epoch, 171 | "weight": self.model.state_dict()} 172 | if not model_out_path.exists(): 173 | model_out_path.mkdir() 174 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 175 | torch.save(state, save_target) 176 | print('==> save model to {}'.format(save_target)) 177 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from . import MeanTeacherv1 2 | from . import MeanTeacherv2 3 | from . import PIv1 4 | from . import PIv2 5 | from . import VATv1 6 | from . import VATv2 7 | from . import ICTv1 8 | from . import ICTv2 9 | from . import ePseudoLabel2013v1 10 | from . import ePseudoLabel2013v2 11 | from . import iPseudoLabel2013v1 12 | from . import iPseudoLabel2013v2 13 | from . import eTempensv1 14 | from . import eTempensv2 15 | from . import iTempensv1 16 | from . import iTempensv2 17 | from . import MixMatch 18 | from . import iFixMatch 19 | from . import eFixMatch 20 | from . import eMixPseudoLabelv1 21 | from . import eMixPseudoLabelv2 22 | -------------------------------------------------------------------------------- /trainer/eMixPseudoLabelv1.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from collections import defaultdict 9 | 10 | from utils.loss import softmax_loss_mean 11 | from utils.loss import one_hot 12 | from utils.mixup import * 13 | from utils.ramps import exp_rampup, pseudo_rampup 14 | from utils.datasets import decode_label 15 | from utils.data_utils import NO_LABEL 16 | 17 | from pdb import set_trace 18 | 19 | class Trainer: 20 | 21 | def __init__(self, model, optimizer, device, config): 22 | print('MixUp-Pseudo-Label-v1 with {} epoch pseudo labels'.format( 23 | 'soft' if config.soft else 'hard')) 24 | self.model = model 25 | self.optimizer = optimizer 26 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 27 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 28 | config.dataset, config.num_labels, 29 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 30 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 31 | self.alpha = config.mixup_alpha 32 | self.usp_weight = config.usp_weight 33 | #self.rampup = pseudo_rampup(config.t1, config.t2) 34 | self.rampup = exp_rampup(config.weight_rampup) 35 | self.save_freq = config.save_freq 36 | self.print_freq = config.print_freq 37 | self.device = device 38 | self.epoch = 0 39 | self.soft = config.soft 40 | self.mixup_loss = mixup_ce_loss_with_softmax #mixup_mse_loss_with_softmax 41 | if not self.soft: self.mixup_loss = mixup_ce_loss_hard 42 | 43 | def train_iteration(self, data_loader, print_freq): 44 | loop_info = defaultdict(list) 45 | label_n, unlab_n = 0, 0 46 | for batch_idx, (data, targets, idxs) in enumerate(data_loader): 47 | data, targets = data.to(self.device), targets.to(self.device) 48 | ##=== decode targets === 49 | lmask, umask = self.decode_targets(targets) 50 | lbs, ubs = lmask.float().sum().item(), umask.float().sum().item() 51 | 52 | ##=== forward === 53 | outputs = self.model(data) 54 | loss = self.ce_loss(outputs[lmask], targets[lmask]) 55 | loop_info['lloss'].append(loss.item()) 56 | 57 | ##=== Semi-supervised Training === 58 | ## mixup pslab loss 59 | iter_unlab_pslab = self.epoch_pslab[idxs] 60 | mixed_x, y_a, y_b, lam = mixup_two_targets(data, iter_unlab_pslab, 61 | self.alpha, self.device, is_bias=False) 62 | mixed_outputs = self.model(mixed_x) 63 | mix_loss = self.mixup_loss(mixed_outputs, y_a, y_b, lam) 64 | mix_loss *= self.rampup(self.epoch)*self.usp_weight 65 | loss += mix_loss; loop_info['aMix'].append(mix_loss.item()) 66 | 67 | ## update pseudo labels 68 | with torch.no_grad(): 69 | pseudo_preds = outputs.clone() if self.soft else outputs.max(1)[1] 70 | self.epoch_pslab[idxs] = pseudo_preds.detach() 71 | ## backward 72 | self.optimizer.zero_grad() 73 | loss.backward() 74 | self.optimizer.step() 75 | 76 | ##=== log info === 77 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 78 | lacc = targets[lmask].eq(outputs[lmask].max(1)[1]).float().sum().item() 79 | uacc = targets[umask].eq(outputs[umask].max(1)[1]).float().sum().item() 80 | loop_info['lacc'].append(lacc) 81 | loop_info['uacc'].append(uacc) 82 | if print_freq>0 and (batch_idx%print_freq)==0: 83 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 84 | print(f">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 85 | return loop_info, label_n 86 | 87 | def test_iteration(self, data_loader, print_freq): 88 | loop_info = defaultdict(list) 89 | label_n, unlab_n = 0, 0 90 | for batch_idx, (data, targets) in enumerate(data_loader): 91 | data, targets = data.to(self.device), targets.to(self.device) 92 | lbs, ubs = data.size(0), -1 93 | 94 | ##=== forward === 95 | outputs = self.model(data) 96 | loss = self.ce_loss(outputs, targets) 97 | loop_info['lloss'].append(loss.item()) 98 | 99 | ##=== log info === 100 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 101 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 102 | if print_freq>0 and (batch_idx%print_freq)==0: 103 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 104 | print(f">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 105 | return loop_info, label_n 106 | 107 | def train(self, data_loader, print_freq=20): 108 | self.model.train() 109 | with torch.enable_grad(): 110 | return self.train_iteration(data_loader, print_freq) 111 | 112 | def test(self, data_loader, print_freq=10): 113 | self.model.eval() 114 | with torch.no_grad(): 115 | return self.test_iteration(data_loader, print_freq) 116 | 117 | def loop(self, epochs, train_data, test_data, scheduler=None): 118 | ## construct epoch pseudo labels 119 | init_pslab = self.create_soft_pslab if self.soft else self.create_pslab 120 | self.epoch_pslab = init_pslab(n_samples=len(train_data.dataset), 121 | n_classes=train_data.dataset.num_classes) 122 | ## main process 123 | best_info, best_acc, n = None, 0., 0 124 | for ep in range(epochs): 125 | self.epoch = ep 126 | if scheduler is not None: scheduler.step() 127 | print("------ Training epochs: {} ------".format(ep)) 128 | self.train(train_data, self.print_freq) 129 | print("------ Testing epochs: {} ------".format(ep)) 130 | info, n = self.test(test_data, self.print_freq) 131 | acc = sum(info['lacc']) / n 132 | if acc>best_acc: best_info, best_acc = info, acc 133 | ## save model 134 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 135 | self.save(ep) 136 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 137 | 138 | def create_pslab(self, n_samples, n_classes, dtype='rand'): 139 | if dtype=='rand': 140 | pslab = torch.randint(0, n_classes, (n_samples,)) 141 | elif dtype=='zero': 142 | pslab = torch.zeros(n_samples) 143 | else: 144 | raise ValueError('Unknown pslab dtype: {}'.format(dtype)) 145 | return pslab.long().to(self.device) 146 | 147 | def create_soft_pslab(self, n_samples, n_classes, dtype='rand'): 148 | if dtype=='rand': 149 | rlabel = torch.randint(0, n_classes, (n_samples,)).long() 150 | pslab = one_hot(rlabel, n_classes) 151 | elif dtype=='zero': 152 | pslab = torch.zeros(n_samples, n_classes) 153 | else: 154 | raise ValueError('Unknown pslab dtype: {}'.format(dtype)) 155 | return pslab.to(self.device) 156 | 157 | def decode_targets(self, targets): 158 | label_mask = targets.ge(0) 159 | unlab_mask = targets.le(NO_LABEL) 160 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 161 | return label_mask, unlab_mask 162 | 163 | def gen_info(self, info, lbs, ubs, iteration=True): 164 | ret = [] 165 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 166 | for k, val in info.items(): 167 | n = nums[k[0]] 168 | v = val[-1] if iteration else sum(val) 169 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 170 | ret.append(s) 171 | return '\t'.join(ret) 172 | 173 | def save(self, epoch, **kwargs): 174 | if self.save_dir is not None: 175 | model_out_path = Path(self.save_dir) 176 | state = {"epoch": epoch, 177 | "weight": self.model.state_dict()} 178 | if not model_out_path.exists(): 179 | model_out_path.mkdir() 180 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 181 | torch.save(state, save_target) 182 | print('==> save model to {}'.format(save_target)) 183 | -------------------------------------------------------------------------------- /trainer/ePseudoLabel2013v1.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from collections import defaultdict 9 | 10 | from utils.loss import softmax_loss_mean 11 | from utils.loss import one_hot 12 | from utils.ramps import exp_rampup, pseudo_rampup 13 | from utils.datasets import decode_label 14 | from utils.data_utils import NO_LABEL 15 | 16 | from pdb import set_trace 17 | 18 | class Trainer: 19 | 20 | def __init__(self, model, optimizer, device, config): 21 | print('Pseudo-Label-v1 2013 with {} epoch pseudo labels'.format( 22 | 'soft' if config.soft else 'hard')) 23 | self.model = model 24 | self.optimizer = optimizer 25 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 26 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 27 | config.dataset, config.num_labels, 28 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 29 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 30 | self.usp_weight = config.usp_weight 31 | #self.rampup = pseudo_rampup(config.t1, config.t2) 32 | self.rampup = exp_rampup(config.weight_rampup) 33 | self.save_freq = config.save_freq 34 | self.print_freq = config.print_freq 35 | self.device = device 36 | self.epoch = 0 37 | self.soft = config.soft 38 | self.unlab_loss = softmax_loss_mean if self.soft else self.ce_loss 39 | 40 | def train_iteration(self, data_loader, print_freq): 41 | loop_info = defaultdict(list) 42 | label_n, unlab_n = 0, 0 43 | for batch_idx, (data, targets, idxs) in enumerate(data_loader): 44 | data, targets = data.to(self.device), targets.to(self.device) 45 | ##=== decode targets === 46 | lmask, umask = self.decode_targets(targets) 47 | lbs, ubs = lmask.float().sum().item(), umask.float().sum().item() 48 | 49 | ##=== forward === 50 | outputs = self.model(data) 51 | loss = self.ce_loss(outputs[lmask], targets[lmask]) 52 | loop_info['lloss'].append(loss.item()) 53 | 54 | ##=== Semi-supervised Training === 55 | iter_unlab_pslab = self.epoch_pslab[idxs[umask]] 56 | uloss = self.unlab_loss(outputs[umask], iter_unlab_pslab) 57 | uloss *= self.rampup(self.epoch)*self.usp_weight 58 | loss += uloss; loop_info['uloss'].append(uloss.item()) 59 | ## update pseudo labels 60 | with torch.no_grad(): 61 | pseudo_preds = outputs.clone() if self.soft else outputs.max(1)[1] 62 | self.epoch_pslab[idxs] = pseudo_preds.detach() 63 | ## backward 64 | self.optimizer.zero_grad() 65 | loss.backward() 66 | self.optimizer.step() 67 | 68 | ##=== log info === 69 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 70 | lacc = targets[lmask].eq(outputs[lmask].max(1)[1]).float().sum().item() 71 | uacc = targets[umask].eq(outputs[umask].max(1)[1]).float().sum().item() 72 | loop_info['lacc'].append(lacc) 73 | loop_info['uacc'].append(uacc) 74 | if print_freq>0 and (batch_idx%print_freq)==0: 75 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 76 | print(f">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 77 | return loop_info, label_n 78 | 79 | def test_iteration(self, data_loader, print_freq): 80 | loop_info = defaultdict(list) 81 | label_n, unlab_n = 0, 0 82 | for batch_idx, (data, targets) in enumerate(data_loader): 83 | data, targets = data.to(self.device), targets.to(self.device) 84 | lbs, ubs = data.size(0), -1 85 | 86 | ##=== forward === 87 | outputs = self.model(data) 88 | loss = self.ce_loss(outputs, targets) 89 | loop_info['lloss'].append(loss.item()) 90 | 91 | ##=== log info === 92 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 93 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 94 | if print_freq>0 and (batch_idx%print_freq)==0: 95 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 96 | print(f">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 97 | return loop_info, label_n 98 | 99 | def train(self, data_loader, print_freq=20): 100 | self.model.train() 101 | with torch.enable_grad(): 102 | return self.train_iteration(data_loader, print_freq) 103 | 104 | def test(self, data_loader, print_freq=10): 105 | self.model.eval() 106 | with torch.no_grad(): 107 | return self.test_iteration(data_loader, print_freq) 108 | 109 | def loop(self, epochs, train_data, test_data, scheduler=None): 110 | ## construct epoch pseudo labels 111 | init_pslab = self.create_soft_pslab if self.soft else self.create_pslab 112 | self.epoch_pslab = init_pslab(n_samples=len(train_data.dataset), 113 | n_classes=train_data.dataset.num_classes) 114 | ## main process 115 | best_info, best_acc, n = None, 0., 0 116 | for ep in range(epochs): 117 | self.epoch = ep 118 | if scheduler is not None: scheduler.step() 119 | print("------ Training epochs: {} ------".format(ep)) 120 | self.train(train_data, self.print_freq) 121 | print("------ Testing epochs: {} ------".format(ep)) 122 | info, n = self.test(test_data, self.print_freq) 123 | acc = sum(info['lacc']) / n 124 | if acc>best_acc: best_info, best_acc = info, acc 125 | ## save model 126 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 127 | self.save(ep) 128 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 129 | 130 | def create_pslab(self, n_samples, n_classes, dtype='rand'): 131 | if dtype=='rand': 132 | pslab = torch.randint(0, n_classes, (n_samples,)) 133 | elif dtype=='zero': 134 | pslab = torch.zeros(n_samples) 135 | else: 136 | raise ValueError('Unknown pslab dtype: {}'.format(dtype)) 137 | return pslab.long().to(self.device) 138 | 139 | def create_soft_pslab(self, n_samples, n_classes, dtype='rand'): 140 | if dtype=='rand': 141 | rlabel = torch.randint(0, n_classes, (n_samples,)).long() 142 | pslab = one_hot(rlabel, n_classes) 143 | elif dtype=='zero': 144 | pslab = torch.zeros(n_samples, n_classes) 145 | else: 146 | raise ValueError('Unknown pslab dtype: {}'.format(dtype)) 147 | return pslab.to(self.device) 148 | 149 | def decode_targets(self, targets): 150 | label_mask = targets.ge(0) 151 | unlab_mask = targets.le(NO_LABEL) 152 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 153 | return label_mask, unlab_mask 154 | 155 | def gen_info(self, info, lbs, ubs, iteration=True): 156 | ret = [] 157 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 158 | for k, val in info.items(): 159 | n = nums[k[0]] 160 | v = val[-1] if iteration else sum(val) 161 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 162 | ret.append(s) 163 | return '\t'.join(ret) 164 | 165 | def save(self, epoch, **kwargs): 166 | if self.save_dir is not None: 167 | model_out_path = Path(self.save_dir) 168 | state = {"epoch": epoch, 169 | "weight": self.model.state_dict()} 170 | if not model_out_path.exists(): 171 | model_out_path.mkdir() 172 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 173 | torch.save(state, save_target) 174 | print('==> save model to {}'.format(save_target)) 175 | -------------------------------------------------------------------------------- /trainer/ePseudoLabel2013v2.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from itertools import cycle 9 | from collections import defaultdict 10 | 11 | from utils.loss import softmax_loss_mean 12 | from utils.loss import one_hot 13 | from utils.ramps import exp_rampup, pseudo_rampup 14 | from utils.datasets import decode_label 15 | from utils.data_utils import NO_LABEL 16 | 17 | from pdb import set_trace 18 | 19 | class Trainer: 20 | 21 | def __init__(self, model, optimizer, device, config): 22 | print('Pseudo-Label-v2 2013 with {} epoch pseudo labels'.format( 23 | 'soft' if config.soft else 'hard')) 24 | self.model = model 25 | self.optimizer = optimizer 26 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 27 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 28 | config.dataset, config.num_labels, 29 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 30 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 31 | self.usp_weight = config.usp_weight 32 | #self.rampup = pseudo_rampup(config.t1, config.t2) 33 | self.rampup = exp_rampup(config.weight_rampup) 34 | self.save_freq = config.save_freq 35 | self.print_freq = config.print_freq 36 | self.device = device 37 | self.epoch = 0 38 | self.soft = config.soft 39 | self.unlab_loss = softmax_loss_mean if self.soft else self.ce_loss 40 | 41 | def train_iteration(self, label_loader, unlab_loader, print_freq): 42 | loop_info = defaultdict(list) 43 | batch_idx, label_n, unlab_n = 0, 0, 0 44 | for (label_x, label_y, ldx), (unlab_x, unlab_y, udx) in zip(cycle(label_loader), unlab_loader): 45 | label_x, label_y = label_x.to(self.device), label_y.to(self.device) 46 | unlab_x, unlab_y = unlab_x.to(self.device), unlab_y.to(self.device) 47 | ##=== decode targets of unlabeled data === 48 | self.decode_targets(unlab_y) 49 | lbs, ubs = label_x.size(0), unlab_x.size(0) 50 | 51 | ##=== forward === 52 | outputs = self.model(label_x) 53 | loss = self.ce_loss(outputs, label_y) 54 | loop_info['lSup'].append(loss.item()) 55 | 56 | ##=== Semi-supervised Training Phase === 57 | ## pslab loss 58 | unlab_outputs = self.model(unlab_x) 59 | iter_unlab_pslab = self.epoch_pslab[udx] 60 | uloss = self.unlab_loss(unlab_outputs, iter_unlab_pslab) 61 | uloss *= self.rampup(self.epoch)*self.usp_weight 62 | loss += uloss; loop_info['uloss'].append(uloss.item()) 63 | 64 | ## update pseudo labels 65 | with torch.no_grad(): 66 | pseudo_preds = unlab_outputs.clone() if self.soft else unlab_outputs.max(1)[1] 67 | self.epoch_pslab[udx] = pseudo_preds.detach() 68 | 69 | ## backward 70 | self.optimizer.zero_grad() 71 | loss.backward() 72 | self.optimizer.step() 73 | 74 | ##=== log info === 75 | batch_idx, label_n, unlab_n = batch_idx+1, label_n+lbs, unlab_n+ubs 76 | loop_info['lacc'].append(label_y.eq(outputs.max(1)[1]).float().sum().item()) 77 | loop_info['uacc'].append(unlab_y.eq(unlab_outputs.max(1)[1]).float().sum().item()) 78 | if print_freq>0 and (batch_idx%print_freq)==0: 79 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 80 | print(">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 81 | return loop_info, label_n 82 | 83 | def test_iteration(self, data_loader, print_freq): 84 | loop_info = defaultdict(list) 85 | label_n, unlab_n = 0, 0 86 | for batch_idx, (data, targets) in enumerate(data_loader): 87 | data, targets = data.to(self.device), targets.to(self.device) 88 | ##=== decode targets === 89 | lbs, ubs = data.size(0), -1 90 | 91 | ##=== forward === 92 | outputs = self.model(data) 93 | loss = self.ce_loss(outputs, targets) 94 | loop_info['lloss'].append(loss.item()) 95 | 96 | ##=== log info === 97 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 98 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 99 | if print_freq>0 and (batch_idx%print_freq)==0: 100 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 101 | print(">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 102 | return loop_info, label_n 103 | 104 | def train(self, label_loader, unlab_loader, print_freq=20): 105 | self.model.train() 106 | with torch.enable_grad(): 107 | return self.train_iteration(label_loader, unlab_loader, print_freq) 108 | 109 | def test(self, data_loader, print_freq=10): 110 | self.model.eval() 111 | with torch.no_grad(): 112 | return self.test_iteration(data_loader, print_freq) 113 | 114 | def loop(self, epochs, label_data, unlab_data, test_data, scheduler=None): 115 | ## construct epoch pseudo labels 116 | init_pslab = self.create_soft_pslab if self.soft else self.create_pslab 117 | self.epoch_pslab = init_pslab(n_samples=len(unlab_data.dataset), 118 | n_classes=unlab_data.dataset.num_classes) 119 | ## main process 120 | best_info, best_acc, n = None, 0., 0 121 | for ep in range(epochs): 122 | self.epoch = ep 123 | if scheduler is not None: scheduler.step() 124 | print("------ Training epochs: {} ------".format(ep)) 125 | self.train(label_data, unlab_data, self.print_freq) 126 | print("------ Testing epochs: {} ------".format(ep)) 127 | info, n = self.test(test_data, self.print_freq) 128 | acc = sum(info['lacc']) / n 129 | if acc>best_acc: best_info, best_acc = info, acc 130 | ## save model 131 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 132 | self.save(ep) 133 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 134 | 135 | def create_pslab(self, n_samples, n_classes, dtype='rand'): 136 | if dtype=='rand': 137 | pslab = torch.randint(0, n_classes, (n_samples,)) 138 | elif dtype=='zero': 139 | pslab = torch.zeros(n_samples) 140 | else: 141 | raise ValueError('Unknown pslab dtype: {}'.format(dtype)) 142 | return pslab.long().to(self.device) 143 | 144 | def create_soft_pslab(self, n_samples, n_classes, dtype='rand'): 145 | if dtype=='rand': 146 | rlabel = torch.randint(0, n_classes, (n_samples,)).long() 147 | pslab = one_hot(rlabel, n_classes) 148 | elif dtype=='zero': 149 | pslab = torch.zeros(n_samples, n_classes) 150 | else: 151 | raise ValueError('Unknown pslab dtype: {}'.format(dtype)) 152 | return pslab.to(self.device) 153 | 154 | def decode_targets(self, targets): 155 | label_mask = targets.ge(0) 156 | unlab_mask = targets.le(NO_LABEL) 157 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 158 | return label_mask, unlab_mask 159 | 160 | def gen_info(self, info, lbs, ubs, iteration=True): 161 | ret = [] 162 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 163 | for k, val in info.items(): 164 | n = nums[k[0]] 165 | v = val[-1] if iteration else sum(val) 166 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 167 | ret.append(s) 168 | return '\t'.join(ret) 169 | 170 | def save(self, epoch, **kwargs): 171 | if self.save_dir is not None: 172 | model_out_path = Path(self.save_dir) 173 | state = {"epoch": epoch, 174 | "weight": self.model.state_dict()} 175 | if not model_out_path.exists(): 176 | model_out_path.mkdir() 177 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 178 | torch.save(state, save_target) 179 | print('==> save model to {}'.format(save_target)) 180 | -------------------------------------------------------------------------------- /trainer/eTempensv1.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from collections import defaultdict 9 | 10 | from utils.loss import mse_with_softmax 11 | from utils.ramps import exp_rampup 12 | from utils.data_utils import NO_LABEL 13 | 14 | class Trainer: 15 | 16 | def __init__(self, model, optimizer, device, config): 17 | print('Tempens-v1 with epoch pseudo labels') 18 | self.model = model 19 | self.optimizer = optimizer 20 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 21 | self.mse_loss = mse_with_softmax # F.mse_loss 22 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 23 | config.dataset, config.num_labels, 24 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 25 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 26 | self.device = device 27 | self.usp_weight = config.usp_weight 28 | self.ema_decay = config.ema_decay 29 | self.rampup = exp_rampup(config.rampup_length) 30 | self.save_freq = config.save_freq 31 | self.print_freq = config.print_freq 32 | self.epoch = 0 33 | self.start_epoch = 0 34 | 35 | def train_iteration(self, data_loader, print_freq): 36 | loop_info = defaultdict(list) 37 | label_n, unlab_n = 0, 0 38 | for batch_idx, (data, targets, idxs) in enumerate(data_loader): 39 | data, targets = data.to(self.device), targets.to(self.device) 40 | ##=== decode targets === 41 | lmask, umask = self.decode_targets(targets) 42 | lbs, ubs = lmask.float().sum().item(), umask.float().sum().item() 43 | 44 | ##=== forward === 45 | outputs = self.model(data) 46 | loss = self.ce_loss(outputs[lmask], targets[lmask]) 47 | loop_info['lloss'].append(loss.item()) 48 | 49 | ##=== Semi-supervised Training Phase === 50 | iter_unlab_pslab = self.epoch_pslab[idxs] 51 | tmp_loss = self.mse_loss(outputs, iter_unlab_pslab) 52 | tmp_loss *= self.rampup(self.epoch)*self.usp_weight 53 | loss += tmp_loss; loop_info['aTmp'].append(tmp_loss.item()) 54 | ## update pseudo labels 55 | with torch.no_grad(): 56 | self.epoch_pslab[idxs] = outputs.clone().detach() 57 | 58 | ## bachward 59 | self.optimizer.zero_grad() 60 | loss.backward() 61 | self.optimizer.step() 62 | 63 | ##=== log info === 64 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 65 | lacc = targets[lmask].eq(outputs[lmask].max(1)[1]).float().sum().item() 66 | uacc = targets[umask].eq(outputs[umask].max(1)[1]).float().sum().item() 67 | loop_info['lacc'].append(lacc) 68 | loop_info['uacc'].append(uacc) 69 | if print_freq>0 and (batch_idx%print_freq)==0: 70 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 71 | # temporal ensemble 72 | self.update_ema_predictions() # update every epoch 73 | print(">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 74 | return loop_info, label_n 75 | 76 | def test_iteration(self, data_loader, print_freq): 77 | loop_info = defaultdict(list) 78 | label_n, unlab_n = 0, 0 79 | for batch_idx, (data, targets) in enumerate(data_loader): 80 | data, targets = data.to(self.device), targets.to(self.device) 81 | lbs, ubs = data.size(0), -1 82 | 83 | ##=== forward === 84 | outputs = self.model(data) 85 | loss = self.ce_loss(outputs, targets) 86 | loop_info['lloss'].append(loss.item()) 87 | 88 | ##=== log info === 89 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 90 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 91 | if print_freq>0 and (batch_idx%print_freq)==0: 92 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 93 | print(f">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 94 | return loop_info, label_n 95 | 96 | def train(self, data_loader, print_freq=20): 97 | self.model.train() 98 | with torch.enable_grad(): 99 | return self.train_iteration(data_loader, print_freq) 100 | 101 | def test(self, data_loader, print_freq=10): 102 | self.model.eval() 103 | with torch.no_grad(): 104 | return self.test_iteration(data_loader, print_freq) 105 | 106 | def update_ema_predictions(self): 107 | """update every epoch""" 108 | self.ema_pslab = (self.ema_decay*self.ema_pslab) + (1.0-self.ema_decay)*self.epoch_pslab 109 | self.epoch_pslab = self.ema_pslab / (1.0 - self.ema_decay**((self.epoch-self.start_epoch)+1.0)) 110 | 111 | def loop(self, epochs, train_data, test_data, scheduler=None): 112 | ## construct epoch pseudo labels 113 | self.epoch_pslab = self.create_soft_pslab(n_samples=len(train_data.dataset), 114 | n_classes=train_data.dataset.num_classes, 115 | dtype='rand') 116 | self.ema_pslab = self.create_soft_pslab(n_samples=len(train_data.dataset), 117 | n_classes=train_data.dataset.num_classes, 118 | dtype='zero') 119 | ## main process 120 | best_info, best_acc, n = None, 0., 0 121 | for ep in range(epochs): 122 | self.epoch = ep 123 | if scheduler is not None: scheduler.step() 124 | print("------ Training epochs: {} ------".format(ep)) 125 | self.train(train_data, self.print_freq) 126 | print("------ Testing epochs: {} ------".format(ep)) 127 | info, n = self.test(test_data, self.print_freq) 128 | acc = sum(info['lacc']) / n 129 | if acc>best_acc: best_info, best_acc = info, acc 130 | ## save model 131 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 132 | self.save(ep) 133 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 134 | 135 | def create_soft_pslab(self, n_samples, n_classes, dtype='rand'): 136 | if dtype=='rand': 137 | pslab = torch.randint(0, n_classes, (n_samples,n_classes)) 138 | elif dtype=='zero': 139 | pslab = torch.zeros(n_samples, n_classes) 140 | else: 141 | raise ValueError('Unknown pslab dtype: {}'.format(dtype)) 142 | return pslab.to(self.device) 143 | 144 | def decode_targets(self, targets): 145 | labeled_mask = targets.ge(0) 146 | unlabeled_mask = targets.le(NO_LABEL) 147 | targets[unlabeled_mask] = NO_LABEL*targets[unlabeled_mask]-1 148 | return labeled_mask, unlabeled_mask 149 | 150 | def gen_info(self, info, lbs, ubs, iteration=True): 151 | ret = [] 152 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 153 | for k, val in info.items(): 154 | n = nums[k[0]] 155 | v = val[-1] if iteration else sum(val) 156 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 157 | ret.append(s) 158 | return '\t'.join(ret) 159 | 160 | def save(self, epoch, **kwargs): 161 | if self.save_dir is not None: 162 | model_out_path = Path(self.save_dir) 163 | state = {"epoch": epoch, 164 | "weight": self.model.state_dict()} 165 | if not model_out_path.exists(): 166 | model_out_path.mkdir() 167 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 168 | torch.save(state, save_target) 169 | print('==> save model to {}'.format(save_target)) 170 | -------------------------------------------------------------------------------- /trainer/eTempensv2.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from itertools import cycle 9 | from collections import defaultdict 10 | 11 | from utils.loss import mse_with_softmax 12 | from utils.ramps import exp_rampup 13 | from utils.datasets import decode_label 14 | from utils.data_utils import NO_LABEL 15 | 16 | class Trainer: 17 | 18 | def __init__(self, model, optimizer, device, config): 19 | print('Tempens-v2 with epoch pseudo labels') 20 | self.model = model 21 | self.optimizer = optimizer 22 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 23 | self.mse_loss = mse_with_softmax # F.mse_loss 24 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 25 | config.dataset, config.num_labels, 26 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 27 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 28 | self.device = device 29 | self.usp_weight = config.usp_weight 30 | self.ema_decay = config.ema_decay 31 | self.rampup = exp_rampup(config.rampup_length) 32 | self.save_freq = config.save_freq 33 | self.print_freq = config.print_freq 34 | self.epoch = 0 35 | self.start_epoch = 0 36 | 37 | def train_iteration(self, label_loader, unlab_loader, print_freq): 38 | loop_info = defaultdict(list) 39 | batch_idx, label_n, unlab_n = 0, 0, 0 40 | for (label_x, label_y, ldx), (unlab_x, unlab_y, udx) in zip(cycle(label_loader), unlab_loader): 41 | label_x, label_y = label_x.to(self.device), label_y.to(self.device) 42 | unlab_x, unlab_y = unlab_x.to(self.device), unlab_y.to(self.device) 43 | ##=== decode targets of unlabeled data === 44 | self.decode_targets(unlab_y) 45 | lbs, ubs = label_x.size(0), unlab_x.size(0) 46 | 47 | ##=== forward === 48 | outputs = self.model(label_x) 49 | loss = self.ce_loss(outputs, label_y) 50 | loop_info['lSup'].append(loss.item()) 51 | 52 | ##=== Semi-supervised Training Phase === 53 | unlab_outputs = self.model(unlab_x) 54 | iter_unlab_pslab = self.epoch_pslab[udx] 55 | uloss = self.mse_loss(unlab_outputs, iter_unlab_pslab) 56 | uloss *= self.rampup(self.epoch)*self.usp_weight 57 | loss += uloss; loop_info['uTmp'].append(uloss.item()) 58 | ## update pseudo labels 59 | with torch.no_grad(): 60 | self.epoch_pslab[udx] = unlab_outputs.clone().detach() 61 | ## bachward 62 | self.optimizer.zero_grad() 63 | loss.backward() 64 | self.optimizer.step() 65 | 66 | ##=== log info === 67 | batch_idx, label_n, unlab_n = batch_idx+1, label_n+lbs, unlab_n+ubs 68 | loop_info['lacc'].append(label_y.eq(outputs.max(1)[1]).float().sum().item()) 69 | loop_info['uacc'].append(unlab_y.eq(unlab_outputs.max(1)[1]).float().sum().item()) 70 | if print_freq>0 and (batch_idx%print_freq)==0: 71 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 72 | # temporal ensemble 73 | self.update_ema_predictions() # update every epoch 74 | print(">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 75 | return loop_info, label_n 76 | 77 | def test_iteration(self, data_loader, print_freq): 78 | loop_info = defaultdict(list) 79 | label_n, unlab_n = 0, 0 80 | for batch_idx, (data, targets) in enumerate(data_loader): 81 | data, targets = data.to(self.device), targets.to(self.device) 82 | ##=== decode targets === 83 | lbs, ubs = data.size(0), -1 84 | 85 | ##=== forward === 86 | outputs = self.model(data) 87 | loss = self.ce_loss(outputs, targets) 88 | loop_info['lloss'].append(loss.item()) 89 | 90 | ##=== log info === 91 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 92 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 93 | if print_freq>0 and (batch_idx%print_freq)==0: 94 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 95 | print(">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 96 | return loop_info, label_n 97 | 98 | def train(self, label_loader, unlab_loader, print_freq=20): 99 | self.model.train() 100 | with torch.enable_grad(): 101 | return self.train_iteration(label_loader, unlab_loader, print_freq) 102 | 103 | def test(self, data_loader, print_freq=10): 104 | self.model.eval() 105 | with torch.no_grad(): 106 | return self.test_iteration(data_loader, print_freq) 107 | 108 | def update_ema_predictions(self): 109 | """update every epoch""" 110 | self.ema_pslab = (self.ema_decay*self.ema_pslab) + (1.0-self.ema_decay)*self.epoch_pslab 111 | self.epoch_pslab = self.ema_pslab / (1.0 - self.ema_decay**((self.epoch-self.start_epoch)+1.0)) 112 | 113 | def loop(self, epochs, label_data, unlab_data, test_data, scheduler=None): 114 | ## construct epoch pseudo labels 115 | self.epoch_pslab = self.create_soft_pslab(n_samples=len(unlab_data.dataset), 116 | n_classes=unlab_data.dataset.num_classes, 117 | dtype='rand') 118 | self.ema_pslab = self.create_soft_pslab(n_samples=len(unlab_data.dataset), 119 | n_classes=unlab_data.dataset.num_classes, 120 | dtype='zero') 121 | ## main process 122 | best_info, best_acc, n = None, 0., 0 123 | for ep in range(epochs): 124 | self.epoch = ep 125 | if scheduler is not None: scheduler.step() 126 | print("------ Training epochs: {} ------".format(ep)) 127 | self.train(label_data, unlab_data, self.print_freq) 128 | print("------ Testing epochs: {} ------".format(ep)) 129 | info, n = self.test(test_data, self.print_freq) 130 | acc = sum(info['lacc']) / n 131 | if acc>best_acc: best_info, best_acc = info, acc 132 | ## save model 133 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 134 | self.save(ep) 135 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 136 | 137 | def create_soft_pslab(self, n_samples, n_classes, dtype='rand'): 138 | if dtype=='rand': 139 | pslab = torch.randint(0, n_classes, (n_samples,n_classes)) 140 | elif dtype=='zero': 141 | pslab = torch.zeros(n_samples, n_classes) 142 | else: 143 | raise ValueError('Unknown pslab dtype: {}'.format(dtype)) 144 | return pslab.to(self.device) 145 | 146 | def decode_targets(self, targets): 147 | label_mask = targets.ge(0) 148 | unlab_mask = targets.le(NO_LABEL) 149 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 150 | return label_mask, unlab_mask 151 | 152 | def gen_info(self, info, lbs, ubs, iteration=True): 153 | ret = [] 154 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 155 | for k, val in info.items(): 156 | n = nums[k[0]] 157 | v = val[-1] if iteration else sum(val) 158 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 159 | ret.append(s) 160 | return '\t'.join(ret) 161 | 162 | def save(self, epoch, **kwargs): 163 | if self.save_dir is not None: 164 | model_out_path = Path(self.save_dir) 165 | state = {"epoch": epoch, 166 | "weight": self.model.state_dict()} 167 | if not model_out_path.exists(): 168 | model_out_path.mkdir() 169 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 170 | torch.save(state, save_target) 171 | print('==> save model to {}'.format(save_target)) 172 | -------------------------------------------------------------------------------- /trainer/iFixMatch.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from collections import defaultdict 9 | from itertools import cycle 10 | 11 | from utils.loss import one_hot 12 | from utils.ramps import exp_rampup 13 | from utils.mixup import * 14 | from utils.datasets import decode_label 15 | from utils.data_utils import NO_LABEL 16 | 17 | class Trainer: 18 | 19 | def __init__(self, model, ema_model, optimizer, device, config): 20 | print("FixMatch") 21 | self.model = model 22 | self.ema_model = ema_model 23 | self.optimizer = optimizer 24 | self.lce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 25 | self.uce_loss = torch.nn.CrossEntropyLoss(reduction='none') 26 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 27 | config.dataset, config.num_labels, 28 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 29 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 30 | self.usp_weight = config.usp_weight 31 | self.threshold = config.threshold 32 | self.ema_decay = config.ema_decay 33 | self.rampup = exp_rampup(config.weight_rampup) 34 | self.save_freq = config.save_freq 35 | self.print_freq = config.print_freq 36 | self.device = device 37 | self.global_step = 0 38 | self.epoch = 0 39 | 40 | def train_iteration(self, label_loader, unlab_loader, print_freq): 41 | loop_info = defaultdict(list) 42 | batch_idx, label_n, unlab_n = 0, 0, 0 43 | for ((x1,_), label_y), ((wu,su), unlab_y) in zip(cycle(label_loader), unlab_loader): 44 | self.global_step += 1; batch_idx+=1; 45 | label_x, weak_u, strong_u = x1.to(self.device), wu.to(self.device), su.to(self.device) 46 | label_y, unlab_y = label_y.to(self.device), unlab_y.to(self.device) 47 | ##=== decode targets === 48 | self.decode_targets(unlab_y) 49 | lbs, ubs = x1.size(0), wu.size(0) 50 | 51 | ##=== forward === 52 | outputs = self.model(label_x) 53 | loss = self.lce_loss(outputs, label_y) 54 | loop_info['lloss'].append(loss.item()) 55 | 56 | ##=== Semi-supervised Training === 57 | ## update mean-teacher 58 | self.update_ema(self.model, self.ema_model, self.ema_decay, self.global_step) 59 | ## use the outputs of weak unlabeled data as pseudo labels 60 | with torch.no_grad(): 61 | woutputs = self.model(weak_u) 62 | woutputs = F.softmax(woutputs, 1) 63 | wprobs, wpslab = woutputs.max(1) 64 | ## cross-entropy loss for confident unlabeled data 65 | mask = wprobs.ge(self.threshold).float() 66 | soutputs = self.model(strong_u) 67 | uloss = torch.mean(mask* self.uce_loss(soutputs, wpslab)) 68 | uloss *= self.usp_weight 69 | loss += uloss; loop_info['uloss'].append(uloss.item()) 70 | 71 | ##=== backwark === 72 | self.optimizer.zero_grad() 73 | loss.backward() 74 | self.optimizer.step() 75 | 76 | ##=== log info === 77 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 78 | loop_info['lacc'].append(label_y.eq(outputs.max(1)[1]).float().sum().item()) 79 | loop_info['uacc'].append(unlab_y.eq(soutputs.max(1)[1]).float().sum().item()) 80 | if print_freq>0 and (batch_idx%print_freq)==0: 81 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 82 | self.update_bn(self.model, self.ema_model) 83 | print(f">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 84 | return loop_info, label_n 85 | 86 | def test_iteration(self, data_loader, print_freq): 87 | loop_info = defaultdict(list) 88 | label_n, unlab_n = 0, 0 89 | for batch_idx, (data, targets) in enumerate(data_loader): 90 | data, targets = data.to(self.device), targets.to(self.device) 91 | lbs, ubs = data.size(0), -1 92 | 93 | ##=== forward === 94 | outputs = self.model(data) 95 | ema_outputs = self.ema_model(data) 96 | 97 | ##=== log info === 98 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 99 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 100 | loop_info['l2acc'].append(targets.eq(ema_outputs.max(1)[1]).float().sum().item()) 101 | if print_freq>0 and (batch_idx%print_freq)==0: 102 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 103 | print(f">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 104 | return loop_info, label_n 105 | 106 | def train(self, label_loader, unlab_loader, print_freq=20): 107 | self.model.train() 108 | self.ema_model.train() 109 | with torch.enable_grad(): 110 | return self.train_iteration(label_loader, unlab_loader, print_freq) 111 | 112 | def test(self, data_loader, print_freq=10): 113 | self.model.eval() 114 | self.ema_model.eval() 115 | with torch.no_grad(): 116 | return self.test_iteration(data_loader, print_freq) 117 | 118 | def loop(self, epochs, label_data, unlab_data, test_data, scheduler=None): 119 | best_acc, n, best_info = 0., 0., None 120 | for ep in range(epochs): 121 | self.epoch = ep 122 | if scheduler is not None: scheduler.step() 123 | print("------ Training epochs: {} ------".format(ep)) 124 | self.train(label_data, unlab_data, self.print_freq) 125 | print("------ Testing epochs: {} ------".format(ep)) 126 | info, n = self.test(test_data, self.print_freq) 127 | acc = sum(info['lacc'])/n 128 | if acc>best_acc: best_acc, best_info = acc, info 129 | ## save model 130 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 131 | self.save(ep) 132 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 133 | 134 | def update_ema(self, model, ema_model, alpha, global_step): 135 | alpha = min(1 - 1 / (global_step +1), alpha) 136 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 137 | ema_param.data.mul_(alpha).add_(1-alpha, param.data) 138 | 139 | def update_bn(self, model, ema_model): 140 | for m2, m1 in zip(ema_model.named_modules(), model.named_modules()): 141 | if ('bn' in m2[0]) and ('bn' in m1[0]): 142 | bn2, bn1 = m2[1].state_dict(), m1[1].state_dict() 143 | bn2['running_mean'].data.copy_(bn1['running_mean'].data) 144 | bn2['running_var'].data.copy_(bn1['running_var'].data) 145 | bn2['num_batches_tracked'].data.copy_(bn1['num_batches_tracked'].data) 146 | 147 | def decode_targets(self, targets): 148 | label_mask = targets.ge(0) 149 | unlab_mask = targets.le(NO_LABEL) 150 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 151 | return label_mask, unlab_mask 152 | 153 | def gen_info(self, info, lbs, ubs, iteration=True): 154 | ret = [] 155 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 156 | for k, val in info.items(): 157 | n = nums[k[0]] 158 | v = val[-1] if iteration else sum(val) 159 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 160 | ret.append(s) 161 | return '\t'.join(ret) 162 | 163 | def save(self, epoch, **kwargs): 164 | if self.save_dir is not None: 165 | model_out_path = Path(self.save_dir) 166 | state = {"epoch": epoch, 167 | "weight": self.model.state_dict()} 168 | if not model_out_path.exists(): 169 | model_out_path.mkdir() 170 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 171 | torch.save(state, save_target) 172 | print('==> save model to {}'.format(save_target)) 173 | -------------------------------------------------------------------------------- /trainer/iPseudoLabel2013v1.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from collections import defaultdict 9 | 10 | from utils.ramps import exp_rampup, pseudo_rampup 11 | from utils.datasets import decode_label 12 | from utils.data_utils import NO_LABEL 13 | 14 | class Trainer: 15 | 16 | def __init__(self, model, optimizer, device, config): 17 | print('Pseudo-Label-v1 2013 with iteration pseudo labels') 18 | self.model = model 19 | self.optimizer = optimizer 20 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 21 | self.save_dir = '{}_{}-{}_{}'.format(config.arch, config.dataset, 22 | config.num_labels, 23 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 24 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 25 | self.usp_weight = config.usp_weight 26 | #self.rampup = pseudo_rampup(config.t1, config.t2) 27 | self.rampup = exp_rampup(config.weight_rampup) 28 | self.save_freq = config.save_freq 29 | self.print_freq = config.print_freq 30 | self.device = device 31 | self.epoch = 0 32 | 33 | def train_iteration(self, data_loader, print_freq): 34 | loop_info = defaultdict(list) 35 | label_n, unlab_n = 0, 0 36 | for batch_idx, (data, targets) in enumerate(data_loader): 37 | data, targets = data.to(self.device), targets.to(self.device) 38 | ##=== decode targets === 39 | lmask, umask = self.decode_targets(targets) 40 | lbs, ubs = lmask.float().sum().item(), umask.float().sum().item() 41 | 42 | ##=== forward === 43 | outputs = self.model(data) 44 | loss = self.ce_loss(outputs[lmask], targets[lmask]) 45 | loop_info['lloss'].append(loss.item()) 46 | 47 | ##=== Semi-supervised Training === 48 | with torch.no_grad(): 49 | iter_unlab_pslab = outputs.max(1)[1] 50 | iter_unlab_pslab.detach_() 51 | uloss = self.ce_loss(outputs[umask], iter_unlab_pslab[umask]) 52 | uloss *= self.rampup(self.epoch)*self.usp_weight 53 | loss += uloss; loop_info['uloss'].append(uloss.item()) 54 | ## backward 55 | self.optimizer.zero_grad() 56 | loss.backward() 57 | self.optimizer.step() 58 | 59 | ##=== log info === 60 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 61 | lacc = targets[lmask].eq(outputs[lmask].max(1)[1]).float().sum().item() 62 | uacc = targets[umask].eq(outputs[umask].max(1)[1]).float().sum().item() 63 | loop_info['lacc'].append(lacc) 64 | loop_info['uacc'].append(uacc) 65 | if print_freq>0 and (batch_idx%print_freq)==0: 66 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 67 | print(">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 68 | return loop_info, label_n 69 | 70 | def test_iteration(self, data_loader, print_freq): 71 | loop_info = defaultdict(list) 72 | label_n, unlab_n = 0, 0 73 | for batch_idx, (data, targets) in enumerate(data_loader): 74 | data, targets = data.to(self.device), targets.to(self.device) 75 | lbs, ubs = data.size(0), -1 76 | 77 | ##=== forward === 78 | outputs = self.model(data) 79 | loss = self.ce_loss(outputs, targets) 80 | loop_info['lloss'].append(loss.item()) 81 | 82 | ##=== log info === 83 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 84 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 85 | if print_freq>0 and (batch_idx%print_freq)==0: 86 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 87 | print(f">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 88 | return loop_info, label_n 89 | 90 | 91 | def train(self, data_loader, print_freq=20): 92 | self.model.train() 93 | with torch.enable_grad(): 94 | return self.train_iteration(data_loader, print_freq) 95 | 96 | def test(self, data_loader, print_freq=10): 97 | self.model.eval() 98 | with torch.no_grad(): 99 | return self.test_iteration(data_loader, print_freq) 100 | 101 | def loop(self, epochs, train_data, test_data, scheduler=None): 102 | best_info, best_acc, n = None, 0., 0 103 | for ep in range(epochs): 104 | self.epoch = ep 105 | if scheduler is not None: scheduler.step() 106 | print("------ Training epochs: {} ------".format(ep)) 107 | self.train(train_data, self.print_freq) 108 | print("------ Testing epochs: {} ------".format(ep)) 109 | info, n = self.test(test_data, self.print_freq) 110 | acc = sum(info['lacc']) / n 111 | if acc>best_acc: best_info, best_acc = info, acc 112 | ## save model 113 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 114 | self.save(ep) 115 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 116 | 117 | def decode_targets(self, targets): 118 | label_mask = targets.ge(0) 119 | unlab_mask = targets.le(NO_LABEL) 120 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 121 | return label_mask, unlab_mask 122 | 123 | def gen_info(self, info, lbs, ubs, iteration=True): 124 | ret = [] 125 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 126 | for k, val in info.items(): 127 | n = nums[k[0]] 128 | v = val[-1] if iteration else sum(val) 129 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 130 | ret.append(s) 131 | return '\t'.join(ret) 132 | 133 | def save(self, epoch, **kwargs): 134 | if self.save_dir is not None: 135 | model_out_path = Path(self.save_dir) 136 | state = {"epoch": epoch, 137 | "weight": self.model.state_dict()} 138 | if not model_out_path.exists(): 139 | model_out_path.mkdir() 140 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 141 | torch.save(state, save_target) 142 | print('==> save model to {}'.format(save_target)) 143 | -------------------------------------------------------------------------------- /trainer/iPseudoLabel2013v2.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from itertools import cycle 9 | from collections import defaultdict 10 | 11 | from utils.datasets import decode_label 12 | from utils.data_utils import NO_LABEL 13 | from utils.ramps import exp_rampup, pseudo_rampup 14 | 15 | class Trainer: 16 | 17 | def __init__(self, model, optimizer, device, config): 18 | print('Pseudo-Label-v2 2013 with iteration pseudo labels') 19 | self.model = model 20 | self.optimizer = optimizer 21 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 22 | self.save_dir = '{}-{}_{}-{}_{}'.format(config.arch, config.model, 23 | config.dataset, config.num_labels, 24 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 25 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 26 | self.usp_weight = config.usp_weight 27 | #self.rampup = pseudo_rampup(config.t1, config.t2) 28 | self.rampup = exp_rampup(config.weight_rampup) 29 | self.save_freq = config.save_freq 30 | self.print_freq = config.print_freq 31 | self.device = device 32 | self.epoch = 0 33 | 34 | def train_iteration(self, label_loader, unlab_loader, print_freq): 35 | loop_info = defaultdict(list) 36 | batch_idx, label_n, unlab_n = 0, 0, 0 37 | for (label_x, label_y), (unlab_x, unlab_y) in zip(cycle(label_loader), unlab_loader): 38 | label_x, label_y = label_x.to(self.device), label_y.to(self.device) 39 | unlab_x, unlab_y = unlab_x.to(self.device), unlab_y.to(self.device) 40 | ##=== decode targets of unlabeled data === 41 | self.decode_targets(unlab_y) 42 | lbs, ubs = label_x.size(0), unlab_x.size(0) 43 | 44 | ##=== forward === 45 | outputs = self.model(label_x) 46 | loss = self.ce_loss(outputs, label_y) 47 | loop_info['lloss'].append(loss.item()) 48 | 49 | ##=== Semi-supervised Training === 50 | ## pslab loss 51 | unlab_outputs = self.model(unlab_x) 52 | with torch.no_grad(): 53 | iter_unlab_pslab = unlab_outputs.max(1)[1] 54 | iter_unlab_pslab.detach_() 55 | uloss = self.ce_loss(unlab_outputs, iter_unlab_pslab) 56 | uloss *= self.rampup(self.epoch)*self.usp_weight 57 | loss += uloss; loop_info['uloss'].append(uloss.item()) 58 | ## backward 59 | self.optimizer.zero_grad() 60 | loss.backward() 61 | self.optimizer.step() 62 | 63 | ##=== log info === 64 | batch_idx, label_n, unlab_n = batch_idx+1, label_n+lbs, unlab_n+ubs 65 | loop_info['lacc'].append(label_y.eq(outputs.max(1)[1]).float().sum().item()) 66 | loop_info['uacc'].append(unlab_y.eq(unlab_outputs.max(1)[1]).float().sum().item()) 67 | if print_freq>0 and (batch_idx%print_freq)==0: 68 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 69 | print(">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 70 | return loop_info, label_n 71 | 72 | def test_iteration(self, data_loader, print_freq): 73 | loop_info = defaultdict(list) 74 | label_n, unlab_n = 0, 0 75 | for batch_idx, (data, targets) in enumerate(data_loader): 76 | data, targets = data.to(self.device), targets.to(self.device) 77 | ##=== decode targets === 78 | lbs, ubs = data.size(0), -1 79 | 80 | ##=== forward === 81 | outputs = self.model(data) 82 | loss = self.ce_loss(outputs, targets) 83 | loop_info['lloss'].append(loss.item()) 84 | 85 | ##=== log info === 86 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 87 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 88 | if print_freq>0 and (batch_idx%print_freq)==0: 89 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 90 | print(">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 91 | return loop_info, label_n 92 | 93 | def train(self, label_loader, unlab_loader, print_freq=20): 94 | self.model.train() 95 | with torch.enable_grad(): 96 | return self.train_iteration(label_loader, unlab_loader, print_freq) 97 | 98 | def test(self, data_loader, print_freq=10): 99 | self.model.eval() 100 | with torch.no_grad(): 101 | return self.test_iteration(data_loader, print_freq) 102 | 103 | def loop(self, epochs, label_data, unlab_data, test_data, scheduler=None): 104 | best_info, best_acc, n = None, 0., 0 105 | for ep in range(epochs): 106 | self.epoch = ep 107 | if scheduler is not None: scheduler.step() 108 | print("------ Training epochs: {} ------".format(ep)) 109 | self.train(label_data, unlab_data, self.print_freq) 110 | print("------ Testing epochs: {} ------".format(ep)) 111 | info, n = self.test(test_data, self.print_freq) 112 | acc = sum(info['lacc']) / n 113 | if acc>best_acc: best_info, best_acc = info, acc 114 | ## save model 115 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 116 | self.save(ep) 117 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 118 | 119 | def decode_targets(self, targets): 120 | label_mask = targets.ge(0) 121 | unlab_mask = targets.le(NO_LABEL) 122 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 123 | return label_mask, unlab_mask 124 | 125 | def gen_info(self, info, lbs, ubs, iteration=True): 126 | ret = [] 127 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 128 | for k, val in info.items(): 129 | n = nums[k[0]] 130 | v = val[-1] if iteration else sum(val) 131 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 132 | ret.append(s) 133 | return '\t'.join(ret) 134 | 135 | def save(self, epoch, **kwargs): 136 | if self.save_dir is not None: 137 | model_out_path = Path(self.save_dir) 138 | state = {"epoch": epoch, 139 | "weight": self.model.state_dict()} 140 | if not model_out_path.exists(): 141 | model_out_path.mkdir() 142 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 143 | torch.save(state, save_target) 144 | print('==> save model to {}'.format(save_target)) 145 | -------------------------------------------------------------------------------- /trainer/iTempensv1.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from collections import defaultdict 9 | 10 | from utils.loss import mse_with_softmax 11 | from utils.ramps import exp_rampup, pseudo_rampup 12 | from utils.datasets import decode_label 13 | from utils.data_utils import NO_LABEL 14 | 15 | from pdb import set_trace 16 | 17 | class Trainer: 18 | 19 | def __init__(self, model, optimizer, device, config): 20 | print('Tempens-v1 with iteration pseudo labels') 21 | self.model = model 22 | self.optimizer = optimizer 23 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 24 | self.mse_loss = mse_with_softmax # F.mse_loss 25 | self.save_dir = '{}_{}-{}_{}'.format(config.arch, config.dataset, 26 | config.num_labels, 27 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 28 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 29 | self.usp_weight = config.usp_weight 30 | self.save_freq = config.save_freq 31 | self.print_freq = config.print_freq 32 | self.device = device 33 | self.epoch = 0 34 | self.start_epoch = 0 35 | self.ema_decay = config.ema_decay 36 | self.rampup = exp_rampup(config.rampup_length) 37 | 38 | def train_iteration(self, data_loader, print_freq): 39 | loop_info = defaultdict(list) 40 | label_n, unlab_n = 0, 0 41 | for batch_idx, (data, targets, idxs) in enumerate(data_loader): 42 | data, targets = data.to(self.device), targets.to(self.device) 43 | ##=== decode targets === 44 | lmask, umask = self.decode_targets(targets) 45 | lbs, ubs = lmask.float().sum().item(), umask.float().sum().item() 46 | 47 | ##=== forward === 48 | outputs = self.model(data) 49 | loss = self.ce_loss(outputs[lmask], targets[lmask]) 50 | loop_info['lloss'].append(loss.item()) 51 | 52 | ##=== Semi-supervised Training === 53 | with torch.no_grad(): 54 | ema_iter_pslab = self.update_ema(outputs.clone().detach(), idxs) 55 | uloss = self.mse_loss(outputs, ema_iter_pslab) 56 | uloss *= self.rampup(self.epoch)*self.usp_weight 57 | loss += uloss; loop_info['aTmp'].append(uloss.item()) 58 | 59 | ## bachward 60 | self.optimizer.zero_grad() 61 | loss.backward() 62 | self.optimizer.step() 63 | 64 | ##=== log info === 65 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 66 | lacc = targets[lmask].eq(outputs[lmask].max(1)[1]).float().sum().item() 67 | uacc = targets[umask].eq(outputs[umask].max(1)[1]).float().sum().item() 68 | u2acc = targets[umask].eq(ema_iter_pslab[umask].max(1)[1]).float().sum().item() 69 | loop_info['lacc'].append(lacc) 70 | loop_info['uacc'].append(uacc) 71 | loop_info['u2acc'].append(u2acc) 72 | if print_freq>0 and (batch_idx%print_freq)==0: 73 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 74 | print(">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 75 | return loop_info, label_n 76 | 77 | def test_iteration(self, data_loader, print_freq): 78 | loop_info = defaultdict(list) 79 | label_n, unlab_n = 0, 0 80 | for batch_idx, (data, targets) in enumerate(data_loader): 81 | data, targets = data.to(self.device), targets.to(self.device) 82 | lbs, ubs = data.size(0), -1 83 | 84 | ##=== forward === 85 | outputs = self.model(data) 86 | loss = self.ce_loss(outputs, targets) 87 | loop_info['lloss'].append(loss.item()) 88 | 89 | ##=== log info === 90 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 91 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 92 | if print_freq>0 and (batch_idx%print_freq)==0: 93 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 94 | print(f">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 95 | return loop_info, label_n 96 | 97 | def train(self, data_loader, print_freq=20): 98 | self.model.train() 99 | with torch.enable_grad(): 100 | return self.train_iteration(data_loader, print_freq) 101 | 102 | def test(self, data_loader, print_freq=10): 103 | self.model.eval() 104 | with torch.no_grad(): 105 | return self.test_iteration(data_loader, print_freq) 106 | 107 | def update_ema(self, iter_pslab, idxs): 108 | """update every iteration""" 109 | ema_iter_pslab = (self.ema_decay*self.ema_pslab[idxs]) + (1.0-self.ema_decay)*iter_pslab 110 | self.ema_pslab[idxs] = ema_iter_pslab 111 | return ema_iter_pslab / (1.0 - self.ema_decay**((self.epoch-self.start_epoch)+1.0)) 112 | 113 | def loop(self, epochs, train_data, test_data, scheduler=None): 114 | ## construct epoch pseudo labels 115 | self.ema_pslab = self.create_soft_pslab(n_samples=len(train_data.dataset), 116 | n_classes=train_data.dataset.num_classes, 117 | dtype='zero') 118 | ## main process 119 | best_info, best_acc, n = None, 0., 0 120 | for ep in range(epochs): 121 | self.epoch = ep 122 | if scheduler is not None: scheduler.step() 123 | print("------ Training epochs: {} ------".format(ep)) 124 | self.train(train_data, self.print_freq) 125 | print("------ Testing epochs: {} ------".format(ep)) 126 | info, n = self.test(test_data, self.print_freq) 127 | acc = sum(info['lacc']) / n 128 | if acc>best_acc: best_info, best_acc = info, acc 129 | ## save model 130 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 131 | self.save(ep) 132 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 133 | 134 | def create_soft_pslab(self, n_samples, n_classes, dtype='rand'): 135 | if dtype=='rand': 136 | pslab = torch.randint(0, n_classes, (n_samples,n_classes)) 137 | elif dtype=='zero': 138 | pslab = torch.zeros(n_samples, n_classes) 139 | else: 140 | raise ValueError('Unknown pslab dtype: {}'.format(dtype)) 141 | return pslab.to(self.device) 142 | 143 | def decode_targets(self, targets): 144 | label_mask = targets.ge(0) 145 | unlab_mask = targets.le(NO_LABEL) 146 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 147 | return label_mask, unlab_mask 148 | 149 | def gen_info(self, info, lbs, ubs, iteration=True): 150 | ret = [] 151 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 152 | for k, val in info.items(): 153 | n = nums[k[0]] 154 | v = val[-1] if iteration else sum(val) 155 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 156 | ret.append(s) 157 | return '\t'.join(ret) 158 | 159 | 160 | def save(self, epoch, **kwargs): 161 | if self.save_dir is not None: 162 | model_out_path = Path(self.save_dir) 163 | state = {"epoch": epoch, 164 | "weight": self.model.state_dict()} 165 | if not model_out_path.exists(): 166 | model_out_path.mkdir() 167 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 168 | torch.save(state, save_target) 169 | print('==> save model to {}'.format(save_target)) 170 | -------------------------------------------------------------------------------- /trainer/iTempensv2.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import datetime 7 | from pathlib import Path 8 | from itertools import cycle 9 | from collections import defaultdict 10 | 11 | from utils.loss import mse_with_softmax 12 | from utils.ramps import exp_rampup, pseudo_rampup 13 | from utils.datasets import decode_label 14 | from utils.data_utils import NO_LABEL 15 | 16 | class Trainer: 17 | 18 | def __init__(self, model, optimizer, device, config): 19 | print('Tempens-v2 with iteration pseudo labels') 20 | self.model = model 21 | self.optimizer = optimizer 22 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=NO_LABEL) 23 | self.mse_loss = mse_with_softmax # F.mse_loss 24 | self.save_dir = '{}_{}-{}_{}'.format(config.arch, config.dataset, 25 | config.num_labels, 26 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")) 27 | self.save_dir = os.path.join(config.save_dir, self.save_dir) 28 | self.save_freq = config.save_freq 29 | self.print_freq = config.print_freq 30 | self.device = device 31 | self.epoch = 0 32 | self.start_epoch = 0 33 | self.usp_weight = config.usp_weight 34 | self.ema_decay = config.ema_decay 35 | self.rampup = exp_rampup(config.rampup_length) 36 | 37 | def train_iteration(self, label_loader, unlab_loader, print_freq): 38 | loop_info = defaultdict(list) 39 | batch_idx, label_n, unlab_n = 0, 0, 0 40 | for (label_x, label_y, ldx), (unlab_x, unlab_y, udx) in zip(cycle(label_loader), unlab_loader): 41 | label_x, label_y = label_x.to(self.device), label_y.to(self.device) 42 | unlab_x, unlab_y = unlab_x.to(self.device), unlab_y.to(self.device) 43 | ##=== decode targets of unlabeled data === 44 | self.decode_targets(unlab_y) 45 | lbs, ubs = label_x.size(0), unlab_x.size(0) 46 | 47 | ##=== forward === 48 | outputs = self.model(label_x) 49 | loss = self.ce_loss(outputs, label_y) 50 | loop_info['lSup'].append(loss.item()) 51 | 52 | ##=== Semi-supervised Training Phase === 53 | unlab_outputs = self.model(unlab_x) 54 | with torch.no_grad(): 55 | ema_iter_pslab = self.update_ema(unlab_outputs.clone().detach(), udx) 56 | uloss = self.mse_loss(unlab_outputs, ema_iter_pslab) 57 | uloss *= self.rampup(self.epoch)*self.usp_weight 58 | loss += uloss; loop_info['uTmp'].append(uloss.item()) 59 | 60 | ## bachward 61 | self.optimizer.zero_grad() 62 | loss.backward() 63 | self.optimizer.step() 64 | 65 | ##=== log info === 66 | batch_idx, label_n, unlab_n = batch_idx+1, label_n+lbs, unlab_n+ubs 67 | loop_info['lacc'].append(label_y.eq(outputs.max(1)[1]).float().sum().item()) 68 | loop_info['uacc'].append(unlab_y.eq(unlab_outputs.max(1)[1]).float().sum().item()) 69 | if print_freq>0 and (batch_idx%print_freq)==0: 70 | print(f"[train][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 71 | print(">>>[train]", self.gen_info(loop_info, label_n, unlab_n, False)) 72 | return loop_info, label_n 73 | 74 | def test_iteration(self, data_loader, print_freq): 75 | loop_info = defaultdict(list) 76 | label_n, unlab_n = 0, 0 77 | for batch_idx, (data, targets) in enumerate(data_loader): 78 | data, targets = data.to(self.device), targets.to(self.device) 79 | ##=== decode targets === 80 | lbs, ubs = data.size(0), -1 81 | 82 | ##=== forward === 83 | outputs = self.model(data) 84 | loss = self.ce_loss(outputs, targets) 85 | loop_info['lloss'].append(loss.item()) 86 | 87 | ##=== log info === 88 | label_n, unlab_n = label_n+lbs, unlab_n+ubs 89 | loop_info['lacc'].append(targets.eq(outputs.max(1)[1]).float().sum().item()) 90 | if print_freq>0 and (batch_idx%print_freq)==0: 91 | print(f"[test][{batch_idx:<3}]", self.gen_info(loop_info, lbs, ubs)) 92 | print(">>>[test]", self.gen_info(loop_info, label_n, unlab_n, False)) 93 | return loop_info, label_n 94 | 95 | def train(self, label_loader, unlab_loader, print_freq=20): 96 | self.model.train() 97 | with torch.enable_grad(): 98 | return self.train_iteration(label_loader, unlab_loader, print_freq) 99 | 100 | def test(self, data_loader, print_freq=10): 101 | self.model.eval() 102 | with torch.no_grad(): 103 | return self.test_iteration(data_loader, print_freq) 104 | 105 | def update_ema(self, iter_pslab, idxs): 106 | """update every iteration""" 107 | ema_iter_pslab = (self.ema_decay*self.ema_pslab[idxs]) + (1.0-self.ema_decay)*iter_pslab 108 | self.ema_pslab[idxs] = ema_iter_pslab 109 | return ema_iter_pslab / (1.0 - self.ema_decay**((self.epoch-self.start_epoch)+1.0)) 110 | 111 | def loop(self, epochs, label_data, unlab_data, test_data, scheduler=None): 112 | ## construct epoch pseudo labels 113 | self.ema_pslab = self.create_soft_pslab(n_samples=len(unlab_data.dataset), 114 | n_classes=unlab_data.dataset.num_classes, 115 | dtype='zero') 116 | ## main process 117 | best_info, best_acc, n = None, 0., 0 118 | for ep in range(epochs): 119 | self.epoch = ep 120 | if scheduler is not None: scheduler.step() 121 | print("------ Training epochs: {} ------".format(ep)) 122 | self.train(label_data, unlab_data, self.print_freq) 123 | print("------ Testing epochs: {} ------".format(ep)) 124 | info, n = self.test(test_data, self.print_freq) 125 | acc = sum(info['lacc']) / n 126 | if acc>best_acc: best_info, best_acc = info, acc 127 | ## save model 128 | if self.save_freq!=0 and (ep+1)%self.save_freq == 0: 129 | self.save(ep) 130 | print(f">>>[best]", self.gen_info(best_info, n, n, False)) 131 | 132 | def create_soft_pslab(self, n_samples, n_classes, dtype='rand'): 133 | if dtype=='rand': 134 | pslab = torch.randint(0, n_classes, (n_samples,n_classes)) 135 | elif dtype=='zero': 136 | pslab = torch.zeros(n_samples, n_classes) 137 | else: 138 | raise ValueError('Unknown pslab dtype: {}'.format(dtype)) 139 | return pslab.to(self.device) 140 | 141 | def decode_targets(self, targets): 142 | label_mask = targets.ge(0) 143 | unlab_mask = targets.le(NO_LABEL) 144 | targets[unlab_mask] = decode_label(targets[unlab_mask]) 145 | return label_mask, unlab_mask 146 | 147 | def gen_info(self, info, lbs, ubs, iteration=True): 148 | ret = [] 149 | nums = {'l': lbs, 'u':ubs, 'a': lbs+ubs} 150 | for k, val in info.items(): 151 | n = nums[k[0]] 152 | v = val[-1] if iteration else sum(val) 153 | s = f'{k}: {v/n:.3%}' if k[-1]=='c' else f'{k}: {v:.5f}' 154 | ret.append(s) 155 | return '\t'.join(ret) 156 | 157 | def save(self, epoch, **kwargs): 158 | if self.save_dir is not None: 159 | model_out_path = Path(self.save_dir) 160 | state = {"epoch": epoch, 161 | "weight": self.model.state_dict()} 162 | if not model_out_path.exists(): 163 | model_out_path.mkdir() 164 | save_target = model_out_path / "model_epoch_{}.pth".format(epoch) 165 | torch.save(state, save_target) 166 | print('==> save model to {}'.format(save_target)) 167 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iBelieveCJM/Tricks-of-Semi-supervisedDeepLeanring-Pytorch/be90060b3017e99b8c53a596110cb5931ec9f38c/utils/__init__.py -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | __all__ = ['parse_cmd_args'] 4 | 5 | 6 | def create_parser(): 7 | parser = argparse.ArgumentParser(description='Semi-supevised Training --PyTorch ') 8 | 9 | # Log and save 10 | parser.add_argument('--print-freq', default=20, type=int, metavar='N', help='display frequence (default: 20)') 11 | parser.add_argument('--save-freq', default=0, type=int, metavar='EPOCHS', help='checkpoint frequency(default: 0)') 12 | parser.add_argument('--save-dir', default='./checkpoints', type=str, metavar='DIR') 13 | 14 | # Data 15 | parser.add_argument('--dataset', metavar='DATASET') 16 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') 17 | parser.add_argument('--num-labels', type=int, metavar='N', help='number of labeled samples') 18 | parser.add_argument('--sup-batch-size', default=100, type=int, metavar='N', help='batch size for supervised data (default: 100)') 19 | parser.add_argument('--usp-batch-size', default=100, type=int, metavar='N', help='batch size for unsupervised data (default: 100)') 20 | 21 | # Data pre-processing 22 | parser.add_argument('--data-twice', default=False, type=str2bool, metavar='BOOL', help='use two data stream (default: False)') 23 | parser.add_argument('--data-idxs', default=False, type=str2bool, metavar='BOOL', help='enable indexs of samples (default: False)') 24 | parser.add_argument('--label-exclude', type=str2bool, metavar='BOOL', help='exclude labeled samples in unsupervised batch') 25 | 26 | # Architecture 27 | parser.add_argument('--arch', '-a', metavar='ARCH') 28 | parser.add_argument('--model', metavar='MODEL') 29 | parser.add_argument('--drop-ratio', default=0., type=float, help='ratio of dropout (default: 0)') 30 | 31 | # Optimization 32 | parser.add_argument('--epochs', type=int, metavar='N', help='number of total training epochs') 33 | parser.add_argument('--optim', default="sgd", type=str, metavar='TYPE', choices=['sgd', 'adam']) 34 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum (default: 0.9)') 35 | parser.add_argument('--nesterov', default=False, type=str2bool, metavar='BOOL', help='use nesterov momentum (default: False)') 36 | parser.add_argument('--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') 37 | 38 | # LR schecular 39 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='max learning rate (default: 0.1)') 40 | parser.add_argument('--lr-scheduler', default="cos", type=str, choices=['cos', 'multistep', 'exp-warmup', 'none']) 41 | parser.add_argument('--min-lr', default=1e-4, type=float, metavar='LR', help='minimum learning rate (default: 1e-4)') 42 | parser.add_argument('--steps', type=int, nargs='+', metavar='N', help='decay steps for multistep scheduler') 43 | parser.add_argument('--gamma', type=float, help='factor of learning rate decay') 44 | parser.add_argument('--rampup-length', type=int, metavar='EPOCHS', help='length of the ramp-up') 45 | parser.add_argument('--rampdown-length', type=int, metavar='EPOCHS', help='length of the ramp-down') 46 | 47 | # Pseudo-Label 2013 48 | parser.add_argument('--t1', type=float, metavar='EPOCHS', help='T1') 49 | parser.add_argument('--t2', type=float, metavar='EPOCHS', help='T1') 50 | parser.add_argument('--soft', type=str2bool, help='use soft pseudo label') 51 | 52 | # VAT 53 | parser.add_argument('--xi', type=float, metavar='W', help='xi for VAT') 54 | parser.add_argument('--eps', type=float, metavar='W', help='epsilon for VAT') 55 | parser.add_argument('--n-power', type=int, metavar='N', help='the iteration number of power iteration method in VAT') 56 | 57 | # Fixmatch 58 | parser.add_argument('--threshold', type=float, metavar='W', help='threshold for confident predictions in Fixmatch') 59 | 60 | # MeanTeacher-based method 61 | parser.add_argument('--ema-decay', type=float, metavar='W', help='ema weight decay') 62 | 63 | # Mixup-based method 64 | parser.add_argument('--mixup-alpha', type=float, metavar='W', help='mixup alpha for beta distribution') 65 | 66 | # Opt for loss 67 | parser.add_argument('--usp-weight', default=1.0, type=float, metavar='W', help='the upper of unsuperivsed weight (default: 1.0)') 68 | parser.add_argument('--weight-rampup', default=30, type=int, metavar='EPOCHS', help='the length of rampup weight (default: 30)') 69 | parser.add_argument('--ent-weight', type=float, metavar='W', help='the weight of minEnt regularization') 70 | 71 | return parser 72 | 73 | 74 | def parse_commandline_args(): 75 | return create_parser().parse_args() 76 | 77 | 78 | def str2bool(v): 79 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 80 | return True 81 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 82 | return False 83 | else: 84 | raise argparse.ArgumentTypeError('Boolean value expected.') 85 | -------------------------------------------------------------------------------- /utils/context.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import contextlib 3 | 4 | @contextlib.contextmanager 5 | def disable_tracking_bn_stats(model): 6 | 7 | def switch_attr(m): 8 | if hasattr(m, 'track_running_stats'): 9 | m.track_running_stats ^= True 10 | 11 | model.apply(switch_attr) 12 | yield 13 | model.apply(switch_attr) 14 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import itertools 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | import torchvision.transforms as transforms 7 | from torch.utils.data.sampler import Sampler 8 | 9 | NO_LABEL = -1 10 | 11 | class DataSetWarpper(Dataset): 12 | """Enable dataset to output index of sample 13 | """ 14 | def __init__(self, dataset, num_classes): 15 | self.dataset = dataset 16 | self.num_classes = num_classes 17 | 18 | def __getitem__(self, index): 19 | sample, label = self.dataset[index] 20 | return sample, label, index 21 | 22 | def __len__(self): 23 | return len(self.dataset) 24 | 25 | class TransformTwice: 26 | 27 | def __init__(self, transform): 28 | self.transform = transform 29 | 30 | def __call__(self, inp): 31 | out1 = self.transform(inp) 32 | out2 = self.transform(inp) 33 | return out1, out2 34 | 35 | class TransformWeakStrong: 36 | 37 | def __init__(self, trans1, trans2): 38 | self.transform1 = trans1 39 | self.transform2 = trans2 40 | 41 | def __call__(self, inp): 42 | out1 = self.transform1(inp) 43 | out2 = self.transform2(inp) 44 | return out1, out2 45 | 46 | 47 | class TwoStreamBatchSampler(Sampler): 48 | """Iterate two sets of indices 49 | """ 50 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 51 | self.primary_indices = primary_indices 52 | self.primary_batch_size = batch_size - secondary_batch_size 53 | self.secondary_indices = secondary_indices 54 | self.secondary_batch_size = secondary_batch_size 55 | 56 | assert len(self.primary_indices) >= self.primary_batch_size > 0 57 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 58 | 59 | def __iter__(self): 60 | primary_iter = iterate_once(self.primary_indices) 61 | secondary_iter = iterate_eternally(self.secondary_indices) 62 | return ( 63 | secondary_batch + primary_batch 64 | for (primary_batch, secondary_batch) 65 | in zip(grouper(primary_iter, self.primary_batch_size), 66 | grouper(secondary_iter, self.secondary_batch_size)) 67 | ) 68 | 69 | def __len__(self): 70 | return len(self.primary_indices) // self.primary_batch_size 71 | 72 | def iterate_once(iterable): 73 | return np.random.permutation(iterable) 74 | 75 | def iterate_eternally(indices, is_shuffle=True): 76 | shuffleFunc = np.random.permutation if is_shuffle else lambda x: x 77 | def infinite_shuffles(): 78 | while True: 79 | yield shuffleFunc(indices) 80 | return itertools.chain.from_iterable(infinite_shuffles()) 81 | 82 | def grouper(iterable, n): 83 | args = [iter(iterable)]*n 84 | return zip(*args) 85 | -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import numpy as np 3 | import torchvision as tv 4 | import torchvision.transforms as transforms 5 | 6 | from utils.randAug import RandAugmentMC 7 | from utils.data_utils import NO_LABEL 8 | from utils.data_utils import TransformWeakStrong as wstwice 9 | 10 | load = {} 11 | def register_dataset(dataset): 12 | def warpper(f): 13 | load[dataset] = f 14 | return f 15 | return warpper 16 | 17 | def encode_label(label): 18 | return NO_LABEL* (label +1) 19 | 20 | def decode_label(label): 21 | return NO_LABEL * label -1 22 | 23 | def split_relabel_data(np_labs, labels, label_per_class, 24 | num_classes): 25 | """ Return the labeled indexes and unlabeled_indexes 26 | """ 27 | labeled_idxs = [] 28 | unlabed_idxs = [] 29 | for id in range(num_classes): 30 | indexes = np.where(np_labs==id)[0] 31 | np.random.shuffle(indexes) 32 | labeled_idxs.extend(indexes[:label_per_class]) 33 | unlabed_idxs.extend(indexes[label_per_class:]) 34 | np.random.shuffle(labeled_idxs) 35 | np.random.shuffle(unlabed_idxs) 36 | ## relabel dataset 37 | for idx in unlabed_idxs: 38 | labels[idx] = encode_label(labels[idx]) 39 | 40 | return labeled_idxs, unlabed_idxs 41 | 42 | 43 | @register_dataset('cifar10') 44 | def cifar10(n_labels, data_root='./data-local/cifar10/'): 45 | channel_stats = dict(mean = [0.4914, 0.4822, 0.4465], 46 | std = [0.2023, 0.1994, 0.2010]) 47 | train_transform = transforms.Compose([ 48 | transforms.Pad(2, padding_mode='reflect'), 49 | transforms.ColorJitter(brightness=0.4, contrast=0.4, 50 | saturation=0.4, hue=0.1), 51 | transforms.RandomCrop(32), 52 | transforms.RandomHorizontalFlip(), 53 | transforms.ToTensor(), 54 | transforms.Normalize(**channel_stats) 55 | ]) 56 | eval_transform = transforms.Compose([ 57 | transforms.ToTensor(), 58 | transforms.Normalize(**channel_stats) 59 | ]) 60 | trainset = tv.datasets.CIFAR10(data_root, train=True, download=True, 61 | transform=train_transform) 62 | evalset = tv.datasets.CIFAR10(data_root, train=False, download=True, 63 | transform=eval_transform) 64 | num_classes = 10 65 | label_per_class = n_labels // num_classes 66 | labeled_idxs, unlabed_idxs = split_relabel_data( 67 | np.array(trainset.train_labels), 68 | trainset.train_labels, 69 | label_per_class, 70 | num_classes) 71 | return { 72 | 'trainset': trainset, 73 | 'evalset': evalset, 74 | 'label_idxs': labeled_idxs, 75 | 'unlab_idxs': unlabed_idxs, 76 | 'num_classes': num_classes 77 | } 78 | 79 | @register_dataset('wscifar10') 80 | def wscifar10(n_labels, data_root='./data-local/cifar10/'): 81 | channel_stats = dict(mean = [0.4914, 0.4822, 0.4465], 82 | std = [0.2023, 0.1994, 0.2010]) 83 | weak = transforms.Compose([ 84 | transforms.RandomHorizontalFlip(), 85 | transforms.Pad(2, padding_mode='reflect'), 86 | transforms.RandomCrop(32), 87 | transforms.ToTensor(), 88 | transforms.Normalize(**channel_stats) 89 | ]) 90 | strong = transforms.Compose([ 91 | transforms.RandomHorizontalFlip(), 92 | transforms.Pad(2, padding_mode='reflect'), 93 | transforms.RandomCrop(32), 94 | RandAugmentMC(n=2, m=10), 95 | transforms.ToTensor(), 96 | transforms.Normalize(**channel_stats) 97 | ]) 98 | train_transform = wstwice(weak, strong) 99 | eval_transform = transforms.Compose([ 100 | transforms.ToTensor(), 101 | transforms.Normalize(**channel_stats) 102 | ]) 103 | trainset = tv.datasets.CIFAR10(data_root, train=True, download=True, 104 | transform=train_transform) 105 | evalset = tv.datasets.CIFAR10(data_root, train=False, download=True, 106 | transform=eval_transform) 107 | num_classes = 10 108 | label_per_class = n_labels // num_classes 109 | labeled_idxs, unlabed_idxs = split_relabel_data( 110 | np.array(trainset.train_labels), 111 | trainset.train_labels, 112 | label_per_class, 113 | num_classes) 114 | return { 115 | 'trainset': trainset, 116 | 'evalset': evalset, 117 | 'label_idxs': labeled_idxs, 118 | 'unlab_idxs': unlabed_idxs, 119 | 'num_classes': num_classes 120 | } 121 | 122 | 123 | @register_dataset('cifar100') 124 | def cifar100(n_labels, data_root='./data-local/cifar100/'): 125 | channel_stats = dict(mean = [0.5071, 0.4867, 0.4408], 126 | std = [0.2675, 0.2565, 0.2761]) 127 | train_transform = transforms.Compose([ 128 | transforms.Pad(2, padding_mode='reflect'), 129 | transforms.ColorJitter(brightness=0.4, contrast=0.4, 130 | saturation=0.4, hue=0.1), 131 | transforms.RandomCrop(32), 132 | transforms.RandomHorizontalFlip(), 133 | transforms.ToTensor(), 134 | transforms.Normalize(**channel_stats) 135 | ]) 136 | eval_transform = transforms.Compose([ 137 | transforms.ToTensor(), 138 | transforms.Normalize(**channel_stats) 139 | ]) 140 | trainset = tv.datasets.CIFAR100(data_root, train=True, download=True, 141 | transform=train_transform) 142 | evalset = tv.datasets.CIFAR100(data_root, train=False, download=True, 143 | transform=eval_transform) 144 | num_classes = 100 145 | label_per_class = n_labels // num_classes 146 | labeled_idxs, unlabed_idxs = split_relabel_data( 147 | np.array(trainset.train_labels), 148 | trainset.train_labels, 149 | label_per_class, 150 | num_classes) 151 | return { 152 | 'trainset': trainset, 153 | 'evalset': evalset, 154 | 'labeled_idxs': labeled_idxs, 155 | 'unlabeled_idxs': unlabed_idxs, 156 | 'num_classes': num_classes 157 | } 158 | -------------------------------------------------------------------------------- /utils/dist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def eucl_dist(x, y): 4 | """ Compute Pairwise (Squared Euclidean) Distance 5 | 6 | Input: 7 | x: embedding of size M x D 8 | y: embedding of size N x D 9 | 10 | Output: 11 | dist: pairwise distance of size M x N 12 | """ 13 | x2 = torch.sum(x**2, dim=1, keepdim=True).expand(-1, y.size(0)) 14 | y2 = torch.sum(y**2, dim=1, keepdim=True).t().expand(x.size(0), -1) 15 | xy = x.mm(y.t()) 16 | return x2 - 2*xy + y2 17 | 18 | 19 | def rbf_graph(x, y, sigma): 20 | diff = eucl_dist(x, y) 21 | g = torch.exp(diff / (-2.0 * sigma**2)) 22 | return g / torch.sum(g, dim=1, keepdim=True) 23 | 24 | 25 | def cosine_dist(x, y): 26 | """ Compute consin dist 27 | 28 | Input: 29 | x: embedding of size M x D 30 | y: embedding of size N x D 31 | 32 | Output: 33 | dist: pairwise distance of size M x N 34 | """ 35 | xy = x.mm(y.t()) 36 | x_norm = torch.norm(x, p=2, dim=1, keepdim=True) 37 | y_norm = torch.norm(y, p=2, dim=1, keepdim=True) 38 | xy_norm = x_norm.mm( y_norm.t() ) 39 | return xy / xy_norm.add(1e-10) 40 | 41 | 42 | def neighbor_graph(x): 43 | neighbor_n = x.size(0) 44 | x1 = x.unsqueeze(0).expand(neighbor_n, -1) 45 | x2 = x.unsqueeze(1).expand(-1, neighbor_n) 46 | return x1.eq(x2).float() 47 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | def kl_div_with_logit(input_logits, target_logits): 6 | assert input_logits.size()==target_logits.size() 7 | targets = F.softmax(targets_logits, dim=1) 8 | return F.kl_div(F.log_softmax(input_logits,1), targets) 9 | 10 | def entropy_y_x(logit): 11 | soft_logit = F.softmax(logit, dim=1) 12 | return -torch.mean(torch.sum(soft_logit* F.log_softmax(logit,dim=1), dim=1)) 13 | 14 | def softmax_loss_no_reduce(input_logits, target_logits, eps=1e-10): 15 | assert input_logits.size()==target_logits.size() 16 | target_soft = F.softmax(target_logits, dim=1) 17 | return -torch.sum(target_soft* F.log_softmax(input_logits+eps,dim=1), dim=1) 18 | 19 | def softmax_loss_mean(input_logits, target_logits, eps=1e-10): 20 | assert input_logits.size()==target_logits.size() 21 | target_soft = F.softmax(target_logits, dim=1) 22 | return -torch.mean(torch.sum(target_soft* F.log_softmax(input_logits+eps,dim=1), dim=1)) 23 | 24 | def sym_mse(logit1, logit2): 25 | assert logit1.size()==logit2.size() 26 | return torch.mean((logit1 - logit2)**2) 27 | 28 | def sym_mse_with_softmax(logit1, logit2): 29 | assert logit1.size()==logit2.size() 30 | return torch.mean((F.softmax(logit1,1) - F.softmax(logit2,1))**2) 31 | 32 | def mse_with_softmax(logit1, logit2): 33 | assert logit1.size()==logit2.size() 34 | return F.mse_loss(F.softmax(logit1,1), F.softmax(logit2,1)) 35 | 36 | def one_hot(targets, nClass): 37 | logits = torch.zeros(targets.size(0), nClass).to(targets.device) 38 | return logits.scatter_(1,targets.unsqueeze(1),1) 39 | 40 | def label_smooth(one_hot_labels, epsilon=0.1): 41 | nClass = labels.size(1) 42 | return ((1.-epsilon)*one_hot_labels + (epsilon/nClass)) 43 | 44 | def uniform_prior_loss(logits): 45 | logit_avg = torch.mean(F.softmax(logits,dim=1), dim=0) 46 | num_classes, device = logits.size(1), logits.device 47 | p = torch.ones(num_classes).to(device) / num_classes 48 | return -torch.sum(torch.log(logit_avg) * p) 49 | -------------------------------------------------------------------------------- /utils/mixup.py: -------------------------------------------------------------------------------- 1 | #!coding:utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import numpy as np 6 | 7 | def mixup_one_target(x, y, alpha=1.0, device='cuda', is_bias=False): 8 | """Returns mixed inputs, mixed targets, and lambda 9 | """ 10 | if alpha > 0: 11 | lam = np.random.beta(alpha, alpha) 12 | else: 13 | lam = 1 14 | if is_bias: lam = max(lam, 1-lam) 15 | 16 | index = torch.randperm(x.size(0)).to(device) 17 | 18 | mixed_x = lam*x + (1-lam)*x[index, :] 19 | mixed_y = lam*y + (1-lam)*y[index] 20 | return mixed_x, mixed_y, lam 21 | 22 | 23 | def mixup_two_targets(x, y, alpha=1.0, device='cuda', is_bias=False): 24 | """Returns mixed inputs, pairs of targets, and lambda 25 | """ 26 | if alpha > 0: 27 | lam = np.random.beta(alpha, alpha) 28 | else: 29 | lam = 1 30 | if is_bias: lam = max(lam, 1-lam) 31 | 32 | index = torch.randperm(x.size(0)).to(device) 33 | 34 | mixed_x = lam*x + (1-lam)*x[index, :] 35 | y_a, y_b = y, y[index] 36 | return mixed_x, y_a, y_b, lam 37 | 38 | 39 | def mixup_ce_loss_soft(preds, targets_a, targets_b, lam): 40 | """ mixed categorical cross-entropy loss for soft labels 41 | """ 42 | mixup_loss_a = -torch.mean(torch.sum(targets_a* F.log_softmax(preds, dim=1), dim=1)) 43 | mixup_loss_b = -torch.mean(torch.sum(targets_b* F.log_softmax(preds, dim=1), dim=1)) 44 | 45 | mixup_loss = lam* mixup_loss_a + (1- lam)* mixup_loss_b 46 | return mixup_loss 47 | 48 | 49 | def mixup_ce_loss_hard(preds, targets_a, targets_b, lam): 50 | """ mixed categorical cross-entropy loss 51 | """ 52 | mixup_loss_a = F.nll_loss(F.log_softmax(preds, dim=1), targets_a) 53 | mixup_loss_b = F.nll_loss(F.log_softmax(preds, dim=1), targets_b) 54 | 55 | mixup_loss = lam* mixup_loss_a + (1- lam)* mixup_loss_b 56 | return mixup_loss 57 | 58 | 59 | def mixup_ce_loss_with_softmax(preds, targets_a, targets_b, lam): 60 | """ mixed categorical cross-entropy loss 61 | """ 62 | mixup_loss_a = -torch.mean(torch.sum(F.softmax(targets_a,1)* F.log_softmax(preds, dim=1), dim=1)) 63 | mixup_loss_b = -torch.mean(torch.sum(F.softmax(targets_b,1)* F.log_softmax(preds, dim=1), dim=1)) 64 | 65 | mixup_loss = lam* mixup_loss_a + (1- lam)* mixup_loss_b 66 | return mixup_loss 67 | 68 | 69 | def mixup_mse_loss_with_softmax(preds, targets_a, targets_b, lam): 70 | """ mixed categorical mse loss 71 | """ 72 | mixup_loss_a = F.mse_loss(F.softmax(preds,1), F.softmax(targets_a,1)) 73 | mixup_loss_b = F.mse_loss(F.softmax(preds,1), F.softmax(targets_b,1)) 74 | 75 | mixup_loss = lam* mixup_loss_a + (1- lam)* mixup_loss_b 76 | return mixup_loss 77 | -------------------------------------------------------------------------------- /utils/ramps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def pseudo_rampup(T1, T2): 4 | def warpper(epoch): 5 | if epoch > T1: 6 | alpha = (epoch-T1) / (T2-T1) 7 | if epoch > T2: 8 | alpha = 1.0 9 | else: 10 | alpha = 0.0 11 | return alpha 12 | return warpper 13 | 14 | 15 | def exp_rampup(rampup_length): 16 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 17 | def warpper(epoch): 18 | if epoch < rampup_length: 19 | epoch = np.clip(epoch, 0.0, rampup_length) 20 | phase = 1.0 - epoch / rampup_length 21 | return float(np.exp(-5.0 * phase * phase)) 22 | else: 23 | return 1.0 24 | return warpper 25 | 26 | 27 | def linear_rampup(rampup_length): 28 | """Linear rampup""" 29 | def warpper(epoch): 30 | if epoch < rampup_length: 31 | return epoch / rampup_length 32 | else: 33 | return 1.0 34 | return warpper 35 | 36 | 37 | def exp_rampdown(rampdown_length, num_epochs): 38 | """Exponential rampdown from https://arxiv.org/abs/1610.02242""" 39 | def warpper(epoch): 40 | if epoch >= (num_epochs - rampdown_length): 41 | ep = .5* (epoch - (num_epochs - rampdown_length)) 42 | return float(np.exp(-(ep * ep) / rampdown_length)) 43 | else: 44 | return 1.0 45 | return warpper 46 | 47 | 48 | def cosine_rampdown(rampdown_length, num_epochs): 49 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 50 | def warpper(epoch): 51 | if epoch >= (num_epochs - rampdown_length): 52 | ep = .5* (epoch - (num_epochs - rampdown_length)) 53 | return float(.5 * (np.cos(np.pi * ep / rampdown_length) + 1)) 54 | else: 55 | return 1.0 56 | return warpper 57 | 58 | 59 | def exp_warmup(rampup_length, rampdown_length, num_epochs): 60 | rampup = exp_rampup(rampup_length) 61 | rampdown = exp_rampdown(rampdown_length, num_epochs) 62 | def warpper(epoch): 63 | return rampup(epoch)*rampdown(epoch) 64 | return warpper 65 | 66 | 67 | def test_warmup(): 68 | warmup = exp_warmup(80, 50, 500) 69 | for ep in range(500): 70 | print(warmup(ep)) 71 | -------------------------------------------------------------------------------- /utils/randAug.py: -------------------------------------------------------------------------------- 1 | """MIT License 2 | 3 | Copyright (c) 2019 Jungdae Kim, Qing Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | """ 23 | 24 | # code in this file is adpated from 25 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 26 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 27 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 28 | import logging 29 | import random 30 | 31 | import numpy as np 32 | import PIL 33 | import PIL.ImageOps 34 | import PIL.ImageEnhance 35 | import PIL.ImageDraw 36 | from PIL import Image 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | PARAMETER_MAX = 10 41 | 42 | 43 | def AutoContrast(img, **kwarg): 44 | return PIL.ImageOps.autocontrast(img) 45 | 46 | 47 | def Brightness(img, v, max_v, bias=0): 48 | v = _float_parameter(v, max_v) + bias 49 | return PIL.ImageEnhance.Brightness(img).enhance(v) 50 | 51 | 52 | def Color(img, v, max_v, bias=0): 53 | v = _float_parameter(v, max_v) + bias 54 | return PIL.ImageEnhance.Color(img).enhance(v) 55 | 56 | 57 | def Contrast(img, v, max_v, bias=0): 58 | v = _float_parameter(v, max_v) + bias 59 | return PIL.ImageEnhance.Contrast(img).enhance(v) 60 | 61 | 62 | def Cutout(img, v, max_v, bias=0): 63 | if v == 0: 64 | return img 65 | v = _float_parameter(v, max_v) + bias 66 | v = int(v * min(img.size)) 67 | return CutoutAbs(img, v) 68 | 69 | 70 | def CutoutAbs(img, v, **kwarg): 71 | w, h = img.size 72 | x0 = np.random.uniform(0, w) 73 | y0 = np.random.uniform(0, h) 74 | x0 = int(max(0, x0 - v / 2.)) 75 | y0 = int(max(0, y0 - v / 2.)) 76 | x1 = int(min(w, x0 + v)) 77 | y1 = int(min(h, y0 + v)) 78 | xy = (x0, y0, x1, y1) 79 | # gray 80 | color = (127, 127, 127) 81 | img = img.copy() 82 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 83 | return img 84 | 85 | 86 | def Equalize(img, **kwarg): 87 | return PIL.ImageOps.equalize(img) 88 | 89 | 90 | def Identity(img, **kwarg): 91 | return img 92 | 93 | 94 | def Invert(img, **kwarg): 95 | return PIL.ImageOps.invert(img) 96 | 97 | 98 | def Posterize(img, v, max_v, bias=0): 99 | v = _int_parameter(v, max_v) + bias 100 | return PIL.ImageOps.posterize(img, v) 101 | 102 | 103 | def Rotate(img, v, max_v, bias=0): 104 | v = _int_parameter(v, max_v) + bias 105 | if random.random() < 0.5: 106 | v = -v 107 | return img.rotate(v) 108 | 109 | 110 | def Sharpness(img, v, max_v, bias=0): 111 | v = _float_parameter(v, max_v) + bias 112 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 113 | 114 | 115 | def ShearX(img, v, max_v, bias=0): 116 | v = _float_parameter(v, max_v) + bias 117 | if random.random() < 0.5: 118 | v = -v 119 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 120 | 121 | 122 | def ShearY(img, v, max_v, bias=0): 123 | v = _float_parameter(v, max_v) + bias 124 | if random.random() < 0.5: 125 | v = -v 126 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 127 | 128 | 129 | def Solarize(img, v, max_v, bias=0): 130 | v = _int_parameter(v, max_v) + bias 131 | return PIL.ImageOps.solarize(img, 256 - v) 132 | 133 | 134 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 135 | v = _int_parameter(v, max_v) + bias 136 | if random.random() < 0.5: 137 | v = -v 138 | img_np = np.array(img).astype(np.int) 139 | img_np = img_np + v 140 | img_np = np.clip(img_np, 0, 255) 141 | img_np = img_np.astype(np.uint8) 142 | img = Image.fromarray(img_np) 143 | return PIL.ImageOps.solarize(img, threshold) 144 | 145 | 146 | def TranslateX(img, v, max_v, bias=0): 147 | v = _float_parameter(v, max_v) + bias 148 | if random.random() < 0.5: 149 | v = -v 150 | v = int(v * img.size[0]) 151 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 152 | 153 | 154 | def TranslateY(img, v, max_v, bias=0): 155 | v = _float_parameter(v, max_v) + bias 156 | if random.random() < 0.5: 157 | v = -v 158 | v = int(v * img.size[1]) 159 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 160 | 161 | 162 | def _float_parameter(v, max_v): 163 | return float(v) * max_v / PARAMETER_MAX 164 | 165 | 166 | def _int_parameter(v, max_v): 167 | return int(v * max_v / PARAMETER_MAX) 168 | 169 | 170 | def fixmatch_augment_pool(): 171 | # FixMatch paper 172 | augs = [(AutoContrast, None, None), 173 | (Brightness, 0.9, 0.05), 174 | (Color, 0.9, 0.05), 175 | (Contrast, 0.9, 0.05), 176 | (Equalize, None, None), 177 | (Identity, None, None), 178 | (Posterize, 4, 4), 179 | (Rotate, 30, 0), 180 | (Sharpness, 0.9, 0.05), 181 | (ShearX, 0.3, 0), 182 | (ShearY, 0.3, 0), 183 | (Solarize, 256, 0), 184 | (TranslateX, 0.3, 0), 185 | (TranslateY, 0.3, 0)] 186 | return augs 187 | 188 | 189 | def my_augment_pool(): 190 | # Test 191 | augs = [(AutoContrast, None, None), 192 | (Brightness, 1.8, 0.1), 193 | (Color, 1.8, 0.1), 194 | (Contrast, 1.8, 0.1), 195 | (Cutout, 0.2, 0), 196 | (Equalize, None, None), 197 | (Invert, None, None), 198 | (Posterize, 4, 4), 199 | (Rotate, 30, 0), 200 | (Sharpness, 1.8, 0.1), 201 | (ShearX, 0.3, 0), 202 | (ShearY, 0.3, 0), 203 | (Solarize, 256, 0), 204 | (SolarizeAdd, 110, 0), 205 | (TranslateX, 0.45, 0), 206 | (TranslateY, 0.45, 0)] 207 | return augs 208 | 209 | 210 | class RandAugmentPC(object): 211 | def __init__(self, n, m): 212 | assert n >= 1 213 | assert 1 <= m <= 10 214 | self.n = n 215 | self.m = m 216 | self.augment_pool = my_augment_pool() 217 | 218 | def __call__(self, img): 219 | ops = random.choices(self.augment_pool, k=self.n) 220 | for op, max_v, bias in ops: 221 | prob = np.random.uniform(0.2, 0.8) 222 | if random.random() + prob >= 1: 223 | img = op(img, v=self.m, max_v=max_v, bias=bias) 224 | img = CutoutAbs(img, 16) 225 | return img 226 | 227 | 228 | class RandAugmentMC(object): 229 | def __init__(self, n, m): 230 | assert n >= 1 231 | assert 1 <= m <= 10 232 | self.n = n 233 | self.m = m 234 | self.augment_pool = fixmatch_augment_pool() 235 | 236 | def __call__(self, img): 237 | ops = random.choices(self.augment_pool, k=self.n) 238 | for op, max_v, bias in ops: 239 | v = np.random.randint(1, self.m) 240 | if random.random() < 0.5: 241 | img = op(img, v=v, max_v=max_v, bias=bias) 242 | img = CutoutAbs(img, 16) 243 | return img 244 | 245 | --------------------------------------------------------------------------------