├── README.md ├── data └── .gitignore ├── logs └── .gitignore ├── models └── .gitignore ├── results └── .gitignore ├── src ├── models │ ├── __init__.py │ ├── attentions.py │ ├── data_loader.py │ ├── encoder.py │ ├── model_builder.py │ ├── optimizers.py │ ├── reporter.py │ ├── stats.py │ └── trainer.py ├── others │ ├── __init__.py │ ├── logging.py │ └── pyrouge.py └── train.py └── temp └── .gitignore /README.md: -------------------------------------------------------------------------------- 1 | # SUMO 2 | 3 | **This code is for paper `Single Document Summarization as Tree Induction`** 4 | 5 | **Python version**: This code is in Python3.6 6 | 7 | **Package Requirements**: pytorch tensorboardX pyrouge 8 | 9 | Some codes are borrowed from ONMT(https://github.com/OpenNMT/OpenNMT-py) 10 | 11 | ## Data Preparation: 12 | 13 | Download the processed data for CNN/Dailymail 14 | 15 | download https://drive.google.com/open?id=1BM9wvnyXx9JvgW2um0Fk9bgQRrx03Tol 16 | 17 | unzip the zipfile and copy to `data/` 18 | 19 | ## Model Training 20 | 21 | ``` 22 | python train.py -mode train -onmt_path ../data/cnndm_data/cnndm -batch_size 50000 -visible_gpu 1 -report 23 | _every 100 -optim adam -lr 1 -save_checkpoint_steps 1000 -train_steps 150000 -model_path ../models/str_l5_i3 -log_file 24 | ../logs/str_l5_i3 -local_layers 5 -inter_layers 3 -dropout 0.1 -emb_size 128 -hidden_size 128 -heads 4 -ff_size 512 -dec 25 | ay_method noam -warmup_steps 8000 -structured 26 | ``` 27 | 28 | 29 | * `-mode` can be {`train, validate, test`}, where `validate` will inspect the model directory and evaluate the model for each newly saved checkpoint, `test` need to be used with `-test_from`, indicating the checkpoint you want to use 30 | 31 | ## Model Evaluation 32 | After the training finished, run 33 | ``` 34 | python train.py -mode validate -onmt_path ../data/cnndm_data/cnndm -batch_size 50000 -visible_gpu 1 -report 35 | _every 100 -optim adam -lr 1 -save_checkpoint_steps 1000 -train_steps 150000 -model_path ../models/str_l5_i3 -log_file 36 | ../logs/str_l5_i3 -local_layers 5 -inter_layers 3 -dropout 0.1 -emb_size 128 -hidden_size 128 -heads 4 -ff_size 512 -dec 37 | ay_method noam -warmup_steps 8000 -structured -test_all 38 | ``` 39 | 40 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /results/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/SUMO/f40c52de24381f8b58a90fbc2e57abd93bad56b7/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/attentions.py: -------------------------------------------------------------------------------- 1 | """ Multi-Head Attention module """ 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | # from onmt.utils.misc import aeq 8 | 9 | 10 | class MultiHeadedAttention(nn.Module): 11 | def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True): 12 | assert model_dim % head_count == 0 13 | self.dim_per_head = model_dim // head_count 14 | self.model_dim = model_dim 15 | 16 | super(MultiHeadedAttention, self).__init__() 17 | self.head_count = head_count 18 | 19 | self.linear_keys = nn.Linear(model_dim, 20 | head_count * self.dim_per_head) 21 | self.linear_values = nn.Linear(model_dim, 22 | head_count * self.dim_per_head) 23 | self.linear_query = nn.Linear(model_dim, 24 | head_count * self.dim_per_head) 25 | self.softmax = nn.Softmax(dim=-1) 26 | self.dropout = nn.Dropout(dropout) 27 | self.use_final_linear = use_final_linear 28 | if (self.use_final_linear): 29 | self.final_linear = nn.Linear(model_dim, model_dim) 30 | 31 | def forward(self, key, value, query, mask=None, 32 | layer_cache=None, type=None, predefined_graph_1=None): 33 | batch_size = key.size(0) 34 | dim_per_head = self.dim_per_head 35 | head_count = self.head_count 36 | 37 | def shape(x): 38 | """ projection """ 39 | return x.view(batch_size, -1, head_count, dim_per_head) \ 40 | .transpose(1, 2) 41 | 42 | def unshape(x): 43 | """ compute context """ 44 | return x.transpose(1, 2).contiguous() \ 45 | .view(batch_size, -1, head_count * dim_per_head) 46 | 47 | # 1) Project key, value, and query. 48 | if layer_cache is not None: 49 | if type == "self": 50 | query, key, value = self.linear_query(query), \ 51 | self.linear_keys(query), \ 52 | self.linear_values(query) 53 | 54 | key = shape(key) 55 | value = shape(value) 56 | 57 | if layer_cache is not None: 58 | device = key.device 59 | if layer_cache["self_keys"] is not None: 60 | key = torch.cat( 61 | (layer_cache["self_keys"].to(device), key), 62 | dim=2) 63 | if layer_cache["self_values"] is not None: 64 | value = torch.cat( 65 | (layer_cache["self_values"].to(device), value), 66 | dim=2) 67 | layer_cache["self_keys"] = key 68 | layer_cache["self_values"] = value 69 | elif type == "context": 70 | query = self.linear_query(query) 71 | if layer_cache is not None: 72 | if layer_cache["memory_keys"] is None: 73 | key, value = self.linear_keys(key), \ 74 | self.linear_values(value) 75 | key = shape(key) 76 | value = shape(value) 77 | else: 78 | key, value = layer_cache["memory_keys"], \ 79 | layer_cache["memory_values"] 80 | layer_cache["memory_keys"] = key 81 | layer_cache["memory_values"] = value 82 | else: 83 | key, value = self.linear_keys(key), \ 84 | self.linear_values(value) 85 | key = shape(key) 86 | value = shape(value) 87 | else: 88 | key = self.linear_keys(key) 89 | value = self.linear_values(value) 90 | query = self.linear_query(query) 91 | key = shape(key) 92 | value = shape(value) 93 | 94 | query = shape(query) 95 | 96 | 97 | # 2) Calculate and scale scores. 98 | query = query / math.sqrt(dim_per_head) 99 | scores = torch.matmul(query, key.transpose(2, 3)) 100 | 101 | if mask is not None: 102 | mask = mask.unsqueeze(1).expand_as(scores) 103 | scores = scores.masked_fill(mask, -1e18) 104 | 105 | # 3) Apply attention dropout and compute context vectors. 106 | 107 | attn = self.softmax(scores) 108 | 109 | if (not predefined_graph_1 is None): 110 | attn_masked = attn[:, -1] * predefined_graph_1 111 | attn_masked = attn_masked / (torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9) 112 | 113 | attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1) 114 | 115 | drop_attn = self.dropout(attn) 116 | if (self.use_final_linear): 117 | context = unshape(torch.matmul(drop_attn, value)) 118 | output = self.final_linear(context) 119 | return output 120 | else: 121 | context = torch.matmul(drop_attn, value) 122 | return context 123 | 124 | 125 | 126 | def _getMatrixTree_multi(scores, root): 127 | A = scores.exp() 128 | R = root.exp() 129 | 130 | L = torch.sum(A, 1) 131 | L = torch.diag_embed(L) 132 | L = L - A 133 | LL = L + torch.diag_embed(R) 134 | LL_inv = torch.inverse(LL) # batch_l, doc_l, doc_l 135 | LL_inv_diag = torch.diagonal(LL_inv, 0, 1, 2) 136 | d0 = R * LL_inv_diag 137 | LL_inv_diag = torch.unsqueeze(LL_inv_diag, 2) 138 | 139 | _A = torch.transpose(A, 1, 2) 140 | _A = _A * LL_inv_diag 141 | tmp1 = torch.transpose(_A, 1, 2) 142 | tmp2 = A * torch.transpose(LL_inv, 1, 2) 143 | 144 | d = tmp1 - tmp2 145 | return d, d0 146 | 147 | 148 | class StructuredAttention(nn.Module): 149 | def __init__(self, model_dim, dropout=0.1): 150 | self.model_dim = model_dim 151 | 152 | super(StructuredAttention, self).__init__() 153 | 154 | self.linear_keys = nn.Linear(model_dim, self.model_dim) 155 | self.linear_query = nn.Linear(model_dim, self.model_dim) 156 | self.linear_root = nn.Linear(model_dim, 1) 157 | self.dropout = nn.Dropout(dropout) 158 | 159 | def forward(self, x, mask=None): 160 | 161 | 162 | key = self.linear_keys(x) 163 | query = self.linear_query(x) 164 | root = self.linear_root(x).squeeze(-1) 165 | 166 | query = query / math.sqrt(self.model_dim) 167 | scores = torch.matmul(query, key.transpose(1, 2)) 168 | 169 | mask = mask.float() 170 | root = root - mask.squeeze(1) * 50 171 | root = torch.clamp(root, min=-40) 172 | scores = scores - mask * 50 173 | scores = scores - torch.transpose(mask, 1, 2) * 50 174 | scores = torch.clamp(scores, min=-40) 175 | # _logits = _logits + (tf.transpose(bias, [0, 2, 1]) - 1) * 40 176 | # _logits = tf.clip_by_value(_logits, -40, 1e10) 177 | 178 | d, d0 = _getMatrixTree_multi(scores, root) 179 | attn = torch.transpose(d, 1,2) 180 | if mask is not None: 181 | mask = mask.expand_as(scores).byte() 182 | attn = attn.masked_fill(mask, 0) 183 | 184 | return attn, d0 185 | 186 | 187 | class MultiHeadedPooling(nn.Module): 188 | def __init__(self, model_dim): 189 | self.model_dim = model_dim 190 | super(MultiHeadedPooling, self).__init__() 191 | self.linear_keys = nn.Linear(model_dim, 1) 192 | self.softmax = nn.Softmax(dim=-1) 193 | 194 | def forward(self, x, mask=None): 195 | 196 | scores = self.linear_keys(x).squeeze(-1) 197 | 198 | 199 | if mask is not None: 200 | scores = scores.masked_fill(1 - mask, -1e18) 201 | 202 | attn = self.softmax(scores).unsqueeze(-1) 203 | output = torch.sum(attn * x, -2) 204 | return output 205 | -------------------------------------------------------------------------------- /src/models/data_loader.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import glob 3 | import random 4 | import torch 5 | 6 | from others.logging import logger 7 | 8 | 9 | class Batch(object): 10 | def _pad(self, data, height, width, pad_id): 11 | rtn_data = [d + [pad_id] * (width - len(d)) for d in data] 12 | rtn_length = [len(d) for d in data] 13 | rtn_data = rtn_data + [[pad_id] * width] * (height - len(data)) 14 | rtn_length = rtn_length + [0] * (height - len(data)) 15 | return rtn_data, rtn_length 16 | 17 | def _pad2(self, data, width, pad_id): 18 | rtn_data = [d + [pad_id] * (width - len(d)) for d in data] 19 | return rtn_data 20 | 21 | def __init__(self, data=None, pad_id=None, device=None, is_test=False): 22 | """Create a Batch from a list of examples.""" 23 | if data is not None: 24 | self.batch_size = len(data) 25 | src = [x[0] for x in data] 26 | labels = [x[1] for x in data] 27 | max_nsent = max([len(e) for e in src]) 28 | max_ntoken = max([max([len(p) for p in e]) for e in src]) 29 | labels = self._pad2(labels, max_nsent, 0) 30 | labels = torch.tensor(labels).float() 31 | 32 | _src = [self._pad(e, max_nsent, max_ntoken, pad_id) for e in src] 33 | src = torch.stack([torch.tensor(e[0]) for e in _src]) # batch_size, n_block, block_size 34 | src_length = torch.tensor([sum(e[1]) for e in _src]) 35 | 36 | setattr(self, 'src', src.to(device)) 37 | setattr(self, 'src_length', src_length.to(device)) 38 | setattr(self, 'labels', labels.to(device)) 39 | 40 | 41 | # _tgt = self._pad(tgt, width=max([len(d) for d in tgt]), height=len(tgt), pad_id=pad_id) 42 | # tgt = torch.tensor(_tgt[0]).transpose(0, 1) # tgt_len * batch_size 43 | # setattr(self, 'tgt', tgt.to(device)) 44 | 45 | if (is_test): 46 | src_str = [x[2] for x in data] 47 | setattr(self, 'src_str', src_str) 48 | tgt_str = [x[3] for x in data] 49 | setattr(self, 'tgt_str', tgt_str) 50 | 51 | def __len__(self): 52 | return self.batch_size 53 | 54 | 55 | def batch(data, batch_size): 56 | minibatch, size_so_far = [], 0 57 | for ex in data: 58 | minibatch.append(ex) 59 | size_so_far = simple_batch_size_fn(ex, len(minibatch)) 60 | if size_so_far == batch_size: 61 | yield minibatch 62 | minibatch, size_so_far = [], 0 63 | elif size_so_far > batch_size: 64 | yield minibatch[:-1] 65 | minibatch, size_so_far = minibatch[-1:], simple_batch_size_fn(ex, 1) 66 | if minibatch: 67 | yield minibatch 68 | 69 | 70 | def load_dataset(args, corpus_type, shuffle): 71 | assert corpus_type in ["train", "valid", "test"] 72 | 73 | def _lazy_dataset_loader(pt_file, corpus_type): 74 | dataset = torch.load(pt_file) 75 | logger.info('Loading %s dataset from %s, number of examples: %d' % 76 | (corpus_type, pt_file, len(dataset))) 77 | return dataset 78 | 79 | # Sort the glob output by file name (by increasing indexes). 80 | pts = sorted(glob.glob(args.onmt_path + '.' + corpus_type + '.[0-9]*.pt')) 81 | if pts: 82 | if (shuffle): 83 | random.shuffle(pts) 84 | 85 | for pt in pts: 86 | yield _lazy_dataset_loader(pt, corpus_type) 87 | else: 88 | # Only one inputters.*Dataset, simple! 89 | pt = args.onmt_path + '.' + corpus_type + '.pt' 90 | yield _lazy_dataset_loader(pt, corpus_type) 91 | 92 | 93 | def simple_batch_size_fn(new, count): 94 | src, labels = new[0], new[1] 95 | global max_n_sents, max_n_tokens, max_size 96 | if count == 1: 97 | max_size = 0 98 | max_n_sents=0 99 | max_n_tokens=0 100 | max_n_sents = max(max_n_sents, len(src)) 101 | max_n_tokens = max(max_n_tokens, max([len(s) for s in src])) 102 | max_size = max(max_size, max_n_sents*max_n_tokens) 103 | src_elements = count * max_size 104 | return src_elements 105 | 106 | 107 | class Dataloader(object): 108 | def __init__(self, args, datasets, symbols, batch_size, 109 | device, shuffle, is_test): 110 | self.args = args 111 | self.datasets = datasets 112 | self.symbols = symbols 113 | self.batch_size = batch_size 114 | self.device = device 115 | self.shuffle = shuffle 116 | self.is_test = is_test 117 | self.cur_iter = self._next_dataset_iterator(datasets) 118 | assert self.cur_iter is not None 119 | 120 | def __iter__(self): 121 | dataset_iter = (d for d in self.datasets) 122 | while self.cur_iter is not None: 123 | for batch in self.cur_iter: 124 | yield batch 125 | self.cur_iter = self._next_dataset_iterator(dataset_iter) 126 | 127 | 128 | def _next_dataset_iterator(self, dataset_iter): 129 | try: 130 | # Drop the current dataset for decreasing memory 131 | if hasattr(self, "cur_dataset"): 132 | self.cur_dataset = None 133 | gc.collect() 134 | del self.cur_dataset 135 | gc.collect() 136 | 137 | self.cur_dataset = next(dataset_iter) 138 | except StopIteration: 139 | return None 140 | 141 | return DataIterator(args = self.args, 142 | dataset=self.cur_dataset, symbols=self.symbols, batch_size=self.batch_size, 143 | device=self.device, shuffle=self.shuffle, is_test=self.is_test) 144 | 145 | 146 | class DataIterator(object): 147 | def __init__(self, args, dataset, symbols, batch_size, device=None, is_test=False, 148 | shuffle=True): 149 | self.args = args 150 | self.batch_size, self.is_test, self.dataset = batch_size, is_test, dataset 151 | self.iterations = 0 152 | self.device = device 153 | self.shuffle = shuffle 154 | 155 | self.sort_key = lambda x: len(x[1]) 156 | 157 | self._iterations_this_epoch = 0 158 | 159 | self.symbols = symbols 160 | 161 | def data(self): 162 | if self.shuffle: 163 | random.shuffle(self.dataset) 164 | xs = self.dataset 165 | return xs 166 | 167 | def preprocess(self, ex): 168 | src = ex['src'] 169 | labels = [0]*len(src) 170 | for l in ex['labels'][0]: 171 | labels[l] = 1 172 | idxs = [i for i,s in enumerate(ex['src']) if (len(s)>self.args.min_src_ntokens)] 173 | 174 | 175 | src = [src[i][:self.args.max_src_ntokens] for i in idxs] 176 | labels = [labels[i] for i in idxs] 177 | src = src[:self.args.max_nsents] 178 | labels = labels[:self.args.max_nsents] 179 | 180 | if(len(src) batch_size: 207 | yield minibatch[:-1] 208 | minibatch, size_so_far = minibatch[-1:], simple_batch_size_fn(ex, 1) 209 | if minibatch: 210 | yield minibatch 211 | 212 | def create_batches(self): 213 | """ Create batches """ 214 | data = self.data() 215 | for buffer in self.batch_buffer(data, self.batch_size * 100): 216 | 217 | p_batch = sorted(buffer, key=lambda x: max([len(s) for s in x[0]])) 218 | p_batch = sorted(p_batch, key=lambda x: len(x[0])) 219 | p_batch = batch(p_batch, self.batch_size) 220 | 221 | p_batch = list(p_batch) 222 | if (self.shuffle): 223 | random.shuffle(p_batch) 224 | for b in p_batch: 225 | yield b 226 | 227 | def __iter__(self): 228 | while True: 229 | self.batches = self.create_batches() 230 | for idx, minibatch in enumerate(self.batches): 231 | # fast-forward if loaded from state 232 | if self._iterations_this_epoch > idx: 233 | continue 234 | self.iterations += 1 235 | self._iterations_this_epoch += 1 236 | batch = Batch(minibatch, self.symbols['PAD'], self.device, self.is_test) 237 | 238 | yield batch 239 | return 240 | -------------------------------------------------------------------------------- /src/models/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Attention is All You Need" 3 | """ 4 | import math 5 | 6 | import torch.nn as nn 7 | import torch 8 | from models.attentions import MultiHeadedAttention, MultiHeadedPooling, StructuredAttention 9 | 10 | 11 | 12 | class PositionwiseFeedForward(nn.Module): 13 | def __init__(self, d_model, d_ff, dropout=0.1): 14 | super(PositionwiseFeedForward, self).__init__() 15 | self.w_1 = nn.Linear(d_model, d_ff) 16 | self.w_2 = nn.Linear(d_ff, d_model) 17 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 18 | self.dropout_1 = nn.Dropout(dropout) 19 | self.relu = nn.ReLU() 20 | self.dropout_2 = nn.Dropout(dropout) 21 | 22 | def forward(self, x): 23 | inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) 24 | output = self.dropout_2(self.w_2(inter)) 25 | return output + x 26 | 27 | class PositionalEncoding(nn.Module): 28 | 29 | def __init__(self, dropout, dim, max_len=5000): 30 | pe = torch.zeros(max_len, dim) 31 | position = torch.arange(0, max_len).unsqueeze(1) 32 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * 33 | -(math.log(10000.0) / dim))) 34 | pe[:, 0::2] = torch.sin(position.float() * div_term) 35 | pe[:, 1::2] = torch.cos(position.float() * div_term) 36 | pe = pe.unsqueeze(0) 37 | super(PositionalEncoding, self).__init__() 38 | self.register_buffer('pe', pe) 39 | self.dropout = nn.Dropout(p=dropout) 40 | self.dim = dim 41 | 42 | def forward(self, emb, step=None): 43 | emb = emb * math.sqrt(self.dim) 44 | if (step): 45 | emb = emb + self.pe[:, step][:, None, :] 46 | 47 | else: 48 | emb = emb + self.pe[:, :emb.size(1)] 49 | emb = self.dropout(emb) 50 | return emb 51 | 52 | def get_emb(self, emb): 53 | return self.pe[:, :emb.size(1)] 54 | 55 | class TransformerEncoderLayer(nn.Module): 56 | def __init__(self, d_model, heads, d_ff, dropout): 57 | super(TransformerEncoderLayer, self).__init__() 58 | 59 | self.self_attn = MultiHeadedAttention( 60 | heads, d_model, dropout=dropout) 61 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 62 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 63 | self.dropout = nn.Dropout(dropout) 64 | 65 | def forward(self, query, inputs, mask): 66 | input_norm = self.layer_norm(inputs) 67 | mask = mask.unsqueeze(1) 68 | context = self.self_attn(input_norm, input_norm, input_norm, 69 | mask=mask) 70 | out = self.dropout(context) + inputs 71 | return self.feed_forward(out) 72 | 73 | 74 | 75 | class TransformerInterEncoder(nn.Module): 76 | def __init__(self, d_model, d_ff, heads, dropout, embeddings, num_local_layers=0, num_inter_layers=0): 77 | super(TransformerInterEncoder, self).__init__() 78 | self.d_model = d_model 79 | self.num_local_layers = num_local_layers 80 | self.num_inter_layers = num_inter_layers 81 | self.embeddings = embeddings 82 | self.pos_emb = PositionalEncoding(dropout, int(self.embeddings.embedding_dim)) 83 | self.transformer_local = nn.ModuleList( 84 | [TransformerEncoderLayer(d_model, heads, d_ff, dropout) 85 | for _ in range(num_local_layers)]) 86 | self.transformer_inter = nn.ModuleList( 87 | [TransformerEncoderLayer(d_model, heads, d_ff, dropout) 88 | for _ in range(num_inter_layers)]) 89 | self.pooling = MultiHeadedPooling(d_model) 90 | self.dropout = nn.Dropout(dropout) 91 | self.layer_norm1 = nn.LayerNorm(d_model, eps=1e-6) 92 | self.layer_norm2 = nn.LayerNorm(d_model, eps=1e-6) 93 | self.wo = nn.Linear(d_model,1,bias=True) 94 | self.sigmoid = nn.Sigmoid() 95 | 96 | 97 | def forward(self, src): 98 | batch_size, n_blocks, n_tokens = src.size() 99 | emb = self.embeddings(src) 100 | padding_idx = self.embeddings.padding_idx 101 | mask_local = 1 - src.data.eq(padding_idx).view(batch_size * n_blocks, n_tokens) 102 | mask_block = torch.sum(mask_local.view(batch_size, n_blocks, n_tokens), -1) > 0 103 | 104 | 105 | local_pos_emb = self.pos_emb.pe[:, :n_tokens].unsqueeze(1).expand(batch_size, n_blocks, n_tokens, self.embeddings.embedding_dim) 106 | emb = emb * math.sqrt(self.embeddings.embedding_dim) 107 | emb = emb + local_pos_emb 108 | emb = self.pos_emb.dropout(emb) 109 | 110 | 111 | word_vec = emb.view(batch_size * n_blocks, n_tokens, -1) 112 | 113 | for i in range(self.num_local_layers): 114 | word_vec = self.transformer_local[i](word_vec, word_vec, 1 - mask_local) # all_sents * max_tokens * dim 115 | 116 | 117 | mask_hier = mask_local[:, :, None].float() 118 | word_vec = word_vec * mask_hier 119 | word_vec = self.layer_norm1(word_vec) 120 | 121 | sent_vec = self.pooling(word_vec, mask_local) 122 | sent_vec = sent_vec.view(batch_size, n_blocks, -1) 123 | global_pos_emb = self.pos_emb.pe[:, :n_blocks] 124 | sent_vec = sent_vec+global_pos_emb 125 | 126 | for i in range(self.num_inter_layers): 127 | sent_vec = self.transformer_inter[i](sent_vec, sent_vec, 1 - mask_block) # all_sents * max_tokens * dim 128 | 129 | 130 | sent_vec = self.layer_norm2(sent_vec) 131 | sent_scores = self.sigmoid(self.wo(sent_vec)) 132 | sent_scores = sent_scores.squeeze(-1) * mask_block.float() 133 | 134 | return sent_scores, mask_block 135 | 136 | class StructuredEncoder(nn.Module): 137 | def __init__(self, d_model, d_ff, heads, dropout, embeddings, num_local_layers=0, num_inter_layers=0): 138 | super(StructuredEncoder, self).__init__() 139 | self.d_model = d_model 140 | self.num_local_layers = num_local_layers 141 | self.num_inter_layers = num_inter_layers 142 | self.embeddings = embeddings 143 | self.pos_emb = PositionalEncoding(dropout, int(self.embeddings.embedding_dim)) 144 | self.transformer_local = nn.ModuleList( 145 | [TransformerEncoderLayer(d_model, heads, d_ff, dropout) 146 | for _ in range(num_local_layers)]) 147 | self.transformer_inter = nn.ModuleList([TMTLayer(d_model, d_ff, dropout, i) for i in range(num_inter_layers)]) 148 | self.pooling = MultiHeadedPooling(d_model) 149 | self.dropout = nn.Dropout(dropout) 150 | self.layer_norm1= nn.LayerNorm(d_model, eps=1e-6) 151 | self.layer_norm2 = nn.LayerNorm(d_model, eps=1e-6) 152 | self.wo = nn.Linear(d_model,1,bias=True) 153 | self.sigmoid = nn.Sigmoid() 154 | 155 | 156 | def forward(self, src): 157 | """ See :obj:`EncoderBase.forward()`""" 158 | batch_size, n_blocks, n_tokens = src.size() 159 | emb = self.embeddings(src) 160 | padding_idx = self.embeddings.padding_idx 161 | mask_local = 1 - src.data.eq(padding_idx).view(batch_size * n_blocks, n_tokens) 162 | mask_block = torch.sum(mask_local.view(batch_size, n_blocks, n_tokens), -1) > 0 163 | 164 | 165 | local_pos_emb = self.pos_emb.pe[:, :n_tokens].unsqueeze(1).expand(batch_size, n_blocks, n_tokens, self.embeddings.embedding_dim) 166 | emb = emb * math.sqrt(self.embeddings.embedding_dim) 167 | emb = emb + local_pos_emb 168 | emb = self.pos_emb.dropout(emb) 169 | 170 | 171 | word_vec = emb.view(batch_size * n_blocks, n_tokens, -1) 172 | 173 | for i in range(self.num_local_layers): 174 | word_vec = self.transformer_local[i](word_vec, word_vec, 1 - mask_local) 175 | 176 | mask_hier = mask_local[:, :, None].float() 177 | word_vec = word_vec * mask_hier 178 | word_vec = self.layer_norm1(word_vec) 179 | sent_vec = self.pooling(word_vec, mask_local) 180 | sent_vec = sent_vec.view(batch_size, n_blocks, -1) 181 | 182 | global_pos_emb = self.pos_emb.pe[:, :n_blocks] 183 | sent_vec = sent_vec+global_pos_emb 184 | 185 | sent_vec = self.layer_norm2(sent_vec)* mask_block.unsqueeze(-1).float() 186 | 187 | structure_vec = sent_vec 188 | roots = [] 189 | for i in range(self.num_inter_layers): 190 | structure_vec, root = self.transformer_inter[i](sent_vec, structure_vec, 1 - mask_block) 191 | roots.append(root) 192 | 193 | 194 | 195 | return roots, mask_block 196 | 197 | 198 | 199 | class TMTLayer(nn.Module): 200 | def __init__(self, d_model, d_ff, dropout, iter): 201 | super(TMTLayer, self).__init__() 202 | 203 | self.iter = iter 204 | self.self_attn = StructuredAttention( d_model, dropout) 205 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 206 | self.dropout = nn.Dropout(dropout) 207 | self.linears1 = nn.ModuleList([nn.Linear(2*d_model,d_model) for _ in range(iter)]) 208 | self.relu = nn.ReLU() 209 | self.linears2 = nn.ModuleList([nn.Linear(d_model,d_model) for _ in range(iter)]) 210 | self.layer_norm = nn.ModuleList([nn.LayerNorm(d_model, eps=1e-6) for _ in range(iter)]) 211 | 212 | def forward(self, x, structure_vec, mask): 213 | vecs = [x] 214 | 215 | mask = mask.unsqueeze(1) 216 | attn, root = self.self_attn(structure_vec,mask=mask) 217 | for i in range(self.iter): 218 | context = torch.matmul(attn, vecs[-1]) 219 | new_c = self.linears2[i](self.relu(self.linears1[i](torch.cat([vecs[-1], context], -1)))) 220 | new_c = self.layer_norm[i](new_c) 221 | vecs.append(new_c) 222 | 223 | return vecs[-1], root 224 | -------------------------------------------------------------------------------- /src/models/model_builder.py: -------------------------------------------------------------------------------- 1 | from models.encoder import TransformerInterEncoder, StructuredEncoder 2 | from models.optimizers import Optimizer 3 | import torch.nn as nn 4 | from torch.nn.init import xavier_uniform_ 5 | 6 | import torch 7 | 8 | 9 | def build_optim(args, model, checkpoint): 10 | """ Build optimizer """ 11 | saved_optimizer_state_dict = None 12 | 13 | if args.train_from != '': 14 | optim = checkpoint['optim'] 15 | saved_optimizer_state_dict = optim.optimizer.state_dict() 16 | else: 17 | optim = Optimizer( 18 | args.optim, args.lr, args.max_grad_norm, 19 | beta1=args.beta1, beta2=args.beta2, 20 | decay_method=args.decay_method, 21 | warmup_steps=args.warmup_steps, model_size=args.hidden_size) 22 | 23 | # Stage 1: 24 | # Essentially optim.set_parameters (re-)creates and optimizer using 25 | # model.paramters() as parameters that will be stored in the 26 | # optim.optimizer.param_groups field of the torch optimizer class. 27 | # Importantly, this method does not yet load the optimizer state, as 28 | # essentially it builds a new optimizer with empty optimizer state and 29 | # parameters from the model. 30 | optim.set_parameters(list(model.named_parameters())) 31 | 32 | if args.train_from != '': 33 | # Stage 2: In this stage, which is only performed when loading an 34 | # optimizer from a checkpoint, we load the saved_optimizer_state_dict 35 | # into the re-created optimizer, to set the optim.optimizer.state 36 | # field, which was previously empty. For this, we use the optimizer 37 | # state saved in the "saved_optimizer_state_dict" variable for 38 | # this purpose. 39 | # See also: https://github.com/pytorch/pytorch/issues/2830 40 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 41 | # Convert back the state values to cuda type if applicable 42 | if args.visible_gpu != '-1': 43 | for state in optim.optimizer.state.values(): 44 | for k, v in state.items(): 45 | if torch.is_tensor(v): 46 | state[k] = v.cuda() 47 | 48 | # We want to make sure that indeed we have a non-empty optimizer state 49 | # when we loaded an existing model. This should be at least the case 50 | # for Adam, which saves "exp_avg" and "exp_avg_sq" state 51 | # (Exponential moving average of gradient and squared gradient values) 52 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 53 | raise RuntimeError( 54 | "Error: loaded Adam optimizer from existing model" + 55 | " but optimizer state is empty") 56 | 57 | return optim 58 | 59 | 60 | 61 | 62 | class Summarizer(nn.Module): 63 | def __init__(self, args, word_padding_idx, vocab_size, device, checkpoint=None, multigpu=False): 64 | self.multigpu = multigpu 65 | super(Summarizer, self).__init__() 66 | self.vocab_size = vocab_size 67 | self.device = device 68 | 69 | src_embeddings = torch.nn.Embedding(self.vocab_size, args.emb_size, padding_idx=word_padding_idx) 70 | if(args.structured): 71 | self.encoder = StructuredEncoder(args.hidden_size, args.ff_size, args.heads, args.dropout, src_embeddings, 72 | args.local_layers, args.inter_layers) 73 | else: 74 | self.encoder = TransformerInterEncoder(args.hidden_size, args.ff_size, args.heads, args.dropout, src_embeddings, 75 | args.local_layers, args.inter_layers) 76 | if checkpoint is not None: 77 | # checkpoint['model'] 78 | keys = list(checkpoint['model'].keys()) 79 | self.load_state_dict(checkpoint['model'], strict=True) 80 | else: 81 | for p in self.parameters(): 82 | if p.dim() > 1: 83 | xavier_uniform_(p) 84 | 85 | self.to(device) 86 | 87 | def forward(self, src, labels, src_lengths): 88 | sent_scores, mask_block = self.encoder(src) 89 | 90 | return sent_scores, mask_block 91 | -------------------------------------------------------------------------------- /src/models/optimizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.nn.utils import clip_grad_norm_ 4 | 5 | 6 | def use_gpu(opt): 7 | return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \ 8 | (hasattr(opt, 'gpu') and opt.gpu > -1) 9 | 10 | 11 | 12 | class MultipleOptimizer(object): 13 | """ Implement multiple optimizers needed for sparse adam """ 14 | 15 | def __init__(self, op): 16 | """ ? """ 17 | self.optimizers = op 18 | 19 | def zero_grad(self): 20 | """ ? """ 21 | for op in self.optimizers: 22 | op.zero_grad() 23 | 24 | def step(self): 25 | """ ? """ 26 | for op in self.optimizers: 27 | op.step() 28 | 29 | @property 30 | def state(self): 31 | """ ? """ 32 | return {k: v for op in self.optimizers for k, v in op.state.items()} 33 | 34 | def state_dict(self): 35 | """ ? """ 36 | return [op.state_dict() for op in self.optimizers] 37 | 38 | def load_state_dict(self, state_dicts): 39 | """ ? """ 40 | assert len(state_dicts) == len(self.optimizers) 41 | for i in range(len(state_dicts)): 42 | self.optimizers[i].load_state_dict(state_dicts[i]) 43 | 44 | 45 | class Optimizer(object): 46 | """ 47 | Controller class for optimization. Mostly a thin 48 | wrapper for `optim`, but also useful for implementing 49 | rate scheduling beyond what is currently available. 50 | Also implements necessary methods for training RNNs such 51 | as grad manipulations. 52 | 53 | Args: 54 | method (:obj:`str`): one of [sgd, adagrad, adadelta, adam] 55 | lr (float): learning rate 56 | lr_decay (float, optional): learning rate decay multiplier 57 | start_decay_steps (int, optional): step to start learning rate decay 58 | beta1, beta2 (float, optional): parameters for adam 59 | adagrad_accum (float, optional): initialization parameter for adagrad 60 | decay_method (str, option): custom decay options 61 | warmup_steps (int, option): parameter for `noam` decay 62 | model_size (int, option): parameter for `noam` decay 63 | 64 | We use the default parameters for Adam that are suggested by 65 | the original paper https://arxiv.org/pdf/1412.6980.pdf 66 | These values are also used by other established implementations, 67 | e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 68 | https://keras.io/optimizers/ 69 | Recently there are slightly different values used in the paper 70 | "Attention is all you need" 71 | https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98 72 | was used there however, beta2=0.999 is still arguably the more 73 | established value, so we use that here as well 74 | """ 75 | 76 | def __init__(self, method, learning_rate, max_grad_norm, 77 | lr_decay=1, start_decay_steps=None, decay_steps=None, 78 | beta1=0.9, beta2=0.999, 79 | adagrad_accum=0.0, 80 | decay_method=None, 81 | warmup_steps=4000, 82 | model_size=None): 83 | self.last_ppl = None 84 | self.learning_rate = learning_rate 85 | self.original_lr = learning_rate 86 | self.max_grad_norm = max_grad_norm 87 | self.method = method 88 | self.lr_decay = lr_decay 89 | self.start_decay_steps = start_decay_steps 90 | self.decay_steps = decay_steps 91 | self.start_decay = False 92 | self._step = 0 93 | self.betas = [beta1, beta2] 94 | self.adagrad_accum = adagrad_accum 95 | self.decay_method = decay_method 96 | self.warmup_steps = warmup_steps 97 | self.model_size = model_size 98 | 99 | def set_parameters(self, params): 100 | """ ? """ 101 | self.params = [] 102 | self.sparse_params = [] 103 | for k, p in params: 104 | if p.requires_grad: 105 | if self.method != 'sparseadam' or "embed" not in k: 106 | self.params.append(p) 107 | else: 108 | self.sparse_params.append(p) 109 | if self.method == 'sgd': 110 | self.optimizer = optim.SGD(self.params, lr=self.learning_rate) 111 | elif self.method == 'adagrad': 112 | self.optimizer = optim.Adagrad(self.params, lr=self.learning_rate) 113 | for group in self.optimizer.param_groups: 114 | for p in group['params']: 115 | self.optimizer.state[p]['sum'] = self.optimizer\ 116 | .state[p]['sum'].fill_(self.adagrad_accum) 117 | elif self.method == 'adadelta': 118 | self.optimizer = optim.Adadelta(self.params, lr=self.learning_rate) 119 | elif self.method == 'adam': 120 | self.optimizer = optim.Adam(self.params, lr=self.learning_rate, 121 | betas=self.betas, eps=1e-9) 122 | elif self.method == 'sparseadam': 123 | self.optimizer = MultipleOptimizer( 124 | [optim.Adam(self.params, lr=self.learning_rate, 125 | betas=self.betas, eps=1e-8), 126 | optim.SparseAdam(self.sparse_params, lr=self.learning_rate, 127 | betas=self.betas, eps=1e-8)]) 128 | else: 129 | raise RuntimeError("Invalid optim method: " + self.method) 130 | 131 | def _set_rate(self, learning_rate): 132 | self.learning_rate = learning_rate 133 | if self.method != 'sparseadam': 134 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 135 | else: 136 | for op in self.optimizer.optimizers: 137 | op.param_groups[0]['lr'] = self.learning_rate 138 | 139 | def step(self): 140 | """Update the model parameters based on current gradients. 141 | 142 | Optionally, will employ gradient modification or update learning 143 | rate. 144 | """ 145 | self._step += 1 146 | 147 | # Decay method used in tensor2tensor. 148 | if self.decay_method == "noam": 149 | self._set_rate( 150 | self.original_lr * 151 | (self.model_size ** (-0.5) * 152 | min(self._step ** (-0.5), 153 | self._step * self.warmup_steps**(-1.5)))) 154 | 155 | # self._set_rate(self.original_lr *self.model_size ** (-0.5) *min(1.0, self._step / self.warmup_steps)*max(self._step, self.warmup_steps)**(-0.5)) 156 | # Decay based on start_decay_steps every decay_steps 157 | else: 158 | if ((self.start_decay_steps is not None) and ( 159 | self._step >= self.start_decay_steps)): 160 | self.start_decay = True 161 | if self.start_decay: 162 | if ((self._step - self.start_decay_steps) 163 | % self.decay_steps == 0): 164 | self.learning_rate = self.learning_rate * self.lr_decay 165 | 166 | if self.method != 'sparseadam': 167 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 168 | 169 | if self.max_grad_norm: 170 | clip_grad_norm_(self.params, self.max_grad_norm) 171 | self.optimizer.step() 172 | -------------------------------------------------------------------------------- /src/models/reporter.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import time 3 | from datetime import datetime 4 | 5 | from models.stats import Statistics 6 | from others.logging import logger 7 | 8 | 9 | def build_report_manager(opt): 10 | if opt.tensorboard: 11 | from tensorboardX import SummaryWriter 12 | tensorboard_log_dir = opt.tensorboard_log_dir 13 | 14 | if not opt.train_from: 15 | tensorboard_log_dir += datetime.now().strftime("/%b-%d_%H-%M-%S") 16 | 17 | writer = SummaryWriter(tensorboard_log_dir, 18 | comment="Unmt") 19 | else: 20 | writer = None 21 | 22 | report_mgr = ReportMgr(opt.report_every, start_time=-1, 23 | tensorboard_writer=writer) 24 | return report_mgr 25 | 26 | 27 | class ReportMgrBase(object): 28 | """ 29 | Report Manager Base class 30 | Inherited classes should override: 31 | * `_report_training` 32 | * `_report_step` 33 | """ 34 | 35 | def __init__(self, report_every, start_time=-1.): 36 | """ 37 | Args: 38 | report_every(int): Report status every this many sentences 39 | start_time(float): manually set report start time. Negative values 40 | means that you will need to set it later or use `start()` 41 | """ 42 | self.report_every = report_every 43 | self.progress_step = 0 44 | self.start_time = start_time 45 | 46 | def start(self): 47 | self.start_time = time.time() 48 | 49 | def log(self, *args, **kwargs): 50 | logger.info(*args, **kwargs) 51 | 52 | def report_training(self, step, num_steps, learning_rate, 53 | report_stats, multigpu=False): 54 | """ 55 | This is the user-defined batch-level traing progress 56 | report function. 57 | 58 | Args: 59 | step(int): current step count. 60 | num_steps(int): total number of batches. 61 | learning_rate(float): current learning rate. 62 | report_stats(Statistics): old Statistics instance. 63 | Returns: 64 | report_stats(Statistics): updated Statistics instance. 65 | """ 66 | if self.start_time < 0: 67 | raise ValueError("""ReportMgr needs to be started 68 | (set 'start_time' or use 'start()'""") 69 | 70 | if step % self.report_every == 0: 71 | if multigpu: 72 | report_stats = \ 73 | Statistics.all_gather_stats(report_stats) 74 | self._report_training( 75 | step, num_steps, learning_rate, report_stats) 76 | self.progress_step += 1 77 | return Statistics() 78 | else: 79 | return report_stats 80 | 81 | def _report_training(self, *args, **kwargs): 82 | """ To be overridden """ 83 | raise NotImplementedError() 84 | 85 | def report_step(self, lr, step, train_stats=None, valid_stats=None): 86 | """ 87 | Report stats of a step 88 | 89 | Args: 90 | train_stats(Statistics): training stats 91 | valid_stats(Statistics): validation stats 92 | lr(float): current learning rate 93 | """ 94 | self._report_step( 95 | lr, step, train_stats=train_stats, valid_stats=valid_stats) 96 | 97 | def _report_step(self, *args, **kwargs): 98 | raise NotImplementedError() 99 | 100 | 101 | class ReportMgr(ReportMgrBase): 102 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None): 103 | """ 104 | A report manager that writes statistics on standard output as well as 105 | (optionally) TensorBoard 106 | 107 | Args: 108 | report_every(int): Report status every this many sentences 109 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`): 110 | The TensorBoard Summary writer to use or None 111 | """ 112 | super(ReportMgr, self).__init__(report_every, start_time) 113 | self.tensorboard_writer = tensorboard_writer 114 | 115 | def maybe_log_tensorboard(self, stats, prefix, learning_rate, step): 116 | if self.tensorboard_writer is not None: 117 | stats.log_tensorboard( 118 | prefix, self.tensorboard_writer, learning_rate, step) 119 | 120 | def _report_training(self, step, num_steps, learning_rate, 121 | report_stats): 122 | """ 123 | See base class method `ReportMgrBase.report_training`. 124 | """ 125 | report_stats.output(step, num_steps, 126 | learning_rate, self.start_time) 127 | 128 | # Log the progress using the number of batches on the x-axis. 129 | self.maybe_log_tensorboard(report_stats, 130 | "progress", 131 | learning_rate, 132 | self.progress_step) 133 | report_stats = Statistics() 134 | 135 | return report_stats 136 | 137 | def _report_step(self, lr, step, train_stats=None, valid_stats=None): 138 | """ 139 | See base class method `ReportMgrBase.report_step`. 140 | """ 141 | if train_stats is not None: 142 | self.log('Train xent: %g' % train_stats.xent()) 143 | 144 | self.maybe_log_tensorboard(train_stats, 145 | "train", 146 | lr, 147 | step) 148 | 149 | if valid_stats is not None: 150 | self.log('Validation xent: %g' % valid_stats.xent()) 151 | 152 | self.maybe_log_tensorboard(valid_stats, 153 | "valid", 154 | lr, 155 | step) 156 | -------------------------------------------------------------------------------- /src/models/stats.py: -------------------------------------------------------------------------------- 1 | """ Statistics calculation utility """ 2 | from __future__ import division 3 | import time 4 | import sys 5 | 6 | from others.logging import logger 7 | 8 | 9 | class Statistics(object): 10 | """ 11 | Accumulator for loss statistics. 12 | Currently calculates: 13 | 14 | * accuracy 15 | * perplexity 16 | * elapsed time 17 | """ 18 | 19 | def __init__(self, loss=0, n_docs=0, n_correct=0): 20 | self.loss = loss 21 | self.n_docs = n_docs 22 | self.start_time = time.time() 23 | 24 | @staticmethod 25 | def all_gather_stats(stat, max_size=4096): 26 | """ 27 | Gather a `Statistics` object accross multiple process/nodes 28 | 29 | Args: 30 | stat(:obj:Statistics): the statistics object to gather 31 | accross all processes/nodes 32 | max_size(int): max buffer size to use 33 | 34 | Returns: 35 | `Statistics`, the update stats object 36 | """ 37 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size) 38 | return stats[0] 39 | 40 | @staticmethod 41 | def all_gather_stats_list(stat_list, max_size=4096): 42 | """ 43 | Gather a `Statistics` list accross all processes/nodes 44 | 45 | Args: 46 | stat_list(list([`Statistics`])): list of statistics objects to 47 | gather accross all processes/nodes 48 | max_size(int): max buffer size to use 49 | 50 | Returns: 51 | our_stats(list([`Statistics`])): list of updated stats 52 | """ 53 | from torch.distributed import get_rank 54 | from onmt.utils.distributed import all_gather_list 55 | 56 | # Get a list of world_size lists with len(stat_list) Statistics objects 57 | all_stats = all_gather_list(stat_list, max_size=max_size) 58 | 59 | our_rank = get_rank() 60 | our_stats = all_stats[our_rank] 61 | for other_rank, stats in enumerate(all_stats): 62 | if other_rank == our_rank: 63 | continue 64 | for i, stat in enumerate(stats): 65 | our_stats[i].update(stat, update_n_src_words=True) 66 | return our_stats 67 | 68 | def update(self, stat, update_n_src_words=False): 69 | """ 70 | Update statistics by suming values with another `Statistics` object 71 | 72 | Args: 73 | stat: another statistic object 74 | update_n_src_words(bool): whether to update (sum) `n_src_words` 75 | or not 76 | 77 | """ 78 | self.loss += stat.loss 79 | 80 | self.n_docs += stat.n_docs 81 | 82 | def xent(self): 83 | """ compute cross entropy """ 84 | return self.loss/self.n_docs 85 | 86 | 87 | def elapsed_time(self): 88 | """ compute elapsed time """ 89 | return time.time() - self.start_time 90 | 91 | def output(self, step, num_steps, learning_rate, start): 92 | """Write out statistics to stdout. 93 | 94 | Args: 95 | step (int): current step 96 | n_batch (int): total batches 97 | start (int): start time of step. 98 | """ 99 | t = self.elapsed_time() 100 | step_fmt = "%2d" % step 101 | if num_steps > 0: 102 | step_fmt = "%s/%5d" % (step_fmt, num_steps) 103 | logger.info( 104 | ("Step %s; xent: %4.2f; " + 105 | "lr: %7.5f; %3.0f docs/s; %6.0f sec") 106 | % (step_fmt, 107 | self.xent(), 108 | learning_rate, 109 | self.n_docs / (t + 1e-5), 110 | time.time() - start)) 111 | sys.stdout.flush() 112 | 113 | def log_tensorboard(self, prefix, writer, learning_rate, step): 114 | """ display statistics to tensorboard """ 115 | t = self.elapsed_time() 116 | writer.add_scalar(prefix + "/xent", self.xent(), step) 117 | writer.add_scalar(prefix + "/lr", learning_rate, step) 118 | -------------------------------------------------------------------------------- /src/models/trainer.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import shutil 3 | import time 4 | import torch 5 | import os 6 | from others import pyrouge 7 | from models.reporter import ReportMgr 8 | from models.stats import Statistics 9 | from others.logging import logger 10 | from tensorboardX import SummaryWriter 11 | import numpy as np 12 | 13 | 14 | def _get_ngrams(n, text): 15 | ngram_set = set() 16 | text_length = len(text) 17 | max_index_ngram_start = text_length - n 18 | for i in range(max_index_ngram_start + 1): 19 | ngram_set.add(tuple(text[i:i + n])) 20 | return ngram_set 21 | 22 | 23 | def _tally_parameters(model): 24 | n_params = sum([p.nelement() for p in model.parameters()]) 25 | enc = 0 26 | dec = 0 27 | for name, param in model.named_parameters(): 28 | if 'encoder' in name: 29 | enc += param.nelement() 30 | elif 'decoder' or 'generator' in name: 31 | dec += param.nelement() 32 | return n_params, enc, dec 33 | 34 | 35 | def build_trainer(args, device_id, model, 36 | optim): 37 | grad_accum_count = args.accum_count 38 | n_gpu = 1 39 | if device_id < 0: 40 | n_gpu = 0 41 | 42 | gpu_rank = 0 43 | 44 | tensorboard_log_dir = args.model_path 45 | 46 | writer = SummaryWriter(tensorboard_log_dir, comment="Unmt") 47 | 48 | report_manager = ReportMgr(args.report_every, start_time=-1, tensorboard_writer=writer) 49 | 50 | trainer = Trainer(args, model, optim, grad_accum_count, n_gpu, gpu_rank, report_manager) 51 | 52 | n_params, enc, dec = _tally_parameters(model) 53 | logger.info('encoder: %d' % enc) 54 | logger.info('decoder: %d' % dec) 55 | logger.info('* number of parameters: %d' % n_params) 56 | 57 | return trainer 58 | 59 | 60 | def process(temp_dir,candidates, references): 61 | cnt = len(candidates) 62 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 63 | tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time)) 64 | if not os.path.isdir(tmp_dir): 65 | os.mkdir(tmp_dir) 66 | os.mkdir(tmp_dir + "/candidate") 67 | os.mkdir(tmp_dir + "/reference") 68 | try: 69 | 70 | for i in range(cnt): 71 | if len(references[i]) < 1: 72 | continue 73 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", 74 | encoding="utf-8") as f: 75 | f.write(candidates[i]) 76 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", 77 | encoding="utf-8") as f: 78 | f.write(references[i]) 79 | r = pyrouge.Rouge155(temp_dir = temp_dir) 80 | r.model_dir = tmp_dir + "/reference/" 81 | r.system_dir = tmp_dir + "/candidate/" 82 | r.model_filename_pattern = 'ref.#ID#.txt' 83 | r.system_filename_pattern = r'cand.(\d+).txt' 84 | rouge_results = r.convert_and_evaluate() 85 | print(rouge_results) 86 | results_dict = r.output_to_dict(rouge_results) 87 | finally: 88 | pass 89 | if os.path.isdir(tmp_dir): 90 | shutil.rmtree(tmp_dir) 91 | return results_dict 92 | 93 | def test_rouge(temp_dir, cand, ref, num_processes): 94 | candidates = [line.strip() for line in cand] 95 | references = [line.strip() for line in ref] 96 | 97 | print(len(candidates)) 98 | print(len(references)) 99 | assert len(candidates) == len(references) 100 | results = process(temp_dir,candidates,references) 101 | return results 102 | 103 | def rouge_results_to_str(results_dict): 104 | return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( 105 | results_dict["rouge_1_f_score"] * 100, 106 | results_dict["rouge_2_f_score"] * 100, 107 | results_dict["rouge_l_f_score"] * 100, 108 | results_dict["rouge_1_recall"] * 100, 109 | results_dict["rouge_2_recall"] * 100, 110 | results_dict["rouge_l_recall"] * 100 111 | 112 | # ,results_dict["rouge_su*_f_score"] * 100 113 | ) 114 | class Trainer(object): 115 | 116 | def __init__(self, args, model, optim, 117 | grad_accum_count=1, n_gpu=1, gpu_rank=1, 118 | report_manager=None): 119 | # Basic attributes. 120 | self.args = args 121 | self.save_checkpoint_steps = args.save_checkpoint_steps 122 | self.model = model 123 | self.optim = optim 124 | self.grad_accum_count = grad_accum_count 125 | self.n_gpu = n_gpu 126 | self.gpu_rank = gpu_rank 127 | self.report_manager = report_manager 128 | 129 | 130 | self.loss = torch.nn.BCELoss(reduction='none') 131 | assert grad_accum_count > 0 132 | # Set model in training mode. 133 | self.model.train() 134 | 135 | def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): 136 | logger.info('Start training...') 137 | 138 | step = self.optim._step + 1 139 | true_batchs = [] 140 | accum = 0 141 | normalization = 0 142 | train_iter = train_iter_fct() 143 | 144 | total_stats = Statistics() 145 | report_stats = Statistics() 146 | self._start_report_manager(start_time=total_stats.start_time) 147 | 148 | while step <= train_steps: 149 | 150 | reduce_counter = 0 151 | for i, batch in enumerate(train_iter): 152 | if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): 153 | 154 | true_batchs.append(batch) 155 | normalization += batch.batch_size 156 | accum += 1 157 | if accum == self.grad_accum_count: 158 | reduce_counter += 1 159 | self._gradient_accumulation( 160 | true_batchs, normalization, total_stats, 161 | report_stats) 162 | 163 | report_stats = self._maybe_report_training( 164 | step, train_steps, 165 | self.optim.learning_rate, 166 | report_stats) 167 | 168 | true_batchs = [] 169 | accum = 0 170 | normalization = 0 171 | if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): 172 | self._save(step) 173 | 174 | step += 1 175 | if step > train_steps: 176 | break 177 | train_iter = train_iter_fct() 178 | 179 | return total_stats 180 | 181 | def validate(self, valid_iter): 182 | # Set model in validating mode. 183 | self.model.eval() 184 | stats = Statistics() 185 | 186 | with torch.no_grad(): 187 | for batch in valid_iter: 188 | src = batch.src 189 | src_lengths = batch.src_length 190 | labels = batch.labels 191 | 192 | if(self.args.structured): 193 | roots, mask = self.model(src, labels, src_lengths) 194 | r = torch.clamp(roots[-1], 1e-5, 1 - 1e-5) 195 | loss = self.loss(r, labels) 196 | 197 | else: 198 | sent_scores, mask = self.model(src, labels, src_lengths) 199 | loss = self.loss(sent_scores, labels) 200 | loss = (loss * mask.float()).sum() 201 | batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels)) 202 | 203 | stats.update(batch_stats) 204 | return stats 205 | 206 | 207 | def test(self, test_iter, step): 208 | def _block_tri(c, p): 209 | tri_c = _get_ngrams(3, c.split()) 210 | for s in p: 211 | tri_s = _get_ngrams(3, s.split()) 212 | if len(tri_c.intersection(tri_s))>0: 213 | return True 214 | return False 215 | 216 | self.model.eval() 217 | gold_path = '%s_step%d.candidate'%(self.args.result_path,step) 218 | can_path = '%s_step%d.gold' % (self.args.result_path, step) 219 | with open(gold_path, 'w') as save_pred: 220 | with open(can_path, 'w') as save_gold: 221 | with torch.no_grad(): 222 | for batch in test_iter: 223 | src = batch.src 224 | src_lengths = batch.src_length 225 | labels = batch.labels 226 | gold = [] 227 | pred = [] 228 | if(self.args.structured): 229 | roots, mask = self.model(src, labels, src_lengths) 230 | sent_scores = roots[-1] + mask.float() 231 | else: 232 | sent_scores, mask = self.model(src, labels, src_lengths) 233 | sent_scores = sent_scores+mask.float() 234 | sent_scores = sent_scores.cpu().data.numpy() 235 | selected_ids = np.argsort(-sent_scores,1) 236 | # selected_ids = np.sort(selected_ids,1) 237 | for i, idx in enumerate(selected_ids): 238 | _pred = [] 239 | if(len(batch.src_str[i])==0): 240 | continue 241 | for j in selected_ids[i][:len(batch.src_str[i])]: 242 | candidate = batch.src_str[i][j].strip() 243 | if(not _block_tri(candidate,_pred)): 244 | _pred.append(candidate) 245 | 246 | if(len(_pred)==3): 247 | break 248 | pred.append(''.join(_pred)) 249 | gold.append(batch.tgt_str[i]) 250 | 251 | for i in range(len(gold)): 252 | save_gold.write(gold[i].strip()+'\n') 253 | for i in range(len(pred)): 254 | save_pred.write(pred[i].strip()+'\n') 255 | if(step!=-1 and self.args.report_rouge): 256 | rouges = self._report_rouge(gold_path, can_path) 257 | logger.info('Rouges at step %d \n%s'%(step,rouge_results_to_str(rouges))) 258 | 259 | 260 | def _report_rouge(self, gold_path, can_path): 261 | logger.info("Calculating Rouge") 262 | 263 | candidates = codecs.open(can_path, encoding="utf-8") 264 | references = codecs.open(gold_path, encoding="utf-8") 265 | results_dict = test_rouge(self.args.temp_dir, candidates, references, 1) 266 | return results_dict 267 | 268 | def _gradient_accumulation(self, true_batchs, normalization, total_stats, 269 | report_stats): 270 | if self.grad_accum_count > 1: 271 | self.model.zero_grad() 272 | 273 | for batch in true_batchs: 274 | src = batch.src 275 | 276 | src_lengths = batch.src_length 277 | labels = batch.labels 278 | 279 | if self.grad_accum_count == 1: 280 | self.model.zero_grad() 281 | 282 | 283 | if(self.args.structured): 284 | roots, mask = self.model(src, labels, src_lengths) 285 | loss = 0 286 | for r in roots: 287 | r = torch.clamp(r, 1e-5, 1 - 1e-5) 288 | _loss = self.loss(r, labels) 289 | _loss = (_loss * mask.float()).sum() 290 | loss += _loss 291 | loss = loss/len(roots) 292 | (loss / loss.numel()).backward() 293 | 294 | 295 | else: 296 | sent_scores, mask = self.model(src, labels, src_lengths) 297 | loss = self.loss(sent_scores, labels) 298 | loss = (loss*mask.float()).sum() 299 | (loss/loss.numel()).backward() 300 | # loss.div(float(normalization)).backward() 301 | 302 | batch_stats = Statistics(float(loss.cpu().data.numpy()), normalization) 303 | 304 | 305 | total_stats.update(batch_stats) 306 | report_stats.update(batch_stats) 307 | 308 | # 4. Update the parameters and statistics. 309 | if self.grad_accum_count == 1: 310 | self.optim.step() 311 | 312 | # in case of multi step gradient accumulation, 313 | # update only after accum batches 314 | if self.grad_accum_count > 1: 315 | self.optim.step() 316 | 317 | def _save(self, step): 318 | real_model = self.model 319 | # real_generator = (self.generator.module 320 | # if isinstance(self.generator, torch.nn.DataParallel) 321 | # else self.generator) 322 | 323 | model_state_dict = real_model.state_dict() 324 | # generator_state_dict = real_generator.state_dict() 325 | checkpoint = { 326 | 'model': model_state_dict, 327 | # 'generator': generator_state_dict, 328 | 'opt': self.args, 329 | 'optim': self.optim, 330 | } 331 | checkpoint_path = os.path.join(self.args.model_path, 'model_step_%d.pt' % step) 332 | logger.info("Saving checkpoint %s" % checkpoint_path) 333 | # checkpoint_path = '%s_step_%d.pt' % (FLAGS.model_path, step) 334 | if (not os.path.exists(checkpoint_path)): 335 | torch.save(checkpoint, checkpoint_path) 336 | return checkpoint, checkpoint_path 337 | 338 | def _start_report_manager(self, start_time=None): 339 | """ 340 | Simple function to start report manager (if any) 341 | """ 342 | if self.report_manager is not None: 343 | if start_time is None: 344 | self.report_manager.start() 345 | else: 346 | self.report_manager.start_time = start_time 347 | 348 | def _maybe_gather_stats(self, stat): 349 | """ 350 | Gather statistics in multi-processes cases 351 | 352 | Args: 353 | stat(:obj:onmt.utils.Statistics): a Statistics object to gather 354 | or None (it returns None in this case) 355 | 356 | Returns: 357 | stat: the updated (or unchanged) stat object 358 | """ 359 | if stat is not None and self.n_gpu > 1: 360 | return Statistics.all_gather_stats(stat) 361 | return stat 362 | 363 | def _maybe_report_training(self, step, num_steps, learning_rate, 364 | report_stats): 365 | """ 366 | Simple function to report training stats (if report_manager is set) 367 | see `onmt.utils.ReportManagerBase.report_training` for doc 368 | """ 369 | if self.report_manager is not None: 370 | return self.report_manager.report_training( 371 | step, num_steps, learning_rate, report_stats, 372 | multigpu=self.n_gpu > 1) 373 | 374 | def _report_step(self, learning_rate, step, train_stats=None, 375 | valid_stats=None): 376 | """ 377 | Simple function to report stats (if report_manager is set) 378 | see `onmt.utils.ReportManagerBase.report_step` for doc 379 | """ 380 | if self.report_manager is not None: 381 | return self.report_manager.report_step( 382 | learning_rate, step, train_stats=train_stats, 383 | valid_stats=valid_stats) 384 | 385 | def _maybe_save(self, step): 386 | """ 387 | Save the model if a model saver is set 388 | """ 389 | if self.model_saver is not None: 390 | self.model_saver.maybe_save(step) 391 | -------------------------------------------------------------------------------- /src/others/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/SUMO/f40c52de24381f8b58a90fbc2e57abd93bad56b7/src/others/__init__.py -------------------------------------------------------------------------------- /src/others/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import logging 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def init_logger(log_file=None, log_file_level=logging.NOTSET): 10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | 14 | console_handler = logging.StreamHandler() 15 | console_handler.setFormatter(log_format) 16 | logger.handlers = [console_handler] 17 | 18 | if log_file and log_file != '': 19 | file_handler = logging.FileHandler(log_file) 20 | file_handler.setLevel(log_file_level) 21 | file_handler.setFormatter(log_format) 22 | logger.addHandler(file_handler) 23 | 24 | return logger 25 | -------------------------------------------------------------------------------- /src/others/pyrouge.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, unicode_literals, division 2 | 3 | import os 4 | import re 5 | import codecs 6 | import platform 7 | 8 | from subprocess import check_output 9 | from tempfile import mkdtemp 10 | from functools import partial 11 | 12 | try: 13 | from configparser import ConfigParser 14 | except ImportError: 15 | from ConfigParser import ConfigParser 16 | 17 | from pyrouge.utils import log 18 | from pyrouge.utils.file_utils import verify_dir 19 | 20 | 21 | class DirectoryProcessor: 22 | 23 | @staticmethod 24 | def process(input_dir, output_dir, function): 25 | """ 26 | Apply function to all files in input_dir and save the resulting ouput 27 | files in output_dir. 28 | 29 | """ 30 | if not os.path.exists(output_dir): 31 | os.makedirs(output_dir) 32 | logger = log.get_global_console_logger() 33 | logger.info("Processing files in {}.".format(input_dir)) 34 | input_file_names = os.listdir(input_dir) 35 | for input_file_name in input_file_names: 36 | input_file = os.path.join(input_dir, input_file_name) 37 | with codecs.open(input_file, "r", encoding="UTF-8") as f: 38 | input_string = f.read() 39 | output_string = function(input_string) 40 | output_file = os.path.join(output_dir, input_file_name) 41 | with codecs.open(output_file, "w", encoding="UTF-8") as f: 42 | f.write(output_string.lower()) 43 | logger.info("Saved processed files to {}.".format(output_dir)) 44 | 45 | 46 | class Rouge155(object): 47 | """ 48 | This is a wrapper for the ROUGE 1.5.5 summary evaluation package. 49 | This class is designed to simplify the evaluation process by: 50 | 51 | 1) Converting summaries into a format ROUGE understands. 52 | 2) Generating the ROUGE configuration file automatically based 53 | on filename patterns. 54 | 55 | This class can be used within Python like this: 56 | 57 | rouge = Rouge155() 58 | rouge.system_dir = 'test/systems' 59 | rouge.model_dir = 'test/models' 60 | 61 | # The system filename pattern should contain one group that 62 | # matches the document ID. 63 | rouge.system_filename_pattern = 'SL.P.10.R.11.SL062003-(\d+).html' 64 | 65 | # The model filename pattern has '#ID#' as a placeholder for the 66 | # document ID. If there are multiple model summaries, pyrouge 67 | # will use the provided regex to automatically match them with 68 | # the corresponding system summary. Here, [A-Z] matches 69 | # multiple model summaries for a given #ID#. 70 | rouge.model_filename_pattern = 'SL.P.10.R.[A-Z].SL062003-#ID#.html' 71 | 72 | rouge_output = rouge.evaluate() 73 | print(rouge_output) 74 | output_dict = rouge.output_to_dict(rouge_ouput) 75 | print(output_dict) 76 | -> {'rouge_1_f_score': 0.95652, 77 | 'rouge_1_f_score_cb': 0.95652, 78 | 'rouge_1_f_score_ce': 0.95652, 79 | 'rouge_1_precision': 0.95652, 80 | [...] 81 | 82 | 83 | To evaluate multiple systems: 84 | 85 | rouge = Rouge155() 86 | rouge.system_dir = '/PATH/TO/systems' 87 | rouge.model_dir = 'PATH/TO/models' 88 | for system_id in ['id1', 'id2', 'id3']: 89 | rouge.system_filename_pattern = \ 90 | 'SL.P/.10.R.{}.SL062003-(\d+).html'.format(system_id) 91 | rouge.model_filename_pattern = \ 92 | 'SL.P.10.R.[A-Z].SL062003-#ID#.html' 93 | rouge_output = rouge.evaluate(system_id) 94 | print(rouge_output) 95 | 96 | """ 97 | 98 | def __init__(self, rouge_dir=None, rouge_args=None, temp_dir = None): 99 | """ 100 | Create a Rouge155 object. 101 | 102 | rouge_dir: Directory containing Rouge-1.5.5.pl 103 | rouge_args: Arguments to pass through to ROUGE if you 104 | don't want to use the default pyrouge 105 | arguments. 106 | 107 | """ 108 | self.temp_dir=temp_dir 109 | self.log = log.get_global_console_logger() 110 | self.__set_dir_properties() 111 | self._config_file = None 112 | self._settings_file = self.__get_config_path() 113 | self.__set_rouge_dir(rouge_dir) 114 | self.args = self.__clean_rouge_args(rouge_args) 115 | self._system_filename_pattern = None 116 | self._model_filename_pattern = None 117 | 118 | def save_home_dir(self): 119 | config = ConfigParser() 120 | section = 'pyrouge settings' 121 | config.add_section(section) 122 | config.set(section, 'home_dir', self._home_dir) 123 | with open(self._settings_file, 'w') as f: 124 | config.write(f) 125 | self.log.info("Set ROUGE home directory to {}.".format(self._home_dir)) 126 | 127 | @property 128 | def settings_file(self): 129 | """ 130 | Path of the setttings file, which stores the ROUGE home dir. 131 | 132 | """ 133 | return self._settings_file 134 | 135 | @property 136 | def bin_path(self): 137 | """ 138 | The full path of the ROUGE binary (although it's technically 139 | a script), i.e. rouge_home_dir/ROUGE-1.5.5.pl 140 | 141 | """ 142 | if self._bin_path is None: 143 | raise Exception( 144 | "ROUGE path not set. Please set the ROUGE home directory " 145 | "and ensure that ROUGE-1.5.5.pl exists in it.") 146 | return self._bin_path 147 | 148 | @property 149 | def system_filename_pattern(self): 150 | """ 151 | The regular expression pattern for matching system summary 152 | filenames. The regex string. 153 | 154 | E.g. "SL.P.10.R.11.SL062003-(\d+).html" will match the system 155 | filenames in the SPL2003/system folder of the ROUGE SPL example 156 | in the "sample-test" folder. 157 | 158 | Currently, there is no support for multiple systems. 159 | 160 | """ 161 | return self._system_filename_pattern 162 | 163 | @system_filename_pattern.setter 164 | def system_filename_pattern(self, pattern): 165 | self._system_filename_pattern = pattern 166 | 167 | @property 168 | def model_filename_pattern(self): 169 | """ 170 | The regular expression pattern for matching model summary 171 | filenames. The pattern needs to contain the string "#ID#", 172 | which is a placeholder for the document ID. 173 | 174 | E.g. "SL.P.10.R.[A-Z].SL062003-#ID#.html" will match the model 175 | filenames in the SPL2003/system folder of the ROUGE SPL 176 | example in the "sample-test" folder. 177 | 178 | "#ID#" is a placeholder for the document ID which has been 179 | matched by the "(\d+)" part of the system filename pattern. 180 | The different model summaries for a given document ID are 181 | matched by the "[A-Z]" part. 182 | 183 | """ 184 | return self._model_filename_pattern 185 | 186 | @model_filename_pattern.setter 187 | def model_filename_pattern(self, pattern): 188 | self._model_filename_pattern = pattern 189 | 190 | @property 191 | def config_file(self): 192 | return self._config_file 193 | 194 | @config_file.setter 195 | def config_file(self, path): 196 | config_dir, _ = os.path.split(path) 197 | verify_dir(config_dir, "configuration file") 198 | self._config_file = path 199 | 200 | def split_sentences(self): 201 | """ 202 | ROUGE requires texts split into sentences. In case the texts 203 | are not already split, this method can be used. 204 | 205 | """ 206 | from pyrouge.utils.sentence_splitter import PunktSentenceSplitter 207 | self.log.info("Splitting sentences.") 208 | ss = PunktSentenceSplitter() 209 | sent_split_to_string = lambda s: "\n".join(ss.split(s)) 210 | process_func = partial( 211 | DirectoryProcessor.process, function=sent_split_to_string) 212 | self.__process_summaries(process_func) 213 | 214 | @staticmethod 215 | def convert_summaries_to_rouge_format(input_dir, output_dir): 216 | """ 217 | Convert all files in input_dir into a format ROUGE understands 218 | and saves the files to output_dir. The input files are assumed 219 | to be plain text with one sentence per line. 220 | 221 | input_dir: Path of directory containing the input files. 222 | output_dir: Path of directory in which the converted files 223 | will be saved. 224 | 225 | """ 226 | DirectoryProcessor.process( 227 | input_dir, output_dir, Rouge155.convert_text_to_rouge_format) 228 | 229 | @staticmethod 230 | def convert_text_to_rouge_format(text, title="dummy title"): 231 | """ 232 | Convert a text to a format ROUGE understands. The text is 233 | assumed to contain one sentence per line. 234 | 235 | text: The text to convert, containg one sentence per line. 236 | title: Optional title for the text. The title will appear 237 | in the converted file, but doesn't seem to have 238 | any other relevance. 239 | 240 | Returns: The converted text as string. 241 | 242 | """ 243 | # sentences = text.split("\n") 244 | sentences = text.split("") 245 | sent_elems = [ 246 | "[{i}] " 247 | "{text}".format(i=i, text=sent) 248 | for i, sent in enumerate(sentences, start=1)] 249 | html = """ 250 | 251 | {title} 252 | 253 | 254 | {elems} 255 | 256 | """.format(title=title, elems="\n".join(sent_elems)) 257 | 258 | return html 259 | 260 | @staticmethod 261 | def write_config_static(system_dir, system_filename_pattern, 262 | model_dir, model_filename_pattern, 263 | config_file_path, system_id=None): 264 | """ 265 | Write the ROUGE configuration file, which is basically a list 266 | of system summary files and their corresponding model summary 267 | files. 268 | 269 | pyrouge uses regular expressions to automatically find the 270 | matching model summary files for a given system summary file 271 | (cf. docstrings for system_filename_pattern and 272 | model_filename_pattern). 273 | 274 | system_dir: Path of directory containing 275 | system summaries. 276 | system_filename_pattern: Regex string for matching 277 | system summary filenames. 278 | model_dir: Path of directory containing 279 | model summaries. 280 | model_filename_pattern: Regex string for matching model 281 | summary filenames. 282 | config_file_path: Path of the configuration file. 283 | system_id: Optional system ID string which 284 | will appear in the ROUGE output. 285 | 286 | """ 287 | system_filenames = [f for f in os.listdir(system_dir)] 288 | system_models_tuples = [] 289 | 290 | system_filename_pattern = re.compile(system_filename_pattern) 291 | for system_filename in sorted(system_filenames): 292 | match = system_filename_pattern.match(system_filename) 293 | if match: 294 | id = match.groups(0)[0] 295 | model_filenames = [model_filename_pattern.replace('#ID#',id)] 296 | # model_filenames = Rouge155.__get_model_filenames_for_id( 297 | # id, model_dir, model_filename_pattern) 298 | system_models_tuples.append( 299 | (system_filename, sorted(model_filenames))) 300 | if not system_models_tuples: 301 | raise Exception( 302 | "Did not find any files matching the pattern {} " 303 | "in the system summaries directory {}.".format( 304 | system_filename_pattern.pattern, system_dir)) 305 | 306 | with codecs.open(config_file_path, 'w', encoding='utf-8') as f: 307 | f.write('') 308 | for task_id, (system_filename, model_filenames) in enumerate( 309 | system_models_tuples, start=1): 310 | 311 | eval_string = Rouge155.__get_eval_string( 312 | task_id, system_id, 313 | system_dir, system_filename, 314 | model_dir, model_filenames) 315 | f.write(eval_string) 316 | f.write("") 317 | 318 | def write_config(self, config_file_path=None, system_id=None): 319 | """ 320 | Write the ROUGE configuration file, which is basically a list 321 | of system summary files and their matching model summary files. 322 | 323 | This is a non-static version of write_config_file_static(). 324 | 325 | config_file_path: Path of the configuration file. 326 | system_id: Optional system ID string which will 327 | appear in the ROUGE output. 328 | 329 | """ 330 | if not system_id: 331 | system_id = 1 332 | if (not config_file_path) or (not self._config_dir): 333 | self._config_dir = mkdtemp(dir=self.temp_dir) 334 | config_filename = "rouge_conf.xml" 335 | else: 336 | config_dir, config_filename = os.path.split(config_file_path) 337 | verify_dir(config_dir, "configuration file") 338 | self._config_file = os.path.join(self._config_dir, config_filename) 339 | Rouge155.write_config_static( 340 | self._system_dir, self._system_filename_pattern, 341 | self._model_dir, self._model_filename_pattern, 342 | self._config_file, system_id) 343 | self.log.info( 344 | "Written ROUGE configuration to {}".format(self._config_file)) 345 | 346 | def evaluate(self, system_id=1, rouge_args=None): 347 | """ 348 | Run ROUGE to evaluate the system summaries in system_dir against 349 | the model summaries in model_dir. The summaries are assumed to 350 | be in the one-sentence-per-line HTML format ROUGE understands. 351 | 352 | system_id: Optional system ID which will be printed in 353 | ROUGE's output. 354 | 355 | Returns: Rouge output as string. 356 | 357 | """ 358 | self.write_config(system_id=system_id) 359 | options = self.__get_options(rouge_args) 360 | command = [self._bin_path] + options 361 | self.log.info( 362 | "Running ROUGE with command {}".format(" ".join(command))) 363 | rouge_output = check_output(command).decode("UTF-8") 364 | return rouge_output 365 | 366 | def convert_and_evaluate(self, system_id=1, 367 | split_sentences=False, rouge_args=None): 368 | """ 369 | Convert plain text summaries to ROUGE format and run ROUGE to 370 | evaluate the system summaries in system_dir against the model 371 | summaries in model_dir. Optionally split texts into sentences 372 | in case they aren't already. 373 | 374 | This is just a convenience method combining 375 | convert_summaries_to_rouge_format() and evaluate(). 376 | 377 | split_sentences: Optional argument specifying if 378 | sentences should be split. 379 | system_id: Optional system ID which will be printed 380 | in ROUGE's output. 381 | 382 | Returns: ROUGE output as string. 383 | 384 | """ 385 | if split_sentences: 386 | self.split_sentences() 387 | self.__write_summaries() 388 | rouge_output = self.evaluate(system_id, rouge_args) 389 | return rouge_output 390 | 391 | def output_to_dict(self, output): 392 | """ 393 | Convert the ROUGE output into python dictionary for further 394 | processing. 395 | 396 | """ 397 | #0 ROUGE-1 Average_R: 0.02632 (95%-conf.int. 0.02632 - 0.02632) 398 | pattern = re.compile( 399 | r"(\d+) (ROUGE-\S+) (Average_\w): (\d.\d+) " 400 | r"\(95%-conf.int. (\d.\d+) - (\d.\d+)\)") 401 | results = {} 402 | for line in output.split("\n"): 403 | match = pattern.match(line) 404 | if match: 405 | sys_id, rouge_type, measure, result, conf_begin, conf_end = \ 406 | match.groups() 407 | measure = { 408 | 'Average_R': 'recall', 409 | 'Average_P': 'precision', 410 | 'Average_F': 'f_score' 411 | }[measure] 412 | rouge_type = rouge_type.lower().replace("-", '_') 413 | key = "{}_{}".format(rouge_type, measure) 414 | results[key] = float(result) 415 | results["{}_cb".format(key)] = float(conf_begin) 416 | results["{}_ce".format(key)] = float(conf_end) 417 | return results 418 | 419 | ################################################################### 420 | # Private methods 421 | 422 | def __set_rouge_dir(self, home_dir=None): 423 | """ 424 | Verfify presence of ROUGE-1.5.5.pl and data folder, and set 425 | those paths. 426 | 427 | """ 428 | if not home_dir: 429 | self._home_dir = self.__get_rouge_home_dir_from_settings() 430 | else: 431 | self._home_dir = home_dir 432 | self.save_home_dir() 433 | self._bin_path = os.path.join(self._home_dir, 'ROUGE-1.5.5.pl') 434 | self.data_dir = os.path.join(self._home_dir, 'data') 435 | if not os.path.exists(self._bin_path): 436 | raise Exception( 437 | "ROUGE binary not found at {}. Please set the " 438 | "correct path by running pyrouge_set_rouge_path " 439 | "/path/to/rouge/home.".format(self._bin_path)) 440 | 441 | def __get_rouge_home_dir_from_settings(self): 442 | config = ConfigParser() 443 | with open(self._settings_file) as f: 444 | if hasattr(config, "read_file"): 445 | config.read_file(f) 446 | else: 447 | # use deprecated python 2.x method 448 | config.readfp(f) 449 | rouge_home_dir = config.get('pyrouge settings', 'home_dir') 450 | return rouge_home_dir 451 | 452 | @staticmethod 453 | def __get_eval_string( 454 | task_id, system_id, 455 | system_dir, system_filename, 456 | model_dir, model_filenames): 457 | """ 458 | ROUGE can evaluate several system summaries for a given text 459 | against several model summaries, i.e. there is an m-to-n 460 | relation between system and model summaries. The system 461 | summaries are listed in the tag and the model summaries 462 | in the tag. pyrouge currently only supports one system 463 | summary per text, i.e. it assumes a 1-to-n relation between 464 | system and model summaries. 465 | 466 | """ 467 | peer_elems = "

