├── README.md ├── assets ├── adaptive_sampling.png ├── dis_diffusion_v2.png └── model_arch.png ├── basic_utils.py ├── datasets └── readme.md ├── diffuseq ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── gaussian_diffusion.cpython-37.pyc │ ├── gaussian_diffusion.cpython-38.pyc │ ├── gaussian_diffusion.cpython-39.pyc │ ├── rounding.cpython-37.pyc │ ├── rounding.cpython-38.pyc │ ├── rounding.cpython-39.pyc │ ├── step_sample.cpython-37.pyc │ ├── step_sample.cpython-39.pyc │ ├── text_datasets.cpython-37.pyc │ ├── text_datasets.cpython-38.pyc │ ├── text_datasets.cpython-39.pyc │ ├── transformer_model.cpython-37.pyc │ ├── transformer_model.cpython-38.pyc │ └── transformer_model.cpython-39.pyc ├── config.json ├── gaussian_diffusion.py ├── rounding copy.py ├── rounding.py ├── step_sample.py ├── text_datasets.py ├── transformer_model.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── dist_util.cpython-37.pyc │ ├── dist_util.cpython-38.pyc │ ├── dist_util.cpython-39.pyc │ ├── fp16_util.cpython-37.pyc │ ├── fp16_util.cpython-39.pyc │ ├── logger.cpython-37.pyc │ ├── logger.cpython-38.pyc │ ├── logger.cpython-39.pyc │ ├── losses.cpython-37.pyc │ ├── losses.cpython-38.pyc │ ├── losses.cpython-39.pyc │ ├── nn.cpython-37.pyc │ ├── nn.cpython-38.pyc │ └── nn.cpython-39.pyc │ ├── dist_util.py │ ├── fp16_util.py │ ├── logger.py │ ├── losses.py │ └── nn.py ├── evaluation ├── __pycache__ │ ├── tokenizer.cpython-37.pyc │ └── tokenizer.cpython-39.pyc ├── eval.py ├── quick_eval.py └── tokenizer.py ├── generation_cases ├── qqp_20 │ ├── mbr.jsonl │ ├── out.log │ ├── seed10_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed11_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed12_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed13_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed14_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed20_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed21_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed22_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed23_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ └── seed24_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json └── ts-20 │ ├── mbr.jsonl │ ├── mbr.jsonl_bk │ ├── out.log │ ├── out.log2 │ ├── seed11_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed13_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed14_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed15_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed23_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed24_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed25_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed35_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ ├── seed36_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json │ └── seed37_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json ├── requirement.txt ├── sample_seq2seq.py ├── scripts ├── eval_seq2seq.py ├── run_decode.py └── run_train.py ├── train.py └── train_util.py /README.md: -------------------------------------------------------------------------------- 1 | # Bridging the Gap between Training and Inference for Diffusion Model 2 | 3 | This is the official code for [Can Diffusion Model Achieve Better Performance in Text Generation? Bridging the Gap between Training and Inference!](https://arxiv.org/pdf/2305.04465.pdf) 4 | 5 | # Highlight 6 | 7 | 1. One can post-train your own diffusion model with two methods below to ``accelerate the inference`` speed and ``achieve better performance`` ! 8 | 9 | 2. Extensive experiments show our method can generate a full sequence with 128 tokens in only ``4`` denoising steps ! 10 | 11 | ## Model Architecture 12 |

Logo

13 | 14 | ## Down-Sampling Strategy 15 | 16 |

Logo

