├── .gitignore ├── README.md ├── fake_data.csv ├── main.py ├── training_mentornet ├── data_generator.py ├── models.py ├── reader.py └── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.ipynb 3 | *.model 4 | .ipynb_checkpoints 5 | *.p -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MentorNet 2 | pytorch version 3 | 4 | Related paper: 5 | **MentorNet: Learning Data-Driven Curriculum for Very Deep Neural Networks on Corrupted Labels 6 | ** 7 |
8 | Lu Jiang, Zhengyuan Zhou, Thomas Leung, Li-Jia Li, Li Fei-Fei 9 |
10 | Presented at [ICML 2018](https://icml.cc/Conferences/2018) 11 | 12 | Related code: 13 | [MentorNet(google)](https://github.com/google/mentornet) 14 | 15 | 16 | ## usage (how to train mentornet (NOT with StudentNet)) 17 | 18 | - first you need to train your student model on a noisy dataset which you have a corresponding clean version. 19 | - store the loss, epoch and label in a csv file. the format is as below: 20 | ``` 21 | 'id' 'epoch' 'noisy label' 'clean label' 'loss on the noisy label' 22 | ... 23 | ``` 24 | there is a sample csv file `fake_data.csv` 25 | - if you want to preprocess the csv file, use these command: 26 | ``` 27 | python main.py --process_data=true --raw_csv="\path" --data_path="save\path\to" 28 | ``` 29 | 30 | if you want to train mentornet on the dataset, use these: 31 | ``` 32 | python main.py --process_data=false --processed_path="\path\to\blah_percentile_40" --epoch=10 --device="cpu" --batch_size=32 --show_progress_bar=false 33 | ``` 34 | 35 | 36 | UPDATE: 37 | - 8.2.2020: add mentornet_nn class, it works but I am not sure if it is correct. 38 | - 8.3.2020: add MentorNet class for training MentorNet_nn with StudentNet, it can run successfully on cuda. Same, I am not sure if it is correct. 39 | - 8.4.2020: add dataset, dataloader and data_generator. the origin tf version didn't use much 'tf' in this part, so I just copy that here. 40 | - 8.7.2020: add MentorNet trainer class, HAVEN'T TESTED. 41 | - 8.8.2020: MentorNet_nn can be trained using trainer in train.py. The training loss is decreasing, so I guess it works to some extend. 42 | - 4.30.2022: fix issue [#1](https://github.com/Furyton/MentorNet_pytorch/issues/1#issue-1221720127), where the `upper_bound` in `utils.py` is wrong :( 43 | -------------------------------------------------------------------------------- /fake_data.csv: -------------------------------------------------------------------------------- 1 | 0 22 2 2 0.05691285813466621 2 | 1 76 2 2 0.056815593200683634 3 | 2 60 1 1 0.12056107374301746 4 | 3 1 2 2 0.159492985471586 5 | 4 44 1 0 6.717469349361247 6 | 5 90 1 2 6.830548000255914 7 | 6 23 2 0 6.290676035999195 8 | 7 76 0 2 6.329583181283654 9 | 8 87 0 1 8.661717142980276 10 | 9 74 1 2 11.134296615736087 11 | 10 21 2 0 1.0346242189673147 12 | 11 54 1 0 2.5055557234120225 13 | 12 60 2 0 2.606250696154226 14 | 13 33 1 1 0.046921637365221695 15 | 14 22 1 1 0.11263098118565658 16 | 15 69 0 1 2.2997700224204984 17 | 16 65 0 0 0.0373146238042274 18 | 17 16 2 1 4.226040600549471 19 | 18 3 0 1 2.1201052851379556 20 | 19 52 2 0 0.3958823733131397 21 | 20 87 0 1 8.213779736066613 22 | 21 75 1 1 0.001494964358933663 23 | 22 53 1 2 5.873861773963313 24 | 23 91 2 2 0.0822743114142723 25 | 24 51 2 0 13.85085120645872 26 | 25 28 1 2 32.23975933838351 27 | 26 53 0 0 0.04190800574217698 28 | 27 85 0 0 0.032179821968826246 29 | 28 47 2 2 0.04158338452345285 30 | 29 90 1 0 7.140184943335309 31 | 30 42 1 1 0.020865305930664692 32 | 31 67 1 0 4.731765464995269 33 | 32 82 0 0 0.021521450774386428 34 | 33 38 1 2 5.740770764387923 35 | 34 15 2 1 0.5084269640505258 36 | 35 25 0 0 0.036877018953376736 37 | 36 26 2 2 0.11262093065124061 38 | 37 15 1 2 0.014155369731683451 39 | 38 43 1 0 3.5555985792814897 40 | 39 56 2 2 0.004519786207402029 41 | 40 67 1 2 2.6183489473147032 42 | 41 32 0 0 0.10131803358206182 43 | 42 70 0 0 0.1614432338564425 44 | 43 71 1 0 4.812864434391505 45 | 44 56 0 2 12.192982673556276 46 | 45 34 2 1 2.5368475292864465 47 | 46 28 0 0 0.0982702075428674 48 | 47 21 0 1 2.4557079474370593 49 | 48 5 0 1 7.883699702627938 50 | 49 84 1 0 10.236504794282741 51 | 50 69 0 1 0.04795654629913004 52 | 51 13 0 2 15.341296555343432 53 | 52 3 2 2 0.04530274093982806 54 | 53 38 2 0 1.810023426142629 55 | 54 65 1 2 6.98936511943484 56 | 55 24 2 0 21.906559878399733 57 | 56 14 1 1 0.19793085185541484 58 | 57 83 1 1 0.040657945522877804 59 | 58 60 1 1 0.03176805328540366 60 | 59 15 2 2 0.1001530071720402 61 | 60 65 0 0 0.08902024510259868 62 | 61 33 2 0 5.142237543565848 63 | 62 76 1 0 0.577218405310012 64 | 63 78 1 0 12.171434481977068 65 | 64 76 1 2 12.243057878292948 66 | 65 62 2 1 17.85465493768766 67 | 66 82 0 0 0.24268996846231058 68 | 67 33 0 0 0.04774318613669772 69 | 68 84 2 1 4.216486702193051 70 | 69 35 2 2 0.10726174246600004 71 | 70 6 0 0 0.04715024202317429 72 | 71 95 0 2 2.084648494815446 73 | 72 34 1 0 4.863903369606167 74 | 73 82 0 2 6.79074720647762 75 | 74 31 2 0 12.214015239022762 76 | 75 1 1 0 1.3310131668497027 77 | 76 9 1 2 11.981662354229353 78 | 77 53 0 0 0.10528355573677413 79 | 78 51 2 2 0.11912398273087903 80 | 79 62 0 2 2.1182673862837937 81 | 80 67 0 0 0.06209842396361868 82 | 81 64 1 0 3.6901312215305797 83 | 82 4 2 1 2.3392164515235017 84 | 83 26 2 0 2.391972242310052 85 | 84 74 2 2 0.06515422660350689 86 | 85 94 2 1 6.57543424894547 87 | 86 90 2 2 0.16321269508860758 88 | 87 69 2 2 0.045674089700306154 89 | 88 71 2 0 7.9125040784755445 90 | 89 50 0 2 14.63713403140689 91 | 90 2 0 1 6.79027110338563 92 | 91 85 2 2 0.024759614974846945 93 | 92 82 1 1 0.005775853155256345 94 | 93 83 1 0 5.897403137391079 95 | 94 61 2 0 7.265893281006673 96 | 95 56 0 1 9.022589019442128 97 | 96 58 0 1 7.983617541976239 98 | 97 93 2 0 10.9563978498973 99 | 98 11 2 0 13.395929245851262 100 | 99 83 0 0 0.12207978508741155 101 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ast import arg, parse 3 | import training_mentornet.train as train 4 | import training_mentornet.data_generator as data_generator 5 | 6 | csv_file_path = "/data/lizongbu-slurm/furyton/mentornet/MentorNet_pytorch/fake_data.csv" 7 | 8 | preprocess_data_path = "processed_data/fake_data_percentile_40" 9 | 10 | def str2bool(v): 11 | return v.lower() in ('true') 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument('--process_data', type=str2bool) 17 | parser.add_argument('--raw_csv', type=str, default=None, help="raw csv file path") 18 | parser.add_argument('--data_path', type=str, default=None, help="where you want to save the processed dataset") 19 | parser.add_argument('--processed_path', type=str, default=None) 20 | parser.add_argument('--train_dir', type=str, default='trial') 21 | parser.add_argument('--epoch',type=int, default=10) 22 | parser.add_argument('--device',type=str,default='cpu') 23 | parser.add_argument('--lr',type=float,default=0.1) 24 | parser.add_argument('--batch_size', type=int,default=32) 25 | parser.add_argument('--show_progress_bar',type=str2bool, default=False) 26 | 27 | config = parser.parse_args() 28 | 29 | if config.process_data: 30 | data_generator.generate_data_driven(config.raw_csv, config.data_path) 31 | else: 32 | tr = train.trainer(train_dir=config.train_dir, data_path=config.processed_path, show_progress_bar=config.show_progress_bar, epoch=config.epoch, mini_batch_size=config.batch_size, device=config.device) 33 | 34 | tr.train() -------------------------------------------------------------------------------- /training_mentornet/data_generator.py: -------------------------------------------------------------------------------- 1 | """Generates training data for learning/updating MentorNet.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import csv 9 | import torch 10 | import pickle 11 | import itertools 12 | import numpy as np 13 | 14 | 15 | def read_from_csv(input_csv_file): 16 | """Reads Data from an input CSV file. 17 | 18 | Args: 19 | input_csv_file: the path of the CSV file. 20 | 21 | Returns: 22 | a numpy array with different data at each index: 23 | """ 24 | data = {} 25 | with open(input_csv_file, 'r') as csv_file_in: 26 | reader = csv.reader(csv_file_in) 27 | for row in reader: 28 | for (_, cell) in enumerate(row): 29 | rdata = cell.strip().split(' ') 30 | rid = rdata[0] 31 | rdata = [float(t) for t in rdata[1:]] 32 | data[rid] = rdata 33 | csv_file_in.close() 34 | return data 35 | 36 | 37 | def generate_data_driven(input_csv_filename, 38 | outdir, 39 | percentile_range='40,50,60,70,80,90'): 40 | """Generates a data-driven trainable dataset, given a CSV. 41 | 42 | Refer to README.md for details on how to format the CSV. 43 | 44 | Args: 45 | input_csv_filename: the path of the CSV file. The csv file format 46 | 0: epoch_percentage 47 | 1: noisy label 48 | 2: clean label 49 | 3: loss 50 | outdir: directory to save the training data. 51 | percentile_range: the percentiles used to compute the moving average. 52 | """ 53 | raw = read_from_csv(input_csv_filename) 54 | 55 | raw = np.array([i for i in raw.values()]) 56 | dataset_name = os.path.splitext(os.path.basename(input_csv_filename))[0] 57 | 58 | percentile_range = percentile_range.split(',') 59 | percentile_range = [int(x) for x in percentile_range] 60 | 61 | for percentile in percentile_range: 62 | percentile = int(percentile) 63 | p_perncentile = np.percentile(raw[:, 3], percentile) 64 | 65 | v_star = np.float32(raw[:, 1] == raw[:, 2]) 66 | 67 | l = raw[:, 3] 68 | diff = raw[:, 3] - p_perncentile 69 | # label not used in the current version. 70 | y = np.array([0] * len(v_star)) 71 | epoch_percentage = raw[:, 0] 72 | 73 | data = np.vstack((l, diff, y, epoch_percentage, v_star)) 74 | data = np.transpose(data) 75 | 76 | perm = np.arange(data.shape[0]) 77 | np.random.shuffle(perm) 78 | data = data[perm,] 79 | 80 | tr_size = int(data.shape[0] * 0.8) 81 | 82 | tr = data[0:tr_size] 83 | ts = data[(tr_size + 1):data.shape[0]] 84 | 85 | cur_outdir = os.path.join( 86 | outdir, '{}_percentile_{}'.format(dataset_name, percentile)) 87 | if not os.path.exists(cur_outdir): 88 | os.makedirs(cur_outdir) 89 | 90 | print('training_shape={} test_shape={}'.format(tr.shape, ts.shape)) 91 | print(cur_outdir) 92 | with open(os.path.join(cur_outdir, 'tr.p'), 'wb') as outfile: 93 | pickle.dump(tr, outfile) 94 | 95 | with open(os.path.join(cur_outdir, 'ts.p'), 'wb') as outfile: 96 | pickle.dump(ts, outfile) 97 | -------------------------------------------------------------------------------- /training_mentornet/models.py: -------------------------------------------------------------------------------- 1 | # Baseline Models, not used for now -------------------------------------------------------------------------------- /training_mentornet/reader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import torch 7 | import pickle 8 | import numpy as np 9 | import torch.utils.data as data_utils 10 | 11 | class Dataset(data_utils.Dataset): 12 | def __init__(self, indir, split_name) -> None: 13 | super().__init__() 14 | self._data = pickle.load(open(os.path.join(indir, split_name + '.p'), 'rb')) 15 | self._num_examples = self._data.shape[0] 16 | self.feat_dim = self._data.shape[1] - 1 17 | self._epochs_completed = 0 18 | self._index_in_epoch = 0 19 | 20 | def __len__(self) -> int: 21 | return self._num_examples 22 | 23 | @property 24 | def is_binary_label(self): 25 | unique_labels = np.unique(self._data[:, -1]) 26 | if len(unique_labels) == 2 and (0 in unique_labels) and ( 27 | 1 in unique_labels): 28 | return True 29 | return False 30 | 31 | def __getitem__(self, index: int): 32 | return torch.tensor(self._data[index]) 33 | 34 | def get_train_dataloader(data_path: str, device: str='cpu', batch_size: int=32, worker_num: int=2): 35 | return data_utils.DataLoader(Dataset(data_path, 'tr'), batch_size=batch_size, shuffle=True, num_workers=worker_num,pin_memory=True) 36 | 37 | def get_test_dataloader(data_path: str, device: str='cpu', batch_size: int=32, worker_num: int=2): 38 | return data_utils.DataLoader(Dataset(data_path, 'ts'), batch_size=batch_size ,num_workers=worker_num,pin_memory=True) 39 | -------------------------------------------------------------------------------- /training_mentornet/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import utils 4 | import torch 5 | from . import reader 6 | import datetime 7 | import numpy as np 8 | from abc import ABCMeta 9 | 10 | 11 | # train_dir = '' 12 | # data_path = '' 13 | # device = 'cpu' 14 | # mini_batch_size = 32 15 | # max_step_train = 3e4 16 | # learning_rate = 0.1 17 | # worker_num = 2 18 | # epoch = 2 19 | 20 | 21 | class trainer(metaclass=ABCMeta): 22 | def __init__(self, train_dir, 23 | data_path, 24 | device='cpu', 25 | mini_batch_size=32, 26 | learning_rate=0.1, 27 | worker_num=2, 28 | epoch=2, 29 | show_progress_bar=False, 30 | is_binary_label=True): 31 | 32 | 33 | self.train_dir = train_dir 34 | self.data_path = data_path 35 | self.device = device 36 | self.mini_batch_size = mini_batch_size 37 | self.learning_rate = learning_rate 38 | self.worker_num = worker_num 39 | self.epoch = epoch 40 | self.show_progress_bar = show_progress_bar 41 | self.is_binary_label = is_binary_label 42 | 43 | if not os.path.exists(train_dir): 44 | os.makedirs(train_dir) 45 | 46 | self.train_dataLoader = reader.get_train_dataloader(data_path=data_path, 47 | device=device, 48 | batch_size=mini_batch_size, 49 | worker_num=worker_num) 50 | 51 | self.test_dataLoader = reader.get_test_dataloader(data_path=data_path, 52 | device=device, 53 | batch_size=mini_batch_size, 54 | worker_num=worker_num) 55 | 56 | self.model = utils.MentorNet_nn() 57 | self.BCEloss = torch.nn.BCEWithLogitsLoss(reduction='mean') 58 | self.MSEloss = torch.nn.MSELoss(reduction='mean') 59 | 60 | self.optim = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) 61 | 62 | self.lr_sheduler = torch.optim.lr_scheduler.ExponentialLR(self.optim,gamma=0.9) 63 | 64 | 65 | def train(self): 66 | val_loss = self.test() 67 | 68 | for epoch in range(self.epoch): 69 | print("start training epoch: ", epoch) 70 | 71 | self.train_one_epoch(epoch) 72 | 73 | cur_loss = self.test() 74 | 75 | if cur_loss < val_loss: 76 | print(f'progress') 77 | val_loss = cur_loss 78 | self.save('best') 79 | 80 | self.lr_sheduler.step() 81 | 82 | self.save('final') 83 | 84 | 85 | def train_one_epoch(self, epoch): 86 | self.model.train() 87 | 88 | iterator = self.train_dataLoader if not self.show_progress_bar else tqdm(self.train_dataLoader) 89 | 90 | tot_loss = 0 91 | tot_batch = 0 92 | 93 | for batch_idx, batch in enumerate(iterator): 94 | self.optim.zero_grad() 95 | 96 | loss = self.calculate_loss(batch) 97 | 98 | tot_loss += loss.item() 99 | 100 | tot_batch += 1 101 | 102 | loss.backward() 103 | 104 | self.optim.step() 105 | 106 | if self.show_progress_bar: 107 | iterator.set_description('Epoch {}, loss {:.3f} '.format(epoch + 1, tot_loss / tot_batch)) 108 | 109 | print(f'epoch: {epoch}, train loss: {tot_loss / tot_batch}') 110 | 111 | def test(self): 112 | self.model.eval() 113 | 114 | tot_loss = 0 115 | tot_batch = 0 116 | 117 | with torch.no_grad(): 118 | iterator = self.test_dataLoader if not self.show_progress_bar else tqdm(self.test_dataLoader) 119 | 120 | for batch_idx, batch in enumerate(iterator): 121 | loss = self.calculate_loss(batch) 122 | 123 | tot_loss += loss.item() 124 | tot_batch += 1 125 | 126 | if self.show_progress_bar: 127 | iterator.set_description('test loss {:.3f} '.format(tot_loss / tot_batch)) 128 | 129 | print('test loss=', tot_loss / tot_batch) 130 | 131 | return tot_loss / tot_batch 132 | 133 | def calculate_loss(self, batch:torch.Tensor): 134 | v_truth = batch[:, 4].reshape(-1, 1) 135 | input_data = batch[:, 0:4] 136 | 137 | v = self.model(input_data) 138 | 139 | if self.is_binary_label: 140 | loss = self.BCEloss(v, v_truth) 141 | else: 142 | loss = self.MSEloss(torch.sigmoid(v), v_truth) 143 | 144 | return loss 145 | 146 | def save(self, tag: str): 147 | torch.save({'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optim.state_dict()}, os.path.join(self.train_dir, '{}.model'.format(tag))) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | # import torch.nn.functional as F 5 | 6 | class MentorNet_nn(torch.nn.Module): 7 | def __init__(self, label_embedding_size=2, 8 | epoch_embedding_size=5, 9 | num_fc_nodes=20, 10 | device="cpu"): 11 | """ 12 | Args: 13 | label_embedding_size: the embedding size for the label feature. 14 | 15 | epoch_embedding_size: the embedding size for the epoch feature. 16 | 17 | num_fc_nodes: number of hidden nodes in the fc layer. 18 | Input: 19 | input_features: a [batch_size, 4] tensor. Each dimension corresponds to 20 | 0: loss, 1: loss difference to the moving average, 2: label and 3: epoch, 21 | where epoch is an integer between 0 and 99 (the first and the last epoch). 22 | 23 | input_feature: B x 4 24 | Output: 25 | v: [batch_size, 1] weight vector.s 26 | """ 27 | super(MentorNet_nn, self).__init__() 28 | 29 | self.device = device 30 | 31 | self.label_embedding = torch.nn.Embedding(num_embeddings=2, embedding_dim=label_embedding_size).to(device) 32 | 33 | self.epoch_embedding = torch.nn.Embedding(num_embeddings=100, embedding_dim=epoch_embedding_size).to(device) 34 | 35 | self.bi_lstm_cell = torch.nn.LSTM(input_size=2, hidden_size=1,bidirectional=True,batch_first=True,num_layers=1).to(device) 36 | 37 | self.feat_size = label_embedding_size + epoch_embedding_size + 2 38 | 39 | self.fc1 = torch.nn.Linear(self.feat_size, num_fc_nodes).to(device) 40 | self.fc2 = torch.nn.Linear(num_fc_nodes, 1, bias=True).to(device) 41 | 42 | def forward(self, input_features): 43 | input_features = input_features.to(self.device) 44 | losses = input_features[:, 0] 45 | loss_diffs = input_features[:, 1] 46 | 47 | lstm_inputs = torch.stack([losses, loss_diffs], dim=-1).to(self.device).to(torch.float32) 48 | 49 | if len(losses.shape) <= 1: 50 | num_steps = 1 51 | lstm_inputs.unsqueeze_(1) 52 | else: 53 | num_steps = int(losses.size()[1]) 54 | 55 | # lstm_inputs should be B x N x 2 56 | # where N is the num_steps, B is the batch size 57 | 58 | 59 | lstm_output, _ = self.bi_lstm_cell(lstm_inputs) 60 | 61 | # lstm_output should be B x N x 2 62 | # where '2' is due to bidirectional setting 63 | 64 | loss_variance = lstm_output.sum(1) # B x 2 65 | 66 | labels = input_features[:, 2].reshape((-1, 1)).to(torch.int64) 67 | 68 | epochs = input_features[:, 3].reshape((-1, 1)).to(torch.int64) 69 | epochs = torch.min(epochs, torch.ones([epochs.size()[0], 1], dtype=torch.int64).to(self.device) * 99).to(self.device) 70 | 71 | # epoch_embedding.weight.requires_grad = False 72 | 73 | label_inputs = self.label_embedding(labels).squeeze(1) # B x D 74 | epoch_inputs = self.epoch_embedding(epochs).squeeze(1) # B x D 75 | 76 | # print(label_inputs.size(), epoch_inputs.size(), loss_variance.size()) 77 | 78 | feat = torch.cat([label_inputs, epoch_inputs, loss_variance], -1).to(self.device) 79 | 80 | fc_1 = self.fc1(feat) 81 | output_1 = torch.tanh(fc_1) 82 | 83 | return self.fc2(output_1) 84 | 85 | 86 | class MentorNet(torch.nn.Module): 87 | def __init__(self, burn_in_epoch=18, 88 | fixed_epoch_after_burn_in = True, 89 | loss_moving_average_decay=0.9, 90 | device="cpu"): 91 | """ 92 | The MentorNet to train with the StudentNet. 93 | Args: 94 | burn_in_epoch: the number of burn_in_epoch. In the first burn_in_epoch, all samples have 1.0 weights. 95 | 96 | fixed_epoch_after_burn_in: whether to fix the epoch after the burn-in. 97 | 98 | loss_moving_average_decay: the decay factor to compute the moving average. 99 | Input: 100 | epoch: a tensor [batch_size, 1] representing the training percentage. Each epoch is an integer between 0 and 99. 101 | 102 | loss: a tensor [batch_size, 1] representing the sample loss. 103 | 104 | labels: a tensor [batch_size, 1] representing the label. Every label is set to 0 in the current version. 105 | 106 | loss_p_percentile: a 1-d tensor of size 100, where each element is the p-percentile at that epoch to compute the moving average. 107 | 108 | example_dropout_rates: a 1-d tensor of size 100, where each element is the dropout rate at that epoch. Dropping out means the probability of setting sample weights to zeros proposed in Liang, Junwei, et al. "Learning to Detect Concepts from Webly-Labeled Video Data." IJCAI. 2016. 109 | """ 110 | super(MentorNet, self).__init__() 111 | 112 | self.device = device 113 | 114 | self.fixed_epoch_after_burn_in = fixed_epoch_after_burn_in 115 | 116 | self.burn_in_epoch = burn_in_epoch 117 | 118 | self.loss_moving_average_decay = loss_moving_average_decay 119 | 120 | self.mentor = MentorNet_nn(device=device) 121 | 122 | self.loss_moving_avg = None 123 | 124 | def forward(self, epoch, loss, labels, loss_p_percentile, example_dropout_rates): 125 | # epoch : B x 1 126 | # loss : B x 1 127 | # labels: B x 1 128 | # loss_p_percentile: 100 129 | # example_dropout_rates: 100 130 | 131 | burn_in_epoch = torch.tensor([[self.burn_in_epoch]] * epoch.shape[0]).to(self.device) 132 | 133 | if not self.fixed_epoch_after_burn_in: 134 | cur_epoch = epoch 135 | else: 136 | cur_epoch = epoch.min(burn_in_epoch) 137 | 138 | # cur_epoch : B x 1 139 | 140 | v_ones = torch.ones(loss.size(), dtype=torch.float32).to(self.device) 141 | 142 | v_zeros = torch.zeros(loss.size(), dtype=torch.float32).to(self.device) 143 | 144 | upper_bound = torch.where(cur_epoch < burn_in_epoch - 1, v_ones, v_zeros).to(self.device) 145 | 146 | # TODO dangerous here 147 | this_dropout_rate = example_dropout_rates[cur_epoch][0][0] 148 | 149 | # TODO dangerous here 150 | this_percentile = loss_p_percentile[cur_epoch].squeeze() 151 | 152 | percentile_loss = torch.tensor(np.percentile(loss.cpu(), this_percentile.cpu()), dtype=torch.float32).unsqueeze(-1).to(self.device) 153 | 154 | # percentile_loss : B x 1 155 | 156 | if self.loss_moving_avg is None: 157 | self.loss_moving_avg = (1 - self.loss_moving_average_decay) * percentile_loss 158 | else: 159 | self.loss_moving_avg = self.loss_moving_avg * self.loss_moving_average_decay + (1 - self.loss_moving_average_decay) * percentile_loss 160 | 161 | # loss_moving_avg : B x 1 162 | 163 | # print(loss.size()) 164 | 165 | input_data = torch.stack([loss, self.loss_moving_avg, labels, cur_epoch.to(torch.float32)], 1).squeeze(-1).to(self.device) 166 | 167 | # print(input_data.size()) 168 | 169 | v = self.mentor(input_data).sigmoid().max(upper_bound) 170 | 171 | # print(torch.ceil(v.size()[0] * (1 - this_dropout_rate))) 172 | 173 | dropout_num = int(torch.ceil(v.size()[0] * (1 - this_dropout_rate)).item()) 174 | 175 | idx = torch.tensor(random.sample(range(v.size()[0]), dropout_num), dtype=torch.int64).to(self.device) 176 | 177 | dropout_v = torch.zeros(v.size()[0]).to(self.device) 178 | dropout_v[idx] = 1 179 | 180 | # dropout_v.dot() 181 | 182 | return (v.squeeze() * (dropout_v)).unsqueeze(-1) 183 | --------------------------------------------------------------------------------