├── .idea ├── .gitignore ├── ds.iml ├── inspectionProfiles │ └── profiles_settings.xml └── modules.xml ├── ds.py ├── generate ├── __init__.py ├── debug_outputs.py ├── generate.py └── utils.py ├── model_configs ├── ds.full.yaml └── lm_prior.yaml ├── models ├── __init__.py ├── ds_losses.py ├── ds_trainer.py ├── ds_utils.py └── sent_lm_trainer.py ├── modules ├── __init__.py ├── data │ ├── __init__.py │ ├── collates.py │ ├── datasets.py │ ├── datasets_ds.py │ ├── samplers.py │ ├── utils.py │ └── vocab.py ├── helpers.py ├── layers.py ├── models.py ├── modules.py └── training │ ├── __init__.py │ ├── base_trainer.py │ └── trainer.py ├── mylogger ├── __init__.py ├── attention.py ├── db.json ├── experiment.py ├── helpers.py ├── inspection.py └── plotting.py ├── rouge-test.py ├── sys_config.py └── utils ├── __init__.py ├── _logging.py ├── config.py ├── data_parsing.py ├── eval.py ├── generic.py ├── load_embeddings.py ├── opts.py ├── training.py ├── transfer.py └── viz.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/ds.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /ds.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import math 3 | import os 4 | import warnings 5 | 6 | import numpy 7 | import codecs 8 | import torch 9 | from tabulate import tabulate 10 | from torch import nn 11 | from torch.distributions import Categorical 12 | from torch.utils.data import DataLoader 13 | from generate.utils import devectorize, devectorize_generate 14 | from models.ds_trainer import DsTrainer 15 | from models.ds_utils import compute_dataset_idf 16 | from modules.data.collates import Seq2SeqCollate, Seq2SeqOOVCollate 17 | from modules.data.datasets_ds import DsDataset 18 | from modules.data.samplers import BucketBatchSampler 19 | from modules.models import Seq2Seq2Seq 20 | from modules.modules import SeqReader, transfer_weigths 21 | from mylogger.attention import samples2html 22 | from mylogger.experiment import Experiment 23 | from sys_config import EXP_DIR, EMBS_PATH, MODEL_CNF_DIR 24 | from utils.eval import rouge_file_list, pprint_rouge_scores, rouge_files, rouge_files_simple 25 | from utils.generic import number_h 26 | from utils.opts import seq2seq2seq_options 27 | from utils.training import load_checkpoint 28 | from utils.transfer import freeze_module 29 | from tensorboardX import SummaryWriter 30 | 31 | #################################################################### 32 | # Settings 33 | #################################################################### 34 | opts, config = seq2seq2seq_options() 35 | folder = os.path.exists(config["data"]["result_path"]) 36 | if not folder: 37 | os.makedirs(config["data"]["result_path"]) 38 | vocab = None 39 | writer = SummaryWriter(config["tensorboard_path"]) 40 | 41 | if config["main_model"] is not None: 42 | main_model_checkpoint = load_checkpoint(config["main_model"]) 43 | 44 | 45 | #################################################################### 46 | # Weight Transfer (pre-train language model) 47 | #################################################################### 48 | if config["model"]["prior_loss"] and config["prior"] is not None: 49 | print("Loading Oracle LM ...") 50 | oracle_cp = load_checkpoint(config["prior"]) 51 | vocab = oracle_cp["vocab"] 52 | 53 | oracle = SeqReader(len(vocab), **oracle_cp["config"]["model"]) 54 | oracle.load_state_dict(oracle_cp["model"]) 55 | oracle.to(opts.device) 56 | freeze_module(oracle) 57 | else: 58 | oracle = None 59 | 60 | 61 | #################################################################### 62 | # Data Loading and Preprocessing 63 | #################################################################### 64 | print("Building training dataset...") 65 | train_data = DsDataset(config["data"]["train_path"], 66 | batch_size=config["batch_size"], 67 | sent_num=config["data"]["sent_num"], 68 | sent_len=config["data"]["sent_len"], 69 | mode="train", 70 | sent_sim_len=config["data"]["sent_sim_len"], 71 | k=config["model"]["k"], 72 | vocab=vocab, 73 | vocab_size=config["vocab"]["size"], 74 | seq_len=config["data"]["seq_len"], 75 | dec_seq_len= config["data"]["dec_seq_len"], 76 | oovs=config["data"]["oovs"]) 77 | 78 | 79 | print("Building validation dataset...") 80 | val_data = DsDataset(config["data"]["val_path"], 81 | batch_size=config["batch_size"], 82 | sent_num=config["data"]["sent_num"], 83 | sent_len=config["data"]["sent_len"], 84 | mode="test", 85 | k=config["model"]["k"], 86 | sent_sim_len=config["data"]["sent_sim_len"], 87 | summary_path=config["data"]["ref_path"], 88 | nsent_path=config["data"]["ref_nsent_path"], 89 | org_dia_path=config["data"]["org_dia_path"], 90 | vocab=vocab, 91 | vocab_size=config["vocab"]["size"], 92 | seq_len=config["data"]["seq_len"], 93 | dec_seq_len= config["data"]["dec_seq_len"], 94 | return_oov=True, 95 | oovs=config["data"]["oovs"]) 96 | 97 | val_data.vocab = train_data.vocab 98 | vocab = train_data.vocab 99 | 100 | # define a dataloader, which handles the way a dataset will be loaded, 101 | # like batching, shuffling and so on ... 102 | train_lengths = [len(x) for x in train_data.data] 103 | 104 | train_sampler = BucketBatchSampler(train_lengths, config["batch_size"]) 105 | train_loader = DataLoader(train_data, batch_sampler=train_sampler, 106 | num_workers=config["num_workers"], 107 | collate_fn=Seq2SeqCollate(config["data"]["sent_sim_len"], config["data"]["sent_num"], config["data"]["sent_len"])) 108 | val_loader = DataLoader(val_data, batch_size=config["batch_size"], 109 | num_workers=config["num_workers"], shuffle=False, 110 | collate_fn=Seq2SeqOOVCollate(config["data"]["sent_sim_len"], config["data"]["sent_num"], config["data"]["sent_len"])) 111 | 112 | #################################################################### 113 | # Model Definition 114 | # - additional layer initializations 115 | # - weight / layer tying 116 | #################################################################### 117 | 118 | # Define the model 119 | n_tokens = len(train_data.vocab) 120 | model = Seq2Seq2Seq(n_tokens, **config["model"]) 121 | criterion = nn.CrossEntropyLoss(ignore_index=0) 122 | 123 | def word_embedding(model): 124 | # Load Pretrained Word Embeddings 125 | if "embeddings" in config["vocab"] and config["vocab"]["embeddings"]: 126 | emb_file = os.path.join(EMBS_PATH, config["vocab"]["embeddings"]) 127 | dims = config["vocab"]["embeddings_dim"] 128 | 129 | embs, emb_mask, missing = train_data.vocab.read_embeddings(emb_file, dims) 130 | model.initialize_embeddings(embs, config["model"]["embed_trainable"]) 131 | 132 | # initialize the output layers with the pretrained embeddings, 133 | # regardless of whether they will be tied 134 | try: 135 | model.compressor.Wo.weight.data.copy_(torch.from_numpy(embs)) 136 | model.decompressor.Wo.weight.data.copy_(torch.from_numpy(embs)) 137 | except: 138 | print("Can't init outputs from embeddings. Dim mismatch!") 139 | 140 | if config["model"]["embed_masked"] and config["model"]["embed_trainable"]: 141 | model.set_embedding_gradient_mask(emb_mask) 142 | 143 | 144 | def topic_pre(model): 145 | if config["model"]["topic_loss"] and config["model"]["topic_idf"]: 146 | print("Computing IDF values...") 147 | idf = compute_dataset_idf(train_data, train_data.vocab.tok2id) 148 | # idf[vocab.tok2id[vocab.SOS]] = 1 # neutralize padding token 149 | # idf[vocab.tok2id[vocab.EOS]] = 1 # neutralize padding token 150 | idf[vocab.tok2id[vocab.PAD]] = 1 # neutralize padding token 151 | model.initialize_embeddings_idf(idf) 152 | 153 | 154 | def tie_models(model): 155 | """Tie encoder/decoder of models""" 156 | 157 | # tie the embedding layers 158 | if config["model"]["tie_embedding"]: 159 | model.cmp_encoder.embed = model.inp_encoder.embed 160 | model.compressor.embed = model.inp_encoder.embed 161 | model.decompressor.embed = model.inp_encoder.embed 162 | model.original_task.embed = model.inp_encoder.embed 163 | 164 | # tie the output layers of the decoders 165 | """if config["model"]["tie_decoder_outputs"]: 166 | model.compressor.Wo = model.decompressor.Wo""" 167 | 168 | # tie the embedding to the output layers 169 | if config["model"]["tie_embedding_outputs"]: 170 | emb_size = model.compressor.embed.embedding.weight.size(1) 171 | rnn_size = model.compressor.Wo.weight.size(1) 172 | 173 | if emb_size != rnn_size: 174 | warnings.warn("Can't tie outputs, since emb_size != rnn_size.") 175 | else: 176 | model.compressor.Wo.weight = model.inp_encoder.embed.embedding.weight 177 | model.decompressor.Wo.weight = model.inp_encoder.embed.embedding.weight 178 | model.original_task.Wo.weight = model.inp_encoder.embed.embedding.weight 179 | 180 | if config["model"]["tie_decoders"]: 181 | #model.compressor = model.decompressor 182 | #model.decompressor = model.original_task 183 | transfer_weigths(model.decompressor, model.original_task) 184 | transfer_weigths(model.compressor, model.original_task) 185 | 186 | if config["model"]["tie_encoders"]: 187 | #model.cmp_encoder = model.inp_encoder 188 | transfer_weigths(model.cmp_encoder, model.inp_encoder) 189 | 190 | # then we need only one bridge 191 | if config["model"]["tie_encoders"] and config["model"]["tie_decoders"]: 192 | model.src_bridge = model.trg_bridge 193 | 194 | 195 | word_embedding(model) 196 | #topic_pre(model) 197 | tie_models(model) 198 | 199 | #################################################################### 200 | # Experiment Logging and Visualization 201 | #################################################################### 202 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 203 | tmp = model.parameters() 204 | optimizer = torch.optim.Adam(parameters, 205 | lr=config["lr"], 206 | weight_decay=config["weight_decay"]) 207 | 208 | model.to(opts.device) 209 | print(model) 210 | 211 | total_params = sum(p.numel() for p in model.parameters()) 212 | total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 213 | 214 | print("Total Params:", number_h(total_params)) 215 | print("Total Trainable Params:", number_h(total_trainable_params)) 216 | trainable_params = sorted([[n] for n, p in model.named_parameters() if p.requires_grad]) 217 | 218 | 219 | def exp_log_visual(model): 220 | """Experiment Logging and Visualization""" 221 | if config["prior"] is not None: 222 | opts.name += "_" + config["prior"] 223 | 224 | exp = Experiment(opts.name, config, src_dirs=opts.source, output_dir=EXP_DIR) 225 | 226 | step_tags = [] 227 | step_tags.append("NSENT") 228 | 229 | if config["model"]["n_sent_sum_loss"]: 230 | step_tags.append("NSENTSUM") 231 | if config["model"]["prior_loss"] and config["prior"] is not None: 232 | step_tags.append("PRIOR") 233 | if config["model"]["topic_loss"]: 234 | step_tags.append("TOPIC") 235 | if config["model"]["length_loss"]: 236 | step_tags.append("LENGTH") 237 | if config["model"]["doc_sum_kl_loss"]: 238 | step_tags.append("DOCSUMKL") 239 | if config["model"]["doc_sum_sim_loss"]: 240 | step_tags.append("DOCSUMSIM") 241 | if config["model"]["sum_loss"]: 242 | step_tags.append("SUMMARY") 243 | if config["model"]["nsent_classification"]: 244 | step_tags.append("CLS") 245 | if config["model"]["nsent_classification_sum"]: 246 | step_tags.append("CLSSUM") 247 | if config["model"]["nsent_classification_kl"]: 248 | step_tags.append("CLSKL") 249 | 250 | exp.add_metric("loss", "line", tags=step_tags) 251 | exp.add_metric("ppl", "line", title="perplexity", tags=step_tags) 252 | exp.add_metric("rouge", "line", title="ROUGE (F1)", tags=["R-1", "R-2", "R-L"]) 253 | exp.add_value("grads", "text", title="gradients") 254 | 255 | exp.add_metric("c_norm", "line", title="Compressor Grad Norms", 256 | tags=step_tags[:len(set(step_tags) & {"NSENTSUM", "PRIOR", "TOPIC", "DOCSUMKL", "DOCSUMSIM", "SUMMARY", 257 | "CLS", "CLSSUM", "CLSKL"}) + 1]) 258 | exp.add_value("progress", "text", title="training progress") 259 | exp.add_value("epoch", "text", title="epoch summary") 260 | exp.add_value("samples", "text", title="Samples") 261 | exp.get_value("samples").pre = False 262 | exp.add_value("weights", "text") 263 | exp.add_value("rouge-stats", "text") 264 | exp.add_value("states", "scatter") 265 | exp.add_metric("lr", "line", "Learning Rate") 266 | exp.add_value("rouge-stats", "text") 267 | 268 | return exp, step_tags 269 | 270 | exp, step_tags = exp_log_visual(model) 271 | 272 | #################################################################### 273 | # 274 | # Training Pipeline 275 | # - batch/epoch callbacks for logging, checkpoints, visualization... 276 | # - initialize trainer 277 | # - initialize training loop 278 | # 279 | #################################################################### 280 | def stats_callback(batch, losses, loss_list, batch_outputs, epoch): 281 | if trainer.step % config["log_interval"] == 0: 282 | 283 | # log gradient norms 284 | grads = sorted(trainer.grads(), key=lambda tup: tup[1], reverse=True) 285 | grads_table = tabulate(grads, numalign="right", floatfmt=".5f", headers=['Parameter', 'Grad(Norm)']) 286 | exp.update_value("grads", grads_table) 287 | 288 | _losses = losses[-config["log_interval"]:] 289 | mono_losses = numpy.array([x[:len(step_tags)] for x in _losses]).mean(0) 290 | for loss, tag in zip(mono_losses, step_tags): 291 | exp.update_metric("loss", loss, tag) 292 | exp.update_metric("ppl", math.exp(loss), tag) 293 | 294 | ################################################ 295 | losses_log = exp.log_metrics(["loss", "ppl"], epoch) 296 | exp.update_value("progress", trainer.progress_log + "\n" + losses_log) 297 | 298 | # clean lines and move cursor back up N lines 299 | print("\n\033[K" + losses_log) 300 | #print("\033[F" * (len(losses_log.split("\n")) + 2)) 301 | 302 | 303 | def samples_to_text(tensor): 304 | return devectorize(tensor.tolist(), train_data.vocab.id2tok, 305 | train_data.vocab.tok2id[vocab.EOS], 306 | strip_eos=False, pp=False) 307 | 308 | 309 | def outs_callback(batch, losses, loss_list, batch_outputs, epoch): 310 | if trainer.step % config["log_interval"] == 0: 311 | prob, enc, enc_filter, dec1, dec2, sent_num, dialog_pre, summary_pre = batch_outputs['model_outputs'] 312 | 313 | if config["plot_norms"]: 314 | norms = batch_outputs['grad_norm'] 315 | exp.update_metric("c_norm", norms[0], "NSENT") 316 | 317 | if "NSENTSUM" in step_tags: 318 | exp.update_metric("c_norm", norms[loss_ids["nsent_sum"]], "NSENTSUM") 319 | 320 | if "TOPIC" in step_tags: 321 | exp.update_metric("c_norm", norms[loss_ids["topic"]], "TOPIC") 322 | 323 | if "PRIOR" in step_tags: 324 | exp.update_metric("c_norm", norms[loss_ids["prior"]], "PRIOR") 325 | 326 | if "DOC-SUM" in step_tags: 327 | exp.update_metric("c_norm", norms[loss_ids["doc_sum_kl"]], "DOCSUM") 328 | 329 | if "SUMMARY" in step_tags: 330 | exp.update_metric("c_norm", norms[loss_ids["sum"]], "SUMMARY") 331 | 332 | if "CLS" in step_tags: 333 | exp.update_metric("c_norm", norms[loss_ids["cls"]], "CLS") 334 | 335 | if "CLSSUM" in step_tags: 336 | exp.update_metric("c_norm", norms[loss_ids["clssum"]], "CLSSUM") 337 | 338 | if "CLSKL" in step_tags: 339 | exp.update_metric("c_norm", norms[loss_ids["clskl"]], "CLSKL") 340 | 341 | if len(batch) == 2: 342 | inp = batch[0][0] 343 | else: 344 | inp = batch[0] 345 | src = samples_to_text(inp) 346 | hyp = samples_to_text(dec1[3].max(dim=2)[1]) 347 | nsent = samples_to_text(dec2[0].max(dim=2)[1]) 348 | 349 | # prior outputs 350 | if "prior" in batch_outputs: 351 | prior_loss = batch_outputs['prior'][0].squeeze().tolist() 352 | prior_logits = batch_outputs['prior'][1] 353 | 354 | prior_argmax = prior_logits.max(dim=2)[1] 355 | prior_entropy = Categorical(logits=prior_logits).entropy().tolist() 356 | 357 | prior = samples_to_text(prior_argmax) 358 | 359 | if "attention" in batch_outputs: 360 | att_scores = batch_outputs['attention'][0].squeeze().tolist() 361 | else: 362 | att_scores = None 363 | 364 | if config["model"]["learn_tau"]: 365 | temps = dec1[5].cpu().data.numpy().round(2) 366 | else: 367 | temps = None 368 | 369 | nsent_losses = batch_outputs['n_sent'].tolist() 370 | 371 | samples = [] 372 | for i in range(len(src)): 373 | sample = [] 374 | 375 | if att_scores is not None: 376 | _src = 'SRC', (src[i], att_scores[i]), "255, 0, 0" 377 | else: 378 | _src = 'SRC', src[i], "0, 0, 0" 379 | sample.append(_src) 380 | 381 | if "prior" in batch_outputs: 382 | _hyp = 'HYP', (hyp[i], prior_loss[i]), "0, 0, 255" 383 | _pri = 'LM ', (prior[i], prior_entropy[i]), "0, 255, 0" 384 | sample.append(_hyp) 385 | sample.append(_pri) 386 | else: 387 | _hyp = 'HYP', hyp[i], "0, 0, 255" 388 | sample.append(_hyp) 389 | 390 | if temps is not None: 391 | _tmp = 'TMP', (list(map(str, temps[i])), temps[i]), "255, 0, 0" 392 | sample.append(_tmp) 393 | 394 | _nsent = 'NSENT', (nsent[i], nsent_losses[i]), "255, 0, 0" 395 | sample.append(_nsent) 396 | 397 | samples.append(sample) 398 | 399 | html_samples = samples2html(samples) 400 | exp.update_value("samples", html_samples) 401 | with open(os.path.join(EXP_DIR, f"{opts.name}.samples.html"), 'w') as f: 402 | f.write(html_samples) 403 | 404 | best_test_score = -1. 405 | def eval_callback(batch, losses, loss_list, batch_outputs, epoch): 406 | global best_test_score 407 | if trainer.step % config["checkpoint_interval"] == 0: 408 | tags = [trainer.epoch, trainer.step] 409 | trainer.checkpoint(name=opts.name, tags=tags) 410 | exp.save() 411 | 412 | original_dialogue = val_data.inputs 413 | if trainer.step % config["eval_interval"] == 0: 414 | results_sent, oov_maps, k_indices = trainer.eval_epoch(config["batch_size"]) 415 | results_sent = list(itertools.chain.from_iterable(results_sent)) 416 | oov_maps = list(itertools.chain.from_iterable(oov_maps)) 417 | 418 | # generate prund summary 419 | v = train_data.vocab 420 | tokens = devectorize_generate(results_sent, v.id2tok, v.tok2id[v.EOS], True, oov_maps) 421 | hyps = [" ".join(x) for x in tokens] 422 | 423 | # generate original summary 424 | hyps_org = [] 425 | k_indices = [x for j in k_indices for x in j] 426 | for index in range(len(k_indices)): 427 | try: 428 | k_sent_index = k_indices[index].tolist() 429 | tmp_file = "" 430 | for j in k_sent_index[0]: 431 | tmp = original_dialogue[index] 432 | tmp_file += original_dialogue[index][j] + " " 433 | hyps_org.append(tmp_file) 434 | except Exception: 435 | hyps_org.append("被告") 436 | 437 | # evaluate summary 438 | #scores = rouge_files(config["data"]["result_path"], config["data"]["ref_path"], hyps) 439 | scores = rouge_files(config["data"]["result_path"], config["data"]["ref_path"], hyps_org) 440 | rouge_table = pprint_rouge_scores(scores) 441 | exp.update_value("rouge-stats", rouge_table) 442 | exp.update_metric("rouge", scores['rouge-1']['f'], "R-1") 443 | exp.update_metric("rouge", scores['rouge-2']['f'], "R-2") 444 | exp.update_metric("rouge", scores['rouge-l']['f'], "R-L") 445 | 446 | epoch_times = trainer.step / config["eval_interval"] 447 | writer.add_scalar('Test/rouge-1', scores['rouge-1']['f'], epoch_times) 448 | writer.add_scalar('Test/rouge-2', scores['rouge-2']['f'], epoch_times) 449 | writer.add_scalar('Test/rouge-l', scores['rouge-l']['f'], epoch_times) 450 | writer.flush() 451 | 452 | if scores['rouge-1']['f'] > best_test_score: 453 | best_test_score = scores['rouge-1']['f'] 454 | # save the best decode results (base on rouge1-f) 455 | tmp_decode_file = codecs.open(config["data"]["dec_summ_path"], 'w') 456 | for hyp in hyps_org: 457 | tmp_decode_file.write(hyp.replace("\n", "")+"\n") 458 | 459 | save_best() 460 | 461 | 462 | #################################################################### 463 | # Loss Weight: order matters! 464 | #################################################################### 465 | def loss_id(): 466 | loss_ids = {} 467 | 468 | loss_weights = [config["model"]["loss_weight_nsent"]] 469 | loss_ids["nsent"] = len(loss_weights) - 1 470 | if config["model"]["n_sent_sum_loss"]: 471 | loss_weights.append(config["model"]["loss_weight_nsent_sum"]) 472 | loss_ids["nsent_sum"] = len(loss_weights) - 1 473 | if config["model"]["prior_loss"] and config["prior"] is not None: 474 | loss_weights.append(config["model"]["loss_weight_prior"]) 475 | loss_ids["prior"] = len(loss_weights) - 1 476 | if config["model"]["topic_loss"]: 477 | loss_weights.append(config["model"]["loss_weight_topic"]) 478 | loss_ids["topic"] = len(loss_weights) - 1 479 | if config["model"]["length_loss"]: 480 | loss_weights.append(config["model"]["loss_weight_length"]) 481 | loss_ids["length"] = len(loss_weights) - 1 482 | if config["model"]["doc_sum_kl_loss"]: 483 | loss_weights.append(config["model"]["loss_weight_doc_sum"]) 484 | loss_ids["doc_sum_kl"] = len(loss_weights) - 1 485 | if config["model"]["doc_sum_sim_loss"]: 486 | loss_weights.append(config["model"]["loss_weight_doc_sum_sim"]) 487 | loss_ids["doc_sum_sim"] = len(loss_weights) - 1 488 | if config["model"]["sum_loss"]: 489 | loss_weights.append(config["model"]["loss_weight_sum"]) 490 | loss_ids["sum"] = len(loss_weights) - 1 491 | if config["model"]["nsent_classification"]: 492 | loss_weights.append(config["model"]["loss_weight_classification"]) 493 | loss_ids["cls"] = len(loss_weights) - 1 494 | if config["model"]["nsent_classification_sum"]: 495 | loss_weights.append(config["model"]["loss_weight_classification_sum"]) 496 | loss_ids["cls_sum"] = len(loss_weights) - 1 497 | if config["model"]["nsent_classification_kl"]: 498 | loss_weights.append(config["model"]["loss_weight_classification_kl"]) 499 | loss_ids["cls_sum"] = len(loss_weights) - 1 500 | return loss_id, loss_weights 501 | 502 | loss_id, loss_weights = loss_id() 503 | if config["main_model"] is not None: 504 | #optimizer.load_state_dict(main_model_checkpoint["optimizers"]) 505 | model.load_state_dict(main_model_checkpoint["model"]) 506 | trainer = DsTrainer(model, train_loader, val_loader, 507 | criterion, optimizer, config, opts.device, 508 | batch_end_callbacks=[stats_callback, 509 | outs_callback, 510 | eval_callback], 511 | loss_weights=loss_weights, oracle=oracle) 512 | 513 | #################################################################### 514 | # Training Loop 515 | #################################################################### 516 | 517 | assert not train_data.vocab.is_corrupt() 518 | assert not val_data.vocab.is_corrupt() 519 | 520 | best_score = None 521 | def save_best(): 522 | global best_score 523 | _score = exp.get_metric("rouge").values["R-2"][-1] 524 | if not best_score or _score > best_score: 525 | best_score = _score 526 | trainer.checkpoint() 527 | exp.save() 528 | 529 | for epoch in range(config["epochs"]): 530 | batch_num = train_data.__len__()/config["batch_size"] 531 | train_loss = trainer.train_epoch(config["pre_train_epochs"], batch_num, writer) 532 | 533 | # Save the model if the validation loss is the best we've seen so far. 534 | save_best() 535 | -------------------------------------------------------------------------------- /generate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiyan524/RepSum/cddc47188e445ccdc23b30e9d8d5f2daa16b7c1d/generate/__init__.py -------------------------------------------------------------------------------- /generate/debug_outputs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from sys_config import DATA_DIR, BASE_DIR 6 | from generate.utils import compress_seq3 7 | from utils.viz import seq3_attentions 8 | 9 | checkpoint = "seq3.full" 10 | seed = 1 11 | device = "cpu" 12 | verbose = True 13 | out_file = "" 14 | torch.manual_seed(seed) 15 | if torch.cuda.is_available(): 16 | torch.cuda.manual_seed(seed) 17 | # 18 | # src_file = os.path.join(DATA_DIR, "gigaword/test_1951/input.txt") 19 | # out_file = os.path.join(DATA_DIR, "gigaword/test_1951/preds.txt") 20 | 21 | # src_file = os.path.join(DATA_DIR, "gigaword/test_1951/input_min8.txt") 22 | src_file = os.path.join(DATA_DIR, "gigaword/small/valid.article.filter.4K.txt") 23 | # src_file = os.path.join(DATA_DIR, "gigaword/dev/valid.src.small.txt") 24 | # src_file = os.path.join(BASE_DIR, "evaluation/DUC2003/input.txt") 25 | # src_file = os.path.join(BASE_DIR, "evaluation/DUC2004/input.txt") 26 | 27 | out_file = os.path.join(BASE_DIR, f"evaluation/{checkpoint}_preds.txt") 28 | results = compress_seq3(checkpoint, src_file, out_file, device, True, 29 | mode="debug") 30 | 31 | # seq3_attentions(results[:15], file=checkpoint + ".pdf") 32 | -------------------------------------------------------------------------------- /generate/generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from generate.utils import compress_seq3 6 | from sys_config import BASE_DIR 7 | 8 | checkpoint = "seq3" 9 | seed = 1 10 | device = "cuda" 11 | verbose = False 12 | torch.manual_seed(seed) 13 | if torch.cuda.is_available(): 14 | torch.cuda.manual_seed(seed) 15 | 16 | datasets = { 17 | "gigaword": os.path.join(BASE_DIR, 18 | "evaluation/gigaword/input_min8.txt"), 19 | "DUC2003": os.path.join(BASE_DIR, "evaluation/DUC2003/input.txt"), 20 | "DUC2004": os.path.join(BASE_DIR, "evaluation/DUC2004/input.txt"), 21 | } 22 | 23 | for name, src_file in datasets.items(): 24 | 25 | if name == "gigaword": 26 | length = None 27 | else: 28 | length = 17 29 | 30 | out_file = os.path.join(BASE_DIR, 31 | f"evaluation/hyps/{name}_{checkpoint}_preds.txt") 32 | 33 | compress_seq3(checkpoint, src_file, out_file, device, mode="results") 34 | -------------------------------------------------------------------------------- /generate/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from itertools import groupby 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | 8 | from modules.data.collates import Seq2SeqOOVCollate 9 | from modules.data.datasets import AEDataset 10 | from modules.models import Seq2Seq2Seq 11 | from utils.training import load_checkpoint 12 | 13 | 14 | def compress_seq3(checkpoint, src_file, out_file, 15 | device, verbose=False, mode="attention"): 16 | checkpoint = load_checkpoint(checkpoint) 17 | config = checkpoint["config"] 18 | vocab = checkpoint["vocab"] 19 | 20 | def giga_tokenizer(x): 21 | return x.strip().lower().split() 22 | 23 | dataset = AEDataset(src_file, 24 | preprocess=giga_tokenizer, 25 | vocab=checkpoint["vocab"], 26 | seq_len=config["data"]["seq_len"], 27 | return_oov=True, 28 | oovs=config["data"]["oovs"]) 29 | 30 | data_loader = DataLoader(dataset, batch_size=config["batch_size"], 31 | num_workers=0, collate_fn=Seq2SeqOOVCollate()) 32 | n_tokens = len(dataset.vocab) 33 | model = Seq2Seq2Seq(n_tokens, **config["model"]).to(device) 34 | model.load_state_dict(checkpoint["model"]) 35 | model.eval() 36 | 37 | ############################################## 38 | 39 | n_batches = math.ceil(len(data_loader.dataset) / data_loader.batch_size) 40 | 41 | if verbose: 42 | iterator = tqdm(enumerate(data_loader, 1), total=n_batches) 43 | else: 44 | iterator = enumerate(data_loader, 1) 45 | 46 | def devect(ids, oov, strip_eos, pp): 47 | return devectorize(ids.tolist(), vocab.id2tok, vocab.tok2id[vocab.EOS], 48 | strip_eos=strip_eos, oov_map=oov, pp=pp) 49 | 50 | def id2txt(ids, oov=None, lengths=None, strip_eos=True): 51 | if lengths: 52 | return [" ".join(x[:l]) for l, x in 53 | zip(lengths, devect(ids, oov, strip_eos, pp=True))] 54 | else: 55 | return [" ".join(x) for x in devect(ids, oov, strip_eos, pp=True)] 56 | 57 | results = [] 58 | with open(out_file, "w") as f: 59 | with torch.no_grad(): 60 | for i, batch in iterator: 61 | batch_oov_map = batch[-1] 62 | batch = batch[:-1] 63 | 64 | batch = list(map(lambda x: x.to(device), batch)) 65 | (inp_src, out_src, inp_trg, out_trg, 66 | src_lengths, trg_lengths) = batch 67 | 68 | trg_lengths = torch.clamp(src_lengths / 2, min=5, max=30) + 1 69 | 70 | ############################################################# 71 | # Debug 72 | ############################################################# 73 | if mode in ["attention", "debug"]: 74 | 75 | outputs = model(inp_src, inp_trg, src_lengths, trg_lengths, 76 | sampling=0) 77 | enc1, dec1, enc2, dec2 = outputs 78 | 79 | if mode == "debug": 80 | 81 | src = id2txt(inp_src) 82 | latent = id2txt(dec1[3].max(-1)[1]) 83 | rec = id2txt(dec2[0].max(-1)[1]) 84 | 85 | _results = list(zip(src, latent, rec)) 86 | 87 | for sample in _results: 88 | f.write("\n".join(sample) + "\n\n") 89 | 90 | elif mode == "attention": 91 | src = devect(inp_src, None, strip_eos=False, pp=False) 92 | latent = devect(dec1[3].max(-1)[1], 93 | None, strip_eos=False, pp=False) 94 | rec = devect(dec2[0].max(-1)[1], 95 | None, strip_eos=False, pp=False) 96 | 97 | _results = [src, latent, dec1[4], rec, dec2[4]] 98 | 99 | results += list(zip(*_results)) 100 | 101 | break 102 | 103 | else: 104 | raise ValueError 105 | else: 106 | enc1, dec1 = model.generate(inp_src, src_lengths, 107 | trg_lengths) 108 | 109 | preds = id2txt(dec1[0].max(-1)[1], 110 | batch_oov_map, trg_lengths.tolist()) 111 | 112 | for sample in preds: 113 | f.write(sample + "\n") 114 | return results 115 | 116 | 117 | def devectorize(data, id2tok, eos, strip_eos=True, oov_map=None, pp=True): 118 | if strip_eos: 119 | for i in range(len(data)): 120 | try: 121 | data[i] = data[i][:list(data[i]).index(eos)] 122 | except: 123 | continue 124 | 125 | # ids to words 126 | data = [[id2tok.get(x, "") for x in seq] for seq in data] 127 | 128 | if oov_map is not None: 129 | data = [[m.get(x, x) for x in seq] for seq, m in zip(data, oov_map)] 130 | 131 | if pp: 132 | rules = {f"": "UNK" for i in range(10)} 133 | rules["unk"] = "UNK" 134 | rules[""] = "UNK" 135 | rules[""] = "" 136 | rules[""] = "" 137 | rules[""] = "" 138 | 139 | data = [[rules.get(x, x) for x in seq] for seq in data] 140 | 141 | # remove repetitions 142 | #data = [[x[0] for x in groupby(seq)] for seq in data] 143 | 144 | return data 145 | 146 | 147 | def devectorize_generate(data, id2tok, eos, strip_eos=True, oov_map=None, pp=True): 148 | if strip_eos: 149 | for i in range(len(data)): 150 | try: 151 | data[i] = data[i][:list(data[i]).index(eos)] 152 | except: 153 | continue 154 | 155 | # ids to words 156 | #data = [[id2tok.get(x, "") for x in seq] for seq in data] 157 | results = [] 158 | for sample in data: 159 | sample_rst = [] 160 | for sent in sample: 161 | sent = sent.tolist() 162 | for wrd in sent: 163 | wrd_real = id2tok.get(wrd, "") 164 | if wrd_real != "": 165 | sample_rst.append(wrd_real) 166 | results.append(sample_rst) 167 | data = results 168 | 169 | if oov_map is not None: 170 | #data = [[m.get(x, x) for x in seq] for seq, m in zip(data, oov_map)] 171 | data_tmp = [] 172 | for seq, m in zip(data, oov_map): 173 | tmp = [] 174 | for x in seq: 175 | tmp.append(m.get(x, x)) 176 | data_tmp.append(tmp) 177 | data = data_tmp 178 | 179 | if pp: 180 | rules = {f"": "UNK" for i in range(10)} 181 | rules["unk"] = "UNK" 182 | rules[""] = "UNK" 183 | rules[""] = "" 184 | rules[""] = "" 185 | rules[""] = "" 186 | 187 | data = [[rules.get(x, x) for x in seq] for seq in data] 188 | 189 | # remove repetitions 190 | #data = [[x[0] for x in groupby(seq)] for seq in data] 191 | 192 | return data 193 | -------------------------------------------------------------------------------- /model_configs/ds.full.yaml: -------------------------------------------------------------------------------- 1 | checkpoint_interval: # how often (batches) to save a checkpoint 2 | eval_interval: # how often (batches) to evaluate the model on the dev set 3 | log_interval: # how often (batches) to log the training process to console 4 | batch_size: # number of batch size 5 | epochs: # number of epochs 6 | pre_train_epochs: # number of pre-train epochs 7 | num_workers: 8 | 9 | plot_norms: False # Plot the gradient norms of each loss wrt to the compressor 10 | 11 | lr: # Learning rate of the optimizer 12 | weight_decay: # Weight decay value of the optimizer 13 | 14 | # The checkpoint of the pretrained LM to be used as prior. 15 | # Use only the prefix of the file (without .pt) 16 | prior: 17 | #main_model: 18 | main_model: 19 | tensorboard_path: 20 | 21 | data: 22 | train_path: 23 | val_path: 24 | ref_path: 25 | ref_nsent_path: 26 | dec_summ_path: 27 | dec_nsent_path: 28 | org_dia_path: 29 | result_path: 30 | 31 | sent_num: 32 | sent_len: 33 | sent_sim_len: 34 | 35 | seq_len: # maximum length of source texts 36 | dec_seq_len: # maxmium length of internal summary 37 | ext_sum_len: # extractive summary length 38 | oovs: # number of special OOV tokens (www.aclweb.org/anthology/K18-1040) 39 | swaps: # percentage of local token swaps to the source text 40 | 41 | vocab: 42 | embeddings: # pretrained word embeddings file 43 | embeddings_dim: # pretrained word embeddings dimensionality 44 | size: # size of the vocabulary. Top-N frequent words. 45 | 46 | model: 47 | clip: # value of clipping the norms of the gradients 48 | k: # extract top k sentences as ground truth 49 | pack: # use packed_sequences 50 | 51 | batch_size: 52 | sent_num: 53 | sent_len: 54 | 55 | ################################################ 56 | # LOSSES 57 | ################################################ 58 | 59 | # Annealing: If you want to anneal the value of a hyper-parameter, 60 | # you can do so, by replacing the value with a list: [from, to]. 61 | # For example, to anneal the value of the weight of the prior: 62 | # loss_weight_prior: [0.001, 0.5] 63 | # Note that the starting value cannot be zero. 64 | 65 | #------------------------------------ 66 | # Nsent 67 | #------------------------------------ 68 | loss_weight_nsent: # weight of the nth sentence loss - λ_R 69 | 70 | #------------------------------------ 71 | # Nsent_Sum 72 | #------------------------------------ 73 | n_sent_sum_loss: True 74 | loss_weight_nsent_sum: 75 | 76 | #------------------------------------ 77 | # Prior 78 | #------------------------------------ 79 | prior_loss: False # enable/disable the prior loss 80 | loss_weight_prior: # weight of the prior loss - λ_P 81 | 82 | #------------------------------------ 83 | # Topic 84 | #------------------------------------ 85 | topic_loss: False # enable/disable the prior loss 86 | loss_weight_topic: # weight of the prior loss - λ_T 87 | topic_idf: True # weight the input embeddings by their IDF 88 | topic_distance: cosine # distance metric for topic loss. Options: cosine, euclidean 89 | 90 | #------------------------------------ 91 | # Length 92 | #------------------------------------ 93 | length_loss: False # enable/disable the length loss 94 | loss_weight_length: # weight of the prior loss - λ_L 95 | 96 | #------------------------------------ 97 | # Doc-Sum-KL 98 | #------------------------------------ 99 | doc_sum_kl_loss: False # enable/disable the document summary generation same sentence loss 100 | loss_weight_doc_sum: # weight of the document summary generation loss - λ_L 101 | 102 | #------------------------------------ 103 | # Doc-Sum-Sim 104 | #------------------------------------ 105 | doc_sum_sim_loss: False 106 | loss_weight_doc_sum_sim: 107 | docsim_distance: cosine 108 | 109 | #------------------------------------ 110 | # Summary 111 | #------------------------------------ 112 | sum_loss: False 113 | loss_weight_sum: 114 | 115 | #------------------------------------ 116 | # Nsent Classification 117 | #------------------------------------ 118 | nsent_classification: True 119 | loss_weight_classification: 120 | 121 | #------------------------------------ 122 | # Nsent Classification Sum 123 | #------------------------------------ 124 | nsent_classification_sum: True 125 | loss_weight_classification_sum: 126 | 127 | #------------------------------------ 128 | # Nsent KL 129 | #------------------------------------ 130 | nsent_classification_kl: True 131 | loss_weight_classification_kl: 132 | 133 | ################################################ 134 | # SUMMARY LENGTHS 135 | ################################################ 136 | min_ratio: # min % of the sampled summary lengths 137 | max_ratio: # max % of the sampled summary lengths 138 | min_length: # absolute min length (words) of the sampled summary length 139 | max_length: # absolute max length (words) of the sampled summary length 140 | test_min_ratio: # same as above but for inference 141 | test_max_ratio: # same as above but for inference 142 | test_min_length: # same as above but for inference 143 | test_max_length: # same as above but for inference 144 | 145 | ################################################ 146 | # PARAMETER SHARING 147 | ################################################ 148 | tie_decoder_outputs: False # tie the output layers of both decoders (projections to vocab) 149 | tie_embedding_outputs: False # tie the embedding and output layers of both decoders 150 | tie_embedding: False # tie all the embedding layers together 151 | tie_decoders: False # tie the decoders of the compressor and reconstructor 152 | tie_encoders: False # tie the encoders of the compressor and reconstructor 153 | 154 | ################################################ 155 | # INIT DECODER 156 | ################################################ 157 | length_control: True # If true, use the countdown parameter for the decoders, 158 | # as well as the target length-aware initialization for each decoder 159 | bridge_hidden: True # use a bridge layer (hidden) between the last layer of the encoder and the initial state of the decoder 160 | bridge_non_linearity: tanh # apply a non-linearity to the bridge layer. Options: tanh, relu 161 | 162 | emb_size: 300 # the size of the embedding layer(s) 163 | embed_dropout: 0.0 # dropout probability for the embedding layer(s) 164 | embed_trainable: True # Finetune the embeddings 165 | embed_masked: False # Finetune the only the words not included in the pretrained embeddings. 166 | layer_norm: True # Apply layer normalization to the outputs of the decoders 167 | enc_token_dropout: 0.0 # % of words to drop from the input 168 | dec_token_dropout: 0.5 # % of words to drop from the reconstruction 169 | enc_rnn_size: 200 # the size of the encoder(s) 170 | dec_rnn_size: 200 # the size of the decoder(s) 171 | rnn_layers: 2 # number of layers for encoders and decoders 172 | rnn_dropout: 0.0 # dropout probability for the outputs of each RNN 173 | rnn_bidirectional: True # Use bidirectional encoder(s) 174 | attention: True # Use attentional seq2seq. False not implemented! 175 | attention_fn: general # The attention function. Options: general, additive, concat 176 | attention_coverage: False # Include a coverage vector to the attention mechanism 177 | input_feeding: True # Use input feeding (Luong et. al. 2015) 178 | input_feeding_learnt: True # Learn the first value of the input feed 179 | out_non_linearity: tanh # Apply a non-linearity to the output vector (before projection to vocab) 180 | 181 | sampling: 0.0 # Probability of schedule-sampling to the reconstructor 182 | top: False # Use argmax for sampling in the latent sequence. True not implemented! 183 | hard: True # Use Straight-Through, i.e., discretize the output distributions in the forwards pass 184 | gumbel: True # Use Gumbel-Softmax instead of softmax in the latent sequence 185 | tau: 0.5 # Temperature of the distributions in the latent sequence 186 | learn_tau: False # Learn the value of the temperature, as function of the output of the decoder(s) 187 | tau_0: 0.5 # Hyper-parameter that controls the upper-bound of the temperature. -------------------------------------------------------------------------------- /model_configs/lm_prior.yaml: -------------------------------------------------------------------------------- 1 | checkpoint_interval: 0 2 | log_interval: 50 3 | batch_size: 32 4 | epochs: 50 5 | 6 | lr: 0.001 7 | scheduler: step 8 | step_size: 10 9 | eta_min: 0.0001 10 | weight_decay: 0.0 11 | gamma: 0.5 12 | milestones: [5,15] 13 | 14 | parallel: False 15 | data: 16 | train_path: 17 | val_path: 18 | seq_len: 150 # maximum length of texts (in words) 19 | stateful: False # Leave it to False. True is used for document level language modeling. 20 | sos: True # Add a Start-of-SequenceSOS token 21 | oovs: 10 # number of special OOV tokens (www.aclweb.org/anthology/K18-1040) 22 | # the LM is trained with the same trick, in order to be able to compute meaningful KL in the compression task 23 | vocab: 24 | vocab_path: 25 | size: 30000 26 | subword: False 27 | subword_path: 28 | model: 29 | emb_size: 300 # the size of the embedding layer(s) 30 | embed_noise: 0.0 # additive gaussian noise with given sigma 31 | embed_dropout: 0.2 # dropout probability for the embedding layer(s) 32 | rnn_size: 256 # size of the RNN 33 | rnn_layers: 2 # number of RNN layers 34 | rnn_dropout: 0.5 # dropout probability for outputs of the RNN 35 | decode: True # leave it to True. It mean that the outputs of the RNN 36 | # will be projected to the vocabulary (decoded) 37 | tie_weights: True # tie the embedding and the output layer 38 | countdown: True # add a countdown input. Leave it to True 39 | pack: True # use packed_sequences 40 | clip: 1 # value of clipping the norms of the gradients -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.insert(0, '.') 4 | sys.path.insert(0, '..') 5 | sys.path.insert(0, '../../') 6 | sys.path.insert(0, '../../../') 7 | -------------------------------------------------------------------------------- /models/ds_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from modules.helpers import sequence_mask 5 | 6 | 7 | def _kl_div(inp_logits, trg_logits, lengths, tau=1): 8 | """ 9 | Compute the prior loss using a pretrained "oracle" LM. 10 | The loss is computed using the produced posteriors over the vocabulary 11 | produced by a generator and the posteriors of the "oracle" LM. 12 | 13 | Args: 14 | logits: the logits of the generator 15 | words: the argmax of the logits 16 | oracle: the oracle LM 17 | tau: the temperature of the softmax 18 | lengths: the lengths of the target sequence. Used for masking the loss. 19 | 20 | 21 | Debug = -F.softmax(_logits, -1) * torch.log(F.softmax(logits, -1) / 22 | F.softmax(_logits, -1)) 23 | 24 | Returns: 25 | the average KL Divergence per timestep (word) 26 | 27 | """ 28 | mask = sequence_mask(lengths).unsqueeze(-1).float() 29 | 30 | input_logp = F.log_softmax(inp_logits * mask / tau, -1) 31 | target_p = F.softmax(trg_logits * mask / tau, -1) 32 | 33 | # shape: batch x seq_length x tokens 34 | loss = F.kl_div(input_logp, target_p, reduction='none') 35 | 36 | # sum over words/vocab (KL per word/timestep !) 37 | # shape: batch x length 38 | loss = loss.sum(-1) 39 | 40 | # zero losses for padded timesteps 41 | loss = loss * mask.squeeze() 42 | 43 | total_loss = loss.sum() / mask.sum() 44 | 45 | return total_loss, loss 46 | 47 | 48 | def _global_prior(logits, word_idx, lengths): 49 | """ 50 | Evaluate the probability of a sequence, under a language model 51 | 52 | """ 53 | 54 | mask = sequence_mask(lengths) 55 | labels = (word_idx * mask.long()).contiguous().view(-1) 56 | _logits = logits.contiguous().view(-1, logits.size(-1)) 57 | loss = F.cross_entropy(_logits, labels, ignore_index=0, reduction='none') 58 | 59 | # normalize by length to avoid mode collapse 60 | total = loss.sum() / mask.float().sum() 61 | 62 | return total, loss.view(mask.size()) 63 | 64 | 65 | def kl_length(logits, lengths, eos): 66 | """ 67 | Length control loss, using a sequence of length labels (with eos token). 68 | 69 | Args: 70 | logits: 71 | lengths: 72 | eos: 73 | 74 | Returns: 75 | 76 | """ 77 | mask = sequence_mask(lengths - 1, lengths.max()) 78 | eos_labels = ((1 - mask) * eos).long().contiguous().view(-1) 79 | 80 | _logits = logits.contiguous().view(-1, logits.size(-1)) 81 | loss = F.cross_entropy(_logits, eos_labels, ignore_index=0) 82 | 83 | return loss 84 | 85 | 86 | def pairwise_loss(a, b, dist="cosine"): 87 | if dist == "euclidean": 88 | return F.pairwise_distance(a, b).mean() 89 | elif dist == "cosine": 90 | return 1 - F.cosine_similarity(a, b).mean() 91 | elif dist == "dot": 92 | dot = torch.bmm(a.unsqueeze(1), b.unsqueeze(-1)).squeeze() 93 | scaled_dot = dot.mean() / a.size(1) 94 | return - scaled_dot 95 | else: 96 | raise ValueError 97 | -------------------------------------------------------------------------------- /models/ds_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | #from models.seq3_losses import _kl_div, kl_length, pairwise_loss 7 | from models.ds_losses import _kl_div, kl_length, pairwise_loss 8 | from models.ds_utils import sample_lengths 9 | #from models.seq3_utils import sample_lengths 10 | from modules.helpers import sequence_mask, avg_vectors, module_grad_wrt_loss, kl_categorical 11 | from modules.training.trainer import Trainer 12 | 13 | 14 | class DsTrainer(Trainer): 15 | 16 | def __init__(self, *args, **kwargs): 17 | 18 | super().__init__(*args, **kwargs) 19 | 20 | self.oracle = kwargs.get("oracle", None) 21 | self.top = self.config["model"]["top"] 22 | self.hard = self.config["model"]["hard"] 23 | self.sampling = self.anneal_init(self.config["model"]["sampling"]) 24 | self.tau = self.anneal_init(self.config["model"]["tau"]) 25 | self.len_min_rt = self.anneal_init(self.config["model"]["min_ratio"]) 26 | self.len_max_rt = self.anneal_init(self.config["model"]["max_ratio"]) 27 | self.len_min = self.anneal_init(self.config["model"]["min_length"]) 28 | self.len_max = self.anneal_init(self.config["model"]["max_length"]) 29 | 30 | def _debug_grads(self): 31 | return list(sorted([(n, p.grad) for n, p in 32 | self.model.named_parameters() if p.requires_grad])) 33 | 34 | def _debug_grad_norms(self, reconstruct_loss, prior_loss, topic_loss, kl_loss): 35 | c_grad_norm = [] 36 | c_grad_norm.append( 37 | module_grad_wrt_loss(self.optimizers, self.model.compressor, 38 | reconstruct_loss, 39 | "rnn")) 40 | 41 | if self.config["model"]["topic_loss"]: 42 | c_grad_norm.append( 43 | module_grad_wrt_loss(self.optimizers, self.model.compressor, 44 | topic_loss, 45 | "rnn")) 46 | 47 | if self.config["model"]["prior_loss"] and self.oracle is not None: 48 | c_grad_norm.append( 49 | module_grad_wrt_loss(self.optimizers, self.model.compressor, 50 | prior_loss, 51 | "rnn")) 52 | 53 | if self.config["model"]["doc_sum_kl_loss"] and self.oracle is not None: 54 | c_grad_norm.append( 55 | module_grad_wrt_loss(self.optimizers, self.model.compressor, 56 | kl_loss, 57 | "rnn")) 58 | 59 | return c_grad_norm 60 | 61 | def _topic_loss(self, inp, dec1, src_lengths, trg_lengths): 62 | """ 63 | Compute the pairwise distance of various outputs of the seq^3 architecture. 64 | Args: 65 | enc1: the outputs of the first encoder (input sequence) 66 | dec1: the outputs of the first decoder (latent sequence) 67 | src_lengths: the lengths of the input sequence 68 | trg_lengths: the lengths of the targer sequence (summary) 69 | 70 | """ 71 | 72 | enc_mask = sequence_mask(src_lengths).unsqueeze(-1).float() 73 | dec_mask = sequence_mask(trg_lengths - 1).unsqueeze(-1).float() 74 | 75 | enc_embs = self.model.inp_encoder.embed(inp) 76 | dec_embs = self.model.compressor.embed.expectation(dec1[3]) 77 | 78 | if self.config["model"]["topic_idf"]: 79 | enc1_energies = self.model.idf(inp) 80 | # dec1_energies = expected_vecs(dec1[3], self.model.idf.weight) 81 | 82 | x_emb, att_x = avg_vectors(enc_embs, enc_mask, enc1_energies) 83 | # y_emb, att_y = avg_vectors(dec_reps, dec_mask, dec1_energies) 84 | y_emb, att_y = avg_vectors(dec_embs, dec_mask) 85 | 86 | else: 87 | x_emb, att_x = avg_vectors(enc_embs, enc_mask) 88 | y_emb, att_y = avg_vectors(dec_embs, dec_mask) 89 | 90 | distance = self.config["model"]["topic_distance"] 91 | loss = pairwise_loss(x_emb, y_emb, distance) 92 | 93 | return loss, (att_x, att_y) 94 | 95 | #def _doc_sum_loss(self, enc1, enc2, doc_lengths, sum_lengths): 96 | def _doc_sum_loss(self, inp, attn_dis, src_lengths, trg_lengths): 97 | """ 98 | Compute the loss of semantic representation between document and summary 99 | Args: 100 | enc1: the outputs of the first encoder (input sequence) 101 | enc2: the outputs of the first decoder (decode summary) 102 | """ 103 | 104 | """doc_mask = sequence_mask(doc_lengths).unsqueeze(-1).float() 105 | sum_mask = sequence_mask(sum_lengths - 1).unsqueeze(-1).float() 106 | 107 | doc_vec = enc1[0] * doc_mask.float() 108 | doc_mean = doc_vec.sum(1) / doc_mask.sum(1) 109 | 110 | sum_vec = enc2[0] * sum_mask.float() 111 | sum_mean = sum_vec.sum(1) / sum_mask.sum(1) 112 | 113 | loss = torch.cosine_similarity(doc_mean, sum_mean) 114 | loss = torch.mean(torch.abs(loss))""" 115 | 116 | enc_mask = sequence_mask(src_lengths).unsqueeze(-1).float() 117 | enc_embs = self.model.inp_encoder.embed(inp) 118 | 119 | if self.config["model"]["topic_idf"]: 120 | enc1_energies = self.model.idf(inp) 121 | x_emb, att_x = avg_vectors(enc_embs, enc_mask, enc1_energies) 122 | #tmp = attn_dis.contiguous().view(-1, attn_dis.size(-1)) 123 | attn_dis = torch.sum(attn_dis, 1) 124 | att_x = torch.squeeze(att_x) 125 | 126 | #loss = torch.dist(att_x, attn_dis, 1) 127 | distance = self.config["model"]["docsim_distance"] 128 | loss = pairwise_loss(att_x, attn_dis) 129 | 130 | return loss 131 | 132 | def _prior_loss(self, outputs, latent_lengths): 133 | """ 134 | Prior Loss 135 | Args: 136 | outputs: 137 | latent_lengths: 138 | 139 | Returns: 140 | 141 | """ 142 | enc1, dec1, enc2, dec2, dec3 = outputs 143 | _vocab = self._get_vocab() 144 | 145 | logits_dec1, outs_dec1, hn_dec1, dists_dec1, _, _ = dec1 146 | 147 | # dists_dec1 contain the distributions from which 148 | # the samples were taken. It contains one less element than the logits 149 | # because the last logit is only used for computing the NLL of EOS. 150 | words_dec1 = dists_dec1.max(-1)[1] 151 | 152 | # sos + the sampled sentence 153 | sos_id = _vocab.tok2id[_vocab.SOS] 154 | sos = torch.zeros_like(words_dec1[:, :1]).fill_(sos_id) 155 | oracle_inp = torch.cat([sos, words_dec1], -1) 156 | 157 | logits_oracle, _, _ = self.oracle(oracle_inp, None, 158 | latent_lengths) 159 | 160 | prior_loss, prior_loss_time = _kl_div(logits_dec1, 161 | logits_oracle, 162 | latent_lengths) 163 | 164 | return prior_loss, prior_loss_time, logits_oracle 165 | 166 | def _process_batch(self, inp_x, inp_sim, out_sim, inp_y, out_y, sim_len, sent_len, sent_num, y_lengths): 167 | 168 | self.model.train() 169 | 170 | outputs = self.model(self.config["model"]["k"], inp_x, inp_sim, inp_y, sim_len, sent_len, sent_num, y_lengths) 171 | 172 | sent_prob, outs_enc, outs_enc_filter, dec1, dec2, sent_len, dialogue_pre, summary_pre = outputs 173 | 174 | batch_outputs = {"model_outputs": outputs} 175 | 176 | # -------------------------------------------------------------- 177 | # 1 - Predict nth sentence 178 | # -------------------------------------------------------------- 179 | _dec1_logits = dec1[0].contiguous().view(-1, dec1[0].size(-1)) 180 | _x_labels = out_y.contiguous().view(-1) 181 | nsent_loss = F.cross_entropy(_dec1_logits, _x_labels, ignore_index=0, reduction='none') 182 | 183 | nsent_loss_token = nsent_loss.view(out_y.size()) 184 | batch_outputs["n_sent"] = nsent_loss_token 185 | mean_rec_loss = nsent_loss.sum() / y_lengths.float().sum() 186 | losses = [mean_rec_loss] 187 | 188 | # -------------------------------------------------------------- 189 | # 1.5 - Predict nth sentence from summary 190 | # -------------------------------------------------------------- 191 | if self.config["model"]["n_sent_sum_loss"]: 192 | _dec2_logits = dec2[0].contiguous().view(-1, dec2[0].size(-1)) 193 | nsent_loss_sum = F.cross_entropy(_dec2_logits, _x_labels, ignore_index=0, reduction='none') 194 | nsent_loss_token_sum = nsent_loss_sum.view(out_y.size()) 195 | batch_outputs["n_sent_sum"] = nsent_loss_token_sum 196 | mean_rec_sum_loss = nsent_loss_sum.sum() / y_lengths.float().sum() 197 | losses.append(mean_rec_sum_loss) 198 | else: 199 | mean_rec_sum_loss = None 200 | 201 | # -------------------------------------------------------------- 202 | # 2 - DOCUEMNT+SUMMARY DISTRIBUTION 203 | # -------------------------------------------------------------- 204 | if self.config["model"]["doc_sum_kl_loss"]: 205 | _dec1_logits = dec1[0].contiguous().view(-1, dec1[0].size(-1)) 206 | _dec2_logits = dec2[0].contiguous().view(-1, dec2[0].size(-1)) 207 | #kl_loss = torch.nn.functional.kl_div(_dec2_logits, _dec2_logits, size_average=None, reduce=True, reduction='mean') 208 | kl_loss = kl_categorical(_dec1_logits, _dec2_logits) 209 | losses.append(kl_loss) 210 | else: 211 | kl_loss = None 212 | 213 | # -------------------------------------------------------------- 214 | # 3 - LENGTH 215 | # -------------------------------------------------------------- 216 | if self.config["model"]["length_loss"]: 217 | _, topk_indices = torch.topk(sent_prob, k=self.config["model"]["k"], dim=1) 218 | topk_indices = torch.squeeze(topk_indices, -1) 219 | sum_length = torch.gather(sent_len, dim=1, index=topk_indices) 220 | sum_length = torch.sum(sum_length, dim=1) 221 | tmp = torch.sub(self.config["data"]["ext_sum_len"], sum_length) 222 | length_loss = torch.mean(tmp.float()) 223 | losses.append(length_loss) 224 | else: 225 | length_loss = None 226 | 227 | # -------------------------------------------------------------- 228 | # 4 - DOCUEMNT SUMMARY SIMILARITY 229 | # -------------------------------------------------------------- 230 | if self.config["model"]["doc_sum_sim_loss"]: 231 | dialog_rep = torch.squeeze(torch.sum(outs_enc, 1)) 232 | summary_rep = torch.squeeze(torch.sum(outs_enc_filter, 1)) 233 | sim_loss = pairwise_loss(dialog_rep, summary_rep) 234 | losses.append(sim_loss) 235 | else: 236 | sim_loss = None 237 | 238 | # -------------------------------------------------------------- 239 | # 4 - N SENTENCE CLASSIFICATION 240 | # -------------------------------------------------------------- 241 | if self.config["model"]["nsent_classification"]: 242 | criterion = torch.nn.BCEWithLogitsLoss() 243 | dia_pre_loss = criterion(dialogue_pre, out_sim.float()) 244 | sum_pre_loss = criterion(summary_pre, out_sim.float()) 245 | pre_kl_loss = kl_categorical(dialogue_pre, summary_pre) 246 | losses.append(dia_pre_loss) 247 | losses.append(sum_pre_loss) 248 | losses.append(pre_kl_loss) 249 | else: 250 | dia_pre_loss = None 251 | sum_pre_loss = None 252 | pre_kl_loss = None 253 | 254 | prior_loss = None 255 | topic_loss = None 256 | kl_loss = None 257 | # -------------------------------------------------------------- 258 | # Plot Norms of loss gradient wrt to the compressor 259 | # -------------------------------------------------------------- 260 | if self.config["plot_norms"] and self.step % self.config["log_interval"] == 0: 261 | batch_outputs["grad_norm"] = self._debug_grad_norms( 262 | mean_rec_loss, 263 | prior_loss, 264 | topic_loss, 265 | kl_loss) 266 | 267 | return losses, batch_outputs 268 | 269 | def eval_epoch(self, batch_size): 270 | """ 271 | Evaluate the network for one epoch and return the average loss. 272 | 273 | Returns: 274 | loss (float, list(float)): list of mean losses 275 | 276 | """ 277 | self.model.eval() 278 | 279 | results = [] 280 | k_indices = [] 281 | oov_maps = [] 282 | 283 | self.len_min_rt = self.anneal_init( 284 | self.config["model"]["test_min_ratio"]) 285 | self.len_max_rt = self.anneal_init( 286 | self.config["model"]["test_max_ratio"]) 287 | self.len_min = self.anneal_init( 288 | self.config["model"]["test_min_length"]) 289 | self.len_max = self.anneal_init( 290 | self.config["model"]["test_max_length"]) 291 | 292 | iterator = self.valid_loader 293 | with torch.no_grad(): 294 | for i_batch, batch in enumerate(iterator, 1): 295 | batch_oov_map = batch[-1] 296 | batch = batch[:-1] 297 | 298 | batch = list(map(lambda x: x.to(self.device), batch)) 299 | (inp_src, inp_sim, out_sim, out_src, out_trg, sim_len, sent_len, sent_num, trg_lengths) = batch 300 | 301 | if inp_src.size()[0] != batch_size: 302 | continue 303 | 304 | sent_prob = self.model.summary(inp_src, sent_len, sent_num) 305 | sent_prob = torch.squeeze(sent_prob) 306 | _, topk_indices = torch.topk(sent_prob, k=self.config["model"]["k"], dim=1) 307 | inp_src = inp_src.view(inp_src.size(0), self.config["model"]["sent_num"], self.config["model"]["sent_len"]) 308 | 309 | inp_src = inp_src.chunk(batch_size, dim=0) 310 | topk_indices = topk_indices.chunk(batch_size, dim=0) 311 | result_sents = [] 312 | for inp, indice in zip(inp_src, topk_indices): 313 | inp = torch.squeeze(inp) 314 | indice = torch.squeeze(indice) 315 | sum_sent = torch.index_select(inp, 0, indice) 316 | result_sents.append(sum_sent) 317 | 318 | oov_maps.append(batch_oov_map) 319 | results.append(result_sents) 320 | k_indices.append(topk_indices) 321 | 322 | return results, oov_maps, k_indices 323 | 324 | def _get_vocab(self): 325 | if isinstance(self.train_loader, (list, tuple)): 326 | dataset = self.train_loader[0].dataset 327 | else: 328 | dataset = self.train_loader.dataset 329 | 330 | if dataset.subword: 331 | _vocab = dataset.subword_path 332 | else: 333 | _vocab = dataset.vocab 334 | 335 | return _vocab 336 | 337 | def get_state(self): 338 | 339 | state = { 340 | "config": self.config, 341 | "epoch": self.epoch, 342 | "step": self.step, 343 | "model": self.model.state_dict(), 344 | "model_class": self.model.__class__.__name__, 345 | "optimizers": [x.state_dict() for x in self.optimizers], 346 | "vocab": self._get_vocab(), 347 | } 348 | 349 | return state 350 | -------------------------------------------------------------------------------- /models/ds_utils.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | import torch 4 | from sklearn.feature_extraction.text import TfidfVectorizer 5 | 6 | from generate.utils import devectorize 7 | 8 | 9 | def compute_dataset_idf(dataset, vocab): 10 | def identity_func(doc): 11 | return doc 12 | 13 | my_data = [dataset.read_sample(i) for i in range(len(dataset))] 14 | 15 | tf = TfidfVectorizer(lowercase=False, 16 | tokenizer=identity_func, 17 | preprocessor=identity_func, 18 | use_idf=True, 19 | vocabulary=vocab) 20 | tf.fit(my_data) 21 | return tf.idf_ 22 | 23 | 24 | def sample2text(word_ids, vocab): 25 | tokens = devectorize(word_ids, 26 | vocab.id2tok, 27 | vocab.tok2id[vocab.EOS], 28 | strip_eos=False, 29 | pp=False) 30 | text = [" ".join(out) for out in tokens] 31 | lengths = [len(t) for t in tokens] 32 | return text, lengths 33 | 34 | 35 | def str2tree(ls): 36 | tree = {} 37 | for item in ls: 38 | t = tree 39 | for part in item.split('.'): 40 | t = t.setdefault(part, {}) 41 | pprint(tree) 42 | 43 | 44 | def sample_lengths(src_lengths, 45 | min_ratio, max_ratio, 46 | min_length, max_length): 47 | """ 48 | Sample summary lengths from a list of source lengths. 49 | 50 | """ 51 | t = torch.empty(len(src_lengths), device=src_lengths.device) 52 | samples = t.uniform_(min_ratio, max_ratio) 53 | lengths = (src_lengths.float() * samples).long() 54 | lengths = lengths.clamp(min=min_length, max=max_length) 55 | return lengths 56 | -------------------------------------------------------------------------------- /models/sent_lm_trainer.py: -------------------------------------------------------------------------------- 1 | from modules.training.trainer import Trainer 2 | 3 | 4 | class LMTrainer(Trainer): 5 | 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | def _seq_loss(self, predictions, labels): 10 | _labels = labels.contiguous().view(-1) 11 | 12 | _logits = predictions[0] 13 | _logits = _logits.contiguous().view(-1, _logits.size(-1)) 14 | loss = self.criterion(_logits, _labels) 15 | 16 | return loss 17 | 18 | def _process_batch(self, inputs, labels, lengths): 19 | predictions = self.model(inputs, None, lengths) 20 | 21 | loss = self._seq_loss(predictions, labels) 22 | del predictions 23 | predictions = None 24 | 25 | return loss, predictions 26 | 27 | def get_state(self): 28 | if self.train_loader.dataset.subword: 29 | _vocab = self.train_loader.dataset.subword_path 30 | else: 31 | _vocab = self.train_loader.dataset.vocab 32 | 33 | state = { 34 | "config": self.config, 35 | "epoch": self.epoch, 36 | "step": self.step, 37 | "model": self.model.state_dict(), 38 | "model_class": self.model.__class__.__name__, 39 | "optimizers": [x.state_dict() for x in self.optimizers], 40 | "vocab": _vocab, 41 | } 42 | 43 | return state 44 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.insert(0, '.') 4 | sys.path.insert(0, '..') 5 | sys.path.insert(0, '../../') 6 | sys.path.insert(0, '../../../') 7 | -------------------------------------------------------------------------------- /modules/data/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.insert(0, '.') 4 | sys.path.insert(0, '..') 5 | sys.path.insert(0, '../../') 6 | sys.path.insert(0, '../../../') 7 | -------------------------------------------------------------------------------- /modules/data/collates.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | 4 | 5 | class SeqCollate: 6 | """ 7 | Base Class. 8 | A variant of callate_fn that pads according to the longest sequence in 9 | a batch of sequences 10 | """ 11 | 12 | def __init__(self, sim_len, sent_num, sent_len, sort=False, batch_first=True): 13 | self.sort = sort 14 | self.batch_first = batch_first 15 | self.sim_len = sim_len 16 | self.sent_num = sent_num 17 | self.sent_len = sent_len 18 | 19 | def pad_samples(self, samples): 20 | return pad_sequence([torch.LongTensor(x) for x in samples], self.batch_first) 21 | 22 | def _collate(self, *args): 23 | raise NotImplementedError 24 | 25 | def __call__(self, batch): 26 | batch = list(zip(*batch)) 27 | return self._collate(*batch) 28 | 29 | def pad_dialogues(self, samples): 30 | pad_lst = [0] * self.sent_len 31 | results_lst = [] 32 | for sample in samples: 33 | for index in range(len(sample)): 34 | if len(sample[index]) >= self.sent_len: 35 | sample[index] = sample[index][:self.sent_len] 36 | else: 37 | while len(sample[index]) < self.sent_len: 38 | sample[index].append(0) 39 | if len(sample) >= self.sent_num: 40 | sample = sample[:self.sent_num] 41 | else: 42 | while len(sample) < self.sent_num: 43 | sample.append(pad_lst) 44 | 45 | #results.append(torch.LongTensor([i for item in sample for i in item])) 46 | results_lst.append([i for item in sample for i in item]) 47 | results = torch.LongTensor(results_lst) 48 | return results 49 | 50 | def pad_sim_sent(self, samples): 51 | results_lst = [] 52 | for sample in samples: 53 | for index in range(len(sample)): 54 | if len(sample[index]) >= self.sim_len: 55 | sample[index] = sample[index][:self.sim_len] 56 | else: 57 | while len(sample[index]) < self.sim_len: 58 | sample[index].append(0) 59 | results_lst.append([i for item in sample for i in item]) 60 | results = torch.LongTensor(results_lst) 61 | return results 62 | 63 | 64 | class LMCollate(SeqCollate): 65 | def __init__(self, *args): 66 | super().__init__(*args) 67 | 68 | def _collate(self, inputs, targets, lengths): 69 | inputs = self.pad_samples(inputs) 70 | targets = self.pad_samples(targets) 71 | lengths = torch.LongTensor(lengths) 72 | return inputs, targets, lengths 73 | 74 | 75 | class CondLMCollate(SeqCollate): 76 | def __init__(self, *args): 77 | super().__init__(*args) 78 | 79 | def _collate(self, inputs, targets, attributes, lengths): 80 | inputs = self.pad_samples(inputs) 81 | targets = self.pad_samples(targets) 82 | attributes = self.pad_samples(attributes) 83 | lengths = torch.LongTensor(lengths) 84 | return inputs, targets, attributes, lengths 85 | 86 | 87 | class Seq2SeqCollate(SeqCollate): 88 | def __init__(self, *args): 89 | super().__init__(*args) 90 | 91 | def _collate(self, inp_src, inp_sim, out_sim, inp_trg, out_trg, sim_len, sent_len, sent_num, len_trg): 92 | inp_src = self.pad_dialogues(inp_src) 93 | inp_sim = self.pad_sim_sent(inp_sim) 94 | inp_trg = self.pad_samples(inp_trg) 95 | out_trg = self.pad_samples(out_trg) 96 | 97 | for index in range(len(sent_len)): 98 | while len(sent_len[index]) < self.sent_num: 99 | sent_len[index].append(5) 100 | 101 | out_sim = torch.LongTensor(out_sim) 102 | sim_len = torch.LongTensor(sim_len) 103 | sent_len = torch.LongTensor(sent_len) 104 | sent_num = torch.LongTensor(sent_num) 105 | len_trg = torch.LongTensor(len_trg) 106 | 107 | return inp_src, inp_sim, out_sim, inp_trg, out_trg, sim_len, sent_len, sent_num, len_trg 108 | 109 | 110 | class Seq2SeqOOVCollate(SeqCollate): 111 | def __init__(self, *args): 112 | super().__init__(*args) 113 | 114 | def _collate(self, inp_src, inp_sim, out_sim, inp_trg, out_trg, sim_len, sent_len, sent_num, len_trg, oov_map): 115 | inp_src = self.pad_dialogues(inp_src) 116 | inp_sim = self.pad_sim_sent(inp_sim) 117 | inp_trg = self.pad_samples(inp_trg) 118 | out_trg = self.pad_samples(out_trg) 119 | 120 | for index in range(len(sent_len)): 121 | while len(sent_len[index]) < self.sent_num: 122 | sent_len[index].append(10) 123 | inp_sim = torch.LongTensor(inp_sim) 124 | 125 | out_sim = torch.LongTensor(out_sim) 126 | sim_len = torch.LongTensor(sim_len) 127 | sent_len = torch.LongTensor(sent_len) 128 | sent_num = torch.LongTensor(sent_num) 129 | len_trg = torch.LongTensor(len_trg) 130 | 131 | return inp_src, inp_sim, out_sim, inp_trg, out_trg, sim_len, sent_len, sent_num, len_trg, oov_map 132 | -------------------------------------------------------------------------------- /modules/data/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import codecs 3 | from abc import ABC 4 | 5 | from nltk import word_tokenize 6 | from tabulate import tabulate 7 | from torch.utils.data import Dataset 8 | 9 | from modules.data.utils import vectorize, read_corpus, read_corpus_subw, \ 10 | unks_per_sample, token_swaps 11 | 12 | 13 | class BaseLMDataset(Dataset, ABC): 14 | def __init__(self, input, preprocess=None, 15 | vocab=None, vocab_size=None, 16 | subword=False, subword_path=None, verbose=True, **kwargs): 17 | """ 18 | Base Dataset for Language Modeling. 19 | 20 | Args: 21 | preprocess (callable): preprocessing callable, which takes as input 22 | a string and returns a list of tokens 23 | input (str, list): the path to the data file, or a list of samples. 24 | vocab (Vocab): a vocab instance. If None, then build a new one 25 | from the Datasets data. 26 | vocab_size(int): if given, then trim the vocab to the given number. 27 | subword(bool): whether the dataset will be 28 | tokenized using subword units, using the SentencePiece package. 29 | subword(SentencePieceProcessor): path to the sentencepiece model 30 | verbose(bool): print useful statistics about the dataset. 31 | """ 32 | #self.input = input 33 | self.input = codecs.open(input).read().split("\n") 34 | self.subword = subword 35 | self.subword_path = subword_path 36 | 37 | if preprocess is not None: 38 | self.preprocess = preprocess 39 | 40 | # tokenize the dataset 41 | if self.subword: 42 | self.vocab, self.data = read_corpus_subw(input, subword_path) 43 | else: 44 | self.vocab, self.data = read_corpus(input, self.preprocess) 45 | 46 | if vocab is not None: 47 | self.vocab = vocab 48 | else: 49 | self.vocab.build(vocab_size) 50 | 51 | if verbose: 52 | print(self) 53 | print() 54 | 55 | def __str__(self): 56 | 57 | props = [] 58 | if isinstance(self.input, str): 59 | props.append(("source", os.path.basename(self.input))) 60 | 61 | _covarage = unks_per_sample(self.vocab.tok2id.keys(), self.data) 62 | _covarage = str(_covarage.round(4)) + " %" 63 | 64 | try: 65 | props.append(("size", len(self))) 66 | except: 67 | pass 68 | props.append(("vocab size", len(self.vocab))) 69 | props.append(("unique tokens", len(self.vocab.vocab))) 70 | props.append(("UNK per sample", _covarage)) 71 | props.append(("subword", self.subword)) 72 | 73 | if hasattr(self, 'seq_len'): 74 | props.append(("max seq length", self.seq_len)) 75 | if hasattr(self, 'bptt'): 76 | props.append(("BPTT", self.bptt)) 77 | if hasattr(self, 'attributes'): 78 | props.append(("attributes", len(self.attributes[0]))) 79 | 80 | return tabulate([[x[1] for x in props]], headers=[x[0] for x in props]) 81 | 82 | def truncate(self, n): 83 | self.data = self.data[:n] 84 | 85 | @staticmethod 86 | def preprocess(text, lower=True): 87 | if lower: 88 | text = text.lower() 89 | # return text.split() 90 | return word_tokenize(text) 91 | 92 | 93 | class SentenceLMDataset(BaseLMDataset): 94 | def __init__(self, *args, seq_len=1000, **kwargs): 95 | """ 96 | Dataset for sentence-level Language Modeling. 97 | """ 98 | super().__init__(*args, **kwargs) 99 | # todo: find more elegant way to ignore seq_len 100 | self.seq_len = seq_len 101 | self.sos = kwargs.get("sos", False) 102 | self.oovs = kwargs.get("oovs", 0) 103 | 104 | for i in range(self.oovs): 105 | self.vocab.add_token(f"") 106 | print() 107 | 108 | def __len__(self): 109 | return len(self.data) 110 | 111 | def __getitem__(self, index): 112 | sentence = self.data[index] 113 | sentence = sentence + [self.vocab.EOS] 114 | 115 | if self.sos: 116 | sentence = [self.vocab.SOS] + sentence 117 | 118 | sentence = sentence[:self.seq_len] 119 | inputs = sentence[:-1] 120 | targets = sentence[1:] 121 | 122 | length = len(inputs) 123 | 124 | if self.oovs > 0: 125 | inputs_vec, _ = vectorize(inputs, self.vocab, self.oovs) 126 | targets_vec, _ = vectorize(targets, self.vocab, self.oovs) 127 | else: 128 | inputs_vec = vectorize(inputs, self.vocab) 129 | targets_vec = vectorize(targets, self.vocab) 130 | 131 | assert len(inputs_vec) == len(targets_vec) 132 | 133 | return inputs_vec, targets_vec, length 134 | 135 | 136 | class AEDataset(BaseLMDataset): 137 | def __init__(self, *args, seq_len=250, **kwargs): 138 | """ 139 | Dataset for sequence autoencoder. 140 | 141 | """ 142 | super().__init__(*args, **kwargs) 143 | # todo: find more elegant way to ignore seq_len 144 | self.seq_len = seq_len 145 | self.oovs = kwargs.get("oovs", 0) 146 | self.return_oov = kwargs.get("return_oov", False) 147 | self.swaps = kwargs.get("swaps", 0.0) 148 | 149 | for i in range(self.oovs): 150 | self.vocab.add_token(f"") 151 | print() 152 | 153 | def __len__(self): 154 | return len(self.data) 155 | 156 | def read_sample(self, index): 157 | sample = self.data[index][:self.seq_len] 158 | sample = [self.vocab.SOS] + sample + [self.vocab.EOS] 159 | sample, _ = vectorize(sample, self.vocab, self.oovs) 160 | return list(map(self.vocab.id2tok.get, sample)) 161 | 162 | def __getitem__(self, index): 163 | 164 | inp_x = self.data[index][:self.seq_len] 165 | out_x = inp_x[1:] + [self.vocab.EOS] 166 | 167 | inp_xhat = [self.vocab.SOS] + self.data[index][:self.seq_len] 168 | out_xhat = inp_xhat[1:] + [self.vocab.EOS] 169 | 170 | # print(tabulate([inp_src, out_src, inp_trg, out_trg], 171 | # tablefmt="psql")) 172 | 173 | if not self.subword: 174 | inp_x, oov_map = vectorize(inp_x, self.vocab, self.oovs) 175 | out_x, _ = vectorize(out_x, self.vocab, self.oovs) 176 | inp_xhat, _ = vectorize(inp_xhat, self.vocab, self.oovs) 177 | out_xhat, _ = vectorize(out_xhat, self.vocab, self.oovs) 178 | else: 179 | raise NotImplementedError 180 | 181 | # add noise in the form of token swaps ! after the OOV replacements 182 | inp_x = token_swaps(inp_x, self.swaps) 183 | 184 | sample = inp_x, out_x, inp_xhat, out_xhat, len(inp_x), len(inp_xhat) 185 | 186 | if self.return_oov: 187 | sample = sample + (oov_map,) 188 | 189 | return sample 190 | -------------------------------------------------------------------------------- /modules/data/datasets_ds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import codecs 3 | import json 4 | import numpy as np 5 | from abc import ABC 6 | 7 | from nltk import word_tokenize 8 | from tabulate import tabulate 9 | from torch.utils.data import Dataset 10 | 11 | from modules.data.utils import vectorize, read_corpus, read_corpus_subw, \ 12 | unks_per_sample, token_swaps, read_corpus_dialogue, unks_per_sample_dialogue, shuffle 13 | 14 | class BaseLMDataset(Dataset, ABC): 15 | def __init__(self, input_path, mode, k, batch_size=None, summary_path=None, nsent_path=None, org_dia_path=None, preprocess=None, 16 | vocab=None, vocab_size=None, subword=False, subword_path=None, verbose=True, **kwargs): 17 | """ 18 | Base Dataset for Language Modeling. 19 | 20 | Args: 21 | preprocess (callable): preprocessing callable, which takes as input 22 | a string and returns a list of tokens 23 | #input (str, list): the path to the data file, or a list of samples. 24 | input_path: the path to the data file 25 | vocab (Vocab): a vocab instance. If None, then build a new one 26 | from the Datasets data. 27 | vocab_size(int): if given, then trim the vocab to the given number. 28 | subword(bool): whether the dataset will be 29 | tokenized using subword units, using the SentencePiece package. 30 | subword(SentencePieceProcessor): path to the sentencepiece model 31 | verbose(bool): print useful statistics about the dataset. 32 | """ 33 | self.mode = mode 34 | self.batch_size = batch_size 35 | self.k = k 36 | if mode != "test": 37 | self.inputs, self.inputs_sim, self.decoder_inputs = load_data_cl_sec(input_path) 38 | #self.inputs, self.inputs_sim, self.decoder_inputs = load_data_cl_ami(input_path) 39 | else: 40 | self.inputs, self.inputs_sim, self.decoder_inputs = load_data_cl_test(input_path, batch_size, summary_path, nsent_path, org_dia_path) 41 | #self.inputs, self.inputs_sim, self.decoder_inputs = load_data_cl_test_ami(input_path, batch_size, summary_path, nsent_path, org_dia_path) 42 | self.subword = subword 43 | self.subword_path = subword_path 44 | 45 | if preprocess is not None: 46 | self.preprocess = preprocess 47 | 48 | # tokenize the dataset 49 | if self.subword: 50 | self.vocab, self.data = read_corpus_subw(self.inputs, subword_path) 51 | else: 52 | #self.vocab, self.data = read_corpus(self.inputs, self.preprocess) 53 | _, self.decoder_inputs = read_corpus(self.decoder_inputs, self.preprocess) 54 | self.vocab, self.data = read_corpus_dialogue(self.inputs, self.preprocess) 55 | _, self.data_sim = read_corpus_dialogue(self.inputs_sim, self.preprocess) 56 | 57 | if vocab is not None: 58 | self.vocab = vocab 59 | else: 60 | self.vocab.build(vocab_size) 61 | 62 | if verbose: 63 | print(self) 64 | print() 65 | 66 | def __str__(self): 67 | 68 | props = [] 69 | """if isinstance(self.input, str): 70 | props.append(("source", os.path.basename(self.input)))""" 71 | _covarage = unks_per_sample_dialogue(self.vocab.tok2id.keys(), self.data) 72 | _covarage = str(_covarage.round(4)) + " %" 73 | 74 | try: 75 | props.append(("size", len(self))) 76 | except: 77 | pass 78 | props.append(("vocab size", len(self.vocab))) 79 | props.append(("unique tokens", len(self.vocab.vocab))) 80 | props.append(("UNK per sample", _covarage)) 81 | props.append(("subword", self.subword)) 82 | 83 | if hasattr(self, 'seq_len'): 84 | props.append(("max seq length", self.seq_len)) 85 | if hasattr(self, 'bptt'): 86 | props.append(("BPTT", self.bptt)) 87 | if hasattr(self, 'attributes'): 88 | props.append(("attributes", len(self.attributes[0]))) 89 | 90 | return tabulate([[x[1] for x in props]], headers=[x[0] for x in props]) 91 | 92 | def truncate(self, n): 93 | self.data = self.data[:n] 94 | 95 | @staticmethod 96 | def preprocess(text, lower=True): 97 | if lower: 98 | text = text.lower() 99 | # return text.split() 100 | return word_tokenize(text) 101 | 102 | 103 | class DsDataset(BaseLMDataset): 104 | def __init__(self, *args, sent_num, sent_len, sent_sim_len, dec_seq_len, **kwargs): 105 | """Dataset for sequence dialogue summarization.""" 106 | 107 | super().__init__(*args, **kwargs) 108 | # todo: find more elegant way to ignore seq_len 109 | self.sent_num = sent_num 110 | self.sent_len = sent_len 111 | self.sent_sim_len = sent_sim_len 112 | self.dec_seq_len = dec_seq_len 113 | self.oovs = kwargs.get("oovs", 0) 114 | self.return_oov = kwargs.get("return_oov", False) 115 | self.swaps = kwargs.get("swaps", 0.0) 116 | 117 | for i in range(self.oovs): 118 | self.vocab.add_token(f"") 119 | 120 | def __len__(self): 121 | return len(self.data) 122 | 123 | def read_sample(self, index): 124 | """calculate the idf for each dialogue""" 125 | sample = self.data[index][:self.seq_len] 126 | sample = [self.vocab.SOS] + sample + [self.vocab.EOS] 127 | sample, _ = vectorize(sample, self.vocab, self.oovs) 128 | return list(map(self.vocab.id2tok.get, sample)) 129 | 130 | def __getitem__(self, index): 131 | inp_x = self.data[index][:self.sent_num] 132 | inp_sim = self.data_sim[index] 133 | out_sim = [0] * self.k 134 | inp_y = [self.vocab.SOS] + self.decoder_inputs[index][:self.dec_seq_len] 135 | out_y = self.decoder_inputs[index][:self.dec_seq_len] + [self.vocab.EOS] 136 | 137 | if not self.subword: 138 | imp_x_vec= [] 139 | inp_sim_vec = [] 140 | dialogue_len = [] 141 | sim_len = [] 142 | 143 | for sent in inp_x: 144 | inp_x_tmp, _ = vectorize(sent, self.vocab, self.oovs) 145 | imp_x_vec.append(inp_x_tmp) 146 | if len(sent) < self.sent_len: 147 | dialogue_len.append(len(sent)) 148 | else: 149 | dialogue_len.append(self.sent_len) 150 | inp_x_oov_map = [x for j in inp_x for x in j] 151 | _, oov_map = vectorize(inp_x_oov_map, self.vocab, self.oovs) 152 | 153 | inp_sim.append(self.decoder_inputs[index]) 154 | out_sim.append(1) 155 | inp_sim, out_sim = shuffle(inp_sim, out_sim) 156 | for sent in inp_sim: 157 | inp_sim_tmp, _ = vectorize(sent, self.vocab, self.oovs) 158 | inp_sim_vec.append(inp_sim_tmp) 159 | if len(sent) < self.sent_sim_len: 160 | sim_len.append(len(sent)) 161 | else: 162 | sim_len.append(self.sent_sim_len) 163 | 164 | inp_x = imp_x_vec 165 | inp_sim = inp_sim_vec 166 | inp_y, _ = vectorize(inp_y, self.vocab, self.oovs) 167 | out_y, _ = vectorize(out_y, self.vocab, self.oovs) 168 | else: 169 | raise NotImplementedError 170 | 171 | if len(inp_x) < self.sent_num: 172 | sample = inp_x, inp_sim, out_sim, inp_y, out_y, sim_len, dialogue_len, len(inp_x), len(inp_y) 173 | else: 174 | sample = inp_x, inp_sim, out_sim, inp_y, out_y, sim_len, dialogue_len, self.sent_num, len(inp_y) 175 | 176 | if self.return_oov: 177 | sample = sample + (oov_map,) 178 | 179 | return sample 180 | 181 | 182 | 183 | def load_data_cl_sec(path): 184 | # encoder 185 | inputs = [] 186 | inputs_sim = [] 187 | decoder_inputs = [] 188 | 189 | file_names = os.listdir(path) 190 | for file_name in file_names: 191 | with open(os.path.join(path, file_name), 'r') as load_f: 192 | try: 193 | file_dict = json.load(load_f) 194 | except Exception: 195 | print(file_name) 196 | continue 197 | sec_num = file_dict["section_num"] 198 | for i in range(sec_num): 199 | sec_tmp = file_dict["section"+str(i)] 200 | sec_dialogue = sec_tmp["dialogue"] 201 | sec_n_sent = sec_tmp["n_sent"] 202 | sec_sim = sec_tmp["n_sent_sim"][:6] 203 | 204 | if len(sec_dialogue) < 10: 205 | continue 206 | 207 | inputs.append(sec_dialogue) 208 | inputs_sim.append(sec_sim) 209 | decoder_inputs.append(sec_n_sent) 210 | sec_dialogue = file_dict["dialogue"] 211 | sec_n_sent = file_dict["n_sent"] 212 | sec_sim = file_dict["n_sent_sim"][:3] 213 | 214 | if len(sec_dialogue) < 10: 215 | continue 216 | 217 | inputs.append(sec_dialogue) 218 | inputs_sim.append(sec_sim) 219 | decoder_inputs.append(sec_n_sent) 220 | 221 | return inputs, inputs_sim, decoder_inputs 222 | 223 | 224 | def load_data_cl_test(path, batch_size, summary_path, nsent_path, org_dia_path): 225 | inputs = [] 226 | inputs_sim = [] 227 | decoder_inputs = [] 228 | summaries = [] 229 | 230 | file_names = os.listdir(path) 231 | for file_name in file_names: 232 | with open(os.path.join(path, file_name), 'r') as load_f: 233 | try: 234 | file_dict = json.load(load_f) 235 | input_dialogue = file_dict['unsupervised_dialogue'] 236 | input_sim = file_dict["n_sent_sim"][:3] 237 | decoder_sentence = file_dict['n_sent'] 238 | summary = file_dict['summary'] 239 | except Exception: 240 | print(file_name) 241 | continue 242 | 243 | if len(input_dialogue) < 10: 244 | continue 245 | 246 | inputs.append(input_dialogue) 247 | inputs_sim.append(input_sim) 248 | decoder_inputs.append(decoder_sentence) 249 | summaries.append(summary) 250 | 251 | summary_file = codecs.open(summary_path, 'w') 252 | abandon_sample = len(inputs) % batch_size 253 | for sumf in summaries[:-abandon_sample]: 254 | summary_file.write(sumf+"\n") 255 | 256 | nsent_file = codecs.open(nsent_path, 'w') 257 | for nsent_tmp in decoder_inputs[:-abandon_sample]: 258 | nsent_file.write(nsent_tmp+"\n") 259 | 260 | return inputs, inputs_sim, decoder_inputs 261 | 262 | 263 | 264 | -------------------------------------------------------------------------------- /modules/data/samplers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy 4 | import torch 5 | from torch.utils.data import Sampler 6 | 7 | 8 | def divide_chunks(l, n): 9 | # looping till length l 10 | for i in range(0, len(l), n): 11 | yield l[i:i + n] 12 | 13 | 14 | class SortedSampler(Sampler): 15 | """ 16 | Defines a strategy for drawing samples from the dataset, 17 | in ascending or descending order, based in the sample lengths. 18 | """ 19 | 20 | def __init__(self, lengths, descending=False): 21 | self.lengths = lengths 22 | self.desc = descending 23 | 24 | def __iter__(self): 25 | 26 | if self.desc: 27 | return iter(numpy.flip(numpy.array(self.lengths).argsort(), 0)) 28 | else: 29 | return iter(numpy.array(self.lengths).argsort()) 30 | 31 | def __len__(self): 32 | return len(self.lengths) 33 | 34 | 35 | class BucketBatchSampler(Sampler): 36 | """ 37 | Defines a strategy for drawing batches of samples from the dataset, 38 | in ascending or descending order, based in the sample lengths. 39 | """ 40 | 41 | def __init__(self, lengths, batch_size, 42 | shuffle=False, even=False, drop_last=False, reverse=False): 43 | sorted_indices = numpy.array(lengths).argsort() 44 | num_sections = math.ceil(len(lengths) / batch_size) 45 | if even: 46 | self.batches = list(divide_chunks(sorted_indices, batch_size)) 47 | else: 48 | self.batches = numpy.array_split(sorted_indices, num_sections) 49 | 50 | if reverse: 51 | self.batches = list(reversed(self.batches)) 52 | 53 | if drop_last: 54 | del self.batches[-1] 55 | 56 | self.shuffle = shuffle 57 | 58 | def __iter__(self): 59 | if self.shuffle: 60 | return iter(self.batches[i] 61 | for i in torch.randperm(len(self.batches))) 62 | else: 63 | return iter(self.batches) 64 | 65 | def __len__(self): 66 | return len(self.batches) 67 | -------------------------------------------------------------------------------- /modules/data/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | import hashlib 4 | import inspect 5 | import os 6 | import pickle 7 | import random 8 | from subprocess import check_output 9 | 10 | import numpy 11 | import sentencepiece as spm 12 | from matplotlib import pyplot as plt 13 | from tqdm import tqdm 14 | 15 | from modules.data.vocab import Vocab 16 | from sys_config import BASE_DIR 17 | 18 | def shuffle(lst1, lst2): 19 | lst = list(zip(lst1, lst2)) 20 | random.shuffle(lst) 21 | lst1, lst2 = zip(*lst) 22 | return lst1, lst2 23 | 24 | def vectorize(tokens, vocab, oovs=0): 25 | """ 26 | Covert array of tokens, to array of ids 27 | Args: 28 | tokens (list): list of tokens 29 | vocab (Vocab): 30 | Returns: list of ids 31 | """ 32 | ids = [] 33 | oov2tok = {} 34 | tok2oov = {} 35 | for token in tokens: 36 | 37 | # if token in the vocabulary, add its token id 38 | if token in vocab.tok2id: 39 | ids.append(vocab.tok2id[token]) 40 | 41 | # if an OOV token has already been encountered, use its token_id 42 | elif token in tok2oov: 43 | ids.append(vocab.tok2id[tok2oov[token]]) 44 | 45 | # if this is a new OOV token, add it to oov2tok and use its token_id 46 | elif oovs > len(oov2tok): 47 | _oov = f"" 48 | ids.append(vocab.tok2id[_oov]) 49 | oov2tok[_oov] = token 50 | tok2oov[token] = _oov 51 | 52 | # if the OOV token exceed our limit, use the generic UNK token 53 | else: 54 | ids.append(vocab.tok2id[vocab.UNK]) 55 | if oovs > 0: 56 | return ids, oov2tok 57 | else: 58 | return ids 59 | 60 | 61 | def wc(filename): 62 | return int(check_output(["wc", "-l", filename]).split()[0]) 63 | 64 | 65 | def args_to_str(args): 66 | _str = [] 67 | for x in args: 68 | if callable(x): 69 | _str.append(inspect.getsource(x)) 70 | else: 71 | _str.append(str(x)) 72 | return _str 73 | 74 | 75 | def disk_memoize(func): 76 | cache_dir = os.path.join(BASE_DIR, "_cache") 77 | if not os.path.exists(cache_dir): 78 | os.makedirs(cache_dir) 79 | 80 | @functools.wraps(func) 81 | def wrapper_decorator(*args, **kwargs): 82 | # check fn arguments 83 | args_str = ''.join(args_to_str(args)) 84 | key = hashlib.md5(args_str.encode()).hexdigest() 85 | cache_file = os.path.join(cache_dir, key) 86 | 87 | if os.path.exists(cache_file): 88 | print(f"Loading {cache_file} from cache!") 89 | with open(cache_file, 'rb') as f: 90 | return pickle.load(f) 91 | else: 92 | print(f"No cache file for {cache_file}...") 93 | data = func(*args, **kwargs) 94 | 95 | with open(cache_file, 'wb') as pickle_file: 96 | pickle.dump(data, pickle_file) 97 | 98 | return data 99 | 100 | return wrapper_decorator 101 | 102 | 103 | def iterate_data(data): 104 | if isinstance(data, str): 105 | assert os.path.exists(data), f"path `{data}` does not exist!" 106 | with open(data, "r") as f: 107 | for line in tqdm(f, total=wc(data), desc=f"Reading {data}..."): 108 | if len(line.strip()) > 0: 109 | yield line 110 | 111 | elif isinstance(data, collections.Iterable): 112 | for x in data: 113 | yield x 114 | 115 | 116 | # @disk_memoize 117 | def read_corpus(file, tokenize): 118 | _vocab = Vocab() 119 | 120 | _data = [] 121 | for line in iterate_data(file): 122 | tokens = line.split(" ") 123 | _vocab.read_sequence(tokens) 124 | _data.append(tokens) 125 | 126 | return _vocab, _data 127 | 128 | 129 | def read_corpus_dialogue(file, tokenize): 130 | _vocab = Vocab() 131 | _data = [] 132 | 133 | for line in iterate_data(file): 134 | dia = [] 135 | for sent in line: 136 | tokens = sent.split(" ") 137 | _vocab.read_sequence(tokens) 138 | dia.append(tokens) 139 | _data.append(dia) 140 | 141 | return _vocab, _data 142 | 143 | 144 | # @disk_memoize 145 | def build_vocab_from_file(file, tokenize): 146 | _vocab = Vocab() 147 | 148 | for line in iterate_data(file): 149 | tokens = tokenize(line) 150 | _vocab.read_sequence(tokens) 151 | 152 | return _vocab 153 | 154 | # @disk_memoize 155 | def read_corpus_subw(file, subword_path): 156 | subword = spm.SentencePieceProcessor() 157 | subword.Load(subword_path + ".model") 158 | 159 | vocab = Vocab(sos="", eos="", unk="") 160 | vocab.from_file(subword_path, skip=4) 161 | 162 | _data = [] 163 | for line in iterate_data(file): 164 | tokens = subword.EncodeAsPieces(line.rstrip().encode('utf-8')) 165 | _data.append(tokens) 166 | 167 | vocab.subword = subword 168 | 169 | return vocab, _data 170 | 171 | def hist_dataset(data, seq_len): 172 | lengths = [len(x) for x in data] 173 | plt.hist(lengths, density=1, bins=20) 174 | plt.axvline(seq_len, color='k', linestyle='dashed', linewidth=1) 175 | plt.show() 176 | 177 | def covarage(vocab, top_n): 178 | occurences = [freq for tok, freq in vocab.most_common()] 179 | total = sum(occurences) 180 | cov = sum(occurences[:top_n]) / total 181 | return cov 182 | 183 | def unks_per_sample_dialogue(keys, data): 184 | known = set(keys) 185 | #_coverage = [len(set(x) - known) / len(x) for x in data] 186 | _coverage = [] 187 | for exm in data: 188 | tmp_lst = [] 189 | for sent in exm: 190 | tmp_lst += sent 191 | _coverage.append(len(set(tmp_lst) - known)) 192 | return numpy.mean(_coverage) * 100 193 | 194 | def unks_per_sample(keys, data): 195 | known = set(keys) 196 | _coverage = [len(set(x) - known) / len(x) for x in data] 197 | return numpy.mean(_coverage) * 100 198 | 199 | def token_shuffle(words, factor): 200 | words = list(words) 201 | length = len(words) 202 | shuffles = int(length * factor) 203 | 204 | if len(words) < 5: 205 | return words 206 | 207 | for i in range(shuffles): 208 | i, j = tuple(int(random.random() * length) for i in range(2)) 209 | words[i], words[j] = words[j], words[i] 210 | return words 211 | 212 | def token_swaps(words, factor): 213 | if not factor > 0: 214 | return words 215 | 216 | words = list(words) 217 | length = len(words) 218 | shuffles = int(length * factor) 219 | 220 | if len(words) < 4: 221 | return words 222 | 223 | for it in range(shuffles): 224 | j = random.randint(0, length - 2) 225 | words[j], words[j + 1] = words[j + 1], words[j] 226 | 227 | return words 228 | -------------------------------------------------------------------------------- /modules/data/vocab.py: -------------------------------------------------------------------------------- 1 | from collections.__init__ import Counter 2 | 3 | import numpy 4 | from gensim.models import FastText 5 | from tqdm import tqdm 6 | 7 | from utils.load_embeddings import load_word_vectors 8 | 9 | 10 | class Vocab(object): 11 | """ 12 | The Vocab Class, holds the vocabulary of a corpus and 13 | mappings from tokens to indices and vice versa. 14 | """ 15 | 16 | def __init__(self, pad="", sos="", eos="", unk="", 17 | oovs=0): 18 | self.PAD = pad 19 | self.SOS = sos 20 | self.EOS = eos 21 | self.UNK = unk 22 | self.oovs = oovs 23 | 24 | self.vocab = Counter() 25 | 26 | self.tok2id = dict() 27 | self.id2tok = dict() 28 | 29 | self.size = 0 30 | 31 | self.subword = None 32 | 33 | def read_sequence(self, tokens): 34 | self.vocab.update(tokens) 35 | 36 | def trim(self, size): 37 | self.tok2id = dict() 38 | self.id2tok = dict() 39 | self.build(size) 40 | 41 | def read_embeddings(self, file, dim): 42 | """ 43 | Create an Embeddings Matrix, in which each row corresponds to 44 | the word vector from the pretrained word embeddings. 45 | If a word is missing from the provided pretrained word vectors, then 46 | sample a new embedding, from the gaussian of the pretrained embeddings. 47 | 48 | Args: 49 | file: 50 | dim: 51 | 52 | Returns: 53 | 54 | """ 55 | word2idx, idx2word, embeddings = load_word_vectors(file, dim) 56 | 57 | mu = embeddings.mean(axis=0) 58 | sigma = embeddings.std(axis=0) 59 | 60 | filtered_embeddings = numpy.zeros((len(self), embeddings.shape[1])) 61 | 62 | mask = numpy.zeros(len(self)) 63 | missing = [] 64 | 65 | for token_id, token in tqdm(self.id2tok.items(), 66 | desc="Reading embeddings...", 67 | total=len(self.id2tok.items())): 68 | if token not in word2idx or token == "": 69 | # todo: smart sampling per dim distribution 70 | # sample = numpy.random.uniform(low=-0.5, high=0.5, 71 | # size=embeddings.shape[1]) 72 | sample = numpy.random.normal(mu, sigma / 4) 73 | filtered_embeddings[token_id] = sample 74 | 75 | mask[token_id] = 1 76 | missing.append(token_id) 77 | else: 78 | filtered_embeddings[token_id] = embeddings[word2idx[token]] 79 | 80 | print(f"Missing tokens from the pretrained embeddings: {len(missing)}") 81 | 82 | return filtered_embeddings, mask, missing 83 | 84 | def read_fasttext(self, file): 85 | """ 86 | Create an Embeddings Matrix, in which each row corresponds to 87 | the word vector from the pretrained word embeddings. 88 | If a word is missing then obtain a representation on-the-fly 89 | using fasttext. 90 | 91 | Args: 92 | file: 93 | dim: 94 | 95 | Returns: 96 | 97 | """ 98 | model = FastText.load_fasttext_format(file) 99 | 100 | embeddings = numpy.zeros((len(self), model.vector_size)) 101 | 102 | missing = [] 103 | 104 | for token_id, token in tqdm(self.id2tok.items(), 105 | desc="Reading embeddings...", 106 | total=len(self.id2tok.items())): 107 | if token not in model.wv.vocab: 108 | missing.append(token) 109 | embeddings[token_id] = model[token] 110 | 111 | print(f"Missing tokens from the pretrained embeddings: {len(missing)}") 112 | 113 | return embeddings, missing 114 | 115 | def add_token(self, token): 116 | index = len(self.tok2id) 117 | 118 | if token not in self.tok2id: 119 | self.tok2id[token] = index 120 | self.id2tok[index] = token 121 | self.size = len(self) 122 | 123 | def __add_special_tokens(self): 124 | self.add_token(self.PAD) 125 | self.add_token(self.SOS) 126 | self.add_token(self.EOS) 127 | self.add_token(self.UNK) 128 | 129 | def from_file(self, file, skip=0): 130 | self.__add_special_tokens() 131 | 132 | lines = open(file).readlines()[skip:] 133 | for line in lines: 134 | token = line.split()[0] 135 | self.add_token(token) 136 | 137 | def to_file(self, file): 138 | with open(file, "w") as f: 139 | f.write("\n".join(self.tok2id.keys())) 140 | 141 | def is_corrupt(self): 142 | return len([tok for tok, index in self.tok2id.items() 143 | if self.id2tok[index] != tok]) > 0 144 | 145 | def get_tokens(self): 146 | return [self.id2tok[key] for key in sorted(self.id2tok.keys())] 147 | 148 | def build(self, size=None): 149 | self.__add_special_tokens() 150 | 151 | for w, k in self.vocab.most_common(size): 152 | self.add_token(w) 153 | 154 | def __len__(self): 155 | return len(self.tok2id) 156 | -------------------------------------------------------------------------------- /modules/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from numpy import mean 3 | from torch.nn import functional as F 4 | from torch.nn.functional import _gumbel_softmax_sample 5 | 6 | 7 | def sequence_mask(lengths, max_len=None): 8 | """ 9 | Creates a boolean mask from sequence lengths. 10 | """ 11 | batch_size = lengths.numel() 12 | max_len = max_len or lengths.max() 13 | return (torch.arange(0, max_len, device=lengths.device) 14 | .type_as(lengths) 15 | .unsqueeze(0).expand(batch_size, max_len) 16 | .lt(lengths.unsqueeze(1))) 17 | 18 | 19 | def masked_normalization(logits, mask): 20 | scores = F.softmax(logits, dim=-1) 21 | 22 | # apply the mask - zero out masked timesteps 23 | masked_scores = scores * mask.float() 24 | 25 | # re-normalize the masked scores 26 | normed_scores = masked_scores.div(masked_scores.sum(-1, keepdim=True)) 27 | 28 | return normed_scores 29 | 30 | 31 | def masked_mean(vecs, mask): 32 | masked_vecs = vecs * mask.float() 33 | 34 | mean = masked_vecs.sum(1) / mask.sum(1) 35 | 36 | return mean 37 | 38 | 39 | def masked_normalization_inf(logits, mask): 40 | logits.masked_fill_(1 - mask, float('-inf')) 41 | # energies.masked_fill_(1 - mask, -1e18) 42 | 43 | scores = F.softmax(logits, dim=-1) 44 | 45 | return scores 46 | 47 | 48 | def expected_vecs(dists, vecs): 49 | flat_probs = dists.contiguous().view(dists.size(0) * dists.size(1), 50 | dists.size(2)) 51 | flat_embs = flat_probs.mm(vecs) 52 | embs = flat_embs.view(dists.size(0), dists.size(1), flat_embs.size(1)) 53 | return embs 54 | 55 | 56 | def straight_softmax(logits, tau=1, hard=False, target_mask=None): 57 | y_soft = F.softmax(logits.squeeze() / tau, dim=1) 58 | 59 | if target_mask is not None: 60 | y_soft = y_soft * target_mask.float() 61 | y_soft.div(y_soft.sum(-1, keepdim=True)) 62 | 63 | if hard: 64 | shape = logits.size() 65 | _, k = y_soft.max(-1) 66 | y_hard = logits.new_zeros(*shape).scatter_(-1, k.view(-1, 1), 1.0) 67 | y = y_hard - y_soft.detach() + y_soft 68 | return y 69 | else: 70 | return y_soft 71 | 72 | 73 | def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, target_mask=None): 74 | r""" 75 | Sample from the Gumbel-Softmax distribution and optionally discretize. 76 | 77 | Args: 78 | logits: `[batch_size, num_features]` unnormalized log probabilities 79 | tau: non-negative scalar temperature 80 | hard: if ``True``, the returned samples will be discretized as one-hot vectors, 81 | but will be differentiated as if it is the soft sample in autograd 82 | 83 | Returns: 84 | Sampled tensor of shape ``batch_size x num_features`` from the Gumbel-Softmax distribution. 85 | If ``hard=True``, the returned samples will be one-hot, otherwise they will 86 | be probability distributions that sum to 1 across features 87 | 88 | Constraints: 89 | 90 | - Currently only work on 2D input :attr:`logits` tensor of shape ``batch_size x num_features`` 91 | 92 | Based on 93 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb , 94 | (MIT license) 95 | """ 96 | shape = logits.size() 97 | assert len(shape) == 2 98 | y_soft = _gumbel_softmax_sample(logits, tau=tau, eps=eps) 99 | 100 | if target_mask is not None: 101 | y_soft = y_soft * target_mask.float() 102 | y_soft.div(y_soft.sum(-1, keepdim=True)) 103 | 104 | if hard: 105 | _, k = y_soft.max(-1) 106 | # this bit is based on 107 | # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5 108 | y_hard = logits.new_zeros(*shape).scatter_(-1, k.view(-1, 1), 1.0) 109 | # this cool bit of code achieves two things: 110 | # - makes the output value exactly one-hot (since we add then 111 | # subtract y_soft value) 112 | # - makes the gradient equal to y_soft gradient (since we strip 113 | # all other gradients) 114 | y = y_hard - y_soft.detach() + y_soft 115 | else: 116 | y = y_soft 117 | return y 118 | 119 | 120 | def avg_vectors(vectors, mask, energies=None): 121 | if energies is None: 122 | centroid = masked_mean(vectors, mask) 123 | return centroid, None 124 | 125 | else: 126 | masked_scores = energies * mask.float() 127 | normed_scores = masked_scores.div(masked_scores.sum(1, keepdim=True)) 128 | centroid = (vectors * normed_scores).sum(1) 129 | return centroid, normed_scores 130 | 131 | 132 | def aeq(*args): 133 | """ 134 | Assert all arguments have the same value 135 | """ 136 | arguments = (arg for arg in args) 137 | first = next(arguments) 138 | assert all(arg == first for arg in arguments), \ 139 | "Not all arguments have the same value: " + str(args) 140 | 141 | 142 | def module_grad_wrt_loss(optimizers, module, loss, prefix=None): 143 | loss.backward(retain_graph=True) 144 | 145 | grad_norms = [(n, p.grad.norm().item()) 146 | for n, p in module.named_parameters()] 147 | 148 | if prefix is not None: 149 | grad_norms = [g for g in grad_norms if g[0].startswith(prefix)] 150 | 151 | mean_norm = mean([gn for n, gn in grad_norms]) 152 | 153 | for optimizer in optimizers: 154 | optimizer.zero_grad() 155 | 156 | return mean_norm 157 | 158 | 159 | def index_mask(mask_row, mask_col, index): 160 | A = torch.zeros((mask_row, mask_col)) 161 | B = index.float() 162 | AA = 1 - A 163 | seq_len = torch.sum(AA, dim=-1) 164 | word_offset = torch.cumsum(seq_len, dim=0).cuda() 165 | BB = B + word_offset.unsqueeze(dim=-1) 166 | 167 | flag_A = A.view(-1) 168 | flag_BB = BB.view(-1).long() 169 | flag_BB = torch.sub(flag_BB, mask_col) 170 | flag_A[flag_BB] = 1 171 | A = flag_A.view(mask_row, -1) 172 | 173 | return A 174 | 175 | 176 | def kl_categorical(p_logit, q_logit): 177 | p = F.softmax(p_logit, dim=-1) 178 | _kl = torch.sum(p * (F.log_softmax(p_logit, dim=-1) - F.log_softmax(q_logit, dim=-1)), 1) 179 | return torch.mean(_kl) 180 | -------------------------------------------------------------------------------- /modules/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | 5 | from modules.helpers import sequence_mask, masked_normalization_inf 6 | 7 | 8 | class GaussianNoise(nn.Module): 9 | def __init__(self, stddev, mean=.0): 10 | """ 11 | Additive Gaussian Noise layer 12 | Args: 13 | stddev (float): the standard deviation of the distribution 14 | mean (float): the mean of the distribution 15 | """ 16 | super().__init__() 17 | self.stddev = stddev 18 | self.mean = mean 19 | 20 | def forward(self, x): 21 | if self.training: 22 | # todo data_bug 23 | noise = Variable(x.data.new(x.size()).normal_(self.mean, 24 | self.stddev)) 25 | return x + noise 26 | return x 27 | 28 | def __repr__(self): 29 | return '{} (mean={}, stddev={})'.format(self.__class__.__name__, 30 | str(self.mean), 31 | str(self.stddev)) 32 | 33 | 34 | class Embed(nn.Module): 35 | def __init__(self, 36 | num_embeddings, 37 | embedding_dim, 38 | embeddings=None, 39 | noise=.0, 40 | dropout=.0, 41 | trainable=True, grad_mask=None, norm=False): 42 | """ 43 | Define the layer of the model and perform the initializations 44 | of the layers (wherever it is necessary) 45 | Args: 46 | embeddings (numpy.ndarray): the 2D ndarray with the word vectors 47 | noise (float): 48 | dropout (float): 49 | trainable (bool): 50 | """ 51 | super(Embed, self).__init__() 52 | 53 | self.norm = norm 54 | 55 | # define the embedding layer, with the corresponding dimensions 56 | self.embedding = nn.Embedding(num_embeddings=num_embeddings, 57 | embedding_dim=embedding_dim) 58 | 59 | # initialize the weights of the Embedding layer, 60 | # with the given pre-trained word vectors 61 | if embeddings is not None: 62 | print("Initializing Embedding layer with pre-trained weights!") 63 | self.init_embeddings(embeddings, trainable) 64 | 65 | # the dropout "layer" for the word embeddings 66 | self.dropout = nn.Dropout(dropout) 67 | 68 | # the gaussian noise "layer" for the word embeddings 69 | self.noise = GaussianNoise(noise) 70 | 71 | self.grad_mask = grad_mask 72 | 73 | if self.norm: 74 | self.layer_norm = nn.LayerNorm(embedding_dim) 75 | 76 | if self.grad_mask is not None: 77 | self.set_grad_mask(self.grad_mask) 78 | 79 | def _emb_hook(self, grad): 80 | return grad * Variable(self.grad_mask.unsqueeze(1)).type_as(grad) 81 | 82 | def set_grad_mask(self, mask): 83 | self.grad_mask = torch.from_numpy(mask) 84 | self.embedding.weight.register_hook(self._emb_hook) 85 | 86 | def init_embeddings(self, weights, trainable): 87 | self.embedding.weight = nn.Parameter(torch.from_numpy(weights), 88 | requires_grad=trainable) 89 | 90 | def regularize(self, embeddings): 91 | if self.noise.stddev > 0: 92 | embeddings = self.noise(embeddings) 93 | 94 | if self.dropout.p > 0: 95 | embeddings = self.dropout(embeddings) 96 | 97 | return embeddings 98 | 99 | def expectation(self, dists): 100 | """ 101 | Obtain a weighted sum (expectation) of all the embeddings, from a 102 | given probability distribution. 103 | 104 | """ 105 | flat_probs = dists.contiguous().view(dists.size(0) * dists.size(1), dists.size(2)) 106 | flat_embs = flat_probs.mm(self.embedding.weight) 107 | embs = flat_embs.view(dists.size(0), dists.size(1), flat_embs.size(1)) 108 | 109 | # apply layer normalization on the expectation 110 | if self.norm: 111 | embs = self.layer_norm(embs) 112 | 113 | # apply all embedding layer's regularizations 114 | embs = self.regularize(embs) 115 | 116 | return embs 117 | 118 | def forward(self, x): 119 | """ 120 | This is the heart of the model. This function, defines how the data 121 | passes through the network. 122 | Args: 123 | x (): the input data (the sentences) 124 | 125 | Returns: the logits for each class 126 | 127 | """ 128 | embeddings = self.embedding(x) 129 | 130 | if self.norm: 131 | embeddings = self.layer_norm(embeddings) 132 | 133 | embeddings = self.regularize(embeddings) 134 | 135 | return embeddings 136 | 137 | 138 | class Attention(nn.Module): 139 | def __init__(self, 140 | input_size, 141 | context_size, 142 | batch_first=True, 143 | non_linearity="tanh", 144 | method="general", 145 | coverage=False): 146 | super(Attention, self).__init__() 147 | 148 | self.batch_first = batch_first 149 | self.method = method 150 | self.coverage = coverage 151 | 152 | if self.method not in ["dot", "general", "concat", "additive"]: 153 | raise ValueError("Please select a valid attention type.") 154 | 155 | if self.coverage: 156 | self.W_c = nn.Linear(1, context_size, bias=False) 157 | self.method = "additive" 158 | 159 | if non_linearity == "relu": 160 | self.activation = nn.ReLU() 161 | else: 162 | self.activation = nn.Tanh() 163 | 164 | if self.method == "general": 165 | self.W_h = nn.Linear(input_size, context_size) 166 | 167 | elif self.method == "additive": 168 | self.W_h = nn.Linear(input_size, context_size) 169 | self.W_s = nn.Linear(context_size, context_size) 170 | self.W_v = nn.Linear(context_size, 1) 171 | 172 | elif self.method == "concat": 173 | self.W_h = nn.Linear(input_size + context_size, context_size) 174 | self.W_v = nn.Linear(context_size, 1) 175 | 176 | def score(self, sequence, query, coverage=None): 177 | batch_size, max_length, feat_size = sequence.size() 178 | 179 | if self.method == "dot": 180 | energies = torch.matmul(sequence, query.unsqueeze(2)).squeeze(2) 181 | 182 | elif self.method == "additive": 183 | enc = self.W_h(sequence) 184 | dec = self.W_s(query) 185 | sums = enc + dec.unsqueeze(1) 186 | 187 | if self.coverage: 188 | cov = self.W_c(coverage.unsqueeze(-1)) 189 | sums = sums + cov 190 | 191 | energies = self.W_v(self.activation(sums)).squeeze(2) 192 | 193 | elif self.method == "general": 194 | h = self.W_h(sequence) 195 | energies = torch.matmul(h, query.unsqueeze(2)).squeeze(2) 196 | 197 | elif self.method == "concat": 198 | c = query.unsqueeze(1).expand(-1, max_length, -1) 199 | u = self.W_h(torch.cat([sequence, c], -1)) 200 | energies = self.W_v(self.activation(u)).squeeze(2) 201 | 202 | else: 203 | raise ValueError 204 | 205 | return energies 206 | 207 | def forward(self, sequence, query, lengths, coverage=None): 208 | 209 | energies = self.score(sequence, query, coverage) 210 | 211 | # construct a mask, based on sentence lengths 212 | mask = sequence_mask(lengths, energies.size(1)) 213 | 214 | scores = masked_normalization_inf(energies, mask) 215 | # scores = self.masked_normalization(energies, mask) 216 | 217 | contexts = (sequence * scores.unsqueeze(-1)).sum(1) 218 | 219 | return contexts, scores 220 | -------------------------------------------------------------------------------- /modules/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from modules.modules import RecurrentHelper, AttSeqDecoder, SeqReader 6 | from modules.helpers import sequence_mask, avg_vectors, index_mask 7 | 8 | 9 | class Seq2Seq2Seq(nn.Module, RecurrentHelper): 10 | 11 | def __init__(self, n_tokens, **kwargs): 12 | super(Seq2Seq2Seq, self).__init__() 13 | 14 | ############################################ 15 | # Attributes 16 | ############################################ 17 | self.n_tokens = n_tokens 18 | self.bridge_hidden = kwargs.get("bridge_hidden", False) 19 | self.bridge_non_linearity = kwargs.get("bridge_non_linearity", None) 20 | self.detach_hidden = kwargs.get("detach_hidden", False) 21 | self.input_feeding = kwargs.get("input_feeding", False) 22 | self.length_control = kwargs.get("length_control", False) 23 | self.bi_encoder = kwargs.get("rnn_bidirectional", False) 24 | self.rnn_type = kwargs.get("rnn_type", "LSTM") 25 | self.layer_norm = kwargs.get("layer_norm", False) 26 | self.sos = kwargs.get("sos", 1) 27 | self.sample_embed_noise = kwargs.get("sample_embed_noise", 0) 28 | self.topic_idf = kwargs.get("topic_idf", False) 29 | self.dec_token_dropout = kwargs.get("dec_token_dropout", .0) 30 | self.enc_token_dropout = kwargs.get("enc_token_dropout", .0) 31 | 32 | self.batch_size = kwargs.get("batch_size") 33 | self.sent_num = kwargs.get("sent_num") 34 | self.sent_len = kwargs.get("sent_len") 35 | 36 | # tie embedding layers to output layers (vocabulary projections) 37 | kwargs["tie_weights"] = kwargs.get("tie_embedding_outputs", False) 38 | 39 | ############################################ 40 | # Layers 41 | ############################################ 42 | 43 | # backward-compatibility for older version of the project 44 | kwargs["rnn_size"] = kwargs.get("enc_rnn_size", kwargs.get("rnn_size")) 45 | self.inp_encoder = SeqReader(self.n_tokens, **kwargs) 46 | enc_size = self.inp_encoder.rnn_size 47 | self.sent_classification = torch.nn.Linear(enc_size, 1) 48 | self.sent_similar = torch.nn.Linear(enc_size*2, 1) 49 | 50 | # backward-compatibility for older version of the project 51 | kwargs["rnn_size"] = kwargs.get("dec_rnn_size", kwargs.get("rnn_size")) 52 | self.dia_nsent = AttSeqDecoder(self.n_tokens, enc_size, **kwargs) 53 | self.sum_nsent = AttSeqDecoder(self.n_tokens, enc_size, **kwargs) 54 | 55 | # create a dummy embedding layer, which will retrieve the idf values 56 | # of each word, given the word ids 57 | if self.topic_idf: 58 | self.idf = nn.Embedding(num_embeddings=n_tokens, embedding_dim=1) 59 | self.idf.weight.requires_grad = False 60 | 61 | if self.bridge_hidden: 62 | self._initialize_bridge(enc_size, 63 | kwargs["dec_rnn_size"], 64 | kwargs["rnn_layers"]) 65 | 66 | def _initialize_bridge(self, enc_hidden_size, dec_hidden_size, num_layers): 67 | """ 68 | adapted from 69 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/encoders/rnn_encoder.py#L85 70 | """ 71 | 72 | # LSTM has hidden and cell state, other only one 73 | number_of_states = 2 if self.rnn_type == "LSTM" else 1 74 | 75 | if self.length_control: 76 | # add a parameter, for scaling the absolute target length 77 | self.Wl = nn.Parameter(torch.rand(1)) 78 | # the length information will contain 2 additional dimensions, 79 | # - the target length 80 | # - the expansion / compression ratio given the source length 81 | enc_hidden_size += 2 82 | 83 | # Build a linear layer for each 84 | self.src_bridge = nn.ModuleList([nn.Linear(enc_hidden_size, 85 | dec_hidden_size) 86 | for _ in range(number_of_states)]) 87 | self.trg_bridge = nn.ModuleList([nn.Linear(enc_hidden_size, 88 | dec_hidden_size) 89 | for _ in range(number_of_states)]) 90 | 91 | def _bridge(self, bridge, hidden, src_lengths=None, trg_lengths=None): 92 | """Forward hidden state through bridge.""" 93 | 94 | def _fix_hidden(_hidden): 95 | # The encoder hidden is (layers*directions) x batch x dim. 96 | # We need to convert it to layers x batch x (directions*dim). 97 | fwd_final = _hidden[0:_hidden.size(0):2] 98 | bwd_final = _hidden[1:_hidden.size(0):2] 99 | final = torch.cat([fwd_final, bwd_final], dim=2) 100 | return final 101 | 102 | def bottle_hidden(linear, states, length_feats=None): 103 | if length_feats is not None: 104 | lf = length_feats.unsqueeze(0).repeat(states.size(0), 1, 1) 105 | _states = torch.cat([states, lf], -1) 106 | result = linear(_states) 107 | else: 108 | result = linear(states) 109 | 110 | if self.bridge_non_linearity == "tanh": 111 | result = torch.tanh(result) 112 | elif self.bridge_non_linearity == "relu": 113 | result = F.relu(result) 114 | 115 | return result 116 | 117 | if self.length_control: 118 | ratio = trg_lengths.float() / src_lengths.float() 119 | lengths = trg_lengths.float() * self.Wl 120 | L = torch.stack([ratio, lengths], -1) 121 | else: 122 | L = None 123 | 124 | if isinstance(hidden, tuple): # LSTM 125 | # concat directions 126 | hidden = tuple(_fix_hidden(h) for h in hidden) 127 | outs = tuple([bottle_hidden(state, hidden[ix], L) 128 | for ix, state in enumerate(bridge)]) 129 | else: 130 | outs = bottle_hidden(bridge[0], hidden) 131 | 132 | return outs 133 | 134 | def initialize_embeddings(self, embs, trainable=False): 135 | 136 | freeze = not trainable 137 | 138 | embeddings = torch.from_numpy(embs).float() 139 | embedding_layer = nn.Embedding.from_pretrained(embeddings, freeze) 140 | 141 | self.inp_encoder.embed.embedding = embedding_layer 142 | self.cmp_encoder.embed.embedding = embedding_layer 143 | self.compressor.embed.embedding = embedding_layer 144 | self.decompressor.embed.embedding = embedding_layer 145 | self.original_task.embed.embedding = embedding_layer 146 | 147 | def initialize_embeddings_idf(self, idf): 148 | idf_embs = torch.from_numpy(idf).float().unsqueeze(-1) 149 | self.idf = nn.Embedding.from_pretrained(idf_embs, freeze=True) 150 | 151 | def set_embedding_gradient_mask(self, mask): 152 | self.inp_encoder.embed.set_grad_mask(mask) 153 | self.cmp_encoder.embed.set_grad_mask(mask) 154 | self.compressor.embed.set_grad_mask(mask) 155 | self.decompressor.embed.set_grad_mask(mask) 156 | self.original_task.embed.set_grad_mask(mask) 157 | 158 | def _fake_inputs(self, inputs, latent_lengths, pad=1): 159 | batch_size, seq_len = inputs.size() 160 | 161 | if latent_lengths is not None: 162 | max_length = max(latent_lengths) 163 | else: 164 | max_length = seq_len + pad 165 | 166 | fakes = torch.zeros(batch_size, max_length, device=inputs.device) 167 | fakes = fakes.type_as(inputs) 168 | fakes[:, 0] = self.sos 169 | return fakes 170 | 171 | #def generate(self, inputs, nsent, src_lengths, trg_seq_len, nsent_len, sampling): 172 | def generate(self, inputs, src_lengths, trg_seq_len): 173 | # ENCODER1 174 | enc1_results = self.inp_encoder(inputs, None, src_lengths) 175 | outs_enc1, hn_enc1 = enc1_results[-2:] 176 | 177 | # DECODER1 178 | dec_init = self._bridge(self.src_bridge, hn_enc1, src_lengths, trg_seq_len) 179 | inp_fake = self._fake_inputs(inputs, trg_seq_len) 180 | dec1_results = self.compressor(inp_fake, outs_enc1, dec_init, 181 | argmax=True, 182 | enc_lengths=src_lengths, 183 | sampling_prob=1., 184 | desired_lengths=trg_seq_len) 185 | 186 | return enc1_results, dec1_results 187 | 188 | def summary(self, inp_src, sent_len, sent_num): 189 | imp_src_org = inp_src.view(self.batch_size, self.sent_num, self.sent_len) 190 | inp_src = imp_src_org.view(imp_src_org.size(0) * imp_src_org.size(1), imp_src_org.size(2)) 191 | inp_length = imp_src_org.size(2) 192 | sent_len = sent_len.view(self.batch_size * self.sent_num) 193 | enc1_results = self.inp_encoder(inp_src, None, sent_len, word_dropout=self.enc_token_dropout) 194 | outs_enc, hn_enc = enc1_results[-2:] 195 | 196 | sent_len_mask = torch.unsqueeze(sequence_mask(sent_len, max_len=self.sent_len), -1).float() 197 | outs_enc = torch.mul(outs_enc, sent_len_mask) 198 | outs_enc = outs_enc.view(self.batch_size, self.sent_num, self.sent_len, -1) 199 | outs_enc = torch.sum(outs_enc, dim=2) 200 | 201 | sent_num_mask = torch.unsqueeze(sequence_mask(sent_num, max_len=self.sent_num), -1).float() 202 | sent_sum_prb = self.sent_classification(outs_enc) 203 | sent_sum_prb = nn.functional.softmax(sent_sum_prb, dim=1) 204 | sent_sum_prb = torch.mul(sent_sum_prb, sent_num_mask) 205 | """_, top_k_index = torch.topk(sent_sum_prb, k=k, dim=1) 206 | top_k_index = torch.squeeze(top_k_index) 207 | top_k_mask = index_mask(self.batch_size, self.sent_num, top_k_index) 208 | top_k_mask = torch.unsqueeze(top_k_mask, dim=-1).cuda() 209 | outs_enc_filter = outs_enc.mul(top_k_mask)""" 210 | 211 | return sent_sum_prb 212 | 213 | def forward(self, k, inp_src, inp_sim, inp_trg, sim_len, sent_len, sent_num, trg_lengths): 214 | """ 215 | enc1------------------>dec1 216 | | | 217 | | | 218 | summary------>enc2---->dec2 219 | 220 | (extrative-based summarization) 221 | 222 | inp_src: input source (batch x sent_num x sent_len) 223 | inp_sim: k similar sentences to nth sentence (batch x k x sim_len) 224 | inp_trg: input nsent (batch x nsent_len) 225 | sim_len: length of each sentence of similar sentences 226 | sent_len: length of each sentence in a dialogue 227 | sent_num: sentence number in a dialogue 228 | trg_lenghts: nth sentence length 229 | """ 230 | 231 | # -------------------------------------------- 232 | # ENCODER (encode each sentence) 233 | # -------------------------------------------- 234 | # encode dialogue 235 | imp_src_org = inp_src.view(self.batch_size, self.sent_num, self.sent_len) 236 | inp_src = imp_src_org.view(imp_src_org.size(0) * imp_src_org.size(1), imp_src_org.size(2)) 237 | inp_length = imp_src_org.size(2) 238 | sent_len = sent_len.view(self.batch_size * self.sent_num) 239 | enc1_results = self.inp_encoder(inp_src, None, sent_len, word_dropout=self.enc_token_dropout) 240 | outs_enc, hn_enc = enc1_results[-2:] 241 | 242 | sent_len_mask = torch.unsqueeze(sequence_mask(sent_len, max_len=self.sent_len), -1).float() 243 | outs_enc = torch.mul(outs_enc, sent_len_mask) 244 | outs_enc = outs_enc.view(self.batch_size, self.sent_num, self.sent_len, -1) 245 | outs_enc = torch.sum(outs_enc, dim=2) 246 | 247 | sent_num_mask = torch.unsqueeze(sequence_mask(sent_num, max_len=self.sent_num), -1).float() 248 | sent_sum_prb = self.sent_classification(outs_enc) 249 | sent_sum_prb = nn.functional.softmax(sent_sum_prb, dim=1) 250 | sent_sum_prb = torch.mul(sent_sum_prb, sent_num_mask) 251 | _, top_k_index = torch.topk(sent_sum_prb, k=k, dim=1) 252 | top_k_index = torch.squeeze(top_k_index) 253 | top_k_mask = index_mask(self.batch_size, self.sent_num, top_k_index) 254 | top_k_mask = torch.unsqueeze(top_k_mask, dim=-1).cuda() 255 | outs_enc_filter = outs_enc.mul(top_k_mask) 256 | 257 | # encode k similar sentences to nth sentence 258 | k_num = sim_len.size(1) 259 | inp_sim = inp_sim.view(self.batch_size, k_num, -1) 260 | inp_sim = inp_sim.view(self.batch_size * k_num, -1) 261 | sim_len = sim_len.view(self.batch_size * k_num) 262 | enc2_results = self.inp_encoder(inp_sim, None, sim_len, word_dropout=self.enc_token_dropout) 263 | outs_enc_sim, hn_enc_sim = enc2_results[-2:] 264 | outs_enc_sim = torch.sum(outs_enc_sim, dim=1) 265 | outs_enc_sim = outs_enc_sim.view(self.batch_size, k_num, -1) 266 | 267 | ## initiate decoder 268 | hn_enc_rst = [] 269 | for index, hn_emc_tmp in enumerate(hn_enc): 270 | hn_emc_tmp = hn_emc_tmp.chunk(self.batch_size, dim=1) 271 | rst = [] 272 | for _, sample in enumerate(hn_emc_tmp): 273 | sample = torch.sum(sample, dim=1) 274 | rst.append(sample) 275 | hn_emc_tmp = torch.stack(rst, dim=1) 276 | hn_enc_rst.append(hn_emc_tmp) 277 | hn_enc = tuple(hn_enc_rst) 278 | 279 | sent_len= sent_len.view(self.batch_size, self.sent_num) 280 | _dec_init = self._bridge(self.src_bridge, hn_enc, sent_num, trg_lengths) 281 | 282 | # ------------------------------------------------------------- 283 | # DECODER-1 (generate nth sentence based on original dialogue) 284 | # ------------------------------------------------------------- 285 | dec1_results = self.dia_nsent(inp_trg, outs_enc, _dec_init, 286 | enc_lengths=sent_num, 287 | sampling_prob=1., 288 | desired_lengths=trg_lengths) 289 | 290 | # -------------------------------------------------- 291 | # DECODER-2 (generate nth sentence based on summary) 292 | # -------------------------------------------------- 293 | dec2_results = self.sum_nsent(inp_trg, outs_enc_filter, _dec_init, 294 | enc_lengths=sent_num, 295 | sampling_prob=1., 296 | desired_lengths=trg_lengths) 297 | 298 | # -------------------------------------------------- 299 | # Predict similar sentences 300 | # -------------------------------------------------- 301 | outs_enc_pre = torch.unsqueeze(torch.sum(outs_enc, dim=1), dim=1) 302 | outs_enc_filter_pre = torch.unsqueeze(torch.sum(outs_enc_filter, dim=1), dim=1) 303 | 304 | outs_enc_pre = outs_enc_pre.expand(outs_enc_sim.size(0), outs_enc_sim.size(1), outs_enc_sim.size(2)) 305 | outs_enc_filter_pre = outs_enc_filter_pre.expand(outs_enc_sim.size(0), outs_enc_sim.size(1), outs_enc_sim.size(2)) 306 | outs_enc_pre = torch.cat((outs_enc_pre, outs_enc_sim), dim=-1) 307 | outs_enc_filter_pre = torch.cat((outs_enc_filter_pre, outs_enc_sim), dim=-1) 308 | 309 | outs_enc_pre = self.sent_similar(outs_enc_pre) 310 | outs_enc_filter_pre = self.sent_similar(outs_enc_filter_pre) 311 | 312 | dialog_pre = torch.squeeze(outs_enc_pre, dim=-1) 313 | summary_pre = torch.squeeze(outs_enc_filter_pre, dim=-1) 314 | 315 | return sent_sum_prb, outs_enc, outs_enc_filter, dec1_results, dec2_results, sent_len, dialog_pre, summary_pre 316 | -------------------------------------------------------------------------------- /modules/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiyan524/RepSum/cddc47188e445ccdc23b30e9d8d5f2daa16b7c1d/modules/training/__init__.py -------------------------------------------------------------------------------- /modules/training/base_trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy 4 | import torch 5 | 6 | class BaseTrainer: 7 | def __init__(self, train_loader, valid_loader, 8 | config, device, 9 | batch_end_callbacks=None, loss_weights=None, 10 | parallel=False, 11 | **kwargs): 12 | 13 | self.train_loader = train_loader 14 | self.valid_loader = valid_loader 15 | self.device = device 16 | self.loss_weights = loss_weights 17 | 18 | self.config = config 19 | self.log_interval = self.config["log_interval"] 20 | self.batch_size = self.config["batch_size"] 21 | self.checkpoint_interval = self.config["checkpoint_interval"] 22 | self.clip = self.config["model"]["clip"] 23 | 24 | if batch_end_callbacks is None: 25 | self.batch_end_callbacks = [] 26 | else: 27 | self.batch_end_callbacks = [c for c in batch_end_callbacks if callable(c)] 28 | 29 | self.epoch = 0 30 | self.step = 0 31 | self.progress_log = None 32 | 33 | # init dataset 34 | self.train_set_size = self._get_dataset_size(self.train_loader) 35 | self.val_set_size = self._get_dataset_size(self.valid_loader) 36 | 37 | self.n_batches = math.ceil( 38 | float(self.train_set_size) / self.batch_size) 39 | self.total_steps = self.n_batches * self.config["epochs"] 40 | 41 | if self.loss_weights is not None: 42 | self.loss_weights = [self.anneal_init(w) for w in 43 | self.loss_weights] 44 | 45 | @staticmethod 46 | def _roll_seq(x, dim=1, shift=1): 47 | length = x.size(dim) - shift 48 | 49 | seq = torch.cat([x.narrow(dim, shift, length), 50 | torch.zeros_like(x[:, :1])], dim) 51 | 52 | return seq 53 | 54 | @staticmethod 55 | def _get_dataset_size(loader): 56 | """ 57 | If the trainer holds multiple datasets, then the size 58 | is estimated based on the largest one. 59 | """ 60 | if isinstance(loader, (tuple, list)): 61 | return len(loader[0].dataset) 62 | else: 63 | return len(loader.dataset) 64 | 65 | def anneal_init(self, param, steps=None): 66 | if isinstance(param, list): 67 | if steps is None: 68 | steps = self.total_steps 69 | return numpy.geomspace(param[0], param[1], num=steps).tolist() 70 | else: 71 | return param 72 | 73 | def anneal_step(self, param): 74 | if isinstance(param, list): 75 | try: 76 | _val = param[self.step] 77 | except: 78 | _val = param[-1] 79 | else: 80 | _val = param 81 | 82 | return _val 83 | 84 | def _tensors_to_device(self, batch): 85 | """batch_trans = [] 86 | for sample in batch: 87 | batch_trans.append([i for item in sample for i in item]) 88 | res_tmp = list(map(lambda x: x.to(self.device), batch_trans))""" 89 | return list(map(lambda x: x.to(self.device), batch)) 90 | 91 | def _batch_to_device(self, batch): 92 | 93 | if torch.is_tensor(batch[0]): 94 | batch = self._tensors_to_device(batch) 95 | else: 96 | batch = list(map(lambda x: self._tensors_to_device(x), batch)) 97 | 98 | return batch 99 | 100 | @staticmethod 101 | def _multi_dataset_iter(loader, strategy, step=1): 102 | # todo: generalize to N datasets. For now works only with 2. 103 | sizes = [len(x) for x in loader] 104 | 105 | iter_a = iter(loader[0]) 106 | iter_b = iter(loader[1]) 107 | 108 | if strategy == "spread": 109 | step = math.floor((sizes[0] - sizes[1]) / (sizes[1] - 1)) 110 | 111 | for i in range(max(sizes)): 112 | if i % (step + 1) == 0: 113 | batch_a = next(iter_a) 114 | batch_b = next(iter_b, None) 115 | 116 | if batch_b is not None: 117 | yield batch_a, batch_b 118 | else: 119 | yield batch_a 120 | else: 121 | yield next(iter_a) 122 | 123 | if strategy == "modulo": 124 | for i in range(max(sizes)): 125 | if i % step == 0: 126 | batch_a = next(iter_a) 127 | batch_b = next(iter_b, None) 128 | 129 | if batch_b is None: # reset iterator b 130 | iter_b = iter(loader[1]) 131 | batch_b = next(iter_b, None) 132 | 133 | yield batch_a, batch_b 134 | else: 135 | yield next(iter_a) 136 | 137 | elif strategy == "cycle": 138 | for i in range(max(sizes)): 139 | batch_a = next(iter_a) 140 | batch_b = next(iter_b, None) 141 | 142 | if batch_b is None: # reset iterator b 143 | iter_b = iter(loader[1]) 144 | batch_b = next(iter_b, None) 145 | 146 | yield batch_a, batch_b 147 | 148 | elif strategy == "beginning": 149 | for i in range(max(sizes)): 150 | batch_a = next(iter_a) 151 | batch_b = next(iter_b, None) 152 | 153 | if batch_b is not None: 154 | yield batch_a, batch_b 155 | else: 156 | yield batch_a 157 | else: 158 | raise ValueError("Invalid iteration strategy!") 159 | 160 | def _dataset_iterator(self, loader, strategy=None, step=1): 161 | # if all datasets have the same size 162 | if isinstance(loader, (tuple, list)): 163 | if len(set(len(x) for x in loader)) == 1: 164 | return zip(*loader) 165 | else: 166 | return self._multi_dataset_iter(loader, strategy, step) 167 | else: 168 | return loader 169 | 170 | def _aggregate_losses(self, batch_losses, loss_weights=None): 171 | """ 172 | This function computes a weighted sum of the models losses 173 | Args: 174 | batch_losses(torch.Tensor, tuple): 175 | 176 | Returns: 177 | loss_sum (int): the aggregation of the constituent losses 178 | loss_list (list, int): the constituent losses 179 | 180 | """ 181 | if isinstance(batch_losses, (tuple, list)): 182 | 183 | if loss_weights is None: 184 | loss_weights = self.loss_weights 185 | loss_weights = [self.anneal_step(w) for w in loss_weights] 186 | 187 | if loss_weights is None: 188 | loss_sum = sum(batch_losses) 189 | loss_list = [x.item() for x in batch_losses] 190 | else: 191 | loss_sum = sum(w * x for x, w in 192 | zip(batch_losses, loss_weights)) 193 | 194 | loss_list = [w * x.item() for x, w in 195 | zip(batch_losses, loss_weights)] 196 | else: 197 | loss_sum = batch_losses 198 | loss_list = batch_losses.item() 199 | return loss_sum, loss_list 200 | -------------------------------------------------------------------------------- /modules/training/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy 4 | import torch 5 | from torch.nn.utils import clip_grad_norm_ 6 | 7 | from modules.training.base_trainer import BaseTrainer 8 | from utils._logging import epoch_progress 9 | from utils.training import save_checkpoint 10 | 11 | 12 | class Trainer(BaseTrainer): 13 | """ 14 | An abstract class representing a Trainer. 15 | A Trainer object, is responsible for handling the training process and 16 | provides various helper methods. 17 | 18 | All other trainers should subclass it. 19 | All subclasses should override process_batch, which handles the way 20 | you feed the input data to the model and performs a forward pass. 21 | """ 22 | 23 | def __init__(self, model, train_loader, valid_loader, criterion, 24 | optimizers, config, device, 25 | batch_end_callbacks=None, loss_weights=None, **kwargs): 26 | 27 | super().__init__(train_loader, valid_loader, config, device, 28 | batch_end_callbacks, loss_weights, **kwargs) 29 | 30 | self.model = model 31 | self.criterion = criterion 32 | self.optimizers = optimizers 33 | 34 | if not isinstance(self.optimizers, (tuple, list)): 35 | self.optimizers = [self.optimizers] 36 | 37 | def _process_batch(self, *args): 38 | raise NotImplementedError 39 | 40 | def _seq_loss(self, logits, labels): 41 | 42 | """ 43 | Compute a sequence loss (i.e. per timestep). 44 | Used for tasks such as Translation, Language Modeling and 45 | Sequence Labelling. 46 | """ 47 | _logits = logits.contiguous().view(-1, logits.size(-1)) 48 | _labels = labels.contiguous().view(-1) 49 | loss = self.criterion(_logits, _labels) 50 | 51 | return loss 52 | 53 | def grads(self): 54 | """ 55 | Get the list of the norms of the gradients for each parameter 56 | """ 57 | return [(name, parameter.grad.norm().item()) 58 | for name, parameter in self.model.named_parameters() 59 | if parameter.requires_grad and parameter.grad is not None] 60 | 61 | def train_epoch(self, pre_train_epoch, batch_num, writer): 62 | """ 63 | Train the network for one epoch and return the average loss. 64 | * This will be a pessimistic approximation of the true loss 65 | of the network, as the loss of the first batches will be higher 66 | than the true. 67 | 68 | Returns: 69 | loss (float, list(float)): list of mean losses 70 | 71 | """ 72 | self.model.train() 73 | losses = [] 74 | 75 | self.epoch += 1 76 | epoch_start = time.time() 77 | 78 | iterator = self._dataset_iterator(self.train_loader) 79 | for i_batch, batch in enumerate(iterator, 1): 80 | 81 | self.step += 1 82 | 83 | # zero gradients 84 | for optimizer in self.optimizers: 85 | optimizer.zero_grad() 86 | 87 | batch = self._batch_to_device(batch) 88 | if batch[0].size(0) != self.batch_size: 89 | continue 90 | 91 | # return here only the first batch losses, in order to avoid 92 | # breaking the existing framework 93 | # pre-train enc1-dec3 by self-supervised training 94 | batch_losses, batch_outputs = self._process_batch(*batch) 95 | 96 | # aggregate the losses into a single loss value 97 | loss_sum, loss_list = self._aggregate_losses(batch_losses) 98 | losses.append(loss_list) 99 | writer.add_scalar('Train/loss', loss_sum, self.step) 100 | loss_count = 0 101 | writer.add_scalar('Train/nsent1_loss', loss_list[loss_count], self.step) 102 | if self.config["model"]["n_sent_sum_loss"]: 103 | loss_count += 1 104 | writer.add_scalar('Train/nsent2_loss', loss_list[loss_count], self.step) 105 | if self.config["model"]["prior_loss"]: 106 | loss_count += 1 107 | writer.add_scalar('Train/lm_loss', loss_list[loss_count], self.step) 108 | if self.config["model"]["topic_loss"]: 109 | loss_count += 1 110 | writer.add_scalar('Train/topic_loss', loss_list[loss_count], self.step) 111 | if self.config["model"]["length_loss"]: 112 | loss_count += 1 113 | writer.add_scalar('Train/length_loss', loss_list[loss_count], self.step) 114 | if self.config["model"]["doc_sum_kl_loss"]: 115 | loss_count += 1 116 | writer.add_scalar('Train/kl_loss', loss_list[loss_count], self.step) 117 | if self.config["model"]["doc_sum_sim_loss"]: 118 | loss_count += 1 119 | writer.add_scalar('Train/doc_sim_loss', loss_list[loss_count], self.step) 120 | if self.config["model"]["sum_loss"]: 121 | loss_count += 1 122 | writer.add_scalar('Train/sum_loss', loss_list[loss_count], self.step) 123 | if self.config["model"]["nsent_classification"]: 124 | loss_count += 1 125 | writer.add_scalar('Train/cls_loss', loss_list[loss_count], self.step) 126 | if self.config["model"]["nsent_classification_sum"]: 127 | loss_count += 1 128 | writer.add_scalar('Train/cls_sum_loss', loss_list[loss_count], self.step) 129 | if self.config["model"]["nsent_classification_kl"]: 130 | loss_count += 1 131 | writer.add_scalar('Train/cla_kl_loss', loss_list[loss_count], self.step) 132 | writer.flush() 133 | 134 | # back-propagate 135 | loss_sum.backward() 136 | 137 | if self.clip is not None: 138 | # clip_grad_norm_(self.model.parameters(), self.clip) 139 | for optimizer in self.optimizers: 140 | clip_grad_norm_((p for group in optimizer.param_groups 141 | for p in group['params']), self.clip) 142 | 143 | # update weights 144 | for optimizer in self.optimizers: 145 | optimizer.step() 146 | 147 | if self.step % self.log_interval == 0: 148 | self.progress_log = epoch_progress(self.epoch, i_batch, 149 | self.batch_size, 150 | self.train_set_size, 151 | epoch_start) 152 | 153 | for c in self.batch_end_callbacks: 154 | if callable(c): 155 | c(batch, losses, loss_list, batch_outputs, self.epoch) 156 | try: 157 | return numpy.array(losses).mean(axis=0) 158 | except: # parallel losses 159 | return numpy.array([x[:len(self.loss_weights) - 1] 160 | for x in losses]).mean(axis=0) 161 | 162 | def eval_epoch(self): 163 | """ 164 | Evaluate the network for one epoch and return the average loss. 165 | 166 | Returns: 167 | loss (float, list(float)): list of mean losses 168 | 169 | """ 170 | self.model.eval() 171 | losses = [] 172 | 173 | iterator = self._dataset_iterator(self.valid_loader) 174 | with torch.no_grad(): 175 | for i_batch, batch in enumerate(iterator, 1): 176 | batch = self._batch_to_device(batch) 177 | 178 | batch_losses, batch_outputs = self._process_batch(*batch) 179 | 180 | # aggregate the losses into a single loss value 181 | loss, _losses = self._aggregate_losses(batch_losses) 182 | losses.append(_losses) 183 | 184 | return numpy.array(losses).mean(axis=0) 185 | 186 | def get_state(self): 187 | """ 188 | Return a dictionary with the current state of the model. 189 | The state should contain all the important properties which will 190 | be save when taking a model checkpoint. 191 | Returns: 192 | state (dict) 193 | 194 | """ 195 | state = { 196 | "config": self.config, 197 | "epoch": self.epoch, 198 | "step": self.step, 199 | "model": self.model.state_dict(), 200 | "model_class": self.model.__class__.__name__, 201 | "optimizers": [x.state_dict() for x in self.optimizers], 202 | } 203 | 204 | return state 205 | 206 | def checkpoint(self, name=None, timestamp=False, tags=None, verbose=False): 207 | 208 | if name is None: 209 | name = self.config["name"] 210 | 211 | return save_checkpoint(self.get_state(), 212 | name=name, tag=tags, timestamp=timestamp, 213 | verbose=verbose) 214 | -------------------------------------------------------------------------------- /mylogger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiyan524/RepSum/cddc47188e445ccdc23b30e9d8d5f2daa16b7c1d/mylogger/__init__.py -------------------------------------------------------------------------------- /mylogger/attention.py: -------------------------------------------------------------------------------- 1 | import html 2 | import os 3 | 4 | import numpy 5 | import numpy as np 6 | 7 | 8 | def viz_sequence(words, scores=None, color="255, 0, 0"): 9 | text = [] 10 | 11 | if scores is None: 12 | scores = [0] * len(words) 13 | else: 14 | # mean = numpy.mean(scores) 15 | # std = numpy.std(scores) 16 | # scores = [(x - mean) / (6 * std) for x in scores] 17 | 18 | length = len([x for x in scores if x != 0]) 19 | scores = [x / sum(scores) for x in scores] 20 | mean = numpy.mean(scores[:length]) 21 | std = numpy.std(scores[:length]) 22 | 23 | scores = [max(0, (x - mean) / (6 * std)) for x in scores] 24 | 25 | # score = (score - this.att_mean) / (4 * this.att_std); 26 | for word, score in zip(words, scores): 27 | text.append(f"{html.escape(word)}") 29 | return "".join(text) 30 | 31 | 32 | def viz_summary(seqs): 33 | txt = "" 34 | for name, data, color in seqs: 35 | if isinstance(data, tuple): 36 | _text = viz_sequence(data[0], data[1], color=color) 37 | length = len(data[0]) 38 | else: 39 | _text = viz_sequence(data) 40 | length = len(data) 41 | 42 | txt += f"
{name}({length}): {_text}
" 43 | 44 | return f"
{txt}
" 45 | 46 | 47 | def sample(words): 48 | return np.random.dirichlet(np.ones(len(words))) 49 | 50 | 51 | def samples2dom(samples): 52 | dom = """ 53 | 54 | 55 | 56 | 57 | 79 | 80 | """ 81 | for s in samples: 82 | dom += viz_summary(s) 83 | 84 | dom += """ 85 | 86 | 87 | """ 88 | return dom 89 | 90 | 91 | def samples2html(samples): 92 | dom = """ 93 | 123 |
124 | """ 125 | 126 | for s in samples: 127 | dom += viz_summary(s) 128 | 129 | dom += """ 130 |
131 | """ 132 | return dom 133 | 134 | 135 | def viz_seq3(dom): 136 | # or simply save in an html file and open in browser 137 | file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 138 | 'attention.html') 139 | 140 | with open(file, 'w') as f: 141 | f.write(dom) 142 | 143 | # samples = [] 144 | # for i in range(10): 145 | # source = lorem.sentence().split() 146 | # scores = sample(source) 147 | # summary = lorem.sentence().split() 148 | # reconstruction = lorem.sentence().split() 149 | # samples.append(((source, scores), summary, reconstruction)) 150 | # viz_seq3(samples2html(samples)) 151 | -------------------------------------------------------------------------------- /mylogger/db.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiyan524/RepSum/cddc47188e445ccdc23b30e9d8d5f2daa16b7c1d/mylogger/db.json -------------------------------------------------------------------------------- /mylogger/experiment.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import sys 5 | import time 6 | from collections import defaultdict 7 | from datetime import datetime 8 | 9 | from pymongo import MongoClient 10 | from tabulate import tabulate 11 | 12 | from mylogger.helpers import dict_to_html, files_to_dict 13 | from mylogger.plotting import Visualizer 14 | from sys_config import VIS, BASE_DIR 15 | 16 | 17 | class Experiment(object): 18 | """ 19 | Experiment class 20 | """ 21 | 22 | def __init__(self, name, config, desc=None, 23 | output_dir=None, 24 | src_dirs=None, 25 | use_db=True, 26 | db_host="localhost", 27 | db_port=27017, 28 | db_uri=None, 29 | db_name="experiments"): 30 | """ 31 | 32 | Metrics = history of values 33 | Values = state of values 34 | Args: 35 | name: 36 | config: 37 | desc: 38 | output_dir: 39 | src_dirs: 40 | use_db: 41 | db_host: 42 | db_port: 43 | db_uri: mongodb://[username:password@]host1[:port1] 44 | db_name: 45 | """ 46 | self.name = name 47 | self.desc = desc 48 | self.config = config 49 | self.metrics = defaultdict(Metric) 50 | self.values = defaultdict(Value) 51 | 52 | self.use_db = use_db 53 | self.db_host = db_host 54 | self.db_port = db_port 55 | self.db_uri = db_uri 56 | self.db_name = db_name 57 | 58 | # the src files (dirs) to backup 59 | if src_dirs is not None: 60 | self.src = files_to_dict(src_dirs) 61 | else: 62 | self.src = None 63 | 64 | # the currently running script 65 | self.src_main = sys.argv[0] 66 | 67 | self.timestamp_start = datetime.now() 68 | self.timestamp_update = datetime.now() 69 | self.last_update = time.time() 70 | 71 | if output_dir is not None: 72 | self.output_dir = output_dir 73 | else: 74 | self.output_dir = BASE_DIR 75 | 76 | server = VIS["server"] 77 | port = VIS["port"] 78 | base_url = VIS["base_url"] 79 | http_proxy_host = VIS["http_proxy_host"] 80 | http_proxy_port = VIS["http_proxy_port"] 81 | self.enabled = VIS["enabled"] 82 | vis_log_file = os.path.join(self.output_dir, f"{self.name}.vis") 83 | 84 | if self.enabled: 85 | self.viz = Visualizer(env=name, 86 | server=server, 87 | port=port, 88 | base_url=base_url, 89 | http_proxy_host=http_proxy_host, 90 | http_proxy_port=http_proxy_port, 91 | log_to_filename=vis_log_file) 92 | 93 | self.add_value("config", "text") 94 | self.update_value("config", dict_to_html(self.config)) 95 | 96 | # connect to MongoDB 97 | if self.use_db and self.enabled: 98 | if self.db_uri: 99 | self.db_client = MongoClient(self.db_uri) 100 | else: 101 | self.db_client = MongoClient(self.db_host, self.db_port) 102 | 103 | self.db = self.db_client[self.db_name] 104 | self.db_collection = self.db.experiments 105 | self.db_record = None 106 | 107 | ############################################################# 108 | # Metric 109 | ############################################################# 110 | def add_metric(self, key, vis_type, title=None, tags=None): 111 | """ 112 | Add a new metric to the experiment. 113 | Metrics hold a history of all the inserted values. 114 | The last value(s) will be used for presentation (plotting and console) 115 | Args: 116 | key (str): the name of the value. This will be used for getting 117 | a handle of the metric 118 | vis_type (str): the visualization type 119 | tags (list): list of tags e.g. ["train_set", "val_set"] 120 | title (str): used for presentation purposes (figure, console...) 121 | 122 | Returns: 123 | 124 | """ 125 | self.metrics[key] = Metric(key, vis_type, tags, title) 126 | 127 | def get_metric(self, key): 128 | """ 129 | Returns a handle to the metric with the given key 130 | Args: 131 | key: 132 | 133 | Returns: 134 | 135 | """ 136 | return self.metrics[key] 137 | 138 | def update_metric(self, key, value, tag=None): 139 | """ 140 | Add new value to the given metric 141 | Args: 142 | key: 143 | value: 144 | tag: 145 | 146 | Returns: 147 | 148 | """ 149 | self.get_metric(key).add(value, tag) 150 | 151 | try: 152 | if self.enabled: 153 | self.__plot_metric(key) 154 | 155 | except IndexError as e: 156 | pass 157 | 158 | except Exception as e: 159 | print(f"An error occurred while trying to plot metric:{key}") 160 | 161 | def __plot_metric(self, key): 162 | 163 | metric = self.get_metric(key) 164 | 165 | if metric.vis_type == "line": 166 | 167 | if metric.tags is not None: 168 | x = [[len(metric.values[tag])] for tag in metric.tags] 169 | y = [[metric.values[tag][-1]] for tag in metric.tags] 170 | else: 171 | x = [len(metric.values)] 172 | y = [metric.values[-1]] 173 | self.viz.plot_line(y, x, metric.title, metric.tags) 174 | 175 | elif metric.vis_type == "scatter": 176 | raise NotImplementedError 177 | elif metric.vis_type == "bar": 178 | raise NotImplementedError 179 | else: 180 | raise NotImplementedError 181 | 182 | ############################################################# 183 | # Value 184 | ############################################################# 185 | def add_value(self, key, vis_type, title=None, tags=None, init=None): 186 | self.values[key] = Value(key, vis_type, tags, title) 187 | 188 | def get_value(self, key): 189 | return self.values[key] 190 | 191 | def update_value(self, key, value, tag=None): 192 | """ 193 | Update the state of the given value 194 | Args: 195 | key: 196 | value: 197 | tag: 198 | 199 | Returns: 200 | 201 | """ 202 | self.get_value(key).update(value, tag) 203 | 204 | try: 205 | if self.enabled: 206 | self.__plot_value(key) 207 | 208 | except IndexError as e: 209 | pass 210 | 211 | except Exception as e: 212 | print(f"An error occurred while trying to plot value:{key}") 213 | 214 | def __plot_value(self, key): 215 | value = self.get_value(key) 216 | 217 | if value.vis_type == "text": 218 | self.viz.plot_text(value.value, value.title, pre=value.pre) 219 | elif value.vis_type == "scatter": 220 | if value.tags is not None: 221 | raise NotImplementedError 222 | else: 223 | data = value.value 224 | 225 | self.viz.plot_scatter(data[0], data[1], value.title) 226 | elif value.vis_type == "heatmap": 227 | if value.tags is not None: 228 | raise NotImplementedError 229 | else: 230 | data = value.value 231 | 232 | self.viz.plot_heatmap(data[0], data[1], value.title) 233 | elif value.vis_type == "bar": 234 | if value.tags is not None: 235 | raise NotImplementedError 236 | else: 237 | data = value.value 238 | 239 | self.viz.plot_bar(data[0], data[1], value.title) 240 | else: 241 | raise NotImplementedError 242 | 243 | ############################################################# 244 | # Persistence 245 | ############################################################# 246 | def _state_dict(self): 247 | omit = ["db", "db_client", "db_collection"] 248 | state = {k: v for k, v in self.__dict__.items() if k not in omit} 249 | 250 | return state 251 | 252 | def to_db(self): 253 | self.timestamp_update = datetime.now() 254 | # record = self._state_dict() 255 | 256 | # todo: avoid this workaround 257 | record = json.loads(self._serialize()) 258 | 259 | if self.db_record is None: 260 | self.db_record = self.db_collection.insert(record) 261 | else: 262 | self.db_collection.replace_one({"_id": self.db_record}, record) 263 | 264 | def _serialize(self): 265 | 266 | data = json.dumps(self._state_dict(), 267 | default=lambda o: getattr(o, '__dict__', str(o))) 268 | return data 269 | 270 | def to_json(self): 271 | self.timestamp_update = datetime.now() 272 | name = self.name + "_{}.json".format(self.get_timestamp()) 273 | filename = os.path.join(self.output_dir, name) 274 | with open(filename, 'w', encoding='utf-8') as f: 275 | f.write(self._serialize()) 276 | 277 | def get_timestamp(self): 278 | return self.timestamp_start.strftime("%y-%m-%d_%H:%M:%S") 279 | 280 | def to_pickle(self): 281 | self.timestamp_update = datetime.now() 282 | name = self.name + "_{}.pickle".format(self.get_timestamp()) 283 | filename = os.path.join(self.output_dir, name) 284 | with open(filename, 'wb') as f: 285 | pickle.dump(self._state_dict(), f) 286 | 287 | def save(self): 288 | try: 289 | self.to_pickle() 290 | except: 291 | print("Failed to save to pickle...") 292 | 293 | # try: 294 | # self.to_json() 295 | # except: 296 | # print("Failed to save to json...") 297 | 298 | # try: 299 | # self.to_db() 300 | # except: 301 | # print("Failed to save to db...") 302 | 303 | def log_metrics(self, keys, epoch): 304 | 305 | _metrics = [self.metrics[key] for key in keys] 306 | _tags = _metrics[0].tags 307 | if _tags is not None: 308 | values = [[tag] + [metric.values[tag][-1] for metric in _metrics] for tag in _tags] 309 | headers = ["TAG"] + [metric.title.upper() for metric in _metrics] 310 | else: 311 | values = [[metric.values[-1] for metric in _metrics]] 312 | headers = [metric.title.upper() for metric in _metrics] 313 | 314 | log_output = tabulate(values, headers, floatfmt=".4f") 315 | 316 | return log_output 317 | 318 | 319 | class Metric(object): 320 | """ 321 | Metric hold the data of a value of the model that is being monitored 322 | 323 | A Metric object has to have a name, 324 | a vis_type which defines how it will be visualized 325 | and a dataset on which it will be attached to. 326 | """ 327 | 328 | def __init__(self, key, vis_type, tags=None, title=None): 329 | """ 330 | 331 | Args: 332 | key (str): the name of the metric 333 | vis_type (str): the visualization type 334 | tags (list): list of tags 335 | title (str): used for presentation purposes (figure, console...) 336 | """ 337 | self.key = key 338 | self.title = title 339 | self.vis_type = vis_type 340 | self.tags = tags 341 | 342 | assert vis_type in ["line"] 343 | 344 | if tags is not None: 345 | self.values = {tag: [] for tag in tags} 346 | else: 347 | self.values = [] 348 | 349 | if title is None: 350 | self.title = key 351 | 352 | def add(self, value, tag=None): 353 | """ 354 | Add a value to the list of values of this metric 355 | Args: 356 | value (int, float): 357 | tag (str): 358 | 359 | Returns: 360 | 361 | """ 362 | if self.tags is not None: 363 | self.values[tag].append(value) 364 | else: 365 | self.values.append(value) 366 | 367 | 368 | class Value(object): 369 | """ 370 | 371 | """ 372 | 373 | def __init__(self, key, vis_type, tags=None, title=None, pre=True): 374 | """ 375 | 376 | Args: 377 | key (str): the name of the value 378 | vis_type (str): the visualization type 379 | tags (list): list of tags 380 | title (str): used for presentation purposes (figure, console...) 381 | """ 382 | self.key = key 383 | self.title = title 384 | self.vis_type = vis_type 385 | self.tags = tags 386 | self.pre = pre 387 | 388 | assert vis_type in ["text", "scatter", "bar", "heatmap"] 389 | 390 | if tags is not None: 391 | self.value = {tag: [] for tag in tags} 392 | else: 393 | self.value = [] 394 | 395 | if title is None: 396 | self.title = key 397 | 398 | def update(self, value, tag=None): 399 | """ 400 | Update the value 401 | Args: 402 | value (int, float): 403 | tag (str): 404 | 405 | Returns: 406 | 407 | """ 408 | if self.tags is not None: 409 | self.value[tag] = value 410 | else: 411 | self.value = value 412 | -------------------------------------------------------------------------------- /mylogger/helpers.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | from glob2 import glob 7 | 8 | 9 | def files_to_dict(dirs, safe=True): 10 | file_data = defaultdict(dict) 11 | 12 | for dir in dirs: 13 | for file in glob(os.path.join(dir + "/*.py")): 14 | _dir = os.path.split(dir)[1] 15 | filename = os.path.basename(file) 16 | if safe: 17 | filename = filename.replace('.', '[dot]') 18 | file_data[_dir][filename] = Path(file).read_text() 19 | 20 | return file_data 21 | 22 | 23 | def dict_to_html(config): 24 | indent = 2 25 | msg = json.dumps(config, indent=indent) 26 | msg = "\n".join([line[2:].rstrip() for line in msg.split("\n") 27 | if len(line.strip()) > 3]) 28 | # format with html 29 | msg = msg.replace('{', '') 30 | msg = msg.replace('}', '') 31 | # msg = msg.replace('\n', '
') 32 | return msg 33 | -------------------------------------------------------------------------------- /mylogger/inspection.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from visdom import Visdom 3 | 4 | from mylogger.plotting import plot_line 5 | 6 | 7 | class Inspector(object): 8 | """ 9 | Class for inspecting the internals of neural networks 10 | """ 11 | 12 | def __init__(self, model, stats): 13 | """ 14 | 15 | Args: 16 | model (torch.nn.Module): the PyTorch model 17 | stats (list): list of stats names. e.g. ["std", "mean"] 18 | """ 19 | 20 | # watch only trainable layers 21 | self.watched_layers = {} 22 | for name, module in self.get_watched_modules(model): 23 | self.watched_layers[name] = {stat: [] for stat in stats} 24 | 25 | self.viz = Visdom() 26 | self.update_state(model) 27 | 28 | def get_watched_modules(self, model): 29 | all_modules = [] 30 | for name, module in model.named_modules(): 31 | if len(list(module.parameters())) > 0 and all( 32 | param.requires_grad for param in module.parameters()): 33 | all_modules.append((name, module)) 34 | 35 | # filter parent nodes 36 | fitered_modules = [] 37 | for name, module in all_modules: 38 | if not any( 39 | [(name in n and name is not n) for n, m in all_modules]): 40 | fitered_modules.append((name, module)) 41 | 42 | return fitered_modules 43 | 44 | def plot_layer(self, name, weights): 45 | self.viz.histogram(X=weights, 46 | win=name, 47 | opts=dict(title="{} weights dist".format(name), 48 | numbins=40)) 49 | for stat_name, stat_val in self.watched_layers[name].items(): 50 | stat_val.append(getattr(numpy, stat_name)(weights)) 51 | 52 | plot_name = "{}-{}".format(name, stat_name) 53 | plot_line(self.viz, numpy.array(stat_val), plot_name, [plot_name]) 54 | 55 | def update_state(self, model): 56 | gen = (child for child in model.named_modules() 57 | if child[0] in self.watched_layers) 58 | for name, layer in gen: 59 | weights = [param.data.cpu().numpy() for param in 60 | layer.parameters()] 61 | if len(weights) > 0: 62 | weights = numpy.concatenate([w.ravel() for w in weights]) 63 | self.plot_layer(name, weights) 64 | -------------------------------------------------------------------------------- /mylogger/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from visdom import Visdom 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | 6 | 7 | class Visualizer: 8 | 9 | def __init__(self, env="main", 10 | server="http://localhost", 11 | port=8097, 12 | base_url="/", 13 | http_proxy_host=None, 14 | http_proxy_port=None, 15 | log_to_filename=None): 16 | self._viz = Visdom(env=env, 17 | server=server, 18 | port=port, 19 | http_proxy_host=http_proxy_host, 20 | http_proxy_port=http_proxy_port, 21 | log_to_filename=log_to_filename, 22 | use_incoming_socket=False) 23 | self._viz.close(env=env) 24 | 25 | def plot_line(self, values, steps, name, legend=None): 26 | if legend is None: 27 | opts = dict(title=name) 28 | else: 29 | opts = dict(title=name, legend=legend) 30 | 31 | self._viz.line( 32 | X=numpy.column_stack(steps), 33 | Y=numpy.column_stack(values), 34 | win=name, 35 | update='append', 36 | opts=opts 37 | ) 38 | 39 | def plot_text(self, text, title, pre=True): 40 | _width = max([len(x) for x in text.split("\n")]) * 10 41 | _heigth = len(text.split("\n")) * 20 42 | _heigth = max(_heigth, 120) 43 | if pre: 44 | text = "
{}
".format(text) 45 | 46 | self._viz.text(text, win=title, opts=dict(title=title, 47 | width=min(_width, 400), 48 | height=min(_heigth, 400))) 49 | 50 | def plot_bar(self, data, labels, title): 51 | self._viz.bar(win=title, X=data, 52 | opts=dict(legend=labels, stacked=False, title=title)) 53 | 54 | def plot_scatter(self, data, labels, title): 55 | X = numpy.concatenate(data, axis=0) 56 | Y = numpy.concatenate([numpy.full(len(d), i) 57 | for i, d in enumerate(data, 1)], axis=0) 58 | self._viz.scatter(win=title, X=X, Y=Y, 59 | opts=dict(legend=labels, title=title, 60 | markersize=5, 61 | webgl=True, 62 | width=400, 63 | height=400, 64 | markeropacity=0.5)) 65 | 66 | def plot_heatmap(self, data, labels, title): 67 | self._viz.heatmap(win=title, 68 | X=data, 69 | opts=dict( 70 | title=title, 71 | columnnames=labels[1], 72 | rownames=labels[0], 73 | width=700, 74 | height=700, 75 | layoutopts={'plotly': { 76 | 'xaxis': { 77 | 'side': 'top', 78 | 'tickangle': -60, 79 | # 'autorange': "reversed" 80 | }, 81 | 'yaxis': { 82 | 'autorange': "reversed" 83 | }, 84 | } 85 | } 86 | )) 87 | -------------------------------------------------------------------------------- /rouge-test.py: -------------------------------------------------------------------------------- 1 | import files2rouge 2 | import chardet 3 | import codecs 4 | import os 5 | 6 | dec_path = "" 7 | ref_path = "" 8 | result_id_path = "" 9 | 10 | 11 | def tokens_to_ids(token_list1, token_list2): 12 | ids = {} 13 | out1 = [] 14 | out2 = [] 15 | for token in token_list1: 16 | out1.append(ids.setdefault(token, len(ids))) 17 | for token in token_list2: 18 | out2.append(ids.setdefault(token, len(ids))) 19 | 20 | return out1, out2 21 | 22 | def write_id(id_lst, file): 23 | for id in id_lst: 24 | file.write(str(id)+" ") 25 | file.write("\n") 26 | 27 | def trans_id(dec_path, ref_path): 28 | dec_files = codecs.open(dec_path, encoding="utf-8").read().split("\n") 29 | ref_files = codecs.open(ref_path, encoding="utf-8").read().split("\n") 30 | 31 | dec_files_id = codecs.open(os.path.join(result_id_path, "decode_id_tmp.txt"), 'a') 32 | ref_files_id = codecs.open(os.path.join(result_id_path, "reference_id_tmp.txt"), 'a') 33 | 34 | sample_num = len(dec_files) 35 | for index in range(sample_num): 36 | dec_file = dec_files[index].split(" ") 37 | ref_file = ref_files[index].split(" ") 38 | dec_id, ref_id = tokens_to_ids(dec_file, ref_file) 39 | write_id(dec_id, dec_files_id) 40 | write_id(ref_id, ref_files_id) 41 | 42 | 43 | #trans_id(dec_path, ref_path) 44 | #files2rouge.run(os.path.join(result_id_path, "decode_id_tmp.txt"), 45 | #os.path.join(result_id_path, "reference_id_tmp.txt"), 46 | #os.path.join(result_id_path, "results.txt")) 47 | 48 | files2rouge.run(ref_path, dec_path) -------------------------------------------------------------------------------- /sys_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | print("torch:", torch.__version__) 6 | print("Cuda:", torch.backends.cudnn.cuda) 7 | print("CuDNN:", torch.backends.cudnn.version()) 8 | 9 | CPU_CORES = 4 10 | RANDOM_SEED = 1618 11 | 12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | 14 | MODEL_CNF_DIR = os.path.join(BASE_DIR, "model_configs") 15 | 16 | TRAINED_PATH = os.path.join(BASE_DIR, "checkpoints") 17 | 18 | EMBS_PATH = os.path.join(BASE_DIR, "embeddings") 19 | 20 | DATA_DIR = os.path.join(BASE_DIR, 'datasets') 21 | 22 | EXP_DIR = os.path.join(BASE_DIR, 'experiments') 23 | 24 | MODEL_DIRS = ["models", "modules", "utils"] 25 | 26 | VIS = { 27 | "server": "http://localhost", 28 | "enabled": False, 29 | "port": 8097, 30 | "base_url": "/", 31 | "http_proxy_host": None, 32 | "http_proxy_port": None, 33 | "log_to_filename": os.path.join(BASE_DIR, "vis_logger.json") 34 | } 35 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.insert(0, '.') 4 | sys.path.insert(0, '..') 5 | sys.path.insert(0, '../../') 6 | sys.path.insert(0, '../../../') 7 | -------------------------------------------------------------------------------- /utils/_logging.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | import time 4 | 5 | from tabulate import tabulate 6 | 7 | 8 | def erase_line(): 9 | sys.stdout.write("\033[K") 10 | 11 | 12 | def asMinutes(s): 13 | m = math.floor(s / 60) 14 | s -= m * 60 15 | return '%dm %ds' % (m, s) 16 | 17 | 18 | def timeSince(since, percent): 19 | now = time.time() 20 | s = now - since 21 | es = s / percent 22 | rs = es - s 23 | return asMinutes(s), asMinutes(rs) 24 | 25 | 26 | def log_seq3_losses(L1_LM, L1_AE, L2_LM, L2_AE, 27 | L1_LMD, L1_TRANSD, L2_LMD, L2_TRANSD, 28 | L1_TRANSG, L2_TRANSG): 29 | losses = [] 30 | losses.append(["L1", L1_LM, L1_AE, math.exp(L1_LM), math.exp(L1_AE), 31 | L1_LMD, L1_TRANSD, L1_TRANSG]) 32 | losses.append(["L2", L2_LM, L2_AE, math.exp(L2_LM), math.exp(L2_AE), 33 | L2_LMD, L2_TRANSD, L2_TRANSG]) 34 | return tabulate(losses, 35 | headers=['Lang', 'LM Loss', 'AE Loss', 'LM PPL', 'AE PPL', 36 | 'LM-D Loss', 'TRANS-D Loss', 'TRANS-G Loss'], 37 | floatfmt=".4f") 38 | 39 | 40 | def progress_bar(percentage, bar_len=20): 41 | filled_len = int(round(bar_len * percentage)) 42 | bar = '=' * filled_len + '-' * (bar_len - filled_len) 43 | return "[{}]".format(bar) 44 | 45 | 46 | def epoch_progress(epoch, batch, batch_size, dataset_size, start): 47 | n_batches = math.ceil(float(dataset_size) / batch_size) 48 | percentage = batch / n_batches 49 | 50 | # stats = 'Epoch:{}, Batch:{}/{} ({0:.2f}%)'.format(epoch, batch, n_batches, 51 | # percentage) 52 | stats = f'Epoch:{epoch}, Batch:{batch}/{n_batches} ' \ 53 | f'({100* percentage:.0f}%)' 54 | # stats = f'Epoch:{epoch}, Batch:{batch} ({100* percentage:.0f}%)' 55 | 56 | elapsed, eta = timeSince(start, batch / n_batches) 57 | time_info = 'Time: {} (-{})'.format(elapsed, eta) 58 | 59 | # clean every line and then add the text output 60 | # log_output = stats + " " + progress_bar + ", " + time_info 61 | 62 | # log_output = " ".join([stats, time_info]) 63 | log_output = " ".join([stats, progress_bar(percentage), time_info]) 64 | 65 | sys.stdout.write("\r \r\033[K" + log_output) 66 | sys.stdout.flush() 67 | return log_output 68 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import yaml 4 | 5 | from sys_config import DATA_DIR 6 | 7 | 8 | def get_parser(): 9 | """Get parser object.""" 10 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 11 | parser = ArgumentParser(description=__doc__, 12 | formatter_class=ArgumentDefaultsHelpFormatter) 13 | parser.add_argument("-cfg", "--config", 14 | dest="cfg", 15 | help="experiment definition file", 16 | metavar="FILE", 17 | required=True) 18 | return parser 19 | 20 | 21 | def make_paths(cfg): 22 | """ 23 | Make all values for keys ending with `_path` absolute to dir_. 24 | """ 25 | for key in cfg.keys(): 26 | if key.endswith("_path"): 27 | if cfg[key] is not None: 28 | cfg[key] = os.path.join(DATA_DIR, cfg[key]) 29 | cfg[key] = os.path.abspath(cfg[key]) 30 | if type(cfg[key]) is dict: 31 | cfg[key] = make_paths(cfg[key]) 32 | return cfg 33 | 34 | 35 | def load_config(file): 36 | with open(file, 'r') as stream: 37 | cfg = yaml.load(stream) 38 | cfg = make_paths(cfg) 39 | 40 | return cfg 41 | -------------------------------------------------------------------------------- /utils/data_parsing.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | import random 5 | from collections import Counter, defaultdict 6 | from matplotlib import pyplot as plt 7 | 8 | import numpy 9 | from glob2 import glob 10 | from sklearn.model_selection import train_test_split 11 | from sklearn.preprocessing import LabelBinarizer 12 | 13 | from sys_config import DATA_DIR 14 | 15 | 16 | def read_amazon(file): 17 | reviews = [] 18 | summaries = [] 19 | labels = [] 20 | 21 | with open(file) as f: 22 | for line in f: 23 | entry = json.loads(line) 24 | reviews.append(entry["reviewText"]) 25 | summaries.append(entry["summary"]) 26 | labels.append(entry["overall"]) 27 | 28 | return reviews, summaries, labels 29 | 30 | 31 | def read_semeval(): 32 | def read_dataset(d): 33 | with open(os.path.join(DATA_DIR, "semeval", "E-c", 34 | "E-c-En-{}.txt".format(d))) as f: 35 | reader = csv.reader(f, delimiter='\t') 36 | labels = next(reader)[2:] 37 | 38 | _X = [] 39 | _y = [] 40 | for row in reader: 41 | _X.append(row[1]) 42 | _y.append([int(x) for x in row[2:]]) 43 | return _X, _y 44 | 45 | X_train, y_train = read_dataset("train") 46 | X_dev, y_dev = read_dataset("dev") 47 | X_test, y_test = read_dataset("test") 48 | 49 | X_train = X_train + X_test 50 | y_train = y_train + y_test 51 | 52 | return X_train, numpy.array(y_train), X_dev, numpy.array(y_dev) 53 | 54 | 55 | def imdb_get_index(): 56 | index = defaultdict(list) 57 | 58 | dirs = ["pos", "neg", "unsup"] 59 | sets = ["train", "test"] 60 | 61 | for s in sets: 62 | for d in dirs: 63 | for file in glob(os.path.join(DATA_DIR, "imdb", s, d) + "/*.txt"): 64 | index["_".join([s, d])].append(file) 65 | return index 66 | 67 | 68 | def get_imdb(): 69 | index = imdb_get_index() 70 | 71 | data = [] 72 | 73 | for ki, vi in index.items(): 74 | for f in vi: 75 | data.append(" ".join(open(f).readlines()).replace('
', '')) 76 | 77 | return data 78 | 79 | 80 | def read_emoji(split=0.1, min_freq=100, max_ex=1000000, top_n=None): 81 | X = [] 82 | y = [] 83 | with open(os.path.join(DATA_DIR, "emoji", "emoji_1m.txt")) as f: 84 | for i, line in enumerate(f): 85 | if i > max_ex: 86 | break 87 | emoji, text = line.rstrip().split("\t") 88 | X.append(text) 89 | y.append(emoji) 90 | 91 | counter = Counter(y) 92 | top = set(l for l, f in counter.most_common(top_n) if f > min_freq) 93 | 94 | data = [(_x, _y) for _x, _y in zip(X, y) if _y in top] 95 | 96 | total = len(data) 97 | 98 | data = [(_x, _y) for _x, _y in data if 99 | random.random() > counter[_y] / total] 100 | 101 | X = [x[0] for x in data] 102 | y = [x[1] for x in data] 103 | 104 | X_train, X_test, y_train, y_test = train_test_split(X, y, 105 | test_size=split, 106 | stratify=y, 107 | random_state=0) 108 | 109 | lb = LabelBinarizer() 110 | lb.fit(y_train) 111 | y_train = lb.transform(y_train) 112 | y_test = lb.transform(y_test) 113 | 114 | return X_train, y_train, X_test, y_test 115 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | import pandas 2 | 3 | import rouge 4 | import codecs 5 | import os 6 | import files2rouge 7 | from tabulate import tabulate 8 | 9 | def rouge_lists(refs, hyps): 10 | evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], 11 | max_n=2, 12 | limit_length=True, 13 | length_limit=100, 14 | length_limit_type='words', 15 | apply_avg=True, 16 | apply_best=False, 17 | alpha=0.5, # Default F1_score 18 | weight_factor=1.2, 19 | stemming=True) 20 | scores = evaluator.get_scores(hyps, refs) 21 | 22 | return scores 23 | 24 | 25 | def tokens_to_ids(token_list1, token_list2): 26 | ids = {} 27 | out1 = [] 28 | out2 = [] 29 | for token in token_list1: 30 | out1.append(ids.setdefault(token, len(ids))) 31 | for token in token_list2: 32 | out2.append(ids.setdefault(token, len(ids))) 33 | return out1, out2 34 | 35 | 36 | def write_id(id_lst, file): 37 | for id in id_lst: 38 | file.write(str(id)+" ") 39 | file.write("\n") 40 | 41 | 42 | def filter_file(file_lst): 43 | if "" == file_lst[-1]: 44 | del(file_lst[-1]) 45 | return file_lst 46 | 47 | 48 | # trans word to id to escape from chinese character 49 | def trans_id(path, hyps, ref_path): 50 | #dec_files = codecs.open(dec_path, encoding="utf-8").read().split("\n") 51 | dec_files = hyps 52 | ref_files = codecs.open(ref_path, encoding="utf-8").read().split("\n") 53 | dec_files = filter_file(dec_files) 54 | ref_files = filter_file(ref_files) 55 | 56 | dec_files_id = codecs.open(os.path.join(path, "decode_id_tmp.txt"), 'w') 57 | ref_files_id = codecs.open(os.path.join(path, "reference_id_tmp.txt"), 'w') 58 | 59 | results_path = os.path.join(path, "results.txt") 60 | 61 | sample_num = len(dec_files) 62 | for index in range(sample_num): 63 | dec_file = dec_files[index].split(" ") 64 | ref_file = ref_files[index].split(" ") 65 | dec_id, ref_id = tokens_to_ids(dec_file, ref_file) 66 | write_id(dec_id, dec_files_id) 67 | write_id(ref_id, ref_files_id) 68 | 69 | scores_str = files2rouge.run(os.path.join(path, "decode_id_tmp.txt"), os.path.join(path, "reference_id_tmp.txt")) 70 | return scores_str 71 | 72 | 73 | def rouge_files(path, refs_file, hyps): 74 | #refs = open(refs_file).readlines() 75 | #hyps = open(hyps_file).readlines() 76 | #scores = rouge_lists(refs, hyps) 77 | scores_str = trans_id(path, hyps, refs_file) 78 | result_file = codecs.open(os.path.join(path, "result.txt"), 'a') 79 | result_file.write(scores_str) 80 | 81 | r1_r = scores_str[scores_str.find("ROUGE-1 Average_R:")+19:scores_str.find("ROUGE-1 Average_R:")+26] 82 | r2_r = scores_str[scores_str.find("ROUGE-2 Average_R:")+19:scores_str.find("ROUGE-2 Average_R:")+26] 83 | rl_r = scores_str[scores_str.find("ROUGE-L Average_R:")+19:scores_str.find("ROUGE-L Average_R:")+26] 84 | 85 | r1_p = scores_str[scores_str.find("ROUGE-1 Average_P:")+19:scores_str.find("ROUGE-1 Average_P:")+26] 86 | r2_p = scores_str[scores_str.find("ROUGE-2 Average_P:")+19:scores_str.find("ROUGE-2 Average_P:")+26] 87 | rl_p = scores_str[scores_str.find("ROUGE-L Average_P:")+19:scores_str.find("ROUGE-L Average_P:")+26] 88 | 89 | r1_f = scores_str[scores_str.find("ROUGE-1 Average_F:")+19:scores_str.find("ROUGE-1 Average_F:")+26] 90 | r2_f = scores_str[scores_str.find("ROUGE-2 Average_F:")+19:scores_str.find("ROUGE-2 Average_F:")+26] 91 | rl_f = scores_str[scores_str.find("ROUGE-L Average_F:")+19:scores_str.find("ROUGE-L Average_F:")+26] 92 | 93 | scores = {} 94 | scores['rouge-1'] = {} 95 | scores['rouge-2'] = {} 96 | scores['rouge-l'] = {} 97 | 98 | scores['rouge-1']['r'] = float(r1_r) 99 | scores['rouge-1']['p'] = float(r1_p) 100 | scores['rouge-1']['f'] = float(r1_f) 101 | 102 | scores['rouge-2']['r'] = float(r2_r) 103 | scores['rouge-2']['p'] = float(r2_p) 104 | scores['rouge-2']['f'] = float(r2_f) 105 | 106 | scores['rouge-l']['r'] = float(rl_r) 107 | scores['rouge-l']['p'] = float(rl_p) 108 | scores['rouge-l']['f'] = float(rl_f) 109 | return scores 110 | 111 | 112 | def rouge_files_simple(path, refs_file, hyps): 113 | scores_str = trans_id(path, hyps, refs_file) 114 | result_file = codecs.open(os.path.join(path, "result_nsent.txt"), 'a') 115 | result_file.write(scores_str) 116 | 117 | r1_r = scores_str[scores_str.find("ROUGE-1 Average_R:")+19:scores_str.find("ROUGE-1 Average_R:")+26] 118 | r2_r = scores_str[scores_str.find("ROUGE-2 Average_R:")+19:scores_str.find("ROUGE-2 Average_R:")+26] 119 | rl_r = scores_str[scores_str.find("ROUGE-L Average_R:")+19:scores_str.find("ROUGE-L Average_R:")+26] 120 | 121 | r1_p = scores_str[scores_str.find("ROUGE-1 Average_P:")+19:scores_str.find("ROUGE-1 Average_P:")+26] 122 | r2_p = scores_str[scores_str.find("ROUGE-2 Average_P:")+19:scores_str.find("ROUGE-2 Average_P:")+26] 123 | rl_p = scores_str[scores_str.find("ROUGE-L Average_P:")+19:scores_str.find("ROUGE-L Average_P:")+26] 124 | 125 | r1_f = scores_str[scores_str.find("ROUGE-1 Average_F:")+19:scores_str.find("ROUGE-1 Average_F:")+26] 126 | r2_f = scores_str[scores_str.find("ROUGE-2 Average_F:")+19:scores_str.find("ROUGE-2 Average_F:")+26] 127 | rl_f = scores_str[scores_str.find("ROUGE-L Average_F:")+19:scores_str.find("ROUGE-L Average_F:")+26] 128 | 129 | return r1_f, r2_f, rl_f 130 | 131 | 132 | def rouge_file_list(refs_file, hyps_list): 133 | refs = open(refs_file).readlines() 134 | scores = rouge_lists(refs, hyps_list) 135 | 136 | return scores 137 | 138 | 139 | def pprint_rouge_scores(scores, pivot=False): 140 | pdt = pandas.DataFrame(scores) 141 | 142 | if pivot: 143 | pdt = pdt.T 144 | 145 | table = tabulate(pdt, 146 | headers='keys', 147 | floatfmt=".4f", tablefmt="psql") 148 | 149 | return table 150 | -------------------------------------------------------------------------------- /utils/generic.py: -------------------------------------------------------------------------------- 1 | from itertools import zip_longest 2 | 3 | import numpy 4 | import umap 5 | from sklearn.decomposition import PCA 6 | 7 | 8 | def merge_dicts(a, b): 9 | a.update({k: v for k, v in b.items() if k in a}) 10 | return a 11 | 12 | 13 | def number_h(num): 14 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 15 | if abs(num) < 1000.0: 16 | return "%3.1f%s" % (num, unit) 17 | num /= 1000.0 18 | return "%.1f%s" % (num, 'Yi') 19 | 20 | 21 | def group(lst, n): 22 | """group([0,3,4,10,2,3], 2) => [(0,3), (4,10), (2,3)] 23 | 24 | Group a list into consecutive n-tuples. Incomplete tuples are 25 | discarded e.g. 26 | 27 | >>> group(range(10), 3) 28 | [(0, 1, 2), (3, 4, 5), (6, 7, 8)] 29 | """ 30 | return zip(*[lst[i::n] for i in range(n)]) 31 | 32 | 33 | def pairwise(iterable): 34 | it = iter(iterable) 35 | a = next(it, None) 36 | 37 | for b in it: 38 | yield (a, b) 39 | a = b 40 | 41 | 42 | def concat_multiline_strings(a, b): 43 | str = [] 44 | for line1, line2 in zip_longest(a.split("\n"), b.split("\n"), 45 | fillvalue=''): 46 | str.append("\t".join([line1, line2])) 47 | 48 | return "\n".join(str) 49 | 50 | 51 | def dim_reduce(data_sets, n_components=2, method="PCA"): 52 | data = numpy.vstack(data_sets) 53 | splits = numpy.cumsum([0] + [len(x) for x in data_sets]) 54 | if method == "PCA": 55 | reducer = PCA(random_state=20, n_components=n_components) 56 | embedding = reducer.fit_transform(data) 57 | elif method == "UMAP": 58 | reducer = umap.UMAP(random_state=20, 59 | n_components=n_components, 60 | min_dist=0.5) 61 | embedding = reducer.fit_transform(data) 62 | else: 63 | reducer_linear = PCA(random_state=20, n_components=50) 64 | linear_embedding = reducer_linear.fit_transform(data) 65 | reducer_nonlinear = umap.UMAP(random_state=20, 66 | n_components=n_components, 67 | min_dist=0.5) 68 | embedding = reducer_nonlinear.fit_transform(linear_embedding) 69 | 70 | return [embedding[start:stop] for start, stop in pairwise(splits)] 71 | -------------------------------------------------------------------------------- /utils/load_embeddings.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | import pickle 4 | 5 | import numpy 6 | 7 | 8 | def file_cache_name(file): 9 | head, tail = os.path.split(file) 10 | filename, ext = os.path.splitext(tail) 11 | return os.path.join(head, filename + ".p") 12 | 13 | 14 | def write_cache_word_vectors(file, data): 15 | with open(file_cache_name(file), 'wb') as pickle_file: 16 | pickle.dump(data, pickle_file) 17 | 18 | 19 | def load_cache_word_vectors(file): 20 | with open(file_cache_name(file), 'rb') as f: 21 | return pickle.load(f) 22 | 23 | 24 | def load_word_vectors(file, dim): 25 | """ 26 | Read the word vectors from a text file 27 | Args: 28 | file (): the filename 29 | dim (): the dimensions of the word vectors 30 | 31 | Returns: 32 | word2idx (dict): dictionary of words to ids 33 | idx2word (dict): dictionary of ids to words 34 | embeddings (numpy.ndarray): the word embeddings matrix 35 | 36 | """ 37 | # in order to avoid this time consuming operation, cache the results 38 | try: 39 | cache = load_cache_word_vectors(file) 40 | print("Loaded word embeddings from cache.") 41 | return cache 42 | except OSError: 43 | print("Didn't find embeddings cache file {}".format(file)) 44 | 45 | # create the necessary dictionaries and the word embeddings matrix 46 | if os.path.exists(file): 47 | print('Indexing file {} ...'.format(file)) 48 | 49 | word2idx = {} # dictionary of words to ids 50 | idx2word = {} # dictionary of ids to words 51 | embeddings = [] # the word embeddings matrix 52 | 53 | # create the 2D array, which will be used for initializing 54 | # the Embedding layer of a NN. 55 | # We reserve the first row (idx=0), as the word embedding, 56 | # which will be used for zero padding (word with id = 0). 57 | embeddings.append(numpy.zeros(dim)) 58 | 59 | # flag indicating whether the first row of the embeddings file 60 | # has a header 61 | header = False 62 | 63 | # read file, line by line 64 | with open(file, "r", encoding="utf-8") as f: 65 | for i, line in enumerate(f, 1): 66 | 67 | # skip the first row if it is a header 68 | if i == 1: 69 | if len(line.split()) < dim: 70 | header = True 71 | continue 72 | 73 | values = line.split(" ") 74 | word = values[0] 75 | vector = numpy.asarray(values[1:], dtype='float32') 76 | 77 | index = i - 1 if header else i 78 | 79 | idx2word[index] = word 80 | word2idx[word] = index 81 | embeddings.append(vector) 82 | 83 | # add an unk token, for OOV words 84 | if "" not in word2idx: 85 | idx2word[len(idx2word) + 1] = "" 86 | word2idx[""] = len(word2idx) + 1 87 | embeddings.append( 88 | numpy.random.uniform(low=-0.05, high=0.05, size=dim)) 89 | 90 | print(set([len(x) for x in embeddings])) 91 | 92 | print('Found %s word vectors.' % len(embeddings)) 93 | embeddings = numpy.array(embeddings, dtype='float32') 94 | 95 | # write the data to a cache file 96 | write_cache_word_vectors(file, (word2idx, idx2word, embeddings)) 97 | 98 | return word2idx, idx2word, embeddings 99 | 100 | else: 101 | print("{} not found!".format(file)) 102 | raise OSError(errno.ENOENT, os.strerror(errno.ENOENT), file) 103 | -------------------------------------------------------------------------------- /utils/opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import signal 4 | import subprocess 5 | import sys 6 | 7 | import torch 8 | 9 | from sys_config import BASE_DIR 10 | from utils.config import load_config 11 | 12 | 13 | def spawn_visdom(): 14 | try: 15 | subprocess.run(["visdom > visdom.txt 2>&1 &"], shell=True) 16 | except: 17 | print("Visdom is already running...") 18 | 19 | def signal_handler(signal, frame): 20 | subprocess.run(["pkill visdom"], shell=True) 21 | print("Killing Visdom server...") 22 | sys.exit(0) 23 | 24 | signal.signal(signal.SIGINT, signal_handler) 25 | 26 | 27 | def train_options(): 28 | print(os.getcwd()) 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--config', default="model_configs/lm_prior.yaml") 31 | parser.add_argument('--name', default="lm_3w") 32 | parser.add_argument('--desc') 33 | parser.add_argument('--resume') 34 | parser.add_argument('--transfer') 35 | parser.add_argument('--visdom', action='store_true') 36 | parser.add_argument('--vocab') 37 | parser.add_argument('--cp-vocab') 38 | parser.add_argument('--device', default="auto") 39 | parser.add_argument('--cores', type=int, default=1) 40 | parser.add_argument('--source', nargs='*', 41 | default=["models", "modules", "utils"]) 42 | 43 | args = parser.parse_args() 44 | config = load_config(args.config) 45 | 46 | if args.name is None: 47 | config_filename = os.path.basename(args.config) 48 | args.name = os.path.splitext(config_filename)[0] 49 | 50 | config["name"] = args.name 51 | config["desc"] = args.desc 52 | 53 | if args.device == "auto": 54 | args.device = torch.device("cuda" if torch.cuda.is_available() 55 | else "cpu") 56 | 57 | if args.source is None: 58 | args.source = [] 59 | 60 | args.source = [os.path.join(BASE_DIR, dir) for dir in args.source] 61 | 62 | if args.visdom: 63 | spawn_visdom() 64 | 65 | for arg in vars(args): 66 | print("{}:{}".format(arg, getattr(args, arg))) 67 | print() 68 | 69 | return args, config 70 | 71 | 72 | def seq2seq2seq_options(): 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('--config', default="model_configs/ds.full.yaml") 75 | parser.add_argument('--name', default="test-justice-out") 76 | parser.add_argument('--desc') 77 | parser.add_argument('--resume') 78 | parser.add_argument('--visdom', action='store_true') 79 | parser.add_argument('--transfer-lm') 80 | parser.add_argument('--device', default="auto") 81 | parser.add_argument('--cores', type=int, default=4) 82 | parser.add_argument('--source', nargs='*', default=["models", "modules", "utils"]) 83 | 84 | args = parser.parse_args() 85 | config = load_config(args.config) 86 | 87 | if args.name is None: 88 | config_filename = os.path.basename(args.config) 89 | args.name = os.path.splitext(config_filename)[0] 90 | 91 | config["name"] = args.name 92 | config["desc"] = args.desc 93 | 94 | if args.device == "auto": 95 | args.device = torch.device("cuda" if torch.cuda.is_available() 96 | else "cpu") 97 | 98 | if args.source is None: 99 | args.source = [] 100 | 101 | args.source = [os.path.join(BASE_DIR, dir) for dir in args.source] 102 | 103 | if args.visdom: 104 | spawn_visdom() 105 | 106 | for arg in vars(args): 107 | print("{}:{}".format(arg, getattr(args, arg))) 108 | print() 109 | 110 | return args, config 111 | -------------------------------------------------------------------------------- /utils/training.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | import torch 5 | 6 | from sys_config import BASE_DIR 7 | 8 | 9 | def save_checkpoint(state, name, path=None, timestamp=False, tag=None, verbose=False): 10 | """ 11 | Save a trained model, along with its optimizer, in order to be able to 12 | resume training 13 | Args: 14 | path (str): the directory, in which to save the checkpoints 15 | timestamp (bool): whether to keep only one model (latest), or keep every 16 | checkpoint 17 | 18 | Returns: 19 | 20 | """ 21 | now = datetime.datetime.now().strftime("%y-%m-%d_%H:%M:%S") 22 | 23 | if tag is not None: 24 | if isinstance(tag, str): 25 | name += "_{}".format(tag) 26 | elif isinstance(tag, list): 27 | for t in tag: 28 | name += "_{}".format(t) 29 | else: 30 | raise ValueError("invalid tag type!") 31 | 32 | if timestamp: 33 | name += "_{}".format(now) 34 | 35 | name += ".pt" 36 | 37 | if path is None: 38 | path = os.path.join(BASE_DIR, "checkpoints") 39 | 40 | file = os.path.join(path, name) 41 | 42 | if verbose: 43 | print("saving checkpoint:{} ...".format(name)) 44 | 45 | torch.save(state, file) 46 | 47 | return name 48 | 49 | 50 | def load_checkpoint(name, path=None, device=None): 51 | """ 52 | Load a trained model, along with its optimizer 53 | Args: 54 | name (str): the name of the model 55 | path (str): the directory, in which the model is saved 56 | 57 | Returns: 58 | model, optimizer 59 | 60 | """ 61 | if path is None: 62 | path = os.path.join(BASE_DIR, "checkpoints") 63 | 64 | model_fname = os.path.join(path, "{}.pt".format(name)) 65 | 66 | print("Loading checkpoint `{}` ...".format(model_fname), end=" ") 67 | 68 | with open(model_fname, 'rb') as f: 69 | state = torch.load(f, map_location="cpu") 70 | 71 | print("done!") 72 | 73 | return state 74 | -------------------------------------------------------------------------------- /utils/transfer.py: -------------------------------------------------------------------------------- 1 | def freeze_module(layer, depth=None): 2 | if depth is None: 3 | for param in layer.parameters(): 4 | param.requires_grad = False 5 | else: 6 | for weight in layer.all_weights[depth]: 7 | weight.requires_grad = False 8 | 9 | 10 | def train_module(layer, depth=None): 11 | if depth is None: 12 | for param in layer.parameters(): 13 | param.requires_grad = True 14 | else: 15 | for weight in layer.all_weights[depth]: 16 | weight.requires_grad = True 17 | 18 | 19 | def dict_rename_by_pattern(from_dict, patterns): 20 | for k in list(from_dict.keys()): 21 | v = from_dict.pop(k) 22 | p = list(filter(lambda x: x in k, patterns.keys())) 23 | if len(p) > 0: 24 | new_key = k.replace(p[0], patterns[p[0]]) 25 | from_dict[new_key] = v 26 | else: 27 | from_dict[k] = v 28 | 29 | 30 | def load_state_dict_subset(model, pretrained_dict): 31 | model_dict = model.state_dict() 32 | 33 | # 1. filter out unnecessary keys 34 | pretrained_dict = {k: v for k, v in pretrained_dict.items() 35 | if k in model_dict} 36 | # 2. overwrite entries in the existing state dict 37 | model_dict.update(pretrained_dict) 38 | 39 | # 3. load the new state dict 40 | model.load_state_dict(model_dict) 41 | -------------------------------------------------------------------------------- /utils/viz.py: -------------------------------------------------------------------------------- 1 | from matplotlib.backends.backend_pdf import PdfPages 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | from graphviz import Digraph 5 | from torch.autograd import Variable 6 | 7 | 8 | def make_dot_2(var): 9 | node_attr = dict(style='filled', 10 | shape='box', 11 | align='left', 12 | fontsize='12', 13 | ranksep='0.1', 14 | height='0.2') 15 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) 16 | seen = set() 17 | 18 | def add_nodes(var): 19 | if var not in seen: 20 | if isinstance(var, Variable): 21 | value = '(' + (', ').join(['%d' % v for v in var.size()]) + ')' 22 | dot.node(str(id(var)), str(value), fillcolor='lightblue') 23 | else: 24 | dot.node(str(id(var)), str(type(var).__name__)) 25 | seen.add(var) 26 | if hasattr(var, 'previous_functions'): 27 | for u in var.previous_functions: 28 | dot.edge(str(id(u[0])), str(id(var))) 29 | add_nodes(u[0]) 30 | 31 | add_nodes(var.creator) 32 | return dot 33 | 34 | 35 | def attention_heatmap_subplot(src, trg, attention, ax=None): 36 | g = sns.heatmap(attention, 37 | # cmap="Greys_r", 38 | cmap="viridis", 39 | cbar=False, 40 | # annot=True, 41 | vmin=0, vmax=1, 42 | robust=False, 43 | fmt=".2f", 44 | annot_kws={'size': 12}, 45 | xticklabels=trg, 46 | yticklabels=src, 47 | # square=True, 48 | ax=ax) 49 | g.set_yticklabels(g.get_yticklabels(), rotation=0, fontsize=12) 50 | g.set_xticklabels(g.get_xticklabels(), rotation=60, fontsize=12) 51 | 52 | # g.set_xticks(numpy.arange(len(src)), src, rotation=0) 53 | # g.set_yticks(numpy.arange(len(trg)), trg, rotation=60) 54 | 55 | 56 | def visualize_translations(lang, prefix_trg2src=False): 57 | for s1, s2, a12, s3, a23 in lang: 58 | # attention_heatmap(i, o, a[:len(o), :len(i)].t().cpu().numpy()) 59 | if prefix_trg2src: 60 | s2_enc = [""] + s2[:-1] 61 | else: 62 | s2_enc = s2 63 | attention_heatmap_pair(s1, s2, s2_enc, s3, 64 | a12.t()[:len(s1), :len(s2)].cpu().numpy(), 65 | a23.t()[:len(s2_enc), :len(s3)].cpu().numpy()) 66 | 67 | 68 | def visualize_compression(lang, prefix_trg2src=False): 69 | for s1, s2, a12, s3, a23 in lang: 70 | # attention_heatmap(i, o, a[:len(o), :len(i)].t().cpu().numpy()) 71 | if prefix_trg2src: 72 | s2_enc = [""] + s2[:-1] 73 | else: 74 | s2_enc = s2 75 | attention_heatmap_pair(s1, s2, s2_enc, s3, 76 | a12.t()[:len(s1), :len(s2)].cpu().numpy(), 77 | a23.t()[:len(s2_enc), :len(s3)].cpu().numpy()) 78 | 79 | 80 | def seq3_attentions(sent, file='foo.pdf'): 81 | from matplotlib import rc 82 | rc('font', **{'family': 'serif', 'serif': ['CMU Serif']}) 83 | # rc('text', usetex=True) 84 | # rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']}) 85 | # rc('text', usetex=True) 86 | 87 | with PdfPages(file) as pdf: 88 | for s1, s2, a12, s3, a23 in sent: 89 | s1 = s1[:s1.index(".") + 1] 90 | s12 = s2[:s2.index("") + 1] 91 | s23 = s2[:s2.index("")] 92 | s3 = s3[:len(s1)] 93 | 94 | att12 = a12.t()[:len(s1), :len(s12)].cpu().numpy() 95 | att23 = a23.t()[:len(s23), :len(s3)].cpu().numpy() 96 | 97 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) 98 | attention_heatmap_subplot(s1, s12, att12, ax=ax1) 99 | attention_heatmap_subplot(s23, s3, att23, ax=ax2) 100 | ax1.set_title("Source to Compression") 101 | ax2.set_title("Compression to Reconstruction") 102 | fig.tight_layout() 103 | 104 | pdf.savefig(fig) 105 | 106 | 107 | def attention_heatmap(src, trg, attention): 108 | fig, ax = plt.subplots(figsize=(11, 5)) 109 | attention_heatmap_subplot(src, trg, attention) 110 | fig.tight_layout() 111 | plt.show() 112 | 113 | 114 | def attention_heatmap_pair(s1, s2, s3, s4, att12, att23): 115 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) 116 | attention_heatmap_subplot(s1, s2, att12, ax=ax1) 117 | attention_heatmap_subplot(s3, s4, att23, ax=ax2) 118 | ax1.set_title("RNN1 -> RNN2") 119 | ax2.set_title("RNN2 -> RNN3") 120 | fig.tight_layout() 121 | plt.show() 122 | --------------------------------------------------------------------------------