17 | 18 | # Dataset & Model Prepartion 19 | 20 | ## Dataset 21 | We provide the download link for all the data used in our paper: 22 | 23 | | Task | Dataset | Samples | Used in our paper | 24 | |------|---------| ---------| ---------| 25 | |Text Simplification| [WIKI AUTO](https://github.com/chaojiang06/wiki-auto) | 677k | [download](https://drive.google.com/drive/folders/1yIo3qploLvtSc9CAzohAeKlHjOoNRfLg?usp=sharing)| 26 | | Paraphrase | [Quora Question Pairs](https://www.kaggle.com/c/quora-question-pairs) | 114k | [download](https://drive.google.com/drive/folders/1kclZh3KTS1IOD3tre6ybsX7UhRkwEPeW?usp=share_link)| 27 | | Story Generation | [ROC Story](https://cs.rochester.edu/nlp/rocstories/) | 88k | [download](https://drive.google.com/drive/folders/1bvjIroxJaACGIkACwSxCCJHPh1PBR3Zv?usp=sharing) | 28 | | Question Generation | [Quasar-T](https://drive.google.com/drive/folders/122YK0IElSnGZbPMigXrduTVL1geB4wEW?usp=sharing) | 117k | [download](https://drive.google.com/drive/folders/122YK0IElSnGZbPMigXrduTVL1geB4wEW?usp=sharing) | 29 | | E2E (Semantic / Syntax) | [E2E](http://www.macs.hw.ac.uk/) | 88k | [download](https://drive.google.com/drive/folders/1YJwa3SIqg2d0VkfzCrVEo8QtrZwkxcBX?usp=sharing) | 30 | 31 | Please download the data and place under the ``./datasets`` folder 32 | 33 | ## Backbone Model 34 | Please refer to the following repos for more details: 35 | 36 | [DiffuSeq: Sequence to Sequence Text Generation with Diffusion Models](https://github.com/Shark-NLP/DiffuSeq) 37 | 38 | [Diffusion-LM Improves Controllable Text Generation](https://github.com/XiangLi1999/Diffusion-LM) 39 | 40 | ``Note`` We also provide the two post-trained models [link](https://drive.google.com/drive/folders/1UvcN9mKOv-nVZuQpJaAOQWCG22sGk_2O?usp=sharing) for quick check 41 | 42 | 43 | # Quick Start 44 | We provide the code for post-training on QQP (Paraphrase) dataset 45 | 46 | ## Environment 47 | ``` 48 | conda create -n diffsuion python=3.9 49 | conda activate diffusion 50 | pip install -r requirement.txt 51 | ``` 52 | 53 | ## Training 54 | We conduct experiment with 4 NVIDIA-A100(40GB) 55 | ```bash 56 | cd scripts 57 | export CUDA_VISIBLE_DEVICES=0,1,2,3; 58 | 59 | DISTRIBUTE_ARGS=" 60 | --nproc_per_node=4 \ 61 | --use_env 62 | " 63 | 64 | TRAIN_ARGS=" 65 | --diff_steps 2000 \ 66 | --microbatch 100 \ 67 | --lr 0.0001 \ 68 | --learning_steps 320000 \ 69 | --save_interval 2500 \ 70 | --seed 109 \ 71 | --noise_schedule sqrt \ 72 | --hidden_dim 128 \ 73 | --bsz 100 \ 74 | --dataset qqp \ 75 | --data_dir datasets/QQP \ 76 | --vocab bert \ 77 | --seq_len 128 \ 78 | --simi_penalty l2_noise_random \ 79 | --simi_lambda -2 \ 80 | --simi_step 10 \ 81 | --simi_noise 0.05 \ 82 | --resume_checkpoint /path/to/checkpoint \ 83 | --schedule_sampler lossaware \ 84 | --notes qqp 85 | " 86 | 87 | python -m torch.distributed.launch $DISTRIBUTE_ARGS run_train.py $TRAIN_ARGS 88 | 89 | ``` 90 | 91 | 92 | # Inference 93 | 94 | ```bash 95 | python sample_seq2seq.py \ 96 | --model_path /path/to/checkpoint \ 97 | --step 2000 \ 98 | --batch_size 16 \ 99 | --seed2 10 \ 100 | --split test \ 101 | --out_dir generation_outputs \ 102 | --decode_respacing "adp_20" 103 | ``` 104 | 105 | 106 | 107 | # Acknowledgement 108 | We appreciate the open source of the following projects: 109 | 110 | [DiffuSeq](https://github.com/Shark-NLP/DiffuSeq)  111 | [Diffusion-LM](https://github.com/XiangLi1999/Diffusion-LM)  112 | -------------------------------------------------------------------------------- /assets/adaptive_sampling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/assets/adaptive_sampling.png -------------------------------------------------------------------------------- /assets/dis_diffusion_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/assets/dis_diffusion_v2.png -------------------------------------------------------------------------------- /assets/model_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/assets/model_arch.png -------------------------------------------------------------------------------- /basic_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json, os 4 | import time 5 | 6 | from diffuseq import gaussian_diffusion as gd 7 | from diffuseq.gaussian_diffusion import SpacedDiffusion, space_timesteps 8 | from diffuseq.transformer_model import TransformerNetModel 9 | from transformers import AutoTokenizer, PreTrainedTokenizerFast 10 | 11 | class myTokenizer(): 12 | """ 13 | Load tokenizer from bert config or defined BPE vocab dict 14 | """ 15 | ################################################ 16 | ### You can custome your own tokenizer here. ### 17 | ################################################ 18 | def __init__(self, args): 19 | if args.vocab == 'bert': 20 | tokenizer = AutoTokenizer.from_pretrained(args.config_name) 21 | self.tokenizer = tokenizer 22 | self.sep_token_id = tokenizer.sep_token_id 23 | self.pad_token_id = tokenizer.pad_token_id 24 | # save 25 | tokenizer.save_pretrained(args.checkpoint_path) 26 | else: 27 | # load vocab from the path 28 | print('#'*30, 'load vocab from', args.vocab) 29 | vocab_dict = {'[START]': 0, '[END]': 1, '[UNK]':2, '[PAD]':3} 30 | with open(args.vocab, 'r', encoding='utf-8') as f: 31 | for row in f: 32 | vocab_dict[row.strip().split(' ')[0]] = len(vocab_dict) 33 | self.tokenizer = vocab_dict 34 | self.rev_tokenizer = {v: k for k, v in vocab_dict.items()} 35 | self.sep_token_id = vocab_dict['[END]'] 36 | self.pad_token_id = vocab_dict['[PAD]'] 37 | # save 38 | if int(os.environ['LOCAL_RANK']) == 0: 39 | path_save_vocab = f'{args.checkpoint_path}/vocab.json' 40 | with open(path_save_vocab, 'w') as f: 41 | json.dump(vocab_dict, f) 42 | 43 | self.vocab_size = len(self.tokenizer) 44 | args.vocab_size = self.vocab_size # update vocab size in args 45 | 46 | def encode_token(self, sentences): 47 | if isinstance(self.tokenizer, dict): 48 | input_ids = [[0] + [self.tokenizer.get(x, self.tokenizer['[UNK]']) for x in seq.split()] + [1] for seq in sentences] 49 | elif isinstance(self.tokenizer, PreTrainedTokenizerFast): 50 | input_ids = self.tokenizer(sentences, add_special_tokens=True)['input_ids'] 51 | else: 52 | assert False, "invalid type of vocab_dict" 53 | return input_ids 54 | 55 | def decode_token(self, seq): 56 | if isinstance(self.tokenizer, dict): 57 | seq = seq.squeeze(-1).tolist() 58 | while len(seq)>0 and seq[-1] == self.pad_token_id: 59 | seq.pop() 60 | tokens = " ".join([self.rev_tokenizer[x] for x in seq]).replace('__ ', '').replace('@@ ', '') 61 | elif isinstance(self.tokenizer, PreTrainedTokenizerFast): 62 | seq = seq.squeeze(-1).tolist() 63 | while len(seq)>0 and seq[-1] == self.pad_token_id: 64 | seq.pop() 65 | tokens = self.tokenizer.decode(seq) 66 | else: 67 | assert False, "invalid type of vocab_dict" 68 | return tokens 69 | 70 | 71 | def load_model_emb(args, tokenizer): 72 | ### random emb or pre-defined embedding like glove embedding. You can custome your own init here. 73 | model = torch.nn.Embedding(tokenizer.vocab_size, args.hidden_dim) 74 | path_save = '{}/random_emb.torch'.format(args.checkpoint_path) 75 | if int(os.environ['LOCAL_RANK']) == 0: 76 | if os.path.exists(path_save): 77 | print('reload the random embeddings', model) 78 | model.load_state_dict(torch.load(path_save)) 79 | else: 80 | print('initializing the random embeddings', model) 81 | torch.nn.init.normal_(model.weight) 82 | torch.save(model.state_dict(), path_save) 83 | else: 84 | while not os.path.exists(path_save): 85 | time.sleep(1) 86 | print('reload the random embeddings', model) 87 | model.load_state_dict(torch.load(path_save)) 88 | 89 | return model, tokenizer 90 | 91 | 92 | def load_tokenizer(args): 93 | tokenizer = myTokenizer(args) 94 | return tokenizer 95 | 96 | def load_defaults_config(): 97 | """ 98 | Load defaults for training args. 99 | """ 100 | with open('diffuseq/config.json', 'r') as f: 101 | return json.load(f) 102 | 103 | 104 | def create_model_and_diffusion( 105 | hidden_t_dim, 106 | hidden_dim, 107 | vocab_size, 108 | config_name, 109 | use_plm_init, 110 | dropout, 111 | diffusion_steps, 112 | noise_schedule, 113 | learn_sigma, 114 | timestep_respacing, 115 | predict_xstart, 116 | rescale_timesteps, 117 | sigma_small, 118 | rescale_learned_sigmas, 119 | use_kl, 120 | notes, 121 | **kwargs, 122 | ): 123 | model = TransformerNetModel( 124 | input_dims=hidden_dim, 125 | output_dims=(hidden_dim if not learn_sigma else hidden_dim*2), 126 | hidden_t_dim=hidden_t_dim, 127 | dropout=dropout, 128 | config_name=config_name, 129 | vocab_size=vocab_size, 130 | init_pretrained=use_plm_init 131 | ) 132 | 133 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 134 | 135 | if not timestep_respacing: 136 | timestep_respacing = [diffusion_steps] 137 | elif timestep_respacing.startswith("ddim"): 138 | pass 139 | elif timestep_respacing.startswith("x2"): 140 | pass 141 | elif timestep_respacing.startswith("adp"): 142 | pass 143 | else: 144 | timestep_respacing = json.loads(timestep_respacing) 145 | 146 | diffusion = SpacedDiffusion( 147 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 148 | betas=betas, 149 | rescale_timesteps=rescale_timesteps, 150 | predict_xstart=predict_xstart, 151 | learn_sigmas = learn_sigma, 152 | sigma_small = sigma_small, 153 | use_kl = use_kl, 154 | rescale_learned_sigmas=rescale_learned_sigmas 155 | ) 156 | 157 | return model, diffusion 158 | 159 | 160 | def add_dict_to_argparser(parser, default_dict): 161 | for k, v in default_dict.items(): 162 | v_type = type(v) 163 | if v is None: 164 | v_type = str 165 | elif isinstance(v, bool): 166 | v_type = str2bool 167 | parser.add_argument(f"--{k}", default=v, type=v_type) 168 | 169 | 170 | def args_to_dict(args, keys): 171 | return {k: getattr(args, k) for k in keys} 172 | 173 | 174 | def str2bool(v): 175 | """ 176 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 177 | """ 178 | if isinstance(v, bool): 179 | return v 180 | if v.lower() in ("yes", "true", "t", "y", "1"): 181 | return True 182 | elif v.lower() in ("no", "false", "f", "n", "0"): 183 | return False 184 | else: 185 | raise argparse.ArgumentTypeError("boolean value expected") 186 | -------------------------------------------------------------------------------- /datasets/readme.md: -------------------------------------------------------------------------------- 1 | # Dataset Prepartion 2 | 3 | | Task | Dataset | Samples | Used in our paper | 4 | |------|---------| ---------| ---------| 5 | |Text Simplification| [WIKI AUTO](https://github.com/chaojiang06/wiki-auto) | 677k | [download](https://drive.google.com/drive/folders/1yIo3qploLvtSc9CAzohAeKlHjOoNRfLg?usp=sharing)| 6 | | Paraphrase | [Quora Question Pairs](https://www.kaggle.com/c/quora-question-pairs) | 114k | [download](https://drive.google.com/drive/folders/1kclZh3KTS1IOD3tre6ybsX7UhRkwEPeW?usp=share_link)| 7 | | Story Generation | [ROC Story](https://cs.rochester.edu/nlp/rocstories/) | 88k | [download](https://drive.google.com/drive/folders/1bvjIroxJaACGIkACwSxCCJHPh1PBR3Zv?usp=sharing) | 8 | | Question Generation | [Quasar-T](https://drive.google.com/drive/folders/122YK0IElSnGZbPMigXrduTVL1geB4wEW?usp=sharing) | 117k | [download](https://drive.google.com/drive/folders/122YK0IElSnGZbPMigXrduTVL1geB4wEW?usp=sharing) | 9 | | E2E (Semantic / Syntax) | [E2E](http://www.macs.hw.ac.uk/) | 88k | [download](https://drive.google.com/drive/folders/1YJwa3SIqg2d0VkfzCrVEo8QtrZwkxcBX?usp=sharing) | 10 | -------------------------------------------------------------------------------- /diffuseq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__init__.py -------------------------------------------------------------------------------- /diffuseq/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/gaussian_diffusion.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/gaussian_diffusion.cpython-37.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/gaussian_diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/gaussian_diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/gaussian_diffusion.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/gaussian_diffusion.cpython-39.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/rounding.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/rounding.cpython-37.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/rounding.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/rounding.cpython-38.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/rounding.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/rounding.cpython-39.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/step_sample.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/step_sample.cpython-37.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/step_sample.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/step_sample.cpython-39.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/text_datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/text_datasets.cpython-37.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/text_datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/text_datasets.cpython-38.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/text_datasets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/text_datasets.cpython-39.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/transformer_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/transformer_model.cpython-37.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/transformer_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/transformer_model.cpython-38.pyc -------------------------------------------------------------------------------- /diffuseq/__pycache__/transformer_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/transformer_model.cpython-39.pyc -------------------------------------------------------------------------------- /diffuseq/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr": 0.0001, 3 | "batch_size": 2048, 4 | "microbatch": 64, 5 | "learning_steps": 320000, 6 | "log_interval": 5, 7 | "save_interval": 100, 8 | "eval_interval": 5000, 9 | "ema_rate": "0.9999", 10 | "resume_checkpoint": "none", 11 | "schedule_sampler": "lossaware", 12 | "diffusion_steps": 2000, 13 | "noise_schedule": "sqrt", 14 | "timestep_respacing": "", 15 | "vocab": "bert", 16 | "use_plm_init": "no", 17 | "vocab_size": 0, 18 | "config_name": "huggingface-config", 19 | "notes": "folder-notes", 20 | "data_dir": "data-dir", 21 | "dataset": "dataset-name", 22 | "checkpoint_path": "checkpoint-path", 23 | "seq_len": 128, 24 | "hidden_t_dim": 128, 25 | "hidden_dim": 128, 26 | "dropout": 0.1, 27 | "use_fp16": false, 28 | "fp16_scale_growth": 0.001, 29 | "seed": 102, 30 | "gradient_clipping": -1.0, 31 | "weight_decay": 0.0, 32 | "learn_sigma": false, 33 | "use_kl": false, 34 | "predict_xstart": true, 35 | "rescale_timesteps": true, 36 | "rescale_learned_sigmas": false, 37 | "sigma_small": false, 38 | "emb_scale_factor": 1.0, 39 | "simi_penalty": "", 40 | "simi_lambda": 0.01, 41 | "simi_step": 10, 42 | "near_step": 0, 43 | "far_step": 0, 44 | "near_lambda": 0.0, 45 | "far_lambda": 0.0, 46 | "simi_noise": 0.05 47 | } 48 | -------------------------------------------------------------------------------- /diffuseq/rounding copy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # bert results 3 | from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator, GPT2TokenizerFast 4 | import sys, yaml, os 5 | import json 6 | 7 | import numpy as np 8 | 9 | def get_knn(model_emb, text_emb, dist='cos'): 10 | if dist == 'cos': 11 | adjacency = model_emb @ text_emb.transpose(1, 0).to(model_emb.device) 12 | elif dist == 'l2': 13 | adjacency = model_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( 14 | model_emb.size(0), -1, -1) 15 | adjacency = -torch.norm(adjacency, dim=-1) 16 | topk_out = torch.topk(adjacency, k=6, dim=0) 17 | return topk_out.values, topk_out.indices 18 | 19 | def get_efficient_knn(model_emb, text_emb): 20 | emb_norm = (model_emb**2).sum(-1).view(-1, 1) # vocab 21 | text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen 22 | arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1 23 | # print(emb_norm.shape, arr_norm.shape) 24 | dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(model_emb, text_emb_t) # (vocab, d) x (d, bsz*seqlen) 25 | dist = torch.clamp(dist, 0.0, np.inf) 26 | # print(dist.shape) 27 | topk_out = torch.topk(-dist, k=1, dim=0) 28 | return topk_out.values, topk_out.indices 29 | 30 | def rounding_func(text_emb_lst, model, tokenizer, emb_scale_factor=1.0): 31 | decoded_out_lst = [] 32 | 33 | model_emb = model.weight # input_embs 34 | down_proj_emb2 = None 35 | 36 | dist = 'l2' 37 | 38 | for text_emb in text_emb_lst: 39 | import torch 40 | text_emb = torch.tensor(text_emb) 41 | # print(text_emb.shape) 42 | if len(text_emb.shape) > 2: 43 | text_emb = text_emb.view(-1, text_emb.size(-1)) 44 | else: 45 | text_emb = text_emb 46 | val, indices = get_knn((down_proj_emb2 if dist == 'cos' else model_emb), 47 | text_emb.to(model_emb.device), dist=dist) 48 | 49 | decoded_out_lst.append(tokenizer.decode_token(indices[0])) 50 | 51 | return decoded_out_lst 52 | 53 | def compute_logp(args, model, x, input_ids): 54 | word_emb = model.weight 55 | sigma = 0.1 56 | if args.model_arch == '1d-unet': 57 | x = x.permute(0, 2, 1) 58 | 59 | bsz, seqlen, dim = x.shape 60 | 61 | x_flat = x.reshape(-1, x.size(-1)).unsqueeze(0) # 1, bsz*sample*seqlen, dim 62 | word_emb_flat = word_emb.unsqueeze(1) # vocab, 1, dim 63 | diff = (x_flat - word_emb_flat) ** 2 # vocab, seqlen, dim 64 | 65 | logp_expanded = -diff.sum(dim=-1) / (2 * sigma ** 2) # vocab, seqlen 66 | logp_expanded = logp_expanded.permute((1, 0)) 67 | 68 | ce = torch.nn.CrossEntropyLoss(reduction='none') 69 | loss = ce(logp_expanded, input_ids.view(-1)).view(bsz, seqlen) 70 | 71 | return loss 72 | 73 | def get_weights(model, args): 74 | if hasattr(model, 'transformer'): 75 | input_embs = model.transformer.wte # input_embs 76 | down_proj = model.down_proj 77 | model_emb = down_proj(input_embs.weight) 78 | print(model_emb.shape) 79 | model = torch.nn.Embedding(model_emb.size(0), model_emb.size(1)) 80 | print(args.emb_scale_factor) 81 | model.weight.data = model_emb * args.emb_scale_factor 82 | 83 | elif hasattr(model, 'weight'): 84 | pass 85 | else: 86 | assert NotImplementedError 87 | 88 | model.weight.requires_grad = False 89 | return model 90 | 91 | def denoised_fn_round(args, model, text_emb, t): 92 | # print(text_emb.shape) # bsz, seqlen, dim 93 | model_emb = model.weight # input_embs 94 | # print(t) 95 | old_shape = text_emb.shape 96 | old_device = text_emb.device 97 | 98 | if len(text_emb.shape) > 2: 99 | text_emb = text_emb.reshape(-1, text_emb.size(-1)) 100 | else: 101 | text_emb = text_emb 102 | # val, indices = get_knn(model_emb, text_emb.to(model_emb.device), dist=dist) 103 | val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device)) 104 | rounded_tokens = indices[0] 105 | # print(rounded_tokens.shape) 106 | new_embeds = model(rounded_tokens).view(old_shape).to(old_device) 107 | 108 | return new_embeds -------------------------------------------------------------------------------- /diffuseq/rounding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # bert results 3 | from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator, GPT2TokenizerFast 4 | import sys, yaml, os 5 | import json 6 | 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | # def get_knn(model_emb, text_emb, dist='cos'): 11 | # if dist == 'cos': 12 | # adjacency = model_emb @ text_emb.transpose(1, 0).to(model_emb.device) 13 | # elif dist == 'l2': 14 | # adjacency = model_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( 15 | # model_emb.size(0), -1, -1) 16 | # adjacency = -torch.norm(adjacency, dim=-1) 17 | # topk_out = torch.topk(adjacency, k=6, dim=0) 18 | # return topk_out.values, topk_out.indices 19 | 20 | # def enforce_repetition_penalty_(self, lprobs, batch_size, prev_output_tokens, repetition_penalty): 21 | # """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """ 22 | # for i in range(batch_size): 23 | # for previous_token in set(prev_output_tokens[i].tolist()): 24 | # # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability 25 | # if lprobs[i, previous_token] < 0: 26 | # lprobs[i, previous_token] *= repetition_penalty 27 | # else: 28 | # lprobs[i, previous_token] /= repetition_penalty 29 | 30 | def get_efficient_knn(model_emb, text_emb): 31 | emb_norm = (model_emb**2).sum(-1).view(-1, 1) # vocab 32 | text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen 33 | arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1 34 | # print(emb_norm.shape, arr_norm.shape) 35 | dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(model_emb, text_emb_t) # (vocab, d) x (d, bsz*seqlen) 36 | dist = torch.clamp(dist, 0.0, np.inf) 37 | # print(dist.shape) 38 | topk_out = torch.topk(-dist, k=1, dim=0) 39 | return topk_out.values, topk_out.indices 40 | 41 | def get_efficient_knn_top_k(model_emb, text_emb, top_k, tau = 1.0): 42 | emb_norm = (model_emb**2).sum(-1).view(-1, 1) # vocab 43 | text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen 44 | arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1 45 | # print(emb_norm.shape, arr_norm.shape) 46 | dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(model_emb, text_emb_t) # (vocab, d) x (d, bsz*seqlen) 47 | dist = torch.clamp(dist, 0.0, np.inf) 48 | # print(dist.shape) 49 | topk_out = torch.topk(-dist, k=top_k, dim=0) 50 | 51 | sftmx = torch.nn.Softmax(dim=-1) 52 | indices = topk_out.indices.transpose(0,1) 53 | values = sftmx(topk_out.values.transpose(0,1)/ tau) 54 | idx = torch.multinomial(values, 1) 55 | indices = torch.gather(indices, -1, idx).transpose(0,1) 56 | values = torch.gather(values, -1, idx).transpose(0,1) 57 | 58 | return values, indices 59 | 60 | def get_efficient_knn_top_p(model_emb, text_emb, top_p, tau = 1.0, scale = 1.0): 61 | emb_norm = (model_emb**2).sum(-1).view(-1, 1) # vocab 62 | text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen 63 | arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1 64 | # print(emb_norm.shape, arr_norm.shape) 65 | dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(model_emb, text_emb_t) # (vocab, d) x (d, bsz*seqlen) 66 | dist = torch.clamp(dist, 0.0, np.inf) 67 | # print(dist.shape) 68 | # topk_out = torch.topk(-dist, k=top_k, dim=0) 69 | 70 | # dist = 71 | dist = -dist.transpose(0,1) 72 | sorted_logits, sorted_indices = torch.sort(dist, descending=True) 73 | if scale == "last": 74 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits / tau, dim=-1), dim=-1) 75 | else: 76 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 77 | sorted_indices_to_remove = cumulative_probs > top_p 78 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 79 | sorted_indices_to_remove[..., 0] = 0 80 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 81 | dist[indices_to_remove] = -float("Inf") 82 | 83 | # tau *= scale 84 | probs = F.softmax(dist / tau, dim=-1) 85 | 86 | idx = torch.multinomial(probs, 1) 87 | 88 | return -dist.transpose(0,1), idx.transpose(0,1) 89 | 90 | def get_efficient_cos_top_p(model_emb, text_emb, top_p, tau = 1.0, scale = 1.0): 91 | dist = model_emb @ text_emb.transpose(1, 0).to(model_emb.device) 92 | 93 | # dist = 94 | dist = dist.transpose(0,1) 95 | sorted_logits, sorted_indices = torch.sort(dist, descending=True) 96 | if scale == "last": 97 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits / tau, dim=-1), dim=-1) 98 | else: 99 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 100 | sorted_indices_to_remove = cumulative_probs > top_p 101 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 102 | sorted_indices_to_remove[..., 0] = 0 103 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 104 | dist[indices_to_remove] = -float("Inf") 105 | 106 | # tau *= scale 107 | probs = F.softmax(dist / tau, dim=-1) 108 | 109 | idx = torch.multinomial(probs, 1) 110 | 111 | return -dist.transpose(0,1), idx.transpose(0,1) 112 | 113 | 114 | def get_efficient_knn_top_l(model_emb, text_emb, top_p): 115 | emb_norm = (model_emb**2).sum(-1).view(-1, 1) # vocab 116 | text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen 117 | arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1 118 | # print(emb_norm.shape, arr_norm.shape) 119 | dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(model_emb, text_emb_t) # (vocab, d) x (d, bsz*seqlen) 120 | dist = torch.clamp(dist, 0.0, np.inf) 121 | # print(dist.shape) 122 | # topk_out = torch.topk(-dist, k=top_k, dim=0) 123 | 124 | # dist = 125 | sftmx = torch.nn.Softmax(dim=-1) 126 | indices = topk_out.indices.transpose(0,1) 127 | values = sftmx(topk_out.values.transpose(0,1)) 128 | idx = torch.multinomial(values, 1) 129 | indices = torch.gather(indices, -1, idx).transpose(0,1) 130 | values = torch.gather(values, -1, idx).transpose(0,1) 131 | 132 | return values, indices 133 | 134 | def rounding_func(text_emb_lst, model, tokenizer, emb_scale_factor=1.0, scale=1.0): 135 | decoded_out_lst = [] 136 | 137 | model_emb = model.weight # input_embs 138 | down_proj_emb2 = None 139 | 140 | dist = 'l2' 141 | 142 | for text_emb in text_emb_lst: 143 | import torch 144 | text_emb = torch.tensor(text_emb) 145 | # print(text_emb.shape) 146 | if len(text_emb.shape) > 2: 147 | text_emb = text_emb.view(-1, text_emb.size(-1)) 148 | else: 149 | text_emb = text_emb 150 | val, indices = get_knn((down_proj_emb2 if dist == 'cos' else model_emb), 151 | text_emb.to(model_emb.device), dist=dist) 152 | 153 | decoded_out_lst.append(tokenizer.decode_token(indices[0])) 154 | 155 | return decoded_out_lst 156 | 157 | def compute_logp(args, model, x, input_ids): 158 | word_emb = model.weight 159 | sigma = 0.1 160 | if args.model_arch == '1d-unet': 161 | x = x.permute(0, 2, 1) 162 | 163 | bsz, seqlen, dim = x.shape 164 | 165 | x_flat = x.reshape(-1, x.size(-1)).unsqueeze(0) # 1, bsz*sample*seqlen, dim 166 | word_emb_flat = word_emb.unsqueeze(1) # vocab, 1, dim 167 | diff = (x_flat - word_emb_flat) ** 2 # vocab, seqlen, dim 168 | 169 | logp_expanded = -diff.sum(dim=-1) / (2 * sigma ** 2) # vocab, seqlen 170 | logp_expanded = logp_expanded.permute((1, 0)) 171 | 172 | ce = torch.nn.CrossEntropyLoss(reduction='none') 173 | loss = ce(logp_expanded, input_ids.view(-1)).view(bsz, seqlen) 174 | 175 | return loss 176 | 177 | def get_weights(model, args): 178 | if hasattr(model, 'transformer'): 179 | input_embs = model.transformer.wte # input_embs 180 | down_proj = model.down_proj 181 | model_emb = down_proj(input_embs.weight) 182 | print(model_emb.shape) 183 | model = torch.nn.Embedding(model_emb.size(0), model_emb.size(1)) 184 | print(args.emb_scale_factor) 185 | model.weight.data = model_emb * args.emb_scale_factor 186 | 187 | elif hasattr(model, 'weight'): 188 | pass 189 | else: 190 | assert NotImplementedError 191 | 192 | model.weight.requires_grad = False 193 | return model 194 | 195 | def denoised_fn_round(args, model, text_emb, t): 196 | # print(text_emb.shape) # bsz, seqlen, dim 197 | model_emb = model.weight # input_embs 198 | # print(t) 199 | if args.clamp_skip == 1: 200 | if t[0].item() % 2 == 0: 201 | print(t[0].item(), ": clamp skip") 202 | return text_emb 203 | 204 | elif args.clamp_skip == 0: 205 | if t[0].item() % 2 == 1: 206 | # print(t[0].item(), ": clamp skip") 207 | return text_emb 208 | 209 | # print(t[0].item(), ": clamp do") 210 | 211 | old_shape = text_emb.shape 212 | old_device = text_emb.device 213 | 214 | if len(text_emb.shape) > 2: 215 | text_emb = text_emb.reshape(-1, text_emb.size(-1)) 216 | else: 217 | text_emb = text_emb 218 | # val, indices = get_knn(model_emb, text_emb.to(model_emb.device), dist=dist) 219 | 220 | if args.top_k != 0: 221 | # try: 222 | # T = (sum(json.loads(args.timestep_respacing)) if args.timestep_respacing else args.step) - 1 223 | # except: 224 | # T = int(args.timestep_respacing[4:]) 225 | # scale = (t[0].item() / T + args.scale_end * (1 - t[0].item() / T)) 226 | scale = 1 227 | top_k = max(int(args.top_k * scale), 1) 228 | val, indices = get_efficient_knn_top_k(model_emb, text_emb.to(model_emb.device), top_k, args.tau) 229 | elif args.top_p != 0: 230 | if args.scale_end == "last": 231 | if t[0].item() == 0: 232 | print(t[0].item(), ":do topp") 233 | val, indices = get_efficient_knn_top_p(model_emb, text_emb.to(model_emb.device), args.top_p, args.tau, args.scale_end) 234 | else: 235 | print(t[0].item(), ":do greedy") 236 | val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device)) 237 | elif args.scale_end == "odd": 238 | if t[0].item() % 2 == 1: 239 | print(t[0].item(), ":do topp") 240 | val, indices = get_efficient_knn_top_p(model_emb, text_emb.to(model_emb.device), args.top_p, args.tau, args.scale_end) 241 | else: 242 | print(t[0].item(), ":do greedy") 243 | val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device)) 244 | elif args.scale_end == "": 245 | print(t[0].item(), ":do topp") 246 | val, indices = get_efficient_knn_top_p(model_emb, text_emb.to(model_emb.device), args.top_p, args.tau, args.scale_end) 247 | elif args.scale_end.startswith('last_'): 248 | t_topp = int(args.scale_end.split('_')[-1]) 249 | if t[0].item() < t_topp: 250 | print(t[0].item(), ":do topp") 251 | val, indices = get_efficient_knn_top_p(model_emb, text_emb.to(model_emb.device), args.top_p, args.tau, args.scale_end) 252 | else: 253 | print(t[0].item(), ":do greedy") 254 | val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device)) 255 | elif args.scale_end.startswith('first_'): 256 | t_topp = int(args.scale_end.split('_')[-1]) 257 | if t[0].item() >= t_topp: 258 | print(t[0].item(), ":do topp") 259 | val, indices = get_efficient_knn_top_p(model_emb, text_emb.to(model_emb.device), args.top_p, args.tau, args.scale_end) 260 | else: 261 | print(t[0].item(), ":do greedy") 262 | val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device)) 263 | elif args.scale_end == 'cos': 264 | print(t[0].item(), ":do cos topp") 265 | val, indices = get_efficient_cos_top_p(model_emb, text_emb.to(model_emb.device), args.top_p, args.tau, args.scale_end) 266 | else: 267 | raise NotImplementedError("Unkown args.scale_end:", args.scale_end) 268 | 269 | 270 | else: 271 | val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device)) 272 | rounded_tokens = indices[0] 273 | # print(rounded_tokens.shape) 274 | new_embeds = model(rounded_tokens).view(old_shape).to(old_device) 275 | 276 | return new_embeds -------------------------------------------------------------------------------- /diffuseq/step_sample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "lossaware": 18 | return LossSecondMomentResampler(diffusion) 19 | elif name == "fixstep": 20 | return FixSampler(diffusion) 21 | else: 22 | raise NotImplementedError(f"unknown schedule sampler: {name}") 23 | 24 | 25 | class ScheduleSampler(ABC): 26 | """ 27 | A distribution over timesteps in the diffusion process, intended to reduce 28 | variance of the objective. 29 | 30 | By default, samplers perform unbiased importance sampling, in which the 31 | objective's mean is unchanged. 32 | However, subclasses may override sample() to change how the resampled 33 | terms are reweighted, allowing for actual changes in the objective. 34 | """ 35 | 36 | @abstractmethod 37 | def weights(self): 38 | """ 39 | Get a numpy array of weights, one per diffusion step. 40 | 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | 48 | :param batch_size: the number of timesteps. 49 | :param device: the torch device to save to. 50 | :return: a tuple (timesteps, weights): 51 | - timesteps: a tensor of timestep indices. 52 | - weights: a tensor of weights to scale the resulting losses. 53 | """ 54 | w = self.weights() 55 | p = w / np.sum(w) 56 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 57 | indices = th.from_numpy(indices_np).long().to(device) 58 | weights_np = 1 / (len(p) * p[indices_np]) 59 | weights = th.from_numpy(weights_np).float().to(device) 60 | return indices, weights 61 | 62 | 63 | class UniformSampler(ScheduleSampler): 64 | def __init__(self, diffusion): 65 | self.diffusion = diffusion 66 | self._weights = np.ones([diffusion.num_timesteps]) 67 | 68 | def weights(self): 69 | return self._weights 70 | 71 | class FixSampler(ScheduleSampler): 72 | def __init__(self, diffusion): 73 | self.diffusion = diffusion 74 | 75 | ############################################################### 76 | ### You can custome your own sampling weight of steps here. ### 77 | ############################################################### 78 | self._weights = np.concatenate([np.ones([diffusion.num_timesteps//2]), np.zeros([diffusion.num_timesteps//2]) + 0.5]) 79 | 80 | def weights(self): 81 | return self._weights 82 | 83 | 84 | class LossAwareSampler(ScheduleSampler): 85 | def update_with_local_losses(self, local_ts, local_losses): 86 | """ 87 | Update the reweighting using losses from a model. 88 | 89 | Call this method from each rank with a batch of timesteps and the 90 | corresponding losses for each of those timesteps. 91 | This method will perform synchronization to make sure all of the ranks 92 | maintain the exact same reweighting. 93 | 94 | :param local_ts: an integer Tensor of timesteps. 95 | :param local_losses: a 1D Tensor of losses. 96 | """ 97 | batch_sizes = [ 98 | th.tensor([0], dtype=th.int32, device=local_ts.device) 99 | for _ in range(dist.get_world_size()) 100 | ] 101 | dist.all_gather( 102 | batch_sizes, 103 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 104 | ) 105 | 106 | # Pad all_gather batches to be the maximum batch size. 107 | batch_sizes = [x.item() for x in batch_sizes] 108 | max_bs = max(batch_sizes) 109 | 110 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 111 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 112 | dist.all_gather(timestep_batches, local_ts) 113 | dist.all_gather(loss_batches, local_losses) 114 | timesteps = [ 115 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 116 | ] 117 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 118 | self.update_with_all_losses(timesteps, losses) 119 | 120 | @abstractmethod 121 | def update_with_all_losses(self, ts, losses): 122 | """ 123 | Update the reweighting using losses from a model. 124 | 125 | Sub-classes should override this method to update the reweighting 126 | using losses from the model. 127 | 128 | This method directly updates the reweighting without synchronizing 129 | between workers. It is called by update_with_local_losses from all 130 | ranks with identical arguments. Thus, it should have deterministic 131 | behavior to maintain state across workers. 132 | 133 | :param ts: a list of int timesteps. 134 | :param losses: a list of float losses, one per timestep. 135 | """ 136 | 137 | 138 | class LossSecondMomentResampler(LossAwareSampler): 139 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 140 | self.diffusion = diffusion 141 | self.history_per_term = history_per_term 142 | self.uniform_prob = uniform_prob 143 | self._loss_history = np.zeros( 144 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 145 | ) 146 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 147 | 148 | def weights(self): 149 | if not self._warmed_up(): 150 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 151 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 152 | weights /= np.sum(weights) 153 | weights *= 1 - self.uniform_prob 154 | weights += self.uniform_prob / len(weights) 155 | return weights 156 | 157 | def update_with_all_losses(self, ts, losses): 158 | for t, loss in zip(ts, losses): 159 | if self._loss_counts[t] == self.history_per_term: 160 | # Shift out the oldest loss term. 161 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 162 | self._loss_history[t, -1] = loss 163 | else: 164 | self._loss_history[t, self._loss_counts[t]] = loss 165 | self._loss_counts[t] += 1 166 | 167 | def _warmed_up(self): 168 | return (self._loss_counts == self.history_per_term).all() 169 | -------------------------------------------------------------------------------- /diffuseq/text_datasets.py: -------------------------------------------------------------------------------- 1 | # import blobfile as bf 2 | import numpy as np 3 | from torch.utils.data import DataLoader, Dataset 4 | 5 | import torch 6 | import json 7 | import psutil 8 | import datasets 9 | from datasets import Dataset as Dataset2 10 | 11 | def load_data_text( 12 | batch_size, 13 | seq_len, 14 | deterministic=False, 15 | data_args=None, 16 | model_emb=None, 17 | split='train', 18 | loaded_vocab=None, 19 | loop=True, 20 | ): 21 | """ 22 | For a dataset, create a generator over (seqs, kwargs) pairs. 23 | 24 | Each seq is an (bsz, len, h) float tensor, and the kwargs dict contains zero or 25 | more keys, each of which map to a batched Tensor of their own. 26 | The kwargs dict can be used for some meta information. 27 | 28 | :param batch_size: the batch size of each returned pair. 29 | :param seq_len: the max sequence length (one-side). 30 | :param deterministic: if True, yield results in a deterministic order. 31 | :param data_args: including dataset directory, num of dataset, basic settings, etc. 32 | :param model_emb: loaded word embeddings. 33 | :param loaded_vocab: loaded word vocabs. 34 | :param loop: loop to get batch data or not. 35 | """ 36 | 37 | print('#'*30, '\nLoading text data...') 38 | 39 | training_data = get_corpus(data_args, seq_len, split=split, loaded_vocab=loaded_vocab) 40 | 41 | dataset = TextDataset( 42 | training_data, 43 | data_args, 44 | model_emb=model_emb 45 | ) 46 | 47 | data_loader = DataLoader( 48 | dataset, 49 | batch_size=batch_size, # 20, 50 | # drop_last=True, 51 | shuffle=not deterministic, 52 | num_workers=0, 53 | ) 54 | if loop: 55 | return infinite_loader(data_loader) 56 | else: 57 | # print(data_loader) 58 | return iter(data_loader) 59 | 60 | def infinite_loader(data_loader): 61 | while True: 62 | yield from data_loader 63 | 64 | def helper_tokenize(sentence_lst, vocab_dict, seq_len): 65 | # Process.memory_info is expressed in bytes, so convert to megabytes 66 | print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") 67 | raw_datasets = Dataset2.from_dict(sentence_lst) 68 | print(raw_datasets) 69 | print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") 70 | 71 | def tokenize_function(examples): 72 | input_id_x = vocab_dict.encode_token(examples['src']) 73 | input_id_y = vocab_dict.encode_token(examples['trg']) 74 | result_dict = {'input_id_x': input_id_x, 'input_id_y': input_id_y} 75 | 76 | return result_dict 77 | 78 | tokenized_datasets = raw_datasets.map( 79 | tokenize_function, 80 | batched=True, 81 | num_proc=4, 82 | remove_columns=['src', 'trg'], 83 | load_from_cache_file=True, 84 | desc="Running tokenizer on dataset", 85 | ) 86 | print('### tokenized_datasets', tokenized_datasets) 87 | print('### tokenized_datasets...example', tokenized_datasets['input_id_x'][0]) 88 | print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") 89 | 90 | def merge_and_mask(group_lst): 91 | lst = [] 92 | mask = [] 93 | for i in range(len(group_lst['input_id_x'])): 94 | end_token = group_lst['input_id_x'][i][-1] 95 | src = group_lst['input_id_x'][i][:-1] 96 | trg = group_lst['input_id_y'][i][:-1] 97 | while len(src) + len(trg) > seq_len - 3: 98 | if len(src)>len(trg): 99 | src.pop() 100 | elif len(src) maxlen else s 84 | 85 | def writeseq(self, seq): 86 | seq = list(seq) 87 | for (i, elem) in enumerate(seq): 88 | self.file.write(elem) 89 | if i < len(seq) - 1: # add space unless this is the last one 90 | self.file.write(" ") 91 | self.file.write("\n") 92 | self.file.flush() 93 | 94 | def close(self): 95 | if self.own_file: 96 | self.file.close() 97 | 98 | 99 | class JSONOutputFormat(KVWriter): 100 | def __init__(self, filename): 101 | self.file = open(filename, "wt") 102 | 103 | def writekvs(self, kvs): 104 | for k, v in sorted(kvs.items()): 105 | if hasattr(v, "dtype"): 106 | kvs[k] = float(v) 107 | self.file.write(json.dumps(kvs) + "\n") 108 | self.file.flush() 109 | 110 | def close(self): 111 | self.file.close() 112 | 113 | 114 | class CSVOutputFormat(KVWriter): 115 | def __init__(self, filename): 116 | self.file = open(filename, "w+t") 117 | self.keys = [] 118 | self.sep = "," 119 | 120 | def writekvs(self, kvs): 121 | # Add our current row to the history 122 | extra_keys = list(kvs.keys() - self.keys) 123 | extra_keys.sort() 124 | if extra_keys: 125 | self.keys.extend(extra_keys) 126 | self.file.seek(0) 127 | lines = self.file.readlines() 128 | self.file.seek(0) 129 | for (i, k) in enumerate(self.keys): 130 | if i > 0: 131 | self.file.write(",") 132 | self.file.write(k) 133 | self.file.write("\n") 134 | for line in lines[1:]: 135 | self.file.write(line[:-1]) 136 | self.file.write(self.sep * len(extra_keys)) 137 | self.file.write("\n") 138 | for (i, k) in enumerate(self.keys): 139 | if i > 0: 140 | self.file.write(",") 141 | v = kvs.get(k) 142 | if v is not None: 143 | self.file.write(str(v)) 144 | self.file.write("\n") 145 | self.file.flush() 146 | 147 | def close(self): 148 | self.file.close() 149 | 150 | 151 | class TensorBoardOutputFormat(KVWriter): 152 | """ 153 | Dumps key/value pairs into TensorBoard's numeric format. 154 | """ 155 | 156 | def __init__(self, dir): 157 | os.makedirs(dir, exist_ok=True) 158 | self.dir = dir 159 | self.step = 1 160 | prefix = "events" 161 | path = osp.join(osp.abspath(dir), prefix) 162 | import tensorflow as tf 163 | from tensorflow.python import pywrap_tensorflow 164 | from tensorflow.core.util import event_pb2 165 | from tensorflow.python.util import compat 166 | 167 | self.tf = tf 168 | self.event_pb2 = event_pb2 169 | self.pywrap_tensorflow = pywrap_tensorflow 170 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 171 | 172 | def writekvs(self, kvs): 173 | def summary_val(k, v): 174 | kwargs = {"tag": k, "simple_value": float(v)} 175 | return self.tf.Summary.Value(**kwargs) 176 | 177 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 178 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 179 | event.step = ( 180 | self.step 181 | ) # is there any reason why you'd want to specify the step? 182 | self.writer.WriteEvent(event) 183 | self.writer.Flush() 184 | self.step += 1 185 | 186 | def close(self): 187 | if self.writer: 188 | self.writer.Close() 189 | self.writer = None 190 | 191 | 192 | def make_output_format(format, ev_dir, log_suffix=""): 193 | os.makedirs(ev_dir, exist_ok=True) 194 | if format == "stdout": 195 | return HumanOutputFormat(sys.stdout) 196 | elif format == "log": 197 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 198 | elif format == "json": 199 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 200 | elif format == "csv": 201 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 202 | elif format == "tensorboard": 203 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 204 | else: 205 | raise ValueError("Unknown format specified: %s" % (format,)) 206 | 207 | 208 | # ================================================================ 209 | # API 210 | # ================================================================ 211 | 212 | 213 | def logkv(key, val): 214 | """ 215 | Log a value of some diagnostic 216 | Call this once for each diagnostic quantity, each iteration 217 | If called many times, last value will be used. 218 | """ 219 | get_current().logkv(key, val) 220 | 221 | 222 | def logkv_mean(key, val): 223 | """ 224 | The same as logkv(), but if called many times, values averaged. 225 | """ 226 | get_current().logkv_mean(key, val) 227 | 228 | 229 | def logkvs(d): 230 | """ 231 | Log a dictionary of key-value pairs 232 | """ 233 | for (k, v) in d.items(): 234 | logkv(k, v) 235 | 236 | 237 | def dumpkvs(): 238 | """ 239 | Write all of the diagnostics from the current iteration 240 | """ 241 | return get_current().dumpkvs() 242 | 243 | 244 | def getkvs(): 245 | return get_current().name2val 246 | 247 | 248 | def log(*args, level=INFO): 249 | """ 250 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 251 | """ 252 | get_current().log(*args, level=level) 253 | 254 | 255 | def debug(*args): 256 | log(*args, level=DEBUG) 257 | 258 | 259 | def info(*args): 260 | log(*args, level=INFO) 261 | 262 | 263 | def warn(*args): 264 | log(*args, level=WARN) 265 | 266 | 267 | def error(*args): 268 | log(*args, level=ERROR) 269 | 270 | 271 | def set_level(level): 272 | """ 273 | Set logging threshold on current logger. 274 | """ 275 | get_current().set_level(level) 276 | 277 | 278 | def set_comm(comm): 279 | get_current().set_comm(comm) 280 | 281 | 282 | def get_dir(): 283 | """ 284 | Get directory that log files are being written to. 285 | will be None if there is no output directory (i.e., if you didn't call start) 286 | """ 287 | return get_current().get_dir() 288 | 289 | 290 | record_tabular = logkv 291 | dump_tabular = dumpkvs 292 | 293 | 294 | @contextmanager 295 | def profile_kv(scopename): 296 | logkey = "wait_" + scopename 297 | tstart = time.time() 298 | try: 299 | yield 300 | finally: 301 | get_current().name2val[logkey] += time.time() - tstart 302 | 303 | 304 | def profile(n): 305 | """ 306 | Usage: 307 | @profile("my_func") 308 | def my_func(): code 309 | """ 310 | 311 | def decorator_with_name(func): 312 | def func_wrapper(*args, **kwargs): 313 | with profile_kv(n): 314 | return func(*args, **kwargs) 315 | 316 | return func_wrapper 317 | 318 | return decorator_with_name 319 | 320 | 321 | # ================================================================ 322 | # Backend 323 | # ================================================================ 324 | 325 | 326 | def get_current(): 327 | if Logger.CURRENT is None: 328 | _configure_default_logger() 329 | 330 | return Logger.CURRENT 331 | 332 | 333 | class Logger(object): 334 | DEFAULT = None # A logger with no output files. (See right below class definition) 335 | # So that you can still log to the terminal without setting up any output files 336 | CURRENT = None # Current logger being used by the free functions above 337 | 338 | def __init__(self, dir, output_formats, comm=None): 339 | self.name2val = defaultdict(float) # values this iteration 340 | self.name2cnt = defaultdict(int) 341 | self.level = INFO 342 | self.dir = dir 343 | self.output_formats = output_formats 344 | self.comm = comm 345 | 346 | # Logging API, forwarded 347 | # ---------------------------------------- 348 | def logkv(self, key, val): 349 | self.name2val[key] = val 350 | 351 | def logkv_mean(self, key, val): 352 | oldval, cnt = self.name2val[key], self.name2cnt[key] 353 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 354 | self.name2cnt[key] = cnt + 1 355 | 356 | def dumpkvs(self, prefix=None): 357 | if self.comm is None: 358 | d = self.name2val 359 | else: 360 | d = mpi_weighted_mean( 361 | self.comm, 362 | { 363 | name: (val, self.name2cnt.get(name, 1)) 364 | for (name, val) in self.name2val.items() 365 | }, 366 | ) 367 | if self.comm.rank != 0: 368 | d["dummy"] = 1 # so we don't get a warning about empty dict 369 | # LISA 370 | out = d.copy() # Return the dict for unit testing purposes 371 | if int(os.environ['LOCAL_RANK']) == 0: 372 | wandb.log({**d}) 373 | for fmt in self.output_formats: 374 | if isinstance(fmt, KVWriter): 375 | fmt.writekvs(d) 376 | self.name2val.clear() 377 | self.name2cnt.clear() 378 | return out 379 | 380 | def log(self, *args, level=INFO): 381 | if self.level <= level: 382 | self._do_log(args) 383 | 384 | # Configuration 385 | # ---------------------------------------- 386 | def set_level(self, level): 387 | self.level = level 388 | 389 | def set_comm(self, comm): 390 | self.comm = comm 391 | 392 | def get_dir(self): 393 | return self.dir 394 | 395 | def close(self): 396 | for fmt in self.output_formats: 397 | fmt.close() 398 | 399 | # Misc 400 | # ---------------------------------------- 401 | def _do_log(self, args): 402 | for fmt in self.output_formats: 403 | if isinstance(fmt, SeqWriter): 404 | fmt.writeseq(map(str, args)) 405 | 406 | 407 | def get_rank_without_mpi_import(): 408 | # check environment variables here instead of importing mpi4py 409 | # to avoid calling MPI_Init() when this module is imported 410 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 411 | if varname in os.environ: 412 | return int(os.environ[varname]) 413 | return 0 414 | 415 | 416 | def mpi_weighted_mean(comm, local_name2valcount): 417 | """ 418 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 419 | Perform a weighted average over dicts that are each on a different node 420 | Input: local_name2valcount: dict mapping key -> (value, count) 421 | Returns: key -> mean 422 | """ 423 | all_name2valcount = comm.gather(local_name2valcount) 424 | if comm.rank == 0: 425 | name2sum = defaultdict(float) 426 | name2count = defaultdict(float) 427 | for n2vc in all_name2valcount: 428 | for (name, (val, count)) in n2vc.items(): 429 | try: 430 | val = float(val) 431 | except ValueError: 432 | if comm.rank == 0: 433 | warnings.warn( 434 | "WARNING: tried to compute mean on non-float {}={}".format( 435 | name, val 436 | ) 437 | ) 438 | else: 439 | name2sum[name] += val * count 440 | name2count[name] += count 441 | return {name: name2sum[name] / name2count[name] for name in name2sum} 442 | else: 443 | return {} 444 | 445 | 446 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 447 | """ 448 | If comm is provided, average all numerical stats across that comm 449 | """ 450 | if dir is None: 451 | dir = os.getenv("OPENAI_LOGDIR") 452 | if dir is None: 453 | dir = osp.join( 454 | tempfile.gettempdir(), 455 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 456 | ) 457 | assert isinstance(dir, str) 458 | dir = os.path.expanduser(dir) 459 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 460 | 461 | rank = get_rank_without_mpi_import() 462 | if rank > 0: 463 | log_suffix = log_suffix + "-rank%03i" % rank 464 | 465 | if format_strs is None: 466 | if rank == 0: 467 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 468 | else: 469 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 470 | format_strs = filter(None, format_strs) 471 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 472 | 473 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 474 | if output_formats: 475 | log("Logging to %s" % dir) 476 | 477 | 478 | def _configure_default_logger(): 479 | configure() 480 | Logger.DEFAULT = Logger.CURRENT 481 | 482 | 483 | def reset(): 484 | if Logger.CURRENT is not Logger.DEFAULT: 485 | Logger.CURRENT.close() 486 | Logger.CURRENT = Logger.DEFAULT 487 | log("Reset logger") 488 | 489 | 490 | @contextmanager 491 | def scoped_configure(dir=None, format_strs=None, comm=None): 492 | prevlogger = Logger.CURRENT 493 | configure(dir=dir, format_strs=format_strs, comm=comm) 494 | try: 495 | yield 496 | finally: 497 | Logger.CURRENT.close() 498 | Logger.CURRENT = prevlogger 499 | 500 | -------------------------------------------------------------------------------- /diffuseq/utils/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | # print(logvar2.shape) 34 | # temp1 = 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2)) 35 | # print(f'const = {temp1.mean()}, coef={(th.exp(-logvar2) * 0.5).mean()}, mse={((mean1 - mean2) ** 2).mean().item()}') 36 | 37 | return 0.5 * ( 38 | -1.0 39 | + logvar2 40 | - logvar1 41 | + th.exp(logvar1 - logvar2) 42 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 43 | ) 44 | 45 | 46 | def approx_standard_normal_cdf(x): 47 | """ 48 | A fast approximation of the cumulative distribution function of the 49 | standard normal. 50 | """ 51 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 52 | 53 | 54 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 55 | """ 56 | Compute the log-likelihood of a Gaussian distribution discretizing to a 57 | given image. 58 | 59 | :param x: the target images. It is assumed that this was uint8 values, 60 | rescaled to the range [-1, 1]. 61 | :param means: the Gaussian mean Tensor. 62 | :param log_scales: the Gaussian log stddev Tensor. 63 | :return: a tensor like x of log probabilities (in nats). 64 | """ 65 | assert x.shape == means.shape == log_scales.shape 66 | centered_x = x - means 67 | inv_stdv = th.exp(-log_scales) 68 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 69 | cdf_plus = approx_standard_normal_cdf(plus_in) 70 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 71 | cdf_min = approx_standard_normal_cdf(min_in) 72 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 73 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 74 | cdf_delta = cdf_plus - cdf_min 75 | log_probs = th.where( 76 | x < -0.999, 77 | log_cdf_plus, 78 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 79 | ) 80 | assert log_probs.shape == x.shape 81 | return log_probs 82 | 83 | def gaussian_density(x, *, means, log_scales): 84 | from torch.distributions import Normal 85 | normal_dist = Normal(means, log_scales.exp()) 86 | logp = normal_dist.log_prob(x) 87 | return logp 88 | 89 | 90 | def discretized_text_log_likelihood(x, *, means, log_scales): 91 | """ 92 | Compute the log-likelihood of a Gaussian distribution discretizing to a 93 | given image. 94 | 95 | :param x: the target images. It is assumed that this was uint8 values, 96 | rescaled to the range [-1, 1]. 97 | :param means: the Gaussian mean Tensor. 98 | :param log_scales: the Gaussian log stddev Tensor. 99 | :return: a tensor like x of log probabilities (in nats). 100 | """ 101 | print(x.shape, means.shape) 102 | # assert x.shape == means.shape == log_scales.shape 103 | print(x, means) 104 | centered_x = x - means 105 | inv_stdv = th.exp(-log_scales) 106 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 107 | cdf_plus = approx_standard_normal_cdf(plus_in) 108 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 109 | cdf_min = approx_standard_normal_cdf(min_in) 110 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 111 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 112 | cdf_delta = cdf_plus - cdf_min 113 | log_probs = th.where( 114 | x < -0.999, 115 | log_cdf_plus, 116 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 117 | ) 118 | assert log_probs.shape == x.shape 119 | return log_probs 120 | -------------------------------------------------------------------------------- /diffuseq/utils/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 11 | class SiLU(nn.Module): 12 | def forward(self, x): 13 | return x * th.sigmoid(x) 14 | 15 | 16 | class GroupNorm32(nn.GroupNorm): 17 | def forward(self, x): 18 | return super().forward(x.float()).type(x.dtype) 19 | 20 | def linear(*args, **kwargs): 21 | """ 22 | Create a linear module. 23 | """ 24 | return nn.Linear(*args, **kwargs) 25 | 26 | 27 | def avg_pool_nd(dims, *args, **kwargs): 28 | """ 29 | Create a 1D, 2D, or 3D average pooling module. 30 | """ 31 | if dims == 1: 32 | return nn.AvgPool1d(*args, **kwargs) 33 | elif dims == 2: 34 | return nn.AvgPool2d(*args, **kwargs) 35 | elif dims == 3: 36 | return nn.AvgPool3d(*args, **kwargs) 37 | raise ValueError(f"unsupported dimensions: {dims}") 38 | 39 | 40 | def update_ema(target_params, source_params, rate=0.99): 41 | """ 42 | Update target parameters to be closer to those of source parameters using 43 | an exponential moving average. 44 | 45 | :param target_params: the target parameter sequence. 46 | :param source_params: the source parameter sequence. 47 | :param rate: the EMA rate (closer to 1 means slower). 48 | """ 49 | for targ, src in zip(target_params, source_params): 50 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 51 | 52 | 53 | def zero_module(module): 54 | """ 55 | Zero out the parameters of a module and return it. 56 | """ 57 | for p in module.parameters(): 58 | p.detach().zero_() 59 | return module 60 | 61 | 62 | def scale_module(module, scale): 63 | """ 64 | Scale the parameters of a module and return it. 65 | """ 66 | for p in module.parameters(): 67 | p.detach().mul_(scale) 68 | return module 69 | 70 | 71 | def mean_flat(tensor): 72 | """ 73 | Take the mean over all non-batch dimensions. 74 | """ 75 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 76 | 77 | 78 | def normalization(channels): 79 | """ 80 | Make a standard normalization layer. 81 | 82 | :param channels: number of input channels. 83 | :return: an nn.Module for normalization. 84 | """ 85 | return GroupNorm32(32, channels) 86 | 87 | 88 | def timestep_embedding(timesteps, dim, max_period=10000): 89 | """ 90 | Create sinusoidal timestep embeddings. 91 | 92 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 93 | These may be fractional. 94 | :param dim: the dimension of the output. 95 | :param max_period: controls the minimum frequency of the embeddings. 96 | :return: an [N x dim] Tensor of positional embeddings. 97 | """ 98 | half = dim // 2 99 | freqs = th.exp( 100 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 101 | ).to(device=timesteps.device) 102 | args = timesteps[:, None].float() * freqs[None] 103 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 104 | if dim % 2: 105 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 106 | return embedding 107 | -------------------------------------------------------------------------------- /evaluation/__pycache__/tokenizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/evaluation/__pycache__/tokenizer.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/evaluation/__pycache__/tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/eval.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import math 3 | import pdb 4 | import numpy as np 5 | from argparse import ArgumentParser 6 | from nltk import ngrams 7 | from tokenizer import SimpleTokenizer 8 | from nltk.tokenize import word_tokenize, wordpunct_tokenize 9 | from transformers import AutoTokenizer,AutoModelForCausalLM 10 | import nltk 11 | import copy 12 | import torch 13 | import evaluate 14 | from evaluate import load 15 | import os 16 | import mauve 17 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu 18 | from multiprocessing.pool import Pool 19 | from tqdm import tqdm 20 | import random 21 | from functools import partial 22 | from torchmetrics.text.rouge import ROUGEScore 23 | rougeScore = ROUGEScore() 24 | device = "cuda" if torch.cuda.is_available() else "cpu" 25 | rougeScore.to(device) 26 | import csv 27 | from bert_score import score 28 | 29 | 30 | 31 | 32 | tokenizer = SimpleTokenizer(method="nltk") 33 | 34 | 35 | 36 | def bleu(refs, cands): 37 | 38 | result = {} 39 | 40 | for i in range(1, 5): 41 | res = [] 42 | for ref,cand in zip(refs,cands): 43 | # result["bleu-%d"%i] = "%.4f"%(nltk.translate.bleu_score.corpus_bleu([[r] for r in refs], cands, weights=tuple([1./i for j in range(i)]),smoothing_function=SmoothingFunction().method4)) 44 | res.append(sentence_bleu([ref], cand, smoothing_function=SmoothingFunction().method4,weights=tuple([1./i for j in range(i)]))) 45 | result["bleu-%d"%i] = np.mean(res) 46 | 47 | # result["bleu-%d"%i] = "%.4f"%(nltk.translate.bleu_score.corpus_bleu([[r] for r in refs], cands)) 48 | return result 49 | 50 | def distinct_n_gram(hypn_lst,n): 51 | dist_list_fin = [] 52 | for hypn in hypn_lst: 53 | hypn = [hypn] 54 | dist_list = [] 55 | for hyp in hypn: 56 | hyp_ngrams = [] 57 | hyp_ngrams += nltk.ngrams(hyp.split(), n) 58 | total_ngrams = len(hyp_ngrams) 59 | unique_ngrams = len(list(set(hyp_ngrams))) 60 | if total_ngrams == 0: 61 | continue 62 | dist_list.append(unique_ngrams/total_ngrams) 63 | if total_ngrams == 0: 64 | continue 65 | dist_list_fin.append(np.mean(dist_list)) 66 | return np.mean(dist_list_fin) 67 | 68 | 69 | 70 | def repetition_distinct(hyps, times): 71 | dis_result, lex_rep = dict(), dict() 72 | for i in range(1, 5): 73 | num, all_ngram, all_ngram_num = 0, {}, 0 74 | for tokens in hyps: 75 | 76 | ngs = ["_".join(c) for c in ngrams(tokens, i)] 77 | all_ngram_num += len(ngs) 78 | for s in ngs: 79 | if s in all_ngram: 80 | all_ngram[s] += 1 81 | else: 82 | all_ngram[s] = 1 83 | for s in set(ngs): 84 | if ngs.count(s) > times: 85 | num += 1 86 | break 87 | lex_rep["repetition-%d"%i] = "%.4f"%(num / float(len(hyps))) 88 | dis_result["distinct-%d"%i] = "%.4f"%(len(all_ngram) / float(all_ngram_num)) 89 | 90 | return dis_result, lex_rep 91 | 92 | def length_(cands): 93 | lengths = [] 94 | for i in cands: 95 | lengths.append(len(i)) 96 | return sum(lengths) / len(lengths) 97 | 98 | 99 | import scipy 100 | from transformers import AutoTokenizer, AutoModel 101 | def sent_semantic_repetition(device, model_name_or_path, hyps, source): 102 | 103 | sbert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 104 | sbert_model = AutoModel.from_pretrained(model_name_or_path).to(device) 105 | 106 | def mean_pooling(model_output, attention_mask): 107 | token_embeddings = model_output[0] # First element of model_output contains all token embeddings 108 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 109 | sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) 110 | sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) 111 | return sum_embeddings / sum_mask 112 | 113 | all_distance, ss_max, ss_min, ss_mean = [], [], [], [] 114 | count = 0 115 | for k, (ipt, cand) in enumerate(zip(source, hyps)): 116 | if k % 1000 == 0: 117 | print("processing %d lines"%k) 118 | sen_list = sent_tokenize("%s %s"%(ipt.strip(), cand.strip())) 119 | if len(sen_list) == 0: 120 | continue 121 | with torch.no_grad(): 122 | encoder_input = sbert_tokenizer(sen_list,padding=True, truncation=True, return_tensors='pt', max_length=128).to(device) 123 | model_output = sbert_model(**encoder_input) 124 | sentence_embeddings = mean_pooling(model_output, encoder_input['attention_mask']).cpu().numpy() 125 | 126 | max_dis, min_dis, all_dis = [-1, -1, -1.], [-1, -1, 1.], [] 127 | if len(sentence_embeddings) == 1: 128 | continue 129 | 130 | for i, sen1 in enumerate(sentence_embeddings): 131 | for j, sen2 in enumerate(sentence_embeddings): 132 | if i < j: 133 | distances = 1 - scipy.spatial.distance.cdist([sen1], [sen2], "cosine")[0][0] 134 | all_dis.append(distances) 135 | 136 | if distances > max_dis[2]: 137 | max_dis = [i, j, distances] 138 | if distances < min_dis[2]: 139 | min_dis = [i, j, distances] 140 | sort_dis = sorted(all_dis) 141 | all_distance.append([]) 142 | for i in range(1, 11): 143 | all_distance[-1].append(np.mean(sort_dis[-i:])) 144 | 145 | ss_max.append(max_dis[2]) 146 | ss_min.append(min_dis[2]) 147 | ss_mean.append(np.mean(all_dis)) 148 | count += 1 149 | all_distance = np.mean(all_distance, 0) 150 | return { 151 | "sent_rept_max": "%.4f"%(np.mean(ss_max)), 152 | "sent_rept_min": "%.4f"%(np.mean(ss_min)), 153 | "sent_rept_mean": "%.4f"%(np.mean(ss_mean)), 154 | "all_distance": " ".join(["%.4f"%tmpf for tmpf in all_distance]), 155 | } 156 | 157 | def rouge_score(preds, golds): 158 | 159 | rouge_results = {} 160 | rouge1 =[] 161 | rouge2 = [] 162 | rougeL = [] 163 | for srcs, tgts in zip(preds, golds): 164 | # predictions = [" ".join(srcs)] 165 | # references = [[" ".join(tgts)]] 166 | # rouge.add_batch(predictions=predictions, references=references) 167 | references = " ".join(tgts) 168 | predictions = " ".join(srcs) 169 | res = rougeScore(predictions, references) 170 | rouge1.append(res["rouge1_fmeasure"]) 171 | rouge2.append(res["rouge2_fmeasure"]) 172 | rougeL.append(res["rougeL_fmeasure"]) 173 | rouge_results["rouge1"] = np.mean(rouge1) 174 | rouge_results["rouge2"] = np.mean(rouge2) 175 | rouge_results["rougeL"] = np.mean(rougeL) 176 | # rouge_results = rouge.compute() 177 | 178 | 179 | # rouge_results = rouge.compute(predictions=predictions, references=references, tokenizer=wordpunct_tokenize) 180 | return rouge_results 181 | 182 | 183 | def show_result(res_dict): 184 | for k, v in res_dict.items(): 185 | print(f"{k:} : {v:}") 186 | 187 | def ori_pro(s, name=""): 188 | s = s.strip() 189 | # for i in range(10): 190 | # s = s.replace("[%d]"%i, "") 191 | s = s.replace("", " ") 192 | s = " ".join(s.strip().split()) 193 | # s = roberta_tokenizer.decode(roberta_tokenizer.convert_tokens_to_ids(roberta_tokenizer.tokenize(s))) 194 | return s 195 | 196 | def pro(token_list, tokenizer): 197 | token_list = "".join(token_list.split(" ")) 198 | token_list = tokenizer(token_list)['input_ids'] 199 | for i, t in enumerate(token_list): 200 | if t not in [0, 2]: 201 | break 202 | token_list = token_list[i:] 203 | string = tokenizer.decode(token_list, skip_special_tokens=False) 204 | string = string.replace("", " ") 205 | string = string[:string.find("")].strip() 206 | return string 207 | 208 | def bleu_i(weights, all_sentences, smoothing_function, i): 209 | # noinspection PyTypeChecker 210 | return sentence_bleu( 211 | references=all_sentences[:i] + all_sentences[i + 1:], 212 | hypothesis=all_sentences[i], 213 | weights=weights, 214 | smoothing_function=smoothing_function) 215 | 216 | def self_bleu(generations_df, n_sample=1000): 217 | 218 | # import spacy 219 | random.seed(0) 220 | # nlp = spacy.load('en', disable=['parser', 'tagger', 'ner']) 221 | # nlp.add_pipe(nlp.create_pipe('sentencizer')) 222 | 223 | smoothing_function = SmoothingFunction().method1 224 | # all_sentences = [] 225 | # for i, row in generations_df.iterrows(): 226 | # # gens = row['tokens'] 227 | # gens = [[str(token) for token in tokens] for tokens in row['tokens']]# for gen in row['generations']] {'prompt':"", tokens: [[1,2,3], [3,4,5], [5,6,7], ....]} 228 | # all_sentences += gens 229 | 230 | all_sentences = generations_df 231 | 232 | pool = Pool(processes=os.cpu_count()) 233 | bleu_scores = [] 234 | for n_gram in range(1, 6): 235 | 236 | if n_gram == 1: 237 | weights = (1.0, 0, 0, 0) 238 | elif n_gram == 2: 239 | weights = (0.5, 0.5, 0, 0) 240 | elif n_gram == 3: 241 | weights = (1.0 / 3, 1.0 / 3, 1.0 / 3, 0) 242 | elif n_gram == 4: 243 | weights = (0.25, 0.25, 0.25, 0.25) 244 | elif n_gram == 5: 245 | weights = (0.2, 0.2, 0.2, 0.2, 0.2) 246 | else: 247 | raise ValueError 248 | bleu_scores.append( 249 | list(tqdm( 250 | pool.imap_unordered( 251 | partial(bleu_i, weights, all_sentences, smoothing_function), 252 | random.sample(range(len(all_sentences)), min(n_sample, len(all_sentences)))), 253 | total=min(n_sample, len(all_sentences)), 254 | smoothing=0.0, 255 | desc=f"bleu-{n_gram}"))) 256 | # print(f"\n\nbleu-{n_gram} = {sum(bleu_scores[n_gram - 1]) / n_sample}") 257 | 258 | pool.close() 259 | pool.join() 260 | 261 | bleus = [] 262 | for n_gram in range(5): 263 | bleus.append(sum(bleu_scores[n_gram]) / n_sample) 264 | # print(f"bleu-{n_gram + 1} = {sum(bleu_scores[n_gram]) / n_sample}") 265 | 266 | return bleus 267 | 268 | 269 | model = AutoModelForCausalLM.from_pretrained( 270 | 'gpt2-large' # path to the AR model trained for LMing this task. 271 | ).cuda() 272 | tokenizer_ppl = AutoTokenizer.from_pretrained('gpt2-large') 273 | 274 | def mean_pooling(model_output, attention_mask): 275 | token_embeddings = model_output[0] #First element of model_output contains all token embeddings 276 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 277 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 278 | 279 | tokenizer_prompt = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens') 280 | model_prompt = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens') 281 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 282 | model_prompt.to(device) 283 | 284 | def similarity(golds,preds,sources): 285 | sen_score_lst = [] 286 | 287 | for gold,pred,source in zip(golds,preds,sources): 288 | 289 | embeddings1 = tokenizer_prompt(pred, padding=True, truncation=True, return_tensors='pt') 290 | embeddings2 = tokenizer_prompt(source, padding=True, truncation=True, return_tensors='pt') 291 | embeddings1 = embeddings1.to(device) 292 | embeddings2 = embeddings2.to(device) 293 | 294 | with torch.no_grad(): 295 | e1 = model_prompt(**embeddings1) 296 | e2 = model_prompt(**embeddings2) 297 | e1 = mean_pooling(e1, embeddings1['attention_mask']) 298 | e2 = mean_pooling(e2, embeddings2['attention_mask']) 299 | sen_score_lst.append(torch.dist(e1,e2,p=2)) 300 | 301 | return sen_score_lst 302 | 303 | def eval_ppl(text_samples): 304 | ''' 305 | Evaluating using GPT2 finetuned on this task... 306 | :param text_lst: 307 | :return: 308 | ''' 309 | 310 | 311 | # print('finished loading models.') 312 | 313 | 314 | 315 | full_score = [] 316 | agg_loss = [] 317 | count = 0 318 | import math 319 | for x in text_samples: 320 | 321 | 322 | # print(x) 323 | # should also add BOS EOS token? 324 | 325 | tokenized_x = tokenizer_ppl(x, truncation=True,return_tensors='pt') #[reverse_tokenizer[s] for s in x] 326 | input_ids = tokenized_x['input_ids'].cuda() 327 | labels = input_ids.clone() 328 | 329 | # print(tokenized_x) 330 | # tokenized_x = torch.LongTensor(tokenized_x).cuda() 331 | # labels = tokenized_x.clone() 332 | # labels[labels == reverse_tokenizer['PAD']] = -100 333 | model_output = model(input_ids, labels=labels) 334 | if not math.isnan(model_output.loss.item()): 335 | agg_loss.append(model_output.loss.item()) 336 | else: 337 | count += 1 338 | 339 | print("nan count:{}:{}".format(count,len(text_samples))) 340 | 341 | example_mean_score = torch.tensor(agg_loss).mean() 342 | 343 | full_score.append(example_mean_score) 344 | 345 | full_score_ = np.array(full_score).mean() 346 | 347 | # print(f'full NLL score is {full_score_} for {len(full_score)}') 348 | # print(f'full PPL score is {np.e ** full_score_} for {len(full_score)}') 349 | 350 | return np.e ** full_score_ 351 | 352 | 353 | 354 | tokenizer_gpt = AutoTokenizer.from_pretrained("gpt2") 355 | 356 | if __name__ == "__main__": 357 | parser = ArgumentParser() 358 | parser.add_argument('--source-file', '-s', dest="source_file", help='source file', default="./src.txt") 359 | parser.add_argument('--golden-file', '-t', dest="golden_file", help='Input data file, one golden per line.', default="./gold.txt") 360 | parser.add_argument('--pred-file', dest="pred_file", help='Model predictions.', default="./pred.txt") 361 | parser.add_argument('--times', '-k', help='calculate the lexical repetitation of different datasets', default="4") 362 | parser.add_argument('--model_path_or_name', '-p', help='where the config and tokenizer store') 363 | parser.add_argument('--folder') 364 | parser.add_argument('--save_dir', default="./") 365 | args = parser.parse_args() 366 | 367 | cnt = 0 368 | paths = sorted(glob.glob(glob.escape(f"{args.folder}")+"/*json")) 369 | print(paths) 370 | 371 | 372 | 373 | for path in tqdm(paths): 374 | print(path) 375 | bleu_1 = [] 376 | bleu_2 = [] 377 | bleu_3 = [] 378 | bleu_4 = [] 379 | self_bleu_1 = [] 380 | self_bleu_2 = [] 381 | self_bleu_3 = [] 382 | self_bleu_4 = [] 383 | self_bleu_5 = [] 384 | times2_repetition_1 = [] 385 | times2_repetition_2 = [] 386 | times2_repetition_3 = [] 387 | times2_repetition_4 = [] 388 | times4_repetition_1 = [] 389 | times4_repetition_2 = [] 390 | times4_repetition_3 = [] 391 | times4_repetition_4 = [] 392 | rouge1 = [] 393 | rouge2 = [] 394 | rougeL = [] 395 | bert_prec = [] 396 | bert_recall = [] 397 | bert_f1 = [] 398 | dist1 = [] 399 | dist2 = [] 400 | dist3 = [] 401 | dist4 = [] 402 | text_ppls = [] 403 | gold_ppls = [] 404 | deltas = [] 405 | m_score = [] 406 | sen_score = [] 407 | 408 | import json 409 | with open(path, "r") as f: 410 | lst = f.readlines() 411 | lst = [json.loads(i) for i in lst] 412 | 413 | 414 | golds = [] 415 | preds = [] 416 | sources = [] 417 | for d in lst: 418 | d["reference"] = d["reference"].replace("[CLS]","").replace("[SEP]","").strip() 419 | d["recover"] = d["recover"].replace("[CLS]","").replace("[SEP]","").strip() 420 | d["source"] = d["source"].replace("[CLS]","").replace("[SEP]","").strip() 421 | if d["reference"] == "" or d["recover"] == "": 422 | continue 423 | golds.append(d["reference"].replace("[CLS]","").replace("[SEP]","").strip()) 424 | preds.append(d["recover"].replace("[CLS]","").replace("[SEP]","").strip()) 425 | sources.append(d["source"].replace("[CLS]","").replace("[SEP]","").strip()) 426 | 427 | source_golds = [] 428 | source_preds = [] 429 | for i,j,z in zip(golds,preds,sources): 430 | source_golds.append(z+" "+i) 431 | source_preds.append(z+" "+j) 432 | 433 | sen_score = similarity(golds,preds,sources) 434 | 435 | 436 | preds_str = preds 437 | golds_str = golds 438 | 439 | preds_bleu = [] 440 | golds_bleu = [] 441 | for i,j in zip(preds,golds): 442 | preds_bleu.append(i.split()) 443 | golds_bleu.append(j.split()) 444 | 445 | preds, golds = [tokenizer.tokenize(i) for i in preds], [tokenizer.tokenize(i) for i in golds] 446 | 447 | bleu_result = bleu(refs = golds_bleu, cands = preds_bleu) 448 | 449 | dis_result, lex_rep2 = repetition_distinct(preds, 2) 450 | dis_result, lex_rep4 = repetition_distinct(preds, 4) 451 | 452 | 453 | 454 | 455 | 456 | len_ = length_(preds) 457 | len_golds = length_(golds) 458 | 459 | # bertscore_result = bert_score(preds, golds) 460 | torch.cuda.empty_cache() 461 | P, R, F1 = score(preds_str, golds_str, model_type='microsoft/deberta-xlarge-mnli', lang='en', verbose=True) 462 | 463 | P = torch.mean(P) 464 | R = torch.mean(R) 465 | F1 = torch.mean(F1) 466 | 467 | rouge_result = rouge_score(preds, golds) 468 | print(path) 469 | text_ppl = eval_ppl(preds_str) 470 | gold_ppl = eval_ppl(golds_str) 471 | delta = text_ppl / gold_ppl 472 | delta = math.log(delta) 473 | 474 | 475 | 476 | recovers = [] 477 | for i in preds_str: 478 | recover = tokenizer_gpt.encode(i) 479 | recover = list(map(str,recover)) 480 | recovers.append(recover) 481 | self_bleus = self_bleu(recovers, n_sample=1000) 482 | 483 | # if "roc" in path: 484 | # print("roc",path) 485 | # m = mauve.compute_mauve(p_text=source_golds, q_text=source_preds, device_id=0, max_text_length=256, verbose=False) 486 | # else: 487 | # m = mauve.compute_mauve(p_text=golds_str, q_text=preds_str, device_id=0, max_text_length=256, verbose=False) 488 | 489 | 490 | # pdb.set_trace() 491 | # print(m) 492 | 493 | 494 | 495 | # m_score.append(m.mauve) 496 | 497 | 498 | # bert_prec.append(bertscore_result["precision"]) 499 | # bert_recall.append(bertscore_result["recall"]) 500 | # bert_f1.append(bertscore_result["f1"]) 501 | 502 | 503 | bert_prec.append(P) 504 | bert_recall.append(R) 505 | bert_f1.append(F1) 506 | bleu_1.append(float(bleu_result["bleu-1"])) 507 | bleu_2.append(float(bleu_result["bleu-2"])) 508 | bleu_3.append(float(bleu_result["bleu-3"])) 509 | bleu_4.append(float(bleu_result["bleu-4"])) 510 | times2_repetition_1.append(float(lex_rep2["repetition-1"])) 511 | times2_repetition_2.append(float(lex_rep2["repetition-2"])) 512 | times2_repetition_3.append(float(lex_rep2["repetition-3"])) 513 | times2_repetition_4.append(float(lex_rep2["repetition-4"])) 514 | times4_repetition_1.append(float(lex_rep4["repetition-1"])) 515 | times4_repetition_2.append(float(lex_rep4["repetition-2"])) 516 | times4_repetition_3.append(float(lex_rep4["repetition-3"])) 517 | times4_repetition_4.append(float(lex_rep4["repetition-4"])) 518 | self_bleu_1.append(self_bleus[0]) 519 | self_bleu_2.append(self_bleus[1]) 520 | self_bleu_3.append(self_bleus[2]) 521 | self_bleu_4.append(self_bleus[3]) 522 | self_bleu_5.append(self_bleus[4]) 523 | rouge1.append(rouge_result["rouge1"]) 524 | rouge2.append(rouge_result["rouge2"]) 525 | rougeL.append(rouge_result["rougeL"]) 526 | dist1.append(float(dis_result["distinct-1"])) 527 | dist2.append(float(dis_result["distinct-2"])) 528 | dist3.append(float(dis_result["distinct-3"])) 529 | dist4.append(float(dis_result["distinct-4"])) 530 | text_ppls.append(text_ppl) 531 | gold_ppls.append(gold_ppl) 532 | deltas.append(abs(delta)) 533 | 534 | 535 | 536 | evaluate = [bleu_1,bleu_2,bleu_3,bleu_4,times2_repetition_1,times2_repetition_2,times2_repetition_3,times2_repetition_4,times4_repetition_1,times4_repetition_2,times4_repetition_3,times4_repetition_4,self_bleu_1,self_bleu_2,self_bleu_3,self_bleu_4,self_bleu_5,rouge1,rouge2,rougeL,dist1,dist2,dist3,dist4,text_ppls,gold_ppls,deltas,m_score,bert_prec,bert_recall,bert_f1,sen_score] 537 | 538 | evaluate_name = ["bleu_1","bleu_2","bleu_3","bleu_4","times2_repetition_1","times2_repetition_2","times2_repetition_3","times2_repetition_4","times4_repetition_1","times4_repetition_2","times4_repetition_3","times4_repetition_4","self_bleu_1","self_bleu_2","self_bleu_3","self_bleu_4","self_bleu_5","rouge1","rouge2","rougeL","dist1","dist2","dist3","dist4","text_ppls","gold_ppls","deltas","mauve_score","bert_prec","bert_recall","bert_f1","sim"] 539 | 540 | print("folder_path:",path) 541 | for name,eva in zip(evaluate_name,evaluate): 542 | if len(eva) != 0: 543 | print("{}:".format(name),sum(eva)/len(eva)) 544 | 545 | -------------------------------------------------------------------------------- /evaluation/quick_eval.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import math 3 | import pdb 4 | import numpy as np 5 | from argparse import ArgumentParser 6 | from nltk import ngrams 7 | from tokenizer import SimpleTokenizer 8 | from nltk.tokenize import word_tokenize, wordpunct_tokenize 9 | from transformers import AutoTokenizer,AutoModelForCausalLM 10 | import nltk 11 | import copy 12 | import torch 13 | import evaluate 14 | from evaluate import load 15 | import os 16 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu 17 | from multiprocessing.pool import Pool 18 | from tqdm import tqdm 19 | import random 20 | from functools import partial 21 | from torchmetrics.text.rouge import ROUGEScore 22 | rougeScore = ROUGEScore() 23 | device = "cuda" if torch.cuda.is_available() else "cpu" 24 | rougeScore.to(device) 25 | import csv 26 | # from bert_score import score 27 | 28 | 29 | tokenizer = SimpleTokenizer(method="nltk") 30 | 31 | 32 | 33 | def bleu(refs, cands): 34 | 35 | result = {} 36 | 37 | for i in range(1, 5): 38 | res = [] 39 | for ref,cand in zip(refs,cands): 40 | # result["bleu-%d"%i] = "%.4f"%(nltk.translate.bleu_score.corpus_bleu([[r] for r in refs], cands, weights=tuple([1./i for j in range(i)]),smoothing_function=SmoothingFunction().method4)) 41 | try: 42 | res.append(sentence_bleu([ref], cand, smoothing_function=SmoothingFunction().method4,weights=tuple([1./i for j in range(i)]))) 43 | except: 44 | pass 45 | result["bleu-%d"%i] = np.mean(res) 46 | 47 | # result["bleu-%d"%i] = "%.4f"%(nltk.translate.bleu_score.corpus_bleu([[r] for r in refs], cands)) 48 | return result 49 | 50 | def distinct_n_gram(hypn_lst,n): 51 | dist_list_fin = [] 52 | for hypn in hypn_lst: 53 | hypn = [hypn] 54 | dist_list = [] 55 | for hyp in hypn: 56 | hyp_ngrams = [] 57 | hyp_ngrams += nltk.ngrams(hyp.split(), n) 58 | total_ngrams = len(hyp_ngrams) 59 | unique_ngrams = len(list(set(hyp_ngrams))) 60 | if total_ngrams == 0: 61 | continue 62 | dist_list.append(unique_ngrams/total_ngrams) 63 | if total_ngrams == 0: 64 | continue 65 | dist_list_fin.append(np.mean(dist_list)) 66 | return np.mean(dist_list_fin) 67 | 68 | 69 | 70 | def repetition_distinct(hyps, times): 71 | dis_result, lex_rep = dict(), dict() 72 | for i in range(1, 5): 73 | num, all_ngram, all_ngram_num = 0, {}, 0 74 | for tokens in hyps: 75 | 76 | ngs = ["_".join(c) for c in ngrams(tokens, i)] 77 | all_ngram_num += len(ngs) 78 | for s in ngs: 79 | if s in all_ngram: 80 | all_ngram[s] += 1 81 | else: 82 | all_ngram[s] = 1 83 | for s in set(ngs): 84 | if ngs.count(s) > times: 85 | num += 1 86 | break 87 | lex_rep["repetition-%d"%i] = "%.4f"%(num / float(len(hyps))) 88 | dis_result["distinct-%d"%i] = "%.4f"%(len(all_ngram) / float(all_ngram_num)) 89 | 90 | return dis_result, lex_rep 91 | 92 | def length_(cands): 93 | lengths = [] 94 | for i in cands: 95 | lengths.append(len(i)) 96 | return sum(lengths) / len(lengths) 97 | 98 | 99 | 100 | def rouge_score(preds, golds): 101 | 102 | rouge_results = {} 103 | rouge1 =[] 104 | rouge2 = [] 105 | rougeL = [] 106 | for srcs, tgts in zip(preds, golds): 107 | # predictions = [" ".join(srcs)] 108 | # references = [[" ".join(tgts)]] 109 | # rouge.add_batch(predictions=predictions, references=references) 110 | references = " ".join(tgts) 111 | predictions = " ".join(srcs) 112 | res = rougeScore(predictions, references) 113 | rouge1.append(res["rouge1_fmeasure"]) 114 | rouge2.append(res["rouge2_fmeasure"]) 115 | rougeL.append(res["rougeL_fmeasure"]) 116 | rouge_results["rouge1"] = np.mean(rouge1) 117 | rouge_results["rouge2"] = np.mean(rouge2) 118 | rouge_results["rougeL"] = np.mean(rougeL) 119 | # rouge_results = rouge.compute() 120 | 121 | 122 | # rouge_results = rouge.compute(predictions=predictions, references=references, tokenizer=wordpunct_tokenize) 123 | return rouge_results 124 | 125 | 126 | def show_result(res_dict): 127 | for k, v in res_dict.items(): 128 | print(f"{k:} : {v:}") 129 | 130 | def ori_pro(s, name=""): 131 | s = s.strip() 132 | # for i in range(10): 133 | # s = s.replace("[%d]"%i, "") 134 | s = s.replace("", " ") 135 | s = " ".join(s.strip().split()) 136 | # s = roberta_tokenizer.decode(roberta_tokenizer.convert_tokens_to_ids(roberta_tokenizer.tokenize(s))) 137 | return s 138 | 139 | def pro(token_list, tokenizer): 140 | token_list = "".join(token_list.split(" ")) 141 | token_list = tokenizer(token_list)['input_ids'] 142 | for i, t in enumerate(token_list): 143 | if t not in [0, 2]: 144 | break 145 | token_list = token_list[i:] 146 | string = tokenizer.decode(token_list, skip_special_tokens=False) 147 | string = string.replace("", " ") 148 | string = string[:string.find("")].strip() 149 | return string 150 | 151 | def bleu_i(weights, all_sentences, smoothing_function, i): 152 | # noinspection PyTypeChecker 153 | return sentence_bleu( 154 | references=all_sentences[:i] + all_sentences[i + 1:], 155 | hypothesis=all_sentences[i], 156 | weights=weights, 157 | smoothing_function=smoothing_function) 158 | 159 | def self_bleu(generations_df, n_sample=1000): 160 | 161 | # import spacy 162 | random.seed(0) 163 | # nlp = spacy.load('en', disable=['parser', 'tagger', 'ner']) 164 | # nlp.add_pipe(nlp.create_pipe('sentencizer')) 165 | 166 | smoothing_function = SmoothingFunction().method1 167 | # all_sentences = [] 168 | # for i, row in generations_df.iterrows(): 169 | # # gens = row['tokens'] 170 | # gens = [[str(token) for token in tokens] for tokens in row['tokens']]# for gen in row['generations']] {'prompt':"", tokens: [[1,2,3], [3,4,5], [5,6,7], ....]} 171 | # all_sentences += gens 172 | 173 | all_sentences = generations_df 174 | 175 | pool = Pool(processes=os.cpu_count()) 176 | bleu_scores = [] 177 | for n_gram in range(1, 6): 178 | 179 | if n_gram == 1: 180 | weights = (1.0, 0, 0, 0) 181 | elif n_gram == 2: 182 | weights = (0.5, 0.5, 0, 0) 183 | elif n_gram == 3: 184 | weights = (1.0 / 3, 1.0 / 3, 1.0 / 3, 0) 185 | elif n_gram == 4: 186 | weights = (0.25, 0.25, 0.25, 0.25) 187 | elif n_gram == 5: 188 | weights = (0.2, 0.2, 0.2, 0.2, 0.2) 189 | else: 190 | raise ValueError 191 | bleu_scores.append( 192 | list(tqdm( 193 | pool.imap_unordered( 194 | partial(bleu_i, weights, all_sentences, smoothing_function), 195 | random.sample(range(len(all_sentences)), min(n_sample, len(all_sentences)))), 196 | total=min(n_sample, len(all_sentences)), 197 | smoothing=0.0, 198 | desc=f"bleu-{n_gram}"))) 199 | # print(f"\n\nbleu-{n_gram} = {sum(bleu_scores[n_gram - 1]) / n_sample}") 200 | 201 | pool.close() 202 | pool.join() 203 | 204 | bleus = [] 205 | for n_gram in range(5): 206 | bleus.append(sum(bleu_scores[n_gram]) / n_sample) 207 | # print(f"bleu-{n_gram + 1} = {sum(bleu_scores[n_gram]) / n_sample}") 208 | 209 | return bleus 210 | 211 | 212 | tokenizer_gpt = AutoTokenizer.from_pretrained("./models/gpt2") 213 | 214 | if __name__ == "__main__": 215 | parser = ArgumentParser() 216 | parser.add_argument('--source-file', '-s', dest="source_file", help='source file', default="./src.txt") 217 | parser.add_argument('--golden-file', '-t', dest="golden_file", help='Input data file, one golden per line.', default="./gold.txt") 218 | parser.add_argument('--pred-file', dest="pred_file", help='Model predictions.', default="./pred.txt") 219 | parser.add_argument('--times', '-k', help='calculate the lexical repetitation of different datasets', default="4") 220 | parser.add_argument('--model_path_or_name', '-p', help='where the config and tokenizer store') 221 | parser.add_argument('--folder') 222 | parser.add_argument('--save_dir', default="./") 223 | args = parser.parse_args() 224 | 225 | cnt = 0 226 | paths = sorted(glob.glob(glob.escape(f"{args.folder}")+"/*json")) 227 | print(paths) 228 | 229 | 230 | 231 | for path in tqdm(paths): 232 | print(path) 233 | bleu_1 = [] 234 | bleu_2 = [] 235 | bleu_3 = [] 236 | bleu_4 = [] 237 | self_bleu_1 = [] 238 | self_bleu_2 = [] 239 | self_bleu_3 = [] 240 | self_bleu_4 = [] 241 | self_bleu_5 = [] 242 | times2_repetition_1 = [] 243 | times2_repetition_2 = [] 244 | times2_repetition_3 = [] 245 | times2_repetition_4 = [] 246 | times4_repetition_1 = [] 247 | times4_repetition_2 = [] 248 | times4_repetition_3 = [] 249 | times4_repetition_4 = [] 250 | rouge1 = [] 251 | rouge2 = [] 252 | rougeL = [] 253 | bert_prec = [] 254 | bert_recall = [] 255 | bert_f1 = [] 256 | dist1 = [] 257 | dist2 = [] 258 | dist3 = [] 259 | dist4 = [] 260 | text_ppls = [] 261 | gold_ppls = [] 262 | deltas = [] 263 | m_score = [] 264 | sen_score = [] 265 | 266 | import json 267 | with open(path, "r") as f: 268 | lst = f.readlines() 269 | lst = [json.loads(i) for i in lst] 270 | 271 | 272 | golds = [] 273 | preds = [] 274 | sources = [] 275 | for d in lst: 276 | d["reference"] = d["reference"].replace("[CLS]","").replace("[SEP]","").strip() 277 | d["recover"] = d["recover"].replace("[CLS]","").replace("[SEP]","").strip() 278 | # d["source"] = d["source"].replace("[CLS]","").replace("[SEP]","").strip() 279 | if d["reference"] == "" or d["recover"] == "": 280 | continue 281 | golds.append(d["reference"].replace("[CLS]","").replace("[SEP]","").strip()) 282 | preds.append(d["recover"].replace("[CLS]","").replace("[SEP]","").strip()) 283 | # sources.append(d["source"].replace("[CLS]","").replace("[SEP]","").strip()) 284 | 285 | # source_golds = [] 286 | # source_preds = [] 287 | # for i,j,z in zip(golds,preds,sources): 288 | # source_golds.append(z+" "+i) 289 | # source_preds.append(z+" "+j) 290 | 291 | 292 | 293 | 294 | preds_str = preds 295 | golds_str = golds 296 | 297 | preds_bleu = [] 298 | golds_bleu = [] 299 | for i,j in zip(preds,golds): 300 | preds_bleu.append(i.split()) 301 | golds_bleu.append(j.split()) 302 | 303 | preds, golds = [tokenizer.tokenize(i) for i in preds], [tokenizer.tokenize(i) for i in golds] 304 | 305 | bleu_result = bleu(refs = golds_bleu, cands = preds_bleu) 306 | 307 | dis_result, lex_rep2 = repetition_distinct(preds, 2) 308 | dis_result, lex_rep4 = repetition_distinct(preds, 4) 309 | 310 | 311 | 312 | 313 | 314 | len_ = length_(preds) 315 | len_golds = length_(golds) 316 | 317 | # bertscore_result = bert_score(preds, golds) 318 | torch.cuda.empty_cache() 319 | # P, R, F1 = score(preds_str, golds_str, model_type='microsoft/deberta-xlarge-mnli', lang='en', verbose=True) 320 | 321 | # P = torch.mean(P) 322 | # R = torch.mean(R) 323 | # F1 = torch.mean(F1) 324 | 325 | rouge_result = rouge_score(preds, golds) 326 | print(path) 327 | 328 | 329 | 330 | 331 | recovers = [] 332 | for i in preds_str: 333 | recover = tokenizer_gpt.encode(i) 334 | recover = list(map(str,recover)) 335 | recovers.append(recover) 336 | self_bleus = self_bleu(recovers, n_sample=1000) 337 | 338 | # bert_prec.append(P) 339 | # bert_recall.append(R) 340 | # bert_f1.append(F1) 341 | bleu_1.append(float(bleu_result["bleu-1"])) 342 | bleu_2.append(float(bleu_result["bleu-2"])) 343 | bleu_3.append(float(bleu_result["bleu-3"])) 344 | bleu_4.append(float(bleu_result["bleu-4"])) 345 | times2_repetition_1.append(float(lex_rep2["repetition-1"])) 346 | times2_repetition_2.append(float(lex_rep2["repetition-2"])) 347 | times2_repetition_3.append(float(lex_rep2["repetition-3"])) 348 | times2_repetition_4.append(float(lex_rep2["repetition-4"])) 349 | times4_repetition_1.append(float(lex_rep4["repetition-1"])) 350 | times4_repetition_2.append(float(lex_rep4["repetition-2"])) 351 | times4_repetition_3.append(float(lex_rep4["repetition-3"])) 352 | times4_repetition_4.append(float(lex_rep4["repetition-4"])) 353 | self_bleu_1.append(self_bleus[0]) 354 | self_bleu_2.append(self_bleus[1]) 355 | self_bleu_3.append(self_bleus[2]) 356 | self_bleu_4.append(self_bleus[3]) 357 | self_bleu_5.append(self_bleus[4]) 358 | rouge1.append(rouge_result["rouge1"]) 359 | rouge2.append(rouge_result["rouge2"]) 360 | rougeL.append(rouge_result["rougeL"]) 361 | dist1.append(float(dis_result["distinct-1"])) 362 | dist2.append(float(dis_result["distinct-2"])) 363 | dist3.append(float(dis_result["distinct-3"])) 364 | dist4.append(float(dis_result["distinct-4"])) 365 | 366 | 367 | 368 | 369 | evaluate = [bleu_1,bleu_2,bleu_3,bleu_4,times2_repetition_1,times2_repetition_2,times2_repetition_3,times2_repetition_4,times4_repetition_1,times4_repetition_2,times4_repetition_3,times4_repetition_4,self_bleu_1,self_bleu_2,self_bleu_3,self_bleu_4,self_bleu_5,rouge1,rouge2,rougeL,dist1,dist2,dist3,dist4] 370 | 371 | evaluate_name = ["bleu_1","bleu_2","bleu_3","bleu_4","times2_repetition_1","times2_repetition_2","times2_repetition_3","times2_repetition_4","times4_repetition_1","times4_repetition_2","times4_repetition_3","times4_repetition_4","self_bleu_1","self_bleu_2","self_bleu_3","self_bleu_4","self_bleu_5","rouge1","rouge2","rougeL","dist1","dist2","dist3","dist4"] 372 | 373 | print("folder_path:",path) 374 | for name,eva in zip(evaluate_name,evaluate): 375 | if len(eva) != 0: 376 | print("{}:".format(name),sum(eva)/len(eva)) 377 | 378 | -------------------------------------------------------------------------------- /evaluation/tokenizer.py: -------------------------------------------------------------------------------- 1 | """A module for Tokenizer""" 2 | import re 3 | 4 | from nltk.tokenize import WordPunctTokenizer 5 | import spacy 6 | 7 | class SimpleTokenizer(): 8 | ''' 9 | A simple tokenizer. ``method`` can either be ``nltk``or ``space`` or ``spacy'' or ``spacy_zh''. 10 | If ``nltk``, use ``WordPunctTokenizer`` from ``nltk.tokenize``. 11 | If ``space``, use ``str.split(" ")``. 12 | If ``spacy``, use ``spacy.load("en_core_web_sm")``. 13 | If ``spacy_zh``, use ``spacy.load("zh_core_web_sm")``. 14 | Arguments: 15 | method (str): the tokenization method, ``nltk`` or ``space``. 16 | special_tokens (List[str]): special tokens not to tokenize, such as ````. 17 | ''' 18 | def __init__(self, method, special_tokens = None): 19 | self.method = method 20 | self.special_tokens = special_tokens 21 | 22 | if method == "nltk": 23 | self._callable_tokenizer = WordPunctTokenizer().tokenize 24 | elif method == "space": 25 | self._callable_tokenizer = str.split 26 | elif method == "spacy": 27 | # python -m spacy download en_core_web_sm 28 | self._callable_tokenizer = spacy.load('en_core_web_sm') 29 | elif method == "spacy_zh": 30 | # python -m spacy download zh_core_web_sm 31 | self._callable_tokenizer = spacy.load('zh_core_web_sm') 32 | else: 33 | raise ValueError('`method` is invalid value {}, should be "nltk" or "space" or "spacy" '.format(method)) 34 | 35 | def tokenize(self, sentence): 36 | '''Tokenize a sentence to a list of tokens. 37 | Arguments: 38 | sentence (str): a sentence to tokenize. 39 | ''' 40 | if self.special_tokens is None: 41 | return self._callable_tokenizer(sentence) 42 | regexPattern = '(' + '|'.join(map(re.escape, self.special_tokens)) + ')' 43 | segments = re.split(regexPattern, sentence) 44 | sent = [] 45 | for seg in segments: 46 | if seg not in self.special_tokens: 47 | sent += self._callable_tokenizer(seg.strip()) 48 | else: 49 | sent += [seg] 50 | return sent 51 | 52 | def convert_tokens_to_sentence(self, tokens): 53 | '''Convert tokens to sentence. 54 | It usually works like the reverse operation of :meth:`tokenize`, but it is not gauranteed. 55 | It may like ``" ".join(tokens)``, but some special condition and tokens will be took care. 56 | Arguments: 57 | tokens(List[str]): tokenized sentence 58 | ''' 59 | if self.method == "nltk": 60 | sent = " ".join(tokens) 61 | out_string = sent.replace(' .', '.').replace(' ?', '?'). \ 62 | replace(' !', '!').replace(' ,', ',').replace(" ' ", "'"). \ 63 | replace(" n't", "n't").replace(" 'm", "'m"). \ 64 | replace(" 's", "'s"). \ 65 | replace(" 've", "'ve").replace(" 're", "'re") 66 | return out_string 67 | elif self.method == "space" or self.method == "spacy" or self.method == "spacy_zh": 68 | return " ".join(tokens) 69 | else: 70 | raise RuntimeError("No such tokenizer %s" % self.method) 71 | 72 | def name(self): 73 | return "SimpleTokenizer/" + self.method 74 | 75 | class PretrainedTokenizer(): 76 | '''Bases: :class:`.dataloader.Tokenizer` 77 | A wrapper for ``Pretrainedtokenizer`` from ``transformers`` package. 78 | If you don't want to do tokenization on some special tokens, see 79 | ``transformers.Pretrainedtokenizer.add_special_tokens``. 80 | Arguments: 81 | tokenizer (transformers.Pretrainedtokenizer): An 82 | instance of ``transformers.Pretrainedtokenizer``. 83 | ''' 84 | def __init__(self, method): 85 | if "gpt" in method: 86 | from transformers.tokenization_gpt2 import GPT2Tokenizer 87 | self.tokenizer = GPT2Tokenizer.from_pretrained(method) 88 | elif "bert" in method: 89 | from transformers.tokenization_bert import BertTokenizer 90 | self.tokenizer = BertTokenizer.from_pretrained(method) 91 | else: 92 | raise ValueError('`method` is invalid value {}, should be "gpt"/"bpe" or "bert"'.format(method)) 93 | 94 | self._tokenizer_class_name = self.tokenizer.__class__.__name__ 95 | 96 | def tokenize(self, sentence): 97 | return self.tokenizer.tokenize(sentence) 98 | 99 | def convert_tokens_to_sentence(self, tokens): 100 | return self.tokenizer.convert_tokens_to_string(tokens) 101 | 102 | def name(self): 103 | return "PretrainedTokenizer/" + self.method -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.4 2 | aiosignal==1.3.1 3 | appdirs==1.4.4 4 | async-timeout==4.0.2 5 | asynctest==0.13.0 6 | attrs==23.1.0 7 | blobfile==2.0.2 8 | certifi @ file:///croot/certifi_1671487769961/work/certifi 9 | charset-normalizer==3.1.0 10 | click==8.1.3 11 | datasets==2.9.0 12 | dill==0.3.6 13 | docker-pycreds==0.4.0 14 | filelock==3.12.0 15 | frozenlist==1.3.3 16 | fsspec==2023.1.0 17 | gitdb==4.0.10 18 | GitPython==3.1.31 19 | huggingface-hub==0.14.1 20 | idna==3.4 21 | importlib-metadata==6.6.0 22 | lxml==4.9.2 23 | multidict==6.0.4 24 | multiprocess==0.70.14 25 | numpy==1.21.6 26 | nvidia-cublas-cu11==11.10.3.66 27 | nvidia-cuda-nvrtc-cu11==11.7.99 28 | nvidia-cuda-runtime-cu11==11.7.99 29 | nvidia-cudnn-cu11==8.5.0.96 30 | packaging==23.1 31 | pandas==1.3.5 32 | pathtools==0.1.2 33 | protobuf==4.22.4 34 | psutil==5.9.5 35 | pyarrow==12.0.0 36 | pycryptodomex==3.17 37 | python-dateutil==2.8.2 38 | pytz==2023.3 39 | PyYAML==6.0 40 | regex==2023.5.5 41 | requests==2.30.0 42 | responses==0.18.0 43 | sentry-sdk==1.22.1 44 | setproctitle==1.3.2 45 | six==1.16.0 46 | smmap==5.0.0 47 | tokenizers==0.12.1 48 | torch==1.13.1 49 | tqdm==4.65.0 50 | transformers==4.19.4 51 | typing_extensions==4.5.0 52 | urllib3==1.26.15 53 | wandb==0.15.2 54 | xxhash==3.2.0 55 | yarl==1.9.2 56 | zipp==3.15.0 57 | -------------------------------------------------------------------------------- /sample_seq2seq.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import os, json 8 | from tracemalloc import start 9 | 10 | import numpy as np 11 | import torch as th 12 | import time 13 | import torch.distributed as dist 14 | from transformers import set_seed 15 | from diffuseq.rounding import denoised_fn_round, get_weights 16 | from diffuseq.text_datasets import load_data_text 17 | 18 | # from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 19 | 20 | import time 21 | from diffuseq.utils import dist_util, logger 22 | from functools import partial 23 | from basic_utils import ( 24 | load_defaults_config, 25 | create_model_and_diffusion, 26 | add_dict_to_argparser, 27 | args_to_dict, 28 | load_model_emb, 29 | load_tokenizer 30 | ) 31 | 32 | def create_argparser(): 33 | defaults = dict(model_path='', step=0, out_dir='',) 34 | decode_defaults = dict(split='valid', clamp_step=0, seed2=105, clip_denoised=False, eta=-1.0, decode_respacing="", \ 35 | top_l = 0.0 , top_k = 0, noised_l = 0.0, scale_end = "", top_p = 0.0, tau = 1.0, clamp_skip = -1, p_d="") 36 | defaults.update(load_defaults_config()) 37 | defaults.update(decode_defaults) 38 | parser = argparse.ArgumentParser() 39 | add_dict_to_argparser(parser, defaults) 40 | return parser 41 | 42 | 43 | def main(): 44 | args = create_argparser().parse_args() 45 | 46 | dist_util.setup_dist() 47 | logger.configure() 48 | 49 | # load configurations. 50 | config_path = os.path.join(os.path.split(args.model_path)[0], "training_args.json") 51 | print(config_path) 52 | # sys.setdefaultencoding('utf-8') 53 | with open(config_path, 'rb', ) as f: 54 | training_args = json.load(f) 55 | training_args['batch_size'] = args.batch_size 56 | args.__dict__.update(training_args) 57 | 58 | # mycode 59 | args.diffusion_steps = args.step 60 | args.timestep_respacing = args.decode_respacing 61 | 62 | 63 | logger.log("### Creating model and diffusion...") 64 | model, diffusion = create_model_and_diffusion( 65 | **args_to_dict(args, load_defaults_config().keys()) 66 | ) 67 | 68 | 69 | model.load_state_dict( 70 | dist_util.load_state_dict(args.model_path, map_location="cpu") 71 | ) 72 | 73 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 74 | logger.log(f'### The parameter count is {pytorch_total_params}') 75 | 76 | model.to(dist_util.dev()) 77 | model.eval() 78 | 79 | tokenizer = load_tokenizer(args) 80 | model_emb, tokenizer = load_model_emb(args, tokenizer) 81 | 82 | model_emb.weight = th.nn.Parameter(model.word_embedding.weight.clone().cpu()) 83 | model_emb_copy = get_weights(model_emb, args) 84 | 85 | set_seed(args.seed2) 86 | 87 | print("### Sampling...on", args.split) 88 | 89 | ## load data 90 | data_valid = load_data_text( 91 | batch_size=args.batch_size, 92 | seq_len=args.seq_len, 93 | deterministic=True, 94 | data_args=args, 95 | split=args.split, 96 | loaded_vocab=tokenizer, 97 | model_emb=model_emb.cpu(), # using the same embedding wight with tranining data 98 | loop=False 99 | ) 100 | 101 | start_t = time.time() 102 | 103 | # batch, cond = next(data_valid) 104 | # print(batch.shape) 105 | 106 | model_base_name = os.path.basename(os.path.split(args.model_path)[0]) + f'.{os.path.split(args.model_path)[1]}' 107 | out_dir = os.path.join(args.out_dir, f"{model_base_name.split('.ema')[0]}") 108 | if not os.path.isdir(out_dir): 109 | os.mkdir(out_dir) 110 | 111 | out_path = os.path.join(out_dir, f"ema{model_base_name.split('.ema')[1]}.samples") 112 | if not os.path.isdir(out_path): 113 | os.mkdir(out_path) 114 | out_path = os.path.join(out_path, f"seed{args.seed2}_step{args.step}_eta{args.eta}_respace_{args.decode_respacing}_{args.split}_topp{args.top_p}_tau{args.tau}\ 115 | _topl{args.top_l}_topk{args.top_k}_noisedl{args.noised_l}_scend{args.scale_end}_clamp_skip{args.clamp_skip}_pd_{args.p_d}.json") 116 | # fout = open(out_path, 'a') 117 | 118 | all_test_data = [] 119 | 120 | try: 121 | while True: 122 | batch, cond = next(data_valid) 123 | # print(batch.shape) 124 | all_test_data.append(cond) 125 | 126 | except StopIteration: 127 | print('### End of reading iteration...') 128 | 129 | from tqdm import tqdm 130 | 131 | # if args.ana_save: os.mkdir(f"/data/tzc/DiffuSeqs/myDiffuSeq/analyze_data/qqp-valid-embedding/{args.ana_save}") 132 | 133 | 134 | t_lst = [] 135 | for cond in tqdm(all_test_data): 136 | 137 | input_ids_x = cond.pop('input_ids').to(dist_util.dev()) 138 | x_start = model.get_embeds(input_ids_x) 139 | input_ids_mask = cond.pop('input_mask') 140 | input_ids_mask_ori = input_ids_mask 141 | 142 | noise = th.randn_like(x_start) 143 | input_ids_mask = th.broadcast_to(input_ids_mask.unsqueeze(dim=-1), x_start.shape).to(dist_util.dev()) 144 | x_noised = th.where(input_ids_mask==0, x_start, noise) 145 | 146 | model_kwargs = {} 147 | 148 | step_gap = 1 149 | 150 | # if args.step == args.diffusion_steps: 151 | # args.use_ddim = False 152 | # step_gap = 1 153 | # else: 154 | # args.use_ddim = True 155 | # step_gap = args.diffusion_steps//args.step 156 | if args.eta >= 0: 157 | args.use_ddim = True 158 | else: 159 | args.use_ddim = False 160 | 161 | if args.p_d: 162 | args.use_ddim = False 163 | 164 | sample_fn = ( 165 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 166 | ) 167 | 168 | # sample_shape = (batch.shape[0], args.seq_len, args.hidden_dim) 169 | sample_shape = (x_start.shape[0], args.seq_len, args.hidden_dim) 170 | 171 | t1 = time.time() 172 | samples = sample_fn( 173 | model, 174 | sample_shape, 175 | noise=x_noised, 176 | clip_denoised=args.clip_denoised, 177 | denoised_fn=partial(denoised_fn_round, args, model_emb_copy.cuda()), 178 | model_kwargs=model_kwargs, 179 | top_p=args.top_p, 180 | clamp_step=args.clamp_step, 181 | clamp_first=True, 182 | mask=input_ids_mask, 183 | x_start=x_start, 184 | gap=step_gap, 185 | eta=args.eta, 186 | ddim_signal=args.p_d, 187 | # top_l = args.top_l, 188 | # top_k = args.top_k, 189 | # noised_l = args.noised_l, 190 | # scale_end = args.scale_end, 191 | ) 192 | t2 = time.time() 193 | t_lst.append(t2 - t1) 194 | 195 | model_emb_copy = model_emb_copy.cpu() 196 | # print(samples[0].shape) # samples for each step 197 | 198 | sample = samples[-1] 199 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 200 | dist.all_gather(gathered_samples, sample) 201 | all_sentence = [sample.cpu().numpy() for sample in gathered_samples] 202 | 203 | # print('sampling takes {:.2f}s .....'.format(time.time() - start_t)) 204 | 205 | word_lst_recover = [] 206 | word_lst_ref = [] 207 | word_lst_source = [] 208 | 209 | 210 | arr = np.concatenate(all_sentence, axis=0) 211 | x_t = th.tensor(arr).cuda() 212 | print('decoding for seq2seq', ) 213 | print(arr.shape) 214 | 215 | reshaped_x_t = x_t 216 | logits = model.get_logits(reshaped_x_t) # bsz, seqlen, vocab 217 | 218 | cands = th.topk(logits, k=1, dim=-1) 219 | sample = cands.indices 220 | # tokenizer = load_tokenizer(args) 221 | 222 | for seq, input_mask in zip(cands.indices, input_ids_mask_ori): 223 | len_x = args.seq_len - sum(input_mask).tolist() 224 | tokens = tokenizer.decode_token(seq[len_x:]) 225 | word_lst_recover.append(tokens) 226 | 227 | for seq, input_mask in zip(input_ids_x, input_ids_mask_ori): 228 | # tokens = tokenizer.decode_token(seq) 229 | len_x = args.seq_len - sum(input_mask).tolist() 230 | word_lst_source.append(tokenizer.decode_token(seq[:len_x])) 231 | word_lst_ref.append(tokenizer.decode_token(seq[len_x:])) 232 | 233 | 234 | fout = open(out_path, 'a') 235 | for (recov, ref, src) in zip(word_lst_recover, word_lst_ref, word_lst_source): 236 | print(json.dumps({"recover": recov, "reference": ref, "source": src}), file=fout) 237 | fout.close() 238 | 239 | # for (recov, ref, src) in zip(word_lst_recover, word_lst_ref, word_lst_source): 240 | # print(json.dumps({"recover": recov, "reference": ref, "source": src})) 241 | 242 | 243 | print('### Total takes {:.2f}s .....'.format(time.time() - start_t)) 244 | print(f'### Written the decoded output to {out_path}') 245 | print(t_lst) 246 | print(np.mean(t_lst)) 247 | 248 | if __name__ == "__main__": 249 | main() 250 | -------------------------------------------------------------------------------- /scripts/eval_seq2seq.py: -------------------------------------------------------------------------------- 1 | import os, sys, glob, json 2 | import numpy as np 3 | import argparse 4 | import torch 5 | 6 | from torchmetrics.text.rouge import ROUGEScore 7 | rougeScore = ROUGEScore() 8 | from bert_score import score 9 | 10 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 11 | import nltk 12 | 13 | def get_bleu(recover, reference): 14 | return sentence_bleu([recover.split()], reference.split(), smoothing_function=SmoothingFunction().method4,) 15 | 16 | def selectBest(sentences): 17 | selfBleu = [[] for i in range(len(sentences))] 18 | for i, s1 in enumerate(sentences): 19 | for j, s2 in enumerate(sentences): 20 | score = get_bleu(s1, s2) 21 | selfBleu[i].append(score) 22 | for i, s1 in enumerate(sentences): 23 | selfBleu[i][i] = 0 24 | idx = np.argmax(np.sum(selfBleu, -1)) 25 | return sentences[idx] 26 | 27 | def get_bleu(recover, reference): 28 | return sentence_bleu([reference.split()], recover.split(), smoothing_function=SmoothingFunction().method4,) 29 | 30 | def diversityOfSet(sentences): 31 | selfBleu = [] 32 | # print(sentences) 33 | for i, sentence in enumerate(sentences): 34 | for j in range(i+1, len(sentences)): 35 | # print(sentence, sentences[j]) 36 | score = get_bleu(sentence, sentences[j]) 37 | selfBleu.append(score) 38 | if len(selfBleu)==0: 39 | selfBleu.append(0) 40 | div4 = distinct_n_gram_inter_sent(sentences, 4) 41 | return np.mean(selfBleu), div4 42 | 43 | 44 | def distinct_n_gram(hypn,n): 45 | dist_list = [] 46 | for hyp in hypn: 47 | hyp_ngrams = [] 48 | hyp_ngrams += nltk.ngrams(hyp.split(), n) 49 | total_ngrams = len(hyp_ngrams) 50 | unique_ngrams = len(list(set(hyp_ngrams))) 51 | if total_ngrams == 0: 52 | return 0 53 | dist_list.append(unique_ngrams/total_ngrams) 54 | return np.mean(dist_list) 55 | 56 | 57 | def distinct_n_gram_inter_sent(hypn, n): 58 | hyp_ngrams = [] 59 | for hyp in hypn: 60 | hyp_ngrams += nltk.ngrams(hyp.split(), n) 61 | total_ngrams = len(hyp_ngrams) 62 | unique_ngrams = len(list(set(hyp_ngrams))) 63 | if total_ngrams == 0: 64 | return 0 65 | dist_n = unique_ngrams/total_ngrams 66 | return dist_n 67 | 68 | if __name__ == '__main__': 69 | 70 | parser = argparse.ArgumentParser(description='decoding args.') 71 | parser.add_argument('--folder', type=str, default='', help='path to the folder of decoded texts') 72 | parser.add_argument('--mbr', action='store_true', help='mbr decoding or not') 73 | parser.add_argument('--sos', type=str, default='[CLS]', help='start token of the sentence') 74 | parser.add_argument('--eos', type=str, default='[SEP]', help='end token of the sentence') 75 | parser.add_argument('--sep', type=str, default='[SEP]', help='sep token of the sentence') 76 | parser.add_argument('--pad', type=str, default='[PAD]', help='pad token of the sentence') 77 | 78 | args = parser.parse_args() 79 | 80 | files = sorted(glob.glob(f"{args.folder}/*json")) 81 | sample_num = 0 82 | with open(files[0], 'r') as f: 83 | for row in f: 84 | sample_num += 1 85 | 86 | sentenceDict = {} 87 | referenceDict = {} 88 | sourceDict = {} 89 | for i in range(sample_num): 90 | sentenceDict[i] = [] 91 | referenceDict[i] = [] 92 | sourceDict[i] = [] 93 | 94 | div4 = [] 95 | selfBleu = [] 96 | 97 | AVG_bleu = [] 98 | AVG_rougel = [] 99 | AVG_dist2 = [] 100 | AVG_dist4 = [] 101 | AVG_f1 = [] 102 | AVG_len = [] 103 | 104 | for path in files: 105 | print(path) 106 | sources = [] 107 | references = [] 108 | recovers = [] 109 | bleu = [] 110 | rougel = [] 111 | avg_len = [] 112 | dist2 = [] 113 | dist4 = [] 114 | 115 | with open(path, 'r') as f: 116 | cnt = 0 117 | for row in f: 118 | 119 | source = json.loads(row)['source'].strip() 120 | reference = json.loads(row)['reference'].strip() 121 | recover = json.loads(row)['recover'].strip() 122 | source = source.replace(args.eos, '').replace(args.sos, '') 123 | reference = reference.replace(args.eos, '').replace(args.sos, '').replace(args.sep, '') 124 | recover = recover.replace(args.eos, '').replace(args.sos, '').replace(args.sep, '').replace(args.pad, '') 125 | 126 | sources.append(source) 127 | references.append(reference) 128 | recovers.append(recover) 129 | avg_len.append(len(recover.split(' '))) 130 | 131 | 132 | bleu.append(get_bleu(recover, reference)) 133 | rougel.append(rougeScore(recover, reference)['rougeL_fmeasure'].tolist()) 134 | dist2.append(distinct_n_gram([recover], 2)) 135 | dist4.append(distinct_n_gram([recover], 4)) 136 | 137 | sentenceDict[cnt].append(recover) 138 | referenceDict[cnt].append(reference) 139 | sourceDict[cnt].append(source) 140 | cnt += 1 141 | 142 | P, R, F1 = score(recovers, references, model_type='microsoft/deberta-xlarge-mnli', lang='en', verbose=True) 143 | 144 | print('*'*30) 145 | print('avg BLEU score', np.mean(bleu)) 146 | AVG_bleu.append(np.mean(bleu).item()) 147 | print('avg ROUGE-L score', np.mean(rougel)) 148 | AVG_rougel.append(np.mean(rougel).item()) 149 | print('avg berscore', torch.mean(F1)) 150 | AVG_f1.append(torch.mean(F1).item()) 151 | print('avg dist2 score', np.mean(dist2)) 152 | AVG_dist2.append(np.mean(dist2).item()) 153 | print('avg dist4 score', np.mean(dist4)) 154 | AVG_dist4.append(np.mean(dist4).item()) 155 | print('avg len', np.mean(avg_len)) 156 | AVG_len.append(np.mean(avg_len).item()) 157 | 158 | if len(files)>1: 159 | if not args.mbr: 160 | print('*'*30) 161 | print('Compute diversity...') 162 | print('*'*30) 163 | for k, v in sentenceDict.items(): 164 | if len(v) == 0: 165 | continue 166 | sb, d4 = diversityOfSet(v) 167 | selfBleu.append(sb) 168 | div4.append(d4) 169 | 170 | print('avg selfBleu score', np.mean(selfBleu)) 171 | print('avg div4 score', np.mean(div4)) 172 | 173 | else: 174 | print('*'*30) 175 | print('MBR...') 176 | print('*'*30) 177 | bleu = [] 178 | rougel = [] 179 | avg_len = [] 180 | dist2 = [] 181 | dist4 = [] 182 | recovers = [] 183 | references = [] 184 | sources = [] 185 | 186 | 187 | for k, v in sentenceDict.items(): 188 | if len(v) == 0 or len(referenceDict[k]) == 0: 189 | continue 190 | 191 | recovers.append(selectBest(v)) 192 | references.append(referenceDict[k][0]) 193 | sources.append(sourceDict[k][0]) 194 | 195 | for (source, reference, recover) in zip(sources, references, recovers): 196 | bleu.append(get_bleu(recover, reference)) 197 | rougel.append(rougeScore(recover, reference)['rougeL_fmeasure'].tolist()) 198 | avg_len.append(len(recover.split(' '))) 199 | dist2.append(distinct_n_gram([recover], 2)) 200 | dist4.append(distinct_n_gram([recover], 4)) 201 | 202 | # print(len(recovers), len(references), len(recovers)) 203 | 204 | P, R, F1 = score(recovers, references, model_type='microsoft/deberta-xlarge-mnli', lang='en', verbose=True) 205 | 206 | print('*'*30) 207 | print('MBR BLEU score', np.mean(bleu)) 208 | print('MBR ROUGE-l score', np.mean(rougel)) 209 | print('MBR berscore', torch.mean(F1)) 210 | print('MBR dist2 score', np.mean(dist2)) 211 | print('MBR dist4 score', np.mean(dist4)) 212 | 213 | print('*'*30) 214 | print('AVG BLEU score', sum(AVG_bleu) / len(AVG_bleu), len(AVG_bleu)) 215 | print('AVG ROUGE-l score',sum(AVG_rougel) / len(AVG_rougel)) 216 | print('AVG berscore', sum(AVG_f1) / len(AVG_f1)) 217 | print('AVG dist2 score', sum(AVG_dist2) / len(AVG_dist2)) 218 | print('AVG dist4 score', sum(AVG_dist4) / len(AVG_dist4)) 219 | 220 | print('*'*30) 221 | print('BLEUs', AVG_bleu) 222 | print('ROUGE-ls',AVG_rougel) 223 | print('berscores', AVG_f1) 224 | print('dist2 scores', AVG_dist2) 225 | print('dist4 scores', AVG_dist4) 226 | -------------------------------------------------------------------------------- /scripts/run_decode.py: -------------------------------------------------------------------------------- 1 | import os, sys, glob 2 | import argparse 3 | sys.path.append('.') 4 | sys.path.append('..') 5 | 6 | if __name__ == '__main__': 7 | 8 | parser = argparse.ArgumentParser(description='decoding args.') 9 | parser.add_argument('--model_dir', type=str, default='', help='path to the folder of diffusion model') 10 | parser.add_argument('--seed', type=int, default=101, help='random seed') 11 | parser.add_argument('--step', type=int, default=2000, help='if less than diffusion training steps, like 1000, use ddim sampling') 12 | parser.add_argument('--step_gap', type=int, default=0) 13 | 14 | parser.add_argument('--bsz', type=int, default=50, help='batch size') 15 | parser.add_argument('--split', type=str, default='test', choices=['train', 'valid', 'test'], help='dataset split used to decode') 16 | 17 | parser.add_argument('--top_p', type=int, default=-1, help='top p used in sampling, default is off') 18 | parser.add_argument('--pattern', type=str, default='ema', help='training pattern') 19 | 20 | args = parser.parse_args() 21 | 22 | # set working dir to the upper folder 23 | abspath = os.path.abspath(sys.argv[0]) 24 | dname = os.path.dirname(abspath) 25 | dname = os.path.dirname(dname) 26 | os.chdir(dname) 27 | 28 | output_lst = [] 29 | for lst in glob.glob(args.model_dir): 30 | print(lst) 31 | checkpoints = sorted(glob.glob(f"{lst}/{args.pattern}*.pt"))[::-1] 32 | 33 | out_dir = 'generation_outputs' 34 | if not os.path.isdir(out_dir): 35 | os.mkdir(out_dir) 36 | 37 | for checkpoint_one in checkpoints: 38 | 39 | COMMAND = f'python sample_seq2seq.py ' \ 40 | f'--model_path {checkpoint_one} --step {args.step} ' \ 41 | f'--batch_size {args.bsz} --seed2 {args.seed} --split {args.split} ' \ 42 | f'--out_dir {out_dir} --top_p {args.top_p} --step_gap {args.step_gap}' 43 | print(COMMAND) 44 | 45 | os.system(COMMAND) 46 | 47 | print('#'*30, 'decoding finished...') -------------------------------------------------------------------------------- /scripts/run_train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | import time 5 | sys.path.append('.') 6 | 7 | if __name__ == '__main__': 8 | 9 | parser = argparse.ArgumentParser(description='training args.') 10 | parser.add_argument('--dataset', type=str, default='', help='name of training dataset') 11 | parser.add_argument('--data_dir', type=str, default='', help='path to training dataset') 12 | 13 | parser.add_argument('--noise_schedule', type=str, default='cosine', choices=['linear', 'cosine', 'sqrt', 'trunc_cos', 'trunc_lin', 'pw_lin'], help='the distribution of noises') 14 | parser.add_argument('--diff_steps', type=int, default=4000, help='diffusion steps') 15 | parser.add_argument('--schedule_sampler', type=str, default='uniform', choices=['uniform', 'lossaware', 'fixstep'], help='schedule sampler of timesteps') 16 | 17 | parser.add_argument('--seq_len', type=int, default=128, help='max len of input sequence') 18 | parser.add_argument('--hidden_t_dim', type=int, default=128, help='hidden size of time embedding') 19 | parser.add_argument('--hidden_dim', type=int, default=128, help='hidden size of word embedding') 20 | parser.add_argument('--learning_steps', type=int, default=40000, help='total steps of learning') 21 | parser.add_argument('--save_interval', type=int, default=10000, help='save step') 22 | parser.add_argument('--resume_checkpoint', type=str, default='none', help='path to resume checkpoint, like xxx/xxx.pt') 23 | parser.add_argument('--lr', type=float, default=1e-04, help='learning rate') 24 | parser.add_argument('--bsz', type=int, default=64, help='batch size') 25 | parser.add_argument('--microbatch', type=int, default=64, help='microbatch size') 26 | parser.add_argument('--seed', type=int, default=101, help='random seed') 27 | 28 | parser.add_argument('--use_fp16', action='store_true') 29 | parser.add_argument('--config_name', type=str, default='bert-base-uncased', help='config of pre-trained models') 30 | parser.add_argument('--vocab', type=str, default='bert', help='use bert vocab or load external vocab dict if given as path') 31 | parser.add_argument('--use_plm_init', type=str, default='no', choices=['no', 'bert'], help='load init parameter from the pre-trained lm') 32 | 33 | parser.add_argument('--notes', type=str, default='-', help='as training notes or specifical args') 34 | parser.add_argument('--app', type=str, default='', help='other input args') 35 | 36 | parser.add_argument('--simi_lambda', type=float, default = 0.01) 37 | parser.add_argument('--simi_step', type=int, default = 10) 38 | parser.add_argument('--simi_penalty', type=str, default = "cosine") 39 | 40 | parser.add_argument('--near_step', type=int, default = 0) 41 | parser.add_argument('--near_lambda', type=float, default = 0) 42 | parser.add_argument('--far_step', type=int, default = 0) 43 | parser.add_argument('--far_lambda', type=float, default = 0) 44 | parser.add_argument('--simi_noise', type=float, default = 0.05) 45 | 46 | args = parser.parse_args() 47 | 48 | # set working dir to the upper folder 49 | abspath = os.path.abspath(sys.argv[0]) 50 | dname = os.path.dirname(abspath) 51 | dname = os.path.dirname(dname) 52 | os.chdir(dname) 53 | 54 | folder_name = "post_train_diffusion_models/" 55 | 56 | if int(os.environ['LOCAL_RANK']) == 0: 57 | if not os.path.isdir(folder_name): 58 | os.mkdir(folder_name) 59 | 60 | Model_FILE = f"diffuseq_{args.dataset}_h{args.hidden_dim}_lr{args.lr}" \ 61 | f"_t{args.diff_steps}_{args.noise_schedule}_{args.schedule_sampler}" \ 62 | f"_seed{args.seed}" 63 | if args.notes: 64 | args.notes += time.strftime("%Y%m%d-%H:%M:%S") 65 | Model_FILE = Model_FILE + f'_{args.notes}' 66 | Model_FILE = os.path.join(folder_name, Model_FILE) 67 | 68 | if int(os.environ['LOCAL_RANK']) == 0: 69 | if not os.path.isdir(Model_FILE): 70 | os.mkdir(Model_FILE) 71 | 72 | COMMANDLINE = f" OPENAI_LOGDIR={Model_FILE} " \ 73 | f"TOKENIZERS_PARALLELISM=false " \ 74 | f"python train.py " \ 75 | f"--checkpoint_path {Model_FILE} " \ 76 | f"--dataset {args.dataset} --data_dir {args.data_dir} --vocab {args.vocab} --use_plm_init {args.use_plm_init} " \ 77 | f"--lr {args.lr} " \ 78 | f"--batch_size {args.bsz} --microbatch {args.microbatch} " \ 79 | f"--diffusion_steps {args.diff_steps} " \ 80 | f"--noise_schedule {args.noise_schedule} " \ 81 | f"--schedule_sampler {args.schedule_sampler} --resume_checkpoint {args.resume_checkpoint} " \ 82 | f"--seq_len {args.seq_len} --hidden_t_dim {args.hidden_t_dim} --seed {args.seed} " \ 83 | f"--hidden_dim {args.hidden_dim} " \ 84 | f"--learning_steps {args.learning_steps} --save_interval {args.save_interval} " \ 85 | f"--config_name {args.config_name} --notes {args.notes} " \ 86 | f"--simi_lambda {args.simi_lambda} --simi_step {args.simi_step} --simi_penalty {args.simi_penalty} " \ 87 | f"--near_lambda {args.near_lambda} --near_step {args.near_step} " \ 88 | f"--far_lambda {args.far_lambda} --far_step {args.far_step} --simi_noise {args.simi_noise}" \ 89 | 90 | 91 | COMMANDLINE += " " + args.app 92 | 93 | if int(os.environ['LOCAL_RANK']) == 0: 94 | with open(os.path.join(Model_FILE, 'saved_bash.sh'), 'w') as f: 95 | print(COMMANDLINE, file=f) 96 | 97 | print(COMMANDLINE) 98 | 99 | os.system(COMMANDLINE) 100 | 101 | print("FINISHED!!!!") 102 | time.sleep(300) 103 | 104 | # OPENAI_LOGDIR=diffusion_models/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_qqp20221123-21:24:15 TOKENIZERS_PARALLELISM=false python train.py --checkpoint_path diffusion_models/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_qqp20221123-21:24:15 --dataset qqp --data_dir datasets/QQP --vocab bert --use_plm_init no --lr 0.0001 --batch_size 2048 --microbatch 64 --diffusion_steps 2000 --noise_schedule sqrt --schedule_sampler lossaware --resume_checkpoint none --seq_len 128 --hidden_t_dim 128 --seed 102 --hidden_dim 128 --learning_steps 50000 --save_interval 10000 --config_name bert-base-uncased --notes qqp20221123-21:24:15 --simi_lambda 0.01 --simi_step 10 --simi_penalty cosine -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | import json, torch, os 7 | import numpy as np 8 | from diffuseq.utils import dist_util, logger 9 | from diffuseq.text_datasets import load_data_text 10 | from diffuseq.step_sample import create_named_schedule_sampler 11 | from basic_utils import ( 12 | load_defaults_config, 13 | create_model_and_diffusion, 14 | args_to_dict, 15 | add_dict_to_argparser, 16 | load_model_emb, 17 | load_tokenizer 18 | ) 19 | from train_util import TrainLoop 20 | from transformers import set_seed 21 | import wandb 22 | 23 | ### custom your wandb setting here ### 24 | # os.environ["WANDB_API_KEY"] = "" 25 | os.environ["WANDB_MODE"] = "offline" 26 | 27 | def create_argparser(): 28 | defaults = dict() 29 | defaults.update(load_defaults_config()) 30 | parser = argparse.ArgumentParser() 31 | add_dict_to_argparser(parser, defaults) # update latest args according to argparse 32 | return parser 33 | 34 | def main(): 35 | args = create_argparser().parse_args() 36 | set_seed(args.seed) 37 | dist_util.setup_dist() 38 | logger.configure() 39 | logger.log("### Creating data loader...") 40 | 41 | tokenizer = load_tokenizer(args) 42 | model_weight, tokenizer = load_model_emb(args, tokenizer) 43 | 44 | data = load_data_text( 45 | batch_size=args.batch_size, 46 | seq_len=args.seq_len, 47 | data_args = args, 48 | loaded_vocab=tokenizer, 49 | model_emb=model_weight # use model's weights as init 50 | ) 51 | next(data) 52 | 53 | data_valid = load_data_text( 54 | batch_size=args.batch_size, 55 | seq_len=args.seq_len, 56 | data_args=args, 57 | split='valid', 58 | deterministic=True, 59 | loaded_vocab=tokenizer, 60 | model_emb=model_weight # using the same embedding wight with tranining data 61 | ) 62 | 63 | print('#'*30, 'size of vocab', args.vocab_size) 64 | 65 | logger.log("### Creating model and diffusion...") 66 | # print('#'*30, 'CUDA_VISIBLE_DEVICES', os.environ['CUDA_VISIBLE_DEVICES']) 67 | model, diffusion = create_model_and_diffusion( 68 | **args_to_dict(args, load_defaults_config().keys()) 69 | ) 70 | # print('#'*30, 'cuda', dist_util.dev()) 71 | model.to(dist_util.dev()) # DEBUG ** 72 | # model.cuda() # DEBUG ** 73 | 74 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 75 | 76 | logger.log(f'### The parameter count is {pytorch_total_params}') 77 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 78 | 79 | logger.log(f'### Saving the hyperparameters to {args.checkpoint_path}/training_args.json') 80 | with open(f'{args.checkpoint_path}/training_args.json', 'w') as f: 81 | json.dump(args.__dict__, f, indent=2) 82 | 83 | if ('LOCAL_RANK' not in os.environ) or (int(os.environ['LOCAL_RANK']) == 0): 84 | wandb.init( 85 | project=os.getenv("WANDB_PROJECT", "DiffuSeq"), 86 | name=args.checkpoint_path, 87 | ) 88 | wandb.config.update(args.__dict__, allow_val_change=True) 89 | 90 | logger.log("### Training...") 91 | 92 | TrainLoop( 93 | model=model, 94 | diffusion=diffusion, 95 | data=data, 96 | batch_size=args.batch_size, 97 | microbatch=args.microbatch, 98 | lr=args.lr, 99 | ema_rate=args.ema_rate, 100 | log_interval=args.log_interval, 101 | save_interval=args.save_interval, 102 | resume_checkpoint=args.resume_checkpoint, 103 | use_fp16=args.use_fp16, 104 | fp16_scale_growth=args.fp16_scale_growth, 105 | schedule_sampler=schedule_sampler, 106 | weight_decay=args.weight_decay, 107 | learning_steps=args.learning_steps, 108 | checkpoint_path=args.checkpoint_path, 109 | gradient_clipping=args.gradient_clipping, 110 | eval_data=data_valid, 111 | eval_interval=args.eval_interval, 112 | simi_penalty = args.simi_penalty, 113 | simi_lambda = args.simi_lambda, 114 | simi_step = args.simi_step, 115 | near_step = args.near_step, 116 | far_step = args.far_step, 117 | near_lambda = args.near_lambda, 118 | far_lambda = args.far_lambda, 119 | simi_noise = args.simi_noise 120 | ).run_loop() 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import numpy as np 7 | import torch as th 8 | import torch.distributed as dist 9 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 10 | from torch.optim import AdamW 11 | import io 12 | import time 13 | 14 | from diffuseq.utils import dist_util, logger 15 | from diffuseq.utils.fp16_util import ( 16 | make_master_params, 17 | master_params_to_model_params, 18 | model_grads_to_master_grads, 19 | unflatten_master_params, 20 | zero_grad, 21 | ) 22 | from diffuseq.utils.nn import update_ema 23 | from diffuseq.step_sample import LossAwareSampler, UniformSampler 24 | 25 | # For ImageNet experiments, this was a good default value. 26 | # We found that the lg_loss_scale quickly climbed to 27 | # 20-21 within the first ~1K steps of training. 28 | INITIAL_LOG_LOSS_SCALE = 20.0 29 | 30 | 31 | class TrainLoop: 32 | def __init__( 33 | self, 34 | *, 35 | model, 36 | diffusion, 37 | data, 38 | batch_size, 39 | microbatch, 40 | lr, 41 | ema_rate, 42 | log_interval, 43 | save_interval, 44 | resume_checkpoint, 45 | use_fp16=False, 46 | fp16_scale_growth=1e-3, 47 | schedule_sampler=None, 48 | weight_decay=0.0, 49 | learning_steps=0, 50 | checkpoint_path='', 51 | gradient_clipping=-1., 52 | eval_data=None, 53 | eval_interval=-1, 54 | simi_penalty = "", 55 | simi_lambda = 0.01, 56 | simi_step = 10, 57 | near_step = 0, 58 | near_lambda = 0, 59 | far_step = 0, 60 | far_lambda = 0, 61 | simi_noise=0.05, 62 | ): 63 | self.simi_noise = simi_noise 64 | self.model = model 65 | self.diffusion = diffusion 66 | self.data = data 67 | self.eval_data = eval_data 68 | self.batch_size = batch_size 69 | self.microbatch = microbatch if microbatch > 0 else batch_size 70 | self.lr = lr 71 | self.ema_rate = ( 72 | [ema_rate] 73 | if isinstance(ema_rate, float) 74 | else [float(x) for x in ema_rate.split(",")] 75 | ) 76 | self.log_interval = log_interval 77 | self.eval_interval = eval_interval 78 | self.save_interval = save_interval 79 | self.resume_checkpoint = resume_checkpoint 80 | self.use_fp16 = use_fp16 81 | self.fp16_scale_growth = fp16_scale_growth 82 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 83 | self.weight_decay = weight_decay 84 | self.learning_steps = learning_steps 85 | self.gradient_clipping = gradient_clipping 86 | 87 | self.step = 0 88 | self.resume_step = 0 89 | self.global_batch = self.batch_size * dist.get_world_size() 90 | 91 | self.model_params = list(self.model.parameters()) 92 | self.master_params = self.model_params 93 | self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE 94 | self.sync_cuda = th.cuda.is_available() 95 | 96 | self.checkpoint_path = checkpoint_path # DEBUG ** 97 | 98 | self.simi_penalty = simi_penalty 99 | self.simi_lambda = simi_lambda 100 | self.simi_step = simi_step 101 | 102 | self.near_step = near_step 103 | self.far_step = far_step 104 | self.near_lambda = near_lambda 105 | self.far_lambda = far_lambda 106 | 107 | self._load_and_sync_parameters() 108 | if self.use_fp16: 109 | self._setup_fp16() 110 | 111 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) 112 | if self.resume_step: 113 | # self._load_optimizer_state() 114 | frac_done = (self.step + self.resume_step) / self.learning_steps 115 | if not resume_checkpoint: lr = self.lr * (1 - frac_done) 116 | self.opt = AdamW(self.master_params, lr=lr, weight_decay=self.weight_decay) 117 | # Model was resumed, either due to a restart or a checkpoint 118 | # being specified at the command line. 119 | self.ema_params = [ 120 | self._load_ema_parameters(rate) for rate in self.ema_rate 121 | ] 122 | else: 123 | self.ema_params = [ 124 | copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) 125 | ] 126 | 127 | if th.cuda.is_available(): # DEBUG ** 128 | self.use_ddp = True 129 | print(dist_util.dev()) 130 | self.ddp_model = DDP( 131 | self.model, 132 | device_ids=[dist_util.dev()], 133 | output_device=dist_util.dev(), 134 | broadcast_buffers=False, 135 | bucket_cap_mb=128, 136 | find_unused_parameters=False, 137 | ) 138 | else: 139 | if dist.get_world_size() > 1: 140 | logger.warn( 141 | "Distributed training requires CUDA. " 142 | "Gradients will not be synchronized properly!" 143 | ) 144 | self.use_ddp = False 145 | self.ddp_model = self.model 146 | 147 | def _load_and_sync_parameters(self): 148 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 149 | 150 | if resume_checkpoint[-3:] == '.pt': 151 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 152 | if dist.get_rank() == 0: 153 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 154 | self.model.load_state_dict( 155 | dist_util.load_state_dict( 156 | actual_model_path(resume_checkpoint), map_location=dist_util.dev() 157 | ) 158 | ) 159 | 160 | dist_util.sync_params(self.model.parameters()) 161 | 162 | def _load_ema_parameters(self, rate): 163 | ema_params = copy.deepcopy(self.master_params) 164 | 165 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 166 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 167 | if ema_checkpoint: 168 | if dist.get_rank() == 0: 169 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 170 | state_dict = dist_util.load_state_dict( 171 | actual_model_path(ema_checkpoint), map_location=dist_util.dev() 172 | ) 173 | ema_params = self._state_dict_to_master_params(state_dict) 174 | 175 | dist_util.sync_params(ema_params) 176 | return ema_params 177 | 178 | def _load_optimizer_state(self): 179 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 180 | if bf.exists(main_checkpoint): 181 | logger.log(f"loading optimizer state from checkpoint: {main_checkpoint}") 182 | state_dict = dist_util.load_state_dict( 183 | actual_model_path(main_checkpoint), map_location=dist_util.dev() 184 | ) 185 | self.opt.load_state_dict(state_dict) 186 | 187 | def _setup_fp16(self): 188 | self.master_params = make_master_params(self.model_params) 189 | self.model.convert_to_fp16() 190 | 191 | def run_loop(self): 192 | while ( 193 | not self.learning_steps 194 | or self.step + self.resume_step < self.learning_steps 195 | ): 196 | batch, cond = next(self.data) 197 | self.run_step(batch, cond) 198 | # print(self.step, ":", time.time()) 199 | if self.step % self.log_interval == 0: 200 | logger.dumpkvs() 201 | if self.eval_data is not None and self.step % self.eval_interval == 0: 202 | batch_eval, cond_eval = next(self.eval_data) 203 | self.forward_only(batch_eval, cond_eval) 204 | print('eval on validation set') 205 | logger.dumpkvs() 206 | if self.step % self.save_interval == 0: 207 | self.save() 208 | # Run for a finite amount of time in integration tests. 209 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 210 | return 211 | self.step += 1 212 | # print('done') 213 | # exit() 214 | # Save the last checkpoint if it wasn't already saved. 215 | if (self.step - 1) % self.save_interval != 0: 216 | self.save() 217 | 218 | def run_step(self, batch, cond): 219 | self.forward_backward(batch, cond, ) 220 | if self.use_fp16: 221 | self.optimize_fp16() 222 | else: 223 | self.optimize_normal() 224 | self.log_step() 225 | 226 | def forward_only(self, batch, cond): 227 | with th.no_grad(): 228 | zero_grad(self.model_params) 229 | for i in range(0, batch.shape[0], self.microbatch): 230 | micro = batch[i: i + self.microbatch].to(dist_util.dev()) 231 | micro_cond = { 232 | k: v[i: i + self.microbatch].to(dist_util.dev()) 233 | for k, v in cond.items() 234 | } 235 | last_batch = (i + self.microbatch) >= batch.shape[0] 236 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 237 | # print(micro_cond.keys()) 238 | 239 | if self.near_step !=0 and self.far_step != 0: loss_type = "near and far" 240 | elif self.simi_penalty == "kl": loss_type = "kl" 241 | else: loss_type = "" 242 | 243 | if loss_type == "near and far": 244 | compute_losses = functools.partial( 245 | self.diffusion.training_losses, 246 | self.ddp_model, 247 | loss_type, 248 | micro, 249 | t, 250 | model_kwargs=micro_cond, 251 | simi_penalty = self.simi_penalty, 252 | near_step = self.near_step, 253 | near_lambda = self.near_lambda, 254 | far_step = self.far_step, 255 | far_lambda = self.far_lambda, 256 | ) 257 | else: 258 | compute_losses = functools.partial( 259 | self.diffusion.training_losses, 260 | self.ddp_model, 261 | loss_type, 262 | micro, 263 | t, 264 | model_kwargs=micro_cond, 265 | simi_penalty = self.simi_penalty, 266 | simi_lambda = self.simi_lambda, 267 | simi_step = self.simi_step, 268 | simi_noise = self.simi_noise, 269 | ) 270 | 271 | 272 | if last_batch or not self.use_ddp: 273 | losses = compute_losses() 274 | else: 275 | with self.ddp_model.no_sync(): 276 | losses = compute_losses() 277 | 278 | log_loss_dict( 279 | self.diffusion, t, {f"eval_{k}": v * weights for k, v in losses.items()} 280 | ) 281 | 282 | 283 | def forward_backward(self, batch, cond): 284 | zero_grad(self.model_params) 285 | for i in range(0, batch.shape[0], self.microbatch): 286 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 287 | micro_cond = { 288 | k: v[i : i + self.microbatch].to(dist_util.dev()) 289 | for k, v in cond.items() 290 | } 291 | last_batch = (i + self.microbatch) >= batch.shape[0] 292 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 293 | # print(micro_cond.keys()) 294 | 295 | if self.near_step !=0 and self.far_step != 0: loss_type = "near and far" 296 | elif self.simi_penalty == "kl": loss_type = "kl" 297 | else: loss_type = "" 298 | 299 | if loss_type == "near and far": 300 | compute_losses = functools.partial( 301 | self.diffusion.training_losses, 302 | self.ddp_model, 303 | loss_type, 304 | micro, 305 | t, 306 | model_kwargs=micro_cond, 307 | simi_penalty = self.simi_penalty, 308 | near_step = self.near_step, 309 | near_lambda = self.near_lambda, 310 | far_step = self.far_step, 311 | far_lambda = self.far_lambda, 312 | ) 313 | else: 314 | compute_losses = functools.partial( 315 | self.diffusion.training_losses, 316 | self.ddp_model, 317 | loss_type, 318 | micro, 319 | t, 320 | model_kwargs=micro_cond, 321 | simi_penalty = self.simi_penalty, 322 | simi_lambda = self.simi_lambda, 323 | simi_step = self.simi_step, 324 | simi_noise = self.simi_noise, 325 | ) 326 | 327 | if last_batch or not self.use_ddp: 328 | losses = compute_losses() 329 | else: 330 | with self.ddp_model.no_sync(): 331 | losses = compute_losses() 332 | 333 | if isinstance(self.schedule_sampler, LossAwareSampler): 334 | self.schedule_sampler.update_with_local_losses( 335 | t, losses["loss"].detach() 336 | ) 337 | 338 | loss = (losses["loss"] * weights).mean() 339 | log_loss_dict( 340 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 341 | ) 342 | if self.use_fp16: 343 | exit() 344 | loss_scale = 2 ** self.lg_loss_scale 345 | (loss * loss_scale).backward() 346 | else: 347 | loss.backward() 348 | 349 | def optimize_fp16(self): 350 | if any(not th.isfinite(p.grad).all() for p in self.model_params): 351 | self.lg_loss_scale -= 1 352 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 353 | return 354 | 355 | model_grads_to_master_grads(self.model_params, self.master_params) 356 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 357 | self._log_grad_norm() 358 | self._anneal_lr() 359 | self.opt.step() 360 | for rate, params in zip(self.ema_rate, self.ema_params): 361 | update_ema(params, self.master_params, rate=rate) 362 | master_params_to_model_params(self.model_params, self.master_params) 363 | self.lg_loss_scale += self.fp16_scale_growth 364 | 365 | def grad_clip(self): 366 | # print('doing gradient clipping') 367 | max_grad_norm=self.gradient_clipping #3.0 368 | if hasattr(self.opt, "clip_grad_norm"): 369 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping 370 | self.opt.clip_grad_norm(max_grad_norm) 371 | # else: 372 | # assert False 373 | # elif hasattr(self.model, "clip_grad_norm_"): 374 | # # Some models (like FullyShardedDDP) have a specific way to do gradient clipping 375 | # self.model.clip_grad_norm_(args.max_grad_norm) 376 | else: 377 | # Revert to normal clipping otherwise, handling Apex or full precision 378 | th.nn.utils.clip_grad_norm_( 379 | self.model.parameters(), #amp.master_params(self.opt) if self.use_apex else 380 | max_grad_norm, 381 | ) 382 | 383 | def optimize_normal(self): 384 | if self.gradient_clipping > 0: 385 | self.grad_clip() 386 | self._log_grad_norm() 387 | # self._anneal_lr() 388 | self.opt.step() 389 | for rate, params in zip(self.ema_rate, self.ema_params): 390 | update_ema(params, self.master_params, rate=rate) 391 | 392 | def _log_grad_norm(self): 393 | sqsum = 0.0 394 | # cnt = 0 395 | for p in self.master_params: 396 | # print(cnt, p) ## DEBUG 397 | # print(cnt, p.grad) 398 | # cnt += 1 399 | if p.grad != None: 400 | sqsum += (p.grad ** 2).sum().item() 401 | logger.logkv_mean("grad_norm", np.sqrt(sqsum)) 402 | 403 | def _anneal_lr(self): 404 | if not self.learning_steps: 405 | return 406 | frac_done = (self.step + self.resume_step) / self.learning_steps 407 | lr = self.lr * (1 - frac_done) 408 | for param_group in self.opt.param_groups: 409 | param_group["lr"] = lr 410 | 411 | def log_step(self): 412 | logger.logkv("step", self.step + self.resume_step) 413 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 414 | if self.use_fp16: 415 | logger.logkv("lg_loss_scale", self.lg_loss_scale) 416 | 417 | def save(self): 418 | def save_checkpoint(rate, params): 419 | state_dict = self._master_params_to_state_dict(params) 420 | if dist.get_rank() == 0: 421 | logger.log(f"saving model {rate}...") 422 | if not rate: 423 | filename = f"model{(self.step+self.resume_step):06d}.pt" 424 | else: 425 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 426 | print('writing to', bf.join(get_blob_logdir(), filename)) 427 | print('writing to', bf.join(self.checkpoint_path, filename)) 428 | # with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 429 | # th.save(state_dict, f) 430 | with bf.BlobFile(bf.join(self.checkpoint_path, filename), "wb") as f: # DEBUG ** 431 | th.save(state_dict, f) # save locally 432 | # pass # save empty 433 | 434 | # save_checkpoint(0, self.master_params) 435 | for rate, params in zip(self.ema_rate, self.ema_params): 436 | save_checkpoint(rate, params) 437 | 438 | dist.barrier() 439 | 440 | def _master_params_to_state_dict(self, master_params): 441 | if self.use_fp16: 442 | master_params = unflatten_master_params( 443 | list(self.model.parameters()), master_params # DEBUG ** 444 | ) 445 | state_dict = self.model.state_dict() 446 | for i, (name, _value) in enumerate(self.model.named_parameters()): 447 | assert name in state_dict 448 | state_dict[name] = master_params[i] 449 | return state_dict 450 | 451 | def _state_dict_to_master_params(self, state_dict): 452 | params = [state_dict[name] for name, _ in self.model.named_parameters()] 453 | if self.use_fp16: 454 | return make_master_params(params) 455 | else: 456 | return params 457 | 458 | 459 | def parse_resume_step_from_filename(filename): 460 | """ 461 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 462 | checkpoint's number of steps. 463 | """ 464 | if filename[-3:] == '.pt': 465 | return int(filename[-9:-3]) 466 | else: 467 | return 0 468 | 469 | 470 | def get_blob_logdir(): 471 | return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir()) 472 | 473 | 474 | def find_resume_checkpoint(): 475 | # On your infrastructure, you may want to override this to automatically 476 | # discover the latest checkpoint on your blob storage, etc. 477 | return None 478 | 479 | 480 | def find_ema_checkpoint(main_checkpoint, step, rate): 481 | if main_checkpoint is None: 482 | return None 483 | filename = f"ema_{rate}_{(step):06d}.pt" 484 | path = bf.join(bf.dirname(main_checkpoint), filename) 485 | if bf.exists(path): 486 | return path 487 | return None 488 | 489 | 490 | def log_loss_dict(diffusion, ts, losses): 491 | for key, values in losses.items(): 492 | logger.logkv_mean(key, values.mean().item()) 493 | # Log the quantiles (four quartiles, in particular). 494 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 495 | quartile = int(4 * sub_t / diffusion.num_timesteps) 496 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 497 | 498 | 499 | def actual_model_path(model_path): 500 | return model_path 501 | --------------------------------------------------------------------------------