├── README.md ├── utils ├── sampler.py ├── Dataset.py ├── randaugment.py ├── utils_loss.py ├── utils_algo.py └── utils_data.py ├── models ├── wideresnet.py ├── resnet.py ├── meta_layers.py └── models.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # LaGAM: Positive-Unlabeled Learning by Latent Group-Aware Meta Disambiguation 2 | 3 | This is the implementation of our CVPR 2024 paper: Positive-Unlabeled Learning by Latent Group-Aware Meta Disambiguation. 4 | 5 | ## Training 6 | ### CIFAR-10-1 7 | 8 | ```python 9 | python train.py --exp-dir experiment --arch resnet18 --dataset cifar10 --positive_list 0,1,8,9 --warmup_epoch 20 --n_positive 1000 --n_valid 500 --num_cluster 5 --cont_cutoff --identifier classifier --knn_aug --num_neighbors 10 --epochs 400 10 | ``` 11 | 12 | ### CIFAR-10-2 13 | 14 | ```python 15 | python train.py --exp-dir experiment --arch resnet18 --dataset cifar10 --positive_list 0,1,8,9 --warmup_epoch 20 --n_positive 1000 --n_valid 500 --num_cluster 5 --cont_cutoff --identifier classifier --knn_aug --num_neighbors 10 --epochs 400 --reverse 1 16 | ``` 17 | 18 | ### CIFAR-100-1 19 | 20 | ```python 21 | python train.py --exp-dir experiment --arch resnet18 --dataset cifar100 --positive_list 18,19 --warmup_epoch 20 --n_positive 1000 --n_valid 500 --num_cluster 100 --cont_cutoff --identifier classifier --knn_aug --num_neighbors 10 --epochs 400 22 | ``` 23 | 24 | ### CIFAR-100-2 25 | 26 | ```python 27 | python train.py --exp-dir experiment --arch resnet18 --dataset cifar100 --positive_list 0,1,7,8,11,12,13,14,15,16 --warmup_epoch 20 --n_positive 1000 --n_valid 500 --num_cluster 100 --cont_cutoff --identifier classifier --knn_aug --num_neighbors 10 --epochs 400 28 | ``` 29 | 30 | ### STL-10-1 31 | 32 | ```python 33 | python train.py --exp-dir experiment --arch resnet18 --dataset stl10 --positive_list 0,2,3,8,9 --warmup_epoch 20 --n_positive 1000 --n_valid 500 --num_cluster 100 --cont_cutoff --identifier classifier --knn_aug --num_neighbors 10 --epochs 400 34 | ``` 35 | 36 | ### STL-10-2 37 | 38 | ```python 39 | python train.py --exp-dir experiment --arch resnet18 --dataset stl10 --positive_list 0,2,3,8,9 --warmup_epoch 20 --n_positive 1000 --n_valid 500 --num_cluster 100 --cont_cutoff --identifier classifier --knn_aug --num_neighbors 10 --epochs 400 --reverse 1 40 | ``` 41 | 42 | -------------------------------------------------------------------------------- /utils/sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from torch.utils.data.sampler import Sampler 4 | import pdb 5 | 6 | 7 | class RandomCycleIter: 8 | def __init__(self, data, test_mode=False): 9 | self.data_list = list(data) 10 | self.length = len(self.data_list) 11 | self.i = self.length - 1 12 | self.test_mode = test_mode 13 | 14 | def __iter__(self): 15 | return self 16 | 17 | def __next__(self): 18 | self.i += 1 19 | 20 | if self.i == self.length: 21 | self.i = 0 22 | if not self.test_mode: 23 | random.shuffle(self.data_list) 24 | return self.data_list[self.i] 25 | 26 | 27 | def class_aware_sample_generator(cls_iter, data_iter_list, n, num_samples_cls=1): 28 | i = 0 29 | j = 0 30 | while i < n: 31 | if j >= num_samples_cls: 32 | j = 0 33 | 34 | if j == 0: 35 | temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]] * num_samples_cls)) 36 | yield temp_tuple[j] 37 | else: 38 | yield temp_tuple[j] 39 | 40 | i += 1 41 | j += 1 42 | 43 | 44 | class ClassAwareSampler(Sampler): 45 | def __init__( 46 | self, 47 | labels, 48 | num_classes, 49 | num_samples_cls=1, 50 | ): 51 | cls_data_list = [list() for _ in range(num_classes)] 52 | for i, label in enumerate(labels): 53 | cls_data_list[label].append(i) 54 | cls_available = [] 55 | for i in range(num_classes): 56 | if len(cls_data_list[i]) > num_samples_cls: 57 | cls_available.append(i) 58 | self.class_iter = RandomCycleIter(cls_available) 59 | self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list] 60 | self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list) 61 | self.num_samples_cls = num_samples_cls 62 | 63 | def __iter__(self): 64 | return class_aware_sample_generator( 65 | self.class_iter, self.data_iter_list, self.num_samples, self.num_samples_cls 66 | ) 67 | 68 | def __len__(self): 69 | return self.num_samples 70 | -------------------------------------------------------------------------------- /utils/Dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as dsets 5 | from .randaugment import RandomAugment 6 | import copy 7 | 8 | 9 | class DatasetClass(Dataset): 10 | def __init__( 11 | self, images, pu_labels, true_labels, data_stats, transform=None, train=True 12 | ): 13 | self.images = images 14 | self.pu_labels = pu_labels 15 | self.true_labels = true_labels 16 | self.transform = transform 17 | self.train = train 18 | 19 | if self.transform is None: 20 | self.weak_transform = transforms.Compose( 21 | [ 22 | transforms.ToPILImage(), 23 | transforms.RandomHorizontalFlip(), 24 | transforms.RandomCrop(data_stats["size"], padding=4), 25 | transforms.ToTensor(), 26 | transforms.Normalize(data_stats["mean"], data_stats["std"]), 27 | ] 28 | ) 29 | self.strong_transform = copy.deepcopy(self.weak_transform) 30 | self.strong_transform.transforms.insert(1, RandomAugment(3, 5)) 31 | 32 | def update_targets(self, new_labels, idxes): 33 | self.pu_labels[idxes] = new_labels 34 | 35 | def __len__(self): 36 | return len(self.true_labels) 37 | 38 | def __getitem__(self, index): 39 | if not self.train: 40 | image = self.transform(self.images[index]) 41 | true_label = self.true_labels[index] 42 | return image, true_label 43 | else: 44 | if self.transform is None: 45 | image_w = self.weak_transform(self.images[index]) 46 | image_s = self.strong_transform(self.images[index]) 47 | label = self.pu_labels[index] 48 | true_label = self.true_labels[index] 49 | 50 | return image_w, image_s, label, true_label, index 51 | else: 52 | label = self.pu_labels[index] 53 | image = self.transform(self.images[index]) 54 | true_label = self.true_labels[index] 55 | return image, image, label, true_label, index 56 | -------------------------------------------------------------------------------- /utils/randaugment.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | 8 | 9 | def AutoContrast(img, _): 10 | return PIL.ImageOps.autocontrast(img) 11 | 12 | 13 | def Brightness(img, v): 14 | assert v >= 0.0 15 | return PIL.ImageEnhance.Brightness(img).enhance(v) 16 | 17 | 18 | def Color(img, v): 19 | assert v >= 0.0 20 | return PIL.ImageEnhance.Color(img).enhance(v) 21 | 22 | 23 | def Contrast(img, v): 24 | assert v >= 0.0 25 | return PIL.ImageEnhance.Contrast(img).enhance(v) 26 | 27 | 28 | def Equalize(img, _): 29 | return PIL.ImageOps.equalize(img) 30 | 31 | 32 | def Invert(img, _): 33 | return PIL.ImageOps.invert(img) 34 | 35 | 36 | def Identity(img, v): 37 | return img 38 | 39 | 40 | def Posterize(img, v): 41 | v = int(v) 42 | v = max(1, v) 43 | return PIL.ImageOps.posterize(img, v) 44 | 45 | 46 | def Rotate(img, v): 47 | return img.rotate(v) 48 | 49 | 50 | def Sharpness(img, v): 51 | assert v >= 0.0 52 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 53 | 54 | 55 | def ShearX(img, v): 56 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 57 | 58 | 59 | def ShearY(img, v): 60 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 61 | 62 | 63 | def TranslateX(img, v): 64 | v = v * img.size[0] 65 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 66 | 67 | 68 | def TranslateXabs(img, v): 69 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 70 | 71 | 72 | def TranslateY(img, v): 73 | v = v * img.size[1] 74 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 75 | 76 | 77 | def TranslateYabs(img, v): 78 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 79 | 80 | 81 | def Solarize(img, v): 82 | assert 0 <= v <= 256 83 | return PIL.ImageOps.solarize(img, v) 84 | 85 | 86 | def Cutout(img, v): 87 | assert 0.0 <= v <= 0.5 88 | if v <= 0.0: 89 | return img 90 | 91 | v = v * img.size[0] 92 | return CutoutAbs(img, v) 93 | 94 | 95 | def CutoutAbs(img, v): 96 | if v < 0: 97 | return img 98 | w, h = img.size 99 | x0 = np.random.uniform(w) 100 | y0 = np.random.uniform(h) 101 | 102 | x0 = int(max(0, x0 - v / 2.0)) 103 | y0 = int(max(0, y0 - v / 2.0)) 104 | x1 = min(w, x0 + v) 105 | y1 = min(h, y0 + v) 106 | 107 | xy = (x0, y0, x1, y1) 108 | color = (125, 123, 114) 109 | img = img.copy() 110 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 111 | return img 112 | 113 | 114 | def augment_list(): 115 | l = [ 116 | (AutoContrast, 0, 1), 117 | (Brightness, 0.05, 0.95), 118 | (Color, 0.05, 0.95), 119 | (Contrast, 0.05, 0.95), 120 | (Equalize, 0, 1), 121 | (Identity, 0, 1), 122 | (Posterize, 4, 8), 123 | (Rotate, -30, 30), 124 | (Sharpness, 0.05, 0.95), 125 | (ShearX, -0.3, 0.3), 126 | (ShearY, -0.3, 0.3), 127 | (Solarize, 0, 256), 128 | (TranslateX, -0.3, 0.3), 129 | (TranslateY, -0.3, 0.3), 130 | ] 131 | return l 132 | 133 | 134 | class RandomAugment: 135 | def __init__(self, n, m): 136 | self.n = n 137 | self.m = m 138 | self.augment_list = augment_list() 139 | 140 | def __call__(self, img): 141 | ops = random.choices(self.augment_list, k=self.n) 142 | for op, min_val, max_val in ops: 143 | val = min_val + float(max_val - min_val) * random.random() 144 | img = op(img, val) 145 | cutout_val = random.random() * 0.5 146 | img = Cutout(img, cutout_val) 147 | return img 148 | 149 | 150 | if __name__ == "__main__": 151 | randaug = RandAugment(3, 5) 152 | print(randaug) 153 | for item in randaug.augment_list: 154 | print(item) 155 | -------------------------------------------------------------------------------- /models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | from models.meta_layers import * 8 | import math 9 | 10 | 11 | class MetaConv2d(MetaModule): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__() 14 | ignore = nn.Conv2d(*args, **kwargs) 15 | 16 | self.in_channels = ignore.in_channels 17 | self.out_channels = ignore.out_channels 18 | self.stride = ignore.stride 19 | self.padding = ignore.padding 20 | self.dilation = ignore.dilation 21 | self.groups = ignore.groups 22 | self.kernel_size = ignore.kernel_size 23 | 24 | self.register_buffer("weight", to_var(ignore.weight.data, requires_grad=True)) 25 | 26 | if ignore.bias is not None: 27 | self.register_buffer("bias", to_var(ignore.bias.data, requires_grad=True)) 28 | else: 29 | self.register_buffer("bias", None) 30 | 31 | def forward(self, x): 32 | return F.conv2d( 33 | x, 34 | self.weight, 35 | self.bias, 36 | self.stride, 37 | self.padding, 38 | self.dilation, 39 | self.groups, 40 | ) 41 | 42 | def named_leaves(self): 43 | return [("weight", self.weight), ("bias", self.bias)] 44 | 45 | 46 | class MetaBasicBlock(MetaModule): 47 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 48 | super(MetaBasicBlock, self).__init__() 49 | 50 | self.bn1 = MetaBatchNorm2d(in_planes) 51 | self.relu1 = nn.ReLU(inplace=True) 52 | self.conv1 = MetaConv2d( 53 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 54 | ) 55 | self.bn2 = MetaBatchNorm2d(out_planes) 56 | self.relu2 = nn.ReLU(inplace=True) 57 | self.conv2 = MetaConv2d( 58 | out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False 59 | ) 60 | self.droprate = dropRate 61 | self.equalInOut = in_planes == out_planes 62 | self.convShortcut = ( 63 | (not self.equalInOut) 64 | and MetaConv2d( 65 | in_planes, 66 | out_planes, 67 | kernel_size=1, 68 | stride=stride, 69 | padding=0, 70 | bias=False, 71 | ) 72 | or None 73 | ) 74 | 75 | def forward(self, x): 76 | if not self.equalInOut: 77 | x = self.relu1(self.bn1(x)) 78 | else: 79 | out = self.relu1(self.bn1(x)) 80 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 81 | if self.droprate > 0: 82 | out = F.dropout(out, p=self.droprate, training=self.training) 83 | out = self.conv2(out) 84 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 85 | 86 | 87 | class MetaNetworkBlock(MetaModule): 88 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 89 | super(MetaNetworkBlock, self).__init__() 90 | self.layer = self._make_layer( 91 | block, in_planes, out_planes, nb_layers, stride, dropRate 92 | ) 93 | 94 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 95 | layers = [] 96 | for i in range(int(nb_layers)): 97 | layers.append( 98 | block( 99 | i == 0 and in_planes or out_planes, 100 | out_planes, 101 | i == 0 and stride or 1, 102 | dropRate, 103 | ) 104 | ) 105 | return nn.Sequential(*layers) 106 | 107 | def forward(self, x): 108 | return self.layer(x) 109 | 110 | 111 | class WideResNet(MetaModule): 112 | def __init__(self, num_classes=2, depth=28, widen_factor=10, dropRate=0.0): 113 | super(WideResNet, self).__init__() 114 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 115 | assert (depth - 4) % 6 == 0 116 | n = (depth - 4) / 6 117 | block = MetaBasicBlock 118 | self.conv1 = MetaConv2d( 119 | 3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False 120 | ) 121 | self.block1 = MetaNetworkBlock( 122 | n, nChannels[0], nChannels[1], block, 1, dropRate 123 | ) 124 | self.block2 = MetaNetworkBlock( 125 | n, nChannels[1], nChannels[2], block, 2, dropRate 126 | ) 127 | self.block3 = MetaNetworkBlock( 128 | n, nChannels[2], nChannels[3], block, 2, dropRate 129 | ) 130 | self.bn1 = MetaBatchNorm2d(nChannels[3]) 131 | self.relu = nn.ReLU(inplace=True) 132 | self.fc = MetaLinear(nChannels[3], num_classes) 133 | self.nChannels = nChannels[3] 134 | 135 | for m in self.modules(): 136 | if isinstance(m, MetaConv2d): 137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 138 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 139 | elif isinstance(m, MetaBatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | elif isinstance(m, MetaLinear): 143 | m.bias.data.zero_() 144 | 145 | def forward(self, x): 146 | out = self.conv1(x) 147 | out = self.block1(out) 148 | out = self.block2(out) 149 | out = self.block3(out) 150 | out = self.relu(self.bn1(out)) 151 | out = F.avg_pool2d(out, 8) 152 | out = out.view(-1, self.nChannels) 153 | return self.fc(out) 154 | -------------------------------------------------------------------------------- /utils/utils_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | class BCELoss(nn.Module): 7 | def __init__(self, ent_loss=False): 8 | super().__init__() 9 | self.ent_loss = ent_loss 10 | 11 | def forward(self, preds, label, weight=None): 12 | preds = torch.sigmoid(preds) 13 | logits_ = torch.cat([1.0 - preds, preds], dim=1) 14 | logits_ = torch.clamp(logits_, 1e-4, 1.0 - 1e-4) 15 | 16 | loss_entries = (-label * logits_.log()).sum(dim=0) 17 | label_num_reverse = 1.0 / label.sum(dim=0) 18 | loss = (loss_entries * label_num_reverse).sum() 19 | 20 | if self.ent_loss: 21 | loss_ent = -(logits_ * logits_.log()).sum(1).mean() 22 | loss = loss + loss_ent * 0.1 23 | return loss 24 | 25 | 26 | class NegEntropy(object): 27 | def __call__(self, outputs): 28 | probs = torch.softmax(outputs, dim=1) 29 | return torch.mean(torch.sum(probs.log() * probs, dim=1)) 30 | 31 | 32 | def consistency_loss( 33 | logits_w, 34 | logits_s, 35 | sin_label_idx, 36 | name="ce", 37 | T=1.0, 38 | p_cutoff=0.0, 39 | use_hard_labels=True, 40 | ): 41 | assert name in ["ce", "L2"] 42 | logits_w = logits_w.detach() 43 | if name == "L2": 44 | assert logits_w.size() == logits_s.size() 45 | pred_w = torch.softmax(logits_w, dim=1).detach() 46 | pred_s = torch.softmax(logits_s, dim=1).detach() 47 | return F.mse_loss(pred_s, pred_w, reduction="mean") 48 | 49 | elif name == "L2_mask": 50 | pass 51 | 52 | elif name == "ce": 53 | pseudo_label = torch.softmax(logits_w, dim=-1) 54 | max_probs = pseudo_label[range(pseudo_label.shape[0]), sin_label_idx] 55 | mask = max_probs.ge(p_cutoff).float() 56 | 57 | if use_hard_labels: 58 | masked_loss = ( 59 | ce_loss(logits_s, sin_label_idx, use_hard_labels, reduction="none") 60 | * mask 61 | ) 62 | else: 63 | pseudo_label = torch.softmax(logits_w / T, dim=-1) 64 | masked_loss = ce_loss(logits_s, pseudo_label, use_hard_labels) * mask 65 | return masked_loss.mean(), mask.mean() 66 | 67 | else: 68 | assert Exception("Not Implemented consistency_loss") 69 | 70 | 71 | class ContLoss(nn.Module): 72 | def __init__( 73 | self, 74 | temperature=0.07, 75 | cont_cutoff=False, 76 | knn_aug=False, 77 | num_neighbors=0, 78 | contrastive_clustering=1, 79 | ): 80 | super().__init__() 81 | self.temperature = temperature 82 | self.contrastive_clustering = contrastive_clustering 83 | self.cont_cutoff = cont_cutoff 84 | self.knn_aug = knn_aug 85 | self.num_neighbors = num_neighbors 86 | 87 | def forward(self, q, k, cluster_idxes=None, preds=None, start_knn_aug=False): 88 | batch_size = q.shape[0] 89 | 90 | q_and_k = torch.cat([q, k], dim=0) 91 | l_i = torch.einsum("nc,kc->nk", [q, q_and_k]) / self.temperature 92 | 93 | self_mask = torch.ones_like(l_i, dtype=torch.float) 94 | self_mask = ( 95 | torch.scatter(self_mask, 1, torch.arange(batch_size).view(-1, 1).cuda(), 0) 96 | .detach() 97 | .cuda() 98 | ) 99 | 100 | positive_mask_i = torch.zeros_like(l_i, dtype=torch.float) 101 | positive_mask_i = ( 102 | torch.scatter( 103 | positive_mask_i, 104 | 1, 105 | batch_size + torch.arange(batch_size).view(-1, 1).cuda(), 106 | 1, 107 | ) 108 | .detach() 109 | .cuda() 110 | ) 111 | 112 | l_i_exp = torch.exp(l_i) 113 | l_i_exp_sum = torch.sum((l_i_exp * self_mask), dim=1, keepdim=True) 114 | 115 | loss = -torch.sum( 116 | torch.log(l_i_exp / l_i_exp_sum) * positive_mask_i, dim=1 117 | ).mean() 118 | 119 | if cluster_idxes is not None and self.contrastive_clustering: 120 | cluster_idxes = cluster_idxes.view(-1, 1) 121 | cluster_idxes_kq = torch.cat([cluster_idxes, cluster_idxes], dim=0) 122 | mask = torch.eq(cluster_idxes, cluster_idxes_kq.T).float().cuda() 123 | 124 | if self.cont_cutoff: 125 | preds = preds.detach() 126 | pred_labels = (preds > 0.5) * 1 127 | pred_labels = pred_labels.view(-1, 1) 128 | pred_labels_kq = torch.cat([pred_labels, pred_labels], dim=0) 129 | label_mask = torch.eq(pred_labels, pred_labels_kq.T).float().cuda() 130 | 131 | mask = mask * label_mask 132 | 133 | if self.knn_aug and start_knn_aug: 134 | cosine_corr = q @ q_and_k.T 135 | _, kNN_index = torch.topk( 136 | cosine_corr, k=self.num_neighbors, dim=-1, largest=True 137 | ) 138 | mask_kNN = torch.scatter( 139 | torch.zeros(mask.shape).cuda(), 1, kNN_index, 1 140 | ) 141 | mask = ((mask + mask_kNN) > 0.5) * 1 142 | 143 | mask = mask.float().detach().cuda() 144 | batch_size = q.shape[0] 145 | anchor_dot_contrast = torch.div( 146 | torch.matmul(q, q_and_k.T), self.temperature 147 | ) 148 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 149 | logits = anchor_dot_contrast - logits_max.detach() 150 | 151 | logits_mask = torch.scatter( 152 | torch.ones_like(mask), 1, torch.arange(batch_size).view(-1, 1).cuda(), 0 153 | ) 154 | mask = mask * logits_mask 155 | 156 | exp_logits = torch.exp(logits) * logits_mask 157 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12) 158 | 159 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 160 | 161 | loss_prot = -mean_log_prob_pos.mean() 162 | loss += loss_prot 163 | 164 | return loss 165 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | from models.meta_layers import * 7 | 8 | from torch.utils.checkpoint import checkpoint 9 | 10 | 11 | class PreActBlockMeta(MetaModule): 12 | """Pre-activation version of the BasicBlock.""" 13 | 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlockMeta, self).__init__() 18 | self.bn1 = MetaBatchNorm2d(in_planes) 19 | self.conv1 = MetaConv2d( 20 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 21 | ) 22 | self.bn2 = MetaBatchNorm2d(planes) 23 | self.conv2 = MetaConv2d( 24 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 25 | ) 26 | 27 | if stride != 1 or in_planes != self.expansion * planes: 28 | self.shortcut = nn.Sequential( 29 | MetaConv2d( 30 | in_planes, 31 | self.expansion * planes, 32 | kernel_size=1, 33 | stride=stride, 34 | bias=False, 35 | ) 36 | ) 37 | 38 | def forward(self, x): 39 | out = F.relu(self.bn1(x)) 40 | shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x 41 | out = self.conv1(out) 42 | out = self.conv2(F.relu(self.bn2(out))) 43 | out += shortcut 44 | return out 45 | 46 | 47 | class PreActBottleneckMeta(MetaModule): 48 | """Pre-activation version of the original Bottleneck module.""" 49 | 50 | expansion = 4 51 | 52 | def __init__(self, in_planes, planes, stride=1): 53 | super(PreActBottleneckMeta, self).__init__() 54 | self.bn1 = MetaBatchNorm2d(in_planes) 55 | self.conv1 = MetaConv2d(in_planes, planes, kernel_size=1, bias=False) 56 | self.bn2 = MetaBatchNorm2d(planes) 57 | self.conv2 = MetaConv2d( 58 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 59 | ) 60 | self.bn3 = MetaBatchNorm2d(planes) 61 | self.conv3 = MetaConv2d( 62 | planes, self.expansion * planes, kernel_size=1, bias=False 63 | ) 64 | 65 | if stride != 1 or in_planes != self.expansion * planes: 66 | self.shortcut = nn.Sequential( 67 | MetaConv2d( 68 | in_planes, 69 | self.expansion * planes, 70 | kernel_size=1, 71 | stride=stride, 72 | bias=False, 73 | ) 74 | ) 75 | 76 | def forward(self, x): 77 | out = F.relu(self.bn1(x)) 78 | shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x 79 | out = self.conv1(out) 80 | out = self.conv2(F.relu(self.bn2(out))) 81 | out = self.conv3(F.relu(self.bn3(out))) 82 | out += shortcut 83 | return out 84 | 85 | 86 | class PreActResNetMeta(MetaModule): 87 | def __init__(self, block, num_blocks, num_classes=10, use_checkpoint=False): 88 | super(PreActResNetMeta, self).__init__() 89 | self.in_planes = 64 90 | 91 | self.conv1 = MetaConv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 92 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 93 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 94 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 95 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 96 | 97 | input_dim = 512 * block.expansion 98 | self.classifier = MetaLinear(input_dim, num_classes) 99 | 100 | self.fc4 = MetaLinear(input_dim, input_dim) 101 | self.fc5 = MetaLinear(input_dim, 128) 102 | self.head = nn.Sequential(self.fc4, nn.ReLU(), self.fc5) 103 | 104 | self.use_checkpoint = use_checkpoint 105 | 106 | def _make_layer(self, block, planes, num_blocks, stride): 107 | strides = [stride] + [1] * (num_blocks - 1) 108 | layers = [] 109 | for stride in strides: 110 | layers.append(block(self.in_planes, planes, stride)) 111 | self.in_planes = planes * block.expansion 112 | return nn.Sequential(*layers) 113 | 114 | def forward(self, x, flag_feature=False): 115 | out = x 116 | out = out + torch.zeros( 117 | 1, dtype=out.dtype, device=out.device, requires_grad=True 118 | ) 119 | if self.use_checkpoint: 120 | out = checkpoint(self.conv1, out) 121 | out = checkpoint(self.layer1, out) 122 | out = checkpoint(self.layer2, out) 123 | out = checkpoint(self.layer3, out) 124 | out = checkpoint(self.layer4, out) 125 | else: 126 | out = self.conv1(out) 127 | out = self.layer1(out) 128 | out = self.layer2(out) 129 | out = self.layer3(out) 130 | out = self.layer4(out) 131 | out = F.adaptive_max_pool2d(out, 1) 132 | out = out.view(out.size(0), -1) 133 | y = self.classifier(out) 134 | feat_cl = F.normalize(self.head(out), dim=1) 135 | if flag_feature: 136 | return y, feat_cl 137 | else: 138 | return y 139 | 140 | 141 | def preact_resnet_meta18(): 142 | return PreActResNetMeta(PreActBlockMeta, [2, 2, 2, 2]) 143 | 144 | 145 | def preact_resnet_meta2332(): 146 | return PreActResNetMeta(PreActBlockMeta, [2, 3, 3, 2]) 147 | 148 | 149 | def preact_resnet_meta3333(): 150 | return PreActResNetMeta(PreActBlockMeta, [3, 3, 3, 3]) 151 | 152 | 153 | def preact_resnet_meta34(): 154 | return PreActResNetMeta(PreActBlockMeta, [3, 4, 6, 3]) 155 | 156 | 157 | def preact_resnet_meta50(): 158 | return PreActResNetMeta(PreActBottleneckMeta, [3, 4, 6, 3]) 159 | 160 | 161 | def preActResNetMeta101(): 162 | return PreActResNetMeta(PreActBottleneckMeta, [3, 4, 23, 3]) 163 | 164 | 165 | def preActResNetMeta152(): 166 | return PreActResNetMeta(PreActBottleneckMeta, [3, 8, 36, 3]) 167 | -------------------------------------------------------------------------------- /utils/utils_algo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | import copy 7 | import faiss 8 | import time 9 | from sklearn.metrics import f1_score, roc_auc_score, precision_score, recall_score 10 | 11 | 12 | class AverageMeter(object): 13 | """Computes and stores the average and current value""" 14 | 15 | def __init__(self, name, fmt=":f"): 16 | self.name = name 17 | self.fmt = fmt 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | def __str__(self): 33 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 34 | return fmtstr.format(**self.__dict__) 35 | 36 | 37 | class ProgressMeter(object): 38 | def __init__(self, num_batches, meters, prefix=""): 39 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 40 | self.meters = meters 41 | self.prefix = prefix 42 | 43 | def display(self, batch): 44 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 45 | entries += [str(meter) for meter in self.meters] 46 | print("\t".join(entries)) 47 | 48 | def _get_batch_fmtstr(self, num_batches): 49 | num_digits = len(str(num_batches // 1)) 50 | fmt = "{:" + str(num_digits) + "d}" 51 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 52 | 53 | 54 | def adjust_learning_rate(args, optimizer, epoch): 55 | lr = args.lr 56 | if args.cosine: 57 | eta_min = lr * (args.lr_decay_rate**3) 58 | lr = ( 59 | eta_min + (lr - eta_min) * (1 + math.cos(math.pi * epoch / args.epochs)) / 2 60 | ) 61 | else: 62 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 63 | if steps > 0: 64 | lr = lr * (args.lr_decay_rate**steps) 65 | 66 | for param_group in optimizer.param_groups: 67 | param_group["lr"] = lr 68 | 69 | 70 | def accuracy(output, target, topk=(1,)): 71 | """Computes the accuracy over the k top predictions for the specified values of k""" 72 | with torch.no_grad(): 73 | maxk = max(topk) 74 | batch_size = target.size(0) 75 | 76 | output, target = output.cpu(), target.cpu() 77 | _, pred = output.topk(maxk, 1, True, True) 78 | 79 | # auc = roc_auc_score(target, output[:, 1]) 80 | # f1 = f1_score(target, pred[:, 0], average=None) 81 | # f1 = np.append(f1, f1_score(target, pred[:, 0], average="macro")) 82 | # f1 = np.append(f1, f1_score(target, pred[:, 0], average="micro")) 83 | # recall = recall_score(target, pred[:, 0], average=None) 84 | # recall = np.append(recall, recall_score(target, pred[:, 0], average="macro")) 85 | # recall = np.append(recall, recall_score(target, pred[:, 0], average="micro")) 86 | # precision = precision_score(target, pred[:, 0], average=None) 87 | # precision = np.append( 88 | # precision, precision_score(target, pred[:, 0], average="macro") 89 | # ) 90 | # precision = np.append( 91 | # precision, precision_score(target, pred[:, 0], average="micro") 92 | # ) 93 | 94 | pred = pred.t() 95 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 96 | 97 | res = [] 98 | for k in topk: 99 | correct_k = correct[:k].reshape((-1,)).float().sum(0, keepdim=True) 100 | res.append(correct_k.mul_(100.0 / batch_size)) 101 | # return res, auc, f1, recall, precision 102 | return res 103 | 104 | 105 | def accuracy_check(loader, model, device): 106 | with torch.no_grad(): 107 | total, num_samples = 0, 0 108 | for images, labels in loader: 109 | labels, images = labels.to(device), images.to(device) 110 | outputs, _ = model(images) 111 | _, predicted = torch.max(outputs.data, 1) 112 | total += (predicted == labels).sum().item() 113 | num_samples += labels.size(0) 114 | return total / num_samples 115 | 116 | 117 | def sigmoid_rampup(current, rampup_length, exp_coe=5.0): 118 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 119 | if rampup_length == 0: 120 | return 1.0 121 | else: 122 | current = np.clip(current, 0.0, rampup_length) 123 | phase = 1.0 - current / rampup_length 124 | return float(np.exp(-exp_coe * phase * phase)) 125 | 126 | 127 | def linear_rampup(current, rampup_length): 128 | """Linear rampup""" 129 | assert current >= 0 and rampup_length >= 0 130 | if current >= rampup_length: 131 | return 1.0 132 | else: 133 | return current / rampup_length 134 | 135 | 136 | def cosine_rampdown(current, rampdown_length): 137 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 138 | assert 0 <= current <= rampdown_length 139 | return float(0.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 140 | 141 | 142 | def run_kmeans(x, args): 143 | """ 144 | Args: 145 | x: data to be clustered 146 | """ 147 | start_time = time.time() 148 | print("performing kmeans clustering") 149 | results = {"im2cluster": [], "centroids": [], "density": []} 150 | 151 | num_cluster = args.num_cluster 152 | d = x.shape[1] 153 | k = int(num_cluster) 154 | clus = faiss.Clustering(d, k) 155 | clus.verbose = False 156 | clus.niter = 20 157 | clus.nredo = 5 158 | clus.max_points_per_centroid = 1000 159 | clus.min_points_per_centroid = 10 160 | 161 | res = faiss.StandardGpuResources() 162 | cfg = faiss.GpuIndexFlatConfig() 163 | cfg.useFloat16 = False 164 | cfg.device = args.gpu 165 | index = faiss.GpuIndexFlatL2(res, d, cfg) 166 | 167 | clus.train(x, index) 168 | 169 | D, I = index.search(x, 1) 170 | im2cluster = [int(n[0]) for n in I] 171 | 172 | centroids = faiss.vector_to_array(clus.centroids).reshape(k, d) 173 | 174 | Dcluster = [[] for c in range(k)] 175 | for im, i in enumerate(im2cluster): 176 | Dcluster[i].append(D[im][0]) 177 | 178 | density = np.zeros(k) 179 | for i, dist in enumerate(Dcluster): 180 | if len(dist) > 1: 181 | d = (np.asarray(dist) ** 0.5).mean() / np.log(len(dist) + 10) 182 | density[i] = d 183 | 184 | dmax = density.max() 185 | for i, dist in enumerate(Dcluster): 186 | if len(dist) <= 1: 187 | density[i] = dmax 188 | 189 | density = density.clip(np.percentile(density, 10), np.percentile(density, 90)) 190 | density = args.temperature * density / density.mean() 191 | 192 | centroids = torch.Tensor(centroids).cuda() 193 | centroids = nn.functional.normalize(centroids, p=2, dim=1) 194 | 195 | im2cluster = torch.LongTensor(im2cluster).cuda() 196 | density = torch.Tensor(density).cuda() 197 | 198 | results["centroids"] = centroids 199 | results["density"] = density 200 | results["im2cluster"] = im2cluster 201 | 202 | print("Kmeans end. Eplapsed {} s".format(time.time() - start_time)) 203 | 204 | return results 205 | -------------------------------------------------------------------------------- /utils/utils_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch 5 | import torchvision.datasets as dsets 6 | import torchvision.transforms as transforms 7 | from .Dataset import DatasetClass 8 | from .utils_algo import * 9 | 10 | 11 | def load_dataset(args): 12 | batch_size = args.batch_size 13 | 14 | if args.dataset == "cifar10": 15 | mean, std, size = (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261), 32 16 | test_temp_dataset = dsets.CIFAR10( 17 | root="./data", train=False, transform=None, download=True 18 | ) 19 | elif args.dataset == "cifar100": 20 | mean, std, size = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761), 32 21 | test_temp_dataset = dsets.CIFAR100( 22 | root="./data", train=False, transform=None, download=True 23 | ) 24 | test_temp_dataset.targets = sparse2coarse(test_temp_dataset.targets) 25 | elif args.dataset == "stl10": 26 | mean, std, size = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), 96 27 | test_temp_dataset = dsets.STL10( 28 | root="./data", split="test", transform=None, download=True 29 | ) 30 | test_temp_dataset.targets = test_temp_dataset.labels 31 | test_temp_dataset.data = transpose( 32 | np.array(test_temp_dataset.data), source="NCHW", target="NHWC" 33 | ) 34 | elif args.dataset == "alz": 35 | mean, std, size = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), 224 36 | test_temp_dataset = dsets.CIFAR10( 37 | root="./data", train=False, transform=None, download=True 38 | ) 39 | test_temp_dataset.data, test_temp_dataset.targets = load_alz(train=False) 40 | else: 41 | raise NotImplementedError("Wrong dataset arguments.") 42 | 43 | data_stats = {"mean": mean, "std": std, "size": size} 44 | 45 | test_transform = transforms.Compose( 46 | [transforms.ToTensor(), transforms.Normalize(mean, std)] 47 | ) 48 | 49 | data_test, test_labels_temp = np.array(test_temp_dataset.data), np.array( 50 | test_temp_dataset.targets 51 | ) 52 | test_labels = binarize_labels(args, test_labels_temp) 53 | test_dataset = DatasetClass( 54 | data_test, None, test_labels, data_stats, transform=test_transform, train=False 55 | ) 56 | test_loader = torch.utils.data.DataLoader( 57 | dataset=test_dataset, batch_size=batch_size * 4, shuffle=False, num_workers=4 58 | ) 59 | 60 | all_data, all_labels = get_train_data(args.dataset) 61 | all_labels_pu = binarize_labels(args, all_labels) 62 | train_labeled_idxs, train_unlabeled_idxs, valid_idxs, prior = train_val_split( 63 | all_labels, args.n_positive, args.positive_list, args.n_valid 64 | ) 65 | 66 | train_idxs = [] 67 | for t in train_labeled_idxs: 68 | train_idxs.append(t) 69 | for t in train_unlabeled_idxs: 70 | train_idxs.append(t) 71 | train_data = all_data[train_idxs] 72 | train_pu_labels = np.zeros(len(train_idxs)) 73 | train_pu_labels[: len(train_labeled_idxs)] = 1 74 | train_labels = all_labels_pu[train_idxs] 75 | 76 | print("Training data num: ", len(train_pu_labels)) 77 | print("Positive label num: ", train_pu_labels.sum()) 78 | 79 | if args.using_cont: 80 | train_transform = None 81 | else: 82 | train_transform = test_transform 83 | 84 | train_dataset = DatasetClass( 85 | train_data, train_pu_labels, train_labels, data_stats, transform=train_transform 86 | ) 87 | train_loader = torch.utils.data.DataLoader( 88 | dataset=train_dataset, 89 | batch_size=batch_size, 90 | shuffle=True, 91 | num_workers=4, 92 | pin_memory=True, 93 | drop_last=True, 94 | ) 95 | 96 | valid_data = all_data[valid_idxs] 97 | valid_labels = all_labels_pu[valid_idxs] 98 | valid_dataset = DatasetClass( 99 | valid_data, 100 | None, 101 | valid_labels, 102 | data_stats, 103 | transform=test_transform, 104 | train=False, 105 | ) 106 | valid_loader = torch.utils.data.DataLoader( 107 | dataset=valid_dataset, 108 | batch_size=batch_size, 109 | shuffle=True, 110 | num_workers=4, 111 | pin_memory=True, 112 | ) 113 | 114 | eval_dataset = DatasetClass( 115 | train_data, train_pu_labels, train_labels, data_stats, transform=test_transform 116 | ) 117 | eval_loader = torch.utils.data.DataLoader( 118 | dataset=eval_dataset, 119 | batch_size=batch_size, 120 | shuffle=True, 121 | num_workers=4, 122 | pin_memory=True, 123 | drop_last=False, 124 | ) 125 | 126 | dim = train_dataset.images.size / len(train_dataset.images) 127 | 128 | return train_loader, test_loader, valid_loader, eval_loader, dim 129 | 130 | 131 | def get_train_data(dataset): 132 | if dataset == "cifar10": 133 | temp_train_dataset = dsets.CIFAR10( 134 | root="./data", train=True, download=True, transform=None 135 | ) 136 | elif dataset == "cifar100": 137 | temp_train_dataset = dsets.CIFAR100( 138 | root="./data", train=True, download=True, transform=None 139 | ) 140 | temp_train_dataset.targets = sparse2coarse(temp_train_dataset.targets) 141 | elif dataset == "stl10": 142 | temp_train_dataset = dsets.STL10( 143 | root="./data", split="train+unlabeled", download=True, transform=None 144 | ) 145 | temp_train_dataset.targets = temp_train_dataset.labels 146 | temp_train_dataset.data = transpose( 147 | np.array(temp_train_dataset.data), source="NCHW", target="NHWC" 148 | ) 149 | elif dataset == "alz": 150 | return load_alz() 151 | data, labels = np.array(temp_train_dataset.data), np.array( 152 | temp_train_dataset.targets 153 | ) 154 | return data, labels 155 | 156 | 157 | def train_val_split(labels, n_labeled, positive_label_list, val_num): 158 | labels = np.array(labels) 159 | label_types = np.unique(labels) 160 | 161 | train_labeled_idxs = [] 162 | train_unlabeled_idxs = [] 163 | valid_idxs = [] 164 | n_labeled_per_class = n_labeled // len(positive_label_list) 165 | 166 | num_positive = 0 167 | label_num = len(label_types) 168 | if -1 in label_types: 169 | label_num -= 1 170 | val_num_per_clss = val_num // label_num 171 | 172 | for i in label_types: 173 | idxs = np.where(labels == i)[0] 174 | 175 | if i == -1: 176 | print( 177 | "[Warning] Label {} detected, collected to unlabeled data. ".format(i) 178 | ) 179 | elif i != -1: 180 | np.random.shuffle(idxs) 181 | 182 | valid_piece = idxs[:val_num_per_clss] 183 | valid_idxs.extend(valid_piece) 184 | 185 | idxs = np.array(list(set(idxs.tolist()) - set(valid_piece.tolist()))) 186 | 187 | if i in positive_label_list: 188 | pos_piece = idxs[:n_labeled_per_class] 189 | train_labeled_idxs.extend(pos_piece) 190 | num_positive += len(idxs) 191 | idxs = list(set(idxs.tolist()) - set(pos_piece.tolist())) 192 | 193 | train_unlabeled_idxs.extend(idxs) 194 | 195 | np.random.shuffle(train_labeled_idxs) 196 | np.random.shuffle(train_unlabeled_idxs) 197 | 198 | prior = num_positive / (len(train_unlabeled_idxs) + len(train_labeled_idxs)) 199 | 200 | return train_labeled_idxs, train_unlabeled_idxs, valid_idxs, prior 201 | 202 | 203 | def binarize_labels(args, labels): 204 | if not args.reverse: 205 | return np.array([1 if l in args.positive_list else 0 for l in labels]) 206 | else: 207 | return np.array([1 if l not in args.positive_list else 0 for l in labels]) 208 | 209 | 210 | def transpose(x, source="NCHW", target="NHWC"): 211 | """ 212 | N: batch size 213 | H: height 214 | W: weight 215 | C: channel 216 | """ 217 | return x.transpose([source.index(d) for d in target]) 218 | 219 | 220 | def normalise(x, mean, std): 221 | x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)] 222 | x -= mean * 255 223 | x *= 1.0 / (255 * std) 224 | return x 225 | 226 | 227 | def sparse2coarse(targets): 228 | """Convert Pytorch CIFAR100 sparse targets to coarse targets. 229 | 230 | Usage: 231 | trainset = torchvision.datasets.CIFAR100(path) 232 | trainset.targets = sparse2coarse(trainset.targets) 233 | """ 234 | coarse_labels = np.array( 235 | [ 236 | 4, 237 | 1, 238 | 14, 239 | 8, 240 | 0, 241 | 6, 242 | 7, 243 | 7, 244 | 18, 245 | 3, 246 | 3, 247 | 14, 248 | 9, 249 | 18, 250 | 7, 251 | 11, 252 | 3, 253 | 9, 254 | 7, 255 | 11, 256 | 6, 257 | 11, 258 | 5, 259 | 10, 260 | 7, 261 | 6, 262 | 13, 263 | 15, 264 | 3, 265 | 15, 266 | 0, 267 | 11, 268 | 1, 269 | 10, 270 | 12, 271 | 14, 272 | 16, 273 | 9, 274 | 11, 275 | 5, 276 | 5, 277 | 19, 278 | 8, 279 | 8, 280 | 15, 281 | 13, 282 | 14, 283 | 17, 284 | 18, 285 | 10, 286 | 16, 287 | 4, 288 | 17, 289 | 4, 290 | 2, 291 | 0, 292 | 17, 293 | 4, 294 | 18, 295 | 17, 296 | 10, 297 | 3, 298 | 2, 299 | 12, 300 | 12, 301 | 16, 302 | 12, 303 | 1, 304 | 9, 305 | 19, 306 | 2, 307 | 10, 308 | 0, 309 | 1, 310 | 16, 311 | 12, 312 | 9, 313 | 13, 314 | 15, 315 | 13, 316 | 16, 317 | 19, 318 | 2, 319 | 4, 320 | 6, 321 | 19, 322 | 5, 323 | 5, 324 | 8, 325 | 19, 326 | 18, 327 | 1, 328 | 2, 329 | 15, 330 | 6, 331 | 0, 332 | 17, 333 | 8, 334 | 14, 335 | 13, 336 | ] 337 | ) 338 | return coarse_labels[targets] 339 | 340 | 341 | def load_alz(train=True): 342 | class_list = ["MildDemented", "ModerateDemented", "NonDemented", "VeryMildDemented"] 343 | data = [] 344 | targets = [] 345 | for class_name in class_list: 346 | if train: 347 | baseDir = "./data/Alzheimer/train/" + class_name 348 | else: 349 | baseDir = "./data/Alzheimer/test/" + class_name 350 | for file in os.listdir(baseDir): 351 | dir = baseDir + "/" + file 352 | img = cv2.imread(dir) 353 | img = cv2.resize(img, (224, 224)) 354 | data.append(img) 355 | targets.append(class_list.index(class_name)) 356 | 357 | return np.array(data), np.array(targets) 358 | -------------------------------------------------------------------------------- /models/meta_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torchvision 6 | from torch.autograd import Variable 7 | import itertools 8 | from utils import * 9 | 10 | 11 | def to_var(x, requires_grad=True): 12 | if torch.cuda.is_available(): 13 | x = x.cuda() 14 | return Variable(x, requires_grad=requires_grad) 15 | 16 | 17 | class MetaModule(nn.Module): 18 | def params(self): 19 | for name, param in self.named_params(self): 20 | yield param 21 | 22 | def named_leaves(self): 23 | return [] 24 | 25 | def named_submodules(self): 26 | return [] 27 | 28 | def named_params(self, curr_module=None, memo=None, prefix=""): 29 | if memo is None: 30 | memo = set() 31 | 32 | if hasattr(curr_module, "named_leaves"): 33 | for name, p in curr_module.named_leaves(): 34 | if p is not None and p not in memo: 35 | memo.add(p) 36 | yield prefix + ("." if prefix else "") + name, p 37 | else: 38 | for name, p in curr_module._parameters.items(): 39 | if p is not None and p not in memo: 40 | memo.add(p) 41 | yield prefix + ("." if prefix else "") + name, p 42 | 43 | for mname, module in curr_module.named_children(): 44 | submodule_prefix = prefix + ("." if prefix else "") + mname 45 | for name, p in self.named_params(module, memo, submodule_prefix): 46 | yield name, p 47 | 48 | def update_params( 49 | self, 50 | lr_inner, 51 | first_order=False, 52 | source_params=None, 53 | detach=False, 54 | identifier=None, 55 | ): 56 | if source_params is not None: 57 | if identifier is None: 58 | named_params = self.named_params(self) 59 | else: 60 | named_params = [] 61 | for name, p in self.named_params(self): 62 | if (identifier in name) and len(p.shape) > 1: 63 | named_params.append((name, p)) 64 | for tgt, src in zip(named_params, source_params): 65 | name_t, param_t = tgt 66 | grad = src 67 | if first_order: 68 | grad = to_var(grad.detach().data) 69 | if grad is not None: 70 | tmp = param_t - lr_inner * grad 71 | self.set_param(self, name_t, tmp) 72 | else: 73 | for name, param in self.named_params(self): 74 | if not detach: 75 | grad = param.grad 76 | if first_order: 77 | grad = to_var(grad.detach().data) 78 | tmp = param - lr_inner * grad 79 | self.set_param(self, name, tmp) 80 | else: 81 | param = param.detach_() 82 | self.set_param(self, name, param) 83 | 84 | def set_param(self, curr_mod, name, param): 85 | if "." in name: 86 | n = name.split(".") 87 | module_name = n[0] 88 | rest = ".".join(n[1:]) 89 | for name, mod in curr_mod.named_children(): 90 | if module_name == name: 91 | self.set_param(mod, rest, param) 92 | break 93 | else: 94 | setattr(curr_mod, name, param) 95 | 96 | def detach_params(self): 97 | for name, param in self.named_params(self): 98 | self.set_param(self, name, param.detach()) 99 | 100 | def copy(self, other, same_var=False): 101 | for name, param in other.named_params(): 102 | if not same_var: 103 | param = to_var(param.data.clone(), requires_grad=True) 104 | self.set_param(name, param) 105 | 106 | 107 | class MetaLinear(MetaModule): 108 | def __init__(self, *args, **kwargs): 109 | super().__init__() 110 | ignore = nn.Linear(*args, **kwargs) 111 | 112 | self.register_buffer("weight", to_var(ignore.weight.data, requires_grad=True)) 113 | self.register_buffer("bias", to_var(ignore.bias.data, requires_grad=True)) 114 | 115 | def forward(self, x): 116 | return F.linear(x, self.weight, self.bias) 117 | 118 | def named_leaves(self): 119 | return [("weight", self.weight), ("bias", self.bias)] 120 | 121 | 122 | class MetaConv2d(MetaModule): 123 | def __init__(self, *args, **kwargs): 124 | super().__init__() 125 | ignore = nn.Conv2d(*args, **kwargs) 126 | 127 | self.stride = ignore.stride 128 | self.padding = ignore.padding 129 | self.dilation = ignore.dilation 130 | self.groups = ignore.groups 131 | 132 | self.register_buffer("weight", to_var(ignore.weight.data, requires_grad=True)) 133 | 134 | if ignore.bias is not None: 135 | self.register_buffer("bias", to_var(ignore.bias.data, requires_grad=True)) 136 | else: 137 | self.register_buffer("bias", None) 138 | 139 | def forward(self, x): 140 | return F.conv2d( 141 | x, 142 | self.weight, 143 | self.bias, 144 | self.stride, 145 | self.padding, 146 | self.dilation, 147 | self.groups, 148 | ) 149 | 150 | def named_leaves(self): 151 | return [("weight", self.weight), ("bias", self.bias)] 152 | 153 | 154 | class MetaConvTranspose2d(MetaModule): 155 | def __init__(self, *args, **kwargs): 156 | super().__init__() 157 | ignore = nn.ConvTranspose2d(*args, **kwargs) 158 | 159 | self.stride = ignore.stride 160 | self.padding = ignore.padding 161 | self.dilation = ignore.dilation 162 | self.groups = ignore.groups 163 | 164 | self.register_buffer("weight", to_var(ignore.weight.data, requires_grad=True)) 165 | 166 | if ignore.bias is not None: 167 | self.register_buffer("bias", to_var(ignore.bias.data, requires_grad=True)) 168 | else: 169 | self.register_buffer("bias", None) 170 | 171 | def forward(self, x, output_size=None): 172 | output_padding = self._output_padding(x, output_size) 173 | return F.conv_transpose2d( 174 | x, 175 | self.weight, 176 | self.bias, 177 | self.stride, 178 | self.padding, 179 | output_padding, 180 | self.groups, 181 | self.dilation, 182 | ) 183 | 184 | def named_leaves(self): 185 | return [("weight", self.weight), ("bias", self.bias)] 186 | 187 | 188 | class MetaBatchNorm2d(MetaModule): 189 | def __init__(self, *args, **kwargs): 190 | super(MetaBatchNorm2d, self).__init__() 191 | ignore = nn.BatchNorm2d(*args, **kwargs) 192 | 193 | self.num_features = ignore.num_features 194 | self.eps = ignore.eps 195 | self.momentum = ignore.momentum 196 | self.affine = ignore.affine 197 | self.track_running_stats = ignore.track_running_stats 198 | 199 | if self.affine: 200 | self.register_buffer( 201 | "weight", to_var(ignore.weight.data, requires_grad=True) 202 | ) 203 | self.register_buffer("bias", to_var(ignore.bias.data, requires_grad=True)) 204 | 205 | if self.track_running_stats: 206 | self.register_buffer("running_mean", torch.zeros(self.num_features)) 207 | self.register_buffer("running_var", torch.ones(self.num_features)) 208 | self.register_buffer( 209 | "num_batches_tracked", torch.tensor(0, dtype=torch.long) 210 | ) 211 | else: 212 | self.register_parameter("running_mean", None) 213 | self.register_parameter("running_var", None) 214 | 215 | def reset_running_stats(self): 216 | if self.track_running_stats: 217 | self.running_mean.zero_() 218 | self.running_var.fill_(1) 219 | self.num_batches_tracked.zero_() 220 | 221 | def reset_parameters(self): 222 | self.reset_running_stats() 223 | if self.affine: 224 | self.weight.data.uniform_() 225 | self.bias.data.zero_() 226 | 227 | def _check_input_dim(self, input): 228 | if input.dim() != 4: 229 | raise ValueError("expected 4D input (got {}D input)".format(input.dim())) 230 | 231 | def forward(self, x): 232 | self._check_input_dim(x) 233 | exponential_average_factor = 0.0 234 | 235 | if self.training and self.track_running_stats: 236 | self.num_batches_tracked += 1 237 | if self.momentum is None: 238 | exponential_average_factor = 1.0 / self.num_batches_tracked.item() 239 | else: 240 | exponential_average_factor = self.momentum 241 | 242 | return F.batch_norm( 243 | x, 244 | self.running_mean, 245 | self.running_var, 246 | self.weight, 247 | self.bias, 248 | self.training or not self.track_running_stats, 249 | self.momentum, 250 | self.eps, 251 | ) 252 | 253 | def named_leaves(self): 254 | return [("weight", self.weight), ("bias", self.bias)] 255 | 256 | def extra_repr(self): 257 | return ( 258 | "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " 259 | "track_running_stats={track_running_stats}".format(**self.__dict__) 260 | ) 261 | 262 | def _load_from_state_dict( 263 | self, 264 | state_dict, 265 | prefix, 266 | metadata, 267 | strict, 268 | missing_keys, 269 | unexpected_keys, 270 | error_msgs, 271 | ): 272 | version = metadata.get("version", None) 273 | 274 | if (version is None or version < 2) and self.track_running_stats: 275 | num_batches_tracked_key = prefix + "num_batches_tracked" 276 | if num_batches_tracked_key not in state_dict: 277 | state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) 278 | 279 | super(MetaBatchNorm2d, self)._load_from_state_dict( 280 | state_dict, 281 | prefix, 282 | metadata, 283 | strict, 284 | missing_keys, 285 | unexpected_keys, 286 | error_msgs, 287 | ) 288 | 289 | 290 | class MetaBatchNorm1d(MetaModule): 291 | def __init__(self, *args, **kwargs): 292 | super(MetaBatchNorm1d, self).__init__() 293 | ignore = nn.BatchNorm1d(*args, **kwargs) 294 | 295 | self.num_features = ignore.num_features 296 | self.eps = ignore.eps 297 | self.momentum = ignore.momentum 298 | self.affine = ignore.affine 299 | self.track_running_stats = ignore.track_running_stats 300 | 301 | if self.affine: 302 | self.register_buffer( 303 | "weight", to_var(ignore.weight.data, requires_grad=True) 304 | ) 305 | self.register_buffer("bias", to_var(ignore.bias.data, requires_grad=True)) 306 | 307 | if self.track_running_stats: 308 | self.register_buffer("running_mean", torch.zeros(self.num_features)) 309 | self.register_buffer("running_var", torch.ones(self.num_features)) 310 | self.register_buffer( 311 | "num_batches_tracked", torch.tensor(0, dtype=torch.long) 312 | ) 313 | else: 314 | self.register_parameter("running_mean", None) 315 | self.register_parameter("running_var", None) 316 | 317 | def reset_running_stats(self): 318 | if self.track_running_stats: 319 | self.running_mean.zero_() 320 | self.running_var.fill_(1) 321 | self.num_batches_tracked.zero_() 322 | 323 | def reset_parameters(self): 324 | self.reset_running_stats() 325 | if self.affine: 326 | self.weight.data.uniform_() 327 | self.bias.data.zero_() 328 | 329 | def _check_input_dim(self, input): 330 | if input.dim() != 2 and input.dim() != 3: 331 | raise ValueError( 332 | "expected 2D or 3D input (got {}D input)".format(input.dim()) 333 | ) 334 | 335 | def forward(self, x): 336 | self._check_input_dim(x) 337 | exponential_average_factor = 0.0 338 | 339 | if self.training and self.track_running_stats: 340 | self.num_batches_tracked += 1 341 | if self.momentum is None: 342 | exponential_average_factor = 1.0 / self.num_batches_tracked.item() 343 | else: 344 | exponential_average_factor = self.momentum 345 | 346 | return F.batch_norm( 347 | x, 348 | self.running_mean, 349 | self.running_var, 350 | self.weight, 351 | self.bias, 352 | self.training or not self.track_running_stats, 353 | self.momentum, 354 | self.eps, 355 | ) 356 | 357 | def named_leaves(self): 358 | return [("weight", self.weight), ("bias", self.bias)] 359 | 360 | def extra_repr(self): 361 | return ( 362 | "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " 363 | "track_running_stats={track_running_stats}".format(**self.__dict__) 364 | ) 365 | 366 | def _load_from_state_dict( 367 | self, 368 | state_dict, 369 | prefix, 370 | metadata, 371 | strict, 372 | missing_keys, 373 | unexpected_keys, 374 | error_msgs, 375 | ): 376 | version = metadata.get("version", None) 377 | 378 | if (version is None or version < 2) and self.track_running_stats: 379 | num_batches_tracked_key = prefix + "num_batches_tracked" 380 | if num_batches_tracked_key not in state_dict: 381 | state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) 382 | 383 | super(MetaBatchNorm1d, self)._load_from_state_dict( 384 | state_dict, 385 | prefix, 386 | metadata, 387 | strict, 388 | missing_keys, 389 | unexpected_keys, 390 | error_msgs, 391 | ) 392 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import builtins 3 | import math 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | import warnings 9 | import torch 10 | import torch.nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import numpy as np 17 | import dill 18 | from tqdm import tqdm 19 | from models.models import * 20 | from models.resnet import * 21 | from models.wideresnet import WideResNet 22 | from utils.utils_algo import * 23 | from utils.utils_data import * 24 | from utils.utils_loss import * 25 | import warnings 26 | 27 | warnings.filterwarnings("ignore") 28 | 29 | torch.set_printoptions(precision=2, sci_mode=False) 30 | 31 | parser = argparse.ArgumentParser( 32 | description="PyTorch implementation of LaGAM" 33 | ) 34 | parser.add_argument( 35 | "--dataset", 36 | default="cifar10", 37 | type=str, 38 | choices=["cifar10", "cifar100", "stl10", "alz"], 39 | help="dataset name (cifar10)", 40 | ) 41 | parser.add_argument( 42 | "--exp-dir", 43 | default="experiment", 44 | type=str, 45 | help="experiment directory for saving checkpoints and logs", 46 | ) 47 | parser.add_argument( 48 | "--no_verbose", action="store_true", help="disable showing running statics" 49 | ) 50 | parser.add_argument( 51 | "-a", 52 | "--arch", 53 | metavar="ARCH", 54 | default="resnet18", 55 | choices=["resnet50", "resnet18", "CNN", "WRN", "CNN13"], 56 | help="network architecture", 57 | ) 58 | parser.add_argument( 59 | "-j", 60 | "--workers", 61 | default=8, 62 | type=int, 63 | help="number of data loading workers", 64 | ) 65 | parser.add_argument( 66 | "--epochs", default=400, type=int, help="number of total epochs to run" 67 | ) 68 | parser.add_argument( 69 | "--start-epoch", 70 | default=0, 71 | type=int, 72 | help="manual epoch number (useful on restarts)", 73 | ) 74 | parser.add_argument( 75 | "-b", 76 | "--batch-size", 77 | default=64, 78 | type=int, 79 | help="mini-batch size (default: 64), this is the total " 80 | "batch size of all GPUs on the current node when " 81 | "using Data Parallel or Distributed Data Parallel", 82 | ) 83 | parser.add_argument( 84 | "--lr", 85 | "--learning-rate", 86 | default=1e-3, 87 | type=float, 88 | metavar="LR", 89 | help="initial learning rate", 90 | dest="lr", 91 | ) 92 | parser.add_argument( 93 | "--lr_decay_epochs", 94 | type=str, 95 | default="250,300,350", 96 | help="where to decay lr, can be a list", 97 | ) 98 | parser.add_argument( 99 | "--lr_decay_rate", type=float, default=0.1, help="decay rate for learning rate" 100 | ) 101 | parser.add_argument( 102 | "--cosine", action="store_true", default=False, help="use cosine lr schedule" 103 | ) 104 | parser.add_argument( 105 | "--momentum", default=0.9, type=float, metavar="M", help="momentum of SGD solver" 106 | ) 107 | parser.add_argument( 108 | "--wd", 109 | "--weight-decay", 110 | default=1e-4, 111 | type=float, 112 | metavar="W", 113 | help="weight decay (default: 1e-4)", 114 | dest="weight_decay", 115 | ) 116 | parser.add_argument( 117 | "-p", "--print-freq", default=100, type=int, help="print frequency (default: 100)" 118 | ) 119 | parser.add_argument( 120 | "--seed", default=None, type=int, help="seed for initializing training. " 121 | ) 122 | parser.add_argument("--gpu", default=0, type=int, help="GPU id to use.") 123 | parser.add_argument("--num-class", default=10, type=int, help="number of class") 124 | 125 | parser.add_argument("--n_positive", default=1000, type=int, help="num_labeled data") 126 | parser.add_argument("--n_valid", default=500, type=int, help="number of valid examples") 127 | parser.add_argument( 128 | "--positive_list", default="0,1,8,9", type=str, help="list of positive labels" 129 | ) 130 | # Standatd setups: 131 | # CIFAR-10-1: 0,1,8,9 132 | # CIFAR-10-2: 2,3,4,5,6,7 133 | # STL-10-1: 0,2,3,8,9 134 | # STL-10-2: 1,4,5,6,7 135 | # CIFAR-100-1: 18,19 136 | # CIFAR-100-2: 0,1,7,8,11,12,13,14,15,16 137 | parser.add_argument( 138 | "--ent_loss", action="store_true", help="whether enable entropy loss" 139 | ) 140 | parser.add_argument("--mix_weight", default=1.0, type=float, help="mixup loss weight") 141 | parser.add_argument( 142 | "--rho_range", default="0.95,0.8", type=str, help="momentum updating parameter" 143 | ) 144 | parser.add_argument( 145 | "--warmup_epoch", default=20, type=int, help="epoch number of warm up" 146 | ) 147 | parser.add_argument( 148 | "--using_cont", type=int, default=1, help="whether using contrastive loss" 149 | ) 150 | 151 | parser.add_argument("--num_cluster", default=100, type=int, help="number of clusters") 152 | parser.add_argument("--temperature", default=0.07, type=float, help="mixup loss weight") 153 | parser.add_argument( 154 | "--cont_cutoff", action="store_true", help="whether cut off by classifier" 155 | ) 156 | parser.add_argument("--knn_aug", action="store_true", help="whether using kNN for CL") 157 | parser.add_argument("--num_neighbors", default=10, type=int, help="number of neighbors") 158 | parser.add_argument( 159 | "--identifier", 160 | default=None, 161 | type=str, 162 | help="identifier for meta layers, e.g. classifier", 163 | ) 164 | parser.add_argument( 165 | "--contrastive_clustering", 166 | default=1, 167 | type=int, 168 | help="whether using contrastive clustering", 169 | ) 170 | 171 | parser.add_argument("--reverse", default=0, type=int, help="whether inverse label") 172 | parser.add_argument("--tag", default="", type=str, help="special identifier") 173 | parser.add_argument("--save", default=0, type=int, help="whether save model") 174 | 175 | 176 | class Trainer: 177 | def __init__(self, args, model_func=None): 178 | self.args = args 179 | model_path = "{ds}_ep{ep}_we{we}_pos{pl}_nl{nl}_rho{rs}~{re}_co{co}_knn{knn}{k}_sd_{seed}".format( 180 | ds=args.dataset, 181 | ep=args.epochs, 182 | pl=str(args.positive_list), 183 | nl=args.n_positive, 184 | rs=args.rho_start, 185 | re=args.rho_end, 186 | knn=args.knn_aug, 187 | co=args.cont_cutoff, 188 | k=args.num_neighbors, 189 | we=args.warmup_epoch, 190 | seed=args.seed, 191 | ) 192 | args.exp_dir = os.path.join(args.exp_dir, model_path) 193 | if not os.path.exists(args.exp_dir): 194 | os.makedirs(args.exp_dir) 195 | 196 | if args.seed is not None: 197 | random.seed(args.seed) 198 | torch.manual_seed(args.seed) 199 | np.random.seed(args.seed) 200 | cudnn.deterministic = True 201 | 202 | train_loader, test_loader, valid_loader, eval_loader, dim = load_dataset( 203 | args=args 204 | ) 205 | 206 | self.train_loader = train_loader 207 | self.test_loader = test_loader 208 | self.valid_loader = valid_loader 209 | self.eval_loader = eval_loader 210 | self.dim = dim 211 | print("=> creating model '{}'".format(args.arch)) 212 | model = create_model(args, self.dim) 213 | 214 | optimizer = torch.optim.SGD( 215 | model.params(), 216 | args.lr, 217 | momentum=args.momentum, 218 | weight_decay=args.weight_decay, 219 | ) 220 | 221 | self.model = model 222 | self.optimizer = optimizer 223 | self.bce_loss = BCELoss(args.ent_loss) 224 | self.contrastive_loss = ContLoss( 225 | temperature=args.temperature, 226 | cont_cutoff=args.cont_cutoff, 227 | knn_aug=args.knn_aug, 228 | num_neighbors=args.num_neighbors, 229 | contrastive_clustering=args.contrastive_clustering, 230 | ) 231 | 232 | def train(self): 233 | args = self.args 234 | optimizer = self.optimizer 235 | 236 | best_acc = 0 237 | 238 | for epoch in range(args.start_epoch, args.epochs): 239 | adjust_learning_rate(args, optimizer, epoch) 240 | 241 | if epoch < args.warmup_epoch or args.using_cont == 0: 242 | self.train_loop(epoch) 243 | else: 244 | features = self.compute_features() 245 | cluster_result = run_kmeans(features, args) 246 | self.train_loop(epoch, cluster_result) 247 | 248 | acc_test = self.test() 249 | 250 | with open(os.path.join(args.exp_dir, "result.log"), "a+") as f: 251 | f.write( 252 | "Epoch {}: Acc {:.2f}, Best Acc {:.2f}. (lr {:.5f})\n".format( 253 | epoch, acc_test, best_acc, optimizer.param_groups[0]["lr"] 254 | ) 255 | ) 256 | 257 | if acc_test > best_acc: 258 | best_acc = acc_test 259 | 260 | if args.save: 261 | file_name = "{}_{}_{}_{}_{}".format( 262 | args.dataset, args.arch, args.n_valid, args.num_cluster, args.tag 263 | ) 264 | 265 | torch.save(self.model, file_name + ".pth") 266 | 267 | with open(file_name + ".pkl", "wb") as f: 268 | dill.dump(self.train_loader.dataset, f) 269 | 270 | def train_loop(self, epoch, cluster_result=None): 271 | args = self.args 272 | train_loader = self.train_loader 273 | model = self.model 274 | optimizer = self.optimizer 275 | bce_loss = self.bce_loss 276 | contrastive_loss = self.contrastive_loss 277 | 278 | batch_time = AverageMeter("Time", ":1.2f") 279 | data_time = AverageMeter("Data", ":1.2f") 280 | acc_cls = AverageMeter("Acc@Cls", ":2.2f") 281 | loss_cls_log = AverageMeter("Loss@Cls", ":2.2f") 282 | loss_cont_log = AverageMeter("Loss@Cont", ":2.2f") 283 | progress = ProgressMeter( 284 | len(train_loader), 285 | [batch_time, data_time, acc_cls, loss_cls_log, loss_cont_log], 286 | prefix="Epoch: [{}]".format(epoch), 287 | ) 288 | 289 | model.train() 290 | 291 | updated_label_list = [] 292 | true_label_list = [] 293 | index_list = [] 294 | ema_param = ( 295 | 1.0 * epoch / args.epochs * (args.rho_end - args.rho_start) + args.rho_start 296 | ) 297 | 298 | end = time.time() 299 | 300 | for i, (images, images_s, labels_, true_labels, index) in enumerate( 301 | train_loader 302 | ): 303 | data_time.update(time.time() - end) 304 | 305 | if labels_.sum() == 0: 306 | continue 307 | true_label_list.append(true_labels) 308 | index_list.append(index) 309 | 310 | images, images_s, labels_, index = ( 311 | images.cuda(), 312 | images_s.cuda(), 313 | labels_.cuda(), 314 | index.cuda(), 315 | ) 316 | labels_ = labels_.unsqueeze(1) 317 | labels = torch.cat([1 - labels_, labels_], dim=1).detach() 318 | Y_true = true_labels.long().detach().cuda() 319 | bs = len(labels) 320 | cluster_idxes = ( 321 | None if cluster_result is None else cluster_result["im2cluster"][index] 322 | ) 323 | 324 | if epoch < args.warmup_epoch: 325 | labels_final = labels 326 | else: 327 | meta_model = create_model(args, self.dim) 328 | meta_model.load_state_dict(model.state_dict()) 329 | 330 | preds_meta = meta_model(images) 331 | 332 | eps = to_var(torch.zeros(bs, 2).cuda()) 333 | labels_meta = labels + eps 334 | loss = bce_loss(preds_meta, labels_meta) 335 | 336 | meta_model.zero_grad() 337 | 338 | params = [] 339 | for name, p in meta_model.named_params(meta_model): 340 | if args.identifier in name and len(p.shape) > 1: 341 | params.append(p) 342 | grads = torch.autograd.grad( 343 | loss, params, create_graph=True, allow_unused=True 344 | ) 345 | meta_lr = 0.001 346 | meta_model.update_params( 347 | meta_lr, source_params=grads, identifier=args.identifier 348 | ) 349 | 350 | try: 351 | images_v, labels_v = next(valid_loder_iter) 352 | except: 353 | valid_loder_iter = iter(self.valid_loader) 354 | images_v, labels_v = next(valid_loder_iter) 355 | 356 | images_v = images_v.cuda() 357 | labels_v = F.one_hot(labels_v.cuda(), 2).float() 358 | 359 | preds_v = meta_model(images_v) 360 | 361 | loss_meta_v = bce_loss(preds_v, labels_v) 362 | grad_eps = torch.autograd.grad( 363 | loss_meta_v, eps, only_inputs=True, allow_unused=True 364 | )[0] 365 | 366 | eps = eps - grad_eps 367 | meta_detected_labels = eps.argmax(dim=1) 368 | meta_detected_labels[labels_.squeeze() == 1] = 1 369 | meta_detected_labels = F.one_hot(meta_detected_labels, 2) 370 | meta_detected_labels = meta_detected_labels.detach() 371 | 372 | updated_labels = labels 373 | updated_labels = updated_labels * ema_param + meta_detected_labels * ( 374 | 1 - ema_param 375 | ) 376 | labels_final = updated_labels.detach() 377 | 378 | updated_label_list.append(updated_labels[:, 1].cpu()) 379 | 380 | del grad_eps, grads, params 381 | 382 | l = np.random.beta(4, 4) 383 | l = max(l, 1 - l) 384 | X_w_c = images 385 | pseudo_label_c = labels_final 386 | idx = torch.randperm(X_w_c.size(0)) 387 | X_w_c_rand = X_w_c[idx] 388 | pseudo_label_c_rand = pseudo_label_c[idx] 389 | X_w_c_mix = l * X_w_c + (1 - l) * X_w_c_rand 390 | pseudo_label_c_mix = l * pseudo_label_c + (1 - l) * pseudo_label_c_rand 391 | logits_mix = model(X_w_c_mix) 392 | loss_mix = bce_loss(logits_mix, pseudo_label_c_mix) 393 | 394 | preds_final, feat_cont = model(images, flag_feature=True) 395 | loss_cls = bce_loss(preds_final, labels_final) 396 | 397 | loss_final = loss_cls + args.mix_weight * loss_mix 398 | 399 | if args.using_cont: 400 | _, feat_cont_s = model(images_s, flag_feature=True) 401 | loss_cont = contrastive_loss( 402 | feat_cont, 403 | feat_cont_s, 404 | cluster_idxes, 405 | preds_final, 406 | start_knn_aug=epoch > 50, 407 | ) 408 | loss_final = loss_final + loss_cont 409 | 410 | loss_cont_log.update(loss_cont.item()) 411 | loss_cls_log.update(loss_final.item()) 412 | 413 | acc = accuracy( 414 | torch.cat([1 - preds_final, preds_final], dim=1), Y_true 415 | ) 416 | acc = acc[0] 417 | acc_cls.update(acc[0]) 418 | 419 | optimizer.zero_grad() 420 | loss_final.backward() 421 | optimizer.step() 422 | batch_time.update(time.time() - end) 423 | end = time.time() 424 | if i % args.print_freq == 0: 425 | progress.display(i) 426 | 427 | if epoch >= args.warmup_epoch and not args.no_verbose: 428 | true_label_list = torch.cat(true_label_list, dim=0) 429 | updated_label_list = torch.cat(updated_label_list, dim=0) 430 | index_list = torch.cat(index_list, dim=0) 431 | 432 | print(updated_label_list[:10]) 433 | print(true_label_list[:10]) 434 | 435 | update_label_cate = (updated_label_list > 0.5) * 1 436 | compare = update_label_cate == true_label_list 437 | print( 438 | "New target accuracy: ", 439 | compare.sum() / len(compare), 440 | "; ema param: ", 441 | ema_param, 442 | ) 443 | 444 | self.train_loader.dataset.update_targets( 445 | updated_label_list.numpy(), index_list 446 | ) 447 | 448 | def test(self): 449 | model = self.model 450 | test_loader = self.test_loader 451 | 452 | with torch.no_grad(): 453 | print("==> Evaluation...") 454 | model.eval() 455 | pred_list = [] 456 | true_list = [] 457 | for _, (images, labels) in enumerate(test_loader): 458 | images = images.cuda() 459 | outputs = model(images) 460 | pred = torch.sigmoid(outputs) 461 | pred = torch.cat([1 - pred, pred], dim=1) 462 | pred_list.append(pred.cpu()) 463 | true_list.append(labels) 464 | 465 | pred_list = torch.cat(pred_list, dim=0) 466 | true_list = torch.cat(true_list, dim=0) 467 | 468 | acc1 = accuracy(pred_list, true_list, topk=(1,)) 469 | acc1 = acc1[0] 470 | print("==> Test Accuracy is %.2f%%" % (acc1)) 471 | # print("==> AUC, F1, Recall, Precision are: ") 472 | # print(auc, f1, recall, precision) 473 | return float(acc1) 474 | 475 | def compute_features(self): 476 | model = self.model 477 | model.eval() 478 | feat_list = torch.zeros(len(self.eval_loader.dataset), 128) 479 | with torch.no_grad(): 480 | for i, (images, _, _, _, index) in enumerate(self.eval_loader): 481 | images = images.cuda(non_blocking=True) 482 | _, feat = model(images, flag_feature=True) 483 | feat_list[index] = feat.cpu() 484 | return feat_list.numpy() 485 | 486 | def save_checkpoint( 487 | self, 488 | state, 489 | is_best, 490 | filename="checkpoint.pth.tar", 491 | best_file_name="model_best.pth.tar", 492 | ): 493 | torch.save(state, filename) 494 | if is_best: 495 | shutil.copyfile(filename, best_file_name) 496 | 497 | 498 | def create_model(args, dim=0): 499 | if args.arch == "CNN": 500 | if args.dataset == "stl10": 501 | model = MixCNNSTL_CL(dim).cuda() 502 | else: 503 | model = MixCNNCIFAR_CL(dim).cuda() 504 | elif args.arch == "WRN": 505 | model = WideResNet() 506 | elif args.arch == "CNN13": 507 | model = MetaCNN() 508 | else: 509 | if args.dataset == "stl10": 510 | model = PreActResNetMeta( 511 | PreActBlockMeta, [2, 2, 2, 2], num_classes=1, use_checkpoint=False 512 | ).cuda() 513 | elif args.dataset == "alz": 514 | model = PreActResNetMeta( 515 | PreActBottleneckMeta, [3, 4, 6, 3], num_classes=1, use_checkpoint=False 516 | ).cuda() 517 | else: 518 | model = PreActResNetMeta( 519 | PreActBlockMeta, [2, 2, 2, 2], num_classes=1, use_checkpoint=False 520 | ).cuda() 521 | model.cuda() 522 | return model 523 | 524 | 525 | if __name__ == "__main__": 526 | args = parser.parse_args() 527 | [args.rho_start, args.rho_end] = [float(item) for item in args.rho_range.split(",")] 528 | args.positive_list = [int(item) for item in args.positive_list.split(",")] 529 | iterations = args.lr_decay_epochs.split(",") 530 | args.lr_decay_epochs = list([]) 531 | for it in iterations: 532 | args.lr_decay_epochs.append(int(it)) 533 | print(args) 534 | trainer = Trainer(args) 535 | trainer.train() 536 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch.autograd import Variable 8 | from models.meta_layers import * 9 | import collections 10 | 11 | from torch.utils.checkpoint import checkpoint 12 | 13 | 14 | class MyClassifier(nn.Module): 15 | def zero_one_loss(self, h, t, is_logistic=False): 16 | self.eval() 17 | positive = 1 18 | negative = 0 if is_logistic else -1 19 | 20 | n_p = (t == positive).sum() 21 | n_n = (t == negative).sum() 22 | size = n_p + n_n 23 | 24 | n_pp = (h == positive).sum() 25 | t_p = ((h == positive) * (t == positive)).sum() 26 | t_n = ((h == negative) * (t == negative)).sum() 27 | f_p = n_n - t_n 28 | f_n = n_p - t_p 29 | 30 | presicion = 0.0 if t_p == 0 else t_p / (t_p + f_p) 31 | recall = 0.0 if t_p == 0 else t_p / (t_p + f_n) 32 | 33 | return presicion, recall, 1 - (t_p + t_n) / size, n_pp 34 | 35 | def error(self, DataLoader, is_logistic=False): 36 | targets_all = np.array([]) 37 | prediction_all = np.array([]) 38 | self.eval() 39 | for data, _, target in DataLoader: 40 | data = data.cuda() 41 | t = target.detach().cpu().numpy() 42 | size = len(t) 43 | if is_logistic: 44 | h = np.reshape(torch.sigmoid(self(data)).detach().cpu().numpy(), size) 45 | h = np.where(h > 0.5, 1, 0).astype(np.int32) 46 | else: 47 | h = np.reshape(torch.sign(self(data)).detach().cpu().numpy(), size) 48 | 49 | targets_all = np.hstack((targets_all, t)) 50 | prediction_all = np.hstack((prediction_all, h)) 51 | 52 | return self.zero_one_loss(prediction_all, targets_all, is_logistic) 53 | 54 | def evalution_with_density(self, DataLoader, prior): 55 | targets_all = np.array([]) 56 | prediction_all = np.array([]) 57 | self.eval() 58 | for data, target in DataLoader: 59 | data = data.to(device) 60 | t = target.detach().cpu().numpy() 61 | size = len(t) 62 | h = np.reshape(self(data).detach().cpu().numpy(), size) 63 | h = self.predict_with_density_threshold(h, target, prior) 64 | 65 | targets_all = np.hstack((targets_all, t)) 66 | prediction_all = np.hstack((prediction_all, h)) 67 | 68 | return self.zero_one_loss(prediction_all, targets_all) 69 | 70 | def predict_with_density_threshold(self, f_x, target, prior): 71 | density_ratio = f_x / prior 72 | sorted_density_ratio = np.sort(density_ratio) 73 | size = len(density_ratio) 74 | 75 | n_pi = int(size * prior) 76 | threshold = ( 77 | sorted_density_ratio[size - n_pi] + sorted_density_ratio[size - n_pi - 1] 78 | ) / 2 79 | h = np.sign(density_ratio - threshold).astype(np.int32) 80 | return h 81 | 82 | 83 | class LeNet(MyClassifier, nn.Module): 84 | def __init__(self, dim): 85 | super(LeNet, self).__init__() 86 | 87 | self.input_dim = dim 88 | 89 | self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2) 90 | self.conv2 = nn.Conv2d(6, 16, kernel_size=5) 91 | self.conv3 = nn.Conv2d(16, 120, kernel_size=5) 92 | self.bn_conv1 = nn.BatchNorm2d(6) 93 | self.bn_conv2 = nn.BatchNorm2d(16) 94 | self.mp = nn.MaxPool2d(2) 95 | self.relu = nn.ReLU() 96 | self.fc1 = nn.Linear(120, 84) 97 | self.bn_fc1 = nn.BatchNorm1d(84) 98 | 99 | self.layer1 = nn.Sequential(self.conv1, self.mp, self.relu) 100 | self.layer2 = nn.Sequential(self.conv2, self.mp, self.relu) 101 | self.layer3 = nn.Sequential(self.conv3, self.relu) 102 | 103 | self.layers = nn.ModuleList([self.layer1, self.layer2, self.layer3]) 104 | 105 | self.layer4 = nn.Sequential(self.fc1, self.bn_fc1, self.relu) 106 | self.classifier = nn.Linear(84, 1) 107 | 108 | def forward(self, x): 109 | h = x 110 | for i, layer_module in enumerate(self.layers): 111 | h = layer_module(h) 112 | 113 | h = h.view(h.size(0), -1) 114 | h = self.layer4(h) 115 | h = self.classifier(h) 116 | return h 117 | 118 | 119 | class MixLeNet(MyClassifier, nn.Module): 120 | def __init__(self, dim): 121 | super(MixLeNet, self).__init__() 122 | 123 | self.input_dim = dim 124 | 125 | self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2) 126 | self.conv2 = nn.Conv2d(6, 16, kernel_size=5) 127 | self.conv3 = nn.Conv2d(16, 120, kernel_size=5) 128 | self.bn_conv1 = nn.BatchNorm2d(6) 129 | self.bn_conv2 = nn.BatchNorm2d(16) 130 | self.mp = nn.MaxPool2d(2) 131 | self.relu = nn.ReLU() 132 | self.fc1 = nn.Linear(120, 84) 133 | self.bn_fc1 = nn.BatchNorm1d(84) 134 | 135 | self.layer1 = nn.Sequential(self.conv1, self.mp, self.relu) 136 | self.layer2 = nn.Sequential(self.conv2, self.mp, self.relu) 137 | self.layer3 = nn.Sequential(self.conv3, self.relu) 138 | 139 | self.layers = nn.ModuleList([self.layer1, self.layer2, self.layer3]) 140 | 141 | self.layer4 = nn.Sequential(self.fc1, self.bn_fc1, self.relu) 142 | self.classifier = nn.Linear(84, 1) 143 | 144 | def forward(self, x, x2=None, l=None, mix_layer=1000, flag_feature=False): 145 | h, h2 = x, x2 146 | if mix_layer == -1: 147 | if h2 is not None: 148 | h = l * h + (1.0 - l) * h2 149 | 150 | for i, layer_module in enumerate(self.layers): 151 | if i <= mix_layer: 152 | h = layer_module(h) 153 | 154 | if h2 is not None: 155 | h2 = layer_module(h2) 156 | 157 | if i == mix_layer: 158 | if h2 is not None: 159 | h = l * h + (1.0 - l) * h2 160 | 161 | if i > mix_layer: 162 | h = layer_module(h) 163 | 164 | h_ = h.view(h.size(0), -1) 165 | h_ = self.layer4(h_) 166 | h = self.classifier(h_) 167 | 168 | if flag_feature: 169 | return h, h_ 170 | else: 171 | return h 172 | 173 | 174 | class CNNSTL(MyClassifier, nn.Module): 175 | def __init__(self, dim): 176 | super(CNNSTL, self).__init__() 177 | 178 | self.input_dim = dim 179 | 180 | self.relu = nn.ReLU() 181 | self.conv1 = nn.Conv2d(3, 6, 3) 182 | self.conv2 = nn.Conv2d(6, 6, 3) 183 | self.mp = nn.MaxPool2d(2, 2) 184 | self.conv3 = nn.Conv2d(6, 16, 5) 185 | self.conv4 = nn.Conv2d(16, 32, 5) 186 | self.fc1 = nn.Linear(32 * 8 * 8, 120) 187 | self.fc2 = nn.Linear(120, 84) 188 | 189 | self.layer1 = nn.Sequential(self.conv1, self.relu, self.mp) 190 | self.layer2 = nn.Sequential(self.conv2, self.relu) 191 | self.layer3 = nn.Sequential(self.conv3, self.relu, self.mp) 192 | self.layer4 = nn.Sequential(self.conv4, self.relu, self.mp) 193 | 194 | self.layers = nn.ModuleList( 195 | [self.layer1, self.layer2, self.layer3, self.layer4] 196 | ) 197 | 198 | self.layer5 = nn.Sequential(self.fc1, self.relu, self.fc2, self.relu) 199 | 200 | self.classifier = nn.Linear(84, 1) 201 | 202 | def forward(self, x): 203 | h = x 204 | for i, layer_module in enumerate(self.layers): 205 | h = layer_module(h) 206 | 207 | h = h.view(h.size(0), -1) 208 | h = self.layer5(h) 209 | h = self.classifier(h) 210 | return h 211 | 212 | 213 | class MixCNNSTL(MyClassifier, nn.Module): 214 | def __init__(self, dim): 215 | super(MixCNNSTL, self).__init__() 216 | 217 | self.input_dim = dim 218 | 219 | self.relu = nn.ReLU() 220 | self.conv1 = nn.Conv2d(3, 6, 3) 221 | self.conv2 = nn.Conv2d(6, 6, 3) 222 | self.mp = nn.MaxPool2d(2, 2) 223 | self.conv3 = nn.Conv2d(6, 16, 5) 224 | self.conv4 = nn.Conv2d(16, 32, 5) 225 | self.fc1 = nn.Linear(32 * 8 * 8, 120) 226 | self.fc2 = nn.Linear(120, 84) 227 | 228 | self.layer1 = nn.Sequential(self.conv1, self.relu, self.mp) 229 | self.layer2 = nn.Sequential(self.conv2, self.relu) 230 | self.layer3 = nn.Sequential(self.conv3, self.relu, self.mp) 231 | self.layer4 = nn.Sequential(self.conv4, self.relu, self.mp) 232 | 233 | self.layers = nn.ModuleList( 234 | [self.layer1, self.layer2, self.layer3, self.layer4] 235 | ) 236 | 237 | self.layer5 = nn.Sequential(self.fc1, self.relu, self.fc2, self.relu) 238 | 239 | self.classifier = nn.Linear(84, 1) 240 | 241 | def forward(self, x, x2=None, l=None, mix_layer=1000, flag_feature=False): 242 | h, h2 = x, x2 243 | if mix_layer == -1: 244 | if h2 is not None: 245 | h = l * h + (1.0 - l) * h2 246 | 247 | for i, layer_module in enumerate(self.layers): 248 | if i <= mix_layer: 249 | h = layer_module(h) 250 | 251 | if h2 is not None: 252 | h2 = layer_module(h2) 253 | 254 | if i == mix_layer: 255 | if h2 is not None: 256 | h = l * h + (1.0 - l) * h2 257 | 258 | if i > mix_layer: 259 | h = layer_module(h) 260 | 261 | h_ = h.view(h.size(0), -1) 262 | h_ = self.layer5(h_) 263 | h = self.classifier(h_) 264 | 265 | if flag_feature: 266 | return h, h_ 267 | else: 268 | return h 269 | 270 | 271 | class CNNCIFAR(MyClassifier, nn.Module): 272 | def __init__(self, dim): 273 | super(CNNCIFAR, self).__init__() 274 | 275 | self.af = F.relu 276 | self.input_dim = dim 277 | 278 | self.conv1 = nn.Conv2d(3, 96, 3) 279 | self.conv2 = nn.Conv2d(96, 96, 3, stride=2) 280 | self.conv3 = nn.Conv2d(96, 192, 1) 281 | self.conv4 = nn.Conv2d(192, 10, 1) 282 | self.fc1 = nn.Linear(1960, 1000) 283 | self.fc2 = nn.Linear(1000, 1000) 284 | self.fc3 = nn.Linear(1000, 1) 285 | 286 | def forward(self, x): 287 | h = self.conv1(x) 288 | h = self.af(h) 289 | h = self.conv2(h) 290 | h = self.af(h) 291 | h = self.conv3(h) 292 | h = self.af(h) 293 | h = self.conv4(h) 294 | h = self.af(h) 295 | 296 | h = h.view(h.size(0), -1) 297 | h = self.fc1(h) 298 | h = self.af(h) 299 | h = self.fc2(h) 300 | h = self.af(h) 301 | h = self.fc3(h) 302 | return h 303 | 304 | 305 | class MixCNNCIFAR(MyClassifier, MetaModule): 306 | def __init__(self, dim): 307 | super(MixCNNCIFAR, self).__init__() 308 | 309 | self.af = nn.ReLU() 310 | self.input_dim = dim 311 | 312 | self.conv_list = [ 313 | MetaConv2d(3, 96, 3), 314 | MetaConv2d(96, 96, 3, stride=2), 315 | MetaConv2d(96, 192, 1), 316 | MetaConv2d(192, 10, 1), 317 | ] 318 | self.fc1 = MetaLinear(1960, 1000) 319 | self.fc2 = MetaLinear(1000, 1000) 320 | self.fc3 = MetaLinear(1000, 1) 321 | 322 | self.layers = nn.ModuleList( 323 | [nn.Sequential(self.conv_list[i], self.af) for i in range(4)] 324 | ) 325 | 326 | self.classifier1 = nn.Sequential( 327 | self.fc1, 328 | self.af, 329 | self.fc2, 330 | self.af, 331 | ) 332 | 333 | def forward(self, x, x2=None, l=None, mix_layer=1000, flag_feature=False): 334 | h, h2 = x, x2 335 | if mix_layer == -1: 336 | if h2 is not None: 337 | h = l * h + (1.0 - l) * h2 338 | 339 | for i, layer_module in enumerate(self.layers): 340 | if i <= mix_layer: 341 | h = layer_module(h) 342 | 343 | if h2 is not None: 344 | h2 = layer_module(h2) 345 | 346 | if i == mix_layer: 347 | if h2 is not None: 348 | h = l * h + (1.0 - l) * h2 349 | 350 | if i > mix_layer: 351 | h = layer_module(h) 352 | 353 | h_ = h.view(h.size(0), -1) 354 | h_ = self.classifier1(h_) 355 | h = self.fc3(h_) 356 | 357 | if flag_feature: 358 | return h, h_ 359 | else: 360 | return h 361 | 362 | 363 | class MixCNNCIFAR_CL_(MyClassifier, MetaModule): 364 | def __init__(self, dim): 365 | super(MixCNNCIFAR_CL_, self).__init__() 366 | 367 | self.af = nn.ReLU() 368 | self.input_dim = dim 369 | 370 | self.conv_list = [ 371 | MetaConv2d(3, 96, 3), 372 | MetaConv2d(96, 96, 3, stride=2), 373 | MetaConv2d(96, 192, 1), 374 | MetaConv2d(192, 10, 1), 375 | ] 376 | self.fc1 = MetaLinear(1960, 1000) 377 | self.fc2 = MetaLinear(1000, 1000) 378 | self.classifier = MetaLinear(1000, 1) 379 | 380 | self.layers = nn.ModuleList( 381 | [nn.Sequential(self.conv_list[i], self.af) for i in range(4)] 382 | ) 383 | 384 | self.mlp = nn.Sequential( 385 | self.fc1, 386 | self.af, 387 | self.fc2, 388 | self.af, 389 | ) 390 | 391 | self.fc4 = MetaLinear(1000, 1000) 392 | self.fc5 = MetaLinear(1000, 128) 393 | self.head = nn.Sequential(self.fc4, nn.ReLU(), self.fc5) 394 | 395 | def forward(self, x, x2=None, l=None, mix_layer=1000, flag_feature=False): 396 | h, h2 = x, x2 397 | if mix_layer == -1: 398 | if h2 is not None: 399 | h = l * h + (1.0 - l) * h2 400 | 401 | for i, layer_module in enumerate(self.layers): 402 | if i <= mix_layer: 403 | h = layer_module(h) 404 | 405 | if h2 is not None: 406 | h2 = layer_module(h2) 407 | 408 | if i == mix_layer: 409 | if h2 is not None: 410 | h = l * h + (1.0 - l) * h2 411 | 412 | if i > mix_layer: 413 | h = layer_module(h) 414 | 415 | h_ = h.view(h.size(0), -1) 416 | h_feat = self.mlp(h_) 417 | h = self.classifier(h_feat) 418 | feat_cl = F.normalize(self.head(h_feat), dim=1) 419 | 420 | if flag_feature: 421 | return h, feat_cl 422 | else: 423 | return h 424 | 425 | 426 | class MixCNNCIFAR_CL(MyClassifier, MetaModule): 427 | def __init__(self, dim): 428 | super(MixCNNCIFAR_CL, self).__init__() 429 | 430 | self.af = nn.ReLU() 431 | self.input_dim = dim 432 | 433 | self.conv_list = [ 434 | MetaConv2d(3, 96, 3), 435 | MetaConv2d(96, 96, 3, stride=2), 436 | MetaConv2d(96, 192, 1), 437 | MetaConv2d(192, 10, 1), 438 | ] 439 | self.fc1 = MetaLinear(21160, 1000) 440 | self.fc2 = MetaLinear(1000, 1000) 441 | self.classifier = MetaLinear(1000, 1) 442 | 443 | self.layers = nn.ModuleList( 444 | [nn.Sequential(self.conv_list[i], self.af) for i in range(4)] 445 | ) 446 | 447 | self.mlp = nn.Sequential( 448 | self.fc1, 449 | self.af, 450 | self.fc2, 451 | self.af, 452 | ) 453 | 454 | self.fc4 = MetaLinear(1000, 1000) 455 | self.fc5 = MetaLinear(1000, 128) 456 | self.head = nn.Sequential(self.fc4, nn.ReLU(), self.fc5) 457 | 458 | def forward(self, x, x2=None, l=None, mix_layer=1000, flag_feature=False): 459 | h, h2 = x, x2 460 | if mix_layer == -1: 461 | if h2 is not None: 462 | h = l * h + (1.0 - l) * h2 463 | 464 | for i, layer_module in enumerate(self.layers): 465 | if i <= mix_layer: 466 | h = layer_module(h) 467 | 468 | if h2 is not None: 469 | h2 = layer_module(h2) 470 | 471 | if i == mix_layer: 472 | if h2 is not None: 473 | h = l * h + (1.0 - l) * h2 474 | 475 | if i > mix_layer: 476 | h = layer_module(h) 477 | 478 | h_ = h.view(h.size(0), -1) 479 | h_feat = self.mlp(h_) 480 | h = self.classifier(h_feat) 481 | feat_cl = F.normalize(self.head(h_feat), dim=1) 482 | 483 | if flag_feature: 484 | return h, feat_cl 485 | else: 486 | return h 487 | 488 | 489 | class MixCNNSTL_CL(MyClassifier, MetaModule): 490 | def __init__(self, dim): 491 | super(MixCNNSTL_CL, self).__init__() 492 | 493 | self.af = nn.ReLU() 494 | self.input_dim = dim 495 | 496 | self.conv_list = [ 497 | MetaConv2d(3, 96, 3), 498 | MetaConv2d(96, 96, 3, stride=2), 499 | MetaConv2d(96, 192, 1), 500 | MetaConv2d(192, 10, 1), 501 | ] 502 | self.fc1 = MetaLinear(21160, 1000) 503 | self.fc2 = MetaLinear(1000, 1000) 504 | self.classifier = MetaLinear(1000, 1) 505 | 506 | self.layers = nn.ModuleList( 507 | [nn.Sequential(self.conv_list[i], self.af) for i in range(4)] 508 | ) 509 | 510 | self.mlp = nn.Sequential( 511 | self.fc1, 512 | self.af, 513 | self.fc2, 514 | self.af, 515 | ) 516 | 517 | self.fc4 = MetaLinear(1000, 1000) 518 | self.fc5 = MetaLinear(1000, 128) 519 | self.head = nn.Sequential(self.fc4, nn.ReLU(), self.fc5) 520 | 521 | def forward(self, x, x2=None, l=None, mix_layer=1000, flag_feature=False): 522 | h, h2 = x, x2 523 | if mix_layer == -1: 524 | if h2 is not None: 525 | h = l * h + (1.0 - l) * h2 526 | 527 | for i, layer_module in enumerate(self.layers): 528 | if i <= mix_layer: 529 | h = checkpoint(layer_module, h) 530 | 531 | if h2 is not None: 532 | h2 = checkpoint(layer_module, h2) 533 | 534 | if i == mix_layer: 535 | if h2 is not None: 536 | h = l * h + (1.0 - l) * h2 537 | 538 | if i > mix_layer: 539 | h = checkpoint(layer_module, h) 540 | 541 | h_ = h.view(h.size(0), -1) 542 | h_feat = self.mlp(h_) 543 | h = self.classifier(h_feat) 544 | feat_cl = F.normalize(self.head(h_feat), dim=1) 545 | 546 | if flag_feature: 547 | return h, feat_cl 548 | else: 549 | return h 550 | 551 | 552 | def weights_init(m): 553 | if isinstance(m, (nn.Conv2d, nn.Linear)): 554 | nn.init.kaiming_normal_(m.weight) 555 | if m.bias is not None: 556 | nn.init.constant_(m.bias, 0.0) 557 | 558 | 559 | class MetaCNN(MetaModule): 560 | def __init__(self, use_checkpoint=False): 561 | super(MetaCNN, self).__init__() 562 | self.conv1 = MetaConv2d(3, 96, kernel_size=3, padding=1) 563 | self.bn1 = MetaBatchNorm2d(96) 564 | self.conv2 = MetaConv2d(96, 96, kernel_size=3, padding=1) 565 | self.bn2 = MetaBatchNorm2d(96) 566 | self.conv3 = MetaConv2d(96, 96, kernel_size=3, stride=2, padding=1) 567 | self.bn3 = MetaBatchNorm2d(96) 568 | self.conv4 = MetaConv2d(96, 192, kernel_size=3, padding=1) 569 | self.bn4 = MetaBatchNorm2d(192) 570 | self.conv5 = MetaConv2d(192, 192, kernel_size=3, padding=1) 571 | self.bn5 = MetaBatchNorm2d(192) 572 | self.conv6 = MetaConv2d(192, 192, kernel_size=3, stride=2, padding=1) 573 | self.bn6 = MetaBatchNorm2d(192) 574 | self.conv7 = MetaConv2d(192, 192, kernel_size=3, padding=1) 575 | self.bn7 = MetaBatchNorm2d(192) 576 | self.conv8 = MetaConv2d(192, 192, kernel_size=1) 577 | self.bn8 = MetaBatchNorm2d(192) 578 | self.conv9 = MetaConv2d(192, 10, kernel_size=1) 579 | self.bn9 = MetaBatchNorm2d(10) 580 | 581 | self.layer1 = nn.Sequential( 582 | self.conv1, 583 | self.bn1, 584 | nn.ReLU(), 585 | ) 586 | 587 | self.layer2 = nn.Sequential( 588 | self.conv2, 589 | self.bn2, 590 | nn.ReLU(), 591 | ) 592 | 593 | self.layer3 = nn.Sequential( 594 | self.conv3, 595 | self.bn3, 596 | nn.ReLU(), 597 | ) 598 | 599 | self.layer4 = nn.Sequential( 600 | self.conv4, 601 | self.bn4, 602 | nn.ReLU(), 603 | ) 604 | 605 | self.layer5 = nn.Sequential( 606 | self.conv5, 607 | self.bn5, 608 | nn.ReLU(), 609 | ) 610 | 611 | self.layer6 = nn.Sequential( 612 | self.conv6, 613 | self.bn6, 614 | nn.ReLU(), 615 | ) 616 | 617 | self.layer7 = nn.Sequential( 618 | self.conv7, 619 | self.bn7, 620 | nn.ReLU(), 621 | ) 622 | 623 | self.layer8 = nn.Sequential( 624 | self.conv8, 625 | self.bn8, 626 | nn.ReLU(), 627 | ) 628 | 629 | self.layer9 = nn.Sequential( 630 | self.conv9, 631 | self.bn9, 632 | nn.ReLU(), 633 | ) 634 | 635 | self.l1 = MetaLinear(640, 1000) 636 | self.l2 = MetaLinear(1000, 1000) 637 | self.classifier = MetaLinear(1000, 1) 638 | 639 | self.fc4 = MetaLinear(1000, 1000) 640 | self.fc5 = MetaLinear(1000, 128) 641 | self.head = nn.Sequential(self.fc4, nn.ReLU(), self.fc5) 642 | 643 | self.apply(weights_init) 644 | 645 | self.use_checkpoint = use_checkpoint 646 | 647 | def forward(self, x, flag_feature=False): 648 | out = x 649 | out = out + torch.zeros( 650 | 1, dtype=out.dtype, device=out.device, requires_grad=True 651 | ) 652 | 653 | if self.use_checkpoint: 654 | out = checkpoint(self.layer1, out) 655 | out = checkpoint(self.layer2, out) 656 | out = checkpoint(self.layer3, out) 657 | out = checkpoint(self.layer4, out) 658 | out = checkpoint(self.layer5, out) 659 | out = checkpoint(self.layer6, out) 660 | out = checkpoint(self.layer7, out) 661 | out = checkpoint(self.layer8, out) 662 | out = checkpoint(self.layer9, out) 663 | else: 664 | out = self.layer1(out) 665 | out = self.layer2(out) 666 | out = self.layer3(out) 667 | out = self.layer4(out) 668 | out = self.layer5(out) 669 | out = self.layer6(out) 670 | out = self.layer7(out) 671 | out = self.layer8(out) 672 | out = self.layer9(out) 673 | 674 | out = out.view(-1, 640) 675 | out = self.l1(out) 676 | out = F.relu(out) 677 | out = self.l2(out) 678 | h_feat = F.relu(out) 679 | 680 | h = self.classifier(h_feat) 681 | 682 | feat_cl = F.normalize(self.head(h_feat), dim=1) 683 | 684 | if flag_feature: 685 | return h, feat_cl 686 | else: 687 | return h 688 | --------------------------------------------------------------------------------