├── data └── data_prepare.txt ├── requirements.txt ├── README.md ├── models.py ├── tools.py ├── resnet.py ├── data.py └── main.py /data/data_prepare.txt: -------------------------------------------------------------------------------- 1 | You can place the used datasets in this folder when you have downloaded them. 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.10.0 2 | matplotlib==3.3.1 3 | numpy==1.19.1 4 | torch==1.2.0 5 | torchvision==0.4.0 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Part-dependent Label Noise 2 | 3 | NeurIPS‘20: Part-dependent Label Noise: Towards Instance-dependent Label Noise (PyTorch implementation). 4 | 5 | This is the code for the paper: 6 | [Part-dependent Label Noise: Towards Instance-dependent Label Noise](https://arxiv.org/pdf/2006.07836.pdf) 7 | Xiaobo Xia, Tongliang Liu, Bo Han, Nannan Wang, Mingming Gong, Haifeng Liu, Gang Niu, Dacheng Tao, Masashi Sugiyama. 8 | 9 | 10 | ## Dependencies 11 | We implement our methods by PyTorch on NVIDIA Tesla V100 GPU. The environment is as bellow: 12 | - [Ubuntu 16.04 Desktop](https://ubuntu.com/download) 13 | - [PyTorch](https://PyTorch.org/), version = 1.2.0 14 | - [CUDA](https://developer.nvidia.com/cuda-downloads), version = 10.0 15 | - [Anaconda3](https://www.anaconda.com/) 16 | 17 | ### Install requirements.txt 18 | ~~~ 19 | pip install -r requirements.txt 20 | ~~~ 21 | 22 | ## Experiments 23 | We verify the effectiveness of the proposed method on synthetic noisy datasets. In this repository, we provide the used [datasets](https://drive.google.com/open?id=1Tz3W3JVYv2nu-mdM6x33KSnRIY1B7ygQ) (the images and labels have been processed to .npy format). You should put the datasets in the folder “data” when you have downloaded them. 24 | Here is a training example: 25 | ```bash 26 | python main.py \ 27 | --dataset mnist \ 28 | --noise_rate 0.2 \ 29 | --gpu 0 30 | ``` 31 | If you find this code useful in your research, please cite 32 | ```bash 33 | @inproceedings{xia2020part, 34 | title={Part-dependent Label Noise: Towards Instance-dependent Label Noise}, 35 | author={Xia, Xiaobo and Liu, Tongliang and Han, Bo and Wang, Nannan and Gong, Mingming and Liu, Haifeng and Niu, Gang and Tao, Dacheng and Sugiyama, Masashi}, 36 | booktitle={NeurIPS}, 37 | year={2020} 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | def init_params(net): 7 | '''Init layer parameters.''' 8 | for m in net.modules(): 9 | if isinstance(m, nn.Conv2d): 10 | init.kaiming_normal(m.weight, mode='fan_out') 11 | 12 | elif isinstance(m, nn.Linear): 13 | init.normal_(m.weight, mean=0, std=1e-3) 14 | 15 | def norm(T): 16 | row_abs = torch.abs(T) 17 | row_sum = torch.sum(row_abs, 1) 18 | T_norm = row_abs / row_sum 19 | return T_norm 20 | 21 | 22 | 23 | class Matrix_optimize(nn.Module): 24 | def __init__(self, basis_num, num_classes): 25 | super(Matrix_optimize, self).__init__() 26 | self.basis_matrix = self._make_layer(basis_num, num_classes) 27 | for m in self.modules(): 28 | if isinstance(m, nn.Linear): 29 | init.normal_(m.weight, std=1e-1) 30 | 31 | def _make_layer(self, basis_num, num_classes): 32 | 33 | layers = [] 34 | for i in range(0, basis_num): 35 | layers.append(nn.Linear(num_classes, 1, False)) 36 | return nn.Sequential(*layers) 37 | 38 | def forward(self, W, num_classes): 39 | results = torch.zeros(num_classes, 1) 40 | for i in range(len(W)): 41 | 42 | coefficient_matrix = float(W[i]) * torch.eye(num_classes, num_classes) 43 | self.basis_matrix[i].weight.data = norm(self.basis_matrix[i].weight.data) # s.t. 44 | anchor_vector = self.basis_matrix[i](coefficient_matrix) 45 | results += anchor_vector 46 | self.basis_matrix[i].weight.data = norm(self.basis_matrix[i].weight.data) 47 | return results 48 | 49 | 50 | class LeNet(nn.Module): 51 | def __init__(self): 52 | super(LeNet, self).__init__() 53 | self.conv1 = nn.Conv2d(1,6,5,stride=1,padding=2) 54 | self.conv2 = nn.Conv2d(6, 16, 5) 55 | self.fc1 = nn.Linear(400, 120) 56 | self.fc2 = nn.Linear(120, 84) 57 | self.fc3 = nn.Linear(84, 10) 58 | self.T_revision = nn.Linear(10, 10, False) 59 | 60 | def forward(self, x, revision=True): 61 | correction = self.T_revision.weight 62 | out = F.relu(self.conv1(x)) 63 | out = F.max_pool2d(out, 2) 64 | out = F.relu(self.conv2(out)) 65 | out = F.max_pool2d(out, 2) 66 | out = out.view(out.size(0), -1) 67 | out = F.relu(self.fc1(out)) 68 | out_1 = F.relu(self.fc2(out)) # -> representations 69 | out_2 = self.fc3(out_1) 70 | if revision == True: 71 | return out_1, out_2, correction 72 | else: 73 | return out_1, out_2 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from math import inf 5 | from scipy import stats 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | 9 | def get_instance_noisy_label(n, dataset, labels, num_classes, feature_size, norm_std, seed): 10 | # n -> noise_rate 11 | # dataset -> mnist, cifar10 # not train_loader 12 | # labels -> labels (targets) 13 | # label_num -> class number 14 | # feature_size -> the size of input images (e.g. 28*28) 15 | # norm_std -> default 0.1 16 | # seed -> random_seed 17 | print("building dataset...") 18 | label_num = num_classes 19 | np.random.seed(int(seed)) 20 | torch.manual_seed(int(seed)) 21 | torch.cuda.manual_seed(int(seed)) 22 | 23 | P = [] 24 | flip_distribution = stats.truncnorm((0 - n) / norm_std, (1 - n) / norm_std, loc=n, scale=norm_std) 25 | flip_rate = flip_distribution.rvs(labels.shape[0]) 26 | 27 | if isinstance(labels, list): 28 | labels = torch.FloatTensor(labels) 29 | labels = labels.cuda() 30 | 31 | W = np.random.randn(label_num, feature_size, label_num) 32 | 33 | 34 | W = torch.FloatTensor(W).cuda() 35 | for i, (x, y) in enumerate(dataset): 36 | # 1*m * m*10 = 1*10 37 | x = x.cuda() 38 | A = x.view(1, -1).mm(W[y]).squeeze(0) 39 | A[y] = -inf 40 | A = flip_rate[i] * F.softmax(A, dim=0) 41 | A[y] += 1 - flip_rate[i] 42 | P.append(A) 43 | P = torch.stack(P, 0).cpu().numpy() 44 | l = [i for i in range(label_num)] 45 | new_label = [np.random.choice(l, p=P[i]) for i in range(labels.shape[0])] 46 | record = [[0 for _ in range(label_num)] for i in range(label_num)] 47 | 48 | for a, b in zip(labels, new_label): 49 | a, b = int(a), int(b) 50 | record[a][b] += 1 51 | 52 | 53 | pidx = np.random.choice(range(P.shape[0]), 1000) 54 | cnt = 0 55 | for i in range(1000): 56 | if labels[pidx[i]] == 0: 57 | a = P[pidx[i], :] 58 | cnt += 1 59 | if cnt >= 10: 60 | break 61 | return np.array(new_label) 62 | 63 | def norm(T): 64 | row_abs = torch.abs(T) 65 | row_sum = torch.sum(row_abs, 1).unsqueeze(1) 66 | T_norm = row_abs / row_sum 67 | return T_norm 68 | 69 | 70 | 71 | def fit(X, num_classes, percentage, filter_outlier=False): 72 | # number of classes 73 | c = num_classes 74 | T = np.empty((c, c)) # +1 -> index 75 | eta_corr = X 76 | ind = [] 77 | for i in np.arange(c): 78 | if not filter_outlier: 79 | idx_best = np.argmax(eta_corr[:, i]) 80 | else: 81 | eta_thresh = np.percentile(eta_corr[:, i], percentage,interpolation='higher') 82 | robust_eta = eta_corr[:, i] 83 | robust_eta[robust_eta >= eta_thresh] = 0.0 84 | idx_best = np.argmax(robust_eta) 85 | ind.append(idx_best) 86 | for j in np.arange(c): 87 | T[i, j] = eta_corr[idx_best, j] 88 | 89 | return T, ind 90 | 91 | def data_split(data, targets, split_percentage, seed=1): 92 | 93 | num_samples = int(targets.shape[0]) 94 | np.random.seed(int(seed)) 95 | train_set_index = np.random.choice(num_samples, int(num_samples*split_percentage), replace=False) 96 | index = np.arange(data.shape[0]) 97 | val_set_index = np.delete(index, train_set_index) 98 | train_set, val_set = data[train_set_index, :], data[val_set_index, :] 99 | train_labels, val_labels = targets[train_set_index], targets[val_set_index] 100 | 101 | return train_set, val_set, train_labels, val_labels 102 | 103 | 104 | def transform_target(label): 105 | label = np.array(label) 106 | target = torch.from_numpy(label).long() 107 | return target 108 | 109 | def init_params(net): 110 | '''Init layer parameters.''' 111 | for m in net.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | nn.init.kaiming_normal(m.weight, mode='fan_out') 114 | 115 | elif isinstance(m, nn.Linear): 116 | nn.init.normal_(m.weight, std=1e-1) 117 | 118 | return net -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Linear(nn.Module): 7 | def __init__(self, in_features, out_features): 8 | 9 | super(Linear, self).__init__() 10 | self.w = nn.Parameter(torch.randn(in_features, out_features)) 11 | 12 | def forward(self, x): 13 | x = x.mm(self.w) 14 | return x 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | def __init__(self, in_planes, planes, stride=1): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride != 1 or in_planes != self.expansion*planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 29 | nn.BatchNorm2d(self.expansion*planes) 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = self.bn2(self.conv2(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class Bottleneck(nn.Module): 41 | expansion = 4 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(self.expansion*planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = F.relu(self.bn2(self.conv2(out))) 61 | out = self.bn3(self.conv3(out)) 62 | out += self.shortcut(x) 63 | out = F.relu(out) 64 | return out 65 | 66 | class ResNet(nn.Module): 67 | def __init__(self, block, num_blocks, num_classes): 68 | super(ResNet, self).__init__() 69 | self.in_planes = 64 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=0, bias=False) 71 | self.bn1 = nn.BatchNorm2d(64) 72 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 73 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 74 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 75 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 76 | self.linear = nn.Linear(512*block.expansion, num_classes) 77 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 78 | self.T_revision = nn.Linear(num_classes, num_classes, False) 79 | 80 | 81 | def _make_layer(self, block, planes, num_blocks, stride): 82 | strides = [stride] + [1]*(num_blocks-1) 83 | layers = [] 84 | for stride in strides: 85 | layers.append(block(self.in_planes, planes, stride)) 86 | self.in_planes = planes * block.expansion 87 | return nn.Sequential(*layers) 88 | 89 | def forward(self, x, revision=True): 90 | 91 | correction = self.T_revision.weight 92 | 93 | out = F.relu(self.bn1(self.conv1(x))) 94 | out = self.layer1(out) 95 | out = self.layer2(out) 96 | out = self.layer3(out) 97 | out = self.layer4(out) 98 | out = self.avgpool(out) 99 | out_1 = out.view(out.size(0), -1) 100 | out_2 = self.linear(out_1) 101 | if revision == True: 102 | return out_1, out_2, correction 103 | else: 104 | return out_1, out_2 105 | 106 | 107 | class ResNet_F(nn.Module): 108 | def __init__(self, block, num_blocks, num_classes): 109 | super(ResNet_F, self).__init__() 110 | self.in_planes = 64 111 | self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=0, bias=False) 112 | self.bn1 = nn.BatchNorm2d(64) 113 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 114 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 115 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 116 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 117 | self.linear = nn.Linear(512 * block.expansion, num_classes) 118 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 119 | self.T_revision = nn.Linear(num_classes, num_classes, False) 120 | 121 | def _make_layer(self, block, planes, num_blocks, stride): 122 | strides = [stride] + [1] * (num_blocks - 1) 123 | layers = [] 124 | for stride in strides: 125 | layers.append(block(self.in_planes, planes, stride)) 126 | self.in_planes = planes * block.expansion 127 | return nn.Sequential(*layers) 128 | 129 | def forward(self, x, revision=True): 130 | 131 | correction = self.T_revision.weight 132 | 133 | out = F.relu(self.bn1(self.conv1(x))) 134 | out = self.layer1(out) 135 | out = self.layer2(out) 136 | out = self.layer3(out) 137 | out = self.layer4(out) 138 | out = self.avgpool(out) 139 | out_1 = out.view(out.size(0), -1) 140 | out_2 = self.linear(out_1) 141 | if revision == True: 142 | return out_1, out_2, correction 143 | else: 144 | return out_1, out_2 145 | 146 | 147 | def ResNet18(num_classes): 148 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 149 | 150 | def ResNet18_F(num_classes): 151 | return ResNet_F(BasicBlock, [2,2,2,2], num_classes) 152 | 153 | def ResNet34(num_classes): 154 | return ResNet(BasicBlock, [3,4,6,3], num_classes) 155 | 156 | def ResNet50(num_classes): 157 | return ResNet(Bottleneck, [3,4,6,3], num_classes) 158 | 159 | def ResNet101(num_classes): 160 | return ResNet(Bottleneck, [3,4,23,3], num_classes) 161 | 162 | def ResNet152(num_classes): 163 | return ResNet(Bottleneck, [3,8,36,3], num_classes) 164 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.utils.data as Data 3 | from PIL import Image 4 | 5 | import tools 6 | import torch 7 | 8 | class mnist_dataset(Data.Dataset): 9 | def __init__(self, train=True, transform=None, target_transform=None, noise_rate=0.2, split_percentage=0.9, seed=1, num_classes=10, feature_size=28*28, norm_std=0.1): 10 | 11 | self.transform = transform 12 | self.target_transform = target_transform 13 | self.train = train 14 | original_images = np.load('data/mnist/train_images.npy') 15 | original_labels = np.load('data/mnist/train_labels.npy') 16 | data = torch.from_numpy(original_images).float() 17 | targets = torch.from_numpy(original_labels) 18 | 19 | dataset = zip(data, targets) 20 | new_labels = tools.get_instance_noisy_label(noise_rate, dataset, targets, num_classes, feature_size, norm_std, seed) 21 | 22 | self.train_data, self.val_data, self.train_labels, self.val_labels = tools.data_split(original_images, new_labels, split_percentage,seed) 23 | 24 | def __getitem__(self, index): 25 | 26 | if self.train: 27 | img, label = self.train_data[index], self.train_labels[index] 28 | else: 29 | img, label = self.val_data[index], self.val_labels[index] 30 | 31 | img = Image.fromarray(img) 32 | 33 | if self.transform is not None: 34 | img = self.transform(img) 35 | 36 | if self.target_transform is not None: 37 | label = self.target_transform(label) 38 | 39 | return img, label 40 | def __len__(self): 41 | 42 | if self.train: 43 | return len(self.train_data) 44 | 45 | else: 46 | return len(self.val_data) 47 | 48 | 49 | class mnist_test_dataset(Data.Dataset): 50 | def __init__(self, transform=None, target_transform=None): 51 | 52 | self.transform = transform 53 | self.target_transform = target_transform 54 | 55 | self.test_data = np.load('data/mnist/test_images.npy') 56 | self.test_labels = np.load('data/mnist/test_labels.npy') - 1 # 0-9 57 | 58 | def __getitem__(self, index): 59 | 60 | img, label = self.test_data[index], self.test_labels[index] 61 | 62 | img = Image.fromarray(img) 63 | 64 | if self.transform is not None: 65 | img = self.transform(img) 66 | 67 | if self.target_transform is not None: 68 | label = self.target_transform(label) 69 | 70 | return img, label 71 | 72 | def __len__(self): 73 | return len(self.test_data) 74 | 75 | class cifar10_dataset(Data.Dataset): 76 | def __init__(self, train=True, transform=None, target_transform=None, noise_rate=0.2, split_percentage=0.9, seed=1, num_classes=10, feature_size=3*32*32, norm_std=0.1): 77 | 78 | self.transform = transform 79 | self.target_transform = target_transform 80 | self.train = train 81 | 82 | original_images = np.load('data/cifar10/train_images.npy') 83 | original_labels = np.load('data/cifar10/train_labels.npy') 84 | data = torch.from_numpy(original_images).float() 85 | targets = torch.from_numpy(original_labels) 86 | 87 | dataset = zip(data, targets) 88 | new_labels = tools.get_instance_noisy_label(noise_rate, dataset, targets, num_classes, feature_size, norm_std, seed) 89 | 90 | 91 | self.train_data, self.val_data, self.train_labels, self.val_labels = tools.data_split(original_images, new_labels, split_percentage,seed) 92 | if self.train: 93 | self.train_data = self.train_data.reshape((-1,3,32,32)) 94 | self.train_data = self.train_data.transpose((0, 2, 3, 1)) 95 | print(self.train_data.shape) 96 | 97 | else: 98 | self.val_data = self.val_data.reshape((-1, 3,32,32)) 99 | self.val_data = self.val_data.transpose((0, 2, 3, 1)) 100 | 101 | def __getitem__(self, index): 102 | 103 | if self.train: 104 | img, label = self.train_data[index], self.train_labels[index] 105 | 106 | else: 107 | img, label = self.val_data[index], self.val_labels[index] 108 | 109 | img = Image.fromarray(img) 110 | 111 | if self.transform is not None: 112 | img = self.transform(img) 113 | 114 | if self.target_transform is not None: 115 | label = self.target_transform(label) 116 | 117 | return img, label 118 | def __len__(self): 119 | 120 | if self.train: 121 | return len(self.train_data) 122 | 123 | else: 124 | return len(self.val_data) 125 | 126 | class cifar10_test_dataset(Data.Dataset): 127 | def __init__(self, transform=None, target_transform=None): 128 | 129 | self.transform = transform 130 | self.target_transform = target_transform 131 | 132 | self.test_data = np.load('data/cifar10/test_images.npy') 133 | self.test_labels = np.load('data/cifar10/test_labels.npy') 134 | self.test_data = self.test_data.reshape((10000,3,32,32)) 135 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) 136 | 137 | def __getitem__(self, index): 138 | 139 | img, label = self.test_data[index], self.test_labels[index] 140 | 141 | img = Image.fromarray(img) 142 | 143 | if self.transform is not None: 144 | img = self.transform(img) 145 | 146 | if self.target_transform is not None: 147 | label = self.target_transform(label) 148 | 149 | return img, label 150 | 151 | def __len__(self): 152 | return len(self.test_data) 153 | 154 | class svhn_dataset(Data.Dataset): 155 | def __init__(self, train=True, transform=None, target_transform=None, noise_rate=0.2, split_percentage=0.9, seed=1, num_classes=10, feature_size=3*32*32, norm_std=0.1): 156 | 157 | self.transform = transform 158 | self.target_transform = target_transform 159 | self.train = train 160 | 161 | original_images = np.load('data/svhn/train_images.npy') 162 | original_labels = np.load('data/svhn/train_labels.npy') 163 | data = torch.from_numpy(original_images).float() 164 | targets = torch.from_numpy(original_labels) 165 | 166 | dataset = zip(data, targets) 167 | new_labels = tools.get_instance_noisy_label(noise_rate, dataset, targets, num_classes, feature_size, norm_std, seed) 168 | 169 | self.train_data, self.val_data, self.train_labels, self.val_labels = tools.data_split(original_images, new_labels, split_percentage,seed) 170 | if self.train: 171 | self.train_data = self.train_data.reshape((-1,3,32,32)) 172 | self.train_data = self.train_data.transpose((0, 2, 3, 1)) 173 | 174 | else: 175 | self.val_data = self.val_data.reshape((-1, 3,32,32)) 176 | self.val_data = self.val_data.transpose((0, 2, 3, 1)) 177 | 178 | def __getitem__(self, index): 179 | 180 | if self.train: 181 | img, label = self.train_data[index], self.train_labels[index] 182 | 183 | else: 184 | img, label = self.val_data[index], self.val_labels[index] 185 | 186 | img = Image.fromarray(img) 187 | 188 | if self.transform is not None: 189 | img = self.transform(img) 190 | 191 | if self.target_transform is not None: 192 | label = self.target_transform(label) 193 | 194 | return img, label 195 | def __len__(self): 196 | 197 | if self.train: 198 | return len(self.train_data) 199 | 200 | else: 201 | return len(self.val_data) 202 | 203 | class svhn_test_dataset(Data.Dataset): 204 | def __init__(self, transform=None, target_transform=None): 205 | 206 | self.transform = transform 207 | self.target_transform = target_transform 208 | 209 | self.test_data = np.load('data/svhn/test_images.npy') 210 | self.test_labels = np.load('data/svhn/test_labels.npy') 211 | self.test_data = self.test_data.reshape((-1,3,32,32)) 212 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) 213 | 214 | def __getitem__(self, index): 215 | 216 | img, label = self.test_data[index], self.test_labels[index] 217 | 218 | img = Image.fromarray(img) 219 | 220 | if self.transform is not None: 221 | img = self.transform(img) 222 | 223 | if self.target_transform is not None: 224 | label = self.target_transform(label) 225 | 226 | return img, label 227 | 228 | def __len__(self): 229 | return len(self.test_data) 230 | 231 | class fashionmnist_dataset(Data.Dataset): 232 | def __init__(self, train=True, transform=None, target_transform=None, noise_rate=0.2, split_percentage=0.9, seed=1, num_classes=10, feature_size=784, norm_std=0.1): 233 | 234 | self.transform = transform 235 | self.target_transform = target_transform 236 | self.train = train 237 | 238 | original_images = np.load('data/fashionmnist/train_images.npy') 239 | original_labels = np.load('data/fashionmnist/train_labels.npy') 240 | data = torch.from_numpy(original_images).float() 241 | targets = torch.from_numpy(original_labels) 242 | 243 | dataset = zip(data, targets) 244 | new_labels = tools.get_instance_noisy_label(noise_rate, dataset, targets, num_classes, feature_size, norm_std, seed) 245 | 246 | 247 | self.train_data, self.val_data, self.train_labels, self.val_labels = tools.data_split(original_images, new_labels, split_percentage,seed) 248 | 249 | def __getitem__(self, index): 250 | 251 | if self.train: 252 | img, label = self.train_data[index], self.train_labels[index] 253 | 254 | else: 255 | img, label = self.val_data[index], self.val_labels[index] 256 | 257 | img = Image.fromarray(img) 258 | 259 | if self.transform is not None: 260 | img = self.transform(img) 261 | 262 | if self.target_transform is not None: 263 | label = self.target_transform(label) 264 | 265 | return img, label 266 | def __len__(self): 267 | 268 | if self.train: 269 | return len(self.train_data) 270 | 271 | else: 272 | return len(self.val_data) 273 | 274 | class fashionmnist_test_dataset(Data.Dataset): 275 | def __init__(self, transform=None, target_transform=None): 276 | 277 | self.transform = transform 278 | self.target_transform = target_transform 279 | 280 | self.test_data = np.load('data/fashionmnist/test_images.npy') 281 | self.test_labels = np.load('data/fashionmnist/test_labels.npy') 282 | 283 | 284 | def __getitem__(self, index): 285 | 286 | img, label = self.test_data[index], self.test_labels[index] 287 | 288 | img = Image.fromarray(img) 289 | 290 | if self.transform is not None: 291 | img = self.transform(img) 292 | 293 | if self.target_transform is not None: 294 | label = self.target_transform(label) 295 | 296 | return img, label 297 | 298 | def __len__(self): 299 | return len(self.test_data) 300 | 301 | 302 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from models import LeNet, Matrix_optimize 9 | import torchvision.transforms as transforms 10 | import numpy as np 11 | import argparse 12 | import datetime 13 | import resnet 14 | import tools 15 | import data 16 | import copy 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--lr', type = float, default = 0.01) 20 | parser.add_argument('--lr_revision', type = float, default = 5e-7) 21 | parser.add_argument('--result_dir', type = str, help = 'dir to save result txt files', default = 'results/') 22 | parser.add_argument('--model_dir', type=str, help='dir to save model files', default='model/') 23 | parser.add_argument('--noise_rate', type = float, help = 'corruption rate, should be less than 1', default = 0.2) 24 | parser.add_argument('--noise_type', type = str, help='[instance, symmetric]', default='instance') 25 | parser.add_argument('--dataset', type = str, help = 'mnist, cifar10, cifar100', default = 'mnist') 26 | parser.add_argument('--n_epoch_1', type = int, help = 'estimate', default=10) 27 | parser.add_argument('--n_epoch_2', type = int, help = 'loss correction',default=100) 28 | parser.add_argument('--n_epoch_3', type = int, help = 'revision',default=50) 29 | parser.add_argument('--n_epoch_4', type = int, help = 'learn matrix',default=1500) 30 | parser.add_argument('--iteration_nmf', type = int, default=20) 31 | parser.add_argument('--optimizer', type = str, default='SGD') 32 | parser.add_argument('--seed', type = int, default=5) 33 | parser.add_argument('--print_freq', type = int, default=100) 34 | parser.add_argument('--num_workers', type = int, default=8, help='how many subprocesses to use for data loading') 35 | parser.add_argument('--model_type', type = str, help='[ce, ours]', default='ours') 36 | parser.add_argument('--split_percentage', type = float, help = 'train and validation', default=0.9) 37 | parser.add_argument('--norm_std', type = float, help = 'distribution ', default=0.1) 38 | parser.add_argument('--num_classes', type = int, help = 'num_classes', default=10) 39 | parser.add_argument('--feature_size', type = int, help = 'the size of feature_size', default=784) 40 | parser.add_argument('--dim', type = int, help = 'the dim of representations', default=84) 41 | parser.add_argument('--basis', type = int, help = 'the num of basis', default=10) 42 | parser.add_argument('--weight_decay', type = float, help = 'weight', default=1e-4) 43 | parser.add_argument('--momentum', type = float, help = 'momentum', default=0.9) 44 | parser.add_argument('--gpu', type = int, help = 'ind of gpu', default=0) 45 | args = parser.parse_args() 46 | # 47 | torch.cuda.set_device(args.gpu) 48 | # Seed 49 | torch.manual_seed(args.seed) 50 | torch.cuda.manual_seed(args.seed) 51 | 52 | # Hyper Parameters 53 | batch_size = 128 54 | learning_rate = args.lr 55 | 56 | # load dataset 57 | if args.dataset=='mnist': 58 | args.feature_size = 28 * 28 59 | args.num_classes = 10 60 | args.n_epoch_1, args.n_epoch_2, args.n_epoch_3 = 5, 20, 50 61 | args.dim = 84 62 | args.basis = 10 63 | train_dataset = data.mnist_dataset(True, 64 | transform = transforms.Compose([ 65 | transforms.ToTensor(), 66 | transforms.Normalize((0.1307, ),(0.3081, )),]), 67 | target_transform=tools.transform_target, 68 | noise_rate=args.noise_rate, 69 | split_percentage=args.split_percentage, 70 | seed=args.seed) 71 | 72 | val_dataset = data.mnist_dataset(False, 73 | transform = transforms.Compose([ 74 | transforms.ToTensor(), 75 | transforms.Normalize((0.1307, ),(0.3081, )),]), 76 | target_transform=tools.transform_target, 77 | noise_rate=args.noise_rate, 78 | split_percentage=args.split_percentage, 79 | seed=args.seed) 80 | 81 | 82 | test_dataset = data.mnist_test_dataset( 83 | transform = transforms.Compose([ 84 | transforms.ToTensor(), 85 | transforms.Normalize((0.1307, ),(0.3081, )),]), 86 | target_transform=tools.transform_target) 87 | 88 | 89 | 90 | if args.dataset=='fashionmnist': 91 | args.feature_size = 28 * 28 92 | args.num_classes = 10 93 | args.n_epoch_1, args.n_epoch_2, args.n_epoch_3 = 5, 20, 50 94 | args.dim = 512 95 | args.basis = 10 96 | train_dataset = data.fashionmnist_dataset(True, 97 | transform = transforms.Compose([ 98 | transforms.ToTensor(), 99 | transforms.Normalize((0.1307, ),(0.3081, )),]), 100 | target_transform=tools.transform_target, 101 | noise_rate=args.noise_rate, 102 | split_percentage=args.split_percentage, 103 | seed=args.seed) 104 | 105 | val_dataset = data.fashionmnist_dataset(False, 106 | transform = transforms.Compose([ 107 | transforms.ToTensor(), 108 | transforms.Normalize((0.1307, ),(0.3081, )),]), 109 | target_transform=tools.transform_target, 110 | noise_rate=args.noise_rate, 111 | split_percentage=args.split_percentage, 112 | seed=args.seed) 113 | 114 | 115 | test_dataset = data.fashionmnist_test_dataset( 116 | transform = transforms.Compose([ 117 | transforms.ToTensor(), 118 | transforms.Normalize((0.1307, ),(0.3081, )),]), 119 | target_transform=tools.transform_target) 120 | if args.dataset=='cifar10': 121 | args.num_classes = 10 122 | args.feature_size = 3 * 32 * 32 123 | args.n_epoch_1, args.n_epoch_2, args.n_epoch_3 = 5, 50, 50 124 | args.dim = 512 125 | args.basis = 20 126 | args.iteration_nmf = 10 127 | train_dataset = data.cifar10_dataset(True, 128 | transform = transforms.Compose([ 129 | transforms.ToTensor(), 130 | transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)), 131 | ]), 132 | target_transform=tools.transform_target, 133 | noise_rate=args.noise_rate, 134 | split_percentage=args.split_percentage, 135 | seed=args.seed) 136 | 137 | val_dataset = data.cifar10_dataset(False, 138 | transform = transforms.Compose([ 139 | transforms.ToTensor(), 140 | transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)), 141 | ]), 142 | target_transform=tools.transform_target, 143 | noise_rate=args.noise_rate, 144 | split_percentage=args.split_percentage, 145 | seed=args.seed) 146 | 147 | 148 | test_dataset = data.cifar10_test_dataset( 149 | transform = transforms.Compose([ 150 | transforms.ToTensor(), 151 | transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)), 152 | ]), 153 | target_transform=tools.transform_target) 154 | 155 | if args.dataset=='svhn': 156 | args.num_classes = 10 157 | args.feature_size = 3 * 32 * 32 158 | args.n_epoch_1, args.n_epoch_2, args.n_epoch_3 = 5, 50, 50 159 | args.dim = 512 160 | args.basis = 10 161 | train_dataset = data.svhn_dataset(True, 162 | transform = transforms.Compose([ 163 | transforms.ToTensor(), 164 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)), 165 | ]), 166 | target_transform=tools.transform_target, 167 | noise_rate=args.noise_rate, 168 | split_percentage=args.split_percentage, 169 | seed=args.seed) 170 | 171 | val_dataset = data.svhn_dataset(False, 172 | transform = transforms.Compose([ 173 | transforms.ToTensor(), 174 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)), 175 | ]), 176 | target_transform=tools.transform_target, 177 | noise_rate=args.noise_rate, 178 | split_percentage=args.split_percentage, 179 | seed=args.seed) 180 | 181 | 182 | test_dataset = data.svhn_test_dataset( 183 | transform = transforms.Compose([ 184 | transforms.ToTensor(), 185 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)), 186 | ]), 187 | target_transform=tools.transform_target) 188 | 189 | 190 | # mkdir 191 | model_save_dir = args.model_dir + '/' + args.dataset + '/' + 'noise_rate_%s'%(args.noise_rate) 192 | 193 | if not os.path.exists(model_save_dir): 194 | os.system('mkdir -p %s'%(model_save_dir)) 195 | 196 | save_dir = args.result_dir +'/' +args.dataset+'/%s/' % args.model_type 197 | 198 | if not os.path.exists(save_dir): 199 | os.system('mkdir -p %s' % save_dir) 200 | 201 | model_str = args.dataset + '_%s_' % args.model_type + args.noise_type + '_' + str(args.noise_rate) 202 | 203 | txtfile = save_dir + "/" + model_str + ".txt" 204 | nowTime=datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 205 | if os.path.exists(txtfile): 206 | os.system('mv %s %s' % (txtfile, txtfile+".bak-%s" % nowTime)) 207 | 208 | 209 | def norm(T): 210 | row_sum = np.sum(T, 1) 211 | T_norm = T / row_sum 212 | return T_norm 213 | 214 | 215 | def train_m(V, r, k, e): 216 | 217 | m, n = np.shape(V) 218 | W = np.mat(np.random.random((m, r))) 219 | H = np.mat(np.random.random((r, n))) 220 | data = [] 221 | 222 | for x in range(k): 223 | V_pre = np.dot(W, H) 224 | E = V - V_pre 225 | err = 0.0 226 | err = np.sum(np.square(E)) 227 | data.append(err) 228 | if err < e: # threshold 229 | break 230 | 231 | a = np.dot(W.T, V) # Hkj 232 | b = np.dot(np.dot(W.T, W), H) 233 | 234 | for i_1 in range(r): 235 | for j_1 in range(n): 236 | if b[i_1, j_1] != 0: 237 | H[i_1, j_1] = H[i_1, j_1] * a[i_1, j_1] / b[i_1, j_1] 238 | 239 | c = np.dot(V, H.T) 240 | d = np.dot(np.dot(W, H), H.T) 241 | for i_2 in range(m): 242 | for j_2 in range(r): 243 | if d[i_2, j_2] != 0: 244 | W[i_2, j_2] = W[i_2, j_2] * c[i_2, j_2] / d[i_2, j_2] 245 | 246 | 247 | 248 | W = norm(W) 249 | 250 | 251 | return W, H, data 252 | 253 | def accuracy(logit, target, topk=(1,)): 254 | """Computes the precision@k for the specified values of k""" 255 | output = F.softmax(logit, dim=1) 256 | maxk = max(topk) 257 | batch_size = target.size(0) 258 | 259 | _, pred = output.topk(maxk, 1, True, True) 260 | pred = pred.t() 261 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 262 | 263 | res = [] 264 | for k in topk: 265 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 266 | res.append(correct_k.mul_(100.0 / batch_size)) 267 | return res 268 | 269 | # Train the Model 270 | 271 | def train(model, train_loader, epoch, optimizer, criterion): 272 | print('Training %s...' % model_str) 273 | 274 | train_total=0 275 | train_correct=0 276 | 277 | 278 | for i, (data, labels) in enumerate(train_loader): 279 | data = data.cuda() 280 | labels = labels.cuda() 281 | 282 | # Forward + Backward + Optimize 283 | optimizer.zero_grad() 284 | _, logits=model(data, revision=False) 285 | prec1, = accuracy(logits, labels, topk=(1, )) 286 | train_total+=1 287 | train_correct+=prec1 288 | loss = criterion(logits, labels) 289 | loss.backward() 290 | optimizer.step() 291 | 292 | if (i+1) % args.print_freq == 0: 293 | print('Epoch [%d/%d], Iter [%d/%d] Training Accuracy: %.4F, Loss: %.4f' 294 | %(epoch+1, args.n_epoch_1, i+1, len(train_dataset)//batch_size, prec1, loss.item())) 295 | 296 | train_acc=float(train_correct)/float(train_total) 297 | 298 | return train_acc 299 | 300 | def train_correction(model, train_loader, epoch, optimizer, W_group, basis_matrix_group, batch_size, num_classes, basis): 301 | print('Training %s...' % model_str) 302 | 303 | train_total=0 304 | train_correct=0 305 | 306 | for i, (data, labels) in enumerate(train_loader): 307 | loss = 0. 308 | data = data.cuda() 309 | labels = labels.cuda() 310 | 311 | # Forward + Backward + Optimize 312 | optimizer.zero_grad() 313 | _, logits=model(data, revision=False) 314 | 315 | logits_ = F.softmax(logits, dim=1) 316 | logits_correction_total = torch.zeros(len(labels), num_classes) 317 | for j in range(len(labels)): 318 | idx = i * batch_size + j 319 | matrix = matrix_combination(basis_matrix_group, W_group, idx, num_classes, basis) 320 | matrix = torch.from_numpy(matrix).float().cuda() 321 | logits_single = logits_[j, :].unsqueeze(0) 322 | logits_correction = logits_single.mm(matrix) 323 | pro1 = logits_single[:, labels[j]] 324 | pro2 = logits_correction[:, labels[j]] 325 | beta = Variable(pro1/pro2, requires_grad=True) 326 | logits_correction = torch.log(logits_correction+1e-12) 327 | logits_single = torch.log(logits_single + 1e-12) 328 | loss_ = beta * F.nll_loss(logits_single, labels[j].unsqueeze(0)) 329 | loss += loss_ 330 | logits_correction_total[j, :] = logits_correction 331 | logits_correction_total = logits_correction_total.cuda() 332 | loss = loss / len(labels) 333 | prec1, = accuracy(logits_correction_total, labels, topk=(1, )) 334 | train_total+=1 335 | train_correct+=prec1 336 | loss.backward() 337 | optimizer.step() 338 | 339 | if (i+1) % args.print_freq == 0: 340 | print('Epoch [%d/%d], Iter [%d/%d] Training Accuracy: %.4F, Loss: %.4f' 341 | %(epoch+1, args.n_epoch_2, i+1, len(train_dataset)//batch_size, prec1, loss.item())) 342 | 343 | train_acc=float(train_correct)/float(train_total) 344 | return train_acc 345 | 346 | def val_correction(model, val_loader, epoch, W_group, basis_matrix_group, batch_size, num_classes, basis): 347 | print('Validating %s...' % model_str) 348 | 349 | val_total=0 350 | val_correct=0 351 | 352 | loss_total = 0. 353 | for i, (data, labels) in enumerate(val_loader): 354 | 355 | data = data.cuda() 356 | labels = labels.cuda() 357 | 358 | # Forward + Backward + Optimize 359 | loss = 0. 360 | _, logits=model(data, revision=False) 361 | 362 | logits_ = F.softmax(logits, dim=1) 363 | logits_correction_total = torch.zeros(len(labels), num_classes) 364 | for j in range(len(labels)): 365 | idx = i * batch_size + j 366 | matrix = matrix_combination(basis_matrix_group, W_group, idx, num_classes, basis) 367 | matrix = norm(matrix) 368 | matrix = torch.from_numpy(matrix).float().cuda() 369 | 370 | logits_single = logits_[j, :].unsqueeze(0) 371 | logits_correction = logits_single.mm(matrix) 372 | pro1 = logits_single[:, labels[j]] 373 | pro2 = logits_correction[:, labels[j]] 374 | beta = Variable(pro1/pro2, requires_grad=False) 375 | logits_correction = torch.log(logits_correction+1e-8) 376 | loss_ = beta * F.nll_loss(logits_correction, labels[j].unsqueeze(0)) 377 | if torch.isnan(loss_) == True: 378 | loss_ = 0. 379 | loss += loss_ 380 | logits_correction_total[j, :] = logits_correction 381 | 382 | logits_correction_total = logits_correction_total.cuda() 383 | loss = loss / len(labels) 384 | prec1, = accuracy(logits_correction_total, labels, topk=(1, )) 385 | val_total+=1 386 | val_correct+=prec1 387 | 388 | loss_total += loss.item() 389 | 390 | if (i+1) % args.print_freq == 0: 391 | print('Epoch [%d/%d], Iter [%d/%d] Training Accuracy: %.4F, Loss: %.4f' 392 | %(epoch+1, args.n_epoch_2, i+1, len(train_dataset)//batch_size, prec1, loss.item())) 393 | 394 | val_acc=float(val_correct)/float(val_total) 395 | 396 | return val_acc 397 | 398 | 399 | def train_revision(model, train_loader, epoch, optimizer, W_group, basis_matrix_group, batch_size, num_classes, basis): 400 | print('Training %s...' % model_str) 401 | 402 | train_total=0 403 | train_correct=0 404 | 405 | for i, (data, labels) in enumerate(train_loader): 406 | 407 | data = data.cuda() 408 | labels = labels.cuda() 409 | loss = 0. 410 | # Forward + Backward + Optimize 411 | optimizer.zero_grad() 412 | _, logits, revision = model(data, revision=True) 413 | 414 | 415 | logits_ = F.softmax(logits, dim=1) 416 | logits_correction_total = torch.zeros(len(labels), num_classes) 417 | for j in range(len(labels)): 418 | idx = i * batch_size + j 419 | matrix = matrix_combination(basis_matrix_group, W_group, idx, num_classes, basis) 420 | matrix = torch.from_numpy(matrix).float().cuda() 421 | matrix = tools.norm(matrix + revision) 422 | 423 | logits_single = logits_[j, :].unsqueeze(0) 424 | logits_correction = logits_single.mm(matrix) 425 | pro1 = logits_single[:, labels[j]] 426 | pro2 = logits_correction[:, labels[j]] 427 | beta = pro1/ pro2 428 | logits_correction = torch.log(logits_correction+1e-12) 429 | logits_single = torch.log(logits_single+1e-12) 430 | loss_ = beta * F.nll_loss(logits_single, labels[j].unsqueeze(0)) 431 | loss += loss_ 432 | logits_correction_total[j, :] = logits_correction 433 | logits_correction_total = logits_correction_total.cuda() 434 | loss = loss / len(labels) 435 | prec1, = accuracy(logits_correction_total, labels, topk=(1, )) 436 | train_total+=1 437 | train_correct+=prec1 438 | 439 | loss.backward() 440 | optimizer.step() 441 | 442 | if (i+1) % args.print_freq == 0: 443 | print('Epoch [%d/%d], Iter [%d/%d] Train Accuracy: %.4F, Loss: %.4f' 444 | %(epoch+1, args.n_epoch_3, i+1, len(train_dataset)//batch_size, prec1, loss.item())) 445 | 446 | train_acc=float(train_correct)/float(train_total) 447 | return train_acc 448 | 449 | 450 | def val_revision(model, train_loader, epoch, W_group, basis_matrix_group, batch_size, num_classes, basis): 451 | 452 | val_total=0 453 | val_correct=0 454 | 455 | for i, (data, labels) in enumerate(train_loader): 456 | model.eval() 457 | data = data.cuda() 458 | labels = labels.cuda() 459 | loss = 0. 460 | # Forward + Backward + Optimize 461 | 462 | _, logits, revision = model(data, revision=True) 463 | 464 | logits_ = F.softmax(logits, dim=1) 465 | logits_correction_total = torch.zeros(len(labels), num_classes) 466 | for j in range(len(labels)): 467 | idx = i * batch_size + j 468 | matrix = matrix_combination(basis_matrix_group, W_group, idx, num_classes, basis) 469 | matrix = torch.from_numpy(matrix).float().cuda() 470 | matrix = tools.norm(matrix + revision) 471 | logits_single = logits_[j, :].unsqueeze(0) 472 | logits_correction = logits_single.mm(matrix) 473 | pro1 = logits_single[:, labels[j]] 474 | pro2 = logits_correction[:, labels[j]] 475 | beta = Variable(pro1/pro2, requires_grad=True) 476 | logits_correction = torch.log(logits_correction+1e-12) 477 | loss_ = beta * F.nll_loss(logits_correction, labels[j].unsqueeze(0)) 478 | loss += loss_ 479 | logits_correction_total[j, :] = logits_correction 480 | logits_correction_total = logits_correction_total.cuda() 481 | prec1, = accuracy(logits_correction_total, labels, topk=(1, )) 482 | val_total+=1 483 | val_correct+=prec1 484 | if (i+1) % args.print_freq == 0: 485 | print('Epoch [%d/%d], Iter [%d/%d] Val Accuracy: %.4F, Loss: %.4f' 486 | %(epoch+1, args.n_epoch_3, i+1, len(val_dataset)//batch_size, prec1, loss.item())) 487 | 488 | val_acc = float(val_correct)/float(val_total) 489 | 490 | return val_acc 491 | 492 | 493 | 494 | 495 | 496 | 497 | # Evaluate the Model 498 | def evaluate(test_loader, model): 499 | print('Evaluating %s...' % model_str) 500 | model.eval() # Change model to 'eval' mode. 501 | correct1 = 0 502 | total1 = 0 503 | for data, labels in test_loader: 504 | 505 | data = data.cuda() 506 | _, logits = model(data, revision=False) 507 | outputs = F.softmax(logits, dim=1) 508 | _, pred1 = torch.max(outputs.data, 1) 509 | total1 += labels.size(0) 510 | correct1 += (pred1.cpu() == labels.long()).sum() 511 | 512 | acc = 100*float(correct1)/float(total1) 513 | 514 | return acc 515 | 516 | 517 | def respresentations_extract(train_loader, model, num_sample, dim_respresentations, batch_size): 518 | 519 | model.eval() 520 | A = torch.rand(num_sample, dim_respresentations) 521 | ind = int(num_sample / batch_size) 522 | with torch.no_grad(): 523 | for i, (data, labels) in enumerate(train_loader): 524 | data = data.cuda() 525 | logits, _ = model(data, revision=False) 526 | if i < ind: 527 | A[i*batch_size:(i+1)*batch_size, :] = logits 528 | else: 529 | A[ind*batch_size:, :] = logits 530 | 531 | return A.cpu().numpy() 532 | 533 | 534 | def probability_extract(train_loader, model, num_sample, num_classes, batch_size): 535 | 536 | model.eval() 537 | A = torch.rand(num_sample, num_classes) 538 | ind = int(num_sample / batch_size) 539 | with torch.no_grad(): 540 | for i, (data, labels) in enumerate(train_loader): 541 | data = data.cuda() 542 | _ , logits = model(data, revision=False) 543 | logits = F.softmax(logits, dim=1) 544 | if i < ind: 545 | A[i*batch_size:(i+1)*batch_size, :] = logits 546 | else: 547 | A[ind*batch_size:, :] = logits 548 | 549 | return A.cpu().numpy() 550 | 551 | 552 | 553 | def estimate_matrix(logits_matrix, model_save_dir): 554 | transition_matrix_group = np.empty((args.basis, args.num_classes, args.num_classes)) 555 | idx_matrix_group = np.empty((args.num_classes, args.basis)) 556 | a = np.linspace(97, 99, args.basis) 557 | a = list(a) 558 | for i in range(len(a)): 559 | percentage = a[i] 560 | index = int(i) 561 | logits_matrix_ = copy.deepcopy(logits_matrix) 562 | transition_matrix, idx = tools.fit(logits_matrix_, args.num_classes, percentage, True) 563 | transition_matrix = norm(transition_matrix) 564 | idx_matrix_group[:, index] = np.array(idx) 565 | transition_matrix_group[index] = transition_matrix 566 | idx_group_save_dir = model_save_dir + '/' + 'idx_group.npy' 567 | group_save_dir = model_save_dir + '/' + 'T_group.npy' 568 | np.save(idx_group_save_dir, idx_matrix_group) 569 | np.save(group_save_dir, transition_matrix_group) 570 | return idx_matrix_group, transition_matrix_group 571 | 572 | def basis_matrix_optimize(model, optimizer, basis, num_classes, W_group, transition_matrix_group, idx_matrix_group, func, model_save_dir, epochs): 573 | basis_matrix_group = np.empty((basis, num_classes, num_classes)) 574 | 575 | for i in range(num_classes): 576 | 577 | model = tools.init_params(model) 578 | for epoch in range(epochs): 579 | loss_total = 0. 580 | for j in range(basis): 581 | class_1_idx = int(idx_matrix_group[i, j]) 582 | W = list(np.array(W_group[class_1_idx, :])) 583 | T = torch.from_numpy(transition_matrix_group[j, i, :][:, np.newaxis]).float() 584 | prediction = model(W[0], num_classes) 585 | optimizer.zero_grad() 586 | loss = func(prediction, T) 587 | loss.backward() 588 | optimizer.step() 589 | loss_total += loss 590 | if loss_total < 0.02: 591 | break 592 | 593 | for x in range(basis): 594 | parameters = np.array(model.basis_matrix[x].weight.data) 595 | 596 | basis_matrix_group[x, i, :] = parameters 597 | A_save_dir = model_save_dir + '/' + 'A.npy' 598 | np.save(A_save_dir, basis_matrix_group) 599 | return basis_matrix_group 600 | 601 | 602 | def matrix_combination(basis_matrix_group, W_group, idx, num_classes, basis): 603 | coefficient = W_group[idx, :] 604 | 605 | M = np.zeros((num_classes, num_classes)) 606 | for i in range(basis): 607 | 608 | temp = float(coefficient[0, i]) * basis_matrix_group[i, :, :] 609 | M += temp 610 | for i in range(M.shape[0]): 611 | for j in range(M.shape[1]): 612 | if M[i,j]<1e-6: 613 | M[i,j] = 0. 614 | return M 615 | 616 | 617 | 618 | def main(): 619 | # Data Loader (Input Pipeline) 620 | print('loading dataset...') 621 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 622 | batch_size=batch_size, 623 | num_workers=args.num_workers, 624 | drop_last=False, 625 | shuffle=False) 626 | 627 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 628 | batch_size=batch_size, 629 | num_workers=args.num_workers, 630 | drop_last=False, 631 | shuffle=False) 632 | 633 | 634 | 635 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 636 | batch_size=batch_size, 637 | num_workers=args.num_workers, 638 | drop_last=False, 639 | shuffle=False) 640 | # Define models 641 | print('building model...') 642 | if args.dataset == 'mnist': 643 | clf1 = LeNet() 644 | if args.dataset == 'fashionmnist': 645 | clf1 = resnet.ResNet18_F(10) 646 | if args.dataset == 'cifar10': 647 | clf1 = resnet.ResNet34(10) 648 | if args.dataset == 'svhn': 649 | clf1 = resnet.ResNet34(10) 650 | 651 | clf1.cuda() 652 | optimizer = torch.optim.SGD(clf1.parameters(), lr=args.lr, weight_decay=args.weight_decay) 653 | 654 | with open(txtfile, "a") as myfile: 655 | myfile.write('epoch train_acc val_acc test_acc\n') 656 | 657 | epoch = 0 658 | train_acc = 0 659 | val_acc = 0 660 | # evaluate models with random weights 661 | test_acc=evaluate(test_loader, clf1) 662 | print('Epoch [%d/%d] Test Accuracy on the %s test data: Model1 %.4f %%' % (epoch+1, args.n_epoch_1, len(test_dataset), test_acc)) 663 | # save results 664 | with open(txtfile, "a") as myfile: 665 | myfile.write(str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) + ' ' + str(test_acc) + ' ' + "\n") 666 | 667 | 668 | best_acc = 0.0 669 | # training 670 | for epoch in range(1, args.n_epoch_1): 671 | # train models 672 | clf1.train() 673 | train_acc = train(clf1, train_loader, epoch, optimizer, nn.CrossEntropyLoss()) 674 | # validation 675 | val_acc = evaluate(val_loader, clf1) 676 | # evaluate models 677 | test_acc = evaluate(test_loader, clf1) 678 | 679 | 680 | # save results 681 | print('Epoch [%d/%d] Train Accuracy on the %s train data: Model %.4f %%' % (epoch+1, args.n_epoch_1, len(train_dataset), train_acc)) 682 | print('Epoch [%d/%d] Val Accuracy on the %s val data: Model %.4f %% ' % (epoch+1, args.n_epoch_1, len(val_dataset), val_acc)) 683 | print('Epoch [%d/%d] Test Accuracy on the %s test data: Model %.4f %% ' % (epoch+1, args.n_epoch_1, len(test_dataset), test_acc)) 684 | with open(txtfile, "a") as myfile: 685 | myfile.write(str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) + ' ' + str(test_acc) + ' ' + "\n") 686 | 687 | if val_acc > best_acc: 688 | best_acc = val_acc 689 | torch.save(clf1.state_dict(), model_save_dir + '/'+ 'model.pth') 690 | 691 | print('Matrix Factorization is doing...') 692 | clf1.load_state_dict(torch.load(model_save_dir + '/'+ 'model.pth')) 693 | A = respresentations_extract(train_loader, clf1, len(train_dataset), args.dim, batch_size) 694 | A_val = respresentations_extract(val_loader, clf1, len(val_dataset), args.dim, batch_size) 695 | A_total = np.append(A, A_val, axis=0) 696 | W_total, H_total ,error= train_m(A_total, args.basis, args.iteration_nmf, 1e-5) 697 | for i in range(W_total.shape[0]): 698 | for j in range(W_total.shape[1]): 699 | if W_total[i,j]<1e-6: 700 | W_total[i,j] = 0. 701 | W = W_total[0:len(train_dataset), :] 702 | W_val = W_total[len(train_dataset):, :] 703 | print('Transition Matrix is estimating...Wating...') 704 | logits_matrix = probability_extract(train_loader, clf1, len(train_dataset), args.num_classes, batch_size) 705 | idx_matrix_group, transition_matrix_group = estimate_matrix(logits_matrix, model_save_dir) 706 | logits_matrix_val = probability_extract(val_loader, clf1, len(val_dataset), args.num_classes, batch_size) 707 | idx_matrix_group_val, transition_matrix_group_val = estimate_matrix(logits_matrix_val, model_save_dir) 708 | func = nn.MSELoss() 709 | 710 | model = Matrix_optimize(args.basis, args.num_classes) 711 | optimizer_1 = torch.optim.Adam(model.parameters(), lr=0.001) 712 | basis_matrix_group = basis_matrix_optimize(model, optimizer_1, args.basis, args.num_classes, W, 713 | transition_matrix_group, idx_matrix_group, func, model_save_dir, args.n_epoch_4) 714 | 715 | basis_matrix_group_val = basis_matrix_optimize(model, optimizer_1, args.basis, args.num_classes, W_val, 716 | transition_matrix_group_val, idx_matrix_group_val, func, model_save_dir, args.n_epoch_4) 717 | 718 | for i in range(basis_matrix_group.shape[0]): 719 | for j in range(basis_matrix_group.shape[1]): 720 | for k in range(basis_matrix_group.shape[2]): 721 | if basis_matrix_group[i, j, k] < 1e-6: 722 | basis_matrix_group[i, j, k] = 0. 723 | 724 | optimizer_ = torch.optim.SGD(clf1.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) 725 | 726 | 727 | best_acc = 0.0 728 | for epoch in range(1, args.n_epoch_2): 729 | # train model 730 | clf1.train() 731 | 732 | train_acc = train_correction(clf1, train_loader, epoch, optimizer_, W, basis_matrix_group, batch_size, args.num_classes, args.basis) 733 | # validation 734 | val_acc = val_correction(clf1, val_loader, epoch, W_val, basis_matrix_group_val, batch_size, args.num_classes, args.basis) 735 | 736 | # evaluate models 737 | test_acc = evaluate(test_loader, clf1) 738 | if val_acc > best_acc: 739 | best_acc = val_acc 740 | torch.save(clf1.state_dict(), model_save_dir + '/'+ 'model.pth') 741 | with open(txtfile, "a") as myfile: 742 | myfile.write(str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) + ' ' + str(test_acc) + ' ' + "\n") 743 | # save results 744 | print('Epoch [%d/%d] Train Accuracy on the %s train data: Model %.4f %%' % (epoch+1, args.n_epoch_2, len(train_dataset), train_acc)) 745 | print('Epoch [%d/%d] Val Accuracy on the %s val data: Model %.4f %% ' % (epoch+1, args.n_epoch_2, len(val_dataset), val_acc)) 746 | print('Epoch [%d/%d] Test Accuracy on the %s test data: Model %.4f %% ' % (epoch+1, args.n_epoch_2, len(test_dataset), test_acc)) 747 | 748 | clf1.load_state_dict(torch.load(model_save_dir + '/'+ 'model.pth')) 749 | optimizer_r = torch.optim.Adam(clf1.parameters(), lr=args.lr_revision, weight_decay=args.weight_decay) 750 | nn.init.constant_(clf1.T_revision.weight, 0.0) 751 | 752 | for epoch in range(1, args.n_epoch_3): 753 | # train models 754 | clf1.train() 755 | train_acc = train_revision(clf1, train_loader, epoch, optimizer_r, W, basis_matrix_group, batch_size, args.num_classes, args.basis) 756 | # validation 757 | val_acc = val_revision(clf1, val_loader, epoch, W_val, basis_matrix_group, batch_size, args.num_classes, args.basis) 758 | # evaluate models 759 | test_acc = evaluate(test_loader, clf1) 760 | with open(txtfile, "a") as myfile: 761 | myfile.write(str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) + ' ' + str(test_acc) + ' ' + "\n") 762 | 763 | # save results 764 | print('Epoch [%d/%d] Train Accuracy on the %s train data: Model %.4f %%' % (epoch+1, args.n_epoch_3, len(train_dataset), train_acc)) 765 | print('Epoch [%d/%d] Val Accuracy on the %s val data: Model %.4f %% ' % (epoch+1, args.n_epoch_3, len(val_dataset), val_acc)) 766 | print('Epoch [%d/%d] Test Accuracy on the %s test data: Model %.4f %% ' % (epoch+1, args.n_epoch_3, len(test_dataset), test_acc)) 767 | 768 | if __name__=='__main__': 769 | main() 770 | --------------------------------------------------------------------------------