├── demo
├── digit-tsne.jpg
├── first-fig.jpg
└── framework.jpg
├── README.md
├── data
├── dataloader.py
├── usps.py
└── centroid.py
├── model.py
├── main.py
└── utils.py
/demo/digit-tsne.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qy-feng/Margin-Openset/HEAD/demo/digit-tsne.jpg
--------------------------------------------------------------------------------
/demo/first-fig.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qy-feng/Margin-Openset/HEAD/demo/first-fig.jpg
--------------------------------------------------------------------------------
/demo/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qy-feng/Margin-Openset/HEAD/demo/framework.jpg
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Margin-Openset
2 | This is the implementation of [Attract or Distract: Explore the Margin of Open Set](https://openaccess.thecvf.com/content_ICCV_2019/html/Feng_Attract_or_Distract_Exploit_the_Margin_of_Open_Set_ICCV_2019_paper.html) (ICCV 2019).
3 |
4 |
5 |
6 | ***
7 | ### Requirements
8 |
9 | - Pytorch 0.4
10 | - scikit-learn
11 |
12 | ### Usage
13 | SVHN -> MNIST
14 | ```
15 | python train.py --task s2m --gpu 0 --epochs 100
16 | ```
17 | USPS -> MNIST
18 | ```
19 | python train.py --task u2m --gpu 0 --epochs 100
20 | ```
21 | MNIST -> USPS
22 | ```
23 | python train.py --task m2u --gpu 0 --epochs 100
24 | ```
25 |
26 | ***
27 | ### digit-TSNE
28 |
29 |
30 | ***
31 | ### Bibtex
32 |
33 | Give a ⭐️ if this project helped you, please also consider citing our work:
34 | ```
35 | @InProceedings{Feng_2019_ICCV,
36 | author = {Feng, Qianyu and Kang, Guoliang and Fan, Hehe and Yang, Yi},
37 | title = {Attract or Distract: Exploit the Margin of Open Set},
38 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
39 | month = {October},
40 | year = {2019}
41 | }
42 | ```
43 |
44 |
45 |
--------------------------------------------------------------------------------
/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms
3 | from torchvision.datasets import MNIST
4 | from torchvision.datasets import SVHN
5 | from .usps import *
6 |
7 |
8 | def get_data(args):
9 | if args.task == 's2m':
10 | src_data = SVHN('../data', split='train', download=True,
11 | transform=transforms.Compose([
12 | transforms.Resize(32),
13 | transforms.ToTensor(),
14 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
15 | ]))
16 |
17 | tgt_data = MNIST('../data', train=True, download=True,
18 | transform=transforms.Compose([
19 | transforms.Resize(32),
20 | transforms.Lambda(lambda x: x.convert("RGB")),
21 | transforms.ToTensor(),
22 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
23 | ]))
24 | elif args.task == 'u2m':
25 | src_data = USPS('../data', train=True, download=True,
26 | transform=transforms.Compose([
27 | transforms.RandomCrop(28, padding=4),
28 | transforms.RandomRotation(10),
29 | transforms.ToTensor(),
30 | transforms.Normalize((0.5,), (0.5,))
31 | ]))
32 |
33 | tgt_data = MNIST('../data', train=True, download=True,
34 | transform=transforms.Compose([
35 | transforms.ToTensor(),
36 | transforms.Normalize((0.5,), (0.5,))
37 | ]))
38 | else:
39 | src_data = MNIST('../data', train=True, download=True,
40 | transform=transforms.Compose([
41 | transforms.ToTensor(),
42 | transforms.Normalize((0.5,), (0.5,))
43 | ]))
44 |
45 | tgt_data = USPS('../data', train=True, download=True,
46 | transform=transforms.Compose([
47 | transforms.ToTensor(),
48 | transforms.Normalize((0.5,), (0.5,))
49 | ]))
50 |
51 | src_data, tgt_data = relabel_data(src_data, tgt_data, args.task)
52 |
53 | src_loader = torch.utils.data.DataLoader(src_data,
54 | batch_size=args.batch_size,
55 | shuffle=True, num_workers=0)
56 |
57 | tgt_loader = torch.utils.data.DataLoader(tgt_data,
58 | batch_size=args.batch_size,
59 | shuffle=True, num_workers=0)
60 | return src_loader, tgt_loader
61 |
62 | def relabel_data(src_data, tgt_data, task, known_cnum=5):
63 | image_path = []
64 | image_label = []
65 | if task == 's2m':
66 | for i in range(len(src_data.data)):
67 | if int(src_data.labels[i]) < known_cnum:
68 | image_path.append(src_data.data[i])
69 | image_label.append(src_data.labels[i])
70 | src_data.data = image_path
71 | src_data.labels = image_label
72 | else:
73 | for i in range(len(src_data.train_data)):
74 | if int(src_data.train_labels[i]) < known_cnum:
75 | image_path.append(src_data.train_data[i])
76 | image_label.append(src_data.train_labels[i])
77 | src_data.train_data = image_path
78 | src_data.train_labels = image_label
79 |
80 | for i in range(len(tgt_data.train_data)):
81 | if int(tgt_data.train_labels[i]) >= known_cnum:
82 | tgt_data.train_labels[i] = known_cnum
83 |
84 | return src_data, tgt_data
--------------------------------------------------------------------------------
/data/usps.py:
--------------------------------------------------------------------------------
1 | """Dataset setting and data loader for USPS.
2 | Modified from
3 | https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py
4 | """
5 |
6 | import gzip
7 | import os
8 | import pickle
9 | import urllib
10 | from PIL import Image
11 |
12 | import numpy as np
13 | import torch
14 | import torch.utils.data as data
15 | from torch.utils.data.sampler import WeightedRandomSampler
16 | from torchvision import datasets, transforms
17 |
18 |
19 | class USPS(data.Dataset):
20 | """USPS Dataset.
21 | Args:
22 | root (string): Root directory of dataset where dataset file exist.
23 | train (bool, optional): If True, resample from dataset randomly.
24 | download (bool, optional): If true, downloads the dataset
25 | from the internet and puts it in root directory.
26 | If dataset is already downloaded, it is not downloaded again.
27 | transform (callable, optional): A function/transform that takes in
28 | an PIL image and returns a transformed version.
29 | E.g, ``transforms.RandomCrop``
30 | """
31 |
32 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl"
33 |
34 | def __init__(self, root, train=True, transform=None, download=False):
35 | """Init USPS dataset."""
36 | # init params
37 | self.root = os.path.expanduser(root)
38 | self.filename = "usps_28x28.pkl"
39 | self.train = train
40 | # Num of Train = 7438, Num ot Test 1860
41 | self.transform = transform
42 | self.dataset_size = None
43 |
44 | # download dataset.
45 | if download:
46 | self.download()
47 | if not self._check_exists():
48 | raise RuntimeError("Dataset not found." +
49 | " You can use download=True to download it")
50 |
51 | self.train_data, self.train_labels = self.load_samples()
52 | if self.train:
53 | total_num_samples = self.train_labels.shape[0]
54 | indices = np.arange(total_num_samples)
55 | self.train_data = self.train_data[indices[0:self.dataset_size], ::]
56 | self.train_labels = self.train_labels[indices[0:self.dataset_size]]
57 | self.train_data *= 255.0
58 | self.train_data = np.squeeze(self.train_data).astype(np.uint8)
59 |
60 | def __getitem__(self, index):
61 | """Get images and target for data loader.
62 | Args:
63 | index (int): Index
64 | Returns:
65 | tuple: (image, target) where target is index of the target class.
66 | """
67 | img, label = self.train_data[index], self.train_labels[index]
68 | img = Image.fromarray(img, mode='L')
69 | img = img.copy()
70 | if self.transform is not None:
71 | img = self.transform(img)
72 | return img, label.astype("int64")
73 |
74 | def __len__(self):
75 | """Return size of dataset."""
76 | return len(self.train_data)
77 |
78 | def _check_exists(self):
79 | """Check if dataset is download and in right place."""
80 | return os.path.exists(os.path.join(self.root, self.filename))
81 |
82 | def download(self):
83 | """Download dataset."""
84 | filename = os.path.join(self.root, self.filename)
85 | dirname = os.path.dirname(filename)
86 | if not os.path.isdir(dirname):
87 | os.makedirs(dirname)
88 | if os.path.isfile(filename):
89 | return
90 | print("Download %s to %s" % (self.url, os.path.abspath(filename)))
91 | urllib.request.urlretrieve(self.url, filename)
92 | print("[DONE]")
93 | return
94 |
95 | def load_samples(self):
96 | """Load sample images from dataset."""
97 | filename = os.path.join(self.root, self.filename)
98 | f = gzip.open(filename, "rb")
99 | data_set = pickle.load(f, encoding="bytes")
100 | f.close()
101 | if self.train:
102 | images = data_set[0][0]
103 | labels = data_set[0][1]
104 | self.dataset_size = labels.shape[0]
105 | else:
106 | images = data_set[1][0]
107 | labels = data_set[1][1]
108 | self.dataset_size = labels.shape[0]
109 | return images, labels
110 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 |
7 | class Conv_Block(nn.Module):
8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1):
9 | super(Conv_Block, self).__init__()
10 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride)
11 | self.relu = torch.nn.LeakyReLU()
12 | self.bn = nn.BatchNorm2d(out_channels)
13 |
14 | def forward(self, x):
15 | x = self.conv(x)
16 | x = self.relu(x)
17 | x = self.bn(x)
18 | return x
19 |
20 |
21 | class Dense_Block(nn.Module):
22 | def __init__(self, in_features, out_features):
23 | super(Dense_Block, self).__init__()
24 | self.fc = nn.Linear(in_features, out_features)
25 | self.relu = torch.nn.LeakyReLU()
26 | self.bn = nn.BatchNorm1d(out_features)
27 |
28 | def forward(self, x):
29 | x = self.fc(x)
30 | x = self.relu(x)
31 | x = self.bn(x)
32 | return x
33 |
34 |
35 | class GradReverse(torch.autograd.Function):
36 | def __init__(self, lambd):
37 | self.lambd = lambd
38 |
39 | def forward(self, x):
40 | return x.view_as(x)
41 |
42 | def backward(self, grad_output):
43 | return (grad_output * -self.lambd)
44 |
45 | def grad_reverse(x, p=1):
46 | lambd = 2. / (1. + np.exp(-10 * p)) - 1
47 | return GradReverse(lambd)(x)
48 |
49 |
50 | class Generator_s2m(nn.Module):
51 | def __init__(self):
52 | super(Generator_s2m, self).__init__()
53 | self.conv1 = Conv_Block(3, 64, kernel_size=5)
54 | self.conv2 = Conv_Block(64, 64, kernel_size=5)
55 | self.conv3 = Conv_Block(64, 128, kernel_size=3, stride=2)
56 | self.conv4 = Conv_Block(128, 128, kernel_size=3, stride=2)
57 | self.fc1 = Dense_Block(3200, 100)
58 | self.fc2 = Dense_Block(100, 100)
59 |
60 | def forward(self, x):
61 | x = self.conv1(x)
62 | x = self.conv2(x)
63 | x = self.conv3(x)
64 | x = self.conv4(x)
65 | x = x.view(x.size(0), -1)
66 | x = self.fc1(x)
67 | x = self.fc2(x)
68 | return x
69 |
70 |
71 | class Classifier_s2m(nn.Module):
72 | def __init__(self, n_output):
73 | super(Classifier_s2m, self).__init__()
74 | self.fc = nn.Linear(100, n_output)
75 |
76 | def forward(self, x):
77 | x = self.fc(x)
78 | return x
79 |
80 |
81 | class Generator_u2m(nn.Module):
82 | def __init__(self):
83 | super(Generator_u2m, self).__init__()
84 | self.conv1 = Conv_Block(1, 20, kernel_size=5)
85 | self.pool1 = nn.MaxPool2d(2, stride=2)
86 | self.conv2 = Conv_Block(20, 50, kernel_size=5)
87 | self.pool2 = nn.MaxPool2d(2, stride=2)
88 | self.drop = nn.Dropout()
89 | self.fc = Dense_Block(800, 500)
90 |
91 | def forward(self, x):
92 | x = self.conv1(x)
93 | x = self.pool1(x)
94 | x = self.conv2(x)
95 | x = self.pool2(x)
96 | x = x.view(x.size(0), -1)
97 | x = self.drop(x)
98 | x = self.fc(x)
99 | return x
100 |
101 |
102 | class Classifier_u2m(nn.Module):
103 | def __init__(self, n_output):
104 | super(Classifier_u2m, self).__init__()
105 | self.fc = nn.Linear(500, n_output)
106 |
107 | def forward(self, x):
108 | x = self.fc(x)
109 | return x
110 |
111 |
112 | class Net(nn.Module):
113 | def __init__(self, task='s2m'):
114 | super(Net, self).__init__()
115 | if task == 's2m':
116 | self.generator = Generator_s2m()
117 | self.classifier = Classifier_s2m(6)
118 | elif task =='u2m' or task == 'm2u':
119 | self.generator = Generator_u2m()
120 | self.classifier = Classifier_u2m(6)
121 |
122 | for m in self.modules():
123 | if isinstance(m, nn.Conv2d):
124 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
125 | elif isinstance(m, nn.BatchNorm2d):
126 | nn.init.constant_(m.weight, 1)
127 | nn.init.constant_(m.bias, 0)
128 | elif isinstance(m, nn.BatchNorm1d):
129 | nn.init.constant_(m.weight, 1)
130 | nn.init.constant_(m.bias, 0)
131 |
132 | def forward(self, x, p=None, adv=False):
133 | x = self.generator(x)
134 | if adv == True:
135 | x = grad_reverse(x, p)
136 | y = self.classifier(x)
137 | return x, y
--------------------------------------------------------------------------------
/data/centroid.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, absolute_import
2 | import numpy as np
3 | import torch as t
4 | from utils import *
5 |
6 |
7 | class Centroids(object):
8 | def __init__(self, class_num, dim, use_cuda):
9 | self.class_num = class_num
10 | self.src_ctrs = t.ones((class_num, dim))
11 | self.tgt_ctrs = t.ones((class_num, dim))
12 | self.src_ctrs *= 1e-10
13 | self.tgt_ctrs *= 1e-10
14 | if use_cuda:
15 | self.src_ctrs = self.src_ctrs.cuda()
16 | self.tgt_ctrs = self.tgt_ctrs.cuda()
17 |
18 |
19 | def get_centroids(self, domain=None, cid=None):
20 | if domain == 'source':
21 | return self.src_ctrs if cid is None else self.src_ctrs[cid, :]
22 | elif domain == 'target':
23 | return self.tgt_ctrs if cid is None else self.tgt_ctrs[cid, :]
24 | else:
25 | return self.src_ctrs, self.tgt_ctrs
26 |
27 | @torch.no_grad()
28 | def update(self, feat_s, pred_s, label_s, feat_t, pred_t):
29 | self.upd_src_centroids(feat_s, pred_s, label_s)
30 | self.upd_tgt_centroids(feat_t, pred_t)
31 |
32 | @torch.no_grad()
33 | def upd_src_centroids(self, feats, probs, labels):
34 | # feats = to_np(feats)
35 | labels = to_np(labels)
36 | # last_centroids = to_np(self.src_ctrs)
37 | probs = to_np(F.softmax(probs, dim=1))
38 |
39 | for i in range(self.class_num - 1):
40 | if np.sum(labels == i) > 1:
41 | last_centroid = self.src_ctrs[i, :]
42 | data_idx = np.argwhere(labels == i)
43 | new_centroid = t.mean(feats[data_idx, :], 0).squeeze()
44 | cs = cal_sim(new_centroid, last_centroid)
45 | # print(cs)
46 | new_centroid = cs * new_centroid + (1 - cs) * last_centroid
47 | self.src_ctrs[i, :] = new_centroid
48 |
49 | @torch.no_grad()
50 | def upd_tgt_centroids(self, feats, probs):
51 | # feats = to_np(feats)
52 | # last_centroids = to_np(self.tgt_ctrs)
53 | # src_centroids = to_np(self.src_ctrs)
54 | _, pseudo_label = probs.max(1, keepdim=True)
55 | pseudo_label = to_np(pseudo_label)
56 | probs = to_np(F.softmax(probs, dim=1))
57 |
58 | for i in range(self.class_num):
59 | if np.sum(pseudo_label == i) > 1:
60 | data_idx = np.argwhere(pseudo_label == i)
61 | new_centroid = t.mean(feats[data_idx, :], 0).squeeze()
62 | last_centroid = self.tgt_ctrs[i, :]
63 | # if last_centroids[i] != np.zeros_like((1, feats.shape[0])):
64 | cs = cal_sim(new_centroid, self.src_ctrs[i, :])
65 | # print(cs)
66 | new_centroid = cs * new_centroid + (1 - cs) * last_centroid
67 | self.tgt_ctrs[i, :] = new_centroid
68 |
69 |
70 | def crit_intra(feats, y, centers, lambd=1e-3):
71 | class_num = len(centers)
72 | batch_size = y.shape[0]
73 |
74 | expanded_centers = centers.expand(batch_size, -1, -1)
75 | expanded_feats = feats.expand(class_num, -1, -1).transpose(1, 0)
76 | # distance_centers = (expanded_feats - expanded_centers).pow(2).sum(dim=-1)
77 | distance_centers = cal_sim(expanded_feats, expanded_centers)
78 | distance_centers = distance_centers.reshape(batch_size, class_num)
79 |
80 | intra_distances = distance_centers.gather(1, y.unsqueeze(1))
81 | # intra_distances = distances_same.sum()
82 | inter_distances = distance_centers.sum(dim=-1) - intra_distances
83 |
84 | epsilon = 1e-6
85 | loss = (1 / 2.0 / batch_size / class_num) * intra_distances / \
86 | (inter_distances + epsilon)
87 | loss = loss.sum()
88 | loss *= lambd
89 | return loss
90 |
91 |
92 | def crit_inter(center1, center2, lambd=1e-3):
93 | # dists = F.pairwise_distance(center1, center2)
94 | # loss = t.mean(dists)
95 |
96 | # dists = cal_cossim(center1.cpu().numpy(), center2.cpu().numpy())
97 | dists = cal_sim(center1, center2)
98 | loss = 0
99 | for i in range(center1.shape[0]):
100 | loss += dists[i]#[i]
101 | loss /= center1.shape[0]
102 | loss *= lambd
103 | return loss, dists
104 |
105 |
106 | def crit_contrast(feats, probs, s_ctds, t_ctds, lambd=1e-3):
107 | batch_num = feats.shape[0]
108 | class_num = s_ctds.shape[0]
109 | probs = F.softmax(probs, dim=-1)
110 | max_probs, preds = probs.max(1, keepdim=True)
111 | # print(probs.shape, max_probs.shape)
112 | select_index = t.nonzero(max_probs.squeeze() >= 0.3).squeeze(1)
113 | select_index = select_index.cpu().tolist()
114 |
115 | # todo: calculate margins
116 | # dist_ctds = cal_cossim(to_np(s_ctds), to_np(t_ctds))
117 | dist_ctds = cal_sim(s_ctds, t_ctds)
118 | # print('dist_ctds', dist_ctds.shape)
119 |
120 | M = np.ones(class_num)
121 | for i in range(class_num):
122 | # M[i] = np.sum(dist_ctds[i, :]) - dist_ctds[i, i]
123 | M[i] = dist_ctds.mean() - dist_ctds[i]
124 | M[i] /= class_num - 1
125 | # print('M', M)
126 |
127 | # todo: calculate D_k between known samples to its source centroid &
128 | # todo: calculate D_u distances between unknown samples to all source centroids
129 | D_k, n_k = 0, 1e-5
130 | D_u, n_u = 0, 1e-5
131 | for i in select_index:
132 | class_id = preds[i][0]
133 | if class_id < class_num:
134 | # D_k += F.pairwise_distance(feats[i, :], s_ctds[class_id]).squeeze()
135 | # print(feats.shape, i)
136 | D_k += cal_sim(feats[i, :], s_ctds[class_id, :])
137 | # print('D_k', D_k)
138 | n_k += 1
139 | else:
140 | # todo: judge if unknown sample in the radius region of known centroid
141 | rp_feats = feats[i, :].unsqueeze(0).repeat(class_num, 1)
142 |
143 | # dist_known = F.pairwise_distance(rp_feats, s_ctds)
144 | dist_known = cal_sim(rp_feats, s_ctds)
145 | # print('dist_known', len(dist_known), dist_known)
146 |
147 | M_mean = M.mean()
148 | outliers = dist_known < M_mean
149 | dist_margin = (dist_known - M_mean) * outliers.float()
150 | D_u += dist_margin.sum()
151 |
152 | loss = D_k / n_k # - D_u / n_u
153 | return loss.mean() * lambd
154 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import os
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.optim as optimizer
9 |
10 | from data.dataloader import *
11 | from data.centroid import *
12 | from model import Net
13 | import utils
14 |
15 |
16 | def main(args):
17 | if torch.cuda.is_available():
18 | device = torch.device("cuda:" + str(args.gpu))
19 | is_cuda = True
20 | else:
21 | device = torch.device("cpu")
22 | is_cuda = False
23 |
24 | src_loader, tgt_loader = get_data(args)
25 |
26 | model = Net(task=args.task).to(device)
27 |
28 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
29 | momentum=args.momentum,
30 | weight_decay=args.weight_decay,
31 | nesterov=True)
32 |
33 | if args.resume:
34 | print("=> loading checkpoint '{}'".format(args.resume))
35 | checkpoint = torch.load(args.resume)
36 | args.start_epoch = checkpoint['epoch']
37 | best_acc = checkpoint['best_acc']
38 | model.load_state_dict(checkpoint['state_dict'])
39 |
40 | print("=> loaded checkpoint '{}' (epoch {})"
41 | .format(args.resume, checkpoint['epoch']))
42 |
43 | best_acc = 0
44 | best_label = []
45 | best_result = []
46 |
47 | # create centroids for known classes
48 | all_centroids = Centroids(args.class_num - 1, 100, use_cuda=is_cuda)
49 |
50 | try:
51 | # start training
52 | for epoch in range(args.epochs):
53 | data = (src_loader, tgt_loader, all_centroids)
54 |
55 | all_centroids = train(model, optimizer, data, epoch, device, args)
56 |
57 | result, gt_label, acc = test(model, tgt_loader, epoch, device, args)
58 |
59 | is_best = acc > best_acc
60 | if is_best:
61 | best_acc = acc
62 | best_label = gt_label
63 | best_pred = result
64 |
65 | utils.save_checkpoint({
66 | 'epoch': epoch,
67 | 'state_dict': model.state_dict(),
68 | 'best_acc': best_acc
69 | }, is_best, args.check_dir)
70 |
71 | print ("------Best-------")
72 | utils.cal_acc(best_label, best_result, args.class_num)
73 |
74 | except KeyboardInterrupt:
75 | print ("------Best-------")
76 | utils.cal_acc(best_label, best_result, args.class_num)
77 |
78 |
79 | def train(model, optimizer, data, epoch, device, args):
80 |
81 | src_loader, tgt_loader, all_centroids = data
82 | pre_stage = 5
83 | adv_stage = 15
84 | criterion_bce = nn.BCELoss()
85 | criterion_cel = nn.CrossEntropyLoss()
86 |
87 | model.train()
88 |
89 | for batch_idx, (batch_s, batch_t) in enumerate(zip(src_loader, tgt_loader)):
90 | global_step = epoch * len(src_loader) + batch_idx
91 | p = global_step / args.epochs * len(src_loader)
92 | lr = utils.adjust_learning_rate(optimizer, epoch, args,
93 | batch_idx, len(src_loader))
94 | data_s, label_s = batch_s
95 | data_s = data_s.to(device)
96 | label_s = label_s.to(device)
97 | data_t, label_t = batch_t
98 | data_t = data_t.to(device)
99 | adv_label_t = torch.tensor([args.th]*len(data_t)).to(device)
100 |
101 | loss = 0
102 | optimizer.zero_grad()
103 | feat_s, pred_s = model(data_s)
104 | feat_t, pred_t = model(data_t, p, adv=True)
105 |
106 | # classification loss for known classes in source domain
107 | loss_cel = criterion_cel(pred_s, label_s)
108 | loss += loss_cel
109 |
110 | if epoch >= pre_stage:
111 | # adversarial loss for unknown class in target domain
112 | pred_t_prob_unk = F.softmax(pred_t, dim=1)[:, -1]
113 | loss_adv = criterion_bce(pred_t_prob_unk, adv_label_t)
114 | loss += loss_adv
115 |
116 | if epoch >= adv_stage:
117 | all_centroids.update(feat_s, pred_s, label_s, feat_t, pred_t)
118 | s_ctds, t_ctds = all_centroids.get_centroids()
119 |
120 | loss_intra = crit_intra(feat_s, label_s, s_ctds)
121 | loss += loss_intra * args.lamb_s
122 |
123 | loss_inter, _ = crit_inter(s_ctds, t_ctds)
124 | loss += loss_inter * args.lamb_c
125 |
126 | loss_contr = crit_contrast(feat_t, pred_t, s_ctds, t_ctds)
127 | loss += loss_contr * args.lamb_t
128 |
129 | loss.backward()
130 | optimizer.step()
131 |
132 | if epoch >= pre_stage and batch_idx % args.log_interval == 0:
133 | print('Epoch: {} [{}/{} ({:.0f}%)] LR: {:.6f} \
134 | Loss(cel): {:.4f} Loss(adv): {:.4f}\t'.format(
135 | epoch, batch_idx * args.batch_size,
136 | len(src_loader.dataset),
137 | 100. * batch_idx / len(src_loader), lr,
138 | loss_cel.item(), loss_adv.item()))
139 |
140 | return all_centroids
141 |
142 |
143 | def test(model, tgt_loader, epoch, device, args):
144 |
145 | loss = 0
146 | correct = 0
147 | result = []
148 | gt_label = []
149 |
150 | model.eval()
151 | criterion_cel = nn.CrossEntropyLoss()
152 |
153 | for batch_idx, (data_t, label) in enumerate(tgt_loader):
154 | data_t = data_t.to(device)
155 | label = label.to(device)
156 |
157 | feat, output = model(data_t)
158 | pred = output.max(1, keepdim=True)[1]
159 | loss += criterion_cel(output, label).item()
160 |
161 | for i in range(len(pred)):
162 | result.append(pred[i].item())
163 | gt_label.append(label[i].item())
164 |
165 | correct += pred.eq(label.view_as(pred)).sum().item()
166 |
167 | loss /= len(tgt_loader.dataset)
168 |
169 | utils.cal_acc(gt_label, result, args.class_num)
170 | acc = 100. * correct / len(tgt_loader.dataset)
171 |
172 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
173 | loss, correct, len(tgt_loader.dataset),
174 | 100. * correct / len(tgt_loader.dataset)))
175 |
176 | return result, gt_label, acc
177 |
178 |
179 | if __name__ == "__main__":
180 |
181 | parser = argparse.ArgumentParser(description='Openset-DA SVHN -> MNIST Example')
182 | parser.add_argument('--task', choices=['s2m', 'u2m', 'm2u'], default='s2m',
183 | help='domain adaptation sub-task')
184 | parser.add_argument('--class-num', type=int, default=6, help='number of classes')
185 | parser.add_argument('--th', type=float, default=0.5, metavar='TH',
186 | help='threshold for unknown class')
187 | parser.add_argument('--lamb-s', type=float, default=0.02)
188 | parser.add_argument('--lamb-c', type=float, default=0.005)
189 | parser.add_argument('--lamb-t', type=float, default=0.0001)
190 | parser.add_argument('--batch-size', type=int, default=64, metavar='N',
191 | help='input batch size for training (default: 128)')
192 | parser.add_argument('--epochs', type=int, default=100, metavar='E',
193 | help='number of epochs to train')
194 | parser.add_argument('--lr', type=float, default=0.0005, metavar='LR',
195 | help='learning rate')
196 | parser.add_argument('--lr-rampdown-epochs', default=101, type=int,
197 | help='length of learning rate cosine rampdown (>= length of training)')
198 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M')
199 | # parser.add_argument('--grl-rampup-epochs', default=20, type=int, metavar='EPOCHS',
200 | # help='length of grl rampup')
201 | parser.add_argument('--weight-decay', '--wd', default=1e-3, type=float,
202 | help='weight decay (default: 1e-3)')
203 |
204 | parser.add_argument('--log-interval', type=int, default=100, metavar='N',
205 | help='how many batches to wait before logging training status')
206 | parser.add_argument('--check_dir', default='checkpoint', type=str,
207 | help='directory to save checkpoint')
208 | parser.add_argument('--resume', default='', type=str,
209 | help='path to resume checkpoint (default: none)')
210 | parser.add_argument('--gpu', default='0', type=str, metavar='GPU',
211 | help='id(s) for CUDA_VISIBLE_DEVICES')
212 |
213 | args = parser.parse_args()
214 |
215 | torch.backends.cudnn.benchmark = True
216 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
217 |
218 | main(args)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import copy
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from sklearn.metrics import accuracy_score
8 | from sklearn.metrics.pairwise import cosine_similarity
9 |
10 |
11 | def save_checkpoint(state, is_best, check_dir):
12 | filename = 'latest.pth.tar'
13 | torch.save(state, os.path.join(check_dir, filename))
14 | if is_best:
15 | shutil.copyfile(os.path.join(check_dir, filename),
16 | os.path.join(check_dir, 'best.pth.tar'))
17 |
18 |
19 | def cal_acc(gt_label, pred_result, num):
20 | acc_sum = 0
21 | for n in range(num):
22 | y = []
23 | pred_y = []
24 | for i in range(len(gt_label)):
25 | gt = gt_label[i]
26 | pred = pred_result[i]
27 | if gt == n:
28 | y.append(gt)
29 | pred_y.append(pred)
30 | print ('{}: {:4f}'.format(n if n != (num - 1) else 'Unk', accuracy_score(y, pred_y)))
31 | if n == (num - 1):
32 | print ('Known Avg Acc: {:4f}'.format(acc_sum / (num - 1)))
33 | acc_sum += accuracy_score(y, pred_y)
34 | print ('Avg Acc: {:4f}'.format(acc_sum / num))
35 | print ('Overall Acc : {:4f}'.format(accuracy_score(gt_label, pred_result)))
36 |
37 |
38 | def cosine_rampdown(current, rampdown_length):
39 | assert 0 <= current <= rampdown_length
40 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
41 |
42 |
43 | def to_np(x):
44 | return x.squeeze().cpu().detach().numpy()
45 |
46 |
47 | def get_src_centroids(data_loader, model, args):
48 | feats, labels, probs, preds = get_features(data_loader, model)
49 | centroids = []
50 | for i in range(args.class_num - 1):
51 | data_idx = np.unique(np.argwhere(labels == i))
52 | feats_i = feats[data_idx].squeeze()
53 |
54 | center_i = np.mean(feats_i, axis=0)
55 | centroids.append(center_i)
56 |
57 | centroids = np.array(centroids).squeeze()
58 | return torch.from_numpy(centroids).cuda()
59 |
60 |
61 | def get_tgt_centroids(data_loader, model, th, src_centroids, args):
62 | feats, labels, probs, preds = get_features(data_loader, model)
63 | src_centroids = to_np(src_centroids)
64 | tgt_dissim = cal_sim(src_centroids, feats, rev=True)
65 | centroids = []
66 | for i in range(args.CLASS_NUM - 1):
67 | class_idx = np.unique(np.argwhere(preds == i))
68 | easy_idx = np.unique(np.argwhere(tgt_dissim[i, :] <= th))
69 | data_idx = np.intersect1d(class_idx, easy_idx)
70 | if len(data_idx) > 1:
71 | feats_i = feats[data_idx].squeeze()
72 | else:
73 | feats_i = np.zeros_like(feats)
74 | print(i, 'none')
75 | center_i = np.mean(feats_i, axis=0)
76 | centroids.append(center_i)
77 |
78 | centroids = np.array(centroids).squeeze()
79 | return torch.from_numpy(centroids).cuda()
80 |
81 |
82 | def upd_src_centroids(feats, labels, probs, last_centroids, args):
83 | new_centroids = []
84 | feats = to_np(feats)
85 | labels = to_np(labels)
86 | last_centroids = to_np(last_centroids)
87 | probs = F.softmax(probs, dim=1)
88 | probs = to_np(probs)
89 | for i in range(args.class_num - 1):
90 | if np.sum(labels == i) > 0:
91 | data_idx = np.intersect1d(np.argwhere(labels == i), np.argwhere(probs[:, i] > 0.1))
92 | new_centroid = np.mean(feats[data_idx], axis=0).reshape(1,-1)
93 | cs = cosine_similarity(new_centroid, last_centroids[i].reshape(1,-1))[0][0]
94 | new_centroid = cs * new_centroid + (1 - cs) * last_centroids[i]
95 | else:
96 | new_centroid = last_centroids[i]
97 |
98 | new_centroids.append(new_centroid.squeeze())
99 |
100 | new_centroids = np.array(new_centroids)
101 | return torch.from_numpy(new_centroids).cuda()
102 |
103 |
104 | def upd_tgt_centroids(feats, probs, last_centroids, src_centroids, args):
105 | new_centroids = []
106 | feats = to_np(feats)
107 | last_centroids = to_np(last_centroids)
108 | src_centroids = to_np(src_centroids)
109 | _, ps_labels = probs.max(1, keepdim=True)
110 | ps_labels = to_np(ps_labels)
111 | probs = F.softmax(probs, dim=1)
112 | probs = to_np(probs)
113 | for i in range(args.CLASS_NUM - 1):
114 | if np.sum(ps_labels == i) > 0:
115 | data_idx = np.intersect1d(np.argwhere(ps_labels == i), np.argwhere(probs[:, i] > 0.1))
116 | new_centroid = np.mean(feats[data_idx], axis=0).reshape(1,-1)
117 |
118 | if last_centroids[i] != np.zeros_like((1, feats.shape[0])):
119 | cs = cosine_similarity(new_centroid, src_centroids[i].reshape(1,-1))[0][0]
120 | new_centroid = cs * new_centroid + (1 - cs) * last_centroids[i]
121 | else:
122 | new_centroid = last_centroids[i]
123 |
124 | new_centroids.append(new_centroid.squeeze())
125 |
126 | new_centroids = np.array(new_centroids)
127 | return torch.from_numpy(new_centroids).cuda()
128 |
129 |
130 | def get_features(data_loader, model):
131 | model.eval()
132 | feats, labels = [], []
133 | probs, preds = [], []
134 | for batch_idx, batch_data in enumerate(data_loader):
135 | input, label = batch_data
136 | input, label = input.cuda(), label.cuda(non_blocking=True)
137 |
138 | feat, prob = model(input)
139 | prob, pred = prob.max(1, keepdim=True)
140 |
141 | feats.append(feat.cpu().detach().numpy())
142 | labels.append(label.cpu().detach().numpy())
143 | probs.append(prob.cpu().detach().numpy())
144 | preds.append(pred.cpu().detach().numpy())
145 |
146 | feats = np.concatenate(feats, axis=0)
147 | labels = np.concatenate(labels, axis=0)
148 | probs = np.concatenate(probs, axis=0)
149 | preds = np.concatenate(preds, axis=0)
150 | return feats, labels, probs, preds
151 |
152 |
153 | def cosine_rampdown(current, rampdown_length):
154 | """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
155 | assert 0 <= current <= rampdown_length
156 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
157 |
158 |
159 | def adjust_learning_rate(optimizer, epoch, args,
160 | step_in_epoch, total_steps_in_epoch):
161 | epoch = epoch + step_in_epoch / total_steps_in_epoch
162 |
163 | lr = args.lr * cosine_rampdown(epoch, args.lr_rampdown_epochs)
164 |
165 | for param_group in optimizer.param_groups:
166 | param_group['lr'] = lr
167 |
168 | return lr
169 |
170 |
171 | def cal_sim(x1, x2, metric='cosine'):
172 | # x = x1.clone()
173 | if len(x1.shape) != 2:
174 | x1 = x1.reshape(-1, x1.shape[-1])
175 | if len(x2.shape) != 2:
176 | x2 = x2.reshape(-1, x2.shape[-1])
177 |
178 | if metric == 'cosine':
179 | sim = (F.cosine_similarity(x1, x2) + 1) / 2
180 | else:
181 | sim = F.pairwise_distance(x1, x2) / torch.norm(x2, dim=1)
182 | return sim
183 |
184 |
185 | def result_log(best_epoch, acc_score, OS_score, all_score, args):
186 | with open(os.path.join(args.checkpoint, args.log_path), 'a') as f:
187 | f.write('Task %s\n' % args.task)
188 | f.write('init_lr %.5f, wd %.5f batch %d\n' % (args.lr, args.weight_decay, args.batch_size))
189 | f.write('w_s %.5f | w_c %.5f | w_t %.5f\n' % (args.w_s, args.w_c, args.w_t))
190 | f.write('Best(%d) OS* %.3f OS %.3f ALL %.3f unk %.3f\n' % (best_epoch, acc_score[0], acc_score[1],
191 | acc_score[2], acc_score[3]))
192 | f.write('(OS) OS* %.3f OS %.3f ALL %.3f unk %.3f\n' % (OS_score[0], OS_score[1], OS_score[2], OS_score[3]))
193 | f.write(
194 | '(all) OS* %.3f OS %.3f ALL %.3f unk %.3f\n' % (all_score[0], all_score[1], all_score[2], all_score[3]))
195 |
196 |
197 | # def cal_acc(gt_list, predict_list, num):
198 | # acc_sum = 0
199 | # acc_list = {}
200 | # for n in range(num):
201 | # y = []
202 | # pred_y = []
203 | # for i in range(len(gt_list)):
204 | # gt = gt_list[i]
205 | # predict = predict_list[i]
206 | # if gt == n:
207 | # y.append(gt)
208 | # pred_y.append(predict)
209 | # acc = accuracy_score(y, pred_y)
210 | # print('{}: {:4f}'.format(n if n != (num - 1) else 'Unk', acc))
211 | # acc_list[n] = acc
212 | # if n == (num - 1):
213 | # OS_ = acc_sum * 1.0 / (num - 1)
214 | # print('Known Avg Acc: {:4f}'.format(OS_))
215 | # unk = accuracy_score(y, pred_y)
216 | # acc_sum += accuracy_score(y, pred_y)
217 | # OS = acc_sum * 1.0 / num
218 | # all = accuracy_score(gt_list, predict_list)
219 | # print('Avg Acc: {:4f}'.format(OS))
220 | # print('Overall Acc : {:4f}\n'.format(all))
221 | # return OS_, OS, all, unk, acc_list
--------------------------------------------------------------------------------