├── .gitignore ├── LICENSE ├── README.md ├── generate_demo.ipynb ├── requirements.txt ├── scripts ├── binarize_data.sh ├── collect_results.py ├── compress_glove.py ├── deduplicate.py ├── edit_eval.py ├── generate.sh ├── generation_to_analysis_file.py ├── get_mask_ids.py ├── guu_lm_query.txt ├── inferece_analysis.sh ├── interpolation_ppl.py ├── lower_case.py ├── multi-bleu.perl ├── pos_eval.py ├── precompute_bert.py ├── template_to_analysis_file.py ├── train.sh └── train.slurm.sh ├── sparse_prototype ├── __init__.py ├── distribution │ ├── __init__.py │ └── vmf_batch.py ├── guu_criterion.py ├── inv_editor │ ├── __init__.py │ ├── inv_editor.py │ ├── inv_editor_guu.py │ └── inv_editor_levenshtein.py ├── language_pair_map_dataset.py ├── lm_criterion.py ├── prepare_data.py ├── retrieve_prototype_dataset.py ├── retriever │ ├── __init__.py │ ├── bert.py │ ├── cnn_text.py │ ├── precompute_emb.py │ └── sent_bert.py ├── sequence_generator.py ├── sp_criterion.py ├── sp_hub_interface.py ├── sp_model.py ├── sp_task.py ├── topk_criterion.py ├── vae.py └── vmf_vae.py ├── train.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # JetBrains PyCharm IDE 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # macOS dir files 13 | .DS_Store 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # Checkpoints 35 | checkpoints 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # dotenv 92 | .env 93 | 94 | # virtualenv 95 | .venv 96 | venv/ 97 | ENV/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | 112 | # Generated files 113 | /fairseq/temporal_convolution_tbc 114 | /fairseq/modules/*_layer/*_forward.cu 115 | /fairseq/modules/*_layer/*_backward.cu 116 | 117 | # data 118 | data-bin/ 119 | 120 | # reranking 121 | /examples/reranking/rerank_data 122 | 123 | # Cython-generated C++ source files 124 | /fairseq/data/data_utils_fast.cpp 125 | /fairseq/data/token_block_utils_fast.cpp 126 | 127 | # VSCODE 128 | .vscode/ftp-sync.json 129 | .vscode/settings.json 130 | 131 | # Experimental Folder 132 | experimental/* 133 | 134 | # custom 135 | *.swp 136 | *.pyc 137 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Junxian He 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse Neural Editor 2 | This repo is the PyTorch implementation of this [paper](https://arxiv.org/abs/2006.16336): 3 | 4 | ``` 5 | Learning Sparse Prototypes for Text Generation 6 | Junxian He, Taylor Berg-Kirkpatrick, Graham Neubig 7 | NeurIPS 2020 8 | ``` 9 | 10 | In this repo, we implement a generative model of text that generates sentences by editying non-parametric prorotypes. The prototype support set is encouraged to be sparse during training to improve the memory/time efficiency at test time. 11 | 12 | 13 | ## Dependencies 14 | 15 | The code mainly requires [PyTorch](https://pytorch.org/) (>=1.4.0) and [fairseq](https://github.com/pytorch/fairseq) (we run our experiments based on this specific [commit](https://github.com/pytorch/fairseq/commit/b65a85b692544e36f9e83ada91cf4ef529791c69)). 16 | 17 | Install dependencies: 18 | 19 | ```bash 20 | # install fairseq from a specific commit 21 | git clone git@github.com:pytorch/fairseq.git fairseq_local 22 | cd fairseq_local 23 | git reset --hard b65a85b 24 | 25 | # a modified sequence_generator.py to use edit vectors 26 | cp ../sparse_prototype/sequence_generator.py fairseq 27 | 28 | pip install --editable ./ 29 | 30 | cd .. 31 | 32 | # install additional dependencies 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | 37 | 38 | ## Prepare Data 39 | 40 | ```bash 41 | # download coco data 42 | gdown https://drive.google.com/uc?id=1fMBZnMZz46qC0Im6y53MnDDQGRuwoC_M 43 | 44 | # download yelp medium data 45 | gdown https://drive.google.com/uc?id=1Bgk94NZeoexdCWF_WPMoIFPLRjJsbuBF 46 | 47 | # download yelp large data 48 | gdown https://drive.google.com/uc?id=1Z6wc4n5UBghwyNOo-C41vXEdNG5CE1Pa 49 | 50 | 51 | mkdir datasets 52 | 53 | # take coco dataset as an example 54 | tar -xvzf coco40k.tar.gz -C datasets 55 | 56 | # binarize dataset for fairseq 57 | bash scripts/binarize_data.sh coco40k 58 | 59 | # generate a mask file which is used to avoid selecting 60 | # exactly the same example as prototype during training 61 | python scripts/get_mask_ids.py coco40k 62 | ``` 63 | 64 | 65 | 66 | ## Training 67 | 68 | We first pre-compute the sentence embeddings for all data examples offline and save them in memory-mapped files using `np.memmap`. During training/evaluation, a bilinear transformation is applied between these data embeddings and prototype embeddings to obtain the retrieval distribution. Here we use BERT as the offline encoder: 69 | 70 | ```bash 71 | # embeddings are saved into pretrained_sent_embeddings/[dataset name] 72 | CUDA_VISIBLE_DEVICES=xx python scripts/precompute_bert.py [dataset name] 73 | ``` 74 | 75 | 76 | 77 | [GloVe](https://github.com/stanfordnlp/GloVe) embeddings are used in the paper to initialize word embeddings: 78 | 79 | ```bash 80 | wget http://nlp.stanford.edu/data/wordvecs/glove.6B.zip 81 | mkdir glove_embeddings 82 | unzip glove.6B.zip -d glove_embeddings 83 | 84 | # compress glove embeddings to generate a new embedding file 85 | # that only contains the dictionary of the dataset 86 | python scripts/compress_glove.py \ 87 | --embed-path glove_embeddings/glove.6B.300d.txt \ 88 | --dict-path data-bin/[dataset_name]/dict.txt \ 89 | > glove_embeddings/[dataset_name]_glove.txt 90 | ``` 91 | 92 | 93 | 94 | ##### Train the model: 95 | 96 | ```bash 97 | # train the sparse neural editor 98 | # [GPUs] can be multiple ids to perform data-parallel training 99 | # some hyperparameters can be specified (e.g. -a [alpha]), see 100 | # details in the script 101 | bash scripts/train.sh -g [GPUs] -d [dataset name] 102 | 103 | # train lm baseline 104 | bash scripts/train.sh -g [GPUs] -c lm_baseline -d [dataset name] 105 | ``` 106 | 107 | 108 | 109 | ## Evaluation 110 | 111 | ##### compute ppl: 112 | 113 | ```bash 114 | # approximate importance-weighted ppl 115 | bash scripts/train.sh -g [GPUs] -d [dataset name] -e iw -p [checkpoint directory] 116 | 117 | # pruning prototypes can be performed at eval time 118 | # [prune num] is the number of prototypes kept 119 | bash scripts/train.sh -g [GPUs] -d [dataset name] -u [prune num] -e iw -p [checkpoint directory] 120 | ``` 121 | 122 | ## Template-based Generation 123 | See the notebook `generate_demo.ipynb`(mainly the `sample_from_cluster` function) for examples to load the pretrained model and generate based on given templates. 124 | 125 | 126 | ## Citation 127 | 128 | ``` 129 | @inproceedings{he2020learning, 130 | title={Learning Sparse Prototypes for Text Generation}, 131 | author={He, Junxian and Berg-Kirkpatrick, Taylor and Neubig, Graham}, 132 | booktitle={Proceedings of NeurIPS}, 133 | year={2020} 134 | } 135 | ``` 136 | 137 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | stanza 2 | tensorboardX 3 | 4 | scipy==1.5.2 5 | transformers==3.1.0 6 | edlib==1.3.8.post1 7 | -------------------------------------------------------------------------------- /scripts/binarize_data.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # 3 | # preprocess.sh 4 | # Copyright (C) 2020-02-09 Junxian 5 | # 6 | # Distributed under terms of the MIT license. 7 | # 8 | 9 | dataset=$1 10 | datadir=datasets/${dataset} 11 | 12 | fairseq-preprocess \ 13 | --only-source \ 14 | --trainpref ${datadir}/train.txt \ 15 | --validpref ${datadir}/valid.txt,${datadir}/template.txt \ 16 | --testpref ${datadir}/test.txt \ 17 | --destdir data-bin/${dataset} \ 18 | --nwordssrc 10000 \ 19 | --workers 20 \ 20 | 21 | -------------------------------------------------------------------------------- /scripts/collect_results.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script collect all results in a directory 3 | """ 4 | 5 | import os 6 | import argparse 7 | 8 | def is_number(s): 9 | try: 10 | float(s) 11 | return True 12 | except ValueError: 13 | return False 14 | 15 | def parse_line(line): 16 | line_s = line.strip().split(" | ") 17 | result = {} 18 | for entry in line_s: 19 | if len(entry.split()) == 2: 20 | k, v = entry.split() 21 | if is_number(v): 22 | result[k] = v 23 | 24 | return result 25 | 26 | 27 | parser = argparse.ArgumentParser(description="collect results script") 28 | 29 | parser.add_argument('--outdir', type=str, help="an high level experiment dir") 30 | 31 | args = parser.parse_args() 32 | 33 | best_results = [] 34 | best_nll = 1000 35 | best_ppl = 1000 36 | 37 | with open(os.path.join(args.outdir, "summary.txt"), "w") as fout: 38 | for root, subdirs, files in os.walk(args.outdir): 39 | valid_result = {} 40 | best_nll = best_ppl = 1000 41 | best_line_str = '' 42 | last_line_str = '' 43 | 44 | best_line_dict = {} 45 | last_line_dict = {} 46 | for file in files: 47 | if file == "stdout.log": 48 | print("processing {}".format(os.path.join(root, file))) 49 | fin = open(os.path.join(root, file)) 50 | for line in fin: 51 | if "valid on" in line and 'ppl_iw' not in line: 52 | valid_result = parse_line(line) 53 | if len(valid_result) > 0 and float(valid_result['loss']) > 0: 54 | ppl = float(valid_result["ppl"]) 55 | if ppl < best_ppl: 56 | best_ppl = ppl 57 | best_line_str = line.rstrip() 58 | best_line_dict = valid_result 59 | 60 | last_line_str = line.rstrip() 61 | last_line_dict = valid_result 62 | 63 | fin.close() 64 | break 65 | 66 | if len(valid_result) > 0: 67 | try: 68 | fout.write("{}\n".format(os.path.abspath(root))) 69 | fout.write("valid best loss: {}, best ppl: {}\n".format( 70 | best_line_dict["loss"], best_line_dict["ppl"])) 71 | fout.write("valid last loss: {}, last ppl: {}\n".format( 72 | last_line_dict["loss"], last_line_dict["ppl"])) 73 | fout.write("{}\n".format(best_line_str)) 74 | fout.write("\n-----------------------------------\n\n") 75 | except KeyError: 76 | pass -------------------------------------------------------------------------------- /scripts/compress_glove.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | 5 | def parse_embedding(embed_path): 6 | """Parse embedding text file into a dictionary of word and embedding tensors. 7 | 8 | The first line can have vocabulary size and dimension. The following lines 9 | should contain word and embedding separated by spaces. 10 | 11 | Example: 12 | 2 5 13 | the -0.0230 -0.0264 0.0287 0.0171 0.1403 14 | at -0.0395 -0.1286 0.0275 0.0254 -0.0932 15 | """ 16 | embed_dict = {} 17 | with open(embed_path) as f_embed: 18 | # next(f_embed) # skip header 19 | for line in f_embed: 20 | pieces = line.rstrip().split(" ") 21 | embed_dict[pieces[0]] = pieces[1:] 22 | 23 | return embed_dict 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = argparse.ArgumentParser(description='simplify glove word embedding') 28 | parser.add_argument('--embed-path', type=str, help='the original glove embed path') 29 | parser.add_argument('--dict-path', type=str, default='dict path') 30 | 31 | args = parser.parse_args() 32 | 33 | embed_dict = parse_embedding(args.embed_path) 34 | sample = embed_dict['a'] 35 | embed_dim = len(sample) 36 | with open(args.dict_path) as fin: 37 | vocab_size = len(fin.readlines()) 38 | 39 | print('{} {}'.format(vocab_size, embed_dim)) 40 | with open(args.dict_path) as fin: 41 | for line in fin: 42 | word = line.split()[0] 43 | if word in embed_dict: 44 | print('{} {}'.format(word, ' '.join(embed_dict[word]))) 45 | else: 46 | print('{} {}'.format(word, ' '.join(['0'] * embed_dim))) 47 | 48 | 49 | -------------------------------------------------------------------------------- /scripts/deduplicate.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim:fenc=utf-8 4 | # 5 | # Copyright © 2020-03-02 Junxian 6 | # 7 | # Distributed under terms of the MIT license. 8 | 9 | import sys 10 | 11 | if __name__ == '__main__': 12 | save = set() 13 | for line in sys.stdin: 14 | key = line.rstrip() 15 | 16 | # for mscoco 17 | # key = '\t'.join(line.rstrip().split('\t')[1:]) 18 | if key not in save: 19 | print(line.rstrip()) 20 | save.update([key]) 21 | 22 | -------------------------------------------------------------------------------- /scripts/edit_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import subprocess 4 | import random 5 | import edlib 6 | from typing import List 7 | from collections import Counter 8 | 9 | import stanza 10 | 11 | def flat_cigar(cigar): 12 | r = [] 13 | pointer = 0 14 | 15 | while pointer < len(cigar): 16 | num = [] 17 | while cigar[pointer].isdigit(): 18 | num.append(cigar[pointer]) 19 | pointer += 1 20 | num = int(''.join(num)) 21 | 22 | r.extend([cigar[pointer]] * num) 23 | pointer += 1 24 | 25 | return r 26 | 27 | class ExtractMetric(object): 28 | """used for precision recall""" 29 | def __init__(self, nume=0, denom_p=0, denom_r=0, precision=0, recall=0, f1=0): 30 | super(ExtractMetric, self).__init__() 31 | self.nume = nume 32 | self.denom_p = denom_p 33 | self.denom_r = denom_r 34 | self.precision = precision 35 | self.recall = recall 36 | self.f1 = f1 37 | 38 | def read_file(fname): 39 | res1, res2 = [], [] 40 | with open(fname) as fin: 41 | for line in fin: 42 | x, y = line.rstrip().split('\t') 43 | res1.append(x) 44 | res2.append(y) 45 | 46 | return res1, res2 47 | 48 | def write_file(fname: str, data: List[str]): 49 | with open(fname, 'w') as fout: 50 | for sent in data: 51 | if isinstance(sent, list): 52 | fout.write('{}\n'.format(' '.join(sent))) 53 | else: 54 | fout.write('{}\n'.format(sent)) 55 | 56 | 57 | parser = argparse.ArgumentParser(description='Evaluate analysis metrics') 58 | parser.add_argument('--prefix', type=str, choices=['inference', 'generation'], 59 | help='prediction file prefix') 60 | parser.add_argument('--exp-dir', type=str, help='output directory') 61 | 62 | args = parser.parse_args() 63 | 64 | fout = open(os.path.join(args.exp_dir, 'edit_analysis_{}_res.txt'.format(args.prefix)), 'w') 65 | 66 | prototypes, examples = read_file(os.path.join(args.exp_dir, '{}_analysis_input.txt'.format(args.prefix))) 67 | examples_rand = random.sample(examples, len(examples)) 68 | prototype_path = os.path.join(args.exp_dir, 'prototype.txt') 69 | prototype_pos_path = os.path.join(args.exp_dir, 'prototype_pos.txt') 70 | 71 | 72 | example_path = os.path.join(args.exp_dir, 'example.txt') 73 | example_rand_path = os.path.join(args.exp_dir, 'example_rand.txt') 74 | example_pos_path = os.path.join(args.exp_dir, 'example_pos.txt') 75 | example_pos_rand_path = os.path.join(args.exp_dir, 'example_pos_rand.txt') 76 | 77 | 78 | 79 | write_file(prototype_path, prototypes) 80 | write_file(example_path, examples) 81 | write_file(example_rand_path, examples_rand) 82 | 83 | # surface BLEU 84 | bleu = subprocess.getoutput( 85 | "./support_prototype/scripts/multi-bleu.perl {} < {}".format(prototype_path, example_rand_path)) 86 | print('Regular BLEU (random baseline): \n{}'.format(bleu)) 87 | fout.write('Regular BLEU (random baseline): \n{}'.format(bleu)) 88 | 89 | fout.write('\n\n\n') 90 | 91 | bleu = subprocess.getoutput( 92 | "./support_prototype/scripts/multi-bleu.perl {} < {}".format(prototype_path, example_path)) 93 | print('Regular BLEU: \n{}'.format(bleu)) 94 | fout.write('Regular BLEU: \n{}'.format(bleu)) 95 | 96 | fout.write('\n\n\n') 97 | 98 | # POS tagging 99 | print('POS tagging') 100 | nlp = stanza.Pipeline(lang='en', processors='tokenize,mwt,pos', tokenize_pretokenized=True) 101 | prototype_doc = nlp('\n'.join(prototypes)) 102 | example_doc = nlp('\n'.join(examples)) 103 | 104 | prototypes_pos = [[word.upos for word in sent.words] for sent in prototype_doc.sentences] 105 | examples_pos = [[word.upos for word in sent.words] for sent in example_doc.sentences] 106 | 107 | example_rand_doc = random.sample(list(example_doc.sentences), len(example_doc.sentences)) 108 | examples_pos_rand = [[word.upos for word in sent.words]for sent in example_rand_doc] 109 | 110 | write_file(prototype_pos_path, prototypes_pos) 111 | write_file(example_pos_path, examples_pos) 112 | write_file(example_pos_rand_path, examples_pos_rand) 113 | 114 | 115 | # POS BLEU 116 | bleu = subprocess.getoutput( 117 | "./support_prototype/scripts/multi-bleu.perl {} < {}".format(prototype_pos_path, example_pos_rand_path)) 118 | print('POS BLEU (random baseline): \n{}'.format(bleu)) 119 | fout.write('POS BLEU (random baseline): \n{}'.format(bleu)) 120 | 121 | fout.write('\n\n\n') 122 | 123 | bleu = subprocess.getoutput( 124 | "./support_prototype/scripts/multi-bleu.perl {} < {}".format(prototype_pos_path, example_pos_path)) 125 | print('POS BLEU: \n{}'.format(bleu)) 126 | fout.write('POS BLEU: \n{}'.format(bleu)) 127 | 128 | fout.write('\n\n\n') 129 | 130 | # break down precision and recall 131 | print("compute precision, recall, f1") 132 | assert len(prototypes) == len(prototypes_pos) 133 | assert len(examples) == len(examples_pos) 134 | 135 | res = eval_f1(list(prototype_doc.sentences), example_rand_doc) 136 | 137 | res = sorted(res.items(), key=lambda item: -item[1].f1) 138 | 139 | fout.write('random baseline precision-recall\n') 140 | fout.write('POS recall precision f1\n') 141 | for k, v in res: 142 | fout.write('{} {} {} {}\n'.format(k, v.recall, v.precision, v.f1)) 143 | 144 | fout.write('\n\n\n') 145 | 146 | res = eval_f1(list(prototype_doc.sentences), list(example_doc.sentences)) 147 | res = sorted(res.items(), key=lambda item: -item[1].f1) 148 | 149 | fout.write('precision-recall\n') 150 | fout.write('POS recall precision f1\n') 151 | for k, v in res: 152 | fout.write('{} {} {} {}\n'.format(k, v.recall, v.precision, v.f1)) 153 | 154 | fout.close() 155 | -------------------------------------------------------------------------------- /scripts/generate.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # 3 | # train.sh 4 | # Copyright (C) 2020-02-09 Junxian 5 | # 6 | # Distributed under terms of the MIT license. 7 | # 8 | 9 | # evaluation script 10 | # Usage: 11 | # bash eval.sh -g [GPU] -p [model_dir] 12 | 13 | iw_nsamples=1000 14 | valid_subset="valid" 15 | 16 | 17 | DATE=`date +%Y%m%d` 18 | data_bin="yelp_large" 19 | emb_type="bert" 20 | eval_mode='gen_interpolation' 21 | decode_strategy='beam' 22 | kappa=30 23 | max_epoch=-1 24 | retriever=pretrained_embed 25 | linear_bias=0 26 | stop_bert_grad=1 27 | freeze_retriever=0 28 | reinforce=1 29 | temperature=1 30 | grad_lambda=0 31 | forget_rate=0.8 32 | decay_rate=1 33 | copy=0 34 | inv_editor='levenshtein' 35 | edit_embed_dim=10 36 | entropy_w=1 37 | term2_w=1 38 | rescale_factor=1. 39 | criterion=sp_elbo 40 | log_format=simple 41 | 42 | load_name='' 43 | 44 | retrieve_split=train 45 | 46 | glove_path=glove_embeddings/${data_bin}_glove.txt 47 | opt=adam 48 | template_emb_file=pretrained_sent_embeddings/${data_bin}.template.${emb_type}.hdf5 49 | 50 | if [ "$data_bin" = "ptb" ]; 51 | then 52 | num_class=41088 53 | max_tokens=2048 54 | save_interval_updates=100 55 | warmup_updates=1000 56 | retrieve_split=valid1 57 | ns=5 58 | elif [ "$data_bin" = "ptb10" ]; 59 | then 60 | num_class=5703 61 | max_tokens=512 62 | save_interval_updates=0 63 | warmup_updates=800 64 | ns=20 65 | elif [ "$data_bin" = "coco40k" ]; 66 | then 67 | num_class=39577 68 | max_tokens=2048 69 | save_interval_updates=0 70 | warmup_updates=1000 71 | max_epoch=30 72 | retrieve_split=valid1 73 | ns=10 74 | 75 | if [ "$retriever" = "sentence-bert" ]; 76 | then 77 | template_emb_file=pretrained_sent_embeddings/${data_bin}.template.bert.hdf5 78 | echo "read bert embeddings" 79 | fi 80 | elif [ "$data_bin" = "yelp" ]; 81 | then 82 | num_class=50000 83 | max_tokens=2048 84 | save_interval_updates=0 85 | warmup_updates=10000 86 | max_update=300000 87 | retrieve_split=valid1 88 | log_interval=100 89 | retriever=bert 90 | ns=10 91 | log_format=tqdm 92 | elif [ "$data_bin" = "yelp_large" ]; 93 | then 94 | num_class=100000 95 | max_tokens=1024 # distributed on two gpus 96 | save_interval_updates=5000 97 | warmup_updates=150000 98 | max_update=500000 99 | kappa=40 100 | lambda_config="0:0,150000:1" 101 | retrieve_split=valid1 102 | log_interval=100 103 | retriever=bert 104 | validate_interval=1000 105 | ns=10 106 | else 107 | num_class=0 108 | max_tokens=0 109 | save_interval_updates=0 110 | warmup_updates=0 111 | ns=0 112 | fi 113 | 114 | if [ "$opt" = "adam" ]; 115 | then 116 | warmup_init_lr='1e-03' 117 | # add_opt_string="--adam-betas '(0.9, 0.98)'" 118 | add_opt_string='' 119 | lr=0.001 120 | else 121 | warmup_init_lr='1' 122 | add_opt_string="" 123 | warmup_updates=8000 124 | lr=1.0 125 | fi 126 | 127 | 128 | 129 | GPU=0 130 | alpha=0.01 131 | separate="1" 132 | prune_num="-1" 133 | gen_prototype=200 134 | gen_nz=10 135 | 136 | while getopts ":g:a:p:k:e:l:t:s:r:u:c:n:z:" arg; do 137 | case $arg in 138 | g) GPU="$OPTARG" 139 | ;; 140 | a) alpha="$OPTARG" 141 | ;; 142 | p) LOADDIR="$OPTARG" 143 | ;; 144 | k) kappa="$OPTARG" 145 | ;; 146 | e) edit_embed_dim="$OPTARG" 147 | ;; 148 | l) lambda_momentum="$OPTARG" 149 | ;; 150 | t) temperature="$OPTARG" 151 | ;; 152 | s) separate="$OPTARG" 153 | ;; 154 | r) rescale_factor="$OPTARG" 155 | ;; 156 | u) prune_num="$OPTARG" 157 | ;; 158 | c) criterion="$OPTARG" 159 | ;; 160 | n) gen_prototype="$OPTARG" 161 | ;; 162 | z) gen_nz="$OPTARG" 163 | ;; 164 | \?) echo "Invalid option -$OPTARG" >&2 165 | ;; 166 | esac 167 | done 168 | 169 | max_tokens=$(( max_tokens * ns / gen_nz )) 170 | 171 | if [ "$criterion" = "lm_baseline" ]; 172 | then 173 | ns=1 174 | eval_mode='none' 175 | fi 176 | 177 | GPUSTR=$(printf "$GPU" | tr , _) 178 | 179 | if [[ -v LOADDIR ]]; 180 | then 181 | # add_load_string="--reset-meters" 182 | add_load_string="--reset-meters --reset-optimizer" 183 | cstring="_continue" 184 | restore_file=checkpoint_load.pt 185 | save_interval_updates=50 186 | if [ "$opt" = "adam" ]; 187 | then 188 | lr=0.0003 189 | warmup_init_lr=0.0003 190 | warmup_updates=8000 191 | fi 192 | if [ "$stop_bert_grad" = 0 ]; 193 | then 194 | max_tokens=1024 195 | ns=10 196 | lr=0.1 197 | warmup_init_lr='1e-3' 198 | warmup_updates=5000 199 | fi 200 | else 201 | add_load_string="" 202 | cstring="" 203 | restore_file=null.pt 204 | fi 205 | 206 | if [ "$eval_mode" = "entropy" ]; 207 | then 208 | LOADDIR="tmp" 209 | fi 210 | 211 | enc_opt_freq=10 212 | dec_opt_freq=1 213 | 214 | if [ "$separate" = "0" ]; 215 | then 216 | separate_str="" 217 | train_separate_str="" 218 | train_script=train.py 219 | else 220 | echo "separate training" 221 | separate_str="_sep_eof${enc_opt_freq}_dof${dec_opt_freq}" 222 | train_separate_str="--dec-opt-freq ${dec_opt_freq} --enc-opt-freq ${enc_opt_freq}" 223 | train_script=train_junxian.py 224 | warmup_updates=$(( warmup_updates*2 )) 225 | max_epoch=$(( max_epoch*2 )) 226 | fi 227 | 228 | if [ "$prune_num" != "-1" ]; 229 | then 230 | prune_str="_prune${prune_num}" 231 | else 232 | prune_str="" 233 | fi 234 | # declare -a rescale_list=("0.01") 235 | # declare -a alpha_list=("0.01") 236 | 237 | if [ "$decode_strategy" = "beam" ]; 238 | then 239 | decode_str="" 240 | elif [ "$decode_strategy" = "sample" ]; 241 | then 242 | decode_str="--sampling --sampling-topk 100 --beam 1" 243 | else 244 | decode_str="" 245 | fi 246 | 247 | echo "start evaluation" 248 | 249 | 250 | CUDA_VISIBLE_DEVICES=${GPU} python generate_junxian.py \ 251 | support_prototype/data-bin/${data_bin} \ 252 | --arch ${data_bin} --task support_prototype \ 253 | --dropout 0.3 \ 254 | --edit-embed-dim ${edit_embed_dim} --embed-init-rescale ${rescale_factor} \ 255 | ${train_separate_str} \ 256 | --retrieve-embed ${template_emb_file} \ 257 | --train-embed pretrained_sent_embeddings/${data_bin}.train.${emb_type}.hdf5 \ 258 | --valid-embed pretrained_sent_embeddings/${data_bin}.${valid_subset}.${emb_type}.hdf5 \ 259 | --reinforce ${reinforce} --infer-ns ${ns} --reinforce-temperature ${temperature} \ 260 | --freeze-retriever ${freeze_retriever} --decoder-copy ${copy} \ 261 | --inveditor-embed-path ${glove_path} --encoder-embed-path ${glove_path} --decoder-embed-path ${glove_path} \ 262 | --encoder-embed-dim 300 --decoder-embed-dim 300 \ 263 | --grad-lambda ${grad_lambda} --entropy-weight ${entropy_w} --term2-weight ${term2_w} \ 264 | --user-dir support_prototype \ 265 | --forget-rate ${forget_rate} --decay-rate ${decay_rate} --retrieve-split ${retrieve_split} --alpha ${alpha} --vmf-kappa ${kappa} \ 266 | --linear-bias ${linear_bias} --stop-bert-grad ${stop_bert_grad} \ 267 | --criterion ${criterion} --label-smoothing 0. --num-workers 0 \ 268 | --max-tokens ${max_tokens} --num-class ${num_class} \ 269 | --log-format ${log_format} --log-interval 5 \ 270 | --retriever ${retriever} --inv-editor ${inv_editor} \ 271 | --path ${LOADDIR}/checkpoint_best.pt \ 272 | --eval-mode ${eval_mode} ${decode_str} --gen-np ${gen_prototype} --gen-nz ${gen_nz}\ 273 | > ${LOADDIR}/eval_${eval_mode}_${decode_strategy}.log 274 | -------------------------------------------------------------------------------- /scripts/generation_to_analysis_file.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | 5 | 6 | parser = argparse.ArgumentParser( 7 | description='take generation file to a tab splitted prototype-generation file') 8 | parser.add_argument('--input', type=str, default='eval_gen_sample_beam.log', help='the file name') 9 | parser.add_argument('--exp-dir', type=str, help='exp dir') 10 | 11 | args = parser.parse_args() 12 | fout = open(os.path.join(args.exp_dir, 'generation_analysis_input.txt'), 'w') 13 | 14 | with open(os.path.join(args.exp_dir, args.input)) as fin: 15 | while True: 16 | line = fin.readline() 17 | if not line: 18 | break 19 | if line.startswith('S-'): 20 | prototype = line.rstrip().split('\t')[1] 21 | examples = [] 22 | while True: 23 | tmp = fin.readline() 24 | if not tmp or tmp == '\n': 25 | break 26 | 27 | if '-generations-' in tmp: 28 | continue 29 | 30 | if len(tmp.rstrip().split('\t')) == 3: 31 | example = tmp.rstrip().split('\t')[2] 32 | examples.append(example) 33 | 34 | for example in examples: 35 | fout.write('{}\t{}\n'.format(prototype, example)) 36 | 37 | fout.close() -------------------------------------------------------------------------------- /scripts/get_mask_ids.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | if __name__ == '__main__': 5 | parser = argparse.ArgumentParser(description='pre-compute mask id files') 6 | parser.add_argument('dataset', type=str, help='the path to training file') 7 | 8 | args = parser.parse_args() 9 | 10 | template2id = {} 11 | with open(f'datasets/{args.dataset}/template.txt') as fin: 12 | for i, line in enumerate(fin): 13 | template2id[line.strip()] = i 14 | 15 | 16 | with open(f'datasets/{args.dataset}/train.txt') as fin, \ 17 | open(f'data-bin/{args.dataset}/mask_id.txt', 'w') as fout: 18 | for i, line in enumerate(fin): 19 | fout.write(f'{template2id.get(line.rstrip(), -1)}\n') 20 | if i % 10000 == 0: 21 | print("processed {} lines".format(i)) 22 | 23 | -------------------------------------------------------------------------------- /scripts/inferece_analysis.sh: -------------------------------------------------------------------------------- 1 | 2 | gpu=$1 3 | exp_dir=$2 4 | 5 | python support_prototype/scripts/template_to_analysis_file.py --exp-dir ${exp_dir} 6 | CUDA_VISIBLE_DEVICES=${gpu} python support_prototype/scripts/pos_eval.py --prefix inference --exp-dir ${exp_dir} 7 | -------------------------------------------------------------------------------- /scripts/interpolation_ppl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | from collections import namedtuple 5 | 6 | LL=namedtuple('Sent', ['tokens', 'll']) 7 | 8 | 9 | def read_file(fname, guu_lm=False, length=None): 10 | res = {} 11 | with open(fname) as fin: 12 | for i, line in enumerate(fin): 13 | if not guu_lm: 14 | id_, tokens, ll = line.rstrip().split() 15 | id_ = int(id_) 16 | tokens = int(tokens) 17 | ll = float(ll) 18 | assert id_ not in res 19 | else: 20 | id_ = i 21 | tokens = 0 22 | ll = -float(line.rstrip()) * np.log(10) 23 | res[id_] = LL(tokens=tokens, ll=ll) 24 | 25 | if length is not None and (i+1) == length: 26 | break 27 | 28 | return res, len(res) 29 | 30 | def compute_ppl(res, input_tokens=None): 31 | ntokens = 0 32 | ll = 0 33 | for k,v in res.items(): 34 | ntokens += v.tokens 35 | ll += v.ll 36 | 37 | if ntokens == 0: 38 | ntokens = input_tokens 39 | 40 | return -ll, ntokens, np.exp(-ll/ntokens) 41 | 42 | def combined_ppl(data1, data2, discount): 43 | ntokens = 0 44 | ll_new = [] 45 | for k in data1: 46 | if data2[k].tokens != 0: 47 | assert data1[k].tokens == data2[k].tokens 48 | ntokens += data1[k].tokens 49 | ll_new.append(np.logaddexp(data1[k].ll + np.log(discount), 50 | data2[k].ll + np.log(1. - discount))) 51 | 52 | return np.exp(-np.sum(ll_new) / ntokens) 53 | 54 | 55 | 56 | 57 | 58 | parser = argparse.ArgumentParser(description='computer ppl with interpolation') 59 | parser.add_argument('--input1', type=str, help='the path to the input text file') 60 | parser.add_argument('--input2', type=str, help='the prefix to the saving file') 61 | parser.add_argument('--discount', type=float, help='the prefix to the saving file') 62 | parser.add_argument('--guu_lm', action='store_true', default=False, help='input2 is guu query file') 63 | 64 | args = parser.parse_args() 65 | 66 | data1, length = read_file(args.input1) 67 | data2, _ = read_file(args.input2, args.guu_lm, length) 68 | 69 | total_loss, ntokens, ppl = compute_ppl(data1) 70 | print("input 1 total loss: {}, ntokens: {}, ppl: {}".format(total_loss, ntokens, ppl)) 71 | 72 | total_loss, ntokens, ppl = compute_ppl(data2, ntokens) 73 | print("input 2 total loss: {}, ntokens: {}, ppl: {}".format(total_loss, ntokens, ppl)) 74 | 75 | print("interpolated ppl: {}".format(combined_ppl(data1, data2, args.discount))) 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /scripts/lower_case.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim:fenc=utf-8 4 | # 5 | # Copyright © 2020-03-02 Junxian 6 | # 7 | # Distributed under terms of the MIT license. 8 | 9 | import sys 10 | 11 | if __name__ == '__main__': 12 | save = set() 13 | for line in sys.stdin: 14 | print(line.rstrip().lower()) 15 | 16 | -------------------------------------------------------------------------------- /scripts/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | 3 | # $Id$ 4 | use strict; 5 | 6 | my $lowercase = 0; 7 | if ($ARGV[0] eq "-lc") { 8 | $lowercase = 1; 9 | shift; 10 | } 11 | 12 | my $stem = $ARGV[0]; 13 | if (!defined $stem) { 14 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 15 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 16 | exit(1); 17 | } 18 | 19 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 20 | 21 | my @REF; 22 | my $ref=0; 23 | while(-e "$stem$ref") { 24 | &add_to_ref("$stem$ref",\@REF); 25 | $ref++; 26 | } 27 | &add_to_ref($stem,\@REF) if -e $stem; 28 | die("ERROR: could not find reference file $stem") unless scalar @REF; 29 | 30 | sub add_to_ref { 31 | my ($file,$REF) = @_; 32 | my $s=0; 33 | open(REF,$file) or die "Can't read $file"; 34 | while() { 35 | chop; 36 | push @{$$REF[$s++]}, $_; 37 | } 38 | close(REF); 39 | } 40 | 41 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 42 | my $s=0; 43 | while() { 44 | chop; 45 | $_ = lc if $lowercase; 46 | my @WORD = split; 47 | my %REF_NGRAM = (); 48 | my $length_translation_this_sentence = scalar(@WORD); 49 | my ($closest_diff,$closest_length) = (9999,9999); 50 | foreach my $reference (@{$REF[$s]}) { 51 | # print "$s $_ <=> $reference\n"; 52 | $reference = lc($reference) if $lowercase; 53 | my @WORD = split(' ',$reference); 54 | my $length = scalar(@WORD); 55 | my $diff = abs($length_translation_this_sentence-$length); 56 | if ($diff < $closest_diff) { 57 | $closest_diff = $diff; 58 | $closest_length = $length; 59 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 60 | } elsif ($diff == $closest_diff) { 61 | $closest_length = $length if $length < $closest_length; 62 | # from two references with the same closeness to me 63 | # take the *shorter* into account, not the "first" one. 64 | } 65 | for(my $n=1;$n<=4;$n++) { 66 | my %REF_NGRAM_N = (); 67 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 68 | my $ngram = "$n"; 69 | for(my $w=0;$w<$n;$w++) { 70 | $ngram .= " ".$WORD[$start+$w]; 71 | } 72 | $REF_NGRAM_N{$ngram}++; 73 | } 74 | foreach my $ngram (keys %REF_NGRAM_N) { 75 | if (!defined($REF_NGRAM{$ngram}) || 76 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 77 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 78 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 79 | } 80 | } 81 | } 82 | } 83 | $length_translation += $length_translation_this_sentence; 84 | $length_reference += $closest_length; 85 | for(my $n=1;$n<=4;$n++) { 86 | my %T_NGRAM = (); 87 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 88 | my $ngram = "$n"; 89 | for(my $w=0;$w<$n;$w++) { 90 | $ngram .= " ".$WORD[$start+$w]; 91 | } 92 | $T_NGRAM{$ngram}++; 93 | } 94 | foreach my $ngram (keys %T_NGRAM) { 95 | $ngram =~ /^(\d+) /; 96 | my $n = $1; 97 | # my $corr = 0; 98 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 99 | $TOTAL[$n] += $T_NGRAM{$ngram}; 100 | if (defined($REF_NGRAM{$ngram})) { 101 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 102 | $CORRECT[$n] += $T_NGRAM{$ngram}; 103 | # $corr = $T_NGRAM{$ngram}; 104 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 105 | } 106 | else { 107 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 108 | # $corr = $REF_NGRAM{$ngram}; 109 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 110 | } 111 | } 112 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 113 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 114 | } 115 | } 116 | $s++; 117 | } 118 | my $brevity_penalty = 1; 119 | my $bleu = 0; 120 | 121 | my @bleu=(); 122 | 123 | for(my $n=1;$n<=4;$n++) { 124 | if (defined ($TOTAL[$n])){ 125 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 126 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 127 | }else{ 128 | $bleu[$n]=0; 129 | } 130 | } 131 | 132 | if ($length_reference==0){ 133 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 134 | exit(1); 135 | } 136 | 137 | if ($length_translation<$length_reference) { 138 | $brevity_penalty = exp(1-$length_reference/$length_translation); 139 | } 140 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 141 | my_log( $bleu[2] ) + 142 | my_log( $bleu[3] ) + 143 | my_log( $bleu[4] ) ) / 4) ; 144 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 145 | 100*$bleu, 146 | 100*$bleu[1], 147 | 100*$bleu[2], 148 | 100*$bleu[3], 149 | 100*$bleu[4], 150 | $brevity_penalty, 151 | $length_translation / $length_reference, 152 | $length_translation, 153 | $length_reference; 154 | 155 | sub my_log { 156 | return -9999999999 unless $_[0]; 157 | return log($_[0]); 158 | } 159 | -------------------------------------------------------------------------------- /scripts/pos_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import subprocess 4 | import random 5 | import edlib 6 | from typing import List 7 | from collections import Counter 8 | 9 | import stanza 10 | 11 | class ExtractMetric(object): 12 | """used for precision recall""" 13 | def __init__(self, nume=0, denom_p=0, denom_r=0, precision=0, recall=0, f1=0): 14 | super(ExtractMetric, self).__init__() 15 | self.nume = nume 16 | self.denom_p = denom_p 17 | self.denom_r = denom_r 18 | self.precision = precision 19 | self.recall = recall 20 | self.f1 = f1 21 | 22 | def read_file(fname, len_cut): 23 | res1, res2 = [], [] 24 | with open(fname) as fin: 25 | for line in fin: 26 | x, y = line.rstrip().split('\t') 27 | if len(x.split()) > len_cut or len(y.split()) > len_cut: 28 | continue 29 | res1.append(x) 30 | res2.append(y) 31 | 32 | return res1, res2 33 | 34 | def write_file(fname: str, data: List[str]): 35 | with open(fname, 'w') as fout: 36 | for sent in data: 37 | if isinstance(sent, list): 38 | fout.write('{}\n'.format(' '.join(sent))) 39 | else: 40 | fout.write('{}\n'.format(sent)) 41 | 42 | def eval_edit(prototype, example): 43 | 44 | def flat_cigar(cigar): 45 | """flatten the result path returned by edlib.align 46 | """ 47 | r = [] 48 | pointer = 0 49 | 50 | while pointer < len(cigar): 51 | num = [] 52 | while cigar[pointer].isdigit(): 53 | num.append(cigar[pointer]) 54 | pointer += 1 55 | num = int(''.join(num)) 56 | 57 | r.extend([cigar[pointer]] * num) 58 | pointer += 1 59 | 60 | return r 61 | 62 | 63 | res = {} 64 | for p_sent, e_sent in zip(prototype, example): 65 | p_pos = [x.upos for x in p_sent.words] 66 | e_pos = [x.upos for x in e_sent.words] 67 | 68 | p_text = [x.text for x in p_sent.words] 69 | e_text = [x.text for x in e_sent.words] 70 | 71 | edit_operation = edlib.align(e_text, p_text, task='path') 72 | edit_operation = flat_cigar(edit_operation['cigar']) 73 | 74 | new_p_text = [] 75 | new_e_text = [] 76 | new_p_pos = [] 77 | new_e_pos = [] 78 | src_cur = tgt_cur = 0 79 | 80 | for edit in edit_operation: 81 | if edit == '=' or edit == 'X': 82 | new_p_text.append(p_text[src_cur]) 83 | new_p_pos.append(p_pos[src_cur]) 84 | new_e_text.append(e_text[tgt_cur]) 85 | new_e_pos.append(e_pos[tgt_cur]) 86 | src_cur += 1 87 | tgt_cur += 1 88 | elif edit == 'I': 89 | new_p_text.append(-1) 90 | new_p_pos.append(-1) 91 | new_e_text.append(e_text[tgt_cur]) 92 | new_e_pos.append(e_pos[tgt_cur]) 93 | tgt_cur += 1 94 | elif edit == 'D': 95 | new_p_text.append(p_text[src_cur]) 96 | new_p_pos.append(p_pos[src_cur]) 97 | new_e_text.append(-1) 98 | new_e_pos.append(-1) 99 | src_cur += 1 100 | else: 101 | raise ValueError('{} edit operation is invalid!'.format(edit)) 102 | 103 | for i, edit in enumerate(edit_operation): 104 | if edit not in res: 105 | res[edit] = Counter() 106 | 107 | if edit == '=': 108 | res[edit]['{}={}'.format(new_p_pos[i], new_e_pos[i])] += 1 109 | elif edit == 'X': 110 | res[edit]['{}->{}'.format(new_p_pos[i], new_e_pos[i])] += 1 111 | elif edit == 'I': 112 | res[edit]['+{}'.format(new_e_pos[i])] += 1 113 | elif edit == 'D': 114 | res[edit]['-{}'.format(new_p_pos[i])] += 1 115 | else: 116 | raise ValueError 117 | 118 | return res 119 | 120 | 121 | 122 | 123 | 124 | def eval_f1(prototype, example): 125 | res = {} 126 | for p_sent, e_sent in zip(prototype, example): 127 | p_pos = [x.upos for x in p_sent.words] 128 | e_pos = [x.upos for x in e_sent.words] 129 | 130 | p_text = [x.text for x in p_sent.words] 131 | e_text = [x.text for x in e_sent.words] 132 | 133 | e_word_counter = Counter(e_text) 134 | for word, pos in zip(p_text, p_pos): 135 | if pos not in res: 136 | res[pos] = ExtractMetric( 137 | nume=0, 138 | denom_p=0, 139 | denom_r=0, 140 | precision=0, 141 | recall=0, 142 | f1=0 143 | ) 144 | 145 | res[pos].denom_r += 1 146 | if e_word_counter[word] > 0: 147 | e_word_counter[word] -= 1 148 | res[pos].nume += 1 149 | 150 | e_pos_counter = Counter(e_pos) 151 | for k, v in e_pos_counter.items(): 152 | if k not in res: 153 | res[k] = ExtractMetric( 154 | nume=0, 155 | denom_p=0, 156 | denom_r=0, 157 | precision=0, 158 | recall=0, 159 | f1=0 160 | ) 161 | 162 | res[k].denom_p += v 163 | 164 | for k, v in res.items(): 165 | if res[k].denom_p != 0 and res[k].denom_r != 0 and res[k].nume != 0: 166 | res[k].precision = res[k].nume / res[k].denom_p 167 | res[k].recall = res[k].nume / res[k].denom_r 168 | res[k].f1 = 2 * res[k].precision * res[k].recall / (res[k].precision + res[k].recall) 169 | 170 | return res 171 | 172 | 173 | def sentence_bleu(ref_path, hypo_path): 174 | sent_bleu = subprocess.getoutput( 175 | "fairseq-score --ref {} --sys {} --sentence-bleu".format(ref_path, hypo_path)) 176 | bleu_list = [float(line.split()[3].rstrip(',')) for line in sent_bleu.split('\n')[1:]] 177 | return sum(bleu_list) / len(bleu_list) 178 | 179 | def generate_rand_prototype(exp_dir, num): 180 | dataset_to_template = { 181 | "coco40k": "support_prototype/datasets/coco/coco.template.40k.txt", 182 | "yelp": "support_prototype/datasets/yelp_data/yelp.template.50k.lower.txt", 183 | "yelp_large": "support_prototype/datasets/yelp_large_data/yelp_large.template.100k.txt", 184 | } 185 | 186 | def parse_exp_dir(name): 187 | dataset = name.rstrip('/').split('/')[-1].split('_')[0] 188 | return dataset 189 | 190 | dataset = parse_exp_dir(exp_dir) 191 | 192 | return subprocess.getoutput( 193 | "shuf -n {} {}".format(num, dataset_to_template[dataset])).split('\n') 194 | 195 | 196 | parser = argparse.ArgumentParser(description='Evaluate analysis metrics') 197 | parser.add_argument('--prefix', type=str, choices=['inference', 'generation'], 198 | help='prediction file prefix') 199 | parser.add_argument('--exp-dir', type=str, help='output directory') 200 | 201 | args = parser.parse_args() 202 | 203 | fout = open(os.path.join(args.exp_dir, 'analysis_{}_res.txt'.format(args.prefix)), 'w') 204 | len_cut = 1000 205 | prototypes, examples = read_file(os.path.join(args.exp_dir, '{}_analysis_input.txt'.format(args.prefix)), len_cut=len_cut) 206 | prototype_path = os.path.join(args.exp_dir, 'prototype.txt') 207 | prototype_pos_path = os.path.join(args.exp_dir, 'prototype_pos.txt') 208 | 209 | prototype_rand_path = os.path.join(args.exp_dir, 'prototype_rand.txt') 210 | prototype_pos_rand_path = os.path.join(args.exp_dir, 'prototype_pos_rand.txt') 211 | 212 | example_path = os.path.join(args.exp_dir, 'example.txt') 213 | example_pos_path = os.path.join(args.exp_dir, 'example_pos.txt') 214 | 215 | prototypes_rand = generate_rand_prototype(args.exp_dir, len(examples)) 216 | 217 | write_file(prototype_path, prototypes) 218 | write_file(example_path, examples) 219 | write_file(prototype_rand_path, prototypes_rand) 220 | 221 | # surface BLEU 222 | # bleu = subprocess.getoutput( 223 | # "./support_prototype/scripts/multi-bleu.perl {} < {}".format(prototype_path, example_rand_path)) 224 | bleu = sentence_bleu(prototype_rand_path, example_path) 225 | print('Regular BLEU (random baseline): \n{}'.format(bleu)) 226 | fout.write('Regular BLEU (random baseline): \n{}'.format(bleu)) 227 | 228 | fout.write('\n\n\n') 229 | 230 | # bleu = subprocess.getoutput( 231 | # "./support_prototype/scripts/multi-bleu.perl {} < {}".format(prototype_path, example_path)) 232 | bleu = sentence_bleu(prototype_path, example_path) 233 | print('Regular BLEU: \n{}'.format(bleu)) 234 | fout.write('Regular BLEU: \n{}'.format(bleu)) 235 | 236 | fout.write('\n\n\n') 237 | 238 | # POS tagging 239 | print('POS tagging') 240 | nlp = stanza.Pipeline(lang='en', processors='tokenize,mwt,pos', tokenize_pretokenized=True) 241 | prototype_doc = nlp('\n'.join(prototypes)) 242 | example_doc = nlp('\n'.join(examples)) 243 | prototype_rand_doc = nlp('\n'.join(prototypes_rand)) 244 | 245 | prototypes_pos = [[word.upos for word in sent.words] for sent in prototype_doc.sentences] 246 | examples_pos = [[word.upos for word in sent.words] for sent in example_doc.sentences] 247 | 248 | prototypes_pos_rand = [[word.upos for word in sent.words]for sent in prototype_rand_doc.sentences] 249 | 250 | write_file(prototype_pos_path, prototypes_pos) 251 | write_file(example_pos_path, examples_pos) 252 | write_file(prototype_pos_rand_path, prototypes_pos_rand) 253 | 254 | 255 | # POS BLEU 256 | # bleu = subprocess.getoutput( 257 | # "./support_prototype/scripts/multi-bleu.perl {} < {}".format(prototype_pos_path, example_pos_rand_path)) 258 | bleu = sentence_bleu(prototype_pos_rand_path, example_pos_path) 259 | print('POS BLEU (random baseline): \n{}'.format(bleu)) 260 | fout.write('POS BLEU (random baseline): \n{}'.format(bleu)) 261 | 262 | fout.write('\n\n\n') 263 | 264 | # bleu = subprocess.getoutput( 265 | # "./support_prototype/scripts/multi-bleu.perl {} < {}".format(prototype_pos_path, example_pos_path)) 266 | bleu = sentence_bleu(prototype_pos_path, example_pos_path) 267 | print('POS BLEU: \n{}'.format(bleu)) 268 | fout.write('POS BLEU: \n{}'.format(bleu)) 269 | 270 | fout.write('\n\n\n') 271 | 272 | # break down precision and recall 273 | print("compute precision, recall, f1") 274 | assert len(prototypes) == len(prototypes_pos) 275 | assert len(examples) == len(examples_pos) 276 | 277 | res = eval_f1(list(prototype_rand_doc.sentences), list(example_doc.sentences)) 278 | 279 | res = sorted(res.items(), key=lambda item: -item[1].f1) 280 | 281 | fout.write('random baseline precision-recall\n') 282 | fout.write('POS recall precision f1\n') 283 | for k, v in res: 284 | fout.write('{} {} {} {}\n'.format(k, v.recall, v.precision, v.f1)) 285 | 286 | fout.write('\n\n\n') 287 | 288 | res = eval_f1(list(prototype_doc.sentences), list(example_doc.sentences)) 289 | res = sorted(res.items(), key=lambda item: -item[1].f1) 290 | 291 | fout.write('precision-recall\n') 292 | fout.write('POS recall precision f1\n') 293 | for k, v in res: 294 | fout.write('{} {} {} {}\n'.format(k, v.recall, v.precision, v.f1)) 295 | 296 | 297 | fout.write('\n\n\n') 298 | 299 | # edit operations 300 | print("edit analysis") 301 | res = eval_edit(list(prototype_doc.sentences), list(example_doc.sentences)) 302 | total = sum([sum(v.values()) for k, v in res.items()]) 303 | fout.write('total: {}\n'.format(total)) 304 | res = sorted(res.items(), key=lambda item: (-sum(item[1].values()))) 305 | for k, v in res: 306 | fout.write('{}: {}\n'.format(k, sum(v.values()))) 307 | for k1, v1 in v.most_common(): 308 | fout.write('{}: {} ({:.3f}), '.format(k1, v1, v1 / sum(v.values()))) 309 | fout.write('\n\n') 310 | 311 | fout.close() 312 | 313 | -------------------------------------------------------------------------------- /scripts/precompute_bert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import argparse 4 | import torch 5 | import json 6 | # import h5py 7 | import gzip, csv 8 | import numpy as np 9 | 10 | from tqdm import tqdm 11 | 12 | from torch.nn.utils.rnn import pad_sequence 13 | from transformers import * 14 | 15 | 16 | 17 | def get_sentence_features(batches, tokenizer, model, device, maxlen=500): 18 | features = tokenizer.batch_encode_plus(batches, padding=True, 19 | return_attention_mask=True, return_token_type_ids=True, 20 | truncation=True, max_length=maxlen) 21 | attention_mask = torch.tensor(features['attention_mask'], device=device) 22 | input_ids = torch.tensor(features['input_ids'], device=device) 23 | token_type_ids=torch.tensor(features['token_type_ids'], device=device) 24 | 25 | # (batch, seq_len, nfeature) 26 | token_embeddings = model(input_ids=input_ids, 27 | attention_mask=attention_mask, 28 | token_type_ids=token_type_ids)[0] 29 | 30 | # mean of embeddings as sentence embeddings 31 | embeddings = (attention_mask.unsqueeze(-1) * token_embeddings).sum(1) / attention_mask.sum(1).unsqueeze(-1) 32 | 33 | return embeddings 34 | 35 | 36 | def hdf5_create_dataset(group, input_file, fp16=False): 37 | global tokenizer, model, device 38 | 39 | print(f'precompute embeddings for {input_file}') 40 | pbar = tqdm() 41 | with open(input_file, 'r') as fin: 42 | batches = [] 43 | cur = 0 44 | for i, line in enumerate(fin): 45 | batches.append(line.strip()) 46 | if (i+1) % batch_size == 0: 47 | with torch.no_grad(): 48 | embeddings = get_sentence_features(batches, tokenizer, model, device) 49 | 50 | for j, embed in enumerate(embeddings): 51 | embed = embed.cpu().numpy() 52 | if fp16: 53 | embed = embed.astype('float16') 54 | group.create_dataset(f'{cur}', embed.shape, 55 | dtype='float32' if not fp16 else 'float16', data=embed) 56 | cur += 1 57 | 58 | pbar.update(len(batches)) 59 | batches = [] 60 | 61 | if len(batches) > 0: 62 | with torch.no_grad(): 63 | embeddings = get_sentence_features(batches, tokenizer, model, device) 64 | 65 | for j, embed in enumerate(embeddings): 66 | embed = embed.cpu().numpy() 67 | if fp16: 68 | embed = embed.astype('float16') 69 | group.create_dataset(f'{cur}', embed.shape, 70 | dtype='float32' if not fp16 else 'float16', data=embed) 71 | cur += 1 72 | 73 | def jsonl_create_dataset(output_file, input_file, fp16=False): 74 | global tokenizer, model, device 75 | 76 | print(f'precompute embeddings for {input_file}') 77 | pbar = tqdm() 78 | fout = open(output_file, 'w') 79 | 80 | with open(input_file, 'r') as fin: 81 | batches = [] 82 | cur = 0 83 | for i, line in enumerate(fin): 84 | batches.append(line.strip()) 85 | if (i+1) % batch_size == 0: 86 | with torch.no_grad(): 87 | embeddings = get_sentence_features(batches, tokenizer, model, device) 88 | 89 | for j, embed in enumerate(embeddings): 90 | embed = embed.cpu().numpy() 91 | if fp16: 92 | embed = embed.astype('float16') 93 | fout.write(json.dumps({cur: embed.tolist()})) 94 | fout.write('\n') 95 | cur += 1 96 | 97 | pbar.update(len(batches)) 98 | batches = [] 99 | 100 | if len(batches) > 0: 101 | with torch.no_grad(): 102 | embeddings = get_sentence_features(batches, tokenizer, model, device) 103 | 104 | for j, embed in enumerate(embeddings): 105 | embed = embed.cpu().numpy() 106 | if fp16: 107 | embed = embed.astype('float16') 108 | fout.write(json.dumps({cur: embed.tolist()})) 109 | fout.write('\n') 110 | cur += 1 111 | fout.close() 112 | 113 | def csv_create_dataset(output_file, input_file, fp16=False): 114 | global tokenizer, model, device 115 | 116 | print(f'precompute embeddings for {input_file}') 117 | pbar = tqdm() 118 | fout = gzip.open(output_file, 'wt') 119 | # fout = open(output_file, 'w') 120 | 121 | fieldnames = ['embedding'] 122 | writer = csv.DictWriter(fout, fieldnames=fieldnames) 123 | 124 | writer.writeheader() 125 | with open(input_file, 'r') as fin: 126 | batches = [] 127 | cur = 0 128 | for i, line in enumerate(fin): 129 | batches.append(line.strip()) 130 | if (i+1) % batch_size == 0: 131 | with torch.no_grad(): 132 | embeddings = get_sentence_features(batches, tokenizer, model, device) 133 | 134 | for j, embed in enumerate(embeddings): 135 | embed = embed.cpu().numpy() 136 | if fp16: 137 | embed = embed.astype('float16') 138 | writer.writerow({'embedding': embed.tolist()}) 139 | cur += 1 140 | 141 | pbar.update(len(batches)) 142 | batches = [] 143 | 144 | if len(batches) > 0: 145 | with torch.no_grad(): 146 | embeddings = get_sentence_features(batches, tokenizer, model, device) 147 | 148 | for j, embed in enumerate(embeddings): 149 | embed = embed.cpu().numpy() 150 | if fp16: 151 | embed = embed.astype('float16') 152 | writer.writerow({'embedding': embed.tolist()}) 153 | cur += 1 154 | fout.close() 155 | 156 | 157 | def np_create_dataset(output_file, input_file, fp16=False): 158 | global tokenizer, model, device 159 | 160 | print(f'precompute embeddings for {input_file}') 161 | pbar = tqdm() 162 | # fout = open(output_file, 'w') 163 | 164 | proc = subprocess.run(['wc', '-l', input_file], capture_output=True) 165 | dstore_size = int(proc.stdout.decode('utf-8').split()[0]) 166 | 167 | dtype = 'float16' if fp16 else 'float32' 168 | print(f'{dstore_size} examples') 169 | dstore = np.memmap(output_file, 170 | dtype=dtype, 171 | mode='w+', 172 | shape=(dstore_size, model.config.hidden_size), 173 | ) 174 | 175 | with open(input_file, 'r') as fin: 176 | batches = [] 177 | cur = 0 178 | for i, line in enumerate(fin): 179 | batches.append(line.strip()) 180 | if (i+1) % batch_size == 0: 181 | with torch.no_grad(): 182 | embeddings = get_sentence_features(batches, tokenizer, model, device) 183 | 184 | dstore[cur:cur+embeddings.size(0)] = embeddings.cpu().numpy().astype(dtype) 185 | cur += embeddings.size(0) 186 | 187 | assert model.config.hidden_size == embeddings.size(1) 188 | 189 | pbar.update(len(batches)) 190 | batches = [] 191 | 192 | if len(batches) > 0: 193 | with torch.no_grad(): 194 | embeddings = get_sentence_features(batches, tokenizer, model, device) 195 | 196 | dstore[cur:cur+embeddings.size(0)] = embeddings.cpu().numpy().astype(dtype) 197 | cur += embeddings.size(0) 198 | 199 | if __name__ == '__main__': 200 | parser = argparse.ArgumentParser(description='pre-compute the Bert embeddings') 201 | parser.add_argument('dataset', type=str, help='the path to the dataset name') 202 | parser.add_argument('--split', type=str, default=None, 203 | help='if specified, only compute for this split') 204 | parser.add_argument('--fp32', action='store_true', default=False, 205 | help='whether to use half float point. It uses half float by default') 206 | parser.add_argument('--sent-bert', action='store_true', default=False, 207 | help='whether to use sentence-BERT') 208 | 209 | args = parser.parse_args() 210 | args.cuda = torch.cuda.is_available() 211 | 212 | save_dir = f"precompute_embedding_datasets/{args.dataset}" 213 | 214 | os.makedirs(save_dir, exist_ok=True) 215 | 216 | device = "cuda" if args.cuda else "cpu" 217 | 218 | model_name = 'bert-base-uncased' if not args.sent_bert else 'sentence-transformers/bert-base-nli-mean-tokens' 219 | model_short = 'bert' if not args.sent_bert else 'sentbert' 220 | 221 | model = AutoModel.from_pretrained(model_name) 222 | tokenizer = AutoTokenizer.from_pretrained(model_name) 223 | 224 | model.to(device) 225 | model.eval() 226 | 227 | gname_list = [args.split] if args.split is not None else ['valid', 'test', 'template', 'train'] 228 | batch_size = 128 229 | 230 | for gname in gname_list: 231 | if os.path.isfile(f'datasets/{args.dataset}/{gname}.txt'): 232 | np_create_dataset(os.path.join(save_dir, f'{args.dataset}.{model_short}.{gname}.npy'), 233 | os.path.join(f'datasets/{args.dataset}/{gname}.txt'), not args.fp32) 234 | 235 | # for gname in gname_list: 236 | # if os.path.isfile(f'datasets/{args.dataset}/{gname}.txt'): 237 | # csv_create_dataset(os.path.join(save_dir, f'{args.dataset}.{model_short}.{gname}.csv.gz'), 238 | # os.path.join(f'datasets/{args.dataset}/{gname}.txt'), args.fp16) 239 | 240 | # for gname in gname_list: 241 | # if os.path.isfile(f'datasets/{args.dataset}/{gname}.txt'): 242 | # with h5py.File(os.path.join(save_dir, f'{args.dataset}.{model_short}.{gname}.hdf5'), 'w') as fout: 243 | # hdf5_create_dataset(fout, os.path.join(f'datasets/{args.dataset}/{gname}.txt')) 244 | -------------------------------------------------------------------------------- /scripts/template_to_analysis_file.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | 5 | 6 | parser = argparse.ArgumentParser( 7 | description='take template file to a tab splitted prototype-example file') 8 | parser.add_argument('--input', type=str, default='templates_eval_valid.txt', help='the file name') 9 | parser.add_argument('--exp-dir', type=str, help='the prefix to the saving file') 10 | 11 | args = parser.parse_args() 12 | fout = open(os.path.join(args.exp_dir, 'inference_analysis_input.txt'), 'w') 13 | 14 | with open(os.path.join(args.exp_dir, args.input)) as fin: 15 | while True: 16 | line = fin.readline() 17 | if not line: 18 | break 19 | if line.startswith('src:'): 20 | example = ' '.join(line.rstrip().split()[1:]) 21 | while True: 22 | tmp = fin.readline() 23 | if not tmp: 24 | raise ValueError 25 | 26 | if '-top K templates-' in tmp: 27 | prototype = fin.readline() 28 | prototype = prototype.rstrip().split('\t')[2] 29 | break 30 | fout.write('{}\t{}\n'.format(prototype, example)) 31 | 32 | fout.close() 33 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # 3 | # train.sh 4 | # Copyright (C) 2020-02-09 Junxian 5 | # 6 | # Distributed under terms of the MIT license. 7 | # 8 | 9 | DATE=`date +%Y%m%d` 10 | 11 | # general configuration 12 | data_bin="coco40k" 13 | emb_type="bert" 14 | kappa=30 15 | copy=0 # whether to have copy mechanism in the decoder 16 | inv_editor='levenshtein' 17 | edit_embed_dim=10 18 | retrieve_split=valid1 19 | criterion="sp_elbo" 20 | 21 | 22 | # optimization hyperparameters 23 | warmup_init_lr='1e-03' 24 | lr=0.001 25 | max_update=30000 26 | retriever=precompute_emb 27 | linear_bias=0 28 | weight_decay=0.0001 29 | stop_bert_grad=1 30 | freeze_retriever=0 31 | reinforce=1 32 | opt=adam 33 | update_freq=1 34 | 35 | # hyperparameters for lambda update 36 | forget_rate=0.8 37 | decay_rate=1 38 | 39 | # some important hyperparameters that we may need to change 40 | rescale_factor=0.3 # rescaling factor for reinforce sampling 41 | free_bits=0 # KL free bits for mitigating posterior collapse 42 | alpha=0.1 # Dirichlet hyperparameters 43 | lambda_config="0:0,1500:1" # KL weights annealing schedule 44 | GPU=0 45 | 46 | 47 | # evaluation parameters, only used during evaluationa after training 48 | eval_mode="none" # perform training by default 49 | prune_num="-1" 50 | valid_subset="valid" # use "valid" to test on valid set 51 | iw_nsamples=100 52 | 53 | while getopts ":g:a:p:k:r:f:c:u:e:d:" arg; do 54 | case $arg in 55 | g) GPU="$OPTARG" 56 | ;; 57 | a) alpha="$OPTARG" 58 | ;; 59 | p) LOADDIR="$OPTARG" 60 | ;; 61 | k) kappa="$OPTARG" 62 | ;; 63 | r) rescale_factor="$OPTARG" 64 | ;; 65 | f) free_bits="$OPTARG" 66 | ;; 67 | c) criterion="$OPTARG" 68 | ;; 69 | u) prune_num="$OPTARG" 70 | ;; 71 | e) eval_mode="$OPTARG" 72 | ;; 73 | d) data_bin="$OPTARG" 74 | ;; 75 | \?) echo "Invalid option -$OPTARG" >&2 76 | ;; 77 | esac 78 | done 79 | 80 | glove_path=glove_embeddings/${data_bin}_glove.txt 81 | emb_dataset_file=precompute_embedding_datasets/${data_bin}/${data_bin}.${emb_type} 82 | 83 | 84 | if [ "$eval_mode" = "none" ]; 85 | then 86 | stdout="stdout.log" 87 | eval_params="--eval-mode none" 88 | else 89 | stdout="eval_${valid_subset}_${eval_mode}_prune${prune_num}.log" 90 | eval_params="--eval-mode ${eval_mode} --iw-nsamples ${iw_nsamples} \ 91 | --valid-subset ${valid_subset} --prune-num ${prune_num} \ 92 | --reset-meters --reset-optimizer --write-loss-path loss_per_sent_${valid_subset}.txt" 93 | fi 94 | 95 | if [ "$criterion" = "topk_elbo" ]; 96 | then 97 | reinforce=0 98 | fi 99 | 100 | if [ "$data_bin" = "coco40k" ]; 101 | then 102 | max_tokens=2048 103 | save_interval_updates=0 104 | warmup_updates=1000 105 | max_update=15000 106 | log_interval=20 107 | validate_interval=1 108 | ns=10 109 | elif [ "$data_bin" = "yelp" ]; 110 | then 111 | max_tokens=2048 112 | save_interval_updates=5000 113 | warmup_updates=20000 114 | max_update=300000 115 | lambda_config="0:0,20000:1" 116 | log_interval=100 117 | validate_interval=1000 118 | ns=10 119 | elif [ "$data_bin" = "yelp_large" ]; 120 | then 121 | max_tokens=2048 # distributed on two gpus 122 | save_interval_updates=5000 123 | warmup_updates=150000 124 | max_update=500000 125 | kappa=40 126 | lambda_config="0:0,150000:1" 127 | log_interval=100 128 | validate_interval=1000 129 | ns=10 130 | else 131 | max_tokens=0 132 | save_interval_updates=0 133 | warmup_updates=0 134 | ns=0 135 | fi 136 | 137 | 138 | if [ "$criterion" = "lm_baseline" ]; 139 | then 140 | ns=1 141 | fi 142 | 143 | GPUSTR=$(printf "$GPU" | tr , _) 144 | 145 | lambda_conifg_str=$(printf "$lambda_config" | tr , _) 146 | lambda_conifg_str=$(printf "$lambda_conifg_str" | tr : _) 147 | 148 | if [[ -v LOADDIR && eval_mode = "none" ]]; 149 | then 150 | add_load_string="" 151 | cstring="_continue" 152 | restore_file=checkpoint_load.pt 153 | else 154 | add_load_string="" 155 | cstring="" 156 | restore_file=null.pt 157 | fi 158 | 159 | 160 | if [[ -v LOADDIR && eval_mode != "none" ]]; 161 | then 162 | SAVE=${LOADDIR} 163 | TENSORBOARD=${SAVE}/tensorboard 164 | else 165 | SAVE_ROOT="checkpoint/${data_bin}/${DATE}/${data_bin}_${opt}_noeditvec_alpha${alpha}_kappa${kappa}" 166 | SAVE_ROOT="${SAVE_ROOT}_ns${ns}" 167 | SAVE_ROOT="${SAVE_ROOT}_editdim${edit_embed_dim}" 168 | SAVE_ROOT="${SAVE_ROOT}_rtr${retriever}_fr${forget_rate}_dr${decay_rate}_rf${rescale_factor}_fb${free_bits}" 169 | SAVE_ROOT="${SAVE_ROOT}_embt${emb_type}_lc${lambda_conifg_str}_uf${update_freq}_gpu${GPUSTR}_c${criterion}${cstring}" 170 | 171 | SAVE=${SAVE_ROOT} 172 | TENSORBOARD=${SAVE}/tensorboard 173 | 174 | rm -r ${SAVE}; mkdir -p ${SAVE} ${TENSORBOARD} 175 | fi 176 | 177 | if [[ -v LOADDIR ]]; 178 | then 179 | cp ${LOADDIR}/checkpoint_best.pt ${SAVE}/checkpoint_load.pt 180 | fi 181 | 182 | CUDA_VISIBLE_DEVICES=${GPU} python train.py \ 183 | data-bin/${data_bin} \ 184 | --arch ${data_bin} --task sparse_prototype \ 185 | --optimizer ${opt} --adam-betas '(0.9, 0.98)' \ 186 | --lr ${lr} --lr-scheduler inverse_sqrt --warmup-updates ${warmup_updates} \ 187 | --warmup-init-lr ${warmup_init_lr} \ 188 | --weight-decay ${weight_decay} \ 189 | --edit-embed-dim ${edit_embed_dim} --embed-init-rescale ${rescale_factor} \ 190 | --free-bits ${free_bits} --lambda-t-config ${lambda_config} \ 191 | --emb-dataset-file ${emb_dataset_file} \ 192 | --reinforce ${reinforce} --infer-ns ${ns} \ 193 | --freeze-retriever ${freeze_retriever} --decoder-copy ${copy} \ 194 | --inveditor-embed-path ${glove_path} --encoder-embed-path ${glove_path} --decoder-embed-path ${glove_path} \ 195 | --encoder-embed-dim 300 --decoder-embed-dim 300 \ 196 | --user-dir sparse_prototype \ 197 | --forget-rate ${forget_rate} --decay-rate ${decay_rate} --retrieve-split ${retrieve_split} --alpha ${alpha} --vmf-kappa ${kappa} \ 198 | --linear-bias ${linear_bias} --stop-bert-grad ${stop_bert_grad} \ 199 | --criterion ${criterion} --label-smoothing 0. --num-workers 0 \ 200 | --max-tokens ${max_tokens} \ 201 | --log-format simple --log-interval ${log_interval} \ 202 | --retriever ${retriever} --inv-editor ${inv_editor} \ 203 | --max-update ${max_update} --update-freq ${update_freq} \ 204 | --validate-interval ${validate_interval} --best-checkpoint-metric ppl --no-epoch-checkpoints \ 205 | --no-last-checkpoints \ 206 | --save-interval-updates ${save_interval_updates} --keep-interval-updates 1 \ 207 | --save-dir ${SAVE} --tensorboard-logdir ${TENSORBOARD} \ 208 | ${add_load_string} --restore-file ${SAVE}/checkpoint_load.pt \ 209 | ${eval_params} \ 210 | | tee -a ${SAVE}/${stdout} 211 | -------------------------------------------------------------------------------- /scripts/train.slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=slurm_out/slurm-%A_%a.out 3 | #SBATCH --error=slurm_out/slurm-%A_%a.err 4 | #SBATCH --array=0-0%1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --mem=20g 7 | #SBATCH --cpus-per-task=10 8 | ##SBATCH --exclude=compute-0-31,compute-0-19,compute-0-15 9 | ##SBATCH —nodelist=compute-0-31,compute-0-30 10 | #SBATCH -t 0 11 | 12 | 13 | alpha=0.1 14 | free_bits=5 15 | 16 | bash scripts/train.sh -g 0 -a ${alpha} -f -------------------------------------------------------------------------------- /sparse_prototype/__init__.py: -------------------------------------------------------------------------------- 1 | from . import sp_criterion 2 | from . import sp_model 3 | from . import sp_task 4 | from . import guu_criterion 5 | from . import lm_criterion 6 | from . import topk_criterion -------------------------------------------------------------------------------- /sparse_prototype/distribution/__init__.py: -------------------------------------------------------------------------------- 1 | from .vmf_batch import * -------------------------------------------------------------------------------- /sparse_prototype/distribution/vmf_batch.py: -------------------------------------------------------------------------------- 1 | """ 2 | The code is upon some modification based on 3 | https://github.com/jiacheng-xu/vmf_vae_nlp 4 | """ 5 | 6 | import numpy as np 7 | import torch 8 | from scipy import special as sp 9 | 10 | 11 | class vMF(torch.nn.Module): 12 | def __init__(self, hid_dim, lat_dim, kappa=1, cuda=True): 13 | """ 14 | von Mises-Fisher distribution class with batch support and manual tuning kappa value. 15 | Implementation follows description of my paper and Guu's. 16 | """ 17 | 18 | super().__init__() 19 | self.hid_dim = hid_dim 20 | self.lat_dim = lat_dim 21 | self.kappa = kappa 22 | self.device = torch.device('cuda' if cuda else 'cpu') 23 | # self.func_kappa = torch.nn.Linear(hid_dim, lat_dim) 24 | self.func_mu = torch.nn.Linear(hid_dim, lat_dim) 25 | 26 | self.kld = torch.from_numpy(vMF._vmf_kld(kappa, lat_dim)).float().to(self.device) 27 | self.log_norm = vMF._log_normalization_constant(kappa, lat_dim) 28 | 29 | # set to 1e-10 to avoid numeric error 30 | self.log_norm_uniform = vMF._log_normalization_constant(0, lat_dim) 31 | 32 | print('KLD: {}'.format(self.kld.data[0])) 33 | 34 | def estimate_param(self, latent_code): 35 | ret_dict = {} 36 | ret_dict['kappa'] = self.kappa 37 | 38 | # Only compute mu, use mu/mu_norm as mu, 39 | # use 1 as norm, use diff(mu_norm, 1) as redundant_norm 40 | mu = self.func_mu(latent_code) 41 | 42 | norm = torch.norm(mu, 2, 1, keepdim=True) 43 | mu_norm_sq_diff_from_one = torch.pow(torch.add(norm, -1), 2) 44 | redundant_norm = torch.sum(mu_norm_sq_diff_from_one, dim=1, keepdim=True) 45 | ret_dict['norm'] = torch.ones_like(mu) 46 | ret_dict['redundant_norm'] = redundant_norm 47 | 48 | mu = mu / torch.norm(mu, p=2, dim=1, keepdim=True) 49 | ret_dict['mu'] = mu 50 | 51 | return ret_dict 52 | 53 | def compute_KLD(self, tup, batch_sz): 54 | return self.kld.expand(batch_sz) 55 | 56 | @staticmethod 57 | def _vmf_kld(k, d): 58 | tmp = (k * ((sp.iv(d / 2.0 + 1.0, k) + sp.iv(d / 2.0, k) * d / (2.0 * k)) / sp.iv(d / 2.0, k) - d / (2.0 * k)) \ 59 | + d * np.log(k) / 2.0 - np.log(sp.iv(d / 2.0, k)) \ 60 | - sp.loggamma(d / 2 + 1) - d * np.log(2) / 2).real 61 | if tmp != tmp: 62 | exit() 63 | return np.array([tmp]) 64 | 65 | @staticmethod 66 | def _vmf_kld_davidson(k, d): 67 | """ 68 | This should be the correct KLD. 69 | Empirically we find that _vmf_kld (as in the Guu paper) only deviates a little (<2%) in most cases we use. 70 | """ 71 | tmp = k * sp.iv(d / 2, k) / sp.iv(d / 2 - 1, k) + (d / 2 - 1) * torch.log(k) - torch.log( 72 | sp.iv(d / 2 - 1, k)) + np.log(np.pi) * d / 2 + np.log(2) - sp.loggamma(d / 2).real - (d / 2) * np.log( 73 | 2 * np.pi) 74 | if tmp != tmp: 75 | exit() 76 | return np.array([tmp]) 77 | 78 | @staticmethod 79 | def _log_normalization_constant(k, d): 80 | # When k=0, we go to the limit case: C = gamma(d/2 + 1) / (d * pi*{d/2}) 81 | # reference: https://en.wikipedia.org/wiki/N-sphere 82 | if k == 0: 83 | tmp = sp.loggamma(d / 2.0 + 1) - np.log(d) - d / 2.0 * np.log(np.pi) 84 | return tmp 85 | 86 | tmp = (d / 2.0 - 1) * np.log(k) - d / 2.0 * np.log(2 * np.pi) - np.log(sp.iv(d / 2.0 - 1.0, k)) 87 | 88 | if tmp != tmp: 89 | exit() 90 | return tmp 91 | 92 | def log_density(self, k, x, mu=None): 93 | """compute log density under vmf distribution 94 | Args: 95 | mu: tensor with shape (batch, *, latent_dim) 96 | x: tensor with shape (batch, *, latent_dim) 97 | 98 | Returns: 99 | density: tensor with shape (batch, *) 100 | 101 | """ 102 | 103 | if k == 0: 104 | return self.log_norm_uniform + k * (x.sum(dim=-1)) 105 | 106 | return self.log_norm + k * (mu * x).sum(dim=-1) 107 | 108 | 109 | def build_bow_rep(self, lat_code, n_sample): 110 | batch_sz = lat_code.size()[0] 111 | tup = self.estimate_param(latent_code=lat_code) 112 | mu = tup['mu'] 113 | norm = tup['norm'] 114 | kappa = tup['kappa'] 115 | 116 | kld = self.compute_KLD(tup, batch_sz) 117 | vecs = [] 118 | if n_sample == 1: 119 | return tup, kld, self.sample_cell(mu, norm, kappa) 120 | for n in range(n_sample): 121 | sample = self.sample_cell(mu, norm, kappa) 122 | vecs.append(sample) 123 | vecs = torch.cat(vecs, dim=1) 124 | return tup, kld, vecs 125 | 126 | def sample_cell(self, mu, norm, kappa): 127 | batch_sz, lat_dim = mu.size() 128 | # mu = GVar(mu) 129 | mu = mu / torch.norm(mu, p=2, dim=1, keepdim=True) 130 | w = self._sample_weight_batch(kappa, lat_dim, batch_sz) 131 | w = w.unsqueeze(1) 132 | 133 | # batch version 134 | w_var = (w * torch.ones(batch_sz, lat_dim)).to(self.device) 135 | v = self._sample_ortho_batch(mu, lat_dim) 136 | scale_factr = torch.sqrt( 137 | torch.ones(batch_sz, lat_dim, device=self.device) - torch.pow(w_var, 2)) 138 | orth_term = v * scale_factr 139 | muscale = mu * w_var 140 | sampled_vec = orth_term + muscale 141 | 142 | return sampled_vec.unsqueeze(1) 143 | 144 | def _sample_weight_batch(self, kappa, dim, batch_sz=1): 145 | # result = torch.FloatTensor((batch_sz)) 146 | result = [0] * batch_sz 147 | for b in range(batch_sz): 148 | result[b] = self._sample_weight(kappa, dim) 149 | return torch.FloatTensor(result) 150 | 151 | def _sample_weight(self, kappa, dim): 152 | """Rejection sampling scheme for sampling distance from center on 153 | surface of the sphere. 154 | """ 155 | dim = dim - 1 # since S^{n-1} 156 | b = dim / (np.sqrt(4. * kappa ** 2 + dim ** 2) + 2 * kappa) # b= 1/(sqrt(4.* kdiv**2 + 1) + 2 * kdiv) 157 | x = (1. - b) / (1. + b) 158 | c = kappa * x + dim * np.log(1 - x ** 2) # dim * (kdiv *x + np.log(1-x**2)) 159 | 160 | while True: 161 | z = np.random.beta(dim / 2., dim / 2.) # concentrates towards 0.5 as d-> inf 162 | w = (1. - (1. + b) * z) / (1. - (1. - b) * z) 163 | u = np.random.uniform(low=0, high=1) 164 | if kappa * w + dim * np.log(1. - x * w) - c >= np.log( 165 | u): # thresh is dim *(kdiv * (w-x) + log(1-x*w) -log(1-x**2)) 166 | return w 167 | 168 | def _sample_ortho_batch(self, mu, dim): 169 | """ 170 | 171 | :param mu: Variable, [batch size, latent dim] 172 | :param dim: scala. =latent dim 173 | :return: 174 | """ 175 | _batch_sz, _lat_dim = mu.size() 176 | assert _lat_dim == dim 177 | squeezed_mu = mu.unsqueeze(1) 178 | 179 | v = torch.randn(_batch_sz, dim, 1, device=self.device) # TODO random 180 | 181 | # v = GVar(torch.linspace(-1, 1, steps=dim)) 182 | # v = v.expand(_batch_sz, dim).unsqueeze(2) 183 | 184 | rescale_val = torch.bmm(squeezed_mu, v).squeeze(2) 185 | proj_mu_v = mu * rescale_val 186 | ortho = v.squeeze() - proj_mu_v 187 | ortho_norm = torch.norm(ortho, p=2, dim=1, keepdim=True) 188 | y = ortho / ortho_norm 189 | return y 190 | 191 | def _sample_orthonormal_to(self, mu, dim): 192 | """Sample point on sphere orthogonal to mu. 193 | """ 194 | v = torch.randn(dim, device=self.device) # TODO random 195 | 196 | # v = GVar(torch.linspace(-1,1,steps=dim)) 197 | 198 | rescale_value = mu.dot(v) / mu.norm() 199 | proj_mu_v = mu * rescale_value.expand(dim) 200 | ortho = v - proj_mu_v 201 | ortho_norm = torch.norm(ortho) 202 | return ortho / ortho_norm.expand_as(ortho) 203 | 204 | # vmf = vMF_fast(50, 100, 100) 205 | # batchsz = 100 206 | # 207 | # mu = torch.FloatTensor(np.random.uniform(0, 1, 20 * batchsz)) 208 | # mu = mu.view(batchsz, -1) 209 | # mu = mu / torch.norm(mu, p=2, dim=1, keepdim=True) 210 | # vmf.sample_cell(mu, None, 100) 211 | # x = vMF(10,lat_dim=50,kappa=50) 212 | -------------------------------------------------------------------------------- /sparse_prototype/guu_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from fairseq import metrics, utils 12 | from fairseq.criterions import LegacyFairseqCriterion, FairseqCriterion, register_criterion 13 | 14 | 15 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): 16 | """compute labeled smoothed nll loss 17 | Returns: 18 | loss: the actual loss to be optimized (after smoothing), with 19 | shape (batch) if reduce is true else (batch, seq_len) 20 | nll_loss: the NLL loss with shape (batch) if reduce is true else 21 | (batch, seq_len) 22 | """ 23 | if target.dim() == lprobs.dim() - 1: 24 | target = target.unsqueeze(-1) 25 | nll_loss = -lprobs.gather(dim=-1, index=target) 26 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 27 | if ignore_index is not None: 28 | pad_mask = target.eq(ignore_index) 29 | if pad_mask.any(): 30 | nll_loss.masked_fill_(pad_mask, 0.) 31 | smooth_loss.masked_fill_(pad_mask, 0.) 32 | 33 | nll_loss = nll_loss.squeeze(-1) 34 | smooth_loss = smooth_loss.squeeze(-1) 35 | 36 | # (batch, seq_len) --> (batch) 37 | if reduce: 38 | nll_loss = nll_loss.sum(-1) 39 | smooth_loss = smooth_loss.sum(-1) 40 | eps_i = epsilon / lprobs.size(-1) 41 | loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss 42 | return loss, nll_loss 43 | 44 | 45 | @register_criterion('guu_elbo') 46 | class GuuELBO(LegacyFairseqCriterion): 47 | 48 | def __init__(self, args, task): 49 | super().__init__(args, task) 50 | self.eps = args.label_smoothing 51 | 52 | @staticmethod 53 | def add_args(parser): 54 | """Add criterion-specific arguments to the parser.""" 55 | # fmt: off 56 | parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', 57 | help='epsilon for label smoothing, 0 means no label smoothing') 58 | # fmt: on 59 | 60 | def forward(self, model, sample, data_len, reduce=True): 61 | """Compute the loss for the given sample. 62 | 63 | Returns a tuple with three elements: 64 | 1) the loss 65 | 2) the sample size, which is used as the denominator for the gradient 66 | 3) logging outputs to display while training 67 | """ 68 | net_output = model.guu_forward(**sample['net_input'], data_len=data_len) 69 | loss, neg_elbo, recon_loss = self.compute_loss(model, net_output, sample, reduce=reduce) 70 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 71 | 72 | nsentences = sample['target'].size(0) 73 | logging_output = { 74 | 'loss': utils.item(loss.data) if reduce else loss.data, 75 | 'neg_elbo': utils.item(neg_elbo.data) if reduce else neg_elbo.data, 76 | 'recon_loss': utils.item(recon_loss.data) if reduce else recon_loss.data, 77 | 'ntokens': sample['ntokens'] / model.infer_ns, 78 | 'nsentences': sample['target'].size(0) / model.infer_ns, 79 | 'sample_size': sample_size / model.infer_ns, 80 | } 81 | 82 | return loss, sample_size, logging_output 83 | 84 | # compute the ELBO loss, involving reinforcement learning 85 | def compute_loss(self, model, net_output, sample, reduce=True): 86 | lprobs = model.get_normalized_probs(net_output['recon_out'], log_probs=True) 87 | # lprobs = lprobs.view(-1, lprobs.size(-1)) 88 | target = model.get_targets(sample, net_output) 89 | smoothed_nll_loss, nll_loss = label_smoothed_nll_loss( 90 | lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, 91 | ) 92 | 93 | revert_order = sample['net_input']['revert_order'] 94 | 95 | nll_loss = nll_loss.index_select(0, revert_order) 96 | smoothed_nll_loss = smoothed_nll_loss.index_select(0, revert_order) 97 | 98 | 99 | nsentences = sample['target'].size(0) / model.infer_ns 100 | 101 | loss = smoothed_nll_loss.view(-1, model.infer_ns).mean(1).sum() 102 | 103 | with torch.no_grad(): 104 | neg_elbo = nll_loss.view(-1, model.infer_ns).mean(1).sum() + \ 105 | math.log(model.num_prototypes / model.infer_ns) * nsentences 106 | 107 | return loss, neg_elbo, nll_loss.view(-1, model.infer_ns).mean(1).sum() 108 | 109 | def iw_eval_new(self, model, sample, data_len, iw_nsample, retrieve_dataset, reduce=True): 110 | """Compute the importance-weighted loss for the given sample. 111 | Returns a tuple with three elements: 112 | 1) the loss 113 | 2) the sample size, which is used as the denominator for the gradient 114 | 3) logging outputs to display while training 115 | """ 116 | 117 | net_output = model.guu_forward(**sample['net_input'], data_len=data_len) 118 | nll_iw = self.compute_loss_iw(model, net_output, sample, reduce=reduce) 119 | 120 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 121 | nsentences = sample['target'].size(0) / model.infer_ns 122 | 123 | logging_output = { 124 | 'nll_iw': utils.item(nll_iw.data) if reduce else nll_iw.data, 125 | 'ntokens': sample['ntokens'] / model.infer_ns, 126 | 'nsentences': nsentences, 127 | 'sample_size': sample_size / model.infer_ns, 128 | } 129 | 130 | return nll_iw, sample_size, logging_output 131 | 132 | def compute_loss_iw(self, model, net_output, sample, reduce=True): 133 | lprobs = model.get_normalized_probs(net_output['recon_out'], log_probs=True) 134 | # lprobs = lprobs.view(-1, lprobs.size(-1)) 135 | target = model.get_targets(sample, net_output) 136 | smoothed_nll_loss, nll_loss = label_smoothed_nll_loss( 137 | lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, 138 | ) 139 | 140 | revert_order = sample['net_input']['revert_order'] 141 | 142 | nll_loss = nll_loss.index_select(0, revert_order) 143 | 144 | 145 | nsentences = sample['target'].size(0) / model.infer_ns 146 | 147 | nll_iw = (torch.logsumexp(-nll_loss.view(-1, model.infer_ns), dim=1) + \ 148 | math.log(1.0 / model.num_class)).sum() 149 | 150 | nll_iw = -nll_iw 151 | 152 | return nll_iw 153 | 154 | @staticmethod 155 | def reduce_metrics(logging_outputs) -> None: 156 | """Aggregate logging outputs from data parallel training.""" 157 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 158 | neg_elbo_sum = sum(log.get('neg_elbo', 0) for log in logging_outputs) 159 | recon_loss_sum = sum(log.get('recon_loss', 0) for log in logging_outputs) 160 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 161 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 162 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 163 | 164 | if 'nll_iw' in logging_outputs[0]: 165 | nll_iw_sum = sum(log.get('nll_iw', 0) for log in logging_outputs) 166 | metrics.log_scalar('nll_iw_s', nll_iw_sum / nsentences, 167 | nsentences, round=3, priority=4) 168 | metrics.log_scalar('nll_iw_t', nll_iw_sum / ntokens / math.log(2), 169 | ntokens, round=3, priority=5) 170 | metrics.log_derived('ppl_iw', lambda meters: utils.get_perplexity(meters['nll_iw_t'].avg), priority=6) 171 | else: 172 | metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), 173 | sample_size, round=3, priority=3) 174 | 175 | metrics.log_scalar('neg_elbo_s', neg_elbo_sum / nsentences, 176 | nsentences, round=3, priority=4) 177 | metrics.log_scalar('recon_loss_s', recon_loss_sum / nsentences, 178 | nsentences, round=3, priority=4) 179 | 180 | metrics.log_scalar('neg_elbo_t', neg_elbo_sum / ntokens / math.log(2), 181 | ntokens, round=3, priority=5) 182 | metrics.log_scalar('recon_loss_t', recon_loss_sum / ntokens / math.log(2), 183 | ntokens, round=3, priority=5) 184 | 185 | metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['neg_elbo_t'].avg), priority=6) 186 | metrics.log_derived('recon_ppl', lambda meters: utils.get_perplexity(meters['recon_loss_t'].avg), priority=7) 187 | 188 | @staticmethod 189 | def logging_outputs_can_be_summed() -> bool: 190 | """ 191 | Whether the logging outputs returned by `forward` can be summed 192 | across workers prior to calling `reduce_metrics`. Setting this 193 | to True will improves distributed training speed. 194 | """ 195 | return True 196 | -------------------------------------------------------------------------------- /sparse_prototype/inv_editor/__init__.py: -------------------------------------------------------------------------------- 1 | from .inv_editor_guu import * 2 | from .inv_editor_levenshtein import * -------------------------------------------------------------------------------- /sparse_prototype/inv_editor/inv_editor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def Embedding(num_embeddings, embedding_dim, padding_idx): 6 | m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) 7 | nn.init.uniform_(m.weight, -0.1, 0.1) 8 | nn.init.constant_(m.weight[padding_idx], 0) 9 | return m 10 | 11 | class InvEditorBase(nn.Module): 12 | """Base class for Inverse Editor p(z|t, x)""" 13 | def __init__(self, embed_dim): 14 | super(InvEditorBase, self).__init__() 15 | self.embed_dim = embed_dim 16 | 17 | def forward(self, src_tokens, temp_tokens, **kwargs): 18 | """ 19 | Args: 20 | src_tokens (LongTensor): (batch, seq_len) 21 | temp_tokens (LongTensor): (batch, seq_len) 22 | 23 | Returns: Tensor1 24 | Tensor1: the representation with shape [batch, embed_dim] 25 | """ 26 | 27 | raise NotImplementedError -------------------------------------------------------------------------------- /sparse_prototype/inv_editor/inv_editor_guu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .inv_editor import InvEditorBase, Embedding 5 | 6 | 7 | 8 | class GuuInvEditor(InvEditorBase): 9 | """the inverse editor from https://arxiv.org/abs/1709.08878""" 10 | def __init__(self, embed_dim, dictionary, pretrained_embed=None, cuda=True): 11 | super(GuuInvEditor, self).__init__(embed_dim) 12 | 13 | self.embed_dim = embed_dim 14 | self.padding_idx = dictionary.pad() 15 | num_embeddings = len(dictionary) 16 | self.device = torch.device('cuda' if cuda else 'cpu') 17 | if pretrained_embed is None: 18 | self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) 19 | else: 20 | self.embed_tokens = pretrained_embed 21 | 22 | 23 | def forward(self, src_tokens, temp_tokens, **kwargs): 24 | """ 25 | Args: 26 | src_tokens (LongTensor): (batch, seq_len) 27 | temp_tokens (LongTensor): (batch, seq_len) 28 | 29 | Returns: Tensor1 30 | Tensor1: the representation with shape [batch, embed_dim] 31 | """ 32 | 33 | res = [] 34 | 35 | for src_tokens_, temp_tokens_ in zip(src_tokens, temp_tokens): 36 | src_token_list, temp_token_list = src_tokens_.tolist(), temp_tokens_.tolist() 37 | delete_words = set(src_token_list) - set(temp_token_list) - set([self.padding_idx]) 38 | insert_words = set(temp_token_list) - set(src_token_list) - set([self.padding_idx]) 39 | 40 | delete_words_t = torch.tensor(list(delete_words), dtype=torch.long, device=self.device) 41 | insert_words_t = torch.tensor(list(insert_words), dtype=torch.long, device=self.device) 42 | 43 | res.append(torch.cat((self.embed_tokens(delete_words_t).sum(0), 44 | self.embed_tokens(insert_words_t).sum(0)), dim=0).unsqueeze(0)) 45 | 46 | return torch.cat(res, dim=0) 47 | 48 | @property 49 | def output_units(self): 50 | return 2 * self.embed_dim 51 | -------------------------------------------------------------------------------- /sparse_prototype/inv_editor/inv_editor_levenshtein.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .inv_editor import InvEditorBase, Embedding 5 | 6 | 7 | 8 | class LevenshteinInvEditor(InvEditorBase): 9 | """the inverse editor from https://arxiv.org/abs/1709.08878""" 10 | def __init__(self, token_embed_dim, edit_embed_dim, hidden_size, 11 | tgt_dict, edit_dict, num_layers=1, pretrained_token_embed=None): 12 | super(LevenshteinInvEditor, self).__init__(hidden_size) 13 | 14 | 15 | self.hidden_size = hidden_size 16 | self.padding_idx = tgt_dict.pad() 17 | num_token_embeddings = len(tgt_dict) 18 | num_edit_embeddings = len(edit_dict) 19 | 20 | if pretrained_token_embed is None: 21 | self.embed_tokens = Embedding(num_token_embeddings, token_embed_dim, self.padding_idx) 22 | else: 23 | self.embed_tokens = pretrained_token_embed 24 | 25 | self.embed_edit = Embedding(num_edit_embeddings, edit_embed_dim, self.padding_idx) 26 | self.num_layers=num_layers 27 | 28 | self.lstm = nn.LSTM( 29 | input_size=token_embed_dim * 2 + edit_embed_dim, 30 | hidden_size=hidden_size, 31 | num_layers=self.num_layers, 32 | bidirectional=True, 33 | ) 34 | 35 | 36 | def forward(self, src_aligned, tgt_aligned, edit_aligned, aligned_length, **kwargs): 37 | """ 38 | Args: 39 | src_aligned (LongTensor): (batch, seq_len) 40 | tgt_aligned (LongTensor): (batch, seq_len) 41 | 42 | Returns: Tensor1 43 | Tensor1: the representation with shape [batch, embed_dim] 44 | """ 45 | 46 | bsz, seqlen = src_aligned.size() 47 | 48 | edit_embed = self.embed_edit(edit_aligned) 49 | src_embed = self.embed_tokens(src_aligned) 50 | tgt_embed = self.embed_tokens(tgt_aligned) 51 | 52 | x = torch.cat((edit_embed, src_embed, tgt_embed), -1) 53 | 54 | # B x T x C -> T x B x C 55 | x = x.transpose(0, 1) 56 | 57 | packed_x = nn.utils.rnn.pack_padded_sequence(x, aligned_length.data.tolist(), enforce_sorted=False) 58 | state_size = 2 * self.num_layers, bsz, self.hidden_size 59 | 60 | h0 = x.new_zeros(*state_size) 61 | c0 = x.new_zeros(*state_size) 62 | 63 | packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) 64 | x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_idx) 65 | 66 | def combine_bidir(outs): 67 | out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous() 68 | return out.view(self.num_layers, bsz, -1) 69 | 70 | 71 | return combine_bidir(final_hiddens)[-1] 72 | 73 | @property 74 | def output_units(self): 75 | return 2 * self.hidden_size -------------------------------------------------------------------------------- /sparse_prototype/language_pair_map_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from fairseq import utils 10 | 11 | from fairseq.data import ( 12 | data_utils, 13 | FairseqDataset, 14 | Dictionary, 15 | LanguagePairDataset, 16 | ) 17 | 18 | def collate( 19 | samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, 20 | input_feeding=True, 21 | ): 22 | if len(samples) == 0: 23 | return {} 24 | 25 | def merge(key, left_pad, move_eos_to_beginning=False): 26 | return data_utils.collate_tokens( 27 | [s[key] for s in samples], 28 | pad_idx, eos_idx, left_pad, move_eos_to_beginning, 29 | ) 30 | 31 | def check_alignment(alignment, src_len, tgt_len): 32 | if alignment is None or len(alignment) == 0: 33 | return False 34 | if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1: 35 | logger.warning("alignment size mismatch found, skipping alignment!") 36 | return False 37 | return True 38 | 39 | def compute_alignment_weights(alignments): 40 | """ 41 | Given a tensor of shape [:, 2] containing the source-target indices 42 | corresponding to the alignments, a weight vector containing the 43 | inverse frequency of each target index is computed. 44 | For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then 45 | a tensor containing [1., 0.5, 0.5, 1] should be returned (since target 46 | index 3 is repeated twice) 47 | """ 48 | align_tgt = alignments[:, 1] 49 | _, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True) 50 | align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] 51 | return 1. / align_weights.float() 52 | 53 | id = torch.LongTensor([s['id'] for s in samples]) 54 | src_tokens = merge('source', left_pad=left_pad_source) 55 | # sort by descending source length 56 | src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) 57 | src_lengths, sort_order = src_lengths.sort(descending=True) 58 | id = id.index_select(0, sort_order) 59 | src_tokens = src_tokens.index_select(0, sort_order) 60 | 61 | prev_output_tokens = None 62 | target = None 63 | if samples[0].get('target', None) is not None: 64 | target = merge('target', left_pad=left_pad_target) 65 | target = target.index_select(0, sort_order) 66 | tgt_lengths = torch.LongTensor([s['target'].numel() for s in samples]).index_select(0, sort_order) 67 | ntokens = sum(len(s['target']) for s in samples) 68 | 69 | if input_feeding: 70 | # we create a shifted version of targets for feeding the 71 | # previous output token(s) into the next decoder step 72 | prev_output_tokens = merge( 73 | 'target', 74 | left_pad=left_pad_target, 75 | move_eos_to_beginning=True, 76 | ) 77 | prev_output_tokens = prev_output_tokens.index_select(0, sort_order) 78 | else: 79 | ntokens = sum(len(s['source']) for s in samples) 80 | 81 | batch = { 82 | 'id': id, 83 | 'nsentences': len(samples), 84 | 'ntokens': ntokens, 85 | 'net_input': { 86 | 'src_tokens': src_tokens, 87 | 'src_lengths': src_lengths, 88 | }, 89 | 'target': target, 90 | } 91 | if prev_output_tokens is not None: 92 | batch['net_input']['prev_output_tokens'] = prev_output_tokens 93 | 94 | if samples[0].get('alignment', None) is not None: 95 | bsz, tgt_sz = batch['target'].shape 96 | src_sz = batch['net_input']['src_tokens'].shape[1] 97 | 98 | offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) 99 | offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz) 100 | if left_pad_source: 101 | offsets[:, 0] += (src_sz - src_lengths) 102 | if left_pad_target: 103 | offsets[:, 1] += (tgt_sz - tgt_lengths) 104 | 105 | alignments = [ 106 | alignment + offset 107 | for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths) 108 | for alignment in [samples[align_idx]['alignment'].view(-1, 2)] 109 | if check_alignment(alignment, src_len, tgt_len) 110 | ] 111 | 112 | if len(alignments) > 0: 113 | alignments = torch.cat(alignments, dim=0) 114 | align_weights = compute_alignment_weights(alignments) 115 | 116 | batch['alignments'] = alignments 117 | batch['align_weights'] = align_weights 118 | 119 | return batch 120 | 121 | 122 | class LanguagePairMapDataset(LanguagePairDataset): 123 | """A slight addon to LanguagePairDataset that supports 124 | index mapping 125 | """ 126 | def __init__( 127 | self, src, src_sizes, src_dict, index_map=None, 128 | tgt=None, tgt_sizes=None, tgt_dict=None, 129 | left_pad_source=True, left_pad_target=False, 130 | max_source_positions=1024, max_target_positions=1024, 131 | shuffle=True, input_feeding=True, 132 | remove_eos_from_source=False, append_eos_to_target=False, 133 | align_dataset=None, 134 | append_bos=False, eos=None 135 | ): 136 | super(LanguagePairMapDataset, self).__init__( 137 | src, src_sizes, src_dict, 138 | tgt=tgt, tgt_sizes=tgt_sizes, tgt_dict=tgt_dict, 139 | left_pad_source=left_pad_source, left_pad_target=left_pad_target, 140 | max_source_positions=max_source_positions, max_target_positions=max_target_positions, 141 | shuffle=shuffle, input_feeding=input_feeding, 142 | remove_eos_from_source=remove_eos_from_source, append_eos_to_target=append_eos_to_target, 143 | align_dataset=align_dataset, 144 | append_bos=append_bos, eos=eos 145 | ) 146 | 147 | self.index_map = index_map 148 | 149 | def __getitem__(self, index): 150 | orig_index = index 151 | index = self.index_map[index] 152 | 153 | tgt_item = self.tgt[index] if self.tgt is not None else None 154 | src_item = self.src[index] 155 | # Append EOS to end of tgt sentence if it does not have an EOS and remove 156 | # EOS from end of src sentence if it exists. This is useful when we use 157 | # use existing datasets for opposite directions i.e., when we want to 158 | # use tgt_dataset as src_dataset and vice versa 159 | if self.append_eos_to_target: 160 | eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() 161 | if self.tgt and self.tgt[index][-1] != eos: 162 | tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) 163 | 164 | if self.append_bos: 165 | bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() 166 | if self.tgt and self.tgt[index][0] != bos: 167 | tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]]) 168 | 169 | bos = self.src_dict.bos() 170 | if self.src[index][-1] != bos: 171 | src_item = torch.cat([torch.LongTensor([bos]), self.src[index]]) 172 | 173 | if self.remove_eos_from_source: 174 | eos = self.src_dict.eos() 175 | if self.src[index][-1] == eos: 176 | src_item = self.src[index][:-1] 177 | 178 | example = { 179 | 'id': orig_index, 180 | 'source': src_item, 181 | 'target': tgt_item, 182 | } 183 | if self.align_dataset is not None: 184 | example['alignment'] = self.align_dataset[index] 185 | return example 186 | 187 | def set_index_map(self, index_map): 188 | self.index_map = index_map 189 | 190 | def reset_index_map(self): 191 | self.index_map = None 192 | -------------------------------------------------------------------------------- /sparse_prototype/lm_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import os 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from fairseq import metrics, utils 13 | from fairseq.criterions import LegacyFairseqCriterion, FairseqCriterion, register_criterion 14 | 15 | 16 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): 17 | """compute labeled smoothed nll loss 18 | Returns: 19 | loss: the actual loss to be optimized (after smoothing), with 20 | shape (batch) if reduce is true else (batch, seq_len) 21 | nll_loss: the NLL loss with shape (batch) if reduce is true else 22 | (batch, seq_len) 23 | """ 24 | if target.dim() == lprobs.dim() - 1: 25 | target = target.unsqueeze(-1) 26 | nll_loss = -lprobs.gather(dim=-1, index=target) 27 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 28 | if ignore_index is not None: 29 | pad_mask = target.eq(ignore_index) 30 | if pad_mask.any(): 31 | nll_loss.masked_fill_(pad_mask, 0.) 32 | smooth_loss.masked_fill_(pad_mask, 0.) 33 | 34 | nll_loss = nll_loss.squeeze(-1) 35 | smooth_loss = smooth_loss.squeeze(-1) 36 | 37 | # (batch, seq_len) --> (batch) 38 | if reduce: 39 | nll_loss = nll_loss.sum(-1) 40 | smooth_loss = smooth_loss.sum(-1) 41 | eps_i = epsilon / lprobs.size(-1) 42 | loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss 43 | return loss, nll_loss 44 | 45 | def write_loss(ll_batch, sample, infer_ns, fout): 46 | revert_order = sample['net_input']['revert_order'] 47 | id_list = sample['tgt_id'].index_select(0, revert_order).view(-1, infer_ns)[:, 0] 48 | length_list = sample['net_input']['src_lengths'].index_select(0, revert_order).view(-1, infer_ns)[:, 0] 49 | 50 | for id_, ntoken, ll in zip(id_list, length_list, ll_batch): 51 | fout.write('{} {} {}\n'.format(id_.item(), ntoken.item(), ll.item())) 52 | 53 | 54 | @register_criterion('lm_baseline') 55 | class LMBaseline(LegacyFairseqCriterion): 56 | 57 | def __init__(self, args, task): 58 | super().__init__(args, task) 59 | self.eps = args.label_smoothing 60 | 61 | if args.write_loss_path is not None: 62 | self.f_loss = open(os.path.join(args.save_dir, args.write_loss_path), 'w') 63 | else: 64 | self.f_loss = None 65 | 66 | @staticmethod 67 | def add_args(parser): 68 | """Add criterion-specific arguments to the parser.""" 69 | # fmt: off 70 | parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', 71 | help='epsilon for label smoothing, 0 means no label smoothing') 72 | # fmt: on 73 | 74 | def forward(self, model, sample, data_len, reduce=True): 75 | """Compute the loss for the given sample. 76 | 77 | Returns a tuple with three elements: 78 | 1) the loss 79 | 2) the sample size, which is used as the denominator for the gradient 80 | 3) logging outputs to display while training 81 | """ 82 | net_output = model.lm_forward(**sample['net_input']) 83 | loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) 84 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 85 | 86 | nsentences = sample['target'].size(0) 87 | logging_output = { 88 | 'loss': utils.item(loss.data) if reduce else loss.data, 89 | 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, 90 | 'ntokens': sample['ntokens'] / model.infer_ns, 91 | 'nsentences': sample['target'].size(0) / model.infer_ns, 92 | 'sample_size': sample_size / model.infer_ns, 93 | } 94 | 95 | return loss, sample_size, logging_output 96 | 97 | # compute the ELBO loss, involving reinforcement learning 98 | def compute_loss(self, model, net_output, sample, reduce=True): 99 | lprobs = model.get_normalized_probs(net_output['recon_out'], log_probs=True) 100 | # lprobs = lprobs.view(-1, lprobs.size(-1)) 101 | target = model.get_targets(sample, net_output) 102 | smoothed_nll_loss, nll_loss = label_smoothed_nll_loss( 103 | lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, 104 | ) 105 | 106 | loss = smoothed_nll_loss.sum() 107 | 108 | if self.f_loss is not None: 109 | revert_order = sample['net_input']['revert_order'] 110 | nll_loss_reorder = nll_loss.index_select(0, revert_order).view(-1, model.infer_ns).mean(1) 111 | write_loss(-nll_loss_reorder, sample, model.infer_ns, self.f_loss) 112 | 113 | 114 | return loss, nll_loss.sum() 115 | 116 | @staticmethod 117 | def reduce_metrics(logging_outputs) -> None: 118 | """Aggregate logging outputs from data parallel training.""" 119 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 120 | nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs) 121 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 122 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 123 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 124 | 125 | metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), 126 | sample_size, round=3, priority=3) 127 | 128 | metrics.log_scalar('nll_loss_s', nll_loss_sum / nsentences, 129 | nsentences, round=3, priority=4) 130 | 131 | metrics.log_scalar('nll_loss_t', nll_loss_sum / ntokens / math.log(2), 132 | ntokens, round=3, priority=5) 133 | 134 | metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss_t'].avg), priority=6) 135 | 136 | @staticmethod 137 | def logging_outputs_can_be_summed() -> bool: 138 | """ 139 | Whether the logging outputs returned by `forward` can be summed 140 | across workers prior to calling `reduce_metrics`. Setting this 141 | to True will improves distributed training speed. 142 | """ 143 | return True 144 | -------------------------------------------------------------------------------- /sparse_prototype/prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import requests 3 | import tarfile 4 | import os 5 | 6 | def download_file_from_google_drive(id, destination): 7 | URL = "https://docs.google.com/uc?export=download" 8 | 9 | session = requests.Session() 10 | 11 | response = session.get(URL, params = { 'id' : id }, stream = True) 12 | token = get_confirm_token(response) 13 | 14 | if token: 15 | params = { 'id' : id, 'confirm' : token } 16 | response = session.get(URL, params = params, stream = True) 17 | 18 | save_response_content(response, destination) 19 | 20 | def get_confirm_token(response): 21 | for key, value in response.cookies.items(): 22 | if key.startswith('download_warning'): 23 | return value 24 | 25 | return None 26 | 27 | def save_response_content(response, destination): 28 | CHUNK_SIZE = 32768 29 | 30 | with open(destination, "wb") as f: 31 | for chunk in response.iter_content(CHUNK_SIZE): 32 | if chunk: # filter out keep-alive new chunks 33 | f.write(chunk) 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser(description="data downloading") 37 | parser.add_argument('--dataset', choices=["synthetic", "yahoo", "yelp", "ptb", "all"], 38 | default="ptb", help='dataset to use') 39 | 40 | args = parser.parse_args() 41 | 42 | if not os.path.exists("datasets"): 43 | os.makedirs("datasets") 44 | 45 | os.chdir("datasets") 46 | 47 | synthetic_id = "1pEHLedf3ZSo7UrHdvR1VWPfWNTcN6oWH" 48 | yahoo_id = "13azGlTuGdzWLCmgDmQPmvb_jcexVWX7i" 49 | yelp_id = "1FT49oLNV8syhmGXEgiK6XTjEfMNqqEJJ" 50 | ptb_id = "1Lh-kDhGtUuQ0inlyVor1ea6eGvVsmeUb" 51 | 52 | if args.dataset == "synthetic": 53 | file_id = [synthetic_id] 54 | elif args.dataset == "yahoo": 55 | file_id = [yahoo_id] 56 | elif args.dataset == "yelp": 57 | file_id = [yelp_id] 58 | elif args.dataset == "ptb": 59 | file_id = [ptb_id] 60 | else: 61 | file_id = [synthetic_id, yahoo_id, yelp_id, ptb_id] 62 | 63 | destination = "datasets.tar.gz" 64 | 65 | for file_id_e in file_id: 66 | download_file_from_google_drive(file_id_e, destination) 67 | tar = tarfile.open(destination, "r:gz") 68 | tar.extractall() 69 | tar.close() 70 | os.remove(destination) 71 | 72 | os.chdir("../") 73 | -------------------------------------------------------------------------------- /sparse_prototype/retrieve_prototype_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from fairseq import utils 10 | 11 | from fairseq.data import data_utils, FairseqDataset, Dictionary 12 | 13 | 14 | 15 | def output_collate( 16 | samples, logits, logits_topk, sample_orig, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, 17 | input_feeding=True, 18 | ): 19 | if len(samples) == 0: 20 | return {} 21 | 22 | def merge(key, left_pad, move_eos_to_beginning=False): 23 | return data_utils.collate_tokens( 24 | [s[key] for s in samples], 25 | pad_idx, eos_idx, left_pad, move_eos_to_beginning, 26 | ) 27 | 28 | def check_alignment(alignment, src_len, tgt_len): 29 | if alignment is None or len(alignment) == 0: 30 | return False 31 | if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1: 32 | print("| alignment size mismatch found, skipping alignment!") 33 | return False 34 | return True 35 | 36 | def compute_alignment_weights(alignments): 37 | """ 38 | Given a tensor of shape [:, 2] containing the source-target indices 39 | corresponding to the alignments, a weight vector containing the 40 | inverse frequency of each target index is computed. 41 | For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then 42 | a tensor containing [1., 0.5, 0.5, 1] should be returned (since target 43 | index 3 is repeated twice) 44 | """ 45 | align_tgt = alignments[:, 1] 46 | _, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True) 47 | align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] 48 | return 1. / align_weights.float() 49 | 50 | id = torch.LongTensor([s['id'] for s in samples]) 51 | src_tokens = merge('source', left_pad=left_pad_source) 52 | # sort by descending source length 53 | src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) 54 | src_lengths, sort_order = src_lengths.sort(descending=True) 55 | id = id.index_select(0, sort_order) 56 | src_tokens = src_tokens.index_select(0, sort_order) 57 | 58 | prev_output_tokens = None 59 | target = None 60 | if samples[0].get('target', None) is not None: 61 | target = merge('target', left_pad=left_pad_target) 62 | target = target.index_select(0, sort_order) 63 | tgt_lengths = torch.LongTensor([s['target'].numel() for s in samples]).index_select(0, sort_order) 64 | ntokens = sum(len(s['target']) for s in samples) 65 | 66 | tgt_id = torch.LongTensor([s['tgt_id'] for s in samples]) 67 | tgt_id = tgt_id.index_select(0, sort_order) 68 | 69 | if input_feeding: 70 | # we create a shifted version of targets for feeding the 71 | # previous output token(s) into the next decoder step 72 | prev_output_tokens = merge( 73 | 'target', 74 | left_pad=left_pad_target, 75 | move_eos_to_beginning=True, 76 | ) 77 | prev_output_tokens = prev_output_tokens.index_select(0, sort_order) 78 | else: 79 | ntokens = sum(len(s['source']) for s in samples) 80 | 81 | _, revert_order = sort_order.sort() 82 | batch = { 83 | 'id': id, 84 | 'tgt_id': tgt_id, 85 | 'nsentences': len(samples), 86 | 'ntokens': ntokens, 87 | 'sample_orig': sample_orig, 88 | 'net_input': { 89 | 'temp_ids': id, 90 | 'temp_tokens': src_tokens, 91 | 'temp_lengths': src_lengths, 92 | 'src_tokens': target, 93 | 'src_lengths': tgt_lengths, 94 | 'logits_topk': logits_topk, 95 | 'logits': logits, 96 | 'revert_order': revert_order, 97 | }, 98 | 'target': target, 99 | } 100 | 101 | # aligned pairs and edit path are required 102 | if samples[0].get('src_aligned', None) is not None: 103 | src_aligned = merge('src_aligned', left_pad=False).index_select(0, sort_order) 104 | aligned_length = torch.LongTensor([s['src_aligned'].numel() for s in samples]).index_select(0, sort_order) 105 | 106 | tgt_aligned = merge('tgt_aligned', left_pad=False).index_select(0, sort_order) 107 | edit_aligned = merge('edit_aligned', left_pad=False).index_select(0, sort_order) 108 | 109 | assert src_aligned.size(1) == edit_aligned.size(1) 110 | batch['net_input']['src_aligned'] = src_aligned 111 | batch['net_input']['tgt_aligned'] = tgt_aligned 112 | batch['net_input']['edit_aligned'] = edit_aligned 113 | batch['net_input']['aligned_length'] = aligned_length 114 | 115 | 116 | 117 | if prev_output_tokens is not None: 118 | batch['net_input']['prev_output_tokens'] = prev_output_tokens 119 | 120 | if samples[0].get('alignment', None) is not None: 121 | bsz, tgt_sz = batch['target'].shape 122 | src_sz = batch['net_input']['src_tokens'].shape[1] 123 | 124 | offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) 125 | offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz) 126 | if left_pad_source: 127 | offsets[:, 0] += (src_sz - src_lengths) 128 | if left_pad_target: 129 | offsets[:, 1] += (tgt_sz - tgt_lengths) 130 | 131 | alignments = [ 132 | alignment + offset 133 | for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths) 134 | for alignment in [samples[align_idx]['alignment'].view(-1, 2)] 135 | if check_alignment(alignment, src_len, tgt_len) 136 | ] 137 | 138 | if len(alignments) > 0: 139 | alignments = torch.cat(alignments, dim=0) 140 | align_weights = compute_alignment_weights(alignments) 141 | 142 | batch['alignments'] = alignments 143 | batch['align_weights'] = align_weights 144 | 145 | return batch 146 | 147 | def lang_pair_collate( 148 | samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, 149 | input_feeding=True, 150 | ): 151 | if len(samples) == 0: 152 | return {} 153 | 154 | def merge(key, left_pad, move_eos_to_beginning=False): 155 | return data_utils.collate_tokens( 156 | [s[key] for s in samples], 157 | pad_idx, eos_idx, left_pad, move_eos_to_beginning, 158 | ) 159 | 160 | def check_alignment(alignment, src_len, tgt_len): 161 | if alignment is None or len(alignment) == 0: 162 | return False 163 | if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1: 164 | logger.warning("alignment size mismatch found, skipping alignment!") 165 | return False 166 | return True 167 | 168 | def compute_alignment_weights(alignments): 169 | """ 170 | Given a tensor of shape [:, 2] containing the source-target indices 171 | corresponding to the alignments, a weight vector containing the 172 | inverse frequency of each target index is computed. 173 | For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then 174 | a tensor containing [1., 0.5, 0.5, 1] should be returned (since target 175 | index 3 is repeated twice) 176 | """ 177 | align_tgt = alignments[:, 1] 178 | _, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True) 179 | align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] 180 | return 1. / align_weights.float() 181 | 182 | id = torch.LongTensor([s['id'] for s in samples]) 183 | src_tokens = merge('source', left_pad=left_pad_source) 184 | # sort by descending source length 185 | src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) 186 | src_lengths, sort_order = src_lengths.sort(descending=True) 187 | id = id.index_select(0, sort_order) 188 | src_tokens = src_tokens.index_select(0, sort_order) 189 | 190 | prev_output_tokens = None 191 | target = None 192 | if samples[0].get('target', None) is not None: 193 | target = merge('target', left_pad=left_pad_target) 194 | target = target.index_select(0, sort_order) 195 | tgt_lengths = torch.LongTensor([s['target'].numel() for s in samples]).index_select(0, sort_order) 196 | ntokens = sum(len(s['target']) for s in samples) 197 | 198 | if input_feeding: 199 | # we create a shifted version of targets for feeding the 200 | # previous output token(s) into the next decoder step 201 | prev_output_tokens = merge( 202 | 'target', 203 | left_pad=left_pad_target, 204 | move_eos_to_beginning=True, 205 | ) 206 | prev_output_tokens = prev_output_tokens.index_select(0, sort_order) 207 | else: 208 | ntokens = sum(len(s['source']) for s in samples) 209 | 210 | _, revert_order = sort_order.sort() 211 | 212 | batch = { 213 | 'id': id, 214 | 'nsentences': len(samples), 215 | 'ntokens': ntokens, 216 | 'net_input': { 217 | 'src_tokens': src_tokens, 218 | 'src_lengths': src_lengths, 219 | 'revert_order': revert_order, 220 | }, 221 | 'target': target, 222 | } 223 | if prev_output_tokens is not None: 224 | batch['net_input']['prev_output_tokens'] = prev_output_tokens 225 | 226 | if samples[0].get('alignment', None) is not None: 227 | bsz, tgt_sz = batch['target'].shape 228 | src_sz = batch['net_input']['src_tokens'].shape[1] 229 | 230 | offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) 231 | offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz) 232 | if left_pad_source: 233 | offsets[:, 0] += (src_sz - src_lengths) 234 | if left_pad_target: 235 | offsets[:, 1] += (tgt_sz - tgt_lengths) 236 | 237 | alignments = [ 238 | alignment + offset 239 | for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths) 240 | for alignment in [samples[align_idx]['alignment'].view(-1, 2)] 241 | if check_alignment(alignment, src_len, tgt_len) 242 | ] 243 | 244 | if len(alignments) > 0: 245 | alignments = torch.cat(alignments, dim=0) 246 | align_weights = compute_alignment_weights(alignments) 247 | 248 | batch['alignments'] = alignments 249 | batch['align_weights'] = align_weights 250 | 251 | return batch 252 | 253 | 254 | class RetrievePrototypeDataset(FairseqDataset): 255 | """ 256 | Sets up a prototype dataset which takes a tgt batch, generates 257 | the prototype id with the classification function, and returns 258 | the corresponding `{generated prototype, input tgt}` batch. 259 | 260 | Args: 261 | tgt_dataset (~fairseq.data.FairseqDataset): the input dataset to be 262 | classified. 263 | tgt_dict (~fairseq.data.Dictionary): the dictionary of sentences. 264 | retrieve_fn (callable, optional): function to call to generate 265 | prototype ids. This is typically the `forward` method of a 266 | classification network. Pass in None when it is not available at initialization time, and 267 | use set_retrieve_fn function to set it when available. 268 | output_collater (callable, optional): function to call on the 269 | backtranslated samples to create the final batch 270 | (default: ``tgt_dataset.collater``). 271 | cuda: use GPU for generation 272 | """ 273 | 274 | # this is class attribute 275 | # should be of type fairseq.data.FairseqDataset 276 | 277 | def __init__( 278 | self, 279 | tgt_dataset, 280 | tgt_dict, 281 | retrieve_dataset=None, 282 | retrieve_fn=None, 283 | cuda=True, 284 | num_samples=1, 285 | temperature=1, 286 | sampling=True, 287 | edit_dict=None, 288 | split=None, 289 | masks=None, 290 | **kwargs 291 | ): 292 | self.tgt_dataset = tgt_dataset 293 | self.retrieve_fn = retrieve_fn 294 | self.cuda = cuda if torch.cuda.is_available() else False 295 | self.tgt_dict = tgt_dict 296 | self.num_samples = num_samples 297 | self.temperature = temperature 298 | self.tgt_dict = tgt_dict 299 | self.sampling = sampling 300 | 301 | self.retrieve_dataset = retrieve_dataset 302 | self.edit_align = (edit_dict is not None) 303 | 304 | self.edit_dict = edit_dict 305 | self.split = split 306 | 307 | self.masks = masks 308 | 309 | @classmethod 310 | def get_edit_dict(self): 311 | # unchanged, substitute, delete, add 312 | tag_list = ['=', 'X', 'D', 'I'] 313 | edit_dict = Dictionary() 314 | for tag in tag_list: 315 | edit_dict.add_symbol(tag) 316 | 317 | return edit_dict 318 | 319 | 320 | def __getitem__(self, index): 321 | """ 322 | Returns a single sample from *tgt_dataset*. 323 | """ 324 | return self.tgt_dataset[index] 325 | 326 | def __len__(self): 327 | return len(self.tgt_dataset) 328 | 329 | def get_string(self, index): 330 | return self.tgt_dict.string(self.tgt_dataset[index]) 331 | 332 | def set_retrieve_fn(self, retrieve_fn): 333 | self.retrieve_fn = retrieve_fn 334 | 335 | def set_sampling(self, val): 336 | self.sampling = val 337 | 338 | def wrap_collate(self, samples): 339 | return lang_pair_collate( 340 | samples, pad_idx=self.tgt_dict.pad(), 341 | eos_idx=self.tgt_dict.eos(), left_pad_source=self.tgt_dataset.left_pad_source, 342 | left_pad_target=self.tgt_dataset.left_pad_target, 343 | input_feeding=self.tgt_dataset.input_feeding, 344 | ) 345 | 346 | def retrieve_prototypes(self, samples, dataset, collate_fn, 347 | retrieve_fn, num_samples=1, 348 | temperature=1, cuda=True, 349 | sampling=True, 350 | edit_align=False, 351 | edit_dict=None): 352 | """retrieve a list of samples. 353 | 354 | Given an input (*samples*) of the form: 355 | 356 | [{'id': source_id, 'source': 'hallo world'}] 357 | 358 | this will return: 359 | 360 | [{'id': prototype_id, 'source': *prototype*, 'target': 'hallo world'}] 361 | 362 | Args: 363 | samples (List[dict]): Individual samples are expected to have a 'source' key, 364 | which will become the 'target' after retrieving. 365 | dataset (~fairseq.data.FairseqDataset): the dataset to be used for indexing. Only 366 | the source side of this dataset will be used. After retrieving, the source 367 | sentences in this dataset will still be returns as source prototypes. 368 | collate_fn (callable): function to collate samples into a mini-batch ready to input 369 | to retrieve_fn. 370 | generate_fn (callable): function to generate classfication logits. 371 | cuda (bool): use GPU for generation (default: ``True``) 372 | 373 | Returns: 374 | dict: contains `logits` and `samples`, which are an updated list of samples 375 | with a retrieved prototype source 376 | """ 377 | 378 | assert dataset is not None 379 | 380 | collated_samples = collate_fn(samples) 381 | 382 | collated_samples = utils.move_to_cuda(collated_samples) if cuda else collated_samples 383 | 384 | 385 | logits = retrieve_fn(collated_samples, self.split) 386 | logits = logits.index_select(0, collated_samples['net_input']['revert_order']) 387 | 388 | # logits = logits / temperature 389 | 390 | # avoid selecting self as templates at training time 391 | if self.masks is not None: 392 | logits_min, min_index = torch.min(logits, 1) 393 | mask_ids = [self.masks[s['id']] if self.masks[s['id']] != -1 else min_index[i].item() for i, s in enumerate(samples)] 394 | mask_ids = torch.LongTensor(mask_ids) 395 | if cuda: 396 | mask_ids = mask_ids.cuda() 397 | 398 | logits.index_fill_(1, mask_ids, logits_min.min()) 399 | 400 | 401 | bs = logits.size(0) 402 | 403 | # (batch, nsample) -> (batch * nsample) 404 | if sampling: 405 | prototype_ids = torch.multinomial(F.softmax(logits / temperature, dim=1), 406 | num_samples, replacement=True).view(-1) 407 | logits_topk = None 408 | # prototype_ids = torch.multinomial(F.softmax(logits, dim=1), 409 | # num_samples, replacement=True).view(-1) 410 | else: 411 | logits_topk, prototype_ids = torch.topk(logits, num_samples, dim=1) 412 | prototype_ids = prototype_ids.view(-1) 413 | 414 | samples_expand = [] 415 | for i in range(bs): 416 | samples_expand.extend([samples[i]] * num_samples) 417 | 418 | 419 | # List[dict] 420 | prototypes = [dataset[id_.item()] for id_ in prototype_ids] 421 | assert prototypes[0]['id'] == prototype_ids[0].item() 422 | # s = utils.move_to_cuda(collated_samples) if cuda else collated_samples 423 | # generated_sources = generate_fn(s) 424 | 425 | # find the minimum edit path from source to target 426 | if edit_align: 427 | import edlib 428 | 429 | def flat_cigar(cigar): 430 | r = [] 431 | pointer = 0 432 | 433 | while pointer < len(cigar): 434 | num = [] 435 | while cigar[pointer].isdigit(): 436 | num.append(cigar[pointer]) 437 | pointer += 1 438 | num = int(''.join(num)) 439 | 440 | r.extend([cigar[pointer]] * num) 441 | pointer += 1 442 | 443 | return r 444 | 445 | src_aligned_l = [] 446 | tgt_aligned_l = [] 447 | edit_aligned_l = [] 448 | for prototype_s, tgt_s in zip(prototypes, samples_expand): 449 | query, answer = prototype_s['source'], tgt_s['source'] 450 | query = [x.item() for x in query] 451 | answer = [x.item() for x in answer] 452 | res = edlib.align(answer, query, task='path') 453 | 454 | _edit_aligned = flat_cigar(res['cigar']) 455 | 456 | _edit_aligned_l = [] 457 | _src_aligned_l = [] 458 | _tgt_aligned_l = [] 459 | src_cur = tgt_cur = 0 460 | 461 | for edit in _edit_aligned: 462 | if edit == '=' or edit == 'X': 463 | _src_aligned_l.append(query[src_cur]) 464 | _tgt_aligned_l.append(answer[tgt_cur]) 465 | src_cur += 1 466 | tgt_cur += 1 467 | elif edit == 'I': 468 | _src_aligned_l.append(self.tgt_dict.unk_index) 469 | _tgt_aligned_l.append(answer[tgt_cur]) 470 | tgt_cur += 1 471 | elif edit == 'D': 472 | _src_aligned_l.append(query[src_cur]) 473 | _tgt_aligned_l.append(self.tgt_dict.unk_index) 474 | src_cur += 1 475 | else: 476 | raise ValueError('{} edit operation is invalid!'.format(edit)) 477 | 478 | _edit_aligned_l.append(edit_dict.index(edit)) 479 | 480 | assert len(_src_aligned_l) == len(_tgt_aligned_l) == len(_edit_aligned_l) 481 | src_aligned_l.append(torch.LongTensor(_src_aligned_l)) 482 | tgt_aligned_l.append(torch.LongTensor(_tgt_aligned_l)) 483 | edit_aligned_l.append(torch.LongTensor(_edit_aligned_l)) 484 | 485 | 486 | if not edit_align: 487 | # Note that the 'id' here is the prototype id instead of the input target ids 488 | return { 489 | 'logits': logits, 490 | 'logits_topk': logits_topk, 491 | 'sample_orig': samples, # to compute importance weighted likelihood 492 | 'samples': [ 493 | {'id': prototype_s['id'], 'tgt_id': tgt_s['id'], 'source': prototype_s['source'], 'target': tgt_s['source']} 494 | for tgt_s, prototype_s in zip(samples_expand, prototypes) 495 | ] 496 | } 497 | else: 498 | return { 499 | 'logits': logits, 500 | 'logits_topk': logits_topk, 501 | 'sample_orig': samples, # to compute importance weighted likelihood 502 | 'samples': [ 503 | {'id': prototype_s['id'], 'tgt_id': tgt_s['id'], 'source': prototype_s['source'], 'target': tgt_s['source'], 504 | 'src_aligned': src_a, 'tgt_aligned': tgt_a, 'edit_aligned': edit_a} 505 | for tgt_s, prototype_s, src_a, tgt_a, edit_a in zip(samples_expand, prototypes, src_aligned_l, tgt_aligned_l, edit_aligned_l) 506 | ] 507 | } 508 | 509 | 510 | def collater(self, samples): 511 | """Merge and backtranslate a list of samples to form a mini-batch. 512 | 513 | Using the samples from *tgt_dataset*, load a collated target sample to 514 | feed to the retrieve function. Then sample indexes, index the samples from 515 | *tgt_dataset* as prototypes, 516 | 517 | Note: we expect *tgt_dataset* to provide a function `collater()` that 518 | will collate samples into the format expected by *retrieve_fn*. 519 | After retrieving and indexing, we will feed the new list of samples (i.e., the 520 | `(retrieved source, original target)` pairs) to *output_collater* 521 | and return the result. 522 | 523 | Args: 524 | samples (List[dict]): samples to classifiy and collate 525 | 526 | Returns: 527 | dict: a mini-batch with keys coming from *output_collater* 528 | """ 529 | if len(samples) == 0: 530 | return {} 531 | 532 | 533 | if samples[0].get('is_dummy', False): 534 | return samples 535 | samples = self.retrieve_prototypes( 536 | samples=samples, 537 | dataset=self.retrieve_dataset, 538 | collate_fn=self.wrap_collate, 539 | retrieve_fn=( 540 | lambda net_input, split: self.retrieve_fn(net_input, split) 541 | ), 542 | num_samples=self.num_samples, 543 | temperature=self.temperature, 544 | cuda=self.cuda, 545 | sampling=self.sampling, 546 | edit_align=self.edit_align, 547 | edit_dict=self.edit_dict, 548 | ) 549 | 550 | return output_collate( 551 | samples['samples'], samples['logits'], logits_topk=samples['logits_topk'], sample_orig=samples['sample_orig'], 552 | pad_idx=self.tgt_dict.pad(), 553 | eos_idx=self.tgt_dict.eos(), left_pad_source=self.tgt_dataset.left_pad_source, 554 | left_pad_target=self.tgt_dataset.left_pad_target, 555 | input_feeding=self.tgt_dataset.input_feeding, 556 | ) 557 | 558 | def num_tokens(self, index): 559 | """Just use the tgt dataset num_tokens""" 560 | return self.tgt_dataset.num_tokens(index) 561 | 562 | def ordered_indices(self): 563 | """Just use the tgt dataset ordered_indices""" 564 | return self.tgt_dataset.ordered_indices() 565 | 566 | def size(self, index): 567 | """Return an example's size as a float or tuple. This value is used 568 | when filtering a dataset with ``--max-positions``. 569 | 570 | Note: we use *tgt_dataset* to approximate the length of the source 571 | sentence, since we do not know the actual length until after 572 | backtranslation. 573 | """ 574 | tgt_size = self.tgt_dataset.size(index)[0] 575 | return (tgt_size, tgt_size) 576 | 577 | @property 578 | def supports_prefetch(self): 579 | return getattr(self.tgt_dataset, 'supports_prefetch', False) 580 | 581 | def prefetch(self, indices): 582 | return self.tgt_dataset.prefetch(indices) 583 | -------------------------------------------------------------------------------- /sparse_prototype/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from .cnn_text import * 2 | from .precompute_emb import * 3 | from .bert import * 4 | -------------------------------------------------------------------------------- /sparse_prototype/retriever/bert.py: -------------------------------------------------------------------------------- 1 | # import h5py 2 | import os 3 | import subprocess 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from transformers import * 11 | from fairseq import utils 12 | # from datasets import load_dataset 13 | 14 | def get_file_len(file): 15 | proc = subprocess.run(['wc', '-l', file], capture_output=True) 16 | return int(proc.stdout.decode('utf-8').split()[0]) 17 | 18 | class BertRetriever(nn.Module): 19 | """the retriever module based on pretrained sentence-Bert embeddings""" 20 | def __init__(self, args, dictionary, emb_dataset_path, 21 | rescale=1., linear_bias=False, stop_grad=False, 22 | freeze=False, cuda=True, sentbert=False, emb_size=768): 23 | 24 | super(BertRetriever, self).__init__() 25 | 26 | self.dict = dictionary 27 | self.stop_grad = stop_grad 28 | self.device = torch.device("cuda" if cuda else "cpu") 29 | 30 | infer_dataset = args.data.split('/')[-1] 31 | if os.path.isfile(f'{emb_dataset_path}.template.npy'): 32 | 33 | # emb_size is defaulted too 768 (BERT) 34 | # here we use float16 since the embeddings are presaved in float16 35 | example_size = get_file_len(f'datasets/{infer_dataset}/template.txt') 36 | template_weight = np.memmap(f'{emb_dataset_path}.template.npy', 37 | dtype='float16', mode='r', shape=(example_size, emb_size)) 38 | else: 39 | raise ValueError('template numpy file does not exist') 40 | 41 | num_template = len(template_weight) 42 | 43 | template_weight = torch.tensor(template_weight) 44 | 45 | print('read template embeddings complete!') 46 | 47 | nfeat = template_weight.size(1) 48 | 49 | self.linear1 = nn.Linear(nfeat, nfeat, bias=False) 50 | self.linear2 = nn.Linear(nfeat, num_template, bias=linear_bias) 51 | 52 | # this should be consistent with pre-saved template embeddings 53 | model_name = 'bert-base-uncased' if not sentbert else 'sentence-transformers/bert-base-nli-mean-tokens' 54 | 55 | self.encoder = AutoModel.from_pretrained(model_name) 56 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 57 | 58 | if stop_grad: 59 | for param in self.encoder.parameters(): 60 | param.requires_grad = False 61 | 62 | with torch.no_grad(): 63 | nn.init.eye_(self.linear1.weight) 64 | self.linear1.weight.data = self.linear1.weight.data / rescale 65 | self.linear2.weight.copy_(template_weight) 66 | 67 | if linear_bias: 68 | self.linear2.bias.zero_() 69 | 70 | self.linear2.weight.requires_grad = False 71 | 72 | # self.linear1.weight.requires_grad = False 73 | 74 | if freeze: 75 | for param in self.parameters(): 76 | param.requires_grad = False 77 | 78 | self.prune_index = None 79 | self.prune_linear2_weight = None 80 | 81 | 82 | def encode(self, batches, maxlen=500): 83 | features = self.tokenizer.batch_encode_plus(batches, padding=True, 84 | return_attention_mask=True, return_token_type_ids=True, 85 | truncation=True, max_length=maxlen, return_tensors='pt') 86 | attention_mask = features['attention_mask'].to(self.device) 87 | input_ids = features['input_ids'].to(self.device) 88 | token_type_ids= features['token_type_ids'].to(self.device) 89 | 90 | # (batch, seq_len, nfeature) 91 | token_embeddings = self.encoder(input_ids=input_ids, 92 | attention_mask=attention_mask, 93 | token_type_ids=token_type_ids)[0] 94 | 95 | # mean of context embeddings as sentence embeddings 96 | embeddings = (attention_mask.unsqueeze(-1) * token_embeddings).sum(1) / attention_mask.sum(1).unsqueeze(-1) 97 | 98 | return embeddings 99 | 100 | def forward(self, samples, split=None, key=None): 101 | """ 102 | Args: 103 | samples (dict): input dict with keys 'net_input', 'id', etc. 104 | 105 | 106 | Returns: 107 | logits (tensor): shape (B, num_template) 108 | """ 109 | 110 | net_input = samples['net_input'] 111 | x = net_input['src_tokens'] 112 | bs = x.size(0) 113 | sent_strings = [] 114 | for sent in x: 115 | sent = utils.strip_pad(sent, self.dict.pad()) 116 | sent_strings.append(self.dict.string(sent).strip('\n')) 117 | 118 | # (bs x nfeats) 119 | if self.stop_grad: 120 | self.eval() 121 | with torch.no_grad(): 122 | embeddings = self.encode(sent_strings) 123 | 124 | else: 125 | embeddings = self.encode(sent_strings) 126 | 127 | 128 | if self.prune_linear2_weight is None: 129 | logits = self.linear2(self.linear1(embeddings)) 130 | else: 131 | logits = F.linear(self.linear1(embeddings), self.prune_linear2_weight) 132 | 133 | # mask itself only during training to mitigate overfitting 134 | # while this is pretty rough, but should be enough considering both time and memory efficiency 135 | 136 | # if split == 'train': 137 | # with torch.no_grad(): 138 | # mask = (self.linear2(embeddings) - (embeddings * embeddings).sum(1).unsqueeze(1)).abs() < 1e-5 139 | 140 | # logits = logits.masked_fill(mask, logits.min().item()) 141 | # mask itself 142 | # mask = logits.new_zeros(logits.size(), dtype=torch.bool) 143 | # for i, id_ in enumerate(samples['id']): 144 | # mask[i, id_.item()] = 1 145 | # logits = logits.masked_fill(mask, logits.min().item()) 146 | 147 | # prune at test time 148 | if self.prune_index is not None: 149 | logits.index_fill_(1, self.prune_index, logits.min().item() - 1e3) 150 | 151 | return logits 152 | 153 | def set_prune_index(self, index): 154 | # self.prune_index = index 155 | self.prune_linear2_weight = self.linear2.weight[index] 156 | 157 | 158 | def reset_prune_index(self): 159 | # self.prune_index = None 160 | 161 | self.prune_linear2_weight = None 162 | -------------------------------------------------------------------------------- /sparse_prototype/retriever/cnn_text.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def Embedding(num_embeddings, embedding_dim, padding_idx): 7 | m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) 8 | nn.init.uniform_(m.weight, -0.1, 0.1) 9 | nn.init.constant_(m.weight[padding_idx], 0) 10 | return m 11 | 12 | class CNN_Text(nn.Module): 13 | """A CNN classifier, originally from 14 | https://github.com/Shawn1993/cnn-text-classification-pytorch 15 | """ 16 | def __init__(self, dictionary, class_num, embed_dim=512, 17 | kernel_num=100, kernel_sizes='3,4,5', dropout=0.5, 18 | pretrained_embed=None, 19 | ): 20 | super(CNN_Text, self).__init__() 21 | 22 | V = len(dictionary) 23 | D = embed_dim 24 | C = class_num 25 | Ci = 1 26 | Co = kernel_num 27 | Ks = [int(x) for x in kernel_sizes.split(',')] 28 | 29 | self.padding_idx = dictionary.pad() 30 | if pretrained_embed is None: 31 | self.embed = Embedding(V, embed_dim, self.padding_idx) 32 | else: 33 | self.embed = pretrained_embed 34 | # self.convs1 = [nn.Conv2d(Ci, Co, (K, D)) for K in Ks] 35 | self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (K, D)) for K in Ks]) 36 | ''' 37 | self.conv13 = nn.Conv2d(Ci, Co, (3, D)) 38 | self.conv14 = nn.Conv2d(Ci, Co, (4, D)) 39 | self.conv15 = nn.Conv2d(Ci, Co, (5, D)) 40 | ''' 41 | self.dropout = nn.Dropout(dropout) 42 | self.fc1 = nn.Linear(len(Ks)*Co, C) 43 | 44 | def conv_and_pool(self, x, conv): 45 | x = F.relu(conv(x)).squeeze(3) # (N, Co, W) 46 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 47 | return x 48 | 49 | def forward(self, net_input, split=None): 50 | """ 51 | Args: 52 | net_input (dict): input dict with keys 'src_tokens' and 'src_lengths' 53 | 54 | 55 | Returns: 56 | logits (tensor): shape (B, class_num) 57 | """ 58 | 59 | x = net_input['src_tokens'] 60 | # (N, W, D) <-> (batch_size, seq_len, feature) 61 | x = self.embed(x) # (N, W, D) 62 | 63 | # if self.args.static: 64 | # x = Variable(x) 65 | 66 | x = x.unsqueeze(1) # (N, Ci, W, D) 67 | 68 | x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] # [(N, Co, W), ...]*len(Ks) 69 | 70 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N, Co), ...]*len(Ks) 71 | 72 | # (N, Co * len(Ks)) <-> (batch_size, kernel_num * len(kernel_size)) 73 | x = torch.cat(x, 1) 74 | 75 | ''' 76 | x1 = self.conv_and_pool(x,self.conv13) #(N,Co) 77 | x2 = self.conv_and_pool(x,self.conv14) #(N,Co) 78 | x3 = self.conv_and_pool(x,self.conv15) #(N,Co) 79 | x = torch.cat((x1, x2, x3), 1) # (N,len(Ks)*Co) 80 | ''' 81 | x = self.dropout(x) # (N, len(Ks)*Co) 82 | logit = self.fc1(x) # (N, C) 83 | return logit 84 | -------------------------------------------------------------------------------- /sparse_prototype/retriever/precompute_emb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import subprocess 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from itertools import chain 11 | from fairseq import utils 12 | # from datasets import load_dataset 13 | 14 | 15 | def get_file_len(file): 16 | proc = subprocess.run(['wc', '-l', file], capture_output=True) 17 | return int(proc.stdout.decode('utf-8').split()[0]) 18 | 19 | class PrecomputeEmbedRetriever(nn.Module): 20 | """the retriever module based on pretrained embeddings""" 21 | def __init__(self, args, dictionary, emb_dataset_path, rescale=1., 22 | linear_bias=False, freeze=False, nlayers=0, emb_size=768): 23 | 24 | super(PrecomputeEmbedRetriever, self).__init__() 25 | 26 | self.dict = dictionary 27 | 28 | self.dataset = {} 29 | split_list = ['template', 'train', 'valid', 'test'] 30 | # for split in split_list: 31 | # if os.path.isfile(f'{emb_dataset_path}.{split}.csv.gz'): 32 | # self.dataset[split] = load_dataset('csv', 33 | # data_files=f'{emb_dataset_path}.{split}.csv.gz', 34 | # cache_dir='hf_dataset_cache') 35 | 36 | infer_dataset = args.data.split('/')[-1] 37 | for split in split_list: 38 | if os.path.isfile(f'{emb_dataset_path}.{split}.npy'): 39 | 40 | # emb_size is defaulted too 768 (BERT) 41 | # here we use float16 since the embeddings are presaved in float16 42 | example_size = get_file_len(f'datasets/{infer_dataset}/{split}.txt') 43 | 44 | dataset_name = args.data.split('/')[1] 45 | if dataset_name != 'cocov': 46 | self.dataset[split] = np.memmap(f'{emb_dataset_path}.{split}.npy', 47 | dtype='float16', mode='r', shape=(example_size, emb_size)) 48 | else: 49 | # temp change, for cocov 50 | emb_size=1024 51 | self.dataset[split] = np.memmap(f'{emb_dataset_path}.{split}.npy', 52 | dtype='float32', mode='r', shape=(example_size, emb_size)) 53 | 54 | template_weight = self.dataset['template'] 55 | num_template = len(template_weight) 56 | # template_weight = [] 57 | 58 | # for i in range(num_template): 59 | # template_weight.append(json.loads(template_group[i]['embedding'])) 60 | 61 | template_weight = torch.tensor(np.array(template_weight)) 62 | 63 | print('read template embeddings complete!') 64 | 65 | nfeat = template_weight.size(1) 66 | 67 | self.linear1 = nn.Linear(nfeat, nfeat, bias=False) 68 | 69 | modules = [] 70 | for _ in range(nlayers): 71 | linear_tmp = nn.Linear(nfeat, nfeat) 72 | with torch.no_grad(): 73 | nn.init.eye_(linear_tmp.weight) 74 | nn.init.zeros_(linear_tmp.bias) 75 | 76 | modules.extend([linear_tmp, nn.ReLU()]) 77 | 78 | self.middle = nn.Sequential(*modules) 79 | 80 | # output layer 81 | self.linear2 = nn.Linear(nfeat, num_template, bias=linear_bias) 82 | 83 | 84 | with torch.no_grad(): 85 | nn.init.eye_(self.linear1.weight) 86 | self.linear1.weight.data = self.linear1.weight.data / rescale 87 | self.linear2.weight.copy_(template_weight) 88 | 89 | if linear_bias: 90 | self.linear2.bias.zero_() 91 | 92 | self.linear2.weight.requires_grad = False 93 | 94 | if freeze: 95 | for param in self.parameters(): 96 | param.requires_grad = False 97 | 98 | self.prune_index = None 99 | self.prune_linear2_weight = None 100 | 101 | def forward(self, samples, split, key='id'): 102 | """ 103 | Args: 104 | samples (dict): input dict with keys 'net_input', 'id', etc. 105 | 106 | 107 | Returns: 108 | logits (tensor): shape (B, class_num) 109 | """ 110 | 111 | id_ = samples[key] 112 | 113 | # embeddings = [json.loads(self.dataset[split]['train'][i.item()]['embedding']) for i in id_] 114 | embeddings = [self.dataset[split][i.item()] for i in id_] 115 | embeddings = self.linear1.weight.new_tensor(embeddings) 116 | 117 | if self.prune_linear2_weight is None: 118 | logits = self.linear2(self.middle(self.linear1(embeddings))) 119 | else: 120 | logits = F.linear(self.middle(self.linear1(embeddings)), self.prune_linear2_weight) 121 | 122 | # mask itself only during training to mitigate overfitting 123 | # while this is pretty rough, but should be enough considering both time and memory efficiency 124 | 125 | # if split == 'train': 126 | # with torch.no_grad(): 127 | # mask = (self.linear2(embeddings) - (embeddings * embeddings).sum(1).unsqueeze(1)).abs() < 1e-5 128 | 129 | # logits = logits.masked_fill(mask, logits.min().item()) 130 | 131 | # prune at test time 132 | if self.prune_index is not None: 133 | logits.index_fill_(1, self.prune_index, logits.min().item() - 1e3) 134 | 135 | # mask itself 136 | # mask = logits.new_zeros(logits.size(), dtype=torch.bool) 137 | # for i, id_ in enumerate(samples['id']): 138 | # mask[i, id_.item()] = 1 139 | # logits = logits.masked_fill(mask, logits.min().item()) 140 | 141 | return logits 142 | 143 | def set_prune_index(self, index): 144 | # self.prune_index = index 145 | self.prune_linear2_weight = self.linear2.weight[index] 146 | 147 | 148 | def reset_prune_index(self): 149 | # self.prune_index = None 150 | 151 | self.prune_linear2_weight = None 152 | 153 | -------------------------------------------------------------------------------- /sparse_prototype/retriever/sent_bert.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from fairseq import utils 8 | from sentence_transformers import SentenceTransformer 9 | 10 | class SentBert(nn.Module): 11 | """the retriever module based on pretrained sentence-Bert embeddings""" 12 | def __init__(self, class_num, dictionary, retrieve_embed, 13 | linear_bias=False, stop_grad=False, freeze=False): 14 | 15 | super(SentBert, self).__init__() 16 | 17 | self.dict = dictionary 18 | self.stop_grad = stop_grad 19 | 20 | sent_embed = [] 21 | with h5py.File(retrieve_embed, 'r') as fin: 22 | for i in range(class_num): 23 | sent_embed.append(fin[str(i)].value) 24 | 25 | print('read Bert embed from {} complete!'.format(retrieve_embed)) 26 | 27 | sent_embed = torch.tensor(sent_embed) 28 | nfeat = sent_embed.size(1) 29 | 30 | self.linear1 = nn.Linear(nfeat, nfeat, bias=False) 31 | self.linear2 = nn.Linear(nfeat, class_num, bias=linear_bias) 32 | 33 | self.encoder = SentenceTransformer('bert-base-nli-mean-tokens') 34 | 35 | if stop_grad: 36 | for param in self.encoder.parameters(): 37 | param.requires_grad = False 38 | 39 | with torch.no_grad(): 40 | nn.init.eye_(self.linear1.weight) 41 | self.linear2.weight.copy_(sent_embed) 42 | 43 | if linear_bias: 44 | self.linear2.bias.zero_() 45 | 46 | self.linear2.weight.requires_grad = False 47 | 48 | # self.linear1.weight.requires_grad = False 49 | 50 | if freeze: 51 | for param in self.parameters(): 52 | param.requires_grad = False 53 | 54 | def forward(self, samples, split=None): 55 | """ 56 | Args: 57 | samples (dict): input dict with keys 'net_input', 'id', etc. 58 | 59 | 60 | Returns: 61 | logits (tensor): shape (B, class_num) 62 | """ 63 | 64 | net_input = samples['net_input'] 65 | x = net_input['src_tokens'] 66 | bs = x.size(0) 67 | sent_strings = [] 68 | for sent in x: 69 | sent = utils.strip_pad(sent, self.dict.pad()) 70 | sent_strings.append(self.dict.string(sent).strip('\n')) 71 | 72 | # (bs x nfeats) 73 | if self.stop_grad: 74 | with torch.no_grad(): 75 | embeddings = self.encoder.encode(sent_strings, 76 | batch_size=bs, 77 | online=True, 78 | show_progress_bar=False, 79 | ) 80 | else: 81 | embeddings = self.encoder.encode(sent_strings, 82 | batch_size=bs, 83 | online=True, 84 | show_progress_bar=False, 85 | ) 86 | 87 | logits = self.linear2(self.linear1(embeddings)) 88 | 89 | # mask itself only during training to mitigate overfitting 90 | # while this is pretty rough, but should be enough considering both time and memory efficiency 91 | 92 | if split == 'train': 93 | with torch.no_grad(): 94 | mask = (self.linear2(embeddings) - (embeddings * embeddings).sum(1).unsqueeze(1)).abs() < 1e-5 95 | 96 | logits = logits.masked_fill(mask, logits.min().item()) 97 | # mask itself 98 | # mask = logits.new_zeros(logits.size(), dtype=torch.bool) 99 | # for i, id_ in enumerate(samples['id']): 100 | # mask[i, id_.item()] = 1 101 | # logits = logits.masked_fill(mask, logits.min().item()) 102 | 103 | return logits 104 | 105 | -------------------------------------------------------------------------------- /sparse_prototype/sp_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import math 8 | import copy 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from fairseq import metrics, utils 14 | from fairseq.criterions import LegacyFairseqCriterion, FairseqCriterion, register_criterion 15 | 16 | 17 | def apply_to_sample(f, sample): 18 | if len(sample) == 0: 19 | return {} 20 | 21 | def _apply(x): 22 | if torch.is_tensor(x): 23 | return f(x) 24 | elif isinstance(x, dict): 25 | return {key: _apply(value) for key, value in x.items()} 26 | elif isinstance(x, list): 27 | return [_apply(x) for x in x] 28 | else: 29 | return x 30 | 31 | return _apply(sample) 32 | 33 | 34 | def move_to_cpu(sample): 35 | def _move_to_cpu(tensor): 36 | return tensor.cpu() 37 | 38 | return apply_to_sample(_move_to_cpu, sample) 39 | 40 | 41 | def prepare_sample(sample, cuda=True, fp16=False): 42 | if sample == "DUMMY": 43 | raise Exception( 44 | "Trying to use an uninitialized 'dummy' batch. This usually indicates " 45 | "that the total number of batches is smaller than the number of " 46 | "participating GPUs. Try reducing the batch size or using fewer GPUs." 47 | ) 48 | 49 | if sample is None or len(sample) == 0: 50 | return None 51 | 52 | if cuda: 53 | sample = utils.move_to_cuda(sample) 54 | 55 | def apply_half(t): 56 | if t.dtype is torch.float32: 57 | return t.half() 58 | return t 59 | 60 | if fp16: 61 | sample = utils.apply_to_sample(apply_half, sample) 62 | 63 | return sample 64 | 65 | 66 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): 67 | """compute labeled smoothed nll loss 68 | Returns: 69 | loss: the actual loss to be optimized (after smoothing), with 70 | shape (batch) if reduce is true else (batch, seq_len) 71 | nll_loss: the NLL loss with shape (batch) if reduce is true else 72 | (batch, seq_len) 73 | """ 74 | if target.dim() == lprobs.dim() - 1: 75 | target = target.unsqueeze(-1) 76 | nll_loss = -lprobs.gather(dim=-1, index=target) 77 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 78 | if ignore_index is not None: 79 | pad_mask = target.eq(ignore_index) 80 | if pad_mask.any(): 81 | nll_loss.masked_fill_(pad_mask, 0.) 82 | smooth_loss.masked_fill_(pad_mask, 0.) 83 | 84 | nll_loss = nll_loss.squeeze(-1) 85 | smooth_loss = smooth_loss.squeeze(-1) 86 | 87 | # (batch, seq_len) --> (batch) 88 | if reduce: 89 | nll_loss = nll_loss.sum(-1) 90 | smooth_loss = smooth_loss.sum(-1) 91 | eps_i = epsilon / lprobs.size(-1) 92 | loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss 93 | return loss, nll_loss 94 | 95 | def write_loss(ll_batch, sample, infer_ns, fout): 96 | revert_order = sample['net_input']['revert_order'] 97 | id_list = sample['tgt_id'].index_select(0, revert_order).view(-1, infer_ns)[:, 0] 98 | length_list = sample['net_input']['src_lengths'].index_select(0, revert_order).view(-1, infer_ns)[:, 0] 99 | 100 | for id_, ntoken, ll in zip(id_list, length_list, ll_batch): 101 | fout.write('{} {} {}\n'.format(id_.item(), ntoken.item(), ll.item())) 102 | 103 | 104 | @register_criterion('sp_elbo') 105 | class SupportPrototypeELBO(LegacyFairseqCriterion): 106 | 107 | def __init__(self, args, task): 108 | super().__init__(args, task) 109 | self.eps = args.label_smoothing 110 | self.free_bits = args.free_bits 111 | 112 | if args.write_loss_path is not None and args.eval_mode != 'time': 113 | self.f_loss = open(os.path.join(args.save_dir, args.write_loss_path), 'w') 114 | else: 115 | self.f_loss = None 116 | 117 | @staticmethod 118 | def add_args(parser): 119 | """Add criterion-specific arguments to the parser.""" 120 | # fmt: off 121 | parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', 122 | help='epsilon for label smoothing, 0 means no label smoothing') 123 | # fmt: on 124 | 125 | def forward(self, model, sample, data_len, reduce=True): 126 | """Compute the loss for the given sample. 127 | 128 | Returns a tuple with three elements: 129 | 1) the loss 130 | 2) the sample size, which is used as the denominator for the gradient 131 | 3) logging outputs to display while training 132 | """ 133 | net_output = model(**sample['net_input'], data_len=data_len) 134 | loss, neg_elbo, recon_loss = self.compute_loss(model, net_output, sample, reduce=reduce) 135 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 136 | 137 | nsentences = sample['target'].size(0) / model.infer_ns 138 | lambda_stats = model.measure_lambda_sparsity() 139 | logging_output = { 140 | 'loss': utils.item(loss.data) if reduce else loss.data, 141 | 'neg_elbo': utils.item(neg_elbo.data) if reduce else neg_elbo.data, 142 | 'recon_loss': utils.item(recon_loss.data) if reduce else recon_loss.data, 143 | 'ntokens': sample['ntokens'] / model.infer_ns, 144 | 'nsentences': nsentences, 145 | 'sample_size': sample_size / model.infer_ns, 146 | 'KLz': utils.item(net_output['KLz'].sum().data / model.infer_ns), 147 | 'KLt': utils.item(net_output['KLt'].sum().data), 148 | 'KLtheta': utils.item(net_output['KLtheta'] * nsentences), 149 | 'lambda_t': model.lambda_t, 150 | } 151 | 152 | logging_output.update(lambda_stats) 153 | return loss, sample_size, logging_output 154 | 155 | # compute the ELBO loss, involving reinforcement learning 156 | def compute_loss(self, model, net_output, sample, reduce=True): 157 | lprobs = model.get_normalized_probs(net_output['recon_out'], log_probs=True) 158 | # lprobs = lprobs.view(-1, lprobs.size(-1)) 159 | target = model.get_targets(sample, net_output) 160 | smoothed_nll_loss, nll_loss = label_smoothed_nll_loss( 161 | lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, 162 | ) 163 | 164 | revert_order = sample['net_input']['revert_order'] 165 | 166 | KLz = net_output['KLz'] 167 | KLt = net_output['KLt'] 168 | KLtheta = net_output['KLtheta'] 169 | logq = net_output['logq'] 170 | 171 | nll_loss = nll_loss.index_select(0, revert_order) 172 | smoothed_nll_loss = smoothed_nll_loss.index_select(0, revert_order) 173 | KLz = KLz.index_select(0, revert_order) 174 | 175 | cls_reward = (nll_loss + KLz).detach() 176 | cls_reward_reshape = cls_reward.view(-1, model.infer_ns) 177 | # use average reward as the baseline, shape (batch) 178 | cls_reward = cls_reward_reshape - cls_reward_reshape.mean(dim=1, keepdim=True) 179 | # cls_reward = cls_reward_reshape - cls_reward_reshape.mean().item() 180 | 181 | nsentences = sample['target'].size(0) / model.infer_ns 182 | 183 | if self.free_bits > 0: 184 | lower_bound = KLt.new_full(KLt.size(), self.free_bits) 185 | KLt_fake, _ = torch.stack((KLt, lower_bound)).max(dim=0) 186 | else: 187 | KLt_fake = KLt 188 | 189 | loss = ((cls_reward * logq).mean(1) + model.lambda_t * KLt_fake + 190 | smoothed_nll_loss.view(-1, model.infer_ns).mean(1)).sum() + KLtheta * nsentences 191 | 192 | with torch.no_grad(): 193 | neg_elbo = ((nll_loss + KLz).view(-1, model.infer_ns).mean(1) + KLt).sum() + KLtheta * nsentences 194 | 195 | return loss, neg_elbo, nll_loss.view(-1, model.infer_ns).mean(1).sum() 196 | 197 | def iw_eval(self, model, sample, data_len, iw_nsample, retrieve_dataset, reduce=True): 198 | """Compute the importance-weighted loss for the given sample. 199 | 200 | Returns a tuple with three elements: 201 | 1) the loss 202 | 2) the sample size, which is used as the denominator for the gradient 203 | 3) logging outputs to display while training 204 | """ 205 | 206 | tmp = [] 207 | new_sample = sample 208 | cuda = next(model.parameters()).is_cuda 209 | for _ in range(int(iw_nsample / model.infer_ns)): 210 | net_output = model.iw_forward(**new_sample['net_input'], data_len=data_len) 211 | 212 | # log [p(x, t, z) / q(t, z |x)] 213 | # (batch, infer_ns) 214 | log_ratio = self._compulte_iw_loss(model, net_output, new_sample, reduce=reduce) 215 | tmp.append(log_ratio) 216 | 217 | sample_orig_cpu = move_to_cpu(sample['sample_orig']) 218 | new_sample = retrieve_dataset.collater(sample_orig_cpu) 219 | new_sample = prepare_sample(new_sample, cuda=cuda, fp16=self.args.fp16) 220 | 221 | 222 | 223 | # (batch) 224 | ll_iw = torch.logsumexp(torch.cat(tmp, dim=-1), dim=-1) - math.log(iw_nsample) 225 | if self.f_loss is not None: 226 | write_loss(ll_iw, sample, model.infer_ns, self.f_loss) 227 | ll_iw = -ll_iw.sum() 228 | 229 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 230 | 231 | nsentences = sample['target'].size(0) / model.infer_ns 232 | 233 | logging_output = { 234 | 'nll_iw': utils.item(ll_iw.data) if reduce else ll_iw.data, 235 | 'ntokens': sample['ntokens'] / model.infer_ns, 236 | 'nsentences': nsentences, 237 | 'sample_size': sample_size / model.infer_ns, 238 | } 239 | 240 | return ll_iw, sample_size, logging_output 241 | 242 | def iw_eval_new(self, model, sample, data_len, iw_nsample, retrieve_dataset, reduce=True): 243 | """Compute the importance-weighted loss for the given sample. 244 | 245 | Returns a tuple with three elements: 246 | 1) the loss 247 | 2) the sample size, which is used as the denominator for the gradient 248 | 3) logging outputs to display while training 249 | """ 250 | 251 | tmp = [] 252 | for _ in range(int(iw_nsample / model.infer_ns)): 253 | net_output = model.iw_forward(**sample['net_input'], data_len=data_len) 254 | 255 | # log [p(x, t, z) / q(t, z |x)] 256 | # (batch, infer_ns) 257 | log_ratio = self._compulte_iw_loss(model, net_output, sample, reduce=reduce) 258 | tmp.append(log_ratio.unsqueeze(-1)) 259 | 260 | tmp_cat = torch.cat(tmp, dim=-1) 261 | 262 | # (batch, infer_ns) 263 | ll_iw_z = torch.logsumexp(tmp_cat, dim=-1) - math.log(tmp_cat.size(-1)) 264 | 265 | # (batch) 266 | ll_iw = torch.logsumexp(net_output['log_pt'] + ll_iw_z, dim=1) 267 | 268 | if self.f_loss is not None: 269 | write_loss(ll_iw, sample, model.infer_ns, self.f_loss) 270 | ll_iw = -ll_iw.sum() 271 | 272 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 273 | 274 | nsentences = sample['target'].size(0) / model.infer_ns 275 | 276 | logging_output = { 277 | 'nll_iw': utils.item(ll_iw.data) if reduce else ll_iw.data, 278 | 'ntokens': sample['ntokens'] / model.infer_ns, 279 | 'nsentences': nsentences, 280 | 'sample_size': sample_size / model.infer_ns, 281 | } 282 | 283 | return ll_iw, sample_size, logging_output 284 | 285 | def _compulte_iw_loss(self, model, net_output, sample, reduce=True): 286 | """compute the importance weighted loss 287 | """ 288 | lprobs = model.get_normalized_probs(net_output['recon_out'], log_probs=True) 289 | # lprobs = lprobs.view(-1, lprobs.size(-1)) 290 | target = model.get_targets(sample, net_output) 291 | smoothed_nll_loss, nll_loss = label_smoothed_nll_loss( 292 | lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, 293 | ) 294 | 295 | revert_order = sample['net_input']['revert_order'] 296 | 297 | # (batch, infer_ns) 298 | log_pxtz = -nll_loss.index_select(0, revert_order).view(-1, model.infer_ns) 299 | 300 | # log_ratio = net_output['log_pz'] + net_output['log_pt'] + log_pxtz \ 301 | # - net_output['log_qz'] - net_output['log_qt'] 302 | log_ratio = net_output['log_pz'] + log_pxtz \ 303 | - net_output['log_qz'] 304 | 305 | return log_ratio 306 | 307 | def entropy_eval(self, model, sample, data_len, reduce=True): 308 | """Compute the importance-weighted loss for the given sample. 309 | 310 | Returns a tuple with three elements: 311 | 1) the loss 312 | 2) the sample size, which is used as the denominator for the gradient 313 | 3) logging outputs to display while training 314 | """ 315 | 316 | net_output = model.entropy_forward(**sample['net_input'], data_len=data_len) 317 | entropy = net_output['entropy'].sum() 318 | 319 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 320 | 321 | nsentences = sample['target'].size(0) / model.infer_ns 322 | 323 | logging_output = { 324 | 'entropy': utils.item(entropy.data) if reduce else entropy.data, 325 | 'ntokens': sample['ntokens'] / model.infer_ns, 326 | 'nsentences': nsentences, 327 | 'sample_size': sample_size / model.infer_ns, 328 | } 329 | 330 | return 0, sample_size, logging_output 331 | 332 | @staticmethod 333 | def reduce_metrics(logging_outputs) -> None: 334 | """Aggregate logging outputs from data parallel training.""" 335 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 336 | neg_elbo_sum = sum(log.get('neg_elbo', 0) for log in logging_outputs) 337 | recon_loss_sum = sum(log.get('recon_loss', 0) for log in logging_outputs) 338 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 339 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 340 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 341 | KLz_sum = sum(log.get('KLz', 0) for log in logging_outputs) 342 | KLt_sum = sum(log.get('KLt', 0) for log in logging_outputs) 343 | KLtheta_sum = sum(log.get('KLtheta', 0) for log in logging_outputs) 344 | 345 | if 'nll_iw' in logging_outputs[0]: 346 | nll_iw_sum = sum(log.get('nll_iw', 0) for log in logging_outputs) 347 | metrics.log_scalar('nll_iw_s', nll_iw_sum / nsentences, 348 | nsentences, round=3, priority=4) 349 | metrics.log_scalar('nll_iw_t', nll_iw_sum / ntokens / math.log(2), 350 | ntokens, round=3, priority=5) 351 | metrics.log_derived('ppl_iw', lambda meters: utils.get_perplexity(meters['nll_iw_t'].avg), priority=6) 352 | elif 'entropy' in logging_outputs[0]: 353 | entropy_sum = sum(log.get('entropy', 0) for log in logging_outputs) 354 | metrics.log_scalar('entropy_s', entropy_sum / nsentences, 355 | nsentences, round=3, priority=4) 356 | 357 | else: 358 | metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), 359 | sample_size, round=3, priority=3) 360 | 361 | metrics.log_scalar('neg_elbo_s', neg_elbo_sum / nsentences, 362 | nsentences, round=3, priority=4) 363 | metrics.log_scalar('recon_loss_s', recon_loss_sum / nsentences, 364 | nsentences, round=3, priority=4) 365 | 366 | metrics.log_scalar('neg_elbo_t', neg_elbo_sum / ntokens / math.log(2), 367 | ntokens, round=3, priority=5) 368 | metrics.log_scalar('recon_loss_t', recon_loss_sum / ntokens / math.log(2), 369 | ntokens, round=3, priority=5) 370 | 371 | metrics.log_scalar('KLz', KLz_sum / nsentences, nsentences, round=1, priority=8) 372 | metrics.log_scalar('KLt', KLt_sum / nsentences, nsentences, round=1, priority=8) 373 | metrics.log_scalar('KLtheta', KLtheta_sum / nsentences, nsentences, round=1, priority=8) 374 | 375 | metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['neg_elbo_t'].avg), priority=6) 376 | metrics.log_derived('recon_ppl', lambda meters: utils.get_perplexity(meters['recon_loss_t'].avg), priority=7) 377 | 378 | 379 | if 'lambda_t' in logging_outputs[0]: 380 | metrics.log_scalar('lambda_t', logging_outputs[0]['lambda_t'], weight=0, round=2, priority=10) 381 | 382 | if 'active' in logging_outputs[0]: 383 | metrics.log_scalar('active', logging_outputs[0]['active'], weight=0, round=1, priority=10) 384 | metrics.log_scalar('percent', logging_outputs[0]['percent'], weight=0, round=2, priority=10) 385 | # metrics.log_scalar('nlow', logging_outputs[0]['nlow'], weight=0, priority=10) 386 | # metrics.log_scalar('nhigh', logging_outputs[0]['nhigh'], weight=0, priority=10) 387 | 388 | @staticmethod 389 | def logging_outputs_can_be_summed() -> bool: 390 | """ 391 | Whether the logging outputs returned by `forward` can be summed 392 | across workers prior to calling `reduce_metrics`. Setting this 393 | to True will improves distributed training speed. 394 | """ 395 | return False 396 | -------------------------------------------------------------------------------- /sparse_prototype/sp_hub_interface.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import logging 4 | import os 5 | from typing import List, Dict, Iterator, Tuple, Any 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from fairseq import utils 11 | from fairseq import hub_utils 12 | from fairseq.data import encoders 13 | 14 | 15 | # logger = logging.getLogger(__name__) 16 | 17 | 18 | class TemplateHubInterface(hub_utils.GeneratorHubInterface): 19 | """The hub interface 20 | """ 21 | 22 | def __init__(self, args, task, models): 23 | super().__init__(args, task, models) 24 | 25 | def generate( 26 | self, 27 | tokenized_sentences: List[torch.LongTensor], 28 | beam: int = 5, 29 | verbose: bool = False, 30 | skip_invalid_size_inputs=False, 31 | inference_step_args=None, 32 | **kwargs, 33 | ) -> List[List[Dict[str, torch.Tensor]]]: 34 | if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1: 35 | return self.generate( 36 | tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs 37 | )[0] 38 | 39 | # build generator using current args as well as any kwargs 40 | gen_args = copy.copy(self.args) 41 | gen_args.beam = beam 42 | for k, v in kwargs.items(): 43 | setattr(gen_args, k, v) 44 | generator = self.task.build_generator(gen_args) 45 | 46 | results = [] 47 | for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): 48 | batch = utils.apply_to_sample(lambda t: t.to(self.device), batch) 49 | translations = self.task.inference_step(generator, self.models, batch, **inference_step_args) 50 | for id, hypos in zip(batch["id"].tolist(), translations): 51 | results.append((id, hypos)) 52 | 53 | # sort output to match input order 54 | outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])] 55 | 56 | if verbose: 57 | 58 | def getarg(name, default): 59 | return getattr(gen_args, name, getattr(self.args, name, default)) 60 | 61 | for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs): 62 | src_str_with_unk = self.string(source_tokens) 63 | logger.info('S\t{}'.format(src_str_with_unk)) 64 | for hypo in target_hypotheses: 65 | hypo_str = self.decode(hypo['tokens']) 66 | logger.info('H\t{}\t{}'.format(hypo['score'], hypo_str)) 67 | logger.info('P\t{}'.format( 68 | ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist())) 69 | )) 70 | if hypo['alignment'] is not None and getarg('print_alignment', False): 71 | logger.info('A\t{}'.format( 72 | ' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu())) 73 | )) 74 | return outputs 75 | -------------------------------------------------------------------------------- /sparse_prototype/topk_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from fairseq import metrics, utils 12 | from fairseq.criterions import LegacyFairseqCriterion, FairseqCriterion, register_criterion 13 | 14 | 15 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): 16 | """compute labeled smoothed nll loss 17 | Returns: 18 | loss: the actual loss to be optimized (after smoothing), with 19 | shape (batch) if reduce is true else (batch, seq_len) 20 | nll_loss: the NLL loss with shape (batch) if reduce is true else 21 | (batch, seq_len) 22 | """ 23 | if target.dim() == lprobs.dim() - 1: 24 | target = target.unsqueeze(-1) 25 | nll_loss = -lprobs.gather(dim=-1, index=target) 26 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 27 | if ignore_index is not None: 28 | pad_mask = target.eq(ignore_index) 29 | if pad_mask.any(): 30 | nll_loss.masked_fill_(pad_mask, 0.) 31 | smooth_loss.masked_fill_(pad_mask, 0.) 32 | 33 | nll_loss = nll_loss.squeeze(-1) 34 | smooth_loss = smooth_loss.squeeze(-1) 35 | 36 | # (batch, seq_len) --> (batch) 37 | if reduce: 38 | nll_loss = nll_loss.sum(-1) 39 | smooth_loss = smooth_loss.sum(-1) 40 | eps_i = epsilon / lprobs.size(-1) 41 | loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss 42 | return loss, nll_loss 43 | 44 | 45 | @register_criterion('topk_elbo') 46 | class TopkELBO(LegacyFairseqCriterion): 47 | 48 | def __init__(self, args, task): 49 | super().__init__(args, task) 50 | self.eps = args.label_smoothing 51 | self.free_bits = args.free_bits 52 | 53 | @staticmethod 54 | def add_args(parser): 55 | """Add criterion-specific arguments to the parser.""" 56 | # fmt: off 57 | parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', 58 | help='epsilon for label smoothing, 0 means no label smoothing') 59 | # fmt: on 60 | 61 | def forward(self, model, sample, data_len, reduce=True): 62 | """Compute the loss for the given sample. 63 | 64 | Returns a tuple with three elements: 65 | 1) the loss 66 | 2) the sample size, which is used as the denominator for the gradient 67 | 3) logging outputs to display while training 68 | """ 69 | net_output = model.topk_forward(**sample['net_input'], data_len=data_len) 70 | loss, neg_elbo, recon_loss = self.compute_loss(model, net_output, sample, reduce=reduce) 71 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 72 | 73 | nsentences = sample['target'].size(0) / model.infer_ns 74 | lambda_stats = model.measure_lambda_sparsity() 75 | logging_output = { 76 | 'loss': utils.item(loss.data) if reduce else loss.data, 77 | 'neg_elbo': utils.item(neg_elbo.data) if reduce else neg_elbo.data, 78 | 'recon_loss': utils.item(recon_loss.data) if reduce else recon_loss.data, 79 | 'ntokens': sample['ntokens'] / model.infer_ns, 80 | 'nsentences': nsentences, 81 | 'sample_size': sample_size / model.infer_ns, 82 | 'KLz': utils.item(net_output['KLz'].sum().data / model.infer_ns), 83 | 'KLt': utils.item(net_output['KLt'].sum().data), 84 | 'KLtheta': utils.item(net_output['KLtheta'] * nsentences) 85 | } 86 | 87 | logging_output.update(lambda_stats) 88 | return loss, sample_size, logging_output 89 | 90 | # compute the ELBO loss, involving reinforcement learning 91 | def compute_loss(self, model, net_output, sample, reduce=True): 92 | lprobs = model.get_normalized_probs(net_output['recon_out'], log_probs=True) 93 | # lprobs = lprobs.view(-1, lprobs.size(-1)) 94 | target = model.get_targets(sample, net_output) 95 | smoothed_nll_loss, nll_loss = label_smoothed_nll_loss( 96 | lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, 97 | ) 98 | 99 | revert_order = sample['net_input']['revert_order'] 100 | 101 | KLz = net_output['KLz'] 102 | KLt = net_output['KLt'] 103 | KLtheta = net_output['KLtheta'] 104 | logits_topk = net_output['logits_topk'] 105 | # logq = net_output['logq'] 106 | 107 | nll_loss = nll_loss.index_select(0, revert_order) 108 | smoothed_nll_loss = smoothed_nll_loss.index_select(0, revert_order) 109 | KLz = KLz.index_select(0, revert_order) 110 | 111 | nsentences = sample['target'].size(0) / model.infer_ns 112 | 113 | if self.free_bits > 0: 114 | lower_bound = KLt.new_full(KLt.size(), self.free_bits) 115 | KLt_fake, _ = torch.stack((KLt, lower_bound)).max(dim=0) 116 | else: 117 | KLt_fake = KLt 118 | 119 | if model.training: 120 | loss = (((smoothed_nll_loss + KLz).view(-1, model.infer_ns) * F.softmax(logits_topk, dim=1)).sum(1) 121 | + KLt_fake).sum() + KLtheta * nsentences 122 | else: 123 | loss = ((nll_loss + KLz).view(-1, model.infer_ns).mean(1) + KLt).sum() + KLtheta * nsentences 124 | 125 | return loss, loss, nll_loss.view(-1, model.infer_ns).mean(1).sum() 126 | 127 | def iw_eval(self, model, sample, data_len, iw_nsample, reduce=True): 128 | """Compute the importance-weighted loss for the given sample. 129 | 130 | Returns a tuple with three elements: 131 | 1) the loss 132 | 2) the sample size, which is used as the denominator for the gradient 133 | 3) logging outputs to display while training 134 | """ 135 | 136 | tmp = [] 137 | for _ in range(int(iw_nsample / model.infer_ns)): 138 | net_output = model.iw_forward(**sample['net_input'], data_len=data_len) 139 | 140 | # log [p(x, t, z) / q(t, z |x)] 141 | # (batch, infer_ns) 142 | log_ratio = self._compulte_iw_loss(model, net_output, sample, reduce=reduce) 143 | tmp.append(log_ratio) 144 | 145 | # (batch) 146 | ll_iw = torch.logsumexp(torch.cat(tmp, dim=-1), dim=-1) - math.log(iw_nsample) 147 | ll_iw = -ll_iw.sum() 148 | 149 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 150 | 151 | nsentences = sample['target'].size(0) / model.infer_ns 152 | 153 | logging_output = { 154 | 'nll_iw': utils.item(ll_iw.data) if reduce else ll_iw.data, 155 | 'ntokens': sample['ntokens'] / model.infer_ns, 156 | 'nsentences': nsentences, 157 | 'sample_size': sample_size / model.infer_ns, 158 | } 159 | 160 | return ll_iw, sample_size, logging_output 161 | 162 | def _compulte_iw_loss(self, model, net_output, sample, reduce=True): 163 | """compute the importance weighted loss 164 | """ 165 | lprobs = model.get_normalized_probs(net_output['recon_out'], log_probs=True) 166 | # lprobs = lprobs.view(-1, lprobs.size(-1)) 167 | target = model.get_targets(sample, net_output) 168 | smoothed_nll_loss, nll_loss = label_smoothed_nll_loss( 169 | lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, 170 | ) 171 | 172 | revert_order = sample['net_input']['revert_order'] 173 | 174 | # (batch, infer_ns) 175 | log_pxtz = -nll_loss.index_select(0, revert_order).view(-1, model.infer_ns) 176 | 177 | log_ratio = net_output['log_pz'] + net_output['log_pt'] + log_pxtz \ 178 | - net_output['log_qz'] - net_output['log_qt'] 179 | 180 | return log_ratio 181 | 182 | @staticmethod 183 | def reduce_metrics(logging_outputs) -> None: 184 | """Aggregate logging outputs from data parallel training.""" 185 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 186 | neg_elbo_sum = sum(log.get('neg_elbo', 0) for log in logging_outputs) 187 | recon_loss_sum = sum(log.get('recon_loss', 0) for log in logging_outputs) 188 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 189 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 190 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 191 | KLz_sum = sum(log.get('KLz', 0) for log in logging_outputs) 192 | KLt_sum = sum(log.get('KLt', 0) for log in logging_outputs) 193 | KLtheta_sum = sum(log.get('KLtheta', 0) for log in logging_outputs) 194 | 195 | if 'nll_iw' in logging_outputs[0]: 196 | nll_iw_sum = sum(log.get('nll_iw', 0) for log in logging_outputs) 197 | metrics.log_scalar('nll_iw_s', nll_iw_sum / nsentences, 198 | nsentences, round=3, priority=4) 199 | metrics.log_scalar('nll_iw_t', nll_iw_sum / ntokens / math.log(2), 200 | ntokens, round=3, priority=5) 201 | metrics.log_derived('ppl_iw', lambda meters: utils.get_perplexity(meters['nll_iw_t'].avg), priority=6) 202 | 203 | else: 204 | metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), 205 | sample_size, round=3, priority=3) 206 | 207 | metrics.log_scalar('neg_elbo_s', neg_elbo_sum / nsentences, 208 | nsentences, round=3, priority=4) 209 | metrics.log_scalar('recon_loss_s', recon_loss_sum / nsentences, 210 | nsentences, round=3, priority=4) 211 | 212 | metrics.log_scalar('neg_elbo_t', neg_elbo_sum / ntokens / math.log(2), 213 | ntokens, round=3, priority=5) 214 | metrics.log_scalar('recon_loss_t', recon_loss_sum / ntokens / math.log(2), 215 | ntokens, round=3, priority=5) 216 | 217 | metrics.log_scalar('KLz', KLz_sum / nsentences, nsentences, round=1, priority=8) 218 | metrics.log_scalar('KLt', KLt_sum / nsentences, nsentences, round=1, priority=8) 219 | metrics.log_scalar('KLtheta', KLtheta_sum / nsentences, nsentences, round=1, priority=8) 220 | 221 | metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['neg_elbo_t'].avg), priority=6) 222 | metrics.log_derived('recon_ppl', lambda meters: utils.get_perplexity(meters['recon_loss_t'].avg), priority=7) 223 | 224 | if 'active' in logging_outputs[0]: 225 | metrics.log_scalar('active', logging_outputs[0]['active'], weight=0, round=1, priority=10) 226 | metrics.log_scalar('percent', logging_outputs[0]['percent'], weight=0, round=2, priority=10) 227 | # metrics.log_scalar('nlow', logging_outputs[0]['nlow'], weight=0, priority=10) 228 | # metrics.log_scalar('nhigh', logging_outputs[0]['nhigh'], weight=0, priority=10) 229 | 230 | @staticmethod 231 | def logging_outputs_can_be_summed() -> bool: 232 | """ 233 | Whether the logging outputs returned by `forward` can be summed 234 | across workers prior to calling `reduce_metrics`. Setting this 235 | to True will improves distributed training speed. 236 | """ 237 | return True 238 | -------------------------------------------------------------------------------- /sparse_prototype/vae.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .distribution import vMF 6 | 7 | class VAEEncoder(nn.Module): 8 | """VAEEncoder class""" 9 | def __init__(self, encoder, hidden_dim, latent_dim, 10 | nsamples=1, dist="vmf", kappa=1, cuda=True 11 | ): 12 | super(VAEEncoder, self).__init__() 13 | self.encoder = encoder 14 | 15 | # self.args = args 16 | 17 | self.nsamples = nsamples 18 | self.kappa = kappa 19 | 20 | if dist == "vmf": 21 | self.dist = vMF(hidden_dim, latent_dim, kappa, cuda=cuda) 22 | else: 23 | raise ValueError("the {} distribution is not supported".format(dist)) 24 | 25 | 26 | # loc = torch.zeros(self.nz, device=args.device) 27 | # scale = torch.ones(self.nz, device=args.device) 28 | 29 | # self.prior = torch.distributions.normal.Normal(loc, scale) 30 | 31 | def forward(self, src_tokens, src_lengths, temp_tokens, temp_lengths, **kwargs): 32 | """ 33 | Args: 34 | src_tokens (LongTensor): (batch, seq_len) 35 | temp_tokens (LongTensor): (batch, seq_len) 36 | 37 | Returns: Tensor1, Tensor2 38 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 39 | Tensor2: the tenor of KL for each x with shape [batch] 40 | """ 41 | bs, seq_len = src_tokens.size() 42 | 43 | # (batch, hidden_dim) 44 | hidden_code = self.encoder(src_tokens=src_tokens, 45 | src_lengths=src_lengths, 46 | temp_tokens=temp_tokens, 47 | temp_lengths=temp_lengths, 48 | **kwargs) 49 | 50 | # ret_dict: a dict that contains parameters of the approx posterior 51 | # KL: (batch) 52 | # z: (batch, nsamples, nz) 53 | ret_dict, KL, z = self.dist.build_bow_rep(hidden_code, self.nsamples) 54 | 55 | return z, KL, ret_dict 56 | 57 | def log_prior_vmf_density(self, x): 58 | """compute log density under the uniform vmf prior 59 | Args: 60 | x: tensor with shape (batch, *, latent_dim) 61 | """ 62 | 63 | 64 | return self.dist.log_density(0, x) 65 | 66 | def log_vmf_density(self, x, mu): 67 | 68 | return self.dist.log_density(self.kappa, x, mu) 69 | 70 | 71 | 72 | def encode_stats(self, x): 73 | """ 74 | Returns: Tensor1, Tensor2 75 | Tensor1: the mean of latent z with shape [batch, nz] 76 | Tensor2: the logvar of latent z with shape [batch, nz] 77 | """ 78 | 79 | return self.encoder(x) 80 | 81 | def decode(self, z, strategy, K=5): 82 | """generate samples from z given strategy 83 | Args: 84 | z: [batch, nsamples, nz] 85 | strategy: "beam" or "greedy" or "sample" 86 | K: the beam width parameter 87 | Returns: List1 88 | List1: a list of decoded word sequence 89 | """ 90 | 91 | if strategy == "beam": 92 | return self.decoder.beam_search_decode(z, K) 93 | elif strategy == "greedy": 94 | return self.decoder.greedy_decode(z) 95 | elif strategy == "sample": 96 | return self.decoder.sample_decode(z) 97 | else: 98 | raise ValueError("the decoding strategy is not supported") 99 | 100 | def reconstruct(self, x, decoding_strategy="greedy", K=5): 101 | """reconstruct from input x 102 | Args: 103 | x: (batch, *) 104 | decoding_strategy: "beam" or "greedy" or "sample" 105 | K: the beam width parameter (if applicable) 106 | Returns: List1 107 | List1: a list of decoded word sequence 108 | """ 109 | z = self.sample_from_inference(x).squeeze(1) 110 | 111 | return self.decode(z, decoding_strategy, K) 112 | 113 | 114 | def loss(self, x, kl_weight, nsamples=1): 115 | """ 116 | Args: 117 | x: if the data is constant-length, x is the data tensor with 118 | shape (batch, *). Otherwise x is a tuple that contains 119 | the data tensor and length list 120 | Returns: Tensor1, Tensor2, Tensor3 121 | Tensor1: total loss [batch] 122 | Tensor2: reconstruction loss shape [batch] 123 | Tensor3: KL loss shape [batch] 124 | """ 125 | 126 | z, KL = self.encode(x, nsamples) 127 | 128 | # (batch) 129 | reconstruct_err = self.decoder.reconstruct_error(x, z).mean(dim=1) 130 | 131 | 132 | return reconstruct_err + kl_weight * KL, reconstruct_err, KL 133 | 134 | def nll_iw(self, x, nsamples, ns=100): 135 | """compute the importance weighting estimate of the log-likelihood 136 | Args: 137 | x: if the data is constant-length, x is the data tensor with 138 | shape (batch, *). Otherwise x is a tuple that contains 139 | the data tensor and length list 140 | nsamples: Int 141 | the number of samples required to estimate marginal data likelihood 142 | Returns: Tensor1 143 | Tensor1: the estimate of log p(x), shape [batch] 144 | """ 145 | 146 | # compute iw every ns samples to address the memory issue 147 | # nsamples = 500, ns = 100 148 | # nsamples = 500, ns = 10 149 | tmp = [] 150 | for _ in range(int(nsamples / ns)): 151 | # [batch, ns, nz] 152 | # param is the parameters required to evaluate q(z|x) 153 | z, param = self.encoder.sample(x, ns) 154 | 155 | # [batch, ns] 156 | log_comp_ll = self.eval_complete_ll(x, z) 157 | log_infer_ll = self.eval_inference_dist(x, z, param) 158 | 159 | tmp.append(log_comp_ll - log_infer_ll) 160 | 161 | ll_iw = torch.logsumexp(torch.cat(tmp, dim=-1), dim=-1) - math.log(nsamples) 162 | 163 | return -ll_iw 164 | 165 | def KL(self, x): 166 | _, KL = self.encode(x, 1) 167 | 168 | return KL 169 | 170 | def eval_prior_dist(self, zrange): 171 | """perform grid search to calculate the true posterior 172 | Args: 173 | zrange: tensor 174 | different z points that will be evaluated, with 175 | shape (k^2, nz), where k=(zmax - zmin)/space 176 | """ 177 | 178 | # (k^2) 179 | return self.prior.log_prob(zrange).sum(dim=-1) 180 | 181 | def eval_complete_ll(self, x, z): 182 | """compute log p(z,x) 183 | Args: 184 | x: Tensor 185 | input with shape [batch, seq_len] 186 | z: Tensor 187 | evaluation points with shape [batch, nsamples, nz] 188 | Returns: Tensor1 189 | Tensor1: log p(z,x) Tensor with shape [batch, nsamples] 190 | """ 191 | 192 | # [batch, nsamples] 193 | log_prior = self.eval_prior_dist(z) 194 | log_gen = self.eval_cond_ll(x, z) 195 | 196 | return log_prior + log_gen 197 | 198 | def eval_cond_ll(self, x, z): 199 | """compute log p(x|z) 200 | """ 201 | 202 | return self.decoder.log_probability(x, z) 203 | 204 | def eval_log_model_posterior(self, x, grid_z): 205 | """perform grid search to calculate the true posterior 206 | this function computes p(z|x) 207 | Args: 208 | grid_z: tensor 209 | different z points that will be evaluated, with 210 | shape (k^2, nz), where k=(zmax - zmin)/pace 211 | Returns: Tensor 212 | Tensor: the log posterior distribution log p(z|x) with 213 | shape [batch_size, K^2] 214 | """ 215 | try: 216 | batch_size = x.size(0) 217 | except: 218 | batch_size = x[0].size(0) 219 | 220 | # (batch_size, k^2, nz) 221 | grid_z = grid_z.unsqueeze(0).expand(batch_size, *grid_z.size()).contiguous() 222 | 223 | # (batch_size, k^2) 224 | log_comp = self.eval_complete_ll(x, grid_z) 225 | 226 | # normalize to posterior 227 | log_posterior = log_comp - torch.logsumexp(log_comp, dim=1, keepdim=True) 228 | 229 | return log_posterior 230 | 231 | def sample_from_prior(self, nsamples): 232 | """sampling from prior distribution 233 | Returns: Tensor 234 | Tensor: samples from prior with shape (nsamples, nz) 235 | """ 236 | return self.prior.sample((nsamples,)) 237 | 238 | 239 | def sample_from_inference(self, x, nsamples=1): 240 | """perform sampling from inference net 241 | Returns: Tensor 242 | Tensor: samples from infernece nets with 243 | shape (batch_size, nsamples, nz) 244 | """ 245 | z, _ = self.encoder.sample(x, nsamples) 246 | 247 | return z 248 | 249 | 250 | def sample_from_posterior(self, x, nsamples): 251 | """perform MH sampling from model posterior 252 | Returns: Tensor 253 | Tensor: samples from model posterior with 254 | shape (batch_size, nsamples, nz) 255 | """ 256 | 257 | # use the samples from inference net as initial points 258 | # for MCMC sampling. [batch_size, nsamples, nz] 259 | cur = self.encoder.sample_from_inference(x, 1) 260 | cur_ll = self.eval_complete_ll(x, cur) 261 | total_iter = self.args.mh_burn_in + nsamples * self.args.mh_thin 262 | samples = [] 263 | for iter_ in range(total_iter): 264 | next = torch.normal(mean=cur, 265 | std=cur.new_full(size=cur.size(), fill_value=self.args.mh_std)) 266 | # [batch_size, 1] 267 | next_ll = self.eval_complete_ll(x, next) 268 | ratio = next_ll - cur_ll 269 | 270 | accept_prob = torch.min(ratio.exp(), ratio.new_ones(ratio.size())) 271 | 272 | uniform_t = accept_prob.new_empty(accept_prob.size()).uniform_() 273 | 274 | # [batch_size, 1] 275 | mask = (uniform_t < accept_prob).float() 276 | 277 | mask_ = mask.unsqueeze(2) 278 | 279 | cur = mask_ * next + (1 - mask_) * cur 280 | cur_ll = mask * next_ll + (1 - mask) * cur_ll 281 | 282 | if iter_ >= self.args.mh_burn_in and (iter_ - self.args.mh_burn_in) % self.args.mh_thin == 0: 283 | samples.append(cur.unsqueeze(1)) 284 | 285 | 286 | return torch.cat(samples, dim=1) 287 | -------------------------------------------------------------------------------- /sparse_prototype/vmf_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from vae import VAEBase 5 | 6 | class VMFVAE(VAEBase): 7 | """VAE base class""" 8 | def __init__(self, encoder, kappa,): 9 | super(VMFVAE, self).__init__(encoder) 10 | 11 | # self.args = args 12 | 13 | # loc = torch.zeros(self.nz, device=args.device) 14 | # scale = torch.ones(self.nz, device=args.device) 15 | 16 | # self.prior = torch.distributions.normal.Normal(loc, scale) 17 | 18 | def encode(self, src_tokens, src_lengths, nsamples=1, **kwargs): 19 | """ 20 | Returns: Tensor1, Tensor2 21 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 22 | Tensor2: the tenor of KL for each x with shape [batch] 23 | """ 24 | 25 | encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) 26 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Train a new model on one or across multiple GPUs. 8 | """ 9 | 10 | import logging 11 | import math 12 | import os 13 | import random 14 | import sys 15 | import time 16 | 17 | import numpy as np 18 | import torch 19 | 20 | from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils 21 | from fairseq.data import iterators 22 | from fairseq.logging import meters, metrics, progress_bar 23 | from trainer import Trainer 24 | 25 | 26 | logging.basicConfig( 27 | format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', 28 | datefmt='%Y-%m-%d %H:%M:%S', 29 | level=logging.INFO, 30 | stream=sys.stdout, 31 | ) 32 | logger = logging.getLogger('fairseq_cli.train') 33 | 34 | 35 | def main(args, init_distributed=False): 36 | utils.import_user_module(args) 37 | 38 | assert args.max_tokens is not None or args.max_sentences is not None, \ 39 | 'Must specify batch size either with --max-tokens or --max-sentences' 40 | 41 | # Initialize CUDA and distributed training 42 | if torch.cuda.is_available() and not args.cpu: 43 | torch.cuda.set_device(args.device_id) 44 | np.random.seed(args.seed) 45 | torch.manual_seed(args.seed) 46 | if init_distributed: 47 | args.distributed_rank = distributed_utils.distributed_init(args) 48 | 49 | if distributed_utils.is_master(args): 50 | checkpoint_utils.verify_checkpoint_directory(args.save_dir) 51 | 52 | # Print args 53 | logger.info(args) 54 | 55 | # Setup task, e.g., translation, language modeling, etc. 56 | task = tasks.setup_task(args) 57 | 58 | # Load valid dataset (we load training data below, based on the latest checkpoint) 59 | for valid_sub_split in args.valid_subset.split(','): 60 | task.load_dataset(valid_sub_split, combine=False, epoch=1) 61 | 62 | # Build model and criterion 63 | model = task.build_model(args) 64 | criterion = task.build_criterion(args) 65 | logger.info(model) 66 | logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) 67 | logger.info('num. model params: {} (num. trained: {})'.format( 68 | sum(p.numel() for p in model.parameters()), 69 | sum(p.numel() for p in model.parameters() if p.requires_grad), 70 | )) 71 | 72 | # Build trainer 73 | trainer = Trainer(args, task, model, criterion) 74 | logger.info('training on {} GPUs'.format(args.distributed_world_size)) 75 | logger.info('max tokens per GPU = {} and max sentences per GPU = {}'.format( 76 | args.max_tokens, 77 | args.max_sentences, 78 | )) 79 | 80 | # Load the latest checkpoint if one is available and restore the 81 | # corresponding train iterator 82 | extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) 83 | 84 | # Train until the learning rate gets too small 85 | max_epoch = args.max_epoch or math.inf 86 | max_update = args.max_update or math.inf 87 | lr = trainer.get_lr() 88 | train_meter = meters.StopwatchMeter() 89 | train_meter.start() 90 | valid_subsets = args.valid_subset.split(',') 91 | 92 | if args.eval_mode != 'none': 93 | start_val_time = time.time() 94 | with torch.no_grad(): 95 | if args.eval_mode != 'entropy': 96 | _ = validate(args, trainer, task, epoch_itr, valid_subsets, args.prune_num) 97 | print('elapsed time (seconds): {}'.format(time.time() - start_val_time)) 98 | 99 | _ = validate_iw(args, trainer, task, epoch_itr, valid_subsets, args.prune_num, mode=args.eval_mode) 100 | return 101 | 102 | 103 | while ( 104 | lr > args.min_lr 105 | and epoch_itr.next_epoch_idx <= max_epoch 106 | and trainer.get_num_updates() < max_update 107 | ): 108 | # train for one epoch 109 | train(args, trainer, task, epoch_itr) 110 | 111 | if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: 112 | valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) 113 | else: 114 | valid_losses = [None] 115 | 116 | # only use first validation loss to update the learning rate 117 | lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) 118 | 119 | # save checkpoint 120 | if epoch_itr.epoch % args.save_interval == 0: 121 | checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) 122 | 123 | # early stop 124 | if should_stop_early(args, valid_losses[0]): 125 | logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience)) 126 | break 127 | 128 | epoch_itr = trainer.get_train_iterator( 129 | epoch_itr.next_epoch_idx, 130 | # sharded data: get train iterator for next epoch 131 | load_dataset=(os.pathsep in getattr(args, 'data', '')), 132 | ) 133 | logger.info('done training in {:.1f} seconds'.format(train_meter.sum)) 134 | 135 | # _ = validate_iw(args, trainer, task, epoch_itr, valid_subsets) 136 | 137 | train_meter.stop() 138 | 139 | 140 | def should_stop_early(args, valid_loss): 141 | # skip check if no validation was done in the current epoch 142 | if valid_loss is None: 143 | return False 144 | if args.patience <= 0: 145 | return False 146 | 147 | def is_better(a, b): 148 | return a > b if args.maximize_best_checkpoint_metric else a < b 149 | 150 | prev_best = getattr(should_stop_early, 'best', None) 151 | if prev_best is None or is_better(valid_loss, prev_best): 152 | should_stop_early.best = valid_loss 153 | should_stop_early.num_runs = 0 154 | return False 155 | else: 156 | should_stop_early.num_runs += 1 157 | return should_stop_early.num_runs > args.patience 158 | 159 | 160 | @metrics.aggregate('train') 161 | def train(args, trainer, task, epoch_itr): 162 | """Train the model for one epoch.""" 163 | # Initialize data iterator 164 | itr = epoch_itr.next_epoch_itr( 165 | fix_batches_to_gpus=args.fix_batches_to_gpus, 166 | shuffle=(epoch_itr.next_epoch_idx > args.curriculum), 167 | ) 168 | update_freq = ( 169 | args.update_freq[epoch_itr.epoch - 1] 170 | if epoch_itr.epoch <= len(args.update_freq) 171 | else args.update_freq[-1] 172 | ) 173 | itr = iterators.GroupedIterator(itr, update_freq) 174 | progress = progress_bar.progress_bar( 175 | itr, 176 | log_format=args.log_format, 177 | log_interval=args.log_interval, 178 | epoch=epoch_itr.epoch, 179 | tensorboard_logdir=( 180 | args.tensorboard_logdir if distributed_utils.is_master(args) else None 181 | ), 182 | default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), 183 | ) 184 | 185 | # task specific setup per epoch 186 | task.begin_epoch(epoch_itr.epoch, trainer.get_model()) 187 | 188 | valid_subsets = args.valid_subset.split(',') 189 | max_update = args.max_update or math.inf 190 | for samples in progress: 191 | with metrics.aggregate('train_inner'): 192 | log_output = trainer.train_step(samples) 193 | if log_output is None: # OOM, overflow, ... 194 | continue 195 | 196 | # log mid-epoch stats 197 | num_updates = trainer.get_num_updates() 198 | if num_updates % args.log_interval == 0: 199 | stats = get_training_stats(metrics.get_smoothed_values('train_inner')) 200 | progress.log(stats, tag='train_inner', step=num_updates) 201 | 202 | # reset mid-epoch stats after each log interval 203 | # the end-of-epoch stats will still be preserved 204 | metrics.reset_meters('train_inner') 205 | 206 | if ( 207 | not args.disable_validation 208 | and args.save_interval_updates > 0 209 | and num_updates % args.save_interval_updates == 0 210 | and num_updates > 0 211 | ): 212 | valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) 213 | checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) 214 | 215 | if num_updates >= max_update: 216 | break 217 | 218 | # log end-of-epoch stats 219 | stats = get_training_stats(metrics.get_smoothed_values('train')) 220 | progress.print(stats, tag='train', step=num_updates) 221 | 222 | # reset epoch-level meters 223 | metrics.reset_meters('train') 224 | 225 | 226 | def get_training_stats(stats): 227 | if 'nll_loss' in stats and 'ppl' not in stats: 228 | stats['ppl'] = utils.get_perplexity(stats['nll_loss']) 229 | stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0) 230 | return stats 231 | 232 | 233 | def validate(args, trainer, task, epoch_itr, subsets, prune=-1): 234 | """Evaluate the model on the validation set(s) and return the losses.""" 235 | 236 | if args.fixed_validation_seed is not None: 237 | # set fixed seed for every validation 238 | utils.set_torch_seed(args.fixed_validation_seed) 239 | 240 | valid_losses = [] 241 | for subset in subsets: 242 | # Initialize data iterator 243 | itr = task.get_batch_iterator( 244 | dataset=task.dataset(subset), 245 | max_tokens=args.max_tokens_valid, 246 | max_sentences=args.max_sentences_valid, 247 | max_positions=utils.resolve_max_positions( 248 | task.max_positions(), 249 | trainer.get_model().max_positions(), 250 | ), 251 | ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, 252 | required_batch_size_multiple=args.required_batch_size_multiple, 253 | seed=args.seed, 254 | num_shards=args.distributed_world_size, 255 | shard_id=args.distributed_rank, 256 | num_workers=args.num_workers, 257 | ).next_epoch_itr(shuffle=False) 258 | progress = progress_bar.progress_bar( 259 | itr, 260 | log_format=args.log_format, 261 | log_interval=args.log_interval, 262 | epoch=epoch_itr.epoch, 263 | prefix=f"valid on '{subset}' subset", 264 | tensorboard_logdir=( 265 | args.tensorboard_logdir if distributed_utils.is_master(args) else None 266 | ), 267 | default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), 268 | ) 269 | 270 | # added by Junxian 271 | if prune > 0: 272 | index_map = trainer.get_model().set_prune_index(prune) 273 | task.set_index_map(index_map) 274 | 275 | # not write templates for time profiling 276 | write_template_flag = False if args.eval_mode == 'time' else True 277 | 278 | # only one worker deals with the template file in DDP 279 | if args.distributed_rank == 0 and write_template_flag: 280 | print('write template files') 281 | 282 | if args.eval_mode == 'none': 283 | fout = open(os.path.join(args.save_dir, 284 | 'templates_{}_{}.txt'.format(epoch_itr.epoch, trainer.get_num_updates())), 'w') 285 | else: 286 | fout = open(os.path.join(args.save_dir,'templates_eval_{}.txt'.format(subset)), 'w') 287 | 288 | if prune <= 0: 289 | task.write_lambda(fout, trainer.get_model()) 290 | else: 291 | fout = None 292 | 293 | # create a new root metrics aggregator so validation metrics 294 | # don't pollute other aggregators (e.g., train meters) 295 | with metrics.aggregate(new_root=True) as agg: 296 | for sample in progress: 297 | trainer.valid_step(sample, split=subset) 298 | 299 | # added by Junxian 300 | if args.distributed_rank == 0: 301 | task.write_template(sample, trainer.get_model(), fout) 302 | 303 | if fout is not None: 304 | fout.close() 305 | 306 | # log validation stats 307 | stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) 308 | progress.print(stats, tag=subset, step=trainer.get_num_updates()) 309 | 310 | valid_losses.append(stats[args.best_checkpoint_metric]) 311 | return valid_losses 312 | 313 | 314 | def validate_iw(args, trainer, task, epoch_itr, subsets, prune=-1, mode='iw'): 315 | """Evaluate the model on the validation set(s) and return the losses.""" 316 | 317 | if mode == 'none' or mode == 'time' or args.criterior == 'lm_baseline': 318 | return [0] 319 | 320 | # top k instead of sampling to approximate sum of prototypes for evaluation 321 | for subset in subsets: 322 | task.dataset(subset).set_sampling(False) 323 | 324 | if args.fixed_validation_seed is not None: 325 | # set fixed seed for every validation 326 | utils.set_torch_seed(args.fixed_validation_seed) 327 | 328 | valid_losses = [] 329 | for subset in subsets: 330 | # Initialize data iterator 331 | itr = task.get_batch_iterator( 332 | dataset=task.dataset(subset), 333 | max_tokens=args.max_tokens_valid, 334 | max_sentences=args.max_sentences_valid, 335 | max_positions=utils.resolve_max_positions( 336 | task.max_positions(), 337 | trainer.get_model().max_positions(), 338 | ), 339 | ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, 340 | required_batch_size_multiple=args.required_batch_size_multiple, 341 | seed=args.seed, 342 | num_shards=args.distributed_world_size, 343 | shard_id=args.distributed_rank, 344 | num_workers=args.num_workers, 345 | ).next_epoch_itr(shuffle=False) 346 | progress = progress_bar.progress_bar( 347 | itr, 348 | log_format=args.log_format, 349 | log_interval=args.log_interval, 350 | epoch=1, 351 | prefix=f"valid on '{subset}' subset", 352 | tensorboard_logdir=( 353 | args.tensorboard_logdir if distributed_utils.is_master(args) else None 354 | ), 355 | default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), 356 | ) 357 | 358 | if prune > 0: 359 | index_map = trainer.get_model().set_prune_index(prune) 360 | task.set_index_map(index_map) 361 | 362 | # create a new root metrics aggregator so validation metrics 363 | # don't pollute other aggregators (e.g., train meters) 364 | with metrics.aggregate(new_root=True) as agg: 365 | for sample in progress: 366 | trainer.valid_iw_step(sample, mode=mode) 367 | 368 | # log validation stats 369 | stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) 370 | progress.print(stats, tag='valid_iw', step=trainer.get_num_updates()) 371 | 372 | # valid_losses.append(stats[args.best_checkpoint_metric]) 373 | 374 | if prune > 0: 375 | trainer.get_model().reset_prune_index() 376 | task.reset_index_map() 377 | 378 | return valid_losses 379 | 380 | 381 | def get_valid_stats(args, trainer, stats): 382 | if 'nll_loss' in stats and 'ppl' not in stats: 383 | stats['ppl'] = utils.get_perplexity(stats['nll_loss']) 384 | stats['num_updates'] = trainer.get_num_updates() 385 | if hasattr(checkpoint_utils.save_checkpoint, 'best'): 386 | key = 'best_{0}'.format(args.best_checkpoint_metric) 387 | best_function = max if args.maximize_best_checkpoint_metric else min 388 | stats[key] = best_function( 389 | checkpoint_utils.save_checkpoint.best, 390 | stats[args.best_checkpoint_metric], 391 | ) 392 | return stats 393 | 394 | 395 | def distributed_main(i, args, start_rank=0): 396 | args.device_id = i 397 | if args.distributed_rank is None: # torch.multiprocessing.spawn 398 | args.distributed_rank = start_rank + i 399 | main(args, init_distributed=True) 400 | 401 | 402 | def cli_main(modify_parser=None): 403 | parser = options.get_training_parser() 404 | args = options.parse_args_and_arch(parser, modify_parser=modify_parser) 405 | 406 | if args.distributed_init_method is None: 407 | distributed_utils.infer_init_method(args) 408 | 409 | if args.distributed_init_method is not None: 410 | # distributed training 411 | if torch.cuda.device_count() > 1 and not args.distributed_no_spawn: 412 | start_rank = args.distributed_rank 413 | args.distributed_rank = None # assign automatically 414 | torch.multiprocessing.spawn( 415 | fn=distributed_main, 416 | args=(args, start_rank), 417 | nprocs=torch.cuda.device_count(), 418 | ) 419 | else: 420 | distributed_main(args.device_id, args) 421 | elif args.distributed_world_size > 1: 422 | # fallback for single node with multiple GPUs 423 | assert args.distributed_world_size <= torch.cuda.device_count() 424 | port = random.randint(10000, 20000) 425 | args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) 426 | args.distributed_rank = None # set based on device id 427 | torch.multiprocessing.spawn( 428 | fn=distributed_main, 429 | args=(args, ), 430 | nprocs=args.distributed_world_size, 431 | ) 432 | else: 433 | # single GPU training 434 | main(args) 435 | 436 | 437 | if __name__ == '__main__': 438 | cli_main() 439 | --------------------------------------------------------------------------------