├── test_run.sh ├── README.md ├── .gitignore ├── decorators.py ├── process_twitter_data.py ├── process_ner_data.py ├── process_ud_data.py ├── stochastic_layers.py ├── model_utils.py ├── config.py ├── vsl_g.py ├── vsl_gg.py ├── train_helper.py ├── models.py └── data_utils.py /test_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate py35 4 | 5 | python vsl_g.py \ 6 | --debug 1 \ 7 | --model g \ 8 | --data_file twitter.data \ 9 | --prior_file test_g \ 10 | --vocab_file twitter \ 11 | --tag_file tagger_twitter \ 12 | --embed_file wordvects.tw100w5-m40-it2 \ 13 | --n_iter 30000 \ 14 | --save_prior 1 \ 15 | --train_emb 0 \ 16 | --tie_weights 1 \ 17 | --embed_dim 100 \ 18 | --latent_z_size 50 \ 19 | --update_freq_label 1 \ 20 | --update_freq_unlabel 1 \ 21 | --rnn_size 100 \ 22 | --char_embed_dim 50 \ 23 | --char_hidden_size 100 \ 24 | --mlp_layer 2 \ 25 | --mlp_hidden_size 100 \ 26 | --learning_rate 1e-3 \ 27 | --vocab_size 100000 \ 28 | --batch_size 10 \ 29 | --kl_anneal_rate 1e-4 \ 30 | --print_every 100 \ 31 | --eval_every 1000 \ 32 | --summarize 1 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VSL 2 | 3 | A PyTorch implementation of "[Variational Sequential Labelers for Semi-Supervised Learning](http://ttic.uchicago.edu/~mchen/papers/mchen+etal.emnlp18.pdf)" (EMNLP 2018) 4 | 5 | 6 | ## Prerequisites 7 | 8 | - Python 3.5 9 | - PyTorch 0.3.0 10 | - Scikit-Learn 11 | - NumPy 12 | 13 | ## Data and Pretrained Embeddings 14 | 15 | Download: [Twitter](https://code.google.com/archive/p/ark-tweet-nlp/downloads), [Universal Dependencies](https://lindat.mff.cuni.cz/repository/xmlui/handle/11234/1-1827?show=full), [Embeddings (for Twitter and UD)](https://drive.google.com/drive/folders/1oie43_thsbhhoUsOHlkyKj2iMpFNOrgA?usp=sharing) 16 | 17 | Run `process_{ner,twitter,ud}_data.py` first to generate `*.pkl` files and then use it as input for `vsl_{g,gg}.py`. 18 | 19 | ## Citation 20 | 21 | ``` 22 | @inproceedings{mchen-variational-18, 23 | author = {Mingda Chen and Qingming Tang and Karen Livescu and Kevin Gimpel}, 24 | title = {Variational Sequential Labelers for Semi-Supervised Learning}, 25 | booktitle = {Proc. of {EMNLP}}, 26 | year = {2018} 27 | } 28 | ``` 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /decorators.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import pickle 3 | import os 4 | 5 | 6 | def auto_init_args(init): 7 | def new_init(self, *args, **kwargs): 8 | arg_dict = inspect.signature(init).parameters 9 | arg_names = list(arg_dict.keys())[1:] # skip self 10 | proc_names = set() 11 | for name, arg in zip(arg_names, args): 12 | setattr(self, name, arg) 13 | proc_names.add(name) 14 | for name, arg in kwargs.items(): 15 | setattr(self, name, arg) 16 | proc_names.add(name) 17 | remain_names = set(arg_names) - proc_names 18 | if len(remain_names): 19 | for name in remain_names: 20 | setattr(self, name, arg_dict[name].default) 21 | init(self, *args, **kwargs) 22 | 23 | return new_init 24 | 25 | 26 | def auto_init_pytorch(init): 27 | def new_init(self, *args, **kwargs): 28 | init(self, *args, **kwargs) 29 | self.opt = self.init_optimizer( 30 | self.expe.config.opt, 31 | self.expe.config.lr, 32 | self.expe.config.l2) 33 | 34 | if self.use_cuda: 35 | self.cuda() 36 | self.expe.log.info("transferred model to gpu") 37 | 38 | return new_init 39 | 40 | 41 | class lazy_execute: 42 | @auto_init_args 43 | def __init__(self, func_name): 44 | pass 45 | 46 | def __call__(self, fn): 47 | func_name = self.func_name 48 | 49 | def new_fn(self, *args, **kwargs): 50 | if os.path.isfile(kwargs['file_name']): 51 | return getattr(self, func_name)(kwargs['file_name']) 52 | else: 53 | data = fn(self, *args, **kwargs) 54 | 55 | self.expe.log.info("saving to {}" 56 | .format(kwargs['file_name'])) 57 | with open(kwargs['file_name'], "wb+") as fp: 58 | pickle.dump(data, fp, protocol=-1) 59 | return data 60 | return new_fn 61 | -------------------------------------------------------------------------------- /process_twitter_data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import logging 4 | 5 | from sklearn.model_selection import train_test_split 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser( 10 | description='Data Preprocessing for Twitter') 11 | parser.add_argument('--train', type=str, default=None, 12 | help='train data path') 13 | parser.add_argument('--dev', type=str, default=None, 14 | help='dev data path') 15 | parser.add_argument('--test', type=str, default=None, 16 | help='test data path') 17 | parser.add_argument('--ratio', type=float, default=1.0, 18 | help='training data ratio') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def process_file(data_file): 24 | logging.info("loading data from " + data_file + " ...") 25 | sents = [] 26 | tags = [] 27 | with open(data_file, 'r', encoding='utf-8') as df: 28 | for line in df.readlines(): 29 | if line.strip(): 30 | index = line.find('|||') 31 | if index == -1: 32 | raise ValueError('Format Error') 33 | sent = line[: index - 1] 34 | tag = line[index + 4: -1] 35 | sents.append(sent.split(' ')) 36 | tags.append(tag.split(' ')) 37 | return sents, tags 38 | 39 | 40 | if __name__ == "__main__": 41 | logging.basicConfig(level=logging.DEBUG, 42 | format='%(asctime)s %(message)s', 43 | datefmt='%m-%d %H:%M') 44 | args = get_args() 45 | train = process_file(args.train) 46 | dev = process_file(args.dev) 47 | test = process_file(args.test) 48 | 49 | tag_set = set(sum([sum(d[1], []) for d in [train, dev, test]], 50 | [])) 51 | with open("twitter_tagfile", "w+", encoding='utf-8') as fp: 52 | fp.write('\n'.join(sorted(list(tag_set)))) 53 | 54 | if args.ratio != 1: 55 | train_x, test_x, train_y, test_y = \ 56 | train_test_split(train[0], train[1], test_size=args.ratio) 57 | train = [test_x, test_y] 58 | assert len(train_x) == len(train_y) 59 | 60 | logging.info("#train: {}".format(len(train[0]))) 61 | logging.info("#dev: {}".format(len(dev[0]))) 62 | logging.info("#test: {}".format(len(test[0]))) 63 | 64 | pickle.dump( 65 | [train, dev, test], 66 | open("data/twitter{}.data".format(args.ratio), "wb+"), protocol=-1) 67 | -------------------------------------------------------------------------------- /process_ner_data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import logging 4 | 5 | from sklearn.model_selection import train_test_split 6 | from collections import Counter 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser( 11 | description='Data Preprocessing for Named Entity Recognition') 12 | parser.add_argument('--train', type=str, default=None, 13 | help='train data path') 14 | parser.add_argument('--dev', type=str, default=None, 15 | help='dev data path') 16 | parser.add_argument('--test', type=str, default=None, 17 | help='test data path') 18 | parser.add_argument('--ratio', type=float, default=1., 19 | help='ratio of labeled data') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def process(word): 25 | word = "".join(c if not c.isdigit() else '0' for c in word) 26 | return word 27 | 28 | 29 | def process_file(data_file): 30 | logging.info("loading data from " + data_file + " ...") 31 | sents = [] 32 | tags = [] 33 | sent = [] 34 | tag = [] 35 | with open(data_file, 'r', encoding='utf-8') as df: 36 | for line in df.readlines(): 37 | if line[0:10] == '-DOCSTART-': 38 | continue 39 | if line.strip(): 40 | word = line.strip().split(" ")[0] 41 | t = line.strip().split(" ")[-1] 42 | sent.append(process(word)) 43 | tag.append(t) 44 | else: 45 | if sent and tag: 46 | sents.append(sent) 47 | tags.append(tag) 48 | sent = [] 49 | tag = [] 50 | return sents, tags 51 | 52 | 53 | if __name__ == "__main__": 54 | logging.basicConfig(level=logging.DEBUG, 55 | format='%(asctime)s %(message)s', 56 | datefmt='%m-%d %H:%M') 57 | args = get_args() 58 | train = process_file(args.train) 59 | dev = process_file(args.dev) 60 | test = process_file(args.test) 61 | 62 | tag_counter = Counter(sum(train[1], []) + 63 | sum(dev[1], []) + sum(test[1], [])) 64 | with open("ner_tagfile".format(args.ratio), "w+", encoding='utf-8') as fp: 65 | fp.write('\n'.join(sorted(tag_counter.keys()))) 66 | 67 | if args.ratio < 1: 68 | n_unlabel = len(train[0]) // 2 69 | X_train, X_test, y_train, y_test = \ 70 | train_test_split(train[0], train[1], test_size=args.ratio) 71 | other = [X_train, y_train] 72 | train = [X_test, y_test] 73 | 74 | X_train, X_test, y_train, y_test = \ 75 | train_test_split(other[0], other[1], test_size=n_unlabel) 76 | 77 | unlabel_data = X_test 78 | logging.info("#unlabeled data: {}".format(len(X_test))) 79 | 80 | with open("ner{}_unlabel.data".format(args.ratio), 81 | "w+", encoding='utf-8') as fp: 82 | fp.write( 83 | "\n".join([" ".join([w for w in sent]) 84 | for sent in unlabel_data])) 85 | logging.info( 86 | "unlabeled data saved to {}".format( 87 | "ner{}_unlabel.data".format(args.ratio))) 88 | 89 | logging.info("#train data: {}".format(len(train[0]))) 90 | logging.info("#dev data: {}".format(len(dev[0]))) 91 | logging.info("#test data: {}".format(len(test[0]))) 92 | 93 | pickle.dump( 94 | [train, dev, test], open("ner{}.data".format(args.ratio), "wb+"), 95 | protocol=-1) 96 | -------------------------------------------------------------------------------- /process_ud_data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import logging 4 | 5 | from sklearn.model_selection import train_test_split 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser( 10 | description='Data Preprocessing for Universal Dependencies') 11 | parser.add_argument('--train', type=str, default=None, 12 | help='train data path') 13 | parser.add_argument('--dev', type=str, default=None, 14 | help='dev data path') 15 | parser.add_argument('--test', type=str, default=None, 16 | help='test data path') 17 | parser.add_argument('--output', type=str, default=None, 18 | help='output file name') 19 | parser.add_argument('--ratio', type=float, default=1.0, 20 | help='labeled data ratio') 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | def load_data(data_file): 26 | logging.info("loading data from {} ...".format(data_file)) 27 | sents = [] 28 | tags = [] 29 | sent = [] 30 | tag = [] 31 | with open(data_file, 'r', encoding="utf-8") as f: 32 | for line in f: 33 | if line[0] == "#": 34 | continue 35 | if line.strip(): 36 | token = line.strip("\n").split("\t") 37 | word = token[1] 38 | t = token[3] 39 | sent.append(word) 40 | tag.append(t) 41 | else: 42 | sents.append(sent) 43 | tags.append(tag) 44 | sent = [] 45 | tag = [] 46 | return sents, tags 47 | 48 | 49 | if __name__ == "__main__": 50 | logging.basicConfig(level=logging.DEBUG, 51 | format='%(asctime)s %(message)s', 52 | datefmt='%m-%d %H:%M') 53 | args = get_args() 54 | logging.info("##### training data #####") 55 | all_sents, all_tags = load_data(args.train) 56 | logging.info("random splitting training data with ratio of {}..." 57 | .format(args.ratio)) 58 | train_sents, unlabel_sents, train_tags, unlabel_tags = \ 59 | train_test_split(all_sents, all_tags, 60 | train_size=args.ratio, shuffle=True) 61 | logging.info("#train sents: {}, #train words: {}, #train tags: {}" 62 | .format(len(train_sents), len(sum(train_sents, [])), 63 | len(sum(train_tags, [])))) 64 | logging.info("#unlabeled sents: {}" 65 | .format(len(unlabel_sents))) 66 | logging.info("##### dev data #####") 67 | dev_sents, dev_tags = load_data(args.dev) 68 | logging.info("#dev sents: {}, #dev words: {}, #dev tags: {}" 69 | .format(len(dev_sents), len(sum(dev_sents, [])), 70 | len(sum(dev_tags, [])))) 71 | logging.info("##### test data #####") 72 | test_sents, test_tags = load_data(args.test) 73 | logging.info("#dev sents: {}, #dev words: {}, #dev tags: {}" 74 | .format(len(test_sents), len(sum(test_sents, [])), 75 | len(sum(test_tags, [])))) 76 | output = "data" if args.output is None else args.output 77 | output += ".ud" 78 | 79 | tag_set = set(sum([sum(d, []) for d in [all_tags, dev_tags, test_tags]], 80 | [])) 81 | with open("ud_tagfile", "w+", encoding='utf-8') as fp: 82 | fp.write('\n'.join(sorted(list(tag_set)))) 83 | dataset = {"train": [train_sents, train_tags], 84 | "unlabel": [unlabel_sents, unlabel_tags], 85 | "dev": [dev_sents, dev_tags], 86 | "test": [test_sents, test_tags]} 87 | pickle.dump(dataset, open(output, "wb+"), protocol=-1) 88 | logging.info("data saved to {}".format(output)) 89 | -------------------------------------------------------------------------------- /stochastic_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | 5 | from model_utils import gaussian 6 | 7 | 8 | class gaussian_layer(nn.Module): 9 | """ 10 | h 11 | | 12 | z 13 | | 14 | x 15 | """ 16 | def __init__(self, input_size, latent_z_size): 17 | super(gaussian_layer, self).__init__() 18 | self.input_size = input_size 19 | self.latent_z_size = latent_z_size 20 | 21 | self.q_mean2_mlp = nn.Linear(input_size, latent_z_size) 22 | self.q_logvar2_mlp = nn.Linear(input_size, latent_z_size) 23 | 24 | def forward(self, inputs, mask, sample): 25 | """ 26 | inputs: batch x batch_len x input_size 27 | """ 28 | batch_size, batch_len, _ = inputs.size() 29 | mean_qs = self.q_mean2_mlp(inputs) 30 | logvar_qs = self.q_logvar2_mlp(inputs) 31 | 32 | if sample: 33 | z = gaussian(mean_qs, logvar_qs) * mask.unsqueeze(-1) 34 | else: 35 | z = mean_qs * mask.unsqueeze(-1) 36 | 37 | return z, mean_qs, logvar_qs 38 | 39 | 40 | class gaussian_flat_layer(nn.Module): 41 | """ 42 | h 43 | / \ 44 | y z 45 | \ / 46 | x 47 | 48 | """ 49 | def __init__(self, input_size, latent_z_size, latent_y_size): 50 | super(gaussian_flat_layer, self).__init__() 51 | self.input_size = input_size 52 | self.latent_y_size = latent_y_size 53 | self.latent_z_size = latent_z_size 54 | 55 | self.q_mean_mlp = nn.Linear(input_size, latent_z_size) 56 | self.q_logvar_mlp = nn.Linear(input_size, latent_z_size) 57 | 58 | self.q_mean2_mlp = nn.Linear(input_size, latent_y_size) 59 | self.q_logvar2_mlp = nn.Linear(input_size, latent_y_size) 60 | 61 | def forward(self, inputs, mask, sample): 62 | """ 63 | inputs: batch x batch_len x input_size 64 | """ 65 | batch_size, batch_len, _ = inputs.size() 66 | 67 | mean_qs = self.q_mean_mlp(inputs) * mask.unsqueeze(-1) 68 | logvar_qs = self.q_logvar_mlp(inputs) * mask.unsqueeze(-1) 69 | 70 | mean2_qs = self.q_mean2_mlp(inputs) * mask.unsqueeze(-1) 71 | logvar2_qs = self.q_logvar2_mlp(inputs) * mask.unsqueeze(-1) 72 | 73 | if sample: 74 | y = gaussian(mean2_qs, logvar2_qs) * mask.unsqueeze(-1) 75 | else: 76 | y = mean2_qs * mask.unsqueeze(-1) 77 | 78 | if sample: 79 | z = gaussian(mean_qs, logvar_qs) * mask.unsqueeze(-1) 80 | else: 81 | z = mean_qs * mask.unsqueeze(-1) 82 | 83 | return z, y, mean_qs, logvar_qs, mean2_qs, logvar2_qs 84 | 85 | 86 | class gaussian_hier_layer(nn.Module): 87 | """ 88 | h 89 | | 90 | y 91 | | 92 | z 93 | | 94 | x 95 | """ 96 | def __init__(self, input_size, latent_z_size, latent_y_size): 97 | super(gaussian_hier_layer, self).__init__() 98 | self.input_size = input_size 99 | self.latent_y_size = latent_y_size 100 | self.latent_z_size = latent_z_size 101 | 102 | self.q_mean2_mlp = nn.Linear(input_size, latent_y_size) 103 | self.q_logvar2_mlp = nn.Linear(input_size, latent_y_size) 104 | 105 | self.q_mean_mlp = nn.Linear(input_size + latent_y_size, latent_z_size) 106 | self.q_logvar_mlp = nn.Linear( 107 | input_size + latent_y_size, latent_z_size) 108 | 109 | def forward(self, inputs, mask, sample): 110 | """ 111 | inputs: batch x batch_len x input_size 112 | """ 113 | batch_size, batch_len, _ = inputs.size() 114 | 115 | mean2_qs = self.q_mean2_mlp(inputs) 116 | logvar2_qs = self.q_logvar2_mlp(inputs) 117 | 118 | if sample: 119 | y = gaussian(mean2_qs, logvar2_qs) * mask.unsqueeze(-1) 120 | else: 121 | y = mean2_qs * mask.unsqueeze(-1) 122 | 123 | gauss_input = torch.cat([inputs, y], -1) 124 | mean_qs = self.q_mean_mlp(gauss_input) 125 | logvar_qs = self.q_logvar_mlp(gauss_input) 126 | 127 | if sample: 128 | z = gaussian(mean_qs, logvar_qs) * mask.unsqueeze(-1) 129 | else: 130 | z = mean_qs * mask.unsqueeze(-1) 131 | 132 | return z, y, mean_qs, logvar_qs, mean2_qs, logvar2_qs 133 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | 5 | from torch.autograd import Variable 6 | 7 | 8 | def gaussian(mean, logvar): 9 | return mean + torch.exp(0.5 * logvar) * \ 10 | Variable(logvar.data.new(mean.size()).normal_()) 11 | 12 | 13 | def to_one_hot(label, n_class): 14 | return Variable( 15 | label.data.new(label.size(0), label.size(1), n_class) 16 | .zero_().scatter_(2, label.data.unsqueeze(-1), 2)).float() 17 | 18 | 19 | def get_mlp_layer(input_size, hidden_size, output_size, n_layer): 20 | if n_layer == 0: 21 | layer = nn.Linear(input_size, output_size) 22 | else: 23 | layer = nn.Sequential( 24 | nn.Linear(input_size, hidden_size), 25 | nn.ReLU()) 26 | 27 | for i in range(n_layer - 1): 28 | layer.add_module( 29 | str(len(layer)), 30 | nn.Linear(hidden_size, hidden_size)) 31 | layer.add_module(str(len(layer)), nn.ReLU()) 32 | 33 | layer.add_module( 34 | str(len(layer)), 35 | nn.Linear(hidden_size, output_size)) 36 | return layer 37 | 38 | 39 | def get_rnn_output(inputs, mask, cell, bidir=False, initial_state=None): 40 | """ 41 | Args: 42 | inputs: batch_size x seq_len x n_feat 43 | mask: batch_size x seq_len 44 | initial_state: batch_size x num_layers x hidden_size 45 | cell: GRU/LSTM/RNN 46 | """ 47 | seq_lengths = torch.sum(mask, dim=-1).squeeze(-1) 48 | sorted_len, sorted_idx = seq_lengths.sort(0, descending=True) 49 | index_sorted_idx = sorted_idx\ 50 | .view(-1, 1, 1).expand_as(inputs) 51 | sorted_inputs = inputs.gather(0, index_sorted_idx.long()) 52 | packed_seq = torch.nn.utils.rnn.pack_padded_sequence( 53 | sorted_inputs, sorted_len.long().cpu().data.numpy(), batch_first=True) 54 | out, _ = cell(packed_seq, hx=initial_state) 55 | unpacked, unpacked_len = \ 56 | torch.nn.utils.rnn.pad_packed_sequence( 57 | out, batch_first=True) 58 | _, original_idx = sorted_idx.sort(0, descending=False) 59 | unsorted_idx = original_idx\ 60 | .view(-1, 1, 1).expand_as(unpacked) 61 | output_seq = unpacked.gather(0, unsorted_idx.long()) 62 | idx = (seq_lengths - 1).view(-1, 1).expand( 63 | output_seq.size(0), output_seq.size(2)).unsqueeze(1) 64 | final_state = output_seq.gather(1, idx.long()).squeeze(1) 65 | if bidir: 66 | hsize = final_state.size(-1) // 2 67 | final_state_fw = final_state[:, :hsize] 68 | final_state_bw = output_seq[:, 0, hsize:] 69 | final_state = torch.cat([final_state_fw, final_state_bw], dim=-1) 70 | return output_seq, final_state, seq_lengths 71 | 72 | 73 | def get_rnn(rnn_type): 74 | if rnn_type.lower() == "lstm": 75 | return nn.LSTM 76 | elif rnn_type.lower() == "gru": 77 | return nn.GRU 78 | elif rnn_type.lower() == "rnn": 79 | return nn.RNN 80 | else: 81 | NotImplementedError("invalid rnn type: {}".format(rnn_type)) 82 | 83 | 84 | def kl_normal2_normal2(mean1, log_var1, mean2, log_var2): 85 | return 0.5 * log_var2 - 0.5 * log_var1 + \ 86 | (torch.exp(log_var1) + (mean1 - mean2) ** 2) / \ 87 | (2 * torch.exp(log_var2) + 1e-10) - 0.5 88 | 89 | 90 | def compute_KL_div(mean_q, log_var_q, mean_prior, log_var_prior): 91 | kl_divergence = kl_normal2_normal2( 92 | mean_q, log_var_q, mean_prior, log_var_prior) 93 | return kl_divergence 94 | 95 | 96 | def compute_KL_div2(mean, log_var): 97 | return - 0.5 * (1 + log_var - mean.pow(2) - log_var.exp()) 98 | 99 | 100 | class char_rnn(nn.Module): 101 | def __init__(self, rnn_type, vocab_size, embed_dim, hidden_size): 102 | super(char_rnn, self).__init__() 103 | self.char_embed = nn.Embedding(vocab_size, embed_dim) 104 | self.hidden_size = hidden_size 105 | self.char_cell = get_rnn(rnn_type)( 106 | input_size=embed_dim, 107 | hidden_size=hidden_size, 108 | bidirectional=True, 109 | batch_first=True) 110 | 111 | def forward(self, chars, chars_mask, data_mask): 112 | """ 113 | chars: batch size x seq len x # chars 114 | chars_mask: batch size x seq len x # chars 115 | data_mask: batch size x seq len 116 | """ 117 | char_output = [] 118 | batch_size, seq_len, char_len = chars.size() 119 | for i in range(batch_size): 120 | char = chars[i, :, :] 121 | char_mask = chars_mask[i, :, :] 122 | # trim off extra padding 123 | word_length = int(data_mask[i, :].sum()) 124 | n_char = int(char_mask.sum(-1).max()) 125 | char = char[:word_length, :n_char] 126 | char_mask = char_mask[:word_length, :n_char] 127 | # char: word length x char len 128 | char_vec = self.char_embed(char.long()) 129 | _, final_state, _ = get_rnn_output( 130 | char_vec, char_mask, self.char_cell, bidir=True) 131 | # final_state: word length x hidden size 132 | padding = final_state.data.new( 133 | seq_len - word_length, 2 * self.hidden_size).zero_() 134 | if padding.dim(): 135 | final_output = torch.cat( 136 | [final_state, Variable(padding)], dim=0) 137 | else: 138 | final_output = final_state 139 | char_output.append(final_output) 140 | char_outputs = torch.stack(char_output, 0) 141 | # batch size x seq len x hidden size 142 | return char_outputs 143 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | UNK_WORD_IDX = 0 5 | UNK_WORD = "UUUNKKK" 6 | UNK_CHAR_IDX = 0 7 | UNK_CHAR = "UUUNKKK" 8 | 9 | 10 | def str2bool(v): 11 | return v.lower() in ('yes', 'true', 't', '1', 'y') 12 | 13 | 14 | def get_parser(): 15 | parser = argparse.ArgumentParser( 16 | description='Variational Sequential Labelers \ 17 | for Semi-supervised learning') 18 | parser.register('type', 'bool', str2bool) 19 | 20 | basic_group = parser.add_argument_group('basics') 21 | basic_group.add_argument('--debug', type="bool", default=False, 22 | help='whether to activate debug mode \ 23 | (default: False)') 24 | basic_group.add_argument('--model', type=str, default='g', 25 | choices=['g', 'flat', 'hier'], 26 | help='type of model (default: g)') 27 | basic_group.add_argument('--random_seed', type=int, default=0, 28 | help='Random seed (default: 0)') 29 | 30 | data = parser.add_argument_group('data') 31 | data.add_argument('--prefix', type=str, default=None, 32 | help='save file prefix (default: None)') 33 | data.add_argument('--data_file', type=str, default=None, 34 | help='path to training data file (default: None)') 35 | data.add_argument('--unlabel_file', type=str, default=None, 36 | help='path to unlabeled file (default: None)') 37 | data.add_argument('--vocab_file', type=str, default=None, 38 | help='path to vocab file (default: None)') 39 | data.add_argument('--tag_file', type=str, default=None, 40 | help='path to tag file (default: None)') 41 | data.add_argument('--embed_file', type=str, default=None, 42 | help='path to embedding file (default: None)') 43 | data.add_argument('--use_unlabel', type="bool", default=None, 44 | help='whether to use unlabeled data (default: None)') 45 | data.add_argument('--prior_file', type=str, default=None, 46 | help='path to saved prior file (default: None)') 47 | 48 | data.add_argument('--embed_type', type=str, default='twitter', 49 | choices=['glove', 'twitter', 'ud'], 50 | help='types of embedding file (default: twitter)') 51 | 52 | config = parser.add_argument_group('configs') 53 | config.add_argument('-edim', '--embed_dim', 54 | dest='edim', type=int, default=100, 55 | help='embedding dimension (default: 100)') 56 | config.add_argument('-rtype', '--rnn_type', 57 | dest='rtype', type=str, default='gru', 58 | choices=['gru', 'lstm', 'rnn'], 59 | help='types of optimizer: gru (default), lstm, rnn') 60 | config.add_argument('-tw', '--tie_weights', 61 | dest='tw', type='bool', default=True, 62 | help='whether to tie weights (default: True)') 63 | 64 | # Character level model detail 65 | config.add_argument('-cdim', '--char_embed_dim', 66 | dest='cdim', type=int, default=15, 67 | help='character embedding dimension (default: 15)') 68 | config.add_argument('-chsize', '--char_hidden_size', 69 | dest='chsize', type=int, default=15, 70 | help='character rnn hidden size (default: 15)') 71 | 72 | # Latent variable specs 73 | config.add_argument('-zsize', '--latent_z_size', 74 | dest='zsize', type=int, default=100, 75 | help='dimension of latent variable (default: 100)') 76 | config.add_argument('-ysize', '--latent_y_size', 77 | dest='ysize', type=int, default=25, 78 | help='dimension of latent variable (default: 25)') 79 | config.add_argument('-rsize', '--rnn_size', 80 | dest='rsize', type=int, default=100, 81 | help='dimension of recurrent nnet (default: 100)') 82 | config.add_argument('-mhsize', '--mlp_hidden_size', 83 | dest='mhsize', type=int, default=100, 84 | help='hidden dimension of feedforward nnet \ 85 | (default: 100)') 86 | config.add_argument('-mlayer', '--mlp_layer', 87 | dest='mlayer', type=int, default=2, 88 | help='number of layers of feedforward nnet \ 89 | (default: 2)') 90 | config.add_argument('-xvar', '--latent_x_logvar', 91 | dest='xvar', type=int, default=1e-3, 92 | help='log varaicne of latent variable x \ 93 | (default: 1e-3)') 94 | 95 | # KL annealing 96 | config.add_argument('-klr', '--kl_anneal_rate', 97 | dest='klr', type=float, default=1e-3, 98 | help='annealing rate (default: 1e-3)') 99 | # Loss specs 100 | config.add_argument('-ur', '--unlabel_ratio', 101 | dest='ur', type=float, default=0.1, 102 | help='unlabeled loss ratio (default: 0.1)') 103 | config.add_argument('-ufl', '--update_freq_label', 104 | dest='ufl', type=int, default=1, 105 | help='frequency of updating prior for labeled data \ 106 | (default: 1)') 107 | config.add_argument('-ufu', '--update_freq_unlabel', 108 | dest='ufu', type=int, default=1, 109 | help='frequency of updating prior for unlabeled data \ 110 | (default: 1)') 111 | 112 | train = parser.add_argument_group('training') 113 | train.add_argument('--opt', type=str, default='adam', 114 | choices=['adam', 'sgd', 'rmsprop'], 115 | help='types of optimizer: adam (default), \ 116 | sgd, rmsprop') 117 | train.add_argument('--n_iter', type=int, default=30000, 118 | help='number of iteration (default: 30000)') 119 | train.add_argument('--batch_size', type=int, default=10, 120 | help='labeled data batch size (default: 10)') 121 | train.add_argument('--unlabel_batch_size', type=int, default=10, 122 | help='unlabeled data batch size (default: 10)') 123 | train.add_argument('--vocab_size', type=int, default=50000, 124 | help='maximum number of words in vocabulary \ 125 | (default: 50000)') 126 | config.add_argument('--char_vocab_size', type=int, default=300, 127 | help='character vocabulary size (default: 300)') 128 | train.add_argument('--train_emb', type="bool", default=False, 129 | help='whether to train word embedding (default: False)') 130 | train.add_argument('--save_prior', type="bool", default=False, 131 | help='whether to save trained prior (default: False)') 132 | train.add_argument('-lr', '--learning_rate', 133 | dest='lr', 134 | type=float, default=1e-3, 135 | help='learning rate (default: 1e-3)') 136 | train.add_argument('--l2', type=float, default=0, 137 | help='weight decay rate (default: 0)') 138 | train.add_argument('--grad_clip', type=float, default=10., 139 | help='gradient clipping (default: 10)') 140 | train.add_argument('--f1_score', type="bool", default=False, 141 | help='whether to report F1 score (default: False)') 142 | 143 | misc = parser.add_argument_group('misc') 144 | misc.add_argument('--print_every', type=int, default=10, 145 | help='print training details after \ 146 | this number of iterations (default: 10)') 147 | misc.add_argument('--eval_every', type=int, default=100, 148 | help='evaluate model after \ 149 | this number of iterations (default: 100)') 150 | misc.add_argument('--summarize', type="bool", default=False, 151 | help='whether to summarize training stats\ 152 | (default: False)') 153 | return parser 154 | -------------------------------------------------------------------------------- /vsl_g.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import config 3 | import train_helper 4 | import data_utils 5 | 6 | import numpy as np 7 | 8 | from models import vsl_g 9 | from tensorboardX import SummaryWriter 10 | 11 | best_dev_res = test_res = 0 12 | 13 | 14 | def run(e): 15 | global best_dev_res, test_res 16 | 17 | e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) 18 | dp = data_utils.data_processor(experiment=e) 19 | data, W = dp.process() 20 | 21 | label_logvar_buffer = \ 22 | train_helper.prior_buffer(data.train[0], e.config.zsize, 23 | experiment=e, 24 | freq=e.config.ufl, 25 | name="label_logvar", 26 | init_path=e.config.prior_file) 27 | label_mean_buffer = \ 28 | train_helper.prior_buffer(data.train[0], e.config.zsize, 29 | experiment=e, 30 | freq=e.config.ufl, 31 | name="label_mean", 32 | init_path=e.config.prior_file) 33 | 34 | all_buffer = [label_logvar_buffer, label_mean_buffer] 35 | 36 | e.log.info("labeled buffer size: logvar: {}, mean: {}" 37 | .format(len(label_logvar_buffer), len(label_mean_buffer))) 38 | 39 | if e.config.use_unlabel: 40 | unlabel_logvar_buffer = \ 41 | train_helper.prior_buffer(data.unlabel[0], e.config.zsize, 42 | experiment=e, 43 | freq=e.config.ufu, 44 | name="unlabel_logvar", 45 | init_path=e.config.prior_file) 46 | unlabel_mean_buffer = \ 47 | train_helper.prior_buffer(data.unlabel[0], e.config.zsize, 48 | experiment=e, 49 | freq=e.config.ufu, 50 | name="unlabel_mean", 51 | init_path=e.config.prior_file) 52 | 53 | all_buffer += [unlabel_logvar_buffer, unlabel_mean_buffer] 54 | 55 | e.log.info("unlabeled buffer size: logvar: {}, mean: {}" 56 | .format(len(unlabel_logvar_buffer), 57 | len(unlabel_mean_buffer))) 58 | 59 | e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) 60 | e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) 61 | 62 | model = vsl_g( 63 | word_vocab_size=len(data.vocab), 64 | char_vocab_size=len(data.char_vocab), 65 | n_tags=len(data.tag_vocab), 66 | embed_dim=e.config.edim if W is None else W.shape[1], 67 | embed_init=W, 68 | experiment=e) 69 | 70 | e.log.info(model) 71 | e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) 72 | 73 | if e.config.summarize: 74 | writer = SummaryWriter(e.experiment_dir) 75 | 76 | label_batch = data_utils.minibatcher( 77 | word_data=data.train[0], 78 | char_data=data.train[1], 79 | label=data.train[2], 80 | batch_size=e.config.batch_size, 81 | shuffle=True) 82 | 83 | if e.config.use_unlabel: 84 | unlabel_batch = data_utils.minibatcher( 85 | word_data=data.unlabel[0], 86 | char_data=data.unlabel[1], 87 | label=data.unlabel[0], 88 | batch_size=e.config.unlabel_batch_size, 89 | shuffle=True) 90 | 91 | evaluator = train_helper.evaluator(data.inv_tag_vocab, model, e) 92 | 93 | e.log.info("Training start ...") 94 | label_stats = train_helper.tracker( 95 | ["loss", "logloss", "kl_div", "sup_loss"]) 96 | unlabel_stats = train_helper.tracker( 97 | ["loss", "logloss", "kl_div"]) 98 | 99 | for it in range(e.config.n_iter): 100 | model.train() 101 | kl_temp = train_helper.get_kl_temp(e.config.klr, it, 1.0) 102 | 103 | try: 104 | l_data, l_mask, l_char, l_char_mask, l_label, l_ixs = \ 105 | next(label_batch) 106 | except StopIteration: 107 | pass 108 | 109 | lp_logvar = label_logvar_buffer[l_ixs] 110 | lp_mean = label_mean_buffer[l_ixs] 111 | 112 | l_loss, l_logloss, l_kld, sup_loss, lq_mean, lq_logvar, _ = \ 113 | model(l_data, l_mask, l_char, l_char_mask, 114 | l_label, lp_mean, lp_logvar, kl_temp) 115 | 116 | label_logvar_buffer.update_buffer(l_ixs, lq_logvar, l_mask.sum(-1)) 117 | label_mean_buffer.update_buffer(l_ixs, lq_mean, l_mask.sum(-1)) 118 | 119 | label_stats.update( 120 | {"loss": l_loss, "logloss": l_logloss, "kl_div": l_kld, 121 | "sup_loss": sup_loss}, l_mask.sum()) 122 | 123 | if not e.config.use_unlabel: 124 | model.optimize(l_loss) 125 | 126 | else: 127 | try: 128 | u_data, u_mask, u_char, u_char_mask, u_label, u_ixs = \ 129 | next(unlabel_batch) 130 | except StopIteration: 131 | pass 132 | 133 | up_logvar = unlabel_logvar_buffer[u_ixs] 134 | up_mean = unlabel_mean_buffer[u_ixs] 135 | 136 | u_loss, u_logloss, u_kld, _, uq_mean, uq_logvar, _ = \ 137 | model(u_data, u_mask, u_char, u_char_mask, 138 | None, up_mean, up_logvar, kl_temp) 139 | 140 | unlabel_logvar_buffer.update_buffer( 141 | u_ixs, uq_logvar, u_mask.sum(-1)) 142 | unlabel_mean_buffer.update_buffer( 143 | u_ixs, uq_mean, u_mask.sum(-1)) 144 | 145 | unlabel_stats.update( 146 | {"loss": u_loss, "logloss": u_logloss, "kl_div": u_kld}, 147 | u_mask.sum()) 148 | 149 | model.optimize(l_loss + e.config.ur * u_loss) 150 | 151 | if (it + 1) % e.config.print_every == 0: 152 | summary = label_stats.summarize( 153 | "it: {} (max: {}), kl_temp: {:.2f}, labeled".format( 154 | it + 1, len(label_batch), kl_temp)) 155 | if e.config.use_unlabel: 156 | summary += unlabel_stats.summarize(", unlabeled") 157 | e.log.info(summary) 158 | if e.config.summarize: 159 | writer.add_scalar( 160 | "label/kl_temp", kl_temp, it) 161 | for name, value in label_stats.stats.items(): 162 | writer.add_scalar( 163 | "label/" + name, value, it) 164 | if e.config.use_unlabel: 165 | for name, value in unlabel_stats.stats.items(): 166 | writer.add_scalar( 167 | "unlabel/" + name, value, it) 168 | label_stats.reset() 169 | unlabel_stats.reset() 170 | if (it + 1) % e.config.eval_every == 0: 171 | 172 | e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25) 173 | 174 | dev_perf, dev_res = evaluator.evaluate(data.dev) 175 | 176 | e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25) 177 | 178 | if e.config.summarize: 179 | for n, v in dev_perf.items(): 180 | writer.add_scalar( 181 | "dev/" + n, v, it) 182 | 183 | if best_dev_res < dev_res: 184 | best_dev_res = dev_res 185 | 186 | e.log.info("*" * 25 + " TEST SET EVALUATION " + "*" * 25) 187 | 188 | test_perf, test_res = evaluator.evaluate(data.test) 189 | 190 | e.log.info("*" * 25 + " TEST SET EVALUATION " + "*" * 25) 191 | 192 | model.save( 193 | dev_perf=dev_perf, 194 | test_perf=test_perf, 195 | iteration=it) 196 | 197 | if e.config.save_prior: 198 | for buf in all_buffer: 199 | buf.save() 200 | 201 | if e.config.summarize: 202 | writer.add_scalar( 203 | "dev/best_result", best_dev_res, it) 204 | for n, v in test_perf.items(): 205 | writer.add_scalar( 206 | "test/" + n, v, it) 207 | e.log.info("best dev result: {:.4f}, " 208 | "test result: {:.4f}, " 209 | .format(best_dev_res, test_res)) 210 | label_stats.reset() 211 | unlabel_stats.reset() 212 | 213 | 214 | if __name__ == '__main__': 215 | 216 | args = config.get_parser().parse_args() 217 | args.use_cuda = torch.cuda.is_available() 218 | 219 | def exit_handler(*args): 220 | print(args) 221 | print("best dev result: {:.4f}, " 222 | "test result: {:.4f}" 223 | .format(best_dev_res, test_res)) 224 | exit() 225 | 226 | train_helper.register_exit_handler(exit_handler) 227 | 228 | np.random.seed(args.random_seed) 229 | torch.manual_seed(args.random_seed) 230 | with train_helper.experiment(args, args.prefix) as e: 231 | 232 | e.log.info("*" * 25 + " ARGS " + "*" * 25) 233 | e.log.info(args) 234 | e.log.info("*" * 25 + " ARGS " + "*" * 25) 235 | 236 | run(e) 237 | -------------------------------------------------------------------------------- /vsl_gg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import config 3 | import train_helper 4 | import data_utils 5 | 6 | import numpy as np 7 | 8 | from models import vsl_gg 9 | from tensorboardX import SummaryWriter 10 | 11 | best_dev_res = test_res = 0 12 | 13 | 14 | def run(e): 15 | global best_dev_res, test_res 16 | 17 | e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) 18 | dp = data_utils.data_processor(experiment=e) 19 | data, W = dp.process() 20 | 21 | label_logvar1_buffer = \ 22 | train_helper.prior_buffer(data.train[0], e.config.zsize, 23 | experiment=e, 24 | freq=e.config.ufl, 25 | name="label_logvar1", 26 | init_path=e.config.prior_file) 27 | label_mean1_buffer = \ 28 | train_helper.prior_buffer(data.train[0], e.config.zsize, 29 | experiment=e, 30 | freq=e.config.ufl, 31 | name="label_mean1", 32 | init_path=e.config.prior_file) 33 | 34 | label_logvar2_buffer = \ 35 | train_helper.prior_buffer(data.train[0], e.config.ysize, 36 | experiment=e, 37 | freq=e.config.ufl, 38 | name="label_logvar2", 39 | init_path=e.config.prior_file) 40 | label_mean2_buffer = \ 41 | train_helper.prior_buffer(data.train[0], e.config.ysize, 42 | experiment=e, 43 | freq=e.config.ufl, 44 | name="label_mean2", 45 | init_path=e.config.prior_file) 46 | 47 | all_buffer = [label_logvar1_buffer, label_mean1_buffer, 48 | label_logvar2_buffer, label_mean2_buffer] 49 | 50 | e.log.info("labeled buffer size: logvar1: {}, mean1: {}, " 51 | "logvar2: {}, mean2: {}" 52 | .format(len(label_logvar1_buffer), len(label_mean1_buffer), 53 | len(label_logvar2_buffer), len(label_mean2_buffer))) 54 | 55 | if e.config.use_unlabel: 56 | unlabel_logvar1_buffer = \ 57 | train_helper.prior_buffer(data.unlabel[0], e.config.zsize, 58 | experiment=e, 59 | freq=e.config.ufu, 60 | name="unlabel_logvar1", 61 | init_path=e.config.prior_file) 62 | unlabel_mean1_buffer = \ 63 | train_helper.prior_buffer(data.unlabel[0], e.config.zsize, 64 | experiment=e, 65 | freq=e.config.ufu, 66 | name="unlabel_mean1", 67 | init_path=e.config.prior_file) 68 | 69 | unlabel_logvar2_buffer = \ 70 | train_helper.prior_buffer(data.unlabel[0], e.config.ysize, 71 | experiment=e, 72 | freq=e.config.ufu, 73 | name="unlabel_logvar2", 74 | init_path=e.config.prior_file) 75 | unlabel_mean2_buffer = \ 76 | train_helper.prior_buffer(data.unlabel[0], e.config.ysize, 77 | experiment=e, 78 | freq=e.config.ufu, 79 | name="unlabel_mean2", 80 | init_path=e.config.prior_file) 81 | 82 | all_buffer += [unlabel_logvar1_buffer, unlabel_mean1_buffer, 83 | unlabel_logvar2_buffer, unlabel_mean2_buffer] 84 | 85 | e.log.info("unlabeled buffer size: logvar1: {}, mean1: {}, " 86 | "logvar2: {}, mean2: {}" 87 | .format(len(unlabel_logvar1_buffer), 88 | len(unlabel_mean1_buffer), 89 | len(unlabel_logvar2_buffer), 90 | len(unlabel_mean2_buffer))) 91 | 92 | e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) 93 | e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) 94 | 95 | model = vsl_gg( 96 | word_vocab_size=len(data.vocab), 97 | char_vocab_size=len(data.char_vocab), 98 | n_tags=len(data.tag_vocab), 99 | embed_dim=e.config.edim if W is None else W.shape[1], 100 | embed_init=W, 101 | experiment=e) 102 | 103 | e.log.info(model) 104 | e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) 105 | 106 | if e.config.summarize: 107 | writer = SummaryWriter(e.experiment_dir) 108 | 109 | label_batch = data_utils.minibatcher( 110 | word_data=data.train[0], 111 | char_data=data.train[1], 112 | label=data.train[2], 113 | batch_size=e.config.batch_size, 114 | shuffle=True) 115 | 116 | if e.config.use_unlabel: 117 | unlabel_batch = data_utils.minibatcher( 118 | word_data=data.unlabel[0], 119 | char_data=data.unlabel[1], 120 | label=data.unlabel[0], 121 | batch_size=e.config.unlabel_batch_size, 122 | shuffle=True) 123 | 124 | evaluator = train_helper.evaluator(data.inv_tag_vocab, model, e) 125 | 126 | e.log.info("Training start ...") 127 | label_stats = train_helper.tracker( 128 | ["loss", "logloss", "kl_div", "sup_loss"]) 129 | unlabel_stats = train_helper.tracker( 130 | ["loss", "logloss", "kl_div"]) 131 | 132 | for it in range(e.config.n_iter): 133 | model.train() 134 | kl_temp = train_helper.get_kl_temp(e.config.klr, it, 1.0) 135 | 136 | try: 137 | l_data, l_mask, l_char, l_char_mask, l_label, l_ixs = \ 138 | next(label_batch) 139 | except StopIteration: 140 | pass 141 | 142 | lp_logvar1 = label_logvar1_buffer[l_ixs] 143 | lp_mean1 = label_mean1_buffer[l_ixs] 144 | lp_logvar2 = label_logvar2_buffer[l_ixs] 145 | lp_mean2 = label_mean2_buffer[l_ixs] 146 | 147 | l_loss, l_logloss, l_kld, sup_loss, \ 148 | lq_mean1, lq_logvar1, lq_mean2, lq_logvar2, _ = \ 149 | model(l_data, l_mask, l_char, l_char_mask, 150 | l_label, [lp_mean1, lp_mean2], [lp_logvar1, lp_logvar2], 151 | kl_temp) 152 | 153 | label_logvar1_buffer.update_buffer(l_ixs, lq_logvar1, l_mask.sum(-1)) 154 | label_mean1_buffer.update_buffer(l_ixs, lq_mean1, l_mask.sum(-1)) 155 | 156 | label_logvar2_buffer.update_buffer(l_ixs, lq_logvar2, l_mask.sum(-1)) 157 | label_mean2_buffer.update_buffer(l_ixs, lq_mean2, l_mask.sum(-1)) 158 | 159 | label_stats.update( 160 | {"loss": l_loss, "logloss": l_logloss, "kl_div": l_kld, 161 | "sup_loss": sup_loss}, l_mask.sum()) 162 | 163 | if not e.config.use_unlabel: 164 | model.optimize(l_loss) 165 | 166 | else: 167 | try: 168 | u_data, u_mask, u_char, u_char_mask, _, u_ixs = \ 169 | next(unlabel_batch) 170 | except StopIteration: 171 | pass 172 | 173 | up_logvar1 = unlabel_logvar1_buffer[u_ixs] 174 | up_mean1 = unlabel_mean1_buffer[u_ixs] 175 | 176 | up_logvar2 = unlabel_logvar2_buffer[u_ixs] 177 | up_mean2 = unlabel_mean2_buffer[u_ixs] 178 | 179 | u_loss, u_logloss, u_kld, _, \ 180 | uq_mean1, uq_logvar1, uq_mean2, uq_logvar2, _ = \ 181 | model(u_data, u_mask, u_char, u_char_mask, 182 | None, [up_mean1, up_mean2], [up_logvar1, up_logvar2], 183 | kl_temp) 184 | 185 | unlabel_logvar1_buffer.update_buffer( 186 | u_ixs, uq_logvar1, u_mask.sum(-1)) 187 | unlabel_mean1_buffer.update_buffer( 188 | u_ixs, uq_mean1, u_mask.sum(-1)) 189 | 190 | unlabel_logvar2_buffer.update_buffer( 191 | u_ixs, uq_logvar2, u_mask.sum(-1)) 192 | unlabel_mean2_buffer.update_buffer( 193 | u_ixs, uq_mean2, u_mask.sum(-1)) 194 | 195 | unlabel_stats.update( 196 | {"loss": u_loss, "logloss": u_logloss, "kl_div": u_kld}, 197 | u_mask.sum()) 198 | 199 | model.optimize(l_loss + e.config.ur * u_loss) 200 | 201 | if (it + 1) % e.config.print_every == 0: 202 | summary = label_stats.summarize( 203 | "it: {} (max: {}), kl_temp: {:.2f}, labeled".format( 204 | it + 1, len(label_batch), kl_temp)) 205 | if e.config.use_unlabel: 206 | summary += unlabel_stats.summarize(", unlabeled") 207 | e.log.info(summary) 208 | if e.config.summarize: 209 | writer.add_scalar( 210 | "label/kl_temp", kl_temp, it) 211 | for name, value in label_stats.stats.items(): 212 | writer.add_scalar( 213 | "label/" + name, value, it) 214 | if e.config.use_unlabel: 215 | for name, value in unlabel_stats.stats.items(): 216 | writer.add_scalar( 217 | "unlabel/" + name, value, it) 218 | label_stats.reset() 219 | unlabel_stats.reset() 220 | if (it + 1) % e.config.eval_every == 0: 221 | 222 | e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25) 223 | 224 | dev_perf, dev_res = evaluator.evaluate(data.dev) 225 | 226 | e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25) 227 | 228 | if e.config.summarize: 229 | for n, v in dev_perf.items(): 230 | writer.add_scalar( 231 | "dev/" + n, v, it) 232 | 233 | if best_dev_res < dev_res: 234 | best_dev_res = dev_res 235 | 236 | e.log.info("*" * 25 + " TEST SET EVALUATION " + "*" * 25) 237 | 238 | test_perf, test_res = evaluator.evaluate(data.test) 239 | 240 | e.log.info("*" * 25 + " TEST SET EVALUATION " + "*" * 25) 241 | 242 | model.save( 243 | dev_perf=dev_perf, 244 | test_perf=test_perf, 245 | iteration=it) 246 | 247 | if e.config.save_prior: 248 | for buf in all_buffer: 249 | buf.save() 250 | 251 | if e.config.summarize: 252 | writer.add_scalar( 253 | "dev/best_result", best_dev_res, it) 254 | for n, v in test_perf.items(): 255 | writer.add_scalar( 256 | "test/" + n, v, it) 257 | e.log.info("best dev result: {:.4f}, " 258 | "test result: {:.4f}, " 259 | .format(best_dev_res, test_res)) 260 | label_stats.reset() 261 | unlabel_stats.reset() 262 | 263 | 264 | if __name__ == '__main__': 265 | 266 | args = config.get_parser().parse_args() 267 | args.use_cuda = torch.cuda.is_available() 268 | 269 | def exit_handler(*args): 270 | print(args) 271 | print("best dev result: {:.4f}, " 272 | "test result: {:.4f}" 273 | .format(best_dev_res, test_res)) 274 | exit() 275 | 276 | train_helper.register_exit_handler(exit_handler) 277 | 278 | np.random.seed(args.random_seed) 279 | torch.manual_seed(args.random_seed) 280 | with train_helper.experiment(args, args.prefix) as e: 281 | 282 | e.log.info("*" * 25 + " ARGS " + "*" * 25) 283 | e.log.info(args) 284 | e.log.info("*" * 25 + " ARGS " + "*" * 25) 285 | 286 | run(e) 287 | -------------------------------------------------------------------------------- /train_helper.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | import argparse 4 | import os 5 | import pickle 6 | 7 | import numpy as np 8 | 9 | from config import get_parser 10 | from data_utils import minibatcher 11 | from decorators import auto_init_args 12 | 13 | 14 | def register_exit_handler(exit_handler): 15 | import atexit 16 | import signal 17 | 18 | atexit.register(exit_handler) 19 | signal.signal(signal.SIGTERM, exit_handler) 20 | signal.signal(signal.SIGINT, exit_handler) 21 | 22 | 23 | def get_kl_temp(kl_anneal_rate, curr_iteration, max_temp): 24 | temp = np.exp(kl_anneal_rate * curr_iteration) - 1. 25 | return float(np.minimum(temp, max_temp)) 26 | 27 | 28 | class tracker: 29 | @auto_init_args 30 | def __init__(self, names): 31 | assert len(names) > 0 32 | self.reset() 33 | 34 | def __getitem__(self, name): 35 | return self.values.get(name, 0) / self.counter if self.counter else 0 36 | 37 | def __len__(self): 38 | return len(self.names) 39 | 40 | def reset(self): 41 | self.values = dict({name: 0. for name in self.names}) 42 | self.counter = 0 43 | self.create_time = time.time() 44 | 45 | def update(self, named_values, count): 46 | """ 47 | named_values: dictionary with each item as name: value 48 | """ 49 | self.counter += count 50 | for name, value in named_values.items(): 51 | self.values[name] += value.data.cpu().numpy()[0] * count 52 | 53 | def summarize(self, output=""): 54 | if output: 55 | output += ", " 56 | for name in self.names: 57 | output += "{}: {:.3f}, ".format( 58 | name, self.values[name] / self.counter if self.counter else 0) 59 | output += "elapsed time: {:.1f}(s)".format( 60 | time.time() - self.create_time) 61 | return output 62 | 63 | @property 64 | def stats(self): 65 | return {n: v / self.counter if self.counter else 0 66 | for n, v in self.values.items()} 67 | 68 | 69 | class experiment: 70 | @auto_init_args 71 | def __init__(self, config, experiments_prefix, logfile_name="log"): 72 | """Create a new Experiment instance. 73 | 74 | Modified based on: https://github.com/ex4sperans/mag 75 | 76 | Args: 77 | logfile_name: str, naming for log file. This can be useful to 78 | separate logs for different runs on the same experiment 79 | experiments_prefix: str, a prefix to the path where 80 | experiment will be saved 81 | """ 82 | 83 | # get all defaults 84 | all_defaults = {} 85 | for key in vars(config): 86 | all_defaults[key] = get_parser().get_default(key) 87 | 88 | self.default_config = all_defaults 89 | 90 | if not config.debug: 91 | if os.path.isdir(self.experiment_dir): 92 | raise ValueError("log exists: {}".format(self.experiment_dir)) 93 | 94 | print(config) 95 | self._makedir() 96 | 97 | self._make_misc_dir() 98 | 99 | def _makedir(self): 100 | os.makedirs(self.experiment_dir, exist_ok=False) 101 | 102 | def _make_misc_dir(self): 103 | os.makedirs(self.config.prior_file, exist_ok=True) 104 | os.makedirs(self.config.vocab_file, exist_ok=True) 105 | 106 | @property 107 | def experiment_dir(self): 108 | if self.config.debug: 109 | return "./" 110 | else: 111 | # get namespace for each group of args 112 | arg_g = dict() 113 | for group in get_parser()._action_groups: 114 | group_d = {a.dest: self.default_config.get(a.dest, None) 115 | for a in group._group_actions} 116 | arg_g[group.title] = argparse.Namespace(**group_d) 117 | 118 | # skip default value 119 | identifier = "" 120 | for key, value in sorted(vars(arg_g["configs"]).items()): 121 | if getattr(self.config, key) != value: 122 | identifier += key + str(getattr(self.config, key)) 123 | return os.path.join(self.experiments_prefix, identifier) 124 | 125 | @property 126 | def log_file(self): 127 | return os.path.join(self.experiment_dir, self.logfile_name) 128 | 129 | def register_directory(self, dirname): 130 | directory = os.path.join(self.experiment_dir, dirname) 131 | os.makedirs(directory, exist_ok=True) 132 | setattr(self, dirname, directory) 133 | 134 | def _register_existing_directories(self): 135 | for item in os.listdir(self.experiment_dir): 136 | fullpath = os.path.join(self.experiment_dir, item) 137 | if os.path.isdir(fullpath): 138 | setattr(self, item, fullpath) 139 | 140 | def __enter__(self): 141 | 142 | if self.config.debug: 143 | logging.basicConfig( 144 | level=logging.DEBUG, 145 | format='%(asctime)s %(levelname)s: %(message)s', 146 | datefmt='%m-%d %H:%M') 147 | else: 148 | print("log saving to", self.log_file) 149 | logging.basicConfig( 150 | filename=self.log_file, 151 | filemode='w+', level=logging.INFO, 152 | format='%(asctime)s %(levelname)s: %(message)s', 153 | datefmt='%m-%d %H:%M') 154 | 155 | self.log = logging.getLogger() 156 | return self 157 | 158 | def __exit__(self, *args): 159 | logging.shutdown() 160 | 161 | 162 | class prior_buffer: 163 | def __init__(self, inputs, dim, freq, name, experiment, init_path=None): 164 | self.dim = dim 165 | self.freq = freq 166 | self.expe = experiment 167 | if init_path is not None: 168 | self.path = os.path.join(init_path, name + "_" + str(dim)) 169 | else: 170 | self.path = init_path 171 | if self.path is None or not os.path.isfile(self.path): 172 | self.buffer = np.asarray( 173 | [np.zeros((len(r), dim)).astype('float32') for r in inputs]) 174 | elif self.path is not None and os.path.isfile(self.path): 175 | self.buffer = self.load() 176 | else: 177 | raise ValueError( 178 | "invalid initial path for prior buffer: {}".format(init_path)) 179 | self.count = [0] * len(inputs) 180 | 181 | def __len__(self): 182 | return len(self.buffer) 183 | 184 | def update_buffer(self, ixs, post, seq_len): 185 | """ 186 | Args: 187 | ixs: list of index 188 | post: batch size x batch length x dim 189 | seq_len: batch size 190 | """ 191 | for p, i, l in zip(post.data.cpu().numpy(), ixs, seq_len): 192 | new_i = i % len(self) 193 | if self.count[new_i] % self.freq == 0: 194 | assert len(self.buffer[new_i]) == l 195 | self.buffer[new_i] = p[:int(l), :] 196 | self.count[new_i] += 1 197 | 198 | def __getitem__(self, ixs): 199 | get_buffer = self.buffer[ixs] 200 | max_len = np.max([len(b) for b in get_buffer]) 201 | batch_size = len(ixs) 202 | 203 | pad_buffer = np.zeros((batch_size, max_len, self.dim)) \ 204 | .astype("float32") 205 | for i, b in enumerate(get_buffer): 206 | pad_buffer[i, :len(b), :] = b 207 | return pad_buffer 208 | 209 | def save(self): 210 | pickle.dump(self.buffer, open(self.path, "wb+"), protocol=-1) 211 | self.expe.log.info("prior saved to: {}".format(self.path)) 212 | 213 | def load(self): 214 | with open(self.path, "rb+") as infile: 215 | priors = pickle.load(infile) 216 | self.expe.log.info("prior loaded from: {}".format(self.path)) 217 | return priors 218 | 219 | 220 | class accuracy_reporter: 221 | def __init__(self): 222 | self.right_count = 0 223 | self.instance_count = 0 224 | 225 | def update(self, pred, label, mask): 226 | self.right_count += ((pred == label) * mask).sum() 227 | self.instance_count += mask.sum() 228 | 229 | def report(self): 230 | acc = self.right_count / self.instance_count \ 231 | if self.instance_count else 0.0 232 | return {"acc": acc, "f1": 0., "prec": 0., "rec": 0.}, acc 233 | 234 | 235 | class f1_reporter: 236 | """ 237 | modified based on: https://github.com/kimiyoung/transfer/blob/master/ner_span.py 238 | """ 239 | def __init__(self, inv_tag_vocab): 240 | self.inv_tag_vocab = inv_tag_vocab 241 | self.instance_count = 0 242 | self.right_count = 0 243 | self.tp = 0 244 | self.fp = 0 245 | self.fn = 0 246 | 247 | @staticmethod 248 | def extract_ent(y, m, inv_tag_vocab): 249 | def label_decode(label): 250 | if label == 'O': 251 | return 'O', 'O' 252 | return tuple(label.split('-')) 253 | 254 | def new_match(y_prev, y_next): 255 | l_prev, l_next = inv_tag_vocab[y_prev], inv_tag_vocab[y_next] 256 | c1_prev, c2_prev = label_decode(l_prev) 257 | c1_next, c2_next = label_decode(l_next) 258 | if c2_prev != c2_next: 259 | return False 260 | if c1_next not in ['I', 'E']: 261 | return False 262 | return True 263 | 264 | ret = set() 265 | i = 0 266 | while i < y.shape[0]: 267 | if m[i] == 0: 268 | i += 1 269 | continue 270 | c1, c2 = label_decode(inv_tag_vocab[y[i]]) 271 | if c1 in ['O', 'I', 'E']: 272 | i += 1 273 | continue 274 | if c1 == 'S': 275 | ret.add((i, i + 1, c2)) 276 | i += 1 277 | continue 278 | j = i + 1 279 | if j == y.shape[0]: 280 | break 281 | end = False 282 | while m[j] != 0 and not end and new_match(y[i], y[j]): 283 | ic1, ic2 = label_decode(inv_tag_vocab[y[j]]) 284 | if ic1 == 'E': 285 | end = True 286 | break 287 | j += 1 288 | if not end: 289 | i += 1 290 | continue 291 | ret.add((i, j, c2)) 292 | i = j 293 | return ret 294 | 295 | def update(self, pred, label, mask): 296 | pred, label, mask = pred.flatten(), label.flatten(), mask.flatten() 297 | self.right_count += ((label == pred) * mask).sum() 298 | self.instance_count += mask.sum() 299 | 300 | p_ent = f1_reporter.extract_ent(pred, mask, self.inv_tag_vocab) 301 | y_ent = f1_reporter.extract_ent(label, mask, self.inv_tag_vocab) 302 | 303 | for ent in p_ent: 304 | if ent in y_ent: 305 | self.tp += 1 306 | else: 307 | self.fp += 1 308 | for ent in y_ent: 309 | if ent not in p_ent: 310 | self.fn += 1 311 | 312 | def report(self): 313 | acc = self.right_count / self.instance_count \ 314 | if self.instance_count else 0.0 315 | prec = 1.0 * self.tp / (self.tp + self.fp) \ 316 | if self.tp + self.fp > 0 else 0.0 317 | recall = 1.0 * self.tp / (self.tp + self.fn) \ 318 | if self.tp + self.fn > 0 else 0.0 319 | f1 = 2.0 * prec * recall / (prec + recall) \ 320 | if prec + recall > 0 else 0.0 321 | return {"acc": acc, "f1": f1, "prec": prec, "rec": recall}, f1 322 | 323 | 324 | class evaluator: 325 | @auto_init_args 326 | def __init__(self, inv_tag_vocab, model, experiment): 327 | self.expe = experiment 328 | 329 | def evaluate(self, data): 330 | self.model.eval() 331 | eval_stats = tracker(["log_loss"]) 332 | if self.expe.config.f1_score: 333 | reporter = f1_reporter(self.inv_tag_vocab) 334 | else: 335 | reporter = accuracy_reporter() 336 | for data, mask, char, char_mask, label, _ in \ 337 | minibatcher( 338 | word_data=data[0], 339 | char_data=data[1], 340 | label=data[2], 341 | batch_size=100, 342 | shuffle=False): 343 | outputs = self.model(data, mask, char, char_mask, 344 | label, None, None, 1.0) 345 | pred, log_loss = outputs[-1], outputs[1] 346 | reporter.update(pred, label, mask) 347 | 348 | eval_stats.update( 349 | {"log_loss": log_loss}, mask.sum()) 350 | perf, res = reporter.report() 351 | summary = eval_stats.summarize( 352 | ", ".join([x[0] + ": {:.5f}".format(x[1]) 353 | for x in sorted(perf.items())])) 354 | self.expe.log.info(summary) 355 | return perf, res 356 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import model_utils 4 | 5 | import torch.nn as nn 6 | import stochastic_layers as sl 7 | import torch.nn.functional as F 8 | 9 | from decorators import auto_init_pytorch 10 | from torch.autograd import Variable 11 | 12 | 13 | class base(nn.Module): 14 | def __init__(self, word_vocab_size, char_vocab_size, embed_dim, 15 | embed_init, n_tags, experiment): 16 | super(base, self).__init__() 17 | self.expe = experiment 18 | self.use_cuda = self.expe.config.use_cuda 19 | self.char_encoder = model_utils.char_rnn( 20 | rnn_type=self.expe.config.rtype, 21 | vocab_size=char_vocab_size, 22 | embed_dim=self.expe.config.cdim, 23 | hidden_size=self.expe.config.chsize) 24 | 25 | self.word_embed = nn.Embedding(word_vocab_size, embed_dim) 26 | 27 | if embed_init is not None: 28 | self.word_embed.weight.data.copy_(torch.from_numpy(embed_init)) 29 | self.expe.log.info("Initialized with pretrained word embedding") 30 | if not self.expe.config.train_emb: 31 | self.word_embed.weight.requires_grad = False 32 | self.expe.log.info("Word Embedding not trainable") 33 | 34 | self.word_encoder = model_utils.get_rnn(self.expe.config.rtype)( 35 | input_size=(embed_dim + 2 * self.expe.config.chsize), 36 | hidden_size=self.expe.config.rsize, 37 | bidirectional=True, 38 | batch_first=True) 39 | 40 | self.x2token = nn.Linear( 41 | embed_dim, word_vocab_size, bias=False) 42 | 43 | if self.expe.config.tw: 44 | self.x2token.weight = self.word_embed.weight 45 | 46 | def get_input_vecs(self, data, mask, char, char_mask): 47 | word_emb = self.word_embed(data.long()) 48 | 49 | char_input = self.char_encoder(char, char_mask, mask) 50 | data_emb = torch.cat([char_input, word_emb], dim=-1) 51 | 52 | return data_emb 53 | 54 | def to_var(self, inputs): 55 | if self.use_cuda: 56 | if isinstance(inputs, Variable): 57 | inputs = inputs.cuda() 58 | inputs.volatile = self.volatile 59 | return inputs 60 | else: 61 | if not torch.is_tensor(inputs): 62 | inputs = torch.from_numpy(inputs) 63 | return Variable(inputs.cuda(), volatile=self.volatile) 64 | else: 65 | if isinstance(inputs, Variable): 66 | inputs = inputs.cpu() 67 | inputs.volatile = self.volatile 68 | return inputs 69 | else: 70 | if not torch.is_tensor(inputs): 71 | inputs = torch.from_numpy(inputs) 72 | return Variable(inputs, volatile=self.volatile) 73 | 74 | def to_vars(self, *inputs): 75 | return [self.to_var(inputs_) if inputs_ is not None and inputs_.size 76 | else None for inputs_ in inputs] 77 | 78 | def optimize(self, loss): 79 | self.opt.zero_grad() 80 | loss.backward() 81 | if self.expe.config.grad_clip is not None: 82 | torch.nn.utils.clip_grad_norm( 83 | self.parameters(), self.expe.config.grad_clip) 84 | self.opt.step() 85 | 86 | def init_optimizer(self, opt_type, learning_rate, weight_decay): 87 | if opt_type.lower() == "adam": 88 | optimizer = torch.optim.Adam 89 | elif opt_type.lower() == "rmsprop": 90 | optimizer = torch.optim.RMSprop 91 | elif opt_type.lower() == "sgd": 92 | optimizer = torch.optim.SGD 93 | else: 94 | raise NotImplementedError("invalid optimizer: {}".format(opt_type)) 95 | 96 | opt = optimizer( 97 | params=filter( 98 | lambda p: p.requires_grad, self.parameters() 99 | ), 100 | lr=learning_rate, 101 | weight_decay=weight_decay) 102 | return opt 103 | 104 | def save(self, dev_perf, test_perf, iteration): 105 | save_path = os.path.join(self.expe.experiment_dir, "model.ckpt") 106 | checkpoint = { 107 | "dev_perf": dev_perf, 108 | "test_perf": test_perf, 109 | "iteration": iteration, 110 | "state_dict": self.state_dict(), 111 | "config": self.expe.config 112 | } 113 | torch.save(checkpoint, save_path) 114 | self.expe.log.info("model saved to {}".format(save_path)) 115 | 116 | def load(self, checkpointed_state_dict=None): 117 | if checkpointed_state_dict is None: 118 | save_path = os.path.join(self.expe.experiment_dir, "model.ckpt") 119 | checkpoint = torch.load(save_path, 120 | map_location=lambda storage, 121 | loc: storage) 122 | self.load_state_dict(checkpoint['state_dict']) 123 | self.expe.log.info("model loaded from {}".format(save_path)) 124 | else: 125 | self.load_state_dict(checkpointed_state_dict) 126 | self.expe.log.info("model loaded!") 127 | 128 | @property 129 | def volatile(self): 130 | return not self.training 131 | 132 | @property 133 | def sampling(self): 134 | return self.training 135 | 136 | 137 | class vsl_g(base): 138 | @auto_init_pytorch 139 | def __init__(self, word_vocab_size, char_vocab_size, embed_dim, 140 | embed_init, n_tags, experiment): 141 | super(vsl_g, self).__init__( 142 | word_vocab_size, char_vocab_size, embed_dim, embed_init, 143 | n_tags, experiment) 144 | assert self.expe.config.model.lower() == "g" 145 | self.to_latent_variable = sl.gaussian_layer( 146 | input_size=2 * self.expe.config.rsize, 147 | latent_z_size=self.expe.config.zsize) 148 | 149 | self.classifier = nn.Linear(self.expe.config.zsize, n_tags) 150 | 151 | self.z2x = model_utils.get_mlp_layer( 152 | input_size=self.expe.config.zsize, 153 | hidden_size=self.expe.config.mhsize, 154 | output_size=embed_dim, 155 | n_layer=self.expe.config.mlayer) 156 | 157 | def forward( 158 | self, data, mask, char, char_mask, label, 159 | prior_mean, prior_logvar, kl_temp): 160 | data, mask, char, char_mask, label, prior_mean, prior_logvar = \ 161 | self.to_vars(data, mask, char, char_mask, label, 162 | prior_mean, prior_logvar) 163 | 164 | batch_size, batch_len = data.size() 165 | input_vecs = self.get_input_vecs(data, mask, char, char_mask) 166 | hidden_vecs, _, _ = model_utils.get_rnn_output( 167 | input_vecs, mask, self.word_encoder) 168 | 169 | z, mean_qs, logvar_qs = \ 170 | self.to_latent_variable(hidden_vecs, mask, self.sampling) 171 | 172 | mean_x = self.z2x(z) 173 | 174 | x = model_utils.gaussian( 175 | mean_x, Variable(mean_x.data.new(1).fill_(self.expe.config.xvar))) 176 | 177 | x_pred = self.x2token(x) 178 | 179 | if label is None: 180 | sup_loss = class_logits = None 181 | else: 182 | class_logits = self.classifier(z) 183 | sup_loss = F.cross_entropy( 184 | class_logits.view(batch_size * batch_len, -1), 185 | label.view(-1).long(), 186 | reduce=False).view_as(data) * mask 187 | sup_loss = sup_loss.sum(-1) / mask.sum(-1) 188 | 189 | log_loss = F.cross_entropy( 190 | x_pred.view(batch_size * batch_len, -1), 191 | data.view(-1).long(), 192 | reduce=False).view_as(data) * mask 193 | log_loss = log_loss.sum(-1) / mask.sum(-1) 194 | 195 | if prior_mean is not None and prior_logvar is not None: 196 | kl_div = model_utils.compute_KL_div( 197 | mean_qs, logvar_qs, prior_mean, prior_logvar) 198 | 199 | kl_div = (kl_div * mask.unsqueeze(-1)).sum(-1) 200 | kl_div = kl_div.sum(-1) / mask.sum(-1) 201 | 202 | loss = log_loss + kl_temp * kl_div 203 | else: 204 | kl_div = None 205 | loss = log_loss 206 | 207 | if sup_loss is not None: 208 | loss = loss + sup_loss 209 | 210 | return loss.mean(), log_loss.mean(), \ 211 | kl_div.mean() if kl_div is not None else None, \ 212 | sup_loss.mean() if sup_loss is not None else None, \ 213 | mean_qs, logvar_qs, \ 214 | class_logits.data.cpu().numpy().argmax(-1) \ 215 | if class_logits is not None else None 216 | 217 | 218 | class vsl_gg(base): 219 | @auto_init_pytorch 220 | def __init__(self, word_vocab_size, char_vocab_size, embed_dim, 221 | embed_init, n_tags, experiment): 222 | super(vsl_gg, self).__init__( 223 | word_vocab_size, char_vocab_size, embed_dim, embed_init, 224 | n_tags, experiment) 225 | if self.expe.config.model.lower() == "flat": 226 | self.to_latent_variable = sl.gaussian_flat_layer( 227 | input_size=2 * self.expe.config.rsize, 228 | latent_z_size=self.expe.config.zsize, 229 | latent_y_size=self.expe.config.ysize) 230 | yzsize = self.expe.config.zsize + self.expe.config.ysize 231 | elif self.expe.config.model.lower() == "hier": 232 | self.to_latent_variable = sl.gaussian_hier_layer( 233 | input_size=2 * self.expe.config.rsize, 234 | latent_z_size=self.expe.config.zsize, 235 | latent_y_size=self.expe.config.ysize) 236 | yzsize = self.expe.config.zsize 237 | else: 238 | raise ValueError( 239 | "invalid model type: {}".format(self.expe.config.model)) 240 | 241 | self.classifier = nn.Linear(self.expe.config.ysize, n_tags) 242 | 243 | self.yz2x = model_utils.get_mlp_layer( 244 | input_size=yzsize, 245 | hidden_size=self.expe.config.mhsize, 246 | output_size=embed_dim, 247 | n_layer=self.expe.config.mlayer) 248 | 249 | def forward( 250 | self, data, mask, char, char_mask, label, 251 | prior_mean, prior_logvar, kl_temp): 252 | if prior_mean is not None: 253 | prior_mean1, prior_mean2 = prior_mean 254 | prior_logvar1, prior_logvar2 = prior_logvar 255 | else: 256 | prior_mean1 = prior_mean2 = prior_logvar1 = prior_logvar2 = None 257 | 258 | data, mask, char, char_mask, label, prior_mean1, \ 259 | prior_mean2, prior_logvar1, prior_logvar2 = \ 260 | self.to_vars(data, mask, char, char_mask, label, 261 | prior_mean1, prior_mean2, 262 | prior_logvar1, prior_logvar2) 263 | 264 | batch_size, batch_len = data.size() 265 | input_vecs = self.get_input_vecs(data, mask, char, char_mask) 266 | hidden_vecs, _, _ = model_utils.get_rnn_output( 267 | input_vecs, mask, self.word_encoder) 268 | 269 | z, y, mean_qs, logvar_qs, mean2_qs, logvar2_qs = \ 270 | self.to_latent_variable(hidden_vecs, mask, self.sampling) 271 | 272 | if self.expe.config.model.lower() == "flat": 273 | yz = torch.cat([z, y], dim=-1) 274 | elif self.expe.config.model.lower() == "hier": 275 | yz = z 276 | 277 | mean_x = self.yz2x(yz) 278 | 279 | x = model_utils.gaussian( 280 | mean_x, Variable(mean_x.data.new(1).fill_(self.expe.config.xvar))) 281 | 282 | x_pred = self.x2token(x) 283 | 284 | if label is None: 285 | sup_loss = class_logits = None 286 | else: 287 | class_logits = self.classifier(y) 288 | sup_loss = F.cross_entropy( 289 | class_logits.view(batch_size * batch_len, -1), 290 | label.view(-1).long(), 291 | reduce=False).view_as(data) * mask 292 | sup_loss = sup_loss.sum(-1) / mask.sum(-1) 293 | 294 | log_loss = F.cross_entropy( 295 | x_pred.view(batch_size * batch_len, -1), 296 | data.view(-1).long(), 297 | reduce=False).view_as(data) * mask 298 | log_loss = log_loss.sum(-1) / mask.sum(-1) 299 | 300 | if prior_mean is not None: 301 | kl_div1 = model_utils.compute_KL_div( 302 | mean_qs, logvar_qs, prior_mean1, prior_logvar1) 303 | kl_div2 = model_utils.compute_KL_div( 304 | mean2_qs, logvar2_qs, prior_mean2, prior_logvar2) 305 | 306 | kl_div = (kl_div1 * mask.unsqueeze(-1)).sum(-1) + \ 307 | (kl_div2 * mask.unsqueeze(-1)).sum(-1) 308 | kl_div = kl_div.sum(-1) / mask.sum(-1) 309 | 310 | loss = log_loss + kl_temp * kl_div 311 | else: 312 | kl_div = None 313 | loss = log_loss 314 | 315 | if sup_loss is not None: 316 | loss = loss + sup_loss 317 | 318 | return loss.mean(), log_loss.mean(), \ 319 | kl_div.mean() if kl_div is not None else None, \ 320 | sup_loss.mean() if sup_loss is not None else None, \ 321 | mean_qs, logvar_qs, mean2_qs, logvar2_qs, \ 322 | class_logits.data.cpu().numpy().argmax(-1) \ 323 | if class_logits is not None else None 324 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | 6 | from collections import Counter 7 | 8 | from decorators import auto_init_args, lazy_execute 9 | from config import UNK_WORD_IDX, UNK_WORD, UNK_CHAR_IDX, \ 10 | UNK_CHAR 11 | 12 | 13 | class data_holder: 14 | @auto_init_args 15 | def __init__(self, train, dev, test, unlabel, 16 | tag_vocab, vocab, char_vocab): 17 | self.inv_vocab = {i: w for w, i in vocab.items()} 18 | self.inv_tag_vocab = {i: w for w, i in tag_vocab.items()} 19 | 20 | 21 | class data_processor: 22 | def __init__(self, experiment): 23 | self.expe = experiment 24 | 25 | def process(self): 26 | fn = "vocab_" + str(self.expe.config.vocab_size) 27 | vocab_file = os.path.join(self.expe.config.vocab_file, fn) 28 | 29 | self.expe.log.info("loading data from {} ...".format( 30 | self.expe.config.data_file)) 31 | with open(self.expe.config.data_file, "rb+") as infile: 32 | train_data, dev_data, test_data = pickle.load(infile) 33 | 34 | train_v_data = train_data[0] 35 | unlabeled_data = None 36 | 37 | if self.expe.config.use_unlabel: 38 | unlabeled_data = self._load_sent(self.expe.config.unlabel_file) 39 | train_v_data = unlabeled_data + train_data[0] 40 | 41 | W, vocab, char_vocab = \ 42 | self._build_vocab_from_embedding( 43 | train_v_data, dev_data[0] + test_data[0], 44 | self.expe.config.embed_file, 45 | self.expe.config.vocab_size, self.expe.config.char_vocab_size, 46 | file_name=vocab_file) 47 | 48 | tag_vocab = self._load_tag(self.expe.config.tag_file) 49 | 50 | self.expe.log.info("tag vocab size: {}".format(len(tag_vocab))) 51 | 52 | train_data = self._label_to_idx( 53 | train_data[0], train_data[1], vocab, char_vocab, tag_vocab) 54 | dev_data = self._label_to_idx( 55 | dev_data[0], dev_data[1], vocab, char_vocab, tag_vocab) 56 | test_data = self._label_to_idx( 57 | test_data[0], test_data[1], vocab, char_vocab, tag_vocab) 58 | 59 | def cal_stats(data): 60 | unk_count = 0 61 | total_count = 0 62 | leng = [] 63 | for sent in data: 64 | leng.append(len(sent)) 65 | for w in sent: 66 | if w == UNK_WORD_IDX: 67 | unk_count += 1 68 | total_count += 1 69 | return (unk_count, total_count, unk_count / total_count), \ 70 | (len(leng), max(leng), min(leng), sum(leng) / len(leng)) 71 | 72 | train_unk_stats, train_len_stats = cal_stats(train_data[0]) 73 | self.expe.log.info("#train data: {}, max len: {}, " 74 | "min len: {}, avg len: {:.2f}" 75 | .format(*train_len_stats)) 76 | 77 | self.expe.log.info("#unk in train sentences: {}" 78 | .format(train_unk_stats)) 79 | 80 | dev_unk_stats, dev_len_stats = cal_stats(dev_data[0]) 81 | self.expe.log.info("#dev data: {}, max len: {}, " 82 | "min len: {}, avg len: {:.2f}" 83 | .format(*dev_len_stats)) 84 | 85 | self.expe.log.info("#unk in dev sentences: {}" 86 | .format(dev_unk_stats)) 87 | 88 | test_unk_stats, test_len_stats = cal_stats(test_data[0]) 89 | self.expe.log.info("#test data: {}, max len: {}, " 90 | "min len: {}, avg len: {:.2f}" 91 | .format(*test_len_stats)) 92 | 93 | self.expe.log.info("#unk in test sentences: {}" 94 | .format(test_unk_stats)) 95 | 96 | if self.expe.config.use_unlabel: 97 | unlabeled_data = self._unlabel_to_idx( 98 | unlabeled_data, vocab, char_vocab) 99 | un_unk_stats, un_len_stats = cal_stats(unlabeled_data[0]) 100 | self.expe.log.info("#unlabeled data: {}, max len: {}, " 101 | "min len: {}, avg len: {:.2f}" 102 | .format(*un_len_stats)) 103 | 104 | self.expe.log.info("#unk in unlabeled sentences: {}" 105 | .format(un_unk_stats)) 106 | 107 | data = data_holder( 108 | train=train_data, 109 | dev=dev_data, 110 | test=test_data, 111 | unlabel=unlabeled_data, 112 | tag_vocab=tag_vocab, 113 | vocab=vocab, 114 | char_vocab=char_vocab) 115 | 116 | return data, W 117 | 118 | def _load_tag(self, path): 119 | self.expe.log.info("loading tags from " + path) 120 | tag = {} 121 | with open(path, 'r') as f: 122 | for (n, i) in enumerate(f): 123 | tag[i.strip()] = n 124 | return tag 125 | 126 | def _load_sent(self, path): 127 | self.expe.log.info("loading data from " + path) 128 | sents = [] 129 | with open(path, "r+", encoding='utf-8') as df: 130 | for line in df: 131 | if line.strip(): 132 | words = line.strip("\n").split(" ") 133 | sents.append(words) 134 | return sents 135 | 136 | def _label_to_idx(self, sentences, tags, vocab, char_vocab, tag_vocab): 137 | sentence_holder = [] 138 | sent_char_holder = [] 139 | tag_holder = [] 140 | for sentence, tag in zip(sentences, tags): 141 | chars = [] 142 | words = [] 143 | for w in sentence: 144 | words.append(vocab.get(w, 0)) 145 | chars.append([char_vocab.get(c, 0) for c in w]) 146 | sentence_holder.append(words) 147 | sent_char_holder.append(chars) 148 | tag_holder.append([tag_vocab[t] for t in tag]) 149 | self.expe.log.info("#sent: {}".format(len(sentence_holder))) 150 | self.expe.log.info("#word: {}".format(len(sum(sentence_holder, [])))) 151 | self.expe.log.info("#tag: {}".format(len(sum(tag_holder, [])))) 152 | return np.asarray(sentence_holder), np.asarray(sent_char_holder), \ 153 | np.asarray(tag_holder) 154 | 155 | def _unlabel_to_idx(self, sentences, vocab, char_vocab): 156 | sentence_holder = [] 157 | sent_char_holder = [] 158 | for sentence in sentences: 159 | chars = [] 160 | words = [] 161 | for w in sentence: 162 | words.append(vocab.get(w, 0)) 163 | chars.append([char_vocab.get(c, 0) for c in w]) 164 | sentence_holder.append(words) 165 | sent_char_holder.append(chars) 166 | self.expe.log.info("#sent: {}".format(len(sentence_holder))) 167 | return np.asarray(sentence_holder), np.asarray(sent_char_holder) 168 | 169 | def _load_twitter_embedding(self, path): 170 | with open(path, 'r', encoding='utf8') as fp: 171 | # word_vectors: word --> vector 172 | word_vectors = {} 173 | for line in fp: 174 | line = line.strip("\n").split("\t") 175 | word_vectors[line[0]] = np.array( 176 | list(map(float, line[1].split(" "))), dtype='float32') 177 | vocab_embed = word_vectors.keys() 178 | embed_dim = word_vectors[next(iter(vocab_embed))].shape[0] 179 | 180 | return word_vectors, vocab_embed, embed_dim 181 | 182 | def _load_glove_embedding(self, path): 183 | with open(path, 'r', encoding='utf8') as fp: 184 | # word_vectors: word --> vector 185 | word_vectors = {} 186 | for line in fp: 187 | line = line.strip("\n").split(" ") 188 | word_vectors[line[0]] = np.array( 189 | list(map(float, line[1:])), dtype='float32') 190 | vocab_embed = word_vectors.keys() 191 | embed_dim = word_vectors[next(iter(vocab_embed))].shape[0] 192 | 193 | return word_vectors, vocab_embed, embed_dim 194 | 195 | def _load_ud_embedding(self, path): 196 | from gensim.models.keyedvectors import KeyedVectors 197 | word_vectors = KeyedVectors.load_word2vec_format(path, binary=True) 198 | vocab_embed = word_vectors.vocab.keys() 199 | embed_dim = word_vectors[next(iter(vocab_embed))].shape[0] 200 | return word_vectors, vocab_embed, embed_dim 201 | 202 | def _build_vocab_from_data(self, train_sents, devtest_sents): 203 | self.expe.log.info("vocab file not exist, start building") 204 | train_char_vocab = Counter() 205 | train_vocab = Counter() 206 | for sent in train_sents: 207 | for w in sent: 208 | train_vocab[w] += 1 209 | for c in w: 210 | train_char_vocab[c] += 1 211 | devtest_vocab = Counter() 212 | for sent in devtest_sents: 213 | for w in sent: 214 | devtest_vocab[w] += 1 215 | 216 | return train_char_vocab, train_vocab, devtest_vocab 217 | 218 | @lazy_execute("_load_from_pickle") 219 | def _build_vocab_from_embedding( 220 | self, train_sents, devtest_sents, embed_file, 221 | vocab_size, char_vocab_size, file_name): 222 | self.expe.log.info("loading embedding file from {}".format(embed_file)) 223 | if self.expe.config.embed_type.lower() == "glove": 224 | word_vectors, vocab_embed, embed_dim = \ 225 | self._load_glove_embedding(embed_file) 226 | elif self.expe.config.embed_type.lower() == "twitter": 227 | word_vectors, vocab_embed, embed_dim = \ 228 | self._load_twitter_embedding(embed_file) 229 | else: 230 | word_vectors, vocab_embed, embed_dim = \ 231 | self._load_ud_embedding(embed_file) 232 | 233 | train_char_vocab, train_vocab, devtest_vocab = \ 234 | self._build_vocab_from_data(train_sents, devtest_sents) 235 | 236 | word_vocab = train_vocab + devtest_vocab 237 | 238 | char_ls = train_char_vocab.most_common(char_vocab_size) 239 | self.expe.log.info('#Chars: {}'.format(len(char_ls))) 240 | for key in char_ls[:5]: 241 | self.expe.log.info(key) 242 | self.expe.log.info('...') 243 | for key in char_ls[-5:]: 244 | self.expe.log.info(key) 245 | char_vocab = {c[0]: index + 1 for (index, c) in enumerate(char_ls)} 246 | 247 | char_vocab[UNK_CHAR] = UNK_CHAR_IDX 248 | 249 | self.expe.log.info("char vocab size: {}".format(len(char_vocab))) 250 | 251 | vocab = {UNK_WORD: UNK_WORD_IDX} 252 | W = [np.random.uniform(-0.1, 0.1, size=(1, embed_dim))] 253 | n = 0 254 | for w, c in sorted(word_vocab.items(), key=lambda x: -x[1]): 255 | if w in vocab_embed: 256 | W.append(word_vectors[w][None, :]) 257 | vocab[w] = n + 1 258 | n += 1 259 | elif w.lower() in vocab_embed: 260 | W.append(word_vectors[w.lower()][None, :]) 261 | vocab[w] = n + 1 262 | n += 1 263 | W = np.concatenate(W, axis=0).astype('float32') 264 | 265 | self.expe.log.info( 266 | "{}/{} words are initialized with loaded embeddings." 267 | .format(n, len(vocab))) 268 | return W, vocab, char_vocab 269 | 270 | def _load_from_pickle(self, file_name): 271 | self.expe.log.info("loading from {}".format(file_name)) 272 | with open(file_name, "rb") as fp: 273 | data = pickle.load(fp) 274 | return data 275 | 276 | 277 | class minibatcher: 278 | @auto_init_args 279 | def __init__(self, word_data, char_data, label, batch_size, shuffle): 280 | self._reset() 281 | 282 | def __len__(self): 283 | return len(self.idx_pool) 284 | 285 | def _reset(self): 286 | self.pointer = 0 287 | idx_list = np.arange(len(self.word_data)) 288 | if self.shuffle: 289 | np.random.shuffle(idx_list) 290 | self.idx_pool = [idx_list[i: i + self.batch_size] 291 | for i in range(0, len(self.word_data), 292 | self.batch_size)] 293 | 294 | def _pad(self, word_data, char_data, labels): 295 | max_word_len = max([len(sent) for sent in word_data]) 296 | max_char_len = max([len(char) for sent in char_data 297 | for char in sent]) 298 | 299 | input_data = \ 300 | np.zeros((len(word_data), max_word_len)).astype("float32") 301 | input_mask = \ 302 | np.zeros((len(word_data), max_word_len)).astype("float32") 303 | input_char = \ 304 | np.zeros( 305 | (len(word_data), max_word_len, max_char_len)).astype("float32") 306 | input_char_mask = \ 307 | np.zeros( 308 | (len(word_data), max_word_len, max_char_len)).astype("float32") 309 | input_label = \ 310 | np.zeros((len(word_data), max_word_len)).astype("float32") 311 | 312 | for i, (sent, chars, label) in enumerate( 313 | zip(word_data, char_data, labels)): 314 | input_data[i, :len(sent)] = \ 315 | np.asarray(list(sent)).astype("float32") 316 | input_label[i, :len(label)] = \ 317 | np.asarray(list(label)).astype("float32") 318 | input_mask[i, :len(sent)] = 1. 319 | 320 | for k, char in enumerate(chars): 321 | input_char[i, k, :len(char)] = \ 322 | np.asarray(char).astype("float32") 323 | input_char_mask[i, k, :len(char)] = 1 324 | 325 | return [input_data, input_mask, input_char, 326 | input_char_mask, input_label] 327 | 328 | def __iter__(self): 329 | return self 330 | 331 | def __next__(self): 332 | if self.pointer == len(self.idx_pool): 333 | self._reset() 334 | raise StopIteration() 335 | 336 | idx = self.idx_pool[self.pointer] 337 | sents, chars, label = \ 338 | self.word_data[idx], self.char_data[idx], self.label[idx] 339 | 340 | self.pointer += 1 341 | return self._pad(sents, chars, label) + [idx] 342 | --------------------------------------------------------------------------------