├── .gitignore ├── README.md ├── compute_all_L_metrics.py ├── compute_lm_metrics_basic.py ├── compute_mauve_metrics.py ├── compute_ref_metrics.py ├── compute_self_bleu_metric.py ├── download_generations.md ├── environment.yml ├── generate_basic.py ├── generate_ref.py ├── human_eval-compute_BT_scores.ipynb ├── human_evaluation.md ├── library ├── DRMM.py └── spreadingvectors │ ├── CODE_OF_CONDUCT.md │ ├── CONTRIBUTING.md │ ├── LICENSE │ ├── README.md │ ├── crossvalidate.sh │ ├── eval.py │ ├── lattices │ ├── Makefile │ ├── Zn_lattice.py │ ├── __init__.py │ ├── bench_Zn_decoder.py │ ├── c_lattices.swig │ ├── lattice_Zn.cpp │ ├── lattice_Zn.h │ ├── lattice_utils.cpp │ ├── lattice_utils.h │ └── test_Zn.py │ ├── reproduce.sh │ ├── requirements.txt │ ├── train.py │ └── train_spv.py ├── local_scripts ├── download_data.py ├── make_output_dirs.sh ├── parallelize.sh └── webtext │ ├── mauve_metrics_drmm.sh │ ├── mauve_metrics_kmeans.sh │ ├── mauve_metrics_spv.sh │ ├── run_all_L_metrics.sh │ ├── run_lm_metrics.sh │ └── run_self_bleu.sh ├── requirements.txt ├── slurm_scripts └── webtext │ ├── arr_generate_basic.sh │ └── generate_ref.sh └── src ├── __init__.py ├── generation_utils.py ├── mauve_metrics.py ├── metrics.py ├── model_utils.py ├── transformers_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | pycharm 2 | .idea 3 | 4 | # compiled 5 | ._pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # vim 31 | [._]*.sw[a-p] 32 | [._]s[a-rt-v][a-z] 33 | [._]ss[a-gi-z] 34 | [._]sw[a-p] 35 | 36 | # Session 37 | Session.vim 38 | Sessionx.vim 39 | 40 | # Temporary 41 | .netrwhist 42 | *~ 43 | # Auto-generated tag files 44 | tags 45 | # Persistent undo 46 | [._]*.un~ 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mauve-experiments 2 | 3 | This repository contains the code and the scripts to reproduce the experiments 4 | [in this paper](https://arxiv.org/pdf/2102.01454.pdf) published at **NeurIPS 2021** and awarded an **Outstanding Paper Award**. 5 | The paper introduces MAUVE, a comparison measure for open-ended text generation. 6 | 7 | MAUVE directly compares the distribution of machine-generated text to 8 | that of human language as the area under the divergence curve for the two distributions. 9 | MAUVE summarizes the trade-off between two types of errors: 10 | those arising from parts of the human distribution that the model distribution 11 | approximates well, and those it does not. 12 | 13 | _**Standalone package**: For a self-contained package to compute MAUVE, installable 14 | via `pip install mauve-text`, please 15 | see [this repository](https://github.com/krishnap25/mauve)._ 16 | 17 | **Summary**: 18 | * The rest of the README describes how to create the generations and reproduce the experiments described in the paper. 19 | * To download the generations used in the papers, see [here](https://github.com/krishnap25/mauve-experiments/blob/main/download_generations.md). This will still require you to featurize the generations and compute the metrics. 20 | * The data we collected in the human evaluations can be found [here](https://github.com/krishnap25/mauve-experiments/blob/main/human_evaluation.md). The code to compute the corresponding Bradley-Terry coefficients can be found [here](https://github.com/krishnap25/mauve-experiments/blob/main/human_eval-compute_BT_scores.ipynb). 21 | 22 | ## Dependencies 23 | The code is written in Python and the dependencies are: 24 | - Python >= 3.6 25 | - PyTorch >= 1.1 26 | - Huggingface Transformers >= 4.2.0 27 | - NLTK >= 3.4.5 28 | - scikit-learn >= 0.22.1 29 | - faiss-gpu >= 1.7.0 30 | - tqdm >= 4.40.0 31 | 32 | **Conda Environment**: 33 | We recommend using a [conda environment](https://docs.conda.io/en/latest/miniconda.html) 34 | for Python 3.8. 35 | To setup the environment, run 36 | ```bash 37 | conda env create --file environment.yml 38 | # activate the environment 39 | conda activate mauve-experiments 40 | ``` 41 | In addition, you will have to install the following manually: 42 | - PyTorch, version 1.7: [instructions](https://pytorch.org/get-started/locally/), 43 | - HuggingFace Transformers, version 4.2.0: [instructions](https://huggingface.co/transformers). 44 | 45 | The code is compatible with PyTorch >= 1.1.0 and transformers >= 3.2.0 but 46 | we have not thoroughly tested it in this configuration. 47 | 48 | 49 | **Install Dependencies via Pip**: 50 | Install PyTorch, version 1.7 ([instructions here](https://pytorch.org/get-started/locally)) 51 | and then run 52 | ```bash 53 | pip install -r requirement.txt 54 | ``` 55 | 56 | ## Datasets 57 | We use the webtext data from the [GPT-2 output dataset repository](https://github.com/openai/gpt-2-output-dataset). 58 | For the purpose of reproducing these experiments, 59 | it suffices to simply download the test set of webtext. 60 | To this end, run: 61 | ```python 62 | python local_scripts/download_data.py 63 | ``` 64 | The data is downloaded to the folder `./data` and pass `--data_dir ./data` for all scripts below. 65 | 66 | ## Experimental Pipeline 67 | For each dataset, once we have the pretrained models, the experimental pipeline is as follows: 68 | 1. generate samples and featurize samples (GPU needed) 69 | 2. compute MAUVE (CPU suffices, highly parallelizable) 70 | 3. compute LM metrics such as perplexity, sparsemax score, Jensen-Shannon score, etc. (GPU needed) 71 | 4. compute self-BLEU (CPU only, embarassingly parallelizable between multiple cores) 72 | 5. compute all other metrics (CPU only) 73 | 6. compute steps 4 and 5 on the human data 74 | 75 | The generation of samples (Step 1) must be run first. Other steps can proceed in any order. 76 | 77 | Here is how to find the scripts step-by-step for webtext. 78 | The variables which need to be set are detailed at the top of each script. 79 | 80 | **Step 0. Prepare directory:** 81 | Run `bash local_scripts/make_output_dirs.sh` to create the necessary output directories. 82 | 83 | **Step 1. Generate the samples:** 84 | Run `slurm_scripts/webtext/arr_generate_basic.sh ${model_size}` to generate samples of basic methods 85 | (pure sampling, top-K sampling, temperature sampling, nucleus sampling and greedy decoding). 86 | `${model_size}` is one of `['small', 'medium', 'large', 'xl']`. 87 | 88 | It is written as a slurm array job. 89 | For each configuration and model size, we generate five sets of 5000 samples each 90 | using prompts from the dataset. This script internally calls the file `generate_basic.py`. 91 | The outputs are stored in `./outputs/{dataset_name}_{model_name}/generations/basic` 92 | The running time for each run varies from around 1 hour (GPT-2 small/medium) to around 3-4 hours (GPT-2 large) 93 | and 12 hours (GPT-2 XL) on a NVIDIA Quadro GPU with a memory of 24G. 94 | If you use a GPU with a memory of 12G, it will likely take around twice as long. 95 | 96 | This creates the following in `./outputs/{dataset_name}_{model_name}/generations/basic`. 97 | - `sentences_test_p${topp}_k${topk}_t${temp}_seed${seed}.p` (e.g. `sentences_test_p0.99_k0_t1.0_seed0.p`): 98 | contains the raw samples in string form. If you load this using pickle, you will find 99 | two lists: (1) list of strings which are the actual samples generated, and, 100 | (2) list of booleans, denoting termination, i.e., whether a `||` (EOS) token was generated. 101 | - `sample_test_p${topp}_k${topk}_t${temp}_seed${seed}.p` (e.g. `sample_test_p0.99_k0_t1.0_seed0.p`): 102 | contains the samples after tokenization. If you load this using pickle, you will find 103 | a list of 5 entires: (1) list of list of ints, each of which is the BPE tokenized representation 104 | of the samples generated above, 105 | (2) list of booleans, denoting termination (same as above), 106 | (3) unique n-gram fraction, for n in 1 to 6, 107 | (4) perplexity of the generated text under the model, and, 108 | (5) the parsed arguments of the script `generate_basic.py`. 109 | - `featsL${max_length}_test_p${topp}_k${topk}_t${temp}.0_seed4.pt` (e.g. `featsL1024_test_p0.99_k0_t1.0_seed0.pt`): 110 | features representation (i.e., terminal hidden state) 111 | under the GPT-2 large model. Each such a file is 25M in size. 112 | For each configuration, we create 4 files with 113 | `max_length` in `{128, 256, 512, 1024}`. 114 | 115 | Next, run `slurm_scripts/webtext/generate_ref.sh` to featurize the human-written text (i.e., webtext test set). 116 | The output is created in `./outputs/{dataset_name}_{model_name}/generations/ref`. 117 | 118 | 119 | **Step 2. Compute MAUVE:** 120 | Run `local_scripts/webtext/mauve_metrics_*.sh`. 121 | - `local_scripts/webtext/mauve_metrics_kmeans.sh`: use k-means for quantization. 122 | Runs on CPU within a few minutes per run. It is massively parallelizable. 123 | - `local_scripts/webtext/mauve_metrics_drmm.sh`: use deep residual mixture models (DRMM) for quantization (Hämäläinen et. al. 2020). 124 | It is copied with minor edits from the [original repo](https://github.com/PerttuHamalainen/DRMM) 125 | (note: this requires TensorFlow 1.12 to be installed. A CPU-only install suffices). 126 | A CPU-only run takes around 2 hours. It is also massively parallelizable. 127 | - `local_scripts/webtext/mauve_metrics_spv.sh`: use spreading vectors for quantization (Sablayrolles et. al. 2018). 128 | It is copied with minor edits from the [original repo](https://github.com/facebookresearch/spreadingvectors). 129 | It runs in <10 minutes on a GPU. 130 | 131 | The outputs are written in 132 | `./outputs/{dataset_name}_{model_name}/metrics/basic`. 133 | The filenames are: 134 | - k-means: `mauve_L${max_len}_test_p${topp}_k${topk}_t${temp}_seed${seed}_kmeans_l2_${num_clusters}_${lower_dim_explained_variance}.p` (e.g., `mauve_L1024_test_p1.0_k50_t1.0_seed2_drmm_3_10.p`): 135 | arguments are `num_clusters` (number of clusters) and 136 | `lower_dim_explained_variance` (lower dimensionality after PCA is chosen with at least this much explained variance). 137 | - DRMM: `mauve_L${max_len}_test_p${topp}_k${topk}_t${temp}_seed${seed}_drmm_${num_layers}_${num_components_per_layer}.p` (e.g., `mauve_L1024_test_p1.0_k50_t1.0_seed2_drmm_3_10.p`): 138 | arguments are `num_layers` (number of layers in the DRMM) and `num_components_per_layer` (number of components in each layer). 139 | The equivalent number of k-means clusters would be `${num_components_per_layer}^${num_layers}` 140 | - SPV: `mauve_L${max_len}_test_p${topp}_k${topk}_t${temp}_seed${seed}_spv.p` (e.g., `mauve_L1024_test_p1.0_k50_t1.0_seed2_drmm_3_10.p`) 141 | 142 | Each of these outputs is a pickle file. In each output, we have, 143 | `[p_hist, q_hist, mauve]`, where `p_hist` and `q_hist` are 144 | respectively the multinomial distributions 145 | obtained after quantization. 146 | 147 | **Step 3. Compute LM metrics:** 148 | Run `local_scripts/webtext/run_lm_metrics.sh`, 149 | which in turn invokes `compute_lm_metrics_basic.sh`. 150 | Output files are written in 151 | `./outputs/{dataset_name}_{model_name}/metrics/basic` 152 | with name `lm_test_p${topp}_k${topk}_t${temp}.p` (e.g., `lm_test_p1.0_k5_t1.0.p`). 153 | Only one job is run per each seed 154 | (since the metrics depend on the model but not on the actual generations). 155 | 156 | **Step 4. Compute Self-BLEU**: 157 | Run `local_scripts/webtext/run_self_bleu.sh`, 158 | which in turn calls `compute_self_bleu_metric.sh`. 159 | Takes around 7 hours on a single processor core, but is embarassingly parallel. 160 | The current script runs one processor per job but 161 | parallelizes jobs at once. 162 | Output files are written in 163 | `./outputs/{dataset_name}_{model_name}/metrics/basic` 164 | with name `bleu_test_p${topp}_k${topk}_t${temp}_seed${seed}.p` 165 | (e.g., `bleu_test_p1.0_k500_t1.0_seed4.p`). 166 | 167 | **Step 5. Compute all other metrics**: 168 | Run `local_scripts/webtext/run_all_L_metrics.sh`. 169 | It calls `compute_all_L_metrics.py` under the hood 170 | and computes other metrics such as the Zipf coefficient 171 | and repetition ratio. Runs in a few seconds. 172 | 173 | Output files are written in 174 | `./outputs/{dataset_name}_{model_name}/metrics/basic` 175 | with name `all_L_test_p${topp}_k${topk}_t${temp}_seed${seed}.p` 176 | (e.g., `all_L_test_p0.92_k0_t1.0_seed3.p`). 177 | 178 | **Step 6. Compute metrics on human data**: 179 | To perform steps 4 and 5 on the human-written text, run 180 | ```bash 181 | python compute_ref_metrics.py --datasplit test --device 0 --parallel_bleu --n_proc_bleu 24 182 | ``` 183 | The self-BLEU computation is the most time-consuming (~7 hours with one process) 184 | and its running time 185 | depends on how many processes are allowed (`--n_proc_bleu`). 186 | Outputs are written to 187 | `./outputs/{dataset_name}_{model_name}/metrics/ref`. 188 | 189 | 190 | ## Citation 191 | If you find this repository useful, or you use it in your research, please cite the following papers: 192 | ``` 193 | 194 | 195 | @article{pillutla-etal:mauve:jmlr2023, 196 | title={{MAUVE Scores for Generative Models: Theory and Practice}}, 197 | author={Pillutla, Krishna and Liu, Lang and Thickstun, John and Welleck, Sean and Swayamdipta, Swabha and Zellers, Rowan and Oh, Sewoong and Choi, Yejin and Harchaoui, Zaid}, 198 | journal={JMLR}, 199 | year={2023} 200 | } 201 | 202 | @inproceedings{pillutla-etal:mauve:neurips2021, 203 | title={MAUVE: Measuring the Gap Between Neural Text and Human Text using Divergence Frontiers}, 204 | author={Pillutla, Krishna and Swayamdipta, Swabha and Zellers, Rowan and Thickstun, John and Welleck, Sean and Choi, Yejin and Harchaoui, Zaid}, 205 | booktitle = {NeurIPS}, 206 | year = {2021} 207 | } 208 | 209 | @inproceedings{liu-etal:mauve-theory:neurips2021, 210 | title={{Divergence Frontiers for Generative Models: Sample Complexity, Quantization Effects, and Frontier Integrals}}, 211 | author={Liu, Lang and Pillutla, Krishna and Welleck, Sean and Oh, Sewoong and Choi, Yejin and Harchaoui, Zaid}, 212 | booktitle={NeurIPS}, 213 | year={2021} 214 | } 215 | 216 | ``` 217 | 218 | ## Acknowledgements 219 | This work was supported by NSF CCF-2019844,the DARPA MCS program through NIWC Pacific(N66001-19-2-4031), 220 | the CIFAR program "Learning in Machines and Brains", 221 | a Qualcomm Innovation Fellowship, and faculty research awards. 222 | 223 | 224 | 225 | -------------------------------------------------------------------------------- /compute_all_L_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | import torch 4 | import src.utils as utils 5 | import src.metrics 6 | 7 | 8 | def main(): 9 | parser = utils.make_metrics_parser() 10 | args = parser.parse_args() 11 | main_metrics(args) 12 | 13 | def main_metrics(args): 14 | print(f'device: {args.device}') 15 | device = utils.get_device_from_arg(args.device) 16 | print(f'Using device: {device}') 17 | 18 | save_directory = f'./outputs/{utils.get_dataset_name_from_datapath(args.data_dir)}_{utils.get_model_basename(args.model_name)}' 19 | filename = f'{args.datasplit}_p{args.top_p}_k{args.top_k}_t{args.temp}_seed{args.generate_seed}' 20 | folder_name = f'{save_directory}/generations/basic' 21 | 22 | 23 | input_file_name = f'{folder_name}/sample_{filename}.p' 24 | if not os.path.isfile(input_file_name): 25 | print(f'File {input_file_name} does not exist. Quitting!') 26 | return 27 | with open(input_file_name, 'rb') as f: 28 | all_sentences, is_completed = pkl.load(f)[:2] 29 | 30 | savefilename = f'{save_directory}/metrics/basic/all_L_{filename}.p' 31 | if os.path.isfile(savefilename) and not args.force: 32 | print('All metrics already computed. Exiting') 33 | return 34 | 35 | model, tokenizer = utils.get_model_and_tokenizer(model_name='gpt2-large', device=device) 36 | 37 | metrics_all = {} 38 | # Distinct-n 39 | n_lst = [1, 2, 3, 4, 5, 6] 40 | unique_ngram_frac = src.metrics.get_unique_ngram_fraction(all_sentences, n_lst) 41 | metrics_all['distinct-n'] = unique_ngram_frac 42 | 43 | # PPL 44 | samples_2 = [torch.LongTensor(x).view(1, -1).to(device) for x in all_sentences] 45 | ppl = src.metrics.get_perplexity_from_samples(model, samples_2) 46 | metrics_all['perplexity'] = ppl 47 | 48 | # Zipf 49 | metrics_all['zipf'] = src.metrics.zipf_coeff(all_sentences) 50 | 51 | # Repetition 52 | metrics_all['repetition'] = src.metrics.get_repetition_fraction(all_sentences) 53 | 54 | # Non-termination 55 | metrics_all['non-termination-ratio'] = src.metrics.get_nontermination_ratio(all_sentences, is_completed) 56 | 57 | # save 58 | with open(savefilename, 'wb') as f: 59 | pkl.dump(metrics_all, f) 60 | print(f'Done. Saved "{savefilename}". Bye!') 61 | 62 | 63 | if __name__ == '__main__': 64 | main() 65 | -------------------------------------------------------------------------------- /compute_lm_metrics_basic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, time, pickle as pkl 3 | 4 | import src.utils as utils 5 | import src.metrics 6 | 7 | def make_parser(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--device', type=int, default=-1, 10 | help='choose one of [0, 1, 2, 3] for GPU, or CPU otherwise') 11 | parser.add_argument('--data_dir', type=str, default='./data') 12 | parser.add_argument('--datasplit', type=str, default='valid') 13 | parser.add_argument('--model_name', type=str, default='gpt2') 14 | parser.add_argument('--max_len', type=int, default=1024) 15 | parser.add_argument('--max_num_data', type=int, default=5000) 16 | return parser 17 | 18 | def get_metrics(param, metric_fn_lst, model, ds_tokens, datasplit, metric_fn_names, save_directory): 19 | # param = (top_, top_k, temp) 20 | p, k, temp = param 21 | output_file_name = f'{save_directory}/metrics/basic/lm_{datasplit}_p{p}_k{k}_t{temp}.p' 22 | if os.path.isfile(output_file_name): 23 | print(f'{output_file_name} existing. Exiting.') 24 | return 25 | t1 = time.time() 26 | metrics = src.metrics.compute_metrics_from_probs( 27 | model, ds_tokens, metric_fn_lst, eppl_eps_lst=[1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 0], 28 | temperature=temp, top_k=k, top_p=p, 29 | ) 30 | t2 = time.time() 31 | print(metrics, round(t2-t1, 2)) 32 | 33 | with open(output_file_name, 'wb') as f: 34 | pkl.dump([metrics, metric_fn_names], f) 35 | 36 | def main(): 37 | parser = make_parser() 38 | args = parser.parse_args() 39 | print(args) 40 | 41 | device = utils.get_device_from_arg(args.device) 42 | print(f'Using device: {device}') 43 | 44 | model, tokenizer = utils.get_model_and_tokenizer(model_name=args.model_name, device=device) 45 | save_directory = f'./outputs/{utils.get_dataset_name_from_datapath(args.data_dir)}_{utils.get_model_basename(args.model_name)}' 46 | 47 | ds_tokens = utils.load_and_tokenize_data(tokenizer, args.data_dir, 48 | args.max_len, args.max_num_data, split=args.datasplit) 49 | 50 | metric_fn_lst = src.metrics.get_probs_metric_fn_lst() 51 | metric_fn_names = src.metrics.get_metric_names() 52 | print(metric_fn_names) 53 | 54 | for p in [0.8, 0.9, 0.92, 0.95, 0.99]: # 5 55 | param = (p, 0, 1.0) 56 | get_metrics(param, metric_fn_lst, model, ds_tokens, args.datasplit, metric_fn_names, save_directory) 57 | 58 | for k in [1, 5, 10, 50, 100, 500, 1000, 2000, 5000, 10000]: # 10 59 | param = (1.0, k, 1.0) 60 | get_metrics(param, metric_fn_lst, model, ds_tokens, args.datasplit, metric_fn_names, save_directory) 61 | 62 | for t in [0.7, 0.8, 0.9, 0.95, 1.0]: # 5 63 | param = (1.0, 0, t) 64 | get_metrics(param, metric_fn_lst, model, ds_tokens, args.datasplit, metric_fn_names, save_directory) 65 | 66 | for t in [0.75, 0.9]: # 4 67 | for k in [10, 100]: 68 | param = (1.0, k, t) 69 | get_metrics(param, metric_fn_lst, model, ds_tokens, args.datasplit, metric_fn_names, save_directory) 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /compute_mauve_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | import torch 4 | 5 | import src.utils as utils 6 | import src.mauve_metrics as mauve_metrics 7 | 8 | 9 | def main(): 10 | parser = utils.make_metrics_parser() 11 | args = parser.parse_args() 12 | print(args) 13 | torch.manual_seed(args.seed) 14 | device = utils.get_device_from_arg(args.device) 15 | save_directory = f'./outputs/{utils.get_dataset_name_from_datapath(args.data_dir)}_{utils.get_model_basename(args.model_name)}' 16 | 17 | if args.use_large_feats: 18 | feats_suffix = f'L{args.max_len}' 19 | elif args.use_bert_feats: 20 | feats_suffix = f'B{args.max_len}' 21 | else: 22 | feats_suffix = '' 23 | 24 | if args.use_large_feats: 25 | print('---------------Using features from GPT-2 Large!!!! Suffix =', feats_suffix) 26 | elif args.use_bert_feats: 27 | print('---------------Using features from Roberta Large!!!! Suffix =', feats_suffix) 28 | else: 29 | print('---------------Using features from model used for generations!!!!') 30 | 31 | if not os.path.isfile(f'{save_directory}/generations/ref/feats{feats_suffix}_{args.datasplit}.pt'): 32 | raise FileNotFoundError(f'Generations {save_directory}/generations/ref/feats{feats_suffix}_{args.datasplit}.pt do not exist') 33 | p_feats = torch.load(f'{save_directory}/generations/ref/feats{feats_suffix}_{args.datasplit}.pt') 34 | folder, filename = utils.get_save_filename_from_args(args) 35 | 36 | algo_name = mauve_metrics.get_discretization_algo_name( 37 | discretization_algo=args.discretization, 38 | kmeans_num_clusters=args.kmeans_num_clusters, kmeans_explained_var=args.kmeans_explained_var, 39 | drmm_num_epochs=args.drmm_num_epochs, drmm_n_layer=args.drmm_n_layer, 40 | drmm_n_comp_per_layer=args.drmm_n_component_per_layer, 41 | spv_num_epochs=args.spv_num_epochs, device=device, seed=args.seed+1 42 | ) 43 | savefilename = f'{save_directory}/metrics/{folder}/mauve_{feats_suffix}_{filename}_{algo_name}.p' 44 | if os.path.isfile(savefilename) and not args.force: 45 | print('Metrics already exist. Exiting') 46 | return 47 | 48 | if not os.path.isfile(f'{save_directory}/generations/{folder}/feats{feats_suffix}_{filename}.pt'): 49 | raise FileNotFoundError(f'Generations {save_directory}/generations/{folder}/feats{feats_suffix}_{filename}.pt do not exist') 50 | 51 | q_feats = torch.load(f'{save_directory}/generations/{folder}/feats{feats_suffix}_{filename}.pt') 52 | 53 | p_quant, q_quant, metrics = mauve_metrics.compute_mauve_metrics( 54 | p_feats, q_feats, discretization_algo=args.discretization, 55 | kmeans_num_clusters=args.kmeans_num_clusters, kmeans_explained_var=args.kmeans_explained_var, 56 | drmm_num_epochs=args.drmm_num_epochs, drmm_n_layer=args.drmm_n_layer, 57 | drmm_n_comp_per_layer=args.drmm_n_component_per_layer, 58 | spv_num_epochs=args.spv_num_epochs, device=device, seed=args.seed+1 59 | ) 60 | print('Mauve metric:', metrics) 61 | 62 | # save 63 | with open(savefilename, 'wb') as f: 64 | pkl.dump([metrics, p_quant, q_quant], f) 65 | print(f'Done. Saved "{savefilename}". Bye!') 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /compute_ref_metrics.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import pickle as pkl 4 | import time 5 | import torch 6 | from nltk.translate.bleu_score import SmoothingFunction 7 | from functools import partial 8 | from multiprocessing.pool import Pool 9 | 10 | import src.utils as utils 11 | from src.generation_utils import self_bleu_one_sentence, get_bleu_weight_for_ngram 12 | from src.utils import tqdm 13 | import src.metrics 14 | 15 | 16 | def main(): 17 | parser = utils.make_metrics_parser() 18 | args = parser.parse_args() 19 | main_metrics(args) 20 | main_bleu(args) 21 | 22 | # Pass args: datasplit, data_dir, parllel_bleu, n_proc_bleu 23 | # time python -u compute_ref_metrics.py --datasplit test --device 1 --parallel_bleu --n_proc_bleu 12 > outs/ref/test 2>&1 24 | 25 | def main_metrics(args): 26 | device = utils.get_device_from_arg(args.device) 27 | print(f'Using device: {device}') 28 | 29 | save_directory = f'./outputs/{utils.get_dataset_name_from_datapath(args.data_dir)}_{utils.get_model_basename(args.model_name)}' 30 | model, tokenizer = utils.get_model_and_tokenizer(model_name=args.model_name, device=device) 31 | folder = 'ref' 32 | if args.ds_name is None: 33 | filename = args.datasplit 34 | else: 35 | filename = f'{args.ds_name}_{args.datasplit}' 36 | 37 | ds_tokens = utils.load_and_tokenize_data( 38 | tokenizer, args.data_dir, args.max_len, args.max_num_data, 39 | ds_name=args.ds_name, split=args.datasplit 40 | ) 41 | savefilename = f'{save_directory}/metrics/{folder}/all_{filename}.p' 42 | if os.path.isfile(savefilename) and not args.force: 43 | print('All metrics already computed. Exiting') 44 | return 45 | 46 | all_sentences = [x[0].numpy().tolist() for x in ds_tokens] 47 | is_completed = [True for _ in all_sentences] 48 | 49 | metrics_all = {} 50 | 51 | # Distinct-n 52 | n_lst = [1, 2, 3, 4, 5, 6] 53 | unique_ngram_frac = src.metrics.get_unique_ngram_fraction(all_sentences, n_lst) 54 | metrics_all['distinct-n'] = unique_ngram_frac 55 | 56 | # PPL 57 | samples_2 = [torch.LongTensor(x).view(1, -1).to(device) for x in all_sentences] 58 | ppl = src.metrics.get_perplexity_from_samples(model, samples_2) 59 | metrics_all['perplexity'] = ppl 60 | 61 | # Zipf 62 | metrics_all['zipf'] = src.metrics.zipf_coeff(all_sentences) 63 | 64 | # Repetition 65 | metrics_all['repetition'] = src.metrics.get_repetition_fraction(all_sentences) 66 | 67 | # Non-termination 68 | metrics_all['non-termination-ratio'] = src.metrics.get_nontermination_ratio(all_sentences, is_completed) 69 | 70 | # save 71 | with open(savefilename, 'wb') as f: 72 | pkl.dump(metrics_all, f) 73 | print(f'Done. Saved "{savefilename}". Bye!') 74 | 75 | 76 | def main_bleu(args): 77 | rng = random.Random(args.seed) 78 | 79 | save_directory = f'./outputs/{utils.get_dataset_name_from_datapath(args.data_dir)}_{utils.get_model_basename(args.model_name)}' 80 | _, tokenizer = utils.get_model_and_tokenizer(model_name=args.model_name, device=utils.CPU_DEVICE) 81 | folder = 'ref' 82 | if args.ds_name is None: 83 | filename = args.datasplit 84 | else: 85 | filename = f'{args.ds_name}_{args.datasplit}' 86 | 87 | ds_tokens = utils.load_and_tokenize_data( 88 | tokenizer, args.data_dir, args.max_len, args.max_num_data, 89 | ds_name=args.ds_name, split=args.datasplit 90 | ) 91 | all_sentences = [x[0].numpy().tolist() for x in ds_tokens] 92 | 93 | savefilename = f'{save_directory}/metrics/{folder}/bleu_{filename}.p' 94 | if os.path.isfile(savefilename) and not args.force: 95 | print('Bleu metrics already computed. Exiting') 96 | return 97 | 98 | smoothing_function = SmoothingFunction().method1 99 | 100 | start_time = time.time() 101 | if args.parallel_bleu: 102 | bleu_scores = compute_bleus_parallel(all_sentences, smoothing_function, rng, args) 103 | else: 104 | bleu_scores = compute_bleus_sequential(all_sentences, smoothing_function, rng, args) 105 | print('Total time for self bleu:', round(time.time() - start_time), 's') 106 | 107 | 108 | # save 109 | with open(savefilename, 'wb') as f: 110 | pkl.dump(bleu_scores, f) 111 | print(f'Done. Saved "{savefilename}". Bye!') 112 | 113 | 114 | def compute_bleus_sequential(all_sentences, smoothing_function, rng, args): 115 | bleu_scores = [] 116 | for n in range(1, 6): 117 | start_time = time.time() 118 | weights = get_bleu_weight_for_ngram(n) 119 | bleu_n_lst = [ 120 | self_bleu_one_sentence(weights, all_sentences, smoothing_function, i) 121 | for i in rng.sample(range(len(all_sentences)), min(len(all_sentences), args.n_sample_bleu)) 122 | ] 123 | bleu_scores.append(sum(bleu_n_lst) / len(bleu_n_lst)) 124 | print(f'Total time for self bleu-{n}:', round(time.time() - start_time), 's') 125 | return bleu_scores 126 | 127 | 128 | def compute_bleus_parallel(all_sentences, smoothing_function, rng, args): 129 | pool = Pool(processes=min(args.n_proc_bleu, os.cpu_count())) 130 | bleu_scores = [] 131 | for n in range(1, 6): 132 | start_time = time.time() 133 | weights = get_bleu_weight_for_ngram(n) 134 | bleu_n_lst = list(tqdm( 135 | pool.imap_unordered( 136 | partial(self_bleu_one_sentence, weights, all_sentences, smoothing_function), 137 | rng.sample(range(len(all_sentences)), min(len(all_sentences), args.n_sample_bleu))), 138 | total=args.n_sample_bleu)) 139 | bleu_scores.append(sum(bleu_n_lst) / len(bleu_n_lst)) 140 | print(f'Total time for self bleu-{n}:', round(time.time() - start_time), 's') 141 | return bleu_scores 142 | 143 | if __name__ == '__main__': 144 | main() 145 | -------------------------------------------------------------------------------- /compute_self_bleu_metric.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import pickle as pkl 4 | import time 5 | from nltk.translate.bleu_score import SmoothingFunction 6 | from functools import partial 7 | from multiprocessing.pool import Pool 8 | 9 | import src.utils as utils 10 | from src.generation_utils import self_bleu_one_sentence, get_bleu_weight_for_ngram 11 | from src.utils import tqdm 12 | 13 | # Inspired by https://github.com/ari-holtzman/degen/blob/master/metrics/self_bleu.py 14 | 15 | # Run time (serial): ~7 hours 16 | 17 | def main(): 18 | parser = utils.make_metrics_parser() 19 | args = parser.parse_args() 20 | rng = random.Random(args.seed) 21 | 22 | save_directory = f'./outputs/{utils.get_dataset_name_from_datapath(args.data_dir)}_{utils.get_model_basename(args.model_name)}' 23 | folder, filename = utils.get_save_filename_from_args(args) 24 | if not os.path.isfile(f'{save_directory}/generations/{folder}/sample_{filename}.p'): 25 | raise FileNotFoundError(f'Generations {save_directory}/generations/{folder}/sample_{filename}.p do not exist') 26 | 27 | savefilename = f'{save_directory}/metrics/{folder}/bleu_{filename}.p' 28 | if os.path.isfile(savefilename) and not args.force: 29 | print('Bleu metrics already computed. Exiting') 30 | return 31 | 32 | with open(f'{save_directory}/generations/{folder}/sample_{filename}.p', 'rb') as f: 33 | all_sentences = pkl.load(f)[0] 34 | smoothing_function = SmoothingFunction().method1 35 | 36 | start_time = time.time() 37 | if args.parallel_bleu: 38 | bleu_scores = compute_bleus_parallel(all_sentences, smoothing_function, rng, args) 39 | else: 40 | bleu_scores = compute_bleus_sequential(all_sentences, smoothing_function, rng, args) 41 | print('Total time for self bleu:', round(time.time() - start_time), 's') 42 | 43 | 44 | # save 45 | with open(savefilename, 'wb') as f: 46 | pkl.dump(bleu_scores, f) 47 | print(f'Done. Saved "{savefilename}". Bye!') 48 | 49 | 50 | def compute_bleus_sequential(all_sentences, smoothing_function, rng, args): 51 | bleu_scores = [] 52 | for n in range(1, 6): 53 | start_time = time.time() 54 | weights = get_bleu_weight_for_ngram(n) 55 | bleu_n_lst = [ 56 | self_bleu_one_sentence(weights, all_sentences, smoothing_function, i) 57 | for i in rng.sample(range(len(all_sentences)), args.n_sample_bleu) 58 | ] 59 | bleu_scores.append(sum(bleu_n_lst) / len(bleu_n_lst)) 60 | print(f'Total time for self bleu-{n}:', round(time.time() - start_time), 's') 61 | return bleu_scores 62 | 63 | 64 | def compute_bleus_parallel(all_sentences, smoothing_function, rng, args): 65 | pool = Pool(processes=min(args.n_proc_bleu, os.cpu_count())) 66 | bleu_scores = [] 67 | for n in range(1, 6): 68 | start_time = time.time() 69 | weights = get_bleu_weight_for_ngram(n) 70 | bleu_n_lst = list(tqdm( 71 | pool.imap_unordered( 72 | partial(self_bleu_one_sentence, weights, all_sentences, smoothing_function), 73 | rng.sample(range(len(all_sentences)), args.n_sample_bleu)), 74 | total=args.n_sample_bleu)) 75 | bleu_scores.append(sum(bleu_n_lst) / len(bleu_n_lst)) 76 | print(f'Total time for self bleu-{n}:', round(time.time() - start_time), 's') 77 | return bleu_scores 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /download_generations.md: -------------------------------------------------------------------------------- 1 | # Download Text Generations 2 | 3 | We provide the text generations from GPT-2 (various sizes) on the webtext datasets. 4 | These samples were used in much of the empirical evaluation in our [NeurIPS 2021 paper](https://arxiv.org/pdf/2102.01454.pdf) 5 | and the subsequent [longer version](https://arxiv.org/pdf/2212.14578.pdf) (under review as of June 2023). 6 | 7 | 8 | **Trigger Warning**: 9 | The generated text is sampled from GPT-2 models of various sizes. It could be biased, harmful, racist, sexist, toxic, and potentially upsetting. 10 | Please use at your own risk. 11 | Do not treat model outputs as substitutes for human judgment or as sources of truth. Please use responsibly. 12 | 13 | 14 | ## Download the data 15 | The data can be found at [this Google Drive link](https://drive.google.com/file/d/1DlmEQ3zgaBMKDRA-Yu5VFD-xw0JvJOft/view?usp=sharing). 16 | The MD5 checksum of `mauve_generations.tgz` is `63bae977e3ce5f3c86d9e35188c1b8e6`. 17 | 18 | You can alternatively download it via the command line using [gdown](https://github.com/wkentaro/gdown) as 19 | 20 | ```bash 21 | pip install gdown # Install gdown if you do not have it 22 | file_id="1DlmEQ3zgaBMKDRA-Yu5VFD-xw0JvJOft" # ID of the file on Google Drive 23 | gdown https://drive.google.com/uc?id=${file_id} # Download the generations 24 | md5sum mauve_generations.tgz # verify that it is "63bae977e3ce5f3c86d9e35188c1b8e6" 25 | tar -zvxf mauve_generations.tgz # Uncompress the generations 26 | ``` 27 | 28 | This downloads `mauve_generations.tgz` whose size is 992M compressed and 2.1G uncompressed. 29 | 30 | ## Data format 31 | The folder structure is `mauve_experiments/webtext_${model_name}/sample_test_p${p}_k${k}_t1.0_seed${seed}.p`, where 32 | * `model_name` is the name of the model and takes values from `['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']` corresponding to the four sizes of GPT-2 33 | * `p` is the top-p parameter of nucleus sampling, and takes values `[0.9, 0.92, 0.95, 0.99, 1.0]` 34 | * `k` is top-k parameter for top-k sampling, and takes values `[0, 1]` (the latter for greedy decoding) 35 | * `seed` is the random seed and takes values in `[0, 1, 2, 3, 4]`. 36 | 37 | The generations are stored as Python Pickle archives. 38 | Follow the [recommended precautions](https://docs.python.org/3/library/pickle.html) when dealing with pickle archives and use at your own risk. 39 | 40 | ## Load the generations 41 | Each pickle file can be loaded as follows: 42 | ```python 43 | import pickle as pkl 44 | filename = "mauve_generations/webtext_gpt2-large/sample_test_p0.95_k0_t1.0_seed1.p" # Or choose your own 45 | with open(filename, "rb") as f: 46 | generations = pkl.load(f)[0] 47 | ``` 48 | 49 | The object `generations` is a list of length 5000, one for each example of the testset of webtext (available from [here](https://github.com/openai/gpt-2-output-dataset), see also the [README](https://github.com/krishnap25/mauve-experiments/blob/main/README.md). 50 | Each entry is a list of integers, representing the BPE tokens used by GPT-2. To get the raw detokenized text, you can run 51 | 52 | ```python 53 | from transformers import GPT2Tokenizer 54 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 55 | print(tokenizer.decode(generations[0])) # => de-tokenized generations 56 | ``` 57 | 58 | ## Plugging the generations into the rest of the experimental pipeline 59 | 60 | The format of files written by the scripts of this repository is described in the [README](https://github.com/krishnap25/mauve-experiments/blob/main/README.md). 61 | We provide only the `sample*` files, while the `sentences*` and `feats*` files will still have to be created. 62 | 63 | To this end, first move each file from `mauve_generations/webtext_{model_name}/` to `./outputs/webtext_{model_name}/generations/basic/`. Then, follow the instructions [here](https://github.com/krishnap25/mauve-experiments/blob/main/README.md#experimental-pipeline). 64 | This will skip the generation and proceed to featurizing the samples directly (as enforced by [this check](https://github.com/krishnap25/mauve-experiments/blob/main/generate_basic.py#L43)). 65 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mauve-experiments 2 | channels: 3 | - conda-forge 4 | - anaconda 5 | dependencies: 6 | - python=3.8 7 | - pip 8 | - numpy=1.19.2 9 | - scikit-learn=0.24.1 10 | - nltk=3.4.5 11 | - tqdm=4.40.0 12 | - requests 13 | - pip: 14 | - faiss-gpu==1.7.0 15 | -------------------------------------------------------------------------------- /generate_basic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys, os, time, pickle as pkl 3 | import time 4 | import torch 5 | 6 | import src.model_utils 7 | from src import utils, generation_utils as gen_utils 8 | import src.metrics 9 | 10 | if __name__ == '__main__': 11 | parser = utils.make_basic_parser() 12 | args = parser.parse_args() 13 | print(args) 14 | torch.manual_seed(args.seed) 15 | 16 | print('*********Prompt size =', args.prompt_size) 17 | 18 | if not args.use_large_feats: 19 | raise ValueError('Need to use large feats') 20 | 21 | # check if have to run 22 | save_directory = f'./outputs/{utils.get_dataset_name_from_datapath(args.data_dir)}_{utils.get_model_basename(args.model_name)}' 23 | name = f'{args.datasplit}_p{args.top_p}_k{args.top_k}_t{args.temp}_seed{args.seed}' 24 | folder_name = f'{save_directory}/generations/basic' 25 | if os.path.isfile(f'{folder_name}/feats_{name}.pt'): 26 | print(f'File: {folder_name}/feats_{name}.pt already exists. Exiting') 27 | sys.exit(-1) 28 | else: 29 | print(f'File: {folder_name}/feats_{name}.pt does not exist. Proceeding with generation') 30 | 31 | 32 | device = utils.get_device_from_arg(args.device) 33 | print(f'Using device: {device}') 34 | 35 | model, tokenizer = utils.get_model_and_tokenizer(model_name=args.model_name, device=device) 36 | 37 | if args.max_len is None: 38 | args.max_len = tokenizer.model_max_length 39 | 40 | ds_tokens = utils.load_and_tokenize_data(tokenizer, args.data_dir, args.max_len, args.max_num_generations, 41 | min_len=args.prompt_size, split=args.datasplit) 42 | 43 | if os.path.isfile(f'{folder_name}/sample_{name}.p'): 44 | print(f'Undecoded samples: {folder_name}/sample_{name}.p already exist. Skipping generation.') 45 | with open(f'{folder_name}/sample_{name}.p', 'rb') as f: 46 | samples, is_completed, unique_ngram_frac, ppl = pkl.load(f)[:4] 47 | samples_2 = [torch.LongTensor(x).view(1, -1).to(device) for x in samples] 48 | else: 49 | batch_size = gen_utils.get_default_batch_size(args.model_name, device) 50 | n_lst = [1, 2, 3, 4, 5, 6] 51 | 52 | sample_fn = gen_utils.create_sample_fn(model, args.max_len, 53 | top_p=args.top_p, top_k=args.top_k, temperature=args.temp 54 | ) 55 | t1 = time.time() 56 | samples, is_completed = gen_utils.get_samples_from_sample_fn( 57 | sample_fn, ds_tokens, tokenizer.eos_token_id, 58 | prompt_size=args.prompt_size, batch_size=batch_size 59 | ) 60 | t2 = time.time() 61 | print('sampling time:', round(t2-t1, 2)) 62 | unique_ngram_frac = src.metrics.get_unique_ngram_fraction(samples, n_lst) 63 | print('n-gram frac:', unique_ngram_frac) 64 | t1 = time.time() 65 | samples_2 = [torch.LongTensor(x).view(1, -1).to(device) for x in samples] 66 | ppl = src.metrics.get_perplexity_from_samples(model, samples_2) 67 | t2 = time.time() 68 | print('ppl time:', round(t2-t1, 2), ppl) 69 | 70 | output_file_name = f'{folder_name}/sample_{name}.p' # un-decoded samples 71 | with open(output_file_name, 'wb') as f: 72 | pkl.dump([samples, is_completed, unique_ngram_frac, ppl, args], f) 73 | 74 | # decode samples 75 | print('Deocding...') 76 | if os.path.isfile(f'{folder_name}/sentences_{name}.p'): 77 | print(f'Decode samples: {folder_name}/sentences_{name}.p already exist. Skipping.') 78 | else: 79 | decoded_samples = utils.decode_samples_from_lst(tokenizer, samples) 80 | with open(f'{folder_name}/sentences_{name}.p', 'wb') as f: 81 | pkl.dump([decoded_samples, is_completed], f) 82 | 83 | # featurize samples 84 | print('Featurizing...') 85 | feats_prefix = '' 86 | if args.use_large_feats: 87 | del model 88 | model, _ = utils.get_model_and_tokenizer(model_name=args.featurize_model_name, device=device) 89 | for l in {128, 256, 512, args.max_len}: 90 | feats_prefix = f'L{l}' 91 | feats_out_fn = f'{folder_name}/feats{feats_prefix}_{name}.pt' 92 | if os.path.isfile(feats_out_fn): 93 | print(f'Feats {feats_out_fn} exisits. Skipping') 94 | continue 95 | else: 96 | print(f'Featurizing l = {l}...') 97 | samples_3 = [x[:, :l] for x in samples_2] 98 | feats = src.model_utils.featurize_sequential(model, samples_3) 99 | torch.save(feats, feats_out_fn) 100 | else: # use features from model 101 | feats = src.model_utils.featurize_sequential(model, samples_2) 102 | torch.save(feats, f'{folder_name}/feats_{name}.pt') 103 | 104 | -------------------------------------------------------------------------------- /generate_ref.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import src.model_utils 5 | from src import utils 6 | import src.metrics 7 | 8 | if __name__ == '__main__': 9 | parser = utils.make_basic_parser() 10 | args = parser.parse_args() 11 | print(args) 12 | torch.manual_seed(args.seed) 13 | 14 | if not args.use_large_feats: 15 | raise ValueError('Use large feats!') 16 | 17 | # check if have to run 18 | save_directory = f'./outputs/{utils.get_dataset_name_from_datapath(args.data_dir)}_{utils.get_model_basename(args.model_name)}' 19 | if args.ds_name is None: 20 | name = args.datasplit 21 | else: 22 | name = f'{args.ds_name}_{args.datasplit}' 23 | folder_name = f'{save_directory}/generations/ref' 24 | 25 | 26 | device = utils.get_device_from_arg(args.device) 27 | print(f'Using device: {device}') 28 | 29 | 30 | ###### OLD 31 | ## featurize samples 32 | # feats = src.model_utils.featurize_sequential(model, ds_tokens) 33 | # torch.save(feats, f'{folder_name}/feats_{name}.pt') 34 | 35 | 36 | feats_prefix = '' 37 | if args.use_large_feats: 38 | model, tokenizer = utils.get_model_and_tokenizer(model_name=args.featurize_model_name, device=device) 39 | ds_tokens = utils.load_and_tokenize_data(tokenizer, args.data_dir, args.max_len, args.max_num_generations, 40 | ds_name=args.ds_name, split=args.datasplit) 41 | for l in {128, 256, 512, args.max_len}: 42 | feats_prefix = f'L{l}' 43 | feats_out_fn = f'{folder_name}/feats{feats_prefix}_{name}.pt' 44 | if os.path.isfile(feats_out_fn): 45 | print(f'Feats {feats_out_fn} exisits. Skipping') 46 | continue 47 | else: 48 | print(f'Featurizing l = {l}...') 49 | samples_3 = [x[:, :l] for x in ds_tokens] 50 | feats = src.model_utils.featurize_sequential(model, samples_3) 51 | torch.save(feats, feats_out_fn) 52 | else: # use features from model 53 | model, tokenizer = utils.get_model_and_tokenizer(model_name=args.model_name, device=device) 54 | ds_tokens = utils.load_and_tokenize_data(tokenizer, args.data_dir, 55 | args.max_len, args.max_num_generations, split=args.datasplit) 56 | feats = src.model_utils.featurize_sequential(model, ds_tokens) 57 | torch.save(feats, f'{folder_name}/feats_{name}.pt') 58 | 59 | -------------------------------------------------------------------------------- /human_evaluation.md: -------------------------------------------------------------------------------- 1 | # Human Evaluation Data 2 | 3 | The raw (anonymized) human evaluations we obtained can be downloaded 4 | [in this csv file](https://drive.google.com/file/d/1doC-NjhUlt4YDnK1qm23tyrfbogVlZ29/view?usp=sharing). 5 | 6 | The columns are: 7 | - `HITId`: Integer indexing the row 8 | - `WorkerId`: Unique identifier of the crowd-worker 9 | - `WorkTimeInSeconds`: Amount of time the HIT was open on AMT 10 | - `Input.idx`: Index of the prompt 11 | - `Input.ctx`: Context/prompt that each completion is based upon 12 | - `Input.model_a`: Name of player A 13 | - `Input.completiona`: Completion generated by player A 14 | - `Input.len_a`: Total length (prompt + generation) of player A's text 15 | - `Input.model_b`: Name of player B 16 | - `Input.completionb`: Completion generated by player B 17 | - `Input.len_b`: Total length (prompt + generation) of player B's text 18 | - `Answer.q1`: Answer of crowd-worker to the question: "Which continuation is more interesting or creative, given the context?" 19 | - `Answer.q2`: Answer of crowd-worker to the question: "Which continuation makes more sense, given the context?" 20 | - `Answer.q3`: Answer of crowd-worker to the question: "Which continuation is more likely to be written by a human?" 21 | - `Answer.te`: Our (pessimistic) estimate of the amount of time the crowd-worker took to answer the question. 22 | 23 | 24 | Key to `Answer.q*` fields: The responses of the crowd-workers to each question is stored with the following key: 25 | - Definitely A: 2a 26 | - Slightly A: 1a 27 | - Tie: 1a 28 | - Slightly B: 1b 29 | - Definitely B: 2b 30 | 31 | Note that both "Tie" and "Slightly A" are recorded as `1a`. Since for each pair, the choice of A versus B is randomized, this amounts to randomly assigning each tie as a win to one of the two players. 32 | -------------------------------------------------------------------------------- /library/spreadingvectors/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /library/spreadingvectors/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to spreadingvectors 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to spreadingvectors, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /library/spreadingvectors/LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered thecensor 394 | f the Creative Commons 395 | public licenses is dedicated to the public domain under the CC0 Public 396 | Domain Dedication. Except for the limited purpose of indicating that 397 | material is shared under a Creative Commons public license or as 398 | otherwise permitted by the Creative Commons policies published at 399 | creativecommons.org/policies, Creative Commons does not authorize the 400 | use of the trademark "Creative Commons" or any other trademark or logo 401 | of Creative Commons without its prior written consent including, 402 | without limitation, in connection with any unauthorized modifications 403 | to any of its public licenses or any other arrangements, 404 | understandings, or agreements concerning use of licensed material. For 405 | the avoidance of doubt, this paragraph does not form part of the 406 | public licenses. 407 | 408 | Creative Commons may be contacted at creativecommons.org. 409 | 410 | -------------------------------------------------------------------------------- /library/spreadingvectors/README.md: -------------------------------------------------------------------------------- 1 | # Spreading vectors for similarity search 2 | 3 | This is the open source implementation of the neural Catalyzer for similarity search. 4 | This code reproduces the results from the ICLR'2019 paper Spreading Vectors for Similarity Search. 5 | 6 | 7 | ## Install 8 | 9 | The basic install only requires Numpy and Pytorch 1.0 10 | ```python 11 | conda install numpy 12 | # See http://pytorch.org for details 13 | conda install pytorch -c pytorch 14 | ``` 15 | 16 | 17 | This code can run as is on a standard computer, but it detects if a GPU is present and automatically uses it. 18 | 19 | ### (optional) GPU Faiss 20 | 21 | If you want to further accelerate the code, you can install [Faiss](https://github.com/facebookresearch/faiss) with GPU support: 22 | ```bash 23 | # Make sure you have CUDA installed before installing faiss-gpu, otherwise it falls back to CPU version 24 | conda install faiss-gpu -c pytorch # [DEFAULT]For CUDA8.0 25 | conda install faiss-gpu cuda90 -c pytorch # For CUDA9.0 26 | conda install faiss-gpu cuda92 -c pytorch # For CUDA9.2 27 | ``` 28 | 29 | ### (optional) Install the C lattice quantizer 30 | 31 | The lattice quantizer can be run much faster custom C extensions. 32 | We provide a C implementation of the lattice quantizer, wrapped in Python using SWIG. 33 | First, you need to download and install Swig from your system's package manager or from [the website](http://www.swig.org/download.html). 34 | 35 | The C code can then be compiled: 36 | ``` 37 | cd lattices 38 | make all 39 | ``` 40 | 41 | 42 | ## Evaluating a model 43 | 44 | To benchmark our method, we use the two standard benchmark datasets [BigANN](http://corpus-texmex.irisa.fr/) and [Deep1b](https://yadi.sk/d/11eDCm7Dsn9GA), see [here](https://github.com/facebookresearch/faiss/tree/master/benchs#getting-bigann) for more info on how to download. 45 | You need to indicate the path to these in lib/data.py: 46 | ```python 47 | # lib/data.py 48 | def getBasedir(s): 49 | paths = { 50 | "bigann": "/path/to/bigann", 51 | "deep1b": "/path/to/deep1b" 52 | } 53 | 54 | return paths[s] 55 | ``` 56 | Note that for both Bigann an Deep1b, only the first 1M vectors of the dataset are used (hence they are called Deep1M and Bigann1M in the paper). 57 | 58 | ``` 59 | python eval.py --ckpt test.pth --quantizer zn_79 60 | ``` 61 | 62 | ### Pre-trained models 63 | 64 | We provide pre-trained models. 65 | The script [reproduce.sh](reproduce.sh) downloads the models and reproduces the paper's main results. 66 | 67 | ## Training a model 68 | 69 | Run training: 70 | ``` 71 | python train.py --num_learn 500000 --database bigann --lambda_uniform 0.02 --dint 1024 --dout 24 72 | ``` 73 | 74 | Typical output: 75 | ``` 76 | load dataset deep1b 77 | keeping 500000/357380000 training vectors 78 | computing training ground truth 79 | build network 80 | Lr schedule ... 81 | Forward pass 82 | Distances 83 | Train 84 | epoch 0, times: [hn 3.99 s epoch 55.19 s val 0.00 s] lr = 0.100000 loss = -0.00585795 = 0.00175652 + lam * -3.80723, offending 17773 85 | Forward pass 86 | Distances 87 | Train 88 | epoch 1, times: [hn 4.07 s epoch 57.41 s val 0.00 s] lr = 0.100000 loss = -0.0034838 = 0.00245264 + lam * -2.96822, offending 56211 89 | Forward pass 90 | Distances 91 | Train 92 | 93 | .... 94 | 95 | epoch 8, times: [hn 4.04 s epoch 55.10 s val 0.00 s] lr = 0.100000 loss = -0.00382894 = 0.00203354 + lam * -2.93124, offending 75412 96 | Forward pass 97 | Distances 98 | Train 99 | Valiation at epoch 9 100 | zn_3 nbit= 14: 0.0000 0.0003 0.0028 101 | zn_10 nbit= 32: 0.0009 0.0073 0.0437 102 | zn_79 nbit= 64: 0.0331 0.1581 0.4756 103 | storing test_ckpt/0.002/checkpoint.pth 104 | zn_79,rank=10 score improves (0.15814 > 0), keeping as best 105 | ``` 106 | 107 | Training uses a small part of the learn set, split between 500k training vectors and 1M+10k validation vectors (1M database, 10k queries). The rest of the data is unused. 108 | 109 | The ground-truth nearest neighbors are computed for the 500k vectors, this is fast enough on GPU. 110 | 111 | The stats that are logged are: 112 | - `lr`: the current learning rate, depends on the type of schedule 113 | - `loss`: total_loss = triplet_loss + lambda * entropy_loss. 114 | - `offending`: number of triplets that caused a non-0 loss (should be decreasing) 115 | - `times`: hard-negative mining time, training time and validation time. 116 | 117 | Validation is performed every 10 epochs (by default). 118 | For a few quantizers (selected with --quantizers) it performs the search on a validation set and reports the 1-recalls at ranks 1, 10, 100. 119 | Then it keeps the best model based on one of the evalated quantizers (zn_79 in this case: the Zn lattice with r^2 = 79). 120 | 121 | Training for 160 epochs takes less than 3 hours on a P-100 GPU, but 90% of the final performance should already be reached in only 10 epochs (around 10 minutes). 122 | 123 | ### Cross-validation 124 | 125 | The only parameter that we cross-validated is the lambda. 126 | The script [crossvalidate.sh](crossvalidate.sh) does the grid search and tests on the best result. 127 | It runs the grid search sequentially, for a faster result it is worthwhile to run it on a cluster of machines. 128 | A typical output is [this gist](https://gist.github.com/mdouze/bd34ceb6b17c3616e0b4e6a45e387cb7), which corresponds to the line "Catalyzer + lattice" of table 1 in the paper. 129 | 130 | ## Zn quantizer 131 | 132 | The spherical Zn quantizer uses as codebook the points of the hypersphere of radius r that have integer coordinates. We provide here the common (squared) radiuses that correspond to 16, 32, and 64 bits for commonly used dimensions. 133 | 134 | 135 | | d | 16 bits | 32 bits | 64 bits | 136 | |----|--------:|--------:|--------:| 137 | | 24 | 3 | 10 | 79 | 138 | | 32 | 3 | 8 | 36 | 139 | | 40 | 2 | 7 | 24 | 140 | 141 | To find out the number of bits needed to encode the vertices of a sphere in 48 dim with squared radius 30, use: 142 | ```python 143 | from lattices.Zn_lattice import ZnCodec 144 | import math 145 | 146 | d, r2 = 48, 30 147 | # number of distinct vertices 148 | nv = ZnCodec(d, r2).nv 149 | 150 | # number of bits needed to encode this 151 | nbit = math.ceil(math.log2(nv)) 152 | ``` 153 | In this case, nv = 311097167722066085728512 and nbit = 79. 154 | 155 | ### Search performance 156 | 157 | A typical use case for the lattice quantizer is to perform asymmetric nearest-neighbors searches. 158 | In that case, a set of n vectors is encoded. 159 | At search time, a query vector x is compared to each of the encoded vectors. 160 | Thus, each vector is decoded to y and the distance to x is computed. 161 | The nearest vector id is then returned. 162 | In general, this is done simultaneously for a batch of query vectors. 163 | 164 | The benchmark [bench_Zn_decoder.py](lattices/bench_Zn_decoder.py) performs this operation, for 1M database vectors and 1k queries. 165 | It compares the performance with that of a Faiss PQ index. 166 | Typical result in [this gist](https://gist.github.com/mdouze/0b3ae8c88ba62aae234cdb8507164934): the lattice decoder is a bit slower than PQ. 167 | This is understandable because PQ performs comparisons in the compressed domain. 168 | 169 | 170 | ## Adding datasets 171 | 172 | We provide dataloaders for the standard BigANN and Deep1b. 173 | Our code can easily be extended to other datasets. 174 | First, add the path to your dataset in lib/data.py: 175 | ```python 176 | # lib/data.py 177 | def getBasedir(s): 178 | paths = { 179 | "bigann": "/path/to/bigann", 180 | "deep1b": "/path/to/deep1b", 181 | "my_data": "/path/to/mydata" 182 | } 183 | 184 | return paths[s] 185 | ``` 186 | 187 | Then, modify lib/data.py to handle the loading of your dataset 188 | ```python 189 | #lib/data.py 190 | 191 | def load_mydata(device, size = 10 ** 6, test=True, qsize=10 ** 5): 192 | basedir = getBasedir("my_data") 193 | 194 | # Exemple code to load your data 195 | xt = np.load(join(basedir, "my_trainingset.npy")) 196 | if test: 197 | xb = np.load(join(basedir, "my_database.npy")) 198 | xq = np.load(join(basedir, "my_queries.npy")) 199 | else: 200 | xb = xt[:size] 201 | xq = xt[size:size+qsize] 202 | xt = xt[size+qsize:] 203 | gt = get_nearestneighbors(xq, xb, 100, device) 204 | 205 | return xt, xb, xq, gt 206 | 207 | 208 | def load_dataset(name, device, size=10**6, test=True): 209 | # ... 210 | elif name == "my_data": 211 | load_mydata(device, name, size, test) 212 | 213 | ``` 214 | 215 | 216 | ## License 217 | 218 | This repository is licensed under the CC BY-NC 4.0. 219 | -------------------------------------------------------------------------------- /library/spreadingvectors/crossvalidate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (c) 2015-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the CC-by-NC license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | set -ex 10 | 11 | # try these values of lambda 12 | lambdas="0.05 0.1 0.2 0.5 1.0 2.0 5.0 10.0 20.0 50.0 100.0" 13 | 14 | dout=4 # output dimension 15 | db="valid_feats" # use this dataset 16 | quant=zn_50 # cross-validate using this quantizer 17 | best_lambda=-1 18 | best_perf="0.0000" 19 | 20 | for lambda in $lambdas; do 21 | mkdir -p test_ckpt/$lambda 22 | time python -u train.py \ 23 | --dout $dout \ 24 | --database $db \ 25 | --lambda_uniform $lambda \ 26 | --checkpoint_dir test_ckpt/$lambda \ 27 | > >(tee -a test_ckpt/$lambda.stdout) 2> >(tee -a test_ckpt/$lambda.log) 28 | 29 | # extract validation accuracy 30 | perf=$(tac test_ckpt/$lambda.stdout | 31 | grep -m1 'keeping as best' | 32 | grep -o '(.*<' | grep -o '[0-9\.]*') 33 | 34 | echo $perf 35 | 36 | if [[ "$perf" < "$best_perf" ]]; then 37 | best_perf=$perf 38 | best_lambda=$lambda 39 | fi 40 | done 41 | 42 | echo "Best value of lambda: $best_lambda" 43 | 44 | python eval.py \ 45 | --database $db \ 46 | --quantizer $quant \ 47 | --ckpt-path test_ckpt/$best_lambda/checkpoint.pth.best 48 | -------------------------------------------------------------------------------- /library/spreadingvectors/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import division 8 | try: 9 | import faiss 10 | except: 11 | pass 12 | import numpy as np 13 | import torch 14 | import argparse 15 | import os 16 | import time 17 | from lib.metrics import evaluate, evaluate_k 18 | from lib.net import Normalize 19 | join = os.path.join 20 | import torch.nn as nn 21 | from lib.data import load_dataset 22 | 23 | 24 | if __name__ == "__main__": 25 | global args 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--candidates", type=int, default=10) 29 | parser.add_argument("--ckpt-path", type=str, required=True) 30 | parser.add_argument("--database", type=str) #, choices=["bigann", "deep1b"]) 31 | parser.add_argument("--device", choices=["cpu", "cuda", "auto"], 32 | default="auto") 33 | parser.add_argument("--gpu", action='store_true', default=False) 34 | parser.add_argument("--quantizer", required=True) 35 | parser.add_argument("--size-base", type=int, default=int(1e6)) 36 | parser.add_argument("--val", action='store_false', dest='test') 37 | parser.set_defaults(gpu=False, test=True) 38 | 39 | args = parser.parse_args() 40 | if args.device == "auto": 41 | args.device = "cuda" if torch.cuda.is_available() else "cpu" 42 | 43 | start = time.time() 44 | if os.path.exists(args.ckpt_path): 45 | print("Loading net") 46 | ckpt = torch.load(args.ckpt_path) 47 | d = vars(args) 48 | for k, v in vars(ckpt['args']).items(): 49 | d[k] = v 50 | 51 | (xt, xb, xq, gt) = load_dataset(args.database, args.device, size=args.size_base, test=args.test) 52 | dim = xb.shape[1] 53 | dint, dout = args.dint, args.dout 54 | 55 | net = nn.Sequential( 56 | nn.Linear(in_features=dim, out_features=dint, bias=True), 57 | nn.BatchNorm1d(dint), 58 | nn.ReLU(), 59 | nn.Linear(in_features=dint, out_features=dint, bias=True), 60 | nn.BatchNorm1d(dint), 61 | nn.ReLU(), 62 | nn.Linear(in_features=dint, out_features=dout, bias=True), 63 | Normalize() 64 | ) 65 | net.load_state_dict(ckpt['state_dict']) 66 | net = net.to(args.device) 67 | net = net.eval() 68 | 69 | elif args.ckpt_path.startswith("pca-"): 70 | assert args.database is not None 71 | (xt, xb, xq, gt) = load_dataset(args.database, args.device, size=args.size_base, test=args.test) 72 | args.dim = int(args.ckpt_path[4:]) 73 | 74 | mu = np.mean(xb, axis=0, keepdims=True) 75 | xb -= mu 76 | xq -= mu 77 | 78 | cov = np.dot(xb.T, xb) / xb.shape[0] 79 | eigvals, eigvecs = np.linalg.eig(cov) 80 | o = eigvals.argsort()[::-1] 81 | PCA = eigvecs[:, o[:args.dim]].astype(np.float32) 82 | 83 | xb = np.dot(xb, PCA) 84 | xb /= np.linalg.norm(xb, axis=1, keepdims=True) 85 | xq = np.dot(xq, PCA) 86 | xq /= np.linalg.norm(xq, axis=1, keepdims=True) 87 | net = nn.Sequential() 88 | else: 89 | print("Main argument not understood: should be the path to a net checkpoint") 90 | import sys;sys.exit(1) 91 | 92 | evaluate_k(net, xb, [args.quantizer], device=args.device) 93 | -------------------------------------------------------------------------------- /library/spreadingvectors/lattices/Makefile: -------------------------------------------------------------------------------- 1 | 2 | 3 | .SUFFIXES: .cpp .o .cxx 4 | 5 | 6 | CPPFLAGS=-fPIC -Wall -Wno-sign-compare -g -O3 -mavx -std=c++11 -fopenmp 7 | 8 | all: _c_lattices.so 9 | 10 | SWIG=swig 11 | 12 | c_lattices_wrap.cxx: c_lattices.swig lattice_utils.h lattice_Zn.h 13 | $(SWIG) -c++ -python c_lattices.swig 14 | 15 | c_lattices_wrap.o: c_lattices_wrap.cxx lattice_utils.h lattice_Zn.h 16 | 17 | lattice_Zn.o: lattice_utils.h lattice_Zn.h 18 | 19 | .cxx.o: 20 | g++ $(CPPFLAGS) $(EXTRACFLAGS) -c $< 21 | 22 | .cpp.o: 23 | g++ $(CPPFLAGS) $(EXTRACFLAGS) -c $< 24 | 25 | c_lattices_wrap.o: EXTRACFLAGS= \ 26 | -I $(shell python -c "import distutils.sysconfig; print(distutils.sysconfig.get_python_inc())" ) \ 27 | -I $(shell python -c "import numpy ; print(numpy.get_include())") 28 | 29 | # linux-specific link line 30 | _c_lattices.so: c_lattices_wrap.o lattice_utils.o lattice_Zn.o 31 | g++ -g -shared -fopenmp -o $@ $^ 32 | 33 | clean: 34 | rm -f *.o _c_lattices.so c_lattices_wrap.cxx c_lattices.py 35 | -------------------------------------------------------------------------------- /library/spreadingvectors/lattices/Zn_lattice.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import numpy as np 8 | 9 | try: 10 | from lattices import c_lattices 11 | swig_ptr = c_lattices.swig_ptr 12 | except ImportError: 13 | c_lattices = None 14 | 15 | 16 | class Comb: 17 | """ a Pascal triangle """ 18 | def __init__(self, npas=100): 19 | pascal = [[1] for i in range(npas)] 20 | for i in range(npas): 21 | for j in range(1, i + 1): 22 | pascal[i].append(pascal[i - 1][j] + pascal[i - 1][j - 1]) 23 | pascal[i].append(0) 24 | self.pascal = pascal 25 | 26 | def __call__(self, n, k): 27 | return self.pascal[n][k] if k <= n else 0 28 | 29 | # function to compute a binomial coefficient 30 | if c_lattices is None: 31 | comb = Comb() 32 | else: 33 | comb = c_lattices.cvar.comb 34 | 35 | def count_comb(x): 36 | """count number of distinct permutations of an array that contains 37 | duplicate values""" 38 | x = np.array(x) 39 | n = len(x) 40 | accu = 1 41 | for v in np.unique(x): 42 | nv = int((x == v).sum()) 43 | accu *= int(comb(n, nv)) 44 | n -= nv 45 | # bits used for signs 46 | accu *= 2 ** int((x != 0).sum()) 47 | return accu 48 | 49 | 50 | def sum_of_sq(total, v, n): 51 | """find all positive integer vectors of size n: 52 | - whose squared elements sum to total 53 | - maximium value is v 54 | """ 55 | 56 | if total < 0: 57 | return [] 58 | elif total == 0: 59 | return [[0] * n] 60 | elif n == 1: 61 | while v * v > total: 62 | v -= 1 63 | if v * v == total: 64 | return [[v]] 65 | else: 66 | return [] 67 | else: 68 | res = [] 69 | for vi in range(v, -1, -1): 70 | res += [[vi] + vv for vv in 71 | sum_of_sq(total - vi * vi, vi, n - 1)] 72 | return res 73 | 74 | 75 | def compute_atoms(d, r2): 76 | """Find atoms that define the Zn sphere of dimension d and squared 77 | radius r2""" 78 | v = int(1 + np.sqrt(r2)) # max value of a component 79 | if c_lattices is None: 80 | atoms = sum_of_sq(r2, v, d) 81 | return np.array(atoms) 82 | else: 83 | atoms = c_lattices.sum_of_sq(r2, v, d) 84 | return c_lattices.vector_to_array(atoms).reshape(-1, d) 85 | 86 | 87 | 88 | class ZnCodecC: 89 | 90 | def __init__(self, d, r2): 91 | self.znc = c_lattices.ZnSphereCodec(d, r2) 92 | atoms = c_lattices.vector_to_array(self.znc.voc) 93 | atoms = atoms.reshape(-1, d) 94 | # recompute instead of using self.znc.nv because it is limited 95 | # to 64 bit 96 | self.nv = sum([count_comb(atom) for atom in atoms]) 97 | self.code_size = self.znc.code_size 98 | 99 | if d & (d - 1) == 0: 100 | # d is a power of 2. Then we can use a ZnSphereCodecRec as 101 | # codec (faster for decoding) 102 | self.znc_rec = c_lattices.ZnSphereCodecRec(d, r2) 103 | else: 104 | self.znc_rec = None 105 | 106 | 107 | def quantize(self, x): 108 | x = np.ascontiguousarray(x, dtype='float32') 109 | n, d = x.shape 110 | assert d == self.znc.dim 111 | c = np.empty((n, d), dtype='float32') 112 | dps = np.empty(n, dtype='float32') 113 | self.znc.search_multi(n, 114 | swig_ptr(x), swig_ptr(c), 115 | swig_ptr(dps)) 116 | return c 117 | 118 | def encode(self, x): 119 | assert self.nv < 2 ** 64 120 | n, d = x.shape 121 | assert d == self.znc.dim 122 | codes = np.empty(n, dtype='uint64') 123 | if not self.znc_rec: 124 | self.znc.encode_multi(n, swig_ptr(x), 125 | swig_ptr(codes)) 126 | else: 127 | # first quantizer then encode 128 | centroids = self.quantize(x) 129 | self.znc_rec.encode_multi(n, swig_ptr(centroids), 130 | swig_ptr(codes)) 131 | return codes 132 | 133 | def decode(self, codes): 134 | n, = codes.shape 135 | x = np.empty((n, self.znc.dim), dtype='float32') 136 | decoder = self.znc_rec or self.znc 137 | decoder.decode_multi(n, swig_ptr(codes), 138 | swig_ptr(x)) 139 | return x 140 | 141 | def find_nn(self, codes, xq): 142 | """ find the nearest code of each vector of xq 143 | (returns dot products, not distances) 144 | """ 145 | assert self.nv < 2 ** 64 146 | nc, = codes.shape 147 | nq, d = xq.shape 148 | assert d == self.znc.dim 149 | ids = np.empty(nq, dtype='int64') 150 | dis = np.empty(nq, dtype='float32') 151 | decoder = self.znc_rec or self.znc 152 | decoder.find_nn(nc, swig_ptr(codes), nq, swig_ptr(xq), 153 | swig_ptr(ids), swig_ptr(dis)) 154 | 155 | return ids, dis 156 | 157 | 158 | 159 | class ZnCodecPy: 160 | 161 | def __init__(self, d, r2): 162 | self.atoms = compute_atoms(d, r2) 163 | self.atoms = np.sort(self.atoms, axis=1) 164 | self.nv = sum([count_comb(atom) for atom in self.atoms]) 165 | 166 | def quantize(self, x): 167 | n, d = x.shape 168 | assert d == self.atoms.shape[1] 169 | x_abs = np.abs(x) 170 | x_mod = np.sort(x_abs, axis=1) 171 | x_order = np.argsort(np.argsort(x_abs, axis=1), axis=1) 172 | matches = np.argmax(np.dot(x_mod, self.atoms.T), axis=1) 173 | x_recons = self.atoms[matches] 174 | q_abs = x_recons[np.tile(np.arange(n).reshape(-1, 1), d), x_order] 175 | q = q_abs * np.sign(x) 176 | 177 | return q.astype('float32') 178 | 179 | if c_lattices is None: 180 | ZnCodec = ZnCodecPy 181 | else: 182 | ZnCodec = ZnCodecC 183 | -------------------------------------------------------------------------------- /library/spreadingvectors/lattices/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krishnap25/mauve-experiments/d1753fdf09396606defe5fa5749f4a7e8fe24c96/library/spreadingvectors/lattices/__init__.py -------------------------------------------------------------------------------- /library/spreadingvectors/lattices/bench_Zn_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import numpy as np 8 | from lattices import Zn_lattice 9 | import time 10 | import faiss 11 | import sys 12 | 13 | ## single-thread benchmark 14 | faiss.omp_set_num_threads(1) 15 | 16 | 17 | ## all data is random, we are not looking at correctness here 18 | if True: 19 | dim = 32 20 | r2 = 36 21 | else: 22 | dim = 24 23 | r2 = 79 24 | 25 | codec = Zn_lattice.ZnCodec(dim, r2) 26 | 27 | 28 | # set number of queries 29 | nb = 10**6 30 | nq = 1000 31 | k = 1 32 | 33 | print("nb=%d nq=%d" % (nb, nq)) 34 | 35 | # sample queries 36 | rs = np.random.RandomState(123) 37 | xq = rs.randn(nq, dim).astype('float32') 38 | 39 | 40 | print("init dim=%d r2=%d" % (dim, r2)) 41 | 42 | codes = rs.randint(1<<31, size=nb).astype('uint64') 43 | 44 | print("code_size=%d nv=%d %.2f bits" % ( 45 | codec.code_size, codec.nv, np.log2(codec.nv))) 46 | assert codec.code_size == 8 47 | 48 | dis = np.empty((nq, k), dtype='float32') 49 | labels = np.empty((nq, k), dtype='int64') 50 | 51 | t0 = time.time() 52 | codec.find_nn(codes, xq) 53 | t1 = time.time() 54 | 55 | print ("time for code_size=%d nq=%d nb=%d: %.3f s (%.3f ms/query)" % ( 56 | codec.code_size, nq, nb, t1 - t0, 57 | (t1 - t0) * 1000 / nq)) 58 | 59 | index = faiss.IndexPQ(dim, 8, 8) 60 | 61 | xb = rs.randn(nb, dim).astype('float32') 62 | print("train") 63 | index.train(xb) 64 | print("add") 65 | index.add(xb) 66 | 67 | t0 = time.time() 68 | index.search(xq, 1) 69 | t1 = time.time() 70 | 71 | print ("time for IndexPQ code_size=%d nq=%d nb=%d: %.3f s (%.3f ms/query)" % ( 72 | index.pq.code_size, nq, nb, t1 - t0, 73 | (t1 - t0) * 1000 / nq)) 74 | -------------------------------------------------------------------------------- /library/spreadingvectors/lattices/c_lattices.swig: -------------------------------------------------------------------------------- 1 | // -*- c++ -*- 2 | /** 3 | * Copyright (c) 2015-present, Facebook, Inc. 4 | * All rights reserved. 5 | * 6 | * This source code is licensed under the BSD+Patents license found in the 7 | * LICENSE file in the root directory of this source tree. 8 | */ 9 | %module c_lattices 10 | 11 | // release GIL by default for all functions 12 | %exception { 13 | Py_BEGIN_ALLOW_THREADS 14 | $action 15 | Py_END_ALLOW_THREADS 16 | } 17 | 18 | %{ 19 | #include 20 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 21 | #include 22 | #include "lattice_Zn.h" 23 | %} 24 | 25 | // simplified interface for vector 26 | namespace std { 27 | 28 | template 29 | class vector { 30 | public: 31 | vector(); 32 | ~vector(); 33 | void push_back(T); 34 | void clear(); 35 | T * data(); 36 | size_t size(); 37 | T at (size_t n) const; 38 | void resize (size_t n); 39 | }; 40 | }; 41 | 42 | 43 | %template(FloatVector) std::vector; 44 | %template(RepeatVector) std::vector; 45 | 46 | 47 | typedef unsigned long uint64_t; 48 | typedef unsigned int int32_t; 49 | typedef unsigned char uint8_t; 50 | 51 | %include "lattice_utils.h" 52 | %include "lattice_Zn.h" 53 | 54 | 55 | %inline %{ 56 | // make sure all these pointer types are instanciated 57 | void dummy_func(double * , int *, unsigned char *, long * , unsigned long *) { 58 | } 59 | 60 | %} 61 | 62 | %{ 63 | PyObject *swig_ptr (PyObject *a) 64 | { 65 | if(!PyArray_Check(a)) { 66 | PyErr_SetString(PyExc_ValueError, "input not a numpy array"); 67 | return NULL; 68 | } 69 | PyArrayObject *ao = (PyArrayObject *)a; 70 | 71 | if(!PyArray_ISCONTIGUOUS(ao)) { 72 | PyErr_SetString(PyExc_ValueError, "array is not C-contiguous"); 73 | return NULL; 74 | } 75 | void * data = PyArray_DATA(ao); 76 | if(PyArray_TYPE(ao) == NPY_FLOAT32) { 77 | return SWIG_NewPointerObj(data, SWIGTYPE_p_float, 0); 78 | } 79 | if(PyArray_TYPE(ao) == NPY_FLOAT64) { 80 | return SWIG_NewPointerObj(data, SWIGTYPE_p_double, 0); 81 | } 82 | if(PyArray_TYPE(ao) == NPY_INT32) { 83 | return SWIG_NewPointerObj(data, SWIGTYPE_p_int, 0); 84 | } 85 | if(PyArray_TYPE(ao) == NPY_UINT8) { 86 | return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_char, 0); 87 | } 88 | if(PyArray_TYPE(ao) == NPY_UINT64) { 89 | return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_long, 0); 90 | } 91 | if(PyArray_TYPE(ao) == NPY_INT64) { 92 | return SWIG_NewPointerObj(data, SWIGTYPE_p_long, 0); 93 | } 94 | PyErr_SetString(PyExc_ValueError, "did not recognize array type"); 95 | return NULL; 96 | } 97 | 98 | 99 | %} 100 | 101 | void omp_set_num_threads(int ); 102 | void *memcpy(void *dest, const void *src, size_t n); 103 | 104 | %exception; 105 | 106 | %exception; 107 | 108 | 109 | %init %{ 110 | /* needed, else crash at runtime */ 111 | import_array(); 112 | %} 113 | 114 | // return a pointer usable as input for functions that expect pointers 115 | PyObject *swig_ptr (PyObject *a); 116 | 117 | %pythoncode %{ 118 | 119 | import numpy as np 120 | 121 | vector_name_map = { 122 | 'Float': 'float32', 123 | 'Byte': 'uint8', 124 | 'Uint64': 'uint64', 125 | 'Long': 'int64', 126 | 'Int': 'int32', 127 | 'Double': 'float64' 128 | } 129 | 130 | def vector_to_array(v): 131 | """ convert a C++ vector to a numpy array """ 132 | classname = v.__class__.__name__ 133 | assert classname.endswith('Vector') 134 | dtype = np.dtype(vector_name_map[classname[:-6]]) 135 | a = np.empty(v.size(), dtype=dtype) 136 | memcpy(swig_ptr(a), v.data(), a.nbytes) 137 | return a 138 | 139 | 140 | def vector_float_to_array(v): 141 | return vector_to_array(v) 142 | 143 | 144 | def copy_array_to_vector(a, v): 145 | """ copy a numpy array to a vector """ 146 | n, = a.shape 147 | classname = v.__class__.__name__ 148 | assert classname.endswith('Vector') 149 | dtype = np.dtype(vector_name_map[classname[:-6]]) 150 | assert dtype == a.dtype, ( 151 | 'cannot copy a %s array to a %s (should be %s)' % ( 152 | a.dtype, classname, dtype)) 153 | v.resize(n) 154 | memcpy(v.data(), swig_ptr(a), a.nbytes) 155 | 156 | %} 157 | -------------------------------------------------------------------------------- /library/spreadingvectors/lattices/lattice_Zn.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD+Patents license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | #include "lattice_Zn.h" 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | 21 | 22 | void VectorCodec::encode_multi(size_t n, const float *c, 23 | uint64_t * codes) const 24 | { 25 | #pragma omp parallel if (n > 1000) 26 | { 27 | #pragma omp for 28 | for(int i = 0; i < n; i++) { 29 | codes[i] = encode(c + i * dim); 30 | } 31 | } 32 | } 33 | 34 | 35 | void VectorCodec::decode_multi(size_t n, const uint64_t * codes, 36 | float *c) const 37 | { 38 | #pragma omp parallel if (n > 1000) 39 | { 40 | #pragma omp for 41 | for(int i = 0; i < n; i++) { 42 | decode(codes[i], c + i * dim); 43 | } 44 | } 45 | } 46 | 47 | void VectorCodec::find_nn ( 48 | size_t nc, const uint64_t * codes, 49 | size_t nq, const float *xq, 50 | long *labels, float *distances) 51 | { 52 | for (long i = 0; i < nq; i++) { 53 | distances[i] = -1e20; 54 | labels[i] = -1; 55 | } 56 | 57 | float c[dim]; 58 | for(long i = 0; i < nc; i++) { 59 | uint64_t code = codes[nc]; 60 | decode(code, c); 61 | for (long j = 0; j < nq; j++) { 62 | const float *x = xq + j * dim; 63 | float dis = fvec_inner_product(x, c, dim); 64 | if (dis > distances[j]) { 65 | distances[j] = dis; 66 | labels[j] = i; 67 | } 68 | } 69 | } 70 | 71 | } 72 | 73 | 74 | /********************************************************** 75 | * ZnSphereSearch 76 | **********************************************************/ 77 | 78 | 79 | ZnSphereSearch::ZnSphereSearch(int dim, int r2): dimS(dim), r2(r2) { 80 | voc = sum_of_sq(r2, int(ceil(sqrt(r2)) + 1), dim); 81 | natom = voc.size() / dim; 82 | } 83 | 84 | float ZnSphereSearch::search(const float *x, float *c) const { 85 | float tmp[dimS * 2]; 86 | int tmp_int[dimS]; 87 | return search(x, c, tmp, tmp_int); 88 | } 89 | 90 | float ZnSphereSearch::search(const float *x, float *c, 91 | float *tmp, // size 2 *dim 92 | int *tmp_int, // size dim 93 | int *ibest_out 94 | ) const { 95 | int dim = dimS; 96 | assert (natom > 0); 97 | int *o = tmp_int; 98 | float *xabs = tmp; 99 | float *xperm = tmp + dim; 100 | 101 | // argsort 102 | for (int i = 0; i < dim; i++) { 103 | o[i] = i; 104 | xabs[i] = fabsf(x[i]); 105 | } 106 | std::sort(o, o + dim, [xabs](int a, int b) { 107 | return xabs[a] > xabs[b]; 108 | }); 109 | for (int i = 0; i < dim; i++) { 110 | xperm[i] = xabs[o[i]]; 111 | } 112 | // find best 113 | int ibest = -1; 114 | float dpbest = -100; 115 | for (int i = 0; i < natom; i++) { 116 | float dp = fvec_inner_product (voc.data() + i * dim, xperm, dim); 117 | if (dp > dpbest) { 118 | dpbest = dp; 119 | ibest = i; 120 | } 121 | } 122 | // revert sort 123 | const float *cin = voc.data() + ibest * dim; 124 | for (int i = 0; i < dim; i++) { 125 | c[o[i]] = copysignf (cin[i], x[o[i]]); 126 | } 127 | if (ibest_out) 128 | *ibest_out = ibest; 129 | return dpbest; 130 | } 131 | 132 | void ZnSphereSearch::search_multi(int n, const float *x, 133 | float *c_out, 134 | float *dp_out) { 135 | #pragma omp parallel if (n > 1000) 136 | { 137 | #pragma omp for 138 | for(int i = 0; i < n; i++) { 139 | dp_out[i] = search(x + i * dimS, c_out + i * dimS); 140 | } 141 | } 142 | } 143 | 144 | 145 | /********************************************************** 146 | * ZnSphereCodec 147 | **********************************************************/ 148 | 149 | ZnSphereCodec::ZnSphereCodec(int dim, int r2): 150 | ZnSphereSearch(dim, r2), 151 | VectorCodec(dim) 152 | { 153 | nv = 0; 154 | for (int i = 0; i < natom; i++) { 155 | Repeats repeats(dim, &voc[i * dim]); 156 | CodeSegment cs(repeats); 157 | cs.c0 = nv; 158 | Repeat &br = repeats.repeats.back(); 159 | cs.signbits = br.val == 0 ? dim - br.n : dim; 160 | code_segments.push_back(cs); 161 | nv += repeats.count() << cs.signbits; 162 | } 163 | 164 | uint64_t nvx = nv; 165 | code_size = 0; 166 | while (nvx > 0) { 167 | nvx >>= 8; 168 | code_size++; 169 | } 170 | } 171 | 172 | uint64_t ZnSphereCodec::search_and_encode(const float *x) const { 173 | float tmp[dim * 2]; 174 | int tmp_int[dim]; 175 | int ano; // atom number 176 | float c[dim]; 177 | search(x, c, tmp, tmp_int, &ano); 178 | uint64_t signs = 0; 179 | float cabs[dim]; 180 | int nnz = 0; 181 | for (int i = 0; i < dim; i++) { 182 | cabs[i] = fabs(c[i]); 183 | if (c[i] != 0) { 184 | if (c[i] < 0) 185 | signs |= 1UL << nnz; 186 | nnz ++; 187 | } 188 | } 189 | const CodeSegment &cs = code_segments[ano]; 190 | assert(nnz == cs.signbits); 191 | uint64_t code = cs.c0 + signs; 192 | code += cs.encode(cabs) << cs.signbits; 193 | return code; 194 | } 195 | 196 | uint64_t ZnSphereCodec::encode(const float *x) const 197 | { 198 | return search_and_encode(x); 199 | } 200 | 201 | 202 | void ZnSphereCodec::decode(uint64_t code, float *c) const { 203 | int i0 = 0, i1 = natom; 204 | while (i0 + 1 < i1) { 205 | int imed = (i0 + i1) / 2; 206 | if (code_segments[imed].c0 <= code) i0 = imed; 207 | else i1 = imed; 208 | } 209 | const CodeSegment &cs = code_segments[i0]; 210 | code -= cs.c0; 211 | uint64_t signs = code; 212 | code >>= cs.signbits; 213 | cs.decode(code, c); 214 | 215 | int nnz = 0; 216 | for (int i = 0; i < dim; i++) { 217 | if (c[i] != 0) { 218 | if (signs & (1UL << nnz)) 219 | c[i] = -c[i]; 220 | nnz ++; 221 | } 222 | } 223 | } 224 | 225 | 226 | /************************************************************** 227 | * ZnSphereCodecRec 228 | **************************************************************/ 229 | 230 | uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const 231 | { 232 | return all_nv[ld * (r2 + 1) + r2a]; 233 | } 234 | 235 | 236 | uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const 237 | { 238 | return all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a]; 239 | } 240 | 241 | void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum) 242 | { 243 | all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a] = cum; 244 | } 245 | 246 | 247 | ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2): 248 | VectorCodec(dim), r2(r2) 249 | { 250 | log2_dim = 0; 251 | while (dim > (1 << log2_dim)) 252 | log2_dim++; 253 | assert(dim == (1 << log2_dim) || 254 | !"dimension must be a power of 2"); 255 | 256 | all_nv.resize((log2_dim + 1) * (r2 + 1)); 257 | all_nv_cum.resize((log2_dim + 1) * (r2 + 1) * (r2 + 1)); 258 | 259 | for (int r2a = 0; r2a <= r2; r2a++) { 260 | int r = int(sqrt(r2a)); 261 | if (r * r == r2a) { 262 | all_nv[r2a] = r == 0 ? 1 : 2; 263 | } else { 264 | all_nv[r2a] = 0; 265 | } 266 | } 267 | 268 | for (int ld = 1; ld <= log2_dim; ld++) { 269 | 270 | for (int r2sub = 0; r2sub <= r2; r2sub++) { 271 | uint64_t nv = 0; 272 | for (int r2a = 0; r2a <= r2sub; r2a++) { 273 | int r2b = r2sub - r2a; 274 | set_nv_cum(ld, r2sub, r2a, nv); 275 | nv += get_nv(ld - 1, r2a) * get_nv(ld - 1, r2b); 276 | } 277 | all_nv[ld * (r2 + 1) + r2sub] = nv; 278 | } 279 | } 280 | nv = get_nv(log2_dim, r2); 281 | 282 | uint64_t nvx = nv; 283 | code_size = 0; 284 | while (nvx > 0) { 285 | nvx >>= 8; 286 | code_size++; 287 | } 288 | 289 | int cache_level = std::min(3, log2_dim - 1); 290 | decode_cache_ld = 0; 291 | assert(cache_level <= log2_dim); 292 | decode_cache.resize((r2 + 1)); 293 | 294 | for (int r2sub = 0; r2sub <= r2; r2sub++) { 295 | int ld = cache_level; 296 | uint64_t nvi = get_nv(ld, r2sub); 297 | std::vector &cache = decode_cache[r2sub]; 298 | int dimsub = (1 << cache_level); 299 | cache.resize (nvi * dimsub); 300 | float c[dim]; 301 | uint64_t code0 = get_nv_cum(cache_level + 1, r2, 302 | r2 - r2sub); 303 | for (int i = 0; i < nvi; i++) { 304 | decode(i + code0, c); 305 | memcpy(&cache[i * dimsub], c + dim - dimsub, 306 | dimsub * sizeof(*c)); 307 | } 308 | } 309 | decode_cache_ld = cache_level; 310 | } 311 | 312 | uint64_t ZnSphereCodecRec::encode(const float *c) const 313 | { 314 | return encode_centroid(c); 315 | } 316 | 317 | 318 | 319 | uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const 320 | { 321 | uint64_t codes[dim]; 322 | int norm2s[dim]; 323 | for(int i = 0; i < dim; i++) { 324 | if (c[i] == 0) { 325 | codes[i] = 0; 326 | norm2s[i] = 0; 327 | } else { 328 | int r2i = int(c[i] * c[i]); 329 | norm2s[i] = r2i; 330 | codes[i] = c[i] >= 0 ? 0 : 1; 331 | } 332 | } 333 | int dim2 = dim / 2; 334 | for(int ld = 1; ld <= log2_dim; ld++) { 335 | for (int i = 0; i < dim2; i++) { 336 | int r2a = norm2s[2 * i]; 337 | int r2b = norm2s[2 * i + 1]; 338 | 339 | uint64_t code_a = codes[2 * i]; 340 | uint64_t code_b = codes[2 * i + 1]; 341 | 342 | codes[i] = 343 | get_nv_cum(ld, r2a + r2b, r2a) + 344 | code_a * get_nv(ld - 1, r2b) + 345 | code_b; 346 | norm2s[i] = r2a + r2b; 347 | } 348 | dim2 /= 2; 349 | } 350 | return codes[0]; 351 | } 352 | 353 | 354 | 355 | void ZnSphereCodecRec::decode(uint64_t code, float *c) const 356 | { 357 | uint64_t codes[dim]; 358 | int norm2s[dim]; 359 | codes[0] = code; 360 | norm2s[0] = r2; 361 | 362 | int dim2 = 1; 363 | for(int ld = log2_dim; ld > decode_cache_ld; ld--) { 364 | for (int i = dim2 - 1; i >= 0; i--) { 365 | int r2sub = norm2s[i]; 366 | int i0 = 0, i1 = r2sub + 1; 367 | uint64_t codei = codes[i]; 368 | const uint64_t *cum = 369 | &all_nv_cum[(ld * (r2 + 1) + r2sub) * (r2 + 1)]; 370 | while (i1 > i0 + 1) { 371 | int imed = (i0 + i1) / 2; 372 | if (cum[imed] <= codei) 373 | i0 = imed; 374 | else 375 | i1 = imed; 376 | } 377 | int r2a = i0, r2b = r2sub - i0; 378 | codei -= cum[r2a]; 379 | norm2s[2 * i] = r2a; 380 | norm2s[2 * i + 1] = r2b; 381 | 382 | uint64_t code_a = codei / get_nv(ld - 1, r2b); 383 | uint64_t code_b = codei % get_nv(ld - 1, r2b); 384 | 385 | codes[2 * i] = code_a; 386 | codes[2 * i + 1] = code_b; 387 | 388 | } 389 | dim2 *= 2; 390 | } 391 | 392 | if (decode_cache_ld == 0) { 393 | for(int i = 0; i < dim; i++) { 394 | if (norm2s[i] == 0) { 395 | c[i] = 0; 396 | } else { 397 | float r = sqrt(norm2s[i]); 398 | assert(r * r == norm2s[i]); 399 | c[i] = codes[i] == 0 ? r : -r; 400 | } 401 | } 402 | } else { 403 | int subdim = 1 << decode_cache_ld; 404 | assert ((dim2 * subdim) == dim); 405 | 406 | for(int i = 0; i < dim2; i++) { 407 | 408 | const std::vector & cache = 409 | decode_cache[norm2s[i]]; 410 | assert(codes[i] < cache.size()); 411 | memcpy(c + i * subdim, 412 | &cache[codes[i] * subdim], 413 | sizeof(*c)* subdim); 414 | } 415 | } 416 | } 417 | -------------------------------------------------------------------------------- /library/spreadingvectors/lattices/lattice_Zn.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD+Patents license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | #pragma once 9 | 10 | #include "lattice_utils.h" 11 | 12 | #include 13 | 14 | 15 | 16 | /** returns the nearest vertex in the sphere to a query. Retunrs only 17 | * the coordinates, not an id. 18 | * 19 | * Algorithm: all points are derived from a one atom vector up to a 20 | * permutation and sign changes. The search function finds the most 21 | * appropriate atom and transformation. 22 | */ 23 | struct ZnSphereSearch { 24 | int dimS, r2; 25 | int natom; 26 | 27 | /// size dim * ntatom 28 | std::vector voc; 29 | 30 | ZnSphereSearch(int dim, int r2); 31 | 32 | /// find nearest centroid 33 | float search(const float *x, float *c) const; 34 | 35 | /// full call. Requires externally-allocated temp space 36 | float search(const float *x, float *c, 37 | float *tmp, // size 2 *dim 38 | int *tmp_int, // size dim 39 | int *ibest_out = nullptr 40 | ) const; 41 | 42 | // multi-threaded 43 | void search_multi(int n, const float *x, 44 | float *c_out, 45 | float *dp_out); 46 | 47 | }; 48 | 49 | 50 | /*************************************************************************** 51 | * Support ids as well. 52 | * 53 | * Limitations: ids are limited to 64 bit 54 | ***************************************************************************/ 55 | 56 | 57 | struct VectorCodec { 58 | /// size of the collection 59 | uint64_t nv; 60 | int dim; 61 | 62 | VectorCodec(int dim): nv(0), dim(dim) {} 63 | 64 | /// encode a vector from a collection 65 | virtual uint64_t encode(const float *x) const = 0; 66 | 67 | /// decode it 68 | virtual void decode(uint64_t code, float *c) const = 0; 69 | 70 | // call encode on nc vectors 71 | void encode_multi (size_t nc, const float *c, 72 | uint64_t * codes) const; 73 | 74 | // call decode on nc codes 75 | void decode_multi (size_t nc, const uint64_t * codes, 76 | float *c) const; 77 | 78 | // find the nearest neighbor of each xq 79 | // (decodes and computes distances) 80 | void find_nn (size_t n, const uint64_t * codes, 81 | size_t nq, const float *xq, 82 | long *idx, float *dis); 83 | 84 | virtual ~VectorCodec() {} 85 | 86 | }; 87 | 88 | /** codec that can return ids for the encoded vectors 89 | * 90 | * uses the ZnSphereSearch to encode the vector by encoding the 91 | * permutation and signs. Depends on ZnSphereSearch because it uses 92 | * the atom numbers */ 93 | struct ZnSphereCodec: ZnSphereSearch, VectorCodec { 94 | 95 | struct CodeSegment:Repeats { 96 | CodeSegment(const Repeats & r): Repeats(r) {} 97 | uint64_t c0; // first code assigned to segment 98 | int signbits; 99 | }; 100 | 101 | std::vector code_segments; 102 | uint64_t nv; 103 | size_t code_size; 104 | 105 | ZnSphereCodec(int dim, int r2); 106 | 107 | uint64_t search_and_encode(const float *x) const; 108 | 109 | void decode(uint64_t code, float *c) const override; 110 | 111 | /// takes vectors that do not need to be centroids 112 | uint64_t encode(const float *x) const override; 113 | 114 | virtual ~ZnSphereCodec() {} 115 | }; 116 | 117 | /** recursive sphere codec 118 | * 119 | * Uses a recursive decomposition on the dimensions to encode 120 | * centroids found by the ZnSphereSearch. The codes are *not* 121 | * compatible with the ones of ZnSpehreCodec 122 | */ 123 | struct ZnSphereCodecRec: VectorCodec { 124 | 125 | int r2; 126 | 127 | int log2_dim; 128 | int code_size; 129 | 130 | ZnSphereCodecRec(int dim, int r2); 131 | 132 | 133 | uint64_t encode_centroid(const float *c) const; 134 | 135 | void decode(uint64_t code, float *c) const override; 136 | 137 | /// vectors need to be centroids 138 | uint64_t encode(const float *x) const override; 139 | 140 | std::vector all_nv; 141 | std::vector all_nv_cum; 142 | 143 | int decode_cache_ld; 144 | std::vector > decode_cache; 145 | 146 | // nb of vectors in the sphere in dim 2^ld with r2 radius 147 | uint64_t get_nv(int ld, int r2a) const; 148 | 149 | // cumulative version 150 | uint64_t get_nv_cum(int ld, int r2t, int r2a) const; 151 | void set_nv_cum(int ld, int r2t, int r2a, uint64_t v); 152 | 153 | virtual ~ZnSphereCodecRec() {} 154 | 155 | }; 156 | -------------------------------------------------------------------------------- /library/spreadingvectors/lattices/lattice_utils.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD+Patents license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | #include "lattice_utils.h" 9 | 10 | #include 11 | 12 | #include 13 | 14 | #include 15 | 16 | 17 | // compute combinations of n integer values <= v that sum up to total (squared) 18 | point_list_t sum_of_sq (float total, int v, int n, float add) { 19 | if (total < 0) { 20 | return point_list_t(); 21 | } else if (n == 1) { 22 | while (sqr(v + add) > total) v--; 23 | if (sqr(v + add) == total) { 24 | return point_list_t(1, v + add); 25 | } else { 26 | return point_list_t(); 27 | } 28 | } else { 29 | point_list_t res; 30 | while (v >= 0) { 31 | point_list_t sub_points = sum_of_sq (total - sqr(v + add), v, n - 1, add); 32 | for (size_t i = 0; i < sub_points.size(); i += n - 1) { 33 | res.push_back (v + add); 34 | for (int j = 0; j < n - 1; j++) 35 | res.push_back(sub_points[i + j]); 36 | } 37 | v--; 38 | } 39 | return res; 40 | } 41 | } 42 | 43 | 44 | 45 | Comb::Comb(int nmax):nmax(nmax) { 46 | tab.resize(nmax * nmax, 0); 47 | tab[0] = 1; 48 | for(int i = 1; i < nmax; i++) { 49 | tab[i * nmax] = 1; 50 | for(int j = 1; j <= i; j++) { 51 | tab[i * nmax + j] = 52 | tab[(i - 1) * nmax + j] + 53 | tab[(i - 1) * nmax + (j - 1)]; 54 | } 55 | 56 | } 57 | } 58 | 59 | Comb comb(100); 60 | 61 | 62 | 63 | 64 | Repeats::Repeats (int dim, const float *c): dim(dim) 65 | { 66 | for(int i = 0; i < dim; i++) { 67 | int j = 0; 68 | for(;;) { 69 | if (j == repeats.size()) { 70 | repeats.push_back(Repeat{c[i], 1}); 71 | break; 72 | } 73 | if (repeats[j].val == c[i]) { 74 | repeats[j].n++; 75 | break; 76 | } 77 | j++; 78 | } 79 | } 80 | } 81 | 82 | 83 | long Repeats::count () const 84 | { 85 | long accu = 1; 86 | int remain = dim; 87 | for (int i = 0; i < repeats.size(); i++) { 88 | accu *= comb(remain, repeats[i].n); 89 | remain -= repeats[i].n; 90 | } 91 | return accu; 92 | } 93 | 94 | // optimized version for < 64 bits 95 | static long repeats_encode_64 ( 96 | const std::vector & repeats, 97 | int dim, const float *c) 98 | { 99 | uint64_t coded = 0; 100 | int nfree = dim; 101 | uint64_t code = 0, shift = 1; 102 | for (auto r = repeats.begin(); r != repeats.end(); ++r) { 103 | int rank = 0, occ = 0; 104 | uint64_t code_comb = 0; 105 | uint64_t tosee = ~coded; 106 | for(;;) { 107 | // directly jump to next available slot. 108 | int i = __builtin_ctzl(tosee); 109 | tosee &= ~(1UL << i) ; 110 | if (c[i] == r->val) { 111 | code_comb += comb(rank, occ + 1); 112 | occ++; 113 | coded |= 1UL << i; 114 | if (occ == r->n) break; 115 | } 116 | rank++; 117 | } 118 | uint64_t max_comb = comb(nfree, r->n); 119 | code += shift * code_comb; 120 | shift *= max_comb; 121 | nfree -= r->n; 122 | } 123 | return code; 124 | } 125 | 126 | 127 | 128 | // version with a bool vector that works for > 64 dim 129 | long Repeats::encode(const float *c) const 130 | { 131 | if (dim < 64) 132 | return repeats_encode_64 (repeats, dim, c); 133 | std::vector coded(dim, false); 134 | int nfree = dim; 135 | uint64_t code = 0, shift = 1; 136 | for (auto r = repeats.begin(); r != repeats.end(); ++r) { 137 | int rank = 0, occ = 0; 138 | uint64_t code_comb = 0; 139 | for (int i = 0; i < dim; i++) { 140 | if (!coded[i]) { 141 | if (c[i] == r->val) { 142 | code_comb += comb(rank, occ + 1); 143 | occ++; 144 | coded[i] = true; 145 | if (occ == r->n) break; 146 | } 147 | rank++; 148 | } 149 | } 150 | uint64_t max_comb = comb(nfree, r->n); 151 | code += shift * code_comb; 152 | shift *= max_comb; 153 | nfree -= r->n; 154 | } 155 | return code; 156 | } 157 | 158 | 159 | 160 | 161 | 162 | static int decode_comb_1 (uint64_t *n, int k1, int r) { 163 | while (comb(r, k1) > *n) 164 | r--; 165 | *n -= comb(r, k1); 166 | return r; 167 | } 168 | 169 | 170 | static void repeats_decode_64( 171 | const std::vector & repeats, 172 | int dim, uint64_t code, float *c) 173 | { 174 | uint64_t decoded = 0; 175 | int nfree = dim; 176 | for (auto r = repeats.begin(); r != repeats.end(); ++r) { 177 | uint64_t max_comb = comb(nfree, r->n); 178 | uint64_t code_comb = code % max_comb; 179 | code /= max_comb; 180 | 181 | int occ = 0; 182 | int rank = nfree; 183 | int next_rank = decode_comb_1 (&code_comb, r->n, rank); 184 | uint64_t tosee = ((1UL << dim) - 1) ^ decoded; 185 | for(;;) { 186 | int i = 63 - __builtin_clzl(tosee); 187 | tosee &= ~(1UL << i); 188 | rank--; 189 | if (rank == next_rank) { 190 | decoded |= 1UL << i; 191 | c[i] = r->val; 192 | occ++; 193 | if (occ == r->n) break; 194 | next_rank = decode_comb_1 ( 195 | &code_comb, r->n - occ, next_rank); 196 | } 197 | } 198 | nfree -= r->n; 199 | } 200 | 201 | } 202 | 203 | 204 | 205 | void Repeats::decode(uint64_t code, float *c) const 206 | { 207 | if (dim < 64) { 208 | repeats_decode_64 (repeats, dim, code, c); 209 | return; 210 | } 211 | 212 | std::vector decoded(dim, false); 213 | int nfree = dim; 214 | for (auto r = repeats.begin(); r != repeats.end(); ++r) { 215 | uint64_t max_comb = comb(nfree, r->n); 216 | uint64_t code_comb = code % max_comb; 217 | code /= max_comb; 218 | 219 | int occ = 0; 220 | int rank = nfree; 221 | int next_rank = decode_comb_1 (&code_comb, r->n, rank); 222 | for (int i = dim - 1; i >= 0; i--) { 223 | if (!decoded[i]) { 224 | rank--; 225 | if (rank == next_rank) { 226 | decoded[i] = true; 227 | c[i] = r->val; 228 | occ++; 229 | if (occ == r->n) break; 230 | next_rank = decode_comb_1 ( 231 | &code_comb, r->n - occ, next_rank); 232 | } 233 | } 234 | } 235 | nfree -= r->n; 236 | } 237 | 238 | } 239 | 240 | 241 | 242 | // reads 0 <= d < 4 floats as __m128 243 | static inline __m128 masked_read (int d, const float *x) 244 | { 245 | assert (0 <= d && d < 4); 246 | __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0}; 247 | switch (d) { 248 | case 3: 249 | buf[2] = x[2]; 250 | case 2: 251 | buf[1] = x[1]; 252 | case 1: 253 | buf[0] = x[0]; 254 | } 255 | return _mm_load_ps (buf); 256 | // cannot use AVX2 _mm_mask_set1_epi32 257 | } 258 | 259 | float fvec_inner_product (const float * x, 260 | const float * y, 261 | size_t d) 262 | { 263 | __m256 msum1 = _mm256_setzero_ps(); 264 | 265 | while (d >= 8) { 266 | __m256 mx = _mm256_loadu_ps (x); x += 8; 267 | __m256 my = _mm256_loadu_ps (y); y += 8; 268 | msum1 = _mm256_add_ps (msum1, _mm256_mul_ps (mx, my)); 269 | d -= 8; 270 | } 271 | 272 | __m128 msum2 = _mm256_extractf128_ps(msum1, 1); 273 | msum2 += _mm256_extractf128_ps(msum1, 0); 274 | 275 | if (d >= 4) { 276 | __m128 mx = _mm_loadu_ps (x); x += 4; 277 | __m128 my = _mm_loadu_ps (y); y += 4; 278 | msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my)); 279 | d -= 4; 280 | } 281 | 282 | if (d > 0) { 283 | __m128 mx = masked_read (d, x); 284 | __m128 my = masked_read (d, y); 285 | msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my)); 286 | } 287 | 288 | msum2 = _mm_hadd_ps (msum2, msum2); 289 | msum2 = _mm_hadd_ps (msum2, msum2); 290 | return _mm_cvtss_f32 (msum2); 291 | } 292 | 293 | float fvec_L2sqr (const float * x, 294 | const float * y, 295 | size_t d) 296 | { 297 | __m256 msum1 = _mm256_setzero_ps(); 298 | 299 | while (d >= 8) { 300 | __m256 mx = _mm256_loadu_ps (x); x += 8; 301 | __m256 my = _mm256_loadu_ps (y); y += 8; 302 | const __m256 a_m_b1 = mx - my; 303 | msum1 += a_m_b1 * a_m_b1; 304 | d -= 8; 305 | } 306 | 307 | __m128 msum2 = _mm256_extractf128_ps(msum1, 1); 308 | msum2 += _mm256_extractf128_ps(msum1, 0); 309 | 310 | if (d >= 4) { 311 | __m128 mx = _mm_loadu_ps (x); x += 4; 312 | __m128 my = _mm_loadu_ps (y); y += 4; 313 | const __m128 a_m_b1 = mx - my; 314 | msum2 += a_m_b1 * a_m_b1; 315 | d -= 4; 316 | } 317 | 318 | if (d > 0) { 319 | __m128 mx = masked_read (d, x); 320 | __m128 my = masked_read (d, y); 321 | __m128 a_m_b1 = mx - my; 322 | msum2 += a_m_b1 * a_m_b1; 323 | } 324 | 325 | msum2 = _mm_hadd_ps (msum2, msum2); 326 | msum2 = _mm_hadd_ps (msum2, msum2); 327 | return _mm_cvtss_f32 (msum2); 328 | } 329 | -------------------------------------------------------------------------------- /library/spreadingvectors/lattices/lattice_utils.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD+Patents license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | #pragma once 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | 15 | inline float sqr(float x) { 16 | return x * x; 17 | } 18 | 19 | inline int popcount64(uint64_t x) { 20 | return __builtin_popcountl(x); 21 | } 22 | 23 | typedef std::vector point_list_t; 24 | 25 | /** compute combinations of n integer values <= v that sum up to total 26 | * (squared). 27 | * 28 | * if ret is the returned point_list_t, the number of vectors is 29 | * ret.size() / n 30 | */ 31 | point_list_t sum_of_sq (float total, int v, int n, float add=0); 32 | 33 | 34 | inline float dotprod(int n, const float *a, const float *b) { 35 | float accu = 0; 36 | for(int i = 0; i < n; i++) 37 | accu += a[i] * b[i]; 38 | return accu; 39 | } 40 | 41 | struct Comb { 42 | std::vector tab; // Pascal's triangle 43 | int nmax; 44 | Comb(int nmax); 45 | uint64_t operator()(int n, int p) const { 46 | if (p > n) return 0; 47 | return tab[n * nmax + p]; 48 | } 49 | }; 50 | 51 | // initialized with nmax=100 52 | extern Comb comb; 53 | 54 | struct Repeat { 55 | float val; 56 | int n; 57 | }; 58 | 59 | 60 | /** Repeats: used to encode a vector that has n occurrences of 61 | * val. Encodes the signs and permutation of the vector. Useful for 62 | * atoms. 63 | */ 64 | struct Repeats { 65 | int dim; 66 | std::vector repeats; 67 | 68 | // initialize from a template of the atom. 69 | Repeats(int dim = 0, const float *c = nullptr); 70 | 71 | // count number of possible codes for this atom 72 | long count() const; 73 | 74 | long encode(const float *c) const; 75 | 76 | void decode(uint64_t code, float *c) const; 77 | }; 78 | 79 | 80 | // optimized inner product 81 | float fvec_inner_product (const float * x, 82 | const float * y, 83 | size_t d); 84 | -------------------------------------------------------------------------------- /library/spreadingvectors/lattices/test_Zn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import sys 8 | import numpy as np 9 | import pdb 10 | from lattices import c_lattices 11 | import unittest 12 | from lattices import Zn_lattice 13 | 14 | class TestZnCodec(unittest.TestCase): 15 | 16 | def test_codec(self): 17 | self.do_test(32, 14) 18 | 19 | def test_codec_rec(self): 20 | self.do_test(24, 79) 21 | 22 | def do_test(self, dim, r2): 23 | codec = Zn_lattice.ZnCodec(dim, r2) 24 | # print("nb atoms", codec.natom) 25 | rs = np.random.RandomState(123) 26 | 27 | n = 2000 28 | x = rs.randn(n, dim).astype('float32') 29 | x /= np.sqrt((x ** 2).sum(1)).reshape(-1, 1) 30 | quant = codec.quantize(x) 31 | 32 | codes = codec.encode(x) 33 | x_decoded = codec.decode(codes) 34 | 35 | assert np.all(x_decoded == quant) 36 | 37 | codec2 = Zn_lattice.ZnCodecPy(dim, r2) 38 | 39 | quant2 = codec2.quantize(x) 40 | assert np.all(quant == quant2) 41 | 42 | ##################################################################### 43 | # Low-level tests 44 | ##################################################################### 45 | 46 | 47 | swig_ptr = c_lattices.swig_ptr 48 | 49 | 50 | class BasicTest(unittest.TestCase): 51 | 52 | def test_comb(self): 53 | assert c_lattices.cvar.comb(2, 1) == 2 54 | 55 | def test_repeats(self): 56 | rs = np.random.RandomState(123) 57 | dim = 32 58 | for i in range(1000): 59 | vec = np.floor((rs.rand(dim) ** 7) * 3).astype('float32') 60 | vecs = vec.copy(); vecs.sort() 61 | repeats = c_lattices.Repeats(dim, swig_ptr(vecs)) 62 | rr = [repeats.repeats.at(i) for i in range(repeats.repeats.size())] 63 | # print([(r.val, r.n) for r in rr]) 64 | code = repeats.encode(swig_ptr(vec)) 65 | #print(vec, code) 66 | vec2 = np.zeros(dim, dtype='float32') 67 | repeats.decode(code, swig_ptr(vec2)) 68 | # print(vec2) 69 | assert np.all(vec == vec2) 70 | 71 | 72 | class TestZnSphereCodec(unittest.TestCase): 73 | 74 | def test_codec(self): 75 | 76 | dim = 32 77 | r2 = 14 78 | codec = c_lattices.ZnSphereCodec(dim, r2) 79 | # print("nb atoms", codec.natom) 80 | rs = np.random.RandomState(123) 81 | for i in range(1000): 82 | x = rs.randn(dim).astype('float32') 83 | ref_c = np.zeros(dim, dtype='float32') 84 | codec.search(swig_ptr(x), swig_ptr(ref_c)) 85 | code = codec.search_and_encode(swig_ptr(x)) 86 | # print(x, code) 87 | c = np.zeros(dim, dtype='float32') 88 | codec.decode(code, swig_ptr(c)) 89 | # print(ref_c, c) 90 | 91 | 92 | class TestZnSphereCodecRec(unittest.TestCase): 93 | 94 | def test_encode_centroid(self): 95 | dim = 8 96 | r2 = 5 97 | ref_codec = c_lattices.ZnSphereCodec(dim, r2) 98 | codec = c_lattices.ZnSphereCodecRec(dim, r2) 99 | # print(ref_codec.nv, codec.nv) 100 | assert ref_codec.nv == codec.nv 101 | s = set() 102 | for i in range(ref_codec.nv): 103 | c = np.zeros(dim, dtype='float32') 104 | ref_codec.decode(i, swig_ptr(c)) 105 | code = codec.encode_centroid(swig_ptr(c)) 106 | assert 0 <= code < codec.nv 107 | s.add(code) 108 | assert len(s) == codec.nv 109 | 110 | def test_codec(self): 111 | dim = 16 112 | r2 = 6 113 | codec = c_lattices.ZnSphereCodecRec(dim, r2) 114 | # print("nv=", codec.nv) 115 | for i in range(codec.nv): 116 | c = np.zeros(dim, dtype='float32') 117 | codec.decode(i, swig_ptr(c)) 118 | code = codec.encode_centroid(swig_ptr(c)) 119 | assert code == i 120 | 121 | 122 | 123 | if __name__ == '__main__': 124 | unittest.main() 125 | -------------------------------------------------------------------------------- /library/spreadingvectors/reproduce.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (c) 2015-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the CC-by-NC license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | set -e 10 | 11 | if [ ! -d ckpts/ ]; then 12 | wget http://dl.fbaipublicfiles.com/spreadingvectors/ckpt.zip 13 | unzip ckpt.zip 14 | fi 15 | 16 | echo "reproduce Catalyzer+sign in table 2" 17 | for f in ckpts/binary/{bigann,deep1b}/ckpt_{16,32,64,128}.pth; do 18 | echo $f 19 | python eval.py --quantizer binary --ckpt-path $f 20 | done 21 | 22 | echo "reproduce Catalyzer+Lattice (+end2end) in table 1" 23 | for f in ckpts/lattice/**/*.pth; do 24 | echo $f 25 | python eval.py --quantizer zn_79 --ckpt-path $f 26 | done 27 | 28 | echo "reproduce Catalyzer+OPQ in table 1" 29 | for f in ckpts/lattice/**/ckpt.pth; do 30 | echo $f 31 | python eval.py --quantizer opq_64 --ckpt-path $f 32 | done 33 | -------------------------------------------------------------------------------- /library/spreadingvectors/requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | blas=1.0=mkl 5 | ca-certificates=2018.12.5=0 6 | certifi=2018.11.29=py37_0 7 | cffi=1.11.5=py37he75722e_1 8 | intel-openmp=2019.1=144 9 | libedit=3.1.20181209=hc058e9b_0 10 | libffi=3.2.1=hd88cf55_4 11 | libgcc-ng=8.2.0=hdf63c60_1 12 | libgfortran-ng=7.3.0=hdf63c60_0 13 | libstdcxx-ng=8.2.0=hdf63c60_1 14 | mkl=2019.1=144 15 | mkl_fft=1.0.10=py37ha843d7b_0 16 | mkl_random=1.0.2=py37hd81dba3_0 17 | ncurses=6.1=he6710b0_1 18 | ninja=1.8.2=py37h6bb024c_1 19 | numpy=1.15.4=py37h7e9f1db_0 20 | numpy-base=1.15.4=py37hde5b4d6_0 21 | openssl=1.1.1a=h7b6447c_0 22 | pip=18.1=py37_0 23 | pycparser=2.19=py37_0 24 | python=3.7.2=h0371630_0 25 | pytorch-cpu=1.0.0=py3.7_cpu_1 26 | readline=7.0=h7b6447c_5 27 | setuptools=40.6.3=py37_0 28 | sqlite=3.26.0=h7b6447c_0 29 | tk=8.6.8=hbc83047_0 30 | wheel=0.32.3=py37_0 31 | xz=5.2.4=h14c3975_4 32 | zlib=1.2.11=h7b6447c_3 33 | -------------------------------------------------------------------------------- /library/spreadingvectors/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import division 8 | from lib.data import load_dataset 9 | import time 10 | import argparse 11 | import numpy as np 12 | from torch import nn, optim 13 | from lib.metrics import ValidationFunction, ValidationFunction_k, get_nearestneighbors, sanitize 14 | from lib.net import Normalize, forward_pass, StraightThroughQuantizer 15 | from lib.quantizers import Zn 16 | import torch.nn.functional as F 17 | import torch 18 | import itertools 19 | import pickle as pkl 20 | 21 | 22 | def repeat(l, r): 23 | return list(itertools.chain.from_iterable(itertools.repeat(x, r) for x in l)) 24 | 25 | 26 | def pairwise_NNs_inner(x): 27 | """ 28 | Pairwise nearest neighbors for L2-normalized vectors. 29 | Uses Torch rather than Faiss to remain on GPU. 30 | """ 31 | # parwise dot products (= inverse distance) 32 | dots = torch.mm(x, x.t()) 33 | n = x.shape[0] 34 | dots.view(-1)[::(n+1)].fill_(-1) # Trick to fill diagonal with -1 35 | _, I = torch.max(dots, 1) # max inner prod -> min distance 36 | return I 37 | 38 | 39 | def triplet_optimize(xt, gt_nn, net, args, val_func): 40 | """ 41 | train a triplet loss on the training set xt (a numpy array) 42 | gt_nn: ground-truth nearest neighbors in input space 43 | net: network to optimize 44 | args: various runtime arguments 45 | val_func: callback called periodically to evaluate the network 46 | """ 47 | 48 | lr_schedule = [float(x.rstrip().lstrip()) for x in args.lr_schedule.split(",")] 49 | assert args.epochs % len(lr_schedule) == 0 50 | lr_schedule = repeat(lr_schedule, args.epochs // len(lr_schedule)) 51 | print("Lr schedule", lr_schedule) 52 | 53 | N, kpos = gt_nn.shape 54 | 55 | if args.quantizer_train != "": 56 | assert args.quantizer_train.startswith("zn_") 57 | r2 = int(args.quantizer_train.split("_")[1]) 58 | qt = StraightThroughQuantizer(Zn(r2)) 59 | else: 60 | qt = lambda x: x 61 | 62 | xt_var = torch.from_numpy(xt).to(args.device) 63 | 64 | # prepare optimizer 65 | optimizer = optim.SGD(net.parameters(), lr_schedule[0], momentum=args.momentum) 66 | pdist = nn.PairwiseDistance(2) 67 | 68 | all_logs = [] 69 | for epoch in range(args.epochs): 70 | # Update learning rate 71 | args.lr = lr_schedule[epoch] 72 | for param_group in optimizer.param_groups: 73 | param_group['lr'] = args.lr 74 | 75 | t0 = time.time() 76 | 77 | # Sample positives for triplet 78 | rank_pos = np.random.choice(kpos, size=N) 79 | positive_idx = gt_nn[np.arange(N), rank_pos] 80 | 81 | # Sample negatives for triplet 82 | net.eval() 83 | print(" Forward pass") 84 | xl_net = forward_pass(net, xt, 1024) 85 | print(" Distances") 86 | I = get_nearestneighbors(xl_net, qt(xl_net), args.rank_negative, args.device, needs_exact=False) 87 | negative_idx = I[:, -1] 88 | 89 | # training pass 90 | print(" Train") 91 | net.train() 92 | avg_triplet, avg_uniform, avg_loss = 0, 0, 0 93 | offending = idx_batch = 0 94 | 95 | # process dataset in a random order 96 | perm = np.random.permutation(N) 97 | 98 | t1 = time.time() 99 | 100 | for i0 in range(0, N, args.batch_size): 101 | i1 = min(i0 + args.batch_size, N) 102 | n = i1 - i0 103 | 104 | data_idx = perm[i0:i1] 105 | 106 | # anchor, positives, negatives 107 | ins = xt_var[data_idx] 108 | pos = xt_var[positive_idx[data_idx]] 109 | neg = xt_var[negative_idx[data_idx]] 110 | 111 | # do the forward pass (+ record gradients) 112 | ins, pos, neg = net(ins), net(pos), net(neg) 113 | pos, neg = qt(pos), qt(neg) 114 | 115 | # triplet loss 116 | per_point_loss = pdist(ins, pos) - pdist(ins, neg) 117 | per_point_loss = F.relu(per_point_loss) 118 | loss_triplet = per_point_loss.mean() 119 | offending += torch.sum(per_point_loss.data > 0).item() 120 | 121 | # entropy loss 122 | I = pairwise_NNs_inner(ins.data) 123 | distances = pdist(ins, ins[I]) 124 | loss_uniform = - torch.log(n * distances).mean() 125 | 126 | # combined loss 127 | loss = loss_triplet + args.lambda_uniform * loss_uniform 128 | 129 | # collect some stats 130 | avg_triplet += loss_triplet.data.item() 131 | avg_uniform += loss_uniform.data.item() 132 | avg_loss += loss.data.item() 133 | 134 | optimizer.zero_grad() 135 | loss.backward() 136 | optimizer.step() 137 | 138 | idx_batch += 1 139 | 140 | avg_triplet /= idx_batch 141 | avg_uniform /= idx_batch 142 | avg_loss /= idx_batch 143 | 144 | logs = { 145 | 'epoch': epoch, 146 | 'loss_triplet': avg_triplet, 147 | 'loss_uniform': avg_uniform, 148 | 'loss': avg_loss, 149 | 'offending': offending, 150 | 'lr': args.lr 151 | } 152 | all_logs.append(logs) 153 | 154 | t2 = time.time() 155 | # maybe perform a validation run 156 | if (epoch + 1) % args.val_freq == 0: 157 | logs['val'] = val_func(net, epoch, args, all_logs) 158 | 159 | t3 = time.time() 160 | 161 | # synthetic logging 162 | print ('epoch %d, times: [hn %.2f s epoch %.2f s val %.2f s]' 163 | ' lr = %f' 164 | ' loss = %g = %g + lam * %g, offending %d' % ( 165 | epoch, t1 - t0, t2 - t1, t3 - t2, 166 | args.lr, 167 | avg_loss, avg_triplet, avg_uniform, offending 168 | )) 169 | 170 | logs['times'] = (t1 - t0, t2 - t1, t3 - t2) 171 | 172 | return all_logs 173 | 174 | def quantize_and_get_hist(x, qnt): 175 | # assume first half rows of x are class 1 and the rest are class 2 176 | q = qnt.quantize(x) 177 | n = x.shape[0] // 2 178 | 179 | inv = np.unique(q, return_inverse=True, axis=0)[1] 180 | 181 | n_cluster = np.unique(inv).shape[0] 182 | print('support size of quantized mutlinomial =', n_cluster) 183 | 184 | inv1 = inv[:n] 185 | inv2 = inv[n:] 186 | cl_ids_1, counts_1 = np.unique(inv1, return_counts=True) 187 | cl_ids_2, counts_2 = np.unique(inv2, return_counts=True) 188 | counts_dict_1 = dict(zip(cl_ids_1, counts_1)) 189 | counts_dict_2 = dict(zip(cl_ids_2, counts_2)) 190 | counts_1 = np.asarray([counts_dict_1.get(i, 0) for i in range(n_cluster)]) 191 | counts_2 = np.asarray([counts_dict_2.get(i, 0) for i in range(n_cluster)]) 192 | return counts_1/counts_1.sum(), counts_2/counts_2.sum() 193 | 194 | 195 | if __name__ == '__main__': 196 | parser = argparse.ArgumentParser() 197 | 198 | def aa(*args, **kwargs): 199 | group.add_argument(*args, **kwargs) 200 | 201 | group = parser.add_argument_group('dataset options') 202 | aa("--database", default="deep1b") # can be "bigann", "deep1b" or "*.fvecs" 203 | aa("--size_base", type=int, default=int(1e6), 204 | help="size of evaluation dataset") 205 | aa("--num_learn", type=int, default=int(5e5), 206 | help="nb of learning vectors") 207 | 208 | group = parser.add_argument_group('Model hyperparameters') 209 | aa("--dint", type=int, default=1024, 210 | help="size of hidden states") 211 | aa("--dout", type=int, default=16, 212 | help="output dimension") 213 | aa("--lambda_uniform", type=float, default=0.05, 214 | help="weight of the uniformity loss") 215 | 216 | group = parser.add_argument_group('Training hyperparameters') 217 | aa("--batch_size", type=int, default=64) 218 | aa("--epochs", type=int, default=160) 219 | aa("--momentum", type=float, default=0.9) 220 | aa("--rank_positive", type=int, default=10, 221 | help="this number of vectors are considered positives") 222 | aa("--rank_negative", type=int, default=50, 223 | help="these are considered negatives") 224 | 225 | group = parser.add_argument_group('Computation params') 226 | aa("--seed", type=int, default=1234) 227 | aa("--checkpoint_dir", type=str, default="", 228 | help="checkpoint directory") 229 | aa("--init_name", type=str, default="", 230 | help="checkpoint to load from") 231 | aa("--save_best_criterion", type=str, default="", 232 | help="for example r2=4,rank=10") 233 | aa("--quantizer_train", type=str, default="") 234 | aa("--lr_schedule", type=str, default="0.1,0.1,0.05,0.01") 235 | aa("--device", choices=["cuda", "cpu", "auto"], default="auto") 236 | aa("--val_freq", type=int, default=10, 237 | help="frequency of validation calls") 238 | aa("--validation_quantizers", type=str, default="", 239 | help="r2 values to try in validation") 240 | 241 | args = parser.parse_args() 242 | 243 | if args.device == "auto": 244 | args.device = "cuda" if torch.cuda.is_available() else "cpu" 245 | 246 | np.random.seed(args.seed) 247 | torch.manual_seed(args.seed) 248 | 249 | # Radiuses that correspond to 16, 32 and 64 bits for Zn 250 | radiuses = { 251 | 4: [20, 25, 30, 50], 252 | 8: [2, 3, 4, 5], 253 | 16: [4, 21, 200], 254 | 24: [3, 10, 79], 255 | 32: [3, 8, 36], 256 | 40: [2, 7, 24], 257 | } 258 | # Validation quantizers default to Zn 259 | if args.validation_quantizers == "": 260 | args.validation_quantizers = ["zn_%d" % x for x in radiuses[args.dout]] 261 | else: 262 | args.validation_quantizers = [x.rstrip().lstrip() for x in args.validation_quantizers.split(",")] 263 | # Default save_best is 64 bits for Zn 264 | if args.save_best_criterion == "": 265 | args.save_best_criterion = "zn_%d,rank=10" % radiuses[args.dout][-1] 266 | print(args) 267 | 268 | print ("load dataset %s" % args.database) 269 | (xt, xb, xq, gt) = load_dataset(args.database, args.device, size=args.size_base, test=False) 270 | 271 | print ("keeping %d/%d training vectors" % (args.num_learn, xt.shape[0])) 272 | xt = sanitize(xt[:args.num_learn]) 273 | 274 | print ("computing training ground truth") 275 | xt_gt = get_nearestneighbors(xt, xt, args.rank_positive, device=args.device) 276 | 277 | print ("build network") 278 | 279 | dim = xb.shape[1] 280 | dint, dout = args.dint, args.dout 281 | 282 | net = nn.Sequential( 283 | nn.Linear(in_features=dim, out_features=dint, bias=True), 284 | nn.BatchNorm1d(dint), 285 | nn.ReLU(), 286 | nn.Linear(in_features=dint, out_features=dint, bias=True), 287 | nn.BatchNorm1d(dint), 288 | nn.ReLU(), 289 | nn.Linear(in_features=dint, out_features=dout, bias=True), 290 | Normalize() 291 | ) 292 | 293 | if args.init_name != '': 294 | print ("loading state from %s" % args.init_name) 295 | ckpt = torch.load(args.init_name) 296 | net.load_state_dict(ckpt['state_dict']) 297 | start_epoch = ckpt['epoch'] 298 | 299 | net.to(args.device) 300 | 301 | val = ValidationFunction_k(xq, xb, gt, args.checkpoint_dir, 302 | validation_key=args.save_best_criterion, 303 | quantizers=args.validation_quantizers) 304 | 305 | all_logs = triplet_optimize(xt, xt_gt, net, args, val) 306 | xt_torch = torch.from_numpy(xt).to(args.device) 307 | with torch.no_grad(): 308 | feats = net(xt_torch) 309 | feats = feats.cpu().numpy() 310 | print(f'feats shape: {feats.shape}') 311 | 312 | quant = Zn(r2=50, d=feats.shape[1]) 313 | hist1, hist2 = quantize_and_get_hist(feats, quant) 314 | print('hist1:', hist1.tolist()) 315 | print('hist2:', hist2.tolist()) 316 | 317 | 318 | -------------------------------------------------------------------------------- /library/spreadingvectors/train_spv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import division 8 | from lib.data import load_dataset 9 | import time 10 | import argparse 11 | import numpy as np 12 | from torch import nn, optim 13 | from lib.metrics import ValidationFunction, ValidationFunction_k, get_nearestneighbors, sanitize 14 | from lib.net import Normalize, forward_pass, StraightThroughQuantizer 15 | from lib.quantizers import Zn 16 | import torch.nn.functional as F 17 | import torch 18 | import itertools 19 | import pickle as pkl 20 | 21 | 22 | def repeat(l, r): 23 | return list(itertools.chain.from_iterable(itertools.repeat(x, r) for x in l)) 24 | 25 | 26 | def pairwise_NNs_inner(x): 27 | """ 28 | Pairwise nearest neighbors for L2-normalized vectors. 29 | Uses Torch rather than Faiss to remain on GPU. 30 | """ 31 | # parwise dot products (= inverse distance) 32 | dots = torch.mm(x, x.t()) 33 | n = x.shape[0] 34 | dots.view(-1)[::(n+1)].fill_(-1) # Trick to fill diagonal with -1 35 | _, I = torch.max(dots, 1) # max inner prod -> min distance 36 | return I 37 | 38 | 39 | def triplet_optimize(xt, gt_nn, net, args, val_func): 40 | """ 41 | train a triplet loss on the training set xt (a numpy array) 42 | gt_nn: ground-truth nearest neighbors in input space 43 | net: network to optimize 44 | args: various runtime arguments 45 | val_func: callback called periodically to evaluate the network 46 | """ 47 | 48 | lr_schedule = [float(x.rstrip().lstrip()) for x in args.lr_schedule.split(",")] 49 | assert args.epochs % len(lr_schedule) == 0 50 | lr_schedule = repeat(lr_schedule, args.epochs // len(lr_schedule)) 51 | print("Lr schedule", lr_schedule) 52 | 53 | N, kpos = gt_nn.shape 54 | 55 | if args.quantizer_train != "": 56 | assert args.quantizer_train.startswith("zn_") 57 | r2 = int(args.quantizer_train.split("_")[1]) 58 | qt = StraightThroughQuantizer(Zn(r2)) 59 | else: 60 | qt = lambda x: x 61 | 62 | xt_var = torch.from_numpy(xt).to(args.device) 63 | 64 | # prepare optimizer 65 | optimizer = optim.SGD(net.parameters(), lr_schedule[0], momentum=args.momentum) 66 | pdist = nn.PairwiseDistance(2) 67 | 68 | all_logs = [] 69 | for epoch in range(args.epochs): 70 | # Update learning rate 71 | args.lr = lr_schedule[epoch] 72 | for param_group in optimizer.param_groups: 73 | param_group['lr'] = args.lr 74 | 75 | t0 = time.time() 76 | 77 | # Sample positives for triplet 78 | rank_pos = np.random.choice(kpos, size=N) 79 | positive_idx = gt_nn[np.arange(N), rank_pos] 80 | 81 | # Sample negatives for triplet 82 | net.eval() 83 | print(" Forward pass") 84 | xl_net = forward_pass(net, xt, 1024) 85 | print(" Distances") 86 | I = get_nearestneighbors(xl_net, qt(xl_net), args.rank_negative, args.device, needs_exact=False) 87 | negative_idx = I[:, -1] 88 | 89 | # training pass 90 | print(" Train") 91 | net.train() 92 | avg_triplet, avg_uniform, avg_loss = 0, 0, 0 93 | offending = idx_batch = 0 94 | 95 | # process dataset in a random order 96 | perm = np.random.permutation(N) 97 | 98 | t1 = time.time() 99 | 100 | for i0 in range(0, N, args.batch_size): 101 | i1 = min(i0 + args.batch_size, N) 102 | n = i1 - i0 103 | 104 | data_idx = perm[i0:i1] 105 | 106 | # anchor, positives, negatives 107 | ins = xt_var[data_idx] 108 | pos = xt_var[positive_idx[data_idx]] 109 | neg = xt_var[negative_idx[data_idx]] 110 | 111 | # do the forward pass (+ record gradients) 112 | ins, pos, neg = net(ins), net(pos), net(neg) 113 | pos, neg = qt(pos), qt(neg) 114 | 115 | # triplet loss 116 | per_point_loss = pdist(ins, pos) - pdist(ins, neg) 117 | per_point_loss = F.relu(per_point_loss) 118 | loss_triplet = per_point_loss.mean() 119 | offending += torch.sum(per_point_loss.data > 0).item() 120 | 121 | # entropy loss 122 | I = pairwise_NNs_inner(ins.data) 123 | distances = pdist(ins, ins[I]) 124 | loss_uniform = - torch.log(n * distances).mean() 125 | 126 | # combined loss 127 | loss = loss_triplet + args.lambda_uniform * loss_uniform 128 | 129 | # collect some stats 130 | avg_triplet += loss_triplet.data.item() 131 | avg_uniform += loss_uniform.data.item() 132 | avg_loss += loss.data.item() 133 | 134 | optimizer.zero_grad() 135 | loss.backward() 136 | optimizer.step() 137 | 138 | idx_batch += 1 139 | 140 | avg_triplet /= idx_batch 141 | avg_uniform /= idx_batch 142 | avg_loss /= idx_batch 143 | 144 | logs = { 145 | 'epoch': epoch, 146 | 'loss_triplet': avg_triplet, 147 | 'loss_uniform': avg_uniform, 148 | 'loss': avg_loss, 149 | 'offending': offending, 150 | 'lr': args.lr 151 | } 152 | all_logs.append(logs) 153 | 154 | t2 = time.time() 155 | # maybe perform a validation run 156 | if (epoch + 1) % args.val_freq == 0: 157 | logs['val'] = val_func(net, epoch, args, all_logs) 158 | 159 | t3 = time.time() 160 | 161 | # synthetic logging 162 | print ('epoch %d, times: [hn %.2f s epoch %.2f s val %.2f s]' 163 | ' lr = %f' 164 | ' loss = %g = %g + lam * %g, offending %d' % ( 165 | epoch, t1 - t0, t2 - t1, t3 - t2, 166 | args.lr, 167 | avg_loss, avg_triplet, avg_uniform, offending 168 | )) 169 | 170 | logs['times'] = (t1 - t0, t2 - t1, t3 - t2) 171 | 172 | return all_logs 173 | 174 | def quantize_and_get_hist(x, qnt): 175 | # assume first half rows of x are class 1 and the rest are class 2 176 | q = qnt.quantize(x) 177 | n = x.shape[0] // 2 178 | 179 | inv = np.unique(q, return_inverse=True, axis=0)[1] 180 | 181 | n_cluster = np.unique(inv).shape[0] 182 | print('support size of quantized mutlinomial =', n_cluster) 183 | 184 | inv1 = inv[:n] 185 | inv2 = inv[n:] 186 | cl_ids_1, counts_1 = np.unique(inv1, return_counts=True) 187 | cl_ids_2, counts_2 = np.unique(inv2, return_counts=True) 188 | counts_dict_1 = dict(zip(cl_ids_1, counts_1)) 189 | counts_dict_2 = dict(zip(cl_ids_2, counts_2)) 190 | counts_1 = np.asarray([counts_dict_1.get(i, 0) for i in range(n_cluster)]) 191 | counts_2 = np.asarray([counts_dict_2.get(i, 0) for i in range(n_cluster)]) 192 | return counts_1/counts_1.sum(), counts_2/counts_2.sum() 193 | 194 | def process_torch_feats(device, data_lst): 195 | x_lst = [] 196 | for x in data_lst: 197 | x_lst.append(x) 198 | xt = np.concatenate(x_lst) 199 | xb = xt 200 | xq = xt 201 | gt = get_nearestneighbors(xq, xb, 100, device) 202 | return xt, xb, xq, gt 203 | 204 | def train_spv_and_quantize(p_feats, q_feats, epochs=160, lambda_uniform=1.0, 205 | dint=768, dout=4, device=torch.device('cpu'), 206 | num_learn=100000, seed=25041993): 207 | device = 'cpu' if device == torch.device('cpu') else 'cuda' 208 | args = argparse.Namespace( 209 | epochs=epochs, lambda_uniform=lambda_uniform, 210 | dint=dint, dout=dout, device=device, batch_size=64, 211 | rank_positive=10, rank_negative=50, seed=seed, 212 | num_learn=num_learn, checkpoint_dir=None, quantizer_train="", 213 | lr_schedule="0.1,0.1,0.05,0.01", momentum=0.9, val_freq=10, 214 | validation_quantizers="" 215 | ) 216 | np.random.seed(args.seed) 217 | torch.manual_seed(args.seed) 218 | 219 | # Radiuses that correspond to 16, 32 and 64 bits for Zn 220 | radiuses = { 221 | 4: [20, 25, 30, 50], 222 | 8: [2, 3, 4, 5], 223 | 16: [4, 21, 200], 224 | 24: [3, 10, 79], 225 | 32: [3, 8, 36], 226 | 40: [2, 7, 24], 227 | } 228 | # Validation quantizers default to Zn 229 | args.validation_quantizers = ["zn_%d" % x for x in radiuses[args.dout]] 230 | # Default save_best is 64 bits for Zn 231 | args.save_best_criterion = "zn_%d,rank=10" % radiuses[args.dout][-1] 232 | print('args to spreadingvectors:', args) 233 | 234 | (xt, xb, xq, gt) = process_torch_feats(device, [p_feats, q_feats]) 235 | 236 | print ("keeping %d/%d training vectors" % (args.num_learn, xt.shape[0])) 237 | xt = sanitize(xt[:args.num_learn]) 238 | 239 | print ("computing training ground truth") 240 | xt_gt = get_nearestneighbors(xt, xt, args.rank_positive, device=args.device) 241 | 242 | print ("build network") 243 | 244 | dim = xb.shape[1] 245 | # dint, dout = args.dint, args.dout 246 | 247 | net = nn.Sequential( 248 | nn.Linear(in_features=dim, out_features=dint, bias=True), 249 | nn.BatchNorm1d(dint), 250 | nn.ReLU(), 251 | nn.Linear(in_features=dint, out_features=dint, bias=True), 252 | nn.BatchNorm1d(dint), 253 | nn.ReLU(), 254 | nn.Linear(in_features=dint, out_features=dout, bias=True), 255 | Normalize() 256 | ) 257 | 258 | net.to(args.device) 259 | 260 | val = ValidationFunction_k(xq, xb, gt, args.checkpoint_dir, 261 | validation_key=args.save_best_criterion, 262 | quantizers=args.validation_quantizers) 263 | 264 | all_logs = triplet_optimize(xt, xt_gt, net, args, val) 265 | xt_torch = torch.from_numpy(xt).to(args.device) 266 | with torch.no_grad(): 267 | feats = net(xt_torch) 268 | feats = feats.cpu().numpy() 269 | print(f'feats shape: {feats.shape}') 270 | 271 | quant = Zn(r2=50, d=feats.shape[1]) 272 | hist1, hist2 = quantize_and_get_hist(feats, quant) 273 | return hist1, hist2 274 | 275 | 276 | -------------------------------------------------------------------------------- /local_scripts/download_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download human written text and text generated by GPT-2. 3 | The data is hosted by OpenAI. This script is adapted from this repository: 4 | https://github.com/openai/gpt-2-output-dataset 5 | This script might need to be updated in the future if the data hosting service changes. 6 | Future changes must track this script: 7 | https://github.com/openai/gpt-2-output-dataset/blob/master/download_dataset.py 8 | """ 9 | import os 10 | import requests 11 | from tqdm import tqdm 12 | 13 | DATA_URL = "https://openaipublic.azureedge.net/gpt-2/output-dataset/v1/" 14 | subdir = 'data' 15 | 16 | if __name__ == '__main__': 17 | if not os.path.exists(subdir): 18 | os.makedirs(subdir) 19 | subdir = subdir.replace('\\','/') # needed for Windows 20 | 21 | for ds in ['webtext']: 22 | for split in ['test']: 23 | filename = ds + "." + split + '.jsonl' 24 | r = requests.get(DATA_URL + filename, stream=True) 25 | 26 | with open(os.path.join(subdir, filename), 'wb') as f: 27 | file_size = int(r.headers["content-length"]) 28 | chunk_size = 1000 29 | with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar: 30 | # 1k for chunk_size, since Ethernet packet size is around 1500 bytes 31 | for chunk in r.iter_content(chunk_size=chunk_size): 32 | f.write(chunk) 33 | pbar.update(chunk_size) 34 | -------------------------------------------------------------------------------- /local_scripts/make_output_dirs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | for outer_dir in "outputs/webtext_gpt2" "outputs/webtext_gpt2-large" "outputs/webtext_gpt2-medium" "outputs/webtext_gpt2-xl" 5 | do 6 | for dir in outs save generations metrics 7 | do 8 | mkdir -p ${outer_dir}/${dir} 9 | mkdir -p ${outer_dir}/${dir}/ref 10 | mkdir -p ${outer_dir}/${dir}/basic 11 | done 12 | done 13 | 14 | 15 | -------------------------------------------------------------------------------- /local_scripts/parallelize.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Original author: deajan 3 | # Code below was obtained from https://stackoverflow.com/a/39189370 4 | 5 | function f_log { 6 | echo "$1" 7 | } 8 | 9 | # Take a list of commands to run, runs them sequentially with numberOfProcesses commands simultaneously runs 10 | # Returns the number of non zero exit codes from commands 11 | function f_ParallelExec { 12 | local numberOfProcesses="${1}" # Number of simultaneous commands to run 13 | local commandsArg="${2}" # Semi-colon separated list of commands 14 | 15 | local pid 16 | local runningPids=0 17 | local counter=0 18 | local commandsArray 19 | local pidsArray 20 | local newPidsArray 21 | local retval 22 | local retvalAll=0 23 | local pidState 24 | local commandsArrayPid 25 | 26 | IFS=';' read -r -a commandsArray <<< "$commandsArg" 27 | 28 | f_log "Runnning ${#commandsArray[@]} commands in $numberOfProcesses simultaneous processes." 29 | 30 | while [ $counter -lt "${#commandsArray[@]}" ] || [ ${#pidsArray[@]} -gt 0 ]; do 31 | 32 | while [ $counter -lt "${#commandsArray[@]}" ] && [ ${#pidsArray[@]} -lt $numberOfProcesses ]; do 33 | f_log "Running command [${commandsArray[$counter]}]." 34 | eval "${commandsArray[$counter]}" & 35 | pid=$! 36 | pidsArray+=($pid) 37 | commandsArrayPid[$pid]="${commandsArray[$counter]}" 38 | counter=$((counter+1)) 39 | done 40 | 41 | 42 | newPidsArray=() 43 | for pid in "${pidsArray[@]}"; do 44 | # Handle uninterruptible sleep state or zombies by ommiting them from running process array (How to kill that is already dead ? :) 45 | if kill -0 $pid > /dev/null 2>&1; then 46 | pidState=$(ps -p$pid -o state= 2 > /dev/null) 47 | if [ "$pidState" != "D" ] && [ "$pidState" != "Z" ]; then 48 | newPidsArray+=($pid) 49 | fi 50 | else 51 | # pid is dead, get it's exit code from wait command 52 | wait $pid 53 | retval=$? 54 | if [ $retval -ne 0 ]; then 55 | f_log "Command [${commandsArrayPid[$pid]}] failed with exit code [$retval]." 56 | retvalAll=$((retvalAll+1)) 57 | fi 58 | fi 59 | done 60 | pidsArray=("${newPidsArray[@]}") 61 | 62 | # Add a trivial sleep time so bash won't eat all CPU 63 | sleep .05 64 | done 65 | 66 | return $retvalAll 67 | } 68 | -------------------------------------------------------------------------------- /local_scripts/webtext/mauve_metrics_drmm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # TODO: set `njobs` (number of jobs to run at one) and `data_dir` 4 | 5 | source local_scripts/parallelize.sh 6 | njobs=32 # TODO 7 | cmds="" 8 | 9 | set -u # x: stack trace 10 | export MKL_NUM_THREADS=1 11 | export NUMEXPR_NUM_THREADS=1 12 | export OMP_NUM_THREADS=1 13 | 14 | export CUDA_VISIBLE_DEVICES="" 15 | export DISABLE_TQDM=True 16 | 17 | # options 18 | datasplit="test" 19 | dataset="webtext" 20 | model_name="gpt2-large" 21 | max_len=${1} 22 | 23 | # Default args 24 | if [ ${dataset} == "webtext" ]; then 25 | data_dir="./data" #TODO 26 | else 27 | data_dir="UNKNOWN dataset ${dataset}" 28 | exit 100 29 | fi 30 | args=" --data_dir ${data_dir} --model_name ${model_name} " 31 | 32 | for generate_seed in 0 1 2 3 4 33 | do 34 | for discretization in "drmm" 35 | do 36 | 37 | options=" ${args} --datasplit ${datasplit} --discretization ${discretization} --device -1" 38 | options="${options} --drmm_num_epochs 20 --drmm_n_layer 3 --drmm_n_component_per_layer 10" 39 | options="${options} --generate_seed ${generate_seed} --seed 1234" 40 | options="${options} --use_large_feats --max_len ${max_len}" 41 | 42 | ################## 43 | # basic 44 | ################## 45 | # nucleus 46 | for p in 0.8 0.9 0.92 0.95 0.99 47 | do 48 | cmds="$cmds ; time python -u compute_mauve_metrics.py ${options} --generation_type basic --top_p ${p} > outs/basic/mauve_${discretization}_p_${p}_${generate_seed} 2>&1 " 49 | done 50 | 51 | # top-k 52 | for k in 1 5 10 50 100 500 1000 2000 5000 53 | do 54 | cmds="$cmds ; time python -u compute_mauve_metrics.py ${options} --generation_type basic --top_k ${k} > outs/basic/mauve_${discretization}_k_${k}_${generate_seed} 2>&1 " 55 | done 56 | 57 | # temperature 58 | for t in 0.7 0.8 0.9 0.95 1.0 59 | do 60 | cmds="$cmds ; time python -u compute_mauve_metrics.py ${options} --generation_type basic --temp ${t} > outs/basic/mauve_${discretization}_t_${t}_${generate_seed} 2>&1 " 61 | done 62 | 63 | 64 | # top-k + temp 65 | for t in 0.75 0.9 66 | do 67 | for k in 10 100 68 | do 69 | cmds="$cmds ; time python -u compute_mauve_metrics.py ${options} --generation_type basic --top_k ${k} --temp ${t} > outs/basic/mauve_${discretization}_k_${k}_t_${t}_${generate_seed} 2>&1 " 70 | done 71 | done 72 | 73 | 74 | 75 | 76 | done # discretization 77 | done # seed 78 | 79 | ############ DONE ########### 80 | 81 | echo "executing..." 82 | date 83 | set +u # for parallel exec to work (unbound variables) 84 | f_ParallelExec $njobs "$cmds" 85 | date 86 | -------------------------------------------------------------------------------- /local_scripts/webtext/mauve_metrics_kmeans.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # TODO: set `njobs` (number of jobs to run at one) and `data_dir` 4 | 5 | source local_scripts/parallelize.sh 6 | njobs=24 # TODO 7 | cmds="" 8 | 9 | 10 | set -u # x: stack trace 11 | export MKL_NUM_THREADS=1 12 | export NUMEXPR_NUM_THREADS=1 13 | export OMP_NUM_THREADS=1 14 | 15 | export CUDA_VISIBLE_DEVICES="" 16 | export DISABLE_TQDM=True 17 | 18 | # options 19 | datasplit="test" 20 | dataset="webtext" 21 | #model_name="gpt2-large" 22 | #max_len=${1} 23 | 24 | # Default args 25 | if [ ${dataset} == "webtext" ]; then 26 | data_dir="./data" # TODO 27 | else 28 | data_dir="UNKNOWN dataset ${dataset}" 29 | exit 100 30 | fi 31 | 32 | discretization="kmeans_l2" 33 | kmeans_num_clusters=500 34 | 35 | for max_len in 1024 512 256 128 36 | do 37 | for generate_seed in 0 1 2 3 4 38 | do 39 | for model_name in "gpt2" "gpt2-medium" "gpt2-large" "gpt2-xl" 40 | do 41 | 42 | args=" --data_dir ${data_dir} --model_name ${model_name} " 43 | sn="${discretization}_${kmeans_num_clusters}_${model_name}_${max_len}" 44 | 45 | options="${args} --datasplit ${datasplit} --discretization ${discretization} --device -1" 46 | options="${options} --kmeans_num_clusters ${kmeans_num_clusters}" 47 | options="${options} --generate_seed ${generate_seed} --seed 1234" 48 | options="${options} --use_large_feats --max_len ${max_len} --kmeans_explained_var 0.9" 49 | 50 | ################## 51 | # basic 52 | ################## 53 | # nucleus 54 | for p in 0.9 0.92 0.95 0.99 55 | do 56 | cmds="$cmds ; time python -u compute_mauve_metrics.py ${options} --generation_type basic --top_p ${p} > outs/basic/mauve_${sn}_p_${p}_seed${generate_seed} 2>&1 " 57 | done 58 | 59 | # top-k 60 | for k in 1 61 | do 62 | cmds="$cmds ; time python -u compute_mauve_metrics.py ${options} --generation_type basic --top_k ${k} > outs/basic/mauve_${sn}_k_${k}_seed${generate_seed} 2>&1 " 63 | done 64 | 65 | # temperature 66 | for t in 1.0 67 | do 68 | cmds="$cmds ; time python -u compute_mauve_metrics.py ${options} --generation_type basic --temp ${t} > outs/basic/mauve_${sn}_t_${t}_seed${generate_seed} 2>&1 " 69 | done 70 | 71 | done # model_name 72 | done # seed 73 | done # length 74 | 75 | ############ DONE ########### 76 | 77 | echo "executing..." 78 | date 79 | set +u # for parallel exec to work (unbound variables) 80 | f_ParallelExec $njobs "$cmds" 81 | date 82 | -------------------------------------------------------------------------------- /local_scripts/webtext/mauve_metrics_spv.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # TODO: set `data_dir` and `device` 4 | njobs=1 5 | cmds="" 6 | 7 | 8 | set -exu # x: stack trace 9 | export MKL_NUM_THREADS=1 10 | export NUMEXPR_NUM_THREADS=1 11 | export OMP_NUM_THREADS=1 12 | 13 | export DISABLE_TQDM=True 14 | 15 | # options 16 | datasplit="test" 17 | dataset="webtext" 18 | model_name="gpt2-large" 19 | 20 | # Default args 21 | if [ ${dataset} == "webtext" ]; then 22 | data_dir="./data" # TODO 23 | else 24 | data_dir="UNKNOWN dataset ${dataset}" 25 | exit 100 26 | fi 27 | args=" --data_dir ${data_dir} --model_name ${model_name} " 28 | 29 | for generate_seed in 0 1 2 3 4 30 | do 31 | discretization="spv" 32 | device=0 # TODO 33 | options="${args} --datasplit ${datasplit} --discretization ${discretization} --device ${device}" 34 | options="${options} --spv_num_epochs 200 " 35 | options="${options} --generate_seed ${generate_seed} --seed 1234" 36 | options="${options} --use_large_feats" 37 | 38 | ################## 39 | # basic 40 | ################## 41 | # nucleus 42 | for p in 0.8 0.9 0.92 0.95 0.99 43 | do 44 | time python -u compute_mauve_metrics.py ${options} --generation_type basic --top_p ${p} > outs/basic/mauve_${discretization}_p_${p} 2>&1 45 | done 46 | 47 | # top-k 48 | for k in 1 5 10 50 100 500 1000 2000 5000 49 | do 50 | time python -u compute_mauve_metrics.py ${options} --generation_type basic --top_k ${k} > outs/basic/mauve_${discretization}_k_${k} 2>&1 51 | done 52 | 53 | # temperature 54 | for t in 0.7 0.8 0.9 0.95 1.0 55 | do 56 | time python -u compute_mauve_metrics.py ${options} --generation_type basic --temp ${t} > outs/basic/mauve_${discretization}_t_${t} 2>&1 57 | done 58 | 59 | 60 | # top-k + temp 61 | for t in 0.75 0.9 62 | do 63 | for k in 10 100 64 | do 65 | time python -u compute_mauve_metrics.py ${options} --generation_type basic --top_k ${k} --temp ${t} > outs/basic/mauve_${discretization}_k_${k}_t_${t} 2>&1 66 | done 67 | done 68 | 69 | 70 | 71 | done # seed 72 | 73 | date 74 | -------------------------------------------------------------------------------- /local_scripts/webtext/run_all_L_metrics.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # TODO: set `data_dir` and `device` 5 | 6 | set -exu # x: stack trace 7 | export MKL_NUM_THREADS=1 8 | export NUMEXPR_NUM_THREADS=1 9 | export OMP_NUM_THREADS=1 10 | 11 | export DISABLE_TQDM=True 12 | 13 | # options 14 | datasplit="test" 15 | 16 | dataset="webtext" 17 | model_size=${1} 18 | model_name="gpt2-${model_size}" 19 | 20 | device=${2} 21 | 22 | if [ ${model_name} == "gpt2-small" ]; then 23 | model_name="gpt2" 24 | fi 25 | 26 | 27 | # Default args 28 | if [ ${dataset} == "webtext" ]; then 29 | data_dir="./data" # TODO 30 | else 31 | data_dir="UNKNOWN dataset ${dataset}" 32 | exit 100 33 | fi 34 | args=" --data_dir ${data_dir} --model_name ${model_name} --device ${device} " 35 | 36 | for generate_seed in 0 1 2 3 4 37 | do 38 | 39 | options="--datasplit ${datasplit} ${args}" 40 | options="${options} --generate_seed ${generate_seed} --seed 1234" 41 | 42 | ################## 43 | # basic 44 | ################## 45 | # nucleus 46 | for p in 0.9 0.92 0.95 0.99 1.0 47 | do 48 | time python -u compute_all_L_metrics.py ${options} --generation_type basic --top_p ${p} > outs/basic/all_p_${p}_${generate_seed}_${model_size} 2>&1 49 | done 50 | 51 | # top-k 52 | for k in 1 53 | do 54 | time python -u compute_all_L_metrics.py ${options} --generation_type basic --top_k ${k} > outs/basic/all_k_${k}_${generate_seed}_${model_size} 2>&1 55 | done 56 | 57 | done # seed 58 | 59 | ############ DONE ########### 60 | 61 | date 62 | -------------------------------------------------------------------------------- /local_scripts/webtext/run_lm_metrics.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -exu 3 | 4 | device=1 5 | data_dir="./data" # TODO 6 | datasplit="test" 7 | max_len=1024 8 | model_name="gpt2-large" 9 | 10 | echo "Starting at $(date)" 11 | 12 | time python -u compute_lm_metrics_basic.py \ 13 | --device ${device} \ 14 | --data_dir ${data_dir} \ 15 | --datasplit ${datasplit} \ 16 | --model_name ${model_name} \ 17 | --max_len ${max_len} \ 18 | > outs/lm_large 2>&1 19 | 20 | echo "Done at $(date)" 21 | -------------------------------------------------------------------------------- /local_scripts/webtext/run_self_bleu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # TODO: set `njobs` (number of jobs to run at one) and `data_dir` 4 | 5 | source local_scripts/parallelize.sh 6 | njobs=28 # TODO 7 | cmds="" 8 | 9 | 10 | set -u # x: stack trace 11 | export MKL_NUM_THREADS=1 12 | export NUMEXPR_NUM_THREADS=1 13 | export OMP_NUM_THREADS=1 14 | 15 | export CUDA_VISIBLE_DEVICES="" 16 | export DISABLE_TQDM=True 17 | 18 | # options 19 | datasplit="test" 20 | dataset="webtext" 21 | model_name="gpt2-large" 22 | 23 | # Default args 24 | if [ ${dataset} == "webtext" ]; then 25 | data_dir="./data" # TODO 26 | else 27 | data_dir="UNKNOWN dataset ${dataset}" 28 | exit 100 29 | fi 30 | args=" --data_dir ${data_dir} --model_name ${model_name} " 31 | 32 | for generate_seed in 0 1 2 3 4 33 | do 34 | 35 | options="${args} --datasplit ${datasplit} --device -1" 36 | options="${options} --generate_seed ${generate_seed} --seed 1234" 37 | 38 | ################## 39 | # basic 40 | ################## 41 | # nucleus 42 | for p in 0.8 0.9 0.92 0.95 0.99 43 | do 44 | cmds="$cmds ; time python -u compute_self_bleu_metric.py ${options} --generation_type basic --top_p ${p} > outs/basic/bleu_p_${p}_${generate_seed} 2>&1 " 45 | done 46 | 47 | # top-k 48 | for k in 1 5 10 50 100 500 1000 2000 5000 10000 49 | do 50 | cmds="$cmds ; time python -u compute_self_bleu_metric.py ${options} --generation_type basic --top_k ${k} > outs/basic/bleu_k_${k}_${generate_seed} 2>&1 " 51 | done 52 | 53 | # temperature 54 | for t in 0.7 0.8 0.9 0.95 1.0 55 | do 56 | cmds="$cmds ; time python -u compute_self_bleu_metric.py ${options} --generation_type basic --temp ${t} > outs/basic/bleu_t_${t}_${generate_seed} 2>&1 " 57 | done 58 | 59 | 60 | # top-k + temp 61 | for t in 0.75 0.9 62 | do 63 | for k in 10 100 64 | do 65 | cmds="$cmds ; time python -u compute_self_bleu_metric.py ${options} --generation_type basic --top_k ${k} --temp ${t} > outs/basic/bleu_k_${k}_t_${t}_${generate_seed} 2>&1 " 66 | done 67 | done 68 | 69 | ################## 70 | # beam 71 | ################## 72 | for bs in 4 8 73 | do 74 | for t in 1.0 0.9 75 | do 76 | for nr in 0 4 77 | do 78 | 79 | fn=outs/beam/bleu_b${bs}_t${t}_n${nr}_${generate_seed} 80 | cmds="$cmds ; time python -u compute_self_bleu_metric.py ${options} --generation_type beam --beam_size ${bs} --temp ${t} --no_repeat_ngram ${nr} > ${fn} 2>&1 " 81 | 82 | done # nr 83 | done # temp 84 | done # bs 85 | 86 | ################## 87 | # entmax 88 | ################## 89 | for alpha in "1.2" 90 | do 91 | fn=outs/entmax/bleu_entmax${alpha}_${generate_seed} 92 | cmds="$cmds ; time python -u compute_self_bleu_metric.py ${options} --generation_type entmax --entmax_alpha ${alpha} > ${fn} 2>&1 " 93 | done 94 | 95 | 96 | done # seed 97 | 98 | ############ DONE ########### 99 | 100 | echo "executing..." 101 | date 102 | set +u # for parallel exec to work (unbound variables) 103 | f_ParallelExec $njobs "$cmds" 104 | date -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk==3.4.5 2 | transformers==4.2.0 3 | scikit-learn==0.22.1 4 | faiss-gpu==1.7.0 5 | tqdm==4.40.0 6 | requests 7 | -------------------------------------------------------------------------------- /slurm_scripts/webtext/arr_generate_basic.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## TO change: `data_dir` in line 45 and output directory. Pass in model size as argument. 4 | 5 | #SBATCH --job-name=gen_basic 6 | #SBATCH --comment="Generate all baselines" 7 | #SBATCH --array=0-29%4 8 | #SBATCH --output=TODO 9 | #SBATCH --nodes=1 10 | #SBATCH --cpus-per-task=4 11 | #SBATCH --mem=40G 12 | #SBATCH --gres=gpu:1 13 | #SBATCH --time=12:00:00 14 | #SBATCH --open-mode=append 15 | #SBATCH --mail-type=ALL 16 | 17 | 18 | # Initialize conda into the right environment + modules. 19 | source ~/.bashrc 20 | conda activate pyt17 # cuda 10.1 21 | #conda activate pyt14_tf1 # cuda 10.0 22 | export DISABLE_TQDM=True 23 | 24 | echo "Running [ ${0} ${@} ] on $(hostname), starting at $(date)" 25 | echo "Job id = ${SLURM_JOB_ID}, task id = ${SLURM_ARRAY_TASK_ID}" 26 | echo "PWD = $(pwd)" 27 | 28 | set -exu 29 | 30 | 31 | model_size=$1 # pass model size as argument 32 | prompt_size=35 33 | 34 | dataset="webtext" 35 | model_name="gpt2-${model_size}" 36 | 37 | if [ ${model_name} == "gpt2-small" ]; then 38 | model_name="gpt2" 39 | fi 40 | 41 | # Default args 42 | if [ ${dataset} == "webtext" ]; then 43 | data_dir="./data" ###TODO 44 | else 45 | data_dir="UNKNOWN dataset ${dataset}" 46 | exit 100 47 | fi 48 | 49 | 50 | list_of_jobs=() 51 | 52 | for seed in 0 1 2 3 4 53 | do 54 | for datasplit in test 55 | do 56 | 57 | # nucleus 58 | for p in 0.9 0.92 0.95 0.99 59 | do 60 | k=0 61 | t=1 62 | job="--top_p ${p} --top_k ${k} --temp ${t} --seed ${seed}" 63 | list_of_jobs+=("${job}") 64 | done 65 | 66 | # top-k 67 | #for k in 1 5 10 50 100 500 1000 2000 5000 1000 68 | for k in 1 69 | do 70 | p=1 71 | t=1 72 | job="--top_p ${p} --top_k ${k} --temp ${t} --seed ${seed}" 73 | list_of_jobs+=("${job}") 74 | done 75 | 76 | # temperature 77 | #for t in 0.7 0.8 0.9 0.95 1.0 78 | for t in 1.0 79 | do 80 | p=1 81 | k=0 82 | job="--top_p ${p} --top_k ${k} --temp ${t} --seed ${seed}" 83 | list_of_jobs+=("${job}") 84 | done 85 | # 86 | ## top-k + temperature 87 | #for t in 0.75 0.9 88 | #do 89 | #for k in 10 100 90 | #do 91 | # p=1 92 | # job="--top_p ${p} --top_k ${k} --temp ${t} --seed ${seed}" 93 | # list_of_jobs+=("${job}") 94 | #done 95 | #done 96 | 97 | done # datasplit 98 | done # seed 99 | 100 | num_jobs=${#list_of_jobs[@]} 101 | 102 | job_id=${SLURM_ARRAY_TASK_ID} 103 | 104 | if [ ${job_id} -ge ${num_jobs} ] ; then 105 | echo "Invalid job id; qutting" 106 | exit 2 107 | fi 108 | 109 | echo "-------- STARTING JOB ${job_id}/${num_jobs}" 110 | 111 | args=${list_of_jobs[${job_id}]} 112 | 113 | 114 | time python -u generate_basic.py ${args} \ 115 | --device 0 \ 116 | --datasplit ${datasplit} \ 117 | --data_dir ${data_dir} \ 118 | --model_name ${model_name} \ 119 | --prompt_size ${prompt_size} \ 120 | --use_large_feats 121 | 122 | echo "Job completed at $(date)" 123 | -------------------------------------------------------------------------------- /slurm_scripts/webtext/generate_ref.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## TODO: set output dirs (line 7) and data_dir (line 39) 4 | 5 | #SBATCH --job-name=gen_ref 6 | #SBATCH --comment="Generate all baselines" 7 | #SBATCH --output=#TODO 8 | #SBATCH --nodes=1 9 | #SBATCH --cpus-per-task=4 10 | #SBATCH --mem=40G 11 | #SBATCH --gres=gpu:1 12 | #SBATCH --time=2:00:00 13 | 14 | 15 | # Initialize conda into the right environment + modules. 16 | source ~/.bashrc 17 | conda activate pyt17 # cuda 10.1 18 | export DISABLE_TQDM=True 19 | 20 | echo "Running [ ${0} ${@} ] on $(hostname), starting at $(date)" 21 | echo "Job id = ${SLURM_JOB_ID}, task id = ${SLURM_ARRAY_TASK_ID}" 22 | echo "PWD = $(pwd)" 23 | 24 | set -exu 25 | 26 | # TODO: set dataset and model_name 27 | 28 | model_size="large" 29 | 30 | dataset="webtext" 31 | model_name="gpt2-${model_size}" 32 | 33 | if [ ${model_name} == "gpt2-small" ]; then 34 | model_name="gpt2" 35 | fi 36 | 37 | # Default args 38 | if [ ${dataset} == "webtext" ]; then 39 | data_dir="./data" #TODO 40 | else 41 | data_dir="UNKNOWN dataset ${dataset}" 42 | exit 100 43 | fi 44 | 45 | 46 | for datasplit in "test" "valid" 47 | do 48 | 49 | time python -u generate_ref.py \ 50 | --device 0 --seed 0 \ 51 | --datasplit ${datasplit} \ 52 | --data_dir ${data_dir} \ 53 | --model_name ${model_name} \ 54 | --use_large_feats 55 | 56 | done 57 | 58 | echo "Job completed at $(date)" 59 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krishnap25/mauve-experiments/d1753fdf09396606defe5fa5749f4a7e8fe24c96/src/__init__.py -------------------------------------------------------------------------------- /src/mauve_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import time 4 | import sklearn.metrics 5 | 6 | import src.utils as utils 7 | 8 | try: 9 | import faiss 10 | from sklearn.preprocessing import normalize 11 | from sklearn.decomposition import PCA 12 | FOUND_FAISS = True 13 | except (ImportError, ModuleNotFoundError): 14 | print('faiss or sklearn not found', file=sys.stderr) 15 | FOUND_FAISS = False 16 | 17 | try: 18 | sys.path.append('library/spreadingvectors') 19 | from train_spv import train_spv_and_quantize 20 | FOUND_SPV = True 21 | except ImportError: 22 | print('SpreadingVectors not found', file=sys.stderr) 23 | FOUND_SPV = False 24 | 25 | try: 26 | sys.path.append('library') 27 | from DRMM import train_drmm_and_quantize 28 | FOUND_DRMM = True 29 | except (ImportError, ModuleNotFoundError): 30 | print('DRMM or TensorFlow not found', file=sys.stderr) 31 | FOUND_DRMM = False 32 | 33 | 34 | # PR metrics 35 | def compute_mauve_metrics(p_feats, q_feats, discretization_algo='kmeans_l1', 36 | kmeans_num_clusters=100, kmeans_explained_var=0.99, 37 | device=utils.CPU_DEVICE, spv_num_epochs=160, 38 | drmm_num_epochs=4, drmm_n_layer=3, drmm_n_comp_per_layer=10, 39 | seed=25041993): 40 | """ 41 | p_feats, q_feats are torch.Tensor 42 | """ 43 | t1 = time.time() 44 | if discretization_algo == 'kmeans_l1': 45 | if not FOUND_FAISS: 46 | print('Faiss or sklearn not found. Exiting') 47 | sys.exit(-1) 48 | p, q = cluster_feats(p_feats.detach().cpu().numpy(), 49 | q_feats.detach().cpu().numpy(), 50 | num_clusters=kmeans_num_clusters, 51 | norm='l1', whiten=True, min_var=kmeans_explained_var, 52 | seed=seed) 53 | elif discretization_algo == 'kmeans_l2': 54 | if not FOUND_FAISS: 55 | print('Faiss or sklearn not found. Exiting') 56 | sys.exit(-1) 57 | p, q = cluster_feats(p_feats.detach().cpu().numpy(), 58 | q_feats.detach().cpu().numpy(), 59 | num_clusters=kmeans_num_clusters, 60 | norm='l2', whiten=False, min_var=kmeans_explained_var, 61 | seed=seed) 62 | elif discretization_algo in ['spv', 'spreadingvectors', 'lattice']: 63 | if not FOUND_SPV: 64 | print('SpreadingVectors not found. Exiting') 65 | sys.exit(-1) 66 | num_epochs = (spv_num_epochs // 4) * 4 # make it divisible by 4 67 | p, q = train_spv_and_quantize(p_feats, q_feats, 68 | device=device, 69 | epochs=num_epochs, seed=seed) 70 | # p, q: (744,) 71 | elif discretization_algo == 'drmm': 72 | if not FOUND_DRMM: 73 | print('DRMM or tensorflow not found. Exiting') 74 | sys.exit(-1) 75 | p, q = train_drmm_and_quantize( 76 | p_feats.detach().cpu().numpy(), q_feats.detach().cpu().numpy(), seed=seed, 77 | nEpoch=drmm_num_epochs, nComponentsPerLayer=drmm_n_comp_per_layer, nLayers=drmm_n_layer, 78 | ) 79 | # p, q: at most (drmm_n_comp_per_layer ** drmm_n_layer,) 80 | else: 81 | raise ValueError('Unknown discretization algo: ', discretization_algo) 82 | t2 = time.time() 83 | print('discretization time:', round(t2-t1, 2)) 84 | metrics = get_mauve_score(p, q) 85 | return p, q, metrics 86 | 87 | 88 | # PR metrics 89 | def get_discretization_algo_name( 90 | discretization_algo='kmeans_l1', kmeans_num_clusters=100, kmeans_explained_var=0.99, 91 | device=utils.CPU_DEVICE, spv_num_epochs=160, seed=25041993, 92 | drmm_num_epochs=4, drmm_n_layer=3, drmm_n_comp_per_layer=10 93 | ): 94 | assert 0 < kmeans_explained_var < 1 95 | kmeans_args = f'{kmeans_num_clusters}_{kmeans_explained_var}' if kmeans_explained_var != 0.99 else kmeans_num_clusters 96 | if discretization_algo == 'kmeans_l1': 97 | name = f'kmeans_l1_{kmeans_args}' 98 | elif discretization_algo == 'kmeans_l2': 99 | name = f'kmeans_l2_{kmeans_args}' 100 | elif discretization_algo in ['spv', 'spreadingvectors', 'lattice']: 101 | name = 'spv' 102 | elif discretization_algo == 'drmm': 103 | name = f'drmm_{drmm_n_layer}_{drmm_n_comp_per_layer}' 104 | else: 105 | raise ValueError('Unknown discretization algo: ', discretization_algo) 106 | return name 107 | 108 | ################## 109 | # Helper functions 110 | ################## 111 | def cluster_feats(p, q, num_clusters, 112 | norm='none', whiten=True, min_var=0.99, 113 | niter=500, seed=0): 114 | """ p, q are numpy arrays""" 115 | assert 0 < min_var < 1 116 | print(f'seed = {seed}') 117 | assert norm in ['none', 'l2', 'l1', None] 118 | data1 = np.vstack([q, p]) 119 | if norm in ['l2', 'l1']: 120 | data1 = normalize(data1, norm=norm, axis=1) 121 | pca = PCA(n_components=None, whiten=whiten, random_state=seed+1) 122 | pca.fit(data1) 123 | s = np.cumsum(pca.explained_variance_ratio_) 124 | idx = np.argmax(s >= min_var) # last index to consider 125 | print(f'lower dimensionality = {idx}') 126 | data1 = pca.transform(data1)[:, :idx+1] 127 | # Cluster 128 | data1 = data1.astype(np.float32) 129 | d = data1.shape[1] 130 | t1 = time.time() 131 | kmeans = faiss.Kmeans(data1.shape[1], num_clusters, niter=niter, verbose=True, 132 | nredo=5, update_index=True, seed=seed+2) 133 | kmeans.train(data1) 134 | _, labels = kmeans.index.search(data1, 1) 135 | labels = labels.reshape(-1) 136 | t2 = time.time() 137 | print('kmeans time:', round(t2-t1, 2)) 138 | 139 | q_labels = labels[:len(q)] 140 | p_labels = labels[len(q):] 141 | 142 | q_bins = np.histogram(q_labels, bins=num_clusters, 143 | range=[0, num_clusters], density=True)[0] 144 | p_bins = np.histogram(p_labels, bins=num_clusters, 145 | range=[0, num_clusters], density=True)[0] 146 | return p_bins / p_bins.sum(), q_bins / q_bins.sum() 147 | 148 | 149 | def kl_multinomial(p, q): 150 | assert p.shape == q.shape 151 | if np.logical_and(p != 0, q == 0).any(): 152 | return np.inf 153 | else: 154 | idxs = np.logical_and(p != 0, q != 0) 155 | return np.sum(p[idxs] * np.log(p[idxs] / q[idxs])) 156 | 157 | def get_mauve_score(p, q, mixture_weights=np.linspace(0, 1, 100), scaling_factor=5): 158 | angles = np.linspace(1e-6, np.pi / 2 - 1e-6, 25) 159 | mixture_weights = np.cos(angles) # on an angular grid 160 | divergence_curve = [[0, np.inf]] # extreme point 161 | for w in np.sort(mixture_weights): 162 | r = w * p + (1 - w) * q 163 | divergence_curve.append([kl_multinomial(q, r), kl_multinomial(p, r)]) 164 | divergence_curve.append([np.inf, 0]) # other extreme point 165 | divergence_curve = np.exp(-scaling_factor * np.asarray(divergence_curve)) 166 | mauve = sklearn.metrics.auc(divergence_curve[:, 0], divergence_curve[:, 1]) 167 | return mauve 168 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import operator 3 | from sklearn import linear_model 4 | import torch 5 | 6 | import src.model_utils 7 | from src.utils import tqdm 8 | from nltk.util import ngrams as ngrams_fn_nltk 9 | from collections import Counter 10 | 11 | import src.utils as utils 12 | 13 | ###################### 14 | # Sparsemax score 15 | ###################### 16 | def sp_score_1(p, sen): 17 | # logp: (b, seq_len, vocab_size) 18 | n = p.shape[1] # seq_len 19 | p = p[0, :-1, :] # (n - 1, vocab_size) 20 | labels = sen[0, 1:] # (n - 1) 21 | count = labels.shape[0] 22 | sp = p[torch.arange(n-1), labels] + 0.5 * (1 - torch.norm(p, dim=1)**2) # (n-1) 23 | return sp.sum().item(), count 24 | 25 | ###################### 26 | # Jensen-Shannon score 27 | ###################### 28 | 29 | def kl(p, q): 30 | idxs = (p != 0) 31 | return (p[idxs] * torch.log(p[idxs] / q[idxs])).sum() 32 | 33 | def js_score_1_naive(p, sen): 34 | # not numerically stable; direct implementation from the formula 35 | n = p.shape[1] # seq_len 36 | p = p[0, :-1, :] # (n - 1, vocab_size) 37 | labels = sen[0, 1:] # (n - 1) 38 | count = labels.shape[0] 39 | p_true = torch.zeros_like(p) 40 | p_true[torch.arange(n-1), labels] = 1.0 # one hot 41 | m = 0.5 * (p + p_true) 42 | js = 0.5 * kl(p, m) + 0.5 * kl(p_true, m) 43 | return js.item(), count 44 | 45 | def js_score_1(p, sen): 46 | # numerically stable version of JS score 47 | n = p.shape[1] # seq_len 48 | p = p[0, :-1, :] 49 | logp = torch.log(p) # (n - 1, vocab_size) 50 | labels = sen[0, 1:] # (n - 1) 51 | count = labels.shape[0] # (n-1) 52 | idxs = torch.arange(n-1) 53 | p_true = p[idxs, labels] # (n-1) 54 | logp_true = logp[idxs, labels] 55 | js1 = np.log(2) + torch.where(p_true > 0, 56 | p_true * (logp_true - torch.log1p(p_true)), 57 | torch.zeros_like(p_true) 58 | ) 59 | js2 = np.log(2) - torch.log1p(p[idxs, labels]) 60 | return (js1 + js2).sum().item() * 0.5, count 61 | 62 | ###################### 63 | # eps-perplexity score 64 | ###################### 65 | def eps_perplexity(p, sen, eps, vocab_size): 66 | n = p.shape[1] # seq_len 67 | gold_probs = p[0, torch.arange(n-1), sen[0, 1:]] 68 | return torch.log(gold_probs + eps) - np.log(1 + eps * vocab_size), n-1 69 | 70 | def eps_perplexity_lst(p, sen, eps_lst, vocab_size): 71 | n = p.shape[1] # seq_len 72 | gold_probs = p[0, torch.arange(n-1), sen[0, 1:]] 73 | ppl = (torch.log(gold_probs[None, :] + eps_lst[:, None]) 74 | - torch.log(1 + eps_lst[:, None] * vocab_size)).sum(dim=1) 75 | return ppl, n-1 76 | 77 | ####################################### 78 | # Repetition Statistics of Greedy Token 79 | ####################################### 80 | def rep_score_1(p, sen, hist_size): 81 | p = p[0, :-1, :] # (n-1, vocab_size) 82 | labels = sen[0, 1:] # (n-1) 83 | count = labels.shape[0] 84 | greedy = p.argmax(dim=1) # (n-1) 85 | reps = sum([1 for i in range(labels.shape[0]) 86 | if greedy[i] in labels[max(0, i-hist_size):i]]) 87 | return reps, count 88 | 89 | def wrep_score_1(p, sen, hist_size): 90 | p = p[0, :-1, :] # (n-1, vocab_size) 91 | labels = sen[0, 1:] # (n-1) 92 | count = labels.shape[0] 93 | greedy = p.argmax(dim=1) # (n-1)\ 94 | reps = sum([1 for i in range(labels.shape[0]) 95 | if (greedy[i] in labels[max(0, i-hist_size):i] 96 | and greedy[i] != labels[i]) 97 | ]) 98 | return reps, count 99 | 100 | 101 | ####################################################################### 102 | # Compute Metrics Based on Performance of Recalibrated Model on Dev Set 103 | ####################################################################### 104 | def compute_metrics_from_probs( 105 | model, dataset, metric_fn_lst, eppl_eps_lst=[], 106 | temperature=1.0, top_k=0, top_p=1.0, 107 | vocab_size=50257 108 | ): 109 | l = len(metric_fn_lst) 110 | num_metrics = len(metric_fn_lst) + len(eppl_eps_lst) 111 | device = next(model.parameters()).device 112 | eppl_eps_lst = torch.from_numpy(np.asarray(eppl_eps_lst)).to(device) 113 | m_numer = np.zeros(num_metrics) 114 | m_denom = np.zeros(num_metrics) 115 | metrics_final = np.zeros(num_metrics) 116 | device = next(model.parameters()).device 117 | for sen in utils.tqdm(dataset): 118 | sen = sen.to(device) 119 | logp = src.model_utils.get_tokenwise_log_probs_seq( 120 | model, sen, temperature=temperature, top_k=top_k, top_p=top_p,) 121 | p = torch.exp(logp) 122 | for i, fn in enumerate(metric_fn_lst): 123 | m, c = fn(p, sen) 124 | m_numer[i] += m 125 | m_denom[i] += c 126 | if eppl_eps_lst.shape[0] > 0: 127 | e_ppl, c = eps_perplexity_lst(p, sen, eppl_eps_lst, vocab_size) 128 | m_numer[l:] += e_ppl.cpu().numpy() 129 | m_denom[l:] += c 130 | metrics_final[:l] = m_numer[:l] / m_denom[:l] # fraction metrics 131 | metrics_final[l:] = np.exp(-m_numer[l:] / m_denom[l:]) # perplexity metrics 132 | return metrics_final 133 | 134 | def get_probs_metric_fn_lst(ls=[64]): 135 | reps = [lambda p, s: rep_score_1(p, s, l) for l in ls] 136 | wreps = [lambda p, s: wrep_score_1(p, s, l) for l in ls] 137 | return [sp_score_1, js_score_1, *reps, *wreps] 138 | 139 | def get_metric_names(ls=[64]): 140 | reps = [f'rep-{l}' for l in ls] 141 | wreps = [f'wrep-{l}' for l in ls] 142 | return ['sp-score', 'js-score', *reps, *wreps] 143 | 144 | 145 | ####################################### 146 | # Disctinct-n metrics 147 | ####################################### 148 | def get_ngram_freqs(samples, n): 149 | ngram_freq = Counter() 150 | for sen in samples: 151 | ngrams = ngrams_fn_nltk(sen, n) 152 | ngram_freq.update(ngrams) 153 | uniq = len(ngram_freq) 154 | total = sum(ngram_freq.values()) 155 | return uniq, total 156 | 157 | def get_unique_ngram_fraction(samples, n_lst): 158 | # distinct-n 159 | out = [] 160 | for n in n_lst: 161 | a, b = get_ngram_freqs(samples, n) 162 | freq = a * 1.0 / b if b > 0 else 0 163 | out.append(freq) 164 | return out 165 | 166 | ####################################### 167 | # Perplexity of Generations 168 | ####################################### 169 | def _get_perplexity_from_prob(logp, num_tokens): 170 | return torch.exp(-logp.sum() / num_tokens).item() 171 | 172 | 173 | def get_perplexity_from_samples(model, ds_tokens): 174 | logp, num_tokens = src.model_utils.get_log_probs_of_ds(model, ds_tokens) 175 | return _get_perplexity_from_prob(logp, num_tokens) 176 | 177 | ####################################### 178 | # Zipf Coefficient 179 | ####################################### 180 | def zipf_coeff(samples, min_num=1, max_num=5000, stretch_factor=15): 181 | # samples: list of lists of tokens; max_num: how many top frequency words to consider 182 | counter = Counter() 183 | for s in samples: 184 | counter.update(s) 185 | top_freqs = np.array(sorted(counter.values(), key=operator.neg)[:max_num]) 186 | # log scale overweights tail, so subsample the tail 187 | # this also helps the best-fit line look more reasonable when plotted in log-scale. 188 | xs, idxs_u = np.unique(np.round( 189 | stretch_factor * np.log(np.arange(min_num, min(len(counter), max_num)).astype(np.float64))) / stretch_factor, 190 | return_index=True) 191 | ys = np.log(top_freqs[idxs_u]) 192 | 193 | lr = linear_model.LinearRegression() 194 | lr.fit(xs.reshape(-1, 1), ys) 195 | slope = lr.coef_[0] 196 | 197 | return slope 198 | 199 | ####################################### 200 | # Repetition 201 | ####################################### 202 | def get_repetition_fraction(samples, max_n=500): 203 | # from https://github.com/ari-holtzman/degen/blob/master/metrics/repetition.py 204 | n_repeated_examples = 0 205 | for gen in samples: 206 | rev_gen = list(reversed(gen)) 207 | last_n_repeats = [0] * max_n 208 | for n in range(1, max_n + 1): 209 | n_repeat = 1 210 | while len(rev_gen[n * n_repeat:n * (n_repeat + 1)]) == n and \ 211 | rev_gen[n * n_repeat:n * (n_repeat + 1)] == rev_gen[:n]: 212 | n_repeat += 1 213 | last_n_repeats[n-1] = n_repeat 214 | max_repeated_n = max(range(max_n), key=lambda x: last_n_repeats[x]) 215 | if last_n_repeats[max_repeated_n] > 1 and (max_repeated_n+1 >= 3 or last_n_repeats[max_repeated_n] > 50): 216 | # repetition detected 217 | n_repeated_examples += 1 218 | return n_repeated_examples / len(samples) 219 | 220 | ####################################### 221 | # Non-termination Ratio 222 | ####################################### 223 | def get_nontermination_ratio(samples, is_completed): 224 | return sum(is_completed) / len(is_completed) 225 | -------------------------------------------------------------------------------- /src/model_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import math 4 | import sys, os, time, pickle as pkl 5 | import json 6 | import random 7 | from tqdm.auto import tqdm as tqdm_original 8 | from typing import Optional 9 | import torch 10 | from torch.nn.functional import softmax, log_softmax, relu 11 | 12 | from src.utils import tqdm 13 | 14 | 15 | @torch.no_grad() 16 | def my_top_k_top_p_filtering( 17 | logits: torch.Tensor, 18 | top_k: int = 0, 19 | top_p: float = 1.0, 20 | filter_value: float = -float("Inf"), 21 | min_tokens_to_keep: int = 1, 22 | ) -> torch.Tensor: 23 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 24 | Args: 25 | logits: logits distribution shape (batch size, vocabulary size) 26 | NOTE: hidden state must be from prior to the token output at the logitso 27 | pass in first_token_p if reliable hidden state is not available 28 | """ 29 | if top_k > 0: 30 | top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check 31 | # Remove all tokens with a probability less than the last token of the top-k 32 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 33 | logits[indices_to_remove] = filter_value 34 | 35 | if top_p < 1.0: 36 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 37 | cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1) 38 | 39 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 40 | sorted_indices_to_remove = (cumulative_probs > top_p) 41 | if min_tokens_to_keep > 1: 42 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 43 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 44 | # Shift the indices to the right to keep also the first token above the threshold 45 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 46 | sorted_indices_to_remove[..., 0] = 0 47 | 48 | # scatter sorted tensors to original indexing 49 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 50 | logits[indices_to_remove] = filter_value 51 | return logits 52 | 53 | 54 | def reshape_logit_scores( 55 | scores, temperature=1.0, top_k=0, top_p=1.0, 56 | ): 57 | # scores: (batch_size, length, vocab_size) 58 | assert temperature > 1e-10 59 | if temperature != 1.0: 60 | scores = scores / temperature 61 | shape = scores.shape 62 | # top_k_top_p_filtering requires 2D input 63 | scores = my_top_k_top_p_filtering( 64 | scores.view(-1, shape[-1]), 65 | top_k=top_k, top_p=top_p 66 | ).view(shape).contiguous() 67 | return scores 68 | 69 | 70 | @torch.no_grad() 71 | def get_tokenwise_log_probs_seq( 72 | model, sen, temperature=1.0, top_k=0, top_p=1.0, 73 | ): 74 | # TODO: only works for batch size 1 75 | device = next(model.parameters()).device 76 | sen = sen.to(device) 77 | outs = model(input_ids=sen, past_key_values=None, 78 | output_hidden_states=True, return_dict=True) 79 | logits = reshape_logit_scores( 80 | outs.logits, temperature, top_k, top_p, 81 | ) 82 | log_probs = log_softmax(logits, dim=2) 83 | return log_probs # (b, seq_len, vocab_size) 84 | 85 | 86 | @torch.no_grad() 87 | def get_log_probs_and_hidden_states(model, sen, hidden_layer=-1): 88 | device = next(model.parameters()).device 89 | sen = sen.to(device) 90 | outs = model(input_ids=sen, past_key_values=None, 91 | output_hidden_states=True, return_dict=True) 92 | log_probs = log_softmax(outs.logits, dim=2) # (b, seq_len, vocab_size) 93 | hs = outs.hidden_states[hidden_layer] # (b, seq_len, hidden_dim) 94 | return log_probs, hs 95 | 96 | 97 | @torch.no_grad() 98 | def get_logprob_of_seq_from_logits(logits, seq): 99 | # logits: (batch_size, seq_len, vocab_size) 100 | # works only if all elements in the batch have the same shape 101 | batch_size, seq_len = logits.shape[:2] 102 | log_probs = log_softmax(logits, dim=2) # (b, seq_len, vocab_size) 103 | # seq_next = (seq[1], seq[2], ..., seq[-1]) 104 | permutation = torch.arange(1, seq_len) 105 | seq_next = seq[:, permutation] # (batch_size, seq_len-1) 106 | # pick up log-probs corresponding to observed sequence 107 | i = torch.ger(torch.arange(batch_size), torch.ones(seq_len-1, dtype=torch.long)) 108 | j = torch.ger(torch.ones(batch_size, dtype=torch.long), torch.arange(seq_len-1)) 109 | return log_probs[i, j, seq_next] # (seq_len-1,) 110 | 111 | 112 | @torch.no_grad() 113 | def get_reshaped_log_probs_of_ds(model, ds_tokens, top_p=1.0, top_k=0, temperature=1.0): 114 | log_probs = [] 115 | device = next(model.parameters()).device 116 | for sen in tqdm(ds_tokens): 117 | sen = sen.to(device) 118 | outs = model(input_ids=sen, past_key_values=None, 119 | output_hidden_states=False, return_dict=True) 120 | logits = reshape_logit_scores( 121 | outs.logits, temperature, top_k, top_p, 122 | ) 123 | # log_p: (seq_len-1,) 124 | log_p = get_logprob_of_seq_from_logits(logits, sen) 125 | log_probs.append(log_p.detach().cpu()) 126 | return log_probs 127 | 128 | @torch.no_grad() 129 | def get_log_probs_of_ds(model, ds_tokens): 130 | log_probs = [] 131 | device = next(model.parameters()).device 132 | num_tokens = 0 133 | for sen in ds_tokens: 134 | num_tokens += sen.view(-1).shape[0] 135 | sen = sen.to(device) 136 | outs = model(input_ids=sen, past_key_values=None, 137 | output_hidden_states=False, return_dict=True) 138 | # log_p: (seq_len,) 139 | log_p = get_logprob_of_seq_from_logits(outs.logits, sen) 140 | log_probs.append(log_p.sum(axis=1)) 141 | return torch.cat(log_probs), num_tokens 142 | 143 | 144 | @torch.no_grad() 145 | def featurize_sequential(model, ds_tokens): 146 | device = next(model.parameters()).device 147 | t1 = time.time() 148 | feats = [] 149 | for sen in tqdm(ds_tokens): 150 | sen = sen.to(device) 151 | outs = model(input_ids=sen, past_key_values=None, 152 | output_hidden_states=True, return_dict=True) 153 | h = outs.hidden_states[-1] # (batch_size, seq_len, dim) 154 | feats.append(h[:, -1, :].cpu()) 155 | t2 = time.time() 156 | print(f'Featurize time: {round(t2-t1, 2)}') 157 | return torch.cat(feats) 158 | 159 | 160 | -------------------------------------------------------------------------------- /src/transformers_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List 2 | import torch 3 | 4 | ############################## 5 | # Legacy transformers code 6 | ############################## 7 | def postprocess_next_token_scores( 8 | scores, 9 | input_ids, 10 | no_repeat_ngram_size, 11 | bad_words_ids, 12 | cur_len, 13 | min_length, 14 | max_length, 15 | eos_token_id, 16 | repetition_penalty, 17 | batch_size, 18 | num_beams, 19 | ): 20 | # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) 21 | if repetition_penalty != 1.0: 22 | enforce_repetition_penalty_( 23 | scores, 24 | batch_size, 25 | num_beams, 26 | input_ids, 27 | repetition_penalty, 28 | ) 29 | 30 | # set eos token prob to zero if min_length is not reached 31 | if eos_token_id is not None and cur_len < min_length: 32 | scores[:, eos_token_id] = -float("inf") 33 | 34 | if no_repeat_ngram_size > 0: 35 | # calculate a list of banned tokens to prevent repetitively generating the same ngrams 36 | num_batch_hypotheses = batch_size * num_beams 37 | # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 38 | banned_batch_tokens = calc_banned_ngram_tokens( 39 | input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len 40 | ) 41 | for i, banned_tokens in enumerate(banned_batch_tokens): 42 | scores[i, banned_tokens] = -float("inf") 43 | 44 | if bad_words_ids is not None: 45 | # Exclude EOS token (already processed) 46 | bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) 47 | # calculate a list of banned tokens according to bad words 48 | banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) 49 | # Modify the scores in place by setting the banned tokens logits to `-inf` 50 | set_scores_to_inf_for_banned_tokens(scores, banned_tokens) 51 | 52 | return scores 53 | 54 | def enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty): 55 | """ 56 | Enforce the repetition penalty (from the `CTRL paper `__). 57 | """ 58 | for i in range(batch_size * num_beams): 59 | for previous_token in set(prev_output_tokens[i].tolist()): 60 | # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability 61 | if lprobs[i, previous_token] < 0: 62 | lprobs[i, previous_token] *= repetition_penalty 63 | else: 64 | lprobs[i, previous_token] /= repetition_penalty 65 | 66 | 67 | def calc_banned_ngram_tokens(prev_input_ids: torch.Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int): 68 | """Copied from fairseq for no_repeat_ngram in beam_search""" 69 | if cur_len + 1 < no_repeat_ngram_size: 70 | # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet 71 | return [[] for _ in range(num_hypos)] 72 | generated_ngrams = [{} for _ in range(num_hypos)] 73 | for idx in range(num_hypos): 74 | gen_tokens = prev_input_ids[idx].tolist() 75 | generated_ngram = generated_ngrams[idx] 76 | for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): 77 | prev_ngram_tuple = tuple(ngram[:-1]) 78 | generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] 79 | 80 | def _get_generated_ngrams(hypo_idx): 81 | # Before decoding the next token, prevent decoding of ngrams that have already appeared 82 | start_idx = cur_len + 1 - no_repeat_ngram_size 83 | ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) 84 | return generated_ngrams[hypo_idx].get(ngram_idx, []) 85 | 86 | banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] 87 | return banned_tokens 88 | 89 | def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]: 90 | banned_tokens = [] 91 | 92 | def _tokens_match(prev_tokens, tokens): 93 | if len(tokens) == 0: 94 | # if bad word tokens is just one token always ban it 95 | return True 96 | if len(tokens) > len(prev_tokens): 97 | # if bad word tokens are longer than prev tokens they can't be equal 98 | return False 99 | 100 | if prev_tokens[-len(tokens) :] == tokens: 101 | # if tokens match 102 | return True 103 | else: 104 | return False 105 | 106 | for prev_input_ids_slice in prev_input_ids: 107 | banned_tokens_slice = [] 108 | 109 | for banned_token_seq in bad_words_ids: 110 | assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format( 111 | bad_words_ids 112 | ) 113 | 114 | if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False: 115 | # if tokens do not match continue 116 | continue 117 | 118 | banned_tokens_slice.append(banned_token_seq[-1]) 119 | 120 | banned_tokens.append(banned_tokens_slice) 121 | 122 | return banned_tokens 123 | 124 | def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: 125 | """Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be 126 | a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...] 127 | Args: 128 | scores: logits distribution of shape (batch size, vocabulary size) 129 | banned_tokens: list of list of tokens to ban of length (batch_size) 130 | """ 131 | banned_mask_list = [] 132 | for idx, batch_banned_tokens in enumerate(banned_tokens): 133 | for token in batch_banned_tokens: 134 | banned_mask_list.append([idx, token]) 135 | if not banned_mask_list: 136 | return 137 | banned_mask = torch.LongTensor(banned_mask_list) 138 | indices = torch.ones(len(banned_mask)) 139 | # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: 140 | # [ 0 1 1 ] 141 | # [ 0 0 0 ] 142 | # [ 1 0 0 ] 143 | 144 | banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() 145 | scores.masked_fill_(banned_mask, -float("inf")) 146 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os, time 4 | import json 5 | import random 6 | from tqdm.auto import tqdm as tqdm_original 7 | import torch 8 | 9 | from transformers import GPT2LMHeadModel, GPT2Tokenizer 10 | from transformers import RobertaModel, RobertaTokenizer 11 | 12 | 13 | CPU_DEVICE = torch.device('cpu') 14 | tqdm = lambda *args, **kwargs: tqdm_original( 15 | *args, **kwargs, disable=os.environ.get("DISABLE_TQDM", False)) 16 | NEWLINE=198 17 | 18 | 19 | def make_basic_parser(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--device', type=int, 22 | help='choose one of [0, 1, 2, 3] for GPU, or CPU otherwise') 23 | parser.add_argument('--data_dir', type=str, default='./data') 24 | parser.add_argument('--ds_name', type=str) 25 | parser.add_argument('--datasplit', type=str) 26 | parser.add_argument('--model_name', type=str, default='gpt2') 27 | parser.add_argument('--featurize_model_name', type=str, default='gpt2-large') 28 | parser.add_argument('--use_large_feats', action='store_true') 29 | parser.add_argument('--seed', type=int, default=25041993) 30 | parser.add_argument('--prefix_len', type=int, default=10) 31 | parser.add_argument('--max_len', type=int, default=1024) 32 | parser.add_argument('--max_num_generations', type=int, default=5000) 33 | parser.add_argument('--prompt_size', type=int, default=10) 34 | parser.add_argument('--top_p', type=float, default=1.0) 35 | parser.add_argument('--top_k', type=int, default=0) 36 | parser.add_argument('--temp', type=float, default=1.0) 37 | parser.add_argument('--beam_size', type=int, default=4) 38 | parser.add_argument('--no_repeat_ngram', type=int, default=0) 39 | parser.add_argument('--entmax_alpha', type=float, default=1.1) 40 | return parser 41 | 42 | def make_metrics_parser(): 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--device', type=int, default=-1, 45 | help='choose one of [0, 1, 2, 3] for GPU, or CPU otherwise') 46 | parser.add_argument('--datasplit', type=str) 47 | parser.add_argument('--ref_name', type=str, help='name of human generations to search for') 48 | parser.add_argument('--max_len', type=int, default=1024) 49 | parser.add_argument('--ds_name', type=str) 50 | parser.add_argument('--max_num_data', type=int, default=5000) 51 | parser.add_argument('--model_name', type=str, default='gpt2') 52 | parser.add_argument('--data_dir', type=str, default='./data') 53 | parser.add_argument('--seed', type=int, default=1234) 54 | parser.add_argument('--force', action='store_true', help='Redo computation even if it already exists') 55 | parser.add_argument('--use_large_feats', action='store_true', help='Use feats from gpt2-large if true') 56 | parser.add_argument('--use_bert_feats', action='store_true', help='Use feats from Roberta-large if true') 57 | parser.add_argument('--subsample_frac', type=float, default=1.0) 58 | parser.add_argument('--subsample_seed', type=int) 59 | 60 | ########################## 61 | # Generation Types Args 62 | ########################## 63 | gen_parser = parser.add_argument_group('gen_args', 'Generation Type args') 64 | 65 | gen_parser.add_argument('--generation_type', type=str, 66 | choices=['basic', 'regr', 'beam', 'entmax']) 67 | # basic 68 | gen_parser.add_argument('--top_k', type=int, default=0) 69 | gen_parser.add_argument('--top_p', type=float, default=1.0) 70 | gen_parser.add_argument('--temp', type=float, default=1.0) 71 | gen_parser.add_argument('--generate_seed', type=int, default=1) 72 | # beam 73 | gen_parser.add_argument('--beam_size', type=int, default=4) 74 | gen_parser.add_argument('--no_repeat_ngram', type=int, default=0) 75 | # entmax 76 | gen_parser.add_argument('--entmax_alpha', type=float, default=1.1) 77 | 78 | ########################## 79 | # PR Args 80 | ########################## 81 | pr_parser = parser.add_argument_group('pr_args', 'Arguments to compute PR metrics') 82 | pr_parser.add_argument('--discretization', type=str, 83 | choices=['kmeans_l1', 'kmeans_l2', 'spv', 'drmm']) 84 | # kmeans 85 | pr_parser.add_argument('--kmeans_num_clusters', type=int, default=100) 86 | pr_parser.add_argument('--kmeans_explained_var', type=float, default=0.99) 87 | # spv 88 | pr_parser.add_argument('--spv_num_epochs', type=int, default=160) 89 | # drmm 90 | pr_parser.add_argument('--drmm_num_epochs', type=int, default=20) 91 | pr_parser.add_argument('--drmm_n_layer', type=int, default=3) 92 | pr_parser.add_argument('--drmm_n_component_per_layer', type=int, default=10) 93 | 94 | ########################## 95 | # Bleu args 96 | ########################## 97 | bleu_parser = parser.add_argument_group('bleu_args') 98 | bleu_parser.add_argument('--n_sample_bleu', type=int, default=500, 99 | help='how many sentences to sample to calculate self-bleu') 100 | bleu_parser.add_argument('--parallel_bleu', action='store_true', help='run in parallel') 101 | bleu_parser.add_argument('--n_proc_bleu', default=6, type=int) 102 | 103 | return parser 104 | 105 | def get_save_filename_from_args(args): 106 | if args.generation_type == 'basic': 107 | folder = 'basic' 108 | filename = f'{args.datasplit}_p{args.top_p}_k{args.top_k}_t{args.temp}_seed{args.generate_seed}' 109 | elif args.generation_type == 'beam': 110 | folder = 'beam' 111 | filename = f'{args.datasplit}_b{args.beam_size}_t{args.temp}_nr{args.no_repeat_ngram}_seed{args.generate_seed}' 112 | elif args.generation_type == 'entmax': 113 | folder = 'entmax' 114 | filename = f'{args.datasplit}_entmax{args.entmax_alpha}_seed{args.generate_seed}' 115 | else: 116 | raise ValueError('Unknown generation type', args.generation_type) 117 | print('folder, filename:', (folder, filename)) 118 | return folder, filename 119 | 120 | def split_dataset(ds, split_point=500, seed=0): 121 | rng = random.Random(seed) 122 | rng.shuffle(ds) 123 | return ds[:split_point], ds[split_point:] 124 | 125 | def get_device_from_arg(device_id): 126 | if (device_id is not None and 127 | torch.cuda.is_available() and 128 | 0 <= device_id < torch.cuda.device_count()): 129 | return torch.device(f'cuda:{device_id}') 130 | else: 131 | return CPU_DEVICE 132 | 133 | def get_model_and_tokenizer(model_name='gpt2', device=CPU_DEVICE): 134 | if 'gpt3' in model_name: # For GPT-3 evals, use GPT-2 large 135 | model_name = 'gpt2-large' 136 | if 'gpt2' in model_name: 137 | tokenizer = GPT2Tokenizer.from_pretrained(model_name) 138 | model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id).to(device) 139 | model = model.eval() 140 | elif 'roberta' in model_name: 141 | tokenizer = RobertaTokenizer.from_pretrained(model_name) 142 | model = RobertaModel.from_pretrained(model_name) 143 | else: 144 | raise ValueError(f'Unknown model: {model_name}') 145 | return model, tokenizer 146 | 147 | 148 | def get_dataset_name_from_datapath(datapath): 149 | known_datasets = ['gpt2_output_dataset', 'webtext', 'wp', 'grover'] 150 | transform = {'grover': 'grover', 'webtext': 'webtext', 'wp': 'writingPrompts', 151 | 'gpt2_output_dataset':'gpt2_output_dataset'} 152 | for ds_name in known_datasets: 153 | if transform[ds_name].lower() in datapath.lower(): 154 | return ds_name 155 | raise ValueError('Unknown dataset', datapath) 156 | 157 | 158 | def get_model_basename(model_name): 159 | if 'gpt2-large' in model_name: 160 | return 'gpt2-large' 161 | elif 'gpt2-xl' in model_name: 162 | return 'gpt2-xl' 163 | elif 'gpt2-medium' in model_name: 164 | return 'gpt2-medium' 165 | elif 'gpt2' in model_name: 166 | return 'gpt2' 167 | elif 'gpt3-ada' in model_name: 168 | return 'gpt3-ada' 169 | elif 'gpt3-babbage' in model_name: 170 | return 'gpt3-babbage' 171 | elif 'gpt3-curie' in model_name: 172 | return 'gpt3-curie' 173 | elif 'gpt3-davinci' in model_name: 174 | return 'gpt3-davinci' 175 | else: 176 | raise ValueError(f'Unknown model name {model_name}') 177 | 178 | def load_json_dataset(data_dir, dataset_name, split=None, max_num_data=np.inf): 179 | if split is None: 180 | path = os.path.join(data_dir, f'{dataset_name}.jsonl') 181 | else: 182 | path = os.path.join(data_dir, f'{dataset_name}.{split}.jsonl') 183 | texts = [] 184 | for i, line in enumerate(open(path)): 185 | if i >= max_num_data: 186 | break 187 | texts.append(json.loads(line)['text']) 188 | return texts 189 | 190 | def load_and_tokenize_data(tokenizer, data_dir, max_len, max_num_data, min_len=None, ds_name=None, split='valid'): 191 | assert max_len <= 1024 and max_num_data >= 2000, f"max_len={max_len}, max_num_data={max_num_data} are insufficent" 192 | t1 = time.time() 193 | if ds_name is None: 194 | ds_name = get_dataset_name_from_datapath(data_dir) 195 | texts = load_json_dataset(data_dir, ds_name, split=split, max_num_data=max_num_data) 196 | t2 = time.time() 197 | print(f'dataset load time: {round(t2-t1, 2)}') 198 | t1 = time.time() 199 | if min_len is None: 200 | tokenized_texts = [tokenizer.encode(sen, return_tensors='pt', truncation=True, max_length=max_len) 201 | for sen in texts] 202 | else: 203 | assert 0 <= min_len <= 100 204 | tokenized_texts = [tokenizer.encode(sen, truncation=True, max_length=max_len) 205 | for sen in texts] 206 | # append with newline if necessary 207 | for i in range(len(tokenized_texts)): 208 | if len(tokenized_texts[i]) < min_len: 209 | num_tokens_to_append = min_len - len(tokenized_texts[i]) 210 | tokenized_texts[i].extend([NEWLINE] * num_tokens_to_append) 211 | tokenized_texts = [torch.LongTensor(sen).unsqueeze(0) for sen in tokenized_texts] 212 | 213 | t2 = time.time() 214 | print(f'tokenizing time: {round(t2-t1, 2)}') 215 | return tokenized_texts 216 | 217 | def decode_samples_from_lst(tokenizer, lst): 218 | t1 = time.time() 219 | output = [] 220 | for l in lst: 221 | o = tokenizer.decode(torch.LongTensor(l), skip_special_tokens=True) 222 | output.append(o) 223 | t2 = time.time() 224 | print(f'de-tokenizing time: {round(t2-t1, 2)}') 225 | return output 226 | 227 | 228 | # def shift_hidden_state(hs): 229 | # # shift hidden state up so that hs[i] corresponds to what was seen before logits[i] 230 | # n = hs.shape[1] 231 | # hs = hs.squeeze(0) # (n, dim) 232 | # hs2 = hs.clone() 233 | # hs2[1:n] = hs[0:n-1] 234 | # hs2[0] = 0 # initial hidden state 235 | # return hs2[None] # (1, n, dim) 236 | 237 | 238 | --------------------------------------------------------------------------------