├── demo ├── digit-tsne.jpg ├── first-fig.jpg └── framework.jpg ├── README.md ├── data ├── dataloader.py ├── usps.py └── centroid.py ├── model.py ├── main.py └── utils.py /demo/digit-tsne.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qy-feng/Margin-Openset/HEAD/demo/digit-tsne.jpg -------------------------------------------------------------------------------- /demo/first-fig.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qy-feng/Margin-Openset/HEAD/demo/first-fig.jpg -------------------------------------------------------------------------------- /demo/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qy-feng/Margin-Openset/HEAD/demo/framework.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Margin-Openset 2 | This is the implementation of [Attract or Distract: Explore the Margin of Open Set](https://openaccess.thecvf.com/content_ICCV_2019/html/Feng_Attract_or_Distract_Exploit_the_Margin_of_Open_Set_ICCV_2019_paper.html) (ICCV 2019). 3 | 4 | 5 | 6 | *** 7 | ### Requirements 8 | 9 | - Pytorch 0.4 10 | - scikit-learn 11 | 12 | ### Usage 13 | SVHN -> MNIST 14 | ``` 15 | python train.py --task s2m --gpu 0 --epochs 100 16 | ``` 17 | USPS -> MNIST 18 | ``` 19 | python train.py --task u2m --gpu 0 --epochs 100 20 | ``` 21 | MNIST -> USPS 22 | ``` 23 | python train.py --task m2u --gpu 0 --epochs 100 24 | ``` 25 | 26 | *** 27 | ### digit-TSNE 28 | 29 | 30 | *** 31 | ### Bibtex 32 | 33 | Give a ⭐️ if this project helped you, please also consider citing our work: 34 | ``` 35 | @InProceedings{Feng_2019_ICCV, 36 | author = {Feng, Qianyu and Kang, Guoliang and Fan, Hehe and Yang, Yi}, 37 | title = {Attract or Distract: Exploit the Margin of Open Set}, 38 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 39 | month = {October}, 40 | year = {2019} 41 | } 42 | ``` 43 | 44 | 45 | -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from torchvision.datasets import MNIST 4 | from torchvision.datasets import SVHN 5 | from .usps import * 6 | 7 | 8 | def get_data(args): 9 | if args.task == 's2m': 10 | src_data = SVHN('../data', split='train', download=True, 11 | transform=transforms.Compose([ 12 | transforms.Resize(32), 13 | transforms.ToTensor(), 14 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 15 | ])) 16 | 17 | tgt_data = MNIST('../data', train=True, download=True, 18 | transform=transforms.Compose([ 19 | transforms.Resize(32), 20 | transforms.Lambda(lambda x: x.convert("RGB")), 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 23 | ])) 24 | elif args.task == 'u2m': 25 | src_data = USPS('../data', train=True, download=True, 26 | transform=transforms.Compose([ 27 | transforms.RandomCrop(28, padding=4), 28 | transforms.RandomRotation(10), 29 | transforms.ToTensor(), 30 | transforms.Normalize((0.5,), (0.5,)) 31 | ])) 32 | 33 | tgt_data = MNIST('../data', train=True, download=True, 34 | transform=transforms.Compose([ 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.5,), (0.5,)) 37 | ])) 38 | else: 39 | src_data = MNIST('../data', train=True, download=True, 40 | transform=transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.5,), (0.5,)) 43 | ])) 44 | 45 | tgt_data = USPS('../data', train=True, download=True, 46 | transform=transforms.Compose([ 47 | transforms.ToTensor(), 48 | transforms.Normalize((0.5,), (0.5,)) 49 | ])) 50 | 51 | src_data, tgt_data = relabel_data(src_data, tgt_data, args.task) 52 | 53 | src_loader = torch.utils.data.DataLoader(src_data, 54 | batch_size=args.batch_size, 55 | shuffle=True, num_workers=0) 56 | 57 | tgt_loader = torch.utils.data.DataLoader(tgt_data, 58 | batch_size=args.batch_size, 59 | shuffle=True, num_workers=0) 60 | return src_loader, tgt_loader 61 | 62 | def relabel_data(src_data, tgt_data, task, known_cnum=5): 63 | image_path = [] 64 | image_label = [] 65 | if task == 's2m': 66 | for i in range(len(src_data.data)): 67 | if int(src_data.labels[i]) < known_cnum: 68 | image_path.append(src_data.data[i]) 69 | image_label.append(src_data.labels[i]) 70 | src_data.data = image_path 71 | src_data.labels = image_label 72 | else: 73 | for i in range(len(src_data.train_data)): 74 | if int(src_data.train_labels[i]) < known_cnum: 75 | image_path.append(src_data.train_data[i]) 76 | image_label.append(src_data.train_labels[i]) 77 | src_data.train_data = image_path 78 | src_data.train_labels = image_label 79 | 80 | for i in range(len(tgt_data.train_data)): 81 | if int(tgt_data.train_labels[i]) >= known_cnum: 82 | tgt_data.train_labels[i] = known_cnum 83 | 84 | return src_data, tgt_data -------------------------------------------------------------------------------- /data/usps.py: -------------------------------------------------------------------------------- 1 | """Dataset setting and data loader for USPS. 2 | Modified from 3 | https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py 4 | """ 5 | 6 | import gzip 7 | import os 8 | import pickle 9 | import urllib 10 | from PIL import Image 11 | 12 | import numpy as np 13 | import torch 14 | import torch.utils.data as data 15 | from torch.utils.data.sampler import WeightedRandomSampler 16 | from torchvision import datasets, transforms 17 | 18 | 19 | class USPS(data.Dataset): 20 | """USPS Dataset. 21 | Args: 22 | root (string): Root directory of dataset where dataset file exist. 23 | train (bool, optional): If True, resample from dataset randomly. 24 | download (bool, optional): If true, downloads the dataset 25 | from the internet and puts it in root directory. 26 | If dataset is already downloaded, it is not downloaded again. 27 | transform (callable, optional): A function/transform that takes in 28 | an PIL image and returns a transformed version. 29 | E.g, ``transforms.RandomCrop`` 30 | """ 31 | 32 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl" 33 | 34 | def __init__(self, root, train=True, transform=None, download=False): 35 | """Init USPS dataset.""" 36 | # init params 37 | self.root = os.path.expanduser(root) 38 | self.filename = "usps_28x28.pkl" 39 | self.train = train 40 | # Num of Train = 7438, Num ot Test 1860 41 | self.transform = transform 42 | self.dataset_size = None 43 | 44 | # download dataset. 45 | if download: 46 | self.download() 47 | if not self._check_exists(): 48 | raise RuntimeError("Dataset not found." + 49 | " You can use download=True to download it") 50 | 51 | self.train_data, self.train_labels = self.load_samples() 52 | if self.train: 53 | total_num_samples = self.train_labels.shape[0] 54 | indices = np.arange(total_num_samples) 55 | self.train_data = self.train_data[indices[0:self.dataset_size], ::] 56 | self.train_labels = self.train_labels[indices[0:self.dataset_size]] 57 | self.train_data *= 255.0 58 | self.train_data = np.squeeze(self.train_data).astype(np.uint8) 59 | 60 | def __getitem__(self, index): 61 | """Get images and target for data loader. 62 | Args: 63 | index (int): Index 64 | Returns: 65 | tuple: (image, target) where target is index of the target class. 66 | """ 67 | img, label = self.train_data[index], self.train_labels[index] 68 | img = Image.fromarray(img, mode='L') 69 | img = img.copy() 70 | if self.transform is not None: 71 | img = self.transform(img) 72 | return img, label.astype("int64") 73 | 74 | def __len__(self): 75 | """Return size of dataset.""" 76 | return len(self.train_data) 77 | 78 | def _check_exists(self): 79 | """Check if dataset is download and in right place.""" 80 | return os.path.exists(os.path.join(self.root, self.filename)) 81 | 82 | def download(self): 83 | """Download dataset.""" 84 | filename = os.path.join(self.root, self.filename) 85 | dirname = os.path.dirname(filename) 86 | if not os.path.isdir(dirname): 87 | os.makedirs(dirname) 88 | if os.path.isfile(filename): 89 | return 90 | print("Download %s to %s" % (self.url, os.path.abspath(filename))) 91 | urllib.request.urlretrieve(self.url, filename) 92 | print("[DONE]") 93 | return 94 | 95 | def load_samples(self): 96 | """Load sample images from dataset.""" 97 | filename = os.path.join(self.root, self.filename) 98 | f = gzip.open(filename, "rb") 99 | data_set = pickle.load(f, encoding="bytes") 100 | f.close() 101 | if self.train: 102 | images = data_set[0][0] 103 | labels = data_set[0][1] 104 | self.dataset_size = labels.shape[0] 105 | else: 106 | images = data_set[1][0] 107 | labels = data_set[1][1] 108 | self.dataset_size = labels.shape[0] 109 | return images, labels 110 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class Conv_Block(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1): 9 | super(Conv_Block, self).__init__() 10 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride) 11 | self.relu = torch.nn.LeakyReLU() 12 | self.bn = nn.BatchNorm2d(out_channels) 13 | 14 | def forward(self, x): 15 | x = self.conv(x) 16 | x = self.relu(x) 17 | x = self.bn(x) 18 | return x 19 | 20 | 21 | class Dense_Block(nn.Module): 22 | def __init__(self, in_features, out_features): 23 | super(Dense_Block, self).__init__() 24 | self.fc = nn.Linear(in_features, out_features) 25 | self.relu = torch.nn.LeakyReLU() 26 | self.bn = nn.BatchNorm1d(out_features) 27 | 28 | def forward(self, x): 29 | x = self.fc(x) 30 | x = self.relu(x) 31 | x = self.bn(x) 32 | return x 33 | 34 | 35 | class GradReverse(torch.autograd.Function): 36 | def __init__(self, lambd): 37 | self.lambd = lambd 38 | 39 | def forward(self, x): 40 | return x.view_as(x) 41 | 42 | def backward(self, grad_output): 43 | return (grad_output * -self.lambd) 44 | 45 | def grad_reverse(x, p=1): 46 | lambd = 2. / (1. + np.exp(-10 * p)) - 1 47 | return GradReverse(lambd)(x) 48 | 49 | 50 | class Generator_s2m(nn.Module): 51 | def __init__(self): 52 | super(Generator_s2m, self).__init__() 53 | self.conv1 = Conv_Block(3, 64, kernel_size=5) 54 | self.conv2 = Conv_Block(64, 64, kernel_size=5) 55 | self.conv3 = Conv_Block(64, 128, kernel_size=3, stride=2) 56 | self.conv4 = Conv_Block(128, 128, kernel_size=3, stride=2) 57 | self.fc1 = Dense_Block(3200, 100) 58 | self.fc2 = Dense_Block(100, 100) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = self.conv2(x) 63 | x = self.conv3(x) 64 | x = self.conv4(x) 65 | x = x.view(x.size(0), -1) 66 | x = self.fc1(x) 67 | x = self.fc2(x) 68 | return x 69 | 70 | 71 | class Classifier_s2m(nn.Module): 72 | def __init__(self, n_output): 73 | super(Classifier_s2m, self).__init__() 74 | self.fc = nn.Linear(100, n_output) 75 | 76 | def forward(self, x): 77 | x = self.fc(x) 78 | return x 79 | 80 | 81 | class Generator_u2m(nn.Module): 82 | def __init__(self): 83 | super(Generator_u2m, self).__init__() 84 | self.conv1 = Conv_Block(1, 20, kernel_size=5) 85 | self.pool1 = nn.MaxPool2d(2, stride=2) 86 | self.conv2 = Conv_Block(20, 50, kernel_size=5) 87 | self.pool2 = nn.MaxPool2d(2, stride=2) 88 | self.drop = nn.Dropout() 89 | self.fc = Dense_Block(800, 500) 90 | 91 | def forward(self, x): 92 | x = self.conv1(x) 93 | x = self.pool1(x) 94 | x = self.conv2(x) 95 | x = self.pool2(x) 96 | x = x.view(x.size(0), -1) 97 | x = self.drop(x) 98 | x = self.fc(x) 99 | return x 100 | 101 | 102 | class Classifier_u2m(nn.Module): 103 | def __init__(self, n_output): 104 | super(Classifier_u2m, self).__init__() 105 | self.fc = nn.Linear(500, n_output) 106 | 107 | def forward(self, x): 108 | x = self.fc(x) 109 | return x 110 | 111 | 112 | class Net(nn.Module): 113 | def __init__(self, task='s2m'): 114 | super(Net, self).__init__() 115 | if task == 's2m': 116 | self.generator = Generator_s2m() 117 | self.classifier = Classifier_s2m(6) 118 | elif task =='u2m' or task == 'm2u': 119 | self.generator = Generator_u2m() 120 | self.classifier = Classifier_u2m(6) 121 | 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 125 | elif isinstance(m, nn.BatchNorm2d): 126 | nn.init.constant_(m.weight, 1) 127 | nn.init.constant_(m.bias, 0) 128 | elif isinstance(m, nn.BatchNorm1d): 129 | nn.init.constant_(m.weight, 1) 130 | nn.init.constant_(m.bias, 0) 131 | 132 | def forward(self, x, p=None, adv=False): 133 | x = self.generator(x) 134 | if adv == True: 135 | x = grad_reverse(x, p) 136 | y = self.classifier(x) 137 | return x, y -------------------------------------------------------------------------------- /data/centroid.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import numpy as np 3 | import torch as t 4 | from utils import * 5 | 6 | 7 | class Centroids(object): 8 | def __init__(self, class_num, dim, use_cuda): 9 | self.class_num = class_num 10 | self.src_ctrs = t.ones((class_num, dim)) 11 | self.tgt_ctrs = t.ones((class_num, dim)) 12 | self.src_ctrs *= 1e-10 13 | self.tgt_ctrs *= 1e-10 14 | if use_cuda: 15 | self.src_ctrs = self.src_ctrs.cuda() 16 | self.tgt_ctrs = self.tgt_ctrs.cuda() 17 | 18 | 19 | def get_centroids(self, domain=None, cid=None): 20 | if domain == 'source': 21 | return self.src_ctrs if cid is None else self.src_ctrs[cid, :] 22 | elif domain == 'target': 23 | return self.tgt_ctrs if cid is None else self.tgt_ctrs[cid, :] 24 | else: 25 | return self.src_ctrs, self.tgt_ctrs 26 | 27 | @torch.no_grad() 28 | def update(self, feat_s, pred_s, label_s, feat_t, pred_t): 29 | self.upd_src_centroids(feat_s, pred_s, label_s) 30 | self.upd_tgt_centroids(feat_t, pred_t) 31 | 32 | @torch.no_grad() 33 | def upd_src_centroids(self, feats, probs, labels): 34 | # feats = to_np(feats) 35 | labels = to_np(labels) 36 | # last_centroids = to_np(self.src_ctrs) 37 | probs = to_np(F.softmax(probs, dim=1)) 38 | 39 | for i in range(self.class_num - 1): 40 | if np.sum(labels == i) > 1: 41 | last_centroid = self.src_ctrs[i, :] 42 | data_idx = np.argwhere(labels == i) 43 | new_centroid = t.mean(feats[data_idx, :], 0).squeeze() 44 | cs = cal_sim(new_centroid, last_centroid) 45 | # print(cs) 46 | new_centroid = cs * new_centroid + (1 - cs) * last_centroid 47 | self.src_ctrs[i, :] = new_centroid 48 | 49 | @torch.no_grad() 50 | def upd_tgt_centroids(self, feats, probs): 51 | # feats = to_np(feats) 52 | # last_centroids = to_np(self.tgt_ctrs) 53 | # src_centroids = to_np(self.src_ctrs) 54 | _, pseudo_label = probs.max(1, keepdim=True) 55 | pseudo_label = to_np(pseudo_label) 56 | probs = to_np(F.softmax(probs, dim=1)) 57 | 58 | for i in range(self.class_num): 59 | if np.sum(pseudo_label == i) > 1: 60 | data_idx = np.argwhere(pseudo_label == i) 61 | new_centroid = t.mean(feats[data_idx, :], 0).squeeze() 62 | last_centroid = self.tgt_ctrs[i, :] 63 | # if last_centroids[i] != np.zeros_like((1, feats.shape[0])): 64 | cs = cal_sim(new_centroid, self.src_ctrs[i, :]) 65 | # print(cs) 66 | new_centroid = cs * new_centroid + (1 - cs) * last_centroid 67 | self.tgt_ctrs[i, :] = new_centroid 68 | 69 | 70 | def crit_intra(feats, y, centers, lambd=1e-3): 71 | class_num = len(centers) 72 | batch_size = y.shape[0] 73 | 74 | expanded_centers = centers.expand(batch_size, -1, -1) 75 | expanded_feats = feats.expand(class_num, -1, -1).transpose(1, 0) 76 | # distance_centers = (expanded_feats - expanded_centers).pow(2).sum(dim=-1) 77 | distance_centers = cal_sim(expanded_feats, expanded_centers) 78 | distance_centers = distance_centers.reshape(batch_size, class_num) 79 | 80 | intra_distances = distance_centers.gather(1, y.unsqueeze(1)) 81 | # intra_distances = distances_same.sum() 82 | inter_distances = distance_centers.sum(dim=-1) - intra_distances 83 | 84 | epsilon = 1e-6 85 | loss = (1 / 2.0 / batch_size / class_num) * intra_distances / \ 86 | (inter_distances + epsilon) 87 | loss = loss.sum() 88 | loss *= lambd 89 | return loss 90 | 91 | 92 | def crit_inter(center1, center2, lambd=1e-3): 93 | # dists = F.pairwise_distance(center1, center2) 94 | # loss = t.mean(dists) 95 | 96 | # dists = cal_cossim(center1.cpu().numpy(), center2.cpu().numpy()) 97 | dists = cal_sim(center1, center2) 98 | loss = 0 99 | for i in range(center1.shape[0]): 100 | loss += dists[i]#[i] 101 | loss /= center1.shape[0] 102 | loss *= lambd 103 | return loss, dists 104 | 105 | 106 | def crit_contrast(feats, probs, s_ctds, t_ctds, lambd=1e-3): 107 | batch_num = feats.shape[0] 108 | class_num = s_ctds.shape[0] 109 | probs = F.softmax(probs, dim=-1) 110 | max_probs, preds = probs.max(1, keepdim=True) 111 | # print(probs.shape, max_probs.shape) 112 | select_index = t.nonzero(max_probs.squeeze() >= 0.3).squeeze(1) 113 | select_index = select_index.cpu().tolist() 114 | 115 | # todo: calculate margins 116 | # dist_ctds = cal_cossim(to_np(s_ctds), to_np(t_ctds)) 117 | dist_ctds = cal_sim(s_ctds, t_ctds) 118 | # print('dist_ctds', dist_ctds.shape) 119 | 120 | M = np.ones(class_num) 121 | for i in range(class_num): 122 | # M[i] = np.sum(dist_ctds[i, :]) - dist_ctds[i, i] 123 | M[i] = dist_ctds.mean() - dist_ctds[i] 124 | M[i] /= class_num - 1 125 | # print('M', M) 126 | 127 | # todo: calculate D_k between known samples to its source centroid & 128 | # todo: calculate D_u distances between unknown samples to all source centroids 129 | D_k, n_k = 0, 1e-5 130 | D_u, n_u = 0, 1e-5 131 | for i in select_index: 132 | class_id = preds[i][0] 133 | if class_id < class_num: 134 | # D_k += F.pairwise_distance(feats[i, :], s_ctds[class_id]).squeeze() 135 | # print(feats.shape, i) 136 | D_k += cal_sim(feats[i, :], s_ctds[class_id, :]) 137 | # print('D_k', D_k) 138 | n_k += 1 139 | else: 140 | # todo: judge if unknown sample in the radius region of known centroid 141 | rp_feats = feats[i, :].unsqueeze(0).repeat(class_num, 1) 142 | 143 | # dist_known = F.pairwise_distance(rp_feats, s_ctds) 144 | dist_known = cal_sim(rp_feats, s_ctds) 145 | # print('dist_known', len(dist_known), dist_known) 146 | 147 | M_mean = M.mean() 148 | outliers = dist_known < M_mean 149 | dist_margin = (dist_known - M_mean) * outliers.float() 150 | D_u += dist_margin.sum() 151 | 152 | loss = D_k / n_k # - D_u / n_u 153 | return loss.mean() * lambd 154 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optimizer 9 | 10 | from data.dataloader import * 11 | from data.centroid import * 12 | from model import Net 13 | import utils 14 | 15 | 16 | def main(args): 17 | if torch.cuda.is_available(): 18 | device = torch.device("cuda:" + str(args.gpu)) 19 | is_cuda = True 20 | else: 21 | device = torch.device("cpu") 22 | is_cuda = False 23 | 24 | src_loader, tgt_loader = get_data(args) 25 | 26 | model = Net(task=args.task).to(device) 27 | 28 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 29 | momentum=args.momentum, 30 | weight_decay=args.weight_decay, 31 | nesterov=True) 32 | 33 | if args.resume: 34 | print("=> loading checkpoint '{}'".format(args.resume)) 35 | checkpoint = torch.load(args.resume) 36 | args.start_epoch = checkpoint['epoch'] 37 | best_acc = checkpoint['best_acc'] 38 | model.load_state_dict(checkpoint['state_dict']) 39 | 40 | print("=> loaded checkpoint '{}' (epoch {})" 41 | .format(args.resume, checkpoint['epoch'])) 42 | 43 | best_acc = 0 44 | best_label = [] 45 | best_result = [] 46 | 47 | # create centroids for known classes 48 | all_centroids = Centroids(args.class_num - 1, 100, use_cuda=is_cuda) 49 | 50 | try: 51 | # start training 52 | for epoch in range(args.epochs): 53 | data = (src_loader, tgt_loader, all_centroids) 54 | 55 | all_centroids = train(model, optimizer, data, epoch, device, args) 56 | 57 | result, gt_label, acc = test(model, tgt_loader, epoch, device, args) 58 | 59 | is_best = acc > best_acc 60 | if is_best: 61 | best_acc = acc 62 | best_label = gt_label 63 | best_pred = result 64 | 65 | utils.save_checkpoint({ 66 | 'epoch': epoch, 67 | 'state_dict': model.state_dict(), 68 | 'best_acc': best_acc 69 | }, is_best, args.check_dir) 70 | 71 | print ("------Best-------") 72 | utils.cal_acc(best_label, best_result, args.class_num) 73 | 74 | except KeyboardInterrupt: 75 | print ("------Best-------") 76 | utils.cal_acc(best_label, best_result, args.class_num) 77 | 78 | 79 | def train(model, optimizer, data, epoch, device, args): 80 | 81 | src_loader, tgt_loader, all_centroids = data 82 | pre_stage = 5 83 | adv_stage = 15 84 | criterion_bce = nn.BCELoss() 85 | criterion_cel = nn.CrossEntropyLoss() 86 | 87 | model.train() 88 | 89 | for batch_idx, (batch_s, batch_t) in enumerate(zip(src_loader, tgt_loader)): 90 | global_step = epoch * len(src_loader) + batch_idx 91 | p = global_step / args.epochs * len(src_loader) 92 | lr = utils.adjust_learning_rate(optimizer, epoch, args, 93 | batch_idx, len(src_loader)) 94 | data_s, label_s = batch_s 95 | data_s = data_s.to(device) 96 | label_s = label_s.to(device) 97 | data_t, label_t = batch_t 98 | data_t = data_t.to(device) 99 | adv_label_t = torch.tensor([args.th]*len(data_t)).to(device) 100 | 101 | loss = 0 102 | optimizer.zero_grad() 103 | feat_s, pred_s = model(data_s) 104 | feat_t, pred_t = model(data_t, p, adv=True) 105 | 106 | # classification loss for known classes in source domain 107 | loss_cel = criterion_cel(pred_s, label_s) 108 | loss += loss_cel 109 | 110 | if epoch >= pre_stage: 111 | # adversarial loss for unknown class in target domain 112 | pred_t_prob_unk = F.softmax(pred_t, dim=1)[:, -1] 113 | loss_adv = criterion_bce(pred_t_prob_unk, adv_label_t) 114 | loss += loss_adv 115 | 116 | if epoch >= adv_stage: 117 | all_centroids.update(feat_s, pred_s, label_s, feat_t, pred_t) 118 | s_ctds, t_ctds = all_centroids.get_centroids() 119 | 120 | loss_intra = crit_intra(feat_s, label_s, s_ctds) 121 | loss += loss_intra * args.lamb_s 122 | 123 | loss_inter, _ = crit_inter(s_ctds, t_ctds) 124 | loss += loss_inter * args.lamb_c 125 | 126 | loss_contr = crit_contrast(feat_t, pred_t, s_ctds, t_ctds) 127 | loss += loss_contr * args.lamb_t 128 | 129 | loss.backward() 130 | optimizer.step() 131 | 132 | if epoch >= pre_stage and batch_idx % args.log_interval == 0: 133 | print('Epoch: {} [{}/{} ({:.0f}%)] LR: {:.6f} \ 134 | Loss(cel): {:.4f} Loss(adv): {:.4f}\t'.format( 135 | epoch, batch_idx * args.batch_size, 136 | len(src_loader.dataset), 137 | 100. * batch_idx / len(src_loader), lr, 138 | loss_cel.item(), loss_adv.item())) 139 | 140 | return all_centroids 141 | 142 | 143 | def test(model, tgt_loader, epoch, device, args): 144 | 145 | loss = 0 146 | correct = 0 147 | result = [] 148 | gt_label = [] 149 | 150 | model.eval() 151 | criterion_cel = nn.CrossEntropyLoss() 152 | 153 | for batch_idx, (data_t, label) in enumerate(tgt_loader): 154 | data_t = data_t.to(device) 155 | label = label.to(device) 156 | 157 | feat, output = model(data_t) 158 | pred = output.max(1, keepdim=True)[1] 159 | loss += criterion_cel(output, label).item() 160 | 161 | for i in range(len(pred)): 162 | result.append(pred[i].item()) 163 | gt_label.append(label[i].item()) 164 | 165 | correct += pred.eq(label.view_as(pred)).sum().item() 166 | 167 | loss /= len(tgt_loader.dataset) 168 | 169 | utils.cal_acc(gt_label, result, args.class_num) 170 | acc = 100. * correct / len(tgt_loader.dataset) 171 | 172 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 173 | loss, correct, len(tgt_loader.dataset), 174 | 100. * correct / len(tgt_loader.dataset))) 175 | 176 | return result, gt_label, acc 177 | 178 | 179 | if __name__ == "__main__": 180 | 181 | parser = argparse.ArgumentParser(description='Openset-DA SVHN -> MNIST Example') 182 | parser.add_argument('--task', choices=['s2m', 'u2m', 'm2u'], default='s2m', 183 | help='domain adaptation sub-task') 184 | parser.add_argument('--class-num', type=int, default=6, help='number of classes') 185 | parser.add_argument('--th', type=float, default=0.5, metavar='TH', 186 | help='threshold for unknown class') 187 | parser.add_argument('--lamb-s', type=float, default=0.02) 188 | parser.add_argument('--lamb-c', type=float, default=0.005) 189 | parser.add_argument('--lamb-t', type=float, default=0.0001) 190 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 191 | help='input batch size for training (default: 128)') 192 | parser.add_argument('--epochs', type=int, default=100, metavar='E', 193 | help='number of epochs to train') 194 | parser.add_argument('--lr', type=float, default=0.0005, metavar='LR', 195 | help='learning rate') 196 | parser.add_argument('--lr-rampdown-epochs', default=101, type=int, 197 | help='length of learning rate cosine rampdown (>= length of training)') 198 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M') 199 | # parser.add_argument('--grl-rampup-epochs', default=20, type=int, metavar='EPOCHS', 200 | # help='length of grl rampup') 201 | parser.add_argument('--weight-decay', '--wd', default=1e-3, type=float, 202 | help='weight decay (default: 1e-3)') 203 | 204 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 205 | help='how many batches to wait before logging training status') 206 | parser.add_argument('--check_dir', default='checkpoint', type=str, 207 | help='directory to save checkpoint') 208 | parser.add_argument('--resume', default='', type=str, 209 | help='path to resume checkpoint (default: none)') 210 | parser.add_argument('--gpu', default='0', type=str, metavar='GPU', 211 | help='id(s) for CUDA_VISIBLE_DEVICES') 212 | 213 | args = parser.parse_args() 214 | 215 | torch.backends.cudnn.benchmark = True 216 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 217 | 218 | main(args) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import copy 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from sklearn.metrics import accuracy_score 8 | from sklearn.metrics.pairwise import cosine_similarity 9 | 10 | 11 | def save_checkpoint(state, is_best, check_dir): 12 | filename = 'latest.pth.tar' 13 | torch.save(state, os.path.join(check_dir, filename)) 14 | if is_best: 15 | shutil.copyfile(os.path.join(check_dir, filename), 16 | os.path.join(check_dir, 'best.pth.tar')) 17 | 18 | 19 | def cal_acc(gt_label, pred_result, num): 20 | acc_sum = 0 21 | for n in range(num): 22 | y = [] 23 | pred_y = [] 24 | for i in range(len(gt_label)): 25 | gt = gt_label[i] 26 | pred = pred_result[i] 27 | if gt == n: 28 | y.append(gt) 29 | pred_y.append(pred) 30 | print ('{}: {:4f}'.format(n if n != (num - 1) else 'Unk', accuracy_score(y, pred_y))) 31 | if n == (num - 1): 32 | print ('Known Avg Acc: {:4f}'.format(acc_sum / (num - 1))) 33 | acc_sum += accuracy_score(y, pred_y) 34 | print ('Avg Acc: {:4f}'.format(acc_sum / num)) 35 | print ('Overall Acc : {:4f}'.format(accuracy_score(gt_label, pred_result))) 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | assert 0 <= current <= rampdown_length 40 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 41 | 42 | 43 | def to_np(x): 44 | return x.squeeze().cpu().detach().numpy() 45 | 46 | 47 | def get_src_centroids(data_loader, model, args): 48 | feats, labels, probs, preds = get_features(data_loader, model) 49 | centroids = [] 50 | for i in range(args.class_num - 1): 51 | data_idx = np.unique(np.argwhere(labels == i)) 52 | feats_i = feats[data_idx].squeeze() 53 | 54 | center_i = np.mean(feats_i, axis=0) 55 | centroids.append(center_i) 56 | 57 | centroids = np.array(centroids).squeeze() 58 | return torch.from_numpy(centroids).cuda() 59 | 60 | 61 | def get_tgt_centroids(data_loader, model, th, src_centroids, args): 62 | feats, labels, probs, preds = get_features(data_loader, model) 63 | src_centroids = to_np(src_centroids) 64 | tgt_dissim = cal_sim(src_centroids, feats, rev=True) 65 | centroids = [] 66 | for i in range(args.CLASS_NUM - 1): 67 | class_idx = np.unique(np.argwhere(preds == i)) 68 | easy_idx = np.unique(np.argwhere(tgt_dissim[i, :] <= th)) 69 | data_idx = np.intersect1d(class_idx, easy_idx) 70 | if len(data_idx) > 1: 71 | feats_i = feats[data_idx].squeeze() 72 | else: 73 | feats_i = np.zeros_like(feats) 74 | print(i, 'none') 75 | center_i = np.mean(feats_i, axis=0) 76 | centroids.append(center_i) 77 | 78 | centroids = np.array(centroids).squeeze() 79 | return torch.from_numpy(centroids).cuda() 80 | 81 | 82 | def upd_src_centroids(feats, labels, probs, last_centroids, args): 83 | new_centroids = [] 84 | feats = to_np(feats) 85 | labels = to_np(labels) 86 | last_centroids = to_np(last_centroids) 87 | probs = F.softmax(probs, dim=1) 88 | probs = to_np(probs) 89 | for i in range(args.class_num - 1): 90 | if np.sum(labels == i) > 0: 91 | data_idx = np.intersect1d(np.argwhere(labels == i), np.argwhere(probs[:, i] > 0.1)) 92 | new_centroid = np.mean(feats[data_idx], axis=0).reshape(1,-1) 93 | cs = cosine_similarity(new_centroid, last_centroids[i].reshape(1,-1))[0][0] 94 | new_centroid = cs * new_centroid + (1 - cs) * last_centroids[i] 95 | else: 96 | new_centroid = last_centroids[i] 97 | 98 | new_centroids.append(new_centroid.squeeze()) 99 | 100 | new_centroids = np.array(new_centroids) 101 | return torch.from_numpy(new_centroids).cuda() 102 | 103 | 104 | def upd_tgt_centroids(feats, probs, last_centroids, src_centroids, args): 105 | new_centroids = [] 106 | feats = to_np(feats) 107 | last_centroids = to_np(last_centroids) 108 | src_centroids = to_np(src_centroids) 109 | _, ps_labels = probs.max(1, keepdim=True) 110 | ps_labels = to_np(ps_labels) 111 | probs = F.softmax(probs, dim=1) 112 | probs = to_np(probs) 113 | for i in range(args.CLASS_NUM - 1): 114 | if np.sum(ps_labels == i) > 0: 115 | data_idx = np.intersect1d(np.argwhere(ps_labels == i), np.argwhere(probs[:, i] > 0.1)) 116 | new_centroid = np.mean(feats[data_idx], axis=0).reshape(1,-1) 117 | 118 | if last_centroids[i] != np.zeros_like((1, feats.shape[0])): 119 | cs = cosine_similarity(new_centroid, src_centroids[i].reshape(1,-1))[0][0] 120 | new_centroid = cs * new_centroid + (1 - cs) * last_centroids[i] 121 | else: 122 | new_centroid = last_centroids[i] 123 | 124 | new_centroids.append(new_centroid.squeeze()) 125 | 126 | new_centroids = np.array(new_centroids) 127 | return torch.from_numpy(new_centroids).cuda() 128 | 129 | 130 | def get_features(data_loader, model): 131 | model.eval() 132 | feats, labels = [], [] 133 | probs, preds = [], [] 134 | for batch_idx, batch_data in enumerate(data_loader): 135 | input, label = batch_data 136 | input, label = input.cuda(), label.cuda(non_blocking=True) 137 | 138 | feat, prob = model(input) 139 | prob, pred = prob.max(1, keepdim=True) 140 | 141 | feats.append(feat.cpu().detach().numpy()) 142 | labels.append(label.cpu().detach().numpy()) 143 | probs.append(prob.cpu().detach().numpy()) 144 | preds.append(pred.cpu().detach().numpy()) 145 | 146 | feats = np.concatenate(feats, axis=0) 147 | labels = np.concatenate(labels, axis=0) 148 | probs = np.concatenate(probs, axis=0) 149 | preds = np.concatenate(preds, axis=0) 150 | return feats, labels, probs, preds 151 | 152 | 153 | def cosine_rampdown(current, rampdown_length): 154 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 155 | assert 0 <= current <= rampdown_length 156 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 157 | 158 | 159 | def adjust_learning_rate(optimizer, epoch, args, 160 | step_in_epoch, total_steps_in_epoch): 161 | epoch = epoch + step_in_epoch / total_steps_in_epoch 162 | 163 | lr = args.lr * cosine_rampdown(epoch, args.lr_rampdown_epochs) 164 | 165 | for param_group in optimizer.param_groups: 166 | param_group['lr'] = lr 167 | 168 | return lr 169 | 170 | 171 | def cal_sim(x1, x2, metric='cosine'): 172 | # x = x1.clone() 173 | if len(x1.shape) != 2: 174 | x1 = x1.reshape(-1, x1.shape[-1]) 175 | if len(x2.shape) != 2: 176 | x2 = x2.reshape(-1, x2.shape[-1]) 177 | 178 | if metric == 'cosine': 179 | sim = (F.cosine_similarity(x1, x2) + 1) / 2 180 | else: 181 | sim = F.pairwise_distance(x1, x2) / torch.norm(x2, dim=1) 182 | return sim 183 | 184 | 185 | def result_log(best_epoch, acc_score, OS_score, all_score, args): 186 | with open(os.path.join(args.checkpoint, args.log_path), 'a') as f: 187 | f.write('Task %s\n' % args.task) 188 | f.write('init_lr %.5f, wd %.5f batch %d\n' % (args.lr, args.weight_decay, args.batch_size)) 189 | f.write('w_s %.5f | w_c %.5f | w_t %.5f\n' % (args.w_s, args.w_c, args.w_t)) 190 | f.write('Best(%d) OS* %.3f OS %.3f ALL %.3f unk %.3f\n' % (best_epoch, acc_score[0], acc_score[1], 191 | acc_score[2], acc_score[3])) 192 | f.write('(OS) OS* %.3f OS %.3f ALL %.3f unk %.3f\n' % (OS_score[0], OS_score[1], OS_score[2], OS_score[3])) 193 | f.write( 194 | '(all) OS* %.3f OS %.3f ALL %.3f unk %.3f\n' % (all_score[0], all_score[1], all_score[2], all_score[3])) 195 | 196 | 197 | # def cal_acc(gt_list, predict_list, num): 198 | # acc_sum = 0 199 | # acc_list = {} 200 | # for n in range(num): 201 | # y = [] 202 | # pred_y = [] 203 | # for i in range(len(gt_list)): 204 | # gt = gt_list[i] 205 | # predict = predict_list[i] 206 | # if gt == n: 207 | # y.append(gt) 208 | # pred_y.append(predict) 209 | # acc = accuracy_score(y, pred_y) 210 | # print('{}: {:4f}'.format(n if n != (num - 1) else 'Unk', acc)) 211 | # acc_list[n] = acc 212 | # if n == (num - 1): 213 | # OS_ = acc_sum * 1.0 / (num - 1) 214 | # print('Known Avg Acc: {:4f}'.format(OS_)) 215 | # unk = accuracy_score(y, pred_y) 216 | # acc_sum += accuracy_score(y, pred_y) 217 | # OS = acc_sum * 1.0 / num 218 | # all = accuracy_score(gt_list, predict_list) 219 | # print('Avg Acc: {:4f}'.format(OS)) 220 | # print('Overall Acc : {:4f}\n'.format(all)) 221 | # return OS_, OS, all, unk, acc_list --------------------------------------------------------------------------------