├── .DS_Store ├── data └── .DS_Store ├── losses.py ├── functional.py ├── options.py ├── README.md ├── vgg.py ├── main.py ├── main_music.py ├── workflow.py ├── utils.py └── conal.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zdchu/CoNAL/HEAD/.DS_Store -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zdchu/CoNAL/HEAD/data/.DS_Store -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def multi_loss(y_true, y_pred, loss_fn=torch.nn.CrossEntropyLoss(reduce='mean').cuda()): 5 | mask = y_true != -1 6 | y_pred = torch.transpose(y_pred, 1, 2) 7 | loss = loss_fn(y_pred[mask], y_true[mask]) 8 | return loss -------------------------------------------------------------------------------- /functional.py: -------------------------------------------------------------------------------- 1 | from torch.distributions import RelaxedOneHotCategorical, RelaxedBernoulli 2 | import torch.nn as nn 3 | 4 | 5 | def gumbel_sigmoid(input, temp): 6 | return RelaxedBernoulli(temp, probs=input).rsample() 7 | 8 | 9 | class GumbelSigmoid(nn.Module): 10 | def __init__(self, 11 | temp: float = 0.1, 12 | threshold: float = 0.5): 13 | super(GumbelSigmoid, self).__init__() 14 | self.temp = temp 15 | self.threshold = threshold 16 | 17 | def forward(self, input): 18 | if self.training: 19 | return gumbel_sigmoid(input, self.temp) 20 | else: 21 | return (input.sigmoid() >= self.threshold).float() -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | def model_opts(parser): 2 | parser.add_argument('--emb-dim', type=int, default=256) 3 | parser.add_argument('--hidden', default=[256, 128]) 4 | parser.add_argument('--dropout', default=0.5, type=float) 5 | parser.add_argument('--lr', default=0.01, type=float) 6 | parser.add_argument('--num_epochs', default=100, type=int) 7 | parser.add_argument('--mode', default='train') 8 | 9 | parser.add_argument('--root', type=str, default='') 10 | parser.add_argument('--batch_size', type=int, default=512) 11 | parser.add_argument('--new_data', type=bool, default=False) 12 | parser.add_argument('--num_samples', type=int, default=40000) 13 | parser.add_argument('--num_users', type=int, default=30) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Common Noise Adaptation Layers 2 | Code for AAAI 2021 long paper **Learning from Crowds by Modeling Common Confusions**. 3 | 4 | ## Dependencies 5 | * Python 3.5+ 6 | * PyTorch 1.2.0 7 | 8 | ## Datasets 9 | Download the real-world dataset [LabelMe](http://fprodrigues.com//deep_LabelMe.tar.gz) and [Music](http://fprodrigues.com//mturk-datasets.tar.gz). Please use the pretrained features in the *prepared* folder and put them into the *data* folder. 10 | 11 | ## Usage 12 | Run the model with default settings 13 | ``` 14 | python main.py 15 | ``` 16 | 17 | ## Citation 18 | Please cite our work if you find it useful to your research 19 | ``` 20 | @article{chu2020learning, 21 | title={Learning from Crowds by Modeling Common Confusions}, 22 | author={Chu, Zhendong and Ma, Jing and Wang, Hongning}, 23 | journal={arXiv preprint arXiv:2012.13052}, 24 | year={2020} 25 | } 26 | ``` 27 | -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class VGG(nn.Module): 6 | """ 7 | the common architecture for the left model 8 | """ 9 | def __init__(self, vgg_name, input_channels=3): 10 | super(VGG, self).__init__() 11 | self.cfg = { 12 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 13 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 14 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 15 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 16 | 'M'], 17 | } 18 | self.input_channels = input_channels 19 | self.features = self._make_layers(self.cfg[vgg_name]) 20 | self.classifier = nn.Linear(512, 10) 21 | 22 | def forward(self, x): 23 | out = self.features(x) 24 | out = out.view(out.size(0), -1) 25 | out = self.classifier(out) 26 | return F.softmax(out,dim=1) 27 | 28 | def _make_layers(self, cfg): 29 | layers = [] 30 | in_channels = self.input_channels 31 | for x in cfg: 32 | if x == 'M': 33 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 34 | else: 35 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(x), 37 | nn.ReLU(inplace=True)] 38 | in_channels = x 39 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 40 | return nn.Sequential(*layers) 41 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from conal import * 2 | from utils import * 3 | from torch import optim 4 | from copy import deepcopy 5 | import argparse 6 | from options import * 7 | from torch.utils.data import DataLoader 8 | from workflow import * 9 | import random 10 | from conal import * 11 | from sklearn.decomposition import NMF 12 | from sklearn.metrics import accuracy_score 13 | 14 | seed = 12 15 | torch.manual_seed(seed) 16 | random.seed(seed) 17 | np.random.seed(seed) 18 | 19 | dataset = 'labelme' 20 | model_dir = './model/' 21 | 22 | train_dataset = Dataset(mode='train', dataset=dataset, sparsity=0) 23 | trn_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True) 24 | 25 | valid_dataset = Dataset(mode='valid', dataset=dataset) 26 | val_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True) 27 | 28 | test_dataset = Dataset(mode='test', dataset=dataset) 29 | tst_loader = DataLoader(test_dataset, batch_size=32, shuffle=True) 30 | 31 | 32 | def main(opt, model=None): 33 | train_acc_list = [] 34 | test_acc_list = [] 35 | 36 | user_feature = np.eye(train_dataset.num_users) 37 | if model == None: 38 | model = torch.load(model_dir + 'model%s' % dataset) 39 | else: 40 | model = CoNAL(num_annotators=train_dataset.num_users, num_class=train_dataset.num_classes, 41 | input_dims=train_dataset.input_dims, user_feature=user_feature, gumbel_common=False).cuda() 42 | best_valid_acc = 0 43 | best_model = None 44 | lr = 1e-2 45 | for epoch in range(opt.num_epochs): 46 | optimizer = optim.Adam(model.parameters(), lr=lr) 47 | train_acc = train(train_loader=trn_loader, model=model, optimizer=optimizer, criterion=multi_loss, mode='common') 48 | valid_acc, valid_f1, _ = test(model=model, test_loader=val_loader) 49 | test_acc, test_f1, _ = test(model=model, test_loader=tst_loader) 50 | train_acc_list.append(train_acc) 51 | test_acc_list.append(test_acc) 52 | if valid_acc > best_valid_acc: 53 | best_valid_acc = valid_acc 54 | best_model = deepcopy(model) 55 | print('Epoch [%3d], Valid acc: %.5f, Valid f1: %.5f' % (epoch, valid_acc, valid_f1)) 56 | print('Test acc: %.5f, Test f1: %.5f' % (test_acc, test_f1)) 57 | 58 | test_acc, test_f1, _ = test(model=best_model, test_loader=tst_loader) 59 | print('Test acc: %.5f, Test f1: %.5f' % (test_acc, test_f1)) 60 | return best_model, test_acc 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | model_opts(parser) 66 | opt = parser.parse_args() 67 | 68 | test_acc = [] 69 | _, acc = main(opt, model=True) 70 | test_acc.append(acc) 71 | -------------------------------------------------------------------------------- /main_music.py: -------------------------------------------------------------------------------- 1 | from conal import * 2 | from utils import * 3 | from torch import optim 4 | from copy import deepcopy 5 | import argparse 6 | from options import * 7 | from torch.utils.data import DataLoader 8 | from workflow import * 9 | import random 10 | from conal import * 11 | from sklearn.decomposition import NMF 12 | from sklearn.metrics import accuracy_score 13 | 14 | seed = 1234 15 | torch.manual_seed(seed) 16 | random.seed(seed) 17 | np.random.seed(seed) 18 | 19 | dataset = 'music' 20 | model_dir = './model/' 21 | 22 | train_dataset = Dataset(mode='train', dataset=dataset, sparsity=0) 23 | trn_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True) 24 | 25 | valid_dataset = Dataset(mode='test', dataset=dataset) 26 | val_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True) 27 | 28 | test_dataset = Dataset(mode='test', dataset=dataset) 29 | tst_loader = DataLoader(test_dataset, batch_size=32, shuffle=True) 30 | 31 | 32 | def main(opt, model=None): 33 | train_acc_list = [] 34 | test_acc_list = [] 35 | 36 | user_feature = np.eye(train_dataset.num_users) 37 | if model == None: 38 | model = torch.load(model_dir + 'model%s' % dataset) 39 | else: 40 | model = CoNAL_music(num_annotators=train_dataset.num_users, num_class=train_dataset.num_classes, 41 | input_dims=train_dataset.input_dims, user_feature=user_feature, gumbel_common=False).cuda() 42 | best_valid_acc = 0 43 | best_model = None 44 | lr = 1e-2 45 | for epoch in range(200): 46 | optimizer = optim.Adam(model.parameters(), lr=lr) 47 | train_acc = train(train_loader=trn_loader, model=model, optimizer=optimizer, criterion=multi_loss, mode='common') 48 | valid_acc, valid_f1, _ = test(model=model, test_loader=val_loader) 49 | test_acc, test_f1, _ = test(model=model, test_loader=tst_loader) 50 | train_acc_list.append(train_acc) 51 | test_acc_list.append(test_acc) 52 | if valid_acc > best_valid_acc: 53 | best_valid_acc = valid_acc 54 | best_model = deepcopy(model) 55 | print('Epoch [%3d], Valid acc: %.5f, Valid f1: %.5f' % (epoch, valid_acc, valid_f1)) 56 | print('Test acc: %.5f, Test f1: %.5f' % (test_acc, test_f1)) 57 | 58 | test_acc, test_f1, _ = test(model=best_model, test_loader=tst_loader) 59 | print('Test acc: %.5f, Test f1: %.5f' % (test_acc, test_f1)) 60 | return best_model, test_acc 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | model_opts(parser) 66 | opt = parser.parse_args() 67 | 68 | test_acc = [] 69 | _, acc = main(opt, model=True) 70 | test_acc.append(acc) 71 | -------------------------------------------------------------------------------- /workflow.py: -------------------------------------------------------------------------------- 1 | from losses import * 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from sklearn.metrics import f1_score 6 | import IPython 7 | 8 | loss_fn = torch.nn.CrossEntropyLoss(reduce='mean').cuda() 9 | def train(train_loader, model, optimizer, criterion=F.cross_entropy, mode='simple', annotators=None, pretrain=None, 10 | support = None, support_t = None, scale=0): 11 | model.train() 12 | correct = 0 13 | total = 0 14 | total_loss = 0 15 | loss = 0 16 | 17 | correct_rec = 0 18 | total_rec = 0 19 | for idx, input, targets, targets_onehot, true_labels in train_loader: 20 | input = input.cuda() 21 | targets = targets.cuda().long() 22 | targets_onehot = targets_onehot.cuda() 23 | targets_onehot[targets_onehot == -1] = 0 24 | true_labels = true_labels.cuda().long() 25 | 26 | if mode == 'simple': 27 | loss = 0 28 | if scale: 29 | cls_out, output, trace_norm = model(input) 30 | loss += scale * trace_norm 31 | mask = targets != -1 32 | y_pred = torch.transpose(output, 1, 2) 33 | y_true = torch.transpose(targets_onehot, 1, 2).float() 34 | loss += torch.mean(-y_true[mask] * torch.log(y_pred[mask])) 35 | else: 36 | cls_out, output = model(input) 37 | loss += criterion(targets, output) 38 | _, predicted = cls_out.max(1) 39 | correct += predicted.eq(true_labels).sum().item() 40 | total += true_labels.size(0) 41 | elif mode == 'common': 42 | rec_loss = 0 43 | loss = 0 44 | cls_out, output = model(input, mode='train') 45 | _, predicted = cls_out.max(1) 46 | correct += predicted.eq(true_labels).sum().item() 47 | total += true_labels.size(0) 48 | loss += criterion(targets, output) 49 | loss -= 0.00001 * torch.sum(torch.norm((model.kernel - model.common_kernel).view(targets.shape[1], -1), dim=1, p=2)) 50 | else: 51 | output, _ = model(input) 52 | loss = loss_fn(output, true_labels) 53 | _, predicted = output.max(1) 54 | correct += predicted.eq(true_labels).sum().item() 55 | total += true_labels.size(0) 56 | total_loss += loss 57 | optimizer.zero_grad() 58 | loss.backward() 59 | optimizer.step() 60 | if mode =='simple' or mode == 'common': 61 | print('Training acc: ', correct / total) 62 | return correct / total 63 | 64 | 65 | def test(model, test_loader): 66 | model.eval() 67 | correct = 0 68 | total = 0 69 | target = [] 70 | predict = [] 71 | for _, inputs, targets in test_loader: 72 | inputs = inputs.cuda() 73 | target.extend(targets.data.numpy()) 74 | targets = targets.cuda() 75 | 76 | total += targets.size(0) 77 | output, _ = model(inputs, mode='test') 78 | _, predicted = output.max(1) 79 | predict.extend(predicted.cpu().data.numpy()) 80 | correct += predicted.eq(targets).sum().item() 81 | acc = correct / total 82 | f1 = f1_score(target, predict, average='macro') 83 | 84 | classes = list(set(target)) 85 | classes.sort() 86 | acc_per_class = [] 87 | predict = np.array(predict) 88 | target = np.array(target) 89 | for i in range(len(classes)): 90 | instance_class = target == i 91 | acc_i = np.mean(predict[instance_class] == classes[i]) 92 | acc_per_class.append(acc_i) 93 | return acc, f1, acc_per_class 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils import data 3 | import torch 4 | from sklearn.metrics.pairwise import euclidean_distances 5 | from sklearn.model_selection import train_test_split 6 | import IPython 7 | import os 8 | from torchvision.datasets.utils import download_url, check_integrity 9 | import torchvision.transforms as transforms 10 | import sys 11 | import pandas as pd 12 | from PIL import Image 13 | import pickle 14 | import torchvision.models as models 15 | from sklearn.preprocessing import normalize 16 | 17 | 18 | def map_data(data): 19 | """ 20 | Map data to proper indices in case they are not in a continues [0, N) range 21 | 22 | Parameters 23 | ---------- 24 | data : np.int32 arrays 25 | 26 | Returns 27 | ------- 28 | mapped_data : np.int32 arrays 29 | n : length of mapped_data 30 | 31 | """ 32 | uniq = list(set(data)) 33 | 34 | id_dict = {old: new for new, old in enumerate(sorted(uniq))} 35 | data = np.array(list(map(lambda x: id_dict[x], data))) 36 | n = len(uniq) 37 | 38 | return data, id_dict, n 39 | 40 | def one_hot(target, n_classes): 41 | targets = np.array([target]).reshape(-1) 42 | one_hot_targets = np.eye(n_classes)[targets] 43 | return one_hot_targets 44 | 45 | 46 | def transform_onehot(answers, N_ANNOT, N_CLASSES, empty=-1): 47 | answers_bin_missings = [] 48 | for i in range(len(answers)): 49 | row = [] 50 | for r in range(N_ANNOT): 51 | if answers[i, r] == -1: 52 | row.append(empty * np.ones(N_CLASSES)) 53 | else: 54 | row.append(one_hot(answers[i, r], N_CLASSES)[0, :]) 55 | answers_bin_missings.append(row) 56 | answers_bin_missings = np.array(answers_bin_missings).swapaxes(1, 2) 57 | return answers_bin_missings 58 | 59 | 60 | class Dataset(data.Dataset): 61 | def __init__(self, mode='train', k=0, dataset='labelme', sparsity=0, test_ratio=0): 62 | if mode[:5] == 'train': 63 | self.mode = mode[:5] 64 | else: 65 | self.mode = mode 66 | 67 | if dataset == 'music': 68 | data_path = '../ldmi/data/music/' 69 | X = np.load(data_path + 'data_%s.npy' % self.mode) 70 | y = np.load(data_path + 'labels_%s.npy' % self.mode).astype(np.int) 71 | if mode == 'train': 72 | answers = np.load(data_path + '/answers.npy') 73 | self.answers = answers 74 | self.num_users = answers.shape[1] 75 | classes = np.unique(answers) 76 | if -1 in classes: 77 | self.num_classes = len(classes) - 1 78 | else: 79 | self.num_classes = len(classes) 80 | self.input_dims = X.shape[1] 81 | self.answers_onehot = transform_onehot(answers, answers.shape[1], self.num_classes) 82 | if dataset == 'labelme': 83 | data_path = '../ldmi/data/labelme/' 84 | X = np.load(data_path + self.mode + '/data_%s_vgg16.npy' % self.mode) 85 | y = np.load(data_path + self.mode + '/labels_%s.npy' % self.mode) 86 | X = X.reshape(X.shape[0], -1) 87 | if mode == 'train': 88 | answers = np.load(data_path + self.mode + '/answers.npy') 89 | self.answers = answers 90 | self.num_users = answers.shape[1] 91 | classes = np.unique(answers) 92 | if -1 in classes: 93 | self.num_classes = len(classes) - 1 94 | else: 95 | self.num_classes = len(classes) 96 | self.input_dims = X.shape[1] 97 | self.answers_onehot = transform_onehot(answers, answers.shape[1], 8) 98 | 99 | # y = np.load(data_path + self.mode + '/labels_%s.npy' % self.mode) 100 | # y = simple_majority_voting(answers) 101 | elif mode == 'train_dmi': 102 | answers = np.load(data_path + self.mode + '/answers.npy') 103 | self.answers = transform_onehot(answers, answers.shape[1], 8) 104 | self.num_users = answers.shape[1] 105 | classes = np.unique(answers) 106 | if -1 in classes: 107 | self.num_classes = len(classes) - 1 108 | else: 109 | self.num_classes = len(classes) 110 | self.input_dims = X.shape[1] 111 | train_num = int(len(X) * (1 - test_ratio)) 112 | self.X = torch.from_numpy(X).float()[:train_num] 113 | self.X_val = torch.from_numpy(X).float()[train_num:] 114 | if k: 115 | dist_mat = euclidean_distances(X, X) 116 | k_neighbors = np.argsort(dist_mat, 1)[:, :k] 117 | self.ins_feat = torch.from_numpy(X) 118 | self.k_neighbors = k_neighbors 119 | # self.X = torch.arange(0, len(dist_mat), 1) 120 | self.y = torch.from_numpy(y)[:train_num] 121 | self.y_val = torch.from_numpy(y)[train_num:] 122 | if mode == 'train': 123 | self.ans_val = answers[train_num:] 124 | 125 | 126 | def __len__(self): 127 | return len(self.X) 128 | 129 | def __getitem__(self, idx): 130 | if self.mode == 'train': 131 | return idx, self.X[idx], self.answers[idx], self.answers_onehot[idx], self.y[idx] 132 | else: 133 | return idx, self.X[idx], self.y[idx] 134 | 135 | 136 | def simple_majority_voting(response, empty=-1): 137 | mv = [] 138 | for row in response: 139 | bincount = np.bincount(row[row != empty]) 140 | mv.append(np.argmax(bincount)) 141 | return np.array(mv) 142 | -------------------------------------------------------------------------------- /conal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | from torch.autograd import Variable 5 | from torch.nn import functional as F 6 | import math 7 | from vgg import * 8 | from torchvision import transforms 9 | import torch.nn.functional as F 10 | from functional import * 11 | 12 | 13 | class CoNAL(nn.Module): 14 | def __identity_init(self, shape): 15 | out = np.ones(shape) * 0 16 | if len(shape) == 3: 17 | for r in range(shape[0]): 18 | for i in range(shape[1]): 19 | out[r, i, i] = 2 20 | elif len(shape) == 2: 21 | for i in range(shape[1]): 22 | out[i, i] = 2 23 | return torch.Tensor(out).cuda() 24 | 25 | def __init__(self, num_annotators, input_dims, num_class, rate=0.5, conn_type='MW', backbone_model=None, user_feature=None 26 | , common_module='simple', num_side_features=None, nb=None, u_features=None, 27 | v_features=None, u_features_side=None, v_features_side=None, input_dim=None, emb_dim=None, hidden=None, gumbel_common=False): 28 | super(CoNAL, self).__init__() 29 | self.num_annotators = num_annotators 30 | self.conn_type = conn_type 31 | self.gumbel_sigmoid = GumbelSigmoid(temp=0.01) 32 | 33 | self.linear1 = nn.Linear(input_dims, 128) 34 | 35 | self.ln1 = nn.Linear(128, 256) 36 | self.ln2 = nn.Linear(256, 128) 37 | 38 | self.linear2 = nn.Linear(128, num_class) 39 | 40 | self.dropout1 = nn.Dropout(0.5) 41 | self.dropout2 = nn.Dropout(0.5) 42 | self.relu = nn.ReLU() 43 | self.rate = rate 44 | self.kernel = nn.Parameter(self.__identity_init((num_annotators, num_class, num_class)), 45 | requires_grad=True) 46 | 47 | self.common_kernel = nn.Parameter(self.__identity_init((num_class, num_class)) , 48 | requires_grad=True) 49 | 50 | self.backbone_model = None 51 | if backbone_model == 'vgg16': 52 | self.backbone_model = VGG('VGG16').cuda() 53 | self.feature = self.backbone_model.features 54 | self.classifier = self.backbone_model.classifier 55 | self.common_module = common_module 56 | 57 | if self.common_module == 'simple': 58 | com_emb_size = 20 59 | self.user_feature_vec = torch.from_numpy(user_feature).float().cuda() 60 | self.diff_linear_1 = nn.Linear(input_dims, 128) 61 | self.diff_linear_2 = nn.Linear(128, com_emb_size) 62 | self.user_feature_1 = nn.Linear(self.user_feature_vec.size(1), com_emb_size) 63 | self.bn_instance = torch.nn.BatchNorm1d(com_emb_size, affine=False) 64 | self.bn_user = torch.nn.BatchNorm1d(com_emb_size, affine=False) 65 | self.single_weight = nn.Linear(20, 1, bias=False) 66 | 67 | def simple_common_module(self, input): 68 | instance_difficulty = self.diff_linear_1(input) 69 | instance_difficulty = self.diff_linear_2(instance_difficulty) 70 | 71 | instance_difficulty = F.normalize(instance_difficulty) 72 | user_feature = self.user_feature_1(self.user_feature_vec) 73 | user_feature = F.normalize(user_feature) 74 | common_rate = torch.einsum('ij,kj->ik', (instance_difficulty, user_feature)) 75 | common_rate = torch.nn.functional.sigmoid(common_rate) 76 | return common_rate 77 | 78 | def forward(self, input, y=None, mode='train', support=None, support_t=None, idx=None): 79 | crowd_out = None 80 | if self.backbone_model: 81 | cls_out = self.backbone_model(input) 82 | else: 83 | x = input.view(input.size(0), -1) 84 | x = self.dropout1(F.relu(self.linear1(x))) 85 | x = self.linear2(x) 86 | cls_out = torch.nn.functional.softmax(x, dim=1) 87 | if mode == 'train': 88 | x = input.view(input.size(0), -1) 89 | if self.common_module == 'simple': 90 | common_rate = self.simple_common_module(x) 91 | common_prob = torch.einsum('ij,jk->ik', (cls_out, self.common_kernel)) 92 | indivi_prob = torch.einsum('ik,jkl->ijl', (cls_out, self.kernel)) 93 | 94 | crowd_out = common_rate[:, :, None] * common_prob[:, None, :] + (1 - common_rate[:, :, None]) * indivi_prob # single instance 95 | crowd_out = crowd_out.transpose(1, 2) 96 | if self.common_module == 'simple' or mode == 'test': 97 | return cls_out, crowd_out 98 | 99 | class CoNAL_music(nn.Module): 100 | def __identity_init(self, shape): 101 | out = np.ones(shape) * 0 102 | if len(shape) == 3: 103 | for r in range(shape[0]): 104 | for i in range(shape[1]): 105 | out[r, i, i] = 2 106 | elif len(shape) == 2: 107 | for i in range(shape[1]): 108 | out[i, i] = 2 109 | return torch.Tensor(out).cuda() 110 | 111 | def __init__(self, num_annotators, input_dims, num_class, rate=0.5, conn_type='MW', backbone_model=None, user_feature=None 112 | , common_module='simple', num_side_features=None, nb=None, u_features=None, 113 | v_features=None, u_features_side=None, v_features_side=None, input_dim=None, emb_dim=None, hidden=None, gumbel_common=False): 114 | super(CoNAL_music, self).__init__() 115 | self.num_annotators = num_annotators 116 | self.conn_type = conn_type 117 | self.gumbel_sigmoid = GumbelSigmoid(temp=0.01) 118 | 119 | self.linear1 = nn.Linear(input_dims, 128) 120 | 121 | self.ln1 = nn.Linear(128, 256) 122 | self.ln2 = nn.Linear(256, 128) 123 | 124 | self.bn = torch.nn.BatchNorm1d(input_dims, affine=False) 125 | self.bn1 = torch.nn.BatchNorm1d(128, affine=False) 126 | 127 | self.linear2 = nn.Linear(128, num_class) 128 | 129 | self.dropout1 = nn.Dropout(0.5) 130 | self.dropout2 = nn.Dropout(0.5) 131 | self.relu = nn.ReLU() 132 | self.rate = rate 133 | self.kernel = nn.Parameter(self.__identity_init((num_annotators, num_class, num_class)), 134 | requires_grad=True) 135 | 136 | self.common_kernel = nn.Parameter(self.__identity_init((num_class, num_class)) , 137 | requires_grad=True) 138 | 139 | self.backbone_model = None 140 | if backbone_model == 'vgg16': 141 | self.backbone_model = VGG('VGG16').cuda() 142 | self.feature = self.backbone_model.features 143 | self.classifier = self.backbone_model.classifier 144 | self.common_module = common_module 145 | 146 | if self.common_module == 'simple': 147 | com_emb_size = 80 148 | self.user_feature_vec = torch.from_numpy(user_feature).float().cuda() 149 | self.diff_linear_1 = nn.Linear(input_dims, 128) 150 | self.diff_linear_2 = nn.Linear(128, com_emb_size) 151 | self.user_feature_1 = nn.Linear(self.user_feature_vec.size(1), com_emb_size) 152 | self.bn_instance = torch.nn.BatchNorm1d(com_emb_size, affine=False) 153 | self.bn_user = torch.nn.BatchNorm1d(com_emb_size, affine=False) 154 | self.single_weight = nn.Linear(20, 1, bias=False) 155 | 156 | def simple_common_module(self, input): 157 | instance_difficulty = self.diff_linear_1(input) 158 | instance_difficulty = self.diff_linear_2(instance_difficulty) 159 | 160 | user_feature = self.user_feature_1(self.user_feature_vec) 161 | user_feature = F.normalize(user_feature) 162 | common_rate = torch.einsum('ij,kj->ik', (instance_difficulty, user_feature)) 163 | common_rate = torch.nn.functional.sigmoid(common_rate) 164 | return common_rate 165 | 166 | def forward(self, input, y=None, mode='train', support=None, support_t=None, idx=None): 167 | crowd_out = None 168 | if self.backbone_model: 169 | cls_out = self.backbone_model(input) 170 | else: 171 | x = input.view(input.size(0), -1) 172 | x = self.bn(x) 173 | x = self.dropout1(F.relu(self.linear1(x))) 174 | x = self.bn1(x) 175 | x = self.linear2(x) 176 | cls_out = torch.nn.functional.softmax(x, dim=1) 177 | if mode == 'train': 178 | x = input.view(input.size(0), -1) 179 | if self.common_module == 'simple': 180 | common_rate = self.simple_common_module(x) 181 | elif self.common_module == 'gcn': 182 | u = list(range(self.num_annotators)) 183 | common_rate, rec_out = self.gae(u, idx, support, support_t) 184 | common_rate = common_rate.transpose(0, 1) 185 | common_prob = torch.einsum('ij,jk->ik', (cls_out, self.common_kernel)) 186 | indivi_prob = torch.einsum('ik,jkl->ijl', (cls_out, self.kernel)) 187 | 188 | crowd_out = common_rate[:, :, None] * common_prob[:, None, :] + (1 - common_rate[:, :, None]) * indivi_prob # single instance 189 | crowd_out = crowd_out.transpose(1, 2) 190 | if self.common_module == 'simple' or mode == 'test': 191 | return cls_out, crowd_out 192 | 193 | 194 | 195 | 196 | --------------------------------------------------------------------------------