{name}

".format( 468 | id=system_id, name=system_filename) 469 | 470 | model_elems = ["{name}".format( 471 | id=chr(65 + i), name=name) 472 | for i, name in enumerate(model_filenames)] 473 | 474 | model_elems = "\n\t\t\t".join(model_elems) 475 | eval_string = """ 476 | 477 | {model_root} 478 | {peer_root} 479 | 480 | 481 | 482 | {peer_elems} 483 | 484 | 485 | {model_elems} 486 | 487 | 488 | """.format( 489 | task_id=task_id, 490 | model_root=model_dir, model_elems=model_elems, 491 | peer_root=system_dir, peer_elems=peer_elems) 492 | return eval_string 493 | 494 | def __process_summaries(self, process_func): 495 | """ 496 | Helper method that applies process_func to the files in the 497 | system and model folders and saves the resulting files to new 498 | system and model folders. 499 | 500 | """ 501 | temp_dir = mkdtemp(dir=self.temp_dir) 502 | new_system_dir = os.path.join(temp_dir, "system") 503 | os.mkdir(new_system_dir) 504 | new_model_dir = os.path.join(temp_dir, "model") 505 | os.mkdir(new_model_dir) 506 | self.log.info( 507 | "Processing summaries. Saving system files to {} and " 508 | "model files to {}.".format(new_system_dir, new_model_dir)) 509 | process_func(self._system_dir, new_system_dir) 510 | process_func(self._model_dir, new_model_dir) 511 | self._system_dir = new_system_dir 512 | self._model_dir = new_model_dir 513 | 514 | def __write_summaries(self): 515 | self.log.info("Writing summaries.") 516 | self.__process_summaries(self.convert_summaries_to_rouge_format) 517 | 518 | @staticmethod 519 | def __get_model_filenames_for_id(id, model_dir, model_filenames_pattern): 520 | pattern = re.compile(model_filenames_pattern.replace('#ID#', id)) 521 | model_filenames = [ 522 | f for f in os.listdir(model_dir) if pattern.match(f)] 523 | if not model_filenames: 524 | raise Exception( 525 | "Could not find any model summaries for the system" 526 | " summary with ID {}. Specified model filename pattern was: " 527 | "{}".format(id, model_filenames_pattern)) 528 | return model_filenames 529 | 530 | def __get_options(self, rouge_args=None): 531 | """ 532 | Get supplied command line arguments for ROUGE or use default 533 | ones. 534 | 535 | """ 536 | if self.args: 537 | options = self.args.split() 538 | elif rouge_args: 539 | options = rouge_args.split() 540 | else: 541 | options = [ 542 | '-e', self._data_dir, 543 | '-c', 95, 544 | # '-2', 545 | # '-1', 546 | # '-U', 547 | '-m', 548 | # '-v', 549 | '-r', 1000, 550 | '-n', 2, 551 | # '-w', 1.2, 552 | '-a', 553 | ] 554 | options = list(map(str, options)) 555 | 556 | 557 | 558 | 559 | options = self.__add_config_option(options) 560 | return options 561 | 562 | def __create_dir_property(self, dir_name, docstring): 563 | """ 564 | Generate getter and setter for a directory property. 565 | 566 | """ 567 | property_name = "{}_dir".format(dir_name) 568 | private_name = "_" + property_name 569 | setattr(self, private_name, None) 570 | 571 | def fget(self): 572 | return getattr(self, private_name) 573 | 574 | def fset(self, path): 575 | verify_dir(path, dir_name) 576 | setattr(self, private_name, path) 577 | 578 | p = property(fget=fget, fset=fset, doc=docstring) 579 | setattr(self.__class__, property_name, p) 580 | 581 | def __set_dir_properties(self): 582 | """ 583 | Automatically generate the properties for directories. 584 | 585 | """ 586 | directories = [ 587 | ("home", "The ROUGE home directory."), 588 | ("data", "The path of the ROUGE 'data' directory."), 589 | ("system", "Path of the directory containing system summaries."), 590 | ("model", "Path of the directory containing model summaries."), 591 | ] 592 | for (dirname, docstring) in directories: 593 | self.__create_dir_property(dirname, docstring) 594 | 595 | def __clean_rouge_args(self, rouge_args): 596 | """ 597 | Remove enclosing quotation marks, if any. 598 | 599 | """ 600 | if not rouge_args: 601 | return 602 | quot_mark_pattern = re.compile('"(.+)"') 603 | match = quot_mark_pattern.match(rouge_args) 604 | if match: 605 | cleaned_args = match.group(1) 606 | return cleaned_args 607 | else: 608 | return rouge_args 609 | 610 | def __add_config_option(self, options): 611 | return options + [self._config_file] 612 | 613 | def __get_config_path(self): 614 | if platform.system() == "Windows": 615 | parent_dir = os.getenv("APPDATA") 616 | config_dir_name = "pyrouge" 617 | elif os.name == "posix": 618 | parent_dir = os.path.expanduser("~") 619 | config_dir_name = ".pyrouge" 620 | else: 621 | parent_dir = os.path.dirname(__file__) 622 | config_dir_name = "" 623 | config_dir = os.path.join(parent_dir, config_dir_name) 624 | if not os.path.exists(config_dir): 625 | os.makedirs(config_dir) 626 | return os.path.join(config_dir, 'settings.ini') 627 | 628 | 629 | if __name__ == "__main__": 630 | import argparse 631 | from utils.argparsers import rouge_path_parser 632 | 633 | parser = argparse.ArgumentParser(parents=[rouge_path_parser]) 634 | args = parser.parse_args() 635 | 636 | rouge = Rouge155(args.rouge_home) 637 | rouge.save_home_dir() 638 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Main training workflow 4 | """ 5 | from __future__ import division 6 | 7 | import argparse 8 | import glob 9 | import os 10 | import time 11 | 12 | import sentencepiece 13 | 14 | from models import data_loader, model_builder 15 | from models.data_loader import load_dataset 16 | from models.model_builder import Summarizer 17 | from models.trainer import build_trainer 18 | from others.logging import logger, init_logger 19 | import torch 20 | import random 21 | 22 | model_flags = ['hidden_size', 'ff_size', 'heads', 'emb_size', 'local_layers', 'inter_layers','structured'] 23 | 24 | def str2bool(v): 25 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 26 | return True 27 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 28 | return False 29 | else: 30 | raise argparse.ArgumentTypeError('Boolean value expected.') 31 | 32 | 33 | 34 | 35 | def wait_and_validate(args, device_id): 36 | timestep = 0 37 | if (args.test_all): 38 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 39 | cp_files.sort(key=os.path.getmtime) 40 | xent_lst = [] 41 | for i, cp in enumerate(cp_files): 42 | step = int(cp.split('.')[-2].split('_')[-1]) 43 | xent = validate(args, device_id, cp, step) 44 | xent_lst.append((xent, cp)) 45 | max_step = xent_lst.index(min(xent_lst)) 46 | if (i - max_step > 10): 47 | break 48 | xent_lst = sorted(xent_lst, key=lambda x: x[0])[:5] 49 | logger.info('PPL %s' % str(xent_lst)) 50 | for xent, cp in xent_lst: 51 | step = int(cp.split('.')[-2].split('_')[-1]) 52 | test(args, device_id, cp, step) 53 | else: 54 | while (True): 55 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 56 | cp_files.sort(key=os.path.getmtime) 57 | if (cp_files): 58 | cp = cp_files[-1] 59 | time_of_cp = os.path.getmtime(cp) 60 | if (not os.path.getsize(cp) > 0): 61 | time.sleep(60) 62 | continue 63 | if (time_of_cp > timestep): 64 | timestep = time_of_cp 65 | step = int(cp.split('.')[-2].split('_')[-1]) 66 | validate(args, device_id, cp, step) 67 | test(args, device_id, cp, step) 68 | 69 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 70 | cp_files.sort(key=os.path.getmtime) 71 | if (cp_files): 72 | cp = cp_files[-1] 73 | time_of_cp = os.path.getmtime(cp) 74 | if (time_of_cp > timestep): 75 | continue 76 | else: 77 | time.sleep(300) 78 | 79 | 80 | def validate(args, device_id, pt, step): 81 | device = "cpu" if args.visible_gpu == '-1' else "cuda" 82 | if (pt != ''): 83 | test_from = pt 84 | else: 85 | test_from = args.test_from 86 | logger.info('Loading checkpoint from %s' % test_from) 87 | checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) 88 | opt = vars(checkpoint['opt']) 89 | for k in opt.keys(): 90 | if (k in model_flags): 91 | setattr(args, k, opt[k]) 92 | print(args) 93 | 94 | spm = sentencepiece.SentencePieceProcessor() 95 | spm.Load(args.vocab_path) 96 | word_padding_idx = spm.PieceToId('') 97 | vocab_size = len(spm) 98 | model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint) 99 | model.eval() 100 | 101 | valid_iter =data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False), {'PAD': word_padding_idx}, 102 | args.batch_size, device, 103 | shuffle=False, is_test=False) 104 | trainer = build_trainer(args, device_id, model, None) 105 | stats = trainer.validate(valid_iter) 106 | trainer._report_step(0, step, valid_stats=stats) 107 | return stats.xent() 108 | 109 | def test(args, device_id, pt, step): 110 | device = "cpu" if args.visible_gpu == '-1' else "cuda" 111 | if (pt != ''): 112 | test_from = pt 113 | else: 114 | test_from = args.test_from 115 | logger.info('Loading checkpoint from %s' % test_from) 116 | checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) 117 | opt = vars(checkpoint['opt']) 118 | for k in opt.keys(): 119 | if (k in model_flags): 120 | setattr(args, k, opt[k]) 121 | print(args) 122 | 123 | spm = sentencepiece.SentencePieceProcessor() 124 | spm.Load(args.vocab_path) 125 | word_padding_idx = spm.PieceToId('') 126 | vocab_size = len(spm) 127 | model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint) 128 | model.eval() 129 | 130 | test_iter =data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), {'PAD': word_padding_idx}, 131 | args.batch_size, device, 132 | shuffle=False, is_test=True) 133 | trainer = build_trainer(args, device_id, model, None) 134 | trainer.test(test_iter,step) 135 | 136 | 137 | 138 | 139 | def train(args, device_id): 140 | init_logger(args.log_file) 141 | 142 | if args.train_from != '': 143 | logger.info('Loading checkpoint from %s' % args.train_from) 144 | checkpoint = torch.load(args.train_from, 145 | map_location=lambda storage, loc: storage) 146 | opt = vars(checkpoint['opt']) 147 | for k in opt.keys(): 148 | if (k in model_flags): 149 | setattr(args, k, opt[k]) 150 | else: 151 | checkpoint = None 152 | 153 | torch.manual_seed(args.seed) 154 | random.seed(args.seed) 155 | torch.backends.cudnn.deterministic = True 156 | 157 | 158 | spm = sentencepiece.SentencePieceProcessor() 159 | spm.Load(args.vocab_path) 160 | word_padding_idx = spm.PieceToId('') 161 | vocab_size = len(spm) 162 | 163 | 164 | def train_iter_fct(): 165 | # return data_loader.AbstractiveDataloader(load_dataset('train', True), symbols, FLAGS.batch_size, device, True) 166 | return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), {'PAD':word_padding_idx}, args.batch_size, device, 167 | shuffle=True, is_test=False) 168 | 169 | model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint) 170 | optim = model_builder.build_optim(args, model, checkpoint) 171 | logger.info(model) 172 | trainer = build_trainer(args, device_id, model, optim) 173 | # 174 | trainer.train(train_iter_fct, args.train_steps) 175 | 176 | 177 | 178 | if __name__ == '__main__': 179 | parser = argparse.ArgumentParser() 180 | 181 | 182 | 183 | parser.add_argument("-mode", default='train', type=str) 184 | parser.add_argument("-onmt_path", default='../data/onmt_data/cnndm') 185 | parser.add_argument("-data_path", default='../data/') 186 | parser.add_argument("-raw_path", default='../line_data') 187 | parser.add_argument("-vocab_path", default='../data/spm.cnndm.model') 188 | parser.add_argument("-model_path", default='../models/') 189 | parser.add_argument("-result_path", default='../results/cnndm') 190 | parser.add_argument("-temp_dir", default='../temp') 191 | 192 | parser.add_argument("-batch_size", default=10000, type=int) 193 | parser.add_argument('-min_nsents', default=3, type=int) 194 | parser.add_argument('-max_nsents', default=100, type=int) 195 | parser.add_argument('-min_src_ntokens', default=5, type=int) 196 | parser.add_argument('-max_src_ntokens', default=200, type=int) 197 | 198 | 199 | parser.add_argument("-structured", type=str2bool, nargs='?',const=True,default=False) 200 | parser.add_argument("-hidden_size", default=128, type=int) 201 | parser.add_argument("-ff_size", default=512, type=int) 202 | parser.add_argument("-heads", default=8, type=int) 203 | parser.add_argument("-emb_size", default=128, type=int) 204 | parser.add_argument("-local_layers", default=5, type=int) 205 | parser.add_argument("-inter_layers", default=2, type=int) 206 | 207 | parser.add_argument("-dropout", default=0.2, type=float) 208 | parser.add_argument("-optim", default='adam', type=str) 209 | parser.add_argument("-lr", default=0.15, type=float) 210 | parser.add_argument("-beta1", default= 0.9, type=float) 211 | parser.add_argument("-beta2", default=0.999, type=float) 212 | parser.add_argument("-decay_method", default='', type=str) 213 | parser.add_argument("-warmup_steps", default=8000, type=int) 214 | parser.add_argument("-max_grad_norm", default=0, type=float) 215 | 216 | parser.add_argument("-save_checkpoint_steps", default=5, type=int) 217 | parser.add_argument("-accum_count", default=1, type=int) 218 | parser.add_argument("-report_every", default=10, type=int) 219 | parser.add_argument("-train_steps", default=1000, type=int) 220 | 221 | 222 | parser.add_argument('-visible_gpu', default='-1', type=str) 223 | parser.add_argument('-gpu_ranks', default=[0], type=list) 224 | parser.add_argument('-log_file', default='../logs/cnndm.log') 225 | parser.add_argument('-dataset', default='') 226 | parser.add_argument('-seed', default=666, type=int) 227 | 228 | parser.add_argument("-test_all", type=str2bool, nargs='?',const=True,default=False) 229 | parser.add_argument("-train_from", default='') 230 | parser.add_argument("-test_from", default='') 231 | parser.add_argument("-report_rouge", type=str2bool, nargs='?',const=True,default=True) 232 | 233 | args = parser.parse_args() 234 | print(args) 235 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu 236 | 237 | init_logger(args.log_file) 238 | device = "cpu" if args.visible_gpu == '-1' else "cuda" 239 | device_id = 0 if device == "cuda" else -1 240 | 241 | if (args.mode == 'train'): 242 | train(args, device_id) 243 | elif (args.mode == 'validate'): 244 | wait_and_validate(args, device_id) 245 | elif (args.mode == 'test'): 246 | test(args, device_id, args.test_from, 0) 247 | -------------------------------------------------------------------------------- /temp/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore --------------------------------------------------------------------------------