├── README.md
├── assets
├── adaptive_sampling.png
├── dis_diffusion_v2.png
└── model_arch.png
├── basic_utils.py
├── datasets
└── readme.md
├── diffuseq
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── gaussian_diffusion.cpython-37.pyc
│ ├── gaussian_diffusion.cpython-38.pyc
│ ├── gaussian_diffusion.cpython-39.pyc
│ ├── rounding.cpython-37.pyc
│ ├── rounding.cpython-38.pyc
│ ├── rounding.cpython-39.pyc
│ ├── step_sample.cpython-37.pyc
│ ├── step_sample.cpython-39.pyc
│ ├── text_datasets.cpython-37.pyc
│ ├── text_datasets.cpython-38.pyc
│ ├── text_datasets.cpython-39.pyc
│ ├── transformer_model.cpython-37.pyc
│ ├── transformer_model.cpython-38.pyc
│ └── transformer_model.cpython-39.pyc
├── config.json
├── gaussian_diffusion.py
├── rounding copy.py
├── rounding.py
├── step_sample.py
├── text_datasets.py
├── transformer_model.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── dist_util.cpython-37.pyc
│ ├── dist_util.cpython-38.pyc
│ ├── dist_util.cpython-39.pyc
│ ├── fp16_util.cpython-37.pyc
│ ├── fp16_util.cpython-39.pyc
│ ├── logger.cpython-37.pyc
│ ├── logger.cpython-38.pyc
│ ├── logger.cpython-39.pyc
│ ├── losses.cpython-37.pyc
│ ├── losses.cpython-38.pyc
│ ├── losses.cpython-39.pyc
│ ├── nn.cpython-37.pyc
│ ├── nn.cpython-38.pyc
│ └── nn.cpython-39.pyc
│ ├── dist_util.py
│ ├── fp16_util.py
│ ├── logger.py
│ ├── losses.py
│ └── nn.py
├── evaluation
├── __pycache__
│ ├── tokenizer.cpython-37.pyc
│ └── tokenizer.cpython-39.pyc
├── eval.py
├── quick_eval.py
└── tokenizer.py
├── generation_cases
├── qqp_20
│ ├── mbr.jsonl
│ ├── out.log
│ ├── seed10_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed11_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed12_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed13_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed14_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed20_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed21_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed22_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed23_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ └── seed24_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
└── ts-20
│ ├── mbr.jsonl
│ ├── mbr.jsonl_bk
│ ├── out.log
│ ├── out.log2
│ ├── seed11_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed13_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed14_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed15_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed23_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed24_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed25_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed35_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ ├── seed36_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
│ └── seed37_step2000_eta-1.0_respace_adp_20_test_topp0.0_tau1.0_topl0.0_topk0_noisedl0.0_scend_clamp_skip-1_pd_.json
├── requirement.txt
├── sample_seq2seq.py
├── scripts
├── eval_seq2seq.py
├── run_decode.py
└── run_train.py
├── train.py
└── train_util.py
/README.md:
--------------------------------------------------------------------------------
1 | # Bridging the Gap between Training and Inference for Diffusion Model
2 |
3 | This is the official code for [Can Diffusion Model Achieve Better Performance in Text Generation? Bridging the Gap between Training and Inference!](https://arxiv.org/pdf/2305.04465.pdf)
4 |
5 | # Highlight
6 |
7 | 1. One can post-train your own diffusion model with two methods below to ``accelerate the inference`` speed and ``achieve better performance`` !
8 |
9 | 2. Extensive experiments show our method can generate a full sequence with 128 tokens in only ``4`` denoising steps !
10 |
11 | ## Model Architecture
12 |

13 |
14 | ## Down-Sampling Strategy
15 |
16 | 
17 |
18 | # Dataset & Model Prepartion
19 |
20 | ## Dataset
21 | We provide the download link for all the data used in our paper:
22 |
23 | | Task | Dataset | Samples | Used in our paper |
24 | |------|---------| ---------| ---------|
25 | |Text Simplification| [WIKI AUTO](https://github.com/chaojiang06/wiki-auto) | 677k | [download](https://drive.google.com/drive/folders/1yIo3qploLvtSc9CAzohAeKlHjOoNRfLg?usp=sharing)|
26 | | Paraphrase | [Quora Question Pairs](https://www.kaggle.com/c/quora-question-pairs) | 114k | [download](https://drive.google.com/drive/folders/1kclZh3KTS1IOD3tre6ybsX7UhRkwEPeW?usp=share_link)|
27 | | Story Generation | [ROC Story](https://cs.rochester.edu/nlp/rocstories/) | 88k | [download](https://drive.google.com/drive/folders/1bvjIroxJaACGIkACwSxCCJHPh1PBR3Zv?usp=sharing) |
28 | | Question Generation | [Quasar-T](https://drive.google.com/drive/folders/122YK0IElSnGZbPMigXrduTVL1geB4wEW?usp=sharing) | 117k | [download](https://drive.google.com/drive/folders/122YK0IElSnGZbPMigXrduTVL1geB4wEW?usp=sharing) |
29 | | E2E (Semantic / Syntax) | [E2E](http://www.macs.hw.ac.uk/) | 88k | [download](https://drive.google.com/drive/folders/1YJwa3SIqg2d0VkfzCrVEo8QtrZwkxcBX?usp=sharing) |
30 |
31 | Please download the data and place under the ``./datasets`` folder
32 |
33 | ## Backbone Model
34 | Please refer to the following repos for more details:
35 |
36 | [DiffuSeq: Sequence to Sequence Text Generation with Diffusion Models](https://github.com/Shark-NLP/DiffuSeq)
37 |
38 | [Diffusion-LM Improves Controllable Text Generation](https://github.com/XiangLi1999/Diffusion-LM)
39 |
40 | ``Note`` We also provide the two post-trained models [link](https://drive.google.com/drive/folders/1UvcN9mKOv-nVZuQpJaAOQWCG22sGk_2O?usp=sharing) for quick check
41 |
42 |
43 | # Quick Start
44 | We provide the code for post-training on QQP (Paraphrase) dataset
45 |
46 | ## Environment
47 | ```
48 | conda create -n diffsuion python=3.9
49 | conda activate diffusion
50 | pip install -r requirement.txt
51 | ```
52 |
53 | ## Training
54 | We conduct experiment with 4 NVIDIA-A100(40GB)
55 | ```bash
56 | cd scripts
57 | export CUDA_VISIBLE_DEVICES=0,1,2,3;
58 |
59 | DISTRIBUTE_ARGS="
60 | --nproc_per_node=4 \
61 | --use_env
62 | "
63 |
64 | TRAIN_ARGS="
65 | --diff_steps 2000 \
66 | --microbatch 100 \
67 | --lr 0.0001 \
68 | --learning_steps 320000 \
69 | --save_interval 2500 \
70 | --seed 109 \
71 | --noise_schedule sqrt \
72 | --hidden_dim 128 \
73 | --bsz 100 \
74 | --dataset qqp \
75 | --data_dir datasets/QQP \
76 | --vocab bert \
77 | --seq_len 128 \
78 | --simi_penalty l2_noise_random \
79 | --simi_lambda -2 \
80 | --simi_step 10 \
81 | --simi_noise 0.05 \
82 | --resume_checkpoint /path/to/checkpoint \
83 | --schedule_sampler lossaware \
84 | --notes qqp
85 | "
86 |
87 | python -m torch.distributed.launch $DISTRIBUTE_ARGS run_train.py $TRAIN_ARGS
88 |
89 | ```
90 |
91 |
92 | # Inference
93 |
94 | ```bash
95 | python sample_seq2seq.py \
96 | --model_path /path/to/checkpoint \
97 | --step 2000 \
98 | --batch_size 16 \
99 | --seed2 10 \
100 | --split test \
101 | --out_dir generation_outputs \
102 | --decode_respacing "adp_20"
103 | ```
104 |
105 |
106 |
107 | # Acknowledgement
108 | We appreciate the open source of the following projects:
109 |
110 | [DiffuSeq](https://github.com/Shark-NLP/DiffuSeq)
111 | [Diffusion-LM](https://github.com/XiangLi1999/Diffusion-LM)
112 |
--------------------------------------------------------------------------------
/assets/adaptive_sampling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/assets/adaptive_sampling.png
--------------------------------------------------------------------------------
/assets/dis_diffusion_v2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/assets/dis_diffusion_v2.png
--------------------------------------------------------------------------------
/assets/model_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/assets/model_arch.png
--------------------------------------------------------------------------------
/basic_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import json, os
4 | import time
5 |
6 | from diffuseq import gaussian_diffusion as gd
7 | from diffuseq.gaussian_diffusion import SpacedDiffusion, space_timesteps
8 | from diffuseq.transformer_model import TransformerNetModel
9 | from transformers import AutoTokenizer, PreTrainedTokenizerFast
10 |
11 | class myTokenizer():
12 | """
13 | Load tokenizer from bert config or defined BPE vocab dict
14 | """
15 | ################################################
16 | ### You can custome your own tokenizer here. ###
17 | ################################################
18 | def __init__(self, args):
19 | if args.vocab == 'bert':
20 | tokenizer = AutoTokenizer.from_pretrained(args.config_name)
21 | self.tokenizer = tokenizer
22 | self.sep_token_id = tokenizer.sep_token_id
23 | self.pad_token_id = tokenizer.pad_token_id
24 | # save
25 | tokenizer.save_pretrained(args.checkpoint_path)
26 | else:
27 | # load vocab from the path
28 | print('#'*30, 'load vocab from', args.vocab)
29 | vocab_dict = {'[START]': 0, '[END]': 1, '[UNK]':2, '[PAD]':3}
30 | with open(args.vocab, 'r', encoding='utf-8') as f:
31 | for row in f:
32 | vocab_dict[row.strip().split(' ')[0]] = len(vocab_dict)
33 | self.tokenizer = vocab_dict
34 | self.rev_tokenizer = {v: k for k, v in vocab_dict.items()}
35 | self.sep_token_id = vocab_dict['[END]']
36 | self.pad_token_id = vocab_dict['[PAD]']
37 | # save
38 | if int(os.environ['LOCAL_RANK']) == 0:
39 | path_save_vocab = f'{args.checkpoint_path}/vocab.json'
40 | with open(path_save_vocab, 'w') as f:
41 | json.dump(vocab_dict, f)
42 |
43 | self.vocab_size = len(self.tokenizer)
44 | args.vocab_size = self.vocab_size # update vocab size in args
45 |
46 | def encode_token(self, sentences):
47 | if isinstance(self.tokenizer, dict):
48 | input_ids = [[0] + [self.tokenizer.get(x, self.tokenizer['[UNK]']) for x in seq.split()] + [1] for seq in sentences]
49 | elif isinstance(self.tokenizer, PreTrainedTokenizerFast):
50 | input_ids = self.tokenizer(sentences, add_special_tokens=True)['input_ids']
51 | else:
52 | assert False, "invalid type of vocab_dict"
53 | return input_ids
54 |
55 | def decode_token(self, seq):
56 | if isinstance(self.tokenizer, dict):
57 | seq = seq.squeeze(-1).tolist()
58 | while len(seq)>0 and seq[-1] == self.pad_token_id:
59 | seq.pop()
60 | tokens = " ".join([self.rev_tokenizer[x] for x in seq]).replace('__ ', '').replace('@@ ', '')
61 | elif isinstance(self.tokenizer, PreTrainedTokenizerFast):
62 | seq = seq.squeeze(-1).tolist()
63 | while len(seq)>0 and seq[-1] == self.pad_token_id:
64 | seq.pop()
65 | tokens = self.tokenizer.decode(seq)
66 | else:
67 | assert False, "invalid type of vocab_dict"
68 | return tokens
69 |
70 |
71 | def load_model_emb(args, tokenizer):
72 | ### random emb or pre-defined embedding like glove embedding. You can custome your own init here.
73 | model = torch.nn.Embedding(tokenizer.vocab_size, args.hidden_dim)
74 | path_save = '{}/random_emb.torch'.format(args.checkpoint_path)
75 | if int(os.environ['LOCAL_RANK']) == 0:
76 | if os.path.exists(path_save):
77 | print('reload the random embeddings', model)
78 | model.load_state_dict(torch.load(path_save))
79 | else:
80 | print('initializing the random embeddings', model)
81 | torch.nn.init.normal_(model.weight)
82 | torch.save(model.state_dict(), path_save)
83 | else:
84 | while not os.path.exists(path_save):
85 | time.sleep(1)
86 | print('reload the random embeddings', model)
87 | model.load_state_dict(torch.load(path_save))
88 |
89 | return model, tokenizer
90 |
91 |
92 | def load_tokenizer(args):
93 | tokenizer = myTokenizer(args)
94 | return tokenizer
95 |
96 | def load_defaults_config():
97 | """
98 | Load defaults for training args.
99 | """
100 | with open('diffuseq/config.json', 'r') as f:
101 | return json.load(f)
102 |
103 |
104 | def create_model_and_diffusion(
105 | hidden_t_dim,
106 | hidden_dim,
107 | vocab_size,
108 | config_name,
109 | use_plm_init,
110 | dropout,
111 | diffusion_steps,
112 | noise_schedule,
113 | learn_sigma,
114 | timestep_respacing,
115 | predict_xstart,
116 | rescale_timesteps,
117 | sigma_small,
118 | rescale_learned_sigmas,
119 | use_kl,
120 | notes,
121 | **kwargs,
122 | ):
123 | model = TransformerNetModel(
124 | input_dims=hidden_dim,
125 | output_dims=(hidden_dim if not learn_sigma else hidden_dim*2),
126 | hidden_t_dim=hidden_t_dim,
127 | dropout=dropout,
128 | config_name=config_name,
129 | vocab_size=vocab_size,
130 | init_pretrained=use_plm_init
131 | )
132 |
133 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
134 |
135 | if not timestep_respacing:
136 | timestep_respacing = [diffusion_steps]
137 | elif timestep_respacing.startswith("ddim"):
138 | pass
139 | elif timestep_respacing.startswith("x2"):
140 | pass
141 | elif timestep_respacing.startswith("adp"):
142 | pass
143 | else:
144 | timestep_respacing = json.loads(timestep_respacing)
145 |
146 | diffusion = SpacedDiffusion(
147 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
148 | betas=betas,
149 | rescale_timesteps=rescale_timesteps,
150 | predict_xstart=predict_xstart,
151 | learn_sigmas = learn_sigma,
152 | sigma_small = sigma_small,
153 | use_kl = use_kl,
154 | rescale_learned_sigmas=rescale_learned_sigmas
155 | )
156 |
157 | return model, diffusion
158 |
159 |
160 | def add_dict_to_argparser(parser, default_dict):
161 | for k, v in default_dict.items():
162 | v_type = type(v)
163 | if v is None:
164 | v_type = str
165 | elif isinstance(v, bool):
166 | v_type = str2bool
167 | parser.add_argument(f"--{k}", default=v, type=v_type)
168 |
169 |
170 | def args_to_dict(args, keys):
171 | return {k: getattr(args, k) for k in keys}
172 |
173 |
174 | def str2bool(v):
175 | """
176 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
177 | """
178 | if isinstance(v, bool):
179 | return v
180 | if v.lower() in ("yes", "true", "t", "y", "1"):
181 | return True
182 | elif v.lower() in ("no", "false", "f", "n", "0"):
183 | return False
184 | else:
185 | raise argparse.ArgumentTypeError("boolean value expected")
186 |
--------------------------------------------------------------------------------
/datasets/readme.md:
--------------------------------------------------------------------------------
1 | # Dataset Prepartion
2 |
3 | | Task | Dataset | Samples | Used in our paper |
4 | |------|---------| ---------| ---------|
5 | |Text Simplification| [WIKI AUTO](https://github.com/chaojiang06/wiki-auto) | 677k | [download](https://drive.google.com/drive/folders/1yIo3qploLvtSc9CAzohAeKlHjOoNRfLg?usp=sharing)|
6 | | Paraphrase | [Quora Question Pairs](https://www.kaggle.com/c/quora-question-pairs) | 114k | [download](https://drive.google.com/drive/folders/1kclZh3KTS1IOD3tre6ybsX7UhRkwEPeW?usp=share_link)|
7 | | Story Generation | [ROC Story](https://cs.rochester.edu/nlp/rocstories/) | 88k | [download](https://drive.google.com/drive/folders/1bvjIroxJaACGIkACwSxCCJHPh1PBR3Zv?usp=sharing) |
8 | | Question Generation | [Quasar-T](https://drive.google.com/drive/folders/122YK0IElSnGZbPMigXrduTVL1geB4wEW?usp=sharing) | 117k | [download](https://drive.google.com/drive/folders/122YK0IElSnGZbPMigXrduTVL1geB4wEW?usp=sharing) |
9 | | E2E (Semantic / Syntax) | [E2E](http://www.macs.hw.ac.uk/) | 88k | [download](https://drive.google.com/drive/folders/1YJwa3SIqg2d0VkfzCrVEo8QtrZwkxcBX?usp=sharing) |
10 |
--------------------------------------------------------------------------------
/diffuseq/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__init__.py
--------------------------------------------------------------------------------
/diffuseq/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/gaussian_diffusion.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/gaussian_diffusion.cpython-37.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/gaussian_diffusion.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/gaussian_diffusion.cpython-38.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/gaussian_diffusion.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/gaussian_diffusion.cpython-39.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/rounding.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/rounding.cpython-37.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/rounding.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/rounding.cpython-38.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/rounding.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/rounding.cpython-39.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/step_sample.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/step_sample.cpython-37.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/step_sample.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/step_sample.cpython-39.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/text_datasets.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/text_datasets.cpython-37.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/text_datasets.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/text_datasets.cpython-38.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/text_datasets.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/text_datasets.cpython-39.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/transformer_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/transformer_model.cpython-37.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/transformer_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/transformer_model.cpython-38.pyc
--------------------------------------------------------------------------------
/diffuseq/__pycache__/transformer_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/diffuseq/__pycache__/transformer_model.cpython-39.pyc
--------------------------------------------------------------------------------
/diffuseq/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "lr": 0.0001,
3 | "batch_size": 2048,
4 | "microbatch": 64,
5 | "learning_steps": 320000,
6 | "log_interval": 5,
7 | "save_interval": 100,
8 | "eval_interval": 5000,
9 | "ema_rate": "0.9999",
10 | "resume_checkpoint": "none",
11 | "schedule_sampler": "lossaware",
12 | "diffusion_steps": 2000,
13 | "noise_schedule": "sqrt",
14 | "timestep_respacing": "",
15 | "vocab": "bert",
16 | "use_plm_init": "no",
17 | "vocab_size": 0,
18 | "config_name": "huggingface-config",
19 | "notes": "folder-notes",
20 | "data_dir": "data-dir",
21 | "dataset": "dataset-name",
22 | "checkpoint_path": "checkpoint-path",
23 | "seq_len": 128,
24 | "hidden_t_dim": 128,
25 | "hidden_dim": 128,
26 | "dropout": 0.1,
27 | "use_fp16": false,
28 | "fp16_scale_growth": 0.001,
29 | "seed": 102,
30 | "gradient_clipping": -1.0,
31 | "weight_decay": 0.0,
32 | "learn_sigma": false,
33 | "use_kl": false,
34 | "predict_xstart": true,
35 | "rescale_timesteps": true,
36 | "rescale_learned_sigmas": false,
37 | "sigma_small": false,
38 | "emb_scale_factor": 1.0,
39 | "simi_penalty": "",
40 | "simi_lambda": 0.01,
41 | "simi_step": 10,
42 | "near_step": 0,
43 | "far_step": 0,
44 | "near_lambda": 0.0,
45 | "far_lambda": 0.0,
46 | "simi_noise": 0.05
47 | }
48 |
--------------------------------------------------------------------------------
/diffuseq/rounding copy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # bert results
3 | from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator, GPT2TokenizerFast
4 | import sys, yaml, os
5 | import json
6 |
7 | import numpy as np
8 |
9 | def get_knn(model_emb, text_emb, dist='cos'):
10 | if dist == 'cos':
11 | adjacency = model_emb @ text_emb.transpose(1, 0).to(model_emb.device)
12 | elif dist == 'l2':
13 | adjacency = model_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
14 | model_emb.size(0), -1, -1)
15 | adjacency = -torch.norm(adjacency, dim=-1)
16 | topk_out = torch.topk(adjacency, k=6, dim=0)
17 | return topk_out.values, topk_out.indices
18 |
19 | def get_efficient_knn(model_emb, text_emb):
20 | emb_norm = (model_emb**2).sum(-1).view(-1, 1) # vocab
21 | text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen
22 | arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1
23 | # print(emb_norm.shape, arr_norm.shape)
24 | dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(model_emb, text_emb_t) # (vocab, d) x (d, bsz*seqlen)
25 | dist = torch.clamp(dist, 0.0, np.inf)
26 | # print(dist.shape)
27 | topk_out = torch.topk(-dist, k=1, dim=0)
28 | return topk_out.values, topk_out.indices
29 |
30 | def rounding_func(text_emb_lst, model, tokenizer, emb_scale_factor=1.0):
31 | decoded_out_lst = []
32 |
33 | model_emb = model.weight # input_embs
34 | down_proj_emb2 = None
35 |
36 | dist = 'l2'
37 |
38 | for text_emb in text_emb_lst:
39 | import torch
40 | text_emb = torch.tensor(text_emb)
41 | # print(text_emb.shape)
42 | if len(text_emb.shape) > 2:
43 | text_emb = text_emb.view(-1, text_emb.size(-1))
44 | else:
45 | text_emb = text_emb
46 | val, indices = get_knn((down_proj_emb2 if dist == 'cos' else model_emb),
47 | text_emb.to(model_emb.device), dist=dist)
48 |
49 | decoded_out_lst.append(tokenizer.decode_token(indices[0]))
50 |
51 | return decoded_out_lst
52 |
53 | def compute_logp(args, model, x, input_ids):
54 | word_emb = model.weight
55 | sigma = 0.1
56 | if args.model_arch == '1d-unet':
57 | x = x.permute(0, 2, 1)
58 |
59 | bsz, seqlen, dim = x.shape
60 |
61 | x_flat = x.reshape(-1, x.size(-1)).unsqueeze(0) # 1, bsz*sample*seqlen, dim
62 | word_emb_flat = word_emb.unsqueeze(1) # vocab, 1, dim
63 | diff = (x_flat - word_emb_flat) ** 2 # vocab, seqlen, dim
64 |
65 | logp_expanded = -diff.sum(dim=-1) / (2 * sigma ** 2) # vocab, seqlen
66 | logp_expanded = logp_expanded.permute((1, 0))
67 |
68 | ce = torch.nn.CrossEntropyLoss(reduction='none')
69 | loss = ce(logp_expanded, input_ids.view(-1)).view(bsz, seqlen)
70 |
71 | return loss
72 |
73 | def get_weights(model, args):
74 | if hasattr(model, 'transformer'):
75 | input_embs = model.transformer.wte # input_embs
76 | down_proj = model.down_proj
77 | model_emb = down_proj(input_embs.weight)
78 | print(model_emb.shape)
79 | model = torch.nn.Embedding(model_emb.size(0), model_emb.size(1))
80 | print(args.emb_scale_factor)
81 | model.weight.data = model_emb * args.emb_scale_factor
82 |
83 | elif hasattr(model, 'weight'):
84 | pass
85 | else:
86 | assert NotImplementedError
87 |
88 | model.weight.requires_grad = False
89 | return model
90 |
91 | def denoised_fn_round(args, model, text_emb, t):
92 | # print(text_emb.shape) # bsz, seqlen, dim
93 | model_emb = model.weight # input_embs
94 | # print(t)
95 | old_shape = text_emb.shape
96 | old_device = text_emb.device
97 |
98 | if len(text_emb.shape) > 2:
99 | text_emb = text_emb.reshape(-1, text_emb.size(-1))
100 | else:
101 | text_emb = text_emb
102 | # val, indices = get_knn(model_emb, text_emb.to(model_emb.device), dist=dist)
103 | val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device))
104 | rounded_tokens = indices[0]
105 | # print(rounded_tokens.shape)
106 | new_embeds = model(rounded_tokens).view(old_shape).to(old_device)
107 |
108 | return new_embeds
--------------------------------------------------------------------------------
/diffuseq/rounding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # bert results
3 | from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator, GPT2TokenizerFast
4 | import sys, yaml, os
5 | import json
6 |
7 | import numpy as np
8 | import torch.nn.functional as F
9 |
10 | # def get_knn(model_emb, text_emb, dist='cos'):
11 | # if dist == 'cos':
12 | # adjacency = model_emb @ text_emb.transpose(1, 0).to(model_emb.device)
13 | # elif dist == 'l2':
14 | # adjacency = model_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
15 | # model_emb.size(0), -1, -1)
16 | # adjacency = -torch.norm(adjacency, dim=-1)
17 | # topk_out = torch.topk(adjacency, k=6, dim=0)
18 | # return topk_out.values, topk_out.indices
19 |
20 | # def enforce_repetition_penalty_(self, lprobs, batch_size, prev_output_tokens, repetition_penalty):
21 | # """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
22 | # for i in range(batch_size):
23 | # for previous_token in set(prev_output_tokens[i].tolist()):
24 | # # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
25 | # if lprobs[i, previous_token] < 0:
26 | # lprobs[i, previous_token] *= repetition_penalty
27 | # else:
28 | # lprobs[i, previous_token] /= repetition_penalty
29 |
30 | def get_efficient_knn(model_emb, text_emb):
31 | emb_norm = (model_emb**2).sum(-1).view(-1, 1) # vocab
32 | text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen
33 | arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1
34 | # print(emb_norm.shape, arr_norm.shape)
35 | dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(model_emb, text_emb_t) # (vocab, d) x (d, bsz*seqlen)
36 | dist = torch.clamp(dist, 0.0, np.inf)
37 | # print(dist.shape)
38 | topk_out = torch.topk(-dist, k=1, dim=0)
39 | return topk_out.values, topk_out.indices
40 |
41 | def get_efficient_knn_top_k(model_emb, text_emb, top_k, tau = 1.0):
42 | emb_norm = (model_emb**2).sum(-1).view(-1, 1) # vocab
43 | text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen
44 | arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1
45 | # print(emb_norm.shape, arr_norm.shape)
46 | dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(model_emb, text_emb_t) # (vocab, d) x (d, bsz*seqlen)
47 | dist = torch.clamp(dist, 0.0, np.inf)
48 | # print(dist.shape)
49 | topk_out = torch.topk(-dist, k=top_k, dim=0)
50 |
51 | sftmx = torch.nn.Softmax(dim=-1)
52 | indices = topk_out.indices.transpose(0,1)
53 | values = sftmx(topk_out.values.transpose(0,1)/ tau)
54 | idx = torch.multinomial(values, 1)
55 | indices = torch.gather(indices, -1, idx).transpose(0,1)
56 | values = torch.gather(values, -1, idx).transpose(0,1)
57 |
58 | return values, indices
59 |
60 | def get_efficient_knn_top_p(model_emb, text_emb, top_p, tau = 1.0, scale = 1.0):
61 | emb_norm = (model_emb**2).sum(-1).view(-1, 1) # vocab
62 | text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen
63 | arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1
64 | # print(emb_norm.shape, arr_norm.shape)
65 | dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(model_emb, text_emb_t) # (vocab, d) x (d, bsz*seqlen)
66 | dist = torch.clamp(dist, 0.0, np.inf)
67 | # print(dist.shape)
68 | # topk_out = torch.topk(-dist, k=top_k, dim=0)
69 |
70 | # dist =
71 | dist = -dist.transpose(0,1)
72 | sorted_logits, sorted_indices = torch.sort(dist, descending=True)
73 | if scale == "last":
74 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits / tau, dim=-1), dim=-1)
75 | else:
76 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
77 | sorted_indices_to_remove = cumulative_probs > top_p
78 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
79 | sorted_indices_to_remove[..., 0] = 0
80 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
81 | dist[indices_to_remove] = -float("Inf")
82 |
83 | # tau *= scale
84 | probs = F.softmax(dist / tau, dim=-1)
85 |
86 | idx = torch.multinomial(probs, 1)
87 |
88 | return -dist.transpose(0,1), idx.transpose(0,1)
89 |
90 | def get_efficient_cos_top_p(model_emb, text_emb, top_p, tau = 1.0, scale = 1.0):
91 | dist = model_emb @ text_emb.transpose(1, 0).to(model_emb.device)
92 |
93 | # dist =
94 | dist = dist.transpose(0,1)
95 | sorted_logits, sorted_indices = torch.sort(dist, descending=True)
96 | if scale == "last":
97 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits / tau, dim=-1), dim=-1)
98 | else:
99 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
100 | sorted_indices_to_remove = cumulative_probs > top_p
101 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
102 | sorted_indices_to_remove[..., 0] = 0
103 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
104 | dist[indices_to_remove] = -float("Inf")
105 |
106 | # tau *= scale
107 | probs = F.softmax(dist / tau, dim=-1)
108 |
109 | idx = torch.multinomial(probs, 1)
110 |
111 | return -dist.transpose(0,1), idx.transpose(0,1)
112 |
113 |
114 | def get_efficient_knn_top_l(model_emb, text_emb, top_p):
115 | emb_norm = (model_emb**2).sum(-1).view(-1, 1) # vocab
116 | text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen
117 | arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1
118 | # print(emb_norm.shape, arr_norm.shape)
119 | dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(model_emb, text_emb_t) # (vocab, d) x (d, bsz*seqlen)
120 | dist = torch.clamp(dist, 0.0, np.inf)
121 | # print(dist.shape)
122 | # topk_out = torch.topk(-dist, k=top_k, dim=0)
123 |
124 | # dist =
125 | sftmx = torch.nn.Softmax(dim=-1)
126 | indices = topk_out.indices.transpose(0,1)
127 | values = sftmx(topk_out.values.transpose(0,1))
128 | idx = torch.multinomial(values, 1)
129 | indices = torch.gather(indices, -1, idx).transpose(0,1)
130 | values = torch.gather(values, -1, idx).transpose(0,1)
131 |
132 | return values, indices
133 |
134 | def rounding_func(text_emb_lst, model, tokenizer, emb_scale_factor=1.0, scale=1.0):
135 | decoded_out_lst = []
136 |
137 | model_emb = model.weight # input_embs
138 | down_proj_emb2 = None
139 |
140 | dist = 'l2'
141 |
142 | for text_emb in text_emb_lst:
143 | import torch
144 | text_emb = torch.tensor(text_emb)
145 | # print(text_emb.shape)
146 | if len(text_emb.shape) > 2:
147 | text_emb = text_emb.view(-1, text_emb.size(-1))
148 | else:
149 | text_emb = text_emb
150 | val, indices = get_knn((down_proj_emb2 if dist == 'cos' else model_emb),
151 | text_emb.to(model_emb.device), dist=dist)
152 |
153 | decoded_out_lst.append(tokenizer.decode_token(indices[0]))
154 |
155 | return decoded_out_lst
156 |
157 | def compute_logp(args, model, x, input_ids):
158 | word_emb = model.weight
159 | sigma = 0.1
160 | if args.model_arch == '1d-unet':
161 | x = x.permute(0, 2, 1)
162 |
163 | bsz, seqlen, dim = x.shape
164 |
165 | x_flat = x.reshape(-1, x.size(-1)).unsqueeze(0) # 1, bsz*sample*seqlen, dim
166 | word_emb_flat = word_emb.unsqueeze(1) # vocab, 1, dim
167 | diff = (x_flat - word_emb_flat) ** 2 # vocab, seqlen, dim
168 |
169 | logp_expanded = -diff.sum(dim=-1) / (2 * sigma ** 2) # vocab, seqlen
170 | logp_expanded = logp_expanded.permute((1, 0))
171 |
172 | ce = torch.nn.CrossEntropyLoss(reduction='none')
173 | loss = ce(logp_expanded, input_ids.view(-1)).view(bsz, seqlen)
174 |
175 | return loss
176 |
177 | def get_weights(model, args):
178 | if hasattr(model, 'transformer'):
179 | input_embs = model.transformer.wte # input_embs
180 | down_proj = model.down_proj
181 | model_emb = down_proj(input_embs.weight)
182 | print(model_emb.shape)
183 | model = torch.nn.Embedding(model_emb.size(0), model_emb.size(1))
184 | print(args.emb_scale_factor)
185 | model.weight.data = model_emb * args.emb_scale_factor
186 |
187 | elif hasattr(model, 'weight'):
188 | pass
189 | else:
190 | assert NotImplementedError
191 |
192 | model.weight.requires_grad = False
193 | return model
194 |
195 | def denoised_fn_round(args, model, text_emb, t):
196 | # print(text_emb.shape) # bsz, seqlen, dim
197 | model_emb = model.weight # input_embs
198 | # print(t)
199 | if args.clamp_skip == 1:
200 | if t[0].item() % 2 == 0:
201 | print(t[0].item(), ": clamp skip")
202 | return text_emb
203 |
204 | elif args.clamp_skip == 0:
205 | if t[0].item() % 2 == 1:
206 | # print(t[0].item(), ": clamp skip")
207 | return text_emb
208 |
209 | # print(t[0].item(), ": clamp do")
210 |
211 | old_shape = text_emb.shape
212 | old_device = text_emb.device
213 |
214 | if len(text_emb.shape) > 2:
215 | text_emb = text_emb.reshape(-1, text_emb.size(-1))
216 | else:
217 | text_emb = text_emb
218 | # val, indices = get_knn(model_emb, text_emb.to(model_emb.device), dist=dist)
219 |
220 | if args.top_k != 0:
221 | # try:
222 | # T = (sum(json.loads(args.timestep_respacing)) if args.timestep_respacing else args.step) - 1
223 | # except:
224 | # T = int(args.timestep_respacing[4:])
225 | # scale = (t[0].item() / T + args.scale_end * (1 - t[0].item() / T))
226 | scale = 1
227 | top_k = max(int(args.top_k * scale), 1)
228 | val, indices = get_efficient_knn_top_k(model_emb, text_emb.to(model_emb.device), top_k, args.tau)
229 | elif args.top_p != 0:
230 | if args.scale_end == "last":
231 | if t[0].item() == 0:
232 | print(t[0].item(), ":do topp")
233 | val, indices = get_efficient_knn_top_p(model_emb, text_emb.to(model_emb.device), args.top_p, args.tau, args.scale_end)
234 | else:
235 | print(t[0].item(), ":do greedy")
236 | val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device))
237 | elif args.scale_end == "odd":
238 | if t[0].item() % 2 == 1:
239 | print(t[0].item(), ":do topp")
240 | val, indices = get_efficient_knn_top_p(model_emb, text_emb.to(model_emb.device), args.top_p, args.tau, args.scale_end)
241 | else:
242 | print(t[0].item(), ":do greedy")
243 | val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device))
244 | elif args.scale_end == "":
245 | print(t[0].item(), ":do topp")
246 | val, indices = get_efficient_knn_top_p(model_emb, text_emb.to(model_emb.device), args.top_p, args.tau, args.scale_end)
247 | elif args.scale_end.startswith('last_'):
248 | t_topp = int(args.scale_end.split('_')[-1])
249 | if t[0].item() < t_topp:
250 | print(t[0].item(), ":do topp")
251 | val, indices = get_efficient_knn_top_p(model_emb, text_emb.to(model_emb.device), args.top_p, args.tau, args.scale_end)
252 | else:
253 | print(t[0].item(), ":do greedy")
254 | val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device))
255 | elif args.scale_end.startswith('first_'):
256 | t_topp = int(args.scale_end.split('_')[-1])
257 | if t[0].item() >= t_topp:
258 | print(t[0].item(), ":do topp")
259 | val, indices = get_efficient_knn_top_p(model_emb, text_emb.to(model_emb.device), args.top_p, args.tau, args.scale_end)
260 | else:
261 | print(t[0].item(), ":do greedy")
262 | val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device))
263 | elif args.scale_end == 'cos':
264 | print(t[0].item(), ":do cos topp")
265 | val, indices = get_efficient_cos_top_p(model_emb, text_emb.to(model_emb.device), args.top_p, args.tau, args.scale_end)
266 | else:
267 | raise NotImplementedError("Unkown args.scale_end:", args.scale_end)
268 |
269 |
270 | else:
271 | val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device))
272 | rounded_tokens = indices[0]
273 | # print(rounded_tokens.shape)
274 | new_embeds = model(rounded_tokens).view(old_shape).to(old_device)
275 |
276 | return new_embeds
--------------------------------------------------------------------------------
/diffuseq/step_sample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | elif name == "lossaware":
18 | return LossSecondMomentResampler(diffusion)
19 | elif name == "fixstep":
20 | return FixSampler(diffusion)
21 | else:
22 | raise NotImplementedError(f"unknown schedule sampler: {name}")
23 |
24 |
25 | class ScheduleSampler(ABC):
26 | """
27 | A distribution over timesteps in the diffusion process, intended to reduce
28 | variance of the objective.
29 |
30 | By default, samplers perform unbiased importance sampling, in which the
31 | objective's mean is unchanged.
32 | However, subclasses may override sample() to change how the resampled
33 | terms are reweighted, allowing for actual changes in the objective.
34 | """
35 |
36 | @abstractmethod
37 | def weights(self):
38 | """
39 | Get a numpy array of weights, one per diffusion step.
40 |
41 | The weights needn't be normalized, but must be positive.
42 | """
43 |
44 | def sample(self, batch_size, device):
45 | """
46 | Importance-sample timesteps for a batch.
47 |
48 | :param batch_size: the number of timesteps.
49 | :param device: the torch device to save to.
50 | :return: a tuple (timesteps, weights):
51 | - timesteps: a tensor of timestep indices.
52 | - weights: a tensor of weights to scale the resulting losses.
53 | """
54 | w = self.weights()
55 | p = w / np.sum(w)
56 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
57 | indices = th.from_numpy(indices_np).long().to(device)
58 | weights_np = 1 / (len(p) * p[indices_np])
59 | weights = th.from_numpy(weights_np).float().to(device)
60 | return indices, weights
61 |
62 |
63 | class UniformSampler(ScheduleSampler):
64 | def __init__(self, diffusion):
65 | self.diffusion = diffusion
66 | self._weights = np.ones([diffusion.num_timesteps])
67 |
68 | def weights(self):
69 | return self._weights
70 |
71 | class FixSampler(ScheduleSampler):
72 | def __init__(self, diffusion):
73 | self.diffusion = diffusion
74 |
75 | ###############################################################
76 | ### You can custome your own sampling weight of steps here. ###
77 | ###############################################################
78 | self._weights = np.concatenate([np.ones([diffusion.num_timesteps//2]), np.zeros([diffusion.num_timesteps//2]) + 0.5])
79 |
80 | def weights(self):
81 | return self._weights
82 |
83 |
84 | class LossAwareSampler(ScheduleSampler):
85 | def update_with_local_losses(self, local_ts, local_losses):
86 | """
87 | Update the reweighting using losses from a model.
88 |
89 | Call this method from each rank with a batch of timesteps and the
90 | corresponding losses for each of those timesteps.
91 | This method will perform synchronization to make sure all of the ranks
92 | maintain the exact same reweighting.
93 |
94 | :param local_ts: an integer Tensor of timesteps.
95 | :param local_losses: a 1D Tensor of losses.
96 | """
97 | batch_sizes = [
98 | th.tensor([0], dtype=th.int32, device=local_ts.device)
99 | for _ in range(dist.get_world_size())
100 | ]
101 | dist.all_gather(
102 | batch_sizes,
103 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
104 | )
105 |
106 | # Pad all_gather batches to be the maximum batch size.
107 | batch_sizes = [x.item() for x in batch_sizes]
108 | max_bs = max(batch_sizes)
109 |
110 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
111 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
112 | dist.all_gather(timestep_batches, local_ts)
113 | dist.all_gather(loss_batches, local_losses)
114 | timesteps = [
115 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
116 | ]
117 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
118 | self.update_with_all_losses(timesteps, losses)
119 |
120 | @abstractmethod
121 | def update_with_all_losses(self, ts, losses):
122 | """
123 | Update the reweighting using losses from a model.
124 |
125 | Sub-classes should override this method to update the reweighting
126 | using losses from the model.
127 |
128 | This method directly updates the reweighting without synchronizing
129 | between workers. It is called by update_with_local_losses from all
130 | ranks with identical arguments. Thus, it should have deterministic
131 | behavior to maintain state across workers.
132 |
133 | :param ts: a list of int timesteps.
134 | :param losses: a list of float losses, one per timestep.
135 | """
136 |
137 |
138 | class LossSecondMomentResampler(LossAwareSampler):
139 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
140 | self.diffusion = diffusion
141 | self.history_per_term = history_per_term
142 | self.uniform_prob = uniform_prob
143 | self._loss_history = np.zeros(
144 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
145 | )
146 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
147 |
148 | def weights(self):
149 | if not self._warmed_up():
150 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
151 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
152 | weights /= np.sum(weights)
153 | weights *= 1 - self.uniform_prob
154 | weights += self.uniform_prob / len(weights)
155 | return weights
156 |
157 | def update_with_all_losses(self, ts, losses):
158 | for t, loss in zip(ts, losses):
159 | if self._loss_counts[t] == self.history_per_term:
160 | # Shift out the oldest loss term.
161 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
162 | self._loss_history[t, -1] = loss
163 | else:
164 | self._loss_history[t, self._loss_counts[t]] = loss
165 | self._loss_counts[t] += 1
166 |
167 | def _warmed_up(self):
168 | return (self._loss_counts == self.history_per_term).all()
169 |
--------------------------------------------------------------------------------
/diffuseq/text_datasets.py:
--------------------------------------------------------------------------------
1 | # import blobfile as bf
2 | import numpy as np
3 | from torch.utils.data import DataLoader, Dataset
4 |
5 | import torch
6 | import json
7 | import psutil
8 | import datasets
9 | from datasets import Dataset as Dataset2
10 |
11 | def load_data_text(
12 | batch_size,
13 | seq_len,
14 | deterministic=False,
15 | data_args=None,
16 | model_emb=None,
17 | split='train',
18 | loaded_vocab=None,
19 | loop=True,
20 | ):
21 | """
22 | For a dataset, create a generator over (seqs, kwargs) pairs.
23 |
24 | Each seq is an (bsz, len, h) float tensor, and the kwargs dict contains zero or
25 | more keys, each of which map to a batched Tensor of their own.
26 | The kwargs dict can be used for some meta information.
27 |
28 | :param batch_size: the batch size of each returned pair.
29 | :param seq_len: the max sequence length (one-side).
30 | :param deterministic: if True, yield results in a deterministic order.
31 | :param data_args: including dataset directory, num of dataset, basic settings, etc.
32 | :param model_emb: loaded word embeddings.
33 | :param loaded_vocab: loaded word vocabs.
34 | :param loop: loop to get batch data or not.
35 | """
36 |
37 | print('#'*30, '\nLoading text data...')
38 |
39 | training_data = get_corpus(data_args, seq_len, split=split, loaded_vocab=loaded_vocab)
40 |
41 | dataset = TextDataset(
42 | training_data,
43 | data_args,
44 | model_emb=model_emb
45 | )
46 |
47 | data_loader = DataLoader(
48 | dataset,
49 | batch_size=batch_size, # 20,
50 | # drop_last=True,
51 | shuffle=not deterministic,
52 | num_workers=0,
53 | )
54 | if loop:
55 | return infinite_loader(data_loader)
56 | else:
57 | # print(data_loader)
58 | return iter(data_loader)
59 |
60 | def infinite_loader(data_loader):
61 | while True:
62 | yield from data_loader
63 |
64 | def helper_tokenize(sentence_lst, vocab_dict, seq_len):
65 | # Process.memory_info is expressed in bytes, so convert to megabytes
66 | print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
67 | raw_datasets = Dataset2.from_dict(sentence_lst)
68 | print(raw_datasets)
69 | print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
70 |
71 | def tokenize_function(examples):
72 | input_id_x = vocab_dict.encode_token(examples['src'])
73 | input_id_y = vocab_dict.encode_token(examples['trg'])
74 | result_dict = {'input_id_x': input_id_x, 'input_id_y': input_id_y}
75 |
76 | return result_dict
77 |
78 | tokenized_datasets = raw_datasets.map(
79 | tokenize_function,
80 | batched=True,
81 | num_proc=4,
82 | remove_columns=['src', 'trg'],
83 | load_from_cache_file=True,
84 | desc="Running tokenizer on dataset",
85 | )
86 | print('### tokenized_datasets', tokenized_datasets)
87 | print('### tokenized_datasets...example', tokenized_datasets['input_id_x'][0])
88 | print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
89 |
90 | def merge_and_mask(group_lst):
91 | lst = []
92 | mask = []
93 | for i in range(len(group_lst['input_id_x'])):
94 | end_token = group_lst['input_id_x'][i][-1]
95 | src = group_lst['input_id_x'][i][:-1]
96 | trg = group_lst['input_id_y'][i][:-1]
97 | while len(src) + len(trg) > seq_len - 3:
98 | if len(src)>len(trg):
99 | src.pop()
100 | elif len(src) maxlen else s
84 |
85 | def writeseq(self, seq):
86 | seq = list(seq)
87 | for (i, elem) in enumerate(seq):
88 | self.file.write(elem)
89 | if i < len(seq) - 1: # add space unless this is the last one
90 | self.file.write(" ")
91 | self.file.write("\n")
92 | self.file.flush()
93 |
94 | def close(self):
95 | if self.own_file:
96 | self.file.close()
97 |
98 |
99 | class JSONOutputFormat(KVWriter):
100 | def __init__(self, filename):
101 | self.file = open(filename, "wt")
102 |
103 | def writekvs(self, kvs):
104 | for k, v in sorted(kvs.items()):
105 | if hasattr(v, "dtype"):
106 | kvs[k] = float(v)
107 | self.file.write(json.dumps(kvs) + "\n")
108 | self.file.flush()
109 |
110 | def close(self):
111 | self.file.close()
112 |
113 |
114 | class CSVOutputFormat(KVWriter):
115 | def __init__(self, filename):
116 | self.file = open(filename, "w+t")
117 | self.keys = []
118 | self.sep = ","
119 |
120 | def writekvs(self, kvs):
121 | # Add our current row to the history
122 | extra_keys = list(kvs.keys() - self.keys)
123 | extra_keys.sort()
124 | if extra_keys:
125 | self.keys.extend(extra_keys)
126 | self.file.seek(0)
127 | lines = self.file.readlines()
128 | self.file.seek(0)
129 | for (i, k) in enumerate(self.keys):
130 | if i > 0:
131 | self.file.write(",")
132 | self.file.write(k)
133 | self.file.write("\n")
134 | for line in lines[1:]:
135 | self.file.write(line[:-1])
136 | self.file.write(self.sep * len(extra_keys))
137 | self.file.write("\n")
138 | for (i, k) in enumerate(self.keys):
139 | if i > 0:
140 | self.file.write(",")
141 | v = kvs.get(k)
142 | if v is not None:
143 | self.file.write(str(v))
144 | self.file.write("\n")
145 | self.file.flush()
146 |
147 | def close(self):
148 | self.file.close()
149 |
150 |
151 | class TensorBoardOutputFormat(KVWriter):
152 | """
153 | Dumps key/value pairs into TensorBoard's numeric format.
154 | """
155 |
156 | def __init__(self, dir):
157 | os.makedirs(dir, exist_ok=True)
158 | self.dir = dir
159 | self.step = 1
160 | prefix = "events"
161 | path = osp.join(osp.abspath(dir), prefix)
162 | import tensorflow as tf
163 | from tensorflow.python import pywrap_tensorflow
164 | from tensorflow.core.util import event_pb2
165 | from tensorflow.python.util import compat
166 |
167 | self.tf = tf
168 | self.event_pb2 = event_pb2
169 | self.pywrap_tensorflow = pywrap_tensorflow
170 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
171 |
172 | def writekvs(self, kvs):
173 | def summary_val(k, v):
174 | kwargs = {"tag": k, "simple_value": float(v)}
175 | return self.tf.Summary.Value(**kwargs)
176 |
177 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
178 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
179 | event.step = (
180 | self.step
181 | ) # is there any reason why you'd want to specify the step?
182 | self.writer.WriteEvent(event)
183 | self.writer.Flush()
184 | self.step += 1
185 |
186 | def close(self):
187 | if self.writer:
188 | self.writer.Close()
189 | self.writer = None
190 |
191 |
192 | def make_output_format(format, ev_dir, log_suffix=""):
193 | os.makedirs(ev_dir, exist_ok=True)
194 | if format == "stdout":
195 | return HumanOutputFormat(sys.stdout)
196 | elif format == "log":
197 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
198 | elif format == "json":
199 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
200 | elif format == "csv":
201 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
202 | elif format == "tensorboard":
203 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
204 | else:
205 | raise ValueError("Unknown format specified: %s" % (format,))
206 |
207 |
208 | # ================================================================
209 | # API
210 | # ================================================================
211 |
212 |
213 | def logkv(key, val):
214 | """
215 | Log a value of some diagnostic
216 | Call this once for each diagnostic quantity, each iteration
217 | If called many times, last value will be used.
218 | """
219 | get_current().logkv(key, val)
220 |
221 |
222 | def logkv_mean(key, val):
223 | """
224 | The same as logkv(), but if called many times, values averaged.
225 | """
226 | get_current().logkv_mean(key, val)
227 |
228 |
229 | def logkvs(d):
230 | """
231 | Log a dictionary of key-value pairs
232 | """
233 | for (k, v) in d.items():
234 | logkv(k, v)
235 |
236 |
237 | def dumpkvs():
238 | """
239 | Write all of the diagnostics from the current iteration
240 | """
241 | return get_current().dumpkvs()
242 |
243 |
244 | def getkvs():
245 | return get_current().name2val
246 |
247 |
248 | def log(*args, level=INFO):
249 | """
250 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
251 | """
252 | get_current().log(*args, level=level)
253 |
254 |
255 | def debug(*args):
256 | log(*args, level=DEBUG)
257 |
258 |
259 | def info(*args):
260 | log(*args, level=INFO)
261 |
262 |
263 | def warn(*args):
264 | log(*args, level=WARN)
265 |
266 |
267 | def error(*args):
268 | log(*args, level=ERROR)
269 |
270 |
271 | def set_level(level):
272 | """
273 | Set logging threshold on current logger.
274 | """
275 | get_current().set_level(level)
276 |
277 |
278 | def set_comm(comm):
279 | get_current().set_comm(comm)
280 |
281 |
282 | def get_dir():
283 | """
284 | Get directory that log files are being written to.
285 | will be None if there is no output directory (i.e., if you didn't call start)
286 | """
287 | return get_current().get_dir()
288 |
289 |
290 | record_tabular = logkv
291 | dump_tabular = dumpkvs
292 |
293 |
294 | @contextmanager
295 | def profile_kv(scopename):
296 | logkey = "wait_" + scopename
297 | tstart = time.time()
298 | try:
299 | yield
300 | finally:
301 | get_current().name2val[logkey] += time.time() - tstart
302 |
303 |
304 | def profile(n):
305 | """
306 | Usage:
307 | @profile("my_func")
308 | def my_func(): code
309 | """
310 |
311 | def decorator_with_name(func):
312 | def func_wrapper(*args, **kwargs):
313 | with profile_kv(n):
314 | return func(*args, **kwargs)
315 |
316 | return func_wrapper
317 |
318 | return decorator_with_name
319 |
320 |
321 | # ================================================================
322 | # Backend
323 | # ================================================================
324 |
325 |
326 | def get_current():
327 | if Logger.CURRENT is None:
328 | _configure_default_logger()
329 |
330 | return Logger.CURRENT
331 |
332 |
333 | class Logger(object):
334 | DEFAULT = None # A logger with no output files. (See right below class definition)
335 | # So that you can still log to the terminal without setting up any output files
336 | CURRENT = None # Current logger being used by the free functions above
337 |
338 | def __init__(self, dir, output_formats, comm=None):
339 | self.name2val = defaultdict(float) # values this iteration
340 | self.name2cnt = defaultdict(int)
341 | self.level = INFO
342 | self.dir = dir
343 | self.output_formats = output_formats
344 | self.comm = comm
345 |
346 | # Logging API, forwarded
347 | # ----------------------------------------
348 | def logkv(self, key, val):
349 | self.name2val[key] = val
350 |
351 | def logkv_mean(self, key, val):
352 | oldval, cnt = self.name2val[key], self.name2cnt[key]
353 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
354 | self.name2cnt[key] = cnt + 1
355 |
356 | def dumpkvs(self, prefix=None):
357 | if self.comm is None:
358 | d = self.name2val
359 | else:
360 | d = mpi_weighted_mean(
361 | self.comm,
362 | {
363 | name: (val, self.name2cnt.get(name, 1))
364 | for (name, val) in self.name2val.items()
365 | },
366 | )
367 | if self.comm.rank != 0:
368 | d["dummy"] = 1 # so we don't get a warning about empty dict
369 | # LISA
370 | out = d.copy() # Return the dict for unit testing purposes
371 | if int(os.environ['LOCAL_RANK']) == 0:
372 | wandb.log({**d})
373 | for fmt in self.output_formats:
374 | if isinstance(fmt, KVWriter):
375 | fmt.writekvs(d)
376 | self.name2val.clear()
377 | self.name2cnt.clear()
378 | return out
379 |
380 | def log(self, *args, level=INFO):
381 | if self.level <= level:
382 | self._do_log(args)
383 |
384 | # Configuration
385 | # ----------------------------------------
386 | def set_level(self, level):
387 | self.level = level
388 |
389 | def set_comm(self, comm):
390 | self.comm = comm
391 |
392 | def get_dir(self):
393 | return self.dir
394 |
395 | def close(self):
396 | for fmt in self.output_formats:
397 | fmt.close()
398 |
399 | # Misc
400 | # ----------------------------------------
401 | def _do_log(self, args):
402 | for fmt in self.output_formats:
403 | if isinstance(fmt, SeqWriter):
404 | fmt.writeseq(map(str, args))
405 |
406 |
407 | def get_rank_without_mpi_import():
408 | # check environment variables here instead of importing mpi4py
409 | # to avoid calling MPI_Init() when this module is imported
410 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
411 | if varname in os.environ:
412 | return int(os.environ[varname])
413 | return 0
414 |
415 |
416 | def mpi_weighted_mean(comm, local_name2valcount):
417 | """
418 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
419 | Perform a weighted average over dicts that are each on a different node
420 | Input: local_name2valcount: dict mapping key -> (value, count)
421 | Returns: key -> mean
422 | """
423 | all_name2valcount = comm.gather(local_name2valcount)
424 | if comm.rank == 0:
425 | name2sum = defaultdict(float)
426 | name2count = defaultdict(float)
427 | for n2vc in all_name2valcount:
428 | for (name, (val, count)) in n2vc.items():
429 | try:
430 | val = float(val)
431 | except ValueError:
432 | if comm.rank == 0:
433 | warnings.warn(
434 | "WARNING: tried to compute mean on non-float {}={}".format(
435 | name, val
436 | )
437 | )
438 | else:
439 | name2sum[name] += val * count
440 | name2count[name] += count
441 | return {name: name2sum[name] / name2count[name] for name in name2sum}
442 | else:
443 | return {}
444 |
445 |
446 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
447 | """
448 | If comm is provided, average all numerical stats across that comm
449 | """
450 | if dir is None:
451 | dir = os.getenv("OPENAI_LOGDIR")
452 | if dir is None:
453 | dir = osp.join(
454 | tempfile.gettempdir(),
455 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
456 | )
457 | assert isinstance(dir, str)
458 | dir = os.path.expanduser(dir)
459 | os.makedirs(os.path.expanduser(dir), exist_ok=True)
460 |
461 | rank = get_rank_without_mpi_import()
462 | if rank > 0:
463 | log_suffix = log_suffix + "-rank%03i" % rank
464 |
465 | if format_strs is None:
466 | if rank == 0:
467 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
468 | else:
469 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
470 | format_strs = filter(None, format_strs)
471 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
472 |
473 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
474 | if output_formats:
475 | log("Logging to %s" % dir)
476 |
477 |
478 | def _configure_default_logger():
479 | configure()
480 | Logger.DEFAULT = Logger.CURRENT
481 |
482 |
483 | def reset():
484 | if Logger.CURRENT is not Logger.DEFAULT:
485 | Logger.CURRENT.close()
486 | Logger.CURRENT = Logger.DEFAULT
487 | log("Reset logger")
488 |
489 |
490 | @contextmanager
491 | def scoped_configure(dir=None, format_strs=None, comm=None):
492 | prevlogger = Logger.CURRENT
493 | configure(dir=dir, format_strs=format_strs, comm=comm)
494 | try:
495 | yield
496 | finally:
497 | Logger.CURRENT.close()
498 | Logger.CURRENT = prevlogger
499 |
500 |
--------------------------------------------------------------------------------
/diffuseq/utils/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for various likelihood-based losses. These are ported from the original
3 | Ho et al. diffusion models codebase:
4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5 | """
6 |
7 | import numpy as np
8 |
9 | import torch as th
10 |
11 |
12 | def normal_kl(mean1, logvar1, mean2, logvar2):
13 | """
14 | Compute the KL divergence between two gaussians.
15 |
16 | Shapes are automatically broadcasted, so batches can be compared to
17 | scalars, among other use cases.
18 | """
19 | tensor = None
20 | for obj in (mean1, logvar1, mean2, logvar2):
21 | if isinstance(obj, th.Tensor):
22 | tensor = obj
23 | break
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | # print(logvar2.shape)
34 | # temp1 = 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2))
35 | # print(f'const = {temp1.mean()}, coef={(th.exp(-logvar2) * 0.5).mean()}, mse={((mean1 - mean2) ** 2).mean().item()}')
36 |
37 | return 0.5 * (
38 | -1.0
39 | + logvar2
40 | - logvar1
41 | + th.exp(logvar1 - logvar2)
42 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
43 | )
44 |
45 |
46 | def approx_standard_normal_cdf(x):
47 | """
48 | A fast approximation of the cumulative distribution function of the
49 | standard normal.
50 | """
51 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
52 |
53 |
54 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
55 | """
56 | Compute the log-likelihood of a Gaussian distribution discretizing to a
57 | given image.
58 |
59 | :param x: the target images. It is assumed that this was uint8 values,
60 | rescaled to the range [-1, 1].
61 | :param means: the Gaussian mean Tensor.
62 | :param log_scales: the Gaussian log stddev Tensor.
63 | :return: a tensor like x of log probabilities (in nats).
64 | """
65 | assert x.shape == means.shape == log_scales.shape
66 | centered_x = x - means
67 | inv_stdv = th.exp(-log_scales)
68 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
69 | cdf_plus = approx_standard_normal_cdf(plus_in)
70 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
71 | cdf_min = approx_standard_normal_cdf(min_in)
72 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
73 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
74 | cdf_delta = cdf_plus - cdf_min
75 | log_probs = th.where(
76 | x < -0.999,
77 | log_cdf_plus,
78 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
79 | )
80 | assert log_probs.shape == x.shape
81 | return log_probs
82 |
83 | def gaussian_density(x, *, means, log_scales):
84 | from torch.distributions import Normal
85 | normal_dist = Normal(means, log_scales.exp())
86 | logp = normal_dist.log_prob(x)
87 | return logp
88 |
89 |
90 | def discretized_text_log_likelihood(x, *, means, log_scales):
91 | """
92 | Compute the log-likelihood of a Gaussian distribution discretizing to a
93 | given image.
94 |
95 | :param x: the target images. It is assumed that this was uint8 values,
96 | rescaled to the range [-1, 1].
97 | :param means: the Gaussian mean Tensor.
98 | :param log_scales: the Gaussian log stddev Tensor.
99 | :return: a tensor like x of log probabilities (in nats).
100 | """
101 | print(x.shape, means.shape)
102 | # assert x.shape == means.shape == log_scales.shape
103 | print(x, means)
104 | centered_x = x - means
105 | inv_stdv = th.exp(-log_scales)
106 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
107 | cdf_plus = approx_standard_normal_cdf(plus_in)
108 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
109 | cdf_min = approx_standard_normal_cdf(min_in)
110 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
111 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
112 | cdf_delta = cdf_plus - cdf_min
113 | log_probs = th.where(
114 | x < -0.999,
115 | log_cdf_plus,
116 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
117 | )
118 | assert log_probs.shape == x.shape
119 | return log_probs
120 |
--------------------------------------------------------------------------------
/diffuseq/utils/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
11 | class SiLU(nn.Module):
12 | def forward(self, x):
13 | return x * th.sigmoid(x)
14 |
15 |
16 | class GroupNorm32(nn.GroupNorm):
17 | def forward(self, x):
18 | return super().forward(x.float()).type(x.dtype)
19 |
20 | def linear(*args, **kwargs):
21 | """
22 | Create a linear module.
23 | """
24 | return nn.Linear(*args, **kwargs)
25 |
26 |
27 | def avg_pool_nd(dims, *args, **kwargs):
28 | """
29 | Create a 1D, 2D, or 3D average pooling module.
30 | """
31 | if dims == 1:
32 | return nn.AvgPool1d(*args, **kwargs)
33 | elif dims == 2:
34 | return nn.AvgPool2d(*args, **kwargs)
35 | elif dims == 3:
36 | return nn.AvgPool3d(*args, **kwargs)
37 | raise ValueError(f"unsupported dimensions: {dims}")
38 |
39 |
40 | def update_ema(target_params, source_params, rate=0.99):
41 | """
42 | Update target parameters to be closer to those of source parameters using
43 | an exponential moving average.
44 |
45 | :param target_params: the target parameter sequence.
46 | :param source_params: the source parameter sequence.
47 | :param rate: the EMA rate (closer to 1 means slower).
48 | """
49 | for targ, src in zip(target_params, source_params):
50 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
51 |
52 |
53 | def zero_module(module):
54 | """
55 | Zero out the parameters of a module and return it.
56 | """
57 | for p in module.parameters():
58 | p.detach().zero_()
59 | return module
60 |
61 |
62 | def scale_module(module, scale):
63 | """
64 | Scale the parameters of a module and return it.
65 | """
66 | for p in module.parameters():
67 | p.detach().mul_(scale)
68 | return module
69 |
70 |
71 | def mean_flat(tensor):
72 | """
73 | Take the mean over all non-batch dimensions.
74 | """
75 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
76 |
77 |
78 | def normalization(channels):
79 | """
80 | Make a standard normalization layer.
81 |
82 | :param channels: number of input channels.
83 | :return: an nn.Module for normalization.
84 | """
85 | return GroupNorm32(32, channels)
86 |
87 |
88 | def timestep_embedding(timesteps, dim, max_period=10000):
89 | """
90 | Create sinusoidal timestep embeddings.
91 |
92 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
93 | These may be fractional.
94 | :param dim: the dimension of the output.
95 | :param max_period: controls the minimum frequency of the embeddings.
96 | :return: an [N x dim] Tensor of positional embeddings.
97 | """
98 | half = dim // 2
99 | freqs = th.exp(
100 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
101 | ).to(device=timesteps.device)
102 | args = timesteps[:, None].float() * freqs[None]
103 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
104 | if dim % 2:
105 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
106 | return embedding
107 |
--------------------------------------------------------------------------------
/evaluation/__pycache__/tokenizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/evaluation/__pycache__/tokenizer.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/__pycache__/tokenizer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LCM-Lab/Bridge_Gap_Diffusion/e16c6a86ce66e171c58dba5c1ab9225cea39008a/evaluation/__pycache__/tokenizer.cpython-39.pyc
--------------------------------------------------------------------------------
/evaluation/eval.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import math
3 | import pdb
4 | import numpy as np
5 | from argparse import ArgumentParser
6 | from nltk import ngrams
7 | from tokenizer import SimpleTokenizer
8 | from nltk.tokenize import word_tokenize, wordpunct_tokenize
9 | from transformers import AutoTokenizer,AutoModelForCausalLM
10 | import nltk
11 | import copy
12 | import torch
13 | import evaluate
14 | from evaluate import load
15 | import os
16 | import mauve
17 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
18 | from multiprocessing.pool import Pool
19 | from tqdm import tqdm
20 | import random
21 | from functools import partial
22 | from torchmetrics.text.rouge import ROUGEScore
23 | rougeScore = ROUGEScore()
24 | device = "cuda" if torch.cuda.is_available() else "cpu"
25 | rougeScore.to(device)
26 | import csv
27 | from bert_score import score
28 |
29 |
30 |
31 |
32 | tokenizer = SimpleTokenizer(method="nltk")
33 |
34 |
35 |
36 | def bleu(refs, cands):
37 |
38 | result = {}
39 |
40 | for i in range(1, 5):
41 | res = []
42 | for ref,cand in zip(refs,cands):
43 | # result["bleu-%d"%i] = "%.4f"%(nltk.translate.bleu_score.corpus_bleu([[r] for r in refs], cands, weights=tuple([1./i for j in range(i)]),smoothing_function=SmoothingFunction().method4))
44 | res.append(sentence_bleu([ref], cand, smoothing_function=SmoothingFunction().method4,weights=tuple([1./i for j in range(i)])))
45 | result["bleu-%d"%i] = np.mean(res)
46 |
47 | # result["bleu-%d"%i] = "%.4f"%(nltk.translate.bleu_score.corpus_bleu([[r] for r in refs], cands))
48 | return result
49 |
50 | def distinct_n_gram(hypn_lst,n):
51 | dist_list_fin = []
52 | for hypn in hypn_lst:
53 | hypn = [hypn]
54 | dist_list = []
55 | for hyp in hypn:
56 | hyp_ngrams = []
57 | hyp_ngrams += nltk.ngrams(hyp.split(), n)
58 | total_ngrams = len(hyp_ngrams)
59 | unique_ngrams = len(list(set(hyp_ngrams)))
60 | if total_ngrams == 0:
61 | continue
62 | dist_list.append(unique_ngrams/total_ngrams)
63 | if total_ngrams == 0:
64 | continue
65 | dist_list_fin.append(np.mean(dist_list))
66 | return np.mean(dist_list_fin)
67 |
68 |
69 |
70 | def repetition_distinct(hyps, times):
71 | dis_result, lex_rep = dict(), dict()
72 | for i in range(1, 5):
73 | num, all_ngram, all_ngram_num = 0, {}, 0
74 | for tokens in hyps:
75 |
76 | ngs = ["_".join(c) for c in ngrams(tokens, i)]
77 | all_ngram_num += len(ngs)
78 | for s in ngs:
79 | if s in all_ngram:
80 | all_ngram[s] += 1
81 | else:
82 | all_ngram[s] = 1
83 | for s in set(ngs):
84 | if ngs.count(s) > times:
85 | num += 1
86 | break
87 | lex_rep["repetition-%d"%i] = "%.4f"%(num / float(len(hyps)))
88 | dis_result["distinct-%d"%i] = "%.4f"%(len(all_ngram) / float(all_ngram_num))
89 |
90 | return dis_result, lex_rep
91 |
92 | def length_(cands):
93 | lengths = []
94 | for i in cands:
95 | lengths.append(len(i))
96 | return sum(lengths) / len(lengths)
97 |
98 |
99 | import scipy
100 | from transformers import AutoTokenizer, AutoModel
101 | def sent_semantic_repetition(device, model_name_or_path, hyps, source):
102 |
103 | sbert_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
104 | sbert_model = AutoModel.from_pretrained(model_name_or_path).to(device)
105 |
106 | def mean_pooling(model_output, attention_mask):
107 | token_embeddings = model_output[0] # First element of model_output contains all token embeddings
108 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
109 | sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
110 | sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
111 | return sum_embeddings / sum_mask
112 |
113 | all_distance, ss_max, ss_min, ss_mean = [], [], [], []
114 | count = 0
115 | for k, (ipt, cand) in enumerate(zip(source, hyps)):
116 | if k % 1000 == 0:
117 | print("processing %d lines"%k)
118 | sen_list = sent_tokenize("%s %s"%(ipt.strip(), cand.strip()))
119 | if len(sen_list) == 0:
120 | continue
121 | with torch.no_grad():
122 | encoder_input = sbert_tokenizer(sen_list,padding=True, truncation=True, return_tensors='pt', max_length=128).to(device)
123 | model_output = sbert_model(**encoder_input)
124 | sentence_embeddings = mean_pooling(model_output, encoder_input['attention_mask']).cpu().numpy()
125 |
126 | max_dis, min_dis, all_dis = [-1, -1, -1.], [-1, -1, 1.], []
127 | if len(sentence_embeddings) == 1:
128 | continue
129 |
130 | for i, sen1 in enumerate(sentence_embeddings):
131 | for j, sen2 in enumerate(sentence_embeddings):
132 | if i < j:
133 | distances = 1 - scipy.spatial.distance.cdist([sen1], [sen2], "cosine")[0][0]
134 | all_dis.append(distances)
135 |
136 | if distances > max_dis[2]:
137 | max_dis = [i, j, distances]
138 | if distances < min_dis[2]:
139 | min_dis = [i, j, distances]
140 | sort_dis = sorted(all_dis)
141 | all_distance.append([])
142 | for i in range(1, 11):
143 | all_distance[-1].append(np.mean(sort_dis[-i:]))
144 |
145 | ss_max.append(max_dis[2])
146 | ss_min.append(min_dis[2])
147 | ss_mean.append(np.mean(all_dis))
148 | count += 1
149 | all_distance = np.mean(all_distance, 0)
150 | return {
151 | "sent_rept_max": "%.4f"%(np.mean(ss_max)),
152 | "sent_rept_min": "%.4f"%(np.mean(ss_min)),
153 | "sent_rept_mean": "%.4f"%(np.mean(ss_mean)),
154 | "all_distance": " ".join(["%.4f"%tmpf for tmpf in all_distance]),
155 | }
156 |
157 | def rouge_score(preds, golds):
158 |
159 | rouge_results = {}
160 | rouge1 =[]
161 | rouge2 = []
162 | rougeL = []
163 | for srcs, tgts in zip(preds, golds):
164 | # predictions = [" ".join(srcs)]
165 | # references = [[" ".join(tgts)]]
166 | # rouge.add_batch(predictions=predictions, references=references)
167 | references = " ".join(tgts)
168 | predictions = " ".join(srcs)
169 | res = rougeScore(predictions, references)
170 | rouge1.append(res["rouge1_fmeasure"])
171 | rouge2.append(res["rouge2_fmeasure"])
172 | rougeL.append(res["rougeL_fmeasure"])
173 | rouge_results["rouge1"] = np.mean(rouge1)
174 | rouge_results["rouge2"] = np.mean(rouge2)
175 | rouge_results["rougeL"] = np.mean(rougeL)
176 | # rouge_results = rouge.compute()
177 |
178 |
179 | # rouge_results = rouge.compute(predictions=predictions, references=references, tokenizer=wordpunct_tokenize)
180 | return rouge_results
181 |
182 |
183 | def show_result(res_dict):
184 | for k, v in res_dict.items():
185 | print(f"{k:} : {v:}")
186 |
187 | def ori_pro(s, name=""):
188 | s = s.strip()
189 | # for i in range(10):
190 | # s = s.replace("[%d]"%i, "")
191 | s = s.replace("", " ")
192 | s = " ".join(s.strip().split())
193 | # s = roberta_tokenizer.decode(roberta_tokenizer.convert_tokens_to_ids(roberta_tokenizer.tokenize(s)))
194 | return s
195 |
196 | def pro(token_list, tokenizer):
197 | token_list = "".join(token_list.split(" "))
198 | token_list = tokenizer(token_list)['input_ids']
199 | for i, t in enumerate(token_list):
200 | if t not in [0, 2]:
201 | break
202 | token_list = token_list[i:]
203 | string = tokenizer.decode(token_list, skip_special_tokens=False)
204 | string = string.replace("", " ")
205 | string = string[:string.find("")].strip()
206 | return string
207 |
208 | def bleu_i(weights, all_sentences, smoothing_function, i):
209 | # noinspection PyTypeChecker
210 | return sentence_bleu(
211 | references=all_sentences[:i] + all_sentences[i + 1:],
212 | hypothesis=all_sentences[i],
213 | weights=weights,
214 | smoothing_function=smoothing_function)
215 |
216 | def self_bleu(generations_df, n_sample=1000):
217 |
218 | # import spacy
219 | random.seed(0)
220 | # nlp = spacy.load('en', disable=['parser', 'tagger', 'ner'])
221 | # nlp.add_pipe(nlp.create_pipe('sentencizer'))
222 |
223 | smoothing_function = SmoothingFunction().method1
224 | # all_sentences = []
225 | # for i, row in generations_df.iterrows():
226 | # # gens = row['tokens']
227 | # gens = [[str(token) for token in tokens] for tokens in row['tokens']]# for gen in row['generations']] {'prompt':"", tokens: [[1,2,3], [3,4,5], [5,6,7], ....]}
228 | # all_sentences += gens
229 |
230 | all_sentences = generations_df
231 |
232 | pool = Pool(processes=os.cpu_count())
233 | bleu_scores = []
234 | for n_gram in range(1, 6):
235 |
236 | if n_gram == 1:
237 | weights = (1.0, 0, 0, 0)
238 | elif n_gram == 2:
239 | weights = (0.5, 0.5, 0, 0)
240 | elif n_gram == 3:
241 | weights = (1.0 / 3, 1.0 / 3, 1.0 / 3, 0)
242 | elif n_gram == 4:
243 | weights = (0.25, 0.25, 0.25, 0.25)
244 | elif n_gram == 5:
245 | weights = (0.2, 0.2, 0.2, 0.2, 0.2)
246 | else:
247 | raise ValueError
248 | bleu_scores.append(
249 | list(tqdm(
250 | pool.imap_unordered(
251 | partial(bleu_i, weights, all_sentences, smoothing_function),
252 | random.sample(range(len(all_sentences)), min(n_sample, len(all_sentences)))),
253 | total=min(n_sample, len(all_sentences)),
254 | smoothing=0.0,
255 | desc=f"bleu-{n_gram}")))
256 | # print(f"\n\nbleu-{n_gram} = {sum(bleu_scores[n_gram - 1]) / n_sample}")
257 |
258 | pool.close()
259 | pool.join()
260 |
261 | bleus = []
262 | for n_gram in range(5):
263 | bleus.append(sum(bleu_scores[n_gram]) / n_sample)
264 | # print(f"bleu-{n_gram + 1} = {sum(bleu_scores[n_gram]) / n_sample}")
265 |
266 | return bleus
267 |
268 |
269 | model = AutoModelForCausalLM.from_pretrained(
270 | 'gpt2-large' # path to the AR model trained for LMing this task.
271 | ).cuda()
272 | tokenizer_ppl = AutoTokenizer.from_pretrained('gpt2-large')
273 |
274 | def mean_pooling(model_output, attention_mask):
275 | token_embeddings = model_output[0] #First element of model_output contains all token embeddings
276 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
277 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
278 |
279 | tokenizer_prompt = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')
280 | model_prompt = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')
281 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
282 | model_prompt.to(device)
283 |
284 | def similarity(golds,preds,sources):
285 | sen_score_lst = []
286 |
287 | for gold,pred,source in zip(golds,preds,sources):
288 |
289 | embeddings1 = tokenizer_prompt(pred, padding=True, truncation=True, return_tensors='pt')
290 | embeddings2 = tokenizer_prompt(source, padding=True, truncation=True, return_tensors='pt')
291 | embeddings1 = embeddings1.to(device)
292 | embeddings2 = embeddings2.to(device)
293 |
294 | with torch.no_grad():
295 | e1 = model_prompt(**embeddings1)
296 | e2 = model_prompt(**embeddings2)
297 | e1 = mean_pooling(e1, embeddings1['attention_mask'])
298 | e2 = mean_pooling(e2, embeddings2['attention_mask'])
299 | sen_score_lst.append(torch.dist(e1,e2,p=2))
300 |
301 | return sen_score_lst
302 |
303 | def eval_ppl(text_samples):
304 | '''
305 | Evaluating using GPT2 finetuned on this task...
306 | :param text_lst:
307 | :return:
308 | '''
309 |
310 |
311 | # print('finished loading models.')
312 |
313 |
314 |
315 | full_score = []
316 | agg_loss = []
317 | count = 0
318 | import math
319 | for x in text_samples:
320 |
321 |
322 | # print(x)
323 | # should also add BOS EOS token?
324 |
325 | tokenized_x = tokenizer_ppl(x, truncation=True,return_tensors='pt') #[reverse_tokenizer[s] for s in x]
326 | input_ids = tokenized_x['input_ids'].cuda()
327 | labels = input_ids.clone()
328 |
329 | # print(tokenized_x)
330 | # tokenized_x = torch.LongTensor(tokenized_x).cuda()
331 | # labels = tokenized_x.clone()
332 | # labels[labels == reverse_tokenizer['PAD']] = -100
333 | model_output = model(input_ids, labels=labels)
334 | if not math.isnan(model_output.loss.item()):
335 | agg_loss.append(model_output.loss.item())
336 | else:
337 | count += 1
338 |
339 | print("nan count:{}:{}".format(count,len(text_samples)))
340 |
341 | example_mean_score = torch.tensor(agg_loss).mean()
342 |
343 | full_score.append(example_mean_score)
344 |
345 | full_score_ = np.array(full_score).mean()
346 |
347 | # print(f'full NLL score is {full_score_} for {len(full_score)}')
348 | # print(f'full PPL score is {np.e ** full_score_} for {len(full_score)}')
349 |
350 | return np.e ** full_score_
351 |
352 |
353 |
354 | tokenizer_gpt = AutoTokenizer.from_pretrained("gpt2")
355 |
356 | if __name__ == "__main__":
357 | parser = ArgumentParser()
358 | parser.add_argument('--source-file', '-s', dest="source_file", help='source file', default="./src.txt")
359 | parser.add_argument('--golden-file', '-t', dest="golden_file", help='Input data file, one golden per line.', default="./gold.txt")
360 | parser.add_argument('--pred-file', dest="pred_file", help='Model predictions.', default="./pred.txt")
361 | parser.add_argument('--times', '-k', help='calculate the lexical repetitation of different datasets', default="4")
362 | parser.add_argument('--model_path_or_name', '-p', help='where the config and tokenizer store')
363 | parser.add_argument('--folder')
364 | parser.add_argument('--save_dir', default="./")
365 | args = parser.parse_args()
366 |
367 | cnt = 0
368 | paths = sorted(glob.glob(glob.escape(f"{args.folder}")+"/*json"))
369 | print(paths)
370 |
371 |
372 |
373 | for path in tqdm(paths):
374 | print(path)
375 | bleu_1 = []
376 | bleu_2 = []
377 | bleu_3 = []
378 | bleu_4 = []
379 | self_bleu_1 = []
380 | self_bleu_2 = []
381 | self_bleu_3 = []
382 | self_bleu_4 = []
383 | self_bleu_5 = []
384 | times2_repetition_1 = []
385 | times2_repetition_2 = []
386 | times2_repetition_3 = []
387 | times2_repetition_4 = []
388 | times4_repetition_1 = []
389 | times4_repetition_2 = []
390 | times4_repetition_3 = []
391 | times4_repetition_4 = []
392 | rouge1 = []
393 | rouge2 = []
394 | rougeL = []
395 | bert_prec = []
396 | bert_recall = []
397 | bert_f1 = []
398 | dist1 = []
399 | dist2 = []
400 | dist3 = []
401 | dist4 = []
402 | text_ppls = []
403 | gold_ppls = []
404 | deltas = []
405 | m_score = []
406 | sen_score = []
407 |
408 | import json
409 | with open(path, "r") as f:
410 | lst = f.readlines()
411 | lst = [json.loads(i) for i in lst]
412 |
413 |
414 | golds = []
415 | preds = []
416 | sources = []
417 | for d in lst:
418 | d["reference"] = d["reference"].replace("[CLS]","").replace("[SEP]","").strip()
419 | d["recover"] = d["recover"].replace("[CLS]","").replace("[SEP]","").strip()
420 | d["source"] = d["source"].replace("[CLS]","").replace("[SEP]","").strip()
421 | if d["reference"] == "" or d["recover"] == "":
422 | continue
423 | golds.append(d["reference"].replace("[CLS]","").replace("[SEP]","").strip())
424 | preds.append(d["recover"].replace("[CLS]","").replace("[SEP]","").strip())
425 | sources.append(d["source"].replace("[CLS]","").replace("[SEP]","").strip())
426 |
427 | source_golds = []
428 | source_preds = []
429 | for i,j,z in zip(golds,preds,sources):
430 | source_golds.append(z+" "+i)
431 | source_preds.append(z+" "+j)
432 |
433 | sen_score = similarity(golds,preds,sources)
434 |
435 |
436 | preds_str = preds
437 | golds_str = golds
438 |
439 | preds_bleu = []
440 | golds_bleu = []
441 | for i,j in zip(preds,golds):
442 | preds_bleu.append(i.split())
443 | golds_bleu.append(j.split())
444 |
445 | preds, golds = [tokenizer.tokenize(i) for i in preds], [tokenizer.tokenize(i) for i in golds]
446 |
447 | bleu_result = bleu(refs = golds_bleu, cands = preds_bleu)
448 |
449 | dis_result, lex_rep2 = repetition_distinct(preds, 2)
450 | dis_result, lex_rep4 = repetition_distinct(preds, 4)
451 |
452 |
453 |
454 |
455 |
456 | len_ = length_(preds)
457 | len_golds = length_(golds)
458 |
459 | # bertscore_result = bert_score(preds, golds)
460 | torch.cuda.empty_cache()
461 | P, R, F1 = score(preds_str, golds_str, model_type='microsoft/deberta-xlarge-mnli', lang='en', verbose=True)
462 |
463 | P = torch.mean(P)
464 | R = torch.mean(R)
465 | F1 = torch.mean(F1)
466 |
467 | rouge_result = rouge_score(preds, golds)
468 | print(path)
469 | text_ppl = eval_ppl(preds_str)
470 | gold_ppl = eval_ppl(golds_str)
471 | delta = text_ppl / gold_ppl
472 | delta = math.log(delta)
473 |
474 |
475 |
476 | recovers = []
477 | for i in preds_str:
478 | recover = tokenizer_gpt.encode(i)
479 | recover = list(map(str,recover))
480 | recovers.append(recover)
481 | self_bleus = self_bleu(recovers, n_sample=1000)
482 |
483 | # if "roc" in path:
484 | # print("roc",path)
485 | # m = mauve.compute_mauve(p_text=source_golds, q_text=source_preds, device_id=0, max_text_length=256, verbose=False)
486 | # else:
487 | # m = mauve.compute_mauve(p_text=golds_str, q_text=preds_str, device_id=0, max_text_length=256, verbose=False)
488 |
489 |
490 | # pdb.set_trace()
491 | # print(m)
492 |
493 |
494 |
495 | # m_score.append(m.mauve)
496 |
497 |
498 | # bert_prec.append(bertscore_result["precision"])
499 | # bert_recall.append(bertscore_result["recall"])
500 | # bert_f1.append(bertscore_result["f1"])
501 |
502 |
503 | bert_prec.append(P)
504 | bert_recall.append(R)
505 | bert_f1.append(F1)
506 | bleu_1.append(float(bleu_result["bleu-1"]))
507 | bleu_2.append(float(bleu_result["bleu-2"]))
508 | bleu_3.append(float(bleu_result["bleu-3"]))
509 | bleu_4.append(float(bleu_result["bleu-4"]))
510 | times2_repetition_1.append(float(lex_rep2["repetition-1"]))
511 | times2_repetition_2.append(float(lex_rep2["repetition-2"]))
512 | times2_repetition_3.append(float(lex_rep2["repetition-3"]))
513 | times2_repetition_4.append(float(lex_rep2["repetition-4"]))
514 | times4_repetition_1.append(float(lex_rep4["repetition-1"]))
515 | times4_repetition_2.append(float(lex_rep4["repetition-2"]))
516 | times4_repetition_3.append(float(lex_rep4["repetition-3"]))
517 | times4_repetition_4.append(float(lex_rep4["repetition-4"]))
518 | self_bleu_1.append(self_bleus[0])
519 | self_bleu_2.append(self_bleus[1])
520 | self_bleu_3.append(self_bleus[2])
521 | self_bleu_4.append(self_bleus[3])
522 | self_bleu_5.append(self_bleus[4])
523 | rouge1.append(rouge_result["rouge1"])
524 | rouge2.append(rouge_result["rouge2"])
525 | rougeL.append(rouge_result["rougeL"])
526 | dist1.append(float(dis_result["distinct-1"]))
527 | dist2.append(float(dis_result["distinct-2"]))
528 | dist3.append(float(dis_result["distinct-3"]))
529 | dist4.append(float(dis_result["distinct-4"]))
530 | text_ppls.append(text_ppl)
531 | gold_ppls.append(gold_ppl)
532 | deltas.append(abs(delta))
533 |
534 |
535 |
536 | evaluate = [bleu_1,bleu_2,bleu_3,bleu_4,times2_repetition_1,times2_repetition_2,times2_repetition_3,times2_repetition_4,times4_repetition_1,times4_repetition_2,times4_repetition_3,times4_repetition_4,self_bleu_1,self_bleu_2,self_bleu_3,self_bleu_4,self_bleu_5,rouge1,rouge2,rougeL,dist1,dist2,dist3,dist4,text_ppls,gold_ppls,deltas,m_score,bert_prec,bert_recall,bert_f1,sen_score]
537 |
538 | evaluate_name = ["bleu_1","bleu_2","bleu_3","bleu_4","times2_repetition_1","times2_repetition_2","times2_repetition_3","times2_repetition_4","times4_repetition_1","times4_repetition_2","times4_repetition_3","times4_repetition_4","self_bleu_1","self_bleu_2","self_bleu_3","self_bleu_4","self_bleu_5","rouge1","rouge2","rougeL","dist1","dist2","dist3","dist4","text_ppls","gold_ppls","deltas","mauve_score","bert_prec","bert_recall","bert_f1","sim"]
539 |
540 | print("folder_path:",path)
541 | for name,eva in zip(evaluate_name,evaluate):
542 | if len(eva) != 0:
543 | print("{}:".format(name),sum(eva)/len(eva))
544 |
545 |
--------------------------------------------------------------------------------
/evaluation/quick_eval.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import math
3 | import pdb
4 | import numpy as np
5 | from argparse import ArgumentParser
6 | from nltk import ngrams
7 | from tokenizer import SimpleTokenizer
8 | from nltk.tokenize import word_tokenize, wordpunct_tokenize
9 | from transformers import AutoTokenizer,AutoModelForCausalLM
10 | import nltk
11 | import copy
12 | import torch
13 | import evaluate
14 | from evaluate import load
15 | import os
16 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
17 | from multiprocessing.pool import Pool
18 | from tqdm import tqdm
19 | import random
20 | from functools import partial
21 | from torchmetrics.text.rouge import ROUGEScore
22 | rougeScore = ROUGEScore()
23 | device = "cuda" if torch.cuda.is_available() else "cpu"
24 | rougeScore.to(device)
25 | import csv
26 | # from bert_score import score
27 |
28 |
29 | tokenizer = SimpleTokenizer(method="nltk")
30 |
31 |
32 |
33 | def bleu(refs, cands):
34 |
35 | result = {}
36 |
37 | for i in range(1, 5):
38 | res = []
39 | for ref,cand in zip(refs,cands):
40 | # result["bleu-%d"%i] = "%.4f"%(nltk.translate.bleu_score.corpus_bleu([[r] for r in refs], cands, weights=tuple([1./i for j in range(i)]),smoothing_function=SmoothingFunction().method4))
41 | try:
42 | res.append(sentence_bleu([ref], cand, smoothing_function=SmoothingFunction().method4,weights=tuple([1./i for j in range(i)])))
43 | except:
44 | pass
45 | result["bleu-%d"%i] = np.mean(res)
46 |
47 | # result["bleu-%d"%i] = "%.4f"%(nltk.translate.bleu_score.corpus_bleu([[r] for r in refs], cands))
48 | return result
49 |
50 | def distinct_n_gram(hypn_lst,n):
51 | dist_list_fin = []
52 | for hypn in hypn_lst:
53 | hypn = [hypn]
54 | dist_list = []
55 | for hyp in hypn:
56 | hyp_ngrams = []
57 | hyp_ngrams += nltk.ngrams(hyp.split(), n)
58 | total_ngrams = len(hyp_ngrams)
59 | unique_ngrams = len(list(set(hyp_ngrams)))
60 | if total_ngrams == 0:
61 | continue
62 | dist_list.append(unique_ngrams/total_ngrams)
63 | if total_ngrams == 0:
64 | continue
65 | dist_list_fin.append(np.mean(dist_list))
66 | return np.mean(dist_list_fin)
67 |
68 |
69 |
70 | def repetition_distinct(hyps, times):
71 | dis_result, lex_rep = dict(), dict()
72 | for i in range(1, 5):
73 | num, all_ngram, all_ngram_num = 0, {}, 0
74 | for tokens in hyps:
75 |
76 | ngs = ["_".join(c) for c in ngrams(tokens, i)]
77 | all_ngram_num += len(ngs)
78 | for s in ngs:
79 | if s in all_ngram:
80 | all_ngram[s] += 1
81 | else:
82 | all_ngram[s] = 1
83 | for s in set(ngs):
84 | if ngs.count(s) > times:
85 | num += 1
86 | break
87 | lex_rep["repetition-%d"%i] = "%.4f"%(num / float(len(hyps)))
88 | dis_result["distinct-%d"%i] = "%.4f"%(len(all_ngram) / float(all_ngram_num))
89 |
90 | return dis_result, lex_rep
91 |
92 | def length_(cands):
93 | lengths = []
94 | for i in cands:
95 | lengths.append(len(i))
96 | return sum(lengths) / len(lengths)
97 |
98 |
99 |
100 | def rouge_score(preds, golds):
101 |
102 | rouge_results = {}
103 | rouge1 =[]
104 | rouge2 = []
105 | rougeL = []
106 | for srcs, tgts in zip(preds, golds):
107 | # predictions = [" ".join(srcs)]
108 | # references = [[" ".join(tgts)]]
109 | # rouge.add_batch(predictions=predictions, references=references)
110 | references = " ".join(tgts)
111 | predictions = " ".join(srcs)
112 | res = rougeScore(predictions, references)
113 | rouge1.append(res["rouge1_fmeasure"])
114 | rouge2.append(res["rouge2_fmeasure"])
115 | rougeL.append(res["rougeL_fmeasure"])
116 | rouge_results["rouge1"] = np.mean(rouge1)
117 | rouge_results["rouge2"] = np.mean(rouge2)
118 | rouge_results["rougeL"] = np.mean(rougeL)
119 | # rouge_results = rouge.compute()
120 |
121 |
122 | # rouge_results = rouge.compute(predictions=predictions, references=references, tokenizer=wordpunct_tokenize)
123 | return rouge_results
124 |
125 |
126 | def show_result(res_dict):
127 | for k, v in res_dict.items():
128 | print(f"{k:} : {v:}")
129 |
130 | def ori_pro(s, name=""):
131 | s = s.strip()
132 | # for i in range(10):
133 | # s = s.replace("[%d]"%i, "")
134 | s = s.replace("", " ")
135 | s = " ".join(s.strip().split())
136 | # s = roberta_tokenizer.decode(roberta_tokenizer.convert_tokens_to_ids(roberta_tokenizer.tokenize(s)))
137 | return s
138 |
139 | def pro(token_list, tokenizer):
140 | token_list = "".join(token_list.split(" "))
141 | token_list = tokenizer(token_list)['input_ids']
142 | for i, t in enumerate(token_list):
143 | if t not in [0, 2]:
144 | break
145 | token_list = token_list[i:]
146 | string = tokenizer.decode(token_list, skip_special_tokens=False)
147 | string = string.replace("", " ")
148 | string = string[:string.find("")].strip()
149 | return string
150 |
151 | def bleu_i(weights, all_sentences, smoothing_function, i):
152 | # noinspection PyTypeChecker
153 | return sentence_bleu(
154 | references=all_sentences[:i] + all_sentences[i + 1:],
155 | hypothesis=all_sentences[i],
156 | weights=weights,
157 | smoothing_function=smoothing_function)
158 |
159 | def self_bleu(generations_df, n_sample=1000):
160 |
161 | # import spacy
162 | random.seed(0)
163 | # nlp = spacy.load('en', disable=['parser', 'tagger', 'ner'])
164 | # nlp.add_pipe(nlp.create_pipe('sentencizer'))
165 |
166 | smoothing_function = SmoothingFunction().method1
167 | # all_sentences = []
168 | # for i, row in generations_df.iterrows():
169 | # # gens = row['tokens']
170 | # gens = [[str(token) for token in tokens] for tokens in row['tokens']]# for gen in row['generations']] {'prompt':"", tokens: [[1,2,3], [3,4,5], [5,6,7], ....]}
171 | # all_sentences += gens
172 |
173 | all_sentences = generations_df
174 |
175 | pool = Pool(processes=os.cpu_count())
176 | bleu_scores = []
177 | for n_gram in range(1, 6):
178 |
179 | if n_gram == 1:
180 | weights = (1.0, 0, 0, 0)
181 | elif n_gram == 2:
182 | weights = (0.5, 0.5, 0, 0)
183 | elif n_gram == 3:
184 | weights = (1.0 / 3, 1.0 / 3, 1.0 / 3, 0)
185 | elif n_gram == 4:
186 | weights = (0.25, 0.25, 0.25, 0.25)
187 | elif n_gram == 5:
188 | weights = (0.2, 0.2, 0.2, 0.2, 0.2)
189 | else:
190 | raise ValueError
191 | bleu_scores.append(
192 | list(tqdm(
193 | pool.imap_unordered(
194 | partial(bleu_i, weights, all_sentences, smoothing_function),
195 | random.sample(range(len(all_sentences)), min(n_sample, len(all_sentences)))),
196 | total=min(n_sample, len(all_sentences)),
197 | smoothing=0.0,
198 | desc=f"bleu-{n_gram}")))
199 | # print(f"\n\nbleu-{n_gram} = {sum(bleu_scores[n_gram - 1]) / n_sample}")
200 |
201 | pool.close()
202 | pool.join()
203 |
204 | bleus = []
205 | for n_gram in range(5):
206 | bleus.append(sum(bleu_scores[n_gram]) / n_sample)
207 | # print(f"bleu-{n_gram + 1} = {sum(bleu_scores[n_gram]) / n_sample}")
208 |
209 | return bleus
210 |
211 |
212 | tokenizer_gpt = AutoTokenizer.from_pretrained("./models/gpt2")
213 |
214 | if __name__ == "__main__":
215 | parser = ArgumentParser()
216 | parser.add_argument('--source-file', '-s', dest="source_file", help='source file', default="./src.txt")
217 | parser.add_argument('--golden-file', '-t', dest="golden_file", help='Input data file, one golden per line.', default="./gold.txt")
218 | parser.add_argument('--pred-file', dest="pred_file", help='Model predictions.', default="./pred.txt")
219 | parser.add_argument('--times', '-k', help='calculate the lexical repetitation of different datasets', default="4")
220 | parser.add_argument('--model_path_or_name', '-p', help='where the config and tokenizer store')
221 | parser.add_argument('--folder')
222 | parser.add_argument('--save_dir', default="./")
223 | args = parser.parse_args()
224 |
225 | cnt = 0
226 | paths = sorted(glob.glob(glob.escape(f"{args.folder}")+"/*json"))
227 | print(paths)
228 |
229 |
230 |
231 | for path in tqdm(paths):
232 | print(path)
233 | bleu_1 = []
234 | bleu_2 = []
235 | bleu_3 = []
236 | bleu_4 = []
237 | self_bleu_1 = []
238 | self_bleu_2 = []
239 | self_bleu_3 = []
240 | self_bleu_4 = []
241 | self_bleu_5 = []
242 | times2_repetition_1 = []
243 | times2_repetition_2 = []
244 | times2_repetition_3 = []
245 | times2_repetition_4 = []
246 | times4_repetition_1 = []
247 | times4_repetition_2 = []
248 | times4_repetition_3 = []
249 | times4_repetition_4 = []
250 | rouge1 = []
251 | rouge2 = []
252 | rougeL = []
253 | bert_prec = []
254 | bert_recall = []
255 | bert_f1 = []
256 | dist1 = []
257 | dist2 = []
258 | dist3 = []
259 | dist4 = []
260 | text_ppls = []
261 | gold_ppls = []
262 | deltas = []
263 | m_score = []
264 | sen_score = []
265 |
266 | import json
267 | with open(path, "r") as f:
268 | lst = f.readlines()
269 | lst = [json.loads(i) for i in lst]
270 |
271 |
272 | golds = []
273 | preds = []
274 | sources = []
275 | for d in lst:
276 | d["reference"] = d["reference"].replace("[CLS]","").replace("[SEP]","").strip()
277 | d["recover"] = d["recover"].replace("[CLS]","").replace("[SEP]","").strip()
278 | # d["source"] = d["source"].replace("[CLS]","").replace("[SEP]","").strip()
279 | if d["reference"] == "" or d["recover"] == "":
280 | continue
281 | golds.append(d["reference"].replace("[CLS]","").replace("[SEP]","").strip())
282 | preds.append(d["recover"].replace("[CLS]","").replace("[SEP]","").strip())
283 | # sources.append(d["source"].replace("[CLS]","").replace("[SEP]","").strip())
284 |
285 | # source_golds = []
286 | # source_preds = []
287 | # for i,j,z in zip(golds,preds,sources):
288 | # source_golds.append(z+" "+i)
289 | # source_preds.append(z+" "+j)
290 |
291 |
292 |
293 |
294 | preds_str = preds
295 | golds_str = golds
296 |
297 | preds_bleu = []
298 | golds_bleu = []
299 | for i,j in zip(preds,golds):
300 | preds_bleu.append(i.split())
301 | golds_bleu.append(j.split())
302 |
303 | preds, golds = [tokenizer.tokenize(i) for i in preds], [tokenizer.tokenize(i) for i in golds]
304 |
305 | bleu_result = bleu(refs = golds_bleu, cands = preds_bleu)
306 |
307 | dis_result, lex_rep2 = repetition_distinct(preds, 2)
308 | dis_result, lex_rep4 = repetition_distinct(preds, 4)
309 |
310 |
311 |
312 |
313 |
314 | len_ = length_(preds)
315 | len_golds = length_(golds)
316 |
317 | # bertscore_result = bert_score(preds, golds)
318 | torch.cuda.empty_cache()
319 | # P, R, F1 = score(preds_str, golds_str, model_type='microsoft/deberta-xlarge-mnli', lang='en', verbose=True)
320 |
321 | # P = torch.mean(P)
322 | # R = torch.mean(R)
323 | # F1 = torch.mean(F1)
324 |
325 | rouge_result = rouge_score(preds, golds)
326 | print(path)
327 |
328 |
329 |
330 |
331 | recovers = []
332 | for i in preds_str:
333 | recover = tokenizer_gpt.encode(i)
334 | recover = list(map(str,recover))
335 | recovers.append(recover)
336 | self_bleus = self_bleu(recovers, n_sample=1000)
337 |
338 | # bert_prec.append(P)
339 | # bert_recall.append(R)
340 | # bert_f1.append(F1)
341 | bleu_1.append(float(bleu_result["bleu-1"]))
342 | bleu_2.append(float(bleu_result["bleu-2"]))
343 | bleu_3.append(float(bleu_result["bleu-3"]))
344 | bleu_4.append(float(bleu_result["bleu-4"]))
345 | times2_repetition_1.append(float(lex_rep2["repetition-1"]))
346 | times2_repetition_2.append(float(lex_rep2["repetition-2"]))
347 | times2_repetition_3.append(float(lex_rep2["repetition-3"]))
348 | times2_repetition_4.append(float(lex_rep2["repetition-4"]))
349 | times4_repetition_1.append(float(lex_rep4["repetition-1"]))
350 | times4_repetition_2.append(float(lex_rep4["repetition-2"]))
351 | times4_repetition_3.append(float(lex_rep4["repetition-3"]))
352 | times4_repetition_4.append(float(lex_rep4["repetition-4"]))
353 | self_bleu_1.append(self_bleus[0])
354 | self_bleu_2.append(self_bleus[1])
355 | self_bleu_3.append(self_bleus[2])
356 | self_bleu_4.append(self_bleus[3])
357 | self_bleu_5.append(self_bleus[4])
358 | rouge1.append(rouge_result["rouge1"])
359 | rouge2.append(rouge_result["rouge2"])
360 | rougeL.append(rouge_result["rougeL"])
361 | dist1.append(float(dis_result["distinct-1"]))
362 | dist2.append(float(dis_result["distinct-2"]))
363 | dist3.append(float(dis_result["distinct-3"]))
364 | dist4.append(float(dis_result["distinct-4"]))
365 |
366 |
367 |
368 |
369 | evaluate = [bleu_1,bleu_2,bleu_3,bleu_4,times2_repetition_1,times2_repetition_2,times2_repetition_3,times2_repetition_4,times4_repetition_1,times4_repetition_2,times4_repetition_3,times4_repetition_4,self_bleu_1,self_bleu_2,self_bleu_3,self_bleu_4,self_bleu_5,rouge1,rouge2,rougeL,dist1,dist2,dist3,dist4]
370 |
371 | evaluate_name = ["bleu_1","bleu_2","bleu_3","bleu_4","times2_repetition_1","times2_repetition_2","times2_repetition_3","times2_repetition_4","times4_repetition_1","times4_repetition_2","times4_repetition_3","times4_repetition_4","self_bleu_1","self_bleu_2","self_bleu_3","self_bleu_4","self_bleu_5","rouge1","rouge2","rougeL","dist1","dist2","dist3","dist4"]
372 |
373 | print("folder_path:",path)
374 | for name,eva in zip(evaluate_name,evaluate):
375 | if len(eva) != 0:
376 | print("{}:".format(name),sum(eva)/len(eva))
377 |
378 |
--------------------------------------------------------------------------------
/evaluation/tokenizer.py:
--------------------------------------------------------------------------------
1 | """A module for Tokenizer"""
2 | import re
3 |
4 | from nltk.tokenize import WordPunctTokenizer
5 | import spacy
6 |
7 | class SimpleTokenizer():
8 | '''
9 | A simple tokenizer. ``method`` can either be ``nltk``or ``space`` or ``spacy'' or ``spacy_zh''.
10 | If ``nltk``, use ``WordPunctTokenizer`` from ``nltk.tokenize``.
11 | If ``space``, use ``str.split(" ")``.
12 | If ``spacy``, use ``spacy.load("en_core_web_sm")``.
13 | If ``spacy_zh``, use ``spacy.load("zh_core_web_sm")``.
14 | Arguments:
15 | method (str): the tokenization method, ``nltk`` or ``space``.
16 | special_tokens (List[str]): special tokens not to tokenize, such as ````.
17 | '''
18 | def __init__(self, method, special_tokens = None):
19 | self.method = method
20 | self.special_tokens = special_tokens
21 |
22 | if method == "nltk":
23 | self._callable_tokenizer = WordPunctTokenizer().tokenize
24 | elif method == "space":
25 | self._callable_tokenizer = str.split
26 | elif method == "spacy":
27 | # python -m spacy download en_core_web_sm
28 | self._callable_tokenizer = spacy.load('en_core_web_sm')
29 | elif method == "spacy_zh":
30 | # python -m spacy download zh_core_web_sm
31 | self._callable_tokenizer = spacy.load('zh_core_web_sm')
32 | else:
33 | raise ValueError('`method` is invalid value {}, should be "nltk" or "space" or "spacy" '.format(method))
34 |
35 | def tokenize(self, sentence):
36 | '''Tokenize a sentence to a list of tokens.
37 | Arguments:
38 | sentence (str): a sentence to tokenize.
39 | '''
40 | if self.special_tokens is None:
41 | return self._callable_tokenizer(sentence)
42 | regexPattern = '(' + '|'.join(map(re.escape, self.special_tokens)) + ')'
43 | segments = re.split(regexPattern, sentence)
44 | sent = []
45 | for seg in segments:
46 | if seg not in self.special_tokens:
47 | sent += self._callable_tokenizer(seg.strip())
48 | else:
49 | sent += [seg]
50 | return sent
51 |
52 | def convert_tokens_to_sentence(self, tokens):
53 | '''Convert tokens to sentence.
54 | It usually works like the reverse operation of :meth:`tokenize`, but it is not gauranteed.
55 | It may like ``" ".join(tokens)``, but some special condition and tokens will be took care.
56 | Arguments:
57 | tokens(List[str]): tokenized sentence
58 | '''
59 | if self.method == "nltk":
60 | sent = " ".join(tokens)
61 | out_string = sent.replace(' .', '.').replace(' ?', '?'). \
62 | replace(' !', '!').replace(' ,', ',').replace(" ' ", "'"). \
63 | replace(" n't", "n't").replace(" 'm", "'m"). \
64 | replace(" 's", "'s"). \
65 | replace(" 've", "'ve").replace(" 're", "'re")
66 | return out_string
67 | elif self.method == "space" or self.method == "spacy" or self.method == "spacy_zh":
68 | return " ".join(tokens)
69 | else:
70 | raise RuntimeError("No such tokenizer %s" % self.method)
71 |
72 | def name(self):
73 | return "SimpleTokenizer/" + self.method
74 |
75 | class PretrainedTokenizer():
76 | '''Bases: :class:`.dataloader.Tokenizer`
77 | A wrapper for ``Pretrainedtokenizer`` from ``transformers`` package.
78 | If you don't want to do tokenization on some special tokens, see
79 | ``transformers.Pretrainedtokenizer.add_special_tokens``.
80 | Arguments:
81 | tokenizer (transformers.Pretrainedtokenizer): An
82 | instance of ``transformers.Pretrainedtokenizer``.
83 | '''
84 | def __init__(self, method):
85 | if "gpt" in method:
86 | from transformers.tokenization_gpt2 import GPT2Tokenizer
87 | self.tokenizer = GPT2Tokenizer.from_pretrained(method)
88 | elif "bert" in method:
89 | from transformers.tokenization_bert import BertTokenizer
90 | self.tokenizer = BertTokenizer.from_pretrained(method)
91 | else:
92 | raise ValueError('`method` is invalid value {}, should be "gpt"/"bpe" or "bert"'.format(method))
93 |
94 | self._tokenizer_class_name = self.tokenizer.__class__.__name__
95 |
96 | def tokenize(self, sentence):
97 | return self.tokenizer.tokenize(sentence)
98 |
99 | def convert_tokens_to_sentence(self, tokens):
100 | return self.tokenizer.convert_tokens_to_string(tokens)
101 |
102 | def name(self):
103 | return "PretrainedTokenizer/" + self.method
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | aiohttp==3.8.4
2 | aiosignal==1.3.1
3 | appdirs==1.4.4
4 | async-timeout==4.0.2
5 | asynctest==0.13.0
6 | attrs==23.1.0
7 | blobfile==2.0.2
8 | certifi @ file:///croot/certifi_1671487769961/work/certifi
9 | charset-normalizer==3.1.0
10 | click==8.1.3
11 | datasets==2.9.0
12 | dill==0.3.6
13 | docker-pycreds==0.4.0
14 | filelock==3.12.0
15 | frozenlist==1.3.3
16 | fsspec==2023.1.0
17 | gitdb==4.0.10
18 | GitPython==3.1.31
19 | huggingface-hub==0.14.1
20 | idna==3.4
21 | importlib-metadata==6.6.0
22 | lxml==4.9.2
23 | multidict==6.0.4
24 | multiprocess==0.70.14
25 | numpy==1.21.6
26 | nvidia-cublas-cu11==11.10.3.66
27 | nvidia-cuda-nvrtc-cu11==11.7.99
28 | nvidia-cuda-runtime-cu11==11.7.99
29 | nvidia-cudnn-cu11==8.5.0.96
30 | packaging==23.1
31 | pandas==1.3.5
32 | pathtools==0.1.2
33 | protobuf==4.22.4
34 | psutil==5.9.5
35 | pyarrow==12.0.0
36 | pycryptodomex==3.17
37 | python-dateutil==2.8.2
38 | pytz==2023.3
39 | PyYAML==6.0
40 | regex==2023.5.5
41 | requests==2.30.0
42 | responses==0.18.0
43 | sentry-sdk==1.22.1
44 | setproctitle==1.3.2
45 | six==1.16.0
46 | smmap==5.0.0
47 | tokenizers==0.12.1
48 | torch==1.13.1
49 | tqdm==4.65.0
50 | transformers==4.19.4
51 | typing_extensions==4.5.0
52 | urllib3==1.26.15
53 | wandb==0.15.2
54 | xxhash==3.2.0
55 | yarl==1.9.2
56 | zipp==3.15.0
57 |
--------------------------------------------------------------------------------
/sample_seq2seq.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of image samples from a model and save them as a large
3 | numpy array. This can be used to produce samples for FID evaluation.
4 | """
5 |
6 | import argparse
7 | import os, json
8 | from tracemalloc import start
9 |
10 | import numpy as np
11 | import torch as th
12 | import time
13 | import torch.distributed as dist
14 | from transformers import set_seed
15 | from diffuseq.rounding import denoised_fn_round, get_weights
16 | from diffuseq.text_datasets import load_data_text
17 |
18 | # from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
19 |
20 | import time
21 | from diffuseq.utils import dist_util, logger
22 | from functools import partial
23 | from basic_utils import (
24 | load_defaults_config,
25 | create_model_and_diffusion,
26 | add_dict_to_argparser,
27 | args_to_dict,
28 | load_model_emb,
29 | load_tokenizer
30 | )
31 |
32 | def create_argparser():
33 | defaults = dict(model_path='', step=0, out_dir='',)
34 | decode_defaults = dict(split='valid', clamp_step=0, seed2=105, clip_denoised=False, eta=-1.0, decode_respacing="", \
35 | top_l = 0.0 , top_k = 0, noised_l = 0.0, scale_end = "", top_p = 0.0, tau = 1.0, clamp_skip = -1, p_d="")
36 | defaults.update(load_defaults_config())
37 | defaults.update(decode_defaults)
38 | parser = argparse.ArgumentParser()
39 | add_dict_to_argparser(parser, defaults)
40 | return parser
41 |
42 |
43 | def main():
44 | args = create_argparser().parse_args()
45 |
46 | dist_util.setup_dist()
47 | logger.configure()
48 |
49 | # load configurations.
50 | config_path = os.path.join(os.path.split(args.model_path)[0], "training_args.json")
51 | print(config_path)
52 | # sys.setdefaultencoding('utf-8')
53 | with open(config_path, 'rb', ) as f:
54 | training_args = json.load(f)
55 | training_args['batch_size'] = args.batch_size
56 | args.__dict__.update(training_args)
57 |
58 | # mycode
59 | args.diffusion_steps = args.step
60 | args.timestep_respacing = args.decode_respacing
61 |
62 |
63 | logger.log("### Creating model and diffusion...")
64 | model, diffusion = create_model_and_diffusion(
65 | **args_to_dict(args, load_defaults_config().keys())
66 | )
67 |
68 |
69 | model.load_state_dict(
70 | dist_util.load_state_dict(args.model_path, map_location="cpu")
71 | )
72 |
73 | pytorch_total_params = sum(p.numel() for p in model.parameters())
74 | logger.log(f'### The parameter count is {pytorch_total_params}')
75 |
76 | model.to(dist_util.dev())
77 | model.eval()
78 |
79 | tokenizer = load_tokenizer(args)
80 | model_emb, tokenizer = load_model_emb(args, tokenizer)
81 |
82 | model_emb.weight = th.nn.Parameter(model.word_embedding.weight.clone().cpu())
83 | model_emb_copy = get_weights(model_emb, args)
84 |
85 | set_seed(args.seed2)
86 |
87 | print("### Sampling...on", args.split)
88 |
89 | ## load data
90 | data_valid = load_data_text(
91 | batch_size=args.batch_size,
92 | seq_len=args.seq_len,
93 | deterministic=True,
94 | data_args=args,
95 | split=args.split,
96 | loaded_vocab=tokenizer,
97 | model_emb=model_emb.cpu(), # using the same embedding wight with tranining data
98 | loop=False
99 | )
100 |
101 | start_t = time.time()
102 |
103 | # batch, cond = next(data_valid)
104 | # print(batch.shape)
105 |
106 | model_base_name = os.path.basename(os.path.split(args.model_path)[0]) + f'.{os.path.split(args.model_path)[1]}'
107 | out_dir = os.path.join(args.out_dir, f"{model_base_name.split('.ema')[0]}")
108 | if not os.path.isdir(out_dir):
109 | os.mkdir(out_dir)
110 |
111 | out_path = os.path.join(out_dir, f"ema{model_base_name.split('.ema')[1]}.samples")
112 | if not os.path.isdir(out_path):
113 | os.mkdir(out_path)
114 | out_path = os.path.join(out_path, f"seed{args.seed2}_step{args.step}_eta{args.eta}_respace_{args.decode_respacing}_{args.split}_topp{args.top_p}_tau{args.tau}\
115 | _topl{args.top_l}_topk{args.top_k}_noisedl{args.noised_l}_scend{args.scale_end}_clamp_skip{args.clamp_skip}_pd_{args.p_d}.json")
116 | # fout = open(out_path, 'a')
117 |
118 | all_test_data = []
119 |
120 | try:
121 | while True:
122 | batch, cond = next(data_valid)
123 | # print(batch.shape)
124 | all_test_data.append(cond)
125 |
126 | except StopIteration:
127 | print('### End of reading iteration...')
128 |
129 | from tqdm import tqdm
130 |
131 | # if args.ana_save: os.mkdir(f"/data/tzc/DiffuSeqs/myDiffuSeq/analyze_data/qqp-valid-embedding/{args.ana_save}")
132 |
133 |
134 | t_lst = []
135 | for cond in tqdm(all_test_data):
136 |
137 | input_ids_x = cond.pop('input_ids').to(dist_util.dev())
138 | x_start = model.get_embeds(input_ids_x)
139 | input_ids_mask = cond.pop('input_mask')
140 | input_ids_mask_ori = input_ids_mask
141 |
142 | noise = th.randn_like(x_start)
143 | input_ids_mask = th.broadcast_to(input_ids_mask.unsqueeze(dim=-1), x_start.shape).to(dist_util.dev())
144 | x_noised = th.where(input_ids_mask==0, x_start, noise)
145 |
146 | model_kwargs = {}
147 |
148 | step_gap = 1
149 |
150 | # if args.step == args.diffusion_steps:
151 | # args.use_ddim = False
152 | # step_gap = 1
153 | # else:
154 | # args.use_ddim = True
155 | # step_gap = args.diffusion_steps//args.step
156 | if args.eta >= 0:
157 | args.use_ddim = True
158 | else:
159 | args.use_ddim = False
160 |
161 | if args.p_d:
162 | args.use_ddim = False
163 |
164 | sample_fn = (
165 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
166 | )
167 |
168 | # sample_shape = (batch.shape[0], args.seq_len, args.hidden_dim)
169 | sample_shape = (x_start.shape[0], args.seq_len, args.hidden_dim)
170 |
171 | t1 = time.time()
172 | samples = sample_fn(
173 | model,
174 | sample_shape,
175 | noise=x_noised,
176 | clip_denoised=args.clip_denoised,
177 | denoised_fn=partial(denoised_fn_round, args, model_emb_copy.cuda()),
178 | model_kwargs=model_kwargs,
179 | top_p=args.top_p,
180 | clamp_step=args.clamp_step,
181 | clamp_first=True,
182 | mask=input_ids_mask,
183 | x_start=x_start,
184 | gap=step_gap,
185 | eta=args.eta,
186 | ddim_signal=args.p_d,
187 | # top_l = args.top_l,
188 | # top_k = args.top_k,
189 | # noised_l = args.noised_l,
190 | # scale_end = args.scale_end,
191 | )
192 | t2 = time.time()
193 | t_lst.append(t2 - t1)
194 |
195 | model_emb_copy = model_emb_copy.cpu()
196 | # print(samples[0].shape) # samples for each step
197 |
198 | sample = samples[-1]
199 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
200 | dist.all_gather(gathered_samples, sample)
201 | all_sentence = [sample.cpu().numpy() for sample in gathered_samples]
202 |
203 | # print('sampling takes {:.2f}s .....'.format(time.time() - start_t))
204 |
205 | word_lst_recover = []
206 | word_lst_ref = []
207 | word_lst_source = []
208 |
209 |
210 | arr = np.concatenate(all_sentence, axis=0)
211 | x_t = th.tensor(arr).cuda()
212 | print('decoding for seq2seq', )
213 | print(arr.shape)
214 |
215 | reshaped_x_t = x_t
216 | logits = model.get_logits(reshaped_x_t) # bsz, seqlen, vocab
217 |
218 | cands = th.topk(logits, k=1, dim=-1)
219 | sample = cands.indices
220 | # tokenizer = load_tokenizer(args)
221 |
222 | for seq, input_mask in zip(cands.indices, input_ids_mask_ori):
223 | len_x = args.seq_len - sum(input_mask).tolist()
224 | tokens = tokenizer.decode_token(seq[len_x:])
225 | word_lst_recover.append(tokens)
226 |
227 | for seq, input_mask in zip(input_ids_x, input_ids_mask_ori):
228 | # tokens = tokenizer.decode_token(seq)
229 | len_x = args.seq_len - sum(input_mask).tolist()
230 | word_lst_source.append(tokenizer.decode_token(seq[:len_x]))
231 | word_lst_ref.append(tokenizer.decode_token(seq[len_x:]))
232 |
233 |
234 | fout = open(out_path, 'a')
235 | for (recov, ref, src) in zip(word_lst_recover, word_lst_ref, word_lst_source):
236 | print(json.dumps({"recover": recov, "reference": ref, "source": src}), file=fout)
237 | fout.close()
238 |
239 | # for (recov, ref, src) in zip(word_lst_recover, word_lst_ref, word_lst_source):
240 | # print(json.dumps({"recover": recov, "reference": ref, "source": src}))
241 |
242 |
243 | print('### Total takes {:.2f}s .....'.format(time.time() - start_t))
244 | print(f'### Written the decoded output to {out_path}')
245 | print(t_lst)
246 | print(np.mean(t_lst))
247 |
248 | if __name__ == "__main__":
249 | main()
250 |
--------------------------------------------------------------------------------
/scripts/eval_seq2seq.py:
--------------------------------------------------------------------------------
1 | import os, sys, glob, json
2 | import numpy as np
3 | import argparse
4 | import torch
5 |
6 | from torchmetrics.text.rouge import ROUGEScore
7 | rougeScore = ROUGEScore()
8 | from bert_score import score
9 |
10 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
11 | import nltk
12 |
13 | def get_bleu(recover, reference):
14 | return sentence_bleu([recover.split()], reference.split(), smoothing_function=SmoothingFunction().method4,)
15 |
16 | def selectBest(sentences):
17 | selfBleu = [[] for i in range(len(sentences))]
18 | for i, s1 in enumerate(sentences):
19 | for j, s2 in enumerate(sentences):
20 | score = get_bleu(s1, s2)
21 | selfBleu[i].append(score)
22 | for i, s1 in enumerate(sentences):
23 | selfBleu[i][i] = 0
24 | idx = np.argmax(np.sum(selfBleu, -1))
25 | return sentences[idx]
26 |
27 | def get_bleu(recover, reference):
28 | return sentence_bleu([reference.split()], recover.split(), smoothing_function=SmoothingFunction().method4,)
29 |
30 | def diversityOfSet(sentences):
31 | selfBleu = []
32 | # print(sentences)
33 | for i, sentence in enumerate(sentences):
34 | for j in range(i+1, len(sentences)):
35 | # print(sentence, sentences[j])
36 | score = get_bleu(sentence, sentences[j])
37 | selfBleu.append(score)
38 | if len(selfBleu)==0:
39 | selfBleu.append(0)
40 | div4 = distinct_n_gram_inter_sent(sentences, 4)
41 | return np.mean(selfBleu), div4
42 |
43 |
44 | def distinct_n_gram(hypn,n):
45 | dist_list = []
46 | for hyp in hypn:
47 | hyp_ngrams = []
48 | hyp_ngrams += nltk.ngrams(hyp.split(), n)
49 | total_ngrams = len(hyp_ngrams)
50 | unique_ngrams = len(list(set(hyp_ngrams)))
51 | if total_ngrams == 0:
52 | return 0
53 | dist_list.append(unique_ngrams/total_ngrams)
54 | return np.mean(dist_list)
55 |
56 |
57 | def distinct_n_gram_inter_sent(hypn, n):
58 | hyp_ngrams = []
59 | for hyp in hypn:
60 | hyp_ngrams += nltk.ngrams(hyp.split(), n)
61 | total_ngrams = len(hyp_ngrams)
62 | unique_ngrams = len(list(set(hyp_ngrams)))
63 | if total_ngrams == 0:
64 | return 0
65 | dist_n = unique_ngrams/total_ngrams
66 | return dist_n
67 |
68 | if __name__ == '__main__':
69 |
70 | parser = argparse.ArgumentParser(description='decoding args.')
71 | parser.add_argument('--folder', type=str, default='', help='path to the folder of decoded texts')
72 | parser.add_argument('--mbr', action='store_true', help='mbr decoding or not')
73 | parser.add_argument('--sos', type=str, default='[CLS]', help='start token of the sentence')
74 | parser.add_argument('--eos', type=str, default='[SEP]', help='end token of the sentence')
75 | parser.add_argument('--sep', type=str, default='[SEP]', help='sep token of the sentence')
76 | parser.add_argument('--pad', type=str, default='[PAD]', help='pad token of the sentence')
77 |
78 | args = parser.parse_args()
79 |
80 | files = sorted(glob.glob(f"{args.folder}/*json"))
81 | sample_num = 0
82 | with open(files[0], 'r') as f:
83 | for row in f:
84 | sample_num += 1
85 |
86 | sentenceDict = {}
87 | referenceDict = {}
88 | sourceDict = {}
89 | for i in range(sample_num):
90 | sentenceDict[i] = []
91 | referenceDict[i] = []
92 | sourceDict[i] = []
93 |
94 | div4 = []
95 | selfBleu = []
96 |
97 | AVG_bleu = []
98 | AVG_rougel = []
99 | AVG_dist2 = []
100 | AVG_dist4 = []
101 | AVG_f1 = []
102 | AVG_len = []
103 |
104 | for path in files:
105 | print(path)
106 | sources = []
107 | references = []
108 | recovers = []
109 | bleu = []
110 | rougel = []
111 | avg_len = []
112 | dist2 = []
113 | dist4 = []
114 |
115 | with open(path, 'r') as f:
116 | cnt = 0
117 | for row in f:
118 |
119 | source = json.loads(row)['source'].strip()
120 | reference = json.loads(row)['reference'].strip()
121 | recover = json.loads(row)['recover'].strip()
122 | source = source.replace(args.eos, '').replace(args.sos, '')
123 | reference = reference.replace(args.eos, '').replace(args.sos, '').replace(args.sep, '')
124 | recover = recover.replace(args.eos, '').replace(args.sos, '').replace(args.sep, '').replace(args.pad, '')
125 |
126 | sources.append(source)
127 | references.append(reference)
128 | recovers.append(recover)
129 | avg_len.append(len(recover.split(' ')))
130 |
131 |
132 | bleu.append(get_bleu(recover, reference))
133 | rougel.append(rougeScore(recover, reference)['rougeL_fmeasure'].tolist())
134 | dist2.append(distinct_n_gram([recover], 2))
135 | dist4.append(distinct_n_gram([recover], 4))
136 |
137 | sentenceDict[cnt].append(recover)
138 | referenceDict[cnt].append(reference)
139 | sourceDict[cnt].append(source)
140 | cnt += 1
141 |
142 | P, R, F1 = score(recovers, references, model_type='microsoft/deberta-xlarge-mnli', lang='en', verbose=True)
143 |
144 | print('*'*30)
145 | print('avg BLEU score', np.mean(bleu))
146 | AVG_bleu.append(np.mean(bleu).item())
147 | print('avg ROUGE-L score', np.mean(rougel))
148 | AVG_rougel.append(np.mean(rougel).item())
149 | print('avg berscore', torch.mean(F1))
150 | AVG_f1.append(torch.mean(F1).item())
151 | print('avg dist2 score', np.mean(dist2))
152 | AVG_dist2.append(np.mean(dist2).item())
153 | print('avg dist4 score', np.mean(dist4))
154 | AVG_dist4.append(np.mean(dist4).item())
155 | print('avg len', np.mean(avg_len))
156 | AVG_len.append(np.mean(avg_len).item())
157 |
158 | if len(files)>1:
159 | if not args.mbr:
160 | print('*'*30)
161 | print('Compute diversity...')
162 | print('*'*30)
163 | for k, v in sentenceDict.items():
164 | if len(v) == 0:
165 | continue
166 | sb, d4 = diversityOfSet(v)
167 | selfBleu.append(sb)
168 | div4.append(d4)
169 |
170 | print('avg selfBleu score', np.mean(selfBleu))
171 | print('avg div4 score', np.mean(div4))
172 |
173 | else:
174 | print('*'*30)
175 | print('MBR...')
176 | print('*'*30)
177 | bleu = []
178 | rougel = []
179 | avg_len = []
180 | dist2 = []
181 | dist4 = []
182 | recovers = []
183 | references = []
184 | sources = []
185 |
186 |
187 | for k, v in sentenceDict.items():
188 | if len(v) == 0 or len(referenceDict[k]) == 0:
189 | continue
190 |
191 | recovers.append(selectBest(v))
192 | references.append(referenceDict[k][0])
193 | sources.append(sourceDict[k][0])
194 |
195 | for (source, reference, recover) in zip(sources, references, recovers):
196 | bleu.append(get_bleu(recover, reference))
197 | rougel.append(rougeScore(recover, reference)['rougeL_fmeasure'].tolist())
198 | avg_len.append(len(recover.split(' ')))
199 | dist2.append(distinct_n_gram([recover], 2))
200 | dist4.append(distinct_n_gram([recover], 4))
201 |
202 | # print(len(recovers), len(references), len(recovers))
203 |
204 | P, R, F1 = score(recovers, references, model_type='microsoft/deberta-xlarge-mnli', lang='en', verbose=True)
205 |
206 | print('*'*30)
207 | print('MBR BLEU score', np.mean(bleu))
208 | print('MBR ROUGE-l score', np.mean(rougel))
209 | print('MBR berscore', torch.mean(F1))
210 | print('MBR dist2 score', np.mean(dist2))
211 | print('MBR dist4 score', np.mean(dist4))
212 |
213 | print('*'*30)
214 | print('AVG BLEU score', sum(AVG_bleu) / len(AVG_bleu), len(AVG_bleu))
215 | print('AVG ROUGE-l score',sum(AVG_rougel) / len(AVG_rougel))
216 | print('AVG berscore', sum(AVG_f1) / len(AVG_f1))
217 | print('AVG dist2 score', sum(AVG_dist2) / len(AVG_dist2))
218 | print('AVG dist4 score', sum(AVG_dist4) / len(AVG_dist4))
219 |
220 | print('*'*30)
221 | print('BLEUs', AVG_bleu)
222 | print('ROUGE-ls',AVG_rougel)
223 | print('berscores', AVG_f1)
224 | print('dist2 scores', AVG_dist2)
225 | print('dist4 scores', AVG_dist4)
226 |
--------------------------------------------------------------------------------
/scripts/run_decode.py:
--------------------------------------------------------------------------------
1 | import os, sys, glob
2 | import argparse
3 | sys.path.append('.')
4 | sys.path.append('..')
5 |
6 | if __name__ == '__main__':
7 |
8 | parser = argparse.ArgumentParser(description='decoding args.')
9 | parser.add_argument('--model_dir', type=str, default='', help='path to the folder of diffusion model')
10 | parser.add_argument('--seed', type=int, default=101, help='random seed')
11 | parser.add_argument('--step', type=int, default=2000, help='if less than diffusion training steps, like 1000, use ddim sampling')
12 | parser.add_argument('--step_gap', type=int, default=0)
13 |
14 | parser.add_argument('--bsz', type=int, default=50, help='batch size')
15 | parser.add_argument('--split', type=str, default='test', choices=['train', 'valid', 'test'], help='dataset split used to decode')
16 |
17 | parser.add_argument('--top_p', type=int, default=-1, help='top p used in sampling, default is off')
18 | parser.add_argument('--pattern', type=str, default='ema', help='training pattern')
19 |
20 | args = parser.parse_args()
21 |
22 | # set working dir to the upper folder
23 | abspath = os.path.abspath(sys.argv[0])
24 | dname = os.path.dirname(abspath)
25 | dname = os.path.dirname(dname)
26 | os.chdir(dname)
27 |
28 | output_lst = []
29 | for lst in glob.glob(args.model_dir):
30 | print(lst)
31 | checkpoints = sorted(glob.glob(f"{lst}/{args.pattern}*.pt"))[::-1]
32 |
33 | out_dir = 'generation_outputs'
34 | if not os.path.isdir(out_dir):
35 | os.mkdir(out_dir)
36 |
37 | for checkpoint_one in checkpoints:
38 |
39 | COMMAND = f'python sample_seq2seq.py ' \
40 | f'--model_path {checkpoint_one} --step {args.step} ' \
41 | f'--batch_size {args.bsz} --seed2 {args.seed} --split {args.split} ' \
42 | f'--out_dir {out_dir} --top_p {args.top_p} --step_gap {args.step_gap}'
43 | print(COMMAND)
44 |
45 | os.system(COMMAND)
46 |
47 | print('#'*30, 'decoding finished...')
--------------------------------------------------------------------------------
/scripts/run_train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import argparse
4 | import time
5 | sys.path.append('.')
6 |
7 | if __name__ == '__main__':
8 |
9 | parser = argparse.ArgumentParser(description='training args.')
10 | parser.add_argument('--dataset', type=str, default='', help='name of training dataset')
11 | parser.add_argument('--data_dir', type=str, default='', help='path to training dataset')
12 |
13 | parser.add_argument('--noise_schedule', type=str, default='cosine', choices=['linear', 'cosine', 'sqrt', 'trunc_cos', 'trunc_lin', 'pw_lin'], help='the distribution of noises')
14 | parser.add_argument('--diff_steps', type=int, default=4000, help='diffusion steps')
15 | parser.add_argument('--schedule_sampler', type=str, default='uniform', choices=['uniform', 'lossaware', 'fixstep'], help='schedule sampler of timesteps')
16 |
17 | parser.add_argument('--seq_len', type=int, default=128, help='max len of input sequence')
18 | parser.add_argument('--hidden_t_dim', type=int, default=128, help='hidden size of time embedding')
19 | parser.add_argument('--hidden_dim', type=int, default=128, help='hidden size of word embedding')
20 | parser.add_argument('--learning_steps', type=int, default=40000, help='total steps of learning')
21 | parser.add_argument('--save_interval', type=int, default=10000, help='save step')
22 | parser.add_argument('--resume_checkpoint', type=str, default='none', help='path to resume checkpoint, like xxx/xxx.pt')
23 | parser.add_argument('--lr', type=float, default=1e-04, help='learning rate')
24 | parser.add_argument('--bsz', type=int, default=64, help='batch size')
25 | parser.add_argument('--microbatch', type=int, default=64, help='microbatch size')
26 | parser.add_argument('--seed', type=int, default=101, help='random seed')
27 |
28 | parser.add_argument('--use_fp16', action='store_true')
29 | parser.add_argument('--config_name', type=str, default='bert-base-uncased', help='config of pre-trained models')
30 | parser.add_argument('--vocab', type=str, default='bert', help='use bert vocab or load external vocab dict if given as path')
31 | parser.add_argument('--use_plm_init', type=str, default='no', choices=['no', 'bert'], help='load init parameter from the pre-trained lm')
32 |
33 | parser.add_argument('--notes', type=str, default='-', help='as training notes or specifical args')
34 | parser.add_argument('--app', type=str, default='', help='other input args')
35 |
36 | parser.add_argument('--simi_lambda', type=float, default = 0.01)
37 | parser.add_argument('--simi_step', type=int, default = 10)
38 | parser.add_argument('--simi_penalty', type=str, default = "cosine")
39 |
40 | parser.add_argument('--near_step', type=int, default = 0)
41 | parser.add_argument('--near_lambda', type=float, default = 0)
42 | parser.add_argument('--far_step', type=int, default = 0)
43 | parser.add_argument('--far_lambda', type=float, default = 0)
44 | parser.add_argument('--simi_noise', type=float, default = 0.05)
45 |
46 | args = parser.parse_args()
47 |
48 | # set working dir to the upper folder
49 | abspath = os.path.abspath(sys.argv[0])
50 | dname = os.path.dirname(abspath)
51 | dname = os.path.dirname(dname)
52 | os.chdir(dname)
53 |
54 | folder_name = "post_train_diffusion_models/"
55 |
56 | if int(os.environ['LOCAL_RANK']) == 0:
57 | if not os.path.isdir(folder_name):
58 | os.mkdir(folder_name)
59 |
60 | Model_FILE = f"diffuseq_{args.dataset}_h{args.hidden_dim}_lr{args.lr}" \
61 | f"_t{args.diff_steps}_{args.noise_schedule}_{args.schedule_sampler}" \
62 | f"_seed{args.seed}"
63 | if args.notes:
64 | args.notes += time.strftime("%Y%m%d-%H:%M:%S")
65 | Model_FILE = Model_FILE + f'_{args.notes}'
66 | Model_FILE = os.path.join(folder_name, Model_FILE)
67 |
68 | if int(os.environ['LOCAL_RANK']) == 0:
69 | if not os.path.isdir(Model_FILE):
70 | os.mkdir(Model_FILE)
71 |
72 | COMMANDLINE = f" OPENAI_LOGDIR={Model_FILE} " \
73 | f"TOKENIZERS_PARALLELISM=false " \
74 | f"python train.py " \
75 | f"--checkpoint_path {Model_FILE} " \
76 | f"--dataset {args.dataset} --data_dir {args.data_dir} --vocab {args.vocab} --use_plm_init {args.use_plm_init} " \
77 | f"--lr {args.lr} " \
78 | f"--batch_size {args.bsz} --microbatch {args.microbatch} " \
79 | f"--diffusion_steps {args.diff_steps} " \
80 | f"--noise_schedule {args.noise_schedule} " \
81 | f"--schedule_sampler {args.schedule_sampler} --resume_checkpoint {args.resume_checkpoint} " \
82 | f"--seq_len {args.seq_len} --hidden_t_dim {args.hidden_t_dim} --seed {args.seed} " \
83 | f"--hidden_dim {args.hidden_dim} " \
84 | f"--learning_steps {args.learning_steps} --save_interval {args.save_interval} " \
85 | f"--config_name {args.config_name} --notes {args.notes} " \
86 | f"--simi_lambda {args.simi_lambda} --simi_step {args.simi_step} --simi_penalty {args.simi_penalty} " \
87 | f"--near_lambda {args.near_lambda} --near_step {args.near_step} " \
88 | f"--far_lambda {args.far_lambda} --far_step {args.far_step} --simi_noise {args.simi_noise}" \
89 |
90 |
91 | COMMANDLINE += " " + args.app
92 |
93 | if int(os.environ['LOCAL_RANK']) == 0:
94 | with open(os.path.join(Model_FILE, 'saved_bash.sh'), 'w') as f:
95 | print(COMMANDLINE, file=f)
96 |
97 | print(COMMANDLINE)
98 |
99 | os.system(COMMANDLINE)
100 |
101 | print("FINISHED!!!!")
102 | time.sleep(300)
103 |
104 | # OPENAI_LOGDIR=diffusion_models/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_qqp20221123-21:24:15 TOKENIZERS_PARALLELISM=false python train.py --checkpoint_path diffusion_models/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_qqp20221123-21:24:15 --dataset qqp --data_dir datasets/QQP --vocab bert --use_plm_init no --lr 0.0001 --batch_size 2048 --microbatch 64 --diffusion_steps 2000 --noise_schedule sqrt --schedule_sampler lossaware --resume_checkpoint none --seq_len 128 --hidden_t_dim 128 --seed 102 --hidden_dim 128 --learning_steps 50000 --save_interval 10000 --config_name bert-base-uncased --notes qqp20221123-21:24:15 --simi_lambda 0.01 --simi_step 10 --simi_penalty cosine
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a diffusion model on images.
3 | """
4 |
5 | import argparse
6 | import json, torch, os
7 | import numpy as np
8 | from diffuseq.utils import dist_util, logger
9 | from diffuseq.text_datasets import load_data_text
10 | from diffuseq.step_sample import create_named_schedule_sampler
11 | from basic_utils import (
12 | load_defaults_config,
13 | create_model_and_diffusion,
14 | args_to_dict,
15 | add_dict_to_argparser,
16 | load_model_emb,
17 | load_tokenizer
18 | )
19 | from train_util import TrainLoop
20 | from transformers import set_seed
21 | import wandb
22 |
23 | ### custom your wandb setting here ###
24 | # os.environ["WANDB_API_KEY"] = ""
25 | os.environ["WANDB_MODE"] = "offline"
26 |
27 | def create_argparser():
28 | defaults = dict()
29 | defaults.update(load_defaults_config())
30 | parser = argparse.ArgumentParser()
31 | add_dict_to_argparser(parser, defaults) # update latest args according to argparse
32 | return parser
33 |
34 | def main():
35 | args = create_argparser().parse_args()
36 | set_seed(args.seed)
37 | dist_util.setup_dist()
38 | logger.configure()
39 | logger.log("### Creating data loader...")
40 |
41 | tokenizer = load_tokenizer(args)
42 | model_weight, tokenizer = load_model_emb(args, tokenizer)
43 |
44 | data = load_data_text(
45 | batch_size=args.batch_size,
46 | seq_len=args.seq_len,
47 | data_args = args,
48 | loaded_vocab=tokenizer,
49 | model_emb=model_weight # use model's weights as init
50 | )
51 | next(data)
52 |
53 | data_valid = load_data_text(
54 | batch_size=args.batch_size,
55 | seq_len=args.seq_len,
56 | data_args=args,
57 | split='valid',
58 | deterministic=True,
59 | loaded_vocab=tokenizer,
60 | model_emb=model_weight # using the same embedding wight with tranining data
61 | )
62 |
63 | print('#'*30, 'size of vocab', args.vocab_size)
64 |
65 | logger.log("### Creating model and diffusion...")
66 | # print('#'*30, 'CUDA_VISIBLE_DEVICES', os.environ['CUDA_VISIBLE_DEVICES'])
67 | model, diffusion = create_model_and_diffusion(
68 | **args_to_dict(args, load_defaults_config().keys())
69 | )
70 | # print('#'*30, 'cuda', dist_util.dev())
71 | model.to(dist_util.dev()) # DEBUG **
72 | # model.cuda() # DEBUG **
73 |
74 | pytorch_total_params = sum(p.numel() for p in model.parameters())
75 |
76 | logger.log(f'### The parameter count is {pytorch_total_params}')
77 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
78 |
79 | logger.log(f'### Saving the hyperparameters to {args.checkpoint_path}/training_args.json')
80 | with open(f'{args.checkpoint_path}/training_args.json', 'w') as f:
81 | json.dump(args.__dict__, f, indent=2)
82 |
83 | if ('LOCAL_RANK' not in os.environ) or (int(os.environ['LOCAL_RANK']) == 0):
84 | wandb.init(
85 | project=os.getenv("WANDB_PROJECT", "DiffuSeq"),
86 | name=args.checkpoint_path,
87 | )
88 | wandb.config.update(args.__dict__, allow_val_change=True)
89 |
90 | logger.log("### Training...")
91 |
92 | TrainLoop(
93 | model=model,
94 | diffusion=diffusion,
95 | data=data,
96 | batch_size=args.batch_size,
97 | microbatch=args.microbatch,
98 | lr=args.lr,
99 | ema_rate=args.ema_rate,
100 | log_interval=args.log_interval,
101 | save_interval=args.save_interval,
102 | resume_checkpoint=args.resume_checkpoint,
103 | use_fp16=args.use_fp16,
104 | fp16_scale_growth=args.fp16_scale_growth,
105 | schedule_sampler=schedule_sampler,
106 | weight_decay=args.weight_decay,
107 | learning_steps=args.learning_steps,
108 | checkpoint_path=args.checkpoint_path,
109 | gradient_clipping=args.gradient_clipping,
110 | eval_data=data_valid,
111 | eval_interval=args.eval_interval,
112 | simi_penalty = args.simi_penalty,
113 | simi_lambda = args.simi_lambda,
114 | simi_step = args.simi_step,
115 | near_step = args.near_step,
116 | far_step = args.far_step,
117 | near_lambda = args.near_lambda,
118 | far_lambda = args.far_lambda,
119 | simi_noise = args.simi_noise
120 | ).run_loop()
121 |
122 | if __name__ == "__main__":
123 | main()
124 |
--------------------------------------------------------------------------------
/train_util.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import os
4 |
5 | import blobfile as bf
6 | import numpy as np
7 | import torch as th
8 | import torch.distributed as dist
9 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
10 | from torch.optim import AdamW
11 | import io
12 | import time
13 |
14 | from diffuseq.utils import dist_util, logger
15 | from diffuseq.utils.fp16_util import (
16 | make_master_params,
17 | master_params_to_model_params,
18 | model_grads_to_master_grads,
19 | unflatten_master_params,
20 | zero_grad,
21 | )
22 | from diffuseq.utils.nn import update_ema
23 | from diffuseq.step_sample import LossAwareSampler, UniformSampler
24 |
25 | # For ImageNet experiments, this was a good default value.
26 | # We found that the lg_loss_scale quickly climbed to
27 | # 20-21 within the first ~1K steps of training.
28 | INITIAL_LOG_LOSS_SCALE = 20.0
29 |
30 |
31 | class TrainLoop:
32 | def __init__(
33 | self,
34 | *,
35 | model,
36 | diffusion,
37 | data,
38 | batch_size,
39 | microbatch,
40 | lr,
41 | ema_rate,
42 | log_interval,
43 | save_interval,
44 | resume_checkpoint,
45 | use_fp16=False,
46 | fp16_scale_growth=1e-3,
47 | schedule_sampler=None,
48 | weight_decay=0.0,
49 | learning_steps=0,
50 | checkpoint_path='',
51 | gradient_clipping=-1.,
52 | eval_data=None,
53 | eval_interval=-1,
54 | simi_penalty = "",
55 | simi_lambda = 0.01,
56 | simi_step = 10,
57 | near_step = 0,
58 | near_lambda = 0,
59 | far_step = 0,
60 | far_lambda = 0,
61 | simi_noise=0.05,
62 | ):
63 | self.simi_noise = simi_noise
64 | self.model = model
65 | self.diffusion = diffusion
66 | self.data = data
67 | self.eval_data = eval_data
68 | self.batch_size = batch_size
69 | self.microbatch = microbatch if microbatch > 0 else batch_size
70 | self.lr = lr
71 | self.ema_rate = (
72 | [ema_rate]
73 | if isinstance(ema_rate, float)
74 | else [float(x) for x in ema_rate.split(",")]
75 | )
76 | self.log_interval = log_interval
77 | self.eval_interval = eval_interval
78 | self.save_interval = save_interval
79 | self.resume_checkpoint = resume_checkpoint
80 | self.use_fp16 = use_fp16
81 | self.fp16_scale_growth = fp16_scale_growth
82 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
83 | self.weight_decay = weight_decay
84 | self.learning_steps = learning_steps
85 | self.gradient_clipping = gradient_clipping
86 |
87 | self.step = 0
88 | self.resume_step = 0
89 | self.global_batch = self.batch_size * dist.get_world_size()
90 |
91 | self.model_params = list(self.model.parameters())
92 | self.master_params = self.model_params
93 | self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
94 | self.sync_cuda = th.cuda.is_available()
95 |
96 | self.checkpoint_path = checkpoint_path # DEBUG **
97 |
98 | self.simi_penalty = simi_penalty
99 | self.simi_lambda = simi_lambda
100 | self.simi_step = simi_step
101 |
102 | self.near_step = near_step
103 | self.far_step = far_step
104 | self.near_lambda = near_lambda
105 | self.far_lambda = far_lambda
106 |
107 | self._load_and_sync_parameters()
108 | if self.use_fp16:
109 | self._setup_fp16()
110 |
111 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
112 | if self.resume_step:
113 | # self._load_optimizer_state()
114 | frac_done = (self.step + self.resume_step) / self.learning_steps
115 | if not resume_checkpoint: lr = self.lr * (1 - frac_done)
116 | self.opt = AdamW(self.master_params, lr=lr, weight_decay=self.weight_decay)
117 | # Model was resumed, either due to a restart or a checkpoint
118 | # being specified at the command line.
119 | self.ema_params = [
120 | self._load_ema_parameters(rate) for rate in self.ema_rate
121 | ]
122 | else:
123 | self.ema_params = [
124 | copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
125 | ]
126 |
127 | if th.cuda.is_available(): # DEBUG **
128 | self.use_ddp = True
129 | print(dist_util.dev())
130 | self.ddp_model = DDP(
131 | self.model,
132 | device_ids=[dist_util.dev()],
133 | output_device=dist_util.dev(),
134 | broadcast_buffers=False,
135 | bucket_cap_mb=128,
136 | find_unused_parameters=False,
137 | )
138 | else:
139 | if dist.get_world_size() > 1:
140 | logger.warn(
141 | "Distributed training requires CUDA. "
142 | "Gradients will not be synchronized properly!"
143 | )
144 | self.use_ddp = False
145 | self.ddp_model = self.model
146 |
147 | def _load_and_sync_parameters(self):
148 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
149 |
150 | if resume_checkpoint[-3:] == '.pt':
151 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
152 | if dist.get_rank() == 0:
153 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
154 | self.model.load_state_dict(
155 | dist_util.load_state_dict(
156 | actual_model_path(resume_checkpoint), map_location=dist_util.dev()
157 | )
158 | )
159 |
160 | dist_util.sync_params(self.model.parameters())
161 |
162 | def _load_ema_parameters(self, rate):
163 | ema_params = copy.deepcopy(self.master_params)
164 |
165 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
166 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
167 | if ema_checkpoint:
168 | if dist.get_rank() == 0:
169 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
170 | state_dict = dist_util.load_state_dict(
171 | actual_model_path(ema_checkpoint), map_location=dist_util.dev()
172 | )
173 | ema_params = self._state_dict_to_master_params(state_dict)
174 |
175 | dist_util.sync_params(ema_params)
176 | return ema_params
177 |
178 | def _load_optimizer_state(self):
179 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
180 | if bf.exists(main_checkpoint):
181 | logger.log(f"loading optimizer state from checkpoint: {main_checkpoint}")
182 | state_dict = dist_util.load_state_dict(
183 | actual_model_path(main_checkpoint), map_location=dist_util.dev()
184 | )
185 | self.opt.load_state_dict(state_dict)
186 |
187 | def _setup_fp16(self):
188 | self.master_params = make_master_params(self.model_params)
189 | self.model.convert_to_fp16()
190 |
191 | def run_loop(self):
192 | while (
193 | not self.learning_steps
194 | or self.step + self.resume_step < self.learning_steps
195 | ):
196 | batch, cond = next(self.data)
197 | self.run_step(batch, cond)
198 | # print(self.step, ":", time.time())
199 | if self.step % self.log_interval == 0:
200 | logger.dumpkvs()
201 | if self.eval_data is not None and self.step % self.eval_interval == 0:
202 | batch_eval, cond_eval = next(self.eval_data)
203 | self.forward_only(batch_eval, cond_eval)
204 | print('eval on validation set')
205 | logger.dumpkvs()
206 | if self.step % self.save_interval == 0:
207 | self.save()
208 | # Run for a finite amount of time in integration tests.
209 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
210 | return
211 | self.step += 1
212 | # print('done')
213 | # exit()
214 | # Save the last checkpoint if it wasn't already saved.
215 | if (self.step - 1) % self.save_interval != 0:
216 | self.save()
217 |
218 | def run_step(self, batch, cond):
219 | self.forward_backward(batch, cond, )
220 | if self.use_fp16:
221 | self.optimize_fp16()
222 | else:
223 | self.optimize_normal()
224 | self.log_step()
225 |
226 | def forward_only(self, batch, cond):
227 | with th.no_grad():
228 | zero_grad(self.model_params)
229 | for i in range(0, batch.shape[0], self.microbatch):
230 | micro = batch[i: i + self.microbatch].to(dist_util.dev())
231 | micro_cond = {
232 | k: v[i: i + self.microbatch].to(dist_util.dev())
233 | for k, v in cond.items()
234 | }
235 | last_batch = (i + self.microbatch) >= batch.shape[0]
236 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
237 | # print(micro_cond.keys())
238 |
239 | if self.near_step !=0 and self.far_step != 0: loss_type = "near and far"
240 | elif self.simi_penalty == "kl": loss_type = "kl"
241 | else: loss_type = ""
242 |
243 | if loss_type == "near and far":
244 | compute_losses = functools.partial(
245 | self.diffusion.training_losses,
246 | self.ddp_model,
247 | loss_type,
248 | micro,
249 | t,
250 | model_kwargs=micro_cond,
251 | simi_penalty = self.simi_penalty,
252 | near_step = self.near_step,
253 | near_lambda = self.near_lambda,
254 | far_step = self.far_step,
255 | far_lambda = self.far_lambda,
256 | )
257 | else:
258 | compute_losses = functools.partial(
259 | self.diffusion.training_losses,
260 | self.ddp_model,
261 | loss_type,
262 | micro,
263 | t,
264 | model_kwargs=micro_cond,
265 | simi_penalty = self.simi_penalty,
266 | simi_lambda = self.simi_lambda,
267 | simi_step = self.simi_step,
268 | simi_noise = self.simi_noise,
269 | )
270 |
271 |
272 | if last_batch or not self.use_ddp:
273 | losses = compute_losses()
274 | else:
275 | with self.ddp_model.no_sync():
276 | losses = compute_losses()
277 |
278 | log_loss_dict(
279 | self.diffusion, t, {f"eval_{k}": v * weights for k, v in losses.items()}
280 | )
281 |
282 |
283 | def forward_backward(self, batch, cond):
284 | zero_grad(self.model_params)
285 | for i in range(0, batch.shape[0], self.microbatch):
286 | micro = batch[i : i + self.microbatch].to(dist_util.dev())
287 | micro_cond = {
288 | k: v[i : i + self.microbatch].to(dist_util.dev())
289 | for k, v in cond.items()
290 | }
291 | last_batch = (i + self.microbatch) >= batch.shape[0]
292 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
293 | # print(micro_cond.keys())
294 |
295 | if self.near_step !=0 and self.far_step != 0: loss_type = "near and far"
296 | elif self.simi_penalty == "kl": loss_type = "kl"
297 | else: loss_type = ""
298 |
299 | if loss_type == "near and far":
300 | compute_losses = functools.partial(
301 | self.diffusion.training_losses,
302 | self.ddp_model,
303 | loss_type,
304 | micro,
305 | t,
306 | model_kwargs=micro_cond,
307 | simi_penalty = self.simi_penalty,
308 | near_step = self.near_step,
309 | near_lambda = self.near_lambda,
310 | far_step = self.far_step,
311 | far_lambda = self.far_lambda,
312 | )
313 | else:
314 | compute_losses = functools.partial(
315 | self.diffusion.training_losses,
316 | self.ddp_model,
317 | loss_type,
318 | micro,
319 | t,
320 | model_kwargs=micro_cond,
321 | simi_penalty = self.simi_penalty,
322 | simi_lambda = self.simi_lambda,
323 | simi_step = self.simi_step,
324 | simi_noise = self.simi_noise,
325 | )
326 |
327 | if last_batch or not self.use_ddp:
328 | losses = compute_losses()
329 | else:
330 | with self.ddp_model.no_sync():
331 | losses = compute_losses()
332 |
333 | if isinstance(self.schedule_sampler, LossAwareSampler):
334 | self.schedule_sampler.update_with_local_losses(
335 | t, losses["loss"].detach()
336 | )
337 |
338 | loss = (losses["loss"] * weights).mean()
339 | log_loss_dict(
340 | self.diffusion, t, {k: v * weights for k, v in losses.items()}
341 | )
342 | if self.use_fp16:
343 | exit()
344 | loss_scale = 2 ** self.lg_loss_scale
345 | (loss * loss_scale).backward()
346 | else:
347 | loss.backward()
348 |
349 | def optimize_fp16(self):
350 | if any(not th.isfinite(p.grad).all() for p in self.model_params):
351 | self.lg_loss_scale -= 1
352 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
353 | return
354 |
355 | model_grads_to_master_grads(self.model_params, self.master_params)
356 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
357 | self._log_grad_norm()
358 | self._anneal_lr()
359 | self.opt.step()
360 | for rate, params in zip(self.ema_rate, self.ema_params):
361 | update_ema(params, self.master_params, rate=rate)
362 | master_params_to_model_params(self.model_params, self.master_params)
363 | self.lg_loss_scale += self.fp16_scale_growth
364 |
365 | def grad_clip(self):
366 | # print('doing gradient clipping')
367 | max_grad_norm=self.gradient_clipping #3.0
368 | if hasattr(self.opt, "clip_grad_norm"):
369 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
370 | self.opt.clip_grad_norm(max_grad_norm)
371 | # else:
372 | # assert False
373 | # elif hasattr(self.model, "clip_grad_norm_"):
374 | # # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
375 | # self.model.clip_grad_norm_(args.max_grad_norm)
376 | else:
377 | # Revert to normal clipping otherwise, handling Apex or full precision
378 | th.nn.utils.clip_grad_norm_(
379 | self.model.parameters(), #amp.master_params(self.opt) if self.use_apex else
380 | max_grad_norm,
381 | )
382 |
383 | def optimize_normal(self):
384 | if self.gradient_clipping > 0:
385 | self.grad_clip()
386 | self._log_grad_norm()
387 | # self._anneal_lr()
388 | self.opt.step()
389 | for rate, params in zip(self.ema_rate, self.ema_params):
390 | update_ema(params, self.master_params, rate=rate)
391 |
392 | def _log_grad_norm(self):
393 | sqsum = 0.0
394 | # cnt = 0
395 | for p in self.master_params:
396 | # print(cnt, p) ## DEBUG
397 | # print(cnt, p.grad)
398 | # cnt += 1
399 | if p.grad != None:
400 | sqsum += (p.grad ** 2).sum().item()
401 | logger.logkv_mean("grad_norm", np.sqrt(sqsum))
402 |
403 | def _anneal_lr(self):
404 | if not self.learning_steps:
405 | return
406 | frac_done = (self.step + self.resume_step) / self.learning_steps
407 | lr = self.lr * (1 - frac_done)
408 | for param_group in self.opt.param_groups:
409 | param_group["lr"] = lr
410 |
411 | def log_step(self):
412 | logger.logkv("step", self.step + self.resume_step)
413 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
414 | if self.use_fp16:
415 | logger.logkv("lg_loss_scale", self.lg_loss_scale)
416 |
417 | def save(self):
418 | def save_checkpoint(rate, params):
419 | state_dict = self._master_params_to_state_dict(params)
420 | if dist.get_rank() == 0:
421 | logger.log(f"saving model {rate}...")
422 | if not rate:
423 | filename = f"model{(self.step+self.resume_step):06d}.pt"
424 | else:
425 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
426 | print('writing to', bf.join(get_blob_logdir(), filename))
427 | print('writing to', bf.join(self.checkpoint_path, filename))
428 | # with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
429 | # th.save(state_dict, f)
430 | with bf.BlobFile(bf.join(self.checkpoint_path, filename), "wb") as f: # DEBUG **
431 | th.save(state_dict, f) # save locally
432 | # pass # save empty
433 |
434 | # save_checkpoint(0, self.master_params)
435 | for rate, params in zip(self.ema_rate, self.ema_params):
436 | save_checkpoint(rate, params)
437 |
438 | dist.barrier()
439 |
440 | def _master_params_to_state_dict(self, master_params):
441 | if self.use_fp16:
442 | master_params = unflatten_master_params(
443 | list(self.model.parameters()), master_params # DEBUG **
444 | )
445 | state_dict = self.model.state_dict()
446 | for i, (name, _value) in enumerate(self.model.named_parameters()):
447 | assert name in state_dict
448 | state_dict[name] = master_params[i]
449 | return state_dict
450 |
451 | def _state_dict_to_master_params(self, state_dict):
452 | params = [state_dict[name] for name, _ in self.model.named_parameters()]
453 | if self.use_fp16:
454 | return make_master_params(params)
455 | else:
456 | return params
457 |
458 |
459 | def parse_resume_step_from_filename(filename):
460 | """
461 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
462 | checkpoint's number of steps.
463 | """
464 | if filename[-3:] == '.pt':
465 | return int(filename[-9:-3])
466 | else:
467 | return 0
468 |
469 |
470 | def get_blob_logdir():
471 | return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir())
472 |
473 |
474 | def find_resume_checkpoint():
475 | # On your infrastructure, you may want to override this to automatically
476 | # discover the latest checkpoint on your blob storage, etc.
477 | return None
478 |
479 |
480 | def find_ema_checkpoint(main_checkpoint, step, rate):
481 | if main_checkpoint is None:
482 | return None
483 | filename = f"ema_{rate}_{(step):06d}.pt"
484 | path = bf.join(bf.dirname(main_checkpoint), filename)
485 | if bf.exists(path):
486 | return path
487 | return None
488 |
489 |
490 | def log_loss_dict(diffusion, ts, losses):
491 | for key, values in losses.items():
492 | logger.logkv_mean(key, values.mean().item())
493 | # Log the quantiles (four quartiles, in particular).
494 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
495 | quartile = int(4 * sub_t / diffusion.num_timesteps)
496 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
497 |
498 |
499 | def actual_model_path(model_path):
500 | return model_path
501 |
--------------------------------------------------------------------------------