├── results_cifar100 ├── 10.npy ├── 15.npy ├── 20.npy ├── 25.npy ├── 30.npy ├── 35.npy └── 40.npy ├── models ├── __pycache__ │ └── resnet.cpython-36.pyc └── resnet.py ├── config.py ├── README.md ├── sampler.py ├── utils.py ├── arguments.py ├── custom_datasets.py ├── main.py ├── acc100.py ├── model.py ├── resnet.py └── solver.py /results_cifar100/10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Beichen1996/SRAAL/HEAD/results_cifar100/10.npy -------------------------------------------------------------------------------- /results_cifar100/15.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Beichen1996/SRAAL/HEAD/results_cifar100/15.npy -------------------------------------------------------------------------------- /results_cifar100/20.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Beichen1996/SRAAL/HEAD/results_cifar100/20.npy -------------------------------------------------------------------------------- /results_cifar100/25.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Beichen1996/SRAAL/HEAD/results_cifar100/25.npy -------------------------------------------------------------------------------- /results_cifar100/30.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Beichen1996/SRAAL/HEAD/results_cifar100/30.npy -------------------------------------------------------------------------------- /results_cifar100/35.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Beichen1996/SRAAL/HEAD/results_cifar100/35.npy -------------------------------------------------------------------------------- /results_cifar100/40.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Beichen1996/SRAAL/HEAD/results_cifar100/40.npy -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Beichen1996/SRAAL/HEAD/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | NUM_TRAIN = 50000 2 | NUM_VAL = 50000 - NUM_TRAIN 3 | BATCH = 128 4 | 5 | MARGIN = 1.0 6 | WEIGHT = 1.0 7 | 8 | CYCLES = 7 9 | 10 | EPOCH = 200 11 | LR = 0.1 12 | MILESTONES = [200,300] 13 | 14 | MOMENTUM = 0.9 15 | WDECAY = 5e-4 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # State-Relabeling Adversarial Active Learning 2 | Code for SRAAL [2020 CVPR Oral] 3 | 4 | # Requirements 5 | torch >= 1.6.0 6 | 7 | numpy >= 1.19.1 8 | 9 | tqdm >= 4.31.1 10 | 11 | # AL Results 12 | The AL sampling starts from 10% initial labeled pool(10.npy) and selects 5% data to label at each iteration. 13 | 14 | The result files locate in ./results_cifar100/ 15 | 16 | 17 | 18 | 19 | 20 | # To Train the Model 21 | 22 | python main.py 23 | 24 | # To Evaluate the Results 25 | 26 | python acc100.py 27 | 28 | -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | 5 | class AdversarySampler: 6 | def __init__(self, budget): 7 | self.budget = budget 8 | 9 | 10 | def sample(self, vae, discriminator, data, cuda): 11 | all_preds = [] 12 | all_indices = [] 13 | 14 | for images, _, indices in data: 15 | if cuda: 16 | images = images.cuda() 17 | 18 | with torch.no_grad(): 19 | _, _, mu, _ = vae(images, labeled = 0) 20 | preds = discriminator(mu) 21 | 22 | preds = preds.cpu().data 23 | all_preds.extend(preds) 24 | all_indices.extend(indices) 25 | 26 | all_preds = torch.stack(all_preds) 27 | all_preds = all_preds.view(-1) 28 | 29 | # select the points which the discriminator things are the most likely to be unlabeled 30 | _, querry_indices = torch.topk(all_preds, int(self.budget)) 31 | querry_pool_indices = np.asarray(all_indices)[querry_indices] 32 | 33 | return querry_pool_indices 34 | 35 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | 3 | 4 | def cifar10_transformer_train(): 5 | return transforms.Compose([ 6 | transforms.RandomCrop(32, padding=4), 7 | transforms.RandomHorizontalFlip(), 8 | transforms.ToTensor(), 9 | transforms.Normalize(mean=[0.4914009, 0.48215896, 0.44653079], std=[0.24703279, 0.24348423, 0.26158753]), 10 | ]) 11 | def cifar10_transformer_test(): 12 | return transforms.Compose([ 13 | transforms.ToTensor(), 14 | transforms.Normalize(mean=[0.4914009, 0.48215896, 0.44653079], std=[0.24703279, 0.24348423, 0.26158753]), 15 | ]) 16 | 17 | def cifar100_transformer_train(): 18 | return transforms.Compose([ 19 | transforms.RandomCrop(32, padding=4), 20 | transforms.RandomHorizontalFlip(), 21 | transforms.ToTensor(), 22 | transforms.Normalize(mean=[0.50707543, 0.48655024, 0.44091907], std=[0.26733398, 0.25643876, 0.27615029]), 23 | ]) 24 | def cifar100_transformer_test(): 25 | return transforms.Compose([ 26 | transforms.ToTensor(), 27 | transforms.Normalize(mean=[0.50707543, 0.48655024, 0.44091907], std=[0.26733398, 0.25643876, 0.27615029]), 28 | ]) 29 | 30 | 31 | def imagenet_transformer(): 32 | transform=transforms.Compose([ 33 | transforms.RandomResizedCrop(224), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 37 | std=[0.229, 0.224, 0.225]) 38 | ]) 39 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--cuda', action='store_true', help='If training is to be done on a GPU') 7 | parser.add_argument('--dataset', type=str, default='cifar100', help='Name of the dataset used.') 8 | parser.add_argument('--classes', type=int, default=100, help='Number of classes in the target model.') 9 | parser.add_argument('--batch_size', type=int, default=250, help='Batch size used for training and testing') 10 | parser.add_argument('--train_iterations', type=int, default=20000, help='Number of training iterations') #20000 11 | parser.add_argument('--task_epochs', type=int, default=200, help='Number of training iterations') #200 12 | parser.add_argument('--latent_dim', type=int, default=200, help='The dimensionality of the VAE latent dimension') 13 | parser.add_argument('--data_path', type=str, default='../data', help='Path to where the data is') 14 | parser.add_argument('--beta', type=float, default=1, help='Hyperparameter for training. The parameter for VAE') 15 | parser.add_argument('--num_adv_steps', type=int, default=1, help='Number of adversary steps taken for every task model step') 16 | parser.add_argument('--num_vae_steps', type=int, default=2, help='Number of VAE steps taken for every task model step') 17 | parser.add_argument('--adversary_param', type=float, default=1, help='Hyperparameter for training. lambda2 in the paper') 18 | parser.add_argument('--log_name', type=str, default='accuracies.log', help='Final performance of the models will be saved with this name') 19 | parser.add_argument('--initnumber', type=int, default=5000, help='init samples to be labeled') 20 | parser.add_argument('--random_version', type=int, default=8, help='initial version to be used') 21 | parser.add_argument('--randompath', type=str, default='../random/', help='initial version to be used') 22 | parser.add_argument('--resultpath', type=str, default='./results_cifar100/', help='result for THE model') 23 | 24 | 25 | args = parser.parse_args() 26 | 27 | if not os.path.exists(args.data_path): 28 | os.mkdir(args.data_path) 29 | if not os.path.exists(args.resultpath): 30 | os.mkdir(args.resultpath) 31 | 32 | return args 33 | 34 | 35 | -------------------------------------------------------------------------------- /custom_datasets.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | from torch.utils.data import Dataset, DataLoader 3 | import torchvision 4 | import numpy 5 | 6 | from utils import * 7 | 8 | def imagenet_transformer(): 9 | transform=transforms.Compose([ 10 | transforms.RandomResizedCrop(224), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.ToTensor(), 13 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 14 | std=[0.229, 0.224, 0.225]) 15 | ]) 16 | 17 | 18 | class CIFAR10(Dataset): 19 | def __init__(self, path): 20 | self.cifar10 = datasets.CIFAR10(root=path, 21 | download=True, 22 | train=True, 23 | transform=cifar10_transformer_train()) 24 | 25 | def __getitem__(self, index): 26 | if isinstance(index, numpy.float64): 27 | index = index.astype(numpy.int64) 28 | 29 | data, target = self.cifar10[index] 30 | 31 | return data, target, index 32 | 33 | def __len__(self): 34 | return len(self.cifar10) 35 | 36 | 37 | class CIFAR100(Dataset): 38 | def __init__(self, path): 39 | self.cifar100 = datasets.CIFAR100(root=path, 40 | download=True, 41 | train=True, 42 | transform=cifar100_transformer_train()) 43 | 44 | def __getitem__(self, index): 45 | if isinstance(index, numpy.float64): 46 | index = index.astype(numpy.int64) 47 | 48 | data, target = self.cifar100[index] 49 | 50 | # Your transformations here (or set it in CIFAR10) 51 | 52 | return data, target, index 53 | 54 | def __len__(self): 55 | return len(self.cifar100) 56 | 57 | 58 | class ImageNet(Dataset): 59 | def __init__(self, path): 60 | self.imagenet = datasets.ImageFolder(root=path, transform=imagenet_transformer) 61 | 62 | def __getitem__(self, index): 63 | if isinstance(index, numpy.float64): 64 | index = index.astype(numpy.int64) 65 | data, target = self.imagenet[index] 66 | 67 | return data, target, index 68 | 69 | def __len__(self): 70 | return len(self.imagenet) 71 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | import torch.utils.data.sampler as sampler 4 | import torch.utils.data as data 5 | 6 | import numpy as np 7 | import argparse 8 | import random 9 | import os 10 | 11 | from custom_datasets import * 12 | import model 13 | from solver import Solver 14 | from utils import * 15 | import arguments 16 | import resnet 17 | 18 | 19 | def cifar_transformer(): 20 | return transforms.Compose([ 21 | transforms.RandomHorizontalFlip(), 22 | transforms.ToTensor(), 23 | ]) 24 | def save_sample(listt, path): 25 | np.save(path, np.array(listt)) 26 | 27 | 28 | def main(args): 29 | if args.dataset == 'cifar10': 30 | test_dataloader = data.DataLoader( 31 | datasets.CIFAR10(args.data_path, download=True, transform=cifar_transformer(), train=False), 32 | batch_size=args.batch_size, drop_last=False) 33 | 34 | train_dataset = CIFAR10(args.data_path) 35 | args.num_images = 50000 36 | args.budget = 2500 37 | args.num_classes = 10 38 | elif args.dataset == 'cifar100': 39 | test_dataloader = data.DataLoader( 40 | datasets.CIFAR100(args.data_path, download=True, transform=cifar_transformer(), train=False), 41 | batch_size=args.batch_size, drop_last=False) 42 | 43 | train_dataset = CIFAR100(args.data_path) 44 | args.num_images = 50000 45 | args.budget = 2500 46 | args.num_classes = 100 47 | else: 48 | raise NotImplementedError 49 | 50 | all_indices = set(np.arange(args.num_images)) 51 | initial_indices = np.load( './results_cifar100/10.npy' ).tolist() 52 | sampler = data.sampler.SubsetRandomSampler(initial_indices) 53 | querry_dataloader = data.DataLoader(train_dataset, sampler=sampler, 54 | batch_size=args.batch_size, drop_last=True) 55 | 56 | args.cuda = torch.cuda.is_available() 57 | print('cuda:', args.cuda) 58 | solver = Solver(args, test_dataloader) 59 | splits = [0.15, 0.2, 0.25, 0.3, 0.35, 0.4] 60 | current_indices = list(initial_indices) 61 | accuracies = [] 62 | for split in splits: 63 | task_model = resnet.ResNet18(num_classes=args.classes).cuda() 64 | vae = model.VAE(args.latent_dim , num_classes = int(args.dataset[5:])).cuda() 65 | discriminator = model.Discriminator(args.latent_dim).cuda() 66 | unlabeled_indices = np.setdiff1d(list(all_indices), current_indices) 67 | unlabeled_sampler = data.sampler.SubsetRandomSampler(unlabeled_indices) 68 | unlabeled_dataloader = data.DataLoader(train_dataset, 69 | sampler=unlabeled_sampler, batch_size=args.batch_size, drop_last=False) 70 | 71 | # train the models on the current data 72 | vae, discriminator = solver.train(querry_dataloader, 73 | task_model, 74 | vae, 75 | discriminator, 76 | unlabeled_dataloader 77 | ) 78 | 79 | unlabeled_dataloader = data.DataLoader(train_dataset, 80 | sampler=unlabeled_sampler, batch_size=args.batch_size, drop_last=False) 81 | #sample based on discriminator's predicted state 82 | sampled_indices = solver.sample_for_labeling(vae, discriminator, unlabeled_dataloader) 83 | current_indices = list(current_indices) + list(sampled_indices) 84 | # save the selection into .npy file 85 | save_sample(current_indices, './results_cifar100/'+str(int(split*100))+'.npy') 86 | print(str(int(split*100))+'% samples is selected' ) 87 | sampler = data.sampler.SubsetRandomSampler(current_indices) 88 | querry_dataloader = data.DataLoader(train_dataset, sampler=sampler, 89 | batch_size=args.batch_size, drop_last=True) 90 | 91 | 92 | if __name__ == '__main__': 93 | args = arguments.get_args() 94 | main(args) -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, in_planes, planes, stride=1): 10 | super(BasicBlock, self).__init__() 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 12 | self.bn1 = nn.BatchNorm2d(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(planes) 15 | 16 | self.shortcut = nn.Sequential() 17 | if stride != 1 or in_planes != self.expansion*planes: 18 | self.shortcut = nn.Sequential( 19 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 20 | nn.BatchNorm2d(self.expansion*planes) 21 | ) 22 | 23 | def forward(self, x): 24 | out = F.relu(self.bn1(self.conv1(x))) 25 | out = self.bn2(self.conv2(out)) 26 | out += self.shortcut(x) 27 | out = F.relu(out) 28 | return out 29 | 30 | 31 | class Bottleneck(nn.Module): 32 | expansion = 4 33 | 34 | def __init__(self, in_planes, planes, stride=1): 35 | super(Bottleneck, self).__init__() 36 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 37 | self.bn1 = nn.BatchNorm2d(planes) 38 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 41 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 42 | 43 | self.shortcut = nn.Sequential() 44 | if stride != 1 or in_planes != self.expansion*planes: 45 | self.shortcut = nn.Sequential( 46 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 47 | nn.BatchNorm2d(self.expansion*planes) 48 | ) 49 | 50 | def forward(self, x): 51 | out = F.relu(self.bn1(self.conv1(x))) 52 | out = F.relu(self.bn2(self.conv2(out))) 53 | out = self.bn3(self.conv3(out)) 54 | out += self.shortcut(x) 55 | out = F.relu(out) 56 | return out 57 | 58 | 59 | class ResNet(nn.Module): 60 | def __init__(self, block, num_blocks, num_classes=10): 61 | super(ResNet, self).__init__() 62 | self.in_planes = 64 63 | 64 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(64) 66 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 67 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 68 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 69 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 70 | self.linear = nn.Linear(512*block.expansion, num_classes) 71 | 72 | def _make_layer(self, block, planes, num_blocks, stride): 73 | strides = [stride] + [1]*(num_blocks-1) 74 | layers = [] 75 | for stride in strides: 76 | layers.append(block(self.in_planes, planes, stride)) 77 | self.in_planes = planes * block.expansion 78 | return nn.Sequential(*layers) 79 | 80 | def forward(self, x): 81 | out = F.relu(self.bn1(self.conv1(x))) 82 | out1 = self.layer1(out) 83 | out2 = self.layer2(out1) 84 | out3 = self.layer3(out2) 85 | out4 = self.layer4(out3) 86 | out = F.avg_pool2d(out4, 4) 87 | out = out.view(out.size(0), -1) 88 | out = self.linear(out) 89 | return out, [out1, out2, out3, out4] 90 | 91 | 92 | def ResNet18(num_classes = 10): 93 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 94 | 95 | def ResNet34(): 96 | return ResNet(BasicBlock, [3,4,6,3]) 97 | 98 | def ResNet50(): 99 | return ResNet(Bottleneck, [3,4,6,3]) 100 | 101 | def ResNet101(): 102 | return ResNet(Bottleneck, [3,4,23,3]) 103 | 104 | def ResNet152(): 105 | return ResNet(Bottleneck, [3,8,36,3]) -------------------------------------------------------------------------------- /acc100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | # Torch 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | import torch.optim.lr_scheduler as lr_scheduler 10 | from torch.utils.data.sampler import SubsetRandomSampler 11 | 12 | # Torchvison 13 | import torchvision.transforms as T 14 | import torchvision.models as models 15 | from torchvision.datasets import CIFAR100, CIFAR10 16 | 17 | # Utils 18 | from tqdm import tqdm 19 | import models.resnet as resnet 20 | from config import * 21 | # Data 22 | train_transform = T.Compose([ 23 | T.RandomHorizontalFlip(), 24 | T.RandomCrop(size=32, padding=4), 25 | T.ToTensor(), 26 | T.Normalize([0.50707543, 0.48655024, 0.44091907], [0.26733398, 0.25643876, 0.27615029]) # T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) # CIFAR-100 27 | ]) 28 | 29 | test_transform = T.Compose([ 30 | T.ToTensor(), 31 | T.Normalize([0.50707543, 0.48655024, 0.44091907], [0.26733398, 0.25643876, 0.27615029]) # T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) # CIFAR-100 32 | ]) 33 | 34 | cifar100_train = CIFAR100('./data', train=True, download=True, transform=train_transform) 35 | cifar100_unlabeled = CIFAR100('./data', train=True, download=True, transform=test_transform) 36 | cifar100_test = CIFAR100('./data', train=False, download=True, transform=test_transform) 37 | 38 | def train_epoch(models, criterion, optimizers, dataloaders, epoch): 39 | models.train() 40 | for data in tqdm(dataloaders['train'], leave=False, total=len(dataloaders['train'])): 41 | inputs = data[0].cuda() 42 | labels = data[1].cuda() 43 | 44 | optimizers['backbone'].zero_grad() 45 | 46 | scores, features = models(inputs) 47 | target_loss = criterion(scores, labels) 48 | m_backbone_loss = torch.sum(target_loss) / target_loss.size(0) 49 | loss = m_backbone_loss 50 | 51 | loss.backward() 52 | optimizers['backbone'].step() 53 | return m_backbone_loss 54 | 55 | # 56 | def test(models, dataloaders, mode='val'): 57 | assert mode == 'val' or mode == 'test' 58 | models.eval() 59 | total = 0 60 | correct = 0 61 | with torch.no_grad(): 62 | for (inputs, labels) in dataloaders[mode]: 63 | inputs = inputs.cuda() 64 | labels = labels.cuda() 65 | 66 | scores, _ = models(inputs) 67 | _, preds = torch.max(scores.data, 1) 68 | total += labels.size(0) 69 | correct += (preds == labels).sum().item() 70 | 71 | return 100 * correct / total 72 | 73 | # 74 | def train(models, criterion, optimizers, schedulers, dataloaders, num_epochs): 75 | bestacc = 0. 76 | for epoch in range(num_epochs): 77 | schedulers['backbone'].step() 78 | losss = train_epoch(models, criterion, optimizers, dataloaders, epoch) 79 | if epoch % 30 == 0: 80 | acc = test(models, dataloaders, 'test') 81 | if bestacc < acc: 82 | bestacc = acc 83 | if epoch % 30 == 0: 84 | print('Val Acc: {:.3f}% \t '.format(acc)) 85 | return acc 86 | 87 | 88 | if __name__ == '__main__': 89 | vis = None 90 | plot_data = {'X': [], 'Y': [], 'legend': ['Backbone Loss']} 91 | EPOCH = 360 92 | BATCH = 128 93 | 94 | torch.backends.cudnn.benchmark = True 95 | for trial in range(1,7): 96 | resnet18 = resnet.ResNet18(num_classes=100).cuda() 97 | models = resnet18 98 | name = './results_cifar100/' + str(int(trial*5+10)) +'.npy' 99 | print (name) 100 | indices = np.load( name).tolist() 101 | labeled_set = indices 102 | all_indices = set(np.arange(NUM_TRAIN)) 103 | indices = list(range(NUM_TRAIN)) 104 | unlabeled_set = np.setdiff1d(indices, labeled_set).tolist() 105 | 106 | train_loader = DataLoader(cifar100_train, batch_size=BATCH, 107 | sampler=SubsetRandomSampler(labeled_set), 108 | pin_memory=True) 109 | test_loader = DataLoader(cifar100_test, batch_size=BATCH) 110 | dataloaders = {'train': train_loader, 'test': test_loader} 111 | criterion = nn.CrossEntropyLoss(reduction='none') 112 | optim_backbone = optim.SGD(models.parameters(), lr=LR, 113 | momentum=MOMENTUM, weight_decay=WDECAY) 114 | 115 | sched_backbone = lr_scheduler.MultiStepLR(optim_backbone, milestones=MILESTONES) 116 | 117 | optimizers = {'backbone': optim_backbone} 118 | schedulers = {'backbone': sched_backbone} 119 | 120 | # Training and test 121 | acc = train(models, criterion, optimizers, schedulers, dataloaders, EPOCH) 122 | print('Label set size {}%: Test acc {}'.format(len(labeled_set)/NUM_TRAIN*100, acc)) 123 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | 6 | 7 | class View(nn.Module): 8 | def __init__(self, size): 9 | super(View, self).__init__() 10 | self.size = size 11 | 12 | def forward(self, tensor): 13 | return tensor.view(self.size) 14 | 15 | 16 | class VAE(nn.Module): 17 | def __init__(self, z_dim=32, nc=3, num_classes = 10): 18 | print(" task is ", num_classes, " classification") 19 | super(VAE, self).__init__() 20 | self.z_dim = z_dim 21 | self.nc = nc 22 | self.encoder = nn.Sequential( 23 | nn.Conv2d(nc, 128, 4, 2, 1, bias=False), # B, 128, 32, 32 24 | nn.BatchNorm2d(128), 25 | nn.ReLU(True), 26 | nn.Conv2d(128, 256, 4, 2, 1, bias=False), # B, 256, 16, 16 27 | nn.BatchNorm2d(256), 28 | nn.ReLU(True), 29 | nn.Conv2d(256, 512, 4, 2, 1, bias=False), # B, 512, 8, 8 30 | nn.BatchNorm2d(512), 31 | nn.ReLU(True), 32 | nn.Conv2d(512, 1024, 4, 2, 1, bias=False), # B, 1024, 4, 4 33 | nn.BatchNorm2d(1024), 34 | nn.ReLU(True), 35 | View((-1, 1024*2*2)), # B, 1024*4*4 36 | ) 37 | 38 | self.fc_mu = nn.Linear(1024*2*2, z_dim) # B, z_dim 39 | self.fc_logvar = nn.Linear(1024*2*2, z_dim) # B, z_dim 40 | self.decoder = nn.Sequential( 41 | nn.Linear(z_dim//2, 1024*4*4), # B, 1024*8*8 42 | View((-1, 1024, 4, 4)), # B, 1024, 8, 8 43 | nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False), # B, 512, 16, 16 44 | nn.BatchNorm2d(512), 45 | nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), # B, 256, 32, 32 46 | nn.ReLU(True), 47 | nn.BatchNorm2d(256), 48 | nn.ReLU(True), 49 | nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), # B, 128, 64, 64 50 | nn.BatchNorm2d(128), 51 | nn.ReLU(True), 52 | nn.ConvTranspose2d(128, nc, 1), # B, nc, 64, 64 53 | ) 54 | self.classifier = nn.Sequential( 55 | nn.Linear(z_dim//2, 1024), 56 | nn.ReLU(True), 57 | nn.Linear(1024, 256), 58 | nn.ReLU(True), 59 | nn.Linear(256, num_classes), # B, nc, 64, 64 60 | ) 61 | self.weight_init() 62 | 63 | def weight_init(self): 64 | for block in self._modules: 65 | try: 66 | for m in self._modules[block]: 67 | kaiming_init(m) 68 | except: 69 | kaiming_init(block) 70 | 71 | def forward(self, x, labeled = 1): 72 | z = self._encode(x) 73 | mu, logvar = self.fc_mu(z), self.fc_logvar(z) 74 | 75 | z = self.reparameterize(mu, logvar) 76 | # Split the latent variable into two, one for reconstruction and another for targer learner. 77 | # The mu is the unified representation. 78 | m1, m2 = torch.split(z, [self.z_dim//2,self.z_dim//2], dim=1) 79 | 80 | x_recon = self._decode(m1) 81 | if(labeled == 1): 82 | pred_label = self._classi(m2) 83 | return x_recon, z, mu, logvar, pred_label 84 | else: 85 | return x_recon, z, mu, logvar 86 | 87 | def reparameterize(self, mu, logvar): 88 | stds = (0.5 * logvar).exp() 89 | epsilon = torch.randn(*mu.size()) 90 | if mu.is_cuda: 91 | stds, epsilon = stds.cuda(), epsilon.cuda() 92 | latents = epsilon * stds + mu 93 | return latents 94 | 95 | def _encode(self, x): 96 | return self.encoder(x) 97 | 98 | def _decode(self, z): 99 | return self.decoder(z) 100 | 101 | def _classi(self, z): 102 | return self.classifier(z) 103 | 104 | 105 | 106 | class Discriminator(nn.Module): 107 | def __init__(self, z_dim=10): 108 | super(Discriminator, self).__init__() 109 | self.z_dim = z_dim 110 | self.net = nn.Sequential( 111 | nn.Linear(z_dim, 512), 112 | nn.ReLU(True), 113 | nn.Linear(512, 512), 114 | nn.ReLU(True), 115 | nn.Linear(512, 512), 116 | nn.ReLU(True), 117 | nn.Linear(512, 512), 118 | nn.ReLU(True), 119 | nn.Linear(512, 1), 120 | nn.Sigmoid() 121 | ) 122 | self.weight_init() 123 | 124 | def weight_init(self): 125 | for block in self._modules: 126 | for m in self._modules[block]: 127 | kaiming_init(m) 128 | 129 | def forward(self, z): 130 | return self.net(z) 131 | 132 | 133 | def kaiming_init(m): 134 | if isinstance(m, (nn.Linear, nn.Conv2d)): 135 | init.kaiming_normal_(m.weight) 136 | if m.bias is not None: 137 | m.bias.data.fill_(0) 138 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): 139 | m.weight.data.fill_(1) 140 | if m.bias is not None: 141 | m.bias.data.fill_(0) 142 | 143 | 144 | def normal_init(m, mean, std): 145 | if isinstance(m, (nn.Linear, nn.Conv2d)): 146 | m.weight.data.normal_(mean, std) 147 | if m.bias.data is not None: 148 | m.bias.data.zero_() 149 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 150 | m.weight.data.fill_(1) 151 | if m.bias.data is not None: 152 | m.bias.data.zero_() 153 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | Reference: 4 | https://github.com/kuangliu/pytorch-cifar 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(self.expansion*planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += self.shortcut(x) 33 | out = F.relu(out) 34 | return out 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | expansion = 4 39 | 40 | def __init__(self, in_planes, planes, stride=1): 41 | super(Bottleneck, self).__init__() 42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(self.expansion*planes) 54 | ) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = F.relu(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out += self.shortcut(x) 61 | out = F.relu(out) 62 | return out 63 | 64 | 65 | class ResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(ResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn1 = nn.BatchNorm2d(64) 72 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 73 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 74 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 75 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 76 | self.linear = nn.Linear(512*block.expansion, num_classes) 77 | 78 | def _make_layer(self, block, planes, num_blocks, stride): 79 | strides = [stride] + [1]*(num_blocks-1) 80 | layers = [] 81 | for stride in strides: 82 | layers.append(block(self.in_planes, planes, stride)) 83 | self.in_planes = planes * block.expansion 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | out = F.relu(self.bn1(self.conv1(x))) 88 | out1 = self.layer1(out) 89 | out2 = self.layer2(out1) 90 | out3 = self.layer3(out2) 91 | out4 = self.layer4(out3) 92 | out = F.avg_pool2d(out4, 4) 93 | out = out.view(out.size(0), -1) 94 | out = self.linear(out) 95 | return out, [out1, out2, out3, out4] 96 | 97 | 98 | def ResNet18(num_classes = 10): 99 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 100 | 101 | def ResNet34(): 102 | return ResNet(BasicBlock, [3,4,6,3]) 103 | 104 | def ResNet50(): 105 | return ResNet(Bottleneck, [3,4,6,3]) 106 | 107 | def ResNet101(): 108 | return ResNet(Bottleneck, [3,4,23,3]) 109 | 110 | def ResNet152(): 111 | return ResNet(Bottleneck, [3,8,36,3]) 112 | 113 | 114 | 115 | def train_epoch(models, criterion, optimizers, dataloaders, epoch): 116 | models.train() 117 | for data in tqdm(dataloaders, leave=False, total=len(dataloaders)): 118 | inputs = data[0].cuda() 119 | labels = data[1].cuda() 120 | 121 | optimizers.zero_grad() 122 | 123 | scores, features = models(inputs) 124 | target_loss = criterion(scores, labels) 125 | m_backbone_loss = torch.sum(target_loss) / target_loss.size(0) 126 | loss = m_backbone_loss 127 | 128 | loss.backward() 129 | optimizers.step() 130 | return m_backbone_loss 131 | 132 | # 133 | def uncertainty_score(models, dataloaders, classes): 134 | models.eval() 135 | all_scores = [] 136 | all_indices = [] 137 | m = nn.Softmax(dim=1) 138 | with torch.no_grad(): 139 | for inputs, labels, indices in dataloaders: 140 | inputs = inputs.cuda() 141 | labels = labels.cuda() 142 | indices = indices.cuda() 143 | predictions, _ = models(inputs) 144 | predictions = m(predictions) 145 | #calculate the uncertainty score 146 | maximum, _ = torch.max(predictions, 1) 147 | pp1 = 1/classes - (1-maximum)/(classes-1) 148 | pp1 = (classes-1) * torch.pow(pp1, 2) 149 | pp2 = 1/classes - maximum 150 | pp2 = torch.pow(pp2, 2) 151 | pp3 = (1/classes) * (pp1 + pp2) 152 | var = torch.var(predictions, 1) 153 | score = 1 - maximum*pp3/var 154 | 155 | all_scores.extend(score.cpu().numpy()) 156 | all_indices.extend(indices.cpu().numpy()) 157 | 158 | all_indices1 = np.asarray(all_indices) 159 | all_scores1 = np.asarray(all_scores) 160 | dic = {} 161 | for i in range(len(all_scores1)): 162 | dic[all_indices1[i]] = all_scores1[i] 163 | return dic #Store the uncertainty scores into the dic with indices 164 | 165 | # 166 | def train(models, criterion, optimizers, schedulers, dataloaders, num_epochs): 167 | for epoch in range(num_epochs): 168 | schedulers.step() 169 | losss = train_epoch(models, criterion, optimizers, dataloaders, epoch) 170 | return 0 171 | 172 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.optim.lr_scheduler as lr_scheduler 5 | import resnet 6 | 7 | import os 8 | import numpy as np 9 | from sklearn.metrics import accuracy_score 10 | 11 | import sampler 12 | 13 | import sys 14 | 15 | 16 | 17 | class Solver: 18 | def __init__(self, args, test_dataloader): 19 | self.args = args 20 | self.test_dataloader = test_dataloader 21 | self.bce_loss = nn.BCELoss() 22 | self.mse_loss = nn.MSELoss() 23 | self.ce_loss = nn.CrossEntropyLoss() 24 | self.sampler = sampler.AdversarySampler(self.args.budget) 25 | 26 | def read_data(self, dataloader, labels=True): 27 | if labels: 28 | while True: 29 | for img, label, indices in dataloader: 30 | yield img, label 31 | else: 32 | while True: 33 | for img, _, indices in dataloader: 34 | yield img, indices 35 | 36 | def train(self, querry_dataloader, task_model, vae, discriminator, unlabeled_dataloader): 37 | labeled_data = self.read_data(querry_dataloader) 38 | unlabeled_data = self.read_data(unlabeled_dataloader, labels=False) 39 | 40 | optim_vae = optim.Adam(vae.parameters(), lr=5e-4) 41 | optim_discriminator = optim.Adam(discriminator.parameters(), lr=5e-4) 42 | 43 | vae.train() 44 | discriminator.train() 45 | task_model.train() 46 | change_lr_iter = self.args.train_iterations // 30 47 | 48 | # task_model for online uncertainty indicator 49 | print('Train the online uncertainty indicator:') 50 | criterion = nn.CrossEntropyLoss(reduction='none') 51 | optim_backbone = optim.SGD(task_model.parameters(), lr=0.1, 52 | momentum=0.9, weight_decay=5e-4) 53 | sched_backbone = lr_scheduler.MultiStepLR(optim_backbone, milestones= [200]) 54 | optimizers = optim_backbone 55 | schedulers = sched_backbone 56 | resnet.train(task_model, criterion, optimizers, schedulers, querry_dataloader, self.args.task_epochs) 57 | scoredic = resnet.uncertainty_score(task_model, unlabeled_dataloader, self.args.classes) 58 | #Above dict contains all the uncertainty scores for unlabeled samples 59 | #They will relabel all the unlabeled data's state. ps: the original state for unlabeled data is 1. 60 | 61 | print('The online uncertainty indicator is ready.') 62 | print('Train the generator and the discriminator.') 63 | #Begin to train the generator (VAE) and the discriminator. 64 | for iter_count in range(self.args.train_iterations): 65 | labeled_imgs, labels = next(labeled_data) 66 | unlabeled_imgs, indices = next(unlabeled_data) 67 | indices = indices.cpu().numpy() 68 | if self.args.cuda: 69 | labeled_imgs = labeled_imgs.cuda() 70 | unlabeled_imgs = unlabeled_imgs.cuda() 71 | labels = labels.cuda() 72 | if iter_count is not 0 and iter_count % change_lr_iter == 0: 73 | for param in optim_vae.param_groups: 74 | param['lr'] = param['lr'] * 0.85 75 | for param in optim_discriminator.param_groups: 76 | param['lr'] = param['lr'] * 0.85 77 | # VAE step 78 | for count in range(self.args.num_vae_steps): 79 | recon, z, mu, logvar, pred_label = vae(labeled_imgs, labeled = 1) 80 | labeled_task_loss = self.ce_loss(pred_label, labels) 81 | unsup_loss = self.vae_loss(labeled_imgs, recon, mu, logvar, self.args.beta) 82 | unlab_recon, unlab_z, unlab_mu, unlab_logvar = vae(unlabeled_imgs, labeled = 0) 83 | transductive_loss = self.vae_loss(unlabeled_imgs, 84 | unlab_recon, unlab_mu, unlab_logvar, self.args.beta) 85 | 86 | labeled_preds = discriminator(mu) 87 | unlabeled_preds = discriminator(unlab_mu) 88 | 89 | lab_real_preds = torch.ones(labeled_imgs.size(0)).cuda() 90 | unlab_real_preds = torch.ones(unlabeled_imgs.size(0)).cuda() 91 | 92 | dsc_loss = self.bce_loss(labeled_preds, lab_real_preds) + self.bce_loss(unlabeled_preds, unlab_real_preds) 93 | total_vae_loss = unsup_loss + transductive_loss + self.args.adversary_param * dsc_loss +5* labeled_task_loss 94 | optim_vae.zero_grad() 95 | total_vae_loss.backward() 96 | optim_vae.step() 97 | 98 | # sample new batch if needed to train the adversarial network 99 | if count < (self.args.num_vae_steps - 1): 100 | labeled_imgs, _ = next(labeled_data) 101 | unlabeled_imgs, indices = next(unlabeled_data) 102 | indices = indices.cpu().numpy() 103 | 104 | if self.args.cuda: 105 | labeled_imgs = labeled_imgs.cuda() 106 | unlabeled_imgs = unlabeled_imgs.cuda() 107 | labels = labels.cuda() 108 | 109 | # Discriminator step 110 | for count in range(self.args.num_adv_steps): 111 | with torch.no_grad(): 112 | _, _, mu, _ = vae(labeled_imgs, labeled = 0) 113 | _, _, unlab_mu, _ = vae(unlabeled_imgs, labeled = 0) 114 | 115 | labeled_preds = discriminator(mu) 116 | unlabeled_preds = discriminator(unlab_mu) 117 | 118 | #Relabeling the state of unlabeled samples, unlab_fake_preds is relabeled 119 | score = [] 120 | for i in range(len(indices)): 121 | score.append(scoredic[indices[i]]) 122 | # the score is the new state for unlabeled samples 123 | lab_real_preds = torch.zeros(labeled_imgs.size(0)).cuda() 124 | # replace the original state binary 1 with the relabeled state 125 | unlab_fake_preds = torch.tensor(score, dtype=torch.float).cuda() 126 | 127 | dsc_loss = self.bce_loss(labeled_preds, lab_real_preds) + self.bce_loss(unlabeled_preds, unlab_fake_preds) 128 | 129 | optim_discriminator.zero_grad() 130 | dsc_loss.backward() 131 | optim_discriminator.step() 132 | 133 | # sample new batch if needed to train the adversarial network 134 | if count < (self.args.num_adv_steps - 1): 135 | labeled_imgs, _ = next(labeled_data) 136 | unlabeled_imgs, indices = next(unlabeled_data) 137 | indices = indices.cpu().numpy() 138 | 139 | if self.args.cuda: 140 | labeled_imgs = labeled_imgs.cuda() 141 | unlabeled_imgs = unlabeled_imgs.cuda() 142 | labels = labels.cuda() 143 | 144 | if iter_count % 5000 == 0: 145 | print(' Current training iteration: {}'.format(iter_count) ) 146 | print(' Current vae model loss: {:.4f} {:.4f} '.format(total_vae_loss.item() , labeled_task_loss.item() )) 147 | print(' Current discriminator model loss: {:.4f}'.format(dsc_loss.item()) ) 148 | return vae, discriminator 149 | 150 | 151 | def sample_for_labeling(self, vae, discriminator, unlabeled_dataloader): 152 | querry_indices = self.sampler.sample(vae, 153 | discriminator, 154 | unlabeled_dataloader, 155 | self.args.cuda) 156 | 157 | return querry_indices 158 | 159 | 160 | def test(self, task_model): 161 | task_model.eval() 162 | total, correct = 0, 0 163 | for imgs, labels in self.test_dataloader: 164 | if self.args.cuda: 165 | imgs = imgs.cuda() 166 | 167 | with torch.no_grad(): 168 | preds = task_model(imgs) 169 | 170 | preds = torch.argmax(preds, dim=1).cpu().numpy() 171 | correct += accuracy_score(labels, preds, normalize=False) 172 | total += imgs.size(0) 173 | return correct / total * 100 174 | 175 | 176 | def vae_loss(self, x, recon, mu, logvar, beta): 177 | MSE = 5*self.mse_loss(recon, x) 178 | KLD = torch.mean(0.5 * torch.sum(torch.exp(logvar) + mu**2 - 1. - logvar, 1)) 179 | KLD = KLD * beta 180 | return (MSE + KLD) 181 | --------------------------------------------------------------------------------