├── .gitignore
├── README.md
├── fake_data.csv
├── main.py
├── training_mentornet
├── data_generator.py
├── models.py
├── reader.py
└── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.ipynb
3 | *.model
4 | .ipynb_checkpoints
5 | *.p
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MentorNet
2 | pytorch version
3 |
4 | Related paper:
5 | **MentorNet: Learning Data-Driven Curriculum for Very Deep Neural Networks on Corrupted Labels
6 | **
7 |
8 | Lu Jiang, Zhengyuan Zhou, Thomas Leung, Li-Jia Li, Li Fei-Fei
9 |
10 | Presented at [ICML 2018](https://icml.cc/Conferences/2018)
11 |
12 | Related code:
13 | [MentorNet(google)](https://github.com/google/mentornet)
14 |
15 |
16 | ## usage (how to train mentornet (NOT with StudentNet))
17 |
18 | - first you need to train your student model on a noisy dataset which you have a corresponding clean version.
19 | - store the loss, epoch and label in a csv file. the format is as below:
20 | ```
21 | 'id' 'epoch' 'noisy label' 'clean label' 'loss on the noisy label'
22 | ...
23 | ```
24 | there is a sample csv file `fake_data.csv`
25 | - if you want to preprocess the csv file, use these command:
26 | ```
27 | python main.py --process_data=true --raw_csv="\path" --data_path="save\path\to"
28 | ```
29 |
30 | if you want to train mentornet on the dataset, use these:
31 | ```
32 | python main.py --process_data=false --processed_path="\path\to\blah_percentile_40" --epoch=10 --device="cpu" --batch_size=32 --show_progress_bar=false
33 | ```
34 |
35 |
36 | UPDATE:
37 | - 8.2.2020: add mentornet_nn class, it works but I am not sure if it is correct.
38 | - 8.3.2020: add MentorNet class for training MentorNet_nn with StudentNet, it can run successfully on cuda. Same, I am not sure if it is correct.
39 | - 8.4.2020: add dataset, dataloader and data_generator. the origin tf version didn't use much 'tf' in this part, so I just copy that here.
40 | - 8.7.2020: add MentorNet trainer class, HAVEN'T TESTED.
41 | - 8.8.2020: MentorNet_nn can be trained using trainer in train.py. The training loss is decreasing, so I guess it works to some extend.
42 | - 4.30.2022: fix issue [#1](https://github.com/Furyton/MentorNet_pytorch/issues/1#issue-1221720127), where the `upper_bound` in `utils.py` is wrong :(
43 |
--------------------------------------------------------------------------------
/fake_data.csv:
--------------------------------------------------------------------------------
1 | 0 22 2 2 0.05691285813466621
2 | 1 76 2 2 0.056815593200683634
3 | 2 60 1 1 0.12056107374301746
4 | 3 1 2 2 0.159492985471586
5 | 4 44 1 0 6.717469349361247
6 | 5 90 1 2 6.830548000255914
7 | 6 23 2 0 6.290676035999195
8 | 7 76 0 2 6.329583181283654
9 | 8 87 0 1 8.661717142980276
10 | 9 74 1 2 11.134296615736087
11 | 10 21 2 0 1.0346242189673147
12 | 11 54 1 0 2.5055557234120225
13 | 12 60 2 0 2.606250696154226
14 | 13 33 1 1 0.046921637365221695
15 | 14 22 1 1 0.11263098118565658
16 | 15 69 0 1 2.2997700224204984
17 | 16 65 0 0 0.0373146238042274
18 | 17 16 2 1 4.226040600549471
19 | 18 3 0 1 2.1201052851379556
20 | 19 52 2 0 0.3958823733131397
21 | 20 87 0 1 8.213779736066613
22 | 21 75 1 1 0.001494964358933663
23 | 22 53 1 2 5.873861773963313
24 | 23 91 2 2 0.0822743114142723
25 | 24 51 2 0 13.85085120645872
26 | 25 28 1 2 32.23975933838351
27 | 26 53 0 0 0.04190800574217698
28 | 27 85 0 0 0.032179821968826246
29 | 28 47 2 2 0.04158338452345285
30 | 29 90 1 0 7.140184943335309
31 | 30 42 1 1 0.020865305930664692
32 | 31 67 1 0 4.731765464995269
33 | 32 82 0 0 0.021521450774386428
34 | 33 38 1 2 5.740770764387923
35 | 34 15 2 1 0.5084269640505258
36 | 35 25 0 0 0.036877018953376736
37 | 36 26 2 2 0.11262093065124061
38 | 37 15 1 2 0.014155369731683451
39 | 38 43 1 0 3.5555985792814897
40 | 39 56 2 2 0.004519786207402029
41 | 40 67 1 2 2.6183489473147032
42 | 41 32 0 0 0.10131803358206182
43 | 42 70 0 0 0.1614432338564425
44 | 43 71 1 0 4.812864434391505
45 | 44 56 0 2 12.192982673556276
46 | 45 34 2 1 2.5368475292864465
47 | 46 28 0 0 0.0982702075428674
48 | 47 21 0 1 2.4557079474370593
49 | 48 5 0 1 7.883699702627938
50 | 49 84 1 0 10.236504794282741
51 | 50 69 0 1 0.04795654629913004
52 | 51 13 0 2 15.341296555343432
53 | 52 3 2 2 0.04530274093982806
54 | 53 38 2 0 1.810023426142629
55 | 54 65 1 2 6.98936511943484
56 | 55 24 2 0 21.906559878399733
57 | 56 14 1 1 0.19793085185541484
58 | 57 83 1 1 0.040657945522877804
59 | 58 60 1 1 0.03176805328540366
60 | 59 15 2 2 0.1001530071720402
61 | 60 65 0 0 0.08902024510259868
62 | 61 33 2 0 5.142237543565848
63 | 62 76 1 0 0.577218405310012
64 | 63 78 1 0 12.171434481977068
65 | 64 76 1 2 12.243057878292948
66 | 65 62 2 1 17.85465493768766
67 | 66 82 0 0 0.24268996846231058
68 | 67 33 0 0 0.04774318613669772
69 | 68 84 2 1 4.216486702193051
70 | 69 35 2 2 0.10726174246600004
71 | 70 6 0 0 0.04715024202317429
72 | 71 95 0 2 2.084648494815446
73 | 72 34 1 0 4.863903369606167
74 | 73 82 0 2 6.79074720647762
75 | 74 31 2 0 12.214015239022762
76 | 75 1 1 0 1.3310131668497027
77 | 76 9 1 2 11.981662354229353
78 | 77 53 0 0 0.10528355573677413
79 | 78 51 2 2 0.11912398273087903
80 | 79 62 0 2 2.1182673862837937
81 | 80 67 0 0 0.06209842396361868
82 | 81 64 1 0 3.6901312215305797
83 | 82 4 2 1 2.3392164515235017
84 | 83 26 2 0 2.391972242310052
85 | 84 74 2 2 0.06515422660350689
86 | 85 94 2 1 6.57543424894547
87 | 86 90 2 2 0.16321269508860758
88 | 87 69 2 2 0.045674089700306154
89 | 88 71 2 0 7.9125040784755445
90 | 89 50 0 2 14.63713403140689
91 | 90 2 0 1 6.79027110338563
92 | 91 85 2 2 0.024759614974846945
93 | 92 82 1 1 0.005775853155256345
94 | 93 83 1 0 5.897403137391079
95 | 94 61 2 0 7.265893281006673
96 | 95 56 0 1 9.022589019442128
97 | 96 58 0 1 7.983617541976239
98 | 97 93 2 0 10.9563978498973
99 | 98 11 2 0 13.395929245851262
100 | 99 83 0 0 0.12207978508741155
101 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from ast import arg, parse
3 | import training_mentornet.train as train
4 | import training_mentornet.data_generator as data_generator
5 |
6 | csv_file_path = "/data/lizongbu-slurm/furyton/mentornet/MentorNet_pytorch/fake_data.csv"
7 |
8 | preprocess_data_path = "processed_data/fake_data_percentile_40"
9 |
10 | def str2bool(v):
11 | return v.lower() in ('true')
12 |
13 | if __name__ == '__main__':
14 | parser = argparse.ArgumentParser()
15 |
16 | parser.add_argument('--process_data', type=str2bool)
17 | parser.add_argument('--raw_csv', type=str, default=None, help="raw csv file path")
18 | parser.add_argument('--data_path', type=str, default=None, help="where you want to save the processed dataset")
19 | parser.add_argument('--processed_path', type=str, default=None)
20 | parser.add_argument('--train_dir', type=str, default='trial')
21 | parser.add_argument('--epoch',type=int, default=10)
22 | parser.add_argument('--device',type=str,default='cpu')
23 | parser.add_argument('--lr',type=float,default=0.1)
24 | parser.add_argument('--batch_size', type=int,default=32)
25 | parser.add_argument('--show_progress_bar',type=str2bool, default=False)
26 |
27 | config = parser.parse_args()
28 |
29 | if config.process_data:
30 | data_generator.generate_data_driven(config.raw_csv, config.data_path)
31 | else:
32 | tr = train.trainer(train_dir=config.train_dir, data_path=config.processed_path, show_progress_bar=config.show_progress_bar, epoch=config.epoch, mini_batch_size=config.batch_size, device=config.device)
33 |
34 | tr.train()
--------------------------------------------------------------------------------
/training_mentornet/data_generator.py:
--------------------------------------------------------------------------------
1 | """Generates training data for learning/updating MentorNet."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import os
8 | import csv
9 | import torch
10 | import pickle
11 | import itertools
12 | import numpy as np
13 |
14 |
15 | def read_from_csv(input_csv_file):
16 | """Reads Data from an input CSV file.
17 |
18 | Args:
19 | input_csv_file: the path of the CSV file.
20 |
21 | Returns:
22 | a numpy array with different data at each index:
23 | """
24 | data = {}
25 | with open(input_csv_file, 'r') as csv_file_in:
26 | reader = csv.reader(csv_file_in)
27 | for row in reader:
28 | for (_, cell) in enumerate(row):
29 | rdata = cell.strip().split(' ')
30 | rid = rdata[0]
31 | rdata = [float(t) for t in rdata[1:]]
32 | data[rid] = rdata
33 | csv_file_in.close()
34 | return data
35 |
36 |
37 | def generate_data_driven(input_csv_filename,
38 | outdir,
39 | percentile_range='40,50,60,70,80,90'):
40 | """Generates a data-driven trainable dataset, given a CSV.
41 |
42 | Refer to README.md for details on how to format the CSV.
43 |
44 | Args:
45 | input_csv_filename: the path of the CSV file. The csv file format
46 | 0: epoch_percentage
47 | 1: noisy label
48 | 2: clean label
49 | 3: loss
50 | outdir: directory to save the training data.
51 | percentile_range: the percentiles used to compute the moving average.
52 | """
53 | raw = read_from_csv(input_csv_filename)
54 |
55 | raw = np.array([i for i in raw.values()])
56 | dataset_name = os.path.splitext(os.path.basename(input_csv_filename))[0]
57 |
58 | percentile_range = percentile_range.split(',')
59 | percentile_range = [int(x) for x in percentile_range]
60 |
61 | for percentile in percentile_range:
62 | percentile = int(percentile)
63 | p_perncentile = np.percentile(raw[:, 3], percentile)
64 |
65 | v_star = np.float32(raw[:, 1] == raw[:, 2])
66 |
67 | l = raw[:, 3]
68 | diff = raw[:, 3] - p_perncentile
69 | # label not used in the current version.
70 | y = np.array([0] * len(v_star))
71 | epoch_percentage = raw[:, 0]
72 |
73 | data = np.vstack((l, diff, y, epoch_percentage, v_star))
74 | data = np.transpose(data)
75 |
76 | perm = np.arange(data.shape[0])
77 | np.random.shuffle(perm)
78 | data = data[perm,]
79 |
80 | tr_size = int(data.shape[0] * 0.8)
81 |
82 | tr = data[0:tr_size]
83 | ts = data[(tr_size + 1):data.shape[0]]
84 |
85 | cur_outdir = os.path.join(
86 | outdir, '{}_percentile_{}'.format(dataset_name, percentile))
87 | if not os.path.exists(cur_outdir):
88 | os.makedirs(cur_outdir)
89 |
90 | print('training_shape={} test_shape={}'.format(tr.shape, ts.shape))
91 | print(cur_outdir)
92 | with open(os.path.join(cur_outdir, 'tr.p'), 'wb') as outfile:
93 | pickle.dump(tr, outfile)
94 |
95 | with open(os.path.join(cur_outdir, 'ts.p'), 'wb') as outfile:
96 | pickle.dump(ts, outfile)
97 |
--------------------------------------------------------------------------------
/training_mentornet/models.py:
--------------------------------------------------------------------------------
1 | # Baseline Models, not used for now
--------------------------------------------------------------------------------
/training_mentornet/reader.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import torch
7 | import pickle
8 | import numpy as np
9 | import torch.utils.data as data_utils
10 |
11 | class Dataset(data_utils.Dataset):
12 | def __init__(self, indir, split_name) -> None:
13 | super().__init__()
14 | self._data = pickle.load(open(os.path.join(indir, split_name + '.p'), 'rb'))
15 | self._num_examples = self._data.shape[0]
16 | self.feat_dim = self._data.shape[1] - 1
17 | self._epochs_completed = 0
18 | self._index_in_epoch = 0
19 |
20 | def __len__(self) -> int:
21 | return self._num_examples
22 |
23 | @property
24 | def is_binary_label(self):
25 | unique_labels = np.unique(self._data[:, -1])
26 | if len(unique_labels) == 2 and (0 in unique_labels) and (
27 | 1 in unique_labels):
28 | return True
29 | return False
30 |
31 | def __getitem__(self, index: int):
32 | return torch.tensor(self._data[index])
33 |
34 | def get_train_dataloader(data_path: str, device: str='cpu', batch_size: int=32, worker_num: int=2):
35 | return data_utils.DataLoader(Dataset(data_path, 'tr'), batch_size=batch_size, shuffle=True, num_workers=worker_num,pin_memory=True)
36 |
37 | def get_test_dataloader(data_path: str, device: str='cpu', batch_size: int=32, worker_num: int=2):
38 | return data_utils.DataLoader(Dataset(data_path, 'ts'), batch_size=batch_size ,num_workers=worker_num,pin_memory=True)
39 |
--------------------------------------------------------------------------------
/training_mentornet/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import utils
4 | import torch
5 | from . import reader
6 | import datetime
7 | import numpy as np
8 | from abc import ABCMeta
9 |
10 |
11 | # train_dir = ''
12 | # data_path = ''
13 | # device = 'cpu'
14 | # mini_batch_size = 32
15 | # max_step_train = 3e4
16 | # learning_rate = 0.1
17 | # worker_num = 2
18 | # epoch = 2
19 |
20 |
21 | class trainer(metaclass=ABCMeta):
22 | def __init__(self, train_dir,
23 | data_path,
24 | device='cpu',
25 | mini_batch_size=32,
26 | learning_rate=0.1,
27 | worker_num=2,
28 | epoch=2,
29 | show_progress_bar=False,
30 | is_binary_label=True):
31 |
32 |
33 | self.train_dir = train_dir
34 | self.data_path = data_path
35 | self.device = device
36 | self.mini_batch_size = mini_batch_size
37 | self.learning_rate = learning_rate
38 | self.worker_num = worker_num
39 | self.epoch = epoch
40 | self.show_progress_bar = show_progress_bar
41 | self.is_binary_label = is_binary_label
42 |
43 | if not os.path.exists(train_dir):
44 | os.makedirs(train_dir)
45 |
46 | self.train_dataLoader = reader.get_train_dataloader(data_path=data_path,
47 | device=device,
48 | batch_size=mini_batch_size,
49 | worker_num=worker_num)
50 |
51 | self.test_dataLoader = reader.get_test_dataloader(data_path=data_path,
52 | device=device,
53 | batch_size=mini_batch_size,
54 | worker_num=worker_num)
55 |
56 | self.model = utils.MentorNet_nn()
57 | self.BCEloss = torch.nn.BCEWithLogitsLoss(reduction='mean')
58 | self.MSEloss = torch.nn.MSELoss(reduction='mean')
59 |
60 | self.optim = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
61 |
62 | self.lr_sheduler = torch.optim.lr_scheduler.ExponentialLR(self.optim,gamma=0.9)
63 |
64 |
65 | def train(self):
66 | val_loss = self.test()
67 |
68 | for epoch in range(self.epoch):
69 | print("start training epoch: ", epoch)
70 |
71 | self.train_one_epoch(epoch)
72 |
73 | cur_loss = self.test()
74 |
75 | if cur_loss < val_loss:
76 | print(f'progress')
77 | val_loss = cur_loss
78 | self.save('best')
79 |
80 | self.lr_sheduler.step()
81 |
82 | self.save('final')
83 |
84 |
85 | def train_one_epoch(self, epoch):
86 | self.model.train()
87 |
88 | iterator = self.train_dataLoader if not self.show_progress_bar else tqdm(self.train_dataLoader)
89 |
90 | tot_loss = 0
91 | tot_batch = 0
92 |
93 | for batch_idx, batch in enumerate(iterator):
94 | self.optim.zero_grad()
95 |
96 | loss = self.calculate_loss(batch)
97 |
98 | tot_loss += loss.item()
99 |
100 | tot_batch += 1
101 |
102 | loss.backward()
103 |
104 | self.optim.step()
105 |
106 | if self.show_progress_bar:
107 | iterator.set_description('Epoch {}, loss {:.3f} '.format(epoch + 1, tot_loss / tot_batch))
108 |
109 | print(f'epoch: {epoch}, train loss: {tot_loss / tot_batch}')
110 |
111 | def test(self):
112 | self.model.eval()
113 |
114 | tot_loss = 0
115 | tot_batch = 0
116 |
117 | with torch.no_grad():
118 | iterator = self.test_dataLoader if not self.show_progress_bar else tqdm(self.test_dataLoader)
119 |
120 | for batch_idx, batch in enumerate(iterator):
121 | loss = self.calculate_loss(batch)
122 |
123 | tot_loss += loss.item()
124 | tot_batch += 1
125 |
126 | if self.show_progress_bar:
127 | iterator.set_description('test loss {:.3f} '.format(tot_loss / tot_batch))
128 |
129 | print('test loss=', tot_loss / tot_batch)
130 |
131 | return tot_loss / tot_batch
132 |
133 | def calculate_loss(self, batch:torch.Tensor):
134 | v_truth = batch[:, 4].reshape(-1, 1)
135 | input_data = batch[:, 0:4]
136 |
137 | v = self.model(input_data)
138 |
139 | if self.is_binary_label:
140 | loss = self.BCEloss(v, v_truth)
141 | else:
142 | loss = self.MSEloss(torch.sigmoid(v), v_truth)
143 |
144 | return loss
145 |
146 | def save(self, tag: str):
147 | torch.save({'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optim.state_dict()}, os.path.join(self.train_dir, '{}.model'.format(tag)))
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | # import torch.nn.functional as F
5 |
6 | class MentorNet_nn(torch.nn.Module):
7 | def __init__(self, label_embedding_size=2,
8 | epoch_embedding_size=5,
9 | num_fc_nodes=20,
10 | device="cpu"):
11 | """
12 | Args:
13 | label_embedding_size: the embedding size for the label feature.
14 |
15 | epoch_embedding_size: the embedding size for the epoch feature.
16 |
17 | num_fc_nodes: number of hidden nodes in the fc layer.
18 | Input:
19 | input_features: a [batch_size, 4] tensor. Each dimension corresponds to
20 | 0: loss, 1: loss difference to the moving average, 2: label and 3: epoch,
21 | where epoch is an integer between 0 and 99 (the first and the last epoch).
22 |
23 | input_feature: B x 4
24 | Output:
25 | v: [batch_size, 1] weight vector.s
26 | """
27 | super(MentorNet_nn, self).__init__()
28 |
29 | self.device = device
30 |
31 | self.label_embedding = torch.nn.Embedding(num_embeddings=2, embedding_dim=label_embedding_size).to(device)
32 |
33 | self.epoch_embedding = torch.nn.Embedding(num_embeddings=100, embedding_dim=epoch_embedding_size).to(device)
34 |
35 | self.bi_lstm_cell = torch.nn.LSTM(input_size=2, hidden_size=1,bidirectional=True,batch_first=True,num_layers=1).to(device)
36 |
37 | self.feat_size = label_embedding_size + epoch_embedding_size + 2
38 |
39 | self.fc1 = torch.nn.Linear(self.feat_size, num_fc_nodes).to(device)
40 | self.fc2 = torch.nn.Linear(num_fc_nodes, 1, bias=True).to(device)
41 |
42 | def forward(self, input_features):
43 | input_features = input_features.to(self.device)
44 | losses = input_features[:, 0]
45 | loss_diffs = input_features[:, 1]
46 |
47 | lstm_inputs = torch.stack([losses, loss_diffs], dim=-1).to(self.device).to(torch.float32)
48 |
49 | if len(losses.shape) <= 1:
50 | num_steps = 1
51 | lstm_inputs.unsqueeze_(1)
52 | else:
53 | num_steps = int(losses.size()[1])
54 |
55 | # lstm_inputs should be B x N x 2
56 | # where N is the num_steps, B is the batch size
57 |
58 |
59 | lstm_output, _ = self.bi_lstm_cell(lstm_inputs)
60 |
61 | # lstm_output should be B x N x 2
62 | # where '2' is due to bidirectional setting
63 |
64 | loss_variance = lstm_output.sum(1) # B x 2
65 |
66 | labels = input_features[:, 2].reshape((-1, 1)).to(torch.int64)
67 |
68 | epochs = input_features[:, 3].reshape((-1, 1)).to(torch.int64)
69 | epochs = torch.min(epochs, torch.ones([epochs.size()[0], 1], dtype=torch.int64).to(self.device) * 99).to(self.device)
70 |
71 | # epoch_embedding.weight.requires_grad = False
72 |
73 | label_inputs = self.label_embedding(labels).squeeze(1) # B x D
74 | epoch_inputs = self.epoch_embedding(epochs).squeeze(1) # B x D
75 |
76 | # print(label_inputs.size(), epoch_inputs.size(), loss_variance.size())
77 |
78 | feat = torch.cat([label_inputs, epoch_inputs, loss_variance], -1).to(self.device)
79 |
80 | fc_1 = self.fc1(feat)
81 | output_1 = torch.tanh(fc_1)
82 |
83 | return self.fc2(output_1)
84 |
85 |
86 | class MentorNet(torch.nn.Module):
87 | def __init__(self, burn_in_epoch=18,
88 | fixed_epoch_after_burn_in = True,
89 | loss_moving_average_decay=0.9,
90 | device="cpu"):
91 | """
92 | The MentorNet to train with the StudentNet.
93 | Args:
94 | burn_in_epoch: the number of burn_in_epoch. In the first burn_in_epoch, all samples have 1.0 weights.
95 |
96 | fixed_epoch_after_burn_in: whether to fix the epoch after the burn-in.
97 |
98 | loss_moving_average_decay: the decay factor to compute the moving average.
99 | Input:
100 | epoch: a tensor [batch_size, 1] representing the training percentage. Each epoch is an integer between 0 and 99.
101 |
102 | loss: a tensor [batch_size, 1] representing the sample loss.
103 |
104 | labels: a tensor [batch_size, 1] representing the label. Every label is set to 0 in the current version.
105 |
106 | loss_p_percentile: a 1-d tensor of size 100, where each element is the p-percentile at that epoch to compute the moving average.
107 |
108 | example_dropout_rates: a 1-d tensor of size 100, where each element is the dropout rate at that epoch. Dropping out means the probability of setting sample weights to zeros proposed in Liang, Junwei, et al. "Learning to Detect Concepts from Webly-Labeled Video Data." IJCAI. 2016.
109 | """
110 | super(MentorNet, self).__init__()
111 |
112 | self.device = device
113 |
114 | self.fixed_epoch_after_burn_in = fixed_epoch_after_burn_in
115 |
116 | self.burn_in_epoch = burn_in_epoch
117 |
118 | self.loss_moving_average_decay = loss_moving_average_decay
119 |
120 | self.mentor = MentorNet_nn(device=device)
121 |
122 | self.loss_moving_avg = None
123 |
124 | def forward(self, epoch, loss, labels, loss_p_percentile, example_dropout_rates):
125 | # epoch : B x 1
126 | # loss : B x 1
127 | # labels: B x 1
128 | # loss_p_percentile: 100
129 | # example_dropout_rates: 100
130 |
131 | burn_in_epoch = torch.tensor([[self.burn_in_epoch]] * epoch.shape[0]).to(self.device)
132 |
133 | if not self.fixed_epoch_after_burn_in:
134 | cur_epoch = epoch
135 | else:
136 | cur_epoch = epoch.min(burn_in_epoch)
137 |
138 | # cur_epoch : B x 1
139 |
140 | v_ones = torch.ones(loss.size(), dtype=torch.float32).to(self.device)
141 |
142 | v_zeros = torch.zeros(loss.size(), dtype=torch.float32).to(self.device)
143 |
144 | upper_bound = torch.where(cur_epoch < burn_in_epoch - 1, v_ones, v_zeros).to(self.device)
145 |
146 | # TODO dangerous here
147 | this_dropout_rate = example_dropout_rates[cur_epoch][0][0]
148 |
149 | # TODO dangerous here
150 | this_percentile = loss_p_percentile[cur_epoch].squeeze()
151 |
152 | percentile_loss = torch.tensor(np.percentile(loss.cpu(), this_percentile.cpu()), dtype=torch.float32).unsqueeze(-1).to(self.device)
153 |
154 | # percentile_loss : B x 1
155 |
156 | if self.loss_moving_avg is None:
157 | self.loss_moving_avg = (1 - self.loss_moving_average_decay) * percentile_loss
158 | else:
159 | self.loss_moving_avg = self.loss_moving_avg * self.loss_moving_average_decay + (1 - self.loss_moving_average_decay) * percentile_loss
160 |
161 | # loss_moving_avg : B x 1
162 |
163 | # print(loss.size())
164 |
165 | input_data = torch.stack([loss, self.loss_moving_avg, labels, cur_epoch.to(torch.float32)], 1).squeeze(-1).to(self.device)
166 |
167 | # print(input_data.size())
168 |
169 | v = self.mentor(input_data).sigmoid().max(upper_bound)
170 |
171 | # print(torch.ceil(v.size()[0] * (1 - this_dropout_rate)))
172 |
173 | dropout_num = int(torch.ceil(v.size()[0] * (1 - this_dropout_rate)).item())
174 |
175 | idx = torch.tensor(random.sample(range(v.size()[0]), dropout_num), dtype=torch.int64).to(self.device)
176 |
177 | dropout_v = torch.zeros(v.size()[0]).to(self.device)
178 | dropout_v[idx] = 1
179 |
180 | # dropout_v.dot()
181 |
182 | return (v.squeeze() * (dropout_v)).unsqueeze(-1)
183 |
--------------------------------------------------------------------------------