├── image └── framework.png ├── old ├── constant.py ├── pipeline.py ├── util.py ├── eval_single.py ├── example_config.json ├── README.md ├── conlleval.py ├── train_single.py ├── data.py ├── task.py └── _model.py ├── constant.py ├── util.py ├── eval_single.py ├── eval_multi.py ├── example_config.json ├── README.md ├── data.py ├── conlleval.py ├── train_single.py ├── train_crosslingual.py ├── model.py └── train_multi.py /image/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limteng-rpi/mlmt/HEAD/image/framework.png -------------------------------------------------------------------------------- /old/constant.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | UNKNOWN_TOKEN = '$UNK$' 4 | UNKNOWN_TOKEN_INDEX = 1 5 | PADDING = '$PAD$' 6 | PADDING_INDEX = 0 7 | EMBED_START_IDX = 2 8 | CHAR_EMBED_START_IDX = 2 9 | 10 | EVAL_BATCH_SIZE = 200 11 | 12 | LOGGING_LEVEL = logging.INFO -------------------------------------------------------------------------------- /constant.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | # UNKNOWN_TOKEN = '$UNK$' 4 | # UNKNOWN_TOKEN_INDEX = 1 5 | # PADDING = '$PAD$' 6 | # PADDING_INDEX = 0 7 | # EMBED_START_IDX = 2 8 | # CHAR_EMBED_START_IDX = 2 9 | 10 | PAD = '<$PAD$>' 11 | UNK = '<$UNK$>' 12 | PAD_INDEX = 0 13 | UNK_INDEX = 1 14 | TOKEN_PADS = [ 15 | (PAD, PAD_INDEX), 16 | (UNK, UNK_INDEX) 17 | ] 18 | CHAR_PADS = [ 19 | (PAD, PAD_INDEX), 20 | (UNK, UNK_INDEX) 21 | ] 22 | 23 | EVAL_BATCH_SIZE = 200 24 | 25 | LOGGING_LEVEL = logging.INFO 26 | 27 | PENN_TREEBANK_BRACKETS = { 28 | '-LRB-': '(', 29 | '-RRB-': ')', 30 | '-LSB-': '[', 31 | '-RSB-': ']', 32 | '-LCB-': '{', 33 | '-RCB-': '}', 34 | '``': '"', 35 | '\'\'': '"', 36 | '/.': '.', 37 | } -------------------------------------------------------------------------------- /old/pipeline.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | from task import build_tasks_from_file, MultiTask 4 | from util import get_logger 5 | from argparse import ArgumentParser 6 | 7 | logger = get_logger(__name__) 8 | 9 | arg_parser = ArgumentParser() 10 | 11 | # arg_parser.add_argument('-d', '--device', 12 | # type=int, default=0, help='GPU index') 13 | # arg_parser.add_argument('-t', '--thread', 14 | # type=int, default=5, help='Thread number') 15 | arg_parser.add_argument('-c', '--config', help='Configuration file') 16 | 17 | args = arg_parser.parse_args() 18 | 19 | # torch.cuda.set_device(args.device) 20 | # torch.set_num_threads(args.thread) 21 | config_file = args.config 22 | 23 | tasks, conf, _ = build_tasks_from_file(config_file, options=None) 24 | multitask = MultiTask(tasks, eval_freq=conf.training.eval_freq) 25 | 26 | try: 27 | for step in range(1, conf.training.max_step + 1): 28 | multitask.step() 29 | except Exception: 30 | traceback.print_exc() 31 | -------------------------------------------------------------------------------- /old/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import constant as C 4 | import conlleval 5 | 6 | 7 | def get_logger(name, level=C.LOGGING_LEVEL, log_file=None): 8 | """Get a logger by name. 9 | 10 | :param name: Logger name (usu. __name__). 11 | :param level: Logging level (default=logging.INFO). 12 | """ 13 | logger = logging.getLogger(name) 14 | logger.addHandler(logging.StreamHandler()) 15 | if log_file: 16 | logger.addHandler(logging.FileHandler(log_file, encoding='utf-8')) 17 | logger.setLevel(level) 18 | return logger 19 | 20 | 21 | def evaluate(results, idx_token, idx_label, writer=None): 22 | """Evaluate prediction results. 23 | 24 | :param results: A List of which each item is a tuple 25 | (predictions, gold labels, sequence lengths, tokens) of a batch. 26 | :param idx_token: Index to token dictionary. 27 | :param idx_label: Index to label dictionary. 28 | :param writer: An object (file object) with a write() function. Extra output. 29 | :return: F-score, precision, and recall. 30 | """ 31 | # b: batch, s: sequence 32 | outputs = [] 33 | for preds_b, golds_b, len_b, tokens_b in results: 34 | for preds_s, golds_s, len_s, tokens_s in zip(preds_b, golds_b, len_b, tokens_b): 35 | l = int(len_s.data[0]) 36 | preds_s = preds_s.data.tolist()[:l] 37 | golds_s = golds_s.data.tolist()[:l] 38 | tokens_s = tokens_s.data.tolist()[:l] 39 | for p, g, t in zip(preds_s, golds_s, tokens_s): 40 | token = idx_token.get(t, C.UNKNOWN_TOKEN) 41 | outputs.append('{} {} {}'.format( 42 | token, idx_label[g], idx_label[p])) 43 | outputs.append('') 44 | counts = conlleval.evaluate(outputs) 45 | overall, by_type = conlleval.metrics(counts) 46 | conlleval.report(counts) 47 | if writer: 48 | conlleval.report(counts, out=writer) 49 | writer.flush() 50 | return overall.fscore, overall.prec, overall.rec 51 | 52 | 53 | class Config(dict): 54 | 55 | def __init__(self, *args, **kwargs): 56 | super(Config, self).__init__(*args, **kwargs) 57 | __getattr__ = dict.__getitem__ 58 | 59 | for arg in args: 60 | if isinstance(arg, dict): 61 | for k, v in arg.items(): 62 | if isinstance(v, dict): 63 | v = Config(v) 64 | if isinstance(v, list): 65 | v = [Config(x) if isinstance(x, dict) else x for x in v] 66 | self[k] = v 67 | if kwargs: 68 | for k, v in kwargs.items(): 69 | self[k] = v 70 | 71 | def __setattr__(self, key, value): 72 | self.__setitem__(key, value) 73 | 74 | def __setitem__(self, key, value): 75 | super(Config, self).__setitem__(key, value) 76 | self.__dict__.update({key: value}) 77 | 78 | def __delattr__(self, item): 79 | self.__delitem__(item) 80 | 81 | def __delitem__(self, key): 82 | super(Config, self).__delitem__(key) 83 | del self.__dict__[key] 84 | 85 | def set_dict(self, dict_obj): 86 | for k, v in dict_obj.items(): 87 | if isinstance(v, dict): 88 | v = Config(v) 89 | self[k] = v 90 | 91 | def update(self, dict_obj): 92 | for k, v in dict_obj.items(): 93 | if isinstance(v, dict): 94 | v = Config(v) 95 | if isinstance(v, list): 96 | v = [Config(x) if isinstance(x, dict) else x for x in v] 97 | self[k] = v 98 | 99 | def clone(self): 100 | return Config(dict(self)) 101 | 102 | @staticmethod 103 | def read(path): 104 | """Read configuration from JSON format file. 105 | 106 | :param path: Path to the configuration file. 107 | :return: Config object. 108 | """ 109 | # logger.info('loading configuration from {}'.format(path)) 110 | json_obj = json.load(open(path, 'r', encoding='utf-8')) 111 | return Config(json_obj) 112 | 113 | def update_value(self, keys, value): 114 | keys = keys.split('.') 115 | assert len(keys) > 0 116 | 117 | tgt = self 118 | for k in keys[:-1]: 119 | try: 120 | tgt = tgt[int(k)] 121 | except Exception: 122 | tgt = tgt[k] 123 | tgt[keys[-1]] = value -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import constant as C 4 | import conlleval 5 | 6 | 7 | def get_logger(name, level=C.LOGGING_LEVEL, log_file=None): 8 | """Get a logger by name. 9 | 10 | :param name: Logger name (usu. __name__). 11 | :param level: Logging level (default=logging.INFO). 12 | """ 13 | logger = logging.getLogger(name) 14 | logger.addHandler(logging.StreamHandler()) 15 | if log_file: 16 | logger.addHandler(logging.FileHandler(log_file, encoding='utf-8')) 17 | logger.setLevel(level) 18 | return logger 19 | 20 | 21 | def evaluate(results, idx_token, idx_label, writer=None): 22 | """Evaluate prediction results. 23 | 24 | :param results: A List of which each item is a tuple 25 | (predictions, gold labels, sequence lengths, tokens) of a batch. 26 | :param idx_token: Index to token dictionary. 27 | :param idx_label: Index to label dictionary. 28 | :param writer: An object (file object) with a write() function. Extra output. 29 | :return: F-score, precision, and recall. 30 | """ 31 | # b: batch, s: sequence 32 | outputs = [] 33 | for preds_b, golds_b, len_b, tokens_b in results: 34 | for preds_s, golds_s, len_s, tokens_s in zip(preds_b, golds_b, len_b, tokens_b): 35 | l = int(len_s.item()) 36 | preds_s = preds_s.data.tolist()[:l] 37 | golds_s = golds_s.data.tolist()[:l] 38 | tokens_s = tokens_s.data.tolist()[:l] 39 | for p, g, t in zip(preds_s, golds_s, tokens_s): 40 | token = idx_token.get(t, C.UNK_INDEX) 41 | outputs.append('{} {} {}'.format( 42 | token, idx_label.get(g, 0), idx_label.get(p, 0))) 43 | outputs.append('') 44 | counts = conlleval.evaluate(outputs) 45 | overall, by_type = conlleval.metrics(counts) 46 | conlleval.report(counts) 47 | if writer: 48 | conlleval.report(counts, out=writer) 49 | writer.flush() 50 | return overall.fscore, overall.prec, overall.rec 51 | 52 | 53 | class Config(dict): 54 | 55 | def __init__(self, *args, **kwargs): 56 | super(Config, self).__init__(*args, **kwargs) 57 | __getattr__ = dict.__getitem__ 58 | 59 | for arg in args: 60 | if isinstance(arg, dict): 61 | for k, v in arg.items(): 62 | if isinstance(v, dict): 63 | v = Config(v) 64 | if isinstance(v, list): 65 | v = [Config(x) if isinstance(x, dict) else x for x in v] 66 | self[k] = v 67 | if kwargs: 68 | for k, v in kwargs.items(): 69 | self[k] = v 70 | 71 | def __setattr__(self, key, value): 72 | self.__setitem__(key, value) 73 | 74 | def __setitem__(self, key, value): 75 | super(Config, self).__setitem__(key, value) 76 | self.__dict__.update({key: value}) 77 | 78 | def __delattr__(self, item): 79 | self.__delitem__(item) 80 | 81 | def __delitem__(self, key): 82 | super(Config, self).__delitem__(key) 83 | del self.__dict__[key] 84 | 85 | def set_dict(self, dict_obj): 86 | for k, v in dict_obj.items(): 87 | if isinstance(v, dict): 88 | v = Config(v) 89 | self[k] = v 90 | 91 | def update(self, dict_obj): 92 | for k, v in dict_obj.items(): 93 | if isinstance(v, dict): 94 | v = Config(v) 95 | if isinstance(v, list): 96 | v = [Config(x) if isinstance(x, dict) else x for x in v] 97 | self[k] = v 98 | 99 | def clone(self): 100 | return Config(dict(self)) 101 | 102 | @staticmethod 103 | def read(path): 104 | """Read configuration from JSON format file. 105 | 106 | :param path: Path to the configuration file. 107 | :return: Config object. 108 | """ 109 | # logger.info('loading configuration from {}'.format(path)) 110 | json_obj = json.load(open(path, 'r', encoding='utf-8')) 111 | return Config(json_obj) 112 | 113 | def update_value(self, keys, value): 114 | keys = keys.split('.') 115 | assert len(keys) > 0 116 | 117 | tgt = self 118 | for k in keys[:-1]: 119 | try: 120 | tgt = tgt[int(k)] 121 | except Exception: 122 | tgt = tgt[k] 123 | tgt[keys[-1]] = value 124 | -------------------------------------------------------------------------------- /eval_single.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | import traceback 5 | 6 | from torch.utils.data import DataLoader 7 | 8 | import constant as C 9 | 10 | import torch 11 | 12 | from argparse import ArgumentParser 13 | from model import Linear, LSTM, CRF, CharCNN, Highway, LstmCrf 14 | from util import evaluate 15 | from data import ConllParser, SeqLabelDataset, SeqLabelProcessor 16 | 17 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 18 | 19 | logging.basicConfig(level=logging.DEBUG) 20 | logger = logging.getLogger() 21 | 22 | argparser = ArgumentParser() 23 | 24 | argparser.add_argument('--model', help='Path to the model file') 25 | argparser.add_argument('--file', help='Path to the file to evaluate') 26 | argparser.add_argument('--log', help='Path to the log dir') 27 | argparser.add_argument('--gpu', action='store_true') 28 | argparser.add_argument('--device', default=0, type=int) 29 | 30 | args = argparser.parse_args() 31 | 32 | use_gpu = args.gpu and torch.cuda.is_available() 33 | if use_gpu: 34 | torch.cuda.set_device(args.device) 35 | 36 | # Parameters 37 | model_file = args.model 38 | data_file = args.file 39 | 40 | log_writer = None 41 | if args.log: 42 | log_file = os.path.join(args.log, 'log.{}.txt'.format(timestamp)) 43 | log_writer = open(log_file, 'a', encoding='utf-8') 44 | logger.addHandler(logging.FileHandler(log_file, encoding='utf-8')) 45 | 46 | # Load saved model 47 | logger.info('Loading saved model from {}'.format(model_file)) 48 | state = torch.load(model_file) 49 | token_vocab = state['vocab']['token'] 50 | label_vocab = state['vocab']['label'] 51 | char_vocab = state['vocab']['char'] 52 | train_args = state['args'] 53 | charcnn_filters = [[int(f.split(',')[0]), int(f.split(',')[1])] 54 | for f in train_args['charcnn_filters'].split(';')] 55 | 56 | # Resume model 57 | logger.info('Resuming the model') 58 | word_embed = torch.nn.Embedding(train_args['word_embed_size'], 59 | train_args['word_embed_dim'], 60 | sparse=True, 61 | padding_idx=C.PAD_INDEX) 62 | char_embed = CharCNN(len(char_vocab), 63 | train_args['char_embed_dim'], 64 | filters=charcnn_filters) 65 | char_hw = Highway(char_embed.output_size, 66 | layer_num=train_args['charhw_layer'], 67 | activation=train_args['charhw_func']) 68 | feat_dim = word_embed.embedding_dim + char_embed.output_size 69 | lstm = LSTM(feat_dim, 70 | train_args['lstm_hidden_size'], 71 | batch_first=True, 72 | bidirectional=True, 73 | forget_bias=train_args['lstm_forget_bias']) 74 | crf = CRF(label_size=len(label_vocab) + 2) 75 | linear = Linear(in_features=lstm.output_size, 76 | out_features=len(label_vocab)) 77 | lstm_crf = LstmCrf( 78 | token_vocab, label_vocab, char_vocab, 79 | word_embedding=word_embed, 80 | char_embedding=char_embed, 81 | crf=crf, 82 | lstm=lstm, 83 | univ_fc_layer=linear, 84 | embed_dropout_prob=train_args['feat_dropout'], 85 | lstm_dropout_prob=train_args['lstm_dropout'], 86 | char_highway=char_hw if train_args['use_highway'] else None 87 | ) 88 | 89 | word_embed.load_state_dict(state['model']['word_embed']) 90 | char_embed.load_state_dict(state['model']['char_embed']) 91 | char_hw.load_state_dict(state['model']['char_hw']) 92 | lstm.load_state_dict(state['model']['lstm']) 93 | crf.load_state_dict(state['model']['crf']) 94 | linear.load_state_dict(state['model']['linear']) 95 | lstm_crf.load_state_dict(state['model']['lstm_crf']) 96 | 97 | if use_gpu: 98 | lstm_crf.cuda() 99 | 100 | # Load dataset 101 | logger.info('Loading data') 102 | parser = ConllParser() 103 | test_set = SeqLabelDataset(data_file, parser=parser) 104 | test_set.numberize(token_vocab, label_vocab, char_vocab) 105 | idx_token = {v: k for k, v in token_vocab.items()} 106 | idx_label = {v: k for k, v in label_vocab.items()} 107 | processor = SeqLabelProcessor(gpu=use_gpu) 108 | 109 | try: 110 | results = [] 111 | dataset_loss = [] 112 | for batch in DataLoader( 113 | test_set, 114 | batch_size=50, 115 | shuffle=False, 116 | collate_fn=processor.process 117 | ): 118 | tokens, labels, chars, seq_lens, char_lens = batch 119 | pred, loss = lstm_crf.predict( 120 | tokens, labels, seq_lens, chars, char_lens) 121 | results.append((pred, labels, seq_lens, tokens)) 122 | dataset_loss.append(loss.data[0]) 123 | 124 | dataset_loss = sum(dataset_loss) / len(dataset_loss) 125 | fscore, prec, rec = evaluate(results, idx_token, idx_label, 126 | writer=log_writer) 127 | if args.log: 128 | logger.info('Log file: {}'.format(log_file)) 129 | log_writer.close() 130 | except KeyboardInterrupt: 131 | traceback.print_exc() 132 | if log_writer: 133 | log_writer.close() 134 | -------------------------------------------------------------------------------- /eval_multi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import traceback 5 | 6 | import torch 7 | import constant as C 8 | 9 | from util import evaluate 10 | from argparse import ArgumentParser 11 | from torch.utils.data import DataLoader 12 | from model import Linear, LSTM, CRF, CharCNN, Highway, LstmCrf 13 | from data import ConllParser, SeqLabelDataset, SeqLabelProcessor 14 | 15 | 16 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 17 | 18 | logging.basicConfig(level=logging.DEBUG) 19 | logger = logging.getLogger() 20 | 21 | argparser = ArgumentParser() 22 | 23 | argparser.add_argument('--model', help='Path to the model file') 24 | argparser.add_argument('--file', help='Path to the file to evaluate') 25 | argparser.add_argument('--log', help='Path to the log dir') 26 | argparser.add_argument('--gpu', action='store_true') 27 | argparser.add_argument('--device', default=0, type=int) 28 | 29 | args = argparser.parse_args() 30 | 31 | use_gpu = args.gpu and torch.cuda.is_available() 32 | if use_gpu: 33 | torch.cuda.set_device(args.device) 34 | 35 | # Parameters 36 | model_file = args.model 37 | data_file = args.file 38 | 39 | log_writer = None 40 | if args.log: 41 | log_file = os.path.join(args.log, 'log.{}.txt'.format(timestamp)) 42 | log_writer = open(log_file, 'a', encoding='utf-8') 43 | logger.addHandler(logging.FileHandler(log_file, encoding='utf-8')) 44 | 45 | # Load saved model 46 | logger.info('Loading saved model from {}'.format(model_file)) 47 | state = torch.load(model_file) 48 | token_vocab = state['vocab']['token'] 49 | label_vocab = state['vocab']['label'] 50 | char_vocab = state['vocab']['char'] 51 | train_args = state['args'] 52 | charcnn_filters = [[int(f.split(',')[0]), int(f.split(',')[1])] 53 | for f in train_args['charcnn_filters'].split(';')] 54 | 55 | # Resume model 56 | logger.info('Resuming the model') 57 | word_embed = torch.nn.Embedding(train_args['word_embed_size'], 58 | train_args['word_embed_dim'], 59 | sparse=True, 60 | padding_idx=C.PAD_INDEX) 61 | char_embed = CharCNN(len(char_vocab), 62 | train_args['char_embed_dim'], 63 | filters=charcnn_filters) 64 | char_hw = Highway(char_embed.output_size, 65 | layer_num=train_args['charhw_layer'], 66 | activation=train_args['charhw_func']) 67 | feat_dim = word_embed.embedding_dim + char_embed.output_size 68 | lstm = LSTM(feat_dim, 69 | train_args['lstm_hidden_size'], 70 | batch_first=True, 71 | bidirectional=True, 72 | forget_bias=train_args['lstm_forget_bias']) 73 | crf = CRF(label_size=len(label_vocab) + 2) 74 | univ_linear = Linear(in_features=lstm.output_size, 75 | out_features=len(label_vocab)) 76 | spec_linear = Linear(in_features=lstm.output_size, 77 | out_features=len(label_vocab)) 78 | lstm_crf = LstmCrf( 79 | token_vocab, label_vocab, char_vocab, 80 | word_embedding=word_embed, 81 | char_embedding=char_embed, 82 | crf=crf, 83 | lstm=lstm, 84 | univ_fc_layer=univ_linear, 85 | spec_fc_layer=spec_linear, 86 | embed_dropout_prob=train_args['feat_dropout'], 87 | lstm_dropout_prob=train_args['lstm_dropout'], 88 | char_highway=char_hw if train_args['use_highway'] else None 89 | ) 90 | 91 | word_embed.load_state_dict(state['model']['word_embed']) 92 | char_embed.load_state_dict(state['model']['char_embed']) 93 | char_hw.load_state_dict(state['model']['char_hw']) 94 | lstm.load_state_dict(state['model']['lstm']) 95 | crf.load_state_dict(state['model']['crf']) 96 | univ_linear.load_state_dict(state['model']['univ_linear']) 97 | spec_linear.load_state_dict(state['model']['spec_linear']) 98 | lstm_crf.load_state_dict(state['model']['lstm_crf']) 99 | 100 | if use_gpu: 101 | lstm_crf.cuda() 102 | 103 | # Load dataset 104 | logger.info('Loading data') 105 | parser = ConllParser() 106 | test_set = SeqLabelDataset(data_file, parser=parser) 107 | test_set.numberize(token_vocab, label_vocab, char_vocab) 108 | idx_token = {v: k for k, v in token_vocab.items()} 109 | idx_label = {v: k for k, v in label_vocab.items()} 110 | processor = SeqLabelProcessor(gpu=use_gpu) 111 | 112 | try: 113 | results = [] 114 | dataset_loss = [] 115 | for batch in DataLoader( 116 | test_set, 117 | batch_size=50, 118 | shuffle=False, 119 | collate_fn=processor.process 120 | ): 121 | tokens, labels, chars, seq_lens, char_lens = batch 122 | pred, loss = lstm_crf.predict( 123 | tokens, labels, seq_lens, chars, char_lens) 124 | results.append((pred, labels, seq_lens, tokens)) 125 | dataset_loss.append(loss.data.item()) 126 | 127 | dataset_loss = sum(dataset_loss) / len(dataset_loss) 128 | fscore, prec, rec = evaluate(results, idx_token, idx_label, 129 | writer=log_writer) 130 | if args.log: 131 | logger.info('Log file: {}'.format(log_file)) 132 | log_writer.close() 133 | except KeyboardInterrupt: 134 | traceback.print_exc() 135 | if log_writer: 136 | log_writer.close() 137 | -------------------------------------------------------------------------------- /old/eval_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import traceback 4 | import constant as C 5 | 6 | import torch 7 | 8 | from model import Linear, LSTM, CRF, CharCNN, Highway, LstmCrf, Embedding 9 | from argparse import ArgumentParser 10 | from util import get_logger, evaluate, Config 11 | from data import ( 12 | SequenceDataset, ConllParser, 13 | compute_metadata, count2vocab, numberize_datasets 14 | ) 15 | 16 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.gmtime()) 17 | 18 | argparser = ArgumentParser() 19 | 20 | argparser.add_argument('--model', help='Path to the model file') 21 | argparser.add_argument('--file', help='Path to the file to evaluate') 22 | argparser.add_argument('--log', help='Path to the log dir') 23 | argparser.add_argument('--gpu', default=1, type=int, help='Use GPU') 24 | argparser.add_argument('--gpu_idx', default=0, type=int) 25 | 26 | args = argparser.parse_args() 27 | 28 | # Parameters 29 | model_file = args.model 30 | data_file = args.file 31 | assert model_file, 'Model file is required' 32 | assert data_file, 'Data file is required' 33 | 34 | use_gpu = (args.gpu == 1) 35 | 36 | 37 | log_writer = None 38 | if args.log: 39 | log_file = os.path.join(args.log, 'log.{}.txt'.format(timestamp)) 40 | log_writer = open(log_file, 'a', encoding='utf-8') 41 | logger = get_logger(__name__, log_file=log_file) 42 | else: 43 | logger = get_logger(__name__) 44 | 45 | # Load saved model 46 | logger.info('Loading saved model from {}'.format(model_file)) 47 | state = torch.load(model_file) 48 | token_vocab = state['vocab']['token'] 49 | label_vocab = state['vocab']['label'] 50 | char_vocab = state['vocab']['char'] 51 | train_args = state['args'] 52 | charcnn_filters = [[int(f.split(',')[0]), int(f.split(',')[1])] 53 | for f in train_args['charcnn_filters'].split(';')] 54 | 55 | # Resume model 56 | logger.info('Resuming the model') 57 | word_embed = Embedding(Config({ 58 | 'num_embeddings': len(token_vocab), 59 | 'embedding_dim': train_args['word_embed_dim'], 60 | 'padding': C.EMBED_START_IDX, 61 | 'padding_idx': 0, 62 | 'sparse': True, 63 | 'trainable': True, 64 | 'stats': train_args['embed_skip_first'], 65 | 'vocab': token_vocab, 66 | 'ignore_case': train_args['word_ignore_case'] 67 | })) 68 | char_cnn = CharCNN(Config({ 69 | 'vocab_size': len(char_vocab), 70 | 'padding': C.CHAR_EMBED_START_IDX, 71 | 'dimension': train_args['char_embed_dim'], 72 | 'filters': charcnn_filters 73 | })) 74 | char_highway = Highway(Config({ 75 | 'num_layers': 2, 76 | 'size': char_cnn.output_size, 77 | 'activation': 'selu' 78 | })) 79 | lstm = LSTM(Config({ 80 | 'input_size': word_embed.output_size + char_cnn.output_size, 81 | 'hidden_size': train_args['lstm_hidden_size'], 82 | 'forget_bias': 1.0, 83 | 'batch_first': True, 84 | 'bidirectional': True 85 | })) 86 | crf = CRF(Config({ 87 | 'label_vocab': label_vocab 88 | })) 89 | output_linear = Linear(Config({ 90 | 'in_features': lstm.output_size, 91 | 'out_features': len(label_vocab) 92 | })) 93 | word_embed.load_state_dict(state['model']['word_embed']) 94 | char_cnn.load_state_dict(state['model']['char_cnn']) 95 | char_highway.load_state_dict(state['model']['char_highway']) 96 | lstm.load_state_dict(state['model']['lstm']) 97 | crf.load_state_dict(state['model']['crf']) 98 | output_linear.load_state_dict(state['model']['output_linear']) 99 | lstm_crf = LstmCrf( 100 | token_vocab=token_vocab, 101 | label_vocab=label_vocab, 102 | char_vocab=char_vocab, 103 | word_embedding=word_embed, 104 | char_embedding=char_cnn, 105 | crf=crf, 106 | lstm=lstm, 107 | univ_fc_layer=output_linear, 108 | embed_dropout_prob=train_args['embed_dropout'], 109 | lstm_dropout_prob=train_args['lstm_dropout'], 110 | linear_dropout_prob=train_args['linear_dropout'], 111 | char_highway=char_highway 112 | ) 113 | lstm_crf.load_state_dict(state['model']['lstm_crf']) 114 | 115 | if use_gpu: 116 | torch.cuda.set_device(args.gpu_idx) 117 | lstm_crf.cuda() 118 | else: 119 | lstm_crf.cpu() 120 | 121 | # Load dataset 122 | logger.info('Loading data') 123 | conll_parser = ConllParser(Config({ 124 | 'separator': '\t', 125 | 'token_col': 0, 126 | 'label_col': 1, 127 | 'skip_comment': True, 128 | })) 129 | test_set = SequenceDataset(Config({ 130 | 'path': data_file, 'parser': conll_parser 131 | })) 132 | numberize_datasets([(test_set, token_vocab, label_vocab, char_vocab)], 133 | token_ignore_case=train_args['word_ignore_case'], 134 | label_ignore_case=False, 135 | char_ignore_case=False) 136 | idx_token = {idx: token for token, idx in token_vocab.items()} 137 | idx_label = {idx: label for label, idx in label_vocab.items()} 138 | idx_token[C.UNKNOWN_TOKEN_INDEX] = C.UNKNOWN_TOKEN 139 | 140 | try: 141 | results = [] 142 | dataset_loss = [] 143 | for batch in test_set.get_dataset(gpu=use_gpu, 144 | shuffle_inst=False, 145 | batch_size=100): 146 | tokens, labels, chars, seq_lens, char_lens = batch 147 | pred, loss = lstm_crf.predict( 148 | tokens, labels, seq_lens, chars, char_lens) 149 | results.append((pred, labels, seq_lens, tokens)) 150 | dataset_loss.append(loss.data[0]) 151 | 152 | dataset_loss = sum(dataset_loss) / len(dataset_loss) 153 | fscore, prec, rec = evaluate(results, idx_token, idx_label, writer=log_writer) 154 | except KeyboardInterrupt: 155 | traceback.print_exc() 156 | if log_writer: 157 | log_writer.close() 158 | -------------------------------------------------------------------------------- /example_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "training": { 3 | "eval_freq": 1000, 4 | "max_step": 50000, 5 | "gpu": true 6 | }, 7 | "datasets": [ 8 | { 9 | "name": "nld_ner", 10 | "language": "nld", 11 | "type": "sequence", 12 | "task": "ner", 13 | "parser": { 14 | "format": "conll", 15 | "token_col": 0, 16 | "label_col": 1 17 | }, 18 | "sample": 200, 19 | "batch_size": 19, 20 | "files": { 21 | "train": "/PATH/TO/ned.train.bioes", 22 | "dev": "/PATH/TO/ned.testa.bioes", 23 | "test": "/PATH/TO/ned.testb.bioes" 24 | } 25 | }, 26 | { 27 | "name": "nld_pos", 28 | "language": "nld", 29 | "type": "sequence", 30 | "task": "pos", 31 | "parser": { 32 | "format": "conll", 33 | "token_col": 1, 34 | "label_col": 3 35 | }, 36 | "batch_size": 19, 37 | "files": { 38 | "train": "/PATH/TO/nl-ud-train.conllu", 39 | "dev": "/PATH/TO/nl-ud-dev.conllu", 40 | "test": "/PATH/TO/nl-ud-test.conllu" 41 | } 42 | }, 43 | { 44 | "name": "eng_ner", 45 | "language": "eng", 46 | "type": "sequence", 47 | "task": "ner", 48 | "parser": { 49 | "format": "conll", 50 | "token_col": 0, 51 | "label_col": 1 52 | }, 53 | "batch_size": 19, 54 | "files": { 55 | "train": "/PATH/TO/eng.train.bioes", 56 | "dev": "/PATH/TO/eng.testa.bioes", 57 | "test": "/PATH/TO/eng.testb.bioes" 58 | } 59 | }, 60 | { 61 | "name": "eng_pos", 62 | "language": "eng", 63 | "type": "sequence", 64 | "task": "pos", 65 | "parser": { 66 | "format": "conll", 67 | "token_col": 1, 68 | "label_col": 3 69 | }, 70 | "batch_size": 19, 71 | "files": { 72 | "train": "/PATH/TO/en-ud-train.conllu", 73 | "dev": "/PATH/TO/en-ud-dev.conllu", 74 | "test": "/PATH/TO/en-ud-test.conllu" 75 | } 76 | } 77 | ], 78 | "tasks": [ 79 | { 80 | "name": "Dutch NER", 81 | "language": "nld", 82 | "task": "ner", 83 | "model": { 84 | "model": "lstm_crf", 85 | "word_embed": "nld_word_embed", 86 | "char_embed": "char_embed", 87 | "crf": "ner_crf", 88 | "lstm": "lstm", 89 | "univ_layer": "ner_univ_linear", 90 | "spec_layer": "ner_nld_linear", 91 | "embed_dropout": 0.0, 92 | "lstm_dropout": 0.6, 93 | "linear_dropout": 0.0, 94 | "use_char_embedding": true, 95 | "char_highway": "char_highway" 96 | }, 97 | "dataset": "nld_ner", 98 | "learning_rate": 0.02, 99 | "decay_rate": 0.9, 100 | "decay_step": 10000, 101 | "prob": 1, 102 | "gpu": true, 103 | "ref": true 104 | }, 105 | { 106 | "name": "Dutch POS", 107 | "language": "nld", 108 | "task": "pos", 109 | "model": { 110 | "model": "lstm_crf", 111 | "word_embed": "nld_word_embed", 112 | "char_embed": "char_embed", 113 | "crf": "pos_crf", 114 | "lstm": "lstm", 115 | "univ_layer": "pos_univ_linear", 116 | "spec_layer": "pos_nld_linear", 117 | "embed_dropout": 0.0, 118 | "lstm_dropout": 0.6, 119 | "linear_dropout": 0.0, 120 | "use_char_embedding": true, 121 | "char_highway": "char_highway" 122 | }, 123 | "dataset": "nld_pos", 124 | "learning_rate": 0.02, 125 | "decay_rate": 0.9, 126 | "decay_step": 10000, 127 | "prob": 0.1, 128 | "gpu": true 129 | }, 130 | { 131 | "name": "English NER", 132 | "language": "eng", 133 | "task": "ner", 134 | "model": { 135 | "model": "lstm_crf", 136 | "word_embed": "eng_word_embed", 137 | "char_embed": "char_embed", 138 | "crf": "ner_crf", 139 | "lstm": "lstm", 140 | "univ_layer": "ner_univ_linear", 141 | "spec_layer": "ner_eng_linear", 142 | "embed_dropout": 0.0, 143 | "lstm_dropout": 0.6, 144 | "linear_dropout": 0.0, 145 | "use_char_embedding": true, 146 | "char_highway": "char_highway" 147 | }, 148 | "dataset": "eng_ner", 149 | "learning_rate": 0.02, 150 | "decay_rate": 0.9, 151 | "decay_step": 10000, 152 | "prob": 1, 153 | "gpu": true 154 | }, 155 | { 156 | "name": "English POS", 157 | "language": "eng", 158 | "task": "pos", 159 | "model": { 160 | "model": "lstm_crf", 161 | "word_embed": "eng_word_embed", 162 | "char_embed": "char_embed", 163 | "crf": "pos_crf", 164 | "lstm": "lstm", 165 | "univ_layer": "pos_univ_linear", 166 | "spec_layer": "pos_eng_linear", 167 | "embed_dropout": 0.0, 168 | "lstm_dropout": 0.6, 169 | "linear_dropout": 0.0, 170 | "use_char_embedding": true, 171 | "char_highway": "char_highway" 172 | }, 173 | "dataset": "eng_pos", 174 | "learning_rate": 0.02, 175 | "decay_rate": 0.9, 176 | "decay_step": 10000, 177 | "prob": 0.1, 178 | "gpu": true 179 | } 180 | ], 181 | "components": [ 182 | { 183 | "name": "eng_word_embed", 184 | "model": "embedding", 185 | "language": "eng", 186 | "file": "/PATH/TO/enwiki.cbow.50d.txt", 187 | "stats": true, 188 | "padding": 2, 189 | "trainable": true, 190 | "allow_gpu": false, 191 | "dimension": 50, 192 | "padding_idx": 0, 193 | "sparse": true 194 | }, 195 | { 196 | "name": "nld_word_embed", 197 | "model": "embedding", 198 | "language": "nld", 199 | "file": "/PATH/TO/nlwiki.cbow.50d.txt", 200 | "stats": true, 201 | "padding": 2, 202 | "trainable": true, 203 | "allow_gpu": false, 204 | "dimension": 50, 205 | "padding_idx": 0, 206 | "sparse": true 207 | }, 208 | { 209 | "name": "char_embed", 210 | "model": "char_cnn", 211 | "dimension": 50, 212 | "filters": [[2, 20], [3, 20], [4, 20]] 213 | }, 214 | { 215 | "name": "lstm", 216 | "model": "lstm", 217 | "hidden_size": 171, 218 | "bidirectional": true, 219 | "forget_bias": 1.0, 220 | "batch_first": true, 221 | "dropout": 0.0 222 | }, 223 | { 224 | "name": "ner_crf", 225 | "model": "crf" 226 | }, 227 | { 228 | "name": "pos_crf", 229 | "model": "crf" 230 | }, 231 | { 232 | "name": "ner_univ_linear", 233 | "model": "linear", 234 | "position": "output" 235 | }, 236 | { 237 | "name": "ner_eng_linear", 238 | "model": "linear", 239 | "position": "output" 240 | }, 241 | { 242 | "name": "ner_nld_linear", 243 | "model": "linear", 244 | "position": "output" 245 | }, 246 | { 247 | "name": "pos_univ_linear", 248 | "model": "linear", 249 | "position": "output" 250 | }, 251 | { 252 | "name": "pos_eng_linear", 253 | "model": "linear", 254 | "position": "output" 255 | }, 256 | { 257 | "name": "pos_nld_linear", 258 | "model": "linear", 259 | "position": "output" 260 | }, 261 | { 262 | "name": "char_highway", 263 | "model": "highway", 264 | "position": "char", 265 | "num_layers": 2, 266 | "activation": "selu" 267 | } 268 | ] 269 | } -------------------------------------------------------------------------------- /old/example_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "training": { 3 | "eval_freq": 1000, 4 | "max_step": 50000, 5 | "gpu": true 6 | }, 7 | "datasets": [ 8 | { 9 | "name": "nld_ner", 10 | "language": "nld", 11 | "type": "sequence", 12 | "task": "ner", 13 | "parser": { 14 | "format": "conll", 15 | "token_col": 0, 16 | "label_col": 1 17 | }, 18 | "sample": 200, 19 | "batch_size": 19, 20 | "files": { 21 | "train": "/PATH/TO/ned.train.bioes", 22 | "dev": "/PATH/TO/ned.testa.bioes", 23 | "test": "/PATH/TO/ned.testb.bioes" 24 | } 25 | }, 26 | { 27 | "name": "nld_pos", 28 | "language": "nld", 29 | "type": "sequence", 30 | "task": "pos", 31 | "parser": { 32 | "format": "conll", 33 | "token_col": 1, 34 | "label_col": 3 35 | }, 36 | "batch_size": 19, 37 | "files": { 38 | "train": "/PATH/TO/nl-ud-train.conllu", 39 | "dev": "/PATH/TO/nl-ud-dev.conllu", 40 | "test": "/PATH/TO/nl-ud-test.conllu" 41 | } 42 | }, 43 | { 44 | "name": "eng_ner", 45 | "language": "eng", 46 | "type": "sequence", 47 | "task": "ner", 48 | "parser": { 49 | "format": "conll", 50 | "token_col": 0, 51 | "label_col": 1 52 | }, 53 | "batch_size": 19, 54 | "files": { 55 | "train": "/PATH/TO/eng.train.bioes", 56 | "dev": "/PATH/TO/eng.testa.bioes", 57 | "test": "/PATH/TO/eng.testb.bioes" 58 | } 59 | }, 60 | { 61 | "name": "eng_pos", 62 | "language": "eng", 63 | "type": "sequence", 64 | "task": "pos", 65 | "parser": { 66 | "format": "conll", 67 | "token_col": 1, 68 | "label_col": 3 69 | }, 70 | "batch_size": 19, 71 | "files": { 72 | "train": "/PATH/TO/en-ud-train.conllu", 73 | "dev": "/PATH/TO/en-ud-dev.conllu", 74 | "test": "/PATH/TO/en-ud-test.conllu" 75 | } 76 | } 77 | ], 78 | "tasks": [ 79 | { 80 | "name": "Dutch NER", 81 | "language": "nld", 82 | "task": "ner", 83 | "model": { 84 | "model": "lstm_crf", 85 | "word_embed": "nld_word_embed", 86 | "char_embed": "char_embed", 87 | "crf": "ner_crf", 88 | "lstm": "lstm", 89 | "univ_layer": "ner_univ_linear", 90 | "spec_layer": "ner_nld_linear", 91 | "embed_dropout": 0.0, 92 | "lstm_dropout": 0.6, 93 | "linear_dropout": 0.0, 94 | "use_char_embedding": true, 95 | "char_highway": "char_highway" 96 | }, 97 | "dataset": "nld_ner", 98 | "learning_rate": 0.02, 99 | "decay_rate": 0.9, 100 | "decay_step": 10000, 101 | "prob": 1, 102 | "gpu": true, 103 | "ref": true 104 | }, 105 | { 106 | "name": "Dutch POS", 107 | "language": "nld", 108 | "task": "pos", 109 | "model": { 110 | "model": "lstm_crf", 111 | "word_embed": "nld_word_embed", 112 | "char_embed": "char_embed", 113 | "crf": "pos_crf", 114 | "lstm": "lstm", 115 | "univ_layer": "pos_univ_linear", 116 | "spec_layer": "pos_nld_linear", 117 | "embed_dropout": 0.0, 118 | "lstm_dropout": 0.6, 119 | "linear_dropout": 0.0, 120 | "use_char_embedding": true, 121 | "char_highway": "char_highway" 122 | }, 123 | "dataset": "nld_pos", 124 | "learning_rate": 0.02, 125 | "decay_rate": 0.9, 126 | "decay_step": 10000, 127 | "prob": 0.1, 128 | "gpu": true 129 | }, 130 | { 131 | "name": "English NER", 132 | "language": "eng", 133 | "task": "ner", 134 | "model": { 135 | "model": "lstm_crf", 136 | "word_embed": "eng_word_embed", 137 | "char_embed": "char_embed", 138 | "crf": "ner_crf", 139 | "lstm": "lstm", 140 | "univ_layer": "ner_univ_linear", 141 | "spec_layer": "ner_eng_linear", 142 | "embed_dropout": 0.0, 143 | "lstm_dropout": 0.6, 144 | "linear_dropout": 0.0, 145 | "use_char_embedding": true, 146 | "char_highway": "char_highway" 147 | }, 148 | "dataset": "eng_ner", 149 | "learning_rate": 0.02, 150 | "decay_rate": 0.9, 151 | "decay_step": 10000, 152 | "prob": 1, 153 | "gpu": true 154 | }, 155 | { 156 | "name": "English POS", 157 | "language": "eng", 158 | "task": "pos", 159 | "model": { 160 | "model": "lstm_crf", 161 | "word_embed": "eng_word_embed", 162 | "char_embed": "char_embed", 163 | "crf": "pos_crf", 164 | "lstm": "lstm", 165 | "univ_layer": "pos_univ_linear", 166 | "spec_layer": "pos_eng_linear", 167 | "embed_dropout": 0.0, 168 | "lstm_dropout": 0.6, 169 | "linear_dropout": 0.0, 170 | "use_char_embedding": true, 171 | "char_highway": "char_highway" 172 | }, 173 | "dataset": "eng_pos", 174 | "learning_rate": 0.02, 175 | "decay_rate": 0.9, 176 | "decay_step": 10000, 177 | "prob": 0.1, 178 | "gpu": true 179 | } 180 | ], 181 | "components": [ 182 | { 183 | "name": "eng_word_embed", 184 | "model": "embedding", 185 | "language": "eng", 186 | "file": "/PATH/TO/enwiki.cbow.50d.txt", 187 | "stats": true, 188 | "padding": 2, 189 | "trainable": true, 190 | "allow_gpu": false, 191 | "dimension": 50, 192 | "padding_idx": 0, 193 | "sparse": true 194 | }, 195 | { 196 | "name": "nld_word_embed", 197 | "model": "embedding", 198 | "language": "nld", 199 | "file": "/PATH/TO/nlwiki.cbow.50d.txt", 200 | "stats": true, 201 | "padding": 2, 202 | "trainable": true, 203 | "allow_gpu": false, 204 | "dimension": 50, 205 | "padding_idx": 0, 206 | "sparse": true 207 | }, 208 | { 209 | "name": "char_embed", 210 | "model": "char_cnn", 211 | "dimension": 50, 212 | "filters": [[2, 20], [3, 20], [4, 20]] 213 | }, 214 | { 215 | "name": "lstm", 216 | "model": "lstm", 217 | "hidden_size": 171, 218 | "bidirectional": true, 219 | "forget_bias": 1.0, 220 | "batch_first": true, 221 | "dropout": 0.0 222 | }, 223 | { 224 | "name": "ner_crf", 225 | "model": "crf" 226 | }, 227 | { 228 | "name": "pos_crf", 229 | "model": "crf" 230 | }, 231 | { 232 | "name": "ner_univ_linear", 233 | "model": "linear", 234 | "position": "output" 235 | }, 236 | { 237 | "name": "ner_eng_linear", 238 | "model": "linear", 239 | "position": "output" 240 | }, 241 | { 242 | "name": "ner_nld_linear", 243 | "model": "linear", 244 | "position": "output" 245 | }, 246 | { 247 | "name": "pos_univ_linear", 248 | "model": "linear", 249 | "position": "output" 250 | }, 251 | { 252 | "name": "pos_eng_linear", 253 | "model": "linear", 254 | "position": "output" 255 | }, 256 | { 257 | "name": "pos_nld_linear", 258 | "model": "linear", 259 | "position": "output" 260 | }, 261 | { 262 | "name": "char_highway", 263 | "model": "highway", 264 | "position": "char", 265 | "num_layers": 2, 266 | "activation": "selu" 267 | } 268 | ] 269 | } -------------------------------------------------------------------------------- /old/README.md: -------------------------------------------------------------------------------- 1 | I made modifications to a few functions in `data.py` and `model.py` when implementing `train_single.py`, `eval_single.py`, and `train_multi.py`. Not sure if `pipeline.py` was affected. I'll test it and finish the todo list as soon as possible. 2 | 3 | ## TODOs 4 | 5 | * Implement `eval_multi.py`. 6 | * Revise `build_tasks_from_file()` to support case-sensitive word embeddings. 7 | 8 | ## Requirements 9 | * Python 3.5+ 10 | * Pytorch 0.3.1 11 | * tqdm (used to display training progress) 12 | 13 | ## Architecture 14 | ![Overall architecture](https://github.com/limteng-rpi/mlmt/blob/master/image/framework.png) 15 | **Figure**: Multi-lingual Multi-task Architecture 16 | 17 | ## Pre-trained word embeddings 18 | 19 | Pre-trained word embeddings for English, Dutch, Spanish, Russian, and Chechen can be found at [this page](http://www.limteng.com/research/2018/05/14/pretrained-word-embeddings.html). 20 | 21 | **Update**: I added English, Dutch, and Spanish case-sensitive word embeddings. 22 | 23 | ## Single-task Mono-lingual Model 24 | 25 | Train a new model: 26 | 27 | ``` 28 | python train_single.py --train --dev 29 | --test --log --model 30 | --max_epoch 50 --embedding --embed_skip_first 31 | --word_embed_dim 100 --char_embed_dim 50 32 | ``` 33 | 34 | Evalute the trained model: 35 | 36 | ``` 37 | python eval_single.py --model --file 38 | --log 39 | ``` 40 | 41 | ## Multi-task Model 42 | 43 | In my original code, I use the `build_tasks_from_file` function in `task.py` to build the whole architecture from a configuration file (see the `Configuration` section). `pipeline.py` shows how to use this function. 44 | 45 | Train a new model: 46 | 47 | ``` 48 | python train_multi.py --train_tgt --dev_tgt 49 | --test_tgt --train_cl --dev_cl 50 | --test_cl --train_ct --dev_ct 51 | --test_ct --train_clct --dev_clct 52 | --test_clct --log 53 | --model --max_epoch 50 54 | --embedding1 --embedding2 --word_embed_dim 50 55 | ``` 56 | 57 | ## Configuration 58 | 59 | For complete configuration, see `example_config.json`. 60 | 61 | ```json 62 | { 63 | "training": { 64 | "eval_freq": 1000, # Evaluate the model every global step 65 | "max_step": 50000, # Maximun training step 66 | "gpu": true # Use GPU 67 | }, 68 | "datasets": [ # A list of data sets 69 | { 70 | "name": "nld_ner", # Data set name 71 | "language": "nld", # Data set language 72 | "type": "sequence", # Data set type; 'sequence' is the only supported value though 73 | "task": "ner", # Task (identical to the 'task' value of the corresponding task) 74 | "parser": { # Data set parser 75 | "format": "conll", # File format 76 | "token_col": 0, # Token column index 77 | "label_col": 1 # Label column index 78 | }, 79 | "sample": 200, # Sample number (optional): 'all', int, or float 80 | "batch_size": 19, # Batch size 81 | "files": { 82 | "train": "/PATH/TO/ned.train.bioes", # Path to the training set 83 | "dev": "/PATH/TO/ned.testa.bioes", # Path to the dev set 84 | "test": "/PATH/TO/ned.testb.bioes" # Path to the test set (optional) 85 | } 86 | }, 87 | ... 88 | ], 89 | "tasks": [ 90 | { 91 | "name": "Dutch NER", # Task name 92 | "language": "nld", # Task language 93 | "task": "ner", # Task 94 | "model": { # Components can be shared and are configured in 'components'. Just 95 | # put their names here. 96 | "model": "lstm_crf", # Model type 97 | "word_embed": "nld_word_embed", # Word embedding 98 | "char_embed": "char_embed", # Character embedding 99 | "crf": "ner_crf", # CRF layer 100 | "lstm": "lstm", # LSTM layer 101 | "univ_layer": "ner_univ_linear", # Universal/shared linear layer 102 | "spec_layer": "ner_nld_linear", # Language-specific linear layer 103 | "embed_dropout": 0.0, # Embedding dropout probability 104 | "lstm_dropout": 0.6, # LSTM output dropout probability 105 | "linear_dropout": 0.0, # Linear layer output dropout probability 106 | "use_char_embedding": true, # Use character embeddings 107 | "char_highway": "char_highway" # Highway networks for character embeddings 108 | }, 109 | "dataset": "nld_ner", # Data set name 110 | "learning_rate": 0.02, # Learning rate 111 | "decay_rate": 0.9, # Decay rate 112 | "decay_step": 10000, # Decay step 113 | "ref": true # Is the target task 114 | }, 115 | ... 116 | ], 117 | "components": [ 118 | { 119 | "name": "eng_word_embed", 120 | "model": "embedding", 121 | "language": "eng", 122 | "file": "/PATH/TO/enwiki.cbow.50d.txt", 123 | "stats": true, 124 | "padding": 2, 125 | "trainable": true, 126 | "allow_gpu": false, 127 | "dimension": 50, 128 | "padding_idx": 0, 129 | "sparse": true 130 | }, 131 | { 132 | "name": "nld_word_embed", 133 | "model": "embedding", 134 | "language": "nld", 135 | "file": "/PATH/TO/nlwiki.cbow.50d.txt", 136 | "stats": true, 137 | "padding": 2, 138 | "trainable": true, 139 | "allow_gpu": false, 140 | "dimension": 50, 141 | "padding_idx": 0, 142 | "sparse": true 143 | }, 144 | { 145 | "name": "char_embed", 146 | "model": "char_cnn", 147 | "dimension": 50, 148 | "filters": [[2, 20], [3, 20], [4, 20]] 149 | }, 150 | { 151 | "name": "lstm", 152 | "model": "lstm", 153 | "hidden_size": 171, 154 | "bidirectional": true, 155 | "forget_bias": 1.0, 156 | "batch_first": true, 157 | "dropout": 0.0 # Because we use a 1-layer LSTM. This value doesn't have any effect. 158 | }, 159 | { 160 | "name": "ner_crf", 161 | "model": "crf" 162 | }, 163 | { 164 | "name": "pos_crf", 165 | "model": "crf" 166 | }, 167 | { 168 | "name": "ner_univ_linear", 169 | "model": "linear", 170 | "position": "output" 171 | }, 172 | { 173 | "name": "ner_eng_linear", 174 | "model": "linear", 175 | "position": "output" 176 | }, 177 | { 178 | "name": "ner_nld_linear", 179 | "model": "linear", 180 | "position": "output" 181 | }, 182 | { 183 | "name": "pos_univ_linear", 184 | "model": "linear", 185 | "position": "output" 186 | }, 187 | { 188 | "name": "pos_eng_linear", 189 | "model": "linear", 190 | "position": "output" 191 | }, 192 | { 193 | "name": "pos_nld_linear", 194 | "model": "linear", 195 | "position": "output" 196 | }, 197 | { 198 | "name": "char_highway", 199 | "model": "highway", 200 | "position": "char", 201 | "num_layers": 2, 202 | "activation": "selu" 203 | } 204 | ] 205 | } 206 | ``` 207 | 208 | ## Reference 209 | 210 | - Lin, Y., Yang, S., Stoyanov, V., Ji, H. (2018) *A Multi-lingual Multi-task Architecture for Low-resource Sequence Labeling*. Proceedings of The 56th Annual Meeting of the Association for Computational Linguistics. \[[pdf](http://nlp.cs.rpi.edu/paper/multilingualmultitask.pdf)\] 211 | 212 | ``` 213 | @inproceedings{ying2018multi, 214 | title = {A Multi-lingual Multi-task Architecture for Low-resource Sequence Labeling}, 215 | author = {Ying Lin and Shengqi Yang and Veselin Stoyanov and Heng Ji}, 216 | booktitle = {Proceedings of The 56th Annual Meeting of the Association for Computational Linguistics (ACL2018)}, 217 | year = {2018} 218 | } 219 | ``` 220 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | If what you want is a (monolingual, single-task) name tagging model, we have a new implementation at https://github.com/limteng-rpi/neural_name_tagging . 2 | 3 | Old files were moved to `old/`. 4 | 5 | ## Requirements 6 | * Python 3.5+ 7 | * Pytorch 0.4.1 or Pytorch 1.0 (old scripts use Pytorch 0.3.1) 8 | * tqdm (used to display training progress) 9 | 10 | ## Architecture 11 | ![Overall architecture](https://github.com/limteng-rpi/mlmt/blob/master/image/framework.png) 12 | **Figure**: Multi-lingual Multi-task Architecture 13 | 14 | ## Pre-trained word embeddings 15 | 16 | Pre-trained word embeddings for English, Dutch, Spanish, Russian, and Chechen can be found at [this page](http://www.limteng.com/research/2018/05/14/pretrained-word-embeddings.html). 17 | 18 | **Update**: I added English, Dutch, and Spanish case-sensitive word embeddings. 19 | 20 | ## Single-task Mono-lingual Model 21 | 22 | Train a new model: 23 | 24 | ``` 25 | python train_single.py --train --dev 26 | --test --log --model 27 | --max_epoch 50 --word_embed 28 | --word_embed_dim 100 --char_embed_dim 50 29 | ``` 30 | 31 | Evalute the trained model: 32 | 33 | ``` 34 | python eval_single.py --model --file 35 | --log 36 | ``` 37 | 38 | ## Multi-task Model 39 | 40 | In my original code, I use the `build_tasks_from_file` function in `task.py` to build the whole architecture from a configuration file (see the `Configuration` section). `pipeline.py` shows how to use this function. 41 | 42 | Train a new model: 43 | 44 | ``` 45 | python train_multi.py --train_tgt --dev_tgt 46 | --test_tgt --train_cl --dev_cl 47 | --test_cl --train_ct --dev_ct 48 | --test_ct --train_clct --dev_clct 49 | --test_clct --log 50 | --model --max_epoch 50 51 | --word_embed_1 --word_embed_2 --word_embed_dim 50 52 | ``` 53 | 54 | Evalute the trained model: 55 | 56 | ``` 57 | python eval_multi.py --model --file 58 | --log 59 | 60 | ## Configuration 61 | 62 | For complete configuration, see `example_config.json`. 63 | 64 | ```json 65 | { 66 | "training": { 67 | "eval_freq": 1000, # Evaluate the model every global step 68 | "max_step": 50000, # Maximun training step 69 | "gpu": true # Use GPU 70 | }, 71 | "datasets": [ # A list of data sets 72 | { 73 | "name": "nld_ner", # Data set name 74 | "language": "nld", # Data set language 75 | "type": "sequence", # Data set type; 'sequence' is the only supported value though 76 | "task": "ner", # Task (identical to the 'task' value of the corresponding task) 77 | "parser": { # Data set parser 78 | "format": "conll", # File format 79 | "token_col": 0, # Token column index 80 | "label_col": 1 # Label column index 81 | }, 82 | "sample": 200, # Sample number (optional): 'all', int, or float 83 | "batch_size": 19, # Batch size 84 | "files": { 85 | "train": "/PATH/TO/ned.train.bioes", # Path to the training set 86 | "dev": "/PATH/TO/ned.testa.bioes", # Path to the dev set 87 | "test": "/PATH/TO/ned.testb.bioes" # Path to the test set (optional) 88 | } 89 | }, 90 | ... 91 | ], 92 | "tasks": [ 93 | { 94 | "name": "Dutch NER", # Task name 95 | "language": "nld", # Task language 96 | "task": "ner", # Task 97 | "model": { # Components can be shared and are configured in 'components'. Just 98 | # put their names here. 99 | "model": "lstm_crf", # Model type 100 | "word_embed": "nld_word_embed", # Word embedding 101 | "char_embed": "char_embed", # Character embedding 102 | "crf": "ner_crf", # CRF layer 103 | "lstm": "lstm", # LSTM layer 104 | "univ_layer": "ner_univ_linear", # Universal/shared linear layer 105 | "spec_layer": "ner_nld_linear", # Language-specific linear layer 106 | "embed_dropout": 0.0, # Embedding dropout probability 107 | "lstm_dropout": 0.6, # LSTM output dropout probability 108 | "linear_dropout": 0.0, # Linear layer output dropout probability 109 | "use_char_embedding": true, # Use character embeddings 110 | "char_highway": "char_highway" # Highway networks for character embeddings 111 | }, 112 | "dataset": "nld_ner", # Data set name 113 | "learning_rate": 0.02, # Learning rate 114 | "decay_rate": 0.9, # Decay rate 115 | "decay_step": 10000, # Decay step 116 | "ref": true # Is the target task 117 | }, 118 | ... 119 | ], 120 | "components": [ 121 | { 122 | "name": "eng_word_embed", 123 | "model": "embedding", 124 | "language": "eng", 125 | "file": "/PATH/TO/enwiki.cbow.50d.txt", 126 | "stats": true, 127 | "padding": 2, 128 | "trainable": true, 129 | "allow_gpu": false, 130 | "dimension": 50, 131 | "padding_idx": 0, 132 | "sparse": true 133 | }, 134 | { 135 | "name": "nld_word_embed", 136 | "model": "embedding", 137 | "language": "nld", 138 | "file": "/PATH/TO/nlwiki.cbow.50d.txt", 139 | "stats": true, 140 | "padding": 2, 141 | "trainable": true, 142 | "allow_gpu": false, 143 | "dimension": 50, 144 | "padding_idx": 0, 145 | "sparse": true 146 | }, 147 | { 148 | "name": "char_embed", 149 | "model": "char_cnn", 150 | "dimension": 50, 151 | "filters": [[2, 20], [3, 20], [4, 20]] 152 | }, 153 | { 154 | "name": "lstm", 155 | "model": "lstm", 156 | "hidden_size": 171, 157 | "bidirectional": true, 158 | "forget_bias": 1.0, 159 | "batch_first": true, 160 | "dropout": 0.0 # Because we use a 1-layer LSTM. This value doesn't have any effect. 161 | }, 162 | { 163 | "name": "ner_crf", 164 | "model": "crf" 165 | }, 166 | { 167 | "name": "pos_crf", 168 | "model": "crf" 169 | }, 170 | { 171 | "name": "ner_univ_linear", 172 | "model": "linear", 173 | "position": "output" 174 | }, 175 | { 176 | "name": "ner_eng_linear", 177 | "model": "linear", 178 | "position": "output" 179 | }, 180 | { 181 | "name": "ner_nld_linear", 182 | "model": "linear", 183 | "position": "output" 184 | }, 185 | { 186 | "name": "pos_univ_linear", 187 | "model": "linear", 188 | "position": "output" 189 | }, 190 | { 191 | "name": "pos_eng_linear", 192 | "model": "linear", 193 | "position": "output" 194 | }, 195 | { 196 | "name": "pos_nld_linear", 197 | "model": "linear", 198 | "position": "output" 199 | }, 200 | { 201 | "name": "char_highway", 202 | "model": "highway", 203 | "position": "char", 204 | "num_layers": 2, 205 | "activation": "selu" 206 | } 207 | ] 208 | } 209 | ``` 210 | 211 | ## Reference 212 | 213 | - Lin, Y., Yang, S., Stoyanov, V., Ji, H. (2018) *A Multi-lingual Multi-task Architecture for Low-resource Sequence Labeling*. Proceedings of The 56th Annual Meeting of the Association for Computational Linguistics. \[[pdf](http://nlp.cs.rpi.edu/paper/multilingualmultitask.pdf)\] 214 | 215 | ``` 216 | @inproceedings{ying2018multi, 217 | title = {A Multi-lingual Multi-task Architecture for Low-resource Sequence Labeling}, 218 | author = {Ying Lin and Shengqi Yang and Veselin Stoyanov and Heng Ji}, 219 | booktitle = {Proceedings of The 56th Annual Meeting of the Association for Computational Linguistics (ACL2018)}, 220 | year = {2018} 221 | } 222 | ``` 223 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import constant as C 3 | 4 | import logging 5 | from collections import Counter, defaultdict 6 | from random import shuffle, uniform, sample 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | def count2vocab(count, offset=0, pads=None, min_count=0, ignore_case=False): 13 | """Convert a token count dictionary to a vocabulary dict. 14 | :param count: Token count dictionary. 15 | :param offset: Begin start offset. 16 | :param pads: A list of padding (token, index) pairs. 17 | :param min_count: Minimum token count. 18 | :param ignore_case: Ignore token case. 19 | :return: Vocab dict. 20 | """ 21 | if ignore_case: 22 | count_ = defaultdict(int) 23 | for k, v in count.items(): 24 | count_[k.lower()] += v 25 | count = count_ 26 | 27 | vocab = {} 28 | for token, freq in count.items(): 29 | if freq > min_count: 30 | vocab[token] = len(vocab) + offset 31 | if pads: 32 | for k, v in pads: 33 | vocab[k] = v 34 | 35 | return vocab 36 | 37 | 38 | 39 | class Parser(object): 40 | 41 | def parse(self, path: str): 42 | raise NotImplementedError 43 | 44 | 45 | class ConllParser(Parser): 46 | 47 | def __init__(self, 48 | token_col: int = 0, 49 | label_col: int = 1, 50 | separator: str = '\t', 51 | skip_comment: bool = False): 52 | """ 53 | :param token_col: Token column (default=0). 54 | :param label_col: Label column (default=1). 55 | :param separator: Separate character (default=\t). 56 | :param skip_comment: Skip lines starting with #. 57 | """ 58 | self.token_col = token_col 59 | self.label_col = label_col 60 | self.separator = separator 61 | self.skip_comment = skip_comment 62 | 63 | def parse(self, 64 | path: str): 65 | token_col = self.token_col 66 | label_col = self.label_col 67 | separator = self.separator 68 | skip_comment = self.skip_comment 69 | 70 | with open(path, 'r', encoding='utf-8') as r: 71 | current_doc = [] 72 | for line in r: 73 | line = line.rstrip() 74 | if skip_comment and line.startswith('#'): 75 | continue 76 | if line: 77 | segs = line.split(separator) 78 | token, label = segs[token_col].strip(), segs[label_col] 79 | token = C.PENN_TREEBANK_BRACKETS.get(token, token) 80 | if label in {'B-O', 'I-O', 'E-O', 'S-O'}: 81 | label = 'O' 82 | current_doc.append((token, label)) 83 | elif current_doc: 84 | tokens = [] 85 | labels = [] 86 | for token, label in current_doc: 87 | tokens.append(token) 88 | labels.append(label) 89 | current_doc = [] 90 | yield tokens, labels 91 | if current_doc: 92 | tokens = [] 93 | labels = [] 94 | for token, label in current_doc: 95 | tokens.append(token) 96 | labels.append(label) 97 | yield tokens, labels 98 | 99 | 100 | class SeqLabelDataset(Dataset): 101 | 102 | def __init__(self, 103 | path: str, 104 | parser: Parser, 105 | max_seq_len: int = -1): 106 | self.path = path 107 | self.parser = parser 108 | self.max_seq_len = max_seq_len 109 | self.raw_data = [] 110 | self.data = [] 111 | self.load() 112 | 113 | def __getitem__(self, 114 | idx: int): 115 | return self.data[idx] 116 | 117 | def __len__(self): 118 | return len(self.data) 119 | 120 | def numberize(self, 121 | token_vocab: dict, 122 | label_vocab: dict, 123 | char_vocab: dict = None, 124 | ignore_case: bool = False): 125 | for tokens, labels in self.raw_data: 126 | if ignore_case: 127 | tokens_ = [t.lower() for t in tokens] 128 | tokens_ = [token_vocab[t] if t in token_vocab 129 | else C.UNK_INDEX for t in tokens_] 130 | else: 131 | tokens_ = [token_vocab[t] if t in token_vocab 132 | else C.UNK_INDEX for t in tokens] 133 | labels_ = [label_vocab[l] for l in labels] 134 | chars = None 135 | if char_vocab: 136 | chars = [[char_vocab[c] if c in char_vocab 137 | else C.UNK_INDEX for c in t] for t in tokens] 138 | if self.max_seq_len > 0: 139 | chars = chars[:self.max_seq_len] 140 | if self.max_seq_len > 0: 141 | tokens_ = tokens_[:self.max_seq_len] 142 | labels_ = labels_[:self.max_seq_len] 143 | self.data.append((tokens_, labels_, chars)) 144 | 145 | def load(self): 146 | self.raw_data = [(tokens, labels) 147 | for tokens, labels in self.parser.parse(self.path)] 148 | 149 | def stats(self, 150 | token_ignore_case: bool = False, 151 | char_ignore_case: bool = False, 152 | label_ignore_case: bool = False, 153 | ): 154 | token_counter = Counter() 155 | char_counter = Counter() 156 | label_counter = Counter() 157 | for item in self.raw_data: 158 | tokens, labels = item[0], item[1] 159 | token_lower = [t.lower() for t in tokens] 160 | if char_ignore_case: 161 | for token in token_lower: 162 | for c in token: 163 | char_counter[c] += 1 164 | else: 165 | for token in tokens: 166 | for c in token: 167 | char_counter[c] += 1 168 | if token_ignore_case: 169 | token_counter.update(token_lower) 170 | else: 171 | token_counter.update(tokens) 172 | if label_ignore_case: 173 | label_counter.update([l.lower() for l in labels]) 174 | else: 175 | label_counter.update(labels) 176 | 177 | return token_counter, char_counter, label_counter 178 | 179 | 180 | class BatchProcessor(object): 181 | 182 | def process(self, batch): 183 | assert NotImplementedError 184 | 185 | 186 | class SeqLabelProcessor(BatchProcessor): 187 | 188 | def __init__(self, 189 | sort: bool = True, 190 | gpu: bool = False, 191 | padding_idx: int = C.PAD_INDEX, 192 | min_char_len: int = 4): 193 | self.sort = sort 194 | self.gpu = gpu 195 | self.padding_idx = padding_idx 196 | self.min_char_len = min_char_len 197 | 198 | def process(self, batch: list): 199 | padding_idx = self.padding_idx 200 | # if self.sort: 201 | batch.sort(key=lambda x: len(x[0]), reverse=True) 202 | 203 | seq_lens = [len(x[0]) for x in batch] 204 | max_seq_len = max(seq_lens) 205 | 206 | char_lens = [] 207 | for seq in batch: 208 | seq_char_lens = [len(x) for x in seq[2]] + \ 209 | [padding_idx] * (max_seq_len - len(seq[0])) 210 | char_lens.extend(seq_char_lens) 211 | max_char_len = max(max(char_lens), self.min_char_len) 212 | 213 | # Padding 214 | batch_tokens = [] 215 | batch_labels = [] 216 | batch_chars = [] 217 | for tokens, labels, chars in batch: 218 | batch_tokens.append(tokens + [padding_idx] * (max_seq_len - len(tokens))) 219 | batch_labels.append(labels + [padding_idx] * (max_seq_len - len(tokens))) 220 | batch_chars.extend( 221 | [x + [0] * (max_char_len - len(x)) for x in chars] 222 | # + [[0] * max_char_len] * (max_seq_len - len(tokens)) 223 | + [[0] * max_char_len for _ in range(max_seq_len - len(tokens))] 224 | ) 225 | 226 | batch_tokens = torch.LongTensor(batch_tokens) 227 | batch_labels = torch.LongTensor(batch_labels) 228 | batch_chars = torch.LongTensor(batch_chars) 229 | seq_lens = torch.LongTensor(seq_lens) 230 | char_lens = torch.LongTensor(char_lens) 231 | 232 | if self.gpu: 233 | batch_tokens = batch_tokens.cuda() 234 | batch_labels = batch_labels.cuda() 235 | batch_chars = batch_chars.cuda() 236 | seq_lens = seq_lens.cuda() 237 | char_lens = char_lens.cuda() 238 | 239 | return (batch_tokens, batch_labels, batch_chars, 240 | seq_lens, char_lens) 241 | -------------------------------------------------------------------------------- /conlleval.py: -------------------------------------------------------------------------------- 1 | # https://github.com/spyysalo/conlleval.py/blob/master/conlleval.py 2 | 3 | # Python version of the evaluation script from CoNLL'00- 4 | 5 | # Intentional differences: 6 | # - accept any space as delimiter by default 7 | # - optional file argument (default STDIN) 8 | # - option to set boundary (-b argument) 9 | # - LaTeX output (-l argument) not supported 10 | # - raw tags (-r argument) not supported 11 | 12 | import sys 13 | import re 14 | 15 | from collections import defaultdict, namedtuple 16 | 17 | ANY_SPACE = '' 18 | 19 | 20 | class FormatError(Exception): 21 | pass 22 | 23 | 24 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore') 25 | 26 | 27 | class EvalCounts(object): 28 | def __init__(self): 29 | self.correct_chunk = 0 # number of correctly identified chunks 30 | self.correct_tags = 0 # number of correct chunk tags 31 | self.found_correct = 0 # number of chunks in corpus 32 | self.found_guessed = 0 # number of identified chunks 33 | self.token_counter = 0 # token counter (ignores sentence breaks) 34 | 35 | # counts by type 36 | self.t_correct_chunk = defaultdict(int) 37 | self.t_found_correct = defaultdict(int) 38 | self.t_found_guessed = defaultdict(int) 39 | 40 | 41 | def parse_args(argv): 42 | import argparse 43 | parser = argparse.ArgumentParser( 44 | description='evaluate tagging results using CoNLL criteria', 45 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 46 | ) 47 | arg = parser.add_argument 48 | arg('-b', '--boundary', metavar='STR', default='-X-', 49 | help='sentence boundary') 50 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE, 51 | help='character delimiting items in input') 52 | arg('-o', '--otag', metavar='CHAR', default='O', 53 | help='alternative outside tag') 54 | arg('file', nargs='?', default=None) 55 | return parser.parse_args(argv) 56 | 57 | 58 | def parse_tag(t): 59 | m = re.match(r'^([^-]*)-(.*)$', t) 60 | return m.groups() if m else (t, '') 61 | 62 | 63 | def evaluate(iterable, options=None): 64 | if options is None: 65 | options = parse_args([]) # use defaults 66 | 67 | counts = EvalCounts() 68 | num_features = None # number of features per line 69 | in_correct = False # currently processed chunks is correct until now 70 | last_correct = 'O' # previous chunk tag in corpus 71 | last_correct_type = '' # type of previously identified chunk tag 72 | last_guessed = 'O' # previously identified chunk tag 73 | last_guessed_type = '' # type of previous chunk tag in corpus 74 | 75 | for line in iterable: 76 | line = line.rstrip('\r\n') 77 | 78 | if options.delimiter == ANY_SPACE: 79 | features = line.split() 80 | else: 81 | features = line.split(options.delimiter) 82 | 83 | if num_features is None: 84 | num_features = len(features) 85 | elif num_features != len(features) and len(features) != 0: 86 | raise FormatError('unexpected number of features: %d (%d)' % 87 | (len(features), num_features)) 88 | 89 | if len(features) == 0 or features[0] == options.boundary: 90 | features = [options.boundary, 'O', 'O'] 91 | if len(features) < 3: 92 | raise FormatError('unexpected number of features in line %s' % line) 93 | 94 | guessed, guessed_type = parse_tag(features.pop()) 95 | correct, correct_type = parse_tag(features.pop()) 96 | first_item = features.pop(0) 97 | 98 | if first_item == options.boundary: 99 | guessed = 'O' 100 | 101 | end_correct = end_of_chunk(last_correct, correct, 102 | last_correct_type, correct_type) 103 | end_guessed = end_of_chunk(last_guessed, guessed, 104 | last_guessed_type, guessed_type) 105 | start_correct = start_of_chunk(last_correct, correct, 106 | last_correct_type, correct_type) 107 | start_guessed = start_of_chunk(last_guessed, guessed, 108 | last_guessed_type, guessed_type) 109 | 110 | if in_correct: 111 | if (end_correct and end_guessed and 112 | last_guessed_type == last_correct_type): 113 | in_correct = False 114 | counts.correct_chunk += 1 115 | counts.t_correct_chunk[last_correct_type] += 1 116 | elif (end_correct != end_guessed or guessed_type != correct_type): 117 | in_correct = False 118 | 119 | if start_correct and start_guessed and guessed_type == correct_type: 120 | in_correct = True 121 | 122 | if start_correct: 123 | counts.found_correct += 1 124 | counts.t_found_correct[correct_type] += 1 125 | if start_guessed: 126 | counts.found_guessed += 1 127 | counts.t_found_guessed[guessed_type] += 1 128 | if first_item != options.boundary: 129 | if correct == guessed and guessed_type == correct_type: 130 | counts.correct_tags += 1 131 | counts.token_counter += 1 132 | 133 | last_guessed = guessed 134 | last_correct = correct 135 | last_guessed_type = guessed_type 136 | last_correct_type = correct_type 137 | 138 | if in_correct: 139 | counts.correct_chunk += 1 140 | counts.t_correct_chunk[last_correct_type] += 1 141 | 142 | return counts 143 | 144 | 145 | def uniq(iterable): 146 | seen = set() 147 | return [i for i in iterable if not (i in seen or seen.add(i))] 148 | 149 | 150 | def calculate_metrics(correct, guessed, total): 151 | tp, fp, fn = correct, guessed - correct, total - correct 152 | p = 0 if tp + fp == 0 else 1. * tp / (tp + fp) 153 | r = 0 if tp + fn == 0 else 1. * tp / (tp + fn) 154 | f = 0 if p + r == 0 else 2 * p * r / (p + r) 155 | return Metrics(tp, fp, fn, p, r, f) 156 | 157 | 158 | def metrics(counts): 159 | c = counts 160 | overall = calculate_metrics( 161 | c.correct_chunk, c.found_guessed, c.found_correct 162 | ) 163 | by_type = {} 164 | uniq_keys = list(c.t_found_correct.keys()) + list(c.t_found_guessed.keys()) 165 | for t in set(uniq_keys): 166 | by_type[t] = calculate_metrics( 167 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t] 168 | ) 169 | return overall, by_type 170 | 171 | 172 | def report(counts, out=None): 173 | if out is None: 174 | out = sys.stdout 175 | 176 | overall, by_type = metrics(counts) 177 | 178 | c = counts 179 | out.write('processed %d tokens with %d phrases; ' % 180 | (c.token_counter, c.found_correct)) 181 | out.write('found: %d phrases; correct: %d.\n' % 182 | (c.found_guessed, c.correct_chunk)) 183 | 184 | if c.token_counter > 0: 185 | out.write('Acc: %6.3f%%; ' % 186 | (100. * c.correct_tags / c.token_counter)) 187 | out.write('P: %6.3f%%; ' % (100. * overall.prec)) 188 | out.write('R: %6.3f%%; ' % (100. * overall.rec)) 189 | out.write('F: %6.3f\n' % (100. * overall.fscore)) 190 | 191 | for i, m in sorted(by_type.items()): 192 | out.write('%12s: ' % i) 193 | out.write('P: %6.3f%%; ' % (100. * m.prec)) 194 | out.write('R: %6.3f%%; ' % (100. * m.rec)) 195 | out.write('F: %6.3f %d\n' % (100. * m.fscore, c.t_found_guessed[i])) 196 | 197 | 198 | def end_of_chunk(prev_tag, tag, prev_type, type_): 199 | # check if a chunk ended between the previous and current word 200 | # arguments: previous and current chunk tags, previous and current types 201 | chunk_end = False 202 | 203 | if prev_tag == 'E': chunk_end = True 204 | if prev_tag == 'S': chunk_end = True 205 | 206 | if prev_tag == 'B' and tag == 'B': chunk_end = True 207 | if prev_tag == 'B' and tag == 'S': chunk_end = True 208 | if prev_tag == 'B' and tag == 'O': chunk_end = True 209 | if prev_tag == 'I' and tag == 'B': chunk_end = True 210 | if prev_tag == 'I' and tag == 'S': chunk_end = True 211 | if prev_tag == 'I' and tag == 'O': chunk_end = True 212 | 213 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_: 214 | chunk_end = True 215 | 216 | # these chunks are assumed to have length 1 217 | if prev_tag == ']': chunk_end = True 218 | if prev_tag == '[': chunk_end = True 219 | 220 | return chunk_end 221 | 222 | 223 | def start_of_chunk(prev_tag, tag, prev_type, type_): 224 | # check if a chunk started between the previous and current word 225 | # arguments: previous and current chunk tags, previous and current types 226 | chunk_start = False 227 | 228 | if tag == 'B': chunk_start = True 229 | if tag == 'S': chunk_start = True 230 | 231 | if prev_tag == 'E' and tag == 'E': chunk_start = True 232 | if prev_tag == 'E' and tag == 'I': chunk_start = True 233 | if prev_tag == 'S' and tag == 'E': chunk_start = True 234 | if prev_tag == 'S' and tag == 'I': chunk_start = True 235 | if prev_tag == 'O' and tag == 'E': chunk_start = True 236 | if prev_tag == 'O' and tag == 'I': chunk_start = True 237 | 238 | if tag != 'O' and tag != '.' and prev_type != type_: 239 | chunk_start = True 240 | 241 | # these chunks are assumed to have length 1 242 | if tag == '[': chunk_start = True 243 | if tag == ']': chunk_start = True 244 | 245 | return chunk_start 246 | -------------------------------------------------------------------------------- /old/conlleval.py: -------------------------------------------------------------------------------- 1 | # https://github.com/spyysalo/conlleval.py/blob/master/conlleval.py 2 | 3 | # Python version of the evaluation script from CoNLL'00- 4 | 5 | # Intentional differences: 6 | # - accept any space as delimiter by default 7 | # - optional file argument (default STDIN) 8 | # - option to set boundary (-b argument) 9 | # - LaTeX output (-l argument) not supported 10 | # - raw tags (-r argument) not supported 11 | 12 | import sys 13 | import re 14 | 15 | from collections import defaultdict, namedtuple 16 | 17 | ANY_SPACE = '' 18 | 19 | 20 | class FormatError(Exception): 21 | pass 22 | 23 | 24 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore') 25 | 26 | 27 | class EvalCounts(object): 28 | def __init__(self): 29 | self.correct_chunk = 0 # number of correctly identified chunks 30 | self.correct_tags = 0 # number of correct chunk tags 31 | self.found_correct = 0 # number of chunks in corpus 32 | self.found_guessed = 0 # number of identified chunks 33 | self.token_counter = 0 # token counter (ignores sentence breaks) 34 | 35 | # counts by type 36 | self.t_correct_chunk = defaultdict(int) 37 | self.t_found_correct = defaultdict(int) 38 | self.t_found_guessed = defaultdict(int) 39 | 40 | 41 | def parse_args(argv): 42 | import argparse 43 | parser = argparse.ArgumentParser( 44 | description='evaluate tagging results using CoNLL criteria', 45 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 46 | ) 47 | arg = parser.add_argument 48 | arg('-b', '--boundary', metavar='STR', default='-X-', 49 | help='sentence boundary') 50 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE, 51 | help='character delimiting items in input') 52 | arg('-o', '--otag', metavar='CHAR', default='O', 53 | help='alternative outside tag') 54 | arg('file', nargs='?', default=None) 55 | return parser.parse_args(argv) 56 | 57 | 58 | def parse_tag(t): 59 | m = re.match(r'^([^-]*)-(.*)$', t) 60 | return m.groups() if m else (t, '') 61 | 62 | 63 | def evaluate(iterable, options=None): 64 | if options is None: 65 | options = parse_args([]) # use defaults 66 | 67 | counts = EvalCounts() 68 | num_features = None # number of features per line 69 | in_correct = False # currently processed chunks is correct until now 70 | last_correct = 'O' # previous chunk tag in corpus 71 | last_correct_type = '' # type of previously identified chunk tag 72 | last_guessed = 'O' # previously identified chunk tag 73 | last_guessed_type = '' # type of previous chunk tag in corpus 74 | 75 | for line in iterable: 76 | line = line.rstrip('\r\n') 77 | 78 | if options.delimiter == ANY_SPACE: 79 | features = line.split() 80 | else: 81 | features = line.split(options.delimiter) 82 | 83 | if num_features is None: 84 | num_features = len(features) 85 | elif num_features != len(features) and len(features) != 0: 86 | raise FormatError('unexpected number of features: %d (%d)' % 87 | (len(features), num_features)) 88 | 89 | if len(features) == 0 or features[0] == options.boundary: 90 | features = [options.boundary, 'O', 'O'] 91 | if len(features) < 3: 92 | raise FormatError('unexpected number of features in line %s' % line) 93 | 94 | guessed, guessed_type = parse_tag(features.pop()) 95 | correct, correct_type = parse_tag(features.pop()) 96 | first_item = features.pop(0) 97 | 98 | if first_item == options.boundary: 99 | guessed = 'O' 100 | 101 | end_correct = end_of_chunk(last_correct, correct, 102 | last_correct_type, correct_type) 103 | end_guessed = end_of_chunk(last_guessed, guessed, 104 | last_guessed_type, guessed_type) 105 | start_correct = start_of_chunk(last_correct, correct, 106 | last_correct_type, correct_type) 107 | start_guessed = start_of_chunk(last_guessed, guessed, 108 | last_guessed_type, guessed_type) 109 | 110 | if in_correct: 111 | if (end_correct and end_guessed and 112 | last_guessed_type == last_correct_type): 113 | in_correct = False 114 | counts.correct_chunk += 1 115 | counts.t_correct_chunk[last_correct_type] += 1 116 | elif (end_correct != end_guessed or guessed_type != correct_type): 117 | in_correct = False 118 | 119 | if start_correct and start_guessed and guessed_type == correct_type: 120 | in_correct = True 121 | 122 | if start_correct: 123 | counts.found_correct += 1 124 | counts.t_found_correct[correct_type] += 1 125 | if start_guessed: 126 | counts.found_guessed += 1 127 | counts.t_found_guessed[guessed_type] += 1 128 | if first_item != options.boundary: 129 | if correct == guessed and guessed_type == correct_type: 130 | counts.correct_tags += 1 131 | counts.token_counter += 1 132 | 133 | last_guessed = guessed 134 | last_correct = correct 135 | last_guessed_type = guessed_type 136 | last_correct_type = correct_type 137 | 138 | if in_correct: 139 | counts.correct_chunk += 1 140 | counts.t_correct_chunk[last_correct_type] += 1 141 | 142 | return counts 143 | 144 | 145 | def uniq(iterable): 146 | seen = set() 147 | return [i for i in iterable if not (i in seen or seen.add(i))] 148 | 149 | 150 | def calculate_metrics(correct, guessed, total): 151 | tp, fp, fn = correct, guessed - correct, total - correct 152 | p = 0 if tp + fp == 0 else 1. * tp / (tp + fp) 153 | r = 0 if tp + fn == 0 else 1. * tp / (tp + fn) 154 | f = 0 if p + r == 0 else 2 * p * r / (p + r) 155 | return Metrics(tp, fp, fn, p, r, f) 156 | 157 | 158 | def metrics(counts): 159 | c = counts 160 | overall = calculate_metrics( 161 | c.correct_chunk, c.found_guessed, c.found_correct 162 | ) 163 | by_type = {} 164 | uniq_keys = list(c.t_found_correct.keys()) + list(c.t_found_guessed.keys()) 165 | for t in set(uniq_keys): 166 | by_type[t] = calculate_metrics( 167 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t] 168 | ) 169 | return overall, by_type 170 | 171 | 172 | def report(counts, out=None): 173 | if out is None: 174 | out = sys.stdout 175 | 176 | overall, by_type = metrics(counts) 177 | 178 | c = counts 179 | out.write('processed %d tokens with %d phrases; ' % 180 | (c.token_counter, c.found_correct)) 181 | out.write('found: %d phrases; correct: %d.\n' % 182 | (c.found_guessed, c.correct_chunk)) 183 | 184 | if c.token_counter > 0: 185 | out.write('Acc: %6.3f%%; ' % 186 | (100. * c.correct_tags / c.token_counter)) 187 | out.write('P: %6.3f%%; ' % (100. * overall.prec)) 188 | out.write('R: %6.3f%%; ' % (100. * overall.rec)) 189 | out.write('F: %6.3f\n' % (100. * overall.fscore)) 190 | 191 | for i, m in sorted(by_type.items()): 192 | out.write('%12s: ' % i) 193 | out.write('P: %6.3f%%; ' % (100. * m.prec)) 194 | out.write('R: %6.3f%%; ' % (100. * m.rec)) 195 | out.write('F: %6.3f %d\n' % (100. * m.fscore, c.t_found_guessed[i])) 196 | 197 | 198 | def end_of_chunk(prev_tag, tag, prev_type, type_): 199 | # check if a chunk ended between the previous and current word 200 | # arguments: previous and current chunk tags, previous and current types 201 | chunk_end = False 202 | 203 | if prev_tag == 'E': chunk_end = True 204 | if prev_tag == 'S': chunk_end = True 205 | 206 | if prev_tag == 'B' and tag == 'B': chunk_end = True 207 | if prev_tag == 'B' and tag == 'S': chunk_end = True 208 | if prev_tag == 'B' and tag == 'O': chunk_end = True 209 | if prev_tag == 'I' and tag == 'B': chunk_end = True 210 | if prev_tag == 'I' and tag == 'S': chunk_end = True 211 | if prev_tag == 'I' and tag == 'O': chunk_end = True 212 | 213 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_: 214 | chunk_end = True 215 | 216 | # these chunks are assumed to have length 1 217 | if prev_tag == ']': chunk_end = True 218 | if prev_tag == '[': chunk_end = True 219 | 220 | return chunk_end 221 | 222 | 223 | def start_of_chunk(prev_tag, tag, prev_type, type_): 224 | # check if a chunk started between the previous and current word 225 | # arguments: previous and current chunk tags, previous and current types 226 | chunk_start = False 227 | 228 | if tag == 'B': chunk_start = True 229 | if tag == 'S': chunk_start = True 230 | 231 | if prev_tag == 'E' and tag == 'E': chunk_start = True 232 | if prev_tag == 'E' and tag == 'I': chunk_start = True 233 | if prev_tag == 'S' and tag == 'E': chunk_start = True 234 | if prev_tag == 'S' and tag == 'I': chunk_start = True 235 | if prev_tag == 'O' and tag == 'E': chunk_start = True 236 | if prev_tag == 'O' and tag == 'I': chunk_start = True 237 | 238 | if tag != 'O' and tag != '.' and prev_type != type_: 239 | chunk_start = True 240 | 241 | # these chunks are assumed to have length 1 242 | if tag == '[': chunk_start = True 243 | if tag == ']': chunk_start = True 244 | 245 | return chunk_start 246 | -------------------------------------------------------------------------------- /train_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import traceback 5 | from collections import Counter 6 | 7 | import torch 8 | from torch import optim 9 | from torch.nn.utils import clip_grad_norm_ 10 | 11 | import constant as C 12 | from argparse import ArgumentParser 13 | 14 | from torch.utils.data import DataLoader 15 | 16 | from util import evaluate 17 | from data import ConllParser, SeqLabelDataset, SeqLabelProcessor, count2vocab 18 | from model import Linears, LSTM, CRF, CharCNN, Highway, LstmCrf, load_embedding 19 | 20 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 21 | 22 | logging.basicConfig(level=logging.DEBUG) 23 | logger = logging.getLogger() 24 | 25 | argparser = ArgumentParser() 26 | 27 | argparser.add_argument('--train', help='Path to the training set file') 28 | argparser.add_argument('--dev', help='Path to the dev set file') 29 | argparser.add_argument('--test', help='Path to the test set file') 30 | argparser.add_argument('--log', help='Path to the log dir') 31 | argparser.add_argument('--model', help='Path to the model file') 32 | argparser.add_argument('--batch_size', default=10, type=int, help='Batch size') 33 | argparser.add_argument('--max_epoch', default=100, type=int) 34 | argparser.add_argument('--word_embed', 35 | help='Path to the pre-trained embedding file') 36 | argparser.add_argument('--word_embed_dim', type=int, default=100, 37 | help='Word embedding dimension') 38 | argparser.set_defaults(word_ignore_case=False) 39 | argparser.add_argument('--char_embed_dim', type=int, default=50, 40 | help='Character embedding dimension') 41 | argparser.add_argument('--charcnn_filters', default='2,25;3,25;4,25', 42 | help='Character-level CNN filters') 43 | argparser.add_argument('--charhw_layer', default=1, type=int) 44 | argparser.add_argument('--charhw_func', default='relu') 45 | argparser.add_argument('--use_highway', action='store_true') 46 | argparser.add_argument('--lstm_hidden_size', default=100, type=int, 47 | help='LSTM hidden state size') 48 | argparser.add_argument('--lstm_forget_bias', default=0, type=float, 49 | help='LSTM forget bias') 50 | argparser.add_argument('--feat_dropout', default=.5, type=float, 51 | help='Word feature dropout probability') 52 | argparser.add_argument('--lstm_dropout', default=.5, type=float, 53 | help='LSTM output dropout probability') 54 | argparser.add_argument('--lr', default=0.005, type=float, 55 | help='Learning rate') 56 | argparser.add_argument('--momentum', default=.9, type=float) 57 | argparser.add_argument('--decay_rate', default=.9, type=float) 58 | argparser.add_argument('--decay_step', default=10000, type=int) 59 | argparser.add_argument('--grad_clipping', default=5, type=float) 60 | argparser.add_argument('--gpu', action='store_true') 61 | argparser.add_argument('--device', default=0, type=int) 62 | argparser.add_argument('--thread', default=5, type=int) 63 | 64 | args = argparser.parse_args() 65 | 66 | use_gpu = args.gpu and torch.cuda.is_available() 67 | if use_gpu: 68 | torch.cuda.set_device(args.device) 69 | 70 | # Model file 71 | model_dir = args.model 72 | assert model_dir and os.path.isdir(model_dir), 'Model output dir is required' 73 | model_file = os.path.join(model_dir, 'model.{}.mdl'.format(timestamp)) 74 | 75 | # Logging file 76 | log_writer = None 77 | if args.log: 78 | log_file = os.path.join(args.log, 'log.{}.txt'.format(timestamp)) 79 | log_writer = open(log_file, 'a', encoding='utf-8') 80 | logger.addHandler(logging.FileHandler(log_file, encoding='utf-8')) 81 | logger.info('----------') 82 | logger.info('Parameters:') 83 | for arg in vars(args): 84 | logger.info('{}: {}'.format(arg, getattr(args, arg))) 85 | logger.info('----------') 86 | 87 | # Data file 88 | logger.info('Loading data sets') 89 | parser = ConllParser(separator='\t', token_col=0, label_col=1, skip_comment=True) 90 | train_set = SeqLabelDataset(args.train, parser=parser) 91 | dev_set = SeqLabelDataset(args.dev, parser=parser) 92 | test_set = SeqLabelDataset(args.test, parser=parser) 93 | datasets = {'train': train_set, 'dev': dev_set, 'test': test_set} 94 | 95 | # Vocabs 96 | logger.info('Building vocabs') 97 | token_count, char_count, label_count = Counter(), Counter(), Counter() 98 | for _, ds in datasets.items(): 99 | tc, cc, lc = ds.stats() 100 | token_count.update(tc) 101 | char_count.update(cc) 102 | label_count.update(lc) 103 | token_vocab = count2vocab(token_count, offset=len(C.TOKEN_PADS), pads=C.TOKEN_PADS) 104 | char_vocab = count2vocab(char_count, offset=len(C.CHAR_PADS), pads=C.CHAR_PADS) 105 | label_vocab = count2vocab(label_count, offset=1, pads=[(C.PAD, C.PAD_INDEX)]) 106 | idx_token = {v: k for k, v in token_vocab.items()} 107 | idx_label = {v: k for k, v in label_vocab.items()} 108 | train_set.numberize(token_vocab, label_vocab, char_vocab) 109 | dev_set.numberize(token_vocab, label_vocab, char_vocab) 110 | test_set.numberize(token_vocab, label_vocab, char_vocab) 111 | print('#token: {}'.format(len(token_vocab))) 112 | print('#char: {}'.format(len(char_vocab))) 113 | print('#label: {}'.format(len(label_vocab))) 114 | 115 | # Embedding file 116 | word_embed = load_embedding(args.word_embed, 117 | dimension=args.word_embed_dim, 118 | vocab=token_vocab) 119 | charcnn_filters = [[int(f.split(',')[0]), int(f.split(',')[1])] 120 | for f in args.charcnn_filters.split(';')] 121 | char_embed = CharCNN(len(char_vocab), 122 | args.char_embed_dim, 123 | filters=charcnn_filters) 124 | char_hw = Highway(char_embed.output_size, 125 | layer_num=args.charhw_layer, 126 | activation=args.charhw_func) 127 | feat_dim = word_embed.embedding_dim + char_embed.output_size 128 | lstm = LSTM(feat_dim, 129 | args.lstm_hidden_size, 130 | batch_first=True, 131 | bidirectional=True, 132 | forget_bias=args.lstm_forget_bias 133 | ) 134 | crf = CRF(label_size=len(label_vocab) + 2) 135 | linear = Linears(in_features=lstm.output_size, 136 | out_features=len(label_vocab), 137 | hiddens=[lstm.output_size // 2]) 138 | lstm_crf = LstmCrf( 139 | token_vocab, label_vocab, char_vocab, 140 | word_embedding=word_embed, 141 | char_embedding=char_embed, 142 | crf=crf, 143 | lstm=lstm, 144 | univ_fc_layer=linear, 145 | embed_dropout_prob=args.feat_dropout, 146 | lstm_dropout_prob=args.lstm_dropout, 147 | char_highway=char_hw if args.use_highway else None 148 | ) 149 | if use_gpu: 150 | lstm_crf.cuda() 151 | torch.set_num_threads(args.thread) 152 | 153 | logger.debug(lstm_crf) 154 | 155 | # Task 156 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, lstm_crf.parameters()), 157 | lr=args.lr, momentum=args.momentum) 158 | processor = SeqLabelProcessor(gpu=use_gpu) 159 | 160 | train_args = vars(args) 161 | train_args['word_embed_size'] = word_embed.num_embeddings 162 | state = { 163 | 'model': { 164 | 'word_embed': word_embed.state_dict(), 165 | 'char_embed': char_embed.state_dict(), 166 | 'char_hw': char_hw.state_dict(), 167 | 'lstm': lstm.state_dict(), 168 | 'crf': crf.state_dict(), 169 | 'linear': linear.state_dict(), 170 | 'lstm_crf': lstm_crf.state_dict() 171 | }, 172 | 'args': train_args, 173 | 'vocab': { 174 | 'token': token_vocab, 175 | 'label': label_vocab, 176 | 'char': char_vocab, 177 | } 178 | } 179 | try: 180 | global_step = 0 181 | best_dev_score = best_test_score = 0.0 182 | 183 | for epoch in range(args.max_epoch): 184 | logger.info('Epoch {}: Training'.format(epoch)) 185 | 186 | best = False 187 | for ds in ['train', 'dev', 'test']: 188 | dataset = datasets[ds] 189 | epoch_loss = [] 190 | results = [] 191 | 192 | for batch in DataLoader( 193 | dataset, 194 | batch_size=args.batch_size, 195 | shuffle=ds == 'train', 196 | drop_last=ds == 'train', 197 | collate_fn=processor.process 198 | ): 199 | optimizer.zero_grad() 200 | tokens, labels, chars, seq_lens, char_lens = batch 201 | if ds == 'train': 202 | global_step += 1 203 | loglik, _ = lstm_crf.loglik( 204 | tokens, labels, seq_lens, chars, char_lens) 205 | loss = -loglik.mean() 206 | loss.backward() 207 | clip_grad_norm_(lstm_crf.parameters(), args.grad_clipping) 208 | optimizer.step() 209 | else: 210 | pred, loss = lstm_crf.predict( 211 | tokens, labels, seq_lens, chars, char_lens) 212 | results.append((pred, labels, seq_lens, tokens)) 213 | 214 | epoch_loss.append(loss.item()) 215 | 216 | epoch_loss = sum(epoch_loss) / len(epoch_loss) 217 | logger.info('{} Loss: {:.4f}'.format(ds, epoch_loss)) 218 | 219 | if ds == 'dev' or ds == 'test': 220 | fscore, prec, rec = evaluate( 221 | results, idx_token, idx_label, writer=log_writer 222 | ) 223 | if ds == 'dev' and fscore > best_dev_score: 224 | logger.info('New best score: {:.4f}'.format(fscore)) 225 | best_dev_score = fscore 226 | best = True 227 | logger.info( 228 | 'Saving the current model to {}'.format(model_file)) 229 | torch.save(state, model_file) 230 | if best and ds == 'test': 231 | best_test_score = fscore 232 | 233 | # learning rate decay 234 | lr = args.lr * args.decay_rate ** (global_step / args.decay_step) 235 | for p in optimizer.param_groups: 236 | p['lr'] = lr 237 | logger.info('New learning rate: {}'.format(lr)) 238 | 239 | logger.info('Best score: {}'.format(best_dev_score)) 240 | logger.info('Best test score: {}'.format(best_test_score)) 241 | logger.info('Model file: {}'.format(model_file)) 242 | if args.log: 243 | logger.info('Log file: {}'.format(log_file)) 244 | log_writer.close() 245 | except Exception: 246 | traceback.print_exc() 247 | if log_writer: 248 | log_writer.close() -------------------------------------------------------------------------------- /old/train_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import traceback 4 | 5 | import tqdm 6 | 7 | import torch 8 | from torch import optim 9 | from torch.nn.utils import clip_grad_norm 10 | 11 | import constant as C 12 | from model import Linear, LSTM, CRF, CharCNN, Highway, LstmCrf, Embedding 13 | from argparse import ArgumentParser 14 | from util import get_logger, evaluate, Config 15 | from data import ( 16 | SequenceDataset, ConllParser, 17 | compute_metadata, count2vocab, numberize_datasets 18 | ) 19 | 20 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.gmtime()) 21 | 22 | argparser = ArgumentParser() 23 | 24 | argparser.add_argument('--train', help='Path to the training set file') 25 | argparser.add_argument('--dev', help='Path to the dev set file') 26 | argparser.add_argument('--test', help='Path to the test set file') 27 | argparser.add_argument('--log', help='Path to the log dir') 28 | argparser.add_argument('--model', help='Path to the model file') 29 | argparser.add_argument('--batch_size', default=10, type=int, help='Batch size') 30 | argparser.add_argument('--max_epoch', default=100, type=int) 31 | argparser.add_argument('--embedding', 32 | help='Path to the pre-trained embedding file') 33 | argparser.add_argument('--embed_skip_first', dest='embed_skip_first', 34 | action='store_true', 35 | help='Skip the first line of the embedding file') 36 | argparser.set_defaults(embed_skip_first=True) 37 | argparser.add_argument('--word_embed_dim', type=int, default=100, 38 | help='Word embedding dimension') 39 | argparser.add_argument('--word_ignore_case', dest='word_ignore_case', 40 | action='store_true') 41 | argparser.set_defaults(word_ignore_case=False) 42 | argparser.add_argument('--char_embed_dim', type=int, default=50, 43 | help='Character embedding dimension') 44 | argparser.add_argument('--charcnn_filters', default='2,25;3,25;4,25', 45 | help='Character-level CNN filters') 46 | argparser.add_argument('--lstm_hidden_size', default=100, type=int, 47 | help='LSTM hidden state size') 48 | argparser.add_argument('--embed_dropout', default=.2, type=float, 49 | help='Embedding dropout probability') 50 | argparser.add_argument('--lstm_dropout', default=.5, type=float, 51 | help='LSTM output dropout probability') 52 | argparser.add_argument('--linear_dropout', default=0, type=float, 53 | help='Output linear layer dropout probability') 54 | argparser.add_argument('--lr', default=0.005, type=float, 55 | help='Learning rate') 56 | argparser.add_argument('--momentum', default=.9, type=float) 57 | argparser.add_argument('--decay_rate', default=.9, type=float) 58 | argparser.add_argument('--decay_step', default=10000, type=int) 59 | argparser.add_argument('--grad_clipping', default=5, type=float) 60 | argparser.add_argument('--gpu', default=1, type=int) 61 | argparser.add_argument('--gpu_idx', default=0, type=int) 62 | 63 | args = argparser.parse_args() 64 | 65 | # Parameters 66 | model_file = args.model 67 | assert model_file, 'Model output path is required' 68 | model_file = os.path.join(args.model, 'model.{}.mdl'.format(timestamp)) 69 | 70 | embed_file = args.embedding 71 | charcnn_filters = [[int(f.split(',')[0]), int(f.split(',')[1])] 72 | for f in args.charcnn_filters.split(';')] 73 | use_gpu = (args.gpu == 1) 74 | word_ignore_case = args.word_ignore_case 75 | log_writer = None 76 | if args.log: 77 | log_file = os.path.join(args.log, 'log.{}.txt'.format(timestamp)) 78 | log_writer = open(log_file, 'a', encoding='utf-8') 79 | logger = get_logger(__name__, log_file=log_file) 80 | else: 81 | logger = get_logger(__name__) 82 | 83 | logger.info('----------') 84 | logger.info('Parameters:') 85 | for arg in vars(args): 86 | logger.info('{}: {}'.format(arg, getattr(args, arg))) 87 | logger.info('----------') 88 | 89 | # Parser for CoNLL format file 90 | conll_parser = ConllParser(Config({ 91 | 'separator': '\t', 92 | 'token_col': 0, 93 | 'label_col': 1, 94 | 'skip_comment': True, 95 | })) 96 | 97 | # Load datasets 98 | logger.info('Loading datasets') 99 | train_set = SequenceDataset(Config({ 100 | 'path': args.train, 'parser': conll_parser, 'batch_size': args.batch_size})) 101 | dev_set = SequenceDataset(Config({ 102 | 'path': args.dev, 'parser': conll_parser})) 103 | test_set = SequenceDataset(Config({ 104 | 'path': args.test, 'parser': conll_parser})) 105 | datasets = {'train': train_set, 'dev': dev_set, 'test': test_set} 106 | 107 | # Vocabs 108 | logger.info('Building vocabularies') 109 | token_count, label_count, char_count = compute_metadata( 110 | [train_set, dev_set, test_set]) 111 | token_vocab = count2vocab([token_count], 112 | start_idx=C.EMBED_START_IDX, 113 | ignore_case=word_ignore_case) 114 | label_vocab = count2vocab([label_count], 115 | start_idx=0, 116 | sort=True, 117 | ignore_case=False) 118 | char_vocab = count2vocab([char_count], 119 | ignore_case=False, 120 | start_idx=C.CHAR_EMBED_START_IDX) 121 | if embed_file: 122 | logger.info('Scaning pre-trained embeddings') 123 | token_vocab = {} 124 | with open(embed_file, 'r', encoding='utf-8') as embed_r: 125 | if args.embed_skip_first: 126 | embed_r.readline() 127 | for line in embed_r: 128 | try: 129 | token = line[:line.find(' ')] 130 | if word_ignore_case: 131 | token = token.lower() 132 | if token not in token_vocab: 133 | token_vocab[token] = len(token_vocab) + C.EMBED_START_IDX 134 | if token.lower() not in token_vocab: 135 | token_vocab[token.lower()] = len(token_vocab) \ 136 | + C.EMBED_START_IDX 137 | except UnicodeDecodeError as e: 138 | logger.warning(e) 139 | idx_token = {idx: token for token, idx in token_vocab.items()} 140 | idx_label = {idx: label for label, idx in label_vocab.items()} 141 | idx_token[C.UNKNOWN_TOKEN_INDEX] = C.UNKNOWN_TOKEN 142 | 143 | # Numberize datasets 144 | logger.info('Numberizing datasets') 145 | numberize_datasets( 146 | [ 147 | (train_set, token_vocab, label_vocab, char_vocab), 148 | (dev_set, token_vocab, label_vocab, char_vocab), 149 | (test_set, token_vocab, label_vocab, char_vocab), 150 | ], 151 | token_ignore_case=word_ignore_case, 152 | label_ignore_case=False, 153 | char_ignore_case=False 154 | ) 155 | 156 | # Model components 157 | logger.info('Building the model') 158 | word_embed = Embedding(Config({ 159 | 'num_embeddings': len(token_vocab), 160 | 'embedding_dim': args.word_embed_dim, 161 | 'padding': C.EMBED_START_IDX, 162 | 'padding_idx': 0, 163 | 'sparse': True, 164 | 'trainable': True, 165 | 'file': embed_file, 166 | 'stats': args.embed_skip_first, 167 | 'vocab': token_vocab, 168 | 'ignore_case': word_ignore_case 169 | })) 170 | char_cnn = CharCNN(Config({ 171 | 'vocab_size': len(char_vocab), 172 | 'padding': C.CHAR_EMBED_START_IDX, 173 | 'dimension': args.char_embed_dim, 174 | 'filters': charcnn_filters 175 | })) 176 | char_highway = Highway(Config({ 177 | 'num_layers': 2, 178 | 'size': char_cnn.output_size, 179 | 'activation': 'selu' 180 | })) 181 | lstm = LSTM(Config({ 182 | 'input_size': word_embed.output_size + char_cnn.output_size, 183 | 'hidden_size': args.lstm_hidden_size, 184 | 'forget_bias': 1.0, 185 | 'batch_first': True, 186 | 'bidirectional': True 187 | })) 188 | crf = CRF(Config({ 189 | 'label_vocab': label_vocab 190 | })) 191 | output_linear = Linear(Config({ 192 | 'in_features': lstm.output_size, 193 | 'out_features': len(label_vocab) 194 | })) 195 | 196 | # LSTM CRF Model 197 | lstm_crf = LstmCrf( 198 | token_vocab=token_vocab, 199 | label_vocab=label_vocab, 200 | char_vocab=char_vocab, 201 | word_embedding=word_embed, 202 | char_embedding=char_cnn, 203 | crf=crf, 204 | lstm=lstm, 205 | univ_fc_layer=output_linear, 206 | embed_dropout_prob=args.embed_dropout, 207 | lstm_dropout_prob=args.lstm_dropout, 208 | linear_dropout_prob=args.linear_dropout, 209 | char_highway=char_highway 210 | ) 211 | 212 | if use_gpu: 213 | torch.cuda.set_device(args.gpu_idx) 214 | lstm_crf.cuda() 215 | 216 | # Task 217 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, lstm_crf.parameters()), 218 | lr=args.lr, momentum=args.momentum) 219 | 220 | state = { 221 | 'model': { 222 | 'word_embed': word_embed.state_dict(), 223 | 'char_cnn': char_cnn.state_dict(), 224 | 'char_highway': char_highway.state_dict(), 225 | 'lstm': lstm.state_dict(), 226 | 'crf': crf.state_dict(), 227 | 'output_linear': output_linear.state_dict(), 228 | 'lstm_crf': lstm_crf.state_dict() 229 | }, 230 | 'args': vars(args), 231 | 'vocab': { 232 | 'token': token_vocab, 233 | 'label': label_vocab, 234 | 'char': char_vocab, 235 | } 236 | } 237 | 238 | try: 239 | global_step = 0 240 | best_dev_score = 0.0 241 | 242 | for epoch in range(args.max_epoch): 243 | logger.info('Epoch {}: Training'.format(epoch)) 244 | 245 | for ds in ['train', 'dev', 'test']: 246 | dataset = datasets[ds] 247 | epoch_loss = [] 248 | results = [] 249 | 250 | progress = tqdm.tqdm(total=dataset.batch_num(args.batch_size), 251 | mininterval=1, desc=ds) 252 | for batch in dataset.get_dataset( 253 | gpu=use_gpu, 254 | shuffle_inst=ds == 'train', 255 | batch_size=args.batch_size): 256 | optimizer.zero_grad() 257 | progress.update(1) 258 | tokens, labels, chars, seq_lens, char_lens = batch 259 | if ds == 'train': 260 | global_step += 1 261 | loglik, _ = lstm_crf.loglik( 262 | tokens, labels, seq_lens, chars, char_lens) 263 | loss = -loglik.mean() 264 | loss.backward() 265 | 266 | params = [p for n, p in lstm_crf.named_parameters() 267 | if 'embedding.weight' not in n] 268 | clip_grad_norm(params, args.grad_clipping) 269 | optimizer.step() 270 | 271 | else: 272 | pred, loss = lstm_crf.predict( 273 | tokens, labels, seq_lens, chars, char_lens) 274 | results.append((pred, labels, seq_lens, tokens)) 275 | 276 | epoch_loss.append(loss.data[0]) 277 | progress.close() 278 | 279 | epoch_loss = sum(epoch_loss) / len(epoch_loss) 280 | logger.info('{} Loss: {:.4f}'.format(ds, epoch_loss)) 281 | 282 | if ds == 'dev' or ds == 'test': 283 | fscore, prec, rec = evaluate( 284 | results, idx_token, idx_label, writer=log_writer) 285 | if ds == 'dev' and fscore > best_dev_score: 286 | logger.info('New best score: {:.4f}'.format(fscore)) 287 | best_dev_score = fscore 288 | logger.info( 289 | 'Saving the current model to {}'.format(model_file)) 290 | torch.save(state, model_file) 291 | 292 | # learning rate decay 293 | lr = args.lr * args.decay_rate ** (global_step / args.decay_step) 294 | for p in optimizer.param_groups: 295 | p['lr'] = lr 296 | logger.info('New learning rate: {}'.format(lr)) 297 | 298 | logger.info('Best score: {}'.format(best_dev_score)) 299 | logger.info('Model file: {}'.format(model_file)) 300 | if args.log: 301 | logger.info('Log file: {}'.format(log_file)) 302 | except Exception: 303 | traceback.print_exc() 304 | if log_writer: 305 | log_writer.close() 306 | -------------------------------------------------------------------------------- /old/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import constant as C 3 | 4 | from util import get_logger 5 | from collections import Counter, defaultdict 6 | from torch.autograd import Variable 7 | from random import shuffle, uniform, sample 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | PARSERS = {} 13 | DATASETS = {} 14 | 15 | 16 | def register_parser(name): 17 | def register(cls): 18 | if name not in PARSERS: 19 | PARSERS[name] = cls 20 | return cls 21 | return register 22 | 23 | 24 | def register_dataset(name): 25 | def register(cls): 26 | if name not in DATASETS: 27 | DATASETS[name] = cls 28 | return cls 29 | return register 30 | 31 | 32 | def create_parser(name, conf): 33 | if name in PARSERS: 34 | return PARSERS[name](conf) 35 | else: 36 | raise ValueError('Parser {} is not registered'.format(name)) 37 | 38 | 39 | def create_dataset(name, conf): 40 | if name in DATASETS: 41 | return DATASETS[name](conf) 42 | else: 43 | raise ValueError('Dataset {} is not registered'.format(name)) 44 | 45 | 46 | def compute_metadata(datasets): 47 | """Compute tokens, labels, and characters in the given data sets. 48 | 49 | :param datasets: A list of data sets. 50 | :return: dicts of token, label, and character counts. 51 | """ 52 | token_count = defaultdict(int) 53 | label_count = defaultdict(int) 54 | char_count = defaultdict(int) 55 | 56 | for dataset in datasets: 57 | if dataset: 58 | t, l, c = dataset.metadata() 59 | for k, v in t.items(): 60 | token_count[k] += v 61 | for k, v in l.items(): 62 | label_count[k] += v 63 | for k, v in c.items(): 64 | char_count[k] += v 65 | 66 | return token_count, label_count, char_count 67 | 68 | 69 | def count2vocab(counts, 70 | start_idx=0, 71 | ignore_case=False, 72 | min_count=0, 73 | sort=False, 74 | sort_func=lambda x: (len(x[0]), x[0])): 75 | """ 76 | 77 | :param counts: 78 | :param start_idx: 79 | :param ignore_case: 80 | :param min_count: 81 | :param sort: Sort the keys. 82 | :param sort_func: Key sorting lambda function. 83 | :return: 84 | """ 85 | 86 | current_idx = start_idx 87 | merge_count = defaultdict(int) 88 | for count in counts: 89 | for k, v in count.items(): 90 | if ignore_case: 91 | k = k.lower() 92 | merge_count[k] += v 93 | 94 | vocab = {} 95 | if sort: 96 | merge_count_list = [(k, v) for k, v in merge_count.items()] 97 | merge_count_list.sort(key=sort_func) 98 | for k, v in merge_count_list: 99 | if v >= min_count: 100 | vocab[k] = current_idx 101 | current_idx += 1 102 | else: 103 | for k, v in merge_count.items(): 104 | if v >= min_count: 105 | vocab[k] = current_idx 106 | current_idx += 1 107 | return vocab 108 | 109 | 110 | def numberize_datasets(confs, 111 | token_ignore_case=True, 112 | label_ignore_case=False, 113 | char_ignore_case=False): 114 | for dataset, token_vocab, label_vocab, char_vocab in confs: 115 | dataset.numberize(token_vocab, 116 | label_vocab, 117 | char_vocab, 118 | token_ignore_case=token_ignore_case, 119 | label_ignore_case=label_ignore_case, 120 | char_ignore_case=char_ignore_case) 121 | 122 | 123 | @register_parser('conll') 124 | class ConllParser(object): 125 | """Parse CoNLL format file.""" 126 | 127 | # def __init__(self, 128 | # separator='\t', 129 | # token_col=0, 130 | # label_col=1, 131 | # skip_comment=True): 132 | def __init__(self, conf): 133 | """ 134 | :param conf: Config object with the following fields: 135 | - separator: Column separator (default='\t'). 136 | - token_col: Index of the token column. 137 | - label_col: Index of the label column. 138 | - skip_comment: Skip lines starting with '#'. 139 | """ 140 | # self.separator = separator 141 | # self.token_col = token_col 142 | # self.label_col = label_col 143 | # self.skip_comment = skip_comment 144 | self.separator = getattr(conf, 'separator', '\t') 145 | self.token_col = getattr(conf, 'token_col', 0) 146 | self.label_col = getattr(conf, 'label_col', 1) 147 | self.skip_comment = getattr(conf, 'skip_comment', True) 148 | 149 | def parse(self, path): 150 | """ 151 | :param path: Path to the file to be parsed. 152 | :return: Lists of tokens and labels. 153 | """ 154 | with open(path, 'r', encoding='utf-8') as r: 155 | current_doc = [] 156 | for line in r: 157 | line = line.strip() 158 | if self.skip_comment and line.startswith('#'): 159 | continue 160 | if line: 161 | segs = line.split(self.separator) 162 | token, label = segs[self.token_col].strip(), segs[self.label_col] 163 | if label in {'B-O', 'I-O', 'E-O', 'S-O'}: 164 | label = 'O' 165 | current_doc.append((token, label)) 166 | elif current_doc: 167 | tokens = [] 168 | labels = [] 169 | for token, label in current_doc: 170 | tokens.append(token) 171 | labels.append(label) 172 | current_doc = [] 173 | yield tokens, labels 174 | if current_doc: 175 | tokens = [] 176 | labels = [] 177 | for token, label in current_doc: 178 | tokens.append(token) 179 | labels.append(label) 180 | yield tokens, labels 181 | 182 | 183 | @register_dataset('sequence') 184 | class SequenceDataset(object): 185 | 186 | # def __init__(self, 187 | # path, 188 | # parser, 189 | # batch_size=1, 190 | # sample=None, 191 | # max_len=10000): 192 | def __init__(self, conf): 193 | """ 194 | :param conf: Config object with the following fields: 195 | - path: Path to the data set. 196 | - parser: File parser. 197 | - batch_size: Batch size (default=1). 198 | - sample: Sample rate (default=None). It can be set to: 199 | * None or 'all': the data set won't be sampled. 200 | * An int number: sample examples from the data set. 201 | * A float number in (0, 1]: the data set will be sampled at the given rate. 202 | - max_len: Max example token number. 203 | """ 204 | 205 | # self.path = path 206 | # self.parser = parser 207 | # self.batch_size = batch_size 208 | # self.sample = sample 209 | # self.max_len = max_len 210 | 211 | assert hasattr(conf, 'path'), 'dataset path is required' 212 | assert hasattr(conf, 'parser'), 'dataset parser is required' 213 | 214 | self.path = getattr(conf, 'path') 215 | self.parser = getattr(conf, 'parser') 216 | self.batch_size = getattr(conf, 'batch_size', 1) 217 | self.sample = getattr(conf, 'sample', None) 218 | self.max_len = getattr(conf, 'max_len', 10000) 219 | 220 | self.dataset = [] 221 | self.batches = [] 222 | self.dataset_numberized = [] 223 | self.doc_num = 0 224 | 225 | self.load() 226 | 227 | def load(self): 228 | if self.sample is None or self.sample == 'all': 229 | self.dataset = [x for x in self.parser.parse(self.path) 230 | if len(x[0]) < self.max_len] 231 | elif type(self.sample) is int: 232 | self.dataset = [x for x in self.parser.parse(self.path) 233 | if len(x[0]) < self.max_len] 234 | if len(self.dataset) > self.sample: 235 | self.dataset = sample(self.dataset, self.sample) 236 | elif type(self.sample) is float: 237 | assert 0 < self.sample <= 1 238 | self.dataset = [x for x in self.parser.parse(self.path) 239 | if uniform(0, 1) < self.sample 240 | and len(x[0]) < self.max_len] 241 | self.doc_num = len(self.dataset) 242 | 243 | def metadata(self): 244 | token_count = Counter() 245 | label_count = Counter() 246 | char_count = Counter() 247 | for tokens, labels in self.dataset: 248 | token_count.update(tokens) 249 | label_count.update(labels) 250 | for token in tokens: 251 | char_count.update([c for c in token]) 252 | return token_count, label_count, char_count 253 | 254 | def numberize(self, 255 | token_vocab, 256 | label_vocab, 257 | char_vocab, 258 | token_ignore_case=True, 259 | label_ignore_case=False, 260 | char_ignore_case=False): 261 | self.dataset_numberized = [] 262 | for tokens, labels in self.dataset: 263 | if char_ignore_case: 264 | chars = [t.lower() for t in tokens] 265 | else: 266 | chars = tokens 267 | char_idxs = [[char_vocab[c] if c in char_vocab 268 | else C.UNKNOWN_TOKEN_INDEX for c in t] for t in chars] 269 | if token_ignore_case: 270 | tokens = [t.lower() for t in tokens] 271 | token_idxs = [token_vocab[x] if x in token_vocab 272 | else token_vocab[x.lower()] if x.lower() in token_vocab 273 | else C.UNKNOWN_TOKEN_INDEX for x in tokens] 274 | if label_ignore_case: 275 | labels = [l.lower() for l in labels] 276 | label_idxs = [label_vocab[l] for l in labels] 277 | self.dataset_numberized.append((token_idxs, label_idxs, char_idxs)) 278 | 279 | def sample_batches(self, shuffle_inst=True): 280 | self.batches = [] 281 | inst_idxs = [i for i in range(len(self.dataset_numberized))] 282 | if shuffle_inst: 283 | shuffle(inst_idxs) 284 | self.batches = [inst_idxs[i:i + self.batch_size] for i in 285 | range(0, len(self.dataset_numberized), self.batch_size)] 286 | 287 | def get_batch(self, volatile=False, gpu=False, shuffle_inst=True): 288 | if len(self.batches) == 0: 289 | self.sample_batches(shuffle_inst) 290 | 291 | batch = self.batches.pop() 292 | batch = [self.dataset_numberized[idx] for idx in batch] 293 | batch.sort(key=lambda x: len(x[0]), reverse=True) 294 | 295 | seq_lens = [len(s[0]) for s in batch] 296 | max_seq_len = max(seq_lens) 297 | 298 | char_lens = [] 299 | for seq in batch: 300 | seq_char_lens = [len(x) for x in seq[2]] \ 301 | + [1] * (max_seq_len - len(seq[0])) 302 | char_lens.extend(seq_char_lens) 303 | max_char_len = max(max(char_lens), 4) 304 | 305 | tokens, labels, chars = [], [], [] 306 | for t, l, c in batch: 307 | tokens.append(t + [0] * (max_seq_len - len(t))) 308 | labels.append(l + [0] * (max_seq_len - len(l))) 309 | chars_padded = [x + [0] * (max_char_len - len(x)) 310 | for x in c] \ 311 | + [[0] * max_char_len] * (max_seq_len - len(t)) 312 | chars.extend(chars_padded) 313 | 314 | tokens = Variable(torch.LongTensor(tokens), volatile=volatile) 315 | labels = Variable(torch.LongTensor(labels), volatile=volatile) 316 | chars = Variable(torch.LongTensor(chars), volatile=volatile) 317 | seq_lens = Variable(torch.LongTensor(seq_lens), volatile=volatile) 318 | char_lens = Variable(torch.LongTensor(char_lens), volatile=volatile) 319 | 320 | if gpu: 321 | tokens = tokens.cuda() 322 | labels = labels.cuda() 323 | chars = chars.cuda() 324 | seq_lens = seq_lens.cuda() 325 | char_lens = char_lens.cuda() 326 | 327 | return tokens, labels, chars, seq_lens, char_lens 328 | 329 | def get_dataset(self, volatile=False, gpu=False, batch_size=100, 330 | shuffle_inst=False): 331 | batches_ = self.batches 332 | batch_size_ = self.batch_size 333 | 334 | self.batches = [] 335 | self.batch_size = batch_size 336 | self.sample_batches(shuffle_inst=shuffle_inst) 337 | 338 | while self.batches: 339 | yield self.get_batch(volatile, gpu) 340 | 341 | self.batches = batches_ 342 | self.batch_size = batch_size_ 343 | 344 | def batch_num(self, batch_size): 345 | return -(-len(self.dataset_numberized) // batch_size) -------------------------------------------------------------------------------- /train_crosslingual.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from tqdm import tqdm 4 | import time 5 | import logging 6 | import traceback 7 | from collections import Counter 8 | 9 | import torch 10 | from torch import optim 11 | from torch.nn.utils import clip_grad_norm_ 12 | 13 | import constant as C 14 | from random import shuffle 15 | from argparse import ArgumentParser 16 | 17 | from torch.utils.data import DataLoader 18 | 19 | from util import evaluate 20 | from data import ConllParser, SeqLabelDataset, SeqLabelProcessor, count2vocab 21 | from model import Linear, LSTM, CRF, CharCNN, Highway, LstmCrf, load_embedding 22 | 23 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 24 | 25 | logging.basicConfig(level=logging.DEBUG) 26 | logger = logging.getLogger() 27 | 28 | argparser = ArgumentParser() 29 | 30 | # Target 31 | argparser.add_argument('--train_tgt', help='Path to the training set file') 32 | argparser.add_argument('--dev_tgt', help='Path to the dev set file') 33 | argparser.add_argument('--test_tgt', help='Path to the test set file') 34 | # Cross-lingual: same task, diff languages 35 | argparser.add_argument('--train_cl', help='Path to the training set file') 36 | argparser.add_argument('--dev_cl', help='Path to the dev set file') 37 | argparser.add_argument('--test_cl', help='Path to the test set file') 38 | 39 | argparser.add_argument('--log', help='Path to the log dir') 40 | argparser.add_argument('--model', help='Path to the model file') 41 | argparser.add_argument('--batch_size', default=10, type=int, help='Batch size') 42 | argparser.add_argument('--max_epoch', default=100, type=int) 43 | argparser.add_argument('--word_embed_1', 44 | help='Path to the pre-trained embedding file for lang 1') 45 | argparser.add_argument('--word_embed_2', 46 | help='Path to the pre-trained embedding file for lang 2') 47 | argparser.add_argument('--word_embed_dim', type=int, default=100, 48 | help='Word embedding dimension') 49 | argparser.set_defaults(word_ignore_case=False) 50 | argparser.add_argument('--char_embed_dim', type=int, default=50, 51 | help='Character embedding dimension') 52 | argparser.add_argument('--charcnn_filters', default='2,25;3,25;4,25', 53 | help='Character-level CNN filters') 54 | argparser.add_argument('--charhw_layer', default=1, type=int) 55 | argparser.add_argument('--charhw_func', default='relu') 56 | argparser.add_argument('--use_highway', action='store_true') 57 | argparser.add_argument('--lstm_hidden_size', default=100, type=int, 58 | help='LSTM hidden state size') 59 | argparser.add_argument('--lstm_forget_bias', default=0, type=float, 60 | help='LSTM forget bias') 61 | argparser.add_argument('--feat_dropout', default=.5, type=float, 62 | help='Word feature dropout probability') 63 | argparser.add_argument('--lstm_dropout', default=.5, type=float, 64 | help='LSTM output dropout probability') 65 | argparser.add_argument('--lr', default=0.005, type=float, 66 | help='Learning rate') 67 | argparser.add_argument('--momentum', default=.9, type=float) 68 | argparser.add_argument('--decay_rate', default=.9, type=float) 69 | argparser.add_argument('--decay_step', default=10000, type=int) 70 | argparser.add_argument('--grad_clipping', default=5, type=float) 71 | argparser.add_argument('--gpu', action='store_true') 72 | argparser.add_argument('--device', default=0, type=int) 73 | argparser.add_argument('--thread', default=5, type=int) 74 | 75 | args = argparser.parse_args() 76 | batch_size = args.batch_size 77 | 78 | use_gpu = args.gpu and torch.cuda.is_available() 79 | if use_gpu: 80 | torch.cuda.set_device(args.device) 81 | torch.set_num_threads(args.thread) 82 | 83 | # Model file 84 | model_dir = args.model 85 | assert model_dir and os.path.isdir(model_dir), 'Model output dir is required' 86 | model_file = os.path.join(model_dir, 'model.{}.mdl'.format(timestamp)) 87 | 88 | # Logging file 89 | log_writer = None 90 | if args.log: 91 | log_file = os.path.join(args.log, 'log.{}.txt'.format(timestamp)) 92 | log_writer = open(log_file, 'a', encoding='utf-8') 93 | logger.addHandler(logging.FileHandler(log_file, encoding='utf-8')) 94 | logger.info('----------') 95 | logger.info('Parameters:') 96 | for arg in vars(args): 97 | logger.info('{}: {}'.format(arg, getattr(args, arg))) 98 | logger.info('----------') 99 | 100 | # Data file 101 | logger.info('Loading data sets') 102 | ner_parser = ConllParser(skip_comment=True, separator='\t') 103 | 104 | train_set_tgt = SeqLabelDataset(args.train_tgt, parser=ner_parser) 105 | dev_set_tgt = SeqLabelDataset(args.dev_tgt, parser=ner_parser) 106 | test_set_tgt = SeqLabelDataset(args.test_tgt, parser=ner_parser) 107 | 108 | train_set_cl = SeqLabelDataset(args.train_cl, parser=ner_parser) 109 | dev_set_cl = SeqLabelDataset(args.dev_cl, parser=ner_parser) 110 | test_set_cl = SeqLabelDataset(args.test_cl, parser=ner_parser) 111 | 112 | datasets = { 113 | 'tgt': {'train': train_set_tgt, 'dev': dev_set_tgt, 'test': test_set_tgt}, 114 | 'cl': {'train': train_set_cl, 'dev': dev_set_cl, 'test': test_set_cl}, 115 | } 116 | 117 | # Vocabs 118 | logger.info('Building vocabs') 119 | ( 120 | token_count_1, token_count_2, char_count, label_count_1 121 | ) = Counter(), Counter(), Counter(), Counter() 122 | for _, ds in datasets['tgt'].items(): 123 | tc, cc, lc = ds.stats() 124 | token_count_1.update(tc) 125 | char_count.update(cc) 126 | label_count_1.update(lc) 127 | for _, ds in datasets['cl'].items(): 128 | tc, cc, lc = ds.stats() 129 | token_count_2.update(tc) 130 | char_count.update(cc) 131 | label_count_1.update(lc) 132 | 133 | token_vocab_1 = count2vocab(token_count_1, offset=len(C.TOKEN_PADS), pads=C.TOKEN_PADS) 134 | token_vocab_2 = count2vocab(token_count_2, offset=len(C.TOKEN_PADS), pads=C.TOKEN_PADS) 135 | char_vocab = count2vocab(char_count, offset=len(C.CHAR_PADS), pads=C.CHAR_PADS) 136 | label_vocab_1 = count2vocab(label_count_1, offset=1, pads=[(C.PAD, C.PAD_INDEX)]) 137 | 138 | idx_token_1 = {idx: token for token, idx in token_vocab_1.items()} 139 | idx_token_2 = {idx: token for token, idx in token_vocab_2.items()} 140 | idx_label_1 = {idx: label for label, idx in label_vocab_1.items()} 141 | 142 | idx_tokens = { 143 | 'tgt': idx_token_1, 144 | 'cl': idx_token_2, 145 | } 146 | idx_labels = { 147 | 'tgt': idx_label_1, 148 | 'cl': idx_label_1, 149 | } 150 | 151 | print('#token (lang 1): {}'.format(len(token_vocab_1))) 152 | print('#token (lang 2): {}'.format(len(token_vocab_2))) 153 | print('#label: {}'.format(len(label_vocab_1))) 154 | print('#char: {}'.format(len(char_vocab))) 155 | 156 | train_set_tgt.numberize(token_vocab_1, label_vocab_1, char_vocab) 157 | dev_set_tgt.numberize(token_vocab_1, label_vocab_1, char_vocab) 158 | test_set_tgt.numberize(token_vocab_1, label_vocab_1, char_vocab) 159 | 160 | train_set_cl.numberize(token_vocab_2, label_vocab_1, char_vocab) 161 | dev_set_cl.numberize(token_vocab_2, label_vocab_1, char_vocab) 162 | test_set_cl.numberize(token_vocab_2, label_vocab_1, char_vocab) 163 | 164 | # Embedding file 165 | word_embed_1 = load_embedding(args.word_embed_1, 166 | dimension=args.word_embed_dim, 167 | vocab=token_vocab_1) 168 | word_embed_2 = load_embedding(args.word_embed_2, 169 | dimension=args.word_embed_dim, 170 | vocab=token_vocab_2) 171 | charcnn_filters = [[int(f.split(',')[0]), int(f.split(',')[1])] 172 | for f in args.charcnn_filters.split(';')] 173 | char_embed = CharCNN(len(char_vocab), 174 | args.char_embed_dim, 175 | filters=charcnn_filters) 176 | char_hw = Highway(char_embed.output_size, 177 | layer_num=args.charhw_layer, 178 | activation=args.charhw_func) 179 | feat_dim = args.word_embed_dim + char_embed.output_size 180 | lstm = LSTM(feat_dim, 181 | args.lstm_hidden_size, 182 | batch_first=True, 183 | bidirectional=True, 184 | forget_bias=args.lstm_forget_bias 185 | ) 186 | crf_1 = CRF(label_size=len(label_vocab_1) + 2) 187 | 188 | # Linear layers for task 1 189 | shared_linear_1 = Linear(in_features=lstm.output_size, 190 | out_features=len(label_vocab_1)) 191 | spec_linear_1_1 = Linear(in_features=lstm.output_size, 192 | out_features=len(label_vocab_1)) 193 | spec_linear_1_2 = Linear(in_features=lstm.output_size, 194 | out_features=len(label_vocab_1)) 195 | 196 | lstm_crf_tgt = LstmCrf( 197 | token_vocab_1, label_vocab_1, char_vocab, 198 | word_embedding=word_embed_1, 199 | char_embedding=char_embed, 200 | crf=crf_1, 201 | lstm=lstm, 202 | univ_fc_layer=shared_linear_1, 203 | spec_fc_layer=spec_linear_1_1, 204 | embed_dropout_prob=args.feat_dropout, 205 | lstm_dropout_prob=args.lstm_dropout, 206 | char_highway=char_hw if args.use_highway else None 207 | ) 208 | lstm_crf_cl = LstmCrf( 209 | token_vocab_2, label_vocab_1, char_vocab, 210 | word_embedding=word_embed_2, 211 | char_embedding=char_embed, 212 | crf=crf_1, 213 | lstm=lstm, 214 | univ_fc_layer=shared_linear_1, 215 | spec_fc_layer=spec_linear_1_2, 216 | embed_dropout_prob=args.feat_dropout, 217 | lstm_dropout_prob=args.lstm_dropout, 218 | char_highway=char_hw if args.use_highway else None 219 | ) 220 | 221 | if use_gpu: 222 | lstm_crf_tgt.cuda() 223 | lstm_crf_cl.cuda() 224 | models = { 225 | 'tgt': lstm_crf_tgt, 226 | 'cl': lstm_crf_cl 227 | } 228 | 229 | # Task 230 | optimizer_tgt = optim.SGD( 231 | filter(lambda p: p.requires_grad, lstm_crf_tgt.parameters()), 232 | lr=args.lr, momentum=args.momentum) 233 | optimizer_cl = optim.SGD( 234 | filter(lambda p: p.requires_grad, lstm_crf_cl.parameters()), 235 | lr=args.lr, momentum=args.momentum) 236 | optimizers = { 237 | 'tgt': optimizer_tgt, 238 | 'cl': optimizer_cl, 239 | } 240 | processor = SeqLabelProcessor(gpu=use_gpu) 241 | 242 | train_args = vars(args) 243 | train_args['word_embed_size'] = word_embed_1.num_embeddings 244 | state = { 245 | 'model': { 246 | 'word_embed': word_embed_1.state_dict(), 247 | 'char_embed': char_embed.state_dict(), 248 | 'char_hw': char_hw.state_dict(), 249 | 'lstm': lstm.state_dict(), 250 | 'crf': crf_1.state_dict(), 251 | 'univ_linear': shared_linear_1.state_dict(), 252 | 'spec_linear': spec_linear_1_1.state_dict(), 253 | 'lstm_crf': lstm_crf_tgt.state_dict() 254 | }, 255 | 'args': train_args, 256 | 'vocab': { 257 | 'token': token_vocab_1, 258 | 'label': label_vocab_1, 259 | 'char': char_vocab, 260 | } 261 | } 262 | 263 | # Calculate mixing rates 264 | batch_num = len(train_set_tgt) // batch_size 265 | r_tgt = math.sqrt(len(train_set_tgt)) 266 | r_cl = 1.0 * .1 * math.sqrt(len(datasets['cl']['train'])) 267 | num_cl = int(r_cl / r_tgt * batch_num) 268 | print('{}, {}'.format(batch_num, num_cl)) 269 | 270 | data_loaders = {} 271 | data_loader_iters = {} 272 | for task, task_datasets in datasets.items(): 273 | data_loaders[task] = { 274 | k: DataLoader(v, 275 | batch_size=batch_size, 276 | shuffle=k == 'train', 277 | drop_last=k == 'train', 278 | collate_fn=processor.process) 279 | for k, v in task_datasets.items() 280 | } 281 | data_loader_iters[task] = {k: iter(v) for k, v 282 | in data_loaders[task].items()} 283 | 284 | try: 285 | global_step = 0 286 | best_dev_score = best_test_score = 0.0 287 | 288 | for epoch in range(args.max_epoch): 289 | logger.info('Epoch {}: Training'.format(epoch + 1)) 290 | best = False 291 | 292 | for ds in ['train', 'dev', 'test']: 293 | if ds == 'train': 294 | tasks = ['tgt'] * batch_num + ['cl'] * num_cl 295 | shuffle(tasks) 296 | progress = tqdm(total=len(tasks), mininterval=1, 297 | desc=ds) 298 | for task in tasks: 299 | progress.update(1) 300 | global_step += 1 301 | model = models[task] 302 | optimizer = optimizers[task] 303 | optimizer.zero_grad() 304 | try: 305 | batch = next(data_loader_iters[task]['train']) 306 | except StopIteration: 307 | data_loader_iters[task]['train'] = iter( 308 | data_loaders[task]['train'] 309 | ) 310 | batch = next(data_loader_iters[task]['train']) 311 | tokens, labels, chars, seq_lens, char_lens = batch 312 | loglik, _ = model.loglik( 313 | tokens, labels, seq_lens, chars, char_lens) 314 | loss = -loglik.mean() 315 | loss.backward() 316 | 317 | clip_grad_norm_(model.parameters(), args.grad_clipping) 318 | optimizer.step() 319 | progress.close() 320 | else: 321 | for task in ['tgt', 'cl']: 322 | logger.info('task: {} dataset: {}'.format(task, ds)) 323 | results = [] 324 | for batch in data_loaders[task][ds]: 325 | tokens, labels, chars, seq_lens, char_lens = batch 326 | pred, _loss = models[task].predict( 327 | tokens, labels, seq_lens, chars, char_lens 328 | ) 329 | results.append((pred, labels, seq_lens, tokens)) 330 | fscore, prec, rec = evaluate( 331 | results, idx_tokens[task], idx_labels[task], 332 | writer=log_writer) 333 | if ds == 'dev' and task == 'tgt' and fscore > best_dev_score: 334 | logger.info('New best score: {:.4f}'.format(fscore)) 335 | best_dev_score = fscore 336 | best = True 337 | logger.info('Saving the model to {}'.format(model_file)) 338 | torch.save(state, model_file) 339 | if best and ds == 'test' and task == 'tgt': 340 | best_test_score = fscore 341 | 342 | # learning rate decay 343 | lr = args.lr * args.decay_rate ** (global_step / args.decay_step) 344 | for opt in optimizers.values(): 345 | for p in opt.param_groups: 346 | p['lr'] = lr 347 | logger.info('New learning rate: {}'.format(lr)) 348 | 349 | logger.info('Best dev score: {}'.format(best_dev_score)) 350 | logger.info('Best test score: {}'.format(best_test_score)) 351 | logger.info('Model file: {}'.format(model_file)) 352 | if args.log: 353 | logger.info('Log file: {}'.format(log_file)) 354 | log_writer.close() 355 | except Exception: 356 | traceback.print_exc() 357 | if log_writer: 358 | log_writer.close() 359 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as I 6 | import torch.nn.utils.rnn as R 7 | import torch.nn.functional as F 8 | 9 | import re 10 | import logging 11 | import constant as C 12 | 13 | logger = logging.getLogger() 14 | 15 | 16 | def log_sum_exp(tensor, dim=0): 17 | """LogSumExp operation.""" 18 | m, _ = torch.max(tensor, dim) 19 | m_exp = m.unsqueeze(-1).expand_as(tensor) 20 | return m + torch.log(torch.sum(torch.exp(tensor - m_exp), dim)) 21 | 22 | 23 | def sequence_mask(lens, max_len=None): 24 | batch_size = lens.size(0) 25 | 26 | if max_len is None: 27 | max_len = lens.max().item() 28 | 29 | ranges = torch.arange(0, max_len).long() 30 | ranges = ranges.unsqueeze(0).expand(batch_size, max_len) 31 | 32 | if lens.data.is_cuda: 33 | ranges = ranges.cuda() 34 | 35 | lens_exp = lens.unsqueeze(1).expand_as(ranges) 36 | mask = ranges < lens_exp 37 | 38 | return mask 39 | 40 | 41 | def load_embedding(path: str, 42 | dimension: int, 43 | vocab: dict = None, 44 | skip_first_line: bool = True, 45 | ): 46 | logger.info('Scanning embedding file: {}'.format(path)) 47 | 48 | embed_vocab = set() 49 | lower_mapping = {} # lower case - original 50 | digit_mapping = {} # lower case + replace digit with 0 - original 51 | digit_pattern = re.compile('\d') 52 | with open(path, 'r', encoding='utf-8') as r: 53 | if skip_first_line: 54 | r.readline() 55 | for line in r: 56 | try: 57 | token = line.split(' ')[0].strip() 58 | if token: 59 | embed_vocab.add(token) 60 | token_lower = token.lower() 61 | token_digit = re.sub(digit_pattern, '0', token_lower) 62 | if token_lower not in lower_mapping: 63 | lower_mapping[token_lower] = token 64 | if token_digit not in digit_mapping: 65 | digit_mapping[token_digit] = token 66 | except UnicodeDecodeError: 67 | continue 68 | 69 | token_mapping = defaultdict(list) # embed token - vocab token 70 | for token in vocab: 71 | token_lower = token.lower() 72 | token_digit = re.sub(digit_pattern, '0', token_lower) 73 | if token in embed_vocab: 74 | token_mapping[token].append(token) 75 | elif token_lower in lower_mapping: 76 | token_mapping[lower_mapping[token_lower]].append(token) 77 | elif token_digit in digit_mapping: 78 | token_mapping[digit_mapping[token_digit]].append(token) 79 | 80 | logger.info('Loading embeddings') 81 | weight = [[.0] * dimension for _ in range(len(vocab))] 82 | with open(path, 'r', encoding='utf-8') as r: 83 | if skip_first_line: 84 | r.readline() 85 | for line in r: 86 | try: 87 | segs = line.rstrip().split(' ') 88 | token = segs[0] 89 | if token in token_mapping: 90 | vec = [float(v) for v in segs[1:]] 91 | for t in token_mapping.get(token): 92 | weight[vocab[t]] = vec.copy() 93 | except UnicodeDecodeError: 94 | continue 95 | except ValueError: 96 | continue 97 | embed = nn.Embedding( 98 | len(vocab), 99 | dimension, 100 | padding_idx=C.PAD_INDEX, 101 | sparse=True, 102 | _weight=torch.FloatTensor(weight) 103 | ) 104 | return embed 105 | 106 | 107 | class Linear(nn.Linear): 108 | def __init__(self, 109 | in_features: int, 110 | out_features: int, 111 | bias: bool = True): 112 | super(Linear, self).__init__(in_features, out_features, bias=bias) 113 | I.orthogonal_(self.weight) 114 | 115 | 116 | class Linears(nn.Module): 117 | def __init__(self, 118 | in_features: int, 119 | out_features: int, 120 | hiddens: list, 121 | bias: bool = True, 122 | activation: str = 'tanh'): 123 | super(Linears, self).__init__() 124 | assert len(hiddens) > 0 125 | 126 | self.in_features = in_features 127 | self.out_features = self.output_size = out_features 128 | 129 | in_dims = [in_features] + hiddens[:-1] 130 | self.linears = nn.ModuleList([Linear(in_dim, out_dim, bias=bias) 131 | for in_dim, out_dim 132 | in zip(in_dims, hiddens)]) 133 | self.output_linear = Linear(hiddens[-1], out_features, bias=bias) 134 | self.activation = getattr(F, activation) 135 | 136 | def forward(self, inputs): 137 | linear_outputs = inputs 138 | for linear in self.linears: 139 | linear_outputs = linear(linear_outputs) 140 | linear_outputs = self.activation(linear_outputs) 141 | return self.output_linear(linear_outputs) 142 | 143 | 144 | class Highway(nn.Module): 145 | def __init__(self, 146 | size: int, 147 | layer_num: int = 1, 148 | activation: str = 'relu'): 149 | super(Highway, self).__init__() 150 | self.size = self.output_size = size 151 | self.layer_num = layer_num 152 | self.activation = getattr(F, activation) 153 | self.non_linear = nn.ModuleList([Linear(size, size) 154 | for _ in range(layer_num)]) 155 | self.gate = nn.ModuleList([Linear(size, size) 156 | for _ in range(layer_num)]) 157 | 158 | def forward(self, inputs): 159 | for layer in range(self.layer_num): 160 | gate = F.sigmoid(self.gate[layer](inputs)) 161 | non_linear = self.activation(self.non_linear[layer](inputs)) 162 | inputs = gate * non_linear + (1 - gate) * inputs 163 | return inputs 164 | 165 | 166 | class LSTM(nn.LSTM): 167 | def __init__(self, 168 | input_size: int, 169 | hidden_size: int, 170 | num_layers: int = 1, 171 | bias: bool = True, 172 | batch_first: bool = False, 173 | dropout: float = 0, 174 | bidirectional: bool = False, 175 | forget_bias: float = 0 176 | ): 177 | super(LSTM, self).__init__(input_size=input_size, 178 | hidden_size=hidden_size, 179 | num_layers=num_layers, 180 | bias=bias, 181 | batch_first=batch_first, 182 | dropout=dropout, 183 | bidirectional=bidirectional) 184 | self.output_size = hidden_size * 2 if bidirectional else hidden_size 185 | self.forget_bias = forget_bias 186 | 187 | def initialize(self): 188 | for n, p in self.named_parameters(): 189 | if 'weight' in n: 190 | I.orthogonal_(p) 191 | elif 'bias' in n: 192 | bias_size = p.size(0) 193 | p[bias_size // 4:bias_size // 2].fill_(self.forget_bias) 194 | 195 | 196 | class CharCNN(nn.Module): 197 | 198 | def __init__(self, embedding_num, embedding_dim, filters): 199 | super(CharCNN, self).__init__() 200 | self.output_size = sum([x[1] for x in filters]) 201 | self.embedding = nn.Embedding(embedding_num, 202 | embedding_dim, 203 | padding_idx=0, 204 | sparse=True) 205 | self.convs = nn.ModuleList([nn.Conv2d(1, x[1], (x[0], embedding_dim)) 206 | for x in filters]) 207 | 208 | def forward(self, inputs): 209 | inputs_embed = self.embedding(inputs) 210 | inputs_embed = inputs_embed.unsqueeze(1) 211 | conv_outputs = [F.relu(conv(inputs_embed)).squeeze(3) 212 | for conv in self.convs] 213 | max_pool_outputs = [F.max_pool1d(i, i.size(2)).squeeze(2) 214 | for i in conv_outputs] 215 | outputs = torch.cat(max_pool_outputs, 1) 216 | return outputs 217 | 218 | 219 | class CRF(nn.Module): 220 | def __init__(self, label_size): 221 | super(CRF, self).__init__() 222 | 223 | self.label_size = label_size 224 | self.start = self.label_size - 2 225 | self.end = self.label_size - 1 226 | transition = torch.randn(self.label_size, self.label_size) 227 | self.transition = nn.Parameter(transition) 228 | self.initialize() 229 | 230 | def initialize(self): 231 | self.transition.data[:, self.end] = -100.0 232 | self.transition.data[self.start, :] = -100.0 233 | 234 | def pad_logits(self, logits): 235 | # lens = lens.data 236 | batch_size, seq_len, label_num = logits.size() 237 | # pads = Variable(logits.data.new(batch_size, seq_len, 2).fill_(-1000.0), 238 | # requires_grad=False) 239 | pads = logits.new_full((batch_size, seq_len, 2), -1000.0, 240 | requires_grad=False) 241 | logits = torch.cat([logits, pads], dim=2) 242 | return logits 243 | 244 | def calc_binary_score(self, labels, lens): 245 | batch_size, seq_len = labels.size() 246 | 247 | # labels_ext = Variable(labels.data.new(batch_size, seq_len + 2)) 248 | labels_ext = labels.new_empty((batch_size, seq_len + 2)) 249 | labels_ext[:, 0] = self.start 250 | labels_ext[:, 1:-1] = labels 251 | mask = sequence_mask(lens + 1, max_len=(seq_len + 2)).long() 252 | # pad_stop = Variable(labels.data.new(1).fill_(self.end)) 253 | pad_stop = labels.new_full((1,), self.end, requires_grad=False) 254 | pad_stop = pad_stop.unsqueeze(-1).expand(batch_size, seq_len + 2) 255 | labels_ext = (1 - mask) * pad_stop + mask * labels_ext 256 | labels = labels_ext 257 | 258 | trn = self.transition 259 | trn_exp = trn.unsqueeze(0).expand(batch_size, *trn.size()) 260 | lbl_r = labels[:, 1:] 261 | lbl_rexp = lbl_r.unsqueeze(-1).expand(*lbl_r.size(), trn.size(0)) 262 | trn_row = torch.gather(trn_exp, 1, lbl_rexp) 263 | 264 | lbl_lexp = labels[:, :-1].unsqueeze(-1) 265 | trn_scr = torch.gather(trn_row, 2, lbl_lexp) 266 | trn_scr = trn_scr.squeeze(-1) 267 | 268 | mask = sequence_mask(lens + 1).float() 269 | trn_scr = trn_scr * mask 270 | score = trn_scr 271 | 272 | return score 273 | 274 | def calc_unary_score(self, logits, labels, lens): 275 | labels_exp = labels.unsqueeze(-1) 276 | scores = torch.gather(logits, 2, labels_exp).squeeze(-1) 277 | mask = sequence_mask(lens).float() 278 | scores = scores * mask 279 | return scores 280 | 281 | def calc_gold_score(self, logits, labels, lens): 282 | unary_score = self.calc_unary_score(logits, labels, lens).sum( 283 | 1).squeeze(-1) 284 | binary_score = self.calc_binary_score(labels, lens).sum(1).squeeze(-1) 285 | return unary_score + binary_score 286 | 287 | def calc_norm_score(self, logits, lens): 288 | batch_size, seq_len, feat_dim = logits.size() 289 | # alpha = logits.data.new(batch_size, self.label_size).fill_(-10000.0) 290 | alpha = logits.new_full((batch_size, self.label_size), -100.0) 291 | alpha[:, self.start] = 0 292 | # alpha = Variable(alpha) 293 | lens_ = lens.clone() 294 | 295 | logits_t = logits.transpose(1, 0) 296 | for logit in logits_t: 297 | logit_exp = logit.unsqueeze(-1).expand(batch_size, 298 | *self.transition.size()) 299 | alpha_exp = alpha.unsqueeze(1).expand(batch_size, 300 | *self.transition.size()) 301 | trans_exp = self.transition.unsqueeze(0).expand_as(alpha_exp) 302 | mat = logit_exp + alpha_exp + trans_exp 303 | alpha_nxt = log_sum_exp(mat, 2).squeeze(-1) 304 | 305 | mask = (lens_ > 0).float().unsqueeze(-1).expand_as(alpha) 306 | alpha = mask * alpha_nxt + (1 - mask) * alpha 307 | lens_ = lens_ - 1 308 | 309 | alpha = alpha + self.transition[self.end].unsqueeze(0).expand_as(alpha) 310 | norm = log_sum_exp(alpha, 1).squeeze(-1) 311 | 312 | return norm 313 | 314 | def viterbi_decode(self, logits, lens): 315 | """Borrowed from pytorch tutorial 316 | Arguments: 317 | logits: [batch_size, seq_len, n_labels] FloatTensor 318 | lens: [batch_size] LongTensor 319 | """ 320 | batch_size, seq_len, n_labels = logits.size() 321 | # vit = logits.data.new(batch_size, self.label_size).fill_(-10000) 322 | vit = logits.new_full((batch_size, self.label_size), -100.0) 323 | vit[:, self.start] = 0 324 | # vit = Variable(vit) 325 | c_lens = lens.clone() 326 | 327 | logits_t = logits.transpose(1, 0) 328 | pointers = [] 329 | for logit in logits_t: 330 | vit_exp = vit.unsqueeze(1).expand(batch_size, n_labels, n_labels) 331 | trn_exp = self.transition.unsqueeze(0).expand_as(vit_exp) 332 | vit_trn_sum = vit_exp + trn_exp 333 | vt_max, vt_argmax = vit_trn_sum.max(2) 334 | 335 | vt_max = vt_max.squeeze(-1) 336 | vit_nxt = vt_max + logit 337 | pointers.append(vt_argmax.squeeze(-1).unsqueeze(0)) 338 | 339 | mask = (c_lens > 0).float().unsqueeze(-1).expand_as(vit_nxt) 340 | vit = mask * vit_nxt + (1 - mask) * vit 341 | 342 | mask = (c_lens == 1).float().unsqueeze(-1).expand_as(vit_nxt) 343 | vit += mask * self.transition[self.end].unsqueeze( 344 | 0).expand_as(vit_nxt) 345 | 346 | c_lens = c_lens - 1 347 | 348 | pointers = torch.cat(pointers) 349 | scores, idx = vit.max(1) 350 | # idx = idx.squeeze(-1) 351 | paths = [idx.unsqueeze(1)] 352 | for argmax in reversed(pointers): 353 | idx_exp = idx.unsqueeze(-1) 354 | idx = torch.gather(argmax, 1, idx_exp) 355 | idx = idx.squeeze(-1) 356 | 357 | paths.insert(0, idx.unsqueeze(1)) 358 | 359 | paths = torch.cat(paths[1:], 1) 360 | scores = scores.squeeze(-1) 361 | 362 | return scores, paths 363 | 364 | class Model(nn.Module): 365 | def __init__(self): 366 | super(Model, self).__init__() 367 | self.gpu = False 368 | 369 | def cuda(self, device=None): 370 | self.gpu = True 371 | for module in self.children(): 372 | module.cuda(device) 373 | return self 374 | 375 | def cpu(self): 376 | self.gpu = False 377 | for module in self.children(): 378 | module.cpu() 379 | return self 380 | 381 | 382 | class LstmCrf(Model): 383 | def __init__(self, 384 | token_vocab, 385 | label_vocab, 386 | char_vocab, 387 | word_embedding, 388 | char_embedding, 389 | crf, 390 | lstm, 391 | input_layer=None, 392 | univ_fc_layer=None, 393 | spec_fc_layer=None, 394 | output_layer=None, 395 | embed_dropout_prob=0, 396 | lstm_dropout_prob=0, 397 | use_char_embedding=True, 398 | char_highway=None 399 | ): 400 | super(LstmCrf, self).__init__() 401 | 402 | self.token_vocab = token_vocab 403 | self.label_vocab = label_vocab 404 | self.char_vocab = char_vocab 405 | self.idx_label = {idx: label for label, idx in label_vocab.items()} 406 | self.embed_dropout_prob = embed_dropout_prob 407 | self.lstm_dropout_prob = lstm_dropout_prob 408 | self.use_char_embedding = use_char_embedding 409 | 410 | self.word_embedding = word_embedding 411 | self.char_embedding = char_embedding 412 | 413 | self.feat_dim = word_embedding.embedding_dim 414 | if use_char_embedding: 415 | self.feat_dim += char_embedding.output_size 416 | 417 | self.lstm = lstm 418 | self.input_layer = input_layer 419 | self.univ_fc_layer = univ_fc_layer 420 | self.spec_fc_layer = spec_fc_layer 421 | self.output_layer = output_layer 422 | self.crf = crf 423 | self.char_highway = char_highway 424 | self.lstm_dropout = nn.Dropout(p=lstm_dropout_prob) 425 | self.embed_dropout = nn.Dropout(p=embed_dropout_prob) 426 | self.label_size = len(label_vocab) 427 | if spec_fc_layer: 428 | self.spec_gate = Linear(spec_fc_layer.in_features, 429 | spec_fc_layer.out_features) 430 | 431 | def forward_model(self, inputs, lens, chars=None, char_lens=None): 432 | batch_size, seq_len = inputs.size() 433 | 434 | # Word embedding 435 | inputs_embed = self.word_embedding(inputs) 436 | 437 | # Character embedding 438 | if self.use_char_embedding: 439 | chars_embed = self.char_embedding(chars) 440 | if self.char_highway: 441 | chars_embed = self.char_highway(chars_embed) 442 | chars_embed = chars_embed.view(batch_size, seq_len, -1) 443 | inputs_embed = torch.cat([inputs_embed, chars_embed], dim=2) 444 | 445 | inputs_embed = self.embed_dropout(inputs_embed) 446 | 447 | # LSTM layer 448 | inputs_packed = R.pack_padded_sequence(inputs_embed, lens.tolist(), 449 | batch_first=True) 450 | lstm_out, _ = self.lstm(inputs_packed) 451 | lstm_out, _ = R.pad_packed_sequence(lstm_out, batch_first=True) 452 | 453 | lstm_out = lstm_out.contiguous().view(-1, self.lstm.output_size) 454 | lstm_out = self.lstm_dropout(lstm_out) 455 | 456 | # Fully-connected layer 457 | univ_feats = self.univ_fc_layer(lstm_out) 458 | if self.spec_fc_layer is not None: 459 | spec_feats = self.spec_fc_layer(lstm_out) 460 | gate = F.sigmoid(self.spec_gate(lstm_out)) 461 | outputs = gate * spec_feats + (1 - gate) * univ_feats 462 | else: 463 | outputs = univ_feats 464 | outputs = outputs.view(batch_size, seq_len, self.label_size) 465 | 466 | return outputs 467 | 468 | def predict(self, inputs, labels, lens, chars=None, char_lens=None): 469 | self.eval() 470 | 471 | loglik, logits = self.loglik(inputs, labels, lens, chars, char_lens) 472 | loss = -loglik.mean() 473 | scores, preds = self.crf.viterbi_decode(logits, lens) 474 | 475 | self.train() 476 | return preds, loss 477 | 478 | def loglik(self, inputs, labels, lens, chars=None, char_lens=None): 479 | logits = self.forward_model(inputs, lens, chars, char_lens) 480 | logits = self.crf.pad_logits(logits) 481 | norm_score = self.crf.calc_norm_score(logits, lens) 482 | gold_score = self.crf.calc_gold_score(logits, labels, lens) 483 | loglik = gold_score - norm_score 484 | 485 | return loglik, logits 486 | -------------------------------------------------------------------------------- /train_multi.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script demonstrates how to train a multi-lingual multi-task model with four 3 | tasks: 4 | - Target task 5 | - Auxiliary task 1 (different language) 6 | - Auxiliary task 2 (different task) 7 | - Auxiliary task 3 (different language and task) 8 | For example, if the target task is Spanish Name Tagging, related task Part-of- 9 | speech Tagging, and related language English, task 1 is English Name Tagging, 10 | task 2 is Spanish Part-of-speech Tagging, and task 3 is English Part-of-speech 11 | Tagging. 12 | """ 13 | 14 | import math 15 | import os 16 | from tqdm import tqdm 17 | import time 18 | import logging 19 | import traceback 20 | from collections import Counter 21 | 22 | import torch 23 | from torch import optim 24 | from torch.nn.utils import clip_grad_norm_ 25 | 26 | import constant as C 27 | from random import shuffle 28 | from argparse import ArgumentParser 29 | 30 | from torch.utils.data import DataLoader 31 | 32 | from util import evaluate 33 | from data import ConllParser, SeqLabelDataset, SeqLabelProcessor, count2vocab 34 | from model import Linear, LSTM, CRF, CharCNN, Highway, LstmCrf, load_embedding 35 | 36 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 37 | 38 | logging.basicConfig(level=logging.DEBUG) 39 | logger = logging.getLogger() 40 | 41 | argparser = ArgumentParser() 42 | 43 | # Target 44 | argparser.add_argument('--train_tgt', help='Path to the training set file') 45 | argparser.add_argument('--dev_tgt', help='Path to the dev set file') 46 | argparser.add_argument('--test_tgt', help='Path to the test set file') 47 | # Cross-lingual: same task, diff languages 48 | argparser.add_argument('--train_cl', help='Path to the training set file') 49 | argparser.add_argument('--dev_cl', help='Path to the dev set file') 50 | argparser.add_argument('--test_cl', help='Path to the test set file') 51 | # Cross-task: same languages, diff tasks 52 | argparser.add_argument('--train_ct', help='Path to the training set file') 53 | argparser.add_argument('--dev_ct', help='Path to the dev set file') 54 | argparser.add_argument('--test_ct', help='Path to the test set file') 55 | # Cross-lingual Cross-task, diff languages and tasks 56 | argparser.add_argument('--train_clct', help='Path to the training set file') 57 | argparser.add_argument('--dev_clct', help='Path to the dev set file') 58 | argparser.add_argument('--test_clct', help='Path to the test set file') 59 | 60 | argparser.add_argument('--log', help='Path to the log dir') 61 | argparser.add_argument('--model', help='Path to the model file') 62 | argparser.add_argument('--batch_size', default=10, type=int, help='Batch size') 63 | argparser.add_argument('--max_epoch', default=100, type=int) 64 | argparser.add_argument('--word_embed_1', 65 | help='Path to the pre-trained embedding file for lang 1') 66 | argparser.add_argument('--word_embed_2', 67 | help='Path to the pre-trained embedding file for lang 2') 68 | argparser.add_argument('--word_embed_dim', type=int, default=100, 69 | help='Word embedding dimension') 70 | argparser.set_defaults(word_ignore_case=False) 71 | argparser.add_argument('--char_embed_dim', type=int, default=50, 72 | help='Character embedding dimension') 73 | argparser.add_argument('--charcnn_filters', default='2,25;3,25;4,25', 74 | help='Character-level CNN filters') 75 | argparser.add_argument('--charhw_layer', default=1, type=int) 76 | argparser.add_argument('--charhw_func', default='relu') 77 | argparser.add_argument('--use_highway', action='store_true') 78 | argparser.add_argument('--lstm_hidden_size', default=100, type=int, 79 | help='LSTM hidden state size') 80 | argparser.add_argument('--lstm_forget_bias', default=0, type=float, 81 | help='LSTM forget bias') 82 | argparser.add_argument('--feat_dropout', default=.5, type=float, 83 | help='Word feature dropout probability') 84 | argparser.add_argument('--lstm_dropout', default=.5, type=float, 85 | help='LSTM output dropout probability') 86 | argparser.add_argument('--lr', default=0.005, type=float, 87 | help='Learning rate') 88 | argparser.add_argument('--momentum', default=.9, type=float) 89 | argparser.add_argument('--decay_rate', default=.9, type=float) 90 | argparser.add_argument('--decay_step', default=10000, type=int) 91 | argparser.add_argument('--grad_clipping', default=5, type=float) 92 | argparser.add_argument('--gpu', action='store_true') 93 | argparser.add_argument('--device', default=0, type=int) 94 | argparser.add_argument('--thread', default=5, type=int) 95 | 96 | args = argparser.parse_args() 97 | batch_size = args.batch_size 98 | 99 | use_gpu = args.gpu and torch.cuda.is_available() 100 | if use_gpu: 101 | torch.cuda.set_device(args.device) 102 | torch.set_num_threads(args.thread) 103 | 104 | # Model file 105 | model_dir = args.model 106 | assert model_dir and os.path.isdir(model_dir), 'Model output dir is required' 107 | model_file = os.path.join(model_dir, 'model.{}.mdl'.format(timestamp)) 108 | 109 | # Logging file 110 | log_writer = None 111 | if args.log: 112 | log_file = os.path.join(args.log, 'log.{}.txt'.format(timestamp)) 113 | log_writer = open(log_file, 'a', encoding='utf-8') 114 | logger.addHandler(logging.FileHandler(log_file, encoding='utf-8')) 115 | logger.info('----------') 116 | logger.info('Parameters:') 117 | for arg in vars(args): 118 | logger.info('{}: {}'.format(arg, getattr(args, arg))) 119 | logger.info('----------') 120 | 121 | # Data file 122 | logger.info('Loading data sets') 123 | ner_parser = ConllParser(skip_comment=True) 124 | pos_parser = ConllParser(token_col=1, label_col=3, skip_comment=True) 125 | 126 | train_set_tgt = SeqLabelDataset(args.train_tgt, parser=ner_parser) 127 | dev_set_tgt = SeqLabelDataset(args.dev_tgt, parser=ner_parser) 128 | test_set_tgt = SeqLabelDataset(args.test_tgt, parser=ner_parser) 129 | 130 | train_set_cl = SeqLabelDataset(args.train_cl, parser=ner_parser) 131 | dev_set_cl = SeqLabelDataset(args.dev_cl, parser=ner_parser) 132 | test_set_cl = SeqLabelDataset(args.test_cl, parser=ner_parser) 133 | 134 | train_set_ct = SeqLabelDataset(args.train_ct, parser=pos_parser) 135 | dev_set_ct = SeqLabelDataset(args.dev_ct, parser=pos_parser) 136 | test_set_ct = SeqLabelDataset(args.test_ct, parser=pos_parser) 137 | 138 | train_set_clct = SeqLabelDataset(args.train_clct, parser=pos_parser) 139 | dev_set_clct = SeqLabelDataset(args.dev_clct, parser=pos_parser) 140 | test_set_clct = SeqLabelDataset(args.test_clct, parser=pos_parser) 141 | 142 | # datasets = {'train': train_set, 'dev': dev_set, 'test': test_set} 143 | datasets = { 144 | 'tgt': {'train': train_set_tgt, 'dev': dev_set_tgt, 'test': test_set_tgt}, 145 | 'cl': {'train': train_set_cl, 'dev': dev_set_cl, 'test': test_set_cl}, 146 | 'ct': {'train': train_set_ct, 'dev': dev_set_ct, 'test': test_set_ct}, 147 | 'clct': {'train': train_set_clct, 'dev': dev_set_clct, 'test': test_set_clct} 148 | } 149 | 150 | # Vocabs 151 | logger.info('Building vocabs') 152 | ( 153 | token_count_1, token_count_2, char_count, label_count_1, label_count_2 154 | ) = Counter(), Counter(), Counter(), Counter(), Counter() 155 | for _, ds in datasets['tgt'].items(): 156 | tc, cc, lc = ds.stats() 157 | token_count_1.update(tc) 158 | char_count.update(cc) 159 | label_count_1.update(lc) 160 | for _, ds in datasets['cl'].items(): 161 | tc, cc, lc = ds.stats() 162 | token_count_2.update(tc) 163 | char_count.update(cc) 164 | label_count_1.update(lc) 165 | for _, ds in datasets['ct'].items(): 166 | tc, cc, lc = ds.stats() 167 | token_count_1.update(tc) 168 | char_count.update(cc) 169 | label_count_2.update(lc) 170 | for _, ds in datasets['clct'].items(): 171 | tc, cc, lc = ds.stats() 172 | token_count_2.update(tc) 173 | char_count.update(cc) 174 | label_count_2.update(lc) 175 | token_vocab_1 = count2vocab(token_count_1, offset=len(C.TOKEN_PADS), pads=C.TOKEN_PADS) 176 | token_vocab_2 = count2vocab(token_count_2, offset=len(C.TOKEN_PADS), pads=C.TOKEN_PADS) 177 | char_vocab = count2vocab(char_count, offset=len(C.CHAR_PADS), pads=C.CHAR_PADS) 178 | label_vocab_1 = count2vocab(label_count_1, offset=1, pads=[(C.PAD, C.PAD_INDEX)]) 179 | label_vocab_2 = count2vocab(label_count_2, offset=1, pads=[(C.PAD, C.PAD_INDEX)]) 180 | 181 | idx_token_1 = {idx: token for token, idx in token_vocab_1.items()} 182 | idx_token_2 = {idx: token for token, idx in token_vocab_2.items()} 183 | idx_label_1 = {idx: label for label, idx in label_vocab_1.items()} 184 | idx_label_2 = {idx: label for label, idx in label_vocab_2.items()} 185 | idx_tokens = { 186 | 'tgt': idx_token_1, 187 | 'cl': idx_token_2, 188 | 'ct': idx_token_1, 189 | 'clct': idx_token_2 190 | } 191 | idx_labels = { 192 | 'tgt': idx_label_1, 193 | 'cl': idx_label_1, 194 | 'ct': idx_label_2, 195 | 'clct': idx_label_2 196 | } 197 | 198 | train_set_tgt.numberize(token_vocab_1, label_vocab_1, char_vocab) 199 | dev_set_tgt.numberize(token_vocab_1, label_vocab_1, char_vocab) 200 | test_set_tgt.numberize(token_vocab_1, label_vocab_1, char_vocab) 201 | 202 | train_set_cl.numberize(token_vocab_2, label_vocab_1, char_vocab) 203 | dev_set_cl.numberize(token_vocab_2, label_vocab_1, char_vocab) 204 | test_set_cl.numberize(token_vocab_2, label_vocab_1, char_vocab) 205 | 206 | train_set_ct.numberize(token_vocab_1, label_vocab_2, char_vocab) 207 | dev_set_ct.numberize(token_vocab_1, label_vocab_2, char_vocab) 208 | test_set_ct.numberize(token_vocab_1, label_vocab_2, char_vocab) 209 | 210 | train_set_clct.numberize(token_vocab_2, label_vocab_2, char_vocab) 211 | dev_set_clct.numberize(token_vocab_2, label_vocab_2, char_vocab) 212 | test_set_clct.numberize(token_vocab_2, label_vocab_2, char_vocab) 213 | 214 | # Embedding file 215 | word_embed_1 = load_embedding(args.word_embed_1, 216 | dimension=args.word_embed_dim, 217 | vocab=token_vocab_1) 218 | word_embed_2 = load_embedding(args.word_embed_2, 219 | dimension=args.word_embed_dim, 220 | vocab=token_vocab_2) 221 | charcnn_filters = [[int(f.split(',')[0]), int(f.split(',')[1])] 222 | for f in args.charcnn_filters.split(';')] 223 | char_embed = CharCNN(len(char_vocab), 224 | args.char_embed_dim, 225 | filters=charcnn_filters) 226 | char_hw = Highway(char_embed.output_size, 227 | layer_num=args.charhw_layer, 228 | activation=args.charhw_func) 229 | feat_dim = args.word_embed_dim + char_embed.output_size 230 | lstm = LSTM(feat_dim, 231 | args.lstm_hidden_size, 232 | batch_first=True, 233 | bidirectional=True, 234 | forget_bias=args.lstm_forget_bias 235 | ) 236 | crf_1 = CRF(label_size=len(label_vocab_1) + 2) 237 | crf_2 = CRF(label_size=len(label_vocab_2) + 2) 238 | # Linear layers for task 1 239 | shared_linear_1 = Linear(in_features=lstm.output_size, 240 | out_features=len(label_vocab_1)) 241 | spec_linear_1_1 = Linear(in_features=lstm.output_size, 242 | out_features=len(label_vocab_1)) 243 | spec_linear_1_2 = Linear(in_features=lstm.output_size, 244 | out_features=len(label_vocab_1)) 245 | # Linear layers for task 2 246 | shared_linear_2 = Linear(in_features=lstm.output_size, 247 | out_features=len(label_vocab_2)) 248 | spec_linear_2_1 = Linear(in_features=lstm.output_size, 249 | out_features=len(label_vocab_2)) 250 | spec_linear_2_2 = Linear(in_features=lstm.output_size, 251 | out_features=len(label_vocab_2)) 252 | 253 | lstm_crf_tgt = LstmCrf( 254 | token_vocab_1, label_vocab_1, char_vocab, 255 | word_embedding=word_embed_1, 256 | char_embedding=char_embed, 257 | crf=crf_1, 258 | lstm=lstm, 259 | univ_fc_layer=shared_linear_1, 260 | spec_fc_layer=spec_linear_1_1, 261 | embed_dropout_prob=args.feat_dropout, 262 | lstm_dropout_prob=args.lstm_dropout, 263 | char_highway=char_hw if args.use_highway else None 264 | ) 265 | lstm_crf_cl = LstmCrf( 266 | token_vocab_2, label_vocab_1, char_vocab, 267 | word_embedding=word_embed_2, 268 | char_embedding=char_embed, 269 | crf=crf_1, 270 | lstm=lstm, 271 | univ_fc_layer=shared_linear_1, 272 | spec_fc_layer=spec_linear_1_2, 273 | embed_dropout_prob=args.feat_dropout, 274 | lstm_dropout_prob=args.lstm_dropout, 275 | char_highway=char_hw if args.use_highway else None 276 | ) 277 | lstm_crf_ct = LstmCrf( 278 | token_vocab_1, label_vocab_2, char_vocab, 279 | word_embedding=word_embed_1, 280 | char_embedding=char_embed, 281 | crf=crf_2, 282 | lstm=lstm, 283 | univ_fc_layer=shared_linear_2, 284 | spec_fc_layer=spec_linear_2_1, 285 | embed_dropout_prob=args.feat_dropout, 286 | lstm_dropout_prob=args.lstm_dropout, 287 | char_highway=char_hw if args.use_highway else None 288 | ) 289 | lstm_crf_clct = LstmCrf( 290 | token_vocab_2, label_vocab_2, char_vocab, 291 | word_embedding=word_embed_2, 292 | char_embedding=char_embed, 293 | crf=crf_2, 294 | lstm=lstm, 295 | univ_fc_layer=shared_linear_2, 296 | spec_fc_layer=spec_linear_2_2, 297 | embed_dropout_prob=args.feat_dropout, 298 | lstm_dropout_prob=args.lstm_dropout, 299 | char_highway=char_hw if args.use_highway else None 300 | ) 301 | if use_gpu: 302 | lstm_crf_tgt.cuda() 303 | lstm_crf_cl.cuda() 304 | lstm_crf_ct.cuda() 305 | lstm_crf_clct.cuda() 306 | models = { 307 | 'tgt': lstm_crf_tgt, 308 | 'cl': lstm_crf_cl, 309 | 'ct': lstm_crf_ct, 310 | 'clct': lstm_crf_clct 311 | } 312 | 313 | # Task 314 | optimizer_tgt = optim.SGD( 315 | filter(lambda p: p.requires_grad, lstm_crf_tgt.parameters()), 316 | lr=args.lr, momentum=args.momentum) 317 | optimizer_cl = optim.SGD( 318 | filter(lambda p: p.requires_grad, lstm_crf_cl.parameters()), 319 | lr=args.lr, momentum=args.momentum) 320 | optimizer_ct = optim.SGD( 321 | filter(lambda p: p.requires_grad, lstm_crf_ct.parameters()), 322 | lr=args.lr, momentum=args.momentum) 323 | optimizer_clct = optim.SGD( 324 | filter(lambda p: p.requires_grad, lstm_crf_clct.parameters()), 325 | lr=args.lr, momentum=args.momentum) 326 | optimizers = { 327 | 'tgt': optimizer_tgt, 328 | 'cl': optimizer_cl, 329 | 'ct': optimizer_ct, 330 | 'clct': optimizer_clct 331 | } 332 | processor = SeqLabelProcessor(gpu=use_gpu) 333 | 334 | train_args = vars(args) 335 | train_args['word_embed_size'] = word_embed_1.num_embeddings 336 | state = { 337 | 'model': { 338 | 'word_embed': word_embed_1.state_dict(), 339 | 'char_embed': char_embed.state_dict(), 340 | 'char_hw': char_hw.state_dict(), 341 | 'lstm': lstm.state_dict(), 342 | 'crf': crf_1.state_dict(), 343 | 'univ_linear': shared_linear_1.state_dict(), 344 | 'spec_linear': spec_linear_1_1.state_dict(), 345 | 'lstm_crf': lstm_crf_tgt.state_dict() 346 | }, 347 | 'args': train_args, 348 | 'vocab': { 349 | 'token': token_vocab_1, 350 | 'label': label_vocab_1, 351 | 'char': char_vocab, 352 | } 353 | } 354 | 355 | # Calculate mixing rates 356 | batch_num = len(train_set_tgt) // batch_size 357 | r_tgt = math.sqrt(len(train_set_tgt)) 358 | r_cl = 1.0 * .1 * math.sqrt(len(datasets['cl']['train'])) 359 | r_ct = .1 * 1.0 * math.sqrt(len(datasets['ct']['train'])) 360 | r_clct = .1 * .1 * math.sqrt(len(datasets['clct']['train'])) 361 | num_cl = int(r_cl / r_tgt * batch_num) 362 | num_ct = int(r_ct / r_tgt * batch_num) 363 | num_clct = int(r_clct / r_tgt * batch_num) 364 | print('{}, {}, {}, {}'.format(batch_num, num_cl, num_ct, num_clct)) 365 | 366 | data_loaders = {} 367 | data_loader_iters = {} 368 | for task, task_datasets in datasets.items(): 369 | data_loaders[task] = { 370 | k: DataLoader(v, 371 | batch_size=batch_size, 372 | shuffle=k == 'train', 373 | collate_fn=processor.process) 374 | for k, v in task_datasets.items() 375 | } 376 | data_loader_iters[task] = {k: iter(v) for k, v 377 | in data_loaders[task].items()} 378 | 379 | try: 380 | global_step = 0 381 | best_dev_score = best_test_score = 0.0 382 | 383 | for epoch in range(args.max_epoch): 384 | logger.info('Epoch {}: Training'.format(epoch + 1)) 385 | best = False 386 | 387 | for ds in ['train', 'dev', 'test']: 388 | if ds == 'train': 389 | tasks = ['tgt'] * batch_num + ['cl'] * num_cl\ 390 | + ['ct'] * num_ct + ['clct'] * num_clct 391 | shuffle(tasks) 392 | progress = tqdm(total=len(tasks), mininterval=1, 393 | desc=ds) 394 | for task in tasks: 395 | progress.update(1) 396 | global_step += 1 397 | model = models[task] 398 | optimizer = optimizers[task] 399 | optimizer.zero_grad() 400 | try: 401 | batch = next(data_loader_iters[task]['train']) 402 | except StopIteration: 403 | data_loader_iters[task]['train'] = iter( 404 | data_loaders[task]['train'] 405 | ) 406 | batch = next(data_loader_iters[task]['train']) 407 | tokens, labels, chars, seq_lens, char_lens = batch 408 | loglik, _ = model.loglik( 409 | tokens, labels, seq_lens, chars, char_lens) 410 | loss = -loglik.mean() 411 | loss.backward() 412 | 413 | clip_grad_norm_(model.parameters(), args.grad_clipping) 414 | optimizer.step() 415 | progress.close() 416 | else: 417 | for task in ['tgt', 'cl', 'ct', 'clct']: 418 | logger.info('task: {} dataset: {}'.format(task, ds)) 419 | results = [] 420 | for batch in data_loaders[task][ds]: 421 | tokens, labels, chars, seq_lens, char_lens = batch 422 | pred, _loss = models[task].predict( 423 | tokens, labels, seq_lens, chars, char_lens 424 | ) 425 | results.append((pred, labels, seq_lens, tokens)) 426 | fscore, prec, rec = evaluate( 427 | results, idx_tokens[task], idx_labels[task], 428 | writer=log_writer) 429 | if ds == 'dev' and task == 'tgt' and fscore > best_dev_score: 430 | logger.info('New best score: {:.4f}'.format(fscore)) 431 | best_dev_score = fscore 432 | best = True 433 | logger.info('Saving the model to {}'.format(model_file)) 434 | torch.save(state, model_file) 435 | if best and ds == 'test' and task == 'tgt': 436 | best_test_score = fscore 437 | 438 | # learning rate decay 439 | lr = args.lr * args.decay_rate ** (global_step / args.decay_step) 440 | for opt in optimizers.values(): 441 | for p in opt.param_groups: 442 | p['lr'] = lr 443 | logger.info('New learning rate: {}'.format(lr)) 444 | 445 | logger.info('Best dev score: {}'.format(best_dev_score)) 446 | logger.info('Best test score: {}'.format(best_test_score)) 447 | logger.info('Model file: {}'.format(model_file)) 448 | if args.log: 449 | logger.info('Log file: {}'.format(log_file)) 450 | log_writer.close() 451 | except Exception: 452 | traceback.print_exc() 453 | if log_writer: 454 | log_writer.close() -------------------------------------------------------------------------------- /old/task.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.optim as optim 3 | 4 | from torch.nn.utils import clip_grad_norm 5 | from numpy.random import choice 6 | from conlleval import evaluate, report, metrics 7 | from collections import defaultdict, namedtuple 8 | from data import (count2vocab, create_parser, create_dataset, 9 | numberize_datasets) 10 | from model import LstmCrf, create_module 11 | from util import Config, get_logger 12 | 13 | logger = get_logger(__name__) 14 | 15 | SCORES = namedtuple('SCORES', ['fscore', 'precision', 'recall', 'loss']) 16 | 17 | 18 | class Task(object): 19 | 20 | def __init__(self, 21 | name, 22 | model, 23 | datasets, 24 | vocabs, 25 | gpu=False, 26 | prob=1.0, 27 | lr=0.001, 28 | momentum=.9, 29 | decay_rate=.9, 30 | decay_step=10000, 31 | gradient_clipping=5.0, 32 | require_eval=True, 33 | ref=False, 34 | aux_task=False, 35 | aux_lang=False 36 | ): 37 | self.name = name 38 | self.model = model 39 | self.prob = prob 40 | self.gpu = gpu 41 | self.require_eval = require_eval 42 | 43 | self.datasets = datasets 44 | self.train = datasets.get('train', None) 45 | self.dev = datasets.get('dev', None) 46 | self.test = datasets.get('test', None) 47 | 48 | self.vocabs = vocabs 49 | self.token_vocab = vocabs.get('token') 50 | self.label_vocab = vocabs.get('label') 51 | self.char_vocab = vocabs.get('char') 52 | self.ref = ref 53 | 54 | self.optimizer = optim.SGD( 55 | filter(lambda p: p.requires_grad, model.parameters()), 56 | lr=lr, momentum=momentum) 57 | self.lr = lr 58 | self.momentum = momentum 59 | self.task_step = 0 60 | self.decay_rate = decay_rate 61 | self.decay_step = float(decay_step) 62 | self.gradient_clipping = gradient_clipping 63 | 64 | self.aux_task = aux_task 65 | self.aux_lang = aux_lang 66 | 67 | if gpu: 68 | self.model.cuda() 69 | 70 | def step(self): 71 | raise NotImplementedError() 72 | 73 | def eval(self, dataset_name, log_output=None): 74 | raise NotImplementedError() 75 | 76 | def learning_rate_decay(self): 77 | lr = self.lr * self.decay_rate ** (self.task_step / self.decay_step) 78 | for p in self.optimizer.param_groups: 79 | p['lr'] = lr 80 | 81 | def update_learning_rate(self, lr): 82 | for p in self.optimizer.param_groups: 83 | p['lr'] = lr 84 | 85 | 86 | class SequenceTask(Task): 87 | 88 | def __init__(self, 89 | name, 90 | model, 91 | datasets, 92 | vocabs, 93 | gpu=False, 94 | prob=1.0, 95 | lr=0.001, 96 | momentum=.9, 97 | decay_rate=.9, 98 | decay_step=10000, 99 | gradient_clipping=5.0, 100 | require_eval=True, 101 | ref=False, 102 | aux_task=False, 103 | aux_lang=False 104 | ): 105 | super(SequenceTask, self).__init__(name, 106 | model, 107 | datasets, 108 | vocabs, 109 | gpu, 110 | prob, 111 | lr, 112 | momentum, 113 | decay_rate, 114 | decay_step, 115 | gradient_clipping, 116 | require_eval, 117 | ref, 118 | aux_task, 119 | aux_lang 120 | ) 121 | self.label_size = len(self.label_vocab) 122 | self.idx_label = {i: l for l, i in self.label_vocab.items()} 123 | self.idx_token = {i: t for t, i in self.token_vocab.items()} 124 | 125 | 126 | def step(self): 127 | self.task_step += 1 128 | self.optimizer.zero_grad() 129 | ( 130 | tokens, labels, chars, seq_lens, char_lens 131 | ) = self.train.get_batch(gpu=self.gpu) 132 | loglik, _ = self.model.loglik(tokens, labels, seq_lens, chars, 133 | char_lens) 134 | loss = -loglik.mean() 135 | loss.backward() 136 | 137 | params = [] 138 | for n, p in self.model.named_parameters(): 139 | if 'embedding.weight' not in n: 140 | params.append(p) 141 | clip_grad_norm(params, self.gradient_clipping) 142 | self.optimizer.step() 143 | 144 | 145 | class NameTagging(SequenceTask): 146 | 147 | def eval(self, dataset_name, log_output=None): 148 | dataset = self.datasets.get(dataset_name, None) 149 | if dataset is None: 150 | return 151 | 152 | results = [] 153 | logger.info('Evaluating {} ({})'.format(self.name, dataset_name)) 154 | set_loss = 0 155 | for tokens, labels, chars, seq_lens, char_lens in dataset.get_dataset(volatile=True, gpu=self.gpu): 156 | preds, loss = self.model.predict(tokens, 157 | labels, 158 | seq_lens, 159 | chars, 160 | char_lens) 161 | set_loss += float(loss.data[0]) 162 | for pred, gold, seq_len, ts in zip(preds, labels, seq_lens, tokens): 163 | l = int(seq_len.data[0]) 164 | pred = pred.data.tolist()[:l] 165 | gold = gold.data.tolist()[:l] 166 | ts = ts.data.tolist()[:l] 167 | for p, g, t in zip(pred, gold, ts): 168 | t = self.idx_token.get(t, 'UNK') 169 | results.append('{} {} {}'.format(t, 170 | self.idx_label[g], 171 | self.idx_label[p])) 172 | results.append('') 173 | counts = evaluate(results) 174 | overall, by_type = metrics(counts) 175 | report(counts) 176 | logger.info('Loss: {:.5f}'.format(set_loss)) 177 | return SCORES(fscore=overall.fscore, 178 | precision=overall.prec, 179 | recall=overall.rec, 180 | loss=set_loss) 181 | 182 | 183 | class PosTagging(SequenceTask): 184 | 185 | def eval(self, dataset_name, log_output=None): 186 | dataset = self.datasets.get(dataset_name, None) 187 | if dataset is None: 188 | return 189 | 190 | total_num = 0 191 | correct_num = 0 192 | logger.info('Evaluating {} ({})'.format(self.name, dataset_name)) 193 | set_loss = 0 194 | 195 | results = [] 196 | for tokens, labels, chars, seq_lens, char_lens in dataset.get_dataset( 197 | volatile=True, gpu=self.gpu): 198 | preds, loss = self.model.predict(tokens, labels, seq_lens, chars, char_lens) 199 | set_loss += float(loss.data[0]) 200 | for pred, gold, seq_len, ts in zip(preds, labels, seq_lens, tokens): 201 | l = int(seq_len.data[0]) 202 | total_num += l 203 | pred = pred.data.tolist()[:l] 204 | gold = gold.data.tolist()[:l] 205 | pred = np.array(pred) 206 | gold = np.array(gold) 207 | correct = (pred == gold).sum() 208 | correct_num += correct 209 | accuracy = correct_num / total_num 210 | logger.info('Accuracy: {0:.5f}'.format(accuracy)) 211 | logger.info('Loss: {}'.format(set_loss)) 212 | return SCORES(fscore=accuracy, 213 | precision=accuracy, 214 | recall=accuracy, 215 | loss=set_loss) 216 | 217 | 218 | class MultiTask(object): 219 | 220 | def __init__(self, tasks, eval_freq=1000): 221 | self.tasks = tasks 222 | self.task_probs = [] 223 | self.update_probs() 224 | self.global_step = 0 225 | self.eval_freq = eval_freq 226 | self.ref_task = 0 227 | self.best_ref_score = -1.0 228 | self.best_scores = [] 229 | for task_idx, task in enumerate(self.tasks): 230 | if task.ref: 231 | self.ref_tasks = task_idx 232 | break 233 | 234 | def update_probs(self): 235 | 236 | def auto_prob(task): 237 | doc_num = len(task.train.dataset) 238 | theta_task = .1 if task.aux_task else 1 239 | theta_lang = .1 if task.aux_lang else 1 240 | prob = doc_num ** .5 * theta_task * theta_lang 241 | return prob 242 | 243 | task_probs = [auto_prob(t) for t in self.tasks] 244 | task_prob_sum = sum(task_probs) 245 | self.task_probs = [p / task_prob_sum for p in task_probs] 246 | 247 | def step(self): 248 | self.global_step += 1 249 | task = choice(self.tasks,p=self.task_probs) 250 | task.learning_rate_decay() 251 | task.step() 252 | 253 | if self.global_step % self.eval_freq == 0: 254 | scores = [] 255 | ref_score = 0 256 | for task_idx, task in enumerate(self.tasks): 257 | if task.require_eval: 258 | dev_scores = task.eval('dev') 259 | test_scores = task.eval('test') 260 | if task_idx == self.ref_task: 261 | ref_score = dev_scores.fscore 262 | scores.append((task_idx, dev_scores, test_scores)) 263 | if ref_score > self.best_ref_score: 264 | self.best_ref_score = ref_score 265 | self.best_scores = scores 266 | 267 | 268 | def compute_metadata(datasets): 269 | """Compute tokens, labels, and characters in the given data sets. 270 | 271 | :param datasets: A list of data sets. 272 | :return: dicts of token, label, and character counts. 273 | """ 274 | token_count = defaultdict(int) 275 | label_count = defaultdict(int) 276 | char_count = defaultdict(int) 277 | 278 | for dataset in datasets: 279 | if dataset: 280 | t, l, c = dataset.metadata() 281 | for k, v in t.items(): 282 | token_count[k] += v 283 | for k, v in l.items(): 284 | label_count[k] += v 285 | for k, v in c.items(): 286 | char_count[k] += v 287 | 288 | return token_count, label_count, char_count 289 | 290 | 291 | def build_tasks_from_file(conf_path, options=None): 292 | if type(conf_path) is str: 293 | conf = Config.read(conf_path) 294 | elif type(conf_path) is Config: 295 | conf = conf_path 296 | else: 297 | raise TypeError('Unknown configuration type. Expect str or Config.') 298 | 299 | if options: 300 | for k, v in options: 301 | conf.update_value(k, v) 302 | 303 | # Create data sets 304 | logger.info('Loading data sets') 305 | datasets = {} 306 | lang_datasets = defaultdict(list) 307 | task_datasets = defaultdict(list) 308 | for dataset in conf.datasets: 309 | parser = create_parser(dataset.parser.format, dataset.parser) 310 | ( 311 | train_conf, dev_conf, test_conf 312 | ) = dataset.clone(), dataset.clone(), dataset.clone() 313 | train_conf.update({'path': dataset.files.train, 314 | 'parser': parser}) 315 | dev_conf.update({'path': dataset.files.dev, 316 | 'parser': parser, 317 | 'sample': None}) 318 | train_dataset = create_dataset(dataset.type, train_conf) 319 | dev_dataset = create_dataset(dataset.type, dev_conf) 320 | if hasattr(dataset.files, 'test'): 321 | test_conf.update({'path': dataset.files.test, 322 | 'parser': parser, 323 | 'sample': None}) 324 | test_dataset = create_dataset(dataset.type, test_conf) 325 | datasets[dataset.name] = { 326 | 'train': train_dataset, 327 | 'dev': dev_dataset, 328 | 'test': test_dataset, 329 | 'language': dataset.language, 330 | 'task': dataset.task 331 | } 332 | lang_datasets[dataset.language].append(dataset.name) 333 | task_datasets[dataset.task].append(dataset.name) 334 | 335 | # Create vocabs 336 | # I only keep words in the data sets to save memory 337 | # If the model will be applied to an unknown test set, it is better to keep 338 | # all words in pre-trained embeddings. 339 | logger.info('Creating vocabularies') 340 | dataset_counts = {} 341 | lang_token_vocabs = {} 342 | task_label_vocabs = {} 343 | for name, ds in datasets.items(): 344 | dataset_counts[name] = compute_metadata( 345 | [ds['train'], ds['dev'], ds['test']] 346 | ) 347 | for lang, ds in lang_datasets.items(): 348 | counts = [dataset_counts[d][0] for d in ds] 349 | lang_token_vocabs[lang] = count2vocab(counts, 350 | ignore_case=True, 351 | start_idx=2) 352 | for task, ds in task_datasets.items(): 353 | counts = [dataset_counts[d][1] for d in ds] 354 | task_label_vocabs[task] = count2vocab(counts, 355 | ignore_case=False, 356 | start_idx=0, 357 | sort=True) 358 | char_vocab = count2vocab([c[2] for c in dataset_counts.values()], 359 | ignore_case=False, start_idx=1) 360 | 361 | # Report stats 362 | for lang, vocab in lang_token_vocabs.items(): 363 | logger.info('#{} token: {}'.format(lang, len(vocab))) 364 | for task, vocab in task_label_vocabs.items(): 365 | logger.info('#{} label: {}'.format(task, len(vocab))) 366 | logger.info(vocab) 367 | 368 | # Numberize datasets 369 | logger.info('Numberizing data sets') 370 | numberize_conf = [] 371 | for ds in datasets.values(): 372 | numberize_conf.append((ds['train'], 373 | lang_token_vocabs[ds['language']], 374 | task_label_vocabs[ds['task']], 375 | char_vocab)) 376 | numberize_conf.append((ds['dev'], 377 | lang_token_vocabs[ds['language']], 378 | task_label_vocabs[ds['task']], 379 | char_vocab)) 380 | numberize_conf.append((ds['test'], 381 | lang_token_vocabs[ds['language']], 382 | task_label_vocabs[ds['task']], 383 | char_vocab)) 384 | numberize_datasets(numberize_conf, 385 | token_ignore_case=True, 386 | label_ignore_case=False, 387 | char_ignore_case=False) 388 | 389 | # Initialize component confs 390 | logger.info('Initializing component configurations') 391 | word_embed_dim = char_embed_dim = lstm_output_dim = 0 392 | cpnt_confs = {} 393 | for cpnt in conf.components: 394 | if cpnt.model == 'embedding': 395 | cpnt.embedding_dim = cpnt.dimension 396 | word_embed_dim = cpnt.dimension 397 | elif cpnt.model == 'char_cnn': 398 | cpnt.vocab_size = len(char_vocab) 399 | char_embed_dim = sum([x[1] for x in cpnt.filters]) 400 | elif cpnt.model == 'lstm': 401 | lstm_output_dim = cpnt.hidden_size * (2 if cpnt.bidirectional else 1) 402 | cpnt_confs[cpnt.name] = cpnt.clone() 403 | 404 | # Update component configurations 405 | target_task = '' 406 | target_lang = '' 407 | for task_conf in conf.tasks: 408 | language = task_conf.language 409 | task = task_conf.task 410 | if task_conf.get('ref', False): 411 | target_lang = language 412 | target_task = task 413 | model_conf = task_conf.model 414 | if model_conf.model != 'lstm_crf': 415 | continue 416 | # Update word embedding configuration 417 | cpnt_confs[model_conf.word_embed].num_embeddings = len( 418 | lang_token_vocabs[language]) 419 | cpnt_confs[model_conf.word_embed].vocab = lang_token_vocabs[language] 420 | # Update output layer configuration 421 | cpnt_confs[model_conf.univ_layer].out_features = len( 422 | task_label_vocabs[task] 423 | ) 424 | if hasattr(model_conf, 'spec_layer'): 425 | cpnt_confs[model_conf.spec_layer].out_features = len( 426 | task_label_vocabs[task] 427 | ) 428 | # Update CRF configuration 429 | cpnt_confs[model_conf.crf].label_vocab = task_label_vocabs[task] 430 | 431 | for _, cpnt_conf in cpnt_confs.items(): 432 | if cpnt_conf.model == 'linear' and cpnt_conf.position == 'output': 433 | cpnt_conf.in_features = lstm_output_dim 434 | if cpnt_conf.model == 'lstm': 435 | cpnt_conf.input_size = char_embed_dim + word_embed_dim 436 | if cpnt_conf.model == 'highway' and cpnt_conf.position == 'char': 437 | cpnt_conf.size = char_embed_dim 438 | 439 | # Create components 440 | logger.info('Creating components') 441 | components = {k: create_module(v.model, v) for k, v in cpnt_confs.items()} 442 | 443 | # Construct models 444 | tasks = [] 445 | for task_conf in conf.tasks: 446 | model_conf = task_conf.model 447 | language = task_conf.language 448 | task = task_conf.task 449 | if model_conf.model == 'lstm_crf': 450 | model = LstmCrf( 451 | lang_token_vocabs[language], 452 | task_label_vocabs[task], 453 | char_vocab, 454 | word_embedding=components[model_conf.word_embed], 455 | char_embedding=components[model_conf.char_embed] if hasattr( 456 | model_conf, 'char_embed') else None, 457 | crf=components[model_conf.crf], 458 | lstm=components[model_conf.lstm], 459 | input_layer=None, 460 | univ_fc_layer=components[model_conf.univ_layer], 461 | spec_fc_layer=components[model_conf.spec_layer] if hasattr( 462 | model_conf, 'spec_linear') else None, 463 | embed_dropout_prob=model_conf.embed_dropout, 464 | lstm_dropout_prob=model_conf.lstm_dropout, 465 | linear_dropout_prob=model_conf.linear_dropout, 466 | char_highway=components[model_conf.char_highway] if hasattr( 467 | model_conf, 'char_highway') else None, 468 | use_char_embedding=model_conf.use_char_embedding if hasattr( 469 | model_conf, 'use_char_embedding') else True, 470 | ) 471 | # elif model_conf.model == 'cbow': 472 | # pass 473 | else: 474 | raise ValueError('Unknown model: {}'.format(model_conf.model)) 475 | logger.debug(model) 476 | 477 | task_classes = {'ner': NameTagging, 'pos': PosTagging} 478 | if task in task_classes: 479 | task_obj = task_classes[task]( 480 | task_conf.name, 481 | model, 482 | datasets=datasets[task_conf.dataset], 483 | vocabs={ 484 | 'token': lang_token_vocabs[language], 485 | 'label': task_label_vocabs[task], 486 | 'char': char_vocab 487 | }, 488 | gpu=task_conf.gpu, 489 | # TODO: 'gpu' -> global config 490 | prob=getattr(task_conf, 'prob', 1.0), 491 | lr=getattr(task_conf, 'learning_rate', .001), 492 | momentum=getattr(task_conf, 'momentum', .9), 493 | decay_rate=getattr(task_conf, 'decay_rate', .9), 494 | decay_step=getattr(task_conf, 'decay_step', 10000), 495 | gradient_clipping=getattr(task_conf, 'gradient_clipping', 5.0), 496 | require_eval=getattr(task_conf, 'require_eval', True), 497 | ref=getattr(task_conf, 'ref', False), 498 | aux_task=task_conf.task != target_task, 499 | aux_lang=task_conf.language != target_lang, 500 | ) 501 | else: 502 | raise ValueError('Unknown task {}'.format(task)) 503 | tasks.append(task_obj) 504 | 505 | return tasks, { 506 | 'lang_token_vocabs': lang_token_vocabs, 507 | 'task_token_vocabs': task_label_vocabs, 508 | 'components': components 509 | } 510 | 511 | 512 | 513 | 514 | 515 | 516 | -------------------------------------------------------------------------------- /old/_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as I 5 | import torch.nn.utils.rnn as R 6 | import torch.nn.functional as F 7 | 8 | from torch.autograd import Variable 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | MODULES = {} 13 | 14 | 15 | def log_sum_exp(vec, dim=0): 16 | """Calculate LogSumExp (used in the CRF layer). 17 | 18 | :param vec: Input vector. 19 | :param dim: 20 | :return: 21 | """ 22 | m, _ = torch.max(vec, dim) 23 | m_exp = m.unsqueeze(-1).expand_as(vec) 24 | return m + torch.log(torch.sum(torch.exp(vec - m_exp), dim)) 25 | 26 | 27 | def sequence_mask(lens, max_len=None): 28 | batch_size = lens.size(0) 29 | 30 | if max_len is None: 31 | max_len = lens.max().data[0] 32 | 33 | ranges = torch.arange(0, max_len).long() 34 | ranges = ranges.unsqueeze(0).expand(batch_size, max_len) 35 | ranges = Variable(ranges) 36 | 37 | if lens.data.is_cuda: 38 | ranges = ranges.cuda() 39 | 40 | lens_exp = lens.unsqueeze(1).expand_as(ranges) 41 | mask = ranges < lens_exp 42 | 43 | return mask 44 | 45 | 46 | def sequence_masks(lens): 47 | batch_size = lens.size(0) 48 | 49 | max_len = lens.max().data[0] 50 | 51 | ranges = torch.arange(0, max_len).long() 52 | ranges = ranges.unsqueeze(0).expand(batch_size, max_len) 53 | ranges = Variable(ranges) 54 | 55 | if lens.data.is_cuda: 56 | ranges = ranges.cuda() 57 | 58 | lens_exp = lens.unsqueeze(1).expand_as(ranges) 59 | 60 | return (ranges < lens_exp).float(), (ranges >= lens_exp).float() 61 | 62 | 63 | class Linear(nn.Linear): 64 | def __init__(self, in_features, out_features, bias=True): 65 | 66 | super(Linear, self).__init__(in_features, out_features, bias=bias) 67 | self.in_features = in_features 68 | self.out_features = out_features 69 | self.initialize() 70 | 71 | def initialize(self): 72 | I.orthogonal(self.weight.data) 73 | 74 | def zero(self): 75 | self.weight.data.fill_(0.0) 76 | self.bias.data.fill_(0.0) 77 | 78 | 79 | class Highway(nn.Module): 80 | 81 | def __init__(self, size, num_layers, activation='relu'): 82 | super(Highway, self).__init__() 83 | 84 | self.size = size 85 | self.num_layers = num_layers 86 | self.f = getattr(F, activation) 87 | self.nonlinear = nn.ModuleList( 88 | [Linear(size, size) for _ in range(num_layers)]) 89 | self.linear = nn.ModuleList( 90 | [Linear(size, size) for _ in range(num_layers)]) 91 | self.gate = nn.ModuleList( 92 | [Linear(size, size) for _ in range(num_layers)]) 93 | 94 | def forward(self, x): 95 | """ 96 | :param x: tensor with shape of [batch_size, size] 97 | :return: tensor with shape of [batch_size, size] 98 | applies σ(x) ⨀ (f(G(x))) + (1 - σ(x)) ⨀ (Q(x)) transformation | G and Q is affine transformation, 99 | f is non-linear transformation, σ(x) is affine transformation with sigmoid non-linearition 100 | and ⨀ is element-wise multiplication 101 | """ 102 | 103 | for layer in range(self.num_layers): 104 | gate = F.sigmoid(self.gate[layer](x)) 105 | 106 | nonlinear = self.f(self.nonlinear[layer](x)) 107 | linear = self.linear[layer](x) 108 | 109 | x = gate * nonlinear + (1 - gate) * linear 110 | 111 | return x 112 | 113 | 114 | class Embedding(nn.Embedding): 115 | def __init__(self, 116 | num_embeddings, 117 | embedding_dim, 118 | padding_idx=None, 119 | max_norm=None, 120 | norm_type=2, 121 | scale_grad_by_freq=False, 122 | sparse=False, 123 | trainable=False, 124 | padding=0, 125 | file=None, 126 | stats=False, 127 | vocab=None): 128 | 129 | self.num_embeddings = num_embeddings 130 | self.embedding_dim = embedding_dim 131 | self.padding_idx = padding_idx 132 | self.max_norm = max_norm 133 | self.norm_type = norm_type 134 | self.scale_grad_by_freq = scale_grad_by_freq 135 | self.sparse = sparse 136 | self.trainable = trainable 137 | self.padding = padding 138 | self.file = file 139 | self.stats = stats 140 | self.vocab = vocab 141 | 142 | super(Embedding, self).__init__(num_embeddings + padding, 143 | embedding_dim, 144 | padding_idx, 145 | max_norm, 146 | norm_type, 147 | scale_grad_by_freq, 148 | sparse) 149 | # self.gpu = False 150 | self.output_size = embedding_dim 151 | if not trainable: 152 | self.weight.requires_grad = False 153 | if file and vocab: 154 | self.load(file, vocab, stats=stats) 155 | else: 156 | self.initialize() 157 | 158 | def initialize(self): 159 | I.xavier_normal(self.weight.data) 160 | 161 | # def cuda(self, device=None): 162 | # self.gpu = True 163 | # if self.allow_gpu: 164 | # return super(Embedding, self).cuda(device=None) 165 | # 166 | # def cpu(self): 167 | # self.gpu = False 168 | # return super(Embedding, self).cpu() 169 | 170 | def save(self, path, vocab, stats=False): 171 | """Save embedding to file. 172 | 173 | :param path: Path to the embedding file. 174 | :param vocab: Token vocab. 175 | :param stats: Write stats line (default=False). 176 | """ 177 | embeds = self.weight.data.cpu().numpy() 178 | with open(path, 'w', encoding='utf-8') as w: 179 | if stats: 180 | embed_num, embed_dim = self.weight.data.size() 181 | w.write('{} {}\n'.format(embed_num, embed_dim)) 182 | for token, idx in vocab.items(): 183 | embed = ' '.join(map(lambda x: str(x), embeds[idx])) 184 | w.write('{} {}\n'.format(token, embed)) 185 | 186 | def load(self, path, vocab, stats=False): 187 | logger.info('Loading embedding from {}'.format(path)) 188 | with open(path, 'r', encoding='utf-8') as r: 189 | if stats: 190 | r.readline() 191 | try: 192 | for line in r: 193 | line = line.strip().split(' ') 194 | token = line[0] 195 | if token in vocab: 196 | vector = self.weight.data.new( 197 | [float(v) for v in line[1:]]) 198 | self.weight.data[vocab[token]] = vector 199 | except UnicodeDecodeError as e: 200 | print(e) 201 | 202 | 203 | class LSTM(nn.LSTM): 204 | 205 | def __init__(self, input_size, hidden_size, num_layers=1, bias=True, 206 | batch_first=False, dropout=0, bidirectional=False, 207 | forget_bias=0): 208 | self.input_size = input_size 209 | self.hidden_size = hidden_size 210 | self.num_layers = num_layers 211 | self.bias = bias 212 | self.batch_first = batch_first 213 | self.dropout = dropout 214 | self.bidirectional = bidirectional 215 | self.forget_bias = forget_bias 216 | self.output_size = hidden_size * (2 if bidirectional else 1) 217 | 218 | super(LSTM, self).__init__(input_size=input_size, 219 | hidden_size=hidden_size, 220 | num_layers=num_layers, 221 | bias=bias, 222 | batch_first=batch_first, 223 | dropout=dropout, 224 | bidirectional=bidirectional) 225 | self.initialize() 226 | 227 | def initialize(self): 228 | for n, p in self.named_parameters(): 229 | if 'weight' in n: 230 | I.orthogonal(p) 231 | elif 'bias' in n: 232 | bias_size = p.size(0) 233 | p.data[bias_size // 4:bias_size // 2].fill_(self.forget_bias) 234 | 235 | 236 | class CharCNN(nn.Module): 237 | def __init__(self, 238 | vocab_size, 239 | dimension, 240 | filters 241 | ): 242 | super(CharCNN, self).__init__() 243 | 244 | self.output_size = sum([x[1] for x in filters]) 245 | self.embedding = Embedding(vocab_size, 246 | dimension, 247 | padding_idx=0, 248 | sparse=True, 249 | padding=2) 250 | self.convs = nn.ModuleList([nn.Conv2d(1, x[1], (x[0], dimension)) 251 | for x in filters]) 252 | 253 | def forward(self, inputs, lens=None): 254 | inputs_embed = self.embedding.forward(inputs) 255 | # input channel 256 | inputs_embed = inputs_embed.unsqueeze(1) 257 | # sequeeze output channel 258 | conv_outputs = [conv.forward(inputs_embed).squeeze(3) 259 | for conv in self.convs] 260 | max_pool_outputs = [F.max_pool1d(i, i.size(2)).squeeze(2) 261 | for i in conv_outputs] 262 | outputs = torch.cat(max_pool_outputs, 1) 263 | return outputs 264 | 265 | 266 | class CRF(nn.Module): 267 | 268 | def __init__(self, label_vocab): 269 | 270 | super(CRF, self).__init__() 271 | 272 | self.label_vocab = label_vocab 273 | self.label_size = len(label_vocab) + 2 274 | self.start = self.label_size - 2 275 | self.end = self.label_size - 1 276 | transition = torch.randn(self.label_size, self.label_size) 277 | self.transition = nn.Parameter(transition) 278 | self.initialize() 279 | 280 | def initialize(self): 281 | self.transition.data[:, self.end] = -100.0 282 | self.transition.data[self.start, :] = -100.0 283 | 284 | def pad_logits(self, logits, lens): 285 | lens = lens.data 286 | batch_size, seq_len, label_num = logits.size() 287 | pads = Variable(logits.data.new(batch_size, seq_len, 2).fill_(-100.0), 288 | requires_grad=False) 289 | logits = torch.cat([logits, pads], dim=2) 290 | # e_s = logits.data.new([-100.0] * label_num + [0, 100]) 291 | # e_s_mat = logits.data.new(logits.size()).fill_(0) 292 | # for i in range(batch_size): 293 | # if lens[i] < seq_len: 294 | # # logits[i][lens[i]] += e_s 295 | # e_s_mat[i][lens[i]] = e_s 296 | # logits += Variable(e_s_mat) 297 | return logits 298 | 299 | def calc_binary_score(self, labels, lens): 300 | batch_size, seq_len = labels.size() 301 | 302 | labels_ext = Variable(labels.data.new(batch_size, seq_len + 2)) 303 | labels_ext[:, 0] = self.start 304 | labels_ext[:, 1:-1] = labels 305 | mask = sequence_mask(lens + 1, max_len=(seq_len + 2)).long() 306 | pad_stop = Variable(labels.data.new(1).fill_(self.end)) 307 | pad_stop = pad_stop.unsqueeze(-1).expand(batch_size, seq_len + 2) 308 | labels_ext = (1 - mask) * pad_stop + mask * labels_ext 309 | labels = labels_ext 310 | 311 | trn = self.transition 312 | trn_exp = trn.unsqueeze(0).expand(batch_size, *trn.size()) 313 | lbl_r = labels[:, 1:] 314 | lbl_rexp = lbl_r.unsqueeze(-1).expand(*lbl_r.size(), trn.size(0)) 315 | trn_row = torch.gather(trn_exp, 1, lbl_rexp) 316 | 317 | lbl_lexp = labels[:, :-1].unsqueeze(-1) 318 | trn_scr = torch.gather(trn_row, 2, lbl_lexp) 319 | trn_scr = trn_scr.squeeze(-1) 320 | 321 | mask = sequence_mask(lens + 1).float() 322 | trn_scr = trn_scr * mask 323 | score = trn_scr 324 | 325 | return score 326 | 327 | def calc_unary_score(self, logits, labels, lens): 328 | labels_exp = labels.unsqueeze(-1) 329 | scores = torch.gather(logits, 2, labels_exp).squeeze(-1) 330 | mask = sequence_mask(lens).float() 331 | scores = scores * mask 332 | return scores 333 | 334 | def calc_gold_score(self, logits, labels, lens): 335 | unary_score = self.calc_unary_score(logits, labels, lens).sum( 336 | 1).squeeze(-1) 337 | binary_score = self.calc_binary_score(labels, lens).sum(1).squeeze(-1) 338 | return unary_score + binary_score 339 | 340 | def calc_norm_score(self, logits, lens): 341 | batch_size, seq_len, feat_dim = logits.size() 342 | alpha = logits.data.new(batch_size, self.label_size).fill_(-10000.0) 343 | alpha[:, self.start] = 0 344 | alpha = Variable(alpha) 345 | lens_ = lens.clone() 346 | 347 | logits_t = logits.transpose(1, 0) 348 | for logit in logits_t: 349 | logit_exp = logit.unsqueeze(-1).expand(batch_size, 350 | *self.transition.size()) 351 | alpha_exp = alpha.unsqueeze(1).expand(batch_size, 352 | *self.transition.size()) 353 | trans_exp = self.transition.unsqueeze(0).expand_as(alpha_exp) 354 | mat = logit_exp + alpha_exp + trans_exp 355 | alpha_nxt = log_sum_exp(mat, 2).squeeze(-1) 356 | 357 | mask = (lens_ > 0).float().unsqueeze(-1).expand_as(alpha) 358 | alpha = mask * alpha_nxt + (1 - mask) * alpha 359 | lens_ = lens_ - 1 360 | 361 | alpha = alpha + self.transition[self.end].unsqueeze(0).expand_as(alpha) 362 | norm = log_sum_exp(alpha, 1).squeeze(-1) 363 | 364 | return norm 365 | 366 | def viterbi_decode(self, logits, lens): 367 | """Borrowed from pytorch tutorial 368 | Arguments: 369 | logits: [batch_size, seq_len, n_labels] FloatTensor 370 | lens: [batch_size] LongTensor 371 | """ 372 | batch_size, seq_len, n_labels = logits.size() 373 | vit = logits.data.new(batch_size, self.label_size).fill_(-10000) 374 | vit[:, self.start] = 0 375 | vit = Variable(vit) 376 | c_lens = lens.clone() 377 | 378 | logits_t = logits.transpose(1, 0) 379 | pointers = [] 380 | for logit in logits_t: 381 | vit_exp = vit.unsqueeze(1).expand(batch_size, n_labels, n_labels) 382 | trn_exp = self.transition.unsqueeze(0).expand_as(vit_exp) 383 | vit_trn_sum = vit_exp + trn_exp 384 | vt_max, vt_argmax = vit_trn_sum.max(2) 385 | 386 | vt_max = vt_max.squeeze(-1) 387 | vit_nxt = vt_max + logit 388 | pointers.append(vt_argmax.squeeze(-1).unsqueeze(0)) 389 | 390 | mask = (c_lens > 0).float().unsqueeze(-1).expand_as(vit_nxt) 391 | vit = mask * vit_nxt + (1 - mask) * vit 392 | 393 | mask = (c_lens == 1).float().unsqueeze(-1).expand_as(vit_nxt) 394 | vit += mask * self.transition[self.end].unsqueeze( 395 | 0).expand_as(vit_nxt) 396 | 397 | c_lens = c_lens - 1 398 | 399 | pointers = torch.cat(pointers) 400 | scores, idx = vit.max(1) 401 | idx = idx.squeeze(-1) 402 | paths = [idx.unsqueeze(1)] 403 | for argmax in reversed(pointers): 404 | idx_exp = idx.unsqueeze(-1) 405 | idx = torch.gather(argmax, 1, idx_exp) 406 | idx = idx.squeeze(-1) 407 | 408 | paths.insert(0, idx.unsqueeze(1)) 409 | 410 | paths = torch.cat(paths[1:], 1) 411 | scores = scores.squeeze(-1) 412 | 413 | return scores, paths 414 | 415 | 416 | class LstmCrf(nn.Module): 417 | def __init__(self, 418 | token_vocab, 419 | label_vocab, 420 | char_vocab, 421 | 422 | word_embedding, 423 | char_embedding, 424 | crf, 425 | lstm, 426 | input_layer=None, 427 | univ_layer=None, 428 | spec_layer=None, 429 | 430 | embedding_dropout_prob=0, 431 | lstm_dropout_prob=0, 432 | linear_dropout_prob=0, 433 | use_char_embedding=True, 434 | char_highway=None, 435 | ): 436 | super(LstmCrf, self).__init__() 437 | 438 | self.token_vocab = token_vocab 439 | self.label_vocab = label_vocab 440 | self.char_vocab = char_vocab 441 | self.idx_label = {idx: label for label, idx in label_vocab.items()} 442 | self.use_char_embedding = use_char_embedding 443 | 444 | self.word_embedding = word_embedding 445 | self.char_embedding = char_embedding 446 | self.feat_dim = word_embedding.output_size 447 | if use_char_embedding: 448 | self.feat_dim += char_embedding.output_size 449 | 450 | self.lstm = lstm 451 | self.input_layer = input_layer 452 | self.univ_layer = univ_layer 453 | self.spec_layer = spec_layer 454 | self.crf = crf 455 | self.char_highway = char_highway 456 | self.lstm_dropout = nn.Dropout(p=lstm_dropout_prob) 457 | self.embedding_dropout = nn.Dropout(p=embedding_dropout_prob) 458 | self.linear_dropout = nn.Dropout(p=linear_dropout_prob) 459 | self.label_size = len(label_vocab) 460 | if spec_layer: 461 | self.spec_gate = Linear(spec_layer.in_features, 462 | spec_layer.out_features) 463 | 464 | def cuda(self, device=None): 465 | for module in self.children(): 466 | module.cuda(device) 467 | return self 468 | 469 | def cpu(self): 470 | for module in self.children(): 471 | module.cpu() 472 | return self 473 | 474 | def forward_model(self, inputs, lens, chars=None, char_lens=None): 475 | """From the input to the linear layer, not including the CRF layer. 476 | 477 | :param inputs: Input tensor of size batch_size * max_seq_len (word indexes). 478 | :param lens: Sequence length tensor of size batch_size (sequence lengths). 479 | :param chars: Input character tensor of size batch_size * max_seq_len * max_word_len (character indexes). 480 | :param char_lens: Word length tensor of size (batch_size * max_seq_len) * max_word_len. 481 | :return: Linear layer output tensor of size batch_size * max_seq_len * label_num. 482 | """ 483 | batch_size, seq_len = inputs.size() 484 | 485 | # Word embedding 486 | inputs_embed = self.word_embedding.forward(inputs) 487 | # Character embedding 488 | if self.use_char_embedding: 489 | chars_embed = self.char_embedding.forward(chars, char_lens) 490 | if self.char_highway: 491 | chars_embed = self.char_highway.forward(chars_embed) 492 | chars_embed = chars_embed.view(batch_size, seq_len, -1) 493 | inputs_embed = torch.cat([inputs_embed, chars_embed], dim=2) 494 | inputs_embed = self.embedding_dropout.forward(inputs_embed) 495 | 496 | # LSTM layer 497 | inputs_packed = R.pack_padded_sequence(inputs_embed, 498 | lens.data.tolist(), 499 | batch_first=True) 500 | lstm_out, _ = self.lstm.forward(inputs_packed) 501 | lstm_out, _ = R.pad_packed_sequence(lstm_out, batch_first=True) 502 | lstm_out = lstm_out.contiguous().view(-1, self.lstm.output_size) 503 | lstm_out = self.lstm_dropout.forward(lstm_out) 504 | 505 | # Linear layer 506 | univ_feats = self.univ_layer.forward(lstm_out) 507 | if self.spec_layer: 508 | spec_feats = self.spec_layer.forward(lstm_out) 509 | gate = F.sigmoid(self.spec_gate.forward(lstm_out)) 510 | outputs = gate * spec_feats + (1 - gate) * univ_feats 511 | else: 512 | outputs = univ_feats 513 | outputs = outputs.view(batch_size, seq_len, self.label_size) 514 | 515 | return outputs 516 | 517 | def predict(self, inputs, labels, lens, chars=None, char_lens=None): 518 | """From the input to the CRF output (prediction mode). 519 | 520 | :param inputs: Input tensor of size batch_size * max_seq_len (word indexes). 521 | :param labels: Gold labels. 522 | :param lens: Sequence length tensor of size batch_size (sequence lengths). 523 | :param chars: Input character tensor of size batch_size * max_seq_len * max_word_len (character indexes). 524 | :param char_lens: Word length tensor of size (batch_size * max_seq_len) * max_word_len. 525 | :return: Prediction and loss. 526 | """ 527 | self.eval() 528 | loglik, logits = self.loglik(inputs, labels, lens, chars, char_lens) 529 | loss = -loglik.mean() 530 | scores, preds = self.crf.viterbi_decode(logits, lens) 531 | self.train() 532 | return preds, loss 533 | 534 | def loglik(self, inputs, labels, lens, chars=None, char_lens=None): 535 | logits = self.forward_model(inputs, lens, chars, char_lens) 536 | logits = self.crf.pad_logits(logits, lens) 537 | norm_score = self.crf.calc_norm_score(logits, lens) 538 | gold_score = self.crf.calc_gold_score(logits, labels, lens) 539 | loglik = gold_score - norm_score 540 | 541 | return loglik, logits --------------------------------------------------------------------------------