├── data ├── snomed_label_to_meta_grouping.json ├── snomed_labels.json ├── snomed_to_meta_map.json └── snomed_labels_to_name.json ├── README.md ├── csu_data.py ├── util.py └── csu_model.py /data/snomed_label_to_meta_grouping.json: -------------------------------------------------------------------------------- 1 | {"0": [12, 34, 13], 2 | "1": [40], 3 | "2": [19, 17, 26, 24, 3, 11, 0, 23, 27], 4 | "3": [18, 21, 28, 20, 4], 5 | "4": [14], 6 | "5": [29], 7 | "6": [25, 33], 8 | "7": [30], 9 | "8": [35], 10 | "9": [15, 39], 11 | "10": [31], 12 | "11": [7, 9], 13 | "12": [2, 10, 38], 14 | "13": [36, 37], 15 | "14": [5, 22], 16 | "15": [6], "16": [1, 41], 17 | "17": [8, 16, 32]} 18 | -------------------------------------------------------------------------------- /data/snomed_labels.json: -------------------------------------------------------------------------------- 1 | ["414916001", "37064009", "404177007", "363246002", "405538007", "32895009", "414025005", "362972006", "85983004", "173300003", "414032001", "2492009", "68843000", "17322007", "74732009", "422400008", "75478009", "414029004", "271737000", "85828009", "362970003", "414022008", "66091009", "420134006", "75934005", "362966006", "473010000", "362969004", "414026006", "118940003", "49601007", "42030000", "417163006", "128127008", "40733004", "50043002", "105969002", "928000", "128598002", "53619000", "399981008", "404684003"] 2 | -------------------------------------------------------------------------------- /data/snomed_to_meta_map.json: -------------------------------------------------------------------------------- 1 | {"414026006": "4", "362970003": "4", "74732009": "5", "42030000": "11", "128127008": "7", "173300003": "12", "50043002": "9", "473010000": "3", "417163006": "18", "49601007": "8", "40733004": "1", "128598002": "13", "405538007": "4", "2492009": "3", "362972006": "12", "420134006": "3", "414032001": "13", "105969002": "14", "37064009": "17", "414022008": "4", "271737000": "4", "118940003": "6", "404177007": "13", "66091009": "15", "85828009": "3", "928000": "14", "75934005": "3", "414025005": "16", "53619000": "10", "414916001": "3", "32895009": "15", "75478009": "18", "422400008": "10", "399981008": "2", "414029004": "3", "362969004": "3", "68843000": "1", "404684003": "17", "362966006": "7", "17322007": "1", "85983004": "18", "363246002": "3"} 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepTag 2 | This contains the source code for the DeepTag Project 3 | 4 | Unfortunately due to data share agreement, we are not at liberty to share our data. 5 | 6 | However, we are able to share our experiment framework code. Note the randomness in CuDNN (an algorithm framework that we used) cannot be completely fixed, and the training process is stochastic. Replicating our number exactly is difficult, but we have run each experiment 5 times and report the average number. 7 | 8 | `csu_model.py` is a very long code that has `Trainer`, `Experiment` set up. `Trainer` manages `Classifier`, it trains and loads classifiers. `Experiment` manages the entire experiment, the folder structure, automatically recording each random run, compute average statistics automatically. This code is very modular, and can be used in many ML researches. 9 | 10 | `csu_data.py` preprocesses the data. 11 | 12 | `Learning_to_Reject.ipynb` is the Jupyter Notebook that we used to compute the abstention algorithm. It loads in previously saved output from a trained classifier, and use it to train a new abstention model. The plots in the paper are generated from this notebook. 13 | 14 | `data/snomed_label_to_meta_grouping.json` contains the label (disease) similarity that we defined. We hope this list to be of general value to people working with SNOMED-CT disease level codes. 15 | -------------------------------------------------------------------------------- /data/snomed_labels_to_name.json: -------------------------------------------------------------------------------- 1 | ["Obesity (disorder)", "Hyperproteinemia (disorder)", "Angioedema and/or urticaria (disorder)", "Nutritional deficiency associated condition (disorder)", "Spontaneous hemorrhage (disorder)", "Hereditary disease (disorder)", "Disorder of fetus or newborn (disorder)", "Disorder of labor / delivery (disorder)", "Disorder caused by exposure to ionizing radiation (disorder)", "Disorder of pregnancy (disorder)", "Disorder of pigmentation (disorder)", "Nutritional disorder (disorder)", "Disease caused by Arthropod (disorder)", "Disease caused by parasite (disorder)", "Mental disorder (disorder)", "Vomiting (disorder)", "Poisoning (disorder)", "Disorder of immune function (disorder)", "Anemia (disorder)", "Autoimmune disease (disorder)", "Disorder of hemostatic system (disorder)", "Disorder of cellular component of blood (disorder)", "Congenital disease (disorder)", "Propensity to adverse reactions (disorder)", "Metabolic disease (disorder)", "Disorder of auditory system (disorder)", "Hypersensitivity condition (disorder)", "Disorder of endocrine system (disorder)", "Disorder of hematopoietic cell proliferation (disorder)", "Disorder of nervous system (disorder)", "Disorder of cardiovascular system (disorder)", "Disorder of the genitourinary system (disorder)", "Traumatic AND/OR non-traumatic injury (disorder)", "Visual system disorder (disorder)", "Infectious disease (disorder)", "Disorder of respiratory system (disorder)", "Disorder of connective tissue (disorder)", "Disorder of musculoskeletal system (disorder)", "Disorder of integument (disorder)", "Disorder of digestive system (disorder)", "Neoplasm and/or hamartoma (disorder)", "Clinical finding (finding)"] 2 | -------------------------------------------------------------------------------- /csu_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Process top-level SNOMED code 3 | """ 4 | 5 | """ 6 | Takes in raw data and preprocess this into good format 7 | for TorchText to train on 8 | """ 9 | 10 | import re 11 | import numpy as np 12 | 13 | """ 14 | We train without diagnosis, and with multilabel 15 | """ 16 | 17 | np.random.seed(1234) 18 | 19 | # ======== Split ========= 20 | train_size = 0.9 21 | 22 | assert (train_size < 1 and train_size > 0) 23 | split_proportions = { 24 | "train": train_size, 25 | "valid": (1 - train_size) / 2, 26 | "test": (1 - train_size) / 2 27 | } 28 | assert (sum([split_proportions[split] for split in split_proportions]) == 1) 29 | 30 | print("the data split is: {}".format(split_proportions)) 31 | 32 | inflating_test_set = True 33 | 34 | # maybe not predicting 17 (it's a catch-all disease) 35 | 36 | def write_to_tsv(data, file_name, label_list): 37 | # we are translating labels here 38 | with open(file_name, 'wb') as f: 39 | for line in data: 40 | mapped_labels = [str(label_list.index(l)) for l in line[1].split()] 41 | f.write(line[0] + '\t' + " ".join(mapped_labels) + '\n') 42 | 43 | 44 | def count_freq(list_labels): 45 | dic = {} 46 | for l in list_labels: 47 | if l not in dic: 48 | dic[l] = 1 49 | else: 50 | dic[l] += 1 51 | return dic 52 | 53 | 54 | def get_most_freq_label(dic): 55 | most_f_l = None 56 | most_f_f = 0. 57 | for l, f in dic.iteritems(): 58 | if f > most_f_f: 59 | most_f_l = l 60 | return most_f_l 61 | 62 | def collapse_label(labels): 63 | # Note: for SNOMED we no longer take out category 17 (no longer exist) 64 | labels = labels.strip() 65 | # labels = labels.replace('17', '') 66 | list_labels = filter(lambda l: len(l) > 0, labels.split('-')) 67 | # if len(list_labels) == 0: 68 | # list_labels = ['17'] # meaning it only has 17 69 | set_labels = set(list_labels) # remove redundancies 70 | return list(set_labels) 71 | 72 | def cleanhtml(raw_html): 73 | cleanr = re.compile('<.*?>') 74 | cleantext = re.sub(cleanr, '', raw_html) 75 | cleantext = re.sub(r'^https?:\/\/.*[\r\n]*', '', cleantext, flags=re.MULTILINE) 76 | return cleantext 77 | 78 | # TODO: 2. Preserve things like "Texas A&M", the ampersand in the middle 79 | def preprocess_text(text, no_description): 80 | no_html = cleanhtml(text) 81 | one_white_space = ' '.join(no_html.split()) 82 | no_html_entities = re.sub('&[a-z]+;', '', one_white_space) 83 | 84 | if no_description: 85 | # delete both diagnosis and discharge status 86 | no_html_entities = no_html_entities.split('Diagnosis:')[0] 87 | 88 | return no_html_entities 89 | 90 | 91 | if __name__ == '__main__': 92 | header = True 93 | 94 | examples = [] 95 | labels_dist = [] 96 | with open("../../data/csu/final_csu_file_snomed", 'r') as f: 97 | for line in f: 98 | if header: 99 | header = False 100 | continue 101 | columns = line.split('\t') 102 | labels = columns[-1] 103 | 104 | text = preprocess_text(columns[4], no_description=True) 105 | 106 | seq_labels = collapse_label(labels) 107 | labels_dist.extend(seq_labels) 108 | # start from 0, and also join back to " " separation 109 | examples.append([text, " ".join(seq_labels)]) 110 | 111 | # import matplotlib.pyplot as plt 112 | # 113 | # n, bins, patches = plt.hist(labels_dist, 50, normed=1, facecolor='green', alpha=0.75) 114 | # plt.show() 115 | 116 | import csv 117 | with open("../../data/csu/Files_for_parsing/snomed_ICD_mapped.csv", 'r') as f: 118 | csv_reader = csv.reader(f, delimiter=';') 119 | snomed_code_to_name = {} 120 | for row in csv_reader: 121 | snomed_code_to_name[row[0]] = row[1] 122 | 123 | labels_dist = count_freq(labels_dist) 124 | 125 | print("number of labels is {}".format(len(labels_dist))) 126 | 127 | with open("../../data/csu/snomed_dist.csv", 'wb') as f: 128 | for k, v in labels_dist.items(): 129 | f.write(snomed_code_to_name[k] + "," + str(v) + "\n") 130 | 131 | labels_prob = map(lambda t: (t[0], float(t[1]) / sum(labels_dist.values())), labels_dist.items()) 132 | 133 | labels_prob = sorted(labels_prob, key=lambda t: t[1]) 134 | 135 | print "code, n, p" 136 | for k, prob in labels_prob: 137 | print "{}, {}, {}".format(k, labels_dist[k], prob) 138 | 139 | label_list = [t[0] for t in labels_prob] 140 | 141 | # process them into tsv format, but also collect frequency distribution 142 | serial_numbers = range(len(examples)) 143 | np.random.shuffle(serial_numbers) 144 | 145 | train_numbers = serial_numbers[:int(np.rint(len(examples) * split_proportions['train']))] 146 | valid_numbers = serial_numbers[ 147 | int(np.rint(len(examples) * split_proportions['train'])): \ 148 | int(np.rint(len(examples) * (split_proportions['train'] + split_proportions['valid'])))] 149 | test_numbers = serial_numbers[ 150 | int(np.rint(len(examples) * (split_proportions['train'] + split_proportions['valid']))):] 151 | 152 | print( 153 | "train/valid/test number of examples: {}/{}/{}".format(len(train_numbers), len(valid_numbers), 154 | len(test_numbers))) 155 | train, valid, test = [], [], [] 156 | 157 | for tn in train_numbers: 158 | train.append(examples[tn]) 159 | for tn in valid_numbers: 160 | valid.append(examples[tn]) 161 | for tn in test_numbers: 162 | test.append(examples[tn]) 163 | 164 | write_to_tsv(train, "../../data/csu/snomed_multi_label_no_des_train.tsv", label_list) 165 | write_to_tsv(valid, "../../data/csu/snomed_multi_label_no_des_valid.tsv", label_list) 166 | write_to_tsv(test, "../../data/csu/snomed_multi_label_no_des_test.tsv", label_list) 167 | 168 | import json 169 | with open('../../data/csu/snomed_labels.json', 'wb') as f: 170 | json.dump(label_list, f) 171 | 172 | names = [snomed_code_to_name[l] for l in label_list] 173 | # index matches 0 to 41 174 | with open('../../data/csu/snomed_labels_to_name.json', 'wb') as f: 175 | json.dump(names, f) 176 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import six 3 | from torchtext.data.field import RawField, Field 4 | from torchtext.vocab import Vocab 5 | from torchtext.data.pipeline import Pipeline 6 | from collections import Counter, OrderedDict 7 | from torch.autograd import Variable 8 | from torchtext.data.dataset import Dataset 9 | from torch.nn import Module 10 | 11 | 12 | class MultiMarginHierarchyLoss(Module): 13 | r"""Creates a criterion that optimizes a multi-class classification hinge 14 | loss (margin-based loss) between input `x` (a 2D mini-batch `Tensor`) and 15 | output `y` (which is a 1D tensor of target class indices, 16 | `0` <= `y` <= `x.size(1)`): 17 | 18 | For each mini-batch sample:: 19 | 20 | loss(x, y) = sum_i(max(0, (margin - x[y] + x[i]))^p) / x.size(0) 21 | where `i == 0` to `x.size(0)` and `i != y`. 22 | 23 | Optionally, you can give non-equal weighting on the classes by passing 24 | a 1D `weight` tensor into the constructor. 25 | 26 | The loss function then becomes: 27 | 28 | loss(x, y) = sum_i(max(0, w[y] * (margin - x[y] - x[i]))^p) / x.size(0) 29 | 30 | By default, the losses are averaged over observations for each minibatch. 31 | However, if the field `size_average` is set to ``False``, 32 | the losses are instead summed. 33 | """ 34 | 35 | def __init__(self, neighbor_maps, class_size=2, p=1, neighbor_margin=0.5, 36 | margin=1, weight=None, size_average=True): 37 | super(MultiMarginHierarchyLoss, self).__init__() 38 | if p != 1 and p != 2: 39 | raise ValueError("only p == 1 and p == 2 supported") 40 | assert weight is None or weight.dim() == 1 41 | self.class_size = class_size 42 | self.neighbor_maps = neighbor_maps 43 | self.neighbor_margin = neighbor_margin 44 | self.p = p 45 | self.margin = margin 46 | self.size_average = size_average 47 | self.weight = weight 48 | 49 | def forward(self, input, target): 50 | # return multi_margin_loss(input, target, self.p, self.margin, 51 | # self.weight, self.size_average) 52 | batch_size = input.size(0) 53 | y_indices = target.nonzero() 54 | for b in range(batch_size): 55 | tgt_labels = y_indices[b, :] 56 | l = [self.neighbor_maps[str(tgt_label)] for tgt_label in tgt_labels] 57 | neighbor_inds = [item for sublist in l for item in sublist] # for entire group 58 | 59 | # compute loss 60 | for i in range(self.class_size): 61 | pass 62 | 63 | return 64 | 65 | 66 | # we also implement the latest BCELoss without reduction 67 | 68 | def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True, reduce=True): 69 | r"""Function that measures Binary Cross Entropy between target and output 70 | logits. 71 | See :class:`~torch.nn.BCEWithLogitsLoss` for details. 72 | Args: 73 | input: Variable of arbitrary shape 74 | target: Variable of the same shape as input 75 | weight (Variable, optional): a manual rescaling weight 76 | if provided it's repeated to match input tensor shape 77 | size_average (bool, optional): By default, the losses are averaged 78 | over observations for each minibatch. However, if the field 79 | sizeAverage is set to False, the losses are instead summed 80 | for each minibatch. Default: ``True`` 81 | reduce (bool, optional): By default, the losses are averaged or summed over 82 | observations for each minibatch depending on size_average. When reduce 83 | is False, returns a loss per input/target element instead and ignores 84 | size_average. Default: True 85 | Examples:: 86 | >>> input = torch.randn(3, requires_grad=True) 87 | >>> target = torch.FloatTensor(3).random_(2) 88 | >>> loss = F.binary_cross_entropy_with_logits(input, target) 89 | >>> loss.backward() 90 | """ 91 | if not (target.size() == input.size()): 92 | raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size())) 93 | 94 | max_val = (-input).clamp(min=0) 95 | loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log() 96 | 97 | if weight is not None: 98 | loss = loss * weight 99 | 100 | if not reduce: 101 | return loss 102 | elif size_average: 103 | return loss.mean() 104 | else: 105 | return loss.sum() 106 | 107 | 108 | class BCEWithLogitsLoss(Module): 109 | r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single 110 | class. This version is more numerically stable than using a plain `Sigmoid` 111 | followed by a `BCELoss` as, by combining the operations into one layer, 112 | we take advantage of the log-sum-exp trick for numerical stability. 113 | The loss can be described as: 114 | .. math:: 115 | \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad 116 | l_n = - w_n \left[ t_n \cdot \log \sigma(x_n) 117 | + (1 - t_n) \cdot \log (1 - \sigma(x_n)) \right], 118 | where :math:`N` is the batch size. If reduce is ``True``, then 119 | .. math:: 120 | \ell(x, y) = \begin{cases} 121 | \operatorname{mean}(L), & \text{if}\; \text{size_average} = \text{True},\\ 122 | \operatorname{sum}(L), & \text{if}\; \text{size_average} = \text{False}. 123 | \end{cases} 124 | This is used for measuring the error of a reconstruction in for example 125 | an auto-encoder. Note that the targets `t[i]` should be numbers 126 | between 0 and 1. 127 | Args: 128 | weight (Tensor, optional): a manual rescaling weight given to the loss 129 | of each batch element. If given, has to be a Tensor of size 130 | "nbatch". 131 | size_average (bool, optional): By default, the losses are averaged 132 | over observations for each minibatch. However, if the field 133 | size_average is set to ``False``, the losses are instead summed for 134 | each minibatch. Default: ``True`` 135 | reduce (bool, optional): By default, the losses are averaged or summed over 136 | observations for each minibatch depending on size_average. When reduce 137 | is False, returns a loss per input/target element instead and ignores 138 | size_average. Default: True 139 | Shape: 140 | - Input: :math:`(N, *)` where `*` means, any number of additional 141 | dimensions 142 | - Target: :math:`(N, *)`, same shape as the input 143 | Examples:: 144 | >>> loss = nn.BCEWithLogitsLoss() 145 | >>> input = torch.randn(3, requires_grad=True) 146 | >>> target = torch.FloatTensor(3).random_(2) 147 | >>> output = loss(input, target) 148 | >>> output.backward() 149 | """ 150 | 151 | def __init__(self, weight=None, size_average=True, reduce=True): 152 | super(BCEWithLogitsLoss, self).__init__() 153 | self.size_average = size_average 154 | self.reduce = reduce 155 | self.register_buffer('weight', weight) 156 | 157 | def forward(self, input, target): 158 | if self.weight is not None: 159 | var = Variable(self.weight) if not isinstance(self.weight, Variable) else self.weight 160 | return binary_cross_entropy_with_logits(input, target, 161 | var, 162 | self.size_average, 163 | reduce=self.reduce) 164 | else: 165 | return binary_cross_entropy_with_logits(input, target, 166 | size_average=self.size_average, 167 | reduce=self.reduce) 168 | 169 | 170 | class ReversibleField(Field): 171 | def __init__(self, **kwargs): 172 | if kwargs.get('tokenize') is list: 173 | self.use_revtok = False 174 | else: 175 | self.use_revtok = True 176 | if kwargs.get('tokenize') not in ('revtok', 'subword', list): 177 | kwargs['tokenize'] = 'revtok' 178 | if 'unk_token' not in kwargs: 179 | kwargs['unk_token'] = ' UNK ' 180 | super(ReversibleField, self).__init__(**kwargs) 181 | 182 | def reverse(self, batch): 183 | if self.use_revtok: 184 | try: 185 | import revtok 186 | except ImportError: 187 | print("Please install revtok.") 188 | raise 189 | if not self.batch_first: 190 | batch = batch.t() 191 | with torch.cuda.device_of(batch): 192 | batch = batch.tolist() 193 | batch = [[self.vocab.itos[ind] for ind in ex] for ex in batch] # denumericalize 194 | 195 | def trim(s, t): 196 | sentence = [] 197 | for w in s: 198 | if w == t: 199 | break 200 | sentence.append(w) 201 | return sentence 202 | 203 | batch = [trim(ex, self.eos_token) for ex in batch] # trim past frst eos 204 | 205 | def filter_special(tok): 206 | return tok not in (self.init_token, self.pad_token) 207 | 208 | batch = [filter(filter_special, ex) for ex in batch] 209 | if self.use_revtok: 210 | return [revtok.detokenize(ex) for ex in batch] 211 | return [' '.join(ex) for ex in batch] 212 | 213 | 214 | class MultiLabelField(RawField): 215 | """Defines a datatype together with instructions for converting to Tensor. 216 | 217 | Field class models common text processing datatypes that can be represented 218 | by tensors. It holds a Vocab object that defines the set of possible values 219 | for elements of the field and their corresponding numerical representations. 220 | The Field object also holds other parameters relating to how a datatype 221 | should be numericalized, such as a tokenization method and the kind of 222 | Tensor that should be produced. 223 | 224 | If a Field is shared between two columns in a dataset (e.g., question and 225 | answer in a QA dataset), then they will have a shared vocabulary. 226 | 227 | Attributes: 228 | sequential: Whether the datatype represents sequential data. If False, 229 | no tokenization is applied. Default: True. 230 | use_vocab: Whether to use a Vocab object. If False, the data in this 231 | field should already be numerical. Default: True. 232 | init_token: A token that will be prepended to every example using this 233 | field, or None for no initial token. Default: None. 234 | eos_token: A token that will be appended to every example using this 235 | field, or None for no end-of-sentence token. Default: None. 236 | fix_length: A fixed length that all examples using this field will be 237 | padded to, or None for flexible sequence lengths. Default: None. 238 | tensor_type: The torch.Tensor class that represents a batch of examples 239 | of this kind of data. Default: torch.LongTensor. 240 | preprocessing: The Pipeline that will be applied to examples 241 | using this field after tokenizing but before numericalizing. Many 242 | Datasets replace this attribute with a custom preprocessor. 243 | Default: None. 244 | postprocessing: A Pipeline that will be applied to examples using 245 | this field after numericalizing but before the numbers are turned 246 | into a Tensor. The pipeline function takes the batch as a list, 247 | the field's Vocab, and train (a bool). 248 | Default: None. 249 | lower: Whether to lowercase the text in this field. Default: False. 250 | tokenize: The function used to tokenize strings using this field into 251 | sequential examples. If "spacy", the SpaCy English tokenizer is 252 | used. Default: str.split. 253 | include_lengths: Whether to return a tuple of a padded minibatch and 254 | a list containing the lengths of each examples, or just a padded 255 | minibatch. Default: False. 256 | batch_first: Whether to produce tensors with the batch dimension first. 257 | Default: False. 258 | pad_token: The string token used as padding. Default: "". 259 | unk_token: The string token used to represent OOV words. Default: "". 260 | """ 261 | 262 | vocab_cls = Vocab 263 | # Dictionary mapping PyTorch tensor types to the appropriate Python 264 | # numeric type. 265 | tensor_types = { 266 | torch.FloatTensor: float, 267 | torch.cuda.FloatTensor: float, 268 | torch.DoubleTensor: float, 269 | torch.cuda.DoubleTensor: float, 270 | torch.HalfTensor: float, 271 | torch.cuda.HalfTensor: float, 272 | 273 | torch.ByteTensor: int, 274 | torch.cuda.ByteTensor: int, 275 | torch.CharTensor: int, 276 | torch.cuda.CharTensor: int, 277 | torch.ShortTensor: int, 278 | torch.cuda.ShortTensor: int, 279 | torch.IntTensor: int, 280 | torch.cuda.IntTensor: int, 281 | torch.LongTensor: int, 282 | torch.cuda.LongTensor: int 283 | } 284 | 285 | def __init__( 286 | self, label_size, sequential=True, use_vocab=True, init_token=None, 287 | eos_token=None, fix_length=None, tensor_type=torch.LongTensor, 288 | preprocessing=None, postprocessing=None, lower=False, 289 | tokenize=(lambda s: s.split()), include_lengths=False, 290 | batch_first=False, pad_token="", unk_token=""): 291 | self.sequential = sequential 292 | self.use_vocab = use_vocab 293 | self.init_token = init_token 294 | self.eos_token = eos_token 295 | self.unk_token = unk_token 296 | self.fix_length = fix_length 297 | self.tensor_type = tensor_type 298 | self.preprocessing = preprocessing 299 | self.postprocessing = postprocessing 300 | self.lower = lower 301 | self.tokenize = tokenize 302 | self.include_lengths = include_lengths 303 | self.batch_first = batch_first 304 | self.pad_token = pad_token if self.sequential else None 305 | 306 | self.label_size = label_size 307 | 308 | def preprocess(self, x): 309 | """Load a single example using this field, tokenizing if necessary. 310 | 311 | If the input is a Python 2 `str`, it will be converted to Unicode 312 | first. If `sequential=True`, it will be tokenized. Then the input 313 | will be optionally lowercased and passed to the user-provided 314 | `preprocessing` Pipeline.""" 315 | if (six.PY2 and isinstance(x, six.string_types) and not 316 | isinstance(x, six.text_type)): 317 | x = Pipeline(lambda s: six.text_type(s, encoding='utf-8'))(x) 318 | # will strip and then split here! 319 | if self.sequential and isinstance(x, six.text_type): 320 | x = self.tokenize(x.rstrip('\n')) 321 | if self.lower: 322 | x = Pipeline(six.text_type.lower)(x) 323 | if self.preprocessing is not None: 324 | return self.preprocessing(x) 325 | else: 326 | return x 327 | 328 | def process(self, batch, device, train): 329 | """ Process a list of examples to create a torch.Tensor. 330 | 331 | Pad, numericalize, and postprocess a batch and create a tensor. 332 | 333 | Args: 334 | batch (list(object)): A list of object from a batch of examples. 335 | Returns: 336 | data (torch.autograd.Varaible): Processed object given the input 337 | and custom postprocessing Pipeline. 338 | """ 339 | padded = self.pad(batch) 340 | tensor = self.numericalize(padded, device=device, train=train) 341 | return tensor 342 | 343 | def pad(self, minibatch): 344 | """Pad a batch of examples using this field. 345 | 346 | Pads to self.fix_length if provided, otherwise pads to the length of 347 | the longest example in the batch. Prepends self.init_token and appends 348 | self.eos_token if those attributes are not None. Returns a tuple of the 349 | padded list and a list containing lengths of each example if 350 | `self.include_lengths` is `True` and `self.sequential` is `True`, else just 351 | returns the padded list. If `self.sequential` is `False`, no padding is applied. 352 | """ 353 | minibatch = list(minibatch) 354 | 355 | # we handle "padding" at numericalization 356 | return minibatch 357 | # 358 | # if self.fix_length is None: 359 | # max_len = max(len(x) for x in minibatch) 360 | # else: 361 | # max_len = self.fix_length + ( 362 | # self.init_token, self.eos_token).count(None) - 2 363 | # padded, lengths = [], [] 364 | # for x in minibatch: 365 | # padded.append( 366 | # ([] if self.init_token is None else [self.init_token]) + 367 | # list(x[:max_len]) + 368 | # ([] if self.eos_token is None else [self.eos_token]) + 369 | # [self.pad_token] * max(0, max_len - len(x))) 370 | # lengths.append(len(padded[-1]) - max(0, max_len - len(x))) 371 | # if self.include_lengths: 372 | # return (padded, lengths) 373 | # return padded 374 | 375 | def build_vocab(self, *args, **kwargs): 376 | """Construct the Vocab object for this field from one or more datasets. 377 | 378 | Arguments: 379 | Positional arguments: Dataset objects or other iterable data 380 | sources from which to construct the Vocab object that 381 | represents the set of possible values for this field. If 382 | a Dataset object is provided, all columns corresponding 383 | to this field are used; individual columns can also be 384 | provided directly. 385 | Remaining keyword arguments: Passed to the constructor of Vocab. 386 | """ 387 | counter = Counter() 388 | sources = [] 389 | for arg in args: 390 | if isinstance(arg, Dataset): 391 | sources += [getattr(arg, name) for name, field in 392 | arg.fields.items() if field is self] 393 | else: 394 | sources.append(arg) 395 | for data in sources: 396 | for x in data: 397 | if not self.sequential: 398 | x = [x] 399 | counter.update(x) 400 | specials = list(OrderedDict.fromkeys( 401 | tok for tok in [self.unk_token, self.pad_token, self.init_token, 402 | self.eos_token] 403 | if tok is not None)) 404 | self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) 405 | 406 | def numericalize(self, arr, device=None, train=True): 407 | """Turn a batch of examples that use this field into a Variable. 408 | 409 | If the field has include_lengths=True, a tensor of lengths will be 410 | included in the return value. 411 | 412 | Arguments: 413 | arr (List[List[str]], or tuple of (List[List[str]], List[int])): 414 | List of tokenized and padded examples, or tuple of List of 415 | tokenized and padded examples and List of lengths of each 416 | example if self.include_lengths is True. 417 | device (-1 or None): Device to create the Variable's Tensor on. 418 | Use -1 for CPU and None for the currently active GPU device. 419 | Default: None. 420 | train (boolean): Whether the batch is for a training set. 421 | If False, the Variable will be created with volatile=True. 422 | Default: True. 423 | """ 424 | if self.include_lengths and not isinstance(arr, tuple): 425 | raise ValueError("Field has include_lengths set to True, but " 426 | "input data is not a tuple of " 427 | "(data batch, batch lengths).") 428 | if isinstance(arr, tuple): 429 | arr, lengths = arr 430 | lengths = torch.LongTensor(lengths) 431 | 432 | if self.use_vocab: 433 | if self.sequential: 434 | arr = [[self.vocab.stoi[x] for x in ex] for ex in arr] 435 | else: 436 | arr = [self.vocab.stoi[x] for x in arr] 437 | 438 | if self.postprocessing is not None: 439 | arr = self.postprocessing(arr, self.vocab, train) 440 | else: 441 | if self.tensor_type not in self.tensor_types: 442 | raise ValueError( 443 | "Specified Field tensor_type {} can not be used with " 444 | "use_vocab=False because we do not know how to numericalize it. " 445 | "Please raise an issue at " 446 | "https://github.com/pytorch/text/issues".format(self.tensor_type)) 447 | numericalization_func = self.tensor_types[self.tensor_type] 448 | # It doesn't make sense to explictly coerce to a numeric type if 449 | # the data is sequential, since it's unclear how to coerce padding tokens 450 | # to a numeric type. 451 | if not self.sequential: 452 | arr = [numericalization_func(x) for x in arr] 453 | if self.sequential: 454 | batches = [] 455 | for x in arr: 456 | zeros = [0.] * self.label_size 457 | for l in x: 458 | zeros[int(l)] = 1. 459 | batches.append(zeros) 460 | arr = batches 461 | if self.postprocessing is not None: 462 | arr = self.postprocessing(arr, None, train) 463 | 464 | arr = self.tensor_type(arr) 465 | if device == -1: 466 | if self.sequential: 467 | arr = arr.contiguous() 468 | else: 469 | arr = arr.cuda(device) 470 | 471 | return Variable(arr, volatile=not train) 472 | -------------------------------------------------------------------------------- /csu_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Store modular components for Jupyter Notebook 3 | """ 4 | import json 5 | import numpy as np 6 | import os 7 | import csv 8 | import logging 9 | import random 10 | import math 11 | from sklearn import metrics 12 | from scipy import stats 13 | from os.path import join as pjoin 14 | from scipy.special import expit as sigmoid 15 | 16 | from collections import defaultdict 17 | from itertools import combinations, izip 18 | import torch 19 | import torch.optim as optim 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from torch.autograd import Variable 23 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 24 | from torchtext import data 25 | from util import MultiLabelField, ReversibleField, BCEWithLogitsLoss, MultiMarginHierarchyLoss 26 | 27 | 28 | def get_ci(vals, return_range=False): 29 | if len(set(vals)) == 1: 30 | return (vals[0], vals[0]) 31 | loc = np.mean(vals) 32 | scale = np.std(vals) / np.sqrt(len(vals)) 33 | range_0, range_1 = stats.t.interval(0.95, len(vals) - 1, loc=loc, scale=scale) 34 | if return_range: 35 | return range_0, range_1 36 | else: 37 | return range_1 - loc 38 | 39 | 40 | class Config(dict): 41 | def __init__(self, **kwargs): 42 | super(Config, self).__init__(**kwargs) 43 | self.__dict__.update(**kwargs) 44 | 45 | def __setitem__(self, key, item): 46 | self.__dict__[key] = item 47 | 48 | def __getitem__(self, key): 49 | return self.__dict__[key] 50 | 51 | def __repr__(self): 52 | return repr(self.__dict__) 53 | 54 | def __len__(self): 55 | return len(self.__dict__) 56 | 57 | def __delitem__(self, key): 58 | del self.__dict__[key] 59 | 60 | def clear(self): 61 | return self.__dict__.clear() 62 | 63 | def copy(self): 64 | return self.__dict__.copy() 65 | 66 | def has_key(self, k): 67 | return k in self.__dict__ 68 | 69 | def update(self, *args, **kwargs): 70 | return self.__dict__.update(*args, **kwargs) 71 | 72 | def keys(self): 73 | return self.__dict__.keys() 74 | 75 | def values(self): 76 | return self.__dict__.values() 77 | 78 | def items(self): 79 | return self.__dict__.items() 80 | 81 | def pop(self, *args): 82 | return self.__dict__.pop(*args) 83 | 84 | def __cmp__(self, dict_): 85 | return self.__cmp__(self.__dict__, dict_) 86 | 87 | def __contains__(self, item): 88 | return item in self.__dict__ 89 | 90 | def __iter__(self): 91 | return iter(self.__dict__) 92 | 93 | def __unicode__(self): 94 | return unicode(repr(self.__dict__)) 95 | 96 | 97 | # then we can make special class for different types of model 98 | # each config is used to build a classifier and a trainer, so one for each 99 | class LSTMBaseConfig(Config): 100 | def __init__(self, emb_dim=100, hidden_size=512, depth=1, label_size=42, bidir=False, 101 | c=False, m=False, co=False, 102 | dropout=0.2, emb_update=True, clip_grad=5., seed=1234, 103 | rand_unk=True, run_name="default", emb_corpus="gigaword", avg_run_times=1, 104 | conv_enc=0, 105 | **kwargs): 106 | # run_name: the folder for the trainer 107 | # c: cluster, m: meta, co: co-occurence constraint 108 | super(LSTMBaseConfig, self).__init__(emb_dim=emb_dim, 109 | hidden_size=hidden_size, 110 | depth=depth, 111 | label_size=label_size, 112 | bidir=bidir, 113 | c=c, 114 | m=m, 115 | co=co, 116 | dropout=dropout, 117 | emb_update=emb_update, 118 | clip_grad=clip_grad, 119 | seed=seed, 120 | rand_unk=rand_unk, 121 | run_name=run_name, 122 | emb_corpus=emb_corpus, 123 | avg_run_times=avg_run_times, 124 | conv_enc=conv_enc, 125 | **kwargs) 126 | 127 | 128 | class LSTM_w_C_Config(LSTMBaseConfig): 129 | def __init__(self, sigma_M, sigma_B, sigma_W, **kwargs): 130 | super(LSTM_w_C_Config, self).__init__(sigma_M=sigma_M, 131 | sigma_B=sigma_B, 132 | sigma_W=sigma_W, 133 | c=True, 134 | **kwargs) 135 | 136 | 137 | class LSTM_w_M_Config(LSTMBaseConfig): 138 | def __init__(self, beta, **kwargs): 139 | super(LSTM_w_M_Config, self).__init__(beta=beta, m=True, **kwargs) 140 | 141 | 142 | class LSTM_w_Co_config(LSTMBaseConfig): 143 | def __init__(self, x_max=100, alpha=0.75, gamma=1e-3, use_csu=True, 144 | use_pp=False, glove=False, ppmi=False, 145 | **kwargs): 146 | """ 147 | :param x_max: int (default: 100) 148 | Words with frequency greater than this are given weight 1.0. 149 | Words with frequency under this are given weight (c/xmax)**alpha 150 | where c is their count in mat (see the paper, eq. (9)). 151 | :param alpha: float (default: 0.75) 152 | Exponent in the weighting function (see the paper, eq. (9)). 153 | :param gamma: float(default=1e-3) 154 | The strength of this penalty 155 | :param use_csu: use co-occurence frequency from CSU 156 | :param use_pp: use co-occurence frequency from PP 157 | :param glove: use GlOVE style loss, otherwise 158 | :param ppmi: we want this to be false because if it's negative, then we want that too 159 | :param kwargs: 160 | """ 161 | super(LSTM_w_Co_config, self).__init__(co=True, 162 | x_max=x_max, 163 | alpha=alpha, 164 | gamma=gamma, 165 | use_csu=use_csu, 166 | use_pp=use_pp, 167 | glove=glove, 168 | ppmi=ppmi, 169 | **kwargs) 170 | 171 | 172 | """ 173 | Hierarchical ConvNet 174 | """ 175 | 176 | 177 | class ConvNetEncoder(nn.Module): 178 | def __init__(self, config): 179 | super(ConvNetEncoder, self).__init__() 180 | 181 | self.word_emb_dim = config['word_emb_dim'] 182 | self.enc_lstm_dim = config['enc_lstm_dim'] 183 | 184 | self.convnet1 = nn.Sequential( 185 | nn.Conv1d(self.word_emb_dim, 2 * self.enc_lstm_dim, kernel_size=3, 186 | stride=1, padding=1), 187 | nn.ReLU(inplace=True), 188 | ) 189 | self.convnet2 = nn.Sequential( 190 | nn.Conv1d(2 * self.enc_lstm_dim, 2 * self.enc_lstm_dim, kernel_size=3, 191 | stride=1, padding=1), 192 | nn.ReLU(inplace=True), 193 | ) 194 | self.convnet3 = nn.Sequential( 195 | nn.Conv1d(2 * self.enc_lstm_dim, 2 * self.enc_lstm_dim, kernel_size=3, 196 | stride=1, padding=1), 197 | nn.ReLU(inplace=True), 198 | ) 199 | self.convnet4 = nn.Sequential( 200 | nn.Conv1d(2 * self.enc_lstm_dim, 2 * self.enc_lstm_dim, kernel_size=3, 201 | stride=1, padding=1), 202 | nn.ReLU(inplace=True), 203 | ) 204 | 205 | def forward(self, sent_tuple): 206 | # sent_len: [max_len, ..., min_len] (batch) 207 | # sent: Variable(seqlen x batch x worddim) 208 | 209 | sent, sent_len = sent_tuple 210 | 211 | sent = sent.transpose(0, 1).transpose(1, 2).contiguous() 212 | # batch, nhid, seqlen) 213 | 214 | sent = self.convnet1(sent) 215 | u1 = torch.max(sent, 2)[0] 216 | 217 | sent = self.convnet2(sent) 218 | u2 = torch.max(sent, 2)[0] 219 | 220 | sent = self.convnet3(sent) 221 | u3 = torch.max(sent, 2)[0] 222 | 223 | sent = self.convnet4(sent) 224 | u4 = torch.max(sent, 2)[0] 225 | 226 | emb = torch.cat((u1, u2, u3, u4), 1) 227 | 228 | return emb 229 | 230 | 231 | """ 232 | Normal ConvNet 233 | """ 234 | 235 | 236 | class NormalConvNetEncoder(nn.Module): 237 | def __init__(self, config): 238 | super(NormalConvNetEncoder, self).__init__() 239 | self.word_emb_dim = config['word_emb_dim'] 240 | self.enc_lstm_dim = config['enc_lstm_dim'] 241 | self.conv = nn.Conv2d(in_channels=1, out_channels=self.enc_lstm_dim, kernel_size=(3, self.word_emb_dim), 242 | stride=(1, self.word_emb_dim)) 243 | 244 | def encode(self, inputs): 245 | output = inputs.transpose(0, 1).unsqueeze(1) # [batch_size, in_kernel, seq_length, embed_dim] 246 | output = F.relu(self.conv(output)) # conv -> [batch_size, out_kernel, seq_length, 1] 247 | output = output.squeeze(3).max(2)[0] # max_pool -> [batch_size, out_kernel] 248 | return output 249 | 250 | def forward(self, sent_tuple): 251 | # sent_len: [max_len, ..., min_len] (batch) 252 | # sent: Variable(seqlen x batch x worddim) 253 | sent, sent_len = sent_tuple 254 | emb = self.encode(sent) 255 | return emb 256 | 257 | 258 | """ 259 | https://github.com/Shawn1993/cnn-text-classification-pytorch/blob/master/model.py 260 | 352 stars 261 | """ 262 | 263 | 264 | class CNN_Text_Encoder(nn.Module): 265 | def __init__(self, config): 266 | super(CNN_Text_Encoder, self).__init__() 267 | 268 | self.word_emb_dim = config['word_emb_dim'] 269 | 270 | # V = args.embed_num 271 | # D = args.embed_dim 272 | # C = args.class_num 273 | Ci = 1 274 | Co = config['kernel_num'] # 100 275 | Ks = config['kernel_sizes'] # '3,4,5' 276 | # len(Ks)*Co 277 | 278 | # self.convs1 = [nn.Conv2d(Ci, Co, (K, D)) for K in Ks] 279 | self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (K, self.word_emb_dim)) for K in Ks]) 280 | ''' 281 | self.conv13 = nn.Conv2d(Ci, Co, (3, D)) 282 | self.conv14 = nn.Conv2d(Ci, Co, (4, D)) 283 | self.conv15 = nn.Conv2d(Ci, Co, (5, D)) 284 | ''' 285 | # self.dropout = nn.Dropout(args.dropout) 286 | # self.fc1 = nn.Linear(len(Ks) * Co, C) 287 | 288 | def conv_and_pool(self, x, conv): 289 | x = F.relu(conv(x)).squeeze(3) # (N, Co, W) 290 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 291 | return x 292 | 293 | def forward(self, x): 294 | # x = self.embed(x) # (N, W, D) 295 | 296 | x = x[0].transpose(0, 1).unsqueeze(1) 297 | # x = x.unsqueeze(1) # (N, Ci, W, D) 298 | 299 | x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] # [(N, Co, W), ...]*len(Ks) 300 | 301 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N, Co), ...]*len(Ks) 302 | 303 | x = torch.cat(x, 1) 304 | 305 | ''' 306 | x1 = self.conv_and_pool(x,self.conv13) #(N,Co) 307 | x2 = self.conv_and_pool(x,self.conv14) #(N,Co) 308 | x3 = self.conv_and_pool(x,self.conv15) #(N,Co) 309 | x = torch.cat((x1, x2, x3), 1) # (N,len(Ks)*Co) 310 | ''' 311 | # x = self.dropout(x) # (N, len(Ks)*Co) 312 | # logit = self.fc1(x) # (N, C) 313 | return x 314 | 315 | 316 | class Classifier(nn.Module): 317 | def __init__(self, vocab, config): 318 | super(Classifier, self).__init__() 319 | self.config = config 320 | self.drop = nn.Dropout(config.dropout) # embedding dropout 321 | if config.conv_enc == 1: 322 | kernel_size = config.hidden_size / 8 323 | print(kernel_size) 324 | self.encoder = ConvNetEncoder({ 325 | 'word_emb_dim': config.emb_dim, 326 | 'enc_lstm_dim': kernel_size if not config.bidir else kernel_size * 2 327 | }) 328 | d_out = config.hidden_size if not config.bidir else config.hidden_size * 2 329 | elif config.conv_enc == 2: 330 | kernel_size = config.hidden_size 331 | print(kernel_size) 332 | self.encoder = NormalConvNetEncoder({ 333 | 'word_emb_dim': config.emb_dim, 334 | 'enc_lstm_dim': kernel_size if not config.bidir else kernel_size * 2 335 | }) 336 | d_out = config.hidden_size if not config.bidir else config.hidden_size * 2 337 | elif config.conv_enc == 3: 338 | kernel_num = config.hidden_size / 3 339 | kernel_num = kernel_num if not config.bidir else kernel_num * 2 340 | self.encoder = CNN_Text_Encoder({ 341 | 'word_emb_dim': config.emb_dim, 342 | 'kernel_sizes': [3, 4, 5], 343 | 'kernel_num': kernel_num 344 | }) 345 | d_out = len([3, 4, 5]) * kernel_num 346 | else: 347 | self.encoder = nn.LSTM( 348 | config.emb_dim, 349 | config.hidden_size, 350 | config.depth, 351 | dropout=config.dropout, 352 | bidirectional=config.bidir) # ha...not even bidirectional 353 | d_out = config.hidden_size if not config.bidir else config.hidden_size * 2 354 | 355 | self.out = nn.Linear(d_out, config.label_size) # include bias, to prevent bias assignment 356 | self.embed = nn.Embedding(len(vocab), config.emb_dim) 357 | self.embed.weight.data.copy_(vocab.vectors) 358 | self.embed.weight.requires_grad = True if config.emb_update else False 359 | 360 | def forward(self, input, lengths=None): 361 | output_vecs = self.get_vectors(input, lengths) 362 | return self.get_logits(output_vecs) 363 | 364 | def get_vectors(self, input, lengths=None): 365 | embed_input = self.embed(input) 366 | 367 | if self.config.conv_enc: 368 | output = self.encoder((embed_input, lengths.view(-1).tolist())) 369 | return output 370 | 371 | packed_emb = embed_input 372 | if lengths is not None: 373 | lengths = lengths.view(-1).tolist() 374 | packed_emb = nn.utils.rnn.pack_padded_sequence(embed_input, lengths) 375 | 376 | output, hidden = self.encoder(packed_emb) # embed_input 377 | 378 | if lengths is not None: 379 | output = unpack(output)[0] 380 | 381 | # we ignored negative masking 382 | return output 383 | 384 | def get_logits(self, output_vec): 385 | if self.config.conv_enc: 386 | output = output_vec 387 | else: 388 | output = torch.max(output_vec, 0)[0].squeeze(0) 389 | return self.out(output) 390 | 391 | def get_softmax_weight(self): 392 | return self.out.weight 393 | 394 | 395 | """ 396 | Interpretation module 397 | """ 398 | 399 | 400 | def propagate_three(a, b, c, activation): 401 | a_contrib = 0.5 * (activation(a + c) - activation(c) + 402 | activation(a + b + c) - activation(b + c)) 403 | b_contrib = 0.5 * (activation(b + c) - activation(c) + 404 | activation(a + b + c) - activation(a + c)) 405 | return a_contrib, b_contrib, activation(c) 406 | 407 | 408 | # propagate tanh nonlinearity 409 | def propagate_tanh_two(a, b): 410 | return 0.5 * (np.tanh(a) + (np.tanh(a + b) - np.tanh(b))), 0.5 * (np.tanh(b) + (np.tanh(a + b) - np.tanh(a))) 411 | 412 | 413 | def propagate_max_two(a, b, d=0): 414 | # need to return a, b with the same shape... 415 | indices = np.argmax(a + b, axis=d) 416 | a_mask = np.zeros_like(a) 417 | a_mask[indices, np.arange(a.shape[1])] = 1 418 | a = a * a_mask 419 | 420 | b_mask = np.zeros_like(b) 421 | b_mask[indices, np.arange(b.shape[1])] = 1 422 | b = b * b_mask 423 | 424 | return a, b 425 | 426 | 427 | class BaseLSTM(object): 428 | def __init__(self, model, bilstm=False): 429 | self.model = model 430 | weights = model.encoder.state_dict() 431 | 432 | self.optimizer = torch.optim.SGD(self.model.parameters(), 0.1) 433 | 434 | self.hidden_dim = model.config.hidden_size 435 | 436 | self.W_ii, self.W_if, self.W_ig, self.W_io = np.split( 437 | weights['weight_ih_l0'], 4, 0) 438 | self.W_hi, self.W_hf, self.W_hg, self.W_ho = np.split( 439 | weights['weight_hh_l0'], 4, 0) 440 | self.b_i, self.b_f, self.b_g, self.b_o = np.split( 441 | weights['bias_ih_l0'].numpy() + weights['bias_hh_l0'].numpy(), 442 | 4) 443 | 444 | if bilstm: 445 | self.rev_W_ii, self.rev_W_if, self.rev_W_ig, self.rev_W_io = np.split( 446 | weights['weight_ih_l0_reverse'], 4, 0) 447 | self.rev_W_hi, self.rev_W_hf, self.rev_W_hg, self.rev_W_ho = np.split( 448 | weights['weight_hh_l0_reverse'], 4, 0) 449 | self.rev_b_i, self.rev_b_f, self.rev_b_g, self.rev_b_o = np.split( 450 | weights['bias_ih_l0_reverse'].numpy( 451 | ) + weights['bias_hh_l0_reverse'].numpy(), 452 | 4) 453 | 454 | self.word_emb_dim = 100 455 | 456 | self.classifiers = [ 457 | (self.model.out.weight.data.numpy(), 458 | self.model.out.bias.data.numpy()) 459 | ] 460 | 461 | def zero_grad(self): 462 | self.optimizer.zero_grad() 463 | 464 | def classify(self, final_res): 465 | # note that u, v could be positional!! don't mix the two 466 | for c in self.classifiers: 467 | w, b = c 468 | final_res = np.dot(w, final_res) + b 469 | return final_res 470 | 471 | 472 | class MaxPoolingCDBiLSTM(BaseLSTM): 473 | def cell(self, prev_h, prev_c, x_i): 474 | # x_i = word_vecs[i] 475 | rel_i = np.dot(self.W_hi, prev_h) 476 | rel_g = np.dot(self.W_hg, prev_h) 477 | rel_f = np.dot(self.W_hf, prev_h) 478 | rel_o = np.dot(self.W_ho, prev_h) 479 | 480 | rel_i = sigmoid(rel_i + np.dot(self.W_ii, x_i) + self.b_i) 481 | rel_g = np.tanh(rel_g + np.dot(self.W_ig, x_i) + self.b_g) 482 | rel_f = sigmoid(rel_f + np.dot(self.W_if, x_i) + self.b_f) 483 | rel_o = sigmoid(rel_o + np.dot(self.W_io, x_i) + self.b_o) 484 | 485 | c_t = rel_f * prev_c + rel_i * rel_g 486 | h_t = rel_o * np.tanh(c_t) 487 | 488 | return h_t, c_t 489 | 490 | def rev_cell(self, prev_h, prev_c, x_i): 491 | # x_i = word_vecs[i] 492 | rel_i = np.dot(self.rev_W_hi, prev_h) 493 | rel_g = np.dot(self.rev_W_hg, prev_h) 494 | rel_f = np.dot(self.rev_W_hf, prev_h) 495 | rel_o = np.dot(self.rev_W_ho, prev_h) 496 | 497 | rel_i = sigmoid(rel_i + np.dot(self.rev_W_ii, x_i) + self.rev_b_i) 498 | rel_g = np.tanh(rel_g + np.dot(self.rev_W_ig, x_i) + self.rev_b_g) 499 | rel_f = sigmoid(rel_f + np.dot(self.rev_W_if, x_i) + self.rev_b_f) 500 | rel_o = sigmoid(rel_o + np.dot(self.rev_W_io, x_i) + self.rev_b_o) 501 | 502 | c_t = rel_f * prev_c + rel_i * rel_g 503 | h_t = rel_o * np.tanh(c_t) 504 | 505 | return h_t, c_t 506 | 507 | def run_bi_lstm(self, sent): 508 | # this is used as validation 509 | # sent: [legnth, dim=100] 510 | word_vecs = sent 511 | 512 | T = word_vecs.shape[0] 513 | 514 | hidden_states = np.zeros((T, self.hidden_dim)) 515 | rev_hidden_states = np.zeros((T, self.hidden_dim)) 516 | 517 | cell_states = np.zeros((T, self.hidden_dim)) 518 | rev_cell_states = np.zeros((T, self.hidden_dim)) 519 | 520 | for i in range(T): 521 | if i > 0: 522 | # this is just the prev hidden state 523 | prev_h = hidden_states[i - 1] 524 | prev_c = cell_states[i - 1] 525 | else: 526 | prev_h = np.zeros(self.hidden_dim) 527 | prev_c = np.zeros(self.hidden_dim) 528 | 529 | new_h, new_c = self.cell(prev_h, prev_c, word_vecs[i]) 530 | 531 | hidden_states[i] = new_h 532 | cell_states[i] = new_c 533 | 534 | for i in reversed(range(T)): 535 | # 20, 19, 18, 17, ... 536 | if i < T - 1: 537 | # this is just the prev hidden state 538 | prev_h = rev_hidden_states[i + 1] 539 | prev_c = rev_cell_states[i + 1] 540 | else: 541 | prev_h = np.zeros(self.hidden_dim) 542 | prev_c = np.zeros(self.hidden_dim) 543 | 544 | new_h, new_c = self.rev_cell(prev_h, prev_c, word_vecs[i]) 545 | 546 | rev_hidden_states[i] = new_h 547 | rev_cell_states[i] = new_c 548 | 549 | # stack second dimension 550 | return np.hstack([hidden_states, rev_hidden_states]), np.hstack([cell_states, rev_cell_states]) 551 | 552 | def get_word_level_scores(self, sentence, sentence_len, label_idx): 553 | """ 554 | :param sentence: word embeddings of [T, d] 555 | :return: 556 | """ 557 | # texts = gen_tiles(text_orig, method='cd', sweep_dim=1).transpose() 558 | # starts, stops = tiles_to_cd(texts) 559 | # [0, 1, 2,...], [0, 1, 2,...] 560 | 561 | self.zero_grad() 562 | 563 | # contextual decomposition 564 | rel_A, irrel_A = self.cd_encode(sentence) # already masked 565 | 566 | # Gradient part! 567 | # now we actually fire up the encoder, and get gradients w.r.t. hidden states 568 | # run the actual model to compute gradients 569 | sentence_emb = self.model.embed(sentence) 570 | lengths = sentence_len.view(-1).tolist() 571 | 572 | output, hidden = self.model.encoder(sentence_emb) 573 | # output_vec = unpack(output)[0] 574 | sent_output = torch.max(output, 0)[0].squeeze(0) 575 | clf_output = self.model.out(sent_output) 576 | # output_vec is the hidden states we want!! (T, hid_state_dim) 577 | 578 | # TODO: fix this part 579 | # y = clf_output[label_idx] 580 | # label_id = torch.max(clf_output, 0)[1] 581 | 582 | # compute A score 583 | y = clf_output[label_idx] 584 | grad = torch.autograd.grad(y, output, retain_graph=True)[0] 585 | 586 | scores_A = grad.data.squeeze() * torch.from_numpy(rel_A).float() 587 | 588 | # (sent_len, num_label) 589 | return scores_A.sum(dim=1), clf_output 590 | 591 | def extract_keywords(self, sentence, sentence_len, dataset, score_values, label_keyword_dict, label_size=42, threshold=0.2): 592 | # sentence: x 593 | # sentence_len: x_len 594 | # text_score_tup_list: [('surgery', 4.0), ...] 595 | 596 | self.zero_grad() 597 | 598 | # contextual decomposition 599 | rel_A, irrel_A = self.cd_encode(sentence) # already masked 600 | 601 | # Gradient part! 602 | # now we actually fire up the encoder, and get gradients w.r.t. hidden states 603 | # run the actual model to compute gradients 604 | sentence_emb = self.model.embed(sentence) 605 | lengths = sentence_len.view(-1).tolist() 606 | # packed_emb = nn.utils.rnn.pack_padded_sequence(sentence_emb, lengths) 607 | # output, hidden = self.model.encoder(packed_emb) 608 | output, hidden = self.model.encoder(sentence_emb) 609 | # output_vec = unpack(output)[0] 610 | sent_output = torch.max(output, 0)[0].squeeze(0) 611 | clf_output = self.model.out(sent_output) 612 | # output_vec is the hidden states we want!! (T, hid_state_dim) 613 | 614 | # y = clf_output[label_idx] 615 | # label_id = torch.max(clf_output, 0)[1] 616 | 617 | text = [dataset.TEXT.vocab.itos[idx] for idx in sentence.data] 618 | 619 | # compute A score 620 | for label_idx in range(label_size): 621 | self.zero_grad() 622 | 623 | y = clf_output[label_idx] 624 | grad = torch.autograd.grad(y, output, retain_graph=True)[0] 625 | 626 | scores_A = grad.data.squeeze() * torch.from_numpy(rel_A).float() 627 | scores_A = scores_A.sum(dim=1).data.squeeze().numpy().tolist() 628 | 629 | assert len(scores_A) == len(text) 630 | 631 | score_values.extend(scores_A) 632 | 633 | for g, t in zip(scores_A, text): 634 | if g > threshold: 635 | label_keyword_dict[label_idx].append(t) 636 | 637 | # we don't return anything :) 638 | return 639 | 640 | def cd_encode(self, sentences): 641 | rel_h, irrel_h, _ = self.flat_cd_text(sentences) 642 | rev_rel_h, rev_irrel_h, _ = self.flat_cd_text(sentences, reverse=True) 643 | rel = np.hstack([rel_h, rev_rel_h]) # T, 2*d 644 | irrel = np.hstack([irrel_h, rev_irrel_h]) # T, 2*d 645 | # again, hidden-states = rel + irrel 646 | 647 | # we mask both 648 | rel_masked, irrel_masked = propagate_max_two(rel, irrel) 649 | 650 | # (2*d), actual sentence representation 651 | return rel_masked, irrel_masked 652 | 653 | def flat_cd_text(self, sentence, reverse=False): 654 | # collects relevance for word 0 to sent_length 655 | # not considering interactions between words; merely collecting word contribution 656 | 657 | # word_vecs = self.model.embed(batch.text)[:, 0].data 658 | word_vecs = self.model.embed(sentence).squeeze().data.numpy() 659 | 660 | T = word_vecs.shape[0] 661 | 662 | # so prev_h is always irrelevant 663 | # there's no rel_h because we only look at each time step individually 664 | 665 | # relevant cell states, irrelevant cell states 666 | relevant = np.zeros((T, self.hidden_dim)) 667 | irrelevant = np.zeros((T, self.hidden_dim)) 668 | 669 | relevant_h = np.zeros((T, self.hidden_dim)) 670 | # keep track of the entire hidden state 671 | irrelevant_h = np.zeros((T, self.hidden_dim)) 672 | 673 | hidden_states = np.zeros((T, self.hidden_dim)) 674 | cell_states = np.zeros((T, self.hidden_dim)) 675 | 676 | if not reverse: 677 | W_ii, W_if, W_ig, W_io = self.W_ii, self.W_if, self.W_ig, self.W_io 678 | W_hi, W_hf, W_hg, W_ho = self.W_hi, self.W_hf, self.W_hg, self.W_ho 679 | b_i, b_f, b_g, b_o = self.b_i, self.b_f, self.b_g, self.b_o 680 | else: 681 | W_ii, W_if, W_ig, W_io = self.rev_W_ii, self.rev_W_if, self.rev_W_ig, self.rev_W_io 682 | W_hi, W_hf, W_hg, W_ho = self.rev_W_hi, self.rev_W_hf, self.rev_W_hg, self.rev_W_ho 683 | b_i, b_f, b_g, b_o = self.rev_b_i, self.rev_b_f, self.rev_b_g, self.rev_b_o 684 | 685 | # strategy: keep using prev_h as irrel_h 686 | # every time, make sure h = irrel + rel, then prev_h = h 687 | 688 | indices = range(T) if not reverse else reversed(range(T)) 689 | for i in indices: 690 | first_cond = i > 0 if not reverse else i < T - 1 691 | if first_cond: 692 | ret_idx = i - 1 if not reverse else i + 1 693 | prev_c = cell_states[ret_idx] 694 | prev_h = hidden_states[ret_idx] 695 | else: 696 | prev_c = np.zeros(self.hidden_dim) 697 | prev_h = np.zeros(self.hidden_dim) 698 | 699 | irrel_i = np.dot(W_hi, prev_h) 700 | irrel_g = np.dot(W_hg, prev_h) 701 | irrel_f = np.dot(W_hf, prev_h) 702 | irrel_o = np.dot(W_ho, prev_h) 703 | 704 | rel_i = np.dot(W_ii, word_vecs[i]) 705 | rel_g = np.dot(W_ig, word_vecs[i]) 706 | rel_f = np.dot(W_if, word_vecs[i]) 707 | rel_o = np.dot(W_io, word_vecs[i]) 708 | 709 | # this remains unchanged 710 | rel_contrib_i, irrel_contrib_i, bias_contrib_i = propagate_three( 711 | rel_i, irrel_i, b_i, sigmoid) 712 | rel_contrib_g, irrel_contrib_g, bias_contrib_g = propagate_three( 713 | rel_g, irrel_g, b_g, np.tanh) 714 | 715 | relevant[i] = rel_contrib_i * (rel_contrib_g + bias_contrib_g) + \ 716 | bias_contrib_i * rel_contrib_g 717 | irrelevant[i] = irrel_contrib_i * (rel_contrib_g + irrel_contrib_g + bias_contrib_g) + \ 718 | (rel_contrib_i + bias_contrib_i) * irrel_contrib_g 719 | 720 | relevant[i] += bias_contrib_i * bias_contrib_g 721 | # if i >= start and i < stop: 722 | # relevant[i] += bias_contrib_i * bias_contrib_g 723 | # else: 724 | # irrelevant[i] += bias_contrib_i * bias_contrib_g 725 | 726 | cond = i > 0 if not reverse else i < T - 1 727 | if cond: 728 | rel_contrib_f, irrel_contrib_f, bias_contrib_f = propagate_three( 729 | rel_f, irrel_f, b_f, sigmoid) 730 | 731 | # not sure if this is completely correct 732 | irrelevant[i] += (rel_contrib_f + 733 | irrel_contrib_f + bias_contrib_f) * prev_c 734 | 735 | # recompute o-gate 736 | o = sigmoid(rel_o + irrel_o + b_o) 737 | rel_contrib_o, irrel_contrib_o, bias_contrib_o = propagate_three( 738 | rel_o, irrel_o, b_o, sigmoid) 739 | # from current cell state 740 | new_rel_h, new_irrel_h = propagate_tanh_two( 741 | relevant[i], irrelevant[i]) 742 | # relevant_h[i] = new_rel_h * (rel_contrib_o + bias_contrib_o) 743 | # irrelevant_h[i] = new_rel_h * (irrel_contrib_o) + new_irrel_h * (rel_contrib_o + irrel_contrib_o + bias_contrib_o) 744 | relevant_h[i] = o * new_rel_h 745 | irrelevant_h[i] = o * new_irrel_h 746 | 747 | hidden_states[i] = relevant_h[i] + irrelevant_h[i] 748 | cell_states[i] = relevant[i] + irrelevant[i] 749 | 750 | return relevant_h, irrelevant_h, hidden_states 751 | 752 | 753 | # this dataset can also take in 5-class classification 754 | class Dataset(object): 755 | def __init__(self, path='./data/csu/', 756 | dataset_prefix='snomed_multi_label_no_des_', 757 | # test_data_name='adobe_abbr_matched_snomed_multi_label_no_des_test.tsv', 758 | test_data_name='adobe_combined_abbr_matched_snomed_multi_label_no_des_test.tsv', 759 | # change this to 'adobe_combined_abbr_matched_snomed_multi_label_no_des_test.tsv' 760 | label_size=42, fix_length=None): 761 | self.TEXT = ReversibleField(sequential=True, include_lengths=True, lower=False, fix_length=fix_length) 762 | self.LABEL = MultiLabelField(sequential=True, use_vocab=False, label_size=label_size, 763 | tensor_type=torch.FloatTensor, fix_length=fix_length) 764 | 765 | # it's actually this step that will take 5 minutes 766 | self.train, self.val, self.test = data.TabularDataset.splits( 767 | path=path, train=dataset_prefix + 'train.tsv', 768 | validation=dataset_prefix + 'valid.tsv', 769 | test=dataset_prefix + 'test.tsv', format='tsv', 770 | fields=[('Text', self.TEXT), ('Description', self.LABEL)]) 771 | 772 | self.external_test = data.TabularDataset(path=path + test_data_name, 773 | format='tsv', 774 | fields=[('Text', self.TEXT), ('Description', self.LABEL)]) 775 | 776 | self.is_vocab_bulit = False 777 | self.iterators = [] 778 | self.test_iterator = None 779 | 780 | def init_emb(self, vocab, init="randn", num_special_toks=2, silent=False): 781 | # we can try randn or glorot 782 | # mode="unk"|"all", all means initialize everything 783 | emb_vectors = vocab.vectors 784 | sweep_range = len(vocab) 785 | running_norm = 0. 786 | num_non_zero = 0 787 | total_words = 0 788 | for i in range(num_special_toks, sweep_range): 789 | if len(emb_vectors[i, :].nonzero()) == 0: 790 | # std = 0.5 is based on the norm of average GloVE word vectors 791 | if init == "randn": 792 | torch.nn.init.normal(emb_vectors[i], mean=0, std=0.5) 793 | else: 794 | num_non_zero += 1 795 | running_norm += torch.norm(emb_vectors[i]) 796 | total_words += 1 797 | if not silent: 798 | print("average GloVE norm is {}, number of known words are {}, total number of words are {}".format( 799 | running_norm / num_non_zero, num_non_zero, total_words)) # directly printing into Jupyter Notebook 800 | 801 | def build_vocab(self, config, silent=False): 802 | if config.emb_corpus == 'common_crawl': 803 | self.TEXT.build_vocab(self.train, vectors="glove.840B.300d") 804 | config.emb_dim = 300 # change the config emb dimension 805 | else: 806 | self.TEXT.build_vocab(self.train, vectors="glove.6B.{}d".format(config.emb_dim)) 807 | self.is_vocab_bulit = True 808 | self.vocab = self.TEXT.vocab 809 | if config.rand_unk: 810 | if not silent: 811 | print("initializing random vocabulary") 812 | self.init_emb(self.vocab, silent=silent) 813 | 814 | def get_iterators(self, device, val_batch_size=128): 815 | if not self.is_vocab_bulit: 816 | raise Exception("Vocabulary is not built yet..needs to call build_vocab()") 817 | 818 | if len(self.iterators) > 0: 819 | return self.iterators # return stored iterator 820 | 821 | # only get them after knowing the device (inside trainer or evaluator) 822 | train_iter, val_iter, test_iter = data.Iterator.splits( 823 | (self.train, self.val, self.test), sort_key=lambda x: len(x.Text), # no global sort, but within-batch-sort 824 | batch_sizes=(32, val_batch_size, val_batch_size), device=device, 825 | sort_within_batch=True, repeat=False) 826 | 827 | return train_iter, val_iter, test_iter 828 | 829 | def get_test_iterator(self, device): 830 | if not self.is_vocab_bulit: 831 | raise Exception("Vocabulary is not built yet..needs to call build_vocab()") 832 | 833 | if self.test_iterator is not None: 834 | return self.test_iterator 835 | 836 | external_test_iter = data.Iterator(self.external_test, 128, sort_key=lambda x: len(x.Text), 837 | device=device, train=False, repeat=False, sort_within_batch=True) 838 | return external_test_iter 839 | 840 | def get_lm_iterator(self, device): 841 | # get language modeling data iterators 842 | pass 843 | 844 | 845 | # compute loss 846 | class ClusterLoss(nn.Module): 847 | def __init__(self, config, cluster_path='./data/csu/snomed_label_to_meta_grouping.json'): 848 | super(ClusterLoss, self).__init__() 849 | 850 | with open(cluster_path, 'rb') as f: 851 | label_grouping = json.load(f) 852 | 853 | self.meta_category_groups = label_grouping.values() 854 | self.config = config 855 | 856 | def forward(self, softmax_weight, batch_size): 857 | w_bar = softmax_weight.sum(1) / self.config.label_size # w_bar 858 | 859 | omega_mean = softmax_weight.pow(2).sum() 860 | omega_between = 0. 861 | omega_within = 0. 862 | 863 | for c in xrange(len(self.meta_category_groups)): 864 | m_c = len(self.meta_category_groups[c]) 865 | w_c_bar = softmax_weight[:, self.meta_category_groups[c]].sum(1) / m_c 866 | omega_between += m_c * (w_c_bar - w_bar).pow(2).sum() 867 | for i in self.meta_category_groups[c]: 868 | # this value will be 0 for singleton group 869 | omega_within += (softmax_weight[:, i] - w_c_bar).pow(2).sum() 870 | 871 | aux_loss = omega_mean * self.config.sigma_M + (omega_between * self.config.sigma_B + 872 | omega_within * self.config.sigma_W) / batch_size 873 | 874 | return aux_loss 875 | 876 | 877 | class MetaLoss(nn.Module): 878 | def __init__(self, config, cluster_path='./data/csu/snomed_label_to_meta_grouping.json', 879 | label_to_meta_map_path='./data/csu/snomed_label_to_meta_map.json'): 880 | super(MetaLoss, self).__init__() 881 | 882 | with open(cluster_path, 'rb') as f: 883 | self.label_grouping = json.load(f) 884 | 885 | with open(label_to_meta_map_path, 'rb') as f: 886 | self.meta_label_mapping = json.load(f) 887 | 888 | self.meta_label_size = len(self.label_grouping) 889 | self.config = config 890 | 891 | # your original classifier did this wrong...found a bug 892 | self.bce_loss = nn.BCELoss() # this takes in probability (after sigmoid) 893 | 894 | # now that this becomes somewhat independent...maybe you can examine this more closely? 895 | def generate_meta_y(self, indices, meta_label_size, batch_size): 896 | a = np.array([[0.] * meta_label_size for _ in range(batch_size)], dtype=np.float32) 897 | matched = defaultdict(set) 898 | for b, l in indices: 899 | if b not in matched: 900 | a[b, self.meta_label_mapping[str(l)]] = 1. 901 | matched[b].add(self.meta_label_mapping[str(l)]) 902 | elif self.meta_label_mapping[str(l)] not in matched[b]: 903 | a[b, self.meta_label_mapping[str(l)]] = 1. 904 | matched[b].add(self.meta_label_mapping[str(l)]) 905 | assert np.sum(a <= 1) == a.size 906 | return a 907 | 908 | def forward(self, logits, true_y, device): 909 | batch_size = logits.size(0) 910 | y_hat = torch.sigmoid(logits) 911 | meta_probs = [] 912 | for i in range(self.meta_label_size): 913 | # 1 - (1 - p_1)(...)(1 - p_n) 914 | meta_prob = (1 - y_hat[:, self.label_grouping[str(i)]]).prod(1) 915 | meta_probs.append(meta_prob) # in this version we don't do threshold....(originally we did) 916 | 917 | meta_probs = torch.stack(meta_probs, dim=1) 918 | assert meta_probs.size(1) == self.meta_label_size 919 | 920 | # generate meta-label 921 | y_indices = true_y.nonzero() 922 | meta_y = self.generate_meta_y(y_indices.data.cpu().numpy().tolist(), self.meta_label_size, 923 | batch_size) 924 | meta_y = Variable(torch.from_numpy(meta_y)) if device == -1 else Variable(torch.from_numpy(meta_y)).cuda(device) 925 | 926 | meta_loss = self.bce_loss(meta_probs, meta_y) * self.config.beta 927 | return meta_loss 928 | 929 | 930 | def log_of_array_ignoring_zeros(M): 931 | """Returns an array containing the logs of the nonzero 932 | elements of M. Zeros are left alone since log(0) isn't 933 | defined. 934 | """ 935 | log_M = M.copy() 936 | mask = log_M > 0 937 | log_M[mask] = np.log(log_M[mask]) 938 | return log_M 939 | 940 | 941 | def observed_over_expected(df): 942 | col_totals = df.sum(axis=0) 943 | total = col_totals.sum() 944 | row_totals = df.sum(axis=1) 945 | expected = np.outer(row_totals, col_totals) / total 946 | oe = df / expected 947 | return oe 948 | 949 | 950 | def pmi(df, positive=True): 951 | df = observed_over_expected(df) 952 | # Silence distracting warnings about log(0): 953 | with np.errstate(divide='ignore'): 954 | df = np.log(df) 955 | df[np.isnan(df)] = 0.0 # log(0) = 0 956 | if positive: 957 | df[df < 0] = 0.0 958 | return df 959 | 960 | 961 | class CoOccurenceLoss(nn.Module): 962 | def __init__(self, config, 963 | csu_path='./data/csu/label_co_matrix.npy', 964 | pp_path='./data/csu/pp_combined_label_co_matrix.npy', 965 | device=-1): 966 | super(CoOccurenceLoss, self).__init__() 967 | self.co_mat_path = csu_path if config.use_csu else pp_path 968 | self.co_mat = np.load(self.co_mat_path) 969 | self.X = self.co_mat 970 | self.glove = self.config.glove 971 | 972 | logging.info("using co_matrix {}".format(self.co_mat_path)) 973 | self.n = config.hidden_size # N-dim rep 974 | self.m = config.label_size 975 | 976 | self.gamma = self.config.gamma 977 | if self.glove: 978 | self.C = torch.empty(self.m, self.n) 979 | self.C = Variable(self.C.uniform_(-0.5, 0.5)).cuda(device) 980 | self.B = torch.empty(2, self.m) 981 | self.B = Variable(self.B.uniform_(-0.5, 0.5)).cuda(device) 982 | 983 | self.indices = list(range(self.m)) # label_size 984 | 985 | # Precomputable GloVe values: 986 | self.X_log = log_of_array_ignoring_zeros(self.X) 987 | self.X_weights = (np.minimum(self.X, config.xmax) / config.xmax) ** config.alpha # eq. (9) 988 | 989 | # iterate on the upper triangular matrix, off-diagonal 990 | self.iu1 = np.triu_indices(41, 1) # 820 iterations 991 | else: 992 | self.X = Variable(pmi(self.X, positive=self.config.ppmi), requires_grad=False).cuda(device) 993 | self.mse = nn.MSELoss() 994 | 995 | def forward(self, softmax_weight): 996 | # this computes a straight-through pass of the GloVE objective 997 | # similar to "Auxiliary" training 998 | # return the loss 999 | # softmax_weight: [d, |Y|] 1000 | if self.glove: 1001 | loss = 0. 1002 | for i, j in zip(self.iu1[0], self.iu1[1]): 1003 | if self.X[i, j] > 0.0: 1004 | # Cost is J' based on eq. (8) in the paper: 1005 | # (1, |Y|) dot (1, |Y|) 1006 | diff = softmax_weight[:, i].dot(self.C[j]) + self.B[0, i] + self.B[1, j] - self.X_log[i, j] 1007 | loss += self.X_weights[i, j] * diff # f(X_ij) * (w_i w_j + b_i + b_j - log X_ij) 1008 | # this is the summation, not average 1009 | else: 1010 | # softmax_weight: (d, m) 1011 | # (m, d) (d, m) 1012 | a = torch.matmul(torch.transpose(softmax_weight, 1, 0), softmax_weight) 1013 | loss = self.mse(a, self.X) 1014 | return loss * self.gamma 1015 | 1016 | 1017 | # maybe we should evaluate inside this 1018 | # currently each Trainer is tied to one GPU, so we don't have to worry about 1019 | # Each trainer is associated with a config and classifier actually...so should be associated with a log 1020 | # Experiment class will create a central folder, and it will have sub-folder for each trainer 1021 | # central folder will have an overall summary...(Experiment will also have ways to do 5 random seed exp) 1022 | class Trainer(object): 1023 | def __init__(self, classifier, dataset, config, save_path, device, load=False, run_order=0, 1024 | **kwargs): 1025 | # save_path: where to save log and model 1026 | if load: 1027 | # or we can add a new keyword... 1028 | if os.path.exists(pjoin(save_path, 'model-{}.pickle'.format(run_order))): 1029 | self.classifier = torch.load(pjoin(save_path, 'model-{}.pickle'.format(run_order))).cuda(device) 1030 | else: 1031 | self.classifier = torch.load(pjoin(save_path, 'model.pickle')).cuda(device) 1032 | else: 1033 | self.classifier = classifier.cuda(device) 1034 | 1035 | # replace old cached config with new config 1036 | self.classifier.config = config 1037 | 1038 | self.dataset = dataset 1039 | self.device = device 1040 | self.config = config 1041 | self.save_path = save_path 1042 | 1043 | self.train_iter, self.val_iter, self.test_iter = self.dataset.get_iterators(device) 1044 | self.external_test_iter = self.dataset.get_test_iterator(device) 1045 | 1046 | if config.m: 1047 | self.aux_loss = MetaLoss(config, **kwargs) 1048 | elif config.c: 1049 | self.aux_loss = ClusterLoss(config, **kwargs) 1050 | 1051 | self.bce_logit_loss = BCEWithLogitsLoss(reduce=False) 1052 | 1053 | need_grad = lambda x: x.requires_grad 1054 | self.optimizer = optim.Adam( 1055 | filter(need_grad, classifier.parameters()), 1056 | lr=0.001) # obviously we could use config to control this 1057 | 1058 | # setting up logging 1059 | if not os.path.exists(save_path): 1060 | os.makedirs(save_path) 1061 | logging.basicConfig(format='[%(asctime)s] %(levelname)s: %(message)s', 1062 | datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO) 1063 | file_handler = logging.FileHandler("{0}/log.txt".format(save_path)) 1064 | self.logger = logging.getLogger(save_path.split('/')[-1]) # so that no model is sharing logger 1065 | self.logger.addHandler(file_handler) 1066 | 1067 | self.logger.info(config) 1068 | 1069 | def load(self, run_order): 1070 | self.classifier = torch.load(pjoin(self.save_path, 'model-{}.pickle').format(run_order)).cuda(self.device) 1071 | 1072 | def pretrain(self, epochs=15): 1073 | # train loop 1074 | # even though without attention...LM can still play a role here to learn word embeddings 1075 | pass 1076 | 1077 | def train(self, run_order=0, epochs=5, no_print=True): 1078 | # train loop 1079 | exp_cost = None 1080 | for e in range(epochs): 1081 | self.classifier.train() 1082 | for iter, data in enumerate(self.train_iter): 1083 | self.classifier.zero_grad() 1084 | (x, x_lengths), y = data.Text, data.Description 1085 | 1086 | # output_vec = self.classifier.get_vectors(x, x_lengths) # this is just logit (before calling sigmoid) 1087 | # final_rep = torch.max(output_vec, 0)[0].squeeze(0) 1088 | # logits = self.classifier.get_logits(output_vec) 1089 | 1090 | logits = self.classifier(x, x_lengths) 1091 | 1092 | batch_size = x.size(0) 1093 | 1094 | if self.config.c: 1095 | softmax_weight = self.classifier.get_softmax_weight() 1096 | aux_loss = self.aux_loss(softmax_weight, batch_size) 1097 | elif self.config.m: 1098 | aux_loss = self.aux_loss(logits, y, self.device) 1099 | else: 1100 | aux_loss = 0. 1101 | 1102 | loss = self.bce_logit_loss(logits, y).mean() + aux_loss 1103 | loss.backward() 1104 | 1105 | torch.nn.utils.clip_grad_norm(self.classifier.parameters(), self.config.clip_grad) 1106 | self.optimizer.step() 1107 | 1108 | if not exp_cost: 1109 | exp_cost = loss.data[0] 1110 | else: 1111 | exp_cost = 0.99 * exp_cost + 0.01 * loss.data[0] 1112 | 1113 | if iter % 100 == 0: 1114 | self.logger.info( 1115 | "iter {} lr={} train_loss={} exp_cost={} \n".format(iter, self.optimizer.param_groups[0]['lr'], 1116 | loss.data[0], exp_cost)) 1117 | self.logger.info("enter validation...") 1118 | valid_em, micro_tup, macro_tup = self.evaluate(is_test=False) 1119 | self.logger.info("epoch {} lr={:.6f} train_loss={:.6f} valid_acc={:.6f}\n".format( 1120 | e + 1, self.optimizer.param_groups[0]['lr'], loss.data[0], valid_em 1121 | )) 1122 | 1123 | # save model 1124 | torch.save(self.classifier, pjoin(self.save_path, 'model-{}.pickle'.format(run_order))) 1125 | 1126 | def test(self, silent=False, return_by_label_stats=False, return_instances=False): 1127 | self.logger.info("compute test set performance...") 1128 | return self.evaluate(is_test=True, silent=silent, return_by_label_stats=return_by_label_stats, 1129 | return_instances=return_instances) 1130 | 1131 | def get_abstention_data_iter(self, data_iter): 1132 | batched_x_list = [] 1133 | batched_y_list = [] 1134 | batched_y_hat_list = [] 1135 | batched_loss_list = [] 1136 | 1137 | for iter, data in enumerate(data_iter): 1138 | (x, x_lengths), y = data.Text, data.Description 1139 | output_vec = self.classifier.get_vectors(x, x_lengths) # this is just logit (before calling sigmoid) 1140 | final_rep = torch.max(output_vec, 0)[0].squeeze(0) 1141 | logits = self.classifier.get_logits(output_vec) 1142 | loss = self.bce_logit_loss(logits, y) # this per-example 1143 | 1144 | # We create new Tensor Variable 1145 | batched_x_list.append(final_rep.detach()) 1146 | batched_y_list.append(y.detach()) 1147 | batched_y_hat_list.append(logits.detach()) # .tolist() 1148 | batched_loss_list.append(loss.detach()) # .tolist() 1149 | 1150 | return batched_x_list, batched_y_list, batched_y_hat_list, batched_loss_list 1151 | 1152 | def get_abstention_data(self): 1153 | # used by Abstention model 1154 | self.classifier.eval() 1155 | train_data = self.get_abstention_data_iter(self.train_iter) 1156 | test_data = self.get_abstention_data_iter(self.test_iter) 1157 | 1158 | return train_data, test_data 1159 | 1160 | def save_error_examples(self, error_dict, label_names, save_address): 1161 | import codecs 1162 | # error_dict: {label: []} 1163 | # creates 42 files in the given directory 1164 | if not os.path.exists(pjoin(self.save_path, save_address)): 1165 | os.makedirs(pjoin(self.save_path, save_address)) 1166 | 1167 | for label_i, error_examples in error_dict.iteritems(): 1168 | file_name = label_names[label_i].replace('AND/OR', '').replace('and/or', '').replace('/', '') 1169 | with codecs.open(pjoin(self.save_path, save_address, file_name + '.txt'), 'w', encoding='utf-8') as f: 1170 | for e_tup in error_examples: 1171 | f.write(e_tup[0] + '\t' + '-'.join([str(x) for x in e_tup[1]]) + '\n') # x tab y 1172 | 1173 | def get_error_examples(self, is_external=False, save_address=None, label_names=None): 1174 | # this function is slower to run than evaluate() 1175 | # save_address needs to point to a folder, not a file 1176 | 1177 | self.classifier.eval() 1178 | data_iter = self.test_iter if not is_external else self.external_test_iter 1179 | 1180 | all_x, error_dict = [], defaultdict(list) 1181 | all_preds, all_y_labels = [], [] # we traverse these two numpy array, then pick out things 1182 | 1183 | for iter, data in enumerate(data_iter): 1184 | (x, x_lengths), y = data.Text, data.Description 1185 | logits = self.classifier(x, x_lengths) 1186 | preds = (torch.sigmoid(logits) > 0.5).data.cpu().numpy().astype(float) 1187 | all_preds.append(preds) 1188 | all_y_labels.append(y.data.cpu().numpy()) 1189 | 1190 | orig_text = self.dataset.TEXT.reverse(x.data) 1191 | all_x.extend(orig_text) 1192 | 1193 | preds = np.vstack(all_preds) 1194 | ys = np.vstack(all_y_labels) 1195 | 1196 | for ith in range(len(all_x)): 1197 | # traverse one by one to find ill match 1198 | if (preds[ith] == ys[ith]).sum() == self.config.label_size: 1199 | continue # perfectly matched 1200 | 1201 | for jth in range(self.config.label_size): 1202 | if preds[ith][jth] != ys[ith][jth]: 1203 | error_dict[jth].append((all_x[ith], ys[ith].nonzero()[0].tolist())) 1204 | # jth disease, append text, will result in duplication 1205 | 1206 | if save_address is not None: 1207 | assert label_names is not None 1208 | self.save_error_examples(error_dict, label_names, save_address) 1209 | 1210 | return error_dict 1211 | 1212 | def evaluate(self, is_test=False, is_external=False, silent=False, return_by_label_stats=False, 1213 | return_instances=False, return_roc_auc=False): 1214 | self.classifier.eval() 1215 | data_iter = self.test_iter if is_test else self.val_iter # evaluate on CSU 1216 | data_iter = self.external_test_iter if is_external else data_iter # evaluate on adobe 1217 | 1218 | all_preds, all_y_labels, all_confs = [], [], [] 1219 | 1220 | for iter, data in enumerate(data_iter): 1221 | (x, x_lengths), y = data.Text, data.Description 1222 | logits = self.classifier(x, x_lengths) 1223 | 1224 | preds = (torch.sigmoid(logits) > 0.5).data.cpu().numpy().astype(float) 1225 | all_preds.append(preds) 1226 | all_y_labels.append(y.data.cpu().numpy()) 1227 | all_confs.append(torch.sigmoid(logits).data.cpu().numpy().astype(float)) 1228 | 1229 | preds = np.vstack(all_preds) 1230 | ys = np.vstack(all_y_labels) 1231 | confs = np.vstack(all_confs) 1232 | 1233 | if not silent: 1234 | self.logger.info("\n" + metrics.classification_report(ys, preds, digits=3)) # write to file 1235 | 1236 | # this is actually the accurate exact match 1237 | em = metrics.accuracy_score(ys, preds) 1238 | accu = np.array([metrics.accuracy_score(ys[:, i], preds[:, i]) for i in range(self.config.label_size)], 1239 | dtype='float32') 1240 | p, r, f1, s = metrics.precision_recall_fscore_support(ys, preds, average=None) 1241 | 1242 | # because some labels are NOT present in the test set, we need to message this function 1243 | # filter out labels that have no examples 1244 | 1245 | # this code works :) 1246 | roc_auc = np.zeros(ys.shape[1]) 1247 | roc_auc[:] = 0.5 # base value for ROC AUC 1248 | non_zero_label_idices = ys.sum(0).nonzero() 1249 | 1250 | non_zero_ys = np.squeeze(ys[:, non_zero_label_idices]) 1251 | non_zero_preds = np.squeeze(preds[:, non_zero_label_idices]) 1252 | non_zero_roc_auc = metrics.roc_auc_score(non_zero_ys, non_zero_preds, average=None) 1253 | 1254 | roc_auc[non_zero_label_idices] = non_zero_roc_auc 1255 | 1256 | if return_by_label_stats and return_roc_auc: 1257 | return p, r, f1, s, accu, roc_auc 1258 | elif return_by_label_stats: 1259 | return p, r, f1, s, accu 1260 | elif return_instances: 1261 | return ys, preds, confs 1262 | 1263 | micro_p, micro_r, micro_f1 = np.average(p, weights=s), np.average(r, weights=s), np.average(f1, weights=s) 1264 | 1265 | # compute Macro-F1 here 1266 | # if is_external: 1267 | # # include clinical finding 1268 | # macro_p, macro_r, macro_f1 = np.average(p[14:]), np.average(r[14:]), np.average(f1[14:]) 1269 | # else: 1270 | # # anything > 10 1271 | # macro_p, macro_r, macro_f1 = np.average(np.take(p, [12] + range(21, 42))), \ 1272 | # np.average(np.take(r, [12] + range(21, 42))), \ 1273 | # np.average(np.take(f1, [12] + range(21, 42))) 1274 | 1275 | # we switch to non-zero macro computing, this can figure out boost from rarest labels 1276 | if is_external: 1277 | # include clinical finding 1278 | macro_p, macro_r, macro_f1 = np.average(p[p.nonzero()]), np.average(r[r.nonzero()]), \ 1279 | np.average(f1[f1.nonzero()]) 1280 | else: 1281 | # anything > 10 1282 | macro_p, macro_r, macro_f1 = np.average(p[p.nonzero()]), \ 1283 | np.average(r[r.nonzero()]), \ 1284 | np.average(f1[f1.nonzero()]) 1285 | 1286 | return em, (micro_p, micro_r, micro_f1), (macro_p, macro_r, macro_f1) 1287 | 1288 | 1289 | class AbstentionConfig(Config): 1290 | def __init__(self, obj_loss=False, obj_accu=False, 1291 | inp_logit=False, inp_pred=False, inp_h=False, inp_conf=False, clip_grad=5., no_shrink=True, 1292 | dropout=0.): 1293 | # some logic checks 1294 | assert inp_logit + inp_pred + inp_h + inp_conf == 1, "only one input type" 1295 | assert obj_loss + obj_accu == 1, "only one objective type" 1296 | 1297 | super(AbstentionConfig, self).__init__(obj_loss=obj_loss, 1298 | obj_accu=obj_accu, # objective is accu 1299 | inp_logit=inp_logit, # input is logit (before sigmoid) 1300 | inp_pred=inp_pred, # input is pred (after sigmoid) 1301 | inp_h=inp_h, 1302 | inp_conf=inp_conf, 1303 | clip_grad=clip_grad, 1304 | no_shrink=no_shrink, 1305 | dropout=dropout) 1306 | 1307 | 1308 | class RejectModel(nn.Module): 1309 | def __init__(self, config, deeptag_config): 1310 | super(RejectModel, self).__init__() 1311 | if config['inp_h']: 1312 | reject_dim = deeptag_config.hidden_size 1313 | if deeptag_config.bidir is True: 1314 | reject_dim *= 2 1315 | else: 1316 | reject_dim = deeptag_config.label_size 1317 | 1318 | if config['no_shrink']: 1319 | self.reject_model = nn.Sequential( 1320 | nn.Linear(reject_dim, int(reject_dim)), 1321 | nn.SELU(), 1322 | nn.Linear(int(reject_dim), int(reject_dim)), 1323 | nn.SELU(), 1324 | nn.Linear(int(reject_dim), 1)) 1325 | else: 1326 | self.reject_model = nn.Sequential( 1327 | nn.Linear(reject_dim, int(reject_dim / 2.)), 1328 | nn.SELU(), 1329 | nn.Linear(int(reject_dim / 2.), int(reject_dim / 4.)), 1330 | nn.SELU(), 1331 | nn.Linear(int(reject_dim / 4.), 1)) 1332 | 1333 | def pred(self, x): 1334 | return self.reject_model(x) 1335 | 1336 | def reject(self, x, gamma=0.): 1337 | # x: (batch_size, rej_dim) 1338 | rej_choices = self.reject_model(x) > gamma 1339 | return rej_choices 1340 | 1341 | 1342 | class Abstention(object): 1343 | # Similar to Experiment class but used to train and manage Reject_Model 1344 | def __init__(self, experiment, deeptag_config): 1345 | self.mse_loss = torch.nn.MSELoss() 1346 | self.sigmoid = torch.nn.Sigmoid() 1347 | 1348 | self.dataset = experiment.dataset 1349 | self.experiment = experiment 1350 | self.deeptag_config = deeptag_config 1351 | 1352 | def get_reject_model(self, config, gpu_id=-1): 1353 | reject_model = RejectModel(config, self.deeptag_config) 1354 | if gpu_id != -1: 1355 | reject_model.cuda(gpu_id) 1356 | return reject_model 1357 | 1358 | def train_loss(self, config, train_data, device, epochs=3, lr=0.001, print_log=False): 1359 | # each training requires a new optimizer 1360 | # these are already Variables 1361 | # batched_x_list, batched_y_list, batched_y_hat_list, batched_loss_list = train_data 1362 | # we used to return losses, but now we all know it works..so no need 1363 | 1364 | reject_model = self.get_reject_model(config, device) 1365 | rej_optimizer = optim.Adam(reject_model.parameters(), lr=lr) 1366 | 1367 | exp_cost = None 1368 | 1369 | for n in range(epochs): 1370 | iteration = 0 1371 | # Already variables on CUDA devices 1372 | print("training at epoch {}".format(n)) 1373 | for x, y, y_hat, orig_loss in izip(*train_data): 1374 | reject_model.zero_grad() 1375 | inp = None 1376 | if config.inp_logit: 1377 | inp = y_hat 1378 | elif config.inp_pred: 1379 | inp = torch.sigmoid(y_hat) 1380 | elif config.inp_h: 1381 | inp = x 1382 | elif config.inp_conf: 1383 | inp = torch.sigmoid(y_hat).cpu().apply_(lambda x: x if x >= 0.5 else 1 - x).cuda(device) 1384 | 1385 | pred_obj = torch.squeeze(reject_model.pred(inp)) 1386 | 1387 | if config['obj_loss']: 1388 | true_obj = orig_loss.mean(dim=1) 1389 | elif config['obj_accu']: 1390 | preds = torch.sigmoid(y_hat) > 0.5 1391 | true_obj = (preds.type_as(y) == y).type_as(y).mean(dim=1) 1392 | 1393 | loss = self.mse_loss(pred_obj, true_obj) 1394 | 1395 | loss.backward() 1396 | torch.nn.utils.clip_grad_norm(reject_model.parameters(), config['clip_grad']) 1397 | rej_optimizer.step() 1398 | 1399 | if not exp_cost: 1400 | exp_cost = loss.data[0] 1401 | else: 1402 | exp_cost = 0.99 * exp_cost + 0.01 * loss.data[0] 1403 | 1404 | if iteration % 100 == 0 and print_log: 1405 | # avg_rej_rate = rej_scores.mean().data[0] 1406 | print("iter {} lr={} train_loss={} exp_cost={} \n".format(iteration, 1407 | rej_optimizer.param_groups[0]['lr'], 1408 | loss.data[0], exp_cost)) 1409 | iteration += 1 1410 | 1411 | return reject_model 1412 | 1413 | @staticmethod 1414 | def compute_exactly_k(k, probs): 1415 | # k: int 1416 | # probs: [float] (confidence score!!) 1417 | score = 0. 1418 | for idx_tup in combinations(range(len(probs)), k): 1419 | success = 1. 1420 | for idx in idx_tup: 1421 | success += math.log(probs[idx]) 1422 | failure = 1. 1423 | for idx in set(range(len(probs))) - set(idx_tup): 1424 | failure += math.log(1 - probs[idx]) 1425 | score += success + failure 1426 | 1427 | return score 1428 | 1429 | def drop(self, data_iter, reject_model, drop_portion, config, device, conf_abstention=False, return_dropped=False, 1430 | weighted_f1=True): 1431 | # apply to whatever documents we want and tag them with abstention priority scores 1432 | # data_iter should be the test set of CSU 1433 | # data_iter is actually not an iterator 1434 | reject_model.eval() 1435 | score_reverse = True if config['obj_loss'] and not conf_abstention else False 1436 | 1437 | prior_score_pred_y_pairs = [] 1438 | 1439 | for x, y, y_hat, orig_loss in izip(*data_iter): 1440 | if not conf_abstention: 1441 | inp = None 1442 | if config.inp_logit: 1443 | inp = y_hat 1444 | elif config.inp_pred: 1445 | inp = torch.sigmoid(y_hat) 1446 | elif config.inp_h: 1447 | inp = x 1448 | elif config.inp_conf: 1449 | inp = torch.sigmoid(y_hat).cpu().apply_(lambda x: x if x >= 0.5 else 1 - x).cuda(device) 1450 | 1451 | pred_obj = torch.squeeze(reject_model.pred(inp)) 1452 | abs_scores = pred_obj.data.cpu().numpy().tolist() 1453 | else: 1454 | # conf_abstention methods 1455 | abs_scores = [] 1456 | confs = torch.sigmoid(y_hat).cpu().apply_(lambda x: x if x >= 0.5 else 1 - x) 1457 | confs = confs.data.numpy().tolist() 1458 | for y_hhat in confs: 1459 | abs_score = self.compute_exactly_k(42, y_hhat) 1460 | abs_scores.append(abs_score) 1461 | 1462 | y_hat = y_hat.data.cpu().numpy() 1463 | y = y.data.cpu().numpy() 1464 | for i, abs_score in enumerate(abs_scores): 1465 | preds = (y_hat[i] > 0.5).astype(float) 1466 | y_np = y[i] 1467 | prior_score_pred_y_pairs.append((abs_score, [preds, y_np])) 1468 | 1469 | # dropping process 1470 | total_examples = len(prior_score_pred_y_pairs) 1471 | drop_num = int(math.ceil(total_examples * drop_portion)) 1472 | 1473 | # drop from smallest value to largest value (accuracy) 1474 | sorted_list = sorted(prior_score_pred_y_pairs, key=lambda x: x[0], reverse=score_reverse) 1475 | 1476 | accepted_exs = sorted_list[drop_num:] # take examples after drop_num 1477 | rejected_exs = sorted_list[:drop_num] 1478 | 1479 | # then we compute the EM, micro-F1, macro-F1 1480 | all_preds, all_y_labels = [], [] 1481 | for ex in accepted_exs: 1482 | pred, y = ex[1] 1483 | all_preds.append(pred); 1484 | all_y_labels.append(y) 1485 | 1486 | if return_dropped: 1487 | rej_preds = [] 1488 | rej_y_labels = [] 1489 | for ex in rejected_exs: 1490 | pred, y = ex[1] 1491 | rej_preds.append(pred); 1492 | rej_y_labels.append(y) 1493 | 1494 | preds = np.vstack(all_preds) 1495 | ys = np.vstack(all_y_labels) 1496 | 1497 | # this is actually the accurate exact match 1498 | em = metrics.accuracy_score(ys, preds) 1499 | p, r, f1, s = metrics.precision_recall_fscore_support(ys, preds, average=None) 1500 | f1 = np.average(f1, weights=s) if weighted_f1 else np.average(f1[f1.nonzero()]) 1501 | 1502 | if return_dropped: 1503 | return rej_preds, rej_y_labels, em, f1 1504 | 1505 | return em, f1 1506 | 1507 | def get_ems_f1s(self, data_iter, model, config, device, conf_abstention=False, weighted_f1=True): 1508 | # data_iter: test data 1509 | # data_iter, reject_model, drop_portion, config, device 1510 | ems = []; 1511 | f1s = [] 1512 | rej_portions = np.linspace(0., 0.9, num=9) 1513 | for rej_p in rej_portions: 1514 | em, f1 = self.drop(data_iter, model, rej_p, config, device, conf_abstention, weighted_f1=weighted_f1) 1515 | ems.append(em); 1516 | f1s.append(f1) 1517 | return ems, f1s 1518 | 1519 | def get_deeptag_data(self, run_order, device, rebuild_vocab=True): 1520 | # send the model in here, we run it 1521 | # need to specify which model to load (exact number) 1522 | # the "data" obtained are universal -- meaning they stay the same during the 1523 | # abstention module training 1524 | if rebuild_vocab: 1525 | self.dataset.build_vocab(self.deeptag_config, True) 1526 | 1527 | self.experiment.set_run_random_seed(run_order) 1528 | 1529 | trainer = self.experiment.get_trainer(self.deeptag_config, device, run_order, build_vocab=False, load=True) 1530 | train_data, test_data = trainer.get_abstention_data() 1531 | 1532 | return train_data, test_data 1533 | 1534 | def save_deeptag_data(self, run_order, device, rebuild_vocab=True): 1535 | # save it into the same format as LTR Vol 2. 1536 | train_data, test_data = self.get_deeptag_data(run_order, device, rebuild_vocab) 1537 | list_train_data = [] 1538 | list_test_data = [] 1539 | for i in range(len(train_data)): 1540 | list_train_data.append([train_data[i][j].data.cpu().numpy().tolist() for j in range(len(train_data[i]))]) 1541 | for i in range(len(test_data)): 1542 | list_test_data.append([test_data[i][j].data.cpu().numpy().tolist() for j in range(len(test_data[i]))]) 1543 | 1544 | with open(pjoin(self.experiment.exp_save_path, self.deeptag_config.run_name, 1545 | "train_data.json"), 'wb') as f: 1546 | json.dump(list_train_data, f) 1547 | 1548 | with open(pjoin(self.experiment.exp_save_path, self.deeptag_config.run_name, 1549 | "test_data.json"), 'wb') as f: 1550 | json.dump(list_train_data, f) 1551 | 1552 | 1553 | # Experiment class can also be "handled" by Jupyter Notebook 1554 | # Usage guide: 1555 | # config also manages random seed. So it's possible to just swap in and out random seed from config 1556 | # to run an average, can write it into another function inside Experiment class called `repeat_execute()` 1557 | # also, currently once trainer is deleted, the classifier pointer would be lost...completely 1558 | class Experiment(object): 1559 | def __init__(self, dataset, exp_save_path): 1560 | """ 1561 | :param dataset: Dataset class 1562 | :param exp_save_path: the overall saving folder 1563 | """ 1564 | if not os.path.exists(exp_save_path): 1565 | os.makedirs(exp_save_path) 1566 | 1567 | self.dataset = dataset 1568 | self.exp_save_path = exp_save_path 1569 | self.saved_random_states = [49537527, 50069528, 44150907, 25982144, 12302344] 1570 | 1571 | # we never want to overwrite this file 1572 | if not os.path.exists(pjoin(exp_save_path, "all_runs_stats.csv")): 1573 | with open(pjoin(self.exp_save_path, "all_runs_stats.csv"), 'w') as f: 1574 | csv_writer = csv.writer(f) 1575 | csv_writer.writerow(['model', 'CSU EM', 'CSU micro-P', 'CSU micro-R', 'CSU micro-F1', 1576 | 'CSU macro-P', 'CSU macro-R', 'CSU macro-F1', 1577 | 'PP EM', 'PP micro-P', 'PP micro-R', 'PP micro-F1', 1578 | 'PP macro-P', 'PP macro-R', 'PP macro-F1']) 1579 | 1580 | def get_trainer(self, config, device, run_order=0, build_vocab=False, load=False, silent=True, **kwargs): 1581 | # build each trainer and classifier by config; or reload classifier 1582 | # **kwargs: additional commands for the two losses 1583 | 1584 | if build_vocab: 1585 | self.dataset.build_vocab(config, silent) # because we might try different word embedding size 1586 | 1587 | self.set_random_seed(config) 1588 | 1589 | classifier = Classifier(self.dataset.vocab, config) 1590 | logging.info(classifier) 1591 | trainer_folder = config.run_name if config.run_name != 'default' else self.config_to_string(config) 1592 | trainer = Trainer(classifier, self.dataset, config, 1593 | save_path=pjoin(self.exp_save_path, trainer_folder), 1594 | device=device, load=load, run_order=run_order, **kwargs) 1595 | 1596 | return trainer 1597 | 1598 | def set_random_seed(self, config): 1599 | seed = config.seed 1600 | torch.manual_seed(seed) 1601 | np.random.seed(seed) 1602 | random.seed(seed) 1603 | torch.cuda.manual_seed_all(config.seed) # need to seed cuda too 1604 | 1605 | # I'm not sure if after setting random seed, should we set random state again... 1606 | def set_run_random_seed(self, run_order): 1607 | seed = self.saved_random_states[run_order] 1608 | torch.manual_seed(seed) 1609 | np.random.seed(seed) 1610 | random.seed(seed) 1611 | torch.cuda.manual_seed_all(seed) 1612 | 1613 | def config_to_string(self, config): 1614 | # we compare config to baseline config, if values are modified, we produce it into string 1615 | model_name = "mod" # this will be the "baseline" 1616 | base_config = LSTMBaseConfig() 1617 | for k, new_v in config.items(): 1618 | if k in base_config.keys(): 1619 | old_v = base_config[k] 1620 | if old_v != new_v: 1621 | model_name += "_{}_{}".format(k, new_v) 1622 | else: 1623 | model_name += "_{}_{}".format(k, new_v) 1624 | 1625 | return model_name.replace('.', '').replace('-', '_') # for 1e-3 to 1e_3 1626 | 1627 | def record_meta_result(self, meta_results, append, config, file_name='all_runs_stats.csv'): 1628 | # this records result one line at a time! 1629 | mode = 'a' if append else 'w' 1630 | model_str = self.config_to_string(config) 1631 | 1632 | csu_em, csu_micro_tup, csu_macro_tup, \ 1633 | pp_em, pp_micro_tup, pp_macro_tup = meta_results 1634 | 1635 | with open(pjoin(self.exp_save_path, file_name), mode=mode) as f: 1636 | csv_writer = csv.writer(f) 1637 | csv_writer.writerow([model_str, csu_em, csu_micro_tup[0], 1638 | csu_micro_tup[1], csu_micro_tup[2], 1639 | csu_macro_tup[0], csu_macro_tup[1], csu_macro_tup[2], 1640 | pp_em, pp_micro_tup[0], pp_micro_tup[1], pp_micro_tup[2], 1641 | pp_macro_tup[0], pp_macro_tup[1], pp_macro_tup[2]]) 1642 | 1643 | def record_per_run_result(self, meta_results, append, trainer_path, run_order, print_header=False): 1644 | mode = 'a' if append else 'w' 1645 | 1646 | csu_em, csu_micro_tup, csu_macro_tup, \ 1647 | pp_em, pp_micro_tup, pp_macro_tup = meta_results 1648 | 1649 | with open(pjoin(trainer_path, "avg_run_stats.csv"), mode=mode) as f: 1650 | csv_writer = csv.writer(f) 1651 | if print_header: 1652 | csv_writer.writerow(['run order', 'CSU EM', 'CSU micro-P', 'CSU micro-R', 'CSU micro-F1', 1653 | 'CSU macro-P', 'CSU macro-R', 'CSU macro-F1', 1654 | 'PP EM', 'PP micro-P', 'PP micro-R', 'PP micro-F1', 1655 | 'PP macro-P', 'PP macro-R', 'PP macro-F1']) 1656 | 1657 | csv_writer.writerow(['runtime_{}'.format(run_order), csu_em, csu_micro_tup[0], 1658 | csu_micro_tup[1], csu_micro_tup[2], 1659 | csu_macro_tup[0], csu_macro_tup[1], csu_macro_tup[2], 1660 | pp_em, pp_micro_tup[0], pp_micro_tup[1], pp_micro_tup[2], 1661 | pp_macro_tup[0], pp_macro_tup[1], pp_macro_tup[2]]) 1662 | 1663 | def execute_trainer(self, trainer, train_epochs=5, append=True): 1664 | # used jointly with `get_trainer()` 1665 | # the benefit of this function is it will record meta-result into a file... 1666 | # use this to "evaluate" a model 1667 | trainer.train(epochs=train_epochs) 1668 | csu_em, csu_micro_tup, csu_macro_tup = trainer.test() 1669 | trainer.logger.info("===== Evaluating on PP data =====") 1670 | pp_em, pp_micro_tup, pp_macro_tup = trainer.evaluate(is_external=True) 1671 | trainer.logger.info("PP accuracy = {}".format(pp_em)) 1672 | self.record_meta_result([csu_em, csu_micro_tup, csu_macro_tup, 1673 | pp_em, pp_micro_tup, pp_macro_tup], 1674 | append=append, config=trainer.config) 1675 | 1676 | def execute(self, config, device, train_epochs=5, append=True): 1677 | # combined get_trainer() and execute_trainer() 1678 | # this is also "training"...not evaluating 1679 | agg_csu_ems, agg_pp_ems = [], [] 1680 | agg_csu_micro_tup, agg_csu_macro_tup = [], [] 1681 | agg_pp_micro_tup, agg_pp_macro_tup = [], [] 1682 | 1683 | self.dataset.build_vocab(config, True) 1684 | trainer_folder = config.run_name if config.run_name != 'default' else self.config_to_string(config) 1685 | 1686 | for run_order in range(config.avg_run_times): 1687 | self.set_run_random_seed(run_order) # hopefully this is enough... 1688 | 1689 | classifier = Classifier(self.dataset.vocab, config) 1690 | trainer = Trainer(classifier, self.dataset, config, 1691 | save_path=pjoin(self.exp_save_path, trainer_folder), 1692 | device=device) 1693 | 1694 | trainer.train(run_order, train_epochs) 1695 | csu_em, csu_micro_tup, csu_macro_tup = trainer.test() 1696 | 1697 | trainer.logger.info("===== Evaluating on PP data =====") 1698 | pp_em, pp_micro_tup, pp_macro_tup = trainer.evaluate(is_external=True) 1699 | trainer.logger.info("PP accuracy = {}".format(pp_em)) 1700 | 1701 | print_header = run_order == 0 1702 | 1703 | self.record_per_run_result([csu_em, csu_micro_tup, csu_macro_tup, 1704 | pp_em, pp_micro_tup, pp_macro_tup], 1705 | append=append, trainer_path=trainer.save_path, run_order=run_order, 1706 | print_header=print_header) 1707 | 1708 | agg_csu_ems.append(csu_em); 1709 | agg_pp_ems.append(pp_em) 1710 | agg_csu_micro_tup.append(np.array(csu_micro_tup)); 1711 | agg_csu_macro_tup.append(np.array(csu_macro_tup)) 1712 | agg_pp_micro_tup.append(np.array(pp_micro_tup)); 1713 | agg_pp_macro_tup.append(np.array(pp_macro_tup)) 1714 | 1715 | csu_avg_em, pp_avg_em = np.average(agg_csu_ems), np.average(agg_pp_ems) 1716 | csu_avg_micro, csu_avg_macro = np.average(agg_csu_micro_tup, axis=0).tolist(), np.average(agg_csu_macro_tup, 1717 | axis=0).tolist() 1718 | pp_avg_micro, pp_avg_macro = np.average(agg_pp_micro_tup, axis=0).tolist(), np.average(agg_pp_macro_tup, 1719 | axis=0).tolist() 1720 | 1721 | self.record_meta_result([csu_avg_em, csu_avg_micro, csu_avg_macro, 1722 | pp_avg_em, pp_avg_micro, pp_avg_macro], 1723 | append=append, config=config) 1724 | 1725 | def delete_trainer(self, trainer): 1726 | # move all parameters to cpu and then delete the pointer 1727 | trainer.classifier.cpu() 1728 | del trainer.classifier 1729 | del trainer 1730 | 1731 | def compute_label_metrics_ci(self, config, list_metric_matrix): 1732 | label_list_metric = [[] for _ in range(config.label_size)] 1733 | mean, ubs, lbs = [], [], [] 1734 | 1735 | for j in range(config.label_size): 1736 | for mm in list_metric_matrix: 1737 | label_list_metric[j].append(mm[j]) 1738 | 1739 | for j in range(config.label_size): 1740 | mean.append(np.mean(label_list_metric[j])) 1741 | lb, ub = get_ci(label_list_metric[j], return_range=True) 1742 | ubs.append(ub); 1743 | lbs.append(lb) 1744 | 1745 | return mean, ubs, lbs 1746 | 1747 | def get_meta_result(self, config, device, rebuild_vocab=False, silent=False, return_avg=True, 1748 | print_to_file=False, file_name='', append=True): 1749 | # returns: csu_avg_em, csu_avg_micro, csu_avg_macro, pp_avg_em, pp_avg_micro, pp_avg_macro 1750 | # basically ONE row in the results table. 1751 | # return_avg: return 5 runs individually (for std, ci calculation), or return average only 1752 | if rebuild_vocab: 1753 | self.dataset.build_vocab(config, True) 1754 | 1755 | agg_csu_ems, agg_pp_ems = [], [] 1756 | agg_csu_micro_tup, agg_csu_macro_tup = [], [] 1757 | agg_pp_micro_tup, agg_pp_macro_tup = [], [] 1758 | 1759 | for run_order in range(config.avg_run_times): 1760 | if not silent: 1761 | print("Executing order {}".format(run_order)) 1762 | trainer = self.get_trainer(config, device, run_order, build_vocab=False, load=True) 1763 | csu_em, csu_micro_tup, csu_macro_tup = trainer.test(silent=silent) 1764 | pp_em, pp_micro_tup, pp_macro_tup = trainer.evaluate(is_external=True, silent=silent) 1765 | 1766 | agg_csu_ems.append(csu_em); 1767 | agg_csu_micro_tup.append(np.array(csu_micro_tup)) 1768 | agg_csu_macro_tup.append(np.array(csu_macro_tup)) 1769 | agg_pp_micro_tup.append(np.array(pp_micro_tup)); 1770 | agg_pp_macro_tup.append(np.array(pp_macro_tup)) 1771 | agg_pp_ems.append(pp_em) 1772 | 1773 | csu_avg_em, pp_avg_em = np.average(agg_csu_ems), np.average(agg_pp_ems) 1774 | csu_avg_micro, csu_avg_macro = np.average(agg_csu_micro_tup, axis=0).tolist(), np.average(agg_csu_macro_tup, 1775 | axis=0).tolist() 1776 | pp_avg_micro, pp_avg_macro = np.average(agg_pp_micro_tup, axis=0).tolist(), np.average(agg_pp_macro_tup, 1777 | axis=0).tolist() 1778 | 1779 | if print_to_file: 1780 | assert file_name != '' 1781 | self.record_meta_result([csu_avg_em, csu_avg_micro, csu_avg_macro, 1782 | pp_avg_em, pp_avg_micro, pp_avg_macro], 1783 | append=append, config=config, file_name=file_name) 1784 | elif return_avg: 1785 | return [csu_avg_em, csu_avg_micro[0], 1786 | csu_avg_micro[1], csu_avg_micro[2], 1787 | csu_avg_macro[0], csu_avg_macro[1], csu_avg_macro[2], 1788 | pp_avg_em, pp_avg_micro[0], pp_avg_micro[1], pp_avg_micro[2], 1789 | pp_avg_macro[0], pp_avg_macro[1], pp_avg_macro[2]] 1790 | else: 1791 | return [agg_csu_ems, agg_csu_micro_tup, agg_csu_macro_tup, 1792 | agg_pp_ems, agg_pp_micro_tup, agg_pp_macro_tup] 1793 | 1794 | def get_meta_header(self): 1795 | # return a list of headers 1796 | # in real scenario, the first column is often 'model name' or 'run order' 1797 | return ['CSU EM', 'CSU micro-P', 'CSU micro-R', 'CSU micro-F1', 1798 | 'CSU macro-P', 'CSU macro-R', 'CSU macro-F1', 1799 | 'PP EM', 'PP micro-P', 'PP micro-R', 'PP micro-F1', 1800 | 'PP macro-P', 'PP macro-R', 'PP macro-F1'] 1801 | 1802 | def evaluate(self, config, device, is_external=False, rebuild_vocab=False, silent=False, 1803 | return_f1_ci=False): 1804 | # Similr to trainer.evaluate() signature 1805 | # but allows to handle multi-run averaging! 1806 | # we also always return by_label_stats 1807 | # return: p,r,f1,s,accu 1808 | 1809 | if rebuild_vocab: 1810 | self.dataset.build_vocab(config, True) 1811 | 1812 | agg_p, agg_r, agg_f1, agg_accu, agg_roc_auc = 0., 0., 0., 0., 0. 1813 | agg_f1_list = [] 1814 | 1815 | for run_order in range(config.avg_run_times): 1816 | if not silent: 1817 | print("Executing order {}".format(run_order)) 1818 | trainer = self.get_trainer(config, device, run_order, build_vocab=False, load=True) 1819 | p, r, f1, s, accu, roc_auc = trainer.evaluate(is_test=True, is_external=is_external, return_by_label_stats=True, 1820 | silent=True, return_roc_auc=True) 1821 | agg_p += p; 1822 | agg_r += r; 1823 | agg_f1 += f1; 1824 | agg_accu += accu; 1825 | agg_roc_auc += roc_auc; 1826 | agg_f1_list.append(f1) 1827 | 1828 | if return_f1_ci: 1829 | return self.compute_label_metrics_ci(config, agg_f1_list) 1830 | 1831 | agg_p, agg_r, agg_f1, agg_accu, agg_roc_auc = agg_p / float(config.avg_run_times), agg_r / float(config.avg_run_times), \ 1832 | agg_f1 / float(config.avg_run_times), agg_accu / float(config.avg_run_times), \ 1833 | agg_roc_auc / float(config.avg_run_times) 1834 | 1835 | return agg_p, agg_r, agg_f1, agg_accu, agg_roc_auc 1836 | 1837 | def get_performance(self, config): 1838 | # actually looks into trainer's actual file 1839 | # returns: [(avg, std, CI), ...] 1840 | trainer_folder = config.run_name if config.run_name != 'default' else self.config_to_string(config) 1841 | 1842 | stat_array = defaultdict(list) 1843 | cat_size = 0. 1844 | 1845 | with open(pjoin(self.exp_save_path, trainer_folder, 'avg_run_stats.csv'), 'r') as f: 1846 | csv_reader = csv.reader(f) 1847 | for i, line in enumerate(csv_reader): 1848 | if i == 0: 1849 | continue 1850 | cat_size = len(line[1:]) 1851 | for j, stat in enumerate(line[1:]): 1852 | stat_array[j].append(float(stat)) 1853 | 1854 | stats_res = [0.] * cat_size 1855 | 1856 | for j in range(cat_size): 1857 | stats_res[j] = (np.mean(stat_array[j]), np.std(stat_array[j]), get_ci(stat_array[j])) 1858 | 1859 | return stats_res 1860 | 1861 | 1862 | # Important! Each time you use "get_iterators", must restore previous random state 1863 | # otherwise the sampling procedure will be different 1864 | def run_baseline(device, label_size): 1865 | random.setstate(orig_state) 1866 | lstm_base_c = LSTMBaseConfig(emb_corpus=emb_corpus, avg_run_times=avg_run_times, 1867 | label_size=label_size, 1868 | conv_enc=use_conv) 1869 | curr_exp.execute(lstm_base_c, train_epochs=train_epochs, device=device) 1870 | # trainer = curr_exp.get_trainer(config=lstm_base_c, device=device, build_vocab=True) 1871 | # curr_exp.execute(trainer=trainer) 1872 | 1873 | 1874 | def run_bidir_baseline(device, label_size): 1875 | random.setstate(orig_state) 1876 | lstm_bidir_c = LSTMBaseConfig(bidir=True, emb_corpus=emb_corpus, avg_run_times=avg_run_times, 1877 | label_size=label_size, 1878 | conv_enc=use_conv) 1879 | curr_exp.execute(lstm_bidir_c, train_epochs=train_epochs, device=device) 1880 | # trainer = curr_exp.get_trainer(config=lstm_bidir_c, device=device, build_vocab=True) 1881 | # curr_exp.execute(trainer=trainer) 1882 | 1883 | 1884 | def run_m_penalty(device, beta=1e-3, bidir=False): 1885 | random.setstate(orig_state) 1886 | config = LSTM_w_M_Config(beta, bidir=bidir, emb_corpus=emb_corpus, avg_run_times=avg_run_times, 1887 | conv_enc=use_conv) 1888 | curr_exp.execute(config, train_epochs=train_epochs, device=device) 1889 | # trainer = curr_exp.get_trainer(config=config, device=device, build_vocab=True) 1890 | # curr_exp.execute(trainer=trainer) 1891 | 1892 | 1893 | def run_c_penalty(device, sigma_M, sigma_B, sigma_W, bidir=False): 1894 | random.setstate(orig_state) 1895 | config = LSTM_w_C_Config(sigma_M, sigma_B, sigma_W, bidir=bidir, emb_corpus=emb_corpus, 1896 | avg_run_times=avg_run_times, conv_enc=use_conv) 1897 | curr_exp.execute(config, train_epochs=train_epochs, device=device) 1898 | # trainer = curr_exp.get_trainer(config=config, device=device, build_vocab=True) 1899 | # curr_exp.execute(trainer=trainer) 1900 | 1901 | 1902 | use_conv = 0 1903 | 1904 | if __name__ == '__main__': 1905 | # if we just call this file, it will set up an interactive console 1906 | random.seed(1234) 1907 | 1908 | # we get the original random state, and simply reset during each run 1909 | orig_state = random.getstate() 1910 | 1911 | action = raw_input("enter branches of default actions: active | baseline | meta | cluster \n") 1912 | 1913 | device_num = int(raw_input("enter the GPU device number \n")) 1914 | assert -1 <= device_num <= 3, "GPU ID must be between -1 and 3" 1915 | 1916 | exp_name = raw_input("enter the experiment name, default is 'csu_new_exp', skip to use default: ") 1917 | exp_name = 'csu_new_exp' if exp_name.strip() == '' else exp_name 1918 | 1919 | emb_corpus = raw_input("enter embedding choice, skip for default: gigaword | common_crawl \n") 1920 | emb_corpus = 'gigaword' if emb_corpus.strip() == '' else emb_corpus 1921 | assert emb_corpus == 'gigaword' or emb_corpus == 'common_crawl' 1922 | 1923 | avg_run_times = raw_input("enter run times (intger), maximum 5: \n") # default 1, but should run 5 times 1924 | avg_run_times = 1 if avg_run_times.strip() == '' else int(avg_run_times) 1925 | avg_run_times = 5 if avg_run_times > 5 else avg_run_times 1926 | 1927 | dataset_number = raw_input("enter dataset name prefix id (1=snomed_multi_label_no_des_ \n " 1928 | "2=snomed_revised_fields_multi_label_no_des_ \n" 1929 | "3=snomed_all_fields_multi_label_no_des_\n" 1930 | "4=snomed_fine_grained_multi_label_no_des_): \n") 1931 | 1932 | label_size = 42 1933 | test_data_name = 'adobe_combined_abbr_matched_snomed_multi_label_no_des_test.tsv' 1934 | if dataset_number.strip() == "": 1935 | print("Default choice to 1") 1936 | dataset_prefix = 'snomed_multi_label_no_des_' 1937 | elif int(dataset_number) == 1: 1938 | dataset_prefix = 'snomed_multi_label_no_des_' 1939 | elif int(dataset_number) == 2: 1940 | dataset_prefix = 'snomed_revised_fields_multi_label_no_des_' 1941 | elif int(dataset_number) == 3: 1942 | dataset_prefix = 'snomed_all_fields_multi_label_no_des_' 1943 | elif int(dataset_number) == 4: 1944 | dataset_prefix = 'snomed_fine_grained_multi_label_no_des_' 1945 | label_size = 4577 1946 | test_data_name = 'adobe_combined_abbr_matched_snomed_fine_grained_label_no_des_test.tsv' 1947 | 1948 | conv_encoder = raw_input("Use conv_encoder or not? 0/1(Hierarchical)/2(Normal)/3(TextCNN) \n") 1949 | assert (conv_encoder == '0' or conv_encoder == '1' or conv_encoder == '2' or conv_encoder == '3') 1950 | 1951 | global use_conv 1952 | use_conv = int(conv_encoder.strip()) 1953 | 1954 | train_epochs = raw_input("Enter the number of training epochs: (default 5) \n") 1955 | if train_epochs.strip() == "": 1956 | train_epochs = 5 1957 | else: 1958 | train_epochs = int(train_epochs.strip()) 1959 | 1960 | print("loading in dataset...will take 3-4 minutes...") 1961 | dataset = Dataset(dataset_prefix=dataset_prefix, label_size=label_size, test_data_name=test_data_name) 1962 | 1963 | curr_exp = Experiment(dataset=dataset, exp_save_path='./{}/'.format(exp_name)) 1964 | 1965 | if action == 'active': 1966 | import IPython; 1967 | 1968 | IPython.embed() 1969 | elif action == 'baseline': 1970 | # baseline LSTM 1971 | run_baseline(device_num, label_size) 1972 | run_bidir_baseline(device_num, label_size) 1973 | elif action == 'meta': 1974 | # baseline LSTM + M 1975 | # run_m_penalty(device_num, beta=1e-3) 1976 | # run_m_penalty(device_num, beta=1e-4) 1977 | 1978 | # run_baseline(device_num) 1979 | # run_bidir_baseline(device_num) 1980 | # 1981 | # # baseline LSTM + M + bidir 1982 | assert 'fine_grained' not in dataset_prefix 1983 | run_m_penalty(device_num, beta=1e-4, bidir=True) 1984 | run_m_penalty(device_num, beta=1e-3, bidir=True) 1985 | # 1986 | # run_c_penalty(device_num, sigma_M=1e-5, sigma_B=1e-4, sigma_W=1e-4, bidir=True) 1987 | # run_c_penalty(device_num, sigma_M=1e-4, sigma_B=1e-3, sigma_W=1e-3, bidir=True) 1988 | # 1989 | run_m_penalty(device_num, beta=1e-4) 1990 | run_m_penalty(device_num, beta=1e-3) 1991 | 1992 | # run_c_penalty(device_num, sigma_M=1e-5, sigma_B=1e-4, sigma_W=1e-4) 1993 | # run_c_penalty(device_num, sigma_M=1e-4, sigma_B=1e-3, sigma_W=1e-3) 1994 | 1995 | elif action == 'cluster': 1996 | assert 'fine_grained' not in dataset_prefix 1997 | # baseline LSTM + C 1998 | run_c_penalty(device_num, sigma_M=1e-5, sigma_B=1e-4, sigma_W=1e-4) 1999 | run_c_penalty(device_num, sigma_M=1e-4, sigma_B=1e-3, sigma_W=1e-3) 2000 | 2001 | # baseline LSTM + C + bidir 2002 | run_c_penalty(device_num, sigma_M=1e-5, sigma_B=1e-4, sigma_W=1e-4, bidir=True) 2003 | run_c_penalty(device_num, sigma_M=1e-4, sigma_B=1e-3, sigma_W=1e-3, bidir=True) 2004 | else: 2005 | print("Non-identifiable action: {}".format(action)) 2006 | --------------------------------------------------------------------------------